linmult.core.branch¶
Branch modules for LinMulT.
MultimodalSignal: creates a shared multimodal token from all modality sequences
CrossModalBranch: processes one target modality through cross-modal attention from all sources
CrossModalModule: orchestrates all branches with optional MMS
Classes¶
Creates a shared multimodal signal from all modality sequences. |
|
Processes one target modality through cross-modal and self-attention. |
|
Orchestrates cross-modal attention across all modalities. |
Module Contents¶
- class linmult.core.branch.MultimodalSignal(tam: linmult.core.temporal.TAM)[source]¶
Bases:
torch.nn.ModuleCreates a shared multimodal signal from all modality sequences.
Wraps a TAM that fuses all modality sequences into a single aligned representation. The output is appended to
x_list/mask_listso that each branch can attend to it via its final cross-modal transformer.- Parameters:
tam (TAM) – Temporal alignment module that receives all modality sequences.
Initialize internal Module state, shared by both nn.Module and ScriptModule.
- forward(x_list: list[torch.Tensor], mask_list: list[torch.Tensor | None]) tuple[list[torch.Tensor], list[torch.Tensor | None]][source]¶
Compute multimodal signal and append to input lists.
- Parameters:
x_list – One tensor per modality, each
(B, T_i, d_model).mask_list – Boolean mask per modality, each
(B, T_i)orNone.
- Returns:
Extended
(x_list, mask_list)with the multimodal signal appended as the last element. The signal has shape(B, time_dim, tgt_dim).
- class linmult.core.branch.CrossModalBranch(cross_transformers: torch.nn.ModuleList, branch_sat: linmult.core.transformer.TransformerEncoder, unimodal_sat: linmult.core.transformer.TransformerEncoder | None = None)[source]¶
Bases:
torch.nn.ModuleProcesses one target modality through cross-modal and self-attention.
- Forward pass:
Run one cross-modal TransformerEncoder per source (query=target, key/value=source).
Concatenate all cross-modal outputs →
(B, T, branch_dim).Apply branch self-attention (SAT).
Optionally concatenate with unimodal SAT output.
- Parameters:
cross_transformers – One cross-modal TransformerEncoder per source modality (including the MMS token if enabled).
branch_sat – Self-attention TransformerEncoder applied over the concatenated cross-modal representation.
unimodal_sat – Optional self-attention TransformerEncoder applied to the original (projected) query sequence. Its output is concatenated with the SAT output before returning.
Initialize internal Module state, shared by both nn.Module and ScriptModule.
- forward(x_query: torch.Tensor, x_sources: list[torch.Tensor], mask_query: torch.Tensor | None = None, mask_sources: list[torch.Tensor | None] | None = None) torch.Tensor[source]¶
Run cross-modal + self-attention for one target modality.
- Parameters:
x_query – Target modality tensor
(B, T_q, d_model).x_sources – Source modality tensors, each
(B, T_s, d_model). Must be in the same order ascross_transformers.mask_query – Boolean mask
(B, T_q)for the query.None= no mask.mask_sources – Boolean masks for each source, same length as
x_sources.Noneentries treated as no mask; omit the list to use no masks.
- Returns:
Branch representation
(B, T_q, full_branch_dim)wherefull_branch_dim = len(x_sources) * d_model [+ d_model if unimodal_sat].
- class linmult.core.branch.CrossModalModule(num_modalities: int, d_model: int, num_heads: int = 8, branch_cmt_num_layers: int = 6, branch_sat_num_layers: int = 6, attention_config: linmult.core.attention.AttentionConfig | None = None, *, add_mms: bool = False, mms_num_layers: int = 6, tam_aligner: str = 'aap', tam_time_dim: int = 0, dropout_tam: float = 0.1, add_unimodal_sat: bool = False, unimodal_sat_num_layers: int = 6, dropout_pe: float = 0.0, dropout_ffn: float = 0.1)[source]¶
Bases:
torch.nn.ModuleOrchestrates cross-modal attention across all modalities.
Builds and manages all cross-modal branches, including optional multimodal signal (MMS) generation. Each target modality gets its own
CrossModalBranchwith cross-modal transformers from every other source (and from MMS if enabled).- Parameters:
num_modalities – Number of input modalities (>= 2).
d_model – Model dimension (shared across all transformers).
num_heads – Number of attention heads.
branch_cmt_num_layers – Depth of each cross-modal transformer.
branch_sat_num_layers – Depth of each branch self-attention transformer.
attention_config – Attention configuration for all internal transformers.
add_mms – Whether to create a multimodal signal via TAM.
mms_num_layers – Depth of the MMS transformer (only when
add_mms=True).tam_aligner – Temporal alignment method for MMS.
tam_time_dim – Target time dimension for MMS alignment.
dropout_tam – Dropout for MMS TAM.
add_unimodal_sat – Whether to add a unimodal self-attention per branch.
unimodal_sat_num_layers – Depth of unimodal self-attention.
dropout_pe – Positional encoding dropout.
dropout_ffn – FFN dropout (within transformers).
Initialize internal Module state, shared by both nn.Module and ScriptModule.
- property output_dim: int¶
Per-branch output dimension (including optional unimodal SAT).
- forward(x_list: list[torch.Tensor], mask_list: list[torch.Tensor | None]) list[torch.Tensor][source]¶
Run cross-modal attention for all target modalities.
- Parameters:
x_list – One projected tensor per modality, each
(B, T_i, d_model).mask_list – Boolean mask per modality, each
(B, T_i)orNone.
- Returns:
List of
num_modalitiesbranch representations, each(B, T_tgt, output_dim).