St utils
Utils related to sentence-transformers.
STransformerWrapper (Module)
#
Wrap sentence-transformers model to provide a forward function with multiple inputs as expected by ONNX export tool.
Source code in src/transformer_deploy/backends/st_utils.py
class STransformerWrapper(nn.Module):
"""
Wrap sentence-transformers model to provide a forward function with multiple inputs as expected by ONNX export tool.
"""
def __init__(self, model: "SentenceTransformer"):
super().__init__()
self.model = model
def forward(self, *kargs, **kwargs):
inputs = dict()
if len(kargs) >= 2:
inputs["input_ids"] = kargs[0]
inputs["attention_mask"] = kargs[-1]
if len(kargs) == 3:
inputs["token_type_ids"] = kargs[1]
if len(kwargs) > 0:
inputs = kwargs
assert 2 <= len(inputs) <= 3, f"unexpected number of inputs: {len(inputs)}"
outputs = self.model.forward(input=inputs)
return outputs["sentence_embedding"]
forward(self, *kargs, **kwargs)
#
Defines the computation performed at every call.
Should be overridden by all subclasses.
.. note::
Although the recipe for forward pass needs to be defined within
this function, one should call the :class:Module
instance afterwards
instead of this since the former takes care of running the
registered hooks while the latter silently ignores them.
Source code in src/transformer_deploy/backends/st_utils.py
def forward(self, *kargs, **kwargs):
inputs = dict()
if len(kargs) >= 2:
inputs["input_ids"] = kargs[0]
inputs["attention_mask"] = kargs[-1]
if len(kargs) == 3:
inputs["token_type_ids"] = kargs[1]
if len(kwargs) > 0:
inputs = kwargs
assert 2 <= len(inputs) <= 3, f"unexpected number of inputs: {len(inputs)}"
outputs = self.model.forward(input=inputs)
return outputs["sentence_embedding"]
load_sentence_transformers(path)
#
Load sentence-transformers model and wrap it to make it behave like any other transformers model
Parameters:
Name | Type | Description | Default |
---|---|---|---|
path |
str |
path to the model |
required |
Returns:
Type | Description |
---|---|
STransformerWrapper |
wrapped sentence-transformers model |
Source code in src/transformer_deploy/backends/st_utils.py
def load_sentence_transformers(path: str) -> STransformerWrapper:
"""
Load sentence-transformers model and wrap it to make it behave like any other transformers model
:param path: path to the model
:return: wrapped sentence-transformers model
"""
try:
from sentence_transformers import SentenceTransformer
except ImportError:
raise Exception(
"sentence-transformers library is not present, please install it: pip install sentence-transformers"
)
model: SentenceTransformer = SentenceTransformer(path)
return STransformerWrapper(model=model)