-
Notifications
You must be signed in to change notification settings - Fork 6
/
Copy pathtest.py
33 lines (24 loc) · 1014 Bytes
/
test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
import torch
from network import AvatarNet
from utils import imload, imsave, maskload
def network_test(args):
# set device
device = torch.device('cuda' if args.gpu_no >= 0 else 'cpu')
# load check point
check_point = torch.load(args.check_point)
# load network
network = AvatarNet(args.layers)
network.load_state_dict(check_point['state_dict'])
network = network.to(device)
# load target images
content_img = imload(args.content, args.imsize, args.cropsize).to(device)
style_imgs = [imload(style, args.imsize, args.cropsize, args.cencrop).to(device) for style in args.style]
masks = None
if args.mask:
masks = [maskload(mask).to(device) for mask in args.mask]
# stylize image
with torch.no_grad():
stylized_img = network(content_img, style_imgs, args.style_strength, args.patch_size, args.patch_stride,
masks, args.interpolation_weights, False)
imsave(stylized_img, 'stylized_image.jpg')
return None