KittenTTS-WebGPU / src /lib /npz-reader.ts
shreyask's picture
feat: KittenTTS WebGPU browser demo
9b1aef8 verified
/**
* Voice data loader — loads KittenTTS voice embeddings.
*
* Instead of parsing .npz (zip of npy), we download the npz and use
* a robust zip + npy parser with proper byte alignment handling.
*/
export interface VoiceInfo {
data: Float32Array;
shape: [number, number]; // [numStyles, styleDim]
}
function parseNpyHeader(bytes: Uint8Array) {
// Magic: \x93NUMPY
if (bytes[0] !== 0x93 || String.fromCharCode(bytes[1], bytes[2], bytes[3], bytes[4], bytes[5]) !== "NUMPY") {
throw new Error("Not a valid .npy file");
}
const majorVersion = bytes[6];
const view = new DataView(bytes.buffer, bytes.byteOffset, bytes.byteLength);
let headerLen: number;
let headerOffset: number;
if (majorVersion === 1) {
headerLen = view.getUint16(8, true);
headerOffset = 10;
} else {
headerLen = view.getUint32(8, true);
headerOffset = 12;
}
const headerStr = new TextDecoder().decode(
bytes.slice(headerOffset, headerOffset + headerLen)
);
const descrMatch = headerStr.match(/'descr'\s*:\s*'([^']+)'/);
const shapeMatch = headerStr.match(/'shape'\s*:\s*\(([^)]*)\)/);
if (!descrMatch) throw new Error("Could not parse dtype from .npy header: " + headerStr);
const descr = descrMatch[1];
const shapeNums = shapeMatch
? shapeMatch[1].split(",").map((s) => parseInt(s.trim(), 10)).filter((n) => !isNaN(n))
: [];
const dataOffset = headerOffset + headerLen;
return { descr, shape: shapeNums, dataOffset };
}
function npyToFloat32(bytes: Uint8Array): { data: Float32Array; shape: number[] } {
const { descr, shape, dataOffset } = parseNpyHeader(bytes);
const rawBytes = bytes.slice(dataOffset);
// Always copy into a fresh aligned ArrayBuffer
const aligned = new ArrayBuffer(rawBytes.length);
new Uint8Array(aligned).set(rawBytes);
let data: Float32Array;
if (descr === "<f4" || descr === "float32") {
data = new Float32Array(aligned);
} else if (descr === "<f8" || descr === "float64") {
const f64 = new Float64Array(aligned);
data = new Float32Array(f64.length);
for (let i = 0; i < f64.length; i++) data[i] = f64[i];
} else {
throw new Error("Unsupported npy dtype: " + descr);
}
return { data, shape };
}
/**
* Parse a zip file and extract entries.
* Handles both stored (method 0) and deflated (method 8) entries.
* Properly handles data descriptors (bit 3 of flags).
*/
async function extractZipEntries(
buffer: ArrayBuffer
): Promise<Map<string, Uint8Array>> {
const bytes = new Uint8Array(buffer);
const view = new DataView(buffer);
const entries = new Map<string, Uint8Array>();
// First, find the Central Directory to get reliable sizes
// Search for End of Central Directory signature (0x06054b50) from the end
let eocdOffset = -1;
for (let i = bytes.length - 22; i >= 0; i--) {
if (view.getUint32(i, true) === 0x06054b50) {
eocdOffset = i;
break;
}
}
if (eocdOffset === -1) {
throw new Error("Could not find End of Central Directory");
}
const cdOffset = view.getUint32(eocdOffset + 16, true);
const cdEntries = view.getUint16(eocdOffset + 10, true);
// Parse Central Directory entries to get accurate sizes and offsets
interface CDEntry {
fileName: string;
compressedSize: number;
uncompressedSize: number;
localHeaderOffset: number;
compressionMethod: number;
}
const cdList: CDEntry[] = [];
let cdPos = cdOffset;
for (let i = 0; i < cdEntries; i++) {
const sig = view.getUint32(cdPos, true);
if (sig !== 0x02014b50) break;
const compressionMethod = view.getUint16(cdPos + 10, true);
const compressedSize = view.getUint32(cdPos + 20, true);
const uncompressedSize = view.getUint32(cdPos + 24, true);
const fileNameLen = view.getUint16(cdPos + 28, true);
const extraLen = view.getUint16(cdPos + 30, true);
const commentLen = view.getUint16(cdPos + 32, true);
const localHeaderOffset = view.getUint32(cdPos + 42, true);
const fileName = new TextDecoder().decode(
bytes.slice(cdPos + 46, cdPos + 46 + fileNameLen)
);
cdList.push({
fileName,
compressedSize,
uncompressedSize,
localHeaderOffset,
compressionMethod,
});
cdPos += 46 + fileNameLen + extraLen + commentLen;
}
// Now extract each entry using local headers + CD sizes
for (const cd of cdList) {
const lhOffset = cd.localHeaderOffset;
const lhFileNameLen = view.getUint16(lhOffset + 26, true);
const lhExtraLen = view.getUint16(lhOffset + 28, true);
const dataStart = lhOffset + 30 + lhFileNameLen + lhExtraLen;
let fileData: Uint8Array;
if (cd.compressionMethod === 0) {
// Stored
fileData = bytes.slice(dataStart, dataStart + cd.uncompressedSize);
} else if (cd.compressionMethod === 8) {
// Deflate
const compressed = bytes.slice(dataStart, dataStart + cd.compressedSize);
const ds = new DecompressionStream("deflate-raw");
const writer = ds.writable.getWriter();
writer.write(compressed);
writer.close();
const reader = ds.readable.getReader();
const chunks: Uint8Array[] = [];
let totalLen = 0;
while (true) {
const { done, value } = await reader.read();
if (done) break;
chunks.push(value);
totalLen += value.length;
}
fileData = new Uint8Array(totalLen);
let pos = 0;
for (const chunk of chunks) {
fileData.set(chunk, pos);
pos += chunk.length;
}
} else {
console.warn(`Skipping ${cd.fileName}: unsupported compression ${cd.compressionMethod}`);
continue;
}
entries.set(cd.fileName, fileData);
}
return entries;
}
/**
* Load voice embeddings from a .npz file URL.
*/
export async function loadVoices(
url: string
): Promise<Record<string, VoiceInfo>> {
const response = await fetch(url);
if (!response.ok) throw new Error(`Failed to fetch voices: ${response.status}`);
const arrayBuffer = await response.arrayBuffer();
const entries = await extractZipEntries(arrayBuffer);
const voices: Record<string, VoiceInfo> = {};
for (const [fileName, fileData] of entries) {
if (!fileName.endsWith(".npy")) continue;
const voiceName = fileName.replace(/\.npy$/, "");
const { data, shape } = npyToFloat32(fileData);
voices[voiceName] = {
data,
shape: [shape[0] || 1, shape[1] || data.length],
};
}
return voices;
}