在epoch前插入:
initepoch = 0
resume = True # 设置是否需要从上次的状态继续训练
if resume:
if os.path.isfile("./testweights/last_model.pth"):
print("Resume from checkpoint...")
checkpoint = torch.load("./testweights/last_model.pth")
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
initepoch = checkpoint['epoch'] + 1
print("====>loaded checkpoint (epoch{})".format(checkpoint['epoch']))
else:
print("====>no checkpoint found.")
initepoch = 0 # 如果没进行训练过,初始训练epoch值为0
epoch循环改为:
for epoch in range(initepoch, args.epochs):
在epoch中插入:
# save best epoch
if val_acc > best_acc:
best_acc = val_acc
torch.save(model.state_dict(), "testweights/best_model.pth")
print("!!--Best Model has Update--!!")
# save epoch model
torch.save(model.state_dict(), "./testweights/model-{}.pth".format(epoch))
# save last model
checkpoint = {"model_state_dict": model.state_dict(),
"optimizer_state_dict": optimizer.state_dict(),
"epoch": epoch}
path_checkpoint = "./testweights/last_model.pth"
torch.save(checkpoint, path_checkpoint)
print("!!--Last Model has Update(-{})--!!".format(epoch))