shenli
Add GraphDatabase module with Neo4j + Redis caching
8a17806
import os
from fastapi import FastAPI, HTTPException, Depends
from fastapi.middleware.cors import CORSMiddleware
import uvicorn
from contextlib import asynccontextmanager
from openai import OpenAI
from dotenv import load_dotenv
from models import NL2CypherRequest, CypherResponse, ValidationRequest, ValidationResponse
from schemas import EXAMPLE_SCHEMA
from prompts import create_system_prompt, create_validation_prompt
from validators import CypherValidator, RuleBasedValidator
# 加载环境变量
load_dotenv()
# 获取 OpenAI 的 api key
openai_api_key = os.getenv("OPENAI_API_KEY")
# 生命周期管理
@asynccontextmanager
async def lifespan(app: FastAPI):
# 启动时初始化
neo4j_uri = os.getenv("NEO4J_URI")
neo4j_user = os.getenv("NEO4J_USER")
neo4j_password = os.getenv("NEO4J_PASSWORD")
if all([neo4j_uri, neo4j_user, neo4j_password]):
app.state.validator = CypherValidator(neo4j_uri, neo4j_user, neo4j_password)
else:
app.state.validator = RuleBasedValidator()
yield
# 关闭时清理
if hasattr(app.state.validator, 'close'):
app.state.validator.close()
# 创建FastAPI应用
app = FastAPI(title="NL2Cypher API", lifespan=lifespan)
# 初始化 OpenAI 模型
client = OpenAI(
api_key=openai_api_key, # 你的 OpenAI API 密钥
base_url="https://api.openai.com/v1", # OpenAI 的 API 端点
)
# 添加CORS中间件
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
def clean_cypher_output(raw_output: str) -> str:
"""清洗 LLM 返回的 Cypher 查询, 去掉多余的包装文本"""
import re
text = raw_output.strip()
# 去掉 markdown 代码块: ```cypher ... ``` 或 ``` ... ```
text = re.sub(r'```(?:cypher)?\s*', '', text)
text = text.strip('`')
# 去掉 Cypher: "..." 包装
match = re.match(r'^[Cc]ypher:\s*["\']?(.*?)["\']?\s*$', text, re.DOTALL)
if match:
text = match.group(1).strip()
# 去掉首尾引号
if (text.startswith('"') and text.endswith('"')) or \
(text.startswith("'") and text.endswith("'")):
text = text[1:-1].strip()
return text
def generate_cypher_query(natural_language: str, query_type: str = None) -> str:
"""使用 OpenAI 生成 Cypher 查询"""
system_prompt = create_system_prompt(str(EXAMPLE_SCHEMA.model_dump()))
user_prompt = natural_language
if query_type:
user_prompt = f"{query_type}查询: {natural_language}"
try:
response = client.chat.completions.create(
model="gpt-4o",
messages=[
{"role": "system", "content": system_prompt},
{"role": "user", "content": user_prompt}
],
temperature=0.1,
max_tokens=2048,
stream=False
)
raw_output = response.choices[0].message.content.strip()
return clean_cypher_output(raw_output)
except Exception as e:
raise HTTPException(status_code=500, detail=f"OpenAI API错误: {str(e)}")
def explain_cypher_query(cypher_query: str) -> str:
"""解释Cypher查询"""
try:
response = client.chat.completions.create(
model="gpt-4o",
messages=[
{"role": "system", "content": "你是一个Neo4j专家, 请用简单明了的语言解释Cypher查询."},
{"role": "user", "content": f"请解释以下Cypher查询: {cypher_query}"}
],
temperature=0.1,
max_tokens=1024,
stream=False
)
return response.choices[0].message.content.strip()
except Exception as e:
return f"无法生成解释: {str(e)}"
@app.post("/generate", response_model=CypherResponse)
async def generate_cypher(request: NL2CypherRequest):
"""生成Cypher查询端点"""
# 利用 OpenAI 生成 Cypher 查询
cypher_query = generate_cypher_query(
request.natural_language_query,
request.query_type.value if request.query_type else None
)
# 利用 OpenAI 生成解释
explanation = explain_cypher_query(cypher_query)
# 验证查询
is_valid, errors = app.state.validator.validate_against_schema(cypher_query, EXAMPLE_SCHEMA)
# 计算置信度, 将基础置信度设置为0.9
confidence = 0.9
# 如果有潜在错误, 重新计算置信度 confidence
if errors:
confidence = max(0.3, confidence - len(errors) * 0.1)
return CypherResponse(
cypher_query=cypher_query,
explanation=explanation,
confidence=confidence,
validated=is_valid,
validation_errors=errors
)
@app.post("/validate", response_model=ValidationResponse)
async def validate_cypher(request: ValidationRequest):
"""验证Cypher查询端点"""
is_valid, errors = app.state.validator.validate_against_schema(request.cypher_query, EXAMPLE_SCHEMA)
# 生成改进建议
suggestions = []
if errors:
try:
response = client.chat.completions.create(
model="gpt-4o",
messages=[
{"role": "system", "content": "你是一个Neo4j专家, 请提供Cypher查询的改进建议."},
{"role": "user", "content": create_validation_prompt(request.cypher_query)}
],
temperature=0.1,
max_tokens=1024,
stream=False
)
suggestions = [response.choices[0].message.content.strip()]
except:
suggestions = ["无法生成建议"]
return ValidationResponse(
is_valid=is_valid,
errors=errors,
suggestions=suggestions
)
@app.get("/schema")
async def get_schema():
"""获取图模式端点"""
return EXAMPLE_SCHEMA.model_dump()
if __name__ == "__main__":
# 因为项目中的主服务Agent启动在8103端口, 所以这个neo4j的服务端口另选一个8101即可
uvicorn.run(app, host="0.0.0.0", port=8101)