Skip to content

Commit

Permalink
used filesystem to check for the onnx_data file (#14169)
Browse files Browse the repository at this point in the history
  • Loading branch information
prabod authored Feb 11, 2024
1 parent 6010244 commit 0e9b54d
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 13 deletions.
30 changes: 20 additions & 10 deletions src/main/scala/com/johnsnowlabs/ml/onnx/OnnxSerializeModel.scala
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,8 @@ trait WriteOnnxModel {
path: String,
spark: SparkSession,
onnxWrappersWithNames: Seq[(OnnxWrapper, String)],
suffix: String): Unit = {
suffix: String,
dataFileSuffix: String = "_data"): Unit = {

val uri = new java.net.URI(path.replaceAllLiterally("\\", "/"))
val fs = FileSystem.get(uri, spark.sparkContext.hadoopConfiguration)
Expand All @@ -55,9 +56,9 @@ trait WriteOnnxModel {

// 4. check if there is a onnx_data file

val onnxDataFile = Paths.get(onnxWrapper.onnxModelPath.get + "_data").toFile
if (onnxDataFile.exists()) {
fs.copyFromLocalFile(new Path(onnxDataFile.getAbsolutePath), new Path(path))
val onnxDataFile = new Path(onnxWrapper.onnxModelPath.get + dataFileSuffix)
if (fs.exists(onnxDataFile)) {
fs.copyFromLocalFile(onnxDataFile, new Path(path))
}
}

Expand Down Expand Up @@ -85,7 +86,8 @@ trait ReadOnnxModel {
suffix: String,
zipped: Boolean = true,
useBundle: Boolean = false,
sessionOptions: Option[SessionOptions] = None): OnnxWrapper = {
sessionOptions: Option[SessionOptions] = None,
dataFileSuffix: String = "_data"): OnnxWrapper = {

val uri = new java.net.URI(path.replaceAllLiterally("\\", "/"))
val fs = FileSystem.get(uri, spark.sparkContext.hadoopConfiguration)
Expand All @@ -101,10 +103,18 @@ trait ReadOnnxModel {

val localPath = new Path(tmpFolder, onnxFile).toString

// 3. Read ONNX state
val fsPath = new Path(path, onnxFile)

// 3. Copy onnx_data file if exists
val onnxDataFile = new Path(fsPath + dataFileSuffix)

if (fs.exists(onnxDataFile)) {
fs.copyToLocalFile(onnxDataFile, new Path(tmpFolder))
}
// 4. Read ONNX state
val onnxWrapper = OnnxWrapper.read(localPath, zipped = zipped, useBundle = useBundle)

// 4. Remove tmp folder
// 5. Remove tmp folder
FileHelper.delete(tmpFolder)

onnxWrapper
Expand Down Expand Up @@ -138,10 +148,10 @@ trait ReadOnnxModel {
val fsPath = new Path(path, localModelFile).toString

// 3. Copy onnx_data file if exists
val onnxDataFile = Paths.get(fsPath + dataFileSuffix).toFile
val onnxDataFile = new Path(fsPath + dataFileSuffix)

if (onnxDataFile.exists()) {
fs.copyToLocalFile(new Path(path, localModelFile + dataFileSuffix), new Path(tmpFolder))
if (fs.exists(onnxDataFile)) {
fs.copyToLocalFile(onnxDataFile, new Path(tmpFolder))
}

// 4. Read ONNX state
Expand Down
7 changes: 4 additions & 3 deletions src/main/scala/com/johnsnowlabs/ml/onnx/OnnxWrapper.scala
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,8 @@ object OnnxWrapper {
modelPath: String,
zipped: Boolean = true,
useBundle: Boolean = false,
modelName: String = "model"): OnnxWrapper = {
modelName: String = "model",
dataFileSuffix: String = "_data"): OnnxWrapper = {

// 1. Create tmp folder
val tmpFolder = Files
Expand All @@ -132,13 +133,13 @@ object OnnxWrapper {
val parentDir = if (zipped) Paths.get(modelPath).getParent.toString else modelPath

val onnxDataFileExist: Boolean = {
onnxDataFile = Paths.get(parentDir, s"${modelName.replace(".onnx", "")}.onnx_data").toFile
onnxDataFile = Paths.get(parentDir, modelName + dataFileSuffix).toFile
onnxDataFile.exists()
}

if (onnxDataFileExist) {
val onnxDataFileTmp =
Paths.get(tmpFolder, s"${modelName.replace(".onnx", "")}.onnx_data").toFile
Paths.get(tmpFolder, modelName + dataFileSuffix).toFile
FileUtils.copyFile(onnxDataFile, onnxDataFileTmp)
}

Expand Down

0 comments on commit 0e9b54d

Please sign in to comment.