Skip to content

Commit

Permalink
fix: thread limits (#19)
Browse files Browse the repository at this point in the history
  • Loading branch information
samuelint authored Jul 9, 2024
1 parent cd086da commit 4756e59
Show file tree
Hide file tree
Showing 3 changed files with 65 additions and 35 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -49,14 +49,27 @@ def list(
order: Literal["asc", "desc"] = None,
) -> SyncCursorPage[Thread]:
query = self.db.query(ThreadModel)

if after:
query = query.filter(ThreadModel.id > after)
after_created_at = (
self.db.query(ThreadModel.created_at)
.filter(ThreadModel.id == after)
.scalar()
)
query = query.filter(ThreadModel.created_at > after_created_at)
if before:
query = query.filter(ThreadModel.id < before)
if order == "desc":
query = query.order_by(ThreadModel.created_at.desc())
else:
before_created_at = (
self.db.query(ThreadModel.created_at)
.filter(ThreadModel.id == before)
.scalar()
)
query = query.filter(ThreadModel.created_at < before_created_at)

if order == "asc":
query = query.order_by(ThreadModel.created_at.asc())
else:
query = query.order_by(ThreadModel.created_at.desc())

if limit is not None:
query = query.limit(limit)

Expand All @@ -65,9 +78,6 @@ def list(

return SyncCursorPage(
data=threads,
order=order,
next_after=models[-1].created_at if models else None,
next_before=models[0].created_at if models else None,
)

def retreive(self, thread_id: str) -> Thread:
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import pytest
from sqlalchemy import create_engine
from sqlalchemy import StaticPool, create_engine
from sqlalchemy.orm import sessionmaker
from sqlalchemy.orm.session import Session
from ai_assistant_core.assistant.infrastructure.sqlalchemy_thread_repository import (
Expand All @@ -8,22 +8,60 @@
from ai_assistant_core.infrastructure.sqlalchemy import Base


@pytest.fixture(scope="session")
@pytest.fixture
def session() -> Session:
engine = create_engine("sqlite:///:memory:", echo=False)
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
database_url = "sqlite:///:memory:"
engine = create_engine(
database_url,
poolclass=StaticPool,
)
SessionLocal = sessionmaker(autocommit=False, bind=engine)
db = SessionLocal()
Base.metadata.create_all(bind=engine)

return db


@pytest.fixture(scope="session")
@pytest.fixture
def instance(session) -> SqlalchemyThreadRepository:
return SqlalchemyThreadRepository(db=session)


class TestList:
def test_list_after(self, instance: SqlalchemyThreadRepository, session: Session):
thread1 = instance.create(created_at=10000)
thread2 = instance.create(created_at=20000)
thread3 = instance.create(created_at=30000)

result = instance.list(after=thread1.id).data

assert len(result) == 2
assert result[0].id == thread3.id
assert result[1].id == thread2.id

def test_list_after_asc(
self, instance: SqlalchemyThreadRepository, session: Session
):
thread1 = instance.create(created_at=10000)
thread2 = instance.create(created_at=20000)
thread3 = instance.create(created_at=30000)

result = instance.list(after=thread1.id, order="asc").data

assert len(result) == 2
assert result[0].id == thread2.id
assert result[1].id == thread3.id

def test_list_before(self, instance: SqlalchemyThreadRepository):
thread1 = instance.create(created_at=10000)
thread2 = instance.create(created_at=20000)
instance.create(created_at=30000)

result = instance.list(before=thread2.id)

assert len(result.data) == 1
assert result.data[0].id == thread1.id

def test_list(self, instance: SqlalchemyThreadRepository):
thread1 = instance.create()
thread2 = instance.create()
Expand Down Expand Up @@ -78,24 +116,3 @@ def test_list_order_desc_with_limit(self, instance: SqlalchemyThreadRepository):

assert result[0].id == thread3.id
assert result[1].id == thread2.id

def test_list_after(self, instance: SqlalchemyThreadRepository):
thread1 = instance.create(created_at=10000)
thread2 = instance.create(created_at=20000)
thread3 = instance.create(created_at=30000)

result = instance.list(after=thread1.id).data

assert len(result) == 2
assert result[0].id == thread2.id
assert result[1].id == thread3.id

def test_list_before(self, instance: SqlalchemyThreadRepository):
thread1 = instance.create(created_at=10000)
thread2 = instance.create(created_at=20000)
instance.create(created_at=30000)

result = instance.list(before=thread2.id).data

assert len(result) == 1
assert result[0].id == thread1.id
5 changes: 4 additions & 1 deletion webapp/components/message-content.tsx
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import { Markdown } from './markdown';
import { ChatMessageContentDto } from './chat.type';
import { Card } from './ui/card';
import { useMemo } from 'react';


export interface ChatMessageProps {
Expand All @@ -9,10 +10,12 @@ export interface ChatMessageProps {

export function MessageContent({ content }: ChatMessageProps) {
const textContent = content?.filter((contentItem) => contentItem.type === 'text')?.map((contentItem) => contentItem.text.value).join('');
const markdown = useMemo(() => <Markdown>{textContent}</Markdown>, [textContent]);
const images = content?.filter((contentItem) => contentItem.type === 'image_file' || contentItem.type === 'image_url');

return (
<>
<Markdown>{textContent}</Markdown>
{ markdown }
{ images?.map((image, index) => <Card key={index}>{JSON.stringify(image)}</Card>)}
</>
);
Expand Down

0 comments on commit 4756e59

Please sign in to comment.