在使用fvcore模块计算模型的flops时,遇到了标题中的问题,记录一下解决方案。
首先是在jit_analysis.py的589行出错。经过调试发现,op_counts.values()的类型是int32,但是计算要求的类型只能是int、float、np.float64和np.int64,因此需要做如下修改:
inputs, outputs = list(node.inputs()), list(node.outputs())
op_counts = self._op_handles[kind](inputs, outputs)
if isinstance(op_counts, Number):
op_counts = float(op_counts) # 手动进行强制转换
op_counts = Counter({self._simplify_op_name(kind): op_counts})
for v in op_counts.values():
if not isinstance(v, (int, float, np.float64, np.int64)):
raise ValueError(
f"Invalid type {type(v)} for the flop count! "
"Please use a wider type to avoid overflow."
)
添加了注释的部分为我手动修改。
然后,在jit_handles.py处有一个警告,意思是:变量的范围不够大,值出现了溢出(这可能会导致最终计算结果为0)。解决方案在https://github.com/facebookresearch/fvcore/issues/104
如果进不去github,改进代码如下:
try:
from math import prod
except ImportError:
from numpy import prod as prodnp # 修改
def prod(x): # 新增
return int(prodnp(x))