Skip to content

St utils

Utils related to sentence-transformers.

STransformerWrapper (Module) #

Wrap sentence-transformers model to provide a forward function with multiple inputs as expected by ONNX export tool.

Source code in src/transformer_deploy/backends/st_utils.py
class STransformerWrapper(nn.Module):
    """
    Wrap sentence-transformers model to provide a forward function with multiple inputs as expected by ONNX export tool.
    """

    def __init__(self, model: "SentenceTransformer"):
        super().__init__()
        self.model = model

    def forward(self, *kargs, **kwargs):
        inputs = dict()
        if len(kargs) >= 2:
            inputs["input_ids"] = kargs[0]
            inputs["attention_mask"] = kargs[-1]
            if len(kargs) == 3:
                inputs["token_type_ids"] = kargs[1]
        if len(kwargs) > 0:
            inputs = kwargs
        assert 2 <= len(inputs) <= 3, f"unexpected number of inputs: {len(inputs)}"
        outputs = self.model.forward(input=inputs)
        return outputs["sentence_embedding"]

forward(self, *kargs, **kwargs) #

Defines the computation performed at every call.

Should be overridden by all subclasses.

.. note:: Although the recipe for forward pass needs to be defined within this function, one should call the :class:Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

Source code in src/transformer_deploy/backends/st_utils.py
def forward(self, *kargs, **kwargs):
    inputs = dict()
    if len(kargs) >= 2:
        inputs["input_ids"] = kargs[0]
        inputs["attention_mask"] = kargs[-1]
        if len(kargs) == 3:
            inputs["token_type_ids"] = kargs[1]
    if len(kwargs) > 0:
        inputs = kwargs
    assert 2 <= len(inputs) <= 3, f"unexpected number of inputs: {len(inputs)}"
    outputs = self.model.forward(input=inputs)
    return outputs["sentence_embedding"]

load_sentence_transformers(path) #

Load sentence-transformers model and wrap it to make it behave like any other transformers model

Parameters:

Name Type Description Default
path str

path to the model

required

Returns:

Type Description
STransformerWrapper

wrapped sentence-transformers model

Source code in src/transformer_deploy/backends/st_utils.py
def load_sentence_transformers(path: str) -> STransformerWrapper:
    """
    Load sentence-transformers model and wrap it to make it behave like any other transformers model
    :param path: path to the model
    :return: wrapped sentence-transformers model
    """
    try:
        from sentence_transformers import SentenceTransformer
    except ImportError:
        raise Exception(
            "sentence-transformers library is not present, please install it: pip install sentence-transformers"
        )
    model: SentenceTransformer = SentenceTransformer(path)
    return STransformerWrapper(model=model)