-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtrain.py
202 lines (165 loc) · 7.29 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
import datetime
from os import path
import math
import random
import numpy as np
import torch
from torch.utils.data import DataLoader, ConcatDataset
import torch.distributed as distributed
from model.model import PropagationModel
from dataset.static_dataset import StaticTransformDataset
from dataset.vos_dataset import VOSDataset
from util.logger import TensorboardLogger
from util.hyper_para import HyperParameters
from util.load_subset import load_sub_davis, load_sub_yv
"""
Initial setup
"""
# Init distributed environment
distributed.init_process_group(backend="nccl")
torch.manual_seed(14159265)
np.random.seed(14159265)
random.seed(14159265)
print('CUDA Device count: ', torch.cuda.device_count())
# Parse command line arguments
para = HyperParameters()
para.parse()
if para['benchmark']:
torch.backends.cudnn.benchmark = True
local_rank = torch.distributed.get_rank()
world_size = torch.distributed.get_world_size()
torch.cuda.set_device(local_rank)
print('I am rank %d in this world of size %d!' % (local_rank, world_size))
"""
Model related
"""
if local_rank == 0:
# Logging
if para['id'].lower() != 'null':
print('I will take the role of logging!')
long_id = '%s_%s' % (datetime.datetime.now().strftime('%b%d_%H.%M.%S'), para['id'])
else:
long_id = None
logger = TensorboardLogger(para['id'], long_id)
logger.log_string('hyperpara', str(para))
# Construct the rank 0 model
model = PropagationModel(para, logger=logger,
save_path=path.join('saves', long_id, long_id) if long_id is not None else None,
local_rank=local_rank, world_size=world_size).train()
else:
# Construct model for other ranks
model = PropagationModel(para, local_rank=local_rank, world_size=world_size).train()
# Load pertrained model if needed
if para['load_model'] is not None:
total_iter = model.load_model(para['load_model'])
print('Previously trained model loaded!')
else:
total_iter = 0
if para['load_network'] is not None:
model.load_network(para['load_network'])
print('Previously trained network loaded!')
"""
Dataloader related
"""
# To re-seed the randomness everytime we start a worker
def worker_init_fn(worker_id):
return np.random.seed(torch.initial_seed()%(2**31) + worker_id + local_rank*100)
def construct_loader(dataset):
train_sampler = torch.utils.data.distributed.DistributedSampler(dataset, rank=local_rank, shuffle=True)
train_loader = DataLoader(dataset, para['batch_size'], sampler=train_sampler, num_workers=8,
worker_init_fn=worker_init_fn, drop_last=True, pin_memory=True)
return train_sampler, train_loader
def renew_vos_loader(max_skip_yt, max_skip_da):
yv_dataset = VOSDataset(path.join(yv_root, 'JPEGImages'),
path.join(yv_root, 'Annotations'), max_skip_yt//5, is_bl=False, subset=load_sub_yv())
davis_dataset = VOSDataset(path.join(davis_root, 'JPEGImages', '480p'),
path.join(davis_root, 'Annotations', '480p'), max_skip_da, is_bl=False, subset=load_sub_davis())
train_dataset = ConcatDataset([davis_dataset]*5 + [yv_dataset])
print('YouTube dataset size: ', len(yv_dataset))
print('DAVIS dataset size: ', len(davis_dataset))
print('Concat dataset size: ', len(train_dataset))
return construct_loader(train_dataset)
def renew_bl_loader(max_skip):
train_dataset = VOSDataset(path.join(bl_root, 'JPEGImages'),
path.join(bl_root, 'Annotations'), max_skip, is_bl=True)
print('Blender dataset size: ', len(train_dataset))
print('Renewed with skip: ', max_skip)
return construct_loader(train_dataset)
"""
Dataset related
"""
skip_values = [10, 15, 20, 25, 10, 5]
davis_skip = [10, 15, 25, 10, 5, 1]
if para['stage'] == 0:
static_root = path.expanduser(para['static_root'])
fss_dataset = StaticTransformDataset(path.join(static_root, 'fss'), method=0)
duts_tr_dataset = StaticTransformDataset(path.join(static_root, 'DUTS-TR'), method=1)
duts_te_dataset = StaticTransformDataset(path.join(static_root, 'DUTS-TE'), method=1)
ecssd_dataset = StaticTransformDataset(path.join(static_root, 'ecssd'), method=1)
big_dataset = StaticTransformDataset(path.join(static_root, 'BIG_small'), method=1)
hrsod_dataset = StaticTransformDataset(path.join(static_root, 'HRSOD_small'), method=1)
# BIG and HRSOD have higher quality, use them more
train_dataset = ConcatDataset([fss_dataset, duts_tr_dataset, duts_te_dataset, ecssd_dataset]
+ [big_dataset, hrsod_dataset]*5)
train_sampler, train_loader = construct_loader(train_dataset)
print('Static dataset size: ', len(train_dataset))
elif para['stage'] == 1:
increase_skip_fraction = [0.1, 0.2, 0.3, 0.4, 0.8, 1.0]
bl_root = path.join(path.expanduser(para['bl_root']))
train_sampler, train_loader = renew_bl_loader(5)
renew_loader = renew_bl_loader
else:
increase_skip_fraction = [0.1, 0.2, 0.3, 0.4, 0.7, 0.9, 1.0]
# VOS dataset, 480p is used for both datasets
yv_root = path.join(path.expanduser(para['yv_root']), 'train_480p')
davis_root = path.join(path.expanduser(para['davis_root']), '2017', 'trainval')
train_sampler, train_loader = renew_vos_loader(5, 5)
renew_loader = renew_vos_loader
"""
Determine current/max epoch
"""
total_epoch = math.ceil(para['iterations']/len(train_loader))
current_epoch = total_iter // len(train_loader)
print('Number of training epochs (the last epoch might not complete): ', total_epoch)
if para['stage'] != 0:
increase_skip_epoch = [round(total_epoch*f) for f in increase_skip_fraction]
# Skip will only change after an epoch, not in the middle
print('The skip value will increase approximately at the following epochs: ', increase_skip_epoch[:-1])
"""
Starts training
"""
# Need this to select random bases in different workers
np.random.seed(np.random.randint(2**30-1) + local_rank*100)
try:
for e in range(current_epoch, total_epoch):
print('Epoch %d/%d' % (e, total_epoch))
if para['stage']==2 and e!=total_epoch and e>=increase_skip_epoch[0]:
while e >= increase_skip_epoch[0]:
cur_skip = skip_values[0]
cur_skip_davis = davis_skip[0]
skip_values = skip_values[1:]
davis_skip = davis_skip[1:]
increase_skip_epoch = increase_skip_epoch[1:]
print('Increasing skip to: ', cur_skip)
train_sampler, train_loader = renew_loader(cur_skip, cur_skip_davis)
if para['stage']==1 and e!=total_epoch and e>=increase_skip_epoch[0]:
while e >= increase_skip_epoch[0]:
cur_skip = skip_values[0]
skip_values = skip_values[1:]
increase_skip_epoch = increase_skip_epoch[1:]
print('Increasing skip to: ', cur_skip)
train_sampler, train_loader = renew_loader(cur_skip)
# Crucial for randomness!
train_sampler.set_epoch(e)
# Train loop
model.train()
for data in train_loader:
model.do_pass(data, total_iter)
total_iter += 1
if total_iter >= para['iterations']:
break
finally:
if not para['debug'] and model.logger is not None and total_iter>5000:
model.save(total_iter)
# Clean up
distributed.destroy_process_group()