File size: 11,584 Bytes
5f923cd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
# Copyright 2026 The ODML Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import pathlib

from absl import flags
from absl.testing import absltest

import litert_lm

FLAGS = flags.FLAGS


class LiteRtLmTestBase(absltest.TestCase):

  @classmethod
  def setUpClass(cls):
    super().setUpClass()
    litert_lm.set_min_log_severity(litert_lm.LogSeverity.VERBOSE)

  def setUp(self):
    super().setUp()
    self.model_path = str(
        pathlib.Path(FLAGS.test_srcdir)
        / "litert_lm/runtime/testdata/test_lm.litertlm"
    )

  def _create_engine(self, max_num_tokens=10):
    return litert_lm.Engine(
        self.model_path,
        litert_lm.Backend.CPU,
        max_num_tokens=max_num_tokens,
        cache_dir=":nocache",
    )

  @staticmethod
  def _extract_text(stream):
    text_pieces = []
    for chunk in stream:
      content_list = chunk.get("content", [])
      for item in content_list:
        if item.get("type") == "text":
          text_pieces.append(item.get("text", ""))
    return text_pieces


class EngineTest(LiteRtLmTestBase):

  _EXPECTED_RESPONSE = "TarefaByte دارایेत्र investigaciónప్రదేశ"

  def test_conversation_send_message(self):
    with (
        self._create_engine() as engine,
        engine.create_conversation() as conversation,
    ):
      self.assertIsNotNone(engine)
      self.assertIsNotNone(conversation)
      user_message = {"role": "user", "content": "Hello world!"}
      message = conversation.send_message(user_message)

      expected_message = {
          "role": "assistant",
          "content": [{"type": "text", "text": self._EXPECTED_RESPONSE}],
      }
      self.assertEqual(message, expected_message)

  def test_conversation_send_message_async(self):
    with (
        self._create_engine() as engine,
        engine.create_conversation() as conversation,
    ):
      self.assertIsNotNone(engine)
      self.assertIsNotNone(conversation)
      user_message = {"role": "user", "content": "Hello world!"}
      stream = conversation.send_message_async(user_message)
      text_pieces = self._extract_text(stream)

      self.assertEqual("".join(text_pieces), self._EXPECTED_RESPONSE)
      self.assertLen(text_pieces, 6)

  def test_conversation_send_message_async_cancel(self):
    with (
        self._create_engine() as engine,
        engine.create_conversation() as conversation,
    ):
      user_message = {"role": "user", "content": "Hello world!"}
      stream = conversation.send_message_async(user_message)

      text_pieces = []
      for chunk in stream:
        content_list = chunk.get("content", [])
        for item in content_list:
          if item.get("type") == "text":
            text_pieces.append(item.get("text", ""))

        # Cancel the process after receiving the first chunk.
        conversation.cancel_process()

      # We only expect to receive the first piece before cancellation.
      self.assertNotEmpty(text_pieces)
      self.assertLess(len(text_pieces), 6)  # Cancelled before completion

  def test_benchmark_class(self):
    benchmark = litert_lm.Benchmark(
        self.model_path,
        litert_lm.Backend.CPU,
        prefill_tokens=10,
        decode_tokens=10,
        cache_dir=":nocache",
    )
    self.assertIsInstance(benchmark, litert_lm.AbstractBenchmark)
    result = benchmark.run()
    self.assertIsInstance(result, litert_lm.BenchmarkInfo)
    self.assertGreater(result.init_time_in_second, 0)
    self.assertGreater(result.time_to_first_token_in_second, 0)
    self.assertGreater(result.last_prefill_token_count, 0)
    self.assertGreater(result.last_prefill_tokens_per_second, 0)
    self.assertGreater(result.last_decode_token_count, 0)
    self.assertGreater(result.last_decode_tokens_per_second, 0)

  def test_engine_abc_inheritance(self):
    with self._create_engine() as engine:
      self.assertIsInstance(engine, litert_lm.AbstractEngine)

  def test_engine_tokenization_api(self):
    with self._create_engine() as engine:
      token_ids = engine.tokenize("Hello world!")
      self.assertNotEmpty(token_ids)
      self.assertTrue(all(isinstance(token_id, int) for token_id in token_ids))

      decoded = engine.detokenize(token_ids)
      self.assertIsInstance(decoded, str)
      self.assertNotEmpty(decoded)

  def test_engine_special_token_metadata(self):
    with self._create_engine() as engine:
      bos_token_id = engine.bos_token_id
      if bos_token_id is not None:
        self.assertIsInstance(bos_token_id, int)

      eos_token_ids = engine.eos_token_ids
      self.assertIsInstance(eos_token_ids, list)
      for stop_token_ids in eos_token_ids:
        self.assertIsInstance(stop_token_ids, list)
        self.assertTrue(
            all(isinstance(token_id, int) for token_id in stop_token_ids)
        )

  def test_conversation_abc_inheritance(self):
    with (
        self._create_engine() as engine,
        engine.create_conversation() as conversation,
    ):
      self.assertIsInstance(conversation, litert_lm.AbstractConversation)

  def test_create_conversation_with_messages(self):
    messages = [{"role": "system", "content": "You are a helpful assistant."}]
    with (
        self._create_engine() as engine,
        engine.create_conversation(messages=messages) as conversation,
    ):
      self.assertEqual(conversation.messages, messages)

  def test_create_conversation_with_extra_context(self):
    extra_context = {"key": "value"}
    with (
        self._create_engine() as engine,
        engine.create_conversation(extra_context=extra_context) as conversation,
    ):
      self.assertEqual(conversation.extra_context, extra_context)

  def test_str_input_support(self):
    with (
        self._create_engine() as engine,
        engine.create_conversation() as conversation,
    ):
      # Test with str input
      message = conversation.send_message("Hello world!")
      self.assertEqual(message["role"], "assistant")

  def test_str_input_support_async(self):
    with (
        self._create_engine() as engine,
        engine.create_conversation() as conversation,
    ):
      # Test with str input (async)
      stream = conversation.send_message_async("Hello world!")
      text_pieces = self._extract_text(stream)
      self.assertNotEmpty(text_pieces)

  def test_tool_event_handler_storage(self):

    class MyHandler(litert_lm.ToolEventHandler):

      def approve_tool_call(self, tool_call):
        return True

      def process_tool_response(self, tool_response):
        return tool_response

    handler = MyHandler()
    with (
        self._create_engine() as engine,
        engine.create_conversation(tool_event_handler=handler) as conversation,
    ):
      self.assertEqual(conversation.tool_event_handler, handler)

  def test_create_session_with_apply_prompt_template(self):
    with self._create_engine() as engine:
      with engine.create_session(apply_prompt_template=True) as session:
        self.assertIsInstance(session, litert_lm.AbstractSession)
      with engine.create_session(apply_prompt_template=False) as session:
        self.assertIsInstance(session, litert_lm.AbstractSession)

  def test_session_api_run_decode(self):
    with (
        self._create_engine() as engine,
        engine.create_session() as session,
    ):
      self.assertIsInstance(session, litert_lm.AbstractSession)
      session.run_prefill(["Hello", " world!"])
      responses = session.run_decode()
      self.assertIsInstance(responses, litert_lm.Responses)
      self.assertLen(responses.texts, 1)
      self.assertEqual(responses.texts, [self._EXPECTED_RESPONSE])
      self.assertLen(responses.scores, 1)
      self.assertEmpty(responses.token_lengths)

  def test_session_api_run_text_scoring_with_token_lengths(self):
    with (
        self._create_engine() as engine,
        engine.create_session() as session,
    ):
      self.assertIsInstance(session, litert_lm.AbstractSession)
      session.run_prefill(["Hello", " world!"])
      scoring_responses = session.run_text_scoring(
          ["Hello"], store_token_lengths=True
      )
      self.assertIsInstance(scoring_responses, litert_lm.Responses)
      self.assertEmpty(scoring_responses.texts)
      self.assertLen(scoring_responses.scores, 1)
      self.assertLen(scoring_responses.token_lengths, 1)

  def test_session_api_run_text_scoring_no_token_lengths(self):
    with (
        self._create_engine() as engine,
        engine.create_session() as session,
    ):
      self.assertIsInstance(session, litert_lm.AbstractSession)
      session.run_prefill(["Hello", " world!"])
      scoring_responses = session.run_text_scoring(
          ["Hello"], store_token_lengths=False
      )
      self.assertIsInstance(scoring_responses, litert_lm.Responses)
      self.assertEmpty(scoring_responses.texts)
      self.assertLen(scoring_responses.scores, 1)
      self.assertEmpty(scoring_responses.token_lengths)

  def test_session_api_run_decode_async(self):
    with (
        self._create_engine() as engine,
        engine.create_session() as session,
    ):
      self.assertIsInstance(session, litert_lm.AbstractSession)
      session.run_prefill(["Hello", " world!"])
      stream = session.run_decode_async()
      responses = list(stream)
      self.assertNotEmpty(responses)
      self.assertLen(responses, 6)
      full_text = "".join(["".join(r.texts) for r in responses])
      self.assertEqual(full_text, self._EXPECTED_RESPONSE)

  def test_session_api_cancel_process(self):
    with (
        self._create_engine() as engine,
        engine.create_session() as session,
    ):
      self.assertIsInstance(session, litert_lm.AbstractSession)
      session.run_prefill(["Hello world!"])
      stream = session.run_decode_async()

      responses = []
      for response in stream:
        responses.append(response)
        session.cancel_process()

      self.assertNotEmpty(responses)
      # We expect fewer responses than a full decode (which is 6 chunks).
      self.assertLess(len(responses), 6)


class FunctionCallingTest(LiteRtLmTestBase):

  def test_create_conversation_with_tools(self):

    def get_weather(location: str):
      """Gets weather for a location."""
      return f"Weather in {location} is sunny."

    tools = [get_weather]
    with (
        self._create_engine() as engine,
        engine.create_conversation(tools=tools) as conversation,
    ):
      self.assertEqual(conversation.tools, tools)

  def test_send_message_async_with_tools(self):

    def get_weather(location: str):
      """Gets weather for a location."""
      return f"Weather in {location} is sunny."

    tools = [get_weather]
    with (
        self._create_engine() as engine,
        engine.create_conversation(tools=tools) as conversation,
    ):
      user_message = {
          "role": "user",
          "content": "What's the weather in London?",
      }
      stream = conversation.send_message_async(user_message)
      text_pieces = self._extract_text(stream)
      self.assertNotEmpty(text_pieces)


if __name__ == "__main__":
  absltest.main()