Skip to content

Commit

Permalink
proper sentence-transformers onnx export support
Browse files Browse the repository at this point in the history
  • Loading branch information
fxmarty committed Dec 12, 2023
1 parent e840d21 commit 304a4ac
Show file tree
Hide file tree
Showing 9 changed files with 212 additions and 43 deletions.
3 changes: 1 addition & 2 deletions optimum/exporters/onnx/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,8 +264,7 @@ def main_export(
_variant (`str`, defaults to `default`):
Specify the variant of the ONNX export to use.
library_name (`Optional[str]`, defaults to `None`):
The library of the model(`"tansformers"` or `"diffusers"` or `"timm"`). If not provided, will attempt to automatically detect
the library name for the checkpoint.
The library of the model (`"tansformers"` or `"diffusers"` or `"timm"` or `"sentence_transformers"`). If not provided, will attempt to automatically detect the library name for the checkpoint.
legacy (`bool`, defaults to `False`):
Disable the use of position_ids for text-generation models that require it for batched generation. Also enable to export decoder only models in three files (without + with past and the merged model). This argument is introduced for backward compatibility and will be removed in a future release of Optimum.
**kwargs_shapes (`Dict`):
Expand Down
42 changes: 42 additions & 0 deletions optimum/exporters/onnx/model_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,8 @@
from .model_patcher import (
FalconModelPatcher,
SAMModelPatcher,
SentenceTransformersCLIPPatcher,
SentenceTransformersTransformerPatcher,
SpeechT5ModelPatcher,
VisionEncoderDecoderPatcher,
WavLMModelPatcher,
Expand Down Expand Up @@ -799,6 +801,32 @@ def inputs(self) -> Dict[str, Dict[int, str]]:
return {"pixel_values": {0: "batch_size"}}


class SentenceTransformersTransformerOnnxConfig(TextEncoderOnnxConfig):
NORMALIZED_CONFIG_CLASS = NormalizedTextConfig

@property
def inputs(self) -> Dict[str, Dict[int, str]]:
return {
"input_ids": {0: "batch_size", 1: "sequence_length"},
"attention_mask": {0: "batch_size", 1: "sequence_length"},
}

@property
def outputs(self) -> Dict[str, Dict[int, str]]:
return {
"token_embeddings": {0: "batch_size", 1: "sequence_length"},
"sentence_embedding": {0: "batch_size"},
}

# we need to set output_attentions=True in the model input to avoid calling
# torch.nn.functional.scaled_dot_product_attention that is not supported by the ONNX export
# due to the op torch.nn.functional.multi_head_attention_forward used for WavLM
def patch_model_for_export(
self, model: Union["PreTrainedModel", "TFPreTrainedModel"], model_kwargs: Optional[Dict[str, Any]] = None
) -> "ModelPatcher":
return SentenceTransformersTransformerPatcher(self, model, model_kwargs=model_kwargs)


class CLIPNormalizedConfig(NormalizedTextAndVisionConfig):
TEXT_CONFIG = "text_config"
VISION_CONFIG = "vision_config"
Expand Down Expand Up @@ -826,6 +854,20 @@ def outputs(self) -> Dict[str, Dict[int, str]]:
}


class SentenceTransformersCLIPOnnxConfig(CLIPOnnxConfig):
@property
def outputs(self) -> Dict[str, Dict[int, str]]:
return {
"text_embeds": {0: "text_batch_size"},
"image_embeds": {0: "image_batch_size"},
}

def patch_model_for_export(
self, model: Union["PreTrainedModel", "TFPreTrainedModel"], model_kwargs: Optional[Dict[str, Any]] = None
) -> "ModelPatcher":
return SentenceTransformersCLIPPatcher(self, model, model_kwargs=model_kwargs)


class CLIPTextWithProjectionOnnxConfig(TextEncoderOnnxConfig):
ATOL_FOR_VALIDATION = 1e-3
# The ONNX export of this architecture needs the Trilu operator support, available since opset 14
Expand Down
45 changes: 45 additions & 0 deletions optimum/exporters/onnx/model_patcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -813,3 +813,48 @@ def patched_forward(
return filterd_outputs

self.patched_forward = patched_forward


class SentenceTransformersTransformerPatcher(ModelPatcher):
def __init__(
self,
config: "OnnxConfig",
model: Union["PreTrainedModel", "TFPreTrainedModel"],
model_kwargs: Dict[str, Any],
):
super().__init__(config, model, model_kwargs)

def patched_forward(input_ids, attention_mask):
result = self.orig_forward({"input_ids": input_ids, "attention_mask": attention_mask})

return result

self.patched_forward = patched_forward


class SentenceTransformersCLIPPatcher(ModelPatcher):
def __init__(
self,
config: "OnnxConfig",
model: Union["PreTrainedModel", "TFPreTrainedModel"],
model_kwargs: Dict[str, Any],
):
super().__init__(config, model, model_kwargs)

def patched_forward(input_ids, attention_mask, pixel_values):
vision_outputs = model[0].model.vision_model(pixel_values=pixel_values)
image_embeds = model[0].model.visual_projection(vision_outputs[1])

text_outputs = model[0].model.text_model(
input_ids=input_ids,
attention_mask=attention_mask,
)
text_embeds = model[0].model.text_projection(text_outputs[1])

if len(model) > 1:
image_embeds = model[1:](image_embeds)
text_embeds = model[1:](text_embeds)

return {"text_embeds": text_embeds, "image_embeds": image_embeds}

self.patched_forward = patched_forward
114 changes: 77 additions & 37 deletions optimum/exporters/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,10 +195,16 @@ class TasksManager:
"image-classification": "create_model",
}

_SENTENCE_TRANSFORMERS_TASKS_TO_MODEL_LOADERS = {
"feature-extraction": "SentenceTransformer",
"sentence-similarity": "SentenceTransformer",
}

_LIBRARY_TO_TASKS_TO_MODEL_LOADER_MAP = {
"transformers": _TRANSFORMERS_TASKS_TO_MODEL_LOADERS,
"diffusers": _DIFFUSERS_TASKS_TO_MODEL_LOADERS,
"sentence_transformers": _SENTENCE_TRANSFORMERS_TASKS_TO_MODEL_LOADERS,
"timm": _TIMM_TASKS_TO_MODEL_LOADERS,
"transformers": _TRANSFORMERS_TASKS_TO_MODEL_LOADERS,
}

if is_tf_available():
Expand Down Expand Up @@ -254,9 +260,10 @@ class TasksManager:

# Reverse dictionaries str -> str, where several model loaders may map to the same task
_LIBRARY_TO_MODEL_LOADERS_TO_TASKS_MAP = {
"transformers": get_model_loaders_to_tasks(_TRANSFORMERS_TASKS_TO_MODEL_LOADERS),
"diffusers": get_model_loaders_to_tasks(_DIFFUSERS_TASKS_TO_MODEL_LOADERS),
"sentence_transformers": get_model_loaders_to_tasks(_SENTENCE_TRANSFORMERS_TASKS_TO_MODEL_LOADERS),
"timm": get_model_loaders_to_tasks(_TIMM_TASKS_TO_MODEL_LOADERS),
"transformers": get_model_loaders_to_tasks(_TRANSFORMERS_TASKS_TO_MODEL_LOADERS),
}
_LIBRARY_TO_TF_MODEL_LOADERS_TO_TASKS_MAP = {
"transformers": get_model_loaders_to_tasks(_TRANSFORMERS_TASKS_TO_TF_MODEL_LOADERS),
Expand Down Expand Up @@ -857,6 +864,16 @@ class TasksManager:
"semantic-segmentation",
onnx="SegformerOnnxConfig",
),
"sentence-transformers-clip": supported_tasks_mapping(
"feature-extraction",
"sentence-similarity",
onnx="SentenceTransformersCLIPOnnxConfig",
),
"sentence-transformers-transformer": supported_tasks_mapping(
"feature-extraction",
"sentence-similarity",
onnx="SentenceTransformersTransformerOnnxConfig",
),
"sew": supported_tasks_mapping(
"feature-extraction",
"automatic-speech-recognition",
Expand Down Expand Up @@ -1340,6 +1357,9 @@ def determine_framework(
):
# stable diffusion case
framework = "pt"
elif "config_sentence_transformers.json" in all_files:
# Sentence Transformers libary relies on PyTorch.
framework = "pt"
else:
if request_exception is not None:
raise RequestsConnectionError(
Expand Down Expand Up @@ -1544,6 +1564,9 @@ def infer_library_from_model(
if not full_model_path.is_dir():
model_info = huggingface_hub.model_info(model_name_or_path, revision=revision)
library_name = getattr(model_info, "library_name", None)
library_name = library_name.replace(
"-", "_"
) # sentence-transformers name on python side is sentence_transformers

if library_name is None:
all_files, _ = TasksManager.get_model_files(model_name_or_path, subfolder, cache_dir)
Expand All @@ -1564,17 +1587,16 @@ def infer_library_from_model(
library_name = "timm"
elif hasattr(model_config, "_diffusers_version"):
library_name = "diffusers"
elif any(file_path.starswith("sentence_") for file_path in all_files):
library_name = "sentence_transformers"
else:
library_name = "transformers"

if library_name is None:
raise ValueError(
"The library_name could not be automatically inferred. If using the command-line, please provide the argument --library (transformers,diffusers,timm)!"
"The library_name could not be automatically inferred. If using the command-line, please provide the argument --library {transformers,diffusers,timm,sentence_transformers}. Example: `--library diffusers`."
)

if library_name == "sentence-transformers":
return "transformers"

return library_name

@classmethod
Expand Down Expand Up @@ -1633,6 +1655,17 @@ def standardize_model_attributes(
model_type = json.load(fp)["architecture"]

setattr(model.config, "model_type", model_type)
elif library_name == "sentence_transformers":
if "Transformer" in model[0].__class__.__name__:
model.config = model[0].auto_model.config
model.config.model_type = "sentence-transformers-transformer"
elif "CLIP" in model[0].__class__.__name__:
model.config = model[0].model.config
model.config.model_type = "sentence-transformers-clip"
else:
raise ValueError(
f"The export of a sentence-transformers model with the first module being {model[0].__class__.__name__} is currently not supported in Optimum. Please open an issue or submit a PR to add the support."
)

@staticmethod
def get_all_tasks():
Expand Down Expand Up @@ -1733,39 +1766,46 @@ def get_model_from_task(

if library_name == "timm":
model = model_class(f"hf_hub:{model_name_or_path}", pretrained=True, exportable=True)
TasksManager.standardize_model_attributes(
model_name_or_path, model, subfolder, revision, cache_dir, library_name
elif library_name == "sentence_transformers":
cache_folder = model_kwargs.pop("cache_folder", None)
use_auth_token = model_kwargs.pop("use_auth_token", None)
model = model_class(
model_name_or_path, device=device, cache_folder=cache_folder, use_auth_token=use_auth_token
)
return model

try:
if framework == "pt":
kwargs["torch_dtype"] = torch_dtype

if isinstance(device, str):
device = torch.device(device)
elif device is None:
device = torch.device("cpu")

# TODO : fix EulerDiscreteScheduler loading to enable for SD models
if version.parse(torch.__version__) >= version.parse("2.0") and library_name != "diffusers":
with device:
# Initialize directly in the requested device, to save allocation time. Especially useful for large
# models to initialize on cuda device.
model = model_class.from_pretrained(model_name_or_path, **kwargs)
else:
try:
if framework == "pt":
kwargs["torch_dtype"] = torch_dtype

if isinstance(device, str):
device = torch.device(device)
elif device is None:
device = torch.device("cpu")

# TODO : fix EulerDiscreteScheduler loading to enable for SD models
if version.parse(torch.__version__) >= version.parse("2.0") and library_name != "diffusers":
with device:
# Initialize directly in the requested device, to save allocation time. Especially useful for large
# models to initialize on cuda device.
model = model_class.from_pretrained(model_name_or_path, **kwargs)
else:
model = model_class.from_pretrained(model_name_or_path, **kwargs).to(device)
else:
model = model_class.from_pretrained(model_name_or_path, **kwargs).to(device)
else:
model = model_class.from_pretrained(model_name_or_path, **kwargs)
except OSError:
if framework == "pt":
logger.info("Loading TensorFlow model in PyTorch before exporting.")
kwargs["from_tf"] = True
model = model_class.from_pretrained(model_name_or_path, **kwargs)
else:
logger.info("Loading PyTorch model in TensorFlow before exporting.")
kwargs["from_pt"] = True
model = model_class.from_pretrained(model_name_or_path, **kwargs)
model = model_class.from_pretrained(model_name_or_path, **kwargs)
except OSError:
if framework == "pt":
logger.info("Loading TensorFlow model in PyTorch before exporting.")
kwargs["from_tf"] = True
model = model_class.from_pretrained(model_name_or_path, **kwargs)
else:
logger.info("Loading PyTorch model in TensorFlow before exporting.")
kwargs["from_pt"] = True
model = model_class.from_pretrained(model_name_or_path, **kwargs)

TasksManager.standardize_model_attributes(
model_name_or_path, model, subfolder, revision, cache_dir, library_name
)

return model

@staticmethod
Expand Down
1 change: 1 addition & 0 deletions optimum/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
is_onnx_available,
is_onnxruntime_available,
is_pydantic_available,
is_sentence_transformers_available,
is_timm_available,
is_torch_onnx_support_available,
require_numpy_strictly_lower,
Expand Down
7 changes: 6 additions & 1 deletion optimum/utils/import_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,8 @@
_accelerate_available = importlib.util.find_spec("accelerate") is not None
_diffusers_available = importlib.util.find_spec("diffusers") is not None
_auto_gptq_available = importlib.util.find_spec("auto_gptq") is not None
_timm_available = importlib.util.find_spec("diffusers") is not None
_timm_available = importlib.util.find_spec("timm") is not None
_sentence_transformers_available = importlib.util.find_spec("sentence_transformers") is not None

torch_version = None
if is_torch_available():
Expand Down Expand Up @@ -107,6 +108,10 @@ def is_timm_available():
return _timm_available


def is_sentence_transformers_available():
return _sentence_transformers_available


def is_auto_gptq_available():
if _auto_gptq_available:
version_autogptq = packaging.version.parse(importlib_metadata.version("auto_gptq"))
Expand Down
12 changes: 11 additions & 1 deletion optimum/utils/testing_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,13 @@

import torch

from . import is_accelerate_available, is_auto_gptq_available, is_diffusers_available, is_timm_available
from . import (
is_accelerate_available,
is_auto_gptq_available,
is_diffusers_available,
is_sentence_transformers_available,
is_timm_available,
)


# Used to test the hub
Expand Down Expand Up @@ -137,6 +143,10 @@ def require_timm(test_case):
return unittest.skipUnless(is_timm_available(), "test requires timm")(test_case)


def require_sentence_transformers(test_case):
return unittest.skipUnless(is_sentence_transformers_available(), "test requires sentence-transformers")(test_case)


def grid_parameters(
parameters: Dict[str, Iterable[Any]],
yield_dict: bool = False,
Expand Down
5 changes: 5 additions & 0 deletions tests/exporters/exporters_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -293,3 +293,8 @@
"resnext101-32x8d": "timm/resnext101_32x8d.tv_in1k",
"resnext101-64x4d": "timm/resnext101_64x4d.c1_in1k",
}

PYTORCH_SENTENCE_TRANSFORMERS_MODEL = {
"sentence-transformers-clip": "sentence-transformers/all-MiniLM-L6-v2",
"sentence-transformers-transformer": "sentence-transformers/clip-ViT-B-32-multilingual-v1",
}
Loading

0 comments on commit 304a4ac

Please sign in to comment.