diff --git a/core/ai_assistant_core/assistant/infrastructure/sqlalchemy_thread_repository.py b/core/ai_assistant_core/assistant/infrastructure/sqlalchemy_thread_repository.py index 0c3a24d..2729ff5 100644 --- a/core/ai_assistant_core/assistant/infrastructure/sqlalchemy_thread_repository.py +++ b/core/ai_assistant_core/assistant/infrastructure/sqlalchemy_thread_repository.py @@ -49,14 +49,27 @@ def list( order: Literal["asc", "desc"] = None, ) -> SyncCursorPage[Thread]: query = self.db.query(ThreadModel) + if after: - query = query.filter(ThreadModel.id > after) + after_created_at = ( + self.db.query(ThreadModel.created_at) + .filter(ThreadModel.id == after) + .scalar() + ) + query = query.filter(ThreadModel.created_at > after_created_at) if before: - query = query.filter(ThreadModel.id < before) - if order == "desc": - query = query.order_by(ThreadModel.created_at.desc()) - else: + before_created_at = ( + self.db.query(ThreadModel.created_at) + .filter(ThreadModel.id == before) + .scalar() + ) + query = query.filter(ThreadModel.created_at < before_created_at) + + if order == "asc": query = query.order_by(ThreadModel.created_at.asc()) + else: + query = query.order_by(ThreadModel.created_at.desc()) + if limit is not None: query = query.limit(limit) @@ -65,9 +78,6 @@ def list( return SyncCursorPage( data=threads, - order=order, - next_after=models[-1].created_at if models else None, - next_before=models[0].created_at if models else None, ) def retreive(self, thread_id: str) -> Thread: diff --git a/core/tests/assistant/infrastructure/test_sqlalchemy_thread_repository.py b/core/tests/assistant/infrastructure/test_sqlalchemy_thread_repository.py index 1b0eaa0..8fd8d2a 100644 --- a/core/tests/assistant/infrastructure/test_sqlalchemy_thread_repository.py +++ b/core/tests/assistant/infrastructure/test_sqlalchemy_thread_repository.py @@ -1,5 +1,5 @@ import pytest -from sqlalchemy import create_engine +from sqlalchemy import StaticPool, create_engine from sqlalchemy.orm import sessionmaker from sqlalchemy.orm.session import Session from ai_assistant_core.assistant.infrastructure.sqlalchemy_thread_repository import ( @@ -8,22 +8,60 @@ from ai_assistant_core.infrastructure.sqlalchemy import Base -@pytest.fixture(scope="session") +@pytest.fixture def session() -> Session: - engine = create_engine("sqlite:///:memory:", echo=False) - SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine) + database_url = "sqlite:///:memory:" + engine = create_engine( + database_url, + poolclass=StaticPool, + ) + SessionLocal = sessionmaker(autocommit=False, bind=engine) db = SessionLocal() Base.metadata.create_all(bind=engine) return db -@pytest.fixture(scope="session") +@pytest.fixture def instance(session) -> SqlalchemyThreadRepository: return SqlalchemyThreadRepository(db=session) class TestList: + def test_list_after(self, instance: SqlalchemyThreadRepository, session: Session): + thread1 = instance.create(created_at=10000) + thread2 = instance.create(created_at=20000) + thread3 = instance.create(created_at=30000) + + result = instance.list(after=thread1.id).data + + assert len(result) == 2 + assert result[0].id == thread3.id + assert result[1].id == thread2.id + + def test_list_after_asc( + self, instance: SqlalchemyThreadRepository, session: Session + ): + thread1 = instance.create(created_at=10000) + thread2 = instance.create(created_at=20000) + thread3 = instance.create(created_at=30000) + + result = instance.list(after=thread1.id, order="asc").data + + assert len(result) == 2 + assert result[0].id == thread2.id + assert result[1].id == thread3.id + + def test_list_before(self, instance: SqlalchemyThreadRepository): + thread1 = instance.create(created_at=10000) + thread2 = instance.create(created_at=20000) + instance.create(created_at=30000) + + result = instance.list(before=thread2.id) + + assert len(result.data) == 1 + assert result.data[0].id == thread1.id + def test_list(self, instance: SqlalchemyThreadRepository): thread1 = instance.create() thread2 = instance.create() @@ -78,24 +116,3 @@ def test_list_order_desc_with_limit(self, instance: SqlalchemyThreadRepository): assert result[0].id == thread3.id assert result[1].id == thread2.id - - def test_list_after(self, instance: SqlalchemyThreadRepository): - thread1 = instance.create(created_at=10000) - thread2 = instance.create(created_at=20000) - thread3 = instance.create(created_at=30000) - - result = instance.list(after=thread1.id).data - - assert len(result) == 2 - assert result[0].id == thread2.id - assert result[1].id == thread3.id - - def test_list_before(self, instance: SqlalchemyThreadRepository): - thread1 = instance.create(created_at=10000) - thread2 = instance.create(created_at=20000) - instance.create(created_at=30000) - - result = instance.list(before=thread2.id).data - - assert len(result) == 1 - assert result[0].id == thread1.id diff --git a/webapp/components/message-content.tsx b/webapp/components/message-content.tsx index 4eb2473..3953fa3 100644 --- a/webapp/components/message-content.tsx +++ b/webapp/components/message-content.tsx @@ -1,6 +1,7 @@ import { Markdown } from './markdown'; import { ChatMessageContentDto } from './chat.type'; import { Card } from './ui/card'; +import { useMemo } from 'react'; export interface ChatMessageProps { @@ -9,10 +10,12 @@ export interface ChatMessageProps { export function MessageContent({ content }: ChatMessageProps) { const textContent = content?.filter((contentItem) => contentItem.type === 'text')?.map((contentItem) => contentItem.text.value).join(''); + const markdown = useMemo(() => {textContent}, [textContent]); const images = content?.filter((contentItem) => contentItem.type === 'image_file' || contentItem.type === 'image_url'); + return ( <> - {textContent} + { markdown } { images?.map((image, index) => {JSON.stringify(image)})} );