Ort utils
All the tooling to ease ONNX Runtime usage.
add_output_nodes(model)
#
Set each node as output node for debugging purpose.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
model |
ModelProto |
ONNX model in protobuf format |
required |
Returns:
Type | Description |
---|---|
ModelProto |
modified ONNX model |
Source code in src/transformer_deploy/backends/ort_utils.py
def add_output_nodes(model: ModelProto) -> ModelProto:
"""
Set each node as output node for debugging purpose.
:param model: ONNX model in protobuf format
:return: modified ONNX model
"""
model = copy.deepcopy(model)
output_nodes = list()
for n in model.graph.node:
for output_name in n.output:
output_nodes.append(onnx.ValueInfoProto(name=output_name))
# clear output array (protobuff way...)
while model.graph.output:
model.graph.output.pop()
model.graph.output.extend(output_nodes)
return model
convert_fp16(onnx_model, nodes_to_exclude)
#
Convert ONNX model in FP16, and still being able to exclude a list of nodes.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
onnx_model |
str |
original FP32 model |
required |
nodes_to_exclude |
List[str] |
nodes that should stay in FP32 |
required |
Returns:
Type | Description |
---|---|
ModelProto |
mostly FP16 model |
Source code in src/transformer_deploy/backends/ort_utils.py
def convert_fp16(onnx_model: str, nodes_to_exclude: List[str]) -> ModelProto:
"""
Convert ONNX model in FP16, and still being able to exclude a list of nodes.
:param onnx_model: original FP32 model
:param nodes_to_exclude: nodes that should stay in FP32
:return: mostly FP16 model
"""
# add value info related to each node, required for the conversion
output_path = onnx_model + "_shape_inference.onnx"
infer_shapes_path(model_path=onnx_model, output_path=output_path)
model_fp16 = onnx.load_model(output_path)
model_fp16 = convert_float_to_float16(model=model_fp16, keep_io_types=False, node_block_list=nodes_to_exclude)
# clean casting nodes before returning the model
wrapped_fp16_model = OnnxModel(model_fp16)
fusion_utils = FusionUtils(wrapped_fp16_model)
fusion_utils.remove_cascaded_cast_nodes()
fusion_utils.remove_useless_cast_nodes()
wrapped_fp16_model.topological_sort()
return wrapped_fp16_model.model
cpu_quantization(input_model_path, output_model_path)
#
ONNX CPU only dynamic quantization.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
input_model_path |
str |
ONNX graph (float) to quantize |
required |
output_model_path |
str |
where to save quantized model |
required |
Source code in src/transformer_deploy/backends/ort_utils.py
def cpu_quantization(input_model_path: str, output_model_path: str) -> None:
"""
ONNX CPU only dynamic quantization.
:param input_model_path: ONNX graph (float) to quantize
:param output_model_path: where to save quantized model
"""
quantize_dynamic(
model_input=Path(input_model_path),
model_output=Path(output_model_path),
op_types_to_quantize=["MatMul", "Attention"],
weight_type=QuantType.QInt8,
per_channel=True,
reduce_range=True,
extra_options={"WeightSymmetric": False, "MatMulConstBOnly": True},
)
create_model_for_provider(path, provider_to_use, nb_threads=12, nb_instances=0, optimization_level=<GraphOptimizationLevel.ORT_ENABLE_EXTENDED: 2>, enable_profiling=False, log_severity=2)
#
Create an ONNX Runtime instance.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
path |
str |
path to ONNX file or serialized to string model |
required |
provider_to_use |
Union[str, List] |
provider to use for inference |
required |
nb_threads |
int |
intra_op_num_threads to use. You may want to try different parameters, more core does not always provide best performances. |
12 |
nb_instances |
int |
inter_op_num_threads to use, to execute multiple subgraphs in parallel when possible. |
0 |
optimization_level |
GraphOptimizationLevel |
expected level of ONNX Runtime optimization. For GPU and NLP, extended is the one providing kernel fusion of element wise operations. Enable all level is for CPU inference. see https://onnxruntime.ai/docs/performance/graph-optimizations.html#layout-optimizations |
<GraphOptimizationLevel.ORT_ENABLE_EXTENDED: 2> |
enable_profiling |
bool |
let Onnx Runtime log each kernel time. |
False |
log_severity |
int |
Log severity level. 0:Verbose, 1:Info, 2:Warning. 3:Error, 4:Fatal. |
2 |
Returns:
Type | Description |
---|---|
InferenceSession |
ONNX Runtime inference session |
Source code in src/transformer_deploy/backends/ort_utils.py
def create_model_for_provider(
path: str,
provider_to_use: Union[str, List],
nb_threads: int = multiprocessing.cpu_count(),
nb_instances: int = 0,
optimization_level: GraphOptimizationLevel = GraphOptimizationLevel.ORT_ENABLE_EXTENDED,
enable_profiling: bool = False,
log_severity: int = 2,
) -> InferenceSession:
"""
Create an ONNX Runtime instance.
:param path: path to ONNX file or serialized to string model
:param provider_to_use: provider to use for inference
:param nb_threads: intra_op_num_threads to use. You may want to try different parameters,
more core does not always provide best performances.
:param nb_instances: inter_op_num_threads to use, to execute multiple subgraphs in parallel when possible.
:param optimization_level: expected level of ONNX Runtime optimization. For GPU and NLP, extended is the one
providing kernel fusion of element wise operations. Enable all level is for CPU inference.
see https://onnxruntime.ai/docs/performance/graph-optimizations.html#layout-optimizations
:param enable_profiling: let Onnx Runtime log each kernel time.
:param log_severity: Log severity level. 0:Verbose, 1:Info, 2:Warning. 3:Error, 4:Fatal.
:return: ONNX Runtime inference session
"""
options = SessionOptions()
options.graph_optimization_level = optimization_level
options.enable_profiling = enable_profiling
options.log_severity_level = log_severity
if isinstance(provider_to_use, str):
provider_to_use = [provider_to_use]
if provider_to_use == ["CPUExecutionProvider"]:
options.execution_mode = ExecutionMode.ORT_SEQUENTIAL if nb_instances <= 1 else ExecutionMode.ORT_PARALLEL
options.intra_op_num_threads = nb_threads
if nb_instances > 1:
options.inter_op_num_threads = nb_instances
return InferenceSession(path, options, providers=provider_to_use)
find_node_fp32(graph, output_nodes)
#
Identify out of range values in node outputs.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
graph |
Dict[str, str] |
graph as adjency nodes dict |
required |
output_nodes |
Dict[str, torch.Tensor] |
output of each node |
required |
Returns:
Type | Description |
---|---|
List[str] |
list of nodes producing outputs outside fp16 tensor |
Source code in src/transformer_deploy/backends/ort_utils.py
def find_node_fp32(graph: Dict[str, str], output_nodes: Dict[str, torch.Tensor]) -> List[str]:
"""
Identify out of range values in node outputs.
:param graph: graph as adjency nodes dict
:param output_nodes: output of each node
:return: list of nodes producing outputs outside fp16 tensor
"""
keep_fp32 = list()
min_float16 = torch.finfo(torch.float16).min
max_float16 = torch.finfo(torch.float16).max
resolution = 5.96e-08 # torch.finfo(torch.float16).eps # minimum value that can be represented by FP16
for k, tensor in output_nodes.items():
if tensor.dtype != torch.float32:
continue
# out of FP16 range
if (
torch.any(tensor > max_float16)
or torch.any(tensor < min_float16)
or (torch.any(tensor < resolution & tensor > -resolution & tensor != 0)) # limited memory footprint
):
keep_fp32.append(graph[k])
return keep_fp32
get_io_to_node_mapping(onnx_model)
#
Extract output->node and input->node mappings
Parameters:
Name | Type | Description | Default |
---|---|---|---|
onnx_model |
ModelProto |
ONNX model |
required |
Returns:
Type | Description |
---|---|
Tuple[Dict[str, str], Dict[str, str]] |
2 mappings, (i->node, o->node) |
Source code in src/transformer_deploy/backends/ort_utils.py
def get_io_to_node_mapping(onnx_model: ModelProto) -> Tuple[Dict[str, str], Dict[str, str]]:
"""
Extract output->node and input->node mappings
:param onnx_model: ONNX model
:return: 2 mappings, (i->node, o->node)
"""
output_mapping: Dict[str, str] = dict()
input_mapping: Dict[str, str] = dict()
for node in onnx_model.graph.node: # type: NodeProto
assert len(node.output) == 1
output_node = node.output[0]
output_mapping[output_node] = node.name
for i in node.input:
input_mapping[i] = node.name
return input_mapping, output_mapping
get_keep_fp32_nodes(onnx_model_path, get_input, early_stop=100, device='cuda')
#
Find the list of nodes to keep in FP32 to avoid out of range values
Parameters:
Name | Type | Description | Default |
---|---|---|---|
onnx_model_path |
str |
ONNX model path |
required |
get_input |
Callable[[], Dict[str, torch.Tensor]] |
generate input to test the model. Output should change from call to call |
required |
early_stop |
int |
will test until |
100 |
device |
str |
where to run the inference |
'cuda' |
Returns:
Type | Description |
---|---|
List[str] |
list of names of nodes to keep in FP32 |
Source code in src/transformer_deploy/backends/ort_utils.py
def get_keep_fp32_nodes(
onnx_model_path: str,
get_input: Callable[[], Dict[str, torch.Tensor]],
early_stop: int = 100,
device: str = "cuda",
) -> List[str]:
"""
Find the list of nodes to keep in FP32 to avoid out of range values
:param onnx_model_path: ONNX model path
:param get_input: generate input to test the model. Output should change from call to call
:param early_stop: will test until `early_stop` tests are done without any new node to keep in FP32
:param device: where to run the inference
:return: list of names of nodes to keep in FP32
"""
# do not load weights on LLM (>2Gb), we only need to modify the computation graph
onnx_model: ModelProto = onnx.load_model(f=onnx_model_path, load_external_data=False)
onnx_model_fp32_all_nodes = add_output_nodes(model=onnx_model)
path_onnx_model_fp32_all_nodes = onnx_model_path + "_all_nodes.onnx"
onnx.save_model(proto=onnx_model_fp32_all_nodes, f=path_onnx_model_fp32_all_nodes, save_as_external_data=False)
provider = "CUDAExecutionProvider" if device == "cuda" else "CPUExecutionProvider"
ort_model_fp32_all_nodes = create_model_for_provider(path_onnx_model_fp32_all_nodes, provider)
ort_binding = ort_model_fp32_all_nodes.io_binding()
input_mapping, output_mapping = get_io_to_node_mapping(onnx_model=onnx_model)
# list all nodes which have an output out of the FP16 range
keep_fp32_nodes = list()
no_new_node_counter = 0
while no_new_node_counter < early_stop:
inputs = get_input()
outputs: Dict[str, torch.Tensor] = inference_onnx_binding(
model_onnx=ort_model_fp32_all_nodes, inputs=inputs, device=device, binding=ort_binding, clone_tensor=False
)
keep_node_io = find_node_fp32(graph=output_mapping, output_nodes=outputs)
nodes_to_add = [n for n in keep_node_io if n not in keep_fp32_nodes]
keep_fp32_nodes += nodes_to_add
if len(nodes_to_add) == 0:
no_new_node_counter += 1
else:
no_new_node_counter = 0
if device == "cuda":
torch.cuda.empty_cache()
# I/O names that can't be found in the graph
nodes_to_skip = (
[n.name for n in onnx_model.graph.input]
+ [n.name for n in onnx_model.graph.output]
+ [n.name for n in onnx_model.graph.initializer]
)
# for each node to keep in FP32, we keep its children in FP32 too as they will receive FP32 values as input
map_children = defaultdict(list)
for node in onnx_model.graph.node:
for o in node.output:
if o in nodes_to_skip:
continue
child = input_mapping[o]
map_children[node.name].append(child)
keep_fp32_nodes += [c for k in keep_fp32_nodes if k in map_children for c in map_children[k]]
return keep_fp32_nodes
inference_onnx_binding(model_onnx, inputs, device, device_id=0, binding=None, clone_tensor=True)
#
Performs inference on ONNX Runtime in an optimized way. In particular, it avoids any Onnx Runtime output tensor copy. It means that Onnx Runtime is still owner of the array, and it will overwrite its content if you do another inference. To avoid any issue, just set clone_tensor to True (default). For best performance and lowest memory footprint, if you know what you are doing, set clone_tensor to True.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
model_onnx |
InferenceSession |
ONNX model |
required |
inputs |
Dict[str, torch.Tensor] |
input torch tensor |
required |
device |
str |
where to run the inference. One of [cpu, cuda] |
required |
device_id |
int |
ID of the device where to run the inference, to be used when there are multiple GPUs, etc. |
0 |
binding |
Optional[onnxruntime.capi.onnxruntime_inference_collection.IOBinding] |
previously generated binding IO, will be reset. |
None |
clone_tensor |
bool |
clone Pytorch tensor to avoid its content being overwritten by Onnx Runtime at the next inference call. |
True |
Returns:
Type | Description |
---|---|
Dict[str, torch.Tensor] |
a dict {axis name: output tensor} |
Source code in src/transformer_deploy/backends/ort_utils.py
def inference_onnx_binding(
model_onnx: InferenceSession,
inputs: Dict[str, torch.Tensor],
device: str,
device_id: int = 0,
binding: Optional[IOBinding] = None,
clone_tensor: bool = True,
) -> Dict[str, torch.Tensor]:
"""
Performs inference on ONNX Runtime in an optimized way.
In particular, it avoids any Onnx Runtime output tensor copy.
It means that Onnx Runtime is still owner of the array, and it will overwrite its content if you do another
inference. To avoid any issue, just set clone_tensor to True (default).
For best performance and lowest memory footprint, if you know what you are doing, set clone_tensor to True.
:param model_onnx: ONNX model
:param inputs: input torch tensor
:param device: where to run the inference. One of [cpu, cuda]
:param device_id: ID of the device where to run the inference, to be used when there are multiple GPUs, etc.
:param binding: previously generated binding IO, will be reset.
:param clone_tensor: clone Pytorch tensor to avoid its content being overwritten by Onnx Runtime
at the next inference call.
:return: a dict {axis name: output tensor}
"""
assert isinstance(device, str)
assert device in ["cpu", "cuda"], f"unexpected inference device: '{device}'"
if binding is None:
binding: IOBinding = model_onnx.io_binding()
else:
binding.clear_binding_inputs()
binding.clear_binding_outputs()
for input_onnx in model_onnx.get_inputs():
if input_onnx.name not in inputs: # some inputs may be optional
continue
tensor: torch.Tensor = inputs[input_onnx.name]
tensor = tensor.detach()
if tensor.dtype in [torch.int64, torch.long]:
# int32 mandatory as input of bindings, int64 not supported
tensor = tensor.type(dtype=torch.int32)
tensor = tensor.contiguous()
binding.bind_input(
name=input_onnx.name,
device_type=device,
device_id=device_id,
element_type=torch_to_numpy_dtype_dict[tensor.dtype],
shape=tuple(tensor.shape),
buffer_ptr=tensor.data_ptr(),
)
inputs[input_onnx.name] = tensor
for out in model_onnx.get_outputs():
binding.bind_output(
name=out.name,
device_type=device,
device_id=device_id,
)
binding.synchronize_inputs()
model_onnx.run_with_iobinding(binding)
binding.synchronize_outputs()
outputs = dict()
assert len(model_onnx.get_outputs()) == len(
binding.get_outputs()
), f"{len(model_onnx.get_outputs())} != {len(binding.get_outputs())}"
for out, t in zip(model_onnx.get_outputs(), binding.get_outputs()):
outputs[out.name] = to_pytorch(t, clone_tensor=clone_tensor)
return outputs
optimize_onnx(onnx_path, onnx_optim_model_path, fp16, use_cuda, num_attention_heads=0, hidden_size=0, architecture='bert')
#
ONNX Runtime transformer graph optimization. Performs some operator fusion (merge several nodes of the graph in a single one) and may convert some nodes to reduced precision.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
onnx_path |
str |
ONNX input path |
required |
onnx_optim_model_path |
str |
where to save optimized model |
required |
fp16 |
bool |
use mixed precision (faster inference) |
required |
use_cuda |
bool |
perform optimization on GPU (should ) |
required |
num_attention_heads |
int |
number of attention heads of a model (0 -> try to detect) |
0 |
hidden_size |
int |
hidden layer size of a model (0 -> try to detect) |
0 |
architecture |
str |
model architecture to optimize. One of [bert, bart, gpt2] |
'bert' |
Source code in src/transformer_deploy/backends/ort_utils.py
def optimize_onnx(
onnx_path: str,
onnx_optim_model_path: str,
fp16: bool,
use_cuda: bool,
num_attention_heads: int = 0,
hidden_size: int = 0,
architecture: str = "bert",
) -> None:
"""
ONNX Runtime transformer graph optimization.
Performs some operator fusion (merge several nodes of the graph in a single one)
and may convert some nodes to reduced precision.
:param onnx_path: ONNX input path
:param onnx_optim_model_path: where to save optimized model
:param fp16: use mixed precision (faster inference)
:param use_cuda: perform optimization on GPU (should )
:param num_attention_heads: number of attention heads of a model (0 -> try to detect)
:param hidden_size: hidden layer size of a model (0 -> try to detect)
:param architecture: model architecture to optimize. One of [bert, bart, gpt2]
"""
assert architecture in ["bert", "bart", "gpt2"], f"unsupported architecture: {architecture}"
opt_level = 1 if architecture == "bert" else 0
optimization_options = FusionOptions(model_type=architecture)
optimization_options.enable_gelu_approximation = False # additional optimization
optimized_model: BertOnnxModel = optimizer.optimize_model(
input=onnx_path,
model_type=architecture,
use_gpu=use_cuda,
opt_level=opt_level,
num_heads=num_attention_heads, # automatic detection with 0 may not work with opset 13 or distilbert models
hidden_size=hidden_size, # automatic detection with 0
optimization_options=optimization_options,
)
if fp16:
# use_symbolic_shape_infer set to false because doesn't work after ONNX package v1.10.2
optimized_model.convert_float_to_float16(use_symbolic_shape_infer=False) # FP32 -> FP16
logging.info(f"optimizations applied: {optimized_model.get_fused_operator_statistics()}")
optimized_model.save_model_to_file(onnx_optim_model_path)
to_pytorch(ort_tensor, clone_tensor)
#
Convert OrtValue output by Onnx Runtime to Pytorch tensor. The process can be done in a zero copy way (depending of clone parameter).
Parameters:
Name | Type | Description | Default |
---|---|---|---|
ort_tensor |
OrtValue |
output from Onnx Runtime |
required |
clone_tensor |
bool |
Onnx Runtime owns the storage array and will write on the next inference. By cloning you guarantee that the data won't change. |
required |
Returns:
Type | Description |
---|---|
Tensor |
Pytorch tensor |
Source code in src/transformer_deploy/backends/ort_utils.py
def to_pytorch(ort_tensor: OrtValue, clone_tensor: bool) -> torch.Tensor:
"""
Convert OrtValue output by Onnx Runtime to Pytorch tensor.
The process can be done in a zero copy way (depending of clone parameter).
:param ort_tensor: output from Onnx Runtime
:param clone_tensor: Onnx Runtime owns the storage array and will write on the next inference.
By cloning you guarantee that the data won't change.
:return: Pytorch tensor
"""
if ort_tensor.device_name().lower() == "cuda":
np_type = ort_to_numpy_dtype_dict[ort_tensor.data_type()]
fake_owner = 1
# size not used anywhere, so just put 0
memory = cp.cuda.UnownedMemory(ort_tensor.data_ptr(), 0, fake_owner)
memory_ptr = cp.cuda.MemoryPointer(memory, 0)
# make sure you interpret the array shape/dtype/strides correctly
cp_array = cp.ndarray(shape=ort_tensor.shape(), memptr=memory_ptr, dtype=np_type)
# cloning required otherwise ORT will recycle the storage array and put new values into it if new inf is done.
torch_tensor = torch.from_dlpack(cp_array.toDlpack())
if clone_tensor:
torch_tensor = torch_tensor.clone()
return torch_tensor
else:
np_tensor = ort_tensor.numpy()
return torch.from_numpy(np_tensor)
use_external_data(path)
#
Check if a model uses external data
Parameters:
Name | Type | Description | Default |
---|---|---|---|
path |
str |
Onnx model path |
required |
Returns:
Type | Description |
---|---|
bool |
True if any initalizer (model weight) is stored in an external file |
Source code in src/transformer_deploy/backends/ort_utils.py
def use_external_data(path: str) -> bool:
"""
Check if a model uses external data
:param path: Onnx model path
:return: True if any initalizer (model weight) is stored in an external file
"""
model = onnx.load_model(f=path, load_external_data=False)
for i in model.graph.initializer:
if i.HasField("data_location") and i.data_location == onnx.TensorProto.EXTERNAL:
return True
return False