linmult.core.attention

Attention mechanisms: linear, softmax, BigBird, Performer (FAVOR+), and GAU (flash).

Classes

AttentionConfig

Attention mechanism selection and its type-specific hyperparameters.

AttentionFactory

Factory for creating attention layers from an AttentionConfig.

AttentionLayer

Multi-head attention wrapper that projects inputs and reprojects the output.

BigBirdAttention

BigBird sparse attention: global + local-block + random tokens.

SoftmaxAttention

Standard scaled dot-product softmax attention with O(N² D) complexity.

LinearAttention

Linear-complexity attention via kernel feature maps — O(N D²).

FeatureMap

Abstract base class defining the feature map interface for linear attention.

ActivationFunctionFeatureMap

Feature map defined by an element-wise activation function.

EluFeatureMap

ELU+1 feature map — default for LinearAttention.

PerformerFeatureMap

performer (FAVOR+) positive random feature map (Choromanski et al., ICLR 2021).

GatedAttentionUnit

flash (GAU) — Hua et al., ICML 2022.

Module Contents

class linmult.core.attention.AttentionConfig[source]

Attention mechanism selection and its type-specific hyperparameters.

This is an internal construction spec — created from user-facing config fields in LinT or LinMulT, then passed directly to TransformerEncoder and TAM. Each field is only relevant when type matches.

Parameters:
  • type (str) – Attention mechanism. One of "linear" (default), "performer", "flash", "softmax", "bigbird", "mha".

  • dropout (float) – Dropout probability on attention weights. Defaults to 0.0.

  • flash_query_key_dim (int | None) – Scoring dimension for "flash" (GAU). Defaults to None (computed as max(d_model // 2, 16)).

  • performer_num_random_features (int | None) – Random feature count for "performer". Defaults to None (computed as max(head_dim * 4, 32)).

  • bigbird_block_size (int) – Local block size for "bigbird". Defaults to 64.

  • bigbird_num_global_tokens (int) – Global tokens for "bigbird". Defaults to 16.

  • bigbird_num_random_tokens (int) – Random tokens for "bigbird". Defaults to 10.

class linmult.core.attention.AttentionFactory[source]

Factory for creating attention layers from an AttentionConfig.

static create(d_model: int, num_heads: int, attention_config: AttentionConfig | None = None) torch.nn.Module[source]

Create and return an attention layer.

Parameters:
  • d_model (int) – Input feature dimensionality.

  • num_heads (int) – Number of attention heads.

  • attention_config (AttentionConfig, optional) – Attention configuration. Defaults to AttentionConfig() (linear attention).

Returns:

An attention module. "mha" returns

nn.MultiheadAttention; "flash" (GAU) returns GatedAttentionUnit; all others return AttentionLayer.

Return type:

nn.Module

Raises:

ValueError – If attention_config.type is not one of the supported values.

class linmult.core.attention.AttentionLayer(attention: torch.nn.Module, d_model: int, num_heads: int, d_keys: int | None = None, d_values: int | None = None)[source]

Bases: torch.nn.Module

Multi-head attention wrapper that projects inputs and reprojects the output.

Projects queries, keys, and values to multi-head representations, delegates the actual attention computation to an inner attention module, then reprojects the concatenated heads back to d_model.

Parameters:
  • attention (nn.Module) – Inner attention module (e.g. LinearAttention, SoftmaxAttention).

  • d_model (int) – Input and output feature dimensionality. Must be divisible by num_heads.

  • num_heads (int) – Number of attention heads.

  • d_keys (int, optional) – Per-head key/query dimensionality. Defaults to d_model // num_heads.

  • d_values (int, optional) – Per-head value dimensionality. Defaults to d_model // num_heads.

Raises:

ValueError – If d_model is not divisible by num_heads.

Initialize internal Module state, shared by both nn.Module and ScriptModule.

forward(queries, keys, values, query_mask=None, key_mask=None) tuple[torch.Tensor, torch.Tensor | None][source]

Apply multi-head attention.

Parameters:
  • queries (torch.Tensor) – Shape (B, T_1, D).

  • keys (torch.Tensor) – Shape (B, T_2, D).

  • values (torch.Tensor) – Shape (B, T_2, D).

  • query_mask (torch.BoolTensor, optional) – Shape (B, T_1). True = valid.

  • key_mask (torch.BoolTensor, optional) – Shape (B, T_2). True = valid.

Returns:

Attended output of shape

(B, T_1, D) and optional attention weights.

Return type:

tuple[torch.Tensor, torch.Tensor | None]

class linmult.core.attention.BigBirdAttention(num_heads: int, block_size: int, num_global_tokens: int, num_random_tokens: int, dropout: float = 0.0)[source]

Bases: torch.nn.Module

BigBird sparse attention: global + local-block + random tokens.

For self-attention (tgt_len == src_len):

  • Global queries (first G positions): full attention over all keys.

  • Non-global queries: each block attends to local global random keys with a single softmax — matching the BigBird paper’s sparse attention.

For cross-attention (tgt_len != src_len): falls back to full softmax attention, as the local-block pattern is undefined across different-length sequences.

Note

Random key indices are sampled without duplicates (torch.randperm), but are not filtered to exclude local-block or global positions. Overlapping positions receive slightly higher attention weight. This is a standard approximation in BigBird implementations with negligible practical impact.

Parameters:
  • num_heads (int) – Number of attention heads.

  • block_size (int) – Size of each local attention block.

  • num_global_tokens (int) – Number of global tokens (first G positions attend everywhere).

  • num_random_tokens (int) – Number of randomly sampled key positions per block.

  • dropout (float) – Dropout probability on attention weights. Defaults to 0.0.

Initialize internal Module state, shared by both nn.Module and ScriptModule.

forward(q, k, v, attn_mask=None, **_kwargs)[source]

Compute BigBird sparse attention.

Parameters:
  • q (torch.Tensor) – Queries of shape (B, T, H, D).

  • k (torch.Tensor) – Keys of shape (B, S, H, D).

  • v (torch.Tensor) – Values of shape (B, S, H, D).

  • attn_mask (torch.Tensor, optional) – Additive mask of shape (B, 1, T, S). -inf at positions to mask out.

Returns:

Output of shape (B, T, H, D) and None

(no attention weight tensor is returned).

Return type:

tuple[torch.Tensor, None]

class linmult.core.attention.SoftmaxAttention(d_model: int, num_heads: int, dropout: float = 0.0)[source]

Bases: torch.nn.Module

Standard scaled dot-product softmax attention with O(N² D) complexity.

Computes:

V’ = dropout(softmax(Q Kᵀ / √d)) V

Parameters:
  • d_model (int) – Total model dimensionality.

  • num_heads (int) – Number of attention heads.

  • dropout (float) – Dropout probability on attention weights. Defaults to 0.0.

Initialize internal Module state, shared by both nn.Module and ScriptModule.

forward(queries: torch.Tensor, keys: torch.Tensor, values: torch.Tensor, attn_mask: torch.Tensor | None = None, **_kwargs) tuple[torch.Tensor, torch.Tensor][source]

Compute softmax attention.

Parameters:
  • queries (torch.Tensor) – Shape (B, T_1, H, D).

  • keys (torch.Tensor) – Shape (B, T_2, H, D).

  • values (torch.Tensor) – Shape (B, T_2, H, D).

  • attn_mask (torch.Tensor, optional) – Additive mask of shape (B, 1, T_1, T_2). -inf at positions to mask out.

Returns:

Output of shape (B, T_1, H, D)

and attention weights of shape (B, H, T_1, T_2).

Return type:

tuple[torch.Tensor, torch.Tensor]

class linmult.core.attention.LinearAttention(d_model: int, num_heads: int, feature_map: collections.abc.Callable | None = None)[source]

Bases: torch.nn.Module

Linear-complexity attention via kernel feature maps — O(N D²).

Instead of computing the full N×N softmax attention matrix, uses a feature map Φ(·) to decompose the kernel and compute:

V’ = normalize(Φ(Q) · Φ(K)ᵀ) · V

This allows reordering the computation to avoid materializing the attention matrix, giving O(N D²) cost where D is the feature-map output dimensionality.

Masking is handled by zeroing Q at masked query positions and K at masked key positions — no NaN risk (unlike softmax with all--inf rows).

Attribution:

Angelos Katharopoulos, Apoorv Vyas — Idiap Research Institute. “Transformers are RNNs: Fast Autoregressive Transformers with Linear Attention”, ICML 2020. https://github.com/idiap/fast-transformers

Parameters:
  • d_model (int) – Total model dimensionality.

  • num_heads (int) – Number of attention heads (head_dim = d_model // num_heads).

  • feature_map (callable, optional) – Factory that takes query_dims and returns a FeatureMap instance. Defaults to EluFeatureMap (elu(x)+1).

Initialize internal Module state, shared by both nn.Module and ScriptModule.

forward(queries, keys, values, query_mask=None, key_mask=None, **_kwargs)[source]

Compute linear attention.

Parameters:
  • queries (torch.Tensor) – Shape (B, T_1, H, D).

  • keys (torch.Tensor) – Shape (B, T_2, H, D).

  • values (torch.Tensor) – Shape (B, T_2, H, D).

  • query_mask (torch.BoolTensor, optional) – Shape (B, T_1). True = valid.

  • key_mask (torch.BoolTensor, optional) – Shape (B, T_2). True = valid.

Returns:

Output of shape (B, T_1, H, D) and None

(no attention weight tensor is returned for linear attention).

Return type:

tuple[torch.Tensor, None]

class linmult.core.attention.FeatureMap(query_dims: int)[source]

Bases: torch.nn.Module

Abstract base class defining the feature map interface for linear attention.

Subclasses implement Φ(·) such that Φ(Q)ᵀΦ(K) approximates (or equals) the desired attention kernel.

Parameters:

query_dims (int) – Head dimensionality (d_model // n_heads).

Initialize internal Module state, shared by both nn.Module and ScriptModule.

abstractmethod new_feature_map(device: torch.device) None[source]

Reinitialize (re-sample) the feature map parameters for this forward pass.

Called once per forward pass by LinearAttention. For random feature maps this samples a fresh projection matrix.

Parameters:

device (torch.device) – The torch device to create tensors on.

Raises:

NotImplementedError – Must be implemented by subclasses.

forward_queries(x: torch.Tensor) torch.Tensor[source]

Encode queries using this feature map.

Parameters:

x (torch.Tensor) – Query tensor of shape (B, T, H, D).

Returns:

Encoded queries of the same leading shape.

Return type:

torch.Tensor

forward_keys(x: torch.Tensor) torch.Tensor[source]

Encode keys using this feature map.

Parameters:

x (torch.Tensor) – Key tensor of shape (B, T, H, D).

Returns:

Encoded keys of the same leading shape.

Return type:

torch.Tensor

abstractmethod forward(x: torch.Tensor) torch.Tensor[source]

Encode x using this feature map.

For symmetric feature maps it suffices to define this method. For asymmetric maps, override forward_queries and forward_keys separately.

Parameters:

x (torch.Tensor) – Input tensor.

Returns:

Encoded output.

Return type:

torch.Tensor

Raises:

NotImplementedError – Must be implemented by subclasses.

classmethod factory(*args, **kwargs) collections.abc.Callable[source]

Return a factory callable for constructing this feature map.

The returned callable accepts query_dims and returns an instance of this class. Inherited by all subclasses, enabling use with LinearAttention’s feature_map argument.

Returns:

A factory function query_dims instance.

Return type:

Callable[[int], FeatureMap]

class linmult.core.attention.ActivationFunctionFeatureMap(query_dims: int, activation_function: collections.abc.Callable)[source]

Bases: FeatureMap

Feature map defined by an element-wise activation function.

Parameters:
  • query_dims (int) – Head dimensionality.

  • activation_function (callable) – Applied element-wise to the input tensor.

Initialize internal Module state, shared by both nn.Module and ScriptModule.

new_feature_map(device: torch.device) None[source]

No-op: activation-based feature maps have no random parameters.

forward(x: torch.Tensor) torch.Tensor[source]

Apply the activation function element-wise.

Parameters:

x (torch.Tensor) – Input tensor of any shape.

Returns:

Activated tensor of the same shape.

Return type:

torch.Tensor

class linmult.core.attention.EluFeatureMap(query_dims: int)[source]

Bases: ActivationFunctionFeatureMap

ELU+1 feature map — default for LinearAttention.

Implements Φ(x) = elu(x) + 1, which satisfies Φ(x) ≥ 0 everywhere and yields a valid positive-definite kernel without random projections.

Parameters:

query_dims (int) – Head dimensionality (d_model // n_heads).

Initialize internal Module state, shared by both nn.Module and ScriptModule.

class linmult.core.attention.PerformerFeatureMap(query_dims: int, num_features: int | None = None)[source]

Bases: FeatureMap

performer (FAVOR+) positive random feature map (Choromanski et al., ICLR 2021).

Provides an unbiased estimator of the softmax attention kernel using orthogonal random features. Unlike elu+1, uses r >> head_dim features, directly addressing the capacity limitation of small head dimensions:

Φ(x)ᵢ = exp(x·ωᵢ − ‖x‖²/2) / √r for r orthogonal vectors ωᵢ

E[Φ(x)ᵀΦ(y)] ≈ exp(xᵀy) — unbiased estimator of the softmax kernel. new_feature_map() resamples the projection each forward pass, reducing variance across training steps.

Select via config: attention_type: performer

Tune via config: performer_num_random_features: 64 (default: max(head_dim*4, 32))

Parameters:
  • query_dims (int) – Head dimensionality (d_model // n_heads).

  • num_features (int, optional) – Number of random features r. Defaults to max(query_dims * 4, 32).

Initialize internal Module state, shared by both nn.Module and ScriptModule.

new_feature_map(device: torch.device) None[source]

Sample a new orthogonal random projection matrix.

Parameters:

device (torch.device) – Target device for the projection tensor.

forward(x: torch.Tensor) torch.Tensor[source]

Apply the FAVOR+ feature map.

Parameters:

x (torch.Tensor) – Input of shape (..., query_dims).

Returns:

Positive random features of shape (..., num_features).

Return type:

torch.Tensor

class linmult.core.attention.GatedAttentionUnit(d_model: int, query_key_dim: int | None = None, dropout: float = 0.0)[source]

Bases: torch.nn.Module

flash (GAU) — Hua et al., ICML 2022.

Replaces multi-head attention with single-head gated linear attention:

u = SiLU(W_u · queries) # gate from query stream v = W_v · values # values from key/value stream q = relu(W_q · queries)² # scoring query (always ≥ 0) k = relu(W_k · keys)² # scoring key (always ≥ 0) a = linear_attn(q, k, v) # O(N·s) single-head attention output = W_o · (u ⊙ a) # gated output

relu² ensures k·q ≥ 0 everywhere, keeping the linear attention denominator positive without a learned feature map. Supports cross-attention: gate and scoring query come from the query (target) stream; scoring key and value come from the key/value (source) stream.

The forward interface matches AttentionLayer, so it is a drop-in replacement in TransformerEncoderLayer without any changes to transformer.py.

Select via config: attention_type: flash

Tune via config: flash_query_key_dim: 32 (default: max(d_model // 2, 16))

dropout_attention: 0.1

Parameters:
  • d_model (int) – Input and output feature dimensionality.

  • query_key_dim (int, optional) – Scoring dimension s. Defaults to max(d_model // 2, 16).

  • dropout (float) – Dropout on the gated pre-projection tensor u a. Defaults to 0.0.

Initialize internal Module state, shared by both nn.Module and ScriptModule.

forward(queries: torch.Tensor, keys: torch.Tensor, values: torch.Tensor, query_mask: torch.Tensor | None = None, key_mask: torch.Tensor | None = None, **_kwargs) tuple[torch.Tensor, None][source]

Compute gated linear attention.

Parameters:
  • queries (torch.Tensor) – Shape (B, T_q, d_model).

  • keys (torch.Tensor) – Shape (B, T_k, d_model).

  • values (torch.Tensor) – Shape (B, T_k, d_model).

  • query_mask (torch.BoolTensor, optional) – Shape (B, T_q). True = valid. Masked output positions are set to zero.

  • key_mask (torch.BoolTensor, optional) – Shape (B, T_k). True = valid. Masked key/value positions are zeroed before accumulation.

Returns:

Output of shape (B, T_q, d_model) and

None (no attention weight tensor is returned).

Return type:

tuple[torch.Tensor, None]