From 0787993e550eca86817404b9411f7057f1285a47 Mon Sep 17 00:00:00 2001 From: Devin Ha Date: Wed, 7 Feb 2024 09:01:19 +0100 Subject: [PATCH] SPARKNLP-985: Add flexible naming for onnx_data Some annotators might have different naming schemes for their files. Added a parameter to control this. --- .../com/johnsnowlabs/ml/onnx/OnnxSerializeModel.scala | 7 ++++--- .../johnsnowlabs/nlp/annotators/audio/WhisperForCTC.scala | 3 ++- 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/src/main/scala/com/johnsnowlabs/ml/onnx/OnnxSerializeModel.scala b/src/main/scala/com/johnsnowlabs/ml/onnx/OnnxSerializeModel.scala index b482ed733b54a0..a0b152ac333e5b 100644 --- a/src/main/scala/com/johnsnowlabs/ml/onnx/OnnxSerializeModel.scala +++ b/src/main/scala/com/johnsnowlabs/ml/onnx/OnnxSerializeModel.scala @@ -116,7 +116,8 @@ trait ReadOnnxModel { modelNames: Seq[String], suffix: String, zipped: Boolean = true, - useBundle: Boolean = false): Map[String, OnnxWrapper] = { + useBundle: Boolean = false, + dataFileSuffix: String = "_data"): Map[String, OnnxWrapper] = { val uri = new java.net.URI(path.replaceAllLiterally("\\", "/")) val fs = FileSystem.get(uri, spark.sparkContext.hadoopConfiguration) @@ -137,10 +138,10 @@ trait ReadOnnxModel { val fsPath = new Path(path, localModelFile).toString // 3. Copy onnx_data file if exists - val onnxDataFile = Paths.get(fsPath + "_data").toFile + val onnxDataFile = Paths.get(fsPath + dataFileSuffix).toFile if (onnxDataFile.exists()) { - fs.copyToLocalFile(new Path(path, localModelFile + "_data"), new Path(tmpFolder)) + fs.copyToLocalFile(new Path(path, localModelFile + dataFileSuffix), new Path(tmpFolder)) } // 4. Read ONNX state diff --git a/src/main/scala/com/johnsnowlabs/nlp/annotators/audio/WhisperForCTC.scala b/src/main/scala/com/johnsnowlabs/nlp/annotators/audio/WhisperForCTC.scala index c3ad11638fd3e0..5bb2741b238d6f 100644 --- a/src/main/scala/com/johnsnowlabs/nlp/annotators/audio/WhisperForCTC.scala +++ b/src/main/scala/com/johnsnowlabs/nlp/annotators/audio/WhisperForCTC.scala @@ -448,7 +448,8 @@ trait ReadWhisperForCTCDLModel extends ReadTensorflowModel with ReadOnnxModel { path, spark, Seq("encoder_model", "decoder_model", "decoder_with_past_model"), - WhisperForCTC.suffix) + WhisperForCTC.suffix, + dataFileSuffix = ".onnx_data") val onnxWrappers = EncoderDecoderWrappers( wrappers("encoder_model"),