Skip to content

Trt utils

All the tooling to ease TensorRT usage.

TensorRTShape dataclass #

Store input shapes for TensorRT build. 3 shapes per input tensor are required (as tuple of integers):

  • minimum input shape
  • optimal size used for the benchmarks during building
  • maximum input shape

Set input name to None for default shape.

Source code in src/transformer_deploy/backends/trt_utils.py
@dataclass
class TensorRTShape:
    """
    Store input shapes for TensorRT build.
    3 shapes per input tensor are required (as tuple of integers):

    * minimum input shape
    * optimal size used for the benchmarks during building
    * maximum input shape

    Set input name to None for default shape.
    """

    min_shape: List[int]
    optimal_shape: List[int]
    max_shape: List[int]
    input_name: Optional[str]

    def check_validity(self) -> None:
        """
        Basic checks of provided shapes
        """
        assert len(self.min_shape) == len(self.optimal_shape) == len(self.max_shape)
        assert len(self.min_shape) > 0
        assert self.min_shape[0] > 0 and self.optimal_shape[0] > 0 and self.max_shape[0] > 0
        assert self.input_name is not None

    def make_copy(self, input_name: str) -> "TensorRTShape":
        """
        Make a copy of the current instance, with a different input name.
        :param input_name: new input name to use
        :return: a copy of the current shape with a different name
        """
        instance_copy = dataclasses.replace(self)
        instance_copy.input_name = input_name
        return instance_copy

    def generate_multiple_shapes(self, input_names: List[str]) -> List["TensorRTShape"]:
        """
        Generate multiple shapes when only a single default one is defined.
        :param input_names: input names used by the model
        :return: a list of shapes
        """
        assert self.input_name is None, f"input name is not None: {self.input_name}"
        result = list()
        for name in input_names:
            shape = self.make_copy(input_name=name)
            result.append(shape)
        return result

check_validity(self) #

Basic checks of provided shapes

Source code in src/transformer_deploy/backends/trt_utils.py
def check_validity(self) -> None:
    """
    Basic checks of provided shapes
    """
    assert len(self.min_shape) == len(self.optimal_shape) == len(self.max_shape)
    assert len(self.min_shape) > 0
    assert self.min_shape[0] > 0 and self.optimal_shape[0] > 0 and self.max_shape[0] > 0
    assert self.input_name is not None

generate_multiple_shapes(self, input_names) #

Generate multiple shapes when only a single default one is defined.

Parameters:

Name Type Description Default
input_names List[str]

input names used by the model

required

Returns:

Type Description
List[TensorRTShape]

a list of shapes

Source code in src/transformer_deploy/backends/trt_utils.py
def generate_multiple_shapes(self, input_names: List[str]) -> List["TensorRTShape"]:
    """
    Generate multiple shapes when only a single default one is defined.
    :param input_names: input names used by the model
    :return: a list of shapes
    """
    assert self.input_name is None, f"input name is not None: {self.input_name}"
    result = list()
    for name in input_names:
        shape = self.make_copy(input_name=name)
        result.append(shape)
    return result

make_copy(self, input_name) #

Make a copy of the current instance, with a different input name.

Parameters:

Name Type Description Default
input_name str

new input name to use

required

Returns:

Type Description
TensorRTShape

a copy of the current shape with a different name

Source code in src/transformer_deploy/backends/trt_utils.py
def make_copy(self, input_name: str) -> "TensorRTShape":
    """
    Make a copy of the current instance, with a different input name.
    :param input_name: new input name to use
    :return: a copy of the current shape with a different name
    """
    instance_copy = dataclasses.replace(self)
    instance_copy.input_name = input_name
    return instance_copy

build_engine(runtime, onnx_file_path, logger, workspace_size, fp16, int8, fp16_fix=<function fix_fp16_network at 0x7f6f4129a0d0>, **kwargs) #

Convert ONNX file to TensorRT engine. It supports dynamic shape, however it's advised to keep sequence length fix as it hurts performance otherwise. Dynamic batch size doesn't hurt performance and is highly advised. Batch size can provided through different ways:

  • min_shape, optimal_shape, max_shape: for simple case, 3 tuples of int when all input tensors have the same shape
  • input_shapes: a list of TensorRTShape with names if there are several input tensors with different shapes

TIP: minimum batch size should be 1 in most cases.

Parameters:

Name Type Description Default
runtime Runtime

global variable shared accross inference call / model building

required
onnx_file_path str

path to the ONNX file

required
logger Logger

specific logger to TensorRT

required
workspace_size int

GPU memory to use during the building, more is always better. If there is not enough memory, some optimization may fail, and the whole conversion process will crash.

required
fp16 bool

enable FP16 precision, it usually provide a 20-30% boost compared to ONNX Runtime.

required
int8 bool

enable INT-8 quantization, best performance but model should have been quantized.

required
fp16_fix Callable[[tensorrt.tensorrt.INetworkDefinition], tensorrt.tensorrt.INetworkDefinition]

a function to set FP32 precision on some nodes to fix FP16 overflow

<function fix_fp16_network at 0x7f6f4129a0d0>

Returns:

Type Description
ICudaEngine

TensorRT engine to use during inference

Source code in src/transformer_deploy/backends/trt_utils.py
def build_engine(
    runtime: Runtime,
    onnx_file_path: str,
    logger: Logger,
    workspace_size: int,
    fp16: bool,
    int8: bool,
    fp16_fix: Callable[[INetworkDefinition], INetworkDefinition] = fix_fp16_network,
    **kwargs,
) -> ICudaEngine:
    """
    Convert ONNX file to TensorRT engine.
    It supports dynamic shape, however it's advised to keep sequence length fix as it hurts performance otherwise.
    Dynamic batch size doesn't hurt performance and is highly advised.
    Batch size can provided through different ways:

    * **min_shape**, **optimal_shape**, **max_shape**: for simple case, 3 tuples of int when all
    input tensors have the same shape
    * **input_shapes**: a list of TensorRTShape with names if there are several input tensors with different shapes

    **TIP**: minimum batch size should be 1 in most cases.

    :param runtime: global variable shared accross inference call / model building
    :param onnx_file_path: path to the ONNX file
    :param logger: specific logger to TensorRT
    :param workspace_size: GPU memory to use during the building, more is always better.
        If there is not enough memory, some optimization may fail, and the whole conversion process will crash.
    :param fp16: enable FP16 precision, it usually provide a 20-30% boost compared to ONNX Runtime.
    :param int8: enable INT-8 quantization, best performance but model should have been quantized.
    :param fp16_fix: a function to set FP32 precision on some nodes to fix FP16 overflow
    :return: TensorRT engine to use during inference
    """
    # default input shape
    if "min_shape" in kwargs and "optimal_shape" in kwargs and "max_shape" in kwargs:
        default_shape = TensorRTShape(
            min_shape=kwargs["min_shape"],
            optimal_shape=kwargs["optimal_shape"],
            max_shape=kwargs["max_shape"],
            input_name=None,
        )
        input_shapes = [default_shape]
    else:
        assert "input_shapes" in kwargs, "missing input shapes"
        input_shapes: List[TensorRTShape] = kwargs["input_shapes"]

    with trt.Builder(logger) as builder:  # type: Builder
        with builder.create_network(
            flags=1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH)
        ) as network_def:  # type: INetworkDefinition
            with trt.OnnxParser(network_def, logger) as parser:  # type: OnnxParser
                # The maximum batch size which can be used at execution time,
                # and also the batch size for which the ICudaEngine will be optimized.
                builder.max_batch_size = max([s.max_shape[0] for s in input_shapes])
                config: IBuilderConfig = builder.create_builder_config()
                config.max_workspace_size = workspace_size
                # to enable complete trt inspector debugging, only for TensorRT >= 8.2
                # config.profiling_verbosity = trt.ProfilingVerbosity.DETAILED
                # disable CUDNN optimizations
                config.set_tactic_sources(
                    tactic_sources=1 << int(trt.TacticSource.CUBLAS) | 1 << int(trt.TacticSource.CUBLAS_LT)
                )
                if int8:
                    config.set_flag(trt.BuilderFlag.INT8)
                if fp16:
                    config.set_flag(trt.BuilderFlag.FP16)
                config.set_flag(trt.BuilderFlag.DISABLE_TIMING_CACHE)
                # https://github.com/NVIDIA/TensorRT/issues/1196 (sometimes big diff in output when using FP16)
                config.set_flag(trt.BuilderFlag.OBEY_PRECISION_CONSTRAINTS)
                with open(onnx_file_path, "rb") as f:
                    parser.parse(f.read())
                profile: IOptimizationProfile = builder.create_optimization_profile()
                # duplicate default shape (one for each input)
                if len(input_shapes) == 1 and input_shapes[0].input_name is None:
                    names = [network_def.get_input(num_input).name for num_input in range(network_def.num_inputs)]
                    input_shapes = input_shapes[0].generate_multiple_shapes(input_names=names)

                for shape in input_shapes:
                    shape.check_validity()
                    profile.set_shape(
                        input=shape.input_name,
                        min=shape.min_shape,
                        opt=shape.optimal_shape,
                        max=shape.max_shape,
                    )
                if "shape_tensors" in kwargs:
                    for shape in kwargs["shape_tensors"]:
                        profile.set_shape_input(
                            input=shape.input_name,
                            min=shape.min_shape,
                            opt=shape.optimal_shape,
                            max=shape.max_shape,
                        )
                config.add_optimization_profile(profile)
                if fp16:
                    network_def = fp16_fix(network_def)
                trt_engine = builder.build_serialized_network(network_def, config)
                engine: ICudaEngine = runtime.deserialize_cuda_engine(trt_engine)
                assert engine is not None, "error during engine generation, check error messages above :-("
                return engine

fix_fp16_network(network_definition) #

Mixed precision on TensorRT can generate scores very far from Pytorch because of some operator being saturated. Indeed, FP16 can't store very large and very small numbers like FP32. Here, we search for some patterns of operators to keep in FP32, in most cases, it is enough to fix the inference and don't hurt performances.

Parameters:

Name Type Description Default
network_definition INetworkDefinition

graph generated by TensorRT after parsing ONNX file (during the model building)

required

Returns:

Type Description
INetworkDefinition

patched network definition

Source code in src/transformer_deploy/backends/trt_utils.py
def fix_fp16_network(network_definition: INetworkDefinition) -> INetworkDefinition:
    """
    Mixed precision on TensorRT can generate scores very far from Pytorch because of some operator being saturated.
    Indeed, FP16 can't store very large and very small numbers like FP32.
    Here, we search for some patterns of operators to keep in FP32, in most cases, it is enough to fix the inference
    and don't hurt performances.
    :param network_definition: graph generated by TensorRT after parsing ONNX file (during the model building)
    :return: patched network definition
    """
    # search for patterns which may overflow in FP16 precision, we force FP32 precisions for those nodes
    for layer_index in range(network_definition.num_layers - 1):
        layer: ILayer = network_definition.get_layer(layer_index)
        next_layer: ILayer = network_definition.get_layer(layer_index + 1)
        # POW operation usually followed by mean reduce
        if layer.type == trt.LayerType.ELEMENTWISE and next_layer.type == trt.LayerType.REDUCE:
            # casting to get access to op attribute
            layer.__class__ = IElementWiseLayer
            next_layer.__class__ = IReduceLayer
            if layer.op == trt.ElementWiseOperation.POW:
                layer.precision = trt.DataType.FLOAT
                next_layer.precision = trt.DataType.FLOAT
            layer.set_output_type(index=0, dtype=trt.DataType.FLOAT)
            next_layer.set_output_type(index=0, dtype=trt.DataType.FLOAT)
    return network_definition

get_binding_idxs(engine, profile_index) #

Calculate start/end binding indices for current context's profile https://docs.nvidia.com/deeplearning/tensorrt/developer-guide/index.html#opt_profiles_bindings

Parameters:

Name Type Description Default
engine ICudaEngine

TensorRT engine generated during the model building

required
profile_index int

profile to use (several profiles can be set during building)

required

Returns:

Type Description

input and output tensor indexes

Source code in src/transformer_deploy/backends/trt_utils.py
def get_binding_idxs(engine: trt.ICudaEngine, profile_index: int):
    """
    Calculate start/end binding indices for current context's profile
    https://docs.nvidia.com/deeplearning/tensorrt/developer-guide/index.html#opt_profiles_bindings
    :param engine: TensorRT engine generated during the model building
    :param profile_index: profile to use (several profiles can be set during building)
    :return: input and output tensor indexes
    """
    num_bindings_per_profile = engine.num_bindings // engine.num_optimization_profiles
    start_binding = profile_index * num_bindings_per_profile
    end_binding = start_binding + num_bindings_per_profile  # Separate input and output binding indices for convenience
    input_binding_idxs: List[int] = []
    output_binding_idxs: List[int] = []
    for binding_index in range(start_binding, end_binding):
        if engine.binding_is_input(binding_index):
            input_binding_idxs.append(binding_index)
        else:
            output_binding_idxs.append(binding_index)
    return input_binding_idxs, output_binding_idxs

get_fix_fp16_network_func(keep_fp32) #

Generate a function for TensorRT engine to set precision of specific nodes to FP32 to keep tensorrt FP16 output close to FP32 nodes.

Parameters:

Name Type Description Default
keep_fp32 List[str]

nodes to keep in FP32

required

Returns:

Type Description
Callable[[tensorrt.tensorrt.INetworkDefinition], tensorrt.tensorrt.INetworkDefinition]

a function to set node precisions

Source code in src/transformer_deploy/backends/trt_utils.py
def get_fix_fp16_network_func(keep_fp32: List[str]) -> Callable[[INetworkDefinition], INetworkDefinition]:
    """
    Generate a function for TensorRT engine to set precision of specific nodes to FP32 to keep tensorrt FP16 output
    close to FP32 nodes.
    :param keep_fp32: nodes to keep in FP32
    :return: a function to set node precisions
    """

    def f(network_definition: INetworkDefinition) -> INetworkDefinition:
        for layer_index in range(network_definition.num_layers):
            layer: ILayer = network_definition.get_layer(layer_index)
            # identity function is mainly used for casting
            # https://docs.nvidia.com/deeplearning/tensorrt/api/python_api/infer/Graph/Layers.html#iidentitylayer
            # if layer.type == LayerType.IDENTITY:
            #     continue

            if layer.name in keep_fp32:
                layer.precision = trt.DataType.FLOAT
                assert layer.num_outputs == 1, f"unexpected # output: {layer.num_outputs}"
                layer.set_output_type(index=0, dtype=trt.DataType.FLOAT)

        return network_definition

    return f

get_output_tensors(context, host_inputs, input_binding_idxs, output_binding_idxs) #

Reserve memory in GPU for input and output tensors.

Parameters:

Name Type Description Default
context IExecutionContext

TensorRT context shared accross inference steps

required
host_inputs List[torch.Tensor]

input tensor

required
input_binding_idxs List[int]

indexes of each input vector (should be the same than during building)

required
output_binding_idxs List[int]

indexes of each output vector (should be the same than during building)

required

Returns:

Type Description
List[torch.Tensor]

tensors where output will be stored

Source code in src/transformer_deploy/backends/trt_utils.py
def get_output_tensors(
    context: trt.IExecutionContext,
    host_inputs: List[torch.Tensor],
    input_binding_idxs: List[int],
    output_binding_idxs: List[int],
) -> List[torch.Tensor]:
    """
    Reserve memory in GPU for input and output tensors.
    :param context: TensorRT context shared accross inference steps
    :param host_inputs: input tensor
    :param input_binding_idxs: indexes of each input vector (should be the same than during building)
    :param output_binding_idxs: indexes of each output vector (should be the same than during building)
    :return: tensors where output will be stored
    """
    # explicitly set dynamic input shapes, so dynamic output shapes can be computed internally
    for host_input, binding_index in zip(host_inputs, input_binding_idxs):
        context.set_binding_shape(binding_index, tuple(host_input.shape))
    # assert context.all_binding_shapes_specified
    device_outputs: List[torch.Tensor] = []
    for binding_index in output_binding_idxs:
        # TensorRT computes output shape based on input shape provided above
        output_shape = context.get_binding_shape(binding_index)
        # allocate buffers to hold output results
        output = torch.empty(tuple(output_shape), device="cuda")
        device_outputs.append(output)
    return device_outputs

infer_tensorrt(context, host_inputs, input_binding_idxs, output_binding_idxs) #

Perform inference with TensorRT.

Parameters:

Name Type Description Default
context IExecutionContext

shared variable

required
host_inputs Dict[str, torch.Tensor]

input tensor

required
input_binding_idxs List[int]

input tensor indexes

required
output_binding_idxs List[int]

output tensor indexes

required

Returns:

Type Description
List[torch.Tensor]

output tensor

Source code in src/transformer_deploy/backends/trt_utils.py
def infer_tensorrt(
    context: IExecutionContext,
    host_inputs: Dict[str, torch.Tensor],
    input_binding_idxs: List[int],
    output_binding_idxs: List[int],
) -> List[torch.Tensor]:
    """
    Perform inference with TensorRT.
    :param context: shared variable
    :param host_inputs: input tensor
    :param input_binding_idxs: input tensor indexes
    :param output_binding_idxs: output tensor indexes
    :return: output tensor
    """

    input_tensors: List[torch.Tensor] = list()
    for i in range(context.engine.num_bindings):
        if not context.engine.binding_is_input(index=i):
            continue
        tensor_name = context.engine.get_binding_name(i)
        assert tensor_name in host_inputs, f"missing input: {tensor_name}"
        tensor = host_inputs[tensor_name]
        assert isinstance(tensor, torch.Tensor), f"unexpected tensor class: {type(tensor)}"
        # warning: small changes in output if int64 is used instead of int32
        if tensor.dtype in [torch.int64, torch.long]:
            tensor = tensor.type(torch.int32)
        if tensor.device.type != "cuda":
            tensor = tensor.to("cuda")
        input_tensors.append(tensor)
    # calculate input shape, bind it, allocate GPU memory for the output
    output_tensors: List[torch.Tensor] = get_output_tensors(
        context, input_tensors, input_binding_idxs, output_binding_idxs
    )
    bindings = [int(i.data_ptr()) for i in input_tensors + output_tensors]
    assert context.execute_async_v2(
        bindings, torch.cuda.current_stream().cuda_stream
    ), "failure during execution of inference"
    torch.cuda.current_stream().synchronize()  # sync all CUDA ops
    return output_tensors

load_engine(runtime, engine_file_path, profile_index=0) #

Load serialized TensorRT engine.

Parameters:

Name Type Description Default
runtime Runtime

shared variable

required
engine_file_path str

path to the serialized engine

required
profile_index int

which profile to load, 0 if you have not used multiple profiles

0

Returns:

Type Description
Callable[[Dict[str, torch.Tensor]], torch.Tensor]

A function to perform inference

Source code in src/transformer_deploy/backends/trt_utils.py
def load_engine(
    runtime: Runtime, engine_file_path: str, profile_index: int = 0
) -> Callable[[Dict[str, torch.Tensor]], torch.Tensor]:
    """
    Load serialized TensorRT engine.
    :param runtime: shared variable
    :param engine_file_path: path to the serialized engine
    :param profile_index: which profile to load, 0 if you have not used multiple profiles
    :return: A function to perform inference
    """
    with open(file=engine_file_path, mode="rb") as f:
        engine: ICudaEngine = runtime.deserialize_cuda_engine(f.read())
        stream: int = torch.cuda.current_stream().cuda_stream
        context: IExecutionContext = engine.create_execution_context()
        context.set_optimization_profile_async(profile_index=profile_index, stream_handle=stream)
        # retrieve input/output IDs
        input_binding_idxs, output_binding_idxs = get_binding_idxs(engine, profile_index)  # type: List[int], List[int]

        def tensorrt_model(inputs: Dict[str, torch.Tensor]) -> torch.Tensor:
            return infer_tensorrt(
                context=context,
                host_inputs=inputs,
                input_binding_idxs=input_binding_idxs,
                output_binding_idxs=output_binding_idxs,
            )[0]

        return tensorrt_model

save_engine(engine, engine_file_path) #

Serialize TensorRT engine to file.

Parameters:

Name Type Description Default
engine ICudaEngine

TensorRT engine

required
engine_file_path str

output path

required
Source code in src/transformer_deploy/backends/trt_utils.py
def save_engine(engine: ICudaEngine, engine_file_path: str) -> None:
    """
    Serialize TensorRT engine to file.
    :param engine: TensorRT engine
    :param engine_file_path: output path
    """
    with open(engine_file_path, "wb") as f:
        f.write(engine.serialize())