linmult.core.norm¶
Batch normalization and instance normalization wrappers.
Classes¶
Module Contents¶
- class linmult.core.norm.BN(feature_dim: int, time_aware: bool)[source]¶
Bases:
torch.nn.ModuleBatch normalization for both sequence
(B, T, F)and vector(B, F)inputs.- Parameters:
feature_dim (int) – Number of features to normalize.
time_aware (bool) – If
True, expects sequence inputs(B, T, F)and normalizes over the batch and time axes viaBatchNorm1d. IfFalse, expects vector inputs(B, F).
Initialize BN.
- class linmult.core.norm.IN(feature_dim: int, time_aware: bool)[source]¶
Bases:
torch.nn.ModuleInstance/layer normalization for sequence
(B, T, F)and vector(B, F)inputs.For sequences (
time_aware=True): appliesInstanceNorm1d, normalizing each sample and channel independently over the time axis.For vectors (
time_aware=False): appliesLayerNormover the feature axis per sample.InstanceNorm1dis degenerate for single-element sequences (normalizes a scalar to zero), soLayerNormis the correct choice here.- Parameters:
feature_dim (int) – Number of features to normalize.
time_aware (bool) – Determines the normalization strategy (see above).
Initialize IN.