CodeT5 / app.py
Sixparticle
Fix tokenizer added_tokens format and numpy compatibility
1e66cc5
import gradio as gr
import os
import json
import logging
import transformers
import huggingface_hub
from huggingface_hub import snapshot_download
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, RobertaTokenizer
import torch
try:
import tokenizers
except Exception: # pragma: no cover - diagnostics only
tokenizers = None
# 加载 CodeT5+ 模型
model_name = "Salesforce/codet5p-220m"
logger = logging.getLogger(__name__)
logging.basicConfig(level=logging.INFO)
def log_runtime_versions() -> None:
"""Log runtime package versions to simplify Space startup debugging."""
tokenizers_version = getattr(tokenizers, "__version__", "not-installed")
logger.info("transformers version: %s", transformers.__version__)
logger.info("huggingface_hub version: %s", huggingface_hub.__version__)
logger.info("tokenizers version: %s", tokenizers_version)
def sanitize_added_tokens_file(added_tokens_file: str) -> None:
"""Normalize added_tokens.json to dict format expected by slow tokenizers."""
if not os.path.exists(added_tokens_file):
return
try:
with open(added_tokens_file, "r", encoding="utf-8") as fp:
data = json.load(fp)
except Exception:
data = {}
if isinstance(data, dict):
sanitized = {k: v for k, v in data.items() if isinstance(k, str) and isinstance(v, int)}
elif isinstance(data, list):
# If a list was written by older workaround, convert to empty mapping.
sanitized = {}
else:
sanitized = {}
with open(added_tokens_file, "w", encoding="utf-8") as fp:
json.dump(sanitized, fp, ensure_ascii=True)
def prepare_local_model(repo_id: str, local_dir: str = "./model_cache") -> str:
snapshot_download(repo_id=repo_id, local_dir=local_dir)
# Work around tokenizer metadata incompatibility in some runtime combos.
added_tokens_file = os.path.join(local_dir, "added_tokens.json")
sanitize_added_tokens_file(added_tokens_file)
return local_dir
log_runtime_versions()
local_model_dir = prepare_local_model(model_name)
auto_error = None
try:
tokenizer = AutoTokenizer.from_pretrained(local_model_dir, use_fast=False, trust_remote_code=False)
logger.info("Tokenizer loaded with AutoTokenizer (slow mode).")
except Exception as exc:
auto_error = exc
logger.warning("AutoTokenizer load failed, trying RobertaTokenizer fallback: %s", exc)
# Fallback to explicit slow tokenizer class to bypass tokenizers fast-path issues.
try:
tokenizer = RobertaTokenizer.from_pretrained(local_model_dir, trust_remote_code=False)
logger.info("Tokenizer loaded with RobertaTokenizer fallback.")
except Exception as fallback_exc:
raise RuntimeError(
"Tokenizer initialization failed for both AutoTokenizer and RobertaTokenizer. "
f"AutoTokenizer error: {auto_error}; RobertaTokenizer error: {fallback_exc}"
) from fallback_exc
model = AutoModelForSeq2SeqLM.from_pretrained(local_model_dir, trust_remote_code=False)
def generate_code(prompt: str, max_length: int = 128) -> str:
"""代码生成/补全"""
if not prompt.strip():
return ""
inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=512)
with torch.no_grad():
outputs = model.generate(
**inputs,
max_length=max_length,
num_beams=4,
early_stopping=True
)
return tokenizer.decode(outputs[0], skip_special_tokens=True)
# 创建 Gradio 界面
demo = gr.Interface(
fn=generate_code,
inputs=[
gr.Textbox(
label="Prompt",
placeholder="输入代码描述或代码片段,例如:def fibonacci(n):",
lines=5
),
gr.Slider(32, 512, value=128, step=32, label="Max Length")
],
outputs=gr.Textbox(label="Generated Code", lines=10),
title="CodeT5+ Code Generation",
description="基于 Salesforce CodeT5+ (220M) 的代码生成模型。支持代码补全、代码生成等任务。",
examples=[
["def fibonacci(n):", 128],
["# Python function to calculate factorial", 128],
["def quick_sort(arr):", 128],
]
)
if __name__ == "__main__":
demo.launch()