Ast operator patch
Contains code to match and patch specific AST patterns.
Patch2ArgsNode (PatchNode)
#
Source code in src/transformer_deploy/QDQModels/ast_operator_patch.py
class Patch2ArgsNode(PatchNode):
def __init__(self, op: str):
"""
Patch source code in the form torch.op(a, b) to torch.op(self.q1(a), self.q1(b))
:param op: operator to match
"""
self.torch_op_to_quantize = op
def should_patch(self, node: ast.AST) -> bool:
return (
isinstance(node, ast.Call)
and isinstance(node.func, ast.Attribute)
and isinstance(node.func.value, ast.Name)
and node.func.value.id == "torch"
and node.func.attr == self.torch_op_to_quantize
)
def patch(self, node: ast.AST, **kwargs) -> List[str]:
assert "nb_quant_node" in kwargs, "missing nb_quant_node paramter"
nb_quant_node: int = kwargs["nb_quant_node"]
q_attr_names = list()
for index in range(2): # only apply transfo to the 2 first args
arg = node.args[index]
q_name = self.get_quant_name(nb_quant_node + len(q_attr_names))
q_attr_names.append(q_name)
node.args[index] = self._wrap_attr(q_name, arg)
return q_attr_names
__init__(self, op)
special
#
Patch source code in the form torch.op(a, b) to torch.op(self.q1(a), self.q1(b))
Parameters:
Name | Type | Description | Default |
---|---|---|---|
op |
str |
operator to match |
required |
patch(self, node, **kwargs)
#
Patch node by adding quantizer nodes around the operator provided during the init
Parameters:
Name | Type | Description | Default |
---|---|---|---|
node |
AST |
node to patch |
required |
kwargs |
additional parameters, like nb_quant_node for the number of existing quantizer node |
{} |
Returns:
Type | Description |
---|---|
List[str] |
return list of generated quantizer node names |
Source code in src/transformer_deploy/QDQModels/ast_operator_patch.py
def patch(self, node: ast.AST, **kwargs) -> List[str]:
assert "nb_quant_node" in kwargs, "missing nb_quant_node paramter"
nb_quant_node: int = kwargs["nb_quant_node"]
q_attr_names = list()
for index in range(2): # only apply transfo to the 2 first args
arg = node.args[index]
q_name = self.get_quant_name(nb_quant_node + len(q_attr_names))
q_attr_names.append(q_name)
node.args[index] = self._wrap_attr(q_name, arg)
return q_attr_names
should_patch(self, node)
#
Check if a node should be patched
Parameters:
Name | Type | Description | Default |
---|---|---|---|
node |
AST |
node to check |
required |
Returns:
Type | Description |
---|---|
bool |
return True if it matches the operator provided during the init |
Source code in src/transformer_deploy/QDQModels/ast_operator_patch.py
PatchAdd2ArgsNode (PatchNode)
#
Source code in src/transformer_deploy/QDQModels/ast_operator_patch.py
class PatchAdd2ArgsNode(PatchNode):
def __init__(self, op: str):
"""
Patch source code in the form torch.op(a + b) to torch.op(self.q1(a) + self.q1(b))
:param op: operator to match
"""
self.torch_op_to_quantize = op
def should_patch(self, node: ast.AST) -> bool:
return (
isinstance(node, ast.Call)
and isinstance(node.func, ast.Attribute)
and node.func.attr == self.torch_op_to_quantize
and isinstance(node.args, list)
and len(node.args) == 1
and isinstance(node.args[0], ast.BinOp)
and isinstance(node.args[0].op, ast.Add)
)
def patch(self, node: ast.AST, **kwargs) -> List[str]:
assert "nb_quant_node" in kwargs, "missing nb_quant_node paramter"
nb_quant_node: int = kwargs["nb_quant_node"]
left_name = self.get_quant_name(nb_quant_node)
right_name = self.get_quant_name(nb_quant_node + 1)
node.args[0].left = self._wrap_attr(left_name, node.args[0].left)
node.args[0].right = self._wrap_attr(right_name, node.args[0].right)
return [left_name, right_name]
__init__(self, op)
special
#
Patch source code in the form torch.op(a + b) to torch.op(self.q1(a) + self.q1(b))
Parameters:
Name | Type | Description | Default |
---|---|---|---|
op |
str |
operator to match |
required |
patch(self, node, **kwargs)
#
Patch node by adding quantizer nodes around the operator provided during the init
Parameters:
Name | Type | Description | Default |
---|---|---|---|
node |
AST |
node to patch |
required |
kwargs |
additional parameters, like nb_quant_node for the number of existing quantizer node |
{} |
Returns:
Type | Description |
---|---|
List[str] |
return list of generated quantizer node names |
Source code in src/transformer_deploy/QDQModels/ast_operator_patch.py
def patch(self, node: ast.AST, **kwargs) -> List[str]:
assert "nb_quant_node" in kwargs, "missing nb_quant_node paramter"
nb_quant_node: int = kwargs["nb_quant_node"]
left_name = self.get_quant_name(nb_quant_node)
right_name = self.get_quant_name(nb_quant_node + 1)
node.args[0].left = self._wrap_attr(left_name, node.args[0].left)
node.args[0].right = self._wrap_attr(right_name, node.args[0].right)
return [left_name, right_name]
should_patch(self, node)
#
Check if a node should be patched
Parameters:
Name | Type | Description | Default |
---|---|---|---|
node |
AST |
node to check |
required |
Returns:
Type | Description |
---|---|
bool |
return True if it matches the operator provided during the init |
Source code in src/transformer_deploy/QDQModels/ast_operator_patch.py
def should_patch(self, node: ast.AST) -> bool:
return (
isinstance(node, ast.Call)
and isinstance(node.func, ast.Attribute)
and node.func.attr == self.torch_op_to_quantize
and isinstance(node.args, list)
and len(node.args) == 1
and isinstance(node.args[0], ast.BinOp)
and isinstance(node.args[0].op, ast.Add)
)
PatchLayer (PatchNode)
#
Source code in src/transformer_deploy/QDQModels/ast_operator_patch.py
class PatchLayer(PatchNode):
def __init__(self, origin_module: str, origin_layer: str, target_module: str, target_layer: str):
"""
Patch source code in the form a.b(...) to c.d(...)
:param origin_module: module to patch
:param origin_layer: layer/method to patch
:param target_module: new module to use
:param target_layer: new layer/method to use
"""
self.origin_module = origin_module
self.origin_layer = origin_layer
self.target_module = target_module
self.target_layer = target_layer
def should_patch(self, node: ast.AST) -> bool:
return (
isinstance(node, ast.Call)
and isinstance(node.func, ast.Attribute)
and isinstance(node.func.value, ast.Name)
and node.func.value.id == self.origin_module
and node.func.attr == self.origin_layer
)
def patch(self, node: ast.AST, **kwargs) -> List[str]:
node.func.value.id = self.target_module
node.func.attr = self.target_layer
return []
__init__(self, origin_module, origin_layer, target_module, target_layer)
special
#
Patch source code in the form a.b(...) to c.d(...)
Parameters:
Name | Type | Description | Default |
---|---|---|---|
origin_module |
str |
module to patch |
required |
origin_layer |
str |
layer/method to patch |
required |
target_module |
str |
new module to use |
required |
target_layer |
str |
new layer/method to use |
required |
Source code in src/transformer_deploy/QDQModels/ast_operator_patch.py
def __init__(self, origin_module: str, origin_layer: str, target_module: str, target_layer: str):
"""
Patch source code in the form a.b(...) to c.d(...)
:param origin_module: module to patch
:param origin_layer: layer/method to patch
:param target_module: new module to use
:param target_layer: new layer/method to use
"""
self.origin_module = origin_module
self.origin_layer = origin_layer
self.target_module = target_module
self.target_layer = target_layer
patch(self, node, **kwargs)
#
Patch node by adding quantizer nodes around the operator provided during the init
Parameters:
Name | Type | Description | Default |
---|---|---|---|
node |
AST |
node to patch |
required |
kwargs |
additional parameters, like nb_quant_node for the number of existing quantizer node |
{} |
Returns:
Type | Description |
---|---|
List[str] |
return list of generated quantizer node names |
should_patch(self, node)
#
Check if a node should be patched
Parameters:
Name | Type | Description | Default |
---|---|---|---|
node |
AST |
node to check |
required |
Returns:
Type | Description |
---|---|
bool |
return True if it matches the operator provided during the init |
Source code in src/transformer_deploy/QDQModels/ast_operator_patch.py
PatchNode
#
Source code in src/transformer_deploy/QDQModels/ast_operator_patch.py
class PatchNode(object):
__metaclass__ = abc.ABCMeta
torch_op_to_quantize: str
@abc.abstractmethod
def should_patch(self, node: ast.AST) -> bool:
"""
Check if a node should be patched
:param node: node to check
:return: return True if it matches the operator provided during the __init__
"""
raise Exception("to implement")
@abc.abstractmethod
def patch(self, node: ast.AST, **kwargs) -> List[str]:
"""
Patch node by adding quantizer nodes around the operator provided during the __init__
:param node: node to patch
:param kwargs: additional parameters, like nb_quant_node for the number of existing quantizer node
:return: return list of generated quantizer node names
"""
raise Exception("to implement")
@staticmethod
def _wrap_attr(quantizer_name: str, tensor_var: ast.expr) -> ast.Call:
"""
Generate quantization wrapping each attribute of a torch operation to optimize (matmul, add, etc.)
:param quantizer_name: generated quantization name
:param tensor_var: the variable to wrap
:return: the ast tree to replace the original variable
"""
return ast.Call(
func=ast.Attribute(value=ast.Name(id="self", ctx=ast.Load()), attr=quantizer_name, ctx=ast.Load()),
args=[tensor_var],
keywords=[],
)
def get_quant_name(self, node_id: int) -> str:
return f"{self.torch_op_to_quantize.lower()}_quantizer_{node_id}"
__metaclass__ (type)
#
Metaclass for defining Abstract Base Classes (ABCs).
Use this metaclass to create an ABC. An ABC can be subclassed directly, and then acts as a mix-in class. You can also register unrelated concrete classes (even built-in classes) and unrelated ABCs as 'virtual subclasses' -- these and their descendants will be considered subclasses of the registering ABC by the built-in issubclass() function, but the registering ABC won't show up in their MRO (Method Resolution Order) nor will method implementations defined by the registering ABC be callable (not even via super()).
Source code in src/transformer_deploy/QDQModels/ast_operator_patch.py
class ABCMeta(type):
"""Metaclass for defining Abstract Base Classes (ABCs).
Use this metaclass to create an ABC. An ABC can be subclassed
directly, and then acts as a mix-in class. You can also register
unrelated concrete classes (even built-in classes) and unrelated
ABCs as 'virtual subclasses' -- these and their descendants will
be considered subclasses of the registering ABC by the built-in
issubclass() function, but the registering ABC won't show up in
their MRO (Method Resolution Order) nor will method
implementations defined by the registering ABC be callable (not
even via super()).
"""
def __new__(mcls, name, bases, namespace, **kwargs):
cls = super().__new__(mcls, name, bases, namespace, **kwargs)
_abc_init(cls)
return cls
def register(cls, subclass):
"""Register a virtual subclass of an ABC.
Returns the subclass, to allow usage as a class decorator.
"""
return _abc_register(cls, subclass)
def __instancecheck__(cls, instance):
"""Override for isinstance(instance, cls)."""
return _abc_instancecheck(cls, instance)
def __subclasscheck__(cls, subclass):
"""Override for issubclass(subclass, cls)."""
return _abc_subclasscheck(cls, subclass)
def _dump_registry(cls, file=None):
"""Debug helper to print the ABC registry."""
print(f"Class: {cls.__module__}.{cls.__qualname__}", file=file)
print(f"Inv. counter: {get_cache_token()}", file=file)
(_abc_registry, _abc_cache, _abc_negative_cache,
_abc_negative_cache_version) = _get_dump(cls)
print(f"_abc_registry: {_abc_registry!r}", file=file)
print(f"_abc_cache: {_abc_cache!r}", file=file)
print(f"_abc_negative_cache: {_abc_negative_cache!r}", file=file)
print(f"_abc_negative_cache_version: {_abc_negative_cache_version!r}",
file=file)
def _abc_registry_clear(cls):
"""Clear the registry (for debugging or testing)."""
_reset_registry(cls)
def _abc_caches_clear(cls):
"""Clear the caches (for debugging or testing)."""
_reset_caches(cls)
patch(self, node, **kwargs)
#
Patch node by adding quantizer nodes around the operator provided during the init
Parameters:
Name | Type | Description | Default |
---|---|---|---|
node |
AST |
node to patch |
required |
kwargs |
additional parameters, like nb_quant_node for the number of existing quantizer node |
{} |
Returns:
Type | Description |
---|---|
List[str] |
return list of generated quantizer node names |
Source code in src/transformer_deploy/QDQModels/ast_operator_patch.py
@abc.abstractmethod
def patch(self, node: ast.AST, **kwargs) -> List[str]:
"""
Patch node by adding quantizer nodes around the operator provided during the __init__
:param node: node to patch
:param kwargs: additional parameters, like nb_quant_node for the number of existing quantizer node
:return: return list of generated quantizer node names
"""
raise Exception("to implement")
should_patch(self, node)
#
Check if a node should be patched
Parameters:
Name | Type | Description | Default |
---|---|---|---|
node |
AST |
node to check |
required |
Returns:
Type | Description |
---|---|
bool |
return True if it matches the operator provided during the init |