diff --git a/src/main/scala/com/johnsnowlabs/ml/onnx/OnnxSerializeModel.scala b/src/main/scala/com/johnsnowlabs/ml/onnx/OnnxSerializeModel.scala index a0b152ac333e5b..cf5802b3912448 100644 --- a/src/main/scala/com/johnsnowlabs/ml/onnx/OnnxSerializeModel.scala +++ b/src/main/scala/com/johnsnowlabs/ml/onnx/OnnxSerializeModel.scala @@ -33,7 +33,8 @@ trait WriteOnnxModel { path: String, spark: SparkSession, onnxWrappersWithNames: Seq[(OnnxWrapper, String)], - suffix: String): Unit = { + suffix: String, + dataFileSuffix: String = "_data"): Unit = { val uri = new java.net.URI(path.replaceAllLiterally("\\", "/")) val fs = FileSystem.get(uri, spark.sparkContext.hadoopConfiguration) @@ -55,9 +56,9 @@ trait WriteOnnxModel { // 4. check if there is a onnx_data file - val onnxDataFile = Paths.get(onnxWrapper.onnxModelPath.get + "_data").toFile - if (onnxDataFile.exists()) { - fs.copyFromLocalFile(new Path(onnxDataFile.getAbsolutePath), new Path(path)) + val onnxDataFile = new Path(onnxWrapper.onnxModelPath.get + dataFileSuffix) + if (fs.exists(onnxDataFile)) { + fs.copyFromLocalFile(onnxDataFile, new Path(path)) } } @@ -85,7 +86,8 @@ trait ReadOnnxModel { suffix: String, zipped: Boolean = true, useBundle: Boolean = false, - sessionOptions: Option[SessionOptions] = None): OnnxWrapper = { + sessionOptions: Option[SessionOptions] = None, + dataFileSuffix: String = "_data"): OnnxWrapper = { val uri = new java.net.URI(path.replaceAllLiterally("\\", "/")) val fs = FileSystem.get(uri, spark.sparkContext.hadoopConfiguration) @@ -101,10 +103,18 @@ trait ReadOnnxModel { val localPath = new Path(tmpFolder, onnxFile).toString - // 3. Read ONNX state + val fsPath = new Path(path, onnxFile) + + // 3. Copy onnx_data file if exists + val onnxDataFile = new Path(fsPath + dataFileSuffix) + + if (fs.exists(onnxDataFile)) { + fs.copyToLocalFile(onnxDataFile, new Path(tmpFolder)) + } + // 4. Read ONNX state val onnxWrapper = OnnxWrapper.read(localPath, zipped = zipped, useBundle = useBundle) - // 4. Remove tmp folder + // 5. Remove tmp folder FileHelper.delete(tmpFolder) onnxWrapper @@ -138,10 +148,10 @@ trait ReadOnnxModel { val fsPath = new Path(path, localModelFile).toString // 3. Copy onnx_data file if exists - val onnxDataFile = Paths.get(fsPath + dataFileSuffix).toFile + val onnxDataFile = new Path(fsPath + dataFileSuffix) - if (onnxDataFile.exists()) { - fs.copyToLocalFile(new Path(path, localModelFile + dataFileSuffix), new Path(tmpFolder)) + if (fs.exists(onnxDataFile)) { + fs.copyToLocalFile(onnxDataFile, new Path(tmpFolder)) } // 4. Read ONNX state diff --git a/src/main/scala/com/johnsnowlabs/ml/onnx/OnnxWrapper.scala b/src/main/scala/com/johnsnowlabs/ml/onnx/OnnxWrapper.scala index 1396b2897f0f07..fb53c35530ec23 100644 --- a/src/main/scala/com/johnsnowlabs/ml/onnx/OnnxWrapper.scala +++ b/src/main/scala/com/johnsnowlabs/ml/onnx/OnnxWrapper.scala @@ -105,7 +105,8 @@ object OnnxWrapper { modelPath: String, zipped: Boolean = true, useBundle: Boolean = false, - modelName: String = "model"): OnnxWrapper = { + modelName: String = "model", + dataFileSuffix: String = "_data"): OnnxWrapper = { // 1. Create tmp folder val tmpFolder = Files @@ -132,13 +133,13 @@ object OnnxWrapper { val parentDir = if (zipped) Paths.get(modelPath).getParent.toString else modelPath val onnxDataFileExist: Boolean = { - onnxDataFile = Paths.get(parentDir, s"${modelName.replace(".onnx", "")}.onnx_data").toFile + onnxDataFile = Paths.get(parentDir, modelName + dataFileSuffix).toFile onnxDataFile.exists() } if (onnxDataFileExist) { val onnxDataFileTmp = - Paths.get(tmpFolder, s"${modelName.replace(".onnx", "")}.onnx_data").toFile + Paths.get(tmpFolder, modelName + dataFileSuffix).toFile FileUtils.copyFile(onnxDataFile, onnxDataFileTmp) }