UniXCoder / app.py
Sixparticle
Fix Gradio launch for Spaces runtime changes
6000b90
import os
import re
from typing import Dict, Any, List
import gradio as gr
import numpy as np
import torch
from transformers import AutoTokenizer, AutoModel
MODEL_ID = "microsoft/unixcoder-base-nine"
MAX_TOKENS = 512
def _ensure_localhost_no_proxy() -> None:
local_hosts = ["localhost", "127.0.0.1", "::1"]
for key in ("NO_PROXY", "no_proxy"):
current = [item.strip() for item in os.environ.get(key, "").split(",") if item.strip()]
merged = current[:]
for host in local_hosts:
if host not in merged:
merged.append(host)
if merged:
os.environ[key] = ",".join(merged)
_ensure_localhost_no_proxy()
def _safe_float(v: float, ndigits: int = 4) -> float:
return float(round(float(v), ndigits))
class UniXcoderAnalyzer:
def __init__(self, model_id: str = MODEL_ID):
self.model_id = model_id
self.device = "cuda" if torch.cuda.is_available() else "cpu"
self.tokenizer = AutoTokenizer.from_pretrained(model_id)
self.model = AutoModel.from_pretrained(model_id).to(self.device)
self.model.eval()
@torch.no_grad()
def _embed(self, text: str) -> np.ndarray:
encoded = self.tokenizer(
text,
return_tensors="pt",
truncation=True,
max_length=MAX_TOKENS,
padding=True,
)
encoded = {k: v.to(self.device) for k, v in encoded.items()}
outputs = self.model(**encoded)
token_embeddings = outputs.last_hidden_state
attention_mask = encoded["attention_mask"].unsqueeze(-1).expand(token_embeddings.size()).float()
masked = token_embeddings * attention_mask
pooled = masked.sum(dim=1) / torch.clamp(attention_mask.sum(dim=1), min=1e-9)
vec = pooled[0].detach().cpu().numpy()
norm = np.linalg.norm(vec) + 1e-9
return vec / norm
def analyze(self, prompt: str, language: str, code: str, analysis_type: str) -> Dict[str, Any]:
prompt = (prompt or "").strip()
code = (code or "").strip()
if not code:
return {
"modelStrategy": "unixcoder-hf-space",
"enabled": True,
"model": self.model_id,
"status": "error",
"message": "code 不能为空",
"analysisError": "EMPTY_CODE",
"summary": "未提供待分析代码,无法执行语义分析。",
"keyPoints": [],
"risks": ["输入代码为空"],
"suggestions": ["请传入完整代码片段后重试"]
}
prompt_vec = self._embed(prompt if prompt else f"Analyze {language} code")
code_vec = self._embed(code)
semantic_alignment = float(np.dot(prompt_vec, code_vec))
semantic_alignment = (semantic_alignment + 1.0) / 2.0
lines = [ln for ln in code.splitlines() if ln.strip()]
line_count = len(lines)
char_count = len(code)
function_like = len(re.findall(r"\b(def|function|public|private|protected|class)\b", code))
control_flow = len(re.findall(r"\b(if|else|for|while|switch|try|catch)\b", code))
long_lines = sum(1 for ln in lines if len(ln) > 120)
comments = len(re.findall(r"//|/\*|\*/|#", code))
complexity_score = min(1.0, (control_flow * 0.08) + (function_like * 0.05) + (line_count / 300.0))
maintainability = max(0.0, min(1.0, 1.0 - (long_lines / max(1, line_count)) * 0.7 + min(comments / max(1, line_count), 0.2)))
key_points: List[str] = [
f"检测到约 {line_count} 行有效代码,{function_like} 个函数/类相关声明。",
f"语义相关性得分 {semantic_alignment:.2f}(0-1 越高越贴合需求)。",
f"控制流关键字出现 {control_flow} 次,复杂度评分 {complexity_score:.2f}。",
]
risks: List[str] = []
if semantic_alignment < 0.55:
risks.append("代码与需求语义相似度偏低,可能存在功能偏移。")
if long_lines > 0:
risks.append(f"存在 {long_lines} 行超长代码行,可读性和可维护性风险较高。")
if comments == 0:
risks.append("未检测到注释,后续维护和协作成本可能上升。")
if complexity_score > 0.7:
risks.append("控制流较复杂,建议补充单元测试覆盖核心分支。")
if not risks:
risks.append("未发现明显高风险项,建议结合业务规则进行人工复核。")
suggestions: List[str] = [
"对关键逻辑分支补充单元测试,优先覆盖边界输入。",
"将超过 120 字符的长行拆分,提升可读性。",
"为核心函数补充文档注释,标明输入、输出和异常行为。",
]
if analysis_type == "risk":
summary = (
f"风险导向分析完成:复杂度 {complexity_score:.2f},可维护性 {maintainability:.2f},"
f"语义相关性 {semantic_alignment:.2f}。"
)
elif analysis_type == "quality":
summary = (
f"质量导向分析完成:代码规模 {line_count} 行,复杂度 {complexity_score:.2f},"
f"可维护性 {maintainability:.2f}。"
)
else:
summary = (
f"语义分析完成:代码与需求相关性 {semantic_alignment:.2f},"
f"复杂度 {complexity_score:.2f},可维护性 {maintainability:.2f}。"
)
return {
"modelStrategy": "unixcoder-hf-space",
"enabled": True,
"model": self.model_id,
"status": "ok",
"message": "analysis success",
"analysisError": None,
"summary": summary,
"keyPoints": key_points,
"risks": risks,
"suggestions": suggestions,
"scores": {
"semanticAlignment": _safe_float(semantic_alignment),
"complexity": _safe_float(complexity_score),
"maintainability": _safe_float(maintainability),
"lineCount": line_count,
"charCount": char_count,
},
"meta": {
"language": language,
"analysisType": analysis_type,
"device": "cuda" if torch.cuda.is_available() else "cpu",
},
}
analyzer = UniXcoderAnalyzer()
def analyze_for_ui(prompt: str, language: str, code: str, analysis_type: str):
result = analyzer.analyze(prompt=prompt, language=language, code=code, analysis_type=analysis_type)
md = "\n".join(
[
f"### 分析摘要\n{result.get('summary', '')}",
"### Key Points",
"\n".join([f"- {x}" for x in result.get("keyPoints", [])]) or "- 无",
"### Risks",
"\n".join([f"- {x}" for x in result.get("risks", [])]) or "- 无",
"### Suggestions",
"\n".join([f"- {x}" for x in result.get("suggestions", [])]) or "- 无",
]
)
return result, md
with gr.Blocks(title="UniXcoder Code Analyzer") as demo:
gr.Markdown("# UniXcoder 代码理解与分析服务")
gr.Markdown("用于代码语义理解、风险提示和质量建议。可通过页面交互,也可通过 Gradio API 调用。")
with gr.Row():
language = gr.Dropdown(
choices=["java", "python", "javascript", "cpp", "go", "other"],
value="java",
label="Language",
)
analysis_type = gr.Dropdown(
choices=["summary", "risk", "quality"],
value="summary",
label="Analysis Type",
)
prompt = gr.Textbox(
label="需求描述 (Prompt)",
placeholder="例如:检查这段代码是否满足线程安全和异常处理要求",
lines=3,
)
code = gr.Textbox(
label="待分析代码",
placeholder="在这里粘贴代码...",
lines=16,
)
run_btn = gr.Button("开始分析", variant="primary")
output_json = gr.JSON(label="结构化结果(用于后端API接入)")
output_md = gr.Markdown(label="可读报告")
run_btn.click(
fn=analyze_for_ui,
inputs=[prompt, language, code, analysis_type],
outputs=[output_json, output_md],
api_name="analyze",
)
if __name__ == "__main__":
demo.launch(server_name="0.0.0.0", ssr_mode=False)