Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: bug where generation stop is not working #44

Merged
merged 4 commits into from
Jul 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ export abstract class AssistantResponseBuilder {
return {};
}

public getResponse(requestBody: any): Response {
public getResponse(requestBody: any, _config: RequestInit): Response {
const jsonBody = this.getResponseJson(requestBody);
return {
ok: true,
Expand Down
22 changes: 18 additions & 4 deletions webapp/src/lib/openai-fetch-mock/create-run-mock.ts
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ export class CreateRunMock extends AssistantResponseBuilder {
}


public getResponse(): Response {
public getResponse(_requestBody: never, config: RequestInit): Response {
const chunks = this.chunks ?? [];
const finishGenerationPromise = this.finishGenerationPromise;
async function* generateChunks() {
Expand All @@ -153,10 +153,24 @@ export class CreateRunMock extends AssistantResponseBuilder {
headers: new Map([['content-type', 'text/event-stream']]),
body: new ReadableStream({
async start(controller) {
for await (const chunk of chunkGenerator) {
controller.enqueue(chunk);
if (config && config.signal) {
if (config.signal.aborted) {
controller.error(new DOMException('Aborted', 'AbortError'));
return;
}
config.signal.addEventListener('abort', () => {
controller.error(new DOMException('Aborted', 'AbortError'));
});
}

try {
for await (const chunk of chunkGenerator) {
controller.enqueue(chunk);
}
controller.close();
} catch (error) {
controller.error(error);
}
controller.close();
},
}),
} as unknown as Response;
Expand Down
2 changes: 1 addition & 1 deletion webapp/src/lib/openai-fetch-mock/openai-fetch-mock.ts
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ export function buildOpenAiApiFetchMock(builders: AssistantResponseBuilder[], de
for (const builder of builders) {
if (builder.doesMatch(url, config)) {
await new Promise(resolve => setTimeout(resolve, delay));
return builder.getResponse(requestBody);
return builder.getResponse(requestBody, config);
}
}
};
Expand Down
5 changes: 4 additions & 1 deletion webapp/src/lib/use-delete-thread.ts
Original file line number Diff line number Diff line change
@@ -1,15 +1,18 @@
import { useCallback } from 'react';
import { useOpenaiClient } from './openai-client';
import { useListThreads } from './use-list-threads';
import { useLocation } from 'wouter';


export function useDeleteThread(): (id: string) => Promise<void> {
const openai = useOpenaiClient();
const [_, setLocation] = useLocation();
const { revalidate } = useListThreads();

return useCallback(async (id) => {
await openai.beta.threads.del(id);
revalidate();
setLocation('/');

}, [openai.beta.threads, revalidate]);
}, [openai.beta.threads, revalidate, setLocation]);
}
71 changes: 36 additions & 35 deletions webapp/src/lib/use-openai-assistant.ts
Original file line number Diff line number Diff line change
Expand Up @@ -20,36 +20,23 @@ interface Props {
threadId: string;
model?: string;
temperature?: number;
initialInput?: string;
}
export function useOpenAiAssistant({ assistantId = '', threadId, model = 'openai:gpt-3.5-turbo', temperature, initialInput }: Props) {
export function useOpenAiAssistant({ assistantId = '', threadId, model, temperature }: Props) {
const [messages, setMessages] = useState<Message[]>([]);
const [input, setInput] = useState(initialInput ?? '');
const [input, setInput] = useState('');
const { imageAttachments, removeImageAttachment, addImageAttachments, setImageAttachments } = useImageAttachments();
const [status, setStatus] = useState<AssistantStatus>('awaiting_message');
const [error, setError] = useState<undefined | Error>(undefined);
const streamRef = useRef<AssistantStream | null>(null);
const abortControlerRef = useRef<AbortController | null>(null);
const openai = useOpenaiClient();


const setUnknownError = useCallback((e: unknown) => {
if (e instanceof Error) setError(e);
else setError(new Error(`${e}`));
}, []);

useEffect(() => {
const fetchMessages = async () => {
try {
const newMessages = await openai.beta.threads.messages.list(threadId);
setMessages(newMessages.data);
} catch (e) {
setUnknownError(e);
}
};
fetchMessages();

}, [openai.beta.threads.messages, threadId, setUnknownError]);

const handleInputChange = (
event:
| React.ChangeEvent<HTMLInputElement>
Expand All @@ -74,30 +61,24 @@ export function useOpenAiAssistant({ assistantId = '', threadId, model = 'openai
temperature,
}, { signal })
.on('messageCreated', (message: Message) => setMessages(messages => [...messages, message]))
.on('messageDelta', (_delta: MessageDelta, snapshot: Message) => setMessages(messages => {
return [
...messages.slice(0, messages.length - 1),
snapshot
];
}))
.on('messageDone', (message: Message) => {
return [
...messages.slice(0, messages.length - 1),
message
];
})
.on('messageDelta', (_delta: MessageDelta, snapshot: Message) => setMessages(messages => [
...messages.slice(0, messages.length - 1),
snapshot
]))
.on('messageDone', (message: Message) => [
...messages.slice(0, messages.length - 1),
message
])
.on('error', (error) => rejects(error))
.on('abort', () => resolve())
.on('end', () => resolve());
});

} catch (e) {
setUnknownError(e);
setMessages(messages => {
return [
...messages.slice(0, messages.length - 1),
];
});
setMessages(messages => [
...messages.slice(0, messages.length - 1),
]);
}
finally {
streamRef.current = null;
Expand All @@ -119,6 +100,7 @@ export function useOpenAiAssistant({ assistantId = '', threadId, model = 'openai
const append = useCallback(async (
message?: CreateMessage,
) => {
if (status === 'in_progress') throw new Error('Cannot append message while in progress');

try {
if (message) {
Expand All @@ -131,11 +113,12 @@ export function useOpenAiAssistant({ assistantId = '', threadId, model = 'openai
created_message,
]);
}

} catch (e) {
setUnknownError(e);
}

}, [openai.beta.threads.messages, threadId, setUnknownError]);
}, [openai.beta.threads.messages, threadId, setUnknownError, status]);

const abort = useCallback(() => {
if (abortControlerRef.current) {
Expand All @@ -144,6 +127,24 @@ export function useOpenAiAssistant({ assistantId = '', threadId, model = 'openai
}
}, []);

const streamRunRef = useRef(streamRun);

useEffect(() => {
streamRunRef.current = streamRun;
}, [streamRun]);

useEffect(() => {
const fetchMessages = async () => {
try {
const newMessages = await openai.beta.threads.messages.list(threadId);
setMessages(newMessages.data);
} catch (e) {
setUnknownError(e);
}
};
fetchMessages();

}, [openai.beta.threads.messages, threadId, setUnknownError]);

const submitMessage = async (
event?: React.FormEvent<HTMLFormElement>,
Expand All @@ -154,7 +155,7 @@ export function useOpenAiAssistant({ assistantId = '', threadId, model = 'openai
return;
}

append(createUserMessage({ input, imageAttachments }));
await append(createUserMessage({ input, imageAttachments }));
setInput('');
setImageAttachments([]);
};
Expand Down
66 changes: 62 additions & 4 deletions webapp/src/lib/use-openai-assistant.ui.test.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import OpenAI from 'openai';
import { useOpenAiAssistant } from './use-openai-assistant';
import { buildOpenAiApiFetchMock, CreateMessageMock, CreateRunMock, CreateThreadMock, ErrorMock } from '@/lib/openai-fetch-mock';
import { useOpenaiClient } from './openai-client';
import { useState } from 'react';


vi.mock('./openai-client');
Expand All @@ -23,7 +24,8 @@ describe('new-conversation', () => {
dangerouslyAllowBrowser: true,
}));

const { status, messages, error, append } = useOpenAiAssistant({ threadId });
const { status, messages, error, append, abort } = useOpenAiAssistant({ threadId });
const [appendError, setAppendError] = useState<Error>();

return (
<div>
Expand All @@ -38,10 +40,20 @@ describe('new-conversation', () => {

<button
data-testid="do-append"
onClick={() => {
append({ role: 'user', content: 'Hello AI' });
onClick={async () => {
try {
await append({ role: 'user', content: 'Hello AI' });
} catch (e) {
setAppendError(e as Error);
}
}}
/>
<button
data-testid="abort"
onClick={() => abort()}
/>

{ appendError && <div data-testid="append-error">{appendError.toString()}</div>}
</div>
);
};
Expand Down Expand Up @@ -200,6 +212,26 @@ describe('new-conversation', () => {
expect(screen.getByTestId('message-1')).toHaveTextContent('Hello human');
});
});

it('should not submit when in progress', async () => {
await userEvent.click(screen.getByTestId('do-append'));
await waitFor(async () => {
expect(screen.getByTestId('status')).toHaveTextContent('in_progress');
});

await userEvent.click(screen.getByTestId('do-append'));


await waitFor(async () => {
expect(screen.getByTestId('message-0')).toHaveTextContent('Hello AI');
expect(screen.getByTestId('message-1')).toHaveTextContent('Hello human');
expect(screen.queryByTestId('message-2')).not.toBeInTheDocument();
});

await waitFor(async () => {
expect(screen.queryByTestId('append-error')).toBeInTheDocument();
});
});
});

describe('error', () => {
Expand Down Expand Up @@ -227,12 +259,13 @@ describe('new-conversation', () => {
beforeEach(() => {
fetch.mockImplementation(buildOpenAiApiFetchMock([
new CreateMessageMock(),
new CreateThreadMock(),
new CreateRunMock({
finishGenerationPromise: new Promise(resolve => {
finishGeneration = resolve;
})
}),
new CreateThreadMock(),

]));

render(<TestComponent />);
Expand All @@ -251,6 +284,31 @@ describe('new-conversation', () => {
});
});
});

describe('abort', () => {
beforeEach(() => {
fetch.mockImplementation(buildOpenAiApiFetchMock([
new CreateMessageMock(),
new CreateThreadMock(),
new CreateRunMock({
finishGenerationPromise: new Promise(() => {
// Never resolve, will be resolved by abort controller
})
}),
]));

render(<TestComponent />);
});

it('should stop generation', async () => {
await userEvent.click(screen.getByTestId('do-append'));
await waitFor(async () => expect(screen.getByTestId('status')).toHaveTextContent('in_progress'));

await userEvent.click(screen.getByTestId('abort'));

await waitFor(async () => expect(screen.getByTestId('status')).toHaveTextContent('awaiting_message'), { timeout: 1000 });
});
});
});

describe('existing-thread', () => {
Expand Down