-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathModel.py
113 lines (87 loc) · 3.55 KB
/
Model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
import os
import re
from operator import itemgetter
from langchain.prompts import ChatPromptTemplate
from langchain_core.output_parsers import StrOutputParser
from langchain_mistralai.chat_models import ChatMistralAI
from langchain_openai import AzureOpenAIEmbeddings
from DocumentAnalyzer import DocumentAnalyzer
from helper import singleton
def content_spliter(content, n):
words = content.split()
total_words = len(words)
section_size = total_words // n
sections = []
for i in range(n):
start_index = i * section_size
if i == n-1:
end_index = total_words
else:
end_index = start_index + section_size
sections.append(' '.join(words[start_index:end_index]))
return sections
@singleton
class Model:
llm = None
embeddings = None
analyzer = None
db = None
__system_msg = (
"system",
"you are a helpfull assistant that checks wheather the given content obeys the rules/laws or not with the help of context",
)
def __init__(self):
self.llm = ChatMistralAI(
endpoint=os.getenv("EP"),
mistral_api_key=os.getenv("MS"),
)
self.embeddings = AzureOpenAIEmbeddings(
model="tuskact2",
)
self.analyzer = DocumentAnalyzer()
def query(self, content):
responce = self.llm.invoke([self.__system_msg, content])
return responce
def rag(self, blob_key):
from Ingest import VectorDB
self.db = VectorDB()
content = self.analyzer.analyze_blob(blob_key, "rag")
split_content = content_spliter(content, 5)
template = """
Here is the content that needs to be checked for any violations of the rules/laws:
\n --- \n {content} \n --- \n
Here is any available content before this content that is checked for any violations in form of pairs:
\n --- \n {q_a_pairs} \n --- \n
Here is additional context relevant to the question:
\n --- \n {context} \n --- \n
Use the above context and pairs to check if the content violates any rules/laws.\n
Give result in this format: is_violation: (yes/no) \n reason: (violation) \n
No extra information is needed. just one result for entire content. the reason must be unique and valid
"""
decomposition_prompt = ChatPromptTemplate.from_template(template)
q_a_pairs = ""
results = []
for content in split_content:
rag_chain = (
{
"content": itemgetter("content"),
"q_a_pairs": itemgetter("q_a_pairs"),
"context": itemgetter("content") | self.db.retriever,
}
| decomposition_prompt
| self.llm
| StrOutputParser()
)
result = {"title": content}
answer = rag_chain.invoke({"content": content, "q_a_pairs": q_a_pairs})
is_violation_pattern = r"is_violation:\s*(yes|no)"
reason_pattern = r"reason:\s*(.*)"
is_violation_match = re.search(is_violation_pattern, answer)
is_violation = is_violation_match.group(1).strip() if is_violation_match else None
reason_match = re.search(reason_pattern, answer, re.DOTALL)
reason = reason_match.group(1).strip() if reason_match else None
result["status"] = is_violation
result["reason"] = reason
q_a_pairs = q_a_pairs + "\n---\n" + str(result)
results.append(result)
return results