From fe89c1817d668e46adf70d0896c42c22a547c76a Mon Sep 17 00:00:00 2001 From: Xiangrui Meng Date: Sun, 22 Nov 2015 21:45:46 -0800 Subject: [PATCH 01/18] [SPARK-11895][ML] rename and refactor DatasetExample under mllib/examples We used the name `Dataset` to refer to `SchemaRDD` in 1.2 in ML pipelines and created this example file. Since `Dataset` has a new meaning in Spark 1.6, we should rename it to avoid confusion. This PR also removes support for dense format to simplify the example code. cc: yinxusen Author: Xiangrui Meng Closes #9873 from mengxr/SPARK-11895. --- .../DataFrameExample.scala} | 71 +++++++------------ 1 file changed, 26 insertions(+), 45 deletions(-) rename examples/src/main/scala/org/apache/spark/examples/{mllib/DatasetExample.scala => ml/DataFrameExample.scala} (51%) diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/DatasetExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/DataFrameExample.scala similarity index 51% rename from examples/src/main/scala/org/apache/spark/examples/mllib/DatasetExample.scala rename to examples/src/main/scala/org/apache/spark/examples/ml/DataFrameExample.scala index dc13f82488af7..424f00158c2f2 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/DatasetExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/DataFrameExample.scala @@ -16,7 +16,7 @@ */ // scalastyle:off println -package org.apache.spark.examples.mllib +package org.apache.spark.examples.ml import java.io.File @@ -24,25 +24,22 @@ import com.google.common.io.Files import scopt.OptionParser import org.apache.spark.{SparkConf, SparkContext} +import org.apache.spark.examples.mllib.AbstractParams import org.apache.spark.mllib.linalg.Vector -import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.stat.MultivariateOnlineSummarizer -import org.apache.spark.mllib.util.MLUtils -import org.apache.spark.rdd.RDD -import org.apache.spark.sql.{Row, SQLContext, DataFrame} +import org.apache.spark.sql.{DataFrame, Row, SQLContext} /** - * An example of how to use [[org.apache.spark.sql.DataFrame]] as a Dataset for ML. Run with + * An example of how to use [[org.apache.spark.sql.DataFrame]] for ML. Run with * {{{ - * ./bin/run-example org.apache.spark.examples.mllib.DatasetExample [options] + * ./bin/run-example ml.DataFrameExample [options] * }}} * If you use it as a template to create your own app, please use `spark-submit` to submit your app. */ -object DatasetExample { +object DataFrameExample { - case class Params( - input: String = "data/mllib/sample_libsvm_data.txt", - dataFormat: String = "libsvm") extends AbstractParams[Params] + case class Params(input: String = "data/mllib/sample_libsvm_data.txt") + extends AbstractParams[Params] def main(args: Array[String]) { val defaultParams = Params() @@ -52,9 +49,6 @@ object DatasetExample { opt[String]("input") .text(s"input path to dataset") .action((x, c) => c.copy(input = x)) - opt[String]("dataFormat") - .text("data format: libsvm (default), dense (deprecated in Spark v1.1)") - .action((x, c) => c.copy(input = x)) checkConfig { params => success } @@ -69,55 +63,42 @@ object DatasetExample { def run(params: Params) { - val conf = new SparkConf().setAppName(s"DatasetExample with $params") + val conf = new SparkConf().setAppName(s"DataFrameExample with $params") val sc = new SparkContext(conf) val sqlContext = new SQLContext(sc) - import sqlContext.implicits._ // for implicit conversions // Load input data - val origData: RDD[LabeledPoint] = params.dataFormat match { - case "dense" => MLUtils.loadLabeledPoints(sc, params.input) - case "libsvm" => MLUtils.loadLibSVMFile(sc, params.input) - } - println(s"Loaded ${origData.count()} instances from file: ${params.input}") - - // Convert input data to DataFrame explicitly. - val df: DataFrame = origData.toDF() - println(s"Inferred schema:\n${df.schema.prettyJson}") - println(s"Converted to DataFrame with ${df.count()} records") - - // Select columns - val labelsDf: DataFrame = df.select("label") - val labels: RDD[Double] = labelsDf.map { case Row(v: Double) => v } - val numLabels = labels.count() - val meanLabel = labels.fold(0.0)(_ + _) / numLabels - println(s"Selected label column with average value $meanLabel") - - val featuresDf: DataFrame = df.select("features") - val features: RDD[Vector] = featuresDf.map { case Row(v: Vector) => v } + println(s"Loading LIBSVM file with UDT from ${params.input}.") + val df: DataFrame = sqlContext.read.format("libsvm").load(params.input).cache() + println("Schema from LIBSVM:") + df.printSchema() + println(s"Loaded training data as a DataFrame with ${df.count()} records.") + + // Show statistical summary of labels. + val labelSummary = df.describe("label") + labelSummary.show() + + // Convert features column to an RDD of vectors. + val features = df.select("features").map { case Row(v: Vector) => v } val featureSummary = features.aggregate(new MultivariateOnlineSummarizer())( (summary, feat) => summary.add(feat), (sum1, sum2) => sum1.merge(sum2)) println(s"Selected features column with average values:\n ${featureSummary.mean.toString}") + // Save the records in a parquet file. val tmpDir = Files.createTempDir() tmpDir.deleteOnExit() val outputDir = new File(tmpDir, "dataset").toString println(s"Saving to $outputDir as Parquet file.") df.write.parquet(outputDir) + // Load the records back. println(s"Loading Parquet file with UDT from $outputDir.") - val newDataset = sqlContext.read.parquet(outputDir) - - println(s"Schema from Parquet: ${newDataset.schema.prettyJson}") - val newFeatures = newDataset.select("features").map { case Row(v: Vector) => v } - val newFeaturesSummary = newFeatures.aggregate(new MultivariateOnlineSummarizer())( - (summary, feat) => summary.add(feat), - (sum1, sum2) => sum1.merge(sum2)) - println(s"Selected features column with average values:\n ${newFeaturesSummary.mean.toString}") + val newDF = sqlContext.read.parquet(outputDir) + println(s"Schema from Parquet:") + newDF.printSchema() sc.stop() } - } // scalastyle:on println From a6fda0bfc16a13b28b1cecc96f1ff91363089144 Mon Sep 17 00:00:00 2001 From: "Joseph K. Bradley" Date: Sun, 22 Nov 2015 21:48:48 -0800 Subject: [PATCH 02/18] [SPARK-6791][ML] Add read/write for CrossValidator and Evaluators I believe this works for general estimators within CrossValidator, including compound estimators. (See the complex unit test.) Added read/write for all 3 Evaluators as well. CC: mengxr yanboliang Author: Joseph K. Bradley Closes #9848 from jkbradley/cv-io. --- .../scala/org/apache/spark/ml/Pipeline.scala | 38 +-- .../BinaryClassificationEvaluator.scala | 11 +- .../MulticlassClassificationEvaluator.scala | 12 +- .../ml/evaluation/RegressionEvaluator.scala | 11 +- .../apache/spark/ml/recommendation/ALS.scala | 14 +- .../spark/ml/tuning/CrossValidator.scala | 229 +++++++++++++++++- .../org/apache/spark/ml/util/ReadWrite.scala | 48 ++-- .../org/apache/spark/ml/PipelineSuite.scala | 4 +- .../BinaryClassificationEvaluatorSuite.scala | 13 +- ...lticlassClassificationEvaluatorSuite.scala | 13 +- .../evaluation/RegressionEvaluatorSuite.scala | 12 +- .../spark/ml/tuning/CrossValidatorSuite.scala | 202 ++++++++++++++- 12 files changed, 522 insertions(+), 85 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala b/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala index 6f15b37abcb30..4b2b3f8489fd0 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala @@ -34,7 +34,6 @@ import org.apache.spark.ml.util.MLWriter import org.apache.spark.ml.util._ import org.apache.spark.sql.DataFrame import org.apache.spark.sql.types.StructType -import org.apache.spark.util.Utils /** * :: DeveloperApi :: @@ -232,20 +231,9 @@ object Pipeline extends MLReadable[Pipeline] { stages: Array[PipelineStage], sc: SparkContext, path: String): Unit = { - // Copied and edited from DefaultParamsWriter.saveMetadata - // TODO: modify DefaultParamsWriter.saveMetadata to avoid duplication - val uid = instance.uid - val cls = instance.getClass.getName val stageUids = stages.map(_.uid) val jsonParams = List("stageUids" -> parse(compact(render(stageUids.toSeq)))) - val metadata = ("class" -> cls) ~ - ("timestamp" -> System.currentTimeMillis()) ~ - ("sparkVersion" -> sc.version) ~ - ("uid" -> uid) ~ - ("paramMap" -> jsonParams) - val metadataPath = new Path(path, "metadata").toString - val metadataJson = compact(render(metadata)) - sc.parallelize(Seq(metadataJson), 1).saveAsTextFile(metadataPath) + DefaultParamsWriter.saveMetadata(instance, path, sc, paramMap = Some(jsonParams)) // Save stages val stagesDir = new Path(path, "stages").toString @@ -266,30 +254,10 @@ object Pipeline extends MLReadable[Pipeline] { implicit val format = DefaultFormats val stagesDir = new Path(path, "stages").toString - val stageUids: Array[String] = metadata.params match { - case JObject(pairs) => - if (pairs.length != 1) { - // Should not happen unless file is corrupted or we have a bug. - throw new RuntimeException( - s"Pipeline read expected 1 Param (stageUids), but found ${pairs.length}.") - } - pairs.head match { - case ("stageUids", jsonValue) => - jsonValue.extract[Seq[String]].toArray - case (paramName, jsonValue) => - // Should not happen unless file is corrupted or we have a bug. - throw new RuntimeException(s"Pipeline read encountered unexpected Param $paramName" + - s" in metadata: ${metadata.metadataStr}") - } - case _ => - throw new IllegalArgumentException( - s"Cannot recognize JSON metadata: ${metadata.metadataStr}.") - } + val stageUids: Array[String] = (metadata.params \ "stageUids").extract[Seq[String]].toArray val stages: Array[PipelineStage] = stageUids.zipWithIndex.map { case (stageUid, idx) => val stagePath = SharedReadWrite.getStagePath(stageUid, idx, stageUids.length, stagesDir) - val stageMetadata = DefaultParamsReader.loadMetadata(stagePath, sc) - val cls = Utils.classForName(stageMetadata.className) - cls.getMethod("read").invoke(null).asInstanceOf[MLReader[PipelineStage]].load(stagePath) + DefaultParamsReader.loadParamsInstance[PipelineStage](stagePath, sc) } (metadata.uid, stages) } diff --git a/mllib/src/main/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluator.scala b/mllib/src/main/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluator.scala index 1fe3abaca81c3..bfb70963b151d 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluator.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluator.scala @@ -20,7 +20,7 @@ package org.apache.spark.ml.evaluation import org.apache.spark.annotation.{Experimental, Since} import org.apache.spark.ml.param._ import org.apache.spark.ml.param.shared._ -import org.apache.spark.ml.util.{Identifiable, SchemaUtils} +import org.apache.spark.ml.util.{DefaultParamsReadable, DefaultParamsWritable, Identifiable, SchemaUtils} import org.apache.spark.mllib.evaluation.BinaryClassificationMetrics import org.apache.spark.mllib.linalg.{Vector, VectorUDT} import org.apache.spark.sql.{DataFrame, Row} @@ -33,7 +33,7 @@ import org.apache.spark.sql.types.DoubleType @Since("1.2.0") @Experimental class BinaryClassificationEvaluator @Since("1.4.0") (@Since("1.4.0") override val uid: String) - extends Evaluator with HasRawPredictionCol with HasLabelCol { + extends Evaluator with HasRawPredictionCol with HasLabelCol with DefaultParamsWritable { @Since("1.2.0") def this() = this(Identifiable.randomUID("binEval")) @@ -105,3 +105,10 @@ class BinaryClassificationEvaluator @Since("1.4.0") (@Since("1.4.0") override va @Since("1.4.1") override def copy(extra: ParamMap): BinaryClassificationEvaluator = defaultCopy(extra) } + +@Since("1.6.0") +object BinaryClassificationEvaluator extends DefaultParamsReadable[BinaryClassificationEvaluator] { + + @Since("1.6.0") + override def load(path: String): BinaryClassificationEvaluator = super.load(path) +} diff --git a/mllib/src/main/scala/org/apache/spark/ml/evaluation/MulticlassClassificationEvaluator.scala b/mllib/src/main/scala/org/apache/spark/ml/evaluation/MulticlassClassificationEvaluator.scala index df5f04ca5a8d9..c44db0ec595ea 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/evaluation/MulticlassClassificationEvaluator.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/evaluation/MulticlassClassificationEvaluator.scala @@ -20,7 +20,7 @@ package org.apache.spark.ml.evaluation import org.apache.spark.annotation.{Experimental, Since} import org.apache.spark.ml.param.{ParamMap, ParamValidators, Param} import org.apache.spark.ml.param.shared.{HasLabelCol, HasPredictionCol} -import org.apache.spark.ml.util.{SchemaUtils, Identifiable} +import org.apache.spark.ml.util.{DefaultParamsReadable, DefaultParamsWritable, SchemaUtils, Identifiable} import org.apache.spark.mllib.evaluation.MulticlassMetrics import org.apache.spark.sql.{Row, DataFrame} import org.apache.spark.sql.types.DoubleType @@ -32,7 +32,7 @@ import org.apache.spark.sql.types.DoubleType @Since("1.5.0") @Experimental class MulticlassClassificationEvaluator @Since("1.5.0") (@Since("1.5.0") override val uid: String) - extends Evaluator with HasPredictionCol with HasLabelCol { + extends Evaluator with HasPredictionCol with HasLabelCol with DefaultParamsWritable { @Since("1.5.0") def this() = this(Identifiable.randomUID("mcEval")) @@ -101,3 +101,11 @@ class MulticlassClassificationEvaluator @Since("1.5.0") (@Since("1.5.0") overrid @Since("1.5.0") override def copy(extra: ParamMap): MulticlassClassificationEvaluator = defaultCopy(extra) } + +@Since("1.6.0") +object MulticlassClassificationEvaluator + extends DefaultParamsReadable[MulticlassClassificationEvaluator] { + + @Since("1.6.0") + override def load(path: String): MulticlassClassificationEvaluator = super.load(path) +} diff --git a/mllib/src/main/scala/org/apache/spark/ml/evaluation/RegressionEvaluator.scala b/mllib/src/main/scala/org/apache/spark/ml/evaluation/RegressionEvaluator.scala index ba012f444d3e0..daaa174a086e0 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/evaluation/RegressionEvaluator.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/evaluation/RegressionEvaluator.scala @@ -20,7 +20,7 @@ package org.apache.spark.ml.evaluation import org.apache.spark.annotation.{Experimental, Since} import org.apache.spark.ml.param.{Param, ParamMap, ParamValidators} import org.apache.spark.ml.param.shared.{HasLabelCol, HasPredictionCol} -import org.apache.spark.ml.util.{Identifiable, SchemaUtils} +import org.apache.spark.ml.util.{DefaultParamsReadable, DefaultParamsWritable, Identifiable, SchemaUtils} import org.apache.spark.mllib.evaluation.RegressionMetrics import org.apache.spark.sql.{DataFrame, Row} import org.apache.spark.sql.functions._ @@ -33,7 +33,7 @@ import org.apache.spark.sql.types.{DoubleType, FloatType} @Since("1.4.0") @Experimental final class RegressionEvaluator @Since("1.4.0") (@Since("1.4.0") override val uid: String) - extends Evaluator with HasPredictionCol with HasLabelCol { + extends Evaluator with HasPredictionCol with HasLabelCol with DefaultParamsWritable { @Since("1.4.0") def this() = this(Identifiable.randomUID("regEval")) @@ -104,3 +104,10 @@ final class RegressionEvaluator @Since("1.4.0") (@Since("1.4.0") override val ui @Since("1.5.0") override def copy(extra: ParamMap): RegressionEvaluator = defaultCopy(extra) } + +@Since("1.6.0") +object RegressionEvaluator extends DefaultParamsReadable[RegressionEvaluator] { + + @Since("1.6.0") + override def load(path: String): RegressionEvaluator = super.load(path) +} diff --git a/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala b/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala index 4d35177ad9b0f..b798aa1fab767 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala @@ -27,9 +27,8 @@ import scala.util.hashing.byteswap64 import com.github.fommil.netlib.BLAS.{getInstance => blas} import org.apache.hadoop.fs.{FileSystem, Path} -import org.json4s.{DefaultFormats, JValue} +import org.json4s.DefaultFormats import org.json4s.JsonDSL._ -import org.json4s.jackson.JsonMethods._ import org.apache.spark.{Logging, Partitioner} import org.apache.spark.annotation.{Since, DeveloperApi, Experimental} @@ -240,7 +239,7 @@ object ALSModel extends MLReadable[ALSModel] { private[ALSModel] class ALSModelWriter(instance: ALSModel) extends MLWriter { override protected def saveImpl(path: String): Unit = { - val extraMetadata = render("rank" -> instance.rank) + val extraMetadata = "rank" -> instance.rank DefaultParamsWriter.saveMetadata(instance, path, sc, Some(extraMetadata)) val userPath = new Path(path, "userFactors").toString instance.userFactors.write.format("parquet").save(userPath) @@ -257,14 +256,7 @@ object ALSModel extends MLReadable[ALSModel] { override def load(path: String): ALSModel = { val metadata = DefaultParamsReader.loadMetadata(path, sc, className) implicit val format = DefaultFormats - val rank: Int = metadata.extraMetadata match { - case Some(m: JValue) => - (m \ "rank").extract[Int] - case None => - throw new RuntimeException(s"ALSModel loader could not read rank from JSON metadata:" + - s" ${metadata.metadataStr}") - } - + val rank = (metadata.metadata \ "rank").extract[Int] val userPath = new Path(path, "userFactors").toString val userFactors = sqlContext.read.format("parquet").load(userPath) val itemPath = new Path(path, "itemFactors").toString diff --git a/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala b/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala index 77d9948ed86b9..83a9048374267 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala @@ -18,17 +18,24 @@ package org.apache.spark.ml.tuning import com.github.fommil.netlib.F2jBLAS +import org.apache.hadoop.fs.Path +import org.json4s.{JObject, DefaultFormats} +import org.json4s.jackson.JsonMethods._ -import org.apache.spark.Logging -import org.apache.spark.annotation.Experimental +import org.apache.spark.ml.classification.OneVsRestParams +import org.apache.spark.ml.feature.RFormulaModel +import org.apache.spark.{SparkContext, Logging} +import org.apache.spark.annotation.{Experimental, Since} import org.apache.spark.ml._ import org.apache.spark.ml.evaluation.Evaluator import org.apache.spark.ml.param._ -import org.apache.spark.ml.util.Identifiable +import org.apache.spark.ml.util._ +import org.apache.spark.ml.util.DefaultParamsReader.Metadata import org.apache.spark.mllib.util.MLUtils import org.apache.spark.sql.DataFrame import org.apache.spark.sql.types.StructType + /** * Params for [[CrossValidator]] and [[CrossValidatorModel]]. */ @@ -53,7 +60,7 @@ private[ml] trait CrossValidatorParams extends ValidatorParams { */ @Experimental class CrossValidator(override val uid: String) extends Estimator[CrossValidatorModel] - with CrossValidatorParams with Logging { + with CrossValidatorParams with MLWritable with Logging { def this() = this(Identifiable.randomUID("cv")) @@ -131,6 +138,166 @@ class CrossValidator(override val uid: String) extends Estimator[CrossValidatorM } copied } + + // Currently, this only works if all [[Param]]s in [[estimatorParamMaps]] are simple types. + // E.g., this may fail if a [[Param]] is an instance of an [[Estimator]]. + // However, this case should be unusual. + @Since("1.6.0") + override def write: MLWriter = new CrossValidator.CrossValidatorWriter(this) +} + +@Since("1.6.0") +object CrossValidator extends MLReadable[CrossValidator] { + + @Since("1.6.0") + override def read: MLReader[CrossValidator] = new CrossValidatorReader + + @Since("1.6.0") + override def load(path: String): CrossValidator = super.load(path) + + private[CrossValidator] class CrossValidatorWriter(instance: CrossValidator) extends MLWriter { + + SharedReadWrite.validateParams(instance) + + override protected def saveImpl(path: String): Unit = + SharedReadWrite.saveImpl(path, instance, sc) + } + + private class CrossValidatorReader extends MLReader[CrossValidator] { + + /** Checked against metadata when loading model */ + private val className = classOf[CrossValidator].getName + + override def load(path: String): CrossValidator = { + val (metadata, estimator, evaluator, estimatorParamMaps, numFolds) = + SharedReadWrite.load(path, sc, className) + new CrossValidator(metadata.uid) + .setEstimator(estimator) + .setEvaluator(evaluator) + .setEstimatorParamMaps(estimatorParamMaps) + .setNumFolds(numFolds) + } + } + + private object CrossValidatorReader { + /** + * Examine the given estimator (which may be a compound estimator) and extract a mapping + * from UIDs to corresponding [[Params]] instances. + */ + def getUidMap(instance: Params): Map[String, Params] = { + val uidList = getUidMapImpl(instance) + val uidMap = uidList.toMap + if (uidList.size != uidMap.size) { + throw new RuntimeException("CrossValidator.load found a compound estimator with stages" + + s" with duplicate UIDs. List of UIDs: ${uidList.map(_._1).mkString(", ")}") + } + uidMap + } + + def getUidMapImpl(instance: Params): List[(String, Params)] = { + val subStages: Array[Params] = instance match { + case p: Pipeline => p.getStages.asInstanceOf[Array[Params]] + case pm: PipelineModel => pm.stages.asInstanceOf[Array[Params]] + case v: ValidatorParams => Array(v.getEstimator, v.getEvaluator) + case ovr: OneVsRestParams => + // TODO: SPARK-11892: This case may require special handling. + throw new UnsupportedOperationException("CrossValidator write will fail because it" + + " cannot yet handle an estimator containing type: ${ovr.getClass.getName}") + case rform: RFormulaModel => + // TODO: SPARK-11891: This case may require special handling. + throw new UnsupportedOperationException("CrossValidator write will fail because it" + + " cannot yet handle an estimator containing an RFormulaModel") + case _: Params => Array() + } + val subStageMaps = subStages.map(getUidMapImpl).foldLeft(List.empty[(String, Params)])(_ ++ _) + List((instance.uid, instance)) ++ subStageMaps + } + } + + private[tuning] object SharedReadWrite { + + /** + * Check that [[CrossValidator.evaluator]] and [[CrossValidator.estimator]] are Writable. + * This does not check [[CrossValidator.estimatorParamMaps]]. + */ + def validateParams(instance: ValidatorParams): Unit = { + def checkElement(elem: Params, name: String): Unit = elem match { + case stage: MLWritable => // good + case other => + throw new UnsupportedOperationException("CrossValidator write will fail " + + s" because it contains $name which does not implement Writable." + + s" Non-Writable $name: ${other.uid} of type ${other.getClass}") + } + checkElement(instance.getEvaluator, "evaluator") + checkElement(instance.getEstimator, "estimator") + // Check to make sure all Params apply to this estimator. Throw an error if any do not. + // Extraneous Params would cause problems when loading the estimatorParamMaps. + val uidToInstance: Map[String, Params] = CrossValidatorReader.getUidMap(instance) + instance.getEstimatorParamMaps.foreach { case pMap: ParamMap => + pMap.toSeq.foreach { case ParamPair(p, v) => + require(uidToInstance.contains(p.parent), s"CrossValidator save requires all Params in" + + s" estimatorParamMaps to apply to this CrossValidator, its Estimator, or its" + + s" Evaluator. An extraneous Param was found: $p") + } + } + } + + private[tuning] def saveImpl( + path: String, + instance: CrossValidatorParams, + sc: SparkContext, + extraMetadata: Option[JObject] = None): Unit = { + import org.json4s.JsonDSL._ + + val estimatorParamMapsJson = compact(render( + instance.getEstimatorParamMaps.map { case paramMap => + paramMap.toSeq.map { case ParamPair(p, v) => + Map("parent" -> p.parent, "name" -> p.name, "value" -> p.jsonEncode(v)) + } + }.toSeq + )) + val jsonParams = List( + "numFolds" -> parse(instance.numFolds.jsonEncode(instance.getNumFolds)), + "estimatorParamMaps" -> parse(estimatorParamMapsJson) + ) + DefaultParamsWriter.saveMetadata(instance, path, sc, extraMetadata, Some(jsonParams)) + + val evaluatorPath = new Path(path, "evaluator").toString + instance.getEvaluator.asInstanceOf[MLWritable].save(evaluatorPath) + val estimatorPath = new Path(path, "estimator").toString + instance.getEstimator.asInstanceOf[MLWritable].save(estimatorPath) + } + + private[tuning] def load[M <: Model[M]]( + path: String, + sc: SparkContext, + expectedClassName: String): (Metadata, Estimator[M], Evaluator, Array[ParamMap], Int) = { + + val metadata = DefaultParamsReader.loadMetadata(path, sc, expectedClassName) + + implicit val format = DefaultFormats + val evaluatorPath = new Path(path, "evaluator").toString + val evaluator = DefaultParamsReader.loadParamsInstance[Evaluator](evaluatorPath, sc) + val estimatorPath = new Path(path, "estimator").toString + val estimator = DefaultParamsReader.loadParamsInstance[Estimator[M]](estimatorPath, sc) + + val uidToParams = Map(evaluator.uid -> evaluator) ++ CrossValidatorReader.getUidMap(estimator) + + val numFolds = (metadata.params \ "numFolds").extract[Int] + val estimatorParamMaps: Array[ParamMap] = + (metadata.params \ "estimatorParamMaps").extract[Seq[Seq[Map[String, String]]]].map { + pMap => + val paramPairs = pMap.map { case pInfo: Map[String, String] => + val est = uidToParams(pInfo("parent")) + val param = est.getParam(pInfo("name")) + val value = param.jsonDecode(pInfo("value")) + param -> value + } + ParamMap(paramPairs: _*) + }.toArray + (metadata, estimator, evaluator, estimatorParamMaps, numFolds) + } + } } /** @@ -139,14 +306,14 @@ class CrossValidator(override val uid: String) extends Estimator[CrossValidatorM * * @param bestModel The best model selected from k-fold cross validation. * @param avgMetrics Average cross-validation metrics for each paramMap in - * [[estimatorParamMaps]], in the corresponding order. + * [[CrossValidator.estimatorParamMaps]], in the corresponding order. */ @Experimental class CrossValidatorModel private[ml] ( override val uid: String, val bestModel: Model[_], val avgMetrics: Array[Double]) - extends Model[CrossValidatorModel] with CrossValidatorParams { + extends Model[CrossValidatorModel] with CrossValidatorParams with MLWritable { override def validateParams(): Unit = { bestModel.validateParams() @@ -168,4 +335,54 @@ class CrossValidatorModel private[ml] ( avgMetrics.clone()) copyValues(copied, extra).setParent(parent) } + + @Since("1.6.0") + override def write: MLWriter = new CrossValidatorModel.CrossValidatorModelWriter(this) +} + +@Since("1.6.0") +object CrossValidatorModel extends MLReadable[CrossValidatorModel] { + + import CrossValidator.SharedReadWrite + + @Since("1.6.0") + override def read: MLReader[CrossValidatorModel] = new CrossValidatorModelReader + + @Since("1.6.0") + override def load(path: String): CrossValidatorModel = super.load(path) + + private[CrossValidatorModel] + class CrossValidatorModelWriter(instance: CrossValidatorModel) extends MLWriter { + + SharedReadWrite.validateParams(instance) + + override protected def saveImpl(path: String): Unit = { + import org.json4s.JsonDSL._ + val extraMetadata = "avgMetrics" -> instance.avgMetrics.toSeq + SharedReadWrite.saveImpl(path, instance, sc, Some(extraMetadata)) + val bestModelPath = new Path(path, "bestModel").toString + instance.bestModel.asInstanceOf[MLWritable].save(bestModelPath) + } + } + + private class CrossValidatorModelReader extends MLReader[CrossValidatorModel] { + + /** Checked against metadata when loading model */ + private val className = classOf[CrossValidatorModel].getName + + override def load(path: String): CrossValidatorModel = { + implicit val format = DefaultFormats + + val (metadata, estimator, evaluator, estimatorParamMaps, numFolds) = + SharedReadWrite.load(path, sc, className) + val bestModelPath = new Path(path, "bestModel").toString + val bestModel = DefaultParamsReader.loadParamsInstance[Model[_]](bestModelPath, sc) + val avgMetrics = (metadata.metadata \ "avgMetrics").extract[Seq[Double]].toArray + val cv = new CrossValidatorModel(metadata.uid, bestModel, avgMetrics) + cv.set(cv.estimator, estimator) + .set(cv.evaluator, evaluator) + .set(cv.estimatorParamMaps, estimatorParamMaps) + .set(cv.numFolds, numFolds) + } + } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala b/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala index ff9322dba122a..8484b1f801066 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala @@ -202,25 +202,36 @@ private[ml] object DefaultParamsWriter { * - timestamp * - sparkVersion * - uid - * - paramMap: These must be encodable using [[org.apache.spark.ml.param.Param.jsonEncode()]]. + * - paramMap + * - (optionally, extra metadata) + * @param extraMetadata Extra metadata to be saved at same level as uid, paramMap, etc. + * @param paramMap If given, this is saved in the "paramMap" field. + * Otherwise, all [[org.apache.spark.ml.param.Param]]s are encoded using + * [[org.apache.spark.ml.param.Param.jsonEncode()]]. */ def saveMetadata( instance: Params, path: String, sc: SparkContext, - extraMetadata: Option[JValue] = None): Unit = { + extraMetadata: Option[JObject] = None, + paramMap: Option[JValue] = None): Unit = { val uid = instance.uid val cls = instance.getClass.getName val params = instance.extractParamMap().toSeq.asInstanceOf[Seq[ParamPair[Any]]] - val jsonParams = params.map { case ParamPair(p, v) => + val jsonParams = paramMap.getOrElse(render(params.map { case ParamPair(p, v) => p.name -> parse(p.jsonEncode(v)) - }.toList - val metadata = ("class" -> cls) ~ + }.toList)) + val basicMetadata = ("class" -> cls) ~ ("timestamp" -> System.currentTimeMillis()) ~ ("sparkVersion" -> sc.version) ~ ("uid" -> uid) ~ - ("paramMap" -> jsonParams) ~ - ("extraMetadata" -> extraMetadata) + ("paramMap" -> jsonParams) + val metadata = extraMetadata match { + case Some(jObject) => + basicMetadata ~ jObject + case None => + basicMetadata + } val metadataPath = new Path(path, "metadata").toString val metadataJson = compact(render(metadata)) sc.parallelize(Seq(metadataJson), 1).saveAsTextFile(metadataPath) @@ -251,8 +262,8 @@ private[ml] object DefaultParamsReader { /** * All info from metadata file. * @param params paramMap, as a [[JValue]] - * @param extraMetadata Extra metadata saved by [[DefaultParamsWriter.saveMetadata()]] - * @param metadataStr Full metadata file String (for debugging) + * @param metadata All metadata, including the other fields + * @param metadataJson Full metadata file String (for debugging) */ case class Metadata( className: String, @@ -260,8 +271,8 @@ private[ml] object DefaultParamsReader { timestamp: Long, sparkVersion: String, params: JValue, - extraMetadata: Option[JValue], - metadataStr: String) + metadata: JValue, + metadataJson: String) /** * Load metadata from file. @@ -279,13 +290,12 @@ private[ml] object DefaultParamsReader { val timestamp = (metadata \ "timestamp").extract[Long] val sparkVersion = (metadata \ "sparkVersion").extract[String] val params = metadata \ "paramMap" - val extraMetadata = (metadata \ "extraMetadata").extract[Option[JValue]] if (expectedClassName.nonEmpty) { require(className == expectedClassName, s"Error loading metadata: Expected class name" + s" $expectedClassName but found class name $className") } - Metadata(className, uid, timestamp, sparkVersion, params, extraMetadata, metadataStr) + Metadata(className, uid, timestamp, sparkVersion, params, metadata, metadataStr) } /** @@ -303,7 +313,17 @@ private[ml] object DefaultParamsReader { } case _ => throw new IllegalArgumentException( - s"Cannot recognize JSON metadata: ${metadata.metadataStr}.") + s"Cannot recognize JSON metadata: ${metadata.metadataJson}.") } } + + /** + * Load a [[Params]] instance from the given path, and return it. + * This assumes the instance implements [[MLReadable]]. + */ + def loadParamsInstance[T](path: String, sc: SparkContext): T = { + val metadata = DefaultParamsReader.loadMetadata(path, sc) + val cls = Utils.classForName(metadata.className) + cls.getMethod("read").invoke(null).asInstanceOf[MLReader[T]].load(path) + } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/PipelineSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/PipelineSuite.scala index 12aba6bc6dbeb..8c86767456368 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/PipelineSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/PipelineSuite.scala @@ -17,11 +17,9 @@ package org.apache.spark.ml -import java.io.File - import scala.collection.JavaConverters._ -import org.apache.hadoop.fs.{FileSystem, Path} +import org.apache.hadoop.fs.Path import org.mockito.Matchers.{any, eq => meq} import org.mockito.Mockito.when import org.scalatest.mock.MockitoSugar.mock diff --git a/mllib/src/test/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluatorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluatorSuite.scala index def869fe66777..a535c1218ecfa 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluatorSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluatorSuite.scala @@ -19,10 +19,21 @@ package org.apache.spark.ml.evaluation import org.apache.spark.SparkFunSuite import org.apache.spark.ml.param.ParamsSuite +import org.apache.spark.ml.util.DefaultReadWriteTest +import org.apache.spark.mllib.util.MLlibTestSparkContext -class BinaryClassificationEvaluatorSuite extends SparkFunSuite { +class BinaryClassificationEvaluatorSuite + extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { test("params") { ParamsSuite.checkParams(new BinaryClassificationEvaluator) } + + test("read/write") { + val evaluator = new BinaryClassificationEvaluator() + .setRawPredictionCol("myRawPrediction") + .setLabelCol("myLabel") + .setMetricName("areaUnderPR") + testDefaultReadWrite(evaluator) + } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/evaluation/MulticlassClassificationEvaluatorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/evaluation/MulticlassClassificationEvaluatorSuite.scala index 6d8412b0b3701..7ee65975d22f7 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/evaluation/MulticlassClassificationEvaluatorSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/evaluation/MulticlassClassificationEvaluatorSuite.scala @@ -19,10 +19,21 @@ package org.apache.spark.ml.evaluation import org.apache.spark.SparkFunSuite import org.apache.spark.ml.param.ParamsSuite +import org.apache.spark.ml.util.DefaultReadWriteTest +import org.apache.spark.mllib.util.MLlibTestSparkContext -class MulticlassClassificationEvaluatorSuite extends SparkFunSuite { +class MulticlassClassificationEvaluatorSuite + extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { test("params") { ParamsSuite.checkParams(new MulticlassClassificationEvaluator) } + + test("read/write") { + val evaluator = new MulticlassClassificationEvaluator() + .setPredictionCol("myPrediction") + .setLabelCol("myLabel") + .setMetricName("recall") + testDefaultReadWrite(evaluator) + } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/evaluation/RegressionEvaluatorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/evaluation/RegressionEvaluatorSuite.scala index aa722da323935..60886bf77d2f0 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/evaluation/RegressionEvaluatorSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/evaluation/RegressionEvaluatorSuite.scala @@ -20,10 +20,12 @@ package org.apache.spark.ml.evaluation import org.apache.spark.SparkFunSuite import org.apache.spark.ml.param.ParamsSuite import org.apache.spark.ml.regression.LinearRegression +import org.apache.spark.ml.util.DefaultReadWriteTest import org.apache.spark.mllib.util.{LinearDataGenerator, MLlibTestSparkContext} import org.apache.spark.mllib.util.TestingUtils._ -class RegressionEvaluatorSuite extends SparkFunSuite with MLlibTestSparkContext { +class RegressionEvaluatorSuite + extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { test("params") { ParamsSuite.checkParams(new RegressionEvaluator) @@ -73,4 +75,12 @@ class RegressionEvaluatorSuite extends SparkFunSuite with MLlibTestSparkContext evaluator.setMetricName("mae") assert(evaluator.evaluate(predictions) ~== 0.08036075 absTol 0.001) } + + test("read/write") { + val evaluator = new RegressionEvaluator() + .setPredictionCol("myPrediction") + .setLabelCol("myLabel") + .setMetricName("r2") + testDefaultReadWrite(evaluator) + } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala index cbe09292a0337..dd6366050c020 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala @@ -18,19 +18,22 @@ package org.apache.spark.ml.tuning import org.apache.spark.SparkFunSuite -import org.apache.spark.ml.util.MLTestingUtils -import org.apache.spark.ml.{Estimator, Model} -import org.apache.spark.ml.classification.LogisticRegression +import org.apache.spark.ml.feature.HashingTF +import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils} +import org.apache.spark.ml.{Pipeline, Estimator, Model} +import org.apache.spark.ml.classification.{LogisticRegressionModel, LogisticRegression} import org.apache.spark.ml.evaluation.{BinaryClassificationEvaluator, Evaluator, RegressionEvaluator} -import org.apache.spark.ml.param.ParamMap +import org.apache.spark.ml.param.{ParamPair, ParamMap} import org.apache.spark.ml.param.shared.HasInputCol import org.apache.spark.ml.regression.LinearRegression import org.apache.spark.mllib.classification.LogisticRegressionSuite.generateLogisticInput +import org.apache.spark.mllib.linalg.Vectors import org.apache.spark.mllib.util.{LinearDataGenerator, MLlibTestSparkContext} import org.apache.spark.sql.{DataFrame, SQLContext} import org.apache.spark.sql.types.StructType -class CrossValidatorSuite extends SparkFunSuite with MLlibTestSparkContext { +class CrossValidatorSuite + extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { @transient var dataset: DataFrame = _ @@ -95,7 +98,7 @@ class CrossValidatorSuite extends SparkFunSuite with MLlibTestSparkContext { } test("validateParams should check estimatorParamMaps") { - import CrossValidatorSuite._ + import CrossValidatorSuite.{MyEstimator, MyEvaluator} val est = new MyEstimator("est") val eval = new MyEvaluator @@ -116,9 +119,194 @@ class CrossValidatorSuite extends SparkFunSuite with MLlibTestSparkContext { cv.validateParams() } } + + test("read/write: CrossValidator with simple estimator") { + val lr = new LogisticRegression().setMaxIter(3) + val evaluator = new BinaryClassificationEvaluator() + .setMetricName("areaUnderPR") // not default metric + val paramMaps = new ParamGridBuilder() + .addGrid(lr.regParam, Array(0.1, 0.2)) + .build() + val cv = new CrossValidator() + .setEstimator(lr) + .setEvaluator(evaluator) + .setNumFolds(20) + .setEstimatorParamMaps(paramMaps) + + val cv2 = testDefaultReadWrite(cv, testParams = false) + + assert(cv.uid === cv2.uid) + assert(cv.getNumFolds === cv2.getNumFolds) + + assert(cv2.getEvaluator.isInstanceOf[BinaryClassificationEvaluator]) + val evaluator2 = cv2.getEvaluator.asInstanceOf[BinaryClassificationEvaluator] + assert(evaluator.uid === evaluator2.uid) + assert(evaluator.getMetricName === evaluator2.getMetricName) + + cv2.getEstimator match { + case lr2: LogisticRegression => + assert(lr.uid === lr2.uid) + assert(lr.getMaxIter === lr2.getMaxIter) + case other => + throw new AssertionError(s"Loaded CrossValidator expected estimator of type" + + s" LogisticRegression but found ${other.getClass.getName}") + } + + CrossValidatorSuite.compareParamMaps(cv.getEstimatorParamMaps, cv2.getEstimatorParamMaps) + } + + test("read/write: CrossValidator with complex estimator") { + // workflow: CrossValidator[Pipeline[HashingTF, CrossValidator[LogisticRegression]]] + val lrEvaluator = new BinaryClassificationEvaluator() + .setMetricName("areaUnderPR") // not default metric + + val lr = new LogisticRegression().setMaxIter(3) + val lrParamMaps = new ParamGridBuilder() + .addGrid(lr.regParam, Array(0.1, 0.2)) + .build() + val lrcv = new CrossValidator() + .setEstimator(lr) + .setEvaluator(lrEvaluator) + .setEstimatorParamMaps(lrParamMaps) + + val hashingTF = new HashingTF() + val pipeline = new Pipeline().setStages(Array(hashingTF, lrcv)) + val paramMaps = new ParamGridBuilder() + .addGrid(hashingTF.numFeatures, Array(10, 20)) + .addGrid(lr.elasticNetParam, Array(0.0, 1.0)) + .build() + val evaluator = new BinaryClassificationEvaluator() + + val cv = new CrossValidator() + .setEstimator(pipeline) + .setEvaluator(evaluator) + .setNumFolds(20) + .setEstimatorParamMaps(paramMaps) + + val cv2 = testDefaultReadWrite(cv, testParams = false) + + assert(cv.uid === cv2.uid) + assert(cv.getNumFolds === cv2.getNumFolds) + + assert(cv2.getEvaluator.isInstanceOf[BinaryClassificationEvaluator]) + assert(cv.getEvaluator.uid === cv2.getEvaluator.uid) + + CrossValidatorSuite.compareParamMaps(cv.getEstimatorParamMaps, cv2.getEstimatorParamMaps) + + cv2.getEstimator match { + case pipeline2: Pipeline => + assert(pipeline.uid === pipeline2.uid) + pipeline2.getStages match { + case Array(hashingTF2: HashingTF, lrcv2: CrossValidator) => + assert(hashingTF.uid === hashingTF2.uid) + lrcv2.getEstimator match { + case lr2: LogisticRegression => + assert(lr.uid === lr2.uid) + assert(lr.getMaxIter === lr2.getMaxIter) + case other => + throw new AssertionError(s"Loaded internal CrossValidator expected to be" + + s" LogisticRegression but found type ${other.getClass.getName}") + } + assert(lrcv.uid === lrcv2.uid) + assert(lrcv2.getEvaluator.isInstanceOf[BinaryClassificationEvaluator]) + assert(lrEvaluator.uid === lrcv2.getEvaluator.uid) + CrossValidatorSuite.compareParamMaps(lrParamMaps, lrcv2.getEstimatorParamMaps) + case other => + throw new AssertionError("Loaded Pipeline expected stages (HashingTF, CrossValidator)" + + " but found: " + other.map(_.getClass.getName).mkString(", ")) + } + case other => + throw new AssertionError(s"Loaded CrossValidator expected estimator of type" + + s" CrossValidator but found ${other.getClass.getName}") + } + } + + test("read/write: CrossValidator fails for extraneous Param") { + val lr = new LogisticRegression() + val lr2 = new LogisticRegression() + val evaluator = new BinaryClassificationEvaluator() + val paramMaps = new ParamGridBuilder() + .addGrid(lr.regParam, Array(0.1, 0.2)) + .addGrid(lr2.regParam, Array(0.1, 0.2)) + .build() + val cv = new CrossValidator() + .setEstimator(lr) + .setEvaluator(evaluator) + .setEstimatorParamMaps(paramMaps) + withClue("CrossValidator.write failed to catch extraneous Param error") { + intercept[IllegalArgumentException] { + cv.write + } + } + } + + test("read/write: CrossValidatorModel") { + val lr = new LogisticRegression() + .setThreshold(0.6) + val lrModel = new LogisticRegressionModel(lr.uid, Vectors.dense(1.0, 2.0), 1.2) + .setThreshold(0.6) + val evaluator = new BinaryClassificationEvaluator() + .setMetricName("areaUnderPR") // not default metric + val paramMaps = new ParamGridBuilder() + .addGrid(lr.regParam, Array(0.1, 0.2)) + .build() + val cv = new CrossValidatorModel("cvUid", lrModel, Array(0.3, 0.6)) + cv.set(cv.estimator, lr) + .set(cv.evaluator, evaluator) + .set(cv.numFolds, 20) + .set(cv.estimatorParamMaps, paramMaps) + + val cv2 = testDefaultReadWrite(cv, testParams = false) + + assert(cv.uid === cv2.uid) + assert(cv.getNumFolds === cv2.getNumFolds) + + assert(cv2.getEvaluator.isInstanceOf[BinaryClassificationEvaluator]) + val evaluator2 = cv2.getEvaluator.asInstanceOf[BinaryClassificationEvaluator] + assert(evaluator.uid === evaluator2.uid) + assert(evaluator.getMetricName === evaluator2.getMetricName) + + cv2.getEstimator match { + case lr2: LogisticRegression => + assert(lr.uid === lr2.uid) + assert(lr.getThreshold === lr2.getThreshold) + case other => + throw new AssertionError(s"Loaded CrossValidator expected estimator of type" + + s" LogisticRegression but found ${other.getClass.getName}") + } + + CrossValidatorSuite.compareParamMaps(cv.getEstimatorParamMaps, cv2.getEstimatorParamMaps) + + cv2.bestModel match { + case lrModel2: LogisticRegressionModel => + assert(lrModel.uid === lrModel2.uid) + assert(lrModel.getThreshold === lrModel2.getThreshold) + assert(lrModel.coefficients === lrModel2.coefficients) + assert(lrModel.intercept === lrModel2.intercept) + case other => + throw new AssertionError(s"Loaded CrossValidator expected bestModel of type" + + s" LogisticRegressionModel but found ${other.getClass.getName}") + } + assert(cv.avgMetrics === cv2.avgMetrics) + } } -object CrossValidatorSuite { +object CrossValidatorSuite extends SparkFunSuite { + + /** + * Assert sequences of estimatorParamMaps are identical. + * Params must be simple types comparable with `===`. + */ + def compareParamMaps(pMaps: Array[ParamMap], pMaps2: Array[ParamMap]): Unit = { + assert(pMaps.length === pMaps2.length) + pMaps.zip(pMaps2).foreach { case (pMap, pMap2) => + assert(pMap.size === pMap2.size) + pMap.toSeq.foreach { case ParamPair(p, v) => + assert(pMap2.contains(p)) + assert(pMap2(p) === v) + } + } + } abstract class MyModel extends Model[MyModel] From fc4b792d287095d70379a51f117c225d8d857078 Mon Sep 17 00:00:00 2001 From: Timothy Hunter Date: Sun, 22 Nov 2015 21:51:42 -0800 Subject: [PATCH 03/18] [SPARK-11835] Adds a sidebar menu to MLlib's documentation This PR adds a sidebar menu when browsing the user guide of MLlib. It uses a YAML file to describe the structure of the documentation. It should be trivial to adapt this to the other projects. ![screen shot 2015-11-18 at 4 46 12 pm](https://cloud.githubusercontent.com/assets/7594753/11259591/a55173f4-8e17-11e5-9340-0aed79d66262.png) Author: Timothy Hunter Closes #9826 from thunterdb/spark-11835. --- docs/_data/menu-ml.yaml | 10 ++++ docs/_data/menu-mllib.yaml | 75 +++++++++++++++++++++++++ docs/_includes/nav-left-wrapper-ml.html | 8 +++ docs/_includes/nav-left.html | 17 ++++++ docs/_layouts/global.html | 24 +++++--- docs/css/main.css | 37 ++++++++++++ 6 files changed, 163 insertions(+), 8 deletions(-) create mode 100644 docs/_data/menu-ml.yaml create mode 100644 docs/_data/menu-mllib.yaml create mode 100644 docs/_includes/nav-left-wrapper-ml.html create mode 100644 docs/_includes/nav-left.html diff --git a/docs/_data/menu-ml.yaml b/docs/_data/menu-ml.yaml new file mode 100644 index 0000000000000..dff3d33bf4ed1 --- /dev/null +++ b/docs/_data/menu-ml.yaml @@ -0,0 +1,10 @@ +- text: Feature extraction, transformation, and selection + url: ml-features.html +- text: Decision trees for classification and regression + url: ml-decision-tree.html +- text: Ensembles + url: ml-ensembles.html +- text: Linear methods with elastic-net regularization + url: ml-linear-methods.html +- text: Multilayer perceptron classifier + url: ml-ann.html diff --git a/docs/_data/menu-mllib.yaml b/docs/_data/menu-mllib.yaml new file mode 100644 index 0000000000000..12d22abd52826 --- /dev/null +++ b/docs/_data/menu-mllib.yaml @@ -0,0 +1,75 @@ +- text: Data types + url: mllib-data-types.html +- text: Basic statistics + url: mllib-statistics.html + subitems: + - text: Summary statistics + url: mllib-statistics.html#summary-statistics + - text: Correlations + url: mllib-statistics.html#correlations + - text: Stratified sampling + url: mllib-statistics.html#stratified-sampling + - text: Hypothesis testing + url: mllib-statistics.html#hypothesis-testing + - text: Random data generation + url: mllib-statistics.html#random-data-generation +- text: Classification and regression + url: mllib-classification-regression.html + subitems: + - text: Linear models (SVMs, logistic regression, linear regression) + url: mllib-linear-methods.html + - text: Naive Bayes + url: mllib-naive-bayes.html + - text: decision trees + url: mllib-decision-tree.html + - text: ensembles of trees (Random Forests and Gradient-Boosted Trees) + url: mllib-ensembles.html + - text: isotonic regression + url: mllib-isotonic-regression.html +- text: Collaborative filtering + url: mllib-collaborative-filtering.html + subitems: + - text: alternating least squares (ALS) + url: mllib-collaborative-filtering.html#collaborative-filtering +- text: Clustering + url: mllib-clustering.html + subitems: + - text: k-means + url: mllib-clustering.html#k-means + - text: Gaussian mixture + url: mllib-clustering.html#gaussian-mixture + - text: power iteration clustering (PIC) + url: mllib-clustering.html#power-iteration-clustering-pic + - text: latent Dirichlet allocation (LDA) + url: mllib-clustering.html#latent-dirichlet-allocation-lda + - text: streaming k-means + url: mllib-clustering.html#streaming-k-means +- text: Dimensionality reduction + url: mllib-dimensionality-reduction.html + subitems: + - text: singular value decomposition (SVD) + url: mllib-dimensionality-reduction.html#singular-value-decomposition-svd + - text: principal component analysis (PCA) + url: mllib-dimensionality-reduction.html#principal-component-analysis-pca +- text: Feature extraction and transformation + url: mllib-feature-extraction.html +- text: Frequent pattern mining + url: mllib-frequent-pattern-mining.html + subitems: + - text: FP-growth + url: mllib-frequent-pattern-mining.html#fp-growth + - text: association rules + url: mllib-frequent-pattern-mining.html#association-rules + - text: PrefixSpan + url: mllib-frequent-pattern-mining.html#prefix-span +- text: Evaluation metrics + url: mllib-evaluation-metrics.html +- text: PMML model export + url: mllib-pmml-model-export.html +- text: Optimization (developer) + url: mllib-optimization.html + subitems: + - text: stochastic gradient descent + url: mllib-optimization.html#stochastic-gradient-descent-sgd + - text: limited-memory BFGS (L-BFGS) + url: mllib-optimization.html#limited-memory-bfgs-l-bfgs diff --git a/docs/_includes/nav-left-wrapper-ml.html b/docs/_includes/nav-left-wrapper-ml.html new file mode 100644 index 0000000000000..0103e890cc21a --- /dev/null +++ b/docs/_includes/nav-left-wrapper-ml.html @@ -0,0 +1,8 @@ +
+
+

spark.ml package

+ {% include nav-left.html nav=include.nav-ml %} +

spark.mllib package

+ {% include nav-left.html nav=include.nav-mllib %} +
+
\ No newline at end of file diff --git a/docs/_includes/nav-left.html b/docs/_includes/nav-left.html new file mode 100644 index 0000000000000..73176f4132554 --- /dev/null +++ b/docs/_includes/nav-left.html @@ -0,0 +1,17 @@ +{% assign navurl = page.url | remove: 'index.html' %} + diff --git a/docs/_layouts/global.html b/docs/_layouts/global.html index 467ff7a03fb70..1b09e2221e173 100755 --- a/docs/_layouts/global.html +++ b/docs/_layouts/global.html @@ -124,16 +124,24 @@ -
- {% if page.displayTitle %} -

{{ page.displayTitle }}

- {% else %} -

{{ page.title }}

- {% endif %} +
- {{ content }} + {% if page.url contains "/ml" %} + {% include nav-left-wrapper-ml.html nav-mllib=site.data.menu-mllib nav-ml=site.data.menu-ml %} + {% endif %} -
+ +
+ {% if page.displayTitle %} +

{{ page.displayTitle }}

+ {% else %} +

{{ page.title }}

+ {% endif %} + + {{ content }} + +
+
diff --git a/docs/css/main.css b/docs/css/main.css index d770173be1014..356b324d6303b 100755 --- a/docs/css/main.css +++ b/docs/css/main.css @@ -39,8 +39,18 @@ margin-left: 10px; } +body .container-wrapper { + position: absolute; + width: 100%; + display: flex; +} + body #content { + position: relative; + line-height: 1.6; /* Inspired by Github's wiki style */ + background-color: white; + padding-left: 15px; } .title { @@ -155,3 +165,30 @@ ul.nav li.dropdown ul.dropdown-menu li.dropdown-submenu ul.dropdown-menu { * AnchorJS (anchor links when hovering over headers) */ a.anchorjs-link:hover { text-decoration: none; } + + +/** + * The left navigation bar. + */ +.left-menu-wrapper { + position: absolute; + height: 100%; + + width: 256px; + margin-top: -20px; + padding-top: 20px; + background-color: #F0F8FC; +} + +.left-menu { + position: fixed; + max-width: 350px; + + padding-right: 10px; + width: 256px; +} + +.left-menu h3 { + margin-left: 10px; + line-height: 30px; +} \ No newline at end of file From d9cf9c21fc6b1aa22e68d66760afd42c4e1c18b8 Mon Sep 17 00:00:00 2001 From: Yanbo Liang Date: Sun, 22 Nov 2015 21:56:07 -0800 Subject: [PATCH 04/18] [SPARK-11912][ML] ml.feature.PCA minor refactor Like [SPARK-11852](https://issues.apache.org/jira/browse/SPARK-11852), ```k``` is params and we should save it under ```metadata/``` rather than both under ```data/``` and ```metadata/```. Refactor the constructor of ```ml.feature.PCAModel``` to take only ```pc``` but construct ```mllib.feature.PCAModel``` inside ```transform```. Author: Yanbo Liang Closes #9897 from yanboliang/spark-11912. --- .../org/apache/spark/ml/feature/PCA.scala | 23 +++++++------- .../apache/spark/ml/feature/PCASuite.scala | 31 ++++++++----------- 2 files changed, 24 insertions(+), 30 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/PCA.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/PCA.scala index 32d7afee6e73b..aa88cb03d23c5 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/PCA.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/PCA.scala @@ -73,7 +73,7 @@ class PCA (override val uid: String) extends Estimator[PCAModel] with PCAParams val input = dataset.select($(inputCol)).map { case Row(v: Vector) => v} val pca = new feature.PCA(k = $(k)) val pcaModel = pca.fit(input) - copyValues(new PCAModel(uid, pcaModel).setParent(this)) + copyValues(new PCAModel(uid, pcaModel.pc).setParent(this)) } override def transformSchema(schema: StructType): StructType = { @@ -99,18 +99,17 @@ object PCA extends DefaultParamsReadable[PCA] { /** * :: Experimental :: * Model fitted by [[PCA]]. + * + * @param pc A principal components Matrix. Each column is one principal component. */ @Experimental class PCAModel private[ml] ( override val uid: String, - pcaModel: feature.PCAModel) + val pc: DenseMatrix) extends Model[PCAModel] with PCAParams with MLWritable { import PCAModel._ - /** a principal components Matrix. Each column is one principal component. */ - val pc: DenseMatrix = pcaModel.pc - /** @group setParam */ def setInputCol(value: String): this.type = set(inputCol, value) @@ -124,6 +123,7 @@ class PCAModel private[ml] ( */ override def transform(dataset: DataFrame): DataFrame = { transformSchema(dataset.schema, logging = true) + val pcaModel = new feature.PCAModel($(k), pc) val pcaOp = udf { pcaModel.transform _ } dataset.withColumn($(outputCol), pcaOp(col($(inputCol)))) } @@ -139,7 +139,7 @@ class PCAModel private[ml] ( } override def copy(extra: ParamMap): PCAModel = { - val copied = new PCAModel(uid, pcaModel) + val copied = new PCAModel(uid, pc) copyValues(copied, extra).setParent(parent) } @@ -152,11 +152,11 @@ object PCAModel extends MLReadable[PCAModel] { private[PCAModel] class PCAModelWriter(instance: PCAModel) extends MLWriter { - private case class Data(k: Int, pc: DenseMatrix) + private case class Data(pc: DenseMatrix) override protected def saveImpl(path: String): Unit = { DefaultParamsWriter.saveMetadata(instance, path, sc) - val data = Data(instance.getK, instance.pc) + val data = Data(instance.pc) val dataPath = new Path(path, "data").toString sqlContext.createDataFrame(Seq(data)).repartition(1).write.parquet(dataPath) } @@ -169,11 +169,10 @@ object PCAModel extends MLReadable[PCAModel] { override def load(path: String): PCAModel = { val metadata = DefaultParamsReader.loadMetadata(path, sc, className) val dataPath = new Path(path, "data").toString - val Row(k: Int, pc: DenseMatrix) = sqlContext.read.parquet(dataPath) - .select("k", "pc") + val Row(pc: DenseMatrix) = sqlContext.read.parquet(dataPath) + .select("pc") .head() - val oldModel = new feature.PCAModel(k, pc) - val model = new PCAModel(metadata.uid, oldModel) + val model = new PCAModel(metadata.uid, pc) DefaultParamsReader.getAndSetParams(model, metadata) model } diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/PCASuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/PCASuite.scala index 5a21cd20ceede..edab21e6c3072 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/PCASuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/PCASuite.scala @@ -32,7 +32,7 @@ class PCASuite extends SparkFunSuite with MLlibTestSparkContext with DefaultRead test("params") { ParamsSuite.checkParams(new PCA) val mat = Matrices.dense(2, 2, Array(0.0, 1.0, 2.0, 3.0)).asInstanceOf[DenseMatrix] - val model = new PCAModel("pca", new OldPCAModel(2, mat)) + val model = new PCAModel("pca", mat) ParamsSuite.checkParams(model) } @@ -66,23 +66,18 @@ class PCASuite extends SparkFunSuite with MLlibTestSparkContext with DefaultRead } } - test("read/write") { + test("PCA read/write") { + val t = new PCA() + .setInputCol("myInputCol") + .setOutputCol("myOutputCol") + .setK(3) + testDefaultReadWrite(t) + } - def checkModelData(model1: PCAModel, model2: PCAModel): Unit = { - assert(model1.pc === model2.pc) - } - val allParams: Map[String, Any] = Map( - "k" -> 3, - "inputCol" -> "features", - "outputCol" -> "pca_features" - ) - val data = Seq( - (0.0, Vectors.sparse(5, Seq((1, 1.0), (3, 7.0)))), - (1.0, Vectors.dense(2.0, 0.0, 3.0, 4.0, 5.0)), - (2.0, Vectors.dense(4.0, 0.0, 0.0, 6.0, 7.0)) - ) - val df = sqlContext.createDataFrame(data).toDF("id", "features") - val pca = new PCA().setK(3) - testEstimatorAndModelReadWrite(pca, df, allParams, checkModelData) + test("PCAModel read/write") { + val instance = new PCAModel("myPCAModel", + Matrices.dense(2, 2, Array(0.0, 1.0, 2.0, 3.0)).asInstanceOf[DenseMatrix]) + val newInstance = testDefaultReadWrite(instance) + assert(newInstance.pc === instance.pc) } } From 4be360d4ee6cdb4d06306feca38ddef5212608cf Mon Sep 17 00:00:00 2001 From: BenFradet Date: Sun, 22 Nov 2015 22:05:01 -0800 Subject: [PATCH 05/18] [SPARK-11902][ML] Unhandled case in VectorAssembler#transform There is an unhandled case in the transform method of VectorAssembler if one of the input columns doesn't have one of the supported type DoubleType, NumericType, BooleanType or VectorUDT. So, if you try to transform a column of StringType you get a cryptic "scala.MatchError: StringType". This PR aims to fix this, throwing a SparkException when dealing with an unknown column type. Author: BenFradet Closes #9885 from BenFradet/SPARK-11902. --- .../org/apache/spark/ml/feature/VectorAssembler.scala | 2 ++ .../spark/ml/feature/VectorAssemblerSuite.scala | 11 +++++++++++ 2 files changed, 13 insertions(+) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala index 0feec0549852b..801096fed27bf 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala @@ -84,6 +84,8 @@ class VectorAssembler(override val uid: String) val numAttrs = group.numAttributes.getOrElse(first.getAs[Vector](index).size) Array.fill(numAttrs)(NumericAttribute.defaultAttr) } + case otherType => + throw new SparkException(s"VectorAssembler does not support the $otherType type") } } val metadata = new AttributeGroup($(outputCol), attrs).toMetadata() diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/VectorAssemblerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/VectorAssemblerSuite.scala index fb21ab6b9bf2c..9c1c00f41ab1d 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/VectorAssemblerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/VectorAssemblerSuite.scala @@ -69,6 +69,17 @@ class VectorAssemblerSuite } } + test("transform should throw an exception in case of unsupported type") { + val df = sqlContext.createDataFrame(Seq(("a", "b", "c"))).toDF("a", "b", "c") + val assembler = new VectorAssembler() + .setInputCols(Array("a", "b", "c")) + .setOutputCol("features") + val thrown = intercept[SparkException] { + assembler.transform(df) + } + assert(thrown.getMessage contains "VectorAssembler does not support the StringType type") + } + test("ML attributes") { val browser = NominalAttribute.defaultAttr.withValues("chrome", "firefox", "safari") val hour = NumericAttribute.defaultAttr.withMin(0.0).withMax(24.0) From 94ce65dfcbba1fe3a1fc9d8002c37d9cd1a11336 Mon Sep 17 00:00:00 2001 From: Xiu Guo Date: Mon, 23 Nov 2015 08:53:40 -0800 Subject: [PATCH 06/18] [SPARK-11628][SQL] support column datatype of char(x) to recognize HiveChar Can someone review my code to make sure I'm not missing anything? Thanks! Author: Xiu Guo Author: Xiu Guo Closes #9612 from xguo27/SPARK-11628. --- .../sql/catalyst/util/DataTypeParser.scala | 6 ++++- .../catalyst/util/DataTypeParserSuite.scala | 8 ++++-- .../spark/sql/sources/TableScanSuite.scala | 5 ++++ .../spark/sql/hive/HiveInspectors.scala | 25 ++++++++++++++++--- .../apache/spark/sql/hive/TableReader.scala | 3 +++ .../spark/sql/hive/client/HiveShim.scala | 3 ++- 6 files changed, 43 insertions(+), 7 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DataTypeParser.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DataTypeParser.scala index 2b83651f9086d..515c071c283b0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DataTypeParser.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DataTypeParser.scala @@ -52,7 +52,8 @@ private[sql] trait DataTypeParser extends StandardTokenParsers { "(?i)decimal".r ^^^ DecimalType.USER_DEFAULT | "(?i)date".r ^^^ DateType | "(?i)timestamp".r ^^^ TimestampType | - varchar + varchar | + char protected lazy val fixedDecimalType: Parser[DataType] = ("(?i)decimal".r ~> "(" ~> numericLit) ~ ("," ~> numericLit <~ ")") ^^ { @@ -60,6 +61,9 @@ private[sql] trait DataTypeParser extends StandardTokenParsers { DecimalType(precision.toInt, scale.toInt) } + protected lazy val char: Parser[DataType] = + "(?i)char".r ~> "(" ~> (numericLit <~ ")") ^^^ StringType + protected lazy val varchar: Parser[DataType] = "(?i)varchar".r ~> "(" ~> (numericLit <~ ")") ^^^ StringType diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/DataTypeParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/DataTypeParserSuite.scala index 1e3409a9db6eb..bebf708965474 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/DataTypeParserSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/DataTypeParserSuite.scala @@ -49,7 +49,9 @@ class DataTypeParserSuite extends SparkFunSuite { checkDataType("DATE", DateType) checkDataType("timestamp", TimestampType) checkDataType("string", StringType) + checkDataType("ChaR(5)", StringType) checkDataType("varchAr(20)", StringType) + checkDataType("cHaR(27)", StringType) checkDataType("BINARY", BinaryType) checkDataType("array", ArrayType(DoubleType, true)) @@ -83,7 +85,8 @@ class DataTypeParserSuite extends SparkFunSuite { |struct< | struct:struct, | MAP:Map, - | arrAy:Array> + | arrAy:Array, + | anotherArray:Array> """.stripMargin, StructType( StructField("struct", @@ -91,7 +94,8 @@ class DataTypeParserSuite extends SparkFunSuite { StructField("deciMal", DecimalType.USER_DEFAULT, true) :: StructField("anotherDecimal", DecimalType(5, 2), true) :: Nil), true) :: StructField("MAP", MapType(TimestampType, StringType), true) :: - StructField("arrAy", ArrayType(DoubleType, true), true) :: Nil) + StructField("arrAy", ArrayType(DoubleType, true), true) :: + StructField("anotherArray", ArrayType(StringType, true), true) :: Nil) ) // A column name can be a reserved word in our DDL parser and SqlParser. checkDataType( diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/TableScanSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/TableScanSuite.scala index 12af8068c398f..26c1ff520406c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/TableScanSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/TableScanSuite.scala @@ -85,6 +85,7 @@ case class AllDataTypesScan( Date.valueOf("1970-01-01"), new Timestamp(20000 + i), s"varchar_$i", + s"char_$i", Seq(i, i + 1), Seq(Map(s"str_$i" -> Row(i.toLong))), Map(i -> i.toString), @@ -115,6 +116,7 @@ class TableScanSuite extends DataSourceTest with SharedSQLContext { Date.valueOf("1970-01-01"), new Timestamp(20000 + i), s"varchar_$i", + s"char_$i", Seq(i, i + 1), Seq(Map(s"str_$i" -> Row(i.toLong))), Map(i -> i.toString), @@ -154,6 +156,7 @@ class TableScanSuite extends DataSourceTest with SharedSQLContext { |dateField dAte, |timestampField tiMestamp, |varcharField varchaR(12), + |charField ChaR(18), |arrayFieldSimple Array, |arrayFieldComplex Array>>, |mapFieldSimple MAP, @@ -207,6 +210,7 @@ class TableScanSuite extends DataSourceTest with SharedSQLContext { StructField("dateField", DateType, true) :: StructField("timestampField", TimestampType, true) :: StructField("varcharField", StringType, true) :: + StructField("charField", StringType, true) :: StructField("arrayFieldSimple", ArrayType(IntegerType), true) :: StructField("arrayFieldComplex", ArrayType( @@ -248,6 +252,7 @@ class TableScanSuite extends DataSourceTest with SharedSQLContext { | dateField, | timestampField, | varcharField, + | charField, | arrayFieldSimple, | arrayFieldComplex, | mapFieldSimple, diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala index 36f0708f9da3d..95b57d6ad124a 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.hive import scala.collection.JavaConverters._ -import org.apache.hadoop.hive.common.`type`.{HiveDecimal, HiveVarchar} +import org.apache.hadoop.hive.common.`type`.{HiveChar, HiveDecimal, HiveVarchar} import org.apache.hadoop.hive.serde2.objectinspector.primitive._ import org.apache.hadoop.hive.serde2.objectinspector.{StructField => HiveStructField, _} import org.apache.hadoop.hive.serde2.typeinfo.{DecimalTypeInfo, TypeInfoFactory} @@ -61,6 +61,7 @@ import org.apache.spark.unsafe.types.UTF8String * Primitive Type * Java Boxed Primitives: * org.apache.hadoop.hive.common.type.HiveVarchar + * org.apache.hadoop.hive.common.type.HiveChar * java.lang.String * java.lang.Integer * java.lang.Boolean @@ -75,6 +76,7 @@ import org.apache.spark.unsafe.types.UTF8String * java.sql.Timestamp * Writables: * org.apache.hadoop.hive.serde2.io.HiveVarcharWritable + * org.apache.hadoop.hive.serde2.io.HiveCharWritable * org.apache.hadoop.io.Text * org.apache.hadoop.io.IntWritable * org.apache.hadoop.hive.serde2.io.DoubleWritable @@ -93,7 +95,8 @@ import org.apache.spark.unsafe.types.UTF8String * Struct: Object[] / java.util.List / java POJO * Union: class StandardUnion { byte tag; Object object } * - * NOTICE: HiveVarchar is not supported by catalyst, it will be simply considered as String type. + * NOTICE: HiveVarchar/HiveChar is not supported by catalyst, it will be simply considered as + * String type. * * * 2. Hive ObjectInspector is a group of flexible APIs to inspect value in different data @@ -137,6 +140,7 @@ import org.apache.spark.unsafe.types.UTF8String * Primitive Object Inspectors: * WritableConstantStringObjectInspector * WritableConstantHiveVarcharObjectInspector + * WritableConstantHiveCharObjectInspector * WritableConstantHiveDecimalObjectInspector * WritableConstantTimestampObjectInspector * WritableConstantIntObjectInspector @@ -259,6 +263,8 @@ private[hive] trait HiveInspectors { UTF8String.fromString(poi.getWritableConstantValue.toString) case poi: WritableConstantHiveVarcharObjectInspector => UTF8String.fromString(poi.getWritableConstantValue.getHiveVarchar.getValue) + case poi: WritableConstantHiveCharObjectInspector => + UTF8String.fromString(poi.getWritableConstantValue.getHiveChar.getValue) case poi: WritableConstantHiveDecimalObjectInspector => HiveShim.toCatalystDecimal( PrimitiveObjectInspectorFactory.javaHiveDecimalObjectInspector, @@ -303,11 +309,15 @@ private[hive] trait HiveInspectors { case _ if data == null => null case poi: VoidObjectInspector => null // always be null for void object inspector case pi: PrimitiveObjectInspector => pi match { - // We think HiveVarchar is also a String + // We think HiveVarchar/HiveChar is also a String case hvoi: HiveVarcharObjectInspector if hvoi.preferWritable() => UTF8String.fromString(hvoi.getPrimitiveWritableObject(data).getHiveVarchar.getValue) case hvoi: HiveVarcharObjectInspector => UTF8String.fromString(hvoi.getPrimitiveJavaObject(data).getValue) + case hvoi: HiveCharObjectInspector if hvoi.preferWritable() => + UTF8String.fromString(hvoi.getPrimitiveWritableObject(data).getHiveChar.getValue) + case hvoi: HiveCharObjectInspector => + UTF8String.fromString(hvoi.getPrimitiveJavaObject(data).getValue) case x: StringObjectInspector if x.preferWritable() => UTF8String.fromString(x.getPrimitiveWritableObject(data).toString) case x: StringObjectInspector => @@ -377,6 +387,15 @@ private[hive] trait HiveInspectors { null } + case _: JavaHiveCharObjectInspector => + (o: Any) => + if (o != null) { + val s = o.asInstanceOf[UTF8String].toString + new HiveChar(s, s.size) + } else { + null + } + case _: JavaHiveDecimalObjectInspector => (o: Any) => if (o != null) { diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/TableReader.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/TableReader.scala index 69f481c49a655..70ee02823eeba 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/TableReader.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/TableReader.scala @@ -382,6 +382,9 @@ private[hive] object HadoopTableReader extends HiveInspectors with Logging { case oi: HiveVarcharObjectInspector => (value: Any, row: MutableRow, ordinal: Int) => row.update(ordinal, UTF8String.fromString(oi.getPrimitiveJavaObject(value).getValue)) + case oi: HiveCharObjectInspector => + (value: Any, row: MutableRow, ordinal: Int) => + row.update(ordinal, UTF8String.fromString(oi.getPrimitiveJavaObject(value).getValue)) case oi: HiveDecimalObjectInspector => (value: Any, row: MutableRow, ordinal: Int) => row.update(ordinal, HiveShim.toCatalystDecimal(oi, value)) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveShim.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveShim.scala index 48bbb21e6c1de..346840079b853 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveShim.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveShim.scala @@ -321,7 +321,8 @@ private[client] class Shim_v0_13 extends Shim_v0_12 { def convertFilters(table: Table, filters: Seq[Expression]): String = { // hive varchar is treated as catalyst string, but hive varchar can't be pushed down. val varcharKeys = table.getPartitionKeys.asScala - .filter(col => col.getType.startsWith(serdeConstants.VARCHAR_TYPE_NAME)) + .filter(col => col.getType.startsWith(serdeConstants.VARCHAR_TYPE_NAME) || + col.getType.startsWith(serdeConstants.CHAR_TYPE_NAME)) .map(col => col.getName).toSet filters.collect { From 1a5baaa6517872b9a4fd6cd41c4b2cf1e390f6d1 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Mon, 23 Nov 2015 10:13:59 -0800 Subject: [PATCH 07/18] [SPARK-11894][SQL] fix isNull for GetInternalRowField We should use `InternalRow.isNullAt` to check if the field is null before calling `InternalRow.getXXX` Thanks gatorsmile who discovered this bug. Author: Wenchen Fan Closes #9904 from cloud-fan/null. --- .../sql/catalyst/expressions/objects.scala | 23 ++++++++----------- .../org/apache/spark/sql/DatasetSuite.scala | 15 +++++++++++- 2 files changed, 23 insertions(+), 15 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala index 82317d3385167..4a1f419f0ad8d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala @@ -236,11 +236,6 @@ case class NewInstance( } if (propagateNull) { - val objNullCheck = if (ctx.defaultValue(dataType) == "null") { - s"${ev.isNull} = ${ev.value} == null;" - } else { - "" - } val argsNonNull = s"!(${argGen.map(_.isNull).mkString(" || ")})" s""" @@ -531,15 +526,15 @@ case class GetInternalRowField(child: Expression, ordinal: Int, dataType: DataTy throw new UnsupportedOperationException("Only code-generated evaluation is supported") override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { - val row = child.gen(ctx) - s""" - ${row.code} - final boolean ${ev.isNull} = ${row.isNull}; - ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)}; - if (!${ev.isNull}) { - ${ev.value} = ${ctx.getValue(row.value, dataType, ordinal.toString)}; - } - """ + nullSafeCodeGen(ctx, ev, eval => { + s""" + if ($eval.isNullAt($ordinal)) { + ${ev.isNull} = true; + } else { + ${ev.value} = ${ctx.getValue(eval, dataType, ordinal.toString)}; + } + """ + }) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala index 9da02550b39ce..cc8e4325fd2f5 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala @@ -386,7 +386,7 @@ class DatasetSuite extends QueryTest with SharedSQLContext { Seq((JavaData(1), 1L), (JavaData(2), 1L))) } - ignore("Java encoder self join") { + test("Java encoder self join") { implicit val kryoEncoder = Encoders.javaSerialization[JavaData] val ds = Seq(JavaData(1), JavaData(2)).toDS() assert(ds.joinWith(ds, lit(true)).collect().toSet == @@ -396,6 +396,19 @@ class DatasetSuite extends QueryTest with SharedSQLContext { (JavaData(2), JavaData(1)), (JavaData(2), JavaData(2)))) } + + test("SPARK-11894: Incorrect results are returned when using null") { + val nullInt = null.asInstanceOf[java.lang.Integer] + val ds1 = Seq((nullInt, "1"), (new java.lang.Integer(22), "2")).toDS() + val ds2 = Seq((nullInt, "1"), (new java.lang.Integer(22), "2")).toDS() + + checkAnswer( + ds1.joinWith(ds2, lit(true)), + ((nullInt, "1"), (nullInt, "1")), + ((new java.lang.Integer(22), "2"), (nullInt, "1")), + ((nullInt, "1"), (new java.lang.Integer(22), "2")), + ((new java.lang.Integer(22), "2"), (new java.lang.Integer(22), "2"))) + } } From f2996e0d12eeb989b1bfa51a3f6fa54ce1ed4fca Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Mon, 23 Nov 2015 10:15:40 -0800 Subject: [PATCH 08/18] [SPARK-11921][SQL] fix `nullable` of encoder schema Author: Wenchen Fan Closes #9906 from cloud-fan/nullable. --- .../catalyst/encoders/ExpressionEncoder.scala | 15 +++++++- .../encoders/ExpressionEncoderSuite.scala | 38 ++++++++++++++++++- 2 files changed, 50 insertions(+), 3 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala index 6eeba1442c1f3..7bc9aed0b204e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala @@ -54,8 +54,13 @@ object ExpressionEncoder { val toRowExpression = ScalaReflection.extractorsFor[T](inputObject) val fromRowExpression = ScalaReflection.constructorFor[T] + val schema = ScalaReflection.schemaFor[T] match { + case ScalaReflection.Schema(s: StructType, _) => s + case ScalaReflection.Schema(dt, nullable) => new StructType().add("value", dt, nullable) + } + new ExpressionEncoder[T]( - toRowExpression.dataType, + schema, flat, toRowExpression.flatten, fromRowExpression, @@ -71,7 +76,13 @@ object ExpressionEncoder { encoders.foreach(_.assertUnresolved()) val schema = StructType(encoders.zipWithIndex.map { - case (e, i) => StructField(s"_${i + 1}", if (e.flat) e.schema.head.dataType else e.schema) + case (e, i) => + val (dataType, nullable) = if (e.flat) { + e.schema.head.dataType -> e.schema.head.nullable + } else { + e.schema -> true + } + StructField(s"_${i + 1}", dataType, nullable) }) val cls = Utils.getContextOrSparkClassLoader.loadClass(s"scala.Tuple${encoders.size}") diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala index 76459b34a484f..d6ca138672ef1 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala @@ -30,7 +30,7 @@ import org.apache.spark.sql.Encoders import org.apache.spark.sql.catalyst.expressions.AttributeReference import org.apache.spark.sql.catalyst.util.ArrayData import org.apache.spark.sql.catalyst.{OptionalData, PrimitiveData} -import org.apache.spark.sql.types.ArrayType +import org.apache.spark.sql.types.{StructType, ArrayType} case class RepeatedStruct(s: Seq[PrimitiveData]) @@ -238,6 +238,42 @@ class ExpressionEncoderSuite extends SparkFunSuite { ExpressionEncoder.tuple(intEnc, ExpressionEncoder.tuple(intEnc, longEnc)) } + test("nullable of encoder schema") { + def checkNullable[T: ExpressionEncoder](nullable: Boolean*): Unit = { + assert(implicitly[ExpressionEncoder[T]].schema.map(_.nullable) === nullable.toSeq) + } + + // test for flat encoders + checkNullable[Int](false) + checkNullable[Option[Int]](true) + checkNullable[java.lang.Integer](true) + checkNullable[String](true) + + // test for product encoders + checkNullable[(String, Int)](true, false) + checkNullable[(Int, java.lang.Long)](false, true) + + // test for nested product encoders + { + val schema = ExpressionEncoder[(Int, (String, Int))].schema + assert(schema(0).nullable === false) + assert(schema(1).nullable === true) + assert(schema(1).dataType.asInstanceOf[StructType](0).nullable === true) + assert(schema(1).dataType.asInstanceOf[StructType](1).nullable === false) + } + + // test for tupled encoders + { + val schema = ExpressionEncoder.tuple( + ExpressionEncoder[Int], + ExpressionEncoder[(String, Int)]).schema + assert(schema(0).nullable === false) + assert(schema(1).nullable === true) + assert(schema(1).dataType.asInstanceOf[StructType](0).nullable === true) + assert(schema(1).dataType.asInstanceOf[StructType](1).nullable === false) + } + } + private val outers: ConcurrentMap[String, AnyRef] = new MapMaker().weakValues().makeMap() outers.put(getClass.getName, this) private def encodeDecodeTest[T : ExpressionEncoder]( From 946b406519af58c79041217e6f93854b6cf80acd Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Mon, 23 Nov 2015 10:39:33 -0800 Subject: [PATCH 09/18] [SPARK-11913][SQL] support typed aggregate with complex buffer schema Author: Wenchen Fan Closes #9898 from cloud-fan/agg. --- .../aggregate/TypedAggregateExpression.scala | 25 +++++++---- .../spark/sql/DatasetAggregatorSuite.scala | 41 ++++++++++++++++++- 2 files changed, 56 insertions(+), 10 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TypedAggregateExpression.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TypedAggregateExpression.scala index 6ce41aaf01e27..a9719128a626e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TypedAggregateExpression.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TypedAggregateExpression.scala @@ -23,9 +23,8 @@ import org.apache.spark.Logging import org.apache.spark.sql.Encoder import org.apache.spark.sql.expressions.Aggregator import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.encoders.encoderFor +import org.apache.spark.sql.catalyst.encoders.{OuterScopes, encoderFor, ExpressionEncoder} import org.apache.spark.sql.catalyst.expressions.aggregate.ImperativeAggregate -import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.types._ @@ -46,14 +45,12 @@ object TypedAggregateExpression { /** * This class is a rough sketch of how to hook `Aggregator` into the Aggregation system. It has * the following limitations: - * - It assumes the aggregator reduces and returns a single column of type `long`. - * - It might only work when there is a single aggregator in the first column. * - It assumes the aggregator has a zero, `0`. */ case class TypedAggregateExpression( aggregator: Aggregator[Any, Any, Any], aEncoder: Option[ExpressionEncoder[Any]], // Should be bound. - bEncoder: ExpressionEncoder[Any], // Should be bound. + unresolvedBEncoder: ExpressionEncoder[Any], cEncoder: ExpressionEncoder[Any], children: Seq[Attribute], mutableAggBufferOffset: Int, @@ -80,10 +77,14 @@ case class TypedAggregateExpression( override lazy val inputTypes: Seq[DataType] = Nil - override val aggBufferSchema: StructType = bEncoder.schema + override val aggBufferSchema: StructType = unresolvedBEncoder.schema override val aggBufferAttributes: Seq[AttributeReference] = aggBufferSchema.toAttributes + val bEncoder = unresolvedBEncoder + .resolve(aggBufferAttributes, OuterScopes.outerScopes) + .bind(aggBufferAttributes) + // Note: although this simply copies aggBufferAttributes, this common code can not be placed // in the superclass because that will lead to initialization ordering issues. override val inputAggBufferAttributes: Seq[AttributeReference] = @@ -93,12 +94,18 @@ case class TypedAggregateExpression( lazy val boundA = aEncoder.get private def updateBuffer(buffer: MutableRow, value: InternalRow): Unit = { - // todo: need a more neat way to assign the value. var i = 0 while (i < aggBufferAttributes.length) { + val offset = mutableAggBufferOffset + i aggBufferSchema(i).dataType match { - case IntegerType => buffer.setInt(mutableAggBufferOffset + i, value.getInt(i)) - case LongType => buffer.setLong(mutableAggBufferOffset + i, value.getLong(i)) + case BooleanType => buffer.setBoolean(offset, value.getBoolean(i)) + case ByteType => buffer.setByte(offset, value.getByte(i)) + case ShortType => buffer.setShort(offset, value.getShort(i)) + case IntegerType => buffer.setInt(offset, value.getInt(i)) + case LongType => buffer.setLong(offset, value.getLong(i)) + case FloatType => buffer.setFloat(offset, value.getFloat(i)) + case DoubleType => buffer.setDouble(offset, value.getDouble(i)) + case other => buffer.update(offset, value.get(i, other)) } i += 1 } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetAggregatorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetAggregatorSuite.scala index 9377589790011..19dce5d1e2f37 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetAggregatorSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetAggregatorSuite.scala @@ -67,7 +67,7 @@ object ComplexResultAgg extends Aggregator[(String, Int), (Long, Long), (Long, L } case class AggData(a: Int, b: String) -object ClassInputAgg extends Aggregator[AggData, Int, Int] with Serializable { +object ClassInputAgg extends Aggregator[AggData, Int, Int] { /** A zero value for this aggregation. Should satisfy the property that any b + zero = b */ override def zero: Int = 0 @@ -88,6 +88,28 @@ object ClassInputAgg extends Aggregator[AggData, Int, Int] with Serializable { override def merge(b1: Int, b2: Int): Int = b1 + b2 } +object ComplexBufferAgg extends Aggregator[AggData, (Int, AggData), Int] { + /** A zero value for this aggregation. Should satisfy the property that any b + zero = b */ + override def zero: (Int, AggData) = 0 -> AggData(0, "0") + + /** + * Combine two values to produce a new value. For performance, the function may modify `b` and + * return it instead of constructing new object for b. + */ + override def reduce(b: (Int, AggData), a: AggData): (Int, AggData) = (b._1 + 1, a) + + /** + * Transform the output of the reduction. + */ + override def finish(reduction: (Int, AggData)): Int = reduction._1 + + /** + * Merge two intermediate values + */ + override def merge(b1: (Int, AggData), b2: (Int, AggData)): (Int, AggData) = + (b1._1 + b2._1, b1._2) +} + class DatasetAggregatorSuite extends QueryTest with SharedSQLContext { import testImplicits._ @@ -168,4 +190,21 @@ class DatasetAggregatorSuite extends QueryTest with SharedSQLContext { ds.groupBy(_.b).agg(ClassInputAgg.toColumn), ("one", 1)) } + + test("typed aggregation: complex input") { + val ds = Seq(AggData(1, "one"), AggData(2, "two")).toDS() + + checkAnswer( + ds.select(ComplexBufferAgg.toColumn), + 2 + ) + + checkAnswer( + ds.select(expr("avg(a)").as[Double], ComplexBufferAgg.toColumn), + (1.5, 2)) + + checkAnswer( + ds.groupBy(_.b).agg(ComplexBufferAgg.toColumn), + ("one", 1), ("two", 1)) + } } From 5fd86e4fc2e06d2403ca538ae417580c93b69e06 Mon Sep 17 00:00:00 2001 From: jerryshao Date: Mon, 23 Nov 2015 10:41:17 -0800 Subject: [PATCH 10/18] [SPARK-7173][YARN] Add label expression support for application master Add label expression support for AM to restrict it runs on the specific set of nodes. I tested it locally and works fine. sryza and vanzin please help to review, thanks a lot. Author: jerryshao Closes #9800 from jerryshao/SPARK-7173. --- docs/running-on-yarn.md | 9 +++++++ .../org/apache/spark/deploy/yarn/Client.scala | 26 ++++++++++++++++++- 2 files changed, 34 insertions(+), 1 deletion(-) diff --git a/docs/running-on-yarn.md b/docs/running-on-yarn.md index db6bfa69ee0fe..925a1e0ba6fcf 100644 --- a/docs/running-on-yarn.md +++ b/docs/running-on-yarn.md @@ -326,6 +326,15 @@ If you need a reference to the proper location to put log files in the YARN so t Otherwise, the client process will exit after submission. + + spark.yarn.am.nodeLabelExpression + (none) + + A YARN node label expression that restricts the set of nodes AM will be scheduled on. + Only versions of YARN greater than or equal to 2.6 support node label expressions, so when + running against earlier versions, this property will be ignored. + + spark.yarn.executor.nodeLabelExpression (none) diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala index ba799884f5689..a77a3e2420e24 100644 --- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala +++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala @@ -225,7 +225,31 @@ private[spark] class Client( val capability = Records.newRecord(classOf[Resource]) capability.setMemory(args.amMemory + amMemoryOverhead) capability.setVirtualCores(args.amCores) - appContext.setResource(capability) + + if (sparkConf.contains("spark.yarn.am.nodeLabelExpression")) { + try { + val amRequest = Records.newRecord(classOf[ResourceRequest]) + amRequest.setResourceName(ResourceRequest.ANY) + amRequest.setPriority(Priority.newInstance(0)) + amRequest.setCapability(capability) + amRequest.setNumContainers(1) + val amLabelExpression = sparkConf.get("spark.yarn.am.nodeLabelExpression") + val method = amRequest.getClass.getMethod("setNodeLabelExpression", classOf[String]) + method.invoke(amRequest, amLabelExpression) + + val setResourceRequestMethod = + appContext.getClass.getMethod("setAMContainerResourceRequest", classOf[ResourceRequest]) + setResourceRequestMethod.invoke(appContext, amRequest) + } catch { + case e: NoSuchMethodException => + logWarning("Ignoring spark.yarn.am.nodeLabelExpression because the version " + + "of YARN does not support it") + appContext.setResource(capability) + } + } else { + appContext.setResource(capability) + } + appContext } From 5231cd5acaae69d735ba3209531705cc222f3cfb Mon Sep 17 00:00:00 2001 From: Marcelo Vanzin Date: Mon, 23 Nov 2015 10:45:23 -0800 Subject: [PATCH 11/18] [SPARK-11762][NETWORK] Account for active streams when couting outstanding requests. This way the timeout handling code can correctly close "hung" channels that are processing streams. Author: Marcelo Vanzin Closes #9747 from vanzin/SPARK-11762. --- .../network/client/StreamInterceptor.java | 12 ++++++++- .../client/TransportResponseHandler.java | 15 +++++++++-- .../TransportResponseHandlerSuite.java | 27 +++++++++++++++++++ 3 files changed, 51 insertions(+), 3 deletions(-) diff --git a/network/common/src/main/java/org/apache/spark/network/client/StreamInterceptor.java b/network/common/src/main/java/org/apache/spark/network/client/StreamInterceptor.java index 02230a00e69fc..88ba3ccebdf20 100644 --- a/network/common/src/main/java/org/apache/spark/network/client/StreamInterceptor.java +++ b/network/common/src/main/java/org/apache/spark/network/client/StreamInterceptor.java @@ -30,13 +30,19 @@ */ class StreamInterceptor implements TransportFrameDecoder.Interceptor { + private final TransportResponseHandler handler; private final String streamId; private final long byteCount; private final StreamCallback callback; private volatile long bytesRead; - StreamInterceptor(String streamId, long byteCount, StreamCallback callback) { + StreamInterceptor( + TransportResponseHandler handler, + String streamId, + long byteCount, + StreamCallback callback) { + this.handler = handler; this.streamId = streamId; this.byteCount = byteCount; this.callback = callback; @@ -45,11 +51,13 @@ class StreamInterceptor implements TransportFrameDecoder.Interceptor { @Override public void exceptionCaught(Throwable cause) throws Exception { + handler.deactivateStream(); callback.onFailure(streamId, cause); } @Override public void channelInactive() throws Exception { + handler.deactivateStream(); callback.onFailure(streamId, new ClosedChannelException()); } @@ -65,8 +73,10 @@ public boolean handle(ByteBuf buf) throws Exception { RuntimeException re = new IllegalStateException(String.format( "Read too many bytes? Expected %d, but read %d.", byteCount, bytesRead)); callback.onFailure(streamId, re); + handler.deactivateStream(); throw re; } else if (bytesRead == byteCount) { + handler.deactivateStream(); callback.onComplete(streamId); } diff --git a/network/common/src/main/java/org/apache/spark/network/client/TransportResponseHandler.java b/network/common/src/main/java/org/apache/spark/network/client/TransportResponseHandler.java index ed3f36af58048..cc88991b588c1 100644 --- a/network/common/src/main/java/org/apache/spark/network/client/TransportResponseHandler.java +++ b/network/common/src/main/java/org/apache/spark/network/client/TransportResponseHandler.java @@ -24,6 +24,7 @@ import java.util.concurrent.ConcurrentLinkedQueue; import java.util.concurrent.atomic.AtomicLong; +import com.google.common.annotations.VisibleForTesting; import io.netty.channel.Channel; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -56,6 +57,7 @@ public class TransportResponseHandler extends MessageHandler { private final Map outstandingRpcs; private final Queue streamCallbacks; + private volatile boolean streamActive; /** Records the time (in system nanoseconds) that the last fetch or RPC request was sent. */ private final AtomicLong timeOfLastRequestNs; @@ -87,9 +89,15 @@ public void removeRpcRequest(long requestId) { } public void addStreamCallback(StreamCallback callback) { + timeOfLastRequestNs.set(System.nanoTime()); streamCallbacks.offer(callback); } + @VisibleForTesting + public void deactivateStream() { + streamActive = false; + } + /** * Fire the failure callback for all outstanding requests. This is called when we have an * uncaught exception or pre-mature connection termination. @@ -177,14 +185,16 @@ public void handle(ResponseMessage message) { StreamResponse resp = (StreamResponse) message; StreamCallback callback = streamCallbacks.poll(); if (callback != null) { - StreamInterceptor interceptor = new StreamInterceptor(resp.streamId, resp.byteCount, + StreamInterceptor interceptor = new StreamInterceptor(this, resp.streamId, resp.byteCount, callback); try { TransportFrameDecoder frameDecoder = (TransportFrameDecoder) channel.pipeline().get(TransportFrameDecoder.HANDLER_NAME); frameDecoder.setInterceptor(interceptor); + streamActive = true; } catch (Exception e) { logger.error("Error installing stream handler.", e); + deactivateStream(); } } else { logger.error("Could not find callback for StreamResponse."); @@ -208,7 +218,8 @@ public void handle(ResponseMessage message) { /** Returns total number of outstanding requests (fetch requests + rpcs) */ public int numOutstandingRequests() { - return outstandingFetches.size() + outstandingRpcs.size(); + return outstandingFetches.size() + outstandingRpcs.size() + streamCallbacks.size() + + (streamActive ? 1 : 0); } /** Returns the time in nanoseconds of when the last request was sent out. */ diff --git a/network/common/src/test/java/org/apache/spark/network/TransportResponseHandlerSuite.java b/network/common/src/test/java/org/apache/spark/network/TransportResponseHandlerSuite.java index 17a03ebe88a93..30144f4a9fc7a 100644 --- a/network/common/src/test/java/org/apache/spark/network/TransportResponseHandlerSuite.java +++ b/network/common/src/test/java/org/apache/spark/network/TransportResponseHandlerSuite.java @@ -17,6 +17,7 @@ package org.apache.spark.network; +import io.netty.channel.Channel; import io.netty.channel.local.LocalChannel; import org.junit.Test; @@ -28,12 +29,16 @@ import org.apache.spark.network.buffer.ManagedBuffer; import org.apache.spark.network.client.ChunkReceivedCallback; import org.apache.spark.network.client.RpcResponseCallback; +import org.apache.spark.network.client.StreamCallback; import org.apache.spark.network.client.TransportResponseHandler; import org.apache.spark.network.protocol.ChunkFetchFailure; import org.apache.spark.network.protocol.ChunkFetchSuccess; import org.apache.spark.network.protocol.RpcFailure; import org.apache.spark.network.protocol.RpcResponse; import org.apache.spark.network.protocol.StreamChunkId; +import org.apache.spark.network.protocol.StreamFailure; +import org.apache.spark.network.protocol.StreamResponse; +import org.apache.spark.network.util.TransportFrameDecoder; public class TransportResponseHandlerSuite { @Test @@ -112,4 +117,26 @@ public void handleFailedRPC() { verify(callback, times(1)).onFailure((Throwable) any()); assertEquals(0, handler.numOutstandingRequests()); } + + @Test + public void testActiveStreams() { + Channel c = new LocalChannel(); + c.pipeline().addLast(TransportFrameDecoder.HANDLER_NAME, new TransportFrameDecoder()); + TransportResponseHandler handler = new TransportResponseHandler(c); + + StreamResponse response = new StreamResponse("stream", 1234L, null); + StreamCallback cb = mock(StreamCallback.class); + handler.addStreamCallback(cb); + assertEquals(1, handler.numOutstandingRequests()); + handler.handle(response); + assertEquals(1, handler.numOutstandingRequests()); + handler.deactivateStream(); + assertEquals(0, handler.numOutstandingRequests()); + + StreamFailure failure = new StreamFailure("stream", "uh-oh"); + handler.addStreamCallback(cb); + assertEquals(1, handler.numOutstandingRequests()); + handler.handle(failure); + assertEquals(0, handler.numOutstandingRequests()); + } } From 98d7ec7df4bb115dbd84cb9acd744b6c8abfebd5 Mon Sep 17 00:00:00 2001 From: Yanbo Liang Date: Mon, 23 Nov 2015 11:51:29 -0800 Subject: [PATCH 12/18] [SPARK-11920][ML][DOC] ML LinearRegression should use correct dataset in examples and user guide doc ML ```LinearRegression``` use ```data/mllib/sample_libsvm_data.txt``` as dataset in examples and user guide doc, but it's actually classification dataset rather than regression dataset. We should use ```data/mllib/sample_linear_regression_data.txt``` instead. The deeper causes is that ```LinearRegression``` with "normal" solver can not solve this dataset correctly, may be due to the ill condition and unreasonable label. This issue has been reported at [SPARK-11918](https://issues.apache.org/jira/browse/SPARK-11918). It will confuse users if they run the example code but get exception, so we should make this change which can clearly illustrate the usage of ```LinearRegression``` algorithm. Author: Yanbo Liang Closes #9905 from yanboliang/spark-11920. --- .../examples/ml/JavaLinearRegressionWithElasticNetExample.java | 2 +- .../src/main/python/ml/linear_regression_with_elastic_net.py | 3 ++- .../examples/ml/LinearRegressionWithElasticNetExample.scala | 3 ++- 3 files changed, 5 insertions(+), 3 deletions(-) diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaLinearRegressionWithElasticNetExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaLinearRegressionWithElasticNetExample.java index 593f8fb3e9fe9..4ad7676c8d32b 100644 --- a/examples/src/main/java/org/apache/spark/examples/ml/JavaLinearRegressionWithElasticNetExample.java +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaLinearRegressionWithElasticNetExample.java @@ -37,7 +37,7 @@ public static void main(String[] args) { // $example on$ // Load training data DataFrame training = sqlContext.read().format("libsvm") - .load("data/mllib/sample_libsvm_data.txt"); + .load("data/mllib/sample_linear_regression_data.txt"); LinearRegression lr = new LinearRegression() .setMaxIter(10) diff --git a/examples/src/main/python/ml/linear_regression_with_elastic_net.py b/examples/src/main/python/ml/linear_regression_with_elastic_net.py index b0278276330c3..a4cd40cf26726 100644 --- a/examples/src/main/python/ml/linear_regression_with_elastic_net.py +++ b/examples/src/main/python/ml/linear_regression_with_elastic_net.py @@ -29,7 +29,8 @@ # $example on$ # Load training data - training = sqlContext.read.format("libsvm").load("data/mllib/sample_libsvm_data.txt") + training = sqlContext.read.format("libsvm")\ + .load("data/mllib/sample_linear_regression_data.txt") lr = LinearRegression(maxIter=10, regParam=0.3, elasticNetParam=0.8) diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/LinearRegressionWithElasticNetExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/LinearRegressionWithElasticNetExample.scala index 5a51ece6f9ba7..22c824cea84d3 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/LinearRegressionWithElasticNetExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/LinearRegressionWithElasticNetExample.scala @@ -33,7 +33,8 @@ object LinearRegressionWithElasticNetExample { // $example on$ // Load training data - val training = sqlCtx.read.format("libsvm").load("data/mllib/sample_libsvm_data.txt") + val training = sqlCtx.read.format("libsvm") + .load("data/mllib/sample_linear_regression_data.txt") val lr = new LinearRegression() .setMaxIter(10) From f6dcc6e96ad3f88563d717d5b6c45394b44db747 Mon Sep 17 00:00:00 2001 From: Mortada Mehyar Date: Mon, 23 Nov 2015 12:03:15 -0800 Subject: [PATCH 13/18] [SPARK-11837][EC2] python3 compatibility for launching ec2 m3 instances this currently breaks for python3 because `string` module doesn't have `letters` anymore, instead `ascii_letters` should be used Author: Mortada Mehyar Closes #9797 from mortada/python3_fix. --- ec2/spark_ec2.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ec2/spark_ec2.py b/ec2/spark_ec2.py index 9327e21e43db7..9fd652a3df4c4 100755 --- a/ec2/spark_ec2.py +++ b/ec2/spark_ec2.py @@ -595,7 +595,7 @@ def launch_cluster(conn, opts, cluster_name): dev = BlockDeviceType() dev.ephemeral_name = 'ephemeral%d' % i # The first ephemeral drive is /dev/sdb. - name = '/dev/sd' + string.letters[i + 1] + name = '/dev/sd' + string.ascii_letters[i + 1] block_map[name] = dev # Launch slaves From 1b6e938be836786bac542fa430580248161e5403 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Mon, 23 Nov 2015 13:19:10 -0800 Subject: [PATCH 14/18] [SPARK-4424] Remove spark.driver.allowMultipleContexts override in tests This patch removes `spark.driver.allowMultipleContexts=true` from our test configuration. The multiple SparkContexts check was originally disabled because certain tests suites in SQL needed to create multiple contexts. As far as I know, this configuration change is no longer necessary, so we should remove it in order to make it easier to find test cleanup bugs. Author: Josh Rosen Closes #9865 from JoshRosen/SPARK-4424. --- pom.xml | 2 -- project/SparkBuild.scala | 1 - 2 files changed, 3 deletions(-) diff --git a/pom.xml b/pom.xml index ad849112ce76c..234fd5dea1a6e 100644 --- a/pom.xml +++ b/pom.xml @@ -1958,7 +1958,6 @@ false false false - true true src @@ -1997,7 +1996,6 @@ 1 false false - true true __not_used__ diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala index 67724c4e9e411..f575f0012d59e 100644 --- a/project/SparkBuild.scala +++ b/project/SparkBuild.scala @@ -632,7 +632,6 @@ object TestSettings { javaOptions in Test += "-Dspark.master.rest.enabled=false", javaOptions in Test += "-Dspark.ui.enabled=false", javaOptions in Test += "-Dspark.ui.showConsoleProgress=false", - javaOptions in Test += "-Dspark.driver.allowMultipleContexts=true", javaOptions in Test += "-Dspark.unsafe.exceptionOnMemoryLeak=true", javaOptions in Test += "-Dsun.io.serialization.extendedDebugInfo=true", javaOptions in Test += "-Dderby.system.durability=test", From 1d9120201012213edb1971a09e0849336dbb9415 Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Mon, 23 Nov 2015 13:44:30 -0800 Subject: [PATCH 15/18] [SPARK-11836][SQL] udf/cast should not create new SQLContext They should use the existing SQLContext. Author: Davies Liu Closes #9914 from davies/create_udf. --- python/pyspark/sql/column.py | 7 ++++--- python/pyspark/sql/functions.py | 7 ++++--- 2 files changed, 8 insertions(+), 6 deletions(-) diff --git a/python/pyspark/sql/column.py b/python/pyspark/sql/column.py index 9ca8e1f264cfa..81fd4e782628a 100644 --- a/python/pyspark/sql/column.py +++ b/python/pyspark/sql/column.py @@ -346,9 +346,10 @@ def cast(self, dataType): if isinstance(dataType, basestring): jc = self._jc.cast(dataType) elif isinstance(dataType, DataType): - sc = SparkContext._active_spark_context - ssql_ctx = sc._jvm.SQLContext(sc._jsc.sc()) - jdt = ssql_ctx.parseDataType(dataType.json()) + from pyspark.sql import SQLContext + sc = SparkContext.getOrCreate() + ctx = SQLContext.getOrCreate(sc) + jdt = ctx._ssql_ctx.parseDataType(dataType.json()) jc = self._jc.cast(jdt) else: raise TypeError("unexpected type: %s" % type(dataType)) diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index c3da513c13897..a1ca723bbd7ab 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -1457,14 +1457,15 @@ def __init__(self, func, returnType, name=None): self._judf = self._create_judf(name) def _create_judf(self, name): + from pyspark.sql import SQLContext f, returnType = self.func, self.returnType # put them in closure `func` func = lambda _, it: map(lambda x: returnType.toInternal(f(*x)), it) ser = AutoBatchedSerializer(PickleSerializer()) command = (func, None, ser, ser) - sc = SparkContext._active_spark_context + sc = SparkContext.getOrCreate() pickled_command, broadcast_vars, env, includes = _prepare_for_python_RDD(sc, command, self) - ssql_ctx = sc._jvm.SQLContext(sc._jsc.sc()) - jdt = ssql_ctx.parseDataType(self.returnType.json()) + ctx = SQLContext.getOrCreate(sc) + jdt = ctx._ssql_ctx.parseDataType(self.returnType.json()) if name is None: name = f.__name__ if hasattr(f, '__name__') else f.__class__.__name__ judf = sc._jvm.UserDefinedPythonFunction(name, bytearray(pickled_command), env, includes, From 242be7daed9b01d19794bb2cf1ac421fe5ab7262 Mon Sep 17 00:00:00 2001 From: Luciano Resende Date: Mon, 23 Nov 2015 13:46:34 -0800 Subject: [PATCH 16/18] [SPARK-11910][STREAMING][DOCS] Update twitter4j dependency version Author: Luciano Resende Closes #9892 from lresende/SPARK-11910. --- docs/streaming-programming-guide.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/streaming-programming-guide.md b/docs/streaming-programming-guide.md index 96b36b7a73209..ed6b28c282135 100644 --- a/docs/streaming-programming-guide.md +++ b/docs/streaming-programming-guide.md @@ -723,7 +723,7 @@ Some of these advanced sources are as follows. - **Kinesis:** Spark Streaming {{site.SPARK_VERSION_SHORT}} is compatible with Kinesis Client Library 1.2.1. See the [Kinesis Integration Guide](streaming-kinesis-integration.html) for more details. -- **Twitter:** Spark Streaming's TwitterUtils uses Twitter4j 3.0.3 to get the public stream of tweets using +- **Twitter:** Spark Streaming's TwitterUtils uses Twitter4j to get the public stream of tweets using [Twitter's Streaming API](https://dev.twitter.com/docs/streaming-apis). Authentication information can be provided by any of the [methods](http://twitter4j.org/en/configuration.html) supported by Twitter4J library. You can either get the public stream, or get the filtered stream based on a From 7cfa4c6bc36d97e459d4adee7b03d537d63c337e Mon Sep 17 00:00:00 2001 From: Marcelo Vanzin Date: Mon, 23 Nov 2015 13:51:43 -0800 Subject: [PATCH 17/18] [SPARK-11865][NETWORK] Avoid returning inactive client in TransportClientFactory. There's a very narrow race here where it would be possible for the timeout handler to close a channel after the client factory verified that the channel was still active. This change makes sure the client is marked as being recently in use so that the timeout handler does not close it until a new timeout cycle elapses. Author: Marcelo Vanzin Closes #9853 from vanzin/SPARK-11865. --- .../spark/network/client/TransportClient.java | 9 ++++- .../client/TransportClientFactory.java | 15 ++++++-- .../client/TransportResponseHandler.java | 9 +++-- .../server/TransportChannelHandler.java | 36 ++++++++++++------- 4 files changed, 52 insertions(+), 17 deletions(-) diff --git a/network/common/src/main/java/org/apache/spark/network/client/TransportClient.java b/network/common/src/main/java/org/apache/spark/network/client/TransportClient.java index a0ba223e340a2..876fcd846791c 100644 --- a/network/common/src/main/java/org/apache/spark/network/client/TransportClient.java +++ b/network/common/src/main/java/org/apache/spark/network/client/TransportClient.java @@ -73,10 +73,12 @@ public class TransportClient implements Closeable { private final Channel channel; private final TransportResponseHandler handler; @Nullable private String clientId; + private volatile boolean timedOut; public TransportClient(Channel channel, TransportResponseHandler handler) { this.channel = Preconditions.checkNotNull(channel); this.handler = Preconditions.checkNotNull(handler); + this.timedOut = false; } public Channel getChannel() { @@ -84,7 +86,7 @@ public Channel getChannel() { } public boolean isActive() { - return channel.isOpen() || channel.isActive(); + return !timedOut && (channel.isOpen() || channel.isActive()); } public SocketAddress getSocketAddress() { @@ -263,6 +265,11 @@ public void onFailure(Throwable e) { } } + /** Mark this channel as having timed out. */ + public void timeOut() { + this.timedOut = true; + } + @Override public void close() { // close is a local operation and should finish with milliseconds; timeout just to be safe diff --git a/network/common/src/main/java/org/apache/spark/network/client/TransportClientFactory.java b/network/common/src/main/java/org/apache/spark/network/client/TransportClientFactory.java index 42a4f664e697c..659c47160c7be 100644 --- a/network/common/src/main/java/org/apache/spark/network/client/TransportClientFactory.java +++ b/network/common/src/main/java/org/apache/spark/network/client/TransportClientFactory.java @@ -136,8 +136,19 @@ public TransportClient createClient(String remoteHost, int remotePort) throws IO TransportClient cachedClient = clientPool.clients[clientIndex]; if (cachedClient != null && cachedClient.isActive()) { - logger.trace("Returning cached connection to {}: {}", address, cachedClient); - return cachedClient; + // Make sure that the channel will not timeout by updating the last use time of the + // handler. Then check that the client is still alive, in case it timed out before + // this code was able to update things. + TransportChannelHandler handler = cachedClient.getChannel().pipeline() + .get(TransportChannelHandler.class); + synchronized (handler) { + handler.getResponseHandler().updateTimeOfLastRequest(); + } + + if (cachedClient.isActive()) { + logger.trace("Returning cached connection to {}: {}", address, cachedClient); + return cachedClient; + } } // If we reach here, we don't have an existing connection open. Let's create a new one. diff --git a/network/common/src/main/java/org/apache/spark/network/client/TransportResponseHandler.java b/network/common/src/main/java/org/apache/spark/network/client/TransportResponseHandler.java index cc88991b588c1..be181e0660826 100644 --- a/network/common/src/main/java/org/apache/spark/network/client/TransportResponseHandler.java +++ b/network/common/src/main/java/org/apache/spark/network/client/TransportResponseHandler.java @@ -71,7 +71,7 @@ public TransportResponseHandler(Channel channel) { } public void addFetchRequest(StreamChunkId streamChunkId, ChunkReceivedCallback callback) { - timeOfLastRequestNs.set(System.nanoTime()); + updateTimeOfLastRequest(); outstandingFetches.put(streamChunkId, callback); } @@ -80,7 +80,7 @@ public void removeFetchRequest(StreamChunkId streamChunkId) { } public void addRpcRequest(long requestId, RpcResponseCallback callback) { - timeOfLastRequestNs.set(System.nanoTime()); + updateTimeOfLastRequest(); outstandingRpcs.put(requestId, callback); } @@ -227,4 +227,9 @@ public long getTimeOfLastRequestNs() { return timeOfLastRequestNs.get(); } + /** Updates the time of the last request to the current system time. */ + public void updateTimeOfLastRequest() { + timeOfLastRequestNs.set(System.nanoTime()); + } + } diff --git a/network/common/src/main/java/org/apache/spark/network/server/TransportChannelHandler.java b/network/common/src/main/java/org/apache/spark/network/server/TransportChannelHandler.java index f8fcd1c3d7d76..29d688a67578c 100644 --- a/network/common/src/main/java/org/apache/spark/network/server/TransportChannelHandler.java +++ b/network/common/src/main/java/org/apache/spark/network/server/TransportChannelHandler.java @@ -116,20 +116,32 @@ public void userEventTriggered(ChannelHandlerContext ctx, Object evt) throws Exc // there are outstanding requests, we also do a secondary consistency check to ensure // there's no race between the idle timeout and incrementing the numOutstandingRequests // (see SPARK-7003). - boolean isActuallyOverdue = - System.nanoTime() - responseHandler.getTimeOfLastRequestNs() > requestTimeoutNs; - if (e.state() == IdleState.ALL_IDLE && isActuallyOverdue) { - if (responseHandler.numOutstandingRequests() > 0) { - String address = NettyUtils.getRemoteAddress(ctx.channel()); - logger.error("Connection to {} has been quiet for {} ms while there are outstanding " + - "requests. Assuming connection is dead; please adjust spark.network.timeout if this " + - "is wrong.", address, requestTimeoutNs / 1000 / 1000); - ctx.close(); - } else if (closeIdleConnections) { - // While CloseIdleConnections is enable, we also close idle connection - ctx.close(); + // + // To avoid a race between TransportClientFactory.createClient() and this code which could + // result in an inactive client being returned, this needs to run in a synchronized block. + synchronized (this) { + boolean isActuallyOverdue = + System.nanoTime() - responseHandler.getTimeOfLastRequestNs() > requestTimeoutNs; + if (e.state() == IdleState.ALL_IDLE && isActuallyOverdue) { + if (responseHandler.numOutstandingRequests() > 0) { + String address = NettyUtils.getRemoteAddress(ctx.channel()); + logger.error("Connection to {} has been quiet for {} ms while there are outstanding " + + "requests. Assuming connection is dead; please adjust spark.network.timeout if this " + + "is wrong.", address, requestTimeoutNs / 1000 / 1000); + client.timeOut(); + ctx.close(); + } else if (closeIdleConnections) { + // While CloseIdleConnections is enable, we also close idle connection + client.timeOut(); + ctx.close(); + } } } } } + + public TransportResponseHandler getResponseHandler() { + return responseHandler; + } + } From c2467dadae8ce44010a912ee91c429310f8add65 Mon Sep 17 00:00:00 2001 From: Marcelo Vanzin Date: Mon, 23 Nov 2015 13:54:19 -0800 Subject: [PATCH 18/18] [SPARK-11140][CORE] Transfer files using network lib when using NettyRpcEnv. This change abstracts the code that serves jars / files to executors so that each RpcEnv can have its own implementation; the akka version uses the existing HTTP-based file serving mechanism, while the netty versions uses the new stream support added to the network lib, which makes file transfers benefit from the easier security configuration of the network library, and should also reduce overhead overall. The change includes a small fix to TransportChannelHandler so that it propagates user events to downstream handlers. Author: Marcelo Vanzin Closes #9530 from vanzin/SPARK-11140. --- .../scala/org/apache/spark/SparkContext.scala | 8 +- .../scala/org/apache/spark/SparkEnv.scala | 14 -- .../scala/org/apache/spark/rpc/RpcEnv.scala | 46 ++++++ .../apache/spark/rpc/akka/AkkaRpcEnv.scala | 60 +++++++- .../apache/spark/rpc/netty/NettyRpcEnv.scala | 138 ++++++++++++++++-- .../spark/rpc/netty/NettyStreamManager.scala | 63 ++++++++ .../scala/org/apache/spark/util/Utils.scala | 9 ++ .../org/apache/spark/rpc/RpcEnvSuite.scala | 39 ++++- .../rpc/netty/NettyRpcHandlerSuite.scala | 10 +- docs/configuration.md | 2 + docs/security.md | 5 +- .../launcher/AbstractCommandBuilder.java | 2 +- .../client/TransportClientFactory.java | 6 +- .../server/TransportChannelHandler.java | 1 + 14 files changed, 356 insertions(+), 47 deletions(-) create mode 100644 core/src/main/scala/org/apache/spark/rpc/netty/NettyStreamManager.scala diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index af4456c05b0a1..b153a7b08e590 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -1379,7 +1379,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli } val key = if (!isLocal && scheme == "file") { - env.httpFileServer.addFile(new File(uri.getPath)) + env.rpcEnv.fileServer.addFile(new File(uri.getPath)) } else { schemeCorrectedPath } @@ -1630,7 +1630,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli var key = "" if (path.contains("\\")) { // For local paths with backslashes on Windows, URI throws an exception - key = env.httpFileServer.addJar(new File(path)) + key = env.rpcEnv.fileServer.addJar(new File(path)) } else { val uri = new URI(path) key = uri.getScheme match { @@ -1644,7 +1644,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli // of the AM to make it show up in the current working directory. val fileName = new Path(uri.getPath).getName() try { - env.httpFileServer.addJar(new File(fileName)) + env.rpcEnv.fileServer.addJar(new File(fileName)) } catch { case e: Exception => // For now just log an error but allow to go through so spark examples work. @@ -1655,7 +1655,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli } } else { try { - env.httpFileServer.addJar(new File(uri.getPath)) + env.rpcEnv.fileServer.addJar(new File(uri.getPath)) } catch { case exc: FileNotFoundException => logError(s"Jar not found at $path") diff --git a/core/src/main/scala/org/apache/spark/SparkEnv.scala b/core/src/main/scala/org/apache/spark/SparkEnv.scala index 88df27f733f2a..84230e32a4462 100644 --- a/core/src/main/scala/org/apache/spark/SparkEnv.scala +++ b/core/src/main/scala/org/apache/spark/SparkEnv.scala @@ -66,7 +66,6 @@ class SparkEnv ( val blockTransferService: BlockTransferService, val blockManager: BlockManager, val securityManager: SecurityManager, - val httpFileServer: HttpFileServer, val sparkFilesDir: String, val metricsSystem: MetricsSystem, val memoryManager: MemoryManager, @@ -91,7 +90,6 @@ class SparkEnv ( if (!isStopped) { isStopped = true pythonWorkers.values.foreach(_.stop()) - Option(httpFileServer).foreach(_.stop()) mapOutputTracker.stop() shuffleManager.stop() broadcastManager.stop() @@ -367,17 +365,6 @@ object SparkEnv extends Logging { val cacheManager = new CacheManager(blockManager) - val httpFileServer = - if (isDriver) { - val fileServerPort = conf.getInt("spark.fileserver.port", 0) - val server = new HttpFileServer(conf, securityManager, fileServerPort) - server.initialize() - conf.set("spark.fileserver.uri", server.serverUri) - server - } else { - null - } - val metricsSystem = if (isDriver) { // Don't start metrics system right now for Driver. // We need to wait for the task scheduler to give us an app ID. @@ -422,7 +409,6 @@ object SparkEnv extends Logging { blockTransferService, blockManager, securityManager, - httpFileServer, sparkFilesDir, metricsSystem, memoryManager, diff --git a/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala b/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala index a560fd10cdf76..3d7d281b0dd66 100644 --- a/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala +++ b/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala @@ -17,6 +17,9 @@ package org.apache.spark.rpc +import java.io.File +import java.nio.channels.ReadableByteChannel + import scala.concurrent.Future import org.apache.spark.{SecurityManager, SparkConf} @@ -132,8 +135,51 @@ private[spark] abstract class RpcEnv(conf: SparkConf) { * that contains [[RpcEndpointRef]]s, the deserialization codes should be wrapped by this method. */ def deserialize[T](deserializationAction: () => T): T + + /** + * Return the instance of the file server used to serve files. This may be `null` if the + * RpcEnv is not operating in server mode. + */ + def fileServer: RpcEnvFileServer + + /** + * Open a channel to download a file from the given URI. If the URIs returned by the + * RpcEnvFileServer use the "spark" scheme, this method will be called by the Utils class to + * retrieve the files. + * + * @param uri URI with location of the file. + */ + def openChannel(uri: String): ReadableByteChannel + } +/** + * A server used by the RpcEnv to server files to other processes owned by the application. + * + * The file server can return URIs handled by common libraries (such as "http" or "hdfs"), or + * it can return "spark" URIs which will be handled by `RpcEnv#fetchFile`. + */ +private[spark] trait RpcEnvFileServer { + + /** + * Adds a file to be served by this RpcEnv. This is used to serve files from the driver + * to executors when they're stored on the driver's local file system. + * + * @param file Local file to serve. + * @return A URI for the location of the file. + */ + def addFile(file: File): String + + /** + * Adds a jar to be served by this RpcEnv. Similar to `addFile` but for jars added using + * `SparkContext.addJar`. + * + * @param file Local file to serve. + * @return A URI for the location of the file. + */ + def addJar(file: File): String + +} private[spark] case class RpcEnvConfig( conf: SparkConf, diff --git a/core/src/main/scala/org/apache/spark/rpc/akka/AkkaRpcEnv.scala b/core/src/main/scala/org/apache/spark/rpc/akka/AkkaRpcEnv.scala index 059a7e10ec12f..94dbec593c315 100644 --- a/core/src/main/scala/org/apache/spark/rpc/akka/AkkaRpcEnv.scala +++ b/core/src/main/scala/org/apache/spark/rpc/akka/AkkaRpcEnv.scala @@ -17,6 +17,8 @@ package org.apache.spark.rpc.akka +import java.io.File +import java.nio.channels.ReadableByteChannel import java.util.concurrent.ConcurrentHashMap import scala.concurrent.Future @@ -30,7 +32,7 @@ import akka.pattern.{ask => akkaAsk} import akka.remote.{AssociationEvent, AssociatedEvent, DisassociatedEvent, AssociationErrorEvent} import akka.serialization.JavaSerializer -import org.apache.spark.{SparkException, Logging, SparkConf} +import org.apache.spark.{HttpFileServer, Logging, SecurityManager, SparkConf, SparkException} import org.apache.spark.rpc._ import org.apache.spark.util.{ActorLogReceive, AkkaUtils, ThreadUtils} @@ -41,7 +43,10 @@ import org.apache.spark.util.{ActorLogReceive, AkkaUtils, ThreadUtils} * remove Akka from the dependencies. */ private[spark] class AkkaRpcEnv private[akka] ( - val actorSystem: ActorSystem, conf: SparkConf, boundPort: Int) + val actorSystem: ActorSystem, + val securityManager: SecurityManager, + conf: SparkConf, + boundPort: Int) extends RpcEnv(conf) with Logging { private val defaultAddress: RpcAddress = { @@ -64,6 +69,8 @@ private[spark] class AkkaRpcEnv private[akka] ( */ private val refToEndpoint = new ConcurrentHashMap[RpcEndpointRef, RpcEndpoint]() + private val _fileServer = new AkkaFileServer(conf, securityManager) + private def registerEndpoint(endpoint: RpcEndpoint, endpointRef: RpcEndpointRef): Unit = { endpointToRef.put(endpoint, endpointRef) refToEndpoint.put(endpointRef, endpoint) @@ -223,6 +230,7 @@ private[spark] class AkkaRpcEnv private[akka] ( override def shutdown(): Unit = { actorSystem.shutdown() + _fileServer.shutdown() } override def stop(endpoint: RpcEndpointRef): Unit = { @@ -241,6 +249,52 @@ private[spark] class AkkaRpcEnv private[akka] ( deserializationAction() } } + + override def openChannel(uri: String): ReadableByteChannel = { + throw new UnsupportedOperationException( + "AkkaRpcEnv's files should be retrieved using an HTTP client.") + } + + override def fileServer: RpcEnvFileServer = _fileServer + +} + +private[akka] class AkkaFileServer( + conf: SparkConf, + securityManager: SecurityManager) extends RpcEnvFileServer { + + @volatile private var httpFileServer: HttpFileServer = _ + + override def addFile(file: File): String = { + getFileServer().addFile(file) + } + + override def addJar(file: File): String = { + getFileServer().addJar(file) + } + + def shutdown(): Unit = { + if (httpFileServer != null) { + httpFileServer.stop() + } + } + + private def getFileServer(): HttpFileServer = { + if (httpFileServer == null) synchronized { + if (httpFileServer == null) { + httpFileServer = startFileServer() + } + } + httpFileServer + } + + private def startFileServer(): HttpFileServer = { + val fileServerPort = conf.getInt("spark.fileserver.port", 0) + val server = new HttpFileServer(conf, securityManager, fileServerPort) + server.initialize() + server + } + } private[spark] class AkkaRpcEnvFactory extends RpcEnvFactory { @@ -249,7 +303,7 @@ private[spark] class AkkaRpcEnvFactory extends RpcEnvFactory { val (actorSystem, boundPort) = AkkaUtils.createActorSystem( config.name, config.host, config.port, config.conf, config.securityManager) actorSystem.actorOf(Props(classOf[ErrorMonitor]), "ErrorMonitor") - new AkkaRpcEnv(actorSystem, config.conf, boundPort) + new AkkaRpcEnv(actorSystem, config.securityManager, config.conf, boundPort) } } diff --git a/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala b/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala index 3ce359868039b..68701f609f77a 100644 --- a/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala +++ b/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala @@ -20,6 +20,7 @@ import java.io._ import java.lang.{Boolean => JBoolean} import java.net.{InetSocketAddress, URI} import java.nio.ByteBuffer +import java.nio.channels.{Pipe, ReadableByteChannel, WritableByteChannel} import java.util.concurrent._ import java.util.concurrent.atomic.AtomicBoolean import javax.annotation.Nullable @@ -45,27 +46,39 @@ private[netty] class NettyRpcEnv( host: String, securityManager: SecurityManager) extends RpcEnv(conf) with Logging { - private val transportConf = SparkTransportConf.fromSparkConf( + private[netty] val transportConf = SparkTransportConf.fromSparkConf( conf.clone.set("spark.rpc.io.numConnectionsPerPeer", "1"), "rpc", conf.getInt("spark.rpc.io.threads", 0)) private val dispatcher: Dispatcher = new Dispatcher(this) + private val streamManager = new NettyStreamManager(this) + private val transportContext = new TransportContext(transportConf, - new NettyRpcHandler(dispatcher, this)) + new NettyRpcHandler(dispatcher, this, streamManager)) - private val clientFactory = { - val bootstraps: java.util.List[TransportClientBootstrap] = - if (securityManager.isAuthenticationEnabled()) { - java.util.Arrays.asList(new SaslClientBootstrap(transportConf, "", securityManager, - securityManager.isSaslEncryptionEnabled())) - } else { - java.util.Collections.emptyList[TransportClientBootstrap] - } - transportContext.createClientFactory(bootstraps) + private def createClientBootstraps(): java.util.List[TransportClientBootstrap] = { + if (securityManager.isAuthenticationEnabled()) { + java.util.Arrays.asList(new SaslClientBootstrap(transportConf, "", securityManager, + securityManager.isSaslEncryptionEnabled())) + } else { + java.util.Collections.emptyList[TransportClientBootstrap] + } } + private val clientFactory = transportContext.createClientFactory(createClientBootstraps()) + + /** + * A separate client factory for file downloads. This avoids using the same RPC handler as + * the main RPC context, so that events caused by these clients are kept isolated from the + * main RPC traffic. + * + * It also allows for different configuration of certain properties, such as the number of + * connections per peer. + */ + @volatile private var fileDownloadFactory: TransportClientFactory = _ + val timeoutScheduler = ThreadUtils.newDaemonSingleThreadScheduledExecutor("netty-rpc-env-timeout") // Because TransportClientFactory.createClient is blocking, we need to run it in this thread pool @@ -292,6 +305,9 @@ private[netty] class NettyRpcEnv( if (clientConnectionExecutor != null) { clientConnectionExecutor.shutdownNow() } + if (fileDownloadFactory != null) { + fileDownloadFactory.close() + } } override def deserialize[T](deserializationAction: () => T): T = { @@ -300,6 +316,96 @@ private[netty] class NettyRpcEnv( } } + override def fileServer: RpcEnvFileServer = streamManager + + override def openChannel(uri: String): ReadableByteChannel = { + val parsedUri = new URI(uri) + require(parsedUri.getHost() != null, "Host name must be defined.") + require(parsedUri.getPort() > 0, "Port must be defined.") + require(parsedUri.getPath() != null && parsedUri.getPath().nonEmpty, "Path must be defined.") + + val pipe = Pipe.open() + val source = new FileDownloadChannel(pipe.source()) + try { + val client = downloadClient(parsedUri.getHost(), parsedUri.getPort()) + val callback = new FileDownloadCallback(pipe.sink(), source, client) + client.stream(parsedUri.getPath(), callback) + } catch { + case e: Exception => + pipe.sink().close() + source.close() + throw e + } + + source + } + + private def downloadClient(host: String, port: Int): TransportClient = { + if (fileDownloadFactory == null) synchronized { + if (fileDownloadFactory == null) { + val module = "files" + val prefix = "spark.rpc.io." + val clone = conf.clone() + + // Copy any RPC configuration that is not overridden in the spark.files namespace. + conf.getAll.foreach { case (key, value) => + if (key.startsWith(prefix)) { + val opt = key.substring(prefix.length()) + clone.setIfMissing(s"spark.$module.io.$opt", value) + } + } + + val ioThreads = clone.getInt("spark.files.io.threads", 1) + val downloadConf = SparkTransportConf.fromSparkConf(clone, module, ioThreads) + val downloadContext = new TransportContext(downloadConf, new NoOpRpcHandler(), true) + fileDownloadFactory = downloadContext.createClientFactory(createClientBootstraps()) + } + } + fileDownloadFactory.createClient(host, port) + } + + private class FileDownloadChannel(source: ReadableByteChannel) extends ReadableByteChannel { + + @volatile private var error: Throwable = _ + + def setError(e: Throwable): Unit = error = e + + override def read(dst: ByteBuffer): Int = { + if (error != null) { + throw error + } + source.read(dst) + } + + override def close(): Unit = source.close() + + override def isOpen(): Boolean = source.isOpen() + + } + + private class FileDownloadCallback( + sink: WritableByteChannel, + source: FileDownloadChannel, + client: TransportClient) extends StreamCallback { + + override def onData(streamId: String, buf: ByteBuffer): Unit = { + while (buf.remaining() > 0) { + sink.write(buf) + } + } + + override def onComplete(streamId: String): Unit = { + sink.close() + } + + override def onFailure(streamId: String, cause: Throwable): Unit = { + logError(s"Error downloading stream $streamId.", cause) + source.setError(cause) + sink.close() + } + + } + } private[netty] object NettyRpcEnv extends Logging { @@ -420,7 +526,7 @@ private[netty] class NettyRpcEndpointRef( override def toString: String = s"NettyRpcEndpointRef(${_address})" - def toURI: URI = new URI(s"spark://${_address}") + def toURI: URI = new URI(_address.toString) final override def equals(that: Any): Boolean = that match { case other: NettyRpcEndpointRef => _address == other._address @@ -471,7 +577,9 @@ private[netty] case class RpcFailure(e: Throwable) * with different `RpcAddress` information). */ private[netty] class NettyRpcHandler( - dispatcher: Dispatcher, nettyEnv: NettyRpcEnv) extends RpcHandler with Logging { + dispatcher: Dispatcher, + nettyEnv: NettyRpcEnv, + streamManager: StreamManager) extends RpcHandler with Logging { // TODO: Can we add connection callback (channel registered) to the underlying framework? // A variable to track whether we should dispatch the RemoteProcessConnected message. @@ -498,7 +606,7 @@ private[netty] class NettyRpcHandler( dispatcher.postRemoteMessage(messageToDispatch, callback) } - override def getStreamManager: StreamManager = new OneForOneStreamManager + override def getStreamManager: StreamManager = streamManager override def exceptionCaught(cause: Throwable, client: TransportClient): Unit = { val addr = client.getChannel.remoteAddress().asInstanceOf[InetSocketAddress] @@ -516,8 +624,8 @@ private[netty] class NettyRpcHandler( override def connectionTerminated(client: TransportClient): Unit = { val addr = client.getChannel.remoteAddress().asInstanceOf[InetSocketAddress] if (addr != null) { - val clientAddr = RpcAddress(addr.getHostName, addr.getPort) clients.remove(client) + val clientAddr = RpcAddress(addr.getHostName, addr.getPort) nettyEnv.removeOutbox(clientAddr) dispatcher.postToAll(RemoteProcessDisconnected(clientAddr)) } else { diff --git a/core/src/main/scala/org/apache/spark/rpc/netty/NettyStreamManager.scala b/core/src/main/scala/org/apache/spark/rpc/netty/NettyStreamManager.scala new file mode 100644 index 0000000000000..eb1d2604fb235 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/rpc/netty/NettyStreamManager.scala @@ -0,0 +1,63 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.rpc.netty + +import java.io.File +import java.util.concurrent.ConcurrentHashMap + +import org.apache.spark.network.buffer.{FileSegmentManagedBuffer, ManagedBuffer} +import org.apache.spark.network.server.StreamManager +import org.apache.spark.rpc.RpcEnvFileServer + +/** + * StreamManager implementation for serving files from a NettyRpcEnv. + */ +private[netty] class NettyStreamManager(rpcEnv: NettyRpcEnv) + extends StreamManager with RpcEnvFileServer { + + private val files = new ConcurrentHashMap[String, File]() + private val jars = new ConcurrentHashMap[String, File]() + + override def getChunk(streamId: Long, chunkIndex: Int): ManagedBuffer = { + throw new UnsupportedOperationException() + } + + override def openStream(streamId: String): ManagedBuffer = { + val Array(ftype, fname) = streamId.stripPrefix("/").split("/", 2) + val file = ftype match { + case "files" => files.get(fname) + case "jars" => jars.get(fname) + case _ => throw new IllegalArgumentException(s"Invalid file type: $ftype") + } + + require(file != null, s"File not found: $streamId") + new FileSegmentManagedBuffer(rpcEnv.transportConf, file, 0, file.length()) + } + + override def addFile(file: File): String = { + require(files.putIfAbsent(file.getName(), file) == null, + s"File ${file.getName()} already registered.") + s"${rpcEnv.address.toSparkURL}/files/${file.getName()}" + } + + override def addJar(file: File): String = { + require(jars.putIfAbsent(file.getName(), file) == null, + s"JAR ${file.getName()} already registered.") + s"${rpcEnv.address.toSparkURL}/jars/${file.getName()}" + } + +} diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala index 1b3acb8ef7f51..af632349c9cae 100644 --- a/core/src/main/scala/org/apache/spark/util/Utils.scala +++ b/core/src/main/scala/org/apache/spark/util/Utils.scala @@ -21,6 +21,7 @@ import java.io._ import java.lang.management.ManagementFactory import java.net._ import java.nio.ByteBuffer +import java.nio.channels.Channels import java.util.concurrent._ import java.util.{Locale, Properties, Random, UUID} import javax.net.ssl.HttpsURLConnection @@ -535,6 +536,14 @@ private[spark] object Utils extends Logging { val uri = new URI(url) val fileOverwrite = conf.getBoolean("spark.files.overwrite", defaultValue = false) Option(uri.getScheme).getOrElse("file") match { + case "spark" => + if (SparkEnv.get == null) { + throw new IllegalStateException( + "Cannot retrieve files with 'spark' scheme without an active SparkEnv.") + } + val source = SparkEnv.get.rpcEnv.openChannel(url) + val is = Channels.newInputStream(source) + downloadFile(url, is, targetFile, fileOverwrite) case "http" | "https" | "ftp" => var uc: URLConnection = null if (securityMgr.isAuthenticationEnabled()) { diff --git a/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala b/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala index 2f55006420ce1..2b664c6313efa 100644 --- a/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala +++ b/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala @@ -17,7 +17,9 @@ package org.apache.spark.rpc -import java.io.NotSerializableException +import java.io.{File, NotSerializableException} +import java.util.UUID +import java.nio.charset.StandardCharsets.UTF_8 import java.util.concurrent.{TimeUnit, CountDownLatch, TimeoutException} import scala.collection.mutable @@ -25,10 +27,14 @@ import scala.concurrent.Await import scala.concurrent.duration._ import scala.language.postfixOps +import com.google.common.io.Files +import org.mockito.Mockito.{mock, when} import org.scalatest.BeforeAndAfterAll import org.scalatest.concurrent.Eventually._ -import org.apache.spark.{SparkConf, SparkException, SparkFunSuite} +import org.apache.spark.{SecurityManager, SparkConf, SparkEnv, SparkException, SparkFunSuite} +import org.apache.spark.deploy.SparkHadoopUtil +import org.apache.spark.util.Utils /** * Common tests for an RpcEnv implementation. @@ -40,12 +46,17 @@ abstract class RpcEnvSuite extends SparkFunSuite with BeforeAndAfterAll { override def beforeAll(): Unit = { val conf = new SparkConf() env = createRpcEnv(conf, "local", 0) + + val sparkEnv = mock(classOf[SparkEnv]) + when(sparkEnv.rpcEnv).thenReturn(env) + SparkEnv.set(sparkEnv) } override def afterAll(): Unit = { if (env != null) { env.shutdown() } + SparkEnv.set(null) } def createRpcEnv(conf: SparkConf, name: String, port: Int, clientMode: Boolean = false): RpcEnv @@ -713,6 +724,30 @@ abstract class RpcEnvSuite extends SparkFunSuite with BeforeAndAfterAll { assert(shortTimeout.timeoutProp.r.findAllIn(reply4).length === 1) } + test("file server") { + val conf = new SparkConf() + val tempDir = Utils.createTempDir() + val file = new File(tempDir, "file") + Files.write(UUID.randomUUID().toString(), file, UTF_8) + val jar = new File(tempDir, "jar") + Files.write(UUID.randomUUID().toString(), jar, UTF_8) + + val fileUri = env.fileServer.addFile(file) + val jarUri = env.fileServer.addJar(jar) + + val destDir = Utils.createTempDir() + val destFile = new File(destDir, file.getName()) + val destJar = new File(destDir, jar.getName()) + + val sm = new SecurityManager(conf) + val hc = SparkHadoopUtil.get.conf + Utils.fetchFile(fileUri, destDir, conf, sm, hc, 0L, false) + Utils.fetchFile(jarUri, destDir, conf, sm, hc, 0L, false) + + assert(Files.equal(file, destFile)) + assert(Files.equal(jar, destJar)) + } + } class UnserializableClass diff --git a/core/src/test/scala/org/apache/spark/rpc/netty/NettyRpcHandlerSuite.scala b/core/src/test/scala/org/apache/spark/rpc/netty/NettyRpcHandlerSuite.scala index f9d8e80c98b66..ccca795683da3 100644 --- a/core/src/test/scala/org/apache/spark/rpc/netty/NettyRpcHandlerSuite.scala +++ b/core/src/test/scala/org/apache/spark/rpc/netty/NettyRpcHandlerSuite.scala @@ -25,17 +25,19 @@ import org.mockito.Matchers._ import org.apache.spark.SparkFunSuite import org.apache.spark.network.client.{TransportResponseHandler, TransportClient} +import org.apache.spark.network.server.StreamManager import org.apache.spark.rpc._ class NettyRpcHandlerSuite extends SparkFunSuite { val env = mock(classOf[NettyRpcEnv]) - when(env.deserialize(any(classOf[TransportClient]), any(classOf[Array[Byte]]))(any())). - thenReturn(RequestMessage(RpcAddress("localhost", 12345), null, null, false)) + val sm = mock(classOf[StreamManager]) + when(env.deserialize(any(classOf[TransportClient]), any(classOf[Array[Byte]]))(any())) + .thenReturn(RequestMessage(RpcAddress("localhost", 12345), null, null, false)) test("receive") { val dispatcher = mock(classOf[Dispatcher]) - val nettyRpcHandler = new NettyRpcHandler(dispatcher, env) + val nettyRpcHandler = new NettyRpcHandler(dispatcher, env, sm) val channel = mock(classOf[Channel]) val client = new TransportClient(channel, mock(classOf[TransportResponseHandler])) @@ -47,7 +49,7 @@ class NettyRpcHandlerSuite extends SparkFunSuite { test("connectionTerminated") { val dispatcher = mock(classOf[Dispatcher]) - val nettyRpcHandler = new NettyRpcHandler(dispatcher, env) + val nettyRpcHandler = new NettyRpcHandler(dispatcher, env, sm) val channel = mock(classOf[Channel]) val client = new TransportClient(channel, mock(classOf[TransportResponseHandler])) diff --git a/docs/configuration.md b/docs/configuration.md index c496146e3ed63..4de202d7f7631 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -1020,6 +1020,7 @@ Apart from these, the following properties are also available, and may be useful (random) Port for the executor to listen on. This is used for communicating with the driver. + This is only relevant when using the Akka RPC backend. @@ -1027,6 +1028,7 @@ Apart from these, the following properties are also available, and may be useful (random) Port for the driver's HTTP file server to listen on. + This is only relevant when using the Akka RPC backend. diff --git a/docs/security.md b/docs/security.md index 177109415180b..e1af221d446b0 100644 --- a/docs/security.md +++ b/docs/security.md @@ -149,7 +149,8 @@ configure those ports. (random) Schedule tasks spark.executor.port - Akka-based. Set to "0" to choose a port randomly. + Akka-based. Set to "0" to choose a port randomly. Only used if Akka RPC backend is + configured. Executor @@ -157,7 +158,7 @@ configure those ports. (random) File server for files and jars spark.fileserver.port - Jetty-based + Jetty-based. Only used if Akka RPC backend is configured. Executor diff --git a/launcher/src/main/java/org/apache/spark/launcher/AbstractCommandBuilder.java b/launcher/src/main/java/org/apache/spark/launcher/AbstractCommandBuilder.java index 3ee6bd92e47fc..55fe156cf665f 100644 --- a/launcher/src/main/java/org/apache/spark/launcher/AbstractCommandBuilder.java +++ b/launcher/src/main/java/org/apache/spark/launcher/AbstractCommandBuilder.java @@ -148,7 +148,7 @@ List buildClassPath(String appClassPath) throws IOException { String scala = getScalaVersion(); List projects = Arrays.asList("core", "repl", "mllib", "bagel", "graphx", "streaming", "tools", "sql/catalyst", "sql/core", "sql/hive", "sql/hive-thriftserver", - "yarn", "launcher"); + "yarn", "launcher", "network/common", "network/shuffle", "network/yarn"); if (prependClasses) { if (!isTesting) { System.err.println( diff --git a/network/common/src/main/java/org/apache/spark/network/client/TransportClientFactory.java b/network/common/src/main/java/org/apache/spark/network/client/TransportClientFactory.java index 659c47160c7be..61bafc8380049 100644 --- a/network/common/src/main/java/org/apache/spark/network/client/TransportClientFactory.java +++ b/network/common/src/main/java/org/apache/spark/network/client/TransportClientFactory.java @@ -170,8 +170,10 @@ public TransportClient createClient(String remoteHost, int remotePort) throws IO } /** - * Create a completely new {@link TransportClient} to the given remote host / port - * But this connection is not pooled. + * Create a completely new {@link TransportClient} to the given remote host / port. + * This connection is not pooled. + * + * As with {@link #createClient(String, int)}, this method is blocking. */ public TransportClient createUnmanagedClient(String remoteHost, int remotePort) throws IOException { diff --git a/network/common/src/main/java/org/apache/spark/network/server/TransportChannelHandler.java b/network/common/src/main/java/org/apache/spark/network/server/TransportChannelHandler.java index 29d688a67578c..3164e00679035 100644 --- a/network/common/src/main/java/org/apache/spark/network/server/TransportChannelHandler.java +++ b/network/common/src/main/java/org/apache/spark/network/server/TransportChannelHandler.java @@ -138,6 +138,7 @@ public void userEventTriggered(ChannelHandlerContext ctx, Object evt) throws Exc } } } + ctx.fireUserEventTriggered(evt); } public TransportResponseHandler getResponseHandler() {