| | import { |
| | |
| | AutoModel, |
| |
|
| | |
| | AutoTokenizer, |
| | AutoModelForCausalLM, |
| | TextStreamer, |
| | InterruptableStoppingCriteria, |
| |
|
| | |
| | Tensor, |
| | pipeline, |
| | } from "@huggingface/transformers"; |
| |
|
| | import { KokoroTTS, TextSplitterStream } from "kokoro-js"; |
| |
|
| | import { |
| | MAX_BUFFER_DURATION, |
| | INPUT_SAMPLE_RATE, |
| | SPEECH_THRESHOLD, |
| | EXIT_THRESHOLD, |
| | SPEECH_PAD_SAMPLES, |
| | MAX_NUM_PREV_BUFFERS, |
| | MIN_SILENCE_DURATION_SAMPLES, |
| | MIN_SPEECH_DURATION_SAMPLES, |
| | } from "./constants"; |
| |
|
| | const model_id = "onnx-community/Kokoro-82M-v1.0-ONNX"; |
| | let voice; |
| | const tts = await KokoroTTS.from_pretrained(model_id, { |
| | dtype: "fp32", |
| | device: "webgpu", |
| | }); |
| |
|
| | const device = "webgpu"; |
| | self.postMessage({ type: "info", message: `Using device: "${device}"` }); |
| | self.postMessage({ |
| | type: "info", |
| | message: "Loading models...", |
| | duration: "until_next", |
| | }); |
| |
|
| | |
| | const silero_vad = await AutoModel.from_pretrained( |
| | "onnx-community/silero-vad", |
| | { |
| | config: { model_type: "custom" }, |
| | dtype: "fp32", |
| | }, |
| | ).catch((error) => { |
| | self.postMessage({ error }); |
| | throw error; |
| | }); |
| |
|
| | const DEVICE_DTYPE_CONFIGS = { |
| | webgpu: { |
| | encoder_model: "fp32", |
| | decoder_model_merged: "fp32", |
| | }, |
| | wasm: { |
| | encoder_model: "fp32", |
| | decoder_model_merged: "q8", |
| | }, |
| | }; |
| | const transcriber = await pipeline( |
| | "automatic-speech-recognition", |
| | "onnx-community/whisper-base", |
| | { |
| | device, |
| | dtype: DEVICE_DTYPE_CONFIGS[device], |
| | }, |
| | ).catch((error) => { |
| | self.postMessage({ error }); |
| | throw error; |
| | }); |
| |
|
| | await transcriber(new Float32Array(INPUT_SAMPLE_RATE)); |
| |
|
| | const llm_model_id = "HuggingFaceTB/SmolLM2-1.7B-Instruct"; |
| | const tokenizer = await AutoTokenizer.from_pretrained(llm_model_id); |
| | const llm = await AutoModelForCausalLM.from_pretrained(llm_model_id, { |
| | dtype: "q4f16", |
| | device: "webgpu", |
| | }); |
| |
|
| | const SYSTEM_MESSAGE = { |
| | role: "system", |
| | content: |
| | "You're a helpful and conversational voice assistant. Keep your responses short, clear, and casual.", |
| | }; |
| | await llm.generate({ ...tokenizer("x"), max_new_tokens: 1 }); |
| |
|
| | let messages = [SYSTEM_MESSAGE]; |
| | let past_key_values_cache; |
| | let stopping_criteria; |
| | self.postMessage({ |
| | type: "status", |
| | status: "ready", |
| | message: "Ready!", |
| | voices: tts.voices, |
| | }); |
| |
|
| | |
| | const BUFFER = new Float32Array(MAX_BUFFER_DURATION * INPUT_SAMPLE_RATE); |
| | let bufferPointer = 0; |
| |
|
| | |
| | const sr = new Tensor("int64", [INPUT_SAMPLE_RATE], []); |
| | let state = new Tensor("float32", new Float32Array(2 * 1 * 128), [2, 1, 128]); |
| |
|
| | |
| | let isRecording = false; |
| | let isPlaying = false; |
| |
|
| | |
| | |
| | |
| | |
| | |
| | async function vad(buffer) { |
| | const input = new Tensor("float32", buffer, [1, buffer.length]); |
| |
|
| | const { stateN, output } = await silero_vad({ input, sr, state }); |
| | state = stateN; |
| |
|
| | const isSpeech = output.data[0]; |
| |
|
| | |
| | return ( |
| | |
| | isSpeech > SPEECH_THRESHOLD || |
| | |
| | (isRecording && isSpeech >= EXIT_THRESHOLD) |
| | ); |
| | } |
| |
|
| | |
| | |
| | |
| | |
| | |
| | const speechToSpeech = async (buffer, data) => { |
| | isPlaying = true; |
| |
|
| | |
| | const text = await transcriber(buffer).then(({ text }) => text.trim()); |
| | if (["", "[BLANK_AUDIO]"].includes(text)) { |
| | |
| | return; |
| | } |
| | messages.push({ role: "user", content: text }); |
| |
|
| | |
| | const splitter = new TextSplitterStream(); |
| | const stream = tts.stream(splitter, { |
| | voice, |
| | }); |
| | (async () => { |
| | for await (const { text, phonemes, audio } of stream) { |
| | self.postMessage({ type: "output", text, result: audio }); |
| | } |
| | })(); |
| |
|
| | |
| | const inputs = tokenizer.apply_chat_template(messages, { |
| | add_generation_prompt: true, |
| | return_dict: true, |
| | }); |
| | const streamer = new TextStreamer(tokenizer, { |
| | skip_prompt: true, |
| | skip_special_tokens: true, |
| | callback_function: (text) => { |
| | splitter.push(text); |
| | }, |
| | token_callback_function: () => {}, |
| | }); |
| |
|
| | stopping_criteria = new InterruptableStoppingCriteria(); |
| | const { past_key_values, sequences } = await llm.generate({ |
| | ...inputs, |
| | past_key_values: past_key_values_cache, |
| |
|
| | do_sample: false, |
| | max_new_tokens: 1024, |
| | streamer, |
| | stopping_criteria, |
| | return_dict_in_generate: true, |
| | }); |
| | past_key_values_cache = past_key_values; |
| |
|
| | |
| | splitter.close(); |
| |
|
| | const decoded = tokenizer.batch_decode( |
| | sequences.slice(null, [inputs.input_ids.dims[1], null]), |
| | { skip_special_tokens: true }, |
| | ); |
| |
|
| | messages.push({ role: "assistant", content: decoded[0] }); |
| | }; |
| |
|
| | |
| | let postSpeechSamples = 0; |
| | const resetAfterRecording = (offset = 0) => { |
| | self.postMessage({ |
| | type: "status", |
| | status: "recording_end", |
| | message: "Transcribing...", |
| | duration: "until_next", |
| | }); |
| | BUFFER.fill(0, offset); |
| | bufferPointer = offset; |
| | isRecording = false; |
| | postSpeechSamples = 0; |
| | }; |
| |
|
| | const dispatchForTranscriptionAndResetAudioBuffer = (overflow) => { |
| | |
| | const now = Date.now(); |
| | const end = |
| | now - ((postSpeechSamples + SPEECH_PAD_SAMPLES) / INPUT_SAMPLE_RATE) * 1000; |
| | const start = end - (bufferPointer / INPUT_SAMPLE_RATE) * 1000; |
| | const duration = end - start; |
| | const overflowLength = overflow?.length ?? 0; |
| |
|
| | |
| | const buffer = BUFFER.slice(0, bufferPointer + SPEECH_PAD_SAMPLES); |
| |
|
| | const prevLength = prevBuffers.reduce((acc, b) => acc + b.length, 0); |
| | const paddedBuffer = new Float32Array(prevLength + buffer.length); |
| | let offset = 0; |
| | for (const prev of prevBuffers) { |
| | paddedBuffer.set(prev, offset); |
| | offset += prev.length; |
| | } |
| | paddedBuffer.set(buffer, offset); |
| | speechToSpeech(paddedBuffer, { start, end, duration }); |
| |
|
| | |
| | if (overflow) { |
| | BUFFER.set(overflow, 0); |
| | } |
| | resetAfterRecording(overflowLength); |
| | }; |
| |
|
| | let prevBuffers = []; |
| | self.onmessage = async (event) => { |
| | const { type, buffer } = event.data; |
| |
|
| | |
| | if (type === "audio" && isPlaying) return; |
| |
|
| | switch (type) { |
| | case "start_call": { |
| | const name = tts.voices[voice ?? "af_heart"]?.name ?? "Heart"; |
| | greet(`Hey there, my name is ${name}! How can I help you today?`); |
| | return; |
| | } |
| | case "end_call": |
| | messages = [SYSTEM_MESSAGE]; |
| | past_key_values_cache = null; |
| | case "interrupt": |
| | stopping_criteria?.interrupt(); |
| | return; |
| | case "set_voice": |
| | voice = event.data.voice; |
| | return; |
| | case "playback_ended": |
| | isPlaying = false; |
| | return; |
| | } |
| |
|
| | const wasRecording = isRecording; |
| | const isSpeech = await vad(buffer); |
| |
|
| | if (!wasRecording && !isSpeech) { |
| | |
| | |
| | |
| | if (prevBuffers.length >= MAX_NUM_PREV_BUFFERS) { |
| | |
| | prevBuffers.shift(); |
| | } |
| | prevBuffers.push(buffer); |
| | return; |
| | } |
| |
|
| | const remaining = BUFFER.length - bufferPointer; |
| | if (buffer.length >= remaining) { |
| | |
| | |
| | BUFFER.set(buffer.subarray(0, remaining), bufferPointer); |
| | bufferPointer += remaining; |
| |
|
| | |
| | const overflow = buffer.subarray(remaining); |
| | dispatchForTranscriptionAndResetAudioBuffer(overflow); |
| | return; |
| | } else { |
| | |
| | |
| | BUFFER.set(buffer, bufferPointer); |
| | bufferPointer += buffer.length; |
| | } |
| |
|
| | if (isSpeech) { |
| | if (!isRecording) { |
| | |
| | self.postMessage({ |
| | type: "status", |
| | status: "recording_start", |
| | message: "Listening...", |
| | duration: "until_next", |
| | }); |
| | } |
| | |
| | isRecording = true; |
| | postSpeechSamples = 0; |
| | return; |
| | } |
| |
|
| | postSpeechSamples += buffer.length; |
| |
|
| | |
| | |
| | if (postSpeechSamples < MIN_SILENCE_DURATION_SAMPLES) { |
| | |
| | |
| | return; |
| | } |
| |
|
| | if (bufferPointer < MIN_SPEECH_DURATION_SAMPLES) { |
| | |
| | |
| | resetAfterRecording(); |
| | return; |
| | } |
| |
|
| | dispatchForTranscriptionAndResetAudioBuffer(); |
| | }; |
| |
|
| | function greet(text) { |
| | isPlaying = true; |
| | const splitter = new TextSplitterStream(); |
| | const stream = tts.stream(splitter, { voice }); |
| | (async () => { |
| | for await (const { text: chunkText, audio } of stream) { |
| | self.postMessage({ type: "output", text: chunkText, result: audio }); |
| | } |
| | })(); |
| | splitter.push(text); |
| | splitter.close(); |
| | messages.push({ role: "assistant", content: text }); |
| | } |
| |
|