From 70f80824601bff8ea8faaf76aa41bd503e4e2d02 Mon Sep 17 00:00:00 2001 From: Rahul Tuli Date: Thu, 18 Jan 2024 16:53:44 -0500 Subject: [PATCH] Remove ort from transformers dependency (#1976) raise apt install message when ort is used --- setup.py | 1 - src/sparseml/export/validators.py | 8 +++++++- 2 files changed, 7 insertions(+), 2 deletions(-) diff --git a/setup.py b/setup.py index 59d7da4a546..672bfcfeca4 100644 --- a/setup.py +++ b/setup.py @@ -83,7 +83,6 @@ "scikit-learn", "seqeval", "einops", - "onnxruntime>=1.0.0", "accelerate>=0.20.3", ] _yolov5_deps = _pytorch_vision_deps + [ diff --git a/src/sparseml/export/validators.py b/src/sparseml/export/validators.py index 74b2bf5c2a5..183eabc9667 100644 --- a/src/sparseml/export/validators.py +++ b/src/sparseml/export/validators.py @@ -133,7 +133,13 @@ def validate_correctness( :param validation_function: The function that will be used to validate the outputs. :return: True if the validation passes, False otherwise. """ - import onnxruntime as ort + try: + import onnxruntime as ort + except ImportError as err: + raise ImportError( + "The onnxruntime package is required for the correctness validation. " + "Please install it using 'pip install sparseml[onnxruntime]'." + ) from err sample_inputs_path = os.path.join(target_path, InputsNames.basename.value) sample_outputs_path = os.path.join(target_path, OutputsNames.basename.value)