【PyTorch模型剪枝实例教程3(多参数与全局剪枝)】


目前大部分最先进的(SOTA)深度学习技术虽然效果好,但由于其模型参数量和计算量过高,难以用于实际部署。而众所周知,生物神经网络使用高效的稀疏连接(生物大脑神经网络balabala啥的都是稀疏连接的),考虑到这一点,为了减少内存、容量和硬件消耗,同时又不牺牲模型预测的精度,在设备上部署轻量级模型,并通过私有的设备上计算以保证隐私,通过减少参数数量来压缩模型的最佳技术非常重要。

稀疏神经网络在预测精度方面可以达到密集神经网络的水平,但由于模型参数量小,理论上来讲推理速度也会快很多。而模型剪枝是一种将密集神经网络训练成稀疏神经网络的方法。

本文将通过学习官方示例教程,介绍如何通过一个简单的实例教程来进行模型剪枝,实践深度学习模型压缩加速。

相关链接

  • 深度学习模型压缩与加速技术(一):参数剪枝

  • PyTorch模型剪枝实例教程一、非结构化剪枝

  • PyTorch模型剪枝实例教程二、结构化剪枝

  • PyTorch模型剪枝实例教程三、多参数与全局剪枝

通过教程一和教程二,我们可以了解如何通过PyTorch进行非结构化和结构化的剪枝,一般而言,我们会考虑将较深的网络进行参数剪枝,此时,通过一个个检查模块诶个给它们剪枝就比较麻烦,我们可以利用多参数和全局剪枝的方法对同类型参数进行剪枝。

一.导包&定义一个简单的网络

import torch
from torch import nn
import torch.nn.utils.prune as prune
import torch.nn.functional as F

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

'''搭建类LeNet网络'''
class LeNet(nn.Module):
    def __init__(self):
        super(LeNet, self).__init__()
        # 单通道图像输入,5×5核尺寸
        self.conv1 = nn.Conv2d(1, 3, 5)
        self.conv2 = nn.Conv2d(3, 16, 5)
        self.fc1 = nn.Linear(16 * 5 * 5, 120)  
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x):
        x = F.max_pool2d(F.relu(self.conv1(x)), (2, 2))
        x = F.max_pool2d(F.relu(self.conv2(x)), 2)
        x = x.view(-1, int(x.nelement() / x.shape[0]))
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

二.多参数剪枝

new_model = LeNet()
for name, module in new_model.named_modules():
    # 对所有Conv2d的参数进行20%的L1非结构化剪枝
    if isinstance(module, torch.nn.Conv2d):
        prune.l1_unstructured(module, name='weight', amount=0.2)
    # 对所有Linear的参数进行20%的L1非结构化剪枝
    elif isinstance(module, torch.nn.Linear):
        prune.l1_unstructured(module, name='weight', amount=0.4)
print(dict(new_model.named_buffers()).keys())  # 验证一下下

输出:

dict_keys(['conv1.weight_mask', 'conv2.weight_mask', 'fc1.weight_mask', 'fc2.weight_mask', 'fc3.weight_mask'])

三.全局剪枝

前面所有提到的方法,都是局部剪枝方法,我们还可以使用全局剪枝方法,通过删除整个模型最低的20%的连接,而非删除每个层中最低20%的连接,也就是说,可能会出现层与层之间删除的百分比不一样的情况。

model = LeNet()

parameters_to_prune = (
    (model.conv1, 'weight'),
    (model.conv2, 'weight'),
    (model.fc1, 'weight'),
    (model.fc2, 'weight'),
    (model.fc3, 'weight'),
)

prune.global_unstructured(
    parameters_to_prune,
    pruning_method=prune.L1Unstructured,
    amount=0.2,
)

print(
    "稀疏性 in conv1.weight: {:.2f}%".format(
        100. * float(torch.sum(model.conv1.weight == 0))
        / float(model.conv1.weight.nelement())
    )
)
print(
    "稀疏性 in conv2.weight: {:.2f}%".format(
        100. * float(torch.sum(model.conv2.weight == 0))
        / float(model.conv2.weight.nelement())
    )
)
print(
    "稀疏性 in fc1.weight: {:.2f}%".format(
        100. * float(torch.sum(model.fc1.weight == 0))
        / float(model.fc1.weight.nelement())
    )
)
print(
    "稀疏性 in fc2.weight: {:.2f}%".format(
        100. * float(torch.sum(model.fc2.weight == 0))
        / float(model.fc2.weight.nelement())
    )
)
print(
    "稀疏性 in fc3.weight: {:.2f}%".format(
        100. * float(torch.sum(model.fc3.weight == 0))
        / float(model.fc3.weight.nelement())
    )
)
print(
    "全局稀疏性: {:.2f}%".format(
        100. * float(
            torch.sum(model.conv1.weight == 0)
            + torch.sum(model.conv2.weight == 0)
            + torch.sum(model.fc1.weight == 0)
            + torch.sum(model.fc2.weight == 0)
            + torch.sum(model.fc3.weight == 0)
        )
        / float(
            model.conv1.weight.nelement()
            + model.conv2.weight.nelement()
            + model.fc1.weight.nelement()
            + model.fc2.weight.nelement()
            + model.fc3.weight.nelement()
        )
    )
)

输出

稀疏性 in conv1.weight: 8.00%
稀疏性 in conv2.weight: 9.33%
稀疏性 in fc1.weight: 22.07%
稀疏性 in fc2.weight: 12.20%
稀疏性 in fc3.weight: 11.31%

全局稀疏性: 20.00%

四.总结

本示例首先搭建了一个类LeNet网络模型,为了进行多参数剪枝,我们使用.named_modules()遍历了所有层,并利用isinstance()方法判断是否为Conv2d或Linear结构,以此来对相同结构参数进行同等类型剪枝。为了进行全局剪枝,我们使用了 .global_unstructured参数进行剪枝,可以发现,全局剪枝与多参数剪枝不一样的地方在于,全局剪枝最终的稀疏性虽然和多参数剪枝稀疏性相同,但全局剪枝稀疏性并非对每层均等稀疏的。

本文用到的核心函数方法:

.named_modules(),获取模型的参数名和结构
isinstance(),判断类型是否一致
.global_unstructured,全局剪枝方法
参考:

Torch官方剪枝教程

猜你喜欢

转载自blog.csdn.net/weixin_42483745/article/details/125035007