| """ |
| Agent 核心模块 - GAIA LangGraph ReAct Agent |
| 包含:AgentState, System Prompt, Graph 构建, 答案提取 |
| """ |
|
|
| import re |
| from typing import Sequence, Literal, Annotated, Optional |
|
|
| from langchain_core.messages import BaseMessage, SystemMessage, HumanMessage, AIMessage |
| from langchain_openai import ChatOpenAI |
| from langgraph.graph import StateGraph, END |
| from langgraph.graph.message import add_messages |
| from langgraph.prebuilt import ToolNode |
|
|
| try: |
| from typing import TypedDict |
| except ImportError: |
| from typing_extensions import TypedDict |
|
|
| from config import ( |
| OPENAI_BASE_URL, |
| OPENAI_API_KEY, |
| MODEL, |
| TEMPERATURE, |
| MAX_ITERATIONS, |
| DEBUG, |
| LLM_TIMEOUT, |
| RATE_LIMIT_RETRY_MAX, |
| RATE_LIMIT_RETRY_BASE_DELAY, |
| ) |
|
|
| |
| from tools import BASE_TOOLS |
|
|
| |
| try: |
| from extension_tools import EXTENSION_TOOLS |
| ALL_TOOLS = BASE_TOOLS + EXTENSION_TOOLS |
| except ImportError as e: |
| print(f"⚠️ 扩展工具加载失败: {e}") |
| print(" 提示: 请确保安装了 pandas 和 openpyxl (pip install pandas openpyxl)") |
| EXTENSION_TOOLS = [] |
| ALL_TOOLS = BASE_TOOLS |
|
|
| |
| try: |
| from rag import RAG_TOOLS |
| ALL_TOOLS = ALL_TOOLS + RAG_TOOLS |
| except ImportError: |
| RAG_TOOLS = [] |
|
|
| |
| try: |
| from rag import rag_lookup_answer |
| except ImportError: |
| rag_lookup_answer = None |
|
|
| |
| _tool_names = [t.name for t in ALL_TOOLS] |
| if DEBUG: |
| print(f"✓ 已加载 {len(ALL_TOOLS)} 个工具: {_tool_names}") |
| if 'parse_excel' not in _tool_names: |
| print("⚠️ 警告: parse_excel 工具未加载,Excel 文件处理将不可用!") |
|
|
|
|
| |
| |
| |
|
|
| SYSTEM_PROMPT = """你是一个专业的问答助手,专门解答GAIA基准测试中的各类问题。你需要准确、简洁地回答问题。 |
| |
| ## 你的能力 |
| |
| 你可以使用以下工具来获取信息和处理任务: |
| |
| ### 知识库工具(RAG) |
| - `rag_query(question)`: 查询知识库中的相似问题,获取解题策略建议。返回推荐的工具和解题步骤。**遇到复杂问题时优先使用!** |
| - `rag_retrieve(question)`: 仅检索相似问题,不生成建议。返回原始的相似问题和解法。 |
| - `rag_stats()`: 查看知识库状态(文档数量等)。 |
| |
| ### 信息获取工具 |
| - `web_search(query)`: 使用DuckDuckGo搜索网络信息。适用于查找人物、事件、地点、组织等外部知识。 |
| - `wikipedia_search(query)`: 在维基百科中搜索,返回简短摘要(3句话)。适用于快速确认人物/事件的基本信息。 |
| - `wikipedia_page(title, section)`: 获取维基百科页面的完整内容。**需要详细数据(如专辑列表、获奖记录、作品年表)时必须用此工具!** |
| - `tavily_search(query)`: 使用Tavily进行高质量网络搜索,返回最多3条结果。需要API Key。 |
| - `arxiv_search(query)`: 在arXiv上搜索学术论文,返回最多3条结果。适用于查找科学研究和学术文献。 |
| |
| ### 文件处理工具 |
| - `fetch_task_files(task_id)`: 从评分服务器下载任务附件。当问题涉及附件时必须先调用此工具。 |
| - `read_file(file_path)`: 读取本地文件内容,支持txt/csv/json/zip等格式。**注意:不支持Excel和PDF!** |
| - `parse_pdf(file_path)`: 解析PDF文件,提取文本内容。**PDF文件必须用此工具!** |
| - `parse_excel(file_path)`: 解析Excel文件(.xlsx/.xls),返回表格内容。**Excel文件必须用此工具!** |
| - `image_ocr(file_path)`: 对图片进行OCR文字识别。 |
| - `transcribe_audio(file_path)`: 将音频文件转写为文字。 |
| - `analyze_image(file_path, question)`: 使用AI分析图片内容。 |
| |
| ### 计算和代码工具 |
| - `calc(expression)`: 执行安全的数学计算,如 "2+3*4" 或 "sqrt(16)"。适用于简单算术。 |
| - `run_python(code)`: 在沙箱中执行Python代码。支持 import math/re/json/datetime/collections/random/string/itertools/functools 模块。适用于复杂数据处理、排序、过滤、日期计算等操作。 |
| |
| ## 工具使用策略 |
| |
| ### 优先级顺序 |
| 0. **先查知识库**【最高优先级】: |
| - 首先调用 `rag_query(question)` 查询知识库 |
| - 如果返回"知识库匹配成功",**直接使用该答案作为最终回答**,不需要再调用其他工具 |
| - 如果返回"知识库参考",参考答案和步骤选择后续工具 |
| - 如果无匹配,按后续优先级使用其他工具 |
| 1. **有附件的问题**【重要】: |
| - 第一步:用 `fetch_task_files(task_id)` 下载文件 |
| - 第二步:根据文件扩展名选择正确的读取工具: |
| * `.xlsx` / `.xls` → 必须用 `parse_excel(file_path)` |
| * `.pdf` → 必须用 `parse_pdf(file_path)` |
| * `.txt` / `.csv` / `.json` / `.md` → 用 `read_file(file_path)` |
| * `.png` / `.jpg` / `.jpeg` → 用 `image_ocr(file_path)` 或 `analyze_image(file_path, question)` |
| * `.mp3` / `.wav` → 用 `transcribe_audio(file_path)` |
| - 第三步:分析文件内容,进行必要的计算或处理 |
| - **禁止**:下载文件后不要用 web_search 搜索,文件内容已经本地可用! |
| 2. **需要外部信息**: |
| - **百科知识查询流程**【重要】: |
| * 第一步:用 `wikipedia_search(query)` 确认页面标题 |
| * 第二步:如果需要详细数据(专辑列表、作品年表、获奖记录等),必须用 `wikipedia_page(title, section)` 获取完整内容 |
| * 示例:查 Mercedes Sosa 专辑数 → `wikipedia_search("Mercedes Sosa")` → `wikipedia_page("Mercedes Sosa", "Discography")` |
| - 通用搜索: 使用 `web_search` 搜索其他网络信息 |
| - 学术论文: 使用 `arxiv_search` 查找研究文献 |
| - 高质量结果: 使用 `tavily_search` (如果配置了API Key) |
| 3. **需要计算**: 简单算术用 `calc`,复杂处理用 `run_python` |
| 4. **数据处理**: 使用 `run_python` 进行排序、过滤、统计等操作 |
| |
| ### 工具使用原则 |
| - **只有问题明确提到"attached file"或"附件"时才调用 `fetch_task_files`**,否则不要调用 |
| - 每次只调用一个必要的工具,分析结果后再决定下一步 |
| - 如果工具返回错误,尝试调整参数或换用其他工具 |
| - 搜索时使用精确的关键词,避免过于宽泛 |
| - 读取大文件时注意内容可能被截断,关注关键信息 |
| - **如果 `wikipedia_search` 返回的摘要不足以回答问题,立即使用 `wikipedia_page` 获取完整内容** |
| |
| ## 思考过程 |
| |
| 在回答问题前,请按以下步骤思考: |
| 1. **理解问题**: 问题在问什么?需要什么类型的信息? |
| 2. **咨询知识库**: 如果问题复杂或不确定解法,用 `rag_query` 查看相似问题的解题策略 |
| 3. **判断工具**: 根据问题类型和 RAG 建议,选择合适的工具 |
| 4. **执行获取**: 调用工具获取信息 |
| 5. **分析整合**: 分析工具返回的信息,提取关键答案 |
| 6. **格式化输出**: 按要求格式输出最终答案 |
| |
| ## 答案格式要求【非常重要】 |
| |
| 最终答案必须遵循以下格式: |
| - **数字答案**: 直接输出数字,如 `42` 而不是 "答案是42" |
| - **人名/地名**: 直接输出名称,如 `Albert Einstein` 而不是 "答案是Albert Einstein" |
| - **日期答案**: 使用标准格式 `YYYY-MM-DD` 或按问题要求的格式 |
| - **列表答案**: 用逗号分隔,如 `A, B, C` |
| - **是/否答案**: 输出 `Yes` 或 `No` |
| |
| ⚠️ 最终回答时,只输出答案本身,不要包含: |
| - 不要说"答案是..."、"The answer is..." |
| - 不要添加解释或推理过程 |
| - 不要使用"最终答案:"等前缀 |
| |
| ## 错误恢复 |
| |
| 如果遇到问题: |
| - 工具调用失败: 检查参数,尝试简化或换用其他工具 |
| - 搜索无结果: 尝试不同的关键词组合 |
| - 文件读取失败: 确认文件路径正确,检查文件格式 |
| - 计算错误: 检查表达式语法,考虑使用Python代码 |
| |
| ## 示例 |
| |
| 问题: "Who was the first person to walk on the moon?" |
| 正确答案: Neil Armstrong |
| 错误答案: The answer is Neil Armstrong. |
| |
| 问题: "What is 15% of 200?" |
| 正确答案: 30 |
| 错误答案: 15% of 200 is 30. |
| |
| ### 文件处理示例【重要】 |
| |
| 问题: "[Task ID: abc123] The attached Excel file contains sales data. What is the total revenue?" |
| |
| ✅ 正确流程: |
| 1. fetch_task_files("abc123") → 下载文件到本地路径 |
| 2. parse_excel("/path/to/file.xlsx") → 读取Excel内容,得到表格数据 |
| 3. calc("100+200+300") 或 run_python("...") → 计算总收入 |
| 4. 输出最终答案 |
| |
| ❌ 错误流程: |
| 1. fetch_task_files("abc123") → 下载文件 |
| 2. web_search("sales data total revenue") → 错!文件内容在本地,不需要搜索网络! |
| |
| ### RAG 辅助示例 |
| |
| 问题: "How many studio albums did Mercedes Sosa release between 2000 and 2009?" |
| |
| ✅ 推荐流程: |
| 1. rag_query("How many studio albums did Mercedes Sosa release between 2000 and 2009?") → 获取建议:使用 wikipedia_page 查 Discography |
| 2. wikipedia_search("Mercedes Sosa") → 确认页面存在 |
| 3. wikipedia_page("Mercedes Sosa", "Discography") → 获取完整专辑列表 |
| 4. run_python("...") → 筛选 2000-2009 年的专辑并计数 |
| 5. 输出最终答案 |
| |
| RAG 的价值:直接告诉你该用 wikipedia_page 而不是 web_search,节省试错时间。 |
| |
| 现在请回答用户的问题。""" |
|
|
|
|
| |
| |
| |
|
|
| class AgentState(TypedDict): |
| """Agent 状态定义""" |
| |
| messages: Annotated[Sequence[BaseMessage], add_messages] |
|
|
| |
| iteration_count: int |
|
|
|
|
| |
| |
| |
|
|
| |
| _llm_instance = None |
| _llm_with_tools = None |
|
|
|
|
| def get_llm(): |
| """获取 LLM 单例""" |
| global _llm_instance |
| if _llm_instance is None: |
| _llm_instance = ChatOpenAI( |
| model=MODEL, |
| temperature=TEMPERATURE, |
| base_url=OPENAI_BASE_URL, |
| api_key=OPENAI_API_KEY, |
| timeout=LLM_TIMEOUT, |
| max_retries=2, |
| ) |
| return _llm_instance |
|
|
|
|
| def get_llm_with_tools(): |
| """获取绑定工具的 LLM 单例""" |
| global _llm_with_tools |
| if _llm_with_tools is None: |
| _llm_with_tools = get_llm().bind_tools(ALL_TOOLS) |
| return _llm_with_tools |
|
|
|
|
| def invoke_llm_with_retry(llm, messages, max_retries=None, base_delay=None): |
| """ |
| 带重试逻辑的 LLM 调用(处理 429 速率限制错误) |
| |
| Args: |
| llm: LLM 实例 |
| messages: 消息列表 |
| max_retries: 最大重试次数,默认使用配置值 |
| base_delay: 基础延迟秒数,默认使用配置值 |
| |
| Returns: |
| LLM 响应 |
| |
| Raises: |
| 原始异常(如果重试耗尽) |
| """ |
| import time |
| from openai import RateLimitError |
|
|
| if max_retries is None: |
| max_retries = RATE_LIMIT_RETRY_MAX |
| if base_delay is None: |
| base_delay = RATE_LIMIT_RETRY_BASE_DELAY |
|
|
| last_error = None |
|
|
| for attempt in range(max_retries + 1): |
| try: |
| return llm.invoke(messages) |
| except RateLimitError as e: |
| last_error = e |
| if attempt < max_retries: |
| |
| delay = base_delay * (2 ** attempt) |
| print(f"[Rate Limit] 429 错误,第 {attempt + 1}/{max_retries + 1} 次尝试,等待 {delay:.1f} 秒后重试...") |
| time.sleep(delay) |
| else: |
| print(f"[Rate Limit] 重试次数已耗尽 ({max_retries + 1} 次),抛出异常") |
| raise |
| except Exception as e: |
| |
| raise |
|
|
| |
| if last_error: |
| raise last_error |
|
|
|
|
| def create_llm(): |
| """创建 LLM 实例(保留兼容性)""" |
| return get_llm() |
|
|
|
|
| |
| |
| |
|
|
| def assistant(state: AgentState) -> dict: |
| """ |
| LLM 推理节点 |
| |
| 职责: |
| 1. 接收当前状态 |
| 2. 构建完整消息(包含 System Prompt) |
| 3. 调用 LLM 生成响应 |
| 4. 更新迭代计数 |
| """ |
| messages = state["messages"] |
| iteration = state.get("iteration_count", 0) + 1 |
|
|
| |
| full_messages = [SystemMessage(content=SYSTEM_PROMPT)] + list(messages) |
|
|
| |
| if iteration >= MAX_ITERATIONS - 1: |
| print(f"[Iteration {iteration}] FORCING FINAL ANSWER (no tools)") |
| warning = f""" |
| |
| ⚠️ 【最后机会】已进行 {iteration} 次迭代,达到上限 {MAX_ITERATIONS}。 |
| 你必须立即给出最终答案!不要再调用任何工具! |
| 直接根据已有信息输出答案。如果信息不足,给出最佳估计。 |
| """ |
| full_messages.append(SystemMessage(content=warning)) |
| |
| llm = get_llm() |
| try: |
| response = invoke_llm_with_retry(llm, full_messages) |
| except Exception as e: |
| print(f"[ERROR] LLM 调用失败: {type(e).__name__}: {str(e)}") |
| raise |
| elif iteration >= MAX_ITERATIONS - 2: |
| warning = f"\n\n⚠️ 警告:已进行 {iteration} 次迭代,接近上限 {MAX_ITERATIONS},请尽快给出最终答案,不要再搜索。" |
| full_messages.append(SystemMessage(content=warning)) |
| |
| llm_with_tools = get_llm_with_tools() |
| try: |
| response = invoke_llm_with_retry(llm_with_tools, full_messages) |
| except Exception as e: |
| print(f"[ERROR] LLM 调用失败: {type(e).__name__}: {str(e)}") |
| raise |
| else: |
| |
| llm_with_tools = get_llm_with_tools() |
| try: |
| response = invoke_llm_with_retry(llm_with_tools, full_messages) |
| except Exception as e: |
| print(f"[ERROR] LLM 调用失败: {type(e).__name__}: {str(e)}") |
| raise |
|
|
| |
| print(f"[Iteration {iteration}] LLM Response: {response.content[:200] if response.content else '(empty)'}...") |
| if hasattr(response, 'tool_calls') and response.tool_calls: |
| print(f"[Iteration {iteration}] Tool calls: {[tc['name'] for tc in response.tool_calls]}") |
|
|
| return { |
| "messages": [response], |
| "iteration_count": iteration |
| } |
|
|
|
|
| def should_continue(state: AgentState) -> Literal["tools", "end"]: |
| """ |
| 路由判断:决定继续使用工具还是结束 |
| |
| 判断逻辑: |
| 1. 达到迭代上限 → 强制结束 |
| 2. 有工具调用 → 继续执行工具 |
| 3. 无工具调用 → 返回答案,结束 |
| """ |
| last_message = state["messages"][-1] |
| iteration = state.get("iteration_count", 0) |
|
|
| |
| if iteration >= MAX_ITERATIONS: |
| print(f"[Router] Reached max iterations ({MAX_ITERATIONS}), forcing end") |
| return "end" |
|
|
| |
| if hasattr(last_message, "tool_calls") and last_message.tool_calls: |
| print(f"[Router] Has tool calls, continuing to tools") |
| return "tools" |
|
|
| |
| print(f"[Router] No tool calls, ending") |
| return "end" |
|
|
|
|
| |
| |
| |
|
|
| def build_agent_graph(): |
| """ |
| 构建 Agent Graph |
| |
| 流程: |
| START → assistant → [should_continue] → tools → assistant → ... → END |
| """ |
| graph = StateGraph(AgentState) |
|
|
| |
| graph.add_node("assistant", assistant) |
| graph.add_node("tools", ToolNode(ALL_TOOLS)) |
|
|
| |
| graph.set_entry_point("assistant") |
|
|
| |
| graph.add_conditional_edges( |
| "assistant", |
| should_continue, |
| {"tools": "tools", "end": END} |
| ) |
|
|
| |
| graph.add_edge("tools", "assistant") |
|
|
| return graph.compile() |
|
|
|
|
| |
| |
| |
|
|
| def extract_final_answer(result: dict) -> str: |
| """ |
| 从 Agent 结果中提取最终答案 |
| |
| 处理步骤: |
| 1. 获取最后一条消息 |
| 2. 移除常见前缀 |
| 3. 移除尾部解释 |
| 4. 提取 JSON 格式答案 |
| 5. 清理格式 |
| """ |
| messages = result.get("messages", []) |
| if not messages: |
| print("[extract_final_answer] No messages in result") |
| return "无法获取答案" |
|
|
| |
| content = None |
|
|
| |
| for msg in reversed(messages): |
| if isinstance(msg, AIMessage) and msg.content and str(msg.content).strip(): |
| if not (hasattr(msg, "tool_calls") and msg.tool_calls): |
| content = msg.content |
| break |
|
|
| |
| if content is None: |
| for msg in reversed(messages): |
| if isinstance(msg, AIMessage) and msg.content and str(msg.content).strip(): |
| content = msg.content |
| break |
|
|
| |
| if content is None: |
| for msg in reversed(messages): |
| if hasattr(msg, "content") and msg.content and str(msg.content).strip(): |
| content = msg.content |
| break |
|
|
| print(f"[extract_final_answer] Raw content: {content[:500] if content else '(empty)'}...") |
|
|
| if not content: |
| print("[extract_final_answer] Empty content in all messages") |
| return "无法获取答案" |
|
|
| answer = content.strip() |
|
|
| |
| prefix_patterns = [ |
| |
| r'^(?:the\s+)?(?:final\s+)?answer\s*(?:is|:)\s*', |
| r'^(?:the\s+)?result\s*(?:is|:)\s*', |
| r'^(?:therefore|thus|so|hence)[,:]?\s*', |
| r'^based\s+on\s+(?:the|my)\s+(?:analysis|research|calculations?)[,:]?\s*', |
| r'^after\s+(?:analyzing|reviewing|checking)[^,]*[,:]?\s*', |
| r'^according\s+to\s+[^,]*[,:]?\s*', |
| |
| r'^(?:最终)?答案[是为::]\s*', |
| r'^(?:结果|结论)[是为::]\s*', |
| r'^(?:因此|所以|综上)[,,::]?\s*', |
| r'^根据(?:以上)?(?:分析|信息|计算)[,,::]?\s*', |
| r'^经过(?:分析|计算|查询)[,,::]?\s*', |
| ] |
|
|
| for pattern in prefix_patterns: |
| answer = re.sub(pattern, '', answer, flags=re.IGNORECASE) |
|
|
| |
| suffix_patterns = [ |
| r'\s*(?:This|That|The|It)\s+(?:is|was|represents|refers\s+to).*$', |
| r'\s*[(\(].*[)\)]$', |
| r'\s*[。\.]$', |
| r'\s*\n\n.*$', |
| ] |
|
|
| for pattern in suffix_patterns: |
| answer = re.sub(pattern, '', answer, flags=re.IGNORECASE | re.DOTALL) |
|
|
| |
| json_patterns = [ |
| r'\{["\']?(?:final_?)?answer["\']?\s*:\s*["\']?([^"\'}\n]+)["\']?\}', |
| r'"answer"\s*:\s*"([^"]+)"', |
| ] |
| for pattern in json_patterns: |
| json_match = re.search(pattern, answer, re.IGNORECASE) |
| if json_match: |
| answer = json_match.group(1) |
| break |
|
|
| |
| answer = answer.strip() |
| answer = re.sub(r'\s+', ' ', answer) |
| answer = answer.strip('"\'') |
|
|
| |
| if re.match(r'^[\d,\.]+$', answer): |
| answer = answer.replace(',', '') |
|
|
| return answer |
|
|
|
|
| def post_process_answer(answer: str, expected_type: str = None) -> str: |
| """ |
| 根据预期类型后处理答案 |
| |
| Args: |
| answer: 原始答案 |
| expected_type: 预期类型 (number, date, boolean, list) |
| |
| Returns: |
| 处理后的答案 |
| """ |
| if expected_type == "number": |
| match = re.search(r'-?\d+\.?\d*', answer.replace(',', '')) |
| if match: |
| return match.group() |
|
|
| elif expected_type == "date": |
| |
| date_patterns = [ |
| (r'(\d{4})-(\d{1,2})-(\d{1,2})', lambda m: f"{m.group(1)}-{int(m.group(2)):02d}-{int(m.group(3)):02d}"), |
| (r'(\d{1,2})/(\d{1,2})/(\d{4})', lambda m: f"{m.group(3)}-{int(m.group(1)):02d}-{int(m.group(2)):02d}"), |
| ] |
| for pattern, formatter in date_patterns: |
| match = re.search(pattern, answer) |
| if match: |
| return formatter(match) |
|
|
| elif expected_type == "boolean": |
| lower = answer.lower().strip() |
| if lower in ['yes', 'true', '是', '对', 'correct']: |
| return "Yes" |
| elif lower in ['no', 'false', '否', '不', '错', 'incorrect']: |
| return "No" |
|
|
| elif expected_type == "list": |
| answer = re.sub(r'\s*[;;、]\s*', ', ', answer) |
|
|
| return answer |
|
|
|
|
| |
| |
| |
|
|
| class GaiaAgent: |
| """ |
| GAIA Agent 入口类 |
| |
| 使用方法: |
| agent = GaiaAgent() |
| answer = agent("Who founded Microsoft?") |
| """ |
|
|
| def __init__(self): |
| """初始化 Agent""" |
| self.graph = build_agent_graph() |
|
|
| def _needs_reformatting(self, answer: str) -> bool: |
| """检查答案是否需要重新格式化""" |
| if not answer or answer == "无法获取答案": |
| return False |
| indicators = [ |
| answer.startswith('http'), |
| 'URL:' in answer, |
| len(answer) > 300, |
| answer.count('\n') > 3, |
| answer.startswith('1.') and '2.' in answer, |
| answer.startswith('- '), |
| '...' in answer and len(answer) > 100, |
| ] |
| return any(indicators) |
|
|
| def _force_format_answer(self, result: dict) -> str: |
| """强制格式化答案""" |
| messages = result.get("messages", []) |
| format_prompt = ( |
| "根据上述对话收集的信息,输出最终答案。\n\n" |
| "【强制要求】只输出答案本身,不要解释、不要前缀。\n" |
| "- 数字:直接输出(如 42)\n" |
| "- 人名/地名:直接输出(如 Albert Einstein)\n" |
| "- 日期:YYYY-MM-DD\n" |
| "- 是/否:Yes 或 No\n\n" |
| "最终答案:" |
| ) |
| full_messages = [SystemMessage(content=SYSTEM_PROMPT)] + list(messages) |
| full_messages.append(HumanMessage(content=format_prompt)) |
| llm = get_llm() |
| try: |
| print("[Reformat] Forcing answer formatting...") |
| response = invoke_llm_with_retry(llm, full_messages) |
| formatted = extract_final_answer({"messages": [response]}) |
| print(f"[Reformat] Result: {formatted[:100]}...") |
| return formatted |
| except Exception as e: |
| print(f"[Reformat] Error: {e}") |
| return "无法获取答案" |
|
|
| def __call__(self, question: str, task_id: str = None) -> str: |
| """ |
| 执行问答 |
| |
| Args: |
| question: 用户问题 |
| task_id: 任务 ID(可选,用于下载附件) |
| |
| Returns: |
| 最终答案 |
| """ |
| |
| if task_id: |
| question_with_id = f"[Task ID: {task_id}]\n\n{question}" |
| else: |
| question_with_id = question |
|
|
| |
| try: |
| if rag_lookup_answer is not None: |
| hit = rag_lookup_answer(question, min_similarity=0.85) |
| if hit and hit.get("answer"): |
| print(f"[GaiaAgent] RAG short-circuit hit: similarity={hit.get('similarity', 0):.2f}") |
| if DEBUG: |
| print(f"[Final Answer] {hit['answer']}") |
| return str(hit["answer"]).strip() |
| except Exception as e: |
| if DEBUG: |
| print(f"[GaiaAgent] RAG short-circuit failed: {type(e).__name__}: {e}") |
| |
|
|
| |
| initial_state = { |
| "messages": [HumanMessage(content=question_with_id)], |
| "iteration_count": 0 |
| } |
|
|
| try: |
| |
| result = self.graph.invoke(initial_state) |
|
|
| |
| answer = extract_final_answer(result) |
|
|
| |
| if self._needs_reformatting(answer): |
| print(f"[GaiaAgent] Answer needs reformatting: {answer[:50]}...") |
| answer = self._force_format_answer(result) |
|
|
| if DEBUG: |
| print(f"[Final Answer] {answer}") |
|
|
| return answer if answer else "无法获取答案" |
|
|
| except Exception as e: |
| import traceback |
| error_msg = f"Agent 执行出错: {type(e).__name__}: {str(e)}" |
| print(f"[ERROR] {error_msg}") |
| print(traceback.format_exc()) |
| return error_msg |
|
|
| def run_with_history(self, messages: list) -> dict: |
| """ |
| 带历史消息执行 |
| |
| Args: |
| messages: 消息历史列表 |
| |
| Returns: |
| 完整结果字典 |
| """ |
| initial_state = { |
| "messages": messages, |
| "iteration_count": 0 |
| } |
|
|
| return self.graph.invoke(initial_state) |
|
|
|
|
| |
| |
| |
|
|
| def run_agent(question: str, task_id: str = None) -> str: |
| """ |
| 运行 Agent 的便捷函数 |
| |
| Args: |
| question: 用户问题 |
| task_id: 任务 ID(可选) |
| |
| Returns: |
| 最终答案 |
| """ |
| agent = GaiaAgent() |
| return agent(question, task_id) |
|
|
|
|
| |
| |
| |
|
|
| if __name__ == "__main__": |
| |
| agent = GaiaAgent() |
|
|
| |
| print("Test 1: Calculation") |
| answer = agent("What is 15% of 200?") |
| print(f"Answer: {answer}\n") |
|
|
| |
| print("Test 2: Search") |
| answer = agent("Who founded Microsoft?") |
| print(f"Answer: {answer}\n") |
|
|