Skip to content

Commit

Permalink
SPARKNLP-985: Add flexible naming for onnx_data
Browse files Browse the repository at this point in the history
Some annotators might have different naming schemes
for their files. Added a parameter to control this.
  • Loading branch information
DevinTDHa committed Feb 7, 2024
1 parent a2cb06b commit 0787993
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,8 @@ trait ReadOnnxModel {
modelNames: Seq[String],
suffix: String,
zipped: Boolean = true,
useBundle: Boolean = false): Map[String, OnnxWrapper] = {
useBundle: Boolean = false,
dataFileSuffix: String = "_data"): Map[String, OnnxWrapper] = {

val uri = new java.net.URI(path.replaceAllLiterally("\\", "/"))
val fs = FileSystem.get(uri, spark.sparkContext.hadoopConfiguration)
Expand All @@ -137,10 +138,10 @@ trait ReadOnnxModel {
val fsPath = new Path(path, localModelFile).toString

// 3. Copy onnx_data file if exists
val onnxDataFile = Paths.get(fsPath + "_data").toFile
val onnxDataFile = Paths.get(fsPath + dataFileSuffix).toFile

if (onnxDataFile.exists()) {
fs.copyToLocalFile(new Path(path, localModelFile + "_data"), new Path(tmpFolder))
fs.copyToLocalFile(new Path(path, localModelFile + dataFileSuffix), new Path(tmpFolder))
}

// 4. Read ONNX state
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -448,7 +448,8 @@ trait ReadWhisperForCTCDLModel extends ReadTensorflowModel with ReadOnnxModel {
path,
spark,
Seq("encoder_model", "decoder_model", "decoder_with_past_model"),
WhisperForCTC.suffix)
WhisperForCTC.suffix,
dataFileSuffix = ".onnx_data")

val onnxWrappers = EncoderDecoderWrappers(
wrappers("encoder_model"),
Expand Down

0 comments on commit 0787993

Please sign in to comment.