TensorRT usage in Python script#
There are 2 ways to use a TensorRT optimized model:
- deploy it on Triton server
- use it directly in Python
This document is about the second option.
High-level explanations#
- call
load_engine()
to parse an existing TensorRT engine orbuild_engine()
to convert an ONNX file - setup a CUDA
stream
(for async call), a TensorRTruntime
and acontext
- load your
profile
(s) - call
infer_tensorrt()
Build engine#
We assume that you have already prepared your ONNX file.
Now we need to convert to TensorRT:
import tensorrt as trt
from tensorrt.tensorrt import Logger, Runtime
from transformer_deploy.backends.trt_utils import build_engine
trt_logger: Logger = trt.Logger(trt.Logger.ERROR)
runtime: Runtime = trt.Runtime(trt_logger)
profile_index = 0
max_seq_len = 256
batch_size = 32
engine = build_engine(
runtime=runtime,
onnx_file_path="model_qat.onnx",
logger=trt_logger,
min_shape=(1, max_seq_len),
optimal_shape=(batch_size, max_seq_len),
max_shape=(batch_size, max_seq_len),
workspace_size=10000 * 1024 * 1024,
fp16=True,
int8=True,
)
Prepare inference#
Now the engine is ready, we can prepare the inference:
import torch
from tensorrt.tensorrt import IExecutionContext
from transformer_deploy.backends.trt_utils import get_binding_idxs
context: IExecutionContext = engine.create_execution_context()
context.set_optimization_profile_async(profile_index=profile_index, stream_handle=torch.cuda.current_stream().cuda_stream)
input_binding_idxs, output_binding_idxs = get_binding_idxs(engine, profile_index) # type: List[int], List[int]
Inference#
from transformer_deploy.backends.trt_utils import infer_tensorrt
input_np = ...
tensorrt_output = infer_tensorrt(
context=context,
host_inputs=input_np,
input_binding_idxs=input_binding_idxs,
output_binding_idxs=output_binding_idxs,
)
print(tensorrt_output)
... and you are done! 🎉
Tip
To go deeper, check in the API:
Convert
Backends/Trt utils
... and if you are looking for inspiration, check onnx-tensorrt