Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ORTOptimizer support ORTModelForCausalLM #794

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 29 additions & 20 deletions optimum/onnxruntime/configuration.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,11 @@
from onnxruntime.transformers.fusion_options import FusionOptions

from ..configuration_utils import BaseConfig
from ..utils import logging


logger = logging.get_logger(__name__)

NodeName = NodeType = str

# This value is used to indicate ORT which axis it should use to quantize an operator "per-channel"
Expand All @@ -55,18 +58,18 @@ class CalibrationConfig:
The number of samples composing the calibration dataset.
method (`CalibrationMethod`):
The method chosen to calculate the activations quantization parameters using the calibration dataset.
num_bins (`int`, *optional*):
num_bins (`Optional[int]`, defaults to `None`):
The number of bins to use when creating the histogram when performing the calibration step using the
Percentile or Entropy method.
num_quantized_bins (`int`, *optional*):
num_quantized_bins (`Optional[int]`, defaults to `None`):
The number of quantized bins to use when performing the calibration step using the Entropy method.
percentile (`float`, *optional*):
percentile (`Optional[float]`, defaults to `None`):
The percentile to use when computing the activations quantization ranges when performing the calibration
step using the Percentile method.
moving_average (`bool`, *optional*):
moving_average (`Optional[bool]`, defaults to `None`):
Whether to compute the moving average of the minimum and maximum values when performing the calibration step
using the MinMax method.
averaging_constant (`float`, *optional*):
averaging_constant (`Optional[float]`, defaults to `None`):
The constant smoothing factor to use when computing the moving average of the minimum and maximum values.
Effective only when the MinMax calibration method is selected and `moving_average` is set to True.
"""
Expand Down Expand Up @@ -812,10 +815,10 @@ def with_optimization_level(cls, optimization_level: str, for_gpu: bool = False,
- O2: Basic and extended general optimizations, transformers-specific fusions.
- O3: Same as O2 with Fast Gelu approximation.
- O4: Same as O3 with mixed precision.
for_gpu (`bool`, *optional*, defaults to `False`):
for_gpu (`bool`, defaults to `False`):
Whether the model to optimize will run on GPU, some optimizations depends on the hardware the model
will run on. Only needed for optimization_level > 1.
kwargs (`Dict[str, Any]`, *optional*):
kwargs (`Dict[str, Any]`):
Arguments to provide to the [`~OptimizationConfig`] constructor.

Returns:
Expand All @@ -825,6 +828,12 @@ def with_optimization_level(cls, optimization_level: str, for_gpu: bool = False,
raise ValueError(
f"optimization_level must be in {', '.join(cls._LEVELS.keys())}, got {optimization_level}"
)

if optimization_level == "O4":
if for_gpu is False:
logger.warning("Overridding for_gpu=False to for_gpu=True as half precision is available only on GPU.")
for_gpu = True

return OptimizationConfig(optimize_for_gpu=for_gpu, **cls._LEVELS[optimization_level], **kwargs)

@classmethod
Expand All @@ -833,10 +842,10 @@ def O1(cls, for_gpu: bool = False, **kwargs) -> OptimizationConfig:
Creates an O1 [`~OptimizationConfig`].

Args:
for_gpu (`bool`, *optional*, defaults to `False`):
for_gpu (`bool`, defaults to `False`):
Whether the model to optimize will run on GPU, some optimizations depends on the hardware the model
will run on. Only needed for optimization_level > 1.
kwargs (`Dict[str, Any]`, *optional*):
kwargs (`Dict[str, Any]`):
Arguments to provide to the [`~OptimizationConfig`] constructor.

Returns:
Expand All @@ -850,10 +859,10 @@ def O2(cls, for_gpu: bool = False, **kwargs) -> OptimizationConfig:
Creates an O2 [`~OptimizationConfig`].

Args:
for_gpu (`bool`, *optional*, defaults to `False`):
for_gpu (`bool`, defaults to `False`):
Whether the model to optimize will run on GPU, some optimizations depends on the hardware the model
will run on. Only needed for optimization_level > 1.
kwargs (`Dict[str, Any]`, *optional*):
kwargs (`Dict[str, Any]`):
Arguments to provide to the [`~OptimizationConfig`] constructor.

Returns:
Expand All @@ -867,10 +876,10 @@ def O3(cls, for_gpu: bool = False, **kwargs) -> OptimizationConfig:
Creates an O3 [`~OptimizationConfig`].

Args:
for_gpu (`bool`, *optional*, defaults to `False`):
for_gpu (`bool`, defaults to `False`):
Whether the model to optimize will run on GPU, some optimizations depends on the hardware the model
will run on. Only needed for optimization_level > 1.
kwargs (`Dict[str, Any]`, *optional*):
kwargs (`Dict[str, Any]`):
Arguments to provide to the [`~OptimizationConfig`] constructor.

Returns:
Expand All @@ -879,15 +888,15 @@ def O3(cls, for_gpu: bool = False, **kwargs) -> OptimizationConfig:
return cls.with_optimization_level("O3", for_gpu=for_gpu, **kwargs)

@classmethod
def O4(cls, for_gpu: bool = False, **kwargs) -> OptimizationConfig:
def O4(cls, for_gpu: bool = True, **kwargs) -> OptimizationConfig:
"""
Creates an O4 [`~OptimizationConfig`].

Args:
for_gpu (`bool`, *optional*, defaults to `False`):
for_gpu (`bool`, defaults to `False`):
Whether the model to optimize will run on GPU, some optimizations depends on the hardware the model
will run on. Only needed for optimization_level > 1.
kwargs (`Dict[str, Any]`, *optional*):
kwargs (`Dict[str, Any]`):
Arguments to provide to the [`~OptimizationConfig`] constructor.

Returns:
Expand All @@ -902,17 +911,17 @@ class ORTConfig(BaseConfig):
optimization and quantization parameters.

Attributes:
opset (`int`, *optional*):
opset (`Optional[int]`, defaults to `None`):
ONNX opset version to export the model with.
use_external_data_format (`bool`, *optional*, defaults to `False`):
use_external_data_format (`bool`, defaults to `False`):
Allow exporting model >= than 2Gb.
one_external_file (`bool`, defaults to `True`):
When `use_external_data_format=True`, whether to save all tensors to one external file.
If false, save each tensor to a file named with the tensor name.
(Can not be set to `False` for the quantization)
optimization (`OptimizationConfig`, *optional*, defaults to None):
optimization (`Optional[OptimizationConfig]`, defaults to `None`):
Specify a configuration to optimize ONNX Runtime model
quantization (`QuantizationConfig`, *optional*, defaults to None):
quantization (`Optional[QuantizationConfig]`, defaults to `None`):
Specify a configuration to quantize ONNX Runtime model
"""

Expand Down
2 changes: 1 addition & 1 deletion optimum/onnxruntime/modeling_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -350,7 +350,7 @@ def _from_pretrained(
try:
decoder_with_past_path = ORTModelDecoder.infer_onnx_filename(
model_id,
DECODER_WITH_PAST_ONNX_FILE_PATTERN,
[DECODER_WITH_PAST_ONNX_FILE_PATTERN],
"decoder_with_past_file_name",
subfolder=subfolder,
use_auth_token=use_auth_token,
Expand Down
14 changes: 13 additions & 1 deletion optimum/onnxruntime/optimization.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from ..utils import CONFIG_NAME, NormalizedConfigManager
from ..utils.save_utils import maybe_save_preprocessors
from .configuration import OptimizationConfig, ORTConfig
from .modeling_decoder import ORTModelForCausalLM
from .modeling_ort import ORTModel
from .modeling_seq2seq import ORTModelForSeq2SeqLM
from .utils import ONNX_WEIGHTS_NAME, ORTConfigManager
Expand Down Expand Up @@ -83,6 +84,17 @@ def from_pretrained(
# Add the decoder with past key/values if present
if model_or_path.use_cache:
onnx_model_path.append(model_or_path.decoder_with_past_model_path)
elif isinstance(model_or_path, ORTModelForCausalLM):
if model_or_path.use_merged is True:
raise NotImplementedError(
"ORTOptimizer does not support ORTModelForCausalLM models that use a single ONNX for both the without/with past cases."
" Please pass an ORTModelForCausalLM that uses a separate ONNX for each without/with past cases. The can be done"
" by using `ORTModelForCausalLM.from_pretrained(..., from_transformers=True, use_merged=False)`, or by"
" using the option `--no-post-process` in the optimum-cli ONNX export tool."
)
onnx_model_path.append(model_or_path.decoder_model_path)
if model_or_path.use_cache:
onnx_model_path.append(model_or_path.decoder_with_past_model_path)
else:
onnx_model_path.append(model_or_path.model_path)
config = model_or_path.config
Expand Down Expand Up @@ -125,7 +137,7 @@ def optimize(
"""
save_dir = Path(save_dir)
save_dir.mkdir(parents=True, exist_ok=True)
ORTConfigManager.check_optimization_supported_model(self.model_type)
ORTConfigManager.check_optimization_supported_model(self.model_type, optimization_config)

self.config.save_pretrained(save_dir)
maybe_save_preprocessors(self.onnx_model_path[0].parent, save_dir)
Expand Down
16 changes: 12 additions & 4 deletions optimum/onnxruntime/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,12 +75,15 @@ class ORTConfigManager:
"""

# Contribution note: Please add new models in alphabetical order
# TODO: for encoder-decoder models, validate if bert or gpt2 optimization is better
_conf = {
"albert": "bert",
"bart": "bart",
"bert": "bert",
"big_bird": "bert",
# "bigbird_pegasus": None, # bug in `fusion_skiplayernorm.py`
"blenderbot": "bert",
"bloom": "gpt2",
"camembert": "bert",
"codegen": "gpt2",
"deberta": "bert",
Expand All @@ -89,15 +92,18 @@ class ORTConfigManager:
"electra": "bert",
"gpt2": "gpt2",
"gpt_neo": "gpt2",
"gpt_neox": "gpt2",
"gptj": "gpt2",
# longt5 with O4 results in segmentation fault
"longt5": "bert",
"marian": "bart",
"mbart": "bart",
"mt5": "bart",
"m2m_100": "bart",
"nystromformer": "bert",
"pegasus": "bert",
"roberta": "bert",
"t5": "t5",
"whisper": "whisper",
"t5": "bert",
"xlm-roberta": "bert",
}

Expand All @@ -116,8 +122,10 @@ def check_supported_model(cls, model_type: str):
)

@classmethod
def check_optimization_supported_model(cls, model_type: str):
supported_model_types_for_optimization = ["bert", "gpt2", "bart"]
def check_optimization_supported_model(cls, model_type: str, optimization_config):
# as of 1.14.O: https://github.com/microsoft/onnxruntime/blob/6ccaeddefa65ccac402a47fa4d9cad8229794bb2/onnxruntime/python/tools/transformers/optimizer.py#L39
supported_model_types_for_optimization = ["bert", "gpt2", "bart", "unet"]

if (model_type not in cls._conf) or (cls._conf[model_type] not in supported_model_types_for_optimization):
raise KeyError(
f"ONNX Runtime doesn't support the graph optimization of {model_type} yet. Only {supported_model_types_for_optimization} are supported. "
Expand Down
76 changes: 1 addition & 75 deletions tests/onnxruntime/test_modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@
from transformers.modeling_utils import no_init_weights
from transformers.onnx.utils import get_preprocessor
from transformers.testing_utils import get_gpu_count, require_torch_gpu
from utils_onnxruntime_tests import MODEL_NAMES, SEED

from optimum.exporters import TasksManager
from optimum.onnx.utils import has_onnx_input
Expand Down Expand Up @@ -96,81 +97,6 @@ def __exit__(self, type, value, traceback):
self.elapsed = (time.perf_counter() - self.elapsed) * 1e3


MODEL_NAMES = {
"albert": "hf-internal-testing/tiny-random-AlbertModel",
"audio_spectrogram_transformer": "Ericwang/tiny-random-ast",
"beit": "hf-internal-testing/tiny-random-BeitForImageClassification",
"bert": "hf-internal-testing/tiny-random-BertModel",
"bart": "hf-internal-testing/tiny-random-bart",
# "big_bird": "hf-internal-testing/tiny-random-BigBirdModel",
# "bigbird_pegasus": "hf-internal-testing/tiny-random-bigbird_pegasus",
"blenderbot_small": "hf-internal-testing/tiny-random-BlenderbotModel",
"blenderbot": "hf-internal-testing/tiny-random-BlenderbotModel",
"bloom": "hf-internal-testing/tiny-random-BloomModel",
"camembert": "hf-internal-testing/tiny-random-camembert",
"clip": "hf-internal-testing/tiny-random-CLIPModel",
"convbert": "hf-internal-testing/tiny-random-ConvBertModel",
"codegen": "hf-internal-testing/tiny-random-CodeGenModel",
"data2vec_text": "hf-internal-testing/tiny-random-Data2VecTextModel",
"data2vec_vision": "hf-internal-testing/tiny-random-Data2VecVisionModel",
"data2vec_audio": "hf-internal-testing/tiny-random-Data2VecAudioModel",
"deberta": "hf-internal-testing/tiny-random-DebertaModel",
"deberta_v2": "hf-internal-testing/tiny-random-DebertaV2Model",
"deit": "hf-internal-testing/tiny-random-DeiTModel",
"convnext": "hf-internal-testing/tiny-random-convnext",
"detr": "hf-internal-testing/tiny-random-detr",
"distilbert": "hf-internal-testing/tiny-random-DistilBertModel",
"electra": "hf-internal-testing/tiny-random-ElectraModel",
"flaubert": "hf-internal-testing/tiny-random-flaubert",
"gpt2": "hf-internal-testing/tiny-random-gpt2",
"gpt_neo": "hf-internal-testing/tiny-random-GPTNeoModel",
"gpt_neox": "hf-internal-testing/tiny-random-GPTNeoXForCausalLM",
"gptj": "hf-internal-testing/tiny-random-GPTJModel",
"groupvit": "hf-internal-testing/tiny-random-groupvit",
"ibert": "hf-internal-testing/tiny-random-IBertModel",
"levit": "hf-internal-testing/tiny-random-LevitModel",
"layoutlm": "hf-internal-testing/tiny-random-LayoutLMModel",
"layoutlmv3": "hf-internal-testing/tiny-random-LayoutLMv3Model",
"longt5": "hf-internal-testing/tiny-random-LongT5Model",
"m2m_100": "hf-internal-testing/tiny-random-m2m_100",
"marian": "sshleifer/tiny-marian-en-de", # hf-internal-testing ones are broken
"mbart": "hf-internal-testing/tiny-random-mbart",
"mobilebert": "hf-internal-testing/tiny-random-MobileBertModel",
"mobilenet_v1": "google/mobilenet_v1_0.75_192",
"mobilenet_v2": "hf-internal-testing/tiny-random-MobileNetV2Model",
"mobilevit": "hf-internal-testing/tiny-random-mobilevit",
"mt5": "lewtun/tiny-random-mt5",
"nystromformer": "hf-internal-testing/tiny-random-NystromformerModel",
"pegasus": "hf-internal-testing/tiny-random-PegasusModel",
"poolformer": "hf-internal-testing/tiny-random-PoolFormerModel",
"resnet": "hf-internal-testing/tiny-random-resnet",
"roberta": "hf-internal-testing/tiny-random-RobertaModel",
"roformer": "hf-internal-testing/tiny-random-RoFormerModel",
"segformer": "hf-internal-testing/tiny-random-SegformerModel",
"squeezebert": "hf-internal-testing/tiny-random-SqueezeBertModel",
"swin": "hf-internal-testing/tiny-random-SwinModel",
"t5": "hf-internal-testing/tiny-random-t5",
"vit": "hf-internal-testing/tiny-random-vit",
"yolos": "hf-internal-testing/tiny-random-YolosModel",
"whisper": "openai/whisper-tiny.en", # hf-internal-testing ones are broken
"hubert": "hf-internal-testing/tiny-random-HubertModel",
"wav2vec2": "hf-internal-testing/tiny-random-Wav2Vec2Model",
"wav2vec2-conformer": "hf-internal-testing/tiny-random-wav2vec2-conformer",
"wavlm": "hf-internal-testing/tiny-random-wavlm",
"sew": "hf-internal-testing/tiny-random-SEWModel",
"sew_d": "hf-internal-testing/tiny-random-SEWDModel",
"speech_to_text": "hf-internal-testing/tiny-random-Speech2TextModel",
"unispeech": "hf-internal-testing/tiny-random-unispeech",
"unispeech_sat": "hf-internal-testing/tiny-random-unispeech-sat",
"xlm": "hf-internal-testing/tiny-random-XLMModel",
"xlm_roberta": "hf-internal-testing/tiny-xlm-roberta",
"vision-encoder-decoder": "hf-internal-testing/tiny-random-VisionEncoderDecoderModel-vit-gpt2",
"trocr": "microsoft/trocr-small-handwritten",
}

SEED = 42


class ORTModelTestMixin(unittest.TestCase):
ARCH_MODEL_MAP = {}

Expand Down
Loading