Pytorch(十四) —— hook

HOOK——获取神经网络特征和梯度的有效工具

为了更深入地理解神经网络模型,有时候我们需要观察它训练得到的卷积核、特征图或者梯度等信息,这在CNN可视化研究中经常用到。其中,卷积核最易获取,将模型参数保存即可得到;特征图是中间变量,所对应的图像处理完即会被系统清除,否则将严重占用内存;梯度跟特征图类似,除了叶子结点外,其它中间变量的梯度都被会内存释放,因而不能直接获取。
最容易想到的获取方法就是改变模型结构,在forward的最后不但返回模型的预测输出,还返回所需要的特征图等信息。

如何在不改变模型结构的基础上获取特征图、梯度等信息呢?

Pytorch的hook编程可以在不改变网络结构的基础上有效获取、改变模型中间变量以及梯度等信息。

hook可以提取或改变Tensor的梯度,也可以获取nn.Module的输出和梯度(这里不能改变)。

Pytorch在进行完一次反向传播后,出于节省内存的考虑,只会存储叶子节点的梯度信息,并不会存储中间变量的梯度信息。然而有些时候我们又不得不使用中间变量的梯度信息完成某些工作(如获取中间层的梯度,获取中间层的特征图),这时候hook()函数就可以派上用场啦

主要有四种钩子函数:

  • ①torch.Tensor.register_hook
  • ②torch.nn.Module.register_backward_hook
  • ③torch.nn.Module.register_forward_hook
  • ④torch.nn.Module.register_forward_pre_hook,接下来分别对他们进行介绍
     

https://blog.csdn.net/weixin_42075898/article/details/103412227

https://www.jianshu.com/p/69e57e3526b3

猜你喜欢

转载自blog.csdn.net/hxxjxw/article/details/115378293