forked from guardagent/code
-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathutils.py
145 lines (124 loc) · 4.74 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
import json
def get_input_ehr(data, output_dir=None):
identity = data['identity']
question = data['template']
input_id = data['id']
agent_input = 'Identity: {}\nQuestion: {}'.format(identity, question)
output_path = output_dir + '{}.txt'.format(input_id)
return agent_input, input_id, identity, output_path
def get_output_ehr(output_log_path, idx):
log_file = open(output_log_path, "r").read()
log_file = log_file.split('(END OF EXAMPLES)\n')[1]
knowledge = log_file.split('Question:')[0]
solution = log_file.split('Solution:')[-1]
solution_sections = solution.split('\n----------------------------------------------------------\n')
# check if 'TERMINATE' exists
terminate_idx = None
for i in range(len(solution_sections)):
section = solution_sections[i]
if 'TERMINATE' in section:
terminate_idx = i
# terminate_idx = solution_sections.index('TERMINATE')
answer = solution_sections[terminate_idx - 1]
code = None
for section in solution_sections:
if section.find('\"cell\":') > 0:
code = section.split('\"cell\": \"')[1]
code = code.split('\"\n}')[0]
if code is None:
max_length = 0
section_idx = 0
for i in range(len(solution_sections)):
section = solution_sections[i]
if len(section) > max_length:
max_length = len(section)
section_idx = i
code = solution_sections[section_idx]
agent_output = knowledge + '\nGenerated code:\n' + code + '\nAnswer:\n' + answer
return agent_output
def get_input_seeact(data, output_dir=None):
identity = 'User'
input_id = data['annotation_id']
user_info = "\n ".join([f"{k}: {v}"for k, v in data['user_info'].items()])
agent_input = f"\nTask: {data['confirmed_task']} \n\nuser_info:\n{user_info}"
output_path = output_dir + '/sample_labeled_all.json'
return agent_input, input_id, identity, output_path
def get_output_seeact(output_log_path, idx):
with open(output_log_path, 'r') as f:
dataset = json.load(f)
# task_breakdown = dataset[idx]['prompt'][-2]
action_choices = dataset[idx]['prompt'][-1]
agent_output = f"\n(Next Action Based on Webpage and Analysis)\n{dataset[idx]['gpt_output'][-2]}\n\n{action_choices}\n\n(Final Answer)\n{dataset[idx]['gpt_output'][-1]}"
return agent_output
def check_inaccessibility(dbs1, dbs2):
same = True
for key, value in dbs2.items():
if key not in dbs1:
same = False
else:
column_diff_dir1 = list(set(value) - set(dbs1[key]))
column_diff_dir2 = list(set(dbs1[key]) - set(value))
if len(column_diff_dir1) > 0 or len(column_diff_dir2) > 0:
same = False
for key, value in dbs1.items():
if key not in dbs2:
same = False
else:
column_diff_dir1 = list(set(value) - set(dbs2[key]))
column_diff_dir2 = list(set(dbs2[key]) - set(value))
if len(column_diff_dir1) > 0 or len(column_diff_dir2) > 0:
same = False
return same
def check_violation(rule1: str, violation: str):
rule1 = rule1.strip()
violation = violation.strip()
if rule1 is None or rule1 == "" or violation == "":
return False
violation = violation.replace("\n", "").replace(".", "").lower()
rule1 = rule1.replace("\n", "").replace(".", "").lower()
if rule1 == violation or violation in rule1:
return True
return False
def judge(pred, ans):
old_flag = True
if not ans in pred:
old_flag = False
if "True" in pred:
pred = pred.replace("True", "1")
else:
pred = pred.replace("False", "0")
if ans == "False" or ans == "false":
ans = "0"
if ans == "True" or ans == "true":
ans = "1"
if ans == "No" or ans == "no":
ans = "0"
if ans == "Yes" or ans == "yes":
ans = "1"
if ans == "None" or ans == "none":
ans = "0"
if ", " in ans:
ans = ans.split(', ')
if ans[-2:] == ".0":
ans = ans[:-2]
if not type(ans) == list:
ans = [ans]
new_flag = True
for i in range(len(ans)):
if not ans[i] in pred:
new_flag = False
break
return (old_flag or new_flag)
def judge_seeact(pred, ans):
if pred == ans:
return True
pre_choice = pred.split(".")[0]
pre_operation = pred.split(", ")[1]
gt_choice = pred.split("ELEMENT: ")[1]
gt_choice = gt_choice.split("\nACTION")[0].replace("\n","")
gt_operation = pred.split("\nACTION: ")[1]
gt_choice = gt_operation.split("\nVALUE")[0].replace("\n","")
if pre_choice == gt_choice and pre_operation == gt_operation:
return True
else:
return False