Source code for linmult.core.temporal

"""Temporal reducers, aligners, and composite modules (TRM, TAM)."""

import torch
import torch.nn as nn
import torch.nn.functional as F

from linmult.core.attention import AttentionConfig
from linmult.core.transformer import TransformerEncoder


[docs] class TemporalFactory: """Factory for creating temporal signal aligners and reducers."""
[docs] @staticmethod def create_aligner(method: str = "aap") -> nn.Module: """Create a temporal aligner module. Args: method (str): Aligner type. One of: - ``"aap"``: Adaptive average pooling. - ``"amp"``: Adaptive max pooling. - ``"padding"``: Zero-padding / truncation. Returns: nn.Module: The constructed aligner module. Raises: ValueError: If ``method`` is not one of the supported values. """ registry = { "aap": AdaptiveAvgPooling, "amp": AdaptiveMaxPooling, "padding": TemporalPadding, } if method not in registry: # pragma: no cover raise ValueError(f"Unknown aligner: {method!r}") return registry[method]()
[docs] @staticmethod def create_reducer(d_model: int, reducer: str) -> nn.Module: """Create a temporal reducer module. Args: d_model (int): Feature dimensionality of the input tensor. Required for :class:`AttentionPooling`; ignored by other reducers. reducer (str): Reducer type. One of ``"attentionpool"``, ``"gmp"``, ``"gap"``, ``"last"``. Returns: nn.Module: The constructed reducer module. Raises: ValueError: If ``reducer`` is not one of the supported values. """ if reducer == "attentionpool": return AttentionPooling(d_model) registry = { "gmp": GlobalMaxPooling, "gap": GlobalAvgPooling, "last": LastTimestamp, } if reducer not in registry: # pragma: no cover raise ValueError(f"Unknown reducer: {reducer!r}") return registry[reducer]()
[docs] class TemporalPadding(nn.Module): """Temporal aligner via truncation or zero-padding. Adjusts the time dimension of a tensor to exactly ``time_dim`` by truncating if the sequence is too long, or zero-padding if too short. The mask is updated accordingly (padded positions are ``False``). """
[docs] def forward( self, x: torch.Tensor, time_dim: int, mask: torch.Tensor | None = None ) -> tuple[torch.Tensor, torch.Tensor]: """Truncate or pad the time dimension. Args: x (torch.Tensor): Input tensor of shape ``(B, T, F)``. time_dim (int): Target time dimension. mask (torch.BoolTensor, optional): Validity mask of shape ``(B, T)``. True = valid. If ``None``, all input positions are treated as valid. Returns: tuple[torch.Tensor, torch.Tensor]: Output tensor of shape ``(B, time_dim, F)`` and updated mask of shape ``(B, time_dim)``. """ current_time_dim = x.size(1) if current_time_dim > time_dim: x_new = x[:, :time_dim, :] if mask is not None: mask_new = mask[:, :time_dim] else: mask_new = torch.ones((x.size(0), time_dim), dtype=torch.bool, device=x.device) elif current_time_dim < time_dim: pad_size = time_dim - current_time_dim x_new = F.pad(x, (0, 0, 0, pad_size)) # pad along the time dimension if mask is not None: mask_new = F.pad(mask, (0, pad_size), value=False) else: mask_new = torch.ones( (x.size(0), current_time_dim), dtype=torch.bool, device=x.device ) mask_new = F.pad(mask_new, (0, pad_size), value=False) else: x_new = x mask_new = ( mask if mask is not None else torch.ones((x.size(0), current_time_dim), dtype=torch.bool, device=x.device) ) return x_new, mask_new
[docs] class AdaptiveMaxPooling(nn.Module): """Temporal aligner via adaptive max pooling. Resizes the time dimension of a tensor to ``time_dim`` using ``F.adaptive_max_pool1d``. Masked (padded) positions are filled with ``-inf`` before pooling so they never win the max, and the output mask is derived from the result. """
[docs] def forward( self, x: torch.Tensor, time_dim: int, mask: torch.Tensor | None = None ) -> tuple[torch.Tensor, torch.Tensor]: """Apply adaptive max pooling along the time dimension. Args: x (torch.Tensor): Input tensor of shape ``(B, T, F)``. time_dim (int): Target time dimension. mask (torch.BoolTensor, optional): Validity mask of shape ``(B, T)``. True = valid. Returns: tuple[torch.Tensor, torch.Tensor]: Pooled tensor of shape ``(B, time_dim, F)`` and updated mask of shape ``(B, time_dim)``. """ if mask is not None: expanded_mask = mask.unsqueeze(-1).expand_as(x) # (B, T, F) x = x.masked_fill(~expanded_mask, float("-inf")) x_new = F.adaptive_max_pool1d(x.transpose(1, 2), time_dim).transpose(1, 2) if mask is not None: mask_new = (x_new != float("-inf")).any(dim=-1) # (B, time_dim) # masked_fill (non-in-place) avoids breaking the autograd graph x_new = x_new.masked_fill(x_new == float("-inf"), 0.0) else: mask_new = torch.ones( x_new.size(0), x_new.size(1), dtype=torch.bool, device=x_new.device ) return x_new, mask_new
[docs] class AdaptiveAvgPooling(nn.Module): """Temporal aligner via adaptive average pooling. Resizes the time dimension of a tensor to ``time_dim`` using ``F.adaptive_avg_pool1d``. Masked positions contribute zero to the average and the output is renormalized by the fraction of valid input positions in each output bin. """
[docs] def forward( self, x: torch.Tensor, time_dim: int, mask: torch.Tensor | None = None ) -> tuple[torch.Tensor, torch.Tensor]: """Apply adaptive average pooling along the time dimension. Args: x (torch.Tensor): Input tensor of shape ``(B, T, F)``. time_dim (int): Target time dimension. mask (torch.BoolTensor, optional): Validity mask of shape ``(B, T)``. True = valid. Returns: tuple[torch.Tensor, torch.Tensor]: Pooled tensor of shape ``(B, time_dim, F)`` and updated mask of shape ``(B, time_dim)``. """ if mask is not None: expanded_mask = mask.unsqueeze(-1).expand_as(x) # (B, T, F) x = x.masked_fill(~expanded_mask, 0.0) pooled_mask = F.adaptive_avg_pool1d(mask.float().unsqueeze(1), time_dim).squeeze( 1 ) # (B, time_dim) mask_new = pooled_mask > 1e-8 x_pooled = F.adaptive_avg_pool1d(x.transpose(1, 2), time_dim).transpose( 1, 2 ) # (B, time_dim, F) # Renormalize to account for masked regions. # clamp(min=1e-8) avoids division by zero for fully-masked bins. x_new = x_pooled / pooled_mask.unsqueeze(-1).clamp(min=1e-8) # (B, time_dim, F) else: x_new = F.adaptive_avg_pool1d(x.transpose(1, 2), time_dim).transpose( 1, 2 ) # (B, time_dim, F) mask_new = torch.ones(x.size(0), time_dim, dtype=torch.bool, device=x.device) return x_new, mask_new
[docs] class LastTimestamp(nn.Module): """Temporal reducer that extracts the last valid timestamp. If a mask is provided, selects the feature vector at the last ``True`` position for each sample. Fully-masked samples (all ``False``) return a zero vector. """
[docs] def forward(self, x: torch.Tensor, mask: torch.Tensor | None = None) -> torch.Tensor: """Extract the last valid timestamp. Args: x (torch.Tensor): Input tensor of shape ``(B, T, F)``. mask (torch.BoolTensor, optional): Validity mask of shape ``(B, T)``. If ``None``, the final timestep is selected for all samples. Returns: torch.Tensor: Output tensor of shape ``(B, F)``. """ if mask is not None: valid_counts = mask.sum(dim=1).long() # (B,) last_timestamp_index = (valid_counts - 1).clamp(min=0) else: last_timestamp_index = torch.full( (x.size(0),), x.size(1) - 1, dtype=torch.long, device=x.device ) batch_indices = torch.arange(x.size(0), device=x.device) result = x[batch_indices, last_timestamp_index] # (B, F) if mask is not None: fully_masked = valid_counts == 0 # (B,) if fully_masked.any(): result = result * (~fully_masked).float().unsqueeze(-1) return result # (B, F)
[docs] class GlobalAvgPooling(nn.Module): """Temporal reducer via masked global average pooling. Computes the mean over valid (unmasked) timesteps. If no mask is provided, averages over all timesteps. """
[docs] def forward(self, x: torch.Tensor, mask: torch.Tensor | None = None) -> torch.Tensor: """Apply global average pooling. Args: x (torch.Tensor): Input tensor of shape ``(B, T, F)``. mask (torch.BoolTensor, optional): Validity mask of shape ``(B, T)``. True = valid. If ``None``, all positions are treated as valid. Returns: torch.Tensor: Pooled output of shape ``(B, F)``. """ if mask is not None: expanded_mask = mask.unsqueeze(-1) # (B, T, 1) x = x * expanded_mask sum_x = x.sum(dim=1) # (B, F) count_x = expanded_mask.sum(dim=1).clamp(min=1) # (B, 1) return sum_x / count_x # (B, F) else: return x.mean(dim=1) # (B, F)
[docs] class GlobalMaxPooling(nn.Module): """Temporal reducer via masked global max pooling. Computes the max over valid (unmasked) timesteps. Masked positions are filled with ``-inf`` before the max, and fully-masked samples return zero. """
[docs] def forward(self, x: torch.Tensor, mask: torch.Tensor | None = None) -> torch.Tensor: """Apply global max pooling. Args: x (torch.Tensor): Input tensor of shape ``(B, T, F)``. mask (torch.BoolTensor, optional): Validity mask of shape ``(B, T)``. True = valid. If ``None``, all positions are treated as valid. Returns: torch.Tensor: Pooled output of shape ``(B, F)``. """ fully_masked = None if mask is not None: fully_masked = ~mask.any(dim=1) # (B,) — fully masked samples x = x.masked_fill(~mask.unsqueeze(-1), float("-inf")) result = x.max(dim=1)[0] # (B, F) if fully_masked is not None and fully_masked.any(): result = result.masked_fill(fully_masked.unsqueeze(-1), 0.0) return result # (B, F)
[docs] class AttentionPooling(nn.Module): """Temporal reducer via learned attention-weighted pooling. Learns a scalar attention score per timestep and computes a weighted sum of the input features. Masked positions receive ``-inf`` before the softmax so their weight is zero. Fully-masked samples return a zero vector. Args: d_model (int): Input feature dimensionality. """ def __init__(self, d_model: int): """Initialize AttentionPooling.""" super().__init__() self.attention = nn.Linear(d_model, 1)
[docs] def forward(self, x: torch.Tensor, mask: torch.Tensor | None = None) -> torch.Tensor: """Apply attention-weighted pooling. Args: x (torch.Tensor): Input tensor of shape ``(B, T, F)``. mask (torch.BoolTensor, optional): Validity mask of shape ``(B, T)``. True = valid. Returns: torch.Tensor: Pooled output of shape ``(B, F)``. """ attn_weights = self.attention(x).squeeze(-1) # (B, T) all_masked = None if mask is not None: all_masked = (~mask).all(dim=1) # (B,) — fully masked samples attn_weights = attn_weights.masked_fill(~mask, float("-inf")) if all_masked.any(): # Replace -inf with 0 so softmax gives uniform weights instead of NaN safe = torch.zeros_like(attn_weights) attn_weights = torch.where(all_masked.unsqueeze(1), safe, attn_weights) attn_weights = torch.softmax(attn_weights, dim=-1) # (B, T) weighted_avg = torch.sum(x * attn_weights.unsqueeze(-1), dim=1) # (B, F) # Zero out fully-masked samples (uniform weights would give non-zero average) if all_masked is not None and all_masked.any(): weighted_avg = weighted_avg * (~all_masked).float().unsqueeze(-1) return weighted_avg # (B, F)
[docs] class TRM(nn.Module): """Time Reduce Module: aggregates the time dimension of a sequence tensor. Transforms ``(B, T, F)`` → ``(B, F)`` using a configurable pooling strategy. Args: d_model (int): Input feature dimensionality. Required for ``"attentionpool"``; ignored by ``"gap"``, ``"gmp"``, and ``"last"``. reducer (str): Pooling strategy. One of ``"attentionpool"``, ``"gmp"``, ``"gap"``, ``"last"``. """ def __init__(self, d_model: int, reducer: str): """Initialize TRM.""" super().__init__() if reducer not in {"attentionpool", "gmp", "gap", "last"}: raise ValueError(f"Invalid reducer: {reducer!r}") self.time_dim_reducer = TemporalFactory.create_reducer(d_model=d_model, reducer=reducer)
[docs] def forward(self, x: torch.Tensor, mask: torch.Tensor | None = None) -> torch.Tensor: """Reduce the time dimension. Args: x (torch.Tensor): Input of shape ``(B, T, F)``. mask (torch.Tensor, optional): Validity mask of shape ``(B, T)``. True = valid. Returns: torch.Tensor: Reduced output of shape ``(B, F)``. """ return self.time_dim_reducer(x, mask) # (B, T, F) -> (B, F)
[docs] def apply_to_list( self, x_list: list[torch.Tensor], mask_list: list[torch.Tensor | None] ) -> list[torch.Tensor]: """Apply time reduction independently to each tensor in a list. Args: x_list (list[torch.Tensor]): List of tensors, each of shape ``(B, T, F)``. mask_list (list[torch.Tensor | None]): Corresponding masks, each of shape ``(B, T)``. Returns: list[torch.Tensor]: List of reduced tensors, each of shape ``(B, F)``. """ return [self(x, mask) for x, mask in zip(x_list, mask_list)]
[docs] class TAM(nn.Module): """Time Align Module: aligns the time dimensions of multiple tensors. Transforms a list of ``(B, T_i, F)`` tensors to a single fused tensor ``(B, time_dim, tgt_dim)`` by pooling/padding each sequence to a common ``time_dim``, concatenating along the feature axis, processing with a transformer, and projecting to ``tgt_dim``. Args: input_dim (int): Concatenated input dimensionality (sum of feature dims across modalities). output_dim (int): Output feature dimensionality after projection. aligner (str): Temporal alignment strategy. One of ``"aap"``, ``"amp"``, ``"padding"``. time_dim (int): Target time dimension after alignment. dropout_out (float): Dropout in the output projector. Defaults to ``0.1``. num_layers (int): Depth of the internal transformer encoder. Defaults to ``6``. num_heads (int): Number of attention heads in the internal encoder. Defaults to ``8``. attention_config (AttentionConfig, optional): Attention type and parameters for the internal encoder. Defaults to ``AttentionConfig()`` (linear attention). dropout_pe (float): Positional-encoding dropout for the internal encoder. Defaults to ``0.0``. dropout_ffn (float): FFN dropout for the internal encoder. Defaults to ``0.1``. name (str): Module name shown in ``repr``. Defaults to ``""``. """ def __init__( self, input_dim: int, output_dim: int, aligner: str, time_dim: int, num_layers: int = 6, num_heads: int = 8, attention_config: AttentionConfig | None = None, dropout_pe: float = 0.0, dropout_ffn: float = 0.1, dropout_out: float = 0.1, name: str = "", ): """Initialize TAM.""" super().__init__() if aligner not in {"aap", "amp", "padding"}: raise ValueError(f"Invalid aligner: {aligner!r}") self.name = name self.aligned_time_dim = time_dim self.time_dim_aligner = TemporalFactory.create_aligner(aligner) self.transformer = TransformerEncoder( d_model=input_dim, num_heads=num_heads, num_layers=num_layers, attention_config=attention_config, dropout_pe=dropout_pe, dropout_ffn=dropout_ffn, ) self.projector = nn.Sequential( nn.Conv1d(input_dim, output_dim, kernel_size=1, padding=0, bias=False), nn.BatchNorm1d(output_dim), nn.GELU(), nn.Dropout(dropout_out), nn.Conv1d(output_dim, output_dim, kernel_size=1, padding=0, bias=True), )
[docs] def extra_repr(self) -> str: """Return the module name for identification in repr output.""" return f"name={self.name!r}" # pragma: no cover
[docs] def forward( self, x_list: list[torch.Tensor], mask_list: list[torch.Tensor | None] ) -> tuple[torch.Tensor, torch.Tensor]: """Align, fuse, and project multiple sequences. Args: x_list (list[torch.Tensor]): Input tensors, each of shape ``(B, T_i, F)``. mask_list (list[torch.BoolTensor | None]): Corresponding masks, each of shape ``(B, T_i)`` or ``None`` (treated as all-valid). Returns: tuple[torch.Tensor, torch.Tensor]: Aligned tensor of shape ``(B, time_dim, output_dim)`` and validity mask of shape ``(B, time_dim)``. """ # Align each input to time_dim aligned = [ self.time_dim_aligner(x, self.aligned_time_dim, mask) for x, mask in zip(x_list, mask_list) ] aligned_x = [a[0] for a in aligned] aligned_masks = [a[1] for a in aligned] # AND across modalities to get a joint mask joint_mask = torch.stack(aligned_masks, dim=0).all(dim=0) # (B, time_dim) # Concatenate along the feature axis and run the transformer x = torch.cat(aligned_x, dim=-1) # (B, time_dim, src_dim) x = self.transformer(x, query_mask=joint_mask) # (B, time_dim, src_dim) # Project to tgt_dim x = self.projector(x.transpose(1, 2)).transpose(1, 2) # (B, time_dim, tgt_dim) return x, joint_mask