-
Notifications
You must be signed in to change notification settings - Fork 503
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add ort export in exporters for encoder-decoder models #497
Changes from 12 commits
5126f53
efcdeb0
2dd0339
90b9271
90153cf
bbb89e1
fc97e78
7ec493d
634c682
db97f8d
bc81a41
bbd201c
f9713dc
936b114
56db9c5
224f79c
eef106c
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -22,7 +22,12 @@ | |
from ...utils import logging | ||
from ..tasks import TasksManager | ||
from .base import OnnxConfigWithPast | ||
from .convert import export, validate_model_outputs | ||
from .convert import ( | ||
export, | ||
export_encoder_decoder_model, | ||
validate_encoder_decoder_model_outputs, | ||
validate_model_outputs, | ||
) | ||
|
||
|
||
logger = logging.get_logger() # pylint: disable=invalid-name | ||
|
@@ -64,6 +69,14 @@ def main(): | |
), | ||
) | ||
parser.add_argument("--cache_dir", type=str, default=None, help="Path indicating where to store cache.") | ||
parser.add_argument( | ||
"--for-ort", | ||
action="store_true", | ||
help=( | ||
"This exports models ready to be run with optimum.onnxruntime ORTModelXXX. Useful for encoder-decoder models for" | ||
"conditional generation. If enabled the encoder and decoder of the model are exported separately." | ||
), | ||
) | ||
parser.add_argument("output", type=Path, help="Path indicating the directory where to store generated ONNX model.") | ||
|
||
# Retrieve CLI arguments | ||
|
@@ -115,12 +128,20 @@ def main(): | |
f"At least {onnx_config.DEFAULT_ONNX_OPSET} is required." | ||
) | ||
|
||
onnx_inputs, onnx_outputs = export( | ||
model, | ||
onnx_config, | ||
args.opset, | ||
args.output, | ||
) | ||
use_past = True if "-with-past" in task else False | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think you do not need that since it is already in the onnx config no? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, was probably thinking of how the Currently, updated to use |
||
if model.config.is_encoder_decoder and args.for_ort: | ||
onnx_inputs, onnx_outputs = export_encoder_decoder_model( | ||
model, | ||
onnx_config, | ||
args.opset, | ||
task, | ||
use_past, | ||
args.output.parent.joinpath("encoder_model.onnx"), | ||
args.output.parent.joinpath("decoder_model.onnx"), | ||
args.output.parent.joinpath("decoder_with_past_model.onnx"), | ||
) | ||
else: | ||
onnx_inputs, onnx_outputs = export(model, onnx_config, args.opset, args.output) | ||
|
||
# Saving the model config as this is needed sometimes. | ||
model.config.save_pretrained(args.output.parent) | ||
|
@@ -144,11 +165,24 @@ def main(): | |
args.atol = args.atol[task.replace("-with-past", "")] | ||
|
||
try: | ||
validate_model_outputs(onnx_config, model, args.output, onnx_outputs, args.atol) | ||
if model.config.is_encoder_decoder and args.for_ort: | ||
validate_encoder_decoder_model_outputs( | ||
onnx_config, | ||
model, | ||
onnx_outputs, | ||
args.atol, | ||
task, | ||
use_past, | ||
args.output.parent.joinpath("encoder_model.onnx"), | ||
args.output.parent.joinpath("decoder_model.onnx"), | ||
args.output.parent.joinpath("decoder_with_past_model.onnx"), | ||
) | ||
else: | ||
validate_model_outputs(onnx_config, model, args.output, onnx_outputs, args.atol) | ||
except ValueError: | ||
logger.error(f"An error occured, but the model was saved at: {args.output.as_posix()}") | ||
logger.error(f"An error occured, but the model was saved at: {args.output.parent.as_posix()}") | ||
return | ||
logger.info(f"All good, model saved at: {args.output.as_posix()}") | ||
logger.info(f"All good, model saved at: {args.output.parent.as_posix()}") | ||
|
||
|
||
if __name__ == "__main__": | ||
|
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
|
@@ -303,6 +303,18 @@ def flatten_output_collection_property(cls, name: str, field: Iterable[Any]) -> | |||||
""" | ||||||
return {f"{name}.{idx}": item for idx, item in enumerate(itertools.chain.from_iterable(field))} | ||||||
|
||||||
def generate_dummy_inputs_onnxruntime(self, reference_model_inputs: Mapping[str, Any]) -> Mapping[str, Any]: | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Discussion with lewtun regarding the use of the function. huggingface/transformers#19525 (comment) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It is needed to generate the inputs for the separate encoder and decoder models? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It is needed only for validation using onnxruntime. Since the onnx model and torch model will have different input signatures when using There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What about calling it There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Agree. Updated! |
||||||
""" | ||||||
Generate inputs for ONNX Runtime using the reference model inputs. Override this to run inference with seq2seq | ||||||
models which have the encoder and decoder exported as separate ONNX files. | ||||||
Args: | ||||||
reference_model_inputs ([`Mapping[str, Tensor]`): | ||||||
Reference inputs for the model. | ||||||
Returns: | ||||||
`Mapping[str, Tensor]`: The mapping holding the kwargs to provide to the model's forward function | ||||||
""" | ||||||
return reference_model_inputs | ||||||
|
||||||
|
||||||
class OnnxConfigWithPast(OnnxConfig, ABC): | ||||||
PAD_ATTENTION_MASK_TO_MATCH_TOTAL_SEQUENCE_LENGTH = True | ||||||
|
@@ -454,3 +466,43 @@ def flatten_past_key_values(self, flattened_output, name, idx, t): | |||||
flattened_output[f"{name}.{idx}.decoder.value"] = t[1] | ||||||
flattened_output[f"{name}.{idx}.encoder.key"] = t[2] | ||||||
flattened_output[f"{name}.{idx}.encoder.value"] = t[3] | ||||||
|
||||||
michaelbenayoun marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||
def get_encoder_onnx_config(self, config: "PretrainedConfig") -> OnnxConfig: | ||||||
""" | ||||||
Returns ONNX encoder config for `Seq2Seq` models. Implement the method to export the encoder | ||||||
of the model separately. | ||||||
|
||||||
Args: | ||||||
config (`PretrainedConfig`): | ||||||
The encoder model's configuration to use when exporting to ONNX. | ||||||
|
||||||
Returns: | ||||||
`OnnxConfig`: An instance of the ONNX configuration object. | ||||||
""" | ||||||
raise NotImplementedError( | ||||||
f"{config.model_type} encoder export is not supported yet. ", | ||||||
f"If you want to support {config.model_type} please propose a PR or open up an issue.", | ||||||
) | ||||||
michaelbenayoun marked this conversation as resolved.
Show resolved
Hide resolved
michaelbenayoun marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||
|
||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||
def get_decoder_onnx_config( | ||||||
self, config: "PretrainedConfig", task: str = "default", use_past: bool = False | ||||||
) -> OnnxConfig: | ||||||
""" | ||||||
Returns ONNX decoder config for `Seq2Seq` models. Implement the method to export the decoder | ||||||
of the model separately. | ||||||
|
||||||
Args: | ||||||
config (`PretrainedConfig`): | ||||||
The decoder model's configuration to use when exporting to ONNX. | ||||||
task (`str`, defaults to `"default"`): | ||||||
The task the model should be exported for. | ||||||
use_past (`bool`, defaults to `False`): | ||||||
Whether to export the model with past_key_values. | ||||||
|
||||||
Returns: | ||||||
`OnnxConfig`: An instance of the ONNX configuration object. | ||||||
""" | ||||||
raise NotImplementedError( | ||||||
f"{config.model_type} decoder export is not supported yet. ", | ||||||
f"If you want to support {config.model_type} please propose a PR or open up an issue.", | ||||||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could the name
for-ort
be misleading as the models can have other tasks apart from the conditional generation?I have mentioned
Useful for encoder-decoder models for conditional generation
in help. But not sure if this would be enough.Probably updating
ORTModelXXX
->ORTModelForConditionalGeneration
?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would just say to run with
optimum.onnxruntime
. I thinkfor-ort
is good enough, or at least I do not have a better naming in mind.