| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| import random |
| import re |
| import string |
| import ast |
| import json |
| from collections.abc import Sequence |
| from typing import Union, Tuple, List, Optional |
|
|
| from vllm.entrypoints.openai.protocol import ( |
| ChatCompletionRequest, |
| DeltaMessage, |
| DeltaFunctionCall, |
| DeltaToolCall, |
| ExtractedToolCallInformation, |
| ToolCall, |
| FunctionCall, |
| ) |
| from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import ( |
| ToolParser |
| ) |
| from vllm.logger import init_logger |
|
|
| import pyjson5 |
|
|
| class ToolCallID: |
| _LENGTH = 10 |
|
|
| def __init__(self, id_val: str, validation: bool = False): |
| self._id = id_val |
| if validation: |
| self._validate() |
|
|
| @classmethod |
| def random(cls, validation=False) -> 'ToolCallID': |
| chars = string.ascii_lowercase + string.digits |
| return cls(''.join(random.choice(chars) for _ in range(ToolCallID._LENGTH)), validation=validation) |
|
|
| def _validate(self): |
| assert len(self._id) == ToolCallID._LENGTH |
| pattern = r'^[a-z0-9]{10}$' |
| assert re.match(pattern, self._id) is not None |
|
|
| def to_string(self) -> str: |
| return self._id |
|
|
| def __str__(self) -> str: |
| return self.to_string() |
|
|
|
|
| logger = init_logger(__name__) |
|
|
|
|
| class SolarOpenToolParser(ToolParser): |
|
|
| def extract_tool_calls( |
| self, |
| model_output: str, |
| request: ChatCompletionRequest, |
| ) -> ExtractedToolCallInformation: |
| content, tool_calls = self._parse_text(model_output) |
| return ExtractedToolCallInformation( |
| tools_called=len(tool_calls) > 0, |
| tool_calls=tool_calls, |
| content=content if content else None, |
| ) |
|
|
| def extract_tool_calls_streaming( |
| self, |
| previous_text: str, |
| current_text: str, |
| delta_text: str, |
| previous_token_ids: Sequence[int], |
| current_token_ids: Sequence[int], |
| delta_token_ids: Sequence[int], |
| request: ChatCompletionRequest, |
| ) -> Union[DeltaMessage, None]: |
| |
| |
| |
| |
| |
| |
| if delta_text: |
| |
| |
| |
| special_markers = ( |
| "<|flush|>", |
| "<|end|>", |
| "<|begin|>", |
| "<|tool_calls|>", |
| "<|tool_call:begin|>", |
| "<|tool_call:name|>", |
| "<|tool_call:args|>", |
| "<|tool_call:end|>", |
| "<|calls|>", |
| ) |
| if not any(tag in previous_text for tag in special_markers): |
| if not any(tag in delta_text for tag in special_markers): |
| return DeltaMessage(content=delta_text, tool_calls=[]) |
|
|
| tool_call_deltas: list[DeltaToolCall] = [] |
|
|
| |
| def _completed_calls_count(txt: str) -> int: |
| return len(self._parse_tool_calls(txt)) |
|
|
| |
| if delta_text and "<|tool_call:args|>" in delta_text: |
| |
| begin_tag = "<|tool_call:begin|>" |
| name_tag = "<|tool_call:name|>" |
| args_tag = "<|tool_call:args|>" |
|
|
| latest_args = current_text.rfind(args_tag) |
| latest_name = current_text.rfind(name_tag, 0, latest_args if latest_args != -1 else None) |
| latest_begin = current_text.rfind(begin_tag, 0, latest_name if latest_name != -1 else None) |
| if latest_begin != -1 and latest_name != -1 and latest_args != -1 and latest_begin < latest_name < latest_args: |
| tool_id = current_text[latest_begin + len(begin_tag):latest_name] |
| func_name = current_text[latest_name + len(name_tag):latest_args] |
| |
| index = previous_text.count(args_tag) |
| tool_call_deltas.append( |
| DeltaToolCall( |
| id=tool_id, |
| type="function", |
| index=index, |
| function=DeltaFunctionCall(name=func_name, arguments=""), |
| ) |
| ) |
|
|
| |
| begin_tag = "<|tool_call:begin|>" |
| args_tag = "<|tool_call:args|>" |
| end_tag = "<|tool_call:end|>" |
| last_args_pos = current_text.rfind(args_tag) |
| last_end_pos = current_text.rfind(end_tag) |
| if last_args_pos != -1 and (last_end_pos == -1 or last_args_pos > last_end_pos): |
| |
| |
| prev_last_args = previous_text.rfind(args_tag) |
| prev_last_end = previous_text.rfind(end_tag) |
| if prev_last_args != -1 and (prev_last_end == -1 or prev_last_args > prev_last_end): |
| |
| if delta_text and delta_text not in (begin_tag, args_tag, end_tag): |
| |
| index = max(previous_text.count(args_tag) - 1, 0) |
| tool_call_deltas.append( |
| DeltaToolCall( |
| id=None, |
| type=None, |
| index=index, |
| function=DeltaFunctionCall(name=None, arguments=delta_text), |
| ) |
| ) |
|
|
| if not tool_call_deltas: |
| return None |
|
|
| return DeltaMessage(content=None, tool_calls=tool_call_deltas) |
|
|
| |
| |
| |
| def _parse_text(self, text: str) -> Tuple[Optional[str], List[ToolCall]]: |
| """Parse the completed segments from the given text. |
| |
| Returns (content, tool_calls) where content is extracted as the leading |
| text up to the first '<|flush|>' or '<|end|>' marker, and tool_calls is |
| a list of fully parsed tool calls inside '<|tool_calls|> ... <|calls|>'. |
| """ |
| content = self._parse_content(text) |
| tool_calls = self._parse_tool_calls(text) |
| return content, tool_calls |
|
|
| def _parse_content(self, text: str) -> Optional[str]: |
| """Extract assistant content from the text. |
| |
| Rule: take the leading content before the first '<|flush|>' or |
| '<|end|>' marker. If neither marker exists, return None. |
| """ |
| end_tags = ["<|flush|>", "<|end|>"] |
|
|
| |
| end_positions = [pos for tag in end_tags if (pos := text.find(tag)) != -1] |
| if not end_positions: |
| return None |
| end = min(end_positions) |
| |
| return text[:end] |
|
|
| def _parse_tool_call_args(self, text: str) -> str: |
| try: |
| |
| args = json.loads(text) |
| except json.JSONDecodeError: |
| try: |
| |
| args = pyjson5.decode(text) |
| except pyjson5.Json5DecoderException: |
| try: |
| |
| args = ast.literal_eval(text) |
| except Exception: |
| |
| args = text |
| if not isinstance(args, str): |
| |
| args = json.dumps(args) |
| return args |
|
|
| def _parse_tool_calls(self, text: str) -> List[ToolCall]: |
| tool_calls: list[ToolCall] = [] |
| |
| section_start = 0 |
| |
| section_end = text.find("<|calls|>") |
| if section_end == -1: |
| section_end = len(text) |
| i = section_start |
| while True: |
| begin_tag = "<|tool_call:begin|>" |
| name_tag = "<|tool_call:name|>" |
| args_tag = "<|tool_call:args|>" |
| end_tag = "<|tool_call:end|>" |
|
|
| b = text.find(begin_tag, i, section_end) |
| if b == -1: |
| break |
| b += len(begin_tag) |
| n = text.find(name_tag, b, section_end) |
| if n == -1: |
| break |
| tool_id = text[b:n] |
| n += len(name_tag) |
| a = text.find(args_tag, n, section_end) |
| if a == -1: |
| break |
| name = text[n:a] |
| a += len(args_tag) |
| e = text.find(end_tag, a, section_end) |
| if e == -1: |
| break |
| args = text[a:e] |
| tool_calls.append( |
| ToolCall( |
| id=tool_id, |
| function=FunctionCall(name=name, arguments=self._parse_tool_call_args(args)), |
| )) |
| i = e + len(end_tag) |
|
|
| return tool_calls |
|
|