"""Transformer class.""" import logging import math from collections import OrderedDict from pathlib import Path from typing import Literal, Tuple import torch import torch.nn.functional as F import yaml from torch import nn import sys, os from .utils import blockwise_causal_norm logger = logging.getLogger(__name__) def _pos_embed_fourier1d_init( cutoff: float = 256, n: int = 32, cutoff_start: float = 1 ): return ( torch.exp(torch.linspace(-math.log(cutoff_start), -math.log(cutoff), n)) .unsqueeze(0) .unsqueeze(0) ) def _rope_pos_embed_fourier1d_init(cutoff: float = 128, n: int = 32): # Maximum initial frequency is 1 return torch.exp(torch.linspace(0, -math.log(cutoff), n)).unsqueeze(0).unsqueeze(0) def _rotate_half(x: torch.Tensor) -> torch.Tensor: """Rotate pairs of scalars as 2d vectors by pi/2.""" x = x.unflatten(-1, (-1, 2)) x1, x2 = x.unbind(dim=-1) return torch.stack((-x2, x1), dim=-1).flatten(start_dim=-2) class RotaryPositionalEncoding(nn.Module): def __init__(self, cutoffs: Tuple[float] = (256,), n_pos: Tuple[int] = (32,)): super().__init__() assert len(cutoffs) == len(n_pos) if not all(n % 2 == 0 for n in n_pos): raise ValueError("n_pos must be even") self._n_dim = len(cutoffs) self.freqs = nn.ParameterList([ nn.Parameter(_rope_pos_embed_fourier1d_init(cutoff, n // 2)) for cutoff, n in zip(cutoffs, n_pos) ]) def get_co_si(self, coords: torch.Tensor): _B, _N, D = coords.shape assert D == len(self.freqs) co = torch.cat( tuple( torch.cos(0.5 * math.pi * x.unsqueeze(-1) * freq) / math.sqrt(len(freq)) for x, freq in zip(coords.moveaxis(-1, 0), self.freqs) ), axis=-1, ) si = torch.cat( tuple( torch.sin(0.5 * math.pi * x.unsqueeze(-1) * freq) / math.sqrt(len(freq)) for x, freq in zip(coords.moveaxis(-1, 0), self.freqs) ), axis=-1, ) return co, si def forward(self, q: torch.Tensor, k: torch.Tensor, coords: torch.Tensor): _B, _N, D = coords.shape _B, _H, _N, _C = q.shape if D != self._n_dim: raise ValueError(f"coords must have {self._n_dim} dimensions, got {D}") co, si = self.get_co_si(coords) co = co.unsqueeze(1).repeat_interleave(2, dim=-1) si = si.unsqueeze(1).repeat_interleave(2, dim=-1) q2 = q * co + _rotate_half(q) * si k2 = k * co + _rotate_half(k) * si return q2, k2 class FeedForward(nn.Module): def __init__(self, d_model, expand: float = 2, bias: bool = True): super().__init__() self.fc1 = nn.Linear(d_model, int(d_model * expand)) self.fc2 = nn.Linear(int(d_model * expand), d_model, bias=bias) self.act = nn.GELU() def forward(self, x): return self.fc2(self.act(self.fc1(x))) class PositionalEncoding(nn.Module): def __init__( self, cutoffs: Tuple[float] = (256,), n_pos: Tuple[int] = (32,), cutoffs_start=None, ): super().__init__() if cutoffs_start is None: cutoffs_start = (1,) * len(cutoffs) assert len(cutoffs) == len(n_pos) self.freqs = nn.ParameterList([ nn.Parameter(_pos_embed_fourier1d_init(cutoff, n // 2)) for cutoff, n, cutoff_start in zip(cutoffs, n_pos, cutoffs_start) ]) def forward(self, coords: torch.Tensor): _B, _N, D = coords.shape assert D == len(self.freqs) embed = torch.cat( tuple( torch.cat( ( torch.sin(0.5 * math.pi * x.unsqueeze(-1) * freq), torch.cos(0.5 * math.pi * x.unsqueeze(-1) * freq), ), axis=-1, ) / math.sqrt(len(freq)) for x, freq in zip(coords.moveaxis(-1, 0), self.freqs) ), axis=-1, ) return embed def _bin_init_exp(cutoff: float, n: int): return torch.exp(torch.linspace(0, math.log(cutoff + 1), n)) def _bin_init_linear(cutoff: float, n: int): return torch.linspace(-cutoff, cutoff, n) class RelativePositionalBias(nn.Module): def __init__( self, n_head: int, cutoff_spatial: float, cutoff_temporal: float, n_spatial: int = 32, n_temporal: int = 16, ): super().__init__() self._spatial_bins = _bin_init_exp(cutoff_spatial, n_spatial) self._temporal_bins = _bin_init_linear(cutoff_temporal, 2 * n_temporal + 1) self.register_buffer("spatial_bins", self._spatial_bins) self.register_buffer("temporal_bins", self._temporal_bins) self.n_spatial = n_spatial self.n_head = n_head self.bias = nn.Parameter( -0.5 + torch.rand((2 * n_temporal + 1) * n_spatial, n_head) ) def forward(self, coords: torch.Tensor): _B, _N, _D = coords.shape t = coords[..., 0] yx = coords[..., 1:] temporal_dist = t.unsqueeze(-1) - t.unsqueeze(-2) spatial_dist = torch.cdist(yx, yx) spatial_idx = torch.bucketize(spatial_dist, self.spatial_bins) torch.clamp_(spatial_idx, max=len(self.spatial_bins) - 1) temporal_idx = torch.bucketize(temporal_dist, self.temporal_bins) torch.clamp_(temporal_idx, max=len(self.temporal_bins) - 1) idx = spatial_idx.flatten() + temporal_idx.flatten() * self.n_spatial bias = self.bias.index_select(0, idx).view((*spatial_idx.shape, self.n_head)) bias = bias.transpose(-1, 1) return bias class RelativePositionalAttention(nn.Module): def __init__( self, coord_dim: int, embed_dim: int, n_head: int, cutoff_spatial: float = 256, cutoff_temporal: float = 16, n_spatial: int = 32, n_temporal: int = 16, dropout: float = 0.0, mode: Literal["bias", "rope", "none"] = "bias", attn_dist_mode: str = "v0", ): super().__init__() if not embed_dim % (2 * n_head) == 0: raise ValueError( f"embed_dim {embed_dim} must be divisible by 2 times n_head {2 * n_head}" ) self.q_pro = nn.Linear(embed_dim, embed_dim, bias=True) self.k_pro = nn.Linear(embed_dim, embed_dim, bias=True) self.v_pro = nn.Linear(embed_dim, embed_dim, bias=True) self.proj = nn.Linear(embed_dim, embed_dim) self.dropout = dropout self.n_head = n_head self.embed_dim = embed_dim self.cutoff_spatial = cutoff_spatial self.attn_dist_mode = attn_dist_mode if mode == "bias" or mode is True: self.pos_bias = RelativePositionalBias( n_head=n_head, cutoff_spatial=cutoff_spatial, cutoff_temporal=cutoff_temporal, n_spatial=n_spatial, n_temporal=n_temporal, ) elif mode == "rope": n_split = 2 * (embed_dim // (2 * (coord_dim + 1) * n_head)) self.rot_pos_enc = RotaryPositionalEncoding( cutoffs=((cutoff_temporal,) + (cutoff_spatial,) * coord_dim), n_pos=(embed_dim // n_head - coord_dim * n_split,) + (n_split,) * coord_dim, ) elif mode == "none": pass elif mode is None or mode is False: logger.warning( "attn_positional_bias is not set (None or False), no positional bias." ) else: raise ValueError(f"Unknown mode {mode}") self._mode = mode def forward( self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, coords: torch.Tensor, padding_mask: torch.Tensor = None, ): B, N, D = query.size() q = self.q_pro(query) k = self.k_pro(key) v = self.v_pro(value) k = k.view(B, N, self.n_head, D // self.n_head).transpose(1, 2) q = q.view(B, N, self.n_head, D // self.n_head).transpose(1, 2) v = v.view(B, N, self.n_head, D // self.n_head).transpose(1, 2) attn_mask = torch.zeros( (B, self.n_head, N, N), device=query.device, dtype=q.dtype ) attn_ignore_val = -1e3 yx = coords[..., 1:] spatial_dist = torch.cdist(yx, yx) spatial_mask = (spatial_dist > self.cutoff_spatial).unsqueeze(1) attn_mask.masked_fill_(spatial_mask, attn_ignore_val) if coords is not None: if self._mode == "bias": attn_mask = attn_mask + self.pos_bias(coords) elif self._mode == "rope": q, k = self.rot_pos_enc(q, k, coords) if self.attn_dist_mode == "v0": dist = torch.cdist(coords, coords, p=2) attn_mask += torch.exp(-0.1 * dist.unsqueeze(1)) elif self.attn_dist_mode == "v1": attn_mask += torch.exp( -5 * spatial_dist.unsqueeze(1) / self.cutoff_spatial ) else: raise ValueError(f"Unknown attn_dist_mode {self.attn_dist_mode}") if padding_mask is not None: ignore_mask = torch.logical_or( padding_mask.unsqueeze(1), padding_mask.unsqueeze(2) ).unsqueeze(1) attn_mask.masked_fill_(ignore_mask, attn_ignore_val) y = F.scaled_dot_product_attention( q, k, v, attn_mask=attn_mask, dropout_p=self.dropout if self.training else 0 ) y = y.transpose(1, 2).contiguous().view(B, N, D) y = self.proj(y) return y class EncoderLayer(nn.Module): def __init__( self, coord_dim: int = 2, d_model=256, num_heads=4, dropout=0.1, cutoff_spatial: int = 256, window: int = 16, positional_bias: Literal["bias", "rope", "none"] = "bias", positional_bias_n_spatial: int = 32, attn_dist_mode: str = "v0", ): super().__init__() self.positional_bias = positional_bias self.attn = RelativePositionalAttention( coord_dim, d_model, num_heads, cutoff_spatial=cutoff_spatial, n_spatial=positional_bias_n_spatial, cutoff_temporal=window, n_temporal=window, dropout=dropout, mode=positional_bias, attn_dist_mode=attn_dist_mode, ) self.mlp = FeedForward(d_model) self.norm1 = nn.LayerNorm(d_model) self.norm2 = nn.LayerNorm(d_model) def forward( self, x: torch.Tensor, coords: torch.Tensor, padding_mask: torch.Tensor = None, ): x = self.norm1(x) # setting coords to None disables positional bias a = self.attn( x, x, x, coords=coords if self.positional_bias else None, padding_mask=padding_mask, ) x = x + a x = x + self.mlp(self.norm2(x)) return x class DecoderLayer(nn.Module): def __init__( self, coord_dim: int = 2, d_model=256, num_heads=4, dropout=0.1, window: int = 16, cutoff_spatial: int = 256, positional_bias: Literal["bias", "rope", "none"] = "bias", positional_bias_n_spatial: int = 32, attn_dist_mode: str = "v0", ): super().__init__() self.positional_bias = positional_bias self.attn = RelativePositionalAttention( coord_dim, d_model, num_heads, cutoff_spatial=cutoff_spatial, n_spatial=positional_bias_n_spatial, cutoff_temporal=window, n_temporal=window, dropout=dropout, mode=positional_bias, attn_dist_mode=attn_dist_mode, ) self.mlp = FeedForward(d_model) self.norm1 = nn.LayerNorm(d_model) self.norm2 = nn.LayerNorm(d_model) self.norm3 = nn.LayerNorm(d_model) def forward( self, x: torch.Tensor, y: torch.Tensor, coords: torch.Tensor, padding_mask: torch.Tensor = None, ): x = self.norm1(x) y = self.norm2(y) # cross attention # setting coords to None disables positional bias a = self.attn( x, y, y, coords=coords if self.positional_bias else None, padding_mask=padding_mask, ) x = x + a x = x + self.mlp(self.norm3(x)) return x class TrackingTransformer(torch.nn.Module): def __init__( self, coord_dim: int = 3, feat_dim: int = 0, d_model: int = 128, nhead: int = 4, num_encoder_layers: int = 4, num_decoder_layers: int = 4, dropout: float = 0.1, pos_embed_per_dim: int = 32, feat_embed_per_dim: int = 1, window: int = 6, spatial_pos_cutoff: int = 256, attn_positional_bias: Literal["bias", "rope", "none"] = "rope", attn_positional_bias_n_spatial: int = 16, causal_norm: Literal[ "none", "linear", "softmax", "quiet_softmax" ] = "quiet_softmax", attn_dist_mode: str = "v0", ): super().__init__() self.config = dict( coord_dim=coord_dim, feat_dim=feat_dim, pos_embed_per_dim=pos_embed_per_dim, d_model=d_model, nhead=nhead, num_encoder_layers=num_encoder_layers, num_decoder_layers=num_decoder_layers, window=window, dropout=dropout, attn_positional_bias=attn_positional_bias, attn_positional_bias_n_spatial=attn_positional_bias_n_spatial, spatial_pos_cutoff=spatial_pos_cutoff, feat_embed_per_dim=feat_embed_per_dim, causal_norm=causal_norm, attn_dist_mode=attn_dist_mode, ) # TODO remove, alredy present in self.config # self.window = window # self.feat_dim = feat_dim # self.coord_dim = coord_dim self.proj = nn.Linear( (1 + coord_dim) * pos_embed_per_dim + feat_dim * feat_embed_per_dim, d_model ) self.norm = nn.LayerNorm(d_model) self.encoder = nn.ModuleList([ EncoderLayer( coord_dim, d_model, nhead, dropout, window=window, cutoff_spatial=spatial_pos_cutoff, positional_bias=attn_positional_bias, positional_bias_n_spatial=attn_positional_bias_n_spatial, attn_dist_mode=attn_dist_mode, ) for _ in range(num_encoder_layers) ]) self.decoder = nn.ModuleList([ DecoderLayer( coord_dim, d_model, nhead, dropout, window=window, cutoff_spatial=spatial_pos_cutoff, positional_bias=attn_positional_bias, positional_bias_n_spatial=attn_positional_bias_n_spatial, attn_dist_mode=attn_dist_mode, ) for _ in range(num_decoder_layers) ]) self.head_x = FeedForward(d_model) self.head_y = FeedForward(d_model) if feat_embed_per_dim > 1: self.feat_embed = PositionalEncoding( cutoffs=(1000,) * feat_dim, n_pos=(feat_embed_per_dim,) * feat_dim, cutoffs_start=(0.01,) * feat_dim, ) else: self.feat_embed = nn.Identity() self.pos_embed = PositionalEncoding( cutoffs=(window,) + (spatial_pos_cutoff,) * coord_dim, n_pos=(pos_embed_per_dim,) * (1 + coord_dim), ) # self.pos_embed = NoPositionalEncoding(d=pos_embed_per_dim * (1 + coord_dim)) # @profile def forward(self, coords, features=None, padding_mask=None, attn_feat=None): assert coords.ndim == 3 and coords.shape[-1] in (3, 4) _B, _N, _D = coords.shape # disable padded coords (such that it doesnt affect minimum) if padding_mask is not None: coords = coords.clone() coords[padding_mask] = coords.max() # remove temporal offset min_time = coords[:, :, :1].min(dim=1, keepdims=True).values coords = coords - min_time pos = self.pos_embed(coords) if features is None or features.numel() == 0: features = pos else: features = self.feat_embed(features) features = torch.cat((pos, features), axis=-1) features = self.proj(features) if attn_feat is not None: # add attention embedding features = features + attn_feat features = self.norm(features) x = features # encoder for enc in self.encoder: x = enc(x, coords=coords, padding_mask=padding_mask) y = features # decoder w cross attention for dec in self.decoder: y = dec(y, x, coords=coords, padding_mask=padding_mask) # y = dec(y, y, coords=coords, padding_mask=padding_mask) x = self.head_x(x) y = self.head_y(y) # outer product is the association matrix (logits) A = torch.einsum("bnd,bmd->bnm", x, y) return A def normalize_output( self, A: torch.FloatTensor, timepoints: torch.LongTensor, coords: torch.FloatTensor, ) -> torch.FloatTensor: """Apply (parental) softmax, or elementwise sigmoid. Args: A: Tensor of shape B, N, N timepoints: Tensor of shape B, N coords: Tensor of shape B, N, (time + n_spatial) """ assert A.ndim == 3 assert timepoints.ndim == 2 assert coords.ndim == 3 assert coords.shape[2] == 1 + self.config["coord_dim"] # spatial distances dist = torch.cdist(coords[:, :, 1:], coords[:, :, 1:]) invalid = dist > self.config["spatial_pos_cutoff"] if self.config["causal_norm"] == "none": # Spatially distant entries are set to zero A = torch.sigmoid(A) A[invalid] = 0 else: return torch.stack([ blockwise_causal_norm( _A, _t, mode=self.config["causal_norm"], mask_invalid=_m ) for _A, _t, _m in zip(A, timepoints, invalid) ]) return A def save(self, folder): folder = Path(folder) folder.mkdir(parents=True, exist_ok=True) yaml.safe_dump(self.config, open(folder / "config.yaml", "w")) torch.save(self.state_dict(), folder / "model.pt") @classmethod def from_folder( cls, folder, map_location=None, checkpoint_path: str = "model.pt" ): folder = Path(folder) config = yaml.load(open(folder / "config.yaml"), Loader=yaml.FullLoader) model = cls(**config) fpath = folder / checkpoint_path logger.info(f"Loading model state from {fpath}") state = torch.load(fpath, map_location=map_location, weights_only=True) # if state is a checkpoint, we have to extract state_dict if "state_dict" in state: state = state["state_dict"] state = OrderedDict( (k[6:], v) for k, v in state.items() if k.startswith("model.") ) model.load_state_dict(state) return model @classmethod def from_cfg( cls, cfg_path ): cfg_path = Path(cfg_path) config = yaml.load(open(cfg_path), Loader=yaml.FullLoader) model = cls(**config) return model