LiteRT-LM / python /litert_lm /tools.py
SeaWolf-AI's picture
Upload full LiteRT-LM codebase
5f923cd verified
# Copyright 2026 The ODML Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tool interfaces for LiteRT LM."""
from __future__ import annotations
import abc
import collections.abc
import inspect
import re
import typing
from typing import Any
def _parse_param_descriptions(docstring: str) -> dict[str, str]:
"""Parses Args section of docstring."""
descriptions = {}
if not docstring:
return descriptions
lines = docstring.split("\n")
in_args = False
current_arg = None
for line in lines:
stripped = line.strip()
if stripped == "Args:":
in_args = True
current_arg = None
continue
if in_args and stripped in ("Returns:", "Raises:", "Yields:"):
in_args = False
break
if not in_args:
continue
# Expect indentation for arguments
match = re.match(r"\s+([\w.]+)(?:\s*\(.*?\))?:\s*(.*)", line)
if match:
current_arg = match.group(1)
descriptions[current_arg] = match.group(2).strip()
elif current_arg and line.startswith(" " * 4):
descriptions[current_arg] += " " + stripped
elif not stripped:
current_arg = None
return descriptions
class Tool(abc.ABC):
"""A tool that can be executed."""
@abc.abstractmethod
def get_tool_description(self) -> dict[str, Any]:
"""Returns a JSON representing the tool in openapi schema."""
@abc.abstractmethod
def execute(self, param: collections.abc.Mapping[str, Any]) -> Any:
"""Executes the underlying function and returns the result.
Args:
param: A dictionary containing the parameters for the tool.
Returns:
The result of the tool execution.
"""
def _py_type_to_openapi(py_type: Any) -> dict[str, Any]:
"""Converts a Python type to an OpenAPI schema fragment."""
if py_type is int:
return {"type": "integer"}
if py_type is float:
return {"type": "number"}
if py_type is bool:
return {"type": "boolean"}
if py_type is str:
return {"type": "string"}
origin = typing.get_origin(py_type)
if origin in (list, collections.abc.Sequence, collections.abc.Iterable):
args = typing.get_args(py_type)
if args:
return {"type": "array", "items": _py_type_to_openapi(args[0])}
return {"type": "array"}
# Fallback to string
return {"type": "string"}
class _FunctionTool(Tool):
"""A Tool implementation that wraps a Python function."""
def __init__(self, func: collections.abc.Callable[..., Any]):
self._func = func
def get_tool_description(self) -> dict[str, Any]:
"""Returns the OpenAPI schema for the function."""
sig = inspect.signature(self._func)
doc = inspect.getdoc(self._func) or ""
param_descriptions = _parse_param_descriptions(doc)
parameters = {
"type": "object",
"properties": {},
"required": [],
}
for name, param in sig.parameters.items():
parameters["properties"][name] = _py_type_to_openapi(param.annotation)
if name in param_descriptions:
parameters["properties"][name]["description"] = param_descriptions[name]
if param.default is inspect.Parameter.empty:
parameters["required"].append(name)
return {
"type": "function",
"function": {
"name": self._func.__name__,
"description": doc.split("\n")[0] if doc else "",
"parameters": parameters,
},
}
def execute(self, param: collections.abc.Mapping[str, Any]) -> Any:
return self._func(**param)
def tool_from_function(func: collections.abc.Callable[..., Any]) -> Tool:
"""Converts a Python function into a Tool."""
return _FunctionTool(func)