carbon-tokenization / frontend /src /hooks /useAgentChat.ts
tfrere's picture
tfrere HF Staff
chore(release): prep v0.1.0 with LICENSE, env example and agent undo test
d63dcfe
import { useChat } from "@ai-sdk/react";
import {
DefaultChatTransport,
lastAssistantMessageIsCompleteWithToolCalls,
} from "ai";
import { useCallback, useEffect, useRef, useState } from "react";
import type { Editor } from "@tiptap/core";
import type { UndoManager } from "yjs";
import type { UIMessage } from "ai";
import type { FrontmatterStore } from "../editor/frontmatter/frontmatter-store";
import { executeTiptapCommand, TIPTAP_TOOL_NAMES } from "../editor/agent-executor";
import { createAgentBatch } from "../editor/agent-undo-batch";
interface UseAgentChatOptions {
editor: Editor | null;
undoManager: UndoManager | null;
frontmatterStore: FrontmatterStore | null;
modelRef: React.RefObject<string>;
initialMessages?: UIMessage[];
onMessagesChange?: (messages: UIMessage[]) => void;
}
const transport = new DefaultChatTransport({ api: "/api/chat" });
export function useAgentChat({ editor, undoManager, frontmatterStore, modelRef, initialMessages, onMessagesChange }: UseAgentChatOptions) {
const pendingSelectionRef = useRef<{ from: number; to: number } | null>(null);
const agentBatchRef = useRef(createAgentBatch(undoManager));
useEffect(() => {
agentBatchRef.current = createAgentBatch(undoManager);
}, [undoManager]);
const [input, setInput] = useState("");
const onMessagesChangeRef = useRef(onMessagesChange);
onMessagesChangeRef.current = onMessagesChange;
const getEditorContext = useCallback(() => {
if (!editor) return {};
const { from, to } = editor.state.selection;
const hasSelection = from !== to;
const document = editor.getText({ blockSeparator: "\n" });
const selection = hasSelection
? editor.state.doc.textBetween(from, to, "\n")
: undefined;
if (hasSelection) {
pendingSelectionRef.current = { from, to };
}
const frontmatter = frontmatterStore?.getAll();
return { document, selection, frontmatter };
}, [editor, frontmatterStore]);
/**
* Start a new undo capture group. All subsequent edits will be
* merged into one undo step until endAgentBatch() is called.
*/
const startAgentBatch = useCallback(() => {
agentBatchRef.current.startAgentBatch();
}, []);
const endAgentBatch = useCallback(() => {
agentBatchRef.current.endAgentBatch();
}, []);
/**
* Extract all text from the ProseMirror document as a flat string,
* keeping a mapping from string offsets to document positions.
*/
const getDocTextWithPositions = useCallback((): {
text: string;
map: number[];
} => {
if (!editor) return { text: "", map: [] };
const chunks: string[] = [];
const map: number[] = [];
editor.state.doc.descendants((node, pos) => {
if (node.isText && node.text) {
for (let i = 0; i < node.text.length; i++) {
map.push(pos + i);
chunks.push(node.text[i]);
}
} else if (node.isBlock && chunks.length > 0) {
map.push(pos);
chunks.push("\n");
}
});
return { text: chunks.join(""), map };
}, [editor]);
/**
* Find text in the document, optionally using surrounding context
* to disambiguate when the same text appears multiple times.
*/
const findTextPosition = useCallback(
(
search: string,
contextBefore?: string,
contextAfter?: string,
): { from: number; to: number } | null => {
if (!editor || !search) return null;
const { text, map } = getDocTextWithPositions();
const candidates: number[] = [];
let startIdx = 0;
while (true) {
const idx = text.indexOf(search, startIdx);
if (idx === -1) break;
candidates.push(idx);
startIdx = idx + 1;
}
if (candidates.length === 0) return null;
if (candidates.length === 1) {
const idx = candidates[0];
return { from: map[idx], to: map[idx + search.length - 1] + 1 };
}
// Score each candidate by how well the context matches
let bestIdx = candidates[0];
let bestScore = -1;
for (const idx of candidates) {
let score = 0;
if (contextBefore) {
const before = text.slice(Math.max(0, idx - contextBefore.length - 10), idx);
if (before.includes(contextBefore)) score += 2;
else if (before.toLowerCase().includes(contextBefore.toLowerCase())) score += 1;
}
if (contextAfter) {
const after = text.slice(idx + search.length, idx + search.length + (contextAfter?.length ?? 0) + 10);
if (after.includes(contextAfter)) score += 2;
else if (after.toLowerCase().includes(contextAfter.toLowerCase())) score += 1;
}
if (score > bestScore) {
bestScore = score;
bestIdx = idx;
}
}
return { from: map[bestIdx], to: map[bestIdx + search.length - 1] + 1 };
},
[editor, getDocTextWithPositions],
);
const executeToolCall = useCallback(
(toolCall: { toolName: string; args: unknown; toolCallId: string }) => {
// --- Tiptap catalog commands (select, toggleBold, toggleHeading, ...) ---
if (TIPTAP_TOOL_NAMES.has(toolCall.toolName)) {
if (!editor) return "Editor not available";
startAgentBatch();
const result = executeTiptapCommand(
editor,
toolCall.toolName,
(toolCall.args as Record<string, unknown>) ?? {},
);
return result ?? `Unknown editor command: ${toolCall.toolName}`;
}
switch (toolCall.toolName) {
// --- Frontmatter tools (no editor needed) ---
case "updateFrontmatter": {
if (!frontmatterStore) return "Frontmatter store not available";
const fields = toolCall.args as Record<string, unknown>;
const updates: Record<string, unknown> = {};
for (const [key, value] of Object.entries(fields)) {
if (value !== undefined) updates[key] = value;
}
frontmatterStore.setAll(updates as any);
return `Frontmatter updated: ${Object.keys(updates).join(", ")}`;
}
case "addAuthor": {
if (!frontmatterStore) return "Frontmatter store not available";
const { name, url, affiliations, newAffiliationName, newAffiliationUrl } =
toolCall.args as {
name: string;
url?: string;
affiliations?: number[];
newAffiliationName?: string;
newAffiliationUrl?: string;
};
const affIndices = affiliations || [];
if (newAffiliationName) {
const newIdx = frontmatterStore.addAffiliation({
name: newAffiliationName,
url: newAffiliationUrl,
});
affIndices.push(newIdx);
}
frontmatterStore.addAuthor({ name, url, affiliations: affIndices });
return `Author "${name}" added`;
}
case "removeAuthor": {
if (!frontmatterStore) return "Frontmatter store not available";
const { index } = toolCall.args as { index: number };
const authors = frontmatterStore.get("authors");
if (index < 0 || index >= authors.length) {
return `Invalid author index ${index} (${authors.length} authors)`;
}
const removed = authors[index].name;
frontmatterStore.removeAuthor(index);
return `Author "${removed}" removed`;
}
// --- Editor tools ---
case "replaceSelection": {
if (!editor) return "Editor not available";
startAgentBatch();
const { newText } = toolCall.args as { newText: string };
const sel = pendingSelectionRef.current;
if (sel) {
editor
.chain()
.focus()
.setTextSelection(sel)
.insertContent(newText)
.run();
pendingSelectionRef.current = null;
return "Selection replaced successfully";
}
const { from, to } = editor.state.selection;
if (from !== to) {
editor.chain().focus().insertContent(newText).run();
return "Selection replaced successfully";
}
return "No text selected to replace";
}
case "insertAtCursor": {
if (!editor) return "Editor not available";
startAgentBatch();
const { text } = toolCall.args as { text: string };
editor.chain().focus().insertContent(text).run();
return "Text inserted successfully";
}
case "applyDiff": {
if (!editor) return "Editor not available";
startAgentBatch();
const { contextBefore, contentToDelete, contentToInsert, contextAfter } =
toolCall.args as {
contextBefore: string;
contentToDelete: string;
contentToInsert: string;
contextAfter: string;
};
const pos = findTextPosition(contentToDelete, contextBefore, contextAfter);
if (pos) {
editor
.chain()
.focus()
.setTextSelection(pos)
.insertContent(contentToInsert || "")
.run();
return "Diff applied successfully";
}
return `Text "${contentToDelete.slice(0, 50)}..." not found in document`;
}
default:
return `Unknown tool: ${toolCall.toolName}`;
}
},
[editor, frontmatterStore, startAgentBatch, findTextPosition],
);
/**
* Hard cap on consecutive auto-sent tool-call rounds to prevent
* runaway loops (e.g. a model that keeps calling tools instead of
* emitting a final text response).
*/
const MAX_TOOL_ROUNDS = 8;
const { addToolOutput, ...chat } = useChat({
transport,
messages: initialMessages,
sendAutomaticallyWhen: ({ messages }) => {
let toolRounds = 0;
for (let i = messages.length - 1; i >= 0; i--) {
const m = messages[i];
if (m.role !== "assistant") break;
const hasToolPart = (m.parts ?? []).some((p) =>
typeof p.type === "string" && p.type.startsWith("tool-"),
);
if (!hasToolPart) break;
toolRounds += 1;
}
if (toolRounds >= MAX_TOOL_ROUNDS) return false;
return lastAssistantMessageIsCompleteWithToolCalls({ messages });
},
async onToolCall({ toolCall }) {
if (toolCall.dynamic) return;
const result = executeToolCall({
toolName: toolCall.toolName as string,
args: toolCall.input,
toolCallId: toolCall.toolCallId,
});
(addToolOutput as (args: { tool: string; toolCallId: string; output: unknown }) => void)({
tool: toolCall.toolName as string,
toolCallId: toolCall.toolCallId,
output: result,
});
},
onFinish() {
endAgentBatch();
},
onError(error) {
console.error("[agent] chat error:", error);
endAgentBatch();
},
});
// Persist when a conversation round completes (status back to "ready")
const prevStatusRef = useRef(chat.status);
useEffect(() => {
const wasActive = prevStatusRef.current === "streaming" || prevStatusRef.current === "submitted";
prevStatusRef.current = chat.status;
if (wasActive && chat.status === "ready" && chat.messages.length > 0) {
onMessagesChangeRef.current?.(chat.messages as UIMessage[]);
}
}, [chat.status, chat.messages]);
const sendMessage = useCallback(
(content: string) => {
endAgentBatch();
const context = getEditorContext();
const model = modelRef.current;
chat.sendMessage({ text: content }, { body: { context, model } });
},
[chat, getEditorContext, endAgentBatch, modelRef],
);
const isLoading =
chat.status === "streaming" || chat.status === "submitted";
const clearMessages = useCallback(() => {
chat.setMessages([]);
onMessagesChangeRef.current?.([]);
}, [chat]);
const setMessages = useCallback((msgs: UIMessage[]) => {
chat.setMessages(msgs);
}, [chat]);
return {
messages: chat.messages as UIMessage[],
isLoading,
error: chat.error,
sendMessage,
clearMessages,
setMessages,
input,
setInput,
stop: chat.stop,
};
}