-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmain.py
76 lines (58 loc) · 2.36 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
import torch
from option import args
from data.dataset import CAVE_dataset
from torch.utils.data import DataLoader
import torch.nn as nn
from model import HyperSR
import numpy as np
import os
from sewar.full_ref import psnr
from prefetch_generator import BackgroundGenerator
class DataLoaderX(DataLoader):
def __iter__(self):
return BackgroundGenerator(super().__iter__())
def save_ckpt(state, save_path='./log', filename='checkpoint.pth'):
torch.save(state, os.path.join(save_path, filename))
def train(train_loader, args):
net = HyperSR(channels_LSI=3, channels_HSI=31, channels=64, n_endmembers=64).cuda()
net = nn.DataParallel(net, device_ids=list(range(args.n_GPUs)))
net.train()
optimizer = torch.optim.Adam(net.parameters(), lr=args.lr)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=args.n_steps, gamma=args.gamma)
loss_list = []
psnr_list = []
loss_save = []
for idx_iter, [HrHSI, HrLSI, LrHSI] in enumerate(train_loader):
HrHSI = HrHSI.cuda()
HrLSI = HrLSI.cuda()
LrHSI = LrHSI.cuda()
# inference
rec_HrHSI = net(HrLSI, LrHSI)
# losses
loss = nn.L1Loss()(rec_HrHSI, HrHSI)
loss_list.append(loss.data.cpu())
psnr_list.append(psnr(HrHSI.data.cpu().numpy(), rec_HrHSI.data.cpu().numpy(), MAX=1))
# backward
optimizer.zero_grad()
loss.backward()
optimizer.step()
# print
if idx_iter % 200 == 0:
print('iteration %5d of total %5d, loss---%f, psnr---%f' %
(idx_iter + 1, args.n_iters, float(np.array(loss_list).mean()), float(np.array(psnr_list).mean())))
loss_save.append(np.array(loss_list).mean())
if idx_iter % 1000 == 0:
save_ckpt({
'iter': idx_iter + 1,
'state_dict': net.state_dict(),
'loss': loss_save,
}, save_path='log/', filename=args.dataset+'/'+args.model+'_'+str(args.n_endmembers)+'_iter' + str(idx_iter + 1) + '.pth')
loss_list = []
psnr_list = []
scheduler.step()
if __name__ == '__main__':
# dataloader
train_set = CAVE_dataset(args, train=True)
train_loader = DataLoaderX(dataset=train_set, num_workers=6, batch_size=args.batch_size, shuffle=True, drop_last=True, pin_memory=True)
# train
train(train_loader, args)