Python读取tif格式文件需要安装libtiff
,此外需要安装 inferno
本文适用于读取三维tif。
from torch.utils.data import DataLoader, Dataset
from inferno.io.transform.base import Transform, Compose
from inferno.io.transform.generic import Normalize, AsTorchBatch
from inferno.io.transform.image import RandomCrop, RandomRotate, RandomFlip
from libtiff import TIFF
import os
import torch
import numpy as np
定义一个MyDataSet
class MyDataSet(Dataset):
def __init__(self, pathLst, transform): # Parameters and their form vary according to program needs
dataPath, labelPath = pathLst
self.tifStreamData, self.tifStreamLabel = [], []
dataFiles, labelFiles = os.listdir(dataPath), os.listdir(labelPath)
dataFiles.sort(key = lambda x: int(x[3:-4])) #sorted by name order, such as LR_20.tif
for dataFile in dataFiles:
dataFileName = os.path.join(dataPath, dataFile)
self.tifStreamData.append(tiff2Stack(dataFileName, transform))
labelFiles.sort(key = lambda x: int(x[3:-4]))
for labelFile in labelFiles:
labelFileName = os.path.join(labelPath, labelFile)
self.tifStreamLabel.append(tiff2Stack(labelFileName, transform))
assert len(self.tifStreamData) == len(self.tifStreamLabel) # check length
def __len__(self):
return len(self.tifStreamData)
def __getitem__(self,idx):
data, label = self.tifStreamData, self.tifStreamLabel
return data[idx], label[idx]
def tiff2Stack(fileName, transform=None): # read tif, data transform, output tensor
tif = TIFF.open(fileName,mode='r')
tifLst = list(tif.iter_images()) # (51,101,101)
tifArr = np.zeros((len(tifLst), tifLst[0].shape[0], tifLst[0].shape[1]))
for i, img in enumerate(list(tif.iter_images())):
tifArr[i,:,:] = img/1.0 # avoid that "can't convert np.ndarray of type numpy.uint16."
if transform:
tifArr = transform(tifArr)
return tifArr
调用
def main():
transform = Compose(RandomRotate(), RandomFlip(), Normalize(), AsTorchBatch(2))
pathLst = ["/your/tif/image/Data/path/", "/your/tif/image/Label/path/"]
myTrainData = MyDataSet(pathLst, transform=transform)
trainData = DataLoader(dataset=myTrainData, batch_size=4, shuffle=True)
for i,j in enumerate(trainData):
print(i)
data, label = j
print("data.shape",data.shape,"label.shape",label.shape)
if __name__ == "__main__":
main()