label_one_hot = F.one_hot(x.to(torch.int64), 40).permute(0, 3, 1, 2)
在对标签进行one-hot编码时候,出现了错误,报错显示:F.one_hot的class参数必须小于真实的类别数。
我用的NYU-depth v2,设的40类,发现没错呀。
然后去搜了搜发现可能出现的问题:
1:x即标签的数据类型不对。我查看了一下,将x设置为torch.float32,运行还是报错。
2:难道class是图片中的类别吗,我使用:
torch.unique(x)
查看发现图片一共有9类:
tensor([ 0., 1., 5., 7., 8., 26., 29., 38., 40.])
将class设置为8,9,都报错。
3:仔细想一想,one-hot应该是对每一张图片的都进行相同的编码操作,那2的思路就是错的,既然40也不对,我想到了在验证的时候,会将标签的像素值减1,是因为在预测时候像素1代表墙,在label中像素值代表空,那么总类别也就是41,上面的unique函数也验证了图片中含有0像素,同时含有40像素,一共有41个类别。
将报错代码改为:
label_one_hot = F.one_hot(x.to(torch.int64), 41).permute(0, 3, 1, 2)
成功运行!