Source code for linmult.core.ffn

"""FFN residual block: two linear layers with GELU activation and residual connection."""

import torch
import torch.nn as nn
import torch.nn.functional as F


[docs] class FFNResidual(nn.Module): """Two-layer FFN with GELU activation, dropout, and residual connection. Computes ``x + fc2(dropout(gelu(fc1(x))))``. Args: dim (int): Input and output feature dimension. dropout (float): Dropout probability applied after the first linear layer. Defaults to ``0.0``. """ def __init__(self, dim: int, dropout: float = 0.0): super().__init__() self.fc1 = nn.Linear(dim, dim) self.fc2 = nn.Linear(dim, dim) self.dropout = dropout
[docs] def forward(self, x: torch.Tensor) -> torch.Tensor: """Apply FFN + residual. Args: x (torch.Tensor): Input tensor of any shape with last dim ``dim``. Returns: torch.Tensor: Same shape as ``x``. """ return self.fc2(F.dropout(F.gelu(self.fc1(x)), p=self.dropout, training=self.training)) + x