-
Notifications
You must be signed in to change notification settings - Fork 718
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
SparkNLP 933: Introducing M2M100 : multilingual translation model (#1…
…4155) * introducing LLAMA2 * Added option to read model from model path to onnx wrapper * Added option to read model from model path to onnx wrapper * updated text description * LLAMA2 python API * added method to save onnx_data * added position ids * - updated Generate.scala to accept onnx tensors - added beam search support for LLAMA2 * updated max input length * updated python default params changed test to slow test * fixed serialization bug * Added Scala code for M2M100 * Documentation for scala code * Python API for M2M100 * added more tests for scala * added tests for python * added pretrained * rewording * fixed serialization bug * fixed serialization bug --------- Co-authored-by: Maziyar Panahi <maziyar.panahi@iscpif.fr>
- Loading branch information
1 parent
0e01a2c
commit a2cb06b
Showing
8 changed files
with
1,688 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
392 changes: 392 additions & 0 deletions
392
python/sparknlp/annotator/seq2seq/m2m100_transformer.py
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,46 @@ | ||
# Copyright 2017-2022 John Snow Labs | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
import unittest | ||
|
||
import pytest | ||
|
||
from sparknlp.annotator import * | ||
from sparknlp.base import * | ||
from test.util import SparkContextForTest | ||
|
||
|
||
@pytest.mark.slow | ||
class M2M100TransformerTextTranslationTestSpec(unittest.TestCase): | ||
def setUp(self): | ||
self.spark = SparkContextForTest.spark | ||
|
||
def runTest(self): | ||
data = self.spark.createDataFrame([ | ||
[1, """生活就像一盒巧克力。""".strip().replace("\n", " ")]]).toDF("id", "text") | ||
|
||
document_assembler = DocumentAssembler() \ | ||
.setInputCol("text") \ | ||
.setOutputCol("documents") | ||
|
||
m2m100 = M2M100Transformer.pretrained() \ | ||
.setInputCols(["documents"]) \ | ||
.setMaxOutputLength(50) \ | ||
.setOutputCol("generation") \ | ||
.setSrcLang("en") \ | ||
.setTgtLang("fr") | ||
|
||
pipeline = Pipeline().setStages([document_assembler, m2m100]) | ||
results = pipeline.fit(data).transform(data) | ||
|
||
results.select("generation.result").show(truncate=False) |
Oops, something went wrong.