diff --git a/python/sparknlp/annotator.py b/python/sparknlp/annotator.py index e47bc38f2d7b1c..3c81b441376dec 100755 --- a/python/sparknlp/annotator.py +++ b/python/sparknlp/annotator.py @@ -365,7 +365,6 @@ class NorvigSweetingApproach(JavaEstimator, JavaMLWritable, JavaMLReadable, Anno @keyword_only def __init__(self, dictPath="/spell/words.txt", - slangPath="/spell/slangs.txt", caseSensitive=False, doubleVariants=False, shortCircuit=False @@ -375,7 +374,6 @@ def __init__(self, kwargs = self._input_kwargs self._setDefault( dictPath="/spell/words.txt", - slangPath="/spell/slangs.txt", caseSensitive=False, doubleVariants=False, shortCircuit=False @@ -412,7 +410,6 @@ def setShortCircuit(self, value): def setParams(self, dictPath="/spell/words.txt", - slangPath="/spell/slangs.txt", caseSensitive=False, doubleVariants=False, shortCircuit=False): diff --git a/src/main/scala/com/johnsnowlabs/nlp/annotators/spell/norvig/NorvigSweetingApproach.scala b/src/main/scala/com/johnsnowlabs/nlp/annotators/spell/norvig/NorvigSweetingApproach.scala index 53c2e46b2366d5..4df8c963d7e871 100644 --- a/src/main/scala/com/johnsnowlabs/nlp/annotators/spell/norvig/NorvigSweetingApproach.scala +++ b/src/main/scala/com/johnsnowlabs/nlp/annotators/spell/norvig/NorvigSweetingApproach.scala @@ -21,7 +21,6 @@ class NorvigSweetingApproach(override val uid: String) val slangPath = new Param[String](this, "slangPath", "path to custom dictionaries") setDefault(dictPath, "/spell/words.txt") - setDefault(slangPath, "/spell/slangs.txt") setDefault(corpusFormat, "TXT") setDefault(caseSensitive, false) @@ -43,20 +42,17 @@ class NorvigSweetingApproach(override val uid: String) def setSlangPath(value: String): this.type = set(slangPath, value) override def train(dataset: Dataset[_]): NorvigSweetingModel = { - val loadWords = ResourceHelper.wordCount($(dictPath), TXT) + val loadWords = ResourceHelper.wordCount($(dictPath), $(corpusFormat).toUpperCase) val corpusWordCount = if (get(corpusPath).isDefined) { - if ($(corpusFormat).toLowerCase == "txt") { - ResourceHelper.wordCount($(corpusPath), TXT) - } else if ($(corpusFormat).toLowerCase == "txtds") { - ResourceHelper.wordCount($(corpusPath), TXTDS) - } else { - throw new Exception("Unsupported corpusFormat. Must be txt or txtds") - } + ResourceHelper.wordCount($(corpusPath), $(corpusFormat).toUpperCase) } else { Map.empty[String, Int] } - val loadSlangs = ResourceHelper.parseKeyValueText($(slangPath), "txt", ",") + val loadSlangs = if (get(slangPath).isDefined) + ResourceHelper.parseKeyValueText($(slangPath), $(corpusFormat).toUpperCase, ",") + else + Map.empty[String, String] new NorvigSweetingModel() .setWordCount(loadWords.toMap ++ corpusWordCount) .setCustomDict(loadSlangs) diff --git a/src/main/scala/com/johnsnowlabs/nlp/util/io/ResourceHelper.scala b/src/main/scala/com/johnsnowlabs/nlp/util/io/ResourceHelper.scala index 630f6d71b4ae7d..ab17967722075b 100644 --- a/src/main/scala/com/johnsnowlabs/nlp/util/io/ResourceHelper.scala +++ b/src/main/scala/com/johnsnowlabs/nlp/util/io/ResourceHelper.scala @@ -4,7 +4,6 @@ import java.io.{File, FileNotFoundException, InputStream} import com.johnsnowlabs.nlp.annotators.{Normalizer, RegexTokenizer} import com.johnsnowlabs.nlp.{DocumentAssembler, Finisher} -import com.johnsnowlabs.nlp.annotators.common.{TaggedSentence, TaggedWord} import com.johnsnowlabs.nlp.util.io.ResourceFormat._ import org.apache.spark.ml.Pipeline import org.apache.spark.sql.SparkSession @@ -27,6 +26,22 @@ object ResourceHelper { private val spark: SparkSession = SparkSession.builder().getOrCreate() + /** Structure for a SourceStream coming from compiled content */ + case class SourceStream(resource: String) { + val pipe: Option[InputStream] = { + var stream = getClass.getResourceAsStream(resource) + if (stream == null) + stream = getClass.getClassLoader.getResourceAsStream(resource) + Option(stream) + } + val content: Source = pipe.map(p => { + Source.fromInputStream(p)("UTF-8") + }).getOrElse(Source.fromFile(resource, "UTF-8")) + def close(): Unit = { + content.close() + pipe.foreach(_.close()) + } + } def listDirectory(path: String): Seq[String] = { var dirURL = getClass.getResource(path) @@ -69,57 +84,12 @@ object ResourceHelper { throw new UnsupportedOperationException(s"Cannot list files for URL $dirURL") } - /** Structure for a SourceStream coming from compiled content */ - case class SourceStream(resource: String) { - val pipe: Option[InputStream] = { - var stream = getClass.getResourceAsStream(resource) - if (stream == null) - stream = getClass.getClassLoader.getResourceAsStream(resource) - Option(stream) - } - val content: Source = pipe.map(p => { - Source.fromInputStream(p)("UTF-8") - }).getOrElse(Source.fromFile(resource, "UTF-8")) - def close(): Unit = { - content.close() - pipe.foreach(_.close()) - } - } - /** Checks whether a path points to directory */ def pathIsDirectory(path: String): Boolean = { //ToDo: Improve me??? if (path.contains(".txt")) false else true } - /** - * General purpose key values parser from source - * Currently only text files - * @param source File input to streamline - * @param format format, for now only txt - * @param keySep separator character - * @param valueSep values separator in dictionary - * @return Dictionary of all values per key - */ - def parseKeyValuesText( - source: String, - format: Format, - keySep: String, - valueSep: String): Map[String, Array[String]] = { - format match { - case TXT => - val sourceStream = SourceStream(source) - val res = sourceStream.content.getLines.map (line => { - val kv = line.split (keySep).map (_.trim) - val key = kv (0) - val values = kv (1).split (valueSep).map (_.trim) - (key, values) - }).toMap - sourceStream.close() - res - } - } - /** * General purpose key value parser from source * Currently read only text files @@ -142,6 +112,14 @@ object ResourceHelper { }).toMap sourceStream.close() res + case TXTDS => + import spark.implicits._ + val dataset = spark.read.option("delimiter", keySep).csv(source).toDF("key", "value") + val keyValueStore = MMap.empty[String, String] + dataset.as[(String, String)].foreach{kv => keyValueStore(kv._1) = kv._2} + keyValueStore.toMap + case _ => + throw new Exception("Unsupported format. Must be TXT or TXTDS") } } @@ -162,6 +140,16 @@ object ResourceHelper { val res = sourceStream.content.getLines.toArray sourceStream.close() res + case TXTDS => + import spark.implicits._ + val dataset = spark.read.text(source) + val lineStore = spark.sparkContext.collectionAccumulator[String] + dataset.as[String].foreach(l => lineStore.add(l)) + val result = lineStore.value.toArray.map(_.toString) + lineStore.reset() + result + case _ => + throw new Exception("Unsupported format. Must be TXT or TXTDS") } } @@ -187,6 +175,19 @@ object ResourceHelper { }).toArray sourceStream.close() res + case TXTDS => + import spark.implicits._ + val dataset = spark.read.text(source) + val lineStore = spark.sparkContext.collectionAccumulator[String] + dataset.as[String].foreach(l => lineStore.add(l)) + val result = lineStore.value.toArray.map(line => { + val kv = line.toString.split (keySep).map (_.trim) + (kv.head, kv.last) + }) + lineStore.reset() + result + case _ => + throw new Exception("Unsupported format. Must be TXT or TXTDS") } } @@ -215,7 +216,19 @@ object ResourceHelper { }) sourceStream.close() m.toMap - case _ => throw new IllegalArgumentException("Only txt supported as a file format") + case TXTDS => + import spark.implicits._ + val dataset = spark.read.text(source) + val valueAsKeys = MMap.empty[String, String] + dataset.as[String].foreach(line => { + val kv = line.split(keySep).map(_.trim) + val key = kv(0) + val values = kv(1).split(valueSep).map(_.trim) + values.foreach(v => valueAsKeys(v) = key) + }) + valueAsKeys.toMap + case _ => + throw new Exception("Unsupported format. Must be TXT or TXTDS") } } @@ -273,7 +286,7 @@ object ResourceHelper { case TXTDS => import spark.implicits._ val dataset = spark.read.textFile(source) - val wordCount = spark.sparkContext.broadcast(MMap.empty[String, Int].withDefaultValue(0)) + val wordCount = MMap.empty[String, Int].withDefaultValue(0) val documentAssembler = new DocumentAssembler() .setInputCol("value") val tokenizer = new RegexTokenizer() @@ -292,11 +305,9 @@ object ResourceHelper { .transform(dataset) .select("finished").as[String] .foreach(text => text.split("--").foreach(t => { - wordCount.value(t) += 1 + wordCount(t) += 1 })) - val result = wordCount.value - wordCount.destroy() - result + wordCount case _ => throw new IllegalArgumentException("format not available for word count") } } diff --git a/src/main/resources/spell/slangs.txt b/src/test/resources/spell/slangs.txt similarity index 100% rename from src/main/resources/spell/slangs.txt rename to src/test/resources/spell/slangs.txt diff --git a/src/test/scala/com/johnsnowlabs/nlp/AnnotatorBuilder.scala b/src/test/scala/com/johnsnowlabs/nlp/AnnotatorBuilder.scala index 34c9eabe2b95fd..20b847e5b6ab71 100644 --- a/src/test/scala/com/johnsnowlabs/nlp/AnnotatorBuilder.scala +++ b/src/test/scala/com/johnsnowlabs/nlp/AnnotatorBuilder.scala @@ -122,6 +122,7 @@ object AnnotatorBuilder extends FlatSpec { this: Suite => val spellChecker = new NorvigSweetingApproach() .setInputCols(Array("normalized")) .setOutputCol("spell") + .setDictPath("./src/main/resources/spell/words.txt") .setCorpusPath("./src/test/resources/spell/sherlockholmes.txt") .setCorpusFormat(inputFormat) spellChecker.fit(withFullNormalizer(dataset)).transform(withFullNormalizer(dataset)) diff --git a/src/test/scala/com/johnsnowlabs/nlp/annotators/spell/norvig/NorvigSweetingBehaviors.scala b/src/test/scala/com/johnsnowlabs/nlp/annotators/spell/norvig/NorvigSweetingBehaviors.scala index 4d9d69293c0134..ec5d6f01b04f36 100644 --- a/src/test/scala/com/johnsnowlabs/nlp/annotators/spell/norvig/NorvigSweetingBehaviors.scala +++ b/src/test/scala/com/johnsnowlabs/nlp/annotators/spell/norvig/NorvigSweetingBehaviors.scala @@ -10,6 +10,7 @@ trait NorvigSweetingBehaviors { this: FlatSpec => val spellChecker = new NorvigSweetingApproach() .setCorpusPath("/spell") + .setSlangPath("/spell/slangs.txt") .fit(DataBuilder.basicDataBuild("dummy")) def isolatedNorvigChecker(wordAnswer: Seq[(String, String)]): Unit = {