训练模型时,控制台输出模板:
【数据集加载类】
import torch
from torch.utils.data import Dataset
__all__ = ["TestDataset"]
class TestDataset(Dataset):
def __init__(self):
super().__init__()
self.dataset = torch.rand(20, 10, 1, 200, 200)
def __getitem__(self, item):
return self.dataset[item,], self.dataset[item, 10:]
def __len__(self):
return self.dataset.shape[0]
【训练代码】
import time
from random import random
from datetime import datetime
from torch.utils.data import DataLoader
from tqdm import tqdm, trange
from TestDataset import TestDataset
max_epochs = 5
train_set = TestDataset()
test_set = TestDataset()
train_loader = DataLoader(train_set)
test_loader = DataLoader(test_set)
def training(epoch: int, test_frequency: int = 5):
for epoch in range(epoch):
with tqdm(
iterable=train_loader,
bar_format='{desc} {n_fmt:>4s}/{total_fmt:<4s} {percentage:3.0f}%|{bar}| {postfix}',
) as t:
start_time = datetime.now()
loss_list = []
for batch, data in enumerate(train_loader):
t.set_description_str(f"\33[36m【Epoch {epoch + 1:04d}】")
# 训练代码
time.sleep(1)
# 计算当前损失
loss = random()
loss_list.append(loss)
cur_time = datetime.now()
delta_time = cur_time - start_time
t.set_postfix_str(f"train_loss={sum(loss_list) / len(loss_list):.6f}, 执行时长:{delta_time}\33[0m")
t.update()
if (epoch + 1) % test_frequency == 0:
with tqdm(
iterable=test_loader,
bar_format='{desc} {postfix}',
) as t:
# 测试一下
time.sleep(2)
test_loss = 3.1415926
t.set_description_str(f"\33[35m【测试集】")
t.set_postfix_str(f"test_loss={test_loss:.6f}\33[0m")
t.update()
training(epoch=max_epochs, test_frequency=1)
【效果图】