Source code for linmult.core.utils
"""Utility functions: config loading and logit aggregation."""
from pathlib import Path
import torch
import yaml
[docs]
def load_config(config_file: str | Path) -> dict:
"""Load a YAML configuration file.
Args:
config_file (str | Path): Path to the YAML file.
Returns:
dict: Parsed configuration dictionary.
"""
with open(config_file) as file:
config = yaml.safe_load(file)
return config
[docs]
def apply_logit_aggregation(
x: torch.Tensor, mask: torch.Tensor | None = None, method: str = "meanpooling"
) -> torch.Tensor:
"""Aggregate logits across the time dimension.
Only timesteps where the mask is ``True`` (valid) contribute to the result.
Fully-masked samples (all ``False``) return a zero vector.
Args:
x (torch.Tensor): Logit tensor of shape ``(B, T, F)``.
mask (torch.Tensor, optional): Boolean validity mask of shape ``(B, T)``.
``True`` = valid timestep. If ``None``, all timesteps are treated as valid.
method (str): Aggregation method. One of:
- ``"meanpooling"``: Masked mean over the time dimension.
- ``"maxpooling"``: Masked max over the time dimension.
Returns:
torch.Tensor: Aggregated output of shape ``(B, F)``.
Raises:
ValueError: If ``method`` is not one of the supported values.
"""
m: torch.Tensor = (
mask
if mask is not None
else torch.ones(size=x.shape[:2], dtype=torch.bool, device=x.device) # (B, T)
)
if method == "maxpooling":
x_masked = x.masked_fill(~m.unsqueeze(-1), float("-inf")) # (B, T, F)
result = torch.max(x_masked, dim=1)[0] # (B, F)
fully_masked = ~m.any(dim=1) # (B,)
if fully_masked.any():
result = result.masked_fill(fully_masked.unsqueeze(-1), 0.0)
return result
elif method == "meanpooling":
x_masked = x.masked_fill(~m.unsqueeze(-1), 0.0) # (B, T, F)
valid_counts = m.sum(dim=1, keepdim=True).clamp(min=1) # (B, 1)
return x_masked.sum(dim=1) / valid_counts # (B, F)
else:
raise ValueError(f"Method {method} for logit aggregation is not supported.")