linmult.core.utils

Utility functions: config loading and logit aggregation.

Functions

load_config(→ dict)

Load a YAML configuration file.

apply_logit_aggregation(→ torch.Tensor)

Aggregate logits across the time dimension.

Module Contents

linmult.core.utils.load_config(config_file: str | pathlib.Path) dict[source]

Load a YAML configuration file.

Parameters:

config_file (str | Path) – Path to the YAML file.

Returns:

Parsed configuration dictionary.

Return type:

dict

linmult.core.utils.apply_logit_aggregation(x: torch.Tensor, mask: torch.Tensor | None = None, method: str = 'meanpooling') torch.Tensor[source]

Aggregate logits across the time dimension.

Only timesteps where the mask is True (valid) contribute to the result. Fully-masked samples (all False) return a zero vector.

Parameters:
  • x (torch.Tensor) – Logit tensor of shape (B, T, F).

  • mask (torch.Tensor, optional) – Boolean validity mask of shape (B, T). True = valid timestep. If None, all timesteps are treated as valid.

  • method (str) –

    Aggregation method. One of:

    • "meanpooling": Masked mean over the time dimension.

    • "maxpooling": Masked max over the time dimension.

Returns:

Aggregated output of shape (B, F).

Return type:

torch.Tensor

Raises:

ValueError – If method is not one of the supported values.