Skip to content

Token classifier

This module is copy-pasted in generated Triton configuration folder to perform inference.

AggregationStrategy (Enum) #

All the valid aggregation strategies for TokenClassificationPipeline

Source code in src/transformer_deploy/utils/token_classifier.py
class AggregationStrategy(Enum):
    """All the valid aggregation strategies for TokenClassificationPipeline"""

    NONE = "none"
    SIMPLE = "simple"
    FIRST = "first"
    AVERAGE = "average"
    MAX = "max"

TritonPythonModel #

Source code in src/transformer_deploy/utils/token_classifier.py
class TritonPythonModel:
    tokenizer: PreTrainedTokenizer
    device: str
    # change aggregation strategy here
    aggregation_strategy = AggregationStrategy.FIRST
    ignore_labels = ["O"]

    def initialize(self, args: Dict[str, str]) -> None:
        """
        Initialize the tokenization process
        :param args: arguments from Triton config file
        """
        current_path: str = os.path.join(args["model_repository"], args["model_version"])
        target_model = args["model_name"].replace("_inference", "_model")
        self.device = "cpu" if args["model_instance_kind"] == "CPU" else "cuda"

        def inference_triton(
            input_ids: torch.Tensor,
            token_type_ids: torch.Tensor,
            attention_mask: torch.Tensor,
        ) -> np.ndarray:
            input_ids = input_ids.type(dtype=torch.int32)
            token_type_ids = token_type_ids.type(dtype=torch.int32)
            attention_mask = attention_mask.type(dtype=torch.int32)
            inputs = [
                pb_utils.Tensor.from_dlpack("input_ids", torch.to_dlpack(input_ids)),
                pb_utils.Tensor.from_dlpack("token_type_ids", torch.to_dlpack(token_type_ids)),
                pb_utils.Tensor.from_dlpack("attention_mask", torch.to_dlpack(attention_mask)),
            ]
            inference_request = pb_utils.InferenceRequest(
                model_name=target_model,
                requested_output_names=["output"],
                inputs=inputs,
            )
            inference_response = inference_request.exec()
            if inference_response.has_error():
                raise pb_utils.TritonModelException(inference_response.error().message())
            else:
                output = pb_utils.get_output_tensor_by_name(inference_response, "output")
                tensor: torch.Tensor = torch.from_dlpack(output.to_dlpack())
                return tensor.detach().cpu().numpy()

        self.model = inference_triton
        self.tokenizer = AutoTokenizer.from_pretrained(current_path)
        self.config = AutoConfig.from_pretrained(current_path)

    def execute(self, requests) -> "List[List[pb_utils.Tensor]]":
        """
        Parse and tokenize each request
        :param requests: 1 or more requests received by Triton server.
        :return: text as input tensors
        """
        responses = []
        # for loop for batch requests (disabled in our case)
        for request in requests:
            # binary data typed back to string
            query = [t.decode("UTF-8") for t in pb_utils.get_input_tensor_by_name(request, "TEXT").as_numpy().tolist()]
            tokens: BatchEncoding = self.tokenizer(
                text=query[0],
                return_tensors=TensorType.PYTORCH,
                return_attention_mask=True,
                return_special_tokens_mask=True,
                return_offsets_mapping=self.tokenizer.is_fast,
            )

            input_ids = tokens["input_ids"]
            token_type_ids = tokens["token_type_ids"]
            attention_mask = tokens["attention_mask"]

            if self.device == "cuda":
                input_ids = input_ids.to("cuda")
                token_type_ids = token_type_ids.to("cuda")
                attention_mask = attention_mask.to("cuda")

            output_seq: np.ndarray = self.model(input_ids, token_type_ids, attention_mask)

            logits = output_seq[0]
            sentence = query[0]
            input_ids = input_ids.cpu().numpy()[0]
            offset_mapping = tokens["offset_mapping"][0] if "offset_mapping" in tokens else None
            special_tokens_mask = tokens["special_tokens_mask"].numpy()[0]

            maxes = np.max(logits, axis=-1, keepdims=True)
            shifted_exp = np.exp(logits - maxes)
            scores = shifted_exp / shifted_exp.sum(axis=-1, keepdims=True)

            pre_entities = self.gather_pre_entities(
                sentence,
                input_ids,
                scores,
                offset_mapping,
                special_tokens_mask,
                self.aggregation_strategy,
            )
            grouped_entities = self.aggregate(pre_entities, self.aggregation_strategy)
            # Filter anything that is in self.ignore_labels
            entities = [
                entity
                for entity in grouped_entities
                if entity.get("entity", None) not in self.ignore_labels
                and entity.get("entity_group", None) not in self.ignore_labels
            ]

            tensor_output = [pb_utils.Tensor("output", np.array(json.dumps(entities), dtype=object))]
            responses.append(pb_utils.InferenceResponse(tensor_output))
        return responses

    def gather_pre_entities(
        self,
        sentence: str,
        input_ids: np.ndarray,
        scores: np.ndarray,
        offset_mapping: Optional[List[Tuple[int, int]]],
        special_tokens_mask: np.ndarray,
        aggregation_strategy: AggregationStrategy,
    ) -> List[dict]:
        """Fuse various numpy arrays into dicts with all the information needed for aggregation"""
        pre_entities = []
        for idx, token_scores in enumerate(scores):
            # Filter special_tokens, they should only occur
            # at the sentence boundaries since we're not encoding pairs of
            # sentences so we don't have to keep track of those.
            if special_tokens_mask[idx]:
                continue

            word = self.tokenizer.convert_ids_to_tokens(int(input_ids[idx]))
            if offset_mapping is not None:
                start_ind, end_ind = offset_mapping[idx]
                if not isinstance(start_ind, int):
                    start_ind = int(start_ind.numpy())
                    end_ind = int(end_ind.numpy())
                word_ref = sentence[start_ind:end_ind]
                if getattr(self.tokenizer._tokenizer.model, "continuing_subword_prefix", None):
                    # This is a BPE, word aware tokenizer, there is a correct way
                    # to fuse tokens
                    is_subword = len(word) != len(word_ref)
                else:
                    # This is a fallback heuristic. This will fail most likely on any kind of text + punctuation
                    # mixtures that will be considered "words". Non word aware models cannot do better
                    # than this unfortunately.
                    if aggregation_strategy in {
                        AggregationStrategy.FIRST,
                        AggregationStrategy.AVERAGE,
                        AggregationStrategy.MAX,
                    }:
                        warnings.warn(
                            "Tokenizer does not support real words, using fallback heuristic",
                            UserWarning,
                        )
                    is_subword = sentence[start_ind - 1 : start_ind] != " " if start_ind > 0 else False  # noqa: E203

                if int(input_ids[idx]) == self.tokenizer.unk_token_id:
                    word = word_ref
                    is_subword = False
            else:
                start_ind = None
                end_ind = None
                is_subword = False

            pre_entity = {
                "word": word,
                "scores": token_scores,
                "start": start_ind,
                "end": end_ind,
                "index": idx,
                "is_subword": is_subword,
            }
            pre_entities.append(pre_entity)
        return pre_entities

    def aggregate(self, pre_entities: List[dict], aggregation_strategy: AggregationStrategy) -> List[dict]:
        if aggregation_strategy in {
            AggregationStrategy.NONE,
            AggregationStrategy.SIMPLE,
        }:
            entities = []
            for pre_entity in pre_entities:
                entity_idx = pre_entity["scores"].argmax()
                score = pre_entity["scores"][entity_idx]
                entity = {
                    "entity": self.config.id2label[entity_idx],
                    "score": float(score),
                    "index": pre_entity["index"],
                    "word": pre_entity["word"],
                    "start": pre_entity["start"],
                    "end": pre_entity["end"],
                }
                entities.append(entity)
        else:
            entities = self.aggregate_words(pre_entities, aggregation_strategy)

        if aggregation_strategy == AggregationStrategy.NONE:
            return entities
        return self.group_entities(entities)

    def aggregate_word(self, entities: List[dict], aggregation_strategy: AggregationStrategy) -> dict:
        word = self.tokenizer.convert_tokens_to_string([entity["word"] for entity in entities])
        if aggregation_strategy == AggregationStrategy.FIRST:
            scores = entities[0]["scores"]
            idx = scores.argmax()
            score = scores[idx]
            entity = self.config.id2label[idx]
        elif aggregation_strategy == AggregationStrategy.MAX:
            max_entity = max(entities, key=lambda entity: entity["scores"].max())
            scores = max_entity["scores"]
            idx = scores.argmax()
            score = scores[idx]
            entity = self.config.id2label[idx]
        elif aggregation_strategy == AggregationStrategy.AVERAGE:
            scores = np.stack([entity["scores"] for entity in entities])
            average_scores = np.nanmean(scores, axis=0)
            entity_idx = average_scores.argmax()
            entity = self.config.id2label[entity_idx]
            score = average_scores[entity_idx]
        else:
            raise ValueError("Invalid aggregation_strategy")
        new_entity = {
            "entity": entity,
            "score": float(score),
            "word": word,
            "start": entities[0]["start"],
            "end": entities[-1]["end"],
        }
        return new_entity

    def aggregate_words(self, entities: List[dict], aggregation_strategy: AggregationStrategy) -> List[dict]:
        """
        Override tokens from a given word that disagree to force agreement on word boundaries.
        Example: micro|soft| com|pany| B-ENT I-NAME I-ENT I-ENT will be rewritten with first strategy as microsoft|
        company| B-ENT I-ENT
        """
        if aggregation_strategy in {
            AggregationStrategy.NONE,
            AggregationStrategy.SIMPLE,
        }:
            raise ValueError("NONE and SIMPLE strategies are invalid for word aggregation")

        word_entities = []
        word_group = None
        for entity in entities:
            if word_group is None:
                word_group = [entity]
            elif entity["is_subword"]:
                word_group.append(entity)
            else:
                word_entities.append(self.aggregate_word(word_group, aggregation_strategy))
                word_group = [entity]
        # Last item
        word_entities.append(self.aggregate_word(word_group, aggregation_strategy))
        return word_entities

    def group_sub_entities(self, entities: List[dict]) -> dict:
        """
        Group together the adjacent tokens with the same entity predicted.
        Args:
            entities (`dict`): The entities predicted by the pipeline.
        """
        # Get the first entity in the entity group
        entity = entities[0]["entity"].split("-")[-1]
        scores = np.nanmean([entity["score"] for entity in entities])
        tokens = [entity["word"] for entity in entities]

        entity_group = {
            "entity_group": entity,
            "score": float(np.mean(scores)),
            "word": self.tokenizer.convert_tokens_to_string(tokens),
            "start": entities[0]["start"],
            "end": entities[-1]["end"],
        }
        return entity_group

    def get_tag(self, entity_name: str) -> Tuple[str, str]:
        if entity_name.startswith("B-"):
            bi = "B"
            tag = entity_name[2:]
        elif entity_name.startswith("I-"):
            bi = "I"
            tag = entity_name[2:]
        else:
            # It's not in B-, I- format
            # Default to I- for continuation.
            bi = "I"
            tag = entity_name
        return bi, tag

    def group_entities(self, entities: List[dict]) -> List[dict]:
        """
        Find and group together the adjacent tokens with the same entity predicted.
        Args:
            entities (`dict`): The entities predicted by the pipeline.
        """

        entity_groups = []
        entity_group_disagg = []

        for entity in entities:
            if not entity_group_disagg:
                entity_group_disagg.append(entity)
                continue

            # If the current entity is similar and adjacent to the previous entity,
            # append it to the disaggregated entity group
            # The split is meant to account for the "B" and "I" prefixes
            # Shouldn't merge if both entities are B-type
            bi, tag = self.get_tag(entity["entity"])
            last_bi, last_tag = self.get_tag(entity_group_disagg[-1]["entity"])

            if tag == last_tag and bi != "B":
                # Modify subword type to be previous_type
                entity_group_disagg.append(entity)
            else:
                # If the current entity is different from the previous entity
                # aggregate the disaggregated entity group
                entity_groups.append(self.group_sub_entities(entity_group_disagg))
                entity_group_disagg = [entity]
        if entity_group_disagg:
            # it's the last entity, add it to the entity groups
            entity_groups.append(self.group_sub_entities(entity_group_disagg))

        return entity_groups

aggregate_words(self, entities, aggregation_strategy) #

Override tokens from a given word that disagree to force agreement on word boundaries. Example: micro|soft| com|pany| B-ENT I-NAME I-ENT I-ENT will be rewritten with first strategy as microsoft| company| B-ENT I-ENT

Source code in src/transformer_deploy/utils/token_classifier.py
def aggregate_words(self, entities: List[dict], aggregation_strategy: AggregationStrategy) -> List[dict]:
    """
    Override tokens from a given word that disagree to force agreement on word boundaries.
    Example: micro|soft| com|pany| B-ENT I-NAME I-ENT I-ENT will be rewritten with first strategy as microsoft|
    company| B-ENT I-ENT
    """
    if aggregation_strategy in {
        AggregationStrategy.NONE,
        AggregationStrategy.SIMPLE,
    }:
        raise ValueError("NONE and SIMPLE strategies are invalid for word aggregation")

    word_entities = []
    word_group = None
    for entity in entities:
        if word_group is None:
            word_group = [entity]
        elif entity["is_subword"]:
            word_group.append(entity)
        else:
            word_entities.append(self.aggregate_word(word_group, aggregation_strategy))
            word_group = [entity]
    # Last item
    word_entities.append(self.aggregate_word(word_group, aggregation_strategy))
    return word_entities

execute(self, requests) #

Parse and tokenize each request

Parameters:

Name Type Description Default
requests

1 or more requests received by Triton server.

required

Returns:

Type Description
List[List[pb_utils.Tensor]]

text as input tensors

Source code in src/transformer_deploy/utils/token_classifier.py
def execute(self, requests) -> "List[List[pb_utils.Tensor]]":
    """
    Parse and tokenize each request
    :param requests: 1 or more requests received by Triton server.
    :return: text as input tensors
    """
    responses = []
    # for loop for batch requests (disabled in our case)
    for request in requests:
        # binary data typed back to string
        query = [t.decode("UTF-8") for t in pb_utils.get_input_tensor_by_name(request, "TEXT").as_numpy().tolist()]
        tokens: BatchEncoding = self.tokenizer(
            text=query[0],
            return_tensors=TensorType.PYTORCH,
            return_attention_mask=True,
            return_special_tokens_mask=True,
            return_offsets_mapping=self.tokenizer.is_fast,
        )

        input_ids = tokens["input_ids"]
        token_type_ids = tokens["token_type_ids"]
        attention_mask = tokens["attention_mask"]

        if self.device == "cuda":
            input_ids = input_ids.to("cuda")
            token_type_ids = token_type_ids.to("cuda")
            attention_mask = attention_mask.to("cuda")

        output_seq: np.ndarray = self.model(input_ids, token_type_ids, attention_mask)

        logits = output_seq[0]
        sentence = query[0]
        input_ids = input_ids.cpu().numpy()[0]
        offset_mapping = tokens["offset_mapping"][0] if "offset_mapping" in tokens else None
        special_tokens_mask = tokens["special_tokens_mask"].numpy()[0]

        maxes = np.max(logits, axis=-1, keepdims=True)
        shifted_exp = np.exp(logits - maxes)
        scores = shifted_exp / shifted_exp.sum(axis=-1, keepdims=True)

        pre_entities = self.gather_pre_entities(
            sentence,
            input_ids,
            scores,
            offset_mapping,
            special_tokens_mask,
            self.aggregation_strategy,
        )
        grouped_entities = self.aggregate(pre_entities, self.aggregation_strategy)
        # Filter anything that is in self.ignore_labels
        entities = [
            entity
            for entity in grouped_entities
            if entity.get("entity", None) not in self.ignore_labels
            and entity.get("entity_group", None) not in self.ignore_labels
        ]

        tensor_output = [pb_utils.Tensor("output", np.array(json.dumps(entities), dtype=object))]
        responses.append(pb_utils.InferenceResponse(tensor_output))
    return responses

gather_pre_entities(self, sentence, input_ids, scores, offset_mapping, special_tokens_mask, aggregation_strategy) #

Fuse various numpy arrays into dicts with all the information needed for aggregation

Source code in src/transformer_deploy/utils/token_classifier.py
def gather_pre_entities(
    self,
    sentence: str,
    input_ids: np.ndarray,
    scores: np.ndarray,
    offset_mapping: Optional[List[Tuple[int, int]]],
    special_tokens_mask: np.ndarray,
    aggregation_strategy: AggregationStrategy,
) -> List[dict]:
    """Fuse various numpy arrays into dicts with all the information needed for aggregation"""
    pre_entities = []
    for idx, token_scores in enumerate(scores):
        # Filter special_tokens, they should only occur
        # at the sentence boundaries since we're not encoding pairs of
        # sentences so we don't have to keep track of those.
        if special_tokens_mask[idx]:
            continue

        word = self.tokenizer.convert_ids_to_tokens(int(input_ids[idx]))
        if offset_mapping is not None:
            start_ind, end_ind = offset_mapping[idx]
            if not isinstance(start_ind, int):
                start_ind = int(start_ind.numpy())
                end_ind = int(end_ind.numpy())
            word_ref = sentence[start_ind:end_ind]
            if getattr(self.tokenizer._tokenizer.model, "continuing_subword_prefix", None):
                # This is a BPE, word aware tokenizer, there is a correct way
                # to fuse tokens
                is_subword = len(word) != len(word_ref)
            else:
                # This is a fallback heuristic. This will fail most likely on any kind of text + punctuation
                # mixtures that will be considered "words". Non word aware models cannot do better
                # than this unfortunately.
                if aggregation_strategy in {
                    AggregationStrategy.FIRST,
                    AggregationStrategy.AVERAGE,
                    AggregationStrategy.MAX,
                }:
                    warnings.warn(
                        "Tokenizer does not support real words, using fallback heuristic",
                        UserWarning,
                    )
                is_subword = sentence[start_ind - 1 : start_ind] != " " if start_ind > 0 else False  # noqa: E203

            if int(input_ids[idx]) == self.tokenizer.unk_token_id:
                word = word_ref
                is_subword = False
        else:
            start_ind = None
            end_ind = None
            is_subword = False

        pre_entity = {
            "word": word,
            "scores": token_scores,
            "start": start_ind,
            "end": end_ind,
            "index": idx,
            "is_subword": is_subword,
        }
        pre_entities.append(pre_entity)
    return pre_entities

group_entities(self, entities) #

Find and group together the adjacent tokens with the same entity predicted. Args: entities (dict): The entities predicted by the pipeline.

Source code in src/transformer_deploy/utils/token_classifier.py
def group_entities(self, entities: List[dict]) -> List[dict]:
    """
    Find and group together the adjacent tokens with the same entity predicted.
    Args:
        entities (`dict`): The entities predicted by the pipeline.
    """

    entity_groups = []
    entity_group_disagg = []

    for entity in entities:
        if not entity_group_disagg:
            entity_group_disagg.append(entity)
            continue

        # If the current entity is similar and adjacent to the previous entity,
        # append it to the disaggregated entity group
        # The split is meant to account for the "B" and "I" prefixes
        # Shouldn't merge if both entities are B-type
        bi, tag = self.get_tag(entity["entity"])
        last_bi, last_tag = self.get_tag(entity_group_disagg[-1]["entity"])

        if tag == last_tag and bi != "B":
            # Modify subword type to be previous_type
            entity_group_disagg.append(entity)
        else:
            # If the current entity is different from the previous entity
            # aggregate the disaggregated entity group
            entity_groups.append(self.group_sub_entities(entity_group_disagg))
            entity_group_disagg = [entity]
    if entity_group_disagg:
        # it's the last entity, add it to the entity groups
        entity_groups.append(self.group_sub_entities(entity_group_disagg))

    return entity_groups

group_sub_entities(self, entities) #

Group together the adjacent tokens with the same entity predicted. Args: entities (dict): The entities predicted by the pipeline.

Source code in src/transformer_deploy/utils/token_classifier.py
def group_sub_entities(self, entities: List[dict]) -> dict:
    """
    Group together the adjacent tokens with the same entity predicted.
    Args:
        entities (`dict`): The entities predicted by the pipeline.
    """
    # Get the first entity in the entity group
    entity = entities[0]["entity"].split("-")[-1]
    scores = np.nanmean([entity["score"] for entity in entities])
    tokens = [entity["word"] for entity in entities]

    entity_group = {
        "entity_group": entity,
        "score": float(np.mean(scores)),
        "word": self.tokenizer.convert_tokens_to_string(tokens),
        "start": entities[0]["start"],
        "end": entities[-1]["end"],
    }
    return entity_group

initialize(self, args) #

Initialize the tokenization process

Parameters:

Name Type Description Default
args Dict[str, str]

arguments from Triton config file

required
Source code in src/transformer_deploy/utils/token_classifier.py
def initialize(self, args: Dict[str, str]) -> None:
    """
    Initialize the tokenization process
    :param args: arguments from Triton config file
    """
    current_path: str = os.path.join(args["model_repository"], args["model_version"])
    target_model = args["model_name"].replace("_inference", "_model")
    self.device = "cpu" if args["model_instance_kind"] == "CPU" else "cuda"

    def inference_triton(
        input_ids: torch.Tensor,
        token_type_ids: torch.Tensor,
        attention_mask: torch.Tensor,
    ) -> np.ndarray:
        input_ids = input_ids.type(dtype=torch.int32)
        token_type_ids = token_type_ids.type(dtype=torch.int32)
        attention_mask = attention_mask.type(dtype=torch.int32)
        inputs = [
            pb_utils.Tensor.from_dlpack("input_ids", torch.to_dlpack(input_ids)),
            pb_utils.Tensor.from_dlpack("token_type_ids", torch.to_dlpack(token_type_ids)),
            pb_utils.Tensor.from_dlpack("attention_mask", torch.to_dlpack(attention_mask)),
        ]
        inference_request = pb_utils.InferenceRequest(
            model_name=target_model,
            requested_output_names=["output"],
            inputs=inputs,
        )
        inference_response = inference_request.exec()
        if inference_response.has_error():
            raise pb_utils.TritonModelException(inference_response.error().message())
        else:
            output = pb_utils.get_output_tensor_by_name(inference_response, "output")
            tensor: torch.Tensor = torch.from_dlpack(output.to_dlpack())
            return tensor.detach().cpu().numpy()

    self.model = inference_triton
    self.tokenizer = AutoTokenizer.from_pretrained(current_path)
    self.config = AutoConfig.from_pretrained(current_path)