import matplotlib.pylab as plt
import torch
def xyplot(x,y,name):
# plt.rcParams['figure.figsize'] = (5, 5)
plt.plot(x.detach().numpy(), y.detach().numpy())
plt.xlabel('x')
plt.ylabel(name + '(x)')
plt.show()
x=torch.arange(-8,8,0.1,requires_grad=True)
y=x.relu()
xyplot(x,y,'relu')
y.sum().backward()
xyplot(x,x.grad,'grad of relu')
y = x.sigmoid()
xyplot(x, y, 'sigmoid')
x.grad.zero_()
y.sum().backward()
xyplot(x, x.grad, 'grad of sigmoid')
y = x.tanh()
xyplot(x, y, 'tanh')
x.grad.zero_()
y.sum().backward()
xyplot(x, x.grad, 'grad of tanh')
python 激活函数图像代码
猜你喜欢
转载自blog.csdn.net/qq_40107571/article/details/131396752
今日推荐
周排行