Source code for AIModels.ClimFormerAttn2
import math
from typing import Optional
import torch
import torch.nn as nn
import transformers as tr
[docs]
class FeatureSelfAttention(nn.Module):
"""Feature‑wise self‑attention **with residual gating**.
The layer treats *features* as tokens and performs scaled dot‑product
attention **independently at every time step**. A *residual* connection
with a learnable scalar gate `gamma` (initialised to *0*) preserves the
original scale of each feature series while letting the model gradually
incorporate cross‑feature information:
.. code-block::
out = x + gamma * A(x) # A(x) = attention‑mixed features
Because `gamma` starts at *0*, the network behaves as an **identity mapping**
at initialisation, eliminating the risk of early‑training scale collapse
observed with a plain weighted average.
"""
def __init__(self, feature_dim: int, projection_dim: Optional[int] = None):
super().__init__()
projection_dim = projection_dim or feature_dim
# Linear projections for Q and K (values are the raw features)
self.W_q = nn.Linear(feature_dim, projection_dim, bias=False)
self.W_k = nn.Linear(feature_dim, projection_dim, bias=False)
self.scale = math.sqrt(projection_dim)
# Per‑feature residual gate γ ∈ ℝ^F (broadcast over batch/time)
self.gamma = nn.Parameter(torch.zeros(1, 1, feature_dim)) # shape (1,1,F)
[docs]
def forward(self, x: torch.Tensor):
"""Return rescaled output and attention weights.
Parameters
----------
x : torch.Tensor, shape *(B, T, F)*.
"""
Q = self.W_q(x) # (B, T, F_q)
K = self.W_k(x) # (B, T, F_q)
# Compute attention across feature dimension
# scores shape: (B, T, F, F)
scores = torch.matmul(Q.unsqueeze(-1), K.unsqueeze(-2)) / self.scale
attn_weights = torch.softmax(scores, dim=-1)
# Weighted sum of original values (B, T, F)
mixed = torch.matmul(attn_weights, x.unsqueeze(-1)).squeeze(-1)
# Residual gating preserves scale
out = x + self.gamma * mixed # per‑feature gating
return out, attn_weights
[docs]
class ClimFormer(tr.InformerForPrediction):
"""InformerForPrediction with feature‑wise residual self‑attention."""
def __init__(self, config):
super().__init__(config)
self.feature_attention = FeatureSelfAttention(config.input_size)
def _apply_feature_attention(self, values: torch.Tensor) -> torch.Tensor:
weighted_values, _ = self.feature_attention(values)
return weighted_values
# ------------------------------------------------------------
# Forward – identical signature to InformerForPrediction
# ------------------------------------------------------------
[docs]
def forward(
self,
past_values: torch.Tensor,
past_time_features: torch.Tensor,
past_observed_mask: torch.Tensor,
static_categorical_features: Optional[torch.Tensor] = None,
static_real_features: Optional[torch.Tensor] = None,
future_values: Optional[torch.Tensor] = None,
future_time_features: Optional[torch.Tensor] = None,
future_observed_mask: Optional[torch.Tensor] = None,
**kwargs,
):
# 1) Apply feature self‑attention with residual scaling
past_values = self._apply_feature_attention(past_values)
if future_values is not None:
future_values = self._apply_feature_attention(future_values)
# 2) Continue through Informer
return super().forward(
past_values=past_values,
past_time_features=past_time_features,
past_observed_mask=past_observed_mask,
static_categorical_features=static_categorical_features,
static_real_features=static_real_features,
future_values=future_values,
future_time_features=future_time_features,
future_observed_mask=future_observed_mask,
**kwargs,
)