1. 问题
当MXNet模型的Batch Normalization的fix_gamma参数为True时,会导致转ONNX模型失败,此时输出的ONNX参数如下图所示,导致ONNX的推理结果和MXNet不一致。
2. 解决方法
出现MXNet Batch Normalization的fix_gamma参数等于True时,可以手动修改batchnorm_gamma参数值,使ONNX模型输出正常,相关代码如下:
sym, arg_params, aux_params = mx.model.load_checkpoint(prefix, epoch)
reshape_params = {}
for k, v in arg_params.items():
if 'batchnorm_gamma' in k:
v = 1 - v
reshape_params[k] = v
mx.model.save_checkpoint(prefix, epoch, sym, reshape_params, aux_params)
经过修改后ONNX模型结构如下图所示,结果输出正常。