linmult.models.LinT

LinT: unimodal linear-complexity transformer.

Classes

LinT

Linear-complexity Transformer for a single input modality.

Module Contents

class linmult.models.LinT.LinT(config: linmult.core.config.LinTConfig | str | pathlib.Path)[source]

Bases: torch.nn.Module

Linear-complexity Transformer for a single input modality.

Processes one time-series input through a projection + self-attention pipeline and applies configurable output heads.

Parameters:

config (LinTConfig | str | Path) – Configuration object or path to a YAML file.

Initialize LinT.

Parameters:

config (LinTConfig | str | Path) – Configuration object or path to a YAML file. When a path is given the file is loaded with LinTConfig.from_yaml().

extra_repr() str[source]

Return the model name for identification in repr output.

forward(x: torch.Tensor, mask: torch.Tensor | None = None, name: str | None = None) dict[str, torch.Tensor][source]

Run the forward pass.

Parameters:
  • x (torch.Tensor) – Input of shape (B, T, F). May also be a single-element list [tensor].

  • mask (torch.Tensor, optional) – Bool mask of shape (B, T). True = valid timestep. A fully-False mask is treated as None.

  • name (str, optional) – Key used for special-handling lookup (e.g. weighted-sum of layer activations). May also be a single-element list.

Returns:

Mapping from head name to output tensor.

Shape is (B, output_dim) when time_dim_reducer is set, otherwise (B, T, output_dim).

Return type:

dict[str, torch.Tensor]