| |
| |
|
|
| """ |
| SSE Tool Handler - 处理工具调用的SSE流 |
| 基于 Z.AI 原生的 edit_index 和 edit_content 机制,更原生地处理工具调用 |
| """ |
|
|
| import json |
| import re |
| import time |
| from typing import Dict, Any, Optional, Generator, List |
|
|
| from app.utils.helpers import debug_log |
|
|
|
|
| class SSEToolHandler: |
|
|
| def __init__(self, chat_id: str, model: str): |
| self.chat_id = chat_id |
| self.model = model |
|
|
| |
| self.has_tool_call = False |
| self.tool_call_usage = None |
| self.content_index = 0 |
| self.has_thinking = False |
|
|
| self.content_buffer = bytearray() |
| self.last_edit_index = 0 |
|
|
| |
| self.active_tools = {} |
| self.completed_tools = [] |
| self.tool_blocks_cache = {} |
|
|
| def process_tool_call_phase(self, data: Dict[str, Any], is_stream: bool = True) -> Generator[str, None, None]: |
| """ |
| 处理tool_call阶段 |
| """ |
| if not self.has_tool_call: |
| self.has_tool_call = True |
| debug_log("🔧 进入工具调用阶段") |
|
|
| edit_content = data.get("edit_content", "") |
| edit_index = data.get("edit_index", 0) |
|
|
| if not edit_content: |
| return |
|
|
| |
|
|
| |
| self._apply_edit_to_buffer(edit_index, edit_content) |
|
|
| |
| yield from self._process_tool_calls_from_buffer(is_stream) |
|
|
| def _apply_edit_to_buffer(self, edit_index: int, edit_content: str): |
| """ |
| 在指定位置替换/插入内容更新内容缓冲区 |
| """ |
| edit_bytes = edit_content.encode('utf-8') |
| required_length = edit_index + len(edit_bytes) |
|
|
| |
| if len(self.content_buffer) < edit_index: |
| |
| self.content_buffer.extend(b'\x00' * (edit_index - len(self.content_buffer))) |
|
|
| |
| if len(self.content_buffer) < required_length: |
| self.content_buffer.extend(b'\x00' * (required_length - len(self.content_buffer))) |
|
|
| |
| end_index = edit_index + len(edit_bytes) |
| self.content_buffer[edit_index:end_index] = edit_bytes |
|
|
| |
|
|
| def _process_tool_calls_from_buffer(self, is_stream: bool) -> Generator[str, None, None]: |
| """ |
| 从内容缓冲区中解析和处理工具调用 |
| """ |
| try: |
| |
| content_str = self.content_buffer.decode('utf-8', errors='ignore').replace('\x00', '') |
| yield from self._extract_and_process_tools(content_str, is_stream) |
| except Exception as e: |
| debug_log(f"📦 内容解析暂时失败,等待更多数据: {e}") |
| |
|
|
| def _extract_and_process_tools(self, content_str: str, is_stream: bool) -> Generator[str, None, None]: |
| """ |
| 从内容字符串中提取和处理工具调用 |
| """ |
| |
| pattern = r'<glm_block\s*>(.*?)(?:</glm_block>|$)' |
| matches = re.findall(pattern, content_str, re.DOTALL) |
|
|
| for block_content in matches: |
| |
| yield from self._process_single_tool_block(block_content, is_stream) |
|
|
| def _process_single_tool_block(self, block_content: str, is_stream: bool) -> Generator[str, None, None]: |
| """ |
| 处理单个工具块,支持增量解析 |
| """ |
| try: |
| |
| fixed_content = self._fix_json_structure(block_content) |
| tool_data = json.loads(fixed_content) |
| metadata = tool_data.get("data", {}).get("metadata", {}) |
|
|
| tool_id = metadata.get("id", "") |
| tool_name = metadata.get("name", "") |
| arguments_raw = metadata.get("arguments", "{}") |
|
|
| if not tool_id or not tool_name: |
| return |
|
|
| debug_log(f"🎯 解析完整工具块: {tool_name}(id={tool_id}), 参数: {arguments_raw}") |
|
|
| |
| yield from self._handle_tool_update(tool_id, tool_name, arguments_raw, is_stream) |
|
|
| except json.JSONDecodeError as e: |
| debug_log(f"📦 JSON解析失败: {e}, 尝试部分解析") |
| |
| yield from self._handle_partial_tool_block(block_content, is_stream) |
| except Exception as e: |
| debug_log(f"📦 工具块处理失败: {e}") |
|
|
| def _fix_json_structure(self, content: str) -> str: |
| """ |
| 修复JSON结构中的常见问题 |
| """ |
| if not content: |
| return content |
|
|
| |
| open_braces = content.count('{') |
| close_braces = content.count('}') |
|
|
| |
| if close_braces > open_braces: |
| excess = close_braces - open_braces |
| fixed_content = content |
| for _ in range(excess): |
| |
| last_brace_pos = fixed_content.rfind('}') |
| if last_brace_pos != -1: |
| fixed_content = fixed_content[:last_brace_pos] + fixed_content[last_brace_pos + 1:] |
| return fixed_content |
|
|
| return content |
|
|
| def _handle_tool_update(self, tool_id: str, tool_name: str, arguments_raw: str, is_stream: bool) -> Generator[str, None, None]: |
| """ |
| 处理工具的创建或更新 - 更可靠的参数完整性检查 |
| """ |
| |
| try: |
| if isinstance(arguments_raw, str): |
| |
| cleaned_args = self._clean_arguments_string(arguments_raw) |
| arguments = json.loads(cleaned_args) if cleaned_args.strip() else {} |
| else: |
| arguments = arguments_raw |
| except json.JSONDecodeError: |
| debug_log(f"📦 参数解析失败,暂不处理: {arguments_raw}") |
| |
| return |
|
|
| |
| is_args_complete = self._is_arguments_complete(arguments, arguments_raw) |
|
|
| |
| if tool_id not in self.active_tools: |
| debug_log(f"🎯 发现新工具: {tool_name}(id={tool_id}), 参数完整性: {is_args_complete}") |
|
|
| self.active_tools[tool_id] = { |
| "id": tool_id, |
| "name": tool_name, |
| "arguments": arguments, |
| "arguments_raw": arguments_raw, |
| "status": "active", |
| "sent_start": False, |
| "last_sent_args": {}, |
| "args_complete": is_args_complete, |
| "pending_send": True |
| } |
|
|
| |
| if is_stream and is_args_complete: |
| yield self._create_tool_start_chunk(tool_id, tool_name, arguments) |
| self.active_tools[tool_id]["sent_start"] = True |
| self.active_tools[tool_id]["last_sent_args"] = arguments.copy() |
| self.active_tools[tool_id]["pending_send"] = False |
| debug_log(f"📤 发送完整工具开始: {tool_name}(id={tool_id})") |
|
|
| else: |
| |
| current_tool = self.active_tools[tool_id] |
|
|
| |
| if self._is_significant_improvement(current_tool["arguments"], arguments, |
| current_tool["arguments_raw"], arguments_raw): |
| debug_log(f"🔄 工具参数有实质性改进: {tool_name}(id={tool_id})") |
|
|
| current_tool["arguments"] = arguments |
| current_tool["arguments_raw"] = arguments_raw |
| current_tool["args_complete"] = is_args_complete |
|
|
| |
| if is_stream and not current_tool["sent_start"] and is_args_complete: |
| yield self._create_tool_start_chunk(tool_id, tool_name, arguments) |
| current_tool["sent_start"] = True |
| current_tool["last_sent_args"] = arguments.copy() |
| current_tool["pending_send"] = False |
| debug_log(f"📤 发送延迟的工具开始: {tool_name}(id={tool_id})") |
|
|
| |
| elif is_stream and current_tool["sent_start"] and is_args_complete: |
| if self._should_send_argument_update(current_tool["last_sent_args"], arguments): |
| yield self._create_tool_arguments_chunk(tool_id, arguments) |
| current_tool["last_sent_args"] = arguments.copy() |
| debug_log(f"📤 发送参数更新: {tool_name}(id={tool_id})") |
|
|
| def _is_arguments_complete(self, arguments: Dict[str, Any], arguments_raw: str) -> bool: |
| """ |
| 检查参数是否看起来完整 |
| """ |
| if not arguments: |
| return False |
|
|
| |
| if not arguments_raw or not arguments_raw.strip(): |
| return False |
|
|
| |
| raw_stripped = arguments_raw.strip() |
|
|
| |
| if not raw_stripped.endswith('}') and not raw_stripped.endswith('"'): |
| return False |
|
|
| |
| for key, value in arguments.items(): |
| if isinstance(value, str): |
| |
| if 'http' in value.lower(): |
| |
| if len(value) < 10 or value.endswith('.go') or value.endswith('.goo'): |
| return False |
|
|
| |
| if len(value) > 0 and value[-1] in ['.', '/', ':', '=']: |
| |
| return False |
|
|
| return True |
|
|
| def _is_significant_improvement(self, old_args: Dict[str, Any], new_args: Dict[str, Any], |
| old_raw: str, new_raw: str) -> bool: |
| """ |
| 检查新参数是否比旧参数有显著改进 |
| """ |
| |
| if not new_args: |
| return False |
| |
| if len(new_args) > len(old_args): |
| return True |
|
|
| |
| for key, new_value in new_args.items(): |
| old_value = old_args.get(key, "") |
|
|
| if isinstance(new_value, str) and isinstance(old_value, str): |
| |
| if len(new_value) > len(old_value) + 5: |
| return True |
|
|
| |
| if old_value.endswith(('.go', '.goo', '.com/', 'http')) and len(new_value) > len(old_value): |
| return True |
|
|
| |
| if len(new_raw) > len(old_raw) + 10: |
| return True |
|
|
| return False |
|
|
| def _should_send_argument_update(self, last_sent: Dict[str, Any], new_args: Dict[str, Any]) -> bool: |
| """ |
| 判断是否应该发送参数更新 - 更严格的标准 |
| """ |
| |
| if last_sent == new_args: |
| return False |
|
|
| |
| if not new_args and last_sent: |
| return False |
|
|
| |
| if len(new_args) > len(last_sent): |
| return True |
|
|
| |
| for key, new_value in new_args.items(): |
| last_value = last_sent.get(key, "") |
| if isinstance(new_value, str) and isinstance(last_value, str): |
| |
| if len(new_value) > len(last_value) + 5: |
| return True |
| elif new_value != last_value and new_value: |
| return True |
|
|
| return False |
|
|
| def _handle_partial_tool_block(self, block_content: str, is_stream: bool) -> Generator[str, None, None]: |
| """ |
| 处理不完整的工具块,尝试提取可用信息 |
| """ |
| try: |
| |
| id_match = re.search(r'"id":\s*"([^"]+)"', block_content) |
| name_match = re.search(r'"name":\s*"([^"]+)"', block_content) |
|
|
| if id_match and name_match: |
| tool_id = id_match.group(1) |
| tool_name = name_match.group(1) |
|
|
| |
| args_match = re.search(r'"arguments":\s*"([^"]*)', block_content) |
| partial_args = args_match.group(1) if args_match else "" |
|
|
| debug_log(f"📦 部分工具块: {tool_name}(id={tool_id}), 部分参数: {partial_args[:50]}") |
|
|
| |
| if tool_id not in self.active_tools: |
| |
| partial_args_dict = self._parse_partial_arguments(partial_args) |
|
|
| self.active_tools[tool_id] = { |
| "id": tool_id, |
| "name": tool_name, |
| "arguments": partial_args_dict, |
| "status": "partial", |
| "sent_start": False, |
| "last_sent_args": {}, |
| "args_complete": False, |
| "partial_args": partial_args |
| } |
|
|
| if is_stream: |
| yield self._create_tool_start_chunk(tool_id, tool_name, partial_args_dict) |
| self.active_tools[tool_id]["sent_start"] = True |
| self.active_tools[tool_id]["last_sent_args"] = partial_args_dict.copy() |
| else: |
| |
| self.active_tools[tool_id]["partial_args"] = partial_args |
| |
| new_partial_dict = self._parse_partial_arguments(partial_args) |
| if new_partial_dict != self.active_tools[tool_id]["arguments"]: |
| self.active_tools[tool_id]["arguments"] = new_partial_dict |
|
|
| except Exception as e: |
| debug_log(f"📦 部分块解析失败: {e}") |
|
|
| def _clean_arguments_string(self, arguments_raw: str) -> str: |
| """ |
| 清理和标准化参数字符串,改进对不完整JSON的处理 |
| """ |
| if not arguments_raw: |
| return "{}" |
|
|
| |
| cleaned = arguments_raw.strip() |
|
|
| |
| if cleaned.lower() == "null": |
| return "{}" |
|
|
| |
| if cleaned.startswith('{\\"') and cleaned.endswith('\\"}'): |
| |
| cleaned = cleaned.replace('\\"', '"') |
| elif cleaned.startswith('"{\\"') and cleaned.endswith('\\"}'): |
| |
| cleaned = cleaned[1:-1].replace('\\"', '"') |
| elif cleaned.startswith('"') and cleaned.endswith('"'): |
| |
| cleaned = cleaned[1:-1] |
|
|
| |
| cleaned = self._fix_incomplete_json(cleaned) |
|
|
| |
| try: |
| |
| parsed = json.loads(cleaned) |
| if parsed is None: |
| return "{}" |
| cleaned = json.dumps(parsed, ensure_ascii=False, separators=(',', ':')) |
| except json.JSONDecodeError: |
| |
| debug_log(f"📦 JSON标准化失败,保持原样: {cleaned[:50]}...") |
|
|
| return cleaned |
|
|
| def _fix_incomplete_json(self, json_str: str) -> str: |
| """ |
| 修复不完整的JSON字符串 |
| """ |
| if not json_str: |
| return "{}" |
|
|
| |
| if not json_str.startswith('{'): |
| json_str = '{' + json_str |
|
|
| |
| if json_str.count('"') % 2 != 0: |
| |
| json_str += '"' |
|
|
| |
| if not json_str.endswith('}'): |
| json_str += '}' |
|
|
| return json_str |
|
|
| def _parse_partial_arguments(self, arguments_raw: str) -> Dict[str, Any]: |
| """ |
| 解析不完整的参数字符串,尽可能提取有效信息 |
| """ |
| if not arguments_raw or arguments_raw.strip() == "" or arguments_raw.strip().lower() == "null": |
| return {} |
|
|
| try: |
| |
| cleaned = self._clean_arguments_string(arguments_raw) |
| result = json.loads(cleaned) |
| |
| return result if isinstance(result, dict) else {} |
| except json.JSONDecodeError: |
| pass |
|
|
| try: |
| |
| fixed_args = arguments_raw.strip() |
|
|
| |
| if '\\' in fixed_args: |
| fixed_args = fixed_args.replace('\\"', '"') |
|
|
| |
| if not fixed_args.startswith('{'): |
| fixed_args = '{' + fixed_args |
|
|
| |
| if not fixed_args.endswith('}'): |
| |
| quote_count = fixed_args.count('"') - fixed_args.count('\\"') |
| if quote_count % 2 != 0: |
| fixed_args += '"' |
| fixed_args += '}' |
|
|
| return json.loads(fixed_args) |
| except json.JSONDecodeError: |
| |
| return self._extract_key_value_pairs(arguments_raw) |
| except Exception: |
| |
| return {} |
|
|
| def _extract_key_value_pairs(self, text: str) -> Dict[str, Any]: |
| """ |
| 从文本中提取键值对,作为最后的解析尝试 |
| """ |
| result = {} |
| try: |
| |
| import re |
|
|
| |
| pattern = r'"([^"]+)":\s*"([^"]*)"' |
| matches = re.findall(pattern, text) |
|
|
| for key, value in matches: |
| result[key] = value |
|
|
| |
| pattern = r'"([^"]+)":\s*(\d+)' |
| matches = re.findall(pattern, text) |
|
|
| for key, value in matches: |
| try: |
| result[key] = int(value) |
| except ValueError: |
| result[key] = value |
|
|
| |
| pattern = r'"([^"]+)":\s*(true|false)' |
| matches = re.findall(pattern, text) |
|
|
| for key, value in matches: |
| result[key] = value.lower() == 'true' |
|
|
| except Exception: |
| pass |
|
|
| return result |
|
|
| def _complete_active_tools(self, is_stream: bool) -> Generator[str, None, None]: |
| """ |
| 完成所有活跃的工具调用 - 处理待发送的工具 |
| """ |
| tools_to_send = [] |
|
|
| for tool_id, tool in self.active_tools.items(): |
| |
| if is_stream and tool.get("pending_send", False) and not tool.get("sent_start", False): |
| if tool.get("args_complete", False): |
| debug_log(f"📤 完成时发送待发送工具: {tool['name']}(id={tool_id})") |
| yield self._create_tool_start_chunk(tool_id, tool["name"], tool["arguments"]) |
| tool["sent_start"] = True |
| tool["pending_send"] = False |
| tools_to_send.append(tool) |
| else: |
| debug_log(f"⚠️ 跳过不完整的工具: {tool['name']}(id={tool_id})") |
|
|
| tool["status"] = "completed" |
| self.completed_tools.append(tool) |
| debug_log(f"✅ 完成工具调用: {tool['name']}(id={tool_id})") |
|
|
| self.active_tools.clear() |
|
|
| if is_stream and (self.completed_tools or tools_to_send): |
| |
| yield self._create_tool_finish_chunk() |
|
|
| def process_other_phase(self, data: Dict[str, Any], is_stream: bool = True) -> Generator[str, None, None]: |
| """ |
| 处理other阶段 - 检测工具调用结束和状态更新 |
| """ |
| edit_content = data.get("edit_content", "") |
| edit_index = data.get("edit_index", 0) |
| usage = data.get("usage") |
|
|
| |
| if self.has_tool_call and usage: |
| self.tool_call_usage = usage |
| debug_log(f"💾 保存工具调用usage: {usage}") |
|
|
| |
| if edit_content: |
| self._apply_edit_to_buffer(edit_index, edit_content) |
| |
| yield from self._process_tool_calls_from_buffer(is_stream) |
|
|
| |
| if self.has_tool_call and self._is_tool_call_finished(edit_content): |
| debug_log("🏁 检测到工具调用结束") |
|
|
| |
| yield from self._complete_active_tools(is_stream) |
|
|
| if is_stream: |
| debug_log("🏁 发送工具调用完成信号") |
| yield "data: [DONE]" |
|
|
| |
| self.has_tool_call = False |
|
|
| def _is_tool_call_finished(self, edit_content: str) -> bool: |
| """ |
| 检测工具调用是否结束的多种标记 |
| """ |
| if not edit_content: |
| return False |
|
|
| |
| end_markers = [ |
| "null,", |
| '"status": "completed"', |
| '"is_error": false', |
| ] |
|
|
| for marker in end_markers: |
| if marker in edit_content: |
| debug_log(f"🔍 检测到结束标记: {marker}") |
| return True |
|
|
| |
| if self.active_tools and '"status": "completed"' in self.content_buffer: |
| return True |
|
|
| return False |
|
|
| def _reset_all_state(self): |
| """重置所有状态""" |
| self.has_tool_call = False |
| self.tool_call_usage = None |
| self.content_index = 0 |
| self.content_buffer = bytearray() |
| self.last_edit_index = 0 |
| self.active_tools.clear() |
| self.completed_tools.clear() |
| self.tool_blocks_cache.clear() |
|
|
| def _create_tool_start_chunk(self, tool_id: str, tool_name: str, initial_args: Dict[str, Any] = None) -> str: |
| """创建工具调用开始的chunk,支持初始参数""" |
| |
| args_dict = initial_args or {} |
| args_str = json.dumps(args_dict, ensure_ascii=False) |
|
|
| chunk = { |
| "choices": [ |
| { |
| "delta": { |
| "role": "assistant", |
| "content": None, |
| "tool_calls": [ |
| { |
| "id": tool_id, |
| "type": "function", |
| "function": {"name": tool_name, "arguments": args_str}, |
| } |
| ], |
| }, |
| "finish_reason": None, |
| "index": self.content_index, |
| "logprobs": None, |
| } |
| ], |
| "created": int(time.time()), |
| "id": self.chat_id, |
| "model": self.model, |
| "object": "chat.completion.chunk", |
| "system_fingerprint": "fp_zai_001", |
| } |
| return f"data: {json.dumps(chunk, ensure_ascii=False)}\n\n" |
|
|
| def _create_tool_arguments_chunk(self, tool_id: str, arguments: Dict) -> str: |
| """创建工具参数的chunk - 只包含参数更新,不包含函数名""" |
| chunk = { |
| "choices": [ |
| { |
| "delta": { |
| "tool_calls": [ |
| { |
| "id": tool_id, |
| "function": {"arguments": json.dumps(arguments, ensure_ascii=False)}, |
| } |
| ], |
| }, |
| "finish_reason": None, |
| "index": self.content_index, |
| "logprobs": None, |
| } |
| ], |
| "created": int(time.time()), |
| "id": self.chat_id, |
| "model": self.model, |
| "object": "chat.completion.chunk", |
| "system_fingerprint": "fp_zai_001", |
| } |
| return f"data: {json.dumps(chunk, ensure_ascii=False)}\n\n" |
|
|
| def _create_tool_finish_chunk(self) -> str: |
| """创建工具调用完成的chunk""" |
| chunk = { |
| "choices": [ |
| { |
| "delta": {"role": "assistant", "content": None, "tool_calls": []}, |
| "finish_reason": "tool_calls", |
| "index": 0, |
| "logprobs": None, |
| } |
| ], |
| "created": int(time.time()), |
| "id": self.chat_id, |
| "usage": self.tool_call_usage or None, |
| "model": self.model, |
| "object": "chat.completion.chunk", |
| "system_fingerprint": "fp_zai_001", |
| } |
| return f"data: {json.dumps(chunk, ensure_ascii=False)}\n\n" |