Source code for linmult.core.norm

"""Batch normalization and instance normalization wrappers."""

import torch
import torch.nn as nn


[docs] class BN(nn.Module): """Batch normalization for both sequence ``(B, T, F)`` and vector ``(B, F)`` inputs. Args: feature_dim (int): Number of features to normalize. time_aware (bool): If ``True``, expects sequence inputs ``(B, T, F)`` and normalizes over the batch and time axes via ``BatchNorm1d``. If ``False``, expects vector inputs ``(B, F)``. """ def __init__(self, feature_dim: int, time_aware: bool): """Initialize BN.""" super().__init__() self.time_aware = time_aware self.bn = nn.BatchNorm1d(feature_dim)
[docs] def forward(self, x: torch.Tensor) -> torch.Tensor: """Apply batch normalization. Args: x (torch.Tensor): Shape ``(B, T, F)`` when ``time_aware=True``, or ``(B, F)`` otherwise. Returns: torch.Tensor: Normalized tensor of the same shape. """ if self.time_aware and x.ndim == 3: # (B, T, F) x = x.permute(0, 2, 1) # (B, F, T) x = self.bn(x) return x.permute(0, 2, 1) # (B, T, F) return self.bn(x) # (B, F)
[docs] class IN(nn.Module): """Instance/layer normalization for sequence ``(B, T, F)`` and vector ``(B, F)`` inputs. For sequences (``time_aware=True``): applies ``InstanceNorm1d``, normalizing each sample and channel independently over the time axis. For vectors (``time_aware=False``): applies ``LayerNorm`` over the feature axis per sample. ``InstanceNorm1d`` is degenerate for single-element sequences (normalizes a scalar to zero), so ``LayerNorm`` is the correct choice here. Args: feature_dim (int): Number of features to normalize. time_aware (bool): Determines the normalization strategy (see above). """ def __init__(self, feature_dim: int, time_aware: bool): """Initialize IN.""" super().__init__() self.time_aware = time_aware if time_aware: self.in_norm = nn.InstanceNorm1d(feature_dim, affine=True) else: # For vector inputs (B, F): LayerNorm normalizes over features per sample, # which is the correct semantics (InstanceNorm1d with L=1 is degenerate). self.in_norm = nn.LayerNorm(feature_dim)
[docs] def forward(self, x: torch.Tensor) -> torch.Tensor: """Apply instance or layer normalization. Args: x (torch.Tensor): Shape ``(B, T, F)`` when ``time_aware=True``, or ``(B, F)`` otherwise. Returns: torch.Tensor: Normalized tensor of the same shape. """ if self.time_aware and x.ndim == 3: # (B, T, F) x = x.permute(0, 2, 1) # (B, F, T) x = self.in_norm(x) return x.permute(0, 2, 1) # (B, T, F) return self.in_norm(x) # (B, F)