import os
import sys
import pdb
import logging
import time
import torch
import argparse
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
from collections import OrderedDict
import options.options as option
import utils.util as util
from data.util import bgr2ycbcr
from data import create_dataset, create_dataloader
from models import create_model
from models.modules import block as B
import matplotlib.pyplot as plt
# options
opt = 'options/test/test_sr.json'
opt = option.parse(opt, is_train=False)
util.mkdirs((path for key, path in opt['path'].items() if not key == 'pretrain_model_G'))
opt = option.dict_to_nonedict(opt)
# Create test dataset and dataloader
test_loaders = []
for phase, dataset_opt in sorted(opt['datasets'].items()):
test_set = create_dataset(dataset_opt)
test_loader = create_dataloader(test_set, dataset_opt)
print('Number of test images in [{:s}]: {:d}'.format(dataset_opt['name'], len(test_set)))
test_loaders.append(test_loader)
# Create model
model = create_model(opt)
# Register hook for featrue map.
def save_feature(name):
def hook(module, input, output):
featuremap[name] = output
return hook
# Set register hook
conv_idx = 0
featuremap = OrderedDict()
for m in model.netG.module.model.modules():
if m._get_name() == '********':
conv_idx += 1
m.register_forward_hook(save_feature('conv_' + str(conv_idx)))#可在module前向传播或反向传播时注册钩子
# print(conv_idx)
# exit()
for test_loader in test_loaders:
test_set_name = test_loader.dataset.opt['name']
test_start_time = time.time()
dataset_dir = os.path.join(opt['path']['results_root'], test_set_name)
util.mkdir(dataset_dir)
test_results = OrderedDict()
test_results['psnr'] = []
test_results['ssim'] = []
test_results['psnr_y'] = []
test_results['ssim_y'] = []
for data in test_loader:
need_HR = False if test_loader.dataset.opt['dataroot_HR'] is None else True
model.feed_data(data, need_HR=need_HR)
img_path = data['LR_path'][0]
img_name = os.path.splitext(os.path.basename(img_path))[0]
model.test() # test
visuals = model.get_current_visuals(need_HR=need_HR)
sr_img = util.tensor2img(visuals['SR']) # uint8
gwp=0
from mpl_toolkits.axes_grid1 import AxesGrid
for k,v in featuremap.items():
# print(v[1].shape)
# exit()
vals = v[0].squeeze().float().cpu().numpy()
print(vals.shape)
# exit()
fig = plt.figure(figsize=(15,5))
grid = AxesGrid(fig, 111,
nrows_ncols=(2, 16),
axes_pad=0.05,
share_all=True,
label_mode="L",
cbar_location="right",
cbar_mode="single",
)
for val, ax in zip(vals,grid):
im = ax.imshow(val)
grid.cbar_axes[0].colorbar(im)
for cax in grid.cbar_axes:
cax.toggle_label(True)
gwp=gwp+1
# plt.show()
fig.savefig(os.path.join(dataset_dir, str(gwp) + '.png'), dpi=400, bbox_inches='tight', transparent=True)
# from mpl_toolkits.axes_grid1 import AxesGrid
# for k,v in featuremap.items():
# vals = []
# if isinstance(v, tuple):
# for i in range(v[0].shape[1]):
# vals.append(v[0].squeeze().float().cpu().numpy()[i])
# hr = F.upsample(v[1], scale_factor=2, mode='nearest')
# for i in range(hr.shape[1]):
# vals.append(hr.squeeze().float().cpu().numpy()[i])
# else:
# for i in range(v.shape[1]):
# vals = v.squeeze().float().cpu().numpy()
# fig = plt.figure(figsize=(15,5))
# grid = AxesGrid(fig, 111,
# nrows_ncols=(2, 8),
# axes_pad=0.05,
# share_all=True,
# label_mode="L",
# cbar_location="right",
# cbar_mode="single",
# )
# for val, ax in zip(vals,grid):
# im = ax.imshow(val, vmin=0, vmax=2)
# grid.cbar_axes[0].colorbar(im)
# for cax in grid.cbar_axes:
# cax.toggle_label(True)
# plt.show()
# # fig.savefig(os.path.join(dataset_dir, k.split('/')[1] + '.png'), dpi=400, bbox_inches='tight', transparent=True)
运行代码后,就可以看到卷积网络每一层layer输出的feature map的形式,进而可以进一步的分析网络