diff --git a/optimum/exporters/onnx/__main__.py b/optimum/exporters/onnx/__main__.py index b916ea49eb..d728a5ed14 100644 --- a/optimum/exporters/onnx/__main__.py +++ b/optimum/exporters/onnx/__main__.py @@ -68,7 +68,7 @@ def main(): "--for-ort", action="store_true", help=( - "This generates ONNX models to run inference with ONNX Runtime ORTModelXXX for encoder-decoder models." + "This exports models ready to be run with optimum.onnxruntime ORTModelXXX. Useful for encoder-decoder models." " If enabled the encoder and decoder of the model are exported separately." ), ) diff --git a/optimum/exporters/onnx/base.py b/optimum/exporters/onnx/base.py index be5eaee99b..d59d6f687e 100644 --- a/optimum/exporters/onnx/base.py +++ b/optimum/exporters/onnx/base.py @@ -315,44 +315,6 @@ def generate_dummy_inputs_onnxruntime(self, reference_model_inputs: Mapping[str, """ return reference_model_inputs - def get_encoder_onnx_config(self, config: "PretrainedConfig"): - """ - Returns ONNX encoder config for `Seq2Seq` models. Implement the method to export the encoder - of the model separately. - - Args: - config (`PretrainedConfig`): - The encoder model's configuration to use when exporting to ONNX. - - Returns: - `OnnxConfig`: An instance of the ONNX configuration object. - """ - raise NotImplementedError( - f"{config.model_type} encoder export is not supported yet. ", - f"If you want to support {config.model_type} please propose a PR or open up an issue.", - ) - - def get_decoder_onnx_config(self, config: "PretrainedConfig", task: str = "default", use_past: bool = False): - """ - Returns ONNX decoder config for `Seq2Seq` models. Implement the method to export the decoder - of the model separately. - - Args: - config (`PretrainedConfig`): - The decoder model's configuration to use when exporting to ONNX. - task (`str`, defaults to `"default"`): - The task the model should be exported for. - use_past (`bool`, defaults to `False`): - Whether to export the model with past_key_values. - - Returns: - `OnnxConfig`: An instance of the ONNX configuration object. - """ - raise NotImplementedError( - f"{config.model_type} decoder export is not supported yet. ", - f"If you want to support {config.model_type} please propose a PR or open up an issue.", - ) - class OnnxConfigWithPast(OnnxConfig, ABC): PAD_ATTENTION_MASK_TO_MATCH_TOTAL_SEQUENCE_LENGTH = True @@ -504,3 +466,43 @@ def flatten_past_key_values(self, flattened_output, name, idx, t): flattened_output[f"{name}.{idx}.decoder.value"] = t[1] flattened_output[f"{name}.{idx}.encoder.key"] = t[2] flattened_output[f"{name}.{idx}.encoder.value"] = t[3] + + def get_encoder_onnx_config(self, config: "PretrainedConfig") -> OnnxConfig: + """ + Returns ONNX encoder config for `Seq2Seq` models. Implement the method to export the encoder + of the model separately. + + Args: + config (`PretrainedConfig`): + The encoder model's configuration to use when exporting to ONNX. + + Returns: + `OnnxConfig`: An instance of the ONNX configuration object. + """ + raise NotImplementedError( + f"{config.model_type} encoder export is not supported yet. ", + f"If you want to support {config.model_type} please propose a PR or open up an issue.", + ) + + def get_decoder_onnx_config( + self, config: "PretrainedConfig", task: str = "default", use_past: bool = False + ) -> OnnxConfig: + """ + Returns ONNX decoder config for `Seq2Seq` models. Implement the method to export the decoder + of the model separately. + + Args: + config (`PretrainedConfig`): + The decoder model's configuration to use when exporting to ONNX. + task (`str`, defaults to `"default"`): + The task the model should be exported for. + use_past (`bool`, defaults to `False`): + Whether to export the model with past_key_values. + + Returns: + `OnnxConfig`: An instance of the ONNX configuration object. + """ + raise NotImplementedError( + f"{config.model_type} decoder export is not supported yet. ", + f"If you want to support {config.model_type} please propose a PR or open up an issue.", + )