llm_mutil_npu / src /safetensors_loader.cpp
xianglarry's picture
Initial C++ aclnn EAGER inference for Qwen3-235B-A22B MoE on Ascend 910 × 16 NPU
4b9fefd
#include "safetensors_loader.h"
#include <fcntl.h>
#include <sys/mman.h>
#include <sys/stat.h>
#include <unistd.h>
#include <cstdio>
#include <cstring>
#include <fstream>
#include <sstream>
#include "json.hpp"
using json = nlohmann::json;
SafetensorsLoader::SafetensorsLoader() = default;
SafetensorsLoader::~SafetensorsLoader() {
for (auto& s : shards_) {
if (s.mmap_ptr) munmap(s.mmap_ptr, s.mmap_size);
if (s.fd >= 0) close(s.fd);
}
}
bool SafetensorsLoader::open(const std::string& dir) {
model_dir_ = dir;
// 1. Parse index.json to discover shard files
std::string idx_path = dir + "/model.safetensors.index.json";
std::ifstream idx_file(idx_path);
if (!idx_file) {
// Fallback: single-file model
std::string single = dir + "/model.safetensors";
std::ifstream f(single);
if (!f) {
fprintf(stderr, "SafetensorsLoader: neither index.json nor model.safetensors found in %s\n", dir.c_str());
return false;
}
shards_.push_back({single});
return parse_shard_header_(0);
}
json idx;
try { idx_file >> idx; } catch (std::exception& e) {
fprintf(stderr, "SafetensorsLoader: bad index.json: %s\n", e.what());
return false;
}
if (!idx.contains("weight_map")) {
fprintf(stderr, "SafetensorsLoader: index.json missing weight_map\n");
return false;
}
// Collect unique shard filenames (preserving discovery order).
std::map<std::string, int> shard_name_to_id;
for (auto& [name, file] : idx["weight_map"].items()) {
std::string shard_name = file.get<std::string>();
if (shard_name_to_id.count(shard_name) == 0) {
int id = (int)shards_.size();
shard_name_to_id[shard_name] = id;
shards_.push_back({dir + "/" + shard_name});
}
}
// 2. Parse header of each shard to discover tensor offsets
for (int i = 0; i < (int)shards_.size(); i++) {
if (!parse_shard_header_(i)) {
fprintf(stderr, "SafetensorsLoader: failed to parse shard %s\n", shards_[i].path.c_str());
return false;
}
}
return true;
}
bool SafetensorsLoader::parse_shard_header_(int shard_id) {
ShardFile& sh = shards_[shard_id];
std::ifstream f(sh.path, std::ios::binary);
if (!f) return false;
// Read 8-byte little-endian header length
uint64_t header_len = 0;
f.read((char*)&header_len, 8);
if (!f) return false;
std::string header(header_len, '\0');
f.read(header.data(), header_len);
if (!f) return false;
sh.data_base = 8 + header_len;
json j;
try { j = json::parse(header); } catch (std::exception& e) {
fprintf(stderr, "SafetensorsLoader: bad shard header JSON in %s: %s\n", sh.path.c_str(), e.what());
return false;
}
for (auto it = j.begin(); it != j.end(); ++it) {
const std::string& name = it.key();
if (name == "__metadata__") continue;
const auto& entry = it.value();
TensorMeta m;
m.name = name;
m.dtype = entry["dtype"].get<std::string>();
for (auto& d : entry["shape"]) m.shape.push_back(d.get<int64_t>());
const auto& offs = entry["data_offsets"];
size_t begin = offs[0].get<size_t>();
size_t end = offs[1].get<size_t>();
m.offset = sh.data_base + begin;
m.nbytes = end - begin;
m.shard_id = shard_id;
tensors_[name] = std::move(m);
}
return true;
}
bool SafetensorsLoader::mmap_shard_(int shard_id) {
ShardFile& sh = shards_[shard_id];
if (sh.mmap_ptr) return true;
sh.fd = ::open(sh.path.c_str(), O_RDONLY);
if (sh.fd < 0) {
perror("open");
return false;
}
struct stat st;
if (fstat(sh.fd, &st) != 0) return false;
sh.mmap_size = st.st_size;
sh.mmap_ptr = mmap(nullptr, sh.mmap_size, PROT_READ, MAP_PRIVATE, sh.fd, 0);
if (sh.mmap_ptr == MAP_FAILED) {
perror("mmap");
sh.mmap_ptr = nullptr;
return false;
}
return true;
}
const TensorMeta* SafetensorsLoader::get(const std::string& name) const {
auto it = tensors_.find(name);
if (it == tensors_.end()) return nullptr;
return &it->second;
}
const void* SafetensorsLoader::data_ptr(const TensorMeta& m) {
if (m.shard_id < 0 || (size_t)m.shard_id >= shards_.size()) return nullptr;
if (!mmap_shard_(m.shard_id)) return nullptr;
ShardFile& sh = shards_[m.shard_id];
return (const char*)sh.mmap_ptr + m.offset;
}
const void* SafetensorsLoader::data_ptr(const std::string& name) {
const auto* m = get(name);
if (!m) return nullptr;
return data_ptr(*m);
}
std::vector<std::string> SafetensorsLoader::list_tensor_names() const {
std::vector<std::string> out;
out.reserve(tensors_.size());
for (auto& [k, v] : tensors_) out.push_back(k);
return out;
}
size_t SafetensorsLoader::total_bytes() const {
size_t sum = 0;
for (auto& [k, v] : tensors_) sum += v.nbytes;
return sum;
}