linmult.core.fusion

Fusion-stage module for LinMulT.

FusionModule groups all post-branch fusion steps into a single nn.Module:
  • Optional per-branch temporal reduction (TRM)

  • Optional TAM-based cross-branch alignment and fusion

  • Optional self-attention over the fused representation (SAT)

  • Optional feed-forward layer (FFN)

Classes

FusionModule

Fuses multi-branch representations into a single representation.

Module Contents

class linmult.core.fusion.FusionModule(input_dim: int, n_branches: int, d_model: int, num_heads: int = 8, attention_config: linmult.core.attention.AttentionConfig | None = None, *, time_dim_reducer: str | None = None, add_tam_fusion: bool = False, tam_aligner: str = 'aap', tam_time_dim: int = 0, fusion_num_layers: int = 6, dropout_tam: float = 0.1, add_sat_fusion: bool = False, fusion_sat_num_layers: int = 6, add_ffn_fusion: bool = False, dropout_output: float = 0.0, dropout_pe: float = 0.0, dropout_ffn: float = 0.1)[source]

Bases: torch.nn.Module

Fuses multi-branch representations into a single representation.

Builds all internal sub-modules (TRM, TAM, SAT, FFN) from primitive parameters. Two primary fusion paths (mutually exclusive):

  • TAM path (when add_tam_fusion=True): the TAM receives all branch tensors, aligns them temporally, and projects to a common dimension.

  • Concat path: each branch is optionally reduced along the time axis with TRM (when time_dim_reducer is set), then all branches are concatenated along the feature dimension.

Either path is followed by an optional self-attention transformer (SAT) and an optional feed-forward residual layer (FFN).

Parameters:
  • input_dim – Per-branch feature dimension (from CrossModalModule.output_dim).

  • n_branches – Number of branches to fuse.

  • d_model – Base model dimension (used for TAM output target).

  • num_heads – Number of attention heads for internal transformers.

  • attention_config – Attention configuration for internal transformers.

  • time_dim_reducer – Temporal reducer type ("attentionpool", "gap", "gmp", "last"). None to skip.

  • add_tam_fusion – Whether to use TAM-based fusion.

  • tam_aligner – Temporal alignment method for TAM fusion.

  • tam_time_dim – Target time dimension for TAM fusion.

  • fusion_num_layers – Depth of the TAM fusion transformer.

  • add_sat_fusion – Whether to add a self-attention transformer after fusion.

  • fusion_sat_num_layers – Depth of the fusion SAT.

  • add_ffn_fusion – Whether to add a feed-forward residual layer.

  • dropout_tam – Dropout for TAM fusion.

  • dropout_output – Dropout for FFN.

  • dropout_pe – Positional encoding dropout for internal transformers.

  • dropout_ffn – FFN dropout within transformers.

Initialize internal Module state, shared by both nn.Module and ScriptModule.

property output_dim: int

Final fused dimension after all fusion stages.

forward(x_list: list[torch.Tensor], mask_list: list[torch.Tensor | None]) tuple[torch.Tensor, torch.Tensor | None][source]

Fuse branch representations into one tensor.

Parameters:
  • x_list – One tensor per branch, each (B, T, input_dim).

  • mask_list – Boolean mask per branch, each (B, T) or None.

Returns:

  • Fused tensor (B, [T,] output_dim). Time axis is present when the TAM path is used or when TRM is not applied.

  • Joint boolean mask (B, T) or None when all masks are None (or after temporal reduction which removes the time axis).

Return type:

Tuple of