| import type { Tool } from "$lib/types/Tool"; |
| import { extractJson } from "./utils"; |
| import { externalToToolCall } from "../textGeneration/tools"; |
| import { logger } from "../logger"; |
| import type { Endpoint, EndpointMessage } from "../endpoints/endpoints"; |
|
|
| interface GetToolOutputOptions { |
| messages: EndpointMessage[]; |
| tool: Tool; |
| preprompt?: string; |
| endpoint: Endpoint; |
| generateSettings?: { |
| max_new_tokens?: number; |
| [key: string]: unknown; |
| }; |
| } |
|
|
| export async function getToolOutput<T = string>({ |
| messages, |
| preprompt, |
| tool, |
| endpoint, |
| generateSettings = { max_new_tokens: 64 }, |
| }: GetToolOutputOptions): Promise<T | undefined> { |
| try { |
| const stream = await endpoint({ |
| messages, |
| preprompt: preprompt + `\n\n Only use tool ${tool.name}.`, |
| tools: [tool], |
| generateSettings, |
| }); |
|
|
| const calls = []; |
|
|
| for await (const output of stream) { |
| if (output.token.toolCalls) { |
| calls.push(...output.token.toolCalls); |
| } |
| if (output.generated_text) { |
| const extractedCalls = await extractJson(output.generated_text).then((calls) => |
| calls.map((call) => externalToToolCall(call, [tool])).filter((call) => call !== undefined) |
| ); |
| calls.push(...extractedCalls); |
| } |
|
|
| if (calls.length > 0) { |
| break; |
| } |
| } |
|
|
| if (calls.length > 0) { |
| |
| const toolCall = calls.find((call) => call.name === tool.name); |
|
|
| |
| if (toolCall?.parameters) { |
| |
| const firstParamValue = Object.values(toolCall.parameters)[0]; |
| if (typeof firstParamValue === "string") { |
| return firstParamValue as T; |
| } |
| } |
| } |
|
|
| return undefined; |
| } catch (error) { |
| logger.warn(error, "Error getting tool output"); |
| return undefined; |
| } |
| } |
|
|