linmult.core.norm ================= .. py:module:: linmult.core.norm .. autoapi-nested-parse:: Batch normalization and instance normalization wrappers. Classes ------- .. autoapisummary:: linmult.core.norm.BN linmult.core.norm.IN Module Contents --------------- .. py:class:: BN(feature_dim: int, time_aware: bool) Bases: :py:obj:`torch.nn.Module` Batch normalization for both sequence ``(B, T, F)`` and vector ``(B, F)`` inputs. :param feature_dim: Number of features to normalize. :type feature_dim: int :param time_aware: If ``True``, expects sequence inputs ``(B, T, F)`` and normalizes over the batch and time axes via ``BatchNorm1d``. If ``False``, expects vector inputs ``(B, F)``. :type time_aware: bool Initialize BN. .. py:method:: forward(x: torch.Tensor) -> torch.Tensor Apply batch normalization. :param x: Shape ``(B, T, F)`` when ``time_aware=True``, or ``(B, F)`` otherwise. :type x: torch.Tensor :returns: Normalized tensor of the same shape. :rtype: torch.Tensor .. py:class:: IN(feature_dim: int, time_aware: bool) Bases: :py:obj:`torch.nn.Module` Instance/layer normalization for sequence ``(B, T, F)`` and vector ``(B, F)`` inputs. For sequences (``time_aware=True``): applies ``InstanceNorm1d``, normalizing each sample and channel independently over the time axis. For vectors (``time_aware=False``): applies ``LayerNorm`` over the feature axis per sample. ``InstanceNorm1d`` is degenerate for single-element sequences (normalizes a scalar to zero), so ``LayerNorm`` is the correct choice here. :param feature_dim: Number of features to normalize. :type feature_dim: int :param time_aware: Determines the normalization strategy (see above). :type time_aware: bool Initialize IN. .. py:method:: forward(x: torch.Tensor) -> torch.Tensor Apply instance or layer normalization. :param x: Shape ``(B, T, F)`` when ``time_aware=True``, or ``(B, F)`` otherwise. :type x: torch.Tensor :returns: Normalized tensor of the same shape. :rtype: torch.Tensor