Network Slimming-神经网络剪枝的精细控制实现
本文介绍如何复现网络剪枝中的一篇经典的文章Learning Efficient Convolutional Networks Through Network Slimming 该文章提出了一种channel-level的裁剪方案,可以通过稀疏化尺度因子(BN层的scaling factor)来裁掉“不重要”的channel。即通过BN层的scale factor权重的大小来反应其对应的channel的重要性,然后按照scale factor绝对值的大小排列后百分比设定一个阈值(比如剪掉30%,那么就从小到大排列,取排在30%位置的数作为阈值),剪掉不那么重要的输出通道所对应的权重,从而实现结构化剪枝。 为了保证scale factor的稀疏性,作者提出给其加上l1 norm的正则得到稀疏解。 思考:输出通道删除后,还需要调整什么呢? 还需要调整下一层的输入,只保留与其上一层保留输出对应的部分,所以对channel的剪枝影响两层,即当前层的输出channel和下一层的输入channel的剪枝。 在复现之前,不妨先讨论一个比较深层的问题。剪枝的本质真的是筛选重要的权重吗? 个人感觉并不是,在使用该文章的方法做实验的时候,误打误撞没有设置l1正则,然后训练得到的结果与设置l1正则得到的结果没啥区别,也尝试过设置过不同l1正则系数,但是得到的结论相同。当l1系数设置太大的时候反而会导致还不如不加。 Rethinking the Value of Network Pruning指出剪枝的本质并不应该是选择重要的权重,而是确定权重的数量,在此基础上,从零开始训练也可以达到原来的性能。所以剪枝是神经网络结构搜索领域的子一个子任务,自然可以采用神经网络结构搜索相关的方法来做。 然而NAS却不是“一般玩家”可以做的,因此针对小模型和小型任务,使用剪枝得到一个更加紧凑的结构,我认为也是较为适合的一个方案。
模型定义和训练过程
接下来就说下如何使用pytorch实现吧 我们对通过bn层衡量其对应算子,此处假设为卷积层,那需要
- bn层对象与conv对象的连接吗,目的是可以通过bn找到其对应conv层。
- 为了更改conv层的结构,那么我们还需要得到conv层的双亲对象节点。
- 为了更改获得下一层或者上一层conv的结构,我们还需要建立conv层与conv层之间的联系。
通过以上分析,我们可以设定这样一个Conv聚合层 实现如下功能:
- 里面的bn可以通过一定方式访问父节点
- bn可以通过一定方式访问其对应conv
- 可以通过一定方式访问其上一个Conv和下一个Conv
ModuleWrapper
包裹需要访问的结点,这么做的目的是为了降低多个节点的耦合,因为nn.Module
类重写了__setitem__
和__getitem__
方法,会将nn.Module
实例放入一个专门的dict中,将其视为其子Module,双亲节点training状态的变化会影响其子节点的变换,显然这不是我们想要的,所以我们使用ModuleWrapper
包括nn.Module
为一个普通对象,尽管这么做不够优雅。
我们设定了bn的如下属性
is_pruned
当前bn层是否需要裁减last_bn
与表示当前bn层所属的Conv层连接的上一个Conv层的bn层next_bn
同理parent
bn层的parent,即双亲节点,指向Conv层- 设定conv和bn方法:
set_conv
、set_bn
- 获得卷积配置参数:
get_conv_config_params
我们可以通过如下方式访问各个节点
is_pruned
状态来标记当前bn是否需要被剪枝,以方便后续加正则loss。
这里我直接给出需要is_pruned=True
的条件,即设定了next_bn.module
不为空即可,这样就可以保证我们对当前节点的输出channel和下一个节点的输入channel的修改的访问过程不会失败。
set
中,单卡下也可以直接判断is_pruned
属性,因为多卡下训练会将BatchNorm
转成SyncBatchNorm
,我们自己设计的属性会丢失,只能训练完成后再将权重导出到单卡下定义的model中。
模型剪枝过程
剪枝的过程:
- 获取所有需要剪枝的bn的weight
- 对weight排序,确定阈值
- 构造一个待剪枝的新模型,结构与内部命名方式与旧模型完全一样(这里我们在函数外面构造好,直接传入)
-
遍历新模型所有需要剪枝的bn层开始剪枝
-
替换当前层的conv
-
确定输出channel的mask
-
判断当前层的输入channel是否需要被剪枝,并确定mask
-
构造新的conv,替换原来的conv
-
构造新的bn,替换原来的bn
-
处理下一个module的输入
-
综上,我们给出剪枝的代码实现
结语
网络剪枝是一个在发展的研究方向,有很多经典的文章值得去探索。剪枝的自动化过程以及其真正的落地中是否会用到剪枝来做处理呢?这或许要未来的我来给自己一个答案。 或许,torch官方会出一个工具,以一种直接更改计算图的方式来实现,就像量化和nihui大佬的PNNX那样。
创建日期: September 17, 2024