| | import torch.nn as nn |
| | from transformers import BertPreTrainedModel, BertModel, AutoTokenizer |
| | from huggingface_hub import hf_hub_download |
| | import torch |
| | from tqdm import tqdm |
| | from .colbert_configuration import ColBERTConfig |
| | from .tokenization_utils import QueryTokenizer, DocTokenizer |
| | import os |
| |
|
| |
|
| | class NullContextManager(object): |
| | def __init__(self, dummy_resource=None): |
| | self.dummy_resource = dummy_resource |
| | def __enter__(self): |
| | return self.dummy_resource |
| | def __exit__(self, *args): |
| | pass |
| |
|
| | class MixedPrecisionManager(): |
| | def __init__(self, activated): |
| | self.activated = activated |
| |
|
| | if self.activated: |
| | self.scaler = torch.amp.GradScaler("cuda") |
| |
|
| | def context(self): |
| | return torch.amp.autocast("cuda") if self.activated else NullContextManager() |
| |
|
| | def backward(self, loss): |
| | if self.activated: |
| | self.scaler.scale(loss).backward() |
| | else: |
| | loss.backward() |
| |
|
| | def step(self, colbert, optimizer, scheduler=None): |
| | if self.activated: |
| | self.scaler.unscale_(optimizer) |
| | torch.nn.utils.clip_grad_norm_(colbert.parameters(), 2.0, error_if_nonfinite=False) |
| |
|
| | self.scaler.step(optimizer) |
| | self.scaler.update() |
| | else: |
| | torch.nn.utils.clip_grad_norm_(colbert.parameters(), 2.0) |
| | optimizer.step() |
| | |
| | if scheduler is not None: |
| | scheduler.step() |
| |
|
| | optimizer.zero_grad() |
| |
|
| | class ConstBERT(BertPreTrainedModel): |
| | """ |
| | Shallow wrapper around HuggingFace transformers. All new parameters should be defined at this level. |
| | |
| | This makes sure `{from,save}_pretrained` and `init_weights` are applied to new parameters correctly. |
| | """ |
| | _keys_to_ignore_on_load_unexpected = [r"cls"] |
| |
|
| | def __init__(self, config, colbert_config, verbose:int = 0): |
| | super().__init__(config) |
| |
|
| | self.config = config |
| | self.dim = colbert_config.dim |
| | self.linear = nn.Linear(config.hidden_size, colbert_config.dim, bias=False) |
| | self.doc_project = nn.Linear(colbert_config.doc_maxlen, 32, bias=False) |
| | self.query_project = nn.Linear(colbert_config.query_maxlen, 64, bias=False) |
| |
|
| | |
| | if not os.path.exists(os.path.join(colbert_config.name_or_path, "tokenizer.json")): |
| | hf_hub_download(repo_id=colbert_config.name_or_path, filename="tokenizer.json") |
| | if not os.path.exists(os.path.join(colbert_config.name_or_path, "vocab.txt")): |
| | hf_hub_download(repo_id=colbert_config.name_or_path, filename="vocab.txt") |
| | if not os.path.exists(os.path.join(colbert_config.name_or_path, "tokenizer_config.json")): |
| | hf_hub_download(repo_id=colbert_config.name_or_path, filename="tokenizer_config.json") |
| | if not os.path.exists(os.path.join(colbert_config.name_or_path, "special_tokens_map.json")): |
| | hf_hub_download(repo_id=colbert_config.name_or_path, filename="special_tokens_map.json") |
| |
|
| | self.query_tokenizer = QueryTokenizer(colbert_config, verbose=verbose) |
| | self.doc_tokenizer = DocTokenizer(colbert_config) |
| | self.amp_manager = MixedPrecisionManager(True) |
| |
|
| | self.raw_tokenizer = AutoTokenizer.from_pretrained(colbert_config.checkpoint) |
| | self.pad_token = self.raw_tokenizer.pad_token_id |
| | self.use_gpu = colbert_config.total_visible_gpus > 0 |
| |
|
| | setattr(self,self.base_model_prefix, BertModel(config)) |
| |
|
| | |
| | |
| |
|
| | self.init_weights() |
| |
|
| | |
| | |
| | |
| |
|
| | @property |
| | def LM(self): |
| | base_model_prefix = getattr(self, "base_model_prefix") |
| | return getattr(self, base_model_prefix) |
| |
|
| |
|
| | @classmethod |
| | def from_pretrained(cls, name_or_path, config=None, *args, **kwargs): |
| | colbert_config = ColBERTConfig(name_or_path) |
| | colbert_config = ColBERTConfig.from_existing(ColBERTConfig.load_from_checkpoint(name_or_path), colbert_config) |
| | obj = super().from_pretrained(name_or_path, colbert_config=colbert_config, config=config) |
| | obj.base = name_or_path |
| |
|
| | return obj |
| |
|
| | @staticmethod |
| | def raw_tokenizer_from_pretrained(name_or_path): |
| | obj = AutoTokenizer.from_pretrained(name_or_path) |
| | obj.base = name_or_path |
| |
|
| | return obj |
| | |
| |
|
| | def _query(self, input_ids, attention_mask): |
| | input_ids, attention_mask = input_ids.to(self.device), attention_mask.to(self.device) |
| | Q = self.bert(input_ids, attention_mask=attention_mask)[0] |
| | |
| | |
| | |
| | Q = self.linear(Q) |
| | |
| |
|
| | mask = torch.tensor(self.mask(input_ids, skiplist=[]), device=self.device).unsqueeze(2).float() |
| | Q = Q * mask |
| |
|
| | return torch.nn.functional.normalize(Q, p=2, dim=2) |
| |
|
| | def _doc(self, input_ids, attention_mask, keep_dims=True): |
| | assert keep_dims in [True, False, 'return_mask'] |
| |
|
| | input_ids, attention_mask = input_ids.to(self.device), attention_mask.to(self.device) |
| | D = self.bert(input_ids, attention_mask=attention_mask)[0] |
| | D = D.permute(0, 2, 1) |
| | D = self.doc_project(D) |
| | D = D.permute(0, 2, 1) |
| | D = self.linear(D) |
| | mask = torch.ones(D.shape[0], D.shape[1], device=self.device).unsqueeze(2).float() |
| |
|
| | |
| | D = D * mask |
| | D = torch.nn.functional.normalize(D, p=2, dim=2) |
| | if self.use_gpu: |
| | D = D.half() |
| |
|
| | if keep_dims is False: |
| | D, mask = D.cpu(), mask.bool().cpu().squeeze(-1) |
| | D = [d[mask[idx]] for idx, d in enumerate(D)] |
| |
|
| | elif keep_dims == 'return_mask': |
| | return D, mask.bool() |
| |
|
| | return D |
| |
|
| | def mask(self, input_ids, skiplist): |
| | mask = [[(x not in skiplist) and (x != self.pad_token) for x in d] for d in input_ids.cpu().tolist()] |
| | return mask |
| |
|
| | def query(self, *args, to_cpu=False, **kw_args): |
| | with torch.no_grad(): |
| | with self.amp_manager.context(): |
| | Q = self._query(*args, **kw_args) |
| | return Q.cpu() if to_cpu else Q |
| |
|
| | def doc(self, *args, to_cpu=False, **kw_args): |
| | with torch.no_grad(): |
| | with self.amp_manager.context(): |
| | D = self._doc(*args, **kw_args) |
| |
|
| | if to_cpu: |
| | return (D[0].cpu(), *D[1:]) if isinstance(D, tuple) else D.cpu() |
| |
|
| | return D |
| |
|
| | def encode_queries(self, queries, bsize=None, to_cpu=False, context=None, full_length_search=False): |
| | if type(queries) == str: |
| | queries = [queries] |
| | if bsize: |
| | batches = self.query_tokenizer.tensorize(queries, context=context, bsize=bsize, full_length_search=full_length_search) |
| | batches = [self.query(input_ids, attention_mask, to_cpu=to_cpu) for input_ids, attention_mask in batches] |
| | return torch.cat(batches) |
| |
|
| | input_ids, attention_mask = self.query_tokenizer.tensorize(queries, context=context, full_length_search=full_length_search) |
| | return self.query(input_ids, attention_mask) |
| |
|
| | def encode_documents(self, docs, bsize=None, keep_dims=True, to_cpu=False, showprogress=False, return_tokens=False): |
| | if type(docs) == str: |
| | docs = [docs] |
| | assert keep_dims in [True, False, 'flatten'] |
| |
|
| | if bsize: |
| | text_batches, reverse_indices = self.doc_tokenizer.tensorize(docs, bsize=bsize) |
| |
|
| | returned_text = [] |
| | if return_tokens: |
| | returned_text = [text for batch in text_batches for text in batch[0]] |
| | returned_text = [returned_text[idx] for idx in reverse_indices.tolist()] |
| | returned_text = [returned_text] |
| |
|
| | keep_dims_ = 'return_mask' if keep_dims == 'flatten' else keep_dims |
| | batches = [self.doc(input_ids, attention_mask, keep_dims=keep_dims_, to_cpu=to_cpu) |
| | for input_ids, attention_mask in tqdm(text_batches, disable=not showprogress)] |
| |
|
| | if keep_dims is True: |
| | D = _stack_3D_tensors(batches) |
| | return (D[reverse_indices], *returned_text) |
| |
|
| | elif keep_dims == 'flatten': |
| | D, mask = [], [] |
| |
|
| | for D_, mask_ in batches: |
| | D.append(D_) |
| | mask.append(mask_) |
| |
|
| | D, mask = torch.cat(D)[reverse_indices], torch.cat(mask)[reverse_indices] |
| |
|
| | doclens = mask.squeeze(-1).sum(-1).tolist() |
| |
|
| | D = D.view(-1, self.colbert_config.dim) |
| | D = D[mask.bool().flatten()].cpu() |
| |
|
| | return (D, doclens, *returned_text) |
| |
|
| | assert keep_dims is False |
| |
|
| | D = [d for batch in batches for d in batch] |
| | return ([D[idx] for idx in reverse_indices.tolist()], *returned_text) |
| |
|
| | input_ids, attention_mask = self.doc_tokenizer.tensorize(docs) |
| | return self.doc(input_ids, attention_mask, keep_dims=keep_dims, to_cpu=to_cpu) |
| |
|
| | def _stack_3D_tensors(groups): |
| | bsize = sum([x.size(0) for x in groups]) |
| | maxlen = max([x.size(1) for x in groups]) |
| | hdim = groups[0].size(2) |
| |
|
| | output = torch.zeros(bsize, maxlen, hdim, device=groups[0].device, dtype=groups[0].dtype) |
| |
|
| | offset = 0 |
| | for x in groups: |
| | endpos = offset + x.size(0) |
| | output[offset:endpos, :x.size(1)] = x |
| | offset = endpos |
| |
|
| | return output |