linmult.core.projection

Projection module for LinMulT and LinT.

ProjectionModule handles per-modality feature projection with optional special handling (e.g. weighted-sum over transformer layers).

Classes

ProjectionModule

Projects each modality's features to a shared dimension.

Module Contents

class linmult.core.projection.ProjectionModule(input_feature_dims: list[int], d_model: int, dropout: float = 0.0, special_handling: dict[str, Any] | None = None)[source]

Bases: torch.nn.Module

Projects each modality’s features to a shared dimension.

For each modality, applies a 1-D convolution (kernel=1) that maps input_feature_dims[i] to d_model. Before projection, optional special handling (e.g. weighted-sum aggregation over stacked transformer layers) can reduce a 4-D input (B, N, T, F) to (B, T, F).

Parameters:
  • input_feature_dims (list[int]) – Feature dimension per modality.

  • d_model (int) – Target projection dimension.

  • dropout (float) – Dropout applied to inputs before projection.

  • special_handling (dict[str, Any], optional) – Dict mapping modality names to handling specs. Currently supports {"type": "weighted_sum", "start_layer": int, "end_layer": int}.

Initialize internal Module state, shared by both nn.Module and ScriptModule.

forward(x_list: list[torch.Tensor], names: list[str] | None = None) list[torch.Tensor][source]

Project each modality to d_model.

Parameters:
  • x_list – One tensor per modality, each (B, T, F) or (B, N, T, F) for weighted-sum inputs.

  • names – Optional modality names for special handling lookup.

Returns:

List of projected tensors, each (B, T, d_model).