Gemma-4-WebGPU / src /hooks /useModel.js
shreyask's picture
Upload folder using huggingface_hub
45f314a verified
// web/src/hooks/useModel.js
import { useState, useEffect, useRef, useCallback } from "react";
import { read_audio, load_video } from "@huggingface/transformers";
export function useModel() {
const [status, setStatus] = useState("idle"); // idle | webgpu-available | webgpu-unavailable | loading | ready | generating | error
const [loadProgress, setLoadProgress] = useState(null);
const [error, setError] = useState(null);
const workerRef = useRef(null);
const callbacksRef = useRef(null);
useEffect(() => {
const worker = new Worker(new URL("../worker.js", import.meta.url), {
type: "module",
});
worker.onmessage = (e) => {
const { type, ...data } = e.data;
switch (type) {
case "status":
setStatus(data.status);
if (data.status === "ready") setLoadProgress(null);
break;
case "progress":
setLoadProgress(data);
break;
case "error":
setError(data.message);
setStatus("error");
callbacksRef.current?.onComplete?.("", data.message);
break;
case "update":
callbacksRef.current?.onUpdate?.(data.text);
break;
case "complete":
setStatus("ready");
callbacksRef.current?.onComplete?.(data.text);
callbacksRef.current = null;
break;
}
};
workerRef.current = worker;
return () => worker.terminate();
}, []);
const checkWebGPU = useCallback(() => {
workerRef.current?.postMessage({ type: "check" });
}, []);
const loadModel = useCallback(() => {
workerRef.current?.postMessage({ type: "load" });
}, []);
const generate = useCallback(async ({ messages, imageUrl, videoUrl, audioUrl, enableThinking, onUpdate, onComplete }) => {
callbacksRef.current = { onUpdate, onComplete };
let audioData = null;
if (audioUrl) {
try {
audioData = await read_audio(audioUrl, 16000);
} catch (err) {
console.error("Audio decode failed:", err);
}
}
// Extract video frames on main thread (load_video needs DOM)
let videoData = null;
const transferables = audioData ? [audioData.buffer] : [];
if (videoUrl) {
try {
const video = await load_video(videoUrl, { num_frames: 4 });
videoData = {
duration: video.duration,
frames: video.frames.map((f) => {
// Transfer raw pixel data as ArrayBuffer
const buf = f.image.data.buffer.slice(0);
transferables.push(buf);
return { data: buf, width: f.image.width, height: f.image.height, channels: f.image.channels, timestamp: f.timestamp };
}),
};
} catch (err) {
console.error("Video frame extraction failed:", err);
}
}
const msg = {
type: "generate",
messages,
imageUrl: imageUrl || null,
videoData,
audioData,
enableThinking: enableThinking || false,
};
workerRef.current?.postMessage(msg, transferables);
}, []);
const interrupt = useCallback(() => {
workerRef.current?.postMessage({ type: "interrupt" });
}, []);
return { status, loadProgress, error, checkWebGPU, loadModel, generate, interrupt };
}