时间:2018/2/7
一般情况下,用一个完整的网络就可以了。但是像我现在要做一个network in network,大网络里要附加一个小网络,而且还想单独训练这个小网络。使用pytorch实现这个想法的时候问题就来了:pytorch只对Variable叶节点有显式的梯度计算,所以任何其他的操作clone等都不能计算梯度的。而又不想训练这个单独的小网络的时候,将梯度传递到主网络里面去,所以有以下两种方法。
方法一:
使用Variable.detach() detach的官网介绍
比如模拟了一个网络的代码如下:
import torch
import numpy as np
from torch.autograd import Variable
import torch.nn as nn
a = np.arange(24).reshape(1, 2, 3, 4).astype(np.float32)
b = Variable(torch.from_numpy(a.astype(np.float32)))
d = b.clone().cuda() + 1
b.requires_grad = True
x1 = nn.Conv2d(2, 2, 1).cuda()
x2 = nn.ReLU().cuda()
l = nn.MSELoss()
x = x1(b.cuda())
c = x.detach()
c.requires_grad = True
xc = x1(c)
ls = l(xc, d)
ls.backward()
print b.grad
print '*' * 10
print c.grad
方法二:
使用clone()
这样得到的变量的确不是叶节点,但是小网络里面用到的权值和偏置项是叶节点,pytorch还是能够将接下来的计算用于求算权重的梯度.
哈哈,折腾了将近一小时,不如脑袋短路一秒钟。
However,使用detach可以将graph结构就此断路,不会继续前传,clone则是不行的。