| #include "safetensors_loader.h" |
|
|
| #include <fcntl.h> |
| #include <sys/mman.h> |
| #include <sys/stat.h> |
| #include <unistd.h> |
|
|
| #include <cstdio> |
| #include <cstring> |
| #include <fstream> |
| #include <sstream> |
|
|
| #include "json.hpp" |
|
|
| using json = nlohmann::json; |
|
|
| SafetensorsLoader::SafetensorsLoader() = default; |
|
|
| SafetensorsLoader::~SafetensorsLoader() { |
| for (auto& s : shards_) { |
| if (s.mmap_ptr) munmap(s.mmap_ptr, s.mmap_size); |
| if (s.fd >= 0) close(s.fd); |
| } |
| } |
|
|
| bool SafetensorsLoader::open(const std::string& dir) { |
| model_dir_ = dir; |
|
|
| |
| std::string idx_path = dir + "/model.safetensors.index.json"; |
| std::ifstream idx_file(idx_path); |
| if (!idx_file) { |
| |
| std::string single = dir + "/model.safetensors"; |
| std::ifstream f(single); |
| if (!f) { |
| fprintf(stderr, "SafetensorsLoader: neither index.json nor model.safetensors found in %s\n", dir.c_str()); |
| return false; |
| } |
| shards_.push_back({single}); |
| return parse_shard_header_(0); |
| } |
|
|
| json idx; |
| try { idx_file >> idx; } catch (std::exception& e) { |
| fprintf(stderr, "SafetensorsLoader: bad index.json: %s\n", e.what()); |
| return false; |
| } |
| if (!idx.contains("weight_map")) { |
| fprintf(stderr, "SafetensorsLoader: index.json missing weight_map\n"); |
| return false; |
| } |
|
|
| |
| std::map<std::string, int> shard_name_to_id; |
| for (auto& [name, file] : idx["weight_map"].items()) { |
| std::string shard_name = file.get<std::string>(); |
| if (shard_name_to_id.count(shard_name) == 0) { |
| int id = (int)shards_.size(); |
| shard_name_to_id[shard_name] = id; |
| shards_.push_back({dir + "/" + shard_name}); |
| } |
| } |
|
|
| |
| for (int i = 0; i < (int)shards_.size(); i++) { |
| if (!parse_shard_header_(i)) { |
| fprintf(stderr, "SafetensorsLoader: failed to parse shard %s\n", shards_[i].path.c_str()); |
| return false; |
| } |
| } |
|
|
| return true; |
| } |
|
|
| bool SafetensorsLoader::parse_shard_header_(int shard_id) { |
| ShardFile& sh = shards_[shard_id]; |
| std::ifstream f(sh.path, std::ios::binary); |
| if (!f) return false; |
|
|
| |
| uint64_t header_len = 0; |
| f.read((char*)&header_len, 8); |
| if (!f) return false; |
|
|
| std::string header(header_len, '\0'); |
| f.read(header.data(), header_len); |
| if (!f) return false; |
|
|
| sh.data_base = 8 + header_len; |
|
|
| json j; |
| try { j = json::parse(header); } catch (std::exception& e) { |
| fprintf(stderr, "SafetensorsLoader: bad shard header JSON in %s: %s\n", sh.path.c_str(), e.what()); |
| return false; |
| } |
|
|
| for (auto it = j.begin(); it != j.end(); ++it) { |
| const std::string& name = it.key(); |
| if (name == "__metadata__") continue; |
| const auto& entry = it.value(); |
|
|
| TensorMeta m; |
| m.name = name; |
| m.dtype = entry["dtype"].get<std::string>(); |
| for (auto& d : entry["shape"]) m.shape.push_back(d.get<int64_t>()); |
| const auto& offs = entry["data_offsets"]; |
| size_t begin = offs[0].get<size_t>(); |
| size_t end = offs[1].get<size_t>(); |
| m.offset = sh.data_base + begin; |
| m.nbytes = end - begin; |
| m.shard_id = shard_id; |
|
|
| tensors_[name] = std::move(m); |
| } |
|
|
| return true; |
| } |
|
|
| bool SafetensorsLoader::mmap_shard_(int shard_id) { |
| ShardFile& sh = shards_[shard_id]; |
| if (sh.mmap_ptr) return true; |
|
|
| sh.fd = ::open(sh.path.c_str(), O_RDONLY); |
| if (sh.fd < 0) { |
| perror("open"); |
| return false; |
| } |
| struct stat st; |
| if (fstat(sh.fd, &st) != 0) return false; |
| sh.mmap_size = st.st_size; |
|
|
| sh.mmap_ptr = mmap(nullptr, sh.mmap_size, PROT_READ, MAP_PRIVATE, sh.fd, 0); |
| if (sh.mmap_ptr == MAP_FAILED) { |
| perror("mmap"); |
| sh.mmap_ptr = nullptr; |
| return false; |
| } |
| return true; |
| } |
|
|
| const TensorMeta* SafetensorsLoader::get(const std::string& name) const { |
| auto it = tensors_.find(name); |
| if (it == tensors_.end()) return nullptr; |
| return &it->second; |
| } |
|
|
| const void* SafetensorsLoader::data_ptr(const TensorMeta& m) { |
| if (m.shard_id < 0 || (size_t)m.shard_id >= shards_.size()) return nullptr; |
| if (!mmap_shard_(m.shard_id)) return nullptr; |
| ShardFile& sh = shards_[m.shard_id]; |
| return (const char*)sh.mmap_ptr + m.offset; |
| } |
|
|
| const void* SafetensorsLoader::data_ptr(const std::string& name) { |
| const auto* m = get(name); |
| if (!m) return nullptr; |
| return data_ptr(*m); |
| } |
|
|
| std::vector<std::string> SafetensorsLoader::list_tensor_names() const { |
| std::vector<std::string> out; |
| out.reserve(tensors_.size()); |
| for (auto& [k, v] : tensors_) out.push_back(k); |
| return out; |
| } |
|
|
| size_t SafetensorsLoader::total_bytes() const { |
| size_t sum = 0; |
| for (auto& [k, v] : tensors_) sum += v.nbytes; |
| return sum; |
| } |
|
|