原
合并bn层到conv或FC层原理介绍及代码实现
1.bn合并的必要性:
bn层即batch-norm层,一般是深度学习中用于加速训练速度和一种方法,一般放置在卷积层(conv层)或者全连接层之后,将数据归一化并加速了训练拟合速度。但是bn层虽然在深度学习模型训练时起到了一定的积极作用,但是在预测时因为凭空多了一些层,影响了整体的计算速度并占用了更多内存或者显存空间。所以我们设想如果能将bn层合并到相邻的卷积层或者全连接层之后就好了,于是就有了这篇文章所提到的工作。
2.bn合并本身的数学原理:
bn层一般在神经网络中‘所处的位置如下图所示:
如上图可以看到,bn层的位置一般在conv(or Fc)层的后面,也有一些情况bn在conv(or Fc)层的前面。我们先来两种情况分别来考虑。
2.1 bn层在conv层之后的情形
bn合并的原理,可以由下两张图所示:
bn层进行数据处理的过程
这张图的表示,将一个数据X,进行bn层的操作和计算得到的结果。
扫描二维码关注公众号,回复:
4834081 查看本文章
这张图表示,第一部分代表bn层处理之后接着卷基层的操作结果,第二部分表示将bn层合并到卷积层之后,卷积层w和b的变化。
2.2 bn在前,卷积在后的合并方式
这种情况下,FC层的合并方式和之前2.1的结果类似,但是bn在前,conv在后的情形,因为conv存在pad的情形,所以无法合并。
3.卷积和bn合并的代码实现
3.1 caffe版本(该版本是我从网络获取的,如侵权删)
-
#!/usr/bin/env python
-
import _init_paths
-
import numpy
as np
-
import sys
-
import os
-
import os.path
as osp
-
import google.protobuf
as pb
-
from argparse
import ArgumentParser
-
import sys
-
import caffe
-
-
-
def load_and_fill_biases(src_model, src_weights, dst_model, dst_weights):
-
with open(src_model)
as f:
-
model = caffe.proto.caffe_pb2.NetParameter()
-
pb.text_format.Merge(f.read(), model)
-
-
for i, layer
in enumerate(model.layer):
-
if layer.type ==
'Convolution':
# or layer.type == 'Scale':
-
# Add bias layer if needed
-
if layer.convolution_param.bias_term ==
False:
-
layer.convolution_param.bias_term =
True
-
layer.convolution_param.bias_filler.type =
'constant'
-
layer.convolution_param.bias_filler.value =
0.0
-
-
with open(dst_model,
'w')
as f:
-
f.write(pb.text_format.MessageToString(model))
-
-
caffe.set_mode_cpu()
-
net_src = caffe.Net(src_model, src_weights, caffe.TEST)
-
net_dst = caffe.Net(dst_model, caffe.TEST)
-
for key
in net_src.params.keys():
-
for i
in range(len(net_src.params[key])):
-
net_dst.params[key][i].data[:] = net_src.params[key][i].data[:]
-
-
if dst_weights
is
not
None:
-
# Store params
-
pass
-
-
return net_dst
-
-
-
def merge_conv_and_bn(net, i_conv, i_bn, i_scale):
-
# This is based on Kyeheyon's work
-
assert(i_conv !=
None)
-
assert(i_bn !=
None)
-
-
def copy_double(data):
-
return np.array(data, copy=
True, dtype=np.double)
-
-
key_conv = net._layer_names[i_conv]
-
key_bn = net._layer_names[i_bn]
-
key_scale = net._layer_names[i_scale]
if i_scale
else
None
-
-
# Copy
-
bn_mean = copy_double(net.params[key_bn][
0].data)
-
bn_variance = copy_double(net.params[key_bn][
1].data)
-
num_bn_samples = copy_double(net.params[key_bn][
2].data)
-
-
# and Invalidate the BN layer
-
net.params[key_bn][
0].data[:] =
0
-
net.params[key_bn][
1].data[:] =
1
-
net.params[key_bn][
2].data[:] =
1
-
if num_bn_samples[
0] ==
0:
-
num_bn_samples[
0] =
1
-
-
if net.params.has_key(key_scale):
-
print
'Combine {:s} + {:s} + {:s}'.format(key_conv, key_bn, key_scale)
-
scale_weight = copy_double(net.params[key_scale][
0].data)
-
scale_bias = copy_double(net.params[key_scale][
1].data)
-
net.params[key_scale][
0].data[:] =
1
-
net.params[key_scale][
1].data[:] =
0
-
else:
-
print
'Combine {:s} + {:s}'.format(key_conv, key_bn)
-
scale_weight =
1
-
scale_bias =
0
-
-
weight = copy_double(net.params[key_conv][
0].data)
-
bias = copy_double(net.params[key_conv][
1].data)
-
alpha = scale_weight / np.sqrt(bn_variance / num_bn_samples[
0] + np.finfo(np.double).eps)
-
net.params[key_conv][
1].data[:] = bias * alpha + (scale_bias - (bn_mean / num_bn_samples[
0]) * alpha)
-
for i
in range(len(alpha)):
-
net.params[key_conv][
0].data[i] = weight[i] * alpha[i]
-
-
def merge_batchnorms_in_net(net):
-
# for each BN
-
for i, layer
in enumerate(net.layers):
-
if layer.type !=
'BatchNorm':
-
continue
-
-
l_name = net._layer_names[i]
-
-
l_bottom = net.bottom_names[l_name]
-
assert(len(l_bottom) ==
1)
-
l_bottom = l_bottom[
0]
-
l_top = net.top_names[l_name]
-
assert(len(l_top) ==
1)
-
l_top = l_top[
0]
-
-
can_be_absorbed =
True
-
-
# Search all (bottom) layers
-
for j
in xrange(i -
1,
-1,
-1):
-
tops_of_j = net.top_names[net._layer_names[j]]
-
if l_bottom
in tops_of_j:
-
if net.layers[j].type
not
in [
'Convolution',
'InnerProduct']:
-
can_be_absorbed =
False
-
else:
-
# There must be only one layer
-
conv_ind = j
-
break
-
-
if
not can_be_absorbed:
-
continue
-
-
# find the following Scale
-
scale_ind =
None
-
for j
in xrange(i +
1, len(net.layers)):
-
bottoms_of_j = net.bottom_names[net._layer_names[j]]
-
if l_top
in bottoms_of_j:
-
if scale_ind:
-
# Followed by two or more layers
-
scale_ind =
None
-
break
-
-
if net.layers[j].type
in [
'Scale']:
-
scale_ind = j
-
-
top_of_j = net.top_names[net._layer_names[j]][
0]
-
if top_of_j == bottoms_of_j[
0]:
-
# On-the-fly => Can be merged
-
break
-
-
else:
-
# Followed by a layer which is not 'Scale'
-
scale_ind =
None
-
break
-
-
-
merge_conv_and_bn(net, conv_ind, i, scale_ind)
-
-
return net
-
-
-
def process_model(net, src_model, dst_model, func_loop, func_finally):
-
with open(src_model)
as f:
-
model = caffe.proto.caffe_pb2.NetParameter()
-
pb.text_format.Merge(f.read(), model)
-
-
-
for i, layer
in enumerate(model.layer):
-
map(
lambda x: x(layer, net, model, i), func_loop)
-
-
map(
lambda x: x(net, model), func_finally)
-
-
with open(dst_model,
'w')
as f:
-
f.write(pb.text_format.MessageToString(model))
-
-
-
# Functions to remove (redundant) BN and Scale layers
-
to_delete_empty = []
-
def pick_empty_layers(layer, net, model, i):
-
if layer.type
not
in [
'BatchNorm',
'Scale']:
-
return
-
-
bottom = layer.bottom[
0]
-
top = layer.top[
0]
-
-
if (bottom != top):
-
# Not supperted yet
-
return
-
-
if layer.type ==
'BatchNorm':
-
zero_mean = np.all(net.params[layer.name][
0].data ==
0)
-
one_var = np.all(net.params[layer.name][
1].data ==
1)
-
#length_is_1 = (net.params['conv1_1/bn'][2].data == 1) or (net.params[layer.name][2].data == 0)
-
length_is_1 = (net.params[layer.name][
2].data ==
1)
-
-
if zero_mean
and one_var
and length_is_1:
-
print
'Delete layer: {}'.format(layer.name)
-
to_delete_empty.append(layer)
-
-
if layer.type ==
'Scale':
-
no_scaling = np.all(net.params[layer.name][
0].data ==
1)
-
zero_bias = np.all(net.params[layer.name][
1].data ==
0)
-
-
if no_scaling
and zero_bias:
-
print
'Delete layer: {}'.format(layer.name)
-
to_delete_empty.append(layer)
-
-
def remove_empty_layers(net, model):
-
map(model.layer.remove, to_delete_empty)
-
-
-
# A function to add 'engine: CAFFE' param into 1x1 convolutions
-
def set_engine_caffe(layer, net, model, i):
-
if layer.type ==
'Convolution':
-
if layer.convolution_param.kernel_size ==
1\
-
or (layer.convolution_param.kernel_h == layer.convolution_param.kernel_w ==
1):
-
layer.convolution_param.engine = dict(layer.convolution_param.Engine.items())[
'CAFFE']
-
-
-
def main(args):
-
# Set default output file names
-
if args.output_model
is
None:
-
file_name = osp.splitext(args.model)[
0]
-
args.output_model = file_name +
'_inference.prototxt'
-
if args.output_weights
is
None:
-
file_name = osp.splitext(args.weights)[
0]
-
args.output_weights = file_name +
'_inference.caffemodel'
-
-
net = load_and_fill_biases(args.model, args.weights, args.model +
'.temp.pt',
None)
-
-
net = merge_batchnorms_in_net(net)
-
-
process_model(net, args.model +
'.temp.pt', args.output_model,
-
[pick_empty_layers, set_engine_caffe],
-
[remove_empty_layers])
-
-
# Store params
-
net.save(args.output_weights)
-
-
-
if __name__ ==
'__main__':
-
parser = ArgumentParser(
-
description=
"Generate Batch Normalized model for inference")
-
parser.add_argument(
'model', help=
"The net definition prototxt")
-
parser.add_argument(
'weights', help=
"The weights caffemodel")
-
parser.add_argument(
'--output_model')
-
parser.add_argument(
'--output_weights')
-
args = parser.parse_args()
-
main(args)
3.2 mxnet版本实现(conv_no_bias=True的时候会有问题,此代码我自己实现)
-
import sys, argparse
-
import find_mxnet, find_caffe
-
import mxnet
as mx
-
import caffe
-
import pdb
-
import json
-
import numpy
as np
-
import copy
-
-
def merge_bn_into_conv_or_fc(json_str,net_param):
-
json_obj, nodes,names,old_num_to_name,inputs = load_json(json_str)
-
#json_str = json.dumps(json_obj, indent=4)
-
name_to_num = dict([(v,k)
for k,v
in old_num_to_name.iteritems()])
-
-
bn_name_list = []
# for store the bn_name
-
conv_name_list = []
# for store the conv_name
-
-
for i
in range(len(json_obj[
'nodes'])):
-
# seach batch-norm and conv(fc)
-
if json_obj[
'nodes'][i][
'op'] ==
"BatchNorm":
-
may_conv_index= json_obj[
'nodes'][i][
'inputs'][
0][
0]
-
-
# search conv or fc before the batchnorm
-
if json_obj[
'nodes'][may_conv_index][
'op']
in [
"Convolution",
"FullyConnected"]:
-
bn_name_list.append(json_obj[
'nodes'][i][
'name'])
-
conv_name_list.append(json_obj[
'nodes'][may_conv_index][
'name'])
-
-
if len(bn_name_list)!=len(conv_name_list)
or len(bn_name_list)<=
0:
-
print
"error, len(bn_name_list) should be equal len(conv_name_list)"
-
exit()
-
-
for i
in range(len(bn_name_list)):
-
print i
-
json_obj, nodes,names,old_num_to_name,inputs = load_json(json_str)
-
name_to_num = dict([(v,k)
for k,v
in old_num_to_name.iteritems()])
-
-
# bn_name,bn-eps,bn-fixgamma
-
bn_index = name_to_num[bn_name_list[i]]
-
bn_name = json_obj[
'nodes'][bn_index][
'name']
-
bn_eps = float(json_obj[
'nodes'][bn_index][
'param'][
'eps'])
-
bn_fixgamma = bool(json_obj[
'nodes'][bn_index][
'param'][
'fix_gamma'])
-
-
# conv_name,no_bias
-
conv_index = name_to_num[conv_name_list[i]]
-
conv_name = json_obj[
'nodes'][conv_index][
'name']
-
conv_no_bias = bool(json_obj[
'nodes'][conv_index][
'param'][
'no_bias'])
-
-
# use merge_bn_conv_after_bn
-
net_param = copy.deepcopy(merge_bn_conv_after(net_param=net_param, conv_name=conv_name, bn_name=bn_name, fix_gamma=bn_fixgamma, no_bias=conv_no_bias, eps=bn_eps))
-
json_str = copy.deepcopy(merge_bn_conv_after_bn_json(json_str=json_str,conv_name=conv_name,bn_name=bn_name,fix_gamma=bn_fixgamma,no_bias=conv_no_bias,eps=bn_eps))
-
-
return json_str,net_param
-
-
-
def load_json(json_str):
-
#json_obj = json.load(json_file) # dict contain "nodes arg_nodes, heads"
-
json_obj = json.loads(json_str)
# dict contain "nodes arg_nodes, heads"
-
nodes = json_obj[
'nodes']
# a list,lens = num of layers
-
names = [node[
'name']
for node
in nodes]
# names
-
old_num_to_name = dict(enumerate(names))
# dict
-
name_to_num = dict([(v,k)
for k,v
in old_num_to_name.iteritems()])
-
inputs = [node[
'inputs']
for node
in nodes]
-
return json_obj ,nodes,names,old_num_to_name,inputs
-
-
def merge_bn_conv_after_bn_json(json_str,conv_name,bn_name,fix_gamma=False,no_bias=False,eps=0.001):
-
json_obj, nodes,names,old_num_to_name,inputs = load_json(json_str)
-
name_to_num = dict([(v,k)
for k,v
in old_num_to_name.iteritems()])
-
# cal the conv and bn index
-
conv_index = name_to_num[conv_name]
-
bn_index = name_to_num[bn_name]
-
for i
in range(len(json_obj[
'nodes'])):
-
if len(json_obj[
'nodes'][i][
'inputs'])<=
0:
-
continue
# when inputs =[]
-
-
# change bn_node to conv_node
-
input_list= json_obj[
'nodes'][i][
'inputs']
-
for j
in range(len(input_list)):
-
if input_list[j][
0] == bn_index:
-
input_list[j][
0] = conv_index
-
else:
-
pass
-
json_obj[
'nodes'][i][
'inputs'] = input_list
-
-
# for change bn-layer to a param not op
-
if json_obj[
'nodes'][i][
'name'] == bn_name:
-
json_obj[
'nodes'][i] = copy.deepcopy(json_obj[
'nodes'][i
-1])
-
json_obj[
'nodes'][i][
'name'] = bn_name
-
-
# change_name
-
if no_bias==
True:
-
# print json_obj['nodes'][int(bn_index)-1]['name']
-
json_obj[
'nodes'][int(bn_index)
-1][
'name'] = conv_name +
'_bias'
-
# print json_obj['nodes'][int(bn_index)-1]['name']
-
json_obj[
'nodes'][conv_index][
'param'][
'no_bias'] =
"False"
-
list_add = []
-
list_add.append(int(bn_index)
-1)
-
#list_add.append(int(bn_index))
-
list_add.append(
0)
-
json_obj[
'nodes'][conv_index][
'inputs'].append(list_add)
-
-
# change bn_beta_name to conv_bias
-
json_obj[
'nodes'][int(bn_index)
-1][
'name'] = conv_name +
'_bias'
-
-
# return json_obj
-
# return json_str
-
return json.dumps(json_obj, indent=
4)
-
-
-
# merge conv and after bn
-
def merge_bn_conv_after(net_param,conv_name,bn_name, fix_gamma = False, no_bias = False, eps=0.001):
-
gamma = net_param[
'arg:'+ bn_name +
'_gamma'].asnumpy()
# scale gamma
-
if fix_gamma ==
True:
# fix_gamma = true
-
gamma *=
0
-
gamma +=
1
-
beta = net_param[
'arg:'+ bn_name +
'_beta'].asnumpy()
# scale beta
-
mov_mean = net_param[
'aux:'+ bn_name +
'_moving_mean'].asnumpy()
# bn-mean
-
mov_var = net_param[
'aux:' + bn_name +
'_moving_var'].asnumpy()
# bn var
-
mov_std = np.sqrt(mov_var + eps)
# calulate the std from var
-
-
# conv_weights and conv_bias before merge
-
part_0_conv_weight = net_param[
'arg:' + conv_name +
'_weight'].asnumpy()
-
-
output_channel =part_0_conv_weight.shape[
0]
# output_channel
-
pdb.set_trace()
-
if no_bias ==
True:
# fill the bias to zero , it is no use has_something wrong
-
# update the conv_bias and conv_weights
-
part_0_conv_bias = np.zeros((output_channel,),dtype = np.float64)
-
#pdb.set_trace()
-
for i
in range(output_channel):
# shape[0] is output_channel_num, weight.shape = [out,in,kernel,kernel]
-
part_0_conv_weight[i,:,:,:] *= float(gamma[i]/mov_std[i])
# update conv_weight
-
# part_0_conv_bias[i] *= float(gamma[i]/mov_std[i]) # update conv_bias
-
part_0_conv_bias[i] += beta[i]-float(gamma[i]*mov_mean[i]/mov_std[i])
# update conv_bias
-
#pdb.set_trace()
-
-
else:
-
# update the conv_bias and conv_weights
-
part_0_conv_bias = net_param[
'arg:' + conv_name+
'_bias'].asnumpy()
-
for i
in range(output_channel):
# shape[0] is output_channel_num, weight.shape = [out,in,kernel,kernel]
-
part_0_conv_weight[i,:,:,:] *= float(gamma[i]/mov_std[i])
# update conv_weight
-
part_0_conv_bias[i] *= float(gamma[i]/mov_std[i])
# update conv_bias
-
part_0_conv_bias[i] += beta[i]-float(gamma[i]*mov_mean[i]/mov_std[i])
# update conv_bias
-
-
-
# update the net_param
-
net_param[
'arg:' + conv_name +
'_weight']= mx.nd.array(part_0_conv_weight)
-
if no_bias==
True:
-
#net_param['arg:' + bn_name + '_bias'] = mx.nd.array(part_0_conv_bias)
-
net_param[
'arg:' + conv_name +
'_bias'] = mx.nd.array(part_0_conv_bias)
-
#pdb.set_trace()
-
else:
-
net_param[
'arg:' + conv_name +
'_bias'] = mx.nd.array(part_0_conv_bias)
-
#print net_param.keys()
-
return net_param
-
-
-
# input_mx_model + input_mx_epoch = resnet/base-symbol.json and resnet/base-14-9999.params
-
input_param = sys.argv[
1]
# such as resnet/base-14-9999.params
-
input_json = file(sys.argv[
2])
# such as resnet/base-14.json
-
net_param = mx.nd.load(input_param)
-
new_json_str,new_param = merge_bn_into_conv_or_fc(input_json.read(),net_param)
-
-
#new_json = merge_bn_conv_after_bn_json(json_file = input_json, bn_name="part_0_bn_conv1", conv_name = "part_0_conv",fix_gamma = True, no_bias = False, eps=0.001)
-
#net_param = merge_bn_conv_after(net_param = net_param, bn_name="part_0_bn_conv1", conv_name = "part_0_conv",fix_gamma = True, no_bias = False, eps=0.001)
-
#net_param = merge_bn_conv_after(net_param = net_param, bn_name="part_0_bn0", conv_name = "part_0_conv0",fix_gamma = True, no_bias = True ) # for resnet_divide4
-
#new_json_str = json.dumps(new_json,indent=4)
-
-
open((sys.argv[
2]).replace(
".json",
"_change.json"),
"w").write(new_json_str)
-
mx.nd.save(input_param.replace(
".params",
"_change.params"),new_param)