""" bert_ordinal.py --------------- BERT-based ordinal regression model, fully integrated with the HuggingFace Transformers API: model.save_pretrained("my-checkpoint/") model = BertOrdinal.from_pretrained("my-checkpoint/") Architecture ------------ 1. A (optionally frozen) BERT backbone. 2. A projection head on the [CLS] token: Linear(hidden_size → hidden_dim) → ReLU → Dropout(p) → Linear(hidden_dim → 1) producing a single latent score s ∈ ℝ. 3. K-1 learnable raw_threshold parameters enforcing monotonicity via cumsum(softplus(·)). 4. Cumulative-link probabilities: P(Y ≤ j | x) = σ(θ_j − s) Usage ----- from bert_ordinal import BertOrdinalConfig, BertOrdinal # ── Create from scratch ────────────────────────────────────────────────── cfg = BertOrdinalConfig( bert_model_name="bert-base-uncased", num_classes=3, hidden_dim=128, dropout=0.1, freeze_bert=True, ) model = BertOrdinal(cfg) # ── Save ──────────────────────────────────────────────────────────────── model.save_pretrained("my-checkpoint/") tokenizer.save_pretrained("my-checkpoint/") # keep tokenizer alongside # ── Reload ────────────────────────────────────────────────────────────── model = BertOrdinal.from_pretrained("my-checkpoint/") tokenizer = AutoTokenizer.from_pretrained("my-checkpoint/") """ from __future__ import annotations from dataclasses import dataclass from typing import Optional import torch import torch.nn as nn import torch.nn.functional as F from transformers import AutoModel, PreTrainedModel from transformers.modeling_outputs import ModelOutput from .configuration_bert_ordinal import BertOrdinalConfig # --------------------------------------------------------------------------- # 1. Output dataclass # --------------------------------------------------------------------------- @dataclass class BertOrdinalOutput(ModelOutput): """ Return type of :class:`BertOrdinal`. Attributes ---------- loss : torch.Tensor or None Ordinal cross-entropy loss (scalar). Present only when ``labels`` are supplied. logits : torch.Tensor (B,) Raw latent score from the projection head. predictions : torch.Tensor (B,) Predicted class index — argmax of ``class_probs``. cum_probs : torch.Tensor (B, K-1) Cumulative probabilities P(Y ≤ j | x). class_probs : torch.Tensor (B, K) Per-class probabilities P(Y = j | x). """ loss: Optional[torch.Tensor] = None logits: Optional[torch.Tensor] = None predictions: Optional[torch.Tensor] = None cum_probs: Optional[torch.Tensor] = None class_probs: Optional[torch.Tensor] = None # --------------------------------------------------------------------------- # 3. Model — subclass PreTrainedModel for save / from_pretrained # --------------------------------------------------------------------------- class BertOrdinal(PreTrainedModel): """ BERT encoder with an ordinal-regression head. Fully compatible with the HuggingFace checkpoint API:: model.save_pretrained("my-checkpoint/") model = BertOrdinal.from_pretrained("my-checkpoint/") What gets saved ~~~~~~~~~~~~~~~ ``save_pretrained`` writes two files: * ``config.json`` — the full :class:`BertOrdinalConfig` (including ``bert_model_name``, ``hidden_size``, thresholds shape, …). * ``model.safetensors`` (or ``pytorch_model.bin``) — a **single flat state_dict** containing both the BERT backbone weights and the head/threshold parameters. ``from_pretrained`` reconstructs the model from the config (which already has ``hidden_size`` cached), loads the state_dict, and re-applies the ``freeze_bert`` setting — no internet access needed after the first save. """ config_class = BertOrdinalConfig def __init__(self, config: BertOrdinalConfig) -> None: super().__init__(config) K = config.num_classes # ── 1. BERT backbone ──────────────────────────────────────────────── # If hidden_size is already in the config (i.e. we are being called # from from_pretrained after a save), build the backbone from the # cached backbone config instead of re-downloading weights — # from_pretrained will overwrite with the saved state_dict anyway. self.bert = AutoModel.from_pretrained(config.bert_model_name) hidden_size: int = self.bert.config.hidden_size # Cache so the head can be rebuilt offline after save_pretrained. config.hidden_size = hidden_size if config.freeze_bert: for param in self.bert.parameters(): param.requires_grad = False # ── 2. Projection head ────────────────────────────────────────────── self.head = nn.Sequential( nn.Linear(hidden_size, config.hidden_dim), nn.ReLU(), nn.Dropout(config.dropout), nn.Linear(config.hidden_dim, 1), ) self._init_head() # ── 3. Ordinal thresholds ─────────────────────────────────────────── # K-1 raw values; monotonicity enforced via cumsum(softplus(·)). self.raw_thresholds = nn.Parameter(torch.zeros(K - 1)) with torch.no_grad(): targets = torch.linspace(-1.0, 1.0, K - 1) diffs = torch.cat([targets[:1], targets[1:] - targets[:-1]]) self.raw_thresholds.copy_( torch.log(torch.expm1(diffs.clamp(min=1e-3))) ) # Finalises weight init bookkeeping required by PreTrainedModel. self.post_init() # ----------------------------------------------------------------------- # Helpers # ----------------------------------------------------------------------- def _init_head(self) -> None: for m in self.head.modules(): if isinstance(m, nn.Linear): nn.init.kaiming_normal_(m.weight, nonlinearity="relu") nn.init.zeros_(m.bias) @property def thresholds(self) -> torch.Tensor: """Monotone thresholds θ₁ ≤ … ≤ θ_{K-1} (shape: K-1).""" return torch.cumsum(F.softplus(self.raw_thresholds), dim=0) # ----------------------------------------------------------------------- # Forward # ----------------------------------------------------------------------- def forward( self, input_ids: torch.Tensor, attention_mask: torch.Tensor, token_type_ids: Optional[torch.Tensor] = None, labels: Optional[torch.Tensor] = None, **kwargs, ) -> BertOrdinalOutput: """ Parameters ---------- input_ids : (B, L) attention_mask : (B, L) token_type_ids : (B, L) optional labels : (B,) long — class indices in {0, …, K-1} Returns ------- BertOrdinalOutput """ # ── Encode ────────────────────────────────────────────────────────── bert_kwargs = dict(input_ids=input_ids, attention_mask=attention_mask) if token_type_ids is not None: bert_kwargs["token_type_ids"] = token_type_ids cls_repr = self.bert(**bert_kwargs).last_hidden_state[:, 0, :] # (B, H) # ── Latent score ──────────────────────────────────────────────────── score = self.head(cls_repr).squeeze(-1) # (B,) # ── Cumulative probs P(Y ≤ j) = σ(θ_j − score) ──────────────────── cum_logits = self.thresholds.unsqueeze(0) - score.unsqueeze(1) # (B, K-1) cum_probs = torch.sigmoid(cum_logits) # (B, K-1) # ── Class probs P(Y = j) = P(Y ≤ j) − P(Y ≤ j-1) ───────────────── B, dev = cum_probs.size(0), cum_probs.device F_ = torch.cat( [torch.zeros(B, 1, device=dev), cum_probs, torch.ones(B, 1, device=dev)], dim=1, ) # (B, K+1) class_probs = (F_[:, 1:] - F_[:, :-1]).clamp(min=1e-9) # (B, K) # ── Predictions ────────────────────────────────────────────────────── predictions = class_probs.argmax(dim=-1) # (B,) # ── Loss ───────────────────────────────────────────────────────────── loss: Optional[torch.Tensor] = None if labels is not None: loss = ordinal_cross_entropy( class_probs, labels, reduction=self.config.loss_reduction ) return BertOrdinalOutput( loss=loss, logits=score, predictions=predictions, cum_probs=cum_probs, class_probs=class_probs, ) # ----------------------------------------------------------------------- # Convenience # ----------------------------------------------------------------------- @torch.no_grad() def predict( self, input_ids: torch.Tensor, attention_mask: torch.Tensor, token_type_ids: Optional[torch.Tensor] = None, ) -> torch.Tensor: """Return predicted class indices (no loss computed).""" return self.forward(input_ids, attention_mask, token_type_ids).predictions # --------------------------------------------------------------------------- # Loss function # --------------------------------------------------------------------------- def ordinal_cross_entropy( class_probs: torch.Tensor, labels: torch.Tensor, reduction: str = "mean", ) -> torch.Tensor: """ Ordinal cross-entropy. Parameters ---------- class_probs : (B, K) — P(Y=j|x), clamped > 0 labels : (B,) — ground-truth indices in {0, …, K-1} reduction : 'mean' | 'sum' | 'none' """ return F.nll_loss(torch.log(class_probs), labels, reduction=reduction)