File size: 1,867 Bytes
8a17806 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 | import os
from pydantic import BaseModel, Field
from typing import Optional, List, Dict, Any
from enum import Enum
class QueryType(str, Enum):
MATCH = "MATCH"
CREATE = "CREATE"
MERGE = "MERGE"
DELETE = "DELETE"
SET = "SET"
REMOVE = "REMOVE"
class NL2CypherRequest(BaseModel):
natural_language_query: str = Field(
description="自然语言描述的需求",
examples=["查找'心血管和血栓栓塞综合征'建议服用什么药物?"]
)
query_type: Optional[QueryType] = Field(
default=None,
description="指定查询类型,如果不指定则由模型推断"
)
limit: Optional[int] = Field(
default=10,
description="结果限制数量",
ge=1,
le=1000
)
class CypherResponse(BaseModel):
cypher_query: str = Field(
...,
description="生成的Cypher查询语句"
)
explanation: str = Field(
...,
description="对生成的Cypher查询的解释"
)
confidence: float = Field(
...,
description="模型对生成查询的信心度(0-1)",
ge=0,
le=1
)
validated: bool = Field(
default=False,
description="查询是否通过验证"
)
validation_errors: List[str] = Field(
default_factory=list,
description="验证过程中发现的错误"
)
class ValidationRequest(BaseModel):
cypher_query: str = Field(
...,
description="需要验证的Cypher查询"
)
class ValidationResponse(BaseModel):
is_valid: bool = Field(
...,
description="查询是否有效"
)
errors: List[str] = Field(
default_factory=list,
description="发现的错误列表"
)
suggestions: List[str] = Field(
default_factory=list,
description="改进建议"
) |