diff --git a/src/main/scala/com/johnsnowlabs/ml/onnx/OnnxSerializeModel.scala b/src/main/scala/com/johnsnowlabs/ml/onnx/OnnxSerializeModel.scala index cf5802b3912448..2735626930016d 100644 --- a/src/main/scala/com/johnsnowlabs/ml/onnx/OnnxSerializeModel.scala +++ b/src/main/scala/com/johnsnowlabs/ml/onnx/OnnxSerializeModel.scala @@ -16,7 +16,6 @@ package com.johnsnowlabs.ml.onnx -import ai.onnxruntime.{OrtEnvironment, OrtLoggingLevel} import ai.onnxruntime.OrtSession.SessionOptions import com.johnsnowlabs.util.FileHelper import org.apache.commons.io.FileUtils diff --git a/src/main/scala/com/johnsnowlabs/ml/onnx/OnnxWrapper.scala b/src/main/scala/com/johnsnowlabs/ml/onnx/OnnxWrapper.scala index fb53c35530ec23..5478a52282990d 100644 --- a/src/main/scala/com/johnsnowlabs/ml/onnx/OnnxWrapper.scala +++ b/src/main/scala/com/johnsnowlabs/ml/onnx/OnnxWrapper.scala @@ -101,6 +101,7 @@ object OnnxWrapper { } } + // TODO: the parts related to onnx_data should be refactored once we support addFile() def read( modelPath: String, zipped: Boolean = true, @@ -152,7 +153,7 @@ object OnnxWrapper { session = _session env = _env } else { - val (_session, _env) = withSafeOnnxModelLoader(modelBytes, sessionOptions, Some(onnxFile)) + val (_session, _env) = withSafeOnnxModelLoader(modelBytes, sessionOptions, None) session = _session env = _env @@ -160,7 +161,9 @@ object OnnxWrapper { // 4. Remove tmp folder FileHelper.delete(tmpFolder) - val onnxWrapper = new OnnxWrapper(modelBytes, Option(onnxFile)) + val onnxWrapper = + if (onnxDataFileExist) new OnnxWrapper(modelBytes, Option(onnxFile)) + else new OnnxWrapper(modelBytes) onnxWrapper.ortSession = session onnxWrapper.ortEnv = env onnxWrapper