需求:语义分割有23个类别,需要删除部分类别,删除的类别不会传入loss进行计算,但是默认在loss的计算过程中,是采用one-hot编码进行,删除类别会影响one-hot编码的排序。
依次要做的事儿:将忽略的类别定义为-1,建立新旧label值的映射表,通过映射表修改label的值,对其中 -1 的值进行过滤,然后生成新的one-hot编码。
需要修改的地方:类别数,修改值和mask其实不是针对我们创建的tensor1,而是真实的label tensor(比如维度为 1xn ),根据映射表和mask就可以对 1xn 的label做处理,我这里处理tensor1是为了方便展示。
代码如下:
import torch
import torch.nn.functional as F
if __name__ == '__main__':
# define 23 classes
tensor1 = torch.arange(0,23)
print(tensor1)
del_class = [4, 6, 8]
# edit del_class to -1
for i in range(len(tensor1)):
if tensor1[i] in del_class:
tensor1[i] = -1
# edit mapping dict
value = torch.unique(tensor1)
mapping = {}
for index in range(len(value)):
if value[index] == -1:
continue
old_value = int(value[index])
new_value = index - 1
mapping[old_value] = new_value
print(tensor1)
print(mapping)
# edit tensor value according to mapping_dict
for i in range(len(tensor1)):
if int(tensor1[i]) in mapping.keys():
new_value = mapping[int(tensor1[i])]
tensor1[i] = new_value
print(tensor1)
# mask -1 value
mask = (tensor1 != -1)
# generate one-hot value
tensor1_one = F.one_hot(tensor1[mask])
print(tensor1_one)