linmult.core.projection¶
Projection module for LinMulT and LinT.
ProjectionModule handles per-modality feature projection with optional special handling (e.g. weighted-sum over transformer layers) and optional per-modality TCN for temporal smoothing.
Classes¶
Projects each modality's features to a shared dimension. |
Module Contents¶
- class linmult.core.projection.ProjectionModule(input_feature_dims: list[int], d_model: int, dropout: float = 0.0, special_handling: dict[str, Any] | None = None, add_tcn: bool = True, tcn_num_layers: int = 3, tcn_kernel_size: int = 3, tcn_dropout: float = 0.1)[source]¶
Bases:
torch.nn.ModuleProjects each modality’s features to a shared dimension.
For each modality, applies a 1-D convolution (kernel=1) that maps
input_feature_dims[i]tod_model. Before projection, optional special handling (e.g. weighted-sum aggregation over stacked transformer layers) can reduce a 4-D input(B, N, T, F)to(B, T, F).When TCN is enabled, a per-modality
TCNis applied after projection to smooth frame-level features temporally.- Parameters:
input_feature_dims (list[int]) – Feature dimension per modality.
d_model (int) – Target projection dimension.
dropout (float) – Dropout applied to inputs before projection.
special_handling (dict[str, Any], optional) – Dict mapping modality names to handling specs. Currently supports
{"type": "weighted_sum", "start_layer": int, "end_layer": int}.add_tcn (bool) – Enable per-modality TCN after projection. Defaults to
True.tcn_num_layers (int) – Number of dilated causal convolution layers. Defaults to
3.tcn_kernel_size (int) – Kernel size for each TCN layer. Defaults to
3.tcn_dropout (float) – Dropout in each TCN layer. Defaults to
0.1.
Initialize internal Module state, shared by both nn.Module and ScriptModule.
- forward(x_list: list[torch.Tensor], names: list[str] | None = None) list[torch.Tensor][source]¶
Project each modality to
d_model.- Parameters:
x_list – One tensor per modality, each
(B, T, F)or(B, N, T, F)for weighted-sum inputs.names – Optional modality names for special handling lookup.
- Returns:
List of projected tensors, each
(B, T, d_model).