import { useRef, useState, useEffect, useCallback, type ReactNode } from "react"; import { pipeline, TextStreamer, InterruptableStoppingCriteria, type TextGenerationPipeline, } from "@huggingface/transformers"; import { LLMContext, createMessageId, type ChatMessage, type LoadingStatus } from "./LLMContext"; interface LLMProviderProps { modelId: string; children: ReactNode; onReady?: () => void; } export function LLMProvider({ modelId, children, onReady }: LLMProviderProps) { const generatorRef = useRef | null>(null); const stoppingCriteria = useRef(new InterruptableStoppingCriteria()); const [status, setStatus] = useState({ state: "idle" }); const [messages, setMessages] = useState([]); const messagesRef = useRef([]); const [isGenerating, setIsGenerating] = useState(false); const isGeneratingRef = useRef(false); const [tps, setTps] = useState(0); useEffect(() => { messagesRef.current = messages; }, [messages]); useEffect(() => { isGeneratingRef.current = isGenerating; }, [isGenerating]); const onReadyRef = useRef(onReady); onReadyRef.current = onReady; useEffect(() => { if (status.state === "ready") onReadyRef.current?.(); }, [status.state]); useEffect(() => { if (generatorRef.current) return; generatorRef.current = (async () => { setStatus({ state: "loading", message: "Downloading model…" }); try { const gen = await pipeline("text-generation", modelId, { dtype: "q4f16", device: "webgpu", progress_callback: (info) => { if (info.status !== "progress_total") return; const loaded = Number(info.loaded ?? 0); const total = Number(info.total ?? 0); const pct = Number(info.progress ?? 0); const toGB = (b: number) => (b / 1e9).toFixed(2); setStatus({ state: "loading", progress: pct, message: total > 0 ? `${toGB(loaded)} GB of ${toGB(total)} GB (${Math.round(pct)}%)` : `Downloading model…`, }); }, }); setStatus({ state: "ready" }); return gen; } catch (err) { const msg = err instanceof Error ? err.message : String(err); setStatus({ state: "error", error: msg }); generatorRef.current = null; throw err; } })(); }, [modelId]); const runGeneration = useCallback(async (chatHistory: ChatMessage[]) => { const generator = await generatorRef.current!; setIsGenerating(true); setTps(0); stoppingCriteria.current.reset(); let tokenCount = 0; let firstTokenTime = 0; const assistantIdx = chatHistory.length; setMessages((prev) => [...prev, { id: createMessageId(), role: "assistant", content: "" }]); const streamer = new TextStreamer(generator.tokenizer, { skip_prompt: true, skip_special_tokens: true, callback_function: (output: string) => { if (!output) return; setMessages((prev) => { const updated = [...prev]; updated[assistantIdx] = { ...updated[assistantIdx], content: updated[assistantIdx].content + output, }; return updated; }); }, token_callback_function: () => { tokenCount++; if (tokenCount === 1) { firstTokenTime = performance.now(); } else { const elapsed = (performance.now() - firstTokenTime) / 1000; if (elapsed > 0) { setTps(Math.round(((tokenCount - 1) / elapsed) * 10) / 10); } } }, }); try { await generator( chatHistory.map((message) => ({ role: message.role, content: message.content, })), { max_new_tokens: 4096, do_sample: false, streamer, stopping_criteria: stoppingCriteria.current, }, ); } catch (err) { console.error("Generation error:", err); } const finalTps = tokenCount > 1 ? Math.round(((tokenCount - 1) / ((performance.now() - firstTokenTime) / 1000)) * 10) / 10 : 0; setMessages((prev) => { const updated = [...prev]; updated[assistantIdx] = { ...updated[assistantIdx], content: updated[assistantIdx].content.trim(), tps: finalTps > 0 ? finalTps : undefined, }; return updated; }); setIsGenerating(false); }, []); const send = useCallback( (text: string) => { if (!generatorRef.current || isGeneratingRef.current) return; const userMsg: ChatMessage = { id: createMessageId(), role: "user", content: text, }; setMessages((prev) => [...prev, userMsg]); runGeneration([...messagesRef.current, userMsg]); }, [runGeneration], ); const stop = useCallback(() => { stoppingCriteria.current.interrupt(); }, []); const clearChat = useCallback(() => { if (isGeneratingRef.current) return; setMessages([]); }, []); const editMessage = useCallback( (index: number, newContent: string) => { if (isGeneratingRef.current) return; const updatedHistory = [ ...messagesRef.current.slice(0, index), { ...messagesRef.current[index], content: newContent }, ]; setMessages(updatedHistory); if (messagesRef.current[index]?.role === "user") { setTimeout(() => runGeneration(updatedHistory), 0); } }, [runGeneration], ); const retryMessage = useCallback( (index: number) => { if (isGeneratingRef.current) return; const history = messagesRef.current.slice(0, index); setMessages(history); setTimeout(() => runGeneration(history), 0); }, [runGeneration], ); return ( {children} ); }