-
Notifications
You must be signed in to change notification settings - Fork 197
/
Copy pathvideo_captioning_from_summarizer_mapper.py
261 lines (232 loc) · 11.3 KB
/
video_captioning_from_summarizer_mapper.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
import copy
from typing import Dict, Optional
from pydantic import PositiveInt
from data_juicer.utils.constant import Fields, MetaKeys
from data_juicer.utils.lazy_loader import AUTOINSTALL
from data_juicer.utils.mm_utils import SpecialTokens, remove_special_tokens
from data_juicer.utils.model_utils import get_model, prepare_model
from ..base_op import OPERATORS, Mapper
NAME = 'video_captioning_from_summarizer_mapper'
@OPERATORS.register_module(NAME)
class VideoCaptioningFromSummarizerMapper(Mapper):
"""
Mapper to generate video captions by summarizing several kinds of generated
texts (captions from video/audio/frames, tags from audio/frames, ...)
"""
_accelerator = 'cuda'
_batched_op = True
def __init__(self,
hf_summarizer: str = None,
trust_remote_code: bool = False,
consider_video_caption_from_video: bool = True,
consider_video_caption_from_audio: bool = True,
consider_video_caption_from_frames: bool = True,
consider_video_tags_from_audio: bool = True,
consider_video_tags_from_frames: bool = True,
vid_cap_from_vid_args: Optional[Dict] = None,
vid_cap_from_frm_args: Optional[Dict] = None,
vid_tag_from_aud_args: Optional[Dict] = None,
vid_tag_from_frm_args: Optional[Dict] = None,
keep_tag_num: PositiveInt = 5,
keep_original_sample: bool = True,
*args,
**kwargs):
"""
Initialization method.
:param hf_summarizer: the summarizer model used to summarize texts
generated by other methods.
:param consider_video_caption_from_video: whether to consider the video
caption generated from video directly in the summarization process.
Default: True.
:param consider_video_caption_from_audio: whether to consider the video
caption generated from audio streams in the video in the
summarization process. Default: True.
:param consider_video_caption_from_frames: whether to consider the
video caption generated from sampled frames from the video in the
summarization process. Default: True.
:param consider_video_tags_from_audio: whether to consider the video
tags generated from audio streams in the video in the summarization
process. Default: True.
:param consider_video_tags_from_frames: whether to consider the video
tags generated from sampled frames from the video in the
summarization process. Default: True.
:param vid_cap_from_vid_args: the arg dict for video captioning from
video directly with keys are the arg names and values are the arg
values. Default: None.
:param vid_cap_from_frm_args: the arg dict for video captioning from
sampled frames from the video with keys are the arg names and
values are the arg values. Default: None.
:param vid_tag_from_aud_args: the arg dict for video tagging from audio
streams in the video with keys are the arg names and values are the
arg values. Default: None.
:param vid_tag_from_frm_args: the arg dict for video tagging from
sampled frames from the video with keys are the arg names and
values are the arg values. Default: None.
:param keep_tag_num: max number N of tags from sampled frames to keep.
Too many tags might bring negative influence to summarized text, so
we consider to only keep the N most frequent tags. Default: 5.
:param keep_original_sample: whether to keep the original sample. If
it's set to False, there will be only summarized captions in the
final datasets and the original captions will be removed. It's True
in default.
:param args: extra args
:param kwargs: extra args
"""
kwargs.setdefault('mem_required', '40GB')
super().__init__(*args, **kwargs)
AUTOINSTALL.check([
'torch',
'transformers',
'transformers_stream_generator',
'einops',
'accelerate',
'tiktoken', # by audio caption
'torchaudio', # by audio tag
])
self.keep_original_sample = keep_original_sample
self.extra_args = kwargs
# prepare summarizer
self._hf_summarizer = hf_summarizer if hf_summarizer else 'mrm8488/flan-t5-large-finetuned-openai-summarize_from_feedback' # noqa: E501
self.model_key = prepare_model(
model_type='huggingface',
pretrained_model_name_or_path=self._hf_summarizer,
trust_remote_code=trust_remote_code)
# prepare input texts ops
if vid_cap_from_vid_args is None:
vid_cap_from_vid_args = {}
if vid_cap_from_frm_args is None:
vid_cap_from_frm_args = {}
if vid_tag_from_aud_args is None:
vid_tag_from_aud_args = {}
if vid_tag_from_frm_args is None:
vid_tag_from_frm_args = {}
self.FIXED_ARGS = {
'caption_num': 1,
'keep_candidate_mode': 'random_any',
'keep_original_sample': False,
}
self.cap_op_list = []
self.tag_op_list = []
if consider_video_caption_from_video:
from .video_captioning_from_video_mapper import \
VideoCaptioningFromVideoMapper
self.cap_op_list.append(
VideoCaptioningFromVideoMapper(**self._prepare_op_args(
VideoCaptioningFromVideoMapper, vid_cap_from_vid_args)))
if consider_video_caption_from_audio:
from .video_captioning_from_audio_mapper import \
VideoCaptioningFromAudioMapper
self.cap_op_list.append(
VideoCaptioningFromAudioMapper(**self._prepare_op_args(
VideoCaptioningFromAudioMapper, {})))
if consider_video_caption_from_frames:
from .video_captioning_from_frames_mapper import \
VideoCaptioningFromFramesMapper
self.cap_op_list.append(
VideoCaptioningFromFramesMapper(**self._prepare_op_args(
VideoCaptioningFromFramesMapper, vid_cap_from_frm_args)))
if consider_video_tags_from_audio:
from .video_tagging_from_audio_mapper import \
VideoTaggingFromAudioMapper
self.tag_op_list.append(
VideoTaggingFromAudioMapper(**self._prepare_op_args(
VideoTaggingFromAudioMapper, vid_tag_from_aud_args)))
if consider_video_tags_from_frames:
from .video_tagging_from_frames_mapper import \
VideoTaggingFromFramesMapper
self.tag_op_list.append(
VideoTaggingFromFramesMapper(**self._prepare_op_args(
VideoTaggingFromFramesMapper, vid_tag_from_frm_args)))
self.keep_tag_num = keep_tag_num
def _prepare_op_args(self, op_class, args_dict):
required_args = set(op_class.__init__.__code__.co_varnames)
args_dict.update(self.FIXED_ARGS)
temp_args = copy.deepcopy(args_dict)
for key in temp_args:
if key not in required_args:
args_dict.pop(key)
args_dict['accelerator'] = self.accelerator
return args_dict
def _process_single_sample(self, sample, rank=None):
# there is no video in this sample
if self.video_key not in sample or not sample[self.video_key]:
return []
# there is no activated ops
if len(self.cap_op_list) == 0 and len(self.tag_op_list) == 0:
return []
# get paths of all video(s)
loaded_video_keys = sample[self.video_key]
# get models
model, tokenizer = get_model(self.model_key, rank, self.use_cuda())
captioned_sample = copy.deepcopy(sample)
# generate for each video chunk by chunk
captioned_texts = ''
offset = 0
for chunk in sample[self.text_key].split(SpecialTokens.eoc):
# skip empty chunks
if not chunk.strip():
continue
vid_count = chunk.count(SpecialTokens.video)
if vid_count == 0:
# add special tokens
captioned_texts += f'{chunk}{SpecialTokens.eoc}'
continue
# make a temporary sample
temp_sample = {
self.text_key: chunk,
self.video_key: loaded_video_keys[offset:offset + vid_count],
Fields.meta: {},
}
captioned_text_list = []
# tag ops
for op in self.tag_op_list:
temp_sample = op.process(temp_sample, rank=rank)
if MetaKeys.video_audio_tags in temp_sample[Fields.meta]:
captioned_text_list.extend(
temp_sample[Fields.meta][MetaKeys.video_audio_tags])
if MetaKeys.video_frame_tags in temp_sample[Fields.meta]:
for tag_list in temp_sample[Fields.meta][
MetaKeys.video_frame_tags]:
captioned_text_list.extend(tag_list[self.keep_tag_num])
# cap ops
for op in self.cap_op_list:
captioned_text_list.append(
remove_special_tokens(
op._process_single_sample(temp_sample,
rank=rank)[0]['text']))
# summarization
all_texts = ', '.join(captioned_text_list)
input_ids = tokenizer(all_texts, return_tensors='pt').input_ids.to(
model.device)
outputs = model.generate(input_ids, max_new_tokens=128)
summarized_text = tokenizer.decode(outputs[0],
skip_special_tokens=True)
offset += vid_count
captioned_text = f'{SpecialTokens.video * vid_count} ' \
f'{summarized_text}'
# add special tokens
captioned_texts += f'{captioned_text}{SpecialTokens.eoc}'
captioned_sample[self.text_key] = captioned_texts
return [captioned_sample]
def process_batched(self, samples, rank=None):
# reconstruct samples from "dict of lists" to "list of dicts"
reconstructed_samples = []
for i in range(len(samples[self.text_key])):
reconstructed_samples.append(
{key: samples[key][i]
for key in samples})
samples_after_split = []
# do split for each sample within the batch
for ori_sample in reconstructed_samples:
if self.keep_original_sample:
samples_after_split.append(ori_sample)
generated_samples = self._process_single_sample(ori_sample,
rank=rank)
if len(generated_samples) != 0:
samples_after_split.extend(generated_samples)
# reconstruct samples from "list of dicts" to "dict of lists"
keys = samples_after_split[0].keys()
res_samples = {}
for key in keys:
res_samples[key] = [s[key] for s in samples_after_split]
return res_samples