From ea8c0b27f2837e621a1c99c5a350cb9eaac63265 Mon Sep 17 00:00:00 2001 From: Frank Liu Date: Mon, 30 Dec 2024 12:22:22 -0800 Subject: [PATCH] [tokenizers] Support import zero-shot-classification to model zoo --- .../djl_converter/huggingface_converter.py | 15 +++--- .../djl_converter/huggingface_models.py | 5 ++ .../python/djl_converter/model_converter.py | 4 ++ .../zero_shot_classification_converter.py | 52 +++++++++++++++++++ 4 files changed, 69 insertions(+), 7 deletions(-) create mode 100644 extensions/tokenizers/src/main/python/djl_converter/zero_shot_classification_converter.py diff --git a/extensions/tokenizers/src/main/python/djl_converter/huggingface_converter.py b/extensions/tokenizers/src/main/python/djl_converter/huggingface_converter.py index cde63d373ce..fb126d11461 100644 --- a/extensions/tokenizers/src/main/python/djl_converter/huggingface_converter.py +++ b/extensions/tokenizers/src/main/python/djl_converter/huggingface_converter.py @@ -18,7 +18,7 @@ import onnx from torch import nn -from transformers.modeling_outputs import TokenClassifierOutput +from transformers.modeling_outputs import TokenClassifierOutput, Seq2SeqSequenceClassifierOutput from djl_converter.safetensors_convert import convert_file import torch @@ -58,8 +58,10 @@ def forward(self, output = self.model(input_ids, attention_mask) else: output = self.model(input_ids, attention_mask, token_type_ids) - if isinstance(output, TokenClassifierOutput): - # TokenClassifierOutput may contains mix of Tensor and Tuple(Tensor) + if isinstance(output, TokenClassifierOutput) or isinstance( + output, Seq2SeqSequenceClassifierOutput): + # TokenClassifierOutput/Seq2SeqSequenceClassifierOutput + # may contains mix of Tensor and Tuple(Tensor) return {"logits": output["logits"]} return output @@ -303,14 +305,13 @@ def jit_trace_model(self, hf_pipeline, model_id: str, temp_dir: str, # noinspection PyBroadException try: + wrapper = ModelWrapper(hf_pipeline.model, include_types) if include_types: script_module = torch.jit.trace( - ModelWrapper(hf_pipeline.model, include_types), - (input_ids, attention_mask, token_type_ids), + wrapper, (input_ids, attention_mask, token_type_ids), strict=False) else: - script_module = torch.jit.trace(ModelWrapper( - hf_pipeline.model, include_types), + script_module = torch.jit.trace(wrapper, (input_ids, attention_mask), strict=False) diff --git a/extensions/tokenizers/src/main/python/djl_converter/huggingface_models.py b/extensions/tokenizers/src/main/python/djl_converter/huggingface_models.py index 59d65a958be..2467cab6f3b 100644 --- a/extensions/tokenizers/src/main/python/djl_converter/huggingface_models.py +++ b/extensions/tokenizers/src/main/python/djl_converter/huggingface_models.py @@ -25,6 +25,7 @@ from djl_converter.sentence_similarity_converter import SentenceSimilarityConverter from djl_converter.text_classification_converter import TextClassificationConverter from djl_converter.token_classification_converter import TokenClassificationConverter +from djl_converter.zero_shot_classification_converter import ZeroShotClassificationConverter ARCHITECTURES_2_TASK = { "ForQuestionAnswering": "question-answering", @@ -42,6 +43,7 @@ "sentence-similarity": SentenceSimilarityConverter(), "text-classification": TextClassificationConverter(), "token-classification": TokenClassificationConverter(), + "zero-shot-classification": ZeroShotClassificationConverter(), } @@ -127,6 +129,9 @@ def list_models(self, args: Namespace) -> List[dict]: if not task: if "sentence-similarity" in model_info.tags: task = "sentence-similarity" + else: + if "zero-shot-classification" in model_info.tags: + task = "zero-shot-classification" if not task: logging.info( diff --git a/extensions/tokenizers/src/main/python/djl_converter/model_converter.py b/extensions/tokenizers/src/main/python/djl_converter/model_converter.py index a9f7b6d9057..11498509f41 100644 --- a/extensions/tokenizers/src/main/python/djl_converter/model_converter.py +++ b/extensions/tokenizers/src/main/python/djl_converter/model_converter.py @@ -63,6 +63,10 @@ def main(): if not task: if "sentence-similarity" in model_info.tags: task = "sentence-similarity" + else: + if "zero-shot-classification" in model_info.tags: + task = "zero-shot-classification" + if not task: logging.error( f"Unsupported model architecture: {arch} for {args.model_id}.") diff --git a/extensions/tokenizers/src/main/python/djl_converter/zero_shot_classification_converter.py b/extensions/tokenizers/src/main/python/djl_converter/zero_shot_classification_converter.py new file mode 100644 index 00000000000..63b4b28c633 --- /dev/null +++ b/extensions/tokenizers/src/main/python/djl_converter/zero_shot_classification_converter.py @@ -0,0 +1,52 @@ +#!/usr/bin/env python +# +# Copyright 2025 Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file +# except in compliance with the License. A copy of the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "LICENSE.txt" file accompanying this file. This file is distributed on an "AS IS" +# BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, express or implied. See the License for +# the specific language governing permissions and limitations under the License. +import math + +from djl_converter.huggingface_converter import HuggingfaceConverter + + +class ZeroShotClassificationConverter(HuggingfaceConverter): + + def __init__(self): + super().__init__() + self.task = "zero-shot-classification" + self.application = "nlp/zero_shot_classification" + self.translator = "ai.djl.huggingface.translator.ZeroShotClassificationTranslatorFactory" + self.inputs = "one day I will see the world" + self.labels = ['travel'] + + def encode_inputs(self, tokenizer): + return tokenizer(self.inputs, + f"This example is {self.labels[0]}.", + return_tensors='pt') + + def verify_jit_output(self, hf_pipeline, encoding, out): + logits = out['logits'] + entail_contradiction_logits = logits[:, [0, 2]] + probs = entail_contradiction_logits.softmax(dim=1) + score = probs[:, 1].item() + + pipeline_output = hf_pipeline(self.inputs, self.labels) + expected = pipeline_output["scores"][0] + + if math.isclose(expected, score, abs_tol=1e-3): + return True, None + + return False, f"Unexpected inference result" + + def get_extra_arguments(self, hf_pipeline, model_id: str, + temp_dir: str) -> dict: + return { + "padding": "true", + "truncation": "only_first", + }