-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathconvert_to_llava_format.py
35 lines (28 loc) · 1.29 KB
/
convert_to_llava_format.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
import json
import os
import random
with open('data/train.json', 'r') as f:
train_data = json.load(f)
random.seed(0)
def format_conversation(instruction_key, response, from_human='human', from_gpt='gpt', prepend_image=False):
if prepend_image:
human_value = f'<image>{response[instruction_key]}'
else:
human_value = response[instruction_key]
return [
{"from": from_human, "value": human_value},
{"from": from_gpt, "value": response['response']}
]
for item in train_data:
item['conversations'] = []
instruction_responses = item['instr-resp']
# Randomly shuffle the instruction-response pairs if there are both safe and unsafe
if len(instruction_responses) > 1:
random.shuffle(instruction_responses)
item['conversations'].extend(format_conversation(list(instruction_responses[0].keys())[0], instruction_responses[0], prepend_image=True))
item['conversations'].extend(format_conversation(list(instruction_responses[1].keys())[0], instruction_responses[1], prepend_image=False))
else:
item['conversations'].extend(format_conversation('instruction', instruction_responses[0], prepend_image=True))
item.pop('instr-resp')
with open('data/train_llava_format.json', 'w') as f:
json.dump(train_data, f, indent=2)