【深度学习原理】交叉熵损失函数的实现

版权声明:本文为博主原创文章,未经博主允许不得转载。 https://blog.csdn.net/u011240016/article/details/85728682

交叉熵损失函数

一般我们学习交叉熵损失函数是在二元分类情况下:

L = [ y l o g y ^ + ( 1 y ) l o g ( 1 y ^ ) ] L=−[ylog ŷ +(1−y)log (1−ŷ )]

推而广之,我们可以得到下面这个交叉熵损失函数公式:

E = k t k l o g ( y k ) E=-\sum_k{t_k}log(y_k)

从机器学习的角度看,这里的 y k y_k 是神经网络的输出, t k t_k 是正确解的标签。

而分类标签有两种方式:

  • One-Hot编码
  • 非One-Hot编码

One-Hot编码下的损失函数实现

使用One-Hot编码时, t k t_k 中,只有正确解的标签才为1,其他的都是0,所以在相乘时,这项就为0,但是我们知道 l o g ( 0 ) log(0) 是负无穷,显然我们需要特别在代码中处理一下。

先不看负无穷的问题,在One-Hot编码时, t k t_k 中只有为1的这项,才有输出,也就是说,我们计算交叉熵损失函数,只用计算对应正确解的输出的自然对数即可

代码如下:

def cross_entropy_error(y, t):
	delta = 1e-7
	return -np.sum(t * np.log(y + delta))

比如:

t = [0,0,1,0,0,0,0,0,0,0]
y = [0.1,0.05,0.6,0.0,0.05,0.1,0.0,0.1,0.0,0.0]
cross_entropy_error(y,t) # ==> 0.510825...
y = [0.1, 0.05, 0.1, 0.0, 0.05, 0.1, 0.0, 0.6, 0.0, 0.0]
cross_entropy_error(y,t) # ==> 2.30258... 

第一个案例中,正确的标签是2,输出的Softmax概率中2对应的标签的概率最大,为0.6,由此计算出来的损失函数值为0.51;第二个案例,预测的概率最大为0.1,以第一个作为预测结果,即0是预测值,得出损失函数值为2.3,可见预测错了损失函数值偏大。

总之,用One-Hot编码,是将
标签值和预测值的编码一一对应,按照交叉熵的公式处理。

非One-hot编码

如果只有一个值,单个样本的损失函数计算如下:

def cross_entropy_error(y, t):
	delta = 1e-7
	return -np.log(y + delta)

这是从前面的One-Hot编码那里推导来的,我们只需要神经网络在正确标签处的输出,就可以计算交叉熵误差。

如果是Mini-Batch呢,需要做哪些变化?

Mini-Batch下的交叉熵函数

One-Hot编码

def cross_entropy_error(y, t):
	if y.ndim == 1:
		t = t.reshape(1, t.size) # ndarray的size属性是存在的
		y = y.reshape(1, y.size)
	batch_size = y.shape[0]
	return -np.sum(t * np.log(y+ 1e-7)) / batch_size

这里既是y和t是小批量的形式,即二维矩阵,按照Numpy的调性,矩阵直接相乘是按照元素相乘,最后聚和再除以总体个数即可。看起来就除了batch_size,其实是聚和了二维矩阵相乘的结果。

非One-Hot编码

def cross_entropy_error(y, t):
	if y.ndim == 1:
		t = t.reshape(1, t.size) # ndarray的size属性是存在的
		y = y.reshape(1, y.size)
	batch_size = y.shape[0]
	return -np.sum(np.log(y[np.arange(batch_size),t] + 1e-7)) / batch_size

这里还是需要注意这句话:我们只需要神经网络在正确标签处的输出,就可以计算交叉熵误差。所以看起来很复杂的y[np.arange(batch_size),t]目的也是为了获得神经网络的输出,取出的是多行与多列的组合。

END.

参考:
《深度学习入门:基于Python的理论和实现》

https://jamesmccaffrey.wordpress.com/2013/11/05/why-you-should-use-cross-entropy-error-instead-of-classification-error-or-mean-squared-error-for-neural-network-classifier-training/

https://www.jianshu.com/p/474439106874

猜你喜欢

转载自blog.csdn.net/u011240016/article/details/85728682