diff --git a/core/ai_assistant_core/assistant/infrastructure/sqlalchemy_thread_repository.py b/core/ai_assistant_core/assistant/infrastructure/sqlalchemy_thread_repository.py
index 0c3a24d..2729ff5 100644
--- a/core/ai_assistant_core/assistant/infrastructure/sqlalchemy_thread_repository.py
+++ b/core/ai_assistant_core/assistant/infrastructure/sqlalchemy_thread_repository.py
@@ -49,14 +49,27 @@ def list(
order: Literal["asc", "desc"] = None,
) -> SyncCursorPage[Thread]:
query = self.db.query(ThreadModel)
+
if after:
- query = query.filter(ThreadModel.id > after)
+ after_created_at = (
+ self.db.query(ThreadModel.created_at)
+ .filter(ThreadModel.id == after)
+ .scalar()
+ )
+ query = query.filter(ThreadModel.created_at > after_created_at)
if before:
- query = query.filter(ThreadModel.id < before)
- if order == "desc":
- query = query.order_by(ThreadModel.created_at.desc())
- else:
+ before_created_at = (
+ self.db.query(ThreadModel.created_at)
+ .filter(ThreadModel.id == before)
+ .scalar()
+ )
+ query = query.filter(ThreadModel.created_at < before_created_at)
+
+ if order == "asc":
query = query.order_by(ThreadModel.created_at.asc())
+ else:
+ query = query.order_by(ThreadModel.created_at.desc())
+
if limit is not None:
query = query.limit(limit)
@@ -65,9 +78,6 @@ def list(
return SyncCursorPage(
data=threads,
- order=order,
- next_after=models[-1].created_at if models else None,
- next_before=models[0].created_at if models else None,
)
def retreive(self, thread_id: str) -> Thread:
diff --git a/core/tests/assistant/infrastructure/test_sqlalchemy_thread_repository.py b/core/tests/assistant/infrastructure/test_sqlalchemy_thread_repository.py
index 1b0eaa0..8fd8d2a 100644
--- a/core/tests/assistant/infrastructure/test_sqlalchemy_thread_repository.py
+++ b/core/tests/assistant/infrastructure/test_sqlalchemy_thread_repository.py
@@ -1,5 +1,5 @@
import pytest
-from sqlalchemy import create_engine
+from sqlalchemy import StaticPool, create_engine
from sqlalchemy.orm import sessionmaker
from sqlalchemy.orm.session import Session
from ai_assistant_core.assistant.infrastructure.sqlalchemy_thread_repository import (
@@ -8,22 +8,60 @@
from ai_assistant_core.infrastructure.sqlalchemy import Base
-@pytest.fixture(scope="session")
+@pytest.fixture
def session() -> Session:
- engine = create_engine("sqlite:///:memory:", echo=False)
- SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
+ database_url = "sqlite:///:memory:"
+ engine = create_engine(
+ database_url,
+ poolclass=StaticPool,
+ )
+ SessionLocal = sessionmaker(autocommit=False, bind=engine)
db = SessionLocal()
Base.metadata.create_all(bind=engine)
return db
-@pytest.fixture(scope="session")
+@pytest.fixture
def instance(session) -> SqlalchemyThreadRepository:
return SqlalchemyThreadRepository(db=session)
class TestList:
+ def test_list_after(self, instance: SqlalchemyThreadRepository, session: Session):
+ thread1 = instance.create(created_at=10000)
+ thread2 = instance.create(created_at=20000)
+ thread3 = instance.create(created_at=30000)
+
+ result = instance.list(after=thread1.id).data
+
+ assert len(result) == 2
+ assert result[0].id == thread3.id
+ assert result[1].id == thread2.id
+
+ def test_list_after_asc(
+ self, instance: SqlalchemyThreadRepository, session: Session
+ ):
+ thread1 = instance.create(created_at=10000)
+ thread2 = instance.create(created_at=20000)
+ thread3 = instance.create(created_at=30000)
+
+ result = instance.list(after=thread1.id, order="asc").data
+
+ assert len(result) == 2
+ assert result[0].id == thread2.id
+ assert result[1].id == thread3.id
+
+ def test_list_before(self, instance: SqlalchemyThreadRepository):
+ thread1 = instance.create(created_at=10000)
+ thread2 = instance.create(created_at=20000)
+ instance.create(created_at=30000)
+
+ result = instance.list(before=thread2.id)
+
+ assert len(result.data) == 1
+ assert result.data[0].id == thread1.id
+
def test_list(self, instance: SqlalchemyThreadRepository):
thread1 = instance.create()
thread2 = instance.create()
@@ -78,24 +116,3 @@ def test_list_order_desc_with_limit(self, instance: SqlalchemyThreadRepository):
assert result[0].id == thread3.id
assert result[1].id == thread2.id
-
- def test_list_after(self, instance: SqlalchemyThreadRepository):
- thread1 = instance.create(created_at=10000)
- thread2 = instance.create(created_at=20000)
- thread3 = instance.create(created_at=30000)
-
- result = instance.list(after=thread1.id).data
-
- assert len(result) == 2
- assert result[0].id == thread2.id
- assert result[1].id == thread3.id
-
- def test_list_before(self, instance: SqlalchemyThreadRepository):
- thread1 = instance.create(created_at=10000)
- thread2 = instance.create(created_at=20000)
- instance.create(created_at=30000)
-
- result = instance.list(before=thread2.id).data
-
- assert len(result) == 1
- assert result[0].id == thread1.id
diff --git a/webapp/components/message-content.tsx b/webapp/components/message-content.tsx
index 4eb2473..3953fa3 100644
--- a/webapp/components/message-content.tsx
+++ b/webapp/components/message-content.tsx
@@ -1,6 +1,7 @@
import { Markdown } from './markdown';
import { ChatMessageContentDto } from './chat.type';
import { Card } from './ui/card';
+import { useMemo } from 'react';
export interface ChatMessageProps {
@@ -9,10 +10,12 @@ export interface ChatMessageProps {
export function MessageContent({ content }: ChatMessageProps) {
const textContent = content?.filter((contentItem) => contentItem.type === 'text')?.map((contentItem) => contentItem.text.value).join('');
+ const markdown = useMemo(() => {textContent}, [textContent]);
const images = content?.filter((contentItem) => contentItem.type === 'image_file' || contentItem.type === 'image_url');
+
return (
<>
- {textContent}
+ { markdown }
{ images?.map((image, index) => {JSON.stringify(image)})}
>
);