Skip to content

Commit

Permalink
Bug fixed #13898
Browse files Browse the repository at this point in the history
  • Loading branch information
prabod committed Aug 1, 2023
1 parent 2b2f93c commit 3fe2379
Show file tree
Hide file tree
Showing 3 changed files with 99 additions and 25 deletions.
6 changes: 5 additions & 1 deletion src/main/scala/com/johnsnowlabs/ml/ai/Bart.scala
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@ private[johnsnowlabs] class Bart(
}

var sentBegin, nextSentEnd = 0
batchDecoder.zip(sentences).map { case (content, sent) =>
val annotations = batchDecoder.zip(sentences).map { case (content, sent) =>
nextSentEnd += content.length - 1
val annots = new Annotation(
annotatorType = AnnotatorType.DOCUMENT,
Expand All @@ -137,6 +137,10 @@ private[johnsnowlabs] class Bart(
sentBegin += nextSentEnd + 1
annots
}
tensorDecoder = new TensorResources()
nextStateTensor1 = None
nextStateTensor2 = None
annotations
}

/** @param batch
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -367,37 +367,45 @@ trait Generate {
def multinomialSampling(logitValues: Array[Float], k: Int, seed: Option[Long]): Array[Int] = {
val (distFiltered, indices) =
logitValues.zipWithIndex.filter { case (elem, index) => !elem.isInfinite }.sorted.unzip
if (!distFiltered.isEmpty) {

val maxLogit = distFiltered.max
val expLogitValues = distFiltered.map(logit => math.exp(logit - maxLogit))
val sumExpLogitValues = expLogitValues.sum
val probabilities = expLogitValues.map(_ / sumExpLogitValues)
val maxLogit = distFiltered.max
val expLogitValues = distFiltered.map(logit => math.exp(logit - maxLogit))
val sumExpLogitValues = expLogitValues.sum
val probabilities = expLogitValues.map(_ / sumExpLogitValues)

val selectedIndices = new Array[Int](k)
var seededRandom = new scala.util.Random()
if (seed.isDefined) {
seededRandom = new scala.util.Random(seed.get)
}
for (i <- 0 until k) {
var rand = scala.util.Random.nextDouble()
val selectedIndices = new Array[Int](k)
var seededRandom = new scala.util.Random()
if (seed.isDefined) {
rand = new scala.util.Random(seed.get).nextDouble()
seededRandom = new scala.util.Random(seed.get)
}
var cumProb = 0.0
var j = 0
while (j < probabilities.length - i) {
cumProb += probabilities(j)
if (rand < cumProb) {
selectedIndices(i) = indices(j)
probabilities(j) = 0.0
indices(j) = indices(indices.length - i - 1)
j = probabilities.length
for (i <- 0 until k) {
var rand = scala.util.Random.nextDouble()
if (seed.isDefined) {
rand = new scala.util.Random(seed.get).nextDouble()
}
var cumProb = 0.0
var j = 0
while (j < probabilities.length - i) {
cumProb += probabilities(j)
if (rand < cumProb) {
selectedIndices(i) = indices(j)
probabilities(j) = 0.0
indices(j) = indices(indices.length - i - 1)
j = probabilities.length
}
j += 1
}
j += 1
}
}

selectedIndices
selectedIndices
} else {
val selectedIndices = new Array[Int](k)
for (i <- 0 until k) {
selectedIndices(i) = 0
}
selectedIndices
}
}

def getModelOutput(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -274,4 +274,66 @@ class BartTestSpec extends AnyFlatSpec {

assert(dataframe1.equals(dataframe2))
}

"bart-large-cnn" should "run SparkNLP pipeline with doSample=false and later change to true " taggedAs SlowTest in {
val testData = ResourceHelper.spark
.createDataFrame(
Seq(
(
1,
"Preheat the oven to 220°C/ fan200°C/gas 7. Trim the lamb fillet of fat and cut into slices the thickness" +
" of a chop. Cut the kidneys in half and snip out the white core. Melt a knob of dripping or 2 tablespoons " +
"of vegetable oil in a heavy large pan. Fry the lamb fillet in batches for 3-4 minutes, turning once, until " +
"browned. Set aside. Fry the kidneys and cook for 1-2 minutes, turning once, until browned. Set aside." +
"Wipe the pan with kitchen paper, then add the butter. Add the onions and fry for about 10 minutes until " +
"softened. Sprinkle in the flour and stir well for 1 minute. Gradually pour in the stock, stirring all the " +
"time to avoid lumps. Add the herbs. Stir the lamb and kidneys into the onions. Season well. Transfer to a" +
" large 2.5-litre casserole. Slice the peeled potatoes thinly and arrange on top in overlapping rows. Brush " +
"with melted butter and season. Cover and bake for 30 minutes. Reduce the oven temperature to 160°C" +
"/fan140°C/gas 3 and cook for a further 2 hours. Then increase the oven temperature to 200°C/ fan180°C/gas 6," +
" uncover, and brush the potatoes with more butter. Cook uncovered for 15-20 minutes, or until golden.")))
.toDF("id", "text")

val documentAssembler = new DocumentAssembler()
.setInputCol("text")
.setOutputCol("documents")

val bart = BartTransformer
.pretrained("distilbart_xsum_12_6")
.setTask("summarize:")
.setInputCols(Array("documents"))
.setDoSample(false)
.setRandomSeed(56)
.setMaxOutputLength(128)
.setTemperature(0.1)
.setOutputCol("summaries")

val pipeline = new Pipeline().setStages(Array(documentAssembler, bart))

val model = pipeline.fit(testData)

var dataframe1 = model
.transform(testData)
.select("summaries.result")
.collect()
.toSeq
.head
.getAs[Seq[String]](0)
.head
println(dataframe1)

bart.setDoSample(true)

dataframe1 = model
.transform(testData)
.select("summaries.result")
.collect()
.toSeq
.head
.getAs[Seq[String]](0)
.head
println(dataframe1)

}

}

0 comments on commit 3fe2379

Please sign in to comment.