Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

No increase in resolution #4

Open
pra-dan opened this issue Aug 8, 2020 · 2 comments
Open

No increase in resolution #4

pra-dan opened this issue Aug 8, 2020 · 2 comments

Comments

@pra-dan
Copy link

pra-dan commented Aug 8, 2020

Hi, I ran the inference on a test image
Phot

And used the pre-trained weights provided by you.
I got 2 results:
Phot_bicubic_x3
and

Phot_espcn_x3

But I got no improvement. That is, the resolution of all the images remained almost same: 579x1032 pixels for all 3 images.

What went wrong ?

@WilsonFok2020
Copy link

The test.py is intended for testing only. Consequently, the script takes a high-resolution image, down-samples it to low resolution, and uses espcn to restore the resolution. For comparison sake, it calculates signal to noise ratio (psnr) on the input and plots the effects of simple bicubic interpolation. Therefore, both the bicubic and espcn outputs have as high resolution as the input.

To achieve what you are asking, I suggest you try the following code that I have written based on the test.py script.

@WilsonFok2020
Copy link

import argparse

import torch
import torch.backends.cudnn as cudnn
import numpy as np
import PIL.Image as pil_image

from models import ESPCN
from utils import convert_ycbcr_to_rgb, preprocess, calc_psnr

if name == 'main':
parser = argparse.ArgumentParser()
parser.add_argument('--weights-file', type=str, required=True)
parser.add_argument('--image-file', type=str, required=True)
parser.add_argument('--scale', type=int, default=3)
args = parser.parse_args()

cudnn.benchmark = True

device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
print ('device = ', device)

model = ESPCN(scale_factor=args.scale).to(device)

state_dict = model.state_dict()
for n, p in torch.load(args.weights_file, map_location=lambda storage, loc: storage).items():
    if n in state_dict.keys():
        state_dict[n].copy_(p)
    else:
        raise KeyError(n)

model.eval()

# image = pil_image.open(args.image_file).convert('RGB')
lr = pil_image.open(args.image_file).convert('RGB')

# image_width = (image.width // args.scale) * args.scale
# image_height = (image.height // args.scale) * args.scale

# hr = image.resize((image_width, image_height), resample=pil_image.BICUBIC)
# lr = hr.resize((hr.width // args.scale, hr.height // args.scale), resample=pil_image.BICUBIC)
bicubic = lr.resize((lr.width * args.scale, lr.height * args.scale), resample=pil_image.BICUBIC)
# bicubic.save(args.image_file.replace('.', '_bicubic_x{}.'.format(args.scale)))

lr, _ = preprocess(lr, device)
# hr, _ = preprocess(hr, device)
_, ycbcr = preprocess(bicubic, device)

with torch.no_grad():
    preds = model(lr).clamp(0.0, 1.0)

# psnr = calc_psnr(hr, preds)
# print('PSNR: {:.2f}'.format(psnr))

preds = preds.mul(255.0).cpu().numpy().squeeze(0).squeeze(0)

output = np.array([preds, ycbcr[..., 1], ycbcr[..., 2]]).transpose([1, 2, 0])
output = np.clip(convert_ycbcr_to_rgb(output), 0.0, 255.0).astype(np.uint8)
output = pil_image.fromarray(output)
output.save(args.image_file.replace('.', '_espcn_x{}.'.format(args.scale)))

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants