tensorrt学习(二)

发布时间 2023-09-02 16:38:43作者: silence_cho

继续整理tensorrt的学习资料,方便后续查看

1. tensorrt插件

对于tensorrt不支持的算子,可以通过plugin插件的方式,自己实现。这里采用在pytorch中自定义一个算子,导出到onnx中,然后实现一个tensorrt plugin来解析这个自定义算子。

1.1 pytorch中自定义onnx算子

官方文档:https://pytorch.org/docs/1.10/onnx.html#torch-autograd-functions

参考:https://zhuanlan.zhihu.com/p/513387413

继承torch.autograd.Function类,实现其forward()和backward()方法,就可以当成一个普通的pytorch的函数在网络中使用,实现其symbolic 静态方法,当我们调用torch.onnx.export()时,就能将其转换为onnx算子,总结下如下:

  • 对于模型推理和训练来说,Function 类本身表示 PyTorch 的一个可导函数,只要为其定义了前向推理和反向传播的实现,我们就可以把它当成一个普通 PyTorch 函数来使用。PyTorch 会自动调度该函数,合适地执行前向和反向计算
  • 对模型部署来说,Function 类有一个很好的性质:如果它定义了 symbolic 静态方法,该 Function 在执行 torch.onnx.export() 时就可以根据 symbolic 中定义的规则转换成 ONNX 算子
  • symbolic是符号函数,通常在其内部返回一个g.op()对象。g.op() 把一个 PyTorch 算子映射成一个或多个 ONNX 算子,或者是自定义的 ONNX 算子。

下面是实现一个selu激活函数的代码:

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.onnx
import torch.autograd
import os

class MYSELUImpl(torch.autograd.Function):

    # reference: https://pytorch.org/docs/1.10/onnx.html#torch-autograd-functions
    @staticmethod
    def symbolic(g, x, p):
        print("==================================call symbolic")
        return g.op("MYSELU", x, p, 
            g.op("Constant", value_t=torch.tensor([3, 2, 1], dtype=torch.float32)),
            attr1_s="这是字符串属性", 
            attr2_i=[1, 2, 3], 
            attr3_f=222
        )

    @staticmethod
    def forward(ctx, x, p):
        return x * 1 / (1 + torch.exp(-x))


class MYSELU(nn.Module):
    def __init__(self, n):
        super().__init__()
        self.param = nn.parameter.Parameter(torch.arange(n).float())

    def forward(self, x):
        return MYSELUImpl.apply(x, self.param)


class Model(nn.Module):
    def __init__(self):
        super().__init__()

        self.conv = nn.Conv2d(1, 1, 3, padding=1)
        self.myselu = MYSELU(3)
        self.conv.weight.data.fill_(1)
        self.conv.bias.data.fill_(0)
    
    def forward(self, x):
        x = self.conv(x)
        x = self.myselu(x)
        return x


# 这个包对应opset11的导出代码,如果想修改导出的细节,可以在这里修改代码
# import torch.onnx.symbolic_opset11
print("对应opset文件夹代码在这里:", os.path.dirname(torch.onnx.__file__))

model = Model().eval()
input = torch.tensor([
    # batch 0
    [
        [1,   1,   1],
        [1,   1,   1],
        [1,   1,   1],
    ],
        # batch 1
    [
        [-1,   1,   1],
        [1,   0,   1],
        [1,   1,   -1]
    ]
], dtype=torch.float32).view(2, 1, 3, 3)

output = model(input)
print(f"inference output = \n{output}")

dummy = torch.zeros(1, 1, 3, 3)
torch.onnx.export(
    model, 

    # 这里的args,是指输入给model的参数,需要传递tuple,因此用括号
    (dummy,), 

    # 储存的文件路径
    "workspace/demo.onnx", 

    # 打印详细信息
    verbose=True, 

    # 为输入和输出节点指定名称,方便后面查看或者操作
    input_names=["image"], 
    output_names=["output"], 

    # 这里的opset,指,各类算子以何种方式导出,对应于symbolic_opset11
    opset_version=11, 

    # 表示他有batch、height、width3个维度是动态的,在onnx中给其赋值为-1
    # 通常,我们只设置batch为动态,其他的避免动态
    dynamic_axes={
        "image": {0: "batch", 2: "height", 3: "width"},
        "output": {0: "batch", 2: "height", 3: "width"},
    },

    # 对于插件,需要禁用onnx检查
    enable_onnx_checker=False
)

print("Done.!")

上述返回的g.op()函数值得说明显下:

g.op("MYSELU", x, p,  # 表示onnx算子的名称为MYSELU
	# 给算子传一个常数参数
	g.op("Constant", value_t=torch.tensor([3, 2, 1], dtype=torch.float32)),
	attr1_s="这是字符串属性",   # s表示字符串
	attr2_i=[1, 2, 3],       # i表示整数
	attr3_f=222              # f表示浮点数
)

下面是导出onnx后MYSELU节点对应如下:(标红的即为g.op中对应的参数)

1.2 tensorrt plugin插件解析onnx算子

官方文档:https://docs.nvidia.com/deeplearning/tensorrt/developer-guide/index.html#extending

tensorrt中自定义一个插件,需要继承和实现两个类, 然后注册这个插件的创建器

  • 1 继承nvinfer1::IPluginV2DynamicExt类,完成插件的具体实现

    class MySELUPlugin : public nvinfer1::IPluginV2DynamicExt {};
    
    • 需要注意的是,实现getPluginType()函数时,其返回的名称要和onnx中该算子一致
    • 插件算子的具体实现逻辑一般用cuda核函数重写,在enqueue()函数中调用核函数
  • 2 继承nvinfer1::IPluginCreator类,是一个插件工厂类,用于插件的实例创建

    class MySELUPluginCreator : public nvinfer1::IPluginCreator {};
    
    • 需要注意的是,实现getPluginName()函数时,其返回的名称要和onnx中该算子一致
  • 3 采用宏REGISTER_TENSORRT_PLUGIN注册插件:

    REGISTER_TENSORRT_PLUGIN(MySELUPluginCreator);
    

IPluginV2DynamicExt继承自IPluginV2Ext, IPluginV2Ext又继承自IPluginV2,所以需要实现这三个基类的虚函数, 主要是下面几个:

  • IPluginV2DynamicExt基类:
    • 构造函数和析构函数
    • virtual DimsExprs getOutputDimensions():输出数据的尺寸
    • virtual bool supportsFormatCombination():支持的数据类型,int8,float16,float32等
    • virtual void configurePlugin(): 配置插件格式(这个算子所采用的数据格式和类型)
    • virtual size_t getWorkspaceSize(): 需要的额外空间大小
    • virtual int enqueue(): 推理具体逻辑
  • IPluginV2Ext基类:
    • virtual nvinfer1::DataType getOutputDataType()
  • IPluginV2基类:
    • virtual AsciiChar const* getPluginType()
    • virtual AsciiChar const* getPluginVersion()
    • virtual int32_t getNbOutputs()
    • virtual size_t getSerializationSize()
    • virtual void serialize(void* buffer)

IPluginCreato基类,主要需要实现的虚函数如下:

  • 构造函数和析构函数
  • virtual AsciiChar const* getPluginName()
  • virtual AsciiChar const* getPluginVersion()
  • virtual PluginFieldCollection const* getFieldNames()
  • virtual IPluginV2* createPlugin()
  • virtual IPluginV2* deserializePlugin()
  • virtual void setPluginNamespace()
  • virtual AsciiChar const* getPluginNamespace()

下面是实现插件的代码:

myselu_plugin.hpp:

#ifndef CUSTOM_MYSELU_PLUGIN_H
#define CUSTOM_MYSELU_PLUGIN_H

#include <NvInferPlugin.h>
#include <string>
#include <vector>

class MySELUPlugin : public nvinfer1::IPluginV2DynamicExt {
public:
	MySELUPlugin(const std::string name, const std::string attr1, float attr3);  // 接受算子名称属性,build engine时构造函数
	MySELUPlugin(const std::string name, const void* data, size_t length);  // 接受算子名称和反序列化的engine data,推理时构造函数
	MySELUPlugin() = delete;

	int getNbOutputs() const noexcept override;
	virtual nvinfer1::DataType getOutputDataType(int32_t index,
		nvinfer1::DataType const* inputTypes, int32_t nbInputs) const noexcept override {
		return inputTypes[0];
	}
	virtual nvinfer1::DimsExprs getOutputDimensions(int32_t outputIndex,
		const nvinfer1::DimsExprs* inputs, int32_t nbInputs, nvinfer1::IExprBuilder& exprBuilder) noexcept override;

	int initialize() noexcept override;
	void terminate() noexcept override;

	virtual size_t getWorkspaceSize(const nvinfer1::PluginTensorDesc* inputs,
		int32_t nbInputs, const nvinfer1::PluginTensorDesc* outputs,
		int32_t nbOutputs) const noexcept override {
		return 0;
	};

	int enqueue(const nvinfer1::PluginTensorDesc* inputDesc, const nvinfer1::PluginTensorDesc* outputDesc,
		const void* const* inputs, void* const* outputs, void* workspace, cudaStream_t stream) noexcept override;

	size_t getSerializationSize() const noexcept override;
	void serialize(void* buffer)  const noexcept override;

	virtual void configurePlugin(const  nvinfer1::DynamicPluginTensorDesc* in, int32_t nbInputs,
		const  nvinfer1::DynamicPluginTensorDesc* out, int32_t nbOutputs) noexcept override;

	virtual bool supportsFormatCombination(int32_t pos, const nvinfer1::PluginTensorDesc* inOut, int32_t nbInputs,
		int32_t nbOutputs) noexcept override;

	const char* getPluginType() const noexcept override;
	const char* getPluginVersion() const noexcept override;

	void destroy() noexcept override;
	nvinfer1::IPluginV2DynamicExt* clone() const noexcept override;
	void setPluginNamespace(nvinfer1::AsciiChar const* pluginNamespace) noexcept override;
	const char* getPluginNamespace()const noexcept override;

private:
	const std::string mLayerName;
	std::string mattr1;
	float mattr3;
	size_t mInputVolume;
	std::string mNamespace;
};


class MySELUPluginCreator : public nvinfer1::IPluginCreator {
public:
	MySELUPluginCreator();
	const char* getPluginName() const noexcept override;
	const char* getPluginVersion() const noexcept override;
	const nvinfer1::PluginFieldCollection* getFieldNames() noexcept override;

	nvinfer1::IPluginV2* createPlugin(nvinfer1::AsciiChar const* name,
		nvinfer1::PluginFieldCollection const* fc) noexcept override;
	nvinfer1::IPluginV2* deserializePlugin(nvinfer1::AsciiChar const* name,
		void const* serialData, size_t serialLength)noexcept override;
	void setPluginNamespace(nvinfer1::AsciiChar const* pluginNamespace) noexcept override;
	const char* getPluginNamespace() const noexcept override;
private:
	static nvinfer1::PluginFieldCollection mfc;
	static std::vector<nvinfer1::PluginField> mPluginAttributes;
	std::string mNamespace;
};

#endif

``myselu_plugin.cpp:

#include "myselu_plugin.hpp"
#include <NvInfer.h>
#include <cstring>
#include <vector>
#include <cassert>

void myselu_inference(const float* x, float* output, int n, cudaStream_t stream);

// MySELU plugin的特定常量
namespace {
	const char* MYSELU_PLUGIN_VERSION{ "1" };
	const char* MYSELU_PLUGIN_NAME{ "MYSELU" };  //名称要和onnx中对应的一致
}

// 静态类字段的初始化
nvinfer1::PluginFieldCollection MySELUPluginCreator::mfc{};
std::vector<nvinfer1::PluginField> MySELUPluginCreator::mPluginAttributes;

// 实际注册时,注册的是创建器,交给tensorRT管理
REGISTER_TENSORRT_PLUGIN(MySELUPluginCreator);

// 用于序列化插件的Helper function
template <typename T>
void writeToBuffer(char*& buffer, const T& val) {
	*reinterpret_cast<T*>(buffer) = val;
	buffer += sizeof(T);
}

// 用于反序列化插件的Helper function
template <typename T>
T readFromBuffer(const char*& buffer) {
	T val = *reinterpret_cast<const T*>(buffer);
	buffer += sizeof(T);
	return val;
}

// 定义插件类MYSELUPlugin
MySELUPlugin::MySELUPlugin(const std::string name, const std::string attr1, float attr3)
	:mLayerName(name), mattr1(attr1), mattr3(attr3)
{
	printf("==================== 编译阶段,attr1 = %s, attr3 = %f\n", attr1.c_str(), attr3);
};

MySELUPlugin::MySELUPlugin(const std::string name, const void* data, size_t length) 
	:mLayerName(name)
{
	// Deserialize in the same order as serialization
	const char* d = static_cast<const char*>(data);
	const char* a = d;
	int nstr = readFromBuffer<int>(d);
	mattr1 = std::string(d, d + nstr);

	d += nstr;
	mattr3 = readFromBuffer<float>(d);
	assert(d == (a + length));
	printf("==================== 推理阶段,attr1 = %s, attr3 = %f\n", mattr1.c_str(), mattr3);
};

const char* MySELUPlugin::getPluginType() const noexcept
{
	return MYSELU_PLUGIN_NAME;
}

const char* MySELUPlugin::getPluginVersion() const noexcept
{
	return MYSELU_PLUGIN_VERSION;
}

int MySELUPlugin::getNbOutputs() const noexcept {
	return 1;
}

// 获取该层的输出维度是多少
nvinfer1::DimsExprs MySELUPlugin::getOutputDimensions(int32_t outputIndex,
	const nvinfer1::DimsExprs* inputs, int32_t nbInputs, nvinfer1::IExprBuilder& exprBuilder) noexcept 
{
	// MySELUping不改变输入尺寸,所以输出尺寸将与输入尺寸相同
	return *inputs;

}

int MySELUPlugin::initialize() noexcept
{
	return 0;
}


int MySELUPlugin::enqueue(const nvinfer1::PluginTensorDesc* inputDesc, const nvinfer1::PluginTensorDesc* outputDesc,
	const void* const* inputs, void* const* outputs, void* workspace, cudaStream_t stream) noexcept
{
	void* output = outputs[0];
	size_t volume = 1;
	for (int i = 0; i < inputDesc->dims.nbDims; ++i) {
		volume *= inputDesc->dims.d[i];
	}
	mInputVolume = volume;
	myselu_inference(static_cast<const float*>(inputs[0]),
		static_cast<float*>(output),
		mInputVolume,
		stream
	);
	return 0;
}

size_t MySELUPlugin::getSerializationSize() const noexcept
{
	return sizeof(int) + mattr1.size() + sizeof(mattr3);
}

// 该层的参数序列化储存为trtmodel文件
void MySELUPlugin::serialize(void* buffer)  const noexcept
{
	char* d = static_cast<char*>(buffer);
	const char* a = d;
	int nstr = mattr1.size();
	writeToBuffer(d, nstr);
	memcpy(d, mattr1.data(), nstr);
	
	d += nstr;
	writeToBuffer(d, mattr3);
	assert(d == a + getSerializationSize());
}

// 判断该插件所支持的数据格式和类型
bool MySELUPlugin::supportsFormatCombination(int32_t pos, const nvinfer1::PluginTensorDesc* inOut, int32_t nbInputs,
	int32_t nbOutputs) noexcept
{
	auto type = inOut[pos].type;
	auto format = inOut[pos].format;
	// 这个插件只支持普通的浮点数,以及NCHW输入格式
	if (type == nvinfer1::DataType::kFLOAT && format == nvinfer1::PluginFormat::kLINEAR) {
		return true;
	}
	else {
		return false;
	}
}

void MySELUPlugin::terminate() noexcept {}

void MySELUPlugin::destroy() noexcept
{
	// This gets called when the network containing plugin is destroyed
	delete this;
}

// 配置插件格式:目前这个层所采用的数据格式和类型
void MySELUPlugin::configurePlugin(const  nvinfer1::DynamicPluginTensorDesc* in, int32_t nbInputs,
	const  nvinfer1::DynamicPluginTensorDesc* out, int32_t nbOutputs) noexcept
{
	auto type = in->desc.type;
	auto format = in->desc.format;
	assert(nbOutputs == 1);
	assert(type == nvinfer1::DataType::kFLOAT);
	assert(format == nvinfer1::PluginFormat::kLINEAR);
}

// 克隆插件
nvinfer1::IPluginV2DynamicExt* MySELUPlugin::clone() const noexcept
{
	printf("===================克隆插件=================\n");
	auto plugin = new MySELUPlugin(mLayerName, mattr1, mattr3);
	plugin->setPluginNamespace(mNamespace.c_str());
	return plugin;
}

void MySELUPlugin::setPluginNamespace(const char* libNamespace) noexcept
{
	mNamespace = libNamespace;
}

const char* MySELUPlugin::getPluginNamespace() const noexcept
{
	return mNamespace.c_str();
}


// 插件创建器
MySELUPluginCreator::MySELUPluginCreator()
{
	// 描述MySELUPlugin的必要PluginField参数
	mPluginAttributes.emplace_back(nvinfer1::PluginField("attr1", nullptr, nvinfer1::PluginFieldType::kCHAR));
	mPluginAttributes.emplace_back(nvinfer1::PluginField("attr3", nullptr, nvinfer1::PluginFieldType::kFLOAT32));
	
	// 收集PluginField的参数
	mfc.nbFields = mPluginAttributes.size();
	mfc.fields = mPluginAttributes.data();

}

const char* MySELUPluginCreator::getPluginName() const noexcept
{
	return MYSELU_PLUGIN_NAME;
}

const char* MySELUPluginCreator::getPluginVersion() const noexcept
{
	return MYSELU_PLUGIN_VERSION;
}
const nvinfer1::PluginFieldCollection* MySELUPluginCreator::getFieldNames() noexcept
{
	return &mfc;
}

// 创建plugin
nvinfer1::IPluginV2* MySELUPluginCreator::createPlugin(nvinfer1::AsciiChar const* name,
	nvinfer1::PluginFieldCollection const* fc) noexcept
{
	std::string attr1;
	float attr3;
	const nvinfer1::PluginField* fields = fc->fields;

	// Parse fields from PluginFieldCollection
	for (int i = 0; i < fc->nbFields; ++i) {
		if (strcmp(fields[i].name, "attr1")==0) {
			assert(fields[i].type == nvinfer1::PluginFieldType::kCHAR);
			auto cp = static_cast<const char*>(fields[i].data);
			attr1 = std::string(cp, cp + fields[i].length);
		}
		else if (strcmp(fields[i].name, "attr3") == 0) {
			assert(fields[i].type == nvinfer1::PluginFieldType::kFLOAT32);
			attr3 = *(static_cast<const float*>(fields[i].data));
		}
	}
	return new MySELUPlugin(name, attr1, attr3);
}

// 反序列化插件参数进行创建
nvinfer1::IPluginV2* MySELUPluginCreator::deserializePlugin(nvinfer1::AsciiChar const* name,
	void const* serialData, size_t serialLength)noexcept
{
	// This object will be deleted when the network is destroyed, which will
	// call MySELUPlugin::destroy()
	return new MySELUPlugin(name, serialData, serialLength);
}

void MySELUPluginCreator::setPluginNamespace(const char* libNamespace) noexcept
{
	mNamespace = libNamespace;
}

const char* MySELUPluginCreator::getPluginNamespace() const noexcept
{
	return mNamespace.c_str();
}

核函数myselu_kernel.cu

#include <cuda_runtime.h>
#include <cmath>

static __device__ float sigmoid(float x) {
	return 1 / (1 + expf(-x));
}

static __global__ void myselu_kernel(const float* x, float* output, int n)
{
	int position = threadIdx.x + blockDim.x*blockIdx.x;
	if (position >= n) return;
	output[position] = x[position]*sigmoid(x[position]);
}

void myselu_inference(const float* x, float* output, int n, cudaStream_t stream)
{
	const int nthreads = 512;
	int block_size = n > nthreads ? nthreads : n;
	int grid_size = (n + block_size - 1) / block_size;
	myselu_kernel<<<grid_size, block_size, 0, stream>>>(x, output, n);
}

主函数main.cpp

// tensorRT include
// 编译用的头文件
#include <NvInfer.h>

// onnx解析器的头文件
#include <NvOnnxParser.h>

// 推理用的运行时头文件
#include <NvInferRuntime.h>

// cuda include
#include <cuda_runtime.h>

// system include
#include <stdio.h>
#include <math.h>

#include <iostream>
#include <fstream>
#include <vector>

using namespace std;

inline const char* severity_string(nvinfer1::ILogger::Severity t) {
	switch (t) {
	case nvinfer1::ILogger::Severity::kINTERNAL_ERROR: return "internal_error";
	case nvinfer1::ILogger::Severity::kERROR:   return "error";
	case nvinfer1::ILogger::Severity::kWARNING: return "warning";
	case nvinfer1::ILogger::Severity::kINFO:    return "info";
	case nvinfer1::ILogger::Severity::kVERBOSE: return "verbose";
	default: return "unknow";
	}
}

class TRTLogger : public nvinfer1::ILogger {
public:
	virtual void log(Severity severity, nvinfer1::AsciiChar const* msg) noexcept override {
		if (severity <= Severity::kINFO) {
			// 打印带颜色的字符,格式如下:
			// printf("\033[47;33m打印的文本\033[0m");
			// 其中 \033[ 是起始标记
			//      47    是背景颜色
			//      ;     分隔符
			//      33    文字颜色
			//      m     开始标记结束
			//      \033[0m 是终止标记
			// 其中背景颜色或者文字颜色可不写
			// 部分颜色代码 https://blog.csdn.net/ericbar/article/details/79652086
			if (severity == Severity::kWARNING) {
				printf("\033[33m%s: %s\033[0m\n", severity_string(severity), msg);
			}
			else if (severity <= Severity::kERROR) {
				printf("\033[31m%s: %s\033[0m\n", severity_string(severity), msg);
			}
			else {
				printf("%s: %s\n", severity_string(severity), msg);
			}
		}
	}
} logger;

bool build_model() {
	TRTLogger logger;

	// 这是基本需要的组件
	nvinfer1::IBuilder* builder = nvinfer1::createInferBuilder(logger);
	nvinfer1::IBuilderConfig* config = builder->createBuilderConfig();
	nvinfer1::INetworkDefinition* network = builder->createNetworkV2(1);

	// 通过onnxparser解析器解析的结果会填充到network中,类似addConv的方式添加进去
	nvonnxparser::IParser* parser = nvonnxparser::createParser(*network, logger);
	if (!parser->parseFromFile("myselu.onnx", 1)) {
		network->destroy();
		config->destroy();
		builder->destroy();
		printf("load onnx file failed\n");
		return false;
	}

	int maxBatchSize = 10;
	printf("Workspace Size = %.2f MB\n", (1 << 28) / 1024.0f / 1024.0f);
	config->setMaxWorkspaceSize(1 << 28);

	// 如果模型有多个输入,则必须多个profile
	auto profile = builder->createOptimizationProfile();
	auto input_tensor = network->getInput(0);
	int input_channel = input_tensor->getDimensions().d[1];

	// 配置输入的最小、最优、最大的范围
	profile->setDimensions(input_tensor->getName(), nvinfer1::OptProfileSelector::kMIN, nvinfer1::Dims4(1, input_channel, 3, 3));
	profile->setDimensions(input_tensor->getName(), nvinfer1::OptProfileSelector::kOPT, nvinfer1::Dims4(1, input_channel, 3, 3));
	profile->setDimensions(input_tensor->getName(), nvinfer1::OptProfileSelector::kMAX, nvinfer1::Dims4(maxBatchSize, input_channel, 5, 5));
	config->addOptimizationProfile(profile);

	nvinfer1::ICudaEngine* engine = builder->buildEngineWithConfig(*network, *config);
	if (engine == nullptr) {
		printf("build engine failed\n");
		network->destroy();
		config->destroy();
		builder->destroy();
		return false;
	}

	// 将模型序列化,并储存为文件
	nvinfer1::IHostMemory* model_data = engine->serialize();
	FILE* f = fopen("myselu.trtmodel", "wb");
	fwrite(model_data->data(), 1, model_data->size(), f);
	fclose(f);

	// 卸载顺序按照构建顺序倒序
	model_data->destroy();
	parser->destroy();
	engine->destroy();
	network->destroy();
	config->destroy();
	builder->destroy();
	printf("Done.\n");
	return true;
}

///////////////////////////////////////////////////////////////////////////////////////////////////////

vector<unsigned char> load_file(const string& file) {
	ifstream in(file, ios::in | ios::binary);
	if (!in.is_open())
		return {};

	in.seekg(0, ios::end);
	size_t length = in.tellg();

	std::vector<uint8_t> data;
	if (length > 0) {
		in.seekg(0, ios::beg);
		data.resize(length);

		in.read((char*)&data[0], length);
	}
	in.close();
	return data;
}

void inference() {

	TRTLogger logger;
	auto engine_data = load_file("myselu.trtmodel");
	nvinfer1::IRuntime* runtime = nvinfer1::createInferRuntime(logger);
	nvinfer1::ICudaEngine* engine = runtime->deserializeCudaEngine(engine_data.data(), engine_data.size());
	if (engine == nullptr) {
		printf("Deserialize cuda engine failed.\n");
		runtime->destroy();
		return;
	}

	nvinfer1::IExecutionContext* execution_context = engine->createExecutionContext();
	cudaStream_t stream = nullptr;
	cudaStreamCreate(&stream);

	float input_data_host[] = {
		// batch 0
		1,   1,   1,
		1,   1,   1,
		1,   1,   1,

		// batch 1
		-1,   1,   1,
		1,   0,   1,
		1,   1,   -1
	};
	float* input_data_device = nullptr;

	// 3x3输入,对应3x3输出
	const int ib = 2;
	const int iw = 3;
	const int ih = 3;
	float output_data_host[ib * iw * ih];
	float* output_data_device = nullptr;
	cudaMalloc(&input_data_device, sizeof(input_data_host));
	cudaMalloc(&output_data_device, sizeof(output_data_host));
	cudaMemcpyAsync(input_data_device, input_data_host, sizeof(input_data_host), cudaMemcpyHostToDevice, stream);

	// 明确当前推理时,使用的数据输入大小
	execution_context->setBindingDimensions(0, nvinfer1::Dims4(ib, 1, ih, iw));
	float* bindings[] = { input_data_device, output_data_device };
	bool success = execution_context->enqueueV2((void**)bindings, stream, nullptr);
	cudaMemcpyAsync(output_data_host, output_data_device, sizeof(output_data_host), cudaMemcpyDeviceToHost, stream);
	cudaStreamSynchronize(stream);

	for (int b = 0; b < ib; ++b) {
		printf("batch %d. output_data_host = \n", b);
		for (int i = 0; i < iw * ih; ++i) {
			printf("%f, ", output_data_host[b * iw * ih + i]);
			if ((i + 1) % iw == 0)
				printf("\n");
		}
	}

	printf("Clean memory\n");
	cudaStreamDestroy(stream);
	cudaFree(input_data_device);
	cudaFree(output_data_device);
	execution_context->destroy();
	engine->destroy();
	runtime->destroy();
}

int main() {
	if (!build_model()) {
		return -1;
	}
	inference();

	std::cin.get();
	return 0;
}

上述代码运行过程中,可以观察插件的运行阶段:

  1. 编译阶段

      1. 通过MySELUPluginCreator::createPlugin创建plugin
      1. 期间会调用MySELUPlugin::clone克隆插件
      1. 调用MySELUPlugin::supportsFormatCombination判断该插件所支持的数据格式和类型
      • 在这里我们告诉引擎,本插件可以支持什么类型的推理
      • 可以支持多种,例如fp32、fp16、int8等等
      1. 调用MySELUPlugin::getOutputDimensions获取该层的输出维度是多少
      1. 调用MySELUPlugin::enqueue进行性能测试(不是一定会执行)
      • 如果支持多种,则会在多种里面进行实际测试,选择一个性能最好的配置
      1. 调用MySELUPlugin::configurePlugin配置插件格式
      • 告诉你目前这个层所采用的数据格式和类型
      1. 调用MySELUPlugin::serialize将该层的参数序列化储存为trtmodel文件
  2. 推理阶段

      1. 通过MySELUPluginCreator::deserializePlugin反序列化插件参数进行创建
      1. 期间会调用MySELUPlugin::clone克隆插件
      1. 调用MySELUPlugin::configurePlugin配置当前插件使用的数据类型和格式
      1. 调用MySELUPlugin::enqueue进行推理