Network Slimming-神经网络剪枝的精细控制实现¶
本文写于2022年1月17日16时
本文介绍如何复现网络剪枝中的一篇经典的文章Learning Efficient Convolutional Networks Through Network Slimming 该文章提出了一种channel-level的裁剪方案,可以通过稀疏化尺度因子(BN层的scaling factor)来裁掉“不重要”的channel。即通过BN层的scale factor权重的大小来反应其对应的channel的重要性,然后按照scale factor绝对值的大小排列后百分比设定一个阈值(比如剪掉30%,那么就从小到大排列,取排在30%位置的数作为阈值),剪掉不那么重要的输出通道所对应的权重,从而实现结构化剪枝。 为了保证scale factor的稀疏性,作者提出给其加上l1 norm的正则得到稀疏解。 思考:输出通道删除后,还需要调整什么呢? 还需要调整下一层的输入,只保留与其上一层保留输出对应的部分,所以对channel的剪枝影响两层,即当前层的输出channel和下一层的输入channel的剪枝。 在复现之前,不妨先讨论一个比较深层的问题。剪枝的本质真的是筛选重要的权重吗? 个人感觉并不是,在使用该文章的方法做实验的时候,误打误撞没有设置l1正则,然后训练得到的结果与设置l1正则得到的结果没啥区别,也尝试过设置过不同l1正则系数,但是得到的结论相同。当l1系数设置太大的时候反而会导致还不如不加。 Rethinking the Value of Network Pruning指出剪枝的本质并不应该是选择重要的权重,而是确定权重的数量,在此基础上,从零开始训练也可以达到原来的性能。所以剪枝是神经网络结构搜索领域的子一个子任务,自然可以采用神经网络结构搜索相关的方法来做。 然而NAS却不是“一般玩家”可以做的,因此针对小模型和小型任务,使用剪枝得到一个更加紧凑的结构,我认为也是较为适合的一个方案。
模型定义和训练过程¶
接下来就说下如何使用pytorch实现吧 我们对通过bn层衡量其对应算子,此处假设为卷积层,那需要
- bn层对象与conv对象的连接吗,目的是可以通过bn找到其对应conv层。
- 为了更改conv层的结构,那么我们还需要得到conv层的双亲对象节点。
- 为了更改获得下一层或者上一层conv的结构,我们还需要建立conv层与conv层之间的联系。
通过以上分析,我们可以设定这样一个Conv聚合层 实现如下功能:
- 里面的bn可以通过一定方式访问父节点
- bn可以通过一定方式访问其对应conv
- 可以通过一定方式访问其上一个Conv和下一个Conv
class ModuleWrapper:
def __init__(self, module) -> None:
self.module = module
def autopad(k, p=None):
# Pad to 'same'
if p is None:
p = k // 2 if isinstance(k, int) else [x // 2 for x in k] # auto-pad
return p
class Conv(nn.Module):
# Standard convolution
def __init__(self, c1, c2, k=1, s=1, p=None, g=1, act=True, inplace=True): # ch_in, ch_out, kernel, stride, padding, groups
super(Conv, self).__init__()
self._conv_configs = {
'in_channels': c1,
'out_channels': c2,
'kernel_size': k,
'stride': s,
'padding': autopad(k, p),
'groups': g,
'bias': False
}
self.conv = nn.Conv2d(**self._conv_configs)
self.bn = nn.BatchNorm2d(c2)
self.bn.conv = ModuleWrapper(self.conv)
# 标记要不要对当前卷积的输出进行裁剪
self.bn.is_pruned = False
self.bn.last_bn = ModuleWrapper(None)
self.bn.next_bn = ModuleWrapper(None)
self.bn.parent = ModuleWrapper(self)
self.act = nn.ReLU(inplace=inplace) if act is True else (act if isinstance(act, nn.Module) else nn.Identity())
def forward(self, x):
return self.act(self.bn(self.conv(x)))
def set_bn(self, bn):
ori_bn = self.bn
bn.conv = ori_bn.conv
bn.is_pruned = ori_bn.is_pruned
bn.last_bn = ori_bn.last_bn
bn.next_bn = ori_bn.next_bn
bn.parent = ori_bn.parent
self.bn = bn
def set_conv(self, conv):
self.conv = conv
def set_pruned(self, is_pruned):
self.bn.is_pruned = is_pruned
def set_last_bn(self, last_bn):
self.bn.last_bn.module = last_bn
def set_next_bn(self, next_bn):
self.bn.next_bn.module = next_bn
def set_last_conv(self, conv):
self.bn.last_bn.module = conv.bn
def set_next_conv(self, conv):
self.bn.next_bn.module = conv.bn
def get_conv_config_params(self):
return self._conv_configs
以上代码是将bn作为双向链表中的一个节点,通过ModuleWrapper
包裹需要访问的结点,这么做的目的是为了降低多个节点的耦合,因为nn.Module
类重写了__setitem__
和__getitem__
方法,会将nn.Module
实例放入一个专门的dict中,将其视为其子Module,双亲节点training状态的变化会影响其子节点的变换,显然这不是我们想要的,所以我们使用ModuleWrapper
包括nn.Module
为一个普通对象,尽管这么做不够优雅。
我们设定了bn的如下属性
# 标记要不要对当前卷积的输出进行裁剪
self.bn.is_pruned = False
self.bn.last_bn = ModuleWrapper(None)
self.bn.next_bn = ModuleWrapper(None)
self.bn.parent = ModuleWrapper(self)
is_pruned
当前bn层是否需要裁减last_bn
与表示当前bn层所属的Conv层连接的上一个Conv层的bn层next_bn
同理parent
bn层的parent,即双亲节点,指向Conv层- 设定conv和bn方法:
set_conv
、set_bn
- 获得卷积配置参数:
get_conv_config_params
我们可以通过如下方式访问各个节点
# 执行到此表示已经构造好了连接关系... 假设 bn 对象 为遍历的时候的 module
# 访问bn的父亲节点
Conv = module.parent.module
# 访问bn所属的conv
conv_node = Conv.conv
# 访问上一层bn
last_bn = module.last_bn.module
# 上文上一层conv
last_conv = last_bn.parent.module.conv
通过构造上面的链表和树的形式,我们可以很方便的获得bn层相关的节点。 不过在构造模型时需要我们自己构造出这种连接关系来 下面举一个例子
class Block(nn.Sequential):
def __init__(self, c1, c2, n=3):
super(Block, self).__init__(*[
Conv(c1,c2,k=3,s=1) for _ in range(n)])
childrens = list(self.children())
if n >= 2:
for i in range(n - 1):
childrens[i].set_next_conv(childrens[i + 1])
for i in range(1, n):
childrens[i].set_last_conv(childrens[i - 1])
通过上面例子我们可以知道,我们每写一个模块,就需要需要将这种连接关系构造清楚,此外模块与模块之间的连接关系也要构造清除,这里就不再展示
在训练之前,我们还需要做一些准备,我们需要设定bn层的is_pruned
状态来标记当前bn是否需要被剪枝,以方便后续加正则loss。
这里我直接给出需要is_pruned=True
的条件,即设定了next_bn.module
不为空即可,这样就可以保证我们对当前节点的输出channel和下一个节点的输入channel的修改的访问过程不会失败。
pruned_dict = set()
for name, module in model.named_modules():
if isinstance(module, nn.BatchNorm2d) and hasattr(module, 'is_pruned'):
if module.next_bn.module is not None:
module.is_pruned = True
pruned_dict.add(name)
print('module', name, 'is_pruned == True')
至此我们在训练的时候就可以使用如下方式加正则项
l1_norm_loss = 0
lambda = 5e-6
for name,module in model.named_modules():
if name in pruned_dict:
l1_norm_loss += lambda* torch.abs(module.weight).sum()
loss += l1_norm_loss
这里我直接给出了多卡下处理的方案,即将module的name放入set
中,单卡下也可以直接判断is_pruned
属性,因为多卡下训练会将BatchNorm
转成SyncBatchNorm
,我们自己设计的属性会丢失,只能训练完成后再将权重导出到单卡下定义的model中。
模型剪枝过程¶
剪枝的过程:
- 获取所有需要剪枝的bn的weight
- 对weight排序,确定阈值
- 构造一个待剪枝的新模型,结构与内部命名方式与旧模型完全一样(这里我们在函数外面构造好,直接传入)
-
遍历新模型所有需要剪枝的bn层开始剪枝
-
替换当前层的conv
- 确定输出channel的mask
- 判断当前层的输入channel是否需要被剪枝,并确定mask
- 构造新的conv,替换原来的conv
- 构造新的bn,替换原来的bn
- 处理下一个module的输入
综上,我们给出剪枝的代码实现
import torch
import torch.nn as nn
from copy import deepcopy
def prune_model(model, new_model, percent = 0.4):
total = 0
for module in model.modules():
if isinstance(module, nn.BatchNorm2d) and getattr(module, 'is_pruned', False):
total += module.weight.data.shape[0]
print('total',total)
total_bn_weights = torch.zeros(total)
index = 0
for module in model.modules():
if isinstance(module, nn.BatchNorm2d) and getattr(module, 'is_pruned', False):
size = module.weight.data.shape[0]
total_bn_weights[index:(index+size)] = module.weight.detach().data.abs()
index += size
topk_values,_ = torch.topk(total_bn_weights ,k = int(total * percent), sorted=True, largest=False)
threhold = topk_values[-1]
pruned = 0
reserved_channels = dict()
channel_masks = dict()
for name, module in model.named_modules():
if isinstance(module, nn.BatchNorm2d) and getattr(module, 'is_pruned', False):
weight_copy = module.weight.detach().data.abs()
mask = weight_copy.gt(threhold).float()
pruned += mask.shape[0] - torch.sum(mask).item()
# module.weight.data.mul_(mask)
# module.bias.data.mul_(mask)
reserved_channels[name] = torch.sum(mask).item()
channel_masks[name] = mask.clone()
# new_model = deepcopy(model)
new_model.load_state_dict(model.state_dict())
# 这里是为了根据name访问旧模型中的参数
def get_module(model, name):
tokens = name.split('.')
sub_tokens = tokens
cur_mod = model
for s in sub_tokens:
cur_mod = getattr(cur_mod, s)
return cur_mod
for name, module in new_model.named_modules():
module.name = name
if name in reserved_channels:
module.is_pruned = True
for name, module in new_model.named_modules():
if name in reserved_channels:
reserved_channel = reserved_channels[name]
module_parent = module.parent.module
mask = channel_masks[name].bool()
# 只关注当前module
conv_config_params = module_parent.get_conv_config_params()
# 判断是否精简输入channel
new_in_channels = conv_config_params['in_channels']
last_bn_weight_mask = torch.ones(new_in_channels).to(torch.bool)
if module.last_bn.module is not None and getattr(module.last_bn.module, 'is_pruned', False):
last_bn_weight_mask = channel_masks[module.last_bn.module.name]
new_in_channels = int(torch.sum(last_bn_weight_mask))
last_bn_weight_mask = last_bn_weight_mask.bool()
new_output_channels = int(reserved_channel)
new_configs = deepcopy(conv_config_params)
new_configs['in_channels'] = new_in_channels
new_configs['out_channels'] = new_output_channels
new_conv = nn.Conv2d(**new_configs)
ori_bn_parent = get_module(model,name).parent.module
ori_weight = ori_bn_parent.conv.weight.detach().data
ori_weight = ori_weight[mask,:,:,:]
ori_weight = ori_weight[:,last_bn_weight_mask,:,:]
new_conv.weight.data.copy_(ori_weight)
if new_configs['bias']:
ori_bias = module_parent.conv.bias.detach().data
new_conv.bias.data.copy_(ori_bias[mask])
module_parent.set_conv(new_conv)
# 更改bn
new_bn = nn.BatchNorm2d(new_output_channels, eps=module.eps, momentum=module.momentum, affine=module.affine, track_running_stats=module.track_running_stats)
new_bn.weight.data.copy_(module.weight.data[mask])
new_bn.bias.data.copy_(module.bias.data[mask])
if new_bn.track_running_stats:
new_bn.running_mean.data.copy_(module.running_mean.data[mask])
new_bn.running_var.data.copy_(module.running_var.data[mask])
new_bn.num_batches_tracked.data.copy_(module.num_batches_tracked.data)
module_parent.set_bn(new_bn)
# 如果下一个module没有剪枝,那么需要更改处理下一个modude的输入
if module.next_bn.module is not None and getattr(module.next_bn.module, 'is_pruned', True) == False:
# 更改下一个module的输入
next_module_parent = module.next_bn.module.parent.module
next_conv_config_params = deepcopy(next_module_parent.get_conv_config_params())
next_conv_config_params['in_channels'] = new_output_channels
new_conv = nn.Conv2d(**next_conv_config_params)
ori_conv= get_module(model,next_module_parent.conv.name)
ori_weight = ori_conv.weight.detach().data
ori_weight = ori_weight[:,mask,:,:]
new_conv.weight.data.copy_(ori_weight)
next_module_parent.set_conv(new_conv)
return new_model
结语¶
网络剪枝是一个在发展的研究方向,有很多经典的文章值得去探索。剪枝的自动化过程以及其真正的落地中是否会用到剪枝来做处理呢?这或许要未来的我来给自己一个答案。 或许,torch官方会出一个工具,以一种直接更改计算图的方式来实现,就像量化和nihui大佬的PNNX那样。