GELU

GELU

这个函数特别占内存,计算量很大,对检测有帮助,收敛比relu6快

但是最高精度没有relu6高

梯度最大在第一层卷积层

类:

#!/usr/bin/env python3
# -*- coding: utf-8 -*-
import torch
import math
from torch import nn
from torch.nn import functional as F

class mish(nn.Module):
    def __init__(self):
        super(mish, self).__init__()
    # Also see https://arxiv.org/abs/1606.08415
    def forward(self, x):
        return x * torch.tanh(F.softplus(x))


class Gelu(nn.Module):
    def __init__(self):
        super(Gelu, self).__init__()
    # Also see https://arxiv.org/abs/1606.08415
    def forward(self, x):
        return x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0)))


class Gelu_new(nn.Module):
    def __init__(self):
        super(Gelu_new, self).__init__()
        #Also see https://arxiv.org/abs/1606.08415
    def forward(self, x):
        return 0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3))))


class swish(nn.Module):
    def __init__(self):
        super(swish, self).__init__()
        #Also see https://arxiv.org/abs/1606.08415
    def forward(self, x):
        return x * torch.sigmoid(x)

https://github.com/StateOfTheArt-quant/transformerquant/blob/e6f3ae7135aa5cb581c1a675b3c000b92e4188cc/transformerquant/modules/activation/activations.py



#!/usr/bin/env python3
# -*- coding: utf-8 -*-
import torch
import math


def gelu(x):
    """ Original Implementation of the gelu activation function in Google Bert repo when initially created.
        For information: OpenAI GPT's gelu is slightly different (and gives slightly different results):
        0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3))))
        Also see https://arxiv.org/abs/1606.08415
    """
    return x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0)))

def gelu_new(x):
    """ Implementation of the gelu activation function currently in Google Bert repo (identical to OpenAI GPT).
        Also see https://arxiv.org/abs/1606.08415
    """
    return 0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3))))

def swish(x):
    return x * torch.sigmoid(x)
发布了2608 篇原创文章 · 获赞 920 · 访问量 506万+

猜你喜欢

转载自blog.csdn.net/jacke121/article/details/103767214