def test(): batch_i=0 mean_loss=0 last_mean_loss=0
if batch_i % 40 == 39: if last_total_loss > 0 and total_loss > last_total_loss*1.01: print("total_loss", total_loss) adjust_learning_rate(optimizer) else: print("total_loss",total_loss,last_total_loss) last_total_loss = total_loss total_loss = torch.sum(loss) elif batch_i==0: total_loss = torch.sum(loss) else: total_loss += torch.sum(loss)if __name__ == '__main__': test()