Source code for linmult.core.transformer
"""Transformer encoder: stacked pre-norm layers with multi-head attention and FFN."""
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from linmult.core.attention import AttentionConfig, AttentionFactory
from linmult.core.pe import PositionalEncoding
[docs]
class TransformerEncoder(nn.Module):
"""Transformer encoder with multiple stacked layers.
Supports both self-attention (when ``x_k`` and ``x_v`` are omitted) and
cross-modal attention (when ``x_k`` and ``x_v`` are provided).
Args:
d_model (int): Input and output feature dimensionality. Defaults to ``40``.
num_heads (int): Number of attention heads. Defaults to ``8``.
num_layers (int): Number of stacked encoder layers. Defaults to ``6``.
attention_config (AttentionConfig, optional): Attention type and parameters.
Defaults to ``AttentionConfig()`` (linear attention).
dropout_pe (float): Dropout after positional encoding. Defaults to ``0.0``.
dropout_ffn (float): Dropout in the FFN sub-layer. Defaults to ``0.1``.
is_cross_modal (bool): Allocate a separate layer-norm for cross-modal key input.
Set to ``True`` for cross-modal attention encoders. Defaults to ``False``.
name (str): Module name shown in ``repr``. Defaults to ``""``.
"""
def __init__(
self,
d_model: int = 40,
num_heads: int = 8,
num_layers: int = 6,
attention_config: AttentionConfig | None = None,
dropout_pe: float = 0.0,
dropout_ffn: float = 0.1,
is_cross_modal: bool = False,
name: str = "",
):
super().__init__()
self.name = name
self.embed_scale = math.sqrt(d_model)
self.embed_positions = PositionalEncoding(dropout=dropout_pe)
self.layers = nn.ModuleList(
[
TransformerEncoderLayer(
d_model=d_model,
num_heads=num_heads,
attention_config=attention_config,
dropout=dropout_ffn,
is_cross_modal=is_cross_modal,
)
for _ in range(num_layers)
]
)
self.layer_norm = nn.LayerNorm(d_model)
[docs]
def extra_repr(self) -> str:
"""Return the module name for identification in repr output."""
return f"name={self.name!r}" # pragma: no cover
[docs]
def forward(
self,
x_q: torch.Tensor,
x_k: torch.Tensor | None = None,
x_v: torch.Tensor | None = None,
query_mask: torch.Tensor | None = None,
key_mask: torch.Tensor | None = None,
) -> torch.Tensor:
"""Run the transformer encoder.
When ``x_k`` and ``x_v`` are omitted the encoder runs self-attention
(``x_q == x_k == x_v``). When provided it runs cross-modal attention.
Args:
x_q (torch.Tensor): Query input of shape ``(B, T_1, F)``.
x_k (torch.Tensor, optional): Key input of shape ``(B, T_2, F)``.
x_v (torch.Tensor, optional): Value input of shape ``(B, T_2, F)``.
query_mask (torch.BoolTensor, optional): Mask for queries, shape ``(B, T_1)``.
key_mask (torch.BoolTensor, optional): Mask for keys, shape ``(B, T_2)``.
Returns:
torch.Tensor: Encoded output of shape ``(B, T_1, F)``.
"""
x = self.embed_positions(self.embed_scale * x_q) # (B, T, d_model)
if x_k is not None and x_v is not None:
# K and V are always the same source tensor; apply PE once, reuse for both.
x_k = x_v = self.embed_positions(self.embed_scale * x_k)
for layer in self.layers:
x = layer(x, x_k, x_v, query_mask=query_mask, key_mask=key_mask)
else:
for layer in self.layers:
x = layer(x, query_mask=query_mask)
return self.layer_norm(x)
[docs]
class TransformerEncoderLayer(nn.Module):
"""Single pre-norm transformer encoder layer with attention + FFN.
Supports self-attention and cross-modal attention. The cross-modal layer
norm (``layer_norm_cross``) is only allocated when ``cross_modal=True``,
since pure self-attention layers never receive external keys.
Args:
d_model (int): Feature dimensionality. Defaults to ``40``.
num_heads (int): Number of attention heads. Defaults to ``8``.
attention_config (AttentionConfig, optional): Attention type and parameters.
Defaults to ``AttentionConfig()`` (linear attention).
dropout (float): Dropout on FFN and residual connections. Defaults to ``0.1``.
is_cross_modal (bool): Allocate a cross-modal layer-norm. Defaults to ``False``.
"""
def __init__(
self,
d_model: int = 40,
num_heads: int = 8,
attention_config: AttentionConfig | None = None,
dropout: float = 0.1,
is_cross_modal: bool = False,
):
super().__init__()
self.attention_type = attention_config.type if attention_config is not None else "linear"
self.attention = AttentionFactory.create(d_model, num_heads, attention_config)
self.fc1 = nn.Linear(d_model, 4 * d_model)
self.fc2 = nn.Linear(4 * d_model, d_model)
self.layer_norms = nn.ModuleList([nn.LayerNorm(d_model) for _ in range(2)])
# Allocate cross-modal norm only for layers that actually do cross-modal attention;
# pure self-attention layers (SAT) never receive x_k/x_v.
self.is_cross_modal = is_cross_modal
if self.is_cross_modal:
self.layer_norm_cross = nn.LayerNorm(d_model)
self.dropout = dropout
[docs]
def forward(
self,
x_q: torch.Tensor,
x_k: torch.Tensor | None = None,
x_v: torch.Tensor | None = None,
query_mask: torch.Tensor | None = None,
key_mask: torch.Tensor | None = None,
) -> torch.Tensor:
"""Run one transformer encoder layer.
Args:
x_q (torch.Tensor): Query input of shape ``(B, T_1, F)``.
x_k (torch.Tensor, optional): Key input of shape ``(B, T_2, F)``.
x_v (torch.Tensor, optional): Value input of shape ``(B, T_2, F)``.
query_mask (torch.BoolTensor, optional): Mask for queries, shape ``(B, T_1)``.
key_mask (torch.BoolTensor, optional): Mask for keys, shape ``(B, T_2)``.
Returns:
torch.Tensor: Layer output of shape ``(B, T_1, F)``.
Raises:
ValueError: If mask shapes or dtypes are incorrect.
"""
if query_mask is not None and (
query_mask.shape != x_q.shape[:2] or query_mask.dtype != torch.bool
):
raise ValueError(
f"Expected query mask has shape (B, T_1) and bool dtype, "
f"got instead: {query_mask.shape} and {query_mask.dtype}"
)
if key_mask is not None:
if x_k is None:
raise ValueError("key_mask was provided but x_k is None.")
if key_mask.shape != x_k.shape[:2] or key_mask.dtype != torch.bool:
raise ValueError(
f"Expected key_mask of shape (B, T_2) and bool dtype, "
f"got {key_mask.shape} {key_mask.dtype}."
)
residual = x_q
x_q = self.layer_norms[0](x_q)
if x_k is not None and x_v is not None:
# K and V are always the same source tensor — normalize once, reuse for both.
# Apply layer_norm_cross only when allocated (cross_modal=True layers).
# Non-cross-modal encoders receive already-normed features from their own encoder.
x_kv = self.layer_norm_cross(x_k) if self.is_cross_modal else x_k # (B, T_2, F)
if self.attention_type == "mha":
kpm = ~key_mask if key_mask is not None else None
x_q, _ = self.attention(x_q, x_kv, x_kv, key_padding_mask=kpm)
else:
x_q, _ = self.attention(
x_q, x_kv, x_kv, query_mask=query_mask, key_mask=key_mask
) # (B, T_1, F)
if key_mask is not None:
fully_masked_keys = ~key_mask.any(dim=1) # (B,)
if fully_masked_keys.any():
# Samples where all keys are masked should produce zero output.
zero_output = torch.zeros_like(x_q)
x_q = torch.where(fully_masked_keys.unsqueeze(1).unsqueeze(2), zero_output, x_q)
else: # self-attention
if self.attention_type == "mha":
kpm = ~query_mask if query_mask is not None else None
x_q, _ = self.attention(x_q, x_q, x_q, key_padding_mask=kpm)
else:
x_q, _ = self.attention(
x_q, x_q, x_q, query_mask=query_mask, key_mask=query_mask
) # (B, T_1, F)
if query_mask is not None:
# Zero out ALL padded query positions (not just fully-masked batches).
# Softmax/BigBird produce NaN at positions where the combined attn_mask row is
# all -inf (query_mask[b,i]=False → combined_mask row all-False → softmax → NaN).
# Must use masked_fill (not multiplication): NaN * 0 = NaN in IEEE 754, so a simple
# x * mask would still propagate NaN. masked_fill unconditionally writes 0.0.
x_q = x_q.masked_fill(~query_mask.unsqueeze(-1), 0.0)
x_q = F.dropout(x_q, p=self.dropout, training=self.training)
x_q = residual + x_q
residual = x_q
x_q = self.layer_norms[1](x_q)
x_q = F.gelu(self.fc1(x_q))
x_q = F.dropout(x_q, p=self.dropout, training=self.training)
x_q = self.fc2(x_q)
x_q = F.dropout(x_q, p=self.dropout, training=self.training)
x_q = residual + x_q
return x_q