File size: 4,285 Bytes
718985b
e20ba09
99a461a
 
 
 
e20ba09
83d6a1c
8110e4e
718985b
99a461a
 
 
 
 
c2293c5
 
e20ba09
99a461a
 
 
 
 
 
 
 
 
 
 
 
 
1e66cc5
99a461a
 
 
 
 
 
 
1e66cc5
99a461a
1e66cc5
 
 
 
 
99a461a
1e66cc5
99a461a
 
 
 
e20ba09
 
 
 
99a461a
e20ba09
99a461a
e20ba09
 
 
 
99a461a
e20ba09
99a461a
83d6a1c
e20ba09
99a461a
 
 
 
83d6a1c
99a461a
 
 
 
 
 
 
 
b51e859
e20ba09
718985b
8110e4e
 
 
 
 
 
 
 
 
 
 
 
c2293c5
8110e4e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c2293c5
 
8110e4e
 
 
 
 
 
 
 
c2293c5
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
import gradio as gr
import os
import json
import logging
import transformers
import huggingface_hub
from huggingface_hub import snapshot_download
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, RobertaTokenizer
import torch

try:
    import tokenizers
except Exception:  # pragma: no cover - diagnostics only
    tokenizers = None

# 加载 CodeT5+ 模型
model_name = "Salesforce/codet5p-220m"

logger = logging.getLogger(__name__)
logging.basicConfig(level=logging.INFO)


def log_runtime_versions() -> None:
    """Log runtime package versions to simplify Space startup debugging."""
    tokenizers_version = getattr(tokenizers, "__version__", "not-installed")
    logger.info("transformers version: %s", transformers.__version__)
    logger.info("huggingface_hub version: %s", huggingface_hub.__version__)
    logger.info("tokenizers version: %s", tokenizers_version)


def sanitize_added_tokens_file(added_tokens_file: str) -> None:
    """Normalize added_tokens.json to dict format expected by slow tokenizers."""
    if not os.path.exists(added_tokens_file):
        return

    try:
        with open(added_tokens_file, "r", encoding="utf-8") as fp:
            data = json.load(fp)
    except Exception:
        data = {}

    if isinstance(data, dict):
        sanitized = {k: v for k, v in data.items() if isinstance(k, str) and isinstance(v, int)}
    elif isinstance(data, list):
        # If a list was written by older workaround, convert to empty mapping.
        sanitized = {}
    else:
        sanitized = {}

    with open(added_tokens_file, "w", encoding="utf-8") as fp:
        json.dump(sanitized, fp, ensure_ascii=True)


def prepare_local_model(repo_id: str, local_dir: str = "./model_cache") -> str:
    snapshot_download(repo_id=repo_id, local_dir=local_dir)

    # Work around tokenizer metadata incompatibility in some runtime combos.
    added_tokens_file = os.path.join(local_dir, "added_tokens.json")
    sanitize_added_tokens_file(added_tokens_file)

    return local_dir


log_runtime_versions()
local_model_dir = prepare_local_model(model_name)
auto_error = None
try:
    tokenizer = AutoTokenizer.from_pretrained(local_model_dir, use_fast=False, trust_remote_code=False)
    logger.info("Tokenizer loaded with AutoTokenizer (slow mode).")
except Exception as exc:
    auto_error = exc
    logger.warning("AutoTokenizer load failed, trying RobertaTokenizer fallback: %s", exc)
    # Fallback to explicit slow tokenizer class to bypass tokenizers fast-path issues.
    try:
        tokenizer = RobertaTokenizer.from_pretrained(local_model_dir, trust_remote_code=False)
        logger.info("Tokenizer loaded with RobertaTokenizer fallback.")
    except Exception as fallback_exc:
        raise RuntimeError(
            "Tokenizer initialization failed for both AutoTokenizer and RobertaTokenizer. "
            f"AutoTokenizer error: {auto_error}; RobertaTokenizer error: {fallback_exc}"
        ) from fallback_exc

model = AutoModelForSeq2SeqLM.from_pretrained(local_model_dir, trust_remote_code=False)

def generate_code(prompt: str, max_length: int = 128) -> str:
    """代码生成/补全"""
    if not prompt.strip():
        return ""
    
    inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=512)
    
    with torch.no_grad():
        outputs = model.generate(
            **inputs,
            max_length=max_length,
            num_beams=4,
            early_stopping=True
        )
    
    return tokenizer.decode(outputs[0], skip_special_tokens=True)

# 创建 Gradio 界面
demo = gr.Interface(
    fn=generate_code,
    inputs=[
        gr.Textbox(
            label="Prompt",
            placeholder="输入代码描述或代码片段,例如:def fibonacci(n):",
            lines=5
        ),
        gr.Slider(32, 512, value=128, step=32, label="Max Length")
    ],
    outputs=gr.Textbox(label="Generated Code", lines=10),
    title="CodeT5+ Code Generation",
    description="基于 Salesforce CodeT5+ (220M) 的代码生成模型。支持代码补全、代码生成等任务。",
    examples=[
        ["def fibonacci(n):", 128],
        ["# Python function to calculate factorial", 128],
        ["def quick_sort(arr):", 128],
    ]
)

if __name__ == "__main__":
    demo.launch()