Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ORTOptimizer support ORTModelForCausalLM #794

Merged
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion optimum/onnxruntime/modeling_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -350,7 +350,7 @@ def _from_pretrained(
try:
decoder_with_past_path = ORTModelDecoder.infer_onnx_filename(
model_id,
DECODER_WITH_PAST_ONNX_FILE_PATTERN,
[DECODER_WITH_PAST_ONNX_FILE_PATTERN],
"decoder_with_past_file_name",
subfolder=subfolder,
use_auth_token=use_auth_token,
Expand Down
14 changes: 13 additions & 1 deletion optimum/onnxruntime/optimization.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from ..utils import CONFIG_NAME, NormalizedConfigManager
from ..utils.save_utils import maybe_save_preprocessors
from .configuration import OptimizationConfig, ORTConfig
from .modeling_decoder import ORTModelForCausalLM
from .modeling_ort import ORTModel
from .modeling_seq2seq import ORTModelForSeq2SeqLM
from .utils import ONNX_WEIGHTS_NAME, ORTConfigManager
Expand Down Expand Up @@ -83,6 +84,17 @@ def from_pretrained(
# Add the decoder with past key/values if present
if model_or_path.use_cache:
onnx_model_path.append(model_or_path.decoder_with_past_model_path)
elif isinstance(model_or_path, ORTModelForCausalLM):
if model_or_path.use_merged is True:
raise NotImplementedError(
"ORTOptimizer does not support ORTModelForCausalLM models that use a single ONNX for both the without/with past cases."
" Please pass an ORTModelForCausalLM that uses a separate ONNX for each without/with past cases. The can be done"
" by using `ORTModelForCausalLM.from_pretrained(..., from_transformers=True, use_merged=False)`, or by"
" using the option `--no-post-process` in the optimum-cli ONNX export tool."
)
onnx_model_path.append(model_or_path.decoder_model_path)
if model_or_path.use_cache:
onnx_model_path.append(model_or_path.decoder_with_past_model_path)
else:
onnx_model_path.append(model_or_path.model_path)
config = model_or_path.config
Expand Down Expand Up @@ -125,7 +137,7 @@ def optimize(
"""
save_dir = Path(save_dir)
save_dir.mkdir(parents=True, exist_ok=True)
ORTConfigManager.check_optimization_supported_model(self.model_type)
ORTConfigManager.check_optimization_supported_model(self.model_type, optimization_config)

self.config.save_pretrained(save_dir)
maybe_save_preprocessors(self.onnx_model_path[0].parent, save_dir)
Expand Down
16 changes: 12 additions & 4 deletions optimum/onnxruntime/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,12 +75,15 @@ class ORTConfigManager:
"""

# Contribution note: Please add new models in alphabetical order
# TODO: for encoder-decoder models, validate if bert or gpt2 optimization is better
_conf = {
"albert": "bert",
"bart": "bart",
"bert": "bert",
"big_bird": "bert",
# "bigbird_pegasus": None, # bug in `fusion_skiplayernorm.py`
"blenderbot": "bert",
"bloom": "gpt2",
"camembert": "bert",
"codegen": "gpt2",
"deberta": "bert",
Expand All @@ -89,15 +92,18 @@ class ORTConfigManager:
"electra": "bert",
"gpt2": "gpt2",
"gpt_neo": "gpt2",
"gpt_neox": "gpt2",
"gptj": "gpt2",
# longt5 with O4 results in segmentation fault
"longt5": "bert",
"marian": "bart",
"mbart": "bart",
"mt5": "bart",
"m2m_100": "bart",
"nystromformer": "bert",
"pegasus": "bert",
"roberta": "bert",
"t5": "t5",
"whisper": "whisper",
"t5": "bert",
"xlm-roberta": "bert",
}

Expand All @@ -116,8 +122,10 @@ def check_supported_model(cls, model_type: str):
)

@classmethod
def check_optimization_supported_model(cls, model_type: str):
supported_model_types_for_optimization = ["bert", "gpt2", "bart"]
def check_optimization_supported_model(cls, model_type: str, optimization_config):
# as of 1.14.O: https://github.com/microsoft/onnxruntime/blob/6ccaeddefa65ccac402a47fa4d9cad8229794bb2/onnxruntime/python/tools/transformers/optimizer.py#L39
supported_model_types_for_optimization = ["bert", "gpt2", "bart", "unet"]

if (model_type not in cls._conf) or (cls._conf[model_type] not in supported_model_types_for_optimization):
raise KeyError(
f"ONNX Runtime doesn't support the graph optimization of {model_type} yet. Only {supported_model_types_for_optimization} are supported. "
Expand Down
76 changes: 1 addition & 75 deletions tests/onnxruntime/test_modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@
from transformers.modeling_utils import no_init_weights
from transformers.onnx.utils import get_preprocessor
from transformers.testing_utils import get_gpu_count, require_torch_gpu
from utils_onnxruntime_tests import MODEL_NAMES, SEED

from optimum.exporters import TasksManager
from optimum.onnx.utils import has_onnx_input
Expand Down Expand Up @@ -96,81 +97,6 @@ def __exit__(self, type, value, traceback):
self.elapsed = (time.perf_counter() - self.elapsed) * 1e3


MODEL_NAMES = {
"albert": "hf-internal-testing/tiny-random-AlbertModel",
"audio_spectrogram_transformer": "Ericwang/tiny-random-ast",
"beit": "hf-internal-testing/tiny-random-BeitForImageClassification",
"bert": "hf-internal-testing/tiny-random-BertModel",
"bart": "hf-internal-testing/tiny-random-bart",
# "big_bird": "hf-internal-testing/tiny-random-BigBirdModel",
# "bigbird_pegasus": "hf-internal-testing/tiny-random-bigbird_pegasus",
"blenderbot_small": "hf-internal-testing/tiny-random-BlenderbotModel",
"blenderbot": "hf-internal-testing/tiny-random-BlenderbotModel",
"bloom": "hf-internal-testing/tiny-random-BloomModel",
"camembert": "hf-internal-testing/tiny-random-camembert",
"clip": "hf-internal-testing/tiny-random-CLIPModel",
"convbert": "hf-internal-testing/tiny-random-ConvBertModel",
"codegen": "hf-internal-testing/tiny-random-CodeGenModel",
"data2vec_text": "hf-internal-testing/tiny-random-Data2VecTextModel",
"data2vec_vision": "hf-internal-testing/tiny-random-Data2VecVisionModel",
"data2vec_audio": "hf-internal-testing/tiny-random-Data2VecAudioModel",
"deberta": "hf-internal-testing/tiny-random-DebertaModel",
"deberta_v2": "hf-internal-testing/tiny-random-DebertaV2Model",
"deit": "hf-internal-testing/tiny-random-DeiTModel",
"convnext": "hf-internal-testing/tiny-random-convnext",
"detr": "hf-internal-testing/tiny-random-detr",
"distilbert": "hf-internal-testing/tiny-random-DistilBertModel",
"electra": "hf-internal-testing/tiny-random-ElectraModel",
"flaubert": "hf-internal-testing/tiny-random-flaubert",
"gpt2": "hf-internal-testing/tiny-random-gpt2",
"gpt_neo": "hf-internal-testing/tiny-random-GPTNeoModel",
"gpt_neox": "hf-internal-testing/tiny-random-GPTNeoXForCausalLM",
"gptj": "hf-internal-testing/tiny-random-GPTJModel",
"groupvit": "hf-internal-testing/tiny-random-groupvit",
"ibert": "hf-internal-testing/tiny-random-IBertModel",
"levit": "hf-internal-testing/tiny-random-LevitModel",
"layoutlm": "hf-internal-testing/tiny-random-LayoutLMModel",
"layoutlmv3": "hf-internal-testing/tiny-random-LayoutLMv3Model",
"longt5": "hf-internal-testing/tiny-random-LongT5Model",
"m2m_100": "hf-internal-testing/tiny-random-m2m_100",
"marian": "sshleifer/tiny-marian-en-de", # hf-internal-testing ones are broken
"mbart": "hf-internal-testing/tiny-random-mbart",
"mobilebert": "hf-internal-testing/tiny-random-MobileBertModel",
"mobilenet_v1": "google/mobilenet_v1_0.75_192",
"mobilenet_v2": "hf-internal-testing/tiny-random-MobileNetV2Model",
"mobilevit": "hf-internal-testing/tiny-random-mobilevit",
"mt5": "lewtun/tiny-random-mt5",
"nystromformer": "hf-internal-testing/tiny-random-NystromformerModel",
"pegasus": "hf-internal-testing/tiny-random-PegasusModel",
"poolformer": "hf-internal-testing/tiny-random-PoolFormerModel",
"resnet": "hf-internal-testing/tiny-random-resnet",
"roberta": "hf-internal-testing/tiny-random-RobertaModel",
"roformer": "hf-internal-testing/tiny-random-RoFormerModel",
"segformer": "hf-internal-testing/tiny-random-SegformerModel",
"squeezebert": "hf-internal-testing/tiny-random-SqueezeBertModel",
"swin": "hf-internal-testing/tiny-random-SwinModel",
"t5": "hf-internal-testing/tiny-random-t5",
"vit": "hf-internal-testing/tiny-random-vit",
"yolos": "hf-internal-testing/tiny-random-YolosModel",
"whisper": "openai/whisper-tiny.en", # hf-internal-testing ones are broken
"hubert": "hf-internal-testing/tiny-random-HubertModel",
"wav2vec2": "hf-internal-testing/tiny-random-Wav2Vec2Model",
"wav2vec2-conformer": "hf-internal-testing/tiny-random-wav2vec2-conformer",
"wavlm": "hf-internal-testing/tiny-random-wavlm",
"sew": "hf-internal-testing/tiny-random-SEWModel",
"sew_d": "hf-internal-testing/tiny-random-SEWDModel",
"speech_to_text": "hf-internal-testing/tiny-random-Speech2TextModel",
"unispeech": "hf-internal-testing/tiny-random-unispeech",
"unispeech_sat": "hf-internal-testing/tiny-random-unispeech-sat",
"xlm": "hf-internal-testing/tiny-random-XLMModel",
"xlm_roberta": "hf-internal-testing/tiny-xlm-roberta",
"vision-encoder-decoder": "hf-internal-testing/tiny-random-VisionEncoderDecoderModel-vit-gpt2",
"trocr": "microsoft/trocr-small-handwritten",
}

SEED = 42


class ORTModelTestMixin(unittest.TestCase):
ARCH_MODEL_MAP = {}

Expand Down
179 changes: 178 additions & 1 deletion tests/onnxruntime/test_optimization.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,18 +14,65 @@

import gc
import os
import shutil
import tempfile
import unittest
from pathlib import Path
from typing import Dict

import onnx
import torch
from parameterized import parameterized
from transformers import AutoTokenizer
from utils_onnxruntime_tests import MODEL_NAMES

from optimum.onnxruntime import ORTConfig, ORTModelForSequenceClassification, ORTOptimizer
from optimum.exporters import TasksManager
from optimum.onnxruntime import AutoOptimizationConfig, ORTConfig, ORTModelForSequenceClassification, ORTOptimizer
from optimum.onnxruntime.configuration import OptimizationConfig
from optimum.onnxruntime.modeling_decoder import ORTModelForCausalLM
from optimum.onnxruntime.modeling_seq2seq import ORTModelForSeq2SeqLM
from optimum.utils.testing_utils import grid_parameters


class ORTOptimizerTestMixin(unittest.TestCase):
@classmethod
def setUpClass(cls):
cls.onnx_model_dirs = {}

def _setup(self, model_args: Dict):
"""
Exports the PyTorch models to ONNX ahead of time to avoid multiple exports during the tests.
We don't use unittest setUpClass, in order to still be able to run individual tests.
"""
model_arch = model_args["model_arch"]
model_arch_and_params = model_args["test_name"]

# TODO: this should actually be checked in ORTModel!
task = self.TASK
if "use_cache" in model_args and model_args["use_cache"] is True:
task = task + "-with-past"

if "use_cache" in model_args and task not in TasksManager.get_supported_tasks_for_model_type(
model_arch.replace("_", "-"), exporter="onnx"
):
self.skipTest("Unsupported export case")

if model_arch_and_params not in self.onnx_model_dirs:
# model_args will contain kwargs to pass to ORTModel.from_pretrained()
model_args.pop("test_name")
model_args.pop("model_arch")

model_id = MODEL_NAMES[model_arch]
onnx_model = self.ORTMODEL_CLASS.from_pretrained(model_id, **model_args, from_transformers=True)

model_dir = tempfile.mkdtemp(prefix=f"{model_arch_and_params}_{self.TASK}_")
onnx_model.save_pretrained(model_dir)
self.onnx_model_dirs[model_arch_and_params] = model_dir

@classmethod
def tearDownClass(cls):
for _, dir_path in cls.onnx_model_dirs.items():
shutil.rmtree(dir_path)


class ORTOptimizerTest(unittest.TestCase):
Expand Down Expand Up @@ -154,3 +201,133 @@ def test_optimization_fp16(self):

# Compare tensors outputs
self.assertTrue(torch.allclose(model_outputs.logits, optimized_model_outputs.logits, atol=1e-4))


class ORTOptimizerForSeq2SeqLMIntegrationTest(ORTOptimizerTestMixin):
TASK = "seq2seq-lm"
ORTMODEL_CLASS = ORTModelForSeq2SeqLM

SUPPORTED_ARCHITECTURES = [
"bart",
"blenderbot",
"blenderbot_small",
# "longt5",
"m2m_100",
"marian",
"mbart",
"mt5",
"pegasus",
"t5",
]

FULL_GRID = {
"model_arch": SUPPORTED_ARCHITECTURES,
"use_cache": [True, False],
"optimization_level": ["O1", "O2", "O3", "O4"],
}

@parameterized.expand(grid_parameters(FULL_GRID))
def test_optimization_level(self, test_name: str, model_arch: str, use_cache: bool, optimization_level: str):
export_name = test_name[:-3] # remove `_OX` that is irrelevant as the export
model_args = {"test_name": export_name, "model_arch": model_arch, "use_cache": use_cache}
self._setup(model_args)

ort_model = ORTModelForSeq2SeqLM.from_pretrained(self.onnx_model_dirs[export_name], use_cache=use_cache)

optimizer = ORTOptimizer.from_pretrained(ort_model)

optimization_config = AutoOptimizationConfig.with_optimization_level(optimization_level)
optimization_config.disable_shape_inference = True
model_id = MODEL_NAMES[model_arch]

with tempfile.TemporaryDirectory(suffix="_optimized") as tmp_dir:
optimizer.optimize(save_dir=tmp_dir, optimization_config=optimization_config)

optimized_model = ORTModelForSeq2SeqLM.from_pretrained(tmp_dir, use_cache=use_cache)

expected_ort_config = ORTConfig(optimization=optimization_config)
ort_config = ORTConfig.from_pretrained(tmp_dir)

# Verify the ORTConfig was correctly created and saved
self.assertEqual(ort_config.to_dict(), expected_ort_config.to_dict())

tokenizer = AutoTokenizer.from_pretrained(model_id)
tokens = tokenizer("This is a sample input", return_tensors="pt")
model_outputs = ort_model.generate(**tokens)
optimized_model_outputs = optimized_model.generate(**tokens)

self.assertTrue(torch.equal(model_outputs, optimized_model_outputs))
gc.collect()


class ORTOptimizerForCausalLMIntegrationTest(ORTOptimizerTestMixin):
TASK = "causal-lm"
ORTMODEL_CLASS = ORTModelForCausalLM

SUPPORTED_ARCHITECTURES = [
"bloom",
# codegen is not supported until https://github.com/microsoft/onnxruntime/pull/14751 is merged
# "codegen",
"gpt2",
"gpt_neo",
"gpt_neox",
"gptj",
]

FULL_GRID = {
"model_arch": SUPPORTED_ARCHITECTURES,
"use_cache": [True, False],
"use_merged": [True, False],
"optimization_level": ["O1", "O2", "O3", "O4"],
}

@parameterized.expand(grid_parameters(FULL_GRID))
def test_optimization_level(
self, test_name: str, model_arch: str, use_cache: bool, use_merged: bool, optimization_level: str
):
if use_cache is False and use_merged is True:
self.skipTest("use_cache=False, use_merged=True are uncompatible")

export_name = test_name[:-3] # remove `_OX` that is irrelevant as the export
model_args = {
"test_name": export_name,
"model_arch": model_arch,
"use_cache": use_cache,
"use_merged": use_merged,
}
self._setup(model_args)

ort_model = ORTModelForCausalLM.from_pretrained(self.onnx_model_dirs[export_name], use_cache=use_cache)

if use_merged:
with self.assertRaises(NotImplementedError) as cm:
optimizer = ORTOptimizer.from_pretrained(ort_model)

self.assertTrue("ORTModelForCausalLM models that use a single ONNX" in str(cm.exception))
self.skipTest("Unsupported optimization case")
else:
optimizer = ORTOptimizer.from_pretrained(ort_model)

optimization_config = AutoOptimizationConfig.with_optimization_level(optimization_level)
optimization_config.disable_shape_inference = True
model_id = MODEL_NAMES[model_arch]

with tempfile.TemporaryDirectory(suffix="_opt") as tmp_dir:
optimizer.optimize(save_dir=tmp_dir, optimization_config=optimization_config)

optimized_model = ORTModelForCausalLM.from_pretrained(tmp_dir, use_cache=use_cache)

expected_ort_config = ORTConfig(optimization=optimization_config)
ort_config = ORTConfig.from_pretrained(tmp_dir)

# Verify the ORTConfig was correctly created and saved
self.assertEqual(ort_config.to_dict(), expected_ort_config.to_dict())

tokenizer = AutoTokenizer.from_pretrained(model_id)
tokens = tokenizer("This is a sample input", return_tensors="pt")

model_outputs = ort_model.generate(**tokens)
optimized_model_outputs = optimized_model.generate(**tokens)

self.assertTrue(torch.equal(model_outputs, optimized_model_outputs))
gc.collect()
Loading