1.前言
Torch 中提供了很多方便的途径, 同样是神经网络, 能快则快, 我们看看如何用更简单的方式搭建同样的神经网络
2.常规方法
class NetWork(nn.Module):
def __init__(self,n_input,n_hidden,n_output):
super(NetWork,self).__init__()
self.hidden = nn.Linear(n_input, n_hidden)
self.output_for_predict = nn.Linear(n_hidden, n_output)
def forward(self, x):
x = F.relu(self.hidden(x))
x = self.output_for_predict(x)
return x
network = NetWork(n_input = 1,n_hidden = 8, n_output = 2) #对于二分类问题n_output为2
print(network) #打印网络结构
3.用nn.Sequential搭建
我们用 class 继承了一个 torch 中的神经网络结构, 然后对其进行了修改, 不过还有更快的方法
network_sequential = nn.Sequential(
nn.Linear(1,8),
nn.ReLU(),
nn.Linear(8,1)
)
print(network_sequential)
4.不同方法对比
我们会发现 network_sequential 相比于network多了ReLU, 在 network 中, 激活函数实际上是在 forward() 功能中才被调用的. 相比network_sequential, network 的好处就是, 你可以根据你的个人需要更加个性化你自己的前向传播过程, 比如(RNN). 不过如果你不需要七七八八的过程, 相信 network_sequential 这种形式更适合你