linmult.core.branch =================== .. py:module:: linmult.core.branch .. autoapi-nested-parse:: Branch modules for LinMulT. - MultimodalSignal: creates a shared multimodal token from all modality sequences - CrossModalBranch: processes one target modality through cross-modal attention from all sources - CrossModalModule: orchestrates all branches with optional MMS Classes ------- .. autoapisummary:: linmult.core.branch.MultimodalSignal linmult.core.branch.CrossModalBranch linmult.core.branch.CrossModalModule Module Contents --------------- .. py:class:: MultimodalSignal(tam: linmult.core.temporal.TAM) Bases: :py:obj:`torch.nn.Module` Creates a shared multimodal signal from all modality sequences. Wraps a TAM that fuses all modality sequences into a single aligned representation. The output is appended to ``x_list`` / ``mask_list`` so that each branch can attend to it via its final cross-modal transformer. :param tam: Temporal alignment module that receives all modality sequences. :type tam: TAM Initialize internal Module state, shared by both nn.Module and ScriptModule. .. py:method:: forward(x_list: list[torch.Tensor], mask_list: list[torch.Tensor | None]) -> tuple[list[torch.Tensor], list[torch.Tensor | None]] Compute multimodal signal and append to input lists. :param x_list: One tensor per modality, each ``(B, T_i, d_model)``. :param mask_list: Boolean mask per modality, each ``(B, T_i)`` or ``None``. :returns: Extended ``(x_list, mask_list)`` with the multimodal signal appended as the last element. The signal has shape ``(B, time_dim, tgt_dim)``. .. py:class:: CrossModalBranch(cross_transformers: torch.nn.ModuleList, branch_sat: linmult.core.transformer.TransformerEncoder, unimodal_sat: linmult.core.transformer.TransformerEncoder | None = None) Bases: :py:obj:`torch.nn.Module` Processes one target modality through cross-modal and self-attention. Forward pass: 1. Run one cross-modal TransformerEncoder per source (query=target, key/value=source). 2. Concatenate all cross-modal outputs → ``(B, T, branch_dim)``. 3. Apply branch self-attention (SAT). 4. Optionally concatenate with unimodal SAT output. :param cross_transformers: One cross-modal TransformerEncoder per source modality (including the MMS token if enabled). :param branch_sat: Self-attention TransformerEncoder applied over the concatenated cross-modal representation. :param unimodal_sat: Optional self-attention TransformerEncoder applied to the original (projected) query sequence. Its output is concatenated with the SAT output before returning. Initialize internal Module state, shared by both nn.Module and ScriptModule. .. py:method:: forward(x_query: torch.Tensor, x_sources: list[torch.Tensor], mask_query: torch.Tensor | None = None, mask_sources: list[torch.Tensor | None] | None = None) -> torch.Tensor Run cross-modal + self-attention for one target modality. :param x_query: Target modality tensor ``(B, T_q, d_model)``. :param x_sources: Source modality tensors, each ``(B, T_s, d_model)``. Must be in the same order as ``cross_transformers``. :param mask_query: Boolean mask ``(B, T_q)`` for the query. ``None`` = no mask. :param mask_sources: Boolean masks for each source, same length as ``x_sources``. ``None`` entries treated as no mask; omit the list to use no masks. :returns: Branch representation ``(B, T_q, full_branch_dim)`` where ``full_branch_dim = len(x_sources) * d_model [+ d_model if unimodal_sat]``. .. py:class:: CrossModalModule(num_modalities: int, d_model: int, num_heads: int = 8, branch_cmt_num_layers: int = 6, branch_sat_num_layers: int = 6, attention_config: linmult.core.attention.AttentionConfig | None = None, *, add_mms: bool = False, mms_num_layers: int = 6, tam_aligner: str = 'aap', tam_time_dim: int = 0, dropout_tam: float = 0.1, add_unimodal_sat: bool = False, unimodal_sat_num_layers: int = 6, dropout_pe: float = 0.0, dropout_ffn: float = 0.1) Bases: :py:obj:`torch.nn.Module` Orchestrates cross-modal attention across all modalities. Builds and manages all cross-modal branches, including optional multimodal signal (MMS) generation. Each target modality gets its own :class:`CrossModalBranch` with cross-modal transformers from every other source (and from MMS if enabled). :param num_modalities: Number of input modalities (>= 2). :param d_model: Model dimension (shared across all transformers). :param num_heads: Number of attention heads. :param branch_cmt_num_layers: Depth of each cross-modal transformer. :param branch_sat_num_layers: Depth of each branch self-attention transformer. :param attention_config: Attention configuration for all internal transformers. :param add_mms: Whether to create a multimodal signal via TAM. :param mms_num_layers: Depth of the MMS transformer (only when ``add_mms=True``). :param tam_aligner: Temporal alignment method for MMS. :param tam_time_dim: Target time dimension for MMS alignment. :param dropout_tam: Dropout for MMS TAM. :param add_unimodal_sat: Whether to add a unimodal self-attention per branch. :param unimodal_sat_num_layers: Depth of unimodal self-attention. :param dropout_pe: Positional encoding dropout. :param dropout_ffn: FFN dropout (within transformers). Initialize internal Module state, shared by both nn.Module and ScriptModule. .. py:property:: output_dim :type: int Per-branch output dimension (including optional unimodal SAT). .. py:method:: forward(x_list: list[torch.Tensor], mask_list: list[torch.Tensor | None]) -> list[torch.Tensor] Run cross-modal attention for all target modalities. :param x_list: One projected tensor per modality, each ``(B, T_i, d_model)``. :param mask_list: Boolean mask per modality, each ``(B, T_i)`` or ``None``. :returns: List of ``num_modalities`` branch representations, each ``(B, T_tgt, output_dim)``.