Trt utils
All the tooling to ease TensorRT usage.
TensorRTShape
dataclass
#
Store input shapes for TensorRT build. 3 shapes per input tensor are required (as tuple of integers):
- minimum input shape
- optimal size used for the benchmarks during building
- maximum input shape
Set input name to None for default shape.
Source code in src/transformer_deploy/backends/trt_utils.py
@dataclass
class TensorRTShape:
"""
Store input shapes for TensorRT build.
3 shapes per input tensor are required (as tuple of integers):
* minimum input shape
* optimal size used for the benchmarks during building
* maximum input shape
Set input name to None for default shape.
"""
min_shape: List[int]
optimal_shape: List[int]
max_shape: List[int]
input_name: Optional[str]
def check_validity(self) -> None:
"""
Basic checks of provided shapes
"""
assert len(self.min_shape) == len(self.optimal_shape) == len(self.max_shape)
assert len(self.min_shape) > 0
assert self.min_shape[0] > 0 and self.optimal_shape[0] > 0 and self.max_shape[0] > 0
assert self.input_name is not None
def make_copy(self, input_name: str) -> "TensorRTShape":
"""
Make a copy of the current instance, with a different input name.
:param input_name: new input name to use
:return: a copy of the current shape with a different name
"""
instance_copy = dataclasses.replace(self)
instance_copy.input_name = input_name
return instance_copy
def generate_multiple_shapes(self, input_names: List[str]) -> List["TensorRTShape"]:
"""
Generate multiple shapes when only a single default one is defined.
:param input_names: input names used by the model
:return: a list of shapes
"""
assert self.input_name is None, f"input name is not None: {self.input_name}"
result = list()
for name in input_names:
shape = self.make_copy(input_name=name)
result.append(shape)
return result
check_validity(self)
#
Basic checks of provided shapes
Source code in src/transformer_deploy/backends/trt_utils.py
def check_validity(self) -> None:
"""
Basic checks of provided shapes
"""
assert len(self.min_shape) == len(self.optimal_shape) == len(self.max_shape)
assert len(self.min_shape) > 0
assert self.min_shape[0] > 0 and self.optimal_shape[0] > 0 and self.max_shape[0] > 0
assert self.input_name is not None
generate_multiple_shapes(self, input_names)
#
Generate multiple shapes when only a single default one is defined.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
input_names |
List[str] |
input names used by the model |
required |
Returns:
Type | Description |
---|---|
List[TensorRTShape] |
a list of shapes |
Source code in src/transformer_deploy/backends/trt_utils.py
def generate_multiple_shapes(self, input_names: List[str]) -> List["TensorRTShape"]:
"""
Generate multiple shapes when only a single default one is defined.
:param input_names: input names used by the model
:return: a list of shapes
"""
assert self.input_name is None, f"input name is not None: {self.input_name}"
result = list()
for name in input_names:
shape = self.make_copy(input_name=name)
result.append(shape)
return result
make_copy(self, input_name)
#
Make a copy of the current instance, with a different input name.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
input_name |
str |
new input name to use |
required |
Returns:
Type | Description |
---|---|
TensorRTShape |
a copy of the current shape with a different name |
Source code in src/transformer_deploy/backends/trt_utils.py
def make_copy(self, input_name: str) -> "TensorRTShape":
"""
Make a copy of the current instance, with a different input name.
:param input_name: new input name to use
:return: a copy of the current shape with a different name
"""
instance_copy = dataclasses.replace(self)
instance_copy.input_name = input_name
return instance_copy
build_engine(runtime, onnx_file_path, logger, workspace_size, fp16, int8, fp16_fix=<function fix_fp16_network at 0x7f6f4129a0d0>, **kwargs)
#
Convert ONNX file to TensorRT engine. It supports dynamic shape, however it's advised to keep sequence length fix as it hurts performance otherwise. Dynamic batch size doesn't hurt performance and is highly advised. Batch size can provided through different ways:
- min_shape, optimal_shape, max_shape: for simple case, 3 tuples of int when all input tensors have the same shape
- input_shapes: a list of TensorRTShape with names if there are several input tensors with different shapes
TIP: minimum batch size should be 1 in most cases.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
runtime |
Runtime |
global variable shared accross inference call / model building |
required |
onnx_file_path |
str |
path to the ONNX file |
required |
logger |
Logger |
specific logger to TensorRT |
required |
workspace_size |
int |
GPU memory to use during the building, more is always better. If there is not enough memory, some optimization may fail, and the whole conversion process will crash. |
required |
fp16 |
bool |
enable FP16 precision, it usually provide a 20-30% boost compared to ONNX Runtime. |
required |
int8 |
bool |
enable INT-8 quantization, best performance but model should have been quantized. |
required |
fp16_fix |
Callable[[tensorrt.tensorrt.INetworkDefinition], tensorrt.tensorrt.INetworkDefinition] |
a function to set FP32 precision on some nodes to fix FP16 overflow |
<function fix_fp16_network at 0x7f6f4129a0d0> |
Returns:
Type | Description |
---|---|
ICudaEngine |
TensorRT engine to use during inference |
Source code in src/transformer_deploy/backends/trt_utils.py
def build_engine(
runtime: Runtime,
onnx_file_path: str,
logger: Logger,
workspace_size: int,
fp16: bool,
int8: bool,
fp16_fix: Callable[[INetworkDefinition], INetworkDefinition] = fix_fp16_network,
**kwargs,
) -> ICudaEngine:
"""
Convert ONNX file to TensorRT engine.
It supports dynamic shape, however it's advised to keep sequence length fix as it hurts performance otherwise.
Dynamic batch size doesn't hurt performance and is highly advised.
Batch size can provided through different ways:
* **min_shape**, **optimal_shape**, **max_shape**: for simple case, 3 tuples of int when all
input tensors have the same shape
* **input_shapes**: a list of TensorRTShape with names if there are several input tensors with different shapes
**TIP**: minimum batch size should be 1 in most cases.
:param runtime: global variable shared accross inference call / model building
:param onnx_file_path: path to the ONNX file
:param logger: specific logger to TensorRT
:param workspace_size: GPU memory to use during the building, more is always better.
If there is not enough memory, some optimization may fail, and the whole conversion process will crash.
:param fp16: enable FP16 precision, it usually provide a 20-30% boost compared to ONNX Runtime.
:param int8: enable INT-8 quantization, best performance but model should have been quantized.
:param fp16_fix: a function to set FP32 precision on some nodes to fix FP16 overflow
:return: TensorRT engine to use during inference
"""
# default input shape
if "min_shape" in kwargs and "optimal_shape" in kwargs and "max_shape" in kwargs:
default_shape = TensorRTShape(
min_shape=kwargs["min_shape"],
optimal_shape=kwargs["optimal_shape"],
max_shape=kwargs["max_shape"],
input_name=None,
)
input_shapes = [default_shape]
else:
assert "input_shapes" in kwargs, "missing input shapes"
input_shapes: List[TensorRTShape] = kwargs["input_shapes"]
with trt.Builder(logger) as builder: # type: Builder
with builder.create_network(
flags=1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH)
) as network_def: # type: INetworkDefinition
with trt.OnnxParser(network_def, logger) as parser: # type: OnnxParser
# The maximum batch size which can be used at execution time,
# and also the batch size for which the ICudaEngine will be optimized.
builder.max_batch_size = max([s.max_shape[0] for s in input_shapes])
config: IBuilderConfig = builder.create_builder_config()
config.max_workspace_size = workspace_size
# to enable complete trt inspector debugging, only for TensorRT >= 8.2
# config.profiling_verbosity = trt.ProfilingVerbosity.DETAILED
# disable CUDNN optimizations
config.set_tactic_sources(
tactic_sources=1 << int(trt.TacticSource.CUBLAS) | 1 << int(trt.TacticSource.CUBLAS_LT)
)
if int8:
config.set_flag(trt.BuilderFlag.INT8)
if fp16:
config.set_flag(trt.BuilderFlag.FP16)
config.set_flag(trt.BuilderFlag.DISABLE_TIMING_CACHE)
# https://github.com/NVIDIA/TensorRT/issues/1196 (sometimes big diff in output when using FP16)
config.set_flag(trt.BuilderFlag.OBEY_PRECISION_CONSTRAINTS)
with open(onnx_file_path, "rb") as f:
parser.parse(f.read())
profile: IOptimizationProfile = builder.create_optimization_profile()
# duplicate default shape (one for each input)
if len(input_shapes) == 1 and input_shapes[0].input_name is None:
names = [network_def.get_input(num_input).name for num_input in range(network_def.num_inputs)]
input_shapes = input_shapes[0].generate_multiple_shapes(input_names=names)
for shape in input_shapes:
shape.check_validity()
profile.set_shape(
input=shape.input_name,
min=shape.min_shape,
opt=shape.optimal_shape,
max=shape.max_shape,
)
if "shape_tensors" in kwargs:
for shape in kwargs["shape_tensors"]:
profile.set_shape_input(
input=shape.input_name,
min=shape.min_shape,
opt=shape.optimal_shape,
max=shape.max_shape,
)
config.add_optimization_profile(profile)
if fp16:
network_def = fp16_fix(network_def)
trt_engine = builder.build_serialized_network(network_def, config)
engine: ICudaEngine = runtime.deserialize_cuda_engine(trt_engine)
assert engine is not None, "error during engine generation, check error messages above :-("
return engine
fix_fp16_network(network_definition)
#
Mixed precision on TensorRT can generate scores very far from Pytorch because of some operator being saturated. Indeed, FP16 can't store very large and very small numbers like FP32. Here, we search for some patterns of operators to keep in FP32, in most cases, it is enough to fix the inference and don't hurt performances.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
network_definition |
INetworkDefinition |
graph generated by TensorRT after parsing ONNX file (during the model building) |
required |
Returns:
Type | Description |
---|---|
INetworkDefinition |
patched network definition |
Source code in src/transformer_deploy/backends/trt_utils.py
def fix_fp16_network(network_definition: INetworkDefinition) -> INetworkDefinition:
"""
Mixed precision on TensorRT can generate scores very far from Pytorch because of some operator being saturated.
Indeed, FP16 can't store very large and very small numbers like FP32.
Here, we search for some patterns of operators to keep in FP32, in most cases, it is enough to fix the inference
and don't hurt performances.
:param network_definition: graph generated by TensorRT after parsing ONNX file (during the model building)
:return: patched network definition
"""
# search for patterns which may overflow in FP16 precision, we force FP32 precisions for those nodes
for layer_index in range(network_definition.num_layers - 1):
layer: ILayer = network_definition.get_layer(layer_index)
next_layer: ILayer = network_definition.get_layer(layer_index + 1)
# POW operation usually followed by mean reduce
if layer.type == trt.LayerType.ELEMENTWISE and next_layer.type == trt.LayerType.REDUCE:
# casting to get access to op attribute
layer.__class__ = IElementWiseLayer
next_layer.__class__ = IReduceLayer
if layer.op == trt.ElementWiseOperation.POW:
layer.precision = trt.DataType.FLOAT
next_layer.precision = trt.DataType.FLOAT
layer.set_output_type(index=0, dtype=trt.DataType.FLOAT)
next_layer.set_output_type(index=0, dtype=trt.DataType.FLOAT)
return network_definition
get_binding_idxs(engine, profile_index)
#
Calculate start/end binding indices for current context's profile https://docs.nvidia.com/deeplearning/tensorrt/developer-guide/index.html#opt_profiles_bindings
Parameters:
Name | Type | Description | Default |
---|---|---|---|
engine |
ICudaEngine |
TensorRT engine generated during the model building |
required |
profile_index |
int |
profile to use (several profiles can be set during building) |
required |
Returns:
Type | Description |
---|---|
input and output tensor indexes |
Source code in src/transformer_deploy/backends/trt_utils.py
def get_binding_idxs(engine: trt.ICudaEngine, profile_index: int):
"""
Calculate start/end binding indices for current context's profile
https://docs.nvidia.com/deeplearning/tensorrt/developer-guide/index.html#opt_profiles_bindings
:param engine: TensorRT engine generated during the model building
:param profile_index: profile to use (several profiles can be set during building)
:return: input and output tensor indexes
"""
num_bindings_per_profile = engine.num_bindings // engine.num_optimization_profiles
start_binding = profile_index * num_bindings_per_profile
end_binding = start_binding + num_bindings_per_profile # Separate input and output binding indices for convenience
input_binding_idxs: List[int] = []
output_binding_idxs: List[int] = []
for binding_index in range(start_binding, end_binding):
if engine.binding_is_input(binding_index):
input_binding_idxs.append(binding_index)
else:
output_binding_idxs.append(binding_index)
return input_binding_idxs, output_binding_idxs
get_fix_fp16_network_func(keep_fp32)
#
Generate a function for TensorRT engine to set precision of specific nodes to FP32 to keep tensorrt FP16 output close to FP32 nodes.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
keep_fp32 |
List[str] |
nodes to keep in FP32 |
required |
Returns:
Type | Description |
---|---|
Callable[[tensorrt.tensorrt.INetworkDefinition], tensorrt.tensorrt.INetworkDefinition] |
a function to set node precisions |
Source code in src/transformer_deploy/backends/trt_utils.py
def get_fix_fp16_network_func(keep_fp32: List[str]) -> Callable[[INetworkDefinition], INetworkDefinition]:
"""
Generate a function for TensorRT engine to set precision of specific nodes to FP32 to keep tensorrt FP16 output
close to FP32 nodes.
:param keep_fp32: nodes to keep in FP32
:return: a function to set node precisions
"""
def f(network_definition: INetworkDefinition) -> INetworkDefinition:
for layer_index in range(network_definition.num_layers):
layer: ILayer = network_definition.get_layer(layer_index)
# identity function is mainly used for casting
# https://docs.nvidia.com/deeplearning/tensorrt/api/python_api/infer/Graph/Layers.html#iidentitylayer
# if layer.type == LayerType.IDENTITY:
# continue
if layer.name in keep_fp32:
layer.precision = trt.DataType.FLOAT
assert layer.num_outputs == 1, f"unexpected # output: {layer.num_outputs}"
layer.set_output_type(index=0, dtype=trt.DataType.FLOAT)
return network_definition
return f
get_output_tensors(context, host_inputs, input_binding_idxs, output_binding_idxs)
#
Reserve memory in GPU for input and output tensors.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
context |
IExecutionContext |
TensorRT context shared accross inference steps |
required |
host_inputs |
List[torch.Tensor] |
input tensor |
required |
input_binding_idxs |
List[int] |
indexes of each input vector (should be the same than during building) |
required |
output_binding_idxs |
List[int] |
indexes of each output vector (should be the same than during building) |
required |
Returns:
Type | Description |
---|---|
List[torch.Tensor] |
tensors where output will be stored |
Source code in src/transformer_deploy/backends/trt_utils.py
def get_output_tensors(
context: trt.IExecutionContext,
host_inputs: List[torch.Tensor],
input_binding_idxs: List[int],
output_binding_idxs: List[int],
) -> List[torch.Tensor]:
"""
Reserve memory in GPU for input and output tensors.
:param context: TensorRT context shared accross inference steps
:param host_inputs: input tensor
:param input_binding_idxs: indexes of each input vector (should be the same than during building)
:param output_binding_idxs: indexes of each output vector (should be the same than during building)
:return: tensors where output will be stored
"""
# explicitly set dynamic input shapes, so dynamic output shapes can be computed internally
for host_input, binding_index in zip(host_inputs, input_binding_idxs):
context.set_binding_shape(binding_index, tuple(host_input.shape))
# assert context.all_binding_shapes_specified
device_outputs: List[torch.Tensor] = []
for binding_index in output_binding_idxs:
# TensorRT computes output shape based on input shape provided above
output_shape = context.get_binding_shape(binding_index)
# allocate buffers to hold output results
output = torch.empty(tuple(output_shape), device="cuda")
device_outputs.append(output)
return device_outputs
infer_tensorrt(context, host_inputs, input_binding_idxs, output_binding_idxs)
#
Perform inference with TensorRT.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
context |
IExecutionContext |
shared variable |
required |
host_inputs |
Dict[str, torch.Tensor] |
input tensor |
required |
input_binding_idxs |
List[int] |
input tensor indexes |
required |
output_binding_idxs |
List[int] |
output tensor indexes |
required |
Returns:
Type | Description |
---|---|
List[torch.Tensor] |
output tensor |
Source code in src/transformer_deploy/backends/trt_utils.py
def infer_tensorrt(
context: IExecutionContext,
host_inputs: Dict[str, torch.Tensor],
input_binding_idxs: List[int],
output_binding_idxs: List[int],
) -> List[torch.Tensor]:
"""
Perform inference with TensorRT.
:param context: shared variable
:param host_inputs: input tensor
:param input_binding_idxs: input tensor indexes
:param output_binding_idxs: output tensor indexes
:return: output tensor
"""
input_tensors: List[torch.Tensor] = list()
for i in range(context.engine.num_bindings):
if not context.engine.binding_is_input(index=i):
continue
tensor_name = context.engine.get_binding_name(i)
assert tensor_name in host_inputs, f"missing input: {tensor_name}"
tensor = host_inputs[tensor_name]
assert isinstance(tensor, torch.Tensor), f"unexpected tensor class: {type(tensor)}"
# warning: small changes in output if int64 is used instead of int32
if tensor.dtype in [torch.int64, torch.long]:
tensor = tensor.type(torch.int32)
if tensor.device.type != "cuda":
tensor = tensor.to("cuda")
input_tensors.append(tensor)
# calculate input shape, bind it, allocate GPU memory for the output
output_tensors: List[torch.Tensor] = get_output_tensors(
context, input_tensors, input_binding_idxs, output_binding_idxs
)
bindings = [int(i.data_ptr()) for i in input_tensors + output_tensors]
assert context.execute_async_v2(
bindings, torch.cuda.current_stream().cuda_stream
), "failure during execution of inference"
torch.cuda.current_stream().synchronize() # sync all CUDA ops
return output_tensors
load_engine(runtime, engine_file_path, profile_index=0)
#
Load serialized TensorRT engine.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
runtime |
Runtime |
shared variable |
required |
engine_file_path |
str |
path to the serialized engine |
required |
profile_index |
int |
which profile to load, 0 if you have not used multiple profiles |
0 |
Returns:
Type | Description |
---|---|
Callable[[Dict[str, torch.Tensor]], torch.Tensor] |
A function to perform inference |
Source code in src/transformer_deploy/backends/trt_utils.py
def load_engine(
runtime: Runtime, engine_file_path: str, profile_index: int = 0
) -> Callable[[Dict[str, torch.Tensor]], torch.Tensor]:
"""
Load serialized TensorRT engine.
:param runtime: shared variable
:param engine_file_path: path to the serialized engine
:param profile_index: which profile to load, 0 if you have not used multiple profiles
:return: A function to perform inference
"""
with open(file=engine_file_path, mode="rb") as f:
engine: ICudaEngine = runtime.deserialize_cuda_engine(f.read())
stream: int = torch.cuda.current_stream().cuda_stream
context: IExecutionContext = engine.create_execution_context()
context.set_optimization_profile_async(profile_index=profile_index, stream_handle=stream)
# retrieve input/output IDs
input_binding_idxs, output_binding_idxs = get_binding_idxs(engine, profile_index) # type: List[int], List[int]
def tensorrt_model(inputs: Dict[str, torch.Tensor]) -> torch.Tensor:
return infer_tensorrt(
context=context,
host_inputs=inputs,
input_binding_idxs=input_binding_idxs,
output_binding_idxs=output_binding_idxs,
)[0]
return tensorrt_model
save_engine(engine, engine_file_path)
#
Serialize TensorRT engine to file.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
engine |
ICudaEngine |
TensorRT engine |
required |
engine_file_path |
str |
output path |
required |