Nvidia GPU INT-8 quantization on any transformers model (encoder based)¶
For some context and explanations, please check our documentation here: https://els-rd.github.io/transformer-deploy/quantization/quantization_intro/.
Your machine should have Nvidia CUDA 11.X, TensorRT 8.2.1 and cuBLAS installed. It's said to be tricky to install, in my experience, just follow Nvidia download page instructions and nothing else, it should work out of the box. Nvidia Docker image could be a good choice too.
#! pip3 install git+ssh://git@github.com/ELS-RD/transformer-deploy
#! pip3 install datasets sklearn
#! pip3 install git+ssh://git@github.com/NVIDIA/TensorRT#egg=pytorch-quantization\&subdirectory=tools/pytorch-quantization/
Looking in indexes: https://pypi.org/simple, https://pypi.ngc.nvidia.com
Collecting datasets
Downloading datasets-1.18.4-py3-none-any.whl (312 kB)
|████████████████████████████████| 312 kB 12.2 MB/s eta 0:00:01
Collecting sklearn
Downloading sklearn-0.0.tar.gz (1.1 kB)
Collecting dill
Downloading dill-0.3.4-py2.py3-none-any.whl (86 kB)
|████████████████████████████████| 86 kB 70.0 MB/s eta 0:00:01
Requirement already satisfied: requests>=2.19.0 in /home/geantvert/.local/share/virtualenvs/fast_transformer/lib/python3.9/site-packages (from datasets) (2.27.1)
Requirement already satisfied: numpy>=1.17 in /home/geantvert/.local/share/virtualenvs/fast_transformer/lib/python3.9/site-packages (from datasets) (1.21.5)
Requirement already satisfied: packaging in /home/geantvert/.local/share/virtualenvs/fast_transformer/lib/python3.9/site-packages (from datasets) (21.3)
Collecting pyarrow!=4.0.0,>=3.0.0
Downloading pyarrow-7.0.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (26.7 MB)
|████████████████████████████████| 26.7 MB 20.6 MB/s eta 0:00:01
Collecting xxhash
Downloading xxhash-3.0.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (211 kB)
|████████████████████████████████| 211 kB 59.3 MB/s eta 0:00:01
Collecting fsspec[http]>=2021.05.0
Downloading fsspec-2022.2.0-py3-none-any.whl (134 kB)
|████████████████████████████████| 134 kB 73.4 MB/s eta 0:00:01
Collecting responses<0.19
Downloading responses-0.18.0-py3-none-any.whl (38 kB)
Collecting pandas
Downloading pandas-1.4.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (11.7 MB)
|████████████████████████████████| 11.7 MB 90.8 MB/s eta 0:00:01
Requirement already satisfied: huggingface-hub<1.0.0,>=0.1.0 in /home/geantvert/.local/share/virtualenvs/fast_transformer/lib/python3.9/site-packages (from datasets) (0.4.0)
Collecting aiohttp
Downloading aiohttp-3.8.1-cp39-cp39-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl (1.2 MB)
|████████████████████████████████| 1.2 MB 100.2 MB/s eta 0:00:01
Collecting multiprocess
Downloading multiprocess-0.70.12.2-py39-none-any.whl (128 kB)
|████████████████████████████████| 128 kB 78.2 MB/s eta 0:00:01
Requirement already satisfied: tqdm>=4.62.1 in /home/geantvert/.local/share/virtualenvs/fast_transformer/lib/python3.9/site-packages (from datasets) (4.62.3)
Requirement already satisfied: pyyaml in /home/geantvert/.local/share/virtualenvs/fast_transformer/lib/python3.9/site-packages (from huggingface-hub<1.0.0,>=0.1.0->datasets) (6.0)
Requirement already satisfied: typing-extensions>=3.7.4.3 in /home/geantvert/.local/share/virtualenvs/fast_transformer/lib/python3.9/site-packages (from huggingface-hub<1.0.0,>=0.1.0->datasets) (4.0.1)
Requirement already satisfied: filelock in /home/geantvert/.local/share/virtualenvs/fast_transformer/lib/python3.9/site-packages (from huggingface-hub<1.0.0,>=0.1.0->datasets) (3.4.2)
Requirement already satisfied: pyparsing!=3.0.5,>=2.0.2 in /home/geantvert/.local/share/virtualenvs/fast_transformer/lib/python3.9/site-packages (from packaging->datasets) (3.0.7)
Requirement already satisfied: urllib3<1.27,>=1.21.1 in /home/geantvert/.local/share/virtualenvs/fast_transformer/lib/python3.9/site-packages (from requests>=2.19.0->datasets) (1.26.8)
Requirement already satisfied: certifi>=2017.4.17 in /home/geantvert/.local/share/virtualenvs/fast_transformer/lib/python3.9/site-packages (from requests>=2.19.0->datasets) (2021.10.8)
Requirement already satisfied: idna<4,>=2.5 in /home/geantvert/.local/share/virtualenvs/fast_transformer/lib/python3.9/site-packages (from requests>=2.19.0->datasets) (3.3)
Requirement already satisfied: charset-normalizer~=2.0.0 in /home/geantvert/.local/share/virtualenvs/fast_transformer/lib/python3.9/site-packages (from requests>=2.19.0->datasets) (2.0.11)
Requirement already satisfied: scikit-learn in /home/geantvert/.local/share/virtualenvs/fast_transformer/lib/python3.9/site-packages (from sklearn) (1.0.2)
Requirement already satisfied: attrs>=17.3.0 in /home/geantvert/.local/share/virtualenvs/fast_transformer/lib/python3.9/site-packages (from aiohttp->datasets) (21.4.0)
Collecting multidict<7.0,>=4.5
Downloading multidict-6.0.2-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (114 kB)
|████████████████████████████████| 114 kB 96.7 MB/s eta 0:00:01
Collecting aiosignal>=1.1.2
Downloading aiosignal-1.2.0-py3-none-any.whl (8.2 kB)
Collecting frozenlist>=1.1.1
Downloading frozenlist-1.3.0-cp39-cp39-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl (156 kB)
|████████████████████████████████| 156 kB 86.6 MB/s eta 0:00:01
Collecting yarl<2.0,>=1.0
Downloading yarl-1.7.2-cp39-cp39-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl (304 kB)
|████████████████████████████████| 304 kB 95.1 MB/s eta 0:00:01
Collecting async-timeout<5.0,>=4.0.0a3
Downloading async_timeout-4.0.2-py3-none-any.whl (5.8 kB)
Requirement already satisfied: python-dateutil>=2.8.1 in /home/geantvert/.local/share/virtualenvs/fast_transformer/lib/python3.9/site-packages (from pandas->datasets) (2.8.2)
Collecting pytz>=2020.1
Downloading pytz-2021.3-py2.py3-none-any.whl (503 kB)
|████████████████████████████████| 503 kB 68.0 MB/s eta 0:00:01
Requirement already satisfied: six>=1.5 in /home/geantvert/.local/share/virtualenvs/fast_transformer/lib/python3.9/site-packages (from python-dateutil>=2.8.1->pandas->datasets) (1.16.0)
Requirement already satisfied: joblib>=0.11 in /home/geantvert/.local/share/virtualenvs/fast_transformer/lib/python3.9/site-packages (from scikit-learn->sklearn) (1.1.0)
Requirement already satisfied: scipy>=1.1.0 in /home/geantvert/.local/share/virtualenvs/fast_transformer/lib/python3.9/site-packages (from scikit-learn->sklearn) (1.8.0)
Requirement already satisfied: threadpoolctl>=2.0.0 in /home/geantvert/.local/share/virtualenvs/fast_transformer/lib/python3.9/site-packages (from scikit-learn->sklearn) (3.1.0)
Building wheels for collected packages: sklearn
Building wheel for sklearn (setup.py) ... done
Created wheel for sklearn: filename=sklearn-0.0-py2.py3-none-any.whl size=1309 sha256=3ac0638e4032964c7ba9d014645ce13935d501fa40fff5bf77431ac4cffa5e14
Stored in directory: /tmp/pip-ephem-wheel-cache-m57vxdol/wheels/e4/7b/98/b6466d71b8d738a0c547008b9eb39bf8676d1ff6ca4b22af1c
Successfully built sklearn
Installing collected packages: multidict, frozenlist, yarl, async-timeout, aiosignal, pytz, fsspec, dill, aiohttp, xxhash, responses, pyarrow, pandas, multiprocess, sklearn, datasets
Successfully installed aiohttp-3.8.1 aiosignal-1.2.0 async-timeout-4.0.2 datasets-1.18.4 dill-0.3.4 frozenlist-1.3.0 fsspec-2022.2.0 multidict-6.0.2 multiprocess-0.70.12.2 pandas-1.4.1 pyarrow-7.0.0 pytz-2021.3 responses-0.18.0 sklearn-0.0 xxhash-3.0.0 yarl-1.7.2
WARNING: You are using pip version 21.1.2; however, version 22.0.4 is available.
You should consider upgrading via the '/home/geantvert/.local/share/virtualenvs/fast_transformer/bin/python -m pip install --upgrade pip' command.
Check the GPU is enabled and usable.
! nvidia-smi
Wed Mar 9 19:14:33 2022 +-----------------------------------------------------------------------------+ | NVIDIA-SMI 510.47.03 Driver Version: 510.47.03 CUDA Version: 11.6 | |-------------------------------+----------------------+----------------------+ | GPU Name Persistence-M| Bus-Id Disp.A | Volatile Uncorr. ECC | | Fan Temp Perf Pwr:Usage/Cap| Memory-Usage | GPU-Util Compute M. | | | | MIG M. | |===============================+======================+======================| | 0 NVIDIA GeForce ... On | 00000000:03:00.0 On | N/A | | 48% 46C P8 41W / 350W | 366MiB / 24576MiB | 0% Default | | | | N/A | +-------------------------------+----------------------+----------------------+ +-----------------------------------------------------------------------------+ | Processes: | | GPU GI CI PID Type Process name GPU Memory | | ID ID Usage | |=============================================================================| | 0 N/A N/A 1895 G /usr/lib/xorg/Xorg 182MiB | | 0 N/A N/A 7706 G /usr/bin/gnome-shell 44MiB | | 0 N/A N/A 8916 G ...on/Bin/AgentConnectix.bin 4MiB | | 0 N/A N/A 10507 G ...884829189831830011,131072 53MiB | | 0 N/A N/A 1025897 G ...AAAAAAAAA= --shared-files 39MiB | | 0 N/A N/A 3867171 G ...947156.log --shared-files 37MiB | +-----------------------------------------------------------------------------+
import logging
import os
from collections import OrderedDict
from typing import Dict, List
from typing import OrderedDict as OD
from typing import Union
import datasets
import numpy as np
import tensorrt as trt
import torch
import transformers
from datasets import load_dataset, load_metric
from tensorrt.tensorrt import IExecutionContext, Logger, Runtime
from transformers import (
AutoModelForSequenceClassification,
AutoTokenizer,
IntervalStrategy,
PreTrainedModel,
PreTrainedTokenizer,
Trainer,
TrainingArguments,
)
from transformer_deploy.backends.ort_utils import (
cpu_quantization,
create_model_for_provider,
optimize_onnx,
)
from transformer_deploy.backends.pytorch_utils import convert_to_onnx
from transformer_deploy.backends.trt_utils import build_engine, get_binding_idxs, infer_tensorrt
from transformer_deploy.benchmarks.utils import print_timings, track_infer_time
from transformer_deploy.QDQModels.calibration_utils import QATCalibrate
Set logging to error
level to ease readability of this notebook
on Github.
log_level = logging.ERROR
logging.getLogger().setLevel(log_level)
datasets.utils.logging.set_verbosity(log_level)
transformers.utils.logging.set_verbosity(log_level)
transformers.utils.logging.enable_default_handler()
transformers.utils.logging.enable_explicit_format()
trt_logger: Logger = trt.Logger(trt.Logger.ERROR)
transformers.logging.set_verbosity_error()
Preprocess data¶
This part is inspired from an official Notebooks from Hugging Face.
There is nothing special to do. Define the task:
model_name = "roberta-base"
task = "mnli"
num_labels = 3
batch_size = 32
max_seq_len = 256
validation_key = "validation_matched"
timings: Dict[str, List[float]] = dict()
runtime: Runtime = trt.Runtime(trt_logger)
profile_index = 0
Preprocess data (task specific):
def preprocess_function(examples):
return tokenizer(
examples["premise"], examples["hypothesis"], truncation=True, padding="max_length", max_length=max_seq_len
)
def compute_metrics(eval_pred):
predictions, labels = eval_pred
if task != "stsb":
predictions = np.argmax(predictions, axis=1)
else:
predictions = predictions[:, 0]
return metric.compute(predictions=predictions, references=labels)
def convert_tensor(data: OD[str, List[List[int]]], output: str) -> OD[str, Union[np.ndarray, torch.Tensor]]:
input: OD[str, Union[np.ndarray, torch.Tensor]] = OrderedDict()
for k in ["input_ids", "attention_mask", "token_type_ids"]:
if k in data:
v = data[k]
if output == "torch":
value = torch.tensor(v, dtype=torch.long, device="cuda")
elif output == "np":
value = np.asarray(v, dtype=np.int32)
else:
raise Exception(f"unknown output type: {output}")
input[k] = value
return input
def measure_accuracy(infer, tensor_type: str) -> float:
outputs = list()
for start_index in range(0, len(encoded_dataset[validation_key]), batch_size):
end_index = start_index + batch_size
data = encoded_dataset[validation_key][start_index:end_index]
inputs: OD[str, np.ndarray] = convert_tensor(data=data, output=tensor_type)
output = infer(inputs)[0]
if tensor_type == "torch":
output = output.detach().cpu().numpy()
output = np.argmax(output, axis=1).astype(int).tolist()
outputs.extend(output)
return np.mean(np.array(outputs) == np.array(validation_labels))
def get_trainer(model: PreTrainedModel) -> Trainer:
trainer = Trainer(
model,
args,
train_dataset=encoded_dataset["train"],
eval_dataset=encoded_dataset[validation_key],
tokenizer=tokenizer,
compute_metrics=compute_metrics,
)
transformers.logging.set_verbosity_error()
return trainer
tokenizer: PreTrainedTokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True)
dataset = load_dataset("glue", task)
metric = load_metric("glue", task)
encoded_dataset = dataset.map(preprocess_function, batched=True)
validation_labels = [item["label"] for item in encoded_dataset[validation_key]]
nb_step = 1000
strategy = IntervalStrategy.STEPS
args = TrainingArguments(
f"{model_name}-{task}",
evaluation_strategy=strategy,
eval_steps=nb_step,
logging_steps=nb_step,
save_steps=nb_step,
save_strategy=strategy,
learning_rate=1e-5,
per_device_train_batch_size=batch_size,
per_device_eval_batch_size=batch_size * 2,
num_train_epochs=1,
fp16=True,
group_by_length=True,
weight_decay=0.01,
load_best_model_at_end=True,
metric_for_best_model="accuracy",
report_to=[],
)
Downloading: 0%| | 0.00/7.78k [00:00<?, ?B/s]
Downloading: 0%| | 0.00/4.47k [00:00<?, ?B/s]
0%| | 0/5 [00:00<?, ?it/s]
Downloading: 0%| | 0.00/1.84k [00:00<?, ?B/s]
0%| | 0/393 [00:00<?, ?ba/s]
0%| | 0/10 [00:00<?, ?ba/s]
0%| | 0/10 [00:00<?, ?ba/s]
0%| | 0/10 [00:00<?, ?ba/s]
0%| | 0/10 [00:00<?, ?ba/s]
(Standard) fine-tuning model¶
Now that our data are ready, we can download/fine tune the pretrained model.
model_fp16: PreTrainedModel = AutoModelForSequenceClassification.from_pretrained(model_name, num_labels=num_labels)
trainer = get_trainer(model_fp16)
transformers.logging.set_verbosity_error()
trainer.train()
print(trainer.evaluate())
model_fp16.save_pretrained("model_trained_fp16")
[INFO|trainer.py:457] 2022-03-09 19:15:29,964 >> Using amp half precision backend /home/geantvert/.local/share/virtualenvs/fast_transformer/lib/python3.9/site-packages/transformers/optimization.py:306: FutureWarning: This implementation of AdamW is deprecated and will be removed in a future version. Use the PyTorch implementation torch.optim.AdamW instead, or set `no_deprecation_warning=True` to disable this warning warnings.warn(
{'loss': 0.6865, 'learning_rate': 9.1875814863103e-06, 'epoch': 0.08} {'eval_loss': 0.4732713997364044, 'eval_accuracy': 0.8161996943453897, 'eval_runtime': 18.5941, 'eval_samples_per_second': 527.856, 'eval_steps_per_second': 8.282, 'epoch': 0.08} {'loss': 0.4992, 'learning_rate': 8.372718383311604e-06, 'epoch': 0.16} {'eval_loss': 0.42749494314193726, 'eval_accuracy': 0.8335201222618441, 'eval_runtime': 18.4454, 'eval_samples_per_second': 532.112, 'eval_steps_per_second': 8.349, 'epoch': 0.16} {'loss': 0.4684, 'learning_rate': 7.557855280312908e-06, 'epoch': 0.24} {'eval_loss': 0.41849666833877563, 'eval_accuracy': 0.8358634742740703, 'eval_runtime': 18.4936, 'eval_samples_per_second': 530.726, 'eval_steps_per_second': 8.327, 'epoch': 0.24} {'loss': 0.4457, 'learning_rate': 6.743807040417211e-06, 'epoch': 0.33} {'eval_loss': 0.3834249973297119, 'eval_accuracy': 0.8508405501782985, 'eval_runtime': 18.4657, 'eval_samples_per_second': 531.527, 'eval_steps_per_second': 8.34, 'epoch': 0.33} {'loss': 0.4302, 'learning_rate': 5.9289439374185145e-06, 'epoch': 0.41} {'eval_loss': 0.3902352452278137, 'eval_accuracy': 0.8499235863474274, 'eval_runtime': 18.5403, 'eval_samples_per_second': 529.388, 'eval_steps_per_second': 8.306, 'epoch': 0.41} {'loss': 0.4227, 'learning_rate': 5.114080834419818e-06, 'epoch': 0.49} {'eval_loss': 0.38703474402427673, 'eval_accuracy': 0.8512480896586857, 'eval_runtime': 18.5155, 'eval_samples_per_second': 530.097, 'eval_steps_per_second': 8.317, 'epoch': 0.49} {'loss': 0.4165, 'learning_rate': 4.30003259452412e-06, 'epoch': 0.57} {'eval_loss': 0.36435264348983765, 'eval_accuracy': 0.8603158430973, 'eval_runtime': 18.5529, 'eval_samples_per_second': 529.029, 'eval_steps_per_second': 8.301, 'epoch': 0.57} {'loss': 0.4132, 'learning_rate': 3.4851694915254244e-06, 'epoch': 0.65} {'eval_loss': 0.35739392042160034, 'eval_accuracy': 0.8606214977075904, 'eval_runtime': 18.5392, 'eval_samples_per_second': 529.418, 'eval_steps_per_second': 8.307, 'epoch': 0.65} {'loss': 0.4015, 'learning_rate': 2.670306388526728e-06, 'epoch': 0.73} {'eval_loss': 0.36193060874938965, 'eval_accuracy': 0.8631686194600102, 'eval_runtime': 18.5487, 'eval_samples_per_second': 529.149, 'eval_steps_per_second': 8.302, 'epoch': 0.73} {'loss': 0.3948, 'learning_rate': 1.8554432855280313e-06, 'epoch': 0.81} {'eval_loss': 0.3555583655834198, 'eval_accuracy': 0.8639836984207845, 'eval_runtime': 18.5639, 'eval_samples_per_second': 528.713, 'eval_steps_per_second': 8.296, 'epoch': 0.81} {'loss': 0.3967, 'learning_rate': 1.0413950456323338e-06, 'epoch': 0.9} {'eval_loss': 0.35024452209472656, 'eval_accuracy': 0.866225165562914, 'eval_runtime': 18.5496, 'eval_samples_per_second': 529.121, 'eval_steps_per_second': 8.302, 'epoch': 0.9} {'loss': 0.3979, 'learning_rate': 2.2734680573663624e-07, 'epoch': 0.98} {'eval_loss': 0.34961390495300293, 'eval_accuracy': 0.8664289353031075, 'eval_runtime': 18.5647, 'eval_samples_per_second': 528.693, 'eval_steps_per_second': 8.295, 'epoch': 0.98} {'train_runtime': 2621.6937, 'train_samples_per_second': 149.789, 'train_steps_per_second': 4.681, 'train_loss': 0.4466879855855969, 'epoch': 1.0} {'eval_loss': 0.34961390495300293, 'eval_accuracy': 0.8664289353031075, 'eval_runtime': 18.571, 'eval_samples_per_second': 528.511, 'eval_steps_per_second': 8.292, 'epoch': 1.0} {'eval_loss': 0.34961390495300293, 'eval_accuracy': 0.8664289353031075, 'eval_runtime': 18.571, 'eval_samples_per_second': 528.511, 'eval_steps_per_second': 8.292, 'epoch': 1.0}
Add quantization support to any model¶
The idea is to take the source code of a specific model and add automatically QDQ
nodes. QDQ nodes will be placed before and after an operation that we want to quantize, that’s inside these nodes that the information to perform the mapping between high precision and low precision number is stored.
If you want to know more, check our documentation on: https://els-rd.github.io/transformer-deploy/quantization/quantization_ast/
for percentile in [99.9, 99.99, 99.999, 99.9999]:
with QATCalibrate(method="histogram", percentile=percentile) as qat:
model_q: PreTrainedModel = AutoModelForSequenceClassification.from_pretrained(
"model_trained_fp16", num_labels=num_labels
)
model_q = model_q.cuda()
qat.setup_model_qat(model_q) # prepare quantizer to any model
with torch.no_grad():
for start_index in range(0, 128, batch_size):
end_index = start_index + batch_size
data = encoded_dataset["train"][start_index:end_index]
input_torch = {
k: torch.tensor(v, dtype=torch.long, device="cuda")
for k, v in data.items()
if k in ["input_ids", "attention_mask", "token_type_ids"]
}
model_q(**input_torch)
trainer = get_trainer(model_q)
print(f"percentile: {percentile}")
print(trainer.evaluate())
[INFO|trainer.py:457] 2022-03-09 20:04:30,756 >> Using amp half precision backend
percentile: 99.9 {'eval_loss': 0.4834432005882263, 'eval_accuracy': 0.8129393785022924, 'eval_runtime': 46.3749, 'eval_samples_per_second': 211.645, 'eval_steps_per_second': 3.321} {'eval_loss': 0.4834432005882263, 'eval_accuracy': 0.8129393785022924, 'eval_runtime': 46.3749, 'eval_samples_per_second': 211.645, 'eval_steps_per_second': 3.321}
[INFO|trainer.py:457] 2022-03-09 20:08:47,875 >> Using amp half precision backend
percentile: 99.99 {'eval_loss': 0.3795105814933777, 'eval_accuracy': 0.8555272542027509, 'eval_runtime': 46.4677, 'eval_samples_per_second': 211.222, 'eval_steps_per_second': 3.314} {'eval_loss': 0.3795105814933777, 'eval_accuracy': 0.8555272542027509, 'eval_runtime': 46.4677, 'eval_samples_per_second': 211.222, 'eval_steps_per_second': 3.314}
[INFO|trainer.py:457] 2022-03-09 20:13:04,907 >> Using amp half precision backend
percentile: 99.999 {'eval_loss': 0.38251793384552, 'eval_accuracy': 0.8548140601120734, 'eval_runtime': 46.4646, 'eval_samples_per_second': 211.236, 'eval_steps_per_second': 3.314} {'eval_loss': 0.38251793384552, 'eval_accuracy': 0.8548140601120734, 'eval_runtime': 46.4646, 'eval_samples_per_second': 211.236, 'eval_steps_per_second': 3.314}
[INFO|trainer.py:457] 2022-03-09 20:17:16,064 >> Using amp half precision backend
percentile: 99.9999 {'eval_loss': 0.9809716939926147, 'eval_accuracy': 0.5092205807437595, 'eval_runtime': 46.4947, 'eval_samples_per_second': 211.099, 'eval_steps_per_second': 3.312} {'eval_loss': 0.9809716939926147, 'eval_accuracy': 0.5092205807437595, 'eval_runtime': 46.4947, 'eval_samples_per_second': 211.099, 'eval_steps_per_second': 3.312}
As you can see, the chosen percentile value has a high impact on the final accuracy.
For the rest of the notebook, we apply the 99.999
percentile.
with QATCalibrate(method="histogram", percentile=99.999) as qat:
model_q: PreTrainedModel = AutoModelForSequenceClassification.from_pretrained(
"model_trained_fp16", num_labels=num_labels
)
model_q = model_q.cuda()
qat.setup_model_qat(model_q) # prepare quantizer to any model
with torch.no_grad():
for start_index in range(0, 128, batch_size):
end_index = start_index + batch_size
data = encoded_dataset["train"][start_index:end_index]
input_torch = {
k: torch.tensor(v, dtype=torch.long, device="cuda")
for k, v in data.items()
if k in ["input_ids", "attention_mask", "token_type_ids"]
}
model_q(**input_torch)
trainer = get_trainer(model_q)
print(trainer.evaluate())
[INFO|trainer.py:457] 2022-03-09 20:21:33,073 >> Using amp half precision backend
{'eval_loss': 0.38251793384552, 'eval_accuracy': 0.8548140601120734, 'eval_runtime': 46.4748, 'eval_samples_per_second': 211.19, 'eval_steps_per_second': 3.314} {'eval_loss': 0.38251793384552, 'eval_accuracy': 0.8548140601120734, 'eval_runtime': 46.4748, 'eval_samples_per_second': 211.19, 'eval_steps_per_second': 3.314}
Per layer quantization analysis¶
Below we will run a sensitivity analysis, by enabling quantization of one layer at a time and measuring the accuracy. That way we will be able to detect if the quantization of a specific layer has a larger cost on accuracy than other layers.
from pytorch_quantization import nn as quant_nn
for i in range(12):
layer_name = f"layer.{i}"
print(layer_name)
for name, module in model_q.named_modules():
if isinstance(module, quant_nn.TensorQuantizer):
if layer_name in name:
module.enable_quant()
else:
module.disable_quant()
trainer.evaluate()
print("----")
layer.0 {'eval_loss': 0.34956735372543335, 'eval_accuracy': 0.86571574121243, 'eval_runtime': 20.7064, 'eval_samples_per_second': 474.009, 'eval_steps_per_second': 7.437} ---- layer.1 {'eval_loss': 0.3523275852203369, 'eval_accuracy': 0.8649006622516556, 'eval_runtime': 25.4843, 'eval_samples_per_second': 385.14, 'eval_steps_per_second': 6.043} ---- layer.2 {'eval_loss': 0.356509268283844, 'eval_accuracy': 0.8622516556291391, 'eval_runtime': 20.7496, 'eval_samples_per_second': 473.021, 'eval_steps_per_second': 7.422} ---- layer.3 {'eval_loss': 0.36036217212677, 'eval_accuracy': 0.8617422312786551, 'eval_runtime': 20.7815, 'eval_samples_per_second': 472.296, 'eval_steps_per_second': 7.41} ---- layer.4 {'eval_loss': 0.35000357031822205, 'eval_accuracy': 0.8643912379011717, 'eval_runtime': 20.7921, 'eval_samples_per_second': 472.053, 'eval_steps_per_second': 7.407} ---- layer.5 {'eval_loss': 0.354992538690567, 'eval_accuracy': 0.8644931227712684, 'eval_runtime': 20.7938, 'eval_samples_per_second': 472.016, 'eval_steps_per_second': 7.406} ---- layer.6 {'eval_loss': 0.35205718874931335, 'eval_accuracy': 0.8645950076413652, 'eval_runtime': 20.7918, 'eval_samples_per_second': 472.061, 'eval_steps_per_second': 7.407} ---- layer.7 {'eval_loss': 0.35065746307373047, 'eval_accuracy': 0.8655119714722364, 'eval_runtime': 20.8011, 'eval_samples_per_second': 471.849, 'eval_steps_per_second': 7.403} ---- layer.8 {'eval_loss': 0.3491470217704773, 'eval_accuracy': 0.8659195109526235, 'eval_runtime': 20.8112, 'eval_samples_per_second': 471.621, 'eval_steps_per_second': 7.4} ---- layer.9 {'eval_loss': 0.3492998480796814, 'eval_accuracy': 0.8659195109526235, 'eval_runtime': 20.8695, 'eval_samples_per_second': 470.303, 'eval_steps_per_second': 7.379} ---- layer.10 {'eval_loss': 0.3501480221748352, 'eval_accuracy': 0.866225165562914, 'eval_runtime': 20.8698, 'eval_samples_per_second': 470.296, 'eval_steps_per_second': 7.379} ---- layer.11 {'eval_loss': 0.3497083783149719, 'eval_accuracy': 0.866225165562914, 'eval_runtime': 20.9345, 'eval_samples_per_second': 468.843, 'eval_steps_per_second': 7.356} ----
It seems that quantization of layers 2 to 6 has the largest accuracy impact.
Operator quantization analysis¶
Below we will run a sensitivity analysis, by enabling quantization of one operator type at a time and measuring the accuracy. That way we will be able to detect if a specific operator has a larger cost on accuracy. On Roberta we only quantize matmul
and LayerNorm
, so we test both candidates.
for op in ["matmul", "layernorm"]:
for name, module in model_q.named_modules():
if isinstance(module, quant_nn.TensorQuantizer):
if op in name:
module.enable_quant()
else:
module.disable_quant()
print(op)
trainer.evaluate()
print("----")
matmul {'eval_loss': 0.3494793176651001, 'eval_accuracy': 0.8654100866021396, 'eval_runtime': 26.892, 'eval_samples_per_second': 364.978, 'eval_steps_per_second': 5.727} ---- layernorm {'eval_loss': 0.3585323095321655, 'eval_accuracy': 0.8587875700458482, 'eval_runtime': 24.0982, 'eval_samples_per_second': 407.293, 'eval_steps_per_second': 6.391} ----
It appears that LayerNorm
quantization has a significant accuracy cost.
Our goal is to disable quantization for as few operations as possible while preserving accuracy as much as possible. Therefore we will try to only disable quantization for LayerNorm
on Layers 2 to 6.
disable_layer_names = ["layer.2", "layer.3", "layer.4", "layer.6"]
for name, module in model_q.named_modules():
if isinstance(module, quant_nn.TensorQuantizer):
if any([f"{l}.output.layernorm" in name for l in disable_layer_names]):
print(f"disable {name}")
module.disable_quant()
else:
module.enable_quant()
trainer.evaluate()
disable roberta.encoder.layer.2.output.layernorm_quantizer_0 disable roberta.encoder.layer.2.output.layernorm_quantizer_1 disable roberta.encoder.layer.3.output.layernorm_quantizer_0 disable roberta.encoder.layer.3.output.layernorm_quantizer_1 disable roberta.encoder.layer.4.output.layernorm_quantizer_0 disable roberta.encoder.layer.4.output.layernorm_quantizer_1 disable roberta.encoder.layer.6.output.layernorm_quantizer_0 disable roberta.encoder.layer.6.output.layernorm_quantizer_1 {'eval_loss': 0.3617263436317444, 'eval_accuracy': 0.8614365766683647, 'eval_runtime': 46.0379, 'eval_samples_per_second': 213.194, 'eval_steps_per_second': 3.345}
{'eval_loss': 0.3617263436317444, 'eval_accuracy': 0.8614365766683647, 'eval_runtime': 46.0379, 'eval_samples_per_second': 213.194, 'eval_steps_per_second': 3.345}
By just disabling quantization for a single operator on a few layers, we keep most of the performance boost (quantization) but retrieve more than 1 point of accuracy. It's also possible to perform an analysis per quantizer to get a smaller granularity but it's a bit slow to run.
If we stop here, it's called a Post Training Quantization (PTQ). Below, we will try to retrieve even more accuracy.
Quantization Aware Training (QAT)¶
We retrain the model with 1/10 or 1/100 of the original learning rate. Our goal is to retrieve most of the original accuracy.
args.learning_rate = 1e-7
trainer = get_trainer(model_q)
trainer.train()
print(trainer.evaluate())
model_q.save_pretrained("model-qat")
[INFO|trainer.py:457] 2022-03-09 20:28:11,049 >> Using amp half precision backend /home/geantvert/.local/share/virtualenvs/fast_transformer/lib/python3.9/site-packages/transformers/optimization.py:306: FutureWarning: This implementation of AdamW is deprecated and will be removed in a future version. Use the PyTorch implementation torch.optim.AdamW instead, or set `no_deprecation_warning=True` to disable this warning warnings.warn(
{'loss': 0.3628, 'learning_rate': 9.1867666232073e-08, 'epoch': 0.08} {'eval_loss': 0.37072187662124634, 'eval_accuracy': 0.8610290371879776, 'eval_runtime': 46.0646, 'eval_samples_per_second': 213.071, 'eval_steps_per_second': 3.343, 'epoch': 0.08} {'loss': 0.3159, 'learning_rate': 8.371903520208604e-08, 'epoch': 0.16} {'eval_loss': 0.3802173137664795, 'eval_accuracy': 0.8589913397860418, 'eval_runtime': 46.1087, 'eval_samples_per_second': 212.866, 'eval_steps_per_second': 3.34, 'epoch': 0.16} {'loss': 0.3038, 'learning_rate': 7.557855280312907e-08, 'epoch': 0.24} {'eval_loss': 0.38381442427635193, 'eval_accuracy': 0.8597045338767193, 'eval_runtime': 46.1153, 'eval_samples_per_second': 212.836, 'eval_steps_per_second': 3.339, 'epoch': 0.24} {'loss': 0.2981, 'learning_rate': 6.74299217731421e-08, 'epoch': 0.33} {'eval_loss': 0.39187753200531006, 'eval_accuracy': 0.8617422312786551, 'eval_runtime': 46.0955, 'eval_samples_per_second': 212.927, 'eval_steps_per_second': 3.341, 'epoch': 0.33} {'loss': 0.2979, 'learning_rate': 5.9289439374185136e-08, 'epoch': 0.41} {'eval_loss': 0.39416825771331787, 'eval_accuracy': 0.8601120733571065, 'eval_runtime': 46.1939, 'eval_samples_per_second': 212.474, 'eval_steps_per_second': 3.334, 'epoch': 0.41} {'loss': 0.3041, 'learning_rate': 5.114080834419817e-08, 'epoch': 0.49} {'eval_loss': 0.39381393790245056, 'eval_accuracy': 0.8609271523178808, 'eval_runtime': 46.1092, 'eval_samples_per_second': 212.864, 'eval_steps_per_second': 3.34, 'epoch': 0.49} {'loss': 0.3122, 'learning_rate': 4.3008474576271185e-08, 'epoch': 0.57} {'eval_loss': 0.39094462990760803, 'eval_accuracy': 0.8600101884870097, 'eval_runtime': 46.1377, 'eval_samples_per_second': 212.733, 'eval_steps_per_second': 3.338, 'epoch': 0.57} {'loss': 0.3297, 'learning_rate': 3.485984354628422e-08, 'epoch': 0.65} {'eval_loss': 0.3906724452972412, 'eval_accuracy': 0.8616403464085584, 'eval_runtime': 46.2006, 'eval_samples_per_second': 212.443, 'eval_steps_per_second': 3.333, 'epoch': 0.65} {'loss': 0.338, 'learning_rate': 2.671121251629726e-08, 'epoch': 0.73} {'eval_loss': 0.38998448848724365, 'eval_accuracy': 0.8624554253693326, 'eval_runtime': 46.1088, 'eval_samples_per_second': 212.866, 'eval_steps_per_second': 3.34, 'epoch': 0.73} {'loss': 0.3557, 'learning_rate': 1.8570730117340286e-08, 'epoch': 0.81} {'eval_loss': 0.39177224040031433, 'eval_accuracy': 0.8593988792664289, 'eval_runtime': 49.4531, 'eval_samples_per_second': 198.471, 'eval_steps_per_second': 3.114, 'epoch': 0.81} {'loss': 0.385, 'learning_rate': 1.0430247718383311e-08, 'epoch': 0.9} {'eval_loss': 0.38946154713630676, 'eval_accuracy': 0.8593988792664289, 'eval_runtime': 51.1685, 'eval_samples_per_second': 191.817, 'eval_steps_per_second': 3.01, 'epoch': 0.9} {'loss': 0.4133, 'learning_rate': 2.2816166883963493e-09, 'epoch': 0.98} {'eval_loss': 0.3868817687034607, 'eval_accuracy': 0.8611309220580744, 'eval_runtime': 46.8425, 'eval_samples_per_second': 209.532, 'eval_steps_per_second': 3.288, 'epoch': 0.98} {'train_runtime': 5020.8729, 'train_samples_per_second': 78.214, 'train_steps_per_second': 2.444, 'train_loss': 0.33702789975580366, 'epoch': 1.0} {'eval_loss': 0.38998448848724365, 'eval_accuracy': 0.8624554253693326, 'eval_runtime': 47.1812, 'eval_samples_per_second': 208.028, 'eval_steps_per_second': 3.264, 'epoch': 1.0} {'eval_loss': 0.38998448848724365, 'eval_accuracy': 0.8624554253693326, 'eval_runtime': 47.1812, 'eval_samples_per_second': 208.028, 'eval_steps_per_second': 3.264, 'epoch': 1.0}
Export a QDQ Pytorch
model to ONNX
¶
We need to enable fake quantization mode from Pytorch.
data = encoded_dataset["train"][1:3]
input_torch = convert_tensor(data, output="torch")
convert_to_onnx(
model_pytorch=model_q,
output_path="model_qat.onnx",
inputs_pytorch=input_torch,
quantization=True,
var_output_seq=False,
)
/home/geantvert/.local/share/virtualenvs/fast_transformer/lib/python3.9/site-packages/pytorch_quantization/nn/modules/tensor_quantizer.py:285: TracerWarning: Converting a tensor to a Python number might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs! inputs, amax.item() / bound, 0, /home/geantvert/.local/share/virtualenvs/fast_transformer/lib/python3.9/site-packages/pytorch_quantization/nn/modules/tensor_quantizer.py:291: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs! quant_dim = list(amax.shape).index(list(amax_sequeeze.shape)[0])
del model_q
QATCalibrate.restore()
Benchmark¶
Convert ONNX
graph to TensorRT
engine¶
engine = build_engine(
runtime=runtime,
onnx_file_path="model_qat.onnx",
logger=trt_logger,
min_shape=(1, max_seq_len),
optimal_shape=(batch_size, max_seq_len),
max_shape=(batch_size, max_seq_len),
workspace_size=10000 * 1024 * 1024,
fp16=True,
int8=True,
)
# same as above, but from the terminal
# !/usr/src/tensorrt/bin/trtexec --onnx=model_qat.onnx --shapes=input_ids:32x256,attention_mask:32x256 --best --workspace=10000 --saveEngine="test.plan"
Prepare input and output buffer¶
context: IExecutionContext = engine.create_execution_context()
context.set_optimization_profile_async(
profile_index=profile_index, stream_handle=torch.cuda.current_stream().cuda_stream
)
input_binding_idxs, output_binding_idxs = get_binding_idxs(engine, profile_index) # type: List[int], List[int]
data = encoded_dataset["train"][0:batch_size]
input_torch: OD[str, torch.Tensor] = convert_tensor(data=data, output="torch")
input_np: OD[str, np.ndarray] = convert_tensor(data=data, output="np")
Inference on TensorRT
¶
We first check that inference is working correctly:
tensorrt_output = infer_tensorrt(
context=context,
host_inputs=input_torch,
input_binding_idxs=input_binding_idxs,
output_binding_idxs=output_binding_idxs,
)
print(tensorrt_output)
[tensor([[ 1.2287, 1.3706, -2.4623], [ 2.4742, -0.7816, -1.8427], [ 2.4837, -0.3966, -2.2666], [ 2.9077, -0.3062, -2.9778], [ 2.3437, 0.0488, -2.6377], [ 3.7914, -1.1918, -3.1387], [-3.6134, 2.7432, 0.8490], [ 3.6679, -1.5787, -2.6408], [ 1.0155, -1.2787, 0.4250], [-3.4514, -0.4434, 4.2748], [ 3.5201, -1.1297, -2.8988], [-3.0225, -0.4062, 3.8606], [-2.7311, 3.5470, -0.4632], [-2.0741, 1.6613, 0.5798], [-0.4047, -0.8650, 1.6144], [ 2.8432, -1.3301, -1.8994], [ 3.7722, -0.9103, -3.3070], [-2.4204, -2.1432, 4.6537], [-3.1179, -1.3207, 4.6400], [-1.8794, 4.1075, -1.8630], [ 3.7726, -1.2056, -3.0701], [ 1.8645, 1.9744, -3.8743], [-3.1448, -1.2497, 4.5782], [ 3.5385, -0.2421, -3.6629], [ 3.7501, -1.6469, -2.7108], [-0.6568, 0.9046, -0.0228], [-3.2998, 0.0867, 3.3673], [-2.1030, 4.0461, -1.6705], [-3.7080, 0.4164, 3.4332], [ 3.6850, -0.9984, -3.0304], [ 3.4525, -0.5405, -3.2981], [ 3.6128, -0.9298, -3.0746]], device='cuda:0')]
Measure of the accuracy:
infer_trt = lambda inputs: infer_tensorrt(
context=context,
host_inputs=inputs,
input_binding_idxs=input_binding_idxs,
output_binding_idxs=output_binding_idxs,
)
measure_accuracy(infer=infer_trt, tensor_type="torch")
0.8618441161487519
Latency measures:
time_buffer = list()
for _ in range(100):
with track_infer_time(time_buffer):
_ = infer_tensorrt(
context=context,
host_inputs=input_torch,
input_binding_idxs=input_binding_idxs,
output_binding_idxs=output_binding_idxs,
)
print_timings(name="TensorRT (INT-8)", timings=time_buffer)
[TensorRT (INT-8)] mean=18.04ms, sd=2.02ms, min=16.67ms, max=28.88ms, median=17.27ms, 95p=20.04ms, 99p=28.86ms
del engine, context
baseline_model = AutoModelForSequenceClassification.from_pretrained("model_trained_fp16", num_labels=num_labels)
baseline_model = baseline_model.cuda()
baseline_model = baseline_model.eval()
data = encoded_dataset["train"][0:batch_size]
input_torch: OD[str, torch.Tensor] = convert_tensor(data=data, output="torch")
with torch.inference_mode():
for _ in range(30):
_ = baseline_model(**input_torch)
torch.cuda.synchronize()
time_buffer = list()
for _ in range(100):
with track_infer_time(time_buffer):
_ = baseline_model(**input_torch)
torch.cuda.synchronize()
print_timings(name="Pytorch (FP32)", timings=time_buffer)
[Pytorch (FP32)] mean=82.28ms, sd=7.72ms, min=75.77ms, max=120.13ms, median=79.73ms, 95p=97.02ms, 99p=110.68ms
with torch.inference_mode():
with torch.cuda.amp.autocast():
for _ in range(30):
_ = baseline_model(**input_torch)
torch.cuda.synchronize()
time_buffer = []
for _ in range(100):
with track_infer_time(time_buffer):
_ = baseline_model(**input_torch)
torch.cuda.synchronize()
print_timings(name="Pytorch (FP16)", timings=time_buffer)
del baseline_model
[Pytorch (FP16)] mean=58.73ms, sd=1.88ms, min=55.56ms, max=69.69ms, median=58.03ms, 95p=62.28ms, 99p=64.42ms
CPU execution¶
baseline_model = AutoModelForSequenceClassification.from_pretrained("model_trained_fp16", num_labels=num_labels)
baseline_model = baseline_model.eval()
data = encoded_dataset["train"][0:batch_size]
input_torch: OD[str, torch.Tensor] = convert_tensor(data=data, output="torch")
input_torch_cpu = {k: v.to("cpu") for k, v in input_torch.items()}
torch.set_num_threads(os.cpu_count())
with torch.inference_mode():
for _ in range(3):
_ = baseline_model(**input_torch_cpu)
torch.cuda.synchronize()
time_buffer = list()
for _ in range(10):
with track_infer_time(time_buffer):
_ = baseline_model(**input_torch_cpu)
torch.cuda.synchronize()
print_timings(name="Pytorch (FP32) - CPU", timings=time_buffer)
[Pytorch (FP32) - CPU] mean=5141.10ms, sd=654.18ms, min=4095.59ms, max=6229.84ms, median=5347.27ms, 95p=5958.43ms, 99p=6175.56ms
with torch.inference_mode():
with torch.cuda.amp.autocast():
for _ in range(3):
_ = baseline_model(**input_torch_cpu)
torch.cuda.synchronize()
time_buffer = []
for _ in range(10):
with track_infer_time(time_buffer):
_ = baseline_model(**input_torch_cpu)
torch.cuda.synchronize()
print_timings(name="Pytorch (FP16) - CPU", timings=time_buffer)
del baseline_model
[Pytorch (FP16) - CPU] mean=4422.39ms, sd=170.45ms, min=4250.10ms, max=4744.97ms, median=4337.85ms, 95p=4727.72ms, 99p=4741.52ms
Below, we will perform dynamic quantization on CPU.
quantized_baseline_model = AutoModelForSequenceClassification.from_pretrained(
"model_trained_fp16", num_labels=num_labels
)
quantized_baseline_model = quantized_baseline_model.eval()
quantized_baseline_model = torch.quantization.quantize_dynamic(
quantized_baseline_model, {torch.nn.Linear}, dtype=torch.qint8
)
with torch.inference_mode():
for _ in range(3):
_ = quantized_baseline_model(**input_torch_cpu)
torch.cuda.synchronize()
time_buffer = list()
for _ in range(10):
with track_infer_time(time_buffer):
_ = quantized_baseline_model(**input_torch_cpu)
torch.cuda.synchronize()
print_timings(name="Pytorch (INT-8) - CPU", timings=time_buffer)
[Pytorch (INT-8) - CPU] mean=3818.99ms, sd=137.98ms, min=3616.11ms, max=4049.00ms, median=3807.33ms, 95p=4024.45ms, 99p=4044.09ms
TensorRT baseline¶
Below we export our finetuned model, the purpose is to only check the performance on mixed precision (FP16, no quantization).
baseline_model = AutoModelForSequenceClassification.from_pretrained("model_trained_fp16", num_labels=num_labels)
baseline_model = baseline_model.cuda()
convert_to_onnx(
baseline_model, output_path="baseline.onnx", inputs_pytorch=input_torch, quantization=False, var_output_seq=False
)
del baseline_model
engine = build_engine(
runtime=runtime,
onnx_file_path="baseline.onnx",
logger=trt_logger,
min_shape=(batch_size, max_seq_len),
optimal_shape=(batch_size, max_seq_len),
max_shape=(batch_size, max_seq_len),
workspace_size=10000 * 1024 * 1024,
fp16=True,
int8=False,
)
input_torch: OD[str, np.ndarray] = convert_tensor(data=data, output="torch")
context: IExecutionContext = engine.create_execution_context()
context.set_optimization_profile_async(
profile_index=profile_index, stream_handle=torch.cuda.current_stream().cuda_stream
)
input_binding_idxs, output_binding_idxs = get_binding_idxs(engine, profile_index) # type: List[int], List[int]
for _ in range(30):
_ = infer_tensorrt(
context=context,
host_inputs=input_torch,
input_binding_idxs=input_binding_idxs,
output_binding_idxs=output_binding_idxs,
)
time_buffer = list()
for _ in range(100):
with track_infer_time(time_buffer):
_ = infer_tensorrt(
context=context,
host_inputs=input_torch,
input_binding_idxs=input_binding_idxs,
output_binding_idxs=output_binding_idxs,
)
print_timings(name="TensorRT (FP16)", timings=time_buffer)
del engine, context
[TensorRT (FP16)] mean=32.36ms, sd=1.47ms, min=29.92ms, max=38.08ms, median=32.47ms, 95p=34.38ms, 99p=36.50ms
ONNX Runtime baseline¶
ONNX Runtime is the go to inference solution from Microsoft.
The recent 1.10 version of ONNX Runtime (with TensorRT support) is still a bit buggy on transformer models, that is why we use the 1.9.0 version in the measures below.
As before, CPU quantization is dynamic.
Function will set ONNX Runtime to use all cores available and enable any possible optimizations.
optimize_onnx(
onnx_path="baseline.onnx",
onnx_optim_model_path="baseline-optimized.onnx",
fp16=True,
use_cuda=True,
num_attention_heads=12,
hidden_size=768,
architecture="bert",
)
failed in shape inference <class 'AssertionError'>
cpu_quantization(input_model_path="baseline.onnx", output_model_path="baseline-quantized.onnx")
Ignore MatMul due to non constant B: /[MatMul_108] Ignore MatMul due to non constant B: /[MatMul_113] Ignore MatMul due to non constant B: /[MatMul_210] Ignore MatMul due to non constant B: /[MatMul_215] Ignore MatMul due to non constant B: /[MatMul_312] Ignore MatMul due to non constant B: /[MatMul_317] Ignore MatMul due to non constant B: /[MatMul_414] Ignore MatMul due to non constant B: /[MatMul_419] Ignore MatMul due to non constant B: /[MatMul_516] Ignore MatMul due to non constant B: /[MatMul_521] Ignore MatMul due to non constant B: /[MatMul_618] Ignore MatMul due to non constant B: /[MatMul_623] Ignore MatMul due to non constant B: /[MatMul_720] Ignore MatMul due to non constant B: /[MatMul_725] Ignore MatMul due to non constant B: /[MatMul_822] Ignore MatMul due to non constant B: /[MatMul_827] Ignore MatMul due to non constant B: /[MatMul_924] Ignore MatMul due to non constant B: /[MatMul_929] Ignore MatMul due to non constant B: /[MatMul_1026] Ignore MatMul due to non constant B: /[MatMul_1031] Ignore MatMul due to non constant B: /[MatMul_1128] Ignore MatMul due to non constant B: /[MatMul_1133] Ignore MatMul due to non constant B: /[MatMul_1230] Ignore MatMul due to non constant B: /[MatMul_1235]
labels = [item["label"] for item in encoded_dataset[validation_key]]
data = encoded_dataset[validation_key][0:batch_size]
inputs_onnx: OD[str, np.ndarray] = convert_tensor(data=data, output="np")
model = create_model_for_provider(path="baseline-optimized.onnx", provider_to_use="CUDAExecutionProvider")
output = model.run(None, inputs_onnx)
data = encoded_dataset["train"][0:batch_size]
inputs_onnx: OD[str, np.ndarray] = convert_tensor(data=data, output="np")
for provider, model_path, benchmark_name, warmup, nb_inference in [
("CUDAExecutionProvider", "baseline.onnx", "ONNX Runtime GPU (FP32)", 10, 100),
("CUDAExecutionProvider", "baseline-optimized.onnx", "ONNX Runtime GPU (FP16)", 10, 100),
("CPUExecutionProvider", "baseline.onnx", "ONNX Runtime CPU (FP32)", 3, 10),
("CPUExecutionProvider", "baseline-optimized.onnx", "ONNX Runtime CPU (FP16)", 3, 10),
("CPUExecutionProvider", "baseline-quantized.onnx", "ONNX Runtime CPU (INT-8)", 3, 10),
]:
model = create_model_for_provider(path=model_path, provider_to_use=provider)
for _ in range(warmup):
_ = model.run(None, inputs_onnx)
time_buffer = []
for _ in range(nb_inference):
with track_infer_time(time_buffer):
_ = model.run(None, inputs_onnx)
print_timings(name=benchmark_name, timings=time_buffer)
del model
[ONNX Runtime GPU (FP32)] mean=82.50ms, sd=11.53ms, min=74.12ms, max=118.73ms, median=77.62ms, 95p=112.51ms, 99p=117.81ms [ONNX Runtime GPU (FP16)] mean=36.23ms, sd=4.01ms, min=33.47ms, max=55.75ms, median=35.35ms, 95p=46.89ms, 99p=53.95ms [ONNX Runtime CPU (FP32)] mean=4608.31ms, sd=468.16ms, min=3954.09ms, max=5226.20ms, median=4531.47ms, 95p=5189.71ms, 99p=5218.90ms [ONNX Runtime CPU (FP16)] mean=3987.95ms, sd=223.00ms, min=3755.70ms, max=4548.57ms, median=3907.74ms, 95p=4375.23ms, 99p=4513.90ms [ONNX Runtime CPU (INT-8)] mean=3506.39ms, sd=125.07ms, min=3425.41ms, max=3872.95ms, median=3463.65ms, 95p=3713.08ms, 99p=3840.98ms
Measure of the accuracy with ONNX Runtime engine and CUDA provider:
model = create_model_for_provider(path="baseline.onnx", provider_to_use="CUDAExecutionProvider")
infer_ort = lambda tokens: model.run(None, tokens)
measure_accuracy(infer=infer_ort, tensor_type="np")
0.8665308201732043
model = create_model_for_provider(path="baseline-optimized.onnx", provider_to_use="CUDAExecutionProvider")
infer_ort = lambda tokens: model.run(None, tokens)
measure_accuracy(infer=infer_ort, tensor_type="np")
0.8663270504330107
model = create_model_for_provider(path="baseline-quantized.onnx", provider_to_use="CPUExecutionProvider")
infer_ort = lambda tokens: model.run(None, tokens)
measure_accuracy(infer=infer_ort, tensor_type="np")
0.8604177279673968
del model