ValentineKRAFTON's picture
initial commit
acd771b verified
# Originally from OpenCLIP (https://github.com/mlfoundations/open_clip)
import html
import os
import string
from typing import List, Optional, Union
import warnings
try:
import ftfy
except ImportError:
ftfy = None
import torch
os.environ["TOKENIZERS_PARALLELISM"] = "false"
DEFAULT_CONTEXT_LENGTH = 77
def basic_clean(text):
if ftfy is not None:
text = ftfy.fix_text(text)
else:
text
text = html.unescape(html.unescape(text))
return text.strip()
def whitespace_clean(text):
text = " ".join(text.split())
text = text.strip()
return text
def _clean_canonicalize(x):
return canonicalize_text(basic_clean(x))
def _clean_lower(x):
return whitespace_clean(basic_clean(x)).lower()
def _clean_whitespace(x):
return whitespace_clean(basic_clean(x))
def get_clean_fn(type: str):
if type == "canonicalize":
return _clean_canonicalize
elif type == "lower":
return _clean_lower
elif type == "whitespace":
return _clean_whitespace
else:
assert False, f"Invalid clean function ({type})."
def canonicalize_text(
text,
*,
keep_punctuation_exact_string=None,
trans_punctuation: dict = str.maketrans("", "", string.punctuation),
):
"""Returns canonicalized `text` (lowercase and punctuation removed)."""
text = text.replace("_", " ")
if keep_punctuation_exact_string:
text = keep_punctuation_exact_string.join(
part.translate(trans_punctuation)
for part in text.split(keep_punctuation_exact_string)
)
else:
text = text.translate(trans_punctuation)
text = text.lower()
text = " ".join(text.split())
return text.strip()
class HFTokenizer:
"""HuggingFace tokenizer wrapper with support for custom tokenization modes"""
def __init__(
self,
tokenizer_name: str,
context_length: Optional[int] = DEFAULT_CONTEXT_LENGTH,
clean: str = "whitespace",
strip_sep_token: bool = False,
language: Optional[str] = None,
cache_dir: Optional[str] = None,
tokenizer_mode: Optional[str] = None,
**kwargs,
):
self.tokenizer_mode = tokenizer_mode or ""
self.context_length = context_length
self.clean_fn = get_clean_fn(clean)
self.strip_sep_token = strip_sep_token
from transformers import AutoTokenizer
self.tokenizer = AutoTokenizer.from_pretrained(
tokenizer_name, cache_dir=cache_dir, **kwargs
)
set_lang_fn = getattr(self.tokenizer, "set_src_lang_special_tokens", None)
if callable(set_lang_fn):
self.set_lang_fn = set_lang_fn
if language is not None:
self.set_language(language)
def save_pretrained(self, dest):
self.tokenizer.save_pretrained(dest)
def __call__(
self, texts: Union[str, List[str]], context_length: Optional[int] = None
) -> torch.Tensor:
if isinstance(texts, str):
texts = [texts]
context_length = context_length or self.context_length
assert context_length, (
"Please set a valid context length in class init or call."
)
texts = [self.clean_fn(text) for text in texts]
if self.tokenizer_mode == "clips":
return self._clips_tokenize(texts, context_length)
else:
output = self.tokenizer(
texts,
return_tensors="pt",
max_length=context_length,
padding="max_length",
truncation=True,
)
input_ids = output.input_ids
if self.strip_sep_token:
input_ids = torch.where(
input_ids == self.tokenizer.sep_token_id,
torch.zeros_like(input_ids),
input_ids,
)
return input_ids
def set_language(self, src_lang):
if hasattr(self, "set_lang_fn"):
self.set_lang_fn(src_lang)
else:
warnings.warn("Cannot set language for the tokenizer.")
def _clips_tokenize(self, texts: List[str], context_length: int) -> torch.Tensor:
encoded_outputs = self.tokenizer(
texts,
add_special_tokens=False,
padding=False,
truncation=False,
return_tensors=None,
)
encoded = []
for tokens in encoded_outputs["input_ids"]:
tokens = tokens[: context_length - 3]
tokens = (
[self.tokenizer.bos_token_id] + tokens + [self.tokenizer.eos_token_id]
)
encoded.append(tokens)
result = torch.zeros(len(encoded), context_length, dtype=torch.long)
for i, tokens in enumerate(encoded):
padded_tokens = self._pad_and_add_class_token(
tokens,
max_length=context_length,
pad_token_id=self.tokenizer.pad_token_id,
cls_token_id=self.tokenizer.cls_token_id,
)
result[i, : len(padded_tokens)] = torch.tensor(padded_tokens)
return result
def _pad_and_add_class_token(
self,
tokens: List[int],
max_length: int,
pad_token_id: int = 0,
cls_token_id: int = 101,
) -> List[int]:
if len(tokens) > max_length - 1:
tokens = tokens[: max_length - 1]
if len(tokens) < max_length - 1:
tokens = tokens + [pad_token_id] * (max_length - 1 - len(tokens))
tokens = tokens + [cls_token_id]
return tokens