Skip to content

Commit

Permalink
✨ feat(webui): align with api chat
Browse files Browse the repository at this point in the history
  • Loading branch information
centonhuang committed Apr 19, 2024
1 parent deb0bba commit f9849b5
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 47 deletions.
32 changes: 10 additions & 22 deletions pages/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import streamlit as st
from streamlit import session_state as cache

from src.webui.utils import chat, get_history, get_llms, get_sessions, get_vector_dbs, new_session, retriever_qa
from src.webui.utils import chat, get_history, get_llms, get_sessions, get_vector_dbs, new_session

ABOUT = """\
### Aris AI is a project of providing private llm api and webui service
Expand Down Expand Up @@ -132,27 +132,15 @@ def body():
with container.chat_message("ai"):
resp = ""
place_holder = st.empty()
if not cache.vector_db_id:
for token in chat(
api_key=cache.api_key,
session_id=cache.session_id,
llm_name=cache.llm,
message=prompt,
temperature=cache.temperature,
):
resp += token
place_holder.markdown(resp)
else:
for token in retriever_qa(
api_key=cache.api_key,
session_id=cache.session_id,
llm_name=cache.llm,
message=prompt,
temperature=cache.temperature,
vector_db_id=cache.vector_db_id,
):
resp += token
place_holder.markdown(resp)
for token in chat(
api_key=cache.api_key,
session_id=cache.session_id,
llm_name=cache.llm,
message=prompt,
temperature=cache.temperature,
):
resp += token
place_holder.markdown(resp)


def main():
Expand Down
29 changes: 4 additions & 25 deletions src/webui/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,32 +219,9 @@ def upload_urls(api_key: str, vector_db_id: int, urls: str, chunk_size: int, chu
return data


def chat(api_key: str, session_id: int, message: str, llm_name: str, temperature: float) -> Iterator[str]:
def chat(api_key: str, session_id: int, message: str, llm_name: str, temperature: float, vector_db_id: int | None = None) -> Iterator[str]:
url = urljoin(API_URL, f"v1/session/{session_id}/chat")
headers = {"Authorization": f"Bearer {api_key}"}
data = {
"message": message,
"llm_name": llm_name,
"temperature": temperature,
}

response = requests.post(
url=url,
headers=headers,
json=data,
stream=True,
)

for chunk in response.iter_lines():
if not chunk:
continue
chunk = loads(chunk.decode("utf-8"))
yield chunk.get("delta", "")


def retriever_qa(api_key: str, session_id: int, message: str, llm_name: str, temperature: float, vector_db_id: int):
url = urljoin(API_URL, f"v1/session/{session_id}/retriever-qa")
headers = {"Authorization": f"Bearer {api_key}"}
data = {
"message": message,
"llm_name": llm_name,
Expand All @@ -262,5 +239,7 @@ def retriever_qa(api_key: str, session_id: int, message: str, llm_name: str, tem
for chunk in response.iter_lines():
if not chunk:
continue
if chunk.startswith(b"data:"):
chunk = chunk[5:]
chunk = loads(chunk.decode("utf-8"))
yield chunk.get("delta", "")
yield chunk.get("data", {}).get("chunk", "")

0 comments on commit f9849b5

Please sign in to comment.