linmult.core.heads

Output head types, factory, and HeadModule for LinMulT/LinT models.

Classes

BaseHead

Abstract base class for all output heads.

SequenceAggregationHead

Output head that aggregates a sequence to a single vector.

SequenceHead

Output head that preserves the time dimension.

VectorHead

Output head for vector (already-aggregated) inputs.

SimpleHead

Lightweight linear head with optional time-dimension pooling.

UpsampleHead

Output head with learnable temporal upsampling.

DownsampleHead

Output head with learnable temporal downsampling.

HeadFactory

Registry and factory for output head types.

HeadModule

Self-contained output head container.

Module Contents

class linmult.core.heads.BaseHead(_input_dim: int, _output_dim: int, config: linmult.core.config.HeadConfig)[source]

Bases: torch.nn.Module

Abstract base class for all output heads.

Subclasses must implement forward(). Use from_config() as the standard factory entry point; it simply delegates to __init__.

Parameters:
  • _input_dim (int) – Input feature dimensionality (stored for subclass use).

  • _output_dim (int) – Output feature dimensionality.

  • config (HeadConfig) – Head configuration.

Initialize BaseHead.

extra_repr() str[source]

Return the head name for identification in repr output.

classmethod from_config(input_dim: int, output_dim: int, config: linmult.core.config.HeadConfig) BaseHead[source]

Construct a head from keyword arguments.

Parameters:
  • input_dim (int) – Input feature dimensionality.

  • output_dim (int) – Output feature dimensionality.

  • config (HeadConfig) – Head configuration.

Returns:

A new instance of this head class.

Return type:

BaseHead

class linmult.core.heads.SequenceAggregationHead(input_dim: int, output_dim: int, config: linmult.core.config.HeadConfig)[source]

Bases: BaseHead

Output head that aggregates a sequence to a single vector.

Maps (B, T, F)(B, output_dim) by normalizing, projecting to a hidden dimension, pooling along the time axis, and projecting to the output dimension.

Parameters:
  • input_dim (int) – Input feature dimensionality.

  • output_dim (int) – Output feature dimensionality.

  • config (HeadConfig) –

    Head configuration. Relevant attributes:

    • norm (str): Normalisation type, "bn" or "in". Default "bn".

    • pooling (str): Pooling type, "gap", "gmp", or "attentionpool". Default "gap".

    • hidden_dim (int): Hidden projection size. Default 256.

    • dropout (float): Dropout in the first projection. Default 0.1.

Initialize SequenceAggregationHead.

forward(x: torch.Tensor, mask: torch.Tensor | None = None) torch.Tensor[source]

Aggregate and project.

Parameters:
  • x (torch.Tensor) – Input of shape (B, T, F).

  • mask (torch.Tensor, optional) – Bool mask of shape (B, T). True = valid.

Returns:

Output of shape (B, output_dim).

Return type:

torch.Tensor

class linmult.core.heads.SequenceHead(input_dim: int, output_dim: int, config: linmult.core.config.HeadConfig)[source]

Bases: BaseHead

Output head that preserves the time dimension.

Maps (B, T, F)(B, T, output_dim) by normalizing and projecting each timestep independently.

Parameters:
  • input_dim (int) – Input feature dimensionality.

  • output_dim (int) – Output feature dimensionality.

  • config (HeadConfig) –

    Head configuration. Relevant attributes:

    • norm (str): Normalisation type, "bn" or "in". Default "bn".

    • hidden_dim (int): Hidden projection size. Default 256.

    • dropout (float): Dropout in the projection. Default 0.1.

Initialize SequenceHead.

forward(x: torch.Tensor, mask: torch.Tensor | None = None) torch.Tensor[source]

Normalize and project each timestep.

Parameters:
  • x (torch.Tensor) – Input of shape (B, T, F).

  • mask (torch.Tensor, optional) – Bool mask of shape (B, T). True = valid.

Returns:

Output of shape (B, T, output_dim).

Return type:

torch.Tensor

class linmult.core.heads.VectorHead(input_dim: int, output_dim: int, config: linmult.core.config.HeadConfig)[source]

Bases: BaseHead

Output head for vector (already-aggregated) inputs.

Maps (B, F)(B, output_dim) by normalizing and projecting.

Parameters:
  • input_dim (int) – Input feature dimensionality.

  • output_dim (int) – Output feature dimensionality.

  • config (HeadConfig) –

    Head configuration. Relevant attributes:

    • norm (str): Normalisation type, "bn" or "in". Default "bn".

    • hidden_dim (int): Hidden projection size. Default 256.

    • dropout (float): Dropout in the projection. Default 0.1.

Initialize VectorHead.

forward(x: torch.Tensor, **_kwargs) torch.Tensor[source]

Normalize and project a vector.

Parameters:

x (torch.Tensor) – Input of shape (B, F).

Returns:

Output of shape (B, output_dim).

Return type:

torch.Tensor

class linmult.core.heads.SimpleHead(input_dim: int, output_dim: int, config: linmult.core.config.HeadConfig)[source]

Bases: BaseHead

Lightweight linear head with optional time-dimension pooling.

Applies an optional pooling step followed by a single linear projection. Depending on the pooling config attribute, the mapping is:

  • No pooling (None): (B, T, F)(B, T, output_dim)

  • With pooling ("gap" / "gmp" / "attentionpool"):

    (B, T, F)(B, output_dim)

Parameters:
  • input_dim (int) – Input feature dimensionality.

  • output_dim (int) – Output feature dimensionality.

  • config (HeadConfig) –

    Head configuration. Relevant attribute:

    • pooling (str, optional): One of "gap", "gmp", "attentionpool", or None (no pooling).

Initialize SimpleHead.

forward(x: torch.Tensor, mask: torch.Tensor | None = None, **_kwargs) torch.Tensor[source]

Apply optional pooling then linear projection.

Parameters:
  • x (torch.Tensor) – Input of shape (B, T, F) or (B, F).

  • mask (torch.Tensor, optional) – Bool mask of shape (B, T). True = valid. Passed through to pooling layers when pool is configured.

Returns:

Output of shape (B, output_dim) if pooled,

otherwise (B, T, output_dim).

Return type:

torch.Tensor

class linmult.core.heads.UpsampleHead(input_dim: int, output_dim: int, config: linmult.core.config.HeadConfig)[source]

Bases: BaseHead

Output head with learnable temporal upsampling.

Maps (B, T_in, F)(B, output_time_dim, output_dim) by projecting the feature dimension, applying a stack of transposed convolutions (each doubling the time axis), then a final adaptive pool to hit the exact target.

Parameters:
  • input_dim (int) – Input feature dimensionality.

  • output_dim (int) – Output feature dimensionality.

  • config (HeadConfig) –

    Head configuration. Required attributes:

    • output_time_dim (int): Target time dimension.

    • input_time_dim (int): Source time dimension.

    • dropout (float): Dropout probability. Default 0.1.

Initialize UpsampleHead.

forward(x: torch.Tensor, mask: torch.Tensor | None = None) torch.Tensor[source]

Upsample and project.

Parameters:
  • x (torch.Tensor) – Input of shape (B, T_in, F).

  • mask (torch.Tensor, optional) – Bool mask of shape (B, T_in). True = valid. Masked positions are zeroed before processing.

Returns:

Output of shape (B, output_time_dim, output_dim).

Return type:

torch.Tensor

class linmult.core.heads.DownsampleHead(input_dim: int, output_dim: int, config: linmult.core.config.HeadConfig)[source]

Bases: BaseHead

Output head with learnable temporal downsampling.

Maps (B, T_in, F)(B, output_time_dim, output_dim) by projecting the feature dimension, applying strided convolutions (each halving the time axis), then a final adaptive average pool to hit the exact target.

Parameters:
  • input_dim (int) – Input feature dimensionality.

  • output_dim (int) – Output feature dimensionality.

  • config (HeadConfig) –

    Head configuration. Required attributes:

    • output_time_dim (int): Target time dimension.

    • input_time_dim (int): Source time dimension.

    • dropout (float): Dropout probability. Default 0.1.

Initialize DownsampleHead.

forward(x: torch.Tensor, mask: torch.Tensor | None = None) torch.Tensor[source]

Downsample and project.

Parameters:
  • x (torch.Tensor) – Input of shape (B, T_in, F).

  • mask (torch.Tensor, optional) – Bool mask of shape (B, T_in). True = valid. Masked positions are zeroed before processing.

Returns:

Output of shape (B, output_time_dim, output_dim).

Return type:

torch.Tensor

class linmult.core.heads.HeadFactory[source]

Registry and factory for output head types.

New head classes can be registered at runtime with register_head(), then instantiated by name with create_head().

Built-in types: "sequence_aggregation", "sequence", "vector", "simple", "upsample", "downsample".

classmethod register_head(name: str, head_cls: type[BaseHead]) None[source]

Register a custom head class under a given name.

Parameters:
  • name (str) – Registry key used in config["type"].

  • head_cls (type[BaseHead]) – Head class to register.

classmethod create_head(type: str, input_dim: int, output_dim: int, config: linmult.core.config.HeadConfig) BaseHead[source]

Instantiate a registered head by type name.

Parameters:
  • type (str) – Registered head type name.

  • input_dim (int) – Input feature dimensionality.

  • output_dim (int) – Output feature dimensionality.

  • config (HeadConfig) – Head configuration.

Returns:

The constructed head module.

Return type:

BaseHead

Raises:

ValueError – If type is not registered.

class linmult.core.heads.HeadModule(input_dim: int, head_configs: list[linmult.core.config.HeadConfig])[source]

Bases: torch.nn.Module

Self-contained output head container.

Builds all output heads from a list of HeadConfig using HeadFactory, and applies them in the forward pass.

Parameters:
  • input_dim – Input feature dimension fed to each head.

  • head_configs – List of head configurations.

Initialize internal Module state, shared by both nn.Module and ScriptModule.

forward(x: torch.Tensor, mask: torch.Tensor | None = None) dict[str, torch.Tensor][source]

Apply all heads to the input.

Parameters:
  • x – Input tensor (B, [T,] input_dim).

  • mask – Optional boolean mask (B, T).

Returns:

Dict mapping head name to output tensor.