其原因是处理图片时:
train_transformer = torchvision.transforms.Compose(
[ transforms.Resize(224,224),
transforms.RandomCrop(192,192),
transforms.ColorJitter(brightness=0.4),
transforms.RandomHorizontalFlip(0.2),
transforms.ColorJitter(contrast=4),
transforms.ToTensor(),
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
]
)
标红部分修改为:
train_transformer = torchvision.transforms.Compose(
[ transforms.Resize((224,224)),
transforms.RandomCrop((192,192)),
transforms.ColorJitter(brightness=0.4),
transforms.RandomHorizontalFlip(0.2),
transforms.ColorJitter(contrast=4),
transforms.ToTensor(),
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
]
)