-
Notifications
You must be signed in to change notification settings - Fork 3
/
Copy pathtest.py
158 lines (119 loc) · 5.48 KB
/
test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
import os
import argparse
import numpy as np
import random
import time
from os.path import join
import math
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torch.backends.cudnn as cudnn
from torch.autograd import Variable
from torch.utils.data import DataLoader
import torchvision.transforms as transforms
import torchvision.datasets as datasets
from dataset import *
try:
from ruamel import yaml
except:
import yaml
from easydict import EasyDict as edict
from PIL import Image, ImageOps
import torchvision.transforms.functional as TF
import torchvision.utils as vutils
import skimage
from skimage import io,transform
from skimage.metrics import peak_signal_noise_ratio as compare_psnr
from skimage.metrics import structural_similarity as compare_ssim
from networks import NetG
str_to_list = lambda x: [int(xi) for xi in x.split(',')]
def enable_path(path):
try:
os.makedirs(path)
except OSError:
pass
def get_config(parser):
args = parser.parse_args()
args = edict(vars(args))
cfg_file_path = args.config_file
with open(cfg_file_path, 'r') as stream:
config = edict(yaml.load(stream))
config.update(args)
return config
def compute_metrics(img, gt, multichannel=True):
img = img.numpy().transpose((0, 2, 3, 1))
gt = gt.numpy().transpose((0, 2, 3, 1))
img = img[0,:,:,:] * 255.
gt = gt[0,:,:,:] * 255.
img = np.array(img, dtype = 'uint8')
gt = np.array(gt, dtype = 'uint8')
if not multichannel:
gt = skimage.color.rgb2ycbcr(gt)[:,:,0]
img = skimage.color.rgb2ycbcr(img)[:,:,0]
cur_psnr = compare_psnr(img, gt, data_range=255)
cur_ssim = compare_ssim(img, gt, data_range=255, multichannel=multichannel)
return cur_psnr, cur_ssim
def main(config):
model = NetG(config).cuda()
if config.pretrained:
state = torch.load(config.pretrained)
model.load_state_dict(state)
test_list = os.listdir(config.testroot)
def test():
normalize = transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
test_set = DoubleImageDataset(test_list, config.testroot, crop_height=None, output_height=None, is_random_crop=False, is_mirror=False, normalize=normalize)
test_data_loader = DataLoader(test_set, batch_size=1, shuffle=False)
psnrs, ssims = [], []
enable_path(config.save_image_path)
model.eval()
with torch.no_grad():
for iteration, batch in enumerate(test_data_loader, 0):
data, label = batch
if len(data.size()) == 3:
data, label = data.unsqueeze(0), label.unsqueeze(0)
data = Variable(data).cuda()
label = Variable(label).cuda()
fake = model(data)
data, label, fake = [x*0.5+0.5 for x in [data, label, fake]]
fake, label = fake.data.cpu(), label.data.cpu()
for i in range(fake.shape[0]):
psnr, ssim = compute_metrics(fake[i:i+1], label[i:i+1])
psnrs.append(psnr)
ssims.append(ssim)
vutils.save_image(fake, '{}/{}'.format(config.save_image_path, test_list[iteration]))
print('Dense:\tPSNR: {:.2f}, SSIM: {:.4f}'.format(np.mean(psnrs[:10]), np.mean(ssims[:10])))
print('Sparse:\tPSNR: {:.2f}, SSIM: {:.4f}'.format(np.mean(psnrs[10:]), np.mean(ssims[10:])))
print('Average:\tPSNR: {:.2f}, SSIM: {:.4f}'.format(np.mean(psnrs), np.mean(ssims)))
return np.mean(psnrs), np.mean(ssims)
def test_real():
normalize = transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
real_set = SingleImageDataset(test_list, config.testroot, crop_height=None, output_height=None, is_random_crop=False, is_mirror=False, normalize=normalize)
real_dataloader = DataLoader(real_set, batch_size=1, shuffle=False)
psnrs, ssims = [], []
enable_path(config.save_image_path)
model.eval()
with torch.no_grad():
for iteration, batch in enumerate(real_dataloader, 0):
data = batch
if len(data.size()) == 3:
data = data.unsqueeze(0)
data = Variable(data).cuda()
_, c, h, w = data.size()
h1 = math.ceil(h / 8.) * 8
w1 = math.ceil(w / 8.) * 8
if h1 != h or w1 != w:
data = F.interpolate(data, (h1, w1), mode='bicubic')
fake = model(data)
if h1 != h or w1 != w:
fake = F.interpolate(fake, (h, w), mode='bicubic')
vutils.save_image(fake*0.5+0.5, '{}/{}'.format(config.save_image_path, test_list[iteration]))
if config.test_real:
test_real()
else:
test()
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--config_file', default="test.yaml", type=str, help='the path of config file')
main(get_config(parser))