Skip to content

Commit

Permalink
SPARKNLP-846: BART: Added maxInputLength. (#13863)
Browse files Browse the repository at this point in the history
* Added maxInputLength. fixes #13829

* changed test types
  • Loading branch information
prabod authored Jul 1, 2023
1 parent d8174cc commit 7645ad4
Show file tree
Hide file tree
Showing 4 changed files with 172 additions and 7 deletions.
75 changes: 75 additions & 0 deletions python/test/annotator/seq2seq/bart_transformer_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,81 @@ def runTest(self):

results.select("documents.result", "answers.result").show(truncate=False)

@pytest.mark.slow
class BartTransformerMaxLengthTestSpec(unittest.TestCase):
def setUp(self):
self.spark = SparkContextForTest.spark

def runTest(self):
data = self.spark.createDataFrame([
[1, """
Heat oven to 200C/180C fan/gas 6. Line each hole of a 12-hole muffin tin with a thin strip of baking
parchment across the middle that’s long enough so the ends stick out a centimetre or two – use a dab of
butter to stick in place. Roll out two thirds of the pastry on a lightly floured surface and stamp out
12 x 10cm circles (you may need to re-roll trimmings). Press a circle into each hole to line.
Sprinkle 1 tsp of breadcrumbs into the base of each pie. Tip the rest of the crumbs into a mixing bowl.
Squeeze in the sausage meat, discarding the skins, along with the bacon, mace, pepper, sage and just a
little salt. Get your hands in and mash and squish everything together until the breadcrumbs have just
about disappeared. Divide mixture between the holes, packing in firmly and shaping to a dome in the middle.
Roll out the remaining pastry and stamp out 12 x 7cm circles. Brush with a little egg and add a top to
each pie, egg-side down to stick, carefully pressing pastry edges together to seal. Brush with more egg
(don’t throw away leftovers) and sprinkle with sesame seeds. Bake for 30 mins until golden then carefully
remove the pies from the tin, using the parchment ends to help you lift them out. Sit on a parchment lined
baking tray, brush all round the sides with more egg and put back in the oven for 8 mins. Cool completely
then eat with piccalilli, or your favourite pickle.
Heat oven to 200C/180C fan/gas 6. Line each hole of a 12-hole muffin tin with a thin strip of baking
parchment across the middle that’s long enough so the ends stick out a centimetre or two – use a dab of
butter to stick in place. Roll out two thirds of the pastry on a lightly floured surface and stamp out
12 x 10cm circles (you may need to re-roll trimmings). Press a circle into each hole to line.
Sprinkle 1 tsp of breadcrumbs into the base of each pie. Tip the rest of the crumbs into a mixing bowl.
Squeeze in the sausage meat, discarding the skins, along with the bacon, mace, pepper, sage and just a
little salt. Get your hands in and mash and squish everything together until the breadcrumbs have just
about disappeared. Divide mixture between the holes, packing in firmly and shaping to a dome in the middle.
Roll out the remaining pastry and stamp out 12 x 7cm circles. Brush with a little egg and add a top to
each pie, egg-side down to stick, carefully pressing pastry edges together to seal. Brush with more egg
(don’t throw away leftovers) and sprinkle with sesame seeds. Bake for 30 mins until golden then carefully
remove the pies from the tin, using the parchment ends to help you lift them out. Sit on a parchment lined
baking tray, brush all round the sides with more egg and put back in the oven for 8 mins. Cool completely
then eat with piccalilli, or your favourite pickle.
Heat oven to 200C/180C fan/gas 6. Line each hole of a 12-hole muffin tin with a thin strip of baking
parchment across the middle that’s long enough so the ends stick out a centimetre or two – use a dab of
butter to stick in place. Roll out two thirds of the pastry on a lightly floured surface and stamp out
12 x 10cm circles (you may need to re-roll trimmings). Press a circle into each hole to line.
Sprinkle 1 tsp of breadcrumbs into the base of each pie. Tip the rest of the crumbs into a mixing bowl.
Squeeze in the sausage meat, discarding the skins, along with the bacon, mace, pepper, sage and just a
little salt. Get your hands in and mash and squish everything together until the breadcrumbs have just
about disappeared. Divide mixture between the holes, packing in firmly and shaping to a dome in the middle.
Roll out the remaining pastry and stamp out 12 x 7cm circles. Brush with a little egg and add a top to
each pie, egg-side down to stick, carefully pressing pastry edges together to seal. Brush with more egg
(don’t throw away leftovers) and sprinkle with sesame seeds. Bake for 30 mins until golden then carefully
remove the pies from the tin, using the parchment ends to help you lift them out. Sit on a parchment lined
baking tray, brush all round the sides with more egg and put back in the oven for 8 mins. Cool completely
then eat with piccalilli, or your favourite pickle.
""".strip().replace("\n", " ")]]).toDF("id", "text")

document_assembler = DocumentAssembler() \
.setInputCol("text") \
.setOutputCol("documents")

bart = BartTransformer.pretrained("distilbart_xsum_12_6") \
.setTask("summarize:") \
.setMaxOutputLength(30) \
.setInputCols(["documents"]) \
.setOutputCol("summaries")

pipeline = Pipeline().setStages([document_assembler, bart])
results = pipeline.fit(data).transform(data)

results.select("summaries.result").show(truncate=False)


@pytest.mark.slow
class BartTransformerSummaryTestSpec(unittest.TestCase):
Expand Down
15 changes: 10 additions & 5 deletions src/main/scala/com/johnsnowlabs/ml/ai/Bart.scala
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,8 @@ private[johnsnowlabs] class Bart(
task: String,
randomSeed: Option[Long] = None,
ignoreTokenIds: Array[Int] = Array(),
beamSize: Int): Seq[Annotation] = {
beamSize: Int,
maxInputLength: Int): Seq[Annotation] = {

val batchDecoder = sentences.grouped(batchSize).toArray.flatMap { batch =>
val batchSP = encode(batch, task)
Expand All @@ -117,7 +118,8 @@ private[johnsnowlabs] class Bart(
noRepeatNgramSize,
randomSeed,
ignoreTokenIds,
beamSize)
beamSize,
maxInputLength)

decode(spIds)

Expand Down Expand Up @@ -176,10 +178,12 @@ private[johnsnowlabs] class Bart(
noRepeatNgramSize: Int,
randomSeed: Option[Long],
ignoreTokenIds: Array[Int] = Array(),
beamSize: Int): Array[Array[Int]] = {
beamSize: Int,
maxInputLength: Int): Array[Array[Int]] = {

val ignoreTokenIdsInt = ignoreTokenIds
val expandedEncoderInputIdsVals = batch.flatMap(x => List.fill(beamSize)(x))
val expandedEncoderInputIdsVals =
batch.flatMap(x => List.fill(beamSize)(x.take(maxInputLength)))
val sequencesLength = expandedEncoderInputIdsVals.map(x => x.length).toArray
val maxSentenceLength = sequencesLength.max // - curLen

Expand Down Expand Up @@ -492,6 +496,7 @@ private[johnsnowlabs] class Bart(
noRepeatNgramSize = 0,
randomSeed = Option(0),
ignoreTokenIds = Array(0),
beamSize = 1)
beamSize = 1,
maxInputLength = 512)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -199,6 +199,18 @@ class BartTransformer(override val uid: String)
this
}

/** max length of the input sequence (Default: `0`)
*
* @group param
*/
val maxInputLength =
new IntParam(this, "maxInputLength", "Maximum length of the input sequence")

def setMaxInputLength(value: Int): BartTransformer.this.type = {
set(maxInputLength, value)
this
}

/** @group getParam */
def getMinOutputLength: Int = $(this.minOutputLength)

Expand Down Expand Up @@ -477,6 +489,7 @@ class BartTransformer(override val uid: String)
ignoreTokenIds -> Array(),
batchSize -> 1,
beamSize -> 4,
maxInputLength -> 512,
useCache -> true)

override def batchAnnotate(batchedAnnotations: Seq[Array[Annotation]]): Seq[Seq[Annotation]] = {
Expand All @@ -503,7 +516,8 @@ class BartTransformer(override val uid: String)
task = $(task),
randomSeed = this.randomSeed,
ignoreTokenIds = $(ignoreTokenIds),
beamSize = $(beamSize))
beamSize = $(beamSize),
maxInputLength = $(maxInputLength))
} else {
Seq()
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ package com.johnsnowlabs.nlp.annotators.seq2seq

import com.johnsnowlabs.nlp.base.DocumentAssembler
import com.johnsnowlabs.nlp.util.io.ResourceHelper
import com.johnsnowlabs.tags.SlowTest
import com.johnsnowlabs.tags.{SlowTest, FastTest}
import com.johnsnowlabs.util.Benchmark
import org.apache.spark.ml.Pipeline
import org.scalatest.flatspec.AnyFlatSpec
Expand Down Expand Up @@ -56,6 +56,77 @@ class BartTestSpec extends AnyFlatSpec {
.show(truncate = false)

}
"distilbart_xsum_12_6" should "handle text inputs longer than 512 and not crash" taggedAs SlowTest in {
// text longer than 512
val testData = ResourceHelper.spark
.createDataFrame(
Seq(
(
1,
"PG&E stated it scheduled the blackouts in response to forecasts for high winds " +
"amid dry conditions. The aim is to reduce the risk of wildfires. Nearly 800 thousand customers were " +
"scheduled to be affected by the shutoffs which were expected to last through at least midday tomorrow." +
"PG&E stated it scheduled the blackouts in response to forecasts for high winds " +
"amid dry conditions. The aim is to reduce the risk of wildfires. Nearly 800 thousand customers were " +
"scheduled to be affected by the shutoffs which were expected to last through at least midday tomorrow." +
"PG&E stated it scheduled the blackouts in response to forecasts for high winds " +
"amid dry conditions. The aim is to reduce the risk of wildfires. Nearly 800 thousand customers were " +
"scheduled to be affected by the shutoffs which were expected to last through at least midday tomorrow." +
"PG&E stated it scheduled the blackouts in response to forecasts for high winds " +
"amid dry conditions. The aim is to reduce the risk of wildfires. Nearly 800 thousand customers were " +
"scheduled to be affected by the shutoffs which were expected to last through at least midday tomorrow." +
"PG&E stated it scheduled the blackouts in response to forecasts for high winds " +
"amid dry conditions. The aim is to reduce the risk of wildfires. Nearly 800 thousand customers were " +
"scheduled to be affected by the shutoffs which were expected to last through at least midday tomorrow." +
"PG&E stated it scheduled the blackouts in response to forecasts for high winds " +
"amid dry conditions. The aim is to reduce the risk of wildfires. Nearly 800 thousand customers were " +
"scheduled to be affected by the shutoffs which were expected to last through at least midday tomorrow." +
"PG&E stated it scheduled the blackouts in response to forecasts for high winds " +
"amid dry conditions. The aim is to reduce the risk of wildfires. Nearly 800 thousand customers were " +
"scheduled to be affected by the shutoffs which were expected to last through at least midday tomorrow." +
"PG&E stated it scheduled the blackouts in response to forecasts for high winds " +
"amid dry conditions. The aim is to reduce the risk of wildfires. Nearly 800 thousand customers were " +
"scheduled to be affected by the shutoffs which were expected to last through at least midday tomorrow." +
"PG&E stated it scheduled the blackouts in response to forecasts for high winds " +
"amid dry conditions. The aim is to reduce the risk of wildfires. Nearly 800 thousand customers were " +
"scheduled to be affected by the shutoffs which were expected to last through at least midday tomorrow." +
"PG&E stated it scheduled the blackouts in response to forecasts for high winds " +
"amid dry conditions. The aim is to reduce the risk of wildfires. Nearly 800 thousand customers were " +
"scheduled to be affected by the shutoffs which were expected to last through at least midday tomorrow." +
"PG&E stated it scheduled the blackouts in response to forecasts for high winds " +
"amid dry conditions. The aim is to reduce the risk of wildfires. Nearly 800 thousand customers were " +
"scheduled to be affected by the shutoffs which were expected to last through at least midday tomorrow." +
"PG&E stated it scheduled the blackouts in response to forecasts for high winds " +
"amid dry conditions. The aim is to reduce the risk of wildfires. Nearly 800 thousand customers were " +
"scheduled to be affected by the shutoffs which were expected to last through at least midday tomorrow." +
"PG&E stated it scheduled the blackouts in response to forecasts for high winds " +
"amid dry conditions. The aim is to reduce the risk of wildfires. Nearly 800 thousand customers were " +
"scheduled to be affected by the shutoffs which were expected to last through at least midday tomorrow." +
"PG&E stated it scheduled the blackouts in response to forecasts for high winds " +
"amid dry conditions. The aim is to reduce the risk of wildfires. Nearly 800 thousand customers were " +
"scheduled to be affected by the shutoffs which were expected to last through at least midday tomorrow.")))
.toDF("id", "text")

val documentAssembler = new DocumentAssembler()
.setInputCol("text")
.setOutputCol("documents")

val bart = BartTransformer
.pretrained("distilbart_xsum_12_6")
.setTask("summarize:")
.setInputCols(Array("documents"))
.setDoSample(true)
.setMaxOutputLength(30)
.setOutputCol("generation")

new Pipeline()
.setStages(Array(documentAssembler, bart))
.fit(testData)
.transform(testData)
.select("generation.result")
.show(truncate = false)
}

"bart-large-cnn" should "run SparkNLP pipeline with maxLength=130 and doSample=true" taggedAs SlowTest in {
val testData = ResourceHelper.spark
.createDataFrame(
Expand Down

0 comments on commit 7645ad4

Please sign in to comment.