pytorch AverageMeter

平均量

class AverageMeter(object):
    """Computes and stores the average and current value"""
    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count

调用:
训练,每个epoch前:
losses = AverageMeter()

当前epoch 每个iter:

losses.update(to_python_float(reduced_loss.data), input.size(0))

# item() is a recent addition, so this helps with backward compatibility.
def to_python_float(t):
    if hasattr(t, 'item'):
        return t.item()
    else:
        return t[0]

猜你喜欢

转载自blog.csdn.net/tywwwww/article/details/131080722