def acc(output, label):
# output: (batch, num_output) float32 ndarray
# label: (batch, ) int32 ndarray
return (output.argmax(axis=1) == label.astype(‘float32’)).mean().asscalar()
在Gluon文档里有这个计算accuracy的函数,就一行看不懂,分析一下。
首先argmax,argmax的意思是返回最大值的坐标。
- axis缺省为全局最大(直接用报错,可以np.argmax(a) )
- axis = 0 为每列最大
- axis = 1为每行最大
x = nd.array(((1,2,3),(3,4,5)))
>>> x.argmax(axis=1)
[2. 2.]
<NDArray 2 @cpu(0)>
>>> x.argmax(axis=0)
[1. 1. 1.]
<NDArray 3 @cpu(0)>
=======================================================
没学过python诶。为什么函数参数可以这么写
见:http://www.runoob.com/python/python-functions.html 中的关键字参数
=======================================================
所以 output.argmax(axis=1) 返回的是每行最大的index。
然后
output.argmax(axis=1) == label.astype('float32')
是一个0 1 数组
最后计算mean就行了,其中asscalar()把结果变成标量