linmult.core.transformer

Transformer encoder: stacked pre-norm layers with multi-head attention and FFN.

Classes

TransformerEncoder

Transformer encoder with multiple stacked layers.

TransformerEncoderLayer

Single pre-norm transformer encoder layer with attention + FFN.

Module Contents

class linmult.core.transformer.TransformerEncoder(d_model: int = 40, num_heads: int = 8, num_layers: int = 6, attention_config: linmult.core.attention.AttentionConfig | None = None, dropout_pe: float = 0.0, dropout_ffn: float = 0.1, is_cross_modal: bool = False, name: str = '')[source]

Bases: torch.nn.Module

Transformer encoder with multiple stacked layers.

Supports both self-attention (when x_k and x_v are omitted) and cross-modal attention (when x_k and x_v are provided).

Parameters:
  • d_model (int) – Input and output feature dimensionality. Defaults to 40.

  • num_heads (int) – Number of attention heads. Defaults to 8.

  • num_layers (int) – Number of stacked encoder layers. Defaults to 6.

  • attention_config (AttentionConfig, optional) – Attention type and parameters. Defaults to AttentionConfig() (linear attention).

  • dropout_pe (float) – Dropout after positional encoding. Defaults to 0.0.

  • dropout_ffn (float) – Dropout in the FFN sub-layer. Defaults to 0.1.

  • is_cross_modal (bool) – Allocate a separate layer-norm for cross-modal key input. Set to True for cross-modal attention encoders. Defaults to False.

  • name (str) – Module name shown in repr. Defaults to "".

Initialize internal Module state, shared by both nn.Module and ScriptModule.

extra_repr() str[source]

Return the module name for identification in repr output.

forward(x_q: torch.Tensor, x_k: torch.Tensor | None = None, x_v: torch.Tensor | None = None, query_mask: torch.Tensor | None = None, key_mask: torch.Tensor | None = None) torch.Tensor[source]

Run the transformer encoder.

When x_k and x_v are omitted the encoder runs self-attention (x_q == x_k == x_v). When provided it runs cross-modal attention.

Parameters:
  • x_q (torch.Tensor) – Query input of shape (B, T_1, F).

  • x_k (torch.Tensor, optional) – Key input of shape (B, T_2, F).

  • x_v (torch.Tensor, optional) – Value input of shape (B, T_2, F).

  • query_mask (torch.BoolTensor, optional) – Mask for queries, shape (B, T_1).

  • key_mask (torch.BoolTensor, optional) – Mask for keys, shape (B, T_2).

Returns:

Encoded output of shape (B, T_1, F).

Return type:

torch.Tensor

class linmult.core.transformer.TransformerEncoderLayer(d_model: int = 40, num_heads: int = 8, attention_config: linmult.core.attention.AttentionConfig | None = None, dropout: float = 0.1, is_cross_modal: bool = False)[source]

Bases: torch.nn.Module

Single pre-norm transformer encoder layer with attention + FFN.

Supports self-attention and cross-modal attention. The cross-modal layer norm (layer_norm_cross) is only allocated when cross_modal=True, since pure self-attention layers never receive external keys.

Parameters:
  • d_model (int) – Feature dimensionality. Defaults to 40.

  • num_heads (int) – Number of attention heads. Defaults to 8.

  • attention_config (AttentionConfig, optional) – Attention type and parameters. Defaults to AttentionConfig() (linear attention).

  • dropout (float) – Dropout on FFN and residual connections. Defaults to 0.1.

  • is_cross_modal (bool) – Allocate a cross-modal layer-norm. Defaults to False.

Initialize internal Module state, shared by both nn.Module and ScriptModule.

forward(x_q: torch.Tensor, x_k: torch.Tensor | None = None, x_v: torch.Tensor | None = None, query_mask: torch.Tensor | None = None, key_mask: torch.Tensor | None = None) torch.Tensor[source]

Run one transformer encoder layer.

Parameters:
  • x_q (torch.Tensor) – Query input of shape (B, T_1, F).

  • x_k (torch.Tensor, optional) – Key input of shape (B, T_2, F).

  • x_v (torch.Tensor, optional) – Value input of shape (B, T_2, F).

  • query_mask (torch.BoolTensor, optional) – Mask for queries, shape (B, T_1).

  • key_mask (torch.BoolTensor, optional) – Mask for keys, shape (B, T_2).

Returns:

Layer output of shape (B, T_1, F).

Return type:

torch.Tensor

Raises:

ValueError – If mask shapes or dtypes are incorrect.