Skip to content

Commit

Permalink
SPARKNLP-746: Handle empty validation sets (#13615)
Browse files Browse the repository at this point in the history
- handle and print warning when insufficient training data
  with low validation split produces empty validation set
- resolved some warnings
  • Loading branch information
DevinTDHa authored Mar 14, 2023
1 parent eb92c16 commit bad5435
Show file tree
Hide file tree
Showing 4 changed files with 87 additions and 6 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ class TensorflowClassifier(
enableOutputLogs,
outputLogsPath)

if (validationSplit > 0.0) {
if (validationSet.nonEmpty && validationSplit > 0.0) {
println(
s"Quality on validation dataset (${validationSplit * 100}%), validation examples = ${validationSet.length}")
outputLog(
Expand All @@ -145,6 +145,9 @@ class TensorflowClassifier(
extended = evaluationLogExtended,
enableOutputLogs,
outputLogsPath)
} else if (validationSet.isEmpty) {
println(f"WARNING: Could not create validation set. " +
f"Number of data points (${inputs._1.length}) not enough for validation split $validationSplit.")
}

if (testSet.nonEmpty) {
Expand Down Expand Up @@ -256,7 +259,7 @@ class TensorflowClassifier(
.run()

val tagsId = TensorResources.extractFloats(calculated.get(0)).grouped(numClasses).toArray
val predictedLabels = tagsId.map { case (score) =>
val predictedLabels = tagsId.map { score =>
val labelId = score.zipWithIndex.maxBy(_._1)._2
labelId
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,7 @@ private[johnsnowlabs] class TensorflowMultiClassifier(
enableOutputLogs,
outputLogsPath)

if (validationSplit > 0.0) {
if (validationSet.nonEmpty && validationSplit > 0.0) {
println(
s"Quality on validation dataset (${validationSplit * 100}%), validation examples = ${validationSet.length} ")
outputLog(
Expand All @@ -167,6 +167,9 @@ private[johnsnowlabs] class TensorflowMultiClassifier(
extended = evaluationLogExtended,
enableOutputLogs,
outputLogsPath)
} else if (validationSet.isEmpty) {
println(f"WARNING: Could not create validation set. " +
f"Number of data points (${trainInputs._1.length}) not enough for validation split $validationSplit.")
}

if (testInputs.isDefined) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@ import org.apache.spark.ml.Pipeline
import org.scalatest.flatspec.AnyFlatSpec

class ClassifierDLTestSpec extends AnyFlatSpec {
import ResourceHelper.spark.implicits._


"ClassifierDL" should "correctly train IMDB train dataset" taggedAs SlowTest in {

Expand Down Expand Up @@ -58,7 +60,7 @@ class ClassifierDLTestSpec extends AnyFlatSpec {

val pipelineModel = pipeline.fit(smallCorpus)

pipelineModel.transform(smallCorpus).select("document").show(1, false)
pipelineModel.transform(smallCorpus).select("document").show(1, truncate = false)

}

Expand Down Expand Up @@ -107,4 +109,41 @@ class ClassifierDLTestSpec extends AnyFlatSpec {
classifierDL.getClasses.foreach(x => print(x + ", "))
}

"ClassifierDL" should "not fail on empty validation sets" taggedAs SlowTest in {
val documentAssembler = new DocumentAssembler()
.setInputCol("text")
.setOutputCol("document")

val sentenceEmbeddings = BertSentenceEmbeddings
.pretrained("sent_small_bert_L2_128")
.setInputCols("document")
.setOutputCol("sentence_embeddings")

val docClassifier = new ClassifierDLApproach()
.setInputCols("sentence_embeddings")
.setOutputCol("category")
.setLabelColumn("label")
.setBatchSize(8)
.setMaxEpochs(1)
.setLr(5e-3f)
.setDropout(0.5f)
.setValidationSplit(0.1f)


val pipeline = new Pipeline()
.setStages(Array(documentAssembler, sentenceEmbeddings, docClassifier))

val data = Seq(
("This is good.", "good"),
("This is bad.", "bad"),
("This has no labels", "")
).toDF("text", "label")


val pipelineModel = pipeline.fit(data)

pipelineModel.transform(data).select("document").show(1, truncate = false)
}


}
Original file line number Diff line number Diff line change
Expand Up @@ -28,15 +28,16 @@ import org.scalatest.flatspec.AnyFlatSpec
class MultiClassifierDLTestSpec extends AnyFlatSpec {

val spark: SparkSession = ResourceHelper.getActiveSparkSession
import spark.implicits._

"MultiClassifierDL" should "correctly train E2E Challenge" taggedAs SlowTest in {
def splitAndTrim = udf { labels: String =>
labels.split(", ").map(x => x.trim)
}

val smallCorpus = spark.read
.option("header", true)
.option("inferSchema", true)
.option("header", value = true)
.option("inferSchema", value = true)
.option("mode", "DROPMALFORMED")
.csv("src/test/resources/classifier/e2e.csv")
.withColumn("labels", splitAndTrim(col("mr")))
Expand Down Expand Up @@ -74,4 +75,39 @@ class MultiClassifierDLTestSpec extends AnyFlatSpec {

}

"MultiClassifierDLApproach" should "not fail on empty validation sets" taggedAs SlowTest in {
val documentAssembler = new DocumentAssembler()
.setInputCol("text")
.setOutputCol("document")

val sentenceEmbeddings = BertSentenceEmbeddings
.pretrained("sent_small_bert_L2_128")
.setInputCols("document")
.setOutputCol("embeddings")

val docClassifier = new MultiClassifierDLApproach()
.setInputCols("embeddings")
.setOutputCol("category")
.setLabelColumn("labels")
.setBatchSize(8)
.setMaxEpochs(1)
.setLr(1e-3f)
.setThreshold(0.5f)
.setEnableOutputLogs(true)
.setRandomSeed(44)
.setValidationSplit(0.1f)

val pipeline = new Pipeline()
.setStages(Array(documentAssembler, sentenceEmbeddings, docClassifier))

val data = Seq(
("This is good.", Array("good")),
("This is bad.", Array("bad")),
("This has no labels", Array.empty[String])
).toDF("text", "labels")

val pipelineModel = pipeline.fit(data)
pipelineModel.transform(data).show(1)
}

}

0 comments on commit bad5435

Please sign in to comment.