【PyTorch实战】二、利用PyTorch玩个Fizz_Buzz小游戏

【PyTorch实战】一、双层神经网络的示例
【PyTorch实战】二、利用PyTorch玩个Fizz_Buzz小游戏

1. Fizz_Buzz小游戏的玩法

这个小游戏很简单,从1开始,遇到3输出fizz,遇到5输出buzz,遇到15输出fizz_buzz。

采用类似分段函数的代码写法如下:

def fizz_buzz_encode(i):
    if i % 15 == 0: return 3
    elif i % 5 == 0: return 2
    elif i % 3 == 0: return 1
    else: return 0
def fizz_buzz_decode(i, prediction):
    return [str(i), "fizz", "buzz", "fizzbuzz"][prediction] #第一个[]代表一个数组,第二个[]代表数组中第几个值, 这个巧妙啊,NB
def helper(i):
    print(fizz_buzz_decode(i, fizz_buzz_encode(i)))
for i in range(1, 16):
    helper(i)

2. 利用两层神经网络学习一个函数

import torch
import numpy as np

NUM_DIGITS = 10

# create an embedding for number
def binary_encode(i, num_digits):
    return np.array([i >> d & 1 for d in range(num_digits)][::-1])

train_x = torch.Tensor([binary_encode(i, NUM_DIGITS) for i in range(101, 2** NUM_DIGITS)])
train_y = torch.LongTensor([fizz_buzz_encode(i) for i in range(101, 2** NUM_DIGITS)]) # Long tensor 是整型

# binary_encode(15, NUM_DIGITS)
# train_x.shape
# train_y.shape

NUM_HIDDEN = 100

model = torch.nn.Sequential(
    torch.nn.Linear(NUM_DIGITS, NUM_HIDDEN),
    torch.nn.ReLU(),
    torch.nn.Linear(NUM_HIDDEN, 4)
)

loss_fn = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr= 0.05)

BATCH_SIZE = 128

for epoch in range(1000):
    for start in range(0, len(train_x), BATCH_SIZE): # from 0 to 923, the number of data is 128
        end = start + BATCH_SIZE
        batch_x = train_x[start:end]
        batch_y = train_y[start:end]
        
        y_pred = model(batch_x)
        loss = loss_fn(y_pred, batch_y)
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    print("Epoch", epoch, loss.item())

测试一下:

# generate test data
test_x = torch.Tensor([binary_encode(i, NUM_DIGITS) for i in range(1, 101)])
# predict the test data based on the model
with torch.no_grad():
    test_y = model(test_x)

predicts = zip(range(1, 101), test_y.max(1)[1].data.tolist())
# print([fizz_buzz_decode(i, x) for i, x in predicts])

# 测试一下,输出模型在测试数据上的准确率
count = 0
for i, x in predicts:
    if fizz_buzz_decode(i, x) == fizz_buzz_decode(i, fizz_buzz_encode(i)):
        count +=1
print("Accuracy:", count)

参考感谢

[1] 七月在线-褚则伟的pytorch实战课程

深度学习与PyTorch实战

猜你喜欢

转载自blog.csdn.net/xiangduixuexi/article/details/106737263
今日推荐