linmult.core.branch

Branch modules for LinMulT.

  • MultimodalSignal: creates a shared multimodal token from all modality sequences

  • CrossModalBranch: processes one target modality through cross-modal attention from all sources

  • CrossModalModule: orchestrates all branches with optional MMS

Classes

MultimodalSignal

Creates a shared multimodal signal from all modality sequences.

CrossModalBranch

Processes one target modality through cross-modal and self-attention.

CrossModalModule

Orchestrates cross-modal attention across all modalities.

Module Contents

class linmult.core.branch.MultimodalSignal(tam: linmult.core.temporal.TAM)[source]

Bases: torch.nn.Module

Creates a shared multimodal signal from all modality sequences.

Wraps a TAM that fuses all modality sequences into a single aligned representation. The output is appended to x_list / mask_list so that each branch can attend to it via its final cross-modal transformer.

Parameters:

tam (TAM) – Temporal alignment module that receives all modality sequences.

Initialize internal Module state, shared by both nn.Module and ScriptModule.

forward(x_list: list[torch.Tensor], mask_list: list[torch.Tensor | None]) tuple[list[torch.Tensor], list[torch.Tensor | None]][source]

Compute multimodal signal and append to input lists.

Parameters:
  • x_list – One tensor per modality, each (B, T_i, d_model).

  • mask_list – Boolean mask per modality, each (B, T_i) or None.

Returns:

Extended (x_list, mask_list) with the multimodal signal appended as the last element. The signal has shape (B, time_dim, tgt_dim).

class linmult.core.branch.CrossModalBranch(cross_transformers: torch.nn.ModuleList, branch_sat: linmult.core.transformer.TransformerEncoder, unimodal_sat: linmult.core.transformer.TransformerEncoder | None = None)[source]

Bases: torch.nn.Module

Processes one target modality through cross-modal and self-attention.

Forward pass:
  1. Run one cross-modal TransformerEncoder per source (query=target, key/value=source).

  2. Concatenate all cross-modal outputs → (B, T, branch_dim).

  3. Apply branch self-attention (SAT).

  4. Optionally concatenate with unimodal SAT output.

Parameters:
  • cross_transformers – One cross-modal TransformerEncoder per source modality (including the MMS token if enabled).

  • branch_sat – Self-attention TransformerEncoder applied over the concatenated cross-modal representation.

  • unimodal_sat – Optional self-attention TransformerEncoder applied to the original (projected) query sequence. Its output is concatenated with the SAT output before returning.

Initialize internal Module state, shared by both nn.Module and ScriptModule.

forward(x_query: torch.Tensor, x_sources: list[torch.Tensor], mask_query: torch.Tensor | None = None, mask_sources: list[torch.Tensor | None] | None = None) torch.Tensor[source]

Run cross-modal + self-attention for one target modality.

Parameters:
  • x_query – Target modality tensor (B, T_q, d_model).

  • x_sources – Source modality tensors, each (B, T_s, d_model). Must be in the same order as cross_transformers.

  • mask_query – Boolean mask (B, T_q) for the query. None = no mask.

  • mask_sources – Boolean masks for each source, same length as x_sources. None entries treated as no mask; omit the list to use no masks.

Returns:

Branch representation (B, T_q, full_branch_dim) where full_branch_dim = len(x_sources) * d_model [+ d_model if unimodal_sat].

class linmult.core.branch.CrossModalModule(num_modalities: int, d_model: int, num_heads: int = 8, branch_cmt_num_layers: int = 6, branch_sat_num_layers: int = 6, attention_config: linmult.core.attention.AttentionConfig | None = None, *, add_mms: bool = False, mms_num_layers: int = 6, tam_aligner: str = 'aap', tam_time_dim: int = 0, dropout_tam: float = 0.1, add_unimodal_sat: bool = False, unimodal_sat_num_layers: int = 6, dropout_pe: float = 0.0, dropout_ffn: float = 0.1)[source]

Bases: torch.nn.Module

Orchestrates cross-modal attention across all modalities.

Builds and manages all cross-modal branches, including optional multimodal signal (MMS) generation. Each target modality gets its own CrossModalBranch with cross-modal transformers from every other source (and from MMS if enabled).

Parameters:
  • num_modalities – Number of input modalities (>= 2).

  • d_model – Model dimension (shared across all transformers).

  • num_heads – Number of attention heads.

  • branch_cmt_num_layers – Depth of each cross-modal transformer.

  • branch_sat_num_layers – Depth of each branch self-attention transformer.

  • attention_config – Attention configuration for all internal transformers.

  • add_mms – Whether to create a multimodal signal via TAM.

  • mms_num_layers – Depth of the MMS transformer (only when add_mms=True).

  • tam_aligner – Temporal alignment method for MMS.

  • tam_time_dim – Target time dimension for MMS alignment.

  • dropout_tam – Dropout for MMS TAM.

  • add_unimodal_sat – Whether to add a unimodal self-attention per branch.

  • unimodal_sat_num_layers – Depth of unimodal self-attention.

  • dropout_pe – Positional encoding dropout.

  • dropout_ffn – FFN dropout (within transformers).

Initialize internal Module state, shared by both nn.Module and ScriptModule.

property output_dim: int

Per-branch output dimension (including optional unimodal SAT).

forward(x_list: list[torch.Tensor], mask_list: list[torch.Tensor | None]) list[torch.Tensor][source]

Run cross-modal attention for all target modalities.

Parameters:
  • x_list – One projected tensor per modality, each (B, T_i, d_model).

  • mask_list – Boolean mask per modality, each (B, T_i) or None.

Returns:

List of num_modalities branch representations, each (B, T_tgt, output_dim).