1、自定义数据集
class Dataset(torch.utils.data.Dataset):
def __init__(self):
super(Dataset, self).__init__()
def __len__(self):
return len()
def __getitem__(self, item):
return item
data_loader = torch.utils.data.DataLoader(dataset, batch_size, num_workers, shuffle)
2、数据、模型需要放在cuda上
3、损失函数
torch.nn
4、优化器
torch.optim
5、
6、