Source code for linmult.core.fusion

"""Fusion-stage module for LinMulT.

FusionModule groups all post-branch fusion steps into a single nn.Module:
  - Optional per-branch temporal reduction (TRM)
  - Optional TAM-based cross-branch alignment and fusion
  - Optional self-attention over the fused representation (SAT)
  - Optional feed-forward layer (FFN)
"""

from typing import cast

import torch
from torch import Tensor, nn

from linmult.core.attention import AttentionConfig
from linmult.core.ffn import FFNResidual
from linmult.core.temporal import TAM, TRM
from linmult.core.transformer import TransformerEncoder


def _combine_masks(x_list: list[Tensor], mask_list: list[Tensor | None]) -> Tensor | None:
    """AND-combine per-branch masks into a single joint mask.

    Returns ``None`` when all masks are ``None`` (no masking needed).
    """
    if all(m is None for m in mask_list):
        return None
    effective = [
        m if m is not None else torch.ones(x.shape[:2], dtype=torch.bool, device=x.device)
        for x, m in zip(x_list, mask_list)
    ]
    return torch.stack(effective, dim=0).all(dim=0)


[docs] class FusionModule(nn.Module): """Fuses multi-branch representations into a single representation. Builds all internal sub-modules (TRM, TAM, SAT, FFN) from primitive parameters. Two primary fusion paths (mutually exclusive): - **TAM path** (when ``add_tam_fusion=True``): the TAM receives all branch tensors, aligns them temporally, and projects to a common dimension. - **Concat path**: each branch is optionally reduced along the time axis with TRM (when ``time_dim_reducer`` is set), then all branches are concatenated along the feature dimension. Either path is followed by an optional self-attention transformer (SAT) and an optional feed-forward residual layer (FFN). Args: input_dim: Per-branch feature dimension (from ``CrossModalModule.output_dim``). n_branches: Number of branches to fuse. d_model: Base model dimension (used for TAM output target). num_heads: Number of attention heads for internal transformers. attention_config: Attention configuration for internal transformers. time_dim_reducer: Temporal reducer type (``"attentionpool"``, ``"gap"``, ``"gmp"``, ``"last"``). ``None`` to skip. add_tam_fusion: Whether to use TAM-based fusion. tam_aligner: Temporal alignment method for TAM fusion. tam_time_dim: Target time dimension for TAM fusion. fusion_num_layers: Depth of the TAM fusion transformer. add_sat_fusion: Whether to add a self-attention transformer after fusion. fusion_sat_num_layers: Depth of the fusion SAT. add_ffn_fusion: Whether to add a feed-forward residual layer. dropout_tam: Dropout for TAM fusion. dropout_output: Dropout for FFN. dropout_pe: Positional encoding dropout for internal transformers. dropout_ffn: FFN dropout within transformers. """ def __init__( self, input_dim: int, n_branches: int, d_model: int, num_heads: int = 8, attention_config: AttentionConfig | None = None, *, time_dim_reducer: str | None = None, add_tam_fusion: bool = False, tam_aligner: str = "aap", tam_time_dim: int = 0, fusion_num_layers: int = 6, dropout_tam: float = 0.1, add_sat_fusion: bool = False, fusion_sat_num_layers: int = 6, add_ffn_fusion: bool = False, dropout_output: float = 0.0, dropout_pe: float = 0.0, dropout_ffn: float = 0.1, ) -> None: super().__init__() fusion_dim = input_dim * n_branches self.trm: TRM | None = None if time_dim_reducer: self.trm = TRM(d_model=input_dim, reducer=time_dim_reducer) self.tam: TAM | None = None if add_tam_fusion: self.tam = TAM( input_dim=fusion_dim, output_dim=n_branches * d_model, aligner=tam_aligner, time_dim=tam_time_dim, dropout_out=dropout_tam, num_layers=fusion_num_layers, num_heads=num_heads, attention_config=attention_config, dropout_pe=dropout_pe, dropout_ffn=dropout_ffn, name="TAM fusion", ) fusion_dim = n_branches * d_model self.sat: TransformerEncoder | None = None if add_sat_fusion: self.sat = TransformerEncoder( d_model=fusion_dim, num_heads=num_heads, num_layers=fusion_sat_num_layers, attention_config=attention_config, dropout_pe=dropout_pe, dropout_ffn=dropout_ffn, name="Fusion SAT", ) self.ffn: FFNResidual | None = None if add_ffn_fusion: self.ffn = FFNResidual(dim=fusion_dim, dropout=dropout_output) self._output_dim = fusion_dim @property def output_dim(self) -> int: """Final fused dimension after all fusion stages.""" return self._output_dim
[docs] def forward( self, x_list: list[Tensor], mask_list: list[Tensor | None], ) -> tuple[Tensor, Tensor | None]: """Fuse branch representations into one tensor. Args: x_list: One tensor per branch, each ``(B, T, input_dim)``. mask_list: Boolean mask per branch, each ``(B, T)`` or ``None``. Returns: Tuple of: - Fused tensor ``(B, [T,] output_dim)``. Time axis is present when the TAM path is used or when TRM is not applied. - Joint boolean mask ``(B, T)`` or ``None`` when all masks are ``None`` (or after temporal reduction which removes the time axis). """ mask: Tensor | None if self.tam is not None: x, mask = self.tam(x_list, mask_list) else: if self.trm is not None: x_list = self.trm.apply_to_list(x_list, mask_list) # (B, F) each mask_list = cast("list[Tensor | None]", [None] * len(mask_list)) mask = _combine_masks(x_list, mask_list) x = torch.cat(x_list, dim=-1) # (B, T, combined) or (B, combined) if self.sat is not None: x = self.sat(x, query_mask=mask) if self.ffn is not None: x = self.ffn(x) return x, mask