Skip to content

Commit

Permalink
🦄 refactor(api): refactor chat api
Browse files Browse the repository at this point in the history
  • Loading branch information
centonhuang committed Apr 20, 2024
1 parent 88e8a95 commit 238a439
Showing 1 changed file with 12 additions and 15 deletions.
27 changes: 12 additions & 15 deletions src/api/router/v1/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from fastapi.responses import StreamingResponse
from sqlalchemy import or_

from src.langchain_aris.callback import OUTPUT_PARSER_NAME
from src.langchain_aris.callback import DOCUMENT_STUFFER__NAME, OUTPUT_PARSER_NAME
from src.langchain_aris.chain import init_chat_chain, init_retriever_qa_chain
from src.logger import logger
from src.middleware.mysql import session
Expand Down Expand Up @@ -270,6 +270,11 @@ async def chat(session_id: int, request: ChatRequest, info: Tuple[int, int] = De
conn.commit()
logger.debug(f"Bind LLM: {request.llm_name} to Session: {session_id}")

chain_kwargs = {
"llm_schema": _llm,
"temperature": request.temperature,
"session_id": session_id,
}
if request.vector_db_id:
with session() as conn:
query = (
Expand All @@ -296,20 +301,9 @@ async def chat(session_id: int, request: ChatRequest, info: Tuple[int, int] = De
return StandardResponse(code=1, status="error", message="Embedding not exist")

chain_func = init_retriever_qa_chain
chain_kwargs = {
"llm_schema": _llm,
"embedding_schema": _embedding,
"temperature": request.temperature,
"session_id": session_id,
"vector_db_id": request.vector_db_id,
}
chain_kwargs.update({"embedding_schema": _embedding, "vector_db_id": request.vector_db_id})
else:
chain_func = init_chat_chain
chain_kwargs = {
"llm_schema": _llm,
"temperature": request.temperature,
"session_id": session_id,
}
try:
chain = chain_func(**chain_kwargs)
except Exception as e:
Expand All @@ -319,9 +313,12 @@ async def chat(session_id: int, request: ChatRequest, info: Tuple[int, int] = De
r.delete(f"session:{session_id}")
r.delete(f"uid:{_uid}:sessions")

# async for event in chain.astream_events(request.message, version="v1", include_names=[OUTPUT_PARSER_NAME, DOCUMENT_STUFFER__NAME]):
# print(event)

async def _filter_event_stream() -> AsyncGenerator[str, None]:
async for event in chain.astream_events({"user_prompt": request.message}, version="v1", include_names=[OUTPUT_PARSER_NAME]):
if event["event"] not in ["on_parser_stream"]:
async for event in chain.astream_events(request.message, version="v1", include_names=[OUTPUT_PARSER_NAME, DOCUMENT_STUFFER__NAME]):
if event["event"] not in ["on_parser_stream", "on_chain_stream"]:
continue
yield f"data: {dumps(event, ensure_ascii=False)}\n\n"
r.delete(redis_lock)
Expand Down

0 comments on commit 238a439

Please sign in to comment.