Spaces:
Running
Running
| // Copyright 2025 The ODML Authors. | |
| // | |
| // Licensed under the Apache License, Version 2.0 (the "License"); | |
| // you may not use this file except in compliance with the License. | |
| // You may obtain a copy of the License at | |
| // | |
| // http://www.apache.org/licenses/LICENSE-2.0 | |
| // | |
| // Unless required by applicable law or agreed to in writing, software | |
| // distributed under the License is distributed on an "AS IS" BASIS, | |
| // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| // See the License for the specific language governing permissions and | |
| // limitations under the License. | |
| namespace litert::lm { | |
| // SessionBasic is a basic implementation of Engine::Session. The underlying | |
| // prefill/decode pipelines use the LLM Executor's basic Decode function which | |
| // does the sampling logics inside. | |
| class SessionBasic : public Engine::Session { | |
| public: | |
| // Creates a SessionBasic object. | |
| // - executor: The initialized LLM Executor to call. | |
| // - tokenizer: The tokenizer to encode/decode the text into token ids. | |
| // - vision_executor: The vision executor to encode the image input. | |
| // - audio_executor: The audio executor to encode the audio input. | |
| // - stop_token_ids: The token ids to stop the decoding process. | |
| // - sampler_params: The sampler parameters used for decoding. Note that if | |
| // the sampler_params.type is TYPE_UNSPECIFIED, the sampling logic will be | |
| // handled by the LLM Executor. | |
| static absl::StatusOr<std::unique_ptr<SessionBasic>> Create( | |
| LlmExecutor* absl_nonnull executor, Tokenizer* absl_nonnull tokenizer, | |
| VisionExecutor* vision_executor, AudioExecutor* audio_executor, | |
| const SessionConfig& session_config, | |
| std::optional<BenchmarkInfo> benchmark_info, | |
| ThreadPool* absl_nonnull worker_thread_pool); | |
| virtual ~SessionBasic(); | |
| absl::StatusOr<Responses> GenerateContent( | |
| const std::vector<InputData>& contents) override; | |
| absl::Status GenerateContentStream( | |
| const std::vector<InputData>& contents, | |
| absl::AnyInvocable<void(absl::StatusOr<Responses>)> callback) override; | |
| absl::Status GenerateContentStream( | |
| const std::vector<InputData>& contents, | |
| absl::AnyInvocable<void(absl::StatusOr<Responses>)> callback, | |
| const DecodeConfig& decode_config) override; | |
| // Scores the target text after the prefill process is done. This function | |
| // will only run the decode process to fetch the decode output logits, which | |
| // is used to calculate the target text's score and update the model memory | |
| // using the target_text tokens. | |
| // This function should be called after the prefill process is done. | |
| // - target_text: The target text to score. | |
| // - store_token_lengths: Whether to store the token lengths of the target | |
| // texts in `Responses`. | |
| // - return: This function returns the score associated with the target | |
| // text after the model has been prefilled. The returned score is the sum of | |
| // the negative log probability of seeing the target text during decode. | |
| absl::StatusOr<Responses> RunTextScoring( | |
| const std::vector<absl::string_view>& target_text, | |
| bool store_token_lengths) override; | |
| absl::StatusOr<std::unique_ptr<Engine::Session::TaskController>> | |
| RunTextScoringAsync( | |
| const std::vector<absl::string_view>& target_text, | |
| absl::AnyInvocable<void(absl::StatusOr<Responses>)> callback, | |
| bool store_token_lengths) override; | |
| absl::Status RunPrefill(const std::vector<InputData>& contents) override; | |
| absl::StatusOr<std::unique_ptr<Engine::Session::TaskController>> | |
| RunPrefillAsync( | |
| const std::vector<InputData>& contents, | |
| absl::AnyInvocable<void(absl::StatusOr<Responses>)> callback) override; | |
| absl::StatusOr<Responses> RunDecode() override; | |
| absl::StatusOr<Responses> RunDecode( | |
| const DecodeConfig& decode_config) override; | |
| absl::StatusOr<std::unique_ptr<Engine::Session::TaskController>> | |
| RunDecodeAsync( | |
| absl::AnyInvocable<void(absl::StatusOr<Responses>)> callback) override; | |
| absl::StatusOr<std::unique_ptr<Engine::Session::TaskController>> | |
| RunDecodeAsync(absl::AnyInvocable<void(absl::StatusOr<Responses>)> callback, | |
| const DecodeConfig& decode_config) override; | |
| absl::StatusOr<BenchmarkInfo> GetBenchmarkInfo() override; | |
| absl::StatusOr<BenchmarkInfo*> GetMutableBenchmarkInfo() override; | |
| // TODO(b/450903294): Add rollback history support for Session and | |
| // Conversation. | |
| void CancelProcess() override { | |
| ABSL_LOG(INFO) << "SessionBasic::CancelProcess"; | |
| cancelled_.store(true); | |
| } | |
| absl::Status WaitUntilDone() override { | |
| return worker_thread_pool_.WaitUntilDone(Engine::kDefaultTimeout); | |
| } | |
| const SessionConfig& GetSessionConfig() const override { | |
| return session_config_; | |
| } | |
| // Util function for creating the combined ExecutorInputs from the | |
| // preprocessed contents. | |
| // TODO - b/436674053: Modularize the preprocessing logic into a separate | |
| // preprocessor class. | |
| absl::StatusOr<ExecutorInputs> ProcessAndCombineContents( | |
| const std::vector<InputData>& preprocessed_contents); | |
| // Save the current step with the name `label`. You can later rewind to this | |
| // checkpoint using `RewindToCheckpoint(label)`. If the checkpoint name | |
| // already exists, the step number will be overwritten. | |
| absl::Status SaveCheckpoint(absl::string_view label) override; | |
| // Rewinds the session to the given checkpoint. Checkpoints after the | |
| // restored step will be removed. Returns an error if the checkpoint name | |
| // does not exist. | |
| absl::Status RewindToCheckpoint(absl::string_view label) override; | |
| // Get the current step of the session. | |
| absl::StatusOr<int> GetCurrentStep() const override; | |
| private: | |
| explicit SessionBasic(LlmExecutor* absl_nonnull executor, | |
| Tokenizer* absl_nonnull tokenizer, | |
| VisionExecutor* vision_executor, | |
| AudioExecutor* audio_executor, | |
| std::unique_ptr<Sampler> sampler, | |
| const SessionConfig& session_config, | |
| std::optional<BenchmarkInfo> benchmark_info, | |
| ThreadPool* absl_nonnull worker_thread_pool, | |
| const StopTokenDetector& stop_token_detector) | |
| : executor_(*executor), | |
| tokenizer_(*tokenizer), | |
| vision_executor_(vision_executor), | |
| audio_executor_(audio_executor), | |
| sampler_(std::move(sampler)), | |
| session_config_(session_config), | |
| benchmark_info_(benchmark_info), | |
| worker_thread_pool_(*worker_thread_pool), | |
| stop_token_detector_(stop_token_detector) {} | |
| // The internal function to prefill the input prompt. It is for convenience to | |
| // wrap it with lambda function for scheduling. | |
| absl::Status PrefillInternal( | |
| const std::vector<InputData>& preprocessed_contents, | |
| bool wait_for_completion); | |
| // The internal functions to decode the input prompt. It is for convenience to | |
| // wrap it with lambda function for scheduling. | |
| absl::StatusOr<Responses> DecodeInternal(const DecodeConfig& decode_config); | |
| absl::Status DecodeInternalStreaming( | |
| absl::AnyInvocable<void(absl::StatusOr<Responses>)> callback, | |
| const DecodeConfig& decode_config); | |
| // The executor used for run the LLM for prefill/decode. | |
| LlmExecutor& executor_; | |
| // The tokenizer used for converting between text to token ids. | |
| Tokenizer& tokenizer_; | |
| // The vision executor used for run the LLM for prefill/decode. | |
| VisionExecutor* vision_executor_; | |
| // The audio executor used for run the LLM for prefill/decode. | |
| AudioExecutor* audio_executor_; | |
| // The session config used for the session. | |
| std::unique_ptr<Sampler> sampler_; | |
| // The session config used for the session. | |
| SessionConfig session_config_; | |
| // The last token id of the prefill ids. It is used for the first decode | |
| // process to determine the token id to start from. | |
| int last_prefill_token_id_; | |
| // The benchmark info used for the session. | |
| std::optional<BenchmarkInfo> benchmark_info_; | |
| // The thread pool used for the session. | |
| ThreadPool& worker_thread_pool_; | |
| // The stop token detector used for the session. | |
| StopTokenDetector stop_token_detector_; | |
| // An atomic boolean to indicate whether the session is cancelled. | |
| std::atomic<bool> cancelled_{false}; | |
| // The state of the session. | |
| // * `kFresh` means the session is just created and | |
| // hasn't been prefilled yet. | |
| // * `kPrefilled` means the session has been prefilled | |
| // but not decoded yet. | |
| // * `kDecoded` means the session has been decoded. | |
| // | |
| // A session is considered fresh only if it has not been prefilled or decoded | |
| // yet. | |
| // A session could transition between kPrefilled and kDecoded if | |
| // `RunPrefill` or `RunDecode` is called multiple times. | |
| enum class SessionState : int { kFresh, kPrefilled, kDecoded }; | |
| SessionState session_state_ = SessionState::kFresh; | |
| // The set of executors that are already existed in the system. This is used | |
| // to avoid creating multiple sessions for the same executor. | |
| static absl::flat_hash_set<LlmExecutor*>* occupied_executors_ | |
| ABSL_GUARDED_BY(occupied_executors_mu_); | |
| static absl::Mutex occupied_executors_mu_; | |
| // The map of checkpoint name to step. | |
| absl::flat_hash_map<std::string, int> checkpoint_map_; | |
| }; | |
| } // namespace litert::lm | |