File size: 6,105 Bytes
881e6cc
 
 
 
9f5c6cd
 
881e6cc
 
 
 
 
9f5c6cd
881e6cc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d3237a5
 
881e6cc
 
 
 
 
fbbff43
881e6cc
 
 
 
 
 
 
 
 
9f5c6cd
 
 
881e6cc
 
 
 
 
9f5c6cd
 
 
881e6cc
 
 
 
 
 
 
 
 
 
fbbff43
9f5c6cd
 
 
 
 
 
 
 
 
 
881e6cc
 
 
 
 
 
 
9f5c6cd
 
 
 
 
d3237a5
9f5c6cd
 
 
 
 
d3237a5
 
 
9f5c6cd
 
 
 
 
 
d3237a5
 
 
 
 
 
 
9f5c6cd
 
 
 
 
 
7abf6f6
9f5c6cd
 
 
 
 
 
 
 
 
 
 
d3237a5
 
 
 
 
9f5c6cd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7abf6f6
d3237a5
9f5c6cd
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
from collections.abc import Generator, Iterable
from dataclasses import dataclass
from enum import StrEnum

import pprint

import torch
import torch.nn as nn
from transformers import (
    AutoConfig,
    AutoModel,
    BatchEncoding,
    ModernBertModel,
    PreTrainedConfig,
    PreTrainedModel,
    PreTrainedTokenizer,
)
from transformers.modeling_outputs import TokenClassifierOutput

BATCH_SIZE = 16


class ModelURI(StrEnum):
    BASE = "answerdotai/ModernBERT-base"
    LARGE = "answerdotai/ModernBERT-large"


@dataclass(slots=True, frozen=True)
class LexicalExample:
    concept: str
    definition: str


@dataclass(slots=True, frozen=True)
class PaddedBatch:
    input_ids: torch.Tensor
    attention_mask: torch.Tensor


class DisamBertSingleSense(PreTrainedModel):
    def __init__(self, config: PreTrainedConfig):
        super().__init__(config)
        if config.init_basemodel:
            self.BaseModel = AutoModel.from_pretrained(config.name_or_path, 
                                                       device_map="auto")
            self.config.vocab_size += 2
            self.BaseModel.resize_token_embeddings(self.config.vocab_size)
        else:
            self.BaseModel = ModernBertModel(config)
        config.init_basemodel = False

        self.loss = nn.CrossEntropyLoss()
        self.post_init()

    @classmethod
    def from_base(cls, base_id: ModelURI):
        config = AutoConfig.from_pretrained(base_id)
        config.init_basemodel = True
        return cls(config)

    def add_special_tokens(self, start: int, end: int):
        self.config.start_token = start
        self.config.end_token = end

    def forward(
        self,
        input_ids: torch.Tensor,
        attention_mask: torch.Tensor,
        candidate_tokens: torch.Tensor,
        candidate_attention_masks: torch.Tensor,
        candidate_mapping: torch.Tensor,
        labels: Iterable[int] | None = None,
        output_hidden_states: bool = False,
        output_attentions: bool = False,
    ) -> TokenClassifierOutput:
        base_model_output = self.BaseModel(
            input_ids,
            attention_mask,
            output_hidden_states=output_hidden_states,
            output_attentions=output_attentions,
        )
        token_vectors = base_model_output.last_hidden_state
        selection = torch.zeros_like(input_ids, dtype=token_vectors.dtype)
        starts = (input_ids == self.config.start_token).nonzero()
        ends = (input_ids == self.config.end_token).nonzero()
        for startpos, endpos in zip(starts, ends, strict=True):
            selection[startpos[0], startpos[1] : endpos[1] + 1] = 1.0
        entity_vectors = torch.einsum("ijk,ij->ik", token_vectors, selection)
        gloss_vectors = self.gloss_vectors(
            candidate_tokens, candidate_attention_masks, candidate_mapping
        )
        logits = torch.einsum("ij,ikj->ik", entity_vectors, gloss_vectors)

        return TokenClassifierOutput(
            logits=logits,
            loss=self.loss(logits, labels) if labels is not None else None,
            hidden_states=base_model_output.hidden_states if output_hidden_states else None,
            attentions=base_model_output.attentions if output_attentions else None,
        )

    def gloss_vectors(self, candidates, candidate_attention_masks, candidate_mapping):
        with self.device:
            vectors = self.BaseModel(candidates, candidate_attention_masks).last_hidden_state[:, 0]
            chunks = [
                torch.squeeze(vectors[(candidate_mapping == sentence_index).nonzero()], dim=1)
                for sentence_index in torch.unique(candidate_mapping)
            ]
            maxlen = max(chunk.shape[0] for chunk in chunks)
            return torch.stack(
                [
                    torch.cat(
                        [chunk, torch.zeros((maxlen - chunk.shape[0], self.config.hidden_size))]
                    )
                    for chunk in chunks
                ]
            )


class CandidateLabeller:
    def __init__(
        self,
        tokenizer: PreTrainedTokenizer,
        ontology: Generator[LexicalExample],
        device: torch.device,
        retain_candidates: bool = False,
    ):
        self.tokenizer = tokenizer
        self.device = device
        self.gloss_tokens = {
            example.concept: self.tokenizer(example.definition, padding=True)
            for example in ontology
        }
        self.retain_candidates = retain_candidates

    def __call__(self, batch: dict) -> dict:
        with self.device:
            encoded = [
                BatchEncoding(
                    {"input_ids": example["input_ids"], "attention_mask": example["attention_mask"]}
                )
                for example in batch
            ]
            tokens = self.tokenizer.pad(encoded, padding=True, return_tensors="pt")
            candidate_tokens = self.tokenizer.pad(
                [
                    self.gloss_tokens[concept]
                    for example in batch
                    for concept in example["candidates"]
                ],
                padding=True,
                return_attention_mask=True,
                return_tensors="pt",
            )
            result = {
                "input_ids": tokens.input_ids,
                "attention_mask": tokens.attention_mask,
                "candidate_tokens": candidate_tokens.input_ids,
                "candidate_attention_masks": candidate_tokens.attention_mask,
                "candidate_mapping": torch.cat(
                    [
                        torch.tensor([i] * len(example["candidates"]))
                        for (i, example) in enumerate(batch)
                    ]
                ),
            }
            if "label" in batch[0]:
                result["labels"] = torch.tensor(
                    [example["candidates"].index(example["label"]) for example in batch]
                )
            if self.retain_candidates:
                result["candidates"] = [example["candidates"] for example in batch]
            return result