Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Serialization features bc #78

Merged
merged 13 commits into from
Jan 8, 2018
70 changes: 38 additions & 32 deletions python/example/vivekn-sentiment/sentiment.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -9,32 +9,15 @@
"outputs": [],
"source": [
"#Imports\n",
"import time\n",
"import sys\n",
"sys.path.append('../../')\n",
"\n",
"from pyspark.ml import Pipeline\n",
"from pyspark.ml import Pipeline, PipelineModel\n",
"from sparknlp.annotator import *\n",
"from sparknlp.base import DocumentAssembler, Finisher\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"from pyspark.sql import SparkSession\n",
"\n",
"spark = SparkSession.builder \\\n",
" .master(\"local[2]\") \\\n",
" .config(\"spark.jar\", \"lib/sparknlp.jar\") \\\n",
" .config(\"spark.driver.memory\", \"5g\")\\\n",
" .config(\"spark.dirver.maxResultSize\", \"2g\")\\\n",
" .getOrCreate()"
]
},
{
"cell_type": "code",
"execution_count": null,
Expand Down Expand Up @@ -137,9 +120,22 @@
"sentiment_detector = ViveknSentimentApproach() \\\n",
" .setInputCols([\"spell\", \"sentence\"]) \\\n",
" .setOutputCol(\"sentiment\") \\\n",
" .setPruneCorpus(False) \\\n",
" .setPositiveSource(\"../../../src/test/resources/vivekn/positive\") \\\n",
" .setNegativeSource(\"../../../src/test/resources/vivekn/negative\") \\\n",
" .setPruneCorpus(False)\n"
" .setNegativeSource(\"../../../src/test/resources/vivekn/negative\") \\\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"pos = PerceptronApproach() \\\n",
" .setInputCols([\"sentence\", \"spell\"]) \\\n",
" .setOutputCol(\"pos\")"
]
},
{
Expand Down Expand Up @@ -168,11 +164,15 @@
" normalizer,\n",
" spell_checker,\n",
" sentiment_detector,\n",
" pos,\n",
" finisher\n",
"])\n",
"\n",
"start = time.time()\n",
"sentiment_data = pipeline.fit(data).transform(data)\n",
"sentiment_data.show()"
"sentiment_data.show()\n",
"end = time.time()\n",
"print(\"Time elapsed pipeline process: \" + str(end - start))"
]
},
{
Expand All @@ -188,24 +188,27 @@
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": true
},
"metadata": {},
"outputs": [],
"source": [
"start = time.time()\n",
"pipeline.write().overwrite().save(\"./ps\")\n",
"pipeline.fit(data).write().overwrite().save(\"./ms\")"
"pipeline.fit(data).write().overwrite().save(\"./ms\")\n",
"end = time.time()\n",
"print(\"Time elapsed in write pipelines: \" + str(end - start))"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": true
},
"metadata": {},
"outputs": [],
"source": [
"from pyspark.ml import Pipeline,PipelineModel"
"start = time.time()\n",
"p = Pipeline.read().load(\"./ps\")\n",
"pm = PipelineModel.read().load(\"./ms\")\n",
"end = time.time()\n",
"print(\"Time elapsed in read pipelines: \" + str(end - start))"
]
},
{
Expand All @@ -214,8 +217,11 @@
"metadata": {},
"outputs": [],
"source": [
"Pipeline.read().load(\"./ps\")\n",
"PipelineModel.read().load(\"./ms\")"
"start = time.time()\n",
"pm.transform(data).where(\"finished_sentiment not like '%negative%'\").show()\n",
"print(pm.transform(data).count())\n",
"end = time.time()\n",
"print(\"Time elapsed in using loaded pipelines: \" + str(end - start))"
]
},
{
Expand Down
3 changes: 1 addition & 2 deletions src/main/scala/com/johnsnowlabs/nlp/AnnotatorModel.scala
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ package com.johnsnowlabs.nlp

import org.apache.spark.ml.Model
import org.apache.spark.ml.param.ParamMap
import org.apache.spark.ml.util.DefaultParamsWritable
import org.apache.spark.sql.expressions.UserDefinedFunction
import org.apache.spark.sql.{DataFrame, Dataset, Row}
import org.apache.spark.sql.types._
Expand All @@ -15,7 +14,7 @@ import org.apache.spark.sql.functions.{array, udf}
*/
abstract class AnnotatorModel[M <: Model[M]]
extends Model[M]
with DefaultParamsWritable
with ParamsAndFeaturesWritable
with HasAnnotatorType
with HasInputAnnotationCols
with HasOutputAnnotationCol {
Expand Down
35 changes: 35 additions & 0 deletions src/main/scala/com/johnsnowlabs/nlp/HasFeatures.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
package com.johnsnowlabs.nlp

import com.johnsnowlabs.nlp.serialization.{ArrayFeature, Feature, MapFeature, StructFeature}

import scala.collection.mutable.ArrayBuffer

trait HasFeatures {

val features: ArrayBuffer[Feature[_, _, _]] = ArrayBuffer.empty

protected def set[T](feature: ArrayFeature[T], value: Array[T]): this.type = {feature.setValue(Some(value)); this}

protected def set[K, V](feature: MapFeature[K, V], value: Map[K, V]): this.type = {feature.setValue(Some(value)); this}

protected def set[T](feature: StructFeature[T], value: T): this.type = {feature.setValue(Some(value)); this}

protected def setDefault[T](feature: ArrayFeature[T], value: Array[T]): this.type = {feature.setValue(Some(value)); this}

protected def setDefault[K, V](feature: MapFeature[K, V], value: Map[K, V]): this.type = {feature.setValue(Some(value)); this}

protected def setDefault[T](feature: StructFeature[T], value: T): this.type = {feature.setValue(Some(value)); this}

protected def get[T](feature: ArrayFeature[T]): Option[Array[T]] = feature.get

protected def get[K, V](feature: MapFeature[K, V]): Option[Map[K, V]] = feature.get

protected def get[T](feature: StructFeature[T]): Option[T] = feature.get

protected def $$[T](feature: ArrayFeature[T]): Array[T] = feature.getValue

protected def $$[K, V](feature: MapFeature[K, V]): Map[K, V] = feature.getValue

protected def $$[T](feature: StructFeature[T]): T = feature.getValue

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
package com.johnsnowlabs.nlp

import org.apache.spark.ml.util.{DefaultParamsReadable, MLReader}
import org.apache.spark.sql.SparkSession

class FeaturesReader[T <: HasFeatures](baseReader: MLReader[T], onRead: (T, String, SparkSession) => Unit) extends MLReader[T] {

override def load(path: String): T = {

val instance = baseReader.load(path)

for (feature <- instance.features) {
val value = feature.deserialize(sparkSession, path, feature.name)
feature.setValue(value)
}

onRead(instance, path, sparkSession)

instance
}
}

trait ParamsAndFeaturesReadable[T <: HasFeatures] extends DefaultParamsReadable[T] {

def onRead(instance: T, path: String, spark: SparkSession): Unit = {}

override def read: MLReader[T] = new FeaturesReader(
super.read,
(instance: T, path: String, spark: SparkSession) => onRead(instance, path, spark)
)
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
package com.johnsnowlabs.nlp

import org.apache.spark.ml.param.Params
import org.apache.spark.ml.util.{DefaultParamsWritable, MLWriter}
import org.apache.spark.sql.SparkSession

class FeaturesWriter[T](annotatorWithFeatures: HasFeatures, baseWriter: MLWriter, onWritten: (String, SparkSession) => Unit)
extends MLWriter with HasFeatures {

override protected def saveImpl(path: String): Unit = {
baseWriter.save(path)

for (feature <- annotatorWithFeatures.features) {
feature.serializeInfer(sparkSession, path, feature.name, feature.getValue)
}

onWritten(path, sparkSession)

}
}

trait ParamsAndFeaturesWritable extends DefaultParamsWritable with Params with HasFeatures {

def onWritten(path: String, spark: SparkSession): Unit = {}

override def write: MLWriter = new FeaturesWriter(
this,
super.write,
(path: String, spark: SparkSession) => onWritten(path, spark)
)

}
29 changes: 16 additions & 13 deletions src/main/scala/com/johnsnowlabs/nlp/annotators/Lemmatizer.scala
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
package com.johnsnowlabs.nlp.annotators

import com.johnsnowlabs.nlp.annotators.common.StringMapParam
import com.johnsnowlabs.nlp.serialization.MapFeature
import com.johnsnowlabs.nlp.util.io.ResourceHelper
import com.johnsnowlabs.nlp.{Annotation, AnnotatorModel}
import com.johnsnowlabs.nlp.{Annotation, AnnotatorModel, ParamsAndFeaturesReadable}
import com.typesafe.config.Config
import com.johnsnowlabs.nlp.util.ConfigHelper
import org.apache.spark.ml.param.Param
import org.apache.spark.ml.util.{DefaultParamsReadable, Identifiable}
import org.apache.spark.ml.util.Identifiable

import scala.collection.JavaConverters._

Expand All @@ -25,7 +25,7 @@ class Lemmatizer(override val uid: String) extends AnnotatorModel[Lemmatizer] {

private val config: Config = ConfigHelper.retrieve

val lemmaDict: StringMapParam = new StringMapParam(this, "lemmaDict", "provide a lemma dictionary")
val lemmaDict: MapFeature[String, String] = new MapFeature(this, "lemmaDict")

val lemmaFormat: Param[String] = new Param[String](this, "lemmaFormat", "TXT or TXTDS for reading dictionary as dataset")

Expand All @@ -52,15 +52,23 @@ class Lemmatizer(override val uid: String) extends AnnotatorModel[Lemmatizer] {

def this() = this(Identifiable.randomUID("LEMMATIZER"))

def getLemmaDict: Map[String, String] = $(lemmaDict)
def getLemmaDict: Map[String, String] = $$(lemmaDict)
protected def getLemmaFormat: String = $(lemmaFormat)
protected def getLemmaKeySep: String = $(lemmaKeySep)
protected def getLemmaValSep: String = $(lemmaValSep)

def setLemmaDict(dictionary: String): this.type = {
set(lemmaDict, Lemmatizer.retrieveLemmaDict(dictionary, $(lemmaFormat), $(lemmaKeySep), $(lemmaValSep)))
}

def setLemmaDictHMap(dictionary: java.util.HashMap[String, String]): this.type = {
set(lemmaDict, dictionary.asScala.toMap)
}
def setLemmaDictMap(dictionary: Map[String, String]): this.type = {
set(lemmaDict, dictionary)
}
def setLemmaFormat(value: String): this.type = set(lemmaFormat, value)
def setLemmaKeySep(value: String): this.type = set(lemmaKeySep, value)
def setLemmaValSep(value: String): this.type = set(lemmaValSep, value)

/**
* @return one to one annotation from token to a lemmatized word, if found on dictionary or leave the word as is
Expand All @@ -72,19 +80,14 @@ class Lemmatizer(override val uid: String) extends AnnotatorModel[Lemmatizer] {
annotatorType,
tokenAnnotation.begin,
tokenAnnotation.end,
$(lemmaDict).getOrElse(token, token),
$$(lemmaDict).getOrElse(token, token),
tokenAnnotation.metadata
)
}
}
}

object Lemmatizer extends DefaultParamsReadable[Lemmatizer] {

/**
* Retrieves Lemma dictionary from configured compiled source set in configuration
* @return a Dictionary for lemmas
*/
object Lemmatizer extends ParamsAndFeaturesReadable[Lemmatizer] {
protected def retrieveLemmaDict(
lemmaFilePath: String,
lemmaFormat: String,
Expand Down
Loading