import gradio as gr import os import json import logging import transformers import huggingface_hub from huggingface_hub import snapshot_download from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, RobertaTokenizer import torch try: import tokenizers except Exception: # pragma: no cover - diagnostics only tokenizers = None # 加载 CodeT5+ 模型 model_name = "Salesforce/codet5p-220m" logger = logging.getLogger(__name__) logging.basicConfig(level=logging.INFO) def log_runtime_versions() -> None: """Log runtime package versions to simplify Space startup debugging.""" tokenizers_version = getattr(tokenizers, "__version__", "not-installed") logger.info("transformers version: %s", transformers.__version__) logger.info("huggingface_hub version: %s", huggingface_hub.__version__) logger.info("tokenizers version: %s", tokenizers_version) def sanitize_added_tokens_file(added_tokens_file: str) -> None: """Normalize added_tokens.json to dict format expected by slow tokenizers.""" if not os.path.exists(added_tokens_file): return try: with open(added_tokens_file, "r", encoding="utf-8") as fp: data = json.load(fp) except Exception: data = {} if isinstance(data, dict): sanitized = {k: v for k, v in data.items() if isinstance(k, str) and isinstance(v, int)} elif isinstance(data, list): # If a list was written by older workaround, convert to empty mapping. sanitized = {} else: sanitized = {} with open(added_tokens_file, "w", encoding="utf-8") as fp: json.dump(sanitized, fp, ensure_ascii=True) def prepare_local_model(repo_id: str, local_dir: str = "./model_cache") -> str: snapshot_download(repo_id=repo_id, local_dir=local_dir) # Work around tokenizer metadata incompatibility in some runtime combos. added_tokens_file = os.path.join(local_dir, "added_tokens.json") sanitize_added_tokens_file(added_tokens_file) return local_dir log_runtime_versions() local_model_dir = prepare_local_model(model_name) auto_error = None try: tokenizer = AutoTokenizer.from_pretrained(local_model_dir, use_fast=False, trust_remote_code=False) logger.info("Tokenizer loaded with AutoTokenizer (slow mode).") except Exception as exc: auto_error = exc logger.warning("AutoTokenizer load failed, trying RobertaTokenizer fallback: %s", exc) # Fallback to explicit slow tokenizer class to bypass tokenizers fast-path issues. try: tokenizer = RobertaTokenizer.from_pretrained(local_model_dir, trust_remote_code=False) logger.info("Tokenizer loaded with RobertaTokenizer fallback.") except Exception as fallback_exc: raise RuntimeError( "Tokenizer initialization failed for both AutoTokenizer and RobertaTokenizer. " f"AutoTokenizer error: {auto_error}; RobertaTokenizer error: {fallback_exc}" ) from fallback_exc model = AutoModelForSeq2SeqLM.from_pretrained(local_model_dir, trust_remote_code=False) def generate_code(prompt: str, max_length: int = 128) -> str: """代码生成/补全""" if not prompt.strip(): return "" inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=512) with torch.no_grad(): outputs = model.generate( **inputs, max_length=max_length, num_beams=4, early_stopping=True ) return tokenizer.decode(outputs[0], skip_special_tokens=True) # 创建 Gradio 界面 demo = gr.Interface( fn=generate_code, inputs=[ gr.Textbox( label="Prompt", placeholder="输入代码描述或代码片段,例如:def fibonacci(n):", lines=5 ), gr.Slider(32, 512, value=128, step=32, label="Max Length") ], outputs=gr.Textbox(label="Generated Code", lines=10), title="CodeT5+ Code Generation", description="基于 Salesforce CodeT5+ (220M) 的代码生成模型。支持代码补全、代码生成等任务。", examples=[ ["def fibonacci(n):", 128], ["# Python function to calculate factorial", 128], ["def quick_sort(arr):", 128], ] ) if __name__ == "__main__": demo.launch()