-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Signed-off-by: Hemslo Wang <[email protected]>
- Loading branch information
Showing
1 changed file
with
79 additions
and
82 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,109 +1,106 @@ | ||
from operator import itemgetter | ||
from typing import List, Tuple | ||
|
||
from langchain_core.messages import BaseMessage, HumanMessage, AIMessage | ||
from langchain_core.output_parsers import StrOutputParser | ||
from langchain_core.prompts import ChatPromptTemplate, format_document | ||
from langchain_core.runnables import ( | ||
RunnableParallel, | ||
from collections.abc import Sequence | ||
|
||
from langchain.chains import create_history_aware_retriever, create_retrieval_chain | ||
from langchain.chains.combine_documents import create_stuff_documents_chain | ||
from langchain_core.documents import Document | ||
from langchain_core.messages import ( | ||
HumanMessage, | ||
AIMessage, | ||
SystemMessage, | ||
FunctionMessage, | ||
ChatMessage, | ||
ToolMessage, | ||
) | ||
from langchain_core.prompts import ( | ||
ChatPromptTemplate, | ||
MessagesPlaceholder, | ||
) | ||
from langchain.prompts.prompt import PromptTemplate | ||
from langserve import CustomUserType | ||
from pydantic import Field | ||
from pydantic import Field, BaseModel | ||
|
||
from app.dependencies.llm import get_llm | ||
from app.dependencies.redis import get_redis | ||
|
||
llm = get_llm() | ||
|
||
retriever = get_redis().as_retriever( | ||
search_type="mmr", | ||
search_kwargs={ | ||
"fetch_k": 20, | ||
"k": 3, | ||
"lambda_mult": 0.5, | ||
}, | ||
) | ||
RETRIEVAL_QA_CHAT_SYSTEM_PROMPT = """\ | ||
Answer any questions based solely on the context below: | ||
<context> | ||
{context} | ||
</context> | ||
""" | ||
|
||
REPHRASE_TEMPLATE = """\ | ||
REPHRASE_PROMPT = """\ | ||
Given the following conversation and a follow up question, \ | ||
rephrase the follow up question to be a standalone question. | ||
Chat History: | ||
{chat_history} | ||
Follow Up Input: {question} | ||
Standalone question: | ||
Follow Up Input: {input} | ||
Standalone Question: | ||
""" | ||
|
||
CONDENSE_QUESTION_PROMPT = PromptTemplate.from_template(REPHRASE_TEMPLATE) | ||
|
||
ANSWER_TEMPLATE = """\ | ||
Use the following pieces of context to answer the question at the end. \ | ||
If you don't know the answer, just say that you don't know, don't try to make up an answer. | ||
{context} | ||
Question: {question} | ||
Helpful Answer: | ||
""" | ||
|
||
ANSWER_PROMPT = ChatPromptTemplate.from_template(ANSWER_TEMPLATE) | ||
def build_chat_chain(): | ||
llm = get_llm() | ||
|
||
DEFAULT_DOCUMENT_PROMPT = PromptTemplate.from_template(template="{page_content}") | ||
retriever = get_redis().as_retriever( | ||
search_type="mmr", | ||
search_kwargs={ | ||
"fetch_k": 20, | ||
"k": 3, | ||
"lambda_mult": 0.5, | ||
}, | ||
) | ||
|
||
rephrase_prompt = PromptTemplate.from_template(REPHRASE_PROMPT) | ||
|
||
def _combine_documents( | ||
docs, document_prompt=DEFAULT_DOCUMENT_PROMPT, document_separator="\n\n" | ||
): | ||
doc_strings = [format_document(doc, document_prompt) for doc in docs] | ||
return document_separator.join(doc_strings) | ||
retriever_chain = create_history_aware_retriever( | ||
llm, | ||
retriever, | ||
rephrase_prompt, | ||
) | ||
|
||
retrieval_qa_chat_prompt = ChatPromptTemplate.from_messages( | ||
[ | ||
( | ||
"system", | ||
RETRIEVAL_QA_CHAT_SYSTEM_PROMPT, | ||
), | ||
MessagesPlaceholder(variable_name="chat_history"), | ||
( | ||
"user", | ||
"{input}", | ||
), | ||
] | ||
) | ||
|
||
class ChatHistory(CustomUserType): | ||
chat_history: List[Tuple[str, str]] = Field( | ||
combine_docs_chain = create_stuff_documents_chain( | ||
llm, | ||
retrieval_qa_chat_prompt, | ||
) | ||
return create_retrieval_chain(retriever_chain, combine_docs_chain) | ||
|
||
|
||
class Input(BaseModel): | ||
chat_history: Sequence[ | ||
HumanMessage | ||
| AIMessage | ||
| SystemMessage | ||
| FunctionMessage | ||
| ChatMessage | ||
| ToolMessage | ||
] = Field( | ||
..., | ||
examples=[[("human input", "ai response")]], | ||
extra={"widget": {"type": "chat", "input": "question", "output": "answer"}}, | ||
extra={ | ||
"widget": {"type": "chat", "input": "input", "output": "answer"}, | ||
}, | ||
) | ||
question: str | ||
|
||
|
||
def _format_chat_history(chat_history: ChatHistory) -> List[BaseMessage]: | ||
messages = [] | ||
for human, ai in chat_history.chat_history: | ||
messages.append(HumanMessage(content=human)) | ||
messages.append(AIMessage(content=ai)) | ||
return messages | ||
|
||
|
||
_inputs = RunnableParallel( | ||
standalone_question={ | ||
"chat_history": _format_chat_history, | ||
"question": lambda x: x.question, | ||
} | ||
| CONDENSE_QUESTION_PROMPT | ||
| llm | ||
| StrOutputParser(), | ||
) | ||
|
||
_retrieved_documents = { | ||
"docs": itemgetter("standalone_question") | retriever, | ||
"question": itemgetter("standalone_question"), | ||
} | ||
|
||
_final_inputs = { | ||
"context": lambda x: _combine_documents(x["docs"]), | ||
"question": itemgetter("question"), | ||
} | ||
input: str | ||
|
||
_answer = { | ||
"answer": _final_inputs | ANSWER_PROMPT | llm | StrOutputParser(), | ||
"docs": itemgetter("docs"), | ||
} | ||
|
||
class Output(BaseModel): | ||
answer: str | ||
context: Sequence[Document] | ||
|
||
_conversational_qa_chain = _inputs | _retrieved_documents | _answer | ||
|
||
chat_chain = _conversational_qa_chain.with_types(input_type=ChatHistory) | ||
chat_chain = build_chat_chain().with_types(input_type=Input, output_type=Output) |