跳转至

Network Slimming-神经网络剪枝的精细控制实现

本文介绍如何复现网络剪枝中的一篇经典的文章Learning Efficient Convolutional Networks Through Network Slimming 该文章提出了一种channel-level的裁剪方案,可以通过稀疏化尺度因子(BN层的scaling factor)来裁掉“不重要”的channel。即通过BN层的scale factor权重的大小来反应其对应的channel的重要性,然后按照scale factor绝对值的大小排列后百分比设定一个阈值(比如剪掉30%,那么就从小到大排列,取排在30%位置的数作为阈值),剪掉不那么重要的输出通道所对应的权重,从而实现结构化剪枝。 为了保证scale factor的稀疏性,作者提出给其加上l1 norm的正则得到稀疏解。 思考:输出通道删除后,还需要调整什么呢? 还需要调整下一层的输入,只保留与其上一层保留输出对应的部分,所以对channel的剪枝影响两层,即当前层的输出channel和下一层的输入channel的剪枝。 在复现之前,不妨先讨论一个比较深层的问题。剪枝的本质真的是筛选重要的权重吗? 个人感觉并不是,在使用该文章的方法做实验的时候,误打误撞没有设置l1正则,然后训练得到的结果与设置l1正则得到的结果没啥区别,也尝试过设置过不同l1正则系数,但是得到的结论相同。当l1系数设置太大的时候反而会导致还不如不加。 Rethinking the Value of Network Pruning指出剪枝的本质并不应该是选择重要的权重,而是确定权重的数量,在此基础上,从零开始训练也可以达到原来的性能。所以剪枝是神经网络结构搜索领域的子一个子任务,自然可以采用神经网络结构搜索相关的方法来做。 然而NAS却不是“一般玩家”可以做的,因此针对小模型和小型任务,使用剪枝得到一个更加紧凑的结构,我认为也是较为适合的一个方案。

模型定义和训练过程

接下来就说下如何使用pytorch实现吧 我们对通过bn层衡量其对应算子,此处假设为卷积层,那需要

  1. bn层对象与conv对象的连接吗,目的是可以通过bn找到其对应conv层。
  2. 为了更改conv层的结构,那么我们还需要得到conv层的双亲对象节点。
  3. 为了更改获得下一层或者上一层conv的结构,我们还需要建立conv层与conv层之间的联系。

通过以上分析,我们可以设定这样一个Conv聚合层 实现如下功能:

  1. 里面的bn可以通过一定方式访问父节点
  2. bn可以通过一定方式访问其对应conv
  3. 可以通过一定方式访问其上一个Conv和下一个Conv

class ModuleWrapper:
    def __init__(self, module) -> None:
        self.module = module
def autopad(k, p=None):
    # Pad to 'same'
    if p is None:
        p = k // 2 if isinstance(k, int) else [x // 2 for x in k]  # auto-pad
    return p
class Conv(nn.Module):
    # Standard convolution
    def __init__(self, c1, c2, k=1, s=1, p=None, g=1, act=True, inplace=True):  # ch_in, ch_out, kernel, stride, padding, groups
        super(Conv, self).__init__()
        self._conv_configs = {
            'in_channels': c1,
            'out_channels': c2,
            'kernel_size': k,
            'stride': s,
            'padding': autopad(k, p),
            'groups': g,
            'bias': False
        }
        self.conv = nn.Conv2d(**self._conv_configs)
        self.bn = nn.BatchNorm2d(c2)
        self.bn.conv = ModuleWrapper(self.conv)
        # 标记要不要对当前卷积的输出进行裁剪
        self.bn.is_pruned = False
        self.bn.last_bn = ModuleWrapper(None)
        self.bn.next_bn = ModuleWrapper(None)
        self.bn.parent = ModuleWrapper(self)
        self.act = nn.ReLU(inplace=inplace) if act is True else (act if isinstance(act, nn.Module) else nn.Identity())
    def forward(self, x):
        return self.act(self.bn(self.conv(x)))
    def set_bn(self, bn):
        ori_bn = self.bn
        bn.conv = ori_bn.conv
        bn.is_pruned = ori_bn.is_pruned
        bn.last_bn = ori_bn.last_bn
        bn.next_bn = ori_bn.next_bn
        bn.parent = ori_bn.parent
        self.bn = bn
    def set_conv(self, conv):
        self.conv = conv
    def set_pruned(self, is_pruned):
        self.bn.is_pruned = is_pruned
    def set_last_bn(self, last_bn):
        self.bn.last_bn.module = last_bn
    def set_next_bn(self, next_bn):
        self.bn.next_bn.module = next_bn
    def set_last_conv(self, conv):
        self.bn.last_bn.module = conv.bn
    def set_next_conv(self, conv):
        self.bn.next_bn.module = conv.bn
    def get_conv_config_params(self):
        return self._conv_configs
以上代码是将bn作为双向链表中的一个节点,通过ModuleWrapper包裹需要访问的结点,这么做的目的是为了降低多个节点的耦合,因为nn.Module类重写了__setitem____getitem__方法,会将nn.Module实例放入一个专门的dict中,将其视为其子Module,双亲节点training状态的变化会影响其子节点的变换,显然这不是我们想要的,所以我们使用ModuleWrapper包括nn.Module为一个普通对象,尽管这么做不够优雅。 我们设定了bn的如下属性
1
2
3
4
5
# 标记要不要对当前卷积的输出进行裁剪
self.bn.is_pruned = False
self.bn.last_bn = ModuleWrapper(None)
self.bn.next_bn = ModuleWrapper(None)
self.bn.parent = ModuleWrapper(self)

  1. is_pruned当前bn层是否需要裁减
  2. last_bn 与表示当前bn层所属的Conv层连接的上一个Conv层的bn层
  3. next_bn 同理
  4. parent bn层的parent,即双亲节点,指向Conv层
  5. 设定conv和bn方法:set_convset_bn
  6. 获得卷积配置参数:get_conv_config_params

我们可以通过如下方式访问各个节点

1
2
3
4
5
6
7
8
9
# 执行到此表示已经构造好了连接关系... 假设 bn 对象 为遍历的时候的 module
# 访问bn的父亲节点
Conv = module.parent.module
# 访问bn所属的conv
conv_node = Conv.conv
# 访问上一层bn
last_bn = module.last_bn.module
# 上文上一层conv
last_conv = last_bn.parent.module.conv
通过构造上面的链表和树的形式,我们可以很方便的获得bn层相关的节点。 不过在构造模型时需要我们自己构造出这种连接关系来 下面举一个例子
class Block(nn.Sequential):
    def __init__(self, c1, c2, n=3):
        super(Block, self).__init__(*[
            Conv(c1,c2,k=3,s=1)  for _ in range(n)])
        childrens = list(self.children())
        if n >= 2:
            for i in range(n - 1):

                childrens[i].set_next_conv(childrens[i + 1])
            for i in range(1, n):
                childrens[i].set_last_conv(childrens[i - 1])
通过上面例子我们可以知道,我们每写一个模块,就需要需要将这种连接关系构造清楚,此外模块与模块之间的连接关系也要构造清除,这里就不再展示 在训练之前,我们还需要做一些准备,我们需要设定bn层的is_pruned状态来标记当前bn是否需要被剪枝,以方便后续加正则loss。 这里我直接给出需要is_pruned=True的条件,即设定了next_bn.module不为空即可,这样就可以保证我们对当前节点的输出channel和下一个节点的输入channel的修改的访问过程不会失败。
1
2
3
4
5
6
7
pruned_dict = set()
for name, module in model.named_modules():
    if isinstance(module, nn.BatchNorm2d) and hasattr(module, 'is_pruned'):
        if module.next_bn.module is not None:
            module.is_pruned = True
            pruned_dict.add(name)
            print('module', name, 'is_pruned == True')
至此我们在训练的时候就可以使用如下方式加正则项
1
2
3
4
5
6
l1_norm_loss = 0
lambda = 5e-6
for name,module in model.named_modules():
    if name in pruned_dict:
        l1_norm_loss += lambda* torch.abs(module.weight).sum()
loss += l1_norm_loss
这里我直接给出了多卡下处理的方案,即将module的name放入set中,单卡下也可以直接判断is_pruned属性,因为多卡下训练会将BatchNorm转成SyncBatchNorm,我们自己设计的属性会丢失,只能训练完成后再将权重导出到单卡下定义的model中。

模型剪枝过程

剪枝的过程:

  1. 获取所有需要剪枝的bn的weight
  2. 对weight排序,确定阈值
  3. 构造一个待剪枝的新模型,结构与内部命名方式与旧模型完全一样(这里我们在函数外面构造好,直接传入)
  4. 遍历新模型所有需要剪枝的bn层开始剪枝

    1. 替换当前层的conv

    2. 确定输出channel的mask

    3. 判断当前层的输入channel是否需要被剪枝,并确定mask

    4. 构造新的conv,替换原来的conv

    5. 构造新的bn,替换原来的bn

    6. 处理下一个module的输入

综上,我们给出剪枝的代码实现

import torch
import torch.nn as nn
from copy import deepcopy
def prune_model(model, new_model, percent = 0.4):
    total = 0
    for module in model.modules():
        if isinstance(module, nn.BatchNorm2d) and getattr(module, 'is_pruned', False):
            total += module.weight.data.shape[0]
    print('total',total)
    total_bn_weights = torch.zeros(total)
    index = 0
    for module in model.modules():
        if isinstance(module, nn.BatchNorm2d) and getattr(module, 'is_pruned', False):
            size = module.weight.data.shape[0]
            total_bn_weights[index:(index+size)] = module.weight.detach().data.abs()
            index += size
    topk_values,_ = torch.topk(total_bn_weights ,k = int(total * percent), sorted=True, largest=False)
    threhold = topk_values[-1]
    pruned = 0
    reserved_channels = dict()
    channel_masks = dict()
    for name, module in model.named_modules():
        if isinstance(module, nn.BatchNorm2d) and getattr(module, 'is_pruned', False):
            weight_copy = module.weight.detach().data.abs()
            mask = weight_copy.gt(threhold).float()
            pruned += mask.shape[0] - torch.sum(mask).item()

#             module.weight.data.mul_(mask)
#             module.bias.data.mul_(mask)
            reserved_channels[name] = torch.sum(mask).item()
            channel_masks[name] = mask.clone()
    # new_model = deepcopy(model)
    new_model.load_state_dict(model.state_dict())
    # 这里是为了根据name访问旧模型中的参数
    def get_module(model, name):
        tokens = name.split('.')
        sub_tokens = tokens
        cur_mod = model
        for s in sub_tokens:
            cur_mod = getattr(cur_mod, s)
        return cur_mod
    for name, module in new_model.named_modules():
        module.name = name
        if name in reserved_channels:
            module.is_pruned = True
    for name, module in new_model.named_modules():
        if name in reserved_channels:
            reserved_channel =  reserved_channels[name]
            module_parent = module.parent.module
            mask = channel_masks[name].bool()
            # 只关注当前module
            conv_config_params = module_parent.get_conv_config_params()
            # 判断是否精简输入channel
            new_in_channels = conv_config_params['in_channels']
            last_bn_weight_mask = torch.ones(new_in_channels).to(torch.bool)
            if module.last_bn.module is not None and getattr(module.last_bn.module, 'is_pruned', False):
                last_bn_weight_mask = channel_masks[module.last_bn.module.name]
                new_in_channels = int(torch.sum(last_bn_weight_mask))
                last_bn_weight_mask = last_bn_weight_mask.bool()
            new_output_channels = int(reserved_channel)
            new_configs = deepcopy(conv_config_params)
            new_configs['in_channels'] = new_in_channels
            new_configs['out_channels'] = new_output_channels
            new_conv = nn.Conv2d(**new_configs)
            ori_bn_parent = get_module(model,name).parent.module
            ori_weight = ori_bn_parent.conv.weight.detach().data
            ori_weight = ori_weight[mask,:,:,:]
            ori_weight = ori_weight[:,last_bn_weight_mask,:,:]
            new_conv.weight.data.copy_(ori_weight)
            if new_configs['bias']:
                ori_bias = module_parent.conv.bias.detach().data
                new_conv.bias.data.copy_(ori_bias[mask])
            module_parent.set_conv(new_conv)
            # 更改bn
            new_bn = nn.BatchNorm2d(new_output_channels, eps=module.eps, momentum=module.momentum, affine=module.affine, track_running_stats=module.track_running_stats)
            new_bn.weight.data.copy_(module.weight.data[mask])
            new_bn.bias.data.copy_(module.bias.data[mask])
            if new_bn.track_running_stats:
                new_bn.running_mean.data.copy_(module.running_mean.data[mask])
                new_bn.running_var.data.copy_(module.running_var.data[mask])
                new_bn.num_batches_tracked.data.copy_(module.num_batches_tracked.data)
            module_parent.set_bn(new_bn)
            # 如果下一个module没有剪枝,那么需要更改处理下一个modude的输入
            if module.next_bn.module is not None and getattr(module.next_bn.module, 'is_pruned', True) == False:
                # 更改下一个module的输入
                next_module_parent = module.next_bn.module.parent.module
                next_conv_config_params = deepcopy(next_module_parent.get_conv_config_params())
                next_conv_config_params['in_channels'] = new_output_channels
                new_conv = nn.Conv2d(**next_conv_config_params)
                ori_conv= get_module(model,next_module_parent.conv.name)
                ori_weight = ori_conv.weight.detach().data
                ori_weight = ori_weight[:,mask,:,:]
                new_conv.weight.data.copy_(ori_weight)
                next_module_parent.set_conv(new_conv)
    return new_model

结语

网络剪枝是一个在发展的研究方向,有很多经典的文章值得去探索。剪枝的自动化过程以及其真正的落地中是否会用到剪枝来做处理呢?这或许要未来的我来给自己一个答案。 或许,torch官方会出一个工具,以一种直接更改计算图的方式来实现,就像量化和nihui大佬的PNNX那样。


最后更新: March 21, 2024
创建日期: March 21, 2024