AverageMeter类的作用:用来管理一些自定义的变量。
代码如下:
class AverageMeter:
"""Computes and stores the average and current value"""
def __init__(self):
self.reset()
def reset(self):
self.val = 0
self.avg = 0
self.sum = 0
self.count = 0
def update(self, val, n=1):
self.val = val
self.sum += val * n
self.count += n
self.avg = self.sum / self.count
使用方法:
@torch.no_grad()
def validate(config, data_loader, model):
criterion = torch.nn.CrossEntropyLoss()
model.eval()
batch_time = AverageMeter()
loss_meter = AverageMeter()
acc1_meter = AverageMeter()
acc5_meter = AverageMeter()
end = time.time()
for idx, (images, target) in enumerate(data_loader):
images = images.cuda(non_blocking=True)
target = target.cuda(non_blocking=True)
# compute output
output = model(images)
# measure accuracy and record loss
loss = criterion(output, target)
acc1, acc5 = accuracy(output, target, topk=(1, 5))
loss_meter.update(loss.item(), target.size(0))
acc1_meter.update(acc1.item(), target.size(0))
acc5_meter.update(acc5.item(), target.size(0))
return acc1_meter.avg, acc5_meter.avg, loss_meter.avg