linmult.core.heads¶
Output head types, factory, and HeadModule for LinMulT/LinT models.
Classes¶
Abstract base class for all output heads. |
|
Output head that aggregates a sequence to a single vector. |
|
Output head that preserves the time dimension. |
|
Output head for vector (already-aggregated) inputs. |
|
Lightweight linear head with optional time-dimension pooling. |
|
Output head with learnable temporal upsampling. |
|
Output head with learnable temporal downsampling. |
|
Registry and factory for output head types. |
|
Self-contained output head container. |
Module Contents¶
- class linmult.core.heads.BaseHead(_input_dim: int, _output_dim: int, config: linmult.core.config.HeadConfig)[source]¶
Bases:
torch.nn.ModuleAbstract base class for all output heads.
Subclasses must implement
forward(). Usefrom_config()as the standard factory entry point; it simply delegates to__init__.- Parameters:
_input_dim (int) – Input feature dimensionality (stored for subclass use).
_output_dim (int) – Output feature dimensionality.
config (HeadConfig) – Head configuration.
Initialize BaseHead.
- classmethod from_config(input_dim: int, output_dim: int, config: linmult.core.config.HeadConfig) BaseHead[source]¶
Construct a head from keyword arguments.
- Parameters:
input_dim (int) – Input feature dimensionality.
output_dim (int) – Output feature dimensionality.
config (HeadConfig) – Head configuration.
- Returns:
A new instance of this head class.
- Return type:
- class linmult.core.heads.SequenceAggregationHead(input_dim: int, output_dim: int, config: linmult.core.config.HeadConfig)[source]¶
Bases:
BaseHeadOutput head that aggregates a sequence to a single vector.
Maps
(B, T, F)→(B, output_dim)by normalizing, projecting to a hidden dimension, pooling along the time axis, and projecting to the output dimension.- Parameters:
input_dim (int) – Input feature dimensionality.
output_dim (int) – Output feature dimensionality.
config (HeadConfig) –
Head configuration. Relevant attributes:
norm(str): Normalisation type,"bn"or"in". Default"bn".pooling(str): Pooling type,"gap","gmp", or"attentionpool". Default"gap".hidden_dim(int): Hidden projection size. Default256.dropout(float): Dropout in the first projection. Default0.1.
Initialize SequenceAggregationHead.
- forward(x: torch.Tensor, mask: torch.Tensor | None = None) torch.Tensor[source]¶
Aggregate and project.
- Parameters:
x (torch.Tensor) – Input of shape
(B, T, F).mask (torch.Tensor, optional) – Bool mask of shape
(B, T). True = valid.
- Returns:
Output of shape
(B, output_dim).- Return type:
torch.Tensor
- class linmult.core.heads.SequenceHead(input_dim: int, output_dim: int, config: linmult.core.config.HeadConfig)[source]¶
Bases:
BaseHeadOutput head that preserves the time dimension.
Maps
(B, T, F)→(B, T, output_dim)by normalizing and projecting each timestep independently.- Parameters:
input_dim (int) – Input feature dimensionality.
output_dim (int) – Output feature dimensionality.
config (HeadConfig) –
Head configuration. Relevant attributes:
norm(str): Normalisation type,"bn"or"in". Default"bn".hidden_dim(int): Hidden projection size. Default256.dropout(float): Dropout in the projection. Default0.1.
Initialize SequenceHead.
- forward(x: torch.Tensor, mask: torch.Tensor | None = None) torch.Tensor[source]¶
Normalize and project each timestep.
- Parameters:
x (torch.Tensor) – Input of shape
(B, T, F).mask (torch.Tensor, optional) – Bool mask of shape
(B, T). True = valid.
- Returns:
Output of shape
(B, T, output_dim).- Return type:
torch.Tensor
- class linmult.core.heads.VectorHead(input_dim: int, output_dim: int, config: linmult.core.config.HeadConfig)[source]¶
Bases:
BaseHeadOutput head for vector (already-aggregated) inputs.
Maps
(B, F)→(B, output_dim)by normalizing and projecting.- Parameters:
input_dim (int) – Input feature dimensionality.
output_dim (int) – Output feature dimensionality.
config (HeadConfig) –
Head configuration. Relevant attributes:
norm(str): Normalisation type,"bn"or"in". Default"bn".hidden_dim(int): Hidden projection size. Default256.dropout(float): Dropout in the projection. Default0.1.
Initialize VectorHead.
- class linmult.core.heads.SimpleHead(input_dim: int, output_dim: int, config: linmult.core.config.HeadConfig)[source]¶
Bases:
BaseHeadLightweight linear head with optional time-dimension pooling.
Applies an optional pooling step followed by a single linear projection. Depending on the
poolingconfig attribute, the mapping is:No pooling (
None):(B, T, F)→(B, T, output_dim)- With pooling (
"gap"/"gmp"/"attentionpool"): (B, T, F)→(B, output_dim)
- With pooling (
- Parameters:
input_dim (int) – Input feature dimensionality.
output_dim (int) – Output feature dimensionality.
config (HeadConfig) –
Head configuration. Relevant attribute:
pooling(str, optional): One of"gap","gmp","attentionpool", orNone(no pooling).
Initialize SimpleHead.
- forward(x: torch.Tensor, mask: torch.Tensor | None = None, **_kwargs) torch.Tensor[source]¶
Apply optional pooling then linear projection.
- Parameters:
x (torch.Tensor) – Input of shape
(B, T, F)or(B, F).mask (torch.Tensor, optional) – Bool mask of shape
(B, T). True = valid. Passed through to pooling layers whenpoolis configured.
- Returns:
- Output of shape
(B, output_dim)if pooled, otherwise
(B, T, output_dim).
- Output of shape
- Return type:
torch.Tensor
- class linmult.core.heads.UpsampleHead(input_dim: int, output_dim: int, config: linmult.core.config.HeadConfig)[source]¶
Bases:
BaseHeadOutput head with learnable temporal upsampling.
Maps
(B, T_in, F)→(B, output_time_dim, output_dim)by projecting the feature dimension, applying a stack of transposed convolutions (each doubling the time axis), then a final adaptive pool to hit the exact target.- Parameters:
input_dim (int) – Input feature dimensionality.
output_dim (int) – Output feature dimensionality.
config (HeadConfig) –
Head configuration. Required attributes:
output_time_dim(int): Target time dimension.input_time_dim(int): Source time dimension.dropout(float): Dropout probability. Default0.1.
Initialize UpsampleHead.
- forward(x: torch.Tensor, mask: torch.Tensor | None = None) torch.Tensor[source]¶
Upsample and project.
- Parameters:
x (torch.Tensor) – Input of shape
(B, T_in, F).mask (torch.Tensor, optional) – Bool mask of shape
(B, T_in). True = valid. Masked positions are zeroed before processing.
- Returns:
Output of shape
(B, output_time_dim, output_dim).- Return type:
torch.Tensor
- class linmult.core.heads.DownsampleHead(input_dim: int, output_dim: int, config: linmult.core.config.HeadConfig)[source]¶
Bases:
BaseHeadOutput head with learnable temporal downsampling.
Maps
(B, T_in, F)→(B, output_time_dim, output_dim)by projecting the feature dimension, applying strided convolutions (each halving the time axis), then a final adaptive average pool to hit the exact target.- Parameters:
input_dim (int) – Input feature dimensionality.
output_dim (int) – Output feature dimensionality.
config (HeadConfig) –
Head configuration. Required attributes:
output_time_dim(int): Target time dimension.input_time_dim(int): Source time dimension.dropout(float): Dropout probability. Default0.1.
Initialize DownsampleHead.
- forward(x: torch.Tensor, mask: torch.Tensor | None = None) torch.Tensor[source]¶
Downsample and project.
- Parameters:
x (torch.Tensor) – Input of shape
(B, T_in, F).mask (torch.Tensor, optional) – Bool mask of shape
(B, T_in). True = valid. Masked positions are zeroed before processing.
- Returns:
Output of shape
(B, output_time_dim, output_dim).- Return type:
torch.Tensor
- class linmult.core.heads.HeadFactory[source]¶
Registry and factory for output head types.
New head classes can be registered at runtime with
register_head(), then instantiated by name withcreate_head().Built-in types:
"sequence_aggregation","sequence","vector","simple","upsample","downsample".- classmethod register_head(name: str, head_cls: type[BaseHead]) None[source]¶
Register a custom head class under a given name.
- Parameters:
name (str) – Registry key used in
config["type"].head_cls (type[BaseHead]) – Head class to register.
- classmethod create_head(type: str, input_dim: int, output_dim: int, config: linmult.core.config.HeadConfig) BaseHead[source]¶
Instantiate a registered head by type name.
- Parameters:
type (str) – Registered head type name.
input_dim (int) – Input feature dimensionality.
output_dim (int) – Output feature dimensionality.
config (HeadConfig) – Head configuration.
- Returns:
The constructed head module.
- Return type:
- Raises:
ValueError – If
typeis not registered.
- class linmult.core.heads.HeadModule(input_dim: int, head_configs: list[linmult.core.config.HeadConfig])[source]¶
Bases:
torch.nn.ModuleSelf-contained output head container.
Builds all output heads from a list of
HeadConfigusingHeadFactory, and applies them in the forward pass.- Parameters:
input_dim – Input feature dimension fed to each head.
head_configs – List of head configurations.
Initialize internal Module state, shared by both nn.Module and ScriptModule.