利用Tensorlfow model optimization tool进行网络剪枝遇到不支持的层

问题描述:
使用TensorFlow Model Optimization 0.6.0对自己的网络进行剪枝的时候,遇到官方暂时不支持的层

ValueError: Please initialize Prune with a supported layer. Layers should either be supported by the PruneRegistry (built-in keras layers) or should be a PrunableLayer instance, or should has a customer defined get_prunable_weights method. You passed: <class ‘tensorflow.python.keras.layers.convolutional.Conv1DTranspose’>
在这里插入图片描述

情况一:不支持的层是Keras官方原生的层

参考博主专业混水
prune_registry.py包含了tensorflow官方目前支持剪枝的层,以及对应要剪枝的权重(默认是kernel,像Relu和Batch_norm这种会返回一个空列表,没有可以剪枝的参数)
找到prune_registry.py中的字典_LAYERS_WEIGHTS_MAP手动修改
在这里插入图片描述
运行成功,无报错,但这前提是keras官方的层。

情况二:不支持的层是自定义层(自定义层中是原生Keras层的组合)

首先,先贴出我个人的自定义层,是一个DenseNet里面的Transition Layer,但是是一维的版本。

class TransitionBlock(tf.keras.layers.Layer,tfmot.sparsity.keras.PrunableLayer):  
    def __init__(self, num_channels,name, **kwargs):
        super(TransitionBlock, self).__init__(**kwargs)
        self.batch_norm = tf.keras.layers.BatchNormalization(name = "{}_BatchNorm".format(name))
        self.relu = tf.keras.layers.ReLU(name = "{}_ReLU".format(name))
        self.conv = tf.keras.layers.Conv1D(num_channels, kernel_size=1 ,kernel_initializer='he_normal', name = "{}_Conv1D".format(name))
        self.max_pool = tf.keras.layers.MaxPooling1D(name = "{}_MaxPool".format(name))

    def get_config(self):
        config = super().get_config().copy()
        config.update({
    
    
            "batch_norm": self.batch_norm,
            "relu": self.relu,
            "conv": self.conv,
            "max_pool": self.max_pool,
        })
        return config

    def get_prunable_weights(self):
        return [self.conv.kernel]  #这里返回的是list,参考官方例子


    def call(self, x):
        x = self.batch_norm(x)

        x = self.relu(x)
        x_concat = self.conv(x)    
        x = self.max_pool(x_concat)
        return x_concat,x   #x_concat是在upsample的时候用来concat的

这里有两点需要注意:
1.在自定义层的时候,同时继承两个类tf.keras.layers.Layer,tfmot.sparsity.keras.PrunableLayer,参考tensorflow官方的例子修剪自定义Keras层或修改某些部分以进行修剪

2.get_prunable_weights()方法
一般情况下,如果是对tf.keras.layers.Dense或者tf.keras.layers.Conv1D进行剪枝的时候,默认是对其kernel进行权重剪枝。但是在自定义层的时候,会识别不出来,出现上面ValueError的情况,需要自己重写get_prunable_weights。
对于本文中的情况,在self.batch_norm、self.relu、self.conv和self.maxpool中,只有self.conv是有权重可以剪枝的。

参考官方例子中,get_prunable_weights()返回的是一个list
在这里插入图片描述
tfmot.sparsity.keras.PrunableLayer 的源代码
通过查看tfmot.sparsity.keras.PrunableLayer 的源代码

class PrunableLayer(object):
  """Abstract Base Class for making your own keras layer prunable.
  Custom keras layers which want to add pruning should implement this class.
  """

  @abc.abstractmethod
  def get_prunable_weights(self):
    """Returns list of prunable weight tensors.
    All the weight tensors which the layer wants to be pruned during
    training must be returned by this method.
    Returns: List of weight tensors/kernels in the keras layer which must be
        pruned during training.
    """
    raise NotImplementedError('Must be implemented in subclasses.')

发现继承tfmot.sparsity.keras.PrunableLayer 时其实是继承了一个抽象方法,

对应地,在我自定义的Transition layer里面重写

    def get_prunable_weights(self):
        return [self.conv.kernel]

即可实现。

情况三:不属于tf.keras.layer里面的层,如TFOpLambda

利用tf.keras.models.clone_model 只转换其中一些layer

def apply_pruning_to(layer):
    accepted_layers = [
        tf.keras.layers.Conv1D,
        tf.keras.layers.MaxPooling1D,
        tf.keras.layers.BatchNormalization,
        tf.keras.layers.ReLU,
        tf.keras.layers.GlobalAveragePooling1D,#通道注意力
        tf.keras.layers.GlobalMaxPooling1D,#通道注意力
        tf.keras.layers.Reshape,#通道注意力
        tf.keras.layers.Add,#通道注意力
        tf.keras.layers.Activation,#通道注意力
        tf.keras.layers.Multiply,
        tf.keras.layers.Concatenate,
        tf.keras.layers.Conv1DTranspose,
        #ConvBlock,
        DenseBlock,
        TransitionBlock,
    ]

    for accepted in accepted_layers:
        if isinstance(layer, accepted):
            return tfmot.sparsity.keras.prune_low_magnitude(layer)
    return layer

model_for_pruning = tf.keras.models.clone_model(
    model,
    clone_function=apply_pruning_to,
)

参考1Yamnet clustering AttributeError: Exception encountered when calling layer “tf.operators.add” (type TFOpLambda). #972
参考2Prune some layers (Sequential and Functional)

下面是各种层对应的要剪枝的参数

class PruneRegistry(object):
  """Registry responsible for built-in keras layers."""

  # The keys represent built-in keras layers and the values represent the
  # the variables within the layers which hold the kernel weights. This
  # allows the wrapper to access and modify the weights.
  _LAYERS_WEIGHTS_MAP = {
    
    
      layers.ELU: [],
      layers.LeakyReLU: [],
      layers.ReLU: [],
      layers.Softmax: [],
      layers.ThresholdedReLU: [],
      layers.Conv1D: ['kernel'],
      layers.Conv2D: ['kernel'],
      layers.Conv2DTranspose: ['kernel'],
      layers.Conv3D: ['kernel'],
      layers.Conv3DTranspose: ['kernel'],
      layers.Cropping1D: [],
      layers.Cropping2D: [],
      layers.Cropping3D: [],
      layers.DepthwiseConv2D: [],
      layers.SeparableConv1D: ['pointwise_kernel'],
      layers.SeparableConv2D: ['pointwise_kernel'],
      layers.UpSampling1D: [],
      layers.UpSampling2D: [],
      layers.UpSampling3D: [],
      layers.ZeroPadding1D: [],
      layers.ZeroPadding2D: [],
      layers.ZeroPadding3D: [],
      layers.Activation: [],
      layers.ActivityRegularization: [],
      layers.Dense: ['kernel'],
      layers.Dropout: [],
      layers.Flatten: [],
      layers.Lambda: [],
      layers.Masking: [],
      layers.Permute: [],
      layers.RepeatVector: [],
      layers.Reshape: [],
      layers.SpatialDropout1D: [],
      layers.SpatialDropout2D: [],
      layers.SpatialDropout3D: [],
      layers.Embedding: ['embeddings'],
      layers.LocallyConnected1D: ['kernel'],
      layers.LocallyConnected2D: ['kernel'],
      layers.Add: [],
      layers.Average: [],
      layers.Concatenate: [],
      layers.Dot: [],
      layers.Maximum: [],
      layers.Minimum: [],
      layers.Multiply: [],
      layers.Subtract: [],
      layers.AlphaDropout: [],
      layers.GaussianDropout: [],
      layers.GaussianNoise: [],
      layers.BatchNormalization: [],
      layers.LayerNormalization: [],
      layers.AveragePooling1D: [],
      layers.AveragePooling2D: [],
      layers.AveragePooling3D: [],
      layers.GlobalAveragePooling1D: [],
      layers.GlobalAveragePooling2D: [],
      layers.GlobalAveragePooling3D: [],
      layers.GlobalMaxPooling1D: [],
      layers.GlobalMaxPooling2D: [],
      layers.GlobalMaxPooling3D: [],
      layers.MaxPooling1D: [],
      layers.MaxPooling2D: [],
      layers.MaxPooling3D: [],
      layers.MultiHeadAttention: [
          '_query_dense.kernel', '_key_dense.kernel', '_value_dense.kernel',
          '_output_dense.kernel'
      ],
      layers.experimental.SyncBatchNormalization: [],
      layers.experimental.preprocessing.Rescaling.__class__: [],
      TensorFlowOpLayer: [],
      layers_compat_v1.BatchNormalization: [],
  }

猜你喜欢

转载自blog.csdn.net/aa2962985/article/details/125361923