Pytorch 统计网络参数个数

1 代码

def count_parameters(model):  # 传入的是模型实例对象
    params = [p.numel() for p in model.parameters() if p.requires_grad]
    for item in params:
        print(f'{
      
      item:>16}')   # 参数大于16的展示
    print(f'________\n{
      
      sum(params):>16}')  # 大于16的进行统计,可以自行修改
count_parameters(net)

2 输出

在这里插入图片描述

猜你喜欢

转载自blog.csdn.net/qq_44864833/article/details/127710471