Unet项目解析(5): 数据封装、数据加载、数据显示

项目GitHub主页:https://github.com/orobix/retina-unet

参考论文:Retina blood vessel segmentation with a convolution neural network (U-net) Retina blood vessel segmentation with a convolution neural network (U-net)


1.数据封装成HDF5格式

import os
import h5py
import numpy as np
from PIL import Image

def write_hdf5(arr,outfile):  # arr:数据  outfile:数据保存文件位置
  with h5py.File(outfile,"w") as f:
    f.create_dataset("image", data=arr, dtype=arr.dtype)

# 训练数据位置:图像 金标准 掩膜
original_imgs_train = "./DRIVE/training/images/"
groundTruth_imgs_train = "./DRIVE/training/1st_manual/"
borderMasks_imgs_train = "./DRIVE/training/mask/"
# 测试数据位置:图像 金标准 掩膜
original_imgs_test = "./DRIVE/test/images/"
groundTruth_imgs_test = "./DRIVE/test/1st_manual/"
borderMasks_imgs_test = "./DRIVE/test/mask/"
# 封装数据保存位置
dataset_path = "./datasets_training_testing/"

Nimgs = 20
channels = 3
height = 584
width = 565

def get_datasets(imgs_dir,groundTruth_dir,borderMasks_dir,train_test="null"):
    imgs = np.empty((Nimgs,height,width,channels))
    groundTruth = np.empty((Nimgs,height,width))  # 二值图像 channels=1
    border_masks = np.empty((Nimgs,height,width)) # 二值图像 channels=1
    for path, subdirs, files in os.walk(imgs_dir):# path=当前路径 subdirs=子文件夹 files=文件夹内所有的文件
        for i in range(len(files)):  # len(files) 所有图像的数量
            print ("original image: " +files[i])
            img = Image.open(imgs_dir+files[i])   # 读取图像到内存
            imgs[i] = np.asarray(img)             # 转换成numpy数据格式
			
            groundTruth_name = files[i][0:2] + "_manual1.gif"
            print ("ground truth name: " + groundTruth_name)
            g_truth = Image.open(groundTruth_dir + groundTruth_name)
            groundTruth[i] = np.asarray(g_truth)

            border_masks_name = ""
            if train_test=="train":
                border_masks_name = files[i][0:2] + "_training_mask.gif"
            elif train_test=="test":
                border_masks_name = files[i][0:2] + "_test_mask.gif"
            else:
                print "please specify if train or test!!"
                exit()
            print ("border masks name: " + border_masks_name)
            b_mask = Image.open(borderMasks_dir + border_masks_name)
            border_masks[i] = np.asarray(b_mask)

    print ("imgs max: " +str(np.max(imgs)))
    print ("imgs min: " +str(np.min(imgs)))
    assert(np.max(groundTruth)==255 and np.max(border_masks)==255) # 断言判断
    assert(np.min(groundTruth)==0 and np.min(border_masks)==0)
	# 调整张量格式 [Nimg, channels, height, width]
    imgs = np.transpose(imgs,(0,3,1,2)) 
    groundTruth = np.reshape(groundTruth,(Nimgs,1,height,width))
    border_masks = np.reshape(border_masks,(Nimgs,1,height,width))
	# 检查张量格式
	assert(imgs.shape == (Nimgs,channels,height,width)) 
    assert(groundTruth.shape == (Nimgs,1,height,width))
    assert(border_masks.shape == (Nimgs,1,height,width))
    return imgs, groundTruth, border_masks

if not os.path.exists(dataset_path):
    os.makedirs(dataset_path)
# 封装训练数据集
imgs_train, groundTruth_train, border_masks_train 
		= get_datasets(original_imgs_train,groundTruth_imgs_train,borderMasks_imgs_train,"train")
print ("saving train datasets ... ...")
write_hdf5(imgs_train, dataset_path + "imgs_train.hdf5")
write_hdf5(groundTruth_train, dataset_path + "groundTruth_train.hdf5")
write_hdf5(border_masks_train,dataset_path + "borderMasks_train.hdf5")

# 封装测试数据集
imgs_test, groundTruth_test, border_masks_test 
		= get_datasets(original_imgs_test,groundTruth_imgs_test,borderMasks_imgs_test,"test")
print ("saving test datasets ... ...")
write_hdf5(imgs_test,dataset_path + "DRIVE_dataset_imgs_test.hdf5")
write_hdf5(groundTruth_test, dataset_path + "DRIVE_dataset_groundTruth_test.hdf5")
write_hdf5(border_masks_test,dataset_path + "DRIVE_dataset_borderMasks_test.hdf5")

2. 写入、加载HDF5文件

def write_hdf5(arr,outfile):
  with h5py.File(outfile,"w") as f:
    f.create_dataset("image", data=arr, dtype=arr.dtype)
def load_hdf5(infile):
  with h5py.File(infile,"r") as f:  # "image"是写入的时候规定的字段 key-value
    return f["image"][()]           # 调用方法 train_imgs_original = load_hdf5( file_dir )

3.图像灰阶转换

# 将RGB图像转换为Gray图像
def rgb2gray(rgb):
    assert (len(rgb.shape)==4)  #[Nimgs, channels, height, width]
    assert (rgb.shape[1]==3)    #确定是否为RGB图像
    bn_imgs = rgb[:,0,:,:]*0.299 + rgb[:,1,:,:]*0.587 + rgb[:,2,:,:]*0.114
    bn_imgs = np.reshape(bn_imgs,(rgb.shape[0],1,rgb.shape[2],rgb.shape[3])) # 确保张量形式正确
    return bn_imgs

4.利用已知信息进行分组显示

# 对数据集划分,进行分组显示,totimg图像阵列
def group_images(data,per_row): # data:数据  per_row:每行显示的图像个数
    assert data.shape[0]%per_row==0  # data=[Nimgs, channels, height, width]
    assert (data.shape[1]==1 or data.shape[1]==3)
    data = np.transpose(data,(0,2,3,1))  # 用于显示
    all_stripe = []
    for i in range(int(data.shape[0]/per_row)): # data.shape[0]/per_row 行数
        stripe = data[i*per_row] # 相当于matlab中的 data(i*per_row, :, :, :) 一张图像
        for k in range(i*per_row+1, i*per_row+per_row):
            stripe = np.concatenate((stripe,data[k]),axis=1) # 每per_row张图像拼成一行
        all_stripe.append(stripe)  # 加入列表
    totimg = all_stripe[0]
    for i in range(1,len(all_stripe)):
        totimg = np.concatenate((totimg,all_stripe[i]),axis=0) # 每行图像进行拼凑 共len(all_stripe)行
    return totimg
def visualize(data,filename):
    assert (len(data.shape)==3) #height*width*channels
    img = None
    if data.shape[2]==1:  #in case it is black and white
        data = np.reshape(data,(data.shape[0],data.shape[1]))
    if np.max(data)>1:
        img = Image.fromarray(data.astype(np.uint8))   #the image is already 0-255
    else:
        img = Image.fromarray((data*255).astype(np.uint8))  #the image is between 0-1
    img.save(filename + '.png') #保存
    return img

猜你喜欢

转载自blog.csdn.net/shenziheng1/article/details/80707618