| import logging |
| import os |
| import json |
| from typing import Optional, Dict, List, Set, Tuple, Union, Literal, Type |
| from pydantic.dataclasses import dataclass |
|
|
| import numpy as np |
| from numpy.typing import NDArray |
|
|
| from transformers import PreTrainedTokenizerFast |
|
|
| logger = logging.getLogger(__name__) |
|
|
| VOCAB_FILES_NAMES = { |
| "tag_category": "tag_category.json", |
| } |
|
|
| PRETRAINED_VOCAB_FILES_MAP = { |
| "tag_category": { |
| "p1atdev/tokenizer_test_1": "https://huggingface.co/p1atdev/tokenizer_test_1/resolve/main/tag_category.json" |
| } |
| } |
|
|
|
|
| @dataclass |
| class Category: |
| name: str |
| max_count: Optional[int] |
| next_category: List[int] |
| can_end: bool |
| bos_token_id: int |
| eos_token_id: int |
| default_mask: int |
|
|
|
|
| @dataclass |
| class SpecialMapping: |
| allow: List[int] |
| disallow: List[int] |
|
|
|
|
| @dataclass |
| class TagCategoryConfig: |
| start_category: int |
| categories: Dict[str, Category] |
| special_mapping: Dict[ |
| str, Dict[str, SpecialMapping] |
| ] |
| category_tags_pairs: Dict[str, List[int]] |
|
|
|
|
| class OverrideMask: |
| allow: np.ndarray |
| disallow: np.ndarray |
|
|
| def __init__(self, allow: np.ndarray, disallow: np.ndarray) -> None: |
| self.allow = allow |
| self.disallow = disallow |
|
|
|
|
| def load_tag_category(config_json: str): |
| with open(config_json, "rb") as file: |
| config: TagCategoryConfig = TagCategoryConfig(**json.loads(file.read())) |
|
|
| return config |
|
|
|
|
| class DartTokenizer(PreTrainedTokenizerFast): |
| """Dart tokenizer""" |
|
|
| vocab_files_names = VOCAB_FILES_NAMES |
| pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP |
|
|
| def __init__(self, tag_category, **kwargs): |
| super().__init__(**kwargs) |
|
|
| self.tag_category_config = load_tag_category(tag_category) |
|
|
| self.category_bos_map = { |
| category.bos_token_id: category_id |
| for category_id, category in self.tag_category_config.categories.items() |
| } |
| self.category_eos_map = { |
| category.eos_token_id: category_id |
| for category_id, category in self.tag_category_config.categories.items() |
| } |
|
|
| self._id_to_category_map = np.zeros(self.vocab_size).astype("uint8") |
| for category_id, tokens in self.tag_category_config.category_tags_pairs.items(): |
| self._id_to_category_map[tokens] = int(category_id) |
|
|
| self.category_mask = self.create_category_vocab_mask() |
|
|
| def create_vocab_mask(self, value: int = 1): |
| """Create an array of vocab size filled with specified value""" |
| return np.full(self.vocab_size, value).astype("uint8") |
|
|
| def create_category_vocab_mask(self): |
| """Create vocab masks for each category""" |
| return { |
| category_id: self.create_vocab_mask( |
| value=category.default_mask, |
| ) |
| for category_id, category in self.tag_category_config.categories.items() |
| } |
|
|
| def get_token_ids_in_category(self, category_id: Union[int, str]): |
| """Get token ids in the specified category""" |
| return self.tag_category_config.category_tags_pairs[str(category_id)] |
|
|
| def get_category(self, category_id: Union[int, str]): |
| """Get the specified category config""" |
| return self.tag_category_config.categories[str(category_id)] |
|
|
| def get_special_mapping(self, token_id: Union[int, str]): |
| """Get the special mapping of specified token id""" |
| return self.tag_category_config.special_mapping[str(token_id)] |
|
|
| def get_banned_tokens_mask(self, tokens: Union[str, List[str], int, List[int]]): |
| if isinstance(tokens, str): |
| tokens = [tokens] |
| elif isinstance(tokens, int): |
| tokens = [tokens] |
| elif isinstance(tokens, list): |
| tokens = [ |
| self.convert_tokens_to_ids(token) if isinstance(token, str) else token |
| for token in tokens |
| ] |
|
|
| assert isinstance(tokens, list) and all( |
| [isinstance(token, int) for token in tokens] |
| ) |
|
|
| mask = self.create_vocab_mask(value=1) |
| mask[tokens] = 0 |
|
|
| return mask |
|
|
| def convert_ids_to_category_ids(self, token_ids: Union[int, List[int]]): |
| return self._id_to_category_map[token_ids] |
|
|
| def get_next_tokens_mask( |
| self, |
| input_ids: List[int], |
| category_mask: Optional[Dict[str, np.ndarray]] = None, |
| ) -> Tuple[np.ndarray, Dict[str, np.ndarray]]: |
| """Get the next token's vocab mask and a category mask""" |
|
|
| if category_mask == None: |
| category_mask = self.category_mask |
|
|
| vocab_mask = self.create_vocab_mask(value=0) |
|
|
| if len(input_ids) == 0: |
| |
| vocab_mask[self.bos_token_id] = 1 |
|
|
| return vocab_mask, category_mask |
|
|
| |
| last_token_id = input_ids[-1] |
|
|
| if last_token_id == self.unk_token_id: |
| |
| logger.warning( |
| "The unk_token was provided! The vocab mask could not be created properly." |
| ) |
| return self.create_vocab_mask(value=1), category_mask |
|
|
| |
| if str(last_token_id) in self.tag_category_config.special_mapping.keys(): |
| for category_id, mapping in self.get_special_mapping(last_token_id).items(): |
| |
| category_mask[category_id][mapping.allow] = 1 |
| category_mask[category_id][mapping.disallow] = 0 |
|
|
| if last_token_id == self.bos_token_id: |
| |
| start_category_id = self.tag_category_config.start_category |
| start_category = self.get_category(start_category_id) |
|
|
| |
| vocab_mask[start_category.bos_token_id] = 1 |
|
|
| return vocab_mask, category_mask |
|
|
| elif last_token_id == self.eos_token_id: |
| |
|
|
| vocab_mask[self.pad_token_id] = 1 |
|
|
| return vocab_mask, category_mask |
|
|
| elif last_token_id in self.category_bos_map: |
| |
|
|
| |
| current_category_id = self.category_bos_map[last_token_id] |
| category = self.get_category(current_category_id) |
|
|
| tokens_in_category = self.get_token_ids_in_category(current_category_id) |
| vocab_mask[tokens_in_category] = 1 |
|
|
| vocab_mask *= category_mask[str(current_category_id)] |
| vocab_mask[category.eos_token_id] = 1 |
|
|
| return vocab_mask, category_mask |
|
|
| elif last_token_id in self.category_eos_map: |
| |
|
|
| current_category_id = self.category_eos_map[last_token_id] |
| category = self.get_category(current_category_id) |
|
|
| if category.can_end: |
| |
| vocab_mask[self.eos_token_id] = 1 |
|
|
| for next_category_id in category.next_category: |
| |
| vocab_mask[self.get_category(next_category_id).bos_token_id] = 1 |
|
|
| return vocab_mask, category_mask |
|
|
| else: |
| |
| current_category_id = self.convert_ids_to_category_ids(last_token_id).item() |
| tokens_in_category = self.get_token_ids_in_category(current_category_id) |
|
|
| vocab_mask[tokens_in_category] = 1 |
| vocab_mask[self.get_category(current_category_id).eos_token_id] = 1 |
| vocab_mask *= category_mask[str(current_category_id)] |
| vocab_mask[input_ids] = 0 |
|
|
| return vocab_mask, category_mask |
|
|