-
Notifications
You must be signed in to change notification settings - Fork 1
/
extract_train_moments.py
114 lines (96 loc) · 3.93 KB
/
extract_train_moments.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
import pickle
import torch
from data_loader import *
from network import *
import sys
from torch.autograd import Variable
def loadCheckpoints(model, PATH):
"""load pretrained model from disk
"""
if os.path.isfile(PATH):
print("=> loading checkpoint '{}'".format(PATH))
checkpoint = torch.load(PATH)
model.load_state_dict(checkpoint['state_dict'])
print("=> loaded checkpoint '{}' (accuracy {})"
.format(PATH, checkpoint['accuracy']))
else:
print("=> no checkpoint found at '{}'".format(PATH))
return model
def loadSavedMoments(filename):
"""load already existing moments from disk
"""
if os.path.isfile(filename):
fileObject = open(filename, 'rb')
moments_list, labels_list, number_of_images_processed = pickle.load(fileObject)
else:
moments_list, labels_list = [], []
number_of_images_processed = 0
return (moments_list, labels_list, number_of_images_processed)
def storeMoments(filename, data):
"""store newly extracted moments on disk
"""
fileObject = open(filename, 'wb')
pickle.dump(data, fileObject)
print ('Data successfully written on disk')
def obtainDataAsTensors(im_path, im_label):
"""obtain images and labels in the form of torch tensors for phase 2 manipulation
"""
transformations = torchvision.transforms.Compose([torchvision.transforms.ToTensor()])
img = Image.open(im_path)
img = img.convert('RGB')
img = transformations(img)
label = torch.from_numpy(np.asarray(im_label).reshape([1,1]))
return (img, label)
def extractTrainMoments(net):
"""
Extracting moments by using the network trained in phase 1
and the inputs as M_tr(Phase 2). 4096 moments are extracted
for each image.
"""
#TODO currently the images in the datasets are of dimensions (batch_size x (3x1000x1000)) whereas in the paper
#it's mentioned as (batch_size x (1000x1000x3)), check for correctness
Mtr_dataset = get_Mtr(QF)
print ("THE SIZE OF MTR DATASET IS ", sys.getsizeof(Mtr_dataset))
output_image, output_labels, num_images_already_processed = loadSavedMoments(saved_moments_filename)
num_of_images_processed = num_images_already_processed
#print ("SET TO PASS: ", Mtr_dataset)
image_paths = Mtr_dataset[0]
print ("THE NUMBER OF IMAGES FOR TRAINING ARE ", len(image_paths))
print ("THE SIZE OF IMAGE PATHS IS ", sys.getsizeof(image_paths))
#print ("IMAGE PATHS SIZE: ", len(image_paths))
image_labels = Mtr_dataset[1]
print ("THE SIZE OF IMAGE LABELS IS ", sys.getsizeof(image_labels))
#print ("IMAGE LABELS SIZE: ", len(image_labels))
for i in range(num_images_already_processed, len(image_labels)):
image, label = obtainDataAsTensors(image_paths[i], image_labels[i])
num_of_images_processed+=1
if (image.size(1)<2048 or image.size(2)<2048):
print("Image number ", num_of_images_processed)
#print ("SIZE OF IMAGE: ", image.size())
image = image.unsqueeze(0)
#print ("SIZE OF IMAGE TENSOR ON MEMORY IS ", image.element_size() * image.nelement())
#Wrap them in a Variable object
img = image.cuda(device)
img = Variable(img)
#Forward pass to extract moments for phase 2(this will be done one image at a time)
single_moment = net(img, phase = 1)
#print ("SIZE OF OUTPUT FROM MODEL ON DISK IS ", single_moment.element_size() * single_moment.nelement())
#print ("-------------SIZE OF SINGLE MOMENT-----------", single_moment.size())
output_image.append(single_moment.data[0])
output_labels.append(label)
if (num_of_images_processed%200==0):
storeMoments(saved_moments_filename, (output_image, output_labels, num_of_images_processed))
torch.cuda.empty_cache()
#MAIN
QF = sys.argv[1]
device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu")
print (device)
#path to save each training epoch
saved_model_filename = QF+'/best_model_phase_1.pth'
saved_moments_filename = QF+'/train_moments'
net_phase_1 = Net()
net_phase_1 = loadCheckpoints(net_phase_1, saved_model_filename)
#move model to cuda
net_phase_1 = net_phase_1.to(device)
#net_phase_1.eval()
extractTrainMoments(net_phase_1)