虽然本人也很想写一个系列的分析文章,奈何水平不足,零零碎碎学到一点就写一点吧
本人是想学习MXNET的源码,首先想要添加一些打印,debug一下,第一个问题是如何在C++源码中打印出NDArray结构的值,
今天尝试如下,可以打印出来,
文件 incubator-mxnet/src/c_api/c_api.cc 中,函数MXNDArraySlice修改如下:
int MXNDArraySlice(NDArrayHandle handle,
mx_uint slice_begin,
mx_uint slice_end,
NDArrayHandle *out) {
NDArray *ptr = new NDArray();
API_BEGIN();
std::cout << "slice_begin:" << slice_begin << std::endl;
std::cout << "slice_end:" << slice_end << std::endl;
*ptr = static_cast<NDArray*>(handle)->SliceWithRecord(
slice_begin, slice_end);
*out = ptr;
float *p = (float *)ptr->data().dptr_;
std::cout << "p[0] = " << p[0] << std::endl;
std::cout << "p[1] = " << p[1] << std::endl;
API_END_HANDLE_ERROR(delete ptr);
}
Python测试代码如下
from mxnet import autograd, nd
import mxnet
print(mxnet.__version__)
x = nd.arange(2, 7).reshape((5, 1))
print(x[2:4].asnumpy())
打印结果为:
#python3 mxnet_test.py
1.5.0
slice_begin:2
slice_end:4
p[0] = 4
p[1] = 5
[[4.]
[5.]]
Great,可以验证出来实际的数值就是在NDArray的data()函数的dptr_指针中,
____________________________________________
但是在操作时有时会无法得到预期的结果,如同文件中函数MXNDArrayGetGrad,如果按照上面的代码进行打印的话,会发现打印出的值全为0,这时需要在代码中添加一行WaitToRead,如下可正常打印
int MXNDArrayGetGrad(NDArrayHandle handle, NDArrayHandle *out) {
API_BEGIN();
NDArray *arr = static_cast<NDArray*>(handle);
NDArray ret = arr->grad();
if (ret.is_none()) {
*out = NULL;
} else {
std::cout << "ret.shape().ndim() = " << ret.shape().ndim() << std::endl;
std::cout << "ret.shape()[0] = " << ret.shape()[0] << std::endl;
std::cout << "ret.shape()[1] = " << ret.shape()[1] << std::endl;
*out = new NDArray(ret);
ret.WaitToRead();
float *p_float = (float *)(ret.data().dptr_);
for (int i = 0; i < ret.shape()[0] * ret.shape()[1]; i++){
std::cout << "p_float[" << i << "] = " << p_float[i] << std::endl;
}
}
API_END();
}
Python 测试代码为:
from mxnet import autograd, nd
import mxnet
print(mxnet.__version__)
x = nd.arange(2, 7).reshape((5, 1))
x.attach_grad()
with autograd.record():
y = 2 * nd.dot(x.T, x)
y.backward()
# assert (x.grad - 4 * x).norm().asscalar() == 0
print(x.grad)
输出为:
# python3 autograd_test.py
1.5.0
ret.shape().ndim() = 2
ret.shape()[0] = 5
ret.shape()[1] = 1
p_float[0] = 8
p_float[1] = 12
p_float[2] = 16
p_float[3] = 20
p_float[4] = 24
[[ 8.]
[12.]
[16.]
[20.]
[24.]]
<NDArray 5x1 @cpu(0)>