Source code for linmult.core.pe
"""Sinusoidal positional encoding with optional dropout."""
import math
import torch
import torch.nn as nn
[docs]
class PositionalEncoding(nn.Module):
"""Sinusoidal positional encoding for sequence inputs.
Adds fixed sinusoidal position encodings to the input tensor, following
Vaswani et al. (2017). The encoding matrix is computed lazily and cached;
it is only recomputed when the sequence is longer or the feature dimension
changes.
Args:
dropout (float): Dropout probability applied after adding the encoding.
Defaults to ``0.1``.
"""
def __init__(self, dropout: float = 0.1):
"""Initialize PositionalEncoding."""
super().__init__()
self.dropout = nn.Dropout(p=dropout)
self.register_buffer("pe", None, persistent=False)
[docs]
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Add sinusoidal positional encoding to the input.
The encoding matrix is rebuilt only when the cached tensor is shorter
than the current sequence or the feature dimensionality has changed.
For odd feature dimensions, the cosine slot count is ``floor(F/2)``
while the sine slot count is ``ceil(F/2)``; the division term is
sliced accordingly so no index is out of range.
Args:
x (torch.Tensor): Input tensor of shape ``(B, T, F)``.
Returns:
torch.Tensor: Encoded tensor of shape ``(B, T, F)`` with dropout applied.
"""
_, time_dim, feature_dim = x.shape
# Rebuild only when the cached PE is too short or the feature dim changed.
# A larger cache is reused by slicing, avoiding recomputation on shorter sequences.
if self.pe is None or self.pe.size(1) < time_dim or self.pe.size(2) != feature_dim:
pe = torch.zeros(time_dim, feature_dim, device=x.device)
position = torch.arange(0, time_dim, dtype=torch.float, device=x.device).unsqueeze(1)
div_term = torch.exp(
torch.arange(0, feature_dim, 2, dtype=torch.float, device=x.device)
* (-math.log(10000.0) / feature_dim)
)
pe[:, 0::2] = torch.sin(position * div_term)
# For odd feature_dim, 1::2 has floor(F/2) slots but div_term has ceil(F/2) — slice.
pe[:, 1::2] = torch.cos(position * div_term[: feature_dim // 2])
self.pe = pe.unsqueeze(0) # (1, T, F)
x = x + self.pe[:, :time_dim, :]
return self.dropout(x)