Easy Deep Learning——激活函数

激活函数是什么?什么神经网络需要激活函数?它的作用是什么?

假设你是一位魔术师,你手里有一个魔法盒子,盒子里面有很多彩色的球。你想要从盒子里面找到所有红色的球,然后将它们放到一个特殊的盒子里。

你开始一个个地从魔法盒子中取出球,然后仔细地检查它们的颜色。如果是红色的球,你就放到特殊的盒子里;如果不是,你就把它放回原来的魔法盒子里,继续找下一个球。

但是你很快发现,这种方法效率非常低。你需要一个更快、更有效的方法来找到所有的红色球。于是你开始思考:有没有一种方法,可以在一次性地查看所有的球之后,直接找到所有的红色球呢?


你想了很久,最终想到了一个办法。你把所有的球都放到一个大盒子里,然后用一个特殊的滤网过滤掉所有不是红色的球,只留下红色的球。这样一来,你就能够一次性地找到所有的红色球了。

激活函数就像是魔法盒子中的滤网一样,可以帮助神经网络找到特定的模式。激活函数通过引入非线性,让神经网络能够识别更为复杂的模式,从而找到特定的目标。激活函数的作用就像是滤网一样,可以帮助神经网络快速、准确地找到我们需要的模式,提高神经网络的性能。 

激活函数(Activation Function)是神经网络中一种非线性的函数,用于将输入信号映射到输出信号。激活函数的作用是增加神经网络的表达能力,使神经网络能够处理非线性的数据。

激活函数对应神经网络有以下作用:

  1. 引入非线性:激活函数通过引入非线性,使得神经网络能够拟合更为复杂的模型。如果没有激活函数,神经网络将退化为线性模型,无法处理非线性数据。

  2. 改善模型输出:激活函数可以将神经网络的输出值映射到一个特定的范围内,例如0到1或-1到1。这种映射可以使神经网络的输出更容易理解和解释。

  3. 防止梯度消失:在反向传播过程中,如果激活函数的导数很小,梯度就会消失,导致神经网络无法更新权重。一些特殊的激活函数,如ReLU(Rectified Linear Unit)和其变种,具有较大的导数,可以避免梯度消失的问题。

常见的激活函数包括Sigmoid函数、Tanh函数、ReLU函数、LeakyReLU函数等。不同的激活函数在不同的场景下表现不同,选择合适的激活函数可以提高神经网络的性能。

Pytorch中常见的激活函数介绍

  1. ReLU(Rectified Linear Unit):是目前使用最广泛的激活函数之一,它将小于零的值设为零,大于零的值不变。可以通过 torch.nn.ReLU() 来使用。

  2. Sigmoid:将实数映射到区间 (0,1) 内,对于二分类问题非常有用。可以通过 torch.nn.Sigmoid() 来使用。

  3. Tanh:将实数映射到区间 (-1,1) 内,比 Sigmoid 函数的输出范围更广。可以通过 torch.nn.Tanh() 来使用。

  4. Softmax:主要用于多分类问题,将实数映射到 (0,1) 区间内的概率值,且所有输出的概率和为1。可以通过 torch.nn.Softmax() 来使用。

 以下使用pytorch的API 来实现这四种函数,代码如下

import  torch
import  torch.nn as nn
import matplotlib.pyplot as plt
import numpy
x = torch.linspace(-6,6,100)
sigmod = nn.Sigmoid()  ##Sigmod激活函数
ysigmod = sigmod(x)

tanh = nn.Tanh() ##Tanh激活函数
ytanh = tanh(x)

relu = nn.ReLU() ##ReLU激活函数
yrelu = relu(x)

softmax = nn.Softmax() ##Softmax函数
ysoftmax = softmax(x)

plt.figure(figsize=(14,3))

plt.subplot(1,4,1)
plt.plot(x.data.numpy(),ysigmod.data.numpy(),"r-")
plt.title('sigmod')
plt.grid()

plt.subplot(1,4,2)
plt.plot(x.data.numpy(),yrelu.data.numpy(),"r-")
plt.title('relu')
plt.grid()



plt.subplot(1,4,3)
plt.plot(x.data.numpy(), ysoftmax.data.numpy(), "r-")
plt.title('softplus')
plt.grid()

plt.subplot(1,4,4)
plt.plot(x.data.numpy(),ytanh.data.numpy(),"r-")
plt.title('tanh')
plt.grid()
plt.show()

  

猜你喜欢

转载自blog.csdn.net/weixin_40582034/article/details/129438065