diff --git a/python/test/annotator/seq2seq/bart_transformer_test.py b/python/test/annotator/seq2seq/bart_transformer_test.py index d489d488ae8ae8..99f42d52bd9895 100644 --- a/python/test/annotator/seq2seq/bart_transformer_test.py +++ b/python/test/annotator/seq2seq/bart_transformer_test.py @@ -46,6 +46,81 @@ def runTest(self): results.select("documents.result", "answers.result").show(truncate=False) +@pytest.mark.slow +class BartTransformerMaxLengthTestSpec(unittest.TestCase): + def setUp(self): + self.spark = SparkContextForTest.spark + + def runTest(self): + data = self.spark.createDataFrame([ + [1, """ + Heat oven to 200C/180C fan/gas 6. Line each hole of a 12-hole muffin tin with a thin strip of baking + parchment across the middle that’s long enough so the ends stick out a centimetre or two – use a dab of + butter to stick in place. Roll out two thirds of the pastry on a lightly floured surface and stamp out + 12 x 10cm circles (you may need to re-roll trimmings). Press a circle into each hole to line. + + Sprinkle 1 tsp of breadcrumbs into the base of each pie. Tip the rest of the crumbs into a mixing bowl. + Squeeze in the sausage meat, discarding the skins, along with the bacon, mace, pepper, sage and just a + little salt. Get your hands in and mash and squish everything together until the breadcrumbs have just + about disappeared. Divide mixture between the holes, packing in firmly and shaping to a dome in the middle. + + Roll out the remaining pastry and stamp out 12 x 7cm circles. Brush with a little egg and add a top to + each pie, egg-side down to stick, carefully pressing pastry edges together to seal. Brush with more egg + (don’t throw away leftovers) and sprinkle with sesame seeds. Bake for 30 mins until golden then carefully + remove the pies from the tin, using the parchment ends to help you lift them out. Sit on a parchment lined + baking tray, brush all round the sides with more egg and put back in the oven for 8 mins. Cool completely + then eat with piccalilli, or your favourite pickle. + + Heat oven to 200C/180C fan/gas 6. Line each hole of a 12-hole muffin tin with a thin strip of baking + parchment across the middle that’s long enough so the ends stick out a centimetre or two – use a dab of + butter to stick in place. Roll out two thirds of the pastry on a lightly floured surface and stamp out + 12 x 10cm circles (you may need to re-roll trimmings). Press a circle into each hole to line. + + Sprinkle 1 tsp of breadcrumbs into the base of each pie. Tip the rest of the crumbs into a mixing bowl. + Squeeze in the sausage meat, discarding the skins, along with the bacon, mace, pepper, sage and just a + little salt. Get your hands in and mash and squish everything together until the breadcrumbs have just + about disappeared. Divide mixture between the holes, packing in firmly and shaping to a dome in the middle. + + Roll out the remaining pastry and stamp out 12 x 7cm circles. Brush with a little egg and add a top to + each pie, egg-side down to stick, carefully pressing pastry edges together to seal. Brush with more egg + (don’t throw away leftovers) and sprinkle with sesame seeds. Bake for 30 mins until golden then carefully + remove the pies from the tin, using the parchment ends to help you lift them out. Sit on a parchment lined + baking tray, brush all round the sides with more egg and put back in the oven for 8 mins. Cool completely + then eat with piccalilli, or your favourite pickle. + + Heat oven to 200C/180C fan/gas 6. Line each hole of a 12-hole muffin tin with a thin strip of baking + parchment across the middle that’s long enough so the ends stick out a centimetre or two – use a dab of + butter to stick in place. Roll out two thirds of the pastry on a lightly floured surface and stamp out + 12 x 10cm circles (you may need to re-roll trimmings). Press a circle into each hole to line. + + Sprinkle 1 tsp of breadcrumbs into the base of each pie. Tip the rest of the crumbs into a mixing bowl. + Squeeze in the sausage meat, discarding the skins, along with the bacon, mace, pepper, sage and just a + little salt. Get your hands in and mash and squish everything together until the breadcrumbs have just + about disappeared. Divide mixture between the holes, packing in firmly and shaping to a dome in the middle. + + Roll out the remaining pastry and stamp out 12 x 7cm circles. Brush with a little egg and add a top to + each pie, egg-side down to stick, carefully pressing pastry edges together to seal. Brush with more egg + (don’t throw away leftovers) and sprinkle with sesame seeds. Bake for 30 mins until golden then carefully + remove the pies from the tin, using the parchment ends to help you lift them out. Sit on a parchment lined + baking tray, brush all round the sides with more egg and put back in the oven for 8 mins. Cool completely + then eat with piccalilli, or your favourite pickle. + """.strip().replace("\n", " ")]]).toDF("id", "text") + + document_assembler = DocumentAssembler() \ + .setInputCol("text") \ + .setOutputCol("documents") + + bart = BartTransformer.pretrained("distilbart_xsum_12_6") \ + .setTask("summarize:") \ + .setMaxOutputLength(30) \ + .setInputCols(["documents"]) \ + .setOutputCol("summaries") + + pipeline = Pipeline().setStages([document_assembler, bart]) + results = pipeline.fit(data).transform(data) + + results.select("summaries.result").show(truncate=False) + @pytest.mark.slow class BartTransformerSummaryTestSpec(unittest.TestCase): diff --git a/src/main/scala/com/johnsnowlabs/ml/ai/Bart.scala b/src/main/scala/com/johnsnowlabs/ml/ai/Bart.scala index cb46d348450483..3edba90c0f2bd1 100644 --- a/src/main/scala/com/johnsnowlabs/ml/ai/Bart.scala +++ b/src/main/scala/com/johnsnowlabs/ml/ai/Bart.scala @@ -101,7 +101,8 @@ private[johnsnowlabs] class Bart( task: String, randomSeed: Option[Long] = None, ignoreTokenIds: Array[Int] = Array(), - beamSize: Int): Seq[Annotation] = { + beamSize: Int, + maxInputLength: Int): Seq[Annotation] = { val batchDecoder = sentences.grouped(batchSize).toArray.flatMap { batch => val batchSP = encode(batch, task) @@ -117,7 +118,8 @@ private[johnsnowlabs] class Bart( noRepeatNgramSize, randomSeed, ignoreTokenIds, - beamSize) + beamSize, + maxInputLength) decode(spIds) @@ -176,10 +178,12 @@ private[johnsnowlabs] class Bart( noRepeatNgramSize: Int, randomSeed: Option[Long], ignoreTokenIds: Array[Int] = Array(), - beamSize: Int): Array[Array[Int]] = { + beamSize: Int, + maxInputLength: Int): Array[Array[Int]] = { val ignoreTokenIdsInt = ignoreTokenIds - val expandedEncoderInputIdsVals = batch.flatMap(x => List.fill(beamSize)(x)) + val expandedEncoderInputIdsVals = + batch.flatMap(x => List.fill(beamSize)(x.take(maxInputLength))) val sequencesLength = expandedEncoderInputIdsVals.map(x => x.length).toArray val maxSentenceLength = sequencesLength.max // - curLen @@ -492,6 +496,7 @@ private[johnsnowlabs] class Bart( noRepeatNgramSize = 0, randomSeed = Option(0), ignoreTokenIds = Array(0), - beamSize = 1) + beamSize = 1, + maxInputLength = 512) } } diff --git a/src/main/scala/com/johnsnowlabs/nlp/annotators/seq2seq/BartTransformer.scala b/src/main/scala/com/johnsnowlabs/nlp/annotators/seq2seq/BartTransformer.scala index aa72c0e4738fbe..5e8a97a693a2df 100644 --- a/src/main/scala/com/johnsnowlabs/nlp/annotators/seq2seq/BartTransformer.scala +++ b/src/main/scala/com/johnsnowlabs/nlp/annotators/seq2seq/BartTransformer.scala @@ -199,6 +199,18 @@ class BartTransformer(override val uid: String) this } + /** max length of the input sequence (Default: `0`) + * + * @group param + */ + val maxInputLength = + new IntParam(this, "maxInputLength", "Maximum length of the input sequence") + + def setMaxInputLength(value: Int): BartTransformer.this.type = { + set(maxInputLength, value) + this + } + /** @group getParam */ def getMinOutputLength: Int = $(this.minOutputLength) @@ -477,6 +489,7 @@ class BartTransformer(override val uid: String) ignoreTokenIds -> Array(), batchSize -> 1, beamSize -> 4, + maxInputLength -> 512, useCache -> true) override def batchAnnotate(batchedAnnotations: Seq[Array[Annotation]]): Seq[Seq[Annotation]] = { @@ -503,7 +516,8 @@ class BartTransformer(override val uid: String) task = $(task), randomSeed = this.randomSeed, ignoreTokenIds = $(ignoreTokenIds), - beamSize = $(beamSize)) + beamSize = $(beamSize), + maxInputLength = $(maxInputLength)) } else { Seq() } diff --git a/src/test/scala/com/johnsnowlabs/nlp/annotators/seq2seq/BartTestSpec.scala b/src/test/scala/com/johnsnowlabs/nlp/annotators/seq2seq/BartTestSpec.scala index 11b5e8d768aa3d..553567d53a4df3 100644 --- a/src/test/scala/com/johnsnowlabs/nlp/annotators/seq2seq/BartTestSpec.scala +++ b/src/test/scala/com/johnsnowlabs/nlp/annotators/seq2seq/BartTestSpec.scala @@ -18,7 +18,7 @@ package com.johnsnowlabs.nlp.annotators.seq2seq import com.johnsnowlabs.nlp.base.DocumentAssembler import com.johnsnowlabs.nlp.util.io.ResourceHelper -import com.johnsnowlabs.tags.SlowTest +import com.johnsnowlabs.tags.{SlowTest, FastTest} import com.johnsnowlabs.util.Benchmark import org.apache.spark.ml.Pipeline import org.scalatest.flatspec.AnyFlatSpec @@ -56,6 +56,77 @@ class BartTestSpec extends AnyFlatSpec { .show(truncate = false) } + "distilbart_xsum_12_6" should "handle text inputs longer than 512 and not crash" taggedAs SlowTest in { + // text longer than 512 + val testData = ResourceHelper.spark + .createDataFrame( + Seq( + ( + 1, + "PG&E stated it scheduled the blackouts in response to forecasts for high winds " + + "amid dry conditions. The aim is to reduce the risk of wildfires. Nearly 800 thousand customers were " + + "scheduled to be affected by the shutoffs which were expected to last through at least midday tomorrow." + + "PG&E stated it scheduled the blackouts in response to forecasts for high winds " + + "amid dry conditions. The aim is to reduce the risk of wildfires. Nearly 800 thousand customers were " + + "scheduled to be affected by the shutoffs which were expected to last through at least midday tomorrow." + + "PG&E stated it scheduled the blackouts in response to forecasts for high winds " + + "amid dry conditions. The aim is to reduce the risk of wildfires. Nearly 800 thousand customers were " + + "scheduled to be affected by the shutoffs which were expected to last through at least midday tomorrow." + + "PG&E stated it scheduled the blackouts in response to forecasts for high winds " + + "amid dry conditions. The aim is to reduce the risk of wildfires. Nearly 800 thousand customers were " + + "scheduled to be affected by the shutoffs which were expected to last through at least midday tomorrow." + + "PG&E stated it scheduled the blackouts in response to forecasts for high winds " + + "amid dry conditions. The aim is to reduce the risk of wildfires. Nearly 800 thousand customers were " + + "scheduled to be affected by the shutoffs which were expected to last through at least midday tomorrow." + + "PG&E stated it scheduled the blackouts in response to forecasts for high winds " + + "amid dry conditions. The aim is to reduce the risk of wildfires. Nearly 800 thousand customers were " + + "scheduled to be affected by the shutoffs which were expected to last through at least midday tomorrow." + + "PG&E stated it scheduled the blackouts in response to forecasts for high winds " + + "amid dry conditions. The aim is to reduce the risk of wildfires. Nearly 800 thousand customers were " + + "scheduled to be affected by the shutoffs which were expected to last through at least midday tomorrow." + + "PG&E stated it scheduled the blackouts in response to forecasts for high winds " + + "amid dry conditions. The aim is to reduce the risk of wildfires. Nearly 800 thousand customers were " + + "scheduled to be affected by the shutoffs which were expected to last through at least midday tomorrow." + + "PG&E stated it scheduled the blackouts in response to forecasts for high winds " + + "amid dry conditions. The aim is to reduce the risk of wildfires. Nearly 800 thousand customers were " + + "scheduled to be affected by the shutoffs which were expected to last through at least midday tomorrow." + + "PG&E stated it scheduled the blackouts in response to forecasts for high winds " + + "amid dry conditions. The aim is to reduce the risk of wildfires. Nearly 800 thousand customers were " + + "scheduled to be affected by the shutoffs which were expected to last through at least midday tomorrow." + + "PG&E stated it scheduled the blackouts in response to forecasts for high winds " + + "amid dry conditions. The aim is to reduce the risk of wildfires. Nearly 800 thousand customers were " + + "scheduled to be affected by the shutoffs which were expected to last through at least midday tomorrow." + + "PG&E stated it scheduled the blackouts in response to forecasts for high winds " + + "amid dry conditions. The aim is to reduce the risk of wildfires. Nearly 800 thousand customers were " + + "scheduled to be affected by the shutoffs which were expected to last through at least midday tomorrow." + + "PG&E stated it scheduled the blackouts in response to forecasts for high winds " + + "amid dry conditions. The aim is to reduce the risk of wildfires. Nearly 800 thousand customers were " + + "scheduled to be affected by the shutoffs which were expected to last through at least midday tomorrow." + + "PG&E stated it scheduled the blackouts in response to forecasts for high winds " + + "amid dry conditions. The aim is to reduce the risk of wildfires. Nearly 800 thousand customers were " + + "scheduled to be affected by the shutoffs which were expected to last through at least midday tomorrow."))) + .toDF("id", "text") + + val documentAssembler = new DocumentAssembler() + .setInputCol("text") + .setOutputCol("documents") + + val bart = BartTransformer + .pretrained("distilbart_xsum_12_6") + .setTask("summarize:") + .setInputCols(Array("documents")) + .setDoSample(true) + .setMaxOutputLength(30) + .setOutputCol("generation") + + new Pipeline() + .setStages(Array(documentAssembler, bart)) + .fit(testData) + .transform(testData) + .select("generation.result") + .show(truncate = false) + } + "bart-large-cnn" should "run SparkNLP pipeline with maxLength=130 and doSample=true" taggedAs SlowTest in { val testData = ResourceHelper.spark .createDataFrame(