diff --git a/src/main/scala/com/johnsnowlabs/nlp/annotators/seq2seq/M2M100Transformer.scala b/src/main/scala/com/johnsnowlabs/nlp/annotators/seq2seq/M2M100Transformer.scala index d658658207af40..3a2dc1a9b22740 100644 --- a/src/main/scala/com/johnsnowlabs/nlp/annotators/seq2seq/M2M100Transformer.scala +++ b/src/main/scala/com/johnsnowlabs/nlp/annotators/seq2seq/M2M100Transformer.scala @@ -446,12 +446,12 @@ class M2M100Transformer(override val uid: String) writeOnnxModels( path, spark, - Seq((wrappers.encoder, "encoder_model")), + Seq((wrappers.encoder, "encoder_model.onnx")), M2M100Transformer.suffix) writeOnnxModels( path, spark, - Seq((wrappers.decoder, "decoder_model")), + Seq((wrappers.decoder, "decoder_model.onnx")), M2M100Transformer.suffix) writeSentencePieceModel( path, @@ -490,12 +490,14 @@ trait ReadM2M100TransformerDLModel extends ReadOnnxModel with ReadSentencePieceM def readModel(instance: M2M100Transformer, path: String, spark: SparkSession): Unit = { instance.getEngine match { case ONNX.name => - val decoderWrappers = readOnnxModels(path, spark, Seq("decoder_model"), suffix) - val encoderWrappers = readOnnxModels(path, spark, Seq("encoder_model"), suffix) + val decoderWrappers = + readOnnxModels(path, spark, Seq("decoder_model.onnx"), suffix) + val encoderWrappers = + readOnnxModels(path, spark, Seq("encoder_model.onnx"), suffix) val onnxWrappers = EncoderDecoderWithoutPastWrappers( - decoder = decoderWrappers("decoder_model"), - encoder = encoderWrappers("encoder_model")) + decoder = decoderWrappers("decoder_model.onnx"), + encoder = encoderWrappers("encoder_model.onnx")) val spp = readSentencePieceModel(path, spark, "_m2m100_spp", sppFile) instance.setModelIfNotSet(spark, onnxWrappers, spp) case _ =>