-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathclassify_from_model.py
346 lines (286 loc) · 14.7 KB
/
classify_from_model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
"""Example call:
python ../classify_from_model.py --path checkpoint.pth.tar --gpu 1 --dataset cifar10 --data /home/abenjamin/data/ --n-filters 32 --noise-dim 100 --lr 1 --epochs 20 --wd 0.001 --opt sgd --lr-schedule"""
from torch import optim
from torchvision import datasets, transforms, utils
from torch.utils.data import Subset
import torch
import torch.nn as nn
import torch.backends.cudnn as cudnn
from dcgan_models import Inference, Generator
from decoder_models import LinearDecoder, NonlinearDecoder
from utils import gen_surprisal
import argparse
import os
import random
import warnings
parser = argparse.ArgumentParser(description='Generate and save images of a network checkpoint.')
parser.add_argument('--path', metavar='DIR', default = "checkpoint.pth.tar",
help='path to the saved model checkpoint.')
parser.add_argument('-d', '--data', metavar='DIR', default = "../data",
help='path to dataset. Loads MNIST if nonexistant.')
parser.add_argument('--dataset', default='imagenet',
choices= ['imagenet', 'folder', 'lfw', 'cifar10', 'mnist'],
help="What dataset are we training on?")
parser.add_argument('-j', '--workers', default=4, type=int, metavar='N',
help='number of data loading workers (default: 4)')
parser.add_argument('--seed', default=None, type=int,
help='seed for initializing training. ')
parser.add_argument('--gpu', default=None, type=int,
help='GPU id to use.')
parser.add_argument('-b', '--batch-size', default=128, type=int,
metavar='N',
help='mini-batch size (default: 256)')
parser.add_argument('--lr', '--learning-rate', default=1, type=float,
metavar='LRC', help='initial learning rate Default 1e-4')
parser.add_argument('--epochs', default=20, type=int, metavar='N',
help='number of total epochs to run. Default 50')
parser.add_argument('--opt', default='opt',
choices= ['adam','sgd'],
help="What algorithm to use?")
parser.add_argument('--wd', '--weight-decay', default=1e-3, type=float,
metavar='W', help='weight decay (default: 1e-4)',
dest='wd')
parser.add_argument('--noise-dim', default=40, type=int, metavar='ND',
help='Dimensionality of the top layer of the cortex.')
parser.add_argument('--n-filters', default=64, type=int,
help='Number of filters in the first conv layer of the DCGAN. Default 64')
parser.add_argument('--n-folds', default=10, type=int,
help='Number of CV folds to calculate the accuracy. Default 10.')
parser.add_argument('--nonlinear', action='store_true',
help="Don't use logistic regression but rather a 2-layer MLP")
parser.add_argument('--loss-type', default='wasserstein',
choices= ['BCE', 'wasserstein', 'hinge'],
help="The form of the minimax loss function. BCE = binary cross entropy of the original GAN")
parser.add_argument('--hidden-size', default=1000, type=int, metavar='ND',
help='Dimensionality of the hidden layer of the nonlinear decoder.')
parser.add_argument('--image-size', default=64, type=int,
help='Images to this many pixels. (default 64)')
parser.add_argument('--noise-type', default = 'none',
choices= ['none', 'fixed', 'learned_by_layer', 'learned_by_channel', 'learned_filter', 'poisson'],
help="What variance of Gaussian noise should be applied after all layers in the "
"cortex? See docs for details. Default is no noise; fixed has variance 0.01")
parser.add_argument('--he-initialization', action='store_true',
help='As in ProgressiveGANs. Plays well with divisive normalization')
parser.add_argument('--divisive-normalization', action='store_true',
help='Divisive normalization over channels, pixel by pixel. As in ProgressiveGANs')
parser.add_argument('--lr-schedule', action='store_true',
help='Learning rate *=.1 halfway through.')
def main(args):
if args.seed is not None:
random.seed(args.seed)
torch.manual_seed(args.seed)
cudnn.deterministic = True
warnings.warn('You have chosen to seed training. '
'This will turn on the CUDNN deterministic setting, '
'which can slow down your training considerably! '
'You may see unexpected behavior when restarting '
'from checkpoints.')
if args.gpu is not None:
warnings.warn('You have chosen a specific GPU. ')
inference,generator = load_cortex(args.path, args)
accuracies, reconstructions = decode_classes_from_layers(args.gpu,
inference,
generator,
args.image_size,
args.n_filters,
args.noise_dim,
args.data,
args.dataset,
args.nonlinear,
args.lr,
args.n_folds,
args.epochs,
args.hidden_size,
args.wd,
args.opt,
args.lr_schedule,
args.batch_size,
args.workers)
for i in range(6):
print("Layer{}: Accuracy {} +/- {}".format(i, accuracies.mean(dim=0)[i],accuracies.std(dim=0)[i]))
def load_cortex(path, args):
"""Loads a cortex from path."""
bn = False if args.loss_type == 'wasserstein' else True
inference = Inference(args.noise_dim, args.n_filters,
1 if args.dataset == 'mnist' else 3,
image_size=args.image_size,
bn=args.bn, hard_norm=args.divisive_normalization, spec_norm=args.spec_norm, derelu=False)
generator = Generator(args.noise_dim, args.n_filters, 1 if args.dataset == 'mnist' else 3,
image_size=args.image_size,
hard_norm=args.divisive_normalization)
if os.path.isfile(path):
print("=> loading checkpoint '{}'".format(path))
# load onto the CPU
checkpoint = torch.load(path,map_location=torch.device('cpu'))
inference.load_state_dict(checkpoint['inference_state_dict'])
generator.load_state_dict(checkpoint['generator_state_dict'])
print("=> loaded checkpoint '{}' (epoch {})"
.format(path, checkpoint['epoch']))
else:
raise IOError("=> no checkpoint found at '{}'".format(path))
return inference,generator
def train(inference, optimizer, decoder, train_loader, gpu):
"""Given some training data, feed that through cortex, put all activations through decoders,
and train the decoder on the supervised task"""
inference.eval()
decoder.train()
loss_fn = nn.CrossEntropyLoss()
for batch, labels in train_loader:
optimizer.zero_grad()
batch, labels = batch.cuda(gpu), labels.cuda(gpu)
# run though
inference(batch)
# get predictions
predictions = decoder(inference.intermediate_state_dict)
loss = 0
for pred in predictions:
loss = loss + loss_fn(pred, labels)
loss.backward()
optimizer.step()
def test(inference, generator, decoder, test_loader, gpu, epoch, n_examples, verbose = False):
"""Returns the classification error and loss on this fold of the test set for each of the 5 layers + the input"""
inference.eval()
decoder.eval()
generator.eval()
correct = [0 for _ in range(6)]
total = 0
criterion = nn.MSELoss()
reconstruction_losses = torch.zeros(5)
for batch, labels in test_loader:
batch, labels = batch.cuda(gpu), labels.cuda(gpu)
# run though
inference(batch)
# get predictions
predictions = decoder(inference.intermediate_state_dict)
# get accuracy on each layer
# also get the reconstruction losses
reconstructions = gen_surprisal(inference.intermediate_state_dict, generator, 1, criterion,
None, detach_inference=True, as_list = True)
total += labels.size(0)
for i in range(6):
_, predicted = torch.max(predictions[i].data, 1)
correct[i] += float((predicted == labels).sum().item())
reconstruction_losses += torch.Tensor(reconstructions) * labels.size(0)
accuracies = []
if verbose:
print('Epoch {}: accuracy on {} test images:'.format(epoch, n_examples))
for i in range(6):
accuracy = 100 * correct[i] / total
accuracies.append(accuracy)
if verbose:
print('Layer{}: {}'.format(i, accuracy))
reconstruction_losses /= total
return accuracies, reconstruction_losses
def decode_classes_from_layers(gpu,
inference,
generator,
image_size,
n_filters,
noise_dim,
data_path,
dataset,
nonlinear = False,
lr = 0.001,
folds = 10,
epochs = 50,
hidden_size = 1000,
wd = 1e-4,
opt = 'adam',
lr_schedule = False,
batch_size = 128,
workers = 4,
verbose = True):
""" Trains a linear or nonlinear decoder from a given layer of the cortex, all layers at a time (including inputs)
Does k-fold CV on the test set of this dataset. A random permutation is used.
Returns a tensor of accuracies on each of the k folds and for each of the 6 decoders:
Input --- Layer1 .... Layer4 --- Noise
"""
# ----- Get dataset ------ #
# Data loading code
valdir = os.path.join(data_path, 'val')
normalize = transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
if dataset in ['imagenet', 'folder', 'lfw']:
# folder dataset
all_test_dataset = datasets.ImageFolder(
valdir,
transforms.Compose([
transforms.Resize(image_size),
transforms.CenterCrop(image_size),
transforms.ToTensor(),
normalize,
]))
nc = 3
n_classes = 1000
elif dataset == 'cifar10':
all_test_dataset = datasets.CIFAR10(root=data_path, download=True, train=False,
transform=transforms.Compose([
transforms.Resize(image_size),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
]))
nc = 3
n_classes = 10
elif dataset == 'mnist':
all_test_dataset = datasets.MNIST(root=data_path, download=True, train=False,
transform=transforms.Compose([
transforms.Resize(image_size),
transforms.ToTensor(),
transforms.Normalize((0.5,), (0.5,)),
]))
nc = 1
n_classes = 10
assert all_test_dataset
perm = torch.randperm(len(all_test_dataset))
n_test_examples = len(all_test_dataset) // folds
all_accuracies = []
all_reconstructions = []
for f in range(folds):
# ---- Get CV indices ----
test_idx = perm[f * n_test_examples: (f+1) * n_test_examples]
if f==folds-1:
#last fold may be larger if len(all_test_dataset) % folds != 0
test_idx = perm[f * n_test_examples:]
train_idx = torch.cat((perm[:f * n_test_examples],
perm[(f + 1) * n_test_examples:]))
# ----- Make loaders -----
train_dataset = Subset(all_test_dataset, train_idx)
test_dataset = Subset(all_test_dataset, test_idx)
train_loader = torch.utils.data.DataLoader(
train_dataset, batch_size=batch_size, shuffle=True,
num_workers=workers, pin_memory=True,)
test_loader = torch.utils.data.DataLoader(
test_dataset, batch_size=batch_size, shuffle=True,
num_workers=workers, pin_memory=True,)
# ----- Build decoder ------
if nonlinear:
decoder = NonlinearDecoder(image_size, noise_dim, n_classes, nc, n_filters, hidden_size)
else:
decoder = LinearDecoder(image_size, noise_dim, n_classes, nc, n_filters)
# get to proper GPU
torch.cuda.set_device(gpu)
inference = inference.cuda(gpu)
generator = generator.cuda(gpu)
decoder = decoder.cuda(gpu)
# ------ Build optimizer ------ #
if opt == 'adam':
optimizer = optim.Adam(decoder.parameters(), lr=lr, betas=(.9, 0.999), weight_decay = wd)
elif opt == 'sgd':
optimizer = optim.SGD(decoder.parameters(), lr=lr, momentum = 0.9, weight_decay = wd)
else:
raise AssertionError("This optimizer not implemented yet.")
for epoch in range(epochs):
if lr_schedule:
adjust_lr(epoch, optimizer, epochs)
train(inference, optimizer, decoder, train_loader, gpu)
if verbose or (epoch==epochs-1):
accuracies, reconstructions = test(inference, generator, decoder, test_loader, gpu, epoch, len(test_idx), verbose)
all_accuracies.append(accuracies)
all_reconstructions.append(reconstructions)
return torch.Tensor(all_accuracies), torch.stack(all_reconstructions)
def adjust_lr(epoch, optimizer,epochs):
if epoch %(epochs//3)==0:
for param_group in optimizer.param_groups:
param_group['lr'] *= .3
if __name__ == '__main__':
args = parser.parse_args()
main(args)