linmult.core.transformer ======================== .. py:module:: linmult.core.transformer .. autoapi-nested-parse:: Transformer encoder: stacked pre-norm layers with multi-head attention and FFN. Classes ------- .. autoapisummary:: linmult.core.transformer.TransformerEncoder linmult.core.transformer.TransformerEncoderLayer Module Contents --------------- .. py:class:: TransformerEncoder(d_model: int = 40, num_heads: int = 8, num_layers: int = 6, attention_config: linmult.core.attention.AttentionConfig | None = None, dropout_pe: float = 0.0, dropout_ffn: float = 0.1, is_cross_modal: bool = False, name: str = '') Bases: :py:obj:`torch.nn.Module` Transformer encoder with multiple stacked layers. Supports both self-attention (when ``x_k`` and ``x_v`` are omitted) and cross-modal attention (when ``x_k`` and ``x_v`` are provided). :param d_model: Input and output feature dimensionality. Defaults to ``40``. :type d_model: int :param num_heads: Number of attention heads. Defaults to ``8``. :type num_heads: int :param num_layers: Number of stacked encoder layers. Defaults to ``6``. :type num_layers: int :param attention_config: Attention type and parameters. Defaults to ``AttentionConfig()`` (linear attention). :type attention_config: AttentionConfig, optional :param dropout_pe: Dropout after positional encoding. Defaults to ``0.0``. :type dropout_pe: float :param dropout_ffn: Dropout in the FFN sub-layer. Defaults to ``0.1``. :type dropout_ffn: float :param is_cross_modal: Allocate a separate layer-norm for cross-modal key input. Set to ``True`` for cross-modal attention encoders. Defaults to ``False``. :type is_cross_modal: bool :param name: Module name shown in ``repr``. Defaults to ``""``. :type name: str Initialize internal Module state, shared by both nn.Module and ScriptModule. .. py:method:: extra_repr() -> str Return the module name for identification in repr output. .. py:method:: forward(x_q: torch.Tensor, x_k: torch.Tensor | None = None, x_v: torch.Tensor | None = None, query_mask: torch.Tensor | None = None, key_mask: torch.Tensor | None = None) -> torch.Tensor Run the transformer encoder. When ``x_k`` and ``x_v`` are omitted the encoder runs self-attention (``x_q == x_k == x_v``). When provided it runs cross-modal attention. :param x_q: Query input of shape ``(B, T_1, F)``. :type x_q: torch.Tensor :param x_k: Key input of shape ``(B, T_2, F)``. :type x_k: torch.Tensor, optional :param x_v: Value input of shape ``(B, T_2, F)``. :type x_v: torch.Tensor, optional :param query_mask: Mask for queries, shape ``(B, T_1)``. :type query_mask: torch.BoolTensor, optional :param key_mask: Mask for keys, shape ``(B, T_2)``. :type key_mask: torch.BoolTensor, optional :returns: Encoded output of shape ``(B, T_1, F)``. :rtype: torch.Tensor .. py:class:: TransformerEncoderLayer(d_model: int = 40, num_heads: int = 8, attention_config: linmult.core.attention.AttentionConfig | None = None, dropout: float = 0.1, is_cross_modal: bool = False) Bases: :py:obj:`torch.nn.Module` Single pre-norm transformer encoder layer with attention + FFN. Supports self-attention and cross-modal attention. The cross-modal layer norm (``layer_norm_cross``) is only allocated when ``cross_modal=True``, since pure self-attention layers never receive external keys. :param d_model: Feature dimensionality. Defaults to ``40``. :type d_model: int :param num_heads: Number of attention heads. Defaults to ``8``. :type num_heads: int :param attention_config: Attention type and parameters. Defaults to ``AttentionConfig()`` (linear attention). :type attention_config: AttentionConfig, optional :param dropout: Dropout on FFN and residual connections. Defaults to ``0.1``. :type dropout: float :param is_cross_modal: Allocate a cross-modal layer-norm. Defaults to ``False``. :type is_cross_modal: bool Initialize internal Module state, shared by both nn.Module and ScriptModule. .. py:method:: forward(x_q: torch.Tensor, x_k: torch.Tensor | None = None, x_v: torch.Tensor | None = None, query_mask: torch.Tensor | None = None, key_mask: torch.Tensor | None = None) -> torch.Tensor Run one transformer encoder layer. :param x_q: Query input of shape ``(B, T_1, F)``. :type x_q: torch.Tensor :param x_k: Key input of shape ``(B, T_2, F)``. :type x_k: torch.Tensor, optional :param x_v: Value input of shape ``(B, T_2, F)``. :type x_v: torch.Tensor, optional :param query_mask: Mask for queries, shape ``(B, T_1)``. :type query_mask: torch.BoolTensor, optional :param key_mask: Mask for keys, shape ``(B, T_2)``. :type key_mask: torch.BoolTensor, optional :returns: Layer output of shape ``(B, T_1, F)``. :rtype: torch.Tensor :raises ValueError: If mask shapes or dtypes are incorrect.