本代码是超分或者复原任务中,想找出PSNR差距较大的区域的代码
import os
import math
import numpy as np
import cv2
import glob
from skimage import transform
from skimage import measure
from collections import OrderedDict
import matplotlib.pyplot as plt
import matplotlib.patches as patches
def bgr2ycbcr(img, only_y=True):
'''same as matlab rgb2ycbcr
only_y: only return Y channel
Input:
uint8, [0, 255]
float, [0, 1]
'''
in_img_type = img.dtype
img.astype(np.float32)
if in_img_type != np.uint8:
img *= 255.
# convert
if only_y:
rlt = np.dot(img, [24.966, 128.553, 65.481]) / 255.0 + 16.0
else:
rlt = np.matmul(img, [[24.966, 112.0, -18.214], [128.553, -74.203, -93.786],
[65.481, -37.797, 112.0]]) / 255.0 + [16, 128, 128]
if in_img_type == np.uint8:
rlt = rlt.round()
else:
rlt /= 255.
return rlt.astype(in_img_type)
def calculate_psnr(img1, img2):
# img1 and img2 have range [0, 255]
img1 = img1.astype(np.float64)
img2 = img2.astype(np.float64)
mse = np.mean((img1 - img2)**2)
if mse == 0:
return float('inf')
return 20 * math.log10(255.0 / math.sqrt(mse))
def mse2psnr(mse):
if mse == 0:
return float('inf')
return 20 * math.log10(1.0 / math.sqrt(mse))
def plot_heatmap(image, heat_map, alpha=0.5, display=False, save=None, cmap='viridis', axis='on',
dpi=80, verbose=False):
height = image.shape[0]
width = image.shape[1]
# resize heat map
heat_map_resized = transform.resize(heat_map, (height, width))
# normalize heat map
max_value = np.max(heat_map_resized)
min_value = np.min(heat_map_resized)
normalized_heat_map = (heat_map_resized - min_value) / (max_value - min_value)
if display:
# display
plt.imshow(image)
plt.imshow(255 * normalized_heat_map, alpha=alpha, cmap=cmap)
plt.axis(axis)
plt.show()
if save is not None:
if verbose:
print('save image: ' + save)
H, W, C = image.shape
figsize = W / float(dpi), H / float(dpi)
fig = plt.figure(figsize=figsize)
ax = fig.add_axes([0, 0, 1, 1])
ax.axis('off')
ax.imshow(image)
ax.imshow(255 * normalized_heat_map, alpha=alpha, cmap=cmap)
ax.set(xlim=[0, W], ylim=[H, 0], aspect=1)
fig.savefig(save, dpi=dpi, transparent=True)
def to_bin(img, lower, upper):
return (lower < img) & (img < upper)
def plot_diffmap(im_BSL, im_OCT, im_GT, heatmap, thres=0.4, alpha=0.5, display=False,
save=None, cmap='viridis', axis='on', dpi=80, verbose=False):
height, width, _ = im_BSL.shape
# resize heat map
heatmap_resized = transform.resize(heatmap, (height, width))
# normalize heat map
max_value = np.max(heatmap_resized)
min_value = np.min(heatmap_resized)
normalized_heatmap = (heatmap_resized - min_value) / (max_value - min_value)
# capture regions
bin_map = to_bin(normalized_heatmap, thres, 1.0)
label_map = measure.label(bin_map, connectivity=2)
props = measure.regionprops(label_map)
plot_im = im_BSL.copy()
plot_im[~bin_map] = 0
if save is not None:
if verbose:
print('save image: ' + save)
H, W, C = im_BSL.shape
figsize = W / float(dpi), H / float(dpi)
fig = plt.figure(figsize=figsize)
ax = fig.add_axes([0, 0, 1, 1])
ax.axis('off')
ax.imshow(im_BSL)
ax.imshow(normalized_heatmap, alpha=alpha)
# ax.imshow(plot_im, alpha=alpha)
ax.axis(axis)
for i in range(len(props)):
if props[i].bbox_area >= 100:
bbox_coord = props[i].bbox
ax.add_patch(
patches.Rectangle(
(bbox_coord[1], bbox_coord[0]),
bbox_coord[3] - bbox_coord[1],
bbox_coord[2] - bbox_coord[0],
edgecolor='y',
linewidth = 6,
fill=False
))
psnr = calculate_psnr(im_OCT[bbox_coord[0]:bbox_coord[2], bbox_coord[1]:bbox_coord[3], :]*255., \
im_GT[bbox_coord[0]:bbox_coord[2], bbox_coord[1]:bbox_coord[3], :]*255.) - \
calculate_psnr(im_BSL[bbox_coord[0]:bbox_coord[2], bbox_coord[1]:bbox_coord[3], :]*255., \
im_GT[bbox_coord[0]:bbox_coord[2], bbox_coord[1]:bbox_coord[3], :]*255.)
h_aln = 'right' if W - bbox_coord[1] < 50 else 'left'
if bbox_coord[0] < 20:
ax.text(bbox_coord[1], bbox_coord[2], "{:+.2f}".format(psnr), color='r',
verticalalignment='top', horizontalalignment=h_aln, fontsize=26)
else:
ax.text(bbox_coord[1], bbox_coord[0], "{:+.2f}".format(psnr), color='r',
verticalalignment='bottom', horizontalalignment=h_aln, fontsize=26)
ax.set(xlim=[0, W], ylim=[H, 0], aspect=1)
fig.savefig(save, dpi=dpi, transparent=True)
# plt.show()
def plot_diff_patch(im_BSL, im_OCT, im_GT, heatmap, thres=0.4, alpha=0.5, display=False,
save=None, cmap='viridis', axis='on', dpi=80, verbose=False):
H, W, C = im_BSL.shape
# resize heat map
heatmap_resized = transform.resize(heatmap, (H, W))
# normalize heat map
max_value = np.max(heatmap_resized)
min_value = np.min(heatmap_resized)
normalized_heatmap = (heatmap_resized - min_value) / (max_value - min_value)
# capture regions
bin_map = to_bin(normalized_heatmap, 0.4, 1.0)
label_map = measure.label(bin_map, connectivity=2)
props = measure.regionprops(label_map)
bbox_err = []
for i in range(len(props)):
if props[i].bbox_area >= 100:
bbox_coord = props[i].bbox
err = np.mean(normalized_heatmap[bbox_coord[0]:bbox_coord[2], bbox_coord[1]:bbox_coord[3]])
psnr = calculate_psnr(im_OCT[bbox_coord[0]:bbox_coord[2], bbox_coord[1]:bbox_coord[3], :]*255., \
im_GT[bbox_coord[0]:bbox_coord[2], bbox_coord[1]:bbox_coord[3], :]*255.) - \
calculate_psnr(im_BSL[bbox_coord[0]:bbox_coord[2], bbox_coord[1]:bbox_coord[3], :]*255., \
im_GT[bbox_coord[0]:bbox_coord[2], bbox_coord[1]:bbox_coord[3], :]*255.)
bbox_err.append((i, err, psnr))
bbox_err.sort(key=lambda x:x[1], reverse=True)
im_diff = np.clip(im_OCT - im_BSL + 0.5, 0.0, 1.0)
save_dir20= '/data1/cropimage/diff6_cvpr'
save_path20 = os.path.join(save_dir20, base_name.replace('Layer_HRLRadd_SRResNet_16B64C_alpha=0.5', '')+'.png')
im_diff20=im_diff*255
cv2.imwrite(save_path20,im_diff20[:, :, [2, 1, 0]])
num_bbox = min(len(bbox_err), 5)
# Plot patches
fig, axes = plt.subplots(nrows=num_bbox, ncols=4, figsize=(15,15))
if axes.ndim == 1:
axes = [axes]
for i in range(num_bbox):
ind, err, psnr = bbox_err[i]
bbox_coord = props[ind].bbox
axes[i][0].imshow(im_GT[bbox_coord[0]:bbox_coord[2], bbox_coord[1]:bbox_coord[3], :])
axes[i][1].imshow(im_BSL[bbox_coord[0]:bbox_coord[2], bbox_coord[1]:bbox_coord[3], :])
axes[i][2].imshow(im_OCT[bbox_coord[0]:bbox_coord[2], bbox_coord[1]:bbox_coord[3], :])
axes[i][3].imshow(im_diff[bbox_coord[0]:bbox_coord[2], bbox_coord[1]:bbox_coord[3], :])
# ###################################################################################################
# im_GT1=im_GT[bbox_coord[0]:bbox_coord[2], bbox_coord[1]:bbox_coord[3], :]
# im_BSL1=im_BSL[bbox_coord[0]:bbox_coord[2], bbox_coord[1]:bbox_coord[3], :]
# im_OCT1=im_OCT[bbox_coord[0]:bbox_coord[2], bbox_coord[1]:bbox_coord[3], :]
# im_diff1=im_diff[bbox_coord[0]:bbox_coord[2], bbox_coord[1]:bbox_coord[3], :]
# axes[i][0].imshow(im_GT1)
# axes[i][1].imshow(im_BSL1)
# axes[i][2].imshow(im_OCT1)
# axes[i][3].imshow(im_diff1)
# save_dir1= '/data1/cropimage/diff/im_GT1'
# save_path1 = os.path.join(save_dir1, base_name.replace('Layer_HRLRadd_SRResNet_16B64C_alpha=0.5', '')+str(i+1) + '.png')
# im_GT1=cv2.resize(im_GT1*255,(100, 100))
# cv2.imwrite(save_path1,im_GT1[:, :, [2, 1, 0]])
# save_dir2= '/data1/cropimage/diff/im_BSL1'
# save_path2 = os.path.join(save_dir2, base_name.replace('Layer_HRLRadd_SRResNet_16B64C_alpha=0.5', '')+str(i+1) + '.png')
# #im_BSL1=im_BSL1*255
# im_BSL1=cv2.resize(im_BSL1*255,(100, 100))
# cv2.imwrite(save_path2,im_BSL1[:, :, [2, 1, 0]])
# save_dir3= '/data1/cropimage/diff/im_OCT1'
# save_path3 = os.path.join(save_dir3, base_name.replace('Layer_HRLRadd_SRResNet_16B64C_alpha=0.5', '')+str(i+1) + '.png')
# #im_OCT1=im_OCT1*255
# im_OCT1=cv2.resize(im_OCT1*255,(100, 100))
# cv2.imwrite(save_path3,im_OCT1[:, :, [2, 1, 0]])
# save_dir4= '/data1/cropimage/diff/im_diff1'
# save_path4 = os.path.join(save_dir4, base_name.replace('Layer_HRLRadd_SRResNet_16B64C_alpha=0.5', '')+str(i+1) + '.png')
# #im_diff1=im_diff1*255
# im_diff1=cv2.resize(im_diff1*255,(100, 100))
# cv2.imwrite(save_path4,im_diff1[:, :, [2, 1, 0]])
axes[i][3].text(bbox_coord[3]-bbox_coord[1], bbox_coord[2]-bbox_coord[0], \
"{:+.2f}".format(psnr), color='r', fontsize=16)
axes[i][3].text(bbox_coord[3]-bbox_coord[1], 0, \
"{}".format(bbox_coord), color='r', fontsize=16)
fig.savefig(save_path, dpi=300, bbox_inches='tight', transparent=False)
# plt.show()
folder_BSL = "/data1/results10.09/(v1)Layer_HRLR_withoutconnection_SRResNet_16B64C_alpha=0.5/DIV2K_VAL/"
folder_OCT = "/data1/results/(ture)_1X1_directshare_SRResNet_44B64C_alpha=0.5/DIV2K_VAL/"
folder_GT = '/data1/data/DIV2K_VAL/DIV2K_valid_HR/'
crop_border = 4
suffix = '' # suffix for Gen images
test_Y = False # True: test Y channel only; False: test RGB channels
PSNR_all = []
SSIM_all = []
img_list = sorted(glob.glob(folder_OCT + '/*'))[:100]
if test_Y:
print('Testing Y channel.')
else:
print('Testing RGB channels.')
patch_size = 32
stride = 10
for img_path in img_list:
base_name = os.path.splitext(os.path.basename(img_path))[0]
im_OCT = cv2.imread(img_path)[:, :, [2, 1, 0]] / 255.
im_BSL = cv2.imread(os.path.join(folder_BSL, base_name.replace('Layer_HRLRadd_SRResNet_16B64C_alpha=0.5', 'SRResNet_16B64C') + suffix + '.png'))[:, :, [2, 1, 0]] / 255.
im_GT = cv2.imread(os.path.join(folder_GT, base_name.replace('Layer_HRLRadd_SRResNet_16B64C_alpha=0.5', '').replace('_bicLRx4', '') + suffix + '.png'))[:, :, [2, 1, 0]] / 255.
H, W, C = im_OCT.shape
H_axis = np.arange(0, H - patch_size, stride)
W_axis = np.arange(0, W - patch_size, stride)
err_map = np.zeros((len(H_axis), len(W_axis)))
inv_map = np.zeros((len(H_axis), len(W_axis)))
total_err = np.mean((im_OCT - im_BSL)**2)
for i, h in enumerate(H_axis):
for j, w in enumerate(W_axis):
patch_OCT = im_OCT[h:h+patch_size, w:w+patch_size, :]
patch_BSL = im_BSL[h:h+patch_size, w:w+patch_size, :]
patch_err = np.sum((patch_OCT - patch_BSL)**2) / (H*W*C)
err_map[i, j] = mse2psnr(patch_err)
inv_map[i, j] = mse2psnr(total_err- patch_err)
save_dir = '/data1/cropimage/diff_cvpr/'
save_path = os.path.join(save_dir, base_name.replace('Layer_HRLRadd_SRResNet_16B64C_alpha=0.5', '') + '.png')
save_dir6 = '/data1/cropimage/heatdiff_cvpr/'
save_path6 = os.path.join(save_dir6, base_name.replace('Layer_HRLRadd_SRResNet_16B64C_alpha=0.5', '') + '.png')
#plot_heatmap(im_BSL, inv_map, alpha=0.7, save=save_path, axis='off', display=False)
plot_diffmap(im_BSL, im_OCT, im_GT, inv_map, alpha=0.5, save=save_path6, axis='off', display=False)
plot_diff_patch(im_BSL, im_OCT, im_GT, inv_map, alpha=0.5, save=save_path, axis='off', display=False)
改进版:
import os
import math
import numpy as np
import cv2
import glob
from skimage import transform
from skimage import measure
from collections import OrderedDict
import matplotlib.pyplot as plt
import matplotlib.patches as patches
def bgr2ycbcr(img, only_y=True):
'''same as matlab rgb2ycbcr
only_y: only return Y channel
Input:
uint8, [0, 255]
float, [0, 1]
'''
in_img_type = img.dtype
img.astype(np.float32)
if in_img_type != np.uint8:
img *= 255.
# convert
if only_y:
rlt = np.dot(img, [24.966, 128.553, 65.481]) / 255.0 + 16.0
else:
rlt = np.matmul(img, [[24.966, 112.0, -18.214], [128.553, -74.203, -93.786],
[65.481, -37.797, 112.0]]) / 255.0 + [16, 128, 128]
if in_img_type == np.uint8:
rlt = rlt.round()
else:
rlt /= 255.
return rlt.astype(in_img_type)
def calculate_psnr(img1, img2):
# img1 and img2 have range [0, 255]
img1 = img1.astype(np.float64)
img2 = img2.astype(np.float64)
mse = np.mean((img1 - img2)**2)
if mse == 0:
return float('inf')
return 20 * math.log10(255.0 / math.sqrt(mse))
def mse2psnr(mse):
if mse == 0:
return float('inf')
return 20 * math.log10(1.0 / math.sqrt(mse))
def plot_heatmap(image, heat_map, alpha=0.5, display=False, save=None, cmap='viridis', axis='on',
dpi=80, verbose=False):
height = image.shape[0]
width = image.shape[1]
# resize heat map
heat_map_resized = transform.resize(heat_map, (height, width))
# normalize heat map
max_value = np.max(heat_map_resized)
min_value = np.min(heat_map_resized)
normalized_heat_map = (heat_map_resized - min_value) / (max_value - min_value)
if display:
# display
plt.imshow(image)
plt.imshow(255 * normalized_heat_map, alpha=alpha, cmap=cmap)
plt.axis(axis)
plt.show()
if save is not None:
if verbose:
print('save image: ' + save)
H, W, C = image.shape
figsize = W / float(dpi), H / float(dpi)
fig = plt.figure(figsize=figsize)
ax = fig.add_axes([0, 0, 1, 1])
ax.axis('off')
ax.imshow(image)
ax.imshow(255 * normalized_heat_map, alpha=alpha, cmap=cmap)
ax.set(xlim=[0, W], ylim=[H, 0], aspect=1)
fig.savefig(save, dpi=dpi, transparent=True)
def to_bin(img, lower, upper):
return (lower < img) & (img < upper)
def plot_diffmap(im_BSL, im_OCT, im_GT, heatmap, thres=0.4, alpha=0.5, display=False,
save=None, cmap='viridis', axis='on', dpi=80, verbose=False):
height, width, _ = im_BSL.shape
# resize heat map
heatmap_resized = transform.resize(heatmap, (height, width))
# normalize heat map
max_value = np.max(heatmap_resized)
min_value = np.min(heatmap_resized)
normalized_heatmap = (heatmap_resized - min_value) / (max_value - min_value)
# capture regions
bin_map = to_bin(normalized_heatmap, thres, 1.0)
label_map = measure.label(bin_map, connectivity=2)
props = measure.regionprops(label_map)
plot_im = im_BSL.copy()
plot_im[~bin_map] = 0
if save is not None:
if verbose:
print('save image: ' + save)
H, W, C = im_BSL.shape
figsize = W / float(dpi), H / float(dpi)
fig = plt.figure(figsize=figsize)
ax = fig.add_axes([0, 0, 1, 1])
ax.axis('off')
ax.imshow(im_BSL)
ax.imshow(normalized_heatmap, alpha=alpha)
# ax.imshow(plot_im, alpha=alpha)
ax.axis(axis)
for i in range(len(props)):
if props[i].bbox_area >= 100:
bbox_coord = props[i].bbox
ax.add_patch(
patches.Rectangle(
(bbox_coord[1], bbox_coord[0]),
bbox_coord[3] - bbox_coord[1],
bbox_coord[2] - bbox_coord[0],
edgecolor='y',
linewidth = 6,
fill=False
))
psnr = calculate_psnr(im_OCT[bbox_coord[0]:bbox_coord[2], bbox_coord[1]:bbox_coord[3], :]*255., \
im_GT[bbox_coord[0]:bbox_coord[2], bbox_coord[1]:bbox_coord[3], :]*255.) - \
calculate_psnr(im_BSL[bbox_coord[0]:bbox_coord[2], bbox_coord[1]:bbox_coord[3], :]*255., \
im_GT[bbox_coord[0]:bbox_coord[2], bbox_coord[1]:bbox_coord[3], :]*255.)
h_aln = 'right' if W - bbox_coord[1] < 50 else 'left'
if bbox_coord[0] < 20:
ax.text(bbox_coord[1], bbox_coord[2], "{:+.2f}".format(psnr), color='r',
verticalalignment='top', horizontalalignment=h_aln, fontsize=26)
else:
ax.text(bbox_coord[1], bbox_coord[0], "{:+.2f}".format(psnr), color='r',
verticalalignment='bottom', horizontalalignment=h_aln, fontsize=26)
ax.set(xlim=[0, W], ylim=[H, 0], aspect=1)
fig.savefig(save, dpi=dpi, transparent=True)
# plt.show()
def plot_diff_patch(im_BSL, im_OCT, im_GT, heatmap, thres=0.4, alpha=0.5, display=False,
save=None, cmap='viridis', axis='on', dpi=80, verbose=False):
H, W, C = im_BSL.shape
# resize heat map
heatmap_resized = transform.resize(heatmap, (H, W))
# normalize heat map
max_value = np.max(heatmap_resized)
min_value = np.min(heatmap_resized)
normalized_heatmap = (heatmap_resized - min_value) / (max_value - min_value)
# capture regions
bin_map = to_bin(normalized_heatmap, 0.4, 1.0)
label_map = measure.label(bin_map, connectivity=2)
props = measure.regionprops(label_map)
bbox_err = []
for i in range(len(props)):
if props[i].bbox_area >= 100:
bbox_coord = props[i].bbox
err = np.mean(normalized_heatmap[bbox_coord[0]:bbox_coord[2], bbox_coord[1]:bbox_coord[3]])
psnr = calculate_psnr(im_OCT[bbox_coord[0]:bbox_coord[2], bbox_coord[1]:bbox_coord[3], :]*255., \
im_GT[bbox_coord[0]:bbox_coord[2], bbox_coord[1]:bbox_coord[3], :]*255.) - \
calculate_psnr(im_BSL[bbox_coord[0]:bbox_coord[2], bbox_coord[1]:bbox_coord[3], :]*255., \
im_GT[bbox_coord[0]:bbox_coord[2], bbox_coord[1]:bbox_coord[3], :]*255.)
bbox_err.append((i, err, psnr))
bbox_err.sort(key=lambda x:x[1], reverse=True)
im_diff = np.clip(im_OCT - im_BSL + 0.5, 0.0, 1.0)
save_dir20= '/data1/cropimage/diff6_cvpr'
save_path20 = os.path.join(save_dir20, base_name.replace('Layer_HRLRadd_SRResNet_16B64C_alpha=0.5', '')+'.png')
im_diff20=im_diff*255
cv2.imwrite(save_path20,im_diff20[:, :, [2, 1, 0]])
num_bbox = min(len(bbox_err), 5)
# Plot patches
fig, axes = plt.subplots(nrows=num_bbox, ncols=4, figsize=(15,15))
if axes.ndim == 1:
axes = [axes]
for i in range(num_bbox):
ind, err, psnr = bbox_err[i]
bbox_coord = props[ind].bbox
axes[i][0].imshow(im_GT[bbox_coord[0]:bbox_coord[2], bbox_coord[1]:bbox_coord[3], :])
axes[i][1].imshow(im_BSL[bbox_coord[0]:bbox_coord[2], bbox_coord[1]:bbox_coord[3], :])
axes[i][2].imshow(im_OCT[bbox_coord[0]:bbox_coord[2], bbox_coord[1]:bbox_coord[3], :])
axes[i][3].imshow(im_diff[bbox_coord[0]:bbox_coord[2], bbox_coord[1]:bbox_coord[3], :])
# ###################################################################################################
# im_GT1=im_GT[bbox_coord[0]:bbox_coord[2], bbox_coord[1]:bbox_coord[3], :]
# im_BSL1=im_BSL[bbox_coord[0]:bbox_coord[2], bbox_coord[1]:bbox_coord[3], :]
# im_OCT1=im_OCT[bbox_coord[0]:bbox_coord[2], bbox_coord[1]:bbox_coord[3], :]
# im_diff1=im_diff[bbox_coord[0]:bbox_coord[2], bbox_coord[1]:bbox_coord[3], :]
# axes[i][0].imshow(im_GT1)
# axes[i][1].imshow(im_BSL1)
# axes[i][2].imshow(im_OCT1)
# axes[i][3].imshow(im_diff1)
# save_dir1= '/data1/cropimage/diff/im_GT1'
# save_path1 = os.path.join(save_dir1, base_name.replace('Layer_HRLRadd_SRResNet_16B64C_alpha=0.5', '')+str(i+1) + '.png')
# im_GT1=cv2.resize(im_GT1*255,(100, 100))
# cv2.imwrite(save_path1,im_GT1[:, :, [2, 1, 0]])
# save_dir2= '/data1/cropimage/diff/im_BSL1'
# save_path2 = os.path.join(save_dir2, base_name.replace('Layer_HRLRadd_SRResNet_16B64C_alpha=0.5', '')+str(i+1) + '.png')
# #im_BSL1=im_BSL1*255
# im_BSL1=cv2.resize(im_BSL1*255,(100, 100))
# cv2.imwrite(save_path2,im_BSL1[:, :, [2, 1, 0]])
# save_dir3= '/data1/cropimage/diff/im_OCT1'
# save_path3 = os.path.join(save_dir3, base_name.replace('Layer_HRLRadd_SRResNet_16B64C_alpha=0.5', '')+str(i+1) + '.png')
# #im_OCT1=im_OCT1*255
# im_OCT1=cv2.resize(im_OCT1*255,(100, 100))
# cv2.imwrite(save_path3,im_OCT1[:, :, [2, 1, 0]])
# save_dir4= '/data1/cropimage/diff/im_diff1'
# save_path4 = os.path.join(save_dir4, base_name.replace('Layer_HRLRadd_SRResNet_16B64C_alpha=0.5', '')+str(i+1) + '.png')
# #im_diff1=im_diff1*255
# im_diff1=cv2.resize(im_diff1*255,(100, 100))
# cv2.imwrite(save_path4,im_diff1[:, :, [2, 1, 0]])
axes[i][3].text(bbox_coord[3]-bbox_coord[1], bbox_coord[2]-bbox_coord[0], \
"{:+.2f}".format(psnr), color='r', fontsize=16)
axes[i][3].text(bbox_coord[3]-bbox_coord[1], 0, \
"{}".format(bbox_coord), color='r', fontsize=16)
fig.savefig(save_path, dpi=300, bbox_inches='tight', transparent=False)
# plt.show()
folder_BSL = "/data1/results10.09/(multi_scale)_SRResNet_16B64C/DIV2K_VAL0.8/"
folder_OCT ="/data1/results/(multi)1X1_directshare_SRResNet_48B64C_alpha=0.5/DIV2K_VAL0.8/"
folder_GT = "/data1/data/multiscale_dataset/DIV2K_valid_HR_0.8/"
crop_border = 4
suffix = '' # suffix for Gen images
test_Y = False # True: test Y channel only; False: test RGB channels
PSNR_all = []
SSIM_all = []
img_list = sorted(glob.glob(folder_OCT + '/*'))[:100]
if test_Y:
print('Testing Y channel.')
else:
print('Testing RGB channels.')
patch_size = 32
stride = 10
for img_path in img_list:
base_name = os.path.splitext(os.path.basename(img_path))[0]
im_OCT = cv2.imread(img_path)[:, :, [2, 1, 0]] / 255.
im_BSL = cv2.imread(os.path.join(folder_BSL, base_name.replace('Layer_HRLRadd_SRResNet_16B64C_alpha=0.5', 'SRResNet_16B64C') + suffix + '.png'))[:, :, [2, 1, 0]] / 255.
im_GT = cv2.imread(os.path.join(folder_GT, base_name.replace('_bicLRx4', '_bicLRx0.6') + suffix + '.png'))[:, :, [2, 1, 0]] / 255.
H, W, C = im_OCT.shape
H_axis = np.arange(0, H - patch_size, stride)
W_axis = np.arange(0, W - patch_size, stride)
err_map = np.zeros((len(H_axis), len(W_axis)))
inv_map = np.zeros((len(H_axis), len(W_axis)))
total_err = np.mean((im_OCT - im_BSL)**2)
for i, h in enumerate(H_axis):
for j, w in enumerate(W_axis):
patch_OCT = im_OCT[h:h+patch_size, w:w+patch_size, :]
patch_BSL = im_BSL[h:h+patch_size, w:w+patch_size, :]
patch_err = np.sum((patch_OCT - patch_BSL)**2) / (H*W*C)
err_map[i, j] = mse2psnr(patch_err)
inv_map[i, j] = mse2psnr(total_err- patch_err)
save_dir = '/data1/cropimage/heatmap/diff_DIV2K_VAL0.8/'
if not os.path.exists(save_dir):
os.mkdir(save_dir)
save_path = os.path.join(save_dir, base_name.replace('Layer_HRLRadd_SRResNet_16B64C_alpha=0.5', '') + '.png')
save_dir6 = '/data1/cropimage/heatmap/heatdiff_DIV2K_VAL0.8/'
if not os.path.exists(save_dir6):
os.mkdir(save_dir6)
save_path6 = os.path.join(save_dir6, base_name.replace('Layer_HRLRadd_SRResNet_16B64C_alpha=0.5', '') + '.png')
#plot_heatmap(im_BSL, inv_map, alpha=0.7, save=save_path, axis='off', display=False)
plot_diffmap(im_BSL, im_OCT, im_GT, inv_map, alpha=0.5, save=save_path6, axis='off', display=False)
plot_diff_patch(im_BSL, im_OCT, im_GT, inv_map, alpha=0.5, save=save_path, axis='off', display=False)