Spaces:
Running
Running
| // Copyright 2025 The Google AI Edge Authors. | |
| // | |
| // Licensed under the Apache License, Version 2.0 (the "License"); | |
| // you may not use this file except in compliance with the License. | |
| // You may obtain a copy of the License at | |
| // | |
| // http://www.apache.org/licenses/LICENSE-2.0 | |
| // | |
| // Unless required by applicable law or agreed to in writing, software | |
| // distributed under the License is distributed on an "AS IS" BASIS, | |
| // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| // See the License for the specific language governing permissions and | |
| // limitations under the License. | |
| namespace litert::lm { | |
| namespace { | |
| RE2 TextAndToolCodeRegex(absl::string_view code_fence_start, | |
| absl::string_view code_fence_end, | |
| bool escape_fence_strings) { | |
| // Construct the regex pattern: (non-greedy text before) <start> (non-greedy | |
| // code) <end>. | |
| std::string pattern; | |
| if (escape_fence_strings) { | |
| // QuoteMeta escapes any special regex characters in the fence strings. | |
| pattern = absl::StrCat("(?ms)(.*?)", RE2::QuoteMeta(code_fence_start), | |
| "(.*?)", RE2::QuoteMeta(code_fence_end)); | |
| } else { | |
| pattern = | |
| absl::StrCat("(?ms)(.*?)", code_fence_start, "(.*?)", code_fence_end); | |
| } | |
| return RE2(pattern); | |
| } | |
| std::string FilterLines(absl::string_view input, const RE2& regex) { | |
| std::vector<absl::string_view> lines = absl::StrSplit(input, '\n'); | |
| std::string captured_part; | |
| std::vector<std::string> captured_lines; | |
| for (absl::string_view line : lines) { | |
| if (RE2::PartialMatch(line, regex, &captured_part)) { | |
| captured_lines.push_back(captured_part); | |
| } else { | |
| captured_lines.push_back(std::string(line)); | |
| } | |
| } | |
| return absl::StrJoin(captured_lines, "\n"); | |
| } | |
| } // namespace | |
| SyntaxType GetSyntaxType(absl::string_view syntax_type) { | |
| static const absl::NoDestructor< | |
| absl::flat_hash_map<absl::string_view, SyntaxType>> | |
| kStringToSyntaxType({ | |
| {"python", SyntaxType::kPython}, | |
| {"json", SyntaxType::kJson}, | |
| {"fc", SyntaxType::kFc}, | |
| }); | |
| auto it = kStringToSyntaxType->find(syntax_type); | |
| if (it == kStringToSyntaxType->end()) { | |
| return SyntaxType::kUnknown; | |
| } | |
| return it->second; | |
| } | |
| absl::StatusOr<nlohmann::ordered_json> ParseTextAndToolCalls( | |
| absl::string_view response_str, absl::string_view code_fence_start, | |
| absl::string_view code_fence_end, SyntaxType syntax_type, | |
| bool escape_fence_strings, absl::string_view tool_code_regex) { | |
| nlohmann::ordered_json result = nlohmann::json::object(); | |
| // If the response is empty, return a content array with a single empty text | |
| // element to ensure the output format is consistent. | |
| if (response_str.empty()) { | |
| result["content"].push_back({{"type", "text"}, {"text", ""}}); | |
| return result; | |
| } | |
| RE2 regex = TextAndToolCodeRegex(code_fence_start, code_fence_end, | |
| escape_fence_strings); | |
| if (!regex.ok()) { | |
| return absl::InvalidArgumentError(absl::StrCat( | |
| "Invalid regex: ", regex.pattern(), " error: ", regex.error())); | |
| } | |
| std::string text; | |
| std::string code_block; | |
| absl::string_view original_response_str = response_str; | |
| while (RE2::Consume(&response_str, regex, &text, &code_block)) { | |
| // Append text to the content array. | |
| if (!text.empty()) { | |
| result["content"].push_back({{"type", "text"}, {"text", text}}); | |
| } | |
| // Before parsing the code block, apply tool_code_regex to each line. | |
| if (!tool_code_regex.empty()) { | |
| RE2 regex(tool_code_regex); | |
| if (!regex.ok()) { | |
| return absl::InvalidArgumentError( | |
| absl::StrCat("Invalid tool_code_regex: ", tool_code_regex)); | |
| } | |
| code_block = FilterLines(code_block, regex); | |
| } | |
| // Parse tool calls from the code block. | |
| if (!code_block.empty()) { | |
| absl::StatusOr<nlohmann::ordered_json> tool_calls; | |
| if (syntax_type == SyntaxType::kPython) { | |
| tool_calls = ParsePythonExpression(code_block); | |
| } else if (syntax_type == SyntaxType::kJson) { | |
| tool_calls = ParseJsonExpression(code_block); | |
| } else if (syntax_type == SyntaxType::kFc) { | |
| tool_calls = ParseFcExpression(code_block); | |
| } else { | |
| return absl::InvalidArgumentError(absl::StrCat( | |
| "Unsupported syntax type: ", static_cast<int>(syntax_type))); | |
| } | |
| if (!tool_calls.ok()) { | |
| return absl::InvalidArgumentError(absl::StrCat( | |
| "Failed to parse tool calls from response: ", original_response_str, | |
| "code block: ", code_block, | |
| " with error: ", tool_calls.status().message())); | |
| } | |
| for (const auto& tool_call : *tool_calls) { | |
| result["tool_calls"].push_back( | |
| {{"type", "function"}, {"function", tool_call}}); | |
| } | |
| } | |
| text.clear(); | |
| code_block.clear(); | |
| } | |
| // Append the remaining text to the content array. | |
| if (!response_str.empty()) { | |
| result["content"].push_back({{"type", "text"}, {"text", response_str}}); | |
| } | |
| return result; | |
| } | |
| } // namespace litert::lm | |