TensorRT代码框架

发布时间 2023-05-05 15:35:01作者: stardsd

TensorRT是一个用于高性能深度学习推理的SDK,它可以将已经训练好的网络转换为TensorRT引擎,从而提高推理速度和效率。要使用TensorRT,你需要以下几个步骤:

  1. 安装TensorRT和相关的系统包,如CUDA、cuDNN、Python等。
  2. 选择一个深度学习框架,如PyTorch、TensorFlow、ONNX等,将你的网络导出为TensorRT支持的格式,如ONNX或Caffe。
  3. 使用TensorRT的解析器(parser)或编译器(compiler)将网络转换为TensorRT引擎(engine),并保存为文件。
  4. 使用TensorRT的运行时(runtime)加载引擎文件,并输入数据进行推理。

下面是一个使用Python API的TensorRT使用示例代码²:

import tensorrt as trt

# 创建一个logger对象,用于报告错误和警告
TRT_LOGGER = trt.Logger(trt.Logger.WARNING)

# 创建一个builder对象,用于构建引擎
builder = trt.Builder(TRT_LOGGER)

# 创建一个network对象,用于定义网络结构
network = builder.create_network()

# 创建一个parser对象,用于解析ONNX模型文件
parser = trt.OnnxParser(network, TRT_LOGGER)

# 从文件中读取ONNX模型
with open("model.onnx", "rb") as f:
    parser.parse(f.read())

# 设置最大批量大小和最大工作空间大小
builder.max_batch_size = 1
builder.max_workspace_size = 1 << 20

# 构建引擎并保存为文件
with builder.build_cuda_engine(network) as engine:
    with open("model.trt", "wb") as f:
        f.write(engine.serialize())

# 创建一个runtime对象,用于加载引擎文件
runtime = trt.Runtime(TRT_LOGGER)

# 从文件中加载引擎
with open("model.trt", "rb") as f:
    engine = runtime.deserialize_cuda_engine(f.read())

# 创建一个context对象,用于执行引擎
context = engine.create_execution_context()

# 分配输入和输出的内存空间
import pycuda.driver as cuda
import pycuda.autoinit

input_size = trt.volume(engine.get_binding_shape(0))
output_size = trt.volume(engine.get_binding_shape(1))
input_dtype = trt.nptype(engine.get_binding_dtype(0))
output_dtype = trt.nptype(engine.get_binding_dtype(1))

input_buffer = cuda.mem_alloc(input_size * input_dtype.itemsize)
output_buffer = cuda.mem_alloc(output_size * output_dtype.itemsize)

# 将输入数据拷贝到GPU内存中
import numpy as np

input_data = np.random.randn(input_size).astype(input_dtype)
cuda.memcpy_htod(input_buffer, input_data)

# 执行引擎并同步结果
context.execute_async(bindings=[int(input_buffer), int(output_buffer)], stream_handle=pycuda.autoinit.stream.handle)
cuda.Context.synchronize()

# 将输出数据从GPU内存中拷贝到CPU内存中
output_data = np.empty(output_size, dtype=output_dtype)
cuda.memcpy_dtoh(output_data, output_buffer)

# 打印输出数据
print(output_data)

Source: 2023/5/5
(1) GitHub - NVIDIA/TensorRT: NVIDIA® TensorRT™, an SDK for high .... https://github.com/NVIDIA/TensorRT.
(2) Developer Guide :: NVIDIA Deep Learning TensorRT Documentation. https://docs.nvidia.com/deeplearning/tensorrt/developer-guide/index.html.
(3) 使用TensorRT部署你的神经网络(1) - 知乎 - 知乎专栏. https://zhuanlan.zhihu.com/p/259539097.
(4) 使用Python部署TensorRT_python tensorrt_wq_0708的博客-CSDN博客. https://blog.csdn.net/wq_0708/article/details/121266031.