Pytorch utils
Utils related to Pytorch inference.
convert_to_onnx(model_pytorch, output_path, inputs_pytorch, quantization, var_output_seq)
#
Convert a Pytorch model to an ONNX graph by tracing the provided input inside the Pytorch code. Pytorch sometimes fails to infer output tensor shape of models In ONNX graph, some axis name may be marked like "Divoutput_dim_1" which is a generated name, and there may be a warning: "WARNING: The shape inference of prim::Constant type is missing, so it may result in wrong shape inference for the exported graph. Please consider adding it in symbolic function." ex.: https://discuss.pytorch.org/t/bidirectional-lstm-and-onnx-runtime-warnings/136374
Parameters:
Name | Type | Description | Default |
---|---|---|---|
model_pytorch |
Module |
Pytorch model (transformers) |
required |
output_path |
str |
where to save ONNX file |
required |
inputs_pytorch |
Dict[str, torch.Tensor] |
Tensor, can be dummy data, shape is not important as we declare all axes as dynamic. Should be on the same device than the model (CPU or GPU) |
required |
quantization |
bool |
model is quantized |
required |
var_output_seq |
bool |
variable size sequence |
required |
Source code in src/transformer_deploy/backends/pytorch_utils.py
def convert_to_onnx(
model_pytorch: torch.nn.Module,
output_path: str,
inputs_pytorch: Dict[str, torch.Tensor],
quantization: bool,
var_output_seq: bool,
) -> None:
"""
Convert a Pytorch model to an ONNX graph by tracing the provided input inside the Pytorch code.
Pytorch sometimes fails to infer output tensor shape of models
In ONNX graph, some axis name may be marked like "Divoutput_dim_1" which is a generated name,
and there may be a warning:
** "WARNING: The shape inference of prim::Constant type is missing, so it may result in wrong shape inference
for the exported graph. Please consider adding it in symbolic function." **
ex.: https://discuss.pytorch.org/t/bidirectional-lstm-and-onnx-runtime-warnings/136374
:param model_pytorch: Pytorch model (transformers)
:param output_path: where to save ONNX file
:param inputs_pytorch: Tensor, can be dummy data, shape is not important as we declare all axes as dynamic.
Should be on the same device than the model (CPU or GPU)
:param quantization: model is quantized
:param var_output_seq: variable size sequence
"""
if quantization:
try:
from pytorch_quantization.nn import TensorQuantizer
except ImportError:
raise ImportError(
"It seems that pytorch-quantization is not yet installed. "
"It is required when you enable the quantization flag and use CUDA device."
"Please find installation instructions on "
"https://github.com/NVIDIA/TensorRT/tree/master/tools/pytorch-quantization or use:\n"
"pip3 install git+ssh://git@github.com/NVIDIA/TensorRT#egg=pytorch-quantization\\&"
"subdirectory=tools/pytorch-quantization/"
)
TensorQuantizer.use_fb_fake_quant = True
if hasattr(model_pytorch, "config") and hasattr(model_pytorch.config, "use_cache"):
use_cache = getattr(model_pytorch.config, "use_cache")
setattr(model_pytorch.config, "use_cache", False)
# dynamic axis == variable length axis
dynamic_axis = dict()
for k in inputs_pytorch.keys():
if var_output_seq:
# seq axis name is fixed to be matched with output seq axis name (for output shape prediction)
dynamic_axis[k] = {0: "batch_size", 1: "sequence"}
else:
# if there is no specific requirement, each axis name is unique, fix some issue on T5 model
dynamic_axis[k] = {0: "batch_size", 1: f"sequence-{k}"}
dynamic_axis["output"] = {0: "batch_size"}
if var_output_seq:
dynamic_axis["output"][1] = "sequence"
# replace int64 input tensors by int32 -> for ONNX Runtime binding API and expected by TensorRT engine
for k, v in inputs_pytorch.items():
if not isinstance(v, torch.Tensor):
continue
if v.dtype in [torch.long, torch.int64]:
inputs_pytorch[k] = v.type(torch.int32)
with torch.no_grad():
torch.onnx.export(
model_pytorch, # model to optimize
args=tuple(inputs_pytorch.values()), # tuple of multiple inputs
f=output_path, # output path / file object
opset_version=13, # the ONNX version to use, >= 13 supports channel quantized model
do_constant_folding=True, # simplify model (replace constant expressions)
input_names=list(inputs_pytorch.keys()), # input names
output_names=["output"], # output axis name, hard coded so only 1 output supported
dynamic_axes=dynamic_axis, # declare dynamix axis for each input / output
training=TrainingMode.EVAL, # always put the model in evaluation mode
verbose=False,
)
if quantization:
TensorQuantizer.use_fb_fake_quant = False
if hasattr(model_pytorch, "config") and hasattr(model_pytorch.config, "use_cache"):
setattr(model_pytorch.config, "use_cache", use_cache)
get_model_size(path)
#
Find number of attention heads and hidden layer size of a model
Parameters:
Name | Type | Description | Default |
---|---|---|---|
path |
str |
path to model |
required |
Returns:
Type | Description |
---|---|
Tuple[int, int] |
tupple of # of attention heads and hidden layer size (0 if not found) |
Source code in src/transformer_deploy/backends/pytorch_utils.py
def get_model_size(path: str) -> Tuple[int, int]:
"""
Find number of attention heads and hidden layer size of a model
:param path: path to model
:return: tupple of # of attention heads and hidden layer size (0 if not found)
"""
config = AutoConfig.from_pretrained(pretrained_model_name_or_path=path)
num_attention_heads = getattr(config, "num_attention_heads", 0)
hidden_size = getattr(config, "hidden_size", 0)
return num_attention_heads, hidden_size
infer_classification_pytorch(model, run_on_cuda)
#
Perform Pytorch inference for classification task
Parameters:
Name | Type | Description | Default |
---|---|---|---|
model |
PreTrainedModel |
Pytorch model (transformers) |
required |
run_on_cuda |
bool |
True if should be ran on GPU |
required |
Returns:
Type | Description |
---|---|
Callable[[Dict[str, torch.Tensor]], torch.Tensor] |
a function to perform inference |
Source code in src/transformer_deploy/backends/pytorch_utils.py
def infer_classification_pytorch(
model: PreTrainedModel, run_on_cuda: bool
) -> Callable[[Dict[str, torch.Tensor]], torch.Tensor]:
"""
Perform Pytorch inference for classification task
:param model: Pytorch model (transformers)
:param run_on_cuda: True if should be ran on GPU
:return: a function to perform inference
"""
def infer(inputs: Dict[str, torch.Tensor]) -> torch.Tensor:
model_output = model(**inputs).logits.detach() # noqa: F821
if run_on_cuda:
torch.cuda.synchronize()
return model_output
return infer
infer_feature_extraction_pytorch(model, run_on_cuda)
#
Perform Pytorch inference for feature extraction task
Parameters:
Name | Type | Description | Default |
---|---|---|---|
model |
PreTrainedModel |
Pytorch model (sentence-transformers) |
required |
run_on_cuda |
bool |
True if should be ran on GPU |
required |
Returns:
Type | Description |
---|---|
Callable[[Dict[str, torch.Tensor]], torch.Tensor] |
a function to perform inference |
Source code in src/transformer_deploy/backends/pytorch_utils.py
def infer_feature_extraction_pytorch(
model: PreTrainedModel, run_on_cuda: bool
) -> Callable[[Dict[str, torch.Tensor]], torch.Tensor]:
"""
Perform Pytorch inference for feature extraction task
:param model: Pytorch model (sentence-transformers)
:param run_on_cuda: True if should be ran on GPU
:return: a function to perform inference
"""
def infer(inputs: Dict[str, torch.Tensor]) -> torch.Tensor:
model_output = model(**inputs).detach() # noqa: F821
if run_on_cuda:
torch.cuda.synchronize()
return model_output
return infer