Spaces:
Running
Running
File size: 11,014 Bytes
35f1842 2597b43 35f1842 2597b43 35f1842 2597b43 35f1842 2597b43 35f1842 2597b43 35f1842 2597b43 042c934 8ef5720 042c934 8ef5720 042c934 2597b43 042c934 2597b43 35f1842 2597b43 35f1842 2597b43 35f1842 2597b43 35f1842 2597b43 35f1842 40112a6 2597b43 35f1842 2597b43 35f1842 2597b43 35f1842 2597b43 35f1842 2597b43 35f1842 2597b43 35f1842 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 | import html
import gradio as gr
from datasets import load_dataset
from transformers import AutoTokenizer
def build_alignment_groups_from_ids(student_tokenizer, teacher_tokenizer, student_token_ids, teacher_token_ids):
"""
Build alignment groups using a greedy substring-equality algorithm on decoded token pieces.
Adapted from TRL's GoldTrainer._build_alignment_groups_from_ids.
"""
def to_canonical_pieces(tok, ids):
pieces = []
prev = ""
for k in range(len(ids)):
cur = tok.decode(ids[: k + 1], skip_special_tokens=False, clean_up_tokenization_spaces=False)
pieces.append(cur[len(prev):])
prev = cur
return pieces
s_pieces = to_canonical_pieces(student_tokenizer, student_token_ids)
t_pieces = to_canonical_pieces(teacher_tokenizer, teacher_token_ids)
i = j = 0
s_buf = t_buf = ""
s_group = []
t_group = []
s_groups = []
t_groups = []
def flush():
if s_group and t_group:
s_groups.append(s_group.copy())
t_groups.append(t_group.copy())
while i < len(s_pieces) or j < len(t_pieces):
if s_buf == t_buf and s_buf != "":
flush()
s_buf = t_buf = ""
s_group = []
t_group = []
continue
if s_buf == "" and i < len(s_pieces):
s_buf += s_pieces[i]
s_group.append(i)
i += 1
continue
if t_buf == "" and j < len(t_pieces):
t_buf += t_pieces[j]
t_group.append(j)
j += 1
continue
if len(s_buf) <= len(t_buf):
if i < len(s_pieces):
s_buf += s_pieces[i]
s_group.append(i)
i += 1
elif j < len(t_pieces):
t_buf += t_pieces[j]
t_group.append(j)
j += 1
else:
if j < len(t_pieces):
t_buf += t_pieces[j]
t_group.append(j)
j += 1
elif i < len(s_pieces):
s_buf += s_pieces[i]
s_group.append(i)
i += 1
if s_buf == t_buf and s_group and t_group:
flush()
elif s_group or t_group:
if not s_group:
s_group = []
if not t_group:
t_group = []
if s_group or t_group:
s_groups.append(s_group.copy() if s_group else [])
t_groups.append(t_group.copy() if t_group else [])
return s_groups, t_groups
def _decode_pieces(tokenizer, token_ids, indices):
"""Decode individual token pieces for a group of token indices."""
return [
tokenizer.decode([token_ids[idx]], skip_special_tokens=False, clean_up_tokenization_spaces=False)
for idx in indices
]
def _format_pieces(pieces):
"""Format token pieces as a list, e.g. '["hel", "lo"]'."""
inner = ", ".join(f'"{p}"' for p in pieces)
return f"[{inner}]"
def highlight_groups(student_tokenizer, teacher_tokenizer, student_token_ids, teacher_token_ids, s_groups, t_groups):
"""Build an HTML string with highlighted misalignment regions."""
parts = []
first_purple = True
for k in range(len(s_groups)):
s_ids = [student_token_ids[idx] for idx in s_groups[k]]
text = student_tokenizer.decode(s_ids, skip_special_tokens=False, clean_up_tokenization_spaces=False)
escaped = html.escape(text)
s_multi = len(s_groups[k]) > 1
t_multi = len(t_groups[k]) > 1
if s_multi and t_multi:
if first_purple:
s_pieces = _decode_pieces(student_tokenizer, student_token_ids, s_groups[k])
t_pieces = _decode_pieces(teacher_tokenizer, teacher_token_ids, t_groups[k])
tooltip = html.escape(f'Student: {_format_pieces(s_pieces)} / Teacher: {_format_pieces(t_pieces)}')
parts.append(f'<span style="background-color: #b388ff;" title="{tooltip}">{escaped}</span>')
first_purple = False
else:
parts.append(f'<span style="background-color: #b388ff;">{escaped}</span>')
elif s_multi:
s_pieces = _decode_pieces(student_tokenizer, student_token_ids, s_groups[k])
tooltip = html.escape(f'Student: {_format_pieces(s_pieces)}')
parts.append(f'<span style="background-color: #ffcc80;" title="{tooltip}">{escaped}</span>')
elif t_multi:
t_pieces = _decode_pieces(teacher_tokenizer, teacher_token_ids, t_groups[k])
tooltip = html.escape(f'Teacher: {_format_pieces(t_pieces)}')
parts.append(f'<span style="background-color: #90caf9;" title="{tooltip}">{escaped}</span>')
else:
parts.append(escaped)
return "".join(parts)
def make_html_block(student_tokenizer, teacher_tokenizer, text, idx):
"""Process a single text and return its highlighted HTML block."""
s_ids = student_tokenizer.encode(text, add_special_tokens=False)
t_ids = teacher_tokenizer.encode(text, add_special_tokens=False)
s_groups, t_groups = build_alignment_groups_from_ids(
student_tokenizer, teacher_tokenizer, s_ids, t_ids
)
highlighted = highlight_groups(student_tokenizer, teacher_tokenizer, s_ids, t_ids, s_groups, t_groups)
# Build tokenized views with alternating colors
s_tokens = [student_tokenizer.decode([tid], skip_special_tokens=False, clean_up_tokenization_spaces=False) for tid in s_ids]
t_tokens = [teacher_tokenizer.decode([tid], skip_special_tokens=False, clean_up_tokenization_spaces=False) for tid in t_ids]
color1 = "#fff9c4"
color2 = "#b2ebf2"
s_tokens_html = "".join(
f'<span style="background-color:{color1 if i % 2 == 0 else color2};">{html.escape(t)}</span>'
for i, t in enumerate(s_tokens)
)
t_tokens_html = "".join(
f'<span style="background-color:{color1 if i % 2 == 0 else color2};">{html.escape(t)}</span>'
for i, t in enumerate(t_tokens)
)
tokenized_section = f'''
<div style="margin-bottom:15px;">
<details style="margin-bottom:10px;">
<summary style="cursor:pointer; font-weight:bold; user-select:none;">Show tokenization details</summary>
<div style="display:grid; grid-template-columns:1fr 1fr; gap:15px; margin-top:10px;">
<div style="border:1px solid #ddd; padding:10px; border-radius:5px;">
<strong style="color:#f57c00;">Student Tokens ({len(s_ids)})</strong>
<div style="margin-top:8px; font-size:12px; word-break:break-word;">{s_tokens_html}</div>
</div>
<div style="border:1px solid #ddd; padding:10px; border-radius:5px;">
<strong style="color:#1976d2;">Teacher Tokens ({len(t_ids)})</strong>
<div style="margin-top:8px; font-size:12px; word-break:break-word;">{t_tokens_html}</div>
</div>
</div>
</details>
</div>
'''
return (
f'<div style="border:1px solid #ccc; padding:10px; margin:10px 0; '
f'border-radius:5px; white-space:pre-wrap; font-family:monospace; font-size:13px;">'
f"<strong>Text {idx + 1}</strong> "
f"(student tokens: {len(s_ids)}, teacher tokens: {len(t_ids)})<br><br>"
f"{tokenized_section}"
f"{highlighted}"
f"</div>"
)
def process_texts(student_model_id, teacher_model_id, dataset_id, dataset_config, progress=gr.Progress()):
"""Load tokenizers and dataset, compute first row only."""
progress(0, desc="Loading tokenizers...")
student_tokenizer = AutoTokenizer.from_pretrained(student_model_id)
teacher_tokenizer = AutoTokenizer.from_pretrained(teacher_model_id)
progress(0.5, desc="Loading dataset...")
config = dataset_config.strip() if dataset_config and dataset_config.strip() else None
ds = load_dataset(dataset_id, name=config, split="train")
rows = ds.select(range(min(10, len(ds))))
texts = ["".join(msg["content"] for msg in row["messages"]) for row in rows]
progress(0.8, desc="Processing first text...")
first_block = make_html_block(student_tokenizer, teacher_tokenizer, texts[0], 0)
cache = {0: first_block}
progress(1, desc="Done!")
return student_tokenizer, teacher_tokenizer, texts, cache, 0, render_page(cache, 0, len(texts))
LEGEND = (
'<div style="margin-bottom:15px; font-family:sans-serif;">'
"<strong>Legend:</strong> "
'<span style="background-color:#ffcc80; padding:2px 8px; margin-right:8px;">Student token split (orange)</span>'
'<span style="background-color:#90caf9; padding:2px 8px; margin-right:8px;">Teacher token split (blue)</span>'
'<span style="background-color:#b388ff; padding:2px 8px;">Both (purple)</span>'
"</div>"
)
def render_page(cache, idx, total):
if not cache:
return ""
counter = f'<div style="font-family:sans-serif; margin-bottom:10px;">Text {idx + 1} of {total}</div>'
return LEGEND + counter + cache[idx]
def go_prev(cache, idx, texts):
idx = max(0, idx - 1)
return cache, idx, render_page(cache, idx, len(texts))
def go_next(student_tokenizer, teacher_tokenizer, texts, cache, idx):
idx = min(len(texts) - 1, idx + 1)
if idx not in cache:
cache[idx] = make_html_block(student_tokenizer, teacher_tokenizer, texts[idx], idx)
return cache, idx, render_page(cache, idx, len(texts))
with gr.Blocks(title="Tokenization Diff") as demo:
gr.Markdown("# Tokenization Diff\nVisualize where two tokenizers differ in how they tokenize text.")
with gr.Row():
student_model = gr.Textbox(label="Student Model", value="Qwen/Qwen3-8B")
teacher_model = gr.Textbox(label="Teacher Model", value="deepseek-ai/DeepSeek-Math-V2")
dataset_id = gr.Textbox(label="Dataset ID", value="lm-provers/FineProofs-SFT")
dataset_config = gr.Textbox(label="Dataset Config", value="default")
submit_btn = gr.Button("Submit", variant="primary")
student_tok_state = gr.State(None)
teacher_tok_state = gr.State(None)
texts_state = gr.State([])
cache_state = gr.State({})
idx_state = gr.State(0)
output = gr.HTML(label="Tokenization Diff Output")
with gr.Row():
prev_btn = gr.Button("Previous")
next_btn = gr.Button("Next")
submit_btn.click(
fn=process_texts,
inputs=[student_model, teacher_model, dataset_id, dataset_config],
outputs=[student_tok_state, teacher_tok_state, texts_state, cache_state, idx_state, output],
)
prev_btn.click(
fn=go_prev,
inputs=[cache_state, idx_state, texts_state],
outputs=[cache_state, idx_state, output],
)
next_btn.click(
fn=go_next,
inputs=[student_tok_state, teacher_tok_state, texts_state, cache_state, idx_state],
outputs=[cache_state, idx_state, output],
)
if __name__ == "__main__":
demo.launch()
|