Pytorch计算模型的参数量

比如说我计算TransUNet的参数量
(计算之前需要安装thop库 pip install thop)

from thop import profile, clever_format
flops, params = profile(net, inputs=(inputs,))
macs, params = clever_format([flops, params], "%.3f") # 格式化输出
print('flops':, macs) # 计算量
print('params:',params) # 模型参数量

最后输出:
flops:128.677G
params:93.192M

可以发现在UNet中加入transformer之后,参数量还是增加不少的(DUNet的参数量是:19.219M)

猜你喜欢

转载自blog.csdn.net/qq_36321330/article/details/115339778