| import { useChat } from "@ai-sdk/react"; |
| import { |
| DefaultChatTransport, |
| lastAssistantMessageIsCompleteWithToolCalls, |
| } from "ai"; |
| import { useCallback, useEffect, useRef, useState } from "react"; |
| import type { Editor } from "@tiptap/core"; |
| import type { UndoManager } from "yjs"; |
| import type { UIMessage } from "ai"; |
| import type { FrontmatterStore } from "../editor/frontmatter/frontmatter-store"; |
| import { executeTiptapCommand, TIPTAP_TOOL_NAMES } from "../editor/agent-executor"; |
| import { createAgentBatch } from "../editor/agent-undo-batch"; |
|
|
| interface UseAgentChatOptions { |
| editor: Editor | null; |
| undoManager: UndoManager | null; |
| frontmatterStore: FrontmatterStore | null; |
| modelRef: React.RefObject<string>; |
| initialMessages?: UIMessage[]; |
| onMessagesChange?: (messages: UIMessage[]) => void; |
| } |
|
|
| const transport = new DefaultChatTransport({ api: "/api/chat" }); |
|
|
| export function useAgentChat({ editor, undoManager, frontmatterStore, modelRef, initialMessages, onMessagesChange }: UseAgentChatOptions) { |
| const pendingSelectionRef = useRef<{ from: number; to: number } | null>(null); |
| const agentBatchRef = useRef(createAgentBatch(undoManager)); |
| useEffect(() => { |
| agentBatchRef.current = createAgentBatch(undoManager); |
| }, [undoManager]); |
| const [input, setInput] = useState(""); |
| const onMessagesChangeRef = useRef(onMessagesChange); |
| onMessagesChangeRef.current = onMessagesChange; |
|
|
| const getEditorContext = useCallback(() => { |
| if (!editor) return {}; |
|
|
| const { from, to } = editor.state.selection; |
| const hasSelection = from !== to; |
|
|
| const document = editor.getText({ blockSeparator: "\n" }); |
| const selection = hasSelection |
| ? editor.state.doc.textBetween(from, to, "\n") |
| : undefined; |
|
|
| if (hasSelection) { |
| pendingSelectionRef.current = { from, to }; |
| } |
|
|
| const frontmatter = frontmatterStore?.getAll(); |
|
|
| return { document, selection, frontmatter }; |
| }, [editor, frontmatterStore]); |
|
|
| |
| |
| |
| |
| const startAgentBatch = useCallback(() => { |
| agentBatchRef.current.startAgentBatch(); |
| }, []); |
|
|
| const endAgentBatch = useCallback(() => { |
| agentBatchRef.current.endAgentBatch(); |
| }, []); |
|
|
| |
| |
| |
| |
| const getDocTextWithPositions = useCallback((): { |
| text: string; |
| map: number[]; |
| } => { |
| if (!editor) return { text: "", map: [] }; |
| const chunks: string[] = []; |
| const map: number[] = []; |
|
|
| editor.state.doc.descendants((node, pos) => { |
| if (node.isText && node.text) { |
| for (let i = 0; i < node.text.length; i++) { |
| map.push(pos + i); |
| chunks.push(node.text[i]); |
| } |
| } else if (node.isBlock && chunks.length > 0) { |
| map.push(pos); |
| chunks.push("\n"); |
| } |
| }); |
|
|
| return { text: chunks.join(""), map }; |
| }, [editor]); |
|
|
| |
| |
| |
| |
| const findTextPosition = useCallback( |
| ( |
| search: string, |
| contextBefore?: string, |
| contextAfter?: string, |
| ): { from: number; to: number } | null => { |
| if (!editor || !search) return null; |
|
|
| const { text, map } = getDocTextWithPositions(); |
|
|
| const candidates: number[] = []; |
| let startIdx = 0; |
| while (true) { |
| const idx = text.indexOf(search, startIdx); |
| if (idx === -1) break; |
| candidates.push(idx); |
| startIdx = idx + 1; |
| } |
|
|
| if (candidates.length === 0) return null; |
| if (candidates.length === 1) { |
| const idx = candidates[0]; |
| return { from: map[idx], to: map[idx + search.length - 1] + 1 }; |
| } |
|
|
| |
| let bestIdx = candidates[0]; |
| let bestScore = -1; |
| for (const idx of candidates) { |
| let score = 0; |
| if (contextBefore) { |
| const before = text.slice(Math.max(0, idx - contextBefore.length - 10), idx); |
| if (before.includes(contextBefore)) score += 2; |
| else if (before.toLowerCase().includes(contextBefore.toLowerCase())) score += 1; |
| } |
| if (contextAfter) { |
| const after = text.slice(idx + search.length, idx + search.length + (contextAfter?.length ?? 0) + 10); |
| if (after.includes(contextAfter)) score += 2; |
| else if (after.toLowerCase().includes(contextAfter.toLowerCase())) score += 1; |
| } |
| if (score > bestScore) { |
| bestScore = score; |
| bestIdx = idx; |
| } |
| } |
|
|
| return { from: map[bestIdx], to: map[bestIdx + search.length - 1] + 1 }; |
| }, |
| [editor, getDocTextWithPositions], |
| ); |
|
|
| const executeToolCall = useCallback( |
| (toolCall: { toolName: string; args: unknown; toolCallId: string }) => { |
| |
| if (TIPTAP_TOOL_NAMES.has(toolCall.toolName)) { |
| if (!editor) return "Editor not available"; |
| startAgentBatch(); |
| const result = executeTiptapCommand( |
| editor, |
| toolCall.toolName, |
| (toolCall.args as Record<string, unknown>) ?? {}, |
| ); |
| return result ?? `Unknown editor command: ${toolCall.toolName}`; |
| } |
|
|
| switch (toolCall.toolName) { |
| |
| case "updateFrontmatter": { |
| if (!frontmatterStore) return "Frontmatter store not available"; |
| const fields = toolCall.args as Record<string, unknown>; |
| const updates: Record<string, unknown> = {}; |
| for (const [key, value] of Object.entries(fields)) { |
| if (value !== undefined) updates[key] = value; |
| } |
| frontmatterStore.setAll(updates as any); |
| return `Frontmatter updated: ${Object.keys(updates).join(", ")}`; |
| } |
|
|
| case "addAuthor": { |
| if (!frontmatterStore) return "Frontmatter store not available"; |
| const { name, url, affiliations, newAffiliationName, newAffiliationUrl } = |
| toolCall.args as { |
| name: string; |
| url?: string; |
| affiliations?: number[]; |
| newAffiliationName?: string; |
| newAffiliationUrl?: string; |
| }; |
| const affIndices = affiliations || []; |
| if (newAffiliationName) { |
| const newIdx = frontmatterStore.addAffiliation({ |
| name: newAffiliationName, |
| url: newAffiliationUrl, |
| }); |
| affIndices.push(newIdx); |
| } |
| frontmatterStore.addAuthor({ name, url, affiliations: affIndices }); |
| return `Author "${name}" added`; |
| } |
|
|
| case "removeAuthor": { |
| if (!frontmatterStore) return "Frontmatter store not available"; |
| const { index } = toolCall.args as { index: number }; |
| const authors = frontmatterStore.get("authors"); |
| if (index < 0 || index >= authors.length) { |
| return `Invalid author index ${index} (${authors.length} authors)`; |
| } |
| const removed = authors[index].name; |
| frontmatterStore.removeAuthor(index); |
| return `Author "${removed}" removed`; |
| } |
|
|
| |
| case "replaceSelection": { |
| if (!editor) return "Editor not available"; |
| startAgentBatch(); |
| const { newText } = toolCall.args as { newText: string }; |
| const sel = pendingSelectionRef.current; |
| if (sel) { |
| editor |
| .chain() |
| .focus() |
| .setTextSelection(sel) |
| .insertContent(newText) |
| .run(); |
| pendingSelectionRef.current = null; |
| return "Selection replaced successfully"; |
| } |
| const { from, to } = editor.state.selection; |
| if (from !== to) { |
| editor.chain().focus().insertContent(newText).run(); |
| return "Selection replaced successfully"; |
| } |
| return "No text selected to replace"; |
| } |
|
|
| case "insertAtCursor": { |
| if (!editor) return "Editor not available"; |
| startAgentBatch(); |
| const { text } = toolCall.args as { text: string }; |
| editor.chain().focus().insertContent(text).run(); |
| return "Text inserted successfully"; |
| } |
|
|
| case "applyDiff": { |
| if (!editor) return "Editor not available"; |
| startAgentBatch(); |
| const { contextBefore, contentToDelete, contentToInsert, contextAfter } = |
| toolCall.args as { |
| contextBefore: string; |
| contentToDelete: string; |
| contentToInsert: string; |
| contextAfter: string; |
| }; |
| const pos = findTextPosition(contentToDelete, contextBefore, contextAfter); |
| if (pos) { |
| editor |
| .chain() |
| .focus() |
| .setTextSelection(pos) |
| .insertContent(contentToInsert || "") |
| .run(); |
| return "Diff applied successfully"; |
| } |
| return `Text "${contentToDelete.slice(0, 50)}..." not found in document`; |
| } |
|
|
| default: |
| return `Unknown tool: ${toolCall.toolName}`; |
| } |
| }, |
| [editor, frontmatterStore, startAgentBatch, findTextPosition], |
| ); |
|
|
| |
| |
| |
| |
| |
| const MAX_TOOL_ROUNDS = 8; |
| const { addToolOutput, ...chat } = useChat({ |
| transport, |
| messages: initialMessages, |
| sendAutomaticallyWhen: ({ messages }) => { |
| let toolRounds = 0; |
| for (let i = messages.length - 1; i >= 0; i--) { |
| const m = messages[i]; |
| if (m.role !== "assistant") break; |
| const hasToolPart = (m.parts ?? []).some((p) => |
| typeof p.type === "string" && p.type.startsWith("tool-"), |
| ); |
| if (!hasToolPart) break; |
| toolRounds += 1; |
| } |
| if (toolRounds >= MAX_TOOL_ROUNDS) return false; |
| return lastAssistantMessageIsCompleteWithToolCalls({ messages }); |
| }, |
|
|
| async onToolCall({ toolCall }) { |
| if (toolCall.dynamic) return; |
|
|
| const result = executeToolCall({ |
| toolName: toolCall.toolName as string, |
| args: toolCall.input, |
| toolCallId: toolCall.toolCallId, |
| }); |
|
|
| (addToolOutput as (args: { tool: string; toolCallId: string; output: unknown }) => void)({ |
| tool: toolCall.toolName as string, |
| toolCallId: toolCall.toolCallId, |
| output: result, |
| }); |
| }, |
|
|
| onFinish() { |
| endAgentBatch(); |
| }, |
|
|
| onError(error) { |
| console.error("[agent] chat error:", error); |
| endAgentBatch(); |
| }, |
| }); |
|
|
| |
| const prevStatusRef = useRef(chat.status); |
| useEffect(() => { |
| const wasActive = prevStatusRef.current === "streaming" || prevStatusRef.current === "submitted"; |
| prevStatusRef.current = chat.status; |
|
|
| if (wasActive && chat.status === "ready" && chat.messages.length > 0) { |
| onMessagesChangeRef.current?.(chat.messages as UIMessage[]); |
| } |
| }, [chat.status, chat.messages]); |
|
|
| const sendMessage = useCallback( |
| (content: string) => { |
| endAgentBatch(); |
| const context = getEditorContext(); |
| const model = modelRef.current; |
| chat.sendMessage({ text: content }, { body: { context, model } }); |
| }, |
| [chat, getEditorContext, endAgentBatch, modelRef], |
| ); |
|
|
| const isLoading = |
| chat.status === "streaming" || chat.status === "submitted"; |
|
|
| const clearMessages = useCallback(() => { |
| chat.setMessages([]); |
| onMessagesChangeRef.current?.([]); |
| }, [chat]); |
|
|
| const setMessages = useCallback((msgs: UIMessage[]) => { |
| chat.setMessages(msgs); |
| }, [chat]); |
|
|
| return { |
| messages: chat.messages as UIMessage[], |
| isLoading, |
| error: chat.error, |
| sendMessage, |
| clearMessages, |
| setMessages, |
| input, |
| setInput, |
| stop: chat.stop, |
| }; |
| } |
|
|