-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathgenerate_dataset.py
952 lines (772 loc) · 34.2 KB
/
generate_dataset.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
r"""Generates CLEVR-Dialog dataset.
Needs access to the following files:
synonyms: Contains several synonyms for each word in the question/caption.
caption templates: List of caption templates.
question templates: List of question templates.
metainfo: Meta-information related to attributes and values of CLEVR objects.
Usage:
python -u generate_dataset.py \
--scene_path="data/scenes/CLEVR_train_scenes.json" \
--num_beams=100 \
--num_workers=12 \
--save_path="data/clevr_train_raw.json"
Author: Satwik Kottur
"""
import copy
import collections
import json
import multiprocessing
import os
import random
import re
import time
from absl import flags
from absl import app
import numpy as np
from tqdm import tqdm as progressbar
import clevr_utils as utils
import global_vars as gvars
# import constraints_splitB as constraints
import constraints
FLAGS = flags.FLAGS
flags.DEFINE_string('synonym_path', '/projects/abdessaied/clevr-dialog/templates/synonyms.json',
'Path to synonyms file')
flags.DEFINE_string('metainfo_path', '/projects/abdessaied/clevr-dialog/templates/metainfo.json',
'Path to meta information file')
flags.DEFINE_string('caption_template_root', '/projects/abdessaied/clevr-dialog/templates/captions/',
'Root to folder with caption templates')
flags.DEFINE_string('question_template_root', '/projects/abdessaied/clevr-dialog/templates/questions/',
'Root to folder with question templates')
flags.DEFINE_string('scene_path',
# '/projects/abdessaied/clevr-dialog/output/result_clevr_oroginal_test.json',
'/projects/abdessaied/clevr-dataset-gen/output_finetune_20_objs_with_masks_many_attr/CLEVR_scenes.json',
'Path to CLEVR scene path json file')
flags.DEFINE_string('scene_id_file', '',
'Path to specific CLEVR scene ids to generate dialogs')
flags.DEFINE_string('save_path', '/projects/abdessaied/clevr-dialog/output/raw_data_modified/dialogs_finetune_20_objects_10_rounds.json',
'Path to save the dataset json')
flags.DEFINE_integer('num_beams', 100, 'Number of beams in dialog search')
flags.DEFINE_integer('num_workers', 64, 'Number of workers to use in search')
flags.DEFINE_integer('captions_per_image', 5, 'Number of captions per image')
flags.DEFINE_integer('num_images', -1,
'Number of images to generate dialogs. -1 for all.')
flags.DEFINE_integer('num_rounds', 10, 'Number of rounds in each dialog')
# Number of beams and distribution of question types.
# Start cutting down beams after 5th round.
# Heuristics (for round 4):
# A. count <= 2 1 <= seek <= 3 exist <= 2
# B. count + exist <= 3
# C. Independent questions <= 1
# Heuristics (for round 5):
# A. count <= 2 2 <= seek <= 4 exist <= 2
# B. count + exist <= 3
# C. Independent questions <= 1
ranges = {3: {'indep': [0, 1], 'seek': [1, 4], 'exist': [0, 1],
'count': [0, 1], 'exist+count': [0, 2]},
4: {'indep': [0, 1], 'seek': [2, 4], 'exist': [0, 1],
'count': [0, 1], 'exist+count': [0, 2]},
5: {'indep': [0, 1], 'seek': [2, 5], 'exist': [0, 2],
'count': [0, 2], 'exist+count': [0, 3]},
6: {'indep': [0, 1], 'seek': [2, 5], 'exist': [0, 2],
'count': [0, 2], 'exist+count': [0, 3]},
7: {'indep': [0, 2], 'seek': [3, 5], 'exist': [0, 2],
'count': [0, 2], 'exist+count': [0, 3]},
8: {'indep': [0, 2], 'seek': [3, 6], 'exist': [0, 3],
'count': [0, 3], 'exist+count': [0, 3]},
9: {'indep': [0, 2], 'seek': [3, 6], 'exist': [0, 3],
'count': [0, 3], 'exist+count': [0, 4]}}
def mapping(tag):
"""Maps tag to attribute.
Args:
tag: An input tag
Returns:
tag_label: Label for the input tag
"""
return gvars.METAINFO['tag_map'][tag.replace('1', '')]
def inv_mapping(attribute, arg_id=0):
"""Inverse maps attribute to tag.
Args:
attribute: Name of the attribute
arg_id: Argument id to use. Append 1 if arg_id is 1, else nothing
Returns:
base_tag: The string for the tag
"""
base_tag = gvars.METAINFO['tag_inv_map'][attribute]
if arg_id > 0:
base_tag = base_tag[:-1] + str(arg_id) + base_tag[-1]
return base_tag
def get_tag_group(tag):
"""Gets the group id from tag string.
For example, tag string of <S> is 0, <S1> is 1.
Assumes single digit group id.
Args:
tag: Tag string
Returns:
group_id: Return extracted group id
"""
group_id = 0 if len(tag) <= 3 else int(tag[-2])
return group_id
def replace_attribute(text, tag, obj_group, eliminate=False):
"""Replaces the attribute tags in text using available object properties.
NOTE: If shape is to be replaced, we use 'thing' in its place.
Args:
text: The text template to perform replacement
tag: The tags to replace in the text
obj_group: Available object properties to replace with
eliminate: Eliminate the remaining attribute tags
Returns:
replaced_text: The replaced text
"""
group = get_tag_group(tag)
if mapping(tag) == 'relation':
# Actual relation tag, else position tag.
if tag == '<R>':
relation_list = gvars.METAINFO['relation_phrases'][obj_group['relation']]
relation_cand = random.choice(relation_list)
else:
relation_cand = obj_group['relation']
return text.replace(tag, relation_cand)
if mapping(tag) == 'shape':
if eliminate:
replacer = 'thing'
else:
replacer = str(obj_group['objects'][group][mapping(tag)])
# Plural forms for groups.
if obj_group.get('count', 1) > 1 or obj_group.get('use_plural', False):
replacer += 's'
elif mapping(tag) == 'count':
if eliminate:
replacer = ''
else:
replacer = str(obj_group['count'])
else:
if eliminate:
replacer = ''
else:
replacer = str(obj_group['objects'][group][mapping(tag)])
return text.replace(tag, replacer)
def realize_text_and_extract_scene(scene, template, filter_objs):
"""Samples attributes for template using filtered objects.
In addition, creates scene graph for the new information added.
Args:
scene: Current scene graph
template: Text template to use to generate questions
filter_objs: Set of objects satisfying constraints of current template
Returns:
sample: Contains the text realization and scene graph
"""
def default_list(): return collections.defaultdict(list)
graph = {'relationships': collections.defaultdict(default_list),
'counts': {}, 'exists': {}, 'history': [], 'objects': {}}
# number of inputs
n_inputs = template.get('inputs', 1)
# sample a text template
text_sample = random.choice(template['text'])
text_sample_index = template['text'].index(text_sample)
# extract attribute tags and get them into groups
tags = re.findall('(<[\d\w]*>)', text_sample)
tag_groups = collections.defaultdict(list)
for tag in tags:
group_id = get_tag_group(tag)
tag_groups[group_id].append(tag)
# sample a random element from filtered
arg_sample = random.choice(filter_objs)
# scene information obtained from the current round
graph_item = arg_sample['graph']
# remove tags from text not allowed by filter_objs
for arg_ind in range(n_inputs):
obj_sample = arg_sample['objects'][arg_ind]
avail_attrs = obj_sample['optional'] + obj_sample['required']
for ii in tag_groups[arg_ind][::-1]:
if mapping(ii) not in avail_attrs:
tag_groups[arg_ind].remove(ii)
text_sample = replace_attribute(
text_sample, ii, arg_sample, True)
# assert that all required attributes are present as tags
for attribute in obj_sample['required']:
required_tag = inv_mapping(attribute, arg_ind)
if required_tag not in tag_groups[arg_ind]:
print("required_tag: {}".format(required_tag))
print("template: {}".format(template))
assert required_tag in tag_groups[arg_ind], \
'A required attribute is missing in template!'
# start compiling tags to keep
tags_to_keep = [inv_mapping(ii, arg_ind)
for ii in obj_sample['required']]
# filter out those not present in text template
optional_tags = [inv_mapping(ii, arg_ind)
for ii in obj_sample['optional']]
optional_tags = [
ii for ii in optional_tags if ii in tag_groups[arg_ind]]
# if tags_to_keep is empty, sample from optional with 1:70 2:25 3:5
if len(optional_tags) > 0:
if len(tags_to_keep) > 0:
n_tags_sample = [0, 1, 2]
else:
n_tags_sample = [1, 2, 3]
n_sample = np.random.choice(n_tags_sample, 1,
p=gvars.METAINFO['probabilities'],
replace=False)
# lower cap at the length of optional
n_sample = min(n_sample[0], len(optional_tags))
if n_sample > 0:
tags_to_keep += random.sample(optional_tags, n_sample)
# now create a dictionary of placeholders with actual attribute values
for tag in tag_groups[arg_ind]:
remove = tag not in tags_to_keep
text_sample = replace_attribute(
text_sample, tag, arg_sample, remove)
# remove attributes from objects not included in tags_to_keep
if 'objects' in graph_item:
for ii in gvars.METAINFO['attributes']:
if inv_mapping(ii, arg_ind) not in tags_to_keep:
if ii in graph_item['objects'][arg_ind]:
del graph_item['objects'][arg_ind][ii]
# record the caption info
# Record info and merge scene graphs.
args = []
# if "unique-obj" == template['label']:
# print('yey')
for obj in arg_sample['objects']:
if obj is None:
continue
else:
for k in obj['required']:
arg = obj.get(k, None)
if arg is not None:
if arg not in args: # and type(arg) == str:
args.append(arg)
else:
arg = arg_sample.get(k, None)
if arg is not None and arg not in args and type(arg) == str:
args.append(arg)
arg = obj.get('attribute', None)
if arg is not None and arg not in args:
args.append(arg)
if template['label'] == 'obj-relation':
args.append(arg_sample['relation'])
if template['label'] == "count-att-no":
template['label'] = "count-att"
graph_item['round'] = 0
sample = {}
sample['template_info'] = [copy.deepcopy(template)]
sample['args'] = args
del sample['template_info'][-1]['text']
sample['template_info'][-1]['index'] = text_sample_index
sample['caption'] = text_sample
sample['template'] = template['label']
sample['dialog'] = []
# append history, update scene graph, and save the new scene graph
graph['history'].append(graph_item)
sample['graph'] = utils.merge_update_scene_graph(graph, graph_item)
return sample
def realize_question(dialog, template, filter_objs):
"""Samples attributes for template using filtered objects.
In addition, creates scene graph for the new information added.
Args:
scene: Current scene graph
template: Text template to use to generate questions
filter_objs: Set of objects satisfying constraints of current template
Returns:
sample: Contains the text realization and scene graph
"""
# Number of inputs.
n_inputs = template.get('inputs', 0)
# Sample a text template.
text_sample = random.choice(template['text'])
text_sample_index = template['text'].index(text_sample)
# Extract attribute tags and get them into groups.
tags = re.findall('(<[\d\w]*>)', text_sample)
tag_groups = collections.defaultdict(list)
for tag in tags:
group_id = get_tag_group(tag)
tag_groups[group_id].append(tag)
# Sample a random element from filtered.
arg_sample = random.choice(filter_objs)
# Remove tags from text not allowed by filter_objs.
for arg_ind in range(n_inputs):
obj_sample = arg_sample['objects'][arg_ind]
avail_attrs = obj_sample['optional'] + obj_sample['required']
for ii in tag_groups[arg_ind][::-1]:
if mapping(ii) not in avail_attrs:
tag_groups[arg_ind].remove(ii)
text_sample = replace_attribute(
text_sample, ii, arg_sample, True)
# Assert that all required attributes are present as tags.
for attribute in obj_sample['required']:
required_tag = inv_mapping(attribute, arg_ind)
# Make an exception for <R> and <P>
if required_tag == '<R>' and '<P>' in tag_groups[arg_ind]:
continue
assert required_tag in tag_groups[arg_ind], \
'A required attribute is missing in template!'
# Start compiling tags to keep.
tags_to_keep = [inv_mapping(ii, arg_ind)
for ii in obj_sample['required']]
# Filter out those not present in text template.
optional_tags = [inv_mapping(ii, arg_ind)
for ii in obj_sample['optional']]
optional_tags = [
ii for ii in optional_tags if ii in tag_groups[arg_ind]]
# If tags_to_keep is empty, sample from optional with (1:70, 2:25, 3:5).
if len(optional_tags) > 0:
if len(tags_to_keep) > 0:
n_tags_sample = [0, 1, 2]
else:
n_tags_sample = [1, 2, 3]
n_sample = np.random.choice(n_tags_sample, 1,
p=gvars.METAINFO['probabilities'],
replace=False)
# Lower cap at the length of optional.
n_sample = min(n_sample[0], len(optional_tags))
if n_sample > 0:
tags_to_keep += random.sample(optional_tags, n_sample)
# Now create a dictionary of placeholders with actual attribute values.
for tag in tag_groups[arg_ind]:
remove = tag not in tags_to_keep
text_sample = replace_attribute(
text_sample, tag, arg_sample, remove)
# Record info and merge scene graphs.
args = []
# if template['label'] == 'seek-attr-early':
# print('yey')
for obj in arg_sample['objects']:
if obj is None:
continue
else:
for k in obj['required']:
arg = obj.get(k, None)
if arg is not None:
if arg not in args:
args.append(arg)
else:
arg = arg_sample.get(k, None)
if arg is not None:
args.append(arg)
arg = obj.get('attribute', None)
if arg is not None and arg not in args:
args.append(arg)
# req_att_keys = [k for obj in arg_sample['objects'] for k in obj['required'] if obj is not None]
dialog_datum = {'question': text_sample, 'answer': arg_sample['answer'],
'template': template['label'], 'args': args}
dialog['template_info'].append(template.copy())
del dialog['template_info'][-1]['text']
dialog['template_info'][-1]['index'] = text_sample_index
if 'unique' in template['label']:
print('voila')
dialog['dialog'].append(dialog_datum)
graph_item = arg_sample['graph']
# If mergeable, add it to the objects list.
dialog['graph'] = utils.merge_update_scene_graph(
dialog['graph'], graph_item)
# If there are volatile objects in the graph item, remove them.
for obj in graph_item['objects'][::-1]:
if obj.get('volatile', False):
graph_item['objects'].remove(obj)
dialog['graph']['history'].append(graph_item)
return dialog
def clean_text_subroutine(text, thing, suffix):
"""Cleans the text and substitutes thing with object (subroutine).
Args:
text: Text string to be cleaned
thing: Whether to use 'thing' or 'object'
suffix: Either '?' (question) or '.' (caption)
Returns:
clean_text: Text string after cleaning procedure
"""
# Synonyms + skipping optional part of the sentence
clean_text = skip_and_replace_phrases(text)
# Remove full stop, empty spaces, capitalize the start letter.
clean_text = re.sub(' +', ' ', clean_text.replace(suffix, '').strip(' '))
# First replace 'a thing' -> 'an object'.
# Then perform remaining actions.
if thing == 'object':
clean_text = clean_text.replace('a thing', 'an object')
clean_text = clean_text.replace('thing', thing)
clean_text = clean_text[0].upper() + clean_text[1:] + suffix
return clean_text
def clean_dialog_text(dialogs):
"""Cleans the dialog texts.
Args:
dialogs: Generated dialogs to perform text cleaning
Returns:
dialogs: Return the dialogs after cleaning the text inplace
"""
# Replace thing with object throughout with probability 0.5.
thing = 'thing' if random.random() > 0.5 else 'object'
for index, dialog_datum in enumerate(dialogs):
# Clean the caption.
text = dialog_datum['caption']
dialogs[index]['caption'] = clean_text_subroutine(text, thing, '.')
for r_id, dialog in enumerate(dialog_datum['dialog']):
# Clean the question.
text = dialog['question']
text = clean_text_subroutine(text, thing, '?')
dialogs[index]['dialog'][r_id]['question'] = text
return dialogs
def skip_and_replace_phrases(text):
"""Substitutes synonyms and skips optional parts stochastically.
Args:
text: Text string
Returns:
text: Text string with synonyms replaced and optional parts skipped
"""
# For each text in [], replace it with '' with probability 0.5.
matches = re.findall('(\[[ \w]*\])', text)
for match in matches:
if random.uniform(0, 1) > 0.5:
text = text.replace(match, '')
else:
text = text.replace(match, match[1:-1])
# Remove empty spaces, if any.
text = re.sub(' +', ' ', text)
# Search for synonyms, replace at uniformly random.
text = text.lower()
for key, values in gvars.METAINFO['synonym_keys']:
if key in text:
text = text.replace(key, random.choice(values))
return text
def generate_captions(scenes, templates):
"""Wrapper generates captions.
Args:
scenes: List of scene graphs for which to generate captions
templates: List of available caption templates
Returns:
generated_content: Captions generated for the input scenes
"""
template_dictionary = {ii['label']: ii for ii in templates}
generated_content = []
for scene in scenes['scenes'][0:FLAGS.num_images]:
content = {}
# Copy over image_index, split, image_filename from scene.
for key in ['image_index', 'split', 'image_filename']:
content[key] = scene[key]
content['dialogs'] = []
# Filter objects based on constraints.
filter_objs = constraints.caption(scene, templates)
for filter_obj in filter_objs:
# Realize the text, and return the partial scene knowledge (q).
template = template_dictionary[filter_obj[0]['graph']['template']]
sample = realize_text_and_extract_scene(
scene, template, filter_obj)
# Add it to the list of dialogs.
content['dialogs'].append(sample)
generated_content.append(content)
return generated_content
def generate_questions(scenes, dialogs, templates, params):
"""Wrapper generates questions.
Args:
scenes: List of scene graphs to generate questions
dialogs: Contains already generated captions for scenes graphs
templates: List of available question templates
params: Beam search parameters for question generation
Returns:
new_dialogs: Generated raw dialogs with captions and questions
"""
new_dialogs = []
for scene_id, dialog_datum in enumerate(dialogs):
image_dialogs = copy.deepcopy(dialog_datum)
image_dialogs['dialogs'] = []
for dialog in dialog_datum['dialogs']:
# Pick a template at random.
flag = False
iter_count = 0
while not flag:
# Pick a template at random.
template = random.choice(templates)
# Filter objects based on constraints.
filter_objs = constraints.question(scenes['scenes'][scene_id],
dialog, template)
flag = len(filter_objs) != 0
# Extreme case -- exit
iter_count += 1
if iter_count > 10:
break
# Realize q question.
if flag:
deep_copy = copy.deepcopy(dialog)
gen_dialog = realize_question(deep_copy, template, filter_objs)
image_dialogs['dialogs'].append(copy.deepcopy(gen_dialog))
new_dialogs.append(image_dialogs)
return new_dialogs
def worker(scenes, cap_templates, ques_templates, worker_id, out_q):
"""Worker method generates dialogs (caption + questions) for pool of scenes.
Args:
scenes: List of CLEVR scenes to generate dialogs
cap_templates: Templates for caption generation
ques_templates: Templates for question generation
worker_id: Id for the current worker
out_q: Output queue to save generated dialogs from different sources
Returns:
Adds dialogs against the worker id in the output queue.
"""
dialogs = []
for index, scene in enumerate(scenes):
cur_time = time.strftime('%a-%d%b%y-%X', time.gmtime())
print('Generating [ %s ] [ Worker: %d, Progress: %d/%d Scene: %d ]' %
(cur_time, worker_id, index, len(scenes), scene['image_index']))
try:
gen_dialog = generate_dialog_bfs(
scene, cap_templates, ques_templates)
dialogs.append(json.loads(json.dumps(gen_dialog)))
except:
print('NOTE: Missing data for %d' % scene['image_index'])
out_q.put({worker_id: dialogs})
def generate_dialog_bfs(scene, cap_templates, ques_templates):
"""Perform approximate breadth-first-search (BFS) to generate dialogs.
Args:
scene: Scene graph for the CLEVR image
cap_templates: List of caption templates
ques_templates: List of question templates
Returns:
bundle: List of dialogs generated for the input scene graph
"""
bundle = {}
# Generate captions for the scene.
# Copy over image_index, split, image_filename from scene.
for key in ['image_index', 'split', 'image_filename']:
bundle[key] = scene[key]
template_dictionary = {ii['label']: ii for ii in cap_templates}
content = {}
# Filter objects based on constraints on captions.
filter_objs = constraints.caption(scene, cap_templates)
for filter_obj in filter_objs:
# Realize the text, and return the partial scene knowledge (q).
template = template_dictionary[filter_obj[0]['graph']['template']]
sample = realize_text_and_extract_scene(scene, template, filter_obj)
# Add it to the list of dialogs.
content[template['label']] = [sample]
# Now generate questions.
# Group templates, exist/count of similar type together.
ques_groups = collections.defaultdict(list)
labels = [ii['label'] for ii in ques_templates]
# print('\n'.join(labels))
for index, ii in enumerate(ques_templates):
if 'exist' in ii['label'] or 'count' in ii['label']:
ques_groups[labels[index][4:]].append(ii)
else:
ques_groups[labels[index]].append(ii)
for round_id in range(FLAGS.num_rounds):
new_content = {}
# For each group.
for cap_label, cap_dialogs in content.items():
cur_pool = []
for dialog_datum in cap_dialogs:
for _, group in ques_groups.items():
template = random.choice(group)
# Make a copy.
datum_copy = copy.deepcopy(dialog_datum)
# Filter objects based on constraints.
filter_objs = constraints.question(
scene, datum_copy, template)
if len(filter_objs) == 0:
continue
# Realize q question.
gen_dialog = realize_question(
datum_copy, template, filter_objs)
cur_pool.append(gen_dialog)
if round_id in ranges:
for d_id, dialog in enumerate(cur_pool):
n_types = {'indep': 0, 'seek': 0, 'exist': 0, 'count': 0}
keep_dialog = True
labels = [ii['label']
for ii in dialog['template_info'][1:]]
for label in labels:
if label in gvars.METAINFO['independent_questions']:
n_types['indep'] += 1
label_type = label.split('-')[0]
n_types[label_type] += 1
# Heuristic A, C
for q_type, count in n_types.items():
limit = ranges[round_id][q_type]
if limit[0] > count or count > limit[1]:
keep_dialog = False
break
# Heuristic B
limit = ranges[round_id]['exist+count']
if n_types['count'] + n_types['exist'] > limit[1]:
keep_dialog = False
if not keep_dialog:
cur_pool[d_id] = None
cur_pool = [ii for ii in cur_pool if ii is not None]
# Keep limited number of beams (for speed).
if len(cur_pool) > FLAGS.num_beams:
cur_pool = sample_beams(cur_pool)[:FLAGS.num_beams]
new_content[cap_label] = cur_pool
content = copy.deepcopy(new_content)
# Get dialogs with sim, imm2, early questions.
for cap_label, cap_dialogs in content.items():
# Sample beams.
content[cap_label] = sample_beams(cap_dialogs)
# Remove keys that are empty.
empty_keys = [key for key, val in content.items() if len(val) == 0]
for key in empty_keys:
del content[key]
# For each caption, sample one.
sampled_dialogs = []
for cap_label, cap_dialogs in content.items():
if len(cap_dialogs) > 0:
sampled_dialogs.append(cap_dialogs.pop())
# Get 5 per image, compensate by taking from other entries.
content_keys = [ii for ii in content.keys()]
while len(sampled_dialogs) < 5:
random_label = random.choice(content_keys)
sampled_dialogs.append(cap_dialogs.pop())
# Finally, make the dialog text readable.
sampled_dialogs = clean_dialog_text(sampled_dialogs)
# Generate the coreference chain.
for dialog_id, dialog in enumerate(sampled_dialogs):
sampled_dialogs[dialog_id] = identify_coref_chains(dialog)
bundle['dialogs'] = sampled_dialogs
return bundle
def sample_beams(dialogs):
"""Samples beams based on the number of constraints satisfied.
Args:
dialogs: Generated dialogs to sample beams
Returns:
sampled_dialogs: List of sampled dialogs based on the constraints
"""
num_constraints = []
for d_id, dialog in enumerate(dialogs):
satisfied = 0
labels = [ii['label'] for ii in dialog['template_info'][1:]]
# Have a imm2 for sure
satisfied += np.sum(['imm2' in ii for ii in labels])
# Have a imm2 for sure
satisfied += np.sum(['sim' in ii for ii in labels])
# Have 'early'
satisfied += min(4, np.sum(['early' in ii for ii in labels]))
# Add it with the number of constraints it satisfies.
num_constraints.append((satisfied, d_id))
# Then order.
def sort_key(x): return (x[0], random.random())
ids = sorted(num_constraints, key=sort_key, reverse=True)
sampled_dialogs = [dialogs[ii[1]] for ii in ids]
return sampled_dialogs
def identify_coref_chains(dialog):
"""Identifies the coreference chains in generated dialog.
Args:
dialog: Generated dialogs for which coreference chains to be identified
Returns:
dialog: A copy of dialog, with coreference chains annotated
"""
for r_id, datum in enumerate(dialog['dialog']):
label = datum['template']
if label in gvars.METAINFO['independent_questions']:
dialog['graph']['history'][r_id + 1]['dependence'] = None
continue
if (label == 'exist-attribute-group' or label == 'count-attribute-group' or
label == 'count-all-group'):
dialog['graph']['history'][r_id + 1]['dependence'] = r_id - 1
continue
if 'imm' in label:
dialog['graph']['history'][r_id + 1]['dependence'] = r_id - 1
continue
if 'early' in label:
# Go over previous history.
cur_history = dialog['graph']['history'][r_id + 1]
assert 'focus_id' in cur_history and 'focus_desc' in cur_history,\
'More focus objects than one, no focus objects!'
focus_id = cur_history['focus_id']
for attr in gvars.METAINFO['attributes']:
if attr in cur_history['focus_desc']:
break
history = dialog['graph']['history'][:r_id + 1]
for hist_id, hist_datum in enumerate(history):
for obj in hist_datum['objects']:
if obj['id'] == focus_id and attr in obj:
dialog['graph']['history'][r_id +
1]['dependence'] = hist_id - 1
break
return dialog
def main(unused_argv):
"""Main method generates the CLEVR-Dialog dataset.
"""
# Read the scene file.
with open(FLAGS.scene_path, 'r') as file_id:
scenes = json.load(file_id)
# Read the synonyms file.
with open(FLAGS.synonym_path, 'r') as file_id:
synonyms = json.load(file_id)
def sorter(x): return len(x[0].split(' '))
# Read the metainformation file.
with open(FLAGS.metainfo_path, 'r') as file_id:
gvars.METAINFO = json.load(file_id)
tag_inv_map = {attr: tag for tag, attr in gvars.METAINFO['tag_map'].items()
if tag != '<P>'}
gvars.METAINFO['tag_inv_map'] = tag_inv_map
gvars.METAINFO['synonym_keys'] = sorted(synonyms.items(),
key=sorter, reverse=True)
# Add ids to objects.
scenes = utils.add_object_ids(scenes)
scenes = utils.clean_object_attributes(scenes)
# Read the caption templates.
template_paths = os.listdir(FLAGS.caption_template_root)
cap_templates = []
for ii in template_paths:
with open(os.path.join(FLAGS.caption_template_root, ii), 'r') as file_id:
cur_templates = json.load(file_id)
cap_templates.extend(cur_templates)
# utils.pretty_print_templates(cap_templates, 1)
# Read the question templates.
template_paths = os.listdir(FLAGS.question_template_root)
ques_templates = []
for ii in template_paths:
with open(os.path.join(FLAGS.question_template_root, ii), 'r') as file_id:
cur_templates = json.load(file_id)
ques_templates.extend(cur_templates)
# utils.pretty_print_templates(ques_templates, 1)
# 1. Check if there a scene_id_file specified.
# 2. Check if num_images is -1
if FLAGS.scene_id_file != '':
with open(FLAGS.scene_id_file, 'r') as file_id:
missing_ids = [int(ii.strip('\n')) for ii in file_id.readlines()]
print('Dialogs missing for scenes: %d' % len(missing_ids))
# Create a image_index -> scenes list index dictionary
image_list_id_dict = {ii['image_index']: index
for index, ii in enumerate(scenes['scenes'])}
scenes_subset = [scenes['scenes'][image_list_id_dict[scene_id]]
for scene_id in missing_ids]
elif FLAGS.num_images == -1:
scenes_subset = scenes['scenes']
else:
scenes_subset = scenes['scenes'][0: FLAGS.num_images]
# BFS for each scene.
if FLAGS.num_workers == 1:
# Single thread version.
dialogs = []
for index, scene in enumerate(scenes_subset):
cur_time = time.strftime('%a-%d%b%y-%X', time.gmtime())
print('Generating [ %s ] [ Worker: %d, Progress: %d/%d Scene: %d ]' %
(cur_time, 0, index, len(scenes_subset), scene['image_index']))
gen_dialog = generate_dialog_bfs(
scene, cap_templates, ques_templates)
dialogs.append(gen_dialog)
else:
# Multithread version.
output_q = multiprocessing.Queue()
jobs = []
for worker_id in range(FLAGS.num_workers):
allotment = scenes_subset[worker_id::FLAGS.num_workers]
inputs = (allotment, cap_templates, ques_templates)
inputs += (worker_id, output_q)
process = multiprocessing.Process(target=worker, args=inputs)
jobs.append(process)
process.start()
# Wait for all the jobs to finish and collect the output.
final_results = {}
for _ in jobs:
final_results.update(output_q.get())
for job in jobs:
job.join()
# Flatten and sort.
final_results = [jj for _, ii in final_results.items() for jj in ii]
dialogs = sorted(final_results, key=lambda x: x['image_index'])
# utils.pretty_print_dialogs(dialogs)
# Save the dialogs.
print('Saving dialog at: %s' % FLAGS.save_path)
with open(FLAGS.save_path, 'w') as file_id:
json.dump(dialogs, file_id)
if __name__ == '__main__':
gvars.initialize()
app.run(main)