Ast utils
Contains the code to patch model AST in RAM.
PatchModule
dataclass
#
PatchModule(module: str, monkey_patch: Dict[str, Tuple[Callable, str]] =
Source code in src/transformer_deploy/QDQModels/ast_utils.py
@dataclass
class PatchModule:
module: str
monkey_patch: Dict[str, Tuple[Callable, str]] = field(default_factory=dict)
def print_code(self):
for class_name, cl in self.monkey_patch.items():
print("---------")
print(class_name)
inspect.getsource(cl)
def restore(self):
model_module = importlib.import_module(name=self.module)
importlib.reload(model_module)
add_init_quantizer(head_node, q_attr_names)
#
Add initialization of quantizer to init()
Parameters:
Name | Type | Description | Default |
---|---|---|---|
head_node |
Module |
node related to a class to optimize |
required |
q_attr_names |
List[str] |
list of quantizer names to init |
required |
Returns:
Type | Description |
---|---|
Module |
modified ast tree |
Source code in src/transformer_deploy/QDQModels/ast_utils.py
def add_init_quantizer(head_node: ast.Module, q_attr_names: List[str]) -> ast.Module:
"""
Add initialization of quantizer to __init__()
:param head_node: node related to a class to optimize
:param q_attr_names: list of quantizer names to init
:return: modified ast tree
"""
for node in ast.walk(head_node): # type: ast.FunctionDef
if isinstance(node, ast.FunctionDef) and node.name == "__init__":
for name in q_attr_names:
quantizer = init_quantizer(name)
node.body.append(quantizer)
return head_node
add_qdq_to_class_name(head_node, new_class_name)
#
Change the name of the class to optimize (may help in debugging / error messages)
Parameters:
Name | Type | Description | Default |
---|---|---|---|
head_node |
Module |
node related to the class to optimize |
required |
new_class_name |
str |
new name to use |
required |
Returns:
Type | Description |
---|---|
Module |
the modified ast tree |
Source code in src/transformer_deploy/QDQModels/ast_utils.py
def add_qdq_to_class_name(head_node: ast.Module, new_class_name: str) -> ast.Module:
"""
Change the name of the class to optimize (may help in debugging / error messages)
:param head_node: node related to the class to optimize
:param new_class_name: new name to use
:return: the modified ast tree
"""
for node in ast.walk(head_node): # type: ast.ClassDef
if isinstance(node, ast.ClassDef):
node.name = new_class_name
return head_node
add_quant_to_module(module_to_patch, new_module_name)
#
Modify a class to add quantization operations around each torch operation to optimize.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
module_to_patch |
type |
Pytorch module to patch |
required |
new_module_name |
str |
new name for the module |
required |
Returns:
Type | Description |
---|---|
Module |
modified ast tree |
Source code in src/transformer_deploy/QDQModels/ast_utils.py
def add_quant_to_module(module_to_patch: type, new_module_name: str) -> ast.Module:
"""
Modify a class to add quantization operations around each torch operation to optimize.
:param module_to_patch: Pytorch module to patch
:param new_module_name: new name for the module
:return: modified ast tree
"""
source_code = inspect.getsource(module_to_patch)
head = ast.parse(source_code)
head, nodes_to_add = patch_nodes(head)
add_init_quantizer(head_node=head, q_attr_names=nodes_to_add)
head = add_qdq_to_class_name(head_node=head, new_class_name=new_module_name)
return head
add_quantization_to_model(module_path, class_to_patch)
#
Add quantization support to a model.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
module_path |
str |
model module to optimize |
required |
class_to_patch |
Optional[List[str]] |
name of modules to patch, if None it will be auto-detected. |
required |
Returns:
Type | Description |
---|---|
backup of original classes |
Source code in src/transformer_deploy/QDQModels/ast_utils.py
def add_quantization_to_model(
module_path: str,
class_to_patch: Optional[List[str]],
):
"""
Add quantization support to a model.
:param module_path: model module to optimize
:param class_to_patch: name of modules to patch, if None it will be auto-detected.
:return: backup of original classes
"""
model_module = importlib.import_module(name=module_path)
load_missing_imports(model_module)
if class_to_patch is None or len(class_to_patch) == 0:
class_to_patch = list_class_to_patch(model_module=model_module)
logging.info(f"modify class {', '.join(class_to_patch)}")
for class_name in class_to_patch:
module_to_patch = getattr(model_module, class_name)
head = add_quant_to_module(module_to_patch=module_to_patch, new_module_name=class_name)
head = ast.fix_missing_locations(head)
module_patched: code = compile(head, filename="<ast modif - transformer deploy>", mode="exec")
# execute the code in the module context so it overrides the original classes and leverage existing imports
exec(module_patched, model_module.__dict__, model_module.__dict__)
contains_op(node)
#
Check if a tree contains some operations to optimize.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
node |
AST |
Head of the ast tree |
required |
Returns:
Type | Description |
---|---|
bool |
True if ast tree contains operations to optimize |
Source code in src/transformer_deploy/QDQModels/ast_utils.py
def contains_op(node: ast.AST) -> bool:
"""
Check if a tree contains some operations to optimize.
:param node: Head of the ast tree
:return: True if ast tree contains operations to optimize
"""
for node in ast.walk(node):
for op in op_to_quant:
if op.should_patch(node=node):
return True
return False
init_quantizer(name)
#
Generate quantization node initialization to add to the end of init()
Parameters:
Name | Type | Description | Default |
---|---|---|---|
name |
str |
generated name of the node |
required |
Returns:
Type | Description |
---|---|
Assign |
quantization init ast node |
Source code in src/transformer_deploy/QDQModels/ast_utils.py
def init_quantizer(name: str) -> ast.Assign:
"""
Generate quantization node initialization to add to the end of __init__()
:param name: generated name of the node
:return: quantization init ast node
"""
quant_linear = ast.Attribute(value=ast.Name(id="quant_nn", ctx=ast.Load()), attr="QuantLinear", ctx=ast.Load())
default_quant_desc_input = ast.Attribute(value=quant_linear, attr="default_quant_desc_input", ctx=ast.Load())
tensor_quant = ast.Name(id="TensorQuantizer", ctx=ast.Load())
quant_value = ast.Attribute(value=ast.Name(id="self", ctx=ast.Load()), attr=name, ctx=ast.Store())
return ast.Assign(
targets=[quant_value],
value=ast.Call(func=tensor_quant, args=[default_quant_desc_input], keywords=[]),
)
list_class_to_patch(model_module)
#
List all classes which contain operations to be optimized.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
model_module |
Pytorch module |
required |
Returns:
Type | Description |
---|---|
List[str] |
the list of module names to be optimized |
Source code in src/transformer_deploy/QDQModels/ast_utils.py
def list_class_to_patch(model_module) -> List[str]:
"""
List all classes which contain operations to be optimized.
:param model_module: Pytorch module
:return: the list of module names to be optimized
"""
module_names: List[str] = list()
module_source_code = inspect.getsource(model_module)
head_node = ast.parse(module_source_code)
for node in ast.walk(head_node):
if isinstance(node, ast.ClassDef) and contains_op(node=node):
module_names.append(node.name)
return module_names
load_missing_imports(model_module)
#
Execute some imports in the context of a module. Override Linear layer by its quantized version
Parameters:
Name | Type | Description | Default |
---|---|---|---|
model_module |
module to use for the imports |
required |
Source code in src/transformer_deploy/QDQModels/ast_utils.py
def load_missing_imports(model_module) -> None:
"""
Execute some imports in the context of a module.
Override Linear layer by its quantized version
:param model_module: module to use for the imports
"""
import_code = """
from pytorch_quantization import nn as quant_nn
from pytorch_quantization.nn import TensorQuantizer
"""
# remove extra spaces
import_code = inspect.cleandoc(import_code)
# execute the code in the module context
exec(import_code, model_module.__dict__, model_module.__dict__)
patch_nodes(head_node)
#
Replace an operation to optimize by its optimized version. May have to generate some quantization node names.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
head_node |
Module |
ast node to modify |
required |
Returns:
Type | Description |
---|---|
Tuple[ast.Module, List[str]] |
the modified ast tree and the list of generated quantization nodes |
Source code in src/transformer_deploy/QDQModels/ast_utils.py
def patch_nodes(head_node: ast.Module) -> Tuple[ast.Module, List[str]]:
"""
Replace an operation to optimize by its optimized version.
May have to generate some quantization node names.
:param head_node: ast node to modify
:return: the modified ast tree and the list of generated quantization nodes
"""
q_attr_names: List[str] = list()
for node in ast.walk(head_node): # type: ast.Call
for op in op_to_quant:
if op.should_patch(node=node):
quant_names = op.patch(node=node, nb_quant_node=len(q_attr_names))
q_attr_names.extend(quant_names)
return head_node, q_attr_names