diff --git a/app/chains/chat.py b/app/chains/chat.py index 07a80e5..e3495be 100644 --- a/app/chains/chat.py +++ b/app/chains/chat.py @@ -1,109 +1,106 @@ -from operator import itemgetter -from typing import List, Tuple - -from langchain_core.messages import BaseMessage, HumanMessage, AIMessage -from langchain_core.output_parsers import StrOutputParser -from langchain_core.prompts import ChatPromptTemplate, format_document -from langchain_core.runnables import ( - RunnableParallel, +from collections.abc import Sequence + +from langchain.chains import create_history_aware_retriever, create_retrieval_chain +from langchain.chains.combine_documents import create_stuff_documents_chain +from langchain_core.documents import Document +from langchain_core.messages import ( + HumanMessage, + AIMessage, + SystemMessage, + FunctionMessage, + ChatMessage, + ToolMessage, +) +from langchain_core.prompts import ( + ChatPromptTemplate, + MessagesPlaceholder, ) from langchain.prompts.prompt import PromptTemplate -from langserve import CustomUserType -from pydantic import Field +from pydantic import Field, BaseModel from app.dependencies.llm import get_llm from app.dependencies.redis import get_redis -llm = get_llm() - -retriever = get_redis().as_retriever( - search_type="mmr", - search_kwargs={ - "fetch_k": 20, - "k": 3, - "lambda_mult": 0.5, - }, -) +RETRIEVAL_QA_CHAT_SYSTEM_PROMPT = """\ +Answer any questions based solely on the context below: + +{context} + +""" -REPHRASE_TEMPLATE = """\ +REPHRASE_PROMPT = """\ Given the following conversation and a follow up question, \ rephrase the follow up question to be a standalone question. Chat History: {chat_history} -Follow Up Input: {question} -Standalone question: +Follow Up Input: {input} +Standalone Question: """ -CONDENSE_QUESTION_PROMPT = PromptTemplate.from_template(REPHRASE_TEMPLATE) - -ANSWER_TEMPLATE = """\ -Use the following pieces of context to answer the question at the end. \ -If you don't know the answer, just say that you don't know, don't try to make up an answer. - -{context} - -Question: {question} - -Helpful Answer: -""" -ANSWER_PROMPT = ChatPromptTemplate.from_template(ANSWER_TEMPLATE) +def build_chat_chain(): + llm = get_llm() -DEFAULT_DOCUMENT_PROMPT = PromptTemplate.from_template(template="{page_content}") + retriever = get_redis().as_retriever( + search_type="mmr", + search_kwargs={ + "fetch_k": 20, + "k": 3, + "lambda_mult": 0.5, + }, + ) + rephrase_prompt = PromptTemplate.from_template(REPHRASE_PROMPT) -def _combine_documents( - docs, document_prompt=DEFAULT_DOCUMENT_PROMPT, document_separator="\n\n" -): - doc_strings = [format_document(doc, document_prompt) for doc in docs] - return document_separator.join(doc_strings) + retriever_chain = create_history_aware_retriever( + llm, + retriever, + rephrase_prompt, + ) + retrieval_qa_chat_prompt = ChatPromptTemplate.from_messages( + [ + ( + "system", + RETRIEVAL_QA_CHAT_SYSTEM_PROMPT, + ), + MessagesPlaceholder(variable_name="chat_history"), + ( + "user", + "{input}", + ), + ] + ) -class ChatHistory(CustomUserType): - chat_history: List[Tuple[str, str]] = Field( + combine_docs_chain = create_stuff_documents_chain( + llm, + retrieval_qa_chat_prompt, + ) + return create_retrieval_chain(retriever_chain, combine_docs_chain) + + +class Input(BaseModel): + chat_history: Sequence[ + HumanMessage + | AIMessage + | SystemMessage + | FunctionMessage + | ChatMessage + | ToolMessage + ] = Field( ..., - examples=[[("human input", "ai response")]], - extra={"widget": {"type": "chat", "input": "question", "output": "answer"}}, + extra={ + "widget": {"type": "chat", "input": "input", "output": "answer"}, + }, ) - question: str - - -def _format_chat_history(chat_history: ChatHistory) -> List[BaseMessage]: - messages = [] - for human, ai in chat_history.chat_history: - messages.append(HumanMessage(content=human)) - messages.append(AIMessage(content=ai)) - return messages - - -_inputs = RunnableParallel( - standalone_question={ - "chat_history": _format_chat_history, - "question": lambda x: x.question, - } - | CONDENSE_QUESTION_PROMPT - | llm - | StrOutputParser(), -) - -_retrieved_documents = { - "docs": itemgetter("standalone_question") | retriever, - "question": itemgetter("standalone_question"), -} - -_final_inputs = { - "context": lambda x: _combine_documents(x["docs"]), - "question": itemgetter("question"), -} + input: str -_answer = { - "answer": _final_inputs | ANSWER_PROMPT | llm | StrOutputParser(), - "docs": itemgetter("docs"), -} +class Output(BaseModel): + answer: str + context: Sequence[Document] -_conversational_qa_chain = _inputs | _retrieved_documents | _answer -chat_chain = _conversational_qa_chain.with_types(input_type=ChatHistory) +chat_chain = build_chat_chain().with_types(input_type=Input, output_type=Output)