一. 封装成类Dataset,再用加载器Dataloader
1.封装成类Dataset:
数据集合转化成Dataset这个类,然后必须有
__init__来加载数据集,
__len__来获取数据集的数据数量,用于for循环的次数,
__getitem__来索引数据集中的某条数据
1.
import torch
import os
import torch.nn as nn
from torchvision import transforms
from torch.utils.data import Dataset
from PIL import Image
class GetData(Dataset):
def __init__(self,path0,path1): #得到名字list
super(GetData,self).__init__()
self.path0 = path0
self.path1 = path1
self.name0_list = os.listdir(self.path0)
self.name1_list = os.listdir(self.path1)
self.img2data = transforms.Compose([transforms.ToTensor()])
def __len__(self):
return len(self.name0_list)
def __getitem__(self, index): #按名取图,index对应批次
self.name0 = self.name0_list[index]
self.name1 = self.name1_list[index]
img0 = Image.open(os.path.join(self.path0, self.name0))
img1 = Image.open(os.path.join(self.path1, self.name1))
imgdata0 = self.img2data(img0)
imgdata1 = self.img2data(img1)
return imgdata0, imgdata1
class Trainer(nn.Module):
def __init__(self):
super(Trainer,self).__init__()
self.main_net = MainNet()
self.main_net.cuda()#主网络加cuda,就相当于里面的所有网络加了cuda
'涉及2种损失,自然就会有对应2个优化器做反向传播'
vae_parameters = []
vae_parameters.extend(self.main_net.encoder.parameters())
vae_parameters.extend(self.main_net.decoder.parameters())
self.opt_dis = torch.optim.Adam(self.main_net.discriminator.parameters(), lr=1e-3)
self.opt_vae = torch.optim.Adam(vae_parameters, lr=1e-3)
def train(self):
for epoch in range(10000):
if os.path.exists('encoder.pkl'):
self.main_net.encoder.load_state_dict(torch.load('encoder.pkl'))
if os.path.exists('decoder.pkl'):
self.main_net.decoder.load_state_dict(torch.load('decoder.pkl'))
if os.path.exists('discriminator.pkl'):
self.main_net.discriminator.load_state_dict(torch.load('discriminator.pkl'))
self.dataloader = DataLoader(dataset.GetData(path0=r'C:\Users\87419\Desktop\cg1\64',
path1=r'C:\Users\87419\Desktop\cg1\dama_64'), batch_size=128, shuffle=True)
count = 0
'每个epoch内都是遍历5万张图,即dataloader数。每count一次,即每次循环都是处理batchsize张'
'dataloader长度 = 总张数/批次数 :782 = 50000/64。即loader长度等于每个ecpoch的总count数'
for img0data, img1data in self.dataloader:
img0data = img0data.cuda()#把输入的数据加cuda,接下来里面的过程数据自然也就以cuda运行
img1data = img1data.cuda()
count += 1
print(count)
2.
import torch
import os
import numpy as np
import cv2
from torch.utils.data import Dataset,DataLoader
class GetData(Dataset):
def __init__(self,path1,path2):
super(GetData,self).__init__()
self.path1 = path1
self.path2 = path2
self.dataset1 = []
self.dataset2 = []
self.dataset1.extend(open(os.path.join(self.path1,'label.txt')).readlines())
self.dataset2.extend(open(os.path.join(self.path2,'label.txt')).readlines())
def __getitem__(self, index): #index不是待赋参量,而是对应批次batch_size
str1 = self.dataset1[index].strip() #如dataset[0]是第一批次
str2 = self.dataset2[index].strip()
imgpath1 = os.path.join(self.path1,str1)
imgpath2 = os.path.join(self.path2,str2)
im1 = cv2.imread(imgpath1)
im2 = cv2.imread(imgpath2)
imgdata1 = torch.Tensor((im1 / 255. - 0.5))
imgdata2 = torch.Tensor((im2 / 255. - 0.5))
return imgdata1,imgdata2
def __len__(self):
return len(self.dataset1)
2.再用加载器Dataloader
for i in range(EPOCHES):
print('epoch:',i)
dataset = GetData(r'C:\Users\87419\Desktop\VAE1\data\trainB', r'C:\Users\87419\Desktop\VAE1\data\trainA')
dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True)
for j,(imgdata1,imgdata2) in enumerate(dataloader):
imgdata1_ = imgdata1.cuda()
imgdata2_ = imgdata2.cuda()
二. 用torchviosion
import torch
from torchvision import transforms, datasets
from torch.utils.data import DataLoader
transform1 = transforms.Compose([
transforms.Resize(256),
transforms.ToTensor(),
transforms.Normalize((0.5,0.5,0.5), (0.5,0.5,0.5))
])
dataset = datasets.ImageFolder(root=r'C:\Users\87419\Desktop\VAE\faces',transform=transform1)
dataset_loader = DataLoader(dataset,batch_size=4, shuffle=True)
或定义数据加载函数:
def sample_data(path,batch_size,size=4):
transform = transforms.Compose([transforms.Resize(size),
transforms.ToTensor(),
transforms.Normalize((0.5,0.5,0.5),(0.5,0.5,0.5))])
dataset = datasets.ImageFolder(path,transform=transform)
dataloader = DataLoader(dataset,batch_size=batch_size,shuffle=True)
return dataloader