昨天使用pytorch写一个程序,程序写完之后却一直不能正确运行,今天定位到了代码的问题所在:
我的代码其中有一处逻辑是这样的:
……
get a # 这里的a就是我想反向求导更新的参数
b=torch.nonzero(a) # 得到a里面所有不为0的下标
for i,j in b:
feature_i=feature[i]
feature[j]=feature[j]
……
get c # 通过b中的元素下标得到一个c
c->loss # c经过一系列操作得到了loss
之后loss反向求导发现a
的梯度一直是0,这是因为我只用到了a中不为0元素的下标,根本就没有用到a这个矩阵,此时的计算图应该是这样的:
b->c->loss
因为a->b根本没有产生梯度关系,所以计算图肯定不能反向传导到a。
以后写代码要思考计算图的建立