【PyTorch实战】一、双层神经网络的示例
【PyTorch实战】二、利用PyTorch玩个Fizz_Buzz小游戏
1. Fizz_Buzz小游戏的玩法
这个小游戏很简单,从1开始,遇到3输出fizz,遇到5输出buzz,遇到15输出fizz_buzz。
采用类似分段函数的代码写法如下:
def fizz_buzz_encode(i):
if i % 15 == 0: return 3
elif i % 5 == 0: return 2
elif i % 3 == 0: return 1
else: return 0
def fizz_buzz_decode(i, prediction):
return [str(i), "fizz", "buzz", "fizzbuzz"][prediction] #第一个[]代表一个数组,第二个[]代表数组中第几个值, 这个巧妙啊,NB
def helper(i):
print(fizz_buzz_decode(i, fizz_buzz_encode(i)))
for i in range(1, 16):
helper(i)
2. 利用两层神经网络学习一个函数
import torch
import numpy as np
NUM_DIGITS = 10
# create an embedding for number
def binary_encode(i, num_digits):
return np.array([i >> d & 1 for d in range(num_digits)][::-1])
train_x = torch.Tensor([binary_encode(i, NUM_DIGITS) for i in range(101, 2** NUM_DIGITS)])
train_y = torch.LongTensor([fizz_buzz_encode(i) for i in range(101, 2** NUM_DIGITS)]) # Long tensor 是整型
# binary_encode(15, NUM_DIGITS)
# train_x.shape
# train_y.shape
NUM_HIDDEN = 100
model = torch.nn.Sequential(
torch.nn.Linear(NUM_DIGITS, NUM_HIDDEN),
torch.nn.ReLU(),
torch.nn.Linear(NUM_HIDDEN, 4)
)
loss_fn = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr= 0.05)
BATCH_SIZE = 128
for epoch in range(1000):
for start in range(0, len(train_x), BATCH_SIZE): # from 0 to 923, the number of data is 128
end = start + BATCH_SIZE
batch_x = train_x[start:end]
batch_y = train_y[start:end]
y_pred = model(batch_x)
loss = loss_fn(y_pred, batch_y)
optimizer.zero_grad()
loss.backward()
optimizer.step()
print("Epoch", epoch, loss.item())
测试一下:
# generate test data
test_x = torch.Tensor([binary_encode(i, NUM_DIGITS) for i in range(1, 101)])
# predict the test data based on the model
with torch.no_grad():
test_y = model(test_x)
predicts = zip(range(1, 101), test_y.max(1)[1].data.tolist())
# print([fizz_buzz_decode(i, x) for i, x in predicts])
# 测试一下,输出模型在测试数据上的准确率
count = 0
for i, x in predicts:
if fizz_buzz_decode(i, x) == fizz_buzz_decode(i, fizz_buzz_encode(i)):
count +=1
print("Accuracy:", count)
参考感谢
[1] 七月在线-褚则伟的pytorch实战课程
深度学习与PyTorch实战