Accelerating GPT-2 model¶
(and any decoder based transformer models)¶
Two trends ongoing in the NLP ecosystem: bigger language model and better text generation. Those trends are game changers (zero shot, etc.) and bring their own challenge: how to perform inference with them? At what cost? GPU or CPU ? etc. aka, how to leverage them in real life?
That’s what we worked on recently, and below you will find the main lessons learned :
- memory IO is by far the main perf bottleneck
- Standard API of ONNX Runtime should not be used but there is an undocumented way of using another ONNX Runtime API which works well
- Nvidia TensorRT is always the fastest option on GPU, by a large margin (expected)
- Caching K/V token representation does not bring any inference optimization (unexpected)
First, let's remind how decoder models work...
Generative text language models like GPT-2 produce text 1 token at a time. The model is auto regressive meaning that each produced token is part of the generation of the next token. There are mainly 2 blocks: the language model itself which produces big tensors, and the decoding algorithm which consumes the tensors and selects 1 or more tokens. Keep in mind that these blocks may live on different hardware… (spoiler: it’s not a good idea)
from IPython.display import Image
Image("../../resources/img/gpt2.png")