Skip to content

Commit

Permalink
SparkNLP 933: Introducing M2M100 : multilingual translation model (#1…
Browse files Browse the repository at this point in the history
…4155)

* introducing LLAMA2

* Added option to read model from model path to onnx wrapper

* Added option to read model from model path to onnx wrapper

* updated text description

* LLAMA2 python API

* added method to save onnx_data

* added position ids

* - updated Generate.scala to accept onnx tensors
- added beam search support for LLAMA2

* updated max input length

* updated python default params
changed test to slow test

* fixed serialization bug

* Added Scala code for M2M100

* Documentation for scala code

* Python API for M2M100

* added more tests for scala

* added tests for python

* added pretrained

* rewording

* fixed serialization bug

* fixed serialization bug

---------

Co-authored-by: Maziyar Panahi <maziyar.panahi@iscpif.fr>
  • Loading branch information
prabod and maziyarpanahi authored Feb 6, 2024
1 parent 0e01a2c commit a2cb06b
Show file tree
Hide file tree
Showing 8 changed files with 1,688 additions and 0 deletions.
1 change: 1 addition & 0 deletions python/sparknlp/annotator/seq2seq/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,3 +18,4 @@
from sparknlp.annotator.seq2seq.t5_transformer import *
from sparknlp.annotator.seq2seq.bart_transformer import *
from sparknlp.annotator.seq2seq.llama2_transformer import *
from sparknlp.annotator.seq2seq.m2m100_transformer import *
392 changes: 392 additions & 0 deletions python/sparknlp/annotator/seq2seq/m2m100_transformer.py

Large diffs are not rendered by default.

6 changes: 6 additions & 0 deletions python/sparknlp/internal/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,6 +191,12 @@ def __init__(self, path, jspark):
jspark)


class _M2M100Loader(ExtendedJavaWrapper):
def __init__(self, path, jspark):
super(_M2M100Loader, self).__init__(
"com.johnsnowlabs.nlp.annotators.seq2seq.M2M100Transformer.loadSavedModel", path, jspark)


class _MarianLoader(ExtendedJavaWrapper):
def __init__(self, path, jspark):
super(_MarianLoader, self).__init__(
Expand Down
46 changes: 46 additions & 0 deletions python/test/annotator/seq2seq/m2m100_transformer_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
# Copyright 2017-2022 John Snow Labs
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest

import pytest

from sparknlp.annotator import *
from sparknlp.base import *
from test.util import SparkContextForTest


@pytest.mark.slow
class M2M100TransformerTextTranslationTestSpec(unittest.TestCase):
def setUp(self):
self.spark = SparkContextForTest.spark

def runTest(self):
data = self.spark.createDataFrame([
[1, """生活就像一盒巧克力。""".strip().replace("\n", " ")]]).toDF("id", "text")

document_assembler = DocumentAssembler() \
.setInputCol("text") \
.setOutputCol("documents")

m2m100 = M2M100Transformer.pretrained() \
.setInputCols(["documents"]) \
.setMaxOutputLength(50) \
.setOutputCol("generation") \
.setSrcLang("en") \
.setTgtLang("fr")

pipeline = Pipeline().setStages([document_assembler, m2m100])
results = pipeline.fit(data).transform(data)

results.select("generation.result").show(truncate=False)
Loading

0 comments on commit a2cb06b

Please sign in to comment.