Source code for linmult.core.branch

"""Branch modules for LinMulT.

- MultimodalSignal: creates a shared multimodal token from all modality sequences
- CrossModalBranch: processes one target modality through cross-modal attention from all sources
- CrossModalModule: orchestrates all branches with optional MMS
"""

from typing import cast

import torch
from torch import Tensor, nn

from linmult.core.attention import AttentionConfig
from linmult.core.temporal import TAM
from linmult.core.transformer import TransformerEncoder


[docs] class MultimodalSignal(nn.Module): """Creates a shared multimodal signal from all modality sequences. Wraps a TAM that fuses all modality sequences into a single aligned representation. The output is appended to ``x_list`` / ``mask_list`` so that each branch can attend to it via its final cross-modal transformer. Args: tam (TAM): Temporal alignment module that receives all modality sequences. """ def __init__(self, tam: TAM) -> None: super().__init__() self.tam = tam
[docs] def forward( self, x_list: list[Tensor], mask_list: list[Tensor | None], ) -> tuple[list[Tensor], list[Tensor | None]]: """Compute multimodal signal and append to input lists. Args: x_list: One tensor per modality, each ``(B, T_i, d_model)``. mask_list: Boolean mask per modality, each ``(B, T_i)`` or ``None``. Returns: Extended ``(x_list, mask_list)`` with the multimodal signal appended as the last element. The signal has shape ``(B, time_dim, tgt_dim)``. """ mms_x, mms_mask = self.tam(x_list, mask_list) return list(x_list) + [mms_x], list(mask_list) + [mms_mask]
[docs] class CrossModalBranch(nn.Module): """Processes one target modality through cross-modal and self-attention. Forward pass: 1. Run one cross-modal TransformerEncoder per source (query=target, key/value=source). 2. Concatenate all cross-modal outputs → ``(B, T, branch_dim)``. 3. Apply branch self-attention (SAT). 4. Optionally concatenate with unimodal SAT output. Args: cross_transformers: One cross-modal TransformerEncoder per source modality (including the MMS token if enabled). branch_sat: Self-attention TransformerEncoder applied over the concatenated cross-modal representation. unimodal_sat: Optional self-attention TransformerEncoder applied to the original (projected) query sequence. Its output is concatenated with the SAT output before returning. """ def __init__( self, cross_transformers: nn.ModuleList, branch_sat: TransformerEncoder, unimodal_sat: TransformerEncoder | None = None, ) -> None: super().__init__() self.cross_transformers = cross_transformers self.branch_sat = branch_sat self.unimodal_sat = unimodal_sat
[docs] def forward( self, x_query: Tensor, x_sources: list[Tensor], mask_query: Tensor | None = None, mask_sources: list[Tensor | None] | None = None, ) -> Tensor: """Run cross-modal + self-attention for one target modality. Args: x_query: Target modality tensor ``(B, T_q, d_model)``. x_sources: Source modality tensors, each ``(B, T_s, d_model)``. Must be in the same order as ``cross_transformers``. mask_query: Boolean mask ``(B, T_q)`` for the query. ``None`` = no mask. mask_sources: Boolean masks for each source, same length as ``x_sources``. ``None`` entries treated as no mask; omit the list to use no masks. Returns: Branch representation ``(B, T_q, full_branch_dim)`` where ``full_branch_dim = len(x_sources) * d_model [+ d_model if unimodal_sat]``. """ if mask_sources is None: mask_sources = cast("list[torch.Tensor | None]", [None] * len(x_sources)) cross_outputs = [ cmt(x_query, src, src, query_mask=mask_query, key_mask=mask_src) for cmt, src, mask_src in zip(self.cross_transformers, x_sources, mask_sources) ] hidden = torch.cat(cross_outputs, dim=-1) # (B, T_q, branch_dim) out = self.branch_sat(hidden, query_mask=mask_query) # (B, T_q, branch_dim) if self.unimodal_sat is not None: uni = self.unimodal_sat(x_query, query_mask=mask_query) # (B, T_q, d_model) out = torch.cat([out, uni], dim=-1) return out
[docs] class CrossModalModule(nn.Module): """Orchestrates cross-modal attention across all modalities. Builds and manages all cross-modal branches, including optional multimodal signal (MMS) generation. Each target modality gets its own :class:`CrossModalBranch` with cross-modal transformers from every other source (and from MMS if enabled). Args: num_modalities: Number of input modalities (>= 2). d_model: Model dimension (shared across all transformers). num_heads: Number of attention heads. branch_cmt_num_layers: Depth of each cross-modal transformer. branch_sat_num_layers: Depth of each branch self-attention transformer. attention_config: Attention configuration for all internal transformers. add_mms: Whether to create a multimodal signal via TAM. mms_num_layers: Depth of the MMS transformer (only when ``add_mms=True``). tam_aligner: Temporal alignment method for MMS. tam_time_dim: Target time dimension for MMS alignment. dropout_tam: Dropout for MMS TAM. add_unimodal_sat: Whether to add a unimodal self-attention per branch. unimodal_sat_num_layers: Depth of unimodal self-attention. dropout_pe: Positional encoding dropout. dropout_ffn: FFN dropout (within transformers). """ def __init__( self, num_modalities: int, d_model: int, num_heads: int = 8, branch_cmt_num_layers: int = 6, branch_sat_num_layers: int = 6, attention_config: AttentionConfig | None = None, *, add_mms: bool = False, mms_num_layers: int = 6, tam_aligner: str = "aap", tam_time_dim: int = 0, dropout_tam: float = 0.1, add_unimodal_sat: bool = False, unimodal_sat_num_layers: int = 6, dropout_pe: float = 0.0, dropout_ffn: float = 0.1, ) -> None: super().__init__() self.num_modalities = num_modalities self.d_model = d_model self._add_mms = add_mms self._add_unimodal_sat = add_unimodal_sat # Shared encoder params self._num_heads = num_heads self._attention_config = attention_config self._dropout_pe = dropout_pe self._dropout_ffn = dropout_ffn # MMS (optional) self.mms: MultimodalSignal | None = None if add_mms: self.mms = MultimodalSignal( tam=TAM( input_dim=num_modalities * d_model, output_dim=d_model, aligner=tam_aligner, time_dim=tam_time_dim, dropout_out=dropout_tam, num_layers=mms_num_layers, num_heads=num_heads, attention_config=attention_config, dropout_pe=dropout_pe, dropout_ffn=dropout_ffn, name="TAM MMS", ) ) # Branches — one per target modality n_sources = self.num_modalities if self._add_mms else self.num_modalities - 1 branches = [] for tgt in range(self.num_modalities): sources = [i for i in range(self.num_modalities) if i != tgt] cross_transformers = nn.ModuleList( [ self._make_encoder( name=f"CMT {src}->{tgt}", num_layers=branch_cmt_num_layers, is_cross_modal=True, ) for src in sources ] ) if add_mms: cross_transformers.append( self._make_encoder( name=f"CMT mms->{tgt}", num_layers=branch_cmt_num_layers, is_cross_modal=True, ) ) sat = self._make_encoder( name=f"SAT {tgt}", num_layers=branch_sat_num_layers, d_model=n_sources * d_model, ) unimodal_sat = None if add_unimodal_sat: unimodal_sat = self._make_encoder( name=f"Unimodal SAT {tgt}", num_layers=unimodal_sat_num_layers, ) branches.append( CrossModalBranch( cross_transformers=cross_transformers, branch_sat=sat, unimodal_sat=unimodal_sat, ) ) self.branches = nn.ModuleList(branches) def _make_encoder( self, *, name: str, num_layers: int, d_model: int | None = None, is_cross_modal: bool = False, ) -> TransformerEncoder: """Create a TransformerEncoder with shared settings.""" return TransformerEncoder( name=name, d_model=d_model if d_model is not None else self.d_model, num_heads=self._num_heads, num_layers=num_layers, attention_config=self._attention_config, is_cross_modal=is_cross_modal, dropout_pe=self._dropout_pe, dropout_ffn=self._dropout_ffn, ) @property def output_dim(self) -> int: """Per-branch output dimension (including optional unimodal SAT).""" n_sources = self.num_modalities if self._add_mms else self.num_modalities - 1 dim = n_sources * self.d_model if self._add_unimodal_sat: dim += self.d_model return dim
[docs] def forward( self, x_list: list[Tensor], mask_list: list[Tensor | None], ) -> list[Tensor]: """Run cross-modal attention for all target modalities. Args: x_list: One projected tensor per modality, each ``(B, T_i, d_model)``. mask_list: Boolean mask per modality, each ``(B, T_i)`` or ``None``. Returns: List of ``num_modalities`` branch representations, each ``(B, T_tgt, output_dim)``. """ if self.mms is not None: x_list, mask_list = self.mms(x_list, mask_list) branch_reps = [] for tgt in range(self.num_modalities): src_indices = [j for j in range(self.num_modalities) if j != tgt] if self.mms is not None: src_indices.append(self.num_modalities) branch_reps.append( self.branches[tgt]( x_query=x_list[tgt], x_sources=[x_list[j] for j in src_indices], mask_query=mask_list[tgt], mask_sources=[mask_list[j] for j in src_indices], ) ) return branch_reps