import { pipeline, TextStreamer, DynamicCache, InterruptableStoppingCriteria, } from "@huggingface/transformers"; const MODEL_IDS = { "1.7b": "onnx-community/Bonsai-1.7B-ONNX", }; async function check() { try { const adapter = await navigator.gpu?.requestAdapter(); if (!adapter) throw new Error("WebGPU is not supported (no adapter found)"); } catch (e) { self.postMessage({ status: "error", data: e.toString() }); } } class TextGenerationPipeline { static instances = new Map(); static getInstance(modelKey, progress_callback = null) { const modelId = MODEL_IDS[modelKey]; if (!modelId) throw new Error(`Unknown model: ${modelKey}`); if (!this.instances.has(modelKey)) { this.instances.set( modelKey, pipeline("text-generation", modelId, { device: "webgpu", dtype: "q1", progress_callback, }), ); } return this.instances.get(modelKey); } } const stopping_criteria = new InterruptableStoppingCriteria(); let past_key_values_cache = null; let current_model_key = null; function disposePastKeyValues() { past_key_values_cache?.dispose?.(); past_key_values_cache = null; } async function load(modelKey) { if (current_model_key && current_model_key !== modelKey) { disposePastKeyValues(); } current_model_key = modelKey; self.postMessage({ status: "loading", data: "Loading model..." }); const generator = await TextGenerationPipeline.getInstance( modelKey, (info) => { if (info.status === "progress_total") { self.postMessage({ status: "progress_total", progress: Number(info.progress ?? 0), loaded: Number(info.loaded ?? 0), total: Number(info.total ?? 0), }); } }, ); self.postMessage({ status: "loading", data: "Optimizing model for 1-bit execution", }); const inputs = generator.tokenizer("a"); await generator.model.generate({ ...inputs, max_new_tokens: 1 }); self.postMessage({ status: "ready" }); } async function generate(messages) { const generator = await TextGenerationPipeline.getInstance(current_model_key); let startTime; let numTokens = 0; let tps; const streamer = new TextStreamer(generator.tokenizer, { skip_prompt: true, skip_special_tokens: true, callback_function: (output) => { self.postMessage({ status: "update", output, tps, numTokens }); }, token_callback_function: () => { startTime ??= performance.now(); if (numTokens++ > 0) { tps = (numTokens / (performance.now() - startTime)) * 1000; } }, }); self.postMessage({ status: "start" }); past_key_values_cache ??= new DynamicCache(); try { const output = await generator(messages, { max_new_tokens: 1024, do_sample: false, streamer, stopping_criteria, past_key_values: past_key_values_cache, }); self.postMessage({ status: "complete", output: output[0].generated_text.at(-1).content, }); } catch (e) { self.postMessage({ status: "error", data: e.toString() }); } } self.addEventListener("message", async (e) => { const { type, data } = e.data; switch (type) { case "check": check(); break; case "load": load(data); break; case "generate": stopping_criteria.reset(); generate(data); break; case "interrupt": stopping_criteria.interrupt(); break; case "reset": disposePastKeyValues(); stopping_criteria.reset(); break; } });