-
Notifications
You must be signed in to change notification settings - Fork 8
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
3 changed files
with
440 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,242 @@ | ||
#!/usr/bin/env python | ||
# -*- coding: utf-8 -*- | ||
# Written by Willy | ||
|
||
from skimage import io,color | ||
from skimage import transform as sk_transform | ||
from src.lib.network.cn import vgg16 | ||
import torch | ||
from torch.autograd import Variable | ||
import numpy as np | ||
import matplotlib.pyplot as plt | ||
import warnings | ||
import scipy.io as scio | ||
import os | ||
from src.lib.utils.image_opt import findSigma,genDensity | ||
warnings.filterwarnings("ignore") | ||
|
||
|
||
def imread(path): | ||
image = io.imread(path) | ||
if len(image.shape) == 2: | ||
image = color.gray2rgb(image) | ||
return image | ||
|
||
def preferred_resize(image,size): | ||
h,w,c = image.shape | ||
ratio = 1.0 | ||
if h > w: | ||
new_h, new_w = size, int(size * w / h + 0.5) | ||
ratio = size/h | ||
else: | ||
new_h, new_w = int(size * h / w + 0.5), size | ||
ratio = size/w | ||
resized_image = sk_transform.resize(image, (new_h, new_w), preserve_range=True) | ||
out_image = np.zeros((size, size, c), dtype=np.float32) | ||
out_image[...] = 127.5 | ||
out_image[:new_h, :new_w, :] = resized_image | ||
|
||
return out_image,ratio | ||
|
||
def ToTensor(image): | ||
image = image.transpose((2, 0, 1)) | ||
return image | ||
|
||
def normalize(image): | ||
image = (image - 127.5) / 127.5 | ||
return image | ||
|
||
def loadModel( path, branch): | ||
model = branch(pretrained=False) | ||
model.create_architecture() | ||
model.load_state_dict(torch.load(path)) | ||
model.train(False) | ||
return model | ||
|
||
|
||
|
||
|
||
def showTest(image, output,gt): | ||
image = (image * 127.5 + 127.5).astype(np.uint8) | ||
image = image.transpose((1, 2, 0)) | ||
fig = plt.figure() | ||
ax = fig.add_subplot(211) | ||
ax.set_title('original image', fontsize=14) | ||
ax.imshow(image) | ||
ax = fig.add_subplot(223) | ||
ax.set_title('gt count: {:.2f}'.format(np.sum(gt))) | ||
ax.imshow(gt, cmap='jet', interpolation='bilinear') | ||
|
||
ax = fig.add_subplot(224) | ||
ax.set_title('predict count: {:.2f}'.format(np.sum(output))) | ||
ax.imshow(output[0, 0, :], cmap='jet', interpolation='bilinear') | ||
plt.show() | ||
|
||
|
||
def density(image,path,ratio=1): | ||
gt = scio.loadmat(path, struct_as_record=False, squeeze_me=True) | ||
info = gt['image_info'] | ||
dots = info.location | ||
print(len(dots)) | ||
dots = dots*ratio | ||
#sigma = findSigma(dots, 2, 0.3) | ||
sigma = torch.FloatTensor(len(dots)).fill_(15) | ||
densityMap = genDensity(image, dots, sigma, 1001) | ||
print(np.sum(densityMap)) | ||
if np.sum(densityMap) != 0: | ||
densityMap = densityMap * (len(dots) / np.sum(densityMap)) | ||
return densityMap,len(dots) | ||
|
||
|
||
|
||
def test(use_gpu, model_path, image_list): | ||
model = loadModel(model_path, vgg16) | ||
|
||
for path in image_list: | ||
print(path) | ||
path = os.path.join('/home/wl/PycharmProjects/0321/CSRNet_shanghaitechb/data/ShanghaiTech/part_B_final/test_data/images',path) | ||
raw = imread(path) | ||
#print(type(raw)) | ||
image = normalize(ToTensor(raw)) | ||
gt_path = path.replace('images', 'ground_truth').replace('IMG', 'GT_IMG').replace('.jpg', '.mat') | ||
dmap,gt_count = density(raw,gt_path) | ||
|
||
if use_gpu: | ||
torchimage = Variable(torch.FloatTensor(image).unsqueeze(dim=0).cuda()) | ||
|
||
model = model.cuda() | ||
else: | ||
torchimage = Variable(torch.FloatTensor(image).unsqueeze(dim=0)) | ||
|
||
densityMap,feature_map = model(torchimage) | ||
output = densityMap.data.cpu().numpy() | ||
|
||
count = np.sum(output.reshape((output.shape[0], -1)), 1) | ||
print('predict count:',count) | ||
print('gt count:', np.sum(dmap),gt_count) | ||
if gt_count>200 and abs(gt_count - count) <= 5: | ||
showTest(image,output,dmap) | ||
|
||
def get_feature_elsewise(feature): | ||
c,h,w = feature.shape | ||
elsewise_mat = np.zeros((h,w)) | ||
for i in range(c): | ||
elsewise_mat = elsewise_mat + feature[i,:,:] | ||
|
||
return elsewise_mat | ||
|
||
def data_analysis(use_gpu, model_path, image_list,resize=0): | ||
model = loadModel(model_path, vgg16) | ||
running_mse = 0.0 | ||
running_mae = 0.0 | ||
show_result_flag = False | ||
|
||
for path in image_list: | ||
print(path) | ||
|
||
raw = imread(path) | ||
if resize: | ||
resize_img,ratio = preferred_resize(raw,resize) | ||
else: | ||
resize_img = raw | ||
ratio=1 | ||
|
||
image = normalize(ToTensor(resize_img)) | ||
gt_path = path.replace('images', 'ground_truth').replace('IMG', 'GT_IMG').replace('.jpg', '.mat') | ||
dmap, gt_count = density(resize_img, gt_path,ratio) | ||
|
||
if use_gpu: | ||
torchimage = Variable(torch.FloatTensor(image).unsqueeze(dim=0).cuda()) | ||
|
||
model = model.cuda() | ||
else: | ||
torchimage = Variable(torch.FloatTensor(image).unsqueeze(dim=0)) | ||
# print('image type:', torchimage) | ||
densityMap = model(torchimage) | ||
output = densityMap.data.cpu().squeeze().numpy() | ||
print('gt count:{} predict count:{:.4f}'.format(gt_count, np.sum(output))) | ||
running_mae += np.sum(np.abs(np.sum(output) - gt_count)) | ||
running_mse += np.sum((np.sum(output) - gt_count) ** 2) | ||
# print('outputshape:',output.shape) | ||
|
||
# if gt_count > 200 and abs(gt_count - np.sum(output)) <= 5: | ||
# #showTest(image, output, dmap) | ||
# | ||
# fig = plt.figure() | ||
# image_show = (image * 127.5 + 127.5).astype(np.uint8) | ||
# image_show = image_show.transpose((1, 2, 0)) | ||
# ax = fig.add_subplot(231) | ||
# ax.set_title('original image', fontsize=14) | ||
# ax.imshow(image_show) | ||
# ax = fig.add_subplot(232) | ||
# ax.set_title('gt count: {:.2f}'.format(np.sum(dmap))) | ||
# ax.imshow(dmap, cmap='jet', interpolation='bilinear') | ||
# ax = fig.add_subplot(233) | ||
# ax.set_title('predict count: {:.2f}'.format(np.sum(output))) | ||
# ax.imshow(output, cmap='jet', interpolation='bilinear') | ||
# # ax = fig.add_subplot(334) | ||
# # ax.set_title('feature map 3') | ||
# # ax.imshow(get_feature_elsewise(feature_3), cmap='jet', interpolation='bilinear') | ||
# # | ||
# # ax = fig.add_subplot(335) | ||
# # ax.set_title('feature map 8') | ||
# # ax.imshow(get_feature_elsewise(feature_8), cmap='jet', interpolation='bilinear') | ||
# # | ||
# # ax = fig.add_subplot(336) | ||
# # ax.set_title('feature map 15') | ||
# # ax.imshow(get_feature_elsewise(feature_15), cmap='jet', interpolation='bilinear') | ||
# ax = fig.add_subplot(234) | ||
# ax.set_title('feature map 15') | ||
# ax.imshow(get_feature_elsewise(feature_15), cmap='jet', interpolation='bilinear') | ||
# ax = fig.add_subplot(235) | ||
# ax.set_title('feature map 22') | ||
# ax.imshow(get_feature_elsewise(feature_22), cmap='jet', interpolation='bilinear') | ||
# ax = fig.add_subplot(236) | ||
# ax.set_title('feature map 29') | ||
# ax.imshow(get_feature_elsewise(feature_29), cmap='jet', interpolation='bilinear') | ||
# | ||
# plt.show() | ||
|
||
|
||
|
||
# print(output) | ||
# print(dmap) | ||
# count = np.sum(output.reshape((output.shape[0], -1)), 1) | ||
# print('predict count:', count) | ||
# print('gt count:', np.sum(dmap), gt_count) | ||
# if gt_count > 200 and abs(gt_count - count) <= 5: | ||
# showTest(image, output, dmap) | ||
if show_result_flag: | ||
fig = plt.figure() | ||
image_show = (image * 127.5 + 127.5).astype(np.uint8) | ||
image_show = image_show.transpose((1, 2, 0)) | ||
ax = fig.add_subplot(131) | ||
ax.set_title('original image', fontsize=14) | ||
ax.imshow(image_show) | ||
ax = fig.add_subplot(132) | ||
ax.set_title('gt count: {:.2f}'.format(np.sum(dmap))) | ||
ax.imshow(dmap, cmap='jet', interpolation='bilinear') | ||
ax = fig.add_subplot(133) | ||
ax.set_title('predict count: {:.2f}'.format(np.sum(output))) | ||
ax.imshow(output, cmap='jet', interpolation='bilinear') | ||
plt.show() | ||
|
||
|
||
epoch_mae = running_mae/ len(image_list) | ||
epoch_mse = np.sqrt(running_mse / len(image_list)) | ||
print('MAE:{:.4f} MSE:{:.4f}'.format(epoch_mae, epoch_mse)) | ||
|
||
|
||
if __name__ == '__main__': | ||
use_gpu = torch.cuda.is_available() | ||
model_path = 'exp/ShanghaiTech/ex/CRSNet_VGG/Generator/12600.model' | ||
image_path = 'data/ShanghaiTech/part_B_final/test_data/images' | ||
#image_list = ['IMG_105.jpg','IMG_153.jpg','IMG_23.jpg','IMG_283.jpg','IMG_306.jpg','IMG_78.jpg','IMG_93.jpg'] | ||
image_list = [] | ||
for file in os.listdir(image_path): | ||
|
||
image_list.append(os.path.join(image_path, file)) | ||
|
||
data_analysis(use_gpu, model_path, image_list,resize=1040) | ||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,98 @@ | ||
#!/usr/bin/env python | ||
# -*- coding: utf-8 -*- | ||
# Written by Willy, Weiyuan | ||
|
||
import os | ||
import sys | ||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..')) | ||
from src.lib.opt.train import TrainModel | ||
#from src.lib.opt.train_adv import TrainModel | ||
from src.lib.network.cn import vgg16 | ||
import yaml | ||
from easydict import EasyDict as edict | ||
import sys,getopt | ||
|
||
def initGenFeatFromCfg(cfg_file): | ||
# Load cfg parameter from yaml file | ||
with open(cfg_file, 'r') as f: | ||
cfg = edict(yaml.load(f)) | ||
|
||
# Fist load the dataset name | ||
dataset = cfg.DATASET | ||
|
||
# Set values | ||
#path | ||
data_path = cfg[dataset].DATA_PATH | ||
tensor_server_path = cfg[dataset].TENSOR_BOARD_PATH | ||
pre_trained_path = cfg[dataset].PRE_TRAINED_PATH | ||
|
||
#network | ||
batch_size = cfg[dataset].BATCH_SIZE | ||
lr = float(cfg[dataset].LEARNING_RATE) | ||
epoch_num = cfg[dataset].EPOCH_NUM | ||
steps = cfg[dataset].STEPS | ||
decay_rate = cfg[dataset].DECAY_RATE | ||
start_epoch = cfg[dataset].START_EPOCH | ||
snap_shot = cfg[dataset].SNAP_SHOT | ||
resize = cfg[dataset].RESIZE | ||
val_size = cfg[dataset].VAL_SIZE | ||
|
||
return dataset, data_path, tensor_server_path, pre_trained_path, batch_size, lr, epoch_num, steps, decay_rate, start_epoch, snap_shot, resize, val_size | ||
|
||
|
||
def dispHelp(): | ||
print("======================================================") | ||
print(" Usage") | ||
print("======================================================") | ||
print("\t-h display this message") | ||
print("\t--cfg <config file yaml>") | ||
|
||
def main(argv): | ||
cfg_file = 'model/shanghaitech.yml' | ||
|
||
# Get parameters | ||
try: | ||
opts, _ = getopt.getopt(argv, "h:", ["cfg="]) | ||
except getopt.GetoptError: | ||
dispHelp() | ||
return | ||
|
||
for opt, arg in opts: | ||
if opt == '-h': | ||
dispHelp(argv[0]) | ||
return | ||
elif opt in ("--cfg"): | ||
cfg_file = arg | ||
|
||
print("Loading configuration file: ", cfg_file) | ||
(dataset, data_path, tensor_server_path, pre_trained_path, batch_size, lr, epoch_num, steps, decay_rate, start_epoch, snap_shot, resize, val_size) = initGenFeatFromCfg(cfg_file) | ||
|
||
print("Choosen parameters:") | ||
print("-------------------") | ||
print("Dataset: ", dataset) | ||
print("Data location: ", data_path) | ||
print("Tensorboard server root: ", tensor_server_path) | ||
print("Pre-trained model path:", pre_trained_path) | ||
print("Batch size:", batch_size) | ||
print("Learning rate: ", lr) | ||
print("Total epoch number: ", epoch_num) | ||
print("Learning rate steps: ", steps) | ||
print("Learning rate decay rate: ", decay_rate) | ||
print("Start epoch: ", start_epoch) | ||
print("Snap shot: ", snap_shot) | ||
print("Image resize: ", resize) | ||
print("Validation size: ", val_size) | ||
print("") | ||
print("===================") | ||
print("") | ||
|
||
tm = TrainModel(data_path=data_path, batchsize=batch_size, lr=lr, epoch=epoch_num, snap_shot=snap_shot, | ||
server_root_path=tensor_server_path, start_epoch=start_epoch, steps=steps, | ||
decay_rate=decay_rate, branch=vgg16, pre_trained=pre_trained_path,resize=resize, val_size=val_size) | ||
|
||
tm.run() | ||
|
||
|
||
|
||
if __name__ == '__main__': | ||
main(sys.argv[1:]) |
Oops, something went wrong.