Skip to content

Commit

Permalink
SPARKNLP-985: Make Whisper compatible with onnx_data
Browse files Browse the repository at this point in the history
  • Loading branch information
DevinTDHa committed Feb 6, 2024
1 parent a2cb06b commit 6bab822
Showing 1 changed file with 10 additions and 7 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -360,9 +360,9 @@ class WhisperForCTC(override val uid: String)
path,
spark,
Seq(
(wrappers.encoder, "encoder_model"),
(wrappers.decoder, "decoder_model"),
(wrappers.decoderWithPast, "decoder_with_past_model")),
(wrappers.encoder, "encoder_model.onnx"),
(wrappers.decoder, "decoder_model.onnx"),
(wrappers.decoderWithPast, "decoder_with_past_model.onnx")),
WhisperForCTC.suffix)
}

Expand Down Expand Up @@ -425,6 +425,9 @@ trait ReadWhisperForCTCDLModel extends ReadTensorflowModel with ReadOnnxModel {
override val tfFile: String = "whisper_ctc_tensorflow"
override val onnxFile: String = "whisper_ctc_onnx"
val suffix: String = "_whisper_ctc"
val encoderModel: String = "encoder_model.onnx"
val decoderModel: String = "decoder_model.onnx"
val decoderWithPastModel: String = "decoder_with_past_model.onnx"

private def checkVersion(spark: SparkSession): Unit = {
val version = Version.parse(spark.version).toFloat
Expand All @@ -447,13 +450,13 @@ trait ReadWhisperForCTCDLModel extends ReadTensorflowModel with ReadOnnxModel {
readOnnxModels(
path,
spark,
Seq("encoder_model", "decoder_model", "decoder_with_past_model"),
Seq(encoderModel, decoderModel, decoderWithPastModel),
WhisperForCTC.suffix)

val onnxWrappers = EncoderDecoderWrappers(
wrappers("encoder_model"),
decoder = wrappers("decoder_model"),
decoderWithPast = wrappers("decoder_with_past_model"))
wrappers(encoderModel),
decoder = wrappers(decoderModel),
decoderWithPast = wrappers(decoderWithPastModel))

instance.setModelIfNotSet(spark, None, Some(onnxWrappers))
case _ =>
Expand Down

0 comments on commit 6bab822

Please sign in to comment.