diff --git a/R/pkg/DESCRIPTION b/R/pkg/DESCRIPTION index 1842b97d43651..052f68c6c24e2 100644 --- a/R/pkg/DESCRIPTION +++ b/R/pkg/DESCRIPTION @@ -17,19 +17,19 @@ License: Apache License (== 2.0) Collate: 'generics.R' 'jobj.R' - 'SQLTypes.R' 'RDD.R' 'pairRDD.R' + 'SQLTypes.R' 'column.R' 'group.R' 'DataFrame.R' 'SQLContext.R' + 'backend.R' 'broadcast.R' + 'client.R' 'context.R' 'deserialize.R' 'serialize.R' 'sparkR.R' - 'backend.R' - 'client.R' 'utils.R' 'zzz.R' diff --git a/R/pkg/R/DataFrame.R b/R/pkg/R/DataFrame.R index feafd56909a67..044fdb4d01223 100644 --- a/R/pkg/R/DataFrame.R +++ b/R/pkg/R/DataFrame.R @@ -17,7 +17,7 @@ # DataFrame.R - DataFrame class and methods implemented in S4 OO classes -#' @include jobj.R SQLTypes.R RDD.R pairRDD.R column.R group.R +#' @include generics.R jobj.R SQLTypes.R RDD.R pairRDD.R column.R group.R NULL setOldClass("jobj") diff --git a/R/pkg/R/RDD.R b/R/pkg/R/RDD.R index 604ad03c407b9..820027ef67e3b 100644 --- a/R/pkg/R/RDD.R +++ b/R/pkg/R/RDD.R @@ -85,7 +85,7 @@ setMethod("initialize", "PipelinedRDD", function(.Object, prev, func, jrdd_val) if (!inherits(prev, "PipelinedRDD") || !isPipelinable(prev)) { # This transformation is the first in its stage: - .Object@func <- func + .Object@func <- cleanClosure(func) .Object@prev_jrdd <- getJRDD(prev) .Object@env$prev_serializedMode <- prev@env$serializedMode # NOTE: We use prev_serializedMode to track the serialization mode of prev_JRDD @@ -94,7 +94,7 @@ setMethod("initialize", "PipelinedRDD", function(.Object, prev, func, jrdd_val) pipelinedFunc <- function(split, iterator) { func(split, prev@func(split, iterator)) } - .Object@func <- pipelinedFunc + .Object@func <- cleanClosure(pipelinedFunc) .Object@prev_jrdd <- prev@prev_jrdd # maintain the pipeline # Get the serialization mode of the parent RDD .Object@env$prev_serializedMode <- prev@env$prev_serializedMode @@ -144,17 +144,13 @@ setMethod("getJRDD", signature(rdd = "PipelinedRDD"), return(rdd@env$jrdd_val) } - computeFunc <- function(split, part) { - rdd@func(split, part) - } - packageNamesArr <- serialize(.sparkREnv[[".packages"]], connection = NULL) broadcastArr <- lapply(ls(.broadcastNames), function(name) { get(name, .broadcastNames) }) - serializedFuncArr <- serialize(computeFunc, connection = NULL) + serializedFuncArr <- serialize(rdd@func, connection = NULL) prev_jrdd <- rdd@prev_jrdd @@ -279,7 +275,7 @@ setMethod("unpersist", #' @examples #'\dontrun{ #' sc <- sparkR.init() -#' setCheckpointDir(sc, "checkpoints") +#' setCheckpointDir(sc, "checkpoint") #' rdd <- parallelize(sc, 1:10, 2L) #' checkpoint(rdd) #'} @@ -551,11 +547,7 @@ setMethod("mapPartitions", setMethod("lapplyPartitionsWithIndex", signature(X = "RDD", FUN = "function"), function(X, FUN) { - FUN <- cleanClosure(FUN) - closureCapturingFunc <- function(split, part) { - FUN(split, part) - } - PipelinedRDD(X, closureCapturingFunc) + PipelinedRDD(X, FUN) }) #' @rdname lapplyPartitionsWithIndex diff --git a/R/pkg/R/column.R b/R/pkg/R/column.R index e196305186b9a..b282001d8b6b5 100644 --- a/R/pkg/R/column.R +++ b/R/pkg/R/column.R @@ -17,7 +17,7 @@ # Column Class -#' @include generics.R jobj.R +#' @include generics.R jobj.R SQLTypes.R NULL setOldClass("jobj") diff --git a/R/pkg/R/context.R b/R/pkg/R/context.R index 2fc0bb294bcce..ebbb8fba1052d 100644 --- a/R/pkg/R/context.R +++ b/R/pkg/R/context.R @@ -216,7 +216,7 @@ broadcast <- function(sc, object) { #' @examples #'\dontrun{ #' sc <- sparkR.init() -#' setCheckpointDir(sc, "~/checkpoints") +#' setCheckpointDir(sc, "~/checkpoint") #' rdd <- parallelize(sc, 1:2, 2L) #' checkpoint(rdd) #'} diff --git a/R/pkg/R/group.R b/R/pkg/R/group.R index 09fc0a7abe48a..855fbdfc7c4ca 100644 --- a/R/pkg/R/group.R +++ b/R/pkg/R/group.R @@ -17,6 +17,9 @@ # group.R - GroupedData class and methods implemented in S4 OO classes +#' @include generics.R jobj.R SQLTypes.R column.R +NULL + setOldClass("jobj") #' @title S4 class that represents a GroupedData diff --git a/R/pkg/R/jobj.R b/R/pkg/R/jobj.R index 4180f146b7fbc..a8a25230b636d 100644 --- a/R/pkg/R/jobj.R +++ b/R/pkg/R/jobj.R @@ -18,6 +18,9 @@ # References to objects that exist on the JVM backend # are maintained using the jobj. +#' @include generics.R +NULL + # Maintain a reference count of Java object references # This allows us to GC the java object when it is safe .validJobjs <- new.env(parent = emptyenv()) diff --git a/R/pkg/R/pairRDD.R b/R/pkg/R/pairRDD.R index c2396c32a7548..5d64822859d1f 100644 --- a/R/pkg/R/pairRDD.R +++ b/R/pkg/R/pairRDD.R @@ -16,6 +16,8 @@ # # Operations supported on RDDs contains pairs (i.e key, value) +#' @include generics.R jobj.R RDD.R +NULL ############ Actions and Transformations ############ @@ -694,10 +696,6 @@ setMethod("cogroup", for (i in 1:rddsLen) { rdds[[i]] <- lapply(rdds[[i]], function(x) { list(x[[1]], list(i, x[[2]])) }) - # TODO(hao): As issue [SparkR-142] mentions, the right value of i - # will not be captured into UDF if getJRDD is not invoked. - # It should be resolved together with that issue. - getJRDD(rdds[[i]]) # Capture the closure. } union.rdd <- Reduce(unionRDD, rdds) group.func <- function(vlist) { diff --git a/R/pkg/inst/tests/test_rdd.R b/R/pkg/inst/tests/test_rdd.R index f75e0817b9406..b76e4db03e715 100644 --- a/R/pkg/inst/tests/test_rdd.R +++ b/R/pkg/inst/tests/test_rdd.R @@ -141,7 +141,8 @@ test_that("PipelinedRDD support actions: cache(), persist(), unpersist(), checkp unpersist(rdd2) expect_false(rdd2@env$isCached) - setCheckpointDir(sc, "checkpoints") + tempDir <- tempfile(pattern = "checkpoint") + setCheckpointDir(sc, tempDir) checkpoint(rdd2) expect_true(rdd2@env$isCheckpointed) @@ -152,7 +153,7 @@ test_that("PipelinedRDD support actions: cache(), persist(), unpersist(), checkp # make sure the data is collectable collect(rdd2) - unlink("checkpoints") + unlink(tempDir) }) test_that("reduce on RDD", { diff --git a/bin/pyspark b/bin/pyspark index 776b28dc41099..8acad6113797d 100755 --- a/bin/pyspark +++ b/bin/pyspark @@ -89,6 +89,7 @@ export PYTHONSTARTUP="$SPARK_HOME/python/pyspark/shell.py" if [[ -n "$SPARK_TESTING" ]]; then unset YARN_CONF_DIR unset HADOOP_CONF_DIR + export PYTHONHASHSEED=0 if [[ -n "$PYSPARK_DOC_TEST" ]]; then exec "$PYSPARK_DRIVER_PYTHON" -m doctest $1 else diff --git a/bin/spark-class b/bin/spark-class index c03946d92e2e4..c49d97ce5cf25 100755 --- a/bin/spark-class +++ b/bin/spark-class @@ -82,13 +82,22 @@ if [ $(command -v "$JAR_CMD") ] ; then fi fi +LAUNCH_CLASSPATH="$SPARK_ASSEMBLY_JAR" + +# Add the launcher build dir to the classpath if requested. +if [ -n "$SPARK_PREPEND_CLASSES" ]; then + LAUNCH_CLASSPATH="$SPARK_HOME/launcher/target/scala-$SPARK_SCALA_VERSION/classes:$LAUNCH_CLASSPATH" +fi + +export _SPARK_ASSEMBLY="$SPARK_ASSEMBLY_JAR" + # The launcher library will print arguments separated by a NULL character, to allow arguments with # characters that would be otherwise interpreted by the shell. Read that in a while loop, populating # an array that will be used to exec the final command. CMD=() while IFS= read -d '' -r ARG; do CMD+=("$ARG") -done < <("$RUNNER" -cp "$SPARK_ASSEMBLY_JAR" org.apache.spark.launcher.Main "$@") +done < <("$RUNNER" -cp "$LAUNCH_CLASSPATH" org.apache.spark.launcher.Main "$@") if [ "${CMD[0]}" = "usage" ]; then "${CMD[@]}" diff --git a/bin/spark-class2.cmd b/bin/spark-class2.cmd index 4b3401d745f2a..3d068dd3a2739 100644 --- a/bin/spark-class2.cmd +++ b/bin/spark-class2.cmd @@ -46,13 +46,22 @@ if "%SPARK_ASSEMBLY_JAR%"=="0" ( exit /b 1 ) +set LAUNCH_CLASSPATH=%SPARK_ASSEMBLY_JAR% + +rem Add the launcher build dir to the classpath if requested. +if not "x%SPARK_PREPEND_CLASSES%"=="x" ( + set LAUNCH_CLASSPATH=%SPARK_HOME%\launcher\target\scala-%SPARK_SCALA_VERSION%\classes;%LAUNCH_CLASSPATH% +) + +set _SPARK_ASSEMBLY=%SPARK_ASSEMBLY_JAR% + rem Figure out where java is. set RUNNER=java if not "x%JAVA_HOME%"=="x" set RUNNER=%JAVA_HOME%\bin\java rem The launcher library prints the command to be executed in a single line suitable for being rem executed by the batch interpreter. So read all the output of the launcher into a variable. -for /f "tokens=*" %%i in ('cmd /C ""%RUNNER%" -cp %SPARK_ASSEMBLY_JAR% org.apache.spark.launcher.Main %*"') do ( +for /f "tokens=*" %%i in ('cmd /C ""%RUNNER%" -cp %LAUNCH_CLASSPATH% org.apache.spark.launcher.Main %*"') do ( set SPARK_CMD=%%i ) %SPARK_CMD% diff --git a/bin/spark-submit b/bin/spark-submit index bcff78edd51ca..0e0afe71a0f05 100755 --- a/bin/spark-submit +++ b/bin/spark-submit @@ -19,6 +19,9 @@ SPARK_HOME="$(cd "`dirname "$0"`"/..; pwd)" +# disable randomized hash for string in Python 3.3+ +export PYTHONHASHSEED=0 + # Only define a usage function if an upstream script hasn't done so. if ! type -t usage >/dev/null 2>&1; then usage() { diff --git a/bin/spark-submit2.cmd b/bin/spark-submit2.cmd index 08ddb185742d2..d3fc4a5cc3f6e 100644 --- a/bin/spark-submit2.cmd +++ b/bin/spark-submit2.cmd @@ -20,6 +20,9 @@ rem rem This is the entry point for running Spark submit. To avoid polluting the rem environment, it just launches a new cmd to do the real work. +rem disable randomized hash for string in Python 3.3+ +set PYTHONHASHSEED=0 + set CLASS=org.apache.spark.deploy.SparkSubmit call %~dp0spark-class2.cmd %CLASS% %* set SPARK_ERROR_LEVEL=%ERRORLEVEL% diff --git a/core/src/main/resources/org/apache/spark/ui/static/webui.css b/core/src/main/resources/org/apache/spark/ui/static/webui.css index 6c37cc8b98236..4910744d1d790 100644 --- a/core/src/main/resources/org/apache/spark/ui/static/webui.css +++ b/core/src/main/resources/org/apache/spark/ui/static/webui.css @@ -85,17 +85,13 @@ table.sortable td { filter: progid:dximagetransform.microsoft.gradient(startColorstr='#FFA4EDFF', endColorstr='#FF94DDFF', GradientType=0); } -span.kill-link { +a.kill-link { margin-right: 2px; margin-left: 20px; color: gray; float: right; } -span.kill-link a { - color: gray; -} - span.expand-details { font-size: 10pt; cursor: pointer; diff --git a/core/src/main/scala/org/apache/spark/ContextCleaner.scala b/core/src/main/scala/org/apache/spark/ContextCleaner.scala index 9b05c9623b704..715b259057569 100644 --- a/core/src/main/scala/org/apache/spark/ContextCleaner.scala +++ b/core/src/main/scala/org/apache/spark/ContextCleaner.scala @@ -22,7 +22,7 @@ import java.lang.ref.{ReferenceQueue, WeakReference} import scala.collection.mutable.{ArrayBuffer, SynchronizedBuffer} import org.apache.spark.broadcast.Broadcast -import org.apache.spark.rdd.RDD +import org.apache.spark.rdd.{RDDCheckpointData, RDD} import org.apache.spark.util.Utils /** @@ -33,6 +33,7 @@ private case class CleanRDD(rddId: Int) extends CleanupTask private case class CleanShuffle(shuffleId: Int) extends CleanupTask private case class CleanBroadcast(broadcastId: Long) extends CleanupTask private case class CleanAccum(accId: Long) extends CleanupTask +private case class CleanCheckpoint(rddId: Int) extends CleanupTask /** * A WeakReference associated with a CleanupTask. @@ -94,12 +95,12 @@ private[spark] class ContextCleaner(sc: SparkContext) extends Logging { @volatile private var stopped = false /** Attach a listener object to get information of when objects are cleaned. */ - def attachListener(listener: CleanerListener) { + def attachListener(listener: CleanerListener): Unit = { listeners += listener } /** Start the cleaner. */ - def start() { + def start(): Unit = { cleaningThread.setDaemon(true) cleaningThread.setName("Spark Context Cleaner") cleaningThread.start() @@ -108,7 +109,7 @@ private[spark] class ContextCleaner(sc: SparkContext) extends Logging { /** * Stop the cleaning thread and wait until the thread has finished running its current task. */ - def stop() { + def stop(): Unit = { stopped = true // Interrupt the cleaning thread, but wait until the current task has finished before // doing so. This guards against the race condition where a cleaning thread may @@ -121,7 +122,7 @@ private[spark] class ContextCleaner(sc: SparkContext) extends Logging { } /** Register a RDD for cleanup when it is garbage collected. */ - def registerRDDForCleanup(rdd: RDD[_]) { + def registerRDDForCleanup(rdd: RDD[_]): Unit = { registerForCleanup(rdd, CleanRDD(rdd.id)) } @@ -130,17 +131,22 @@ private[spark] class ContextCleaner(sc: SparkContext) extends Logging { } /** Register a ShuffleDependency for cleanup when it is garbage collected. */ - def registerShuffleForCleanup(shuffleDependency: ShuffleDependency[_, _, _]) { + def registerShuffleForCleanup(shuffleDependency: ShuffleDependency[_, _, _]): Unit = { registerForCleanup(shuffleDependency, CleanShuffle(shuffleDependency.shuffleId)) } /** Register a Broadcast for cleanup when it is garbage collected. */ - def registerBroadcastForCleanup[T](broadcast: Broadcast[T]) { + def registerBroadcastForCleanup[T](broadcast: Broadcast[T]): Unit = { registerForCleanup(broadcast, CleanBroadcast(broadcast.id)) } + /** Register a RDDCheckpointData for cleanup when it is garbage collected. */ + def registerRDDCheckpointDataForCleanup[T](rdd: RDD[_], parentId: Int): Unit = { + registerForCleanup(rdd, CleanCheckpoint(parentId)) + } + /** Register an object for cleanup. */ - private def registerForCleanup(objectForCleanup: AnyRef, task: CleanupTask) { + private def registerForCleanup(objectForCleanup: AnyRef, task: CleanupTask): Unit = { referenceBuffer += new CleanupTaskWeakReference(task, objectForCleanup, referenceQueue) } @@ -164,6 +170,8 @@ private[spark] class ContextCleaner(sc: SparkContext) extends Logging { doCleanupBroadcast(broadcastId, blocking = blockOnCleanupTasks) case CleanAccum(accId) => doCleanupAccum(accId, blocking = blockOnCleanupTasks) + case CleanCheckpoint(rddId) => + doCleanCheckpoint(rddId) } } } @@ -175,7 +183,7 @@ private[spark] class ContextCleaner(sc: SparkContext) extends Logging { } /** Perform RDD cleanup. */ - def doCleanupRDD(rddId: Int, blocking: Boolean) { + def doCleanupRDD(rddId: Int, blocking: Boolean): Unit = { try { logDebug("Cleaning RDD " + rddId) sc.unpersistRDD(rddId, blocking) @@ -187,7 +195,7 @@ private[spark] class ContextCleaner(sc: SparkContext) extends Logging { } /** Perform shuffle cleanup, asynchronously. */ - def doCleanupShuffle(shuffleId: Int, blocking: Boolean) { + def doCleanupShuffle(shuffleId: Int, blocking: Boolean): Unit = { try { logDebug("Cleaning shuffle " + shuffleId) mapOutputTrackerMaster.unregisterShuffle(shuffleId) @@ -200,7 +208,7 @@ private[spark] class ContextCleaner(sc: SparkContext) extends Logging { } /** Perform broadcast cleanup. */ - def doCleanupBroadcast(broadcastId: Long, blocking: Boolean) { + def doCleanupBroadcast(broadcastId: Long, blocking: Boolean): Unit = { try { logDebug(s"Cleaning broadcast $broadcastId") broadcastManager.unbroadcast(broadcastId, true, blocking) @@ -212,7 +220,7 @@ private[spark] class ContextCleaner(sc: SparkContext) extends Logging { } /** Perform accumulator cleanup. */ - def doCleanupAccum(accId: Long, blocking: Boolean) { + def doCleanupAccum(accId: Long, blocking: Boolean): Unit = { try { logDebug("Cleaning accumulator " + accId) Accumulators.remove(accId) @@ -223,6 +231,18 @@ private[spark] class ContextCleaner(sc: SparkContext) extends Logging { } } + /** Perform checkpoint cleanup. */ + def doCleanCheckpoint(rddId: Int): Unit = { + try { + logDebug("Cleaning rdd checkpoint data " + rddId) + RDDCheckpointData.clearRDDCheckpointData(sc, rddId) + logInfo("Cleaned rdd checkpoint data " + rddId) + } + catch { + case e: Exception => logError("Error cleaning rdd checkpoint data " + rddId, e) + } + } + private def blockManagerMaster = sc.env.blockManager.master private def broadcastManager = sc.env.broadcastManager private def mapOutputTrackerMaster = sc.env.mapOutputTracker.asInstanceOf[MapOutputTrackerMaster] diff --git a/core/src/main/scala/org/apache/spark/SparkConf.scala b/core/src/main/scala/org/apache/spark/SparkConf.scala index 390e631647bd6..b0186e9a007b8 100644 --- a/core/src/main/scala/org/apache/spark/SparkConf.scala +++ b/core/src/main/scala/org/apache/spark/SparkConf.scala @@ -68,6 +68,7 @@ class SparkConf(loadDefaults: Boolean) extends Cloneable with Logging { if (value == null) { throw new NullPointerException("null value for " + key) } + logDeprecationWarning(key) settings.put(key, value) this } @@ -134,13 +135,15 @@ class SparkConf(loadDefaults: Boolean) extends Cloneable with Logging { /** Set multiple parameters together */ def setAll(settings: Traversable[(String, String)]): SparkConf = { - this.settings.putAll(settings.toMap.asJava) + settings.foreach { case (k, v) => set(k, v) } this } /** Set a parameter if it isn't already configured */ def setIfMissing(key: String, value: String): SparkConf = { - settings.putIfAbsent(key, value) + if (settings.putIfAbsent(key, value) == null) { + logDeprecationWarning(key) + } this } @@ -174,8 +177,8 @@ class SparkConf(loadDefaults: Boolean) extends Cloneable with Logging { getOption(key).getOrElse(defaultValue) } - /** - * Get a time parameter as seconds; throws a NoSuchElementException if it's not set. If no + /** + * Get a time parameter as seconds; throws a NoSuchElementException if it's not set. If no * suffix is provided then seconds are assumed. * @throws NoSuchElementException */ @@ -183,36 +186,35 @@ class SparkConf(loadDefaults: Boolean) extends Cloneable with Logging { Utils.timeStringAsSeconds(get(key)) } - /** - * Get a time parameter as seconds, falling back to a default if not set. If no + /** + * Get a time parameter as seconds, falling back to a default if not set. If no * suffix is provided then seconds are assumed. - * */ def getTimeAsSeconds(key: String, defaultValue: String): Long = { Utils.timeStringAsSeconds(get(key, defaultValue)) } - /** - * Get a time parameter as milliseconds; throws a NoSuchElementException if it's not set. If no - * suffix is provided then milliseconds are assumed. + /** + * Get a time parameter as milliseconds; throws a NoSuchElementException if it's not set. If no + * suffix is provided then milliseconds are assumed. * @throws NoSuchElementException */ def getTimeAsMs(key: String): Long = { Utils.timeStringAsMs(get(key)) } - /** - * Get a time parameter as milliseconds, falling back to a default if not set. If no - * suffix is provided then milliseconds are assumed. + /** + * Get a time parameter as milliseconds, falling back to a default if not set. If no + * suffix is provided then milliseconds are assumed. */ def getTimeAsMs(key: String, defaultValue: String): Long = { Utils.timeStringAsMs(get(key, defaultValue)) } - + /** Get a parameter as an Option */ def getOption(key: String): Option[String] = { - Option(settings.get(key)) + Option(settings.get(key)).orElse(getDeprecatedConfig(key, this)) } /** Get all parameters as a list of pairs */ @@ -379,13 +381,6 @@ class SparkConf(loadDefaults: Boolean) extends Cloneable with Logging { } } } - - // Warn against the use of deprecated configs - deprecatedConfigs.values.foreach { dc => - if (contains(dc.oldName)) { - dc.warn() - } - } } /** @@ -400,19 +395,44 @@ class SparkConf(loadDefaults: Boolean) extends Cloneable with Logging { private[spark] object SparkConf extends Logging { + /** + * Maps deprecated config keys to information about the deprecation. + * + * The extra information is logged as a warning when the config is present in the user's + * configuration. + */ private val deprecatedConfigs: Map[String, DeprecatedConfig] = { val configs = Seq( - DeprecatedConfig("spark.files.userClassPathFirst", "spark.executor.userClassPathFirst", - "1.3"), - DeprecatedConfig("spark.yarn.user.classpath.first", null, "1.3", - "Use spark.{driver,executor}.userClassPathFirst instead."), - DeprecatedConfig("spark.history.fs.updateInterval", - "spark.history.fs.update.interval.seconds", - "1.3", "Use spark.history.fs.update.interval.seconds instead"), - DeprecatedConfig("spark.history.updateInterval", - "spark.history.fs.update.interval.seconds", - "1.3", "Use spark.history.fs.update.interval.seconds instead")) - configs.map { x => (x.oldName, x) }.toMap + DeprecatedConfig("spark.yarn.user.classpath.first", "1.3", + "Please use spark.{driver,executor}.userClassPathFirst instead.")) + Map(configs.map { cfg => (cfg.key -> cfg) }:_*) + } + + /** + * Maps a current config key to alternate keys that were used in previous version of Spark. + * + * The alternates are used in the order defined in this map. If deprecated configs are + * present in the user's configuration, a warning is logged. + */ + private val configsWithAlternatives = Map[String, Seq[AlternateConfig]]( + "spark.executor.userClassPathFirst" -> Seq( + AlternateConfig("spark.files.userClassPathFirst", "1.3")), + "spark.history.fs.update.interval" -> Seq( + AlternateConfig("spark.history.fs.update.interval.seconds", "1.4"), + AlternateConfig("spark.history.fs.updateInterval", "1.3"), + AlternateConfig("spark.history.updateInterval", "1.3")) + ) + + /** + * A view of `configsWithAlternatives` that makes it more efficient to look up deprecated + * config keys. + * + * Maps the deprecated config name to a 2-tuple (new config name, alternate config info). + */ + private val allAlternatives: Map[String, (String, AlternateConfig)] = { + configsWithAlternatives.keys.flatMap { key => + configsWithAlternatives(key).map { cfg => (cfg.key -> (key -> cfg)) } + }.toMap } /** @@ -443,61 +463,57 @@ private[spark] object SparkConf extends Logging { } /** - * Translate the configuration key if it is deprecated and has a replacement, otherwise just - * returns the provided key. - * - * @param userKey Configuration key from the user / caller. - * @param warn Whether to print a warning if the key is deprecated. Warnings will be printed - * only once for each key. + * Looks for available deprecated keys for the given config option, and return the first + * value available. */ - private def translateConfKey(userKey: String, warn: Boolean = false): String = { - deprecatedConfigs.get(userKey) - .map { deprecatedKey => - if (warn) { - deprecatedKey.warn() - } - deprecatedKey.newName.getOrElse(userKey) - }.getOrElse(userKey) + def getDeprecatedConfig(key: String, conf: SparkConf): Option[String] = { + configsWithAlternatives.get(key).flatMap { alts => + alts.collectFirst { case alt if conf.contains(alt.key) => + val value = conf.get(alt.key) + alt.translation.map(_(value)).getOrElse(value) + } + } } /** - * Holds information about keys that have been deprecated or renamed. + * Logs a warning message if the given config key is deprecated. + */ + def logDeprecationWarning(key: String): Unit = { + deprecatedConfigs.get(key).foreach { cfg => + logWarning( + s"The configuration key '$key' has been deprecated as of Spark ${cfg.version} and " + + s"may be removed in the future. ${cfg.deprecationMessage}") + } + + allAlternatives.get(key).foreach { case (newKey, cfg) => + logWarning( + s"The configuration key '$key' has been deprecated as of Spark ${cfg.version} and " + + s"and may be removed in the future. Please use the new key '$newKey' instead.") + } + } + + /** + * Holds information about keys that have been deprecated and do not have a replacement. * - * @param oldName Old configuration key. - * @param newName New configuration key, or `null` if key has no replacement, in which case the - * deprecated key will be used (but the warning message will still be printed). + * @param key The deprecated key. * @param version Version of Spark where key was deprecated. - * @param deprecationMessage Message to include in the deprecation warning; mandatory when - * `newName` is not provided. + * @param deprecationMessage Message to include in the deprecation warning. */ private case class DeprecatedConfig( - oldName: String, - _newName: String, + key: String, version: String, - deprecationMessage: String = null) { - - private val warned = new AtomicBoolean(false) - val newName = Option(_newName) + deprecationMessage: String) - if (newName == null && (deprecationMessage == null || deprecationMessage.isEmpty())) { - throw new IllegalArgumentException("Need new config name or deprecation message.") - } - - def warn(): Unit = { - if (warned.compareAndSet(false, true)) { - if (newName != null) { - val message = Option(deprecationMessage).getOrElse( - s"Please use the alternative '$newName' instead.") - logWarning( - s"The configuration option '$oldName' has been replaced as of Spark $version and " + - s"may be removed in the future. $message") - } else { - logWarning( - s"The configuration option '$oldName' has been deprecated as of Spark $version and " + - s"may be removed in the future. $deprecationMessage") - } - } - } + /** + * Information about an alternate configuration key that has been deprecated. + * + * @param key The deprecated config key. + * @param version The Spark version in which the key was deprecated. + * @param translation A translation function for converting old config values into new ones. + */ + private case class AlternateConfig( + key: String, + version: String, + translation: Option[String => String] = None) - } } diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index 64847486efa9c..fdc65f9a2f00f 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -31,6 +31,7 @@ import scala.collection.JavaConversions._ import scala.collection.generic.Growable import scala.collection.mutable.HashMap import scala.reflect.{ClassTag, classTag} +import scala.util.control.NonFatal import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.Path @@ -50,9 +51,10 @@ import org.apache.spark.executor.{ExecutorEndpoint, TriggerThreadDump} import org.apache.spark.input.{StreamInputFormat, PortableDataStream, WholeTextFileInputFormat, FixedLengthBinaryInputFormat} import org.apache.spark.io.CompressionCodec +import org.apache.spark.metrics.MetricsSystem import org.apache.spark.partial.{ApproximateEvaluator, PartialResult} import org.apache.spark.rdd._ -import org.apache.spark.rpc.RpcAddress +import org.apache.spark.rpc.{RpcAddress, RpcEndpointRef} import org.apache.spark.scheduler._ import org.apache.spark.scheduler.cluster.{CoarseGrainedSchedulerBackend, SparkDeploySchedulerBackend, SimrSchedulerBackend} @@ -192,8 +194,43 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli // log out Spark Version in Spark driver log logInfo(s"Running Spark version $SPARK_VERSION") - private[spark] val conf = config.clone() - conf.validateSettings() + /* ------------------------------------------------------------------------------------- * + | Private variables. These variables keep the internal state of the context, and are | + | not accessible by the outside world. They're mutable since we want to initialize all | + | of them to some neutral value ahead of time, so that calling "stop()" while the | + | constructor is still running is safe. | + * ------------------------------------------------------------------------------------- */ + + private var _conf: SparkConf = _ + private var _eventLogDir: Option[URI] = None + private var _eventLogCodec: Option[String] = None + private var _env: SparkEnv = _ + private var _metadataCleaner: MetadataCleaner = _ + private var _jobProgressListener: JobProgressListener = _ + private var _statusTracker: SparkStatusTracker = _ + private var _progressBar: Option[ConsoleProgressBar] = None + private var _ui: Option[SparkUI] = None + private var _hadoopConfiguration: Configuration = _ + private var _executorMemory: Int = _ + private var _schedulerBackend: SchedulerBackend = _ + private var _taskScheduler: TaskScheduler = _ + private var _heartbeatReceiver: RpcEndpointRef = _ + @volatile private var _dagScheduler: DAGScheduler = _ + private var _applicationId: String = _ + private var _applicationAttemptId: String = _ + private var _eventLogger: Option[EventLoggingListener] = None + private var _executorAllocationManager: Option[ExecutorAllocationManager] = None + private var _cleaner: Option[ContextCleaner] = None + private var _listenerBusStarted: Boolean = false + private var _jars: Seq[String] = _ + private var _files: Seq[String] = _ + + /* ------------------------------------------------------------------------------------- * + | Accessors and public fields. These provide access to the internal state of the | + | context. | + * ------------------------------------------------------------------------------------- */ + + private[spark] def conf: SparkConf = _conf /** * Return a copy of this SparkContext's configuration. The configuration ''cannot'' be @@ -201,65 +238,24 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli */ def getConf: SparkConf = conf.clone() - if (!conf.contains("spark.master")) { - throw new SparkException("A master URL must be set in your configuration") - } - if (!conf.contains("spark.app.name")) { - throw new SparkException("An application name must be set in your configuration") - } - - if (conf.getBoolean("spark.logConf", false)) { - logInfo("Spark configuration:\n" + conf.toDebugString) - } - - // Set Spark driver host and port system properties - conf.setIfMissing("spark.driver.host", Utils.localHostName()) - conf.setIfMissing("spark.driver.port", "0") - - val jars: Seq[String] = - conf.getOption("spark.jars").map(_.split(",")).map(_.filter(_.size != 0)).toSeq.flatten - - val files: Seq[String] = - conf.getOption("spark.files").map(_.split(",")).map(_.filter(_.size != 0)).toSeq.flatten - - val master = conf.get("spark.master") - val appName = conf.get("spark.app.name") + def jars: Seq[String] = _jars + def files: Seq[String] = _files + def master: String = _conf.get("spark.master") + def appName: String = _conf.get("spark.app.name") - private[spark] val isEventLogEnabled = conf.getBoolean("spark.eventLog.enabled", false) - private[spark] val eventLogDir: Option[URI] = { - if (isEventLogEnabled) { - val unresolvedDir = conf.get("spark.eventLog.dir", EventLoggingListener.DEFAULT_LOG_DIR) - .stripSuffix("/") - Some(Utils.resolveURI(unresolvedDir)) - } else { - None - } - } - private[spark] val eventLogCodec: Option[String] = { - val compress = conf.getBoolean("spark.eventLog.compress", false) - if (compress && isEventLogEnabled) { - Some(CompressionCodec.getCodecName(conf)).map(CompressionCodec.getShortName) - } else { - None - } - } + private[spark] def isEventLogEnabled: Boolean = _conf.getBoolean("spark.eventLog.enabled", false) + private[spark] def eventLogDir: Option[URI] = _eventLogDir + private[spark] def eventLogCodec: Option[String] = _eventLogCodec // Generate the random name for a temp folder in Tachyon // Add a timestamp as the suffix here to make it more safe val tachyonFolderName = "spark-" + randomUUID.toString() - conf.set("spark.tachyonStore.folderName", tachyonFolderName) - val isLocal = (master == "local" || master.startsWith("local[")) - - if (master == "yarn-client") System.setProperty("SPARK_YARN_MODE", "true") + def isLocal: Boolean = (master == "local" || master.startsWith("local[")) // An asynchronous listener bus for Spark events private[spark] val listenerBus = new LiveListenerBus - conf.set("spark.executor.id", SparkContext.DRIVER_IDENTIFIER) - - // Create the Spark execution environment (cache, map output tracker, etc) - // This function allows components created by SparkEnv to be mocked in unit tests: private[spark] def createSparkEnv( conf: SparkConf, @@ -268,8 +264,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli SparkEnv.createDriverEnv(conf, isLocal, listenerBus) } - private[spark] val env = createSparkEnv(conf, isLocal, listenerBus) - SparkEnv.set(env) + private[spark] def env: SparkEnv = _env // Used to store a URL for each static file/jar together with the file's local timestamp private[spark] val addedFiles = HashMap[String, Long]() @@ -277,35 +272,14 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli // Keeps track of all persisted RDDs private[spark] val persistentRdds = new TimeStampedWeakValueHashMap[Int, RDD[_]] - private[spark] val metadataCleaner = - new MetadataCleaner(MetadataCleanerType.SPARK_CONTEXT, this.cleanup, conf) - + private[spark] def metadataCleaner: MetadataCleaner = _metadataCleaner + private[spark] def jobProgressListener: JobProgressListener = _jobProgressListener - private[spark] val jobProgressListener = new JobProgressListener(conf) - listenerBus.addListener(jobProgressListener) + def statusTracker: SparkStatusTracker = _statusTracker - val statusTracker = new SparkStatusTracker(this) + private[spark] def progressBar: Option[ConsoleProgressBar] = _progressBar - private[spark] val progressBar: Option[ConsoleProgressBar] = - if (conf.getBoolean("spark.ui.showConsoleProgress", true) && !log.isInfoEnabled) { - Some(new ConsoleProgressBar(this)) - } else { - None - } - - // Initialize the Spark UI - private[spark] val ui: Option[SparkUI] = - if (conf.getBoolean("spark.ui.enabled", true)) { - Some(SparkUI.createLiveUI(this, conf, listenerBus, jobProgressListener, - env.securityManager,appName)) - } else { - // For tests, do not enable the UI - None - } - - // Bind the UI before starting the task scheduler to communicate - // the bound port to the cluster manager properly - ui.foreach(_.bind()) + private[spark] def ui: Option[SparkUI] = _ui /** * A default Hadoop Configuration for the Hadoop code (e.g. file systems) that we reuse. @@ -313,135 +287,251 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli * '''Note:''' As it will be reused in all Hadoop RDDs, it's better not to modify it unless you * plan to set some global configurations for all Hadoop RDDs. */ - val hadoopConfiguration = SparkHadoopUtil.get.newConfiguration(conf) + def hadoopConfiguration: Configuration = _hadoopConfiguration + + private[spark] def executorMemory: Int = _executorMemory + + // Environment variables to pass to our executors. + private[spark] val executorEnvs = HashMap[String, String]() + + // Set SPARK_USER for user who is running SparkContext. + val sparkUser = Utils.getCurrentUserName() - // Add each JAR given through the constructor - if (jars != null) { - jars.foreach(addJar) + private[spark] def schedulerBackend: SchedulerBackend = _schedulerBackend + private[spark] def schedulerBackend_=(sb: SchedulerBackend): Unit = { + _schedulerBackend = sb } - if (files != null) { - files.foreach(addFile) + private[spark] def taskScheduler: TaskScheduler = _taskScheduler + private[spark] def taskScheduler_=(ts: TaskScheduler): Unit = { + _taskScheduler = ts } + private[spark] def dagScheduler: DAGScheduler = _dagScheduler + private[spark] def dagScheduler_=(ds: DAGScheduler): Unit = { + _dagScheduler = ds + } + + def applicationId: String = _applicationId + def applicationAttemptId: String = _applicationAttemptId + + def metricsSystem: MetricsSystem = if (_env != null) _env.metricsSystem else null + + private[spark] def eventLogger: Option[EventLoggingListener] = _eventLogger + + private[spark] def executorAllocationManager: Option[ExecutorAllocationManager] = + _executorAllocationManager + + private[spark] def cleaner: Option[ContextCleaner] = _cleaner + + private[spark] var checkpointDir: Option[String] = None + + // Thread Local variable that can be used by users to pass information down the stack + private val localProperties = new InheritableThreadLocal[Properties] { + override protected def childValue(parent: Properties): Properties = new Properties(parent) + override protected def initialValue(): Properties = new Properties() + } + + /* ------------------------------------------------------------------------------------- * + | Initialization. This code initializes the context in a manner that is exception-safe. | + | All internal fields holding state are initialized here, and any error prompts the | + | stop() method to be called. | + * ------------------------------------------------------------------------------------- */ + private def warnSparkMem(value: String): String = { logWarning("Using SPARK_MEM to set amount of memory to use per executor process is " + "deprecated, please use spark.executor.memory instead.") value } - private[spark] val executorMemory = conf.getOption("spark.executor.memory") - .orElse(Option(System.getenv("SPARK_EXECUTOR_MEMORY"))) - .orElse(Option(System.getenv("SPARK_MEM")).map(warnSparkMem)) - .map(Utils.memoryStringToMb) - .getOrElse(512) + try { + _conf = config.clone() + _conf.validateSettings() - // Environment variables to pass to our executors. - private[spark] val executorEnvs = HashMap[String, String]() + if (!_conf.contains("spark.master")) { + throw new SparkException("A master URL must be set in your configuration") + } + if (!_conf.contains("spark.app.name")) { + throw new SparkException("An application name must be set in your configuration") + } - // Convert java options to env vars as a work around - // since we can't set env vars directly in sbt. - for { (envKey, propKey) <- Seq(("SPARK_TESTING", "spark.testing")) - value <- Option(System.getenv(envKey)).orElse(Option(System.getProperty(propKey)))} { - executorEnvs(envKey) = value - } - Option(System.getenv("SPARK_PREPEND_CLASSES")).foreach { v => - executorEnvs("SPARK_PREPEND_CLASSES") = v - } - // The Mesos scheduler backend relies on this environment variable to set executor memory. - // TODO: Set this only in the Mesos scheduler. - executorEnvs("SPARK_EXECUTOR_MEMORY") = executorMemory + "m" - executorEnvs ++= conf.getExecutorEnv + if (_conf.getBoolean("spark.logConf", false)) { + logInfo("Spark configuration:\n" + _conf.toDebugString) + } - // Set SPARK_USER for user who is running SparkContext. - val sparkUser = Utils.getCurrentUserName() - executorEnvs("SPARK_USER") = sparkUser + // Set Spark driver host and port system properties + _conf.setIfMissing("spark.driver.host", Utils.localHostName()) + _conf.setIfMissing("spark.driver.port", "0") - // We need to register "HeartbeatReceiver" before "createTaskScheduler" because Executor will - // retrieve "HeartbeatReceiver" in the constructor. (SPARK-6640) - private val heartbeatReceiver = env.rpcEnv.setupEndpoint( - HeartbeatReceiver.ENDPOINT_NAME, new HeartbeatReceiver(this)) + _conf.set("spark.executor.id", SparkContext.DRIVER_IDENTIFIER) - // Create and start the scheduler - private[spark] var (schedulerBackend, taskScheduler) = - SparkContext.createTaskScheduler(this, master) + _jars =_conf.getOption("spark.jars").map(_.split(",")).map(_.filter(_.size != 0)).toSeq.flatten + _files = _conf.getOption("spark.files").map(_.split(",")).map(_.filter(_.size != 0)) + .toSeq.flatten - heartbeatReceiver.send(TaskSchedulerIsSet) + _eventLogDir = + if (isEventLogEnabled) { + val unresolvedDir = conf.get("spark.eventLog.dir", EventLoggingListener.DEFAULT_LOG_DIR) + .stripSuffix("/") + Some(Utils.resolveURI(unresolvedDir)) + } else { + None + } - @volatile private[spark] var dagScheduler: DAGScheduler = _ - try { - dagScheduler = new DAGScheduler(this) - } catch { - case e: Exception => { - try { - stop() - } finally { - throw new SparkException("Error while constructing DAGScheduler", e) + _eventLogCodec = { + val compress = _conf.getBoolean("spark.eventLog.compress", false) + if (compress && isEventLogEnabled) { + Some(CompressionCodec.getCodecName(_conf)).map(CompressionCodec.getShortName) + } else { + None } } - } - // start TaskScheduler after taskScheduler sets DAGScheduler reference in DAGScheduler's - // constructor - taskScheduler.start() + _conf.set("spark.tachyonStore.folderName", tachyonFolderName) - val applicationId: String = taskScheduler.applicationId() - val applicationAttemptId: String = taskScheduler.applicationAttemptId() - conf.set("spark.app.id", applicationId) + if (master == "yarn-client") System.setProperty("SPARK_YARN_MODE", "true") - env.blockManager.initialize(applicationId) + // Create the Spark execution environment (cache, map output tracker, etc) + _env = createSparkEnv(_conf, isLocal, listenerBus) + SparkEnv.set(_env) - val metricsSystem = env.metricsSystem + _metadataCleaner = new MetadataCleaner(MetadataCleanerType.SPARK_CONTEXT, this.cleanup, _conf) - // The metrics system for Driver need to be set spark.app.id to app ID. - // So it should start after we get app ID from the task scheduler and set spark.app.id. - metricsSystem.start() - // Attach the driver metrics servlet handler to the web ui after the metrics system is started. - metricsSystem.getServletHandlers.foreach(handler => ui.foreach(_.attachHandler(handler))) + _jobProgressListener = new JobProgressListener(_conf) + listenerBus.addListener(jobProgressListener) - // Optionally log Spark events - private[spark] val eventLogger: Option[EventLoggingListener] = { - if (isEventLogEnabled) { - val logger = new EventLoggingListener( - applicationId, applicationAttemptId, eventLogDir.get, conf, hadoopConfiguration) - logger.start() - listenerBus.addListener(logger) - Some(logger) - } else None - } + _statusTracker = new SparkStatusTracker(this) - // Optionally scale number of executors dynamically based on workload. Exposed for testing. - private val dynamicAllocationEnabled = conf.getBoolean("spark.dynamicAllocation.enabled", false) - private val dynamicAllocationTesting = conf.getBoolean("spark.dynamicAllocation.testing", false) - private[spark] val executorAllocationManager: Option[ExecutorAllocationManager] = - if (dynamicAllocationEnabled) { - assert(supportDynamicAllocation, - "Dynamic allocation of executors is currently only supported in YARN mode") - Some(new ExecutorAllocationManager(this, listenerBus, conf)) - } else { - None + _progressBar = + if (_conf.getBoolean("spark.ui.showConsoleProgress", true) && !log.isInfoEnabled) { + Some(new ConsoleProgressBar(this)) + } else { + None + } + + _ui = + if (conf.getBoolean("spark.ui.enabled", true)) { + Some(SparkUI.createLiveUI(this, _conf, listenerBus, _jobProgressListener, + _env.securityManager,appName)) + } else { + // For tests, do not enable the UI + None + } + // Bind the UI before starting the task scheduler to communicate + // the bound port to the cluster manager properly + _ui.foreach(_.bind()) + + _hadoopConfiguration = SparkHadoopUtil.get.newConfiguration(_conf) + + // Add each JAR given through the constructor + if (jars != null) { + jars.foreach(addJar) } - executorAllocationManager.foreach(_.start()) - private[spark] val cleaner: Option[ContextCleaner] = { - if (conf.getBoolean("spark.cleaner.referenceTracking", true)) { - Some(new ContextCleaner(this)) - } else { - None + if (files != null) { + files.foreach(addFile) } - } - cleaner.foreach(_.start()) - setupAndStartListenerBus() - postEnvironmentUpdate() - postApplicationStart() + _executorMemory = _conf.getOption("spark.executor.memory") + .orElse(Option(System.getenv("SPARK_EXECUTOR_MEMORY"))) + .orElse(Option(System.getenv("SPARK_MEM")) + .map(warnSparkMem)) + .map(Utils.memoryStringToMb) + .getOrElse(512) + + // Convert java options to env vars as a work around + // since we can't set env vars directly in sbt. + for { (envKey, propKey) <- Seq(("SPARK_TESTING", "spark.testing")) + value <- Option(System.getenv(envKey)).orElse(Option(System.getProperty(propKey)))} { + executorEnvs(envKey) = value + } + Option(System.getenv("SPARK_PREPEND_CLASSES")).foreach { v => + executorEnvs("SPARK_PREPEND_CLASSES") = v + } + // The Mesos scheduler backend relies on this environment variable to set executor memory. + // TODO: Set this only in the Mesos scheduler. + executorEnvs("SPARK_EXECUTOR_MEMORY") = executorMemory + "m" + executorEnvs ++= _conf.getExecutorEnv + executorEnvs("SPARK_USER") = sparkUser + + // We need to register "HeartbeatReceiver" before "createTaskScheduler" because Executor will + // retrieve "HeartbeatReceiver" in the constructor. (SPARK-6640) + _heartbeatReceiver = env.rpcEnv.setupEndpoint( + HeartbeatReceiver.ENDPOINT_NAME, new HeartbeatReceiver(this)) + + // Create and start the scheduler + val (sched, ts) = SparkContext.createTaskScheduler(this, master) + _schedulerBackend = sched + _taskScheduler = ts + _dagScheduler = new DAGScheduler(this) + _heartbeatReceiver.send(TaskSchedulerIsSet) + + // start TaskScheduler after taskScheduler sets DAGScheduler reference in DAGScheduler's + // constructor + _taskScheduler.start() + + _applicationId = _taskScheduler.applicationId() + _applicationAttemptId = taskScheduler.applicationAttemptId() + _conf.set("spark.app.id", _applicationId) + _env.blockManager.initialize(_applicationId) + + // The metrics system for Driver need to be set spark.app.id to app ID. + // So it should start after we get app ID from the task scheduler and set spark.app.id. + metricsSystem.start() + // Attach the driver metrics servlet handler to the web ui after the metrics system is started. + metricsSystem.getServletHandlers.foreach(handler => ui.foreach(_.attachHandler(handler))) + + _eventLogger = + if (isEventLogEnabled) { + val logger = + new EventLoggingListener(_applicationId, _applicationAttemptId, _eventLogDir.get, + _conf, _hadoopConfiguration) + logger.start() + listenerBus.addListener(logger) + Some(logger) + } else { + None + } - private[spark] var checkpointDir: Option[String] = None + // Optionally scale number of executors dynamically based on workload. Exposed for testing. + val dynamicAllocationEnabled = _conf.getBoolean("spark.dynamicAllocation.enabled", false) + _executorAllocationManager = + if (dynamicAllocationEnabled) { + assert(supportDynamicAllocation, + "Dynamic allocation of executors is currently only supported in YARN mode") + Some(new ExecutorAllocationManager(this, listenerBus, _conf)) + } else { + None + } + _executorAllocationManager.foreach(_.start()) - // Thread Local variable that can be used by users to pass information down the stack - private val localProperties = new InheritableThreadLocal[Properties] { - override protected def childValue(parent: Properties): Properties = new Properties(parent) - override protected def initialValue(): Properties = new Properties() + _cleaner = + if (_conf.getBoolean("spark.cleaner.referenceTracking", true)) { + Some(new ContextCleaner(this)) + } else { + None + } + _cleaner.foreach(_.start()) + + setupAndStartListenerBus() + postEnvironmentUpdate() + postApplicationStart() + + // Post init + _taskScheduler.postStartHook() + _env.metricsSystem.registerSource(new DAGSchedulerSource(dagScheduler)) + _env.metricsSystem.registerSource(new BlockManagerSource(_env.blockManager)) + } catch { + case NonFatal(e) => + logError("Error initializing SparkContext.", e) + try { + stop() + } catch { + case NonFatal(inner) => + logError("Error stopping SparkContext after init error.", inner) + } finally { + throw e + } } /** @@ -545,19 +635,6 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli setLocalProperty(SparkContext.SPARK_JOB_INTERRUPT_ON_CANCEL, null) } - // Post init - taskScheduler.postStartHook() - - private val dagSchedulerSource = new DAGSchedulerSource(this.dagScheduler) - private val blockManagerSource = new BlockManagerSource(SparkEnv.get.blockManager) - - private def initDriverMetrics() { - SparkEnv.get.metricsSystem.registerSource(dagSchedulerSource) - SparkEnv.get.metricsSystem.registerSource(blockManagerSource) - } - - initDriverMetrics() - // Methods for creating RDDs /** Distribute a local Scala collection to form an RDD. @@ -1147,7 +1224,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli * this application is supported. This is currently only available for YARN. */ private[spark] def supportDynamicAllocation = - master.contains("yarn") || dynamicAllocationTesting + master.contains("yarn") || _conf.getBoolean("spark.dynamicAllocation.testing", false) /** * :: DeveloperApi :: @@ -1164,7 +1241,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli * This is currently only supported in YARN mode. Return whether the request is received. */ private[spark] override def requestTotalExecutors(numExecutors: Int): Boolean = { - assert(master.contains("yarn") || dynamicAllocationTesting, + assert(supportDynamicAllocation, "Requesting executors is currently only supported in YARN mode") schedulerBackend match { case b: CoarseGrainedSchedulerBackend => @@ -1404,28 +1481,40 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli def stop() { // Use the stopping variable to ensure no contention for the stop scenario. // Still track the stopped variable for use elsewhere in the code. - if (!stopped.compareAndSet(false, true)) { logInfo("SparkContext already stopped.") return } postApplicationEnd() - ui.foreach(_.stop()) - env.metricsSystem.report() - metadataCleaner.cancel() - cleaner.foreach(_.stop()) - executorAllocationManager.foreach(_.stop()) - dagScheduler.stop() - dagScheduler = null - listenerBus.stop() - eventLogger.foreach(_.stop()) - env.rpcEnv.stop(heartbeatReceiver) - progressBar.foreach(_.stop()) - taskScheduler = null + _ui.foreach(_.stop()) + if (env != null) { + env.metricsSystem.report() + } + if (metadataCleaner != null) { + metadataCleaner.cancel() + } + _cleaner.foreach(_.stop()) + _executorAllocationManager.foreach(_.stop()) + if (_dagScheduler != null) { + _dagScheduler.stop() + _dagScheduler = null + } + if (_listenerBusStarted) { + listenerBus.stop() + _listenerBusStarted = false + } + _eventLogger.foreach(_.stop()) + if (env != null && _heartbeatReceiver != null) { + env.rpcEnv.stop(_heartbeatReceiver) + } + _progressBar.foreach(_.stop()) + _taskScheduler = null // TODO: Cache.stop()? - env.stop() - SparkEnv.set(null) + if (_env != null) { + _env.stop() + SparkEnv.set(null) + } SparkContext.clearActiveContext() logInfo("Successfully stopped SparkContext") } @@ -1750,6 +1839,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli } listenerBus.start(this) + _listenerBusStarted = true } /** Post the application start event */ @@ -2153,7 +2243,7 @@ object SparkContext extends Logging { master match { case "local" => val scheduler = new TaskSchedulerImpl(sc, MAX_LOCAL_TASK_FAILURES, isLocal = true) - val backend = new LocalBackend(scheduler, 1) + val backend = new LocalBackend(sc.getConf, scheduler, 1) scheduler.initialize(backend) (backend, scheduler) @@ -2165,7 +2255,7 @@ object SparkContext extends Logging { throw new SparkException(s"Asked to run locally with $threadCount threads") } val scheduler = new TaskSchedulerImpl(sc, MAX_LOCAL_TASK_FAILURES, isLocal = true) - val backend = new LocalBackend(scheduler, threadCount) + val backend = new LocalBackend(sc.getConf, scheduler, threadCount) scheduler.initialize(backend) (backend, scheduler) @@ -2175,7 +2265,7 @@ object SparkContext extends Logging { // local[N, M] means exactly N threads with M failures val threadCount = if (threads == "*") localCpuCount else threads.toInt val scheduler = new TaskSchedulerImpl(sc, maxFailures.toInt, isLocal = true) - val backend = new LocalBackend(scheduler, threadCount) + val backend = new LocalBackend(sc.getConf, scheduler, threadCount) scheduler.initialize(backend) (backend, scheduler) diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala index b1ffba4c546bf..7409dc2d866f6 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala @@ -604,7 +604,7 @@ private[spark] object PythonRDD extends Logging { * The thread will terminate after all the data are sent or any exceptions happen. */ private def serveIterator[T](items: Iterator[T], threadName: String): Int = { - val serverSocket = new ServerSocket(0, 1) + val serverSocket = new ServerSocket(0, 1, InetAddress.getByName("localhost")) // Close the socket if no connection in 3 seconds serverSocket.setSoTimeout(3000) diff --git a/core/src/main/scala/org/apache/spark/deploy/ApplicationDescription.scala b/core/src/main/scala/org/apache/spark/deploy/ApplicationDescription.scala index b7ae9c1fc0a23..ae99432f5ce86 100644 --- a/core/src/main/scala/org/apache/spark/deploy/ApplicationDescription.scala +++ b/core/src/main/scala/org/apache/spark/deploy/ApplicationDescription.scala @@ -22,12 +22,13 @@ import java.net.URI private[spark] class ApplicationDescription( val name: String, val maxCores: Option[Int], - val memoryPerSlave: Int, + val memoryPerExecutorMB: Int, val command: Command, var appUiUrl: String, val eventLogDir: Option[URI] = None, // short name of compression codec used when writing event logs, if any (e.g. lzf) - val eventLogCodec: Option[String] = None) + val eventLogCodec: Option[String] = None, + val coresPerExecutor: Option[Int] = None) extends Serializable { val user = System.getProperty("user.name", "") @@ -35,13 +36,13 @@ private[spark] class ApplicationDescription( def copy( name: String = name, maxCores: Option[Int] = maxCores, - memoryPerSlave: Int = memoryPerSlave, + memoryPerExecutorMB: Int = memoryPerExecutorMB, command: Command = command, appUiUrl: String = appUiUrl, eventLogDir: Option[URI] = eventLogDir, eventLogCodec: Option[String] = eventLogCodec): ApplicationDescription = new ApplicationDescription( - name, maxCores, memoryPerSlave, command, appUiUrl, eventLogDir, eventLogCodec) + name, maxCores, memoryPerExecutorMB, command, appUiUrl, eventLogDir, eventLogCodec) override def toString: String = "ApplicationDescription(" + name + ")" } diff --git a/core/src/main/scala/org/apache/spark/deploy/JsonProtocol.scala b/core/src/main/scala/org/apache/spark/deploy/JsonProtocol.scala index dfc5b97e6a6c8..2954f932b4f41 100644 --- a/core/src/main/scala/org/apache/spark/deploy/JsonProtocol.scala +++ b/core/src/main/scala/org/apache/spark/deploy/JsonProtocol.scala @@ -46,7 +46,7 @@ private[deploy] object JsonProtocol { ("name" -> obj.desc.name) ~ ("cores" -> obj.desc.maxCores) ~ ("user" -> obj.desc.user) ~ - ("memoryperslave" -> obj.desc.memoryPerSlave) ~ + ("memoryperslave" -> obj.desc.memoryPerExecutorMB) ~ ("submitdate" -> obj.submitDate.toString) ~ ("state" -> obj.state.toString) ~ ("duration" -> obj.duration) @@ -55,7 +55,7 @@ private[deploy] object JsonProtocol { def writeApplicationDescription(obj: ApplicationDescription): JObject = { ("name" -> obj.name) ~ ("cores" -> obj.maxCores) ~ - ("memoryperslave" -> obj.memoryPerSlave) ~ + ("memoryperslave" -> obj.memoryPerExecutorMB) ~ ("user" -> obj.user) ~ ("command" -> obj.command.toString) } diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala index 60bc243ebf40a..296a0764b8baf 100644 --- a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala +++ b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala @@ -406,6 +406,8 @@ object SparkSubmit { OptionAssigner(args.jars, YARN, CLUSTER, clOption = "--addJars"), // Other options + OptionAssigner(args.executorCores, STANDALONE, ALL_DEPLOY_MODES, + sysProp = "spark.executor.cores"), OptionAssigner(args.executorMemory, STANDALONE | MESOS | YARN, ALL_DEPLOY_MODES, sysProp = "spark.executor.memory"), OptionAssigner(args.totalExecutorCores, STANDALONE | MESOS, ALL_DEPLOY_MODES, diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala b/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala index 03ecf3fd99ec5..faa8780288ea3 100644 --- a/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala +++ b/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala @@ -482,10 +482,13 @@ private[deploy] class SparkSubmitArguments(args: Seq[String], env: Map[String, S | Spark standalone and Mesos only: | --total-executor-cores NUM Total cores for all executors. | + | Spark standalone and YARN only: + | --executor-cores NUM Number of cores per executor. (Default: 1 in YARN mode, + | or all available cores on the worker in standalone mode) + | | YARN-only: | --driver-cores NUM Number of cores used by the driver, only in cluster mode | (Default: 1). - | --executor-cores NUM Number of cores per executor (Default: 1). | --queue QUEUE_NAME The YARN queue to submit to (Default: "default"). | --num-executors NUM Number of executors to launch (Default: 2). | --archives ARCHIVES Comma separated list of archives to be extracted into the diff --git a/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala b/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala index 783e144a8c7b4..ad4f6f3032c74 100644 --- a/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala +++ b/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala @@ -49,11 +49,7 @@ private[history] class FsHistoryProvider(conf: SparkConf) extends ApplicationHis private val NOT_STARTED = "" // Interval between each check for event log updates - private val UPDATE_INTERVAL_MS = conf.getOption("spark.history.fs.update.interval.seconds") - .orElse(conf.getOption("spark.history.fs.updateInterval")) - .orElse(conf.getOption("spark.history.updateInterval")) - .map(_.toInt) - .getOrElse(10) * 1000 + private val UPDATE_INTERVAL_S = conf.getTimeAsSeconds("spark.history.fs.update.interval", "10s") // Interval between each cleaner checks for event logs to delete private val CLEAN_INTERVAL_MS = conf.getLong("spark.history.fs.cleaner.interval.seconds", @@ -130,8 +126,7 @@ private[history] class FsHistoryProvider(conf: SparkConf) extends ApplicationHis // Disable the background thread during tests. if (!conf.contains("spark.testing")) { // A task that periodically checks for event log updates on disk. - pool.scheduleAtFixedRate(getRunner(checkForLogs), 0, UPDATE_INTERVAL_MS, - TimeUnit.MILLISECONDS) + pool.scheduleAtFixedRate(getRunner(checkForLogs), 0, UPDATE_INTERVAL_S, TimeUnit.SECONDS) if (conf.getBoolean("spark.history.fs.cleaner.enabled", false)) { // A task that periodically cleans event logs on disk. diff --git a/core/src/main/scala/org/apache/spark/deploy/history/HistoryPage.scala b/core/src/main/scala/org/apache/spark/deploy/history/HistoryPage.scala index 0a41108a36b92..f60babaf9af57 100644 --- a/core/src/main/scala/org/apache/spark/deploy/history/HistoryPage.scala +++ b/core/src/main/scala/org/apache/spark/deploy/history/HistoryPage.scala @@ -103,6 +103,8 @@ private[history] class HistoryPage(parent: HistoryServer) extends WebUIPage("") ++ appTable + } else if (requestedIncomplete) { +

No incomplete applications found!

} else {

No completed applications found!

++

Did you specify the correct logging directory? diff --git a/core/src/main/scala/org/apache/spark/deploy/master/ApplicationInfo.scala b/core/src/main/scala/org/apache/spark/deploy/master/ApplicationInfo.scala index bc5b293379f2b..f59d550d4f3b3 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/ApplicationInfo.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/ApplicationInfo.scala @@ -75,9 +75,11 @@ private[deploy] class ApplicationInfo( } } - private[master] def addExecutor(worker: WorkerInfo, cores: Int, useID: Option[Int] = None): - ExecutorDesc = { - val exec = new ExecutorDesc(newExecutorId(useID), this, worker, cores, desc.memoryPerSlave) + private[master] def addExecutor( + worker: WorkerInfo, + cores: Int, + useID: Option[Int] = None): ExecutorDesc = { + val exec = new ExecutorDesc(newExecutorId(useID), this, worker, cores, desc.memoryPerExecutorMB) executors(exec.id) = exec coresGranted += cores exec diff --git a/core/src/main/scala/org/apache/spark/deploy/master/Master.scala b/core/src/main/scala/org/apache/spark/deploy/master/Master.scala index 4994004180942..80098cb9ff43f 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/Master.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/Master.scala @@ -524,52 +524,28 @@ private[master] class Master( } /** - * Can an app use the given worker? True if the worker has enough memory and we haven't already - * launched an executor for the app on it (right now the standalone backend doesn't like having - * two executors on the same worker). - */ - private def canUse(app: ApplicationInfo, worker: WorkerInfo): Boolean = { - worker.memoryFree >= app.desc.memoryPerSlave && !worker.hasExecutor(app) - } - - /** - * Schedule the currently available resources among waiting apps. This method will be called - * every time a new app joins or resource availability changes. + * Schedule executors to be launched on the workers. + * + * There are two modes of launching executors. The first attempts to spread out an application's + * executors on as many workers as possible, while the second does the opposite (i.e. launch them + * on as few workers as possible). The former is usually better for data locality purposes and is + * the default. + * + * The number of cores assigned to each executor is configurable. When this is explicitly set, + * multiple executors from the same application may be launched on the same worker if the worker + * has enough cores and memory. Otherwise, each executor grabs all the cores available on the + * worker by default, in which case only one executor may be launched on each worker. */ - private def schedule() { - if (state != RecoveryState.ALIVE) { return } - - // First schedule drivers, they take strict precedence over applications - // Randomization helps balance drivers - val shuffledAliveWorkers = Random.shuffle(workers.toSeq.filter(_.state == WorkerState.ALIVE)) - val numWorkersAlive = shuffledAliveWorkers.size - var curPos = 0 - - for (driver <- waitingDrivers.toList) { // iterate over a copy of waitingDrivers - // We assign workers to each waiting driver in a round-robin fashion. For each driver, we - // start from the last worker that was assigned a driver, and continue onwards until we have - // explored all alive workers. - var launched = false - var numWorkersVisited = 0 - while (numWorkersVisited < numWorkersAlive && !launched) { - val worker = shuffledAliveWorkers(curPos) - numWorkersVisited += 1 - if (worker.memoryFree >= driver.desc.mem && worker.coresFree >= driver.desc.cores) { - launchDriver(worker, driver) - waitingDrivers -= driver - launched = true - } - curPos = (curPos + 1) % numWorkersAlive - } - } - + private def startExecutorsOnWorkers(): Unit = { // Right now this is a very simple FIFO scheduler. We keep trying to fit in the first app // in the queue, then the second app, etc. if (spreadOutApps) { - // Try to spread out each app among all the nodes, until it has all its cores + // Try to spread out each app among all the workers, until it has all its cores for (app <- waitingApps if app.coresLeft > 0) { val usableWorkers = workers.toArray.filter(_.state == WorkerState.ALIVE) - .filter(canUse(app, _)).sortBy(_.coresFree).reverse + .filter(worker => worker.memoryFree >= app.desc.memoryPerExecutorMB && + worker.coresFree >= app.desc.coresPerExecutor.getOrElse(1)) + .sortBy(_.coresFree).reverse val numUsable = usableWorkers.length val assigned = new Array[Int](numUsable) // Number of cores to give on each node var toAssign = math.min(app.coresLeft, usableWorkers.map(_.coresFree).sum) @@ -582,32 +558,61 @@ private[master] class Master( pos = (pos + 1) % numUsable } // Now that we've decided how many cores to give on each node, let's actually give them - for (pos <- 0 until numUsable) { - if (assigned(pos) > 0) { - val exec = app.addExecutor(usableWorkers(pos), assigned(pos)) - launchExecutor(usableWorkers(pos), exec) - app.state = ApplicationState.RUNNING - } + for (pos <- 0 until numUsable if assigned(pos) > 0) { + allocateWorkerResourceToExecutors(app, assigned(pos), usableWorkers(pos)) } } } else { - // Pack each app into as few nodes as possible until we've assigned all its cores + // Pack each app into as few workers as possible until we've assigned all its cores for (worker <- workers if worker.coresFree > 0 && worker.state == WorkerState.ALIVE) { for (app <- waitingApps if app.coresLeft > 0) { - if (canUse(app, worker)) { - val coresToUse = math.min(worker.coresFree, app.coresLeft) - if (coresToUse > 0) { - val exec = app.addExecutor(worker, coresToUse) - launchExecutor(worker, exec) - app.state = ApplicationState.RUNNING - } - } + allocateWorkerResourceToExecutors(app, app.coresLeft, worker) + } + } + } + } + + /** + * Allocate a worker's resources to one or more executors. + * @param app the info of the application which the executors belong to + * @param coresToAllocate cores on this worker to be allocated to this application + * @param worker the worker info + */ + private def allocateWorkerResourceToExecutors( + app: ApplicationInfo, + coresToAllocate: Int, + worker: WorkerInfo): Unit = { + val memoryPerExecutor = app.desc.memoryPerExecutorMB + val coresPerExecutor = app.desc.coresPerExecutor.getOrElse(coresToAllocate) + var coresLeft = coresToAllocate + while (coresLeft >= coresPerExecutor && worker.memoryFree >= memoryPerExecutor) { + val exec = app.addExecutor(worker, coresPerExecutor) + coresLeft -= coresPerExecutor + launchExecutor(worker, exec) + app.state = ApplicationState.RUNNING + } + } + + /** + * Schedule the currently available resources among waiting apps. This method will be called + * every time a new app joins or resource availability changes. + */ + private def schedule(): Unit = { + if (state != RecoveryState.ALIVE) { return } + // Drivers take strict precedence over executors + val shuffledWorkers = Random.shuffle(workers) // Randomization helps balance drivers + for (worker <- shuffledWorkers if worker.state == WorkerState.ALIVE) { + for (driver <- waitingDrivers) { + if (worker.memoryFree >= driver.desc.mem && worker.coresFree >= driver.desc.cores) { + launchDriver(worker, driver) + waitingDrivers -= driver } } } + startExecutorsOnWorkers() } - private def launchExecutor(worker: WorkerInfo, exec: ExecutorDesc) { + private def launchExecutor(worker: WorkerInfo, exec: ExecutorDesc): Unit = { logInfo("Launching executor " + exec.fullId + " on worker " + worker.id) worker.addExecutor(exec) worker.actor ! LaunchExecutor(masterUrl, diff --git a/core/src/main/scala/org/apache/spark/deploy/master/ui/ApplicationPage.scala b/core/src/main/scala/org/apache/spark/deploy/master/ui/ApplicationPage.scala index 761aa8f7b1ef6..273f077bd8f57 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/ui/ApplicationPage.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/ui/ApplicationPage.scala @@ -94,7 +94,7 @@ private[ui] class ApplicationPage(parent: MasterWebUI) extends WebUIPage("app")

  • Executor Memory: - {Utils.megabytesToString(app.desc.memoryPerSlave)} + {Utils.megabytesToString(app.desc.memoryPerExecutorMB)}
  • Submit Date: {app.submitDate}
  • State: {app.state}
  • diff --git a/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterPage.scala b/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterPage.scala index 45412a35e9a7d..1f2c3fdbfb2bc 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterPage.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterPage.scala @@ -190,12 +190,14 @@ private[ui] class MasterPage(parent: MasterWebUI) extends WebUIPage("") { private def appRow(app: ApplicationInfo): Seq[Node] = { val killLink = if (parent.killEnabled && (app.state == ApplicationState.RUNNING || app.state == ApplicationState.WAITING)) { - val killLinkUri = s"app/kill?id=${app.id}&terminate=true" - val confirm = "return window.confirm(" + - s"'Are you sure you want to kill application ${app.id} ?');" - - (kill) - + val confirm = + s"if (window.confirm('Are you sure you want to kill application ${app.id} ?')) " + + "{ this.parentNode.submit(); return true; } else { return false; }" +
    + + + (kill) +
    } @@ -208,8 +210,8 @@ private[ui] class MasterPage(parent: MasterWebUI) extends WebUIPage("") { {app.coresGranted} - - {Utils.megabytesToString(app.desc.memoryPerSlave)} + + {Utils.megabytesToString(app.desc.memoryPerExecutorMB)} {UIUtils.formatDate(app.submitDate)} {app.desc.user} @@ -223,12 +225,14 @@ private[ui] class MasterPage(parent: MasterWebUI) extends WebUIPage("") { (driver.state == DriverState.RUNNING || driver.state == DriverState.SUBMITTED || driver.state == DriverState.RELAUNCHING)) { - val killLinkUri = s"driver/kill?id=${driver.id}&terminate=true" - val confirm = "return window.confirm(" + - s"'Are you sure you want to kill driver ${driver.id} ?');" - - (kill) - + val confirm = + s"if (window.confirm('Are you sure you want to kill driver ${driver.id} ?')) " + + "{ this.parentNode.submit(); return true; } else { return false; }" +
    + + + (kill) +
    } {driver.id} {killLink} diff --git a/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterWebUI.scala b/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterWebUI.scala index 1b670418ab1ff..bb11e0642ddc6 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterWebUI.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterWebUI.scala @@ -43,10 +43,10 @@ class MasterWebUI(val master: Master, requestedPort: Int) attachPage(new HistoryNotFoundPage(this)) attachPage(masterPage) attachHandler(createStaticHandler(MasterWebUI.STATIC_RESOURCE_DIR, "/static")) - attachHandler( - createRedirectHandler("/app/kill", "/", masterPage.handleAppKillRequest)) - attachHandler( - createRedirectHandler("/driver/kill", "/", masterPage.handleDriverKillRequest)) + attachHandler(createRedirectHandler( + "/app/kill", "/", masterPage.handleAppKillRequest, httpMethod = "POST")) + attachHandler(createRedirectHandler( + "/driver/kill", "/", masterPage.handleDriverKillRequest, httpMethod = "POST")) } /** Attach a reconstructed UI to this Master UI. Only valid after bind(). */ diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/DriverRunner.scala b/core/src/main/scala/org/apache/spark/deploy/worker/DriverRunner.scala index e0948e16ef354..ef7a703bffe67 100644 --- a/core/src/main/scala/org/apache/spark/deploy/worker/DriverRunner.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/DriverRunner.scala @@ -24,14 +24,14 @@ import scala.collection.JavaConversions._ import akka.actor.ActorRef import com.google.common.base.Charsets.UTF_8 import com.google.common.io.Files -import org.apache.hadoop.fs.{FileUtil, Path} +import org.apache.hadoop.fs.Path -import org.apache.spark.{Logging, SparkConf} +import org.apache.spark.{Logging, SparkConf, SecurityManager} import org.apache.spark.deploy.{DriverDescription, SparkHadoopUtil} import org.apache.spark.deploy.DeployMessages.DriverStateChanged import org.apache.spark.deploy.master.DriverState import org.apache.spark.deploy.master.DriverState.DriverState -import org.apache.spark.util.{Clock, SystemClock} +import org.apache.spark.util.{Utils, Clock, SystemClock} /** * Manages the execution of one driver, including automatically restarting the driver on failure. @@ -44,7 +44,8 @@ private[deploy] class DriverRunner( val sparkHome: File, val driverDesc: DriverDescription, val worker: ActorRef, - val workerUrl: String) + val workerUrl: String, + val securityManager: SecurityManager) extends Logging { @volatile private var process: Option[Process] = None @@ -136,12 +137,9 @@ private[deploy] class DriverRunner( * Will throw an exception if there are errors downloading the jar. */ private def downloadUserJar(driverDir: File): String = { - val jarPath = new Path(driverDesc.jarUrl) val hadoopConf = SparkHadoopUtil.get.newConfiguration(conf) - val jarFileSystem = jarPath.getFileSystem(hadoopConf) - val destPath = new File(driverDir.getAbsolutePath, jarPath.getName) val jarFileName = jarPath.getName val localJarFile = new File(driverDir, jarFileName) @@ -149,7 +147,14 @@ private[deploy] class DriverRunner( if (!localJarFile.exists()) { // May already exist if running multiple workers on one node logInfo(s"Copying user jar $jarPath to $destPath") - FileUtil.copy(jarFileSystem, jarPath, destPath, false, hadoopConf) + Utils.fetchFile( + driverDesc.jarUrl, + driverDir, + conf, + securityManager, + hadoopConf, + System.currentTimeMillis(), + useCache = false) } if (!localJarFile.exists()) { // Verify copy succeeded diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala b/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala index c4c24a7866aa3..3ee2eb69e8a4e 100755 --- a/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala @@ -436,7 +436,8 @@ private[worker] class Worker( sparkHome, driverDesc.copy(command = Worker.maybeUpdateSSLSettings(driverDesc.command, conf)), self, - akkaUrl) + akkaUrl, + securityMgr) drivers(driverId) = driver driver.start() diff --git a/core/src/main/scala/org/apache/spark/executor/Executor.scala b/core/src/main/scala/org/apache/spark/executor/Executor.scala index 516f619529c48..327d155b38c22 100644 --- a/core/src/main/scala/org/apache/spark/executor/Executor.scala +++ b/core/src/main/scala/org/apache/spark/executor/Executor.scala @@ -21,7 +21,7 @@ import java.io.File import java.lang.management.ManagementFactory import java.net.URL import java.nio.ByteBuffer -import java.util.concurrent.ConcurrentHashMap +import java.util.concurrent.{ConcurrentHashMap, Executors, TimeUnit} import scala.collection.JavaConversions._ import scala.collection.mutable.{ArrayBuffer, HashMap} @@ -60,8 +60,6 @@ private[spark] class Executor( private val conf = env.conf - @volatile private var isStopped = false - // No ip or host:port - just hostname Utils.checkHost(executorHostname, "Expected executed slave to be a hostname") // must not have port specified. @@ -91,10 +89,7 @@ private[spark] class Executor( ExecutorEndpoint.EXECUTOR_ENDPOINT_NAME, new ExecutorEndpoint(env.rpcEnv, executorId)) // Whether to load classes in user jars before those in Spark jars - private val userClassPathFirst: Boolean = { - conf.getBoolean("spark.executor.userClassPathFirst", - conf.getBoolean("spark.files.userClassPathFirst", false)) - } + private val userClassPathFirst = conf.getBoolean("spark.executor.userClassPathFirst", false) // Create our ClassLoader // do this after SparkEnv creation so can access the SecurityManager @@ -114,6 +109,10 @@ private[spark] class Executor( // Maintains the list of running tasks. private val runningTasks = new ConcurrentHashMap[Long, TaskRunner] + // Executor for the heartbeat task. + private val heartbeater = Executors.newSingleThreadScheduledExecutor( + Utils.namedThreadFactory("driver-heartbeater")) + startDriverHeartbeater() def launchTask( @@ -138,7 +137,8 @@ private[spark] class Executor( def stop(): Unit = { env.metricsSystem.report() env.rpcEnv.stop(executorEndpoint) - isStopped = true + heartbeater.shutdown() + heartbeater.awaitTermination(10, TimeUnit.SECONDS) threadPool.shutdown() if (!isLocal) { env.stop() @@ -432,23 +432,17 @@ private[spark] class Executor( } /** - * Starts a thread to report heartbeat and partial metrics for active tasks to driver. - * This thread stops running when the executor is stopped. + * Schedules a task to report heartbeat and partial metrics for active tasks to driver. */ private def startDriverHeartbeater(): Unit = { val intervalMs = conf.getTimeAsMs("spark.executor.heartbeatInterval", "10s") - val thread = new Thread() { - override def run() { - // Sleep a random interval so the heartbeats don't end up in sync - Thread.sleep(intervalMs + (math.random * intervalMs).asInstanceOf[Int]) - while (!isStopped) { - reportHeartBeat() - Thread.sleep(intervalMs) - } - } + + // Wait a random interval so the heartbeats don't end up in sync + val initialDelay = intervalMs + (math.random * intervalMs).asInstanceOf[Int] + + val heartbeatTask = new Runnable() { + override def run(): Unit = Utils.logUncaughtExceptions(reportHeartBeat()) } - thread.setDaemon(true) - thread.setName("driver-heartbeater") - thread.start() + heartbeater.scheduleAtFixedRate(heartbeatTask, initialDelay, intervalMs, TimeUnit.MILLISECONDS) } } diff --git a/core/src/main/scala/org/apache/spark/network/nio/ConnectionManager.scala b/core/src/main/scala/org/apache/spark/network/nio/ConnectionManager.scala index 8e3c30fc3d781..5a74c13b38bf7 100644 --- a/core/src/main/scala/org/apache/spark/network/nio/ConnectionManager.scala +++ b/core/src/main/scala/org/apache/spark/network/nio/ConnectionManager.scala @@ -86,11 +86,11 @@ private[nio] class ConnectionManager( conf.get("spark.network.timeout", "120s")) // Get the thread counts from the Spark Configuration. - // + // // Even though the ThreadPoolExecutor constructor takes both a minimum and maximum value, // we only query for the minimum value because we are using LinkedBlockingDeque. - // - // The JavaDoc for ThreadPoolExecutor points out that when using a LinkedBlockingDeque (which is + // + // The JavaDoc for ThreadPoolExecutor points out that when using a LinkedBlockingDeque (which is // an unbounded queue) no more than corePoolSize threads will ever be created, so only the "min" // parameter is necessary. private val handlerThreadCount = conf.getInt("spark.core.connection.handler.threads.min", 20) @@ -989,6 +989,7 @@ private[nio] class ConnectionManager( def stop() { ackTimeoutMonitor.stop() + selector.wakeup() selectorThread.interrupt() selectorThread.join() selector.close() diff --git a/core/src/main/scala/org/apache/spark/rdd/DoubleRDDFunctions.scala b/core/src/main/scala/org/apache/spark/rdd/DoubleRDDFunctions.scala index 29ca3e9c4bd04..843a893235e56 100644 --- a/core/src/main/scala/org/apache/spark/rdd/DoubleRDDFunctions.scala +++ b/core/src/main/scala/org/apache/spark/rdd/DoubleRDDFunctions.scala @@ -31,7 +31,7 @@ import org.apache.spark.util.StatCounter class DoubleRDDFunctions(self: RDD[Double]) extends Logging with Serializable { /** Add up the elements in this RDD. */ def sum(): Double = { - self.reduce(_ + _) + self.fold(0.0)(_ + _) } /** diff --git a/core/src/main/scala/org/apache/spark/rdd/RDDCheckpointData.scala b/core/src/main/scala/org/apache/spark/rdd/RDDCheckpointData.scala index 6afd63d537d75..1722c27e55003 100644 --- a/core/src/main/scala/org/apache/spark/rdd/RDDCheckpointData.scala +++ b/core/src/main/scala/org/apache/spark/rdd/RDDCheckpointData.scala @@ -21,7 +21,7 @@ import scala.reflect.ClassTag import org.apache.hadoop.fs.Path -import org.apache.spark.{Logging, Partition, SerializableWritable, SparkException} +import org.apache.spark._ import org.apache.spark.scheduler.{ResultTask, ShuffleMapTask} /** @@ -83,7 +83,7 @@ private[spark] class RDDCheckpointData[T: ClassTag](@transient rdd: RDD[T]) } // Create the output path for the checkpoint - val path = new Path(rdd.context.checkpointDir.get, "rdd-" + rdd.id) + val path = RDDCheckpointData.rddCheckpointDataPath(rdd.context, rdd.id).get val fs = path.getFileSystem(rdd.context.hadoopConfiguration) if (!fs.mkdirs(path)) { throw new SparkException("Failed to create checkpoint path " + path) @@ -92,8 +92,13 @@ private[spark] class RDDCheckpointData[T: ClassTag](@transient rdd: RDD[T]) // Save to file, and reload it as an RDD val broadcastedConf = rdd.context.broadcast( new SerializableWritable(rdd.context.hadoopConfiguration)) - rdd.context.runJob(rdd, CheckpointRDD.writeToFile[T](path.toString, broadcastedConf) _) val newRDD = new CheckpointRDD[T](rdd.context, path.toString) + if (rdd.conf.getBoolean("spark.cleaner.referenceTracking.cleanCheckpoints", false)) { + rdd.context.cleaner.foreach { cleaner => + cleaner.registerRDDCheckpointDataForCleanup(newRDD, rdd.id) + } + } + rdd.context.runJob(rdd, CheckpointRDD.writeToFile[T](path.toString, broadcastedConf) _) if (newRDD.partitions.length != rdd.partitions.length) { throw new SparkException( "Checkpoint RDD " + newRDD + "(" + newRDD.partitions.length + ") has different " + @@ -130,5 +135,17 @@ private[spark] class RDDCheckpointData[T: ClassTag](@transient rdd: RDD[T]) } } -// Used for synchronization -private[spark] object RDDCheckpointData +private[spark] object RDDCheckpointData { + def rddCheckpointDataPath(sc: SparkContext, rddId: Int): Option[Path] = { + sc.checkpointDir.map { dir => new Path(dir, "rdd-" + rddId) } + } + + def clearRDDCheckpointData(sc: SparkContext, rddId: Int): Unit = { + rddCheckpointDataPath(sc, rddId).foreach { path => + val fs = path.getFileSystem(sc.hadoopConfiguration) + if (fs.exists(path)) { + fs.delete(path, true) + } + } + } +} diff --git a/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala b/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala index e259867c14040..f2c1c86af767e 100644 --- a/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala +++ b/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala @@ -284,7 +284,7 @@ private[spark] abstract class RpcEndpointRef(@transient conf: SparkConf) private[this] val maxRetries = conf.getInt("spark.akka.num.retries", 3) private[this] val retryWaitMs = conf.getLong("spark.akka.retry.wait", 3000) - private[this] val defaultTimeout = conf.getLong("spark.akka.lookupTimeout", 30) seconds + private[this] val defaultAskTimeout = conf.getLong("spark.akka.askTimeout", 30) seconds /** * return the address for the [[RpcEndpointRef]] @@ -304,7 +304,8 @@ private[spark] abstract class RpcEndpointRef(@transient conf: SparkConf) * * This method only sends the message once and never retries. */ - def sendWithReply[T: ClassTag](message: Any): Future[T] = sendWithReply(message, defaultTimeout) + def sendWithReply[T: ClassTag](message: Any): Future[T] = + sendWithReply(message, defaultAskTimeout) /** * Send a message to the corresponding [[RpcEndpoint.receiveAndReply)]] and return a `Future` to @@ -327,7 +328,7 @@ private[spark] abstract class RpcEndpointRef(@transient conf: SparkConf) * @tparam T type of the reply message * @return the reply message from the corresponding [[RpcEndpoint]] */ - def askWithReply[T: ClassTag](message: Any): T = askWithReply(message, defaultTimeout) + def askWithReply[T: ClassTag](message: Any): T = askWithReply(message, defaultAskTimeout) /** * Send a message to the corresponding [[RpcEndpoint.receive]] and get its result within a diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala index 508fe7b3303ca..4a32f8936fb0e 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala @@ -818,12 +818,7 @@ class DAGScheduler( } } - val properties = if (jobIdToActiveJob.contains(jobId)) { - jobIdToActiveJob(stage.jobId).properties - } else { - // this stage will be assigned to "default" pool - null - } + val properties = jobIdToActiveJob.get(stage.jobId).map(_.properties).orNull runningStages += stage // SparkListenerStageSubmitted should be posted before testing whether tasks are diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala index 6dd51596cbe47..b6a84d3542a5f 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala @@ -142,11 +142,10 @@ private[spark] class TaskSchedulerImpl( if (!isLocal && conf.getBoolean("spark.speculation", false)) { logInfo("Starting speculative execution thread") - import sc.env.actorSystem.dispatcher sc.env.actorSystem.scheduler.schedule(SPECULATION_INTERVAL_MS milliseconds, SPECULATION_INTERVAL_MS milliseconds) { Utils.tryOrStopSparkContext(sc) { checkSpeculatableTasks() } - } + }(sc.env.actorSystem.dispatcher) } } @@ -394,7 +393,7 @@ private[spark] class TaskSchedulerImpl( def error(message: String) { synchronized { - if (activeTaskSets.size > 0) { + if (activeTaskSets.nonEmpty) { // Have each task set throw a SparkException with the error for ((taskSetId, manager) <- activeTaskSets) { try { @@ -407,8 +406,7 @@ private[spark] class TaskSchedulerImpl( // No task sets are active but we still got an error. Just exit since this // must mean the error is during registration. // It might be good to do something smarter here in the future. - logError("Exiting due to error from cluster scheduler: " + message) - System.exit(1) + throw new SparkException(s"Exiting due to error from cluster scheduler: $message") } } } diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala index 7eb3fdc19b5b8..ccf1dc5af6120 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala @@ -82,12 +82,11 @@ private[spark] class SparkDeploySchedulerBackend( val command = Command("org.apache.spark.executor.CoarseGrainedExecutorBackend", args, sc.executorEnvs, classPathEntries ++ testingClassPath, libraryPathEntries, javaOpts) val appUIAddress = sc.ui.map(_.appUIAddress).getOrElse("") - val appDesc = new ApplicationDescription(sc.appName, maxCores, sc.executorMemory, command, - appUIAddress, sc.eventLogDir, sc.eventLogCodec) - + val coresPerExecutor = conf.getOption("spark.executor.cores").map(_.toInt) + val appDesc = new ApplicationDescription(sc.appName, maxCores, sc.executorMemory, + command, appUIAddress, sc.eventLogDir, sc.eventLogCodec, coresPerExecutor) client = new AppClient(sc.env.actorSystem, masters, appDesc, this, conf) client.start() - waitForRegistration() } @@ -119,9 +118,12 @@ private[spark] class SparkDeploySchedulerBackend( notifyContext() if (!stopping) { logError("Application has been killed. Reason: " + reason) - scheduler.error(reason) - // Ensure the application terminates, as we can no longer run jobs. - sc.stop() + try { + scheduler.error(reason) + } finally { + // Ensure the application terminates, as we can no longer run jobs. + sc.stop() + } } } diff --git a/core/src/main/scala/org/apache/spark/scheduler/local/LocalBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/local/LocalBackend.scala index 70a477a6895cc..50ba0b9d5a612 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/local/LocalBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/local/LocalBackend.scala @@ -20,12 +20,12 @@ package org.apache.spark.scheduler.local import java.nio.ByteBuffer import java.util.concurrent.{Executors, TimeUnit} -import org.apache.spark.rpc.{ThreadSafeRpcEndpoint, RpcEndpointRef, RpcEnv} -import org.apache.spark.util.Utils -import org.apache.spark.{Logging, SparkContext, SparkEnv, TaskState} +import org.apache.spark.{Logging, SparkConf, SparkContext, SparkEnv, TaskState} import org.apache.spark.TaskState.TaskState import org.apache.spark.executor.{Executor, ExecutorBackend} +import org.apache.spark.rpc.{ThreadSafeRpcEndpoint, RpcCallContext, RpcEndpointRef, RpcEnv} import org.apache.spark.scheduler.{SchedulerBackend, TaskSchedulerImpl, WorkerOffer} +import org.apache.spark.util.Utils private case class ReviveOffers() @@ -71,11 +71,15 @@ private[spark] class LocalEndpoint( case KillTask(taskId, interruptThread) => executor.killTask(taskId, interruptThread) + } + override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = { case StopExecutor => executor.stop() + context.reply(true) } + def reviveOffers() { val offers = Seq(new WorkerOffer(localExecutorId, localExecutorHostname, freeCores)) val tasks = scheduler.resourceOffers(offers).flatten @@ -104,8 +108,11 @@ private[spark] class LocalEndpoint( * master all run in the same JVM. It sits behind a TaskSchedulerImpl and handles launching tasks * on a single Executor (created by the LocalBackend) running locally. */ -private[spark] class LocalBackend(scheduler: TaskSchedulerImpl, val totalCores: Int) - extends SchedulerBackend with ExecutorBackend { +private[spark] class LocalBackend( + conf: SparkConf, + scheduler: TaskSchedulerImpl, + val totalCores: Int) + extends SchedulerBackend with ExecutorBackend with Logging { private val appId = "local-" + System.currentTimeMillis var localEndpoint: RpcEndpointRef = null @@ -116,7 +123,7 @@ private[spark] class LocalBackend(scheduler: TaskSchedulerImpl, val totalCores: } override def stop() { - localEndpoint.send(StopExecutor) + localEndpoint.sendWithReply(StopExecutor) } override def reviveOffers() { diff --git a/core/src/main/scala/org/apache/spark/ui/JettyUtils.scala b/core/src/main/scala/org/apache/spark/ui/JettyUtils.scala index 95f254a9ef22a..a091ca650c60c 100644 --- a/core/src/main/scala/org/apache/spark/ui/JettyUtils.scala +++ b/core/src/main/scala/org/apache/spark/ui/JettyUtils.scala @@ -114,10 +114,23 @@ private[spark] object JettyUtils extends Logging { srcPath: String, destPath: String, beforeRedirect: HttpServletRequest => Unit = x => (), - basePath: String = ""): ServletContextHandler = { + basePath: String = "", + httpMethod: String = "GET"): ServletContextHandler = { val prefixedDestPath = attachPrefix(basePath, destPath) val servlet = new HttpServlet { - override def doGet(request: HttpServletRequest, response: HttpServletResponse) { + override def doGet(request: HttpServletRequest, response: HttpServletResponse): Unit = { + httpMethod match { + case "GET" => doRequest(request, response) + case _ => response.sendError(HttpServletResponse.SC_METHOD_NOT_ALLOWED) + } + } + override def doPost(request: HttpServletRequest, response: HttpServletResponse): Unit = { + httpMethod match { + case "POST" => doRequest(request, response) + case _ => response.sendError(HttpServletResponse.SC_METHOD_NOT_ALLOWED) + } + } + private def doRequest(request: HttpServletRequest, response: HttpServletResponse): Unit = { beforeRedirect(request) // Make sure we don't end up with "//" in the middle val newUrl = new URL(new URL(request.getRequestURL.toString), prefixedDestPath).toString diff --git a/core/src/main/scala/org/apache/spark/ui/SparkUI.scala b/core/src/main/scala/org/apache/spark/ui/SparkUI.scala index adfa6bbada256..580ab8b1325f8 100644 --- a/core/src/main/scala/org/apache/spark/ui/SparkUI.scala +++ b/core/src/main/scala/org/apache/spark/ui/SparkUI.scala @@ -55,8 +55,8 @@ private[spark] class SparkUI private ( attachTab(new ExecutorsTab(this)) attachHandler(createStaticHandler(SparkUI.STATIC_RESOURCE_DIR, "/static")) attachHandler(createRedirectHandler("/", "/jobs", basePath = basePath)) - attachHandler( - createRedirectHandler("/stages/stage/kill", "/stages", stagesTab.handleKillRequest)) + attachHandler(createRedirectHandler( + "/stages/stage/kill", "/stages", stagesTab.handleKillRequest, httpMethod = "POST")) } initialize() diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/StageTable.scala b/core/src/main/scala/org/apache/spark/ui/jobs/StageTable.scala index 5865850fa09b5..cb72890a0fd20 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/StageTable.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/StageTable.scala @@ -73,20 +73,21 @@ private[ui] class StageTableBase( } private def makeDescription(s: StageInfo): Seq[Node] = { - // scalastyle:off + val basePathUri = UIUtils.prependBaseUri(basePath) + val killLink = if (killEnabled) { - val killLinkUri = "%s/stages/stage/kill?id=%s&terminate=true" - .format(UIUtils.prependBaseUri(basePath), s.stageId) - val confirm = "return window.confirm('Are you sure you want to kill stage %s ?');" - .format(s.stageId) - - (kill) - + val killLinkUri = s"$basePathUri/stages/stage/kill/" + val confirm = + s"if (window.confirm('Are you sure you want to kill stage ${s.stageId} ?')) " + + "{ this.parentNode.submit(); return true; } else { return false; }" +
    + + + (kill) +
    } - // scalastyle:on - val nameLinkUri ="%s/stages/stage?id=%s&attempt=%s" - .format(UIUtils.prependBaseUri(basePath), s.stageId, s.attemptId) + val nameLinkUri = s"$basePathUri/stages/stage?id=${s.stageId}&attempt=${s.attemptId}" val nameLink = {s.name} val cachedRddInfos = s.rddInfos.filter(_.numCachedPartitions > 0) @@ -98,11 +99,9 @@ private[ui] class StageTableBase( diff --git a/core/src/main/scala/org/apache/spark/util/ActorLogReceive.scala b/core/src/main/scala/org/apache/spark/util/ActorLogReceive.scala index 332d0cbb2dc0c..81a7cbde01ce5 100644 --- a/core/src/main/scala/org/apache/spark/util/ActorLogReceive.scala +++ b/core/src/main/scala/org/apache/spark/util/ActorLogReceive.scala @@ -43,7 +43,13 @@ private[spark] trait ActorLogReceive { private val _receiveWithLogging = receiveWithLogging - override def isDefinedAt(o: Any): Boolean = _receiveWithLogging.isDefinedAt(o) + override def isDefinedAt(o: Any): Boolean = { + val handled = _receiveWithLogging.isDefinedAt(o) + if (!handled) { + log.debug(s"Received unexpected actor system event: $o") + } + handled + } override def apply(o: Any): Unit = { if (log.isDebugEnabled) { diff --git a/core/src/main/scala/org/apache/spark/util/ListenerBus.scala b/core/src/main/scala/org/apache/spark/util/ListenerBus.scala index d60b8b9a31a9b..a725767d08cc2 100644 --- a/core/src/main/scala/org/apache/spark/util/ListenerBus.scala +++ b/core/src/main/scala/org/apache/spark/util/ListenerBus.scala @@ -19,9 +19,12 @@ package org.apache.spark.util import java.util.concurrent.CopyOnWriteArrayList +import scala.collection.JavaConversions._ +import scala.reflect.ClassTag import scala.util.control.NonFatal import org.apache.spark.Logging +import org.apache.spark.scheduler.SparkListener /** * An event bus which posts events to its listeners. @@ -64,4 +67,9 @@ private[spark] trait ListenerBus[L <: AnyRef, E] extends Logging { */ def onPostEvent(listener: L, event: E): Unit + private[spark] def findListenersByClass[T <: L : ClassTag](): Seq[T] = { + val c = implicitly[ClassTag[T]].runtimeClass + listeners.filter(_.getClass == c).map(_.asInstanceOf[T]).toSeq + } + } diff --git a/core/src/test/scala/org/apache/spark/ContextCleanerSuite.scala b/core/src/test/scala/org/apache/spark/ContextCleanerSuite.scala index 1de169d964d23..097e7076e5391 100644 --- a/core/src/test/scala/org/apache/spark/ContextCleanerSuite.scala +++ b/core/src/test/scala/org/apache/spark/ContextCleanerSuite.scala @@ -28,7 +28,8 @@ import org.scalatest.concurrent.{PatienceConfiguration, Eventually} import org.scalatest.concurrent.Eventually._ import org.scalatest.time.SpanSugar._ -import org.apache.spark.rdd.RDD +import org.apache.spark.SparkContext._ +import org.apache.spark.rdd.{RDDCheckpointData, RDD} import org.apache.spark.storage._ import org.apache.spark.shuffle.hash.HashShuffleManager import org.apache.spark.shuffle.sort.SortShuffleManager @@ -205,6 +206,52 @@ class ContextCleanerSuite extends ContextCleanerSuiteBase { postGCTester.assertCleanup() } + test("automatically cleanup checkpoint") { + val checkpointDir = java.io.File.createTempFile("temp", "") + checkpointDir.deleteOnExit() + checkpointDir.delete() + var rdd = newPairRDD + sc.setCheckpointDir(checkpointDir.toString) + rdd.checkpoint() + rdd.cache() + rdd.collect() + var rddId = rdd.id + + // Confirm the checkpoint directory exists + assert(RDDCheckpointData.rddCheckpointDataPath(sc, rddId).isDefined) + val path = RDDCheckpointData.rddCheckpointDataPath(sc, rddId).get + val fs = path.getFileSystem(sc.hadoopConfiguration) + assert(fs.exists(path)) + + // the checkpoint is not cleaned by default (without the configuration set) + var postGCTester = new CleanerTester(sc, Seq(rddId), Nil, Nil) + rdd = null // Make RDD out of scope + runGC() + postGCTester.assertCleanup() + assert(fs.exists(RDDCheckpointData.rddCheckpointDataPath(sc, rddId).get)) + + sc.stop() + val conf = new SparkConf().setMaster("local[2]").setAppName("cleanupCheckpoint"). + set("spark.cleaner.referenceTracking.cleanCheckpoints", "true") + sc = new SparkContext(conf) + rdd = newPairRDD + sc.setCheckpointDir(checkpointDir.toString) + rdd.checkpoint() + rdd.cache() + rdd.collect() + rddId = rdd.id + + // Confirm the checkpoint directory exists + assert(fs.exists(RDDCheckpointData.rddCheckpointDataPath(sc, rddId).get)) + + // Test that GC causes checkpoint data cleanup after dereferencing the RDD + postGCTester = new CleanerTester(sc, Seq(rddId), Nil, Nil) + rdd = null // Make RDD out of scope + runGC() + postGCTester.assertCleanup() + assert(!fs.exists(RDDCheckpointData.rddCheckpointDataPath(sc, rddId).get)) + } + test("automatically cleanup RDD + shuffle + broadcast") { val numRdds = 100 val numBroadcasts = 4 // Broadcasts are more costly diff --git a/core/src/test/scala/org/apache/spark/ExecutorAllocationManagerSuite.scala b/core/src/test/scala/org/apache/spark/ExecutorAllocationManagerSuite.scala index 6b3049b28cd5e..22acc270b983e 100644 --- a/core/src/test/scala/org/apache/spark/ExecutorAllocationManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/ExecutorAllocationManagerSuite.scala @@ -56,19 +56,13 @@ class ExecutorAllocationManagerSuite extends FunSuite with LocalSparkContext wit // Min < 0 val conf1 = conf.clone().set("spark.dynamicAllocation.minExecutors", "-1") intercept[SparkException] { contexts += new SparkContext(conf1) } - SparkEnv.get.stop() - SparkContext.clearActiveContext() // Max < 0 val conf2 = conf.clone().set("spark.dynamicAllocation.maxExecutors", "-1") intercept[SparkException] { contexts += new SparkContext(conf2) } - SparkEnv.get.stop() - SparkContext.clearActiveContext() // Both min and max, but min > max intercept[SparkException] { createSparkContext(2, 1) } - SparkEnv.get.stop() - SparkContext.clearActiveContext() // Both min and max, and min == max val sc1 = createSparkContext(1, 1) diff --git a/core/src/test/scala/org/apache/spark/SparkConfSuite.scala b/core/src/test/scala/org/apache/spark/SparkConfSuite.scala index e08210ae60d17..7d87ba5fd2610 100644 --- a/core/src/test/scala/org/apache/spark/SparkConfSuite.scala +++ b/core/src/test/scala/org/apache/spark/SparkConfSuite.scala @@ -197,6 +197,28 @@ class SparkConfSuite extends FunSuite with LocalSparkContext with ResetSystemPro serializer.newInstance().serialize(new StringBuffer()) } + test("deprecated configs") { + val conf = new SparkConf() + val newName = "spark.history.fs.update.interval" + + assert(!conf.contains(newName)) + + conf.set("spark.history.updateInterval", "1") + assert(conf.get(newName) === "1") + + conf.set("spark.history.fs.updateInterval", "2") + assert(conf.get(newName) === "2") + + conf.set("spark.history.fs.update.interval.seconds", "3") + assert(conf.get(newName) === "3") + + conf.set(newName, "4") + assert(conf.get(newName) === "4") + + val count = conf.getAll.filter { case (k, v) => k.startsWith("spark.history.") }.size + assert(count === 4) + } + } class Class1 {} diff --git a/core/src/test/scala/org/apache/spark/deploy/JsonProtocolSuite.scala b/core/src/test/scala/org/apache/spark/deploy/JsonProtocolSuite.scala index 2071701b313db..b58d62567afe1 100644 --- a/core/src/test/scala/org/apache/spark/deploy/JsonProtocolSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/JsonProtocolSuite.scala @@ -28,7 +28,7 @@ import org.scalatest.FunSuite import org.apache.spark.deploy.DeployMessages.{MasterStateResponse, WorkerStateResponse} import org.apache.spark.deploy.master.{ApplicationInfo, DriverInfo, RecoveryState, WorkerInfo} import org.apache.spark.deploy.worker.{DriverRunner, ExecutorRunner} -import org.apache.spark.SparkConf +import org.apache.spark.{SecurityManager, SparkConf} class JsonProtocolSuite extends FunSuite { @@ -124,8 +124,9 @@ class JsonProtocolSuite extends FunSuite { } def createDriverRunner(): DriverRunner = { - new DriverRunner(new SparkConf(), "driverId", new File("workDir"), new File("sparkHome"), - createDriverDesc(), null, "akka://worker") + val conf = new SparkConf() + new DriverRunner(conf, "driverId", new File("workDir"), new File("sparkHome"), + createDriverDesc(), null, "akka://worker", new SecurityManager(conf)) } def assertValidJson(json: JValue) { diff --git a/core/src/test/scala/org/apache/spark/deploy/LogUrlsStandaloneSuite.scala b/core/src/test/scala/org/apache/spark/deploy/LogUrlsStandaloneSuite.scala index 9cdb42814ca32..c93d16f8a1586 100644 --- a/core/src/test/scala/org/apache/spark/deploy/LogUrlsStandaloneSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/LogUrlsStandaloneSuite.scala @@ -19,6 +19,7 @@ package org.apache.spark.deploy import java.net.URL +import scala.collection.JavaConversions._ import scala.collection.mutable import scala.io.Source @@ -65,16 +66,17 @@ class LogUrlsStandaloneSuite extends FunSuite with LocalSparkContext { new MySparkConf().setAll(getAll) } } - val conf = new MySparkConf() + val conf = new MySparkConf().set( + "spark.extraListeners", classOf[SaveExecutorInfo].getName) sc = new SparkContext("local-cluster[2,1,512]", "test", conf) - val listener = new SaveExecutorInfo - sc.addSparkListener(listener) - // Trigger a job so that executors get added sc.parallelize(1 to 100, 4).map(_.toString).count() assert(sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS)) + val listeners = sc.listenerBus.findListenersByClass[SaveExecutorInfo] + assert(listeners.size === 1) + val listener = listeners(0) listener.addedExecutorInfos.values.foreach { info => assert(info.logUrlMap.nonEmpty) info.logUrlMap.values.foreach { logUrl => @@ -82,12 +84,12 @@ class LogUrlsStandaloneSuite extends FunSuite with LocalSparkContext { } } } +} - private class SaveExecutorInfo extends SparkListener { - val addedExecutorInfos = mutable.Map[String, ExecutorInfo]() +private[spark] class SaveExecutorInfo extends SparkListener { + val addedExecutorInfos = mutable.Map[String, ExecutorInfo]() - override def onExecutorAdded(executor: SparkListenerExecutorAdded) { - addedExecutorInfos(executor.executorId) = executor.executorInfo - } + override def onExecutorAdded(executor: SparkListenerExecutorAdded) { + addedExecutorInfos(executor.executorId) = executor.executorInfo } } diff --git a/core/src/test/scala/org/apache/spark/deploy/worker/DriverRunnerTest.scala b/core/src/test/scala/org/apache/spark/deploy/worker/DriverRunnerTest.scala index aa6e4874cecde..2159fd8c16c6f 100644 --- a/core/src/test/scala/org/apache/spark/deploy/worker/DriverRunnerTest.scala +++ b/core/src/test/scala/org/apache/spark/deploy/worker/DriverRunnerTest.scala @@ -25,7 +25,7 @@ import org.mockito.invocation.InvocationOnMock import org.mockito.stubbing.Answer import org.scalatest.FunSuite -import org.apache.spark.SparkConf +import org.apache.spark.{SecurityManager, SparkConf} import org.apache.spark.deploy.{Command, DriverDescription} import org.apache.spark.util.Clock @@ -33,8 +33,9 @@ class DriverRunnerTest extends FunSuite { private def createDriverRunner() = { val command = new Command("mainClass", Seq(), Map(), Seq(), Seq(), Seq()) val driverDescription = new DriverDescription("jarUrl", 512, 1, true, command) - new DriverRunner(new SparkConf(), "driverId", new File("workDir"), new File("sparkHome"), - driverDescription, null, "akka://1.2.3.4/worker/") + val conf = new SparkConf() + new DriverRunner(conf, "driverId", new File("workDir"), new File("sparkHome"), + driverDescription, null, "akka://1.2.3.4/worker/", new SecurityManager(conf)) } private def createProcessBuilderAndProcess(): (ProcessBuilderLike, Process) = { diff --git a/core/src/test/scala/org/apache/spark/rdd/DoubleRDDSuite.scala b/core/src/test/scala/org/apache/spark/rdd/DoubleRDDSuite.scala index 97079382c716f..01039b9449daf 100644 --- a/core/src/test/scala/org/apache/spark/rdd/DoubleRDDSuite.scala +++ b/core/src/test/scala/org/apache/spark/rdd/DoubleRDDSuite.scala @@ -22,6 +22,12 @@ import org.scalatest.FunSuite import org.apache.spark._ class DoubleRDDSuite extends FunSuite with SharedSparkContext { + test("sum") { + assert(sc.parallelize(Seq.empty[Double]).sum() === 0.0) + assert(sc.parallelize(Seq(1.0)).sum() === 1.0) + assert(sc.parallelize(Seq(1.0, 2.0)).sum() === 3.0) + } + // Verify tests on the histogram functionality. We test with both evenly // and non-evenly spaced buckets as the bucket lookup function changes. test("WorksOnEmpty") { diff --git a/core/src/test/scala/org/apache/spark/ui/UISeleniumSuite.scala b/core/src/test/scala/org/apache/spark/ui/UISeleniumSuite.scala index 1cb594633f331..eb9db550fd74c 100644 --- a/core/src/test/scala/org/apache/spark/ui/UISeleniumSuite.scala +++ b/core/src/test/scala/org/apache/spark/ui/UISeleniumSuite.scala @@ -17,6 +17,7 @@ package org.apache.spark.ui +import java.net.{HttpURLConnection, URL} import javax.servlet.http.HttpServletRequest import scala.collection.JavaConversions._ @@ -56,12 +57,13 @@ class UISeleniumSuite extends FunSuite with WebBrowser with Matchers with Before * Create a test SparkContext with the SparkUI enabled. * It is safe to `get` the SparkUI directly from the SparkContext returned here. */ - private def newSparkContext(): SparkContext = { + private def newSparkContext(killEnabled: Boolean = true): SparkContext = { val conf = new SparkConf() .setMaster("local") .setAppName("test") .set("spark.ui.enabled", "true") .set("spark.ui.port", "0") + .set("spark.ui.killEnabled", killEnabled.toString) val sc = new SparkContext(conf) assert(sc.ui.isDefined) sc @@ -128,21 +130,12 @@ class UISeleniumSuite extends FunSuite with WebBrowser with Matchers with Before } test("spark.ui.killEnabled should properly control kill button display") { - def getSparkContext(killEnabled: Boolean): SparkContext = { - val conf = new SparkConf() - .setMaster("local") - .setAppName("test") - .set("spark.ui.enabled", "true") - .set("spark.ui.killEnabled", killEnabled.toString) - new SparkContext(conf) - } - def hasKillLink: Boolean = find(className("kill-link")).isDefined def runSlowJob(sc: SparkContext) { sc.parallelize(1 to 10).map{x => Thread.sleep(10000); x}.countAsync() } - withSpark(getSparkContext(killEnabled = true)) { sc => + withSpark(newSparkContext(killEnabled = true)) { sc => runSlowJob(sc) eventually(timeout(5 seconds), interval(50 milliseconds)) { go to (sc.ui.get.appUIAddress.stripSuffix("/") + "/stages") @@ -150,7 +143,7 @@ class UISeleniumSuite extends FunSuite with WebBrowser with Matchers with Before } } - withSpark(getSparkContext(killEnabled = false)) { sc => + withSpark(newSparkContext(killEnabled = false)) { sc => runSlowJob(sc) eventually(timeout(5 seconds), interval(50 milliseconds)) { go to (sc.ui.get.appUIAddress.stripSuffix("/") + "/stages") @@ -233,7 +226,7 @@ class UISeleniumSuite extends FunSuite with WebBrowser with Matchers with Before // because someone could change the error message and cause this test to pass by accident. // Instead, it's safer to check that each row contains a link to a stage details page. findAll(cssSelector("tbody tr")).foreach { row => - val link = row.underlying.findElement(By.xpath(".//a")) + val link = row.underlying.findElement(By.xpath("./td/div/a")) link.getAttribute("href") should include ("stage") } } @@ -356,4 +349,25 @@ class UISeleniumSuite extends FunSuite with WebBrowser with Matchers with Before } } } + + test("kill stage is POST only") { + def getResponseCode(url: URL, method: String): Int = { + val connection = url.openConnection().asInstanceOf[HttpURLConnection] + connection.setRequestMethod(method) + connection.connect() + val code = connection.getResponseCode() + connection.disconnect() + code + } + + withSpark(newSparkContext(killEnabled = true)) { sc => + sc.parallelize(1 to 10).map{x => Thread.sleep(10000); x}.countAsync() + eventually(timeout(5 seconds), interval(50 milliseconds)) { + val url = new URL( + sc.ui.get.appUIAddress.stripSuffix("/") + "/stages/stage/kill/?id=0&terminate=true") + getResponseCode(url, "GET") should be (405) + getResponseCode(url, "POST") should be (200) + } + } + } } diff --git a/dev/audit-release/sbt_app_sql/src/main/scala/SqlApp.scala b/dev/audit-release/sbt_app_sql/src/main/scala/SqlApp.scala index d888de929fdda..cc86ef45858c9 100644 --- a/dev/audit-release/sbt_app_sql/src/main/scala/SqlApp.scala +++ b/dev/audit-release/sbt_app_sql/src/main/scala/SqlApp.scala @@ -36,8 +36,10 @@ object SparkSqlExample { val sc = new SparkContext(conf) val sqlContext = new SQLContext(sc) + import sqlContext.implicits._ import sqlContext._ - val people = sc.makeRDD(1 to 100, 10).map(x => Person(s"Name$x", x)) + + val people = sc.makeRDD(1 to 100, 10).map(x => Person(s"Name$x", x)).toDF() people.registerTempTable("people") val teenagers = sql("SELECT name FROM people WHERE age >= 13 AND age <= 19") val teenagerNames = teenagers.map(t => "Name: " + t(0)).collect() diff --git a/dev/run-tests b/dev/run-tests index bb21ab6c9aa04..861d1671182c2 100755 --- a/dev/run-tests +++ b/dev/run-tests @@ -235,6 +235,8 @@ echo "=========================================================================" CURRENT_BLOCK=$BLOCK_PYSPARK_UNIT_TESTS +# add path for python 3 in jenkins +export PATH="${PATH}:/home/anaonda/envs/py3k/bin" ./python/run-tests echo "" diff --git a/dev/run-tests-jenkins b/dev/run-tests-jenkins index f6372835a6dbf..030f2cdddb350 100755 --- a/dev/run-tests-jenkins +++ b/dev/run-tests-jenkins @@ -47,7 +47,7 @@ COMMIT_URL="https://github.com/apache/spark/commit/${ghprbActualCommit}" # GitHub doesn't auto-link short hashes when submitted via the API, unfortunately. :( SHORT_COMMIT_HASH="${ghprbActualCommit:0:7}" -TESTS_TIMEOUT="120m" # format: http://linux.die.net/man/1/timeout +TESTS_TIMEOUT="150m" # format: http://linux.die.net/man/1/timeout # Array to capture all tests to run on the pull request. These tests are held under the #+ dev/tests/ directory. @@ -161,6 +161,10 @@ pr_message="" # Ensure we save off the current HEAD to revert to current_pr_head="`git rev-parse HEAD`" +echo "HEAD: `git rev-parse HEAD`" +echo "GHPRB: $ghprbActualCommit" +echo "SHA1: $sha1" + # Run pull request tests for t in "${PR_TESTS[@]}"; do this_test="${FWDIR}/dev/tests/${t}.sh" diff --git a/dev/tests/pr_new_dependencies.sh b/dev/tests/pr_new_dependencies.sh index 370c7cc737bbd..fdfb3c62aff58 100755 --- a/dev/tests/pr_new_dependencies.sh +++ b/dev/tests/pr_new_dependencies.sh @@ -39,12 +39,12 @@ CURR_CP_FILE="my-classpath.txt" MASTER_CP_FILE="master-classpath.txt" # First switch over to the master branch -git checkout master &>/dev/null +git checkout -f master # Find and copy all pom.xml files into a *.gate file that we can check # against through various `git` changes find -name "pom.xml" -exec cp {} {}.gate \; # Switch back to the current PR -git checkout "${current_pr_head}" &>/dev/null +git checkout -f "${current_pr_head}" # Check if any *.pom files from the current branch are different from the master difference_q="" @@ -71,7 +71,7 @@ else sort > ${CURR_CP_FILE} # Checkout the master branch to compare against - git checkout master &>/dev/null + git checkout -f master ${MVN_BIN} clean package dependency:build-classpath -DskipTests 2>/dev/null | \ sed -n -e '/Building Spark Project Assembly/,$p' | \ @@ -84,7 +84,7 @@ else rev | \ sort > ${MASTER_CP_FILE} - DIFF_RESULTS="`diff my-classpath.txt master-classpath.txt`" + DIFF_RESULTS="`diff ${CURR_CP_FILE} ${MASTER_CP_FILE}`" if [ -z "${DIFF_RESULTS}" ]; then echo " * This patch does not change any dependencies." diff --git a/docs/configuration.md b/docs/configuration.md index 7169ec295ef7f..d9e9e67026cbb 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -723,6 +723,17 @@ Apart from these, the following properties are also available, and may be useful this duration will be cleared as well. + + spark.executor.cores + 1 in YARN mode, all the available cores on the worker in standalone mode. + + The number of cores to use on each executor. For YARN and standalone mode only. + + In standalone mode, setting this parameter allows an application to run multiple executors on + the same worker, provided that there are enough cores on that worker. Otherwise, only one + executor per application will run on each worker. + + spark.default.parallelism diff --git a/docs/monitoring.md b/docs/monitoring.md index 6816671ffbf46..2a130224591ca 100644 --- a/docs/monitoring.md +++ b/docs/monitoring.md @@ -86,10 +86,10 @@ follows: - spark.history.fs.update.interval.seconds - 10 + spark.history.fs.update.interval + 10s - The period, in seconds, at which information displayed by this history server is updated. + The period at which information displayed by this history server is updated. Each update checks for any changes made to the event logs in persisted storage. diff --git a/docs/sql-programming-guide.md b/docs/sql-programming-guide.md index 332618edf0c55..03500867df70f 100644 --- a/docs/sql-programming-guide.md +++ b/docs/sql-programming-guide.md @@ -1371,7 +1371,10 @@ the Data Sources API. The following options are supported: These options must all be specified if any of them is specified. They describe how to partition the table when reading in parallel from multiple workers. - partitionColumn must be a numeric column from the table in question. + partitionColumn must be a numeric column from the table in question. Notice + that lowerBound and upperBound are just used to decide the + partition stride, not for filtering the rows in table. So all rows in the table will be + partitioned and returned. diff --git a/ec2/spark_ec2.py b/ec2/spark_ec2.py index 0c1f24761d0de..87c0818279713 100755 --- a/ec2/spark_ec2.py +++ b/ec2/spark_ec2.py @@ -19,7 +19,7 @@ # limitations under the License. # -from __future__ import with_statement +from __future__ import with_statement, print_function import hashlib import itertools @@ -37,12 +37,17 @@ import tempfile import textwrap import time -import urllib2 import warnings from datetime import datetime from optparse import OptionParser from sys import stderr +if sys.version < "3": + from urllib2 import urlopen, Request, HTTPError +else: + from urllib.request import urlopen, Request + from urllib.error import HTTPError + SPARK_EC2_VERSION = "1.2.1" SPARK_EC2_DIR = os.path.dirname(os.path.realpath(__file__)) @@ -88,10 +93,10 @@ def setup_external_libs(libs): SPARK_EC2_LIB_DIR = os.path.join(SPARK_EC2_DIR, "lib") if not os.path.exists(SPARK_EC2_LIB_DIR): - print "Downloading external libraries that spark-ec2 needs from PyPI to {path}...".format( + print("Downloading external libraries that spark-ec2 needs from PyPI to {path}...".format( path=SPARK_EC2_LIB_DIR - ) - print "This should be a one-time operation." + )) + print("This should be a one-time operation.") os.mkdir(SPARK_EC2_LIB_DIR) for lib in libs: @@ -100,8 +105,8 @@ def setup_external_libs(libs): if not os.path.isdir(lib_dir): tgz_file_path = os.path.join(SPARK_EC2_LIB_DIR, versioned_lib_name + ".tar.gz") - print " - Downloading {lib}...".format(lib=lib["name"]) - download_stream = urllib2.urlopen( + print(" - Downloading {lib}...".format(lib=lib["name"])) + download_stream = urlopen( "{prefix}/{first_letter}/{lib_name}/{lib_name}-{lib_version}.tar.gz".format( prefix=PYPI_URL_PREFIX, first_letter=lib["name"][:1], @@ -113,13 +118,13 @@ def setup_external_libs(libs): tgz_file.write(download_stream.read()) with open(tgz_file_path) as tar: if hashlib.md5(tar.read()).hexdigest() != lib["md5"]: - print >> stderr, "ERROR: Got wrong md5sum for {lib}.".format(lib=lib["name"]) + print("ERROR: Got wrong md5sum for {lib}.".format(lib=lib["name"]), file=stderr) sys.exit(1) tar = tarfile.open(tgz_file_path) tar.extractall(path=SPARK_EC2_LIB_DIR) tar.close() os.remove(tgz_file_path) - print " - Finished downloading {lib}.".format(lib=lib["name"]) + print(" - Finished downloading {lib}.".format(lib=lib["name"])) sys.path.insert(1, lib_dir) @@ -299,12 +304,12 @@ def parse_args(): if home_dir is None or not os.path.isfile(home_dir + '/.boto'): if not os.path.isfile('/etc/boto.cfg'): if os.getenv('AWS_ACCESS_KEY_ID') is None: - print >> stderr, ("ERROR: The environment variable AWS_ACCESS_KEY_ID " + - "must be set") + print("ERROR: The environment variable AWS_ACCESS_KEY_ID must be set", + file=stderr) sys.exit(1) if os.getenv('AWS_SECRET_ACCESS_KEY') is None: - print >> stderr, ("ERROR: The environment variable AWS_SECRET_ACCESS_KEY " + - "must be set") + print("ERROR: The environment variable AWS_SECRET_ACCESS_KEY must be set", + file=stderr) sys.exit(1) return (opts, action, cluster_name) @@ -316,7 +321,7 @@ def get_or_make_group(conn, name, vpc_id): if len(group) > 0: return group[0] else: - print "Creating security group " + name + print("Creating security group " + name) return conn.create_security_group(name, "Spark EC2 group", vpc_id) @@ -324,18 +329,19 @@ def get_validate_spark_version(version, repo): if "." in version: version = version.replace("v", "") if version not in VALID_SPARK_VERSIONS: - print >> stderr, "Don't know about Spark version: {v}".format(v=version) + print("Don't know about Spark version: {v}".format(v=version), file=stderr) sys.exit(1) return version else: github_commit_url = "{repo}/commit/{commit_hash}".format(repo=repo, commit_hash=version) - request = urllib2.Request(github_commit_url) + request = Request(github_commit_url) request.get_method = lambda: 'HEAD' try: - response = urllib2.urlopen(request) - except urllib2.HTTPError, e: - print >> stderr, "Couldn't validate Spark commit: {url}".format(url=github_commit_url) - print >> stderr, "Received HTTP response code of {code}.".format(code=e.code) + response = urlopen(request) + except HTTPError as e: + print("Couldn't validate Spark commit: {url}".format(url=github_commit_url), + file=stderr) + print("Received HTTP response code of {code}.".format(code=e.code), file=stderr) sys.exit(1) return version @@ -394,8 +400,7 @@ def get_spark_ami(opts): instance_type = EC2_INSTANCE_TYPES[opts.instance_type] else: instance_type = "pvm" - print >> stderr,\ - "Don't recognize %s, assuming type is pvm" % opts.instance_type + print("Don't recognize %s, assuming type is pvm" % opts.instance_type, file=stderr) # URL prefix from which to fetch AMI information ami_prefix = "{r}/{b}/ami-list".format( @@ -404,10 +409,10 @@ def get_spark_ami(opts): ami_path = "%s/%s/%s" % (ami_prefix, opts.region, instance_type) try: - ami = urllib2.urlopen(ami_path).read().strip() - print "Spark AMI: " + ami + ami = urlopen(ami_path).read().strip() + print("Spark AMI: " + ami) except: - print >> stderr, "Could not resolve AMI at: " + ami_path + print("Could not resolve AMI at: " + ami_path, file=stderr) sys.exit(1) return ami @@ -419,11 +424,11 @@ def get_spark_ami(opts): # Fails if there already instances running in the cluster's groups. def launch_cluster(conn, opts, cluster_name): if opts.identity_file is None: - print >> stderr, "ERROR: Must provide an identity file (-i) for ssh connections." + print("ERROR: Must provide an identity file (-i) for ssh connections.", file=stderr) sys.exit(1) if opts.key_pair is None: - print >> stderr, "ERROR: Must provide a key pair name (-k) to use on instances." + print("ERROR: Must provide a key pair name (-k) to use on instances.", file=stderr) sys.exit(1) user_data_content = None @@ -431,7 +436,7 @@ def launch_cluster(conn, opts, cluster_name): with open(opts.user_data) as user_data_file: user_data_content = user_data_file.read() - print "Setting up security groups..." + print("Setting up security groups...") master_group = get_or_make_group(conn, cluster_name + "-master", opts.vpc_id) slave_group = get_or_make_group(conn, cluster_name + "-slaves", opts.vpc_id) authorized_address = opts.authorized_address @@ -497,8 +502,8 @@ def launch_cluster(conn, opts, cluster_name): existing_masters, existing_slaves = get_existing_cluster(conn, opts, cluster_name, die_on_error=False) if existing_slaves or (existing_masters and not opts.use_existing_master): - print >> stderr, ("ERROR: There are already instances running in " + - "group %s or %s" % (master_group.name, slave_group.name)) + print("ERROR: There are already instances running in group %s or %s" % + (master_group.name, slave_group.name), file=stderr) sys.exit(1) # Figure out Spark AMI @@ -511,12 +516,12 @@ def launch_cluster(conn, opts, cluster_name): additional_group_ids = [sg.id for sg in conn.get_all_security_groups() if opts.additional_security_group in (sg.name, sg.id)] - print "Launching instances..." + print("Launching instances...") try: image = conn.get_all_images(image_ids=[opts.ami])[0] except: - print >> stderr, "Could not find AMI " + opts.ami + print("Could not find AMI " + opts.ami, file=stderr) sys.exit(1) # Create block device mapping so that we can add EBS volumes if asked to. @@ -542,8 +547,8 @@ def launch_cluster(conn, opts, cluster_name): # Launch slaves if opts.spot_price is not None: # Launch spot instances with the requested price - print ("Requesting %d slaves as spot instances with price $%.3f" % - (opts.slaves, opts.spot_price)) + print("Requesting %d slaves as spot instances with price $%.3f" % + (opts.slaves, opts.spot_price)) zones = get_zones(conn, opts) num_zones = len(zones) i = 0 @@ -566,7 +571,7 @@ def launch_cluster(conn, opts, cluster_name): my_req_ids += [req.id for req in slave_reqs] i += 1 - print "Waiting for spot instances to be granted..." + print("Waiting for spot instances to be granted...") try: while True: time.sleep(10) @@ -579,24 +584,24 @@ def launch_cluster(conn, opts, cluster_name): if i in id_to_req and id_to_req[i].state == "active": active_instance_ids.append(id_to_req[i].instance_id) if len(active_instance_ids) == opts.slaves: - print "All %d slaves granted" % opts.slaves + print("All %d slaves granted" % opts.slaves) reservations = conn.get_all_reservations(active_instance_ids) slave_nodes = [] for r in reservations: slave_nodes += r.instances break else: - print "%d of %d slaves granted, waiting longer" % ( - len(active_instance_ids), opts.slaves) + print("%d of %d slaves granted, waiting longer" % ( + len(active_instance_ids), opts.slaves)) except: - print "Canceling spot instance requests" + print("Canceling spot instance requests") conn.cancel_spot_instance_requests(my_req_ids) # Log a warning if any of these requests actually launched instances: (master_nodes, slave_nodes) = get_existing_cluster( conn, opts, cluster_name, die_on_error=False) running = len(master_nodes) + len(slave_nodes) if running: - print >> stderr, ("WARNING: %d instances are still running" % running) + print(("WARNING: %d instances are still running" % running), file=stderr) sys.exit(0) else: # Launch non-spot instances @@ -618,16 +623,16 @@ def launch_cluster(conn, opts, cluster_name): placement_group=opts.placement_group, user_data=user_data_content) slave_nodes += slave_res.instances - print "Launched {s} slave{plural_s} in {z}, regid = {r}".format( - s=num_slaves_this_zone, - plural_s=('' if num_slaves_this_zone == 1 else 's'), - z=zone, - r=slave_res.id) + print("Launched {s} slave{plural_s} in {z}, regid = {r}".format( + s=num_slaves_this_zone, + plural_s=('' if num_slaves_this_zone == 1 else 's'), + z=zone, + r=slave_res.id)) i += 1 # Launch or resume masters if existing_masters: - print "Starting master..." + print("Starting master...") for inst in existing_masters: if inst.state not in ["shutting-down", "terminated"]: inst.start() @@ -650,10 +655,10 @@ def launch_cluster(conn, opts, cluster_name): user_data=user_data_content) master_nodes = master_res.instances - print "Launched master in %s, regid = %s" % (zone, master_res.id) + print("Launched master in %s, regid = %s" % (zone, master_res.id)) # This wait time corresponds to SPARK-4983 - print "Waiting for AWS to propagate instance metadata..." + print("Waiting for AWS to propagate instance metadata...") time.sleep(5) # Give the instances descriptive names for master in master_nodes: @@ -674,8 +679,8 @@ def get_existing_cluster(conn, opts, cluster_name, die_on_error=True): Get the EC2 instances in an existing cluster if available. Returns a tuple of lists of EC2 instance objects for the masters and slaves. """ - print "Searching for existing cluster {c} in region {r}...".format( - c=cluster_name, r=opts.region) + print("Searching for existing cluster {c} in region {r}...".format( + c=cluster_name, r=opts.region)) def get_instances(group_names): """ @@ -693,16 +698,15 @@ def get_instances(group_names): slave_instances = get_instances([cluster_name + "-slaves"]) if any((master_instances, slave_instances)): - print "Found {m} master{plural_m}, {s} slave{plural_s}.".format( - m=len(master_instances), - plural_m=('' if len(master_instances) == 1 else 's'), - s=len(slave_instances), - plural_s=('' if len(slave_instances) == 1 else 's')) + print("Found {m} master{plural_m}, {s} slave{plural_s}.".format( + m=len(master_instances), + plural_m=('' if len(master_instances) == 1 else 's'), + s=len(slave_instances), + plural_s=('' if len(slave_instances) == 1 else 's'))) if not master_instances and die_on_error: - print >> sys.stderr, \ - "ERROR: Could not find a master for cluster {c} in region {r}.".format( - c=cluster_name, r=opts.region) + print("ERROR: Could not find a master for cluster {c} in region {r}.".format( + c=cluster_name, r=opts.region), file=sys.stderr) sys.exit(1) return (master_instances, slave_instances) @@ -713,7 +717,7 @@ def get_instances(group_names): def setup_cluster(conn, master_nodes, slave_nodes, opts, deploy_ssh_key): master = get_dns_name(master_nodes[0], opts.private_ips) if deploy_ssh_key: - print "Generating cluster's SSH key on master..." + print("Generating cluster's SSH key on master...") key_setup = """ [ -f ~/.ssh/id_rsa ] || (ssh-keygen -q -t rsa -N '' -f ~/.ssh/id_rsa && @@ -721,10 +725,10 @@ def setup_cluster(conn, master_nodes, slave_nodes, opts, deploy_ssh_key): """ ssh(master, opts, key_setup) dot_ssh_tar = ssh_read(master, opts, ['tar', 'c', '.ssh']) - print "Transferring cluster's SSH key to slaves..." + print("Transferring cluster's SSH key to slaves...") for slave in slave_nodes: slave_address = get_dns_name(slave, opts.private_ips) - print slave_address + print(slave_address) ssh_write(slave_address, opts, ['tar', 'x'], dot_ssh_tar) modules = ['spark', 'ephemeral-hdfs', 'persistent-hdfs', @@ -738,8 +742,8 @@ def setup_cluster(conn, master_nodes, slave_nodes, opts, deploy_ssh_key): # NOTE: We should clone the repository before running deploy_files to # prevent ec2-variables.sh from being overwritten - print "Cloning spark-ec2 scripts from {r}/tree/{b} on master...".format( - r=opts.spark_ec2_git_repo, b=opts.spark_ec2_git_branch) + print("Cloning spark-ec2 scripts from {r}/tree/{b} on master...".format( + r=opts.spark_ec2_git_repo, b=opts.spark_ec2_git_branch)) ssh( host=master, opts=opts, @@ -749,7 +753,7 @@ def setup_cluster(conn, master_nodes, slave_nodes, opts, deploy_ssh_key): b=opts.spark_ec2_git_branch) ) - print "Deploying files to master..." + print("Deploying files to master...") deploy_files( conn=conn, root_dir=SPARK_EC2_DIR + "/" + "deploy.generic", @@ -760,25 +764,25 @@ def setup_cluster(conn, master_nodes, slave_nodes, opts, deploy_ssh_key): ) if opts.deploy_root_dir is not None: - print "Deploying {s} to master...".format(s=opts.deploy_root_dir) + print("Deploying {s} to master...".format(s=opts.deploy_root_dir)) deploy_user_files( root_dir=opts.deploy_root_dir, opts=opts, master_nodes=master_nodes ) - print "Running setup on master..." + print("Running setup on master...") setup_spark_cluster(master, opts) - print "Done!" + print("Done!") def setup_spark_cluster(master, opts): ssh(master, opts, "chmod u+x spark-ec2/setup.sh") ssh(master, opts, "spark-ec2/setup.sh") - print "Spark standalone cluster started at http://%s:8080" % master + print("Spark standalone cluster started at http://%s:8080" % master) if opts.ganglia: - print "Ganglia started at http://%s:5080/ganglia" % master + print("Ganglia started at http://%s:5080/ganglia" % master) def is_ssh_available(host, opts, print_ssh_output=True): @@ -795,7 +799,7 @@ def is_ssh_available(host, opts, print_ssh_output=True): if s.returncode != 0 and print_ssh_output: # extra leading newline is for spacing in wait_for_cluster_state() - print textwrap.dedent("""\n + print(textwrap.dedent("""\n Warning: SSH connection error. (This could be temporary.) Host: {h} SSH return code: {r} @@ -804,7 +808,7 @@ def is_ssh_available(host, opts, print_ssh_output=True): h=host, r=s.returncode, o=cmd_output.strip() - ) + )) return s.returncode == 0 @@ -865,10 +869,10 @@ def wait_for_cluster_state(conn, opts, cluster_instances, cluster_state): sys.stdout.write("\n") end_time = datetime.now() - print "Cluster is now in '{s}' state. Waited {t} seconds.".format( + print("Cluster is now in '{s}' state. Waited {t} seconds.".format( s=cluster_state, t=(end_time - start_time).seconds - ) + )) # Get number of local disks available for a given EC2 instance type. @@ -916,8 +920,8 @@ def get_num_disks(instance_type): if instance_type in disks_by_instance: return disks_by_instance[instance_type] else: - print >> stderr, ("WARNING: Don't know number of disks on instance type %s; assuming 1" - % instance_type) + print("WARNING: Don't know number of disks on instance type %s; assuming 1" + % instance_type, file=stderr) return 1 @@ -951,7 +955,7 @@ def deploy_files(conn, root_dir, opts, master_nodes, slave_nodes, modules): # Spark-only custom deploy spark_v = "%s|%s" % (opts.spark_git_repo, opts.spark_version) tachyon_v = "" - print "Deploying Spark via git hash; Tachyon won't be set up" + print("Deploying Spark via git hash; Tachyon won't be set up") modules = filter(lambda x: x != "tachyon", modules) master_addresses = [get_dns_name(i, opts.private_ips) for i in master_nodes] @@ -1067,8 +1071,8 @@ def ssh(host, opts, command): "--key-pair parameters and try again.".format(host)) else: raise e - print >> stderr, \ - "Error executing remote command, retrying after 30 seconds: {0}".format(e) + print("Error executing remote command, retrying after 30 seconds: {0}".format(e), + file=stderr) time.sleep(30) tries = tries + 1 @@ -1107,8 +1111,8 @@ def ssh_write(host, opts, command, arguments): elif tries > 5: raise RuntimeError("ssh_write failed with error %s" % proc.returncode) else: - print >> stderr, \ - "Error {0} while executing remote command, retrying after 30 seconds".format(status) + print("Error {0} while executing remote command, retrying after 30 seconds". + format(status), file=stderr) time.sleep(30) tries = tries + 1 @@ -1162,42 +1166,41 @@ def real_main(): if opts.identity_file is not None: if not os.path.exists(opts.identity_file): - print >> stderr,\ - "ERROR: The identity file '{f}' doesn't exist.".format(f=opts.identity_file) + print("ERROR: The identity file '{f}' doesn't exist.".format(f=opts.identity_file), + file=stderr) sys.exit(1) file_mode = os.stat(opts.identity_file).st_mode if not (file_mode & S_IRUSR) or not oct(file_mode)[-2:] == '00': - print >> stderr, "ERROR: The identity file must be accessible only by you." - print >> stderr, 'You can fix this with: chmod 400 "{f}"'.format(f=opts.identity_file) + print("ERROR: The identity file must be accessible only by you.", file=stderr) + print('You can fix this with: chmod 400 "{f}"'.format(f=opts.identity_file), + file=stderr) sys.exit(1) if opts.instance_type not in EC2_INSTANCE_TYPES: - print >> stderr, "Warning: Unrecognized EC2 instance type for instance-type: {t}".format( - t=opts.instance_type) + print("Warning: Unrecognized EC2 instance type for instance-type: {t}".format( + t=opts.instance_type), file=stderr) if opts.master_instance_type != "": if opts.master_instance_type not in EC2_INSTANCE_TYPES: - print >> stderr, \ - "Warning: Unrecognized EC2 instance type for master-instance-type: {t}".format( - t=opts.master_instance_type) + print("Warning: Unrecognized EC2 instance type for master-instance-type: {t}".format( + t=opts.master_instance_type), file=stderr) # Since we try instance types even if we can't resolve them, we check if they resolve first # and, if they do, see if they resolve to the same virtualization type. if opts.instance_type in EC2_INSTANCE_TYPES and \ opts.master_instance_type in EC2_INSTANCE_TYPES: if EC2_INSTANCE_TYPES[opts.instance_type] != \ EC2_INSTANCE_TYPES[opts.master_instance_type]: - print >> stderr, \ - "Error: spark-ec2 currently does not support having a master and slaves " + \ - "with different AMI virtualization types." - print >> stderr, "master instance virtualization type: {t}".format( - t=EC2_INSTANCE_TYPES[opts.master_instance_type]) - print >> stderr, "slave instance virtualization type: {t}".format( - t=EC2_INSTANCE_TYPES[opts.instance_type]) + print("Error: spark-ec2 currently does not support having a master and slaves " + "with different AMI virtualization types.", file=stderr) + print("master instance virtualization type: {t}".format( + t=EC2_INSTANCE_TYPES[opts.master_instance_type]), file=stderr) + print("slave instance virtualization type: {t}".format( + t=EC2_INSTANCE_TYPES[opts.instance_type]), file=stderr) sys.exit(1) if opts.ebs_vol_num > 8: - print >> stderr, "ebs-vol-num cannot be greater than 8" + print("ebs-vol-num cannot be greater than 8", file=stderr) sys.exit(1) # Prevent breaking ami_prefix (/, .git and startswith checks) @@ -1206,23 +1209,22 @@ def real_main(): opts.spark_ec2_git_repo.endswith(".git") or \ not opts.spark_ec2_git_repo.startswith("https://github.com") or \ not opts.spark_ec2_git_repo.endswith("spark-ec2"): - print >> stderr, "spark-ec2-git-repo must be a github repo and it must not have a " \ - "trailing / or .git. " \ - "Furthermore, we currently only support forks named spark-ec2." + print("spark-ec2-git-repo must be a github repo and it must not have a trailing / or .git. " + "Furthermore, we currently only support forks named spark-ec2.", file=stderr) sys.exit(1) if not (opts.deploy_root_dir is None or (os.path.isabs(opts.deploy_root_dir) and os.path.isdir(opts.deploy_root_dir) and os.path.exists(opts.deploy_root_dir))): - print >> stderr, "--deploy-root-dir must be an absolute path to a directory that exists " \ - "on the local file system" + print("--deploy-root-dir must be an absolute path to a directory that exists " + "on the local file system", file=stderr) sys.exit(1) try: conn = ec2.connect_to_region(opts.region) except Exception as e: - print >> stderr, (e) + print((e), file=stderr) sys.exit(1) # Select an AZ at random if it was not specified. @@ -1231,7 +1233,7 @@ def real_main(): if action == "launch": if opts.slaves <= 0: - print >> sys.stderr, "ERROR: You have to start at least 1 slave" + print("ERROR: You have to start at least 1 slave", file=sys.stderr) sys.exit(1) if opts.resume: (master_nodes, slave_nodes) = get_existing_cluster(conn, opts, cluster_name) @@ -1250,18 +1252,18 @@ def real_main(): conn, opts, cluster_name, die_on_error=False) if any(master_nodes + slave_nodes): - print "The following instances will be terminated:" + print("The following instances will be terminated:") for inst in master_nodes + slave_nodes: - print "> %s" % get_dns_name(inst, opts.private_ips) - print "ALL DATA ON ALL NODES WILL BE LOST!!" + print("> %s" % get_dns_name(inst, opts.private_ips)) + print("ALL DATA ON ALL NODES WILL BE LOST!!") msg = "Are you sure you want to destroy the cluster {c}? (y/N) ".format(c=cluster_name) response = raw_input(msg) if response == "y": - print "Terminating master..." + print("Terminating master...") for inst in master_nodes: inst.terminate() - print "Terminating slaves..." + print("Terminating slaves...") for inst in slave_nodes: inst.terminate() @@ -1274,16 +1276,16 @@ def real_main(): cluster_instances=(master_nodes + slave_nodes), cluster_state='terminated' ) - print "Deleting security groups (this will take some time)..." + print("Deleting security groups (this will take some time)...") attempt = 1 while attempt <= 3: - print "Attempt %d" % attempt + print("Attempt %d" % attempt) groups = [g for g in conn.get_all_security_groups() if g.name in group_names] success = True # Delete individual rules in all groups before deleting groups to # remove dependencies between them for group in groups: - print "Deleting rules in security group " + group.name + print("Deleting rules in security group " + group.name) for rule in group.rules: for grant in rule.grants: success &= group.revoke(ip_protocol=rule.ip_protocol, @@ -1298,10 +1300,10 @@ def real_main(): try: # It is needed to use group_id to make it work with VPC conn.delete_security_group(group_id=group.id) - print "Deleted security group %s" % group.name + print("Deleted security group %s" % group.name) except boto.exception.EC2ResponseError: success = False - print "Failed to delete security group %s" % group.name + print("Failed to delete security group %s" % group.name) # Unfortunately, group.revoke() returns True even if a rule was not # deleted, so this needs to be rerun if something fails @@ -1311,17 +1313,16 @@ def real_main(): attempt += 1 if not success: - print "Failed to delete all security groups after 3 tries." - print "Try re-running in a few minutes." + print("Failed to delete all security groups after 3 tries.") + print("Try re-running in a few minutes.") elif action == "login": (master_nodes, slave_nodes) = get_existing_cluster(conn, opts, cluster_name) if not master_nodes[0].public_dns_name and not opts.private_ips: - print "Master has no public DNS name. Maybe you meant to specify " \ - "--private-ips?" + print("Master has no public DNS name. Maybe you meant to specify --private-ips?") else: master = get_dns_name(master_nodes[0], opts.private_ips) - print "Logging into master " + master + "..." + print("Logging into master " + master + "...") proxy_opt = [] if opts.proxy_port is not None: proxy_opt = ['-D', opts.proxy_port] @@ -1336,19 +1337,18 @@ def real_main(): if response == "y": (master_nodes, slave_nodes) = get_existing_cluster( conn, opts, cluster_name, die_on_error=False) - print "Rebooting slaves..." + print("Rebooting slaves...") for inst in slave_nodes: if inst.state not in ["shutting-down", "terminated"]: - print "Rebooting " + inst.id + print("Rebooting " + inst.id) inst.reboot() elif action == "get-master": (master_nodes, slave_nodes) = get_existing_cluster(conn, opts, cluster_name) if not master_nodes[0].public_dns_name and not opts.private_ips: - print "Master has no public DNS name. Maybe you meant to specify " \ - "--private-ips?" + print("Master has no public DNS name. Maybe you meant to specify --private-ips?") else: - print get_dns_name(master_nodes[0], opts.private_ips) + print(get_dns_name(master_nodes[0], opts.private_ips)) elif action == "stop": response = raw_input( @@ -1361,11 +1361,11 @@ def real_main(): if response == "y": (master_nodes, slave_nodes) = get_existing_cluster( conn, opts, cluster_name, die_on_error=False) - print "Stopping master..." + print("Stopping master...") for inst in master_nodes: if inst.state not in ["shutting-down", "terminated"]: inst.stop() - print "Stopping slaves..." + print("Stopping slaves...") for inst in slave_nodes: if inst.state not in ["shutting-down", "terminated"]: if inst.spot_instance_request_id: @@ -1375,11 +1375,11 @@ def real_main(): elif action == "start": (master_nodes, slave_nodes) = get_existing_cluster(conn, opts, cluster_name) - print "Starting slaves..." + print("Starting slaves...") for inst in slave_nodes: if inst.state not in ["shutting-down", "terminated"]: inst.start() - print "Starting master..." + print("Starting master...") for inst in master_nodes: if inst.state not in ["shutting-down", "terminated"]: inst.start() @@ -1403,15 +1403,15 @@ def real_main(): setup_cluster(conn, master_nodes, slave_nodes, opts, False) else: - print >> stderr, "Invalid action: %s" % action + print("Invalid action: %s" % action, file=stderr) sys.exit(1) def main(): try: real_main() - except UsageError, e: - print >> stderr, "\nError:\n", e + except UsageError as e: + print("\nError:\n", e, file=stderr) sys.exit(1) diff --git a/examples/scala-2.10/src/main/scala/org/apache/spark/examples/streaming/DirectKafkaWordCount.scala b/examples/scala-2.10/src/main/scala/org/apache/spark/examples/streaming/DirectKafkaWordCount.scala index 1c8a20bf8f1ae..11a8cf09533ce 100644 --- a/examples/scala-2.10/src/main/scala/org/apache/spark/examples/streaming/DirectKafkaWordCount.scala +++ b/examples/scala-2.10/src/main/scala/org/apache/spark/examples/streaming/DirectKafkaWordCount.scala @@ -41,7 +41,7 @@ object DirectKafkaWordCount { | is a list of one or more Kafka brokers | is a list of one or more kafka topics to consume from | - """".stripMargin) + """.stripMargin) System.exit(1) } diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaDeveloperApiExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaDeveloperApiExample.java index 19d0eb216848e..eaf00d09f550d 100644 --- a/examples/src/main/java/org/apache/spark/examples/ml/JavaDeveloperApiExample.java +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaDeveloperApiExample.java @@ -116,7 +116,7 @@ class MyJavaLogisticRegression */ IntParam maxIter = new IntParam(this, "maxIter", "max number of iterations"); - int getMaxIter() { return (Integer) get(maxIter); } + int getMaxIter() { return (Integer) getOrDefault(maxIter); } public MyJavaLogisticRegression() { setMaxIter(100); @@ -211,7 +211,7 @@ public Vector predictRaw(Vector features) { public MyJavaLogisticRegressionModel copy() { MyJavaLogisticRegressionModel m = new MyJavaLogisticRegressionModel(parent_, fittingParamMap_, weights_); - Params$.MODULE$.inheritValues(this.paramMap(), this, m); + Params$.MODULE$.inheritValues(this.extractParamMap(), this, m); return m; } } diff --git a/examples/src/main/python/als.py b/examples/src/main/python/als.py index 70b6146e39a87..1c3a787bd0e94 100755 --- a/examples/src/main/python/als.py +++ b/examples/src/main/python/als.py @@ -21,7 +21,8 @@ This example requires numpy (http://www.numpy.org/) """ -from os.path import realpath +from __future__ import print_function + import sys import numpy as np @@ -57,9 +58,9 @@ def update(i, vec, mat, ratings): Usage: als [M] [U] [F] [iterations] [partitions]" """ - print >> sys.stderr, """WARN: This is a naive implementation of ALS and is given as an + print("""WARN: This is a naive implementation of ALS and is given as an example. Please use the ALS method found in pyspark.mllib.recommendation for more - conventional use.""" + conventional use.""", file=sys.stderr) sc = SparkContext(appName="PythonALS") M = int(sys.argv[1]) if len(sys.argv) > 1 else 100 @@ -68,8 +69,8 @@ def update(i, vec, mat, ratings): ITERATIONS = int(sys.argv[4]) if len(sys.argv) > 4 else 5 partitions = int(sys.argv[5]) if len(sys.argv) > 5 else 2 - print "Running ALS with M=%d, U=%d, F=%d, iters=%d, partitions=%d\n" % \ - (M, U, F, ITERATIONS, partitions) + print("Running ALS with M=%d, U=%d, F=%d, iters=%d, partitions=%d\n" % + (M, U, F, ITERATIONS, partitions)) R = matrix(rand(M, F)) * matrix(rand(U, F).T) ms = matrix(rand(M, F)) @@ -95,7 +96,7 @@ def update(i, vec, mat, ratings): usb = sc.broadcast(us) error = rmse(R, ms, us) - print "Iteration %d:" % i - print "\nRMSE: %5.4f\n" % error + print("Iteration %d:" % i) + print("\nRMSE: %5.4f\n" % error) sc.stop() diff --git a/examples/src/main/python/avro_inputformat.py b/examples/src/main/python/avro_inputformat.py index 4626bbb7e3b02..da368ac628a49 100644 --- a/examples/src/main/python/avro_inputformat.py +++ b/examples/src/main/python/avro_inputformat.py @@ -15,9 +15,12 @@ # limitations under the License. # +from __future__ import print_function + import sys from pyspark import SparkContext +from functools import reduce """ Read data file users.avro in local Spark distro: @@ -49,7 +52,7 @@ """ if __name__ == "__main__": if len(sys.argv) != 2 and len(sys.argv) != 3: - print >> sys.stderr, """ + print(""" Usage: avro_inputformat [reader_schema_file] Run with example jar: @@ -57,7 +60,7 @@ /path/to/examples/avro_inputformat.py [reader_schema_file] Assumes you have Avro data stored in . Reader schema can be optionally specified in [reader_schema_file]. - """ + """, file=sys.stderr) exit(-1) path = sys.argv[1] @@ -77,6 +80,6 @@ conf=conf) output = avro_rdd.map(lambda x: x[0]).collect() for k in output: - print k + print(k) sc.stop() diff --git a/examples/src/main/python/cassandra_inputformat.py b/examples/src/main/python/cassandra_inputformat.py index 05f34b74df45a..93ca0cfcc9302 100644 --- a/examples/src/main/python/cassandra_inputformat.py +++ b/examples/src/main/python/cassandra_inputformat.py @@ -15,6 +15,8 @@ # limitations under the License. # +from __future__ import print_function + import sys from pyspark import SparkContext @@ -47,14 +49,14 @@ """ if __name__ == "__main__": if len(sys.argv) != 4: - print >> sys.stderr, """ + print(""" Usage: cassandra_inputformat Run with example jar: ./bin/spark-submit --driver-class-path /path/to/example/jar \ /path/to/examples/cassandra_inputformat.py Assumes you have some data in Cassandra already, running on , in and - """ + """, file=sys.stderr) exit(-1) host = sys.argv[1] @@ -77,6 +79,6 @@ conf=conf) output = cass_rdd.collect() for (k, v) in output: - print (k, v) + print((k, v)) sc.stop() diff --git a/examples/src/main/python/cassandra_outputformat.py b/examples/src/main/python/cassandra_outputformat.py index d144539e58b8f..5d643eac92f94 100644 --- a/examples/src/main/python/cassandra_outputformat.py +++ b/examples/src/main/python/cassandra_outputformat.py @@ -15,6 +15,8 @@ # limitations under the License. # +from __future__ import print_function + import sys from pyspark import SparkContext @@ -46,7 +48,7 @@ """ if __name__ == "__main__": if len(sys.argv) != 7: - print >> sys.stderr, """ + print(""" Usage: cassandra_outputformat Run with example jar: @@ -60,7 +62,7 @@ ... fname text, ... lname text ... ); - """ + """, file=sys.stderr) exit(-1) host = sys.argv[1] diff --git a/examples/src/main/python/hbase_inputformat.py b/examples/src/main/python/hbase_inputformat.py index 3b16010f1cb97..e17819d5feb76 100644 --- a/examples/src/main/python/hbase_inputformat.py +++ b/examples/src/main/python/hbase_inputformat.py @@ -15,6 +15,8 @@ # limitations under the License. # +from __future__ import print_function + import sys from pyspark import SparkContext @@ -47,14 +49,14 @@ """ if __name__ == "__main__": if len(sys.argv) != 3: - print >> sys.stderr, """ + print(""" Usage: hbase_inputformat Run with example jar: ./bin/spark-submit --driver-class-path /path/to/example/jar \ /path/to/examples/hbase_inputformat.py
    Assumes you have some data in HBase already, running on , in
    - """ + """, file=sys.stderr) exit(-1) host = sys.argv[1] @@ -74,6 +76,6 @@ conf=conf) output = hbase_rdd.collect() for (k, v) in output: - print (k, v) + print((k, v)) sc.stop() diff --git a/examples/src/main/python/hbase_outputformat.py b/examples/src/main/python/hbase_outputformat.py index abb425b1f886a..9e5641789a976 100644 --- a/examples/src/main/python/hbase_outputformat.py +++ b/examples/src/main/python/hbase_outputformat.py @@ -15,6 +15,8 @@ # limitations under the License. # +from __future__ import print_function + import sys from pyspark import SparkContext @@ -40,7 +42,7 @@ """ if __name__ == "__main__": if len(sys.argv) != 7: - print >> sys.stderr, """ + print(""" Usage: hbase_outputformat
    Run with example jar: @@ -48,7 +50,7 @@ /path/to/examples/hbase_outputformat.py Assumes you have created
    with column family in HBase running on already - """ + """, file=sys.stderr) exit(-1) host = sys.argv[1] diff --git a/examples/src/main/python/kmeans.py b/examples/src/main/python/kmeans.py index 86ef6f32c84e8..19391506463f0 100755 --- a/examples/src/main/python/kmeans.py +++ b/examples/src/main/python/kmeans.py @@ -22,6 +22,7 @@ This example requires NumPy (http://www.numpy.org/). """ +from __future__ import print_function import sys @@ -47,12 +48,12 @@ def closestPoint(p, centers): if __name__ == "__main__": if len(sys.argv) != 4: - print >> sys.stderr, "Usage: kmeans " + print("Usage: kmeans ", file=sys.stderr) exit(-1) - print >> sys.stderr, """WARN: This is a naive implementation of KMeans Clustering and is given + print("""WARN: This is a naive implementation of KMeans Clustering and is given as an example! Please refer to examples/src/main/python/mllib/kmeans.py for an example on - how to use MLlib's KMeans implementation.""" + how to use MLlib's KMeans implementation.""", file=sys.stderr) sc = SparkContext(appName="PythonKMeans") lines = sc.textFile(sys.argv[1]) @@ -69,13 +70,13 @@ def closestPoint(p, centers): pointStats = closest.reduceByKey( lambda (x1, y1), (x2, y2): (x1 + x2, y1 + y2)) newPoints = pointStats.map( - lambda (x, (y, z)): (x, y / z)).collect() + lambda xy: (xy[0], xy[1][0] / xy[1][1])).collect() tempDist = sum(np.sum((kPoints[x] - y) ** 2) for (x, y) in newPoints) for (x, y) in newPoints: kPoints[x] = y - print "Final centers: " + str(kPoints) + print("Final centers: " + str(kPoints)) sc.stop() diff --git a/examples/src/main/python/logistic_regression.py b/examples/src/main/python/logistic_regression.py index 3aa56b0528168..b318b7d87bfdc 100755 --- a/examples/src/main/python/logistic_regression.py +++ b/examples/src/main/python/logistic_regression.py @@ -22,10 +22,8 @@ In practice, one may prefer to use the LogisticRegression algorithm in MLlib, as shown in examples/src/main/python/mllib/logistic_regression.py. """ +from __future__ import print_function -from collections import namedtuple -from math import exp -from os.path import realpath import sys import numpy as np @@ -42,19 +40,19 @@ def readPointBatch(iterator): strs = list(iterator) matrix = np.zeros((len(strs), D + 1)) - for i in xrange(len(strs)): - matrix[i] = np.fromstring(strs[i].replace(',', ' '), dtype=np.float32, sep=' ') + for i, s in enumerate(strs): + matrix[i] = np.fromstring(s.replace(',', ' '), dtype=np.float32, sep=' ') return [matrix] if __name__ == "__main__": if len(sys.argv) != 3: - print >> sys.stderr, "Usage: logistic_regression " + print("Usage: logistic_regression ", file=sys.stderr) exit(-1) - print >> sys.stderr, """WARN: This is a naive implementation of Logistic Regression and is + print("""WARN: This is a naive implementation of Logistic Regression and is given as an example! Please refer to examples/src/main/python/mllib/logistic_regression.py - to see how MLlib's implementation is used.""" + to see how MLlib's implementation is used.""", file=sys.stderr) sc = SparkContext(appName="PythonLR") points = sc.textFile(sys.argv[1]).mapPartitions(readPointBatch).cache() @@ -62,7 +60,7 @@ def readPointBatch(iterator): # Initialize w to a random value w = 2 * np.random.ranf(size=D) - 1 - print "Initial w: " + str(w) + print("Initial w: " + str(w)) # Compute logistic regression gradient for a matrix of data points def gradient(matrix, w): @@ -76,9 +74,9 @@ def add(x, y): return x for i in range(iterations): - print "On iteration %i" % (i + 1) + print("On iteration %i" % (i + 1)) w -= points.map(lambda m: gradient(m, w)).reduce(add) - print "Final w: " + str(w) + print("Final w: " + str(w)) sc.stop() diff --git a/examples/src/main/python/ml/simple_text_classification_pipeline.py b/examples/src/main/python/ml/simple_text_classification_pipeline.py index c73edb7fd6b20..fab21f003b233 100644 --- a/examples/src/main/python/ml/simple_text_classification_pipeline.py +++ b/examples/src/main/python/ml/simple_text_classification_pipeline.py @@ -15,6 +15,8 @@ # limitations under the License. # +from __future__ import print_function + from pyspark import SparkContext from pyspark.ml import Pipeline from pyspark.ml.classification import LogisticRegression @@ -37,10 +39,10 @@ # Prepare training documents, which are labeled. LabeledDocument = Row("id", "text", "label") - training = sc.parallelize([(0L, "a b c d e spark", 1.0), - (1L, "b d", 0.0), - (2L, "spark f g h", 1.0), - (3L, "hadoop mapreduce", 0.0)]) \ + training = sc.parallelize([(0, "a b c d e spark", 1.0), + (1, "b d", 0.0), + (2, "spark f g h", 1.0), + (3, "hadoop mapreduce", 0.0)]) \ .map(lambda x: LabeledDocument(*x)).toDF() # Configure an ML pipeline, which consists of tree stages: tokenizer, hashingTF, and lr. @@ -54,16 +56,16 @@ # Prepare test documents, which are unlabeled. Document = Row("id", "text") - test = sc.parallelize([(4L, "spark i j k"), - (5L, "l m n"), - (6L, "mapreduce spark"), - (7L, "apache hadoop")]) \ + test = sc.parallelize([(4, "spark i j k"), + (5, "l m n"), + (6, "mapreduce spark"), + (7, "apache hadoop")]) \ .map(lambda x: Document(*x)).toDF() # Make predictions on test documents and print columns of interest. prediction = model.transform(test) selected = prediction.select("id", "text", "prediction") for row in selected.collect(): - print row + print(row) sc.stop() diff --git a/examples/src/main/python/mllib/correlations.py b/examples/src/main/python/mllib/correlations.py index 4218eca822a99..0e13546b88e67 100755 --- a/examples/src/main/python/mllib/correlations.py +++ b/examples/src/main/python/mllib/correlations.py @@ -18,6 +18,7 @@ """ Correlations using MLlib. """ +from __future__ import print_function import sys @@ -29,7 +30,7 @@ if __name__ == "__main__": if len(sys.argv) not in [1, 2]: - print >> sys.stderr, "Usage: correlations ()" + print("Usage: correlations ()", file=sys.stderr) exit(-1) sc = SparkContext(appName="PythonCorrelations") if len(sys.argv) == 2: @@ -41,20 +42,20 @@ points = MLUtils.loadLibSVMFile(sc, filepath)\ .map(lambda lp: LabeledPoint(lp.label, lp.features.toArray())) - print - print 'Summary of data file: ' + filepath - print '%d data points' % points.count() + print() + print('Summary of data file: ' + filepath) + print('%d data points' % points.count()) # Statistics (correlations) - print - print 'Correlation (%s) between label and each feature' % corrType - print 'Feature\tCorrelation' + print() + print('Correlation (%s) between label and each feature' % corrType) + print('Feature\tCorrelation') numFeatures = points.take(1)[0].features.size labelRDD = points.map(lambda lp: lp.label) for i in range(numFeatures): featureRDD = points.map(lambda lp: lp.features[i]) corr = Statistics.corr(labelRDD, featureRDD, corrType) - print '%d\t%g' % (i, corr) - print + print('%d\t%g' % (i, corr)) + print() sc.stop() diff --git a/examples/src/main/python/mllib/dataset_example.py b/examples/src/main/python/mllib/dataset_example.py index fcbf56cbf0c52..e23ecc0c5d302 100644 --- a/examples/src/main/python/mllib/dataset_example.py +++ b/examples/src/main/python/mllib/dataset_example.py @@ -19,6 +19,7 @@ An example of how to use DataFrame as a dataset for ML. Run with:: bin/spark-submit examples/src/main/python/mllib/dataset_example.py """ +from __future__ import print_function import os import sys @@ -32,16 +33,16 @@ def summarize(dataset): - print "schema: %s" % dataset.schema().json() + print("schema: %s" % dataset.schema().json()) labels = dataset.map(lambda r: r.label) - print "label average: %f" % labels.mean() + print("label average: %f" % labels.mean()) features = dataset.map(lambda r: r.features) summary = Statistics.colStats(features) - print "features average: %r" % summary.mean() + print("features average: %r" % summary.mean()) if __name__ == "__main__": if len(sys.argv) > 2: - print >> sys.stderr, "Usage: dataset_example.py " + print("Usage: dataset_example.py ", file=sys.stderr) exit(-1) sc = SparkContext(appName="DatasetExample") sqlContext = SQLContext(sc) @@ -54,9 +55,9 @@ def summarize(dataset): summarize(dataset0) tempdir = tempfile.NamedTemporaryFile(delete=False).name os.unlink(tempdir) - print "Save dataset as a Parquet file to %s." % tempdir + print("Save dataset as a Parquet file to %s." % tempdir) dataset0.saveAsParquetFile(tempdir) - print "Load it back and summarize it again." + print("Load it back and summarize it again.") dataset1 = sqlContext.parquetFile(tempdir).setName("dataset1").cache() summarize(dataset1) shutil.rmtree(tempdir) diff --git a/examples/src/main/python/mllib/decision_tree_runner.py b/examples/src/main/python/mllib/decision_tree_runner.py index fccabd841b139..513ed8fd51450 100755 --- a/examples/src/main/python/mllib/decision_tree_runner.py +++ b/examples/src/main/python/mllib/decision_tree_runner.py @@ -20,6 +20,7 @@ This example requires NumPy (http://www.numpy.org/). """ +from __future__ import print_function import numpy import os @@ -83,18 +84,17 @@ def reindexClassLabels(data): numClasses = len(classCounts) # origToNewLabels: class --> index in 0,...,numClasses-1 if (numClasses < 2): - print >> sys.stderr, \ - "Dataset for classification should have at least 2 classes." + \ - " The given dataset had only %d classes." % numClasses + print("Dataset for classification should have at least 2 classes." + " The given dataset had only %d classes." % numClasses, file=sys.stderr) exit(1) origToNewLabels = dict([(sortedClasses[i], i) for i in range(0, numClasses)]) - print "numClasses = %d" % numClasses - print "Per-class example fractions, counts:" - print "Class\tFrac\tCount" + print("numClasses = %d" % numClasses) + print("Per-class example fractions, counts:") + print("Class\tFrac\tCount") for c in sortedClasses: frac = classCounts[c] / (numExamples + 0.0) - print "%g\t%g\t%d" % (c, frac, classCounts[c]) + print("%g\t%g\t%d" % (c, frac, classCounts[c])) if (sortedClasses[0] == 0 and sortedClasses[-1] == numClasses - 1): return (data, origToNewLabels) @@ -105,8 +105,7 @@ def reindexClassLabels(data): def usage(): - print >> sys.stderr, \ - "Usage: decision_tree_runner [libsvm format data filepath]" + print("Usage: decision_tree_runner [libsvm format data filepath]", file=sys.stderr) exit(1) @@ -133,13 +132,13 @@ def usage(): model = DecisionTree.trainClassifier(reindexedData, numClasses=numClasses, categoricalFeaturesInfo=categoricalFeaturesInfo) # Print learned tree and stats. - print "Trained DecisionTree for classification:" - print " Model numNodes: %d" % model.numNodes() - print " Model depth: %d" % model.depth() - print " Training accuracy: %g" % getAccuracy(model, reindexedData) + print("Trained DecisionTree for classification:") + print(" Model numNodes: %d" % model.numNodes()) + print(" Model depth: %d" % model.depth()) + print(" Training accuracy: %g" % getAccuracy(model, reindexedData)) if model.numNodes() < 20: - print model.toDebugString() + print(model.toDebugString()) else: - print model + print(model) sc.stop() diff --git a/examples/src/main/python/mllib/gaussian_mixture_model.py b/examples/src/main/python/mllib/gaussian_mixture_model.py index a2cd626c9f19d..2cb8010cdc07f 100644 --- a/examples/src/main/python/mllib/gaussian_mixture_model.py +++ b/examples/src/main/python/mllib/gaussian_mixture_model.py @@ -18,7 +18,8 @@ """ A Gaussian Mixture Model clustering program using MLlib. """ -import sys +from __future__ import print_function + import random import argparse import numpy as np @@ -59,7 +60,7 @@ def parseVector(line): model = GaussianMixture.train(data, args.k, args.convergenceTol, args.maxIterations, args.seed) for i in range(args.k): - print ("weight = ", model.weights[i], "mu = ", model.gaussians[i].mu, - "sigma = ", model.gaussians[i].sigma.toArray()) - print ("Cluster labels (first 100): ", model.predict(data).take(100)) + print(("weight = ", model.weights[i], "mu = ", model.gaussians[i].mu, + "sigma = ", model.gaussians[i].sigma.toArray())) + print(("Cluster labels (first 100): ", model.predict(data).take(100))) sc.stop() diff --git a/examples/src/main/python/mllib/gradient_boosted_trees.py b/examples/src/main/python/mllib/gradient_boosted_trees.py index e647773ad9060..781bd61c9d2b5 100644 --- a/examples/src/main/python/mllib/gradient_boosted_trees.py +++ b/examples/src/main/python/mllib/gradient_boosted_trees.py @@ -18,6 +18,7 @@ """ Gradient boosted Trees classification and regression using MLlib. """ +from __future__ import print_function import sys @@ -34,7 +35,7 @@ def testClassification(trainingData, testData): # Evaluate model on test instances and compute test error predictions = model.predict(testData.map(lambda x: x.features)) labelsAndPredictions = testData.map(lambda lp: lp.label).zip(predictions) - testErr = labelsAndPredictions.filter(lambda (v, p): v != p).count() \ + testErr = labelsAndPredictions.filter(lambda v_p: v_p[0] != v_p[1]).count() \ / float(testData.count()) print('Test Error = ' + str(testErr)) print('Learned classification ensemble model:') @@ -49,7 +50,7 @@ def testRegression(trainingData, testData): # Evaluate model on test instances and compute test error predictions = model.predict(testData.map(lambda x: x.features)) labelsAndPredictions = testData.map(lambda lp: lp.label).zip(predictions) - testMSE = labelsAndPredictions.map(lambda (v, p): (v - p) * (v - p)).sum() \ + testMSE = labelsAndPredictions.map(lambda vp: (vp[0] - vp[1]) * (vp[0] - vp[1])).sum() \ / float(testData.count()) print('Test Mean Squared Error = ' + str(testMSE)) print('Learned regression ensemble model:') @@ -58,7 +59,7 @@ def testRegression(trainingData, testData): if __name__ == "__main__": if len(sys.argv) > 1: - print >> sys.stderr, "Usage: gradient_boosted_trees" + print("Usage: gradient_boosted_trees", file=sys.stderr) exit(1) sc = SparkContext(appName="PythonGradientBoostedTrees") diff --git a/examples/src/main/python/mllib/kmeans.py b/examples/src/main/python/mllib/kmeans.py index 2eeb1abeeb12b..f901a87fa63ac 100755 --- a/examples/src/main/python/mllib/kmeans.py +++ b/examples/src/main/python/mllib/kmeans.py @@ -20,6 +20,7 @@ This example requires NumPy (http://www.numpy.org/). """ +from __future__ import print_function import sys @@ -34,12 +35,12 @@ def parseVector(line): if __name__ == "__main__": if len(sys.argv) != 3: - print >> sys.stderr, "Usage: kmeans " + print("Usage: kmeans ", file=sys.stderr) exit(-1) sc = SparkContext(appName="KMeans") lines = sc.textFile(sys.argv[1]) data = lines.map(parseVector) k = int(sys.argv[2]) model = KMeans.train(data, k) - print "Final centers: " + str(model.clusterCenters) + print("Final centers: " + str(model.clusterCenters)) sc.stop() diff --git a/examples/src/main/python/mllib/logistic_regression.py b/examples/src/main/python/mllib/logistic_regression.py index 8cae27fc4a52d..d4f1d34e2d8cf 100755 --- a/examples/src/main/python/mllib/logistic_regression.py +++ b/examples/src/main/python/mllib/logistic_regression.py @@ -20,11 +20,10 @@ This example requires NumPy (http://www.numpy.org/). """ +from __future__ import print_function -from math import exp import sys -import numpy as np from pyspark import SparkContext from pyspark.mllib.regression import LabeledPoint from pyspark.mllib.classification import LogisticRegressionWithSGD @@ -42,12 +41,12 @@ def parsePoint(line): if __name__ == "__main__": if len(sys.argv) != 3: - print >> sys.stderr, "Usage: logistic_regression " + print("Usage: logistic_regression ", file=sys.stderr) exit(-1) sc = SparkContext(appName="PythonLR") points = sc.textFile(sys.argv[1]).map(parsePoint) iterations = int(sys.argv[2]) model = LogisticRegressionWithSGD.train(points, iterations) - print "Final weights: " + str(model.weights) - print "Final intercept: " + str(model.intercept) + print("Final weights: " + str(model.weights)) + print("Final intercept: " + str(model.intercept)) sc.stop() diff --git a/examples/src/main/python/mllib/random_forest_example.py b/examples/src/main/python/mllib/random_forest_example.py index d3c24f7664329..4cfdad868c66e 100755 --- a/examples/src/main/python/mllib/random_forest_example.py +++ b/examples/src/main/python/mllib/random_forest_example.py @@ -22,6 +22,7 @@ For information on multiclass classification, please refer to the decision_tree_runner.py example. """ +from __future__ import print_function import sys @@ -43,7 +44,7 @@ def testClassification(trainingData, testData): # Evaluate model on test instances and compute test error predictions = model.predict(testData.map(lambda x: x.features)) labelsAndPredictions = testData.map(lambda lp: lp.label).zip(predictions) - testErr = labelsAndPredictions.filter(lambda (v, p): v != p).count()\ + testErr = labelsAndPredictions.filter(lambda v_p: v_p[0] != v_p[1]).count()\ / float(testData.count()) print('Test Error = ' + str(testErr)) print('Learned classification forest model:') @@ -62,8 +63,8 @@ def testRegression(trainingData, testData): # Evaluate model on test instances and compute test error predictions = model.predict(testData.map(lambda x: x.features)) labelsAndPredictions = testData.map(lambda lp: lp.label).zip(predictions) - testMSE = labelsAndPredictions.map(lambda (v, p): (v - p) * (v - p)).sum()\ - / float(testData.count()) + testMSE = labelsAndPredictions.map(lambda v_p1: (v_p1[0] - v_p1[1]) * (v_p1[0] - v_p1[1]))\ + .sum() / float(testData.count()) print('Test Mean Squared Error = ' + str(testMSE)) print('Learned regression forest model:') print(model.toDebugString()) @@ -71,7 +72,7 @@ def testRegression(trainingData, testData): if __name__ == "__main__": if len(sys.argv) > 1: - print >> sys.stderr, "Usage: random_forest_example" + print("Usage: random_forest_example", file=sys.stderr) exit(1) sc = SparkContext(appName="PythonRandomForestExample") diff --git a/examples/src/main/python/mllib/random_rdd_generation.py b/examples/src/main/python/mllib/random_rdd_generation.py index 1e8892741e714..729bae30b152c 100755 --- a/examples/src/main/python/mllib/random_rdd_generation.py +++ b/examples/src/main/python/mllib/random_rdd_generation.py @@ -18,6 +18,7 @@ """ Randomly generated RDDs. """ +from __future__ import print_function import sys @@ -27,7 +28,7 @@ if __name__ == "__main__": if len(sys.argv) not in [1, 2]: - print >> sys.stderr, "Usage: random_rdd_generation" + print("Usage: random_rdd_generation", file=sys.stderr) exit(-1) sc = SparkContext(appName="PythonRandomRDDGeneration") @@ -37,19 +38,19 @@ # Example: RandomRDDs.normalRDD normalRDD = RandomRDDs.normalRDD(sc, numExamples) - print 'Generated RDD of %d examples sampled from the standard normal distribution'\ - % normalRDD.count() - print ' First 5 samples:' + print('Generated RDD of %d examples sampled from the standard normal distribution' + % normalRDD.count()) + print(' First 5 samples:') for sample in normalRDD.take(5): - print ' ' + str(sample) - print + print(' ' + str(sample)) + print() # Example: RandomRDDs.normalVectorRDD normalVectorRDD = RandomRDDs.normalVectorRDD(sc, numRows=numExamples, numCols=2) - print 'Generated RDD of %d examples of length-2 vectors.' % normalVectorRDD.count() - print ' First 5 samples:' + print('Generated RDD of %d examples of length-2 vectors.' % normalVectorRDD.count()) + print(' First 5 samples:') for sample in normalVectorRDD.take(5): - print ' ' + str(sample) - print + print(' ' + str(sample)) + print() sc.stop() diff --git a/examples/src/main/python/mllib/sampled_rdds.py b/examples/src/main/python/mllib/sampled_rdds.py index 92af3af5ebd1e..b7033ab7daeb3 100755 --- a/examples/src/main/python/mllib/sampled_rdds.py +++ b/examples/src/main/python/mllib/sampled_rdds.py @@ -18,6 +18,7 @@ """ Randomly sampled RDDs. """ +from __future__ import print_function import sys @@ -27,7 +28,7 @@ if __name__ == "__main__": if len(sys.argv) not in [1, 2]: - print >> sys.stderr, "Usage: sampled_rdds " + print("Usage: sampled_rdds ", file=sys.stderr) exit(-1) if len(sys.argv) == 2: datapath = sys.argv[1] @@ -41,24 +42,24 @@ examples = MLUtils.loadLibSVMFile(sc, datapath) numExamples = examples.count() if numExamples == 0: - print >> sys.stderr, "Error: Data file had no samples to load." + print("Error: Data file had no samples to load.", file=sys.stderr) exit(1) - print 'Loaded data with %d examples from file: %s' % (numExamples, datapath) + print('Loaded data with %d examples from file: %s' % (numExamples, datapath)) # Example: RDD.sample() and RDD.takeSample() expectedSampleSize = int(numExamples * fraction) - print 'Sampling RDD using fraction %g. Expected sample size = %d.' \ - % (fraction, expectedSampleSize) + print('Sampling RDD using fraction %g. Expected sample size = %d.' + % (fraction, expectedSampleSize)) sampledRDD = examples.sample(withReplacement=True, fraction=fraction) - print ' RDD.sample(): sample has %d examples' % sampledRDD.count() + print(' RDD.sample(): sample has %d examples' % sampledRDD.count()) sampledArray = examples.takeSample(withReplacement=True, num=expectedSampleSize) - print ' RDD.takeSample(): sample has %d examples' % len(sampledArray) + print(' RDD.takeSample(): sample has %d examples' % len(sampledArray)) - print + print() # Example: RDD.sampleByKey() keyedRDD = examples.map(lambda lp: (int(lp.label), lp.features)) - print ' Keyed data using label (Int) as key ==> Orig' + print(' Keyed data using label (Int) as key ==> Orig') # Count examples per label in original data. keyCountsA = keyedRDD.countByKey() @@ -69,18 +70,18 @@ sampledByKeyRDD = keyedRDD.sampleByKey(withReplacement=True, fractions=fractions) keyCountsB = sampledByKeyRDD.countByKey() sizeB = sum(keyCountsB.values()) - print ' Sampled %d examples using approximate stratified sampling (by label). ==> Sample' \ - % sizeB + print(' Sampled %d examples using approximate stratified sampling (by label). ==> Sample' + % sizeB) # Compare samples - print ' \tFractions of examples with key' - print 'Key\tOrig\tSample' + print(' \tFractions of examples with key') + print('Key\tOrig\tSample') for k in sorted(keyCountsA.keys()): fracA = keyCountsA[k] / float(numExamples) if sizeB != 0: fracB = keyCountsB.get(k, 0) / float(sizeB) else: fracB = 0 - print '%d\t%g\t%g' % (k, fracA, fracB) + print('%d\t%g\t%g' % (k, fracA, fracB)) sc.stop() diff --git a/examples/src/main/python/mllib/word2vec.py b/examples/src/main/python/mllib/word2vec.py index 99fef4276a369..40d1b887927e0 100644 --- a/examples/src/main/python/mllib/word2vec.py +++ b/examples/src/main/python/mllib/word2vec.py @@ -23,6 +23,7 @@ # grep -o -E '\w+(\W+\w+){0,15}' text8 > text8_lines # This was done so that the example can be run in local mode +from __future__ import print_function import sys @@ -34,7 +35,7 @@ if __name__ == "__main__": if len(sys.argv) < 2: - print USAGE + print(USAGE) sys.exit("Argument for file not provided") file_path = sys.argv[1] sc = SparkContext(appName='Word2Vec') @@ -46,5 +47,5 @@ synonyms = model.findSynonyms('china', 40) for word, cosine_distance in synonyms: - print "{}: {}".format(word, cosine_distance) + print("{}: {}".format(word, cosine_distance)) sc.stop() diff --git a/examples/src/main/python/pagerank.py b/examples/src/main/python/pagerank.py index a5f25d78c1146..2fdc9773d4eb1 100755 --- a/examples/src/main/python/pagerank.py +++ b/examples/src/main/python/pagerank.py @@ -19,6 +19,7 @@ This is an example implementation of PageRank. For more conventional use, Please refer to PageRank implementation provided by graphx """ +from __future__ import print_function import re import sys @@ -42,11 +43,12 @@ def parseNeighbors(urls): if __name__ == "__main__": if len(sys.argv) != 3: - print >> sys.stderr, "Usage: pagerank " + print("Usage: pagerank ", file=sys.stderr) exit(-1) - print >> sys.stderr, """WARN: This is a naive implementation of PageRank and is - given as an example! Please refer to PageRank implementation provided by graphx""" + print("""WARN: This is a naive implementation of PageRank and is + given as an example! Please refer to PageRank implementation provided by graphx""", + file=sys.stderr) # Initialize the spark context. sc = SparkContext(appName="PythonPageRank") @@ -62,19 +64,19 @@ def parseNeighbors(urls): links = lines.map(lambda urls: parseNeighbors(urls)).distinct().groupByKey().cache() # Loads all URLs with other URL(s) link to from input file and initialize ranks of them to one. - ranks = links.map(lambda (url, neighbors): (url, 1.0)) + ranks = links.map(lambda url_neighbors: (url_neighbors[0], 1.0)) # Calculates and updates URL ranks continuously using PageRank algorithm. - for iteration in xrange(int(sys.argv[2])): + for iteration in range(int(sys.argv[2])): # Calculates URL contributions to the rank of other URLs. contribs = links.join(ranks).flatMap( - lambda (url, (urls, rank)): computeContribs(urls, rank)) + lambda url_urls_rank: computeContribs(url_urls_rank[1][0], url_urls_rank[1][1])) # Re-calculates URL ranks based on neighbor contributions. ranks = contribs.reduceByKey(add).mapValues(lambda rank: rank * 0.85 + 0.15) # Collects all URL ranks and dump them to console. for (link, rank) in ranks.collect(): - print "%s has rank: %s." % (link, rank) + print("%s has rank: %s." % (link, rank)) sc.stop() diff --git a/examples/src/main/python/parquet_inputformat.py b/examples/src/main/python/parquet_inputformat.py index fa4c20ab20281..96ddac761d698 100644 --- a/examples/src/main/python/parquet_inputformat.py +++ b/examples/src/main/python/parquet_inputformat.py @@ -1,3 +1,4 @@ +from __future__ import print_function # # Licensed to the Apache Software Foundation (ASF) under one or more # contributor license agreements. See the NOTICE file distributed with @@ -35,14 +36,14 @@ """ if __name__ == "__main__": if len(sys.argv) != 2: - print >> sys.stderr, """ + print(""" Usage: parquet_inputformat.py Run with example jar: ./bin/spark-submit --driver-class-path /path/to/example/jar \\ /path/to/examples/parquet_inputformat.py Assumes you have Parquet data stored in . - """ + """, file=sys.stderr) exit(-1) path = sys.argv[1] @@ -56,6 +57,6 @@ valueConverter='org.apache.spark.examples.pythonconverters.IndexedRecordToJavaConverter') output = parquet_rdd.map(lambda x: x[1]).collect() for k in output: - print k + print(k) sc.stop() diff --git a/examples/src/main/python/pi.py b/examples/src/main/python/pi.py index a7c74e969cdb9..92e5cf45abc8b 100755 --- a/examples/src/main/python/pi.py +++ b/examples/src/main/python/pi.py @@ -1,3 +1,4 @@ +from __future__ import print_function # # Licensed to the Apache Software Foundation (ASF) under one or more # contributor license agreements. See the NOTICE file distributed with @@ -35,7 +36,7 @@ def f(_): y = random() * 2 - 1 return 1 if x ** 2 + y ** 2 < 1 else 0 - count = sc.parallelize(xrange(1, n + 1), partitions).map(f).reduce(add) - print "Pi is roughly %f" % (4.0 * count / n) + count = sc.parallelize(range(1, n + 1), partitions).map(f).reduce(add) + print("Pi is roughly %f" % (4.0 * count / n)) sc.stop() diff --git a/examples/src/main/python/sort.py b/examples/src/main/python/sort.py index bb686f17518a0..f6b0ecb02c100 100755 --- a/examples/src/main/python/sort.py +++ b/examples/src/main/python/sort.py @@ -15,6 +15,8 @@ # limitations under the License. # +from __future__ import print_function + import sys from pyspark import SparkContext @@ -22,7 +24,7 @@ if __name__ == "__main__": if len(sys.argv) != 2: - print >> sys.stderr, "Usage: sort " + print("Usage: sort ", file=sys.stderr) exit(-1) sc = SparkContext(appName="PythonSort") lines = sc.textFile(sys.argv[1], 1) @@ -33,6 +35,6 @@ # In reality, we wouldn't want to collect all the data to the driver node. output = sortedCount.collect() for (num, unitcount) in output: - print num + print(num) sc.stop() diff --git a/examples/src/main/python/sql.py b/examples/src/main/python/sql.py index d89361f324917..87d7b088f077b 100644 --- a/examples/src/main/python/sql.py +++ b/examples/src/main/python/sql.py @@ -15,6 +15,8 @@ # limitations under the License. # +from __future__ import print_function + import os from pyspark import SparkContext @@ -68,6 +70,6 @@ teenagers = sqlContext.sql("SELECT name FROM people WHERE age >= 13 AND age <= 19") for each in teenagers.collect(): - print each[0] + print(each[0]) sc.stop() diff --git a/examples/src/main/python/status_api_demo.py b/examples/src/main/python/status_api_demo.py index a33bdc475a06d..49b7902185aaa 100644 --- a/examples/src/main/python/status_api_demo.py +++ b/examples/src/main/python/status_api_demo.py @@ -15,6 +15,8 @@ # limitations under the License. # +from __future__ import print_function + import time import threading import Queue @@ -52,15 +54,15 @@ def run(): ids = status.getJobIdsForGroup() for id in ids: job = status.getJobInfo(id) - print "Job", id, "status: ", job.status + print("Job", id, "status: ", job.status) for sid in job.stageIds: info = status.getStageInfo(sid) if info: - print "Stage %d: %d tasks total (%d active, %d complete)" % \ - (sid, info.numTasks, info.numActiveTasks, info.numCompletedTasks) + print("Stage %d: %d tasks total (%d active, %d complete)" % + (sid, info.numTasks, info.numActiveTasks, info.numCompletedTasks)) time.sleep(1) - print "Job results are:", result.get() + print("Job results are:", result.get()) sc.stop() if __name__ == "__main__": diff --git a/examples/src/main/python/streaming/hdfs_wordcount.py b/examples/src/main/python/streaming/hdfs_wordcount.py index f7ffb5379681e..f815dd26823d1 100644 --- a/examples/src/main/python/streaming/hdfs_wordcount.py +++ b/examples/src/main/python/streaming/hdfs_wordcount.py @@ -25,6 +25,7 @@ Then create a text file in `localdir` and the words in the file will get counted. """ +from __future__ import print_function import sys @@ -33,7 +34,7 @@ if __name__ == "__main__": if len(sys.argv) != 2: - print >> sys.stderr, "Usage: hdfs_wordcount.py " + print("Usage: hdfs_wordcount.py ", file=sys.stderr) exit(-1) sc = SparkContext(appName="PythonStreamingHDFSWordCount") diff --git a/examples/src/main/python/streaming/kafka_wordcount.py b/examples/src/main/python/streaming/kafka_wordcount.py index 51e1ff822fc55..b178e7899b5e1 100644 --- a/examples/src/main/python/streaming/kafka_wordcount.py +++ b/examples/src/main/python/streaming/kafka_wordcount.py @@ -27,6 +27,7 @@ spark-streaming-kafka-assembly-*.jar examples/src/main/python/streaming/kafka_wordcount.py \ localhost:2181 test` """ +from __future__ import print_function import sys @@ -36,7 +37,7 @@ if __name__ == "__main__": if len(sys.argv) != 3: - print >> sys.stderr, "Usage: kafka_wordcount.py " + print("Usage: kafka_wordcount.py ", file=sys.stderr) exit(-1) sc = SparkContext(appName="PythonStreamingKafkaWordCount") diff --git a/examples/src/main/python/streaming/network_wordcount.py b/examples/src/main/python/streaming/network_wordcount.py index cfa9c1ff5bfbc..2b48bcfd55db0 100644 --- a/examples/src/main/python/streaming/network_wordcount.py +++ b/examples/src/main/python/streaming/network_wordcount.py @@ -25,6 +25,7 @@ and then run the example `$ bin/spark-submit examples/src/main/python/streaming/network_wordcount.py localhost 9999` """ +from __future__ import print_function import sys @@ -33,7 +34,7 @@ if __name__ == "__main__": if len(sys.argv) != 3: - print >> sys.stderr, "Usage: network_wordcount.py " + print("Usage: network_wordcount.py ", file=sys.stderr) exit(-1) sc = SparkContext(appName="PythonStreamingNetworkWordCount") ssc = StreamingContext(sc, 1) diff --git a/examples/src/main/python/streaming/recoverable_network_wordcount.py b/examples/src/main/python/streaming/recoverable_network_wordcount.py index fc6827c82bf9b..ac91f0a06b172 100644 --- a/examples/src/main/python/streaming/recoverable_network_wordcount.py +++ b/examples/src/main/python/streaming/recoverable_network_wordcount.py @@ -35,6 +35,7 @@ checkpoint data exists in ~/checkpoint/, then it will create StreamingContext from the checkpoint data. """ +from __future__ import print_function import os import sys @@ -46,7 +47,7 @@ def createContext(host, port, outputPath): # If you do not see this printed, that means the StreamingContext has been loaded # from the new checkpoint - print "Creating new context" + print("Creating new context") if os.path.exists(outputPath): os.remove(outputPath) sc = SparkContext(appName="PythonStreamingRecoverableNetworkWordCount") @@ -60,8 +61,8 @@ def createContext(host, port, outputPath): def echo(time, rdd): counts = "Counts at time %s %s" % (time, rdd.collect()) - print counts - print "Appending to " + os.path.abspath(outputPath) + print(counts) + print("Appending to " + os.path.abspath(outputPath)) with open(outputPath, 'a') as f: f.write(counts + "\n") @@ -70,8 +71,8 @@ def echo(time, rdd): if __name__ == "__main__": if len(sys.argv) != 5: - print >> sys.stderr, "Usage: recoverable_network_wordcount.py "\ - " " + print("Usage: recoverable_network_wordcount.py " + " ", file=sys.stderr) exit(-1) host, port, checkpoint, output = sys.argv[1:] ssc = StreamingContext.getOrCreate(checkpoint, diff --git a/examples/src/main/python/streaming/sql_network_wordcount.py b/examples/src/main/python/streaming/sql_network_wordcount.py index f89bc562d856b..da90c07dbd82f 100644 --- a/examples/src/main/python/streaming/sql_network_wordcount.py +++ b/examples/src/main/python/streaming/sql_network_wordcount.py @@ -27,6 +27,7 @@ and then run the example `$ bin/spark-submit examples/src/main/python/streaming/sql_network_wordcount.py localhost 9999` """ +from __future__ import print_function import os import sys @@ -44,7 +45,7 @@ def getSqlContextInstance(sparkContext): if __name__ == "__main__": if len(sys.argv) != 3: - print >> sys.stderr, "Usage: sql_network_wordcount.py " + print("Usage: sql_network_wordcount.py ", file=sys.stderr) exit(-1) host, port = sys.argv[1:] sc = SparkContext(appName="PythonSqlNetworkWordCount") @@ -57,7 +58,7 @@ def getSqlContextInstance(sparkContext): # Convert RDDs of the words DStream to DataFrame and run SQL query def process(time, rdd): - print "========= %s =========" % str(time) + print("========= %s =========" % str(time)) try: # Get the singleton instance of SQLContext diff --git a/examples/src/main/python/streaming/stateful_network_wordcount.py b/examples/src/main/python/streaming/stateful_network_wordcount.py index 18a9a5a452ffb..16ef646b7c42e 100644 --- a/examples/src/main/python/streaming/stateful_network_wordcount.py +++ b/examples/src/main/python/streaming/stateful_network_wordcount.py @@ -29,6 +29,7 @@ `$ bin/spark-submit examples/src/main/python/streaming/stateful_network_wordcount.py \ localhost 9999` """ +from __future__ import print_function import sys @@ -37,7 +38,7 @@ if __name__ == "__main__": if len(sys.argv) != 3: - print >> sys.stderr, "Usage: stateful_network_wordcount.py " + print("Usage: stateful_network_wordcount.py ", file=sys.stderr) exit(-1) sc = SparkContext(appName="PythonStreamingStatefulNetworkWordCount") ssc = StreamingContext(sc, 1) diff --git a/examples/src/main/python/transitive_closure.py b/examples/src/main/python/transitive_closure.py index 00a281bfb6506..7bf5fb6ddfe29 100755 --- a/examples/src/main/python/transitive_closure.py +++ b/examples/src/main/python/transitive_closure.py @@ -15,6 +15,8 @@ # limitations under the License. # +from __future__ import print_function + import sys from random import Random @@ -49,20 +51,20 @@ def generateGraph(): # the graph to obtain the path (x, z). # Because join() joins on keys, the edges are stored in reversed order. - edges = tc.map(lambda (x, y): (y, x)) + edges = tc.map(lambda x_y: (x_y[1], x_y[0])) - oldCount = 0L + oldCount = 0 nextCount = tc.count() while True: oldCount = nextCount # Perform the join, obtaining an RDD of (y, (z, x)) pairs, # then project the result to obtain the new (x, z) paths. - new_edges = tc.join(edges).map(lambda (_, (a, b)): (b, a)) + new_edges = tc.join(edges).map(lambda __a_b: (__a_b[1][1], __a_b[1][0])) tc = tc.union(new_edges).distinct().cache() nextCount = tc.count() if nextCount == oldCount: break - print "TC has %i edges" % tc.count() + print("TC has %i edges" % tc.count()) sc.stop() diff --git a/examples/src/main/python/wordcount.py b/examples/src/main/python/wordcount.py index ae6cd13b83d92..7c0143607b61d 100755 --- a/examples/src/main/python/wordcount.py +++ b/examples/src/main/python/wordcount.py @@ -15,6 +15,8 @@ # limitations under the License. # +from __future__ import print_function + import sys from operator import add @@ -23,7 +25,7 @@ if __name__ == "__main__": if len(sys.argv) != 2: - print >> sys.stderr, "Usage: wordcount " + print("Usage: wordcount ", file=sys.stderr) exit(-1) sc = SparkContext(appName="PythonWordCount") lines = sc.textFile(sys.argv[1], 1) @@ -32,6 +34,6 @@ .reduceByKey(add) output = counts.collect() for (word, count) in output: - print "%s: %i" % (word, count) + print("%s: %i" % (word, count)) sc.stop() diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/DeveloperApiExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/DeveloperApiExample.scala index df26798e41b7b..2245fa429fda3 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/DeveloperApiExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/DeveloperApiExample.scala @@ -99,7 +99,7 @@ private trait MyLogisticRegressionParams extends ClassifierParams { * class since the maxIter parameter is only used during training (not in the Model). */ val maxIter: IntParam = new IntParam(this, "maxIter", "max number of iterations") - def getMaxIter: Int = get(maxIter) + def getMaxIter: Int = getOrDefault(maxIter) } /** @@ -174,11 +174,11 @@ private class MyLogisticRegressionModel( * Create a copy of the model. * The copy is shallow, except for the embedded paramMap, which gets a deep copy. * - * This is used for the defaul implementation of [[transform()]]. + * This is used for the default implementation of [[transform()]]. */ override protected def copy(): MyLogisticRegressionModel = { val m = new MyLogisticRegressionModel(parent, fittingParamMap, weights) - Params.inheritValues(this.paramMap, this, m) + Params.inheritValues(extractParamMap(), this, m) m } } diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/PowerIterationClusteringExample.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/PowerIterationClusteringExample.scala index 9f22d40c15f3f..6d8b806569dfd 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/PowerIterationClusteringExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/PowerIterationClusteringExample.scala @@ -65,7 +65,7 @@ object PowerIterationClusteringExample { def main(args: Array[String]) { val defaultParams = Params() - val parser = new OptionParser[Params]("PIC Circles") { + val parser = new OptionParser[Params]("PowerIterationClusteringExample") { head("PowerIterationClusteringExample: an example PIC app using concentric circles.") opt[Int]('k', "k") .text(s"number of circles (/clusters), default: ${defaultParams.k}") @@ -76,9 +76,9 @@ object PowerIterationClusteringExample { opt[Int]("maxIterations") .text(s"number of iterations, default: ${defaultParams.maxIterations}") .action((x, c) => c.copy(maxIterations = x)) - opt[Int]('r', "r") + opt[Double]('r', "r") .text(s"radius of outermost circle, default: ${defaultParams.outerRadius}") - .action((x, c) => c.copy(numPoints = x)) + .action((x, c) => c.copy(outerRadius = x)) } parser.parse(args, defaultParams).map { params => @@ -154,3 +154,4 @@ object PowerIterationClusteringExample { coeff * math.exp(expCoeff * ssquares) } } + diff --git a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaRDD.scala b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaRDD.scala index a0b8a0c565210..a1b4a12e5d6a0 100644 --- a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaRDD.scala +++ b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaRDD.scala @@ -23,10 +23,9 @@ import org.apache.spark.{Logging, Partition, SparkContext, SparkException, TaskC import org.apache.spark.rdd.RDD import org.apache.spark.util.NextIterator -import java.util.Properties import kafka.api.{FetchRequestBuilder, FetchResponse} import kafka.common.{ErrorMapping, TopicAndPartition} -import kafka.consumer.{ConsumerConfig, SimpleConsumer} +import kafka.consumer.SimpleConsumer import kafka.message.{MessageAndMetadata, MessageAndOffset} import kafka.serializer.Decoder import kafka.utils.VerifiableProperties diff --git a/launcher/src/main/java/org/apache/spark/launcher/AbstractCommandBuilder.java b/launcher/src/main/java/org/apache/spark/launcher/AbstractCommandBuilder.java index d8279145d8e90..b8f02b961113d 100644 --- a/launcher/src/main/java/org/apache/spark/launcher/AbstractCommandBuilder.java +++ b/launcher/src/main/java/org/apache/spark/launcher/AbstractCommandBuilder.java @@ -186,12 +186,24 @@ List buildClassPath(String appClassPath) throws IOException { addToClassPath(cp, String.format("%s/core/target/jars/*", sparkHome)); } - final String assembly = AbstractCommandBuilder.class.getProtectionDomain().getCodeSource(). - getLocation().getPath(); + // We can't rely on the ENV_SPARK_ASSEMBLY variable to be set. Certain situations, such as + // when running unit tests, or user code that embeds Spark and creates a SparkContext + // with a local or local-cluster master, will cause this code to be called from an + // environment where that env variable is not guaranteed to exist. + // + // For the testing case, we rely on the test code to set and propagate the test classpath + // appropriately. + // + // For the user code case, we fall back to looking for the Spark assembly under SPARK_HOME. + // That duplicates some of the code in the shell scripts that look for the assembly, though. + String assembly = getenv(ENV_SPARK_ASSEMBLY); + if (assembly == null && isEmpty(getenv("SPARK_TESTING"))) { + assembly = findAssembly(); + } addToClassPath(cp, assembly); - // Datanucleus jars must be included on the classpath. Datanucleus jars do not work if only - // included in the uber jar as plugin.xml metadata is lost. Both sbt and maven will populate + // Datanucleus jars must be included on the classpath. Datanucleus jars do not work if only + // included in the uber jar as plugin.xml metadata is lost. Both sbt and maven will populate // "lib_managed/jars/" with the datanucleus jars when Spark is built with Hive File libdir; if (new File(sparkHome, "RELEASE").isFile()) { @@ -299,6 +311,30 @@ String getenv(String key) { return firstNonEmpty(childEnv.get(key), System.getenv(key)); } + private String findAssembly() { + String sparkHome = getSparkHome(); + File libdir; + if (new File(sparkHome, "RELEASE").isFile()) { + libdir = new File(sparkHome, "lib"); + checkState(libdir.isDirectory(), "Library directory '%s' does not exist.", + libdir.getAbsolutePath()); + } else { + libdir = new File(sparkHome, String.format("assembly/target/scala-%s", getScalaVersion())); + } + + final Pattern re = Pattern.compile("spark-assembly.*hadoop.*\\.jar"); + FileFilter filter = new FileFilter() { + @Override + public boolean accept(File file) { + return file.isFile() && re.matcher(file.getName()).matches(); + } + }; + File[] assemblies = libdir.listFiles(filter); + checkState(assemblies != null && assemblies.length > 0, "No assemblies found in '%s'.", libdir); + checkState(assemblies.length == 1, "Multiple assemblies found in '%s'.", libdir); + return assemblies[0].getAbsolutePath(); + } + private String getConfDir() { String confDir = getenv("SPARK_CONF_DIR"); return confDir != null ? confDir : join(File.separator, getSparkHome(), "conf"); diff --git a/launcher/src/main/java/org/apache/spark/launcher/CommandBuilderUtils.java b/launcher/src/main/java/org/apache/spark/launcher/CommandBuilderUtils.java index f4ebc25bdd32b..8028e42ffb483 100644 --- a/launcher/src/main/java/org/apache/spark/launcher/CommandBuilderUtils.java +++ b/launcher/src/main/java/org/apache/spark/launcher/CommandBuilderUtils.java @@ -30,6 +30,7 @@ class CommandBuilderUtils { static final String DEFAULT_MEM = "512m"; static final String DEFAULT_PROPERTIES_FILE = "spark-defaults.conf"; static final String ENV_SPARK_HOME = "SPARK_HOME"; + static final String ENV_SPARK_ASSEMBLY = "_SPARK_ASSEMBLY"; /** Returns whether the given string is null or empty. */ static boolean isEmpty(String s) { diff --git a/launcher/src/main/java/org/apache/spark/launcher/SparkLauncher.java b/launcher/src/main/java/org/apache/spark/launcher/SparkLauncher.java index b566507ee6061..d4cfeacb6ef18 100644 --- a/launcher/src/main/java/org/apache/spark/launcher/SparkLauncher.java +++ b/launcher/src/main/java/org/apache/spark/launcher/SparkLauncher.java @@ -52,7 +52,7 @@ public class SparkLauncher { /** Configuration key for the executor VM options. */ public static final String EXECUTOR_EXTRA_JAVA_OPTIONS = "spark.executor.extraJavaOptions"; /** Configuration key for the executor native library path. */ - public static final String EXECUTOR_EXTRA_LIBRARY_PATH = "spark.executor.extraLibraryOptions"; + public static final String EXECUTOR_EXTRA_LIBRARY_PATH = "spark.executor.extraLibraryPath"; /** Configuration key for the number of executor CPU cores. */ public static final String EXECUTOR_CORES = "spark.executor.cores"; diff --git a/launcher/src/test/java/org/apache/spark/launcher/SparkSubmitCommandBuilderSuite.java b/launcher/src/test/java/org/apache/spark/launcher/SparkSubmitCommandBuilderSuite.java index 626116a9e7477..97043a76cc612 100644 --- a/launcher/src/test/java/org/apache/spark/launcher/SparkSubmitCommandBuilderSuite.java +++ b/launcher/src/test/java/org/apache/spark/launcher/SparkSubmitCommandBuilderSuite.java @@ -98,7 +98,7 @@ public void testShellCliParser() throws Exception { parser.NAME, "appName"); - List args = new SparkSubmitCommandBuilder(sparkSubmitArgs).buildSparkSubmitArgs(); + List args = newCommandBuilder(sparkSubmitArgs).buildSparkSubmitArgs(); List expected = Arrays.asList("spark-shell", "--app-arg", "bar", "--app-switch"); assertEquals(expected, args.subList(args.size() - expected.size(), args.size())); } @@ -110,7 +110,7 @@ public void testAlternateSyntaxParsing() throws Exception { parser.MASTER + "=foo", parser.DEPLOY_MODE + "=bar"); - List cmd = new SparkSubmitCommandBuilder(sparkSubmitArgs).buildSparkSubmitArgs(); + List cmd = newCommandBuilder(sparkSubmitArgs).buildSparkSubmitArgs(); assertEquals("org.my.Class", findArgValue(cmd, parser.CLASS)); assertEquals("foo", findArgValue(cmd, parser.MASTER)); assertEquals("bar", findArgValue(cmd, parser.DEPLOY_MODE)); @@ -153,7 +153,7 @@ private void testCmdBuilder(boolean isDriver) throws Exception { String deployMode = isDriver ? "client" : "cluster"; SparkSubmitCommandBuilder launcher = - new SparkSubmitCommandBuilder(Collections.emptyList()); + newCommandBuilder(Collections.emptyList()); launcher.childEnv.put(CommandBuilderUtils.ENV_SPARK_HOME, System.getProperty("spark.test.home")); launcher.master = "yarn"; @@ -273,10 +273,15 @@ private boolean findInStringList(String list, String sep, String needle) { return contains(needle, list.split(sep)); } - private List buildCommand(List args, Map env) throws Exception { + private SparkSubmitCommandBuilder newCommandBuilder(List args) { SparkSubmitCommandBuilder builder = new SparkSubmitCommandBuilder(args); builder.childEnv.put(CommandBuilderUtils.ENV_SPARK_HOME, System.getProperty("spark.test.home")); - return builder.buildCommand(env); + builder.childEnv.put(CommandBuilderUtils.ENV_SPARK_ASSEMBLY, "dummy"); + return builder; + } + + private List buildCommand(List args, Map env) throws Exception { + return newCommandBuilder(args).buildCommand(env); } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/Estimator.scala b/mllib/src/main/scala/org/apache/spark/ml/Estimator.scala index eff7ef925dfbd..d6b3503ebdd9a 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/Estimator.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/Estimator.scala @@ -40,7 +40,7 @@ abstract class Estimator[M <: Model[M]] extends PipelineStage with Params { */ @varargs def fit(dataset: DataFrame, paramPairs: ParamPair[_]*): M = { - val map = new ParamMap().put(paramPairs: _*) + val map = ParamMap(paramPairs: _*) fit(dataset, map) } diff --git a/mllib/src/main/scala/org/apache/spark/ml/Identifiable.scala b/mllib/src/main/scala/org/apache/spark/ml/Identifiable.scala index a50090671ae48..a1d49095c24ac 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/Identifiable.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/Identifiable.scala @@ -25,7 +25,7 @@ import java.util.UUID private[ml] trait Identifiable extends Serializable { /** - * A unique id for the object. The default implementation concatenates the class name, "-", and 8 + * A unique id for the object. The default implementation concatenates the class name, "_", and 8 * random hex chars. */ private[ml] val uid: String = diff --git a/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala b/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala index a455341a1f723..8eddf79cdfe28 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala @@ -84,7 +84,7 @@ class Pipeline extends Estimator[PipelineModel] { /** param for pipeline stages */ val stages: Param[Array[PipelineStage]] = new Param(this, "stages", "stages of the pipeline") def setStages(value: Array[PipelineStage]): this.type = { set(stages, value); this } - def getStages: Array[PipelineStage] = get(stages) + def getStages: Array[PipelineStage] = getOrDefault(stages) /** * Fits the pipeline to the input dataset with additional parameters. If a stage is an @@ -101,7 +101,7 @@ class Pipeline extends Estimator[PipelineModel] { */ override def fit(dataset: DataFrame, paramMap: ParamMap): PipelineModel = { transformSchema(dataset.schema, paramMap, logging = true) - val map = this.paramMap ++ paramMap + val map = extractParamMap(paramMap) val theStages = map(stages) // Search for the last estimator. var indexOfLastEstimator = -1 @@ -138,7 +138,7 @@ class Pipeline extends Estimator[PipelineModel] { } override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = { - val map = this.paramMap ++ paramMap + val map = extractParamMap(paramMap) val theStages = map(stages) require(theStages.toSet.size == theStages.size, "Cannot have duplicate components in a pipeline.") @@ -177,14 +177,14 @@ class PipelineModel private[ml] ( override def transform(dataset: DataFrame, paramMap: ParamMap): DataFrame = { // Precedence of ParamMaps: paramMap > this.paramMap > fittingParamMap - val map = (fittingParamMap ++ this.paramMap) ++ paramMap + val map = fittingParamMap ++ extractParamMap(paramMap) transformSchema(dataset.schema, map, logging = true) stages.foldLeft(dataset)((cur, transformer) => transformer.transform(cur, map)) } override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = { // Precedence of ParamMaps: paramMap > this.paramMap > fittingParamMap - val map = (fittingParamMap ++ this.paramMap) ++ paramMap + val map = fittingParamMap ++ extractParamMap(paramMap) stages.foldLeft(schema)((cur, transformer) => transformer.transformSchema(cur, map)) } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/Transformer.scala b/mllib/src/main/scala/org/apache/spark/ml/Transformer.scala index 9a5848684b179..7fb87fe452ee6 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/Transformer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/Transformer.scala @@ -22,6 +22,7 @@ import scala.annotation.varargs import org.apache.spark.Logging import org.apache.spark.annotation.AlphaComponent import org.apache.spark.ml.param._ +import org.apache.spark.ml.param.shared._ import org.apache.spark.sql.DataFrame import org.apache.spark.sql.functions._ import org.apache.spark.sql.types._ @@ -86,7 +87,7 @@ private[ml] abstract class UnaryTransformer[IN, OUT, T <: UnaryTransformer[IN, O protected def validateInputType(inputType: DataType): Unit = {} override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = { - val map = this.paramMap ++ paramMap + val map = extractParamMap(paramMap) val inputType = schema(map(inputCol)).dataType validateInputType(inputType) if (schema.fieldNames.contains(map(outputCol))) { @@ -99,7 +100,7 @@ private[ml] abstract class UnaryTransformer[IN, OUT, T <: UnaryTransformer[IN, O override def transform(dataset: DataFrame, paramMap: ParamMap): DataFrame = { transformSchema(dataset.schema, paramMap, logging = true) - val map = this.paramMap ++ paramMap + val map = extractParamMap(paramMap) dataset.withColumn(map(outputCol), callUDF(this.createTransformFunc(map), outputDataType, dataset(map(inputCol)))) } diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala index c5fc89f935432..29339c98f51cf 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala @@ -17,12 +17,14 @@ package org.apache.spark.ml.classification -import org.apache.spark.annotation.{DeveloperApi, AlphaComponent} +import org.apache.spark.annotation.{AlphaComponent, DeveloperApi} import org.apache.spark.ml.impl.estimator.{PredictionModel, Predictor, PredictorParams} -import org.apache.spark.ml.param.{Params, ParamMap, HasRawPredictionCol} +import org.apache.spark.ml.param.{ParamMap, Params} +import org.apache.spark.ml.param.shared.HasRawPredictionCol +import org.apache.spark.ml.util.SchemaUtils import org.apache.spark.mllib.linalg.{Vector, VectorUDT} -import org.apache.spark.sql.functions._ import org.apache.spark.sql.DataFrame +import org.apache.spark.sql.functions._ import org.apache.spark.sql.types.{DataType, DoubleType, StructType} @@ -42,8 +44,8 @@ private[spark] trait ClassifierParams extends PredictorParams fitting: Boolean, featuresDataType: DataType): StructType = { val parentSchema = super.validateAndTransformSchema(schema, paramMap, fitting, featuresDataType) - val map = this.paramMap ++ paramMap - addOutputColumn(parentSchema, map(rawPredictionCol), new VectorUDT) + val map = extractParamMap(paramMap) + SchemaUtils.appendColumn(parentSchema, map(rawPredictionCol), new VectorUDT) } } @@ -67,8 +69,7 @@ private[spark] abstract class Classifier[ with ClassifierParams { /** @group setParam */ - def setRawPredictionCol(value: String): E = - set(rawPredictionCol, value).asInstanceOf[E] + def setRawPredictionCol(value: String): E = set(rawPredictionCol, value).asInstanceOf[E] // TODO: defaultEvaluator (follow-up PR) } @@ -109,7 +110,7 @@ abstract class ClassificationModel[FeaturesType, M <: ClassificationModel[Featur // Check schema transformSchema(dataset.schema, paramMap, logging = true) - val map = this.paramMap ++ paramMap + val map = extractParamMap(paramMap) // Prepare model val tmpModel = if (paramMap.size != 0) { diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala index 34625745dd0a8..cc8b0721cf2b6 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala @@ -19,11 +19,11 @@ package org.apache.spark.ml.classification import org.apache.spark.annotation.AlphaComponent import org.apache.spark.ml.param._ +import org.apache.spark.ml.param.shared._ import org.apache.spark.mllib.classification.LogisticRegressionWithLBFGS import org.apache.spark.mllib.linalg.{VectorUDT, BLAS, Vector, Vectors} import org.apache.spark.sql.DataFrame import org.apache.spark.sql.functions._ -import org.apache.spark.sql.types.DoubleType import org.apache.spark.storage.StorageLevel @@ -31,8 +31,10 @@ import org.apache.spark.storage.StorageLevel * Params for logistic regression. */ private[classification] trait LogisticRegressionParams extends ProbabilisticClassifierParams - with HasRegParam with HasMaxIter with HasFitIntercept with HasThreshold + with HasRegParam with HasMaxIter with HasFitIntercept with HasThreshold { + setDefault(regParam -> 0.1, maxIter -> 100, threshold -> 0.5) +} /** * :: AlphaComponent :: @@ -45,10 +47,6 @@ class LogisticRegression extends ProbabilisticClassifier[Vector, LogisticRegression, LogisticRegressionModel] with LogisticRegressionParams { - setRegParam(0.1) - setMaxIter(100) - setThreshold(0.5) - /** @group setParam */ def setRegParam(value: Double): this.type = set(regParam, value) @@ -100,8 +98,6 @@ class LogisticRegressionModel private[ml] ( extends ProbabilisticClassificationModel[Vector, LogisticRegressionModel] with LogisticRegressionParams { - setThreshold(0.5) - /** @group setParam */ def setThreshold(value: Double): this.type = set(threshold, value) @@ -123,7 +119,7 @@ class LogisticRegressionModel private[ml] ( // Check schema transformSchema(dataset.schema, paramMap, logging = true) - val map = this.paramMap ++ paramMap + val map = extractParamMap(paramMap) // Output selected columns only. // This is a bit complicated since it tries to avoid repeated computation. @@ -184,7 +180,7 @@ class LogisticRegressionModel private[ml] ( * The behavior of this can be adjusted using [[threshold]]. */ override protected def predict(features: Vector): Double = { - if (score(features) > paramMap(threshold)) 1 else 0 + if (score(features) > getThreshold) 1 else 0 } override protected def predictProbabilities(features: Vector): Vector = { @@ -199,7 +195,7 @@ class LogisticRegressionModel private[ml] ( override protected def copy(): LogisticRegressionModel = { val m = new LogisticRegressionModel(parent, fittingParamMap, weights, intercept) - Params.inheritValues(this.paramMap, this, m) + Params.inheritValues(this.extractParamMap(), this, m) m } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/ProbabilisticClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/ProbabilisticClassifier.scala index bd8caac855981..10404548ccfde 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/ProbabilisticClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/ProbabilisticClassifier.scala @@ -18,13 +18,14 @@ package org.apache.spark.ml.classification import org.apache.spark.annotation.{AlphaComponent, DeveloperApi} -import org.apache.spark.ml.param.{HasProbabilityCol, ParamMap, Params} +import org.apache.spark.ml.param.{ParamMap, Params} +import org.apache.spark.ml.param.shared._ +import org.apache.spark.ml.util.SchemaUtils import org.apache.spark.mllib.linalg.{Vector, VectorUDT} import org.apache.spark.sql.DataFrame import org.apache.spark.sql.functions._ import org.apache.spark.sql.types.{DataType, StructType} - /** * Params for probabilistic classification. */ @@ -37,8 +38,8 @@ private[classification] trait ProbabilisticClassifierParams fitting: Boolean, featuresDataType: DataType): StructType = { val parentSchema = super.validateAndTransformSchema(schema, paramMap, fitting, featuresDataType) - val map = this.paramMap ++ paramMap - addOutputColumn(parentSchema, map(probabilityCol), new VectorUDT) + val map = extractParamMap(paramMap) + SchemaUtils.appendColumn(parentSchema, map(probabilityCol), new VectorUDT) } } @@ -102,7 +103,7 @@ private[spark] abstract class ProbabilisticClassificationModel[ // Check schema transformSchema(dataset.schema, paramMap, logging = true) - val map = this.paramMap ++ paramMap + val map = extractParamMap(paramMap) // Prepare model val tmpModel = if (paramMap.size != 0) { diff --git a/mllib/src/main/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluator.scala b/mllib/src/main/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluator.scala index 2360f4479f1c2..c865eb9fe092d 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluator.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluator.scala @@ -20,12 +20,13 @@ package org.apache.spark.ml.evaluation import org.apache.spark.annotation.AlphaComponent import org.apache.spark.ml.Evaluator import org.apache.spark.ml.param._ +import org.apache.spark.ml.param.shared._ +import org.apache.spark.ml.util.SchemaUtils import org.apache.spark.mllib.evaluation.BinaryClassificationMetrics import org.apache.spark.mllib.linalg.{Vector, VectorUDT} import org.apache.spark.sql.{DataFrame, Row} import org.apache.spark.sql.types.DoubleType - /** * :: AlphaComponent :: * @@ -40,10 +41,10 @@ class BinaryClassificationEvaluator extends Evaluator with Params * @group param */ val metricName: Param[String] = new Param(this, "metricName", - "metric name in evaluation (areaUnderROC|areaUnderPR)", Some("areaUnderROC")) + "metric name in evaluation (areaUnderROC|areaUnderPR)") /** @group getParam */ - def getMetricName: String = get(metricName) + def getMetricName: String = getOrDefault(metricName) /** @group setParam */ def setMetricName(value: String): this.type = set(metricName, value) @@ -54,12 +55,14 @@ class BinaryClassificationEvaluator extends Evaluator with Params /** @group setParam */ def setLabelCol(value: String): this.type = set(labelCol, value) + setDefault(metricName -> "areaUnderROC") + override def evaluate(dataset: DataFrame, paramMap: ParamMap): Double = { - val map = this.paramMap ++ paramMap + val map = extractParamMap(paramMap) val schema = dataset.schema - checkInputColumn(schema, map(rawPredictionCol), new VectorUDT) - checkInputColumn(schema, map(labelCol), DoubleType) + SchemaUtils.checkColumnType(schema, map(rawPredictionCol), new VectorUDT) + SchemaUtils.checkColumnType(schema, map(labelCol), DoubleType) // TODO: When dataset metadata has been implemented, check rawPredictionCol vector length = 2. val scoreAndLabels = dataset.select(map(rawPredictionCol), map(labelCol)) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/HashingTF.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/HashingTF.scala index fc4e12773c46d..b20f2fc49a8f6 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/HashingTF.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/HashingTF.scala @@ -35,14 +35,16 @@ class HashingTF extends UnaryTransformer[Iterable[_], Vector, HashingTF] { * number of features * @group param */ - val numFeatures = new IntParam(this, "numFeatures", "number of features", Some(1 << 18)) + val numFeatures = new IntParam(this, "numFeatures", "number of features") /** @group getParam */ - def getNumFeatures: Int = get(numFeatures) + def getNumFeatures: Int = getOrDefault(numFeatures) /** @group setParam */ def setNumFeatures(value: Int): this.type = set(numFeatures, value) + setDefault(numFeatures -> (1 << 18)) + override protected def createTransformFunc(paramMap: ParamMap): Iterable[_] => Vector = { val hashingTF = new feature.HashingTF(paramMap(numFeatures)) hashingTF.transform diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Normalizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Normalizer.scala index 05f91dc9105fe..decaeb0da6246 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/Normalizer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Normalizer.scala @@ -35,14 +35,16 @@ class Normalizer extends UnaryTransformer[Vector, Vector, Normalizer] { * Normalization in L^p^ space, p = 2 by default. * @group param */ - val p = new DoubleParam(this, "p", "the p norm value", Some(2)) + val p = new DoubleParam(this, "p", "the p norm value") /** @group getParam */ - def getP: Double = get(p) + def getP: Double = getOrDefault(p) /** @group setParam */ def setP(value: Double): this.type = set(p, value) + setDefault(p -> 2.0) + override protected def createTransformFunc(paramMap: ParamMap): Vector => Vector = { val normalizer = new feature.Normalizer(paramMap(p)) normalizer.transform @@ -50,4 +52,3 @@ class Normalizer extends UnaryTransformer[Vector, Vector, Normalizer] { override protected def outputDataType: DataType = new VectorUDT() } - diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala index 1142aa4f8e73d..1b102619b3524 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala @@ -20,6 +20,7 @@ package org.apache.spark.ml.feature import org.apache.spark.annotation.AlphaComponent import org.apache.spark.ml._ import org.apache.spark.ml.param._ +import org.apache.spark.ml.param.shared._ import org.apache.spark.mllib.feature import org.apache.spark.mllib.linalg.{Vector, VectorUDT} import org.apache.spark.sql._ @@ -47,7 +48,7 @@ class StandardScaler extends Estimator[StandardScalerModel] with StandardScalerP override def fit(dataset: DataFrame, paramMap: ParamMap): StandardScalerModel = { transformSchema(dataset.schema, paramMap, logging = true) - val map = this.paramMap ++ paramMap + val map = extractParamMap(paramMap) val input = dataset.select(map(inputCol)).map { case Row(v: Vector) => v } val scaler = new feature.StandardScaler().fit(input) val model = new StandardScalerModel(this, map, scaler) @@ -56,7 +57,7 @@ class StandardScaler extends Estimator[StandardScalerModel] with StandardScalerP } override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = { - val map = this.paramMap ++ paramMap + val map = extractParamMap(paramMap) val inputType = schema(map(inputCol)).dataType require(inputType.isInstanceOf[VectorUDT], s"Input column ${map(inputCol)} must be a vector column") @@ -86,13 +87,13 @@ class StandardScalerModel private[ml] ( override def transform(dataset: DataFrame, paramMap: ParamMap): DataFrame = { transformSchema(dataset.schema, paramMap, logging = true) - val map = this.paramMap ++ paramMap + val map = extractParamMap(paramMap) val scale = udf((v: Vector) => { scaler.transform(v) } : Vector) dataset.withColumn(map(outputCol), scale(col(map(inputCol)))) } override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = { - val map = this.paramMap ++ paramMap + val map = extractParamMap(paramMap) val inputType = schema(map(inputCol)).dataType require(inputType.isInstanceOf[VectorUDT], s"Input column ${map(inputCol)} must be a vector column") diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala index 61e6742e880d8..4d960df357fe9 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala @@ -22,6 +22,8 @@ import org.apache.spark.annotation.AlphaComponent import org.apache.spark.ml.{Estimator, Model} import org.apache.spark.ml.attribute.NominalAttribute import org.apache.spark.ml.param._ +import org.apache.spark.ml.param.shared._ +import org.apache.spark.ml.util.SchemaUtils import org.apache.spark.sql.DataFrame import org.apache.spark.sql.functions._ import org.apache.spark.sql.types.{StringType, StructType} @@ -34,8 +36,8 @@ private[feature] trait StringIndexerBase extends Params with HasInputCol with Ha /** Validates and transforms the input schema. */ protected def validateAndTransformSchema(schema: StructType, paramMap: ParamMap): StructType = { - val map = this.paramMap ++ paramMap - checkInputColumn(schema, map(inputCol), StringType) + val map = extractParamMap(paramMap) + SchemaUtils.checkColumnType(schema, map(inputCol), StringType) val inputFields = schema.fields val outputColName = map(outputCol) require(inputFields.forall(_.name != outputColName), @@ -64,7 +66,7 @@ class StringIndexer extends Estimator[StringIndexerModel] with StringIndexerBase // TODO: handle unseen labels override def fit(dataset: DataFrame, paramMap: ParamMap): StringIndexerModel = { - val map = this.paramMap ++ paramMap + val map = extractParamMap(paramMap) val counts = dataset.select(map(inputCol)).map(_.getString(0)).countByValue() val labels = counts.toSeq.sortBy(-_._2).map(_._1).toArray val model = new StringIndexerModel(this, map, labels) @@ -105,7 +107,7 @@ class StringIndexerModel private[ml] ( def setOutputCol(value: String): this.type = set(outputCol, value) override def transform(dataset: DataFrame, paramMap: ParamMap): DataFrame = { - val map = this.paramMap ++ paramMap + val map = extractParamMap(paramMap) val indexer = udf { label: String => if (labelToIndex.contains(label)) { labelToIndex(label) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Tokenizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Tokenizer.scala index 68401e36950bd..376a004858b4c 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/Tokenizer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Tokenizer.scala @@ -56,39 +56,39 @@ class RegexTokenizer extends UnaryTransformer[String, Seq[String], RegexTokenize * param for minimum token length, default is one to avoid returning empty strings * @group param */ - val minTokenLength: IntParam = new IntParam(this, "minLength", "minimum token length", Some(1)) + val minTokenLength: IntParam = new IntParam(this, "minLength", "minimum token length") /** @group setParam */ def setMinTokenLength(value: Int): this.type = set(minTokenLength, value) /** @group getParam */ - def getMinTokenLength: Int = get(minTokenLength) + def getMinTokenLength: Int = getOrDefault(minTokenLength) /** * param sets regex as splitting on gaps (true) or matching tokens (false) * @group param */ - val gaps: BooleanParam = new BooleanParam( - this, "gaps", "Set regex to match gaps or tokens", Some(false)) + val gaps: BooleanParam = new BooleanParam(this, "gaps", "Set regex to match gaps or tokens") /** @group setParam */ def setGaps(value: Boolean): this.type = set(gaps, value) /** @group getParam */ - def getGaps: Boolean = get(gaps) + def getGaps: Boolean = getOrDefault(gaps) /** * param sets regex pattern used by tokenizer * @group param */ - val pattern: Param[String] = new Param( - this, "pattern", "regex pattern used for tokenizing", Some("\\p{L}+|[^\\p{L}\\s]+")) + val pattern: Param[String] = new Param(this, "pattern", "regex pattern used for tokenizing") /** @group setParam */ def setPattern(value: String): this.type = set(pattern, value) /** @group getParam */ - def getPattern: String = get(pattern) + def getPattern: String = getOrDefault(pattern) + + setDefault(minTokenLength -> 1, gaps -> false, pattern -> "\\p{L}+|[^\\p{L}\\s]+") override protected def createTransformFunc(paramMap: ParamMap): String => Seq[String] = { str => val re = paramMap(pattern).r diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala index d1b8f7e6e9295..e567e069e7c0b 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala @@ -22,7 +22,8 @@ import scala.collection.mutable.ArrayBuilder import org.apache.spark.SparkException import org.apache.spark.annotation.AlphaComponent import org.apache.spark.ml.Transformer -import org.apache.spark.ml.param.{HasInputCols, HasOutputCol, ParamMap} +import org.apache.spark.ml.param.ParamMap +import org.apache.spark.ml.param.shared._ import org.apache.spark.mllib.linalg.{Vector, VectorUDT, Vectors} import org.apache.spark.sql.{Column, DataFrame, Row} import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute @@ -44,7 +45,7 @@ class VectorAssembler extends Transformer with HasInputCols with HasOutputCol { def setOutputCol(value: String): this.type = set(outputCol, value) override def transform(dataset: DataFrame, paramMap: ParamMap): DataFrame = { - val map = this.paramMap ++ paramMap + val map = extractParamMap(paramMap) val assembleFunc = udf { r: Row => VectorAssembler.assemble(r.toSeq: _*) } @@ -61,7 +62,7 @@ class VectorAssembler extends Transformer with HasInputCols with HasOutputCol { } override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = { - val map = this.paramMap ++ paramMap + val map = extractParamMap(paramMap) val inputColNames = map(inputCols) val outputColName = map(outputCol) val inputDataTypes = inputColNames.map(name => schema(name).dataType) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorIndexer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorIndexer.scala index 8760960e19272..452faa06e2021 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorIndexer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorIndexer.scala @@ -18,10 +18,12 @@ package org.apache.spark.ml.feature import org.apache.spark.annotation.AlphaComponent +import org.apache.spark.ml.util.SchemaUtils import org.apache.spark.ml.{Estimator, Model} import org.apache.spark.ml.attribute.{BinaryAttribute, NumericAttribute, NominalAttribute, Attribute, AttributeGroup} -import org.apache.spark.ml.param.{HasInputCol, HasOutputCol, IntParam, ParamMap, Params} +import org.apache.spark.ml.param.{IntParam, ParamMap, Params} +import org.apache.spark.ml.param.shared._ import org.apache.spark.mllib.linalg.{SparseVector, DenseVector, Vector, VectorUDT} import org.apache.spark.sql.{Row, DataFrame} import org.apache.spark.sql.functions.callUDF @@ -40,11 +42,12 @@ private[ml] trait VectorIndexerParams extends Params with HasInputCol with HasOu */ val maxCategories = new IntParam(this, "maxCategories", "Threshold for the number of values a categorical feature can take." + - " If a feature is found to have > maxCategories values, then it is declared continuous.", - Some(20)) + " If a feature is found to have > maxCategories values, then it is declared continuous.") /** @group getParam */ - def getMaxCategories: Int = get(maxCategories) + def getMaxCategories: Int = getOrDefault(maxCategories) + + setDefault(maxCategories -> 20) } /** @@ -101,7 +104,7 @@ class VectorIndexer extends Estimator[VectorIndexerModel] with VectorIndexerPara override def fit(dataset: DataFrame, paramMap: ParamMap): VectorIndexerModel = { transformSchema(dataset.schema, paramMap, logging = true) - val map = this.paramMap ++ paramMap + val map = extractParamMap(paramMap) val firstRow = dataset.select(map(inputCol)).take(1) require(firstRow.length == 1, s"VectorIndexer cannot be fit on an empty dataset.") val numFeatures = firstRow(0).getAs[Vector](0).size @@ -120,12 +123,12 @@ class VectorIndexer extends Estimator[VectorIndexerModel] with VectorIndexerPara override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = { // We do not transfer feature metadata since we do not know what types of features we will // produce in transform(). - val map = this.paramMap ++ paramMap + val map = extractParamMap(paramMap) val dataType = new VectorUDT require(map.contains(inputCol), s"VectorIndexer requires input column parameter: $inputCol") require(map.contains(outputCol), s"VectorIndexer requires output column parameter: $outputCol") - checkInputColumn(schema, map(inputCol), dataType) - addOutputColumn(schema, map(outputCol), dataType) + SchemaUtils.checkColumnType(schema, map(inputCol), dataType) + SchemaUtils.appendColumn(schema, map(outputCol), dataType) } } @@ -320,7 +323,7 @@ class VectorIndexerModel private[ml] ( override def transform(dataset: DataFrame, paramMap: ParamMap): DataFrame = { transformSchema(dataset.schema, paramMap, logging = true) - val map = this.paramMap ++ paramMap + val map = extractParamMap(paramMap) val newField = prepOutputField(dataset.schema, map) val newCol = callUDF(transformFunc, new VectorUDT, dataset(map(inputCol))) // For now, just check the first row of inputCol for vector length. @@ -334,13 +337,13 @@ class VectorIndexerModel private[ml] ( } override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = { - val map = this.paramMap ++ paramMap + val map = extractParamMap(paramMap) val dataType = new VectorUDT require(map.contains(inputCol), s"VectorIndexerModel requires input column parameter: $inputCol") require(map.contains(outputCol), s"VectorIndexerModel requires output column parameter: $outputCol") - checkInputColumn(schema, map(inputCol), dataType) + SchemaUtils.checkColumnType(schema, map(inputCol), dataType) val origAttrGroup = AttributeGroup.fromStructField(schema(map(inputCol))) val origNumFeatures: Option[Int] = if (origAttrGroup.attributes.nonEmpty) { diff --git a/mllib/src/main/scala/org/apache/spark/ml/impl/estimator/Predictor.scala b/mllib/src/main/scala/org/apache/spark/ml/impl/estimator/Predictor.scala index dfb89cc8d4af3..195333a5cc47f 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/impl/estimator/Predictor.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/impl/estimator/Predictor.scala @@ -18,8 +18,10 @@ package org.apache.spark.ml.impl.estimator import org.apache.spark.annotation.{AlphaComponent, DeveloperApi} +import org.apache.spark.ml.util.SchemaUtils import org.apache.spark.ml.{Estimator, Model} import org.apache.spark.ml.param._ +import org.apache.spark.ml.param.shared._ import org.apache.spark.mllib.linalg.{VectorUDT, Vector} import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.rdd.RDD @@ -53,14 +55,14 @@ private[spark] trait PredictorParams extends Params paramMap: ParamMap, fitting: Boolean, featuresDataType: DataType): StructType = { - val map = this.paramMap ++ paramMap + val map = extractParamMap(paramMap) // TODO: Support casting Array[Double] and Array[Float] to Vector when FeaturesType = Vector - checkInputColumn(schema, map(featuresCol), featuresDataType) + SchemaUtils.checkColumnType(schema, map(featuresCol), featuresDataType) if (fitting) { // TODO: Allow other numeric types - checkInputColumn(schema, map(labelCol), DoubleType) + SchemaUtils.checkColumnType(schema, map(labelCol), DoubleType) } - addOutputColumn(schema, map(predictionCol), DoubleType) + SchemaUtils.appendColumn(schema, map(predictionCol), DoubleType) } } @@ -98,7 +100,7 @@ private[spark] abstract class Predictor[ // This handles a few items such as schema validation. // Developers only need to implement train(). transformSchema(dataset.schema, paramMap, logging = true) - val map = this.paramMap ++ paramMap + val map = extractParamMap(paramMap) val model = train(dataset, map) Params.inheritValues(map, this, model) // copy params to model model @@ -141,7 +143,7 @@ private[spark] abstract class Predictor[ * and put it in an RDD with strong types. */ protected def extractLabeledPoints(dataset: DataFrame, paramMap: ParamMap): RDD[LabeledPoint] = { - val map = this.paramMap ++ paramMap + val map = extractParamMap(paramMap) dataset.select(map(labelCol), map(featuresCol)) .map { case Row(label: Double, features: Vector) => LabeledPoint(label, features) @@ -201,7 +203,7 @@ private[spark] abstract class PredictionModel[FeaturesType, M <: PredictionModel // Check schema transformSchema(dataset.schema, paramMap, logging = true) - val map = this.paramMap ++ paramMap + val map = extractParamMap(paramMap) // Prepare model val tmpModel = if (paramMap.size != 0) { diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/params.scala b/mllib/src/main/scala/org/apache/spark/ml/param/params.scala index 7d5178d0abb2d..849c60433c777 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/param/params.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/param/params.scala @@ -17,15 +17,14 @@ package org.apache.spark.ml.param +import java.lang.reflect.Modifier +import java.util.NoSuchElementException + import scala.annotation.varargs import scala.collection.mutable -import java.lang.reflect.Modifier - import org.apache.spark.annotation.{AlphaComponent, DeveloperApi} import org.apache.spark.ml.Identifiable -import org.apache.spark.sql.types.{DataType, StructField, StructType} - /** * :: AlphaComponent :: @@ -38,12 +37,7 @@ import org.apache.spark.sql.types.{DataType, StructField, StructType} * @tparam T param value type */ @AlphaComponent -class Param[T] ( - val parent: Params, - val name: String, - val doc: String, - val defaultValue: Option[T] = None) - extends Serializable { +class Param[T] (val parent: Params, val name: String, val doc: String) extends Serializable { /** * Creates a param pair with the given value (for Java). @@ -55,58 +49,55 @@ class Param[T] ( */ def ->(value: T): ParamPair[T] = ParamPair(this, value) + /** + * Converts this param's name, doc, and optionally its default value and the user-supplied + * value in its parent to string. + */ override def toString: String = { - if (defaultValue.isDefined) { - s"$name: $doc (default: ${defaultValue.get})" + val valueStr = if (parent.isDefined(this)) { + val defaultValueStr = parent.getDefault(this).map("default: " + _) + val currentValueStr = parent.get(this).map("current: " + _) + (defaultValueStr ++ currentValueStr).mkString("(", ", ", ")") } else { - s"$name: $doc" + "(undefined)" } + s"$name: $doc $valueStr" } } // specialize primitive-typed params because Java doesn't recognize scala.Double, scala.Int, ... /** Specialized version of [[Param[Double]]] for Java. */ -class DoubleParam(parent: Params, name: String, doc: String, defaultValue: Option[Double]) - extends Param[Double](parent, name, doc, defaultValue) { - - def this(parent: Params, name: String, doc: String) = this(parent, name, doc, None) +class DoubleParam(parent: Params, name: String, doc: String) + extends Param[Double](parent, name, doc) { override def w(value: Double): ParamPair[Double] = super.w(value) } /** Specialized version of [[Param[Int]]] for Java. */ -class IntParam(parent: Params, name: String, doc: String, defaultValue: Option[Int]) - extends Param[Int](parent, name, doc, defaultValue) { - - def this(parent: Params, name: String, doc: String) = this(parent, name, doc, None) +class IntParam(parent: Params, name: String, doc: String) + extends Param[Int](parent, name, doc) { override def w(value: Int): ParamPair[Int] = super.w(value) } /** Specialized version of [[Param[Float]]] for Java. */ -class FloatParam(parent: Params, name: String, doc: String, defaultValue: Option[Float]) - extends Param[Float](parent, name, doc, defaultValue) { - - def this(parent: Params, name: String, doc: String) = this(parent, name, doc, None) +class FloatParam(parent: Params, name: String, doc: String) + extends Param[Float](parent, name, doc) { override def w(value: Float): ParamPair[Float] = super.w(value) } /** Specialized version of [[Param[Long]]] for Java. */ -class LongParam(parent: Params, name: String, doc: String, defaultValue: Option[Long]) - extends Param[Long](parent, name, doc, defaultValue) { - - def this(parent: Params, name: String, doc: String) = this(parent, name, doc, None) +class LongParam(parent: Params, name: String, doc: String) + extends Param[Long](parent, name, doc) { override def w(value: Long): ParamPair[Long] = super.w(value) } /** Specialized version of [[Param[Boolean]]] for Java. */ -class BooleanParam(parent: Params, name: String, doc: String, defaultValue: Option[Boolean]) - extends Param[Boolean](parent, name, doc, defaultValue) { - - def this(parent: Params, name: String, doc: String) = this(parent, name, doc, None) +class BooleanParam(parent: Params, name: String, doc: String) + extends Param[Boolean](parent, name, doc) { override def w(value: Boolean): ParamPair[Boolean] = super.w(value) } @@ -124,8 +115,11 @@ case class ParamPair[T](param: Param[T], value: T) @AlphaComponent trait Params extends Identifiable with Serializable { - /** Returns all params. */ - def params: Array[Param[_]] = { + /** + * Returns all params sorted by their names. The default implementation uses Java reflection to + * list all public methods that have no arguments and return [[Param]]. + */ + lazy val params: Array[Param[_]] = { val methods = this.getClass.getMethods methods.filter { m => Modifier.isPublic(m.getModifiers) && @@ -153,25 +147,29 @@ trait Params extends Identifiable with Serializable { def explainParams(): String = params.mkString("\n") /** Checks whether a param is explicitly set. */ - def isSet(param: Param[_]): Boolean = { - require(param.parent.eq(this)) + final def isSet(param: Param[_]): Boolean = { + shouldOwn(param) paramMap.contains(param) } + /** Checks whether a param is explicitly set or has a default value. */ + final def isDefined(param: Param[_]): Boolean = { + shouldOwn(param) + defaultParamMap.contains(param) || paramMap.contains(param) + } + /** Gets a param by its name. */ - private[ml] def getParam(paramName: String): Param[Any] = { - val m = this.getClass.getMethod(paramName) - assert(Modifier.isPublic(m.getModifiers) && - classOf[Param[_]].isAssignableFrom(m.getReturnType) && - m.getParameterTypes.isEmpty) - m.invoke(this).asInstanceOf[Param[Any]] + def getParam(paramName: String): Param[Any] = { + params.find(_.name == paramName).getOrElse { + throw new NoSuchElementException(s"Param $paramName does not exist.") + }.asInstanceOf[Param[Any]] } /** * Sets a parameter in the embedded param map. */ - protected def set[T](param: Param[T], value: T): this.type = { - require(param.parent.eq(this)) + protected final def set[T](param: Param[T], value: T): this.type = { + shouldOwn(param) paramMap.put(param.asInstanceOf[Param[Any]], value) this } @@ -179,52 +177,102 @@ trait Params extends Identifiable with Serializable { /** * Sets a parameter (by name) in the embedded param map. */ - private[ml] def set(param: String, value: Any): this.type = { + protected final def set(param: String, value: Any): this.type = { set(getParam(param), value) } /** - * Gets the value of a parameter in the embedded param map. + * Optionally returns the user-supplied value of a param. + */ + final def get[T](param: Param[T]): Option[T] = { + shouldOwn(param) + paramMap.get(param) + } + + /** + * Clears the user-supplied value for the input param. + */ + protected final def clear(param: Param[_]): this.type = { + shouldOwn(param) + paramMap.remove(param) + this + } + + /** + * Gets the value of a param in the embedded param map or its default value. Throws an exception + * if neither is set. + */ + final def getOrDefault[T](param: Param[T]): T = { + shouldOwn(param) + get(param).orElse(getDefault(param)).get + } + + /** + * Sets a default value for a param. + * @param param param to set the default value. Make sure that this param is initialized before + * this method gets called. + * @param value the default value */ - protected def get[T](param: Param[T]): T = { - require(param.parent.eq(this)) - paramMap(param) + protected final def setDefault[T](param: Param[T], value: T): this.type = { + shouldOwn(param) + defaultParamMap.put(param, value) + this } /** - * Internal param map. + * Sets default values for a list of params. + * @param paramPairs a list of param pairs that specify params and their default values to set + * respectively. Make sure that the params are initialized before this method + * gets called. */ - protected val paramMap: ParamMap = ParamMap.empty + protected final def setDefault(paramPairs: ParamPair[_]*): this.type = { + paramPairs.foreach { p => + setDefault(p.param.asInstanceOf[Param[Any]], p.value) + } + this + } /** - * Check whether the given schema contains an input column. - * @param colName Input column name - * @param dataType Input column DataType + * Gets the default value of a parameter. */ - protected def checkInputColumn(schema: StructType, colName: String, dataType: DataType): Unit = { - val actualDataType = schema(colName).dataType - require(actualDataType.equals(dataType), s"Input column $colName must be of type $dataType" + - s" but was actually $actualDataType. Column param description: ${getParam(colName)}") + final def getDefault[T](param: Param[T]): Option[T] = { + shouldOwn(param) + defaultParamMap.get(param) } /** - * Add an output column to the given schema. - * This fails if the given output column already exists. - * @param schema Initial schema (not modified) - * @param colName Output column name. If this column name is an empy String "", this method - * returns the initial schema, unchanged. This allows users to disable output - * columns. - * @param dataType Output column DataType - */ - protected def addOutputColumn( - schema: StructType, - colName: String, - dataType: DataType): StructType = { - if (colName.length == 0) return schema - val fieldNames = schema.fieldNames - require(!fieldNames.contains(colName), s"Output column $colName already exists.") - val outputFields = schema.fields ++ Seq(StructField(colName, dataType, nullable = false)) - StructType(outputFields) + * Tests whether the input param has a default value set. + */ + final def hasDefault[T](param: Param[T]): Boolean = { + shouldOwn(param) + defaultParamMap.contains(param) + } + + /** + * Extracts the embedded default param values and user-supplied values, and then merges them with + * extra values from input into a flat param map, where the latter value is used if there exist + * conflicts, i.e., with ordering: default param values < user-supplied values < extraParamMap. + */ + protected final def extractParamMap(extraParamMap: ParamMap): ParamMap = { + defaultParamMap ++ paramMap ++ extraParamMap + } + + /** + * [[extractParamMap]] with no extra values. + */ + protected final def extractParamMap(): ParamMap = { + extractParamMap(ParamMap.empty) + } + + /** Internal param map for user-supplied values. */ + private val paramMap: ParamMap = ParamMap.empty + + /** Internal param map for default values. */ + private val defaultParamMap: ParamMap = ParamMap.empty + + /** Validates that the input param belongs to this instance. */ + private def shouldOwn(param: Param[_]): Unit = { + require(param.parent.eq(this), s"Param $param does not belong to $this.") } } @@ -261,12 +309,13 @@ private[spark] object Params { * A param to value map. */ @AlphaComponent -class ParamMap private[ml] (private val map: mutable.Map[Param[Any], Any]) extends Serializable { +final class ParamMap private[ml] (private val map: mutable.Map[Param[Any], Any]) + extends Serializable { /** * Creates an empty param map. */ - def this() = this(mutable.Map.empty[Param[Any], Any]) + def this() = this(mutable.Map.empty) /** * Puts a (param, value) pair (overwrites if the input param exists). @@ -288,12 +337,17 @@ class ParamMap private[ml] (private val map: mutable.Map[Param[Any], Any]) exten } /** - * Optionally returns the value associated with a param or its default. + * Optionally returns the value associated with a param. */ def get[T](param: Param[T]): Option[T] = { - map.get(param.asInstanceOf[Param[Any]]) - .orElse(param.defaultValue) - .asInstanceOf[Option[T]] + map.get(param.asInstanceOf[Param[Any]]).asInstanceOf[Option[T]] + } + + /** + * Returns the value associated with a param or a default value. + */ + def getOrElse[T](param: Param[T], default: T): T = { + get(param).getOrElse(default) } /** @@ -301,10 +355,7 @@ class ParamMap private[ml] (private val map: mutable.Map[Param[Any], Any]) exten * Raises a NoSuchElementException if there is no value associated with the input param. */ def apply[T](param: Param[T]): T = { - val value = get(param) - if (value.isDefined) { - value.get - } else { + get(param).getOrElse { throw new NoSuchElementException(s"Cannot find param ${param.name}.") } } @@ -316,6 +367,13 @@ class ParamMap private[ml] (private val map: mutable.Map[Param[Any], Any]) exten map.contains(param.asInstanceOf[Param[Any]]) } + /** + * Removes a key from this map and returns its value associated previously as an option. + */ + def remove[T](param: Param[T]): Option[T] = { + map.remove(param.asInstanceOf[Param[Any]]).asInstanceOf[Option[T]] + } + /** * Filters this param map for the given parent. */ @@ -325,7 +383,7 @@ class ParamMap private[ml] (private val map: mutable.Map[Param[Any], Any]) exten } /** - * Make a copy of this param map. + * Creates a copy of this param map. */ def copy: ParamMap = new ParamMap(map.clone()) @@ -337,7 +395,7 @@ class ParamMap private[ml] (private val map: mutable.Map[Param[Any], Any]) exten /** * Returns a new param map that contains parameters in this map and the given map, - * where the latter overwrites this if there exists conflicts. + * where the latter overwrites this if there exist conflicts. */ def ++(other: ParamMap): ParamMap = { // TODO: Provide a better method name for Java users. @@ -363,7 +421,7 @@ class ParamMap private[ml] (private val map: mutable.Map[Param[Any], Any]) exten } /** - * Number of param pairs in this set. + * Number of param pairs in this map. */ def size: Int = map.size } diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala b/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala new file mode 100644 index 0000000000000..95d7e64790c79 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala @@ -0,0 +1,169 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.param.shared + +import java.io.PrintWriter + +import scala.reflect.ClassTag + +/** + * Code generator for shared params (sharedParams.scala). Run under the Spark folder with + * {{{ + * build/sbt "mllib/runMain org.apache.spark.ml.param.shared.SharedParamsCodeGen" + * }}} + */ +private[shared] object SharedParamsCodeGen { + + def main(args: Array[String]): Unit = { + val params = Seq( + ParamDesc[Double]("regParam", "regularization parameter"), + ParamDesc[Int]("maxIter", "max number of iterations"), + ParamDesc[String]("featuresCol", "features column name", Some("\"features\"")), + ParamDesc[String]("labelCol", "label column name", Some("\"label\"")), + ParamDesc[String]("predictionCol", "prediction column name", Some("\"prediction\"")), + ParamDesc[String]("rawPredictionCol", "raw prediction (a.k.a. confidence) column name", + Some("\"rawPrediction\"")), + ParamDesc[String]("probabilityCol", + "column name for predicted class conditional probabilities", Some("\"probability\"")), + ParamDesc[Double]("threshold", "threshold in binary classification prediction"), + ParamDesc[String]("inputCol", "input column name"), + ParamDesc[Array[String]]("inputCols", "input column names"), + ParamDesc[String]("outputCol", "output column name"), + ParamDesc[Int]("checkpointInterval", "checkpoint interval"), + ParamDesc[Boolean]("fitIntercept", "whether to fit an intercept term", Some("true"))) + + val code = genSharedParams(params) + val file = "src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala" + val writer = new PrintWriter(file) + writer.write(code) + writer.close() + } + + /** Description of a param. */ + private case class ParamDesc[T: ClassTag]( + name: String, + doc: String, + defaultValueStr: Option[String] = None) { + + require(name.matches("[a-z][a-zA-Z0-9]*"), s"Param name $name is invalid.") + require(doc.nonEmpty) // TODO: more rigorous on doc + + def paramTypeName: String = { + val c = implicitly[ClassTag[T]].runtimeClass + c match { + case _ if c == classOf[Int] => "IntParam" + case _ if c == classOf[Long] => "LongParam" + case _ if c == classOf[Float] => "FloatParam" + case _ if c == classOf[Double] => "DoubleParam" + case _ if c == classOf[Boolean] => "BooleanParam" + case _ => s"Param[${getTypeString(c)}]" + } + } + + def valueTypeName: String = { + val c = implicitly[ClassTag[T]].runtimeClass + getTypeString(c) + } + + private def getTypeString(c: Class[_]): String = { + c match { + case _ if c == classOf[Int] => "Int" + case _ if c == classOf[Long] => "Long" + case _ if c == classOf[Float] => "Float" + case _ if c == classOf[Double] => "Double" + case _ if c == classOf[Boolean] => "Boolean" + case _ if c == classOf[String] => "String" + case _ if c.isArray => s"Array[${getTypeString(c.getComponentType)}]" + } + } + } + + /** Generates the HasParam trait code for the input param. */ + private def genHasParamTrait(param: ParamDesc[_]): String = { + val name = param.name + val Name = name(0).toUpper +: name.substring(1) + val Param = param.paramTypeName + val T = param.valueTypeName + val doc = param.doc + val defaultValue = param.defaultValueStr + val defaultValueDoc = defaultValue.map { v => + s" (default: $v)" + }.getOrElse("") + val setDefault = defaultValue.map { v => + s""" + | setDefault($name, $v) + |""".stripMargin + }.getOrElse("") + + s""" + |/** + | * :: DeveloperApi :: + | * Trait for shared param $name$defaultValueDoc. + | */ + |@DeveloperApi + |trait Has$Name extends Params { + | + | /** + | * Param for $doc. + | * @group param + | */ + | final val $name: $Param = new $Param(this, "$name", "$doc") + |$setDefault + | /** @group getParam */ + | final def get$Name: $T = getOrDefault($name) + |} + |""".stripMargin + } + + /** Generates Scala source code for the input params with header. */ + private def genSharedParams(params: Seq[ParamDesc[_]]): String = { + val header = + """/* + | * Licensed to the Apache Software Foundation (ASF) under one or more + | * contributor license agreements. See the NOTICE file distributed with + | * this work for additional information regarding copyright ownership. + | * The ASF licenses this file to You under the Apache License, Version 2.0 + | * (the "License"); you may not use this file except in compliance with + | * the License. You may obtain a copy of the License at + | * + | * http://www.apache.org/licenses/LICENSE-2.0 + | * + | * Unless required by applicable law or agreed to in writing, software + | * distributed under the License is distributed on an "AS IS" BASIS, + | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + | * See the License for the specific language governing permissions and + | * limitations under the License. + | */ + | + |package org.apache.spark.ml.param.shared + | + |import org.apache.spark.annotation.DeveloperApi + |import org.apache.spark.ml.param._ + | + |// DO NOT MODIFY THIS FILE! It was generated by SharedParamsCodeGen. + | + |// scalastyle:off + |""".stripMargin + + val footer = "// scalastyle:on\n" + + val traits = params.map(genHasParamTrait).mkString + + header + traits + footer + } +} diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala b/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala new file mode 100644 index 0000000000000..72b08bf276483 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala @@ -0,0 +1,259 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.param.shared + +import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.ml.param._ + +// DO NOT MODIFY THIS FILE! It was generated by SharedParamsCodeGen. + +// scalastyle:off + +/** + * :: DeveloperApi :: + * Trait for shared param regParam. + */ +@DeveloperApi +trait HasRegParam extends Params { + + /** + * Param for regularization parameter. + * @group param + */ + final val regParam: DoubleParam = new DoubleParam(this, "regParam", "regularization parameter") + + /** @group getParam */ + final def getRegParam: Double = getOrDefault(regParam) +} + +/** + * :: DeveloperApi :: + * Trait for shared param maxIter. + */ +@DeveloperApi +trait HasMaxIter extends Params { + + /** + * Param for max number of iterations. + * @group param + */ + final val maxIter: IntParam = new IntParam(this, "maxIter", "max number of iterations") + + /** @group getParam */ + final def getMaxIter: Int = getOrDefault(maxIter) +} + +/** + * :: DeveloperApi :: + * Trait for shared param featuresCol (default: "features"). + */ +@DeveloperApi +trait HasFeaturesCol extends Params { + + /** + * Param for features column name. + * @group param + */ + final val featuresCol: Param[String] = new Param[String](this, "featuresCol", "features column name") + + setDefault(featuresCol, "features") + + /** @group getParam */ + final def getFeaturesCol: String = getOrDefault(featuresCol) +} + +/** + * :: DeveloperApi :: + * Trait for shared param labelCol (default: "label"). + */ +@DeveloperApi +trait HasLabelCol extends Params { + + /** + * Param for label column name. + * @group param + */ + final val labelCol: Param[String] = new Param[String](this, "labelCol", "label column name") + + setDefault(labelCol, "label") + + /** @group getParam */ + final def getLabelCol: String = getOrDefault(labelCol) +} + +/** + * :: DeveloperApi :: + * Trait for shared param predictionCol (default: "prediction"). + */ +@DeveloperApi +trait HasPredictionCol extends Params { + + /** + * Param for prediction column name. + * @group param + */ + final val predictionCol: Param[String] = new Param[String](this, "predictionCol", "prediction column name") + + setDefault(predictionCol, "prediction") + + /** @group getParam */ + final def getPredictionCol: String = getOrDefault(predictionCol) +} + +/** + * :: DeveloperApi :: + * Trait for shared param rawPredictionCol (default: "rawPrediction"). + */ +@DeveloperApi +trait HasRawPredictionCol extends Params { + + /** + * Param for raw prediction (a.k.a. confidence) column name. + * @group param + */ + final val rawPredictionCol: Param[String] = new Param[String](this, "rawPredictionCol", "raw prediction (a.k.a. confidence) column name") + + setDefault(rawPredictionCol, "rawPrediction") + + /** @group getParam */ + final def getRawPredictionCol: String = getOrDefault(rawPredictionCol) +} + +/** + * :: DeveloperApi :: + * Trait for shared param probabilityCol (default: "probability"). + */ +@DeveloperApi +trait HasProbabilityCol extends Params { + + /** + * Param for column name for predicted class conditional probabilities. + * @group param + */ + final val probabilityCol: Param[String] = new Param[String](this, "probabilityCol", "column name for predicted class conditional probabilities") + + setDefault(probabilityCol, "probability") + + /** @group getParam */ + final def getProbabilityCol: String = getOrDefault(probabilityCol) +} + +/** + * :: DeveloperApi :: + * Trait for shared param threshold. + */ +@DeveloperApi +trait HasThreshold extends Params { + + /** + * Param for threshold in binary classification prediction. + * @group param + */ + final val threshold: DoubleParam = new DoubleParam(this, "threshold", "threshold in binary classification prediction") + + /** @group getParam */ + final def getThreshold: Double = getOrDefault(threshold) +} + +/** + * :: DeveloperApi :: + * Trait for shared param inputCol. + */ +@DeveloperApi +trait HasInputCol extends Params { + + /** + * Param for input column name. + * @group param + */ + final val inputCol: Param[String] = new Param[String](this, "inputCol", "input column name") + + /** @group getParam */ + final def getInputCol: String = getOrDefault(inputCol) +} + +/** + * :: DeveloperApi :: + * Trait for shared param inputCols. + */ +@DeveloperApi +trait HasInputCols extends Params { + + /** + * Param for input column names. + * @group param + */ + final val inputCols: Param[Array[String]] = new Param[Array[String]](this, "inputCols", "input column names") + + /** @group getParam */ + final def getInputCols: Array[String] = getOrDefault(inputCols) +} + +/** + * :: DeveloperApi :: + * Trait for shared param outputCol. + */ +@DeveloperApi +trait HasOutputCol extends Params { + + /** + * Param for output column name. + * @group param + */ + final val outputCol: Param[String] = new Param[String](this, "outputCol", "output column name") + + /** @group getParam */ + final def getOutputCol: String = getOrDefault(outputCol) +} + +/** + * :: DeveloperApi :: + * Trait for shared param checkpointInterval. + */ +@DeveloperApi +trait HasCheckpointInterval extends Params { + + /** + * Param for checkpoint interval. + * @group param + */ + final val checkpointInterval: IntParam = new IntParam(this, "checkpointInterval", "checkpoint interval") + + /** @group getParam */ + final def getCheckpointInterval: Int = getOrDefault(checkpointInterval) +} + +/** + * :: DeveloperApi :: + * Trait for shared param fitIntercept (default: true). + */ +@DeveloperApi +trait HasFitIntercept extends Params { + + /** + * Param for whether to fit an intercept term. + * @group param + */ + final val fitIntercept: BooleanParam = new BooleanParam(this, "fitIntercept", "whether to fit an intercept term") + + setDefault(fitIntercept, true) + + /** @group getParam */ + final def getFitIntercept: Boolean = getOrDefault(fitIntercept) +} +// scalastyle:on diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/sharedParams.scala b/mllib/src/main/scala/org/apache/spark/ml/param/sharedParams.scala deleted file mode 100644 index 07e6eb417763d..0000000000000 --- a/mllib/src/main/scala/org/apache/spark/ml/param/sharedParams.scala +++ /dev/null @@ -1,173 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.ml.param - -/* NOTE TO DEVELOPERS: - * If you mix these parameter traits into your algorithm, please add a setter method as well - * so that users may use a builder pattern: - * val myLearner = new MyLearner().setParam1(x).setParam2(y)... - */ - -private[ml] trait HasRegParam extends Params { - /** - * param for regularization parameter - * @group param - */ - val regParam: DoubleParam = new DoubleParam(this, "regParam", "regularization parameter") - - /** @group getParam */ - def getRegParam: Double = get(regParam) -} - -private[ml] trait HasMaxIter extends Params { - /** - * param for max number of iterations - * @group param - */ - val maxIter: IntParam = new IntParam(this, "maxIter", "max number of iterations") - - /** @group getParam */ - def getMaxIter: Int = get(maxIter) -} - -private[ml] trait HasFeaturesCol extends Params { - /** - * param for features column name - * @group param - */ - val featuresCol: Param[String] = - new Param(this, "featuresCol", "features column name", Some("features")) - - /** @group getParam */ - def getFeaturesCol: String = get(featuresCol) -} - -private[ml] trait HasLabelCol extends Params { - /** - * param for label column name - * @group param - */ - val labelCol: Param[String] = new Param(this, "labelCol", "label column name", Some("label")) - - /** @group getParam */ - def getLabelCol: String = get(labelCol) -} - -private[ml] trait HasPredictionCol extends Params { - /** - * param for prediction column name - * @group param - */ - val predictionCol: Param[String] = - new Param(this, "predictionCol", "prediction column name", Some("prediction")) - - /** @group getParam */ - def getPredictionCol: String = get(predictionCol) -} - -private[ml] trait HasRawPredictionCol extends Params { - /** - * param for raw prediction column name - * @group param - */ - val rawPredictionCol: Param[String] = - new Param(this, "rawPredictionCol", "raw prediction (a.k.a. confidence) column name", - Some("rawPrediction")) - - /** @group getParam */ - def getRawPredictionCol: String = get(rawPredictionCol) -} - -private[ml] trait HasProbabilityCol extends Params { - /** - * param for predicted class conditional probabilities column name - * @group param - */ - val probabilityCol: Param[String] = - new Param(this, "probabilityCol", "column name for predicted class conditional probabilities", - Some("probability")) - - /** @group getParam */ - def getProbabilityCol: String = get(probabilityCol) -} - -private[ml] trait HasFitIntercept extends Params { - /** - * param for fitting the intercept term, defaults to true - * @group param - */ - val fitIntercept: BooleanParam = - new BooleanParam(this, "fitIntercept", "indicates whether to fit an intercept term", Some(true)) - - /** @group getParam */ - def getFitIntercept: Boolean = get(fitIntercept) -} - -private[ml] trait HasThreshold extends Params { - /** - * param for threshold in (binary) prediction - * @group param - */ - val threshold: DoubleParam = new DoubleParam(this, "threshold", "threshold in prediction") - - /** @group getParam */ - def getThreshold: Double = get(threshold) -} - -private[ml] trait HasInputCol extends Params { - /** - * param for input column name - * @group param - */ - val inputCol: Param[String] = new Param(this, "inputCol", "input column name") - - /** @group getParam */ - def getInputCol: String = get(inputCol) -} - -private[ml] trait HasInputCols extends Params { - /** - * Param for input column names. - */ - val inputCols: Param[Array[String]] = new Param(this, "inputCols", "input column names") - - /** @group getParam */ - def getInputCols: Array[String] = get(inputCols) -} - -private[ml] trait HasOutputCol extends Params { - /** - * param for output column name - * @group param - */ - val outputCol: Param[String] = new Param(this, "outputCol", "output column name") - - /** @group getParam */ - def getOutputCol: String = get(outputCol) -} - -private[ml] trait HasCheckpointInterval extends Params { - /** - * param for checkpoint interval - * @group param - */ - val checkpointInterval: IntParam = new IntParam(this, "checkpointInterval", "checkpoint interval") - - /** @group getParam */ - def getCheckpointInterval: Int = get(checkpointInterval) -} diff --git a/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala b/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala index 52c9e95d6012f..bd793beba35b6 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala @@ -34,6 +34,7 @@ import org.apache.spark.{Logging, Partitioner} import org.apache.spark.annotation.DeveloperApi import org.apache.spark.ml.{Estimator, Model} import org.apache.spark.ml.param._ +import org.apache.spark.ml.param.shared._ import org.apache.spark.mllib.optimization.NNLS import org.apache.spark.rdd.RDD import org.apache.spark.sql.DataFrame @@ -54,86 +55,88 @@ private[recommendation] trait ALSParams extends Params with HasMaxIter with HasR * Param for rank of the matrix factorization. * @group param */ - val rank = new IntParam(this, "rank", "rank of the factorization", Some(10)) + val rank = new IntParam(this, "rank", "rank of the factorization") /** @group getParam */ - def getRank: Int = get(rank) + def getRank: Int = getOrDefault(rank) /** * Param for number of user blocks. * @group param */ - val numUserBlocks = new IntParam(this, "numUserBlocks", "number of user blocks", Some(10)) + val numUserBlocks = new IntParam(this, "numUserBlocks", "number of user blocks") /** @group getParam */ - def getNumUserBlocks: Int = get(numUserBlocks) + def getNumUserBlocks: Int = getOrDefault(numUserBlocks) /** * Param for number of item blocks. * @group param */ val numItemBlocks = - new IntParam(this, "numItemBlocks", "number of item blocks", Some(10)) + new IntParam(this, "numItemBlocks", "number of item blocks") /** @group getParam */ - def getNumItemBlocks: Int = get(numItemBlocks) + def getNumItemBlocks: Int = getOrDefault(numItemBlocks) /** * Param to decide whether to use implicit preference. * @group param */ - val implicitPrefs = - new BooleanParam(this, "implicitPrefs", "whether to use implicit preference", Some(false)) + val implicitPrefs = new BooleanParam(this, "implicitPrefs", "whether to use implicit preference") /** @group getParam */ - def getImplicitPrefs: Boolean = get(implicitPrefs) + def getImplicitPrefs: Boolean = getOrDefault(implicitPrefs) /** * Param for the alpha parameter in the implicit preference formulation. * @group param */ - val alpha = new DoubleParam(this, "alpha", "alpha for implicit preference", Some(1.0)) + val alpha = new DoubleParam(this, "alpha", "alpha for implicit preference") /** @group getParam */ - def getAlpha: Double = get(alpha) + def getAlpha: Double = getOrDefault(alpha) /** * Param for the column name for user ids. * @group param */ - val userCol = new Param[String](this, "userCol", "column name for user ids", Some("user")) + val userCol = new Param[String](this, "userCol", "column name for user ids") /** @group getParam */ - def getUserCol: String = get(userCol) + def getUserCol: String = getOrDefault(userCol) /** * Param for the column name for item ids. * @group param */ - val itemCol = - new Param[String](this, "itemCol", "column name for item ids", Some("item")) + val itemCol = new Param[String](this, "itemCol", "column name for item ids") /** @group getParam */ - def getItemCol: String = get(itemCol) + def getItemCol: String = getOrDefault(itemCol) /** * Param for the column name for ratings. * @group param */ - val ratingCol = new Param[String](this, "ratingCol", "column name for ratings", Some("rating")) + val ratingCol = new Param[String](this, "ratingCol", "column name for ratings") /** @group getParam */ - def getRatingCol: String = get(ratingCol) + def getRatingCol: String = getOrDefault(ratingCol) /** * Param for whether to apply nonnegativity constraints. * @group param */ val nonnegative = new BooleanParam( - this, "nonnegative", "whether to use nonnegative constraint for least squares", Some(false)) + this, "nonnegative", "whether to use nonnegative constraint for least squares") /** @group getParam */ - val getNonnegative: Boolean = get(nonnegative) + def getNonnegative: Boolean = getOrDefault(nonnegative) + + setDefault(rank -> 10, maxIter -> 10, regParam -> 0.1, numUserBlocks -> 10, numItemBlocks -> 10, + implicitPrefs -> false, alpha -> 1.0, userCol -> "user", itemCol -> "item", + ratingCol -> "rating", nonnegative -> false) /** * Validates and transforms the input schema. @@ -142,7 +145,7 @@ private[recommendation] trait ALSParams extends Params with HasMaxIter with HasR * @return output schema */ protected def validateAndTransformSchema(schema: StructType, paramMap: ParamMap): StructType = { - val map = this.paramMap ++ paramMap + val map = extractParamMap(paramMap) assert(schema(map(userCol)).dataType == IntegerType) assert(schema(map(itemCol)).dataType== IntegerType) val ratingType = schema(map(ratingCol)).dataType @@ -171,7 +174,7 @@ class ALSModel private[ml] ( override def transform(dataset: DataFrame, paramMap: ParamMap): DataFrame = { import dataset.sqlContext.implicits._ - val map = this.paramMap ++ paramMap + val map = extractParamMap(paramMap) val users = userFactors.toDF("id", "features") val items = itemFactors.toDF("id", "features") @@ -283,7 +286,7 @@ class ALS extends Estimator[ALSModel] with ALSParams { setCheckpointInterval(10) override def fit(dataset: DataFrame, paramMap: ParamMap): ALSModel = { - val map = this.paramMap ++ paramMap + val map = extractParamMap(paramMap) val ratings = dataset .select(col(map(userCol)), col(map(itemCol)), col(map(ratingCol)).cast(FloatType)) .map { row => diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala index 65f6627a0c351..26ca7459c4fdf 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala @@ -18,7 +18,8 @@ package org.apache.spark.ml.regression import org.apache.spark.annotation.AlphaComponent -import org.apache.spark.ml.param.{Params, ParamMap, HasMaxIter, HasRegParam} +import org.apache.spark.ml.param.{Params, ParamMap} +import org.apache.spark.ml.param.shared._ import org.apache.spark.mllib.linalg.{BLAS, Vector} import org.apache.spark.mllib.regression.LinearRegressionWithSGD import org.apache.spark.sql.DataFrame @@ -41,8 +42,7 @@ private[regression] trait LinearRegressionParams extends RegressorParams class LinearRegression extends Regressor[Vector, LinearRegression, LinearRegressionModel] with LinearRegressionParams { - setRegParam(0.1) - setMaxIter(100) + setDefault(regParam -> 0.1, maxIter -> 100) /** @group setParam */ def setRegParam(value: Double): this.type = set(regParam, value) @@ -93,7 +93,7 @@ class LinearRegressionModel private[ml] ( override protected def copy(): LinearRegressionModel = { val m = new LinearRegressionModel(parent, fittingParamMap, weights, intercept) - Params.inheritValues(this.paramMap, this, m) + Params.inheritValues(extractParamMap(), this, m) m } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala b/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala index 2eb1dac56f1e9..4bb4ed813c006 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala @@ -31,6 +31,7 @@ import org.apache.spark.sql.types.StructType * Params for [[CrossValidator]] and [[CrossValidatorModel]]. */ private[ml] trait CrossValidatorParams extends Params { + /** * param for the estimator to be cross-validated * @group param @@ -38,7 +39,7 @@ private[ml] trait CrossValidatorParams extends Params { val estimator: Param[Estimator[_]] = new Param(this, "estimator", "estimator for selection") /** @group getParam */ - def getEstimator: Estimator[_] = get(estimator) + def getEstimator: Estimator[_] = getOrDefault(estimator) /** * param for estimator param maps @@ -48,7 +49,7 @@ private[ml] trait CrossValidatorParams extends Params { new Param(this, "estimatorParamMaps", "param maps for the estimator") /** @group getParam */ - def getEstimatorParamMaps: Array[ParamMap] = get(estimatorParamMaps) + def getEstimatorParamMaps: Array[ParamMap] = getOrDefault(estimatorParamMaps) /** * param for the evaluator for selection @@ -57,17 +58,18 @@ private[ml] trait CrossValidatorParams extends Params { val evaluator: Param[Evaluator] = new Param(this, "evaluator", "evaluator for selection") /** @group getParam */ - def getEvaluator: Evaluator = get(evaluator) + def getEvaluator: Evaluator = getOrDefault(evaluator) /** * param for number of folds for cross validation * @group param */ - val numFolds: IntParam = - new IntParam(this, "numFolds", "number of folds for cross validation", Some(3)) + val numFolds: IntParam = new IntParam(this, "numFolds", "number of folds for cross validation") /** @group getParam */ - def getNumFolds: Int = get(numFolds) + def getNumFolds: Int = getOrDefault(numFolds) + + setDefault(numFolds -> 3) } /** @@ -92,7 +94,7 @@ class CrossValidator extends Estimator[CrossValidatorModel] with CrossValidatorP def setNumFolds(value: Int): this.type = set(numFolds, value) override def fit(dataset: DataFrame, paramMap: ParamMap): CrossValidatorModel = { - val map = this.paramMap ++ paramMap + val map = extractParamMap(paramMap) val schema = dataset.schema transformSchema(dataset.schema, paramMap, logging = true) val sqlCtx = dataset.sqlContext @@ -130,7 +132,7 @@ class CrossValidator extends Estimator[CrossValidatorModel] with CrossValidatorP } override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = { - val map = this.paramMap ++ paramMap + val map = extractParamMap(paramMap) map(estimator).transformSchema(schema, paramMap) } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/util/SchemaUtils.scala b/mllib/src/main/scala/org/apache/spark/ml/util/SchemaUtils.scala new file mode 100644 index 0000000000000..0383bf0b382b7 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/util/SchemaUtils.scala @@ -0,0 +1,61 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.util + +import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.sql.types.{DataType, StructField, StructType} + +/** + * :: DeveloperApi :: + * Utils for handling schemas. + */ +@DeveloperApi +object SchemaUtils { + + // TODO: Move the utility methods to SQL. + + /** + * Check whether the given schema contains a column of the required data type. + * @param colName column name + * @param dataType required column data type + */ + def checkColumnType(schema: StructType, colName: String, dataType: DataType): Unit = { + val actualDataType = schema(colName).dataType + require(actualDataType.equals(dataType), + s"Column $colName must be of type $dataType but was actually $actualDataType.") + } + + /** + * Appends a new column to the input schema. This fails if the given output column already exists. + * @param schema input schema + * @param colName new column name. If this column name is an empty string "", this method returns + * the input schema unchanged. This allows users to disable output columns. + * @param dataType new column data type + * @return new schema with the input column appended + */ + def appendColumn( + schema: StructType, + colName: String, + dataType: DataType): StructType = { + if (colName.isEmpty) return schema + val fieldNames = schema.fieldNames + require(!fieldNames.contains(colName), s"Column $colName already exists.") + val outputFields = schema.fields :+ StructField(colName, dataType, nullable = false) + StructType(outputFields) + } +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/api/python/MatrixFactorizationModelWrapper.scala b/mllib/src/main/scala/org/apache/spark/mllib/api/python/MatrixFactorizationModelWrapper.scala index ecd3b16598438..534edac56bc5a 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/api/python/MatrixFactorizationModelWrapper.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/api/python/MatrixFactorizationModelWrapper.scala @@ -18,6 +18,7 @@ package org.apache.spark.mllib.api.python import org.apache.spark.api.java.JavaRDD +import org.apache.spark.mllib.linalg.Vectors import org.apache.spark.mllib.recommendation.{MatrixFactorizationModel, Rating} import org.apache.spark.rdd.RDD @@ -31,10 +32,14 @@ private[python] class MatrixFactorizationModelWrapper(model: MatrixFactorization predict(SerDe.asTupleRDD(userAndProducts.rdd)) def getUserFeatures: RDD[Array[Any]] = { - SerDe.fromTuple2RDD(userFeatures.asInstanceOf[RDD[(Any, Any)]]) + SerDe.fromTuple2RDD(userFeatures.map { + case (user, feature) => (user, Vectors.dense(feature)) + }.asInstanceOf[RDD[(Any, Any)]]) } def getProductFeatures: RDD[Array[Any]] = { - SerDe.fromTuple2RDD(productFeatures.asInstanceOf[RDD[(Any, Any)]]) + SerDe.fromTuple2RDD(productFeatures.map { + case (product, feature) => (product, Vectors.dense(feature)) + }.asInstanceOf[RDD[(Any, Any)]]) } } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala index ab15f0f36a14b..f976d2f97b043 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala @@ -28,7 +28,6 @@ import scala.reflect.ClassTag import net.razorvine.pickle._ -import org.apache.spark.annotation.DeveloperApi import org.apache.spark.api.java.{JavaRDD, JavaSparkContext} import org.apache.spark.api.python.SerDeUtil import org.apache.spark.mllib.classification._ @@ -40,15 +39,15 @@ import org.apache.spark.mllib.optimization._ import org.apache.spark.mllib.random.{RandomRDDs => RG} import org.apache.spark.mllib.recommendation._ import org.apache.spark.mllib.regression._ -import org.apache.spark.mllib.stat.{MultivariateStatisticalSummary, Statistics} import org.apache.spark.mllib.stat.correlation.CorrelationNames import org.apache.spark.mllib.stat.distribution.MultivariateGaussian import org.apache.spark.mllib.stat.test.ChiSqTestResult -import org.apache.spark.mllib.tree.{GradientBoostedTrees, RandomForest, DecisionTree} -import org.apache.spark.mllib.tree.configuration.{BoostingStrategy, Algo, Strategy} +import org.apache.spark.mllib.stat.{MultivariateStatisticalSummary, Statistics} +import org.apache.spark.mllib.tree.configuration.{Algo, BoostingStrategy, Strategy} import org.apache.spark.mllib.tree.impurity._ import org.apache.spark.mllib.tree.loss.Losses -import org.apache.spark.mllib.tree.model.{GradientBoostedTreesModel, RandomForestModel, DecisionTreeModel} +import org.apache.spark.mllib.tree.model.{DecisionTreeModel, GradientBoostedTreesModel, RandomForestModel} +import org.apache.spark.mllib.tree.{DecisionTree, GradientBoostedTrees, RandomForest} import org.apache.spark.mllib.util.MLUtils import org.apache.spark.rdd.RDD import org.apache.spark.storage.StorageLevel @@ -279,7 +278,7 @@ private[python] class PythonMLLibAPI extends Serializable { data: JavaRDD[LabeledPoint], lambda: Double): JList[Object] = { val model = NaiveBayes.train(data.rdd, lambda) - List(Vectors.dense(model.labels), Vectors.dense(model.pi), model.theta). + List(Vectors.dense(model.labels), Vectors.dense(model.pi), model.theta.map(Vectors.dense)). map(_.asInstanceOf[Object]).asJava } @@ -335,7 +334,7 @@ private[python] class PythonMLLibAPI extends Serializable { mu += model.gaussians(i).mu sigma += model.gaussians(i).sigma } - List(wt.toArray, mu.toArray, sigma.toArray).map(_.asInstanceOf[Object]).asJava + List(Vectors.dense(wt.toArray), mu.toArray, sigma.toArray).map(_.asInstanceOf[Object]).asJava } finally { data.rdd.unpersist(blocking = false) } @@ -346,20 +345,20 @@ private[python] class PythonMLLibAPI extends Serializable { */ def predictSoftGMM( data: JavaRDD[Vector], - wt: Object, + wt: Vector, mu: Array[Object], - si: Array[Object]): RDD[Array[Double]] = { + si: Array[Object]): RDD[Vector] = { - val weight = wt.asInstanceOf[Array[Double]] + val weight = wt.toArray val mean = mu.map(_.asInstanceOf[DenseVector]) val sigma = si.map(_.asInstanceOf[DenseMatrix]) val gaussians = Array.tabulate(weight.length){ i => new MultivariateGaussian(mean(i), sigma(i)) } val model = new GaussianMixtureModel(weight, gaussians) - model.predictSoft(data) + model.predictSoft(data).map(Vectors.dense) } - + /** * Java stub for Python mllib ALS.train(). This stub returns a handle * to the Java object instead of the content of the Java object. Extra care @@ -936,6 +935,14 @@ private[spark] object SerDe extends Serializable { out.write(code) } + protected def getBytes(obj: Object): Array[Byte] = { + if (obj.getClass.isArray) { + obj.asInstanceOf[Array[Byte]] + } else { + obj.asInstanceOf[String].getBytes(LATIN1) + } + } + private[python] def saveState(obj: Object, out: OutputStream, pickler: Pickler) } @@ -961,7 +968,7 @@ private[spark] object SerDe extends Serializable { if (args.length != 1) { throw new PickleException("should be 1") } - val bytes = args(0).asInstanceOf[String].getBytes(LATIN1) + val bytes = getBytes(args(0)) val bb = ByteBuffer.wrap(bytes, 0, bytes.length) bb.order(ByteOrder.nativeOrder()) val db = bb.asDoubleBuffer() @@ -994,7 +1001,7 @@ private[spark] object SerDe extends Serializable { if (args.length != 3) { throw new PickleException("should be 3") } - val bytes = args(2).asInstanceOf[String].getBytes(LATIN1) + val bytes = getBytes(args(2)) val n = bytes.length / 8 val values = new Array[Double](n) val order = ByteOrder.nativeOrder() @@ -1031,8 +1038,8 @@ private[spark] object SerDe extends Serializable { throw new PickleException("should be 3") } val size = args(0).asInstanceOf[Int] - val indiceBytes = args(1).asInstanceOf[String].getBytes(LATIN1) - val valueBytes = args(2).asInstanceOf[String].getBytes(LATIN1) + val indiceBytes = getBytes(args(1)) + val valueBytes = getBytes(args(2)) val n = indiceBytes.length / 4 val indices = new Array[Int](n) val values = new Array[Double](n) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala index 328dbe2ce11fa..4ef171f4f0419 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala @@ -227,7 +227,7 @@ object Vectors { * @param elements vector elements in (index, value) pairs. */ def sparse(size: Int, elements: Seq[(Int, Double)]): Vector = { - require(size > 0) + require(size > 0, "The size of the requested sparse vector must be greater than 0.") val (indices, values) = elements.sortBy(_._1).unzip var prev = -1 @@ -235,7 +235,8 @@ object Vectors { require(prev < i, s"Found duplicate indices: $i.") prev = i } - require(prev < size) + require(prev < size, s"You may not write an element to index $prev because the declared " + + s"size of your vector is $size") new SparseVector(size, indices.toArray, values.toArray) } @@ -309,7 +310,8 @@ object Vectors { * @return norm in L^p^ space. */ def norm(vector: Vector, p: Double): Double = { - require(p >= 1.0) + require(p >= 1.0, "To compute the p-norm of the vector, we require that you specify a p>=1. " + + s"You specified p=$p.") val values = vector match { case DenseVector(vs) => vs case SparseVector(n, ids, vs) => vs @@ -360,7 +362,8 @@ object Vectors { * @return squared distance between two Vectors. */ def sqdist(v1: Vector, v2: Vector): Double = { - require(v1.size == v2.size, "vector dimension mismatch") + require(v1.size == v2.size, s"Vector dimensions do not match: Dim(v1)=${v1.size} and Dim(v2)" + + s"=${v2.size}.") var squaredDistance = 0.0 (v1, v2) match { case (v1: SparseVector, v2: SparseVector) => @@ -518,7 +521,9 @@ class SparseVector( val indices: Array[Int], val values: Array[Double]) extends Vector { - require(indices.length == values.length) + require(indices.length == values.length, "Sparse vectors require that the dimension of the" + + s" indices match the dimension of the values. You provided ${indices.size} indices and " + + s" ${values.size} values.") override def toString: String = "(%s,%s,%s)".format(size, indices.mkString("[", ",", "]"), values.mkString("[", ",", "]")) diff --git a/mllib/src/test/scala/org/apache/spark/ml/param/ParamsSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/param/ParamsSuite.scala index 1ce2987612378..88ea679eeaad5 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/param/ParamsSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/param/ParamsSuite.scala @@ -21,19 +21,25 @@ import org.scalatest.FunSuite class ParamsSuite extends FunSuite { - val solver = new TestParams() - import solver.{inputCol, maxIter} - test("param") { + val solver = new TestParams() + import solver.{maxIter, inputCol} + assert(maxIter.name === "maxIter") assert(maxIter.doc === "max number of iterations") - assert(maxIter.defaultValue.get === 100) assert(maxIter.parent.eq(solver)) - assert(maxIter.toString === "maxIter: max number of iterations (default: 100)") - assert(inputCol.defaultValue === None) + assert(maxIter.toString === "maxIter: max number of iterations (default: 10)") + + solver.setMaxIter(5) + assert(maxIter.toString === "maxIter: max number of iterations (default: 10, current: 5)") + + assert(inputCol.toString === "inputCol: input column name (undefined)") } test("param pair") { + val solver = new TestParams() + import solver.maxIter + val pair0 = maxIter -> 5 val pair1 = maxIter.w(5) val pair2 = ParamPair(maxIter, 5) @@ -44,10 +50,12 @@ class ParamsSuite extends FunSuite { } test("param map") { + val solver = new TestParams() + import solver.{maxIter, inputCol} + val map0 = ParamMap.empty assert(!map0.contains(maxIter)) - assert(map0(maxIter) === maxIter.defaultValue.get) map0.put(maxIter, 10) assert(map0.contains(maxIter)) assert(map0(maxIter) === 10) @@ -78,23 +86,39 @@ class ParamsSuite extends FunSuite { } test("params") { + val solver = new TestParams() + import solver.{maxIter, inputCol} + val params = solver.params - assert(params.size === 2) + assert(params.length === 2) assert(params(0).eq(inputCol), "params must be ordered by name") assert(params(1).eq(maxIter)) + + assert(!solver.isSet(maxIter)) + assert(solver.isDefined(maxIter)) + assert(solver.getMaxIter === 10) + solver.setMaxIter(100) + assert(solver.isSet(maxIter)) + assert(solver.getMaxIter === 100) + assert(!solver.isSet(inputCol)) + assert(!solver.isDefined(inputCol)) + intercept[NoSuchElementException](solver.getInputCol) + assert(solver.explainParams() === Seq(inputCol, maxIter).mkString("\n")) + assert(solver.getParam("inputCol").eq(inputCol)) assert(solver.getParam("maxIter").eq(maxIter)) - intercept[NoSuchMethodException] { + intercept[NoSuchElementException] { solver.getParam("abc") } - assert(!solver.isSet(inputCol)) + intercept[IllegalArgumentException] { solver.validate() } solver.validate(ParamMap(inputCol -> "input")) solver.setInputCol("input") assert(solver.isSet(inputCol)) + assert(solver.isDefined(inputCol)) assert(solver.getInputCol === "input") solver.validate() intercept[IllegalArgumentException] { @@ -104,5 +128,8 @@ class ParamsSuite extends FunSuite { intercept[IllegalArgumentException] { solver.validate() } + + solver.clearMaxIter() + assert(!solver.isSet(maxIter)) } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/param/TestParams.scala b/mllib/src/test/scala/org/apache/spark/ml/param/TestParams.scala index ce52f2f230085..641b64b42a5e7 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/param/TestParams.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/param/TestParams.scala @@ -17,20 +17,21 @@ package org.apache.spark.ml.param +import org.apache.spark.ml.param.shared.{HasInputCol, HasMaxIter} + /** A subclass of Params for testing. */ -class TestParams extends Params { +class TestParams extends Params with HasMaxIter with HasInputCol { - val maxIter = new IntParam(this, "maxIter", "max number of iterations", Some(100)) def setMaxIter(value: Int): this.type = { set(maxIter, value); this } - def getMaxIter: Int = get(maxIter) - - val inputCol = new Param[String](this, "inputCol", "input column name") def setInputCol(value: String): this.type = { set(inputCol, value); this } - def getInputCol: String = get(inputCol) + + setDefault(maxIter -> 10) override def validate(paramMap: ParamMap): Unit = { - val m = this.paramMap ++ paramMap + val m = extractParamMap(paramMap) require(m(maxIter) >= 0) require(m.contains(inputCol)) } + + def clearMaxIter(): this.type = clear(maxIter) } diff --git a/pom.xml b/pom.xml index d8881c213bf07..bcc2f57f1af5d 100644 --- a/pom.xml +++ b/pom.xml @@ -147,7 +147,7 @@ 1.8.3 1.1.0 4.2.6 - 3.1.1 + 3.4.1 ${project.build.directory}/spark-test-classpath.txt 2.10.4 2.10 @@ -156,7 +156,7 @@ 3.6.3 1.8.8 2.4.4 - 1.1.1.6 + 1.1.1.7 1.1.2 ${java.home} @@ -1447,7 +1447,7 @@ org.scalastyle scalastyle-maven-plugin - 0.4.0 + 0.7.0 false true @@ -1456,13 +1456,12 @@ ${basedir}/src/main/scala ${basedir}/src/test/scala scalastyle-config.xml - scalastyle-output.xml + ${basedir}/target/scalastyle-output.xml ${project.build.sourceEncoding} ${project.reporting.outputEncoding} - package check diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala index 5f51f4b58f97a..09b4976d10c26 100644 --- a/project/SparkBuild.scala +++ b/project/SparkBuild.scala @@ -15,7 +15,7 @@ * limitations under the License. */ -import java.io.File +import java.io._ import scala.util.Properties import scala.collection.JavaConversions._ @@ -166,6 +166,9 @@ object SparkBuild extends PomBuild { /* Enable Assembly for all assembly projects */ assemblyProjects.foreach(enable(Assembly.settings)) + /* Package pyspark artifacts in the main assembly. */ + enable(PySparkAssembly.settings)(assembly) + /* Enable unidoc only for the root spark project */ enable(Unidoc.settings)(spark) @@ -316,6 +319,7 @@ object Hive { } object Assembly { + import sbtassembly.AssemblyUtils._ import sbtassembly.Plugin._ import AssemblyKeys._ @@ -347,6 +351,60 @@ object Assembly { ) } +object PySparkAssembly { + import sbtassembly.Plugin._ + import AssemblyKeys._ + + lazy val settings = Seq( + unmanagedJars in Compile += { BuildCommons.sparkHome / "python/lib/py4j-0.8.2.1-src.zip" }, + // Use a resource generator to copy all .py files from python/pyspark into a managed directory + // to be included in the assembly. We can't just add "python/" to the assembly's resource dir + // list since that will copy unneeded / unwanted files. + resourceGenerators in Compile <+= resourceManaged in Compile map { outDir: File => + val dst = new File(outDir, "pyspark") + if (!dst.isDirectory()) { + require(dst.mkdirs()) + } + + val src = new File(BuildCommons.sparkHome, "python/pyspark") + copy(src, dst) + } + ) + + private def copy(src: File, dst: File): Seq[File] = { + src.listFiles().flatMap { f => + val child = new File(dst, f.getName()) + if (f.isDirectory()) { + child.mkdir() + copy(f, child) + } else if (f.getName().endsWith(".py")) { + var in: Option[FileInputStream] = None + var out: Option[FileOutputStream] = None + try { + in = Some(new FileInputStream(f)) + out = Some(new FileOutputStream(child)) + + val bytes = new Array[Byte](1024) + var read = 0 + while (read >= 0) { + read = in.get.read(bytes) + if (read > 0) { + out.get.write(bytes, 0, read) + } + } + + Some(child) + } finally { + in.foreach(_.close()) + out.foreach(_.close()) + } + } else { + None + } + } + } +} + object Unidoc { import BuildCommons._ diff --git a/python/pyspark/accumulators.py b/python/pyspark/accumulators.py index ccbca67656c8d..7271809e43880 100644 --- a/python/pyspark/accumulators.py +++ b/python/pyspark/accumulators.py @@ -54,7 +54,7 @@ ... def zero(self, value): ... return [0.0] * len(value) ... def addInPlace(self, val1, val2): -... for i in xrange(len(val1)): +... for i in range(len(val1)): ... val1[i] += val2[i] ... return val1 >>> va = sc.accumulator([1.0, 2.0, 3.0], VectorAccumulatorParam()) @@ -86,9 +86,13 @@ Exception:... """ +import sys import select import struct -import SocketServer +if sys.version < '3': + import SocketServer +else: + import socketserver as SocketServer import threading from pyspark.cloudpickle import CloudPickler from pyspark.serializers import read_int, PickleSerializer @@ -247,6 +251,7 @@ class AccumulatorServer(SocketServer.TCPServer): def shutdown(self): self.server_shutdown = True SocketServer.TCPServer.shutdown(self) + self.server_close() def _start_update_server(): diff --git a/python/pyspark/broadcast.py b/python/pyspark/broadcast.py index 6b8a8b256a891..3de4615428bb6 100644 --- a/python/pyspark/broadcast.py +++ b/python/pyspark/broadcast.py @@ -16,10 +16,15 @@ # import os -import cPickle +import sys import gc from tempfile import NamedTemporaryFile +if sys.version < '3': + import cPickle as pickle +else: + import pickle + unicode = str __all__ = ['Broadcast'] @@ -70,33 +75,19 @@ def __init__(self, sc=None, value=None, pickle_registry=None, path=None): self._path = path def dump(self, value, f): - if isinstance(value, basestring): - if isinstance(value, unicode): - f.write('U') - value = value.encode('utf8') - else: - f.write('S') - f.write(value) - else: - f.write('P') - cPickle.dump(value, f, 2) + pickle.dump(value, f, 2) f.close() return f.name def load(self, path): with open(path, 'rb', 1 << 20) as f: - flag = f.read(1) - data = f.read() - if flag == 'P': - # cPickle.loads() may create lots of objects, disable GC - # temporary for better performance - gc.disable() - try: - return cPickle.loads(data) - finally: - gc.enable() - else: - return data.decode('utf8') if flag == 'U' else data + # pickle.load() may create lots of objects, disable GC + # temporary for better performance + gc.disable() + try: + return pickle.load(f) + finally: + gc.enable() @property def value(self): diff --git a/python/pyspark/cloudpickle.py b/python/pyspark/cloudpickle.py index bb0783555aa77..9ef93071d2e77 100644 --- a/python/pyspark/cloudpickle.py +++ b/python/pyspark/cloudpickle.py @@ -40,164 +40,126 @@ NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. """ - +from __future__ import print_function import operator import os +import io import pickle import struct import sys import types from functools import partial import itertools -from copy_reg import _extension_registry, _inverted_registry, _extension_cache -import new import dis import traceback -import platform - -PyImp = platform.python_implementation() - -import logging -cloudLog = logging.getLogger("Cloud.Transport") +if sys.version < '3': + from pickle import Pickler + try: + from cStringIO import StringIO + except ImportError: + from StringIO import StringIO + PY3 = False +else: + types.ClassType = type + from pickle import _Pickler as Pickler + from io import BytesIO as StringIO + PY3 = True #relevant opcodes -STORE_GLOBAL = chr(dis.opname.index('STORE_GLOBAL')) -DELETE_GLOBAL = chr(dis.opname.index('DELETE_GLOBAL')) -LOAD_GLOBAL = chr(dis.opname.index('LOAD_GLOBAL')) +STORE_GLOBAL = dis.opname.index('STORE_GLOBAL') +DELETE_GLOBAL = dis.opname.index('DELETE_GLOBAL') +LOAD_GLOBAL = dis.opname.index('LOAD_GLOBAL') GLOBAL_OPS = [STORE_GLOBAL, DELETE_GLOBAL, LOAD_GLOBAL] +HAVE_ARGUMENT = dis.HAVE_ARGUMENT +EXTENDED_ARG = dis.EXTENDED_ARG -HAVE_ARGUMENT = chr(dis.HAVE_ARGUMENT) -EXTENDED_ARG = chr(dis.EXTENDED_ARG) - -if PyImp == "PyPy": - # register builtin type in `new` - new.method = types.MethodType - -try: - from cStringIO import StringIO -except ImportError: - from StringIO import StringIO -# These helper functions were copied from PiCloud's util module. def islambda(func): - return getattr(func,'func_name') == '' + return getattr(func,'__name__') == '' -def xrange_params(xrangeobj): - """Returns a 3 element tuple describing the xrange start, step, and len - respectively - Note: Only guarentees that elements of xrange are the same. parameters may - be different. - e.g. xrange(1,1) is interpretted as xrange(0,0); both behave the same - though w/ iteration - """ - - xrange_len = len(xrangeobj) - if not xrange_len: #empty - return (0,1,0) - start = xrangeobj[0] - if xrange_len == 1: #one element - return start, 1, 1 - return (start, xrangeobj[1] - xrangeobj[0], xrange_len) - -#debug variables intended for developer use: -printSerialization = False -printMemoization = False +_BUILTIN_TYPE_NAMES = {} +for k, v in types.__dict__.items(): + if type(v) is type: + _BUILTIN_TYPE_NAMES[v] = k -useForcedImports = True #Should I use forced imports for tracking? +def _builtin_type(name): + return getattr(types, name) -class CloudPickler(pickle.Pickler): +class CloudPickler(Pickler): - dispatch = pickle.Pickler.dispatch.copy() - savedForceImports = False - savedDjangoEnv = False #hack tro transport django environment + dispatch = Pickler.dispatch.copy() - def __init__(self, file, protocol=None, min_size_to_save= 0): - pickle.Pickler.__init__(self,file,protocol) - self.modules = set() #set of modules needed to depickle - self.globals_ref = {} # map ids to dictionary. used to ensure that functions can share global env + def __init__(self, file, protocol=None): + Pickler.__init__(self, file, protocol) + # set of modules to unpickle + self.modules = set() + # map ids to dictionary. used to ensure that functions can share global env + self.globals_ref = {} def dump(self, obj): - # note: not thread safe - # minimal side-effects, so not fixing - recurse_limit = 3000 - base_recurse = sys.getrecursionlimit() - if base_recurse < recurse_limit: - sys.setrecursionlimit(recurse_limit) self.inject_addons() try: - return pickle.Pickler.dump(self, obj) - except RuntimeError, e: + return Pickler.dump(self, obj) + except RuntimeError as e: if 'recursion' in e.args[0]: - msg = """Could not pickle object as excessively deep recursion required. - Try _fast_serialization=2 or contact PiCloud support""" + msg = """Could not pickle object as excessively deep recursion required.""" raise pickle.PicklingError(msg) - finally: - new_recurse = sys.getrecursionlimit() - if new_recurse == recurse_limit: - sys.setrecursionlimit(base_recurse) + + def save_memoryview(self, obj): + """Fallback to save_string""" + Pickler.save_string(self, str(obj)) def save_buffer(self, obj): """Fallback to save_string""" - pickle.Pickler.save_string(self,str(obj)) - dispatch[buffer] = save_buffer + Pickler.save_string(self,str(obj)) + if PY3: + dispatch[memoryview] = save_memoryview + else: + dispatch[buffer] = save_buffer - #block broken objects - def save_unsupported(self, obj, pack=None): + def save_unsupported(self, obj): raise pickle.PicklingError("Cannot pickle objects of type %s" % type(obj)) dispatch[types.GeneratorType] = save_unsupported - #python2.6+ supports slice pickling. some py2.5 extensions might as well. We just test it - try: - slice(0,1).__reduce__() - except TypeError: #can't pickle - - dispatch[slice] = save_unsupported - - #itertools objects do not pickle! + # itertools objects do not pickle! for v in itertools.__dict__.values(): if type(v) is type: dispatch[v] = save_unsupported - - def save_dict(self, obj): - """hack fix - If the dict is a global, deal with it in a special way - """ - #print 'saving', obj - if obj is __builtins__: - self.save_reduce(_get_module_builtins, (), obj=obj) - else: - pickle.Pickler.save_dict(self, obj) - dispatch[pickle.DictionaryType] = save_dict - - - def save_module(self, obj, pack=struct.pack): + def save_module(self, obj): """ Save a module as an import """ - #print 'try save import', obj.__name__ self.modules.add(obj) - self.save_reduce(subimport,(obj.__name__,), obj=obj) - dispatch[types.ModuleType] = save_module #new type + self.save_reduce(subimport, (obj.__name__,), obj=obj) + dispatch[types.ModuleType] = save_module - def save_codeobject(self, obj, pack=struct.pack): + def save_codeobject(self, obj): """ Save a code object """ - #print 'try to save codeobj: ', obj - args = ( - obj.co_argcount, obj.co_nlocals, obj.co_stacksize, obj.co_flags, obj.co_code, - obj.co_consts, obj.co_names, obj.co_varnames, obj.co_filename, obj.co_name, - obj.co_firstlineno, obj.co_lnotab, obj.co_freevars, obj.co_cellvars - ) + if PY3: + args = ( + obj.co_argcount, obj.co_kwonlyargcount, obj.co_nlocals, obj.co_stacksize, + obj.co_flags, obj.co_code, obj.co_consts, obj.co_names, obj.co_varnames, + obj.co_filename, obj.co_name, obj.co_firstlineno, obj.co_lnotab, obj.co_freevars, + obj.co_cellvars + ) + else: + args = ( + obj.co_argcount, obj.co_nlocals, obj.co_stacksize, obj.co_flags, obj.co_code, + obj.co_consts, obj.co_names, obj.co_varnames, obj.co_filename, obj.co_name, + obj.co_firstlineno, obj.co_lnotab, obj.co_freevars, obj.co_cellvars + ) self.save_reduce(types.CodeType, args, obj=obj) - dispatch[types.CodeType] = save_codeobject #new type + dispatch[types.CodeType] = save_codeobject - def save_function(self, obj, name=None, pack=struct.pack): + def save_function(self, obj, name=None): """ Registered with the dispatch to handle all function types. Determines what kind of function obj is (e.g. lambda, defined at @@ -205,12 +167,14 @@ def save_function(self, obj, name=None, pack=struct.pack): """ write = self.write - name = obj.__name__ + if name is None: + name = obj.__name__ modname = pickle.whichmodule(obj, name) - #print 'which gives %s %s %s' % (modname, obj, name) + # print('which gives %s %s %s' % (modname, obj, name)) try: themodule = sys.modules[modname] - except KeyError: # eval'd items such as namedtuple give invalid items for their function __module__ + except KeyError: + # eval'd items such as namedtuple give invalid items for their function __module__ modname = '__main__' if modname == '__main__': @@ -221,37 +185,18 @@ def save_function(self, obj, name=None, pack=struct.pack): if getattr(themodule, name, None) is obj: return self.save_global(obj, name) - if not self.savedDjangoEnv: - #hack for django - if we detect the settings module, we transport it - django_settings = os.environ.get('DJANGO_SETTINGS_MODULE', '') - if django_settings: - django_mod = sys.modules.get(django_settings) - if django_mod: - cloudLog.debug('Transporting django settings %s during save of %s', django_mod, name) - self.savedDjangoEnv = True - self.modules.add(django_mod) - write(pickle.MARK) - self.save_reduce(django_settings_load, (django_mod.__name__,), obj=django_mod) - write(pickle.POP_MARK) - - # if func is lambda, def'ed at prompt, is in main, or is nested, then # we'll pickle the actual function object rather than simply saving a # reference (as is done in default pickler), via save_function_tuple. - if islambda(obj) or obj.func_code.co_filename == '' or themodule is None: - #Force server to import modules that have been imported in main - modList = None - if themodule is None and not self.savedForceImports: - mainmod = sys.modules['__main__'] - if useForcedImports and hasattr(mainmod,'___pyc_forcedImports__'): - modList = list(mainmod.___pyc_forcedImports__) - self.savedForceImports = True - self.save_function_tuple(obj, modList) + if islambda(obj) or obj.__code__.co_filename == '' or themodule is None: + #print("save global", islambda(obj), obj.__code__.co_filename, modname, themodule) + self.save_function_tuple(obj) return - else: # func is nested + else: + # func is nested klass = getattr(themodule, name, None) if klass is None or klass is not obj: - self.save_function_tuple(obj, [themodule]) + self.save_function_tuple(obj) return if obj.__dict__: @@ -266,7 +211,7 @@ def save_function(self, obj, name=None, pack=struct.pack): self.memoize(obj) dispatch[types.FunctionType] = save_function - def save_function_tuple(self, func, forced_imports): + def save_function_tuple(self, func): """ Pickles an actual func object. A func comprises: code, globals, defaults, closure, and dict. We @@ -281,19 +226,6 @@ def save_function_tuple(self, func, forced_imports): save = self.save write = self.write - # save the modules (if any) - if forced_imports: - write(pickle.MARK) - save(_modules_to_main) - #print 'forced imports are', forced_imports - - forced_names = map(lambda m: m.__name__, forced_imports) - save((forced_names,)) - - #save((forced_imports,)) - write(pickle.REDUCE) - write(pickle.POP_MARK) - code, f_globals, defaults, closure, dct, base_globals = self.extract_func_data(func) save(_fill_function) # skeleton function updater @@ -318,6 +250,8 @@ def extract_code_globals(co): Find all globals names read or written to by codeblock co """ code = co.co_code + if not PY3: + code = [ord(c) for c in code] names = co.co_names out_names = set() @@ -327,18 +261,18 @@ def extract_code_globals(co): while i < n: op = code[i] - i = i+1 + i += 1 if op >= HAVE_ARGUMENT: - oparg = ord(code[i]) + ord(code[i+1])*256 + extended_arg + oparg = code[i] + code[i+1] * 256 + extended_arg extended_arg = 0 - i = i+2 + i += 2 if op == EXTENDED_ARG: - extended_arg = oparg*65536L + extended_arg = oparg*65536 if op in GLOBAL_OPS: out_names.add(names[oparg]) - #print 'extracted', out_names, ' from ', names - if co.co_consts: # see if nested function have any global refs + # see if nested function have any global refs + if co.co_consts: for const in co.co_consts: if type(const) is types.CodeType: out_names |= CloudPickler.extract_code_globals(const) @@ -350,46 +284,28 @@ def extract_func_data(self, func): Turn the function into a tuple of data necessary to recreate it: code, globals, defaults, closure, dict """ - code = func.func_code + code = func.__code__ # extract all global ref's - func_global_refs = CloudPickler.extract_code_globals(code) + func_global_refs = self.extract_code_globals(code) # process all variables referenced by global environment f_globals = {} for var in func_global_refs: - #Some names, such as class functions are not global - we don't need them - if func.func_globals.has_key(var): - f_globals[var] = func.func_globals[var] + if var in func.__globals__: + f_globals[var] = func.__globals__[var] # defaults requires no processing - defaults = func.func_defaults - - def get_contents(cell): - try: - return cell.cell_contents - except ValueError, e: #cell is empty error on not yet assigned - raise pickle.PicklingError('Function to be pickled has free variables that are referenced before assignment in enclosing scope') - + defaults = func.__defaults__ # process closure - if func.func_closure: - closure = map(get_contents, func.func_closure) - else: - closure = [] + closure = [c.cell_contents for c in func.__closure__] if func.__closure__ else [] # save the dict - dct = func.func_dict - - if printSerialization: - outvars = ['code: ' + str(code) ] - outvars.append('globals: ' + str(f_globals)) - outvars.append('defaults: ' + str(defaults)) - outvars.append('closure: ' + str(closure)) - print 'function ', func, 'is extracted to: ', ', '.join(outvars) + dct = func.__dict__ - base_globals = self.globals_ref.get(id(func.func_globals), {}) - self.globals_ref[id(func.func_globals)] = base_globals + base_globals = self.globals_ref.get(id(func.__globals__), {}) + self.globals_ref[id(func.__globals__)] = base_globals return (code, f_globals, defaults, closure, dct, base_globals) @@ -400,8 +316,9 @@ def save_builtin_function(self, obj): dispatch[types.BuiltinFunctionType] = save_builtin_function def save_global(self, obj, name=None, pack=struct.pack): - write = self.write - memo = self.memo + if obj.__module__ == "__builtin__" or obj.__module__ == "builtins": + if obj in _BUILTIN_TYPE_NAMES: + return self.save_reduce(_builtin_type, (_BUILTIN_TYPE_NAMES[obj],), obj=obj) if name is None: name = obj.__name__ @@ -410,98 +327,57 @@ def save_global(self, obj, name=None, pack=struct.pack): if modname is None: modname = pickle.whichmodule(obj, name) - try: - __import__(modname) - themodule = sys.modules[modname] - except (ImportError, KeyError, AttributeError): #should never occur - raise pickle.PicklingError( - "Can't pickle %r: Module %s cannot be found" % - (obj, modname)) - if modname == '__main__': themodule = None - - if themodule: + else: + __import__(modname) + themodule = sys.modules[modname] self.modules.add(themodule) - sendRef = True - typ = type(obj) - #print 'saving', obj, typ - try: - try: #Deal with case when getattribute fails with exceptions - klass = getattr(themodule, name) - except (AttributeError): - if modname == '__builtin__': #new.* are misrepeported - modname = 'new' - __import__(modname) - themodule = sys.modules[modname] - try: - klass = getattr(themodule, name) - except AttributeError, a: - # print themodule, name, obj, type(obj) - raise pickle.PicklingError("Can't pickle builtin %s" % obj) - else: - raise + if hasattr(themodule, name) and getattr(themodule, name) is obj: + return Pickler.save_global(self, obj, name) - except (ImportError, KeyError, AttributeError): - if typ == types.TypeType or typ == types.ClassType: - sendRef = False - else: #we can't deal with this - raise - else: - if klass is not obj and (typ == types.TypeType or typ == types.ClassType): - sendRef = False - if not sendRef: - #note: Third party types might crash this - add better checks! - d = dict(obj.__dict__) #copy dict proxy to a dict - if not isinstance(d.get('__dict__', None), property): # don't extract dict that are properties - d.pop('__dict__',None) - d.pop('__weakref__',None) + typ = type(obj) + if typ is not obj and isinstance(obj, (type, types.ClassType)): + d = dict(obj.__dict__) # copy dict proxy to a dict + if not isinstance(d.get('__dict__', None), property): + # don't extract dict that are properties + d.pop('__dict__', None) + d.pop('__weakref__', None) # hack as __new__ is stored differently in the __dict__ new_override = d.get('__new__', None) if new_override: d['__new__'] = obj.__new__ - self.save_reduce(type(obj),(obj.__name__,obj.__bases__, - d),obj=obj) - #print 'internal reduce dask %s %s' % (obj, d) - return - - if self.proto >= 2: - code = _extension_registry.get((modname, name)) - if code: - assert code > 0 - if code <= 0xff: - write(pickle.EXT1 + chr(code)) - elif code <= 0xffff: - write("%c%c%c" % (pickle.EXT2, code&0xff, code>>8)) - else: - write(pickle.EXT4 + pack("> sys.stderr, 'Cloud not import django settings %s:' % (name) - print_exec(sys.stderr) - if modified_env: - del os.environ['DJANGO_SETTINGS_MODULE'] - else: - #add project directory to sys,path: - if hasattr(module,'__file__'): - dirname = os.path.split(module.__file__)[0] + '/' - sys.path.append(dirname) # restores function attributes def _restore_attr(obj, attr): @@ -851,13 +636,16 @@ def _restore_attr(obj, attr): setattr(obj, key, val) return obj + def _get_module_builtins(): return pickle.__builtins__ + def print_exec(stream): ei = sys.exc_info() traceback.print_exception(ei[0], ei[1], ei[2], None, stream) + def _modules_to_main(modList): """Force every module in modList to be placed into main""" if not modList: @@ -868,22 +656,16 @@ def _modules_to_main(modList): if type(modname) is str: try: mod = __import__(modname) - except Exception, i: #catch all... - sys.stderr.write('warning: could not import %s\n. Your function may unexpectedly error due to this import failing; \ -A version mismatch is likely. Specific error was:\n' % modname) + except Exception as e: + sys.stderr.write('warning: could not import %s\n. ' + 'Your function may unexpectedly error due to this import failing;' + 'A version mismatch is likely. Specific error was:\n' % modname) print_exec(sys.stderr) else: - setattr(main,mod.__name__, mod) - else: - #REVERSE COMPATIBILITY FOR CLOUD CLIENT 1.5 (WITH EPD) - #In old version actual module was sent - setattr(main,modname.__name__, modname) + setattr(main, mod.__name__, mod) -#object generators: -def _build_xrange(start, step, len): - """Built xrange explicitly""" - return xrange(start, start + step*len, step) +#object generators: def _genpartial(func, args, kwds): if not args: args = () @@ -891,22 +673,26 @@ def _genpartial(func, args, kwds): kwds = {} return partial(func, *args, **kwds) + def _fill_function(func, globals, defaults, dict): """ Fills in the rest of function data into the skeleton function object that were created via _make_skel_func(). """ - func.func_globals.update(globals) - func.func_defaults = defaults - func.func_dict = dict + func.__globals__.update(globals) + func.__defaults__ = defaults + func.__dict__ = dict return func + def _make_cell(value): - return (lambda: value).func_closure[0] + return (lambda: value).__closure__[0] + def _reconstruct_closure(values): return tuple([_make_cell(v) for v in values]) + def _make_skel_func(code, closures, base_globals = None): """ Creates a skeleton function object that contains just the provided code and the correct number of cells in func_closure. All other @@ -928,40 +714,3 @@ def _make_skel_func(code, closures, base_globals = None): def _getobject(modname, attribute): mod = __import__(modname, fromlist=[attribute]) return mod.__dict__[attribute] - -def _generateImage(size, mode, str_rep): - """Generate image from string representation""" - import Image - i = Image.new(mode, size) - i.fromstring(str_rep) - return i - -def _lazyloadImage(fp): - import Image - fp.seek(0) #works in almost any case - return Image.open(fp) - -"""Timeseries""" -def _genTimeSeries(reduce_args, state): - import scikits.timeseries.tseries as ts - from numpy import ndarray - from numpy.ma import MaskedArray - - - time_series = ts._tsreconstruct(*reduce_args) - - #from setstate modified - (ver, shp, typ, isf, raw, msk, flv, dsh, dtm, dtyp, frq, infodict) = state - #print 'regenerating %s' % dtyp - - MaskedArray.__setstate__(time_series, (ver, shp, typ, isf, raw, msk, flv)) - _dates = time_series._dates - #_dates.__setstate__((ver, dsh, typ, isf, dtm, frq)) #use remote typ - ndarray.__setstate__(_dates,(dsh,dtyp, isf, dtm)) - _dates.freq = frq - _dates._cachedinfo.update(dict(full=None, hasdups=None, steps=None, - toobj=None, toord=None, tostr=None)) - # Update the _optinfo dictionary - time_series._optinfo.update(infodict) - return time_series - diff --git a/python/pyspark/conf.py b/python/pyspark/conf.py index dc7cd0bce56f3..924da3eecf214 100644 --- a/python/pyspark/conf.py +++ b/python/pyspark/conf.py @@ -44,7 +44,7 @@ >>> conf.get("spark.executorEnv.VAR1") u'value1' ->>> print conf.toDebugString() +>>> print(conf.toDebugString()) spark.executorEnv.VAR1=value1 spark.executorEnv.VAR3=value3 spark.executorEnv.VAR4=value4 @@ -56,6 +56,13 @@ __all__ = ['SparkConf'] +import sys +import re + +if sys.version > '3': + unicode = str + __doc__ = re.sub(r"(\W|^)[uU](['])", r'\1\2', __doc__) + class SparkConf(object): diff --git a/python/pyspark/context.py b/python/pyspark/context.py index 78dccc40470e3..1dc2fec0ae5c8 100644 --- a/python/pyspark/context.py +++ b/python/pyspark/context.py @@ -15,6 +15,8 @@ # limitations under the License. # +from __future__ import print_function + import os import shutil import sys @@ -32,11 +34,14 @@ from pyspark.serializers import PickleSerializer, BatchedSerializer, UTF8Deserializer, \ PairDeserializer, AutoBatchedSerializer, NoOpSerializer from pyspark.storagelevel import StorageLevel -from pyspark.rdd import RDD, _load_from_socket +from pyspark.rdd import RDD, _load_from_socket, ignore_unicode_prefix from pyspark.traceback_utils import CallSite, first_spark_call from pyspark.status import StatusTracker from pyspark.profiler import ProfilerCollector, BasicProfiler +if sys.version > '3': + xrange = range + __all__ = ['SparkContext'] @@ -133,7 +138,7 @@ def _do_init(self, master, appName, sparkHome, pyFiles, environment, batchSize, if sparkHome: self._conf.setSparkHome(sparkHome) if environment: - for key, value in environment.iteritems(): + for key, value in environment.items(): self._conf.setExecutorEnv(key, value) for key, value in DEFAULT_CONFIGS.items(): self._conf.setIfMissing(key, value) @@ -153,6 +158,10 @@ def _do_init(self, master, appName, sparkHome, pyFiles, environment, batchSize, if k.startswith("spark.executorEnv."): varName = k[len("spark.executorEnv."):] self.environment[varName] = v + if sys.version >= '3.3' and 'PYTHONHASHSEED' not in os.environ: + # disable randomness of hash of string in worker, if this is not + # launched by spark-submit + self.environment["PYTHONHASHSEED"] = "0" # Create the Java SparkContext through Py4J self._jsc = jsc or self._initialize_context(self._conf._jconf) @@ -323,7 +332,7 @@ def parallelize(self, c, numSlices=None): start0 = c[0] def getStart(split): - return start0 + (split * size / numSlices) * step + return start0 + int((split * size / numSlices)) * step def f(split, iterator): return xrange(getStart(split), getStart(split + 1), step) @@ -357,6 +366,7 @@ def pickleFile(self, name, minPartitions=None): minPartitions = minPartitions or self.defaultMinPartitions return RDD(self._jsc.objectFile(name, minPartitions), self) + @ignore_unicode_prefix def textFile(self, name, minPartitions=None, use_unicode=True): """ Read a text file from HDFS, a local file system (available on all @@ -369,7 +379,7 @@ def textFile(self, name, minPartitions=None, use_unicode=True): >>> path = os.path.join(tempdir, "sample-text.txt") >>> with open(path, "w") as testFile: - ... testFile.write("Hello world!") + ... _ = testFile.write("Hello world!") >>> textFile = sc.textFile(path) >>> textFile.collect() [u'Hello world!'] @@ -378,6 +388,7 @@ def textFile(self, name, minPartitions=None, use_unicode=True): return RDD(self._jsc.textFile(name, minPartitions), self, UTF8Deserializer(use_unicode)) + @ignore_unicode_prefix def wholeTextFiles(self, path, minPartitions=None, use_unicode=True): """ Read a directory of text files from HDFS, a local file system @@ -411,9 +422,9 @@ def wholeTextFiles(self, path, minPartitions=None, use_unicode=True): >>> dirPath = os.path.join(tempdir, "files") >>> os.mkdir(dirPath) >>> with open(os.path.join(dirPath, "1.txt"), "w") as file1: - ... file1.write("1") + ... _ = file1.write("1") >>> with open(os.path.join(dirPath, "2.txt"), "w") as file2: - ... file2.write("2") + ... _ = file2.write("2") >>> textFiles = sc.wholeTextFiles(dirPath) >>> sorted(textFiles.collect()) [(u'.../1.txt', u'1'), (u'.../2.txt', u'2')] @@ -456,7 +467,7 @@ def _dictToJavaMap(self, d): jm = self._jvm.java.util.HashMap() if not d: d = {} - for k, v in d.iteritems(): + for k, v in d.items(): jm[k] = v return jm @@ -608,6 +619,7 @@ def _checkpointFile(self, name, input_deserializer): jrdd = self._jsc.checkpointFile(name) return RDD(jrdd, self, input_deserializer) + @ignore_unicode_prefix def union(self, rdds): """ Build the union of a list of RDDs. @@ -618,7 +630,7 @@ def union(self, rdds): >>> path = os.path.join(tempdir, "union-text.txt") >>> with open(path, "w") as testFile: - ... testFile.write("Hello") + ... _ = testFile.write("Hello") >>> textFile = sc.textFile(path) >>> textFile.collect() [u'Hello'] @@ -677,7 +689,7 @@ def addFile(self, path): >>> from pyspark import SparkFiles >>> path = os.path.join(tempdir, "test.txt") >>> with open(path, "w") as testFile: - ... testFile.write("100") + ... _ = testFile.write("100") >>> sc.addFile(path) >>> def func(iterator): ... with open(SparkFiles.get("test.txt")) as testFile: @@ -705,11 +717,13 @@ def addPyFile(self, path): """ self.addFile(path) (dirname, filename) = os.path.split(path) # dirname may be directory or HDFS/S3 prefix - if filename[-4:].lower() in self.PACKAGE_EXTENSIONS: self._python_includes.append(filename) # for tests in local mode sys.path.insert(1, os.path.join(SparkFiles.getRootDirectory(), filename)) + if sys.version > '3': + import importlib + importlib.invalidate_caches() def setCheckpointDir(self, dirName): """ @@ -744,7 +758,7 @@ def setJobGroup(self, groupId, description, interruptOnCancel=False): The application can use L{SparkContext.cancelJobGroup} to cancel all running jobs in this group. - >>> import thread, threading + >>> import threading >>> from time import sleep >>> result = "Not Set" >>> lock = threading.Lock() @@ -763,10 +777,10 @@ def setJobGroup(self, groupId, description, interruptOnCancel=False): ... sleep(5) ... sc.cancelJobGroup("job_to_cancel") >>> supress = lock.acquire() - >>> supress = thread.start_new_thread(start_job, (10,)) - >>> supress = thread.start_new_thread(stop_job, tuple()) + >>> supress = threading.Thread(target=start_job, args=(10,)).start() + >>> supress = threading.Thread(target=stop_job).start() >>> supress = lock.acquire() - >>> print result + >>> print(result) Cancelled If interruptOnCancel is set to true for the job group, then job cancellation will result diff --git a/python/pyspark/daemon.py b/python/pyspark/daemon.py index 93885985fe377..7f06d4288c872 100644 --- a/python/pyspark/daemon.py +++ b/python/pyspark/daemon.py @@ -24,9 +24,10 @@ import traceback import time import gc -from errno import EINTR, ECHILD, EAGAIN +from errno import EINTR, EAGAIN from socket import AF_INET, SOCK_STREAM, SOMAXCONN from signal import SIGHUP, SIGTERM, SIGCHLD, SIG_DFL, SIG_IGN, SIGINT + from pyspark.worker import main as worker_main from pyspark.serializers import read_int, write_int @@ -53,8 +54,8 @@ def worker(sock): # Read the socket using fdopen instead of socket.makefile() because the latter # seems to be very slow; note that we need to dup() the file descriptor because # otherwise writes also cause a seek that makes us miss data on the read side. - infile = os.fdopen(os.dup(sock.fileno()), "a+", 65536) - outfile = os.fdopen(os.dup(sock.fileno()), "a+", 65536) + infile = os.fdopen(os.dup(sock.fileno()), "rb", 65536) + outfile = os.fdopen(os.dup(sock.fileno()), "wb", 65536) exit_code = 0 try: worker_main(infile, outfile) @@ -68,17 +69,6 @@ def worker(sock): return exit_code -# Cleanup zombie children -def cleanup_dead_children(): - try: - while True: - pid, _ = os.waitpid(0, os.WNOHANG) - if not pid: - break - except: - pass - - def manager(): # Create a new process group to corral our children os.setpgid(0, 0) @@ -88,8 +78,12 @@ def manager(): listen_sock.bind(('127.0.0.1', 0)) listen_sock.listen(max(1024, SOMAXCONN)) listen_host, listen_port = listen_sock.getsockname() - write_int(listen_port, sys.stdout) - sys.stdout.flush() + + # re-open stdin/stdout in 'wb' mode + stdin_bin = os.fdopen(sys.stdin.fileno(), 'rb', 4) + stdout_bin = os.fdopen(sys.stdout.fileno(), 'wb', 4) + write_int(listen_port, stdout_bin) + stdout_bin.flush() def shutdown(code): signal.signal(SIGTERM, SIG_DFL) @@ -101,6 +95,7 @@ def handle_sigterm(*args): shutdown(1) signal.signal(SIGTERM, handle_sigterm) # Gracefully exit on SIGTERM signal.signal(SIGHUP, SIG_IGN) # Don't die on SIGHUP + signal.signal(SIGCHLD, SIG_IGN) reuse = os.environ.get("SPARK_REUSE_WORKER") @@ -115,12 +110,9 @@ def handle_sigterm(*args): else: raise - # cleanup in signal handler will cause deadlock - cleanup_dead_children() - if 0 in ready_fds: try: - worker_pid = read_int(sys.stdin) + worker_pid = read_int(stdin_bin) except EOFError: # Spark told us to exit by closing stdin shutdown(0) @@ -145,7 +137,7 @@ def handle_sigterm(*args): time.sleep(1) pid = os.fork() # error here will shutdown daemon else: - outfile = sock.makefile('w') + outfile = sock.makefile(mode='wb') write_int(e.errno, outfile) # Signal that the fork failed outfile.flush() outfile.close() @@ -157,7 +149,7 @@ def handle_sigterm(*args): listen_sock.close() try: # Acknowledge that the fork was successful - outfile = sock.makefile("w") + outfile = sock.makefile(mode="wb") write_int(os.getpid(), outfile) outfile.flush() outfile.close() diff --git a/python/pyspark/heapq3.py b/python/pyspark/heapq3.py index bc441f138f7fc..4ef2afe03544f 100644 --- a/python/pyspark/heapq3.py +++ b/python/pyspark/heapq3.py @@ -627,51 +627,49 @@ def merge(iterables, key=None, reverse=False): if key is None: for order, it in enumerate(map(iter, iterables)): try: - next = it.next - h_append([next(), order * direction, next]) + h_append([next(it), order * direction, it]) except StopIteration: pass _heapify(h) while len(h) > 1: try: while True: - value, order, next = s = h[0] + value, order, it = s = h[0] yield value - s[0] = next() # raises StopIteration when exhausted + s[0] = next(it) # raises StopIteration when exhausted _heapreplace(h, s) # restore heap condition except StopIteration: _heappop(h) # remove empty iterator if h: # fast case when only a single iterator remains - value, order, next = h[0] + value, order, it = h[0] yield value - for value in next.__self__: + for value in it: yield value return for order, it in enumerate(map(iter, iterables)): try: - next = it.next - value = next() - h_append([key(value), order * direction, value, next]) + value = next(it) + h_append([key(value), order * direction, value, it]) except StopIteration: pass _heapify(h) while len(h) > 1: try: while True: - key_value, order, value, next = s = h[0] + key_value, order, value, it = s = h[0] yield value - value = next() + value = next(it) s[0] = key(value) s[2] = value _heapreplace(h, s) except StopIteration: _heappop(h) if h: - key_value, order, value, next = h[0] + key_value, order, value, it = h[0] yield value - for value in next.__self__: + for value in it: yield value diff --git a/python/pyspark/java_gateway.py b/python/pyspark/java_gateway.py index 2a5e84a7dfdb4..45bc38f7e61f8 100644 --- a/python/pyspark/java_gateway.py +++ b/python/pyspark/java_gateway.py @@ -69,7 +69,7 @@ def preexec_func(): if callback_socket in readable: gateway_connection = callback_socket.accept()[0] # Determine which ephemeral port the server started on: - gateway_port = read_int(gateway_connection.makefile()) + gateway_port = read_int(gateway_connection.makefile(mode="rb")) gateway_connection.close() callback_socket.close() if gateway_port is None: diff --git a/python/pyspark/join.py b/python/pyspark/join.py index c3491defb2b29..94df3990164d6 100644 --- a/python/pyspark/join.py +++ b/python/pyspark/join.py @@ -32,6 +32,7 @@ """ from pyspark.resultiterable import ResultIterable +from functools import reduce def _do_python_join(rdd, other, numPartitions, dispatch): diff --git a/python/pyspark/ml/classification.py b/python/pyspark/ml/classification.py index 7f42de531f3b4..45754bc9d4b10 100644 --- a/python/pyspark/ml/classification.py +++ b/python/pyspark/ml/classification.py @@ -39,10 +39,10 @@ class LogisticRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredicti >>> lr = LogisticRegression(maxIter=5, regParam=0.01) >>> model = lr.fit(df) >>> test0 = sc.parallelize([Row(features=Vectors.dense(-1.0))]).toDF() - >>> print model.transform(test0).head().prediction + >>> model.transform(test0).head().prediction 0.0 >>> test1 = sc.parallelize([Row(features=Vectors.sparse(1, [0], [1.0]))]).toDF() - >>> print model.transform(test1).head().prediction + >>> model.transform(test1).head().prediction 1.0 >>> lr.setParams("vector") Traceback (most recent call last): @@ -59,6 +59,7 @@ def __init__(self, featuresCol="features", labelCol="label", predictionCol="pred maxIter=100, regParam=0.1) """ super(LogisticRegression, self).__init__() + self._setDefault(maxIter=100, regParam=0.1) kwargs = self.__init__._input_kwargs self.setParams(**kwargs) @@ -71,7 +72,7 @@ def setParams(self, featuresCol="features", labelCol="label", predictionCol="pre Sets params for logistic regression. """ kwargs = self.setParams._input_kwargs - return self._set_params(**kwargs) + return self._set(**kwargs) def _create_model(self, java_model): return LogisticRegressionModel(java_model) diff --git a/python/pyspark/ml/feature.py b/python/pyspark/ml/feature.py index 1cfcd019dfb18..4e4614b859ac6 100644 --- a/python/pyspark/ml/feature.py +++ b/python/pyspark/ml/feature.py @@ -15,6 +15,7 @@ # limitations under the License. # +from pyspark.rdd import ignore_unicode_prefix from pyspark.ml.param.shared import HasInputCol, HasOutputCol, HasNumFeatures from pyspark.ml.util import keyword_only from pyspark.ml.wrapper import JavaTransformer @@ -24,6 +25,7 @@ @inherit_doc +@ignore_unicode_prefix class Tokenizer(JavaTransformer, HasInputCol, HasOutputCol): """ A tokenizer that converts the input string to lowercase and then @@ -32,15 +34,15 @@ class Tokenizer(JavaTransformer, HasInputCol, HasOutputCol): >>> from pyspark.sql import Row >>> df = sc.parallelize([Row(text="a b c")]).toDF() >>> tokenizer = Tokenizer(inputCol="text", outputCol="words") - >>> print tokenizer.transform(df).head() + >>> tokenizer.transform(df).head() Row(text=u'a b c', words=[u'a', u'b', u'c']) >>> # Change a parameter. - >>> print tokenizer.setParams(outputCol="tokens").transform(df).head() + >>> tokenizer.setParams(outputCol="tokens").transform(df).head() Row(text=u'a b c', tokens=[u'a', u'b', u'c']) >>> # Temporarily modify a parameter. - >>> print tokenizer.transform(df, {tokenizer.outputCol: "words"}).head() + >>> tokenizer.transform(df, {tokenizer.outputCol: "words"}).head() Row(text=u'a b c', words=[u'a', u'b', u'c']) - >>> print tokenizer.transform(df).head() + >>> tokenizer.transform(df).head() Row(text=u'a b c', tokens=[u'a', u'b', u'c']) >>> # Must use keyword arguments to specify params. >>> tokenizer.setParams("text") @@ -52,22 +54,22 @@ class Tokenizer(JavaTransformer, HasInputCol, HasOutputCol): _java_class = "org.apache.spark.ml.feature.Tokenizer" @keyword_only - def __init__(self, inputCol="input", outputCol="output"): + def __init__(self, inputCol=None, outputCol=None): """ - __init__(self, inputCol="input", outputCol="output") + __init__(self, inputCol=None, outputCol=None) """ super(Tokenizer, self).__init__() kwargs = self.__init__._input_kwargs self.setParams(**kwargs) @keyword_only - def setParams(self, inputCol="input", outputCol="output"): + def setParams(self, inputCol=None, outputCol=None): """ setParams(self, inputCol="input", outputCol="output") Sets params for this Tokenizer. """ kwargs = self.setParams._input_kwargs - return self._set_params(**kwargs) + return self._set(**kwargs) @inherit_doc @@ -79,34 +81,35 @@ class HashingTF(JavaTransformer, HasInputCol, HasOutputCol, HasNumFeatures): >>> from pyspark.sql import Row >>> df = sc.parallelize([Row(words=["a", "b", "c"])]).toDF() >>> hashingTF = HashingTF(numFeatures=10, inputCol="words", outputCol="features") - >>> print hashingTF.transform(df).head().features - (10,[7,8,9],[1.0,1.0,1.0]) - >>> print hashingTF.setParams(outputCol="freqs").transform(df).head().freqs - (10,[7,8,9],[1.0,1.0,1.0]) + >>> hashingTF.transform(df).head().features + SparseVector(10, {7: 1.0, 8: 1.0, 9: 1.0}) + >>> hashingTF.setParams(outputCol="freqs").transform(df).head().freqs + SparseVector(10, {7: 1.0, 8: 1.0, 9: 1.0}) >>> params = {hashingTF.numFeatures: 5, hashingTF.outputCol: "vector"} - >>> print hashingTF.transform(df, params).head().vector - (5,[2,3,4],[1.0,1.0,1.0]) + >>> hashingTF.transform(df, params).head().vector + SparseVector(5, {2: 1.0, 3: 1.0, 4: 1.0}) """ _java_class = "org.apache.spark.ml.feature.HashingTF" @keyword_only - def __init__(self, numFeatures=1 << 18, inputCol="input", outputCol="output"): + def __init__(self, numFeatures=1 << 18, inputCol=None, outputCol=None): """ - __init__(self, numFeatures=1 << 18, inputCol="input", outputCol="output") + __init__(self, numFeatures=1 << 18, inputCol=None, outputCol=None) """ super(HashingTF, self).__init__() + self._setDefault(numFeatures=1 << 18) kwargs = self.__init__._input_kwargs self.setParams(**kwargs) @keyword_only - def setParams(self, numFeatures=1 << 18, inputCol="input", outputCol="output"): + def setParams(self, numFeatures=1 << 18, inputCol=None, outputCol=None): """ - setParams(self, numFeatures=1 << 18, inputCol="input", outputCol="output") + setParams(self, numFeatures=1 << 18, inputCol=None, outputCol=None) Sets params for this HashingTF. """ kwargs = self.setParams._input_kwargs - return self._set_params(**kwargs) + return self._set(**kwargs) if __name__ == "__main__": diff --git a/python/pyspark/ml/param/__init__.py b/python/pyspark/ml/param/__init__.py index e3a53dd780c4c..9fccb65675185 100644 --- a/python/pyspark/ml/param/__init__.py +++ b/python/pyspark/ml/param/__init__.py @@ -25,23 +25,21 @@ class Param(object): """ - A param with self-contained documentation and optionally default value. + A param with self-contained documentation. """ - def __init__(self, parent, name, doc, defaultValue=None): - if not isinstance(parent, Identifiable): - raise ValueError("Parent must be identifiable but got type %s." % type(parent).__name__) + def __init__(self, parent, name, doc): + if not isinstance(parent, Params): + raise ValueError("Parent must be a Params but got type %s." % type(parent).__name__) self.parent = parent self.name = str(name) self.doc = str(doc) - self.defaultValue = defaultValue def __str__(self): - return str(self.parent) + "-" + self.name + return str(self.parent) + "__" + self.name def __repr__(self): - return "Param(parent=%r, name=%r, doc=%r, defaultValue=%r)" % \ - (self.parent, self.name, self.doc, self.defaultValue) + return "Param(parent=%r, name=%r, doc=%r)" % (self.parent, self.name, self.doc) class Params(Identifiable): @@ -52,26 +50,128 @@ class Params(Identifiable): __metaclass__ = ABCMeta - def __init__(self): - super(Params, self).__init__() - #: embedded param map - self.paramMap = {} + #: internal param map for user-supplied values param map + paramMap = {} + + #: internal param map for default values + defaultParamMap = {} @property def params(self): """ - Returns all params. The default implementation uses - :py:func:`dir` to get all attributes of type + Returns all params ordered by name. The default implementation + uses :py:func:`dir` to get all attributes of type :py:class:`Param`. """ - return filter(lambda attr: isinstance(attr, Param), - [getattr(self, x) for x in dir(self) if x != "params"]) + return list(filter(lambda attr: isinstance(attr, Param), + [getattr(self, x) for x in dir(self) if x != "params"])) + + def _explain(self, param): + """ + Explains a single param and returns its name, doc, and optional + default value and user-supplied value in a string. + """ + param = self._resolveParam(param) + values = [] + if self.isDefined(param): + if param in self.defaultParamMap: + values.append("default: %s" % self.defaultParamMap[param]) + if param in self.paramMap: + values.append("current: %s" % self.paramMap[param]) + else: + values.append("undefined") + valueStr = "(" + ", ".join(values) + ")" + return "%s: %s %s" % (param.name, param.doc, valueStr) + + def explainParams(self): + """ + Returns the documentation of all params with their optionally + default values and user-supplied values. + """ + return "\n".join([self._explain(param) for param in self.params]) + + def getParam(self, paramName): + """ + Gets a param by its name. + """ + param = getattr(self, paramName) + if isinstance(param, Param): + return param + else: + raise ValueError("Cannot find param with name %s." % paramName) + + def isSet(self, param): + """ + Checks whether a param is explicitly set by user. + """ + param = self._resolveParam(param) + return param in self.paramMap + + def hasDefault(self, param): + """ + Checks whether a param has a default value. + """ + param = self._resolveParam(param) + return param in self.defaultParamMap + + def isDefined(self, param): + """ + Checks whether a param is explicitly set by user or has a default value. + """ + return self.isSet(param) or self.hasDefault(param) - def _merge_params(self, params): - paramMap = self.paramMap.copy() - paramMap.update(params) + def getOrDefault(self, param): + """ + Gets the value of a param in the user-supplied param map or its + default value. Raises an error if either is set. + """ + if isinstance(param, Param): + if param in self.paramMap: + return self.paramMap[param] + else: + return self.defaultParamMap[param] + elif isinstance(param, str): + return self.getOrDefault(self.getParam(param)) + else: + raise KeyError("Cannot recognize %r as a param." % param) + + def extractParamMap(self, extraParamMap={}): + """ + Extracts the embedded default param values and user-supplied + values, and then merges them with extra values from input into + a flat param map, where the latter value is used if there exist + conflicts, i.e., with ordering: default param values < + user-supplied values < extraParamMap. + :param extraParamMap: extra param values + :return: merged param map + """ + paramMap = self.defaultParamMap.copy() + paramMap.update(self.paramMap) + paramMap.update(extraParamMap) return paramMap + def _shouldOwn(self, param): + """ + Validates that the input param belongs to this Params instance. + """ + if param.parent is not self: + raise ValueError("Param %r does not belong to %r." % (param, self)) + + def _resolveParam(self, param): + """ + Resolves a param and validates the ownership. + :param param: param name or the param instance, which must + belong to this Params instance + :return: resolved param instance + """ + if isinstance(param, Param): + self._shouldOwn(param) + return param + elif isinstance(param, str): + return self.getParam(param) + else: + raise ValueError("Cannot resolve %r as a param." % param) + @staticmethod def _dummy(): """ @@ -81,10 +181,18 @@ def _dummy(): dummy.uid = "undefined" return dummy - def _set_params(self, **kwargs): + def _set(self, **kwargs): """ - Sets params. + Sets user-supplied params. """ - for param, value in kwargs.iteritems(): + for param, value in kwargs.items(): self.paramMap[getattr(self, param)] = value return self + + def _setDefault(self, **kwargs): + """ + Sets default params. + """ + for param, value in kwargs.items(): + self.defaultParamMap[getattr(self, param)] = value + return self diff --git a/python/pyspark/ml/param/_gen_shared_params.py b/python/pyspark/ml/param/_shared_params_code_gen.py similarity index 67% rename from python/pyspark/ml/param/_gen_shared_params.py rename to python/pyspark/ml/param/_shared_params_code_gen.py index 5eb81106f116c..6a3192465d66d 100644 --- a/python/pyspark/ml/param/_gen_shared_params.py +++ b/python/pyspark/ml/param/_shared_params_code_gen.py @@ -15,6 +15,8 @@ # limitations under the License. # +from __future__ import print_function + header = """# # Licensed to the Apache Software Foundation (ASF) under one or more # contributor license agreements. See the NOTICE file distributed with @@ -32,29 +34,34 @@ # limitations under the License. #""" +# Code generator for shared params (shared.py). Run under this folder with: +# python _shared_params_code_gen.py > shared.py + -def _gen_param_code(name, doc, defaultValue): +def _gen_param_code(name, doc, defaultValueStr): """ Generates Python code for a shared param class. :param name: param name :param doc: param doc - :param defaultValue: string representation of the param + :param defaultValueStr: string representation of the default value :return: code string """ # TODO: How to correctly inherit instance attributes? template = '''class Has$Name(Params): """ - Params with $name. + Mixin for param $name: $doc. """ # a placeholder to make it appear in the generated doc - $name = Param(Params._dummy(), "$name", "$doc", $defaultValue) + $name = Param(Params._dummy(), "$name", "$doc") def __init__(self): super(Has$Name, self).__init__() #: param for $doc - self.$name = Param(self, "$name", "$doc", $defaultValue) + self.$name = Param(self, "$name", "$doc") + if $defaultValueStr is not None: + self._setDefault($name=$defaultValueStr) def set$Name(self, value): """ @@ -67,32 +74,29 @@ def get$Name(self): """ Gets the value of $name or its default value. """ - if self.$name in self.paramMap: - return self.paramMap[self.$name] - else: - return self.$name.defaultValue''' + return self.getOrDefault(self.$name)''' - upperCamelName = name[0].upper() + name[1:] + Name = name[0].upper() + name[1:] return template \ .replace("$name", name) \ - .replace("$Name", upperCamelName) \ + .replace("$Name", Name) \ .replace("$doc", doc) \ - .replace("$defaultValue", defaultValue) + .replace("$defaultValueStr", str(defaultValueStr)) if __name__ == "__main__": - print header - print "\n# DO NOT MODIFY. The code is generated by _gen_shared_params.py.\n" - print "from pyspark.ml.param import Param, Params\n\n" + print(header) + print("\n# DO NOT MODIFY THIS FILE! It was generated by _shared_params_code_gen.py.\n") + print("from pyspark.ml.param import Param, Params\n\n") shared = [ - ("maxIter", "max number of iterations", "100"), - ("regParam", "regularization constant", "0.1"), + ("maxIter", "max number of iterations", None), + ("regParam", "regularization constant", None), ("featuresCol", "features column name", "'features'"), ("labelCol", "label column name", "'label'"), ("predictionCol", "prediction column name", "'prediction'"), - ("inputCol", "input column name", "'input'"), - ("outputCol", "output column name", "'output'"), - ("numFeatures", "number of features", "1 << 18")] + ("inputCol", "input column name", None), + ("outputCol", "output column name", None), + ("numFeatures", "number of features", None)] code = [] - for name, doc, defaultValue in shared: - code.append(_gen_param_code(name, doc, defaultValue)) - print "\n\n\n".join(code) + for name, doc, defaultValueStr in shared: + code.append(_gen_param_code(name, doc, defaultValueStr)) + print("\n\n\n".join(code)) diff --git a/python/pyspark/ml/param/shared.py b/python/pyspark/ml/param/shared.py index 586822f2de423..13b6749998ad0 100644 --- a/python/pyspark/ml/param/shared.py +++ b/python/pyspark/ml/param/shared.py @@ -15,23 +15,25 @@ # limitations under the License. # -# DO NOT MODIFY. The code is generated by _gen_shared_params.py. +# DO NOT MODIFY THIS FILE! It was generated by _shared_params_code_gen.py. from pyspark.ml.param import Param, Params class HasMaxIter(Params): """ - Params with maxIter. + Mixin for param maxIter: max number of iterations. """ # a placeholder to make it appear in the generated doc - maxIter = Param(Params._dummy(), "maxIter", "max number of iterations", 100) + maxIter = Param(Params._dummy(), "maxIter", "max number of iterations") def __init__(self): super(HasMaxIter, self).__init__() #: param for max number of iterations - self.maxIter = Param(self, "maxIter", "max number of iterations", 100) + self.maxIter = Param(self, "maxIter", "max number of iterations") + if None is not None: + self._setDefault(maxIter=None) def setMaxIter(self, value): """ @@ -44,24 +46,23 @@ def getMaxIter(self): """ Gets the value of maxIter or its default value. """ - if self.maxIter in self.paramMap: - return self.paramMap[self.maxIter] - else: - return self.maxIter.defaultValue + return self.getOrDefault(self.maxIter) class HasRegParam(Params): """ - Params with regParam. + Mixin for param regParam: regularization constant. """ # a placeholder to make it appear in the generated doc - regParam = Param(Params._dummy(), "regParam", "regularization constant", 0.1) + regParam = Param(Params._dummy(), "regParam", "regularization constant") def __init__(self): super(HasRegParam, self).__init__() #: param for regularization constant - self.regParam = Param(self, "regParam", "regularization constant", 0.1) + self.regParam = Param(self, "regParam", "regularization constant") + if None is not None: + self._setDefault(regParam=None) def setRegParam(self, value): """ @@ -74,24 +75,23 @@ def getRegParam(self): """ Gets the value of regParam or its default value. """ - if self.regParam in self.paramMap: - return self.paramMap[self.regParam] - else: - return self.regParam.defaultValue + return self.getOrDefault(self.regParam) class HasFeaturesCol(Params): """ - Params with featuresCol. + Mixin for param featuresCol: features column name. """ # a placeholder to make it appear in the generated doc - featuresCol = Param(Params._dummy(), "featuresCol", "features column name", 'features') + featuresCol = Param(Params._dummy(), "featuresCol", "features column name") def __init__(self): super(HasFeaturesCol, self).__init__() #: param for features column name - self.featuresCol = Param(self, "featuresCol", "features column name", 'features') + self.featuresCol = Param(self, "featuresCol", "features column name") + if 'features' is not None: + self._setDefault(featuresCol='features') def setFeaturesCol(self, value): """ @@ -104,24 +104,23 @@ def getFeaturesCol(self): """ Gets the value of featuresCol or its default value. """ - if self.featuresCol in self.paramMap: - return self.paramMap[self.featuresCol] - else: - return self.featuresCol.defaultValue + return self.getOrDefault(self.featuresCol) class HasLabelCol(Params): """ - Params with labelCol. + Mixin for param labelCol: label column name. """ # a placeholder to make it appear in the generated doc - labelCol = Param(Params._dummy(), "labelCol", "label column name", 'label') + labelCol = Param(Params._dummy(), "labelCol", "label column name") def __init__(self): super(HasLabelCol, self).__init__() #: param for label column name - self.labelCol = Param(self, "labelCol", "label column name", 'label') + self.labelCol = Param(self, "labelCol", "label column name") + if 'label' is not None: + self._setDefault(labelCol='label') def setLabelCol(self, value): """ @@ -134,24 +133,23 @@ def getLabelCol(self): """ Gets the value of labelCol or its default value. """ - if self.labelCol in self.paramMap: - return self.paramMap[self.labelCol] - else: - return self.labelCol.defaultValue + return self.getOrDefault(self.labelCol) class HasPredictionCol(Params): """ - Params with predictionCol. + Mixin for param predictionCol: prediction column name. """ # a placeholder to make it appear in the generated doc - predictionCol = Param(Params._dummy(), "predictionCol", "prediction column name", 'prediction') + predictionCol = Param(Params._dummy(), "predictionCol", "prediction column name") def __init__(self): super(HasPredictionCol, self).__init__() #: param for prediction column name - self.predictionCol = Param(self, "predictionCol", "prediction column name", 'prediction') + self.predictionCol = Param(self, "predictionCol", "prediction column name") + if 'prediction' is not None: + self._setDefault(predictionCol='prediction') def setPredictionCol(self, value): """ @@ -164,24 +162,23 @@ def getPredictionCol(self): """ Gets the value of predictionCol or its default value. """ - if self.predictionCol in self.paramMap: - return self.paramMap[self.predictionCol] - else: - return self.predictionCol.defaultValue + return self.getOrDefault(self.predictionCol) class HasInputCol(Params): """ - Params with inputCol. + Mixin for param inputCol: input column name. """ # a placeholder to make it appear in the generated doc - inputCol = Param(Params._dummy(), "inputCol", "input column name", 'input') + inputCol = Param(Params._dummy(), "inputCol", "input column name") def __init__(self): super(HasInputCol, self).__init__() #: param for input column name - self.inputCol = Param(self, "inputCol", "input column name", 'input') + self.inputCol = Param(self, "inputCol", "input column name") + if None is not None: + self._setDefault(inputCol=None) def setInputCol(self, value): """ @@ -194,24 +191,23 @@ def getInputCol(self): """ Gets the value of inputCol or its default value. """ - if self.inputCol in self.paramMap: - return self.paramMap[self.inputCol] - else: - return self.inputCol.defaultValue + return self.getOrDefault(self.inputCol) class HasOutputCol(Params): """ - Params with outputCol. + Mixin for param outputCol: output column name. """ # a placeholder to make it appear in the generated doc - outputCol = Param(Params._dummy(), "outputCol", "output column name", 'output') + outputCol = Param(Params._dummy(), "outputCol", "output column name") def __init__(self): super(HasOutputCol, self).__init__() #: param for output column name - self.outputCol = Param(self, "outputCol", "output column name", 'output') + self.outputCol = Param(self, "outputCol", "output column name") + if None is not None: + self._setDefault(outputCol=None) def setOutputCol(self, value): """ @@ -224,24 +220,23 @@ def getOutputCol(self): """ Gets the value of outputCol or its default value. """ - if self.outputCol in self.paramMap: - return self.paramMap[self.outputCol] - else: - return self.outputCol.defaultValue + return self.getOrDefault(self.outputCol) class HasNumFeatures(Params): """ - Params with numFeatures. + Mixin for param numFeatures: number of features. """ # a placeholder to make it appear in the generated doc - numFeatures = Param(Params._dummy(), "numFeatures", "number of features", 1 << 18) + numFeatures = Param(Params._dummy(), "numFeatures", "number of features") def __init__(self): super(HasNumFeatures, self).__init__() #: param for number of features - self.numFeatures = Param(self, "numFeatures", "number of features", 1 << 18) + self.numFeatures = Param(self, "numFeatures", "number of features") + if None is not None: + self._setDefault(numFeatures=None) def setNumFeatures(self, value): """ @@ -254,7 +249,4 @@ def getNumFeatures(self): """ Gets the value of numFeatures or its default value. """ - if self.numFeatures in self.paramMap: - return self.paramMap[self.numFeatures] - else: - return self.numFeatures.defaultValue + return self.getOrDefault(self.numFeatures) diff --git a/python/pyspark/ml/pipeline.py b/python/pyspark/ml/pipeline.py index 83880a5afcd1d..d94ecfff09f66 100644 --- a/python/pyspark/ml/pipeline.py +++ b/python/pyspark/ml/pipeline.py @@ -124,10 +124,10 @@ def setParams(self, stages=[]): Sets params for Pipeline. """ kwargs = self.setParams._input_kwargs - return self._set_params(**kwargs) + return self._set(**kwargs) def fit(self, dataset, params={}): - paramMap = self._merge_params(params) + paramMap = self.extractParamMap(params) stages = paramMap[self.stages] for stage in stages: if not (isinstance(stage, Estimator) or isinstance(stage, Transformer)): @@ -164,7 +164,7 @@ def __init__(self, transformers): self.transformers = transformers def transform(self, dataset, params={}): - paramMap = self._merge_params(params) + paramMap = self.extractParamMap(params) for t in self.transformers: dataset = t.transform(dataset, paramMap) return dataset diff --git a/python/pyspark/ml/tests.py b/python/pyspark/ml/tests.py index b627c2b4e930b..3a42bcf723894 100644 --- a/python/pyspark/ml/tests.py +++ b/python/pyspark/ml/tests.py @@ -33,6 +33,7 @@ from pyspark.tests import ReusedPySparkTestCase as PySparkTestCase from pyspark.sql import DataFrame from pyspark.ml.param import Param +from pyspark.ml.param.shared import HasMaxIter, HasInputCol from pyspark.ml.pipeline import Transformer, Estimator, Pipeline @@ -46,7 +47,7 @@ class MockTransformer(Transformer): def __init__(self): super(MockTransformer, self).__init__() - self.fake = Param(self, "fake", "fake", None) + self.fake = Param(self, "fake", "fake") self.dataset_index = None self.fake_param_value = None @@ -62,7 +63,7 @@ class MockEstimator(Estimator): def __init__(self): super(MockEstimator, self).__init__() - self.fake = Param(self, "fake", "fake", None) + self.fake = Param(self, "fake", "fake") self.dataset_index = None self.fake_param_value = None self.model = None @@ -111,5 +112,52 @@ def test_pipeline(self): self.assertEqual(6, dataset.index) +class TestParams(HasMaxIter, HasInputCol): + """ + A subclass of Params mixed with HasMaxIter and HasInputCol. + """ + + def __init__(self): + super(TestParams, self).__init__() + self._setDefault(maxIter=10) + + +class ParamTests(PySparkTestCase): + + def test_param(self): + testParams = TestParams() + maxIter = testParams.maxIter + self.assertEqual(maxIter.name, "maxIter") + self.assertEqual(maxIter.doc, "max number of iterations") + self.assertTrue(maxIter.parent is testParams) + + def test_params(self): + testParams = TestParams() + maxIter = testParams.maxIter + inputCol = testParams.inputCol + + params = testParams.params + self.assertEqual(params, [inputCol, maxIter]) + + self.assertTrue(testParams.hasDefault(maxIter)) + self.assertFalse(testParams.isSet(maxIter)) + self.assertTrue(testParams.isDefined(maxIter)) + self.assertEqual(testParams.getMaxIter(), 10) + testParams.setMaxIter(100) + self.assertTrue(testParams.isSet(maxIter)) + self.assertEquals(testParams.getMaxIter(), 100) + + self.assertFalse(testParams.hasDefault(inputCol)) + self.assertFalse(testParams.isSet(inputCol)) + self.assertFalse(testParams.isDefined(inputCol)) + with self.assertRaises(KeyError): + testParams.getInputCol() + + self.assertEquals( + testParams.explainParams(), + "\n".join(["inputCol: input column name (undefined)", + "maxIter: max number of iterations (default: 10, current: 100)"])) + + if __name__ == "__main__": unittest.main() diff --git a/python/pyspark/ml/util.py b/python/pyspark/ml/util.py index 6f7f39c40eb5a..d3cb100a9efa5 100644 --- a/python/pyspark/ml/util.py +++ b/python/pyspark/ml/util.py @@ -40,8 +40,8 @@ class Identifiable(object): def __init__(self): #: A unique id for the object. The default implementation - #: concatenates the class name, "-", and 8 random hex chars. - self.uid = type(self).__name__ + "-" + uuid.uuid4().hex[:8] + #: concatenates the class name, "_", and 8 random hex chars. + self.uid = type(self).__name__ + "_" + uuid.uuid4().hex[:8] def __repr__(self): return self.uid diff --git a/python/pyspark/ml/wrapper.py b/python/pyspark/ml/wrapper.py index 31a66b3d2f730..394f23c5e9b12 100644 --- a/python/pyspark/ml/wrapper.py +++ b/python/pyspark/ml/wrapper.py @@ -64,7 +64,7 @@ def _transfer_params_to_java(self, params, java_obj): :param params: additional params (overwriting embedded values) :param java_obj: Java object to receive the params """ - paramMap = self._merge_params(params) + paramMap = self.extractParamMap(params) for param in self.params: if param in paramMap: java_obj.set(param.name, paramMap[param]) diff --git a/python/pyspark/mllib/__init__.py b/python/pyspark/mllib/__init__.py index f2ef573fe9f6f..07507b2ad0d05 100644 --- a/python/pyspark/mllib/__init__.py +++ b/python/pyspark/mllib/__init__.py @@ -18,6 +18,7 @@ """ Python bindings for MLlib. """ +from __future__ import absolute_import # MLlib currently needs NumPy 1.4+, so complain if lower @@ -29,7 +30,9 @@ 'recommendation', 'regression', 'stat', 'tree', 'util'] import sys -import rand as random -random.__name__ = 'random' -random.RandomRDDs.__module__ = __name__ + '.random' -sys.modules[__name__ + '.random'] = random +from . import rand as random +modname = __name__ + '.random' +random.__name__ = modname +random.RandomRDDs.__module__ = modname +sys.modules[modname] = random +del modname, sys diff --git a/python/pyspark/mllib/classification.py b/python/pyspark/mllib/classification.py index 2466e8ac43458..eda0b60f8b1e7 100644 --- a/python/pyspark/mllib/classification.py +++ b/python/pyspark/mllib/classification.py @@ -510,9 +510,10 @@ def save(self, sc, path): def load(cls, sc, path): java_model = sc._jvm.org.apache.spark.mllib.classification.NaiveBayesModel.load( sc._jsc.sc(), path) - py_labels = _java2py(sc, java_model.labels()) - py_pi = _java2py(sc, java_model.pi()) - py_theta = _java2py(sc, java_model.theta()) + # Can not unpickle array.array from Pyrolite in Python3 with "bytes" + py_labels = _java2py(sc, java_model.labels(), "latin1") + py_pi = _java2py(sc, java_model.pi(), "latin1") + py_theta = _java2py(sc, java_model.theta(), "latin1") return NaiveBayesModel(py_labels, py_pi, numpy.array(py_theta)) diff --git a/python/pyspark/mllib/clustering.py b/python/pyspark/mllib/clustering.py index 464f49aeee3cd..abbb7cf60eece 100644 --- a/python/pyspark/mllib/clustering.py +++ b/python/pyspark/mllib/clustering.py @@ -15,6 +15,12 @@ # limitations under the License. # +import sys +import array as pyarray + +if sys.version > '3': + xrange = range + from numpy import array from pyspark import RDD @@ -55,8 +61,8 @@ class KMeansModel(Saveable, Loader): True >>> model.predict(sparse_data[2]) == model.predict(sparse_data[3]) True - >>> type(model.clusterCenters) - + >>> isinstance(model.clusterCenters, list) + True >>> import os, tempfile >>> path = tempfile.mkdtemp() >>> model.save(sc, path) @@ -90,7 +96,7 @@ def predict(self, x): return best def save(self, sc, path): - java_centers = _py2java(sc, map(_convert_to_vector, self.centers)) + java_centers = _py2java(sc, [_convert_to_vector(c) for c in self.centers]) java_model = sc._jvm.org.apache.spark.mllib.clustering.KMeansModel(java_centers) java_model.save(sc._jsc.sc(), path) @@ -133,7 +139,7 @@ class GaussianMixtureModel(object): ... 5.7048, 4.6567, 5.5026, ... 4.5605, 5.2043, 6.2734]).reshape(5, 3)) >>> model = GaussianMixture.train(clusterdata_2, 2, convergenceTol=0.0001, - ... maxIterations=150, seed=10) + ... maxIterations=150, seed=10) >>> labels = model.predict(clusterdata_2).collect() >>> labels[0]==labels[1]==labels[2] True @@ -168,8 +174,8 @@ def predictSoft(self, x): if isinstance(x, RDD): means, sigmas = zip(*[(g.mu, g.sigma) for g in self.gaussians]) membership_matrix = callMLlibFunc("predictSoftGMM", x.map(_convert_to_vector), - self.weights, means, sigmas) - return membership_matrix + _convert_to_vector(self.weights), means, sigmas) + return membership_matrix.map(lambda x: pyarray.array('d', x)) class GaussianMixture(object): diff --git a/python/pyspark/mllib/common.py b/python/pyspark/mllib/common.py index a539d2f2846f9..ba6058978880a 100644 --- a/python/pyspark/mllib/common.py +++ b/python/pyspark/mllib/common.py @@ -15,6 +15,11 @@ # limitations under the License. # +import sys +if sys.version >= '3': + long = int + unicode = str + import py4j.protocol from py4j.protocol import Py4JJavaError from py4j.java_gateway import JavaObject @@ -36,7 +41,7 @@ def _new_smart_decode(obj): if isinstance(obj, float): - s = unicode(obj) + s = str(obj) return _float_str_mapping.get(s, s) return _old_smart_decode(obj) @@ -74,15 +79,15 @@ def _py2java(sc, obj): obj = ListConverter().convert([_py2java(sc, x) for x in obj], sc._gateway._gateway_client) elif isinstance(obj, JavaObject): pass - elif isinstance(obj, (int, long, float, bool, basestring)): + elif isinstance(obj, (int, long, float, bool, bytes, unicode)): pass else: - bytes = bytearray(PickleSerializer().dumps(obj)) - obj = sc._jvm.SerDe.loads(bytes) + data = bytearray(PickleSerializer().dumps(obj)) + obj = sc._jvm.SerDe.loads(data) return obj -def _java2py(sc, r): +def _java2py(sc, r, encoding="bytes"): if isinstance(r, JavaObject): clsName = r.getClass().getSimpleName() # convert RDD into JavaRDD @@ -102,8 +107,8 @@ def _java2py(sc, r): except Py4JJavaError: pass # not pickable - if isinstance(r, bytearray): - r = PickleSerializer().loads(str(r)) + if isinstance(r, (bytearray, bytes)): + r = PickleSerializer().loads(bytes(r), encoding=encoding) return r diff --git a/python/pyspark/mllib/feature.py b/python/pyspark/mllib/feature.py index 8be819aceec24..1140539a24e95 100644 --- a/python/pyspark/mllib/feature.py +++ b/python/pyspark/mllib/feature.py @@ -23,12 +23,17 @@ import sys import warnings import random +import binascii +if sys.version >= '3': + basestring = str + unicode = str from py4j.protocol import Py4JJavaError -from pyspark import RDD, SparkContext +from pyspark import SparkContext +from pyspark.rdd import RDD, ignore_unicode_prefix from pyspark.mllib.common import callMLlibFunc, JavaModelWrapper -from pyspark.mllib.linalg import Vectors, Vector, _convert_to_vector +from pyspark.mllib.linalg import Vectors, _convert_to_vector __all__ = ['Normalizer', 'StandardScalerModel', 'StandardScaler', 'HashingTF', 'IDFModel', 'IDF', 'Word2Vec', 'Word2VecModel'] @@ -206,7 +211,7 @@ class HashingTF(object): >>> htf = HashingTF(100) >>> doc = "a a b b c d".split(" ") >>> htf.transform(doc) - SparseVector(100, {1: 1.0, 14: 1.0, 31: 2.0, 44: 2.0}) + SparseVector(100, {...}) """ def __init__(self, numFeatures=1 << 20): """ @@ -360,6 +365,7 @@ def getVectors(self): return self.call("getVectors") +@ignore_unicode_prefix class Word2Vec(object): """ Word2Vec creates vector representation of words in a text corpus. @@ -382,7 +388,7 @@ class Word2Vec(object): >>> sentence = "a b " * 100 + "a c " * 10 >>> localDoc = [sentence, sentence] >>> doc = sc.parallelize(localDoc).map(lambda line: line.split(" ")) - >>> model = Word2Vec().setVectorSize(10).setSeed(42L).fit(doc) + >>> model = Word2Vec().setVectorSize(10).setSeed(42).fit(doc) >>> syms = model.findSynonyms("a", 2) >>> [s[0] for s in syms] @@ -400,7 +406,7 @@ def __init__(self): self.learningRate = 0.025 self.numPartitions = 1 self.numIterations = 1 - self.seed = random.randint(0, sys.maxint) + self.seed = random.randint(0, sys.maxsize) self.minCount = 5 def setVectorSize(self, vectorSize): @@ -459,7 +465,7 @@ def fit(self, data): raise TypeError("data should be an RDD of list of string") jmodel = callMLlibFunc("trainWord2Vec", data, int(self.vectorSize), float(self.learningRate), int(self.numPartitions), - int(self.numIterations), long(self.seed), + int(self.numIterations), int(self.seed), int(self.minCount)) return Word2VecModel(jmodel) diff --git a/python/pyspark/mllib/fpm.py b/python/pyspark/mllib/fpm.py index 3aa6d79d7093c..628ccc01cf3cc 100644 --- a/python/pyspark/mllib/fpm.py +++ b/python/pyspark/mllib/fpm.py @@ -16,12 +16,14 @@ # from pyspark import SparkContext +from pyspark.rdd import ignore_unicode_prefix from pyspark.mllib.common import JavaModelWrapper, callMLlibFunc, inherit_doc __all__ = ['FPGrowth', 'FPGrowthModel'] @inherit_doc +@ignore_unicode_prefix class FPGrowthModel(JavaModelWrapper): """ diff --git a/python/pyspark/mllib/linalg.py b/python/pyspark/mllib/linalg.py index a80320c52d1d0..38b3aa3ad460e 100644 --- a/python/pyspark/mllib/linalg.py +++ b/python/pyspark/mllib/linalg.py @@ -25,7 +25,13 @@ import sys import array -import copy_reg + +if sys.version >= '3': + basestring = str + xrange = range + import copyreg as copy_reg +else: + import copy_reg import numpy as np @@ -57,7 +63,7 @@ def fast_pickle_array(ar): def _convert_to_vector(l): if isinstance(l, Vector): return l - elif type(l) in (array.array, np.array, np.ndarray, list, tuple): + elif type(l) in (array.array, np.array, np.ndarray, list, tuple, xrange): return DenseVector(l) elif _have_scipy and scipy.sparse.issparse(l): assert l.shape[1] == 1, "Expected column vector" @@ -88,7 +94,7 @@ def _vector_size(v): """ if isinstance(v, Vector): return len(v) - elif type(v) in (array.array, list, tuple): + elif type(v) in (array.array, list, tuple, xrange): return len(v) elif type(v) == np.ndarray: if v.ndim == 1 or (v.ndim == 2 and v.shape[1] == 1): @@ -193,7 +199,7 @@ class DenseVector(Vector): DenseVector([1.0, 0.0]) """ def __init__(self, ar): - if isinstance(ar, basestring): + if isinstance(ar, bytes): ar = np.frombuffer(ar, dtype=np.float64) elif not isinstance(ar, np.ndarray): ar = np.array(ar, dtype=np.float64) @@ -321,11 +327,13 @@ def func(self, other): __sub__ = _delegate("__sub__") __mul__ = _delegate("__mul__") __div__ = _delegate("__div__") + __truediv__ = _delegate("__truediv__") __mod__ = _delegate("__mod__") __radd__ = _delegate("__radd__") __rsub__ = _delegate("__rsub__") __rmul__ = _delegate("__rmul__") __rdiv__ = _delegate("__rdiv__") + __rtruediv__ = _delegate("__rtruediv__") __rmod__ = _delegate("__rmod__") @@ -344,12 +352,12 @@ def __init__(self, size, *args): :param args: Non-zero entries, as a dictionary, list of tupes, or two sorted lists containing indices and values. - >>> print SparseVector(4, {1: 1.0, 3: 5.5}) - (4,[1,3],[1.0,5.5]) - >>> print SparseVector(4, [(1, 1.0), (3, 5.5)]) - (4,[1,3],[1.0,5.5]) - >>> print SparseVector(4, [1, 3], [1.0, 5.5]) - (4,[1,3],[1.0,5.5]) + >>> SparseVector(4, {1: 1.0, 3: 5.5}) + SparseVector(4, {1: 1.0, 3: 5.5}) + >>> SparseVector(4, [(1, 1.0), (3, 5.5)]) + SparseVector(4, {1: 1.0, 3: 5.5}) + >>> SparseVector(4, [1, 3], [1.0, 5.5]) + SparseVector(4, {1: 1.0, 3: 5.5}) """ self.size = int(size) assert 1 <= len(args) <= 2, "must pass either 2 or 3 arguments" @@ -361,8 +369,8 @@ def __init__(self, size, *args): self.indices = np.array([p[0] for p in pairs], dtype=np.int32) self.values = np.array([p[1] for p in pairs], dtype=np.float64) else: - if isinstance(args[0], basestring): - assert isinstance(args[1], str), "values should be string too" + if isinstance(args[0], bytes): + assert isinstance(args[1], bytes), "values should be string too" if args[0]: self.indices = np.frombuffer(args[0], np.int32) self.values = np.frombuffer(args[1], np.float64) @@ -591,12 +599,12 @@ def sparse(size, *args): :param args: Non-zero entries, as a dictionary, list of tupes, or two sorted lists containing indices and values. - >>> print Vectors.sparse(4, {1: 1.0, 3: 5.5}) - (4,[1,3],[1.0,5.5]) - >>> print Vectors.sparse(4, [(1, 1.0), (3, 5.5)]) - (4,[1,3],[1.0,5.5]) - >>> print Vectors.sparse(4, [1, 3], [1.0, 5.5]) - (4,[1,3],[1.0,5.5]) + >>> Vectors.sparse(4, {1: 1.0, 3: 5.5}) + SparseVector(4, {1: 1.0, 3: 5.5}) + >>> Vectors.sparse(4, [(1, 1.0), (3, 5.5)]) + SparseVector(4, {1: 1.0, 3: 5.5}) + >>> Vectors.sparse(4, [1, 3], [1.0, 5.5]) + SparseVector(4, {1: 1.0, 3: 5.5}) """ return SparseVector(size, *args) @@ -645,7 +653,7 @@ def _convert_to_array(array_like, dtype): """ Convert Matrix attributes which are array-like or buffer to array. """ - if isinstance(array_like, basestring): + if isinstance(array_like, bytes): return np.frombuffer(array_like, dtype=dtype) return np.asarray(array_like, dtype=dtype) @@ -677,7 +685,7 @@ def toArray(self): def toSparse(self): """Convert to SparseMatrix""" indices = np.nonzero(self.values)[0] - colCounts = np.bincount(indices / self.numRows) + colCounts = np.bincount(indices // self.numRows) colPtrs = np.cumsum(np.hstack( (0, colCounts, np.zeros(self.numCols - colCounts.size)))) values = self.values[indices] diff --git a/python/pyspark/mllib/rand.py b/python/pyspark/mllib/rand.py index 20ee9d78bf5b0..06fbc0eb6aef0 100644 --- a/python/pyspark/mllib/rand.py +++ b/python/pyspark/mllib/rand.py @@ -88,10 +88,10 @@ def normalRDD(sc, size, numPartitions=None, seed=None): :param seed: Random seed (default: a random long integer). :return: RDD of float comprised of i.i.d. samples ~ N(0.0, 1.0). - >>> x = RandomRDDs.normalRDD(sc, 1000, seed=1L) + >>> x = RandomRDDs.normalRDD(sc, 1000, seed=1) >>> stats = x.stats() >>> stats.count() - 1000L + 1000 >>> abs(stats.mean() - 0.0) < 0.1 True >>> abs(stats.stdev() - 1.0) < 0.1 @@ -118,10 +118,10 @@ def logNormalRDD(sc, mean, std, size, numPartitions=None, seed=None): >>> std = 1.0 >>> expMean = exp(mean + 0.5 * std * std) >>> expStd = sqrt((exp(std * std) - 1.0) * exp(2.0 * mean + std * std)) - >>> x = RandomRDDs.logNormalRDD(sc, mean, std, 1000, seed=2L) + >>> x = RandomRDDs.logNormalRDD(sc, mean, std, 1000, seed=2) >>> stats = x.stats() >>> stats.count() - 1000L + 1000 >>> abs(stats.mean() - expMean) < 0.5 True >>> from math import sqrt @@ -145,10 +145,10 @@ def poissonRDD(sc, mean, size, numPartitions=None, seed=None): :return: RDD of float comprised of i.i.d. samples ~ Pois(mean). >>> mean = 100.0 - >>> x = RandomRDDs.poissonRDD(sc, mean, 1000, seed=2L) + >>> x = RandomRDDs.poissonRDD(sc, mean, 1000, seed=2) >>> stats = x.stats() >>> stats.count() - 1000L + 1000 >>> abs(stats.mean() - mean) < 0.5 True >>> from math import sqrt @@ -171,10 +171,10 @@ def exponentialRDD(sc, mean, size, numPartitions=None, seed=None): :return: RDD of float comprised of i.i.d. samples ~ Exp(mean). >>> mean = 2.0 - >>> x = RandomRDDs.exponentialRDD(sc, mean, 1000, seed=2L) + >>> x = RandomRDDs.exponentialRDD(sc, mean, 1000, seed=2) >>> stats = x.stats() >>> stats.count() - 1000L + 1000 >>> abs(stats.mean() - mean) < 0.5 True >>> from math import sqrt @@ -202,10 +202,10 @@ def gammaRDD(sc, shape, scale, size, numPartitions=None, seed=None): >>> scale = 2.0 >>> expMean = shape * scale >>> expStd = sqrt(shape * scale * scale) - >>> x = RandomRDDs.gammaRDD(sc, shape, scale, 1000, seed=2L) + >>> x = RandomRDDs.gammaRDD(sc, shape, scale, 1000, seed=2) >>> stats = x.stats() >>> stats.count() - 1000L + 1000 >>> abs(stats.mean() - expMean) < 0.5 True >>> abs(stats.stdev() - expStd) < 0.5 @@ -254,7 +254,7 @@ def normalVectorRDD(sc, numRows, numCols, numPartitions=None, seed=None): :return: RDD of Vector with vectors containing i.i.d. samples ~ `N(0.0, 1.0)`. >>> import numpy as np - >>> mat = np.matrix(RandomRDDs.normalVectorRDD(sc, 100, 100, seed=1L).collect()) + >>> mat = np.matrix(RandomRDDs.normalVectorRDD(sc, 100, 100, seed=1).collect()) >>> mat.shape (100, 100) >>> abs(mat.mean() - 0.0) < 0.1 @@ -286,8 +286,8 @@ def logNormalVectorRDD(sc, mean, std, numRows, numCols, numPartitions=None, seed >>> std = 1.0 >>> expMean = exp(mean + 0.5 * std * std) >>> expStd = sqrt((exp(std * std) - 1.0) * exp(2.0 * mean + std * std)) - >>> mat = np.matrix(RandomRDDs.logNormalVectorRDD(sc, mean, std, \ - 100, 100, seed=1L).collect()) + >>> m = RandomRDDs.logNormalVectorRDD(sc, mean, std, 100, 100, seed=1).collect() + >>> mat = np.matrix(m) >>> mat.shape (100, 100) >>> abs(mat.mean() - expMean) < 0.1 @@ -315,7 +315,7 @@ def poissonVectorRDD(sc, mean, numRows, numCols, numPartitions=None, seed=None): >>> import numpy as np >>> mean = 100.0 - >>> rdd = RandomRDDs.poissonVectorRDD(sc, mean, 100, 100, seed=1L) + >>> rdd = RandomRDDs.poissonVectorRDD(sc, mean, 100, 100, seed=1) >>> mat = np.mat(rdd.collect()) >>> mat.shape (100, 100) @@ -345,7 +345,7 @@ def exponentialVectorRDD(sc, mean, numRows, numCols, numPartitions=None, seed=No >>> import numpy as np >>> mean = 0.5 - >>> rdd = RandomRDDs.exponentialVectorRDD(sc, mean, 100, 100, seed=1L) + >>> rdd = RandomRDDs.exponentialVectorRDD(sc, mean, 100, 100, seed=1) >>> mat = np.mat(rdd.collect()) >>> mat.shape (100, 100) @@ -380,8 +380,7 @@ def gammaVectorRDD(sc, shape, scale, numRows, numCols, numPartitions=None, seed= >>> scale = 2.0 >>> expMean = shape * scale >>> expStd = sqrt(shape * scale * scale) - >>> mat = np.matrix(RandomRDDs.gammaVectorRDD(sc, shape, scale, \ - 100, 100, seed=1L).collect()) + >>> mat = np.matrix(RandomRDDs.gammaVectorRDD(sc, shape, scale, 100, 100, seed=1).collect()) >>> mat.shape (100, 100) >>> abs(mat.mean() - expMean) < 0.1 diff --git a/python/pyspark/mllib/recommendation.py b/python/pyspark/mllib/recommendation.py index c5c4c13dae105..80e0a356bb78a 100644 --- a/python/pyspark/mllib/recommendation.py +++ b/python/pyspark/mllib/recommendation.py @@ -15,6 +15,7 @@ # limitations under the License. # +import array from collections import namedtuple from pyspark import SparkContext @@ -104,14 +105,14 @@ def predictAll(self, user_product): assert isinstance(user_product, RDD), "user_product should be RDD of (user, product)" first = user_product.first() assert len(first) == 2, "user_product should be RDD of (user, product)" - user_product = user_product.map(lambda (u, p): (int(u), int(p))) + user_product = user_product.map(lambda u_p: (int(u_p[0]), int(u_p[1]))) return self.call("predict", user_product) def userFeatures(self): - return self.call("getUserFeatures") + return self.call("getUserFeatures").mapValues(lambda v: array.array('d', v)) def productFeatures(self): - return self.call("getProductFeatures") + return self.call("getProductFeatures").mapValues(lambda v: array.array('d', v)) @classmethod def load(cls, sc, path): diff --git a/python/pyspark/mllib/stat/_statistics.py b/python/pyspark/mllib/stat/_statistics.py index 1d83e9d483f8e..b475be4b4d953 100644 --- a/python/pyspark/mllib/stat/_statistics.py +++ b/python/pyspark/mllib/stat/_statistics.py @@ -15,7 +15,7 @@ # limitations under the License. # -from pyspark import RDD +from pyspark.rdd import RDD, ignore_unicode_prefix from pyspark.mllib.common import callMLlibFunc, JavaModelWrapper from pyspark.mllib.linalg import Matrix, _convert_to_vector from pyspark.mllib.regression import LabeledPoint @@ -38,7 +38,7 @@ def variance(self): return self.call("variance").toArray() def count(self): - return self.call("count") + return int(self.call("count")) def numNonzeros(self): return self.call("numNonzeros").toArray() @@ -78,7 +78,7 @@ def colStats(rdd): >>> cStats.variance() array([ 4., 13., 0., 25.]) >>> cStats.count() - 3L + 3 >>> cStats.numNonzeros() array([ 3., 2., 0., 3.]) >>> cStats.max() @@ -124,20 +124,20 @@ def corr(x, y=None, method=None): >>> rdd = sc.parallelize([Vectors.dense([1, 0, 0, -2]), Vectors.dense([4, 5, 0, 3]), ... Vectors.dense([6, 7, 0, 8]), Vectors.dense([9, 0, 0, 1])]) >>> pearsonCorr = Statistics.corr(rdd) - >>> print str(pearsonCorr).replace('nan', 'NaN') + >>> print(str(pearsonCorr).replace('nan', 'NaN')) [[ 1. 0.05564149 NaN 0.40047142] [ 0.05564149 1. NaN 0.91359586] [ NaN NaN 1. NaN] [ 0.40047142 0.91359586 NaN 1. ]] >>> spearmanCorr = Statistics.corr(rdd, method="spearman") - >>> print str(spearmanCorr).replace('nan', 'NaN') + >>> print(str(spearmanCorr).replace('nan', 'NaN')) [[ 1. 0.10540926 NaN 0.4 ] [ 0.10540926 1. NaN 0.9486833 ] [ NaN NaN 1. NaN] [ 0.4 0.9486833 NaN 1. ]] >>> try: ... Statistics.corr(rdd, "spearman") - ... print "Method name as second argument without 'method=' shouldn't be allowed." + ... print("Method name as second argument without 'method=' shouldn't be allowed.") ... except TypeError: ... pass """ @@ -153,6 +153,7 @@ def corr(x, y=None, method=None): return callMLlibFunc("corr", x.map(float), y.map(float), method) @staticmethod + @ignore_unicode_prefix def chiSqTest(observed, expected=None): """ .. note:: Experimental @@ -188,11 +189,11 @@ def chiSqTest(observed, expected=None): >>> from pyspark.mllib.linalg import Vectors, Matrices >>> observed = Vectors.dense([4, 6, 5]) >>> pearson = Statistics.chiSqTest(observed) - >>> print pearson.statistic + >>> print(pearson.statistic) 0.4 >>> pearson.degreesOfFreedom 2 - >>> print round(pearson.pValue, 4) + >>> print(round(pearson.pValue, 4)) 0.8187 >>> pearson.method u'pearson' @@ -202,12 +203,12 @@ def chiSqTest(observed, expected=None): >>> observed = Vectors.dense([21, 38, 43, 80]) >>> expected = Vectors.dense([3, 5, 7, 20]) >>> pearson = Statistics.chiSqTest(observed, expected) - >>> print round(pearson.pValue, 4) + >>> print(round(pearson.pValue, 4)) 0.0027 >>> data = [40.0, 24.0, 29.0, 56.0, 32.0, 42.0, 31.0, 10.0, 0.0, 30.0, 15.0, 12.0] >>> chi = Statistics.chiSqTest(Matrices.dense(3, 4, data)) - >>> print round(chi.statistic, 4) + >>> print(round(chi.statistic, 4)) 21.9958 >>> data = [LabeledPoint(0.0, Vectors.dense([0.5, 10.0])), @@ -218,9 +219,9 @@ def chiSqTest(observed, expected=None): ... LabeledPoint(1.0, Vectors.dense([3.5, 40.0])),] >>> rdd = sc.parallelize(data, 4) >>> chi = Statistics.chiSqTest(rdd) - >>> print chi[0].statistic + >>> print(chi[0].statistic) 0.75 - >>> print chi[1].statistic + >>> print(chi[1].statistic) 1.5 """ if isinstance(observed, RDD): diff --git a/python/pyspark/mllib/tests.py b/python/pyspark/mllib/tests.py index 8eaddcf8b9b5e..c6ed5acd1770e 100644 --- a/python/pyspark/mllib/tests.py +++ b/python/pyspark/mllib/tests.py @@ -72,11 +72,11 @@ class VectorTests(PySparkTestCase): def _test_serialize(self, v): self.assertEqual(v, ser.loads(ser.dumps(v))) jvec = self.sc._jvm.SerDe.loads(bytearray(ser.dumps(v))) - nv = ser.loads(str(self.sc._jvm.SerDe.dumps(jvec))) + nv = ser.loads(bytes(self.sc._jvm.SerDe.dumps(jvec))) self.assertEqual(v, nv) vs = [v] * 100 jvecs = self.sc._jvm.SerDe.loads(bytearray(ser.dumps(vs))) - nvs = ser.loads(str(self.sc._jvm.SerDe.dumps(jvecs))) + nvs = ser.loads(bytes(self.sc._jvm.SerDe.dumps(jvecs))) self.assertEqual(vs, nvs) def test_serialize(self): @@ -412,11 +412,11 @@ def test_col_norms(self): self.assertEqual(10, len(summary.normL1())) self.assertEqual(10, len(summary.normL2())) - data2 = self.sc.parallelize(xrange(10)).map(lambda x: Vectors.dense(x)) + data2 = self.sc.parallelize(range(10)).map(lambda x: Vectors.dense(x)) summary2 = Statistics.colStats(data2) self.assertEqual(array([45.0]), summary2.normL1()) import math - expectedNormL2 = math.sqrt(sum(map(lambda x: x*x, xrange(10)))) + expectedNormL2 = math.sqrt(sum(map(lambda x: x*x, range(10)))) self.assertTrue(math.fabs(summary2.normL2()[0] - expectedNormL2) < 1e-14) @@ -438,11 +438,11 @@ def test_serialization(self): def test_infer_schema(self): sqlCtx = SQLContext(self.sc) rdd = self.sc.parallelize([LabeledPoint(1.0, self.dv1), LabeledPoint(0.0, self.sv1)]) - srdd = sqlCtx.inferSchema(rdd) - schema = srdd.schema + df = rdd.toDF() + schema = df.schema field = [f for f in schema.fields if f.name == "features"][0] self.assertEqual(field.dataType, self.udt) - vectors = srdd.map(lambda p: p.features).collect() + vectors = df.map(lambda p: p.features).collect() self.assertEqual(len(vectors), 2) for v in vectors: if isinstance(v, SparseVector): @@ -695,7 +695,7 @@ def test_right_number_of_results(self): class SerDeTest(PySparkTestCase): def test_to_java_object_rdd(self): # SPARK-6660 - data = RandomRDDs.uniformRDD(self.sc, 10, 5, seed=0L) + data = RandomRDDs.uniformRDD(self.sc, 10, 5, seed=0) self.assertEqual(_to_java_object_rdd(data).count(), 10) @@ -771,7 +771,7 @@ def test_model_transform(self): if __name__ == "__main__": if not _have_scipy: - print "NOTE: Skipping SciPy tests as it does not seem to be installed" + print("NOTE: Skipping SciPy tests as it does not seem to be installed") unittest.main() if not _have_scipy: - print "NOTE: SciPy tests were skipped as it does not seem to be installed" + print("NOTE: SciPy tests were skipped as it does not seem to be installed") diff --git a/python/pyspark/mllib/tree.py b/python/pyspark/mllib/tree.py index a7a4d2aaf855b..0fe6e4fabe43a 100644 --- a/python/pyspark/mllib/tree.py +++ b/python/pyspark/mllib/tree.py @@ -163,14 +163,16 @@ def trainClassifier(cls, data, numClasses, categoricalFeaturesInfo, ... LabeledPoint(1.0, [3.0]) ... ] >>> model = DecisionTree.trainClassifier(sc.parallelize(data), 2, {}) - >>> print model, # it already has newline + >>> print(model) DecisionTreeModel classifier of depth 1 with 3 nodes - >>> print model.toDebugString(), # it already has newline + + >>> print(model.toDebugString()) DecisionTreeModel classifier of depth 1 with 3 nodes If (feature 0 <= 0.0) Predict: 0.0 Else (feature 0 > 0.0) Predict: 1.0 + >>> model.predict(array([1.0])) 1.0 >>> model.predict(array([0.0])) @@ -318,9 +320,10 @@ def trainClassifier(cls, data, numClasses, categoricalFeaturesInfo, numTrees, 3 >>> model.totalNumNodes() 7 - >>> print model, + >>> print(model) TreeEnsembleModel classifier with 3 trees - >>> print model.toDebugString(), + + >>> print(model.toDebugString()) TreeEnsembleModel classifier with 3 trees Tree 0: @@ -335,6 +338,7 @@ def trainClassifier(cls, data, numClasses, categoricalFeaturesInfo, numTrees, Predict: 0.0 Else (feature 0 > 1.0) Predict: 1.0 + >>> model.predict([2.0]) 1.0 >>> model.predict([0.0]) @@ -483,8 +487,9 @@ def trainClassifier(cls, data, categoricalFeaturesInfo, 100 >>> model.totalNumNodes() 300 - >>> print model, # it already has newline + >>> print(model) # it already has newline TreeEnsembleModel classifier with 100 trees + >>> model.predict([2.0]) 1.0 >>> model.predict([0.0]) diff --git a/python/pyspark/mllib/util.py b/python/pyspark/mllib/util.py index c5c3468eb95e9..16a90db146ef0 100644 --- a/python/pyspark/mllib/util.py +++ b/python/pyspark/mllib/util.py @@ -15,10 +15,14 @@ # limitations under the License. # +import sys import numpy as np import warnings -from pyspark.mllib.common import callMLlibFunc, JavaModelWrapper, inherit_doc +if sys.version > '3': + xrange = range + +from pyspark.mllib.common import callMLlibFunc, inherit_doc from pyspark.mllib.linalg import Vectors, SparseVector, _convert_to_vector @@ -94,22 +98,16 @@ def loadLibSVMFile(sc, path, numFeatures=-1, minPartitions=None, multiclass=None >>> from pyspark.mllib.util import MLUtils >>> from pyspark.mllib.regression import LabeledPoint >>> tempFile = NamedTemporaryFile(delete=True) - >>> tempFile.write("+1 1:1.0 3:2.0 5:3.0\\n-1\\n-1 2:4.0 4:5.0 6:6.0") + >>> _ = tempFile.write(b"+1 1:1.0 3:2.0 5:3.0\\n-1\\n-1 2:4.0 4:5.0 6:6.0") >>> tempFile.flush() >>> examples = MLUtils.loadLibSVMFile(sc, tempFile.name).collect() >>> tempFile.close() - >>> type(examples[0]) == LabeledPoint - True - >>> print examples[0] - (1.0,(6,[0,2,4],[1.0,2.0,3.0])) - >>> type(examples[1]) == LabeledPoint - True - >>> print examples[1] - (-1.0,(6,[],[])) - >>> type(examples[2]) == LabeledPoint - True - >>> print examples[2] - (-1.0,(6,[1,3,5],[4.0,5.0,6.0])) + >>> examples[0] + LabeledPoint(1.0, (6,[0,2,4],[1.0,2.0,3.0])) + >>> examples[1] + LabeledPoint(-1.0, (6,[],[])) + >>> examples[2] + LabeledPoint(-1.0, (6,[1,3,5],[4.0,5.0,6.0])) """ from pyspark.mllib.regression import LabeledPoint if multiclass is not None: diff --git a/python/pyspark/profiler.py b/python/pyspark/profiler.py index 4408996db0790..d18daaabfcb3c 100644 --- a/python/pyspark/profiler.py +++ b/python/pyspark/profiler.py @@ -84,11 +84,11 @@ class Profiler(object): >>> from pyspark import BasicProfiler >>> class MyCustomProfiler(BasicProfiler): ... def show(self, id): - ... print "My custom profiles for RDD:%s" % id + ... print("My custom profiles for RDD:%s" % id) ... >>> conf = SparkConf().set("spark.python.profile", "true") >>> sc = SparkContext('local', 'test', conf=conf, profiler_cls=MyCustomProfiler) - >>> sc.parallelize(list(range(1000))).map(lambda x: 2 * x).take(10) + >>> sc.parallelize(range(1000)).map(lambda x: 2 * x).take(10) [0, 2, 4, 6, 8, 10, 12, 14, 16, 18] >>> sc.show_profiles() My custom profiles for RDD:1 @@ -111,9 +111,9 @@ def show(self, id): """ Print the profile stats to stdout, id is the RDD id """ stats = self.stats() if stats: - print "=" * 60 - print "Profile of RDD" % id - print "=" * 60 + print("=" * 60) + print("Profile of RDD" % id) + print("=" * 60) stats.sort_stats("time", "cumulative").print_stats() def dump(self, id, path): diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py index c9ac95d117574..d9cdbb666f92a 100644 --- a/python/pyspark/rdd.py +++ b/python/pyspark/rdd.py @@ -16,21 +16,29 @@ # import copy -from collections import defaultdict -from itertools import chain, ifilter, imap -import operator import sys +import os +import re +import operator import shlex -from subprocess import Popen, PIPE -from tempfile import NamedTemporaryFile -from threading import Thread import warnings import heapq import bisect import random import socket +from subprocess import Popen, PIPE +from tempfile import NamedTemporaryFile +from threading import Thread +from collections import defaultdict +from itertools import chain +from functools import reduce from math import sqrt, log, isinf, isnan, pow, ceil +if sys.version > '3': + basestring = unicode = str +else: + from itertools import imap as map, ifilter as filter + from pyspark.serializers import NoOpSerializer, CartesianDeserializer, \ BatchedSerializer, CloudPickleSerializer, PairDeserializer, \ PickleSerializer, pack_long, AutoBatchedSerializer @@ -50,20 +58,21 @@ __all__ = ["RDD"] -# TODO: for Python 3.3+, PYTHONHASHSEED should be reset to disable randomized -# hash for string def portable_hash(x): """ - This function returns consistant hash code for builtin types, especially + This function returns consistent hash code for builtin types, especially for None and tuple with None. - The algrithm is similar to that one used by CPython 2.7 + The algorithm is similar to that one used by CPython 2.7 >>> portable_hash(None) 0 >>> portable_hash((None, 1)) & 0xffffffff 219750521 """ + if sys.version >= '3.3' and 'PYTHONHASHSEED' not in os.environ: + raise Exception("Randomness of hash of string should be disabled via PYTHONHASHSEED") + if x is None: return 0 if isinstance(x, tuple): @@ -71,7 +80,7 @@ def portable_hash(x): for i in x: h ^= portable_hash(i) h *= 1000003 - h &= sys.maxint + h &= sys.maxsize h ^= len(x) if h == -1: h = -2 @@ -123,6 +132,19 @@ def _load_from_socket(port, serializer): sock.close() +def ignore_unicode_prefix(f): + """ + Ignore the 'u' prefix of string in doc tests, to make it works + in both python 2 and 3 + """ + if sys.version >= '3': + # the representation of unicode string in Python 3 does not have prefix 'u', + # so remove the prefix 'u' for doc tests + literal_re = re.compile(r"(\W|^)[uU](['])", re.UNICODE) + f.__doc__ = literal_re.sub(r'\1\2', f.__doc__) + return f + + class Partitioner(object): def __init__(self, numPartitions, partitionFunc): self.numPartitions = numPartitions @@ -251,7 +273,7 @@ def map(self, f, preservesPartitioning=False): [('a', 1), ('b', 1), ('c', 1)] """ def func(_, iterator): - return imap(f, iterator) + return map(f, iterator) return self.mapPartitionsWithIndex(func, preservesPartitioning) def flatMap(self, f, preservesPartitioning=False): @@ -266,7 +288,7 @@ def flatMap(self, f, preservesPartitioning=False): [(2, 2), (2, 2), (3, 3), (3, 3), (4, 4), (4, 4)] """ def func(s, iterator): - return chain.from_iterable(imap(f, iterator)) + return chain.from_iterable(map(f, iterator)) return self.mapPartitionsWithIndex(func, preservesPartitioning) def mapPartitions(self, f, preservesPartitioning=False): @@ -329,7 +351,7 @@ def filter(self, f): [2, 4] """ def func(iterator): - return ifilter(f, iterator) + return filter(f, iterator) return self.mapPartitions(func, True) def distinct(self, numPartitions=None): @@ -341,7 +363,7 @@ def distinct(self, numPartitions=None): """ return self.map(lambda x: (x, None)) \ .reduceByKey(lambda x, _: x, numPartitions) \ - .map(lambda (x, _): x) + .map(lambda x: x[0]) def sample(self, withReplacement, fraction, seed=None): """ @@ -354,8 +376,8 @@ def sample(self, withReplacement, fraction, seed=None): :param seed: seed for the random number generator >>> rdd = sc.parallelize(range(100), 4) - >>> rdd.sample(False, 0.1, 81).count() - 10 + >>> 6 <= rdd.sample(False, 0.1, 81).count() <= 14 + True """ assert fraction >= 0.0, "Negative fraction value: %s" % fraction return self.mapPartitionsWithIndex(RDDSampler(withReplacement, fraction, seed).func, True) @@ -368,12 +390,14 @@ def randomSplit(self, weights, seed=None): :param seed: random seed :return: split RDDs in a list - >>> rdd = sc.parallelize(range(5), 1) + >>> rdd = sc.parallelize(range(500), 1) >>> rdd1, rdd2 = rdd.randomSplit([2, 3], 17) - >>> rdd1.collect() - [1, 3] - >>> rdd2.collect() - [0, 2, 4] + >>> len(rdd1.collect() + rdd2.collect()) + 500 + >>> 150 < rdd1.count() < 250 + True + >>> 250 < rdd2.count() < 350 + True """ s = float(sum(weights)) cweights = [0.0] @@ -416,7 +440,7 @@ def takeSample(self, withReplacement, num, seed=None): rand.shuffle(samples) return samples - maxSampleSize = sys.maxint - int(numStDev * sqrt(sys.maxint)) + maxSampleSize = sys.maxsize - int(numStDev * sqrt(sys.maxsize)) if num > maxSampleSize: raise ValueError( "Sample size cannot be greater than %d." % maxSampleSize) @@ -430,7 +454,7 @@ def takeSample(self, withReplacement, num, seed=None): # See: scala/spark/RDD.scala while len(samples) < num: # TODO: add log warning for when more than one iteration was run - seed = rand.randint(0, sys.maxint) + seed = rand.randint(0, sys.maxsize) samples = self.sample(withReplacement, fraction, seed).collect() rand.shuffle(samples) @@ -507,7 +531,7 @@ def intersection(self, other): """ return self.map(lambda v: (v, None)) \ .cogroup(other.map(lambda v: (v, None))) \ - .filter(lambda (k, vs): all(vs)) \ + .filter(lambda k_vs: all(k_vs[1])) \ .keys() def _reserialize(self, serializer=None): @@ -549,7 +573,7 @@ def repartitionAndSortWithinPartitions(self, numPartitions=None, partitionFunc=p def sortPartition(iterator): sort = ExternalSorter(memory * 0.9, serializer).sorted if spill else sorted - return iter(sort(iterator, key=lambda (k, v): keyfunc(k), reverse=(not ascending))) + return iter(sort(iterator, key=lambda k_v: keyfunc(k_v[0]), reverse=(not ascending))) return self.partitionBy(numPartitions, partitionFunc).mapPartitions(sortPartition, True) @@ -579,7 +603,7 @@ def sortByKey(self, ascending=True, numPartitions=None, keyfunc=lambda x: x): def sortPartition(iterator): sort = ExternalSorter(memory * 0.9, serializer).sorted if spill else sorted - return iter(sort(iterator, key=lambda (k, v): keyfunc(k), reverse=(not ascending))) + return iter(sort(iterator, key=lambda kv: keyfunc(kv[0]), reverse=(not ascending))) if numPartitions == 1: if self.getNumPartitions() > 1: @@ -594,12 +618,12 @@ def sortPartition(iterator): return self # empty RDD maxSampleSize = numPartitions * 20.0 # constant from Spark's RangePartitioner fraction = min(maxSampleSize / max(rddSize, 1), 1.0) - samples = self.sample(False, fraction, 1).map(lambda (k, v): k).collect() + samples = self.sample(False, fraction, 1).map(lambda kv: kv[0]).collect() samples = sorted(samples, key=keyfunc) # we have numPartitions many parts but one of the them has # an implicit boundary - bounds = [samples[len(samples) * (i + 1) / numPartitions] + bounds = [samples[int(len(samples) * (i + 1) / numPartitions)] for i in range(0, numPartitions - 1)] def rangePartitioner(k): @@ -662,12 +686,13 @@ def groupBy(self, f, numPartitions=None): """ return self.map(lambda x: (f(x), x)).groupByKey(numPartitions) + @ignore_unicode_prefix def pipe(self, command, env={}): """ Return an RDD created by piping elements to a forked external process. >>> sc.parallelize(['1', '2', '', '3']).pipe('cat').collect() - ['1', '2', '', '3'] + [u'1', u'2', u'', u'3'] """ def func(iterator): pipe = Popen( @@ -675,17 +700,18 @@ def func(iterator): def pipe_objs(out): for obj in iterator: - out.write(str(obj).rstrip('\n') + '\n') + s = str(obj).rstrip('\n') + '\n' + out.write(s.encode('utf-8')) out.close() Thread(target=pipe_objs, args=[pipe.stdin]).start() - return (x.rstrip('\n') for x in iter(pipe.stdout.readline, '')) + return (x.rstrip(b'\n').decode('utf-8') for x in iter(pipe.stdout.readline, b'')) return self.mapPartitions(func) def foreach(self, f): """ Applies a function to all elements of this RDD. - >>> def f(x): print x + >>> def f(x): print(x) >>> sc.parallelize([1, 2, 3, 4, 5]).foreach(f) """ def processPartition(iterator): @@ -700,7 +726,7 @@ def foreachPartition(self, f): >>> def f(iterator): ... for x in iterator: - ... print x + ... print(x) >>> sc.parallelize([1, 2, 3, 4, 5]).foreachPartition(f) """ def func(it): @@ -874,7 +900,7 @@ def aggregatePartition(iterator): # aggregation. while numPartitions > scale + numPartitions / scale: numPartitions /= scale - curNumPartitions = numPartitions + curNumPartitions = int(numPartitions) def mapPartition(i, iterator): for obj in iterator: @@ -984,7 +1010,7 @@ def histogram(self, buckets): (('a', 'b', 'c'), [2, 2]) """ - if isinstance(buckets, (int, long)): + if isinstance(buckets, int): if buckets < 1: raise ValueError("number of buckets must be >= 1") @@ -1020,6 +1046,7 @@ def minmax(a, b): raise ValueError("Can not generate buckets with infinite value") # keep them as integer if possible + inc = int(inc) if inc * buckets != maxv - minv: inc = (maxv - minv) * 1.0 / buckets @@ -1137,7 +1164,7 @@ def countPartition(iterator): yield counts def mergeMaps(m1, m2): - for k, v in m2.iteritems(): + for k, v in m2.items(): m1[k] += v return m1 return self.mapPartitions(countPartition).reduce(mergeMaps) @@ -1197,7 +1224,7 @@ def take(self, num): [91, 92, 93] """ items = [] - totalParts = self._jrdd.partitions().size() + totalParts = self.getNumPartitions() partsScanned = 0 while len(items) < num and partsScanned < totalParts: @@ -1260,7 +1287,7 @@ def isEmpty(self): >>> sc.parallelize([1]).isEmpty() False """ - return self._jrdd.partitions().size() == 0 or len(self.take(1)) == 0 + return self.getNumPartitions() == 0 or len(self.take(1)) == 0 def saveAsNewAPIHadoopDataset(self, conf, keyConverter=None, valueConverter=None): """ @@ -1378,8 +1405,8 @@ def saveAsPickleFile(self, path, batchSize=10): >>> tmpFile = NamedTemporaryFile(delete=True) >>> tmpFile.close() >>> sc.parallelize([1, 2, 'spark', 'rdd']).saveAsPickleFile(tmpFile.name, 3) - >>> sorted(sc.pickleFile(tmpFile.name, 5).collect()) - [1, 2, 'rdd', 'spark'] + >>> sorted(sc.pickleFile(tmpFile.name, 5).map(str).collect()) + ['1', '2', 'rdd', 'spark'] """ if batchSize == 0: ser = AutoBatchedSerializer(PickleSerializer()) @@ -1387,6 +1414,7 @@ def saveAsPickleFile(self, path, batchSize=10): ser = BatchedSerializer(PickleSerializer(), batchSize) self._reserialize(ser)._jrdd.saveAsObjectFile(path) + @ignore_unicode_prefix def saveAsTextFile(self, path, compressionCodecClass=None): """ Save this RDD as a text file, using string representations of elements. @@ -1418,12 +1446,13 @@ def saveAsTextFile(self, path, compressionCodecClass=None): >>> codec = "org.apache.hadoop.io.compress.GzipCodec" >>> sc.parallelize(['foo', 'bar']).saveAsTextFile(tempFile3.name, codec) >>> from fileinput import input, hook_compressed - >>> ''.join(sorted(input(glob(tempFile3.name + "/part*.gz"), openhook=hook_compressed))) - 'bar\\nfoo\\n' + >>> result = sorted(input(glob(tempFile3.name + "/part*.gz"), openhook=hook_compressed)) + >>> b''.join(result).decode('utf-8') + u'bar\\nfoo\\n' """ def func(split, iterator): for x in iterator: - if not isinstance(x, basestring): + if not isinstance(x, (unicode, bytes)): x = unicode(x) if isinstance(x, unicode): x = x.encode("utf-8") @@ -1458,7 +1487,7 @@ def keys(self): >>> m.collect() [1, 3] """ - return self.map(lambda (k, v): k) + return self.map(lambda x: x[0]) def values(self): """ @@ -1468,7 +1497,7 @@ def values(self): >>> m.collect() [2, 4] """ - return self.map(lambda (k, v): v) + return self.map(lambda x: x[1]) def reduceByKey(self, func, numPartitions=None): """ @@ -1507,7 +1536,7 @@ def reducePartition(iterator): yield m def mergeMaps(m1, m2): - for k, v in m2.iteritems(): + for k, v in m2.items(): m1[k] = func(m1[k], v) if k in m1 else v return m1 return self.mapPartitions(reducePartition).reduce(mergeMaps) @@ -1604,8 +1633,8 @@ def partitionBy(self, numPartitions, partitionFunc=portable_hash): >>> pairs = sc.parallelize([1, 2, 3, 4, 2, 4, 1]).map(lambda x: (x, x)) >>> sets = pairs.partitionBy(2).glom().collect() - >>> set(sets[0]).intersection(set(sets[1])) - set([]) + >>> len(set(sets[0]).intersection(set(sets[1]))) + 0 """ if numPartitions is None: numPartitions = self._defaultReducePartitions() @@ -1637,22 +1666,22 @@ def add_shuffle_key(split, iterator): if (c % 1000 == 0 and get_used_memory() > limit or c > batch): n, size = len(buckets), 0 - for split in buckets.keys(): + for split in list(buckets.keys()): yield pack_long(split) d = outputSerializer.dumps(buckets[split]) del buckets[split] yield d size += len(d) - avg = (size / n) >> 20 + avg = int(size / n) >> 20 # let 1M < avg < 10M if avg < 1: batch *= 1.5 elif avg > 10: - batch = max(batch / 1.5, 1) + batch = max(int(batch / 1.5), 1) c = 0 - for split, items in buckets.iteritems(): + for split, items in buckets.items(): yield pack_long(split) yield outputSerializer.dumps(items) @@ -1707,7 +1736,7 @@ def combineLocally(iterator): merger = ExternalMerger(agg, memory * 0.9, serializer) \ if spill else InMemoryMerger(agg) merger.mergeValues(iterator) - return merger.iteritems() + return merger.items() locally_combined = self.mapPartitions(combineLocally, preservesPartitioning=True) shuffled = locally_combined.partitionBy(numPartitions) @@ -1716,7 +1745,7 @@ def _mergeCombiners(iterator): merger = ExternalMerger(agg, memory, serializer) \ if spill else InMemoryMerger(agg) merger.mergeCombiners(iterator) - return merger.iteritems() + return merger.items() return shuffled.mapPartitions(_mergeCombiners, preservesPartitioning=True) @@ -1745,7 +1774,7 @@ def foldByKey(self, zeroValue, func, numPartitions=None): >>> rdd = sc.parallelize([("a", 1), ("b", 1), ("a", 1)]) >>> from operator import add - >>> rdd.foldByKey(0, add).collect() + >>> sorted(rdd.foldByKey(0, add).collect()) [('a', 2), ('b', 1)] """ def createZero(): @@ -1769,10 +1798,10 @@ def groupByKey(self, numPartitions=None): sum or average) over each key, using reduceByKey or aggregateByKey will provide much better performance. - >>> x = sc.parallelize([("a", 1), ("b", 1), ("a", 1)]) - >>> sorted(x.groupByKey().mapValues(len).collect()) + >>> rdd = sc.parallelize([("a", 1), ("b", 1), ("a", 1)]) + >>> sorted(rdd.groupByKey().mapValues(len).collect()) [('a', 2), ('b', 1)] - >>> sorted(x.groupByKey().mapValues(list).collect()) + >>> sorted(rdd.groupByKey().mapValues(list).collect()) [('a', [1, 1]), ('b', [1])] """ def createCombiner(x): @@ -1795,7 +1824,7 @@ def combine(iterator): merger = ExternalMerger(agg, memory * 0.9, serializer) \ if spill else InMemoryMerger(agg) merger.mergeValues(iterator) - return merger.iteritems() + return merger.items() locally_combined = self.mapPartitions(combine, preservesPartitioning=True) shuffled = locally_combined.partitionBy(numPartitions) @@ -1804,7 +1833,7 @@ def groupByKey(it): merger = ExternalGroupBy(agg, memory, serializer)\ if spill else InMemoryMerger(agg) merger.mergeCombiners(it) - return merger.iteritems() + return merger.items() return shuffled.mapPartitions(groupByKey, True).mapValues(ResultIterable) @@ -1819,7 +1848,7 @@ def flatMapValues(self, f): >>> x.flatMapValues(f).collect() [('a', 'x'), ('a', 'y'), ('a', 'z'), ('b', 'p'), ('b', 'r')] """ - flat_map_fn = lambda (k, v): ((k, x) for x in f(v)) + flat_map_fn = lambda kv: ((kv[0], x) for x in f(kv[1])) return self.flatMap(flat_map_fn, preservesPartitioning=True) def mapValues(self, f): @@ -1833,7 +1862,7 @@ def mapValues(self, f): >>> x.mapValues(f).collect() [('a', 3), ('b', 1)] """ - map_values_fn = lambda (k, v): (k, f(v)) + map_values_fn = lambda kv: (kv[0], f(kv[1])) return self.map(map_values_fn, preservesPartitioning=True) def groupWith(self, other, *others): @@ -1844,8 +1873,7 @@ def groupWith(self, other, *others): >>> x = sc.parallelize([("a", 1), ("b", 4)]) >>> y = sc.parallelize([("a", 2)]) >>> z = sc.parallelize([("b", 42)]) - >>> map((lambda (x,y): (x, (list(y[0]), list(y[1]), list(y[2]), list(y[3])))), \ - sorted(list(w.groupWith(x, y, z).collect()))) + >>> [(x, tuple(map(list, y))) for x, y in sorted(list(w.groupWith(x, y, z).collect()))] [('a', ([5], [1], [2], [])), ('b', ([6], [4], [], [42]))] """ @@ -1860,7 +1888,7 @@ def cogroup(self, other, numPartitions=None): >>> x = sc.parallelize([("a", 1), ("b", 4)]) >>> y = sc.parallelize([("a", 2)]) - >>> map((lambda (x,y): (x, (list(y[0]), list(y[1])))), sorted(list(x.cogroup(y).collect()))) + >>> [(x, tuple(map(list, y))) for x, y in sorted(list(x.cogroup(y).collect()))] [('a', ([1], [2])), ('b', ([4], []))] """ return python_cogroup((self, other), numPartitions) @@ -1896,8 +1924,9 @@ def subtractByKey(self, other, numPartitions=None): >>> sorted(x.subtractByKey(y).collect()) [('b', 4), ('b', 5)] """ - def filter_func((key, vals)): - return vals[0] and not vals[1] + def filter_func(pair): + key, (val1, val2) = pair + return val1 and not val2 return self.cogroup(other, numPartitions).filter(filter_func).flatMapValues(lambda x: x[0]) def subtract(self, other, numPartitions=None): @@ -1919,8 +1948,8 @@ def keyBy(self, f): >>> x = sc.parallelize(range(0,3)).keyBy(lambda x: x*x) >>> y = sc.parallelize(zip(range(0,5), range(0,5))) - >>> map((lambda (x,y): (x, (list(y[0]), (list(y[1]))))), sorted(x.cogroup(y).collect())) - [(0, ([0], [0])), (1, ([1], [1])), (2, ([], [2])), (3, ([], [3])), (4, ([2], [4]))] + >>> [(x, list(map(list, y))) for x, y in sorted(x.cogroup(y).collect())] + [(0, [[0], [0]]), (1, [[1], [1]]), (2, [[], [2]]), (3, [[], [3]]), (4, [[2], [4]])] """ return self.map(lambda x: (f(x), x)) @@ -2049,17 +2078,18 @@ def name(self): """ Return the name of this RDD. """ - name_ = self._jrdd.name() - if name_: - return name_.encode('utf-8') + n = self._jrdd.name() + if n: + return n + @ignore_unicode_prefix def setName(self, name): """ Assign a name to this RDD. - >>> rdd1 = sc.parallelize([1,2]) + >>> rdd1 = sc.parallelize([1, 2]) >>> rdd1.setName('RDD1').name() - 'RDD1' + u'RDD1' """ self._jrdd.setName(name) return self @@ -2121,7 +2151,7 @@ def lookup(self, key): >>> sorted.lookup(1024) [] """ - values = self.filter(lambda (k, v): k == key).values() + values = self.filter(lambda kv: kv[0] == key).values() if self.partitioner is not None: return self.ctx.runJob(values, lambda x: x, [self.partitioner(key)], False) @@ -2159,7 +2189,7 @@ def sumApprox(self, timeout, confidence=0.95): or meet the confidence. >>> rdd = sc.parallelize(range(1000), 10) - >>> r = sum(xrange(1000)) + >>> r = sum(range(1000)) >>> (rdd.sumApprox(1000) - r) / r < 0.05 True """ @@ -2176,7 +2206,7 @@ def meanApprox(self, timeout, confidence=0.95): or meet the confidence. >>> rdd = sc.parallelize(range(1000), 10) - >>> r = sum(xrange(1000)) / 1000.0 + >>> r = sum(range(1000)) / 1000.0 >>> (rdd.meanApprox(1000) - r) / r < 0.05 True """ @@ -2201,10 +2231,10 @@ def countApproxDistinct(self, relativeSD=0.05): It must be greater than 0.000017. >>> n = sc.parallelize(range(1000)).map(str).countApproxDistinct() - >>> 950 < n < 1050 + >>> 900 < n < 1100 True >>> n = sc.parallelize([i % 20 for i in range(1000)]).countApproxDistinct() - >>> 18 < n < 22 + >>> 16 < n < 24 True """ if relativeSD < 0.000017: @@ -2223,8 +2253,7 @@ def toLocalIterator(self): >>> [x for x in rdd.toLocalIterator()] [0, 1, 2, 3, 4, 5, 6, 7, 8, 9] """ - partitions = xrange(self.getNumPartitions()) - for partition in partitions: + for partition in range(self.getNumPartitions()): rows = self.context.runJob(self, lambda x: x, [partition]) for row in rows: yield row @@ -2235,11 +2264,9 @@ def _prepare_for_python_RDD(sc, command, obj=None): ser = CloudPickleSerializer() pickled_command = ser.dumps((command, sys.version_info[:2])) if len(pickled_command) > (1 << 20): # 1M + # The broadcast will have same life cycle as created PythonRDD broadcast = sc.broadcast(pickled_command) pickled_command = ser.dumps(broadcast) - # tracking the life cycle by obj - if obj is not None: - obj._broadcast = broadcast broadcast_vars = ListConverter().convert( [x._jbroadcast for x in sc._pickled_broadcast_vars], sc._gateway._gateway_client) @@ -2294,12 +2321,9 @@ def pipeline_func(split, iterator): self._jrdd_deserializer = self.ctx.serializer self._bypass_serializer = False self.partitioner = prev.partitioner if self.preservesPartitioning else None - self._broadcast = None - def __del__(self): - if self._broadcast: - self._broadcast.unpersist() - self._broadcast = None + def getNumPartitions(self): + return self._prev_jrdd.partitions().size() @property def _jrdd(self): diff --git a/python/pyspark/rddsampler.py b/python/pyspark/rddsampler.py index 459e1427803cb..fe8f87324804b 100644 --- a/python/pyspark/rddsampler.py +++ b/python/pyspark/rddsampler.py @@ -23,7 +23,7 @@ class RDDSamplerBase(object): def __init__(self, withReplacement, seed=None): - self._seed = seed if seed is not None else random.randint(0, sys.maxint) + self._seed = seed if seed is not None else random.randint(0, sys.maxsize) self._withReplacement = withReplacement self._random = None @@ -31,7 +31,7 @@ def initRandomGenerator(self, split): self._random = random.Random(self._seed ^ split) # mixing because the initial seeds are close to each other - for _ in xrange(10): + for _ in range(10): self._random.randint(0, 1) def getUniformSample(self): diff --git a/python/pyspark/serializers.py b/python/pyspark/serializers.py index 4afa82f4b2973..d8cdcda3a3783 100644 --- a/python/pyspark/serializers.py +++ b/python/pyspark/serializers.py @@ -49,16 +49,24 @@ >>> sc.stop() """ -import cPickle -from itertools import chain, izip, product +import sys +from itertools import chain, product import marshal import struct -import sys import types import collections import zlib import itertools +if sys.version < '3': + import cPickle as pickle + protocol = 2 + from itertools import izip as zip +else: + import pickle + protocol = 3 + xrange = range + from pyspark import cloudpickle @@ -97,7 +105,7 @@ def _load_stream_without_unbatching(self, stream): # subclasses should override __eq__ as appropriate. def __eq__(self, other): - return isinstance(other, self.__class__) + return isinstance(other, self.__class__) and self.__dict__ == other.__dict__ def __ne__(self, other): return not self.__eq__(other) @@ -212,10 +220,6 @@ def load_stream(self, stream): def _load_stream_without_unbatching(self, stream): return self.serializer.load_stream(stream) - def __eq__(self, other): - return (isinstance(other, BatchedSerializer) and - other.serializer == self.serializer and other.batchSize == self.batchSize) - def __repr__(self): return "BatchedSerializer(%s, %d)" % (str(self.serializer), self.batchSize) @@ -233,14 +237,14 @@ def __init__(self, serializer, batchSize=10): def _batched(self, iterator): n = self.batchSize for key, values in iterator: - for i in xrange(0, len(values), n): + for i in range(0, len(values), n): yield key, values[i:i + n] def load_stream(self, stream): return self.serializer.load_stream(stream) def __repr__(self): - return "FlattenedValuesSerializer(%d)" % self.batchSize + return "FlattenedValuesSerializer(%s, %d)" % (self.serializer, self.batchSize) class AutoBatchedSerializer(BatchedSerializer): @@ -270,12 +274,8 @@ def dump_stream(self, iterator, stream): elif size > best * 10 and batch > 1: batch /= 2 - def __eq__(self, other): - return (isinstance(other, AutoBatchedSerializer) and - other.serializer == self.serializer and other.bestSize == self.bestSize) - def __repr__(self): - return "AutoBatchedSerializer(%s)" % str(self.serializer) + return "AutoBatchedSerializer(%s)" % self.serializer class CartesianDeserializer(FramedSerializer): @@ -285,6 +285,7 @@ class CartesianDeserializer(FramedSerializer): """ def __init__(self, key_ser, val_ser): + FramedSerializer.__init__(self) self.key_ser = key_ser self.val_ser = val_ser @@ -293,7 +294,7 @@ def prepare_keys_values(self, stream): val_stream = self.val_ser._load_stream_without_unbatching(stream) key_is_batched = isinstance(self.key_ser, BatchedSerializer) val_is_batched = isinstance(self.val_ser, BatchedSerializer) - for (keys, vals) in izip(key_stream, val_stream): + for (keys, vals) in zip(key_stream, val_stream): keys = keys if key_is_batched else [keys] vals = vals if val_is_batched else [vals] yield (keys, vals) @@ -303,10 +304,6 @@ def load_stream(self, stream): for pair in product(keys, vals): yield pair - def __eq__(self, other): - return (isinstance(other, CartesianDeserializer) and - self.key_ser == other.key_ser and self.val_ser == other.val_ser) - def __repr__(self): return "CartesianDeserializer(%s, %s)" % \ (str(self.key_ser), str(self.val_ser)) @@ -318,22 +315,14 @@ class PairDeserializer(CartesianDeserializer): Deserializes the JavaRDD zip() of two PythonRDDs. """ - def __init__(self, key_ser, val_ser): - self.key_ser = key_ser - self.val_ser = val_ser - def load_stream(self, stream): for (keys, vals) in self.prepare_keys_values(stream): if len(keys) != len(vals): raise ValueError("Can not deserialize RDD with different number of items" " in pair: (%d, %d)" % (len(keys), len(vals))) - for pair in izip(keys, vals): + for pair in zip(keys, vals): yield pair - def __eq__(self, other): - return (isinstance(other, PairDeserializer) and - self.key_ser == other.key_ser and self.val_ser == other.val_ser) - def __repr__(self): return "PairDeserializer(%s, %s)" % (str(self.key_ser), str(self.val_ser)) @@ -382,8 +371,8 @@ def _hijack_namedtuple(): global _old_namedtuple # or it will put in closure def _copy_func(f): - return types.FunctionType(f.func_code, f.func_globals, f.func_name, - f.func_defaults, f.func_closure) + return types.FunctionType(f.__code__, f.__globals__, f.__name__, + f.__defaults__, f.__closure__) _old_namedtuple = _copy_func(collections.namedtuple) @@ -392,15 +381,15 @@ def namedtuple(*args, **kwargs): return _hack_namedtuple(cls) # replace namedtuple with new one - collections.namedtuple.func_globals["_old_namedtuple"] = _old_namedtuple - collections.namedtuple.func_globals["_hack_namedtuple"] = _hack_namedtuple - collections.namedtuple.func_code = namedtuple.func_code + collections.namedtuple.__globals__["_old_namedtuple"] = _old_namedtuple + collections.namedtuple.__globals__["_hack_namedtuple"] = _hack_namedtuple + collections.namedtuple.__code__ = namedtuple.__code__ collections.namedtuple.__hijack = 1 # hack the cls already generated by namedtuple # those created in other module can be pickled as normal, # so only hack those in __main__ module - for n, o in sys.modules["__main__"].__dict__.iteritems(): + for n, o in sys.modules["__main__"].__dict__.items(): if (type(o) is type and o.__base__ is tuple and hasattr(o, "_fields") and "__reduce__" not in o.__dict__): @@ -413,7 +402,7 @@ def namedtuple(*args, **kwargs): class PickleSerializer(FramedSerializer): """ - Serializes objects using Python's cPickle serializer: + Serializes objects using Python's pickle serializer: http://docs.python.org/2/library/pickle.html @@ -422,10 +411,14 @@ class PickleSerializer(FramedSerializer): """ def dumps(self, obj): - return cPickle.dumps(obj, 2) + return pickle.dumps(obj, protocol) - def loads(self, obj): - return cPickle.loads(obj) + if sys.version >= '3': + def loads(self, obj, encoding="bytes"): + return pickle.loads(obj, encoding=encoding) + else: + def loads(self, obj, encoding=None): + return pickle.loads(obj) class CloudPickleSerializer(PickleSerializer): @@ -454,7 +447,7 @@ def loads(self, obj): class AutoSerializer(FramedSerializer): """ - Choose marshal or cPickle as serialization protocol automatically + Choose marshal or pickle as serialization protocol automatically """ def __init__(self): @@ -463,19 +456,19 @@ def __init__(self): def dumps(self, obj): if self._type is not None: - return 'P' + cPickle.dumps(obj, -1) + return b'P' + pickle.dumps(obj, -1) try: - return 'M' + marshal.dumps(obj) + return b'M' + marshal.dumps(obj) except Exception: - self._type = 'P' - return 'P' + cPickle.dumps(obj, -1) + self._type = b'P' + return b'P' + pickle.dumps(obj, -1) def loads(self, obj): _type = obj[0] - if _type == 'M': + if _type == b'M': return marshal.loads(obj[1:]) - elif _type == 'P': - return cPickle.loads(obj[1:]) + elif _type == b'P': + return pickle.loads(obj[1:]) else: raise ValueError("invalid sevialization type: %s" % _type) @@ -495,8 +488,8 @@ def dumps(self, obj): def loads(self, obj): return self.serializer.loads(zlib.decompress(obj)) - def __eq__(self, other): - return isinstance(other, CompressedSerializer) and self.serializer == other.serializer + def __repr__(self): + return "CompressedSerializer(%s)" % self.serializer class UTF8Deserializer(Serializer): @@ -505,7 +498,7 @@ class UTF8Deserializer(Serializer): Deserializes streams written by String.getBytes. """ - def __init__(self, use_unicode=False): + def __init__(self, use_unicode=True): self.use_unicode = use_unicode def loads(self, stream): @@ -526,13 +519,13 @@ def load_stream(self, stream): except EOFError: return - def __eq__(self, other): - return isinstance(other, UTF8Deserializer) and self.use_unicode == other.use_unicode + def __repr__(self): + return "UTF8Deserializer(%s)" % self.use_unicode def read_long(stream): length = stream.read(8) - if length == "": + if not length: raise EOFError return struct.unpack("!q", length)[0] @@ -547,7 +540,7 @@ def pack_long(value): def read_int(stream): length = stream.read(4) - if length == "": + if not length: raise EOFError return struct.unpack("!i", length)[0] diff --git a/python/pyspark/shell.py b/python/pyspark/shell.py index 81aa970a32f76..144cdf0b0cdd5 100644 --- a/python/pyspark/shell.py +++ b/python/pyspark/shell.py @@ -21,13 +21,6 @@ This file is designed to be launched as a PYTHONSTARTUP script. """ -import sys -if sys.version_info[0] != 2: - print("Error: Default Python used is Python%s" % sys.version_info.major) - print("\tSet env variable PYSPARK_PYTHON to Python2 binary and re-run it.") - sys.exit(1) - - import atexit import os import platform @@ -53,9 +46,14 @@ try: # Try to access HiveConf, it will raise exception if Hive is not added sc._jvm.org.apache.hadoop.hive.conf.HiveConf() - sqlCtx = sqlContext = HiveContext(sc) + sqlContext = HiveContext(sc) except py4j.protocol.Py4JError: - sqlCtx = sqlContext = SQLContext(sc) + sqlContext = SQLContext(sc) +except TypeError: + sqlContext = SQLContext(sc) + +# for compatibility +sqlCtx = sqlContext print("""Welcome to ____ __ diff --git a/python/pyspark/shuffle.py b/python/pyspark/shuffle.py index 8a6fc627eb383..b54baa57ec28a 100644 --- a/python/pyspark/shuffle.py +++ b/python/pyspark/shuffle.py @@ -78,8 +78,8 @@ def _get_local_dirs(sub): # global stats -MemoryBytesSpilled = 0L -DiskBytesSpilled = 0L +MemoryBytesSpilled = 0 +DiskBytesSpilled = 0 class Aggregator(object): @@ -126,7 +126,7 @@ def mergeCombiners(self, iterator): """ Merge the combined items by mergeCombiner """ raise NotImplementedError - def iteritems(self): + def items(self): """ Return the merged items ad iterator """ raise NotImplementedError @@ -156,9 +156,9 @@ def mergeCombiners(self, iterator): for k, v in iterator: d[k] = comb(d[k], v) if k in d else v - def iteritems(self): - """ Return the merged items as iterator """ - return self.data.iteritems() + def items(self): + """ Return the merged items ad iterator """ + return iter(self.data.items()) def _compressed_serializer(self, serializer=None): @@ -208,15 +208,15 @@ class ExternalMerger(Merger): >>> agg = SimpleAggregator(lambda x, y: x + y) >>> merger = ExternalMerger(agg, 10) >>> N = 10000 - >>> merger.mergeValues(zip(xrange(N), xrange(N))) + >>> merger.mergeValues(zip(range(N), range(N))) >>> assert merger.spills > 0 - >>> sum(v for k,v in merger.iteritems()) + >>> sum(v for k,v in merger.items()) 49995000 >>> merger = ExternalMerger(agg, 10) - >>> merger.mergeCombiners(zip(xrange(N), xrange(N))) + >>> merger.mergeCombiners(zip(range(N), range(N))) >>> assert merger.spills > 0 - >>> sum(v for k,v in merger.iteritems()) + >>> sum(v for k,v in merger.items()) 49995000 """ @@ -335,10 +335,10 @@ def _spill(self): # above limit at the first time. # open all the files for writing - streams = [open(os.path.join(path, str(i)), 'w') + streams = [open(os.path.join(path, str(i)), 'wb') for i in range(self.partitions)] - for k, v in self.data.iteritems(): + for k, v in self.data.items(): h = self._partition(k) # put one item in batch, make it compatible with load_stream # it will increase the memory if dump them in batch @@ -354,9 +354,9 @@ def _spill(self): else: for i in range(self.partitions): p = os.path.join(path, str(i)) - with open(p, "w") as f: + with open(p, "wb") as f: # dump items in batch - self.serializer.dump_stream(self.pdata[i].iteritems(), f) + self.serializer.dump_stream(iter(self.pdata[i].items()), f) self.pdata[i].clear() DiskBytesSpilled += os.path.getsize(p) @@ -364,10 +364,10 @@ def _spill(self): gc.collect() # release the memory as much as possible MemoryBytesSpilled += (used_memory - get_used_memory()) << 20 - def iteritems(self): + def items(self): """ Return all merged items as iterator """ if not self.pdata and not self.spills: - return self.data.iteritems() + return iter(self.data.items()) return self._external_items() def _external_items(self): @@ -398,7 +398,8 @@ def _merged_items(self, index): path = self._get_spill_dir(j) p = os.path.join(path, str(index)) # do not check memory during merging - self.mergeCombiners(self.serializer.load_stream(open(p)), 0) + with open(p, "rb") as f: + self.mergeCombiners(self.serializer.load_stream(f), 0) # limit the total partitions if (self.scale * self.partitions < self.MAX_TOTAL_PARTITIONS @@ -408,7 +409,7 @@ def _merged_items(self, index): gc.collect() # release the memory as much as possible return self._recursive_merged_items(index) - return self.data.iteritems() + return self.data.items() def _recursive_merged_items(self, index): """ @@ -426,7 +427,8 @@ def _recursive_merged_items(self, index): for j in range(self.spills): path = self._get_spill_dir(j) p = os.path.join(path, str(index)) - m.mergeCombiners(self.serializer.load_stream(open(p)), 0) + with open(p, 'rb') as f: + m.mergeCombiners(self.serializer.load_stream(f), 0) if get_used_memory() > limit: m._spill() @@ -451,7 +453,7 @@ class ExternalSorter(object): >>> sorter = ExternalSorter(1) # 1M >>> import random - >>> l = range(1024) + >>> l = list(range(1024)) >>> random.shuffle(l) >>> sorted(l) == list(sorter.sorted(l)) True @@ -499,9 +501,16 @@ def sorted(self, iterator, key=None, reverse=False): # sort them inplace will save memory current_chunk.sort(key=key, reverse=reverse) path = self._get_path(len(chunks)) - with open(path, 'w') as f: + with open(path, 'wb') as f: self.serializer.dump_stream(current_chunk, f) - chunks.append(self.serializer.load_stream(open(path))) + + def load(f): + for v in self.serializer.load_stream(f): + yield v + # close the file explicit once we consume all the items + # to avoid ResourceWarning in Python3 + f.close() + chunks.append(load(open(path, 'rb'))) current_chunk = [] gc.collect() limit = self._next_limit() @@ -527,7 +536,7 @@ class ExternalList(object): ExternalList can have many items which cannot be hold in memory in the same time. - >>> l = ExternalList(range(100)) + >>> l = ExternalList(list(range(100))) >>> len(l) 100 >>> l.append(10) @@ -555,11 +564,11 @@ def __init__(self, values): def __getstate__(self): if self._file is not None: self._file.flush() - f = os.fdopen(os.dup(self._file.fileno())) - f.seek(0) - serialized = f.read() + with os.fdopen(os.dup(self._file.fileno()), "rb") as f: + f.seek(0) + serialized = f.read() else: - serialized = '' + serialized = b'' return self.values, self.count, serialized def __setstate__(self, item): @@ -575,7 +584,7 @@ def __iter__(self): if self._file is not None: self._file.flush() # read all items from disks first - with os.fdopen(os.dup(self._file.fileno()), 'r') as f: + with os.fdopen(os.dup(self._file.fileno()), 'rb') as f: f.seek(0) for v in self._ser.load_stream(f): yield v @@ -598,11 +607,16 @@ def _open_file(self): d = dirs[id(self) % len(dirs)] if not os.path.exists(d): os.makedirs(d) - p = os.path.join(d, str(id)) - self._file = open(p, "w+", 65536) + p = os.path.join(d, str(id(self))) + self._file = open(p, "wb+", 65536) self._ser = BatchedSerializer(CompressedSerializer(PickleSerializer()), 1024) os.unlink(p) + def __del__(self): + if self._file: + self._file.close() + self._file = None + def _spill(self): """ dump the values into disk """ global MemoryBytesSpilled, DiskBytesSpilled @@ -651,33 +665,28 @@ class GroupByKey(object): """ Group a sorted iterator as [(k1, it1), (k2, it2), ...] - >>> k = [i/3 for i in range(6)] + >>> k = [i // 3 for i in range(6)] >>> v = [[i] for i in range(6)] - >>> g = GroupByKey(iter(zip(k, v))) + >>> g = GroupByKey(zip(k, v)) >>> [(k, list(it)) for k, it in g] [(0, [0, 1, 2]), (1, [3, 4, 5])] """ def __init__(self, iterator): - self.iterator = iter(iterator) - self.next_item = None + self.iterator = iterator def __iter__(self): - return self - - def next(self): - key, value = self.next_item if self.next_item else next(self.iterator) - values = ExternalListOfList([value]) - try: - while True: - k, v = next(self.iterator) - if k != key: - self.next_item = (k, v) - break + key, values = None, None + for k, v in self.iterator: + if values is not None and k == key: values.append(v) - except StopIteration: - self.next_item = None - return key, values + else: + if values is not None: + yield (key, values) + key = k + values = ExternalListOfList([v]) + if values is not None: + yield (key, values) class ExternalGroupBy(ExternalMerger): @@ -744,7 +753,7 @@ def _spill(self): # above limit at the first time. # open all the files for writing - streams = [open(os.path.join(path, str(i)), 'w') + streams = [open(os.path.join(path, str(i)), 'wb') for i in range(self.partitions)] # If the number of keys is small, then the overhead of sort is small @@ -756,7 +765,7 @@ def _spill(self): h = self._partition(k) self.serializer.dump_stream([(k, self.data[k])], streams[h]) else: - for k, v in self.data.iteritems(): + for k, v in self.data.items(): h = self._partition(k) self.serializer.dump_stream([(k, v)], streams[h]) @@ -771,14 +780,14 @@ def _spill(self): else: for i in range(self.partitions): p = os.path.join(path, str(i)) - with open(p, "w") as f: + with open(p, "wb") as f: # dump items in batch if self._sorted: # sort by key only (stable) - sorted_items = sorted(self.pdata[i].iteritems(), key=operator.itemgetter(0)) + sorted_items = sorted(self.pdata[i].items(), key=operator.itemgetter(0)) self.serializer.dump_stream(sorted_items, f) else: - self.serializer.dump_stream(self.pdata[i].iteritems(), f) + self.serializer.dump_stream(self.pdata[i].items(), f) self.pdata[i].clear() DiskBytesSpilled += os.path.getsize(p) @@ -792,7 +801,7 @@ def _merged_items(self, index): # if the memory can not hold all the partition, # then use sort based merge. Because of compression, # the data on disks will be much smaller than needed memory - if (size >> 20) >= self.memory_limit / 10: + if size >= self.memory_limit << 17: # * 1M / 8 return self._merge_sorted_items(index) self.data = {} @@ -800,15 +809,18 @@ def _merged_items(self, index): path = self._get_spill_dir(j) p = os.path.join(path, str(index)) # do not check memory during merging - self.mergeCombiners(self.serializer.load_stream(open(p)), 0) - return self.data.iteritems() + with open(p, "rb") as f: + self.mergeCombiners(self.serializer.load_stream(f), 0) + return self.data.items() def _merge_sorted_items(self, index): """ load a partition from disk, then sort and group by key """ def load_partition(j): path = self._get_spill_dir(j) p = os.path.join(path, str(index)) - return self.serializer.load_stream(open(p, 'r', 65536)) + with open(p, 'rb', 65536) as f: + for v in self.serializer.load_stream(f): + yield v disk_items = [load_partition(j) for j in range(self.spills)] diff --git a/python/pyspark/sql/__init__.py b/python/pyspark/sql/__init__.py index 65abb24eed823..6d54b9e49ed10 100644 --- a/python/pyspark/sql/__init__.py +++ b/python/pyspark/sql/__init__.py @@ -37,9 +37,22 @@ - L{types} List of data types available. """ +from __future__ import absolute_import + +# fix the module name conflict for Python 3+ +import sys +from . import _types as types +modname = __name__ + '.types' +types.__name__ = modname +# update the __module__ for all objects, make them picklable +for v in types.__dict__.values(): + if hasattr(v, "__module__") and v.__module__.endswith('._types'): + v.__module__ = modname +sys.modules[modname] = types +del modname, sys -from pyspark.sql.context import SQLContext, HiveContext from pyspark.sql.types import Row +from pyspark.sql.context import SQLContext, HiveContext from pyspark.sql.dataframe import DataFrame, GroupedData, Column, SchemaRDD, DataFrameNaFunctions __all__ = [ diff --git a/python/pyspark/sql/types.py b/python/pyspark/sql/_types.py similarity index 97% rename from python/pyspark/sql/types.py rename to python/pyspark/sql/_types.py index ef76d84c00481..492c0cbdcf693 100644 --- a/python/pyspark/sql/types.py +++ b/python/pyspark/sql/_types.py @@ -15,6 +15,7 @@ # limitations under the License. # +import sys import decimal import datetime import keyword @@ -25,6 +26,9 @@ from array import array from operator import itemgetter +if sys.version >= "3": + long = int + unicode = str __all__ = [ "DataType", "NullType", "StringType", "BinaryType", "BooleanType", "DateType", @@ -410,7 +414,7 @@ def fromJson(cls, json): split = pyUDT.rfind(".") pyModule = pyUDT[:split] pyClass = pyUDT[split+1:] - m = __import__(pyModule, globals(), locals(), [pyClass], -1) + m = __import__(pyModule, globals(), locals(), [pyClass]) UDT = getattr(m, pyClass) return UDT() @@ -419,10 +423,9 @@ def __eq__(self, other): _all_primitive_types = dict((v.typeName(), v) - for v in globals().itervalues() - if type(v) is PrimitiveTypeSingleton and - v.__base__ == PrimitiveType) - + for v in list(globals().values()) + if (type(v) is type or type(v) is PrimitiveTypeSingleton) + and v.__base__ == PrimitiveType) _all_complex_types = dict((v.typeName(), v) for v in [ArrayType, MapType, StructType]) @@ -486,10 +489,10 @@ def _parse_datatype_json_string(json_string): def _parse_datatype_json_value(json_value): - if type(json_value) is unicode: + if not isinstance(json_value, dict): if json_value in _all_primitive_types.keys(): return _all_primitive_types[json_value]() - elif json_value == u'decimal': + elif json_value == 'decimal': return DecimalType() elif _FIXED_DECIMAL.match(json_value): m = _FIXED_DECIMAL.match(json_value) @@ -511,10 +514,8 @@ def _parse_datatype_json_value(json_value): type(None): NullType, bool: BooleanType, int: LongType, - long: LongType, float: DoubleType, str: StringType, - unicode: StringType, bytearray: BinaryType, decimal.Decimal: DecimalType, datetime.date: DateType, @@ -522,6 +523,12 @@ def _parse_datatype_json_value(json_value): datetime.time: TimestampType, } +if sys.version < "3": + _type_mappings.update({ + unicode: StringType, + long: LongType, + }) + def _infer_type(obj): """Infer the DataType from obj @@ -541,7 +548,7 @@ def _infer_type(obj): return dataType() if isinstance(obj, dict): - for key, value in obj.iteritems(): + for key, value in obj.items(): if key is not None and value is not None: return MapType(_infer_type(key), _infer_type(value), True) else: @@ -565,10 +572,10 @@ def _infer_schema(row): items = sorted(row.items()) elif isinstance(row, (tuple, list)): - if hasattr(row, "_fields"): # namedtuple - items = zip(row._fields, tuple(row)) - elif hasattr(row, "__fields__"): # Row + if hasattr(row, "__fields__"): # Row items = zip(row.__fields__, tuple(row)) + elif hasattr(row, "_fields"): # namedtuple + items = zip(row._fields, tuple(row)) else: names = ['_%d' % i for i in range(1, len(row) + 1)] items = zip(names, row) @@ -647,7 +654,7 @@ def converter(obj): if isinstance(obj, dict): return tuple(c(obj.get(n)) for n, c in zip(names, converters)) elif isinstance(obj, tuple): - if hasattr(obj, "_fields") or hasattr(obj, "__fields__"): + if hasattr(obj, "__fields__") or hasattr(obj, "_fields"): return tuple(c(v) for c, v in zip(converters, obj)) elif all(isinstance(x, tuple) and len(x) == 2 for x in obj): # k-v pairs d = dict(obj) @@ -733,12 +740,12 @@ def _create_converter(dataType): if isinstance(dataType, ArrayType): conv = _create_converter(dataType.elementType) - return lambda row: map(conv, row) + return lambda row: [conv(v) for v in row] elif isinstance(dataType, MapType): kconv = _create_converter(dataType.keyType) vconv = _create_converter(dataType.valueType) - return lambda row: dict((kconv(k), vconv(v)) for k, v in row.iteritems()) + return lambda row: dict((kconv(k), vconv(v)) for k, v in row.items()) elif isinstance(dataType, NullType): return lambda x: None @@ -881,7 +888,7 @@ def _infer_schema_type(obj, dataType): >>> _infer_schema_type(row, schema) StructType...a,ArrayType...b,MapType(StringType,...c,LongType... """ - if dataType is NullType(): + if isinstance(dataType, NullType): return _infer_type(obj) if not obj: @@ -892,7 +899,7 @@ def _infer_schema_type(obj, dataType): return ArrayType(eType, True) elif isinstance(dataType, MapType): - k, v = obj.iteritems().next() + k, v = next(iter(obj.items())) return MapType(_infer_schema_type(k, dataType.keyType), _infer_schema_type(v, dataType.valueType)) @@ -935,7 +942,7 @@ def _verify_type(obj, dataType): >>> _verify_type(None, StructType([])) >>> _verify_type("", StringType()) >>> _verify_type(0, LongType()) - >>> _verify_type(range(3), ArrayType(ShortType())) + >>> _verify_type(list(range(3)), ArrayType(ShortType())) >>> _verify_type(set(), ArrayType(StringType())) # doctest: +IGNORE_EXCEPTION_DETAIL Traceback (most recent call last): ... @@ -976,7 +983,7 @@ def _verify_type(obj, dataType): _verify_type(i, dataType.elementType) elif isinstance(dataType, MapType): - for k, v in obj.iteritems(): + for k, v in obj.items(): _verify_type(k, dataType.keyType) _verify_type(v, dataType.valueType) @@ -1213,6 +1220,8 @@ def __getattr__(self, item): return self[idx] except IndexError: raise AttributeError(item) + except ValueError: + raise AttributeError(item) def __reduce__(self): if hasattr(self, "__fields__"): diff --git a/python/pyspark/sql/context.py b/python/pyspark/sql/context.py index e8529a8f8e3a4..c90afc326ca0e 100644 --- a/python/pyspark/sql/context.py +++ b/python/pyspark/sql/context.py @@ -15,14 +15,19 @@ # limitations under the License. # +import sys import warnings import json -from itertools import imap + +if sys.version >= '3': + basestring = unicode = str +else: + from itertools import imap as map from py4j.protocol import Py4JError from py4j.java_collections import MapConverter -from pyspark.rdd import RDD, _prepare_for_python_RDD +from pyspark.rdd import RDD, _prepare_for_python_RDD, ignore_unicode_prefix from pyspark.serializers import AutoBatchedSerializer, PickleSerializer from pyspark.sql.types import Row, StringType, StructType, _verify_type, \ _infer_schema, _has_nulltype, _merge_type, _create_converter, _python_to_sql_converter @@ -62,31 +67,27 @@ class SQLContext(object): A SQLContext can be used create :class:`DataFrame`, register :class:`DataFrame` as tables, execute SQL over tables, cache tables, and read parquet files. - When created, :class:`SQLContext` adds a method called ``toDF`` to :class:`RDD`, - which could be used to convert an RDD into a DataFrame, it's a shorthand for - :func:`SQLContext.createDataFrame`. - :param sparkContext: The :class:`SparkContext` backing this SQLContext. :param sqlContext: An optional JVM Scala SQLContext. If set, we do not instantiate a new SQLContext in the JVM, instead we make all calls to this object. """ + @ignore_unicode_prefix def __init__(self, sparkContext, sqlContext=None): """Creates a new SQLContext. >>> from datetime import datetime >>> sqlContext = SQLContext(sc) - >>> allTypes = sc.parallelize([Row(i=1, s="string", d=1.0, l=1L, + >>> allTypes = sc.parallelize([Row(i=1, s="string", d=1.0, l=1, ... b=True, list=[1, 2, 3], dict={"s": 0}, row=Row(a=1), ... time=datetime(2014, 8, 1, 14, 1, 5))]) >>> df = allTypes.toDF() >>> df.registerTempTable("allTypes") >>> sqlContext.sql('select i+1, d+1, not b, list[1], dict["s"], time, row.a ' ... 'from allTypes where b and i > 0').collect() - [Row(c0=2, c1=2.0, c2=False, c3=2, c4=0...8, 1, 14, 1, 5), a=1)] - >>> df.map(lambda x: (x.i, x.s, x.d, x.l, x.b, x.time, - ... x.row.a, x.list)).collect() - [(1, u'string', 1.0, 1, True, ...(2014, 8, 1, 14, 1, 5), 1, [1, 2, 3])] + [Row(c0=2, c1=2.0, c2=False, c3=2, c4=0, time=datetime.datetime(2014, 8, 1, 14, 1, 5), a=1)] + >>> df.map(lambda x: (x.i, x.s, x.d, x.l, x.b, x.time, x.row.a, x.list)).collect() + [(1, u'string', 1.0, 1, True, datetime.datetime(2014, 8, 1, 14, 1, 5), 1, [1, 2, 3])] """ self._sc = sparkContext self._jsc = self._sc._jsc @@ -122,6 +123,7 @@ def udf(self): """Returns a :class:`UDFRegistration` for UDF registration.""" return UDFRegistration(self) + @ignore_unicode_prefix def registerFunction(self, name, f, returnType=StringType()): """Registers a lambda function as a UDF so it can be used in SQL statements. @@ -147,7 +149,7 @@ def registerFunction(self, name, f, returnType=StringType()): >>> sqlContext.sql("SELECT stringLengthInt('test')").collect() [Row(c0=4)] """ - func = lambda _, it: imap(lambda x: f(*x), it) + func = lambda _, it: map(lambda x: f(*x), it) ser = AutoBatchedSerializer(PickleSerializer()) command = (func, None, ser, ser) pickled_cmd, bvars, env, includes = _prepare_for_python_RDD(self._sc, command, self) @@ -185,6 +187,7 @@ def _inferSchema(self, rdd, samplingRatio=None): schema = rdd.map(_infer_schema).reduce(_merge_type) return schema + @ignore_unicode_prefix def inferSchema(self, rdd, samplingRatio=None): """::note: Deprecated in 1.3, use :func:`createDataFrame` instead. """ @@ -195,6 +198,7 @@ def inferSchema(self, rdd, samplingRatio=None): return self.createDataFrame(rdd, None, samplingRatio) + @ignore_unicode_prefix def applySchema(self, rdd, schema): """::note: Deprecated in 1.3, use :func:`createDataFrame` instead. """ @@ -208,6 +212,7 @@ def applySchema(self, rdd, schema): return self.createDataFrame(rdd, schema) + @ignore_unicode_prefix def createDataFrame(self, data, schema=None, samplingRatio=None): """ Creates a :class:`DataFrame` from an :class:`RDD` of :class:`tuple`/:class:`list`, @@ -380,6 +385,7 @@ def jsonFile(self, path, schema=None, samplingRatio=1.0): df = self._ssql_ctx.jsonFile(path, scala_datatype) return DataFrame(df, self) + @ignore_unicode_prefix def jsonRDD(self, rdd, schema=None, samplingRatio=1.0): """Loads an RDD storing one JSON object per string as a :class:`DataFrame`. @@ -477,6 +483,7 @@ def createExternalTable(self, tableName, path=None, source=None, joptions) return DataFrame(df, self) + @ignore_unicode_prefix def sql(self, sqlQuery): """Returns a :class:`DataFrame` representing the result of the given query. @@ -497,6 +504,7 @@ def table(self, tableName): """ return DataFrame(self._ssql_ctx.table(tableName), self) + @ignore_unicode_prefix def tables(self, dbName=None): """Returns a :class:`DataFrame` containing names of tables in the given database. diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index ef91a9c4f522d..326d22e72f104 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -16,14 +16,19 @@ # import sys -import itertools import warnings import random +if sys.version >= '3': + basestring = unicode = str + long = int +else: + from itertools import imap as map + from py4j.java_collections import ListConverter, MapConverter from pyspark.context import SparkContext -from pyspark.rdd import RDD, _load_from_socket +from pyspark.rdd import RDD, _load_from_socket, ignore_unicode_prefix from pyspark.serializers import BatchedSerializer, PickleSerializer, UTF8Deserializer from pyspark.storagelevel import StorageLevel from pyspark.traceback_utils import SCCallSiteSync @@ -65,19 +70,20 @@ def __init__(self, jdf, sql_ctx): self._sc = sql_ctx and sql_ctx._sc self.is_cached = False self._schema = None # initialized lazily + self._lazy_rdd = None @property def rdd(self): """Returns the content as an :class:`pyspark.RDD` of :class:`Row`. """ - if not hasattr(self, '_lazy_rdd'): + if self._lazy_rdd is None: jrdd = self._jdf.javaToPython() rdd = RDD(jrdd, self.sql_ctx._sc, BatchedSerializer(PickleSerializer())) schema = self.schema def applySchema(it): cls = _create_cls(schema) - return itertools.imap(cls, it) + return map(cls, it) self._lazy_rdd = rdd.mapPartitions(applySchema) @@ -89,13 +95,14 @@ def na(self): """ return DataFrameNaFunctions(self) - def toJSON(self, use_unicode=False): + @ignore_unicode_prefix + def toJSON(self, use_unicode=True): """Converts a :class:`DataFrame` into a :class:`RDD` of string. Each row is turned into a JSON document as one element in the returned RDD. >>> df.toJSON().first() - '{"age":2,"name":"Alice"}' + u'{"age":2,"name":"Alice"}' """ rdd = self._jdf.toJSON() return RDD(rdd.toJavaRDD(), self._sc, UTF8Deserializer(use_unicode)) @@ -228,7 +235,7 @@ def printSchema(self): |-- name: string (nullable = true) """ - print (self._jdf.schema().treeString()) + print(self._jdf.schema().treeString()) def explain(self, extended=False): """Prints the (logical and physical) plans to the console for debugging purpose. @@ -250,9 +257,9 @@ def explain(self, extended=False): == RDD == """ if extended: - print self._jdf.queryExecution().toString() + print(self._jdf.queryExecution().toString()) else: - print self._jdf.queryExecution().executedPlan().toString() + print(self._jdf.queryExecution().executedPlan().toString()) def isLocal(self): """Returns ``True`` if the :func:`collect` and :func:`take` methods can be run locally @@ -270,7 +277,7 @@ def show(self, n=20): 2 Alice 5 Bob """ - print self._jdf.showString(n).encode('utf8', 'ignore') + print(self._jdf.showString(n)) def __repr__(self): return "DataFrame[%s]" % (", ".join("%s: %s" % c for c in self.dtypes)) @@ -279,10 +286,11 @@ def count(self): """Returns the number of rows in this :class:`DataFrame`. >>> df.count() - 2L + 2 """ - return self._jdf.count() + return int(self._jdf.count()) + @ignore_unicode_prefix def collect(self): """Returns all the records as a list of :class:`Row`. @@ -295,6 +303,7 @@ def collect(self): cls = _create_cls(self.schema) return [cls(r) for r in rs] + @ignore_unicode_prefix def limit(self, num): """Limits the result count to the number specified. @@ -306,6 +315,7 @@ def limit(self, num): jdf = self._jdf.limit(num) return DataFrame(jdf, self.sql_ctx) + @ignore_unicode_prefix def take(self, num): """Returns the first ``num`` rows as a :class:`list` of :class:`Row`. @@ -314,6 +324,7 @@ def take(self, num): """ return self.limit(num).collect() + @ignore_unicode_prefix def map(self, f): """ Returns a new :class:`RDD` by applying a the ``f`` function to each :class:`Row`. @@ -324,6 +335,7 @@ def map(self, f): """ return self.rdd.map(f) + @ignore_unicode_prefix def flatMap(self, f): """ Returns a new :class:`RDD` by first applying the ``f`` function to each :class:`Row`, and then flattening the results. @@ -353,7 +365,7 @@ def foreach(self, f): This is a shorthand for ``df.rdd.foreach()``. >>> def f(person): - ... print person.name + ... print(person.name) >>> df.foreach(f) """ return self.rdd.foreach(f) @@ -365,7 +377,7 @@ def foreachPartition(self, f): >>> def f(people): ... for person in people: - ... print person.name + ... print(person.name) >>> df.foreachPartition(f) """ return self.rdd.foreachPartition(f) @@ -412,7 +424,7 @@ def distinct(self): """Returns a new :class:`DataFrame` containing the distinct rows in this :class:`DataFrame`. >>> df.distinct().count() - 2L + 2 """ return DataFrame(self._jdf.distinct(), self.sql_ctx) @@ -420,10 +432,10 @@ def sample(self, withReplacement, fraction, seed=None): """Returns a sampled subset of this :class:`DataFrame`. >>> df.sample(False, 0.5, 97).count() - 1L + 1 """ assert fraction >= 0.0, "Negative fraction value: %s" % fraction - seed = seed if seed is not None else random.randint(0, sys.maxint) + seed = seed if seed is not None else random.randint(0, sys.maxsize) rdd = self._jdf.sample(withReplacement, fraction, long(seed)) return DataFrame(rdd, self.sql_ctx) @@ -437,6 +449,7 @@ def dtypes(self): return [(str(f.name), f.dataType.simpleString()) for f in self.schema.fields] @property + @ignore_unicode_prefix def columns(self): """Returns all column names as a list. @@ -445,6 +458,7 @@ def columns(self): """ return [f.name for f in self.schema.fields] + @ignore_unicode_prefix def join(self, other, joinExprs=None, joinType=None): """Joins with another :class:`DataFrame`, using the given join expression. @@ -456,7 +470,7 @@ def join(self, other, joinExprs=None, joinType=None): One of `inner`, `outer`, `left_outer`, `right_outer`, `semijoin`. >>> df.join(df2, df.name == df2.name, 'outer').select(df.name, df2.height).collect() - [Row(name=None, height=80), Row(name=u'Bob', height=85), Row(name=u'Alice', height=None)] + [Row(name=None, height=80), Row(name=u'Alice', height=None), Row(name=u'Bob', height=85)] """ if joinExprs is None: @@ -470,13 +484,18 @@ def join(self, other, joinExprs=None, joinType=None): jdf = self._jdf.join(other._jdf, joinExprs._jc, joinType) return DataFrame(jdf, self.sql_ctx) - def sort(self, *cols): + @ignore_unicode_prefix + def sort(self, *cols, **kwargs): """Returns a new :class:`DataFrame` sorted by the specified column(s). - :param cols: list of :class:`Column` to sort by. + :param cols: list of :class:`Column` or column names to sort by. + :param ascending: sort by ascending order or not, could be bool, int + or list of bool, int (default: True). >>> df.sort(df.age.desc()).collect() [Row(age=5, name=u'Bob'), Row(age=2, name=u'Alice')] + >>> df.sort("age", ascending=False).collect() + [Row(age=5, name=u'Bob'), Row(age=2, name=u'Alice')] >>> df.orderBy(df.age.desc()).collect() [Row(age=5, name=u'Bob'), Row(age=2, name=u'Alice')] >>> from pyspark.sql.functions import * @@ -484,16 +503,42 @@ def sort(self, *cols): [Row(age=2, name=u'Alice'), Row(age=5, name=u'Bob')] >>> df.orderBy(desc("age"), "name").collect() [Row(age=5, name=u'Bob'), Row(age=2, name=u'Alice')] + >>> df.orderBy(["age", "name"], ascending=[0, 1]).collect() + [Row(age=5, name=u'Bob'), Row(age=2, name=u'Alice')] """ if not cols: raise ValueError("should sort by at least one column") - jcols = ListConverter().convert([_to_java_column(c) for c in cols], - self._sc._gateway._gateway_client) - jdf = self._jdf.sort(self._sc._jvm.PythonUtils.toSeq(jcols)) + if len(cols) == 1 and isinstance(cols[0], list): + cols = cols[0] + jcols = [_to_java_column(c) for c in cols] + ascending = kwargs.get('ascending', True) + if isinstance(ascending, (bool, int)): + if not ascending: + jcols = [jc.desc() for jc in jcols] + elif isinstance(ascending, list): + jcols = [jc if asc else jc.desc() + for asc, jc in zip(ascending, jcols)] + else: + raise TypeError("ascending can only be bool or list, but got %s" % type(ascending)) + + jdf = self._jdf.sort(self._jseq(jcols)) return DataFrame(jdf, self.sql_ctx) orderBy = sort + def _jseq(self, cols, converter=None): + """Return a JVM Seq of Columns from a list of Column or names""" + return _to_seq(self.sql_ctx._sc, cols, converter) + + def _jcols(self, *cols): + """Return a JVM Seq of Columns from a list of Column or column names + + If `cols` has only one list in it, cols[0] will be used as the list. + """ + if len(cols) == 1 and isinstance(cols[0], list): + cols = cols[0] + return self._jseq(cols, _to_java_column) + def describe(self, *cols): """Computes statistics for numeric columns. @@ -508,11 +553,10 @@ def describe(self, *cols): min 2 max 5 """ - cols = ListConverter().convert(cols, - self.sql_ctx._sc._gateway._gateway_client) - jdf = self._jdf.describe(self.sql_ctx._sc._jvm.PythonUtils.toSeq(cols)) + jdf = self._jdf.describe(self._jseq(cols)) return DataFrame(jdf, self.sql_ctx) + @ignore_unicode_prefix def head(self, n=None): """ Returns the first ``n`` rows as a list of :class:`Row`, @@ -528,6 +572,7 @@ def head(self, n=None): return rs[0] if rs else None return self.take(n) + @ignore_unicode_prefix def first(self): """Returns the first row as a :class:`Row`. @@ -536,6 +581,7 @@ def first(self): """ return self.head() + @ignore_unicode_prefix def __getitem__(self, item): """Returns the column as a :class:`Column`. @@ -545,16 +591,23 @@ def __getitem__(self, item): [Row(name=u'Alice', age=2), Row(name=u'Bob', age=5)] >>> df[ df.age > 3 ].collect() [Row(age=5, name=u'Bob')] + >>> df[df[0] > 3].collect() + [Row(age=5, name=u'Bob')] """ if isinstance(item, basestring): + if item not in self.columns: + raise IndexError("no such column: %s" % item) jc = self._jdf.apply(item) return Column(jc) elif isinstance(item, Column): return self.filter(item) - elif isinstance(item, list): + elif isinstance(item, (list, tuple)): return self.select(*item) + elif isinstance(item, int): + jc = self._jdf.apply(self.columns[item]) + return Column(jc) else: - raise IndexError("unexpected index: %s" % item) + raise TypeError("unexpected type: %s" % type(item)) def __getattr__(self, name): """Returns the :class:`Column` denoted by ``name``. @@ -562,11 +615,12 @@ def __getattr__(self, name): >>> df.select(df.age).collect() [Row(age=2), Row(age=5)] """ - if name.startswith("__"): - raise AttributeError(name) + if name not in self.columns: + raise AttributeError("No such column: %s" % name) jc = self._jdf.apply(name) return Column(jc) + @ignore_unicode_prefix def select(self, *cols): """Projects a set of expressions and returns a new :class:`DataFrame`. @@ -581,9 +635,7 @@ def select(self, *cols): >>> df.select(df.name, (df.age + 10).alias('age')).collect() [Row(name=u'Alice', age=12), Row(name=u'Bob', age=15)] """ - jcols = ListConverter().convert([_to_java_column(c) for c in cols], - self._sc._gateway._gateway_client) - jdf = self._jdf.select(self.sql_ctx._sc._jvm.PythonUtils.toSeq(jcols)) + jdf = self._jdf.select(self._jcols(*cols)) return DataFrame(jdf, self.sql_ctx) def selectExpr(self, *expr): @@ -594,10 +646,12 @@ def selectExpr(self, *expr): >>> df.selectExpr("age * 2", "abs(age)").collect() [Row((age * 2)=4, Abs(age)=2), Row((age * 2)=10, Abs(age)=5)] """ - jexpr = ListConverter().convert(expr, self._sc._gateway._gateway_client) - jdf = self._jdf.selectExpr(self._sc._jvm.PythonUtils.toSeq(jexpr)) + if len(expr) == 1 and isinstance(expr[0], list): + expr = expr[0] + jdf = self._jdf.selectExpr(self._jseq(expr)) return DataFrame(jdf, self.sql_ctx) + @ignore_unicode_prefix def filter(self, condition): """Filters rows using the given condition. @@ -626,26 +680,31 @@ def filter(self, condition): where = filter + @ignore_unicode_prefix def groupBy(self, *cols): """Groups the :class:`DataFrame` using the specified columns, so we can run aggregation on them. See :class:`GroupedData` for all the available aggregate functions. + :func:`groupby` is an alias for :func:`groupBy`. + :param cols: list of columns to group by. Each element should be a column name (string) or an expression (:class:`Column`). >>> df.groupBy().avg().collect() [Row(AVG(age)=3.5)] >>> df.groupBy('name').agg({'age': 'mean'}).collect() - [Row(name=u'Bob', AVG(age)=5.0), Row(name=u'Alice', AVG(age)=2.0)] + [Row(name=u'Alice', AVG(age)=2.0), Row(name=u'Bob', AVG(age)=5.0)] >>> df.groupBy(df.name).avg().collect() - [Row(name=u'Bob', AVG(age)=5.0), Row(name=u'Alice', AVG(age)=2.0)] + [Row(name=u'Alice', AVG(age)=2.0), Row(name=u'Bob', AVG(age)=5.0)] + >>> df.groupBy(['name', df.age]).count().collect() + [Row(name=u'Bob', age=5, count=1), Row(name=u'Alice', age=2, count=1)] """ - jcols = ListConverter().convert([_to_java_column(c) for c in cols], - self._sc._gateway._gateway_client) - jdf = self._jdf.groupBy(self.sql_ctx._sc._jvm.PythonUtils.toSeq(jcols)) + jdf = self._jdf.groupBy(self._jcols(*cols)) return GroupedData(jdf, self.sql_ctx) + groupby = groupBy + def agg(self, *exprs): """ Aggregate on the entire :class:`DataFrame` without groups (shorthand for ``df.groupBy.agg()``). @@ -716,9 +775,7 @@ def dropna(self, how='any', thresh=None, subset=None): if thresh is None: thresh = len(subset) if how == 'any' else 1 - cols = ListConverter().convert(subset, self.sql_ctx._sc._gateway._gateway_client) - cols = self.sql_ctx._sc._jvm.PythonUtils.toSeq(cols) - return DataFrame(self._jdf.na().drop(thresh, cols), self.sql_ctx) + return DataFrame(self._jdf.na().drop(thresh, self._jseq(subset)), self.sql_ctx) def fillna(self, value, subset=None): """Replace null values, alias for ``na.fill()``. @@ -771,10 +828,9 @@ def fillna(self, value, subset=None): elif not isinstance(subset, (list, tuple)): raise ValueError("subset should be a list or tuple of column names") - cols = ListConverter().convert(subset, self.sql_ctx._sc._gateway._gateway_client) - cols = self.sql_ctx._sc._jvm.PythonUtils.toSeq(cols) - return DataFrame(self._jdf.na().fill(value, cols), self.sql_ctx) + return DataFrame(self._jdf.na().fill(value, self._jseq(subset)), self.sql_ctx) + @ignore_unicode_prefix def withColumn(self, colName, col): """Returns a new :class:`DataFrame` by adding a column. @@ -786,6 +842,7 @@ def withColumn(self, colName, col): """ return self.select('*', col.alias(colName)) + @ignore_unicode_prefix def withColumnRenamed(self, existing, new): """REturns a new :class:`DataFrame` by renaming an existing column. @@ -832,10 +889,8 @@ def _api(self): def df_varargs_api(f): def _api(self, *args): - jargs = ListConverter().convert(args, - self.sql_ctx._sc._gateway._gateway_client) name = f.__name__ - jdf = getattr(self._jdf, name)(self.sql_ctx._sc._jvm.PythonUtils.toSeq(jargs)) + jdf = getattr(self._jdf, name)(_to_seq(self.sql_ctx._sc, args)) return DataFrame(jdf, self.sql_ctx) _api.__name__ = f.__name__ _api.__doc__ = f.__doc__ @@ -852,6 +907,7 @@ def __init__(self, jdf, sql_ctx): self._jdf = jdf self.sql_ctx = sql_ctx + @ignore_unicode_prefix def agg(self, *exprs): """Compute aggregates and returns the result as a :class:`DataFrame`. @@ -867,11 +923,11 @@ def agg(self, *exprs): >>> gdf = df.groupBy(df.name) >>> gdf.agg({"*": "count"}).collect() - [Row(name=u'Bob', COUNT(1)=1), Row(name=u'Alice', COUNT(1)=1)] + [Row(name=u'Alice', COUNT(1)=1), Row(name=u'Bob', COUNT(1)=1)] >>> from pyspark.sql import functions as F >>> gdf.agg(F.min(df.age)).collect() - [Row(MIN(age)=5), Row(MIN(age)=2)] + [Row(MIN(age)=2), Row(MIN(age)=5)] """ assert exprs, "exprs should not be empty" if len(exprs) == 1 and isinstance(exprs[0], dict): @@ -881,9 +937,8 @@ def agg(self, *exprs): else: # Columns assert all(isinstance(c, Column) for c in exprs), "all exprs should be Column" - jcols = ListConverter().convert([c._jc for c in exprs[1:]], - self.sql_ctx._sc._gateway._gateway_client) - jdf = self._jdf.agg(exprs[0]._jc, self.sql_ctx._sc._jvm.PythonUtils.toSeq(jcols)) + jdf = self._jdf.agg(exprs[0]._jc, + _to_seq(self.sql_ctx._sc, [c._jc for c in exprs[1:]])) return DataFrame(jdf, self.sql_ctx) @dfapi @@ -975,6 +1030,19 @@ def _to_java_column(col): return jcol +def _to_seq(sc, cols, converter=None): + """ + Convert a list of Column (or names) into a JVM Seq of Column. + + An optional `converter` could be used to convert items in `cols` + into JVM Column objects. + """ + if converter: + cols = [converter(c) for c in cols] + jcols = ListConverter().convert(cols, sc._gateway._gateway_client) + return sc._jvm.PythonUtils.toSeq(jcols) + + def _unary_op(name, doc="unary operator"): """ Create a method for given unary operator """ def _(self): @@ -1041,11 +1109,13 @@ def __init__(self, jc): __sub__ = _bin_op("minus") __mul__ = _bin_op("multiply") __div__ = _bin_op("divide") + __truediv__ = _bin_op("divide") __mod__ = _bin_op("mod") __radd__ = _bin_op("plus") __rsub__ = _reverse_op("minus") __rmul__ = _bin_op("multiply") __rdiv__ = _reverse_op("divide") + __rtruediv__ = _reverse_op("divide") __rmod__ = _reverse_op("mod") # logistic operators @@ -1067,7 +1137,39 @@ def __init__(self, jc): # container operators __contains__ = _bin_op("contains") __getitem__ = _bin_op("getItem") - getField = _bin_op("getField", "An expression that gets a field by name in a StructField.") + + def getItem(self, key): + """An expression that gets an item at position `ordinal` out of a list, + or gets an item by key out of a dict. + + >>> df = sc.parallelize([([1, 2], {"key": "value"})]).toDF(["l", "d"]) + >>> df.select(df.l.getItem(0), df.d.getItem("key")).show() + l[0] d[key] + 1 value + >>> df.select(df.l[0], df.d["key"]).show() + l[0] d[key] + 1 value + """ + return self[key] + + def getField(self, name): + """An expression that gets a field by name in a StructField. + + >>> from pyspark.sql import Row + >>> df = sc.parallelize([Row(r=Row(a=1, b="b"))]).toDF() + >>> df.select(df.r.getField("b")).show() + r.b + b + >>> df.select(df.r.a).show() + r.a + 1 + """ + return Column(self._jc.getField(name)) + + def __getattr__(self, item): + if item.startswith("__"): + raise AttributeError(item) + return self.getField(item) # string methods rlike = _bin_op("rlike") @@ -1075,6 +1177,7 @@ def __init__(self, jc): startswith = _bin_op("startsWith") endswith = _bin_op("endsWith") + @ignore_unicode_prefix def substr(self, startPos, length): """ Return a :class:`Column` which is a substring of the column @@ -1097,6 +1200,7 @@ def substr(self, startPos, length): __getslice__ = substr + @ignore_unicode_prefix def inSet(self, *cols): """ A boolean expression that is evaluated to true if the value of this expression is contained by the evaluated values of the arguments. @@ -1110,8 +1214,7 @@ def inSet(self, *cols): cols = cols[0] cols = [c._jc if isinstance(c, Column) else _create_column_from_literal(c) for c in cols] sc = SparkContext._active_spark_context - jcols = ListConverter().convert(cols, sc._gateway._gateway_client) - jc = getattr(self._jc, "in")(sc._jvm.PythonUtils.toSeq(jcols)) + jc = getattr(self._jc, "in")(_to_seq(sc, cols)) return Column(jc) # order @@ -1131,6 +1234,7 @@ def alias(self, alias): """ return Column(getattr(self._jc, "as")(alias)) + @ignore_unicode_prefix def cast(self, dataType): """ Convert the column into type `dataType` diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index daeb6916b58bc..bb47923f24b82 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -18,16 +18,16 @@ """ A collections of builtin functions """ +import sys -from itertools import imap - -from py4j.java_collections import ListConverter +if sys.version < "3": + from itertools import imap as map from pyspark import SparkContext from pyspark.rdd import _prepare_for_python_RDD from pyspark.serializers import PickleSerializer, AutoBatchedSerializer from pyspark.sql.types import StringType -from pyspark.sql.dataframe import Column, _to_java_column +from pyspark.sql.dataframe import Column, _to_java_column, _to_seq __all__ = ['countDistinct', 'approxCountDistinct', 'udf'] @@ -85,8 +85,7 @@ def countDistinct(col, *cols): [Row(c=2)] """ sc = SparkContext._active_spark_context - jcols = ListConverter().convert([_to_java_column(c) for c in cols], sc._gateway._gateway_client) - jc = sc._jvm.functions.countDistinct(_to_java_column(col), sc._jvm.PythonUtils.toSeq(jcols)) + jc = sc._jvm.functions.countDistinct(_to_java_column(col), _to_seq(sc, cols, _to_java_column)) return Column(jc) @@ -116,7 +115,7 @@ def __init__(self, func, returnType): def _create_judf(self): f = self.func # put it in closure `func` - func = lambda _, it: imap(lambda x: f(*x), it) + func = lambda _, it: map(lambda x: f(*x), it) ser = AutoBatchedSerializer(PickleSerializer()) command = (func, None, ser, ser) sc = SparkContext._active_spark_context @@ -136,9 +135,7 @@ def __del__(self): def __call__(self, *cols): sc = SparkContext._active_spark_context - jcols = ListConverter().convert([_to_java_column(c) for c in cols], - sc._gateway._gateway_client) - jc = self._judf.apply(sc._jvm.PythonUtils.toSeq(jcols)) + jc = self._judf.apply(_to_seq(sc, cols, _to_java_column)) return Column(jc) diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index b3a6a2c6a9229..aa3aa1d164d9f 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -157,13 +157,13 @@ def test_udf2(self): self.assertEqual(4, res[0]) def test_udf_with_array_type(self): - d = [Row(l=range(3), d={"key": range(5)})] + d = [Row(l=list(range(3)), d={"key": list(range(5))})] rdd = self.sc.parallelize(d) self.sqlCtx.createDataFrame(rdd).registerTempTable("test") self.sqlCtx.registerFunction("copylist", lambda l: list(l), ArrayType(IntegerType())) self.sqlCtx.registerFunction("maplen", lambda d: len(d), IntegerType()) [(l1, l2)] = self.sqlCtx.sql("select copylist(l), maplen(d) from test").collect() - self.assertEqual(range(3), l1) + self.assertEqual(list(range(3)), l1) self.assertEqual(1, l2) def test_broadcast_in_udf(self): @@ -266,7 +266,7 @@ def test_infer_nested_schema(self): def test_apply_schema(self): from datetime import date, datetime - rdd = self.sc.parallelize([(127, -128L, -32768, 32767, 2147483647L, 1.0, + rdd = self.sc.parallelize([(127, -128, -32768, 32767, 2147483647, 1.0, date(2010, 1, 1), datetime(2010, 1, 1, 1, 1, 1), {"a": 1}, (2,), [1, 2, 3], None)]) schema = StructType([ @@ -282,7 +282,7 @@ def test_apply_schema(self): StructField("struct1", StructType([StructField("b", ShortType(), False)]), False), StructField("list1", ArrayType(ByteType(), False), False), StructField("null1", DoubleType(), True)]) - df = self.sqlCtx.applySchema(rdd, schema) + df = self.sqlCtx.createDataFrame(rdd, schema) results = df.map(lambda x: (x.byte1, x.byte2, x.short1, x.short2, x.int1, x.float1, x.date1, x.time1, x.map1["a"], x.struct1.b, x.list1, x.null1)) r = (127, -128, -32768, 32767, 2147483647, 1.0, date(2010, 1, 1), @@ -309,7 +309,7 @@ def test_apply_schema(self): def test_struct_in_map(self): d = [Row(m={Row(i=1): Row(s="")})] df = self.sc.parallelize(d).toDF() - k, v = df.head().m.items()[0] + k, v = list(df.head().m.items())[0] self.assertEqual(1, k.i) self.assertEqual("", v.s) @@ -426,6 +426,24 @@ def test_help_command(self): pydoc.render_doc(df.foo) pydoc.render_doc(df.take(1)) + def test_access_column(self): + df = self.df + self.assertTrue(isinstance(df.key, Column)) + self.assertTrue(isinstance(df['key'], Column)) + self.assertTrue(isinstance(df[0], Column)) + self.assertRaises(IndexError, lambda: df[2]) + self.assertRaises(IndexError, lambda: df["bad_key"]) + self.assertRaises(TypeError, lambda: df[{}]) + + def test_access_nested_types(self): + df = self.sc.parallelize([Row(l=[1], r=Row(a=1, b="b"), d={"k": "v"})]).toDF() + self.assertEqual(1, df.select(df.l[0]).first()[0]) + self.assertEqual(1, df.select(df.l.getItem(0)).first()[0]) + self.assertEqual(1, df.select(df.r.a).first()[0]) + self.assertEqual("b", df.select(df.r.getField("b")).first()[0]) + self.assertEqual("v", df.select(df.d["k"]).first()[0]) + self.assertEqual("v", df.select(df.d.getItem("k")).first()[0]) + def test_infer_long_type(self): longrow = [Row(f1='a', f2=100000000000000)] df = self.sc.parallelize(longrow).toDF() @@ -554,6 +572,9 @@ def setUpClass(cls): except py4j.protocol.Py4JError: cls.sqlCtx = None return + except TypeError: + cls.sqlCtx = None + return os.unlink(cls.tempdir.name) _scala_HiveContext =\ cls.sc._jvm.org.apache.spark.sql.hive.test.TestHiveContext(cls.sc._jsc.sc()) diff --git a/python/pyspark/statcounter.py b/python/pyspark/statcounter.py index 1e597d64e03fe..944fa414b0c0e 100644 --- a/python/pyspark/statcounter.py +++ b/python/pyspark/statcounter.py @@ -31,7 +31,7 @@ class StatCounter(object): def __init__(self, values=[]): - self.n = 0L # Running count of our values + self.n = 0 # Running count of our values self.mu = 0.0 # Running mean of our values self.m2 = 0.0 # Running variance numerator (sum of (x - mean)^2) self.maxValue = float("-inf") @@ -87,7 +87,7 @@ def copy(self): return copy.deepcopy(self) def count(self): - return self.n + return int(self.n) def mean(self): return self.mu diff --git a/python/pyspark/streaming/context.py b/python/pyspark/streaming/context.py index 2c73083c9f9a8..4590c58839266 100644 --- a/python/pyspark/streaming/context.py +++ b/python/pyspark/streaming/context.py @@ -14,6 +14,9 @@ # See the License for the specific language governing permissions and # limitations under the License. # + +from __future__ import print_function + import os import sys @@ -157,7 +160,7 @@ def getOrCreate(cls, checkpointPath, setupFunc): try: jssc = gw.jvm.JavaStreamingContext(checkpointPath) except Exception: - print >>sys.stderr, "failed to load StreamingContext from checkpoint" + print("failed to load StreamingContext from checkpoint", file=sys.stderr) raise jsc = jssc.sparkContext() diff --git a/python/pyspark/streaming/dstream.py b/python/pyspark/streaming/dstream.py index 3fa42444239f7..ff097985fae3e 100644 --- a/python/pyspark/streaming/dstream.py +++ b/python/pyspark/streaming/dstream.py @@ -15,11 +15,15 @@ # limitations under the License. # -from itertools import chain, ifilter, imap +import sys import operator import time +from itertools import chain from datetime import datetime +if sys.version < "3": + from itertools import imap as map, ifilter as filter + from py4j.protocol import Py4JJavaError from pyspark import RDD @@ -76,7 +80,7 @@ def filter(self, f): Return a new DStream containing only the elements that satisfy predicate. """ def func(iterator): - return ifilter(f, iterator) + return filter(f, iterator) return self.mapPartitions(func, True) def flatMap(self, f, preservesPartitioning=False): @@ -85,7 +89,7 @@ def flatMap(self, f, preservesPartitioning=False): this DStream, and then flattening the results """ def func(s, iterator): - return chain.from_iterable(imap(f, iterator)) + return chain.from_iterable(map(f, iterator)) return self.mapPartitionsWithIndex(func, preservesPartitioning) def map(self, f, preservesPartitioning=False): @@ -93,7 +97,7 @@ def map(self, f, preservesPartitioning=False): Return a new DStream by applying a function to each element of DStream. """ def func(iterator): - return imap(f, iterator) + return map(f, iterator) return self.mapPartitions(func, preservesPartitioning) def mapPartitions(self, f, preservesPartitioning=False): @@ -150,7 +154,7 @@ def foreachRDD(self, func): """ Apply a function to each RDD in this DStream. """ - if func.func_code.co_argcount == 1: + if func.__code__.co_argcount == 1: old_func = func func = lambda t, rdd: old_func(rdd) jfunc = TransformFunction(self._sc, func, self._jrdd_deserializer) @@ -165,14 +169,14 @@ def pprint(self, num=10): """ def takeAndPrint(time, rdd): taken = rdd.take(num + 1) - print "-------------------------------------------" - print "Time: %s" % time - print "-------------------------------------------" + print("-------------------------------------------") + print("Time: %s" % time) + print("-------------------------------------------") for record in taken[:num]: - print record + print(record) if len(taken) > num: - print "..." - print + print("...") + print() self.foreachRDD(takeAndPrint) @@ -181,7 +185,7 @@ def mapValues(self, f): Return a new DStream by applying a map function to the value of each key-value pairs in this DStream without changing the key. """ - map_values_fn = lambda (k, v): (k, f(v)) + map_values_fn = lambda kv: (kv[0], f(kv[1])) return self.map(map_values_fn, preservesPartitioning=True) def flatMapValues(self, f): @@ -189,7 +193,7 @@ def flatMapValues(self, f): Return a new DStream by applying a flatmap function to the value of each key-value pairs in this DStream without changing the key. """ - flat_map_fn = lambda (k, v): ((k, x) for x in f(v)) + flat_map_fn = lambda kv: ((kv[0], x) for x in f(kv[1])) return self.flatMap(flat_map_fn, preservesPartitioning=True) def glom(self): @@ -286,10 +290,10 @@ def transform(self, func): `func` can have one argument of `rdd`, or have two arguments of (`time`, `rdd`) """ - if func.func_code.co_argcount == 1: + if func.__code__.co_argcount == 1: oldfunc = func func = lambda t, rdd: oldfunc(rdd) - assert func.func_code.co_argcount == 2, "func should take one or two arguments" + assert func.__code__.co_argcount == 2, "func should take one or two arguments" return TransformedDStream(self, func) def transformWith(self, func, other, keepSerializer=False): @@ -300,10 +304,10 @@ def transformWith(self, func, other, keepSerializer=False): `func` can have two arguments of (`rdd_a`, `rdd_b`) or have three arguments of (`time`, `rdd_a`, `rdd_b`) """ - if func.func_code.co_argcount == 2: + if func.__code__.co_argcount == 2: oldfunc = func func = lambda t, a, b: oldfunc(a, b) - assert func.func_code.co_argcount == 3, "func should take two or three arguments" + assert func.__code__.co_argcount == 3, "func should take two or three arguments" jfunc = TransformFunction(self._sc, func, self._jrdd_deserializer, other._jrdd_deserializer) dstream = self._sc._jvm.PythonTransformed2DStream(self._jdstream.dstream(), other._jdstream.dstream(), jfunc) @@ -460,7 +464,7 @@ def reduceByWindow(self, reduceFunc, invReduceFunc, windowDuration, slideDuratio keyed = self.map(lambda x: (1, x)) reduced = keyed.reduceByKeyAndWindow(reduceFunc, invReduceFunc, windowDuration, slideDuration, 1) - return reduced.map(lambda (k, v): v) + return reduced.map(lambda kv: kv[1]) def countByWindow(self, windowDuration, slideDuration): """ @@ -489,7 +493,7 @@ def countByValueAndWindow(self, windowDuration, slideDuration, numPartitions=Non keyed = self.map(lambda x: (x, 1)) counted = keyed.reduceByKeyAndWindow(operator.add, operator.sub, windowDuration, slideDuration, numPartitions) - return counted.filter(lambda (k, v): v > 0).count() + return counted.filter(lambda kv: kv[1] > 0).count() def groupByKeyAndWindow(self, windowDuration, slideDuration, numPartitions=None): """ @@ -548,7 +552,8 @@ def reduceFunc(t, a, b): def invReduceFunc(t, a, b): b = b.reduceByKey(func, numPartitions) joined = a.leftOuterJoin(b, numPartitions) - return joined.mapValues(lambda (v1, v2): invFunc(v1, v2) if v2 is not None else v1) + return joined.mapValues(lambda kv: invFunc(kv[0], kv[1]) + if kv[1] is not None else kv[0]) jreduceFunc = TransformFunction(self._sc, reduceFunc, reduced._jrdd_deserializer) if invReduceFunc: @@ -579,9 +584,9 @@ def reduceFunc(t, a, b): g = b.groupByKey(numPartitions).mapValues(lambda vs: (list(vs), None)) else: g = a.cogroup(b.partitionBy(numPartitions), numPartitions) - g = g.mapValues(lambda (va, vb): (list(vb), list(va)[0] if len(va) else None)) - state = g.mapValues(lambda (vs, s): updateFunc(vs, s)) - return state.filter(lambda (k, v): v is not None) + g = g.mapValues(lambda ab: (list(ab[1]), list(ab[0])[0] if len(ab[0]) else None)) + state = g.mapValues(lambda vs_s: updateFunc(vs_s[0], vs_s[1])) + return state.filter(lambda k_v: k_v[1] is not None) jreduceFunc = TransformFunction(self._sc, reduceFunc, self._sc.serializer, self._jrdd_deserializer) diff --git a/python/pyspark/streaming/kafka.py b/python/pyspark/streaming/kafka.py index f083ed149effb..7a7b6e1d9a527 100644 --- a/python/pyspark/streaming/kafka.py +++ b/python/pyspark/streaming/kafka.py @@ -67,10 +67,10 @@ def createStream(ssc, zkQuorum, groupId, topics, kafkaParams={}, .loadClass("org.apache.spark.streaming.kafka.KafkaUtilsPythonHelper") helper = helperClass.newInstance() jstream = helper.createStream(ssc._jssc, jparam, jtopics, jlevel) - except Py4JJavaError, e: + except Py4JJavaError as e: # TODO: use --jar once it also work on driver if 'ClassNotFoundException' in str(e.java_exception): - print """ + print(""" ________________________________________________________________________________________________ Spark Streaming's Kafka libraries not found in class path. Try one of the following. @@ -88,8 +88,8 @@ def createStream(ssc, zkQuorum, groupId, topics, kafkaParams={}, ________________________________________________________________________________________________ -""" % (ssc.sparkContext.version, ssc.sparkContext.version) +""" % (ssc.sparkContext.version, ssc.sparkContext.version)) raise e ser = PairDeserializer(NoOpSerializer(), NoOpSerializer()) stream = DStream(jstream, ssc, ser) - return stream.map(lambda (k, v): (keyDecoder(k), valueDecoder(v))) + return stream.map(lambda k_v: (keyDecoder(k_v[0]), valueDecoder(k_v[1]))) diff --git a/python/pyspark/streaming/tests.py b/python/pyspark/streaming/tests.py index 9b4635e49020b..06d22154373bc 100644 --- a/python/pyspark/streaming/tests.py +++ b/python/pyspark/streaming/tests.py @@ -22,6 +22,7 @@ import unittest import tempfile import struct +from functools import reduce from py4j.java_collections import MapConverter @@ -51,7 +52,7 @@ def wait_for(self, result, n): while len(result) < n and time.time() - start_time < self.timeout: time.sleep(0.01) if len(result) < n: - print "timeout after", self.timeout + print("timeout after", self.timeout) def _take(self, dstream, n): """ @@ -131,7 +132,7 @@ def test_map(self): def func(dstream): return dstream.map(str) - expected = map(lambda x: map(str, x), input) + expected = [list(map(str, x)) for x in input] self._test_func(input, func, expected) def test_flatMap(self): @@ -140,8 +141,8 @@ def test_flatMap(self): def func(dstream): return dstream.flatMap(lambda x: (x, x * 2)) - expected = map(lambda x: list(chain.from_iterable((map(lambda y: [y, y * 2], x)))), - input) + expected = [list(chain.from_iterable((map(lambda y: [y, y * 2], x)))) + for x in input] self._test_func(input, func, expected) def test_filter(self): @@ -150,7 +151,7 @@ def test_filter(self): def func(dstream): return dstream.filter(lambda x: x % 2 == 0) - expected = map(lambda x: filter(lambda y: y % 2 == 0, x), input) + expected = [[y for y in x if y % 2 == 0] for x in input] self._test_func(input, func, expected) def test_count(self): @@ -159,7 +160,7 @@ def test_count(self): def func(dstream): return dstream.count() - expected = map(lambda x: [len(x)], input) + expected = [[len(x)] for x in input] self._test_func(input, func, expected) def test_reduce(self): @@ -168,7 +169,7 @@ def test_reduce(self): def func(dstream): return dstream.reduce(operator.add) - expected = map(lambda x: [reduce(operator.add, x)], input) + expected = [[reduce(operator.add, x)] for x in input] self._test_func(input, func, expected) def test_reduceByKey(self): @@ -185,27 +186,27 @@ def func(dstream): def test_mapValues(self): """Basic operation test for DStream.mapValues.""" input = [[("a", 2), ("b", 2), ("c", 1), ("d", 1)], - [("", 4), (1, 1), (2, 2), (3, 3)], + [(0, 4), (1, 1), (2, 2), (3, 3)], [(1, 1), (2, 1), (3, 1), (4, 1)]] def func(dstream): return dstream.mapValues(lambda x: x + 10) expected = [[("a", 12), ("b", 12), ("c", 11), ("d", 11)], - [("", 14), (1, 11), (2, 12), (3, 13)], + [(0, 14), (1, 11), (2, 12), (3, 13)], [(1, 11), (2, 11), (3, 11), (4, 11)]] self._test_func(input, func, expected, sort=True) def test_flatMapValues(self): """Basic operation test for DStream.flatMapValues.""" input = [[("a", 2), ("b", 2), ("c", 1), ("d", 1)], - [("", 4), (1, 1), (2, 1), (3, 1)], + [(0, 4), (1, 1), (2, 1), (3, 1)], [(1, 1), (2, 1), (3, 1), (4, 1)]] def func(dstream): return dstream.flatMapValues(lambda x: (x, x + 10)) expected = [[("a", 2), ("a", 12), ("b", 2), ("b", 12), ("c", 1), ("c", 11), ("d", 1), ("d", 11)], - [("", 4), ("", 14), (1, 1), (1, 11), (2, 1), (2, 11), (3, 1), (3, 11)], + [(0, 4), (0, 14), (1, 1), (1, 11), (2, 1), (2, 11), (3, 1), (3, 11)], [(1, 1), (1, 11), (2, 1), (2, 11), (3, 1), (3, 11), (4, 1), (4, 11)]] self._test_func(input, func, expected) @@ -233,7 +234,7 @@ def f(iterator): def test_countByValue(self): """Basic operation test for DStream.countByValue.""" - input = [range(1, 5) * 2, range(5, 7) + range(5, 9), ["a", "a", "b", ""]] + input = [list(range(1, 5)) * 2, list(range(5, 7)) + list(range(5, 9)), ["a", "a", "b", ""]] def func(dstream): return dstream.countByValue() @@ -285,7 +286,7 @@ def test_union(self): def func(d1, d2): return d1.union(d2) - expected = [range(6), range(6), range(6)] + expected = [list(range(6)), list(range(6)), list(range(6))] self._test_func(input1, func, expected, input2=input2) def test_cogroup(self): @@ -424,7 +425,7 @@ class StreamingContextTests(PySparkStreamingTestCase): duration = 0.1 def _add_input_stream(self): - inputs = map(lambda x: range(1, x), range(101)) + inputs = [range(1, x) for x in range(101)] stream = self.ssc.queueStream(inputs) self._collect(stream, 1, block=False) @@ -441,7 +442,7 @@ def test_stop_multiple_times(self): self.ssc.stop() def test_queue_stream(self): - input = [range(i + 1) for i in range(3)] + input = [list(range(i + 1)) for i in range(3)] dstream = self.ssc.queueStream(input) result = self._collect(dstream, 3) self.assertEqual(input, result) @@ -457,13 +458,13 @@ def test_text_file_stream(self): with open(os.path.join(d, name), "w") as f: f.writelines(["%d\n" % i for i in range(10)]) self.wait_for(result, 2) - self.assertEqual([range(10), range(10)], result) + self.assertEqual([list(range(10)), list(range(10))], result) def test_binary_records_stream(self): d = tempfile.mkdtemp() self.ssc = StreamingContext(self.sc, self.duration) dstream = self.ssc.binaryRecordsStream(d, 10).map( - lambda v: struct.unpack("10b", str(v))) + lambda v: struct.unpack("10b", bytes(v))) result = self._collect(dstream, 2, block=False) self.ssc.start() for name in ('a', 'b'): @@ -471,10 +472,10 @@ def test_binary_records_stream(self): with open(os.path.join(d, name), "wb") as f: f.write(bytearray(range(10))) self.wait_for(result, 2) - self.assertEqual([range(10), range(10)], map(lambda v: list(v[0]), result)) + self.assertEqual([list(range(10)), list(range(10))], [list(v[0]) for v in result]) def test_union(self): - input = [range(i + 1) for i in range(3)] + input = [list(range(i + 1)) for i in range(3)] dstream = self.ssc.queueStream(input) dstream2 = self.ssc.queueStream(input) dstream3 = self.ssc.union(dstream, dstream2) diff --git a/python/pyspark/streaming/util.py b/python/pyspark/streaming/util.py index 86ee5aa04f252..34291f30a5652 100644 --- a/python/pyspark/streaming/util.py +++ b/python/pyspark/streaming/util.py @@ -91,9 +91,9 @@ def dumps(self, id): except Exception: traceback.print_exc() - def loads(self, bytes): + def loads(self, data): try: - f, deserializers = self.serializer.loads(str(bytes)) + f, deserializers = self.serializer.loads(bytes(data)) return TransformFunction(self.ctx, f, *deserializers) except Exception: traceback.print_exc() @@ -116,7 +116,7 @@ def rddToFileName(prefix, suffix, timestamp): """ if isinstance(timestamp, datetime): seconds = time.mktime(timestamp.timetuple()) - timestamp = long(seconds * 1000) + timestamp.microsecond / 1000 + timestamp = int(seconds * 1000) + timestamp.microsecond // 1000 if suffix is None: return prefix + "-" + str(timestamp) else: diff --git a/python/pyspark/tests.py b/python/pyspark/tests.py index b938b9ce12395..75f39d9e75f38 100644 --- a/python/pyspark/tests.py +++ b/python/pyspark/tests.py @@ -19,8 +19,8 @@ Unit tests for PySpark; additional tests are implemented as doctests in individual modules. """ + from array import array -from fileinput import input from glob import glob import os import re @@ -45,6 +45,9 @@ sys.exit(1) else: import unittest + if sys.version_info[0] >= 3: + xrange = range + basestring = str from pyspark.conf import SparkConf @@ -52,7 +55,9 @@ from pyspark.rdd import RDD from pyspark.files import SparkFiles from pyspark.serializers import read_int, BatchedSerializer, MarshalSerializer, PickleSerializer, \ - CloudPickleSerializer, CompressedSerializer, UTF8Deserializer, NoOpSerializer + CloudPickleSerializer, CompressedSerializer, UTF8Deserializer, NoOpSerializer, \ + PairDeserializer, CartesianDeserializer, AutoBatchedSerializer, AutoSerializer, \ + FlattenedValuesSerializer from pyspark.shuffle import Aggregator, InMemoryMerger, ExternalMerger, ExternalSorter from pyspark import shuffle from pyspark.profiler import BasicProfiler @@ -81,7 +86,7 @@ class MergerTests(unittest.TestCase): def setUp(self): self.N = 1 << 12 self.l = [i for i in xrange(self.N)] - self.data = zip(self.l, self.l) + self.data = list(zip(self.l, self.l)) self.agg = Aggregator(lambda x: [x], lambda x, y: x.append(y) or x, lambda x, y: x.extend(y) or x) @@ -89,45 +94,45 @@ def setUp(self): def test_in_memory(self): m = InMemoryMerger(self.agg) m.mergeValues(self.data) - self.assertEqual(sum(sum(v) for k, v in m.iteritems()), + self.assertEqual(sum(sum(v) for k, v in m.items()), sum(xrange(self.N))) m = InMemoryMerger(self.agg) - m.mergeCombiners(map(lambda (x, y): (x, [y]), self.data)) - self.assertEqual(sum(sum(v) for k, v in m.iteritems()), + m.mergeCombiners(map(lambda x_y: (x_y[0], [x_y[1]]), self.data)) + self.assertEqual(sum(sum(v) for k, v in m.items()), sum(xrange(self.N))) def test_small_dataset(self): m = ExternalMerger(self.agg, 1000) m.mergeValues(self.data) self.assertEqual(m.spills, 0) - self.assertEqual(sum(sum(v) for k, v in m.iteritems()), + self.assertEqual(sum(sum(v) for k, v in m.items()), sum(xrange(self.N))) m = ExternalMerger(self.agg, 1000) - m.mergeCombiners(map(lambda (x, y): (x, [y]), self.data)) + m.mergeCombiners(map(lambda x_y1: (x_y1[0], [x_y1[1]]), self.data)) self.assertEqual(m.spills, 0) - self.assertEqual(sum(sum(v) for k, v in m.iteritems()), + self.assertEqual(sum(sum(v) for k, v in m.items()), sum(xrange(self.N))) def test_medium_dataset(self): - m = ExternalMerger(self.agg, 30) + m = ExternalMerger(self.agg, 20) m.mergeValues(self.data) self.assertTrue(m.spills >= 1) - self.assertEqual(sum(sum(v) for k, v in m.iteritems()), + self.assertEqual(sum(sum(v) for k, v in m.items()), sum(xrange(self.N))) m = ExternalMerger(self.agg, 10) - m.mergeCombiners(map(lambda (x, y): (x, [y]), self.data * 3)) + m.mergeCombiners(map(lambda x_y2: (x_y2[0], [x_y2[1]]), self.data * 3)) self.assertTrue(m.spills >= 1) - self.assertEqual(sum(sum(v) for k, v in m.iteritems()), + self.assertEqual(sum(sum(v) for k, v in m.items()), sum(xrange(self.N)) * 3) def test_huge_dataset(self): - m = ExternalMerger(self.agg, 10, partitions=3) - m.mergeCombiners(map(lambda (k, v): (k, [str(v)]), self.data * 10)) + m = ExternalMerger(self.agg, 5, partitions=3) + m.mergeCombiners(map(lambda k_v: (k_v[0], [str(k_v[1])]), self.data * 10)) self.assertTrue(m.spills >= 1) - self.assertEqual(sum(len(v) for k, v in m.iteritems()), + self.assertEqual(sum(len(v) for k, v in m.items()), self.N * 10) m._cleanup() @@ -144,55 +149,55 @@ def gen_gs(N, step=1): self.assertEqual(1, len(list(gen_gs(1)))) self.assertEqual(2, len(list(gen_gs(2)))) self.assertEqual(100, len(list(gen_gs(100)))) - self.assertEqual(range(1, 101), [k for k, _ in gen_gs(100)]) - self.assertTrue(all(range(k) == list(vs) for k, vs in gen_gs(100))) + self.assertEqual(list(range(1, 101)), [k for k, _ in gen_gs(100)]) + self.assertTrue(all(list(range(k)) == list(vs) for k, vs in gen_gs(100))) for k, vs in gen_gs(50002, 10000): self.assertEqual(k, len(vs)) - self.assertEqual(range(k), list(vs)) + self.assertEqual(list(range(k)), list(vs)) ser = PickleSerializer() l = ser.loads(ser.dumps(list(gen_gs(50002, 30000)))) for k, vs in l: self.assertEqual(k, len(vs)) - self.assertEqual(range(k), list(vs)) + self.assertEqual(list(range(k)), list(vs)) class SorterTests(unittest.TestCase): def test_in_memory_sort(self): - l = range(1024) + l = list(range(1024)) random.shuffle(l) sorter = ExternalSorter(1024) - self.assertEquals(sorted(l), list(sorter.sorted(l))) - self.assertEquals(sorted(l, reverse=True), list(sorter.sorted(l, reverse=True))) - self.assertEquals(sorted(l, key=lambda x: -x), list(sorter.sorted(l, key=lambda x: -x))) - self.assertEquals(sorted(l, key=lambda x: -x, reverse=True), - list(sorter.sorted(l, key=lambda x: -x, reverse=True))) + self.assertEqual(sorted(l), list(sorter.sorted(l))) + self.assertEqual(sorted(l, reverse=True), list(sorter.sorted(l, reverse=True))) + self.assertEqual(sorted(l, key=lambda x: -x), list(sorter.sorted(l, key=lambda x: -x))) + self.assertEqual(sorted(l, key=lambda x: -x, reverse=True), + list(sorter.sorted(l, key=lambda x: -x, reverse=True))) def test_external_sort(self): - l = range(1024) + l = list(range(1024)) random.shuffle(l) sorter = ExternalSorter(1) - self.assertEquals(sorted(l), list(sorter.sorted(l))) + self.assertEqual(sorted(l), list(sorter.sorted(l))) self.assertGreater(shuffle.DiskBytesSpilled, 0) last = shuffle.DiskBytesSpilled - self.assertEquals(sorted(l, reverse=True), list(sorter.sorted(l, reverse=True))) + self.assertEqual(sorted(l, reverse=True), list(sorter.sorted(l, reverse=True))) self.assertGreater(shuffle.DiskBytesSpilled, last) last = shuffle.DiskBytesSpilled - self.assertEquals(sorted(l, key=lambda x: -x), list(sorter.sorted(l, key=lambda x: -x))) + self.assertEqual(sorted(l, key=lambda x: -x), list(sorter.sorted(l, key=lambda x: -x))) self.assertGreater(shuffle.DiskBytesSpilled, last) last = shuffle.DiskBytesSpilled - self.assertEquals(sorted(l, key=lambda x: -x, reverse=True), - list(sorter.sorted(l, key=lambda x: -x, reverse=True))) + self.assertEqual(sorted(l, key=lambda x: -x, reverse=True), + list(sorter.sorted(l, key=lambda x: -x, reverse=True))) self.assertGreater(shuffle.DiskBytesSpilled, last) def test_external_sort_in_rdd(self): conf = SparkConf().set("spark.python.worker.memory", "1m") sc = SparkContext(conf=conf) - l = range(10240) + l = list(range(10240)) random.shuffle(l) - rdd = sc.parallelize(l, 10) - self.assertEquals(sorted(l), rdd.sortBy(lambda x: x).collect()) + rdd = sc.parallelize(l, 2) + self.assertEqual(sorted(l), rdd.sortBy(lambda x: x).collect()) sc.stop() @@ -200,11 +205,11 @@ class SerializationTestCase(unittest.TestCase): def test_namedtuple(self): from collections import namedtuple - from cPickle import dumps, loads + from pickle import dumps, loads P = namedtuple("P", "x y") p1 = P(1, 3) p2 = loads(dumps(p1, 2)) - self.assertEquals(p1, p2) + self.assertEqual(p1, p2) def test_itemgetter(self): from operator import itemgetter @@ -246,7 +251,7 @@ def test_pickling_file_handles(self): ser = CloudPickleSerializer() out1 = sys.stderr out2 = ser.loads(ser.dumps(out1)) - self.assertEquals(out1, out2) + self.assertEqual(out1, out2) def test_func_globals(self): @@ -263,19 +268,36 @@ def __reduce__(self): def foo(): sys.exit(0) - self.assertTrue("exit" in foo.func_code.co_names) + self.assertTrue("exit" in foo.__code__.co_names) ser.dumps(foo) def test_compressed_serializer(self): ser = CompressedSerializer(PickleSerializer()) - from StringIO import StringIO + try: + from StringIO import StringIO + except ImportError: + from io import BytesIO as StringIO io = StringIO() ser.dump_stream(["abc", u"123", range(5)], io) io.seek(0) self.assertEqual(["abc", u"123", range(5)], list(ser.load_stream(io))) ser.dump_stream(range(1000), io) io.seek(0) - self.assertEqual(["abc", u"123", range(5)] + range(1000), list(ser.load_stream(io))) + self.assertEqual(["abc", u"123", range(5)] + list(range(1000)), list(ser.load_stream(io))) + io.close() + + def test_hash_serializer(self): + hash(NoOpSerializer()) + hash(UTF8Deserializer()) + hash(PickleSerializer()) + hash(MarshalSerializer()) + hash(AutoSerializer()) + hash(BatchedSerializer(PickleSerializer())) + hash(AutoBatchedSerializer(MarshalSerializer())) + hash(PairDeserializer(NoOpSerializer(), UTF8Deserializer())) + hash(CartesianDeserializer(NoOpSerializer(), UTF8Deserializer())) + hash(CompressedSerializer(PickleSerializer())) + hash(FlattenedValuesSerializer(PickleSerializer())) class PySparkTestCase(unittest.TestCase): @@ -340,7 +362,7 @@ def test_checkpoint_and_restore(self): self.assertTrue(flatMappedRDD.getCheckpointFile() is not None) recovered = self.sc._checkpointFile(flatMappedRDD.getCheckpointFile(), flatMappedRDD._jrdd_deserializer) - self.assertEquals([1, 2, 3, 4], recovered.collect()) + self.assertEqual([1, 2, 3, 4], recovered.collect()) class AddFileTests(PySparkTestCase): @@ -356,8 +378,7 @@ def test_add_py_file(self): def func(x): from userlibrary import UserClass return UserClass().hello() - self.assertRaises(Exception, - self.sc.parallelize(range(2)).map(func).first) + self.assertRaises(Exception, self.sc.parallelize(range(2)).map(func).first) log4j.LogManager.getRootLogger().setLevel(old_level) # Add the file, so the job should now succeed: @@ -372,7 +393,7 @@ def test_add_file_locally(self): download_path = SparkFiles.get("hello.txt") self.assertNotEqual(path, download_path) with open(download_path) as test_file: - self.assertEquals("Hello World!\n", test_file.readline()) + self.assertEqual("Hello World!\n", test_file.readline()) def test_add_py_file_locally(self): # To ensure that we're actually testing addPyFile's effects, check that @@ -381,7 +402,7 @@ def func(): from userlibrary import UserClass self.assertRaises(ImportError, func) path = os.path.join(SPARK_HOME, "python/test_support/userlibrary.py") - self.sc.addFile(path) + self.sc.addPyFile(path) from userlibrary import UserClass self.assertEqual("Hello World!", UserClass().hello()) @@ -391,7 +412,7 @@ def test_add_egg_file_locally(self): def func(): from userlib import UserClass self.assertRaises(ImportError, func) - path = os.path.join(SPARK_HOME, "python/test_support/userlib-0.1-py2.7.egg") + path = os.path.join(SPARK_HOME, "python/test_support/userlib-0.1.zip") self.sc.addPyFile(path) from userlib import UserClass self.assertEqual("Hello World from inside a package!", UserClass().hello()) @@ -427,8 +448,9 @@ def test_save_as_textfile_with_unicode(self): tempFile = tempfile.NamedTemporaryFile(delete=True) tempFile.close() data.saveAsTextFile(tempFile.name) - raw_contents = ''.join(input(glob(tempFile.name + "/part-0000*"))) - self.assertEqual(x, unicode(raw_contents.strip(), "utf-8")) + raw_contents = b''.join(open(p, 'rb').read() + for p in glob(tempFile.name + "/part-0000*")) + self.assertEqual(x, raw_contents.strip().decode("utf-8")) def test_save_as_textfile_with_utf8(self): x = u"\u00A1Hola, mundo!" @@ -436,19 +458,20 @@ def test_save_as_textfile_with_utf8(self): tempFile = tempfile.NamedTemporaryFile(delete=True) tempFile.close() data.saveAsTextFile(tempFile.name) - raw_contents = ''.join(input(glob(tempFile.name + "/part-0000*"))) - self.assertEqual(x, unicode(raw_contents.strip(), "utf-8")) + raw_contents = b''.join(open(p, 'rb').read() + for p in glob(tempFile.name + "/part-0000*")) + self.assertEqual(x, raw_contents.strip().decode('utf8')) def test_transforming_cartesian_result(self): # Regression test for SPARK-1034 rdd1 = self.sc.parallelize([1, 2]) rdd2 = self.sc.parallelize([3, 4]) cart = rdd1.cartesian(rdd2) - result = cart.map(lambda (x, y): x + y).collect() + result = cart.map(lambda x_y3: x_y3[0] + x_y3[1]).collect() def test_transforming_pickle_file(self): # Regression test for SPARK-2601 - data = self.sc.parallelize(["Hello", "World!"]) + data = self.sc.parallelize([u"Hello", u"World!"]) tempFile = tempfile.NamedTemporaryFile(delete=True) tempFile.close() data.saveAsPickleFile(tempFile.name) @@ -461,13 +484,13 @@ def test_cartesian_on_textfile(self): a = self.sc.textFile(path) result = a.cartesian(a).collect() (x, y) = result[0] - self.assertEqual("Hello World!", x.strip()) - self.assertEqual("Hello World!", y.strip()) + self.assertEqual(u"Hello World!", x.strip()) + self.assertEqual(u"Hello World!", y.strip()) def test_deleting_input_files(self): # Regression test for SPARK-1025 tempFile = tempfile.NamedTemporaryFile(delete=False) - tempFile.write("Hello World!") + tempFile.write(b"Hello World!") tempFile.close() data = self.sc.textFile(tempFile.name) filtered_data = data.filter(lambda x: True) @@ -510,21 +533,21 @@ def test_namedtuple_in_rdd(self): jon = Person(1, "Jon", "Doe") jane = Person(2, "Jane", "Doe") theDoes = self.sc.parallelize([jon, jane]) - self.assertEquals([jon, jane], theDoes.collect()) + self.assertEqual([jon, jane], theDoes.collect()) def test_large_broadcast(self): N = 100000 data = [[float(i) for i in range(300)] for i in range(N)] bdata = self.sc.broadcast(data) # 270MB m = self.sc.parallelize(range(1), 1).map(lambda x: len(bdata.value)).sum() - self.assertEquals(N, m) + self.assertEqual(N, m) def test_multiple_broadcasts(self): N = 1 << 21 b1 = self.sc.broadcast(set(range(N))) # multiple blocks in JVM - r = range(1 << 15) + r = list(range(1 << 15)) random.shuffle(r) - s = str(r) + s = str(r).encode() checksum = hashlib.md5(s).hexdigest() b2 = self.sc.broadcast(s) r = list(set(self.sc.parallelize(range(10), 10).map( @@ -535,7 +558,7 @@ def test_multiple_broadcasts(self): self.assertEqual(checksum, csum) random.shuffle(r) - s = str(r) + s = str(r).encode() checksum = hashlib.md5(s).hexdigest() b2 = self.sc.broadcast(s) r = list(set(self.sc.parallelize(range(10), 10).map( @@ -549,11 +572,9 @@ def test_large_closure(self): N = 1000000 data = [float(i) for i in xrange(N)] rdd = self.sc.parallelize(range(1), 1).map(lambda x: len(data)) - self.assertEquals(N, rdd.first()) - self.assertTrue(rdd._broadcast is not None) - rdd = self.sc.parallelize(range(1), 1).map(lambda x: 1) - self.assertEqual(1, rdd.first()) - self.assertTrue(rdd._broadcast is None) + self.assertEqual(N, rdd.first()) + # regression test for SPARK-6886 + self.assertEqual(1, rdd.map(lambda x: (x, 1)).groupByKey().count()) def test_zip_with_different_serializers(self): a = self.sc.parallelize(range(5)) @@ -592,15 +613,15 @@ def test_zip_with_different_number_of_items(self): # same total number of items, but different distributions a = self.sc.parallelize([2, 3], 2).flatMap(range) b = self.sc.parallelize([3, 2], 2).flatMap(range) - self.assertEquals(a.count(), b.count()) + self.assertEqual(a.count(), b.count()) self.assertRaises(Exception, lambda: a.zip(b).count()) def test_count_approx_distinct(self): rdd = self.sc.parallelize(range(1000)) - self.assertTrue(950 < rdd.countApproxDistinct(0.04) < 1050) - self.assertTrue(950 < rdd.map(float).countApproxDistinct(0.04) < 1050) - self.assertTrue(950 < rdd.map(str).countApproxDistinct(0.04) < 1050) - self.assertTrue(950 < rdd.map(lambda x: (x, -x)).countApproxDistinct(0.04) < 1050) + self.assertTrue(950 < rdd.countApproxDistinct(0.03) < 1050) + self.assertTrue(950 < rdd.map(float).countApproxDistinct(0.03) < 1050) + self.assertTrue(950 < rdd.map(str).countApproxDistinct(0.03) < 1050) + self.assertTrue(950 < rdd.map(lambda x: (x, -x)).countApproxDistinct(0.03) < 1050) rdd = self.sc.parallelize([i % 20 for i in range(1000)], 7) self.assertTrue(18 < rdd.countApproxDistinct() < 22) @@ -614,59 +635,59 @@ def test_count_approx_distinct(self): def test_histogram(self): # empty rdd = self.sc.parallelize([]) - self.assertEquals([0], rdd.histogram([0, 10])[1]) - self.assertEquals([0, 0], rdd.histogram([0, 4, 10])[1]) + self.assertEqual([0], rdd.histogram([0, 10])[1]) + self.assertEqual([0, 0], rdd.histogram([0, 4, 10])[1]) self.assertRaises(ValueError, lambda: rdd.histogram(1)) # out of range rdd = self.sc.parallelize([10.01, -0.01]) - self.assertEquals([0], rdd.histogram([0, 10])[1]) - self.assertEquals([0, 0], rdd.histogram((0, 4, 10))[1]) + self.assertEqual([0], rdd.histogram([0, 10])[1]) + self.assertEqual([0, 0], rdd.histogram((0, 4, 10))[1]) # in range with one bucket rdd = self.sc.parallelize(range(1, 5)) - self.assertEquals([4], rdd.histogram([0, 10])[1]) - self.assertEquals([3, 1], rdd.histogram([0, 4, 10])[1]) + self.assertEqual([4], rdd.histogram([0, 10])[1]) + self.assertEqual([3, 1], rdd.histogram([0, 4, 10])[1]) # in range with one bucket exact match - self.assertEquals([4], rdd.histogram([1, 4])[1]) + self.assertEqual([4], rdd.histogram([1, 4])[1]) # out of range with two buckets rdd = self.sc.parallelize([10.01, -0.01]) - self.assertEquals([0, 0], rdd.histogram([0, 5, 10])[1]) + self.assertEqual([0, 0], rdd.histogram([0, 5, 10])[1]) # out of range with two uneven buckets rdd = self.sc.parallelize([10.01, -0.01]) - self.assertEquals([0, 0], rdd.histogram([0, 4, 10])[1]) + self.assertEqual([0, 0], rdd.histogram([0, 4, 10])[1]) # in range with two buckets rdd = self.sc.parallelize([1, 2, 3, 5, 6]) - self.assertEquals([3, 2], rdd.histogram([0, 5, 10])[1]) + self.assertEqual([3, 2], rdd.histogram([0, 5, 10])[1]) # in range with two bucket and None rdd = self.sc.parallelize([1, 2, 3, 5, 6, None, float('nan')]) - self.assertEquals([3, 2], rdd.histogram([0, 5, 10])[1]) + self.assertEqual([3, 2], rdd.histogram([0, 5, 10])[1]) # in range with two uneven buckets rdd = self.sc.parallelize([1, 2, 3, 5, 6]) - self.assertEquals([3, 2], rdd.histogram([0, 5, 11])[1]) + self.assertEqual([3, 2], rdd.histogram([0, 5, 11])[1]) # mixed range with two uneven buckets rdd = self.sc.parallelize([-0.01, 0.0, 1, 2, 3, 5, 6, 11.0, 11.01]) - self.assertEquals([4, 3], rdd.histogram([0, 5, 11])[1]) + self.assertEqual([4, 3], rdd.histogram([0, 5, 11])[1]) # mixed range with four uneven buckets rdd = self.sc.parallelize([-0.01, 0.0, 1, 2, 3, 5, 6, 11.01, 12.0, 199.0, 200.0, 200.1]) - self.assertEquals([4, 2, 1, 3], rdd.histogram([0.0, 5.0, 11.0, 12.0, 200.0])[1]) + self.assertEqual([4, 2, 1, 3], rdd.histogram([0.0, 5.0, 11.0, 12.0, 200.0])[1]) # mixed range with uneven buckets and NaN rdd = self.sc.parallelize([-0.01, 0.0, 1, 2, 3, 5, 6, 11.01, 12.0, 199.0, 200.0, 200.1, None, float('nan')]) - self.assertEquals([4, 2, 1, 3], rdd.histogram([0.0, 5.0, 11.0, 12.0, 200.0])[1]) + self.assertEqual([4, 2, 1, 3], rdd.histogram([0.0, 5.0, 11.0, 12.0, 200.0])[1]) # out of range with infinite buckets rdd = self.sc.parallelize([10.01, -0.01, float('nan'), float("inf")]) - self.assertEquals([1, 2], rdd.histogram([float('-inf'), 0, float('inf')])[1]) + self.assertEqual([1, 2], rdd.histogram([float('-inf'), 0, float('inf')])[1]) # invalid buckets self.assertRaises(ValueError, lambda: rdd.histogram([])) @@ -676,25 +697,25 @@ def test_histogram(self): # without buckets rdd = self.sc.parallelize(range(1, 5)) - self.assertEquals(([1, 4], [4]), rdd.histogram(1)) + self.assertEqual(([1, 4], [4]), rdd.histogram(1)) # without buckets single element rdd = self.sc.parallelize([1]) - self.assertEquals(([1, 1], [1]), rdd.histogram(1)) + self.assertEqual(([1, 1], [1]), rdd.histogram(1)) # without bucket no range rdd = self.sc.parallelize([1] * 4) - self.assertEquals(([1, 1], [4]), rdd.histogram(1)) + self.assertEqual(([1, 1], [4]), rdd.histogram(1)) # without buckets basic two rdd = self.sc.parallelize(range(1, 5)) - self.assertEquals(([1, 2.5, 4], [2, 2]), rdd.histogram(2)) + self.assertEqual(([1, 2.5, 4], [2, 2]), rdd.histogram(2)) # without buckets with more requested than elements rdd = self.sc.parallelize([1, 2]) buckets = [1 + 0.2 * i for i in range(6)] hist = [1, 0, 0, 0, 1] - self.assertEquals((buckets, hist), rdd.histogram(5)) + self.assertEqual((buckets, hist), rdd.histogram(5)) # invalid RDDs rdd = self.sc.parallelize([1, float('inf')]) @@ -704,15 +725,8 @@ def test_histogram(self): # string rdd = self.sc.parallelize(["ab", "ac", "b", "bd", "ef"], 2) - self.assertEquals([2, 2], rdd.histogram(["a", "b", "c"])[1]) - self.assertEquals((["ab", "ef"], [5]), rdd.histogram(1)) - self.assertRaises(TypeError, lambda: rdd.histogram(2)) - - # mixed RDD - rdd = self.sc.parallelize([1, 4, "ab", "ac", "b"], 2) - self.assertEquals([1, 1], rdd.histogram([0, 4, 10])[1]) - self.assertEquals([2, 1], rdd.histogram(["a", "b", "c"])[1]) - self.assertEquals(([1, "b"], [5]), rdd.histogram(1)) + self.assertEqual([2, 2], rdd.histogram(["a", "b", "c"])[1]) + self.assertEqual((["ab", "ef"], [5]), rdd.histogram(1)) self.assertRaises(TypeError, lambda: rdd.histogram(2)) def test_repartitionAndSortWithinPartitions(self): @@ -720,31 +734,31 @@ def test_repartitionAndSortWithinPartitions(self): repartitioned = rdd.repartitionAndSortWithinPartitions(2, lambda key: key % 2) partitions = repartitioned.glom().collect() - self.assertEquals(partitions[0], [(0, 5), (0, 8), (2, 6)]) - self.assertEquals(partitions[1], [(1, 3), (3, 8), (3, 8)]) + self.assertEqual(partitions[0], [(0, 5), (0, 8), (2, 6)]) + self.assertEqual(partitions[1], [(1, 3), (3, 8), (3, 8)]) def test_distinct(self): rdd = self.sc.parallelize((1, 2, 3)*10, 10) - self.assertEquals(rdd.getNumPartitions(), 10) - self.assertEquals(rdd.distinct().count(), 3) + self.assertEqual(rdd.getNumPartitions(), 10) + self.assertEqual(rdd.distinct().count(), 3) result = rdd.distinct(5) - self.assertEquals(result.getNumPartitions(), 5) - self.assertEquals(result.count(), 3) + self.assertEqual(result.getNumPartitions(), 5) + self.assertEqual(result.count(), 3) def test_external_group_by_key(self): - self.sc._conf.set("spark.python.worker.memory", "5m") + self.sc._conf.set("spark.python.worker.memory", "1m") N = 200001 kv = self.sc.parallelize(range(N)).map(lambda x: (x % 3, x)) gkv = kv.groupByKey().cache() self.assertEqual(3, gkv.count()) - filtered = gkv.filter(lambda (k, vs): k == 1) + filtered = gkv.filter(lambda kv: kv[0] == 1) self.assertEqual(1, filtered.count()) - self.assertEqual([(1, N/3)], filtered.mapValues(len).collect()) - self.assertEqual([(N/3, N/3)], + self.assertEqual([(1, N // 3)], filtered.mapValues(len).collect()) + self.assertEqual([(N // 3, N // 3)], filtered.values().map(lambda x: (len(x), len(list(x)))).collect()) result = filtered.collect()[0][1] - self.assertEqual(N/3, len(result)) - self.assertTrue(isinstance(result.data, shuffle.ExternalList)) + self.assertEqual(N // 3, len(result)) + self.assertTrue(isinstance(result.data, shuffle.ExternalListOfList)) def test_sort_on_empty_rdd(self): self.assertEqual([], self.sc.parallelize(zip([], [])).sortByKey().collect()) @@ -769,7 +783,7 @@ def test_null_in_rdd(self): rdd = RDD(jrdd, self.sc, UTF8Deserializer()) self.assertEqual([u"a", None, u"b"], rdd.collect()) rdd = RDD(jrdd, self.sc, NoOpSerializer()) - self.assertEqual(["a", None, "b"], rdd.collect()) + self.assertEqual([b"a", None, b"b"], rdd.collect()) def test_multiple_python_java_RDD_conversions(self): # Regression test for SPARK-5361 @@ -815,14 +829,14 @@ def test_narrow_dependency_in_join(self): self.sc.setJobGroup("test3", "test", True) d = sorted(parted.cogroup(parted).collect()) self.assertEqual(10, len(d)) - self.assertEqual([[0], [0]], map(list, d[0][1])) + self.assertEqual([[0], [0]], list(map(list, d[0][1]))) jobId = tracker.getJobIdsForGroup("test3")[0] self.assertEqual(2, len(tracker.getJobInfo(jobId).stageIds)) self.sc.setJobGroup("test4", "test", True) d = sorted(parted.cogroup(rdd).collect()) self.assertEqual(10, len(d)) - self.assertEqual([[0], [0]], map(list, d[0][1])) + self.assertEqual([[0], [0]], list(map(list, d[0][1]))) jobId = tracker.getJobIdsForGroup("test4")[0] self.assertEqual(3, len(tracker.getJobInfo(jobId).stageIds)) @@ -908,6 +922,7 @@ def tearDownClass(cls): ReusedPySparkTestCase.tearDownClass() shutil.rmtree(cls.tempdir.name) + @unittest.skipIf(sys.version >= "3", "serialize array of byte") def test_sequencefiles(self): basepath = self.tempdir.name ints = sorted(self.sc.sequenceFile(basepath + "/sftestdata/sfint/", @@ -956,15 +971,16 @@ def test_sequencefiles(self): en = [(1, None), (1, None), (2, None), (2, None), (2, None), (3, None)] self.assertEqual(nulls, en) - maps = sorted(self.sc.sequenceFile(basepath + "/sftestdata/sfmap/", - "org.apache.hadoop.io.IntWritable", - "org.apache.hadoop.io.MapWritable").collect()) + maps = self.sc.sequenceFile(basepath + "/sftestdata/sfmap/", + "org.apache.hadoop.io.IntWritable", + "org.apache.hadoop.io.MapWritable").collect() em = [(1, {}), (1, {3.0: u'bb'}), (2, {1.0: u'aa'}), (2, {1.0: u'cc'}), (3, {2.0: u'dd'})] - self.assertEqual(maps, em) + for v in maps: + self.assertTrue(v in em) # arrays get pickled to tuples by default tuples = sorted(self.sc.sequenceFile( @@ -1091,8 +1107,8 @@ def test_converters(self): def test_binary_files(self): path = os.path.join(self.tempdir.name, "binaryfiles") os.mkdir(path) - data = "short binary data" - with open(os.path.join(path, "part-0000"), 'w') as f: + data = b"short binary data" + with open(os.path.join(path, "part-0000"), 'wb') as f: f.write(data) [(p, d)] = self.sc.binaryFiles(path).collect() self.assertTrue(p.endswith("part-0000")) @@ -1105,7 +1121,7 @@ def test_binary_records(self): for i in range(100): f.write('%04d' % i) result = self.sc.binaryRecords(path, 4).map(int).collect() - self.assertEqual(range(100), result) + self.assertEqual(list(range(100)), result) class OutputFormatTests(ReusedPySparkTestCase): @@ -1117,6 +1133,7 @@ def setUp(self): def tearDown(self): shutil.rmtree(self.tempdir.name, ignore_errors=True) + @unittest.skipIf(sys.version >= "3", "serialize array of byte") def test_sequencefiles(self): basepath = self.tempdir.name ei = [(1, u'aa'), (1, u'aa'), (2, u'aa'), (2, u'bb'), (2, u'bb'), (3, u'cc')] @@ -1157,8 +1174,9 @@ def test_sequencefiles(self): (2, {1.0: u'cc'}), (3, {2.0: u'dd'})] self.sc.parallelize(em).saveAsSequenceFile(basepath + "/sfmap/") - maps = sorted(self.sc.sequenceFile(basepath + "/sfmap/").collect()) - self.assertEqual(maps, em) + maps = self.sc.sequenceFile(basepath + "/sfmap/").collect() + for v in maps: + self.assertTrue(v, em) def test_oldhadoop(self): basepath = self.tempdir.name @@ -1170,12 +1188,13 @@ def test_oldhadoop(self): "org.apache.hadoop.mapred.SequenceFileOutputFormat", "org.apache.hadoop.io.IntWritable", "org.apache.hadoop.io.MapWritable") - result = sorted(self.sc.hadoopFile( + result = self.sc.hadoopFile( basepath + "/oldhadoop/", "org.apache.hadoop.mapred.SequenceFileInputFormat", "org.apache.hadoop.io.IntWritable", - "org.apache.hadoop.io.MapWritable").collect()) - self.assertEqual(result, dict_data) + "org.apache.hadoop.io.MapWritable").collect() + for v in result: + self.assertTrue(v, dict_data) conf = { "mapred.output.format.class": "org.apache.hadoop.mapred.SequenceFileOutputFormat", @@ -1185,12 +1204,13 @@ def test_oldhadoop(self): } self.sc.parallelize(dict_data).saveAsHadoopDataset(conf) input_conf = {"mapred.input.dir": basepath + "/olddataset/"} - old_dataset = sorted(self.sc.hadoopRDD( + result = self.sc.hadoopRDD( "org.apache.hadoop.mapred.SequenceFileInputFormat", "org.apache.hadoop.io.IntWritable", "org.apache.hadoop.io.MapWritable", - conf=input_conf).collect()) - self.assertEqual(old_dataset, dict_data) + conf=input_conf).collect() + for v in result: + self.assertTrue(v, dict_data) def test_newhadoop(self): basepath = self.tempdir.name @@ -1225,6 +1245,7 @@ def test_newhadoop(self): conf=input_conf).collect()) self.assertEqual(new_dataset, data) + @unittest.skipIf(sys.version >= "3", "serialize of array") def test_newhadoop_with_array(self): basepath = self.tempdir.name # use custom ArrayWritable types and converters to handle arrays @@ -1305,7 +1326,7 @@ def test_reserialization(self): basepath = self.tempdir.name x = range(1, 5) y = range(1001, 1005) - data = zip(x, y) + data = list(zip(x, y)) rdd = self.sc.parallelize(x).zip(self.sc.parallelize(y)) rdd.saveAsSequenceFile(basepath + "/reserialize/sequence") result1 = sorted(self.sc.sequenceFile(basepath + "/reserialize/sequence").collect()) @@ -1356,7 +1377,7 @@ def connect(self, port): sock = socket(AF_INET, SOCK_STREAM) sock.connect(('127.0.0.1', port)) # send a split index of -1 to shutdown the worker - sock.send("\xFF\xFF\xFF\xFF") + sock.send(b"\xFF\xFF\xFF\xFF") sock.close() return True @@ -1397,7 +1418,6 @@ def test_termination_sigterm(self): class WorkerTests(PySparkTestCase): - def test_cancel_task(self): temp = tempfile.NamedTemporaryFile(delete=True) temp.close() @@ -1412,7 +1432,7 @@ def sleep(x): # start job in background thread def run(): - self.sc.parallelize(range(1)).foreach(sleep) + self.sc.parallelize(range(1), 1).foreach(sleep) import threading t = threading.Thread(target=run) t.daemon = True @@ -1421,7 +1441,8 @@ def run(): daemon_pid, worker_pid = 0, 0 while True: if os.path.exists(path): - data = open(path).read().split(' ') + with open(path) as f: + data = f.read().split(' ') daemon_pid, worker_pid = map(int, data) break time.sleep(0.1) @@ -1457,7 +1478,7 @@ def raise_exception(_): def test_after_jvm_exception(self): tempFile = tempfile.NamedTemporaryFile(delete=False) - tempFile.write("Hello World!") + tempFile.write(b"Hello World!") tempFile.close() data = self.sc.textFile(tempFile.name, 1) filtered_data = data.filter(lambda x: True) @@ -1579,12 +1600,12 @@ def test_single_script(self): |from pyspark import SparkContext | |sc = SparkContext() - |print sc.parallelize([1, 2, 3]).map(lambda x: x * 2).collect() + |print(sc.parallelize([1, 2, 3]).map(lambda x: x * 2).collect()) """) proc = subprocess.Popen([self.sparkSubmit, script], stdout=subprocess.PIPE) out, err = proc.communicate() self.assertEqual(0, proc.returncode) - self.assertIn("[2, 4, 6]", out) + self.assertIn("[2, 4, 6]", out.decode('utf-8')) def test_script_with_local_functions(self): """Submit and test a single script file calling a global function""" @@ -1595,12 +1616,12 @@ def test_script_with_local_functions(self): | return x * 3 | |sc = SparkContext() - |print sc.parallelize([1, 2, 3]).map(foo).collect() + |print(sc.parallelize([1, 2, 3]).map(foo).collect()) """) proc = subprocess.Popen([self.sparkSubmit, script], stdout=subprocess.PIPE) out, err = proc.communicate() self.assertEqual(0, proc.returncode) - self.assertIn("[3, 6, 9]", out) + self.assertIn("[3, 6, 9]", out.decode('utf-8')) def test_module_dependency(self): """Submit and test a script with a dependency on another module""" @@ -1609,7 +1630,7 @@ def test_module_dependency(self): |from mylib import myfunc | |sc = SparkContext() - |print sc.parallelize([1, 2, 3]).map(myfunc).collect() + |print(sc.parallelize([1, 2, 3]).map(myfunc).collect()) """) zip = self.createFileInZip("mylib.py", """ |def myfunc(x): @@ -1619,7 +1640,7 @@ def test_module_dependency(self): stdout=subprocess.PIPE) out, err = proc.communicate() self.assertEqual(0, proc.returncode) - self.assertIn("[2, 3, 4]", out) + self.assertIn("[2, 3, 4]", out.decode('utf-8')) def test_module_dependency_on_cluster(self): """Submit and test a script with a dependency on another module on a cluster""" @@ -1628,7 +1649,7 @@ def test_module_dependency_on_cluster(self): |from mylib import myfunc | |sc = SparkContext() - |print sc.parallelize([1, 2, 3]).map(myfunc).collect() + |print(sc.parallelize([1, 2, 3]).map(myfunc).collect()) """) zip = self.createFileInZip("mylib.py", """ |def myfunc(x): @@ -1639,7 +1660,7 @@ def test_module_dependency_on_cluster(self): stdout=subprocess.PIPE) out, err = proc.communicate() self.assertEqual(0, proc.returncode) - self.assertIn("[2, 3, 4]", out) + self.assertIn("[2, 3, 4]", out.decode('utf-8')) def test_package_dependency(self): """Submit and test a script with a dependency on a Spark Package""" @@ -1648,14 +1669,14 @@ def test_package_dependency(self): |from mylib import myfunc | |sc = SparkContext() - |print sc.parallelize([1, 2, 3]).map(myfunc).collect() + |print(sc.parallelize([1, 2, 3]).map(myfunc).collect()) """) self.create_spark_package("a:mylib:0.1") proc = subprocess.Popen([self.sparkSubmit, "--packages", "a:mylib:0.1", "--repositories", "file:" + self.programDir, script], stdout=subprocess.PIPE) out, err = proc.communicate() self.assertEqual(0, proc.returncode) - self.assertIn("[2, 3, 4]", out) + self.assertIn("[2, 3, 4]", out.decode('utf-8')) def test_package_dependency_on_cluster(self): """Submit and test a script with a dependency on a Spark Package on a cluster""" @@ -1664,7 +1685,7 @@ def test_package_dependency_on_cluster(self): |from mylib import myfunc | |sc = SparkContext() - |print sc.parallelize([1, 2, 3]).map(myfunc).collect() + |print(sc.parallelize([1, 2, 3]).map(myfunc).collect()) """) self.create_spark_package("a:mylib:0.1") proc = subprocess.Popen([self.sparkSubmit, "--packages", "a:mylib:0.1", "--repositories", @@ -1672,7 +1693,7 @@ def test_package_dependency_on_cluster(self): "local-cluster[1,1,512]", script], stdout=subprocess.PIPE) out, err = proc.communicate() self.assertEqual(0, proc.returncode) - self.assertIn("[2, 3, 4]", out) + self.assertIn("[2, 3, 4]", out.decode('utf-8')) def test_single_script_on_cluster(self): """Submit and test a single script on a cluster""" @@ -1683,7 +1704,7 @@ def test_single_script_on_cluster(self): | return x * 2 | |sc = SparkContext() - |print sc.parallelize([1, 2, 3]).map(foo).collect() + |print(sc.parallelize([1, 2, 3]).map(foo).collect()) """) # this will fail if you have different spark.executor.memory # in conf/spark-defaults.conf @@ -1692,7 +1713,7 @@ def test_single_script_on_cluster(self): stdout=subprocess.PIPE) out, err = proc.communicate() self.assertEqual(0, proc.returncode) - self.assertIn("[2, 4, 6]", out) + self.assertIn("[2, 4, 6]", out.decode('utf-8')) class ContextTests(unittest.TestCase): @@ -1767,7 +1788,7 @@ class SciPyTests(PySparkTestCase): def test_serialize(self): from scipy.special import gammaln x = range(1, 5) - expected = map(gammaln, x) + expected = list(map(gammaln, x)) observed = self.sc.parallelize(x).map(gammaln).collect() self.assertEqual(expected, observed) @@ -1788,11 +1809,11 @@ def test_statcounter_array(self): if __name__ == "__main__": if not _have_scipy: - print "NOTE: Skipping SciPy tests as it does not seem to be installed" + print("NOTE: Skipping SciPy tests as it does not seem to be installed") if not _have_numpy: - print "NOTE: Skipping NumPy tests as it does not seem to be installed" + print("NOTE: Skipping NumPy tests as it does not seem to be installed") unittest.main() if not _have_scipy: - print "NOTE: SciPy tests were skipped as it does not seem to be installed" + print("NOTE: SciPy tests were skipped as it does not seem to be installed") if not _have_numpy: - print "NOTE: NumPy tests were skipped as it does not seem to be installed" + print("NOTE: NumPy tests were skipped as it does not seem to be installed") diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py index 452d6fabdcc17..fbdaf3a5814cd 100644 --- a/python/pyspark/worker.py +++ b/python/pyspark/worker.py @@ -18,6 +18,7 @@ """ Worker that receives input from Piped RDD. """ +from __future__ import print_function import os import sys import time @@ -37,9 +38,9 @@ def report_times(outfile, boot, init, finish): write_int(SpecialLengths.TIMING_DATA, outfile) - write_long(1000 * boot, outfile) - write_long(1000 * init, outfile) - write_long(1000 * finish, outfile) + write_long(int(1000 * boot), outfile) + write_long(int(1000 * init), outfile) + write_long(int(1000 * finish), outfile) def add_path(path): @@ -72,6 +73,9 @@ def main(infile, outfile): for _ in range(num_python_includes): filename = utf8_deserializer.loads(infile) add_path(os.path.join(spark_files_dir, filename)) + if sys.version > '3': + import importlib + importlib.invalidate_caches() # fetch names and values of broadcast variables num_broadcast_variables = read_int(infile) @@ -106,14 +110,14 @@ def process(): except Exception: try: write_int(SpecialLengths.PYTHON_EXCEPTION_THROWN, outfile) - write_with_length(traceback.format_exc(), outfile) + write_with_length(traceback.format_exc().encode("utf-8"), outfile) except IOError: # JVM close the socket pass except Exception: # Write the error to stderr if it happened while serializing - print >> sys.stderr, "PySpark worker failed with exception:" - print >> sys.stderr, traceback.format_exc() + print("PySpark worker failed with exception:", file=sys.stderr) + print(traceback.format_exc(), file=sys.stderr) exit(-1) finish_time = time.time() report_times(outfile, boot_time, init_time, finish_time) diff --git a/python/run-tests b/python/run-tests index f3a07d8aba562..ed3e819ef30c1 100755 --- a/python/run-tests +++ b/python/run-tests @@ -66,7 +66,7 @@ function run_core_tests() { function run_sql_tests() { echo "Run sql tests ..." - run_test "pyspark/sql/types.py" + run_test "pyspark/sql/_types.py" run_test "pyspark/sql/context.py" run_test "pyspark/sql/dataframe.py" run_test "pyspark/sql/functions.py" @@ -136,6 +136,19 @@ run_mllib_tests run_ml_tests run_streaming_tests +# Try to test with Python 3 +if [ $(which python3.4) ]; then + export PYSPARK_PYTHON="python3.4" + echo "Testing with Python3.4 version:" + $PYSPARK_PYTHON --version + + run_core_tests + run_sql_tests + run_mllib_tests + run_ml_tests + run_streaming_tests +fi + # Try to test with PyPy if [ $(which pypy) ]; then export PYSPARK_PYTHON="pypy" diff --git a/python/test_support/userlib-0.1-py2.7.egg b/python/test_support/userlib-0.1-py2.7.egg deleted file mode 100644 index 1674c9cb2227e..0000000000000 Binary files a/python/test_support/userlib-0.1-py2.7.egg and /dev/null differ diff --git a/python/test_support/userlib-0.1.zip b/python/test_support/userlib-0.1.zip new file mode 100644 index 0000000000000..496e1349aa967 Binary files /dev/null and b/python/test_support/userlib-0.1.zip differ diff --git a/sbin/spark-daemon.sh b/sbin/spark-daemon.sh index d8e0facb81169..de762acc8fa0e 100755 --- a/sbin/spark-daemon.sh +++ b/sbin/spark-daemon.sh @@ -129,7 +129,7 @@ run_command() { if [ -f "$pid" ]; then TARGET_ID="$(cat "$pid")" - if [[ $(ps -p "$TARGET_ID" -o args=) =~ $command ]]; then + if [[ $(ps -p "$TARGET_ID" -o comm=) =~ "java" ]]; then echo "$command running as process $TARGET_ID. Stop it first." exit 1 fi @@ -163,7 +163,7 @@ run_command() { echo "$newpid" > "$pid" sleep 2 # Check if the process has died; in that case we'll tail the log so the user can see - if [[ ! $(ps -p "$newpid" -o args=) =~ $command ]]; then + if [[ ! $(ps -p "$newpid" -o comm=) =~ "java" ]]; then echo "failed to launch $command:" tail -2 "$log" | sed 's/^/ /' echo "full log in $log" diff --git a/sbin/start-slave.sh b/sbin/start-slave.sh index 5a6de11afdd3d..4c919ff76a8f5 100755 --- a/sbin/start-slave.sh +++ b/sbin/start-slave.sh @@ -18,15 +18,68 @@ # # Starts a slave on the machine this script is executed on. +# +# Environment Variables +# +# SPARK_WORKER_INSTANCES The number of worker instances to run on this +# slave. Default is 1. +# SPARK_WORKER_PORT The base port number for the first worker. If set, +# subsequent workers will increment this number. If +# unset, Spark will find a valid port number, but +# with no guarantee of a predictable pattern. +# SPARK_WORKER_WEBUI_PORT The base port for the web interface of the first +# worker. Subsequent workers will increment this +# number. Default is 8081. -usage="Usage: start-slave.sh where is like spark://localhost:7077" +usage="Usage: start-slave.sh where is like spark://localhost:7077" -if [ $# -lt 2 ]; then +if [ $# -lt 1 ]; then echo $usage + echo Called as start-slave.sh $* exit 1 fi sbin="`dirname "$0"`" sbin="`cd "$sbin"; pwd`" -"$sbin"/spark-daemon.sh start org.apache.spark.deploy.worker.Worker "$@" +. "$sbin/spark-config.sh" + +. "$SPARK_PREFIX/bin/load-spark-env.sh" + +# First argument should be the master; we need to store it aside because we may +# need to insert arguments between it and the other arguments +MASTER=$1 +shift + +# Determine desired worker port +if [ "$SPARK_WORKER_WEBUI_PORT" = "" ]; then + SPARK_WORKER_WEBUI_PORT=8081 +fi + +# Start up the appropriate number of workers on this machine. +# quick local function to start a worker +function start_instance { + WORKER_NUM=$1 + shift + + if [ "$SPARK_WORKER_PORT" = "" ]; then + PORT_FLAG= + PORT_NUM= + else + PORT_FLAG="--port" + PORT_NUM=$(( $SPARK_WORKER_PORT + $WORKER_NUM - 1 )) + fi + WEBUI_PORT=$(( $SPARK_WORKER_WEBUI_PORT + $WORKER_NUM - 1 )) + + "$sbin"/spark-daemon.sh start org.apache.spark.deploy.worker.Worker $WORKER_NUM \ + --webui-port "$WEBUI_PORT" $PORT_FLAG $PORT_NUM $MASTER "$@" +} + +if [ "$SPARK_WORKER_INSTANCES" = "" ]; then + start_instance 1 "$@" +else + for ((i=0; i<$SPARK_WORKER_INSTANCES; i++)); do + start_instance $(( 1 + $i )) "$@" + done +fi + diff --git a/sbin/start-slaves.sh b/sbin/start-slaves.sh index 4356c03657109..24d6268815ed3 100755 --- a/sbin/start-slaves.sh +++ b/sbin/start-slaves.sh @@ -59,13 +59,4 @@ if [ "$START_TACHYON" == "true" ]; then fi # Launch the slaves -if [ "$SPARK_WORKER_INSTANCES" = "" ]; then - exec "$sbin/slaves.sh" cd "$SPARK_HOME" \; "$sbin/start-slave.sh" 1 "spark://$SPARK_MASTER_IP:$SPARK_MASTER_PORT" -else - if [ "$SPARK_WORKER_WEBUI_PORT" = "" ]; then - SPARK_WORKER_WEBUI_PORT=8081 - fi - for ((i=0; i<$SPARK_WORKER_INSTANCES; i++)); do - "$sbin/slaves.sh" cd "$SPARK_HOME" \; "$sbin/start-slave.sh" $(( $i + 1 )) --webui-port $(( $SPARK_WORKER_WEBUI_PORT + $i )) "spark://$SPARK_MASTER_IP:$SPARK_MASTER_PORT" - done -fi +"$sbin/slaves.sh" cd "$SPARK_HOME" \; "$sbin/start-slave.sh" "spark://$SPARK_MASTER_IP:$SPARK_MASTER_PORT" diff --git a/sbin/stop-slave.sh b/sbin/stop-slave.sh new file mode 100755 index 0000000000000..3d1da5b254f2a --- /dev/null +++ b/sbin/stop-slave.sh @@ -0,0 +1,43 @@ +#!/usr/bin/env bash + +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# A shell script to stop all workers on a single slave +# +# Environment variables +# +# SPARK_WORKER_INSTANCES The number of worker instances that should be +# running on this slave. Default is 1. + +# Usage: stop-slave.sh +# Stops all slaves on this worker machine + +sbin="`dirname "$0"`" +sbin="`cd "$sbin"; pwd`" + +. "$sbin/spark-config.sh" + +. "$SPARK_PREFIX/bin/load-spark-env.sh" + +if [ "$SPARK_WORKER_INSTANCES" = "" ]; then + "$sbin"/spark-daemon.sh stop org.apache.spark.deploy.worker.Worker 1 +else + for ((i=0; i<$SPARK_WORKER_INSTANCES; i++)); do + "$sbin"/spark-daemon.sh stop org.apache.spark.deploy.worker.Worker $(( $i + 1 )) + done +fi diff --git a/sbin/stop-slaves.sh b/sbin/stop-slaves.sh index 7c2201100ef97..54c9bd46803a9 100755 --- a/sbin/stop-slaves.sh +++ b/sbin/stop-slaves.sh @@ -17,8 +17,8 @@ # limitations under the License. # -sbin=`dirname "$0"` -sbin=`cd "$sbin"; pwd` +sbin="`dirname "$0"`" +sbin="`cd "$sbin"; pwd`" . "$sbin/spark-config.sh" @@ -29,10 +29,4 @@ if [ -e "$sbin"/../tachyon/bin/tachyon ]; then "$sbin/slaves.sh" cd "$SPARK_HOME" \; "$sbin"/../tachyon/bin/tachyon killAll tachyon.worker.Worker fi -if [ "$SPARK_WORKER_INSTANCES" = "" ]; then - "$sbin"/spark-daemons.sh stop org.apache.spark.deploy.worker.Worker 1 -else - for ((i=0; i<$SPARK_WORKER_INSTANCES; i++)); do - "$sbin"/spark-daemons.sh stop org.apache.spark.deploy.worker.Worker $(( $i + 1 )) - done -fi +"$sbin/slaves.sh" cd "$SPARK_HOME" \; "$sbin"/stop-slave.sh diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/Row.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/Row.scala index d794f034f5578..ac8a782976465 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/Row.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/Row.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql import scala.util.hashing.MurmurHash3 import org.apache.spark.sql.catalyst.expressions.GenericRow -import org.apache.spark.sql.types.{StructType, DateUtils} +import org.apache.spark.sql.types.StructType object Row { /** @@ -257,6 +257,7 @@ trait Row extends Serializable { * * @throws ClassCastException when data type does not match. */ + // TODO(davies): This is not the right default implementation, we use Int as Date internally def getDate(i: Int): java.sql.Date = apply(i).asInstanceOf[java.sql.Date] /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala index 91976fef6dc0d..d4f9fdacda4fb 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala @@ -77,6 +77,9 @@ object CatalystTypeConverters { } new GenericRowWithSchema(ar, structType) + case (d: String, _) => + UTF8String(d) + case (d: BigDecimal, _) => Decimal(d) @@ -175,6 +178,11 @@ object CatalystTypeConverters { case other => other } + case dataType: StringType => (item: Any) => extractOption(item) match { + case s: String => UTF8String(s) + case other => other + } + case _ => (item: Any) => extractOption(item) match { case d: BigDecimal => Decimal(d) @@ -184,6 +192,26 @@ object CatalystTypeConverters { } } + /** + * Converts Scala objects to catalyst rows / types. + * + * Note: This should be called before do evaluation on Row + * (It does not support UDT) + * This is used to create an RDD or test results with correct types for Catalyst. + */ + def convertToCatalyst(a: Any): Any = a match { + case s: String => UTF8String(s) + case d: java.sql.Date => DateUtils.fromJavaDate(d) + case d: BigDecimal => Decimal(d) + case d: java.math.BigDecimal => Decimal(d) + case seq: Seq[Any] => seq.map(convertToCatalyst) + case r: Row => Row(r.toSeq.map(convertToCatalyst): _*) + case arr: Array[Any] => arr.toSeq.map(convertToCatalyst).toArray + case m: Map[Any, Any] => + m.map { case (k, v) => (convertToCatalyst(k), convertToCatalyst(v)) }.toMap + case other => other + } + /** * Converts Catalyst types used internally in rows to standard Scala types * This method is slow, and for batch conversion you should be using converter @@ -211,6 +239,9 @@ object CatalystTypeConverters { case (i: Int, DateType) => DateUtils.toJavaDate(i) + case (s: UTF8String, StringType) => + s.toString() + case (other, _) => other } @@ -262,6 +293,12 @@ object CatalystTypeConverters { case other => other } + case StringType => + (item: Any) => item match { + case s: UTF8String => s.toString() + case other => other + } + case other => (item: Any) => item } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala index 01d5c1512201a..d9521953cad73 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala @@ -138,6 +138,7 @@ trait ScalaReflection { // The data type can be determined without ambiguity. case obj: BooleanType.JvmType => BooleanType case obj: BinaryType.JvmType => BinaryType + case obj: String => StringType case obj: StringType.JvmType => StringType case obj: ByteType.JvmType => ByteType case obj: ShortType.JvmType => ShortType diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala index bc8d3751f6616..0af969cc5cc67 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala @@ -121,14 +121,14 @@ class SqlParser extends AbstractSparkSQLParser with DataTypeParser { } protected lazy val start: Parser[LogicalPlan] = - ( (select | ("(" ~> select <~ ")")) * - ( UNION ~ ALL ^^^ { (q1: LogicalPlan, q2: LogicalPlan) => Union(q1, q2) } - | INTERSECT ^^^ { (q1: LogicalPlan, q2: LogicalPlan) => Intersect(q1, q2) } - | EXCEPT ^^^ { (q1: LogicalPlan, q2: LogicalPlan) => Except(q1, q2)} - | UNION ~ DISTINCT.? ^^^ { (q1: LogicalPlan, q2: LogicalPlan) => Distinct(Union(q1, q2)) } - ) - | insert - | cte + start1 | insert | cte + + protected lazy val start1: Parser[LogicalPlan] = + (select | ("(" ~> select <~ ")")) * + ( UNION ~ ALL ^^^ { (q1: LogicalPlan, q2: LogicalPlan) => Union(q1, q2) } + | INTERSECT ^^^ { (q1: LogicalPlan, q2: LogicalPlan) => Intersect(q1, q2) } + | EXCEPT ^^^ { (q1: LogicalPlan, q2: LogicalPlan) => Except(q1, q2)} + | UNION ~ DISTINCT.? ^^^ { (q1: LogicalPlan, q2: LogicalPlan) => Distinct(Union(q1, q2)) } ) protected lazy val select: Parser[LogicalPlan] = @@ -159,7 +159,7 @@ class SqlParser extends AbstractSparkSQLParser with DataTypeParser { } protected lazy val cte: Parser[LogicalPlan] = - WITH ~> rep1sep(ident ~ ( AS ~ "(" ~> start <~ ")"), ",") ~ start ^^ { + WITH ~> rep1sep(ident ~ ( AS ~ "(" ~> start1 <~ ")"), ",") ~ (start1 | insert) ^^ { case r ~ s => With(s, r.map({case n ~ s => (n, Subquery(n, s))}).toMap) } @@ -381,13 +381,13 @@ class SqlParser extends AbstractSparkSQLParser with DataTypeParser { | "(" ~> expression <~ ")" | function | dotExpressionHeader - | ident ^^ UnresolvedAttribute + | ident ^^ {case i => UnresolvedAttribute.quoted(i)} | signedPrimary | "~" ~> expression ^^ BitwiseNot ) protected lazy val dotExpressionHeader: Parser[Expression] = (ident <~ ".") ~ ident ~ rep("." ~> ident) ^^ { - case i1 ~ i2 ~ rest => UnresolvedAttribute((Seq(i1, i2) ++ rest).mkString(".")) + case i1 ~ i2 ~ rest => UnresolvedAttribute(Seq(i1, i2) ++ rest) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 8b68b0df35f48..cb49e5ad5586f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -297,14 +297,15 @@ class Analyzer( case q: LogicalPlan => logTrace(s"Attempting to resolve ${q.simpleString}") q transformExpressionsUp { - case u @ UnresolvedAttribute(name) if resolver(name, VirtualColumn.groupingIdName) && + case u @ UnresolvedAttribute(nameParts) if nameParts.length == 1 && + resolver(nameParts(0), VirtualColumn.groupingIdName) && q.isInstanceOf[GroupingAnalytics] => // Resolve the virtual column GROUPING__ID for the operator GroupingAnalytics q.asInstanceOf[GroupingAnalytics].gid - case u @ UnresolvedAttribute(name) => + case u @ UnresolvedAttribute(nameParts) => // Leave unchanged if resolution fails. Hopefully will be resolved next round. val result = - withPosition(u) { q.resolveChildren(name, resolver).getOrElse(u) } + withPosition(u) { q.resolveChildren(nameParts, resolver).getOrElse(u) } logDebug(s"Resolving $u to $result") result case UnresolvedGetField(child, fieldName) if child.resolved => @@ -383,12 +384,12 @@ class Analyzer( child: LogicalPlan, grandchild: LogicalPlan): (Seq[SortOrder], Seq[Attribute]) = { // Find any attributes that remain unresolved in the sort. - val unresolved: Seq[String] = - ordering.flatMap(_.collect { case UnresolvedAttribute(name) => name }) + val unresolved: Seq[Seq[String]] = + ordering.flatMap(_.collect { case UnresolvedAttribute(nameParts) => nameParts }) // Create a map from name, to resolved attributes, when the desired name can be found // prior to the projection. - val resolved: Map[String, NamedExpression] = + val resolved: Map[Seq[String], NamedExpression] = unresolved.flatMap(u => grandchild.resolve(u, resolver).map(a => u -> a)).toMap // Construct a set that contains all of the attributes that we need to evaluate the diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala index fa02111385c06..1155dac28fc78 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala @@ -46,8 +46,12 @@ trait CheckAnalysis { operator transformExpressionsUp { case a: Attribute if !a.resolved => if (operator.childrenResolved) { + val nameParts = a match { + case UnresolvedAttribute(nameParts) => nameParts + case _ => Seq(a.name) + } // Throw errors for specific problems with get field. - operator.resolveChildren(a.name, resolver, throwErrors = true) + operator.resolveChildren(nameParts, resolver, throwErrors = true) } val from = operator.inputSet.map(_.name).mkString(", ") diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala index 3aeb964994d37..35c7f00d4e42a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala @@ -115,7 +115,7 @@ trait HiveTypeCoercion { * the appropriate numeric equivalent. */ object ConvertNaNs extends Rule[LogicalPlan] { - val stringNaN = Literal.create("NaN", StringType) + val stringNaN = Literal("NaN") def apply(plan: LogicalPlan): LogicalPlan = plan transform { case q: LogicalPlan => q transformExpressions { @@ -563,6 +563,10 @@ trait HiveTypeCoercion { case Sum(e @ TimestampType()) => Sum(Cast(e, DoubleType)) case Average(e @ TimestampType()) => Average(Cast(e, DoubleType)) + // Compatible with Hive + case Substring(e, start, len) if e.dataType != StringType => + Substring(Cast(e, StringType), start, len) + // Coalesce should return the first non-null value, which could be any column // from the list. So we need to make sure the return type is deterministic and // compatible with every child column. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala index 300e9ba187bc5..3f567e3e8b2a6 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala @@ -49,7 +49,12 @@ case class UnresolvedRelation( /** * Holds the name of an attribute that has yet to be resolved. */ -case class UnresolvedAttribute(name: String) extends Attribute with trees.LeafNode[Expression] { +case class UnresolvedAttribute(nameParts: Seq[String]) + extends Attribute with trees.LeafNode[Expression] { + + def name: String = + nameParts.map(n => if (n.contains(".")) s"`$n`" else n).mkString(".") + override def exprId: ExprId = throw new UnresolvedException(this, "exprId") override def dataType: DataType = throw new UnresolvedException(this, "dataType") override def nullable: Boolean = throw new UnresolvedException(this, "nullable") @@ -59,7 +64,7 @@ case class UnresolvedAttribute(name: String) extends Attribute with trees.LeafNo override def newInstance(): UnresolvedAttribute = this override def withNullability(newNullability: Boolean): UnresolvedAttribute = this override def withQualifiers(newQualifiers: Seq[String]): UnresolvedAttribute = this - override def withName(newName: String): UnresolvedAttribute = UnresolvedAttribute(name) + override def withName(newName: String): UnresolvedAttribute = UnresolvedAttribute.quoted(newName) // Unresolved attributes are transient at compile time and don't get evaluated during execution. override def eval(input: Row = null): EvaluatedType = @@ -68,6 +73,11 @@ case class UnresolvedAttribute(name: String) extends Attribute with trees.LeafNo override def toString: String = s"'$name" } +object UnresolvedAttribute { + def apply(name: String): UnresolvedAttribute = new UnresolvedAttribute(name.split("\\.")) + def quoted(name: String): UnresolvedAttribute = new UnresolvedAttribute(Seq(name)) +} + case class UnresolvedFunction(name: String, children: Seq[Expression]) extends Expression { override def dataType: DataType = throw new UnresolvedException(this, "dataType") override def foldable: Boolean = throw new UnresolvedException(this, "foldable") diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala index 31f1a5fdc7e53..adf941ab2a45f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala @@ -21,7 +21,6 @@ import java.sql.{Date, Timestamp} import java.text.{DateFormat, SimpleDateFormat} import org.apache.spark.Logging -import org.apache.spark.sql.catalyst.errors.TreeNodeException import org.apache.spark.sql.types._ /** Cast the child expression to the target data type. */ @@ -112,21 +111,21 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w // UDFToString private[this] def castToString(from: DataType): Any => Any = from match { - case BinaryType => buildCast[Array[Byte]](_, new String(_, "UTF-8")) - case DateType => buildCast[Int](_, d => DateUtils.toString(d)) - case TimestampType => buildCast[Timestamp](_, timestampToString) - case _ => buildCast[Any](_, _.toString) + case BinaryType => buildCast[Array[Byte]](_, UTF8String(_)) + case DateType => buildCast[Int](_, d => UTF8String(DateUtils.toString(d))) + case TimestampType => buildCast[Timestamp](_, t => UTF8String(timestampToString(t))) + case _ => buildCast[Any](_, o => UTF8String(o.toString)) } // BinaryConverter private[this] def castToBinary(from: DataType): Any => Any = from match { - case StringType => buildCast[String](_, _.getBytes("UTF-8")) + case StringType => buildCast[UTF8String](_, _.getBytes) } // UDFToBoolean private[this] def castToBoolean(from: DataType): Any => Any = from match { case StringType => - buildCast[String](_, _.length() != 0) + buildCast[UTF8String](_, _.length() != 0) case TimestampType => buildCast[Timestamp](_, t => t.getTime() != 0 || t.getNanos() != 0) case DateType => @@ -151,8 +150,9 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w // TimestampConverter private[this] def castToTimestamp(from: DataType): Any => Any = from match { case StringType => - buildCast[String](_, s => { + buildCast[UTF8String](_, utfs => { // Throw away extra if more than 9 decimal places + val s = utfs.toString val periodIdx = s.indexOf(".") var n = s if (periodIdx != -1 && n.length() - periodIdx > 9) { @@ -227,8 +227,8 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w // DateConverter private[this] def castToDate(from: DataType): Any => Any = from match { case StringType => - buildCast[String](_, s => - try DateUtils.fromJavaDate(Date.valueOf(s)) + buildCast[UTF8String](_, s => + try DateUtils.fromJavaDate(Date.valueOf(s.toString)) catch { case _: java.lang.IllegalArgumentException => null } ) case TimestampType => @@ -245,7 +245,7 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w // LongConverter private[this] def castToLong(from: DataType): Any => Any = from match { case StringType => - buildCast[String](_, s => try s.toLong catch { + buildCast[UTF8String](_, s => try s.toString.toLong catch { case _: NumberFormatException => null }) case BooleanType => @@ -261,7 +261,7 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w // IntConverter private[this] def castToInt(from: DataType): Any => Any = from match { case StringType => - buildCast[String](_, s => try s.toInt catch { + buildCast[UTF8String](_, s => try s.toString.toInt catch { case _: NumberFormatException => null }) case BooleanType => @@ -277,7 +277,7 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w // ShortConverter private[this] def castToShort(from: DataType): Any => Any = from match { case StringType => - buildCast[String](_, s => try s.toShort catch { + buildCast[UTF8String](_, s => try s.toString.toShort catch { case _: NumberFormatException => null }) case BooleanType => @@ -293,7 +293,7 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w // ByteConverter private[this] def castToByte(from: DataType): Any => Any = from match { case StringType => - buildCast[String](_, s => try s.toByte catch { + buildCast[UTF8String](_, s => try s.toString.toByte catch { case _: NumberFormatException => null }) case BooleanType => @@ -323,7 +323,9 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w private[this] def castToDecimal(from: DataType, target: DecimalType): Any => Any = from match { case StringType => - buildCast[String](_, s => try changePrecision(Decimal(s.toDouble), target) catch { + buildCast[UTF8String](_, s => try { + changePrecision(Decimal(s.toString.toDouble), target) + } catch { case _: NumberFormatException => null }) case BooleanType => @@ -348,7 +350,7 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w // DoubleConverter private[this] def castToDouble(from: DataType): Any => Any = from match { case StringType => - buildCast[String](_, s => try s.toDouble catch { + buildCast[UTF8String](_, s => try s.toString.toDouble catch { case _: NumberFormatException => null }) case BooleanType => @@ -364,7 +366,7 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w // FloatConverter private[this] def castToFloat(from: DataType): Any => Any = from match { case StringType => - buildCast[String](_, s => try s.toFloat catch { + buildCast[UTF8String](_, s => try s.toString.toFloat catch { case _: NumberFormatException => null }) case BooleanType => diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SpecificMutableRow.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SpecificMutableRow.scala index 47b6f358ed1b1..3475ed05f4454 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SpecificMutableRow.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SpecificMutableRow.scala @@ -230,13 +230,17 @@ final class SpecificMutableRow(val values: Array[MutableValue]) extends MutableR new GenericRow(newValues) } - override def update(ordinal: Int, value: Any): Unit = { - if (value == null) setNullAt(ordinal) else values(ordinal).update(value) + override def update(ordinal: Int, value: Any) { + if (value == null) { + setNullAt(ordinal) + } else { + values(ordinal).update(value) + } } - override def setString(ordinal: Int, value: String): Unit = update(ordinal, value) + override def setString(ordinal: Int, value: String): Unit = update(ordinal, UTF8String(value)) - override def getString(ordinal: Int): String = apply(ordinal).asInstanceOf[String] + override def getString(ordinal: Int): String = apply(ordinal).toString override def setInt(ordinal: Int, value: Int): Unit = { val currentValue = values(ordinal).asInstanceOf[MutableInt] diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala index 14a855054b94d..f3830c6d3bcf2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala @@ -326,7 +326,7 @@ case class Average(child: Expression) extends PartialAggregate with trees.UnaryN override def asPartial: SplitEvaluation = { child.dataType match { - case DecimalType.Fixed(_, _) => + case DecimalType.Fixed(_, _) | DecimalType.Unlimited => // Turn the child to unlimited decimals for calculation, before going back to fixed val partialSum = Alias(Sum(Cast(child, DecimalType.Unlimited)), "PartialSum")() val partialCount = Alias(Count(child), "PartialCount")() diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala index 1f6526ef66c56..566b34f7c3a6a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala @@ -369,6 +369,51 @@ case class MaxOf(left: Expression, right: Expression) extends Expression { override def toString: String = s"MaxOf($left, $right)" } +case class MinOf(left: Expression, right: Expression) extends Expression { + type EvaluatedType = Any + + override def foldable: Boolean = left.foldable && right.foldable + + override def nullable: Boolean = left.nullable && right.nullable + + override def children: Seq[Expression] = left :: right :: Nil + + override lazy val resolved = + left.resolved && right.resolved && + left.dataType == right.dataType + + override def dataType: DataType = { + if (!resolved) { + throw new UnresolvedException(this, + s"datatype. Can not resolve due to differing types ${left.dataType}, ${right.dataType}") + } + left.dataType + } + + lazy val ordering = left.dataType match { + case i: NativeType => i.ordering.asInstanceOf[Ordering[Any]] + case other => sys.error(s"Type $other does not support ordered operations") + } + + override def eval(input: Row): Any = { + val evalE1 = left.eval(input) + val evalE2 = right.eval(input) + if (evalE1 == null) { + evalE2 + } else if (evalE2 == null) { + evalE1 + } else { + if (ordering.compare(evalE1, evalE2) < 0) { + evalE1 + } else { + evalE2 + } + } + } + + override def toString: String = s"MinOf($left, $right)" +} + /** * A function that get the absolute value of the numeric value. */ diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index aac56e1568332..be2c101d63a63 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -216,10 +216,11 @@ abstract class CodeGenerator[InType <: AnyRef, OutType <: AnyRef] extends Loggin val $primitiveTerm: ${termForType(dataType)} = $value """.children - case expressions.Literal(value: String, dataType) => + case expressions.Literal(value: UTF8String, dataType) => q""" val $nullTerm = ${value == null} - val $primitiveTerm: ${termForType(dataType)} = $value + val $primitiveTerm: ${termForType(dataType)} = + org.apache.spark.sql.types.UTF8String(${value.getBytes}) """.children case expressions.Literal(value: Int, dataType) => @@ -243,11 +244,14 @@ abstract class CodeGenerator[InType <: AnyRef, OutType <: AnyRef] extends Loggin if($nullTerm) ${defaultPrimitive(StringType)} else - new String(${eval.primitiveTerm}.asInstanceOf[Array[Byte]]) + org.apache.spark.sql.types.UTF8String(${eval.primitiveTerm}.asInstanceOf[Array[Byte]]) """.children case Cast(child @ DateType(), StringType) => - child.castOrNull(c => q"org.apache.spark.sql.types.DateUtils.toString($c)", StringType) + child.castOrNull(c => + q"""org.apache.spark.sql.types.UTF8String( + org.apache.spark.sql.types.DateUtils.toString($c))""", + StringType) case Cast(child @ NumericType(), IntegerType) => child.castOrNull(c => q"$c.toInt", IntegerType) @@ -272,9 +276,18 @@ abstract class CodeGenerator[InType <: AnyRef, OutType <: AnyRef] extends Loggin if($nullTerm) ${defaultPrimitive(StringType)} else - ${eval.primitiveTerm}.toString + org.apache.spark.sql.types.UTF8String(${eval.primitiveTerm}.toString) """.children + case EqualTo(e1: BinaryType, e2: BinaryType) => + (e1, e2).evaluateAs (BooleanType) { + case (eval1, eval2) => + q""" + java.util.Arrays.equals($eval1.asInstanceOf[Array[Byte]], + $eval2.asInstanceOf[Array[Byte]]) + """ + } + case EqualTo(e1, e2) => (e1, e2).evaluateAs (BooleanType) { case (eval1, eval2) => q"$eval1 == $eval2" } @@ -524,6 +537,30 @@ abstract class CodeGenerator[InType <: AnyRef, OutType <: AnyRef] extends Loggin } """.children + case MinOf(e1, e2) => + val eval1 = expressionEvaluator(e1) + val eval2 = expressionEvaluator(e2) + + eval1.code ++ eval2.code ++ + q""" + var $nullTerm = false + var $primitiveTerm: ${termForType(e1.dataType)} = ${defaultPrimitive(e1.dataType)} + + if (${eval1.nullTerm}) { + $nullTerm = ${eval2.nullTerm} + $primitiveTerm = ${eval2.primitiveTerm} + } else if (${eval2.nullTerm}) { + $nullTerm = ${eval1.nullTerm} + $primitiveTerm = ${eval1.primitiveTerm} + } else { + if (${eval1.primitiveTerm} < ${eval2.primitiveTerm}) { + $primitiveTerm = ${eval1.primitiveTerm} + } else { + $primitiveTerm = ${eval2.primitiveTerm} + } + } + """.children + case UnscaledValue(child) => val childEval = expressionEvaluator(child) @@ -573,7 +610,8 @@ abstract class CodeGenerator[InType <: AnyRef, OutType <: AnyRef] extends Loggin val localLogger = log val localLoggerTree = reify { localLogger } q""" - $localLoggerTree.debug(${e.toString} + ": " + (if($nullTerm) "null" else $primitiveTerm)) + $localLoggerTree.debug( + ${e.toString} + ": " + (if ($nullTerm) "null" else $primitiveTerm.toString)) """ :: Nil } else { Nil @@ -584,6 +622,7 @@ abstract class CodeGenerator[InType <: AnyRef, OutType <: AnyRef] extends Loggin protected def getColumn(inputRow: TermName, dataType: DataType, ordinal: Int) = { dataType match { + case StringType => q"$inputRow($ordinal).asInstanceOf[org.apache.spark.sql.types.UTF8String]" case dt @ NativeType() => q"$inputRow.${accessorForType(dt)}($ordinal)" case _ => q"$inputRow.apply($ordinal).asInstanceOf[${termForType(dataType)}]" } @@ -595,6 +634,7 @@ abstract class CodeGenerator[InType <: AnyRef, OutType <: AnyRef] extends Loggin ordinal: Int, value: TermName) = { dataType match { + case StringType => q"$destinationRow.update($ordinal, $value)" case dt @ NativeType() => q"$destinationRow.${mutatorForType(dt)}($ordinal, $value)" case _ => q"$destinationRow.update($ordinal, $value)" } @@ -618,13 +658,13 @@ abstract class CodeGenerator[InType <: AnyRef, OutType <: AnyRef] extends Loggin case DoubleType => "Double" case FloatType => "Float" case BooleanType => "Boolean" - case StringType => "String" + case StringType => "org.apache.spark.sql.types.UTF8String" } protected def defaultPrimitive(dt: DataType) = dt match { case BooleanType => ru.Literal(Constant(false)) case FloatType => ru.Literal(Constant(-1.0.toFloat)) - case StringType => ru.Literal(Constant("")) + case StringType => q"""org.apache.spark.sql.types.UTF8String("")""" case ShortType => ru.Literal(Constant(-1.toShort)) case LongType => ru.Literal(Constant(-1L)) case ByteType => ru.Literal(Constant(-1.toByte)) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateOrdering.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateOrdering.scala index 0db29eb404bd1..fc2a2b60703e4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateOrdering.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateOrdering.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.catalyst.expressions.codegen import org.apache.spark.Logging import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.types.{StringType, NumericType} +import org.apache.spark.sql.types.{BinaryType, StringType, NumericType} /** * Generates bytecode for an [[Ordering]] of [[Row Rows]] for a given set of @@ -43,6 +43,18 @@ object GenerateOrdering extends CodeGenerator[Seq[SortOrder], Ordering[Row]] wit val evalB = expressionEvaluator(order.child) val compare = order.child.dataType match { + case BinaryType => + q""" + val x = ${if (order.direction == Ascending) evalA.primitiveTerm else evalB.primitiveTerm} + val y = ${if (order.direction != Ascending) evalB.primitiveTerm else evalA.primitiveTerm} + var i = 0 + while (i < x.length && i < y.length) { + val res = x(i).compareTo(y(i)) + if (res != 0) return res + i = i+1 + } + return x.length - y.length + """ case _: NumericType => q""" val comp = ${evalA.primitiveTerm} - ${evalB.primitiveTerm} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala index 69397a73a8880..6f572ff959fb4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala @@ -111,36 +111,54 @@ object GenerateProjection extends CodeGenerator[Seq[Expression], Projection] { val specificAccessorFunctions = NativeType.all.map { dataType => val ifStatements = expressions.zipWithIndex.flatMap { - case (e, i) if e.dataType == dataType => + // getString() is not used by expressions + case (e, i) if e.dataType == dataType && dataType != StringType => val elementName = newTermName(s"c$i") // TODO: The string of ifs gets pretty inefficient as the row grows in size. // TODO: Optional null checks? q"if(i == $i) return $elementName" :: Nil case _ => Nil } - - q""" - override def ${accessorForType(dataType)}(i: Int):${termForType(dataType)} = { - ..$ifStatements; - $accessorFailure - }""" + dataType match { + // Row() need this interface to compile + case StringType => + q""" + override def getString(i: Int): String = { + $accessorFailure + }""" + case other => + q""" + override def ${accessorForType(dataType)}(i: Int): ${termForType(dataType)} = { + ..$ifStatements; + $accessorFailure + }""" + } } val specificMutatorFunctions = NativeType.all.map { dataType => val ifStatements = expressions.zipWithIndex.flatMap { - case (e, i) if e.dataType == dataType => + // setString() is not used by expressions + case (e, i) if e.dataType == dataType && dataType != StringType => val elementName = newTermName(s"c$i") // TODO: The string of ifs gets pretty inefficient as the row grows in size. // TODO: Optional null checks? q"if(i == $i) { nullBits($i) = false; $elementName = value; return }" :: Nil case _ => Nil } - - q""" - override def ${mutatorForType(dataType)}(i: Int, value: ${termForType(dataType)}): Unit = { - ..$ifStatements; - $accessorFailure - }""" + dataType match { + case StringType => + // MutableRow() need this interface to compile + q""" + override def setString(i: Int, value: String) { + $accessorFailure + }""" + case other => + q""" + override def ${mutatorForType(dataType)}(i: Int, value: ${termForType(dataType)}) { + ..$ifStatements; + $accessorFailure + }""" + } } val hashValues = expressions.zipWithIndex.map { case (e,i) => diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala index 860b72fad38b3..67caadb839ff9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.catalyst.expressions import scala.collection.Map -import org.apache.spark.sql.catalyst.trees +import org.apache.spark.sql.catalyst.{CatalystTypeConverters, trees} import org.apache.spark.sql.types._ /** @@ -85,8 +85,11 @@ case class UserDefinedGenerator( override protected def makeOutput(): Seq[Attribute] = schema override def eval(input: Row): TraversableOnce[Row] = { + // TODO(davies): improve this + // Convert the objects into Scala Type before calling function, we need schema to support UDT + val inputSchema = StructType(children.map(e => StructField(e.simpleString, e.dataType, true))) val inputRow = new InterpretedProjection(children) - function(inputRow(input)) + function(CatalystTypeConverters.convertToScala(inputRow(input), inputSchema).asInstanceOf[Row]) } override def toString: String = s"UserDefinedGenerator(${children.mkString(",")})" diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala index 0e2d593e94124..18cba4cc46707 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.catalyst.expressions import java.sql.{Date, Timestamp} +import org.apache.spark.sql.catalyst.CatalystTypeConverters import org.apache.spark.sql.types._ object Literal { @@ -29,7 +30,7 @@ object Literal { case f: Float => Literal(f, FloatType) case b: Byte => Literal(b, ByteType) case s: Short => Literal(s, ShortType) - case s: String => Literal(s, StringType) + case s: String => Literal(UTF8String(s), StringType) case b: Boolean => Literal(b, BooleanType) case d: BigDecimal => Literal(Decimal(d), DecimalType.Unlimited) case d: java.math.BigDecimal => Literal(Decimal(d), DecimalType.Unlimited) @@ -42,7 +43,9 @@ object Literal { throw new RuntimeException("Unsupported literal type " + v.getClass + " " + v) } - def create(v: Any, dataType: DataType): Literal = Literal(v, dataType) + def create(v: Any, dataType: DataType): Literal = { + Literal(CatalystTypeConverters.convertToCatalyst(v), dataType) + } } /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala index 7e47cb3fffe12..fcd6352079b4d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala @@ -179,8 +179,7 @@ case class EqualTo(left: Expression, right: Expression) extends BinaryComparison val r = right.eval(input) if (r == null) null else if (left.dataType != BinaryType) l == r - else BinaryType.ordering.compare( - l.asInstanceOf[Array[Byte]], r.asInstanceOf[Array[Byte]]) == 0 + else java.util.Arrays.equals(l.asInstanceOf[Array[Byte]], r.asInstanceOf[Array[Byte]]) } } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala index 0a275b84086cf..b6ec7d3417ef8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala @@ -17,8 +17,7 @@ package org.apache.spark.sql.catalyst.expressions -import org.apache.spark.sql.types.{StructType, NativeType} - +import org.apache.spark.sql.types.{UTF8String, DataType, StructType, NativeType} /** * An extended interface to [[Row]] that allows the values for each column to be updated. Setting @@ -37,6 +36,7 @@ trait MutableRow extends Row { def setByte(ordinal: Int, value: Byte) def setFloat(ordinal: Int, value: Float) def setString(ordinal: Int, value: String) + // TODO(davies): add setDate() and setDecimal() } /** @@ -114,9 +114,15 @@ class GenericRow(protected[sql] val values: Array[Any]) extends Row { } override def getString(i: Int): String = { - values(i).asInstanceOf[String] + values(i) match { + case null => null + case s: String => s + case utf8: UTF8String => utf8.toString + } } + // TODO(davies): add getDate and getDecimal + // Custom hashCode function that matches the efficient code generated version. override def hashCode: Int = { var result: Int = 37 @@ -189,8 +195,7 @@ class GenericMutableRow(v: Array[Any]) extends GenericRow(v) with MutableRow { override def setFloat(ordinal: Int, value: Float): Unit = { values(ordinal) = value } override def setInt(ordinal: Int, value: Int): Unit = { values(ordinal) = value } override def setLong(ordinal: Int, value: Long): Unit = { values(ordinal) = value } - override def setString(ordinal: Int, value: String): Unit = { values(ordinal) = value } - + override def setString(ordinal: Int, value: String) { values(ordinal) = UTF8String(value)} override def setNullAt(i: Int): Unit = { values(i) = null } override def setShort(ordinal: Int, value: Short): Unit = { values(ordinal) = value } @@ -233,3 +238,10 @@ class RowOrdering(ordering: Seq[SortOrder]) extends Ordering[Row] { return 0 } } + +object RowOrdering { + def forSchema(dataTypes: Seq[DataType]): RowOrdering = + new RowOrdering(dataTypes.zipWithIndex.map { + case(dt, index) => new SortOrder(BoundReference(index, dt, nullable = true), Ascending) + }) +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala index acfbbace608ef..d597bf7ce756a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala @@ -19,11 +19,8 @@ package org.apache.spark.sql.catalyst.expressions import java.util.regex.Pattern -import scala.collection.IndexedSeqOptimized - - import org.apache.spark.sql.catalyst.analysis.UnresolvedException -import org.apache.spark.sql.types.{BinaryType, BooleanType, DataType, StringType} +import org.apache.spark.sql.types._ trait StringRegexExpression { self: BinaryExpression => @@ -60,38 +57,17 @@ trait StringRegexExpression { if(r == null) { null } else { - val regex = pattern(r.asInstanceOf[String]) + val regex = pattern(r.asInstanceOf[UTF8String].toString) if(regex == null) { null } else { - matches(regex, l.asInstanceOf[String]) + matches(regex, l.asInstanceOf[UTF8String].toString) } } } } } -trait CaseConversionExpression { - self: UnaryExpression => - - type EvaluatedType = Any - - def convert(v: String): String - - override def foldable: Boolean = child.foldable - def nullable: Boolean = child.nullable - def dataType: DataType = StringType - - override def eval(input: Row): Any = { - val evaluated = child.eval(input) - if (evaluated == null) { - null - } else { - convert(evaluated.toString) - } - } -} - /** * Simple RegEx pattern matching function */ @@ -134,12 +110,33 @@ case class RLike(left: Expression, right: Expression) override def matches(regex: Pattern, str: String): Boolean = regex.matcher(str).find(0) } +trait CaseConversionExpression { + self: UnaryExpression => + + type EvaluatedType = Any + + def convert(v: UTF8String): UTF8String + + override def foldable: Boolean = child.foldable + def nullable: Boolean = child.nullable + def dataType: DataType = StringType + + override def eval(input: Row): Any = { + val evaluated = child.eval(input) + if (evaluated == null) { + null + } else { + convert(evaluated.asInstanceOf[UTF8String]) + } + } +} + /** * A function that converts the characters of a string to uppercase. */ case class Upper(child: Expression) extends UnaryExpression with CaseConversionExpression { - override def convert(v: String): String = v.toUpperCase() + override def convert(v: UTF8String): UTF8String = v.toUpperCase override def toString: String = s"Upper($child)" } @@ -149,7 +146,7 @@ case class Upper(child: Expression) extends UnaryExpression with CaseConversionE */ case class Lower(child: Expression) extends UnaryExpression with CaseConversionExpression { - override def convert(v: String): String = v.toLowerCase() + override def convert(v: UTF8String): UTF8String = v.toLowerCase override def toString: String = s"Lower($child)" } @@ -162,15 +159,16 @@ trait StringComparison { override def nullable: Boolean = left.nullable || right.nullable - def compare(l: String, r: String): Boolean + def compare(l: UTF8String, r: UTF8String): Boolean override def eval(input: Row): Any = { - val leftEval = left.eval(input).asInstanceOf[String] + val leftEval = left.eval(input) if(leftEval == null) { null } else { - val rightEval = right.eval(input).asInstanceOf[String] - if (rightEval == null) null else compare(leftEval, rightEval) + val rightEval = right.eval(input) + if (rightEval == null) null + else compare(leftEval.asInstanceOf[UTF8String], rightEval.asInstanceOf[UTF8String]) } } @@ -184,7 +182,7 @@ trait StringComparison { */ case class Contains(left: Expression, right: Expression) extends BinaryPredicate with StringComparison { - override def compare(l: String, r: String): Boolean = l.contains(r) + override def compare(l: UTF8String, r: UTF8String): Boolean = l.contains(r) } /** @@ -192,7 +190,7 @@ case class Contains(left: Expression, right: Expression) */ case class StartsWith(left: Expression, right: Expression) extends BinaryPredicate with StringComparison { - override def compare(l: String, r: String): Boolean = l.startsWith(r) + override def compare(l: UTF8String, r: UTF8String): Boolean = l.startsWith(r) } /** @@ -200,7 +198,7 @@ case class StartsWith(left: Expression, right: Expression) */ case class EndsWith(left: Expression, right: Expression) extends BinaryPredicate with StringComparison { - override def compare(l: String, r: String): Boolean = l.endsWith(r) + override def compare(l: UTF8String, r: UTF8String): Boolean = l.endsWith(r) } /** @@ -224,9 +222,7 @@ case class Substring(str: Expression, pos: Expression, len: Expression) extends override def children: Seq[Expression] = str :: pos :: len :: Nil @inline - def slice[T, C <: Any](str: C, startPos: Int, sliceLen: Int) - (implicit ev: (C=>IndexedSeqOptimized[T,_])): Any = { - val len = str.length + def slicePos(startPos: Int, sliceLen: Int, length: () => Int): (Int, Int) = { // Hive and SQL use one-based indexing for SUBSTR arguments but also accept zero and // negative indices for start positions. If a start index i is greater than 0, it // refers to element i-1 in the sequence. If a start index i is less than 0, it refers @@ -235,7 +231,7 @@ case class Substring(str: Expression, pos: Expression, len: Expression) extends val start = startPos match { case pos if pos > 0 => pos - 1 - case neg if neg < 0 => len + neg + case neg if neg < 0 => length() + neg case _ => 0 } @@ -244,12 +240,11 @@ case class Substring(str: Expression, pos: Expression, len: Expression) extends case x => start + x } - str.slice(start, end) + (start, end) } override def eval(input: Row): Any = { val string = str.eval(input) - val po = pos.eval(input) val ln = len.eval(input) @@ -257,11 +252,14 @@ case class Substring(str: Expression, pos: Expression, len: Expression) extends null } else { val start = po.asInstanceOf[Int] - val length = ln.asInstanceOf[Int] - + val length = ln.asInstanceOf[Int] string match { - case ba: Array[Byte] => slice(ba, start, length) - case other => slice(other.toString, start, length) + case ba: Array[Byte] => + val (st, end) = slicePos(start, length, () => ba.length) + ba.slice(st, end) + case s: UTF8String => + val (st, end) = slicePos(start, length, () => s.length) + s.slice(st, end) } } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index 93e69d409cb91..7c80634d2c852 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -198,14 +198,19 @@ object LikeSimplification extends Rule[LogicalPlan] { val equalTo = "([^_%]*)".r def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions { - case Like(l, Literal(startsWith(pattern), StringType)) if !pattern.endsWith("\\") => - StartsWith(l, Literal(pattern)) - case Like(l, Literal(endsWith(pattern), StringType)) => - EndsWith(l, Literal(pattern)) - case Like(l, Literal(contains(pattern), StringType)) if !pattern.endsWith("\\") => - Contains(l, Literal(pattern)) - case Like(l, Literal(equalTo(pattern), StringType)) => - EqualTo(l, Literal(pattern)) + case Like(l, Literal(utf, StringType)) => + utf.toString match { + case startsWith(pattern) if !pattern.endsWith("\\") => + StartsWith(l, Literal(pattern)) + case endsWith(pattern) => + EndsWith(l, Literal(pattern)) + case contains(pattern) if !pattern.endsWith("\\") => + Contains(l, Literal(pattern)) + case equalTo(pattern) => + EqualTo(l, Literal(pattern)) + case _ => + Like(l, Literal.create(utf, StringType)) + } } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala index 579a0fb8d3f93..ae4620a4e5abf 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala @@ -19,12 +19,11 @@ package org.apache.spark.sql.catalyst.plans.logical import org.apache.spark.Logging import org.apache.spark.sql.AnalysisException -import org.apache.spark.sql.catalyst.analysis.{EliminateSubQueries, UnresolvedGetField, Resolver} +import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, EliminateSubQueries, Resolver} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.QueryPlan import org.apache.spark.sql.catalyst.trees.TreeNode import org.apache.spark.sql.catalyst.trees -import org.apache.spark.sql.types.{ArrayType, StructType, StructField} abstract class LogicalPlan extends QueryPlan[LogicalPlan] with Logging { @@ -111,10 +110,10 @@ abstract class LogicalPlan extends QueryPlan[LogicalPlan] with Logging { * as string in the following form: `[scope].AttributeName.[nested].[fields]...`. */ def resolveChildren( - name: String, + nameParts: Seq[String], resolver: Resolver, throwErrors: Boolean = false): Option[NamedExpression] = - resolve(name, children.flatMap(_.output), resolver, throwErrors) + resolve(nameParts, children.flatMap(_.output), resolver, throwErrors) /** * Optionally resolves the given string to a [[NamedExpression]] based on the output of this @@ -122,10 +121,10 @@ abstract class LogicalPlan extends QueryPlan[LogicalPlan] with Logging { * `[scope].AttributeName.[nested].[fields]...`. */ def resolve( - name: String, + nameParts: Seq[String], resolver: Resolver, throwErrors: Boolean = false): Option[NamedExpression] = - resolve(name, output, resolver, throwErrors) + resolve(nameParts, output, resolver, throwErrors) /** * Resolve the given `name` string against the given attribute, returning either 0 or 1 match. @@ -135,7 +134,7 @@ abstract class LogicalPlan extends QueryPlan[LogicalPlan] with Logging { * See the comment above `candidates` variable in resolve() for semantics the returned data. */ private def resolveAsTableColumn( - nameParts: Array[String], + nameParts: Seq[String], resolver: Resolver, attribute: Attribute): Option[(Attribute, List[String])] = { assert(nameParts.length > 1) @@ -155,7 +154,7 @@ abstract class LogicalPlan extends QueryPlan[LogicalPlan] with Logging { * See the comment above `candidates` variable in resolve() for semantics the returned data. */ private def resolveAsColumn( - nameParts: Array[String], + nameParts: Seq[String], resolver: Resolver, attribute: Attribute): Option[(Attribute, List[String])] = { if (resolver(attribute.name, nameParts.head)) { @@ -167,13 +166,11 @@ abstract class LogicalPlan extends QueryPlan[LogicalPlan] with Logging { /** Performs attribute resolution given a name and a sequence of possible attributes. */ protected def resolve( - name: String, + nameParts: Seq[String], input: Seq[Attribute], resolver: Resolver, throwErrors: Boolean): Option[NamedExpression] = { - val parts = name.split("\\.") - // A sequence of possible candidate matches. // Each candidate is a tuple. The first element is a resolved attribute, followed by a list // of parts that are to be resolved. @@ -182,9 +179,9 @@ abstract class LogicalPlan extends QueryPlan[LogicalPlan] with Logging { // and the second element will be List("c"). var candidates: Seq[(Attribute, List[String])] = { // If the name has 2 or more parts, try to resolve it as `table.column` first. - if (parts.length > 1) { + if (nameParts.length > 1) { input.flatMap { option => - resolveAsTableColumn(parts, resolver, option) + resolveAsTableColumn(nameParts, resolver, option) } } else { Seq.empty @@ -194,10 +191,12 @@ abstract class LogicalPlan extends QueryPlan[LogicalPlan] with Logging { // If none of attributes match `table.column` pattern, we try to resolve it as a column. if (candidates.isEmpty) { candidates = input.flatMap { candidate => - resolveAsColumn(parts, resolver, candidate) + resolveAsColumn(nameParts, resolver, candidate) } } + def name = UnresolvedAttribute(nameParts).name + candidates.distinct match { // One match, no nested fields, use it. case Seq((a, Nil)) => Some(a) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala index 288c11f69fe22..fb4217a44807b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala @@ -94,6 +94,9 @@ sealed trait Partitioning { * only compatible if the `numPartitions` of them is the same. */ def compatibleWith(other: Partitioning): Boolean + + /** Returns the expressions that are used to key the partitioning. */ + def keyExpressions: Seq[Expression] } case class UnknownPartitioning(numPartitions: Int) extends Partitioning { @@ -106,6 +109,8 @@ case class UnknownPartitioning(numPartitions: Int) extends Partitioning { case UnknownPartitioning(_) => true case _ => false } + + override def keyExpressions: Seq[Expression] = Nil } case object SinglePartition extends Partitioning { @@ -117,6 +122,8 @@ case object SinglePartition extends Partitioning { case SinglePartition => true case _ => false } + + override def keyExpressions: Seq[Expression] = Nil } case object BroadcastPartitioning extends Partitioning { @@ -128,6 +135,8 @@ case object BroadcastPartitioning extends Partitioning { case SinglePartition => true case _ => false } + + override def keyExpressions: Seq[Expression] = Nil } /** @@ -158,6 +167,8 @@ case class HashPartitioning(expressions: Seq[Expression], numPartitions: Int) case _ => false } + override def keyExpressions: Seq[Expression] = expressions + override def eval(input: Row = null): EvaluatedType = throw new TreeNodeException(this, s"No function to evaluate expression. type: ${this.nodeName}") } @@ -200,6 +211,8 @@ case class RangePartitioning(ordering: Seq[SortOrder], numPartitions: Int) case _ => false } + override def keyExpressions: Seq[Expression] = ordering.map(_.child) + override def eval(input: Row): EvaluatedType = throw new TreeNodeException(this, s"No function to evaluate expression. type: ${this.nodeName}") } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala index a2df51e598a2b..97502ed3afe72 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala @@ -85,7 +85,7 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] { * @param f the function to be applied to each node in the tree. */ def foreachUp(f: BaseType => Unit): Unit = { - children.foreach(_.foreach(f)) + children.foreach(_.foreachUp(f)) f(this) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DateUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DateUtils.scala index 504fb05842505..d36a49159b87f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DateUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DateUtils.scala @@ -40,6 +40,7 @@ object DateUtils { millisToDays(d.getTime) } + // we should use the exact day as Int, for example, (year, month, day) -> day def millisToDays(millisLocal: Long): Int = { ((millisLocal + LOCAL_TIMEZONE.get().getOffset(millisLocal)) / MILLIS_PER_DAY).toInt } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/UTF8String.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/UTF8String.scala new file mode 100644 index 0000000000000..fc02ba6c9c43e --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/UTF8String.scala @@ -0,0 +1,214 @@ +/* +* Licensed to the Apache Software Foundation (ASF) under one or more +* contributor license agreements. See the NOTICE file distributed with +* this work for additional information regarding copyright ownership. +* The ASF licenses this file to You under the Apache License, Version 2.0 +* (the "License"); you may not use this file except in compliance with +* the License. You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +package org.apache.spark.sql.types + +import java.util.Arrays + +/** + * A UTF-8 String, as internal representation of StringType in SparkSQL + * + * A String encoded in UTF-8 as an Array[Byte], which can be used for comparison, + * search, see http://en.wikipedia.org/wiki/UTF-8 for details. + * + * Note: This is not designed for general use cases, should not be used outside SQL. + */ + +final class UTF8String extends Ordered[UTF8String] with Serializable { + + private[this] var bytes: Array[Byte] = _ + + /** + * Update the UTF8String with String. + */ + def set(str: String): UTF8String = { + bytes = str.getBytes("utf-8") + this + } + + /** + * Update the UTF8String with Array[Byte], which should be encoded in UTF-8 + */ + def set(bytes: Array[Byte]): UTF8String = { + this.bytes = bytes + this + } + + /** + * Return the number of bytes for a code point with the first byte as `b` + * @param b The first byte of a code point + */ + @inline + private[this] def numOfBytes(b: Byte): Int = { + val offset = (b & 0xFF) - 192 + if (offset >= 0) UTF8String.bytesOfCodePointInUTF8(offset) else 1 + } + + /** + * Return the number of code points in it. + * + * This is only used by Substring() when `start` is negative. + */ + def length(): Int = { + var len = 0 + var i: Int = 0 + while (i < bytes.length) { + i += numOfBytes(bytes(i)) + len += 1 + } + len + } + + def getBytes: Array[Byte] = { + bytes + } + + /** + * Return a substring of this, + * @param start the position of first code point + * @param until the position after last code point + */ + def slice(start: Int, until: Int): UTF8String = { + if (until <= start || start >= bytes.length || bytes == null) { + new UTF8String + } + + var c = 0 + var i: Int = 0 + while (c < start && i < bytes.length) { + i += numOfBytes(bytes(i)) + c += 1 + } + var j = i + while (c < until && j < bytes.length) { + j += numOfBytes(bytes(j)) + c += 1 + } + UTF8String(Arrays.copyOfRange(bytes, i, j)) + } + + def contains(sub: UTF8String): Boolean = { + val b = sub.getBytes + if (b.length == 0) { + return true + } + var i: Int = 0 + while (i <= bytes.length - b.length) { + // In worst case, it's O(N*K), but should works fine with SQL + if (bytes(i) == b(0) && Arrays.equals(Arrays.copyOfRange(bytes, i, i + b.length), b)) { + return true + } + i += 1 + } + false + } + + def startsWith(prefix: UTF8String): Boolean = { + val b = prefix.getBytes + if (b.length > bytes.length) { + return false + } + Arrays.equals(Arrays.copyOfRange(bytes, 0, b.length), b) + } + + def endsWith(suffix: UTF8String): Boolean = { + val b = suffix.getBytes + if (b.length > bytes.length) { + return false + } + Arrays.equals(Arrays.copyOfRange(bytes, bytes.length - b.length, bytes.length), b) + } + + def toUpperCase(): UTF8String = { + // upper case depends on locale, fallback to String. + UTF8String(toString().toUpperCase) + } + + def toLowerCase(): UTF8String = { + // lower case depends on locale, fallback to String. + UTF8String(toString().toLowerCase) + } + + override def toString(): String = { + new String(bytes, "utf-8") + } + + override def clone(): UTF8String = new UTF8String().set(this.bytes) + + override def compare(other: UTF8String): Int = { + var i: Int = 0 + val b = other.getBytes + while (i < bytes.length && i < b.length) { + val res = bytes(i).compareTo(b(i)) + if (res != 0) return res + i += 1 + } + bytes.length - b.length + } + + override def compareTo(other: UTF8String): Int = { + compare(other) + } + + override def equals(other: Any): Boolean = other match { + case s: UTF8String => + Arrays.equals(bytes, s.getBytes) + case s: String => + // This is only used for Catalyst unit tests + // fail fast + bytes.length >= s.length && length() == s.length && toString() == s + case _ => + false + } + + override def hashCode(): Int = { + Arrays.hashCode(bytes) + } +} + +object UTF8String { + // number of tailing bytes in a UTF8 sequence for a code point + // see http://en.wikipedia.org/wiki/UTF-8, 192-256 of Byte 1 + private[types] val bytesOfCodePointInUTF8: Array[Int] = Array(2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, + 4, 4, 4, 4, 4, 4, 4, 4, + 5, 5, 5, 5, + 6, 6, 6, 6) + + /** + * Create a UTF-8 String from String + */ + def apply(s: String): UTF8String = { + if (s != null) { + new UTF8String().set(s) + } else{ + null + } + } + + /** + * Create a UTF-8 String from Array[Byte], which should be encoded in UTF-8 + */ + def apply(bytes: Array[Byte]): UTF8String = { + if (bytes != null) { + new UTF8String().set(bytes) + } else { + null + } + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/dataTypes.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/dataTypes.scala index cdf2bc68d9c5e..c6fb22c26bd3c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/dataTypes.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/dataTypes.scala @@ -350,7 +350,7 @@ class StringType private() extends NativeType with PrimitiveType { // The companion object and this class is separated so the companion object also subclasses // this type. Otherwise, the companion object would be of type "StringType$" in byte code. // Defined with a private constructor so the companion object is the only possible instantiation. - private[sql] type JvmType = String + private[sql] type JvmType = UTF8String @transient private[sql] lazy val tag = ScalaReflectionLock.synchronized { typeTag[JvmType] } private[sql] val ordering = implicitly[Ordering[JvmType]] @@ -1196,8 +1196,8 @@ abstract class UserDefinedType[UserType] extends DataType with Serializable { /** * Convert the user type to a SQL datum * - * TODO: Can we make this take obj: UserType? The issue is in ScalaReflection.convertToCatalyst, - * where we need to convert Any to UserType. + * TODO: Can we make this take obj: UserType? The issue is in + * CatalystTypeConverters.convertToCatalyst, where we need to convert Any to UserType. */ def serialize(obj: Any): Any diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala index 6e3d6b9263e86..e10ddfdf5127c 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala @@ -27,8 +27,6 @@ import org.apache.spark.sql.types._ import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.dsl.plans._ -import scala.collection.immutable - class AnalysisSuite extends FunSuite with BeforeAndAfter { val caseSensitiveCatalog = new SimpleCatalog(true) val caseInsensitiveCatalog = new SimpleCatalog(false) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala index d2b1090a0cdd5..76298f03c94ae 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala @@ -25,8 +25,9 @@ import org.scalactic.TripleEqualsSupport.Spread import org.scalatest.FunSuite import org.scalatest.Matchers._ -import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.catalyst.CatalystTypeConverters import org.apache.spark.sql.catalyst.analysis.UnresolvedGetField +import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.types._ @@ -59,6 +60,10 @@ class ExpressionEvaluationBaseSuite extends FunSuite { class ExpressionEvaluationSuite extends ExpressionEvaluationBaseSuite { + def create_row(values: Any*): Row = { + new GenericRow(values.map(CatalystTypeConverters.convertToCatalyst).toArray) + } + test("literals") { checkEvaluation(Literal(1), 1) checkEvaluation(Literal(true), true) @@ -233,6 +238,16 @@ class ExpressionEvaluationSuite extends ExpressionEvaluationBaseSuite { checkEvaluation(MaxOf(2, Literal.create(null, IntegerType)), 2) } + test("MinOf") { + checkEvaluation(MinOf(1, 2), 1) + checkEvaluation(MinOf(2, 1), 1) + checkEvaluation(MinOf(1L, 2L), 1L) + checkEvaluation(MinOf(2L, 1L), 1L) + + checkEvaluation(MinOf(Literal.create(null, IntegerType), 1), 1) + checkEvaluation(MinOf(1, Literal.create(null, IntegerType)), 1) + } + test("LIKE literal Regular Expression") { checkEvaluation(Literal.create(null, StringType).like("a"), null) checkEvaluation(Literal.create("a", StringType).like(Literal.create(null, StringType)), null) @@ -255,24 +270,23 @@ class ExpressionEvaluationSuite extends ExpressionEvaluationBaseSuite { test("LIKE Non-literal Regular Expression") { val regEx = 'a.string.at(0) - checkEvaluation("abcd" like regEx, null, new GenericRow(Array[Any](null))) - checkEvaluation("abdef" like regEx, true, new GenericRow(Array[Any]("abdef"))) - checkEvaluation("a_%b" like regEx, true, new GenericRow(Array[Any]("a\\__b"))) - checkEvaluation("addb" like regEx, true, new GenericRow(Array[Any]("a_%b"))) - checkEvaluation("addb" like regEx, false, new GenericRow(Array[Any]("a\\__b"))) - checkEvaluation("addb" like regEx, false, new GenericRow(Array[Any]("a%\\%b"))) - checkEvaluation("a_%b" like regEx, true, new GenericRow(Array[Any]("a%\\%b"))) - checkEvaluation("addb" like regEx, true, new GenericRow(Array[Any]("a%"))) - checkEvaluation("addb" like regEx, false, new GenericRow(Array[Any]("**"))) - checkEvaluation("abc" like regEx, true, new GenericRow(Array[Any]("a%"))) - checkEvaluation("abc" like regEx, false, new GenericRow(Array[Any]("b%"))) - checkEvaluation("abc" like regEx, false, new GenericRow(Array[Any]("bc%"))) - checkEvaluation("a\nb" like regEx, true, new GenericRow(Array[Any]("a_b"))) - checkEvaluation("ab" like regEx, true, new GenericRow(Array[Any]("a%b"))) - checkEvaluation("a\nb" like regEx, true, new GenericRow(Array[Any]("a%b"))) - - checkEvaluation(Literal.create(null, StringType) like regEx, null, - new GenericRow(Array[Any]("bc%"))) + checkEvaluation("abcd" like regEx, null, create_row(null)) + checkEvaluation("abdef" like regEx, true, create_row("abdef")) + checkEvaluation("a_%b" like regEx, true, create_row("a\\__b")) + checkEvaluation("addb" like regEx, true, create_row("a_%b")) + checkEvaluation("addb" like regEx, false, create_row("a\\__b")) + checkEvaluation("addb" like regEx, false, create_row("a%\\%b")) + checkEvaluation("a_%b" like regEx, true, create_row("a%\\%b")) + checkEvaluation("addb" like regEx, true, create_row("a%")) + checkEvaluation("addb" like regEx, false, create_row("**")) + checkEvaluation("abc" like regEx, true, create_row("a%")) + checkEvaluation("abc" like regEx, false, create_row("b%")) + checkEvaluation("abc" like regEx, false, create_row("bc%")) + checkEvaluation("a\nb" like regEx, true, create_row("a_b")) + checkEvaluation("ab" like regEx, true, create_row("a%b")) + checkEvaluation("a\nb" like regEx, true, create_row("a%b")) + + checkEvaluation(Literal.create(null, StringType) like regEx, null, create_row("bc%")) } test("RLIKE literal Regular Expression") { @@ -303,14 +317,14 @@ class ExpressionEvaluationSuite extends ExpressionEvaluationBaseSuite { test("RLIKE Non-literal Regular Expression") { val regEx = 'a.string.at(0) - checkEvaluation("abdef" rlike regEx, true, new GenericRow(Array[Any]("abdef"))) - checkEvaluation("abbbbc" rlike regEx, true, new GenericRow(Array[Any]("a.*c"))) - checkEvaluation("fofo" rlike regEx, true, new GenericRow(Array[Any]("^fo"))) - checkEvaluation("fo\no" rlike regEx, true, new GenericRow(Array[Any]("^fo\no$"))) - checkEvaluation("Bn" rlike regEx, true, new GenericRow(Array[Any]("^Ba*n"))) + checkEvaluation("abdef" rlike regEx, true, create_row("abdef")) + checkEvaluation("abbbbc" rlike regEx, true, create_row("a.*c")) + checkEvaluation("fofo" rlike regEx, true, create_row("^fo")) + checkEvaluation("fo\no" rlike regEx, true, create_row("^fo\no$")) + checkEvaluation("Bn" rlike regEx, true, create_row("^Ba*n")) intercept[java.util.regex.PatternSyntaxException] { - evaluate("abbbbc" rlike regEx, new GenericRow(Array[Any]("**"))) + evaluate("abbbbc" rlike regEx, create_row("**")) } } @@ -753,7 +767,7 @@ class ExpressionEvaluationSuite extends ExpressionEvaluationBaseSuite { } test("null checking") { - val row = new GenericRow(Array[Any]("^Ba*n", null, true, null)) + val row = create_row("^Ba*n", null, true, null) val c1 = 'a.string.at(0) val c2 = 'a.string.at(1) val c3 = 'a.boolean.at(2) @@ -793,7 +807,7 @@ class ExpressionEvaluationSuite extends ExpressionEvaluationBaseSuite { } test("case when") { - val row = new GenericRow(Array[Any](null, false, true, "a", "b", "c")) + val row = create_row(null, false, true, "a", "b", "c") val c1 = 'a.boolean.at(0) val c2 = 'a.boolean.at(1) val c3 = 'a.boolean.at(2) @@ -836,13 +850,13 @@ class ExpressionEvaluationSuite extends ExpressionEvaluationBaseSuite { } test("complex type") { - val row = new GenericRow(Array[Any]( - "^Ba*n", // 0 - null.asInstanceOf[String], // 1 - new GenericRow(Array[Any]("aa", "bb")), // 2 - Map("aa"->"bb"), // 3 - Seq("aa", "bb") // 4 - )) + val row = create_row( + "^Ba*n", // 0 + null.asInstanceOf[UTF8String], // 1 + create_row("aa", "bb"), // 2 + Map("aa"->"bb"), // 3 + Seq("aa", "bb") // 4 + ) val typeS = StructType( StructField("a", StringType, true) :: StructField("b", StringType, true) :: Nil @@ -899,7 +913,7 @@ class ExpressionEvaluationSuite extends ExpressionEvaluationBaseSuite { } test("arithmetic") { - val row = new GenericRow(Array[Any](1, 2, 3, null)) + val row = create_row(1, 2, 3, null) val c1 = 'a.int.at(0) val c2 = 'a.int.at(1) val c3 = 'a.int.at(2) @@ -924,7 +938,7 @@ class ExpressionEvaluationSuite extends ExpressionEvaluationBaseSuite { } test("fractional arithmetic") { - val row = new GenericRow(Array[Any](1.1, 2.0, 3.1, null)) + val row = create_row(1.1, 2.0, 3.1, null) val c1 = 'a.double.at(0) val c2 = 'a.double.at(1) val c3 = 'a.double.at(2) @@ -948,7 +962,7 @@ class ExpressionEvaluationSuite extends ExpressionEvaluationBaseSuite { } test("BinaryComparison") { - val row = new GenericRow(Array[Any](1, 2, 3, null, 3, null)) + val row = create_row(1, 2, 3, null, 3, null) val c1 = 'a.int.at(0) val c2 = 'a.int.at(1) val c3 = 'a.int.at(2) @@ -978,7 +992,7 @@ class ExpressionEvaluationSuite extends ExpressionEvaluationBaseSuite { } test("StringComparison") { - val row = new GenericRow(Array[Any]("abc", null)) + val row = create_row("abc", null) val c1 = 'a.string.at(0) val c2 = 'a.string.at(1) @@ -999,7 +1013,7 @@ class ExpressionEvaluationSuite extends ExpressionEvaluationBaseSuite { } test("Substring") { - val row = new GenericRow(Array[Any]("example", "example".toArray.map(_.toByte))) + val row = create_row("example", "example".toArray.map(_.toByte)) val s = 'a.string.at(0) @@ -1043,7 +1057,7 @@ class ExpressionEvaluationSuite extends ExpressionEvaluationBaseSuite { // substring(null, _, _) -> null checkEvaluation(Substring(s, Literal.create(100, IntegerType), Literal.create(4, IntegerType)), - null, new GenericRow(Array[Any](null))) + null, create_row(null)) // substring(_, null, _) -> null checkEvaluation(Substring(s, Literal.create(null, IntegerType), Literal.create(4, IntegerType)), @@ -1092,20 +1106,20 @@ class ExpressionEvaluationSuite extends ExpressionEvaluationBaseSuite { test("SQRT") { val inputSequence = (1 to (1<<24) by 511).map(_ * (1L<<24)) val expectedResults = inputSequence.map(l => math.sqrt(l.toDouble)) - val rowSequence = inputSequence.map(l => new GenericRow(Array[Any](l.toDouble))) + val rowSequence = inputSequence.map(l => create_row(l.toDouble)) val d = 'a.double.at(0) for ((row, expected) <- rowSequence zip expectedResults) { checkEvaluation(Sqrt(d), expected, row) } - checkEvaluation(Sqrt(Literal.create(null, DoubleType)), null, new GenericRow(Array[Any](null))) + checkEvaluation(Sqrt(Literal.create(null, DoubleType)), null, create_row(null)) checkEvaluation(Sqrt(-1), null, EmptyRow) checkEvaluation(Sqrt(-1.5), null, EmptyRow) } test("Bitwise operations") { - val row = new GenericRow(Array[Any](1, 2, 3, null)) + val row = create_row(1, 2, 3, null) val c1 = 'a.int.at(0) val c2 = 'a.int.at(1) val c3 = 'a.int.at(2) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/GeneratedMutableEvaluationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/GeneratedMutableEvaluationSuite.scala index 275ea2627ebcd..bcc0c404d2cfb 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/GeneratedMutableEvaluationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/GeneratedMutableEvaluationSuite.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.catalyst.expressions -import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.catalyst.CatalystTypeConverters import org.apache.spark.sql.catalyst.expressions.codegen._ /** @@ -43,7 +43,7 @@ class GeneratedMutableEvaluationSuite extends ExpressionEvaluationSuite { } val actual = plan(inputRow) - val expectedRow = new GenericRow(Array[Any](expected)) + val expectedRow = new GenericRow(Array[Any](CatalystTypeConverters.convertToCatalyst(expected))) if (actual.hashCode() != expectedRow.hashCode()) { fail( s""" diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala index 4eb8708335dcf..6b393327cc97a 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala @@ -117,5 +117,17 @@ class TreeNodeSuite extends FunSuite { assert(transformed.origin.startPosition.isDefined) } + test("foreach up") { + val actual = new ArrayBuffer[String]() + val expected = Seq("1", "2", "3", "4", "-", "*", "+") + val expression = Add(Literal(1), Multiply(Literal(2), Subtract(Literal(3), Literal(4)))) + expression foreachUp { + case b: BinaryExpression => actual.append(b.symbol); + case l: Literal => actual.append(l.toString); + } + + assert(expected === actual) + } + } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/UTF8StringSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/UTF8StringSuite.scala new file mode 100644 index 0000000000000..a22aa6f244c48 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/UTF8StringSuite.scala @@ -0,0 +1,70 @@ +/* +* Licensed to the Apache Software Foundation (ASF) under one or more +* contributor license agreements. See the NOTICE file distributed with +* this work for additional information regarding copyright ownership. +* The ASF licenses this file to You under the Apache License, Version 2.0 +* (the "License"); you may not use this file except in compliance with +* the License. You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +package org.apache.spark.sql.types + +import org.scalatest.FunSuite + +// scalastyle:off +class UTF8StringSuite extends FunSuite { + test("basic") { + def check(str: String, len: Int) { + + assert(UTF8String(str).length == len) + assert(UTF8String(str.getBytes("utf8")).length() == len) + + assert(UTF8String(str) == str) + assert(UTF8String(str.getBytes("utf8")) == str) + assert(UTF8String(str).toString == str) + assert(UTF8String(str.getBytes("utf8")).toString == str) + assert(UTF8String(str.getBytes("utf8")) == UTF8String(str)) + + assert(UTF8String(str).hashCode() == UTF8String(str.getBytes("utf8")).hashCode()) + } + + check("hello", 5) + check("世 界", 3) + } + + test("contains") { + assert(UTF8String("hello").contains(UTF8String("ello"))) + assert(!UTF8String("hello").contains(UTF8String("vello"))) + assert(UTF8String("大千世界").contains(UTF8String("千世"))) + assert(!UTF8String("大千世界").contains(UTF8String("世千"))) + } + + test("prefix") { + assert(UTF8String("hello").startsWith(UTF8String("hell"))) + assert(!UTF8String("hello").startsWith(UTF8String("ell"))) + assert(UTF8String("大千世界").startsWith(UTF8String("大千"))) + assert(!UTF8String("大千世界").startsWith(UTF8String("千"))) + } + + test("suffix") { + assert(UTF8String("hello").endsWith(UTF8String("ello"))) + assert(!UTF8String("hello").endsWith(UTF8String("ellov"))) + assert(UTF8String("大千世界").endsWith(UTF8String("世界"))) + assert(!UTF8String("大千世界").endsWith(UTF8String("世"))) + } + + test("slice") { + assert(UTF8String("hello").slice(1, 3) == UTF8String("el")) + assert(UTF8String("大千世界").slice(0, 1) == UTF8String("大")) + assert(UTF8String("大千世界").slice(1, 3) == UTF8String("千世")) + assert(UTF8String("大千世界").slice(3, 5) == UTF8String("界")) + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/CacheManager.scala b/sql/core/src/main/scala/org/apache/spark/sql/CacheManager.scala index ca4a127120b37..18584c2dcf797 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/CacheManager.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/CacheManager.scala @@ -112,7 +112,7 @@ private[sql] class CacheManager(sqlContext: SQLContext) extends Logging { val planToCache = query.queryExecution.analyzed val dataIndex = cachedData.indexWhere(cd => planToCache.sameResult(cd.plan)) require(dataIndex >= 0, s"Table $query is not cached.") - cachedData(dataIndex).cachedRepresentation.cachedColumnBuffers.unpersist(blocking) + cachedData(dataIndex).cachedRepresentation.uncache(blocking) cachedData.remove(dataIndex) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala index 3cd7adf8cab5e..edb229c059e6b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala @@ -515,14 +515,15 @@ class Column(protected[sql] val expr: Expression) extends Logging { def rlike(literal: String): Column = RLike(expr, lit(literal).expr) /** - * An expression that gets an item at position `ordinal` out of an array. + * An expression that gets an item at position `ordinal` out of an array, + * or gets a value by key `key` in a [[MapType]]. * * @group expr_ops */ - def getItem(ordinal: Int): Column = GetItem(expr, Literal(ordinal)) + def getItem(key: Any): Column = GetItem(expr, Literal(key)) /** - * An expression that gets a field by name in a [[StructField]]. + * An expression that gets a field by name in a [[StructType]]. * * @group expr_ops */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala index 94ae2d65fd0e4..17c21f6e3a0e9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala @@ -158,7 +158,7 @@ class DataFrame private[sql]( } protected[sql] def resolve(colName: String): NamedExpression = { - queryExecution.analyzed.resolve(colName, sqlContext.analyzer.resolver).getOrElse { + queryExecution.analyzed.resolve(colName.split("\\."), sqlContext.analyzer.resolver).getOrElse { throw new AnalysisException( s"""Cannot resolve column name "$colName" among (${schema.fieldNames.mkString(", ")})""") } @@ -166,7 +166,7 @@ class DataFrame private[sql]( protected[sql] def numericColumns: Seq[Expression] = { schema.fields.filter(_.dataType.isInstanceOf[NumericType]).map { n => - queryExecution.analyzed.resolve(n.name, sqlContext.analyzer.resolver).get + queryExecution.analyzed.resolve(n.name.split("\\."), sqlContext.analyzer.resolver).get } } @@ -908,6 +908,20 @@ class DataFrame private[sql]( schema, needsConversion = false) } + /** + * Returns a new [[DataFrame]] that has exactly `numPartitions` partitions. + * Similar to coalesce defined on an [[RDD]], this operation results in a narrow dependency, e.g. + * if you go from 1000 partitions to 100 partitions, there will not be a shuffle, instead each of + * the 100 new partitions will claim 10 of the current partitions. + * @group rdd + */ + override def coalesce(numPartitions: Int): DataFrame = { + sqlContext.createDataFrame( + queryExecution.toRdd.coalesce(numPartitions), + schema, + needsConversion = false) + } + /** * Returns a new [[DataFrame]] that contains only the unique rows from this [[DataFrame]]. * @group dfops diff --git a/sql/core/src/main/scala/org/apache/spark/sql/RDDApi.scala b/sql/core/src/main/scala/org/apache/spark/sql/RDDApi.scala index ba4373f0124b4..63dbab19947c0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/RDDApi.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/RDDApi.scala @@ -61,5 +61,7 @@ private[sql] trait RDDApi[T] { def repartition(numPartitions: Int): DataFrame + def coalesce(numPartitions: Int): DataFrame + def distinct: DataFrame } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala index ee641bdfeb2d7..5c65f04ee8497 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala @@ -47,6 +47,7 @@ private[spark] object SQLConf { // Options that control which operators can be chosen by the query planner. These should be // considered hints and may be ignored by future versions of Spark SQL. val EXTERNAL_SORT = "spark.sql.planner.externalSort" + val SORTMERGE_JOIN = "spark.sql.planner.sortMergeJoin" // This is only used for the thriftserver val THRIFTSERVER_POOL = "spark.sql.thriftserver.scheduler.pool" @@ -128,6 +129,13 @@ private[sql] class SQLConf extends Serializable { /** When true the planner will use the external sort, which may spill to disk. */ private[spark] def externalSortEnabled: Boolean = getConf(EXTERNAL_SORT, "false").toBoolean + /** + * Sort merge join would sort the two side of join first, and then iterate both sides together + * only once to get all matches. Using sort merge join can save a lot of memory usage compared + * to HashJoin. + */ + private[spark] def sortMergeJoinEnabled: Boolean = getConf(SORTMERGE_JOIN, "false").toBoolean + /** * When set to true, Spark SQL will use the Scala compiler at runtime to generate custom bytecode * that evaluates expressions found in queries. In general this custom code runs much faster diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala index c25ef58e6f62a..f9f3eb2e03817 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala @@ -873,8 +873,8 @@ class SQLContext(@transient val sparkContext: SparkContext) * passed to this function. * * @param columnName the name of a column of integral type that will be used for partitioning. - * @param lowerBound the minimum value of `columnName` to retrieve - * @param upperBound the maximum value of `columnName` to retrieve + * @param lowerBound the minimum value of `columnName` used to decide partition stride + * @param upperBound the maximum value of `columnName` used to decide partition stride * @param numPartitions the number of partitions. the range `minValue`-`maxValue` will be split * evenly into this many partitions * @@ -1081,7 +1081,7 @@ class SQLContext(@transient val sparkContext: SparkContext) @transient protected[sql] val prepareForExecution = new RuleExecutor[SparkPlan] { val batches = - Batch("Add exchange", Once, AddExchange(self)) :: Nil + Batch("Add exchange", Once, EnsureRequirements(self)) :: Nil } protected[sql] def openSession(): SQLSession = { @@ -1195,6 +1195,7 @@ class SQLContext(@transient val sparkContext: SparkContext) case FloatType => true case DateType => true case TimestampType => true + case StringType => true case ArrayType(_, _) => true case MapType(_, _, _) => true case StructType(_) => true diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnBuilder.scala index c881747751520..00ed70430b84d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnBuilder.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnBuilder.scala @@ -153,6 +153,7 @@ private[sql] object ColumnBuilder { val builder: ColumnBuilder = dataType match { case IntegerType => new IntColumnBuilder case LongType => new LongColumnBuilder + case FloatType => new FloatColumnBuilder case DoubleType => new DoubleColumnBuilder case BooleanType => new BooleanColumnBuilder case ByteType => new ByteColumnBuilder diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnStats.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnStats.scala index 87a6631da8300..b0f983c180673 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnStats.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnStats.scala @@ -216,13 +216,13 @@ private[sql] class IntColumnStats extends ColumnStats { } private[sql] class StringColumnStats extends ColumnStats { - protected var upper: String = null - protected var lower: String = null + protected var upper: UTF8String = null + protected var lower: UTF8String = null override def gatherStats(row: Row, ordinal: Int): Unit = { super.gatherStats(row, ordinal) if (!row.isNullAt(ordinal)) { - val value = row.getString(ordinal) + val value = row(ordinal).asInstanceOf[UTF8String] if (upper == null || value.compareTo(upper) > 0) upper = value if (lower == null || value.compareTo(lower) < 0) lower = value sizeInBytes += STRING.actualSize(row, ordinal) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnType.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnType.scala index c47497e0662d9..1b9e0df2dcb5e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnType.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnType.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.columnar import java.nio.ByteBuffer -import java.sql.{Date, Timestamp} +import java.sql.Timestamp import scala.reflect.runtime.universe.TypeTag @@ -312,26 +312,28 @@ private[sql] object STRING extends NativeColumnType(StringType, 7, 8) { row.getString(ordinal).getBytes("utf-8").length + 4 } - override def append(v: String, buffer: ByteBuffer): Unit = { - val stringBytes = v.getBytes("utf-8") + override def append(v: UTF8String, buffer: ByteBuffer): Unit = { + val stringBytes = v.getBytes buffer.putInt(stringBytes.length).put(stringBytes, 0, stringBytes.length) } - override def extract(buffer: ByteBuffer): String = { + override def extract(buffer: ByteBuffer): UTF8String = { val length = buffer.getInt() val stringBytes = new Array[Byte](length) buffer.get(stringBytes, 0, length) - new String(stringBytes, "utf-8") + UTF8String(stringBytes) } - override def setField(row: MutableRow, ordinal: Int, value: String): Unit = { - row.setString(ordinal, value) + override def setField(row: MutableRow, ordinal: Int, value: UTF8String): Unit = { + row.update(ordinal, value) } - override def getField(row: Row, ordinal: Int): String = row.getString(ordinal) + override def getField(row: Row, ordinal: Int): UTF8String = { + row(ordinal).asInstanceOf[UTF8String] + } override def copyField(from: Row, fromOrdinal: Int, to: MutableRow, toOrdinal: Int): Unit = { - to.setString(toOrdinal, from.getString(fromOrdinal)) + to.update(toOrdinal, from(fromOrdinal)) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/InMemoryColumnarTableScan.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/InMemoryColumnarTableScan.scala index 6eee0c86d6a1c..d9b6fb43ab83d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/InMemoryColumnarTableScan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/InMemoryColumnarTableScan.scala @@ -19,13 +19,15 @@ package org.apache.spark.sql.columnar import java.nio.ByteBuffer -import org.apache.spark.Accumulator +import org.apache.spark.{Accumulable, Accumulator, Accumulators} import org.apache.spark.sql.catalyst.expressions import scala.collection.mutable.ArrayBuffer +import scala.collection.mutable.HashMap import org.apache.spark.rdd.RDD import org.apache.spark.sql.Row +import org.apache.spark.SparkContext import org.apache.spark.sql.catalyst.analysis.MultiInstanceRelation import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.expressions._ @@ -53,11 +55,16 @@ private[sql] case class InMemoryRelation( child: SparkPlan, tableName: Option[String])( private var _cachedColumnBuffers: RDD[CachedBatch] = null, - private var _statistics: Statistics = null) + private var _statistics: Statistics = null, + private var _batchStats: Accumulable[ArrayBuffer[Row], Row] = null) extends LogicalPlan with MultiInstanceRelation { - private val batchStats = - child.sqlContext.sparkContext.accumulableCollection(ArrayBuffer.empty[Row]) + private val batchStats: Accumulable[ArrayBuffer[Row], Row] = + if (_batchStats == null) { + child.sqlContext.sparkContext.accumulableCollection(ArrayBuffer.empty[Row]) + } else { + _batchStats + } val partitionStatistics = new PartitionStatistics(output) @@ -161,7 +168,7 @@ private[sql] case class InMemoryRelation( def withOutput(newOutput: Seq[Attribute]): InMemoryRelation = { InMemoryRelation( newOutput, useCompression, batchSize, storageLevel, child, tableName)( - _cachedColumnBuffers, statisticsToBePropagated) + _cachedColumnBuffers, statisticsToBePropagated, batchStats) } override def children: Seq[LogicalPlan] = Seq.empty @@ -175,13 +182,20 @@ private[sql] case class InMemoryRelation( child, tableName)( _cachedColumnBuffers, - statisticsToBePropagated).asInstanceOf[this.type] + statisticsToBePropagated, + batchStats).asInstanceOf[this.type] } def cachedColumnBuffers: RDD[CachedBatch] = _cachedColumnBuffers override protected def otherCopyArgs: Seq[AnyRef] = - Seq(_cachedColumnBuffers, statisticsToBePropagated) + Seq(_cachedColumnBuffers, statisticsToBePropagated, batchStats) + + private[sql] def uncache(blocking: Boolean): Unit = { + Accumulators.remove(batchStats.id) + cachedColumnBuffers.unpersist(blocking) + _cachedColumnBuffers = null + } } private[sql] case class InMemoryColumnarTableScan( @@ -244,15 +258,20 @@ private[sql] case class InMemoryColumnarTableScan( } } + lazy val enableAccumulators: Boolean = + sqlContext.getConf("spark.sql.inMemoryTableScanStatistics.enable", "false").toBoolean + // Accumulators used for testing purposes - val readPartitions: Accumulator[Int] = sparkContext.accumulator(0) - val readBatches: Accumulator[Int] = sparkContext.accumulator(0) + lazy val readPartitions: Accumulator[Int] = sparkContext.accumulator(0) + lazy val readBatches: Accumulator[Int] = sparkContext.accumulator(0) private val inMemoryPartitionPruningEnabled = sqlContext.conf.inMemoryPartitionPruning override def execute(): RDD[Row] = { - readPartitions.setValue(0) - readBatches.setValue(0) + if (enableAccumulators) { + readPartitions.setValue(0) + readBatches.setValue(0) + } relation.cachedColumnBuffers.mapPartitions { cachedBatchIterator => val partitionFilter = newPredicate( @@ -302,7 +321,7 @@ private[sql] case class InMemoryColumnarTableScan( } } - if (rows.hasNext) { + if (rows.hasNext && enableAccumulators) { readPartitions += 1 } @@ -321,7 +340,9 @@ private[sql] case class InMemoryColumnarTableScan( logInfo(s"Skipping partition based on stats $statsString") false } else { - readBatches += 1 + if (enableAccumulators) { + readBatches += 1 + } true } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala index 437408d30bfd2..69a620e1ec929 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala @@ -19,24 +19,42 @@ package org.apache.spark.sql.execution import org.apache.spark.annotation.DeveloperApi import org.apache.spark.shuffle.sort.SortShuffleManager -import org.apache.spark.sql.catalyst.expressions import org.apache.spark.{SparkEnv, HashPartitioner, RangePartitioner, SparkConf} import org.apache.spark.rdd.{RDD, ShuffledRDD} import org.apache.spark.sql.{SQLContext, Row} import org.apache.spark.sql.catalyst.errors.attachTree -import org.apache.spark.sql.catalyst.expressions.{Attribute, RowOrdering} +import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.util.MutablePair +object Exchange { + /** + * Returns true when the ordering expressions are a subset of the key. + * if true, ShuffledRDD can use `setKeyOrdering(orderingKey)` to sort within [[Exchange]]. + */ + def canSortWithShuffle(partitioning: Partitioning, desiredOrdering: Seq[SortOrder]): Boolean = { + desiredOrdering.map(_.child).toSet.subsetOf(partitioning.keyExpressions.toSet) + } +} + /** * :: DeveloperApi :: + * Performs a shuffle that will result in the desired `newPartitioning`. Optionally sorts each + * resulting partition based on expressions from the partition key. It is invalid to construct an + * exchange operator with a `newOrdering` that cannot be calculated using the partitioning key. */ @DeveloperApi -case class Exchange(newPartitioning: Partitioning, child: SparkPlan) extends UnaryNode { +case class Exchange( + newPartitioning: Partitioning, + newOrdering: Seq[SortOrder], + child: SparkPlan) + extends UnaryNode { override def outputPartitioning: Partitioning = newPartitioning + override def outputOrdering: Seq[SortOrder] = newOrdering + override def output: Seq[Attribute] = child.output /** We must copy rows when sort based shuffle is on */ @@ -45,7 +63,23 @@ case class Exchange(newPartitioning: Partitioning, child: SparkPlan) extends Una private val bypassMergeThreshold = child.sqlContext.sparkContext.conf.getInt("spark.shuffle.sort.bypassMergeThreshold", 200) + private val keyOrdering = { + if (newOrdering.nonEmpty) { + val key = newPartitioning.keyExpressions + val boundOrdering = newOrdering.map { o => + val ordinal = key.indexOf(o.child) + if (ordinal == -1) sys.error(s"Invalid ordering on $o requested for $newPartitioning") + o.copy(child = BoundReference(ordinal, o.child.dataType, o.child.nullable)) + } + new RowOrdering(boundOrdering) + } else { + null // Ordering will not be used + } + } + override def execute(): RDD[Row] = attachTree(this , "execute") { + lazy val sparkConf = child.sqlContext.sparkContext.getConf + newPartitioning match { case HashPartitioning(expressions, numPartitions) => // TODO: Eliminate redundant expressions in grouping key and value. @@ -56,7 +90,9 @@ case class Exchange(newPartitioning: Partitioning, child: SparkPlan) extends Una // we can avoid the defensive copies to improve performance. In the long run, we probably // want to include information in shuffle dependencies to indicate whether elements in the // source RDD should be copied. - val rdd = if (sortBasedShuffleOn && numPartitions > bypassMergeThreshold) { + val willMergeSort = sortBasedShuffleOn && numPartitions > bypassMergeThreshold + + val rdd = if (willMergeSort || newOrdering.nonEmpty) { child.execute().mapPartitions { iter => val hashExpressions = newMutableProjection(expressions, child.output)() iter.map(r => (hashExpressions(r).copy(), r.copy())) @@ -69,12 +105,17 @@ case class Exchange(newPartitioning: Partitioning, child: SparkPlan) extends Una } } val part = new HashPartitioner(numPartitions) - val shuffled = new ShuffledRDD[Row, Row, Row](rdd, part) - shuffled.setSerializer(new SparkSqlSerializer(new SparkConf(false))) + val shuffled = + if (newOrdering.nonEmpty) { + new ShuffledRDD[Row, Row, Row](rdd, part).setKeyOrdering(keyOrdering) + } else { + new ShuffledRDD[Row, Row, Row](rdd, part) + } + shuffled.setSerializer(new SparkSqlSerializer(sparkConf)) shuffled.map(_._2) case RangePartitioning(sortingExpressions, numPartitions) => - val rdd = if (sortBasedShuffleOn) { + val rdd = if (sortBasedShuffleOn || newOrdering.nonEmpty) { child.execute().mapPartitions { iter => iter.map(row => (row.copy(), null))} } else { child.execute().mapPartitions { iter => @@ -87,9 +128,13 @@ case class Exchange(newPartitioning: Partitioning, child: SparkPlan) extends Una implicit val ordering = new RowOrdering(sortingExpressions, child.output) val part = new RangePartitioner(numPartitions, rdd, ascending = true) - val shuffled = new ShuffledRDD[Row, Null, Null](rdd, part) - shuffled.setSerializer(new SparkSqlSerializer(new SparkConf(false))) - + val shuffled = + if (newOrdering.nonEmpty) { + new ShuffledRDD[Row, Null, Null](rdd, part).setKeyOrdering(keyOrdering) + } else { + new ShuffledRDD[Row, Null, Null](rdd, part) + } + shuffled.setSerializer(new SparkSqlSerializer(sparkConf)) shuffled.map(_._1) case SinglePartition => @@ -107,7 +152,7 @@ case class Exchange(newPartitioning: Partitioning, child: SparkPlan) extends Una } val partitioner = new HashPartitioner(1) val shuffled = new ShuffledRDD[Null, Row, Row](rdd, partitioner) - shuffled.setSerializer(new SparkSqlSerializer(new SparkConf(false))) + shuffled.setSerializer(new SparkSqlSerializer(sparkConf)) shuffled.map(_._2) case _ => sys.error(s"Exchange not implemented for $newPartitioning") @@ -120,27 +165,34 @@ case class Exchange(newPartitioning: Partitioning, child: SparkPlan) extends Una * Ensures that the [[org.apache.spark.sql.catalyst.plans.physical.Partitioning Partitioning]] * of input data meets the * [[org.apache.spark.sql.catalyst.plans.physical.Distribution Distribution]] requirements for - * each operator by inserting [[Exchange]] Operators where required. + * each operator by inserting [[Exchange]] Operators where required. Also ensure that the + * required input partition ordering requirements are met. */ -private[sql] case class AddExchange(sqlContext: SQLContext) extends Rule[SparkPlan] { +private[sql] case class EnsureRequirements(sqlContext: SQLContext) extends Rule[SparkPlan] { // TODO: Determine the number of partitions. def numPartitions: Int = sqlContext.conf.numShufflePartitions def apply(plan: SparkPlan): SparkPlan = plan.transformUp { case operator: SparkPlan => - // Check if every child's outputPartitioning satisfies the corresponding + // True iff every child's outputPartitioning satisfies the corresponding // required data distribution. def meetsRequirements: Boolean = - !operator.requiredChildDistribution.zip(operator.children).map { + operator.requiredChildDistribution.zip(operator.children).forall { case (required, child) => val valid = child.outputPartitioning.satisfies(required) logDebug( s"${if (valid) "Valid" else "Invalid"} distribution," + s"required: $required current: ${child.outputPartitioning}") valid - }.exists(!_) + } - // Check if outputPartitionings of children are compatible with each other. + // True iff any of the children are incorrectly sorted. + def needsAnySort: Boolean = + operator.requiredChildOrdering.zip(operator.children).exists { + case (required, child) => required.nonEmpty && required != child.outputOrdering + } + + // True iff outputPartitionings of children are compatible with each other. // It is possible that every child satisfies its required data distribution // but two children have incompatible outputPartitionings. For example, // A dataset is range partitioned by "a.asc" (RangePartitioning) and another @@ -157,28 +209,69 @@ private[sql] case class AddExchange(sqlContext: SQLContext) extends Rule[SparkPl case Seq(a,b) => a compatibleWith b }.exists(!_) - // Check if the partitioning we want to ensure is the same as the child's output - // partitioning. If so, we do not need to add the Exchange operator. - def addExchangeIfNecessary(partitioning: Partitioning, child: SparkPlan): SparkPlan = - if (child.outputPartitioning != partitioning) Exchange(partitioning, child) else child + // Adds Exchange or Sort operators as required + def addOperatorsIfNecessary( + partitioning: Partitioning, + rowOrdering: Seq[SortOrder], + child: SparkPlan): SparkPlan = { + val needSort = rowOrdering.nonEmpty && child.outputOrdering != rowOrdering + val needsShuffle = child.outputPartitioning != partitioning + val canSortWithShuffle = Exchange.canSortWithShuffle(partitioning, rowOrdering) + + if (needSort && needsShuffle && canSortWithShuffle) { + Exchange(partitioning, rowOrdering, child) + } else { + val withShuffle = if (needsShuffle) { + Exchange(partitioning, Nil, child) + } else { + child + } + + val withSort = if (needSort) { + if (sqlContext.conf.externalSortEnabled) { + ExternalSort(rowOrdering, global = false, withShuffle) + } else { + Sort(rowOrdering, global = false, withShuffle) + } + } else { + withShuffle + } + + withSort + } + } - if (meetsRequirements && compatible) { + if (meetsRequirements && compatible && !needsAnySort) { operator } else { // At least one child does not satisfies its required data distribution or // at least one child's outputPartitioning is not compatible with another child's // outputPartitioning. In this case, we need to add Exchange operators. - val repartitionedChildren = operator.requiredChildDistribution.zip(operator.children).map { - case (AllTuples, child) => - addExchangeIfNecessary(SinglePartition, child) - case (ClusteredDistribution(clustering), child) => - addExchangeIfNecessary(HashPartitioning(clustering, numPartitions), child) - case (OrderedDistribution(ordering), child) => - addExchangeIfNecessary(RangePartitioning(ordering, numPartitions), child) - case (UnspecifiedDistribution, child) => child - case (dist, _) => sys.error(s"Don't know how to ensure $dist") + val requirements = + (operator.requiredChildDistribution, operator.requiredChildOrdering, operator.children) + + val fixedChildren = requirements.zipped.map { + case (AllTuples, rowOrdering, child) => + addOperatorsIfNecessary(SinglePartition, rowOrdering, child) + case (ClusteredDistribution(clustering), rowOrdering, child) => + addOperatorsIfNecessary(HashPartitioning(clustering, numPartitions), rowOrdering, child) + case (OrderedDistribution(ordering), rowOrdering, child) => + addOperatorsIfNecessary(RangePartitioning(ordering, numPartitions), rowOrdering, child) + + case (UnspecifiedDistribution, Seq(), child) => + child + case (UnspecifiedDistribution, rowOrdering, child) => + if (sqlContext.conf.externalSortEnabled) { + ExternalSort(rowOrdering, global = false, child) + } else { + Sort(rowOrdering, global = false, child) + } + + case (dist, ordering, _) => + sys.error(s"Don't know how to ensure $dist with ordering $ordering") } - operator.withNewChildren(repartitionedChildren) + + operator.withNewChildren(fixedChildren) } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala index 656bdd7212f56..1fd387eec7e57 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala @@ -19,12 +19,12 @@ package org.apache.spark.sql.execution import org.apache.spark.annotation.DeveloperApi import org.apache.spark.rdd.RDD -import org.apache.spark.sql.{Row, SQLContext} import org.apache.spark.sql.catalyst.CatalystTypeConverters import org.apache.spark.sql.catalyst.analysis.MultiInstanceRelation -import org.apache.spark.sql.catalyst.expressions.{SpecificMutableRow, Attribute} +import org.apache.spark.sql.catalyst.expressions.{Attribute, GenericMutableRow, SpecificMutableRow} import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Statistics} import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.{Row, SQLContext} /** * :: DeveloperApi :: @@ -54,6 +54,33 @@ object RDDConversions { } } } + + /** + * Convert the objects inside Row into the types Catalyst expected. + */ + def rowToRowRdd(data: RDD[Row], schema: StructType): RDD[Row] = { + data.mapPartitions { iterator => + if (iterator.isEmpty) { + Iterator.empty + } else { + val bufferedIterator = iterator.buffered + val mutableRow = new GenericMutableRow(bufferedIterator.head.toSeq.toArray) + val schemaFields = schema.fields.toArray + val converters = schemaFields.map { + f => CatalystTypeConverters.createToCatalystConverter(f.dataType) + } + bufferedIterator.map { r => + var i = 0 + while (i < mutableRow.length) { + mutableRow(i) = converters(i)(r(i)) + i += 1 + } + + mutableRow + } + } + } + } } /** Logical plan node for scanning data from an RDD. */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala index 95176e425132d..b1ef6556de1e9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala @@ -153,51 +153,6 @@ case class GeneratedAggregate( AggregateEvaluation(currentSum :: Nil, initialValue :: Nil, updateFunction :: Nil, result) - case a @ Average(expr) => - val calcType = - expr.dataType match { - case DecimalType.Fixed(_, _) => - DecimalType.Unlimited - case _ => - expr.dataType - } - - val currentCount = AttributeReference("currentCount", LongType, nullable = false)() - val currentSum = AttributeReference("currentSum", calcType, nullable = false)() - val initialCount = Literal(0L) - val initialSum = Cast(Literal(0L), calcType) - - // If we're evaluating UnscaledValue(x), we can do Count on x directly, since its - // UnscaledValue will be null if and only if x is null; helps with Average on decimals - val toCount = expr match { - case UnscaledValue(e) => e - case _ => expr - } - - val updateCount = If(IsNotNull(toCount), Add(currentCount, Literal(1L)), currentCount) - val updateSum = Coalesce(Add(Cast(expr, calcType), currentSum) :: currentSum :: Nil) - - val result = - expr.dataType match { - case DecimalType.Fixed(_, _) => - If(EqualTo(currentCount, Literal(0L)), - Literal.create(null, a.dataType), - Cast(Divide( - Cast(currentSum, DecimalType.Unlimited), - Cast(currentCount, DecimalType.Unlimited)), a.dataType)) - case _ => - If(EqualTo(currentCount, Literal(0L)), - Literal.create(null, a.dataType), - Divide(Cast(currentSum, a.dataType), Cast(currentCount, a.dataType))) - } - - AggregateEvaluation( - currentCount :: currentSum :: Nil, - initialCount :: initialSum :: Nil, - updateCount :: updateSum :: Nil, - result - ) - case m @ Max(expr) => val currentMax = AttributeReference("currentMax", expr.dataType, nullable = true)() val initialValue = Literal.create(null, expr.dataType) @@ -209,6 +164,17 @@ case class GeneratedAggregate( updateMax :: Nil, currentMax) + case m @ Min(expr) => + val currentMin = AttributeReference("currentMin", expr.dataType, nullable = true)() + val initialValue = Literal.create(null, expr.dataType) + val updateMin = MinOf(currentMin, expr) + + AggregateEvaluation( + currentMin :: Nil, + initialValue :: Nil, + updateMin :: Nil, + currentMin) + case CollectHashSet(Seq(expr)) => val set = AttributeReference("hashSet", new OpenHashSetUDT(expr.dataType), nullable = false)() @@ -233,6 +199,8 @@ case class GeneratedAggregate( initialValue :: Nil, collectSets :: Nil, CountSet(set)) + + case o => sys.error(s"$o can't be codegened.") } val computationSchema = computeFunctions.flatMap(_.schema) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala index fabcf6b4a0570..e159ffe66cb24 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala @@ -72,6 +72,12 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ def requiredChildDistribution: Seq[Distribution] = Seq.fill(children.size)(UnspecifiedDistribution) + /** Specifies how data is ordered in each partition. */ + def outputOrdering: Seq[SortOrder] = Nil + + /** Specifies sort order for each partition requirements on the input data for this operator. */ + def requiredChildOrdering: Seq[Seq[SortOrder]] = Seq.fill(children.size)(Nil) + /** * Runs this query returning the result as an RDD. */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlSerializer.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlSerializer.scala index 914f387dec78f..eea15aff5dbcf 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlSerializer.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlSerializer.scala @@ -65,12 +65,9 @@ private[sql] class SparkSqlSerializer(conf: SparkConf) extends KryoSerializer(co private[execution] class KryoResourcePool(size: Int) extends ResourcePool[SerializerInstance](size) { - val ser: KryoSerializer = { + val ser: SparkSqlSerializer = { val sparkConf = Option(SparkEnv.get).map(_.conf).getOrElse(new SparkConf()) - // TODO (lian) Using KryoSerializer here is workaround, needs further investigation - // Using SparkSqlSerializer here makes BasicQuerySuite to fail because of Kryo serialization - // related error. - new KryoSerializer(sparkConf) + new SparkSqlSerializer(sparkConf) } def newInstance(): SerializerInstance = ser.newInstance() diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index f0d92ffffcda3..e687d01f57520 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -90,6 +90,14 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { left.statistics.sizeInBytes <= sqlContext.conf.autoBroadcastJoinThreshold => makeBroadcastHashJoin(leftKeys, rightKeys, left, right, condition, joins.BuildLeft) + // If the sort merge join option is set, we want to use sort merge join prior to hashjoin + // for now let's support inner join first, then add outer join + case ExtractEquiJoinKeys(Inner, leftKeys, rightKeys, condition, left, right) + if sqlContext.conf.sortMergeJoinEnabled => + val mergeJoin = + joins.SortMergeJoin(leftKeys, rightKeys, planLater(left), planLater(right)) + condition.map(Filter(_, mergeJoin)).getOrElse(mergeJoin) :: Nil + case ExtractEquiJoinKeys(Inner, leftKeys, rightKeys, condition, left, right) => val buildSide = if (right.statistics.sizeInBytes <= left.statistics.sizeInBytes) { @@ -155,7 +163,7 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { } def canBeCodeGened(aggs: Seq[AggregateExpression]): Boolean = !aggs.exists { - case _: CombineSum | _: Sum | _: Count | _: Max | _: CombineSetsAndCount => false + case _: CombineSum | _: Sum | _: Count | _: Max | _: Min | _: CombineSetsAndCount => false // The generated set implementation is pretty limited ATM. case CollectHashSet(exprs) if exprs.size == 1 && Seq(IntegerType, LongType).contains(exprs.head.dataType) => false @@ -309,7 +317,8 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { case logical.OneRowRelation => execution.PhysicalRDD(Nil, singleRowRdd) :: Nil case logical.Repartition(expressions, child) => - execution.Exchange(HashPartitioning(expressions, numPartitions), planLater(child)) :: Nil + execution.Exchange( + HashPartitioning(expressions, numPartitions), Nil, planLater(child)) :: Nil case e @ EvaluatePython(udf, child, _) => BatchPythonEvaluation(udf, e.output, planLater(child)) :: Nil case LogicalRDD(output, rdd) => PhysicalRDD(output, rdd) :: Nil diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala index f8221f41bc6c3..d286fe81bee5f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala @@ -41,6 +41,8 @@ case class Project(projectList: Seq[NamedExpression], child: SparkPlan) extends val resuableProjection = buildProjection() iter.map(resuableProjection) } + + override def outputOrdering: Seq[SortOrder] = child.outputOrdering } /** @@ -55,6 +57,8 @@ case class Filter(condition: Expression, child: SparkPlan) extends UnaryNode { override def execute(): RDD[Row] = child.execute().mapPartitions { iter => iter.filter(conditionEvaluator) } + + override def outputOrdering: Seq[SortOrder] = child.outputOrdering } /** @@ -117,7 +121,7 @@ case class Limit(limit: Int, child: SparkPlan) } val part = new HashPartitioner(1) val shuffled = new ShuffledRDD[Boolean, Row, Row](rdd, part) - shuffled.setSerializer(new SparkSqlSerializer(new SparkConf(false))) + shuffled.setSerializer(new SparkSqlSerializer(child.sqlContext.sparkContext.getConf)) shuffled.mapPartitions(_.take(limit).map(_._2)) } } @@ -147,6 +151,8 @@ case class TakeOrdered(limit: Int, sortOrder: Seq[SortOrder], child: SparkPlan) // TODO: Terminal split should be implemented differently from non-terminal split. // TODO: Pick num splits based on |limit|. override def execute(): RDD[Row] = sparkContext.makeRDD(collectData(), 1) + + override def outputOrdering: Seq[SortOrder] = sortOrder } /** @@ -172,6 +178,8 @@ case class Sort( } override def output: Seq[Attribute] = child.output + + override def outputOrdering: Seq[SortOrder] = sortOrder } /** @@ -202,6 +210,8 @@ case class ExternalSort( } override def output: Seq[Attribute] = child.output + + override def outputOrdering: Seq[SortOrder] = sortOrder } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/commands.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/commands.scala index fad7a281dc1e2..99f24910fd61f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/commands.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/commands.scala @@ -20,12 +20,13 @@ package org.apache.spark.sql.execution import org.apache.spark.Logging import org.apache.spark.annotation.DeveloperApi import org.apache.spark.rdd.RDD -import org.apache.spark.sql.types.{BooleanType, StructField, StructType, StringType} -import org.apache.spark.sql.{DataFrame, SQLConf, SQLContext} +import org.apache.spark.sql.catalyst.CatalystTypeConverters import org.apache.spark.sql.catalyst.errors.TreeNodeException -import org.apache.spark.sql.catalyst.expressions.{AttributeReference, Row, Attribute} +import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, Row} import org.apache.spark.sql.catalyst.plans.logical import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.types._ +import org.apache.spark.sql.{DataFrame, SQLConf, SQLContext} /** * A logical command that is executed for its side-effects. `RunnableCommand`s are @@ -61,7 +62,11 @@ case class ExecutedCommand(cmd: RunnableCommand) extends SparkPlan { override def executeTake(limit: Int): Array[Row] = sideEffectResult.take(limit).toArray - override def execute(): RDD[Row] = sqlContext.sparkContext.parallelize(sideEffectResult, 1) + override def execute(): RDD[Row] = { + val converted = sideEffectResult.map(r => + CatalystTypeConverters.convertToCatalyst(r, schema).asInstanceOf[Row]) + sqlContext.sparkContext.parallelize(converted, 1) + } } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala index e916e68e58b5d..710787096e6cb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala @@ -164,7 +164,7 @@ package object debug { case (_: Long, LongType) => case (_: Int, IntegerType) => - case (_: String, StringType) => + case (_: UTF8String, StringType) => case (_: Float, FloatType) => case (_: Byte, ByteType) => case (_: Short, ShortType) => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala new file mode 100644 index 0000000000000..b5123668ba11e --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala @@ -0,0 +1,169 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.joins + +import java.util.NoSuchElementException + +import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.Row +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.plans._ +import org.apache.spark.sql.catalyst.plans.physical._ +import org.apache.spark.sql.execution.{BinaryNode, SparkPlan} +import org.apache.spark.util.collection.CompactBuffer + +/** + * :: DeveloperApi :: + * Performs an sort merge join of two child relations. + */ +@DeveloperApi +case class SortMergeJoin( + leftKeys: Seq[Expression], + rightKeys: Seq[Expression], + left: SparkPlan, + right: SparkPlan) extends BinaryNode { + + override def output: Seq[Attribute] = left.output ++ right.output + + override def outputPartitioning: Partitioning = left.outputPartitioning + + override def requiredChildDistribution: Seq[Distribution] = + ClusteredDistribution(leftKeys) :: ClusteredDistribution(rightKeys) :: Nil + + // this is to manually construct an ordering that can be used to compare keys from both sides + private val keyOrdering: RowOrdering = RowOrdering.forSchema(leftKeys.map(_.dataType)) + + override def outputOrdering: Seq[SortOrder] = requiredOrders(leftKeys) + + override def requiredChildOrdering: Seq[Seq[SortOrder]] = + requiredOrders(leftKeys) :: requiredOrders(rightKeys) :: Nil + + @transient protected lazy val leftKeyGenerator = newProjection(leftKeys, left.output) + @transient protected lazy val rightKeyGenerator = newProjection(rightKeys, right.output) + + private def requiredOrders(keys: Seq[Expression]): Seq[SortOrder] = + keys.map(SortOrder(_, Ascending)) + + override def execute(): RDD[Row] = { + val leftResults = left.execute().map(_.copy()) + val rightResults = right.execute().map(_.copy()) + + leftResults.zipPartitions(rightResults) { (leftIter, rightIter) => + new Iterator[Row] { + // Mutable per row objects. + private[this] val joinRow = new JoinedRow5 + private[this] var leftElement: Row = _ + private[this] var rightElement: Row = _ + private[this] var leftKey: Row = _ + private[this] var rightKey: Row = _ + private[this] var rightMatches: CompactBuffer[Row] = _ + private[this] var rightPosition: Int = -1 + private[this] var stop: Boolean = false + private[this] var matchKey: Row = _ + + // initialize iterator + initialize() + + override final def hasNext: Boolean = nextMatchingPair() + + override final def next(): Row = { + if (hasNext) { + // we are using the buffered right rows and run down left iterator + val joinedRow = joinRow(leftElement, rightMatches(rightPosition)) + rightPosition += 1 + if (rightPosition >= rightMatches.size) { + rightPosition = 0 + fetchLeft() + if (leftElement == null || keyOrdering.compare(leftKey, matchKey) != 0) { + stop = false + rightMatches = null + } + } + joinedRow + } else { + // no more result + throw new NoSuchElementException + } + } + + private def fetchLeft() = { + if (leftIter.hasNext) { + leftElement = leftIter.next() + leftKey = leftKeyGenerator(leftElement) + } else { + leftElement = null + } + } + + private def fetchRight() = { + if (rightIter.hasNext) { + rightElement = rightIter.next() + rightKey = rightKeyGenerator(rightElement) + } else { + rightElement = null + } + } + + private def initialize() = { + fetchLeft() + fetchRight() + } + + /** + * Searches the right iterator for the next rows that have matches in left side, and store + * them in a buffer. + * + * @return true if the search is successful, and false if the right iterator runs out of + * tuples. + */ + private def nextMatchingPair(): Boolean = { + if (!stop && rightElement != null) { + // run both side to get the first match pair + while (!stop && leftElement != null && rightElement != null) { + val comparing = keyOrdering.compare(leftKey, rightKey) + // for inner join, we need to filter those null keys + stop = comparing == 0 && !leftKey.anyNull + if (comparing > 0 || rightKey.anyNull) { + fetchRight() + } else if (comparing < 0 || leftKey.anyNull) { + fetchLeft() + } + } + rightMatches = new CompactBuffer[Row]() + if (stop) { + stop = false + // iterate the right side to buffer all rows that matches + // as the records should be ordered, exit when we meet the first that not match + while (!stop && rightElement != null) { + rightMatches += rightElement + fetchRight() + stop = keyOrdering.compare(leftKey, rightKey) != 0 + } + if (rightMatches.size > 0) { + rightPosition = 0 + matchKey = leftKey + } + } + } + rightMatches != null && rightMatches.size > 0 + } + } + } + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUdfs.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUdfs.scala index 5b308d88d4cdf..7a43bfd8bc8d9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUdfs.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUdfs.scala @@ -140,6 +140,7 @@ object EvaluatePython { case (ud, udt: UserDefinedType[_]) => toJava(udt.serialize(ud), udt.sqlType) case (date: Int, DateType) => DateUtils.toJavaDate(date) + case (s: UTF8String, StringType) => s.toString // Pyrolite can handle Timestamp and Decimal case (other, _) => other @@ -192,7 +193,8 @@ object EvaluatePython { case (c: Long, IntegerType) => c.toInt case (c: Int, LongType) => c.toLong case (c: Double, FloatType) => c.toFloat - case (c, StringType) if !c.isInstanceOf[String] => c.toString + case (c: String, StringType) => UTF8String(c) + case (c, StringType) if !c.isInstanceOf[String] => UTF8String(c.toString) case (c, _) => c } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JDBCRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JDBCRDD.scala index 463e1dcc268bc..b9022fcd9e3ad 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JDBCRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JDBCRDD.scala @@ -233,7 +233,7 @@ private[sql] class JDBCRDD( * Converts value to SQL expression. */ private def compileValue(value: Any): Any = value match { - case stringValue: String => s"'${escapeSql(stringValue)}'" + case stringValue: UTF8String => s"'${escapeSql(stringValue.toString)}'" case _ => value } @@ -349,12 +349,14 @@ private[sql] class JDBCRDD( val pos = i + 1 conversions(i) match { case BooleanConversion => mutableRow.setBoolean(i, rs.getBoolean(pos)) + // TODO(davies): convert Date into Int case DateConversion => mutableRow.update(i, rs.getDate(pos)) case DecimalConversion => mutableRow.update(i, rs.getBigDecimal(pos)) case DoubleConversion => mutableRow.setDouble(i, rs.getDouble(pos)) case FloatConversion => mutableRow.setFloat(i, rs.getFloat(pos)) case IntegerConversion => mutableRow.setInt(i, rs.getInt(pos)) case LongConversion => mutableRow.setLong(i, rs.getLong(pos)) + // TODO(davies): use getBytes for better performance, if the encoding is UTF-8 case StringConversion => mutableRow.setString(i, rs.getString(pos)) case TimestampConversion => mutableRow.update(i, rs.getTimestamp(pos)) case BinaryConversion => mutableRow.update(i, rs.getBytes(pos)) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JDBCRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JDBCRelation.scala index 4fa84dc076f7e..5f480083d5a49 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JDBCRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JDBCRelation.scala @@ -28,6 +28,7 @@ import org.apache.spark.sql.SQLContext import org.apache.spark.sql.catalyst.expressions.Row import org.apache.spark.sql.sources._ import org.apache.spark.sql.types.StructType +import org.apache.spark.util.Utils /** * Data corresponding to one partition of a JDBCRDD. @@ -99,7 +100,7 @@ private[sql] class DefaultSource extends RelationProvider { val upperBound = parameters.getOrElse("upperBound", null) val numPartitions = parameters.getOrElse("numPartitions", null) - if (driver != null) Class.forName(driver) + if (driver != null) Utils.getContextOrSparkClassLoader.loadClass(driver) if (partitionColumn != null && (lowerBound == null || upperBound == null || numPartitions == null)) { @@ -130,6 +131,8 @@ private[sql] case class JDBCRelation( extends BaseRelation with PrunedFilteredScan { + override val needConversion: Boolean = false + override val schema: StructType = JDBCRDD.resolveTable(url, table, properties) override def buildScan(requiredColumns: Array[String], filters: Array[Filter]): RDD[Row] = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/jdbc.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/jdbc.scala index 34f864f5fda7a..d4e0abc040bc6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/jdbc.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/jdbc.scala @@ -18,11 +18,8 @@ package org.apache.spark.sql import java.sql.{Connection, DriverManager, PreparedStatement} -import org.apache.spark.{Logging, Partition} -import org.apache.spark.sql._ -import org.apache.spark.sql.sources.LogicalRelation -import org.apache.spark.sql.jdbc.{JDBCPartitioningInfo, JDBCRelation, JDBCPartition} +import org.apache.spark.Logging import org.apache.spark.sql.types._ package object jdbc { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/json/JSONRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/json/JSONRelation.scala index f4c99b4b56606..e3352d02787fd 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/json/JSONRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/json/JSONRelation.scala @@ -20,12 +20,12 @@ package org.apache.spark.sql.json import java.io.IOException import org.apache.hadoop.fs.Path + import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.expressions.Row - -import org.apache.spark.sql.{SaveMode, DataFrame, SQLContext} import org.apache.spark.sql.sources._ -import org.apache.spark.sql.types.{DataType, StructType} +import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.{DataFrame, SQLContext, SaveMode} private[sql] class DefaultSource @@ -113,6 +113,8 @@ private[sql] case class JSONRelation( // TODO: Support partitioned JSON relation. private def baseRDD = sqlContext.sparkContext.textFile(path) + override val needConversion: Boolean = false + override val schema = userSpecifiedSchema.getOrElse( JsonRDD.nullTypeToStringType( JsonRDD.inferSchema( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/json/JsonRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/json/JsonRDD.scala index b1e8521383756..29de7401dda71 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/json/JsonRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/json/JsonRDD.scala @@ -409,7 +409,7 @@ private[sql] object JsonRDD extends Logging { null } else { desiredType match { - case StringType => toString(value) + case StringType => UTF8String(toString(value)) case _ if value == null || value == "" => null // guard the non string type case IntegerType => value.asInstanceOf[IntegerType.JvmType] case LongType => toLong(value) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/DirectParquetOutputCommitter.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/DirectParquetOutputCommitter.scala deleted file mode 100644 index 25a66cb488103..0000000000000 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/DirectParquetOutputCommitter.scala +++ /dev/null @@ -1,66 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.parquet - -import org.apache.hadoop.fs.Path -import org.apache.hadoop.mapreduce.{JobContext, TaskAttemptContext} -import org.apache.hadoop.mapreduce.lib.output.FileOutputCommitter - -import parquet.Log -import parquet.hadoop.util.ContextUtil -import parquet.hadoop.{ParquetFileReader, ParquetFileWriter, ParquetOutputCommitter} - -private[parquet] class DirectParquetOutputCommitter(outputPath: Path, context: TaskAttemptContext) - extends ParquetOutputCommitter(outputPath, context) { - val LOG = Log.getLog(classOf[ParquetOutputCommitter]) - - override def getWorkPath(): Path = outputPath - override def abortTask(taskContext: TaskAttemptContext): Unit = {} - override def commitTask(taskContext: TaskAttemptContext): Unit = {} - override def needsTaskCommit(taskContext: TaskAttemptContext): Boolean = true - override def setupJob(jobContext: JobContext): Unit = {} - override def setupTask(taskContext: TaskAttemptContext): Unit = {} - - override def commitJob(jobContext: JobContext) { - try { - val configuration = ContextUtil.getConfiguration(jobContext) - val fileSystem = outputPath.getFileSystem(configuration) - val outputStatus = fileSystem.getFileStatus(outputPath) - val footers = ParquetFileReader.readAllFootersInParallel(configuration, outputStatus) - try { - ParquetFileWriter.writeMetadataFile(configuration, outputPath, footers) - if (configuration.getBoolean("mapreduce.fileoutputcommitter.marksuccessfuljobs", true)) { - val successPath = new Path(outputPath, FileOutputCommitter.SUCCEEDED_FILE_NAME) - fileSystem.create(successPath).close() - } - } catch { - case e: Exception => { - LOG.warn("could not write summary file for " + outputPath, e) - val metadataPath = new Path(outputPath, ParquetFileWriter.PARQUET_METADATA_FILE) - if (fileSystem.exists(metadataPath)) { - fileSystem.delete(metadataPath, true) - } - } - } - } catch { - case e: Exception => LOG.warn("could not write summary file for " + outputPath, e) - } - } - -} - diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetConverter.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetConverter.scala index 43ca359b51735..bc108e37dfb0f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetConverter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetConverter.scala @@ -219,8 +219,8 @@ private[parquet] abstract class CatalystConverter extends GroupConverter { protected[parquet] def updateBinary(fieldIndex: Int, value: Binary): Unit = updateField(fieldIndex, value.getBytes) - protected[parquet] def updateString(fieldIndex: Int, value: String): Unit = - updateField(fieldIndex, value) + protected[parquet] def updateString(fieldIndex: Int, value: Array[Byte]): Unit = + updateField(fieldIndex, UTF8String(value)) protected[parquet] def updateTimestamp(fieldIndex: Int, value: Binary): Unit = updateField(fieldIndex, readTimestamp(value)) @@ -418,8 +418,8 @@ private[parquet] class CatalystPrimitiveRowConverter( override protected[parquet] def updateBinary(fieldIndex: Int, value: Binary): Unit = current.update(fieldIndex, value.getBytes) - override protected[parquet] def updateString(fieldIndex: Int, value: String): Unit = - current.setString(fieldIndex, value) + override protected[parquet] def updateString(fieldIndex: Int, value: Array[Byte]): Unit = + current.update(fieldIndex, UTF8String(value)) override protected[parquet] def updateTimestamp(fieldIndex: Int, value: Binary): Unit = current.update(fieldIndex, readTimestamp(value)) @@ -475,19 +475,18 @@ private[parquet] class CatalystPrimitiveConverter( private[parquet] class CatalystPrimitiveStringConverter(parent: CatalystConverter, fieldIndex: Int) extends CatalystPrimitiveConverter(parent, fieldIndex) { - private[this] var dict: Array[String] = null + private[this] var dict: Array[Array[Byte]] = null override def hasDictionarySupport: Boolean = true override def setDictionary(dictionary: Dictionary):Unit = - dict = Array.tabulate(dictionary.getMaxId + 1) {dictionary.decodeToBinary(_).toStringUsingUTF8} - + dict = Array.tabulate(dictionary.getMaxId + 1) { dictionary.decodeToBinary(_).getBytes } override def addValueFromDictionary(dictionaryId: Int): Unit = parent.updateString(fieldIndex, dict(dictionaryId)) override def addBinary(value: Binary): Unit = - parent.updateString(fieldIndex, value.toStringUsingUTF8) + parent.updateString(fieldIndex, value.getBytes) } private[parquet] object CatalystArrayConverter { @@ -714,9 +713,9 @@ private[parquet] class CatalystNativeArrayConverter( elements += 1 } - override protected[parquet] def updateString(fieldIndex: Int, value: String): Unit = { + override protected[parquet] def updateString(fieldIndex: Int, value: Array[Byte]): Unit = { checkGrowBuffer() - buffer(elements) = value.asInstanceOf[NativeType] + buffer(elements) = UTF8String(value).asInstanceOf[NativeType] elements += 1 } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetFilters.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetFilters.scala index 0357dcc4688be..5eb1c6abc2432 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetFilters.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetFilters.scala @@ -55,7 +55,7 @@ private[sql] object ParquetFilters { case StringType => (n: String, v: Any) => FilterApi.eq( binaryColumn(n), - Option(v).map(s => Binary.fromString(s.asInstanceOf[String])).orNull) + Option(v).map(s => Binary.fromByteArray(s.asInstanceOf[UTF8String].getBytes)).orNull) case BinaryType => (n: String, v: Any) => FilterApi.eq( binaryColumn(n), @@ -76,7 +76,7 @@ private[sql] object ParquetFilters { case StringType => (n: String, v: Any) => FilterApi.notEq( binaryColumn(n), - Option(v).map(s => Binary.fromString(s.asInstanceOf[String])).orNull) + Option(v).map(s => Binary.fromByteArray(s.asInstanceOf[UTF8String].getBytes)).orNull) case BinaryType => (n: String, v: Any) => FilterApi.notEq( binaryColumn(n), @@ -94,7 +94,7 @@ private[sql] object ParquetFilters { (n: String, v: Any) => FilterApi.lt(doubleColumn(n), v.asInstanceOf[java.lang.Double]) case StringType => (n: String, v: Any) => - FilterApi.lt(binaryColumn(n), Binary.fromString(v.asInstanceOf[String])) + FilterApi.lt(binaryColumn(n), Binary.fromByteArray(v.asInstanceOf[UTF8String].getBytes)) case BinaryType => (n: String, v: Any) => FilterApi.lt(binaryColumn(n), Binary.fromByteArray(v.asInstanceOf[Array[Byte]])) @@ -111,7 +111,7 @@ private[sql] object ParquetFilters { (n: String, v: Any) => FilterApi.ltEq(doubleColumn(n), v.asInstanceOf[java.lang.Double]) case StringType => (n: String, v: Any) => - FilterApi.ltEq(binaryColumn(n), Binary.fromString(v.asInstanceOf[String])) + FilterApi.ltEq(binaryColumn(n), Binary.fromByteArray(v.asInstanceOf[UTF8String].getBytes)) case BinaryType => (n: String, v: Any) => FilterApi.ltEq(binaryColumn(n), Binary.fromByteArray(v.asInstanceOf[Array[Byte]])) @@ -128,7 +128,7 @@ private[sql] object ParquetFilters { (n: String, v: Any) => FilterApi.gt(doubleColumn(n), v.asInstanceOf[java.lang.Double]) case StringType => (n: String, v: Any) => - FilterApi.gt(binaryColumn(n), Binary.fromString(v.asInstanceOf[String])) + FilterApi.gt(binaryColumn(n), Binary.fromByteArray(v.asInstanceOf[UTF8String].getBytes)) case BinaryType => (n: String, v: Any) => FilterApi.gt(binaryColumn(n), Binary.fromByteArray(v.asInstanceOf[Array[Byte]])) @@ -145,7 +145,7 @@ private[sql] object ParquetFilters { (n: String, v: Any) => FilterApi.gtEq(doubleColumn(n), v.asInstanceOf[java.lang.Double]) case StringType => (n: String, v: Any) => - FilterApi.gtEq(binaryColumn(n), Binary.fromString(v.asInstanceOf[String])) + FilterApi.gtEq(binaryColumn(n), Binary.fromByteArray(v.asInstanceOf[UTF8String].getBytes)) case BinaryType => (n: String, v: Any) => FilterApi.gtEq(binaryColumn(n), Binary.fromByteArray(v.asInstanceOf[Array[Byte]])) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableOperations.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableOperations.scala index 3724bda829d30..1c868da23e060 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableOperations.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableOperations.scala @@ -379,8 +379,6 @@ private[sql] case class InsertIntoParquetTable( */ private[parquet] class AppendingParquetOutputFormat(offset: Int) extends parquet.hadoop.ParquetOutputFormat[Row] { - var committer: OutputCommitter = null - // override to accept existing directories as valid output directory override def checkOutputSpecs(job: JobContext): Unit = {} @@ -405,26 +403,6 @@ private[parquet] class AppendingParquetOutputFormat(offset: Int) private def getTaskAttemptID(context: TaskAttemptContext): TaskAttemptID = { context.getClass.getMethod("getTaskAttemptID").invoke(context).asInstanceOf[TaskAttemptID] } - - // override to create output committer from configuration - override def getOutputCommitter(context: TaskAttemptContext): OutputCommitter = { - if (committer == null) { - val output = getOutputPath(context) - val cls = context.getConfiguration.getClass("spark.sql.parquet.output.committer.class", - classOf[ParquetOutputCommitter], classOf[ParquetOutputCommitter]) - val ctor = cls.getDeclaredConstructor(classOf[Path], classOf[TaskAttemptContext]) - committer = ctor.newInstance(output, context).asInstanceOf[ParquetOutputCommitter] - } - committer - } - - // FileOutputFormat.getOutputPath takes JobConf in hadoop-1 but JobContext in hadoop-2 - private def getOutputPath(context: TaskAttemptContext): Path = { - context.getConfiguration().get("mapred.output.dir") match { - case null => null - case name => new Path(name) - } - } } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableSupport.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableSupport.scala index 5a1b15490d273..e05a4c20b0d41 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableSupport.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableSupport.scala @@ -198,10 +198,7 @@ private[parquet] class RowWriteSupport extends WriteSupport[Row] with Logging { if (value != null) { schema match { case StringType => writer.addBinary( - Binary.fromByteArray( - value.asInstanceOf[String].getBytes("utf-8") - ) - ) + Binary.fromByteArray(value.asInstanceOf[UTF8String].getBytes)) case BinaryType => writer.addBinary( Binary.fromByteArray(value.asInstanceOf[Array[Byte]])) case IntegerType => writer.addInteger(value.asInstanceOf[Int]) @@ -349,7 +346,7 @@ private[parquet] class MutableRowWriteSupport extends RowWriteSupport { index: Int): Unit = { ctype match { case StringType => writer.addBinary( - Binary.fromByteArray(record(index).asInstanceOf[String].getBytes("utf-8"))) + Binary.fromByteArray(record(index).asInstanceOf[UTF8String].getBytes)) case BinaryType => writer.addBinary( Binary.fromByteArray(record(index).asInstanceOf[Array[Byte]])) case IntegerType => writer.addInteger(record.getInt(index)) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/newParquet.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/newParquet.scala index 20fdf5e58ef82..af7b3c81ae7b2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/newParquet.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/newParquet.scala @@ -33,7 +33,6 @@ import org.apache.hadoop.io.Writable import org.apache.hadoop.mapreduce.lib.input.FileInputFormat import org.apache.hadoop.mapreduce.lib.output.FileOutputFormat import org.apache.hadoop.mapreduce.{InputSplit, Job, JobContext} - import parquet.filter2.predicate.FilterApi import parquet.format.converter.ParquetMetadataConverter import parquet.hadoop.metadata.CompressionCodecName @@ -45,13 +44,13 @@ import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.mapred.SparkHadoopMapRedUtil import org.apache.spark.mapreduce.SparkHadoopMapReduceUtil import org.apache.spark.rdd.{NewHadoopPartition, NewHadoopRDD, RDD} -import org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.{CatalystTypeConverters, expressions} import org.apache.spark.sql.parquet.ParquetTypesConverter._ import org.apache.spark.sql.sources._ import org.apache.spark.sql.types.{IntegerType, StructField, StructType, _} import org.apache.spark.sql.{DataFrame, Row, SQLConf, SQLContext, SaveMode} -import org.apache.spark.{Logging, Partition => SparkPartition, SerializableWritable, SparkException, TaskContext} +import org.apache.spark.{Logging, SerializableWritable, SparkException, TaskContext, Partition => SparkPartition} /** * Allows creation of Parquet based tables using the syntax: @@ -409,6 +408,9 @@ private[sql] case class ParquetRelation2( file.getName == ParquetFileWriter.PARQUET_METADATA_FILE } + // Skip type conversion + override val needConversion: Boolean = false + // TODO Should calculate per scan size // It's common that a query only scans a fraction of a large Parquet file. Returning size of the // whole Parquet file disables some optimizations in this case (e.g. broadcast join). @@ -550,7 +552,8 @@ private[sql] case class ParquetRelation2( baseRDD.mapPartitionsWithInputSplit { case (split: ParquetInputSplit, iterator) => val partValues = selectedPartitions.collectFirst { - case p if split.getPath.getParent.toString == p.path => p.values + case p if split.getPath.getParent.toString == p.path => + CatalystTypeConverters.convertToCatalyst(p.values).asInstanceOf[Row] }.get val requiredPartOrdinal = partitionKeyLocations.keys.toSeq diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/DataSourceStrategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/DataSourceStrategy.scala index 34d048e426d10..b3d71f687a60a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/sources/DataSourceStrategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/DataSourceStrategy.scala @@ -23,7 +23,8 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.planning.PhysicalOperation import org.apache.spark.sql.catalyst.plans.logical import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan -import org.apache.spark.sql.types.StringType +import org.apache.spark.sql.execution.SparkPlan +import org.apache.spark.sql.types.{UTF8String, StringType} import org.apache.spark.sql.{Row, Strategy, execution, sources} /** @@ -53,7 +54,7 @@ private[sql] object DataSourceStrategy extends Strategy { (a, _) => t.buildScan(a)) :: Nil case l @ LogicalRelation(t: TableScan) => - execution.PhysicalRDD(l.output, t.buildScan()) :: Nil + createPhysicalRDD(l.relation, l.output, t.buildScan()) :: Nil case i @ logical.InsertIntoTable( l @ LogicalRelation(t: InsertableRelation), part, query, overwrite, false) if part.isEmpty => @@ -102,20 +103,30 @@ private[sql] object DataSourceStrategy extends Strategy { projectList.asInstanceOf[Seq[Attribute]] // Safe due to if above. .map(relation.attributeMap) // Match original case of attributes. - val scan = - execution.PhysicalRDD( - projectList.map(_.toAttribute), + val scan = createPhysicalRDD(relation.relation, projectList.map(_.toAttribute), scanBuilder(requestedColumns, pushedFilters)) filterCondition.map(execution.Filter(_, scan)).getOrElse(scan) } else { val requestedColumns = (projectSet ++ filterSet).map(relation.attributeMap).toSeq - val scan = - execution.PhysicalRDD(requestedColumns, scanBuilder(requestedColumns, pushedFilters)) + val scan = createPhysicalRDD(relation.relation, requestedColumns, + scanBuilder(requestedColumns, pushedFilters)) execution.Project(projectList, filterCondition.map(execution.Filter(_, scan)).getOrElse(scan)) } } + private[this] def createPhysicalRDD( + relation: BaseRelation, + output: Seq[Attribute], + rdd: RDD[Row]): SparkPlan = { + val converted = if (relation.needConversion) { + execution.RDDConversions.rowToRowRdd(rdd, relation.schema) + } else { + rdd + } + execution.PhysicalRDD(output, converted) + } + /** * Selects Catalyst predicate [[Expression]]s which are convertible into data source [[Filter]]s, * and convert them. @@ -167,14 +178,14 @@ private[sql] object DataSourceStrategy extends Strategy { case expressions.Not(child) => translate(child).map(sources.Not) - case expressions.StartsWith(a: Attribute, Literal(v: String, StringType)) => - Some(sources.StringStartsWith(a.name, v)) + case expressions.StartsWith(a: Attribute, Literal(v: UTF8String, StringType)) => + Some(sources.StringStartsWith(a.name, v.toString)) - case expressions.EndsWith(a: Attribute, Literal(v: String, StringType)) => - Some(sources.StringEndsWith(a.name, v)) + case expressions.EndsWith(a: Attribute, Literal(v: UTF8String, StringType)) => + Some(sources.StringEndsWith(a.name, v.toString)) - case expressions.Contains(a: Attribute, Literal(v: String, StringType)) => - Some(sources.StringContains(a.name, v)) + case expressions.Contains(a: Attribute, Literal(v: UTF8String, StringType)) => + Some(sources.StringContains(a.name, v.toString)) case _ => None } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/ddl.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/ddl.scala index 319de710fbc3e..2e861b84b7133 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/sources/ddl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/ddl.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.sources import scala.language.existentials +import scala.util.matching.Regex import scala.language.implicitConversions import org.apache.spark.Logging @@ -155,7 +156,19 @@ private[sql] class DDLParser( protected lazy val className: Parser[String] = repsep(ident, ".") ^^ { case s => s.mkString(".")} - protected lazy val pair: Parser[(String, String)] = ident ~ stringLit ^^ { case k ~ v => (k,v) } + override implicit def regexToParser(regex: Regex): Parser[String] = acceptMatch( + s"identifier matching regex ${regex}", { + case lexical.Identifier(str) if regex.unapplySeq(str).isDefined => str + case lexical.Keyword(str) if regex.unapplySeq(str).isDefined => str + } + ) + + protected lazy val optionName: Parser[String] = "[_a-zA-Z][a-zA-Z0-9]*".r ^^ { + case name => name + } + + protected lazy val pair: Parser[(String, String)] = + optionName ~ stringLit ^^ { case k ~ v => (k,v) } protected lazy val column: Parser[StructField] = ident ~ dataType ~ (COMMENT ~> stringLit).? ^^ { case columnName ~ typ ~ cm => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala index 8f9946a5a801e..ca53dcdb92c52 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala @@ -126,6 +126,16 @@ abstract class BaseRelation { * could lead to execution plans that are suboptimal (i.e. broadcasting a very large table). */ def sizeInBytes: Long = sqlContext.conf.defaultSizeInBytes + + /** + * Whether does it need to convert the objects in Row to internal representation, for example: + * java.lang.String -> UTF8String + * java.lang.Decimal -> Decimal + * + * Note: The internal representation is not stable across releases and thus data sources outside + * of Spark SQL should leave this as true. + */ + def needConversion: Boolean = true } /** diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala index f7b5f08beb92f..01e3b8671071e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala @@ -22,6 +22,7 @@ import scala.language.{implicitConversions, postfixOps} import org.scalatest.concurrent.Eventually._ +import org.apache.spark.Accumulators import org.apache.spark.sql.TestData._ import org.apache.spark.sql.columnar._ import org.apache.spark.sql.test.TestSQLContext._ @@ -297,4 +298,21 @@ class CachedTableSuite extends QueryTest { sql("Clear CACHE") assert(cacheManager.isEmpty) } + + test("Clear accumulators when uncacheTable to prevent memory leaking") { + val accsSize = Accumulators.originals.size + + sql("SELECT key FROM testData LIMIT 10").registerTempTable("t1") + sql("SELECT key FROM testData LIMIT 5").registerTempTable("t2") + cacheTable("t1") + cacheTable("t2") + sql("SELECT * FROM t1").count() + sql("SELECT * FROM t2").count() + sql("SELECT * FROM t1").count() + sql("SELECT * FROM t2").count() + uncacheTable("t1") + uncacheTable("t2") + + assert(accsSize >= Accumulators.originals.size) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index b26e22f6229fe..3250ab476aeb4 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -86,6 +86,12 @@ class DataFrameSuite extends QueryTest { TestSQLContext.setConf(SQLConf.DATAFRAME_EAGER_ANALYSIS, oldSetting.toString) } + test("access complex data") { + assert(complexData.filter(complexData("a").getItem(0) === 2).count() == 1) + assert(complexData.filter(complexData("m").getItem("1") === 1).count() == 1) + assert(complexData.filter(complexData("s").getField("key") === 1).count() == 1) + } + test("table scan") { checkAnswer( testData, @@ -172,6 +178,14 @@ class DataFrameSuite extends QueryTest { testData.select('key).collect().toSeq) } + test("coalesce") { + assert(testData.select('key).coalesce(1).rdd.partitions.size === 1) + + checkAnswer( + testData.select('key).coalesce(1).select('key), + testData.select('key).collect().toSeq) + } + test("groupBy") { checkAnswer( testData2.groupBy("a").agg($"a", sum($"b")), @@ -531,4 +545,13 @@ class DataFrameSuite extends QueryTest { val df = TestSQLContext.createDataFrame(rowRDD, schema) df.rdd.collect() } + + test("SPARK-6899") { + val originalValue = TestSQLContext.conf.codegenEnabled + TestSQLContext.setConf(SQLConf.CODEGEN_ENABLED, "true") + checkAnswer( + decimalData.agg(avg('a)), + Row(new java.math.BigDecimal(2.0))) + TestSQLContext.setConf(SQLConf.CODEGEN_ENABLED, originalValue.toString) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala index e4dee87849fd4..037d392c1f929 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala @@ -51,6 +51,7 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach { case j: CartesianProduct => j case j: BroadcastNestedLoopJoin => j case j: BroadcastLeftSemiJoinHash => j + case j: SortMergeJoin => j } assert(operators.size === 1) @@ -62,6 +63,7 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach { test("join operator selection") { cacheManager.clearCache() + val SORTMERGEJOIN_ENABLED: Boolean = conf.sortMergeJoinEnabled Seq( ("SELECT * FROM testData LEFT SEMI JOIN testData2 ON key = a", classOf[LeftSemiJoinHash]), ("SELECT * FROM testData LEFT SEMI JOIN testData2", classOf[LeftSemiJoinBNL]), @@ -91,17 +93,41 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach { ("SELECT * FROM testData full JOIN testData2 ON (key * a != key + a)", classOf[BroadcastNestedLoopJoin]) ).foreach { case (query, joinClass) => assertJoin(query, joinClass) } + try { + conf.setConf("spark.sql.planner.sortMergeJoin", "true") + Seq( + ("SELECT * FROM testData JOIN testData2 ON key = a", classOf[SortMergeJoin]), + ("SELECT * FROM testData JOIN testData2 ON key = a and key = 2", classOf[SortMergeJoin]), + ("SELECT * FROM testData JOIN testData2 ON key = a where key = 2", classOf[SortMergeJoin]) + ).foreach { case (query, joinClass) => assertJoin(query, joinClass) } + } finally { + conf.setConf("spark.sql.planner.sortMergeJoin", SORTMERGEJOIN_ENABLED.toString) + } } test("broadcasted hash join operator selection") { cacheManager.clearCache() sql("CACHE TABLE testData") + val SORTMERGEJOIN_ENABLED: Boolean = conf.sortMergeJoinEnabled Seq( ("SELECT * FROM testData join testData2 ON key = a", classOf[BroadcastHashJoin]), ("SELECT * FROM testData join testData2 ON key = a and key = 2", classOf[BroadcastHashJoin]), - ("SELECT * FROM testData join testData2 ON key = a where key = 2", classOf[BroadcastHashJoin]) + ("SELECT * FROM testData join testData2 ON key = a where key = 2", + classOf[BroadcastHashJoin]) ).foreach { case (query, joinClass) => assertJoin(query, joinClass) } + try { + conf.setConf("spark.sql.planner.sortMergeJoin", "true") + Seq( + ("SELECT * FROM testData join testData2 ON key = a", classOf[BroadcastHashJoin]), + ("SELECT * FROM testData join testData2 ON key = a and key = 2", + classOf[BroadcastHashJoin]), + ("SELECT * FROM testData join testData2 ON key = a where key = 2", + classOf[BroadcastHashJoin]) + ).foreach { case (query, joinClass) => assertJoin(query, joinClass) } + } finally { + conf.setConf("spark.sql.planner.sortMergeJoin", SORTMERGEJOIN_ENABLED.toString) + } sql("UNCACHE TABLE testData") } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala index 9a81fc5d72819..59f9508444f25 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala @@ -104,9 +104,12 @@ object QueryTest { // Converts data to types that we can do equality comparison using Scala collections. // For BigDecimal type, the Scala type has a better definition of equality test (similar to // Java's java.math.BigDecimal.compareTo). + // For binary arrays, we convert it to Seq to avoid of calling java.util.Arrays.equals for + // equality test. val converted: Seq[Row] = answer.map { s => Row.fromSeq(s.toSeq.map { case d: java.math.BigDecimal => BigDecimal(d) + case b: Array[Byte] => b.toSeq case o => o }) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/RowSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/RowSuite.scala index 36465cc2fa11a..bf6cf1321a056 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/RowSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/RowSuite.scala @@ -30,7 +30,7 @@ class RowSuite extends FunSuite { test("create row") { val expected = new GenericMutableRow(4) expected.update(0, 2147483647) - expected.update(1, "this is a string") + expected.setString(1, "this is a string") expected.update(2, false) expected.update(3, null) val actual1 = Row(2147483647, "this is a string", false, null) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index 5e453e05e2ac7..9e02e69fda3f2 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -17,19 +17,14 @@ package org.apache.spark.sql -import org.apache.spark.sql.execution.GeneratedAggregate -import org.apache.spark.sql.test.TestSQLContext import org.scalatest.BeforeAndAfterAll import org.apache.spark.sql.execution.GeneratedAggregate import org.apache.spark.sql.functions._ -import org.apache.spark.sql.catalyst.errors.TreeNodeException -import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan -import org.apache.spark.sql.types._ - import org.apache.spark.sql.TestData._ +import org.apache.spark.sql.test.TestSQLContext import org.apache.spark.sql.test.TestSQLContext.{udf => _, _} - +import org.apache.spark.sql.types._ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll { // Make sure the tables are loaded. @@ -172,6 +167,13 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll { testCodeGen( "SELECT max(key) FROM testData3x", Row(100) :: Nil) + // MIN + testCodeGen( + "SELECT value, min(key) FROM testData3x GROUP BY value", + (1 to 100).map(i => Row(i.toString, i))) + testCodeGen( + "SELECT min(key) FROM testData3x", + Row(1) :: Nil) // Some combinations. testCodeGen( """ @@ -179,16 +181,17 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll { | value, | sum(key), | max(key), + | min(key), | avg(key), | count(key), | count(distinct key) |FROM testData3x |GROUP BY value """.stripMargin, - (1 to 100).map(i => Row(i.toString, i*3, i, i, 3, 1))) + (1 to 100).map(i => Row(i.toString, i*3, i, i, i, 3, 1))) testCodeGen( - "SELECT max(key), avg(key), count(key), count(distinct key) FROM testData3x", - Row(100, 50.5, 300, 100) :: Nil) + "SELECT max(key), min(key), avg(key), count(key), count(distinct key) FROM testData3x", + Row(100, 1, 50.5, 300, 100) :: Nil) // Aggregate with Code generation handling all null values testCodeGen( "SELECT sum('a'), avg('a'), count(null) FROM testData", @@ -395,6 +398,26 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll { setConf(SQLConf.EXTERNAL_SORT, before.toString) } + test("SPARK-6927 sorting with codegen on") { + val externalbefore = conf.externalSortEnabled + val codegenbefore = conf.codegenEnabled + setConf(SQLConf.EXTERNAL_SORT, "false") + setConf(SQLConf.CODEGEN_ENABLED, "true") + sortTest() + setConf(SQLConf.EXTERNAL_SORT, externalbefore.toString) + setConf(SQLConf.CODEGEN_ENABLED, codegenbefore.toString) + } + + test("SPARK-6927 external sorting with codegen on") { + val externalbefore = conf.externalSortEnabled + val codegenbefore = conf.codegenEnabled + setConf(SQLConf.CODEGEN_ENABLED, "true") + setConf(SQLConf.EXTERNAL_SORT, "true") + sortTest() + setConf(SQLConf.EXTERNAL_SORT, externalbefore.toString) + setConf(SQLConf.CODEGEN_ENABLED, codegenbefore.toString) + } + test("limit") { checkAnswer( sql("SELECT * FROM testData LIMIT 10"), @@ -423,6 +446,12 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll { } + test("Allow only a single WITH clause per query") { + intercept[RuntimeException] { + sql("with q1 as (select * from testData) with q2 as (select * from q1) select * from q2") + } + } + test("date row") { checkAnswer(sql( """select cast("2015-01-28" as date) from testData limit 1"""), @@ -1115,7 +1144,7 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll { val data = sparkContext.parallelize( Seq("""{"key?number1": "value1", "key.number2": "value2"}""")) jsonRDD(data).registerTempTable("records") - sql("SELECT `key?number1` FROM records") + sql("SELECT `key?number1`, `key.number2` FROM records") } test("SPARK-3814 Support Bitwise & operator") { @@ -1215,4 +1244,12 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll { checkAnswer(sql("SELECT a.b[0] FROM t ORDER BY c0.a"), Row(1)) checkAnswer(sql("SELECT b[0].a FROM t ORDER BY c0.a"), Row(1)) } + + test("SPARK-6898: complete support for special chars in column names") { + jsonRDD(sparkContext.makeRDD( + """{"a": {"c.b": 1}, "b.$q": [{"a@!.q": 1}], "q.w": {"w.i&": [1]}}""" :: Nil)) + .registerTempTable("t") + + checkAnswer(sql("SELECT a.`c.b`, `b.$q`[0].`a@!.q`, `q.w`.`w.i&`[0] FROM t"), Row(1, 1, 1)) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala b/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala index 637f59b2e68ca..225b51bd73d6c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala @@ -20,9 +20,8 @@ package org.apache.spark.sql import java.sql.Timestamp import org.apache.spark.sql.catalyst.plans.logical -import org.apache.spark.sql.functions._ -import org.apache.spark.sql.test._ import org.apache.spark.sql.test.TestSQLContext.implicits._ +import org.apache.spark.sql.test._ case class TestData(key: Int, value: String) @@ -199,11 +198,11 @@ object TestData { Salary(1, 1000.0) :: Nil).toDF() salary.registerTempTable("salary") - case class ComplexData(m: Map[Int, String], s: TestData, a: Seq[Int], b: Boolean) + case class ComplexData(m: Map[String, Int], s: TestData, a: Seq[Int], b: Boolean) val complexData = TestSQLContext.sparkContext.parallelize( - ComplexData(Map(1 -> "1"), TestData(1, "1"), Seq(1), true) - :: ComplexData(Map(2 -> "2"), TestData(2, "2"), Seq(2), false) + ComplexData(Map("1" -> 1), TestData(1, "1"), Seq(1), true) + :: ComplexData(Map("2" -> 2), TestData(2, "2"), Seq(2), false) :: Nil).toDF() complexData.registerTempTable("complexData") } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnTypeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnTypeSuite.scala index 5f08834f73c6b..b48bed1871c50 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnTypeSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnTypeSuite.scala @@ -20,9 +20,12 @@ package org.apache.spark.sql.columnar import java.nio.ByteBuffer import java.sql.Timestamp +import com.esotericsoftware.kryo.{Serializer, Kryo} +import com.esotericsoftware.kryo.io.{Input, Output} +import org.apache.spark.serializer.KryoRegistrator import org.scalatest.FunSuite -import org.apache.spark.Logging +import org.apache.spark.{SparkConf, Logging} import org.apache.spark.sql.catalyst.expressions.GenericMutableRow import org.apache.spark.sql.columnar.ColumnarTestUtils._ import org.apache.spark.sql.execution.SparkSqlSerializer @@ -65,7 +68,7 @@ class ColumnTypeSuite extends FunSuite with Logging { checkActualSize(FLOAT, Float.MaxValue, 4) checkActualSize(FIXED_DECIMAL(15, 10), Decimal(0, 15, 10), 8) checkActualSize(BOOLEAN, true, 1) - checkActualSize(STRING, "hello", 4 + "hello".getBytes("utf-8").length) + checkActualSize(STRING, UTF8String("hello"), 4 + "hello".getBytes("utf-8").length) checkActualSize(DATE, 0, 4) checkActualSize(TIMESTAMP, new Timestamp(0L), 12) @@ -73,7 +76,7 @@ class ColumnTypeSuite extends FunSuite with Logging { checkActualSize(BINARY, binary, 4 + 4) val generic = Map(1 -> "a") - checkActualSize(GENERIC, SparkSqlSerializer.serialize(generic), 4 + 11) + checkActualSize(GENERIC, SparkSqlSerializer.serialize(generic), 4 + 8) } testNativeColumnType[BooleanType.type]( @@ -108,8 +111,8 @@ class ColumnTypeSuite extends FunSuite with Logging { testNativeColumnType[StringType.type]( STRING, - (buffer: ByteBuffer, string: String) => { - val bytes = string.getBytes("utf-8") + (buffer: ByteBuffer, string: UTF8String) => { + val bytes = string.getBytes buffer.putInt(bytes.length) buffer.put(bytes) }, @@ -117,7 +120,7 @@ class ColumnTypeSuite extends FunSuite with Logging { val length = buffer.getInt() val bytes = new Array[Byte](length) buffer.get(bytes) - new String(bytes, "utf-8") + UTF8String(bytes) }) testColumnType[BinaryType.type, Array[Byte]]( @@ -158,6 +161,41 @@ class ColumnTypeSuite extends FunSuite with Logging { } } + test("CUSTOM") { + val conf = new SparkConf() + conf.set("spark.kryo.registrator", "org.apache.spark.sql.columnar.Registrator") + val serializer = new SparkSqlSerializer(conf).newInstance() + + val buffer = ByteBuffer.allocate(512) + val obj = CustomClass(Int.MaxValue,Long.MaxValue) + val serializedObj = serializer.serialize(obj).array() + + GENERIC.append(serializer.serialize(obj).array(), buffer) + buffer.rewind() + + val length = buffer.getInt + assert(length === serializedObj.length) + assert(13 == length) // id (1) + int (4) + long (8) + + val genericSerializedObj = SparkSqlSerializer.serialize(obj) + assert(length != genericSerializedObj.length) + assert(length < genericSerializedObj.length) + + assertResult(obj, "Custom deserialized object didn't equal the original object") { + val bytes = new Array[Byte](length) + buffer.get(bytes, 0, length) + serializer.deserialize(ByteBuffer.wrap(bytes)) + } + + buffer.rewind() + buffer.putInt(serializedObj.length).put(serializedObj) + + assertResult(obj, "Custom deserialized object didn't equal the original object") { + buffer.rewind() + serializer.deserialize(ByteBuffer.wrap(GENERIC.extract(buffer))) + } + } + def testNativeColumnType[T <: NativeType]( columnType: NativeColumnType[T], putter: (ByteBuffer, T#JvmType) => Unit, @@ -229,3 +267,23 @@ class ColumnTypeSuite extends FunSuite with Logging { } } } + +private[columnar] final case class CustomClass(a: Int, b: Long) + +private[columnar] object CustomerSerializer extends Serializer[CustomClass] { + override def write(kryo: Kryo, output: Output, t: CustomClass) { + output.writeInt(t.a) + output.writeLong(t.b) + } + override def read(kryo: Kryo, input: Input, aClass: Class[CustomClass]): CustomClass = { + val a = input.readInt() + val b = input.readLong() + CustomClass(a,b) + } +} + +private[columnar] final class Registrator extends KryoRegistrator { + override def registerClasses(kryo: Kryo) { + kryo.register(classOf[CustomClass], CustomerSerializer) + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnarTestUtils.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnarTestUtils.scala index b301818a008e7..f76314b9dab5e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnarTestUtils.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnarTestUtils.scala @@ -24,7 +24,7 @@ import scala.util.Random import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.expressions.GenericMutableRow -import org.apache.spark.sql.types.{Decimal, DataType, NativeType} +import org.apache.spark.sql.types.{UTF8String, DataType, Decimal, NativeType} object ColumnarTestUtils { def makeNullRow(length: Int): GenericMutableRow = { @@ -48,7 +48,7 @@ object ColumnarTestUtils { case FLOAT => Random.nextFloat() case DOUBLE => Random.nextDouble() case FIXED_DECIMAL(precision, scale) => Decimal(Random.nextLong() % 100, precision, scale) - case STRING => Random.nextString(Random.nextInt(32)) + case STRING => UTF8String(Random.nextString(Random.nextInt(32))) case BOOLEAN => Random.nextBoolean() case BINARY => randomBytes(Random.nextInt(32)) case DATE => Random.nextInt() diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/InMemoryColumnarQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/InMemoryColumnarQuerySuite.scala index 479210d1c9c43..56591d9dba29e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/InMemoryColumnarQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/InMemoryColumnarQuerySuite.scala @@ -17,11 +17,13 @@ package org.apache.spark.sql.columnar +import java.sql.{Date, Timestamp} + import org.apache.spark.sql.TestData._ import org.apache.spark.sql.catalyst.expressions.Row import org.apache.spark.sql.test.TestSQLContext._ import org.apache.spark.sql.test.TestSQLContext.implicits._ -import org.apache.spark.sql.types.{DecimalType, Decimal} +import org.apache.spark.sql.types._ import org.apache.spark.sql.{QueryTest, TestData} import org.apache.spark.storage.StorageLevel.MEMORY_ONLY @@ -132,4 +134,59 @@ class InMemoryColumnarQuerySuite extends QueryTest { sql("SELECT * FROM test_fixed_decimal"), (1 to 10).map(i => Row(Decimal(i, 15, 10).toJavaBigDecimal))) } + + test("test different data types") { + // Create the schema. + val struct = + StructType( + StructField("f1", FloatType, true) :: + StructField("f2", ArrayType(BooleanType), true) :: Nil) + val dataTypes = + Seq(StringType, BinaryType, NullType, BooleanType, + ByteType, ShortType, IntegerType, LongType, + FloatType, DoubleType, DecimalType.Unlimited, DecimalType(6, 5), + DateType, TimestampType, + ArrayType(IntegerType), MapType(StringType, LongType), struct) + val fields = dataTypes.zipWithIndex.map { case (dataType, index) => + StructField(s"col$index", dataType, true) + } + val allColumns = fields.map(_.name).mkString(",") + val schema = StructType(fields) + + // Create a RDD for the schema + val rdd = + sparkContext.parallelize((1 to 100), 10).map { i => + Row( + s"str${i}: test cache.", + s"binary${i}: test cache.".getBytes("UTF-8"), + null, + i % 2 == 0, + i.toByte, + i.toShort, + i, + Long.MaxValue - i.toLong, + (i + 0.25).toFloat, + (i + 0.75), + BigDecimal(Long.MaxValue.toString + ".12345"), + new java.math.BigDecimal(s"${i % 9 + 1}" + ".23456"), + new Date(i), + new Timestamp(i), + (1 to i).toSeq, + (0 to i).map(j => s"map_key_$j" -> (Long.MaxValue - j)).toMap, + Row((i - 0.25).toFloat, (1 to i).toSeq)) + } + createDataFrame(rdd, schema).registerTempTable("InMemoryCache_different_data_types") + // Cache the table. + sql("cache table InMemoryCache_different_data_types") + // Make sure the table is indeed cached. + val tableScan = table("InMemoryCache_different_data_types").queryExecution.executedPlan + assert( + isCached("InMemoryCache_different_data_types"), + "InMemoryCache_different_data_types should be cached.") + // Issue a query and check the results. + checkAnswer( + sql(s"SELECT DISTINCT ${allColumns} FROM InMemoryCache_different_data_types"), + table("InMemoryCache_different_data_types").collect()) + dropTempTable("InMemoryCache_different_data_types") + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/PartitionBatchPruningSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/PartitionBatchPruningSuite.scala index e57bb06e7263b..2a0b701cad7fa 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/PartitionBatchPruningSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/PartitionBatchPruningSuite.scala @@ -39,6 +39,8 @@ class PartitionBatchPruningSuite extends FunSuite with BeforeAndAfterAll with Be // Enable in-memory partition pruning setConf(SQLConf.IN_MEMORY_PARTITION_PRUNING, "true") + // Enable in-memory table scan accumulators + setConf("spark.sql.inMemoryTableScanStatistics.enable", "true") } override protected def afterAll(): Unit = { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetIOSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetIOSuite.scala index 4d0bf7cf99cdf..97c0f439acf13 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetIOSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetIOSuite.scala @@ -381,27 +381,6 @@ class ParquetIOSuiteBase extends QueryTest with ParquetTest { } } } - - test("SPARK-6352 DirectParquetOutputCommitter") { - // Write to a parquet file and let it fail. - // _temporary should be missing if direct output committer works. - try { - configuration.set("spark.sql.parquet.output.committer.class", - "org.apache.spark.sql.parquet.DirectParquetOutputCommitter") - sqlContext.udf.register("div0", (x: Int) => x / 0) - withTempPath { dir => - intercept[org.apache.spark.SparkException] { - sqlContext.sql("select div0(1)").saveAsParquetFile(dir.getCanonicalPath) - } - val path = new Path(dir.getCanonicalPath, "_temporary") - val fs = path.getFileSystem(configuration) - assert(!fs.exists(path)) - } - } - finally { - configuration.unset("spark.sql.parquet.output.committer.class") - } - } } class ParquetDataSourceOnIOSuite extends ParquetIOSuiteBase with BeforeAndAfterAll { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/DDLTestSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/DDLTestSuite.scala index 3f24a497390c1..ca25751b9583d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/DDLTestSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/DDLTestSuite.scala @@ -25,17 +25,17 @@ class DDLScanSource extends RelationProvider { override def createRelation( sqlContext: SQLContext, parameters: Map[String, String]): BaseRelation = { - SimpleDDLScan(parameters("from").toInt, parameters("TO").toInt)(sqlContext) + SimpleDDLScan(parameters("from").toInt, parameters("TO").toInt, parameters("Table"))(sqlContext) } } -case class SimpleDDLScan(from: Int, to: Int)(@transient val sqlContext: SQLContext) +case class SimpleDDLScan(from: Int, to: Int, table: String)(@transient val sqlContext: SQLContext) extends BaseRelation with TableScan { override def schema: StructType = StructType(Seq( StructField("intType", IntegerType, nullable = false, - new MetadataBuilder().putString("comment", "test comment").build()), + new MetadataBuilder().putString("comment", s"test comment $table").build()), StructField("stringType", StringType, nullable = false), StructField("dateType", DateType, nullable = false), StructField("timestampType", TimestampType, nullable = false), @@ -73,7 +73,8 @@ class DDLTestSuite extends DataSourceTest { |USING org.apache.spark.sql.sources.DDLScanSource |OPTIONS ( | From '1', - | To '10' + | To '10', + | Table 'test1' |) """.stripMargin) } @@ -81,7 +82,7 @@ class DDLTestSuite extends DataSourceTest { sqlTest( "describe ddlPeople", Seq( - Row("intType", "int", "test comment"), + Row("intType", "int", "test comment test1"), Row("stringType", "string", ""), Row("dateType", "date", ""), Row("timestampType", "timestamp", ""), diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/TableScanSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/TableScanSuite.scala index 60c8c00bda4d5..3b47b8adf313b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/TableScanSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/TableScanSuite.scala @@ -74,7 +74,7 @@ case class AllDataTypesScan( i.toDouble, new java.math.BigDecimal(i), new java.math.BigDecimal(i), - new Date((i + 1) * 8640000), + new Date(1970, 1, 1), new Timestamp(20000 + i), s"varchar_$i", Seq(i, i + 1), @@ -82,7 +82,7 @@ case class AllDataTypesScan( Map(i -> i.toString), Map(Map(s"str_$i" -> i.toFloat) -> Row(i.toLong)), Row(i, i.toString), - Row(Seq(s"str_$i", s"str_${i + 1}"), Row(Seq(new Date((i + 2) * 8640000))))) + Row(Seq(s"str_$i", s"str_${i + 1}"), Row(Seq(new Date(1970, 1, i + 1))))) } } } @@ -103,7 +103,7 @@ class TableScanSuite extends DataSourceTest { i.toDouble, new java.math.BigDecimal(i), new java.math.BigDecimal(i), - new Date((i + 1) * 8640000), + new Date(1970, 1, 1), new Timestamp(20000 + i), s"varchar_$i", Seq(i, i + 1), @@ -111,7 +111,7 @@ class TableScanSuite extends DataSourceTest { Map(i -> i.toString), Map(Map(s"str_$i" -> i.toFloat) -> Row(i.toLong)), Row(i, i.toString), - Row(Seq(s"str_$i", s"str_${i + 1}"), Row(Seq(new Date((i + 2) * 8640000))))) + Row(Seq(s"str_$i", s"str_${i + 1}"), Row(Seq(new Date(1970, 1, i + 1))))) }.toSeq before { @@ -266,7 +266,7 @@ class TableScanSuite extends DataSourceTest { sqlTest( "SELECT structFieldComplex.Value.`value_(2)` FROM tableWithSchema", - (1 to 10).map(i => Row(Seq(new Date((i + 2) * 8640000)))).toSeq) + (1 to 10).map(i => Row(Seq(new Date(1970, 1, i + 1)))).toSeq) test("Caching") { // Cached Query Execution diff --git a/sql/hive-thriftserver/pom.xml b/sql/hive-thriftserver/pom.xml index a96b1ffc26966..f38c796241df1 100644 --- a/sql/hive-thriftserver/pom.xml +++ b/sql/hive-thriftserver/pom.xml @@ -44,7 +44,6 @@ com.google.guava guava - runtime ${hive.group} diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIDriver.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIDriver.scala index 6272cdedb3e48..85281c6d73a3b 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIDriver.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIDriver.scala @@ -33,7 +33,7 @@ import org.apache.hadoop.hive.common.{HiveInterruptCallback, HiveInterruptUtils, import org.apache.hadoop.hive.conf.HiveConf import org.apache.hadoop.hive.ql.Driver import org.apache.hadoop.hive.ql.exec.Utilities -import org.apache.hadoop.hive.ql.processors.{SetProcessor, CommandProcessor, CommandProcessorFactory} +import org.apache.hadoop.hive.ql.processors.{AddResourceProcessor, SetProcessor, CommandProcessor, CommandProcessorFactory} import org.apache.hadoop.hive.ql.session.SessionState import org.apache.hadoop.hive.shims.ShimLoader import org.apache.thrift.transport.TSocket @@ -145,6 +145,9 @@ private[hive] object SparkSQLCLIDriver { case e: UnsupportedEncodingException => System.exit(3) } + // use the specified database if specified + cli.processSelectDatabase(sessionState); + // Execute -i init files (always in silent mode) cli.processInitFiles(sessionState) @@ -264,7 +267,8 @@ private[hive] class SparkSQLCLIDriver extends CliDriver with Logging { val proc: CommandProcessor = HiveShim.getCommandProcessor(Array(tokens(0)), hconf) if (proc != null) { - if (proc.isInstanceOf[Driver] || proc.isInstanceOf[SetProcessor]) { + if (proc.isInstanceOf[Driver] || proc.isInstanceOf[SetProcessor] || + proc.isInstanceOf[AddResourceProcessor]) { val driver = new SparkSQLDriver driver.init() diff --git a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/CliSuite.scala b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/CliSuite.scala index 6d1d7c3a4e698..b070fa8eaa469 100644 --- a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/CliSuite.scala +++ b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/CliSuite.scala @@ -25,22 +25,31 @@ import scala.concurrent.{Await, Promise} import scala.sys.process.{Process, ProcessLogger} import org.apache.hadoop.hive.conf.HiveConf.ConfVars -import org.scalatest.{BeforeAndAfterAll, FunSuite} +import org.scalatest.{BeforeAndAfter, BeforeAndAfterAll, FunSuite} import org.apache.spark.Logging import org.apache.spark.util.Utils -class CliSuite extends FunSuite with BeforeAndAfterAll with Logging { +class CliSuite extends FunSuite with BeforeAndAfter with Logging { + val warehousePath = Utils.createTempDir() + val metastorePath = Utils.createTempDir() + + before { + warehousePath.delete() + metastorePath.delete() + } + + after { + warehousePath.delete() + metastorePath.delete() + } + def runCliWithin( timeout: FiniteDuration, extraArgs: Seq[String] = Seq.empty)( - queriesAndExpectedAnswers: (String, String)*) { + queriesAndExpectedAnswers: (String, String)*): Unit = { val (queries, expectedAnswers) = queriesAndExpectedAnswers.unzip - val warehousePath = Utils.createTempDir() - warehousePath.delete() - val metastorePath = Utils.createTempDir() - metastorePath.delete() val cliScript = "../../bin/spark-sql".split("/").mkString(File.separator) val command = { @@ -95,8 +104,6 @@ class CliSuite extends FunSuite with BeforeAndAfterAll with Logging { """.stripMargin, cause) throw cause } finally { - warehousePath.delete() - metastorePath.delete() process.destroy() } } @@ -124,4 +131,24 @@ class CliSuite extends FunSuite with BeforeAndAfterAll with Logging { test("Single command with -e") { runCliWithin(1.minute, Seq("-e", "SHOW DATABASES;"))("" -> "OK") } + + test("Single command with --database") { + runCliWithin(1.minute)( + "CREATE DATABASE hive_test_db;" + -> "OK", + "USE hive_test_db;" + -> "OK", + "CREATE TABLE hive_test(key INT, val STRING);" + -> "OK", + "SHOW TABLES;" + -> "Time taken: " + ) + + runCliWithin(1.minute, Seq("--database", "hive_test_db", "-e", "SHOW TABLES;"))( + "" + -> "OK", + "" + -> "hive_test" + ) + } } diff --git a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/SortMergeCompatibilitySuite.scala b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/SortMergeCompatibilitySuite.scala new file mode 100644 index 0000000000000..65d070bd3cbde --- /dev/null +++ b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/SortMergeCompatibilitySuite.scala @@ -0,0 +1,162 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive.execution + +import org.apache.spark.sql.SQLConf +import org.apache.spark.sql.hive.test.TestHive + +/** + * Runs the test cases that are included in the hive distribution with sort merge join is true. + */ +class SortMergeCompatibilitySuite extends HiveCompatibilitySuite { + override def beforeAll() { + super.beforeAll() + TestHive.setConf(SQLConf.SORTMERGE_JOIN, "true") + } + + override def afterAll() { + TestHive.setConf(SQLConf.SORTMERGE_JOIN, "false") + super.afterAll() + } + + override def whiteList = Seq( + "auto_join0", + "auto_join1", + "auto_join10", + "auto_join11", + "auto_join12", + "auto_join13", + "auto_join14", + "auto_join14_hadoop20", + "auto_join15", + "auto_join17", + "auto_join18", + "auto_join19", + "auto_join2", + "auto_join20", + "auto_join21", + "auto_join22", + "auto_join23", + "auto_join24", + "auto_join25", + "auto_join26", + "auto_join27", + "auto_join28", + "auto_join3", + "auto_join30", + "auto_join31", + "auto_join32", + "auto_join4", + "auto_join5", + "auto_join6", + "auto_join7", + "auto_join8", + "auto_join9", + "auto_join_filters", + "auto_join_nulls", + "auto_join_reordering_values", + "auto_smb_mapjoin_14", + "auto_sortmerge_join_1", + "auto_sortmerge_join_10", + "auto_sortmerge_join_11", + "auto_sortmerge_join_12", + "auto_sortmerge_join_13", + "auto_sortmerge_join_14", + "auto_sortmerge_join_15", + "auto_sortmerge_join_16", + "auto_sortmerge_join_2", + "auto_sortmerge_join_3", + "auto_sortmerge_join_4", + "auto_sortmerge_join_5", + "auto_sortmerge_join_6", + "auto_sortmerge_join_7", + "auto_sortmerge_join_8", + "auto_sortmerge_join_9", + "correlationoptimizer1", + "correlationoptimizer10", + "correlationoptimizer11", + "correlationoptimizer13", + "correlationoptimizer14", + "correlationoptimizer15", + "correlationoptimizer2", + "correlationoptimizer3", + "correlationoptimizer4", + "correlationoptimizer6", + "correlationoptimizer7", + "correlationoptimizer8", + "correlationoptimizer9", + "join0", + "join1", + "join10", + "join11", + "join12", + "join13", + "join14", + "join14_hadoop20", + "join15", + "join16", + "join17", + "join18", + "join19", + "join2", + "join20", + "join21", + "join22", + "join23", + "join24", + "join25", + "join26", + "join27", + "join28", + "join29", + "join3", + "join30", + "join31", + "join32", + "join32_lessSize", + "join33", + "join34", + "join35", + "join36", + "join37", + "join38", + "join39", + "join4", + "join40", + "join41", + "join5", + "join6", + "join7", + "join8", + "join9", + "join_1to1", + "join_array", + "join_casesensitive", + "join_empty", + "join_filters", + "join_hive_626", + "join_map_ppr", + "join_nulls", + "join_nullsafe", + "join_rc", + "join_reorder2", + "join_reorder3", + "join_reorder4", + "join_star" + ) +} diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala index 921c6194c7b76..74ae984f34866 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala @@ -34,7 +34,7 @@ import scala.collection.JavaConversions._ * 1. The Underlying data type in catalyst and in Hive * In catalyst: * Primitive => - * java.lang.String + * UTF8String * int / scala.Int * boolean / scala.Boolean * float / scala.Float @@ -239,9 +239,10 @@ private[hive] trait HiveInspectors { */ def unwrap(data: Any, oi: ObjectInspector): Any = oi match { case coi: ConstantObjectInspector if coi.getWritableConstantValue == null => null - case poi: WritableConstantStringObjectInspector => poi.getWritableConstantValue.toString + case poi: WritableConstantStringObjectInspector => + UTF8String(poi.getWritableConstantValue.toString) case poi: WritableConstantHiveVarcharObjectInspector => - poi.getWritableConstantValue.getHiveVarchar.getValue + UTF8String(poi.getWritableConstantValue.getHiveVarchar.getValue) case poi: WritableConstantHiveDecimalObjectInspector => HiveShim.toCatalystDecimal( PrimitiveObjectInspectorFactory.javaHiveDecimalObjectInspector, @@ -284,10 +285,13 @@ private[hive] trait HiveInspectors { case pi: PrimitiveObjectInspector => pi match { // We think HiveVarchar is also a String case hvoi: HiveVarcharObjectInspector if hvoi.preferWritable() => - hvoi.getPrimitiveWritableObject(data).getHiveVarchar.getValue - case hvoi: HiveVarcharObjectInspector => hvoi.getPrimitiveJavaObject(data).getValue + UTF8String(hvoi.getPrimitiveWritableObject(data).getHiveVarchar.getValue) + case hvoi: HiveVarcharObjectInspector => + UTF8String(hvoi.getPrimitiveJavaObject(data).getValue) case x: StringObjectInspector if x.preferWritable() => - x.getPrimitiveWritableObject(data).toString + UTF8String(x.getPrimitiveWritableObject(data).toString) + case x: StringObjectInspector => + UTF8String(x.getPrimitiveJavaObject(data)) case x: IntObjectInspector if x.preferWritable() => x.get(data) case x: BooleanObjectInspector if x.preferWritable() => x.get(data) case x: FloatObjectInspector if x.preferWritable() => x.get(data) @@ -340,7 +344,9 @@ private[hive] trait HiveInspectors { */ protected def wrapperFor(oi: ObjectInspector): Any => Any = oi match { case _: JavaHiveVarcharObjectInspector => - (o: Any) => new HiveVarchar(o.asInstanceOf[String], o.asInstanceOf[String].size) + (o: Any) => + val s = o.asInstanceOf[UTF8String].toString + new HiveVarchar(s, s.size) case _: JavaHiveDecimalObjectInspector => (o: Any) => HiveShim.createDecimal(o.asInstanceOf[Decimal].toJavaBigDecimal) @@ -409,7 +415,7 @@ private[hive] trait HiveInspectors { case x: PrimitiveObjectInspector => x match { // TODO we don't support the HiveVarcharObjectInspector yet. case _: StringObjectInspector if x.preferWritable() => HiveShim.getStringWritable(a) - case _: StringObjectInspector => a.asInstanceOf[java.lang.String] + case _: StringObjectInspector => a.asInstanceOf[UTF8String].toString() case _: IntObjectInspector if x.preferWritable() => HiveShim.getIntWritable(a) case _: IntObjectInspector => a.asInstanceOf[java.lang.Integer] case _: BooleanObjectInspector if x.preferWritable() => HiveShim.getBooleanWritable(a) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala index 53a204b8c2932..fd305eb480e63 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala @@ -1101,7 +1101,7 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C case Token(".", qualifier :: Token(attr, Nil) :: Nil) => nodeToExpr(qualifier) match { case UnresolvedAttribute(qualifierName) => - UnresolvedAttribute(qualifierName + "." + cleanIdentifier(attr)) + UnresolvedAttribute(qualifierName :+ cleanIdentifier(attr)) case other => UnresolvedGetField(other, attr) } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala index 1ccb0c279c60e..a6f4fbe8aba06 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala @@ -17,24 +17,21 @@ package org.apache.spark.sql.hive -import org.apache.spark.sql.catalyst.expressions.Row - import scala.collection.JavaConversions._ import org.apache.spark.annotation.Experimental import org.apache.spark.sql._ +import org.apache.spark.sql.catalyst.CatalystTypeConverters import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute -import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.codegen.GeneratePredicate +import org.apache.spark.sql.catalyst.expressions.{Row, _} import org.apache.spark.sql.catalyst.planning._ import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan -import org.apache.spark.sql.sources.DescribeCommand -import org.apache.spark.sql.execution.{DescribeCommand => RunnableDescribeCommand} -import org.apache.spark.sql.execution._ +import org.apache.spark.sql.execution.{DescribeCommand => RunnableDescribeCommand, _} import org.apache.spark.sql.hive.execution._ import org.apache.spark.sql.parquet.ParquetRelation -import org.apache.spark.sql.sources.{CreateTableUsingAsSelect, CreateTableUsing} +import org.apache.spark.sql.sources.{CreateTableUsing, CreateTableUsingAsSelect, DescribeCommand} import org.apache.spark.sql.types.StringType @@ -131,7 +128,7 @@ private[hive] trait HiveStrategies { val partitionValues = part.getValues var i = 0 while (i < partitionValues.size()) { - inputData(i) = partitionValues(i) + inputData(i) = CatalystTypeConverters.convertToCatalyst(partitionValues(i)) i += 1 } pruningCondition(inputData) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/TableReader.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/TableReader.scala index d35291543c9f9..e556c74ffb015 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/TableReader.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/TableReader.scala @@ -35,6 +35,7 @@ import org.apache.spark.broadcast.Broadcast import org.apache.spark.rdd.{EmptyRDD, HadoopRDD, RDD, UnionRDD} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.types.DateUtils +import org.apache.spark.util.Utils /** * A trait for subclasses that handle table scans. @@ -76,7 +77,9 @@ class HadoopTableReader( override def makeRDDForTable(hiveTable: HiveTable): RDD[Row] = makeRDDForTable( hiveTable, - relation.tableDesc.getDeserializerClass.asInstanceOf[Class[Deserializer]], + Class.forName( + relation.tableDesc.getSerdeClassName, true, sc.sessionState.getConf.getClassLoader) + .asInstanceOf[Class[Deserializer]], filterOpt = None) /** diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/ScriptTransformation.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/ScriptTransformation.scala index 8efed7f0299bf..cab0fdd35723a 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/ScriptTransformation.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/ScriptTransformation.scala @@ -17,8 +17,7 @@ package org.apache.spark.sql.hive.execution -import java.io.{BufferedReader, InputStreamReader} -import java.io.{DataInputStream, DataOutputStream, EOFException} +import java.io.{BufferedReader, DataInputStream, DataOutputStream, EOFException, InputStreamReader} import java.util.Properties import scala.collection.JavaConversions._ @@ -28,12 +27,13 @@ import org.apache.hadoop.hive.serde2.AbstractSerDe import org.apache.hadoop.hive.serde2.objectinspector._ import org.apache.spark.rdd.RDD +import org.apache.spark.sql.catalyst.CatalystTypeConverters import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.logical.ScriptInputOutputSchema import org.apache.spark.sql.execution._ -import org.apache.spark.sql.types.DataType -import org.apache.spark.sql.hive.{HiveContext, HiveInspectors} import org.apache.spark.sql.hive.HiveShim._ +import org.apache.spark.sql.hive.{HiveContext, HiveInspectors} +import org.apache.spark.sql.types.DataType import org.apache.spark.util.Utils /** @@ -121,14 +121,13 @@ case class ScriptTransformation( if (outputSerde == null) { val prevLine = curLine curLine = reader.readLine() - if (!ioschema.schemaLess) { - new GenericRow( - prevLine.split(ioschema.outputRowFormatMap("TOK_TABLEROWFORMATFIELD")) + new GenericRow(CatalystTypeConverters.convertToCatalyst( + prevLine.split(ioschema.outputRowFormatMap("TOK_TABLEROWFORMATFIELD"))) .asInstanceOf[Array[Any]]) } else { - new GenericRow( - prevLine.split(ioschema.outputRowFormatMap("TOK_TABLEROWFORMATFIELD"), 2) + new GenericRow(CatalystTypeConverters.convertToCatalyst( + prevLine.split(ioschema.outputRowFormatMap("TOK_TABLEROWFORMATFIELD"), 2)) .asInstanceOf[Array[Any]]) } } else { diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/commands.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/commands.scala index 99dc58646ddd6..a40a1e53117cd 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/commands.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/commands.scala @@ -22,11 +22,11 @@ import org.apache.spark.sql.catalyst.analysis.EliminateSubQueries import org.apache.spark.sql.catalyst.util._ import org.apache.spark.sql.sources._ import org.apache.spark.sql.{SaveMode, DataFrame, SQLContext} -import org.apache.spark.sql.catalyst.expressions.Row +import org.apache.spark.sql.catalyst.expressions.{Attribute, Row} import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.execution.RunnableCommand import org.apache.spark.sql.hive.HiveContext -import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.types._ /** * Analyzes the given table in the current database to generate statistics, which will be @@ -76,11 +76,17 @@ case class DropTable( private[hive] case class AddJar(path: String) extends RunnableCommand { + override val output: Seq[Attribute] = { + val schema = StructType( + StructField("result", IntegerType, false) :: Nil) + schema.toAttributes + } + override def run(sqlContext: SQLContext): Seq[Row] = { val hiveContext = sqlContext.asInstanceOf[HiveContext] hiveContext.runSqlHive(s"ADD JAR $path") hiveContext.sparkContext.addJar(path) - Seq.empty[Row] + Seq(Row(0)) } } diff --git a/sql/hive/src/test/resources/hive-hcatalog-core-0.13.1.jar b/sql/hive/src/test/resources/hive-hcatalog-core-0.13.1.jar new file mode 100644 index 0000000000000..37af9aafad8a4 Binary files /dev/null and b/sql/hive/src/test/resources/hive-hcatalog-core-0.13.1.jar differ diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala index 1222fbabd8b33..300b1f7920473 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala @@ -813,6 +813,21 @@ class HiveQuerySuite extends HiveComparisonTest with BeforeAndAfter { sql("DROP TABLE alter1") } + test("ADD JAR command 2") { + // this is a test case from mapjoin_addjar.q + val testJar = TestHive.getHiveFile("hive-hcatalog-core-0.13.1.jar").getCanonicalPath + val testData = TestHive.getHiveFile("data/files/sample.json").getCanonicalPath + if (HiveShim.version == "0.13.1") { + sql(s"ADD JAR $testJar") + sql( + """CREATE TABLE t1(a string, b string) + |ROW FORMAT SERDE 'org.apache.hive.hcatalog.data.JsonSerDe'""".stripMargin) + sql(s"""LOAD DATA LOCAL INPATH "$testData" INTO TABLE t1""") + sql("select * from src join t1 on src.key = t1.a") + sql("DROP TABLE t1") + } + } + test("ADD FILE command") { val testFile = TestHive.getHiveFile("data/files/v1.txt").getCanonicalFile sql(s"ADD FILE $testFile") diff --git a/sql/hive/v0.12.0/src/main/scala/org/apache/spark/sql/hive/Shim12.scala b/sql/hive/v0.12.0/src/main/scala/org/apache/spark/sql/hive/Shim12.scala index 0ed93c2c5b1fa..33e96eaabfbf6 100644 --- a/sql/hive/v0.12.0/src/main/scala/org/apache/spark/sql/hive/Shim12.scala +++ b/sql/hive/v0.12.0/src/main/scala/org/apache/spark/sql/hive/Shim12.scala @@ -41,7 +41,7 @@ import org.apache.hadoop.hive.serde2.typeinfo.{TypeInfo, TypeInfoFactory} import org.apache.hadoop.io.{NullWritable, Writable} import org.apache.hadoop.mapred.InputFormat -import org.apache.spark.sql.types.{Decimal, DecimalType} +import org.apache.spark.sql.types.{UTF8String, Decimal, DecimalType} private[hive] case class HiveFunctionWrapper(functionClassName: String) extends java.io.Serializable { @@ -135,7 +135,7 @@ private[hive] object HiveShim { PrimitiveCategory.VOID, null) def getStringWritable(value: Any): hadoopIo.Text = - if (value == null) null else new hadoopIo.Text(value.asInstanceOf[String]) + if (value == null) null else new hadoopIo.Text(value.asInstanceOf[UTF8String].toString) def getIntWritable(value: Any): hadoopIo.IntWritable = if (value == null) null else new hadoopIo.IntWritable(value.asInstanceOf[Int]) diff --git a/sql/hive/v0.13.1/src/main/scala/org/apache/spark/sql/hive/Shim13.scala b/sql/hive/v0.13.1/src/main/scala/org/apache/spark/sql/hive/Shim13.scala index 7577309900209..d331c210e8939 100644 --- a/sql/hive/v0.13.1/src/main/scala/org/apache/spark/sql/hive/Shim13.scala +++ b/sql/hive/v0.13.1/src/main/scala/org/apache/spark/sql/hive/Shim13.scala @@ -17,37 +17,35 @@ package org.apache.spark.sql.hive -import java.util -import java.util.{ArrayList => JArrayList} -import java.util.Properties import java.rmi.server.UID +import java.util.{Properties, ArrayList => JArrayList} import scala.collection.JavaConversions._ import scala.language.implicitConversions +import com.esotericsoftware.kryo.Kryo import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.Path -import org.apache.hadoop.io.{NullWritable, Writable} -import org.apache.hadoop.mapred.InputFormat import org.apache.hadoop.hive.common.StatsSetupConst -import org.apache.hadoop.hive.common.`type`.{HiveDecimal} +import org.apache.hadoop.hive.common.`type`.HiveDecimal import org.apache.hadoop.hive.conf.HiveConf import org.apache.hadoop.hive.ql.Context -import org.apache.hadoop.hive.ql.metadata.{Table, Hive, Partition} +import org.apache.hadoop.hive.ql.exec.{UDF, Utilities} +import org.apache.hadoop.hive.ql.metadata.{Hive, Partition, Table} import org.apache.hadoop.hive.ql.plan.{CreateTableDesc, FileSinkDesc, TableDesc} import org.apache.hadoop.hive.ql.processors.CommandProcessorFactory import org.apache.hadoop.hive.serde.serdeConstants -import org.apache.hadoop.hive.serde2.typeinfo.{TypeInfo, DecimalTypeInfo, TypeInfoFactory} -import org.apache.hadoop.hive.serde2.objectinspector.primitive.{HiveDecimalObjectInspector, PrimitiveObjectInspectorFactory} -import org.apache.hadoop.hive.serde2.objectinspector.{ObjectInspectorConverters, PrimitiveObjectInspector, ObjectInspector} -import org.apache.hadoop.hive.serde2.{Deserializer, ColumnProjectionUtils} -import org.apache.hadoop.hive.serde2.{io => hiveIo} import org.apache.hadoop.hive.serde2.avro.AvroGenericRecordWritable +import org.apache.hadoop.hive.serde2.objectinspector.primitive.{HiveDecimalObjectInspector, PrimitiveObjectInspectorFactory} +import org.apache.hadoop.hive.serde2.objectinspector.{ObjectInspector, ObjectInspectorConverters, PrimitiveObjectInspector} +import org.apache.hadoop.hive.serde2.typeinfo.{DecimalTypeInfo, TypeInfo, TypeInfoFactory} +import org.apache.hadoop.hive.serde2.{ColumnProjectionUtils, Deserializer, io => hiveIo} +import org.apache.hadoop.io.{NullWritable, Writable} +import org.apache.hadoop.mapred.InputFormat import org.apache.hadoop.{io => hadoopIo} import org.apache.spark.Logging -import org.apache.spark.sql.types.{Decimal, DecimalType} - +import org.apache.spark.sql.types.{Decimal, DecimalType, UTF8String} /** * This class provides the UDF creation and also the UDF instance serialization and @@ -63,18 +61,14 @@ private[hive] case class HiveFunctionWrapper(var functionClassName: String) // for Serialization def this() = this(null) - import java.io.{OutputStream, InputStream} - import com.esotericsoftware.kryo.Kryo import org.apache.spark.util.Utils._ - import org.apache.hadoop.hive.ql.exec.Utilities - import org.apache.hadoop.hive.ql.exec.UDF @transient private val methodDeSerialize = { val method = classOf[Utilities].getDeclaredMethod( "deserializeObjectByKryo", classOf[Kryo], - classOf[InputStream], + classOf[java.io.InputStream], classOf[Class[_]]) method.setAccessible(true) @@ -87,7 +81,7 @@ private[hive] case class HiveFunctionWrapper(var functionClassName: String) "serializeObjectByKryo", classOf[Kryo], classOf[Object], - classOf[OutputStream]) + classOf[java.io.OutputStream]) method.setAccessible(true) method @@ -224,7 +218,7 @@ private[hive] object HiveShim { TypeInfoFactory.voidTypeInfo, null) def getStringWritable(value: Any): hadoopIo.Text = - if (value == null) null else new hadoopIo.Text(value.asInstanceOf[String]) + if (value == null) null else new hadoopIo.Text(value.asInstanceOf[UTF8String].toString) def getIntWritable(value: Any): hadoopIo.IntWritable = if (value == null) null else new hadoopIo.IntWritable(value.asInstanceOf[Int]) diff --git a/streaming/src/main/scala/org/apache/spark/streaming/ui/AllBatchesTable.scala b/streaming/src/main/scala/org/apache/spark/streaming/ui/AllBatchesTable.scala new file mode 100644 index 0000000000000..df1c0a10704c3 --- /dev/null +++ b/streaming/src/main/scala/org/apache/spark/streaming/ui/AllBatchesTable.scala @@ -0,0 +1,114 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.streaming.ui + +import scala.xml.Node + +import org.apache.spark.streaming.scheduler.BatchInfo +import org.apache.spark.ui.UIUtils + +private[ui] abstract class BatchTableBase(tableId: String) { + + protected def columns: Seq[Node] = { + + + + + } + + protected def baseRow(batch: BatchInfo): Seq[Node] = { + val batchTime = batch.batchTime.milliseconds + val formattedBatchTime = UIUtils.formatDate(batch.batchTime.milliseconds) + val eventCount = batch.receivedBlockInfo.values.map { + receivers => receivers.map(_.numRecords).sum + }.sum + val schedulingDelay = batch.schedulingDelay + val formattedSchedulingDelay = schedulingDelay.map(UIUtils.formatDuration).getOrElse("-") + val processingTime = batch.processingDelay + val formattedProcessingTime = processingTime.map(UIUtils.formatDuration).getOrElse("-") + + + + + + } + + private def batchTable: Seq[Node] = { +
    Batch TimeInput SizeScheduling DelayProcessing Time{formattedBatchTime}{eventCount.toString} events + {formattedSchedulingDelay} + + {formattedProcessingTime} +
    + + {columns} + + + {renderRows} + +
    + } + + def toNodeSeq: Seq[Node] = { + batchTable + } + + /** + * Return HTML for all rows of this table. + */ + protected def renderRows: Seq[Node] +} + +private[ui] class ActiveBatchTable(runningBatches: Seq[BatchInfo], waitingBatches: Seq[BatchInfo]) + extends BatchTableBase("active-batches-table") { + + override protected def columns: Seq[Node] = super.columns ++ Status + + override protected def renderRows: Seq[Node] = { + // The "batchTime"s of "waitingBatches" must be greater than "runningBatches"'s, so display + // waiting batches before running batches + waitingBatches.flatMap(batch => {waitingBatchRow(batch)}) ++ + runningBatches.flatMap(batch => {runningBatchRow(batch)}) + } + + private def runningBatchRow(batch: BatchInfo): Seq[Node] = { + baseRow(batch) ++ processing + } + + private def waitingBatchRow(batch: BatchInfo): Seq[Node] = { + baseRow(batch) ++ queued + } +} + +private[ui] class CompletedBatchTable(batches: Seq[BatchInfo]) + extends BatchTableBase("completed-batches-table") { + + override protected def columns: Seq[Node] = super.columns ++ Total Delay + + override protected def renderRows: Seq[Node] = { + batches.flatMap(batch => {completedBatchRow(batch)}) + } + + private def completedBatchRow(batch: BatchInfo): Seq[Node] = { + val totalDelay = batch.totalDelay + val formattedTotalDelay = totalDelay.map(UIUtils.formatDuration).getOrElse("-") + baseRow(batch) ++ + + {formattedTotalDelay} + + } +} diff --git a/streaming/src/main/scala/org/apache/spark/streaming/ui/StreamingPage.scala b/streaming/src/main/scala/org/apache/spark/streaming/ui/StreamingPage.scala index b6dcb62bfeec8..07fa285642eec 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/ui/StreamingPage.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/ui/StreamingPage.scala @@ -41,7 +41,8 @@ private[ui] class StreamingPage(parent: StreamingTab) generateBasicStats() ++

    ++

    Statistics over last {listener.retainedCompletedBatches.size} processed batches

    ++ generateReceiverStats() ++ - generateBatchStatsTable() + generateBatchStatsTable() ++ + generateBatchListTables() } UIUtils.headerSparkPage("Streaming", content, parent, Some(5000)) } @@ -49,9 +50,10 @@ private[ui] class StreamingPage(parent: StreamingTab) /** Generate basic stats of the streaming program */ private def generateBasicStats(): Seq[Node] = { val timeSinceStart = System.currentTimeMillis() - startTime + // scalastyle:off
    • - Started at: {startTime.toString} + Started at: {UIUtils.formatDate(startTime)}
    • Time since start: {formatDurationVerbose(timeSinceStart)} @@ -63,18 +65,19 @@ private[ui] class StreamingPage(parent: StreamingTab) Batch interval: {formatDurationVerbose(listener.batchDuration)}
    • - Processed batches: {listener.numTotalCompletedBatches} + Completed batches: {listener.numTotalCompletedBatches}
    • - Waiting batches: {listener.numUnprocessedBatches} + Active batches: {listener.numUnprocessedBatches}
    • - Received records: {listener.numTotalReceivedRecords} + Received events: {listener.numTotalReceivedRecords}
    • - Processed records: {listener.numTotalProcessedRecords} + Processed events: {listener.numTotalProcessedRecords}
    + // scalastyle:on } /** Generate stats of data received by the receivers in the streaming program */ @@ -86,10 +89,10 @@ private[ui] class StreamingPage(parent: StreamingTab) "Receiver", "Status", "Location", - "Records in last batch\n[" + formatDate(Calendar.getInstance().getTime()) + "]", - "Minimum rate\n[records/sec]", - "Median rate\n[records/sec]", - "Maximum rate\n[records/sec]", + "Events in last batch\n[" + formatDate(Calendar.getInstance().getTime()) + "]", + "Minimum rate\n[events/sec]", + "Median rate\n[events/sec]", + "Maximum rate\n[events/sec]", "Last Error" ) val dataRows = (0 until listener.numReceivers).map { receiverId => @@ -190,5 +193,26 @@ private[ui] class StreamingPage(parent: StreamingTab) } UIUtils.listingTable(headers, generateDataRow, data, fixedWidth = true) } + + private def generateBatchListTables(): Seq[Node] = { + val runningBatches = listener.runningBatches.sortBy(_.batchTime.milliseconds).reverse + val waitingBatches = listener.waitingBatches.sortBy(_.batchTime.milliseconds).reverse + val completedBatches = listener.retainedCompletedBatches. + sortBy(_.batchTime.milliseconds).reverse + + val activeBatchesContent = { +

    Active Batches ({runningBatches.size + waitingBatches.size})

    ++ + new ActiveBatchTable(runningBatches, waitingBatches).toNodeSeq + } + + val completedBatchesContent = { +

    + Completed Batches (last {completedBatches.size} out of {listener.numTotalCompletedBatches}) +

    ++ + new CompletedBatchTable(completedBatches).toNodeSeq + } + + activeBatchesContent ++ completedBatchesContent + } } diff --git a/streaming/src/test/scala/org/apache/spark/streaming/UISeleniumSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/UISeleniumSuite.scala index 998426ebb82e5..205ddf6dbe9b0 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/UISeleniumSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/UISeleniumSuite.scala @@ -75,6 +75,17 @@ class UISeleniumSuite val statisticText = findAll(cssSelector("li strong")).map(_.text).toSeq statisticText should contain("Network receivers:") statisticText should contain("Batch interval:") + + val h4Text = findAll(cssSelector("h4")).map(_.text).toSeq + h4Text should contain("Active Batches (0)") + h4Text should contain("Completed Batches (last 0 out of 0)") + + findAll(cssSelector("""#active-batches-table th""")).map(_.text).toSeq should be { + List("Batch Time", "Input Size", "Scheduling Delay", "Processing Time", "Status") + } + findAll(cssSelector("""#completed-batches-table th""")).map(_.text).toSeq should be { + List("Batch Time", "Input Size", "Scheduling Delay", "Processing Time", "Total Delay") + } } ssc.stop(false) diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala index 1091ff54b0463..52e4dee46c535 100644 --- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala +++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala @@ -1052,8 +1052,7 @@ object Client extends Logging { if (isDriver) { conf.getBoolean("spark.driver.userClassPathFirst", false) } else { - conf.getBoolean("spark.executor.userClassPathFirst", - conf.getBoolean("spark.files.userClassPathFirst", false)) + conf.getBoolean("spark.executor.userClassPathFirst", false) } } diff --git a/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala b/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala index c06c0105670c0..a18c94d4ab4a8 100644 --- a/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala +++ b/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala @@ -33,7 +33,7 @@ import org.scalatest.{BeforeAndAfterAll, FunSuite, Matchers} import org.apache.spark.{Logging, SparkConf, SparkContext, SparkException, TestUtils} import org.apache.spark.scheduler.cluster.ExecutorInfo -import org.apache.spark.scheduler.{SparkListener, SparkListenerExecutorAdded} +import org.apache.spark.scheduler.{SparkListenerJobStart, SparkListener, SparkListenerExecutorAdded} import org.apache.spark.util.Utils /** @@ -144,7 +144,7 @@ class YarnClusterSuite extends FunSuite with BeforeAndAfterAll with Matchers wit } // Enable this once fix SPARK-6700 - ignore("run Python application in yarn-cluster mode") { + test("run Python application in yarn-cluster mode") { val primaryPyFile = new File(tempDir, "test.py") Files.write(TEST_PYFILE, primaryPyFile, UTF_8) val pyFile = new File(tempDir, "test2.py") @@ -282,10 +282,10 @@ class YarnClusterSuite extends FunSuite with BeforeAndAfterAll with Matchers wit } -private class SaveExecutorInfo extends SparkListener { +private[spark] class SaveExecutorInfo extends SparkListener { val addedExecutorInfos = mutable.Map[String, ExecutorInfo]() - override def onExecutorAdded(executor : SparkListenerExecutorAdded) { + override def onExecutorAdded(executor: SparkListenerExecutorAdded) { addedExecutorInfos(executor.executorId) = executor.executorInfo } } @@ -293,7 +293,6 @@ private class SaveExecutorInfo extends SparkListener { private object YarnClusterDriver extends Logging with Matchers { val WAIT_TIMEOUT_MILLIS = 10000 - var listener: SaveExecutorInfo = null def main(args: Array[String]): Unit = { if (args.length != 1) { @@ -306,10 +305,9 @@ private object YarnClusterDriver extends Logging with Matchers { System.exit(1) } - listener = new SaveExecutorInfo val sc = new SparkContext(new SparkConf() + .set("spark.extraListeners", classOf[SaveExecutorInfo].getName) .setAppName("yarn \"test app\" 'with quotes' and \\back\\slashes and $dollarSigns")) - sc.addSparkListener(listener) val status = new File(args(0)) var result = "failure" try { @@ -323,7 +321,12 @@ private object YarnClusterDriver extends Logging with Matchers { } // verify log urls are present - listener.addedExecutorInfos.values.foreach { info => + val listeners = sc.listenerBus.findListenersByClass[SaveExecutorInfo] + assert(listeners.size === 1) + val listener = listeners(0) + val executorInfos = listener.addedExecutorInfos.values + assert(executorInfos.nonEmpty) + executorInfos.foreach { info => assert(info.logUrlMap.nonEmpty) } }