linmult.models.LinT¶
LinT: unimodal linear-complexity transformer.
Classes¶
Linear-complexity Transformer for a single input modality. |
Module Contents¶
- class linmult.models.LinT.LinT(config: linmult.core.config.LinTConfig | str | pathlib.Path)[source]¶
Bases:
torch.nn.ModuleLinear-complexity Transformer for a single input modality.
Processes one time-series input through a projection + self-attention pipeline and applies configurable output heads.
- Parameters:
config (LinTConfig | str | Path) – Configuration object or path to a YAML file.
Initialize LinT.
- Parameters:
config (LinTConfig | str | Path) – Configuration object or path to a YAML file. When a path is given the file is loaded with
LinTConfig.from_yaml().
- forward(x: torch.Tensor, mask: torch.Tensor | None = None, name: str | None = None) dict[str, torch.Tensor][source]¶
Run the forward pass.
- Parameters:
x (torch.Tensor) – Input of shape
(B, T, F). May also be a single-element list[tensor].mask (torch.Tensor, optional) – Bool mask of shape
(B, T). True = valid timestep. A fully-False mask is treated asNone.name (str, optional) – Key used for special-handling lookup (e.g. weighted-sum of layer activations). May also be a single-element list.
- Returns:
- Mapping from head name to output tensor.
Shape is
(B, output_dim)whentime_dim_reduceris set, otherwise(B, T, output_dim).
- Return type:
dict[str, torch.Tensor]