#include "safetensors_loader.h" #include #include #include #include #include #include #include #include #include "json.hpp" using json = nlohmann::json; SafetensorsLoader::SafetensorsLoader() = default; SafetensorsLoader::~SafetensorsLoader() { for (auto& s : shards_) { if (s.mmap_ptr) munmap(s.mmap_ptr, s.mmap_size); if (s.fd >= 0) close(s.fd); } } bool SafetensorsLoader::open(const std::string& dir) { model_dir_ = dir; // 1. Parse index.json to discover shard files std::string idx_path = dir + "/model.safetensors.index.json"; std::ifstream idx_file(idx_path); if (!idx_file) { // Fallback: single-file model std::string single = dir + "/model.safetensors"; std::ifstream f(single); if (!f) { fprintf(stderr, "SafetensorsLoader: neither index.json nor model.safetensors found in %s\n", dir.c_str()); return false; } shards_.push_back({single}); return parse_shard_header_(0); } json idx; try { idx_file >> idx; } catch (std::exception& e) { fprintf(stderr, "SafetensorsLoader: bad index.json: %s\n", e.what()); return false; } if (!idx.contains("weight_map")) { fprintf(stderr, "SafetensorsLoader: index.json missing weight_map\n"); return false; } // Collect unique shard filenames (preserving discovery order). std::map shard_name_to_id; for (auto& [name, file] : idx["weight_map"].items()) { std::string shard_name = file.get(); if (shard_name_to_id.count(shard_name) == 0) { int id = (int)shards_.size(); shard_name_to_id[shard_name] = id; shards_.push_back({dir + "/" + shard_name}); } } // 2. Parse header of each shard to discover tensor offsets for (int i = 0; i < (int)shards_.size(); i++) { if (!parse_shard_header_(i)) { fprintf(stderr, "SafetensorsLoader: failed to parse shard %s\n", shards_[i].path.c_str()); return false; } } return true; } bool SafetensorsLoader::parse_shard_header_(int shard_id) { ShardFile& sh = shards_[shard_id]; std::ifstream f(sh.path, std::ios::binary); if (!f) return false; // Read 8-byte little-endian header length uint64_t header_len = 0; f.read((char*)&header_len, 8); if (!f) return false; std::string header(header_len, '\0'); f.read(header.data(), header_len); if (!f) return false; sh.data_base = 8 + header_len; json j; try { j = json::parse(header); } catch (std::exception& e) { fprintf(stderr, "SafetensorsLoader: bad shard header JSON in %s: %s\n", sh.path.c_str(), e.what()); return false; } for (auto it = j.begin(); it != j.end(); ++it) { const std::string& name = it.key(); if (name == "__metadata__") continue; const auto& entry = it.value(); TensorMeta m; m.name = name; m.dtype = entry["dtype"].get(); for (auto& d : entry["shape"]) m.shape.push_back(d.get()); const auto& offs = entry["data_offsets"]; size_t begin = offs[0].get(); size_t end = offs[1].get(); m.offset = sh.data_base + begin; m.nbytes = end - begin; m.shard_id = shard_id; tensors_[name] = std::move(m); } return true; } bool SafetensorsLoader::mmap_shard_(int shard_id) { ShardFile& sh = shards_[shard_id]; if (sh.mmap_ptr) return true; sh.fd = ::open(sh.path.c_str(), O_RDONLY); if (sh.fd < 0) { perror("open"); return false; } struct stat st; if (fstat(sh.fd, &st) != 0) return false; sh.mmap_size = st.st_size; sh.mmap_ptr = mmap(nullptr, sh.mmap_size, PROT_READ, MAP_PRIVATE, sh.fd, 0); if (sh.mmap_ptr == MAP_FAILED) { perror("mmap"); sh.mmap_ptr = nullptr; return false; } return true; } const TensorMeta* SafetensorsLoader::get(const std::string& name) const { auto it = tensors_.find(name); if (it == tensors_.end()) return nullptr; return &it->second; } const void* SafetensorsLoader::data_ptr(const TensorMeta& m) { if (m.shard_id < 0 || (size_t)m.shard_id >= shards_.size()) return nullptr; if (!mmap_shard_(m.shard_id)) return nullptr; ShardFile& sh = shards_[m.shard_id]; return (const char*)sh.mmap_ptr + m.offset; } const void* SafetensorsLoader::data_ptr(const std::string& name) { const auto* m = get(name); if (!m) return nullptr; return data_ptr(*m); } std::vector SafetensorsLoader::list_tensor_names() const { std::vector out; out.reserve(tensors_.size()); for (auto& [k, v] : tensors_) out.push_back(k); return out; } size_t SafetensorsLoader::total_bytes() const { size_t sum = 0; for (auto& [k, v] : tensors_) sum += v.nbytes; return sum; }