| from __future__ import annotations |
| from pathlib import Path |
| from torch import nn |
| from transformers import AutoConfig, AutoModel, AutoTokenizer |
| from transformers import PretrainedConfig, PreTrainedModel, PreTrainedTokenizer |
|
|
| def make_config_class(model_args: dict, model_type: str) -> type[PretrainedConfig]: |
| model_type_ = model_type |
|
|
| class Config(PretrainedConfig): |
| model_type = model_type_ |
|
|
| def __init__(self, **kwargs): |
| for k, v in model_args.items(): |
| setattr(self, k, kwargs.get(k, v)) |
|
|
| super().__init__(**kwargs) |
|
|
| return Config |
|
|
|
|
| def make_model_class(base_class: type[nn.Module]) -> type[PreTrainedModel]: |
| class Model(PreTrainedModel): |
| config_class: type[PretrainedConfig] |
|
|
| def __init__(self, config: PretrainedConfig, *args, **kwargs): |
| super().__init__(config) |
| self._model = base_class(config, *args, **kwargs) |
|
|
| def forward(self, *args, **kwargs): |
| return self._model(*args, **kwargs) |
|
|
| return Model |
|
|
|
|
| def make_tokenizer_class( |
| vocab: list[str], |
| special_tokens: dict[str, str] |
| ) -> type[PreTrainedTokenizer]: |
|
|
| for key in special_tokens: |
| if key not in ["unk", "pad", "bos", "eos", "sep", "cls", "mask"]: |
| raise ValueError(f"unrecognized special token key: `{key}`") |
|
|
| unk_token = special_tokens.get("unk", vocab[0]) |
| token_to_idx = {k: v for v, k in enumerate(vocab)} |
| idx_to_token = {v: k for k, v in token_to_idx.items()} |
|
|
| |
| class Tokenizer(PreTrainedTokenizer): |
| model_input_names = ["input_ids"] |
|
|
| def __init__( |
| self, |
| model_max_length: int | None = None, |
| split_special_tokens: bool = True, |
| **kwargs |
| ): |
| self.model_max_length = model_max_length |
| self._vocab = token_to_idx |
| self._inv_vocab = idx_to_token |
| tokens = dict( |
| unk_token=special_tokens.get("unk"), |
| pad_token=special_tokens.get("pad"), |
| bos_token=special_tokens.get("bos"), |
| eos_token=special_tokens.get("eos"), |
| sep_token=special_tokens.get("sep"), |
| cls_token=special_tokens.get("cls"), |
| mask_token=special_tokens.get("mask"), |
| ) |
| tokens = {k: v for k, v in tokens.items() if v is not None} |
| super().__init__( |
| model_max_length=model_max_length, |
| split_special_tokens=split_special_tokens, |
| **tokens, |
| **kwargs, |
| ) |
|
|
| def _tokenize(self, seq: str) -> list[str]: |
| return list(seq) |
|
|
| def _convert_token_to_id(self, token: str) -> int: |
| return self._vocab.get(token, self._vocab[unk_token]) |
|
|
| def _convert_id_to_token(self, idx: int) -> str: |
| return self._inv_vocab[idx] |
|
|
| @property |
| def vocab_size(self) -> int: |
| return len(self._vocab) |
|
|
| def get_vocab(self) -> dict[str, int]: |
| return self._vocab |
|
|
| def save_vocabulary(self, save_directory: str, filename_prefix: str | None = None) -> tuple: |
| return () |
|
|
| return Tokenizer |
|
|
|
|
| def register_auto_classes( |
| config_class: type[PretrainedConfig], |
| model_class: type[PreTrainedModel] = None, |
| tokenizer_class: type[PreTrainedTokenizer] = None, |
| force_registration: bool = False, |
| ): |
| model_type = getattr(config_class, "model_type", None) |
| if model_type is None: |
| raise ValueError("`config_class` must have a `model_type` attribute") |
|
|
| |
| already_registered = check_auto_class_registered( |
| *(c for c in [config_class, model_class, tokenizer_class] if c is not None) |
| ) |
| if already_registered and not force_registration: |
| raise RuntimeError("One or more classes are already registered. Set `force_registration=True` to override.") |
|
|
| AutoConfig.register(model_type, config_class) |
| config_class.register_for_auto_class() |
|
|
| if model_class is not None: |
| if not hasattr(model_class, "config_class") or model_class.config_class is None: |
| model_class.config_class = config_class |
| |
| AutoModel.register(config_class, model_class) |
| model_class.register_for_auto_class("AutoModel") |
|
|
| if tokenizer_class is not None: |
| AutoTokenizer.register(config_class, tokenizer_class) |
| tokenizer_class.register_for_auto_class("AutoTokenizer") |
|
|
|
|
| def check_auto_class_registered(*classes) -> bool: |
| |
| |
| return False |
|
|
|
|
| def push_model_to_hub( |
| config_class: type[PretrainedConfig], |
| model_class: type[PreTrainedModel], |
| model_args: dict, |
| state_dict: dict, |
| id_: str, |
| commit_message: str = "Upload model", |
| ) -> str: |
| config = config_class(**model_args) |
| huggingface_model = model_class(config) |
| pytorch_model = getattr(huggingface_model, "_model") |
| pytorch_model.load_state_dict(state_dict) |
| config.save_pretrained(id_) |
| huggingface_model.save_pretrained(id_) |
| return huggingface_model.push_to_hub(id_, commit_message=commit_message) |
|
|
|
|
| def push_tokenizer_to_hub( |
| tokenizer_class: type[PreTrainedTokenizer], |
| id_: str, |
| commit_message: str = "Upload tokenizer", |
| **kwargs, |
| ) -> str: |
| tokenizer = tokenizer_class(**kwargs) |
| tokenizer.save_pretrained(id_) |
| return tokenizer.push_to_hub(id_, commit_message=commit_message) |