#!/usr/bin/env python3 # tokenization_binaryllm.py # ============================================================ # BinaryLLMTokenizer (AutoTokenizer compatible) — EXACTEMENT la même # tokenisation/decodage que llmTalk (mode base=65536) + infer_tagged12/11: # # - Base: 65536 # - IDs radix: 0..65535 # - BOS: 65536 # - EOS: 65537 # - UNK: alias EOS (65537) (pas de nouveau token dans la base) # - Encodage: UTF-8 bytes -> digits base65536 BIG-ENDIAN (chunks 2 bytes) # * si longueur impaire: dernier byte encodé en valeur 0..255 (1 digit) # - Décodage: digits -> bytes BIG-ENDIAN -> UTF-8 (errors="replace") # # Important: # - build_inputs_with_special_tokens: [BOS] + seq + [EOS] (comme HF classique) # - encode(..., add_special_tokens=False) renvoie UNIQUEMENT les digits base65536 # - encode(..., add_special_tokens=True) ajoute BOS/EOS via build_inputs... # # Ce fichier suffit pour `trust_remote_code=True` côté repo HF. # ============================================================ from __future__ import annotations import json import os import re from typing import Dict, List, Optional, Tuple, Any from transformers import PreTrainedTokenizer class BinaryLLMTokenizer(PreTrainedTokenizer): model_input_names = ["input_ids", "attention_mask"] TOKEN_RE = re.compile(r"^$") def __init__( self, bos_token: str = "", eos_token: str = "", unk_token: str = "", pad_token: Optional[str] = None, **kwargs: Any, ): # radix strict self._base_vocab_size = 65536 # specials strict: base + 0/1 self._bos_id = 65536 self._eos_id = 65537 # UNK alias EOS (pas de token additionnel) self._unk_id = self._eos_id self._bos_str = bos_token self._eos_str = eos_token self._unk_str = unk_token self._pad_str = pad_token super().__init__( bos_token=bos_token, eos_token=eos_token, unk_token=unk_token, pad_token=pad_token, **kwargs, ) # ---------- vocab / ids ---------- @property def vocab_size(self) -> int: # 65536 + BOS + EOS return 65538 def get_vocab(self) -> Dict[str, int]: # IMPORTANT: ne jamais appeler self.unk_token_id ici (boucle) v = { self._bos_str: self._bos_id, self._eos_str: self._eos_id, self._unk_str: self._unk_id, } if self.pad_token is not None: v[self.pad_token] = self._convert_token_to_id(self.pad_token) return v def _id_to_token_base(self, i: int) -> str: return f"" # ---------- core encode/decode (même logique que infer_tagged / llmTalk base) ---------- def _encode_to_base65536_big_endian(self, text: str) -> List[int]: b = bytearray(text.encode("utf-8", errors="strict")) if len(b) == 0: return [0] out: List[int] = [] i = 0 n = len(b) while i + 1 < n: # 2 bytes -> 1 digit base65536 big-endian out.append((b[i] << 8) | b[i + 1]) i += 2 if i < n: # dernier byte seul -> digit 0..255 out.append(int(b[i])) return out def _decode_from_base65536_big_endian(self, ids: List[int]) -> str: bb = bytearray() for x in ids: xi = int(x) & 0xFFFFFFFF if 0 <= xi <= 255: bb.append(xi) else: bb.append((xi >> 8) & 0xFF) bb.append(xi & 0xFF) return bytes(bb).decode("utf-8", errors="replace") # ---------- HF tokenizer API overrides ---------- def _tokenize(self, text: str) -> List[str]: ids = self._encode_to_base65536_big_endian(text) return [self._id_to_token_base(i) for i in ids] def _convert_token_to_id(self, token: str) -> int: if token == self._bos_str: return self._bos_id if token == self._eos_str: return self._eos_id if token == self._unk_str: return self._unk_id if self.pad_token is not None and token == self.pad_token: # pas de PAD dédié => alias EOS (compatible avec ton cadre) if self.pad_token == self._eos_str: return self._eos_id return self._eos_id m = self.TOKEN_RE.match(token) if m: return int(m.group(1), 16) return self._unk_id def _convert_id_to_token(self, index: int) -> str: if index == self._bos_id: return self._bos_str if index == self._eos_id: return self._eos_str if index == self._unk_id: return self._unk_str if self.pad_token is not None and index == self.pad_token_id: return self.pad_token if 0 <= index < self._base_vocab_size: return self._id_to_token_base(index) return self._unk_str def convert_tokens_to_string(self, tokens: List[str]) -> str: ids: List[int] = [] for t in tokens: if t in (self._bos_str, self._eos_str, self._unk_str): continue if self.pad_token is not None and t == self.pad_token: continue m = self.TOKEN_RE.match(t) if m: ids.append(int(m.group(1), 16)) return self._decode_from_base65536_big_endian(ids) def build_inputs_with_special_tokens( self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None, ) -> List[int]: # HF-style (simple): [BOS] seq [EOS] # Pair: [BOS] seq0 [EOS] seq1 [EOS] if token_ids_1 is None: return [self._bos_id] + token_ids_0 + [self._eos_id] return [self._bos_id] + token_ids_0 + [self._eos_id] + token_ids_1 + [self._eos_id] def get_special_tokens_mask( self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None, already_has_special_tokens: bool = False, ) -> List[int]: pad_id = self.pad_token_id if self.pad_token is not None else -1 if already_has_special_tokens: return [ 1 if t in (self._bos_id, self._eos_id, self._unk_id, pad_id) else 0 for t in token_ids_0 ] if token_ids_1 is None: return [1] + [0] * len(token_ids_0) + [1] return [1] + [0] * len(token_ids_0) + [1] + [0] * len(token_ids_1) + [1] def create_token_type_ids_from_sequences( self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None, ) -> List[int]: if token_ids_1 is None: return [0] * (len(token_ids_0) + 2) return [0] * (len(token_ids_0) + len(token_ids_1) + 3) def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]: if not os.path.isdir(save_directory): os.makedirs(save_directory, exist_ok=True) name = (filename_prefix + "-" if filename_prefix else "") + "binaryllm_vocab.json" path = os.path.join(save_directory, name) data = { "base_vocab_size": 65536, "vocab_size": 65538, "bos_token": self._bos_str, "bos_token_id": self._bos_id, "eos_token": self._eos_str, "eos_token_id": self._eos_id, "unk_token": self._unk_str, "unk_token_id": self._unk_id, "pad_token": self.pad_token, "pad_token_id": self.pad_token_id, "encoding": "utf-8", "radix": 65536, "endianness": "big", "odd_length_rule": "last_byte_as_single_digit_0_255", } with open(path, "w", encoding="utf-8") as f: json.dump(data, f, ensure_ascii=False, indent=2) return (path,)