Spaces:
Running
Running
| # Copyright 2026 The ODML Authors. | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| """Utility functions for litert-lm models.""" | |
| import dataclasses | |
| import glob | |
| import importlib.util | |
| import inspect | |
| import json | |
| import os | |
| import pathlib | |
| import traceback | |
| import click | |
| import prompt_toolkit | |
| from prompt_toolkit import key_binding | |
| import litert_lm | |
| try: | |
| # pylint: disable=g-import-not-at-top | |
| from litert_lm.adb import adb_benchmark | |
| from litert_lm.adb import adb_engine | |
| _HAS_ADB = True | |
| except ImportError: | |
| _HAS_ADB = False | |
| def load_preset(preset: str): | |
| """Loads a preset file and returns the tools, messages and extra_context.""" | |
| click.echo(click.style(f"Loading preset from {preset}:", dim=True)) | |
| if not os.path.exists(preset): | |
| click.echo(click.style(f"Preset file not found: {preset}", fg="red")) | |
| return None, None, None | |
| spec = importlib.util.spec_from_file_location("user_tools", preset) | |
| if not spec or not spec.loader: | |
| click.echo(click.style(f"Failed to load tools from {preset}", fg="red")) | |
| return None, None, None | |
| user_tools = importlib.util.module_from_spec(spec) | |
| spec.loader.exec_module(user_tools) | |
| tools = getattr(user_tools, "tools", None) | |
| if tools is None: | |
| tools = [ | |
| obj | |
| for name, obj in inspect.getmembers(user_tools, inspect.isfunction) | |
| if obj.__module__ == "user_tools" | |
| ] | |
| messages = None | |
| system_instruction = getattr(user_tools, "system_instruction", None) | |
| if system_instruction: | |
| click.echo( | |
| click.style(f"- System instruction: {system_instruction}", dim=True) | |
| ) | |
| messages = [{ | |
| "role": "system", | |
| "content": [{"type": "text", "text": system_instruction}], | |
| }] | |
| click.echo(click.style("- Tools:", dim=True)) | |
| for tool in tools: | |
| click.echo( | |
| click.style(f" - {getattr(tool, '__name__', str(tool))}", dim=True) | |
| ) | |
| extra_context = getattr(user_tools, "extra_context", None) | |
| if extra_context: | |
| click.echo(click.style(f"- Extra context: {extra_context}", dim=True)) | |
| return tools, messages, extra_context | |
| class LoggingToolEventHandler(litert_lm.ToolEventHandler): | |
| """Log tool call and tool response events.""" | |
| def __init__(self, model): | |
| self.model = model | |
| def approve_tool_call(self, tool_call): | |
| """Logs a tool call.""" | |
| if self.model.active_channel is not None: | |
| click.echo("\n", nl=False) | |
| self.model.active_channel = None | |
| click.echo( | |
| click.style( | |
| f"[tool_call] {json.dumps(tool_call['function'])}", fg="green" | |
| ) | |
| ) | |
| return True | |
| def process_tool_response(self, tool_response): | |
| """Logs a tool response.""" | |
| click.echo( | |
| click.style(f"[tool_response] {json.dumps(tool_response)}", fg="green") | |
| ) | |
| return tool_response | |
| def _parse_backend(backend: str) -> litert_lm.Backend: | |
| """Parses the backend string and returns the corresponding Backend enum.""" | |
| backend_lower = backend.lower() | |
| if backend_lower == "gpu": | |
| return litert_lm.Backend.GPU | |
| return litert_lm.Backend.CPU | |
| class Model: | |
| """Represents a LiteRT-LM model. | |
| Attributes: | |
| model_id: The ID of the model. | |
| model_path: The local path to the model file. | |
| active_channel: The name of the currently active channel, or None if default | |
| text is being printed. | |
| """ | |
| model_id: str | |
| model_path: str | |
| active_channel: str | None = None | |
| def exists(self) -> bool: | |
| """Returns True if the model file exists locally.""" | |
| return os.path.isfile(self.model_path) | |
| def to_str(self) -> str: | |
| """Returns a string representation of the model.""" | |
| return self.model_id | |
| def run_interactive( | |
| self, | |
| is_android: bool = False, | |
| backend: str = "cpu", | |
| preset: str | None = None, | |
| prompt: str | None = None, | |
| enable_speculative_decoding: bool | None = None, | |
| no_template: bool = False, | |
| ): | |
| """Runs the model interactively or with a single prompt. | |
| Args: | |
| is_android: Whether to run the model on an Android device via ADB. | |
| backend: The backend to use (cpu or gpu). | |
| preset: Path to a Python file containing tool functions and system | |
| instructions. | |
| prompt: A single prompt to run once and exit. | |
| enable_speculative_decoding: Whether to enable speculative decoding. If | |
| None, use the model's default. | |
| no_template: Interact with the model directly without applying prompt | |
| templates or stripping stop tokens. | |
| """ | |
| if not self.exists(): | |
| click.echo( | |
| click.style( | |
| f"Could not find {self.to_str()} locally in {self.model_path}.", | |
| fg="red", | |
| ) | |
| ) | |
| return | |
| try: | |
| backend_val = _parse_backend(backend) | |
| if is_android: | |
| if not _HAS_ADB: | |
| raise ImportError("litert_lm.adb dependencies are not available.") | |
| engine_cm = adb_engine.AdbEngine(self.model_path, backend=backend_val) | |
| else: | |
| engine_cm = litert_lm.Engine( | |
| self.model_path, | |
| backend=backend_val, | |
| enable_speculative_decoding=enable_speculative_decoding, | |
| ) | |
| with engine_cm as engine: | |
| if no_template: | |
| runner_cm = engine.create_session(apply_prompt_template=False) | |
| else: | |
| tools = None | |
| messages = None | |
| extra_context = None | |
| if preset: | |
| tools, messages, extra_context = load_preset(preset) | |
| if tools is None and messages is None and extra_context is None: | |
| return | |
| handler = LoggingToolEventHandler(self) if tools else None | |
| runner_cm = engine.create_conversation( | |
| tools=tools, | |
| messages=messages, | |
| tool_event_handler=handler, | |
| extra_context=extra_context, | |
| ) | |
| with runner_cm as runner: | |
| if prompt: | |
| if isinstance(runner, litert_lm.AbstractSession): | |
| self._execute_raw_prompt(runner, prompt) | |
| elif isinstance(runner, litert_lm.AbstractConversation): | |
| self._execute_prompt(runner, prompt) | |
| return | |
| click.echo( | |
| click.style( | |
| "[enter] submit | [ctrl+j] newline | [ctrl+c] clear/exit", | |
| fg="cyan", | |
| ) | |
| ) | |
| click.echo() | |
| history_path = os.path.join( | |
| os.path.expanduser("~"), ".litert-lm", "history" | |
| ) | |
| os.makedirs(os.path.dirname(history_path), exist_ok=True) | |
| prompt_session = prompt_toolkit.PromptSession( | |
| history=prompt_toolkit.history.FileHistory(history_path), | |
| key_bindings=self._create_keybindings(), | |
| ) | |
| while True: | |
| try: | |
| user_prompt = prompt_session.prompt( | |
| prompt_toolkit.ANSI(click.style("> ", fg="green", bold=True)), | |
| multiline=True, | |
| # Start the new line in the beginning of line. This makes | |
| # copying respecting the text. | |
| prompt_continuation=lambda width, line_number, is_soft_wrap: ( | |
| "" | |
| ), | |
| ) | |
| if not user_prompt: | |
| continue | |
| if isinstance(runner, litert_lm.AbstractSession): | |
| self._execute_raw_prompt( | |
| runner, | |
| user_prompt, | |
| ) | |
| elif isinstance(runner, litert_lm.AbstractConversation): | |
| self._execute_prompt( | |
| runner, | |
| user_prompt, | |
| ) | |
| except EOFError: | |
| break | |
| except KeyboardInterrupt: | |
| # Catch Ctrl+C at the input prompt | |
| click.echo() | |
| continue | |
| except Exception: # pylint: disable=broad-exception-caught | |
| click.echo(click.style("Error during inference", fg="red")) | |
| traceback.print_exc() | |
| except Exception: # pylint: disable=broad-exception-caught | |
| click.echo(click.style("An error occurred", fg="red")) | |
| traceback.print_exc() | |
| def _execute_prompt( | |
| self, conversation: litert_lm.AbstractConversation, prompt: str | |
| ): | |
| """Executes a single prompt and prints the result.""" | |
| self.active_channel = None | |
| stream = conversation.send_message_async(prompt) | |
| try: | |
| for chunk in stream: | |
| # Handle regular content | |
| content_list = chunk.get("content", []) | |
| for item in content_list: | |
| if item.get("type") == "text": | |
| if self.active_channel is not None: | |
| click.echo() | |
| self.active_channel = None | |
| click.echo(click.style(item.get("text", ""), fg="yellow"), nl=False) | |
| # Handle channels | |
| channels = chunk.get("channels", {}) | |
| for channel_name, channel_content in channels.items(): | |
| if self.active_channel != channel_name: | |
| if self.active_channel is not None: | |
| click.echo() | |
| click.echo(click.style(f"[{channel_name}] ", fg="blue"), nl=False) | |
| self.active_channel = channel_name | |
| click.echo(click.style(channel_content, fg="yellow"), nl=False) | |
| if self.active_channel is not None: | |
| click.echo() | |
| else: | |
| click.echo() | |
| except KeyboardInterrupt: | |
| conversation.cancel_process() | |
| # Empty the iterator queue. | |
| # This ensures we don't throw away StopIteration. | |
| for _ in stream: | |
| pass | |
| click.echo(click.style("\n[Generation cancelled]", dim=True)) | |
| def _execute_raw_prompt( | |
| self, session: litert_lm.AbstractSession, prompt: str | |
| ): | |
| """Executes a single raw prompt and prints the result.""" | |
| session.run_prefill([prompt]) | |
| stream = session.run_decode_async() | |
| try: | |
| for chunk in stream: | |
| if chunk.texts: | |
| click.echo(click.style(chunk.texts[0], fg="yellow"), nl=False) | |
| click.echo() | |
| except KeyboardInterrupt: | |
| # Empty the iterator queue. | |
| for _ in stream: | |
| pass | |
| click.echo(click.style("\n[Generation cancelled]", dim=True)) | |
| def _create_keybindings(self) -> key_binding.KeyBindings: | |
| """Creates keybindings for the interactive prompt.""" | |
| kb = key_binding.KeyBindings() | |
| # Key binding for sending the prompt. | |
| def _handle_enter(event): | |
| buffer = event.current_buffer | |
| if buffer.text.strip(): | |
| buffer.validate_and_handle() | |
| # Key binding for new line. Note that terminal cannot take | |
| # "shift+enter", and "ctrl+enter" | |
| # standard terminal convention. | |
| # alt+enter and esc+enter | |
| def _handle_newline(event): | |
| event.current_buffer.insert_text("\n") | |
| # Key binding for clearing input or exiting. | |
| def _handle_clear_or_exit(event): | |
| buffer = event.current_buffer | |
| if buffer.text: | |
| buffer.text = "" | |
| else: | |
| event.app.exit(exception=EOFError) | |
| return kb | |
| def benchmark( | |
| self, | |
| prefill_tokens: int = 256, | |
| decode_tokens: int = 256, | |
| is_android: bool = False, | |
| backend: str = "cpu", | |
| enable_speculative_decoding: bool | None = None, | |
| ): | |
| """Benchmarks the model. | |
| Args: | |
| prefill_tokens: The number of tokens to prefill. | |
| decode_tokens: The number of tokens to decode. | |
| is_android: Whether to run the benchmark on an Android device via ADB. | |
| backend: The backend to use (cpu or gpu). | |
| enable_speculative_decoding: Whether to enable speculative decoding. If | |
| None, use the model's default. | |
| """ | |
| if not self.exists(): | |
| click.echo( | |
| click.style( | |
| f"Could not find {self.to_str()} locally in {self.model_path}.", | |
| fg="red", | |
| ) | |
| ) | |
| return | |
| try: | |
| backend_val = _parse_backend(backend) | |
| if is_android: | |
| if not _HAS_ADB: | |
| raise ImportError("litert_lm.adb dependencies are not available.") | |
| benchmark_obj = adb_benchmark.AdbBenchmark( | |
| self.model_path, | |
| backend=backend_val, | |
| prefill_tokens=prefill_tokens, | |
| decode_tokens=decode_tokens, | |
| cache_dir=":nocache", | |
| ) | |
| else: | |
| benchmark_obj = litert_lm.Benchmark( | |
| self.model_path, | |
| backend=backend_val, | |
| prefill_tokens=prefill_tokens, | |
| decode_tokens=decode_tokens, | |
| cache_dir=":nocache", | |
| enable_speculative_decoding=enable_speculative_decoding, | |
| ) | |
| click.echo(f"Benchmarking model: {self.to_str()} ({self.model_path})") | |
| click.echo(f"Number of tokens in prefill: {prefill_tokens}") | |
| click.echo(f"Number of tokens in decode : {decode_tokens}") | |
| click.echo(f"Backend : {backend}") | |
| spec_dec_str = "auto" | |
| if enable_speculative_decoding is True: | |
| spec_dec_str = "true" | |
| elif enable_speculative_decoding is False: | |
| spec_dec_str = "false" | |
| print(f"Speculative decoding : {spec_dec_str}") | |
| if is_android: | |
| click.echo("Target : Android") | |
| result = benchmark_obj.run() | |
| click.echo("----- Results -----") | |
| click.echo( | |
| f"Prefill speed: {result.last_prefill_tokens_per_second:.2f}" | |
| " tokens/s" | |
| ) | |
| click.echo( | |
| f"Decode speed: {result.last_decode_tokens_per_second:.2f}" | |
| " tokens/s" | |
| ) | |
| click.echo(f"Init time: {result.init_time_in_second:.4f} s") | |
| click.echo( | |
| f"Time to first token: {result.time_to_first_token_in_second:.4f} s" | |
| ) | |
| except Exception: # pylint: disable=broad-exception-caught | |
| click.echo(click.style("An error occurred during benchmarking", fg="red")) | |
| traceback.print_exc() | |
| def get_all_models(cls): | |
| """Returns a list of all locally available models.""" | |
| model_paths = glob.glob( | |
| "*/model.litertlm", | |
| root_dir=get_converted_models_base_dir(), | |
| recursive=True, | |
| ) | |
| return [ | |
| Model.from_model_id( | |
| path.removesuffix("/model.litertlm").replace("--", "/") | |
| ) | |
| for path in model_paths | |
| ] | |
| def from_model_reference(cls, model_reference): | |
| """Creates a Model instance from a model reference.""" | |
| if os.path.exists(model_reference): | |
| return cls.from_model_path(model_reference) | |
| else: | |
| # assume the reference is model_id | |
| return cls.from_model_id(model_reference) | |
| def from_model_path(cls, model_path): | |
| """Creates a Model instance from a model path.""" | |
| return cls( | |
| model_id=os.path.basename(model_path), | |
| model_path=os.path.abspath(model_path), | |
| ) | |
| def from_model_id(cls, model_id): | |
| """Creates a Model instance from a model ID.""" | |
| return cls( | |
| model_id=model_id, | |
| model_path=os.path.join( | |
| get_converted_models_base_dir(), | |
| model_id.replace("/", "--"), | |
| "model.litertlm", | |
| ), | |
| ) | |
| # Just to use the huggingface convention. Likely to change. | |
| def model_id_dir_name(model_id): | |
| """Converts a model ID to a directory name.""" | |
| return model_id.replace("/", "--") | |
| # ~/.litert-lm/models | |
| def get_converted_models_base_dir(): | |
| """Gets the base directory for all converted models.""" | |
| return os.path.join(os.path.expanduser("~"), ".litert-lm", "models") | |
| # ~/.litert-lm/models/<model_id> | |
| def get_model_dir(model_id): | |
| """Gets the model directory for a given model ID.""" | |
| return os.path.join( | |
| get_converted_models_base_dir(), | |
| model_id_dir_name(model_id), | |
| ) | |