Skip to content

Commit

Permalink
Refactor chat chain
Browse files Browse the repository at this point in the history
Signed-off-by: Hemslo Wang <[email protected]>
  • Loading branch information
hemslo committed Feb 19, 2024
1 parent 9834570 commit c072efe
Showing 1 changed file with 79 additions and 82 deletions.
161 changes: 79 additions & 82 deletions app/chains/chat.py
Original file line number Diff line number Diff line change
@@ -1,109 +1,106 @@
from operator import itemgetter
from typing import List, Tuple

from langchain_core.messages import BaseMessage, HumanMessage, AIMessage
from langchain_core.output_parsers import StrOutputParser
from langchain_core.prompts import ChatPromptTemplate, format_document
from langchain_core.runnables import (
RunnableParallel,
from collections.abc import Sequence

from langchain.chains import create_history_aware_retriever, create_retrieval_chain
from langchain.chains.combine_documents import create_stuff_documents_chain
from langchain_core.documents import Document
from langchain_core.messages import (
HumanMessage,
AIMessage,
SystemMessage,
FunctionMessage,
ChatMessage,
ToolMessage,
)
from langchain_core.prompts import (
ChatPromptTemplate,
MessagesPlaceholder,
)
from langchain.prompts.prompt import PromptTemplate
from langserve import CustomUserType
from pydantic import Field
from pydantic import Field, BaseModel

from app.dependencies.llm import get_llm
from app.dependencies.redis import get_redis

llm = get_llm()

retriever = get_redis().as_retriever(
search_type="mmr",
search_kwargs={
"fetch_k": 20,
"k": 3,
"lambda_mult": 0.5,
},
)
RETRIEVAL_QA_CHAT_SYSTEM_PROMPT = """\
Answer any questions based solely on the context below:
<context>
{context}
</context>
"""

REPHRASE_TEMPLATE = """\
REPHRASE_PROMPT = """\
Given the following conversation and a follow up question, \
rephrase the follow up question to be a standalone question.
Chat History:
{chat_history}
Follow Up Input: {question}
Standalone question:
Follow Up Input: {input}
Standalone Question:
"""

CONDENSE_QUESTION_PROMPT = PromptTemplate.from_template(REPHRASE_TEMPLATE)

ANSWER_TEMPLATE = """\
Use the following pieces of context to answer the question at the end. \
If you don't know the answer, just say that you don't know, don't try to make up an answer.
{context}
Question: {question}
Helpful Answer:
"""

ANSWER_PROMPT = ChatPromptTemplate.from_template(ANSWER_TEMPLATE)
def build_chat_chain():
llm = get_llm()

DEFAULT_DOCUMENT_PROMPT = PromptTemplate.from_template(template="{page_content}")
retriever = get_redis().as_retriever(
search_type="mmr",
search_kwargs={
"fetch_k": 20,
"k": 3,
"lambda_mult": 0.5,
},
)

rephrase_prompt = PromptTemplate.from_template(REPHRASE_PROMPT)

def _combine_documents(
docs, document_prompt=DEFAULT_DOCUMENT_PROMPT, document_separator="\n\n"
):
doc_strings = [format_document(doc, document_prompt) for doc in docs]
return document_separator.join(doc_strings)
retriever_chain = create_history_aware_retriever(
llm,
retriever,
rephrase_prompt,
)

retrieval_qa_chat_prompt = ChatPromptTemplate.from_messages(
[
(
"system",
RETRIEVAL_QA_CHAT_SYSTEM_PROMPT,
),
MessagesPlaceholder(variable_name="chat_history"),
(
"user",
"{input}",
),
]
)

class ChatHistory(CustomUserType):
chat_history: List[Tuple[str, str]] = Field(
combine_docs_chain = create_stuff_documents_chain(
llm,
retrieval_qa_chat_prompt,
)
return create_retrieval_chain(retriever_chain, combine_docs_chain)


class Input(BaseModel):
chat_history: Sequence[
HumanMessage
| AIMessage
| SystemMessage
| FunctionMessage
| ChatMessage
| ToolMessage
] = Field(
...,
examples=[[("human input", "ai response")]],
extra={"widget": {"type": "chat", "input": "question", "output": "answer"}},
extra={
"widget": {"type": "chat", "input": "input", "output": "answer"},
},
)
question: str


def _format_chat_history(chat_history: ChatHistory) -> List[BaseMessage]:
messages = []
for human, ai in chat_history.chat_history:
messages.append(HumanMessage(content=human))
messages.append(AIMessage(content=ai))
return messages


_inputs = RunnableParallel(
standalone_question={
"chat_history": _format_chat_history,
"question": lambda x: x.question,
}
| CONDENSE_QUESTION_PROMPT
| llm
| StrOutputParser(),
)

_retrieved_documents = {
"docs": itemgetter("standalone_question") | retriever,
"question": itemgetter("standalone_question"),
}

_final_inputs = {
"context": lambda x: _combine_documents(x["docs"]),
"question": itemgetter("question"),
}
input: str

_answer = {
"answer": _final_inputs | ANSWER_PROMPT | llm | StrOutputParser(),
"docs": itemgetter("docs"),
}

class Output(BaseModel):
answer: str
context: Sequence[Document]

_conversational_qa_chain = _inputs | _retrieved_documents | _answer

chat_chain = _conversational_qa_chain.with_types(input_type=ChatHistory)
chat_chain = build_chat_chain().with_types(input_type=Input, output_type=Output)

0 comments on commit c072efe

Please sign in to comment.