import torch
import torch.nn as nn
x = torch.randn(batch_size, channel, height, width)
# 初始化缩放参数 gamma 和平移参数 beta
gamma = torch.ones(channel, dtype=torch.float32)
beta = torch.zeros(channel, dtype=torch.float32)
def BN(x, gamma, beta, epsilon = 1e-8):
mean = torch.mean(x, dim = (0, 2, 3), keepdim = True)
var = torch.var(x, dim = (0, 2, 3), keepdim = True)
x = (x - mean) / (torch.sqrt(var + epsilon))
out = x * gamma + beta
return out
手写实现BN前向过程
猜你喜欢
转载自blog.csdn.net/slamer111/article/details/132799892
今日推荐
周排行