LiteRT-LM / python /litert_lm /engine_test.py
SeaWolf-AI's picture
Upload full LiteRT-LM codebase
5f923cd verified
# Copyright 2026 The ODML Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import pathlib
from absl import flags
from absl.testing import absltest
import litert_lm
FLAGS = flags.FLAGS
class LiteRtLmTestBase(absltest.TestCase):
@classmethod
def setUpClass(cls):
super().setUpClass()
litert_lm.set_min_log_severity(litert_lm.LogSeverity.VERBOSE)
def setUp(self):
super().setUp()
self.model_path = str(
pathlib.Path(FLAGS.test_srcdir)
/ "litert_lm/runtime/testdata/test_lm.litertlm"
)
def _create_engine(self, max_num_tokens=10):
return litert_lm.Engine(
self.model_path,
litert_lm.Backend.CPU,
max_num_tokens=max_num_tokens,
cache_dir=":nocache",
)
@staticmethod
def _extract_text(stream):
text_pieces = []
for chunk in stream:
content_list = chunk.get("content", [])
for item in content_list:
if item.get("type") == "text":
text_pieces.append(item.get("text", ""))
return text_pieces
class EngineTest(LiteRtLmTestBase):
_EXPECTED_RESPONSE = "TarefaByte دارایेत्र investigaciónప్రదేశ"
def test_conversation_send_message(self):
with (
self._create_engine() as engine,
engine.create_conversation() as conversation,
):
self.assertIsNotNone(engine)
self.assertIsNotNone(conversation)
user_message = {"role": "user", "content": "Hello world!"}
message = conversation.send_message(user_message)
expected_message = {
"role": "assistant",
"content": [{"type": "text", "text": self._EXPECTED_RESPONSE}],
}
self.assertEqual(message, expected_message)
def test_conversation_send_message_async(self):
with (
self._create_engine() as engine,
engine.create_conversation() as conversation,
):
self.assertIsNotNone(engine)
self.assertIsNotNone(conversation)
user_message = {"role": "user", "content": "Hello world!"}
stream = conversation.send_message_async(user_message)
text_pieces = self._extract_text(stream)
self.assertEqual("".join(text_pieces), self._EXPECTED_RESPONSE)
self.assertLen(text_pieces, 6)
def test_conversation_send_message_async_cancel(self):
with (
self._create_engine() as engine,
engine.create_conversation() as conversation,
):
user_message = {"role": "user", "content": "Hello world!"}
stream = conversation.send_message_async(user_message)
text_pieces = []
for chunk in stream:
content_list = chunk.get("content", [])
for item in content_list:
if item.get("type") == "text":
text_pieces.append(item.get("text", ""))
# Cancel the process after receiving the first chunk.
conversation.cancel_process()
# We only expect to receive the first piece before cancellation.
self.assertNotEmpty(text_pieces)
self.assertLess(len(text_pieces), 6) # Cancelled before completion
def test_benchmark_class(self):
benchmark = litert_lm.Benchmark(
self.model_path,
litert_lm.Backend.CPU,
prefill_tokens=10,
decode_tokens=10,
cache_dir=":nocache",
)
self.assertIsInstance(benchmark, litert_lm.AbstractBenchmark)
result = benchmark.run()
self.assertIsInstance(result, litert_lm.BenchmarkInfo)
self.assertGreater(result.init_time_in_second, 0)
self.assertGreater(result.time_to_first_token_in_second, 0)
self.assertGreater(result.last_prefill_token_count, 0)
self.assertGreater(result.last_prefill_tokens_per_second, 0)
self.assertGreater(result.last_decode_token_count, 0)
self.assertGreater(result.last_decode_tokens_per_second, 0)
def test_engine_abc_inheritance(self):
with self._create_engine() as engine:
self.assertIsInstance(engine, litert_lm.AbstractEngine)
def test_engine_tokenization_api(self):
with self._create_engine() as engine:
token_ids = engine.tokenize("Hello world!")
self.assertNotEmpty(token_ids)
self.assertTrue(all(isinstance(token_id, int) for token_id in token_ids))
decoded = engine.detokenize(token_ids)
self.assertIsInstance(decoded, str)
self.assertNotEmpty(decoded)
def test_engine_special_token_metadata(self):
with self._create_engine() as engine:
bos_token_id = engine.bos_token_id
if bos_token_id is not None:
self.assertIsInstance(bos_token_id, int)
eos_token_ids = engine.eos_token_ids
self.assertIsInstance(eos_token_ids, list)
for stop_token_ids in eos_token_ids:
self.assertIsInstance(stop_token_ids, list)
self.assertTrue(
all(isinstance(token_id, int) for token_id in stop_token_ids)
)
def test_conversation_abc_inheritance(self):
with (
self._create_engine() as engine,
engine.create_conversation() as conversation,
):
self.assertIsInstance(conversation, litert_lm.AbstractConversation)
def test_create_conversation_with_messages(self):
messages = [{"role": "system", "content": "You are a helpful assistant."}]
with (
self._create_engine() as engine,
engine.create_conversation(messages=messages) as conversation,
):
self.assertEqual(conversation.messages, messages)
def test_create_conversation_with_extra_context(self):
extra_context = {"key": "value"}
with (
self._create_engine() as engine,
engine.create_conversation(extra_context=extra_context) as conversation,
):
self.assertEqual(conversation.extra_context, extra_context)
def test_str_input_support(self):
with (
self._create_engine() as engine,
engine.create_conversation() as conversation,
):
# Test with str input
message = conversation.send_message("Hello world!")
self.assertEqual(message["role"], "assistant")
def test_str_input_support_async(self):
with (
self._create_engine() as engine,
engine.create_conversation() as conversation,
):
# Test with str input (async)
stream = conversation.send_message_async("Hello world!")
text_pieces = self._extract_text(stream)
self.assertNotEmpty(text_pieces)
def test_tool_event_handler_storage(self):
class MyHandler(litert_lm.ToolEventHandler):
def approve_tool_call(self, tool_call):
return True
def process_tool_response(self, tool_response):
return tool_response
handler = MyHandler()
with (
self._create_engine() as engine,
engine.create_conversation(tool_event_handler=handler) as conversation,
):
self.assertEqual(conversation.tool_event_handler, handler)
def test_create_session_with_apply_prompt_template(self):
with self._create_engine() as engine:
with engine.create_session(apply_prompt_template=True) as session:
self.assertIsInstance(session, litert_lm.AbstractSession)
with engine.create_session(apply_prompt_template=False) as session:
self.assertIsInstance(session, litert_lm.AbstractSession)
def test_session_api_run_decode(self):
with (
self._create_engine() as engine,
engine.create_session() as session,
):
self.assertIsInstance(session, litert_lm.AbstractSession)
session.run_prefill(["Hello", " world!"])
responses = session.run_decode()
self.assertIsInstance(responses, litert_lm.Responses)
self.assertLen(responses.texts, 1)
self.assertEqual(responses.texts, [self._EXPECTED_RESPONSE])
self.assertLen(responses.scores, 1)
self.assertEmpty(responses.token_lengths)
def test_session_api_run_text_scoring_with_token_lengths(self):
with (
self._create_engine() as engine,
engine.create_session() as session,
):
self.assertIsInstance(session, litert_lm.AbstractSession)
session.run_prefill(["Hello", " world!"])
scoring_responses = session.run_text_scoring(
["Hello"], store_token_lengths=True
)
self.assertIsInstance(scoring_responses, litert_lm.Responses)
self.assertEmpty(scoring_responses.texts)
self.assertLen(scoring_responses.scores, 1)
self.assertLen(scoring_responses.token_lengths, 1)
def test_session_api_run_text_scoring_no_token_lengths(self):
with (
self._create_engine() as engine,
engine.create_session() as session,
):
self.assertIsInstance(session, litert_lm.AbstractSession)
session.run_prefill(["Hello", " world!"])
scoring_responses = session.run_text_scoring(
["Hello"], store_token_lengths=False
)
self.assertIsInstance(scoring_responses, litert_lm.Responses)
self.assertEmpty(scoring_responses.texts)
self.assertLen(scoring_responses.scores, 1)
self.assertEmpty(scoring_responses.token_lengths)
def test_session_api_run_decode_async(self):
with (
self._create_engine() as engine,
engine.create_session() as session,
):
self.assertIsInstance(session, litert_lm.AbstractSession)
session.run_prefill(["Hello", " world!"])
stream = session.run_decode_async()
responses = list(stream)
self.assertNotEmpty(responses)
self.assertLen(responses, 6)
full_text = "".join(["".join(r.texts) for r in responses])
self.assertEqual(full_text, self._EXPECTED_RESPONSE)
def test_session_api_cancel_process(self):
with (
self._create_engine() as engine,
engine.create_session() as session,
):
self.assertIsInstance(session, litert_lm.AbstractSession)
session.run_prefill(["Hello world!"])
stream = session.run_decode_async()
responses = []
for response in stream:
responses.append(response)
session.cancel_process()
self.assertNotEmpty(responses)
# We expect fewer responses than a full decode (which is 6 chunks).
self.assertLess(len(responses), 6)
class FunctionCallingTest(LiteRtLmTestBase):
def test_create_conversation_with_tools(self):
def get_weather(location: str):
"""Gets weather for a location."""
return f"Weather in {location} is sunny."
tools = [get_weather]
with (
self._create_engine() as engine,
engine.create_conversation(tools=tools) as conversation,
):
self.assertEqual(conversation.tools, tools)
def test_send_message_async_with_tools(self):
def get_weather(location: str):
"""Gets weather for a location."""
return f"Weather in {location} is sunny."
tools = [get_weather]
with (
self._create_engine() as engine,
engine.create_conversation(tools=tools) as conversation,
):
user_message = {
"role": "user",
"content": "What's the weather in London?",
}
stream = conversation.send_message_async(user_message)
text_pieces = self._extract_text(stream)
self.assertNotEmpty(text_pieces)
if __name__ == "__main__":
absltest.main()