/** * Web Worker — KittenTTS inference via ONNX Runtime Web (WebGPU/WASM). * * Models: https://huggingface.co/KittenML * Phonemizer: https://github.com/xenova/phonemizer.js (Xenova) * ONNX Runtime Web: https://onnxruntime.ai */ import { tokenize } from "./lib/text-cleaner"; import { loadVoices, type VoiceInfo } from "./lib/npz-reader"; // Dynamic imports — resolved at runtime to avoid Vite dev server transform issues let phonemize: (text: string, lang: string) => Promise; let ort: any; const HF_BASE = "https://huggingface.co"; const SAMPLE_RATE = 24000; // Only nano (fp32) confirmed working on WebGPU; micro/mini are int8 quantized const WEBGPU_SAFE_MODELS = ["Nano", "nano", "fp32"]; interface ModelConfig { name: string; version: string; type: string; model: string; model_file: string; voices: string; speed_priors: Record; voice_aliases: Record; } let session: any = null; let voices: Record = {}; let config: ModelConfig | null = null; let currentDevice: "webgpu" | "wasm" = "wasm"; function resolveUrl(repoId: string, filename: string): string { return `${HF_BASE}/${repoId}/resolve/main/${filename}`; } async function detectWebGPU(): Promise { try { if (!("gpu" in navigator)) return false; const adapter = await (navigator as any).gpu.requestAdapter(); return !!adapter; } catch { return false; } } async function loadModel(repoId: string) { self.postMessage({ type: "status", message: "Detecting hardware..." }); const hasWebGPU = await detectWebGPU(); // Load runtime dependencies self.postMessage({ type: "status", message: "Loading runtime..." }); const [ortModule, phonemizerModule] = await Promise.all([ import("onnxruntime-web"), import("phonemizer"), ]); ort = ortModule; phonemize = phonemizerModule.phonemize; // Load config (onnx-community repos use kitten_config.json for the TTS config) self.postMessage({ type: "status", message: "Loading config..." }); let configResp = await fetch(resolveUrl(repoId, "kitten_config.json")); if (!configResp.ok) { // Fallback to config.json for original KittenML repos configResp = await fetch(resolveUrl(repoId, "config.json")); } config = (await configResp.json()) as ModelConfig; // Only use WebGPU for models confirmed to work (nano-fp32) const modelName = config.model || repoId.split("/").pop() || ""; const isSafe = WEBGPU_SAFE_MODELS.some((m) => modelName.includes(m)); currentDevice = hasWebGPU && isSafe ? "webgpu" : "wasm"; if (hasWebGPU && !isSafe) { console.log(`[KittenTTS] Using WASM for "${modelName}" (WebGPU only confirmed for nano-fp32)`); } self.postMessage({ type: "device", device: currentDevice }); // Load voices (.npz) and ONNX model in parallel self.postMessage({ type: "status", message: "Downloading model & voices..." }); // onnx-community repos store the model at onnx/model.onnx const isOnnxCommunity = repoId.startsWith("onnx-community/"); const modelFile = isOnnxCommunity ? "onnx/model.onnx" : config.model_file; const modelUrl = resolveUrl(repoId, modelFile); const modelPromise = (async () => { const resp = await fetch(modelUrl); if (!resp.ok) throw new Error(`Failed to fetch model: ${resp.status}`); const contentLength = parseInt(resp.headers.get("content-length") || "0", 10); const reader = resp.body!.getReader(); const chunks: Uint8Array[] = []; let loaded = 0; while (true) { const { done, value } = await reader.read(); if (done) break; chunks.push(value); loaded += value.length; if (contentLength > 0) { const pct = Math.round((loaded / contentLength) * 100); const mb = (loaded / 1024 / 1024).toFixed(1); self.postMessage({ type: "status", message: `Downloading model... ${pct}% (${mb} MB)`, }); } } const modelData = new Uint8Array(loaded); let offset = 0; for (const chunk of chunks) { modelData.set(chunk, offset); offset += chunk.length; } return modelData.buffer; })(); const voicesUrl = resolveUrl(repoId, config.voices); const voicesPromise = loadVoices(voicesUrl); const [modelBuffer, loadedVoices] = await Promise.all([modelPromise, voicesPromise]); voices = loadedVoices; // Create ONNX inference session self.postMessage({ type: "status", message: `Initializing ${currentDevice.toUpperCase()} session...`, }); const sessionOptions: any = { executionProviders: currentDevice === "webgpu" ? ["webgpu"] : ["wasm"], }; if (currentDevice === "wasm") { ort.env.wasm.numThreads = 1; } session = await ort.InferenceSession.create(modelBuffer, sessionOptions); const voiceNames = config.voice_aliases ? Object.keys(config.voice_aliases) : Object.keys(voices); self.postMessage({ type: "ready", voices: voiceNames, device: currentDevice, modelName: config.name, }); } function ensurePunctuation(text: string): string { text = text.trim(); if (!text) return text; if (!".!?,;:".includes(text[text.length - 1])) { text += "."; } return text; } function chunkText(text: string, maxLen = 400): string[] { // Split on sentence boundaries but keep the punctuation const sentences = text.match(/[^.!?]*[.!?]+|[^.!?]+$/g) || [text]; const chunks: string[] = []; for (let sentence of sentences) { sentence = sentence.trim(); if (!sentence) continue; if (sentence.length <= maxLen) { chunks.push(ensurePunctuation(sentence)); } else { const words = sentence.split(/\s+/); let temp = ""; for (const word of words) { if (temp.length + word.length + 1 <= maxLen) { temp += (temp ? " " : "") + word; } else { if (temp) chunks.push(ensurePunctuation(temp)); temp = word; } } if (temp) chunks.push(ensurePunctuation(temp)); } } return chunks; } function basicTokenize(text: string): string[] { // Python's \w matches Unicode word chars (including IPA symbols). // JS \w only matches [a-zA-Z0-9_], so we use the Unicode-aware flag. return text.match(/[\p{L}\p{N}_]+|[^\p{L}\p{N}_\s]/gu) || []; } async function generateChunk( text: string, voiceKey: string, speed: number ): Promise { if (!session || !config) throw new Error("Model not loaded"); let voiceId = voiceKey; if (config.voice_aliases?.[voiceKey]) { voiceId = config.voice_aliases[voiceKey]; } const voiceData = voices[voiceId]; if (!voiceData) throw new Error(`Voice "${voiceKey}" not found`); if (config.speed_priors?.[voiceId]) { speed = speed * config.speed_priors[voiceId]; } // Phonemize text preserving punctuation (matching Python's preserve_punctuation=True). // Split on punctuation, phonemize non-punctuation segments, rejoin with punctuation. const PUNCT_RE = /(\s*[;:,.!?¡¿—…"«»""()\[\]{}]+\s*)+/g; const sections: { match: boolean; text: string }[] = []; let lastIdx = 0; for (const m of text.matchAll(PUNCT_RE)) { if (lastIdx < m.index!) { sections.push({ match: false, text: text.slice(lastIdx, m.index!) }); } sections.push({ match: true, text: m[0] }); lastIdx = m.index! + m[0].length; } if (lastIdx < text.length) { sections.push({ match: false, text: text.slice(lastIdx) }); } // Phonemize only non-punctuation sections const phonemeParts = await Promise.all( sections.map(async (s) => { if (s.match) return s.text; // keep punctuation as-is const result = await phonemize(s.text, "en-us"); return result.join(" "); }) ); const phonemesRaw = phonemeParts.join(""); const phonemeTokens = basicTokenize(phonemesRaw); const phonemesJoined = phonemeTokens.join(" "); const inputIds = tokenize(phonemesJoined); // Select voice style reference based on text length (matches Python logic) const refId = Math.min(text.length, voiceData.shape[0] - 1); const styleDim = voiceData.shape[1]; const refStyle = voiceData.data.slice(refId * styleDim, (refId + 1) * styleDim); // Create ONNX tensors const inputIdsTensor = new ort.Tensor( "int64", BigInt64Array.from(inputIds.map(BigInt)), [1, inputIds.length] ); const styleTensor = new ort.Tensor("float32", refStyle, [1, styleDim]); const speedTensor = new ort.Tensor("float32", new Float32Array([speed]), [1]); // Run inference const results = await session.run({ input_ids: inputIdsTensor, style: styleTensor, speed: speedTensor, }); // Get output audio const outputKey = session.outputNames[0]; const audioData = results[outputKey].data as Float32Array; // Check for NaN — if detected, the model doesn't work on this backend const hasNaN = audioData.length > 0 && isNaN(audioData[0]); if (hasNaN) { console.warn(`[KittenTTS] Model produced NaN audio — this model may not be compatible with ${currentDevice.toUpperCase()}`); } // Python trims audio[..., :-5000] but this can cut real audio on short clips. // Only trim if audio is long enough (>1 second = 24000 samples) if (audioData.length > 24000) { return audioData.slice(0, audioData.length - 5000); } return audioData; } async function generate(text: string, voice: string, speed: number) { try { const chunks = chunkText(text); self.postMessage({ type: "status", message: `Generating (${chunks.length} chunk${chunks.length > 1 ? "s" : ""})...`, }); const audioChunks: Float32Array[] = []; for (let i = 0; i < chunks.length; i++) { self.postMessage({ type: "progress", current: i + 1, total: chunks.length, }); const audio = await generateChunk(chunks[i], voice, speed); audioChunks.push(audio); } const totalLen = audioChunks.reduce((s, c) => s + c.length, 0); const fullAudio = new Float32Array(totalLen); let offset = 0; for (const chunk of audioChunks) { fullAudio.set(chunk, offset); offset += chunk.length; } self.postMessage( { type: "audio", audio: fullAudio.buffer, sampleRate: SAMPLE_RATE, }, { transfer: [fullAudio.buffer] } ); } catch (err: any) { self.postMessage({ type: "error", error: err.message || String(err) }); } } // Message handler self.addEventListener("message", async (e) => { const { action, ...data } = e.data; switch (action) { case "load": try { await loadModel(data.repoId); } catch (err: any) { console.error("[KittenTTS Worker] Load error:", err); self.postMessage({ type: "error", error: err.message || String(err) }); } break; case "generate": await generate(data.text, data.voice, data.speed); break; } }); self.addEventListener("error", (e) => { self.postMessage({ type: "error", error: e.message || "Unknown worker error" }); }); self.addEventListener("unhandledrejection", (e: any) => { self.postMessage({ type: "error", error: e.reason?.message || String(e.reason) }); });