// Copyright 2025 The ODML Authors. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "runtime/util/lora_data.h" #include #include #include #include #include #include #include #include "absl/status/status.h" // from @com_google_absl #include "absl/status/statusor.h" // from @com_google_absl #include "absl/strings/str_cat.h" // from @com_google_absl #include "absl/strings/string_view.h" // from @com_google_absl #include "flatbuffers/buffer.h" // from @flatbuffers #include "flatbuffers/vector.h" // from @flatbuffers #include "litert/cc/litert_buffer_ref.h" // from @litert #include "runtime/util/lora_util.h" #include "runtime/util/scoped_file.h" #include "runtime/util/status_macros.h" #include "tflite/model_builder.h" // from @litert #include "tflite/schema/schema_generated.h" // from @litert namespace litert::lm { namespace { constexpr absl::string_view kLoRARank = "lora_rank"; // The maximum size of the metadata buffer. // This is the max length we need to mmap to build the flatbuffer model. constexpr int kMetadataMaxSize = 1024 * 1024; // 1MB absl::StatusOr> CreateFlatBufferModelFromBuffer(const void* buffer_addr, size_t buffer_size) { const bool obfuscated = !tflite::ModelBufferHasIdentifier(buffer_addr); if (obfuscated) { return absl::UnimplementedError( "Input is not valid flatbuffer model. Deobfuscation is not supported " "yet."); } std::unique_ptr model = tflite::FlatBufferModel::VerifyAndBuildFromBuffer( reinterpret_cast(buffer_addr), buffer_size); RET_CHECK(model) << "Error building tflite model."; return model; } // LoRA data based on FlatBufferModel. class FlatBufferLoraData : public LoraData { public: ~FlatBufferLoraData() override = default; absl::StatusOr GetLoRARank() override { const tflite::Metadata* metadata = GetMetadata(kLoRARank); if (metadata == nullptr) { return absl::NotFoundError("No LoRA metadata found."); } return static_cast(metadata->buffer()); } absl::StatusOr>> ReadTensor( absl::string_view name) override { const tflite::Buffer* buffer = GetBuffer(name); if (buffer == nullptr) { return absl::NotFoundError( absl::StrCat("No buffer found for tensor: ", name)); } return ReadData(buffer->offset(), buffer->size()); } bool HasTensor(absl::string_view name) const override { return GetBuffer(name) != nullptr; } std::vector GetAllTensorNames() const override { std::vector tensor_names; const tflite::Model* tflite_model = GetFlatBufferModel()->GetModel(); for (const tflite::SubGraph* subgraph : *tflite_model->subgraphs()) { for (const tflite::Tensor* tfl_tensor : *subgraph->tensors()) { tensor_names.push_back(tfl_tensor->name()->c_str()); } } return tensor_names; } protected: // Returns the FlatBufferModel object reference. // FlatBufferModel is owned by derived classes to be destroyed in correct // order, thus it is accessed by base class with a reference here. virtual const tflite::FlatBufferModel* GetFlatBufferModel() const = 0; // Reads data stored at the given offset and size. virtual absl::StatusOr>> ReadData( uint64_t offset, uint64_t size) = 0; private: // Get metadata from the flatbuffer model. const tflite::Metadata* GetMetadata(absl::string_view name) { const tflite::Model* tflite_model = GetFlatBufferModel()->GetModel(); if (tflite_model->metadata() == nullptr) { return nullptr; } for (const tflite::Metadata* metadata : *tflite_model->metadata()) { if (name == metadata->name()->c_str()) { return metadata; } } return nullptr; } const tflite::Buffer* GetBuffer(absl::string_view name) const { const tflite::Model* tflite_model = GetFlatBufferModel()->GetModel(); const flatbuffers::Vector>& buffers = *tflite_model->buffers(); for (const tflite::SubGraph* subgraph : *tflite_model->subgraphs()) { for (const tflite::Tensor* tfl_tensor : *subgraph->tensors()) { if (name != tfl_tensor->name()->c_str()) { continue; } if (tfl_tensor->buffer() >= buffers.size()) { continue; } return buffers.Get(tfl_tensor->buffer()); } } return nullptr; } }; // FlatBufferModel based LoRA data backed by a file. class FileLoraData : public FlatBufferLoraData { public: // Constructor for FileLoraData. // // @param file A shared_ptr to the ScopedFile object representing the LoRA // data file. // @param region A unique_ptr to the MemoryMappedFileWithAutoAlignment object // representing the memory mapped region of the FlatBufferModel metadata. // @param model A unique_ptr to the FlatBufferModel object representing the // LoRA data metadata. explicit FileLoraData( std::shared_ptr file, std::unique_ptr region, std::unique_ptr model, const std::string& key) : file_(std::move(file)), region_(std::move(region)), model_(std::move(model)), key_(key) {} ~FileLoraData() override = default; private: const tflite::FlatBufferModel* GetFlatBufferModel() const override { return model_.get(); } absl::StatusOr>> ReadData( uint64_t offset, uint64_t size) override { ASSIGN_OR_RETURN(auto mapped_region, MemoryMappedFileWithAutoAlignment::Create( file_->file(), /*offset=*/offset, /*size=*/size, key_)); return std::make_unique>(std::move(mapped_region)); } private: std::shared_ptr file_; std::unique_ptr region_; std::unique_ptr model_; const std::string key_; }; // FlatBufferModel based LoRA data backed by a BufferRef. class BufferLoraData : public FlatBufferLoraData { public: // Constructor for BufferLoraData. // // @param data A BufferRef object representing the LoRA data. // @param model A unique_ptr to the FlatBufferModel object representing the // LoRA data. explicit BufferLoraData(BufferRef data, std::unique_ptr model) : data_(std::move(data)), model_(std::move(model)) {} ~BufferLoraData() override = default; private: const tflite::FlatBufferModel* GetFlatBufferModel() const override { return model_.get(); } absl::StatusOr>> ReadData( uint64_t offset, uint64_t size) override { return std::make_unique>( data_.Data(), /*end_offset=*/offset + size, /*start_offset=*/offset); } private: BufferRef data_; std::unique_ptr model_; }; } // namespace // static absl::StatusOr> LoraData::CreateFromFilePath( absl::string_view file_path) { ASSIGN_OR_RETURN(auto file, ScopedFile::Open(file_path)); return CreateFromScopedFile(std::make_shared(std::move(file))); } // static absl::StatusOr> LoraData::CreateFromScopedFile( std::shared_ptr file) { static std::atomic next_key{0}; const std::string key{absl::StrCat("FileLoraData_", next_key.fetch_add(1))}; ASSIGN_OR_RETURN(auto mapped_file, MemoryMappedFileWithAutoAlignment::Create( file->file(), /*offset=*/0, /*size=*/kMetadataMaxSize, key)); ASSIGN_OR_RETURN(auto model, CreateFlatBufferModelFromBuffer( mapped_file->data(), mapped_file->length())); RET_CHECK(model) << "Error building tflite model."; return std::make_unique(std::move(file), std::move(mapped_file), std::move(model), key); } // static absl::StatusOr> LoraData::CreateFromBuffer( BufferRef buffer) { ASSIGN_OR_RETURN(auto model, CreateFlatBufferModelFromBuffer(buffer.Data(), buffer.Size())); RET_CHECK(model) << "Error building tflite model."; return std::make_unique(std::move(buffer), std::move(model)); } } // namespace litert::lm