作者:陈诚
链接:https://www.zhihu.com/question/67209417/answer/344752405
来源:知乎
著作权归作者所有。商业转载请联系作者获得授权,非商业转载请注明出处。
链接:https://www.zhihu.com/question/67209417/answer/344752405
来源:知乎
著作权归作者所有。商业转载请联系作者获得授权,非商业转载请注明出处。
算是动态图的一个坑吧。记录loss信息的时候直接使用了输出的Variable。
应该不止我经历过这个吧...
久久不用又会不小心掉到这个坑里去...
for data, label in trainloader:
......
out = model(data)
loss = criterion(out, label)
loss_sum += loss # <--- 这里
......
运行着就发现显存炸了
观察了一下发现随着每个batch显存消耗在不断增大..
参考了别人的代码发现那句loss一般是这样写 /(ㄒoㄒ)/~~
loss_sum += loss.data[0]
这是因为输出的loss的数据类型是Variable。
而PyTorch的动态图机制就是通过Variable来构建图。主要是使用Variable计算的时候,会记录下新产生的Variable的运算符号,在反向传播求导的时候进行使用。
如果这里直接将loss加起来,系统会认为这里也是计算图的一部分,也就是说网络会一直延伸变大~那么消耗的显存也就越来越大~~
总之使用Variable的数据时候要非常小心。不是必要的话尽量使用Tensor来进行计算...
包括数据的输入时候,如果“过早”把数据丢到Variable里面去,那么可能也会被系统视为网络的一部分。所以,要投入的时候再把数据丢到Variable里面去吧~