linmult.core.projection¶
Projection module for LinMulT and LinT.
ProjectionModule handles per-modality feature projection with optional special handling (e.g. weighted-sum over transformer layers).
Classes¶
Projects each modality's features to a shared dimension. |
Module Contents¶
- class linmult.core.projection.ProjectionModule(input_feature_dims: list[int], d_model: int, dropout: float = 0.0, special_handling: dict[str, Any] | None = None)[source]¶
Bases:
torch.nn.ModuleProjects each modality’s features to a shared dimension.
For each modality, applies a 1-D convolution (kernel=1) that maps
input_feature_dims[i]tod_model. Before projection, optional special handling (e.g. weighted-sum aggregation over stacked transformer layers) can reduce a 4-D input(B, N, T, F)to(B, T, F).- Parameters:
input_feature_dims (list[int]) – Feature dimension per modality.
d_model (int) – Target projection dimension.
dropout (float) – Dropout applied to inputs before projection.
special_handling (dict[str, Any], optional) – Dict mapping modality names to handling specs. Currently supports
{"type": "weighted_sum", "start_layer": int, "end_layer": int}.
Initialize internal Module state, shared by both nn.Module and ScriptModule.
- forward(x_list: list[torch.Tensor], names: list[str] | None = None) list[torch.Tensor][source]¶
Project each modality to
d_model.- Parameters:
x_list – One tensor per modality, each
(B, T, F)or(B, N, T, F)for weighted-sum inputs.names – Optional modality names for special handling lookup.
- Returns:
List of projected tensors, each
(B, T, d_model).