Skip to content

Commit

Permalink
SPARKNLP-955: Add missing companion object (#14088)
Browse files Browse the repository at this point in the history
  • Loading branch information
DevinTDHa authored Dec 27, 2023
1 parent a14d7b8 commit dc18ef4
Show file tree
Hide file tree
Showing 2 changed files with 47 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ package com.johnsnowlabs.nlp.annotators
import com.johnsnowlabs.nlp.functions.ExplodeAnnotations
import com.johnsnowlabs.nlp.{Annotation, AnnotatorModel, AnnotatorType, HasSimpleAnnotate}
import org.apache.spark.ml.param.{BooleanParam, IntParam, StringArrayParam}
import org.apache.spark.ml.util.Identifiable
import org.apache.spark.ml.util.{DefaultParamsReadable, Identifiable}
import org.apache.spark.sql.DataFrame

import scala.util.matching.Regex
Expand Down Expand Up @@ -270,3 +270,8 @@ class DocumentCharacterTextSplitter(override val uid: String)
}

}

/** This is the companion object of [[DocumentCharacterTextSplitter]]. Please refer to that class
* for the documentation.
*/
object DocumentCharacterTextSplitter extends DefaultParamsReadable[DocumentCharacterTextSplitter]
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
package com.johnsnowlabs.nlp.annotators

import com.johnsnowlabs.nlp.Annotation
import com.johnsnowlabs.nlp.annotator.DocumentCharacterTextSplitter
import com.johnsnowlabs.nlp.base.DocumentAssembler
import com.johnsnowlabs.nlp.util.io.ResourceHelper
import com.johnsnowlabs.tags.FastTest
import com.johnsnowlabs.tags.{FastTest, SlowTest}
import org.apache.spark.ml.Pipeline
import org.apache.spark.sql.DataFrame
import org.scalatest.flatspec.AnyFlatSpec

Expand Down Expand Up @@ -221,4 +223,42 @@ class DocumentCharacterTextSplitterTest extends AnyFlatSpec {
assertResult(sampleText, result, expected)
}

it should "be serializable" taggedAs SlowTest in {
val textSplitter = new DocumentCharacterTextSplitter()
.setInputCols("document")
.setOutputCol("splits")
.setChunkSize(1000)
.setChunkOverlap(100)
.setExplodeSplits(true)

val pipeline = new Pipeline().setStages(Array(documentAssembler, textSplitter))
val pipelineModel = pipeline.fit(splitTextDF)

pipelineModel.stages.last
.asInstanceOf[DocumentCharacterTextSplitter]
.write
.overwrite()
.save("./tmp_textSplitter")

val loadedTextSplitModel = DocumentCharacterTextSplitter.load("tmp_textSplitter")

loadedTextSplitModel.transform(textDocument).select("splits").show(truncate = false)
}

it should "be exportable to pipeline" taggedAs SlowTest in {
val textSplitter = new DocumentCharacterTextSplitter()
.setInputCols("document")
.setOutputCol("splits")
.setChunkSize(1000)
.setChunkOverlap(100)
.setExplodeSplits(true)

val pipeline = new Pipeline().setStages(Array(documentAssembler, textSplitter))
pipeline.write.overwrite().save("tmp_textsplitter_pipe")

val loadedPipelineModel = Pipeline.load("tmp_textsplitter_pipe")

loadedPipelineModel.fit(splitTextDF).transform(splitTextDF).select("splits").show()
}

}

0 comments on commit dc18ef4

Please sign in to comment.