Spaces:
Running
Running
File size: 14,218 Bytes
5f923cd | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 | // Copyright 2025 The ODML Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef THIRD_PARTY_ODML_LITERT_LM_RUNTIME_ENGINE_ENGINE_H_
#define THIRD_PARTY_ODML_LITERT_LM_RUNTIME_ENGINE_ENGINE_H_
#include <memory>
#include <vector>
#include "absl/functional/any_invocable.h" // from @com_google_absl
#include "absl/log/absl_log.h" // from @com_google_absl
#include "absl/status/status.h" // from @com_google_absl
#include "absl/status/statusor.h" // from @com_google_absl
#include "absl/strings/string_view.h" // from @com_google_absl
#include "absl/time/time.h" // from @com_google_absl
#include "runtime/components/tokenizer.h"
#include "runtime/engine/engine_settings.h"
#include "runtime/engine/io_types.h"
namespace litert::lm {
// Engine is the interface for the LLM runtime. It is responsible for
// - Initializing the LLM model and related resources, e.g. tokenizer,
// embedder, etc.
// - Providing the APIs to create the Session.
//
// The Session is responsible for hosting the internal state (e.g. conversation
// history) of each separate interaction with LLM. It is created by the Engine
// and is responsible for:
// - Generating content from the input prompt/query.
// - Running the prefill and decode processes.
//
// Example usage:
// // Create the model assets.
// auto model_assets = ModelAssets::Create(model_path);
// CHECK_OK(model_assets);
//
// // Create the engine.
// auto engine = Engine::CreateEngine(EngineSettings::CreateDefault(
// model_assets, litert::lm::Backend::CPU));
// CHECK_OK(engine);
//
// // Create the session.
// auto session = engine->CreateSession(SessionConfig::CreateDefault());
// CHECK_OK(session);
//
// // Run generate content.
// auto responses = (*session)->GenerateContent({InputText("What's the tallest
// building in the world?")});
// CHECK_OK(responses);
//
// // Print the response.
// std::cout << *responses << std::endl;
class Engine {
public:
virtual ~Engine() = default;
// Session is responsible for hosting the internal state (e.g. conversation
// history) of each separate interaction with LLM.
class Session {
public:
// The TaskController is responsible for controlling the async task
// execution.
class TaskController {
public:
TaskController() = default;
// The TaskController is not copyable. This is to avoid
// the user from accidentally copying the TaskController and calling the
// CancelProcess function multiple times.
TaskController(const TaskController&) = delete;
TaskController& operator=(const TaskController&) = delete;
// The TaskController is movable.
TaskController(TaskController&&) = default;
TaskController& operator=(TaskController&&) = default;
// The TaskController destructor.
virtual ~TaskController() = default;
// Waits until all the tasks are done or the timeout is reached. The
// function will return error if the timeout is reached.
virtual absl::Status WaitUntilDone(absl::Duration timeout) {
return absl::UnimplementedError("Not implemented.");
};
// Cancels the ongoing inference process. Note that if this function is
// called after the inference process is done, the function will be a
// no-op.
virtual absl::Status Cancel() {
return absl::UnimplementedError("Not implemented.");
};
};
virtual ~Session() = default;
// High-level API to generate content from the input prompt/query. This
// function will handle the prefill and decode processes internally and
// the usage is similar to the Gemini Text Generation API
// (https://ai.google.dev/gemini-api/docs/text-generation).
// - contents: The input data for generation.
virtual absl::StatusOr<Responses> GenerateContent(
const std::vector<InputData>& contents) = 0;
// This is a not blocking call and the function will return right away. The
// result will be streamed through the callback.
//
// - contents: The input data for generation.
// - callback: Callback to receive streamed results.
// Note:
// - If the generation is done successfully, the callback will be
// called with empty responses to signal the completion.
// - If there is an error during the streaming process, the callback
// will be called with the error status and no further results will be
// sent.
// - If the generation is cancelled, the callback will be called
// with a Cancellation error.
virtual absl::Status GenerateContentStream(
const std::vector<InputData>& contents,
absl::AnyInvocable<void(absl::StatusOr<Responses>)> callback) = 0;
// Same as above, but with a custom decode config.
// - decode_config: configuration for the model decode process.
virtual absl::Status GenerateContentStream(
const std::vector<InputData>& contents,
absl::AnyInvocable<void(absl::StatusOr<Responses>)> callback,
const DecodeConfig& decode_config) = 0;
// Scores the target text after the prefill process is done. This function
// will only run the decode process to fetch the decode output logits, which
// is used to calculate the target text's score and update the model memory
// using the target_text tokens.
// This function should be called after the prefill process is done.
// - target_text: The target text to score.
// - store_token_lengths: Whether to store the token lengths of the target
// texts in `Responses`.
// - returns: This function returns the score associated with the target
// text after the model has been prefilled. The returned score is the sum of
// the negative log probability of seeing the target text during decode.
virtual absl::StatusOr<Responses> RunTextScoring(
const std::vector<absl::string_view>& target_text,
bool store_token_lengths) = 0;
// Similar to the above RunTextScoring function, but this is a not blocking
// call and the function will return right away. The processing status will
// be signaled through the callback.
// - target_text: The target text to score.
// - callback: Callback to receive the scoring results.
// - store_token_lengths: Whether to store the token lengths of the target
// texts in `Responses`.
virtual absl::StatusOr<std::unique_ptr<TaskController>> RunTextScoringAsync(
const std::vector<absl::string_view>& target_text,
absl::AnyInvocable<void(absl::StatusOr<Responses>)> callback,
bool store_token_lengths) {
return absl::UnimplementedError("Not implemented.");
}
// Adds the input prompt/query to the model for starting the prefilling
// process. Note that the user can break down their prompt/query into
// multiple chunks and call this function multiple times.
//
// This is a blocking call and the function will return when the prefill
// process is done.
virtual absl::Status RunPrefill(const std::vector<InputData>& contents) = 0;
// This is a not blocking call and the function will return right away. The
// processing status will be signaled through the callback.
virtual absl::StatusOr<std::unique_ptr<TaskController>> RunPrefillAsync(
const std::vector<InputData>& contents,
absl::AnyInvocable<void(absl::StatusOr<Responses>)> callback) {
return absl::UnimplementedError("Not implemented.");
}
// Starts the decoding process for the model to predict the response based
// on the input prompt/query added after using RunPrefill* functions.
// This is a blocking call and the function will return when the decoding
// process is done.
virtual absl::StatusOr<Responses> RunDecode() = 0;
// Same as above, but with a custom decode config.
// - decode_config: configuration for the model decode process.
virtual absl::StatusOr<Responses> RunDecode(
const DecodeConfig& decode_config) = 0;
// Startes the decoding process for the model to predict the response based
// on the input prompt/query added after using RunPrefill* functions.
// This is a not blocking call and the function will return right away. The
// result will be streamed through the callback.
// - callback: Callback to receive streamed results.
virtual absl::StatusOr<std::unique_ptr<TaskController>> RunDecodeAsync(
absl::AnyInvocable<void(absl::StatusOr<Responses>)> callback) {
return absl::UnimplementedError("Not implemented.");
}
// Same as above, but with a custom decode config.
// - decode_config: configuration for the model decode process.
virtual absl::StatusOr<std::unique_ptr<TaskController>> RunDecodeAsync(
absl::AnyInvocable<void(absl::StatusOr<Responses>)> callback,
const DecodeConfig& decode_config) {
return absl::UnimplementedError("Not implemented.");
}
// Returns the benchmark info for the session. Returns error if the
// benchmark is not enabled.
virtual absl::StatusOr<BenchmarkInfo> GetBenchmarkInfo() = 0;
// Returns the mutable benchmark info for the session. Returns error if the
// benchmark is not enabled.
virtual absl::StatusOr<BenchmarkInfo*> GetMutableBenchmarkInfo() = 0;
// Cancels the ongoing inference process. Note that if this function is
// called, the inference process will return with a kCancelled error. The
// session could still be used after afterwards.
virtual void CancelProcess() {
ABSL_LOG(FATAL) << "CancelProcess is not implemented.";
}
// Waits until all the tasks are done or the default timeout is reached.
virtual absl::Status WaitUntilDone() = 0;
// Clones the session.
// The cloned session have all the settings and context
// of the original session up to the point that the clone function is
// called.
// - callback: Callback to when the streamed results.
//
// Example usage:
// Session session1 = engine->CreateSession(...);
// session1->Prefill("What is the tallest building ");
// Session session2 = session1->Clone();
// session1->Prefill("in the world?");
// session1->Decode();
// session2->Prefill("in France?");
// session2->Decode();
virtual absl::StatusOr<std::unique_ptr<Session>> Clone() {
return absl::UnimplementedError("Not implemented.");
};
// Clones the session asynchronously.
// The cloned session have all the settings and context
// of the original session up to the point that the clone function is
// called.
// - callback: Callback to when the streamed results.
//
// Example usage:
// Session session1 = engine->CreateSession(...);
// session1->RunPrefillAsync("What is the tallest building ", ...);
// Session session2 = session1->CloneAsync(...);
// session1->RunPrefillAsync("in the world?", ...);
// session1->RunDecodeAsync(...);
// session2->RunPrefillAsync("in France?", ...);
// session2->RunDecodeAsync(...);
virtual absl::StatusOr<std::unique_ptr<Session>> CloneAsync(
absl::AnyInvocable<void(absl::StatusOr<Responses>)> callback) {
return absl::UnimplementedError("Not implemented.");
};
// Save the current step with the name `label`. You can later rewind to this
// checkpoint using `RewindToCheckpoint(label)`. If the checkpoint name
// already exists, the step number will be overwritten.
virtual absl::Status SaveCheckpoint(absl::string_view label) {
return absl::UnimplementedError("SaveCheckpoint not implemented.");
}
// Rewinds the session to the given checkpoint. Checkpoints after the
// restored step will be removed. Returns an error if the checkpoint name
// does not exist.
virtual absl::Status RewindToCheckpoint(absl::string_view label) {
return absl::UnimplementedError("RewindToCheckpoint not implemented.");
}
// Get the current step of the session.
virtual absl::StatusOr<int> GetCurrentStep() const {
return absl::UnimplementedError("GetCurrentStep not implemented.");
}
// Get the reference to the session config for the session.
virtual const SessionConfig& GetSessionConfig() const = 0;
};
// Method to create the Session.
virtual absl::StatusOr<std::unique_ptr<Session>> CreateSession(
const SessionConfig& session_config) = 0;
// Waits until the engine is done with all the tasks. The function will
// return error if the timeout is reached.
virtual absl::Status WaitUntilDone(absl::Duration timeout) {
return absl::UnimplementedError("Not implemented.");
}
// Returns the EngineSettings currently used by the engine.
virtual const EngineSettings& GetEngineSettings() const = 0;
// Get the reference to the tokenizer for the engine.
virtual const Tokenizer& GetTokenizer() const = 0;
// Get the audio model properties for the session. This is only available
// if the engine is created with audio modality enabled.
virtual absl::StatusOr<AudioExecutorProperties> GetAudioExecutorProperties()
const = 0;
// Get the vision model properties for the session. This is only available
// if the engine is created with vision modality enabled.
virtual absl::StatusOr<VisionExecutorProperties> GetVisionExecutorProperties()
const = 0;
// Default timeout duration for the engine/session processes.
static constexpr absl::Duration kDefaultTimeout = absl::Minutes(10);
};
} // namespace litert::lm
#endif // THIRD_PARTY_ODML_LITERT_LM_RUNTIME_ENGINE_ENGINE_H_
|