linmult.core.attention ====================== .. py:module:: linmult.core.attention .. autoapi-nested-parse:: Attention mechanisms: linear, softmax, BigBird, Performer (FAVOR+), and GAU (flash). Classes ------- .. autoapisummary:: linmult.core.attention.AttentionConfig linmult.core.attention.AttentionFactory linmult.core.attention.AttentionLayer linmult.core.attention.BigBirdAttention linmult.core.attention.SoftmaxAttention linmult.core.attention.LinearAttention linmult.core.attention.FeatureMap linmult.core.attention.ActivationFunctionFeatureMap linmult.core.attention.EluFeatureMap linmult.core.attention.PerformerFeatureMap linmult.core.attention.GatedAttentionUnit Module Contents --------------- .. py:class:: AttentionConfig Attention mechanism selection and its type-specific hyperparameters. This is an internal construction spec — created from user-facing config fields in ``LinT`` or ``LinMulT``, then passed directly to :class:`TransformerEncoder` and :class:`TAM`. Each field is only relevant when ``type`` matches. :param type: Attention mechanism. One of ``"linear"`` (default), ``"performer"``, ``"flash"``, ``"softmax"``, ``"bigbird"``, ``"mha"``. :type type: str :param dropout: Dropout probability on attention weights. Defaults to ``0.0``. :type dropout: float :param flash_query_key_dim: Scoring dimension for ``"flash"`` (GAU). Defaults to ``None`` (computed as ``max(d_model // 2, 16)``). :type flash_query_key_dim: int | None :param performer_num_random_features: Random feature count for ``"performer"``. Defaults to ``None`` (computed as ``max(head_dim * 4, 32)``). :type performer_num_random_features: int | None :param bigbird_block_size: Local block size for ``"bigbird"``. Defaults to ``64``. :type bigbird_block_size: int :param bigbird_num_global_tokens: Global tokens for ``"bigbird"``. Defaults to ``16``. :type bigbird_num_global_tokens: int :param bigbird_num_random_tokens: Random tokens for ``"bigbird"``. Defaults to ``10``. :type bigbird_num_random_tokens: int .. py:class:: AttentionFactory Factory for creating attention layers from an :class:`AttentionConfig`. .. py:method:: create(d_model: int, num_heads: int, attention_config: AttentionConfig | None = None) -> torch.nn.Module :staticmethod: Create and return an attention layer. :param d_model: Input feature dimensionality. :type d_model: int :param num_heads: Number of attention heads. :type num_heads: int :param attention_config: Attention configuration. Defaults to ``AttentionConfig()`` (linear attention). :type attention_config: AttentionConfig, optional :returns: An attention module. ``"mha"`` returns ``nn.MultiheadAttention``; ``"flash"`` (GAU) returns :class:`GatedAttentionUnit`; all others return :class:`AttentionLayer`. :rtype: nn.Module :raises ValueError: If ``attention_config.type`` is not one of the supported values. .. py:class:: AttentionLayer(attention: torch.nn.Module, d_model: int, num_heads: int, d_keys: int | None = None, d_values: int | None = None) Bases: :py:obj:`torch.nn.Module` Multi-head attention wrapper that projects inputs and reprojects the output. Projects queries, keys, and values to multi-head representations, delegates the actual attention computation to an inner attention module, then reprojects the concatenated heads back to ``d_model``. :param attention: Inner attention module (e.g. ``LinearAttention``, ``SoftmaxAttention``). :type attention: nn.Module :param d_model: Input and output feature dimensionality. Must be divisible by ``num_heads``. :type d_model: int :param num_heads: Number of attention heads. :type num_heads: int :param d_keys: Per-head key/query dimensionality. Defaults to ``d_model // num_heads``. :type d_keys: int, optional :param d_values: Per-head value dimensionality. Defaults to ``d_model // num_heads``. :type d_values: int, optional :raises ValueError: If ``d_model`` is not divisible by ``num_heads``. Initialize internal Module state, shared by both nn.Module and ScriptModule. .. py:method:: forward(queries, keys, values, query_mask=None, key_mask=None) -> tuple[torch.Tensor, torch.Tensor | None] Apply multi-head attention. :param queries: Shape ``(B, T_1, D)``. :type queries: torch.Tensor :param keys: Shape ``(B, T_2, D)``. :type keys: torch.Tensor :param values: Shape ``(B, T_2, D)``. :type values: torch.Tensor :param query_mask: Shape ``(B, T_1)``. True = valid. :type query_mask: torch.BoolTensor, optional :param key_mask: Shape ``(B, T_2)``. True = valid. :type key_mask: torch.BoolTensor, optional :returns: Attended output of shape ``(B, T_1, D)`` and optional attention weights. :rtype: tuple[torch.Tensor, torch.Tensor | None] .. py:class:: BigBirdAttention(num_heads: int, block_size: int, num_global_tokens: int, num_random_tokens: int, dropout: float = 0.0) Bases: :py:obj:`torch.nn.Module` BigBird sparse attention: global + local-block + random tokens. For self-attention (``tgt_len == src_len``): - Global queries (first G positions): full attention over all keys. - Non-global queries: each block attends to ``local ∪ global ∪ random`` keys with a single softmax — matching the BigBird paper's sparse attention. For cross-attention (``tgt_len != src_len``): falls back to full softmax attention, as the local-block pattern is undefined across different-length sequences. .. note:: Random key indices are sampled without duplicates (``torch.randperm``), but are not filtered to exclude local-block or global positions. Overlapping positions receive slightly higher attention weight. This is a standard approximation in BigBird implementations with negligible practical impact. :param num_heads: Number of attention heads. :type num_heads: int :param block_size: Size of each local attention block. :type block_size: int :param num_global_tokens: Number of global tokens (first G positions attend everywhere). :type num_global_tokens: int :param num_random_tokens: Number of randomly sampled key positions per block. :type num_random_tokens: int :param dropout: Dropout probability on attention weights. Defaults to ``0.0``. :type dropout: float Initialize internal Module state, shared by both nn.Module and ScriptModule. .. py:method:: forward(q, k, v, attn_mask=None, **_kwargs) Compute BigBird sparse attention. :param q: Queries of shape ``(B, T, H, D)``. :type q: torch.Tensor :param k: Keys of shape ``(B, S, H, D)``. :type k: torch.Tensor :param v: Values of shape ``(B, S, H, D)``. :type v: torch.Tensor :param attn_mask: Additive mask of shape ``(B, 1, T, S)``. ``-inf`` at positions to mask out. :type attn_mask: torch.Tensor, optional :returns: Output of shape ``(B, T, H, D)`` and ``None`` (no attention weight tensor is returned). :rtype: tuple[torch.Tensor, None] .. py:class:: SoftmaxAttention(d_model: int, num_heads: int, dropout: float = 0.0) Bases: :py:obj:`torch.nn.Module` Standard scaled dot-product softmax attention with O(N² D) complexity. Computes: V' = dropout(softmax(Q Kᵀ / √d)) V :param d_model: Total model dimensionality. :type d_model: int :param num_heads: Number of attention heads. :type num_heads: int :param dropout: Dropout probability on attention weights. Defaults to ``0.0``. :type dropout: float Initialize internal Module state, shared by both nn.Module and ScriptModule. .. py:method:: forward(queries: torch.Tensor, keys: torch.Tensor, values: torch.Tensor, attn_mask: torch.Tensor | None = None, **_kwargs) -> tuple[torch.Tensor, torch.Tensor] Compute softmax attention. :param queries: Shape ``(B, T_1, H, D)``. :type queries: torch.Tensor :param keys: Shape ``(B, T_2, H, D)``. :type keys: torch.Tensor :param values: Shape ``(B, T_2, H, D)``. :type values: torch.Tensor :param attn_mask: Additive mask of shape ``(B, 1, T_1, T_2)``. ``-inf`` at positions to mask out. :type attn_mask: torch.Tensor, optional :returns: Output of shape ``(B, T_1, H, D)`` and attention weights of shape ``(B, H, T_1, T_2)``. :rtype: tuple[torch.Tensor, torch.Tensor] .. py:class:: LinearAttention(d_model: int, num_heads: int, feature_map: collections.abc.Callable | None = None) Bases: :py:obj:`torch.nn.Module` Linear-complexity attention via kernel feature maps — O(N D²). Instead of computing the full N×N softmax attention matrix, uses a feature map Φ(·) to decompose the kernel and compute: V' = normalize(Φ(Q) · Φ(K)ᵀ) · V This allows reordering the computation to avoid materializing the attention matrix, giving O(N D²) cost where D is the feature-map output dimensionality. Masking is handled by zeroing Q at masked query positions and K at masked key positions — no NaN risk (unlike softmax with all-``-inf`` rows). Attribution: Angelos Katharopoulos, Apoorv Vyas — Idiap Research Institute. "Transformers are RNNs: Fast Autoregressive Transformers with Linear Attention", ICML 2020. https://github.com/idiap/fast-transformers :param d_model: Total model dimensionality. :type d_model: int :param num_heads: Number of attention heads (``head_dim = d_model // num_heads``). :type num_heads: int :param feature_map: Factory that takes ``query_dims`` and returns a ``FeatureMap`` instance. Defaults to ``EluFeatureMap`` (elu(x)+1). :type feature_map: callable, optional Initialize internal Module state, shared by both nn.Module and ScriptModule. .. py:method:: forward(queries, keys, values, query_mask=None, key_mask=None, **_kwargs) Compute linear attention. :param queries: Shape ``(B, T_1, H, D)``. :type queries: torch.Tensor :param keys: Shape ``(B, T_2, H, D)``. :type keys: torch.Tensor :param values: Shape ``(B, T_2, H, D)``. :type values: torch.Tensor :param query_mask: Shape ``(B, T_1)``. True = valid. :type query_mask: torch.BoolTensor, optional :param key_mask: Shape ``(B, T_2)``. True = valid. :type key_mask: torch.BoolTensor, optional :returns: Output of shape ``(B, T_1, H, D)`` and ``None`` (no attention weight tensor is returned for linear attention). :rtype: tuple[torch.Tensor, None] .. py:class:: FeatureMap(query_dims: int) Bases: :py:obj:`torch.nn.Module` Abstract base class defining the feature map interface for linear attention. Subclasses implement Φ(·) such that Φ(Q)ᵀΦ(K) approximates (or equals) the desired attention kernel. :param query_dims: Head dimensionality (``d_model // n_heads``). :type query_dims: int Initialize internal Module state, shared by both nn.Module and ScriptModule. .. py:method:: new_feature_map(device: torch.device) -> None :abstractmethod: Reinitialize (re-sample) the feature map parameters for this forward pass. Called once per forward pass by ``LinearAttention``. For random feature maps this samples a fresh projection matrix. :param device: The torch device to create tensors on. :type device: torch.device :raises NotImplementedError: Must be implemented by subclasses. .. py:method:: forward_queries(x: torch.Tensor) -> torch.Tensor Encode queries using this feature map. :param x: Query tensor of shape ``(B, T, H, D)``. :type x: torch.Tensor :returns: Encoded queries of the same leading shape. :rtype: torch.Tensor .. py:method:: forward_keys(x: torch.Tensor) -> torch.Tensor Encode keys using this feature map. :param x: Key tensor of shape ``(B, T, H, D)``. :type x: torch.Tensor :returns: Encoded keys of the same leading shape. :rtype: torch.Tensor .. py:method:: forward(x: torch.Tensor) -> torch.Tensor :abstractmethod: Encode ``x`` using this feature map. For symmetric feature maps it suffices to define this method. For asymmetric maps, override ``forward_queries`` and ``forward_keys`` separately. :param x: Input tensor. :type x: torch.Tensor :returns: Encoded output. :rtype: torch.Tensor :raises NotImplementedError: Must be implemented by subclasses. .. py:method:: factory(*args, **kwargs) -> collections.abc.Callable :classmethod: Return a factory callable for constructing this feature map. The returned callable accepts ``query_dims`` and returns an instance of this class. Inherited by all subclasses, enabling use with ``LinearAttention``'s ``feature_map`` argument. :returns: A factory function ``query_dims → instance``. :rtype: Callable[[int], FeatureMap] .. py:class:: ActivationFunctionFeatureMap(query_dims: int, activation_function: collections.abc.Callable) Bases: :py:obj:`FeatureMap` Feature map defined by an element-wise activation function. :param query_dims: Head dimensionality. :type query_dims: int :param activation_function: Applied element-wise to the input tensor. :type activation_function: callable Initialize internal Module state, shared by both nn.Module and ScriptModule. .. py:method:: new_feature_map(device: torch.device) -> None No-op: activation-based feature maps have no random parameters. .. py:method:: forward(x: torch.Tensor) -> torch.Tensor Apply the activation function element-wise. :param x: Input tensor of any shape. :type x: torch.Tensor :returns: Activated tensor of the same shape. :rtype: torch.Tensor .. py:class:: EluFeatureMap(query_dims: int) Bases: :py:obj:`ActivationFunctionFeatureMap` ELU+1 feature map — default for ``LinearAttention``. Implements Φ(x) = elu(x) + 1, which satisfies Φ(x) ≥ 0 everywhere and yields a valid positive-definite kernel without random projections. :param query_dims: Head dimensionality (``d_model // n_heads``). :type query_dims: int Initialize internal Module state, shared by both nn.Module and ScriptModule. .. py:class:: PerformerFeatureMap(query_dims: int, num_features: int | None = None) Bases: :py:obj:`FeatureMap` performer (FAVOR+) positive random feature map (Choromanski et al., ICLR 2021). Provides an unbiased estimator of the softmax attention kernel using orthogonal random features. Unlike ``elu+1``, uses ``r >> head_dim`` features, directly addressing the capacity limitation of small head dimensions: Φ(x)ᵢ = exp(x·ωᵢ − ‖x‖²/2) / √r for r orthogonal vectors ωᵢ E[Φ(x)ᵀΦ(y)] ≈ exp(xᵀy) — unbiased estimator of the softmax kernel. ``new_feature_map()`` resamples the projection each forward pass, reducing variance across training steps. Select via config: ``attention_type: performer`` Tune via config: ``performer_num_random_features: 64`` (default: ``max(head_dim*4, 32)``) :param query_dims: Head dimensionality (``d_model // n_heads``). :type query_dims: int :param num_features: Number of random features ``r``. Defaults to ``max(query_dims * 4, 32)``. :type num_features: int, optional Initialize internal Module state, shared by both nn.Module and ScriptModule. .. py:method:: new_feature_map(device: torch.device) -> None Sample a new orthogonal random projection matrix. :param device: Target device for the projection tensor. :type device: torch.device .. py:method:: forward(x: torch.Tensor) -> torch.Tensor Apply the FAVOR+ feature map. :param x: Input of shape ``(..., query_dims)``. :type x: torch.Tensor :returns: Positive random features of shape ``(..., num_features)``. :rtype: torch.Tensor .. py:class:: GatedAttentionUnit(d_model: int, query_key_dim: int | None = None, dropout: float = 0.0) Bases: :py:obj:`torch.nn.Module` flash (GAU) — Hua et al., ICML 2022. Replaces multi-head attention with single-head gated linear attention: u = SiLU(W_u · queries) # gate from query stream v = W_v · values # values from key/value stream q = relu(W_q · queries)² # scoring query (always ≥ 0) k = relu(W_k · keys)² # scoring key (always ≥ 0) a = linear_attn(q, k, v) # O(N·s) single-head attention output = W_o · (u ⊙ a) # gated output relu² ensures k·q ≥ 0 everywhere, keeping the linear attention denominator positive without a learned feature map. Supports cross-attention: gate and scoring query come from the query (target) stream; scoring key and value come from the key/value (source) stream. The forward interface matches ``AttentionLayer``, so it is a drop-in replacement in ``TransformerEncoderLayer`` without any changes to ``transformer.py``. Select via config: ``attention_type: flash`` Tune via config: ``flash_query_key_dim: 32`` (default: ``max(d_model // 2, 16)``) ``dropout_attention: 0.1`` :param d_model: Input and output feature dimensionality. :type d_model: int :param query_key_dim: Scoring dimension ``s``. Defaults to ``max(d_model // 2, 16)``. :type query_key_dim: int, optional :param dropout: Dropout on the gated pre-projection tensor ``u ⊙ a``. Defaults to ``0.0``. :type dropout: float Initialize internal Module state, shared by both nn.Module and ScriptModule. .. py:method:: forward(queries: torch.Tensor, keys: torch.Tensor, values: torch.Tensor, query_mask: torch.Tensor | None = None, key_mask: torch.Tensor | None = None, **_kwargs) -> tuple[torch.Tensor, None] Compute gated linear attention. :param queries: Shape ``(B, T_q, d_model)``. :type queries: torch.Tensor :param keys: Shape ``(B, T_k, d_model)``. :type keys: torch.Tensor :param values: Shape ``(B, T_k, d_model)``. :type values: torch.Tensor :param query_mask: Shape ``(B, T_q)``. True = valid. Masked output positions are set to zero. :type query_mask: torch.BoolTensor, optional :param key_mask: Shape ``(B, T_k)``. True = valid. Masked key/value positions are zeroed before accumulation. :type key_mask: torch.BoolTensor, optional :returns: Output of shape ``(B, T_q, d_model)`` and ``None`` (no attention weight tensor is returned). :rtype: tuple[torch.Tensor, None]