| |
|
|
| import html |
| import os |
| import string |
| from typing import List, Optional, Union |
| import warnings |
|
|
| try: |
| import ftfy |
| except ImportError: |
| ftfy = None |
| import torch |
|
|
| os.environ["TOKENIZERS_PARALLELISM"] = "false" |
|
|
| DEFAULT_CONTEXT_LENGTH = 77 |
|
|
|
|
| def basic_clean(text): |
| if ftfy is not None: |
| text = ftfy.fix_text(text) |
| else: |
| text |
| text = html.unescape(html.unescape(text)) |
| return text.strip() |
|
|
|
|
| def whitespace_clean(text): |
| text = " ".join(text.split()) |
| text = text.strip() |
| return text |
|
|
|
|
| def _clean_canonicalize(x): |
| return canonicalize_text(basic_clean(x)) |
|
|
|
|
| def _clean_lower(x): |
| return whitespace_clean(basic_clean(x)).lower() |
|
|
|
|
| def _clean_whitespace(x): |
| return whitespace_clean(basic_clean(x)) |
|
|
|
|
| def get_clean_fn(type: str): |
| if type == "canonicalize": |
| return _clean_canonicalize |
| elif type == "lower": |
| return _clean_lower |
| elif type == "whitespace": |
| return _clean_whitespace |
| else: |
| assert False, f"Invalid clean function ({type})." |
|
|
|
|
| def canonicalize_text( |
| text, |
| *, |
| keep_punctuation_exact_string=None, |
| trans_punctuation: dict = str.maketrans("", "", string.punctuation), |
| ): |
| """Returns canonicalized `text` (lowercase and punctuation removed).""" |
| text = text.replace("_", " ") |
| if keep_punctuation_exact_string: |
| text = keep_punctuation_exact_string.join( |
| part.translate(trans_punctuation) |
| for part in text.split(keep_punctuation_exact_string) |
| ) |
| else: |
| text = text.translate(trans_punctuation) |
| text = text.lower() |
| text = " ".join(text.split()) |
| return text.strip() |
|
|
|
|
| class HFTokenizer: |
| """HuggingFace tokenizer wrapper with support for custom tokenization modes""" |
|
|
| def __init__( |
| self, |
| tokenizer_name: str, |
| context_length: Optional[int] = DEFAULT_CONTEXT_LENGTH, |
| clean: str = "whitespace", |
| strip_sep_token: bool = False, |
| language: Optional[str] = None, |
| cache_dir: Optional[str] = None, |
| tokenizer_mode: Optional[str] = None, |
| **kwargs, |
| ): |
| self.tokenizer_mode = tokenizer_mode or "" |
| self.context_length = context_length |
| self.clean_fn = get_clean_fn(clean) |
| self.strip_sep_token = strip_sep_token |
|
|
| from transformers import AutoTokenizer |
|
|
| self.tokenizer = AutoTokenizer.from_pretrained( |
| tokenizer_name, cache_dir=cache_dir, **kwargs |
| ) |
|
|
| set_lang_fn = getattr(self.tokenizer, "set_src_lang_special_tokens", None) |
| if callable(set_lang_fn): |
| self.set_lang_fn = set_lang_fn |
| if language is not None: |
| self.set_language(language) |
|
|
| def save_pretrained(self, dest): |
| self.tokenizer.save_pretrained(dest) |
|
|
| def __call__( |
| self, texts: Union[str, List[str]], context_length: Optional[int] = None |
| ) -> torch.Tensor: |
| if isinstance(texts, str): |
| texts = [texts] |
|
|
| context_length = context_length or self.context_length |
| assert context_length, ( |
| "Please set a valid context length in class init or call." |
| ) |
|
|
| texts = [self.clean_fn(text) for text in texts] |
|
|
| if self.tokenizer_mode == "clips": |
| return self._clips_tokenize(texts, context_length) |
| else: |
| output = self.tokenizer( |
| texts, |
| return_tensors="pt", |
| max_length=context_length, |
| padding="max_length", |
| truncation=True, |
| ) |
| input_ids = output.input_ids |
|
|
| if self.strip_sep_token: |
| input_ids = torch.where( |
| input_ids == self.tokenizer.sep_token_id, |
| torch.zeros_like(input_ids), |
| input_ids, |
| ) |
|
|
| return input_ids |
|
|
| def set_language(self, src_lang): |
| if hasattr(self, "set_lang_fn"): |
| self.set_lang_fn(src_lang) |
| else: |
| warnings.warn("Cannot set language for the tokenizer.") |
|
|
| def _clips_tokenize(self, texts: List[str], context_length: int) -> torch.Tensor: |
| encoded_outputs = self.tokenizer( |
| texts, |
| add_special_tokens=False, |
| padding=False, |
| truncation=False, |
| return_tensors=None, |
| ) |
|
|
| encoded = [] |
| for tokens in encoded_outputs["input_ids"]: |
| tokens = tokens[: context_length - 3] |
| tokens = ( |
| [self.tokenizer.bos_token_id] + tokens + [self.tokenizer.eos_token_id] |
| ) |
| encoded.append(tokens) |
|
|
| result = torch.zeros(len(encoded), context_length, dtype=torch.long) |
| for i, tokens in enumerate(encoded): |
| padded_tokens = self._pad_and_add_class_token( |
| tokens, |
| max_length=context_length, |
| pad_token_id=self.tokenizer.pad_token_id, |
| cls_token_id=self.tokenizer.cls_token_id, |
| ) |
| result[i, : len(padded_tokens)] = torch.tensor(padded_tokens) |
|
|
| return result |
|
|
| def _pad_and_add_class_token( |
| self, |
| tokens: List[int], |
| max_length: int, |
| pad_token_id: int = 0, |
| cls_token_id: int = 101, |
| ) -> List[int]: |
| if len(tokens) > max_length - 1: |
| tokens = tokens[: max_length - 1] |
| if len(tokens) < max_length - 1: |
| tokens = tokens + [pad_token_id] * (max_length - 1 - len(tokens)) |
| tokens = tokens + [cls_token_id] |
| return tokens |
|
|