linmult.core.utils¶
Utility functions: config loading and logit aggregation.
Functions¶
|
Load a YAML configuration file. |
|
Aggregate logits across the time dimension. |
Module Contents¶
- linmult.core.utils.load_config(config_file: str | pathlib.Path) dict[source]¶
Load a YAML configuration file.
- Parameters:
config_file (str | Path) – Path to the YAML file.
- Returns:
Parsed configuration dictionary.
- Return type:
dict
- linmult.core.utils.apply_logit_aggregation(x: torch.Tensor, mask: torch.Tensor | None = None, method: str = 'meanpooling') torch.Tensor[source]¶
Aggregate logits across the time dimension.
Only timesteps where the mask is
True(valid) contribute to the result. Fully-masked samples (allFalse) return a zero vector.- Parameters:
x (torch.Tensor) – Logit tensor of shape
(B, T, F).mask (torch.Tensor, optional) – Boolean validity mask of shape
(B, T).True= valid timestep. IfNone, all timesteps are treated as valid.method (str) –
Aggregation method. One of:
"meanpooling": Masked mean over the time dimension."maxpooling": Masked max over the time dimension.
- Returns:
Aggregated output of shape
(B, F).- Return type:
torch.Tensor
- Raises:
ValueError – If
methodis not one of the supported values.