Patch
Simple to use wrapper to patch transformer models AST
add_qdq(modules_to_patch=None)
#
Add quantization support to each tested model by modifyin their AST.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
modules_to_patch |
Optional[List[transformer_deploy.QDQModels.ast_utils.PatchModule]] |
list of operator to target |
None |
Source code in src/transformer_deploy/QDQModels/patch.py
def add_qdq(modules_to_patch: Optional[List[PatchModule]] = None) -> None:
"""
Add quantization support to each tested model by modifyin their AST.
:param modules_to_patch: list of operator to target
"""
if modules_to_patch is None:
modules_to_patch = tested_models
for patch in modules_to_patch:
logging.info(f"add quantization to module {patch.module}")
patch_model(patch)
patch_model(patch)
#
Perform modifications to model to make it work with ONNX export and quantization.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
patch |
PatchModule |
an object containing all the information to perform a modification |
required |
Source code in src/transformer_deploy/QDQModels/patch.py
def patch_model(patch: PatchModule) -> None:
"""
Perform modifications to model to make it work with ONNX export and quantization.
:param patch: an object containing all the information to perform a modification
"""
add_quantization_to_model(module_path=patch.module, class_to_patch=None)
model_module = importlib.import_module(patch.module)
for target, (modified_object, object_name) in patch.monkey_patch.items():
source_code = inspect.getsource(modified_object)
source_code += f"\n{target} = {object_name}"
exec(source_code, model_module.__dict__, model_module.__dict__)
remove_qdq(modules_to_patch=None)
#
Restore AST of modified modules.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
modules_to_patch |
Optional[List[transformer_deploy.QDQModels.ast_utils.PatchModule]] |
list of operator to target |
None |
Source code in src/transformer_deploy/QDQModels/patch.py
def remove_qdq(modules_to_patch: Optional[List[PatchModule]] = None) -> None:
"""
Restore AST of modified modules.
:param modules_to_patch: list of operator to target
"""
if modules_to_patch is None:
modules_to_patch = tested_models
for patch in modules_to_patch:
logging.info(f"restore module {patch.module}")
patch.restore()