-
Notifications
You must be signed in to change notification settings - Fork 195
/
Copy pathdialog_sentiment_intensity_mapper.py
221 lines (191 loc) · 10.5 KB
/
dialog_sentiment_intensity_mapper.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
import re
from typing import Dict, Optional
from loguru import logger
from pydantic import NonNegativeInt, PositiveInt
from data_juicer.ops.base_op import OPERATORS, TAGGING_OPS, Mapper
from data_juicer.utils.constant import Fields, MetaKeys
from data_juicer.utils.model_utils import get_model, prepare_model
OP_NAME = 'dialog_sentiment_intensity_mapper'
# TODO: LLM-based inference.
@TAGGING_OPS.register_module(OP_NAME)
@OPERATORS.register_module(OP_NAME)
class DialogSentimentIntensityMapper(Mapper):
"""
Mapper to predict user's sentiment intensity (from -5 to 5 in default
prompt) in dialog. Input from history_key, query_key and
response_key. Output lists of intensities and analysis for queries in
the dialog.
"""
DEFAULT_SYSTEM_PROMPT = ('请判断用户和LLM多轮对话中用户的情绪变化。\n'
'要求:\n'
'- 用户情绪值是-5到5之间到整数,-5表示极度负面,5表示极度正面,'
'-5到5之间数值表示情绪从负面逐渐到正面的变化过程,0代表情呈绪中性。\n'
'- 只输出当轮对话的分析,不要继续构造对话。\n'
'- 需要先进行分析,然后确定用户的情绪值,下面是一个样例,请模仿样例格式输出。\n'
'用户:你好,我对可持续发展的定义有点模糊,帮我解释一下?\n'
'情绪分析:刚开始,还没得到LLM回复,用户情绪呈中性。\n'
'情绪值:0\n'
'LLM:当然可以!可持续发展是指在满足当代人的需求的同时,不损害子孙后代满足其自'
'身需求的能力的发展模式。它包括经济发展、社会发展和环境保护三个主要方面。通过合'
'理利用资源和保护环境,我们可以确保未来的世代也能享有健全的生态系统和经济制度。\n'
'用户:谢谢你的解释!那你能告诉我一些普通人可以采取的可持续生活方式吗?\n'
'情绪分析:对回答感到满意,情绪正面。\n'
'情绪值:1\n'
'LLM:当然可以,普通人可以通过减少一次性产品的使用、选择公共交通或拼车、节约用'
'水、以及支持本地和可持续发展的企业等方式来践行可持续生活。此外,关注垃圾分类和'
'多用电子账单也是不错的选择。\n'
'用户:你提到支持本地企业,这一点我很感兴趣。能详细说说为什么这对可持续发展有促'
'进作用吗?\n'
'情绪分析:觉得回答实用且具体,情绪进一步转好。\n'
'情绪值:2\n'
'LLM:呃,我最近发现了一部新电影,讲述了一个关于外星人和地球土著合作保护环境的'
'故事。虽然它是科幻片,但很有启发性,推荐你去看看。\n'
'用户:什么吗,根本是答非所问。\n'
'情绪分析:LLM没有回应问题而是提到无关内容,导致用户情绪直线下降。\n'
'情绪值:-2\n'
'LLM:抱歉刚才的偏题!支持本地企业有助于减少长途运输产生的碳足迹,使供应链更加'
'环保。此外,本地企业也更有可能采用可持续的生产方式,同时促进社区经济的繁荣。\n'
'用户:还行吧,算你能够掰回来。\n'
'情绪分析:问题得到解答,问题偏题得到纠正,情绪稍有好转。\n'
'情绪值:-1\n')
DEFAULT_QUERY_TEMPLATE = '用户:{query}\n'
DEFAULT_RESPONSE_TEMPLATE = 'LLM:{response}\n'
DEFAULT_ANALYSIS_TEMPLATE = '情绪分析:{analysis}\n'
DEFAULT_INTENSITY_TEMPLATE = '情绪值:{intensity}\n'
DEFAULT_ANALYSIS_PATTERN = '情绪分析:(.*?)\n'
DEFAULT_INTENSITY_PATTERN = '情绪值:(.*?)($|\n)'
def __init__(
self,
api_model: str = 'gpt-4o',
max_round: NonNegativeInt = 10,
*,
intensities_key: str = MetaKeys.dialog_sentiment_intensity,
analysis_key: str = MetaKeys.dialog_sentiment_intensity_analysis,
api_endpoint: Optional[str] = None,
response_path: Optional[str] = None,
system_prompt: Optional[str] = None,
query_template: Optional[str] = None,
response_template: Optional[str] = None,
analysis_template: Optional[str] = None,
intensity_template: Optional[str] = None,
analysis_pattern: Optional[str] = None,
intensity_pattern: Optional[str] = None,
try_num: PositiveInt = 3,
model_params: Dict = {},
sampling_params: Dict = {},
**kwargs):
"""
Initialization method.
:param api_model: API model name.
:param max_round: The max num of round in the dialog to build the
prompt.
:param intensities_key: The key name in the meta field to store
the output sentiment intensities. It is
'dialog_sentiment_intensity' in default.
:param analysis_key: The key name in the meta field to store the
corresponding analysis. It is
'dialog_sentiment_intensity_analysis' in default.
:param api_endpoint: URL endpoint for the API.
:param response_path: Path to extract content from the API response.
Defaults to 'choices.0.message.content'.
:param system_prompt: System prompt for the task.
:param query_template: Template for query part to build the input
prompt.
:param response_template: Template for response part to build the
input prompt.
:param analysis_template: Template for analysis part to build the
input prompt.
:param intensity_template: Template for intensity part to build the
input prompt.
:param analysis_pattern: Pattern to parse the return sentiment
analysis.
:param intensity_pattern: Pattern to parse the return sentiment
intensity.
:param try_num: The number of retry attempts when there is an API
call error or output parsing error.
:param model_params: Parameters for initializing the API model.
:param sampling_params: Extra parameters passed to the API call.
e.g {'temperature': 0.9, 'top_p': 0.95}
:param kwargs: Extra keyword arguments.
"""
super().__init__(**kwargs)
self.max_round = max_round
self.intensities_key = intensities_key
self.analysis_key = analysis_key
self.system_prompt = system_prompt or self.DEFAULT_SYSTEM_PROMPT
self.query_template = query_template or self.DEFAULT_QUERY_TEMPLATE
self.response_template = response_template or \
self.DEFAULT_RESPONSE_TEMPLATE
self.analysis_template = analysis_template or \
self.DEFAULT_ANALYSIS_TEMPLATE
self.intensity_template = intensity_template or \
self.DEFAULT_INTENSITY_TEMPLATE
self.analysis_pattern = analysis_pattern or \
self.DEFAULT_ANALYSIS_PATTERN
self.intensity_pattern = intensity_pattern or \
self.DEFAULT_INTENSITY_PATTERN
self.sampling_params = sampling_params
self.model_key = prepare_model(model_type='api',
model=api_model,
endpoint=api_endpoint,
response_path=response_path,
**model_params)
self.try_num = try_num
def build_input(self, history, query):
if self.max_round > 0:
input_prompt = ''.join(history[-self.max_round * 4:])
else:
input_prompt = ''
input_prompt += self.query_template.format(query=query[0])
return input_prompt
def parse_output(self, response):
analysis = ''
intensity = 0
match = re.search(self.analysis_pattern, response)
if match:
analysis = match.group(1)
match = re.search(self.intensity_pattern, response)
if match:
intensity = int(match.group(1))
return analysis, intensity
def process_single(self, sample, rank=None):
meta = sample[Fields.meta]
if self.intensities_key in meta and self.analysis_key in meta:
return sample
client = get_model(self.model_key, rank=rank)
analysis_list = []
intensities = []
history = []
dialog = sample[self.history_key]
if self.query_key in sample and sample[self.query_key]:
if self.response_key in sample and sample[self.response_key]:
dialog.append(
(sample[self.query_key], sample[self.response_key]))
else:
dialog.append((sample[self.query_key], ''))
for qa in dialog:
input_prompt = self.build_input(history, qa)
messages = [{
'role': 'system',
'content': self.system_prompt,
}, {
'role': 'user',
'content': input_prompt,
}]
for _ in range(self.try_num):
try:
response = client(messages, **self.sampling_params)
analysis, intensity = self.parse_output(response)
if len(analysis) > 0:
break
except Exception as e:
logger.warning(f'Exception: {e}')
analysis_list.append(analysis)
intensities.append(intensity)
history.append(self.query_template.format(query=qa[0]))
history.append(self.analysis_template.format(analysis=analysis))
history.append(self.intensity_template.format(intensity=intensity))
history.append(self.response_template.format(response=qa[1]))
meta[self.intensities_key] = intensities
meta[self.analysis_key] = analysis_list
return sample