-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathvisualize.py
58 lines (50 loc) · 1.96 KB
/
visualize.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
import argparse
import numpy as np
import six
from tqdm import tqdm
import time
from model import HumanPartsNet
from data import MiniBatchLoader
import chainer
import os
from os.path import isdir, basename, join
import cv2
resultdir = "./result/"
X_dir = "./data/img/"
y_dir = "./data/mask/"
def standardize(image):
subtracted_img = 2. * image.astype(np.float32) - 255.
return subtracted_img / 255.
parser = argparse.ArgumentParser(description='Human parts network')
parser.add_argument('--gpu', '-g', default=-1, type=int,
help='GPU ID (negative value indicates CPU)')
parser.add_argument('--pretrainedmodel', '-p', default=None,
help='Path to pretrained model')
parser.add_argument('--file', '-f', type=str,
help='Path to image to predict mask')
parser.add_argument('--extension', '-e', type=str,
help='Extension for processed file')
args = parser.parse_args()
# model setteing
model = HumanPartsNet(n_class=15)
if args.pretrainedmodel is not None:
from chainer import serializers
serializers.load_hdf5(args.pretrainedmodel, model)
if not isdir(args.file):
bname = basename(args.file)
img = np.transpose(np.expand_dims(standardize(cv2.resize(
cv2.imread(args.file), (300, 300)).astype(np.uint8)), 0), (0, 3, 1, 2))
x = chainer.Variable(img.astype(np.float32), volatile='on')
y = model.predict(x)
mask = np.argmax(y.data[0], axis=0)
np.save(resultdir + bname + '.npy', mask)
else:
for f in tqdm(os.listdir(args.file)):
if f.endswith(args.extension):
bname = basename(f)
img = np.transpose(np.expand_dims(standardize(cv2.resize(
cv2.imread(join(args.file, f)), (300, 300)).astype(np.uint8)), 0), (0, 3, 1, 2))
x = chainer.Variable(img.astype(np.float32), volatile='on')
y = model.predict(x)
mask = np.argmax(y.data[0], axis=0)
np.save(resultdir + bname + '.npy', mask)