| | import base64 |
| | import os |
| | from functools import lru_cache |
| | from typing import Optional |
| | import torch |
| | from transformers import AutoTokenizer |
| | from whisper.tokenizer import Tokenizer |
| |
|
| | import tiktoken |
| |
|
| | LANGUAGES = { |
| | "en": "english", |
| | "zh": "chinese", |
| | "de": "german", |
| | "es": "spanish", |
| | "ru": "russian", |
| | "ko": "korean", |
| | "fr": "french", |
| | "ja": "japanese", |
| | "pt": "portuguese", |
| | "tr": "turkish", |
| | "pl": "polish", |
| | "ca": "catalan", |
| | "nl": "dutch", |
| | "ar": "arabic", |
| | "sv": "swedish", |
| | "it": "italian", |
| | "id": "indonesian", |
| | "hi": "hindi", |
| | "fi": "finnish", |
| | "vi": "vietnamese", |
| | "he": "hebrew", |
| | "uk": "ukrainian", |
| | "el": "greek", |
| | "ms": "malay", |
| | "cs": "czech", |
| | "ro": "romanian", |
| | "da": "danish", |
| | "hu": "hungarian", |
| | "ta": "tamil", |
| | "no": "norwegian", |
| | "th": "thai", |
| | "ur": "urdu", |
| | "hr": "croatian", |
| | "bg": "bulgarian", |
| | "lt": "lithuanian", |
| | "la": "latin", |
| | "mi": "maori", |
| | "ml": "malayalam", |
| | "cy": "welsh", |
| | "sk": "slovak", |
| | "te": "telugu", |
| | "fa": "persian", |
| | "lv": "latvian", |
| | "bn": "bengali", |
| | "sr": "serbian", |
| | "az": "azerbaijani", |
| | "sl": "slovenian", |
| | "kn": "kannada", |
| | "et": "estonian", |
| | "mk": "macedonian", |
| | "br": "breton", |
| | "eu": "basque", |
| | "is": "icelandic", |
| | "hy": "armenian", |
| | "ne": "nepali", |
| | "mn": "mongolian", |
| | "bs": "bosnian", |
| | "kk": "kazakh", |
| | "sq": "albanian", |
| | "sw": "swahili", |
| | "gl": "galician", |
| | "mr": "marathi", |
| | "pa": "punjabi", |
| | "si": "sinhala", |
| | "km": "khmer", |
| | "sn": "shona", |
| | "yo": "yoruba", |
| | "so": "somali", |
| | "af": "afrikaans", |
| | "oc": "occitan", |
| | "ka": "georgian", |
| | "be": "belarusian", |
| | "tg": "tajik", |
| | "sd": "sindhi", |
| | "gu": "gujarati", |
| | "am": "amharic", |
| | "yi": "yiddish", |
| | "lo": "lao", |
| | "uz": "uzbek", |
| | "fo": "faroese", |
| | "ht": "haitian creole", |
| | "ps": "pashto", |
| | "tk": "turkmen", |
| | "nn": "nynorsk", |
| | "mt": "maltese", |
| | "sa": "sanskrit", |
| | "lb": "luxembourgish", |
| | "my": "myanmar", |
| | "bo": "tibetan", |
| | "tl": "tagalog", |
| | "mg": "malagasy", |
| | "as": "assamese", |
| | "tt": "tatar", |
| | "haw": "hawaiian", |
| | "ln": "lingala", |
| | "ha": "hausa", |
| | "ba": "bashkir", |
| | "jw": "javanese", |
| | "su": "sundanese", |
| | "yue": "cantonese", |
| | "minnan": "minnan", |
| | "wuyu": "wuyu", |
| | "dialect": "dialect", |
| | "zh/en": "zh/en", |
| | "en/zh": "en/zh", |
| | } |
| |
|
| | |
| | TO_LANGUAGE_CODE = { |
| | **{language: code for code, language in LANGUAGES.items()}, |
| | "burmese": "my", |
| | "valencian": "ca", |
| | "flemish": "nl", |
| | "haitian": "ht", |
| | "letzeburgesch": "lb", |
| | "pushto": "ps", |
| | "panjabi": "pa", |
| | "moldavian": "ro", |
| | "moldovan": "ro", |
| | "sinhalese": "si", |
| | "castilian": "es", |
| | "mandarin": "zh", |
| | } |
| |
|
| | AUDIO_EVENT = { |
| | "ASR": "ASR", |
| | "AED": "AED", |
| | "SER": "SER", |
| | "Speech": "Speech", |
| | "/Speech": "/Speech", |
| | "BGM": "BGM", |
| | "/BGM": "/BGM", |
| | "Laughter": "Laughter", |
| | "/Laughter": "/Laughter", |
| | "Applause": "Applause", |
| | "/Applause": "/Applause", |
| | } |
| |
|
| | EMOTION = { |
| | "HAPPY": "HAPPY", |
| | "SAD": "SAD", |
| | "ANGRY": "ANGRY", |
| | "NEUTRAL": "NEUTRAL", |
| | } |
| |
|
| | TTS_Vocal_Token = { |
| | "TTS/B": "TTS/B", |
| | "TTS/O": "TTS/O", |
| | "TTS/Q": "TTS/Q", |
| | "TTS/A": "TTS/A", |
| | "TTS/CO": "TTS/CO", |
| | "TTS/CL": "TTS/CL", |
| | "TTS/H": "TTS/H", |
| | **{f"TTS/SP{i:02d}": f"TTS/SP{i:02d}" for i in range(1, 14)} |
| | } |
| |
|
| |
|
| | @lru_cache(maxsize=None) |
| | def get_encoding(name: str = "gpt2", num_languages: int = 99): |
| | vocab_path = os.path.join(os.path.dirname(__file__), "assets", f"{name}.tiktoken") |
| | ranks = { |
| | base64.b64decode(token): int(rank) |
| | for token, rank in (line.split() for line in open(vocab_path) if line) |
| | } |
| | n_vocab = len(ranks) |
| | special_tokens = {} |
| |
|
| | specials = [ |
| | "<|endoftext|>", |
| | "<|startoftranscript|>", |
| | *[f"<|{lang}|>" for lang in list(LANGUAGES.keys())[:num_languages]], |
| | *[f"<|{audio_event}|>" for audio_event in list(AUDIO_EVENT.keys())], |
| | *[f"<|{emotion}|>" for emotion in list(EMOTION.keys())], |
| | "<|translate|>", |
| | "<|transcribe|>", |
| | "<|startoflm|>", |
| | "<|startofprev|>", |
| | "<|nospeech|>", |
| | "<|notimestamps|>", |
| | *[f"<|SPECIAL_TOKEN_{i}|>" for i in range(1, 31)], |
| | *[f"<|{tts}|>" for tts in list(TTS_Vocal_Token.keys())], |
| | *[f"<|{i * 0.02:.2f}|>" for i in range(1501)], |
| | ] |
| |
|
| | for token in specials: |
| | special_tokens[token] = n_vocab |
| | n_vocab += 1 |
| |
|
| | return tiktoken.Encoding( |
| | name=os.path.basename(vocab_path), |
| | explicit_n_vocab=n_vocab, |
| | pat_str=r"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+""", |
| | mergeable_ranks=ranks, |
| | special_tokens=special_tokens, |
| | ) |
| |
|
| |
|
| | @lru_cache(maxsize=None) |
| | def get_tokenizer( |
| | multilingual: bool, |
| | *, |
| | num_languages: int = 99, |
| | language: Optional[str] = None, |
| | task: Optional[str] = None, |
| | ) -> Tokenizer: |
| | if language is not None: |
| | language = language.lower() |
| | if language not in LANGUAGES: |
| | if language in TO_LANGUAGE_CODE: |
| | language = TO_LANGUAGE_CODE[language] |
| | else: |
| | raise ValueError(f"Unsupported language: {language}") |
| |
|
| | if multilingual: |
| | encoding_name = "multilingual_zh_ja_yue_char_del" |
| | language = language or "en" |
| | task = task or "transcribe" |
| | else: |
| | encoding_name = "gpt2" |
| | language = None |
| | task = None |
| |
|
| | encoding = get_encoding(name=encoding_name, num_languages=num_languages) |
| |
|
| | return Tokenizer( |
| | encoding=encoding, num_languages=num_languages, language=language, task=task |
| | ) |
| |
|
| |
|
| | class QwenTokenizer(): |
| | def __init__(self, token_path, skip_special_tokens=True): |
| | super().__init__() |
| | |
| | special_tokens = { |
| | 'eos_token': '<|endoftext|>', |
| | 'pad_token': '<|endoftext|>', |
| | 'additional_special_tokens': [ |
| | '<|im_start|>', '<|im_end|>', '<|endofprompt|>', |
| | '[breath]', '<strong>', '</strong>', '[noise]', |
| | '[laughter]', '[cough]', '[clucking]', '[accent]', |
| | '[quick_breath]', |
| | "<laughter>", "</laughter>", |
| | "[hissing]", "[sigh]", "[vocalized-noise]", |
| | "[lipsmack]", "[mn]" |
| | ] |
| | } |
| | self.special_tokens = special_tokens |
| | self.tokenizer = AutoTokenizer.from_pretrained(token_path) |
| | self.tokenizer.add_special_tokens(special_tokens) |
| | self.skip_special_tokens = skip_special_tokens |
| |
|
| | def encode(self, text, **kwargs): |
| | tokens = self.tokenizer([text], return_tensors="pt") |
| | tokens = tokens["input_ids"][0].cpu().tolist() |
| | return tokens |
| |
|
| | def decode(self, tokens): |
| | tokens = torch.tensor(tokens, dtype=torch.int64) |
| | text = self.tokenizer.batch_decode([tokens], skip_special_tokens=self.skip_special_tokens)[0] |
| | return text |
| |
|
| |
|
| | @lru_cache(maxsize=None) |
| | def get_qwen_tokenizer( |
| | token_path: str, |
| | skip_special_tokens: bool |
| | ) -> QwenTokenizer: |
| | return QwenTokenizer(token_path=token_path, skip_special_tokens=skip_special_tokens) |
| |
|