File size: 6,086 Bytes
8a17806 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 | import os
from fastapi import FastAPI, HTTPException, Depends
from fastapi.middleware.cors import CORSMiddleware
import uvicorn
from contextlib import asynccontextmanager
from openai import OpenAI
from dotenv import load_dotenv
from models import NL2CypherRequest, CypherResponse, ValidationRequest, ValidationResponse
from schemas import EXAMPLE_SCHEMA
from prompts import create_system_prompt, create_validation_prompt
from validators import CypherValidator, RuleBasedValidator
# 加载环境变量
load_dotenv()
# 获取 OpenAI 的 api key
openai_api_key = os.getenv("OPENAI_API_KEY")
# 生命周期管理
@asynccontextmanager
async def lifespan(app: FastAPI):
# 启动时初始化
neo4j_uri = os.getenv("NEO4J_URI")
neo4j_user = os.getenv("NEO4J_USER")
neo4j_password = os.getenv("NEO4J_PASSWORD")
if all([neo4j_uri, neo4j_user, neo4j_password]):
app.state.validator = CypherValidator(neo4j_uri, neo4j_user, neo4j_password)
else:
app.state.validator = RuleBasedValidator()
yield
# 关闭时清理
if hasattr(app.state.validator, 'close'):
app.state.validator.close()
# 创建FastAPI应用
app = FastAPI(title="NL2Cypher API", lifespan=lifespan)
# 初始化 OpenAI 模型
client = OpenAI(
api_key=openai_api_key, # 你的 OpenAI API 密钥
base_url="https://api.openai.com/v1", # OpenAI 的 API 端点
)
# 添加CORS中间件
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
def clean_cypher_output(raw_output: str) -> str:
"""清洗 LLM 返回的 Cypher 查询, 去掉多余的包装文本"""
import re
text = raw_output.strip()
# 去掉 markdown 代码块: ```cypher ... ``` 或 ``` ... ```
text = re.sub(r'```(?:cypher)?\s*', '', text)
text = text.strip('`')
# 去掉 Cypher: "..." 包装
match = re.match(r'^[Cc]ypher:\s*["\']?(.*?)["\']?\s*$', text, re.DOTALL)
if match:
text = match.group(1).strip()
# 去掉首尾引号
if (text.startswith('"') and text.endswith('"')) or \
(text.startswith("'") and text.endswith("'")):
text = text[1:-1].strip()
return text
def generate_cypher_query(natural_language: str, query_type: str = None) -> str:
"""使用 OpenAI 生成 Cypher 查询"""
system_prompt = create_system_prompt(str(EXAMPLE_SCHEMA.model_dump()))
user_prompt = natural_language
if query_type:
user_prompt = f"{query_type}查询: {natural_language}"
try:
response = client.chat.completions.create(
model="gpt-4o",
messages=[
{"role": "system", "content": system_prompt},
{"role": "user", "content": user_prompt}
],
temperature=0.1,
max_tokens=2048,
stream=False
)
raw_output = response.choices[0].message.content.strip()
return clean_cypher_output(raw_output)
except Exception as e:
raise HTTPException(status_code=500, detail=f"OpenAI API错误: {str(e)}")
def explain_cypher_query(cypher_query: str) -> str:
"""解释Cypher查询"""
try:
response = client.chat.completions.create(
model="gpt-4o",
messages=[
{"role": "system", "content": "你是一个Neo4j专家, 请用简单明了的语言解释Cypher查询."},
{"role": "user", "content": f"请解释以下Cypher查询: {cypher_query}"}
],
temperature=0.1,
max_tokens=1024,
stream=False
)
return response.choices[0].message.content.strip()
except Exception as e:
return f"无法生成解释: {str(e)}"
@app.post("/generate", response_model=CypherResponse)
async def generate_cypher(request: NL2CypherRequest):
"""生成Cypher查询端点"""
# 利用 OpenAI 生成 Cypher 查询
cypher_query = generate_cypher_query(
request.natural_language_query,
request.query_type.value if request.query_type else None
)
# 利用 OpenAI 生成解释
explanation = explain_cypher_query(cypher_query)
# 验证查询
is_valid, errors = app.state.validator.validate_against_schema(cypher_query, EXAMPLE_SCHEMA)
# 计算置信度, 将基础置信度设置为0.9
confidence = 0.9
# 如果有潜在错误, 重新计算置信度 confidence
if errors:
confidence = max(0.3, confidence - len(errors) * 0.1)
return CypherResponse(
cypher_query=cypher_query,
explanation=explanation,
confidence=confidence,
validated=is_valid,
validation_errors=errors
)
@app.post("/validate", response_model=ValidationResponse)
async def validate_cypher(request: ValidationRequest):
"""验证Cypher查询端点"""
is_valid, errors = app.state.validator.validate_against_schema(request.cypher_query, EXAMPLE_SCHEMA)
# 生成改进建议
suggestions = []
if errors:
try:
response = client.chat.completions.create(
model="gpt-4o",
messages=[
{"role": "system", "content": "你是一个Neo4j专家, 请提供Cypher查询的改进建议."},
{"role": "user", "content": create_validation_prompt(request.cypher_query)}
],
temperature=0.1,
max_tokens=1024,
stream=False
)
suggestions = [response.choices[0].message.content.strip()]
except:
suggestions = ["无法生成建议"]
return ValidationResponse(
is_valid=is_valid,
errors=errors,
suggestions=suggestions
)
@app.get("/schema")
async def get_schema():
"""获取图模式端点"""
return EXAMPLE_SCHEMA.model_dump()
if __name__ == "__main__":
# 因为项目中的主服务Agent启动在8103端口, 所以这个neo4j的服务端口另选一个8101即可
uvicorn.run(app, host="0.0.0.0", port=8101)
|