linmult.core.attention¶
Attention mechanisms: linear, softmax, BigBird, Performer (FAVOR+), and GAU (flash).
Classes¶
Attention mechanism selection and its type-specific hyperparameters. |
|
Factory for creating attention layers from an |
|
Multi-head attention wrapper that projects inputs and reprojects the output. |
|
BigBird sparse attention: global + local-block + random tokens. |
|
Standard scaled dot-product softmax attention with O(N² D) complexity. |
|
Linear-complexity attention via kernel feature maps — O(N D²). |
|
Abstract base class defining the feature map interface for linear attention. |
|
Feature map defined by an element-wise activation function. |
|
ELU+1 feature map — default for |
|
performer (FAVOR+) positive random feature map (Choromanski et al., ICLR 2021). |
|
flash (GAU) — Hua et al., ICML 2022. |
Module Contents¶
- class linmult.core.attention.AttentionConfig[source]¶
Attention mechanism selection and its type-specific hyperparameters.
This is an internal construction spec — created from user-facing config fields in
LinTorLinMulT, then passed directly toTransformerEncoderandTAM. Each field is only relevant whentypematches.- Parameters:
type (str) – Attention mechanism. One of
"linear"(default),"performer","flash","softmax","bigbird","mha".dropout (float) – Dropout probability on attention weights. Defaults to
0.0.flash_query_key_dim (int | None) – Scoring dimension for
"flash"(GAU). Defaults toNone(computed asmax(d_model // 2, 16)).performer_num_random_features (int | None) – Random feature count for
"performer". Defaults toNone(computed asmax(head_dim * 4, 32)).bigbird_block_size (int) – Local block size for
"bigbird". Defaults to64.bigbird_num_global_tokens (int) – Global tokens for
"bigbird". Defaults to16.bigbird_num_random_tokens (int) – Random tokens for
"bigbird". Defaults to10.
- class linmult.core.attention.AttentionFactory[source]¶
Factory for creating attention layers from an
AttentionConfig.- static create(d_model: int, num_heads: int, attention_config: AttentionConfig | None = None) torch.nn.Module[source]¶
Create and return an attention layer.
- Parameters:
d_model (int) – Input feature dimensionality.
num_heads (int) – Number of attention heads.
attention_config (AttentionConfig, optional) – Attention configuration. Defaults to
AttentionConfig()(linear attention).
- Returns:
- An attention module.
"mha"returns nn.MultiheadAttention;"flash"(GAU) returnsGatedAttentionUnit; all others returnAttentionLayer.
- An attention module.
- Return type:
nn.Module
- Raises:
ValueError – If
attention_config.typeis not one of the supported values.
- class linmult.core.attention.AttentionLayer(attention: torch.nn.Module, d_model: int, num_heads: int, d_keys: int | None = None, d_values: int | None = None)[source]¶
Bases:
torch.nn.ModuleMulti-head attention wrapper that projects inputs and reprojects the output.
Projects queries, keys, and values to multi-head representations, delegates the actual attention computation to an inner attention module, then reprojects the concatenated heads back to
d_model.- Parameters:
attention (nn.Module) – Inner attention module (e.g.
LinearAttention,SoftmaxAttention).d_model (int) – Input and output feature dimensionality. Must be divisible by
num_heads.num_heads (int) – Number of attention heads.
d_keys (int, optional) – Per-head key/query dimensionality. Defaults to
d_model // num_heads.d_values (int, optional) – Per-head value dimensionality. Defaults to
d_model // num_heads.
- Raises:
ValueError – If
d_modelis not divisible bynum_heads.
Initialize internal Module state, shared by both nn.Module and ScriptModule.
- forward(queries, keys, values, query_mask=None, key_mask=None) tuple[torch.Tensor, torch.Tensor | None][source]¶
Apply multi-head attention.
- Parameters:
queries (torch.Tensor) – Shape
(B, T_1, D).keys (torch.Tensor) – Shape
(B, T_2, D).values (torch.Tensor) – Shape
(B, T_2, D).query_mask (torch.BoolTensor, optional) – Shape
(B, T_1). True = valid.key_mask (torch.BoolTensor, optional) – Shape
(B, T_2). True = valid.
- Returns:
- Attended output of shape
(B, T_1, D)and optional attention weights.
- Return type:
tuple[torch.Tensor, torch.Tensor | None]
- class linmult.core.attention.BigBirdAttention(num_heads: int, block_size: int, num_global_tokens: int, num_random_tokens: int, dropout: float = 0.0)[source]¶
Bases:
torch.nn.ModuleBigBird sparse attention: global + local-block + random tokens.
For self-attention (
tgt_len == src_len):Global queries (first G positions): full attention over all keys.
Non-global queries: each block attends to
local ∪ global ∪ randomkeys with a single softmax — matching the BigBird paper’s sparse attention.
For cross-attention (
tgt_len != src_len): falls back to full softmax attention, as the local-block pattern is undefined across different-length sequences.Note
Random key indices are sampled without duplicates (
torch.randperm), but are not filtered to exclude local-block or global positions. Overlapping positions receive slightly higher attention weight. This is a standard approximation in BigBird implementations with negligible practical impact.- Parameters:
num_heads (int) – Number of attention heads.
block_size (int) – Size of each local attention block.
num_global_tokens (int) – Number of global tokens (first G positions attend everywhere).
num_random_tokens (int) – Number of randomly sampled key positions per block.
dropout (float) – Dropout probability on attention weights. Defaults to
0.0.
Initialize internal Module state, shared by both nn.Module and ScriptModule.
- forward(q, k, v, attn_mask=None, **_kwargs)[source]¶
Compute BigBird sparse attention.
- Parameters:
q (torch.Tensor) – Queries of shape
(B, T, H, D).k (torch.Tensor) – Keys of shape
(B, S, H, D).v (torch.Tensor) – Values of shape
(B, S, H, D).attn_mask (torch.Tensor, optional) – Additive mask of shape
(B, 1, T, S).-infat positions to mask out.
- Returns:
- Output of shape
(B, T, H, D)andNone (no attention weight tensor is returned).
- Output of shape
- Return type:
tuple[torch.Tensor, None]
- class linmult.core.attention.SoftmaxAttention(d_model: int, num_heads: int, dropout: float = 0.0)[source]¶
Bases:
torch.nn.ModuleStandard scaled dot-product softmax attention with O(N² D) complexity.
Computes:
V’ = dropout(softmax(Q Kᵀ / √d)) V
- Parameters:
d_model (int) – Total model dimensionality.
num_heads (int) – Number of attention heads.
dropout (float) – Dropout probability on attention weights. Defaults to
0.0.
Initialize internal Module state, shared by both nn.Module and ScriptModule.
- forward(queries: torch.Tensor, keys: torch.Tensor, values: torch.Tensor, attn_mask: torch.Tensor | None = None, **_kwargs) tuple[torch.Tensor, torch.Tensor][source]¶
Compute softmax attention.
- Parameters:
queries (torch.Tensor) – Shape
(B, T_1, H, D).keys (torch.Tensor) – Shape
(B, T_2, H, D).values (torch.Tensor) – Shape
(B, T_2, H, D).attn_mask (torch.Tensor, optional) – Additive mask of shape
(B, 1, T_1, T_2).-infat positions to mask out.
- Returns:
- Output of shape
(B, T_1, H, D) and attention weights of shape
(B, H, T_1, T_2).
- Output of shape
- Return type:
tuple[torch.Tensor, torch.Tensor]
- class linmult.core.attention.LinearAttention(d_model: int, num_heads: int, feature_map: collections.abc.Callable | None = None)[source]¶
Bases:
torch.nn.ModuleLinear-complexity attention via kernel feature maps — O(N D²).
Instead of computing the full N×N softmax attention matrix, uses a feature map Φ(·) to decompose the kernel and compute:
V’ = normalize(Φ(Q) · Φ(K)ᵀ) · V
This allows reordering the computation to avoid materializing the attention matrix, giving O(N D²) cost where D is the feature-map output dimensionality.
Masking is handled by zeroing Q at masked query positions and K at masked key positions — no NaN risk (unlike softmax with all-
-infrows).- Attribution:
Angelos Katharopoulos, Apoorv Vyas — Idiap Research Institute. “Transformers are RNNs: Fast Autoregressive Transformers with Linear Attention”, ICML 2020. https://github.com/idiap/fast-transformers
- Parameters:
d_model (int) – Total model dimensionality.
num_heads (int) – Number of attention heads (
head_dim = d_model // num_heads).feature_map (callable, optional) – Factory that takes
query_dimsand returns aFeatureMapinstance. Defaults toEluFeatureMap(elu(x)+1).
Initialize internal Module state, shared by both nn.Module and ScriptModule.
- forward(queries, keys, values, query_mask=None, key_mask=None, **_kwargs)[source]¶
Compute linear attention.
- Parameters:
queries (torch.Tensor) – Shape
(B, T_1, H, D).keys (torch.Tensor) – Shape
(B, T_2, H, D).values (torch.Tensor) – Shape
(B, T_2, H, D).query_mask (torch.BoolTensor, optional) – Shape
(B, T_1). True = valid.key_mask (torch.BoolTensor, optional) – Shape
(B, T_2). True = valid.
- Returns:
- Output of shape
(B, T_1, H, D)andNone (no attention weight tensor is returned for linear attention).
- Output of shape
- Return type:
tuple[torch.Tensor, None]
- class linmult.core.attention.FeatureMap(query_dims: int)[source]¶
Bases:
torch.nn.ModuleAbstract base class defining the feature map interface for linear attention.
Subclasses implement Φ(·) such that Φ(Q)ᵀΦ(K) approximates (or equals) the desired attention kernel.
- Parameters:
query_dims (int) – Head dimensionality (
d_model // n_heads).
Initialize internal Module state, shared by both nn.Module and ScriptModule.
- abstractmethod new_feature_map(device: torch.device) None[source]¶
Reinitialize (re-sample) the feature map parameters for this forward pass.
Called once per forward pass by
LinearAttention. For random feature maps this samples a fresh projection matrix.- Parameters:
device (torch.device) – The torch device to create tensors on.
- Raises:
NotImplementedError – Must be implemented by subclasses.
- forward_queries(x: torch.Tensor) torch.Tensor[source]¶
Encode queries using this feature map.
- Parameters:
x (torch.Tensor) – Query tensor of shape
(B, T, H, D).- Returns:
Encoded queries of the same leading shape.
- Return type:
torch.Tensor
- forward_keys(x: torch.Tensor) torch.Tensor[source]¶
Encode keys using this feature map.
- Parameters:
x (torch.Tensor) – Key tensor of shape
(B, T, H, D).- Returns:
Encoded keys of the same leading shape.
- Return type:
torch.Tensor
- abstractmethod forward(x: torch.Tensor) torch.Tensor[source]¶
Encode
xusing this feature map.For symmetric feature maps it suffices to define this method. For asymmetric maps, override
forward_queriesandforward_keysseparately.- Parameters:
x (torch.Tensor) – Input tensor.
- Returns:
Encoded output.
- Return type:
torch.Tensor
- Raises:
NotImplementedError – Must be implemented by subclasses.
- classmethod factory(*args, **kwargs) collections.abc.Callable[source]¶
Return a factory callable for constructing this feature map.
The returned callable accepts
query_dimsand returns an instance of this class. Inherited by all subclasses, enabling use withLinearAttention’sfeature_mapargument.- Returns:
A factory function
query_dims → instance.- Return type:
Callable[[int], FeatureMap]
- class linmult.core.attention.ActivationFunctionFeatureMap(query_dims: int, activation_function: collections.abc.Callable)[source]¶
Bases:
FeatureMapFeature map defined by an element-wise activation function.
- Parameters:
query_dims (int) – Head dimensionality.
activation_function (callable) – Applied element-wise to the input tensor.
Initialize internal Module state, shared by both nn.Module and ScriptModule.
- class linmult.core.attention.EluFeatureMap(query_dims: int)[source]¶
Bases:
ActivationFunctionFeatureMapELU+1 feature map — default for
LinearAttention.Implements Φ(x) = elu(x) + 1, which satisfies Φ(x) ≥ 0 everywhere and yields a valid positive-definite kernel without random projections.
- Parameters:
query_dims (int) – Head dimensionality (
d_model // n_heads).
Initialize internal Module state, shared by both nn.Module and ScriptModule.
- class linmult.core.attention.PerformerFeatureMap(query_dims: int, num_features: int | None = None)[source]¶
Bases:
FeatureMapperformer (FAVOR+) positive random feature map (Choromanski et al., ICLR 2021).
Provides an unbiased estimator of the softmax attention kernel using orthogonal random features. Unlike
elu+1, usesr >> head_dimfeatures, directly addressing the capacity limitation of small head dimensions:Φ(x)ᵢ = exp(x·ωᵢ − ‖x‖²/2) / √r for r orthogonal vectors ωᵢ
E[Φ(x)ᵀΦ(y)] ≈ exp(xᵀy) — unbiased estimator of the softmax kernel.
new_feature_map()resamples the projection each forward pass, reducing variance across training steps.Select via config:
attention_type: performerTune via config:
performer_num_random_features: 64(default:max(head_dim*4, 32))- Parameters:
query_dims (int) – Head dimensionality (
d_model // n_heads).num_features (int, optional) – Number of random features
r. Defaults tomax(query_dims * 4, 32).
Initialize internal Module state, shared by both nn.Module and ScriptModule.
- class linmult.core.attention.GatedAttentionUnit(d_model: int, query_key_dim: int | None = None, dropout: float = 0.0)[source]¶
Bases:
torch.nn.Moduleflash (GAU) — Hua et al., ICML 2022.
Replaces multi-head attention with single-head gated linear attention:
u = SiLU(W_u · queries) # gate from query stream v = W_v · values # values from key/value stream q = relu(W_q · queries)² # scoring query (always ≥ 0) k = relu(W_k · keys)² # scoring key (always ≥ 0) a = linear_attn(q, k, v) # O(N·s) single-head attention output = W_o · (u ⊙ a) # gated output
relu² ensures k·q ≥ 0 everywhere, keeping the linear attention denominator positive without a learned feature map. Supports cross-attention: gate and scoring query come from the query (target) stream; scoring key and value come from the key/value (source) stream.
The forward interface matches
AttentionLayer, so it is a drop-in replacement inTransformerEncoderLayerwithout any changes totransformer.py.Select via config:
attention_type: flash- Tune via config:
flash_query_key_dim: 32(default:max(d_model // 2, 16)) dropout_attention: 0.1
- Parameters:
d_model (int) – Input and output feature dimensionality.
query_key_dim (int, optional) – Scoring dimension
s. Defaults tomax(d_model // 2, 16).dropout (float) – Dropout on the gated pre-projection tensor
u ⊙ a. Defaults to0.0.
Initialize internal Module state, shared by both nn.Module and ScriptModule.
- forward(queries: torch.Tensor, keys: torch.Tensor, values: torch.Tensor, query_mask: torch.Tensor | None = None, key_mask: torch.Tensor | None = None, **_kwargs) tuple[torch.Tensor, None][source]¶
Compute gated linear attention.
- Parameters:
queries (torch.Tensor) – Shape
(B, T_q, d_model).keys (torch.Tensor) – Shape
(B, T_k, d_model).values (torch.Tensor) – Shape
(B, T_k, d_model).query_mask (torch.BoolTensor, optional) – Shape
(B, T_q). True = valid. Masked output positions are set to zero.key_mask (torch.BoolTensor, optional) – Shape
(B, T_k). True = valid. Masked key/value positions are zeroed before accumulation.
- Returns:
- Output of shape
(B, T_q, d_model)and None(no attention weight tensor is returned).
- Output of shape
- Return type:
tuple[torch.Tensor, None]
- Tune via config: