Densefuse
Densefuse的源码位置
github库
我是直接在pytorch1.12.1+cu113上复现的,具体要改哪些可以根据你的报错来看
可以参考这边博客,当然她不全面,建议直接复制我的代码替换后再看还会有什么问题
densefuse-pytorch 图像融合代码复现记录
需要额外安装的扩展
pip install torchfile
pip install scikit-image
测试部分
我这里给出我修改之后的utils.py和test_image.py,这样你至少可以在第一时间跑通测试
utils.py
import os
from os import listdir, mkdir, sep
from os.path import join, exists, splitext
import random
import numpy as np
import torch
from PIL import Image
from torch.autograd import Variable
import torchfile
from args_fusion import args
# from scipy.misc import imread, imsave, imresize
import matplotlib as mpl
import cv2
from torchvision import datasets, transforms
from skimage.transform import resize as imresize
from imageio import imwrite,imread
def list_images(directory):
images = []
names = []
dir = listdir(directory)
dir.sort()
for file in dir:
name = file.lower()
if name.endswith('.png'):
images.append(join(directory, file))
elif name.endswith('.jpg'):
images.append(join(directory, file))
elif name.endswith('.jpeg'):
images.append(join(directory, file))
name1 = name.split('.')
names.append(name1[0])
return images
def tensor_load_rgbimage(filename, size=None, scale=None, keep_asp=False):
img = Image.open(filename).convert('RGB')
if size is not None:
if keep_asp:
size2 = int(size * 1.0 / img.size[0] * img.size[1])
img = img.resize((size, size2), Image.ANTIALIAS)
else:
img = img.resize((size, size), Image.ANTIALIAS)
elif scale is not None:
img = img.resize((int(img.size[0] / scale), int(img.size[1] / scale)), Image.ANTIALIAS)
img = np.array(img).transpose(2, 0, 1)
img = torch.from_numpy(img).float()
return img
def tensor_save_rgbimage(tensor, filename, cuda=True):
if cuda:
# img = tensor.clone().cpu().clamp(0, 255).numpy()
img = tensor.cpu().clamp(0, 255).data[0].numpy()
else:
# img = tensor.clone().clamp(0, 255).numpy()
img = tensor.clamp(0, 255).numpy()
img = img.transpose(1, 2, 0).astype('uint8')
img = Image.fromarray(img)
img.save(filename)
def tensor_save_bgrimage(tensor, filename, cuda=False):
(b, g, r) = torch.chunk(tensor, 3)
tensor = torch.cat((r, g, b))
tensor_save_rgbimage(tensor, filename, cuda)
def gram_matrix(y):
(b, ch, h, w) = y.size()
features = y.view(b, ch, w * h)
features_t = features.transpose(1, 2)
gram = features.bmm(features_t) / (ch * h * w)
return gram
def matSqrt(x):
U,D,V = torch.svd(x)
return U * (D.pow(0.5).diag()) * V.t()
# load training images
def load_dataset(image_path, BATCH_SIZE, num_imgs=None):
if num_imgs is None:
num_imgs = len(image_path)
original_imgs_path = image_path[:num_imgs]
# random
random.shuffle(original_imgs_path)
mod = num_imgs % BATCH_SIZE
print('BATCH SIZE %d.' % BATCH_SIZE)
print('Train images number %d.' % num_imgs)
print('Train images samples %s.' % str(num_imgs / BATCH_SIZE))
if mod > 0:
print('Train set has been trimmed %d samples...\n' % mod)
original_imgs_path = original_imgs_path[:-mod]
batches = int(len(original_imgs_path) // BATCH_SIZE)
return original_imgs_path, batches
def get_image(path, height=256, width=256, mode='L'):
if mode == 'L':
image = imread(path, pilmode=mode)
elif mode == 'RGB':
image = Image.open(path).convert('RGB')
if height is not None and width is not None:
image = imresize(image, [height, width], interp='nearest')
return image
def get_train_images_auto(paths, height=256, width=256, mode='RGB'):
if isinstance(paths, str):
paths = [paths]
images = []
for path in paths:
image = get_image(path, height, width, mode=mode)
if mode == 'L':
image = np.reshape(image, [1, image.shape[0], image.shape[1]])
else:
image = np.reshape(image, [image.shape[2], image.shape[0], image.shape[1]])
images.append(image)
images = np.stack(images, axis=0)
images = torch.from_numpy(images).float()
return images
def get_test_images(paths, height=None, width=None, mode='RGB'):
ImageToTensor = transforms.Compose([transforms.ToTensor()])
if isinstance(paths, str):
paths = [paths]
images = []
for path in paths:
image = get_image(path, height, width, mode=mode)
if mode == 'L':
image = np.reshape(image, [1, image.shape[0], image.shape[1]])
else:
# test = ImageToTensor(image).numpy()
# shape = ImageToTensor(image).size()
image = ImageToTensor(image).float().numpy()*255
images.append(image)
images = np.stack(images, axis=0)
images = torch.from_numpy(images).float()
return images
# colormap
def colormap():
return mpl.colors.LinearSegmentedColormap.from_list('cmap', ['#FFFFFF', '#98F5FF', '#00FF00', '#FFFF00','#FF0000', '#8B0000'], 256)
def save_images(path, data):
# if isinstance(paths, str):
# paths = [paths]
#
# t1 = len(paths)
# t2 = len(datas)
# assert (len(paths) == len(datas))
# if prefix is None:
# prefix = ''
# if suffix is None:
# suffix = ''
if data.shape[2] == 1:
data = data.reshape([data.shape[0], data.shape[1]])
cv2.imwrite(path, data)
# for i, path in enumerate(paths):
# data = datas[i]
# # print('data ==>>\n', data)
# if data.shape[2] == 1:
# data = data.reshape([data.shape[0], data.shape[1]])
# # print('data reshape==>>\n', data)
#
# name, ext = splitext(path)
# name = name.split(sep)[-1]
#
# path = join(save_path, prefix + suffix + ext)
# print('data path==>>', path)
#
# # new_im = Image.fromarray(data)
# # new_im.show()
#
# imsave(path, data)
test_image.py
# test phase
import torch
from torch.autograd import Variable
from net import DenseFuse_net
import utils
from args_fusion import args
import numpy as np
import time
import cv2
import os
def load_model(path, input_nc, output_nc):
nest_model = DenseFuse_net(input_nc, output_nc)
nest_model.load_state_dict(torch.load(path))
para = sum([np.prod(list(p.size())) for p in nest_model.parameters()])
type_size = 4
print('Model {} : params: {:4f}M'.format(nest_model._get_name(), para * type_size / 1000 / 1000))
nest_model.eval()
nest_model.cuda()
return nest_model
def _generate_fusion_image(model, strategy_type, img1, img2):
# encoder
# test = torch.unsqueeze(img_ir[:, i, :, :], 1)
en_r = model.encoder(img1)
# vision_features(en_r, 'ir')
en_v = model.encoder(img2)
# vision_features(en_v, 'vi')
# fusion
f = model.fusion(en_r, en_v, strategy_type=strategy_type)
# f = en_v
# decoder
img_fusion = model.decoder(f)
return img_fusion[0]
def run_demo(model, infrared_path, visible_path, output_path_root, index, fusion_type, network_type, strategy_type, ssim_weight_str, mode):
# if mode == 'L':
ir_img = utils.get_test_images(infrared_path, height=None, width=None, mode=mode)
vis_img = utils.get_test_images(visible_path, height=None, width=None, mode=mode)
# else:
# img_ir = utils.tensor_load_rgbimage(infrared_path)
# img_ir = img_ir.unsqueeze(0).float()
# img_vi = utils.tensor_load_rgbimage(visible_path)
# img_vi = img_vi.unsqueeze(0).float()
# dim = img_ir.shape
if args.cuda:
ir_img = ir_img.cuda()
vis_img = vis_img.cuda()
ir_img = Variable(ir_img, requires_grad=False)
vis_img = Variable(vis_img, requires_grad=False)
dimension = ir_img.size()
img_fusion = _generate_fusion_image(model, strategy_type, ir_img, vis_img)
############################ multi outputs ##############################################
file_name = 'fusion_' + fusion_type + '_' + str(index) + '_network_' + network_type + '_' + strategy_type + '_' + ssim_weight_str + '.png'
output_path = output_path_root + file_name
# # save images
# utils.save_image_test(img_fusion, output_path)
# utils.tensor_save_rgbimage(img_fusion, output_path)
if args.cuda:
img = img_fusion.cpu().clamp(0, 255).data[0].numpy()
else:
img = img_fusion.clamp(0, 255).data[0].numpy()
img = img.transpose(1, 2, 0).astype('uint8')
utils.save_images(output_path, img)
print(output_path)
def vision_features(feature_maps, img_type):
count = 0
for features in feature_maps:
count += 1
for index in range(features.size(1)):
file_name = 'feature_maps_' + img_type + '_level_' + str(count) + '_channel_' + str(index) + '.png'
output_path = 'outputs/feature_maps/' + file_name
map = features[:, index, :, :].view(1,1,features.size(2),features.size(3))
map = map*255
# save images
utils.save_image_test(map, output_path)
def main():
# run demo
# test_path = "images/test-RGB/"
test_path = "images/IV_images/"
network_type = 'densefuse'
fusion_type = 'auto' # auto, fusion_layer, fusion_all
strategy_type_list = ['addition', 'attention_weight'] # addition, attention_weight, attention_enhance, adain_fusion, channel_fusion, saliency_mask
output_path = './outputs/'
strategy_type = strategy_type_list[0]
if os.path.exists(output_path) is False:
os.mkdir(output_path)
# in_c = 3 for RGB images; in_c = 1 for gray images
in_c = 1
if in_c == 1:
out_c = in_c
mode = 'L'
model_path = args.model_path_gray
else:
out_c = in_c
mode = 'RGB'
model_path = args.model_path_rgb
with torch.no_grad():
print('SSIM weight ----- ' + args.ssim_path[2])
ssim_weight_str = args.ssim_path[2]
model = load_model(model_path, in_c, out_c)
for i in range(1):
index = i + 1
infrared_path = test_path + 'IR' + str(index) + '.png'
visible_path = test_path + 'VIS' + str(index) + '.png'
run_demo(model, infrared_path, visible_path, output_path, index, fusion_type, network_type, strategy_type, ssim_weight_str, mode)
print('Done......')
if __name__ == '__main__':
main()