Skip to content

Commit

Permalink
🐞 fix(api): fix some cache bug and add chat lock
Browse files Browse the repository at this point in the history
  • Loading branch information
centonhuang committed Feb 9, 2024
1 parent 339abc7 commit 5d98df9
Showing 1 changed file with 21 additions and 12 deletions.
33 changes: 21 additions & 12 deletions internal/api/router/v1/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,11 @@
from sqlalchemy import or_

from internal.langchain.llm import init_llm
from internal.langchain.memory import (init_history, init_msg_memory,
init_str_memory)
from internal.langchain.memory import init_history, init_msg_memory, init_str_memory
from internal.langchain.prompt import init_msg_prompt, init_str_prompt
from internal.logger import logger
from internal.middleware.mysql import session
from internal.middleware.mysql.model import (LLMSchema, MessageSchema,
SessionSchema)
from internal.middleware.mysql.model import LLMSchema, MessageSchema, SessionSchema
from internal.middleware.redis import r

from ...auth import sk_auth
Expand Down Expand Up @@ -70,16 +68,16 @@ async def create_session(info: Tuple[int, int] = Depends(sk_auth)):
r.delete(f"session:{_session.session_id}")

r.delete(f"uid:{uid}:sessions")

return StandardResponse(code=0, status="success", data=data)


@session_router.get("/sessions", response_model=StandardResponse, dependencies=[Depends(sk_auth)])
async def list_session(page_id: int = 0, per_page_num: int = 20, info: Tuple[int, int] = Depends(sk_auth)):
uid, _ = info
redis_set = f"uid:{uid}:sessions"
if r.exists(redis_set):
if session_list := r.zrange(redis_set, page_id * per_page_num, (page_id + 1) * per_page_num - 1, desc=True):
redis_list = f"uid:{uid}:sessions"
if r.exists(redis_list):
if session_list := r.lrange(redis_list, page_id * per_page_num, (page_id + 1) * per_page_num - 1):
return StandardResponse(
code=0,
status="success",
Expand All @@ -97,7 +95,7 @@ async def list_session(page_id: int = 0, per_page_num: int = 20, info: Tuple[int
conn.query(SessionSchema.session_id, SessionSchema.create_at, SessionSchema.update_at)
.filter(SessionSchema.uid == uid)
.filter(or_(SessionSchema.delete_at.is_(None), datetime.datetime.now() < SessionSchema.delete_at))
.order_by(SessionSchema.create_at.desc())
.order_by(SessionSchema.session_id.desc())
.offset(page_id * per_page_num)
)
result = query.limit(per_page_num).all()
Expand All @@ -111,7 +109,7 @@ async def list_session(page_id: int = 0, per_page_num: int = 20, info: Tuple[int
for session_id, create_at, update_at in result
]
for s in session_list:
r.zadd(redis_set, {dumps(s, ensure_ascii=False): s["session_id"]})
r.lpush(redis_list, dumps(s, ensure_ascii=False))

data = {"session_list": session_list}
return StandardResponse(code=0, status="success", data=data)
Expand Down Expand Up @@ -205,7 +203,7 @@ async def delete_session(session_id: int, uid: int = -1, info: Tuple[int, int] =

if not query.first():
return StandardResponse(code=1, status="error", message="Session not exist")

with session() as conn:
if not conn.is_active:
conn.rollback()
Expand All @@ -223,14 +221,19 @@ async def delete_session(session_id: int, uid: int = -1, info: Tuple[int, int] =

r.delete(f"session:{session_id}")
r.delete(f"uid:{uid}:sessions")

return StandardResponse(code=0, status="success", message="Delete session successfully")


@session_router.post("/{session_id}/chat", dependencies=[Depends(sk_auth)])
async def chat(session_id: int, request: ChatRequest, info: Tuple[int, int] = Depends(sk_auth)) -> StandardResponse | SSEResponse:
_uid, _ = info

redis_lock = f"chat_lock:uid:{_uid}"
if r.exists(redis_lock):
return StandardResponse(code=1, status="error", message="You are chatting, please wait a moment")
r.set(redis_lock, "lock", ex=30)

with session() as conn:
if not conn.is_active:
conn.rollback()
Expand All @@ -248,6 +251,7 @@ async def chat(session_id: int, request: ChatRequest, info: Tuple[int, int] = De

result = query.first()
if not result:
r.delete(redis_lock)
return StandardResponse(code=1, status="error", message="Session not exist")

_, llm_name = result
Expand All @@ -261,6 +265,7 @@ async def chat(session_id: int, request: ChatRequest, info: Tuple[int, int] = De
)
_llm: LLMSchema | None = query.first()
if not _llm:
r.delete(redis_lock)
return StandardResponse(code=1, status="error", message="LLM not exist")

if not llm_name:
Expand Down Expand Up @@ -337,8 +342,12 @@ async def _sse_response():
logger.debug(f"SSE response: {data}")

yield data
r.delete(redis_lock)
except Exception as e:
logger.error(f"SSE failed: {e}")
yield dumps({"extras": {}, "delta": "", "status": f"exception: {e}"}) + "\n"

r.delete(f"session:{session_id}")
r.delete(f"uid:{_uid}:sessions")

return StreamingResponse(_sse_response(), media_type="text/event-stream")

0 comments on commit 5d98df9

Please sign in to comment.