Source code for linmult.core.tcn

"""Temporal Convolutional Network (TCN) for local temporal smoothing.

Provides dilated causal 1-D convolution layers that capture short-range
temporal dynamics (e.g. micro-expressions, motion patterns) without leaking
future information.  Designed to sit after the projection Conv1d and before
cross-modal attention in the LinMulT pipeline.
"""

import torch
import torch.nn as nn
import torch.nn.functional as F


[docs] class TCNLayer(nn.Module): """Single dilated causal Conv1d layer with residual connection. Computes:: x + dropout(relu(bn(causal_conv1d(x)))) Causal padding is applied on the left so that output at time *t* depends only on inputs at times ``<= t``. Args: d_model (int): Number of input and output channels. kernel_size (int): Convolution kernel size. Defaults to ``3``. dilation (int): Dilation factor. Defaults to ``1``. dropout (float): Dropout probability after activation. Defaults to ``0.1``. """ def __init__( self, d_model: int, kernel_size: int = 3, dilation: int = 1, dropout: float = 0.1, ) -> None: super().__init__() self.left_pad = (kernel_size - 1) * dilation self.conv = nn.Conv1d( d_model, d_model, kernel_size, dilation=dilation, padding=0, bias=False, ) self.norm = nn.BatchNorm1d(d_model) self.dropout = dropout
[docs] def forward(self, x: torch.Tensor) -> torch.Tensor: """Apply causal convolution with residual. Args: x (torch.Tensor): Input ``(B, T, d_model)``. Returns: torch.Tensor: Output ``(B, T, d_model)``, same shape as input. """ residual = x # (B, T, C) -> (B, C, T) out = x.transpose(1, 2) out = F.pad(out, (self.left_pad, 0)) out = self.conv(out) out = self.norm(out) out = F.relu(out) out = F.dropout(out, p=self.dropout, training=self.training) # (B, C, T) -> (B, T, C) out = out.transpose(1, 2) return out + residual
[docs] class TCN(nn.Module): """Stack of :class:`TCNLayer` with exponentially increasing dilation. Dilations are ``[1, 2, 4, ..., 2^(num_layers-1)]``, giving a receptive field of ``1 + sum((kernel_size - 1) * 2^i for i in range(num_layers))`` frames. With the defaults (``num_layers=3, kernel_size=3``) the receptive field is 15 frames (~0.5 s at 30 fps). Args: d_model (int): Channel dimension (preserved through all layers). num_layers (int): Number of dilated convolution layers. Defaults to ``3``. kernel_size (int): Kernel size for every layer. Defaults to ``3``. dropout (float): Dropout probability in each layer. Defaults to ``0.1``. """ def __init__( self, d_model: int, num_layers: int = 3, kernel_size: int = 3, dropout: float = 0.1, ) -> None: super().__init__() self.layers = nn.ModuleList( [ TCNLayer(d_model, kernel_size, dilation=2**i, dropout=dropout) for i in range(num_layers) ] )
[docs] def forward(self, x: torch.Tensor) -> torch.Tensor: """Apply all TCN layers sequentially. Args: x (torch.Tensor): Input ``(B, T, d_model)``. Returns: torch.Tensor: Output ``(B, T, d_model)``, temporally smoothed. """ for layer in self.layers: x = layer(x) return x