Skip to content

Ast operator patch

Contains code to match and patch specific AST patterns.

Patch2ArgsNode (PatchNode) #

Source code in src/transformer_deploy/QDQModels/ast_operator_patch.py
class Patch2ArgsNode(PatchNode):
    def __init__(self, op: str):
        """
        Patch source code in the form torch.op(a, b) to torch.op(self.q1(a), self.q1(b))
        :param op: operator to match
        """
        self.torch_op_to_quantize = op

    def should_patch(self, node: ast.AST) -> bool:
        return (
            isinstance(node, ast.Call)
            and isinstance(node.func, ast.Attribute)
            and isinstance(node.func.value, ast.Name)
            and node.func.value.id == "torch"
            and node.func.attr == self.torch_op_to_quantize
        )

    def patch(self, node: ast.AST, **kwargs) -> List[str]:
        assert "nb_quant_node" in kwargs, "missing nb_quant_node paramter"
        nb_quant_node: int = kwargs["nb_quant_node"]
        q_attr_names = list()
        for index in range(2):  # only apply transfo to the 2 first args
            arg = node.args[index]
            q_name = self.get_quant_name(nb_quant_node + len(q_attr_names))
            q_attr_names.append(q_name)
            node.args[index] = self._wrap_attr(q_name, arg)
        return q_attr_names

__init__(self, op) special #

Patch source code in the form torch.op(a, b) to torch.op(self.q1(a), self.q1(b))

Parameters:

Name Type Description Default
op str

operator to match

required
Source code in src/transformer_deploy/QDQModels/ast_operator_patch.py
def __init__(self, op: str):
    """
    Patch source code in the form torch.op(a, b) to torch.op(self.q1(a), self.q1(b))
    :param op: operator to match
    """
    self.torch_op_to_quantize = op

patch(self, node, **kwargs) #

Patch node by adding quantizer nodes around the operator provided during the init

Parameters:

Name Type Description Default
node AST

node to patch

required
kwargs

additional parameters, like nb_quant_node for the number of existing quantizer node

{}

Returns:

Type Description
List[str]

return list of generated quantizer node names

Source code in src/transformer_deploy/QDQModels/ast_operator_patch.py
def patch(self, node: ast.AST, **kwargs) -> List[str]:
    assert "nb_quant_node" in kwargs, "missing nb_quant_node paramter"
    nb_quant_node: int = kwargs["nb_quant_node"]
    q_attr_names = list()
    for index in range(2):  # only apply transfo to the 2 first args
        arg = node.args[index]
        q_name = self.get_quant_name(nb_quant_node + len(q_attr_names))
        q_attr_names.append(q_name)
        node.args[index] = self._wrap_attr(q_name, arg)
    return q_attr_names

should_patch(self, node) #

Check if a node should be patched

Parameters:

Name Type Description Default
node AST

node to check

required

Returns:

Type Description
bool

return True if it matches the operator provided during the init

Source code in src/transformer_deploy/QDQModels/ast_operator_patch.py
def should_patch(self, node: ast.AST) -> bool:
    return (
        isinstance(node, ast.Call)
        and isinstance(node.func, ast.Attribute)
        and isinstance(node.func.value, ast.Name)
        and node.func.value.id == "torch"
        and node.func.attr == self.torch_op_to_quantize
    )

PatchAdd2ArgsNode (PatchNode) #

Source code in src/transformer_deploy/QDQModels/ast_operator_patch.py
class PatchAdd2ArgsNode(PatchNode):
    def __init__(self, op: str):
        """
        Patch source code in the form torch.op(a + b) to torch.op(self.q1(a) + self.q1(b))
        :param op: operator to match
        """
        self.torch_op_to_quantize = op

    def should_patch(self, node: ast.AST) -> bool:
        return (
            isinstance(node, ast.Call)
            and isinstance(node.func, ast.Attribute)
            and node.func.attr == self.torch_op_to_quantize
            and isinstance(node.args, list)
            and len(node.args) == 1
            and isinstance(node.args[0], ast.BinOp)
            and isinstance(node.args[0].op, ast.Add)
        )

    def patch(self, node: ast.AST, **kwargs) -> List[str]:
        assert "nb_quant_node" in kwargs, "missing nb_quant_node paramter"
        nb_quant_node: int = kwargs["nb_quant_node"]
        left_name = self.get_quant_name(nb_quant_node)
        right_name = self.get_quant_name(nb_quant_node + 1)
        node.args[0].left = self._wrap_attr(left_name, node.args[0].left)
        node.args[0].right = self._wrap_attr(right_name, node.args[0].right)
        return [left_name, right_name]

__init__(self, op) special #

Patch source code in the form torch.op(a + b) to torch.op(self.q1(a) + self.q1(b))

Parameters:

Name Type Description Default
op str

operator to match

required
Source code in src/transformer_deploy/QDQModels/ast_operator_patch.py
def __init__(self, op: str):
    """
    Patch source code in the form torch.op(a + b) to torch.op(self.q1(a) + self.q1(b))
    :param op: operator to match
    """
    self.torch_op_to_quantize = op

patch(self, node, **kwargs) #

Patch node by adding quantizer nodes around the operator provided during the init

Parameters:

Name Type Description Default
node AST

node to patch

required
kwargs

additional parameters, like nb_quant_node for the number of existing quantizer node

{}

Returns:

Type Description
List[str]

return list of generated quantizer node names

Source code in src/transformer_deploy/QDQModels/ast_operator_patch.py
def patch(self, node: ast.AST, **kwargs) -> List[str]:
    assert "nb_quant_node" in kwargs, "missing nb_quant_node paramter"
    nb_quant_node: int = kwargs["nb_quant_node"]
    left_name = self.get_quant_name(nb_quant_node)
    right_name = self.get_quant_name(nb_quant_node + 1)
    node.args[0].left = self._wrap_attr(left_name, node.args[0].left)
    node.args[0].right = self._wrap_attr(right_name, node.args[0].right)
    return [left_name, right_name]

should_patch(self, node) #

Check if a node should be patched

Parameters:

Name Type Description Default
node AST

node to check

required

Returns:

Type Description
bool

return True if it matches the operator provided during the init

Source code in src/transformer_deploy/QDQModels/ast_operator_patch.py
def should_patch(self, node: ast.AST) -> bool:
    return (
        isinstance(node, ast.Call)
        and isinstance(node.func, ast.Attribute)
        and node.func.attr == self.torch_op_to_quantize
        and isinstance(node.args, list)
        and len(node.args) == 1
        and isinstance(node.args[0], ast.BinOp)
        and isinstance(node.args[0].op, ast.Add)
    )

PatchLayer (PatchNode) #

Source code in src/transformer_deploy/QDQModels/ast_operator_patch.py
class PatchLayer(PatchNode):
    def __init__(self, origin_module: str, origin_layer: str, target_module: str, target_layer: str):
        """
        Patch source code in the form a.b(...) to c.d(...)
        :param origin_module: module to patch
        :param origin_layer: layer/method to patch
        :param target_module: new module to use
        :param target_layer: new layer/method to use
        """
        self.origin_module = origin_module
        self.origin_layer = origin_layer
        self.target_module = target_module
        self.target_layer = target_layer

    def should_patch(self, node: ast.AST) -> bool:
        return (
            isinstance(node, ast.Call)
            and isinstance(node.func, ast.Attribute)
            and isinstance(node.func.value, ast.Name)
            and node.func.value.id == self.origin_module
            and node.func.attr == self.origin_layer
        )

    def patch(self, node: ast.AST, **kwargs) -> List[str]:
        node.func.value.id = self.target_module
        node.func.attr = self.target_layer
        return []

__init__(self, origin_module, origin_layer, target_module, target_layer) special #

Patch source code in the form a.b(...) to c.d(...)

Parameters:

Name Type Description Default
origin_module str

module to patch

required
origin_layer str

layer/method to patch

required
target_module str

new module to use

required
target_layer str

new layer/method to use

required
Source code in src/transformer_deploy/QDQModels/ast_operator_patch.py
def __init__(self, origin_module: str, origin_layer: str, target_module: str, target_layer: str):
    """
    Patch source code in the form a.b(...) to c.d(...)
    :param origin_module: module to patch
    :param origin_layer: layer/method to patch
    :param target_module: new module to use
    :param target_layer: new layer/method to use
    """
    self.origin_module = origin_module
    self.origin_layer = origin_layer
    self.target_module = target_module
    self.target_layer = target_layer

patch(self, node, **kwargs) #

Patch node by adding quantizer nodes around the operator provided during the init

Parameters:

Name Type Description Default
node AST

node to patch

required
kwargs

additional parameters, like nb_quant_node for the number of existing quantizer node

{}

Returns:

Type Description
List[str]

return list of generated quantizer node names

Source code in src/transformer_deploy/QDQModels/ast_operator_patch.py
def patch(self, node: ast.AST, **kwargs) -> List[str]:
    node.func.value.id = self.target_module
    node.func.attr = self.target_layer
    return []

should_patch(self, node) #

Check if a node should be patched

Parameters:

Name Type Description Default
node AST

node to check

required

Returns:

Type Description
bool

return True if it matches the operator provided during the init

Source code in src/transformer_deploy/QDQModels/ast_operator_patch.py
def should_patch(self, node: ast.AST) -> bool:
    return (
        isinstance(node, ast.Call)
        and isinstance(node.func, ast.Attribute)
        and isinstance(node.func.value, ast.Name)
        and node.func.value.id == self.origin_module
        and node.func.attr == self.origin_layer
    )

PatchNode #

Source code in src/transformer_deploy/QDQModels/ast_operator_patch.py
class PatchNode(object):
    __metaclass__ = abc.ABCMeta
    torch_op_to_quantize: str

    @abc.abstractmethod
    def should_patch(self, node: ast.AST) -> bool:
        """
        Check if a node should be patched
        :param node: node to check
        :return: return True if it matches the operator provided during the __init__
        """
        raise Exception("to implement")

    @abc.abstractmethod
    def patch(self, node: ast.AST, **kwargs) -> List[str]:
        """
        Patch node by adding quantizer nodes around the operator provided during the __init__
        :param node: node to patch
        :param kwargs: additional parameters, like nb_quant_node for the number of existing quantizer node
        :return: return list of generated quantizer node names
        """
        raise Exception("to implement")

    @staticmethod
    def _wrap_attr(quantizer_name: str, tensor_var: ast.expr) -> ast.Call:
        """
        Generate quantization wrapping each attribute of a torch operation to optimize (matmul, add, etc.)
        :param quantizer_name: generated quantization name
        :param tensor_var: the variable to wrap
        :return: the ast tree to replace the original variable
        """
        return ast.Call(
            func=ast.Attribute(value=ast.Name(id="self", ctx=ast.Load()), attr=quantizer_name, ctx=ast.Load()),
            args=[tensor_var],
            keywords=[],
        )

    def get_quant_name(self, node_id: int) -> str:
        return f"{self.torch_op_to_quantize.lower()}_quantizer_{node_id}"

__metaclass__ (type) #

Metaclass for defining Abstract Base Classes (ABCs).

Use this metaclass to create an ABC. An ABC can be subclassed directly, and then acts as a mix-in class. You can also register unrelated concrete classes (even built-in classes) and unrelated ABCs as 'virtual subclasses' -- these and their descendants will be considered subclasses of the registering ABC by the built-in issubclass() function, but the registering ABC won't show up in their MRO (Method Resolution Order) nor will method implementations defined by the registering ABC be callable (not even via super()).

Source code in src/transformer_deploy/QDQModels/ast_operator_patch.py
class ABCMeta(type):
    """Metaclass for defining Abstract Base Classes (ABCs).

    Use this metaclass to create an ABC.  An ABC can be subclassed
    directly, and then acts as a mix-in class.  You can also register
    unrelated concrete classes (even built-in classes) and unrelated
    ABCs as 'virtual subclasses' -- these and their descendants will
    be considered subclasses of the registering ABC by the built-in
    issubclass() function, but the registering ABC won't show up in
    their MRO (Method Resolution Order) nor will method
    implementations defined by the registering ABC be callable (not
    even via super()).
    """
    def __new__(mcls, name, bases, namespace, **kwargs):
        cls = super().__new__(mcls, name, bases, namespace, **kwargs)
        _abc_init(cls)
        return cls

    def register(cls, subclass):
        """Register a virtual subclass of an ABC.

        Returns the subclass, to allow usage as a class decorator.
        """
        return _abc_register(cls, subclass)

    def __instancecheck__(cls, instance):
        """Override for isinstance(instance, cls)."""
        return _abc_instancecheck(cls, instance)

    def __subclasscheck__(cls, subclass):
        """Override for issubclass(subclass, cls)."""
        return _abc_subclasscheck(cls, subclass)

    def _dump_registry(cls, file=None):
        """Debug helper to print the ABC registry."""
        print(f"Class: {cls.__module__}.{cls.__qualname__}", file=file)
        print(f"Inv. counter: {get_cache_token()}", file=file)
        (_abc_registry, _abc_cache, _abc_negative_cache,
         _abc_negative_cache_version) = _get_dump(cls)
        print(f"_abc_registry: {_abc_registry!r}", file=file)
        print(f"_abc_cache: {_abc_cache!r}", file=file)
        print(f"_abc_negative_cache: {_abc_negative_cache!r}", file=file)
        print(f"_abc_negative_cache_version: {_abc_negative_cache_version!r}",
              file=file)

    def _abc_registry_clear(cls):
        """Clear the registry (for debugging or testing)."""
        _reset_registry(cls)

    def _abc_caches_clear(cls):
        """Clear the caches (for debugging or testing)."""
        _reset_caches(cls)

__instancecheck__(cls, instance) special #

Override for isinstance(instance, cls).

Source code in src/transformer_deploy/QDQModels/ast_operator_patch.py
def __instancecheck__(cls, instance):
    """Override for isinstance(instance, cls)."""
    return _abc_instancecheck(cls, instance)

__new__(mcls, name, bases, namespace, **kwargs) special staticmethod #

Create and return a new object. See help(type) for accurate signature.

Source code in src/transformer_deploy/QDQModels/ast_operator_patch.py
def __new__(mcls, name, bases, namespace, **kwargs):
    cls = super().__new__(mcls, name, bases, namespace, **kwargs)
    _abc_init(cls)
    return cls

__subclasscheck__(cls, subclass) special #

Override for issubclass(subclass, cls).

Source code in src/transformer_deploy/QDQModels/ast_operator_patch.py
def __subclasscheck__(cls, subclass):
    """Override for issubclass(subclass, cls)."""
    return _abc_subclasscheck(cls, subclass)

register(cls, subclass) #

Register a virtual subclass of an ABC.

Returns the subclass, to allow usage as a class decorator.

Source code in src/transformer_deploy/QDQModels/ast_operator_patch.py
def register(cls, subclass):
    """Register a virtual subclass of an ABC.

    Returns the subclass, to allow usage as a class decorator.
    """
    return _abc_register(cls, subclass)

patch(self, node, **kwargs) #

Patch node by adding quantizer nodes around the operator provided during the init

Parameters:

Name Type Description Default
node AST

node to patch

required
kwargs

additional parameters, like nb_quant_node for the number of existing quantizer node

{}

Returns:

Type Description
List[str]

return list of generated quantizer node names

Source code in src/transformer_deploy/QDQModels/ast_operator_patch.py
@abc.abstractmethod
def patch(self, node: ast.AST, **kwargs) -> List[str]:
    """
    Patch node by adding quantizer nodes around the operator provided during the __init__
    :param node: node to patch
    :param kwargs: additional parameters, like nb_quant_node for the number of existing quantizer node
    :return: return list of generated quantizer node names
    """
    raise Exception("to implement")

should_patch(self, node) #

Check if a node should be patched

Parameters:

Name Type Description Default
node AST

node to check

required

Returns:

Type Description
bool

return True if it matches the operator provided during the init

Source code in src/transformer_deploy/QDQModels/ast_operator_patch.py
@abc.abstractmethod
def should_patch(self, node: ast.AST) -> bool:
    """
    Check if a node should be patched
    :param node: node to check
    :return: return True if it matches the operator provided during the __init__
    """
    raise Exception("to implement")