Skip to content

Commit

Permalink
✨ feat(api): support llm binding in api
Browse files Browse the repository at this point in the history
  • Loading branch information
centonhuang committed Feb 8, 2024
1 parent 89d0ab4 commit c246a29
Show file tree
Hide file tree
Showing 2 changed files with 66 additions and 51 deletions.
115 changes: 64 additions & 51 deletions internal/api/router/v1/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,17 +86,18 @@ async def get_session(session_id: str, info: Tuple[int, int] = Depends(sk_auth))
conn.commit()

query = (
conn.query(SessionSchema.session_id, SessionSchema.create_at, SessionSchema.update_at)
conn.query(SessionSchema.session_id, SessionSchema.create_at, SessionSchema.update_at, LLMSchema.llm_name)
.filter(SessionSchema.session_id == session_id)
.filter(SessionSchema.uid == uid)
.join(LLMSchema, isouter=True)
.filter(or_(SessionSchema.delete_at.is_(None), datetime.datetime.now() < SessionSchema.delete_at))
)
result = query.first()

if not result:
return StandardResponse(code=1, status="error", message="Session not exist")

session_id, create_at, update_at = result
session_id, create_at, update_at, llm_name = result

query = conn.query(MessageSchema.id, MessageSchema.chat_at, MessageSchema.message).filter(MessageSchema.session_id == session_id)
results = query.all()
Expand All @@ -110,6 +111,7 @@ async def get_session(session_id: str, info: Tuple[int, int] = Depends(sk_auth))
"session_id": session_id,
"create_at": create_at,
"update_at": update_at,
"bind_llm": llm_name,
"messages": messages,
}
return StandardResponse(code=0, status="success", message="Get session successfully", data=data)
Expand Down Expand Up @@ -162,15 +164,21 @@ async def chat(session_id: int, request: ChatRequest, info: Tuple[int, int] = De
conn.commit()

query = (
conn.query(SessionSchema.session_id)
conn.query(SessionSchema.session_id, LLMSchema.llm_name)
.filter(SessionSchema.session_id == session_id)
.filter(SessionSchema.uid == _uid)
.join(LLMSchema, isouter=True)
.filter(or_(SessionSchema.delete_at.is_(None), datetime.datetime.now() < SessionSchema.delete_at))
)

if not query.first():
result = query.first()
if not result:
return StandardResponse(code=1, status="error", message="Session not exist")

_, llm_name = result
if llm_name:
request.llm_name = llm_name
logger.debug(f"Use bind LLM: {llm_name}")
query = (
conn.query(LLMSchema)
.filter(LLMSchema.llm_name == request.llm_name)
Expand All @@ -180,54 +188,59 @@ async def chat(session_id: int, request: ChatRequest, info: Tuple[int, int] = De
if not _llm:
return StandardResponse(code=1, status="error", message="LLM not exist")

try:
llm: ChatOpenAI = init_llm(
llm_type=_llm.llm_type,
llm_name=_llm.llm_name,
base_url=_llm.base_url,
api_key=_llm.api_key,
temperature=request.temperature,
max_tokens=_llm.max_tokens,
)
if not llm_name:
conn.query(SessionSchema).filter(SessionSchema.session_id == session_id).update({SessionSchema.llm_id: _llm.llm_id})
conn.commit()
logger.debug(f"Bind LLM: {request.llm_name} to Session: {session_id}")

history = init_history(session_id=session_id)

match _llm.request_type:
case "string":
memory = init_str_memory(
history=history,
ai_name=_llm.ai_name,
user_name=_llm.user_name,
k=8,
)
prompt = init_str_prompt(
sys_name=_llm.sys_name,
sys_prompt=_llm.sys_prompt,
user_name=_llm.user_name,
ai_name=_llm.ai_name,
)
case "message":
memory = init_msg_memory(
history=history,
k=8,
)
prompt = init_msg_prompt(
sys_prompt=_llm.sys_prompt,
)
case _:
return StandardResponse(code=1, status="error", message="Invalid request type")

chain = LLMChain(
name="multi_turn_chat_llm_chain",
llm=llm,
prompt=prompt,
memory=memory,
verbose=True,
return_final_only=True,
)
except Exception as e:
logger.error(f"Init langchain modules failed: {e}")
return StandardResponse(code=1, status="error", message="Chat init failed")
try:
llm: ChatOpenAI = init_llm(
llm_type=_llm.llm_type,
llm_name=_llm.llm_name,
base_url=_llm.base_url,
api_key=_llm.api_key,
temperature=request.temperature,
max_tokens=_llm.max_tokens,
)

history = init_history(session_id=session_id)

match _llm.request_type:
case "string":
memory = init_str_memory(
history=history,
ai_name=_llm.ai_name,
user_name=_llm.user_name,
k=8,
)
prompt = init_str_prompt(
sys_name=_llm.sys_name,
sys_prompt=_llm.sys_prompt,
user_name=_llm.user_name,
ai_name=_llm.ai_name,
)
case "message":
memory = init_msg_memory(
history=history,
k=8,
)
prompt = init_msg_prompt(
sys_prompt=_llm.sys_prompt,
)
case _:
return StandardResponse(code=1, status="error", message="Invalid request type")

chain = LLMChain(
name="multi_turn_chat_llm_chain",
llm=llm,
prompt=prompt,
memory=memory,
verbose=True,
return_final_only=True,
)
except Exception as e:
logger.error(f"Init langchain modules failed: {e}")
return StandardResponse(code=1, status="error", message="Chat init failed")

async def _sse_response():
try:
Expand Down
2 changes: 2 additions & 0 deletions internal/middleware/mysql/model/sessions.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from sqlalchemy import Column, DateTime, ForeignKey, Integer

from .base import BaseSchema
from .llms import LLMSchema
from .users import UserSchema


Expand All @@ -13,4 +14,5 @@ class SessionSchema(BaseSchema):
create_at: datetime = Column(DateTime, default=datetime.now)
update_at: datetime = Column(DateTime, default=datetime.now, onupdate=datetime.now)
delete_at: datetime = Column(DateTime, nullable=True)
llm_id: int = Column(Integer, ForeignKey(LLMSchema.llm_id, ondelete="CASCADE"), nullable=True)
uid: int = Column(Integer, ForeignKey(UserSchema.uid, ondelete="CASCADE"), nullable=False)

0 comments on commit c246a29

Please sign in to comment.