Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Proper sentence-transformers ONNX export support #1589

Merged

Conversation

fxmarty
Copy link
Contributor

@fxmarty fxmarty commented Dec 12, 2023

As reported in #1519, simple mapping sentence-transformers to transformers library allows to use only a subset of sentence-transformers library.

This PR adds the support of the export of sentence_embedding for sentence-transformers models.

Examples:

optimum-cli export onnx -m sentence-transformers/clip-ViT-B-32-multilingual-v1 clip_vit_multilingual_onnx
optimum-cli export onnx -m sentence-transformers/all-MiniLM-L6-v2 minilm_onnx

@fxmarty fxmarty requested review from mht-sharma and michaelbenayoun and removed request for mht-sharma December 12, 2023 15:11
@JingyaHuang
Copy link
Contributor

Can those extra outputs be supported directly in transformers? Just find the changes a bit hacky, and this is causing errors in optimum neuron subpackage: aws-neuron/aws-neuron-sdk#808

@JingyaHuang
Copy link
Contributor

standardizing model attributes made it a bit misleading for debug, eg. it was surprising to get sentence-transformers-transformer as model_type from a config where the model_type is marked as bert:
image

@fxmarty
Copy link
Contributor Author

fxmarty commented Jan 9, 2024

I doubt this is feasible, as sentence-transformers adds quite a few features on top of transformers. For example, the sentence embeddings.

standardizing model attributes made it a bit misleading for debug, eg. it was surprising to get sentence-transformers-transformer as model_type from a config where the model_type is marked as bert:

I don't find this to be too hacky, as the model_type bert refers to the bottleneck model in sentence-transformers' nn.Sequential. Some sentence-transformers models use the Transformer as bottleneck, some use a CLIPModel, and the export is different depending on the architecture:

class SentenceTransformersTransformerOnnxConfig(TextEncoderOnnxConfig):
NORMALIZED_CONFIG_CLASS = NormalizedTextConfig
DEFAULT_ONNX_OPSET = 14 # Some bottleneck transformers models require a specific ONNX opset to be successfully exported. We put a rather high opset here for the export to work for all architectures.
@property
def inputs(self) -> Dict[str, Dict[int, str]]:
return {
"input_ids": {0: "batch_size", 1: "sequence_length"},
"attention_mask": {0: "batch_size", 1: "sequence_length"},
}
@property
def outputs(self) -> Dict[str, Dict[int, str]]:
return {
"token_embeddings": {0: "batch_size", 1: "sequence_length"},
"sentence_embedding": {0: "batch_size"},
}
# we need to set output_attentions=True in the model input to avoid calling
# torch.nn.functional.scaled_dot_product_attention that is not supported by the ONNX export
# due to the op torch.nn.functional.multi_head_attention_forward used for WavLM
def patch_model_for_export(
self, model: Union["PreTrainedModel", "TFPreTrainedModel"], model_kwargs: Optional[Dict[str, Any]] = None
) -> "ModelPatcher":
return SentenceTransformersTransformerPatcher(self, model, model_kwargs=model_kwargs)
&
class SentenceTransformersCLIPOnnxConfig(CLIPOnnxConfig):
@property
def outputs(self) -> Dict[str, Dict[int, str]]:
return {
"text_embeds": {0: "text_batch_size"},
"image_embeds": {0: "image_batch_size"},
}
def patch_model_for_export(
self, model: Union["PreTrainedModel", "TFPreTrainedModel"], model_kwargs: Optional[Dict[str, Any]] = None
) -> "ModelPatcher":
return SentenceTransformersCLIPPatcher(self, model, model_kwargs=model_kwargs)
.

Happy to refactor if needed though

@require_torch
@require_vision
@require_sentence_transformers
@pytest.mark.timm_test
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is it timm test? @fxmarty

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is a typo

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants