linmult.core.fusion =================== .. py:module:: linmult.core.fusion .. autoapi-nested-parse:: Fusion-stage module for LinMulT. FusionModule groups all post-branch fusion steps into a single nn.Module: - Optional per-branch temporal reduction (TRM) - Optional TAM-based cross-branch alignment and fusion - Optional self-attention over the fused representation (SAT) - Optional feed-forward layer (FFN) Classes ------- .. autoapisummary:: linmult.core.fusion.FusionModule Module Contents --------------- .. py:class:: FusionModule(input_dim: int, n_branches: int, d_model: int, num_heads: int = 8, attention_config: linmult.core.attention.AttentionConfig | None = None, *, time_dim_reducer: str | None = None, add_tam_fusion: bool = False, tam_aligner: str = 'aap', tam_time_dim: int = 0, fusion_num_layers: int = 6, dropout_tam: float = 0.1, add_sat_fusion: bool = False, fusion_sat_num_layers: int = 6, add_ffn_fusion: bool = False, dropout_output: float = 0.0, dropout_pe: float = 0.0, dropout_ffn: float = 0.1) Bases: :py:obj:`torch.nn.Module` Fuses multi-branch representations into a single representation. Builds all internal sub-modules (TRM, TAM, SAT, FFN) from primitive parameters. Two primary fusion paths (mutually exclusive): - **TAM path** (when ``add_tam_fusion=True``): the TAM receives all branch tensors, aligns them temporally, and projects to a common dimension. - **Concat path**: each branch is optionally reduced along the time axis with TRM (when ``time_dim_reducer`` is set), then all branches are concatenated along the feature dimension. Either path is followed by an optional self-attention transformer (SAT) and an optional feed-forward residual layer (FFN). :param input_dim: Per-branch feature dimension (from ``CrossModalModule.output_dim``). :param n_branches: Number of branches to fuse. :param d_model: Base model dimension (used for TAM output target). :param num_heads: Number of attention heads for internal transformers. :param attention_config: Attention configuration for internal transformers. :param time_dim_reducer: Temporal reducer type (``"attentionpool"``, ``"gap"``, ``"gmp"``, ``"last"``). ``None`` to skip. :param add_tam_fusion: Whether to use TAM-based fusion. :param tam_aligner: Temporal alignment method for TAM fusion. :param tam_time_dim: Target time dimension for TAM fusion. :param fusion_num_layers: Depth of the TAM fusion transformer. :param add_sat_fusion: Whether to add a self-attention transformer after fusion. :param fusion_sat_num_layers: Depth of the fusion SAT. :param add_ffn_fusion: Whether to add a feed-forward residual layer. :param dropout_tam: Dropout for TAM fusion. :param dropout_output: Dropout for FFN. :param dropout_pe: Positional encoding dropout for internal transformers. :param dropout_ffn: FFN dropout within transformers. Initialize internal Module state, shared by both nn.Module and ScriptModule. .. py:property:: output_dim :type: int Final fused dimension after all fusion stages. .. py:method:: forward(x_list: list[torch.Tensor], mask_list: list[torch.Tensor | None]) -> tuple[torch.Tensor, torch.Tensor | None] Fuse branch representations into one tensor. :param x_list: One tensor per branch, each ``(B, T, input_dim)``. :param mask_list: Boolean mask per branch, each ``(B, T)`` or ``None``. :returns: - Fused tensor ``(B, [T,] output_dim)``. Time axis is present when the TAM path is used or when TRM is not applied. - Joint boolean mask ``(B, T)`` or ``None`` when all masks are ``None`` (or after temporal reduction which removes the time axis). :rtype: Tuple of