Skip to content

Commit

Permalink
fixed serialization bug
Browse files Browse the repository at this point in the history
  • Loading branch information
prabod committed Jan 31, 2024
1 parent 86a079c commit 73310e9
Showing 1 changed file with 8 additions and 6 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -446,12 +446,12 @@ class M2M100Transformer(override val uid: String)
writeOnnxModels(
path,
spark,
Seq((wrappers.encoder, "encoder_model")),
Seq((wrappers.encoder, "encoder_model.onnx")),
M2M100Transformer.suffix)
writeOnnxModels(
path,
spark,
Seq((wrappers.decoder, "decoder_model")),
Seq((wrappers.decoder, "decoder_model.onnx")),
M2M100Transformer.suffix)
writeSentencePieceModel(
path,
Expand Down Expand Up @@ -490,12 +490,14 @@ trait ReadM2M100TransformerDLModel extends ReadOnnxModel with ReadSentencePieceM
def readModel(instance: M2M100Transformer, path: String, spark: SparkSession): Unit = {
instance.getEngine match {
case ONNX.name =>
val decoderWrappers = readOnnxModels(path, spark, Seq("decoder_model"), suffix)
val encoderWrappers = readOnnxModels(path, spark, Seq("encoder_model"), suffix)
val decoderWrappers =
readOnnxModels(path, spark, Seq("decoder_model.onnx"), suffix)
val encoderWrappers =
readOnnxModels(path, spark, Seq("encoder_model.onnx"), suffix)
val onnxWrappers =
EncoderDecoderWithoutPastWrappers(
decoder = decoderWrappers("decoder_model"),
encoder = encoderWrappers("encoder_model"))
decoder = decoderWrappers("decoder_model.onnx"),
encoder = encoderWrappers("encoder_model.onnx"))
val spp = readSentencePieceModel(path, spark, "_m2m100_spp", sppFile)
instance.setModelIfNotSet(spark, onnxWrappers, spp)
case _ =>
Expand Down

0 comments on commit 73310e9

Please sign in to comment.