diff --git a/src/main/scala/com/johnsnowlabs/nlp/annotators/DocumentCharacterTextSplitter.scala b/src/main/scala/com/johnsnowlabs/nlp/annotators/DocumentCharacterTextSplitter.scala index ca7f5c15d7705f..27cc0acc8665cd 100644 --- a/src/main/scala/com/johnsnowlabs/nlp/annotators/DocumentCharacterTextSplitter.scala +++ b/src/main/scala/com/johnsnowlabs/nlp/annotators/DocumentCharacterTextSplitter.scala @@ -3,7 +3,7 @@ package com.johnsnowlabs.nlp.annotators import com.johnsnowlabs.nlp.functions.ExplodeAnnotations import com.johnsnowlabs.nlp.{Annotation, AnnotatorModel, AnnotatorType, HasSimpleAnnotate} import org.apache.spark.ml.param.{BooleanParam, IntParam, StringArrayParam} -import org.apache.spark.ml.util.Identifiable +import org.apache.spark.ml.util.{DefaultParamsReadable, Identifiable} import org.apache.spark.sql.DataFrame import scala.util.matching.Regex @@ -270,3 +270,8 @@ class DocumentCharacterTextSplitter(override val uid: String) } } + +/** This is the companion object of [[DocumentCharacterTextSplitter]]. Please refer to that class + * for the documentation. + */ +object DocumentCharacterTextSplitter extends DefaultParamsReadable[DocumentCharacterTextSplitter] diff --git a/src/test/scala/com/johnsnowlabs/nlp/annotators/DocumentCharacterTextSplitterTest.scala b/src/test/scala/com/johnsnowlabs/nlp/annotators/DocumentCharacterTextSplitterTest.scala index e8179829b63e85..b554a83ac046c7 100644 --- a/src/test/scala/com/johnsnowlabs/nlp/annotators/DocumentCharacterTextSplitterTest.scala +++ b/src/test/scala/com/johnsnowlabs/nlp/annotators/DocumentCharacterTextSplitterTest.scala @@ -1,9 +1,11 @@ package com.johnsnowlabs.nlp.annotators import com.johnsnowlabs.nlp.Annotation +import com.johnsnowlabs.nlp.annotator.DocumentCharacterTextSplitter import com.johnsnowlabs.nlp.base.DocumentAssembler import com.johnsnowlabs.nlp.util.io.ResourceHelper -import com.johnsnowlabs.tags.FastTest +import com.johnsnowlabs.tags.{FastTest, SlowTest} +import org.apache.spark.ml.Pipeline import org.apache.spark.sql.DataFrame import org.scalatest.flatspec.AnyFlatSpec @@ -221,4 +223,42 @@ class DocumentCharacterTextSplitterTest extends AnyFlatSpec { assertResult(sampleText, result, expected) } + it should "be serializable" taggedAs SlowTest in { + val textSplitter = new DocumentCharacterTextSplitter() + .setInputCols("document") + .setOutputCol("splits") + .setChunkSize(1000) + .setChunkOverlap(100) + .setExplodeSplits(true) + + val pipeline = new Pipeline().setStages(Array(documentAssembler, textSplitter)) + val pipelineModel = pipeline.fit(splitTextDF) + + pipelineModel.stages.last + .asInstanceOf[DocumentCharacterTextSplitter] + .write + .overwrite() + .save("./tmp_textSplitter") + + val loadedTextSplitModel = DocumentCharacterTextSplitter.load("tmp_textSplitter") + + loadedTextSplitModel.transform(textDocument).select("splits").show(truncate = false) + } + + it should "be exportable to pipeline" taggedAs SlowTest in { + val textSplitter = new DocumentCharacterTextSplitter() + .setInputCols("document") + .setOutputCol("splits") + .setChunkSize(1000) + .setChunkOverlap(100) + .setExplodeSplits(true) + + val pipeline = new Pipeline().setStages(Array(documentAssembler, textSplitter)) + pipeline.write.overwrite().save("tmp_textsplitter_pipe") + + val loadedPipelineModel = Pipeline.load("tmp_textsplitter_pipe") + + loadedPipelineModel.fit(splitTextDF).transform(splitTextDF).select("splits").show() + } + }