File size: 5,742 Bytes
0218456
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
from collections.abc import Generator, Iterable
from dataclasses import dataclass
from enum import StrEnum

import pprint

import torch
import torch.nn as nn
from transformers import (
    AutoConfig,
    AutoModel,
    BatchEncoding,
    ModernBertModel,
    PreTrainedConfig,
    PreTrainedModel,
    PreTrainedTokenizer,
)
from transformers.modeling_outputs import TokenClassifierOutput

BATCH_SIZE = 16


class ModelURI(StrEnum):
    BASE = "answerdotai/ModernBERT-base"
    LARGE = "answerdotai/ModernBERT-large"


@dataclass(slots=True, frozen=True)
class LexicalExample:
    concept: str
    definition: str


@dataclass(slots=True, frozen=True)
class PaddedBatch:
    input_ids: torch.Tensor
    attention_mask: torch.Tensor


class DisamBertSingleSense(PreTrainedModel):
    def __init__(self, config: PreTrainedConfig):
        super().__init__(config)
        if config.init_basemodel:
            self.BaseModel = AutoModel.from_pretrained(config.name_or_path,
                                                       attn_implementation="flash_attention_2",
                                                       dtype=torch.bfloat16,
                                                       device_map="auto")
            self.config.vocab_size += 3
            self.BaseModel.resize_token_embeddings(self.config.vocab_size)
        else:
            self.BaseModel = ModernBertModel(config)
        config.init_basemodel = False

        self.loss = nn.CrossEntropyLoss()
        self.post_init()

    @classmethod
    def from_base(cls, base_id: ModelURI):
        config = AutoConfig.from_pretrained(base_id)
        config.init_basemodel = True
        return cls(config)

    def add_special_tokens(self, start: int, end: int, gloss: int):
        self.config.start_token = start
        self.config.end_token = end
        self.config.gloss_token = gloss

    def forward(
        self,
        input_ids: torch.Tensor,
        attention_mask: torch.Tensor,
        labels: Iterable[int] | None = None,
        output_hidden_states: bool = False,
        output_attentions: bool = False,
    ) -> TokenClassifierOutput:
        base_model_output = self.BaseModel(
            input_ids,
            attention_mask,
            output_hidden_states=output_hidden_states,
            output_attentions=output_attentions,
        )
        token_vectors = base_model_output.last_hidden_state
        selection = torch.zeros_like(input_ids, dtype=token_vectors.dtype)
        starts = (input_ids == self.config.start_token).nonzero()
        ends = (input_ids == self.config.end_token).nonzero()
        for startpos, endpos in zip(starts, ends, strict=True):
            selection[startpos[0], startpos[1] : endpos[1] + 1] = 1.0
        entity_vectors = torch.einsum("ijk,ij->ik", token_vectors, selection)
        gloss_vectors = self.gloss_vectors(
            token_vectors,
            input_ids,
        )
        logits = torch.einsum("ij,ikj->ik", entity_vectors, gloss_vectors)

        return TokenClassifierOutput(
            logits=logits,
            loss=self.loss(logits, labels) if labels is not None else None,
            hidden_states=base_model_output.hidden_states if output_hidden_states else None,
            attentions=base_model_output.attentions if output_attentions else None,
        )

    def gloss_vectors(self, token_vectors: torch.Tensor, input_ids:torch.Tensor)->torch.Tensor:
        with self.device:
            selection = (input_ids==self.config.gloss_token)
            candidates_per_row = selection.sum(axis=1)
            max_candidates = candidates_per_row.max()
            indices = torch.flatten(selection)
            vectors = torch.reshape(token_vectors,
                                    (token_vectors.shape[0]*token_vectors.shape[1],
                                     token_vectors.shape[2]))
            gloss_vectors = vectors[indices]
            return torch.stack([torch.cat([chunk,torch.zeros((max_candidates-chunk.shape[0],
                                                              chunk.shape[1]),
                                                             dtype=torch.bfloat16)])
                                for chunk in torch.split(gloss_vectors,
                                                         tuple(candidates_per_row.tolist()))])
            


class CandidateLabeller:
    def __init__(self, tokenizer: PreTrainedTokenizer, 
                 ontology: Generator[LexicalExample], 
                 device:torch.device,
                 retain_candidates: bool = False):
        self.tokenizer = tokenizer
        self.device = device
        self.glosses = {
            example.concept: example.definition
            for example in ontology
        }
        self.retain_candidates = retain_candidates

    def __call__(self, batch: list[dict]) -> dict:
        with self.device:
            glosses = ["\n".join(self.glosses[candidate]
                                  for candidate in example)
                       for example in batch['candidates']]
            tokens = self.tokenizer(batch["text"],glosses,padding=True,return_tensors="pt")
            result = {"input_ids":tokens.input_ids,
                      "attention_mask":tokens.attention_mask}
            if "label" in batch:
                result["labels"] = torch.tensor(
                    [candidates.index(label) 
                     for (candidates,label) in zip(batch['candidates'],
                                                   batch['label'],
                                                   strict=True)]
                )
            if self.retain_candidates:
                result['candidates'] = batch['candidates']
            return result