跳转至

TensorRT Plugin的实现、调试与验证:以实现Layernorm为例

一、Plugin是什么

TensorRT的Plugin直译应该为插件,望文生义的解释是插入到作为一个计算节点插入到TensorRT构造的计算图中,所以Plugin本质上可以理解为一个计算节点,会以动态链接库的形式插入到网络模型中实现某些算子。

二、什么时候需要Plugin

为了实现以下几种需求时:

  • 实现TensorRT原生不支持的层或者结构
  • 替换TensorRT原本实现性能不够好的层或者结构
  • 手动合并没有自动融合的层或者结构

三、Plugin的特点

  • 自定义cuda kerkel,功能和性能实现完全自定义
  • Plugin实现的算子不可以与其他Layer之间融合,如果Plugin覆盖的结构中有参与算子融合的结构,那么此处算子融合会被破坏。
  • Plugin节点前后可能会插入reformatting节点,增加开销。

四、onnx计算节点与TensorRT Plugin之间的联系

源码地址: https://github.com/onnx/onnx-tensorrt/blob/main/builtin_op_importers.cpp#L4881

// Any ops that are not supported will attempt to import as plugins.
DEFINE_BUILTIN_OP_IMPORTER(FallbackPluginImporter)
{
    OnnxAttrs attrs(node, ctx);
    const std::string pluginName{node.op_type()};
    const std::string pluginVersion{attrs.get<std::string>("plugin_version", "1")};
    const std::string pluginNamespace{attrs.get<std::string>("plugin_namespace", "")};
    LOG_INFO("Searching for plugin: " << pluginName << ", plugin_version: " << pluginVersion << ", plugin_namespace: " << pluginNamespace);
    nvinfer1::IPluginCreator* creator = importPluginCreator(pluginName, pluginVersion, pluginNamespace);
    ASSERT(creator && "Plugin not found, are the plugin name, version, and namespace correct?", ErrorCode::kUNSUPPORTED_NODE);
    const nvinfer1::PluginFieldCollection* fieldNames = creator->getFieldNames();
    // Field data needs to be type erased, we use fieldData for temporary allocations.
    string_map<std::vector<uint8_t>> fieldData{};
    std::vector<nvinfer1::PluginField> fields = loadFields(fieldData, attrs, fieldNames, ctx);
    const auto plugin = createPlugin(getNodeName(node), creator, fields);
    ASSERT(plugin && "Could not create plugin", ErrorCode::kUNSUPPORTED_NODE);
    std::vector<nvinfer1::ITensor*> pluginInputs{};
    for (auto& input : inputs)
    {
        pluginInputs.emplace_back(&convertToTensor(input, ctx));
    }
    LOG_INFO("Successfully created plugin: " << pluginName);
    auto* layer = ctx->network()->addPluginV2(pluginInputs.data(), pluginInputs.size(), *plugin);
    ctx->registerLayer(layer, getNodeName(node));
    RETURN_ALL_OUTPUTS(layer);
}
通过看onnx-parser的源码可知,对于一个onnx中的计算节点,当解析器中其他的解析规则都不满足时会转向使用TensorRT Plugin构造,onnx中计算节点的参数会成为TensorRT Plugin中构造时的传参。按照源码里面的要求,假设我们想对一个onnx计算节点编写Plugin,那么Plugin的名字就要与onnx计算节点的名字相同,且版本为"1"

使用技巧: 假设我们想融合onnx中一些节点的操作,比如LayerNorm,那么就可以对Layernorm牵扯的节点的参数收集起来,然后删除这些节点,用新的名字为“Layernorm”的节点来替换这些节点,接着写TensorRT Plugin,名字为“Layernorm”,来实现上述的融合操作。

五、Plugin的实现流程

5.1 Plugin的种类

我们可以通过实现Plugin基类定义好的接口(虚函数)来自定义Plugin。 根据继承基类的种类,可以将Plugin的种类分为三种

Introduced in TensorRT version? Mixed I/O formats/types Dynamic shapes? Supports implicit/explicit batch mode?
IPluginV2Ext 5.1 Limited No Both implicit and explicit batch modes
IPluginV2IOExt 6.0.1 General No Both implicit and explicit batch modes
IPluginV2DynamicExt 6.0.1 General Yes Explicit batch mode only

上面的表格抄录自官方文档,作者写这篇文章时,TensorRT版本已经升到8.4,所以Plugin的实现通常就是继承表格中最后一行的IPluginV2DynamicExt类,也就是可以实现动态shape的。

5.2 Plugin的定义实现相关接口

Plugin的实现步骤:

  1. 继承IPluginV2DynamicExt类实现一个新的Plugin类
  2. 继承IPluginCreator类实现一个PluginCreator类
  3. 实现用于计算的CUDA C++ kernel
  4. 将Plugin编译为动态链接库.so
  5. c++/python代码中加载so文件

Plugin构建期和运行期要执行的交互(下面所述功能均反映在IPluginV2DynamicExt的接口定义上) 构建期

  • 用户或者解析器需要向Plugin传递构造参数和权重
  • Plugin需要报告其支持的输入输出张量信息(包括类型,形状、数量、数据分布等)
  • Plugin需要报告所需的workspace大小
  • TensorRT会尝试所有允许的组合,然后选择性能最佳的组合(可能会在Plugin前后插入reformat节点)
  • Plugin要实现序列化和反序列化接口

运行期

  • TensorRT为Plugin提供输入输出张量的所有信息,workspace地址还有cuda stream

注意:不要在运行期enqueue函数里面使用cudaMalloc等耗时的函数 Plugin需要实现的关键API

  • getOutputDimensions

根据输入张量形状,向 TensorRT 报告每个输出张量的形状

  • supportsFormatCombination

向TensorRT报告支持的数据类型和数据排布 Plugin 可以同时支持多种数据类型(DataType)和数据排布(LayerOut)的输入输出张量组合 TensorRT 深度优先遍历各个输入张量索引,尝试各种组合,该函数返回“Plugin 是否支持当前尝试的组合” 保证性能的前提下, 多实现一些 format 的组合, 消除 TenosrRT 自动插入 reformat 节点的开销

  • configurePlugin

  • 在推理前将调用该成员函数

  • Dynamic Shape 模式中,每当输入数据形状发生变化(调用 context.set_binding_shape)时,该成员函数被调用

  • 构建期调用时 in/out 张量形状中含有 -1

  • 运行期调用时 in/out 张量形状为真实绑定的形状

  • getWorkspaceSize

  • 向 TensorRT 报告中间计算结果的存储空间

  • workspace显存申请和释放由 TensorRT 管理

  • enqueue

  • 调用 CUDA C++ kernel 计算的地方

  • 可以根据输入张量的不同形状、数据类型等条件选择不同 kernel 执行计算

  • 不要在 enqueue 中使用 cudaMalloc* 等函数

  • 资源管理类

  • initialize(engine 创建时被调用,用于初始化 Plugin 层)

  • terminate (engine 销毁时被调用,用于释放 initialize 函数申请的资源)

  • clone(创建多个 context ,可以与源对象共享本 engine 的资源)

  • attachToContext(申请使用 context 独占的 cudnn 或 cublas 资源)

  • detachFromContext(销毁 context 独占的 cudnn 或 cublas 资源)

  • destroy(当 context/engine 销毁时被调用)

  • 序列化反序列化

  • 序列化(Plugin 负责)

    • getSerializationSize(报告序列化需要的空间大小,单位 Byte)

    • serialize(将Plugin 数据序列化到给定的 buffer 中)

  • 反序列化(PluginCreator 负责)

    • deserializePlugin(把序列化的 buffer 传给 Plugin 的构造函数)

    • Plugin 构造函数(从 buffer 中读取数据并完成 Plugin 构造)

PluginCreator的关键API

  • createPlugin(依照传入参数调用 Plugin 构造函数)
  • 注册 PluginCreator

REGISTER_TENSORRT_PLUGIN(xxxPluginCreator);

  • Version 和 namespace 相关

  • 序列化时 TensorRT 会将 Plugin 的名字(Name)、类型(Type)、版本号(Version)、命名空间(Namespace)信息写入 engine

  • 通常不用修改

5.3 Plugin的功能实现技巧

  • 先保证原生计算的正确性
  • 尝试采用或者基于自带Plugin修改
  • 尝试采用其他框架的cuda实现
  • 采用cublas和cudnn自行实现
  • 采用原生cuda自行编写

六、LayerNorm的原理与Plugin的实现

开源地址:
https://github.com/thb1314/tensorrt-layernorm-plugin
版本要求: TensorRT8.x

作者实现的Layernorm插件性能是有保证的,cuba部分移植的oneflow的官方实现。
https://github.com/Oneflow-Inc/oneflow/blob/master/oneflow/core/cuda/layer_norm.cuh
移植技巧:根据cuh函数名字在github仓库中搜索具体用例,然后写一个适配器函数或者直接修改原函数接口来满足自己的使用。 Layernorm的原理 这里仅给出Layernorm的计算过程,为什么需要Layernorm请自行搜索。 torch伪代码:

1
2
3
4
5
6
7
8
# x shape: [B, S, E]
# B代表 batch size, S 代表 Sequence length E代表Embedding dim length
gamma = nn.Parameter(torch.ones(E))
beta = nn.Parameter(torch.zeros(E))
mean = x.mean(-1, keepdim=True) # mean: [bsz, max_len, 1]
std = x.std(-1, keepdim=True) # std: [bsz, max_len, 1]
# 注意这里也在最后一个维度发生了广播
layernorm_output =  gamma * (x - mean) / (std + self.eps) + beta

拓展
1. 如果是应用在CV领域,x的shape为[B,C,H,W]的是时候应该对谁进行归一化和求mean呢?
2. torch自带的nn.LayerNorm怎么适配这种情况呢?
在插件的enqueue函数,做了如下类型判断,也就是当前的LayerNorm Plugin fp32/fp16都是支持的。

int LayerNormalizationPlugin::enqueue(const nvinfer1::PluginTensorDesc* inputDesc,
    const nvinfer1::PluginTensorDesc* outputDesc, const void* const* inputs, void* const* outputs, void* workspace,
    cudaStream_t stream) noexcept
{
    // Get the input dimensions
    nvinfer1::Dims input_dims = inputDesc[0].dims;
    int batchSize = input_dims.d[0];
    int nbChannels = input_dims.d[1];
    bool use_fp16 = inputDesc[0].type == DataType::kHALF;
    bool use_fp32 = inputDesc[0].type == DataType::kFLOAT;
    mChannelVolume = std::accumulate(input_dims.d + 2, input_dims.d + inputDesc[0].dims.nbDims, 1, std::multiplies<int>());
    int need_size = batchSize * nbChannels;
    if(mean_len < need_size) {
        if(mean != nullptr) {
            cudaFree(mean);
            mean = nullptr;
        }
        mean_len = need_size;
        cudaMalloc(&mean, mean_len * sizeof(float));
    }
    if(inv_variance_len < need_size) {
        if(inv_variance != nullptr) {
            cudaFree(inv_variance);
            inv_variance = nullptr;
        }
        inv_variance_len = need_size;
        cudaMalloc(&inv_variance, inv_variance_len * sizeof(float));
    }
    if(use_fp16) {
        LayerNormForwardGpu(stream, batchSize * nbChannels, mChannelVolume,
                        mEpsilon, reinterpret_cast<const __half*>(inputs[0]), reinterpret_cast<const __half*>(inputs[1]),
                        reinterpret_cast<const __half*>(inputs[2]), reinterpret_cast<__half*>(outputs[0]), mean,
                        inv_variance);
    } else if(use_fp32) {
        LayerNormForwardGpu(stream, batchSize * nbChannels, mChannelVolume,
                        mEpsilon, reinterpret_cast<const float*>(inputs[0]), reinterpret_cast<const float*>(inputs[1]),
                        reinterpret_cast<const float*>(inputs[2]), reinterpret_cast<float*>(outputs[0]),mean,
                        inv_variance);
    } else {
        printf("unsupported type!");
    }
    return 0;
}
这里先判断输入数据的类型,再决定采用那个类型的计算函数,后续文章为讲一下int8的实现与计算。 读者可能会困惑为什么这里enqueue函数内部有cudaMalloc函数,前面不是说不能用么,这里的逻辑在configurePlugin逻辑也有实现,目的是为了按需动态申请mean和std的显存。按理说这里的if判断是进不去的,这么写是为了安全一些。

七、Plugin的Debug技巧

如果读者由仔细阅读源码的话,可以发现里面的trt_tensor相关的类与插件编写没什么关联,放在这里是为了后续的调试。 调试的主要操作就是获取中间结果并保存然后采用其他方式验证。

  1. enqueue函数中间结果调试使用cnpy保存为npz文件
  2. 单个cuda kernel的调试建议另外起一个新项目,仿照cublas项目

八、LayerNorm的验证

LayerNorm的验证可以采用TensorRT python的API来验证。 输入输出由numpy tensor定义,具体函数实现也可以通过numpy来模拟,然后对比tensorRT的输出结果和numpy的输出结果。 代码

import os
import ctypes
import numpy as np
# cuda: https://nvidia.github.io/cuda-python/
from cuda import cudart
import tensorrt as trt
import torch
soFile = "./layernorm_plugin.so"
epsilon = 1.0e-2
np.random.seed(97)
def printArrayInfo(x, description=""):
    print('%s: %s\n  Mean=%.5e,SumAbs=%.5e,Var=%.5e,Max=%.5f,Min=%.5f,SAD=%.5e' % (
        description, str(x.shape), np.mean(x), np.sum(abs(x)), np.var(x), np.max(x), np.min(x), np.sum(np.abs(np.diff(x.reshape(-1))))))
    print("\t", x.reshape(-1)[:10])
def check(a, b, weak=False):
    if weak:
        res = np.all(np.abs(a - b) < epsilon)

    else:
        res = np.all(a == b)
    diff0 = np.max(np.abs(a - b))

    diff1 = np.max(np.abs(a - b) / (np.abs(b) + epsilon))

    print("check:", res, "maxAbsDiff:", diff0, "maxRelDiff:", diff1)
def getLayerNormalizationPlugin():
    for c in trt.get_plugin_registry().plugin_creator_list:
        if c.name == 'LayerNormalizationPlugin':
            parameterList = []
            parameterList.append(trt.PluginField(
                "eps", np.float32(1e-5), trt.PluginFieldType.FLOAT32))
            return c.create_plugin(c.name, trt.PluginFieldCollection(parameterList))
    return None
use_fp16 = True
def run():
    trtFile = "./layernorm-plugin.plan"
    logger = trt.Logger(trt.Logger.ERROR)
    trt.init_libnvinfer_plugins(logger, '')
    ctypes.cdll.LoadLibrary(soFile)
    if os.path.isfile(trtFile):
        with open(trtFile, 'rb') as f:
            engine = trt.Runtime(logger).deserialize_cuda_engine(f.read())
        if engine is None:
            print("Failed loading engine!")
            return
        print("Succeeded loading engine!")
    else:
        builder = trt.Builder(logger)
        network = builder.create_network(
            1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH))
        profile = builder.create_optimization_profile()
        config = builder.create_builder_config()
        config.max_workspace_size = 6 << 30
        if builder.platform_has_fast_fp16 and use_fp16:
            config.set_flag(trt.BuilderFlag.FP16)
            config.set_flag(trt.BuilderFlag.STRICT_TYPES)
        if use_fp16:
            inputT0 = network.add_input(
                'x', trt.DataType.HALF, [-1 for i in range(3)])
            weight = network.add_input('weight', trt.DataType.HALF, [-1])
            bias = network.add_input('bias', trt.DataType.HALF, [-1])
        else:
            inputT0 = network.add_input(
                'x', trt.DataType.FLOAT, [-1 for i in range(3)])
            weight = network.add_input('weight', trt.DataType.FLOAT, [-1])
            bias = network.add_input('bias', trt.DataType.FLOAT, [-1])
        profile.set_shape(inputT0.name, [1, 1, 1], [8, 63, 256], [64, 63, 256])
        profile.set_shape(weight.name, [1, ], [8, ], [64, ])
        profile.set_shape(bias.name, [1, ], [8, ], [64, ])
        config.add_optimization_profile(profile)
        pluginLayer = network.add_plugin_v2(
            [inputT0, weight, bias], getLayerNormalizationPlugin())
        # pluginLayer.
        network.mark_output(pluginLayer.get_output(0))
        if use_fp16:
            pluginLayer.precision = trt.float16
            pluginLayer.set_output_type(0, trt.float16)
            network.get_output(0).dtype = trt.float16
        print('type', network.get_output(0).dtype)
        engineString = builder.build_serialized_network(network, config)
        if engineString is None:
            print("Failed building engine!")
            return
        print("Succeeded building engine!")
        with open(trtFile, 'wb') as f:
            f.write(engineString)
        engine = trt.Runtime(logger).deserialize_cuda_engine(engineString)
    context = engine.create_execution_context()
    shape = (2, 32, 10)
    context.set_binding_shape(0, shape)
    context.set_binding_shape(1, [shape[-1]])
    context.set_binding_shape(2, [shape[-1]])
    _, stream = cudart.cudaStreamCreate()
    nInput = np.sum([engine.binding_is_input(i) for i in range(engine.num_bindings)])
    nOutput = engine.num_bindings - nInput

    bufferH = []
    data_type = np.float16 if use_fp16 else np.float32
    data = np.random.rand(np.prod(shape)).astype(data_type).reshape(shape) * 200 - 100

    print("min, max:", data.min(), data.max())
    weight_data = np.ones((shape[-1], ), dtype=data_type)
    bias_data = np.zeros((shape[-1], ), dtype=data_type)
    bufferH.append(data)
    bufferH.append(weight_data)
    bufferH.append(bias_data)
    print('nOutput:', nOutput)
    for i in range(nOutput):
        print('context.get_binding_shape(nInput + i)',
              context.get_binding_shape(nInput + i))
        bufferH.append(np.empty(context.get_binding_shape(
            nInput + i), dtype=trt.nptype(engine.get_binding_dtype(nInput + i))))
    bufferD = []
    for i in range(engine.num_bindings):
        bufferD.append(cudart.cudaMallocAsync(bufferH[i].nbytes, stream)[1])
    for i in range(nInput):
        cudart.cudaMemcpyAsync(bufferD[i], np.ascontiguousarray(
            bufferH[i].reshape(-1)).ctypes.data, bufferH[i].nbytes, cudart.cudaMemcpyKind.cudaMemcpyHostToDevice, stream)
    context.execute_async_v2(bufferD, stream)
    for i in range(nOutput):
        cudart.cudaMemcpyAsync(bufferH[nInput + i].ctypes.data, bufferD[nInput + i],
                               bufferH[nInput + i].nbytes, cudart.cudaMemcpyKind.cudaMemcpyDeviceToHost, stream)
    cudart.cudaStreamSynchronize(stream)
    mean = np.mean(bufferH[0], axis=-1, keepdims=True)
    std = np.sqrt((np.mean((bufferH[0] - mean) ** 2, axis=-1, keepdims=True) + 1e-5))

    a = (bufferH[0] - mean) / std

    weight = bufferH[1]
    bias = bufferH[2]
    a = weight.reshape(1, 1, -1) * a + bias.reshape(1, 1, -1)
    print("bufferH[-1].dtype: ", bufferH[-1].dtype)
    print('diff abs max', np.abs(a - bufferH[-1].astype(data_type)).max())

    t = torch.as_tensor(bufferH[0])
    cudart.cudaStreamDestroy(stream)
    for buffer in bufferD:
        cudart.cudaFree(buffer)
if __name__ == '__main__':
    os.system('rm ./layernorm-plugin.plan')
    np.set_printoptions(precision=3, linewidth=100, suppress=True)
    run()
    print("test finish!")

九、总结

TensorRT Plugin的实现不算很复杂,接口定义也很清晰。本文详细介绍了TensorRT Plugin的基础使用和实现,学海无涯,作者会继续学习为读者分享更多关于Plugin的高级用法。


最后更新: September 17, 2024
创建日期: September 17, 2024