KittenTTS-WebGPU / src /App.tsx
shreyask's picture
Upload folder using huggingface_hub
7eb0260 verified
import { useState, useRef, useCallback, useEffect } from "react";
import WaveformPlayer from "./WaveformPlayer";
const MODELS: Record<string, string> = {
"Nano (15M - Fastest)": "onnx-community/KittenTTS-Nano-v0.8-ONNX",
"Micro (40M - Balanced)": "onnx-community/KittenTTS-Micro-v0.8-ONNX",
"Mini (80M - Best Quality)": "onnx-community/KittenTTS-Mini-v0.8-ONNX",
};
const DEFAULT_MODEL = "Nano (15M - Fastest)";
const EXAMPLES = [
{
text: "Space is a three-dimensional continuum containing positions and directions.",
model: "Micro (40M - Balanced)",
voice: "Jasper",
speed: 1.0,
},
{
text: "She picked up her coffee and walked toward the window.",
model: "Mini (80M - Best Quality)",
voice: "Luna",
speed: 1.0,
},
{
text: "The sun set slowly over the calm, quiet lake",
model: "Nano (15M - Fastest)",
voice: "Bella",
speed: 1.1,
},
];
type Status = "idle" | "loading" | "ready" | "generating" | "error";
export default function App() {
const [text, setText] = useState("");
const [model, setModel] = useState(DEFAULT_MODEL);
const [voice, setVoice] = useState("Jasper");
const [speed, setSpeed] = useState(1.0);
const [voices, setVoices] = useState<string[]>([]);
const [status, setStatus] = useState<Status>("idle");
const [statusMsg, setStatusMsg] = useState("");
const [, setDevice] = useState("");
const [progress, setProgress] = useState({ current: 0, total: 0 });
const [audioUrl, setAudioUrl] = useState<string | null>(null);
const [error, setError] = useState<string | null>(null);
const [duration, setDuration] = useState<number | null>(null);
const workerRef = useRef<Worker | null>(null);
const genStartRef = useRef<number>(0);
const initWorker = useCallback(() => {
if (workerRef.current) workerRef.current.terminate();
const worker = new Worker(new URL("./worker.ts", import.meta.url), {
type: "module",
});
workerRef.current = worker;
worker.addEventListener("error", (e) => {
console.error("Worker error:", e);
setError(`Worker failed: ${e.message}`);
setStatus("error");
setStatusMsg("");
});
worker.addEventListener("message", (e) => {
const msg = e.data;
switch (msg.type) {
case "status":
setStatusMsg(msg.message);
break;
case "device":
setDevice(msg.device);
break;
case "ready":
setStatus("ready");
setVoices(msg.voices);
setStatusMsg(`${msg.modelName} loaded`);
break;
case "progress":
setProgress({ current: msg.current, total: msg.total });
break;
case "audio": {
const audioData = new Float32Array(msg.audio);
const blob = float32ToWav(audioData, msg.sampleRate);
const url = URL.createObjectURL(blob);
setAudioUrl((prev) => {
if (prev) URL.revokeObjectURL(prev);
return url;
});
setDuration(
Math.round(performance.now() - genStartRef.current)
);
setStatus("ready");
setStatusMsg("Done!");
break;
}
case "error":
setError(msg.error);
setStatus("error");
setStatusMsg("");
break;
}
});
return worker;
}, []);
const loadModel = useCallback(
(modelKey: string) => {
const worker = workerRef.current || initWorker();
setStatus("loading");
setError(null);
setAudioUrl(null);
setDuration(null);
setStatusMsg("Starting...");
worker.postMessage({ action: "load", repoId: MODELS[modelKey] });
},
[initWorker]
);
useEffect(() => {
loadModel(model);
return () => {
workerRef.current?.terminate();
};
// eslint-disable-next-line react-hooks/exhaustive-deps
}, []);
const handleModelChange = (newModel: string) => {
setModel(newModel);
loadModel(newModel);
};
const handleGenerate = () => {
if (!text.trim() || status !== "ready") return;
setStatus("generating");
setError(null);
setDuration(null);
setProgress({ current: 0, total: 0 });
genStartRef.current = performance.now();
workerRef.current?.postMessage({ action: "generate", text, voice, speed });
};
const handleExample = (ex: (typeof EXAMPLES)[0]) => {
setText(ex.text);
setVoice(ex.voice);
setSpeed(ex.speed);
if (ex.model !== model) {
handleModelChange(ex.model);
}
};
const handleKeyDown = (e: React.KeyboardEvent<HTMLTextAreaElement>) => {
if ((e.metaKey || e.ctrlKey) && e.key === "Enter") {
e.preventDefault();
handleGenerate();
}
};
return (
<div className="container">
<header>
<h1>
<a href="https://huggingface.co/KittenML" target="_blank" rel="noopener" className="title-link">
<span className="logo">🐱</span> KittenTTS
</a>
</h1>
<p className="subtitle">
Text-to-speech running entirely in your browser
</p>
</header>
<main>
<div className="input-section">
<label htmlFor="text-input">Text</label>
<textarea
id="text-input"
value={text}
onChange={(e) => setText(e.target.value)}
onKeyDown={handleKeyDown}
placeholder="Enter text to synthesize…"
rows={5}
/>
<div className="controls-row">
<div className="control">
<label htmlFor="model-select">Model</label>
<select
id="model-select"
value={model}
onChange={(e) => handleModelChange(e.target.value)}
disabled={status === "loading" || status === "generating"}
>
{Object.keys(MODELS).map((m) => (
<option key={m} value={m}>
{m}
</option>
))}
</select>
</div>
<div className="control">
<label htmlFor="voice-select">Voice</label>
<select
id="voice-select"
value={voice}
onChange={(e) => setVoice(e.target.value)}
disabled={voices.length === 0}
>
{voices.map((v) => (
<option key={v} value={v}>
{v}
</option>
))}
</select>
</div>
</div>
<div className="speed-row">
<label htmlFor="speed-slider">Speed: {speed.toFixed(2)}x</label>
<input
id="speed-slider"
type="range"
min={0.5}
max={2.0}
step={0.05}
value={speed}
onChange={(e) => setSpeed(parseFloat(e.target.value))}
/>
</div>
<button
className="generate-btn"
onClick={handleGenerate}
disabled={status !== "ready" || !text.trim()}
>
{status === "generating"
? progress.total > 0
? `Generating ${progress.current}/${progress.total}…`
: "Generating…"
: status === "loading"
? "Loading model…"
: "Generate Speech"}
</button>
</div>
<div className="output-section">
<label>Output</label>
{audioUrl ? (
<WaveformPlayer audioUrl={audioUrl} duration={duration} />
) : (
<div className="audio-placeholder">
{status === "loading" || status === "generating"
? statusMsg
: "Audio will appear here"}
</div>
)}
</div>
<div className="examples">
<label>Examples</label>
<div className="examples-grid">
{EXAMPLES.map((ex, i) => (
<button
key={i}
className="example-btn"
onClick={() => handleExample(ex)}
disabled={status === "loading" || status === "generating"}
>
<span className="example-voice">{ex.voice}</span>
<span className="example-text">{ex.text}</span>
<span className="example-meta">
{ex.model.split(" (")[0]}{ex.speed !== 1.0 ? ` · ${ex.speed}x` : ""}
</span>
</button>
))}
</div>
</div>
{error && <div className="error-msg">{error}</div>}
</main>
<footer>
<p className="footer-note">
* Nano runs on WebGPU for faster inference. Micro and Mini use WASM (int8 quantized).
</p>
<p>
Models by{" "}
<a href="https://huggingface.co/KittenML" target="_blank" rel="noopener">KittenML</a>
{" · "}
<a href="https://huggingface.co/spaces/KittenML/KittenTTS-Demo" target="_blank" rel="noopener">Original demo</a>
</p>
<p>
Powered by{" "}
<a href="https://github.com/huggingface/transformers.js" target="_blank" rel="noopener">Transformers.js</a>
{" · "}
<a href="https://github.com/xenova/phonemizer.js" target="_blank" rel="noopener">phonemizer.js</a>
{" by "}
<a href="https://github.com/xenova" target="_blank" rel="noopener">Xenova</a>
{" · "}
<a href="https://onnxruntime.ai" target="_blank" rel="noopener">ONNX Runtime Web</a>
</p>
</footer>
</div>
);
}
/** Convert Float32Array PCM to WAV Blob */
/** Convert Float32Array PCM to WAV Blob using IEEE float format */
function float32ToWav(samples: Float32Array, sampleRate: number): Blob {
// Normalize audio to [-1, 1] range
let maxAbs = 0;
for (let i = 0; i < samples.length; i++) {
const abs = Math.abs(samples[i]);
if (abs > maxAbs) maxAbs = abs;
}
if (maxAbs > 1) {
const scale = 0.95 / maxAbs; // leave some headroom
for (let i = 0; i < samples.length; i++) {
samples[i] *= scale;
}
}
// Write as IEEE 32-bit float WAV (format 3)
const bytesPerSample = 4;
const dataSize = samples.length * bytesPerSample;
const buffer = new ArrayBuffer(44 + dataSize);
const view = new DataView(buffer);
const writeStr = (offset: number, str: string) => {
for (let i = 0; i < str.length; i++)
view.setUint8(offset + i, str.charCodeAt(i));
};
writeStr(0, "RIFF");
view.setUint32(4, 36 + dataSize, true);
writeStr(8, "WAVE");
writeStr(12, "fmt ");
view.setUint32(16, 16, true);
view.setUint16(20, 3, true); // IEEE float
view.setUint16(22, 1, true); // mono
view.setUint32(24, sampleRate, true);
view.setUint32(28, sampleRate * bytesPerSample, true);
view.setUint16(32, bytesPerSample, true);
view.setUint16(34, 32, true); // bits per sample
writeStr(36, "data");
view.setUint32(40, dataSize, true);
// Write float samples directly
let offset = 44;
for (let i = 0; i < samples.length; i++) {
view.setFloat32(offset, samples[i], true);
offset += 4;
}
return new Blob([buffer], { type: "audio/wav" });
}