Skip to content

Ort utils

All the tooling to ease ONNX Runtime usage.

add_output_nodes(model) #

Set each node as output node for debugging purpose.

Parameters:

Name Type Description Default
model ModelProto

ONNX model in protobuf format

required

Returns:

Type Description
ModelProto

modified ONNX model

Source code in src/transformer_deploy/backends/ort_utils.py
def add_output_nodes(model: ModelProto) -> ModelProto:
    """
    Set each node as output node for debugging purpose.
    :param model: ONNX model in protobuf format
    :return: modified ONNX model
    """
    model = copy.deepcopy(model)
    output_nodes = list()
    for n in model.graph.node:
        for output_name in n.output:
            output_nodes.append(onnx.ValueInfoProto(name=output_name))
    # clear output array (protobuff way...)
    while model.graph.output:
        model.graph.output.pop()
    model.graph.output.extend(output_nodes)
    return model

convert_fp16(onnx_model, nodes_to_exclude) #

Convert ONNX model in FP16, and still being able to exclude a list of nodes.

Parameters:

Name Type Description Default
onnx_model str

original FP32 model

required
nodes_to_exclude List[str]

nodes that should stay in FP32

required

Returns:

Type Description
ModelProto

mostly FP16 model

Source code in src/transformer_deploy/backends/ort_utils.py
def convert_fp16(onnx_model: str, nodes_to_exclude: List[str]) -> ModelProto:
    """
    Convert ONNX model in FP16, and still being able to exclude a list of nodes.
    :param onnx_model: original FP32 model
    :param nodes_to_exclude: nodes that should stay in FP32
    :return: mostly FP16 model
    """
    # add value info related to each node, required for the conversion
    output_path = onnx_model + "_shape_inference.onnx"
    infer_shapes_path(model_path=onnx_model, output_path=output_path)
    model_fp16 = onnx.load_model(output_path)
    model_fp16 = convert_float_to_float16(model=model_fp16, keep_io_types=False, node_block_list=nodes_to_exclude)
    # clean casting nodes before returning the model
    wrapped_fp16_model = OnnxModel(model_fp16)
    fusion_utils = FusionUtils(wrapped_fp16_model)
    fusion_utils.remove_cascaded_cast_nodes()
    fusion_utils.remove_useless_cast_nodes()
    wrapped_fp16_model.topological_sort()
    return wrapped_fp16_model.model

cpu_quantization(input_model_path, output_model_path) #

ONNX CPU only dynamic quantization.

Parameters:

Name Type Description Default
input_model_path str

ONNX graph (float) to quantize

required
output_model_path str

where to save quantized model

required
Source code in src/transformer_deploy/backends/ort_utils.py
def cpu_quantization(input_model_path: str, output_model_path: str) -> None:
    """
    ONNX CPU only dynamic quantization.

    :param input_model_path: ONNX graph (float) to quantize
    :param output_model_path: where to save quantized model
    """
    quantize_dynamic(
        model_input=Path(input_model_path),
        model_output=Path(output_model_path),
        op_types_to_quantize=["MatMul", "Attention"],
        weight_type=QuantType.QInt8,
        per_channel=True,
        reduce_range=True,
        extra_options={"WeightSymmetric": False, "MatMulConstBOnly": True},
    )

create_model_for_provider(path, provider_to_use, nb_threads=12, nb_instances=0, optimization_level=<GraphOptimizationLevel.ORT_ENABLE_EXTENDED: 2>, enable_profiling=False, log_severity=2) #

Create an ONNX Runtime instance.

Parameters:

Name Type Description Default
path str

path to ONNX file or serialized to string model

required
provider_to_use Union[str, List]

provider to use for inference

required
nb_threads int

intra_op_num_threads to use. You may want to try different parameters, more core does not always provide best performances.

12
nb_instances int

inter_op_num_threads to use, to execute multiple subgraphs in parallel when possible.

0
optimization_level GraphOptimizationLevel

expected level of ONNX Runtime optimization. For GPU and NLP, extended is the one providing kernel fusion of element wise operations. Enable all level is for CPU inference. see https://onnxruntime.ai/docs/performance/graph-optimizations.html#layout-optimizations

<GraphOptimizationLevel.ORT_ENABLE_EXTENDED: 2>
enable_profiling bool

let Onnx Runtime log each kernel time.

False
log_severity int

Log severity level. 0:Verbose, 1:Info, 2:Warning. 3:Error, 4:Fatal.

2

Returns:

Type Description
InferenceSession

ONNX Runtime inference session

Source code in src/transformer_deploy/backends/ort_utils.py
def create_model_for_provider(
    path: str,
    provider_to_use: Union[str, List],
    nb_threads: int = multiprocessing.cpu_count(),
    nb_instances: int = 0,
    optimization_level: GraphOptimizationLevel = GraphOptimizationLevel.ORT_ENABLE_EXTENDED,
    enable_profiling: bool = False,
    log_severity: int = 2,
) -> InferenceSession:
    """
    Create an ONNX Runtime instance.
    :param path: path to ONNX file or serialized to string model
    :param provider_to_use: provider to use for inference
    :param nb_threads: intra_op_num_threads to use. You may want to try different parameters,
        more core does not always provide best performances.
    :param nb_instances: inter_op_num_threads to use, to execute multiple subgraphs in parallel when possible.
    :param optimization_level: expected level of ONNX Runtime optimization. For GPU and NLP, extended is the one
        providing kernel fusion of element wise operations. Enable all level is for CPU inference.
        see https://onnxruntime.ai/docs/performance/graph-optimizations.html#layout-optimizations
    :param enable_profiling: let Onnx Runtime log each kernel time.
    :param log_severity: Log severity level. 0:Verbose, 1:Info, 2:Warning. 3:Error, 4:Fatal.
    :return: ONNX Runtime inference session
    """
    options = SessionOptions()
    options.graph_optimization_level = optimization_level
    options.enable_profiling = enable_profiling
    options.log_severity_level = log_severity
    if isinstance(provider_to_use, str):
        provider_to_use = [provider_to_use]
    if provider_to_use == ["CPUExecutionProvider"]:
        options.execution_mode = ExecutionMode.ORT_SEQUENTIAL if nb_instances <= 1 else ExecutionMode.ORT_PARALLEL
        options.intra_op_num_threads = nb_threads
        if nb_instances > 1:
            options.inter_op_num_threads = nb_instances
    return InferenceSession(path, options, providers=provider_to_use)

find_node_fp32(graph, output_nodes) #

Identify out of range values in node outputs.

Parameters:

Name Type Description Default
graph Dict[str, str]

graph as adjency nodes dict

required
output_nodes Dict[str, torch.Tensor]

output of each node

required

Returns:

Type Description
List[str]

list of nodes producing outputs outside fp16 tensor

Source code in src/transformer_deploy/backends/ort_utils.py
def find_node_fp32(graph: Dict[str, str], output_nodes: Dict[str, torch.Tensor]) -> List[str]:
    """
    Identify out of range values in node outputs.
    :param graph: graph as adjency nodes dict
    :param output_nodes: output of each node
    :return: list of nodes producing outputs outside fp16 tensor
    """
    keep_fp32 = list()
    min_float16 = torch.finfo(torch.float16).min
    max_float16 = torch.finfo(torch.float16).max
    resolution = 5.96e-08  # torch.finfo(torch.float16).eps  # minimum value that can be represented by FP16
    for k, tensor in output_nodes.items():
        if tensor.dtype != torch.float32:
            continue
        # out of FP16 range
        if (
            torch.any(tensor > max_float16)
            or torch.any(tensor < min_float16)
            or (torch.any(tensor < resolution & tensor > -resolution & tensor != 0))  # limited memory footprint
        ):
            keep_fp32.append(graph[k])
    return keep_fp32

get_io_to_node_mapping(onnx_model) #

Extract output->node and input->node mappings

Parameters:

Name Type Description Default
onnx_model ModelProto

ONNX model

required

Returns:

Type Description
Tuple[Dict[str, str], Dict[str, str]]

2 mappings, (i->node, o->node)

Source code in src/transformer_deploy/backends/ort_utils.py
def get_io_to_node_mapping(onnx_model: ModelProto) -> Tuple[Dict[str, str], Dict[str, str]]:
    """
    Extract output->node and input->node mappings
    :param onnx_model: ONNX model
    :return: 2 mappings, (i->node, o->node)
    """
    output_mapping: Dict[str, str] = dict()
    input_mapping: Dict[str, str] = dict()
    for node in onnx_model.graph.node:  # type: NodeProto
        assert len(node.output) == 1
        output_node = node.output[0]
        output_mapping[output_node] = node.name
        for i in node.input:
            input_mapping[i] = node.name

    return input_mapping, output_mapping

get_keep_fp32_nodes(onnx_model_path, get_input, early_stop=100, device='cuda') #

Find the list of nodes to keep in FP32 to avoid out of range values

Parameters:

Name Type Description Default
onnx_model_path str

ONNX model path

required
get_input Callable[[], Dict[str, torch.Tensor]]

generate input to test the model. Output should change from call to call

required
early_stop int

will test until early_stop tests are done without any new node to keep in FP32

100
device str

where to run the inference

'cuda'

Returns:

Type Description
List[str]

list of names of nodes to keep in FP32

Source code in src/transformer_deploy/backends/ort_utils.py
def get_keep_fp32_nodes(
    onnx_model_path: str,
    get_input: Callable[[], Dict[str, torch.Tensor]],
    early_stop: int = 100,
    device: str = "cuda",
) -> List[str]:
    """
    Find the list of nodes to keep in FP32 to avoid out of range values
    :param onnx_model_path: ONNX model path
    :param get_input: generate input to test the model. Output should change from call to call
    :param early_stop: will test until `early_stop` tests are done without any new node to keep in FP32
    :param device: where to run the inference
    :return: list of names of nodes to keep in FP32
    """
    # do not load weights on LLM (>2Gb), we only need to modify the computation graph
    onnx_model: ModelProto = onnx.load_model(f=onnx_model_path, load_external_data=False)
    onnx_model_fp32_all_nodes = add_output_nodes(model=onnx_model)
    path_onnx_model_fp32_all_nodes = onnx_model_path + "_all_nodes.onnx"
    onnx.save_model(proto=onnx_model_fp32_all_nodes, f=path_onnx_model_fp32_all_nodes, save_as_external_data=False)
    provider = "CUDAExecutionProvider" if device == "cuda" else "CPUExecutionProvider"
    ort_model_fp32_all_nodes = create_model_for_provider(path_onnx_model_fp32_all_nodes, provider)
    ort_binding = ort_model_fp32_all_nodes.io_binding()
    input_mapping, output_mapping = get_io_to_node_mapping(onnx_model=onnx_model)
    # list all nodes which have an output out of the FP16 range
    keep_fp32_nodes = list()
    no_new_node_counter = 0
    while no_new_node_counter < early_stop:
        inputs = get_input()
        outputs: Dict[str, torch.Tensor] = inference_onnx_binding(
            model_onnx=ort_model_fp32_all_nodes, inputs=inputs, device=device, binding=ort_binding, clone_tensor=False
        )
        keep_node_io = find_node_fp32(graph=output_mapping, output_nodes=outputs)

        nodes_to_add = [n for n in keep_node_io if n not in keep_fp32_nodes]
        keep_fp32_nodes += nodes_to_add
        if len(nodes_to_add) == 0:
            no_new_node_counter += 1
        else:
            no_new_node_counter = 0

    if device == "cuda":
        torch.cuda.empty_cache()
    # I/O names that can't be found in the graph
    nodes_to_skip = (
        [n.name for n in onnx_model.graph.input]
        + [n.name for n in onnx_model.graph.output]
        + [n.name for n in onnx_model.graph.initializer]
    )

    # for each node to keep in FP32, we keep its children in FP32 too as they will receive FP32 values as input
    map_children = defaultdict(list)
    for node in onnx_model.graph.node:
        for o in node.output:
            if o in nodes_to_skip:
                continue
            child = input_mapping[o]
            map_children[node.name].append(child)
    keep_fp32_nodes += [c for k in keep_fp32_nodes if k in map_children for c in map_children[k]]
    return keep_fp32_nodes

inference_onnx_binding(model_onnx, inputs, device, device_id=0, binding=None, clone_tensor=True) #

Performs inference on ONNX Runtime in an optimized way. In particular, it avoids any Onnx Runtime output tensor copy. It means that Onnx Runtime is still owner of the array, and it will overwrite its content if you do another inference. To avoid any issue, just set clone_tensor to True (default). For best performance and lowest memory footprint, if you know what you are doing, set clone_tensor to True.

Parameters:

Name Type Description Default
model_onnx InferenceSession

ONNX model

required
inputs Dict[str, torch.Tensor]

input torch tensor

required
device str

where to run the inference. One of [cpu, cuda]

required
device_id int

ID of the device where to run the inference, to be used when there are multiple GPUs, etc.

0
binding Optional[onnxruntime.capi.onnxruntime_inference_collection.IOBinding]

previously generated binding IO, will be reset.

None
clone_tensor bool

clone Pytorch tensor to avoid its content being overwritten by Onnx Runtime at the next inference call.

True

Returns:

Type Description
Dict[str, torch.Tensor]

a dict {axis name: output tensor}

Source code in src/transformer_deploy/backends/ort_utils.py
def inference_onnx_binding(
    model_onnx: InferenceSession,
    inputs: Dict[str, torch.Tensor],
    device: str,
    device_id: int = 0,
    binding: Optional[IOBinding] = None,
    clone_tensor: bool = True,
) -> Dict[str, torch.Tensor]:
    """
    Performs inference on ONNX Runtime in an optimized way.
    In particular, it avoids any Onnx Runtime output tensor copy.
    It means that Onnx Runtime is still owner of the array, and it will overwrite its content if you do another
    inference. To avoid any issue, just set clone_tensor to True (default).
    For best performance and lowest memory footprint, if you know what you are doing, set clone_tensor to True.

    :param model_onnx: ONNX model
    :param inputs: input torch tensor
    :param device: where to run the inference. One of [cpu, cuda]
    :param device_id: ID of the device where to run the inference, to be used when there are multiple GPUs, etc.
    :param binding: previously generated binding IO, will be reset.
    :param clone_tensor: clone Pytorch tensor to avoid its content being overwritten by Onnx Runtime
        at the next inference call.
    :return: a dict {axis name: output tensor}
    """
    assert isinstance(device, str)
    assert device in ["cpu", "cuda"], f"unexpected inference device: '{device}'"
    if binding is None:
        binding: IOBinding = model_onnx.io_binding()
    else:
        binding.clear_binding_inputs()
        binding.clear_binding_outputs()
    for input_onnx in model_onnx.get_inputs():
        if input_onnx.name not in inputs:  # some inputs may be optional
            continue
        tensor: torch.Tensor = inputs[input_onnx.name]
        tensor = tensor.detach()
        if tensor.dtype in [torch.int64, torch.long]:
            # int32 mandatory as input of bindings, int64 not supported
            tensor = tensor.type(dtype=torch.int32)
        tensor = tensor.contiguous()
        binding.bind_input(
            name=input_onnx.name,
            device_type=device,
            device_id=device_id,
            element_type=torch_to_numpy_dtype_dict[tensor.dtype],
            shape=tuple(tensor.shape),
            buffer_ptr=tensor.data_ptr(),
        )
        inputs[input_onnx.name] = tensor

    for out in model_onnx.get_outputs():
        binding.bind_output(
            name=out.name,
            device_type=device,
            device_id=device_id,
        )
    binding.synchronize_inputs()
    model_onnx.run_with_iobinding(binding)
    binding.synchronize_outputs()
    outputs = dict()
    assert len(model_onnx.get_outputs()) == len(
        binding.get_outputs()
    ), f"{len(model_onnx.get_outputs())} != {len(binding.get_outputs())}"
    for out, t in zip(model_onnx.get_outputs(), binding.get_outputs()):
        outputs[out.name] = to_pytorch(t, clone_tensor=clone_tensor)
    return outputs

optimize_onnx(onnx_path, onnx_optim_model_path, fp16, use_cuda, num_attention_heads=0, hidden_size=0, architecture='bert') #

ONNX Runtime transformer graph optimization. Performs some operator fusion (merge several nodes of the graph in a single one) and may convert some nodes to reduced precision.

Parameters:

Name Type Description Default
onnx_path str

ONNX input path

required
onnx_optim_model_path str

where to save optimized model

required
fp16 bool

use mixed precision (faster inference)

required
use_cuda bool

perform optimization on GPU (should )

required
num_attention_heads int

number of attention heads of a model (0 -> try to detect)

0
hidden_size int

hidden layer size of a model (0 -> try to detect)

0
architecture str

model architecture to optimize. One of [bert, bart, gpt2]

'bert'
Source code in src/transformer_deploy/backends/ort_utils.py
def optimize_onnx(
    onnx_path: str,
    onnx_optim_model_path: str,
    fp16: bool,
    use_cuda: bool,
    num_attention_heads: int = 0,
    hidden_size: int = 0,
    architecture: str = "bert",
) -> None:
    """
    ONNX Runtime transformer graph optimization.
    Performs some operator fusion (merge several nodes of the graph in a single one)
    and may convert some nodes to reduced precision.
    :param onnx_path: ONNX input path
    :param onnx_optim_model_path: where to save optimized model
    :param fp16: use mixed precision (faster inference)
    :param use_cuda: perform optimization on GPU (should )
    :param num_attention_heads: number of attention heads of a model (0 -> try to detect)
    :param hidden_size: hidden layer size of a model (0 -> try to detect)
    :param architecture: model architecture to optimize. One of [bert, bart, gpt2]
    """
    assert architecture in ["bert", "bart", "gpt2"], f"unsupported architecture: {architecture}"
    opt_level = 1 if architecture == "bert" else 0
    optimization_options = FusionOptions(model_type=architecture)
    optimization_options.enable_gelu_approximation = False  # additional optimization
    optimized_model: BertOnnxModel = optimizer.optimize_model(
        input=onnx_path,
        model_type=architecture,
        use_gpu=use_cuda,
        opt_level=opt_level,
        num_heads=num_attention_heads,  # automatic detection with 0 may not work with opset 13 or distilbert models
        hidden_size=hidden_size,  # automatic detection with 0
        optimization_options=optimization_options,
    )
    if fp16:
        # use_symbolic_shape_infer set to false because doesn't work after ONNX package v1.10.2
        optimized_model.convert_float_to_float16(use_symbolic_shape_infer=False)  # FP32 -> FP16
    logging.info(f"optimizations applied: {optimized_model.get_fused_operator_statistics()}")
    optimized_model.save_model_to_file(onnx_optim_model_path)

to_pytorch(ort_tensor, clone_tensor) #

Convert OrtValue output by Onnx Runtime to Pytorch tensor. The process can be done in a zero copy way (depending of clone parameter).

Parameters:

Name Type Description Default
ort_tensor OrtValue

output from Onnx Runtime

required
clone_tensor bool

Onnx Runtime owns the storage array and will write on the next inference. By cloning you guarantee that the data won't change.

required

Returns:

Type Description
Tensor

Pytorch tensor

Source code in src/transformer_deploy/backends/ort_utils.py
def to_pytorch(ort_tensor: OrtValue, clone_tensor: bool) -> torch.Tensor:
    """
    Convert OrtValue output by Onnx Runtime to Pytorch tensor.
    The process can be done in a zero copy way (depending of clone parameter).
    :param ort_tensor: output from Onnx Runtime
    :param clone_tensor: Onnx Runtime owns the storage array and will write on the next inference.
        By cloning you guarantee that the data won't change.
    :return: Pytorch tensor
    """
    if ort_tensor.device_name().lower() == "cuda":
        np_type = ort_to_numpy_dtype_dict[ort_tensor.data_type()]
        fake_owner = 1
        # size not used anywhere, so just put 0
        memory = cp.cuda.UnownedMemory(ort_tensor.data_ptr(), 0, fake_owner)
        memory_ptr = cp.cuda.MemoryPointer(memory, 0)
        # make sure you interpret the array shape/dtype/strides correctly
        cp_array = cp.ndarray(shape=ort_tensor.shape(), memptr=memory_ptr, dtype=np_type)
        # cloning required otherwise ORT will recycle the storage array and put new values into it if new inf is done.
        torch_tensor = torch.from_dlpack(cp_array.toDlpack())
        if clone_tensor:
            torch_tensor = torch_tensor.clone()
        return torch_tensor
    else:
        np_tensor = ort_tensor.numpy()
        return torch.from_numpy(np_tensor)

use_external_data(path) #

Check if a model uses external data

Parameters:

Name Type Description Default
path str

Onnx model path

required

Returns:

Type Description
bool

True if any initalizer (model weight) is stored in an external file

Source code in src/transformer_deploy/backends/ort_utils.py
def use_external_data(path: str) -> bool:
    """
    Check if a model uses external data
    :param path: Onnx model path
    :return: True if any initalizer (model weight) is stored in an external file
    """
    model = onnx.load_model(f=path, load_external_data=False)
    for i in model.graph.initializer:
        if i.HasField("data_location") and i.data_location == onnx.TensorProto.EXTERNAL:
            return True
    return False