linmult.core.fusion¶
Fusion-stage module for LinMulT.
- FusionModule groups all post-branch fusion steps into a single nn.Module:
Optional per-branch temporal reduction (TRM)
Optional TAM-based cross-branch alignment and fusion
Optional self-attention over the fused representation (SAT)
Optional feed-forward layer (FFN)
Classes¶
Fuses multi-branch representations into a single representation. |
Module Contents¶
- class linmult.core.fusion.FusionModule(input_dim: int, n_branches: int, d_model: int, num_heads: int = 8, attention_config: linmult.core.attention.AttentionConfig | None = None, *, time_dim_reducer: str | None = None, add_tam_fusion: bool = False, tam_aligner: str = 'aap', tam_time_dim: int = 0, fusion_num_layers: int = 6, dropout_tam: float = 0.1, add_sat_fusion: bool = False, fusion_sat_num_layers: int = 6, add_ffn_fusion: bool = False, dropout_output: float = 0.0, dropout_pe: float = 0.0, dropout_ffn: float = 0.1)[source]¶
Bases:
torch.nn.ModuleFuses multi-branch representations into a single representation.
Builds all internal sub-modules (TRM, TAM, SAT, FFN) from primitive parameters. Two primary fusion paths (mutually exclusive):
TAM path (when
add_tam_fusion=True): the TAM receives all branch tensors, aligns them temporally, and projects to a common dimension.Concat path: each branch is optionally reduced along the time axis with TRM (when
time_dim_reduceris set), then all branches are concatenated along the feature dimension.
Either path is followed by an optional self-attention transformer (SAT) and an optional feed-forward residual layer (FFN).
- Parameters:
input_dim – Per-branch feature dimension (from
CrossModalModule.output_dim).n_branches – Number of branches to fuse.
d_model – Base model dimension (used for TAM output target).
num_heads – Number of attention heads for internal transformers.
attention_config – Attention configuration for internal transformers.
time_dim_reducer – Temporal reducer type (
"attentionpool","gap","gmp","last").Noneto skip.add_tam_fusion – Whether to use TAM-based fusion.
tam_aligner – Temporal alignment method for TAM fusion.
tam_time_dim – Target time dimension for TAM fusion.
fusion_num_layers – Depth of the TAM fusion transformer.
add_sat_fusion – Whether to add a self-attention transformer after fusion.
fusion_sat_num_layers – Depth of the fusion SAT.
add_ffn_fusion – Whether to add a feed-forward residual layer.
dropout_tam – Dropout for TAM fusion.
dropout_output – Dropout for FFN.
dropout_pe – Positional encoding dropout for internal transformers.
dropout_ffn – FFN dropout within transformers.
Initialize internal Module state, shared by both nn.Module and ScriptModule.
- property output_dim: int¶
Final fused dimension after all fusion stages.
- forward(x_list: list[torch.Tensor], mask_list: list[torch.Tensor | None]) tuple[torch.Tensor, torch.Tensor | None][source]¶
Fuse branch representations into one tensor.
- Parameters:
x_list – One tensor per branch, each
(B, T, input_dim).mask_list – Boolean mask per branch, each
(B, T)orNone.
- Returns:
Fused tensor
(B, [T,] output_dim). Time axis is present when the TAM path is used or when TRM is not applied.Joint boolean mask
(B, T)orNonewhen all masks areNone(or after temporal reduction which removes the time axis).
- Return type:
Tuple of