| from collections.abc import Generator, Iterable |
| from dataclasses import dataclass |
| from enum import StrEnum |
|
|
| import pprint |
|
|
| import torch |
| import torch.nn as nn |
| from transformers import ( |
| AutoConfig, |
| AutoModel, |
| BatchEncoding, |
| ModernBertModel, |
| PreTrainedConfig, |
| PreTrainedModel, |
| PreTrainedTokenizer, |
| ) |
| from transformers.modeling_outputs import TokenClassifierOutput |
|
|
| BATCH_SIZE = 16 |
|
|
|
|
| class ModelURI(StrEnum): |
| BASE = "answerdotai/ModernBERT-base" |
| LARGE = "answerdotai/ModernBERT-large" |
|
|
|
|
| @dataclass(slots=True, frozen=True) |
| class LexicalExample: |
| concept: str |
| definition: str |
|
|
|
|
| @dataclass(slots=True, frozen=True) |
| class PaddedBatch: |
| input_ids: torch.Tensor |
| attention_mask: torch.Tensor |
|
|
|
|
| class DisamBertSingleSense(PreTrainedModel): |
| def __init__(self, config: PreTrainedConfig): |
| super().__init__(config) |
| if config.init_basemodel: |
| self.BaseModel = AutoModel.from_pretrained(config.name_or_path, |
| device_map="auto") |
| self.config.vocab_size += 2 |
| self.BaseModel.resize_token_embeddings(self.config.vocab_size) |
| else: |
| self.BaseModel = ModernBertModel(config) |
| config.init_basemodel = False |
|
|
| self.loss = nn.CrossEntropyLoss() |
| self.post_init() |
|
|
| @classmethod |
| def from_base(cls, base_id: ModelURI): |
| config = AutoConfig.from_pretrained(base_id) |
| config.init_basemodel = True |
| return cls(config) |
|
|
| def add_special_tokens(self, start: int, end: int): |
| self.config.start_token = start |
| self.config.end_token = end |
|
|
| def forward( |
| self, |
| input_ids: torch.Tensor, |
| attention_mask: torch.Tensor, |
| candidate_tokens: torch.Tensor, |
| candidate_attention_masks: torch.Tensor, |
| candidate_mapping: torch.Tensor, |
| labels: Iterable[int] | None = None, |
| output_hidden_states: bool = False, |
| output_attentions: bool = False, |
| ) -> TokenClassifierOutput: |
| base_model_output = self.BaseModel( |
| input_ids, |
| attention_mask, |
| output_hidden_states=output_hidden_states, |
| output_attentions=output_attentions, |
| ) |
| token_vectors = base_model_output.last_hidden_state |
| selection = torch.zeros_like(input_ids, dtype=token_vectors.dtype) |
| starts = (input_ids == self.config.start_token).nonzero() |
| ends = (input_ids == self.config.end_token).nonzero() |
| for startpos, endpos in zip(starts, ends, strict=True): |
| selection[startpos[0], startpos[1] : endpos[1] + 1] = 1.0 |
| entity_vectors = torch.einsum("ijk,ij->ik", token_vectors, selection) |
| gloss_vectors = self.gloss_vectors( |
| candidate_tokens, candidate_attention_masks, candidate_mapping |
| ) |
| logits = torch.einsum("ij,ikj->ik", entity_vectors, gloss_vectors) |
|
|
| return TokenClassifierOutput( |
| logits=logits, |
| loss=self.loss(logits, labels) if labels is not None else None, |
| hidden_states=base_model_output.hidden_states if output_hidden_states else None, |
| attentions=base_model_output.attentions if output_attentions else None, |
| ) |
|
|
| def gloss_vectors(self, candidates, candidate_attention_masks, candidate_mapping): |
| with self.device: |
| vectors = self.BaseModel(candidates, candidate_attention_masks).last_hidden_state[:, 0] |
| chunks = [ |
| torch.squeeze(vectors[(candidate_mapping == sentence_index).nonzero()], dim=1) |
| for sentence_index in torch.unique(candidate_mapping) |
| ] |
| maxlen = max(chunk.shape[0] for chunk in chunks) |
| return torch.stack( |
| [ |
| torch.cat( |
| [chunk, torch.zeros((maxlen - chunk.shape[0], self.config.hidden_size))] |
| ) |
| for chunk in chunks |
| ] |
| ) |
|
|
|
|
| class CandidateLabeller: |
| def __init__( |
| self, |
| tokenizer: PreTrainedTokenizer, |
| ontology: Generator[LexicalExample], |
| device: torch.device, |
| retain_candidates: bool = False, |
| ): |
| self.tokenizer = tokenizer |
| self.device = device |
| self.gloss_tokens = { |
| example.concept: self.tokenizer(example.definition, padding=True) |
| for example in ontology |
| } |
| self.retain_candidates = retain_candidates |
|
|
| def __call__(self, batch: dict) -> dict: |
| with self.device: |
| encoded = [ |
| BatchEncoding( |
| {"input_ids": example["input_ids"], "attention_mask": example["attention_mask"]} |
| ) |
| for example in batch |
| ] |
| tokens = self.tokenizer.pad(encoded, padding=True, return_tensors="pt") |
| candidate_tokens = self.tokenizer.pad( |
| [ |
| self.gloss_tokens[concept] |
| for example in batch |
| for concept in example["candidates"] |
| ], |
| padding=True, |
| return_attention_mask=True, |
| return_tensors="pt", |
| ) |
| result = { |
| "input_ids": tokens.input_ids, |
| "attention_mask": tokens.attention_mask, |
| "candidate_tokens": candidate_tokens.input_ids, |
| "candidate_attention_masks": candidate_tokens.attention_mask, |
| "candidate_mapping": torch.cat( |
| [ |
| torch.tensor([i] * len(example["candidates"])) |
| for (i, example) in enumerate(batch) |
| ] |
| ), |
| } |
| if "label" in batch[0]: |
| result["labels"] = torch.tensor( |
| [example["candidates"].index(example["label"]) for example in batch] |
| ) |
| if self.retain_candidates: |
| result["candidates"] = [example["candidates"] for example in batch] |
| return result |
|
|