【3D 图像分割】基于 Pytorch 的 3D 图像分割2(基础数据流篇)

构建pytorch训练模型读取的数据,是有模版可以参考的,是有套路的,这点相信使用过的人都知道。我也会给出一个套路的模版,方便学习和查询。

同时,也可以先去参考学习之前的一篇较为简单的3D分类任务的数据构建方法,链接在这里:【3D图像分类】基于Pytorch的3D立体图像分类1(基础篇)

到了本篇训练的数据构建,相比于上面参考的这篇博客,就多了一丢丢的复杂。那就是有了原始图、mask图后,又多了一个结节目标的中心点坐标和半径。

那就意味着,我们读取到的原始图、mask图三维信息后,不能直接放进去训练,因为尺寸也不一样,背景信息太多了。那就需要根据结节目标的坐标信息,先进行裁剪,裁剪出固定大小区域的目标图,然后再放进去训练。

至此,整个过程基本上就清晰了。

  1. 获取原始图、mask图、结节目标的中心点坐标和半径
  2. 裁剪操作,取出固定大小的输入信息

为了方便调试,和查看裁剪的对不对,那就配合一个查看的操作,这样就完美了

一、训练数据预处理

Luna16的原始数据相对来说是较为复杂的,不够直观,这部分的数据处理如果放到训练阶段来做,将会耗费很多的时间和内存资源。所以,有必要将这部分较为复杂的数据,预先处理成较为简单的、直观的、一一对应的数据关系。

LUNA16数据集中的888个CT图像被均匀的划分为10个子集,分别存放在subset0-subset9这10个文件夹中,该数据集的图像存储方式为MetaImage(mhd/raw)的格式,每个mhd文件都存储着一个单独对应的raw文件。

  • raw二进制文件,存放像素数据;
  • mhd文件存放信息

这两部分是同时存在,表示一个完整的信息。对于上述关于Luna16的处理代码,可以参考【小目标】vnet 肺结节 3d图像分割中作者对这块的处理即可。针对这块,我做个简要的处理步骤介绍:

  1. 根据结节标注信息,获取标注结节的坐标信息,生成mask 3维数据块,尺寸和原始CT图像大小一致;
  2. 根据肺区分割图,和原始的CT图像处理,得到去除肺区外,只留下肺实质的信息,对mask同样处理;
  3. 根据x、y、z三个方向的space信息,进行resample操作,将原本各个维度像素点代表的不同尺度,给resample1mm单位;
  4. 最后再从resample后的mask中,获取结节的坐标信息,保存到csv文件呢。

至此,一个较为复杂的流程下,终于把他们一一的对应关系给整理顺畅了。本篇博文基本上是对作者视频部分的二次整理,和微微的改进与测试、可视化等等的工作。后续等训练了,发现了问题,再解决问题,进行优化。

到这里,训练所需要的文件基本上整理出来了,路径结构如下:

扫描二维码关注公众号,回复: 17218731 查看本文章
sk_output
├── bbox_annos
    ├── bbox_annos.csv

├── bbox_image
    ├── subset0
        ├── source_1.npy
        ├── source_2.npy
        └── ...
    ├── subset1
        ├── source_103.npy
        ├── source_104.npy
        └── ...
    ├── subset2
        ├── source_205.npy
        ├── source_206.npy
        └── ...
    ├── subset3
        ├── source_307.npy
        ├── source_308.npy
        └── ...
    └── ...


├── bbox_mask
    ├── subset0
        ├── source_1.npy
        ├── source_2.npy
        └── ...
    ├── subset1
        ├── source_103.npy
        ├── source_104.npy
        └── ...
    ├── subset2
        ├── source_205.npy
        ├── source_206.npy
        └── ...
    ├── subset3
        ├── source_307.npy
        ├── source_308.npy
        └── ...
    └── ...

其中,

  1. bbox_annos.csv:记录了文件名,及标记结节中心点坐标和半径;
  2. bbox_image.npy的图像信息,元素大小为0-255
  3. bbox_mask.npymask信息,和bbox_image内对应文件数量相等,单个npy文件shape一致。只有结节一个目标,元素值为0 or 1

二、构建myDataset类

构建这个数据集,其实也就是那么几件事:

  1. 读取原始图和mask图;
  2. 获取标记结节的中心点坐标信息,这里是从csv文件中获取的;
  3. 根据结节中心点坐标信息,再根据要裁剪的patch的大小,确定好立体举行的最小、最大坐标;
  4. 裁剪出patch的区域

至此,裁剪下来的patch,就是包含有结节的数组了,包括了图像数组,和标注mask数组,一一对应,用于训练。类中函数:

  1. getAnnotations 函数,需要从csv文件中获取文件名和结节对应坐标,最后存储为一个字典;
  2. getNpyFile_Path 函数,获取imagemask文件路径;
  3. get_annos_label 函数,获取文件对应的结节中心点标注信息。

如下,就是整个代码过程:

import os
import torch
import torch.nn as nn
import torch.utils.data
from torch.utils.data import Dataset
import numpy as np
import cv2
from tqdm import tqdm
import random
import matplotlib.pyplot as plt

def getAnnotations(csv_file):
    content = pd.read_csv(csv_file, delimiter=',', header=None,
                              index_col=False)
    names = content[1].values
    coors = content[2].values

    dict_tp = {
   
    
    }
    for n, c in zip(names, coors):
        c_list = eval(c)
        if c_list:
            print(n, c_list, type(c_list))
            dict_tp[n] = c_list
    return dict_tp

class myDataset(Dataset):
    
    def __init__(self, csv_file, data_path, label_path, crop_size=(16, 96, 96)):
        """
        :param csv_file: 记录文件名和结节标记中心点坐标+半径的信息
        :param data_path: 存储原始CT图像
        :param label_path: 存储mask图像
        :param crop_size:   裁剪的尺寸
        """
        self.annosNameCenter_list = getAnnotations(csv_file)
        self.dataFile_paths = self.getNpyFile_Path(data_path)   # 图的path列表
        self.labelFile_paths = self.getNpyFile_Path(label_path)   # 标签的path列表

        self.annos_img_info =  = self.get_annos_label(self.dataFile_paths)  # 图的位置列表 输入进去  吐出  结节附近的图的【【图片位置,结节中心,半径】列表】

        self.crop_size = crop_size
        self.crop_size_z, self.crop_size_h, self.crop_size_w = crop_size

    def __getitem__(self, index):
        img_all = self.annos_img[index]     # 0 - image_path ; 1 - 结节的中心; 2 - 结节的半径
        label_all = self.annos_label[index]

        path, zyx_centerCoor, r = img_all

        img = np.load

猜你喜欢

转载自blog.csdn.net/wsLJQian/article/details/133966967