Skip to content

Commit

Permalink
Fix ONNX models failing in clusters like Databricks
Browse files Browse the repository at this point in the history
  • Loading branch information
maziyarpanahi committed Mar 2, 2024
1 parent 75d398e commit 8877454
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@

package com.johnsnowlabs.ml.onnx

import ai.onnxruntime.{OrtEnvironment, OrtLoggingLevel}
import ai.onnxruntime.OrtSession.SessionOptions
import com.johnsnowlabs.util.FileHelper
import org.apache.commons.io.FileUtils
Expand Down
7 changes: 5 additions & 2 deletions src/main/scala/com/johnsnowlabs/ml/onnx/OnnxWrapper.scala
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,7 @@ object OnnxWrapper {
}
}

// TODO: the parts related to onnx_data should be refactored once we support addFile()
def read(
modelPath: String,
zipped: Boolean = true,
Expand Down Expand Up @@ -152,15 +153,17 @@ object OnnxWrapper {
session = _session
env = _env
} else {
val (_session, _env) = withSafeOnnxModelLoader(modelBytes, sessionOptions, Some(onnxFile))
val (_session, _env) = withSafeOnnxModelLoader(modelBytes, sessionOptions, None)
session = _session
env = _env

}
// 4. Remove tmp folder
FileHelper.delete(tmpFolder)

val onnxWrapper = new OnnxWrapper(modelBytes, Option(onnxFile))
val onnxWrapper =
if (onnxDataFileExist) new OnnxWrapper(modelBytes, Option(onnxFile))
else new OnnxWrapper(modelBytes)
onnxWrapper.ortSession = session
onnxWrapper.ortEnv = env
onnxWrapper
Expand Down

0 comments on commit 8877454

Please sign in to comment.