导入包
import torch
import shutil
二、模块
1)保存模型参数,保存模型状态,状态中可以有模型参数,优化器参数,epoch等。如果是在验证集上表现比之前好,那么就是is_best=True,使用shutil.copyfile(src, des)将src文件直接拷贝到des,如果已经存在,就直接覆盖掉。
def save_checkpoint(state, is_best, filename='checkpoint.pth.tar'):#state是一个字典,包含优化器、网络等参数
torch.save(state, filename)
if is_best:
shutil.copyfile(filename, 'model_best.pth.tar')
2)计算相关统计值, 如时间、top1、top5等,需要有四个值,分别是当前的值val、累计值sum(用于求取平均)、所有数量count、平均值avg。同时设置了归零的函数。
class AverageMeter(object):
"""
Keeps track of most recent, average, sum, and count of a metric.
"""
def __init__(self):
self.reset()
def reset(self):
self.val = 0
self.avg = 0
self.sum = 0
self.count = 0
def update(self, val, n=1):
self.val = val
self.sum += val * n
self.count += n
self.avg = self.sum / self.count
3)计算准确率, 有top1、以及top5。top1是正常的错误率计算,也就是选取最大的概率的标签作为对应的标签,若与真实标签不同,则error。top5是前五大概率中都没有真实标签,才算错误,相当于放宽了标准。
def accuracy(scores, targets, k):
"""
Computes top-k accuracy, from predicted and true labels.
:param scores: scores from the model
:param targets: true labels
:param k: k in top-k accuracy
:return: top-k accuracy
"""
batch_size = targets.size(0)
_, ind = scores.topk(k, 1, True, True)
correct = ind.eq(targets.view(-1, 1).expand_as(ind)) #每一行最多有一个True或者没有
correct_total = correct.view(-1).float().sum() # 0D tensor
return correct_total.item() * (100.0 / batch_size)
'''
topk(input, dim, replace, p)参数分别表示 批量概率向量、维度、是否有放回(也就是是否可以重复,True可重复)、p为输入向量各个元素的概率
'''