linmult.core.heads ================== .. py:module:: linmult.core.heads .. autoapi-nested-parse:: Output head types, factory, and HeadModule for LinMulT/LinT models. Classes ------- .. autoapisummary:: linmult.core.heads.BaseHead linmult.core.heads.SequenceAggregationHead linmult.core.heads.SequenceHead linmult.core.heads.VectorHead linmult.core.heads.SimpleHead linmult.core.heads.UpsampleHead linmult.core.heads.DownsampleHead linmult.core.heads.HeadFactory linmult.core.heads.HeadModule Module Contents --------------- .. py:class:: BaseHead(_input_dim: int, _output_dim: int, config: linmult.core.config.HeadConfig) Bases: :py:obj:`torch.nn.Module` Abstract base class for all output heads. Subclasses must implement :meth:`forward`. Use :meth:`from_config` as the standard factory entry point; it simply delegates to ``__init__``. :param _input_dim: Input feature dimensionality (stored for subclass use). :type _input_dim: int :param _output_dim: Output feature dimensionality. :type _output_dim: int :param config: Head configuration. :type config: HeadConfig Initialize BaseHead. .. py:method:: extra_repr() -> str Return the head name for identification in repr output. .. py:method:: from_config(input_dim: int, output_dim: int, config: linmult.core.config.HeadConfig) -> BaseHead :classmethod: Construct a head from keyword arguments. :param input_dim: Input feature dimensionality. :type input_dim: int :param output_dim: Output feature dimensionality. :type output_dim: int :param config: Head configuration. :type config: HeadConfig :returns: A new instance of this head class. :rtype: BaseHead .. py:class:: SequenceAggregationHead(input_dim: int, output_dim: int, config: linmult.core.config.HeadConfig) Bases: :py:obj:`BaseHead` Output head that aggregates a sequence to a single vector. Maps ``(B, T, F)`` → ``(B, output_dim)`` by normalizing, projecting to a hidden dimension, pooling along the time axis, and projecting to the output dimension. :param input_dim: Input feature dimensionality. :type input_dim: int :param output_dim: Output feature dimensionality. :type output_dim: int :param config: Head configuration. Relevant attributes: - ``norm`` (str): Normalisation type, ``"bn"`` or ``"in"``. Default ``"bn"``. - ``pooling`` (str): Pooling type, ``"gap"``, ``"gmp"``, or ``"attentionpool"``. Default ``"gap"``. - ``hidden_dim`` (int): Hidden projection size. Default ``256``. - ``dropout`` (float): Dropout in the first projection. Default ``0.1``. :type config: HeadConfig Initialize SequenceAggregationHead. .. py:method:: forward(x: torch.Tensor, mask: torch.Tensor | None = None) -> torch.Tensor Aggregate and project. :param x: Input of shape ``(B, T, F)``. :type x: torch.Tensor :param mask: Bool mask of shape ``(B, T)``. True = valid. :type mask: torch.Tensor, optional :returns: Output of shape ``(B, output_dim)``. :rtype: torch.Tensor .. py:class:: SequenceHead(input_dim: int, output_dim: int, config: linmult.core.config.HeadConfig) Bases: :py:obj:`BaseHead` Output head that preserves the time dimension. Maps ``(B, T, F)`` → ``(B, T, output_dim)`` by normalizing and projecting each timestep independently. :param input_dim: Input feature dimensionality. :type input_dim: int :param output_dim: Output feature dimensionality. :type output_dim: int :param config: Head configuration. Relevant attributes: - ``norm`` (str): Normalisation type, ``"bn"`` or ``"in"``. Default ``"bn"``. - ``hidden_dim`` (int): Hidden projection size. Default ``256``. - ``dropout`` (float): Dropout in the projection. Default ``0.1``. :type config: HeadConfig Initialize SequenceHead. .. py:method:: forward(x: torch.Tensor, mask: torch.Tensor | None = None) -> torch.Tensor Normalize and project each timestep. :param x: Input of shape ``(B, T, F)``. :type x: torch.Tensor :param mask: Bool mask of shape ``(B, T)``. True = valid. :type mask: torch.Tensor, optional :returns: Output of shape ``(B, T, output_dim)``. :rtype: torch.Tensor .. py:class:: VectorHead(input_dim: int, output_dim: int, config: linmult.core.config.HeadConfig) Bases: :py:obj:`BaseHead` Output head for vector (already-aggregated) inputs. Maps ``(B, F)`` → ``(B, output_dim)`` by normalizing and projecting. :param input_dim: Input feature dimensionality. :type input_dim: int :param output_dim: Output feature dimensionality. :type output_dim: int :param config: Head configuration. Relevant attributes: - ``norm`` (str): Normalisation type, ``"bn"`` or ``"in"``. Default ``"bn"``. - ``hidden_dim`` (int): Hidden projection size. Default ``256``. - ``dropout`` (float): Dropout in the projection. Default ``0.1``. :type config: HeadConfig Initialize VectorHead. .. py:method:: forward(x: torch.Tensor, **_kwargs) -> torch.Tensor Normalize and project a vector. :param x: Input of shape ``(B, F)``. :type x: torch.Tensor :returns: Output of shape ``(B, output_dim)``. :rtype: torch.Tensor .. py:class:: SimpleHead(input_dim: int, output_dim: int, config: linmult.core.config.HeadConfig) Bases: :py:obj:`BaseHead` Lightweight linear head with optional time-dimension pooling. Applies an optional pooling step followed by a single linear projection. Depending on the ``pooling`` config attribute, the mapping is: - No pooling (``None``): ``(B, T, F)`` → ``(B, T, output_dim)`` - With pooling (``"gap"`` / ``"gmp"`` / ``"attentionpool"``): ``(B, T, F)`` → ``(B, output_dim)`` :param input_dim: Input feature dimensionality. :type input_dim: int :param output_dim: Output feature dimensionality. :type output_dim: int :param config: Head configuration. Relevant attribute: - ``pooling`` (str, optional): One of ``"gap"``, ``"gmp"``, ``"attentionpool"``, or ``None`` (no pooling). :type config: HeadConfig Initialize SimpleHead. .. py:method:: forward(x: torch.Tensor, mask: torch.Tensor | None = None, **_kwargs) -> torch.Tensor Apply optional pooling then linear projection. :param x: Input of shape ``(B, T, F)`` or ``(B, F)``. :type x: torch.Tensor :param mask: Bool mask of shape ``(B, T)``. True = valid. Passed through to pooling layers when ``pool`` is configured. :type mask: torch.Tensor, optional :returns: Output of shape ``(B, output_dim)`` if pooled, otherwise ``(B, T, output_dim)``. :rtype: torch.Tensor .. py:class:: UpsampleHead(input_dim: int, output_dim: int, config: linmult.core.config.HeadConfig) Bases: :py:obj:`BaseHead` Output head with learnable temporal upsampling. Maps ``(B, T_in, F)`` → ``(B, output_time_dim, output_dim)`` by projecting the feature dimension, applying a stack of transposed convolutions (each doubling the time axis), then a final adaptive pool to hit the exact target. :param input_dim: Input feature dimensionality. :type input_dim: int :param output_dim: Output feature dimensionality. :type output_dim: int :param config: Head configuration. Required attributes: - ``output_time_dim`` (int): Target time dimension. - ``input_time_dim`` (int): Source time dimension. - ``dropout`` (float): Dropout probability. Default ``0.1``. :type config: HeadConfig Initialize UpsampleHead. .. py:method:: forward(x: torch.Tensor, mask: torch.Tensor | None = None) -> torch.Tensor Upsample and project. :param x: Input of shape ``(B, T_in, F)``. :type x: torch.Tensor :param mask: Bool mask of shape ``(B, T_in)``. True = valid. Masked positions are zeroed before processing. :type mask: torch.Tensor, optional :returns: Output of shape ``(B, output_time_dim, output_dim)``. :rtype: torch.Tensor .. py:class:: DownsampleHead(input_dim: int, output_dim: int, config: linmult.core.config.HeadConfig) Bases: :py:obj:`BaseHead` Output head with learnable temporal downsampling. Maps ``(B, T_in, F)`` → ``(B, output_time_dim, output_dim)`` by projecting the feature dimension, applying strided convolutions (each halving the time axis), then a final adaptive average pool to hit the exact target. :param input_dim: Input feature dimensionality. :type input_dim: int :param output_dim: Output feature dimensionality. :type output_dim: int :param config: Head configuration. Required attributes: - ``output_time_dim`` (int): Target time dimension. - ``input_time_dim`` (int): Source time dimension. - ``dropout`` (float): Dropout probability. Default ``0.1``. :type config: HeadConfig Initialize DownsampleHead. .. py:method:: forward(x: torch.Tensor, mask: torch.Tensor | None = None) -> torch.Tensor Downsample and project. :param x: Input of shape ``(B, T_in, F)``. :type x: torch.Tensor :param mask: Bool mask of shape ``(B, T_in)``. True = valid. Masked positions are zeroed before processing. :type mask: torch.Tensor, optional :returns: Output of shape ``(B, output_time_dim, output_dim)``. :rtype: torch.Tensor .. py:class:: HeadFactory Registry and factory for output head types. New head classes can be registered at runtime with :meth:`register_head`, then instantiated by name with :meth:`create_head`. Built-in types: ``"sequence_aggregation"``, ``"sequence"``, ``"vector"``, ``"simple"``, ``"upsample"``, ``"downsample"``. .. py:method:: register_head(name: str, head_cls: type[BaseHead]) -> None :classmethod: Register a custom head class under a given name. :param name: Registry key used in ``config["type"]``. :type name: str :param head_cls: Head class to register. :type head_cls: type[BaseHead] .. py:method:: create_head(type: str, input_dim: int, output_dim: int, config: linmult.core.config.HeadConfig) -> BaseHead :classmethod: Instantiate a registered head by type name. :param type: Registered head type name. :type type: str :param input_dim: Input feature dimensionality. :type input_dim: int :param output_dim: Output feature dimensionality. :type output_dim: int :param config: Head configuration. :type config: HeadConfig :returns: The constructed head module. :rtype: BaseHead :raises ValueError: If ``type`` is not registered. .. py:class:: HeadModule(input_dim: int, head_configs: list[linmult.core.config.HeadConfig]) Bases: :py:obj:`torch.nn.Module` Self-contained output head container. Builds all output heads from a list of :class:`HeadConfig` using :class:`HeadFactory`, and applies them in the forward pass. :param input_dim: Input feature dimension fed to each head. :param head_configs: List of head configurations. Initialize internal Module state, shared by both nn.Module and ScriptModule. .. py:method:: forward(x: torch.Tensor, mask: torch.Tensor | None = None) -> dict[str, torch.Tensor] Apply all heads to the input. :param x: Input tensor ``(B, [T,] input_dim)``. :param mask: Optional boolean mask ``(B, T)``. :returns: Dict mapping head name to output tensor.