LSTM网络本质还是RNN网络
LSTM是为了解决RNN中的反馈消失问题而被提出的模型,它也可以被视为RNN的一个变种。与RNN相比,增加了3个门(gate):input门,forget门和output门。
import torch.nn as nn
from torch.nn import functional as F
class LSTM(nn.Module):
def __init__(self):
super(LSTM, self).__init__()
self.rnn = nn.LSTM(
input_size=14,
hidden_size=32,
num_layers=2,
batch_first=True
)
self.fc = nn.Linear(32, 2)
def forward(self, x):
out, _ = self.rnn(x)
out = out.view(out.size(0), -1)
out = self.fc(out)
# return F.log_softmax(out)
return out
class CNN(nn.Module):
def __init__(self):
super(CNN, self).__init__()
self.cnn = nn.Sequential(
nn.Conv1d(1, 32, 5),
nn.BatchNorm1d(32),
nn.ReLU(),
nn.Conv1d(32, 64, 5),
nn.BatchNorm1d(