linmult.core.norm

Batch normalization and instance normalization wrappers.

Classes

BN

Batch normalization for both sequence (B, T, F) and vector (B, F) inputs.

IN

Instance/layer normalization for sequence (B, T, F) and vector (B, F) inputs.

Module Contents

class linmult.core.norm.BN(feature_dim: int, time_aware: bool)[source]

Bases: torch.nn.Module

Batch normalization for both sequence (B, T, F) and vector (B, F) inputs.

Parameters:
  • feature_dim (int) – Number of features to normalize.

  • time_aware (bool) – If True, expects sequence inputs (B, T, F) and normalizes over the batch and time axes via BatchNorm1d. If False, expects vector inputs (B, F).

Initialize BN.

forward(x: torch.Tensor) torch.Tensor[source]

Apply batch normalization.

Parameters:

x (torch.Tensor) – Shape (B, T, F) when time_aware=True, or (B, F) otherwise.

Returns:

Normalized tensor of the same shape.

Return type:

torch.Tensor

class linmult.core.norm.IN(feature_dim: int, time_aware: bool)[source]

Bases: torch.nn.Module

Instance/layer normalization for sequence (B, T, F) and vector (B, F) inputs.

For sequences (time_aware=True): applies InstanceNorm1d, normalizing each sample and channel independently over the time axis.

For vectors (time_aware=False): applies LayerNorm over the feature axis per sample. InstanceNorm1d is degenerate for single-element sequences (normalizes a scalar to zero), so LayerNorm is the correct choice here.

Parameters:
  • feature_dim (int) – Number of features to normalize.

  • time_aware (bool) – Determines the normalization strategy (see above).

Initialize IN.

forward(x: torch.Tensor) torch.Tensor[source]

Apply instance or layer normalization.

Parameters:

x (torch.Tensor) – Shape (B, T, F) when time_aware=True, or (B, F) otherwise.

Returns:

Normalized tensor of the same shape.

Return type:

torch.Tensor