跳转至

模型压缩框架nncf模型量化中QAT量化参数的梯度推导

NNCF(https://github.com/openvinotoolkit/nncf)是英特尔为自家模型推理框架openvino推出的一款模型压缩框架,支持Quantization、Binarization、Sparsity和Filter pruning。相关的论文可以在https://arxiv.org/abs/2002.08679下载。做项目时需要用到在英特尔CPU上的量化功能。在看该仓库的量化源码的时候,量化参数的求导困扰了我好多天,一直和伪量化的前向传播对不上,最终发现是作者的变量名误导了我,下面给出在nncf中量化的代码和公式推导过程。 这里简单起见,以对称量化为例,来看NNCF中的QAT量化原理 打开git仓库,找到nncf/torch/extensions/src/quantization/cpu/functions_cpu.cpp,可以看到伪量化的前向传播和反向传播方式。 这里首先贴出伪量化的前向传播的代码

template <typename scalar_t>
at::Tensor q_cpu_forward(
        at::Tensor input,
        at::Tensor input_low,
        at::Tensor input_range,
        scalar_t levels) {
    at::Tensor s = (levels - 1) / input_range;

    auto output = at::max(at::min(input, input_low + input_range), input_low);
    output -= input_low;
    output *= s;
    output = output.round_();
    output = output.div_(s);
    output += input_low;
    return output;
}
同时通过底部的代码
1
2
3
4
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
  m.def("Quantize_forward", &q_forward, "Quantize forward");
  m.def("Quantize_backward", &q_backward, "Quantize backward");
}
搜索Quantize_forward可以追溯到nncf/nncf/torch/quantization/quantize_functions.py文件,这里找到了对称量化的torch function定义
class QuantizeSymmetric(torch.autograd.Function):
    @staticmethod
    def forward(ctx, input_, scale, level_low, level_high, levels):
        input_low = scale * (level_low / level_high)
        input_range = scale - input_low

        if input_.is_cuda:
            if not input_.is_contiguous():
                warnings.warn("input_ is not contiguous!", RuntimeWarning)
                input_ = input_.contiguous()
            output = QuantizedFunctionsCUDA.Quantize_forward(input_, input_low, input_range, levels)
        else:
            output = QuantizedFunctionsCPU.Quantize_forward(input_, input_low, input_range, levels)
        ctx.save_for_backward(input_, input_low, input_range)
        ctx.levels = levels
        ctx.level_low = level_low
        ctx.level_high = level_high
        return output
    @staticmethod
    def backward(ctx, grad_output):
        input_, input_low, input_range = ctx.saved_tensors
        levels = ctx.levels
        level_low = ctx.level_low
        level_high = ctx.level_high
        if grad_output.is_cuda:
            if not grad_output.is_contiguous():
                warnings.warn("grad_output is not contiguous!", RuntimeWarning)
                grad_output = grad_output.contiguous()
            grad_input, _, grad_scale = QuantizedFunctionsCUDA.Quantize_backward(
                grad_output, input_, input_low, input_range, levels, level_low, level_high
            )
        else:
            grad_input, _, grad_scale = QuantizedFunctionsCPU.Quantize_backward(
                grad_output, input_, input_low, input_range, levels, level_low, level_high, False
            )
        return grad_input, grad_scale, None, None, None
综合上面的forward函数和c++的代码,可以写出前向传播的公式

\[ s = \frac{levels - 1}{q_{range}} \]
\[ q_{low} = scale * \frac{level_{low}}{level_{high}} \]
\[ q_{range} = scale - q_{low} \]
\[ q_{high} = q_{low} + q_{range} \]
\[ \operatorname{fakequantize}( x,q_{low},q_{range} ) = \frac{round(s \cdot (\operatorname{clip}(x, q_{low}, q_{high}) - q_{low}))}{s} + q_{low} \]

下面贴出反向传播的代码

template <typename scalar_t>
std::vector<at::Tensor> q_cpu_backward(
        at::Tensor grad_output,
        at::Tensor input,
        at::Tensor input_low,
        at::Tensor input_range,
        scalar_t levels,
        scalar_t levels_low,
        scalar_t levels_high,
        bool is_asymmetric) {
    auto output = q_cpu_forward<scalar_t>(input, input_low, input_range, levels);
    auto reverted_range = 1 / input_range;
    scalar_t alpha = levels_low / levels_high;
    auto mask_hi = input.gt(input_low + input_range);
    auto mask_lo = input.lt(input_low);
    auto err = at::sub(output, input);
    err.mul_(reverted_range);
    err = err.masked_fill_(mask_hi, 1);
    err = err.masked_fill_(mask_lo, alpha);
    err = err.mul_(grad_output);
    auto grad_input_range = err;
    sum_like(grad_input_range, input_range);
    auto grad_input = grad_output.clone();
    auto outside_mask = mask_hi.add_(mask_lo);
    grad_input = grad_input.masked_fill_(outside_mask, 0);
    if (is_asymmetric) {
        auto grad_input_low = grad_output.clone();
        auto all_ones = torch::ones_like(outside_mask);
        grad_input_low = grad_input_low.masked_fill_(at::__xor__(all_ones, outside_mask), 0);
        sum_like(grad_input_low, input_low);
        return {grad_input, grad_input_low, grad_input_range};
    }
    auto dummy_variable = torch::autograd::make_variable(at::empty(input_low.sizes()), true);
    return {grad_input, dummy_variable, grad_input_range};
}
此时,我们先只关注对称的情况,上面代码中grad_input_range表示的并不是对input_range的偏导数,而是scale的偏导数,原因可以看下前面QuantizeSymmetric类的源代码
1
2
3
grad_input, _, grad_scale = QuantizedFunctionsCPU.Quantize_backward(
                grad_output, input_, input_low, input_range, levels, level_low, level_high, False
            )
从该行代码可以看出grad_input_range赋值给了grad_scale。也就是说grad_input_range代表scale的偏导数,那么反向传播的时候对grad_input_range就好解释了。这里给出伪量化的输出对输入xscale的公式

\[ output := fakequantize( x,q_{low},q_{range} ) \]
\[ \nabla_{x} fakequantize( x,q_{low},q_{range} ) = \begin{cases} 1&\text{ if } q_{low} \leq x \leq q_{high}, \\ 0&\text{ if } x < q_{low} \ or \ x > q_{high}. \end{cases} \]
\[ \nabla_{scale} fakequantize( x,q_{low},q_{range} ) = \begin{cases} \frac{(output-x)(1 - \frac{level_{low}}{level_{high}})}{q_{range}} &\\ \text{if}~q_{low} \leq x \leq q_{high} \\ \frac{level_{low}}{level_{high}} &\text { if } x < q_{low}, \\ 1 & \text { if } x > q_{high} .\end{cases} \]

levels表示量化区间可以表示的总个数,比如8位的情况下levels = 256。由于是对称量化,scale为唯一的量化参数,其余的量化参数比如\(q_{low}\)\(q_{range}\)\(q_{high}\)都可以由\(scale\)所表示。 \(level_{low}\)表示量化区间的最小值,\(level_{high}\)表示量化区间的最大值。\(round\)操作表示四舍五入取整(round)操作。\(fakequantize( x,q_{low},q_{range} )\)表示对输入x的线性伪量化函数。\(\nabla_{x} fakequantize( x,q_{low},q_{range} )\)\(\nabla_{scale} fakequantize( x,q_{low},q_{range} )\)表示伪量化函数对\(x\)\(scale\)的偏导数。 关于公式对scale求偏导的中第一个条件的推导过程如下,这里需要用到STE(straight through estimator)假设,即对round函数中自变量求偏导约等于把round函数去掉后的结果。

\[ \begin{align} \nabla_{s} fakequantize( x,q_{low},q_{range} ) &= -\frac{1}{s^2} round(s \cdot (x - q_{low})) + \frac{1}{s}(x - q_{low}) \\ &= \frac{1}{s}[(x - q_{low}) - (output - q_{low})] \\ &= \frac{1}{s}(x - output) \end{align} \]
\[ \begin{align} &\nabla_{q_{low}} fakequantize( x,q_{low},q_{range} ) = \frac{1}{s} (-s) + 1 = 0 \\ &\nabla_{q_{range}} s = - \frac{levels - 1}{{q_{range}}^2} = -s \cdot \frac{1}{q_{range}} \\ &\nabla_{scale} q_{range} = 1 - \frac{level_{low}}{level_{high}} \end{align} \]

整合以上公式得

\[ \begin{align} \nabla_{scale} fakequantize( x,q_{low},q_{range} ) &= \nabla_{s} output \cdot \nabla_{q_{range}} s \cdot \nabla_{scale} q_{range} + \nabla_{q_{low}} output \cdot \nabla_{scale} q_{low} \\ &= \frac{1}{s}(x - output) \cdot (-s \cdot \frac{1}{q_{range}}) \cdot (1 - \frac{level_{low}}{level_{high}}) \\ &= \frac{output - x}{q_{range}}(1 - \frac{level_{low}}{level_{high}}) \end{align} \]

在NNCF中代码没有\((1 - \frac{level_{low}}{level_{high}})\)这一常数项,可能是因为NNCF为了简化运算,对量化参数和模型参数使用不同的学习率将该常数项合并到学习率中。其他条件的证明相对简单,先对前向传播的公式化简就可以求得,这里就不在给出。

综上,NNCF对量化参数的求导依然采用LSQ和TQT的方法,只是变换了参数,这里需要采用多元函数微分的办法对\(scale\)进行求偏导。


最后更新: September 17, 2024
创建日期: September 17, 2024