agentV2 / GraphDatabase /validators.py
shenli
Add GraphDatabase module with Neo4j + Redis caching
8a17806
import re
from typing import List, Tuple
from neo4j import GraphDatabase
import os
class CypherValidator:
def __init__(self, neo4j_uri: str, neo4j_user: str, neo4j_password: str):
self.driver = GraphDatabase.driver(neo4j_uri, auth=(neo4j_user, neo4j_password))
def validate_syntax(self, cypher_query: str) -> Tuple[bool, List[str]]:
"""验证Cypher查询的语法"""
errors = []
# 基本语法检查
if not cypher_query.strip().upper().startswith(('MATCH', 'CREATE', 'MERGE', 'CALL')):
errors.append("查询必须以MATCH, CREATE, MERGE 或 CALL开头!!!")
# 检查是否有潜在的注入风险
if any(keyword in cypher_query.upper() for keyword in ['DROP', 'DELETE', 'DETACH', 'REMOVE']):
if not any(keyword in cypher_query.upper() for keyword in ['DELETE', 'DETACH']):
errors.append("查询包含可能危险的操作符")
# 检查RETURN语句是否存在 (对于MATCH查询)
if cypher_query.upper().startswith('MATCH') and 'RETURN' not in cypher_query.upper():
errors.append("MATCH查询必须包含RETURN语句!!!")
# 使用Neo4j解释计划验证查询
try:
with self.driver.session() as session:
result = session.run(f"EXPLAIN {cypher_query}")
# 如果解释成功, 语法基本正确
return True, errors
except Exception as e:
errors.append(f"语法错误: {str(e)}")
return False, errors
def validate_against_schema(self, cypher_query: str, schema) -> Tuple[bool, List[str]]:
"""根据模式验证查询"""
errors = []
# 提取所有节点标签
node_labels = [node.label for node in schema.nodes]
node_pattern = r'\(([a-zA-Z0-9_]+)?:?([a-zA-Z0-9_]+)\)'
matches = re.findall(node_pattern, cypher_query)
for match in matches:
if match[1] and match[1] not in node_labels:
errors.append(f"使用了不存在的节点标签: {match[1]}")
# 提取所有关系类型
rel_types = [rel.type for rel in schema.relationships]
rel_pattern = r'\[([a-zA-Z0-9_]+)?:?([a-zA-Z0-9_]+)\]'
rel_matches = re.findall(rel_pattern, cypher_query)
for match in rel_matches:
if match[1] and match[1] not in rel_types:
errors.append(f"使用了不存在的关系类型: {match[1]}")
return len(errors) == 0, errors
def close(self):
self.driver.close()
# 简单的基于规则的验证器 (当无法连接Neo4j时使用)
class RuleBasedValidator:
def validate(self, cypher_query: str, schema) -> Tuple[bool, List[str]]:
errors = []
# 检查基本结构
if not cypher_query.strip():
errors.append("查询不能为空!!!")
return False, errors
# 检查是否包含潜在危险操作
dangerous_patterns = [
(r'(?i)drop\s+', "DROP操作可能危险"),
(r'(?i)delete\s+', "DELETE操作需要谨慎"),
(r'(?i)detach\s+delete', "DETACH DELETE操作非常危险!!"),
(r'(?i)remove\s+', "REMOVE操作需要谨慎"),
]
for pattern, message in dangerous_patterns:
if re.search(pattern, cypher_query):
errors.append(message)
# 检查MATCH查询是否包含RETURN
if re.match(r'(?i)match', cypher_query) and not re.search(r'(?i)return', cypher_query):
errors.append("MATCH查询必须包含RETURN子句")
# 检查CREATE查询是否合理
if re.match(r'(?i)create', cypher_query) and not re.search(r'(?i)(node|relationship|label|index)', cypher_query):
errors.append("CREATE查询应该明确创建节点或关系")
return len(errors) == 0, errors
def validate_against_schema(self, cypher_query: str, schema) -> Tuple[bool, List[str]]:
"""兼容CypherValidator的接口, 先做规则验证再做schema验证"""
is_valid, errors = self.validate(cypher_query, schema)
# 额外进行schema验证
node_labels = [node.label for node in schema.nodes]
node_pattern = r'\(([a-zA-Z0-9_]+)?:?([a-zA-Z0-9_]+)\)'
matches = re.findall(node_pattern, cypher_query)
for match in matches:
if match[1] and match[1] not in node_labels:
errors.append(f"使用了不存在的节点标签: {match[1]}")
rel_types = [rel.type for rel in schema.relationships]
rel_pattern = r'\[([a-zA-Z0-9_]+)?:?([a-zA-Z0-9_]+)\]'
rel_matches = re.findall(rel_pattern, cypher_query)
for match in rel_matches:
if match[1] and match[1] not in rel_types:
errors.append(f"使用了不存在的关系类型: {match[1]}")
return len(errors) == 0, errors