RuntimeError: Class values must be smaller than num_classes

label_one_hot = F.one_hot(x.to(torch.int64), 40).permute(0, 3, 1, 2)

在对标签进行one-hot编码时候,出现了错误,报错显示:F.one_hot的class参数必须小于真实的类别数。
我用的NYU-depth v2,设的40类,发现没错呀。
然后去搜了搜发现可能出现的问题:
1:x即标签的数据类型不对。我查看了一下,将x设置为torch.float32,运行还是报错。
2:难道class是图片中的类别吗,我使用:

torch.unique(x)

查看发现图片一共有9类:

tensor([ 0.,  1.,  5.,  7.,  8., 26., 29., 38., 40.])

将class设置为8,9,都报错。

3:仔细想一想,one-hot应该是对每一张图片的都进行相同的编码操作,那2的思路就是错的,既然40也不对,我想到了在验证的时候会将标签的像素值减1,是因为在预测时候像素1代表墙,在label中像素值代表空,那么总类别也就是41,上面的unique函数也验证了图片中含有0像素,同时含有40像素,一共有41个类别。

将报错代码改为:

label_one_hot = F.one_hot(x.to(torch.int64), 41).permute(0, 3, 1, 2)

成功运行!

猜你喜欢

转载自blog.csdn.net/qq_43733107/article/details/129071714