KittenTTS-WebGPU / src /worker.ts
shreyask's picture
Upload folder using huggingface_hub
f8290dd verified
/**
* Web Worker — KittenTTS inference via ONNX Runtime Web (WebGPU/WASM).
*
* Models: https://huggingface.co/KittenML
* Phonemizer: https://github.com/xenova/phonemizer.js (Xenova)
* ONNX Runtime Web: https://onnxruntime.ai
*/
import { tokenize } from "./lib/text-cleaner";
import { loadVoices, type VoiceInfo } from "./lib/npz-reader";
// Dynamic imports — resolved at runtime to avoid Vite dev server transform issues
let phonemize: (text: string, lang: string) => Promise<string[]>;
let ort: any;
const HF_BASE = "https://huggingface.co";
const SAMPLE_RATE = 24000;
// Only nano (fp32) confirmed working on WebGPU; micro/mini are int8 quantized
const WEBGPU_SAFE_MODELS = ["Nano", "nano", "fp32"];
interface ModelConfig {
name: string;
version: string;
type: string;
model: string;
model_file: string;
voices: string;
speed_priors: Record<string, number>;
voice_aliases: Record<string, string>;
}
let session: any = null;
let voices: Record<string, VoiceInfo> = {};
let config: ModelConfig | null = null;
let currentDevice: "webgpu" | "wasm" = "wasm";
function resolveUrl(repoId: string, filename: string): string {
return `${HF_BASE}/${repoId}/resolve/main/${filename}`;
}
async function detectWebGPU(): Promise<boolean> {
try {
if (!("gpu" in navigator)) return false;
const adapter = await (navigator as any).gpu.requestAdapter();
return !!adapter;
} catch {
return false;
}
}
async function loadModel(repoId: string) {
self.postMessage({ type: "status", message: "Detecting hardware..." });
const hasWebGPU = await detectWebGPU();
// Load runtime dependencies
self.postMessage({ type: "status", message: "Loading runtime..." });
const [ortModule, phonemizerModule] = await Promise.all([
import("onnxruntime-web"),
import("phonemizer"),
]);
ort = ortModule;
phonemize = phonemizerModule.phonemize;
// Load config (onnx-community repos use kitten_config.json for the TTS config)
self.postMessage({ type: "status", message: "Loading config..." });
let configResp = await fetch(resolveUrl(repoId, "kitten_config.json"));
if (!configResp.ok) {
// Fallback to config.json for original KittenML repos
configResp = await fetch(resolveUrl(repoId, "config.json"));
}
config = (await configResp.json()) as ModelConfig;
// Only use WebGPU for models confirmed to work (nano-fp32)
const modelName = config.model || repoId.split("/").pop() || "";
const isSafe = WEBGPU_SAFE_MODELS.some((m) => modelName.includes(m));
currentDevice = hasWebGPU && isSafe ? "webgpu" : "wasm";
if (hasWebGPU && !isSafe) {
console.log(`[KittenTTS] Using WASM for "${modelName}" (WebGPU only confirmed for nano-fp32)`);
}
self.postMessage({ type: "device", device: currentDevice });
// Load voices (.npz) and ONNX model in parallel
self.postMessage({ type: "status", message: "Downloading model & voices..." });
// onnx-community repos store the model at onnx/model.onnx
const isOnnxCommunity = repoId.startsWith("onnx-community/");
const modelFile = isOnnxCommunity ? "onnx/model.onnx" : config.model_file;
const modelUrl = resolveUrl(repoId, modelFile);
const modelPromise = (async () => {
const resp = await fetch(modelUrl);
if (!resp.ok) throw new Error(`Failed to fetch model: ${resp.status}`);
const contentLength = parseInt(resp.headers.get("content-length") || "0", 10);
const reader = resp.body!.getReader();
const chunks: Uint8Array[] = [];
let loaded = 0;
while (true) {
const { done, value } = await reader.read();
if (done) break;
chunks.push(value);
loaded += value.length;
if (contentLength > 0) {
const pct = Math.round((loaded / contentLength) * 100);
const mb = (loaded / 1024 / 1024).toFixed(1);
self.postMessage({
type: "status",
message: `Downloading model... ${pct}% (${mb} MB)`,
});
}
}
const modelData = new Uint8Array(loaded);
let offset = 0;
for (const chunk of chunks) {
modelData.set(chunk, offset);
offset += chunk.length;
}
return modelData.buffer;
})();
const voicesUrl = resolveUrl(repoId, config.voices);
const voicesPromise = loadVoices(voicesUrl);
const [modelBuffer, loadedVoices] = await Promise.all([modelPromise, voicesPromise]);
voices = loadedVoices;
// Create ONNX inference session
self.postMessage({
type: "status",
message: `Initializing ${currentDevice.toUpperCase()} session...`,
});
const sessionOptions: any = {
executionProviders: currentDevice === "webgpu" ? ["webgpu"] : ["wasm"],
};
if (currentDevice === "wasm") {
ort.env.wasm.numThreads = 1;
}
session = await ort.InferenceSession.create(modelBuffer, sessionOptions);
const voiceNames = config.voice_aliases
? Object.keys(config.voice_aliases)
: Object.keys(voices);
self.postMessage({
type: "ready",
voices: voiceNames,
device: currentDevice,
modelName: config.name,
});
}
function ensurePunctuation(text: string): string {
text = text.trim();
if (!text) return text;
if (!".!?,;:".includes(text[text.length - 1])) {
text += ".";
}
return text;
}
function chunkText(text: string, maxLen = 400): string[] {
// Split on sentence boundaries but keep the punctuation
const sentences = text.match(/[^.!?]*[.!?]+|[^.!?]+$/g) || [text];
const chunks: string[] = [];
for (let sentence of sentences) {
sentence = sentence.trim();
if (!sentence) continue;
if (sentence.length <= maxLen) {
chunks.push(ensurePunctuation(sentence));
} else {
const words = sentence.split(/\s+/);
let temp = "";
for (const word of words) {
if (temp.length + word.length + 1 <= maxLen) {
temp += (temp ? " " : "") + word;
} else {
if (temp) chunks.push(ensurePunctuation(temp));
temp = word;
}
}
if (temp) chunks.push(ensurePunctuation(temp));
}
}
return chunks;
}
function basicTokenize(text: string): string[] {
// Python's \w matches Unicode word chars (including IPA symbols).
// JS \w only matches [a-zA-Z0-9_], so we use the Unicode-aware flag.
return text.match(/[\p{L}\p{N}_]+|[^\p{L}\p{N}_\s]/gu) || [];
}
async function generateChunk(
text: string,
voiceKey: string,
speed: number
): Promise<Float32Array> {
if (!session || !config) throw new Error("Model not loaded");
let voiceId = voiceKey;
if (config.voice_aliases?.[voiceKey]) {
voiceId = config.voice_aliases[voiceKey];
}
const voiceData = voices[voiceId];
if (!voiceData) throw new Error(`Voice "${voiceKey}" not found`);
if (config.speed_priors?.[voiceId]) {
speed = speed * config.speed_priors[voiceId];
}
// Phonemize text preserving punctuation (matching Python's preserve_punctuation=True).
// Split on punctuation, phonemize non-punctuation segments, rejoin with punctuation.
const PUNCT_RE = /(\s*[;:,.!?¡¿—…"«»""()\[\]{}]+\s*)+/g;
const sections: { match: boolean; text: string }[] = [];
let lastIdx = 0;
for (const m of text.matchAll(PUNCT_RE)) {
if (lastIdx < m.index!) {
sections.push({ match: false, text: text.slice(lastIdx, m.index!) });
}
sections.push({ match: true, text: m[0] });
lastIdx = m.index! + m[0].length;
}
if (lastIdx < text.length) {
sections.push({ match: false, text: text.slice(lastIdx) });
}
// Phonemize only non-punctuation sections
const phonemeParts = await Promise.all(
sections.map(async (s) => {
if (s.match) return s.text; // keep punctuation as-is
const result = await phonemize(s.text, "en-us");
return result.join(" ");
})
);
const phonemesRaw = phonemeParts.join("");
const phonemeTokens = basicTokenize(phonemesRaw);
const phonemesJoined = phonemeTokens.join(" ");
const inputIds = tokenize(phonemesJoined);
// Select voice style reference based on text length (matches Python logic)
const refId = Math.min(text.length, voiceData.shape[0] - 1);
const styleDim = voiceData.shape[1];
const refStyle = voiceData.data.slice(refId * styleDim, (refId + 1) * styleDim);
// Create ONNX tensors
const inputIdsTensor = new ort.Tensor(
"int64",
BigInt64Array.from(inputIds.map(BigInt)),
[1, inputIds.length]
);
const styleTensor = new ort.Tensor("float32", refStyle, [1, styleDim]);
const speedTensor = new ort.Tensor("float32", new Float32Array([speed]), [1]);
// Run inference
const results = await session.run({
input_ids: inputIdsTensor,
style: styleTensor,
speed: speedTensor,
});
// Get output audio
const outputKey = session.outputNames[0];
const audioData = results[outputKey].data as Float32Array;
// Check for NaN — if detected, the model doesn't work on this backend
const hasNaN = audioData.length > 0 && isNaN(audioData[0]);
if (hasNaN) {
console.warn(`[KittenTTS] Model produced NaN audio — this model may not be compatible with ${currentDevice.toUpperCase()}`);
}
// Python trims audio[..., :-5000] but this can cut real audio on short clips.
// Only trim if audio is long enough (>1 second = 24000 samples)
if (audioData.length > 24000) {
return audioData.slice(0, audioData.length - 5000);
}
return audioData;
}
async function generate(text: string, voice: string, speed: number) {
try {
const chunks = chunkText(text);
self.postMessage({
type: "status",
message: `Generating (${chunks.length} chunk${chunks.length > 1 ? "s" : ""})...`,
});
const audioChunks: Float32Array[] = [];
for (let i = 0; i < chunks.length; i++) {
self.postMessage({
type: "progress",
current: i + 1,
total: chunks.length,
});
const audio = await generateChunk(chunks[i], voice, speed);
audioChunks.push(audio);
}
const totalLen = audioChunks.reduce((s, c) => s + c.length, 0);
const fullAudio = new Float32Array(totalLen);
let offset = 0;
for (const chunk of audioChunks) {
fullAudio.set(chunk, offset);
offset += chunk.length;
}
self.postMessage(
{
type: "audio",
audio: fullAudio.buffer,
sampleRate: SAMPLE_RATE,
},
{ transfer: [fullAudio.buffer] }
);
} catch (err: any) {
self.postMessage({ type: "error", error: err.message || String(err) });
}
}
// Message handler
self.addEventListener("message", async (e) => {
const { action, ...data } = e.data;
switch (action) {
case "load":
try {
await loadModel(data.repoId);
} catch (err: any) {
console.error("[KittenTTS Worker] Load error:", err);
self.postMessage({ type: "error", error: err.message || String(err) });
}
break;
case "generate":
await generate(data.text, data.voice, data.speed);
break;
}
});
self.addEventListener("error", (e) => {
self.postMessage({ type: "error", error: e.message || "Unknown worker error" });
});
self.addEventListener("unhandledrejection", (e: any) => {
self.postMessage({ type: "error", error: e.reason?.message || String(e.reason) });
});