File size: 4,285 Bytes
718985b e20ba09 99a461a e20ba09 83d6a1c 8110e4e 718985b 99a461a c2293c5 e20ba09 99a461a 1e66cc5 99a461a 1e66cc5 99a461a 1e66cc5 99a461a 1e66cc5 99a461a e20ba09 99a461a e20ba09 99a461a e20ba09 99a461a e20ba09 99a461a 83d6a1c e20ba09 99a461a 83d6a1c 99a461a b51e859 e20ba09 718985b 8110e4e c2293c5 8110e4e c2293c5 8110e4e c2293c5 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 | import gradio as gr
import os
import json
import logging
import transformers
import huggingface_hub
from huggingface_hub import snapshot_download
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, RobertaTokenizer
import torch
try:
import tokenizers
except Exception: # pragma: no cover - diagnostics only
tokenizers = None
# 加载 CodeT5+ 模型
model_name = "Salesforce/codet5p-220m"
logger = logging.getLogger(__name__)
logging.basicConfig(level=logging.INFO)
def log_runtime_versions() -> None:
"""Log runtime package versions to simplify Space startup debugging."""
tokenizers_version = getattr(tokenizers, "__version__", "not-installed")
logger.info("transformers version: %s", transformers.__version__)
logger.info("huggingface_hub version: %s", huggingface_hub.__version__)
logger.info("tokenizers version: %s", tokenizers_version)
def sanitize_added_tokens_file(added_tokens_file: str) -> None:
"""Normalize added_tokens.json to dict format expected by slow tokenizers."""
if not os.path.exists(added_tokens_file):
return
try:
with open(added_tokens_file, "r", encoding="utf-8") as fp:
data = json.load(fp)
except Exception:
data = {}
if isinstance(data, dict):
sanitized = {k: v for k, v in data.items() if isinstance(k, str) and isinstance(v, int)}
elif isinstance(data, list):
# If a list was written by older workaround, convert to empty mapping.
sanitized = {}
else:
sanitized = {}
with open(added_tokens_file, "w", encoding="utf-8") as fp:
json.dump(sanitized, fp, ensure_ascii=True)
def prepare_local_model(repo_id: str, local_dir: str = "./model_cache") -> str:
snapshot_download(repo_id=repo_id, local_dir=local_dir)
# Work around tokenizer metadata incompatibility in some runtime combos.
added_tokens_file = os.path.join(local_dir, "added_tokens.json")
sanitize_added_tokens_file(added_tokens_file)
return local_dir
log_runtime_versions()
local_model_dir = prepare_local_model(model_name)
auto_error = None
try:
tokenizer = AutoTokenizer.from_pretrained(local_model_dir, use_fast=False, trust_remote_code=False)
logger.info("Tokenizer loaded with AutoTokenizer (slow mode).")
except Exception as exc:
auto_error = exc
logger.warning("AutoTokenizer load failed, trying RobertaTokenizer fallback: %s", exc)
# Fallback to explicit slow tokenizer class to bypass tokenizers fast-path issues.
try:
tokenizer = RobertaTokenizer.from_pretrained(local_model_dir, trust_remote_code=False)
logger.info("Tokenizer loaded with RobertaTokenizer fallback.")
except Exception as fallback_exc:
raise RuntimeError(
"Tokenizer initialization failed for both AutoTokenizer and RobertaTokenizer. "
f"AutoTokenizer error: {auto_error}; RobertaTokenizer error: {fallback_exc}"
) from fallback_exc
model = AutoModelForSeq2SeqLM.from_pretrained(local_model_dir, trust_remote_code=False)
def generate_code(prompt: str, max_length: int = 128) -> str:
"""代码生成/补全"""
if not prompt.strip():
return ""
inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=512)
with torch.no_grad():
outputs = model.generate(
**inputs,
max_length=max_length,
num_beams=4,
early_stopping=True
)
return tokenizer.decode(outputs[0], skip_special_tokens=True)
# 创建 Gradio 界面
demo = gr.Interface(
fn=generate_code,
inputs=[
gr.Textbox(
label="Prompt",
placeholder="输入代码描述或代码片段,例如:def fibonacci(n):",
lines=5
),
gr.Slider(32, 512, value=128, step=32, label="Max Length")
],
outputs=gr.Textbox(label="Generated Code", lines=10),
title="CodeT5+ Code Generation",
description="基于 Salesforce CodeT5+ (220M) 的代码生成模型。支持代码补全、代码生成等任务。",
examples=[
["def fibonacci(n):", 128],
["# Python function to calculate factorial", 128],
["def quick_sort(arr):", 128],
]
)
if __name__ == "__main__":
demo.launch()
|