from collections.abc import Generator, Iterable from dataclasses import dataclass from enum import StrEnum import pprint import torch import torch.nn as nn from transformers import ( AutoConfig, AutoModel, BatchEncoding, ModernBertModel, PreTrainedConfig, PreTrainedModel, PreTrainedTokenizer, ) from transformers.modeling_outputs import TokenClassifierOutput BATCH_SIZE = 16 class ModelURI(StrEnum): BASE = "answerdotai/ModernBERT-base" LARGE = "answerdotai/ModernBERT-large" @dataclass(slots=True, frozen=True) class LexicalExample: concept: str definition: str @dataclass(slots=True, frozen=True) class PaddedBatch: input_ids: torch.Tensor attention_mask: torch.Tensor class DisamBertSingleSense(PreTrainedModel): def __init__(self, config: PreTrainedConfig): super().__init__(config) if config.init_basemodel: self.BaseModel = AutoModel.from_pretrained(config.name_or_path, attn_implementation="flash_attention_2", dtype=torch.bfloat16, device_map="auto") self.config.vocab_size += 3 self.BaseModel.resize_token_embeddings(self.config.vocab_size) else: self.BaseModel = ModernBertModel(config) config.init_basemodel = False self.loss = nn.CrossEntropyLoss() self.post_init() @classmethod def from_base(cls, base_id: ModelURI): config = AutoConfig.from_pretrained(base_id) config.init_basemodel = True return cls(config) def add_special_tokens(self, start: int, end: int, gloss: int): self.config.start_token = start self.config.end_token = end self.config.gloss_token = gloss def forward( self, input_ids: torch.Tensor, attention_mask: torch.Tensor, labels: Iterable[int] | None = None, output_hidden_states: bool = False, output_attentions: bool = False, ) -> TokenClassifierOutput: base_model_output = self.BaseModel( input_ids, attention_mask, output_hidden_states=output_hidden_states, output_attentions=output_attentions, ) token_vectors = base_model_output.last_hidden_state selection = torch.zeros_like(input_ids, dtype=token_vectors.dtype) starts = (input_ids == self.config.start_token).nonzero() ends = (input_ids == self.config.end_token).nonzero() for startpos, endpos in zip(starts, ends, strict=True): selection[startpos[0], startpos[1] : endpos[1] + 1] = 1.0 entity_vectors = torch.einsum("ijk,ij->ik", token_vectors, selection) gloss_vectors = self.gloss_vectors( token_vectors, input_ids, ) logits = torch.einsum("ij,ikj->ik", entity_vectors, gloss_vectors) return TokenClassifierOutput( logits=logits, loss=self.loss(logits, labels) if labels is not None else None, hidden_states=base_model_output.hidden_states if output_hidden_states else None, attentions=base_model_output.attentions if output_attentions else None, ) def gloss_vectors(self, token_vectors: torch.Tensor, input_ids:torch.Tensor)->torch.Tensor: with self.device: selection = (input_ids==self.config.gloss_token) candidates_per_row = selection.sum(axis=1) max_candidates = candidates_per_row.max() indices = torch.flatten(selection) vectors = torch.reshape(token_vectors, (token_vectors.shape[0]*token_vectors.shape[1], token_vectors.shape[2])) gloss_vectors = vectors[indices] return torch.stack([torch.cat([chunk,torch.zeros((max_candidates-chunk.shape[0], chunk.shape[1]), dtype=torch.bfloat16)]) for chunk in torch.split(gloss_vectors, tuple(candidates_per_row.tolist()))]) class CandidateLabeller: def __init__(self, tokenizer: PreTrainedTokenizer, ontology: Generator[LexicalExample], device:torch.device, retain_candidates: bool = False): self.tokenizer = tokenizer self.device = device self.glosses = { example.concept: example.definition for example in ontology } self.retain_candidates = retain_candidates def __call__(self, batch: list[dict]) -> dict: with self.device: glosses = ["\n".join(self.glosses[candidate] for candidate in example) for example in batch['candidates']] tokens = self.tokenizer(batch["text"],glosses,padding=True,return_tensors="pt") result = {"input_ids":tokens.input_ids, "attention_mask":tokens.attention_mask} if "label" in batch: result["labels"] = torch.tensor( [candidates.index(label) for (candidates,label) in zip(batch['candidates'], batch['label'], strict=True)] ) if self.retain_candidates: result['candidates'] = batch['candidates'] return result