File size: 14,218 Bytes
5f923cd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
// Copyright 2025 The ODML Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//      http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#ifndef THIRD_PARTY_ODML_LITERT_LM_RUNTIME_ENGINE_ENGINE_H_
#define THIRD_PARTY_ODML_LITERT_LM_RUNTIME_ENGINE_ENGINE_H_

#include <memory>
#include <vector>

#include "absl/functional/any_invocable.h"  // from @com_google_absl
#include "absl/log/absl_log.h"  // from @com_google_absl
#include "absl/status/status.h"  // from @com_google_absl
#include "absl/status/statusor.h"  // from @com_google_absl
#include "absl/strings/string_view.h"  // from @com_google_absl
#include "absl/time/time.h"  // from @com_google_absl
#include "runtime/components/tokenizer.h"
#include "runtime/engine/engine_settings.h"
#include "runtime/engine/io_types.h"

namespace litert::lm {

// Engine is the interface for the LLM runtime. It is responsible for
// - Initializing the LLM model and related resources, e.g. tokenizer,
//   embedder, etc.
// - Providing the APIs to create the Session.
//
// The Session is responsible for hosting the internal state (e.g. conversation
// history) of each separate interaction with LLM. It is created by the Engine
// and is responsible for:
// - Generating content from the input prompt/query.
// - Running the prefill and decode processes.
//
// Example usage:
//   // Create the model assets.
//   auto model_assets = ModelAssets::Create(model_path);
//   CHECK_OK(model_assets);
//
//   // Create the engine.
//   auto engine = Engine::CreateEngine(EngineSettings::CreateDefault(
//       model_assets, litert::lm::Backend::CPU));
//   CHECK_OK(engine);
//
//   // Create the session.
//   auto session = engine->CreateSession(SessionConfig::CreateDefault());
//   CHECK_OK(session);
//
//   // Run generate content.
//   auto responses = (*session)->GenerateContent({InputText("What's the tallest
//   building in the world?")});
//   CHECK_OK(responses);
//
//   // Print the response.
//   std::cout << *responses << std::endl;
class Engine {
 public:
  virtual ~Engine() = default;

  // Session is responsible for hosting the internal state (e.g. conversation
  // history) of each separate interaction with LLM.
  class Session {
   public:
    // The TaskController is responsible for controlling the async task
    // execution.
    class TaskController {
     public:
      TaskController() = default;

      // The TaskController is not copyable. This is to avoid
      // the user from accidentally copying the TaskController and calling the
      // CancelProcess function multiple times.
      TaskController(const TaskController&) = delete;
      TaskController& operator=(const TaskController&) = delete;

      // The TaskController is movable.
      TaskController(TaskController&&) = default;
      TaskController& operator=(TaskController&&) = default;

      // The TaskController destructor.
      virtual ~TaskController() = default;

      // Waits until all the tasks are done or the timeout is reached. The
      // function will return error if the timeout is reached.
      virtual absl::Status WaitUntilDone(absl::Duration timeout) {
        return absl::UnimplementedError("Not implemented.");
      };

      // Cancels the ongoing inference process. Note that if this function is
      // called after the inference process is done, the function will be a
      // no-op.
      virtual absl::Status Cancel() {
        return absl::UnimplementedError("Not implemented.");
      };
    };

    virtual ~Session() = default;

    // High-level API to generate content from the input prompt/query. This
    // function will handle the prefill and decode processes internally and
    // the usage is similar to the Gemini Text Generation API
    // (https://ai.google.dev/gemini-api/docs/text-generation).
    // - contents: The input data for generation.
    virtual absl::StatusOr<Responses> GenerateContent(
        const std::vector<InputData>& contents) = 0;

    // This is a not blocking call and the function will return right away. The
    // result will be streamed through the callback.
    //
    // - contents: The input data for generation.
    // - callback: Callback to receive streamed results.
    //   Note:
    //     - If the generation is done successfully, the callback will be
    //       called with empty responses to signal the completion.
    //     - If there is an error during the streaming process, the callback
    //       will be called with the error status and no further results will be
    //       sent.
    //     - If the generation is cancelled, the callback will be called
    //       with a Cancellation error.
    virtual absl::Status GenerateContentStream(
        const std::vector<InputData>& contents,
        absl::AnyInvocable<void(absl::StatusOr<Responses>)> callback) = 0;

    // Same as above, but with a custom decode config.
    // - decode_config: configuration for the model decode process.
    virtual absl::Status GenerateContentStream(
        const std::vector<InputData>& contents,
        absl::AnyInvocable<void(absl::StatusOr<Responses>)> callback,
        const DecodeConfig& decode_config) = 0;

    // Scores the target text after the prefill process is done. This function
    // will only run the decode process to fetch the decode output logits, which
    // is used to calculate the target text's score and update the model memory
    // using the target_text tokens.
    // This function should be called after the prefill process is done.
    // - target_text: The target text to score.
    // - store_token_lengths: Whether to store the token lengths of the target
    //   texts in `Responses`.
    // - returns: This function returns the score associated with the target
    // text after the model has been prefilled. The returned score is the sum of
    // the negative log probability of seeing the target text during decode.
    virtual absl::StatusOr<Responses> RunTextScoring(
        const std::vector<absl::string_view>& target_text,
        bool store_token_lengths) = 0;

    // Similar to the above RunTextScoring function, but this is a not blocking
    // call and the function will return right away. The processing status will
    // be signaled through the callback.
    // - target_text: The target text to score.
    // - callback: Callback to receive the scoring results.
    // - store_token_lengths: Whether to store the token lengths of the target
    //   texts in `Responses`.
    virtual absl::StatusOr<std::unique_ptr<TaskController>> RunTextScoringAsync(
        const std::vector<absl::string_view>& target_text,
        absl::AnyInvocable<void(absl::StatusOr<Responses>)> callback,
        bool store_token_lengths) {
      return absl::UnimplementedError("Not implemented.");
    }

    // Adds the input prompt/query to the model for starting the prefilling
    // process. Note that the user can break down their prompt/query into
    // multiple chunks and call this function multiple times.
    //
    // This is a blocking call and the function will return when the prefill
    // process is done.
    virtual absl::Status RunPrefill(const std::vector<InputData>& contents) = 0;

    // This is a not blocking call and the function will return right away. The
    // processing status will be signaled through the callback.
    virtual absl::StatusOr<std::unique_ptr<TaskController>> RunPrefillAsync(
        const std::vector<InputData>& contents,
        absl::AnyInvocable<void(absl::StatusOr<Responses>)> callback) {
      return absl::UnimplementedError("Not implemented.");
    }

    // Starts the decoding process for the model to predict the response based
    // on the input prompt/query added after using RunPrefill* functions.
    // This is a blocking call and the function will return when the decoding
    // process is done.
    virtual absl::StatusOr<Responses> RunDecode() = 0;

    // Same as above, but with a custom decode config.
    // - decode_config: configuration for the model decode process.
    virtual absl::StatusOr<Responses> RunDecode(
        const DecodeConfig& decode_config) = 0;

    // Startes the decoding process for the model to predict the response based
    // on the input prompt/query added after using RunPrefill* functions.
    // This is a not blocking call and the function will return right away. The
    // result will be streamed through the callback.
    // - callback: Callback to receive streamed results.
    virtual absl::StatusOr<std::unique_ptr<TaskController>> RunDecodeAsync(
        absl::AnyInvocable<void(absl::StatusOr<Responses>)> callback) {
      return absl::UnimplementedError("Not implemented.");
    }

    // Same as above, but with a custom decode config.
    // - decode_config: configuration for the model decode process.
    virtual absl::StatusOr<std::unique_ptr<TaskController>> RunDecodeAsync(
        absl::AnyInvocable<void(absl::StatusOr<Responses>)> callback,
        const DecodeConfig& decode_config) {
      return absl::UnimplementedError("Not implemented.");
    }

    // Returns the benchmark info for the session. Returns error if the
    // benchmark is not enabled.
    virtual absl::StatusOr<BenchmarkInfo> GetBenchmarkInfo() = 0;

    // Returns the mutable benchmark info for the session. Returns error if the
    // benchmark is not enabled.
    virtual absl::StatusOr<BenchmarkInfo*> GetMutableBenchmarkInfo() = 0;

    // Cancels the ongoing inference process. Note that if this function is
    // called, the inference process will return with a kCancelled error. The
    // session could still be used after afterwards.
    virtual void CancelProcess() {
      ABSL_LOG(FATAL) << "CancelProcess is not implemented.";
    }

    // Waits until all the tasks are done or the default timeout is reached.
    virtual absl::Status WaitUntilDone() = 0;

    // Clones the session.
    // The cloned session have all the settings and context
    // of the original session up to the point that the clone function is
    // called.
    // - callback: Callback to when the streamed results.
    //
    // Example usage:
    //   Session session1 = engine->CreateSession(...);
    //   session1->Prefill("What is the tallest building ");
    //   Session session2 = session1->Clone();
    //   session1->Prefill("in the world?");
    //   session1->Decode();
    //   session2->Prefill("in France?");
    //   session2->Decode();
    virtual absl::StatusOr<std::unique_ptr<Session>> Clone() {
      return absl::UnimplementedError("Not implemented.");
    };

    // Clones the session asynchronously.
    // The cloned session have all the settings and context
    // of the original session up to the point that the clone function is
    // called.
    // - callback: Callback to when the streamed results.
    //
    // Example usage:
    //   Session session1 = engine->CreateSession(...);
    //   session1->RunPrefillAsync("What is the tallest building ", ...);
    //   Session session2 = session1->CloneAsync(...);
    //   session1->RunPrefillAsync("in the world?", ...);
    //   session1->RunDecodeAsync(...);
    //   session2->RunPrefillAsync("in France?", ...);
    //   session2->RunDecodeAsync(...);
    virtual absl::StatusOr<std::unique_ptr<Session>> CloneAsync(
        absl::AnyInvocable<void(absl::StatusOr<Responses>)> callback) {
      return absl::UnimplementedError("Not implemented.");
    };
    // Save the current step with the name `label`. You can later rewind to this
    // checkpoint using `RewindToCheckpoint(label)`. If the checkpoint name
    // already exists, the step number will be overwritten.
    virtual absl::Status SaveCheckpoint(absl::string_view label) {
      return absl::UnimplementedError("SaveCheckpoint not implemented.");
    }

    // Rewinds the session to the given checkpoint. Checkpoints after the
    // restored step will be removed. Returns an error if the checkpoint name
    // does not exist.
    virtual absl::Status RewindToCheckpoint(absl::string_view label) {
      return absl::UnimplementedError("RewindToCheckpoint not implemented.");
    }

    // Get the current step of the session.
    virtual absl::StatusOr<int> GetCurrentStep() const {
      return absl::UnimplementedError("GetCurrentStep not implemented.");
    }

    // Get the reference to the session config for the session.
    virtual const SessionConfig& GetSessionConfig() const = 0;
  };

  // Method to create the Session.
  virtual absl::StatusOr<std::unique_ptr<Session>> CreateSession(
      const SessionConfig& session_config) = 0;

  // Waits until the engine is done with all the tasks. The function will
  // return error if the timeout is reached.
  virtual absl::Status WaitUntilDone(absl::Duration timeout) {
    return absl::UnimplementedError("Not implemented.");
  }

  // Returns the EngineSettings currently used by the engine.
  virtual const EngineSettings& GetEngineSettings() const = 0;

  // Get the reference to the tokenizer for the engine.
  virtual const Tokenizer& GetTokenizer() const = 0;

  // Get the audio model properties for the session. This is only available
  // if the engine is created with audio modality enabled.
  virtual absl::StatusOr<AudioExecutorProperties> GetAudioExecutorProperties()
      const = 0;

  // Get the vision model properties for the session. This is only available
  // if the engine is created with vision modality enabled.
  virtual absl::StatusOr<VisionExecutorProperties> GetVisionExecutorProperties()
      const = 0;

  // Default timeout duration for the engine/session processes.
  static constexpr absl::Duration kDefaultTimeout = absl::Minutes(10);
};

}  // namespace litert::lm

#endif  // THIRD_PARTY_ODML_LITERT_LM_RUNTIME_ENGINE_ENGINE_H_