MXNet中x.grad源码追溯

Python测试代码如https://zh.gluon.ai/chapter_prerequisite/autograd.html

本文追溯x.grad这一行代码的调用

grad调用的是函数MXNDArrayGetGrad,/usr/local/lib/python3.7/dist-packages/mxnet-1.5.0-py3.7.egg/mxnet/ndarray/ndarray.py

MXNDArrayGetGrad的源码依旧是在文件src/c_api/c_api.cc中,

NDArray ret = arr->grad();

ret就是获取到的梯度

这里grad的源码文件为src/ndarray/ndarray.cc,

Imperative::AGInfo& info = Imperative::AGInfo::Get(entry_.node);

return info.out_grads[0];

这里Imperative::AGInfo::Get的源码文件为 include/mxnet/imperative.h

return dmlc::get<AGInfo>(node->info);

这里get的源码文件为3rdparty/dmlc-core/include/dmlc/any.h

return *any::TypeInfo<T>::get_ptr(&(src.data_));

这个get_ptr调用的是同文件中的如下代码:

template<typename T>
class any::TypeOnHeap {
 public:
  inline static T* get_ptr(any::Data* data) {
    return static_cast<T*>(data->pheap);
  }

回到上面的代码,那个entry_是NDArrary类的一个对象:

  /*! \brief node entry for autograd */
  nnvm::NodeEntry entry_;

猜你喜欢

转载自blog.csdn.net/zhqh100/article/details/91490938