import { useChat } from "@ai-sdk/react"; import { DefaultChatTransport, lastAssistantMessageIsCompleteWithToolCalls, } from "ai"; import { useCallback, useEffect, useRef, useState } from "react"; import type { Editor } from "@tiptap/core"; import type { UndoManager } from "yjs"; import type { UIMessage } from "ai"; import type { FrontmatterStore } from "../editor/frontmatter/frontmatter-store"; import { executeTiptapCommand, TIPTAP_TOOL_NAMES } from "../editor/agent-executor"; import { createAgentBatch } from "../editor/agent-undo-batch"; interface UseAgentChatOptions { editor: Editor | null; undoManager: UndoManager | null; frontmatterStore: FrontmatterStore | null; modelRef: React.RefObject; initialMessages?: UIMessage[]; onMessagesChange?: (messages: UIMessage[]) => void; } const transport = new DefaultChatTransport({ api: "/api/chat" }); export function useAgentChat({ editor, undoManager, frontmatterStore, modelRef, initialMessages, onMessagesChange }: UseAgentChatOptions) { const pendingSelectionRef = useRef<{ from: number; to: number } | null>(null); const agentBatchRef = useRef(createAgentBatch(undoManager)); useEffect(() => { agentBatchRef.current = createAgentBatch(undoManager); }, [undoManager]); const [input, setInput] = useState(""); const onMessagesChangeRef = useRef(onMessagesChange); onMessagesChangeRef.current = onMessagesChange; const getEditorContext = useCallback(() => { if (!editor) return {}; const { from, to } = editor.state.selection; const hasSelection = from !== to; const document = editor.getText({ blockSeparator: "\n" }); const selection = hasSelection ? editor.state.doc.textBetween(from, to, "\n") : undefined; if (hasSelection) { pendingSelectionRef.current = { from, to }; } const frontmatter = frontmatterStore?.getAll(); return { document, selection, frontmatter }; }, [editor, frontmatterStore]); /** * Start a new undo capture group. All subsequent edits will be * merged into one undo step until endAgentBatch() is called. */ const startAgentBatch = useCallback(() => { agentBatchRef.current.startAgentBatch(); }, []); const endAgentBatch = useCallback(() => { agentBatchRef.current.endAgentBatch(); }, []); /** * Extract all text from the ProseMirror document as a flat string, * keeping a mapping from string offsets to document positions. */ const getDocTextWithPositions = useCallback((): { text: string; map: number[]; } => { if (!editor) return { text: "", map: [] }; const chunks: string[] = []; const map: number[] = []; editor.state.doc.descendants((node, pos) => { if (node.isText && node.text) { for (let i = 0; i < node.text.length; i++) { map.push(pos + i); chunks.push(node.text[i]); } } else if (node.isBlock && chunks.length > 0) { map.push(pos); chunks.push("\n"); } }); return { text: chunks.join(""), map }; }, [editor]); /** * Find text in the document, optionally using surrounding context * to disambiguate when the same text appears multiple times. */ const findTextPosition = useCallback( ( search: string, contextBefore?: string, contextAfter?: string, ): { from: number; to: number } | null => { if (!editor || !search) return null; const { text, map } = getDocTextWithPositions(); const candidates: number[] = []; let startIdx = 0; while (true) { const idx = text.indexOf(search, startIdx); if (idx === -1) break; candidates.push(idx); startIdx = idx + 1; } if (candidates.length === 0) return null; if (candidates.length === 1) { const idx = candidates[0]; return { from: map[idx], to: map[idx + search.length - 1] + 1 }; } // Score each candidate by how well the context matches let bestIdx = candidates[0]; let bestScore = -1; for (const idx of candidates) { let score = 0; if (contextBefore) { const before = text.slice(Math.max(0, idx - contextBefore.length - 10), idx); if (before.includes(contextBefore)) score += 2; else if (before.toLowerCase().includes(contextBefore.toLowerCase())) score += 1; } if (contextAfter) { const after = text.slice(idx + search.length, idx + search.length + (contextAfter?.length ?? 0) + 10); if (after.includes(contextAfter)) score += 2; else if (after.toLowerCase().includes(contextAfter.toLowerCase())) score += 1; } if (score > bestScore) { bestScore = score; bestIdx = idx; } } return { from: map[bestIdx], to: map[bestIdx + search.length - 1] + 1 }; }, [editor, getDocTextWithPositions], ); const executeToolCall = useCallback( (toolCall: { toolName: string; args: unknown; toolCallId: string }) => { // --- Tiptap catalog commands (select, toggleBold, toggleHeading, ...) --- if (TIPTAP_TOOL_NAMES.has(toolCall.toolName)) { if (!editor) return "Editor not available"; startAgentBatch(); const result = executeTiptapCommand( editor, toolCall.toolName, (toolCall.args as Record) ?? {}, ); return result ?? `Unknown editor command: ${toolCall.toolName}`; } switch (toolCall.toolName) { // --- Frontmatter tools (no editor needed) --- case "updateFrontmatter": { if (!frontmatterStore) return "Frontmatter store not available"; const fields = toolCall.args as Record; const updates: Record = {}; for (const [key, value] of Object.entries(fields)) { if (value !== undefined) updates[key] = value; } frontmatterStore.setAll(updates as any); return `Frontmatter updated: ${Object.keys(updates).join(", ")}`; } case "addAuthor": { if (!frontmatterStore) return "Frontmatter store not available"; const { name, url, affiliations, newAffiliationName, newAffiliationUrl } = toolCall.args as { name: string; url?: string; affiliations?: number[]; newAffiliationName?: string; newAffiliationUrl?: string; }; const affIndices = affiliations || []; if (newAffiliationName) { const newIdx = frontmatterStore.addAffiliation({ name: newAffiliationName, url: newAffiliationUrl, }); affIndices.push(newIdx); } frontmatterStore.addAuthor({ name, url, affiliations: affIndices }); return `Author "${name}" added`; } case "removeAuthor": { if (!frontmatterStore) return "Frontmatter store not available"; const { index } = toolCall.args as { index: number }; const authors = frontmatterStore.get("authors"); if (index < 0 || index >= authors.length) { return `Invalid author index ${index} (${authors.length} authors)`; } const removed = authors[index].name; frontmatterStore.removeAuthor(index); return `Author "${removed}" removed`; } // --- Editor tools --- case "replaceSelection": { if (!editor) return "Editor not available"; startAgentBatch(); const { newText } = toolCall.args as { newText: string }; const sel = pendingSelectionRef.current; if (sel) { editor .chain() .focus() .setTextSelection(sel) .insertContent(newText) .run(); pendingSelectionRef.current = null; return "Selection replaced successfully"; } const { from, to } = editor.state.selection; if (from !== to) { editor.chain().focus().insertContent(newText).run(); return "Selection replaced successfully"; } return "No text selected to replace"; } case "insertAtCursor": { if (!editor) return "Editor not available"; startAgentBatch(); const { text } = toolCall.args as { text: string }; editor.chain().focus().insertContent(text).run(); return "Text inserted successfully"; } case "applyDiff": { if (!editor) return "Editor not available"; startAgentBatch(); const { contextBefore, contentToDelete, contentToInsert, contextAfter } = toolCall.args as { contextBefore: string; contentToDelete: string; contentToInsert: string; contextAfter: string; }; const pos = findTextPosition(contentToDelete, contextBefore, contextAfter); if (pos) { editor .chain() .focus() .setTextSelection(pos) .insertContent(contentToInsert || "") .run(); return "Diff applied successfully"; } return `Text "${contentToDelete.slice(0, 50)}..." not found in document`; } default: return `Unknown tool: ${toolCall.toolName}`; } }, [editor, frontmatterStore, startAgentBatch, findTextPosition], ); /** * Hard cap on consecutive auto-sent tool-call rounds to prevent * runaway loops (e.g. a model that keeps calling tools instead of * emitting a final text response). */ const MAX_TOOL_ROUNDS = 8; const { addToolOutput, ...chat } = useChat({ transport, messages: initialMessages, sendAutomaticallyWhen: ({ messages }) => { let toolRounds = 0; for (let i = messages.length - 1; i >= 0; i--) { const m = messages[i]; if (m.role !== "assistant") break; const hasToolPart = (m.parts ?? []).some((p) => typeof p.type === "string" && p.type.startsWith("tool-"), ); if (!hasToolPart) break; toolRounds += 1; } if (toolRounds >= MAX_TOOL_ROUNDS) return false; return lastAssistantMessageIsCompleteWithToolCalls({ messages }); }, async onToolCall({ toolCall }) { if (toolCall.dynamic) return; const result = executeToolCall({ toolName: toolCall.toolName as string, args: toolCall.input, toolCallId: toolCall.toolCallId, }); (addToolOutput as (args: { tool: string; toolCallId: string; output: unknown }) => void)({ tool: toolCall.toolName as string, toolCallId: toolCall.toolCallId, output: result, }); }, onFinish() { endAgentBatch(); }, onError(error) { console.error("[agent] chat error:", error); endAgentBatch(); }, }); // Persist when a conversation round completes (status back to "ready") const prevStatusRef = useRef(chat.status); useEffect(() => { const wasActive = prevStatusRef.current === "streaming" || prevStatusRef.current === "submitted"; prevStatusRef.current = chat.status; if (wasActive && chat.status === "ready" && chat.messages.length > 0) { onMessagesChangeRef.current?.(chat.messages as UIMessage[]); } }, [chat.status, chat.messages]); const sendMessage = useCallback( (content: string) => { endAgentBatch(); const context = getEditorContext(); const model = modelRef.current; chat.sendMessage({ text: content }, { body: { context, model } }); }, [chat, getEditorContext, endAgentBatch, modelRef], ); const isLoading = chat.status === "streaming" || chat.status === "submitted"; const clearMessages = useCallback(() => { chat.setMessages([]); onMessagesChangeRef.current?.([]); }, [chat]); const setMessages = useCallback((msgs: UIMessage[]) => { chat.setMessages(msgs); }, [chat]); return { messages: chat.messages as UIMessage[], isLoading, error: chat.error, sendMessage, clearMessages, setMessages, input, setInput, stop: chat.stop, }; }