Spaces:
Running
Running
File size: 3,240 Bytes
f31a721 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 | import type { ChatCompletionChunk } from "@mlc-ai/web-llm";
import {
getLlmTextSearchResults,
getQuery,
getSearchPromise,
getSettings,
getTextGenerationState,
updateTextGenerationState,
} from "./pubSub";
import { getSystemPrompt } from "./systemPrompt";
import type { ChatMessage } from "./types";
/**
* Default context size for text generation in tokens
*/
export const defaultContextSize = 4096;
/**
* Custom error class for chat generation failures
*/
export class ChatGenerationError extends Error {
constructor(message: string) {
super(message);
this.name = "ChatGenerationError";
}
}
/**
* Formats search results for inclusion in chat prompts
* @param shouldIncludeUrl - Whether to include URLs in the formatted output
* @returns Formatted search results string
*/
export function getFormattedSearchResults(shouldIncludeUrl: boolean) {
const searchResults = getLlmTextSearchResults();
if (searchResults.length === 0) return "None.";
if (shouldIncludeUrl) {
return searchResults
.map(([title, snippet, url]) => `• [${title}](${url}) | ${snippet}`)
.join("\n");
}
return searchResults
.map(([title, snippet]) => `• ${title} | ${snippet}`)
.join("\n");
}
/**
* Waits for search results if they are required before starting response generation
*/
export async function canStartResponding() {
if (getSettings().searchResultsToConsider > 0) {
updateTextGenerationState("awaitingSearchResults");
await getSearchPromise();
}
}
/**
* Gets default parameters for streaming chat completion requests
* @returns Default chat completion parameters
*/
export function getDefaultChatCompletionCreateParamsStreaming() {
const settings = getSettings();
return {
stream: true,
max_tokens: settings.openAiContextLength ?? defaultContextSize,
temperature: settings.inferenceTemperature,
top_p: settings.inferenceTopP,
min_p: settings.minP,
frequency_penalty: settings.inferenceFrequencyPenalty,
presence_penalty: settings.inferencePresencePenalty,
} as const;
}
export async function handleStreamingResponse(
completion: AsyncIterable<ChatCompletionChunk>,
onChunk: (streamedMessage: string) => void,
options?: {
abortController?: { abort: () => void };
shouldUpdateGeneratingState?: boolean;
},
) {
let streamedMessage = "";
for await (const chunk of completion) {
const deltaContent = chunk.choices[0].delta.content;
if (deltaContent) {
streamedMessage += deltaContent;
onChunk(streamedMessage);
}
if (getTextGenerationState() === "interrupted") {
if (options?.abortController) {
options.abortController.abort();
}
throw new ChatGenerationError("Chat generation interrupted");
}
if (
options?.shouldUpdateGeneratingState &&
getTextGenerationState() !== "generating"
) {
updateTextGenerationState("generating");
}
}
return streamedMessage;
}
export function getDefaultChatMessages(searchResults: string): ChatMessage[] {
return [
{
role: "user",
content: getSystemPrompt(searchResults),
},
{ role: "assistant", content: "Ok!" },
{ role: "user", content: getQuery() },
];
}
|