Spaces:
Running
Running
| """ | |
| TokenizerBench — Hugging Face Space | |
| Evaluate and compare tokenizers on the TokenizerBench dataset. | |
| Supports: HuggingFace AutoTokenizer (Hub ID or uploaded files), | |
| tiktoken encodings, SentencePiece .model files. | |
| """ | |
| import io | |
| import shutil | |
| import tempfile | |
| import traceback | |
| from pathlib import Path | |
| from typing import Any | |
| import gradio as gr | |
| import matplotlib | |
| import matplotlib.pyplot as plt | |
| import matplotlib.patches as mpatches | |
| import numpy as np | |
| import pandas as pd | |
| matplotlib.use("Agg") | |
| # ───────────────────────────────────────────────────────────────── | |
| # Dataset (inline subset of TokenizerBench) | |
| # ───────────────────────────────────────────────────────────────── | |
| DATASET: dict[str, dict[str, list[str]]] = { | |
| "human_languages": { | |
| "english": [ | |
| "The quick brown fox jumps over the lazy dog.", | |
| "Artificial intelligence is transforming every industry.", | |
| "Natural language processing enables machines to understand text.", | |
| "Tokenization is the first step in most NLP pipelines.", | |
| "The model achieved state-of-the-art results on all benchmarks.", | |
| ], | |
| "hindi": [ | |
| "कृत्रिम बुद्धिमत्ता दुनिया को तेजी से बदल रही है।", | |
| "मुझे नई तकनीकें सीखना पसंद है।", | |
| "यह एक परीक्षण वाक्य है।", | |
| "संख्याएँ 12345 और चिह्नों को सही ढंग से संसाधित किया जाना चाहिए।", | |
| "प्राकृतिक भाषा प्रसंस्करण कृत्रिम बुद्धिमत्ता का एक महत्वपूर्ण क्षेत्र है।", | |
| ], | |
| "chinese": [ | |
| "人工智能正在迅速改变世界。", | |
| "我喜欢学习新技术。", | |
| "这是一个测试句子。", | |
| "数字12345和符号需要正确处理。", | |
| "自然语言处理是人工智能的重要领域。", | |
| ], | |
| "arabic": [ | |
| "الذكاء الاصطناعي يغير العالم بسرعة.", | |
| "أحب تعلم التقنيات الجديدة.", | |
| "هذه جملة اختبارية.", | |
| "معالجة اللغة الطبيعية مجال مهم في الذكاء الاصطناعي.", | |
| "يجب معالجة الأرقام 12345 والرموز بشكل صحيح.", | |
| ], | |
| "japanese": [ | |
| "人工知能は世界を急速に変えています。", | |
| "私は新しい技術を学ぶのが好きです。", | |
| "これはテスト用の文です。", | |
| "数字12345と記号を正しく処理する必要があります。", | |
| "自然言語処理は人工知能の重要な分野です。", | |
| ], | |
| "german": [ | |
| "Künstliche Intelligenz verändert die Welt schnell.", | |
| "Ich lerne gerne neue Technologien.", | |
| "Donaudampfschifffahrtsgesellschaft ist ein langes deutsches Wort.", | |
| "Dies ist ein Testsatz.", | |
| "Natürliche Sprachverarbeitung ist ein wichtiges Forschungsgebiet.", | |
| ], | |
| "russian": [ | |
| "Искусственный интеллект быстро меняет мир.", | |
| "Мне нравится изучать новые технологии.", | |
| "Это тестовое предложение.", | |
| "Обработка естественного языка — важная область ИИ.", | |
| "Числа 12345 и символы должны обрабатываться корректно.", | |
| ], | |
| "korean": [ | |
| "인공지능은 세상을 빠르게 변화시키고 있습니다.", | |
| "나는 새로운 기술을 배우는 것을 좋아합니다.", | |
| "이것은 테스트 문장입니다.", | |
| "자연어 처리는 인공지능의 중요한 분야입니다.", | |
| "숫자 12345와 기호를 올바르게 처리해야 합니다.", | |
| ], | |
| }, | |
| "programming_languages": { | |
| "python": [ | |
| "def greet(name): return f'Hello, {name}!'", | |
| "numbers = [1,2,3]; squared = [x**2 for x in numbers]", | |
| "import torch\nmodel = torch.nn.Linear(128, 64)", | |
| "async def fetch(url):\n async with aiohttp.ClientSession() as s:\n return await s.get(url)", | |
| "class Tokenizer:\n def __init__(self, vocab):\n self.vocab = vocab", | |
| ], | |
| "javascript": [ | |
| "const greet = name => `Hello, ${name}!`;", | |
| "const nums = [1,2,3]; const sq = nums.map(x => x**2);", | |
| "async function fetchData(url) { const res = await fetch(url); return res.json(); }", | |
| "const obj = { key: 'value', nested: { a: 1 } };", | |
| ], | |
| "sql": [ | |
| "SELECT u.name, COUNT(o.id) FROM users u JOIN orders o ON u.id = o.user_id GROUP BY u.name;", | |
| "CREATE INDEX idx_users_email ON users(email);", | |
| "WITH ranked AS (SELECT *, ROW_NUMBER() OVER (PARTITION BY dept ORDER BY salary DESC) rn FROM emp) SELECT * FROM ranked WHERE rn=1;", | |
| ], | |
| "rust": [ | |
| "fn main() { println!(\"Hello, world!\"); }", | |
| "let v: Vec<i32> = (1..=10).collect();", | |
| "impl fmt::Display for Point { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { write!(f, \"({}, {})\", self.x, self.y) } }", | |
| ], | |
| }, | |
| "scientific_formulas": { | |
| "algebra": [ | |
| "x² + y² = z²", | |
| "x = (-b ± √(b² - 4ac)) / 2a", | |
| "e^(iπ) + 1 = 0", | |
| "∑ᵢ₌₁ⁿ i = n(n+1)/2", | |
| ], | |
| "calculus": [ | |
| "∫₀¹ x² dx = 1/3", | |
| "d/dx (x²) = 2x", | |
| "lim(x→0) sin(x)/x = 1", | |
| "∂²u/∂x² + ∂²u/∂y² = 0", | |
| ], | |
| "physics": [ | |
| "E = mc²", | |
| "∇·E = ρ/ε₀", | |
| "ψ(x,t) = Ae^{i(kx - ωt)}", | |
| "|ψ⟩ = α|0⟩ + β|1⟩", | |
| ], | |
| "statistics": [ | |
| "P(A|B) = P(A∩B)/P(B)", | |
| "H(X) = -∑ p(x) log p(x)", | |
| "KL(P||Q) = ∑ P(x) log(P(x)/Q(x))", | |
| "E[X] = ∑ xP(x), Var(X) = E[X²] - (E[X])²", | |
| ], | |
| }, | |
| "edge_cases": { | |
| "whitespace_control": [ | |
| "word1\t\tword2\t\tword3", | |
| "line1\nline2\nline3", | |
| " leading spaces", | |
| "trailing spaces ", | |
| ], | |
| "long_tokens": [ | |
| "https://www.example.com/very/long/path/to/some/resource?param1=value1¶m2=value2", | |
| "thisIsAReallyLongCamelCaseIdentifierThatMightAppearInCode", | |
| "SGVsbG8gV29ybGQhIFRoaXMgaXMgYSBiYXNlNjQgZW5jb2RlZCBzdHJpbmc=", | |
| "550e8400-e29b-41d4-a716-446655440000", | |
| ], | |
| "mixed_scripts": [ | |
| "Hello 世界 مرحبا Привет こんにちは", | |
| "AI模型 and NLP技术 are transforming الذكاء الاصطناعي", | |
| "math: α + β = γ, code: x += 1", | |
| ], | |
| }, | |
| } | |
| CATEGORY_LABELS = { | |
| "human_languages": "🌍 Human languages", | |
| "programming_languages": "💻 Programming languages", | |
| "scientific_formulas": "🧮 Scientific formulas", | |
| "edge_cases": "⚠️ Edge cases", | |
| } | |
| # ───────────────────────────────────────────────────────────────── | |
| # Tokenizer loaders | |
| # ───────────────────────────────────────────────────────────────── | |
| def _hf_wrapper(tok): | |
| class W: | |
| def encode(self, text): | |
| return tok.encode(text, add_special_tokens=False) | |
| def decode(self, ids): | |
| return tok.decode(ids, skip_special_tokens=True, | |
| clean_up_tokenization_spaces=False) | |
| return W() | |
| def load_from_hub(model_id: str): | |
| from transformers import AutoTokenizer | |
| tok = AutoTokenizer.from_pretrained(model_id.strip()) | |
| return _hf_wrapper(tok), model_id.strip() | |
| def load_from_uploaded_files(files: list, display_name: str): | |
| """ | |
| Accepts a list of Gradio file objects and returns (wrapper, name). | |
| Supported combinations: | |
| • tokenizer.json [+ tokenizer_config.json, vocab.txt, merges.txt …] | |
| → HuggingFace fast tokenizer loaded from a temp dir | |
| • *.model | |
| → SentencePiece | |
| • vocab.json + merges.txt (BPE without tokenizer.json) | |
| → HuggingFace from temp dir | |
| """ | |
| if not files: | |
| raise ValueError("No files uploaded.") | |
| paths = [Path(f.name) for f in files] | |
| filenames = {p.name for p in paths} | |
| # ── SentencePiece .model ─────────────────────────────────── | |
| sp_models = [p for p in paths if p.suffix == ".model"] | |
| if sp_models: | |
| import sentencepiece as spm | |
| sp = spm.SentencePieceProcessor() | |
| sp.Load(str(sp_models[0])) | |
| class SPWrapper: | |
| def encode(self, text): return sp.EncodeAsIds(text) | |
| def decode(self, ids): return sp.DecodeIds(ids) | |
| return SPWrapper(), display_name or sp_models[0].stem | |
| # ── HuggingFace file set ─────────────────────────────────── | |
| HF_FILES = { | |
| "tokenizer.json", "tokenizer_config.json", | |
| "vocab.txt", "vocab.json", "merges.txt", | |
| "special_tokens_map.json", "added_tokens.json", "spiece.model", | |
| } | |
| hf_uploads = [p for p in paths if p.name in HF_FILES] | |
| if hf_uploads: | |
| from transformers import AutoTokenizer | |
| tmp = Path(tempfile.mkdtemp(prefix="tok_")) | |
| for p in hf_uploads: | |
| shutil.copy(p, tmp / p.name) | |
| tok = AutoTokenizer.from_pretrained(str(tmp)) | |
| return _hf_wrapper(tok), display_name or "uploaded-tokenizer" | |
| raise ValueError( | |
| f"Unrecognised file(s): {', '.join(p.name for p in paths)}.\n" | |
| "Expected: tokenizer.json, *.model, or vocab.json + merges.txt" | |
| ) | |
| def load_tiktoken(encoding: str): | |
| import tiktoken | |
| enc = tiktoken.get_encoding(encoding.strip()) | |
| class W: | |
| def encode(self, text): return enc.encode(text) | |
| def decode(self, ids): return enc.decode(ids) | |
| return W(), encoding.strip() | |
| def resolve_tokenizer(source, hub_id, uploaded_files, upload_name, tiktoken_enc): | |
| if source == "HuggingFace Hub ID": | |
| if not hub_id.strip(): | |
| raise ValueError("Please enter a Hub model ID (e.g. bert-base-multilingual-cased).") | |
| return load_from_hub(hub_id) | |
| elif source == "Upload files": | |
| if not uploaded_files: | |
| raise ValueError("Please upload at least one tokenizer file.") | |
| return load_from_uploaded_files(uploaded_files, (upload_name or "").strip()) | |
| elif source == "tiktoken encoding": | |
| if not tiktoken_enc.strip(): | |
| raise ValueError("Please enter a tiktoken encoding (e.g. cl100k_base).") | |
| return load_tiktoken(tiktoken_enc) | |
| raise ValueError(f"Unknown source: {source}") | |
| # ───────────────────────────────────────────────────────────────── | |
| # Metrics | |
| # ───────────────────────────────────────────────────────────────── | |
| def fertility_score(tok, text): | |
| words = text.split() | |
| return len(tok.encode(text)) / len(words) if words else 0.0 | |
| def compression_ratio(tok, text): | |
| return len(tok.encode(text)) / len(text) if text else 0.0 | |
| def byte_compression_ratio(tok, text): | |
| n = len(text.encode("utf-8")) | |
| return len(tok.encode(text)) / n if n else 0.0 | |
| def roundtrip_fidelity(tok, text): | |
| try: | |
| return text.strip() == tok.decode(tok.encode(text)).strip() | |
| except Exception: | |
| return False | |
| def evaluate_tokenizer(tok, dataset): | |
| results: dict[str, Any] = {} | |
| all_f, all_c, failures = [], [], 0 | |
| for category, subcategories in dataset.items(): | |
| results[category] = {} | |
| for subcategory, samples in subcategories.items(): | |
| ferts, comps, byte_comps, token_counts, sub_fails = [], [], [], [], 0 | |
| for text in samples: | |
| if not text or not text.strip(): | |
| continue | |
| try: | |
| token_counts.append(len(tok.encode(text))) | |
| f = fertility_score(tok, text); ferts.append(f); all_f.append(f) | |
| c = compression_ratio(tok, text); comps.append(c); all_c.append(c) | |
| byte_comps.append(byte_compression_ratio(tok, text)) | |
| if not roundtrip_fidelity(tok, text): | |
| sub_fails += 1; failures += 1 | |
| except Exception: | |
| pass | |
| def avg(l): return round(sum(l)/len(l), 4) if l else 0.0 | |
| results[category][subcategory] = { | |
| "n_samples": len(token_counts), | |
| "avg_tokens": avg(token_counts), | |
| "avg_fertility": avg(ferts), | |
| "avg_compression_ratio": avg(comps), | |
| "avg_byte_compression": avg(byte_comps), | |
| "fidelity_failures": sub_fails, | |
| } | |
| results["__summary__"] = { | |
| "overall_avg_fertility": round(sum(all_f)/len(all_f), 4) if all_f else 0, | |
| "overall_avg_compression": round(sum(all_c)/len(all_c), 4) if all_c else 0, | |
| "total_samples": sum(len(s) for cat in dataset.values() for s in cat.values()), | |
| "fidelity_failure_count": failures, | |
| } | |
| return results | |
| # ───────────────────────────────────────────────────────────────── | |
| # Plots | |
| # ───────────────────────────────────────────────────────────────── | |
| PALETTE = ["#3b82f6","#8b5cf6","#ec4899","#f59e0b","#10b981", | |
| "#ef4444","#06b6d4","#84cc16"] | |
| def fig_to_pil(fig): | |
| buf = io.BytesIO() | |
| fig.savefig(buf, format="png", dpi=130, bbox_inches="tight", | |
| facecolor=fig.get_facecolor()) | |
| buf.seek(0) | |
| from PIL import Image | |
| return Image.open(buf).copy() | |
| def _dark_fig(w, h): | |
| fig, ax = plt.subplots(figsize=(w, h), facecolor="#0f1117") | |
| ax.set_facecolor("#0f1117") | |
| ax.tick_params(colors="white") | |
| ax.spines[["top","right","bottom","left"]].set_color("#333") | |
| return fig, ax | |
| def plot_fertility_heatmap(result, title): | |
| cats = [c for c in result if not c.startswith("__") and isinstance(result[c], dict)] | |
| if not cats: return None | |
| data = {cat: {sub: v.get("avg_fertility", 0) for sub, v in result[cat].items() | |
| if isinstance(v, dict)} for cat in cats} | |
| df = pd.DataFrame(data).T.fillna(0) | |
| fig, ax = _dark_fig(max(10, len(df.columns)*0.8), max(4, len(df)*0.6)) | |
| import seaborn as sns | |
| sns.heatmap(df, ax=ax, cmap="YlOrRd", annot=True, fmt=".2f", | |
| linewidths=0.5, linecolor="#1e2130", | |
| cbar_kws={"label": "Avg fertility (tokens/word)"}) | |
| ax.set_title(f"Fertility heatmap — {title}", fontsize=12, color="white", pad=10) | |
| ax.tick_params(colors="white", labelsize=8) | |
| plt.xticks(rotation=40, ha="right", color="white") | |
| plt.yticks(color="white") | |
| ax.figure.axes[-1].tick_params(colors="white", labelsize=8) | |
| ax.figure.axes[-1].yaxis.label.set_color("white") | |
| plt.tight_layout() | |
| img = fig_to_pil(fig); plt.close(fig); return img | |
| def plot_language_fertility_bar(result, title): | |
| lang_data = result.get("human_languages", {}) | |
| if not lang_data: return None | |
| langs = {lang: v["avg_fertility"] for lang, v in lang_data.items() | |
| if isinstance(v, dict) and "avg_fertility" in v} | |
| langs = dict(sorted(langs.items(), key=lambda x: x[1])) | |
| colors = ["#d73027" if v > 3 else "#fdae61" if v > 2 else "#1a9850" for v in langs.values()] | |
| fig, ax = _dark_fig(9, max(4, len(langs)*0.35)) | |
| bars = ax.barh(list(langs.keys()), list(langs.values()), color=colors, height=0.7) | |
| for bar, val in zip(bars, langs.values()): | |
| ax.text(val+0.02, bar.get_y()+bar.get_height()/2, | |
| f"{val:.2f}", va="center", fontsize=8, color="white") | |
| ax.axvline(1.0, color="#aaa", linestyle="--", lw=0.8, label="Ideal (1.0)") | |
| ax.axvline(2.0, color="#fdae61", linestyle="--", lw=0.8, label="Acceptable (2.0)") | |
| ax.axvline(4.0, color="#d73027", linestyle="--", lw=0.8, label="Poor (≥4.0)") | |
| ax.set_xlabel("Avg fertility (tokens/word)", color="white") | |
| ax.set_title(f"Per-language fertility — {title}", color="white", fontsize=11) | |
| ax.legend(fontsize=8, facecolor="#1e2130", labelcolor="white") | |
| plt.tight_layout() | |
| img = fig_to_pil(fig); plt.close(fig); return img | |
| def plot_compression_scatter(result, title): | |
| xs, ys, labels, cat_list = [], [], [], [] | |
| cat_colors = {} | |
| cats = [c for c in result if not c.startswith("__") and isinstance(result[c], dict)] | |
| for i, cat in enumerate(cats): | |
| cat_colors[cat] = PALETTE[i % len(PALETTE)] | |
| for sub, vals in result[cat].items(): | |
| if not isinstance(vals, dict): continue | |
| f = vals.get("avg_fertility"); c = vals.get("avg_byte_compression") | |
| if f is not None and c is not None: | |
| xs.append(c); ys.append(f); labels.append(sub); cat_list.append(cat) | |
| if not xs: return None | |
| fig, ax = _dark_fig(9, 6) | |
| for cat in set(cat_list): | |
| idxs = [i for i, c in enumerate(cat_list) if c == cat] | |
| ax.scatter([xs[i] for i in idxs], [ys[i] for i in idxs], | |
| color=cat_colors[cat], label=CATEGORY_LABELS.get(cat, cat), | |
| alpha=0.85, s=70, edgecolors="white", linewidths=0.3) | |
| for i, lbl in enumerate(labels): | |
| ax.annotate(lbl, (xs[i], ys[i]), fontsize=6.5, color="#ccc", | |
| xytext=(4, 3), textcoords="offset points") | |
| ax.axhline(1.0, color="#aaa", linestyle="--", lw=0.8, label="Fertility=1.0") | |
| ax.axhline(2.0, color="#fdae61", linestyle="--", lw=0.8, label="Fertility=2.0") | |
| ax.set_xlabel("Byte compression (tokens/byte) — lower is better", color="white") | |
| ax.set_ylabel("Fertility (tokens/word) — lower is better", color="white") | |
| ax.set_title(f"Fertility vs byte compression — {title}", color="white", fontsize=11) | |
| ax.legend(fontsize=8, facecolor="#1e2130", labelcolor="white") | |
| plt.tight_layout() | |
| img = fig_to_pil(fig); plt.close(fig); return img | |
| def plot_comparison_bar(results_dict, metric="avg_fertility"): | |
| if not results_dict: return None | |
| cats, data = set(), {} | |
| for tok_name, result in results_dict.items(): | |
| data[tok_name] = {} | |
| for cat, subcats in result.items(): | |
| if cat.startswith("__") or not isinstance(subcats, dict): continue | |
| vals = [v.get(metric, 0) for v in subcats.values() | |
| if isinstance(v, dict) and metric in v] | |
| if vals: | |
| data[tok_name][cat] = round(sum(vals)/len(vals), 4); cats.add(cat) | |
| cats = sorted(cats) | |
| tok_names = list(data.keys()) | |
| x = np.arange(len(cats)) | |
| width = 0.75 / max(len(tok_names), 1) | |
| fig, ax = _dark_fig(max(9, len(cats)*1.8), 5.5) | |
| for i, name in enumerate(tok_names): | |
| vals = [data[name].get(cat, 0) for cat in cats] | |
| offset = x + i*width - (len(tok_names)-1)*width/2 | |
| bars = ax.bar(offset, vals, width*0.9, label=name, | |
| color=PALETTE[i % len(PALETTE)], alpha=0.88) | |
| for bar, val in zip(bars, vals): | |
| ax.text(bar.get_x()+bar.get_width()/2, bar.get_height()+0.01, | |
| f"{val:.2f}", ha="center", va="bottom", fontsize=7.5, color="white") | |
| ax.set_xticks(x) | |
| ax.set_xticklabels([CATEGORY_LABELS.get(c, c) for c in cats], | |
| rotation=20, ha="right", color="white", fontsize=9) | |
| ax.set_ylabel(metric.replace("_"," ").title(), color="white") | |
| ax.set_title(f"Tokenizer comparison — {metric.replace('_',' ').title()}", color="white", fontsize=11) | |
| ax.legend(fontsize=9, facecolor="#1e2130", labelcolor="white") | |
| plt.tight_layout() | |
| img = fig_to_pil(fig); plt.close(fig); return img | |
| def plot_fidelity_summary(results_dict): | |
| names = list(results_dict.keys()) | |
| failures = [r.get("__summary__", {}).get("fidelity_failure_count", 0) | |
| for r in results_dict.values()] | |
| fig, ax = _dark_fig(max(5, len(names)*1.4), 4.5) | |
| colors = ["#d73027" if f > 0 else "#1a9850" for f in failures] | |
| bars = ax.bar(names, failures, color=colors, width=0.5) | |
| for bar, val in zip(bars, failures): | |
| ax.text(bar.get_x()+bar.get_width()/2, bar.get_height()+0.05, | |
| str(val), ha="center", va="bottom", fontsize=10, | |
| color="#d73027" if val > 0 else "#1a9850") | |
| ax.set_ylabel("Fidelity failure count", color="white") | |
| ax.set_title("Roundtrip fidelity failures", color="white", fontsize=11) | |
| ax.set_ylim(bottom=0) | |
| ax.legend(handles=[ | |
| mpatches.Patch(color="#1a9850", label="0 failures (pass)"), | |
| mpatches.Patch(color="#d73027", label="Has failures"), | |
| ], fontsize=8, facecolor="#1e2130", labelcolor="white") | |
| plt.tight_layout() | |
| img = fig_to_pil(fig); plt.close(fig); return img | |
| # ───────────────────────────────────────────────────────────────── | |
| # Shared tokenizer source block builder | |
| # ───────────────────────────────────────────────────────────────── | |
| def tokenizer_source_block(prefix=""): | |
| """Renders the three-way tokenizer source UI and returns component dict.""" | |
| gr.Markdown(f"#### {prefix}Load tokenizer") | |
| source = gr.Radio( | |
| ["HuggingFace Hub ID", "Upload files", "tiktoken encoding"], | |
| value="HuggingFace Hub ID", | |
| label="Source", | |
| ) | |
| with gr.Column(visible=True) as hub_col: | |
| hub_id = gr.Textbox( | |
| label="Hub model ID", | |
| placeholder="bert-base-multilingual-cased", | |
| value="bert-base-multilingual-cased", | |
| ) | |
| gr.Markdown( | |
| "<small>Examples: `xlm-roberta-base` · `google/mt5-base` · " | |
| "`facebook/mbart-large-50` · `ai4bharat/indic-bert`</small>" | |
| ) | |
| with gr.Column(visible=False) as upload_col: | |
| uploaded_files = gr.File( | |
| label="Upload tokenizer file(s)", | |
| file_count="multiple", | |
| file_types=[".json", ".txt", ".model", ".bpe", ".vocab"], | |
| ) | |
| upload_name = gr.Textbox( | |
| label="Display name (optional)", | |
| placeholder="my-custom-tokenizer", | |
| ) | |
| gr.Markdown( | |
| "<small>" | |
| "**HuggingFace fast tokenizer** → upload `tokenizer.json` " | |
| "(optionally also `tokenizer_config.json`, `vocab.txt`, `merges.txt`)<br>" | |
| "**SentencePiece** → upload the `.model` file<br>" | |
| "**BPE (GPT-2 style)** → upload `vocab.json` + `merges.txt`" | |
| "</small>" | |
| ) | |
| with gr.Column(visible=False) as tiktoken_col: | |
| tiktoken_enc = gr.Textbox( | |
| label="Encoding name", | |
| placeholder="cl100k_base", | |
| value="cl100k_base", | |
| ) | |
| gr.Markdown( | |
| "<small>Available encodings: " | |
| "`cl100k_base` (GPT-3.5/4) · `o200k_base` (GPT-4o) · `p50k_base` (Codex)</small>" | |
| ) | |
| # dummy defaults so every branch always has a value | |
| hub_id_default = gr.Textbox(value="", visible=False) | |
| upload_name_default = gr.Textbox(value="", visible=False) | |
| tiktoken_enc_default = gr.Textbox(value="cl100k_base", visible=False) | |
| def _toggle(s): | |
| return ( | |
| gr.update(visible=s == "HuggingFace Hub ID"), | |
| gr.update(visible=s == "Upload files"), | |
| gr.update(visible=s == "tiktoken encoding"), | |
| ) | |
| source.change(_toggle, source, [hub_col, upload_col, tiktoken_col]) | |
| return dict( | |
| source=source, | |
| hub_id=hub_id, | |
| uploaded_files=uploaded_files, | |
| upload_name=upload_name, | |
| tiktoken_enc=tiktoken_enc, | |
| ) | |
| # ───────────────────────────────────────────────────────────────── | |
| # Tab logic | |
| # ───────────────────────────────────────────────────────────────── | |
| def tokenize_live(source, hub_id, uploaded_files, upload_name, tiktoken_enc, text): | |
| if not text.strip(): | |
| return "Enter some text above to tokenize.", "" | |
| try: | |
| tok, name = resolve_tokenizer(source, hub_id, uploaded_files, upload_name, tiktoken_enc) | |
| except Exception: | |
| return f"❌ Could not load tokenizer:\n```\n{traceback.format_exc()}\n```", "" | |
| try: | |
| ids = tok.encode(text) | |
| fid = "✅ Roundtrip OK" if roundtrip_fidelity(tok, text) else "⚠️ Roundtrip mismatch" | |
| info = ( | |
| f"**Tokenizer:** `{name}` \n" | |
| f"**Token count:** {len(ids)} | " | |
| f"**Fertility:** {len(ids)/max(1,len(text.split())):.2f} | " | |
| f"**Compression:** {len(ids)/max(1,len(text)):.3f} | {fid}" | |
| ) | |
| ids_str = " ".join(str(i) for i in ids[:120]) | |
| if len(ids) > 120: | |
| ids_str += f" … (+{len(ids)-120} more)" | |
| return info, ids_str | |
| except Exception: | |
| return f"❌ Tokenization error:\n```\n{traceback.format_exc()}\n```", "" | |
| def run_single_eval(source, hub_id, uploaded_files, upload_name, tiktoken_enc, categories): | |
| try: | |
| tok, name = resolve_tokenizer(source, hub_id, uploaded_files, upload_name, tiktoken_enc) | |
| except Exception: | |
| return f"❌ Could not load tokenizer:\n```\n{traceback.format_exc()}\n```", None, None, None, None | |
| dataset_subset = {k: v for k, v in DATASET.items() if k in (categories or [])} | |
| if not dataset_subset: | |
| return "⚠️ Select at least one dataset category.", None, None, None, None | |
| try: | |
| result = evaluate_tokenizer(tok, dataset_subset) | |
| except Exception: | |
| return f"❌ Evaluation error:\n```\n{traceback.format_exc()}\n```", None, None, None, None | |
| s = result["__summary__"] | |
| status = ( | |
| f"✅ **{name}** — {s['total_samples']} samples evaluated\n\n" | |
| f"| Metric | Value |\n|--------|-------|\n" | |
| f"| Overall avg fertility | `{s['overall_avg_fertility']}` |\n" | |
| f"| Overall avg compression | `{s['overall_avg_compression']}` |\n" | |
| f"| Fidelity failures | `{s['fidelity_failure_count']}` |" | |
| ) | |
| rows = [] | |
| for cat, subcats in result.items(): | |
| if cat.startswith("__") or not isinstance(subcats, dict): continue | |
| for sub, vals in subcats.items(): | |
| if isinstance(vals, dict): | |
| rows.append({ | |
| "Category": CATEGORY_LABELS.get(cat, cat), | |
| "Subcategory": sub, | |
| "Avg tokens": vals.get("avg_tokens", 0), | |
| "Avg fertility": vals.get("avg_fertility", 0), | |
| "Avg compression": vals.get("avg_compression_ratio", 0), | |
| "Fidelity fails": vals.get("fidelity_failures", 0), | |
| }) | |
| return ( | |
| status, | |
| plot_fertility_heatmap(result, name), | |
| plot_language_fertility_bar(result, name) if "human_languages" in dataset_subset else None, | |
| plot_compression_scatter(result, name), | |
| pd.DataFrame(rows), | |
| ) | |
| def run_compare_eval( | |
| src_a, hub_a, files_a, name_a, tt_a, | |
| src_b, hub_b, files_b, name_b, tt_b, | |
| metric, categories, | |
| ): | |
| results_dict = {} | |
| for src, hub, files, uname, tt in [ | |
| (src_a, hub_a, files_a, name_a, tt_a), | |
| (src_b, hub_b, files_b, name_b, tt_b), | |
| ]: | |
| try: | |
| tok, dname = resolve_tokenizer(src, hub, files, uname, tt) | |
| except Exception: | |
| return f"❌ Could not load tokenizer:\n```\n{traceback.format_exc()}\n```", None, None, None | |
| dataset_subset = {k: v for k, v in DATASET.items() if k in (categories or [])} | |
| if not dataset_subset: | |
| return "⚠️ Select at least one dataset category.", None, None, None | |
| try: | |
| results_dict[dname] = evaluate_tokenizer(tok, dataset_subset) | |
| except Exception: | |
| return f"❌ Eval error for `{dname}`:\n```\n{traceback.format_exc()}\n```", None, None, None | |
| metric_key = { | |
| "Fertility (lower = better)": "avg_fertility", | |
| "Compression ratio": "avg_compression_ratio", | |
| "Byte compression": "avg_byte_compression", | |
| }.get(metric, "avg_fertility") | |
| rows = [] | |
| for name, result in results_dict.items(): | |
| s = result.get("__summary__", {}) | |
| rows.append({ | |
| "Tokenizer": name, | |
| "Avg fertility": s.get("overall_avg_fertility"), | |
| "Avg compression": s.get("overall_avg_compression"), | |
| "Samples evaluated": s.get("total_samples"), | |
| "Fidelity failures": s.get("fidelity_failure_count"), | |
| }) | |
| df = pd.DataFrame(rows).sort_values("Avg fertility") | |
| status = "✅ Comparison complete.\n\n**Leaderboard (lower fertility = better)**\n\n" | |
| for _, row in df.iterrows(): | |
| status += f"- **{row['Tokenizer']}** — fertility `{row['Avg fertility']}`, failures `{row['Fidelity failures']}`\n" | |
| return ( | |
| status, | |
| plot_comparison_bar(results_dict, metric_key), | |
| plot_fidelity_summary(results_dict), | |
| df, | |
| ) | |
| # ───────────────────────────────────────────────────────────────── | |
| # Gradio UI | |
| # ───────────────────────────────────────────────────────────────── | |
| CATEGORY_CHOICES = list(DATASET.keys()) | |
| with gr.Blocks(title="TokenizerBench", theme=gr.themes.Soft()) as demo: | |
| gr.Markdown( | |
| """# 🤗 TokenizerBench | |
| Evaluate and compare tokenizers on multilingual text, code, scientific formulas, and edge cases. | |
| Load from the **Hugging Face Hub**, **upload your own files**, or use a **tiktoken** encoding. | |
| """ | |
| ) | |
| with gr.Tabs(): | |
| # ── Tab 1: Playground ────────────────────────────────── | |
| with gr.Tab("🧪 Playground"): | |
| gr.Markdown("Type or paste any text and see instant tokenization results.") | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| pg = tokenizer_source_block() | |
| pg_btn = gr.Button("Tokenize ▶", variant="primary") | |
| with gr.Column(scale=2): | |
| pg_text = gr.Textbox( | |
| label="Input text", | |
| placeholder="Type or paste anything…", | |
| lines=5, | |
| value="The quick brown fox jumps over the lazy dog. 快速的棕色狐狸跳过了懒狗。", | |
| ) | |
| pg_info = gr.Markdown("_Results will appear here._") | |
| pg_ids = gr.Textbox(label="Token IDs", lines=2, interactive=False) | |
| pg_btn.click( | |
| tokenize_live, | |
| [pg["source"], pg["hub_id"], pg["uploaded_files"], | |
| pg["upload_name"], pg["tiktoken_enc"], pg_text], | |
| [pg_info, pg_ids], | |
| ) | |
| gr.Markdown("---\n### Browse dataset samples") | |
| gr.Markdown("Click any sample below to load it into the text box above.") | |
| for cat_key, cat_label in CATEGORY_LABELS.items(): | |
| with gr.Accordion(cat_label, open=False): | |
| for sub, samples in DATASET[cat_key].items(): | |
| gr.Markdown(f"**{sub}**") | |
| with gr.Row(): | |
| for s in samples[:3]: | |
| btn = gr.Button(s[:65] + ("…" if len(s) > 65 else ""), size="sm") | |
| btn.click(lambda t=s: t, outputs=pg_text) | |
| # ── Tab 2: Evaluate ──────────────────────────────────── | |
| with gr.Tab("📊 Evaluate"): | |
| gr.Markdown("Run a full benchmark on a single tokenizer across all dataset categories.") | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| ev = tokenizer_source_block() | |
| ev_cats = gr.CheckboxGroup( | |
| CATEGORY_CHOICES, value=CATEGORY_CHOICES, | |
| label="Dataset categories to evaluate", | |
| ) | |
| ev_btn = gr.Button("Run evaluation ▶", variant="primary") | |
| with gr.Column(scale=2): | |
| ev_status = gr.Markdown("_Results will appear here after you click Run evaluation._") | |
| ev_table = gr.Dataframe(label="Per-subcategory breakdown", wrap=True) | |
| with gr.Tabs(): | |
| with gr.Tab("Fertility heatmap"): | |
| ev_heatmap = gr.Image(type="pil") | |
| with gr.Tab("Language fertility bar"): | |
| ev_langbar = gr.Image(type="pil") | |
| with gr.Tab("Fertility vs compression"): | |
| ev_scatter = gr.Image(type="pil") | |
| ev_btn.click( | |
| run_single_eval, | |
| [ev["source"], ev["hub_id"], ev["uploaded_files"], | |
| ev["upload_name"], ev["tiktoken_enc"], ev_cats], | |
| [ev_status, ev_heatmap, ev_langbar, ev_scatter, ev_table], | |
| ) | |
| # ── Tab 3: Compare ───────────────────────────────────── | |
| with gr.Tab("⚖️ Compare"): | |
| gr.Markdown("Compare two tokenizers side-by-side on the same dataset.") | |
| with gr.Row(): | |
| with gr.Column(): | |
| cmp_a = tokenizer_source_block("Tokenizer A — ") | |
| with gr.Column(): | |
| cmp_b = tokenizer_source_block("Tokenizer B — ") | |
| with gr.Row(): | |
| cmp_metric = gr.Dropdown( | |
| ["Fertility (lower = better)", "Compression ratio", "Byte compression"], | |
| value="Fertility (lower = better)", | |
| label="Comparison metric", | |
| ) | |
| cmp_cats = gr.CheckboxGroup( | |
| CATEGORY_CHOICES, value=CATEGORY_CHOICES, | |
| label="Dataset categories", | |
| ) | |
| cmp_btn = gr.Button("Compare ▶", variant="primary") | |
| cmp_status = gr.Markdown("_Results will appear here._") | |
| cmp_table = gr.Dataframe(label="Leaderboard", wrap=True) | |
| with gr.Tabs(): | |
| with gr.Tab("Category comparison bar"): | |
| cmp_bar_img = gr.Image(type="pil") | |
| with gr.Tab("Fidelity failures"): | |
| cmp_fid_img = gr.Image(type="pil") | |
| cmp_btn.click( | |
| run_compare_eval, | |
| [ | |
| cmp_a["source"], cmp_a["hub_id"], cmp_a["uploaded_files"], | |
| cmp_a["upload_name"], cmp_a["tiktoken_enc"], | |
| cmp_b["source"], cmp_b["hub_id"], cmp_b["uploaded_files"], | |
| cmp_b["upload_name"], cmp_b["tiktoken_enc"], | |
| cmp_metric, cmp_cats, | |
| ], | |
| [cmp_status, cmp_bar_img, cmp_fid_img, cmp_table], | |
| ) | |
| gr.Markdown( | |
| """--- | |
| **Metrics explained** — Fertility = tokens/word (lower = better, ≥4 = poor) · Compression = tokens/char · Fidelity = encode→decode must reproduce original text exactly | |
| **Upload guide** | |
| | File(s) to upload | Tokenizer type | | |
| |-------------------|----------------| | |
| | `tokenizer.json` | Any HuggingFace fast tokenizer (BERT, RoBERTa, GPT-2, LLaMA…) | | |
| | `tokenizer.json` + `tokenizer_config.json` + `vocab.txt` | Full HF tokenizer folder | | |
| | `vocab.json` + `merges.txt` | BPE tokenizer (GPT-2 style) | | |
| | `*.model` | SentencePiece (T5, mT5, XLM-R, mBERT…) | | |
| """ | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch() |