Skip to content

Commit

Permalink
Fixing onnx saving path bug (#13959)
Browse files Browse the repository at this point in the history
* fixing onnx write issue on windows

* fixing indentation

* fixing formatting

* fixing formatting

* final formatting fix

* Fix onnx saving bug

---------

Co-authored-by: Devin Ha <t.ha@tu-berlin.de>
Co-authored-by: Maziyar Panahi <maziyar.panahi@iscpif.fr>
  • Loading branch information
3 people authored Sep 11, 2023
1 parent 1f61caf commit 04f2f7e
Show file tree
Hide file tree
Showing 3 changed files with 7 additions and 5 deletions.
8 changes: 5 additions & 3 deletions src/main/scala/com/johnsnowlabs/ml/onnx/OnnxWrapper.scala
Original file line number Diff line number Diff line change
Expand Up @@ -58,9 +58,12 @@ class OnnxWrapper(var onnxModel: Array[Byte]) extends Serializable {
.toString

// 2. Save onnx model
val onnxFile = Paths.get(tmpFolder, file).toString
FileUtils.writeByteArrayToFile(new File(onnxFile), onnxModel)
val fileName = Paths.get(file).getFileName.toString
val onnxFile = Paths
.get(tmpFolder, fileName)
.toString

FileUtils.writeByteArrayToFile(new File(onnxFile), onnxModel)
// 4. Zip folder
if (zip) ZipArchiveUtil.zip(tmpFolder, file)

Expand Down Expand Up @@ -163,5 +166,4 @@ object OnnxWrapper {
encoder: OnnxWrapper,
decoder: OnnxWrapper,
decoderWithPast: OnnxWrapper)

}
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ class WhisperForCTCTest extends AnyFlatSpec with WhisperForCTCBehaviors {

// Needs to be added manually
lazy val modelTf: WhisperForCTC = WhisperForCTC
.pretrained("asr_whisper_tiny")
.pretrained("asr_whisper_tiny", "xx")
.setInputCols("audio_assembler")
.setOutputCol("document")

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ package com.johnsnowlabs.nlp.embeddings

import com.johnsnowlabs.nlp.base.DocumentAssembler
import com.johnsnowlabs.nlp.util.io.ResourceHelper
import com.johnsnowlabs.tags.{SlowTest, FastTest}
import com.johnsnowlabs.tags.{SlowTest}
import org.apache.spark.ml.Pipeline
import org.scalatest.flatspec.AnyFlatSpec

Expand Down

0 comments on commit 04f2f7e

Please sign in to comment.