-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmodel.py
148 lines (117 loc) · 5.75 KB
/
model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
# Model class for the CVPR 23 paper: "Learning Action Changes by Measuring Verb-Adverb Textual Relationships"
import torch.nn as nn
import torch
from attention import SDPAttention
from text_encoder import TextEncoder
def build_mlp(channels_in, channels_out, hidden_units, bn=False, dropout=0.0, act_func='relu',
add_act_func_to_last=False):
if isinstance(hidden_units, str):
hidden_units = tuple(int(x) for x in hidden_units.split(',') if x.strip())
assert not (dropout and bn), 'Choose either BN or dropout'
assert act_func in ('relu', 'tanh', 'none'), f'Cannot deal with this activation function yet: {act_func}'
layers = []
for i in range(len(hidden_units)):
if i == 0:
in_ = channels_in
out_ = hidden_units[i]
if bn:
layers.append(nn.BatchNorm1d(in_))
else:
in_ = hidden_units[i - 1]
out_ = hidden_units[i]
layers.append(nn.Linear(in_, out_))
if act_func == 'relu':
layers.append(nn.ReLU())
elif act_func == 'tanh':
layers.append(nn.Tanh())
elif act_func != 'none':
raise ValueError(f'Unexpected activation function: {act_func}')
if dropout:
layers.append(nn.Dropout(dropout))
elif bn:
layers.append(nn.BatchNorm1d(out_))
if hidden_units:
layers.append(nn.Linear(hidden_units[-1], channels_out))
else:
layers.append(nn.Linear(channels_in, channels_out))
if add_act_func_to_last:
if act_func == 'relu':
layers.append(nn.ReLU())
elif act_func == 'tanh':
layers.append(nn.Tanh())
elif act_func != 'none':
raise ValueError(f'Unexpected activation function: {act_func}')
return layers
class RegClsAdverbModel(nn.Module):
def __init__(self, train_dataset, args, text_emb_dim=512):
super(RegClsAdverbModel, self).__init__()
assert not (args.fixed_d and args.cls_variant)
self.train_dataset = train_dataset
self.args = args
self.attention = SDPAttention(self.train_dataset.feature_dim, text_emb_dim, text_emb_dim,
text_emb_dim, heads=4, dropout=args.dropout)
modifier_input = text_emb_dim
self.n_verbs = len(self.train_dataset.verbs)
self.n_adverbs = len(self.train_dataset.adverbs)
self.n_pairs = len(self.train_dataset.pairs)
layers = build_mlp(modifier_input, self.n_adverbs, args.hidden_units, dropout=args.dropout)
self.rho = nn.Sequential(*layers)
self.ce = nn.CrossEntropyLoss()
self.mse = nn.MSELoss()
text_embeddings_verbs = TextEncoder.get_text_embeddings(args, self.train_dataset.verbs)
self.verb_embedding = nn.Embedding.from_pretrained(text_embeddings_verbs, freeze=False)
_, _, delta_dict, d_dict, _, _, _, _ = TextEncoder.compute_delta(args, train_dataset.dataset_data,
train_dataset.antonyms,
no_ant=args.no_antonyms)
self.delta_dict = delta_dict
self.d_dict = d_dict
def forward(self, features, labels_tuple, training=True):
video_features = features['s3d_features']
adverbs, verbs, neg_adverbs = labels_tuple
query = self.verb_embedding(verbs)
padding_mask = video_features == 0
video_emb, attention_weights = self.attention(video_features, query, padding_mask=padding_mask)
o = self.rho(video_emb)
if self.args.cls_variant:
target = adverbs
loss = self.ce(o, target)
else:
if self.args.fixed_d:
d = torch.ones(len(adverbs)).cuda()
else:
d = [self.delta_dict[self.train_dataset.idx2verb[v.item()]][self.train_dataset.idx2adverb[a.item()]]
for v, a in zip(verbs, adverbs)]
d = torch.Tensor(d).cuda()
target = self.create_target_from_delta(d, adverbs, neg_adverbs)
loss = self.mse(o, target)
if not training:
predictions = self.get_predictions(video_features)
predictions_no_act_gt = predictions
pred_tuple = (predictions, predictions_no_act_gt)
else:
pred_tuple = None
output = [loss, pred_tuple]
return output
def create_target_from_delta(self, delta, adverbs, neg_adverbs):
assert delta.min() > 0 # the loss assumes this to flip the target for antonyms
batch_size = len(adverbs)
target = torch.zeros((batch_size, self.n_adverbs)).cuda()
target.scatter_(1, adverbs.unsqueeze(1), delta.unsqueeze(1))
if not self.args.no_antonyms:
target.scatter_(1, neg_adverbs.unsqueeze(1), -delta.unsqueeze(1))
return target
def get_predictions(self, video_features, verbs=None):
assert verbs is None, 'Do not pass verb labels. Predictions scores are later calculated accordingly'
batch_size = video_features.shape[0]
pair_scores = torch.zeros((batch_size, self.n_pairs))
for verb_idx, verb_str in self.train_dataset.idx2verb.items():
emb_idx = torch.LongTensor([verb_idx]).repeat(batch_size).cuda()
q = self.verb_embedding(emb_idx)
video_emb, _ = self.attention(video_features, q)
adverb_pred = self.rho(video_emb)
for adv_idx, adv_str in self.train_dataset.idx2adverb.items():
p_idx = self.train_dataset.get_verb_adv_pair_idx(dict(verb=[verb_idx], adverb=[adv_idx]))
assert len(p_idx) == 1
p_idx = p_idx[0]
pair_scores[:, p_idx] = adverb_pred[:, adv_idx]
return pair_scores