Gemma-4-WebGPU / src /worker.js
shreyask's picture
Upload folder using huggingface_hub
45f314a verified
// web/src/worker.js
import {
AutoProcessor,
Gemma4ForConditionalGeneration,
TextStreamer,
InterruptableStoppingCriteria,
load_image,
RawImage,
} from "@huggingface/transformers";
const MODEL_ID = "onnx-community/gemma-4-E2B-it-ONNX";
const THINK_START = "‹‹THINK››";
const THINK_END = "‹‹/THINK››";
function cleanGemmaOutput(raw) {
return raw
.replace(/<\|?channel\|?>?\s*thought\s*/gi, THINK_START)
.replace(/<\|?channell?\|?>/gi, THINK_END)
.replace(/<\|?[a-z_]+\|?>/gi, "")
.trim();
}
let processor = null;
let model = null;
const stoppingCriteria = new InterruptableStoppingCriteria();
async function checkWebGPU() {
try {
const adapter = await navigator.gpu?.requestAdapter();
self.postMessage({
type: "status",
status: adapter ? "webgpu-available" : "webgpu-unavailable",
});
} catch {
self.postMessage({ type: "status", status: "webgpu-unavailable" });
}
}
async function loadModel() {
try {
self.postMessage({ type: "status", status: "loading" });
const progress_callback = (p) => self.postMessage({ type: "progress", ...p });
processor = await AutoProcessor.from_pretrained(MODEL_ID, { progress_callback });
model = await Gemma4ForConditionalGeneration.from_pretrained(MODEL_ID, {
dtype: "q4f16",
device: "webgpu",
progress_callback,
});
self.postMessage({ type: "status", status: "ready" });
} catch (err) {
self.postMessage({ type: "error", message: err.message });
}
}
async function generate({ messages, imageUrl, videoData, audioData, enableThinking }) {
if (!model || !processor) {
self.postMessage({ type: "error", message: "Model not loaded" });
return;
}
try {
self.postMessage({ type: "status", status: "generating" });
stoppingCriteria.reset();
const prompt = processor.apply_chat_template(messages, {
enable_thinking: enableThinking,
add_generation_prompt: true,
});
// Gemma4ImageProcessor expects RawImage | RawImage[], not RawVideo
let image = null;
if (videoData) {
image = videoData.frames.map((f) =>
new RawImage(new Uint8ClampedArray(f.data), f.width, f.height, f.channels)
);
} else if (imageUrl) {
image = await load_image(imageUrl);
}
const audio = audioData ?? null;
const inputs = await processor(prompt, image, audio, {
add_special_tokens: false,
});
let fullText = "";
const streamer = new TextStreamer(processor.tokenizer, {
skip_prompt: true,
skip_special_tokens: false,
callback_function: (text) => {
fullText += text;
const cleaned = cleanGemmaOutput(fullText);
self.postMessage({ type: "update", text: cleaned });
},
});
await model.generate({
...inputs,
max_new_tokens: 512,
do_sample: false,
streamer,
stopping_criteria: [stoppingCriteria],
});
self.postMessage({ type: "complete", text: cleanGemmaOutput(fullText) });
} catch (err) {
self.postMessage({ type: "error", message: err.message });
}
}
self.onmessage = (e) => {
switch (e.data.type) {
case "check":
checkWebGPU();
break;
case "load":
loadModel();
break;
case "generate":
generate(e.data);
break;
case "interrupt":
stoppingCriteria.interrupt();
break;
}
};