diff --git a/webapp/src/lib/openai-fetch-mock/assistant-response-builder.ts b/webapp/src/lib/openai-fetch-mock/assistant-response-builder.ts index 4c8f402..ee1a30b 100644 --- a/webapp/src/lib/openai-fetch-mock/assistant-response-builder.ts +++ b/webapp/src/lib/openai-fetch-mock/assistant-response-builder.ts @@ -18,7 +18,7 @@ export abstract class AssistantResponseBuilder { return {}; } - public getResponse(requestBody: any): Response { + public getResponse(requestBody: any, _config: RequestInit): Response { const jsonBody = this.getResponseJson(requestBody); return { ok: true, diff --git a/webapp/src/lib/openai-fetch-mock/create-run-mock.ts b/webapp/src/lib/openai-fetch-mock/create-run-mock.ts index ffd3b8f..68c5717 100644 --- a/webapp/src/lib/openai-fetch-mock/create-run-mock.ts +++ b/webapp/src/lib/openai-fetch-mock/create-run-mock.ts @@ -132,7 +132,7 @@ export class CreateRunMock extends AssistantResponseBuilder { } - public getResponse(): Response { + public getResponse(_requestBody: never, config: RequestInit): Response { const chunks = this.chunks ?? []; const finishGenerationPromise = this.finishGenerationPromise; async function* generateChunks() { @@ -153,10 +153,24 @@ export class CreateRunMock extends AssistantResponseBuilder { headers: new Map([['content-type', 'text/event-stream']]), body: new ReadableStream({ async start(controller) { - for await (const chunk of chunkGenerator) { - controller.enqueue(chunk); + if (config && config.signal) { + if (config.signal.aborted) { + controller.error(new DOMException('Aborted', 'AbortError')); + return; + } + config.signal.addEventListener('abort', () => { + controller.error(new DOMException('Aborted', 'AbortError')); + }); + } + + try { + for await (const chunk of chunkGenerator) { + controller.enqueue(chunk); + } + controller.close(); + } catch (error) { + controller.error(error); } - controller.close(); }, }), } as unknown as Response; diff --git a/webapp/src/lib/openai-fetch-mock/openai-fetch-mock.ts b/webapp/src/lib/openai-fetch-mock/openai-fetch-mock.ts index 75988f2..961b48e 100644 --- a/webapp/src/lib/openai-fetch-mock/openai-fetch-mock.ts +++ b/webapp/src/lib/openai-fetch-mock/openai-fetch-mock.ts @@ -11,7 +11,7 @@ export function buildOpenAiApiFetchMock(builders: AssistantResponseBuilder[], de for (const builder of builders) { if (builder.doesMatch(url, config)) { await new Promise(resolve => setTimeout(resolve, delay)); - return builder.getResponse(requestBody); + return builder.getResponse(requestBody, config); } } }; diff --git a/webapp/src/lib/use-delete-thread.ts b/webapp/src/lib/use-delete-thread.ts index 46139c8..520479b 100644 --- a/webapp/src/lib/use-delete-thread.ts +++ b/webapp/src/lib/use-delete-thread.ts @@ -1,15 +1,18 @@ import { useCallback } from 'react'; import { useOpenaiClient } from './openai-client'; import { useListThreads } from './use-list-threads'; +import { useLocation } from 'wouter'; export function useDeleteThread(): (id: string) => Promise { const openai = useOpenaiClient(); + const [_, setLocation] = useLocation(); const { revalidate } = useListThreads(); return useCallback(async (id) => { await openai.beta.threads.del(id); revalidate(); + setLocation('/'); - }, [openai.beta.threads, revalidate]); + }, [openai.beta.threads, revalidate, setLocation]); } \ No newline at end of file diff --git a/webapp/src/lib/use-openai-assistant.ts b/webapp/src/lib/use-openai-assistant.ts index f95ea64..2c22bb0 100644 --- a/webapp/src/lib/use-openai-assistant.ts +++ b/webapp/src/lib/use-openai-assistant.ts @@ -20,11 +20,10 @@ interface Props { threadId: string; model?: string; temperature?: number; - initialInput?: string; } -export function useOpenAiAssistant({ assistantId = '', threadId, model = 'openai:gpt-3.5-turbo', temperature, initialInput }: Props) { +export function useOpenAiAssistant({ assistantId = '', threadId, model, temperature }: Props) { const [messages, setMessages] = useState([]); - const [input, setInput] = useState(initialInput ?? ''); + const [input, setInput] = useState(''); const { imageAttachments, removeImageAttachment, addImageAttachments, setImageAttachments } = useImageAttachments(); const [status, setStatus] = useState('awaiting_message'); const [error, setError] = useState(undefined); @@ -32,24 +31,12 @@ export function useOpenAiAssistant({ assistantId = '', threadId, model = 'openai const abortControlerRef = useRef(null); const openai = useOpenaiClient(); + const setUnknownError = useCallback((e: unknown) => { if (e instanceof Error) setError(e); else setError(new Error(`${e}`)); }, []); - useEffect(() => { - const fetchMessages = async () => { - try { - const newMessages = await openai.beta.threads.messages.list(threadId); - setMessages(newMessages.data); - } catch (e) { - setUnknownError(e); - } - }; - fetchMessages(); - - }, [openai.beta.threads.messages, threadId, setUnknownError]); - const handleInputChange = ( event: | React.ChangeEvent @@ -74,18 +61,14 @@ export function useOpenAiAssistant({ assistantId = '', threadId, model = 'openai temperature, }, { signal }) .on('messageCreated', (message: Message) => setMessages(messages => [...messages, message])) - .on('messageDelta', (_delta: MessageDelta, snapshot: Message) => setMessages(messages => { - return [ - ...messages.slice(0, messages.length - 1), - snapshot - ]; - })) - .on('messageDone', (message: Message) => { - return [ - ...messages.slice(0, messages.length - 1), - message - ]; - }) + .on('messageDelta', (_delta: MessageDelta, snapshot: Message) => setMessages(messages => [ + ...messages.slice(0, messages.length - 1), + snapshot + ])) + .on('messageDone', (message: Message) => [ + ...messages.slice(0, messages.length - 1), + message + ]) .on('error', (error) => rejects(error)) .on('abort', () => resolve()) .on('end', () => resolve()); @@ -93,11 +76,9 @@ export function useOpenAiAssistant({ assistantId = '', threadId, model = 'openai } catch (e) { setUnknownError(e); - setMessages(messages => { - return [ - ...messages.slice(0, messages.length - 1), - ]; - }); + setMessages(messages => [ + ...messages.slice(0, messages.length - 1), + ]); } finally { streamRef.current = null; @@ -119,6 +100,7 @@ export function useOpenAiAssistant({ assistantId = '', threadId, model = 'openai const append = useCallback(async ( message?: CreateMessage, ) => { + if (status === 'in_progress') throw new Error('Cannot append message while in progress'); try { if (message) { @@ -131,11 +113,12 @@ export function useOpenAiAssistant({ assistantId = '', threadId, model = 'openai created_message, ]); } + } catch (e) { setUnknownError(e); } - }, [openai.beta.threads.messages, threadId, setUnknownError]); + }, [openai.beta.threads.messages, threadId, setUnknownError, status]); const abort = useCallback(() => { if (abortControlerRef.current) { @@ -144,6 +127,24 @@ export function useOpenAiAssistant({ assistantId = '', threadId, model = 'openai } }, []); + const streamRunRef = useRef(streamRun); + + useEffect(() => { + streamRunRef.current = streamRun; + }, [streamRun]); + + useEffect(() => { + const fetchMessages = async () => { + try { + const newMessages = await openai.beta.threads.messages.list(threadId); + setMessages(newMessages.data); + } catch (e) { + setUnknownError(e); + } + }; + fetchMessages(); + + }, [openai.beta.threads.messages, threadId, setUnknownError]); const submitMessage = async ( event?: React.FormEvent, @@ -154,7 +155,7 @@ export function useOpenAiAssistant({ assistantId = '', threadId, model = 'openai return; } - append(createUserMessage({ input, imageAttachments })); + await append(createUserMessage({ input, imageAttachments })); setInput(''); setImageAttachments([]); }; diff --git a/webapp/src/lib/use-openai-assistant.ui.test.tsx b/webapp/src/lib/use-openai-assistant.ui.test.tsx index e389561..21e2670 100644 --- a/webapp/src/lib/use-openai-assistant.ui.test.tsx +++ b/webapp/src/lib/use-openai-assistant.ui.test.tsx @@ -10,6 +10,7 @@ import OpenAI from 'openai'; import { useOpenAiAssistant } from './use-openai-assistant'; import { buildOpenAiApiFetchMock, CreateMessageMock, CreateRunMock, CreateThreadMock, ErrorMock } from '@/lib/openai-fetch-mock'; import { useOpenaiClient } from './openai-client'; +import { useState } from 'react'; vi.mock('./openai-client'); @@ -23,7 +24,8 @@ describe('new-conversation', () => { dangerouslyAllowBrowser: true, })); - const { status, messages, error, append } = useOpenAiAssistant({ threadId }); + const { status, messages, error, append, abort } = useOpenAiAssistant({ threadId }); + const [appendError, setAppendError] = useState(); return (
@@ -38,10 +40,20 @@ describe('new-conversation', () => {
); }; @@ -200,6 +212,26 @@ describe('new-conversation', () => { expect(screen.getByTestId('message-1')).toHaveTextContent('Hello human'); }); }); + + it('should not submit when in progress', async () => { + await userEvent.click(screen.getByTestId('do-append')); + await waitFor(async () => { + expect(screen.getByTestId('status')).toHaveTextContent('in_progress'); + }); + + await userEvent.click(screen.getByTestId('do-append')); + + + await waitFor(async () => { + expect(screen.getByTestId('message-0')).toHaveTextContent('Hello AI'); + expect(screen.getByTestId('message-1')).toHaveTextContent('Hello human'); + expect(screen.queryByTestId('message-2')).not.toBeInTheDocument(); + }); + + await waitFor(async () => { + expect(screen.queryByTestId('append-error')).toBeInTheDocument(); + }); + }); }); describe('error', () => { @@ -227,12 +259,13 @@ describe('new-conversation', () => { beforeEach(() => { fetch.mockImplementation(buildOpenAiApiFetchMock([ new CreateMessageMock(), + new CreateThreadMock(), new CreateRunMock({ finishGenerationPromise: new Promise(resolve => { finishGeneration = resolve; }) }), - new CreateThreadMock(), + ])); render(); @@ -251,6 +284,31 @@ describe('new-conversation', () => { }); }); }); + + describe('abort', () => { + beforeEach(() => { + fetch.mockImplementation(buildOpenAiApiFetchMock([ + new CreateMessageMock(), + new CreateThreadMock(), + new CreateRunMock({ + finishGenerationPromise: new Promise(() => { + // Never resolve, will be resolved by abort controller + }) + }), + ])); + + render(); + }); + + it('should stop generation', async () => { + await userEvent.click(screen.getByTestId('do-append')); + await waitFor(async () => expect(screen.getByTestId('status')).toHaveTextContent('in_progress')); + + await userEvent.click(screen.getByTestId('abort')); + + await waitFor(async () => expect(screen.getByTestId('status')).toHaveTextContent('awaiting_message'), { timeout: 1000 }); + }); + }); }); describe('existing-thread', () => {