Spaces:
Running
Running
| /** | |
| * Web Worker — KittenTTS inference via ONNX Runtime Web (WebGPU/WASM). | |
| * | |
| * Models: https://huggingface.co/KittenML | |
| * Phonemizer: https://github.com/xenova/phonemizer.js (Xenova) | |
| * ONNX Runtime Web: https://onnxruntime.ai | |
| */ | |
| import { tokenize } from "./lib/text-cleaner"; | |
| import { loadVoices, type VoiceInfo } from "./lib/npz-reader"; | |
| // Dynamic imports — resolved at runtime to avoid Vite dev server transform issues | |
| let phonemize: (text: string, lang: string) => Promise<string[]>; | |
| let ort: any; | |
| const HF_BASE = "https://huggingface.co"; | |
| const SAMPLE_RATE = 24000; | |
| // Only nano (fp32) confirmed working on WebGPU; micro/mini are int8 quantized | |
| const WEBGPU_SAFE_MODELS = ["Nano", "nano", "fp32"]; | |
| interface ModelConfig { | |
| name: string; | |
| version: string; | |
| type: string; | |
| model: string; | |
| model_file: string; | |
| voices: string; | |
| speed_priors: Record<string, number>; | |
| voice_aliases: Record<string, string>; | |
| } | |
| let session: any = null; | |
| let voices: Record<string, VoiceInfo> = {}; | |
| let config: ModelConfig | null = null; | |
| let currentDevice: "webgpu" | "wasm" = "wasm"; | |
| function resolveUrl(repoId: string, filename: string): string { | |
| return `${HF_BASE}/${repoId}/resolve/main/${filename}`; | |
| } | |
| async function detectWebGPU(): Promise<boolean> { | |
| try { | |
| if (!("gpu" in navigator)) return false; | |
| const adapter = await (navigator as any).gpu.requestAdapter(); | |
| return !!adapter; | |
| } catch { | |
| return false; | |
| } | |
| } | |
| async function loadModel(repoId: string) { | |
| self.postMessage({ type: "status", message: "Detecting hardware..." }); | |
| const hasWebGPU = await detectWebGPU(); | |
| // Load runtime dependencies | |
| self.postMessage({ type: "status", message: "Loading runtime..." }); | |
| const [ortModule, phonemizerModule] = await Promise.all([ | |
| import("onnxruntime-web"), | |
| import("phonemizer"), | |
| ]); | |
| ort = ortModule; | |
| phonemize = phonemizerModule.phonemize; | |
| // Load config (onnx-community repos use kitten_config.json for the TTS config) | |
| self.postMessage({ type: "status", message: "Loading config..." }); | |
| let configResp = await fetch(resolveUrl(repoId, "kitten_config.json")); | |
| if (!configResp.ok) { | |
| // Fallback to config.json for original KittenML repos | |
| configResp = await fetch(resolveUrl(repoId, "config.json")); | |
| } | |
| config = (await configResp.json()) as ModelConfig; | |
| // Only use WebGPU for models confirmed to work (nano-fp32) | |
| const modelName = config.model || repoId.split("/").pop() || ""; | |
| const isSafe = WEBGPU_SAFE_MODELS.some((m) => modelName.includes(m)); | |
| currentDevice = hasWebGPU && isSafe ? "webgpu" : "wasm"; | |
| if (hasWebGPU && !isSafe) { | |
| console.log(`[KittenTTS] Using WASM for "${modelName}" (WebGPU only confirmed for nano-fp32)`); | |
| } | |
| self.postMessage({ type: "device", device: currentDevice }); | |
| // Load voices (.npz) and ONNX model in parallel | |
| self.postMessage({ type: "status", message: "Downloading model & voices..." }); | |
| // onnx-community repos store the model at onnx/model.onnx | |
| const isOnnxCommunity = repoId.startsWith("onnx-community/"); | |
| const modelFile = isOnnxCommunity ? "onnx/model.onnx" : config.model_file; | |
| const modelUrl = resolveUrl(repoId, modelFile); | |
| const modelPromise = (async () => { | |
| const resp = await fetch(modelUrl); | |
| if (!resp.ok) throw new Error(`Failed to fetch model: ${resp.status}`); | |
| const contentLength = parseInt(resp.headers.get("content-length") || "0", 10); | |
| const reader = resp.body!.getReader(); | |
| const chunks: Uint8Array[] = []; | |
| let loaded = 0; | |
| while (true) { | |
| const { done, value } = await reader.read(); | |
| if (done) break; | |
| chunks.push(value); | |
| loaded += value.length; | |
| if (contentLength > 0) { | |
| const pct = Math.round((loaded / contentLength) * 100); | |
| const mb = (loaded / 1024 / 1024).toFixed(1); | |
| self.postMessage({ | |
| type: "status", | |
| message: `Downloading model... ${pct}% (${mb} MB)`, | |
| }); | |
| } | |
| } | |
| const modelData = new Uint8Array(loaded); | |
| let offset = 0; | |
| for (const chunk of chunks) { | |
| modelData.set(chunk, offset); | |
| offset += chunk.length; | |
| } | |
| return modelData.buffer; | |
| })(); | |
| const voicesUrl = resolveUrl(repoId, config.voices); | |
| const voicesPromise = loadVoices(voicesUrl); | |
| const [modelBuffer, loadedVoices] = await Promise.all([modelPromise, voicesPromise]); | |
| voices = loadedVoices; | |
| // Create ONNX inference session | |
| self.postMessage({ | |
| type: "status", | |
| message: `Initializing ${currentDevice.toUpperCase()} session...`, | |
| }); | |
| const sessionOptions: any = { | |
| executionProviders: currentDevice === "webgpu" ? ["webgpu"] : ["wasm"], | |
| }; | |
| if (currentDevice === "wasm") { | |
| ort.env.wasm.numThreads = 1; | |
| } | |
| session = await ort.InferenceSession.create(modelBuffer, sessionOptions); | |
| const voiceNames = config.voice_aliases | |
| ? Object.keys(config.voice_aliases) | |
| : Object.keys(voices); | |
| self.postMessage({ | |
| type: "ready", | |
| voices: voiceNames, | |
| device: currentDevice, | |
| modelName: config.name, | |
| }); | |
| } | |
| function ensurePunctuation(text: string): string { | |
| text = text.trim(); | |
| if (!text) return text; | |
| if (!".!?,;:".includes(text[text.length - 1])) { | |
| text += "."; | |
| } | |
| return text; | |
| } | |
| function chunkText(text: string, maxLen = 400): string[] { | |
| // Split on sentence boundaries but keep the punctuation | |
| const sentences = text.match(/[^.!?]*[.!?]+|[^.!?]+$/g) || [text]; | |
| const chunks: string[] = []; | |
| for (let sentence of sentences) { | |
| sentence = sentence.trim(); | |
| if (!sentence) continue; | |
| if (sentence.length <= maxLen) { | |
| chunks.push(ensurePunctuation(sentence)); | |
| } else { | |
| const words = sentence.split(/\s+/); | |
| let temp = ""; | |
| for (const word of words) { | |
| if (temp.length + word.length + 1 <= maxLen) { | |
| temp += (temp ? " " : "") + word; | |
| } else { | |
| if (temp) chunks.push(ensurePunctuation(temp)); | |
| temp = word; | |
| } | |
| } | |
| if (temp) chunks.push(ensurePunctuation(temp)); | |
| } | |
| } | |
| return chunks; | |
| } | |
| function basicTokenize(text: string): string[] { | |
| // Python's \w matches Unicode word chars (including IPA symbols). | |
| // JS \w only matches [a-zA-Z0-9_], so we use the Unicode-aware flag. | |
| return text.match(/[\p{L}\p{N}_]+|[^\p{L}\p{N}_\s]/gu) || []; | |
| } | |
| async function generateChunk( | |
| text: string, | |
| voiceKey: string, | |
| speed: number | |
| ): Promise<Float32Array> { | |
| if (!session || !config) throw new Error("Model not loaded"); | |
| let voiceId = voiceKey; | |
| if (config.voice_aliases?.[voiceKey]) { | |
| voiceId = config.voice_aliases[voiceKey]; | |
| } | |
| const voiceData = voices[voiceId]; | |
| if (!voiceData) throw new Error(`Voice "${voiceKey}" not found`); | |
| if (config.speed_priors?.[voiceId]) { | |
| speed = speed * config.speed_priors[voiceId]; | |
| } | |
| // Phonemize text preserving punctuation (matching Python's preserve_punctuation=True). | |
| // Split on punctuation, phonemize non-punctuation segments, rejoin with punctuation. | |
| const PUNCT_RE = /(\s*[;:,.!?¡¿—…"«»""()\[\]{}]+\s*)+/g; | |
| const sections: { match: boolean; text: string }[] = []; | |
| let lastIdx = 0; | |
| for (const m of text.matchAll(PUNCT_RE)) { | |
| if (lastIdx < m.index!) { | |
| sections.push({ match: false, text: text.slice(lastIdx, m.index!) }); | |
| } | |
| sections.push({ match: true, text: m[0] }); | |
| lastIdx = m.index! + m[0].length; | |
| } | |
| if (lastIdx < text.length) { | |
| sections.push({ match: false, text: text.slice(lastIdx) }); | |
| } | |
| // Phonemize only non-punctuation sections | |
| const phonemeParts = await Promise.all( | |
| sections.map(async (s) => { | |
| if (s.match) return s.text; // keep punctuation as-is | |
| const result = await phonemize(s.text, "en-us"); | |
| return result.join(" "); | |
| }) | |
| ); | |
| const phonemesRaw = phonemeParts.join(""); | |
| const phonemeTokens = basicTokenize(phonemesRaw); | |
| const phonemesJoined = phonemeTokens.join(" "); | |
| const inputIds = tokenize(phonemesJoined); | |
| // Select voice style reference based on text length (matches Python logic) | |
| const refId = Math.min(text.length, voiceData.shape[0] - 1); | |
| const styleDim = voiceData.shape[1]; | |
| const refStyle = voiceData.data.slice(refId * styleDim, (refId + 1) * styleDim); | |
| // Create ONNX tensors | |
| const inputIdsTensor = new ort.Tensor( | |
| "int64", | |
| BigInt64Array.from(inputIds.map(BigInt)), | |
| [1, inputIds.length] | |
| ); | |
| const styleTensor = new ort.Tensor("float32", refStyle, [1, styleDim]); | |
| const speedTensor = new ort.Tensor("float32", new Float32Array([speed]), [1]); | |
| // Run inference | |
| const results = await session.run({ | |
| input_ids: inputIdsTensor, | |
| style: styleTensor, | |
| speed: speedTensor, | |
| }); | |
| // Get output audio | |
| const outputKey = session.outputNames[0]; | |
| const audioData = results[outputKey].data as Float32Array; | |
| // Check for NaN — if detected, the model doesn't work on this backend | |
| const hasNaN = audioData.length > 0 && isNaN(audioData[0]); | |
| if (hasNaN) { | |
| console.warn(`[KittenTTS] Model produced NaN audio — this model may not be compatible with ${currentDevice.toUpperCase()}`); | |
| } | |
| // Python trims audio[..., :-5000] but this can cut real audio on short clips. | |
| // Only trim if audio is long enough (>1 second = 24000 samples) | |
| if (audioData.length > 24000) { | |
| return audioData.slice(0, audioData.length - 5000); | |
| } | |
| return audioData; | |
| } | |
| async function generate(text: string, voice: string, speed: number) { | |
| try { | |
| const chunks = chunkText(text); | |
| self.postMessage({ | |
| type: "status", | |
| message: `Generating (${chunks.length} chunk${chunks.length > 1 ? "s" : ""})...`, | |
| }); | |
| const audioChunks: Float32Array[] = []; | |
| for (let i = 0; i < chunks.length; i++) { | |
| self.postMessage({ | |
| type: "progress", | |
| current: i + 1, | |
| total: chunks.length, | |
| }); | |
| const audio = await generateChunk(chunks[i], voice, speed); | |
| audioChunks.push(audio); | |
| } | |
| const totalLen = audioChunks.reduce((s, c) => s + c.length, 0); | |
| const fullAudio = new Float32Array(totalLen); | |
| let offset = 0; | |
| for (const chunk of audioChunks) { | |
| fullAudio.set(chunk, offset); | |
| offset += chunk.length; | |
| } | |
| self.postMessage( | |
| { | |
| type: "audio", | |
| audio: fullAudio.buffer, | |
| sampleRate: SAMPLE_RATE, | |
| }, | |
| { transfer: [fullAudio.buffer] } | |
| ); | |
| } catch (err: any) { | |
| self.postMessage({ type: "error", error: err.message || String(err) }); | |
| } | |
| } | |
| // Message handler | |
| self.addEventListener("message", async (e) => { | |
| const { action, ...data } = e.data; | |
| switch (action) { | |
| case "load": | |
| try { | |
| await loadModel(data.repoId); | |
| } catch (err: any) { | |
| console.error("[KittenTTS Worker] Load error:", err); | |
| self.postMessage({ type: "error", error: err.message || String(err) }); | |
| } | |
| break; | |
| case "generate": | |
| await generate(data.text, data.voice, data.speed); | |
| break; | |
| } | |
| }); | |
| self.addEventListener("error", (e) => { | |
| self.postMessage({ type: "error", error: e.message || "Unknown worker error" }); | |
| }); | |
| self.addEventListener("unhandledrejection", (e: any) => { | |
| self.postMessage({ type: "error", error: e.reason?.message || String(e.reason) }); | |
| }); | |