1.SeLU(scaled exponential linear units)激活函数计算公式
selu ( x ) = λ { x if x > 0 α e x − α if x ⩽ 0. \text{selu}(x)= \lambda \begin{cases} x& \text{ if } x>0 \\ \alpha e^x-\alpha & \text{ if } x\leqslant 0. \end{cases} selu(x)=λ{ xαex−α if x>0 if x⩽0.
其中 λ = 1.0507009873554804934193349852946 \lambda=1.0507009873554804934193349852946 λ=1.0507009873554804934193349852946, α = 1.6732632423543772848170429916717. \alpha=1.6732632423543772848170429916717. α=1.6732632423543772848170429916717.
2.JAX代码实现
#!/usr/bin/python3
# -*- coding: utf-8 -*-
"""
@Time : 2022/7/20 13:51
@Author : Albert Darren
@Contact : [email protected]
@File : Program1.1.py
@Version : Version 1.0.0
@Description : TODO 利用jax计算selu函数,详见P12
@Created By : PyCharm
"""
import jax.numpy as jnp # 导入numpy计算包
from jax import random # 导入random随机数包
def selu(x, alpha=1.6732632423543772848170429916717, lmbda=1.0507009873554804934193349852946):
"""
实现selu激活函数
:param x: 输入张量
:param alpha: 预定义参数alpha
:param lmbda: 预定义参数lambda,此处变量名故意拼写错误,避免与关键字lambda命名冲突
:return: selu函数值
"""
return lmbda * jnp.where(x > 0, x, alpha * (jnp.exp(x) - 1))
# 产生一个固定数字17作为key
key = random.PRNGKey(17)
# 随机生成一个大小为[1,5]的矩阵
x = random.normal(key, (5,))
print(selu(x))
# [-1.2497659 0.4546819 1.5760192 -0.81573856 0.27510932]