Spaces:
Running
Running
| // web/src/hooks/useModel.js | |
| import { useState, useEffect, useRef, useCallback } from "react"; | |
| import { read_audio, load_video } from "@huggingface/transformers"; | |
| export function useModel() { | |
| const [status, setStatus] = useState("idle"); // idle | webgpu-available | webgpu-unavailable | loading | ready | generating | error | |
| const [loadProgress, setLoadProgress] = useState(null); | |
| const [error, setError] = useState(null); | |
| const workerRef = useRef(null); | |
| const callbacksRef = useRef(null); | |
| useEffect(() => { | |
| const worker = new Worker(new URL("../worker.js", import.meta.url), { | |
| type: "module", | |
| }); | |
| worker.onmessage = (e) => { | |
| const { type, ...data } = e.data; | |
| switch (type) { | |
| case "status": | |
| setStatus(data.status); | |
| if (data.status === "ready") setLoadProgress(null); | |
| break; | |
| case "progress": | |
| setLoadProgress(data); | |
| break; | |
| case "error": | |
| setError(data.message); | |
| setStatus("error"); | |
| callbacksRef.current?.onComplete?.("", data.message); | |
| break; | |
| case "update": | |
| callbacksRef.current?.onUpdate?.(data.text); | |
| break; | |
| case "complete": | |
| setStatus("ready"); | |
| callbacksRef.current?.onComplete?.(data.text); | |
| callbacksRef.current = null; | |
| break; | |
| } | |
| }; | |
| workerRef.current = worker; | |
| return () => worker.terminate(); | |
| }, []); | |
| const checkWebGPU = useCallback(() => { | |
| workerRef.current?.postMessage({ type: "check" }); | |
| }, []); | |
| const loadModel = useCallback(() => { | |
| workerRef.current?.postMessage({ type: "load" }); | |
| }, []); | |
| const generate = useCallback(async ({ messages, imageUrl, videoUrl, audioUrl, enableThinking, onUpdate, onComplete }) => { | |
| callbacksRef.current = { onUpdate, onComplete }; | |
| let audioData = null; | |
| if (audioUrl) { | |
| try { | |
| audioData = await read_audio(audioUrl, 16000); | |
| } catch (err) { | |
| console.error("Audio decode failed:", err); | |
| } | |
| } | |
| // Extract video frames on main thread (load_video needs DOM) | |
| let videoData = null; | |
| const transferables = audioData ? [audioData.buffer] : []; | |
| if (videoUrl) { | |
| try { | |
| const video = await load_video(videoUrl, { num_frames: 4 }); | |
| videoData = { | |
| duration: video.duration, | |
| frames: video.frames.map((f) => { | |
| // Transfer raw pixel data as ArrayBuffer | |
| const buf = f.image.data.buffer.slice(0); | |
| transferables.push(buf); | |
| return { data: buf, width: f.image.width, height: f.image.height, channels: f.image.channels, timestamp: f.timestamp }; | |
| }), | |
| }; | |
| } catch (err) { | |
| console.error("Video frame extraction failed:", err); | |
| } | |
| } | |
| const msg = { | |
| type: "generate", | |
| messages, | |
| imageUrl: imageUrl || null, | |
| videoData, | |
| audioData, | |
| enableThinking: enableThinking || false, | |
| }; | |
| workerRef.current?.postMessage(msg, transferables); | |
| }, []); | |
| const interrupt = useCallback(() => { | |
| workerRef.current?.postMessage({ type: "interrupt" }); | |
| }, []); | |
| return { status, loadProgress, error, checkWebGPU, loadModel, generate, interrupt }; | |
| } | |