diff --git a/assembly/pom.xml b/assembly/pom.xml
index b2a9d0780ee2b..594fa0c779e1b 100644
--- a/assembly/pom.xml
+++ b/assembly/pom.xml
@@ -142,8 +142,10 @@
com/google/common/base/Absent*
+ com/google/common/base/Functioncom/google/common/base/Optional*com/google/common/base/Present*
+ com/google/common/base/Supplier
diff --git a/bin/compute-classpath.sh b/bin/compute-classpath.sh
index 8f3b396ffd086..9e8d0b785194e 100755
--- a/bin/compute-classpath.sh
+++ b/bin/compute-classpath.sh
@@ -72,22 +72,25 @@ else
assembly_folder="$ASSEMBLY_DIR"
fi
-num_jars="$(ls "$assembly_folder" | grep "spark-assembly.*hadoop.*\.jar$" | wc -l)"
-if [ "$num_jars" -eq "0" ]; then
- echo "Failed to find Spark assembly in $assembly_folder"
- echo "You need to build Spark before running this program."
- exit 1
-fi
+num_jars=0
+
+for f in ${assembly_folder}/spark-assembly*hadoop*.jar; do
+ if [[ ! -e "$f" ]]; then
+ echo "Failed to find Spark assembly in $assembly_folder" 1>&2
+ echo "You need to build Spark before running this program." 1>&2
+ exit 1
+ fi
+ ASSEMBLY_JAR="$f"
+ num_jars=$((num_jars+1))
+done
+
if [ "$num_jars" -gt "1" ]; then
- jars_list=$(ls "$assembly_folder" | grep "spark-assembly.*hadoop.*.jar$")
- echo "Found multiple Spark assembly jars in $assembly_folder:"
- echo "$jars_list"
- echo "Please remove all but one jar."
+ echo "Found multiple Spark assembly jars in $assembly_folder:" 1>&2
+ ls ${assembly_folder}/spark-assembly*hadoop*.jar 1>&2
+ echo "Please remove all but one jar." 1>&2
exit 1
fi
-ASSEMBLY_JAR="$(ls "$assembly_folder"/spark-assembly*hadoop*.jar 2>/dev/null)"
-
# Verify that versions of java used to build the jars and run Spark are compatible
jar_error_check=$("$JAR_CMD" -tf "$ASSEMBLY_JAR" nonexistent/class/path 2>&1)
if [[ "$jar_error_check" =~ "invalid CEN header" ]]; then
diff --git a/bin/run-example b/bin/run-example
index 3d932509426fc..c567acf9a6b5c 100755
--- a/bin/run-example
+++ b/bin/run-example
@@ -35,17 +35,32 @@ else
fi
if [ -f "$FWDIR/RELEASE" ]; then
- export SPARK_EXAMPLES_JAR="`ls "$FWDIR"/lib/spark-examples-*hadoop*.jar`"
-elif [ -e "$EXAMPLES_DIR"/target/scala-$SPARK_SCALA_VERSION/spark-examples-*hadoop*.jar ]; then
- export SPARK_EXAMPLES_JAR="`ls "$EXAMPLES_DIR"/target/scala-$SPARK_SCALA_VERSION/spark-examples-*hadoop*.jar`"
+ JAR_PATH="${FWDIR}/lib"
+else
+ JAR_PATH="${EXAMPLES_DIR}/target/scala-${SPARK_SCALA_VERSION}"
fi
-if [[ -z "$SPARK_EXAMPLES_JAR" ]]; then
- echo "Failed to find Spark examples assembly in $FWDIR/lib or $FWDIR/examples/target" 1>&2
- echo "You need to build Spark before running this program" 1>&2
+JAR_COUNT=0
+
+for f in ${JAR_PATH}/spark-examples-*hadoop*.jar; do
+ if [[ ! -e "$f" ]]; then
+ echo "Failed to find Spark examples assembly in $FWDIR/lib or $FWDIR/examples/target" 1>&2
+ echo "You need to build Spark before running this program" 1>&2
+ exit 1
+ fi
+ SPARK_EXAMPLES_JAR="$f"
+ JAR_COUNT=$((JAR_COUNT+1))
+done
+
+if [ "$JAR_COUNT" -gt "1" ]; then
+ echo "Found multiple Spark examples assembly jars in ${JAR_PATH}" 1>&2
+ ls ${JAR_PATH}/spark-examples-*hadoop*.jar 1>&2
+ echo "Please remove all but one jar." 1>&2
exit 1
fi
+export SPARK_EXAMPLES_JAR
+
EXAMPLE_MASTER=${MASTER:-"local[*]"}
if [[ ! $EXAMPLE_CLASS == org.apache.spark.examples* ]]; then
diff --git a/bin/spark-class b/bin/spark-class
index 0d58d95c1aee3..2f0441bb3c1c2 100755
--- a/bin/spark-class
+++ b/bin/spark-class
@@ -29,6 +29,7 @@ FWDIR="$(cd "`dirname "$0"`"/..; pwd)"
# Export this as SPARK_HOME
export SPARK_HOME="$FWDIR"
+export SPARK_CONF_DIR="${SPARK_CONF_DIR:-"$SPARK_HOME/conf"}"
. "$FWDIR"/bin/load-spark-env.sh
@@ -71,6 +72,8 @@ case "$1" in
'org.apache.spark.executor.MesosExecutorBackend')
OUR_JAVA_OPTS="$SPARK_JAVA_OPTS $SPARK_EXECUTOR_OPTS"
OUR_JAVA_MEM=${SPARK_EXECUTOR_MEMORY:-$DEFAULT_MEM}
+ export PYTHONPATH="$FWDIR/python:$PYTHONPATH"
+ export PYTHONPATH="$FWDIR/python/lib/py4j-0.8.2.1-src.zip:$PYTHONPATH"
;;
# Spark submit uses SPARK_JAVA_OPTS + SPARK_SUBMIT_OPTS +
@@ -118,8 +121,8 @@ fi
JAVA_OPTS="$JAVA_OPTS -Xms$OUR_JAVA_MEM -Xmx$OUR_JAVA_MEM"
# Load extra JAVA_OPTS from conf/java-opts, if it exists
-if [ -e "$FWDIR/conf/java-opts" ] ; then
- JAVA_OPTS="$JAVA_OPTS `cat "$FWDIR"/conf/java-opts`"
+if [ -e "$SPARK_CONF_DIR/java-opts" ] ; then
+ JAVA_OPTS="$JAVA_OPTS `cat "$SPARK_CONF_DIR"/java-opts`"
fi
# Attention: when changing the way the JAVA_OPTS are assembled, the change must be reflected in CommandUtils.scala!
@@ -148,7 +151,7 @@ fi
if [[ "$1" =~ org.apache.spark.tools.* ]]; then
if test -z "$SPARK_TOOLS_JAR"; then
echo "Failed to find Spark Tools Jar in $FWDIR/tools/target/scala-$SPARK_SCALA_VERSION/" 1>&2
- echo "You need to build Spark before running $1." 1>&2
+ echo "You need to run \"build/sbt tools/package\" before running $1." 1>&2
exit 1
fi
CLASSPATH="$CLASSPATH:$SPARK_TOOLS_JAR"
diff --git a/build/mvn b/build/mvn
index 43471f83e904c..f91e2b4bdcc02 100755
--- a/build/mvn
+++ b/build/mvn
@@ -68,10 +68,10 @@ install_app() {
# Install maven under the build/ folder
install_mvn() {
install_app \
- "http://apache.claz.org/maven/maven-3/3.2.3/binaries" \
- "apache-maven-3.2.3-bin.tar.gz" \
- "apache-maven-3.2.3/bin/mvn"
- MVN_BIN="${_DIR}/apache-maven-3.2.3/bin/mvn"
+ "http://archive.apache.org/dist/maven/maven-3/3.2.5/binaries" \
+ "apache-maven-3.2.5-bin.tar.gz" \
+ "apache-maven-3.2.5/bin/mvn"
+ MVN_BIN="${_DIR}/apache-maven-3.2.5/bin/mvn"
}
# Install zinc under the build/ folder
diff --git a/core/pom.xml b/core/pom.xml
index d9a49c9e08afc..1984682b9c099 100644
--- a/core/pom.xml
+++ b/core/pom.xml
@@ -372,8 +372,10 @@
com.google.guava:guavacom/google/common/base/Absent*
+ com/google/common/base/Functioncom/google/common/base/Optional*com/google/common/base/Present*
+ com/google/common/base/Supplier
diff --git a/core/src/main/java/org/apache/spark/JavaSparkListener.java b/core/src/main/java/org/apache/spark/JavaSparkListener.java
new file mode 100644
index 0000000000000..646496f313507
--- /dev/null
+++ b/core/src/main/java/org/apache/spark/JavaSparkListener.java
@@ -0,0 +1,97 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark;
+
+import org.apache.spark.scheduler.SparkListener;
+import org.apache.spark.scheduler.SparkListenerApplicationEnd;
+import org.apache.spark.scheduler.SparkListenerApplicationStart;
+import org.apache.spark.scheduler.SparkListenerBlockManagerAdded;
+import org.apache.spark.scheduler.SparkListenerBlockManagerRemoved;
+import org.apache.spark.scheduler.SparkListenerEnvironmentUpdate;
+import org.apache.spark.scheduler.SparkListenerExecutorAdded;
+import org.apache.spark.scheduler.SparkListenerExecutorMetricsUpdate;
+import org.apache.spark.scheduler.SparkListenerExecutorRemoved;
+import org.apache.spark.scheduler.SparkListenerJobEnd;
+import org.apache.spark.scheduler.SparkListenerJobStart;
+import org.apache.spark.scheduler.SparkListenerStageCompleted;
+import org.apache.spark.scheduler.SparkListenerStageSubmitted;
+import org.apache.spark.scheduler.SparkListenerTaskEnd;
+import org.apache.spark.scheduler.SparkListenerTaskGettingResult;
+import org.apache.spark.scheduler.SparkListenerTaskStart;
+import org.apache.spark.scheduler.SparkListenerUnpersistRDD;
+
+/**
+ * Java clients should extend this class instead of implementing
+ * SparkListener directly. This is to prevent java clients
+ * from breaking when new events are added to the SparkListener
+ * trait.
+ *
+ * This is a concrete class instead of abstract to enforce
+ * new events get added to both the SparkListener and this adapter
+ * in lockstep.
+ */
+public class JavaSparkListener implements SparkListener {
+
+ @Override
+ public void onStageCompleted(SparkListenerStageCompleted stageCompleted) { }
+
+ @Override
+ public void onStageSubmitted(SparkListenerStageSubmitted stageSubmitted) { }
+
+ @Override
+ public void onTaskStart(SparkListenerTaskStart taskStart) { }
+
+ @Override
+ public void onTaskGettingResult(SparkListenerTaskGettingResult taskGettingResult) { }
+
+ @Override
+ public void onTaskEnd(SparkListenerTaskEnd taskEnd) { }
+
+ @Override
+ public void onJobStart(SparkListenerJobStart jobStart) { }
+
+ @Override
+ public void onJobEnd(SparkListenerJobEnd jobEnd) { }
+
+ @Override
+ public void onEnvironmentUpdate(SparkListenerEnvironmentUpdate environmentUpdate) { }
+
+ @Override
+ public void onBlockManagerAdded(SparkListenerBlockManagerAdded blockManagerAdded) { }
+
+ @Override
+ public void onBlockManagerRemoved(SparkListenerBlockManagerRemoved blockManagerRemoved) { }
+
+ @Override
+ public void onUnpersistRDD(SparkListenerUnpersistRDD unpersistRDD) { }
+
+ @Override
+ public void onApplicationStart(SparkListenerApplicationStart applicationStart) { }
+
+ @Override
+ public void onApplicationEnd(SparkListenerApplicationEnd applicationEnd) { }
+
+ @Override
+ public void onExecutorMetricsUpdate(SparkListenerExecutorMetricsUpdate executorMetricsUpdate) { }
+
+ @Override
+ public void onExecutorAdded(SparkListenerExecutorAdded executorAdded) { }
+
+ @Override
+ public void onExecutorRemoved(SparkListenerExecutorRemoved executorRemoved) { }
+}
diff --git a/core/src/main/resources/org/apache/spark/log4j-defaults.properties b/core/src/main/resources/org/apache/spark/log4j-defaults.properties
index 89eec7d4b7f61..c99a61f63ea2b 100644
--- a/core/src/main/resources/org/apache/spark/log4j-defaults.properties
+++ b/core/src/main/resources/org/apache/spark/log4j-defaults.properties
@@ -10,3 +10,4 @@ log4j.logger.org.eclipse.jetty=WARN
log4j.logger.org.eclipse.jetty.util.component.AbstractLifeCycle=ERROR
log4j.logger.org.apache.spark.repl.SparkIMain$exprTyper=INFO
log4j.logger.org.apache.spark.repl.SparkILoop$SparkILoopInterpreter=INFO
+log4j.logger.org.apache.hadoop.yarn.util.RackResolver=WARN
diff --git a/core/src/main/resources/org/apache/spark/ui/static/webui.css b/core/src/main/resources/org/apache/spark/ui/static/webui.css
index 5751964b792ce..f23ba9dba167f 100644
--- a/core/src/main/resources/org/apache/spark/ui/static/webui.css
+++ b/core/src/main/resources/org/apache/spark/ui/static/webui.css
@@ -19,6 +19,7 @@
height: 50px;
font-size: 15px;
margin-bottom: 15px;
+ min-width: 1200px
}
.navbar .navbar-inner {
@@ -39,12 +40,12 @@
.navbar .nav > li a {
height: 30px;
- line-height: 30px;
+ line-height: 2;
}
.navbar-text {
height: 50px;
- line-height: 50px;
+ line-height: 3.3;
}
table.sortable thead {
@@ -120,6 +121,14 @@ pre {
border: none;
}
+.description-input {
+ overflow: hidden;
+ text-overflow: ellipsis;
+ width: 100%;
+ white-space: nowrap;
+ display: block;
+}
+
.stacktrace-details {
max-height: 300px;
overflow-y: auto;
@@ -170,7 +179,7 @@ span.additional-metric-title {
}
.version {
- line-height: 30px;
+ line-height: 2.5;
vertical-align: bottom;
font-size: 12px;
padding: 0;
@@ -181,6 +190,7 @@ span.additional-metric-title {
/* Hide all additional metrics by default. This is done here rather than using JavaScript to
* avoid slow page loads for stage pages with large numbers (e.g., thousands) of tasks. */
-.scheduler_delay, .deserialization_time, .serialization_time, .getting_result_time {
+.scheduler_delay, .deserialization_time, .fetch_wait_time, .serialization_time,
+.getting_result_time {
display: none;
}
diff --git a/core/src/main/scala/org/apache/spark/Aggregator.scala b/core/src/main/scala/org/apache/spark/Aggregator.scala
index 09eb9605fb799..3b684bbeceaf2 100644
--- a/core/src/main/scala/org/apache/spark/Aggregator.scala
+++ b/core/src/main/scala/org/apache/spark/Aggregator.scala
@@ -61,8 +61,8 @@ case class Aggregator[K, V, C] (
// Update task metrics if context is not null
// TODO: Make context non optional in a future release
Option(context).foreach { c =>
- c.taskMetrics.memoryBytesSpilled += combiners.memoryBytesSpilled
- c.taskMetrics.diskBytesSpilled += combiners.diskBytesSpilled
+ c.taskMetrics.incMemoryBytesSpilled(combiners.memoryBytesSpilled)
+ c.taskMetrics.incDiskBytesSpilled(combiners.diskBytesSpilled)
}
combiners.iterator
}
@@ -95,8 +95,8 @@ case class Aggregator[K, V, C] (
// Update task metrics if context is not null
// TODO: Make context non-optional in a future release
Option(context).foreach { c =>
- c.taskMetrics.memoryBytesSpilled += combiners.memoryBytesSpilled
- c.taskMetrics.diskBytesSpilled += combiners.diskBytesSpilled
+ c.taskMetrics.incMemoryBytesSpilled(combiners.memoryBytesSpilled)
+ c.taskMetrics.incDiskBytesSpilled(combiners.diskBytesSpilled)
}
combiners.iterator
}
diff --git a/core/src/main/scala/org/apache/spark/CacheManager.scala b/core/src/main/scala/org/apache/spark/CacheManager.scala
index 80da62c44edc5..a0c0372b7f0ef 100644
--- a/core/src/main/scala/org/apache/spark/CacheManager.scala
+++ b/core/src/main/scala/org/apache/spark/CacheManager.scala
@@ -44,7 +44,11 @@ private[spark] class CacheManager(blockManager: BlockManager) extends Logging {
blockManager.get(key) match {
case Some(blockResult) =>
// Partition is already materialized, so just return its values
- context.taskMetrics.inputMetrics = Some(blockResult.inputMetrics)
+ val inputMetrics = blockResult.inputMetrics
+ val existingMetrics = context.taskMetrics
+ .getInputMetricsForReadMethod(inputMetrics.readMethod)
+ existingMetrics.addBytesRead(inputMetrics.bytesRead)
+
new InterruptibleIterator(context, blockResult.data.asInstanceOf[Iterator[T]])
case None =>
diff --git a/core/src/main/scala/org/apache/spark/ExecutorAllocationManager.scala b/core/src/main/scala/org/apache/spark/ExecutorAllocationManager.scala
index a0ee2a7cbb2a2..b28da192c1c0d 100644
--- a/core/src/main/scala/org/apache/spark/ExecutorAllocationManager.scala
+++ b/core/src/main/scala/org/apache/spark/ExecutorAllocationManager.scala
@@ -158,7 +158,7 @@ private[spark] class ExecutorAllocationManager(
"shuffle service. You may enable this through spark.shuffle.service.enabled.")
}
if (tasksPerExecutor == 0) {
- throw new SparkException("spark.executor.cores must not be less than spark.task.cpus.cores")
+ throw new SparkException("spark.executor.cores must not be less than spark.task.cpus.")
}
}
diff --git a/core/src/main/scala/org/apache/spark/Logging.scala b/core/src/main/scala/org/apache/spark/Logging.scala
index d4f2624061e35..419d093d55643 100644
--- a/core/src/main/scala/org/apache/spark/Logging.scala
+++ b/core/src/main/scala/org/apache/spark/Logging.scala
@@ -118,15 +118,17 @@ trait Logging {
// org.slf4j.impl.Log4jLoggerFactory, from the log4j 2.0 binding, currently
// org.apache.logging.slf4j.Log4jLoggerFactory
val usingLog4j12 = "org.slf4j.impl.Log4jLoggerFactory".equals(binderClass)
- val log4j12Initialized = LogManager.getRootLogger.getAllAppenders.hasMoreElements
- if (!log4j12Initialized && usingLog4j12) {
- val defaultLogProps = "org/apache/spark/log4j-defaults.properties"
- Option(Utils.getSparkClassLoader.getResource(defaultLogProps)) match {
- case Some(url) =>
- PropertyConfigurator.configure(url)
- System.err.println(s"Using Spark's default log4j profile: $defaultLogProps")
- case None =>
- System.err.println(s"Spark was unable to load $defaultLogProps")
+ if (usingLog4j12) {
+ val log4j12Initialized = LogManager.getRootLogger.getAllAppenders.hasMoreElements
+ if (!log4j12Initialized) {
+ val defaultLogProps = "org/apache/spark/log4j-defaults.properties"
+ Option(Utils.getSparkClassLoader.getResource(defaultLogProps)) match {
+ case Some(url) =>
+ PropertyConfigurator.configure(url)
+ System.err.println(s"Using Spark's default log4j profile: $defaultLogProps")
+ case None =>
+ System.err.println(s"Spark was unable to load $defaultLogProps")
+ }
}
}
Logging.initialized = true
diff --git a/core/src/main/scala/org/apache/spark/SparkConf.scala b/core/src/main/scala/org/apache/spark/SparkConf.scala
index a0ce107f43b16..cd91c8f87547b 100644
--- a/core/src/main/scala/org/apache/spark/SparkConf.scala
+++ b/core/src/main/scala/org/apache/spark/SparkConf.scala
@@ -17,8 +17,11 @@
package org.apache.spark
+import java.util.concurrent.ConcurrentHashMap
+
import scala.collection.JavaConverters._
-import scala.collection.mutable.{HashMap, LinkedHashSet}
+import scala.collection.mutable.LinkedHashSet
+
import org.apache.spark.serializer.KryoSerializer
/**
@@ -46,12 +49,12 @@ class SparkConf(loadDefaults: Boolean) extends Cloneable with Logging {
/** Create a SparkConf that loads defaults from system properties and the classpath */
def this() = this(true)
- private[spark] val settings = new HashMap[String, String]()
+ private val settings = new ConcurrentHashMap[String, String]()
if (loadDefaults) {
// Load any spark.* system properties
for ((k, v) <- System.getProperties.asScala if k.startsWith("spark.")) {
- settings(k) = v
+ set(k, v)
}
}
@@ -63,7 +66,7 @@ class SparkConf(loadDefaults: Boolean) extends Cloneable with Logging {
if (value == null) {
throw new NullPointerException("null value for " + key)
}
- settings(key) = value
+ settings.put(key, value)
this
}
@@ -129,15 +132,13 @@ class SparkConf(loadDefaults: Boolean) extends Cloneable with Logging {
/** Set multiple parameters together */
def setAll(settings: Traversable[(String, String)]) = {
- this.settings ++= settings
+ this.settings.putAll(settings.toMap.asJava)
this
}
/** Set a parameter if it isn't already configured */
def setIfMissing(key: String, value: String): SparkConf = {
- if (!settings.contains(key)) {
- settings(key) = value
- }
+ settings.putIfAbsent(key, value)
this
}
@@ -163,21 +164,23 @@ class SparkConf(loadDefaults: Boolean) extends Cloneable with Logging {
/** Get a parameter; throws a NoSuchElementException if it's not set */
def get(key: String): String = {
- settings.getOrElse(key, throw new NoSuchElementException(key))
+ getOption(key).getOrElse(throw new NoSuchElementException(key))
}
/** Get a parameter, falling back to a default if not set */
def get(key: String, defaultValue: String): String = {
- settings.getOrElse(key, defaultValue)
+ getOption(key).getOrElse(defaultValue)
}
/** Get a parameter as an Option */
def getOption(key: String): Option[String] = {
- settings.get(key)
+ Option(settings.get(key))
}
/** Get all parameters as a list of pairs */
- def getAll: Array[(String, String)] = settings.clone().toArray
+ def getAll: Array[(String, String)] = {
+ settings.entrySet().asScala.map(x => (x.getKey, x.getValue)).toArray
+ }
/** Get a parameter as an integer, falling back to a default if not set */
def getInt(key: String, defaultValue: Int): Int = {
@@ -224,11 +227,11 @@ class SparkConf(loadDefaults: Boolean) extends Cloneable with Logging {
def getAppId: String = get("spark.app.id")
/** Does the configuration contain a given parameter? */
- def contains(key: String): Boolean = settings.contains(key)
+ def contains(key: String): Boolean = settings.containsKey(key)
/** Copy this object */
override def clone: SparkConf = {
- new SparkConf(false).setAll(settings)
+ new SparkConf(false).setAll(getAll)
}
/**
@@ -240,7 +243,7 @@ class SparkConf(loadDefaults: Boolean) extends Cloneable with Logging {
/** Checks for illegal or deprecated config settings. Throws an exception for the former. Not
* idempotent - may mutate this conf object to convert deprecated settings to supported ones. */
private[spark] def validateSettings() {
- if (settings.contains("spark.local.dir")) {
+ if (contains("spark.local.dir")) {
val msg = "In Spark 1.0 and later spark.local.dir will be overridden by the value set by " +
"the cluster manager (via SPARK_LOCAL_DIRS in mesos/standalone and LOCAL_DIRS in YARN)."
logWarning(msg)
@@ -265,7 +268,7 @@ class SparkConf(loadDefaults: Boolean) extends Cloneable with Logging {
}
// Validate spark.executor.extraJavaOptions
- settings.get(executorOptsKey).map { javaOpts =>
+ getOption(executorOptsKey).map { javaOpts =>
if (javaOpts.contains("-Dspark")) {
val msg = s"$executorOptsKey is not allowed to set Spark options (was '$javaOpts'). " +
"Set them directly on a SparkConf or in a properties file when using ./bin/spark-submit."
@@ -345,7 +348,7 @@ class SparkConf(loadDefaults: Boolean) extends Cloneable with Logging {
* configuration out for debugging.
*/
def toDebugString: String = {
- settings.toArray.sorted.map{case (k, v) => k + "=" + v}.mkString("\n")
+ getAll.sorted.map{case (k, v) => k + "=" + v}.mkString("\n")
}
}
diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala
index ff5d796ee2766..4c4ee04cc515e 100644
--- a/core/src/main/scala/org/apache/spark/SparkContext.scala
+++ b/core/src/main/scala/org/apache/spark/SparkContext.scala
@@ -85,6 +85,14 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli
val startTime = System.currentTimeMillis()
+ @volatile private var stopped: Boolean = false
+
+ private def assertNotStopped(): Unit = {
+ if (stopped) {
+ throw new IllegalStateException("Cannot call methods on a stopped SparkContext")
+ }
+ }
+
/**
* Create a SparkContext that loads settings from system properties (for instance, when
* launching with ./bin/spark-submit).
@@ -520,12 +528,12 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli
/** Distribute a local Scala collection to form an RDD.
*
- * @note Parallelize acts lazily. If `seq` is a mutable collection and is
- * altered after the call to parallelize and before the first action on the
- * RDD, the resultant RDD will reflect the modified collection. Pass a copy of
- * the argument to avoid this.
+ * @note Parallelize acts lazily. If `seq` is a mutable collection and is altered after the call
+ * to parallelize and before the first action on the RDD, the resultant RDD will reflect the
+ * modified collection. Pass a copy of the argument to avoid this.
*/
def parallelize[T: ClassTag](seq: Seq[T], numSlices: Int = defaultParallelism): RDD[T] = {
+ assertNotStopped()
new ParallelCollectionRDD[T](this, seq, numSlices, Map[Int, Seq[String]]())
}
@@ -541,6 +549,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli
* location preferences (hostnames of Spark nodes) for each object.
* Create a new partition for each collection item. */
def makeRDD[T: ClassTag](seq: Seq[(T, Seq[String])]): RDD[T] = {
+ assertNotStopped()
val indexToPrefs = seq.zipWithIndex.map(t => (t._2, t._1._2)).toMap
new ParallelCollectionRDD[T](this, seq.map(_._1), seq.size, indexToPrefs)
}
@@ -550,6 +559,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli
* Hadoop-supported file system URI, and return it as an RDD of Strings.
*/
def textFile(path: String, minPartitions: Int = defaultMinPartitions): RDD[String] = {
+ assertNotStopped()
hadoopFile(path, classOf[TextInputFormat], classOf[LongWritable], classOf[Text],
minPartitions).map(pair => pair._2.toString).setName(path)
}
@@ -583,6 +593,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli
*/
def wholeTextFiles(path: String, minPartitions: Int = defaultMinPartitions):
RDD[(String, String)] = {
+ assertNotStopped()
val job = new NewHadoopJob(hadoopConfiguration)
NewFileInputFormat.addInputPath(job, new Path(path))
val updateConf = job.getConfiguration
@@ -628,6 +639,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli
@Experimental
def binaryFiles(path: String, minPartitions: Int = defaultMinPartitions):
RDD[(String, PortableDataStream)] = {
+ assertNotStopped()
val job = new NewHadoopJob(hadoopConfiguration)
NewFileInputFormat.addInputPath(job, new Path(path))
val updateConf = job.getConfiguration
@@ -652,6 +664,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli
@Experimental
def binaryRecords(path: String, recordLength: Int, conf: Configuration = hadoopConfiguration)
: RDD[Array[Byte]] = {
+ assertNotStopped()
conf.setInt(FixedLengthBinaryInputFormat.RECORD_LENGTH_PROPERTY, recordLength)
val br = newAPIHadoopFile[LongWritable, BytesWritable, FixedLengthBinaryInputFormat](path,
classOf[FixedLengthBinaryInputFormat],
@@ -685,6 +698,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli
valueClass: Class[V],
minPartitions: Int = defaultMinPartitions
): RDD[(K, V)] = {
+ assertNotStopped()
// Add necessary security credentials to the JobConf before broadcasting it.
SparkHadoopUtil.get.addCredentials(conf)
new HadoopRDD(this, conf, inputFormatClass, keyClass, valueClass, minPartitions)
@@ -704,6 +718,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli
valueClass: Class[V],
minPartitions: Int = defaultMinPartitions
): RDD[(K, V)] = {
+ assertNotStopped()
// A Hadoop configuration can be about 10 KB, which is pretty big, so broadcast it.
val confBroadcast = broadcast(new SerializableWritable(hadoopConfiguration))
val setInputPathsFunc = (jobConf: JobConf) => FileInputFormat.setInputPaths(jobConf, path)
@@ -783,6 +798,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli
kClass: Class[K],
vClass: Class[V],
conf: Configuration = hadoopConfiguration): RDD[(K, V)] = {
+ assertNotStopped()
val job = new NewHadoopJob(conf)
NewFileInputFormat.addInputPath(job, new Path(path))
val updatedConf = job.getConfiguration
@@ -803,6 +819,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli
fClass: Class[F],
kClass: Class[K],
vClass: Class[V]): RDD[(K, V)] = {
+ assertNotStopped()
new NewHadoopRDD(this, fClass, kClass, vClass, conf)
}
@@ -818,6 +835,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli
valueClass: Class[V],
minPartitions: Int
): RDD[(K, V)] = {
+ assertNotStopped()
val inputFormatClass = classOf[SequenceFileInputFormat[K, V]]
hadoopFile(path, inputFormatClass, keyClass, valueClass, minPartitions)
}
@@ -829,9 +847,10 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli
* If you plan to directly cache Hadoop writable objects, you should first copy them using
* a `map` function.
* */
- def sequenceFile[K, V](path: String, keyClass: Class[K], valueClass: Class[V]
- ): RDD[(K, V)] =
+ def sequenceFile[K, V](path: String, keyClass: Class[K], valueClass: Class[V]): RDD[(K, V)] = {
+ assertNotStopped()
sequenceFile(path, keyClass, valueClass, defaultMinPartitions)
+ }
/**
* Version of sequenceFile() for types implicitly convertible to Writables through a
@@ -859,6 +878,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli
(implicit km: ClassTag[K], vm: ClassTag[V],
kcf: () => WritableConverter[K], vcf: () => WritableConverter[V])
: RDD[(K, V)] = {
+ assertNotStopped()
val kc = kcf()
val vc = vcf()
val format = classOf[SequenceFileInputFormat[Writable, Writable]]
@@ -880,6 +900,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli
path: String,
minPartitions: Int = defaultMinPartitions
): RDD[T] = {
+ assertNotStopped()
sequenceFile(path, classOf[NullWritable], classOf[BytesWritable], minPartitions)
.flatMap(x => Utils.deserialize[Array[T]](x._2.getBytes, Utils.getContextOrSparkClassLoader))
}
@@ -955,6 +976,13 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli
* The variable will be sent to each cluster only once.
*/
def broadcast[T: ClassTag](value: T): Broadcast[T] = {
+ assertNotStopped()
+ if (classOf[RDD[_]].isAssignableFrom(classTag[T].runtimeClass)) {
+ // This is a warning instead of an exception in order to avoid breaking user programs that
+ // might have created RDD broadcast variables but not used them:
+ logWarning("Can not directly broadcast RDDs; instead, call collect() and "
+ + "broadcast the result (see SPARK-5063)")
+ }
val bc = env.broadcastManager.newBroadcast[T](value, isLocal)
val callSite = getCallSite
logInfo("Created broadcast " + bc.id + " from " + callSite.shortForm)
@@ -1047,6 +1075,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli
* memory available for caching.
*/
def getExecutorMemoryStatus: Map[String, (Long, Long)] = {
+ assertNotStopped()
env.blockManager.master.getMemoryStatus.map { case(blockManagerId, mem) =>
(blockManagerId.host + ":" + blockManagerId.port, mem)
}
@@ -1059,6 +1088,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli
*/
@DeveloperApi
def getRDDStorageInfo: Array[RDDInfo] = {
+ assertNotStopped()
val rddInfos = persistentRdds.values.map(RDDInfo.fromRdd).toArray
StorageUtils.updateRddInfo(rddInfos, getExecutorStorageStatus)
rddInfos.filter(_.isCached)
@@ -1076,6 +1106,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli
*/
@DeveloperApi
def getExecutorStorageStatus: Array[StorageStatus] = {
+ assertNotStopped()
env.blockManager.master.getStorageStatus
}
@@ -1085,6 +1116,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli
*/
@DeveloperApi
def getAllPools: Seq[Schedulable] = {
+ assertNotStopped()
// TODO(xiajunluan): We should take nested pools into account
taskScheduler.rootPool.schedulableQueue.toSeq
}
@@ -1095,6 +1127,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli
*/
@DeveloperApi
def getPoolForName(pool: String): Option[Schedulable] = {
+ assertNotStopped()
Option(taskScheduler.rootPool.schedulableNameToSchedulable.get(pool))
}
@@ -1102,6 +1135,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli
* Return current scheduling mode
*/
def getSchedulingMode: SchedulingMode.SchedulingMode = {
+ assertNotStopped()
taskScheduler.schedulingMode
}
@@ -1207,16 +1241,14 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli
SparkContext.SPARK_CONTEXT_CONSTRUCTOR_LOCK.synchronized {
postApplicationEnd()
ui.foreach(_.stop())
- // Do this only if not stopped already - best case effort.
- // prevent NPE if stopped more than once.
- val dagSchedulerCopy = dagScheduler
- dagScheduler = null
- if (dagSchedulerCopy != null) {
+ if (!stopped) {
+ stopped = true
env.metricsSystem.report()
metadataCleaner.cancel()
env.actorSystem.stop(heartbeatReceiver)
cleaner.foreach(_.stop())
- dagSchedulerCopy.stop()
+ dagScheduler.stop()
+ dagScheduler = null
taskScheduler = null
// TODO: Cache.stop()?
env.stop()
@@ -1290,8 +1322,8 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli
partitions: Seq[Int],
allowLocal: Boolean,
resultHandler: (Int, U) => Unit) {
- if (dagScheduler == null) {
- throw new SparkException("SparkContext has been shutdown")
+ if (stopped) {
+ throw new IllegalStateException("SparkContext has been shutdown")
}
val callSite = getCallSite
val cleanedFunc = clean(func)
@@ -1378,6 +1410,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli
func: (TaskContext, Iterator[T]) => U,
evaluator: ApproximateEvaluator[U, R],
timeout: Long): PartialResult[R] = {
+ assertNotStopped()
val callSite = getCallSite
logInfo("Starting job: " + callSite.shortForm)
val start = System.nanoTime
@@ -1400,6 +1433,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli
resultHandler: (Int, U) => Unit,
resultFunc: => R): SimpleFutureAction[R] =
{
+ assertNotStopped()
val cleanF = clean(processPartition)
val callSite = getCallSite
val waiter = dagScheduler.submitJob(
@@ -1418,11 +1452,13 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli
* for more information.
*/
def cancelJobGroup(groupId: String) {
+ assertNotStopped()
dagScheduler.cancelJobGroup(groupId)
}
/** Cancel all jobs that have been scheduled or are running. */
def cancelAllJobs() {
+ assertNotStopped()
dagScheduler.cancelAllJobs()
}
@@ -1469,13 +1505,20 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli
def getCheckpointDir = checkpointDir
/** Default level of parallelism to use when not given by user (e.g. parallelize and makeRDD). */
- def defaultParallelism: Int = taskScheduler.defaultParallelism
+ def defaultParallelism: Int = {
+ assertNotStopped()
+ taskScheduler.defaultParallelism
+ }
/** Default min number of partitions for Hadoop RDDs when not given by user */
@deprecated("use defaultMinPartitions", "1.0.0")
def defaultMinSplits: Int = math.min(defaultParallelism, 2)
- /** Default min number of partitions for Hadoop RDDs when not given by user */
+ /**
+ * Default min number of partitions for Hadoop RDDs when not given by user
+ * Notice that we use math.min so the "defaultMinPartitions" cannot be higher than 2.
+ * The reasons for this are discussed in https://github.com/mesos/spark/pull/718
+ */
def defaultMinPartitions: Int = math.min(defaultParallelism, 2)
private val nextShuffleId = new AtomicInteger(0)
diff --git a/core/src/main/scala/org/apache/spark/SparkEnv.scala b/core/src/main/scala/org/apache/spark/SparkEnv.scala
index 4d418037bd33f..1264a8126153b 100644
--- a/core/src/main/scala/org/apache/spark/SparkEnv.scala
+++ b/core/src/main/scala/org/apache/spark/SparkEnv.scala
@@ -326,6 +326,10 @@ object SparkEnv extends Logging {
// Then we can start the metrics system.
MetricsSystem.createMetricsSystem("driver", conf, securityManager)
} else {
+ // We need to set the executor ID before the MetricsSystem is created because sources and
+ // sinks specified in the metrics configuration file will want to incorporate this executor's
+ // ID into the metrics they report.
+ conf.set("spark.executor.id", executorId)
val ms = MetricsSystem.createMetricsSystem("executor", conf, securityManager)
ms.start()
ms
diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaRDDLike.scala b/core/src/main/scala/org/apache/spark/api/java/JavaRDDLike.scala
index bd451634e53d2..62bf18d82d9b0 100644
--- a/core/src/main/scala/org/apache/spark/api/java/JavaRDDLike.scala
+++ b/core/src/main/scala/org/apache/spark/api/java/JavaRDDLike.scala
@@ -38,6 +38,10 @@ import org.apache.spark.rdd.RDD
import org.apache.spark.storage.StorageLevel
import org.apache.spark.util.Utils
+/**
+ * Defines operations common to several Java RDD implementations.
+ * Note that this trait is not intended to be implemented by user code.
+ */
trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable {
def wrapRDD(rdd: RDD[T]): This
@@ -435,6 +439,12 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable {
*/
def first(): T = rdd.first()
+ /**
+ * @return true if and only if the RDD contains no elements at all. Note that an RDD
+ * may be empty even when it has at least 1 partition.
+ */
+ def isEmpty(): Boolean = rdd.isEmpty()
+
/**
* Save this RDD as a text file, using string representations of elements.
*/
diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
index bad40e6529f74..4ac666c54fbcd 100644
--- a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
+++ b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
@@ -125,8 +125,8 @@ private[spark] class PythonRDD(
init, finish))
val memoryBytesSpilled = stream.readLong()
val diskBytesSpilled = stream.readLong()
- context.taskMetrics.memoryBytesSpilled += memoryBytesSpilled
- context.taskMetrics.diskBytesSpilled += diskBytesSpilled
+ context.taskMetrics.incMemoryBytesSpilled(memoryBytesSpilled)
+ context.taskMetrics.incDiskBytesSpilled(diskBytesSpilled)
read()
case SpecialLengths.PYTHON_EXCEPTION_THROWN =>
// Signals that an exception has been thrown in python
diff --git a/core/src/main/scala/org/apache/spark/deploy/ClientArguments.scala b/core/src/main/scala/org/apache/spark/deploy/ClientArguments.scala
index 2e1e52906ceeb..e5873ce724b9f 100644
--- a/core/src/main/scala/org/apache/spark/deploy/ClientArguments.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/ClientArguments.scala
@@ -23,7 +23,7 @@ import scala.collection.mutable.ListBuffer
import org.apache.log4j.Level
-import org.apache.spark.util.MemoryParam
+import org.apache.spark.util.{IntParam, MemoryParam}
/**
* Command-line parser for the driver client.
@@ -51,8 +51,8 @@ private[spark] class ClientArguments(args: Array[String]) {
parse(args.toList)
def parse(args: List[String]): Unit = args match {
- case ("--cores" | "-c") :: value :: tail =>
- cores = value.toInt
+ case ("--cores" | "-c") :: IntParam(value) :: tail =>
+ cores = value
parse(tail)
case ("--memory" | "-m") :: MemoryParam(value) :: tail =>
diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala b/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala
index 57f9faf5ddd1d..211e3ede53d9c 100644
--- a/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala
@@ -133,10 +133,9 @@ class SparkHadoopUtil extends Logging {
* statistics are only available as of Hadoop 2.5 (see HADOOP-10688).
* Returns None if the required method can't be found.
*/
- private[spark] def getFSBytesReadOnThreadCallback(path: Path, conf: Configuration)
- : Option[() => Long] = {
+ private[spark] def getFSBytesReadOnThreadCallback(): Option[() => Long] = {
try {
- val threadStats = getFileSystemThreadStatistics(path, conf)
+ val threadStats = getFileSystemThreadStatistics()
val getBytesReadMethod = getFileSystemThreadStatisticsMethod("getBytesRead")
val f = () => threadStats.map(getBytesReadMethod.invoke(_).asInstanceOf[Long]).sum
val baselineBytesRead = f()
@@ -156,10 +155,9 @@ class SparkHadoopUtil extends Logging {
* statistics are only available as of Hadoop 2.5 (see HADOOP-10688).
* Returns None if the required method can't be found.
*/
- private[spark] def getFSBytesWrittenOnThreadCallback(path: Path, conf: Configuration)
- : Option[() => Long] = {
+ private[spark] def getFSBytesWrittenOnThreadCallback(): Option[() => Long] = {
try {
- val threadStats = getFileSystemThreadStatistics(path, conf)
+ val threadStats = getFileSystemThreadStatistics()
val getBytesWrittenMethod = getFileSystemThreadStatisticsMethod("getBytesWritten")
val f = () => threadStats.map(getBytesWrittenMethod.invoke(_).asInstanceOf[Long]).sum
val baselineBytesWritten = f()
@@ -172,10 +170,8 @@ class SparkHadoopUtil extends Logging {
}
}
- private def getFileSystemThreadStatistics(path: Path, conf: Configuration): Seq[AnyRef] = {
- val qualifiedPath = path.getFileSystem(conf).makeQualified(path)
- val scheme = qualifiedPath.toUri().getScheme()
- val stats = FileSystem.getAllStatistics().filter(_.getScheme().equals(scheme))
+ private def getFileSystemThreadStatistics(): Seq[AnyRef] = {
+ val stats = FileSystem.getAllStatistics()
stats.map(Utils.invoke(classOf[Statistics], _, "getThreadStatistics"))
}
diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala
index 955cbd6dab96d..050ba91eb2bc3 100644
--- a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala
@@ -200,6 +200,7 @@ object SparkSubmit {
// Yarn cluster only
OptionAssigner(args.name, YARN, CLUSTER, clOption = "--name"),
OptionAssigner(args.driverMemory, YARN, CLUSTER, clOption = "--driver-memory"),
+ OptionAssigner(args.driverCores, YARN, CLUSTER, clOption = "--driver-cores"),
OptionAssigner(args.queue, YARN, CLUSTER, clOption = "--queue"),
OptionAssigner(args.numExecutors, YARN, CLUSTER, clOption = "--num-executors"),
OptionAssigner(args.executorMemory, YARN, CLUSTER, clOption = "--executor-memory"),
diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala b/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala
index 47059b08a397f..81ec08cb6d501 100644
--- a/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala
@@ -108,6 +108,9 @@ private[spark] class SparkSubmitArguments(args: Seq[String], env: Map[String, St
.orElse(sparkProperties.get("spark.driver.memory"))
.orElse(env.get("SPARK_DRIVER_MEMORY"))
.orNull
+ driverCores = Option(driverCores)
+ .orElse(sparkProperties.get("spark.driver.cores"))
+ .orNull
executorMemory = Option(executorMemory)
.orElse(sparkProperties.get("spark.executor.memory"))
.orElse(env.get("SPARK_EXECUTOR_MEMORY"))
@@ -406,6 +409,8 @@ private[spark] class SparkSubmitArguments(args: Seq[String], env: Map[String, St
| --total-executor-cores NUM Total cores for all executors.
|
| YARN-only:
+ | --driver-cores NUM Number of cores used by the driver, only in cluster mode
+ | (Default: 1).
| --executor-cores NUM Number of cores per executor (Default: 1).
| --queue QUEUE_NAME The YARN queue to submit to (Default: "default").
| --num-executors NUM Number of executors to launch (Default: 2).
diff --git a/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala b/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala
index 2b084a2d73b78..0ae45f4ad9130 100644
--- a/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala
@@ -203,7 +203,9 @@ private[history] class FsHistoryProvider(conf: SparkConf) extends ApplicationHis
if (!logInfos.isEmpty) {
val newApps = new mutable.LinkedHashMap[String, FsApplicationHistoryInfo]()
def addIfAbsent(info: FsApplicationHistoryInfo) = {
- if (!newApps.contains(info.id)) {
+ if (!newApps.contains(info.id) ||
+ newApps(info.id).logPath.endsWith(EventLoggingListener.IN_PROGRESS) &&
+ !info.logPath.endsWith(EventLoggingListener.IN_PROGRESS)) {
newApps += (info.id -> info)
}
}
diff --git a/core/src/main/scala/org/apache/spark/deploy/master/ApplicationInfo.scala b/core/src/main/scala/org/apache/spark/deploy/master/ApplicationInfo.scala
index ad7d81747c377..ede0a9dbefb8d 100644
--- a/core/src/main/scala/org/apache/spark/deploy/master/ApplicationInfo.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/master/ApplicationInfo.scala
@@ -38,8 +38,8 @@ private[spark] class ApplicationInfo(
extends Serializable {
@transient var state: ApplicationState.Value = _
- @transient var executors: mutable.HashMap[Int, ExecutorInfo] = _
- @transient var removedExecutors: ArrayBuffer[ExecutorInfo] = _
+ @transient var executors: mutable.HashMap[Int, ExecutorDesc] = _
+ @transient var removedExecutors: ArrayBuffer[ExecutorDesc] = _
@transient var coresGranted: Int = _
@transient var endTime: Long = _
@transient var appSource: ApplicationSource = _
@@ -55,12 +55,12 @@ private[spark] class ApplicationInfo(
private def init() {
state = ApplicationState.WAITING
- executors = new mutable.HashMap[Int, ExecutorInfo]
+ executors = new mutable.HashMap[Int, ExecutorDesc]
coresGranted = 0
endTime = -1L
appSource = new ApplicationSource(this)
nextExecutorId = 0
- removedExecutors = new ArrayBuffer[ExecutorInfo]
+ removedExecutors = new ArrayBuffer[ExecutorDesc]
}
private def newExecutorId(useID: Option[Int] = None): Int = {
@@ -75,14 +75,14 @@ private[spark] class ApplicationInfo(
}
}
- def addExecutor(worker: WorkerInfo, cores: Int, useID: Option[Int] = None): ExecutorInfo = {
- val exec = new ExecutorInfo(newExecutorId(useID), this, worker, cores, desc.memoryPerSlave)
+ def addExecutor(worker: WorkerInfo, cores: Int, useID: Option[Int] = None): ExecutorDesc = {
+ val exec = new ExecutorDesc(newExecutorId(useID), this, worker, cores, desc.memoryPerSlave)
executors(exec.id) = exec
coresGranted += cores
exec
}
- def removeExecutor(exec: ExecutorInfo) {
+ def removeExecutor(exec: ExecutorDesc) {
if (executors.contains(exec.id)) {
removedExecutors += executors(exec.id)
executors -= exec.id
diff --git a/core/src/main/scala/org/apache/spark/deploy/master/ExecutorInfo.scala b/core/src/main/scala/org/apache/spark/deploy/master/ExecutorDesc.scala
similarity index 95%
rename from core/src/main/scala/org/apache/spark/deploy/master/ExecutorInfo.scala
rename to core/src/main/scala/org/apache/spark/deploy/master/ExecutorDesc.scala
index d417070c51016..5d620dfcabad5 100644
--- a/core/src/main/scala/org/apache/spark/deploy/master/ExecutorInfo.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/master/ExecutorDesc.scala
@@ -19,7 +19,7 @@ package org.apache.spark.deploy.master
import org.apache.spark.deploy.{ExecutorDescription, ExecutorState}
-private[spark] class ExecutorInfo(
+private[spark] class ExecutorDesc(
val id: Int,
val application: ApplicationInfo,
val worker: WorkerInfo,
@@ -37,7 +37,7 @@ private[spark] class ExecutorInfo(
override def equals(other: Any): Boolean = {
other match {
- case info: ExecutorInfo =>
+ case info: ExecutorDesc =>
fullId == info.fullId &&
worker.id == info.worker.id &&
cores == info.cores &&
diff --git a/core/src/main/scala/org/apache/spark/deploy/master/Master.scala b/core/src/main/scala/org/apache/spark/deploy/master/Master.scala
index 4b631ec639071..d92d99310a583 100644
--- a/core/src/main/scala/org/apache/spark/deploy/master/Master.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/master/Master.scala
@@ -581,7 +581,7 @@ private[spark] class Master(
}
}
- def launchExecutor(worker: WorkerInfo, exec: ExecutorInfo) {
+ def launchExecutor(worker: WorkerInfo, exec: ExecutorDesc) {
logInfo("Launching executor " + exec.fullId + " on worker " + worker.id)
worker.addExecutor(exec)
worker.actor ! LaunchExecutor(masterUrl,
diff --git a/core/src/main/scala/org/apache/spark/deploy/master/WorkerInfo.scala b/core/src/main/scala/org/apache/spark/deploy/master/WorkerInfo.scala
index 473ddc23ff0f3..e94aae93e4495 100644
--- a/core/src/main/scala/org/apache/spark/deploy/master/WorkerInfo.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/master/WorkerInfo.scala
@@ -38,7 +38,7 @@ private[spark] class WorkerInfo(
Utils.checkHost(host, "Expected hostname")
assert (port > 0)
- @transient var executors: mutable.HashMap[String, ExecutorInfo] = _ // executorId => info
+ @transient var executors: mutable.HashMap[String, ExecutorDesc] = _ // executorId => info
@transient var drivers: mutable.HashMap[String, DriverInfo] = _ // driverId => info
@transient var state: WorkerState.Value = _
@transient var coresUsed: Int = _
@@ -70,13 +70,13 @@ private[spark] class WorkerInfo(
host + ":" + port
}
- def addExecutor(exec: ExecutorInfo) {
+ def addExecutor(exec: ExecutorDesc) {
executors(exec.fullId) = exec
coresUsed += exec.cores
memoryUsed += exec.memory
}
- def removeExecutor(exec: ExecutorInfo) {
+ def removeExecutor(exec: ExecutorDesc) {
if (executors.contains(exec.fullId)) {
executors -= exec.fullId
coresUsed -= exec.cores
diff --git a/core/src/main/scala/org/apache/spark/deploy/master/ui/ApplicationPage.scala b/core/src/main/scala/org/apache/spark/deploy/master/ui/ApplicationPage.scala
index 4588c130ef439..3aae2b95d7396 100644
--- a/core/src/main/scala/org/apache/spark/deploy/master/ui/ApplicationPage.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/master/ui/ApplicationPage.scala
@@ -27,7 +27,7 @@ import org.json4s.JValue
import org.apache.spark.deploy.{ExecutorState, JsonProtocol}
import org.apache.spark.deploy.DeployMessages.{MasterStateResponse, RequestMasterState}
-import org.apache.spark.deploy.master.ExecutorInfo
+import org.apache.spark.deploy.master.ExecutorDesc
import org.apache.spark.ui.{UIUtils, WebUIPage}
import org.apache.spark.util.Utils
@@ -109,7 +109,7 @@ private[spark] class ApplicationPage(parent: MasterWebUI) extends WebUIPage("app
UIUtils.basicSparkPage(content, "Application: " + app.desc.name)
}
- private def executorRow(executor: ExecutorInfo): Seq[Node] = {
+ private def executorRow(executor: ExecutorDesc): Seq[Node] = {
{executor.id}
diff --git a/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala b/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala
index 9a4adfbbb3d71..823825302658c 100644
--- a/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala
+++ b/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala
@@ -84,8 +84,12 @@ private[spark] class CoarseGrainedExecutorBackend(
}
case x: DisassociatedEvent =>
- logError(s"Driver $x disassociated! Shutting down.")
- System.exit(1)
+ if (x.remoteAddress == driver.anchorPath.address) {
+ logError(s"Driver $x disassociated! Shutting down.")
+ System.exit(1)
+ } else {
+ logWarning(s"Received irrelevant DisassociatedEvent $x")
+ }
case StopExecutor =>
logInfo("Driver commanded a shutdown")
diff --git a/core/src/main/scala/org/apache/spark/executor/Executor.scala b/core/src/main/scala/org/apache/spark/executor/Executor.scala
index b75c77b5b4457..d8c2e41a7c715 100644
--- a/core/src/main/scala/org/apache/spark/executor/Executor.scala
+++ b/core/src/main/scala/org/apache/spark/executor/Executor.scala
@@ -41,11 +41,14 @@ import org.apache.spark.util.{SparkUncaughtExceptionHandler, AkkaUtils, Utils}
*/
private[spark] class Executor(
executorId: String,
- slaveHostname: String,
+ executorHostname: String,
env: SparkEnv,
isLocal: Boolean = false)
extends Logging
{
+
+ logInfo(s"Starting executor ID $executorId on host $executorHostname")
+
// Application dependencies (added through SparkContext) that we've fetched so far on this node.
// Each map holds the master's timestamp for the version of that file or JAR we got.
private val currentFiles: HashMap[String, Long] = new HashMap[String, Long]()
@@ -58,12 +61,12 @@ private[spark] class Executor(
@volatile private var isStopped = false
// No ip or host:port - just hostname
- Utils.checkHost(slaveHostname, "Expected executed slave to be a hostname")
+ Utils.checkHost(executorHostname, "Expected executed slave to be a hostname")
// must not have port specified.
- assert (0 == Utils.parseHostPort(slaveHostname)._2)
+ assert (0 == Utils.parseHostPort(executorHostname)._2)
// Make sure the local hostname we report matches the cluster scheduler's name for this host
- Utils.setCustomHostname(slaveHostname)
+ Utils.setCustomHostname(executorHostname)
if (!isLocal) {
// Setup an uncaught exception handler for non-local mode.
@@ -203,10 +206,10 @@ private[spark] class Executor(
val afterSerialization = System.currentTimeMillis()
for (m <- task.metrics) {
- m.executorDeserializeTime = taskStart - deserializeStartTime
- m.executorRunTime = taskFinish - taskStart
- m.jvmGCTime = gcTime - startGCTime
- m.resultSerializationTime = afterSerialization - beforeSerialization
+ m.setExecutorDeserializeTime(taskStart - deserializeStartTime)
+ m.setExecutorRunTime(taskFinish - taskStart)
+ m.setJvmGCTime(gcTime - startGCTime)
+ m.setResultSerializationTime(afterSerialization - beforeSerialization)
}
val accumUpdates = Accumulators.values
@@ -257,8 +260,8 @@ private[spark] class Executor(
val serviceTime = System.currentTimeMillis() - taskStart
val metrics = attemptedTask.flatMap(t => t.metrics)
for (m <- metrics) {
- m.executorRunTime = serviceTime
- m.jvmGCTime = gcTime - startGCTime
+ m.setExecutorRunTime(serviceTime)
+ m.setJvmGCTime(gcTime - startGCTime)
}
val reason = new ExceptionFailure(t, metrics)
execBackend.statusUpdate(taskId, TaskState.FAILED, ser.serialize(reason))
@@ -376,10 +379,12 @@ private[spark] class Executor(
val curGCTime = gcTime
for (taskRunner <- runningTasks.values()) {
- if (!taskRunner.attemptedTask.isEmpty) {
+ if (taskRunner.attemptedTask.nonEmpty) {
Option(taskRunner.task).flatMap(_.metrics).foreach { metrics =>
- metrics.updateShuffleReadMetrics
- metrics.jvmGCTime = curGCTime - taskRunner.startGCTime
+ metrics.updateShuffleReadMetrics()
+ metrics.updateInputMetrics()
+ metrics.setJvmGCTime(curGCTime - taskRunner.startGCTime)
+
if (isLocal) {
// JobProgressListener will hold an reference of it during
// onExecutorMetricsUpdate(), then JobProgressListener can not see
diff --git a/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala b/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala
index 51b5328cb4c8f..97912c68c5982 100644
--- a/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala
+++ b/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala
@@ -17,6 +17,10 @@
package org.apache.spark.executor
+import java.util.concurrent.atomic.AtomicLong
+
+import org.apache.spark.executor.DataReadMethod.DataReadMethod
+
import scala.collection.mutable.ArrayBuffer
import org.apache.spark.annotation.DeveloperApi
@@ -39,48 +43,78 @@ class TaskMetrics extends Serializable {
/**
* Host's name the task runs on
*/
- var hostname: String = _
-
+ private var _hostname: String = _
+ def hostname = _hostname
+ private[spark] def setHostname(value: String) = _hostname = value
+
/**
* Time taken on the executor to deserialize this task
*/
- var executorDeserializeTime: Long = _
-
+ private var _executorDeserializeTime: Long = _
+ def executorDeserializeTime = _executorDeserializeTime
+ private[spark] def setExecutorDeserializeTime(value: Long) = _executorDeserializeTime = value
+
+
/**
* Time the executor spends actually running the task (including fetching shuffle data)
*/
- var executorRunTime: Long = _
-
+ private var _executorRunTime: Long = _
+ def executorRunTime = _executorRunTime
+ private[spark] def setExecutorRunTime(value: Long) = _executorRunTime = value
+
/**
* The number of bytes this task transmitted back to the driver as the TaskResult
*/
- var resultSize: Long = _
+ private var _resultSize: Long = _
+ def resultSize = _resultSize
+ private[spark] def setResultSize(value: Long) = _resultSize = value
+
/**
* Amount of time the JVM spent in garbage collection while executing this task
*/
- var jvmGCTime: Long = _
+ private var _jvmGCTime: Long = _
+ def jvmGCTime = _jvmGCTime
+ private[spark] def setJvmGCTime(value: Long) = _jvmGCTime = value
/**
* Amount of time spent serializing the task result
*/
- var resultSerializationTime: Long = _
+ private var _resultSerializationTime: Long = _
+ def resultSerializationTime = _resultSerializationTime
+ private[spark] def setResultSerializationTime(value: Long) = _resultSerializationTime = value
/**
* The number of in-memory bytes spilled by this task
*/
- var memoryBytesSpilled: Long = _
+ private var _memoryBytesSpilled: Long = _
+ def memoryBytesSpilled = _memoryBytesSpilled
+ private[spark] def incMemoryBytesSpilled(value: Long) = _memoryBytesSpilled += value
+ private[spark] def decMemoryBytesSpilled(value: Long) = _memoryBytesSpilled -= value
/**
* The number of on-disk bytes spilled by this task
*/
- var diskBytesSpilled: Long = _
+ private var _diskBytesSpilled: Long = _
+ def diskBytesSpilled = _diskBytesSpilled
+ def incDiskBytesSpilled(value: Long) = _diskBytesSpilled += value
+ def decDiskBytesSpilled(value: Long) = _diskBytesSpilled -= value
/**
* If this task reads from a HadoopRDD or from persisted data, metrics on how much data was read
* are stored here.
*/
- var inputMetrics: Option[InputMetrics] = None
+ private var _inputMetrics: Option[InputMetrics] = None
+
+ def inputMetrics = _inputMetrics
+
+ /**
+ * This should only be used when recreating TaskMetrics, not when updating input metrics in
+ * executors
+ */
+ private[spark] def setInputMetrics(inputMetrics: Option[InputMetrics]) {
+ _inputMetrics = inputMetrics
+ }
/**
* If this task writes data externally (e.g. to a distributed filesystem), metrics on how much
@@ -133,19 +167,47 @@ class TaskMetrics extends Serializable {
readMetrics
}
+ /**
+ * Returns the input metrics object that the task should use. Currently, if
+ * there exists an input metric with the same readMethod, we return that one
+ * so the caller can accumulate bytes read. If the readMethod is different
+ * than previously seen by this task, we return a new InputMetric but don't
+ * record it.
+ *
+ * Once https://issues.apache.org/jira/browse/SPARK-5225 is addressed,
+ * we can store all the different inputMetrics (one per readMethod).
+ */
+ private[spark] def getInputMetricsForReadMethod(readMethod: DataReadMethod):
+ InputMetrics =synchronized {
+ _inputMetrics match {
+ case None =>
+ val metrics = new InputMetrics(readMethod)
+ _inputMetrics = Some(metrics)
+ metrics
+ case Some(metrics @ InputMetrics(method)) if method == readMethod =>
+ metrics
+ case Some(InputMetrics(method)) =>
+ new InputMetrics(readMethod)
+ }
+ }
+
/**
* Aggregates shuffle read metrics for all registered dependencies into shuffleReadMetrics.
*/
private[spark] def updateShuffleReadMetrics() = synchronized {
val merged = new ShuffleReadMetrics()
for (depMetrics <- depsShuffleReadMetrics) {
- merged.fetchWaitTime += depMetrics.fetchWaitTime
- merged.localBlocksFetched += depMetrics.localBlocksFetched
- merged.remoteBlocksFetched += depMetrics.remoteBlocksFetched
- merged.remoteBytesRead += depMetrics.remoteBytesRead
+ merged.incFetchWaitTime(depMetrics.fetchWaitTime)
+ merged.incLocalBlocksFetched(depMetrics.localBlocksFetched)
+ merged.incRemoteBlocksFetched(depMetrics.remoteBlocksFetched)
+ merged.incRemoteBytesRead(depMetrics.remoteBytesRead)
}
_shuffleReadMetrics = Some(merged)
}
+
+ private[spark] def updateInputMetrics() = synchronized {
+ inputMetrics.foreach(_.updateBytesRead())
+ }
}
private[spark] object TaskMetrics {
@@ -179,10 +241,38 @@ object DataWriteMethod extends Enumeration with Serializable {
*/
@DeveloperApi
case class InputMetrics(readMethod: DataReadMethod.Value) {
+
+ private val _bytesRead: AtomicLong = new AtomicLong()
+
/**
* Total bytes read.
*/
- var bytesRead: Long = 0L
+ def bytesRead: Long = _bytesRead.get()
+ @volatile @transient var bytesReadCallback: Option[() => Long] = None
+
+ /**
+ * Adds additional bytes read for this read method.
+ */
+ def addBytesRead(bytes: Long) = {
+ _bytesRead.addAndGet(bytes)
+ }
+
+ /**
+ * Invoke the bytesReadCallback and mutate bytesRead.
+ */
+ def updateBytesRead() {
+ bytesReadCallback.foreach { c =>
+ _bytesRead.set(c())
+ }
+ }
+
+ /**
+ * Register a function that can be called to get up-to-date information on how many bytes the task
+ * has read from an input source.
+ */
+ def setBytesReadCallback(f: Option[() => Long]) {
+ bytesReadCallback = f
+ }
}
/**
@@ -194,7 +284,9 @@ case class OutputMetrics(writeMethod: DataWriteMethod.Value) {
/**
* Total bytes written
*/
- var bytesWritten: Long = 0L
+ private var _bytesWritten: Long = _
+ def bytesWritten = _bytesWritten
+ private[spark] def setBytesWritten(value : Long) = _bytesWritten = value
}
/**
@@ -203,32 +295,45 @@ case class OutputMetrics(writeMethod: DataWriteMethod.Value) {
*/
@DeveloperApi
class ShuffleReadMetrics extends Serializable {
- /**
- * Number of blocks fetched in this shuffle by this task (remote or local)
- */
- def totalBlocksFetched: Int = remoteBlocksFetched + localBlocksFetched
-
/**
* Number of remote blocks fetched in this shuffle by this task
*/
- var remoteBlocksFetched: Int = _
-
+ private var _remoteBlocksFetched: Int = _
+ def remoteBlocksFetched = _remoteBlocksFetched
+ private[spark] def incRemoteBlocksFetched(value: Int) = _remoteBlocksFetched += value
+ private[spark] def defRemoteBlocksFetched(value: Int) = _remoteBlocksFetched -= value
+
/**
* Number of local blocks fetched in this shuffle by this task
*/
- var localBlocksFetched: Int = _
+ private var _localBlocksFetched: Int = _
+ def localBlocksFetched = _localBlocksFetched
+ private[spark] def incLocalBlocksFetched(value: Int) = _localBlocksFetched += value
+ private[spark] def defLocalBlocksFetched(value: Int) = _localBlocksFetched -= value
+
/**
* Time the task spent waiting for remote shuffle blocks. This only includes the time
* blocking on shuffle input data. For instance if block B is being fetched while the task is
* still not finished processing block A, it is not considered to be blocking on block B.
*/
- var fetchWaitTime: Long = _
-
+ private var _fetchWaitTime: Long = _
+ def fetchWaitTime = _fetchWaitTime
+ private[spark] def incFetchWaitTime(value: Long) = _fetchWaitTime += value
+ private[spark] def decFetchWaitTime(value: Long) = _fetchWaitTime -= value
+
/**
* Total number of remote bytes read from the shuffle by this task
*/
- var remoteBytesRead: Long = _
+ private var _remoteBytesRead: Long = _
+ def remoteBytesRead = _remoteBytesRead
+ private[spark] def incRemoteBytesRead(value: Long) = _remoteBytesRead += value
+ private[spark] def decRemoteBytesRead(value: Long) = _remoteBytesRead -= value
+
+ /**
+ * Number of blocks fetched in this shuffle by this task (remote or local)
+ */
+ def totalBlocksFetched = _remoteBlocksFetched + _localBlocksFetched
}
/**
@@ -240,10 +345,18 @@ class ShuffleWriteMetrics extends Serializable {
/**
* Number of bytes written for the shuffle by this task
*/
- @volatile var shuffleBytesWritten: Long = _
-
+ @volatile private var _shuffleBytesWritten: Long = _
+ def shuffleBytesWritten = _shuffleBytesWritten
+ private[spark] def incShuffleBytesWritten(value: Long) = _shuffleBytesWritten += value
+ private[spark] def decShuffleBytesWritten(value: Long) = _shuffleBytesWritten -= value
+
/**
* Time the task spent blocking on writes to disk or buffer cache, in nanoseconds
*/
- @volatile var shuffleWriteTime: Long = _
+ @volatile private var _shuffleWriteTime: Long = _
+ def shuffleWriteTime= _shuffleWriteTime
+ private[spark] def incShuffleWriteTime(value: Long) = _shuffleWriteTime += value
+ private[spark] def decShuffleWriteTime(value: Long) = _shuffleWriteTime -= value
+
+
}
diff --git a/core/src/main/scala/org/apache/spark/metrics/MetricsSystem.scala b/core/src/main/scala/org/apache/spark/metrics/MetricsSystem.scala
index 45633e3de01dd..83e8eb71260eb 100644
--- a/core/src/main/scala/org/apache/spark/metrics/MetricsSystem.scala
+++ b/core/src/main/scala/org/apache/spark/metrics/MetricsSystem.scala
@@ -130,8 +130,8 @@ private[spark] class MetricsSystem private (
if (appId.isDefined && executorId.isDefined) {
MetricRegistry.name(appId.get, executorId.get, source.sourceName)
} else {
- // Only Driver and Executor are set spark.app.id and spark.executor.id.
- // For instance, Master and Worker are not related to a specific application.
+ // Only Driver and Executor set spark.app.id and spark.executor.id.
+ // Other instance types, e.g. Master and Worker, are not related to a specific application.
val warningMsg = s"Using default name $defaultName for source because %s is not set."
if (appId.isEmpty) { logWarning(warningMsg.format("spark.app.id")) }
if (executorId.isEmpty) { logWarning(warningMsg.format("spark.executor.id")) }
diff --git a/core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala b/core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala
index 70edf191d928a..07398a6fa62f6 100644
--- a/core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala
@@ -159,8 +159,8 @@ class CoGroupedRDD[K](@transient var rdds: Seq[RDD[_ <: Product2[K, _]]], part:
for ((it, depNum) <- rddIterators) {
map.insertAll(it.map(pair => (pair._1, new CoGroupValue(pair._2, depNum))))
}
- context.taskMetrics.memoryBytesSpilled += map.memoryBytesSpilled
- context.taskMetrics.diskBytesSpilled += map.diskBytesSpilled
+ context.taskMetrics.incMemoryBytesSpilled(map.memoryBytesSpilled)
+ context.taskMetrics.incDiskBytesSpilled(map.diskBytesSpilled)
new InterruptibleIterator(context,
map.iterator.asInstanceOf[Iterator[(K, Array[Iterable[_]])]])
}
diff --git a/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala b/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala
index 37e0c13029d8b..c3e3931042de2 100644
--- a/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala
@@ -35,6 +35,7 @@ import org.apache.hadoop.mapred.Reporter
import org.apache.hadoop.mapred.JobID
import org.apache.hadoop.mapred.TaskAttemptID
import org.apache.hadoop.mapred.TaskID
+import org.apache.hadoop.mapred.lib.CombineFileSplit
import org.apache.hadoop.util.ReflectionUtils
import org.apache.spark._
@@ -213,18 +214,19 @@ class HadoopRDD[K, V](
logInfo("Input split: " + split.inputSplit)
val jobConf = getJobConf()
- val inputMetrics = new InputMetrics(DataReadMethod.Hadoop)
+ val inputMetrics = context.taskMetrics
+ .getInputMetricsForReadMethod(DataReadMethod.Hadoop)
+
// Find a function that will return the FileSystem bytes read by this thread. Do this before
// creating RecordReader, because RecordReader's constructor might read some bytes
- val bytesReadCallback = if (split.inputSplit.value.isInstanceOf[FileSplit]) {
- SparkHadoopUtil.get.getFSBytesReadOnThreadCallback(
- split.inputSplit.value.asInstanceOf[FileSplit].getPath, jobConf)
- } else {
- None
- }
- if (bytesReadCallback.isDefined) {
- context.taskMetrics.inputMetrics = Some(inputMetrics)
+ val bytesReadCallback = inputMetrics.bytesReadCallback.orElse {
+ split.inputSplit.value match {
+ case _: FileSplit | _: CombineFileSplit =>
+ SparkHadoopUtil.get.getFSBytesReadOnThreadCallback()
+ case _ => None
+ }
}
+ inputMetrics.setBytesReadCallback(bytesReadCallback)
var reader: RecordReader[K, V] = null
val inputFormat = getInputFormat(jobConf)
@@ -237,8 +239,6 @@ class HadoopRDD[K, V](
val key: K = reader.createKey()
val value: V = reader.createValue()
- var recordsSinceMetricsUpdate = 0
-
override def getNext() = {
try {
finished = !reader.next(key, value)
@@ -247,15 +247,6 @@ class HadoopRDD[K, V](
finished = true
}
- // Update bytes read metric every few records
- if (recordsSinceMetricsUpdate == HadoopRDD.RECORDS_BETWEEN_BYTES_READ_METRIC_UPDATES
- && bytesReadCallback.isDefined) {
- recordsSinceMetricsUpdate = 0
- val bytesReadFn = bytesReadCallback.get
- inputMetrics.bytesRead = bytesReadFn()
- } else {
- recordsSinceMetricsUpdate += 1
- }
(key, value)
}
@@ -263,14 +254,13 @@ class HadoopRDD[K, V](
try {
reader.close()
if (bytesReadCallback.isDefined) {
- val bytesReadFn = bytesReadCallback.get
- inputMetrics.bytesRead = bytesReadFn()
- } else if (split.inputSplit.value.isInstanceOf[FileSplit]) {
+ inputMetrics.updateBytesRead()
+ } else if (split.inputSplit.value.isInstanceOf[FileSplit] ||
+ split.inputSplit.value.isInstanceOf[CombineFileSplit]) {
// If we can't get the bytes read from the FS stats, fall back to the split size,
// which may be inaccurate.
try {
- inputMetrics.bytesRead = split.inputSplit.value.getLength
- context.taskMetrics.inputMetrics = Some(inputMetrics)
+ inputMetrics.addBytesRead(split.inputSplit.value.getLength)
} catch {
case e: java.io.IOException =>
logWarning("Unable to get input size to set InputMetrics for task", e)
diff --git a/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala b/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala
index e55d03d391e03..d86f95ac3e485 100644
--- a/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala
@@ -25,7 +25,7 @@ import scala.reflect.ClassTag
import org.apache.hadoop.conf.{Configurable, Configuration}
import org.apache.hadoop.io.Writable
import org.apache.hadoop.mapreduce._
-import org.apache.hadoop.mapreduce.lib.input.FileSplit
+import org.apache.hadoop.mapreduce.lib.input.{CombineFileSplit, FileSplit}
import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.input.WholeTextFileInputFormat
@@ -34,7 +34,7 @@ import org.apache.spark.Logging
import org.apache.spark.Partition
import org.apache.spark.SerializableWritable
import org.apache.spark.{SparkContext, TaskContext}
-import org.apache.spark.executor.{DataReadMethod, InputMetrics}
+import org.apache.spark.executor.DataReadMethod
import org.apache.spark.mapreduce.SparkHadoopMapReduceUtil
import org.apache.spark.rdd.NewHadoopRDD.NewHadoopMapPartitionsWithSplitRDD
import org.apache.spark.util.Utils
@@ -109,18 +109,19 @@ class NewHadoopRDD[K, V](
logInfo("Input split: " + split.serializableHadoopSplit)
val conf = confBroadcast.value.value
- val inputMetrics = new InputMetrics(DataReadMethod.Hadoop)
+ val inputMetrics = context.taskMetrics
+ .getInputMetricsForReadMethod(DataReadMethod.Hadoop)
+
// Find a function that will return the FileSystem bytes read by this thread. Do this before
// creating RecordReader, because RecordReader's constructor might read some bytes
- val bytesReadCallback = if (split.serializableHadoopSplit.value.isInstanceOf[FileSplit]) {
- SparkHadoopUtil.get.getFSBytesReadOnThreadCallback(
- split.serializableHadoopSplit.value.asInstanceOf[FileSplit].getPath, conf)
- } else {
- None
- }
- if (bytesReadCallback.isDefined) {
- context.taskMetrics.inputMetrics = Some(inputMetrics)
+ val bytesReadCallback = inputMetrics.bytesReadCallback.orElse {
+ split.serializableHadoopSplit.value match {
+ case _: FileSplit | _: CombineFileSplit =>
+ SparkHadoopUtil.get.getFSBytesReadOnThreadCallback()
+ case _ => None
+ }
}
+ inputMetrics.setBytesReadCallback(bytesReadCallback)
val attemptId = newTaskAttemptID(jobTrackerId, id, isMap = true, split.index, 0)
val hadoopAttemptContext = newTaskAttemptContext(conf, attemptId)
@@ -154,33 +155,20 @@ class NewHadoopRDD[K, V](
}
havePair = false
- // Update bytes read metric every few records
- if (recordsSinceMetricsUpdate == HadoopRDD.RECORDS_BETWEEN_BYTES_READ_METRIC_UPDATES
- && bytesReadCallback.isDefined) {
- recordsSinceMetricsUpdate = 0
- val bytesReadFn = bytesReadCallback.get
- inputMetrics.bytesRead = bytesReadFn()
- } else {
- recordsSinceMetricsUpdate += 1
- }
-
(reader.getCurrentKey, reader.getCurrentValue)
}
private def close() {
try {
reader.close()
-
- // Update metrics with final amount
if (bytesReadCallback.isDefined) {
- val bytesReadFn = bytesReadCallback.get
- inputMetrics.bytesRead = bytesReadFn()
- } else if (split.serializableHadoopSplit.value.isInstanceOf[FileSplit]) {
+ inputMetrics.updateBytesRead()
+ } else if (split.serializableHadoopSplit.value.isInstanceOf[FileSplit] ||
+ split.serializableHadoopSplit.value.isInstanceOf[CombineFileSplit]) {
// If we can't get the bytes read from the FS stats, fall back to the split size,
// which may be inaccurate.
try {
- inputMetrics.bytesRead = split.serializableHadoopSplit.value.getLength
- context.taskMetrics.inputMetrics = Some(inputMetrics)
+ inputMetrics.addBytesRead(split.serializableHadoopSplit.value.getLength)
} catch {
case e: java.io.IOException =>
logWarning("Unable to get input size to set InputMetrics for task", e)
diff --git a/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala b/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala
index e43e5066655b9..49b88a90ab5af 100644
--- a/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala
@@ -990,7 +990,7 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)])
val committer = format.getOutputCommitter(hadoopContext)
committer.setupTask(hadoopContext)
- val (outputMetrics, bytesWrittenCallback) = initHadoopOutputMetrics(context, config)
+ val (outputMetrics, bytesWrittenCallback) = initHadoopOutputMetrics(context)
val writer = format.getRecordWriter(hadoopContext).asInstanceOf[NewRecordWriter[K,V]]
try {
@@ -1007,7 +1007,7 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)])
writer.close(hadoopContext)
}
committer.commitTask(hadoopContext)
- bytesWrittenCallback.foreach { fn => outputMetrics.bytesWritten = fn() }
+ bytesWrittenCallback.foreach { fn => outputMetrics.setBytesWritten(fn()) }
1
} : Int
@@ -1061,7 +1061,7 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)])
// around by taking a mod. We expect that no task will be attempted 2 billion times.
val taskAttemptId = (context.taskAttemptId % Int.MaxValue).toInt
- val (outputMetrics, bytesWrittenCallback) = initHadoopOutputMetrics(context, config)
+ val (outputMetrics, bytesWrittenCallback) = initHadoopOutputMetrics(context)
writer.setup(context.stageId, context.partitionId, taskAttemptId)
writer.open()
@@ -1079,18 +1079,15 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)])
writer.close()
}
writer.commit()
- bytesWrittenCallback.foreach { fn => outputMetrics.bytesWritten = fn() }
+ bytesWrittenCallback.foreach { fn => outputMetrics.setBytesWritten(fn()) }
}
self.context.runJob(self, writeToFile)
writer.commitJob()
}
- private def initHadoopOutputMetrics(context: TaskContext, config: Configuration)
- : (OutputMetrics, Option[() => Long]) = {
- val bytesWrittenCallback = Option(config.get("mapreduce.output.fileoutputformat.outputdir"))
- .map(new Path(_))
- .flatMap(SparkHadoopUtil.get.getFSBytesWrittenOnThreadCallback(_, config))
+ private def initHadoopOutputMetrics(context: TaskContext): (OutputMetrics, Option[() => Long]) = {
+ val bytesWrittenCallback = SparkHadoopUtil.get.getFSBytesWrittenOnThreadCallback()
val outputMetrics = new OutputMetrics(DataWriteMethod.Hadoop)
if (bytesWrittenCallback.isDefined) {
context.taskMetrics.outputMetrics = Some(outputMetrics)
@@ -1102,7 +1099,7 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)])
outputMetrics: OutputMetrics, recordsWritten: Long): Unit = {
if (recordsWritten % PairRDDFunctions.RECORDS_BETWEEN_BYTES_WRITTEN_METRIC_UPDATES == 0
&& bytesWrittenCallback.isDefined) {
- bytesWrittenCallback.foreach { fn => outputMetrics.bytesWritten = fn() }
+ bytesWrittenCallback.foreach { fn => outputMetrics.setBytesWritten(fn()) }
}
}
diff --git a/core/src/main/scala/org/apache/spark/rdd/ParallelCollectionRDD.scala b/core/src/main/scala/org/apache/spark/rdd/ParallelCollectionRDD.scala
index 87b22de6ae697..f12d0cffaba34 100644
--- a/core/src/main/scala/org/apache/spark/rdd/ParallelCollectionRDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/ParallelCollectionRDD.scala
@@ -111,7 +111,8 @@ private object ParallelCollectionRDD {
/**
* Slice a collection into numSlices sub-collections. One extra thing we do here is to treat Range
* collections specially, encoding the slices as other Ranges to minimize memory cost. This makes
- * it efficient to run Spark over RDDs representing large sets of numbers.
+ * it efficient to run Spark over RDDs representing large sets of numbers. And if the collection
+ * is an inclusive Range, we use inclusive range for the last slice.
*/
def slice[T: ClassTag](seq: Seq[T], numSlices: Int): Seq[Seq[T]] = {
if (numSlices < 1) {
@@ -127,19 +128,15 @@ private object ParallelCollectionRDD {
})
}
seq match {
- case r: Range.Inclusive => {
- val sign = if (r.step < 0) {
- -1
- } else {
- 1
- }
- slice(new Range(
- r.start, r.end + sign, r.step).asInstanceOf[Seq[T]], numSlices)
- }
case r: Range => {
- positions(r.length, numSlices).map({
- case (start, end) =>
+ positions(r.length, numSlices).zipWithIndex.map({ case ((start, end), index) =>
+ // If the range is inclusive, use inclusive range for the last slice
+ if (r.isInclusive && index == numSlices - 1) {
+ new Range.Inclusive(r.start + start * r.step, r.end, r.step)
+ }
+ else {
new Range(r.start + start * r.step, r.start + end * r.step, r.step)
+ }
}).toSeq.asInstanceOf[Seq[Seq[T]]]
}
case nr: NumericRange[_] => {
diff --git a/core/src/main/scala/org/apache/spark/rdd/RDD.scala b/core/src/main/scala/org/apache/spark/rdd/RDD.scala
index ec96b12a4e0b2..ea4277a433b00 100644
--- a/core/src/main/scala/org/apache/spark/rdd/RDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/RDD.scala
@@ -76,10 +76,27 @@ import org.apache.spark.util.random.{BernoulliSampler, PoissonSampler, Bernoulli
* on RDD internals.
*/
abstract class RDD[T: ClassTag](
- @transient private var sc: SparkContext,
+ @transient private var _sc: SparkContext,
@transient private var deps: Seq[Dependency[_]]
) extends Serializable with Logging {
+ if (classOf[RDD[_]].isAssignableFrom(elementClassTag.runtimeClass)) {
+ // This is a warning instead of an exception in order to avoid breaking user programs that
+ // might have defined nested RDDs without running jobs with them.
+ logWarning("Spark does not support nested RDDs (see SPARK-5063)")
+ }
+
+ private def sc: SparkContext = {
+ if (_sc == null) {
+ throw new SparkException(
+ "RDD transformations and actions can only be invoked by the driver, not inside of other " +
+ "transformations; for example, rdd1.map(x => rdd2.values.count() * x) is invalid because " +
+ "the values transformation and count action cannot be performed inside of the rdd1.map " +
+ "transformation. For more information, see SPARK-5063.")
+ }
+ _sc
+ }
+
/** Construct an RDD with just a one-to-one dependency on one parent */
def this(@transient oneParent: RDD[_]) =
this(oneParent.context , List(new OneToOneDependency(oneParent)))
@@ -1189,6 +1206,12 @@ abstract class RDD[T: ClassTag](
* */
def min()(implicit ord: Ordering[T]): T = this.reduce(ord.min)
+ /**
+ * @return true if and only if the RDD contains no elements at all. Note that an RDD
+ * may be empty even when it has at least 1 partition.
+ */
+ def isEmpty(): Boolean = partitions.length == 0 || take(1).length == 0
+
/**
* Save this RDD as a text file, using string representations of elements.
*/
diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
index 8cb15918baa8c..1cfe98673773a 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
@@ -19,6 +19,7 @@ package org.apache.spark.scheduler
import java.io.NotSerializableException
import java.util.Properties
+import java.util.concurrent.{TimeUnit, Executors}
import java.util.concurrent.atomic.AtomicInteger
import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet, Map, Stack}
@@ -28,8 +29,6 @@ import scala.language.postfixOps
import scala.reflect.ClassTag
import scala.util.control.NonFatal
-import akka.actor._
-import akka.actor.SupervisorStrategy.Stop
import akka.pattern.ask
import akka.util.Timeout
@@ -39,7 +38,7 @@ import org.apache.spark.executor.TaskMetrics
import org.apache.spark.partial.{ApproximateActionListener, ApproximateEvaluator, PartialResult}
import org.apache.spark.rdd.RDD
import org.apache.spark.storage._
-import org.apache.spark.util.{CallSite, SystemClock, Clock, Utils}
+import org.apache.spark.util.{CallSite, EventLoop, SystemClock, Clock, Utils}
import org.apache.spark.storage.BlockManagerMessages.BlockManagerHeartbeat
/**
@@ -67,8 +66,6 @@ class DAGScheduler(
clock: Clock = SystemClock)
extends Logging {
- import DAGScheduler._
-
def this(sc: SparkContext, taskScheduler: TaskScheduler) = {
this(
sc,
@@ -112,14 +109,10 @@ class DAGScheduler(
// stray messages to detect.
private val failedEpoch = new HashMap[String, Long]
- private val dagSchedulerActorSupervisor =
- env.actorSystem.actorOf(Props(new DAGSchedulerActorSupervisor(this)))
-
// A closure serializer that we reuse.
// This is only safe because DAGScheduler runs in a single thread.
private val closureSerializer = SparkEnv.get.closureSerializer.newInstance()
- private[scheduler] var eventProcessActor: ActorRef = _
/** If enabled, we may run certain actions like take() and first() locally. */
private val localExecutionEnabled = sc.getConf.getBoolean("spark.localExecution.enabled", false)
@@ -127,27 +120,20 @@ class DAGScheduler(
/** If enabled, FetchFailed will not cause stage retry, in order to surface the problem. */
private val disallowStageRetryForTest = sc.getConf.getBoolean("spark.test.noStageRetry", false)
- private def initializeEventProcessActor() {
- // blocking the thread until supervisor is started, which ensures eventProcessActor is
- // not null before any job is submitted
- implicit val timeout = Timeout(30 seconds)
- val initEventActorReply =
- dagSchedulerActorSupervisor ? Props(new DAGSchedulerEventProcessActor(this))
- eventProcessActor = Await.result(initEventActorReply, timeout.duration).
- asInstanceOf[ActorRef]
- }
+ private val messageScheduler =
+ Executors.newScheduledThreadPool(1, Utils.namedThreadFactory("dag-scheduler-message"))
- initializeEventProcessActor()
+ private[scheduler] val eventProcessLoop = new DAGSchedulerEventProcessLoop(this)
taskScheduler.setDAGScheduler(this)
// Called by TaskScheduler to report task's starting.
def taskStarted(task: Task[_], taskInfo: TaskInfo) {
- eventProcessActor ! BeginEvent(task, taskInfo)
+ eventProcessLoop.post(BeginEvent(task, taskInfo))
}
// Called to report that a task has completed and results are being fetched remotely.
def taskGettingResult(taskInfo: TaskInfo) {
- eventProcessActor ! GettingResultEvent(taskInfo)
+ eventProcessLoop.post(GettingResultEvent(taskInfo))
}
// Called by TaskScheduler to report task completions or failures.
@@ -158,7 +144,8 @@ class DAGScheduler(
accumUpdates: Map[Long, Any],
taskInfo: TaskInfo,
taskMetrics: TaskMetrics) {
- eventProcessActor ! CompletionEvent(task, reason, result, accumUpdates, taskInfo, taskMetrics)
+ eventProcessLoop.post(
+ CompletionEvent(task, reason, result, accumUpdates, taskInfo, taskMetrics))
}
/**
@@ -180,18 +167,18 @@ class DAGScheduler(
// Called by TaskScheduler when an executor fails.
def executorLost(execId: String) {
- eventProcessActor ! ExecutorLost(execId)
+ eventProcessLoop.post(ExecutorLost(execId))
}
// Called by TaskScheduler when a host is added
def executorAdded(execId: String, host: String) {
- eventProcessActor ! ExecutorAdded(execId, host)
+ eventProcessLoop.post(ExecutorAdded(execId, host))
}
// Called by TaskScheduler to cancel an entire TaskSet due to either repeated failures or
// cancellation of the job itself.
def taskSetFailed(taskSet: TaskSet, reason: String) {
- eventProcessActor ! TaskSetFailed(taskSet, reason)
+ eventProcessLoop.post(TaskSetFailed(taskSet, reason))
}
private def getCacheLocs(rdd: RDD[_]): Array[Seq[TaskLocation]] = {
@@ -496,8 +483,8 @@ class DAGScheduler(
assert(partitions.size > 0)
val func2 = func.asInstanceOf[(TaskContext, Iterator[_]) => _]
val waiter = new JobWaiter(this, jobId, partitions.size, resultHandler)
- eventProcessActor ! JobSubmitted(
- jobId, rdd, func2, partitions.toArray, allowLocal, callSite, waiter, properties)
+ eventProcessLoop.post(JobSubmitted(
+ jobId, rdd, func2, partitions.toArray, allowLocal, callSite, waiter, properties))
waiter
}
@@ -537,8 +524,8 @@ class DAGScheduler(
val func2 = func.asInstanceOf[(TaskContext, Iterator[_]) => _]
val partitions = (0 until rdd.partitions.size).toArray
val jobId = nextJobId.getAndIncrement()
- eventProcessActor ! JobSubmitted(
- jobId, rdd, func2, partitions, allowLocal = false, callSite, listener, properties)
+ eventProcessLoop.post(JobSubmitted(
+ jobId, rdd, func2, partitions, allowLocal = false, callSite, listener, properties))
listener.awaitResult() // Will throw an exception if the job fails
}
@@ -547,19 +534,19 @@ class DAGScheduler(
*/
def cancelJob(jobId: Int) {
logInfo("Asked to cancel job " + jobId)
- eventProcessActor ! JobCancelled(jobId)
+ eventProcessLoop.post(JobCancelled(jobId))
}
def cancelJobGroup(groupId: String) {
logInfo("Asked to cancel job group " + groupId)
- eventProcessActor ! JobGroupCancelled(groupId)
+ eventProcessLoop.post(JobGroupCancelled(groupId))
}
/**
* Cancel all jobs that are running or waiting in the queue.
*/
def cancelAllJobs() {
- eventProcessActor ! AllJobsCancelled
+ eventProcessLoop.post(AllJobsCancelled)
}
private[scheduler] def doCancelAllJobs() {
@@ -575,7 +562,7 @@ class DAGScheduler(
* Cancel all jobs associated with a running or scheduled stage.
*/
def cancelStage(stageId: Int) {
- eventProcessActor ! StageCancelled(stageId)
+ eventProcessLoop.post(StageCancelled(stageId))
}
/**
@@ -661,7 +648,7 @@ class DAGScheduler(
// completion events or stage abort
stageIdToStage -= s.id
jobIdToStageIds -= job.jobId
- listenerBus.post(SparkListenerJobEnd(job.jobId, jobResult))
+ listenerBus.post(SparkListenerJobEnd(job.jobId, clock.getTime(), jobResult))
}
}
@@ -710,7 +697,7 @@ class DAGScheduler(
stage.latestInfo.stageFailed(stageFailedMessage)
listenerBus.post(SparkListenerStageCompleted(stage.latestInfo))
}
- listenerBus.post(SparkListenerJobEnd(job.jobId, JobFailed(error)))
+ listenerBus.post(SparkListenerJobEnd(job.jobId, clock.getTime(), JobFailed(error)))
}
}
@@ -749,9 +736,11 @@ class DAGScheduler(
logInfo("Missing parents: " + getMissingParentStages(finalStage))
val shouldRunLocally =
localExecutionEnabled && allowLocal && finalStage.parents.isEmpty && partitions.length == 1
+ val jobSubmissionTime = clock.getTime()
if (shouldRunLocally) {
// Compute very short actions like first() or take() with no parent stages locally.
- listenerBus.post(SparkListenerJobStart(job.jobId, Seq.empty, properties))
+ listenerBus.post(
+ SparkListenerJobStart(job.jobId, jobSubmissionTime, Seq.empty, properties))
runLocally(job)
} else {
jobIdToActiveJob(jobId) = job
@@ -759,7 +748,8 @@ class DAGScheduler(
finalStage.resultOfJob = Some(job)
val stageIds = jobIdToStageIds(jobId).toArray
val stageInfos = stageIds.flatMap(id => stageIdToStage.get(id).map(_.latestInfo))
- listenerBus.post(SparkListenerJobStart(job.jobId, stageInfos, properties))
+ listenerBus.post(
+ SparkListenerJobStart(job.jobId, jobSubmissionTime, stageInfos, properties))
submitStage(finalStage)
}
}
@@ -965,7 +955,8 @@ class DAGScheduler(
if (job.numFinished == job.numPartitions) {
markStageAsFinished(stage)
cleanupStateForJobAndIndependentStages(job)
- listenerBus.post(SparkListenerJobEnd(job.jobId, JobSucceeded))
+ listenerBus.post(
+ SparkListenerJobEnd(job.jobId, clock.getTime(), JobSucceeded))
}
// taskSucceeded runs some user code that might throw an exception. Make sure
@@ -1059,16 +1050,15 @@ class DAGScheduler(
if (disallowStageRetryForTest) {
abortStage(failedStage, "Fetch failure will not retry stage due to testing config")
- } else if (failedStages.isEmpty && eventProcessActor != null) {
+ } else if (failedStages.isEmpty) {
// Don't schedule an event to resubmit failed stages if failed isn't empty, because
- // in that case the event will already have been scheduled. eventProcessActor may be
- // null during unit tests.
+ // in that case the event will already have been scheduled.
// TODO: Cancel running tasks in the stage
- import env.actorSystem.dispatcher
logInfo(s"Resubmitting $mapStage (${mapStage.name}) and " +
s"$failedStage (${failedStage.name}) due to fetch failure")
- env.actorSystem.scheduler.scheduleOnce(
- RESUBMIT_TIMEOUT, eventProcessActor, ResubmitFailedStages)
+ messageScheduler.schedule(new Runnable {
+ override def run(): Unit = eventProcessLoop.post(ResubmitFailedStages)
+ }, DAGScheduler.RESUBMIT_TIMEOUT, TimeUnit.MILLISECONDS)
}
failedStages += failedStage
failedStages += mapStage
@@ -1234,7 +1224,7 @@ class DAGScheduler(
if (ableToCancelStages) {
job.listener.jobFailed(error)
cleanupStateForJobAndIndependentStages(job)
- listenerBus.post(SparkListenerJobEnd(job.jobId, JobFailed(error)))
+ listenerBus.post(SparkListenerJobEnd(job.jobId, clock.getTime(), JobFailed(error)))
}
}
@@ -1326,40 +1316,21 @@ class DAGScheduler(
def stop() {
logInfo("Stopping DAGScheduler")
- dagSchedulerActorSupervisor ! PoisonPill
+ eventProcessLoop.stop()
taskScheduler.stop()
}
-}
-private[scheduler] class DAGSchedulerActorSupervisor(dagScheduler: DAGScheduler)
- extends Actor with Logging {
-
- override val supervisorStrategy =
- OneForOneStrategy() {
- case x: Exception =>
- logError("eventProcesserActor failed; shutting down SparkContext", x)
- try {
- dagScheduler.doCancelAllJobs()
- } catch {
- case t: Throwable => logError("DAGScheduler failed to cancel all jobs.", t)
- }
- dagScheduler.sc.stop()
- Stop
- }
-
- def receive = {
- case p: Props => sender ! context.actorOf(p)
- case _ => logWarning("received unknown message in DAGSchedulerActorSupervisor")
- }
+ // Start the event thread at the end of the constructor
+ eventProcessLoop.start()
}
-private[scheduler] class DAGSchedulerEventProcessActor(dagScheduler: DAGScheduler)
- extends Actor with Logging {
+private[scheduler] class DAGSchedulerEventProcessLoop(dagScheduler: DAGScheduler)
+ extends EventLoop[DAGSchedulerEvent]("dag-scheduler-event-loop") with Logging {
/**
* The main event loop of the DAG scheduler.
*/
- def receive = {
+ override def onReceive(event: DAGSchedulerEvent): Unit = event match {
case JobSubmitted(jobId, rdd, func, partitions, allowLocal, callSite, listener, properties) =>
dagScheduler.handleJobSubmitted(jobId, rdd, func, partitions, allowLocal, callSite,
listener, properties)
@@ -1398,7 +1369,17 @@ private[scheduler] class DAGSchedulerEventProcessActor(dagScheduler: DAGSchedule
dagScheduler.resubmitFailedStages()
}
- override def postStop() {
+ override def onError(e: Throwable): Unit = {
+ logError("DAGSchedulerEventProcessLoop failed; shutting down SparkContext", e)
+ try {
+ dagScheduler.doCancelAllJobs()
+ } catch {
+ case t: Throwable => logError("DAGScheduler failed to cancel all jobs.", t)
+ }
+ dagScheduler.sc.stop()
+ }
+
+ override def onStop() {
// Cancel any active jobs in postStop hook
dagScheduler.cleanUpAfterSchedulerStop()
}
@@ -1408,9 +1389,5 @@ private[spark] object DAGScheduler {
// The time, in millis, to wait for fetch failure events to stop coming in after one is detected;
// this is a simplistic way to avoid resubmitting tasks in the non-fetchable map stage one by one
// as more failure events come in
- val RESUBMIT_TIMEOUT = 200.milliseconds
-
- // The time, in millis, to wake up between polls of the completion queue in order to potentially
- // resubmit failed stages
- val POLL_TIMEOUT = 10L
+ val RESUBMIT_TIMEOUT = 200
}
diff --git a/core/src/main/scala/org/apache/spark/scheduler/EventLoggingListener.scala b/core/src/main/scala/org/apache/spark/scheduler/EventLoggingListener.scala
index 27bf4f1599076..30075c172bdb1 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/EventLoggingListener.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/EventLoggingListener.scala
@@ -168,6 +168,10 @@ private[spark] class EventLoggingListener(
logEvent(event, flushLogger = true)
override def onApplicationEnd(event: SparkListenerApplicationEnd) =
logEvent(event, flushLogger = true)
+ override def onExecutorAdded(event: SparkListenerExecutorAdded) =
+ logEvent(event, flushLogger = true)
+ override def onExecutorRemoved(event: SparkListenerExecutorRemoved) =
+ logEvent(event, flushLogger = true)
// No-op because logging every update would be overkill
override def onExecutorMetricsUpdate(event: SparkListenerExecutorMetricsUpdate) { }
diff --git a/core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala b/core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala
index b62b0c1312693..e5d1eb767e109 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala
@@ -25,6 +25,7 @@ import scala.collection.mutable
import org.apache.spark.{Logging, TaskEndReason}
import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.executor.TaskMetrics
+import org.apache.spark.scheduler.cluster.ExecutorInfo
import org.apache.spark.storage.BlockManagerId
import org.apache.spark.util.{Distribution, Utils}
@@ -58,6 +59,7 @@ case class SparkListenerTaskEnd(
@DeveloperApi
case class SparkListenerJobStart(
jobId: Int,
+ time: Long,
stageInfos: Seq[StageInfo],
properties: Properties = null)
extends SparkListenerEvent {
@@ -67,7 +69,11 @@ case class SparkListenerJobStart(
}
@DeveloperApi
-case class SparkListenerJobEnd(jobId: Int, jobResult: JobResult) extends SparkListenerEvent
+case class SparkListenerJobEnd(
+ jobId: Int,
+ time: Long,
+ jobResult: JobResult)
+ extends SparkListenerEvent
@DeveloperApi
case class SparkListenerEnvironmentUpdate(environmentDetails: Map[String, Seq[(String, String)]])
@@ -84,6 +90,14 @@ case class SparkListenerBlockManagerRemoved(time: Long, blockManagerId: BlockMan
@DeveloperApi
case class SparkListenerUnpersistRDD(rddId: Int) extends SparkListenerEvent
+@DeveloperApi
+case class SparkListenerExecutorAdded(executorId: String, executorInfo: ExecutorInfo)
+ extends SparkListenerEvent
+
+@DeveloperApi
+case class SparkListenerExecutorRemoved(executorId: String)
+ extends SparkListenerEvent
+
/**
* Periodic updates from executors.
* @param execId executor id
@@ -109,7 +123,8 @@ private[spark] case object SparkListenerShutdown extends SparkListenerEvent
/**
* :: DeveloperApi ::
* Interface for listening to events from the Spark scheduler. Note that this is an internal
- * interface which might change in different Spark releases.
+ * interface which might change in different Spark releases. Java clients should extend
+ * {@link JavaSparkListener}
*/
@DeveloperApi
trait SparkListener {
@@ -183,6 +198,16 @@ trait SparkListener {
* Called when the driver receives task metrics from an executor in a heartbeat.
*/
def onExecutorMetricsUpdate(executorMetricsUpdate: SparkListenerExecutorMetricsUpdate) { }
+
+ /**
+ * Called when the driver registers a new executor.
+ */
+ def onExecutorAdded(executorAdded: SparkListenerExecutorAdded) { }
+
+ /**
+ * Called when the driver removes an executor.
+ */
+ def onExecutorRemoved(executorRemoved: SparkListenerExecutorRemoved) { }
}
/**
diff --git a/core/src/main/scala/org/apache/spark/scheduler/SparkListenerBus.scala b/core/src/main/scala/org/apache/spark/scheduler/SparkListenerBus.scala
index e79ffd7a3587d..e700c6af542f4 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/SparkListenerBus.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/SparkListenerBus.scala
@@ -70,6 +70,10 @@ private[spark] trait SparkListenerBus extends Logging {
foreachListener(_.onApplicationEnd(applicationEnd))
case metricsUpdate: SparkListenerExecutorMetricsUpdate =>
foreachListener(_.onExecutorMetricsUpdate(metricsUpdate))
+ case executorAdded: SparkListenerExecutorAdded =>
+ foreachListener(_.onExecutorAdded(executorAdded))
+ case executorRemoved: SparkListenerExecutorRemoved =>
+ foreachListener(_.onExecutorRemoved(executorRemoved))
case SparkListenerShutdown =>
}
}
diff --git a/core/src/main/scala/org/apache/spark/scheduler/Task.scala b/core/src/main/scala/org/apache/spark/scheduler/Task.scala
index 2367f7e2cf67e..847a4912eec13 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/Task.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/Task.scala
@@ -55,7 +55,7 @@ private[spark] abstract class Task[T](val stageId: Int, var partitionId: Int) ex
context = new TaskContextImpl(stageId = stageId, partitionId = partitionId,
taskAttemptId = taskAttemptId, attemptNumber = attemptNumber, runningLocally = false)
TaskContextHelper.setTaskContext(context)
- context.taskMetrics.hostname = Utils.localHostName()
+ context.taskMetrics.setHostname(Utils.localHostName())
taskThread = Thread.currentThread()
if (_killed) {
kill(interruptThread = false)
diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskResultGetter.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskResultGetter.scala
index 4896ec845bbc9..774f3d8cdb275 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/TaskResultGetter.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/TaskResultGetter.scala
@@ -77,7 +77,7 @@ private[spark] class TaskResultGetter(sparkEnv: SparkEnv, scheduler: TaskSchedul
(deserializedResult, size)
}
- result.metrics.resultSize = size
+ result.metrics.setResultSize(size)
scheduler.handleSuccessfulTask(taskSetManager, tid, result)
} catch {
case cnf: ClassNotFoundException =>
diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala
index a1dfb01062591..33a7aae5d3fcd 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala
@@ -168,7 +168,7 @@ private[spark] class TaskSchedulerImpl(
if (!hasLaunchedTask) {
logWarning("Initial job has not accepted any resources; " +
"check your cluster UI to ensure that workers are registered " +
- "and have sufficient memory")
+ "and have sufficient resources")
} else {
this.cancel()
}
diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala
index fe9914b50bc54..5786d367464f4 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala
@@ -28,7 +28,7 @@ import akka.pattern.ask
import akka.remote.{DisassociatedEvent, RemotingLifecycleEvent}
import org.apache.spark.{ExecutorAllocationClient, Logging, SparkEnv, SparkException, TaskState}
-import org.apache.spark.scheduler.{SchedulerBackend, SlaveLost, TaskDescription, TaskSchedulerImpl, WorkerOffer}
+import org.apache.spark.scheduler._
import org.apache.spark.scheduler.cluster.CoarseGrainedClusterMessages._
import org.apache.spark.util.{ActorLogReceive, SerializableBuffer, AkkaUtils, Utils}
@@ -66,6 +66,8 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val actorSyste
// Number of executors requested from the cluster manager that have not registered yet
private var numPendingExecutors = 0
+ private val listenerBus = scheduler.sc.listenerBus
+
// Executors we have requested the cluster manager to kill that have not died yet
private val executorsPendingToRemove = new HashSet[String]
@@ -106,6 +108,7 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val actorSyste
logDebug(s"Decremented number of pending executors ($numPendingExecutors left)")
}
}
+ listenerBus.post(SparkListenerExecutorAdded(executorId, data))
makeOffers()
}
@@ -213,6 +216,7 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val actorSyste
totalCoreCount.addAndGet(-executorInfo.totalCores)
totalRegisteredExecutors.addAndGet(-1)
scheduler.executorLost(executorId, SlaveLost(reason))
+ listenerBus.post(SparkListenerExecutorRemoved(executorId))
case None => logError(s"Asked to remove non-existent executor $executorId")
}
}
diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/ExecutorData.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/ExecutorData.scala
index b71bd5783d6df..eb52ddfb1eab1 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/cluster/ExecutorData.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/ExecutorData.scala
@@ -31,7 +31,7 @@ import akka.actor.{Address, ActorRef}
private[cluster] class ExecutorData(
val executorActor: ActorRef,
val executorAddress: Address,
- val executorHost: String ,
+ override val executorHost: String,
var freeCores: Int,
- val totalCores: Int
-)
+ override val totalCores: Int
+) extends ExecutorInfo(executorHost, totalCores)
diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/ExecutorInfo.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/ExecutorInfo.scala
new file mode 100644
index 0000000000000..b4738e64c9391
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/ExecutorInfo.scala
@@ -0,0 +1,45 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.scheduler.cluster
+
+import org.apache.spark.annotation.DeveloperApi
+
+/**
+ * :: DeveloperApi ::
+ * Stores information about an executor to pass from the scheduler to SparkListeners.
+ */
+@DeveloperApi
+class ExecutorInfo(
+ val executorHost: String,
+ val totalCores: Int
+) {
+
+ def canEqual(other: Any): Boolean = other.isInstanceOf[ExecutorInfo]
+
+ override def equals(other: Any): Boolean = other match {
+ case that: ExecutorInfo =>
+ (that canEqual this) &&
+ executorHost == that.executorHost &&
+ totalCores == that.totalCores
+ case _ => false
+ }
+
+ override def hashCode(): Int = {
+ val state = Seq(executorHost, totalCores)
+ state.map(_.hashCode()).foldLeft(0)((a, b) => 31 * a + b)
+ }
+}
diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackend.scala
index 75d8ddf375e27..79c9051e88691 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackend.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackend.scala
@@ -27,9 +27,12 @@ import scala.collection.mutable.{HashMap, HashSet}
import org.apache.mesos.protobuf.ByteString
import org.apache.mesos.{Scheduler => MScheduler}
import org.apache.mesos._
-import org.apache.mesos.Protos.{TaskInfo => MesosTaskInfo, TaskState => MesosTaskState, _}
+import org.apache.mesos.Protos.{TaskInfo => MesosTaskInfo, TaskState => MesosTaskState,
+ ExecutorInfo => MesosExecutorInfo, _}
+import org.apache.spark.executor.MesosExecutorBackend
import org.apache.spark.{Logging, SparkContext, SparkException, TaskState}
+import org.apache.spark.scheduler.cluster.ExecutorInfo
import org.apache.spark.scheduler._
import org.apache.spark.util.Utils
@@ -62,6 +65,9 @@ private[spark] class MesosSchedulerBackend(
var classLoader: ClassLoader = null
+ // The listener bus to publish executor added/removed events.
+ val listenerBus = sc.listenerBus
+
@volatile var appId: String = _
override def start() {
@@ -87,7 +93,7 @@ private[spark] class MesosSchedulerBackend(
}
}
- def createExecutorInfo(execId: String): ExecutorInfo = {
+ def createExecutorInfo(execId: String): MesosExecutorInfo = {
val executorSparkHome = sc.conf.getOption("spark.mesos.executor.home")
.orElse(sc.getSparkHome()) // Fall back to driver Spark home for backward compatibility
.getOrElse {
@@ -118,14 +124,15 @@ private[spark] class MesosSchedulerBackend(
val command = CommandInfo.newBuilder()
.setEnvironment(environment)
val uri = sc.conf.get("spark.executor.uri", null)
+ val executorBackendName = classOf[MesosExecutorBackend].getName
if (uri == null) {
- val executorPath = new File(executorSparkHome, "/sbin/spark-executor").getCanonicalPath
- command.setValue("%s %s".format(prefixEnv, executorPath))
+ val executorPath = new File(executorSparkHome, "/bin/spark-class").getCanonicalPath
+ command.setValue(s"$prefixEnv $executorPath $executorBackendName")
} else {
// Grab everything to the first '.'. We'll use that and '*' to
// glob the directory "correctly".
val basename = uri.split('/').last.split('.').head
- command.setValue("cd %s*; %s ./sbin/spark-executor".format(basename, prefixEnv))
+ command.setValue(s"cd ${basename}*; $prefixEnv ./bin/spark-class $executorBackendName")
command.addUris(CommandInfo.URI.newBuilder().setValue(uri))
}
val cpus = Resource.newBuilder()
@@ -141,7 +148,7 @@ private[spark] class MesosSchedulerBackend(
Value.Scalar.newBuilder()
.setValue(MemoryUtils.calculateTotalMemory(sc)).build())
.build()
- ExecutorInfo.newBuilder()
+ MesosExecutorInfo.newBuilder()
.setExecutorId(ExecutorID.newBuilder().setValue(execId).build())
.setCommand(command)
.setData(ByteString.copyFrom(createExecArg()))
@@ -237,6 +244,7 @@ private[spark] class MesosSchedulerBackend(
}
val slaveIdToOffer = usableOffers.map(o => o.getSlaveId.getValue -> o).toMap
+ val slaveIdToWorkerOffer = workerOffers.map(o => o.executorId -> o).toMap
val mesosTasks = new HashMap[String, JArrayList[MesosTaskInfo]]
@@ -260,6 +268,10 @@ private[spark] class MesosSchedulerBackend(
val filters = Filters.newBuilder().setRefuseSeconds(1).build() // TODO: lower timeout?
mesosTasks.foreach { case (slaveId, tasks) =>
+ slaveIdToWorkerOffer.get(slaveId).foreach(o =>
+ listenerBus.post(SparkListenerExecutorAdded(slaveId,
+ new ExecutorInfo(o.host, o.cores)))
+ )
d.launchTasks(Collections.singleton(slaveIdToOffer(slaveId).getId), tasks, filters)
}
@@ -315,7 +327,7 @@ private[spark] class MesosSchedulerBackend(
synchronized {
if (status.getState == MesosTaskState.TASK_LOST && taskIdToSlaveId.contains(tid)) {
// We lost the executor on this slave, so remember that it's gone
- slaveIdsWithExecutors -= taskIdToSlaveId(tid)
+ removeExecutor(taskIdToSlaveId(tid))
}
if (isFinished(status.getState)) {
taskIdToSlaveId.remove(tid)
@@ -344,12 +356,20 @@ private[spark] class MesosSchedulerBackend(
override def frameworkMessage(d: SchedulerDriver, e: ExecutorID, s: SlaveID, b: Array[Byte]) {}
+ /**
+ * Remove executor associated with slaveId in a thread safe manner.
+ */
+ private def removeExecutor(slaveId: String) = {
+ synchronized {
+ listenerBus.post(SparkListenerExecutorRemoved(slaveId))
+ slaveIdsWithExecutors -= slaveId
+ }
+ }
+
private def recordSlaveLost(d: SchedulerDriver, slaveId: SlaveID, reason: ExecutorLossReason) {
inClassLoader() {
logInfo("Mesos slave lost: " + slaveId.getValue)
- synchronized {
- slaveIdsWithExecutors -= slaveId.getValue
- }
+ removeExecutor(slaveId.getValue)
scheduler.executorLost(slaveId.getValue, reason)
}
}
diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosTaskLaunchData.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosTaskLaunchData.scala
index 4416ce92ade25..5e7e6567a3e06 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosTaskLaunchData.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosTaskLaunchData.scala
@@ -21,24 +21,29 @@ import java.nio.ByteBuffer
import org.apache.mesos.protobuf.ByteString
+import org.apache.spark.Logging
+
/**
* Wrapper for serializing the data sent when launching Mesos tasks.
*/
private[spark] case class MesosTaskLaunchData(
serializedTask: ByteBuffer,
- attemptNumber: Int) {
+ attemptNumber: Int) extends Logging {
def toByteString: ByteString = {
val dataBuffer = ByteBuffer.allocate(4 + serializedTask.limit)
dataBuffer.putInt(attemptNumber)
dataBuffer.put(serializedTask)
+ dataBuffer.rewind
+ logDebug(s"ByteBuffer size: [${dataBuffer.remaining}]")
ByteString.copyFrom(dataBuffer)
}
}
-private[spark] object MesosTaskLaunchData {
+private[spark] object MesosTaskLaunchData extends Logging {
def fromByteString(byteString: ByteString): MesosTaskLaunchData = {
val byteBuffer = byteString.asReadOnlyByteBuffer()
+ logDebug(s"ByteBuffer size: [${byteBuffer.remaining}]")
val attemptNumber = byteBuffer.getInt // updates the position by 4 bytes
val serializedTask = byteBuffer.slice() // subsequence starting at the current position
MesosTaskLaunchData(serializedTask, attemptNumber)
diff --git a/core/src/main/scala/org/apache/spark/serializer/JavaSerializer.scala b/core/src/main/scala/org/apache/spark/serializer/JavaSerializer.scala
index 662a7b91248aa..fa8a337ad63a8 100644
--- a/core/src/main/scala/org/apache/spark/serializer/JavaSerializer.scala
+++ b/core/src/main/scala/org/apache/spark/serializer/JavaSerializer.scala
@@ -92,7 +92,7 @@ private[spark] class JavaSerializerInstance(counterReset: Int, defaultClassLoade
}
override def deserializeStream(s: InputStream): DeserializationStream = {
- new JavaDeserializationStream(s, Utils.getContextOrSparkClassLoader)
+ new JavaDeserializationStream(s, defaultClassLoader)
}
def deserializeStream(s: InputStream, loader: ClassLoader): DeserializationStream = {
diff --git a/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleReader.scala b/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleReader.scala
index de72148ccc7ac..41bafabde05b9 100644
--- a/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleReader.scala
+++ b/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleReader.scala
@@ -59,8 +59,8 @@ private[spark] class HashShuffleReader[K, C](
// the ExternalSorter won't spill to disk.
val sorter = new ExternalSorter[K, C, C](ordering = Some(keyOrd), serializer = Some(ser))
sorter.insertAll(aggregatedIter)
- context.taskMetrics.memoryBytesSpilled += sorter.memoryBytesSpilled
- context.taskMetrics.diskBytesSpilled += sorter.diskBytesSpilled
+ context.taskMetrics.incMemoryBytesSpilled(sorter.memoryBytesSpilled)
+ context.taskMetrics.incDiskBytesSpilled(sorter.diskBytesSpilled)
sorter.iterator
case None =>
aggregatedIter
diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala
index d7b184f8a10e9..8bc5a1cd18b64 100644
--- a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala
+++ b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala
@@ -34,10 +34,9 @@ import org.apache.spark.executor._
import org.apache.spark.io.CompressionCodec
import org.apache.spark.network._
import org.apache.spark.network.buffer.{ManagedBuffer, NioManagedBuffer}
-import org.apache.spark.network.netty.{SparkTransportConf, NettyBlockTransferService}
+import org.apache.spark.network.netty.SparkTransportConf
import org.apache.spark.network.shuffle.ExternalShuffleClient
import org.apache.spark.network.shuffle.protocol.ExecutorShuffleInfo
-import org.apache.spark.network.util.{ConfigProvider, TransportConf}
import org.apache.spark.serializer.Serializer
import org.apache.spark.shuffle.ShuffleManager
import org.apache.spark.shuffle.hash.HashShuffleManager
@@ -54,7 +53,7 @@ private[spark] class BlockResult(
readMethod: DataReadMethod.Value,
bytes: Long) {
val inputMetrics = new InputMetrics(readMethod)
- inputMetrics.bytesRead = bytes
+ inputMetrics.addBytesRead(bytes)
}
/**
@@ -120,7 +119,7 @@ private[spark] class BlockManager(
private[spark] var shuffleServerId: BlockManagerId = _
// Client to read other executors' shuffle files. This is either an external service, or just the
- // standard BlockTranserService to directly connect to other Executors.
+ // standard BlockTransferService to directly connect to other Executors.
private[spark] val shuffleClient = if (externalShuffleServiceEnabled) {
val transConf = SparkTransportConf.fromSparkConf(conf, numUsableCores)
new ExternalShuffleClient(transConf, securityManager, securityManager.isAuthenticationEnabled())
diff --git a/core/src/main/scala/org/apache/spark/storage/BlockObjectWriter.scala b/core/src/main/scala/org/apache/spark/storage/BlockObjectWriter.scala
index 9c469370ffe1f..3198d766fca37 100644
--- a/core/src/main/scala/org/apache/spark/storage/BlockObjectWriter.scala
+++ b/core/src/main/scala/org/apache/spark/storage/BlockObjectWriter.scala
@@ -160,14 +160,14 @@ private[spark] class DiskBlockObjectWriter(
}
finalPosition = file.length()
// In certain compression codecs, more bytes are written after close() is called
- writeMetrics.shuffleBytesWritten += (finalPosition - reportedPosition)
+ writeMetrics.incShuffleBytesWritten(finalPosition - reportedPosition)
}
// Discard current writes. We do this by flushing the outstanding writes and then
// truncating the file to its initial position.
override def revertPartialWritesAndClose() {
try {
- writeMetrics.shuffleBytesWritten -= (reportedPosition - initialPosition)
+ writeMetrics.decShuffleBytesWritten(reportedPosition - initialPosition)
if (initialized) {
objOut.flush()
@@ -212,14 +212,14 @@ private[spark] class DiskBlockObjectWriter(
*/
private def updateBytesWritten() {
val pos = channel.position()
- writeMetrics.shuffleBytesWritten += (pos - reportedPosition)
+ writeMetrics.incShuffleBytesWritten(pos - reportedPosition)
reportedPosition = pos
}
private def callWithTiming(f: => Unit) = {
val start = System.nanoTime()
f
- writeMetrics.shuffleWriteTime += (System.nanoTime() - start)
+ writeMetrics.incShuffleWriteTime(System.nanoTime() - start)
}
// For testing
diff --git a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala
index 2499c11a65b0e..ab9ee4f0096bf 100644
--- a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala
+++ b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala
@@ -156,8 +156,8 @@ final class ShuffleBlockFetcherIterator(
// This needs to be released after use.
buf.retain()
results.put(new SuccessFetchResult(BlockId(blockId), sizeMap(blockId), buf))
- shuffleMetrics.remoteBytesRead += buf.size
- shuffleMetrics.remoteBlocksFetched += 1
+ shuffleMetrics.incRemoteBytesRead(buf.size)
+ shuffleMetrics.incRemoteBlocksFetched(1)
}
logTrace("Got remote block " + blockId + " after " + Utils.getUsedTimeMs(startTime))
}
@@ -233,7 +233,7 @@ final class ShuffleBlockFetcherIterator(
val blockId = iter.next()
try {
val buf = blockManager.getBlockData(blockId)
- shuffleMetrics.localBlocksFetched += 1
+ shuffleMetrics.incLocalBlocksFetched(1)
buf.retain()
results.put(new SuccessFetchResult(blockId, 0, buf))
} catch {
@@ -277,7 +277,7 @@ final class ShuffleBlockFetcherIterator(
currentResult = results.take()
val result = currentResult
val stopFetchWait = System.currentTimeMillis()
- shuffleMetrics.fetchWaitTime += (stopFetchWait - startFetchWait)
+ shuffleMetrics.incFetchWaitTime(stopFetchWait - startFetchWait)
result match {
case SuccessFetchResult(_, size, _) => bytesInFlight -= size
diff --git a/core/src/main/scala/org/apache/spark/ui/ToolTips.scala b/core/src/main/scala/org/apache/spark/ui/ToolTips.scala
index 6f446c5a95a0a..4307029d44fbb 100644
--- a/core/src/main/scala/org/apache/spark/ui/ToolTips.scala
+++ b/core/src/main/scala/org/apache/spark/ui/ToolTips.scala
@@ -24,8 +24,10 @@ private[spark] object ToolTips {
scheduler delay is large, consider decreasing the size of tasks or decreasing the size
of task results."""
- val TASK_DESERIALIZATION_TIME =
- """Time spent deserializating the task closure on the executor."""
+ val TASK_DESERIALIZATION_TIME = "Time spent deserializing the task closure on the executor."
+
+ val SHUFFLE_READ_BLOCKED_TIME =
+ "Time that the task spent blocked waiting for shuffle data to be read from remote machines."
val INPUT = "Bytes read from Hadoop or from Spark storage."
diff --git a/core/src/main/scala/org/apache/spark/ui/UIWorkloadGenerator.scala b/core/src/main/scala/org/apache/spark/ui/UIWorkloadGenerator.scala
index b4677447c8872..fc1844600f1cb 100644
--- a/core/src/main/scala/org/apache/spark/ui/UIWorkloadGenerator.scala
+++ b/core/src/main/scala/org/apache/spark/ui/UIWorkloadGenerator.scala
@@ -22,20 +22,23 @@ import scala.util.Random
import org.apache.spark.{SparkConf, SparkContext}
import org.apache.spark.scheduler.SchedulingMode
+// scalastyle:off
/**
* Continuously generates jobs that expose various features of the WebUI (internal testing tool).
*
- * Usage: ./bin/spark-class org.apache.spark.ui.UIWorkloadGenerator [master] [FIFO|FAIR]
+ * Usage: ./bin/spark-class org.apache.spark.ui.UIWorkloadGenerator [master] [FIFO|FAIR] [#job set (4 jobs per set)]
*/
+// scalastyle:on
private[spark] object UIWorkloadGenerator {
val NUM_PARTITIONS = 100
val INTER_JOB_WAIT_MS = 5000
def main(args: Array[String]) {
- if (args.length < 2) {
+ if (args.length < 3) {
println(
- "usage: ./bin/spark-class org.apache.spark.ui.UIWorkloadGenerator [master] [FIFO|FAIR]")
+ "usage: ./bin/spark-class org.apache.spark.ui.UIWorkloadGenerator " +
+ "[master] [FIFO|FAIR] [#job set (4 jobs per set)]")
System.exit(1)
}
@@ -45,6 +48,7 @@ private[spark] object UIWorkloadGenerator {
if (schedulingMode == SchedulingMode.FAIR) {
conf.set("spark.scheduler.mode", "FAIR")
}
+ val nJobSet = args(2).toInt
val sc = new SparkContext(conf)
def setProperties(s: String) = {
@@ -84,7 +88,7 @@ private[spark] object UIWorkloadGenerator {
("Job with delays", baseData.map(x => Thread.sleep(100)).count)
)
- while (true) {
+ (1 to nJobSet).foreach { _ =>
for ((desc, job) <- jobs) {
new Thread {
override def run() {
@@ -101,5 +105,6 @@ private[spark] object UIWorkloadGenerator {
Thread.sleep(INTER_JOB_WAIT_MS)
}
}
+ sc.stop()
}
}
diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/AllJobsPage.scala b/core/src/main/scala/org/apache/spark/ui/jobs/AllJobsPage.scala
index 1d1c701878447..045c69da06feb 100644
--- a/core/src/main/scala/org/apache/spark/ui/jobs/AllJobsPage.scala
+++ b/core/src/main/scala/org/apache/spark/ui/jobs/AllJobsPage.scala
@@ -21,7 +21,6 @@ import scala.xml.{Node, NodeSeq}
import javax.servlet.http.HttpServletRequest
-import org.apache.spark.JobExecutionStatus
import org.apache.spark.ui.{WebUIPage, UIUtils}
import org.apache.spark.ui.jobs.UIData.JobUIData
@@ -51,13 +50,13 @@ private[ui] class AllJobsPage(parent: JobsTab) extends WebUIPage("") {
val lastStageName = lastStageInfo.map(_.name).getOrElse("(Unknown Stage Name)")
val lastStageDescription = lastStageData.flatMap(_.description).getOrElse("")
val duration: Option[Long] = {
- job.startTime.map { start =>
- val end = job.endTime.getOrElse(System.currentTimeMillis())
+ job.submissionTime.map { start =>
+ val end = job.completionTime.getOrElse(System.currentTimeMillis())
end - start
}
}
val formattedDuration = duration.map(d => UIUtils.formatDuration(d)).getOrElse("Unknown")
- val formattedSubmissionTime = job.startTime.map(UIUtils.formatDate).getOrElse("Unknown")
+ val formattedSubmissionTime = job.submissionTime.map(UIUtils.formatDate).getOrElse("Unknown")
val detailUrl =
"%s/jobs/job?id=%s".format(UIUtils.prependBaseUri(parent.basePath), job.jobId)
++
failedStagesTable.toNodeSeq
-
+ }
UIUtils.headerSparkPage("Spark Stages (for all jobs)", content, parent)
}
}
diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala b/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala
index 72935beb3a34a..4d200eeda86b9 100644
--- a/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala
+++ b/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala
@@ -56,6 +56,7 @@ class JobProgressListener(conf: SparkConf) extends SparkListener with Logging {
val jobIdToData = new HashMap[JobId, JobUIData]
// Stages:
+ val pendingStages = new HashMap[StageId, StageInfo]
val activeStages = new HashMap[StageId, StageInfo]
val completedStages = ListBuffer[StageInfo]()
val skippedStages = ListBuffer[StageInfo]()
@@ -153,14 +154,14 @@ class JobProgressListener(conf: SparkConf) extends SparkListener with Logging {
val jobData: JobUIData =
new JobUIData(
jobId = jobStart.jobId,
- startTime = Some(System.currentTimeMillis),
- endTime = None,
+ submissionTime = Option(jobStart.time).filter(_ >= 0),
stageIds = jobStart.stageIds,
jobGroup = jobGroup,
status = JobExecutionStatus.RUNNING)
+ jobStart.stageInfos.foreach(x => pendingStages(x.stageId) = x)
// Compute (a potential underestimate of) the number of tasks that will be run by this job.
// This may be an underestimate because the job start event references all of the result
- // stages's transitive stage dependencies, but some of these stages might be skipped if their
+ // stages' transitive stage dependencies, but some of these stages might be skipped if their
// output is available from earlier runs.
// See https://github.com/apache/spark/pull/3009 for a more extensive discussion.
jobData.numTasks = {
@@ -186,7 +187,9 @@ class JobProgressListener(conf: SparkConf) extends SparkListener with Logging {
logWarning(s"Job completed for unknown job ${jobEnd.jobId}")
new JobUIData(jobId = jobEnd.jobId)
}
- jobData.endTime = Some(System.currentTimeMillis())
+ jobData.completionTime = Option(jobEnd.time).filter(_ >= 0)
+
+ jobData.stageIds.foreach(pendingStages.remove)
jobEnd.jobResult match {
case JobSucceeded =>
completedJobs += jobData
@@ -257,7 +260,7 @@ class JobProgressListener(conf: SparkConf) extends SparkListener with Logging {
override def onStageSubmitted(stageSubmitted: SparkListenerStageSubmitted) = synchronized {
val stage = stageSubmitted.stageInfo
activeStages(stage.stageId) = stage
-
+ pendingStages.remove(stage.stageId)
val poolName = Option(stageSubmitted.properties).map {
p => p.getProperty("spark.scheduler.pool", DEFAULT_POOL_NAME)
}.getOrElse(DEFAULT_POOL_NAME)
@@ -309,7 +312,7 @@ class JobProgressListener(conf: SparkConf) extends SparkListener with Logging {
override def onTaskEnd(taskEnd: SparkListenerTaskEnd) = synchronized {
val info = taskEnd.taskInfo
// If stage attempt id is -1, it means the DAGScheduler had no idea which attempt this task
- // compeletion event is for. Let's just drop it here. This means we might have some speculation
+ // completion event is for. Let's just drop it here. This means we might have some speculation
// tasks on the web ui that's never marked as complete.
if (info != null && taskEnd.stageAttemptId != -1) {
val stageData = stageIdToData.getOrElseUpdate((taskEnd.stageId, taskEnd.stageAttemptId), {
diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala b/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala
index 09a936c2234c0..d8be1b20b3acd 100644
--- a/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala
+++ b/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala
@@ -132,6 +132,15 @@ private[ui] class StagePage(parent: StagesTab) extends WebUIPage("stage") {
Task Deserialization Time
+ {if (hasShuffleRead) {
+
}
diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/TaskDetailsClassNames.scala b/core/src/main/scala/org/apache/spark/ui/jobs/TaskDetailsClassNames.scala
index 2d13bb6ddde42..37cf2c207ba40 100644
--- a/core/src/main/scala/org/apache/spark/ui/jobs/TaskDetailsClassNames.scala
+++ b/core/src/main/scala/org/apache/spark/ui/jobs/TaskDetailsClassNames.scala
@@ -27,6 +27,7 @@ package org.apache.spark.ui.jobs
private[spark] object TaskDetailsClassNames {
val SCHEDULER_DELAY = "scheduler_delay"
val TASK_DESERIALIZATION_TIME = "deserialization_time"
+ val SHUFFLE_READ_BLOCKED_TIME = "fetch_wait_time"
val RESULT_SERIALIZATION_TIME = "serialization_time"
val GETTING_RESULT_TIME = "getting_result_time"
}
diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/UIData.scala b/core/src/main/scala/org/apache/spark/ui/jobs/UIData.scala
index 48fd7caa1a1ed..01f7e23212c3d 100644
--- a/core/src/main/scala/org/apache/spark/ui/jobs/UIData.scala
+++ b/core/src/main/scala/org/apache/spark/ui/jobs/UIData.scala
@@ -40,15 +40,15 @@ private[jobs] object UIData {
class JobUIData(
var jobId: Int = -1,
- var startTime: Option[Long] = None,
- var endTime: Option[Long] = None,
+ var submissionTime: Option[Long] = None,
+ var completionTime: Option[Long] = None,
var stageIds: Seq[Int] = Seq.empty,
var jobGroup: Option[String] = None,
var status: JobExecutionStatus = JobExecutionStatus.UNKNOWN,
/* Tasks */
// `numTasks` is a potential underestimate of the true number of tasks that this job will run.
// This may be an underestimate because the job start event references all of the result
- // stages's transitive stage dependencies, but some of these stages might be skipped if their
+ // stages' transitive stage dependencies, but some of these stages might be skipped if their
// output is available from earlier runs.
// See https://github.com/apache/spark/pull/3009 for a more extensive discussion.
var numTasks: Int = 0,
diff --git a/core/src/main/scala/org/apache/spark/util/EventLoop.scala b/core/src/main/scala/org/apache/spark/util/EventLoop.scala
new file mode 100644
index 0000000000000..b0ed908b84424
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/util/EventLoop.scala
@@ -0,0 +1,124 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.util
+
+import java.util.concurrent.atomic.AtomicBoolean
+import java.util.concurrent.{BlockingQueue, LinkedBlockingDeque}
+
+import scala.util.control.NonFatal
+
+import org.apache.spark.Logging
+
+/**
+ * An event loop to receive events from the caller and process all events in the event thread. It
+ * will start an exclusive event thread to process all events.
+ *
+ * Note: The event queue will grow indefinitely. So subclasses should make sure `onReceive` can
+ * handle events in time to avoid the potential OOM.
+ */
+private[spark] abstract class EventLoop[E](name: String) extends Logging {
+
+ private val eventQueue: BlockingQueue[E] = new LinkedBlockingDeque[E]()
+
+ private val stopped = new AtomicBoolean(false)
+
+ private val eventThread = new Thread(name) {
+ setDaemon(true)
+
+ override def run(): Unit = {
+ try {
+ while (!stopped.get) {
+ val event = eventQueue.take()
+ try {
+ onReceive(event)
+ } catch {
+ case NonFatal(e) => {
+ try {
+ onError(e)
+ } catch {
+ case NonFatal(e) => logError("Unexpected error in " + name, e)
+ }
+ }
+ }
+ }
+ } catch {
+ case ie: InterruptedException => // exit even if eventQueue is not empty
+ case NonFatal(e) => logError("Unexpected error in " + name, e)
+ }
+ }
+
+ }
+
+ def start(): Unit = {
+ if (stopped.get) {
+ throw new IllegalStateException(name + " has already been stopped")
+ }
+ // Call onStart before starting the event thread to make sure it happens before onReceive
+ onStart()
+ eventThread.start()
+ }
+
+ def stop(): Unit = {
+ if (stopped.compareAndSet(false, true)) {
+ eventThread.interrupt()
+ eventThread.join()
+ // Call onStop after the event thread exits to make sure onReceive happens before onStop
+ onStop()
+ } else {
+ // Keep quiet to allow calling `stop` multiple times.
+ }
+ }
+
+ /**
+ * Put the event into the event queue. The event thread will process it later.
+ */
+ def post(event: E): Unit = {
+ eventQueue.put(event)
+ }
+
+ /**
+ * Return if the event thread has already been started but not yet stopped.
+ */
+ def isActive: Boolean = eventThread.isAlive
+
+ /**
+ * Invoked when `start()` is called but before the event thread starts.
+ */
+ protected def onStart(): Unit = {}
+
+ /**
+ * Invoked when `stop()` is called and the event thread exits.
+ */
+ protected def onStop(): Unit = {}
+
+ /**
+ * Invoked in the event thread when polling events from the event queue.
+ *
+ * Note: Should avoid calling blocking actions in `onReceive`, or the event thread will be blocked
+ * and cannot process events in time. If you want to call some blocking actions, run them in
+ * another thread.
+ */
+ protected def onReceive(event: E): Unit
+
+ /**
+ * Invoked if `onReceive` throws any non fatal error. Any non fatal error thrown from `onError`
+ * will be ignored.
+ */
+ protected def onError(e: Throwable): Unit
+
+}
diff --git a/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala b/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala
index d94e8252650d2..f896b5072e4fa 100644
--- a/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala
+++ b/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala
@@ -19,6 +19,8 @@ package org.apache.spark.util
import java.util.{Properties, UUID}
+import org.apache.spark.scheduler.cluster.ExecutorInfo
+
import scala.collection.JavaConverters._
import scala.collection.Map
@@ -30,6 +32,7 @@ import org.apache.spark.executor._
import org.apache.spark.scheduler._
import org.apache.spark.storage._
import org.apache.spark._
+import org.apache.hadoop.hdfs.web.JsonUtil
/**
* Serializes SparkListener events to/from JSON. This protocol provides strong backwards-
@@ -83,7 +86,10 @@ private[spark] object JsonProtocol {
applicationStartToJson(applicationStart)
case applicationEnd: SparkListenerApplicationEnd =>
applicationEndToJson(applicationEnd)
-
+ case executorAdded: SparkListenerExecutorAdded =>
+ executorAddedToJson(executorAdded)
+ case executorRemoved: SparkListenerExecutorRemoved =>
+ executorRemovedToJson(executorRemoved)
// These aren't used, but keeps compiler happy
case SparkListenerShutdown => JNothing
case SparkListenerExecutorMetricsUpdate(_, _) => JNothing
@@ -136,6 +142,7 @@ private[spark] object JsonProtocol {
val properties = propertiesToJson(jobStart.properties)
("Event" -> Utils.getFormattedClassName(jobStart)) ~
("Job ID" -> jobStart.jobId) ~
+ ("Submission Time" -> jobStart.time) ~
("Stage Infos" -> jobStart.stageInfos.map(stageInfoToJson)) ~ // Added in Spark 1.2.0
("Stage IDs" -> jobStart.stageIds) ~
("Properties" -> properties)
@@ -145,6 +152,7 @@ private[spark] object JsonProtocol {
val jobResult = jobResultToJson(jobEnd.jobResult)
("Event" -> Utils.getFormattedClassName(jobEnd)) ~
("Job ID" -> jobEnd.jobId) ~
+ ("Completion Time" -> jobEnd.time) ~
("Job Result" -> jobResult)
}
@@ -194,6 +202,16 @@ private[spark] object JsonProtocol {
("Timestamp" -> applicationEnd.time)
}
+ def executorAddedToJson(executorAdded: SparkListenerExecutorAdded): JValue = {
+ ("Event" -> Utils.getFormattedClassName(executorAdded)) ~
+ ("Executor ID" -> executorAdded.executorId) ~
+ ("Executor Info" -> executorInfoToJson(executorAdded.executorInfo))
+ }
+
+ def executorRemovedToJson(executorRemoved: SparkListenerExecutorRemoved): JValue = {
+ ("Event" -> Utils.getFormattedClassName(executorRemoved)) ~
+ ("Executor ID" -> executorRemoved.executorId)
+ }
/** ------------------------------------------------------------------- *
* JSON serialization methods for classes SparkListenerEvents depend on |
@@ -362,6 +380,10 @@ private[spark] object JsonProtocol {
("Disk Size" -> blockStatus.diskSize)
}
+ def executorInfoToJson(executorInfo: ExecutorInfo): JValue = {
+ ("Host" -> executorInfo.executorHost) ~
+ ("Total Cores" -> executorInfo.totalCores)
+ }
/** ------------------------------ *
* Util JSON serialization methods |
@@ -416,6 +438,8 @@ private[spark] object JsonProtocol {
val unpersistRDD = Utils.getFormattedClassName(SparkListenerUnpersistRDD)
val applicationStart = Utils.getFormattedClassName(SparkListenerApplicationStart)
val applicationEnd = Utils.getFormattedClassName(SparkListenerApplicationEnd)
+ val executorAdded = Utils.getFormattedClassName(SparkListenerExecutorAdded)
+ val executorRemoved = Utils.getFormattedClassName(SparkListenerExecutorRemoved)
(json \ "Event").extract[String] match {
case `stageSubmitted` => stageSubmittedFromJson(json)
@@ -431,6 +455,8 @@ private[spark] object JsonProtocol {
case `unpersistRDD` => unpersistRDDFromJson(json)
case `applicationStart` => applicationStartFromJson(json)
case `applicationEnd` => applicationEndFromJson(json)
+ case `executorAdded` => executorAddedFromJson(json)
+ case `executorRemoved` => executorRemovedFromJson(json)
}
}
@@ -469,6 +495,8 @@ private[spark] object JsonProtocol {
def jobStartFromJson(json: JValue): SparkListenerJobStart = {
val jobId = (json \ "Job ID").extract[Int]
+ val submissionTime =
+ Utils.jsonOption(json \ "Submission Time").map(_.extract[Long]).getOrElse(-1L)
val stageIds = (json \ "Stage IDs").extract[List[JValue]].map(_.extract[Int])
val properties = propertiesFromJson(json \ "Properties")
// The "Stage Infos" field was added in Spark 1.2.0
@@ -476,13 +504,15 @@ private[spark] object JsonProtocol {
.map(_.extract[Seq[JValue]].map(stageInfoFromJson)).getOrElse {
stageIds.map(id => new StageInfo(id, 0, "unknown", 0, Seq.empty, "unknown"))
}
- SparkListenerJobStart(jobId, stageInfos, properties)
+ SparkListenerJobStart(jobId, submissionTime, stageInfos, properties)
}
def jobEndFromJson(json: JValue): SparkListenerJobEnd = {
val jobId = (json \ "Job ID").extract[Int]
+ val completionTime =
+ Utils.jsonOption(json \ "Completion Time").map(_.extract[Long]).getOrElse(-1L)
val jobResult = jobResultFromJson(json \ "Job Result")
- SparkListenerJobEnd(jobId, jobResult)
+ SparkListenerJobEnd(jobId, completionTime, jobResult)
}
def environmentUpdateFromJson(json: JValue): SparkListenerEnvironmentUpdate = {
@@ -523,6 +553,16 @@ private[spark] object JsonProtocol {
SparkListenerApplicationEnd((json \ "Timestamp").extract[Long])
}
+ def executorAddedFromJson(json: JValue): SparkListenerExecutorAdded = {
+ val executorId = (json \ "Executor ID").extract[String]
+ val executorInfo = executorInfoFromJson(json \ "Executor Info")
+ SparkListenerExecutorAdded(executorId, executorInfo)
+ }
+
+ def executorRemovedFromJson(json: JValue): SparkListenerExecutorRemoved = {
+ val executorId = (json \ "Executor ID").extract[String]
+ SparkListenerExecutorRemoved(executorId)
+ }
/** --------------------------------------------------------------------- *
* JSON deserialization methods for classes SparkListenerEvents depend on |
@@ -592,20 +632,20 @@ private[spark] object JsonProtocol {
return TaskMetrics.empty
}
val metrics = new TaskMetrics
- metrics.hostname = (json \ "Host Name").extract[String]
- metrics.executorDeserializeTime = (json \ "Executor Deserialize Time").extract[Long]
- metrics.executorRunTime = (json \ "Executor Run Time").extract[Long]
- metrics.resultSize = (json \ "Result Size").extract[Long]
- metrics.jvmGCTime = (json \ "JVM GC Time").extract[Long]
- metrics.resultSerializationTime = (json \ "Result Serialization Time").extract[Long]
- metrics.memoryBytesSpilled = (json \ "Memory Bytes Spilled").extract[Long]
- metrics.diskBytesSpilled = (json \ "Disk Bytes Spilled").extract[Long]
+ metrics.setHostname((json \ "Host Name").extract[String])
+ metrics.setExecutorDeserializeTime((json \ "Executor Deserialize Time").extract[Long])
+ metrics.setExecutorRunTime((json \ "Executor Run Time").extract[Long])
+ metrics.setResultSize((json \ "Result Size").extract[Long])
+ metrics.setJvmGCTime((json \ "JVM GC Time").extract[Long])
+ metrics.setResultSerializationTime((json \ "Result Serialization Time").extract[Long])
+ metrics.incMemoryBytesSpilled((json \ "Memory Bytes Spilled").extract[Long])
+ metrics.incDiskBytesSpilled((json \ "Disk Bytes Spilled").extract[Long])
metrics.setShuffleReadMetrics(
Utils.jsonOption(json \ "Shuffle Read Metrics").map(shuffleReadMetricsFromJson))
metrics.shuffleWriteMetrics =
Utils.jsonOption(json \ "Shuffle Write Metrics").map(shuffleWriteMetricsFromJson)
- metrics.inputMetrics =
- Utils.jsonOption(json \ "Input Metrics").map(inputMetricsFromJson)
+ metrics.setInputMetrics(
+ Utils.jsonOption(json \ "Input Metrics").map(inputMetricsFromJson))
metrics.outputMetrics =
Utils.jsonOption(json \ "Output Metrics").map(outputMetricsFromJson)
metrics.updatedBlocks =
@@ -621,31 +661,31 @@ private[spark] object JsonProtocol {
def shuffleReadMetricsFromJson(json: JValue): ShuffleReadMetrics = {
val metrics = new ShuffleReadMetrics
- metrics.remoteBlocksFetched = (json \ "Remote Blocks Fetched").extract[Int]
- metrics.localBlocksFetched = (json \ "Local Blocks Fetched").extract[Int]
- metrics.fetchWaitTime = (json \ "Fetch Wait Time").extract[Long]
- metrics.remoteBytesRead = (json \ "Remote Bytes Read").extract[Long]
+ metrics.incRemoteBlocksFetched((json \ "Remote Blocks Fetched").extract[Int])
+ metrics.incLocalBlocksFetched((json \ "Local Blocks Fetched").extract[Int])
+ metrics.incFetchWaitTime((json \ "Fetch Wait Time").extract[Long])
+ metrics.incRemoteBytesRead((json \ "Remote Bytes Read").extract[Long])
metrics
}
def shuffleWriteMetricsFromJson(json: JValue): ShuffleWriteMetrics = {
val metrics = new ShuffleWriteMetrics
- metrics.shuffleBytesWritten = (json \ "Shuffle Bytes Written").extract[Long]
- metrics.shuffleWriteTime = (json \ "Shuffle Write Time").extract[Long]
+ metrics.incShuffleBytesWritten((json \ "Shuffle Bytes Written").extract[Long])
+ metrics.incShuffleWriteTime((json \ "Shuffle Write Time").extract[Long])
metrics
}
def inputMetricsFromJson(json: JValue): InputMetrics = {
val metrics = new InputMetrics(
DataReadMethod.withName((json \ "Data Read Method").extract[String]))
- metrics.bytesRead = (json \ "Bytes Read").extract[Long]
+ metrics.addBytesRead((json \ "Bytes Read").extract[Long])
metrics
}
def outputMetricsFromJson(json: JValue): OutputMetrics = {
val metrics = new OutputMetrics(
DataWriteMethod.withName((json \ "Data Write Method").extract[String]))
- metrics.bytesWritten = (json \ "Bytes Written").extract[Long]
+ metrics.setBytesWritten((json \ "Bytes Written").extract[Long])
metrics
}
@@ -745,6 +785,11 @@ private[spark] object JsonProtocol {
BlockStatus(storageLevel, memorySize, diskSize, tachyonSize)
}
+ def executorInfoFromJson(json: JValue): ExecutorInfo = {
+ val executorHost = (json \ "Host").extract[String]
+ val totalCores = (json \ "Total Cores").extract[Int]
+ new ExecutorInfo(executorHost, totalCores)
+ }
/** -------------------------------- *
* Util JSON deserialization methods |
diff --git a/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala b/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala
index 15bda1c9cc29c..6ba03841f746b 100644
--- a/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala
+++ b/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala
@@ -757,12 +757,12 @@ private[spark] class ExternalSorter[K, V, C](
}
}
- context.taskMetrics.memoryBytesSpilled += memoryBytesSpilled
- context.taskMetrics.diskBytesSpilled += diskBytesSpilled
+ context.taskMetrics.incMemoryBytesSpilled(memoryBytesSpilled)
+ context.taskMetrics.incDiskBytesSpilled(diskBytesSpilled)
context.taskMetrics.shuffleWriteMetrics.filter(_ => bypassMergeSort).foreach { m =>
if (curWriteMetrics != null) {
- m.shuffleBytesWritten += curWriteMetrics.shuffleBytesWritten
- m.shuffleWriteTime += curWriteMetrics.shuffleWriteTime
+ m.incShuffleBytesWritten(curWriteMetrics.shuffleBytesWritten)
+ m.incShuffleWriteTime(curWriteMetrics.shuffleWriteTime)
}
}
diff --git a/core/src/test/java/org/apache/spark/JavaAPISuite.java b/core/src/test/java/org/apache/spark/JavaAPISuite.java
index 07b1e44d04be6..004de05c10ee1 100644
--- a/core/src/test/java/org/apache/spark/JavaAPISuite.java
+++ b/core/src/test/java/org/apache/spark/JavaAPISuite.java
@@ -606,6 +606,27 @@ public void take() {
rdd.takeSample(false, 2, 42);
}
+ @Test
+ public void isEmpty() {
+ Assert.assertTrue(sc.emptyRDD().isEmpty());
+ Assert.assertTrue(sc.parallelize(new ArrayList()).isEmpty());
+ Assert.assertFalse(sc.parallelize(Arrays.asList(1)).isEmpty());
+ Assert.assertTrue(sc.parallelize(Arrays.asList(1, 2, 3), 3).filter(
+ new Function() {
+ @Override
+ public Boolean call(Integer i) {
+ return i < 0;
+ }
+ }).isEmpty());
+ Assert.assertFalse(sc.parallelize(Arrays.asList(1, 2, 3)).filter(
+ new Function() {
+ @Override
+ public Boolean call(Integer i) {
+ return i > 1;
+ }
+ }).isEmpty());
+ }
+
@Test
public void cartesian() {
JavaDoubleRDD doubleRDD = sc.parallelizeDoubles(Arrays.asList(1.0, 1.0, 2.0, 3.0, 5.0, 8.0));
diff --git a/core/src/test/scala/org/apache/spark/JobCancellationSuite.scala b/core/src/test/scala/org/apache/spark/JobCancellationSuite.scala
index 7584ae79fc920..21487bc24d58a 100644
--- a/core/src/test/scala/org/apache/spark/JobCancellationSuite.scala
+++ b/core/src/test/scala/org/apache/spark/JobCancellationSuite.scala
@@ -171,11 +171,11 @@ class JobCancellationSuite extends FunSuite with Matchers with BeforeAndAfter
assert(jobB.get() === 100)
}
- ignore("two jobs sharing the same stage") {
+ test("two jobs sharing the same stage") {
// sem1: make sure cancel is issued after some tasks are launched
- // sem2: make sure the first stage is not finished until cancel is issued
+ // twoJobsSharingStageSemaphore:
+ // make sure the first stage is not finished until cancel is issued
val sem1 = new Semaphore(0)
- val sem2 = new Semaphore(0)
sc = new SparkContext("local[2]", "test")
sc.addSparkListener(new SparkListener {
@@ -186,7 +186,7 @@ class JobCancellationSuite extends FunSuite with Matchers with BeforeAndAfter
// Create two actions that would share the some stages.
val rdd = sc.parallelize(1 to 10, 2).map { i =>
- sem2.acquire()
+ JobCancellationSuite.twoJobsSharingStageSemaphore.acquire()
(i, i)
}.reduceByKey(_+_)
val f1 = rdd.collectAsync()
@@ -196,13 +196,13 @@ class JobCancellationSuite extends FunSuite with Matchers with BeforeAndAfter
future {
sem1.acquire()
f1.cancel()
- sem2.release(10)
+ JobCancellationSuite.twoJobsSharingStageSemaphore.release(10)
}
- // Expect both to fail now.
- // TODO: update this test when we change Spark so cancelling f1 wouldn't affect f2.
+ // Expect f1 to fail due to cancellation,
intercept[SparkException] { f1.get() }
- intercept[SparkException] { f2.get() }
+ // but f2 should not be affected
+ f2.get()
}
def testCount() {
@@ -268,4 +268,5 @@ class JobCancellationSuite extends FunSuite with Matchers with BeforeAndAfter
object JobCancellationSuite {
val taskStartedSemaphore = new Semaphore(0)
val taskCancelledSemaphore = new Semaphore(0)
+ val twoJobsSharingStageSemaphore = new Semaphore(0)
}
diff --git a/core/src/test/scala/org/apache/spark/broadcast/BroadcastSuite.scala b/core/src/test/scala/org/apache/spark/broadcast/BroadcastSuite.scala
index b0a70f012f1f3..af3272692d7a1 100644
--- a/core/src/test/scala/org/apache/spark/broadcast/BroadcastSuite.scala
+++ b/core/src/test/scala/org/apache/spark/broadcast/BroadcastSuite.scala
@@ -170,6 +170,15 @@ class BroadcastSuite extends FunSuite with LocalSparkContext {
testPackage.runCallSiteTest(sc)
}
+ test("Broadcast variables cannot be created after SparkContext is stopped (SPARK-5065)") {
+ sc = new SparkContext("local", "test")
+ sc.stop()
+ val thrown = intercept[IllegalStateException] {
+ sc.broadcast(Seq(1, 2, 3))
+ }
+ assert(thrown.getMessage.toLowerCase.contains("stopped"))
+ }
+
/**
* Verify the persistence of state associated with an HttpBroadcast in either local mode or
* local-cluster mode (when distributed = true).
@@ -349,8 +358,7 @@ class BroadcastSuite extends FunSuite with LocalSparkContext {
package object testPackage extends Assertions {
def runCallSiteTest(sc: SparkContext) {
- val rdd = sc.makeRDD(Array(1, 2, 3, 4), 2)
- val broadcast = sc.broadcast(rdd)
+ val broadcast = sc.broadcast(Array(1, 2, 3, 4))
broadcast.destroy()
val thrown = intercept[SparkException] { broadcast.value }
assert(thrown.getMessage.contains("BroadcastSuite.scala"))
diff --git a/core/src/test/scala/org/apache/spark/deploy/history/FsHistoryProviderSuite.scala b/core/src/test/scala/org/apache/spark/deploy/history/FsHistoryProviderSuite.scala
index 8379883e065e7..3fbc1a21d10ed 100644
--- a/core/src/test/scala/org/apache/spark/deploy/history/FsHistoryProviderSuite.scala
+++ b/core/src/test/scala/org/apache/spark/deploy/history/FsHistoryProviderSuite.scala
@@ -167,6 +167,29 @@ class FsHistoryProviderSuite extends FunSuite with BeforeAndAfter with Matchers
list.size should be (1)
}
+ test("history file is renamed from inprogress to completed") {
+ val conf = new SparkConf()
+ .set("spark.history.fs.logDirectory", testDir.getAbsolutePath())
+ .set("spark.testing", "true")
+ val provider = new FsHistoryProvider(conf)
+
+ val logFile1 = new File(testDir, "app1" + EventLoggingListener.IN_PROGRESS)
+ writeFile(logFile1, true, None,
+ SparkListenerApplicationStart("app1", Some("app1"), 1L, "test"),
+ SparkListenerApplicationEnd(2L)
+ )
+ provider.checkForLogs()
+ val appListBeforeRename = provider.getListing()
+ appListBeforeRename.size should be (1)
+ appListBeforeRename.head.logPath should endWith(EventLoggingListener.IN_PROGRESS)
+
+ logFile1.renameTo(new File(testDir, "app1"))
+ provider.checkForLogs()
+ val appListAfterRename = provider.getListing()
+ appListAfterRename.size should be (1)
+ appListAfterRename.head.logPath should not endWith(EventLoggingListener.IN_PROGRESS)
+ }
+
private def writeFile(file: File, isNewFormat: Boolean, codec: Option[CompressionCodec],
events: SparkListenerEvent*) = {
val out =
diff --git a/core/src/test/scala/org/apache/spark/deploy/worker/WorkerArgumentsTest.scala b/core/src/test/scala/org/apache/spark/deploy/worker/WorkerArgumentsTest.scala
index 1a28a9a187cd7..372d7aa453008 100644
--- a/core/src/test/scala/org/apache/spark/deploy/worker/WorkerArgumentsTest.scala
+++ b/core/src/test/scala/org/apache/spark/deploy/worker/WorkerArgumentsTest.scala
@@ -43,7 +43,7 @@ class WorkerArgumentsTest extends FunSuite {
}
override def clone: SparkConf = {
- new MySparkConf().setAll(settings)
+ new MySparkConf().setAll(getAll)
}
}
val conf = new MySparkConf()
@@ -62,7 +62,7 @@ class WorkerArgumentsTest extends FunSuite {
}
override def clone: SparkConf = {
- new MySparkConf().setAll(settings)
+ new MySparkConf().setAll(getAll)
}
}
val conf = new MySparkConf()
diff --git a/core/src/test/scala/org/apache/spark/metrics/InputOutputMetricsSuite.scala b/core/src/test/scala/org/apache/spark/metrics/InputOutputMetricsSuite.scala
index f8bcde12a371a..81db66ae17464 100644
--- a/core/src/test/scala/org/apache/spark/metrics/InputOutputMetricsSuite.scala
+++ b/core/src/test/scala/org/apache/spark/metrics/InputOutputMetricsSuite.scala
@@ -17,73 +17,201 @@
package org.apache.spark.metrics
-import java.io.{FileWriter, PrintWriter, File}
+import java.io.{File, FileWriter, PrintWriter}
-import org.apache.spark.SharedSparkContext
-import org.apache.spark.deploy.SparkHadoopUtil
-import org.apache.spark.scheduler.{SparkListenerTaskEnd, SparkListener}
+import scala.collection.mutable.ArrayBuffer
import org.scalatest.FunSuite
-import org.scalatest.Matchers
import org.apache.hadoop.conf.Configuration
-import org.apache.hadoop.fs.{Path, FileSystem}
+import org.apache.hadoop.fs.{FileSystem, Path}
+import org.apache.hadoop.io.{LongWritable, Text}
+import org.apache.hadoop.mapred.{FileSplit => OldFileSplit, InputSplit => OldInputSplit, JobConf,
+ LineRecordReader => OldLineRecordReader, RecordReader => OldRecordReader, Reporter,
+ TextInputFormat => OldTextInputFormat}
+import org.apache.hadoop.mapred.lib.{CombineFileInputFormat => OldCombineFileInputFormat,
+ CombineFileSplit => OldCombineFileSplit, CombineFileRecordReader => OldCombineFileRecordReader}
+import org.apache.hadoop.mapreduce.{InputSplit => NewInputSplit, RecordReader => NewRecordReader,
+ TaskAttemptContext}
+import org.apache.hadoop.mapreduce.lib.input.{CombineFileInputFormat => NewCombineFileInputFormat,
+ CombineFileRecordReader => NewCombineFileRecordReader, CombineFileSplit => NewCombineFileSplit,
+ FileSplit => NewFileSplit, TextInputFormat => NewTextInputFormat}
-import scala.collection.mutable.ArrayBuffer
+import org.apache.spark.SharedSparkContext
+import org.apache.spark.deploy.SparkHadoopUtil
+import org.apache.spark.scheduler.{SparkListener, SparkListenerTaskEnd}
+import org.apache.spark.util.Utils
+
+class InputOutputMetricsSuite extends FunSuite with SharedSparkContext {
+
+ @transient var tmpDir: File = _
+ @transient var tmpFile: File = _
+ @transient var tmpFilePath: String = _
+
+ override def beforeAll() {
+ super.beforeAll()
+
+ tmpDir = Utils.createTempDir()
+ val testTempDir = new File(tmpDir, "test")
+ testTempDir.mkdir()
-class InputOutputMetricsSuite extends FunSuite with SharedSparkContext with Matchers {
- test("input metrics when reading text file with single split") {
- val file = new File(getClass.getSimpleName + ".txt")
- val pw = new PrintWriter(new FileWriter(file))
- pw.println("some stuff")
- pw.println("some other stuff")
- pw.println("yet more stuff")
- pw.println("too much stuff")
+ tmpFile = new File(testTempDir, getClass.getSimpleName + ".txt")
+ val pw = new PrintWriter(new FileWriter(tmpFile))
+ for (x <- 1 to 1000000) {
+ pw.println("s")
+ }
pw.close()
- file.deleteOnExit()
- val taskBytesRead = new ArrayBuffer[Long]()
- sc.addSparkListener(new SparkListener() {
- override def onTaskEnd(taskEnd: SparkListenerTaskEnd) {
- taskBytesRead += taskEnd.taskMetrics.inputMetrics.get.bytesRead
- }
- })
- sc.textFile("file://" + file.getAbsolutePath, 2).count()
+ // Path to tmpFile
+ tmpFilePath = "file://" + tmpFile.getAbsolutePath
+ }
- // Wait for task end events to come in
- sc.listenerBus.waitUntilEmpty(500)
- assert(taskBytesRead.length == 2)
- assert(taskBytesRead.sum >= file.length())
+ override def afterAll() {
+ super.afterAll()
+ Utils.deleteRecursively(tmpDir)
}
- test("input metrics when reading text file with multiple splits") {
- val file = new File(getClass.getSimpleName + ".txt")
- val pw = new PrintWriter(new FileWriter(file))
- for (i <- 0 until 10000) {
- pw.println("some stuff")
+ test("input metrics for old hadoop with coalesce") {
+ val bytesRead = runAndReturnBytesRead {
+ sc.textFile(tmpFilePath, 4).count()
+ }
+ val bytesRead2 = runAndReturnBytesRead {
+ sc.textFile(tmpFilePath, 4).coalesce(2).count()
+ }
+ assert(bytesRead != 0)
+ assert(bytesRead == bytesRead2)
+ assert(bytesRead2 >= tmpFile.length())
+ }
+
+ test("input metrics with cache and coalesce") {
+ // prime the cache manager
+ val rdd = sc.textFile(tmpFilePath, 4).cache()
+ rdd.collect()
+
+ val bytesRead = runAndReturnBytesRead {
+ rdd.count()
+ }
+ val bytesRead2 = runAndReturnBytesRead {
+ rdd.coalesce(4).count()
+ }
+
+ // for count and coelesce, the same bytes should be read.
+ assert(bytesRead != 0)
+ assert(bytesRead2 == bytesRead)
+ }
+
+ /**
+ * This checks the situation where we have interleaved reads from
+ * different sources. Currently, we only accumulate fron the first
+ * read method we find in the task. This test uses cartesian to create
+ * the interleaved reads.
+ *
+ * Once https://issues.apache.org/jira/browse/SPARK-5225 is fixed
+ * this test should break.
+ */
+ test("input metrics with mixed read method") {
+ // prime the cache manager
+ val numPartitions = 2
+ val rdd = sc.parallelize(1 to 100, numPartitions).cache()
+ rdd.collect()
+
+ val rdd2 = sc.textFile(tmpFilePath, numPartitions)
+
+ val bytesRead = runAndReturnBytesRead {
+ rdd.count()
+ }
+ val bytesRead2 = runAndReturnBytesRead {
+ rdd2.count()
+ }
+
+ val cartRead = runAndReturnBytesRead {
+ rdd.cartesian(rdd2).count()
+ }
+
+ assert(cartRead != 0)
+ assert(bytesRead != 0)
+ // We read from the first rdd of the cartesian once per partition.
+ assert(cartRead == bytesRead * numPartitions)
+ }
+
+ test("input metrics for new Hadoop API with coalesce") {
+ val bytesRead = runAndReturnBytesRead {
+ sc.newAPIHadoopFile(tmpFilePath, classOf[NewTextInputFormat], classOf[LongWritable],
+ classOf[Text]).count()
+ }
+ val bytesRead2 = runAndReturnBytesRead {
+ sc.newAPIHadoopFile(tmpFilePath, classOf[NewTextInputFormat], classOf[LongWritable],
+ classOf[Text]).coalesce(5).count()
+ }
+ assert(bytesRead != 0)
+ assert(bytesRead2 == bytesRead)
+ assert(bytesRead >= tmpFile.length())
+ }
+
+ test("input metrics when reading text file") {
+ val bytesRead = runAndReturnBytesRead {
+ sc.textFile(tmpFilePath, 2).count()
+ }
+ assert(bytesRead >= tmpFile.length())
+ }
+
+ test("input metrics with interleaved reads") {
+ val numPartitions = 2
+ val cartVector = 0 to 9
+ val cartFile = new File(tmpDir, getClass.getSimpleName + "_cart.txt")
+ val cartFilePath = "file://" + cartFile.getAbsolutePath
+
+ // write files to disk so we can read them later.
+ sc.parallelize(cartVector).saveAsTextFile(cartFilePath)
+ val aRdd = sc.textFile(cartFilePath, numPartitions)
+
+ val tmpRdd = sc.textFile(tmpFilePath, numPartitions)
+
+ val firstSize= runAndReturnBytesRead {
+ aRdd.count()
+ }
+ val secondSize = runAndReturnBytesRead {
+ tmpRdd.count()
}
- pw.close()
- file.deleteOnExit()
+ val cartesianBytes = runAndReturnBytesRead {
+ aRdd.cartesian(tmpRdd).count()
+ }
+
+ // Computing the amount of bytes read for a cartesian operation is a little involved.
+ // Cartesian interleaves reads between two partitions eg. p1 and p2.
+ // Here are the steps:
+ // 1) First it creates an iterator for p1
+ // 2) Creates an iterator for p2
+ // 3) Reads the first element of p1 and then all the elements of p2
+ // 4) proceeds to the next element of p1
+ // 5) Creates a new iterator for p2
+ // 6) rinse and repeat.
+ // As a result we read from the second partition n times where n is the number of keys in
+ // p1. Thus the math below for the test.
+ assert(cartesianBytes != 0)
+ assert(cartesianBytes == firstSize * numPartitions + (cartVector.length * secondSize))
+ }
+
+ private def runAndReturnBytesRead(job : => Unit): Long = {
val taskBytesRead = new ArrayBuffer[Long]()
sc.addSparkListener(new SparkListener() {
override def onTaskEnd(taskEnd: SparkListenerTaskEnd) {
taskBytesRead += taskEnd.taskMetrics.inputMetrics.get.bytesRead
}
})
- sc.textFile("file://" + file.getAbsolutePath, 2).count()
- // Wait for task end events to come in
+ job
+
sc.listenerBus.waitUntilEmpty(500)
- assert(taskBytesRead.length == 2)
- assert(taskBytesRead.sum >= file.length())
+ taskBytesRead.sum
}
test("output metrics when writing text file") {
val fs = FileSystem.getLocal(new Configuration())
val outPath = new Path(fs.getWorkingDirectory, "outdir")
- if (SparkHadoopUtil.get.getFSBytesWrittenOnThreadCallback(outPath, fs.getConf).isDefined) {
+ if (SparkHadoopUtil.get.getFSBytesWrittenOnThreadCallback().isDefined) {
val taskBytesWritten = new ArrayBuffer[Long]()
sc.addSparkListener(new SparkListener() {
override def onTaskEnd(taskEnd: SparkListenerTaskEnd) {
@@ -106,4 +234,88 @@ class InputOutputMetricsSuite extends FunSuite with SharedSparkContext with Matc
}
}
}
+
+ test("input metrics with old CombineFileInputFormat") {
+ val bytesRead = runAndReturnBytesRead {
+ sc.hadoopFile(tmpFilePath, classOf[OldCombineTextInputFormat], classOf[LongWritable],
+ classOf[Text], 2).count()
+ }
+ assert(bytesRead >= tmpFile.length())
+ }
+
+ test("input metrics with new CombineFileInputFormat") {
+ val bytesRead = runAndReturnBytesRead {
+ sc.newAPIHadoopFile(tmpFilePath, classOf[NewCombineTextInputFormat], classOf[LongWritable],
+ classOf[Text], new Configuration()).count()
+ }
+ assert(bytesRead >= tmpFile.length())
+ }
+}
+
+/**
+ * Hadoop 2 has a version of this, but we can't use it for backwards compatibility
+ */
+class OldCombineTextInputFormat extends OldCombineFileInputFormat[LongWritable, Text] {
+ override def getRecordReader(split: OldInputSplit, conf: JobConf, reporter: Reporter)
+ : OldRecordReader[LongWritable, Text] = {
+ new OldCombineFileRecordReader[LongWritable, Text](conf,
+ split.asInstanceOf[OldCombineFileSplit], reporter, classOf[OldCombineTextRecordReaderWrapper]
+ .asInstanceOf[Class[OldRecordReader[LongWritable, Text]]])
+ }
+}
+
+class OldCombineTextRecordReaderWrapper(
+ split: OldCombineFileSplit,
+ conf: Configuration,
+ reporter: Reporter,
+ idx: Integer) extends OldRecordReader[LongWritable, Text] {
+
+ val fileSplit = new OldFileSplit(split.getPath(idx),
+ split.getOffset(idx),
+ split.getLength(idx),
+ split.getLocations())
+
+ val delegate: OldLineRecordReader = new OldTextInputFormat().getRecordReader(fileSplit,
+ conf.asInstanceOf[JobConf], reporter).asInstanceOf[OldLineRecordReader]
+
+ override def next(key: LongWritable, value: Text): Boolean = delegate.next(key, value)
+ override def createKey(): LongWritable = delegate.createKey()
+ override def createValue(): Text = delegate.createValue()
+ override def getPos(): Long = delegate.getPos
+ override def close(): Unit = delegate.close()
+ override def getProgress(): Float = delegate.getProgress
+}
+
+/**
+ * Hadoop 2 has a version of this, but we can't use it for backwards compatibility
+ */
+class NewCombineTextInputFormat extends NewCombineFileInputFormat[LongWritable,Text] {
+ def createRecordReader(split: NewInputSplit, context: TaskAttemptContext)
+ : NewRecordReader[LongWritable, Text] = {
+ new NewCombineFileRecordReader[LongWritable,Text](split.asInstanceOf[NewCombineFileSplit],
+ context, classOf[NewCombineTextRecordReaderWrapper])
+ }
}
+
+class NewCombineTextRecordReaderWrapper(
+ split: NewCombineFileSplit,
+ context: TaskAttemptContext,
+ idx: Integer) extends NewRecordReader[LongWritable, Text] {
+
+ val fileSplit = new NewFileSplit(split.getPath(idx),
+ split.getOffset(idx),
+ split.getLength(idx),
+ split.getLocations())
+
+ val delegate = new NewTextInputFormat().createRecordReader(fileSplit, context)
+
+ override def initialize(split: NewInputSplit, context: TaskAttemptContext): Unit = {
+ delegate.initialize(fileSplit, context)
+ }
+
+ override def nextKeyValue(): Boolean = delegate.nextKeyValue()
+ override def getCurrentKey(): LongWritable = delegate.getCurrentKey
+ override def getCurrentValue(): Text = delegate.getCurrentValue
+ override def getProgress(): Float = delegate.getProgress
+ override def close(): Unit = delegate.close()
+}
\ No newline at end of file
diff --git a/core/src/test/scala/org/apache/spark/rdd/ParallelCollectionSplitSuite.scala b/core/src/test/scala/org/apache/spark/rdd/ParallelCollectionSplitSuite.scala
index 1b112f1a41ca9..cd193ae4f5238 100644
--- a/core/src/test/scala/org/apache/spark/rdd/ParallelCollectionSplitSuite.scala
+++ b/core/src/test/scala/org/apache/spark/rdd/ParallelCollectionSplitSuite.scala
@@ -76,6 +76,7 @@ class ParallelCollectionSplitSuite extends FunSuite with Checkers {
assert(slices(0).mkString(",") === (0 to 32).mkString(","))
assert(slices(1).mkString(",") === (33 to 66).mkString(","))
assert(slices(2).mkString(",") === (67 to 100).mkString(","))
+ assert(slices(2).isInstanceOf[Range.Inclusive])
}
test("empty data") {
@@ -227,4 +228,28 @@ class ParallelCollectionSplitSuite extends FunSuite with Checkers {
assert(slices.map(_.size).reduceLeft(_+_) === 100)
assert(slices.forall(_.isInstanceOf[NumericRange[_]]))
}
+
+ test("inclusive ranges with Int.MaxValue and Int.MinValue") {
+ val data1 = 1 to Int.MaxValue
+ val slices1 = ParallelCollectionRDD.slice(data1, 3)
+ assert(slices1.size === 3)
+ assert(slices1.map(_.size).sum === Int.MaxValue)
+ assert(slices1(2).isInstanceOf[Range.Inclusive])
+ val data2 = -2 to Int.MinValue by -1
+ val slices2 = ParallelCollectionRDD.slice(data2, 3)
+ assert(slices2.size == 3)
+ assert(slices2.map(_.size).sum === Int.MaxValue)
+ assert(slices2(2).isInstanceOf[Range.Inclusive])
+ }
+
+ test("empty ranges with Int.MaxValue and Int.MinValue") {
+ val data1 = Int.MaxValue until Int.MaxValue
+ val slices1 = ParallelCollectionRDD.slice(data1, 5)
+ assert(slices1.size === 5)
+ for (i <- 0 until 5) assert(slices1(i).size === 0)
+ val data2 = Int.MaxValue until Int.MaxValue
+ val slices2 = ParallelCollectionRDD.slice(data2, 5)
+ assert(slices2.size === 5)
+ for (i <- 0 until 5) assert(slices2(i).size === 0)
+ }
}
diff --git a/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala b/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala
index 0deb9b18b8688..e33b4bbbb8e4c 100644
--- a/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala
+++ b/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala
@@ -52,6 +52,7 @@ class RDDSuite extends FunSuite with SharedSparkContext {
assert(nums.glom().map(_.toList).collect().toList === List(List(1, 2), List(3, 4)))
assert(nums.collect({ case i if i >= 3 => i.toString }).collect().toList === List("3", "4"))
assert(nums.keyBy(_.toString).collect().toList === List(("1", 1), ("2", 2), ("3", 3), ("4", 4)))
+ assert(!nums.isEmpty())
assert(nums.max() === 4)
assert(nums.min() === 1)
val partitionSums = nums.mapPartitions(iter => Iterator(iter.reduceLeft(_ + _)))
@@ -545,6 +546,14 @@ class RDDSuite extends FunSuite with SharedSparkContext {
assert(sortedTopK === nums.sorted(ord).take(5))
}
+ test("isEmpty") {
+ assert(sc.emptyRDD.isEmpty())
+ assert(sc.parallelize(Seq[Int]()).isEmpty())
+ assert(!sc.parallelize(Seq(1)).isEmpty())
+ assert(sc.parallelize(Seq(1,2,3), 3).filter(_ < 0).isEmpty())
+ assert(!sc.parallelize(Seq(1,2,3), 3).filter(_ > 1).isEmpty())
+ }
+
test("sample preserves partitioner") {
val partitioner = new HashPartitioner(2)
val rdd = sc.parallelize(Seq((0, 1), (2, 3))).partitionBy(partitioner)
@@ -918,4 +927,44 @@ class RDDSuite extends FunSuite with SharedSparkContext {
mutableDependencies += dep
}
}
+
+ test("nested RDDs are not supported (SPARK-5063)") {
+ val rdd: RDD[Int] = sc.parallelize(1 to 100)
+ val rdd2: RDD[Int] = sc.parallelize(1 to 100)
+ val thrown = intercept[SparkException] {
+ val nestedRDD: RDD[RDD[Int]] = rdd.mapPartitions { x => Seq(rdd2.map(x => x)).iterator }
+ nestedRDD.count()
+ }
+ assert(thrown.getMessage.contains("SPARK-5063"))
+ }
+
+ test("actions cannot be performed inside of transformations (SPARK-5063)") {
+ val rdd: RDD[Int] = sc.parallelize(1 to 100)
+ val rdd2: RDD[Int] = sc.parallelize(1 to 100)
+ val thrown = intercept[SparkException] {
+ rdd.map(x => x * rdd2.count).collect()
+ }
+ assert(thrown.getMessage.contains("SPARK-5063"))
+ }
+
+ test("cannot run actions after SparkContext has been stopped (SPARK-5063)") {
+ val existingRDD = sc.parallelize(1 to 100)
+ sc.stop()
+ val thrown = intercept[IllegalStateException] {
+ existingRDD.count()
+ }
+ assert(thrown.getMessage.contains("shutdown"))
+ }
+
+ test("cannot call methods on a stopped SparkContext (SPARK-5063)") {
+ sc.stop()
+ def assertFails(block: => Any): Unit = {
+ val thrown = intercept[IllegalStateException] {
+ block
+ }
+ assert(thrown.getMessage.contains("stopped"))
+ }
+ assertFails { sc.parallelize(1 to 100) }
+ assertFails { sc.textFile("/nonexistent-path") }
+ }
}
diff --git a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala
index d30eb10bbe947..eb116213f69fc 100644
--- a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala
@@ -19,9 +19,8 @@ package org.apache.spark.scheduler
import scala.collection.mutable.{ArrayBuffer, HashSet, HashMap, Map}
import scala.language.reflectiveCalls
+import scala.util.control.NonFatal
-import akka.actor._
-import akka.testkit.{ImplicitSender, TestKit, TestActorRef}
import org.scalatest.{BeforeAndAfter, FunSuiteLike}
import org.scalatest.concurrent.Timeouts
import org.scalatest.time.SpanSugar._
@@ -33,10 +32,16 @@ import org.apache.spark.storage.{BlockId, BlockManagerId, BlockManagerMaster}
import org.apache.spark.util.CallSite
import org.apache.spark.executor.TaskMetrics
-class BuggyDAGEventProcessActor extends Actor {
- val state = 0
- def receive = {
- case _ => throw new SparkException("error")
+class DAGSchedulerEventProcessLoopTester(dagScheduler: DAGScheduler)
+ extends DAGSchedulerEventProcessLoop(dagScheduler) {
+
+ override def post(event: DAGSchedulerEvent): Unit = {
+ try {
+ // Forward event to `onReceive` directly to avoid processing event asynchronously.
+ onReceive(event)
+ } catch {
+ case NonFatal(e) => onError(e)
+ }
}
}
@@ -65,8 +70,7 @@ class MyRDD(
class DAGSchedulerSuiteDummyException extends Exception
-class DAGSchedulerSuite extends TestKit(ActorSystem("DAGSchedulerSuite")) with FunSuiteLike
- with ImplicitSender with BeforeAndAfter with LocalSparkContext with Timeouts {
+class DAGSchedulerSuite extends FunSuiteLike with BeforeAndAfter with LocalSparkContext with Timeouts {
val conf = new SparkConf
/** Set of TaskSets the DAGScheduler has requested executed. */
@@ -113,7 +117,7 @@ class DAGSchedulerSuite extends TestKit(ActorSystem("DAGSchedulerSuite")) with F
var mapOutputTracker: MapOutputTrackerMaster = null
var scheduler: DAGScheduler = null
- var dagEventProcessTestActor: TestActorRef[DAGSchedulerEventProcessActor] = null
+ var dagEventProcessLoopTester: DAGSchedulerEventProcessLoop = null
/**
* Set of cache locations to return from our mock BlockManagerMaster.
@@ -167,13 +171,11 @@ class DAGSchedulerSuite extends TestKit(ActorSystem("DAGSchedulerSuite")) with F
runLocallyWithinThread(job)
}
}
- dagEventProcessTestActor = TestActorRef[DAGSchedulerEventProcessActor](
- Props(classOf[DAGSchedulerEventProcessActor], scheduler))(system)
+ dagEventProcessLoopTester = new DAGSchedulerEventProcessLoopTester(scheduler)
}
override def afterAll() {
super.afterAll()
- TestKit.shutdownActorSystem(system)
}
/**
@@ -190,7 +192,7 @@ class DAGSchedulerSuite extends TestKit(ActorSystem("DAGSchedulerSuite")) with F
* DAGScheduler event loop.
*/
private def runEvent(event: DAGSchedulerEvent) {
- dagEventProcessTestActor.receive(event)
+ dagEventProcessLoopTester.post(event)
}
/**
@@ -397,8 +399,7 @@ class DAGSchedulerSuite extends TestKit(ActorSystem("DAGSchedulerSuite")) with F
runLocallyWithinThread(job)
}
}
- dagEventProcessTestActor = TestActorRef[DAGSchedulerEventProcessActor](
- Props(classOf[DAGSchedulerEventProcessActor], noKillScheduler))(system)
+ dagEventProcessLoopTester = new DAGSchedulerEventProcessLoopTester(noKillScheduler)
val jobId = submit(new MyRDD(sc, 1, Nil), Array(0))
cancel(jobId)
// Because the job wasn't actually cancelled, we shouldn't have received a failure message.
@@ -726,18 +727,6 @@ class DAGSchedulerSuite extends TestKit(ActorSystem("DAGSchedulerSuite")) with F
assert(sc.parallelize(1 to 10, 2).first() === 1)
}
- test("DAGSchedulerActorSupervisor closes the SparkContext when EventProcessActor crashes") {
- val actorSystem = ActorSystem("test")
- val supervisor = actorSystem.actorOf(
- Props(classOf[DAGSchedulerActorSupervisor], scheduler), "dagSupervisor")
- supervisor ! Props[BuggyDAGEventProcessActor]
- val child = expectMsgType[ActorRef]
- watch(child)
- child ! "hi"
- expectMsgPF(){ case Terminated(child) => () }
- assert(scheduler.sc.dagScheduler === null)
- }
-
test("accumulator not calculated for resubmitted result stage") {
//just for register
val accum = new Accumulator[Int](0, AccumulatorParam.IntAccumulatorParam)
diff --git a/core/src/test/scala/org/apache/spark/scheduler/EventLoggingListenerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/EventLoggingListenerSuite.scala
index 1de7e130039a5..437d8693c0b1f 100644
--- a/core/src/test/scala/org/apache/spark/scheduler/EventLoggingListenerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/scheduler/EventLoggingListenerSuite.scala
@@ -160,7 +160,7 @@ class EventLoggingListenerSuite extends FunSuite with BeforeAndAfter with Loggin
*/
private def testApplicationEventLogging(compressionCodec: Option[String] = None) {
val conf = getLoggingConf(testDirPath, compressionCodec)
- val sc = new SparkContext("local", "test", conf)
+ val sc = new SparkContext("local-cluster[2,2,512]", "test", conf)
assert(sc.eventLogger.isDefined)
val eventLogger = sc.eventLogger.get
val expectedLogDir = testDir.toURI().toString()
@@ -184,6 +184,7 @@ class EventLoggingListenerSuite extends FunSuite with BeforeAndAfter with Loggin
val eventSet = mutable.Set(
SparkListenerApplicationStart,
SparkListenerBlockManagerAdded,
+ SparkListenerExecutorAdded,
SparkListenerEnvironmentUpdate,
SparkListenerJobStart,
SparkListenerJobEnd,
diff --git a/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala
index 24f41bf8cccda..0fb1bdd30d975 100644
--- a/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala
@@ -34,6 +34,8 @@ class SparkListenerSuite extends FunSuite with LocalSparkContext with Matchers
/** Length of time to wait while draining listener events. */
val WAIT_TIMEOUT_MILLIS = 10000
+ val jobCompletionTime = 1421191296660L
+
before {
sc = new SparkContext("local", "SparkListenerSuite")
}
@@ -44,7 +46,7 @@ class SparkListenerSuite extends FunSuite with LocalSparkContext with Matchers
bus.addListener(counter)
// Listener bus hasn't started yet, so posting events should not increment counter
- (1 to 5).foreach { _ => bus.post(SparkListenerJobEnd(0, JobSucceeded)) }
+ (1 to 5).foreach { _ => bus.post(SparkListenerJobEnd(0, jobCompletionTime, JobSucceeded)) }
assert(counter.count === 0)
// Starting listener bus should flush all buffered events
@@ -54,7 +56,7 @@ class SparkListenerSuite extends FunSuite with LocalSparkContext with Matchers
// After listener bus has stopped, posting events should not increment counter
bus.stop()
- (1 to 5).foreach { _ => bus.post(SparkListenerJobEnd(0, JobSucceeded)) }
+ (1 to 5).foreach { _ => bus.post(SparkListenerJobEnd(0, jobCompletionTime, JobSucceeded)) }
assert(counter.count === 5)
// Listener bus must not be started twice
@@ -99,7 +101,7 @@ class SparkListenerSuite extends FunSuite with LocalSparkContext with Matchers
bus.addListener(blockingListener)
bus.start()
- bus.post(SparkListenerJobEnd(0, JobSucceeded))
+ bus.post(SparkListenerJobEnd(0, jobCompletionTime, JobSucceeded))
listenerStarted.acquire()
// Listener should be blocked after start
@@ -345,7 +347,7 @@ class SparkListenerSuite extends FunSuite with LocalSparkContext with Matchers
bus.start()
// Post events to all listeners, and wait until the queue is drained
- (1 to 5).foreach { _ => bus.post(SparkListenerJobEnd(0, JobSucceeded)) }
+ (1 to 5).foreach { _ => bus.post(SparkListenerJobEnd(0, jobCompletionTime, JobSucceeded)) }
assert(bus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS))
// The exception should be caught, and the event should be propagated to other listeners
diff --git a/core/src/test/scala/org/apache/spark/scheduler/SparkListenerWithClusterSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/SparkListenerWithClusterSuite.scala
new file mode 100644
index 0000000000000..623a687c359a2
--- /dev/null
+++ b/core/src/test/scala/org/apache/spark/scheduler/SparkListenerWithClusterSuite.scala
@@ -0,0 +1,62 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.scheduler
+
+import org.apache.spark.scheduler.cluster.ExecutorInfo
+import org.apache.spark.{SparkContext, LocalSparkContext}
+
+import org.scalatest.{FunSuite, BeforeAndAfter, BeforeAndAfterAll}
+
+import scala.collection.mutable
+
+/**
+ * Unit tests for SparkListener that require a local cluster.
+ */
+class SparkListenerWithClusterSuite extends FunSuite with LocalSparkContext
+ with BeforeAndAfter with BeforeAndAfterAll {
+
+ /** Length of time to wait while draining listener events. */
+ val WAIT_TIMEOUT_MILLIS = 10000
+
+ before {
+ sc = new SparkContext("local-cluster[2,1,512]", "SparkListenerSuite")
+ }
+
+ test("SparkListener sends executor added message") {
+ val listener = new SaveExecutorInfo
+ sc.addSparkListener(listener)
+
+ val rdd1 = sc.parallelize(1 to 100, 4)
+ val rdd2 = rdd1.map(_.toString)
+ rdd2.setName("Target RDD")
+ rdd2.count()
+
+ assert(sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS))
+ assert(listener.addedExecutorInfo.size == 2)
+ assert(listener.addedExecutorInfo("0").totalCores == 1)
+ assert(listener.addedExecutorInfo("1").totalCores == 1)
+ }
+
+ private class SaveExecutorInfo extends SparkListener {
+ val addedExecutorInfo = mutable.Map[String, ExecutorInfo]()
+
+ override def onExecutorAdded(executor: SparkListenerExecutorAdded) {
+ addedExecutorInfo(executor.executorId) = executor.executorInfo
+ }
+ }
+}
diff --git a/core/src/test/scala/org/apache/spark/scheduler/mesos/MesosSchedulerBackendSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/mesos/MesosSchedulerBackendSuite.scala
index 48f5e40f506d9..073814c127edc 100644
--- a/core/src/test/scala/org/apache/spark/scheduler/mesos/MesosSchedulerBackendSuite.scala
+++ b/core/src/test/scala/org/apache/spark/scheduler/mesos/MesosSchedulerBackendSuite.scala
@@ -17,23 +17,58 @@
package org.apache.spark.scheduler.mesos
+import org.apache.spark.executor.MesosExecutorBackend
import org.scalatest.FunSuite
-import org.apache.spark.{scheduler, SparkConf, SparkContext, LocalSparkContext}
-import org.apache.spark.scheduler.{TaskDescription, WorkerOffer, TaskSchedulerImpl}
+import org.apache.spark.{SparkConf, SparkContext, LocalSparkContext}
+import org.apache.spark.scheduler.{SparkListenerExecutorAdded, LiveListenerBus,
+ TaskDescription, WorkerOffer, TaskSchedulerImpl}
+import org.apache.spark.scheduler.cluster.ExecutorInfo
import org.apache.spark.scheduler.cluster.mesos.{MemoryUtils, MesosSchedulerBackend}
import org.apache.mesos.SchedulerDriver
-import org.apache.mesos.Protos._
-import org.scalatest.mock.EasyMockSugar
+import org.apache.mesos.Protos.{ExecutorInfo => MesosExecutorInfo, _}
import org.apache.mesos.Protos.Value.Scalar
import org.easymock.{Capture, EasyMock}
import java.nio.ByteBuffer
import java.util.Collections
import java.util
+import org.scalatest.mock.EasyMockSugar
+
import scala.collection.mutable
import scala.collection.mutable.ArrayBuffer
class MesosSchedulerBackendSuite extends FunSuite with LocalSparkContext with EasyMockSugar {
+ test("check spark-class location correctly") {
+ val conf = new SparkConf
+ conf.set("spark.mesos.executor.home" , "/mesos-home")
+
+ val listenerBus = EasyMock.createMock(classOf[LiveListenerBus])
+ listenerBus.post(SparkListenerExecutorAdded("s1", new ExecutorInfo("host1", 2)))
+ EasyMock.replay(listenerBus)
+
+ val sc = EasyMock.createMock(classOf[SparkContext])
+ EasyMock.expect(sc.getSparkHome()).andReturn(Option("/spark-home")).anyTimes()
+ EasyMock.expect(sc.conf).andReturn(conf).anyTimes()
+ EasyMock.expect(sc.executorEnvs).andReturn(new mutable.HashMap).anyTimes()
+ EasyMock.expect(sc.executorMemory).andReturn(100).anyTimes()
+ EasyMock.expect(sc.listenerBus).andReturn(listenerBus)
+ EasyMock.replay(sc)
+ val taskScheduler = EasyMock.createMock(classOf[TaskSchedulerImpl])
+ EasyMock.expect(taskScheduler.CPUS_PER_TASK).andReturn(2).anyTimes()
+ EasyMock.replay(taskScheduler)
+
+ val mesosSchedulerBackend = new MesosSchedulerBackend(taskScheduler, sc, "master")
+
+ // uri is null.
+ val executorInfo = mesosSchedulerBackend.createExecutorInfo("test-id")
+ assert(executorInfo.getCommand.getValue === s" /mesos-home/bin/spark-class ${classOf[MesosExecutorBackend].getName}")
+
+ // uri exists.
+ conf.set("spark.executor.uri", "hdfs:///test-app-1.0.0.tgz")
+ val executorInfo1 = mesosSchedulerBackend.createExecutorInfo("test-id")
+ assert(executorInfo1.getCommand.getValue === s"cd test-app-1*; ./bin/spark-class ${classOf[MesosExecutorBackend].getName}")
+ }
+
test("mesos resource offers result in launching tasks") {
def createOffer(id: Int, mem: Int, cpu: Int) = {
val builder = Offer.newBuilder()
@@ -52,11 +87,16 @@ class MesosSchedulerBackendSuite extends FunSuite with LocalSparkContext with Ea
val driver = EasyMock.createMock(classOf[SchedulerDriver])
val taskScheduler = EasyMock.createMock(classOf[TaskSchedulerImpl])
+ val listenerBus = EasyMock.createMock(classOf[LiveListenerBus])
+ listenerBus.post(SparkListenerExecutorAdded("s1", new ExecutorInfo("host1", 2)))
+ EasyMock.replay(listenerBus)
+
val sc = EasyMock.createMock(classOf[SparkContext])
EasyMock.expect(sc.executorMemory).andReturn(100).anyTimes()
EasyMock.expect(sc.getSparkHome()).andReturn(Option("/path")).anyTimes()
EasyMock.expect(sc.executorEnvs).andReturn(new mutable.HashMap).anyTimes()
EasyMock.expect(sc.conf).andReturn(new SparkConf).anyTimes()
+ EasyMock.expect(sc.listenerBus).andReturn(listenerBus)
EasyMock.replay(sc)
val minMem = MemoryUtils.calculateTotalMemory(sc).toInt
diff --git a/core/src/test/scala/org/apache/spark/scheduler/mesos/MesosTaskLaunchDataSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/mesos/MesosTaskLaunchDataSuite.scala
new file mode 100644
index 0000000000000..86a42a7398e4d
--- /dev/null
+++ b/core/src/test/scala/org/apache/spark/scheduler/mesos/MesosTaskLaunchDataSuite.scala
@@ -0,0 +1,38 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.scheduler.mesos
+
+import java.nio.ByteBuffer
+
+import org.scalatest.FunSuite
+
+import org.apache.spark.scheduler.cluster.mesos.MesosTaskLaunchData
+
+class MesosTaskLaunchDataSuite extends FunSuite {
+ test("serialize and deserialize data must be same") {
+ val serializedTask = ByteBuffer.allocate(40)
+ (Range(100, 110).map(serializedTask.putInt(_)))
+ serializedTask.rewind
+ val attemptNumber = 100
+ val byteString = MesosTaskLaunchData(serializedTask, attemptNumber).toByteString
+ serializedTask.rewind
+ val mesosTaskLaunchData = MesosTaskLaunchData.fromByteString(byteString)
+ assert(mesosTaskLaunchData.attemptNumber == attemptNumber)
+ assert(mesosTaskLaunchData.serializedTask.equals(serializedTask))
+ }
+}
diff --git a/core/src/test/scala/org/apache/spark/storage/LocalDirsSuite.scala b/core/src/test/scala/org/apache/spark/storage/LocalDirsSuite.scala
index dae7bf0e336de..8cf951adb354b 100644
--- a/core/src/test/scala/org/apache/spark/storage/LocalDirsSuite.scala
+++ b/core/src/test/scala/org/apache/spark/storage/LocalDirsSuite.scala
@@ -49,7 +49,7 @@ class LocalDirsSuite extends FunSuite {
}
override def clone: SparkConf = {
- new MySparkConf().setAll(settings)
+ new MySparkConf().setAll(getAll)
}
}
// spark.local.dir only contains invalid directories, but that's not a problem since
diff --git a/core/src/test/scala/org/apache/spark/ui/jobs/JobProgressListenerSuite.scala b/core/src/test/scala/org/apache/spark/ui/jobs/JobProgressListenerSuite.scala
index 12af60caf7d54..68074ae32a672 100644
--- a/core/src/test/scala/org/apache/spark/ui/jobs/JobProgressListenerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/ui/jobs/JobProgressListenerSuite.scala
@@ -28,6 +28,8 @@ import org.apache.spark.util.Utils
class JobProgressListenerSuite extends FunSuite with LocalSparkContext with Matchers {
+ val jobSubmissionTime = 1421191042750L
+ val jobCompletionTime = 1421191296660L
private def createStageStartEvent(stageId: Int) = {
val stageInfo = new StageInfo(stageId, 0, stageId.toString, 0, null, "")
@@ -46,12 +48,12 @@ class JobProgressListenerSuite extends FunSuite with LocalSparkContext with Matc
val stageInfos = stageIds.map { stageId =>
new StageInfo(stageId, 0, stageId.toString, 0, null, "")
}
- SparkListenerJobStart(jobId, stageInfos)
+ SparkListenerJobStart(jobId, jobSubmissionTime, stageInfos)
}
private def createJobEndEvent(jobId: Int, failed: Boolean = false) = {
val result = if (failed) JobFailed(new Exception("dummy failure")) else JobSucceeded
- SparkListenerJobEnd(jobId, result)
+ SparkListenerJobEnd(jobId, jobCompletionTime, result)
}
private def runJob(listener: SparkListener, jobId: Int, shouldFail: Boolean = false) {
@@ -138,7 +140,7 @@ class JobProgressListenerSuite extends FunSuite with LocalSparkContext with Matc
assert(listener.stageIdToData.size === 0)
// finish this task, should get updated shuffleRead
- shuffleReadMetrics.remoteBytesRead = 1000
+ shuffleReadMetrics.incRemoteBytesRead(1000)
taskMetrics.setShuffleReadMetrics(Some(shuffleReadMetrics))
var taskInfo = new TaskInfo(1234L, 0, 1, 0L, "exe-1", "host1", TaskLocality.NODE_LOCAL, false)
taskInfo.finishTime = 1
@@ -224,18 +226,18 @@ class JobProgressListenerSuite extends FunSuite with LocalSparkContext with Matc
val shuffleWriteMetrics = new ShuffleWriteMetrics()
taskMetrics.setShuffleReadMetrics(Some(shuffleReadMetrics))
taskMetrics.shuffleWriteMetrics = Some(shuffleWriteMetrics)
- shuffleReadMetrics.remoteBytesRead = base + 1
- shuffleReadMetrics.remoteBlocksFetched = base + 2
- shuffleWriteMetrics.shuffleBytesWritten = base + 3
- taskMetrics.executorRunTime = base + 4
- taskMetrics.diskBytesSpilled = base + 5
- taskMetrics.memoryBytesSpilled = base + 6
+ shuffleReadMetrics.incRemoteBytesRead(base + 1)
+ shuffleReadMetrics.incRemoteBlocksFetched(base + 2)
+ shuffleWriteMetrics.incShuffleBytesWritten(base + 3)
+ taskMetrics.setExecutorRunTime(base + 4)
+ taskMetrics.incDiskBytesSpilled(base + 5)
+ taskMetrics.incMemoryBytesSpilled(base + 6)
val inputMetrics = new InputMetrics(DataReadMethod.Hadoop)
- taskMetrics.inputMetrics = Some(inputMetrics)
- inputMetrics.bytesRead = base + 7
+ taskMetrics.setInputMetrics(Some(inputMetrics))
+ inputMetrics.addBytesRead(base + 7)
val outputMetrics = new OutputMetrics(DataWriteMethod.Hadoop)
taskMetrics.outputMetrics = Some(outputMetrics)
- outputMetrics.bytesWritten = base + 8
+ outputMetrics.setBytesWritten(base + 8)
taskMetrics
}
diff --git a/core/src/test/scala/org/apache/spark/util/EventLoopSuite.scala b/core/src/test/scala/org/apache/spark/util/EventLoopSuite.scala
new file mode 100644
index 0000000000000..1026cb2aa7cae
--- /dev/null
+++ b/core/src/test/scala/org/apache/spark/util/EventLoopSuite.scala
@@ -0,0 +1,206 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.util
+
+import java.util.concurrent.CountDownLatch
+
+import scala.collection.mutable
+import scala.concurrent.duration._
+import scala.language.postfixOps
+
+import org.scalatest.concurrent.Eventually._
+import org.scalatest.concurrent.Timeouts
+import org.scalatest.FunSuite
+
+class EventLoopSuite extends FunSuite with Timeouts {
+
+ test("EventLoop") {
+ val buffer = new mutable.ArrayBuffer[Int] with mutable.SynchronizedBuffer[Int]
+ val eventLoop = new EventLoop[Int]("test") {
+
+ override def onReceive(event: Int): Unit = {
+ buffer += event
+ }
+
+ override def onError(e: Throwable): Unit = {}
+ }
+ eventLoop.start()
+ (1 to 100).foreach(eventLoop.post)
+ eventually(timeout(5 seconds), interval(5 millis)) {
+ assert((1 to 100) === buffer.toSeq)
+ }
+ eventLoop.stop()
+ }
+
+ test("EventLoop: start and stop") {
+ val eventLoop = new EventLoop[Int]("test") {
+
+ override def onReceive(event: Int): Unit = {}
+
+ override def onError(e: Throwable): Unit = {}
+ }
+ assert(false === eventLoop.isActive)
+ eventLoop.start()
+ assert(true === eventLoop.isActive)
+ eventLoop.stop()
+ assert(false === eventLoop.isActive)
+ }
+
+ test("EventLoop: onError") {
+ val e = new RuntimeException("Oops")
+ @volatile var receivedError: Throwable = null
+ val eventLoop = new EventLoop[Int]("test") {
+
+ override def onReceive(event: Int): Unit = {
+ throw e
+ }
+
+ override def onError(e: Throwable): Unit = {
+ receivedError = e
+ }
+ }
+ eventLoop.start()
+ eventLoop.post(1)
+ eventually(timeout(5 seconds), interval(5 millis)) {
+ assert(e === receivedError)
+ }
+ eventLoop.stop()
+ }
+
+ test("EventLoop: error thrown from onError should not crash the event thread") {
+ val e = new RuntimeException("Oops")
+ @volatile var receivedError: Throwable = null
+ val eventLoop = new EventLoop[Int]("test") {
+
+ override def onReceive(event: Int): Unit = {
+ throw e
+ }
+
+ override def onError(e: Throwable): Unit = {
+ receivedError = e
+ throw new RuntimeException("Oops")
+ }
+ }
+ eventLoop.start()
+ eventLoop.post(1)
+ eventually(timeout(5 seconds), interval(5 millis)) {
+ assert(e === receivedError)
+ assert(eventLoop.isActive)
+ }
+ eventLoop.stop()
+ }
+
+ test("EventLoop: calling stop multiple times should only call onStop once") {
+ var onStopTimes = 0
+ val eventLoop = new EventLoop[Int]("test") {
+
+ override def onReceive(event: Int): Unit = {
+ }
+
+ override def onError(e: Throwable): Unit = {
+ }
+
+ override def onStop(): Unit = {
+ onStopTimes += 1
+ }
+ }
+
+ eventLoop.start()
+
+ eventLoop.stop()
+ eventLoop.stop()
+ eventLoop.stop()
+
+ assert(1 === onStopTimes)
+ }
+
+ test("EventLoop: post event in multiple threads") {
+ @volatile var receivedEventsCount = 0
+ val eventLoop = new EventLoop[Int]("test") {
+
+ override def onReceive(event: Int): Unit = {
+ receivedEventsCount += 1
+ }
+
+ override def onError(e: Throwable): Unit = {
+ }
+
+ }
+ eventLoop.start()
+
+ val threadNum = 5
+ val eventsFromEachThread = 100
+ (1 to threadNum).foreach { _ =>
+ new Thread() {
+ override def run(): Unit = {
+ (1 to eventsFromEachThread).foreach(eventLoop.post)
+ }
+ }.start()
+ }
+
+ eventually(timeout(5 seconds), interval(5 millis)) {
+ assert(threadNum * eventsFromEachThread === receivedEventsCount)
+ }
+ eventLoop.stop()
+ }
+
+ test("EventLoop: onReceive swallows InterruptException") {
+ val onReceiveLatch = new CountDownLatch(1)
+ val eventLoop = new EventLoop[Int]("test") {
+
+ override def onReceive(event: Int): Unit = {
+ onReceiveLatch.countDown()
+ try {
+ Thread.sleep(5000)
+ } catch {
+ case ie: InterruptedException => // swallow
+ }
+ }
+
+ override def onError(e: Throwable): Unit = {
+ }
+
+ }
+ eventLoop.start()
+ eventLoop.post(1)
+ failAfter(5 seconds) {
+ // Wait until we enter `onReceive`
+ onReceiveLatch.await()
+ eventLoop.stop()
+ }
+ assert(false === eventLoop.isActive)
+ }
+
+ test("EventLoop: stop in eventThread") {
+ val eventLoop = new EventLoop[Int]("test") {
+
+ override def onReceive(event: Int): Unit = {
+ stop()
+ }
+
+ override def onError(e: Throwable): Unit = {
+ }
+
+ }
+ eventLoop.start()
+ eventLoop.post(1)
+ eventually(timeout(5 seconds), interval(5 millis)) {
+ assert(!eventLoop.isActive)
+ }
+ }
+}
diff --git a/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala b/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala
index 63c2559c5c5f5..0357fc6ce2780 100644
--- a/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala
+++ b/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala
@@ -19,6 +19,7 @@ package org.apache.spark.util
import java.util.Properties
+import org.apache.spark.scheduler.cluster.ExecutorInfo
import org.apache.spark.shuffle.MetadataFetchFailedException
import scala.collection.Map
@@ -33,6 +34,9 @@ import org.apache.spark.storage._
class JsonProtocolSuite extends FunSuite {
+ val jobSubmissionTime = 1421191042750L
+ val jobCompletionTime = 1421191296660L
+
test("SparkListenerEvent") {
val stageSubmitted =
SparkListenerStageSubmitted(makeStageInfo(100, 200, 300, 400L, 500L), properties)
@@ -53,9 +57,9 @@ class JsonProtocolSuite extends FunSuite {
val stageIds = Seq[Int](1, 2, 3, 4)
val stageInfos = stageIds.map(x =>
makeStageInfo(x, x * 200, x * 300, x * 400L, x * 500L))
- SparkListenerJobStart(10, stageInfos, properties)
+ SparkListenerJobStart(10, jobSubmissionTime, stageInfos, properties)
}
- val jobEnd = SparkListenerJobEnd(20, JobSucceeded)
+ val jobEnd = SparkListenerJobEnd(20, jobCompletionTime, JobSucceeded)
val environmentUpdate = SparkListenerEnvironmentUpdate(Map[String, Seq[(String, String)]](
"JVM Information" -> Seq(("GC speed", "9999 objects/s"), ("Java home", "Land of coffee")),
"Spark Properties" -> Seq(("Job throughput", "80000 jobs/s, regardless of job type")),
@@ -69,6 +73,9 @@ class JsonProtocolSuite extends FunSuite {
val unpersistRdd = SparkListenerUnpersistRDD(12345)
val applicationStart = SparkListenerApplicationStart("The winner of all", None, 42L, "Garfield")
val applicationEnd = SparkListenerApplicationEnd(42L)
+ val executorAdded = SparkListenerExecutorAdded("exec1",
+ new ExecutorInfo("Hostee.awesome.com", 11))
+ val executorRemoved = SparkListenerExecutorRemoved("exec2")
testEvent(stageSubmitted, stageSubmittedJsonString)
testEvent(stageCompleted, stageCompletedJsonString)
@@ -85,6 +92,8 @@ class JsonProtocolSuite extends FunSuite {
testEvent(unpersistRdd, unpersistRDDJsonString)
testEvent(applicationStart, applicationStartJsonString)
testEvent(applicationEnd, applicationEndJsonString)
+ testEvent(executorAdded, executorAddedJsonString)
+ testEvent(executorRemoved, executorRemovedJsonString)
}
test("Dependent Classes") {
@@ -94,6 +103,7 @@ class JsonProtocolSuite extends FunSuite {
testTaskMetrics(makeTaskMetrics(
33333L, 44444L, 55555L, 66666L, 7, 8, hasHadoopInput = false, hasOutput = false))
testBlockManagerId(BlockManagerId("Hong", "Kong", 500))
+ testExecutorInfo(new ExecutorInfo("host", 43))
// StorageLevel
testStorageLevel(StorageLevel.NONE)
@@ -240,13 +250,31 @@ class JsonProtocolSuite extends FunSuite {
val stageInfos = stageIds.map(x => makeStageInfo(x, x * 200, x * 300, x * 400, x * 500))
val dummyStageInfos =
stageIds.map(id => new StageInfo(id, 0, "unknown", 0, Seq.empty, "unknown"))
- val jobStart = SparkListenerJobStart(10, stageInfos, properties)
+ val jobStart = SparkListenerJobStart(10, jobSubmissionTime, stageInfos, properties)
val oldEvent = JsonProtocol.jobStartToJson(jobStart).removeField({_._1 == "Stage Infos"})
val expectedJobStart =
- SparkListenerJobStart(10, dummyStageInfos, properties)
+ SparkListenerJobStart(10, jobSubmissionTime, dummyStageInfos, properties)
assertEquals(expectedJobStart, JsonProtocol.jobStartFromJson(oldEvent))
}
+ test("SparkListenerJobStart and SparkListenerJobEnd backward compatibility") {
+ // Prior to Spark 1.3.0, SparkListenerJobStart did not have a "Submission Time" property.
+ // Also, SparkListenerJobEnd did not have a "Completion Time" property.
+ val stageIds = Seq[Int](1, 2, 3, 4)
+ val stageInfos = stageIds.map(x => makeStageInfo(x * 10, x * 20, x * 30, x * 40, x * 50))
+ val jobStart = SparkListenerJobStart(11, jobSubmissionTime, stageInfos, properties)
+ val oldStartEvent = JsonProtocol.jobStartToJson(jobStart)
+ .removeField({ _._1 == "Submission Time"})
+ val expectedJobStart = SparkListenerJobStart(11, -1, stageInfos, properties)
+ assertEquals(expectedJobStart, JsonProtocol.jobStartFromJson(oldStartEvent))
+
+ val jobEnd = SparkListenerJobEnd(11, jobCompletionTime, JobSucceeded)
+ val oldEndEvent = JsonProtocol.jobEndToJson(jobEnd)
+ .removeField({ _._1 == "Completion Time"})
+ val expectedJobEnd = SparkListenerJobEnd(11, -1, JobSucceeded)
+ assertEquals(expectedJobEnd, JsonProtocol.jobEndFromJson(oldEndEvent))
+ }
+
/** -------------------------- *
| Helper test running methods |
* --------------------------- */
@@ -303,6 +331,10 @@ class JsonProtocolSuite extends FunSuite {
assert(blockId === newBlockId)
}
+ private def testExecutorInfo(info: ExecutorInfo) {
+ val newInfo = JsonProtocol.executorInfoFromJson(JsonProtocol.executorInfoToJson(info))
+ assertEquals(info, newInfo)
+ }
/** -------------------------------- *
| Util methods for comparing events |
@@ -335,6 +367,11 @@ class JsonProtocolSuite extends FunSuite {
assertEquals(e1.jobResult, e2.jobResult)
case (e1: SparkListenerEnvironmentUpdate, e2: SparkListenerEnvironmentUpdate) =>
assertEquals(e1.environmentDetails, e2.environmentDetails)
+ case (e1: SparkListenerExecutorAdded, e2: SparkListenerExecutorAdded) =>
+ assert(e1.executorId == e1.executorId)
+ assertEquals(e1.executorInfo, e2.executorInfo)
+ case (e1: SparkListenerExecutorRemoved, e2: SparkListenerExecutorRemoved) =>
+ assert(e1.executorId == e1.executorId)
case (e1, e2) =>
assert(e1 === e2)
case _ => fail("Events don't match in types!")
@@ -387,6 +424,11 @@ class JsonProtocolSuite extends FunSuite {
assert(info1.accumulables === info2.accumulables)
}
+ private def assertEquals(info1: ExecutorInfo, info2: ExecutorInfo) {
+ assert(info1.executorHost == info2.executorHost)
+ assert(info1.totalCores == info2.totalCores)
+ }
+
private def assertEquals(metrics1: TaskMetrics, metrics2: TaskMetrics) {
assert(metrics1.hostname === metrics2.hostname)
assert(metrics1.executorDeserializeTime === metrics2.executorDeserializeTime)
@@ -599,34 +641,34 @@ class JsonProtocolSuite extends FunSuite {
hasHadoopInput: Boolean,
hasOutput: Boolean) = {
val t = new TaskMetrics
- t.hostname = "localhost"
- t.executorDeserializeTime = a
- t.executorRunTime = b
- t.resultSize = c
- t.jvmGCTime = d
- t.resultSerializationTime = a + b
- t.memoryBytesSpilled = a + c
+ t.setHostname("localhost")
+ t.setExecutorDeserializeTime(a)
+ t.setExecutorRunTime(b)
+ t.setResultSize(c)
+ t.setJvmGCTime(d)
+ t.setResultSerializationTime(a + b)
+ t.incMemoryBytesSpilled(a + c)
if (hasHadoopInput) {
val inputMetrics = new InputMetrics(DataReadMethod.Hadoop)
- inputMetrics.bytesRead = d + e + f
- t.inputMetrics = Some(inputMetrics)
+ inputMetrics.addBytesRead(d + e + f)
+ t.setInputMetrics(Some(inputMetrics))
} else {
val sr = new ShuffleReadMetrics
- sr.remoteBytesRead = b + d
- sr.localBlocksFetched = e
- sr.fetchWaitTime = a + d
- sr.remoteBlocksFetched = f
+ sr.incRemoteBytesRead(b + d)
+ sr.incLocalBlocksFetched(e)
+ sr.incFetchWaitTime(a + d)
+ sr.incRemoteBlocksFetched(f)
t.setShuffleReadMetrics(Some(sr))
}
if (hasOutput) {
val outputMetrics = new OutputMetrics(DataWriteMethod.Hadoop)
- outputMetrics.bytesWritten = a + b + c
+ outputMetrics.setBytesWritten(a + b + c)
t.outputMetrics = Some(outputMetrics)
} else {
val sw = new ShuffleWriteMetrics
- sw.shuffleBytesWritten = a + b + c
- sw.shuffleWriteTime = b + c + d
+ sw.incShuffleBytesWritten(a + b + c)
+ sw.incShuffleWriteTime(b + c + d)
t.shuffleWriteMetrics = Some(sw)
}
// Make at most 6 blocks
@@ -1054,6 +1096,7 @@ class JsonProtocolSuite extends FunSuite {
|{
| "Event": "SparkListenerJobStart",
| "Job ID": 10,
+ | "Submission Time": 1421191042750,
| "Stage Infos": [
| {
| "Stage ID": 1,
@@ -1328,6 +1371,7 @@ class JsonProtocolSuite extends FunSuite {
|{
| "Event": "SparkListenerJobEnd",
| "Job ID": 20,
+ | "Completion Time": 1421191296660,
| "Job Result": {
| "Result": "JobSucceeded"
| }
@@ -1407,4 +1451,24 @@ class JsonProtocolSuite extends FunSuite {
| "Timestamp": 42
|}
"""
+
+ private val executorAddedJsonString =
+ """
+ |{
+ | "Event": "SparkListenerExecutorAdded",
+ | "Executor ID": "exec1",
+ | "Executor Info": {
+ | "Host": "Hostee.awesome.com",
+ | "Total Cores": 11
+ | }
+ |}
+ """
+
+ private val executorRemovedJsonString =
+ """
+ |{
+ | "Event": "SparkListenerExecutorRemoved",
+ | "Executor ID": "exec2"
+ |}
+ """
}
diff --git a/dev/create-release/create-release.sh b/dev/create-release/create-release.sh
index b1b8cb44e098b..b2a7e092a0291 100755
--- a/dev/create-release/create-release.sh
+++ b/dev/create-release/create-release.sh
@@ -122,8 +122,14 @@ if [[ ! "$@" =~ --package-only ]]; then
for file in $(find . -type f)
do
echo $GPG_PASSPHRASE | gpg --passphrase-fd 0 --output $file.asc --detach-sig --armour $file;
- gpg --print-md MD5 $file > $file.md5;
- gpg --print-md SHA1 $file > $file.sha1
+ if [ $(command -v md5) ]; then
+ # Available on OS X; -q to keep only hash
+ md5 -q $file > $file.md5
+ else
+ # Available on Linux; cut to keep only hash
+ md5sum $file | cut -f1 -d' ' > $file.md5
+ fi
+ shasum -a 1 $file | cut -f1 -d' ' > $file.sha1
done
nexus_upload=$NEXUS_ROOT/deployByRepositoryId/$staged_repo_id
diff --git a/docs/configuration.md b/docs/configuration.md
index 673cdb371a512..7c5b6d011cfd3 100644
--- a/docs/configuration.md
+++ b/docs/configuration.md
@@ -102,11 +102,10 @@ of the most common options to set are:
-
spark.executor.memory
-
512m
+
spark.driver.cores
+
1
- Amount of memory to use per executor process, in the same format as JVM memory strings
- (e.g. 512m, 2g).
+ Number of cores to use for the driver process, only in cluster mode.
@@ -117,6 +116,14 @@ of the most common options to set are:
(e.g. 512m, 2g).
+
+
spark.executor.memory
+
512m
+
+ Amount of memory to use per executor process, in the same format as JVM memory strings
+ (e.g. 512m, 2g).
+
+
spark.driver.maxResultSize
1g
@@ -190,6 +197,27 @@ Apart from these, the following properties are also available, and may be useful
#### Runtime Environment
Property Name
Default
Meaning
+
+
spark.driver.extraJavaOptions
+
(none)
+
+ A string of extra JVM options to pass to the driver. For instance, GC settings or other logging.
+
+
+
+
spark.driver.extraClassPath
+
(none)
+
+ Extra classpath entries to append to the classpath of the driver.
+
+
+
+
spark.driver.extraLibraryPath
+
(none)
+
+ Set a special library path to use when launching the driver JVM.
+
+
spark.executor.extraJavaOptions
(none)
diff --git a/docs/ml-guide.md b/docs/ml-guide.md
index 1c2e27341473b..be178d7689fdd 100644
--- a/docs/ml-guide.md
+++ b/docs/ml-guide.md
@@ -3,13 +3,16 @@ layout: global
title: Spark ML Programming Guide
---
-Spark ML is Spark's new machine learning package. It is currently an alpha component but is potentially a successor to [MLlib](mllib-guide.html). The `spark.ml` package aims to replace the old APIs with a cleaner, more uniform set of APIs which will help users create full machine learning pipelines.
+`spark.ml` is a new package introduced in Spark 1.2, which aims to provide a uniform set of
+high-level APIs that help users create and tune practical machine learning pipelines.
+It is currently an alpha component, and we would like to hear back from the community about
+how it fits real-world use cases and how it could be improved.
-MLlib vs. Spark ML:
-
-* Users can use algorithms from either of the two packages, but APIs may differ. Currently, `spark.ml` offers a subset of the algorithms from `spark.mllib`. Since Spark ML is an alpha component, its API may change in future releases.
-* Developers should contribute new algorithms to `spark.mllib` and can optionally contribute to `spark.ml`. See below for more details.
-* Spark ML only has Scala and Java APIs, whereas MLlib also has a Python API.
+Note that we will keep supporting and adding features to `spark.mllib` along with the
+development of `spark.ml`.
+Users should be comfortable using `spark.mllib` features and expect more features coming.
+Developers should contribute new algorithms to `spark.mllib` and can optionally contribute
+to `spark.ml`.
**Table of Contents**
@@ -686,17 +689,3 @@ Spark ML currently depends on MLlib and has the same dependencies.
Please see the [MLlib Dependencies guide](mllib-guide.html#Dependencies) for more info.
Spark ML also depends upon Spark SQL, but the relevant parts of Spark SQL do not bring additional dependencies.
-
-# Developers
-
-**Development plan**
-
-If all goes well, `spark.ml` will become the primary ML package at the time of the Spark 1.3 release. Initially, simple wrappers will be used to port algorithms to `spark.ml`, but eventually, code will be moved to `spark.ml` and `spark.mllib` will be deprecated.
-
-**Advice to developers**
-
-During the next development cycle, new algorithms should be contributed to `spark.mllib`, but we welcome patches sent to either package. If an algorithm is best expressed using the new API (e.g., feature transformers), we may ask for developers to use the new `spark.ml` API.
-Wrappers for old and new algorithms can be contributed to `spark.ml`.
-
-Users will be able to use algorithms from either of the two packages. The main difficulty will be the differences in APIs between the two packages.
-
diff --git a/docs/mllib-collaborative-filtering.md b/docs/mllib-collaborative-filtering.md
index 2094963392295..ef18cec9371d6 100644
--- a/docs/mllib-collaborative-filtering.md
+++ b/docs/mllib-collaborative-filtering.md
@@ -192,12 +192,11 @@ We use the default ALS.train() method which assumes ratings are explicit. We eva
recommendation by measuring the Mean Squared Error of rating prediction.
{% highlight python %}
-from pyspark.mllib.recommendation import ALS
-from numpy import array
+from pyspark.mllib.recommendation import ALS, Rating
# Load and parse the data
data = sc.textFile("data/mllib/als/test.data")
-ratings = data.map(lambda line: array([float(x) for x in line.split(',')]))
+ratings = data.map(lambda l: l.split(',')).map(lambda l: Rating(int(l[0]), int(l[1]), float(l[2])))
# Build the recommendation model using Alternating Least Squares
rank = 10
@@ -205,10 +204,10 @@ numIterations = 20
model = ALS.train(ratings, rank, numIterations)
# Evaluate the model on training data
-testdata = ratings.map(lambda p: (int(p[0]), int(p[1])))
+testdata = ratings.map(lambda p: (p[0], p[1]))
predictions = model.predictAll(testdata).map(lambda r: ((r[0], r[1]), r[2]))
ratesAndPreds = ratings.map(lambda r: ((r[0], r[1]), r[2])).join(predictions)
-MSE = ratesAndPreds.map(lambda r: (r[1][0] - r[1][1])**2).reduce(lambda x, y: x + y)/ratesAndPreds.count()
+MSE = ratesAndPreds.map(lambda r: (r[1][0] - r[1][1])**2).reduce(lambda x, y: x + y) / ratesAndPreds.count()
print("Mean Squared Error = " + str(MSE))
{% endhighlight %}
@@ -217,7 +216,7 @@ signals), you can use the trainImplicit method to get better results.
{% highlight python %}
# Build the recommendation model using Alternating Least Squares based on implicit ratings
-model = ALS.trainImplicit(ratings, rank, numIterations, alpha = 0.01)
+model = ALS.trainImplicit(ratings, rank, numIterations, alpha=0.01)
{% endhighlight %}
diff --git a/docs/mllib-guide.md b/docs/mllib-guide.md
index efd7dda310712..39c64d06926bf 100644
--- a/docs/mllib-guide.md
+++ b/docs/mllib-guide.md
@@ -35,16 +35,20 @@ MLlib is under active development.
The APIs marked `Experimental`/`DeveloperApi` may change in future releases,
and the migration guide below will explain all changes between releases.
-# spark.ml: The New ML Package
+# spark.ml: high-level APIs for ML pipelines
-Spark 1.2 includes a new machine learning package called `spark.ml`, currently an alpha component but potentially a successor to `spark.mllib`. The `spark.ml` package aims to replace the old APIs with a cleaner, more uniform set of APIs which will help users create full machine learning pipelines.
+Spark 1.2 includes a new package called `spark.ml`, which aims to provide a uniform set of
+high-level APIs that help users create and tune practical machine learning pipelines.
+It is currently an alpha component, and we would like to hear back from the community about
+how it fits real-world use cases and how it could be improved.
-See the **[spark.ml programming guide](ml-guide.html)** for more information on this package.
-
-Users can use algorithms from either of the two packages, but APIs may differ. Currently, `spark.ml` offers a subset of the algorithms from `spark.mllib`.
+Note that we will keep supporting and adding features to `spark.mllib` along with the
+development of `spark.ml`.
+Users should be comfortable using `spark.mllib` features and expect more features coming.
+Developers should contribute new algorithms to `spark.mllib` and can optionally contribute
+to `spark.ml`.
-Developers should contribute new algorithms to `spark.mllib` and can optionally contribute to `spark.ml`.
-See the `spark.ml` programming guide linked above for more details.
+See the **[spark.ml programming guide](ml-guide.html)** for more information on this package.
# Dependencies
diff --git a/docs/programming-guide.md b/docs/programming-guide.md
index 5e0d5c15d7069..2443fc29b4706 100644
--- a/docs/programming-guide.md
+++ b/docs/programming-guide.md
@@ -913,7 +913,7 @@ for details.
cogroup(otherDataset, [numTasks])
-
When called on datasets of type (K, V) and (K, W), returns a dataset of (K, Iterable<V>, Iterable<W>) tuples. This operation is also called groupWith.
+
When called on datasets of type (K, V) and (K, W), returns a dataset of (K, (Iterable<V>, Iterable<W>)) tuples. This operation is also called groupWith.
cartesian(otherDataset)
@@ -1316,7 +1316,35 @@ For accumulator updates performed inside actions only, Spark guarantees t
will only be applied once, i.e. restarted tasks will not update the value. In transformations, users should be aware
of that each task's update may be applied more than once if tasks or job stages are re-executed.
+Accumulators do not change the lazy evaluation model of Spark. If they are being updated within an operation on an RDD, their value is only updated once that RDD is computed as part of an action. Consequently, accumulator updates are not guaranteed to be executed when made within a lazy transformation like `map()`. The below code fragment demonstrates this property:
+
+
+
+{% highlight scala %}
+val acc = sc.accumulator(0)
+data.map(x => acc += x; f(x))
+// Here, acc is still 0 because no actions have cause the `map` to be computed.
+{% endhighlight %}
+
+
+
+{% highlight java %}
+Accumulator accum = sc.accumulator(0);
+data.map(x -> accum.add(x); f(x););
+// Here, accum is still 0 because no actions have cause the `map` to be computed.
+{% endhighlight %}
+
+
+
+{% highlight python %}
+accum = sc.accumulator(0)
+data.map(lambda x => acc.add(x); f(x))
+# Here, acc is still 0 because no actions have cause the `map` to be computed.
+{% endhighlight %}
+
+
+
# Deploying to a Cluster
diff --git a/docs/running-on-yarn.md b/docs/running-on-yarn.md
index 4f273098c5db3..68ab127bcf087 100644
--- a/docs/running-on-yarn.md
+++ b/docs/running-on-yarn.md
@@ -29,6 +29,23 @@ Most of the configs are the same for Spark on YARN as for other deployment modes
In cluster mode, use spark.driver.memory instead.
+
+
spark.driver.cores
+
1
+
+ Number of cores used by the driver in YARN cluster mode.
+ Since the driver is run in the same JVM as the YARN Application Master in cluster mode, this also controls the cores used by the YARN AM.
+ In client mode, use spark.yarn.am.cores to control the number of cores used by the YARN AM instead.
+
+
+
+
spark.yarn.am.cores
+
1
+
+ Number of cores to use for the YARN Application Master in client mode.
+ In cluster mode, use spark.driver.cores instead.
+
All data types of Spark SQL are located in the package of
-`org.apache.spark.sql.api.java`. To access or create a data type,
+`org.apache.spark.sql.types`. To access or create a data type,
please use factory methods provided in
-`org.apache.spark.sql.api.java.DataType`.
+`org.apache.spark.sql.types.DataTypes`.
@@ -1346,109 +1346,110 @@ please use factory methods provided in
- DataType.createArrayType(elementType)
+ DataTypes.createArrayType(elementType) Note: The value of containsNull will be true
- DataType.createArrayType(elementType, containsNull).
+ DataTypes.createArrayType(elementType, containsNull).
MapType
java.util.Map
- DataType.createMapType(keyType, valueType)
+ DataTypes.createMapType(keyType, valueType) Note: The value of valueContainsNull will be true.
- DataType.createMapType(keyType, valueType, valueContainsNull)
+ DataTypes.createMapType(keyType, valueType, valueContainsNull)
StructType
org.apache.spark.sql.api.java.Row
- DataType.createStructType(fields)
+ DataTypes.createStructType(fields) Note:fields is a List or an array of StructFields.
Also, two fields with the same name are not allowed.
@@ -1458,7 +1459,7 @@ please use factory methods provided in
The value type in Java of the data type of this field
(For example, int for a StructField with the data type IntegerType)
diff --git a/docs/streaming-kafka-integration.md b/docs/streaming-kafka-integration.md
index 0e38fe2144e9f..77c0abbbacbd0 100644
--- a/docs/streaming-kafka-integration.md
+++ b/docs/streaming-kafka-integration.md
@@ -29,7 +29,7 @@ title: Spark Streaming + Kafka Integration Guide
streamingContext, [zookeeperQuorum], [group id of the consumer], [per-topic number of Kafka partitions to consume]);
See the [API docs](api/java/index.html?org/apache/spark/streaming/kafka/KafkaUtils.html)
- and the [example]({{site.SPARK_GITHUB_URL}}/tree/master/examples/src/main/java/org/apache/spark/examples/streaming/JavaKafkaWordCount.java).
+ and the [example]({{site.SPARK_GITHUB_URL}}/blob/master/examples/scala-2.10/src/main/java/org/apache/spark/examples/streaming/JavaKafkaWordCount.java).
diff --git a/docs/submitting-applications.md b/docs/submitting-applications.md
index 3bd1deaccfafe..14a87f8436984 100644
--- a/docs/submitting-applications.md
+++ b/docs/submitting-applications.md
@@ -58,8 +58,8 @@ for applications that involve the REPL (e.g. Spark shell).
Alternatively, if your application is submitted from a machine far from the worker machines (e.g.
locally on your laptop), it is common to use `cluster` mode to minimize network latency between
-the drivers and the executors. Note that `cluster` mode is currently not supported for standalone
-clusters, Mesos clusters, or Python applications.
+the drivers and the executors. Note that `cluster` mode is currently not supported for
+Mesos clusters or Python applications.
For Python applications, simply pass a `.py` file in the place of `` instead of a JAR,
and add Python `.zip`, `.egg` or `.py` files to the search path with `--py-files`.
diff --git a/examples/scala-2.10/src/main/scala/org/apache/spark/examples/streaming/KafkaWordCount.scala b/examples/scala-2.10/src/main/scala/org/apache/spark/examples/streaming/KafkaWordCount.scala
index 2adc63f7ff30e..387c0e421334b 100644
--- a/examples/scala-2.10/src/main/scala/org/apache/spark/examples/streaming/KafkaWordCount.scala
+++ b/examples/scala-2.10/src/main/scala/org/apache/spark/examples/streaming/KafkaWordCount.scala
@@ -76,7 +76,7 @@ object KafkaWordCountProducer {
val Array(brokers, topic, messagesPerSec, wordsPerMessage) = args
- // Zookeper connection properties
+ // Zookeeper connection properties
val props = new Properties()
props.put("metadata.broker.list", brokers)
props.put("serializer.class", "kafka.serializer.StringEncoder")
diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaCrossValidatorExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaCrossValidatorExample.java
index f4b4f8d8c7b2f..0fbee6e433608 100644
--- a/examples/src/main/java/org/apache/spark/examples/ml/JavaCrossValidatorExample.java
+++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaCrossValidatorExample.java
@@ -33,9 +33,9 @@
import org.apache.spark.ml.tuning.CrossValidator;
import org.apache.spark.ml.tuning.CrossValidatorModel;
import org.apache.spark.ml.tuning.ParamGridBuilder;
-import org.apache.spark.sql.api.java.JavaSQLContext;
-import org.apache.spark.sql.api.java.JavaSchemaRDD;
-import org.apache.spark.sql.api.java.Row;
+import org.apache.spark.sql.DataFrame;
+import org.apache.spark.sql.SQLContext;
+import org.apache.spark.sql.Row;
/**
* A simple example demonstrating model selection using CrossValidator.
@@ -55,7 +55,7 @@ public class JavaCrossValidatorExample {
public static void main(String[] args) {
SparkConf conf = new SparkConf().setAppName("JavaCrossValidatorExample");
JavaSparkContext jsc = new JavaSparkContext(conf);
- JavaSQLContext jsql = new JavaSQLContext(jsc);
+ SQLContext jsql = new SQLContext(jsc);
// Prepare training documents, which are labeled.
List localTraining = Lists.newArrayList(
@@ -71,8 +71,7 @@ public static void main(String[] args) {
new LabeledDocument(9L, "a e c l", 0.0),
new LabeledDocument(10L, "spark compile", 1.0),
new LabeledDocument(11L, "hadoop software", 0.0));
- JavaSchemaRDD training =
- jsql.applySchema(jsc.parallelize(localTraining), LabeledDocument.class);
+ DataFrame training = jsql.applySchema(jsc.parallelize(localTraining), LabeledDocument.class);
// Configure an ML pipeline, which consists of three stages: tokenizer, hashingTF, and lr.
Tokenizer tokenizer = new Tokenizer()
@@ -113,11 +112,11 @@ public static void main(String[] args) {
new Document(5L, "l m n"),
new Document(6L, "mapreduce spark"),
new Document(7L, "apache hadoop"));
- JavaSchemaRDD test = jsql.applySchema(jsc.parallelize(localTest), Document.class);
+ DataFrame test = jsql.applySchema(jsc.parallelize(localTest), Document.class);
// Make predictions on test documents. cvModel uses the best model found (lrModel).
- cvModel.transform(test).registerAsTable("prediction");
- JavaSchemaRDD predictions = jsql.sql("SELECT id, text, score, prediction FROM prediction");
+ cvModel.transform(test).registerTempTable("prediction");
+ DataFrame predictions = jsql.sql("SELECT id, text, score, prediction FROM prediction");
for (Row r: predictions.collect()) {
System.out.println("(" + r.get(0) + ", " + r.get(1) + ") --> score=" + r.get(2)
+ ", prediction=" + r.get(3));
diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaSimpleParamsExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaSimpleParamsExample.java
index e25b271777ed4..eaaa344be49c8 100644
--- a/examples/src/main/java/org/apache/spark/examples/ml/JavaSimpleParamsExample.java
+++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaSimpleParamsExample.java
@@ -28,9 +28,9 @@
import org.apache.spark.ml.classification.LogisticRegression;
import org.apache.spark.mllib.linalg.Vectors;
import org.apache.spark.mllib.regression.LabeledPoint;
-import org.apache.spark.sql.api.java.JavaSQLContext;
-import org.apache.spark.sql.api.java.JavaSchemaRDD;
-import org.apache.spark.sql.api.java.Row;
+import org.apache.spark.sql.DataFrame;
+import org.apache.spark.sql.SQLContext;
+import org.apache.spark.sql.Row;
/**
* A simple example demonstrating ways to specify parameters for Estimators and Transformers.
@@ -44,17 +44,17 @@ public class JavaSimpleParamsExample {
public static void main(String[] args) {
SparkConf conf = new SparkConf().setAppName("JavaSimpleParamsExample");
JavaSparkContext jsc = new JavaSparkContext(conf);
- JavaSQLContext jsql = new JavaSQLContext(jsc);
+ SQLContext jsql = new SQLContext(jsc);
// Prepare training data.
// We use LabeledPoint, which is a JavaBean. Spark SQL can convert RDDs of JavaBeans
- // into SchemaRDDs, where it uses the bean metadata to infer the schema.
+ // into DataFrames, where it uses the bean metadata to infer the schema.
List localTraining = Lists.newArrayList(
new LabeledPoint(1.0, Vectors.dense(0.0, 1.1, 0.1)),
new LabeledPoint(0.0, Vectors.dense(2.0, 1.0, -1.0)),
new LabeledPoint(0.0, Vectors.dense(2.0, 1.3, 1.0)),
new LabeledPoint(1.0, Vectors.dense(0.0, 1.2, -0.5)));
- JavaSchemaRDD training = jsql.applySchema(jsc.parallelize(localTraining), LabeledPoint.class);
+ DataFrame training = jsql.applySchema(jsc.parallelize(localTraining), LabeledPoint.class);
// Create a LogisticRegression instance. This instance is an Estimator.
LogisticRegression lr = new LogisticRegression();
@@ -94,14 +94,14 @@ public static void main(String[] args) {
new LabeledPoint(1.0, Vectors.dense(-1.0, 1.5, 1.3)),
new LabeledPoint(0.0, Vectors.dense(3.0, 2.0, -0.1)),
new LabeledPoint(1.0, Vectors.dense(0.0, 2.2, -1.5)));
- JavaSchemaRDD test = jsql.applySchema(jsc.parallelize(localTest), LabeledPoint.class);
+ DataFrame test = jsql.applySchema(jsc.parallelize(localTest), LabeledPoint.class);
// Make predictions on test documents using the Transformer.transform() method.
// LogisticRegression.transform will only use the 'features' column.
// Note that model2.transform() outputs a 'probability' column instead of the usual 'score'
// column since we renamed the lr.scoreCol parameter previously.
- model2.transform(test).registerAsTable("results");
- JavaSchemaRDD results =
+ model2.transform(test).registerTempTable("results");
+ DataFrame results =
jsql.sql("SELECT features, label, probability, prediction FROM results");
for (Row r: results.collect()) {
System.out.println("(" + r.get(0) + ", " + r.get(1) + ") -> prob=" + r.get(2)
diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaSimpleTextClassificationPipeline.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaSimpleTextClassificationPipeline.java
index 54f18014e4b2f..82d665a3e1386 100644
--- a/examples/src/main/java/org/apache/spark/examples/ml/JavaSimpleTextClassificationPipeline.java
+++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaSimpleTextClassificationPipeline.java
@@ -21,6 +21,7 @@
import com.google.common.collect.Lists;
+import org.apache.spark.SparkConf;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.ml.Pipeline;
import org.apache.spark.ml.PipelineModel;
@@ -28,10 +29,9 @@
import org.apache.spark.ml.classification.LogisticRegression;
import org.apache.spark.ml.feature.HashingTF;
import org.apache.spark.ml.feature.Tokenizer;
-import org.apache.spark.sql.api.java.JavaSQLContext;
-import org.apache.spark.sql.api.java.JavaSchemaRDD;
-import org.apache.spark.sql.api.java.Row;
-import org.apache.spark.SparkConf;
+import org.apache.spark.sql.DataFrame;
+import org.apache.spark.sql.SQLContext;
+import org.apache.spark.sql.Row;
/**
* A simple text classification pipeline that recognizes "spark" from input text. It uses the Java
@@ -46,7 +46,7 @@ public class JavaSimpleTextClassificationPipeline {
public static void main(String[] args) {
SparkConf conf = new SparkConf().setAppName("JavaSimpleTextClassificationPipeline");
JavaSparkContext jsc = new JavaSparkContext(conf);
- JavaSQLContext jsql = new JavaSQLContext(jsc);
+ SQLContext jsql = new SQLContext(jsc);
// Prepare training documents, which are labeled.
List localTraining = Lists.newArrayList(
@@ -54,8 +54,7 @@ public static void main(String[] args) {
new LabeledDocument(1L, "b d", 0.0),
new LabeledDocument(2L, "spark f g h", 1.0),
new LabeledDocument(3L, "hadoop mapreduce", 0.0));
- JavaSchemaRDD training =
- jsql.applySchema(jsc.parallelize(localTraining), LabeledDocument.class);
+ DataFrame training = jsql.applySchema(jsc.parallelize(localTraining), LabeledDocument.class);
// Configure an ML pipeline, which consists of three stages: tokenizer, hashingTF, and lr.
Tokenizer tokenizer = new Tokenizer()
@@ -80,11 +79,11 @@ public static void main(String[] args) {
new Document(5L, "l m n"),
new Document(6L, "mapreduce spark"),
new Document(7L, "apache hadoop"));
- JavaSchemaRDD test = jsql.applySchema(jsc.parallelize(localTest), Document.class);
+ DataFrame test = jsql.applySchema(jsc.parallelize(localTest), Document.class);
// Make predictions on test documents.
- model.transform(test).registerAsTable("prediction");
- JavaSchemaRDD predictions = jsql.sql("SELECT id, text, score, prediction FROM prediction");
+ model.transform(test).registerTempTable("prediction");
+ DataFrame predictions = jsql.sql("SELECT id, text, score, prediction FROM prediction");
for (Row r: predictions.collect()) {
System.out.println("(" + r.get(0) + ", " + r.get(1) + ") --> score=" + r.get(2)
+ ", prediction=" + r.get(3));
diff --git a/examples/src/main/java/org/apache/spark/examples/sql/JavaSparkSQL.java b/examples/src/main/java/org/apache/spark/examples/sql/JavaSparkSQL.java
index 01c77bd44337e..8defb769ffaaf 100644
--- a/examples/src/main/java/org/apache/spark/examples/sql/JavaSparkSQL.java
+++ b/examples/src/main/java/org/apache/spark/examples/sql/JavaSparkSQL.java
@@ -26,9 +26,9 @@
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.api.java.function.Function;
-import org.apache.spark.sql.api.java.JavaSQLContext;
-import org.apache.spark.sql.api.java.JavaSchemaRDD;
-import org.apache.spark.sql.api.java.Row;
+import org.apache.spark.sql.DataFrame;
+import org.apache.spark.sql.Row;
+import org.apache.spark.sql.SQLContext;
public class JavaSparkSQL {
public static class Person implements Serializable {
@@ -55,7 +55,7 @@ public void setAge(int age) {
public static void main(String[] args) throws Exception {
SparkConf sparkConf = new SparkConf().setAppName("JavaSparkSQL");
JavaSparkContext ctx = new JavaSparkContext(sparkConf);
- JavaSQLContext sqlCtx = new JavaSQLContext(ctx);
+ SQLContext sqlCtx = new SQLContext(ctx);
System.out.println("=== Data source: RDD ===");
// Load a text file and convert each line to a Java Bean.
@@ -74,15 +74,15 @@ public Person call(String line) {
});
// Apply a schema to an RDD of Java Beans and register it as a table.
- JavaSchemaRDD schemaPeople = sqlCtx.applySchema(people, Person.class);
+ DataFrame schemaPeople = sqlCtx.applySchema(people, Person.class);
schemaPeople.registerTempTable("people");
// SQL can be run over RDDs that have been registered as tables.
- JavaSchemaRDD teenagers = sqlCtx.sql("SELECT name FROM people WHERE age >= 13 AND age <= 19");
+ DataFrame teenagers = sqlCtx.sql("SELECT name FROM people WHERE age >= 13 AND age <= 19");
- // The results of SQL queries are SchemaRDDs and support all the normal RDD operations.
+ // The results of SQL queries are DataFrames and support all the normal RDD operations.
// The columns of a row in the result can be accessed by ordinal.
- List teenagerNames = teenagers.map(new Function() {
+ List teenagerNames = teenagers.toJavaRDD().map(new Function() {
@Override
public String call(Row row) {
return "Name: " + row.getString(0);
@@ -93,19 +93,19 @@ public String call(Row row) {
}
System.out.println("=== Data source: Parquet File ===");
- // JavaSchemaRDDs can be saved as parquet files, maintaining the schema information.
+ // DataFrames can be saved as parquet files, maintaining the schema information.
schemaPeople.saveAsParquetFile("people.parquet");
// Read in the parquet file created above.
// Parquet files are self-describing so the schema is preserved.
- // The result of loading a parquet file is also a JavaSchemaRDD.
- JavaSchemaRDD parquetFile = sqlCtx.parquetFile("people.parquet");
+ // The result of loading a parquet file is also a DataFrame.
+ DataFrame parquetFile = sqlCtx.parquetFile("people.parquet");
//Parquet files can also be registered as tables and then used in SQL statements.
parquetFile.registerTempTable("parquetFile");
- JavaSchemaRDD teenagers2 =
+ DataFrame teenagers2 =
sqlCtx.sql("SELECT name FROM parquetFile WHERE age >= 13 AND age <= 19");
- teenagerNames = teenagers2.map(new Function() {
+ teenagerNames = teenagers2.toJavaRDD().map(new Function() {
@Override
public String call(Row row) {
return "Name: " + row.getString(0);
@@ -119,8 +119,8 @@ public String call(Row row) {
// A JSON dataset is pointed by path.
// The path can be either a single text file or a directory storing text files.
String path = "examples/src/main/resources/people.json";
- // Create a JavaSchemaRDD from the file(s) pointed by path
- JavaSchemaRDD peopleFromJsonFile = sqlCtx.jsonFile(path);
+ // Create a DataFrame from the file(s) pointed by path
+ DataFrame peopleFromJsonFile = sqlCtx.jsonFile(path);
// Because the schema of a JSON dataset is automatically inferred, to write queries,
// it is better to take a look at what is the schema.
@@ -130,15 +130,15 @@ public String call(Row row) {
// |-- age: IntegerType
// |-- name: StringType
- // Register this JavaSchemaRDD as a table.
+ // Register this DataFrame as a table.
peopleFromJsonFile.registerTempTable("people");
// SQL statements can be run by using the sql methods provided by sqlCtx.
- JavaSchemaRDD teenagers3 = sqlCtx.sql("SELECT name FROM people WHERE age >= 13 AND age <= 19");
+ DataFrame teenagers3 = sqlCtx.sql("SELECT name FROM people WHERE age >= 13 AND age <= 19");
- // The results of SQL queries are JavaSchemaRDDs and support all the normal RDD operations.
+ // The results of SQL queries are DataFrame and support all the normal RDD operations.
// The columns of a row in the result can be accessed by ordinal.
- teenagerNames = teenagers3.map(new Function() {
+ teenagerNames = teenagers3.toJavaRDD().map(new Function() {
@Override
public String call(Row row) { return "Name: " + row.getString(0); }
}).collect();
@@ -146,14 +146,14 @@ public String call(Row row) {
System.out.println(name);
}
- // Alternatively, a JavaSchemaRDD can be created for a JSON dataset represented by
+ // Alternatively, a DataFrame can be created for a JSON dataset represented by
// a RDD[String] storing one JSON object per string.
List jsonData = Arrays.asList(
"{\"name\":\"Yin\",\"address\":{\"city\":\"Columbus\",\"state\":\"Ohio\"}}");
JavaRDD anotherPeopleRDD = ctx.parallelize(jsonData);
- JavaSchemaRDD peopleFromJsonRDD = sqlCtx.jsonRDD(anotherPeopleRDD);
+ DataFrame peopleFromJsonRDD = sqlCtx.jsonRDD(anotherPeopleRDD.rdd());
- // Take a look at the schema of this new JavaSchemaRDD.
+ // Take a look at the schema of this new DataFrame.
peopleFromJsonRDD.printSchema();
// The schema of anotherPeople is ...
// root
@@ -164,8 +164,8 @@ public String call(Row row) {
peopleFromJsonRDD.registerTempTable("people2");
- JavaSchemaRDD peopleWithCity = sqlCtx.sql("SELECT name, address.city FROM people2");
- List nameAndCity = peopleWithCity.map(new Function() {
+ DataFrame peopleWithCity = sqlCtx.sql("SELECT name, address.city FROM people2");
+ List nameAndCity = peopleWithCity.toJavaRDD().map(new Function() {
@Override
public String call(Row row) {
return "Name: " + row.getString(0) + ", City: " + row.getString(1);
diff --git a/examples/src/main/python/mllib/dataset_example.py b/examples/src/main/python/mllib/dataset_example.py
index 540dae785f6ea..b5a70db2b9a3c 100644
--- a/examples/src/main/python/mllib/dataset_example.py
+++ b/examples/src/main/python/mllib/dataset_example.py
@@ -16,7 +16,7 @@
#
"""
-An example of how to use SchemaRDD as a dataset for ML. Run with::
+An example of how to use DataFrame as a dataset for ML. Run with::
bin/spark-submit examples/src/main/python/mllib/dataset_example.py
"""
diff --git a/examples/src/main/python/sql.py b/examples/src/main/python/sql.py
index d2c5ca48c6cb8..7f5c68e3d0fe2 100644
--- a/examples/src/main/python/sql.py
+++ b/examples/src/main/python/sql.py
@@ -30,18 +30,18 @@
some_rdd = sc.parallelize([Row(name="John", age=19),
Row(name="Smith", age=23),
Row(name="Sarah", age=18)])
- # Infer schema from the first row, create a SchemaRDD and print the schema
- some_schemardd = sqlContext.inferSchema(some_rdd)
- some_schemardd.printSchema()
+ # Infer schema from the first row, create a DataFrame and print the schema
+ some_df = sqlContext.inferSchema(some_rdd)
+ some_df.printSchema()
# Another RDD is created from a list of tuples
another_rdd = sc.parallelize([("John", 19), ("Smith", 23), ("Sarah", 18)])
# Schema with two fields - person_name and person_age
schema = StructType([StructField("person_name", StringType(), False),
StructField("person_age", IntegerType(), False)])
- # Create a SchemaRDD by applying the schema to the RDD and print the schema
- another_schemardd = sqlContext.applySchema(another_rdd, schema)
- another_schemardd.printSchema()
+ # Create a DataFrame by applying the schema to the RDD and print the schema
+ another_df = sqlContext.applySchema(another_rdd, schema)
+ another_df.printSchema()
# root
# |-- age: integer (nullable = true)
# |-- name: string (nullable = true)
@@ -49,7 +49,7 @@
# A JSON dataset is pointed to by path.
# The path can be either a single text file or a directory storing text files.
path = os.path.join(os.environ['SPARK_HOME'], "examples/src/main/resources/people.json")
- # Create a SchemaRDD from the file(s) pointed to by path
+ # Create a DataFrame from the file(s) pointed to by path
people = sqlContext.jsonFile(path)
# root
# |-- person_name: string (nullable = false)
@@ -61,7 +61,7 @@
# |-- age: IntegerType
# |-- name: StringType
- # Register this SchemaRDD as a table.
+ # Register this DataFrame as a table.
people.registerAsTable("people")
# SQL statements can be run by using the sql methods provided by sqlContext
diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/CrossValidatorExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/CrossValidatorExample.scala
index d8c7ef38ee46d..283bb80f1c788 100644
--- a/examples/src/main/scala/org/apache/spark/examples/ml/CrossValidatorExample.scala
+++ b/examples/src/main/scala/org/apache/spark/examples/ml/CrossValidatorExample.scala
@@ -18,7 +18,6 @@
package org.apache.spark.examples.ml
import org.apache.spark.{SparkConf, SparkContext}
-import org.apache.spark.SparkContext._
import org.apache.spark.ml.Pipeline
import org.apache.spark.ml.classification.LogisticRegression
import org.apache.spark.ml.evaluation.BinaryClassificationEvaluator
@@ -101,7 +100,7 @@ object CrossValidatorExample {
// Make predictions on test documents. cvModel uses the best model found (lrModel).
cvModel.transform(test)
- .select('id, 'text, 'score, 'prediction)
+ .select("id", "text", "score", "prediction")
.collect()
.foreach { case Row(id: Long, text: String, score: Double, prediction: Double) =>
println("(" + id + ", " + text + ") --> score=" + score + ", prediction=" + prediction)
diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/MovieLensALS.scala b/examples/src/main/scala/org/apache/spark/examples/ml/MovieLensALS.scala
new file mode 100644
index 0000000000000..b7885829459a3
--- /dev/null
+++ b/examples/src/main/scala/org/apache/spark/examples/ml/MovieLensALS.scala
@@ -0,0 +1,174 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.examples.ml
+
+import scopt.OptionParser
+
+import org.apache.spark.{SparkConf, SparkContext}
+import org.apache.spark.examples.mllib.AbstractParams
+import org.apache.spark.ml.recommendation.ALS
+import org.apache.spark.sql.{Row, SQLContext}
+
+/**
+ * An example app for ALS on MovieLens data (http://grouplens.org/datasets/movielens/).
+ * Run with
+ * {{{
+ * bin/run-example ml.MovieLensALS
+ * }}}
+ */
+object MovieLensALS {
+
+ case class Rating(userId: Int, movieId: Int, rating: Float, timestamp: Long)
+
+ object Rating {
+ def parseRating(str: String): Rating = {
+ val fields = str.split("::")
+ assert(fields.size == 4)
+ Rating(fields(0).toInt, fields(1).toInt, fields(2).toFloat, fields(3).toLong)
+ }
+ }
+
+ case class Movie(movieId: Int, title: String, genres: Seq[String])
+
+ object Movie {
+ def parseMovie(str: String): Movie = {
+ val fields = str.split("::")
+ assert(fields.size == 3)
+ Movie(fields(0).toInt, fields(1), fields(2).split("|"))
+ }
+ }
+
+ case class Params(
+ ratings: String = null,
+ movies: String = null,
+ maxIter: Int = 10,
+ regParam: Double = 0.1,
+ rank: Int = 10,
+ numBlocks: Int = 10) extends AbstractParams[Params]
+
+ def main(args: Array[String]) {
+ val defaultParams = Params()
+
+ val parser = new OptionParser[Params]("MovieLensALS") {
+ head("MovieLensALS: an example app for ALS on MovieLens data.")
+ opt[String]("ratings")
+ .required()
+ .text("path to a MovieLens dataset of ratings")
+ .action((x, c) => c.copy(ratings = x))
+ opt[String]("movies")
+ .required()
+ .text("path to a MovieLens dataset of movies")
+ .action((x, c) => c.copy(movies = x))
+ opt[Int]("rank")
+ .text(s"rank, default: ${defaultParams.rank}}")
+ .action((x, c) => c.copy(rank = x))
+ opt[Int]("maxIter")
+ .text(s"max number of iterations, default: ${defaultParams.maxIter}")
+ .action((x, c) => c.copy(maxIter = x))
+ opt[Double]("regParam")
+ .text(s"regularization parameter, default: ${defaultParams.regParam}")
+ .action((x, c) => c.copy(regParam = x))
+ opt[Int]("numBlocks")
+ .text(s"number of blocks, default: ${defaultParams.numBlocks}")
+ .action((x, c) => c.copy(numBlocks = x))
+ note(
+ """
+ |Example command line to run this app:
+ |
+ | bin/spark-submit --class org.apache.spark.examples.ml.MovieLensALS \
+ | examples/target/scala-*/spark-examples-*.jar \
+ | --rank 10 --maxIter 15 --regParam 0.1 \
+ | --movies path/to/movielens/movies.dat \
+ | --ratings path/to/movielens/ratings.dat
+ """.stripMargin)
+ }
+
+ parser.parse(args, defaultParams).map { params =>
+ run(params)
+ } getOrElse {
+ System.exit(1)
+ }
+ }
+
+ def run(params: Params) {
+ val conf = new SparkConf().setAppName(s"MovieLensALS with $params")
+ val sc = new SparkContext(conf)
+ val sqlContext = new SQLContext(sc)
+ import sqlContext._
+
+ val ratings = sc.textFile(params.ratings).map(Rating.parseRating).cache()
+
+ val numRatings = ratings.count()
+ val numUsers = ratings.map(_.userId).distinct().count()
+ val numMovies = ratings.map(_.movieId).distinct().count()
+
+ println(s"Got $numRatings ratings from $numUsers users on $numMovies movies.")
+
+ val splits = ratings.randomSplit(Array(0.8, 0.2), 0L)
+ val training = splits(0).cache()
+ val test = splits(1).cache()
+
+ val numTraining = training.count()
+ val numTest = test.count()
+ println(s"Training: $numTraining, test: $numTest.")
+
+ ratings.unpersist(blocking = false)
+
+ val als = new ALS()
+ .setUserCol("userId")
+ .setItemCol("movieId")
+ .setRank(params.rank)
+ .setMaxIter(params.maxIter)
+ .setRegParam(params.regParam)
+ .setNumBlocks(params.numBlocks)
+
+ val model = als.fit(training)
+
+ val predictions = model.transform(test).cache()
+
+ // Evaluate the model.
+ // TODO: Create an evaluator to compute RMSE.
+ val mse = predictions.select("rating", "prediction").rdd
+ .flatMap { case Row(rating: Float, prediction: Float) =>
+ val err = rating.toDouble - prediction
+ val err2 = err * err
+ if (err2.isNaN) {
+ None
+ } else {
+ Some(err2)
+ }
+ }.mean()
+ val rmse = math.sqrt(mse)
+ println(s"Test RMSE = $rmse.")
+
+ // Inspect false positives.
+ predictions.registerTempTable("prediction")
+ sc.textFile(params.movies).map(Movie.parseMovie).registerTempTable("movie")
+ sqlContext.sql(
+ """
+ |SELECT userId, prediction.movieId, title, rating, prediction
+ | FROM prediction JOIN movie ON prediction.movieId = movie.movieId
+ | WHERE rating <= 1 AND prediction >= 4
+ | LIMIT 100
+ """.stripMargin)
+ .collect()
+ .foreach(println)
+
+ sc.stop()
+ }
+}
diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/SimpleParamsExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/SimpleParamsExample.scala
index e8a2adff929cb..95cc9801eaeb9 100644
--- a/examples/src/main/scala/org/apache/spark/examples/ml/SimpleParamsExample.scala
+++ b/examples/src/main/scala/org/apache/spark/examples/ml/SimpleParamsExample.scala
@@ -18,7 +18,6 @@
package org.apache.spark.examples.ml
import org.apache.spark.{SparkConf, SparkContext}
-import org.apache.spark.SparkContext._
import org.apache.spark.ml.classification.LogisticRegression
import org.apache.spark.ml.param.ParamMap
import org.apache.spark.mllib.linalg.{Vector, Vectors}
@@ -42,7 +41,7 @@ object SimpleParamsExample {
// Prepare training data.
// We use LabeledPoint, which is a case class. Spark SQL can convert RDDs of Java Beans
- // into SchemaRDDs, where it uses the bean metadata to infer the schema.
+ // into DataFrames, where it uses the bean metadata to infer the schema.
val training = sparkContext.parallelize(Seq(
LabeledPoint(1.0, Vectors.dense(0.0, 1.1, 0.1)),
LabeledPoint(0.0, Vectors.dense(2.0, 1.0, -1.0)),
@@ -92,7 +91,7 @@ object SimpleParamsExample {
// Note that model2.transform() outputs a 'probability' column instead of the usual 'score'
// column since we renamed the lr.scoreCol parameter previously.
model2.transform(test)
- .select('features, 'label, 'probability, 'prediction)
+ .select("features", "label", "probability", "prediction")
.collect()
.foreach { case Row(features: Vector, label: Double, prob: Double, prediction: Double) =>
println("(" + features + ", " + label + ") -> prob=" + prob + ", prediction=" + prediction)
diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/SimpleTextClassificationPipeline.scala b/examples/src/main/scala/org/apache/spark/examples/ml/SimpleTextClassificationPipeline.scala
index b9a6ef0229def..065db62b0f5ed 100644
--- a/examples/src/main/scala/org/apache/spark/examples/ml/SimpleTextClassificationPipeline.scala
+++ b/examples/src/main/scala/org/apache/spark/examples/ml/SimpleTextClassificationPipeline.scala
@@ -20,7 +20,6 @@ package org.apache.spark.examples.ml
import scala.beans.BeanInfo
import org.apache.spark.{SparkConf, SparkContext}
-import org.apache.spark.SparkContext._
import org.apache.spark.ml.Pipeline
import org.apache.spark.ml.classification.LogisticRegression
import org.apache.spark.ml.feature.{HashingTF, Tokenizer}
@@ -80,7 +79,7 @@ object SimpleTextClassificationPipeline {
// Make predictions on test documents.
model.transform(test)
- .select('id, 'text, 'score, 'prediction)
+ .select("id", "text", "score", "prediction")
.collect()
.foreach { case Row(id: Long, text: String, score: Double, prediction: Double) =>
println("(" + id + ", " + text + ") --> score=" + score + ", prediction=" + prediction)
diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/DatasetExample.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/DatasetExample.scala
index f8d83f4ec7327..f229a58985a3e 100644
--- a/examples/src/main/scala/org/apache/spark/examples/mllib/DatasetExample.scala
+++ b/examples/src/main/scala/org/apache/spark/examples/mllib/DatasetExample.scala
@@ -28,10 +28,10 @@ import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.stat.MultivariateOnlineSummarizer
import org.apache.spark.mllib.util.MLUtils
import org.apache.spark.rdd.RDD
-import org.apache.spark.sql.{Row, SQLContext, SchemaRDD}
+import org.apache.spark.sql.{Row, SQLContext, DataFrame}
/**
- * An example of how to use [[org.apache.spark.sql.SchemaRDD]] as a Dataset for ML. Run with
+ * An example of how to use [[org.apache.spark.sql.DataFrame]] as a Dataset for ML. Run with
* {{{
* ./bin/run-example org.apache.spark.examples.mllib.DatasetExample [options]
* }}}
@@ -47,7 +47,7 @@ object DatasetExample {
val defaultParams = Params()
val parser = new OptionParser[Params]("DatasetExample") {
- head("Dataset: an example app using SchemaRDD as a Dataset for ML.")
+ head("Dataset: an example app using DataFrame as a Dataset for ML.")
opt[String]("input")
.text(s"input path to dataset")
.action((x, c) => c.copy(input = x))
@@ -80,20 +80,20 @@ object DatasetExample {
}
println(s"Loaded ${origData.count()} instances from file: ${params.input}")
- // Convert input data to SchemaRDD explicitly.
- val schemaRDD: SchemaRDD = origData
- println(s"Inferred schema:\n${schemaRDD.schema.prettyJson}")
- println(s"Converted to SchemaRDD with ${schemaRDD.count()} records")
+ // Convert input data to DataFrame explicitly.
+ val df: DataFrame = origData.toDF
+ println(s"Inferred schema:\n${df.schema.prettyJson}")
+ println(s"Converted to DataFrame with ${df.count()} records")
- // Select columns, using implicit conversion to SchemaRDD.
- val labelsSchemaRDD: SchemaRDD = origData.select('label)
- val labels: RDD[Double] = labelsSchemaRDD.map { case Row(v: Double) => v }
+ // Select columns, using implicit conversion to DataFrames.
+ val labelsDf: DataFrame = origData.select("label")
+ val labels: RDD[Double] = labelsDf.map { case Row(v: Double) => v }
val numLabels = labels.count()
val meanLabel = labels.fold(0.0)(_ + _) / numLabels
println(s"Selected label column with average value $meanLabel")
- val featuresSchemaRDD: SchemaRDD = origData.select('features)
- val features: RDD[Vector] = featuresSchemaRDD.map { case Row(v: Vector) => v }
+ val featuresDf: DataFrame = origData.select("features")
+ val features: RDD[Vector] = featuresDf.map { case Row(v: Vector) => v }
val featureSummary = features.aggregate(new MultivariateOnlineSummarizer())(
(summary, feat) => summary.add(feat),
(sum1, sum2) => sum1.merge(sum2))
@@ -103,13 +103,13 @@ object DatasetExample {
tmpDir.deleteOnExit()
val outputDir = new File(tmpDir, "dataset").toString
println(s"Saving to $outputDir as Parquet file.")
- schemaRDD.saveAsParquetFile(outputDir)
+ df.saveAsParquetFile(outputDir)
println(s"Loading Parquet file with UDT from $outputDir.")
val newDataset = sqlContext.parquetFile(outputDir)
println(s"Schema from Parquet: ${newDataset.schema.prettyJson}")
- val newFeatures = newDataset.select('features).map { case Row(v: Vector) => v }
+ val newFeatures = newDataset.select("features").map { case Row(v: Vector) => v }
val newFeaturesSummary = newFeatures.aggregate(new MultivariateOnlineSummarizer())(
(summary, feat) => summary.add(feat),
(sum1, sum2) => sum1.merge(sum2))
diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/DenseGmmEM.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/DenseGmmEM.scala
index 948c350953e27..de58be38c7bfb 100644
--- a/examples/src/main/scala/org/apache/spark/examples/mllib/DenseGmmEM.scala
+++ b/examples/src/main/scala/org/apache/spark/examples/mllib/DenseGmmEM.scala
@@ -54,7 +54,7 @@ object DenseGmmEM {
for (i <- 0 until clusters.k) {
println("weight=%f\nmu=%s\nsigma=\n%s\n" format
- (clusters.weight(i), clusters.mu(i), clusters.sigma(i)))
+ (clusters.weights(i), clusters.gaussians(i).mu, clusters.gaussians(i).sigma))
}
println("Cluster labels (first <= 100):")
diff --git a/examples/src/main/scala/org/apache/spark/examples/sql/RDDRelation.scala b/examples/src/main/scala/org/apache/spark/examples/sql/RDDRelation.scala
index 2e98b2dc30b80..a5d7f262581f5 100644
--- a/examples/src/main/scala/org/apache/spark/examples/sql/RDDRelation.scala
+++ b/examples/src/main/scala/org/apache/spark/examples/sql/RDDRelation.scala
@@ -19,6 +19,8 @@ package org.apache.spark.examples.sql
import org.apache.spark.{SparkConf, SparkContext}
import org.apache.spark.sql.SQLContext
+import org.apache.spark.sql.dsl._
+import org.apache.spark.sql.dsl.literals._
// One method for defining the schema of an RDD is to make a case class with the desired column
// names and types.
@@ -54,7 +56,7 @@ object RDDRelation {
rddFromSql.map(row => s"Key: ${row(0)}, Value: ${row(1)}").collect().foreach(println)
// Queries can also be written using a LINQ-like Scala DSL.
- rdd.where('key === 1).orderBy('value.asc).select('key).collect().foreach(println)
+ rdd.where($"key" === 1).orderBy($"value".asc).select($"key").collect().foreach(println)
// Write out an RDD as a parquet file.
rdd.saveAsParquetFile("pair.parquet")
@@ -63,7 +65,7 @@ object RDDRelation {
val parquetFile = sqlContext.parquetFile("pair.parquet")
// Queries can be run using the DSL on parequet files just like the original RDD.
- parquetFile.where('key === 1).select('value as 'a).collect().foreach(println)
+ parquetFile.where($"key" === 1).select($"value".as("a")).collect().foreach(println)
// These files can also be registered as tables.
parquetFile.registerTempTable("parquetFile")
diff --git a/graphx/src/main/scala/org/apache/spark/graphx/impl/EdgeRDDImpl.scala b/graphx/src/main/scala/org/apache/spark/graphx/impl/EdgeRDDImpl.scala
index 897c7ee12a436..f1550ac2e18ad 100644
--- a/graphx/src/main/scala/org/apache/spark/graphx/impl/EdgeRDDImpl.scala
+++ b/graphx/src/main/scala/org/apache/spark/graphx/impl/EdgeRDDImpl.scala
@@ -19,7 +19,7 @@ package org.apache.spark.graphx.impl
import scala.reflect.{classTag, ClassTag}
-import org.apache.spark.{OneToOneDependency, Partition, Partitioner, TaskContext}
+import org.apache.spark.{OneToOneDependency, HashPartitioner, TaskContext}
import org.apache.spark.rdd.RDD
import org.apache.spark.storage.StorageLevel
@@ -46,7 +46,7 @@ class EdgeRDDImpl[ED: ClassTag, VD: ClassTag] private[graphx] (
* partitioner that allows co-partitioning with `partitionsRDD`.
*/
override val partitioner =
- partitionsRDD.partitioner.orElse(Some(Partitioner.defaultPartitioner(partitionsRDD)))
+ partitionsRDD.partitioner.orElse(Some(new HashPartitioner(partitions.size)))
override def collect(): Array[Edge[ED]] = this.map(_.copy()).collect()
diff --git a/graphx/src/main/scala/org/apache/spark/graphx/util/GraphGenerators.scala b/graphx/src/main/scala/org/apache/spark/graphx/util/GraphGenerators.scala
index 8a13c74221546..2d6a825b61726 100644
--- a/graphx/src/main/scala/org/apache/spark/graphx/util/GraphGenerators.scala
+++ b/graphx/src/main/scala/org/apache/spark/graphx/util/GraphGenerators.scala
@@ -133,6 +133,12 @@ object GraphGenerators {
// This ensures that the 4 quadrants are the same size at all recursion levels
val numVertices = math.round(
math.pow(2.0, math.ceil(math.log(requestedNumVertices) / math.log(2.0)))).toInt
+ val numEdgesUpperBound =
+ math.pow(2.0, 2 * ((math.log(numVertices) / math.log(2.0)) - 1)).toInt
+ if (numEdgesUpperBound < numEdges) {
+ throw new IllegalArgumentException(
+ s"numEdges must be <= $numEdgesUpperBound but was $numEdges")
+ }
var edges: Set[Edge[Int]] = Set()
while (edges.size < numEdges) {
if (edges.size % 100 == 0) {
diff --git a/graphx/src/test/scala/org/apache/spark/graphx/GraphSuite.scala b/graphx/src/test/scala/org/apache/spark/graphx/GraphSuite.scala
index 9da0064104fb6..ed9876b8dc21c 100644
--- a/graphx/src/test/scala/org/apache/spark/graphx/GraphSuite.scala
+++ b/graphx/src/test/scala/org/apache/spark/graphx/GraphSuite.scala
@@ -386,4 +386,24 @@ class GraphSuite extends FunSuite with LocalSparkContext {
}
}
+ test("non-default number of edge partitions") {
+ val n = 10
+ val defaultParallelism = 3
+ val numEdgePartitions = 4
+ assert(defaultParallelism != numEdgePartitions)
+ val conf = new org.apache.spark.SparkConf()
+ .set("spark.default.parallelism", defaultParallelism.toString)
+ val sc = new SparkContext("local", "test", conf)
+ try {
+ val edges = sc.parallelize((1 to n).map(x => (x: VertexId, 0: VertexId)),
+ numEdgePartitions)
+ val graph = Graph.fromEdgeTuples(edges, 1)
+ val neighborAttrSums = graph.mapReduceTriplets[Int](
+ et => Iterator((et.dstId, et.srcAttr)), _ + _)
+ assert(neighborAttrSums.collect.toSet === Set((0: VertexId, n)))
+ } finally {
+ sc.stop()
+ }
+ }
+
}
diff --git a/graphx/src/test/scala/org/apache/spark/graphx/util/GraphGeneratorsSuite.scala b/graphx/src/test/scala/org/apache/spark/graphx/util/GraphGeneratorsSuite.scala
index 3abefbe52fa8a..8d9c8ddccbb3c 100644
--- a/graphx/src/test/scala/org/apache/spark/graphx/util/GraphGeneratorsSuite.scala
+++ b/graphx/src/test/scala/org/apache/spark/graphx/util/GraphGeneratorsSuite.scala
@@ -110,4 +110,14 @@ class GraphGeneratorsSuite extends FunSuite with LocalSparkContext {
}
}
+ test("SPARK-5064 GraphGenerators.rmatGraph numEdges upper bound") {
+ withSpark { sc =>
+ val g1 = GraphGenerators.rmatGraph(sc, 4, 4)
+ assert(g1.edges.count() === 4)
+ intercept[IllegalArgumentException] {
+ val g2 = GraphGenerators.rmatGraph(sc, 4, 8)
+ }
+ }
+ }
+
}
diff --git a/make-distribution.sh b/make-distribution.sh
index 4e2f400be3053..0adca7851819b 100755
--- a/make-distribution.sh
+++ b/make-distribution.sh
@@ -115,7 +115,7 @@ if which git &>/dev/null; then
unset GITREV
fi
-if ! which $MVN &>/dev/null; then
+if ! which "$MVN" &>/dev/null; then
echo -e "Could not locate Maven command: '$MVN'."
echo -e "Specify the Maven command with the --mvn flag"
exit -1;
@@ -171,13 +171,16 @@ cd "$SPARK_HOME"
export MAVEN_OPTS="-Xmx2g -XX:MaxPermSize=512M -XX:ReservedCodeCacheSize=512m"
-BUILD_COMMAND="$MVN clean package -DskipTests $@"
+# Store the command as an array because $MVN variable might have spaces in it.
+# Normal quoting tricks don't work.
+# See: http://mywiki.wooledge.org/BashFAQ/050
+BUILD_COMMAND=("$MVN" clean package -DskipTests $@)
# Actually build the jar
echo -e "\nBuilding with..."
-echo -e "\$ $BUILD_COMMAND\n"
+echo -e "\$ ${BUILD_COMMAND[@]}\n"
-${BUILD_COMMAND}
+"${BUILD_COMMAND[@]}"
# Make directories
rm -rf "$DISTDIR"
diff --git a/mllib/src/main/scala/org/apache/spark/ml/Estimator.scala b/mllib/src/main/scala/org/apache/spark/ml/Estimator.scala
index fdbee743e8177..bc3defe968afd 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/Estimator.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/Estimator.scala
@@ -18,12 +18,10 @@
package org.apache.spark.ml
import scala.annotation.varargs
-import scala.collection.JavaConverters._
import org.apache.spark.annotation.AlphaComponent
import org.apache.spark.ml.param.{ParamMap, ParamPair, Params}
-import org.apache.spark.sql.SchemaRDD
-import org.apache.spark.sql.api.java.JavaSchemaRDD
+import org.apache.spark.sql.DataFrame
/**
* :: AlphaComponent ::
@@ -40,7 +38,7 @@ abstract class Estimator[M <: Model[M]] extends PipelineStage with Params {
* @return fitted model
*/
@varargs
- def fit(dataset: SchemaRDD, paramPairs: ParamPair[_]*): M = {
+ def fit(dataset: DataFrame, paramPairs: ParamPair[_]*): M = {
val map = new ParamMap().put(paramPairs: _*)
fit(dataset, map)
}
@@ -52,7 +50,7 @@ abstract class Estimator[M <: Model[M]] extends PipelineStage with Params {
* @param paramMap parameter map
* @return fitted model
*/
- def fit(dataset: SchemaRDD, paramMap: ParamMap): M
+ def fit(dataset: DataFrame, paramMap: ParamMap): M
/**
* Fits multiple models to the input data with multiple sets of parameters.
@@ -63,43 +61,7 @@ abstract class Estimator[M <: Model[M]] extends PipelineStage with Params {
* @param paramMaps an array of parameter maps
* @return fitted models, matching the input parameter maps
*/
- def fit(dataset: SchemaRDD, paramMaps: Array[ParamMap]): Seq[M] = {
+ def fit(dataset: DataFrame, paramMaps: Array[ParamMap]): Seq[M] = {
paramMaps.map(fit(dataset, _))
}
-
- // Java-friendly versions of fit.
-
- /**
- * Fits a single model to the input data with optional parameters.
- *
- * @param dataset input dataset
- * @param paramPairs optional list of param pairs (overwrite embedded params)
- * @return fitted model
- */
- @varargs
- def fit(dataset: JavaSchemaRDD, paramPairs: ParamPair[_]*): M = {
- fit(dataset.schemaRDD, paramPairs: _*)
- }
-
- /**
- * Fits a single model to the input data with provided parameter map.
- *
- * @param dataset input dataset
- * @param paramMap parameter map
- * @return fitted model
- */
- def fit(dataset: JavaSchemaRDD, paramMap: ParamMap): M = {
- fit(dataset.schemaRDD, paramMap)
- }
-
- /**
- * Fits multiple models to the input data with multiple sets of parameters.
- *
- * @param dataset input dataset
- * @param paramMaps an array of parameter maps
- * @return fitted models, matching the input parameter maps
- */
- def fit(dataset: JavaSchemaRDD, paramMaps: Array[ParamMap]): java.util.List[M] = {
- fit(dataset.schemaRDD, paramMaps).asJava
- }
}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/Evaluator.scala b/mllib/src/main/scala/org/apache/spark/ml/Evaluator.scala
index db563dd550e56..d2ca2e6871e6b 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/Evaluator.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/Evaluator.scala
@@ -19,7 +19,7 @@ package org.apache.spark.ml
import org.apache.spark.annotation.AlphaComponent
import org.apache.spark.ml.param.ParamMap
-import org.apache.spark.sql.SchemaRDD
+import org.apache.spark.sql.DataFrame
/**
* :: AlphaComponent ::
@@ -35,5 +35,5 @@ abstract class Evaluator extends Identifiable {
* @param paramMap parameter map that specifies the input columns and output metrics
* @return metric
*/
- def evaluate(dataset: SchemaRDD, paramMap: ParamMap): Double
+ def evaluate(dataset: DataFrame, paramMap: ParamMap): Double
}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala b/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala
index ad6fed178fae9..fe39cd1bc0bd2 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala
@@ -22,7 +22,7 @@ import scala.collection.mutable.ListBuffer
import org.apache.spark.Logging
import org.apache.spark.annotation.AlphaComponent
import org.apache.spark.ml.param.{Param, ParamMap}
-import org.apache.spark.sql.SchemaRDD
+import org.apache.spark.sql.DataFrame
import org.apache.spark.sql.types.StructType
/**
@@ -88,7 +88,7 @@ class Pipeline extends Estimator[PipelineModel] {
* @param paramMap parameter map
* @return fitted pipeline
*/
- override def fit(dataset: SchemaRDD, paramMap: ParamMap): PipelineModel = {
+ override def fit(dataset: DataFrame, paramMap: ParamMap): PipelineModel = {
transformSchema(dataset.schema, paramMap, logging = true)
val map = this.paramMap ++ paramMap
val theStages = map(stages)
@@ -162,7 +162,7 @@ class PipelineModel private[ml] (
}
}
- override def transform(dataset: SchemaRDD, paramMap: ParamMap): SchemaRDD = {
+ override def transform(dataset: DataFrame, paramMap: ParamMap): DataFrame = {
// Precedence of ParamMaps: paramMap > this.paramMap > fittingParamMap
val map = (fittingParamMap ++ this.paramMap) ++ paramMap
transformSchema(dataset.schema, map, logging = true)
diff --git a/mllib/src/main/scala/org/apache/spark/ml/Transformer.scala b/mllib/src/main/scala/org/apache/spark/ml/Transformer.scala
index 1331b9124045c..b233bff08305c 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/Transformer.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/Transformer.scala
@@ -22,10 +22,9 @@ import scala.annotation.varargs
import org.apache.spark.Logging
import org.apache.spark.annotation.AlphaComponent
import org.apache.spark.ml.param._
-import org.apache.spark.sql.SchemaRDD
-import org.apache.spark.sql.api.java.JavaSchemaRDD
-import org.apache.spark.sql.catalyst.analysis.Star
-import org.apache.spark.sql.catalyst.expressions.ScalaUdf
+import org.apache.spark.sql.DataFrame
+import org.apache.spark.sql._
+import org.apache.spark.sql.dsl._
import org.apache.spark.sql.types._
/**
@@ -42,7 +41,7 @@ abstract class Transformer extends PipelineStage with Params {
* @return transformed dataset
*/
@varargs
- def transform(dataset: SchemaRDD, paramPairs: ParamPair[_]*): SchemaRDD = {
+ def transform(dataset: DataFrame, paramPairs: ParamPair[_]*): DataFrame = {
val map = new ParamMap()
paramPairs.foreach(map.put(_))
transform(dataset, map)
@@ -54,30 +53,7 @@ abstract class Transformer extends PipelineStage with Params {
* @param paramMap additional parameters, overwrite embedded params
* @return transformed dataset
*/
- def transform(dataset: SchemaRDD, paramMap: ParamMap): SchemaRDD
-
- // Java-friendly versions of transform.
-
- /**
- * Transforms the dataset with optional parameters.
- * @param dataset input datset
- * @param paramPairs optional list of param pairs, overwrite embedded params
- * @return transformed dataset
- */
- @varargs
- def transform(dataset: JavaSchemaRDD, paramPairs: ParamPair[_]*): JavaSchemaRDD = {
- transform(dataset.schemaRDD, paramPairs: _*).toJavaSchemaRDD
- }
-
- /**
- * Transforms the dataset with provided parameter map as additional parameters.
- * @param dataset input dataset
- * @param paramMap additional parameters, overwrite embedded params
- * @return transformed dataset
- */
- def transform(dataset: JavaSchemaRDD, paramMap: ParamMap): JavaSchemaRDD = {
- transform(dataset.schemaRDD, paramMap).toJavaSchemaRDD
- }
+ def transform(dataset: DataFrame, paramMap: ParamMap): DataFrame
}
/**
@@ -119,11 +95,10 @@ private[ml] abstract class UnaryTransformer[IN, OUT, T <: UnaryTransformer[IN, O
StructType(outputFields)
}
- override def transform(dataset: SchemaRDD, paramMap: ParamMap): SchemaRDD = {
+ override def transform(dataset: DataFrame, paramMap: ParamMap): DataFrame = {
transformSchema(dataset.schema, paramMap, logging = true)
- import dataset.sqlContext._
val map = this.paramMap ++ paramMap
- val udf = ScalaUdf(this.createTransformFunc(map), outputDataType, Seq(map(inputCol).attr))
- dataset.select(Star(None), udf as map(outputCol))
+ dataset.select($"*", callUDF(
+ this.createTransformFunc(map), outputDataType, Column(map(inputCol))).as(map(outputCol)))
}
}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala
index 8c570812f8316..eeb6301c3f64a 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala
@@ -24,7 +24,7 @@ import org.apache.spark.mllib.classification.LogisticRegressionWithLBFGS
import org.apache.spark.mllib.linalg.{BLAS, Vector, VectorUDT}
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.sql._
-import org.apache.spark.sql.catalyst.analysis.Star
+import org.apache.spark.sql.dsl._
import org.apache.spark.sql.catalyst.dsl._
import org.apache.spark.sql.types.{DoubleType, StructField, StructType}
import org.apache.spark.storage.StorageLevel
@@ -87,11 +87,10 @@ class LogisticRegression extends Estimator[LogisticRegressionModel] with Logisti
def setScoreCol(value: String): this.type = set(scoreCol, value)
def setPredictionCol(value: String): this.type = set(predictionCol, value)
- override def fit(dataset: SchemaRDD, paramMap: ParamMap): LogisticRegressionModel = {
+ override def fit(dataset: DataFrame, paramMap: ParamMap): LogisticRegressionModel = {
transformSchema(dataset.schema, paramMap, logging = true)
- import dataset.sqlContext._
val map = this.paramMap ++ paramMap
- val instances = dataset.select(map(labelCol).attr, map(featuresCol).attr)
+ val instances = dataset.select(map(labelCol), map(featuresCol))
.map { case Row(label: Double, features: Vector) =>
LabeledPoint(label, features)
}.persist(StorageLevel.MEMORY_AND_DISK)
@@ -131,9 +130,8 @@ class LogisticRegressionModel private[ml] (
validateAndTransformSchema(schema, paramMap, fitting = false)
}
- override def transform(dataset: SchemaRDD, paramMap: ParamMap): SchemaRDD = {
+ override def transform(dataset: DataFrame, paramMap: ParamMap): DataFrame = {
transformSchema(dataset.schema, paramMap, logging = true)
- import dataset.sqlContext._
val map = this.paramMap ++ paramMap
val score: Vector => Double = (v) => {
val margin = BLAS.dot(v, weights)
@@ -143,7 +141,7 @@ class LogisticRegressionModel private[ml] (
val predict: Double => Double = (score) => {
if (score > t) 1.0 else 0.0
}
- dataset.select(Star(None), score.call(map(featuresCol).attr) as map(scoreCol))
- .select(Star(None), predict.call(map(scoreCol).attr) as map(predictionCol))
+ dataset.select($"*", callUDF(score, Column(map(featuresCol))).as(map(scoreCol)))
+ .select($"*", callUDF(predict, Column(map(scoreCol))).as(map(predictionCol)))
}
}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluator.scala b/mllib/src/main/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluator.scala
index 12473cb2b5719..1979ab9eb6516 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluator.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluator.scala
@@ -21,7 +21,7 @@ import org.apache.spark.annotation.AlphaComponent
import org.apache.spark.ml._
import org.apache.spark.ml.param._
import org.apache.spark.mllib.evaluation.BinaryClassificationMetrics
-import org.apache.spark.sql.{Row, SchemaRDD}
+import org.apache.spark.sql.{DataFrame, Row}
import org.apache.spark.sql.types.DoubleType
/**
@@ -41,7 +41,7 @@ class BinaryClassificationEvaluator extends Evaluator with Params
def setScoreCol(value: String): this.type = set(scoreCol, value)
def setLabelCol(value: String): this.type = set(labelCol, value)
- override def evaluate(dataset: SchemaRDD, paramMap: ParamMap): Double = {
+ override def evaluate(dataset: DataFrame, paramMap: ParamMap): Double = {
val map = this.paramMap ++ paramMap
val schema = dataset.schema
@@ -52,8 +52,7 @@ class BinaryClassificationEvaluator extends Evaluator with Params
require(labelType == DoubleType,
s"Label column ${map(labelCol)} must be double type but found $labelType")
- import dataset.sqlContext._
- val scoreAndLabels = dataset.select(map(scoreCol).attr, map(labelCol).attr)
+ val scoreAndLabels = dataset.select(map(scoreCol), map(labelCol))
.map { case Row(score: Double, label: Double) =>
(score, label)
}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala
index 72825f6e02182..e7bdb070c8193 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala
@@ -23,7 +23,7 @@ import org.apache.spark.ml.param._
import org.apache.spark.mllib.feature
import org.apache.spark.mllib.linalg.{Vector, VectorUDT}
import org.apache.spark.sql._
-import org.apache.spark.sql.catalyst.analysis.Star
+import org.apache.spark.sql.dsl._
import org.apache.spark.sql.catalyst.dsl._
import org.apache.spark.sql.types.{StructField, StructType}
@@ -43,14 +43,10 @@ class StandardScaler extends Estimator[StandardScalerModel] with StandardScalerP
def setInputCol(value: String): this.type = set(inputCol, value)
def setOutputCol(value: String): this.type = set(outputCol, value)
- override def fit(dataset: SchemaRDD, paramMap: ParamMap): StandardScalerModel = {
+ override def fit(dataset: DataFrame, paramMap: ParamMap): StandardScalerModel = {
transformSchema(dataset.schema, paramMap, logging = true)
- import dataset.sqlContext._
val map = this.paramMap ++ paramMap
- val input = dataset.select(map(inputCol).attr)
- .map { case Row(v: Vector) =>
- v
- }
+ val input = dataset.select(map(inputCol)).map { case Row(v: Vector) => v }
val scaler = new feature.StandardScaler().fit(input)
val model = new StandardScalerModel(this, map, scaler)
Params.inheritValues(map, this, model)
@@ -83,14 +79,13 @@ class StandardScalerModel private[ml] (
def setInputCol(value: String): this.type = set(inputCol, value)
def setOutputCol(value: String): this.type = set(outputCol, value)
- override def transform(dataset: SchemaRDD, paramMap: ParamMap): SchemaRDD = {
+ override def transform(dataset: DataFrame, paramMap: ParamMap): DataFrame = {
transformSchema(dataset.schema, paramMap, logging = true)
- import dataset.sqlContext._
val map = this.paramMap ++ paramMap
val scale: (Vector) => Vector = (v) => {
scaler.transform(v)
}
- dataset.select(Star(None), scale.call(map(inputCol).attr) as map(outputCol))
+ dataset.select($"*", callUDF(scale, Column(map(inputCol))).as(map(outputCol)))
}
private[ml] override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = {
diff --git a/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala b/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala
new file mode 100644
index 0000000000000..f6437c7fbc8ed
--- /dev/null
+++ b/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala
@@ -0,0 +1,970 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.recommendation
+
+import java.{util => ju}
+
+import scala.collection.mutable
+
+import com.github.fommil.netlib.BLAS.{getInstance => blas}
+import com.github.fommil.netlib.LAPACK.{getInstance => lapack}
+import org.netlib.util.intW
+
+import org.apache.spark.{HashPartitioner, Logging, Partitioner}
+import org.apache.spark.ml.{Estimator, Model}
+import org.apache.spark.ml.param._
+import org.apache.spark.rdd.RDD
+import org.apache.spark.sql.{Column, DataFrame}
+import org.apache.spark.sql.dsl._
+import org.apache.spark.sql.types.{DoubleType, FloatType, IntegerType, StructField, StructType}
+import org.apache.spark.util.Utils
+import org.apache.spark.util.collection.{OpenHashMap, OpenHashSet, SortDataFormat, Sorter}
+import org.apache.spark.util.random.XORShiftRandom
+
+/**
+ * Common params for ALS.
+ */
+private[recommendation] trait ALSParams extends Params with HasMaxIter with HasRegParam
+ with HasPredictionCol {
+
+ /** Param for rank of the matrix factorization. */
+ val rank = new IntParam(this, "rank", "rank of the factorization", Some(10))
+ def getRank: Int = get(rank)
+
+ /** Param for number of user blocks. */
+ val numUserBlocks = new IntParam(this, "numUserBlocks", "number of user blocks", Some(10))
+ def getNumUserBlocks: Int = get(numUserBlocks)
+
+ /** Param for number of item blocks. */
+ val numItemBlocks =
+ new IntParam(this, "numItemBlocks", "number of item blocks", Some(10))
+ def getNumItemBlocks: Int = get(numItemBlocks)
+
+ /** Param to decide whether to use implicit preference. */
+ val implicitPrefs =
+ new BooleanParam(this, "implicitPrefs", "whether to use implicit preference", Some(false))
+ def getImplicitPrefs: Boolean = get(implicitPrefs)
+
+ /** Param for the alpha parameter in the implicit preference formulation. */
+ val alpha = new DoubleParam(this, "alpha", "alpha for implicit preference", Some(1.0))
+ def getAlpha: Double = get(alpha)
+
+ /** Param for the column name for user ids. */
+ val userCol = new Param[String](this, "userCol", "column name for user ids", Some("user"))
+ def getUserCol: String = get(userCol)
+
+ /** Param for the column name for item ids. */
+ val itemCol =
+ new Param[String](this, "itemCol", "column name for item ids", Some("item"))
+ def getItemCol: String = get(itemCol)
+
+ /** Param for the column name for ratings. */
+ val ratingCol = new Param[String](this, "ratingCol", "column name for ratings", Some("rating"))
+ def getRatingCol: String = get(ratingCol)
+
+ /**
+ * Validates and transforms the input schema.
+ * @param schema input schema
+ * @param paramMap extra params
+ * @return output schema
+ */
+ protected def validateAndTransformSchema(schema: StructType, paramMap: ParamMap): StructType = {
+ val map = this.paramMap ++ paramMap
+ assert(schema(map(userCol)).dataType == IntegerType)
+ assert(schema(map(itemCol)).dataType== IntegerType)
+ val ratingType = schema(map(ratingCol)).dataType
+ assert(ratingType == FloatType || ratingType == DoubleType)
+ val predictionColName = map(predictionCol)
+ assert(!schema.fieldNames.contains(predictionColName),
+ s"Prediction column $predictionColName already exists.")
+ val newFields = schema.fields :+ StructField(map(predictionCol), FloatType, nullable = false)
+ StructType(newFields)
+ }
+}
+
+/**
+ * Model fitted by ALS.
+ */
+class ALSModel private[ml] (
+ override val parent: ALS,
+ override val fittingParamMap: ParamMap,
+ k: Int,
+ userFactors: RDD[(Int, Array[Float])],
+ itemFactors: RDD[(Int, Array[Float])])
+ extends Model[ALSModel] with ALSParams {
+
+ def setPredictionCol(value: String): this.type = set(predictionCol, value)
+
+ override def transform(dataset: DataFrame, paramMap: ParamMap): DataFrame = {
+ import dataset.sqlContext._
+ import org.apache.spark.ml.recommendation.ALSModel.Factor
+ val map = this.paramMap ++ paramMap
+ // TODO: Add DSL to simplify the code here.
+ val instanceTable = s"instance_$uid"
+ val userTable = s"user_$uid"
+ val itemTable = s"item_$uid"
+ val instances = dataset.as(instanceTable)
+ val users = userFactors.map { case (id, features) =>
+ Factor(id, features)
+ }.as(userTable)
+ val items = itemFactors.map { case (id, features) =>
+ Factor(id, features)
+ }.as(itemTable)
+ val predict: (Seq[Float], Seq[Float]) => Float = (userFeatures, itemFeatures) => {
+ if (userFeatures != null && itemFeatures != null) {
+ blas.sdot(k, userFeatures.toArray, 1, itemFeatures.toArray, 1)
+ } else {
+ Float.NaN
+ }
+ }
+ val inputColumns = dataset.schema.fieldNames
+ val prediction = callUDF(predict, $"$userTable.features", $"$itemTable.features")
+ .as(map(predictionCol))
+ val outputColumns = inputColumns.map(f => $"$instanceTable.$f".as(f)) :+ prediction
+ instances
+ .join(users, Column(map(userCol)) === $"$userTable.id", "left")
+ .join(items, Column(map(itemCol)) === $"$itemTable.id", "left")
+ .select(outputColumns: _*)
+ }
+
+ override private[ml] def transformSchema(schema: StructType, paramMap: ParamMap): StructType = {
+ validateAndTransformSchema(schema, paramMap)
+ }
+}
+
+private object ALSModel {
+ /** Case class to convert factors to SchemaRDDs */
+ private case class Factor(id: Int, features: Seq[Float])
+}
+
+/**
+ * Alternating Least Squares (ALS) matrix factorization.
+ *
+ * ALS attempts to estimate the ratings matrix `R` as the product of two lower-rank matrices,
+ * `X` and `Y`, i.e. `X * Yt = R`. Typically these approximations are called 'factor' matrices.
+ * The general approach is iterative. During each iteration, one of the factor matrices is held
+ * constant, while the other is solved for using least squares. The newly-solved factor matrix is
+ * then held constant while solving for the other factor matrix.
+ *
+ * This is a blocked implementation of the ALS factorization algorithm that groups the two sets
+ * of factors (referred to as "users" and "products") into blocks and reduces communication by only
+ * sending one copy of each user vector to each product block on each iteration, and only for the
+ * product blocks that need that user's feature vector. This is achieved by pre-computing some
+ * information about the ratings matrix to determine the "out-links" of each user (which blocks of
+ * products it will contribute to) and "in-link" information for each product (which of the feature
+ * vectors it receives from each user block it will depend on). This allows us to send only an
+ * array of feature vectors between each user block and product block, and have the product block
+ * find the users' ratings and update the products based on these messages.
+ *
+ * For implicit preference data, the algorithm used is based on
+ * "Collaborative Filtering for Implicit Feedback Datasets", available at
+ * [[http://dx.doi.org/10.1109/ICDM.2008.22]], adapted for the blocked approach used here.
+ *
+ * Essentially instead of finding the low-rank approximations to the rating matrix `R`,
+ * this finds the approximations for a preference matrix `P` where the elements of `P` are 1 if
+ * r > 0 and 0 if r <= 0. The ratings then act as 'confidence' values related to strength of
+ * indicated user
+ * preferences rather than explicit ratings given to items.
+ */
+class ALS extends Estimator[ALSModel] with ALSParams {
+
+ import org.apache.spark.ml.recommendation.ALS.Rating
+
+ def setRank(value: Int): this.type = set(rank, value)
+ def setNumUserBlocks(value: Int): this.type = set(numUserBlocks, value)
+ def setNumItemBlocks(value: Int): this.type = set(numItemBlocks, value)
+ def setImplicitPrefs(value: Boolean): this.type = set(implicitPrefs, value)
+ def setAlpha(value: Double): this.type = set(alpha, value)
+ def setUserCol(value: String): this.type = set(userCol, value)
+ def setItemCol(value: String): this.type = set(itemCol, value)
+ def setRatingCol(value: String): this.type = set(ratingCol, value)
+ def setPredictionCol(value: String): this.type = set(predictionCol, value)
+ def setMaxIter(value: Int): this.type = set(maxIter, value)
+ def setRegParam(value: Double): this.type = set(regParam, value)
+
+ /** Sets both numUserBlocks and numItemBlocks to the specific value. */
+ def setNumBlocks(value: Int): this.type = {
+ setNumUserBlocks(value)
+ setNumItemBlocks(value)
+ this
+ }
+
+ setMaxIter(20)
+ setRegParam(1.0)
+
+ override def fit(dataset: DataFrame, paramMap: ParamMap): ALSModel = {
+ val map = this.paramMap ++ paramMap
+ val ratings = dataset
+ .select(Column(map(userCol)), Column(map(itemCol)), Column(map(ratingCol)).cast(FloatType))
+ .map { row =>
+ new Rating(row.getInt(0), row.getInt(1), row.getFloat(2))
+ }
+ val (userFactors, itemFactors) = ALS.train(ratings, rank = map(rank),
+ numUserBlocks = map(numUserBlocks), numItemBlocks = map(numItemBlocks),
+ maxIter = map(maxIter), regParam = map(regParam), implicitPrefs = map(implicitPrefs),
+ alpha = map(alpha))
+ val model = new ALSModel(this, map, map(rank), userFactors, itemFactors)
+ Params.inheritValues(map, this, model)
+ model
+ }
+
+ override private[ml] def transformSchema(schema: StructType, paramMap: ParamMap): StructType = {
+ validateAndTransformSchema(schema, paramMap)
+ }
+}
+
+private[recommendation] object ALS extends Logging {
+
+ /** Rating class for better code readability. */
+ private[recommendation] case class Rating(user: Int, item: Int, rating: Float)
+
+ /** Cholesky solver for least square problems. */
+ private[recommendation] class CholeskySolver {
+
+ private val upper = "U"
+ private val info = new intW(0)
+
+ /**
+ * Solves a least squares problem with L2 regularization:
+ *
+ * min norm(A x - b)^2^ + lambda * n * norm(x)^2^
+ *
+ * @param ne a [[NormalEquation]] instance that contains AtA, Atb, and n (number of instances)
+ * @param lambda regularization constant, which will be scaled by n
+ * @return the solution x
+ */
+ def solve(ne: NormalEquation, lambda: Double): Array[Float] = {
+ val k = ne.k
+ // Add scaled lambda to the diagonals of AtA.
+ val scaledlambda = lambda * ne.n
+ var i = 0
+ var j = 2
+ while (i < ne.triK) {
+ ne.ata(i) += scaledlambda
+ i += j
+ j += 1
+ }
+ lapack.dppsv(upper, k, 1, ne.ata, ne.atb, k, info)
+ val code = info.`val`
+ assert(code == 0, s"lapack.dppsv returned $code.")
+ val x = new Array[Float](k)
+ i = 0
+ while (i < k) {
+ x(i) = ne.atb(i).toFloat
+ i += 1
+ }
+ ne.reset()
+ x
+ }
+ }
+
+ /** Representing a normal equation (ALS' subproblem). */
+ private[recommendation] class NormalEquation(val k: Int) extends Serializable {
+
+ /** Number of entries in the upper triangular part of a k-by-k matrix. */
+ val triK = k * (k + 1) / 2
+ /** A^T^ * A */
+ val ata = new Array[Double](triK)
+ /** A^T^ * b */
+ val atb = new Array[Double](k)
+ /** Number of observations. */
+ var n = 0
+
+ private val da = new Array[Double](k)
+ private val upper = "U"
+
+ private def copyToDouble(a: Array[Float]): Unit = {
+ var i = 0
+ while (i < k) {
+ da(i) = a(i)
+ i += 1
+ }
+ }
+
+ /** Adds an observation. */
+ def add(a: Array[Float], b: Float): this.type = {
+ require(a.size == k)
+ copyToDouble(a)
+ blas.dspr(upper, k, 1.0, da, 1, ata)
+ blas.daxpy(k, b.toDouble, da, 1, atb, 1)
+ n += 1
+ this
+ }
+
+ /**
+ * Adds an observation with implicit feedback. Note that this does not increment the counter.
+ */
+ def addImplicit(a: Array[Float], b: Float, alpha: Double): this.type = {
+ require(a.size == k)
+ // Extension to the original paper to handle b < 0. confidence is a function of |b| instead
+ // so that it is never negative.
+ val confidence = 1.0 + alpha * math.abs(b)
+ copyToDouble(a)
+ blas.dspr(upper, k, confidence - 1.0, da, 1, ata)
+ // For b <= 0, the corresponding preference is 0. So the term below is only added for b > 0.
+ if (b > 0) {
+ blas.daxpy(k, confidence, da, 1, atb, 1)
+ }
+ this
+ }
+
+ /** Merges another normal equation object. */
+ def merge(other: NormalEquation): this.type = {
+ require(other.k == k)
+ blas.daxpy(ata.size, 1.0, other.ata, 1, ata, 1)
+ blas.daxpy(atb.size, 1.0, other.atb, 1, atb, 1)
+ n += other.n
+ this
+ }
+
+ /** Resets everything to zero, which should be called after each solve. */
+ def reset(): Unit = {
+ ju.Arrays.fill(ata, 0.0)
+ ju.Arrays.fill(atb, 0.0)
+ n = 0
+ }
+ }
+
+ /**
+ * Implementation of the ALS algorithm.
+ */
+ private def train(
+ ratings: RDD[Rating],
+ rank: Int = 10,
+ numUserBlocks: Int = 10,
+ numItemBlocks: Int = 10,
+ maxIter: Int = 10,
+ regParam: Double = 1.0,
+ implicitPrefs: Boolean = false,
+ alpha: Double = 1.0): (RDD[(Int, Array[Float])], RDD[(Int, Array[Float])]) = {
+ val userPart = new HashPartitioner(numUserBlocks)
+ val itemPart = new HashPartitioner(numItemBlocks)
+ val userLocalIndexEncoder = new LocalIndexEncoder(userPart.numPartitions)
+ val itemLocalIndexEncoder = new LocalIndexEncoder(itemPart.numPartitions)
+ val blockRatings = partitionRatings(ratings, userPart, itemPart).cache()
+ val (userInBlocks, userOutBlocks) = makeBlocks("user", blockRatings, userPart, itemPart)
+ // materialize blockRatings and user blocks
+ userOutBlocks.count()
+ val swappedBlockRatings = blockRatings.map {
+ case ((userBlockId, itemBlockId), RatingBlock(userIds, itemIds, localRatings)) =>
+ ((itemBlockId, userBlockId), RatingBlock(itemIds, userIds, localRatings))
+ }
+ val (itemInBlocks, itemOutBlocks) = makeBlocks("item", swappedBlockRatings, itemPart, userPart)
+ // materialize item blocks
+ itemOutBlocks.count()
+ var userFactors = initialize(userInBlocks, rank)
+ var itemFactors = initialize(itemInBlocks, rank)
+ if (implicitPrefs) {
+ for (iter <- 1 to maxIter) {
+ userFactors.setName(s"userFactors-$iter").persist()
+ val previousItemFactors = itemFactors
+ itemFactors = computeFactors(userFactors, userOutBlocks, itemInBlocks, rank, regParam,
+ userLocalIndexEncoder, implicitPrefs, alpha)
+ previousItemFactors.unpersist()
+ itemFactors.setName(s"itemFactors-$iter").persist()
+ val previousUserFactors = userFactors
+ userFactors = computeFactors(itemFactors, itemOutBlocks, userInBlocks, rank, regParam,
+ itemLocalIndexEncoder, implicitPrefs, alpha)
+ previousUserFactors.unpersist()
+ }
+ } else {
+ for (iter <- 0 until maxIter) {
+ itemFactors = computeFactors(userFactors, userOutBlocks, itemInBlocks, rank, regParam,
+ userLocalIndexEncoder)
+ userFactors = computeFactors(itemFactors, itemOutBlocks, userInBlocks, rank, regParam,
+ itemLocalIndexEncoder)
+ }
+ }
+ val userIdAndFactors = userInBlocks
+ .mapValues(_.srcIds)
+ .join(userFactors)
+ .values
+ .setName("userFactors")
+ .cache()
+ userIdAndFactors.count()
+ itemFactors.unpersist()
+ val itemIdAndFactors = itemInBlocks
+ .mapValues(_.srcIds)
+ .join(itemFactors)
+ .values
+ .setName("itemFactors")
+ .cache()
+ itemIdAndFactors.count()
+ userInBlocks.unpersist()
+ userOutBlocks.unpersist()
+ itemInBlocks.unpersist()
+ itemOutBlocks.unpersist()
+ blockRatings.unpersist()
+ val userOutput = userIdAndFactors.flatMap { case (ids, factors) =>
+ ids.view.zip(factors)
+ }
+ val itemOutput = itemIdAndFactors.flatMap { case (ids, factors) =>
+ ids.view.zip(factors)
+ }
+ (userOutput, itemOutput)
+ }
+
+ /**
+ * Factor block that stores factors (Array[Float]) in an Array.
+ */
+ private type FactorBlock = Array[Array[Float]]
+
+ /**
+ * Out-link block that stores, for each dst (item/user) block, which src (user/item) factors to
+ * send. For example, outLinkBlock(0) contains the local indices (not the original src IDs) of the
+ * src factors in this block to send to dst block 0.
+ */
+ private type OutBlock = Array[Array[Int]]
+
+ /**
+ * In-link block for computing src (user/item) factors. This includes the original src IDs
+ * of the elements within this block as well as encoded dst (item/user) indices and corresponding
+ * ratings. The dst indices are in the form of (blockId, localIndex), which are not the original
+ * dst IDs. To compute src factors, we expect receiving dst factors that match the dst indices.
+ * For example, if we have an in-link record
+ *
+ * {srcId: 0, dstBlockId: 2, dstLocalIndex: 3, rating: 5.0},
+ *
+ * and assume that the dst factors are stored as dstFactors: Map[Int, Array[Array[Float]]], which
+ * is a blockId to dst factors map, the corresponding dst factor of the record is dstFactor(2)(3).
+ *
+ * We use a CSC-like (compressed sparse column) format to store the in-link information. So we can
+ * compute src factors one after another using only one normal equation instance.
+ *
+ * @param srcIds src ids (ordered)
+ * @param dstPtrs dst pointers. Elements in range [dstPtrs(i), dstPtrs(i+1)) of dst indices and
+ * ratings are associated with srcIds(i).
+ * @param dstEncodedIndices encoded dst indices
+ * @param ratings ratings
+ *
+ * @see [[LocalIndexEncoder]]
+ */
+ private[recommendation] case class InBlock(
+ srcIds: Array[Int],
+ dstPtrs: Array[Int],
+ dstEncodedIndices: Array[Int],
+ ratings: Array[Float]) {
+ /** Size of the block. */
+ val size: Int = ratings.size
+
+ require(dstEncodedIndices.size == size)
+ require(dstPtrs.size == srcIds.size + 1)
+ }
+
+ /**
+ * Initializes factors randomly given the in-link blocks.
+ *
+ * @param inBlocks in-link blocks
+ * @param rank rank
+ * @return initialized factor blocks
+ */
+ private def initialize(inBlocks: RDD[(Int, InBlock)], rank: Int): RDD[(Int, FactorBlock)] = {
+ // Choose a unit vector uniformly at random from the unit sphere, but from the
+ // "first quadrant" where all elements are nonnegative. This can be done by choosing
+ // elements distributed as Normal(0,1) and taking the absolute value, and then normalizing.
+ // This appears to create factorizations that have a slightly better reconstruction
+ // (<1%) compared picking elements uniformly at random in [0,1].
+ inBlocks.map { case (srcBlockId, inBlock) =>
+ val random = new XORShiftRandom(srcBlockId)
+ val factors = Array.fill(inBlock.srcIds.size) {
+ val factor = Array.fill(rank)(random.nextGaussian().toFloat)
+ val nrm = blas.snrm2(rank, factor, 1)
+ blas.sscal(rank, 1.0f / nrm, factor, 1)
+ factor
+ }
+ (srcBlockId, factors)
+ }
+ }
+
+ /**
+ * A rating block that contains src IDs, dst IDs, and ratings, stored in primitive arrays.
+ */
+ private[recommendation]
+ case class RatingBlock(srcIds: Array[Int], dstIds: Array[Int], ratings: Array[Float]) {
+ /** Size of the block. */
+ val size: Int = srcIds.size
+ require(dstIds.size == size)
+ require(ratings.size == size)
+ }
+
+ /**
+ * Builder for [[RatingBlock]]. [[mutable.ArrayBuilder]] is used to avoid boxing/unboxing.
+ */
+ private[recommendation] class RatingBlockBuilder extends Serializable {
+
+ private val srcIds = mutable.ArrayBuilder.make[Int]
+ private val dstIds = mutable.ArrayBuilder.make[Int]
+ private val ratings = mutable.ArrayBuilder.make[Float]
+ var size = 0
+
+ /** Adds a rating. */
+ def add(r: Rating): this.type = {
+ size += 1
+ srcIds += r.user
+ dstIds += r.item
+ ratings += r.rating
+ this
+ }
+
+ /** Merges another [[RatingBlockBuilder]]. */
+ def merge(other: RatingBlock): this.type = {
+ size += other.srcIds.size
+ srcIds ++= other.srcIds
+ dstIds ++= other.dstIds
+ ratings ++= other.ratings
+ this
+ }
+
+ /** Builds a [[RatingBlock]]. */
+ def build(): RatingBlock = {
+ RatingBlock(srcIds.result(), dstIds.result(), ratings.result())
+ }
+ }
+
+ /**
+ * Partitions raw ratings into blocks.
+ *
+ * @param ratings raw ratings
+ * @param srcPart partitioner for src IDs
+ * @param dstPart partitioner for dst IDs
+ *
+ * @return an RDD of rating blocks in the form of ((srcBlockId, dstBlockId), ratingBlock)
+ */
+ private def partitionRatings(
+ ratings: RDD[Rating],
+ srcPart: Partitioner,
+ dstPart: Partitioner): RDD[((Int, Int), RatingBlock)] = {
+
+ /* The implementation produces the same result as the following but generates less objects.
+
+ ratings.map { r =>
+ ((srcPart.getPartition(r.user), dstPart.getPartition(r.item)), r)
+ }.aggregateByKey(new RatingBlockBuilder)(
+ seqOp = (b, r) => b.add(r),
+ combOp = (b0, b1) => b0.merge(b1.build()))
+ .mapValues(_.build())
+ */
+
+ val numPartitions = srcPart.numPartitions * dstPart.numPartitions
+ ratings.mapPartitions { iter =>
+ val builders = Array.fill(numPartitions)(new RatingBlockBuilder)
+ iter.flatMap { r =>
+ val srcBlockId = srcPart.getPartition(r.user)
+ val dstBlockId = dstPart.getPartition(r.item)
+ val idx = srcBlockId + srcPart.numPartitions * dstBlockId
+ val builder = builders(idx)
+ builder.add(r)
+ if (builder.size >= 2048) { // 2048 * (3 * 4) = 24k
+ builders(idx) = new RatingBlockBuilder
+ Iterator.single(((srcBlockId, dstBlockId), builder.build()))
+ } else {
+ Iterator.empty
+ }
+ } ++ {
+ builders.view.zipWithIndex.filter(_._1.size > 0).map { case (block, idx) =>
+ val srcBlockId = idx % srcPart.numPartitions
+ val dstBlockId = idx / srcPart.numPartitions
+ ((srcBlockId, dstBlockId), block.build())
+ }
+ }
+ }.groupByKey().mapValues { blocks =>
+ val builder = new RatingBlockBuilder
+ blocks.foreach(builder.merge)
+ builder.build()
+ }.setName("ratingBlocks")
+ }
+
+ /**
+ * Builder for uncompressed in-blocks of (srcId, dstEncodedIndex, rating) tuples.
+ * @param encoder encoder for dst indices
+ */
+ private[recommendation] class UncompressedInBlockBuilder(encoder: LocalIndexEncoder) {
+
+ private val srcIds = mutable.ArrayBuilder.make[Int]
+ private val dstEncodedIndices = mutable.ArrayBuilder.make[Int]
+ private val ratings = mutable.ArrayBuilder.make[Float]
+
+ /**
+ * Adds a dst block of (srcId, dstLocalIndex, rating) tuples.
+ *
+ * @param dstBlockId dst block ID
+ * @param srcIds original src IDs
+ * @param dstLocalIndices dst local indices
+ * @param ratings ratings
+ */
+ def add(
+ dstBlockId: Int,
+ srcIds: Array[Int],
+ dstLocalIndices: Array[Int],
+ ratings: Array[Float]): this.type = {
+ val sz = srcIds.size
+ require(dstLocalIndices.size == sz)
+ require(ratings.size == sz)
+ this.srcIds ++= srcIds
+ this.ratings ++= ratings
+ var j = 0
+ while (j < sz) {
+ this.dstEncodedIndices += encoder.encode(dstBlockId, dstLocalIndices(j))
+ j += 1
+ }
+ this
+ }
+
+ /** Builds a [[UncompressedInBlock]]. */
+ def build(): UncompressedInBlock = {
+ new UncompressedInBlock(srcIds.result(), dstEncodedIndices.result(), ratings.result())
+ }
+ }
+
+ /**
+ * A block of (srcId, dstEncodedIndex, rating) tuples stored in primitive arrays.
+ */
+ private[recommendation] class UncompressedInBlock(
+ val srcIds: Array[Int],
+ val dstEncodedIndices: Array[Int],
+ val ratings: Array[Float]) {
+
+ /** Size the of block. */
+ def size: Int = srcIds.size
+
+ /**
+ * Compresses the block into an [[InBlock]]. The algorithm is the same as converting a
+ * sparse matrix from coordinate list (COO) format into compressed sparse column (CSC) format.
+ * Sorting is done using Spark's built-in Timsort to avoid generating too many objects.
+ */
+ def compress(): InBlock = {
+ val sz = size
+ assert(sz > 0, "Empty in-link block should not exist.")
+ sort()
+ val uniqueSrcIdsBuilder = mutable.ArrayBuilder.make[Int]
+ val dstCountsBuilder = mutable.ArrayBuilder.make[Int]
+ var preSrcId = srcIds(0)
+ uniqueSrcIdsBuilder += preSrcId
+ var curCount = 1
+ var i = 1
+ var j = 0
+ while (i < sz) {
+ val srcId = srcIds(i)
+ if (srcId != preSrcId) {
+ uniqueSrcIdsBuilder += srcId
+ dstCountsBuilder += curCount
+ preSrcId = srcId
+ j += 1
+ curCount = 0
+ }
+ curCount += 1
+ i += 1
+ }
+ dstCountsBuilder += curCount
+ val uniqueSrcIds = uniqueSrcIdsBuilder.result()
+ val numUniqueSrdIds = uniqueSrcIds.size
+ val dstCounts = dstCountsBuilder.result()
+ val dstPtrs = new Array[Int](numUniqueSrdIds + 1)
+ var sum = 0
+ i = 0
+ while (i < numUniqueSrdIds) {
+ sum += dstCounts(i)
+ i += 1
+ dstPtrs(i) = sum
+ }
+ InBlock(uniqueSrcIds, dstPtrs, dstEncodedIndices, ratings)
+ }
+
+ private def sort(): Unit = {
+ val sz = size
+ // Since there might be interleaved log messages, we insert a unique id for easy pairing.
+ val sortId = Utils.random.nextInt()
+ logDebug(s"Start sorting an uncompressed in-block of size $sz. (sortId = $sortId)")
+ val start = System.nanoTime()
+ val sorter = new Sorter(new UncompressedInBlockSort)
+ sorter.sort(this, 0, size, Ordering[IntWrapper])
+ val duration = (System.nanoTime() - start) / 1e9
+ logDebug(s"Sorting took $duration seconds. (sortId = $sortId)")
+ }
+ }
+
+ /**
+ * A wrapper that holds a primitive integer key.
+ *
+ * @see [[UncompressedInBlockSort]]
+ */
+ private class IntWrapper(var key: Int = 0) extends Ordered[IntWrapper] {
+ override def compare(that: IntWrapper): Int = {
+ key.compare(that.key)
+ }
+ }
+
+ /**
+ * [[SortDataFormat]] of [[UncompressedInBlock]] used by [[Sorter]].
+ */
+ private class UncompressedInBlockSort extends SortDataFormat[IntWrapper, UncompressedInBlock] {
+
+ override def newKey(): IntWrapper = new IntWrapper()
+
+ override def getKey(
+ data: UncompressedInBlock,
+ pos: Int,
+ reuse: IntWrapper): IntWrapper = {
+ if (reuse == null) {
+ new IntWrapper(data.srcIds(pos))
+ } else {
+ reuse.key = data.srcIds(pos)
+ reuse
+ }
+ }
+
+ override def getKey(
+ data: UncompressedInBlock,
+ pos: Int): IntWrapper = {
+ getKey(data, pos, null)
+ }
+
+ private def swapElements[@specialized(Int, Float) T](
+ data: Array[T],
+ pos0: Int,
+ pos1: Int): Unit = {
+ val tmp = data(pos0)
+ data(pos0) = data(pos1)
+ data(pos1) = tmp
+ }
+
+ override def swap(data: UncompressedInBlock, pos0: Int, pos1: Int): Unit = {
+ swapElements(data.srcIds, pos0, pos1)
+ swapElements(data.dstEncodedIndices, pos0, pos1)
+ swapElements(data.ratings, pos0, pos1)
+ }
+
+ override def copyRange(
+ src: UncompressedInBlock,
+ srcPos: Int,
+ dst: UncompressedInBlock,
+ dstPos: Int,
+ length: Int): Unit = {
+ System.arraycopy(src.srcIds, srcPos, dst.srcIds, dstPos, length)
+ System.arraycopy(src.dstEncodedIndices, srcPos, dst.dstEncodedIndices, dstPos, length)
+ System.arraycopy(src.ratings, srcPos, dst.ratings, dstPos, length)
+ }
+
+ override def allocate(length: Int): UncompressedInBlock = {
+ new UncompressedInBlock(
+ new Array[Int](length), new Array[Int](length), new Array[Float](length))
+ }
+
+ override def copyElement(
+ src: UncompressedInBlock,
+ srcPos: Int,
+ dst: UncompressedInBlock,
+ dstPos: Int): Unit = {
+ dst.srcIds(dstPos) = src.srcIds(srcPos)
+ dst.dstEncodedIndices(dstPos) = src.dstEncodedIndices(srcPos)
+ dst.ratings(dstPos) = src.ratings(srcPos)
+ }
+ }
+
+ /**
+ * Creates in-blocks and out-blocks from rating blocks.
+ * @param prefix prefix for in/out-block names
+ * @param ratingBlocks rating blocks
+ * @param srcPart partitioner for src IDs
+ * @param dstPart partitioner for dst IDs
+ * @return (in-blocks, out-blocks)
+ */
+ private def makeBlocks(
+ prefix: String,
+ ratingBlocks: RDD[((Int, Int), RatingBlock)],
+ srcPart: Partitioner,
+ dstPart: Partitioner): (RDD[(Int, InBlock)], RDD[(Int, OutBlock)]) = {
+ val inBlocks = ratingBlocks.map {
+ case ((srcBlockId, dstBlockId), RatingBlock(srcIds, dstIds, ratings)) =>
+ // The implementation is a faster version of
+ // val dstIdToLocalIndex = dstIds.toSet.toSeq.sorted.zipWithIndex.toMap
+ val start = System.nanoTime()
+ val dstIdSet = new OpenHashSet[Int](1 << 20)
+ dstIds.foreach(dstIdSet.add)
+ val sortedDstIds = new Array[Int](dstIdSet.size)
+ var i = 0
+ var pos = dstIdSet.nextPos(0)
+ while (pos != -1) {
+ sortedDstIds(i) = dstIdSet.getValue(pos)
+ pos = dstIdSet.nextPos(pos + 1)
+ i += 1
+ }
+ assert(i == dstIdSet.size)
+ ju.Arrays.sort(sortedDstIds)
+ val dstIdToLocalIndex = new OpenHashMap[Int, Int](sortedDstIds.size)
+ i = 0
+ while (i < sortedDstIds.size) {
+ dstIdToLocalIndex.update(sortedDstIds(i), i)
+ i += 1
+ }
+ logDebug(
+ "Converting to local indices took " + (System.nanoTime() - start) / 1e9 + " seconds.")
+ val dstLocalIndices = dstIds.map(dstIdToLocalIndex.apply)
+ (srcBlockId, (dstBlockId, srcIds, dstLocalIndices, ratings))
+ }.groupByKey(new HashPartitioner(srcPart.numPartitions))
+ .mapValues { iter =>
+ val builder =
+ new UncompressedInBlockBuilder(new LocalIndexEncoder(dstPart.numPartitions))
+ iter.foreach { case (dstBlockId, srcIds, dstLocalIndices, ratings) =>
+ builder.add(dstBlockId, srcIds, dstLocalIndices, ratings)
+ }
+ builder.build().compress()
+ }.setName(prefix + "InBlocks").cache()
+ val outBlocks = inBlocks.mapValues { case InBlock(srcIds, dstPtrs, dstEncodedIndices, _) =>
+ val encoder = new LocalIndexEncoder(dstPart.numPartitions)
+ val activeIds = Array.fill(dstPart.numPartitions)(mutable.ArrayBuilder.make[Int])
+ var i = 0
+ val seen = new Array[Boolean](dstPart.numPartitions)
+ while (i < srcIds.size) {
+ var j = dstPtrs(i)
+ ju.Arrays.fill(seen, false)
+ while (j < dstPtrs(i + 1)) {
+ val dstBlockId = encoder.blockId(dstEncodedIndices(j))
+ if (!seen(dstBlockId)) {
+ activeIds(dstBlockId) += i // add the local index in this out-block
+ seen(dstBlockId) = true
+ }
+ j += 1
+ }
+ i += 1
+ }
+ activeIds.map { x =>
+ x.result()
+ }
+ }.setName(prefix + "OutBlocks").cache()
+ (inBlocks, outBlocks)
+ }
+
+ /**
+ * Compute dst factors by constructing and solving least square problems.
+ *
+ * @param srcFactorBlocks src factors
+ * @param srcOutBlocks src out-blocks
+ * @param dstInBlocks dst in-blocks
+ * @param rank rank
+ * @param regParam regularization constant
+ * @param srcEncoder encoder for src local indices
+ * @param implicitPrefs whether to use implicit preference
+ * @param alpha the alpha constant in the implicit preference formulation
+ *
+ * @return dst factors
+ */
+ private def computeFactors(
+ srcFactorBlocks: RDD[(Int, FactorBlock)],
+ srcOutBlocks: RDD[(Int, OutBlock)],
+ dstInBlocks: RDD[(Int, InBlock)],
+ rank: Int,
+ regParam: Double,
+ srcEncoder: LocalIndexEncoder,
+ implicitPrefs: Boolean = false,
+ alpha: Double = 1.0): RDD[(Int, FactorBlock)] = {
+ val numSrcBlocks = srcFactorBlocks.partitions.size
+ val YtY = if (implicitPrefs) Some(computeYtY(srcFactorBlocks, rank)) else None
+ val srcOut = srcOutBlocks.join(srcFactorBlocks).flatMap {
+ case (srcBlockId, (srcOutBlock, srcFactors)) =>
+ srcOutBlock.view.zipWithIndex.map { case (activeIndices, dstBlockId) =>
+ (dstBlockId, (srcBlockId, activeIndices.map(idx => srcFactors(idx))))
+ }
+ }
+ val merged = srcOut.groupByKey(new HashPartitioner(dstInBlocks.partitions.size))
+ dstInBlocks.join(merged).mapValues {
+ case (InBlock(dstIds, srcPtrs, srcEncodedIndices, ratings), srcFactors) =>
+ val sortedSrcFactors = new Array[FactorBlock](numSrcBlocks)
+ srcFactors.foreach { case (srcBlockId, factors) =>
+ sortedSrcFactors(srcBlockId) = factors
+ }
+ val dstFactors = new Array[Array[Float]](dstIds.size)
+ var j = 0
+ val ls = new NormalEquation(rank)
+ val solver = new CholeskySolver // TODO: add NNLS solver
+ while (j < dstIds.size) {
+ ls.reset()
+ if (implicitPrefs) {
+ ls.merge(YtY.get)
+ }
+ var i = srcPtrs(j)
+ while (i < srcPtrs(j + 1)) {
+ val encoded = srcEncodedIndices(i)
+ val blockId = srcEncoder.blockId(encoded)
+ val localIndex = srcEncoder.localIndex(encoded)
+ val srcFactor = sortedSrcFactors(blockId)(localIndex)
+ val rating = ratings(i)
+ if (implicitPrefs) {
+ ls.addImplicit(srcFactor, rating, alpha)
+ } else {
+ ls.add(srcFactor, rating)
+ }
+ i += 1
+ }
+ dstFactors(j) = solver.solve(ls, regParam)
+ j += 1
+ }
+ dstFactors
+ }
+ }
+
+ /**
+ * Computes the Gramian matrix of user or item factors, which is only used in implicit preference.
+ * Caching of the input factors is handled in [[ALS#train]].
+ */
+ private def computeYtY(factorBlocks: RDD[(Int, FactorBlock)], rank: Int): NormalEquation = {
+ factorBlocks.values.aggregate(new NormalEquation(rank))(
+ seqOp = (ne, factors) => {
+ factors.foreach(ne.add(_, 0.0f))
+ ne
+ },
+ combOp = (ne1, ne2) => ne1.merge(ne2))
+ }
+
+ /**
+ * Encoder for storing (blockId, localIndex) into a single integer.
+ *
+ * We use the leading bits (including the sign bit) to store the block id and the rest to store
+ * the local index. This is based on the assumption that users/items are approximately evenly
+ * partitioned. With this assumption, we should be able to encode two billion distinct values.
+ *
+ * @param numBlocks number of blocks
+ */
+ private[recommendation] class LocalIndexEncoder(numBlocks: Int) extends Serializable {
+
+ require(numBlocks > 0, s"numBlocks must be positive but found $numBlocks.")
+
+ private[this] final val numLocalIndexBits =
+ math.min(java.lang.Integer.numberOfLeadingZeros(numBlocks - 1), 31)
+ private[this] final val localIndexMask = (1 << numLocalIndexBits) - 1
+
+ /** Encodes a (blockId, localIndex) into a single integer. */
+ def encode(blockId: Int, localIndex: Int): Int = {
+ require(blockId < numBlocks)
+ require((localIndex & ~localIndexMask) == 0)
+ (blockId << numLocalIndexBits) | localIndex
+ }
+
+ /** Gets the block id from an encoded index. */
+ @inline
+ def blockId(encoded: Int): Int = {
+ encoded >>> numLocalIndexBits
+ }
+
+ /** Gets the local index from an encoded index. */
+ @inline
+ def localIndex(encoded: Int): Int = {
+ encoded & localIndexMask
+ }
+ }
+}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala b/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala
index 08fe99176424a..5d51c51346665 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala
@@ -24,7 +24,7 @@ import org.apache.spark.annotation.AlphaComponent
import org.apache.spark.ml._
import org.apache.spark.ml.param.{IntParam, Param, ParamMap, Params}
import org.apache.spark.mllib.util.MLUtils
-import org.apache.spark.sql.SchemaRDD
+import org.apache.spark.sql.DataFrame
import org.apache.spark.sql.types.StructType
/**
@@ -64,7 +64,7 @@ class CrossValidator extends Estimator[CrossValidatorModel] with CrossValidatorP
def setEvaluator(value: Evaluator): this.type = set(evaluator, value)
def setNumFolds(value: Int): this.type = set(numFolds, value)
- override def fit(dataset: SchemaRDD, paramMap: ParamMap): CrossValidatorModel = {
+ override def fit(dataset: DataFrame, paramMap: ParamMap): CrossValidatorModel = {
val map = this.paramMap ++ paramMap
val schema = dataset.schema
transformSchema(dataset.schema, paramMap, logging = true)
@@ -74,7 +74,7 @@ class CrossValidator extends Estimator[CrossValidatorModel] with CrossValidatorP
val epm = map(estimatorParamMaps)
val numModels = epm.size
val metrics = new Array[Double](epm.size)
- val splits = MLUtils.kFold(dataset, map(numFolds), 0)
+ val splits = MLUtils.kFold(dataset.rdd, map(numFolds), 0)
splits.zipWithIndex.foreach { case ((training, validation), splitIndex) =>
val trainingDataset = sqlCtx.applySchema(training, schema).cache()
val validationDataset = sqlCtx.applySchema(validation, schema).cache()
@@ -117,7 +117,7 @@ class CrossValidatorModel private[ml] (
val bestModel: Model[_])
extends Model[CrossValidatorModel] with CrossValidatorParams {
- override def transform(dataset: SchemaRDD, paramMap: ParamMap): SchemaRDD = {
+ override def transform(dataset: DataFrame, paramMap: ParamMap): DataFrame = {
bestModel.transform(dataset, paramMap)
}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala
index 555da8c7e7ab3..430d763ef7ca7 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala
@@ -266,12 +266,16 @@ class PythonMLLibAPI extends Serializable {
k: Int,
maxIterations: Int,
runs: Int,
- initializationMode: String): KMeansModel = {
+ initializationMode: String,
+ seed: java.lang.Long): KMeansModel = {
val kMeansAlg = new KMeans()
.setK(k)
.setMaxIterations(maxIterations)
.setRuns(runs)
.setInitializationMode(initializationMode)
+
+ if (seed != null) kMeansAlg.setSeed(seed)
+
try {
kMeansAlg.run(data.rdd.persist(StorageLevel.MEMORY_AND_DISK))
} finally {
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureEM.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureEM.scala
index d8e134619411b..899fe5e9e9cf2 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureEM.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureEM.scala
@@ -134,9 +134,7 @@ class GaussianMixtureEM private (
// diagonal covariance matrices using component variances
// derived from the samples
val (weights, gaussians) = initialModel match {
- case Some(gmm) => (gmm.weight, gmm.mu.zip(gmm.sigma).map { case(mu, sigma) =>
- new MultivariateGaussian(mu, sigma)
- })
+ case Some(gmm) => (gmm.weights, gmm.gaussians)
case None => {
val samples = breezeData.takeSample(withReplacement = true, k * nSamples, seed)
@@ -176,10 +174,7 @@ class GaussianMixtureEM private (
iter += 1
}
- // Need to convert the breeze matrices to MLlib matrices
- val means = Array.tabulate(k) { i => gaussians(i).mu }
- val sigmas = Array.tabulate(k) { i => gaussians(i).sigma }
- new GaussianMixtureModel(weights, means, sigmas)
+ new GaussianMixtureModel(weights, gaussians)
}
/** Average of dense breeze vectors */
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureModel.scala
index 416cad080c408..1a2178ee7f711 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureModel.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureModel.scala
@@ -20,7 +20,7 @@ package org.apache.spark.mllib.clustering
import breeze.linalg.{DenseVector => BreezeVector}
import org.apache.spark.rdd.RDD
-import org.apache.spark.mllib.linalg.{Matrix, Vector}
+import org.apache.spark.mllib.linalg.Vector
import org.apache.spark.mllib.stat.distribution.MultivariateGaussian
import org.apache.spark.mllib.util.MLUtils
@@ -36,12 +36,13 @@ import org.apache.spark.mllib.util.MLUtils
* covariance matrix for Gaussian i
*/
class GaussianMixtureModel(
- val weight: Array[Double],
- val mu: Array[Vector],
- val sigma: Array[Matrix]) extends Serializable {
+ val weights: Array[Double],
+ val gaussians: Array[MultivariateGaussian]) extends Serializable {
+
+ require(weights.length == gaussians.length, "Length of weight and Gaussian arrays must match")
/** Number of gaussians in mixture */
- def k: Int = weight.length
+ def k: Int = weights.length
/** Maps given points to their cluster indices. */
def predict(points: RDD[Vector]): RDD[Int] = {
@@ -55,14 +56,10 @@ class GaussianMixtureModel(
*/
def predictSoft(points: RDD[Vector]): RDD[Array[Double]] = {
val sc = points.sparkContext
- val dists = sc.broadcast {
- (0 until k).map { i =>
- new MultivariateGaussian(mu(i).toBreeze.toDenseVector, sigma(i).toBreeze.toDenseMatrix)
- }.toArray
- }
- val weights = sc.broadcast(weight)
+ val bcDists = sc.broadcast(gaussians)
+ val bcWeights = sc.broadcast(weights)
points.map { x =>
- computeSoftAssignments(x.toBreeze.toDenseVector, dists.value, weights.value, k)
+ computeSoftAssignments(x.toBreeze.toDenseVector, bcDists.value, bcWeights.value, k)
}
}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala
index 54c301d3e9e14..11633e8242313 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala
@@ -19,14 +19,14 @@ package org.apache.spark.mllib.clustering
import scala.collection.mutable.ArrayBuffer
-import org.apache.spark.annotation.Experimental
import org.apache.spark.Logging
-import org.apache.spark.SparkContext._
+import org.apache.spark.annotation.Experimental
import org.apache.spark.mllib.linalg.{Vector, Vectors}
import org.apache.spark.mllib.linalg.BLAS.{axpy, scal}
import org.apache.spark.mllib.util.MLUtils
import org.apache.spark.rdd.RDD
import org.apache.spark.storage.StorageLevel
+import org.apache.spark.util.Utils
import org.apache.spark.util.random.XORShiftRandom
/**
@@ -43,13 +43,14 @@ class KMeans private (
private var runs: Int,
private var initializationMode: String,
private var initializationSteps: Int,
- private var epsilon: Double) extends Serializable with Logging {
+ private var epsilon: Double,
+ private var seed: Long) extends Serializable with Logging {
/**
* Constructs a KMeans instance with default parameters: {k: 2, maxIterations: 20, runs: 1,
- * initializationMode: "k-means||", initializationSteps: 5, epsilon: 1e-4}.
+ * initializationMode: "k-means||", initializationSteps: 5, epsilon: 1e-4, seed: random}.
*/
- def this() = this(2, 20, 1, KMeans.K_MEANS_PARALLEL, 5, 1e-4)
+ def this() = this(2, 20, 1, KMeans.K_MEANS_PARALLEL, 5, 1e-4, Utils.random.nextLong())
/** Set the number of clusters to create (k). Default: 2. */
def setK(k: Int): this.type = {
@@ -112,6 +113,12 @@ class KMeans private (
this
}
+ /** Set the random seed for cluster initialization. */
+ def setSeed(seed: Long): this.type = {
+ this.seed = seed
+ this
+ }
+
/**
* Train a K-means model on the given set of points; `data` should be cached for high
* performance, because this is an iterative algorithm.
@@ -255,7 +262,7 @@ class KMeans private (
private def initRandom(data: RDD[VectorWithNorm])
: Array[Array[VectorWithNorm]] = {
// Sample all the cluster centers in one pass to avoid repeated scans
- val sample = data.takeSample(true, runs * k, new XORShiftRandom().nextInt()).toSeq
+ val sample = data.takeSample(true, runs * k, new XORShiftRandom(this.seed).nextInt()).toSeq
Array.tabulate(runs)(r => sample.slice(r * k, (r + 1) * k).map { v =>
new VectorWithNorm(Vectors.dense(v.vector.toArray), v.norm)
}.toArray)
@@ -272,45 +279,81 @@ class KMeans private (
*/
private def initKMeansParallel(data: RDD[VectorWithNorm])
: Array[Array[VectorWithNorm]] = {
- // Initialize each run's center to a random point
- val seed = new XORShiftRandom().nextInt()
+ // Initialize empty centers and point costs.
+ val centers = Array.tabulate(runs)(r => ArrayBuffer.empty[VectorWithNorm])
+ var costs = data.map(_ => Vectors.dense(Array.fill(runs)(Double.PositiveInfinity))).cache()
+
+ // Initialize each run's first center to a random point.
+ val seed = new XORShiftRandom(this.seed).nextInt()
val sample = data.takeSample(true, runs, seed).toSeq
- val centers = Array.tabulate(runs)(r => ArrayBuffer(sample(r).toDense))
+ val newCenters = Array.tabulate(runs)(r => ArrayBuffer(sample(r).toDense))
+
+ /** Merges new centers to centers. */
+ def mergeNewCenters(): Unit = {
+ var r = 0
+ while (r < runs) {
+ centers(r) ++= newCenters(r)
+ newCenters(r).clear()
+ r += 1
+ }
+ }
// On each step, sample 2 * k points on average for each run with probability proportional
- // to their squared distance from that run's current centers
+ // to their squared distance from that run's centers. Note that only distances between points
+ // and new centers are computed in each iteration.
var step = 0
while (step < initializationSteps) {
- val bcCenters = data.context.broadcast(centers)
- val sumCosts = data.flatMap { point =>
- (0 until runs).map { r =>
- (r, KMeans.pointCost(bcCenters.value(r), point))
- }
- }.reduceByKey(_ + _).collectAsMap()
- val chosen = data.mapPartitionsWithIndex { (index, points) =>
+ val bcNewCenters = data.context.broadcast(newCenters)
+ val preCosts = costs
+ costs = data.zip(preCosts).map { case (point, cost) =>
+ Vectors.dense(
+ Array.tabulate(runs) { r =>
+ math.min(KMeans.pointCost(bcNewCenters.value(r), point), cost(r))
+ })
+ }.cache()
+ val sumCosts = costs
+ .aggregate(Vectors.zeros(runs))(
+ seqOp = (s, v) => {
+ // s += v
+ axpy(1.0, v, s)
+ s
+ },
+ combOp = (s0, s1) => {
+ // s0 += s1
+ axpy(1.0, s1, s0)
+ s0
+ }
+ )
+ preCosts.unpersist(blocking = false)
+ val chosen = data.zip(costs).mapPartitionsWithIndex { (index, pointsWithCosts) =>
val rand = new XORShiftRandom(seed ^ (step << 16) ^ index)
- points.flatMap { p =>
- (0 until runs).filter { r =>
- rand.nextDouble() < 2.0 * KMeans.pointCost(bcCenters.value(r), p) * k / sumCosts(r)
- }.map((_, p))
+ pointsWithCosts.flatMap { case (p, c) =>
+ val rs = (0 until runs).filter { r =>
+ rand.nextDouble() < 2.0 * c(r) * k / sumCosts(r)
+ }
+ if (rs.length > 0) Some(p, rs) else None
}
}.collect()
- chosen.foreach { case (r, p) =>
- centers(r) += p.toDense
+ mergeNewCenters()
+ chosen.foreach { case (p, rs) =>
+ rs.foreach(newCenters(_) += p.toDense)
}
step += 1
}
+ mergeNewCenters()
+ costs.unpersist(blocking = false)
+
// Finally, we might have a set of more than k candidate centers for each run; weigh each
// candidate by the number of points in the dataset mapping to it and run a local k-means++
// on the weighted centers to pick just k of them
val bcCenters = data.context.broadcast(centers)
val weightMap = data.flatMap { p =>
- (0 until runs).map { r =>
+ Iterator.tabulate(runs) { r =>
((r, KMeans.findClosest(bcCenters.value(r), p)._1), 1.0)
}
}.reduceByKey(_ + _).collectAsMap()
- val finalCenters = (0 until runs).map { r =>
+ val finalCenters = (0 until runs).par.map { r =>
val myCenters = centers(r).toArray
val myWeights = (0 until myCenters.length).map(i => weightMap.getOrElse((r, i), 0.0)).toArray
LocalKMeans.kMeansPlusPlus(r, myCenters, myWeights, k, 30)
@@ -333,7 +376,32 @@ object KMeans {
/**
* Trains a k-means model using the given set of parameters.
*
- * @param data training points stored as `RDD[Array[Double]]`
+ * @param data training points stored as `RDD[Vector]`
+ * @param k number of clusters
+ * @param maxIterations max number of iterations
+ * @param runs number of parallel runs, defaults to 1. The best model is returned.
+ * @param initializationMode initialization model, either "random" or "k-means||" (default).
+ * @param seed random seed value for cluster initialization
+ */
+ def train(
+ data: RDD[Vector],
+ k: Int,
+ maxIterations: Int,
+ runs: Int,
+ initializationMode: String,
+ seed: Long): KMeansModel = {
+ new KMeans().setK(k)
+ .setMaxIterations(maxIterations)
+ .setRuns(runs)
+ .setInitializationMode(initializationMode)
+ .setSeed(seed)
+ .run(data)
+ }
+
+ /**
+ * Trains a k-means model using the given set of parameters.
+ *
+ * @param data training points stored as `RDD[Vector]`
* @param k number of clusters
* @param maxIterations max number of iterations
* @param runs number of parallel runs, defaults to 1. The best model is returned.
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/BLAS.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/BLAS.scala
index 3414daccd7ca4..34e0392f1b21a 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/BLAS.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/BLAS.scala
@@ -257,80 +257,58 @@ private[spark] object BLAS extends Serializable with Logging {
/**
* C := alpha * A * B + beta * C
- * @param transA whether to use the transpose of matrix A (true), or A itself (false).
- * @param transB whether to use the transpose of matrix B (true), or B itself (false).
* @param alpha a scalar to scale the multiplication A * B.
* @param A the matrix A that will be left multiplied to B. Size of m x k.
* @param B the matrix B that will be left multiplied by A. Size of k x n.
* @param beta a scalar that can be used to scale matrix C.
- * @param C the resulting matrix C. Size of m x n.
+ * @param C the resulting matrix C. Size of m x n. C.isTransposed must be false.
*/
def gemm(
- transA: Boolean,
- transB: Boolean,
alpha: Double,
A: Matrix,
B: DenseMatrix,
beta: Double,
C: DenseMatrix): Unit = {
+ require(!C.isTransposed,
+ "The matrix C cannot be the product of a transpose() call. C.isTransposed must be false.")
if (alpha == 0.0) {
logDebug("gemm: alpha is equal to 0. Returning C.")
} else {
A match {
case sparse: SparseMatrix =>
- gemm(transA, transB, alpha, sparse, B, beta, C)
+ gemm(alpha, sparse, B, beta, C)
case dense: DenseMatrix =>
- gemm(transA, transB, alpha, dense, B, beta, C)
+ gemm(alpha, dense, B, beta, C)
case _ =>
throw new IllegalArgumentException(s"gemm doesn't support matrix type ${A.getClass}.")
}
}
}
- /**
- * C := alpha * A * B + beta * C
- *
- * @param alpha a scalar to scale the multiplication A * B.
- * @param A the matrix A that will be left multiplied to B. Size of m x k.
- * @param B the matrix B that will be left multiplied by A. Size of k x n.
- * @param beta a scalar that can be used to scale matrix C.
- * @param C the resulting matrix C. Size of m x n.
- */
- def gemm(
- alpha: Double,
- A: Matrix,
- B: DenseMatrix,
- beta: Double,
- C: DenseMatrix): Unit = {
- gemm(false, false, alpha, A, B, beta, C)
- }
-
/**
* C := alpha * A * B + beta * C
* For `DenseMatrix` A.
*/
private def gemm(
- transA: Boolean,
- transB: Boolean,
alpha: Double,
A: DenseMatrix,
B: DenseMatrix,
beta: Double,
C: DenseMatrix): Unit = {
- val mA: Int = if (!transA) A.numRows else A.numCols
- val nB: Int = if (!transB) B.numCols else B.numRows
- val kA: Int = if (!transA) A.numCols else A.numRows
- val kB: Int = if (!transB) B.numRows else B.numCols
- val tAstr = if (!transA) "N" else "T"
- val tBstr = if (!transB) "N" else "T"
-
- require(kA == kB, s"The columns of A don't match the rows of B. A: $kA, B: $kB")
- require(mA == C.numRows, s"The rows of C don't match the rows of A. C: ${C.numRows}, A: $mA")
- require(nB == C.numCols,
- s"The columns of C don't match the columns of B. C: ${C.numCols}, A: $nB")
-
- nativeBLAS.dgemm(tAstr, tBstr, mA, nB, kA, alpha, A.values, A.numRows, B.values, B.numRows,
- beta, C.values, C.numRows)
+ val tAstr = if (A.isTransposed) "T" else "N"
+ val tBstr = if (B.isTransposed) "T" else "N"
+ val lda = if (!A.isTransposed) A.numRows else A.numCols
+ val ldb = if (!B.isTransposed) B.numRows else B.numCols
+
+ require(A.numCols == B.numRows,
+ s"The columns of A don't match the rows of B. A: ${A.numCols}, B: ${B.numRows}")
+ require(A.numRows == C.numRows,
+ s"The rows of C don't match the rows of A. C: ${C.numRows}, A: ${A.numRows}")
+ require(B.numCols == C.numCols,
+ s"The columns of C don't match the columns of B. C: ${C.numCols}, A: ${B.numCols}")
+
+ nativeBLAS.dgemm(tAstr, tBstr, A.numRows, B.numCols, A.numCols, alpha, A.values, lda,
+ B.values, ldb, beta, C.values, C.numRows)
}
/**
@@ -338,17 +316,15 @@ private[spark] object BLAS extends Serializable with Logging {
* For `SparseMatrix` A.
*/
private def gemm(
- transA: Boolean,
- transB: Boolean,
alpha: Double,
A: SparseMatrix,
B: DenseMatrix,
beta: Double,
C: DenseMatrix): Unit = {
- val mA: Int = if (!transA) A.numRows else A.numCols
- val nB: Int = if (!transB) B.numCols else B.numRows
- val kA: Int = if (!transA) A.numCols else A.numRows
- val kB: Int = if (!transB) B.numRows else B.numCols
+ val mA: Int = A.numRows
+ val nB: Int = B.numCols
+ val kA: Int = A.numCols
+ val kB: Int = B.numRows
require(kA == kB, s"The columns of A don't match the rows of B. A: $kA, B: $kB")
require(mA == C.numRows, s"The rows of C don't match the rows of A. C: ${C.numRows}, A: $mA")
@@ -358,23 +334,23 @@ private[spark] object BLAS extends Serializable with Logging {
val Avals = A.values
val Bvals = B.values
val Cvals = C.values
- val Arows = if (!transA) A.rowIndices else A.colPtrs
- val Acols = if (!transA) A.colPtrs else A.rowIndices
+ val ArowIndices = A.rowIndices
+ val AcolPtrs = A.colPtrs
// Slicing is easy in this case. This is the optimal multiplication setting for sparse matrices
- if (transA){
+ if (A.isTransposed){
var colCounterForB = 0
- if (!transB) { // Expensive to put the check inside the loop
+ if (!B.isTransposed) { // Expensive to put the check inside the loop
while (colCounterForB < nB) {
var rowCounterForA = 0
val Cstart = colCounterForB * mA
val Bstart = colCounterForB * kA
while (rowCounterForA < mA) {
- var i = Arows(rowCounterForA)
- val indEnd = Arows(rowCounterForA + 1)
+ var i = AcolPtrs(rowCounterForA)
+ val indEnd = AcolPtrs(rowCounterForA + 1)
var sum = 0.0
while (i < indEnd) {
- sum += Avals(i) * Bvals(Bstart + Acols(i))
+ sum += Avals(i) * Bvals(Bstart + ArowIndices(i))
i += 1
}
val Cindex = Cstart + rowCounterForA
@@ -385,19 +361,19 @@ private[spark] object BLAS extends Serializable with Logging {
}
} else {
while (colCounterForB < nB) {
- var rowCounter = 0
+ var rowCounterForA = 0
val Cstart = colCounterForB * mA
- while (rowCounter < mA) {
- var i = Arows(rowCounter)
- val indEnd = Arows(rowCounter + 1)
+ while (rowCounterForA < mA) {
+ var i = AcolPtrs(rowCounterForA)
+ val indEnd = AcolPtrs(rowCounterForA + 1)
var sum = 0.0
while (i < indEnd) {
- sum += Avals(i) * B(colCounterForB, Acols(i))
+ sum += Avals(i) * B(ArowIndices(i), colCounterForB)
i += 1
}
- val Cindex = Cstart + rowCounter
+ val Cindex = Cstart + rowCounterForA
Cvals(Cindex) = beta * Cvals(Cindex) + sum * alpha
- rowCounter += 1
+ rowCounterForA += 1
}
colCounterForB += 1
}
@@ -410,17 +386,17 @@ private[spark] object BLAS extends Serializable with Logging {
// Perform matrix multiplication and add to C. The rows of A are multiplied by the columns of
// B, and added to C.
var colCounterForB = 0 // the column to be updated in C
- if (!transB) { // Expensive to put the check inside the loop
+ if (!B.isTransposed) { // Expensive to put the check inside the loop
while (colCounterForB < nB) {
var colCounterForA = 0 // The column of A to multiply with the row of B
val Bstart = colCounterForB * kB
val Cstart = colCounterForB * mA
while (colCounterForA < kA) {
- var i = Acols(colCounterForA)
- val indEnd = Acols(colCounterForA + 1)
+ var i = AcolPtrs(colCounterForA)
+ val indEnd = AcolPtrs(colCounterForA + 1)
val Bval = Bvals(Bstart + colCounterForA) * alpha
while (i < indEnd) {
- Cvals(Cstart + Arows(i)) += Avals(i) * Bval
+ Cvals(Cstart + ArowIndices(i)) += Avals(i) * Bval
i += 1
}
colCounterForA += 1
@@ -432,11 +408,11 @@ private[spark] object BLAS extends Serializable with Logging {
var colCounterForA = 0 // The column of A to multiply with the row of B
val Cstart = colCounterForB * mA
while (colCounterForA < kA) {
- var i = Acols(colCounterForA)
- val indEnd = Acols(colCounterForA + 1)
- val Bval = B(colCounterForB, colCounterForA) * alpha
+ var i = AcolPtrs(colCounterForA)
+ val indEnd = AcolPtrs(colCounterForA + 1)
+ val Bval = B(colCounterForA, colCounterForB) * alpha
while (i < indEnd) {
- Cvals(Cstart + Arows(i)) += Avals(i) * Bval
+ Cvals(Cstart + ArowIndices(i)) += Avals(i) * Bval
i += 1
}
colCounterForA += 1
@@ -449,7 +425,6 @@ private[spark] object BLAS extends Serializable with Logging {
/**
* y := alpha * A * x + beta * y
- * @param trans whether to use the transpose of matrix A (true), or A itself (false).
* @param alpha a scalar to scale the multiplication A * x.
* @param A the matrix A that will be left multiplied to x. Size of m x n.
* @param x the vector x that will be left multiplied by A. Size of n x 1.
@@ -457,65 +432,43 @@ private[spark] object BLAS extends Serializable with Logging {
* @param y the resulting vector y. Size of m x 1.
*/
def gemv(
- trans: Boolean,
alpha: Double,
A: Matrix,
x: DenseVector,
beta: Double,
y: DenseVector): Unit = {
-
- val mA: Int = if (!trans) A.numRows else A.numCols
- val nx: Int = x.size
- val nA: Int = if (!trans) A.numCols else A.numRows
-
- require(nA == nx, s"The columns of A don't match the number of elements of x. A: $nA, x: $nx")
- require(mA == y.size,
- s"The rows of A don't match the number of elements of y. A: $mA, y:${y.size}}")
+ require(A.numCols == x.size,
+ s"The columns of A don't match the number of elements of x. A: ${A.numCols}, x: ${x.size}")
+ require(A.numRows == y.size,
+ s"The rows of A don't match the number of elements of y. A: ${A.numRows}, y:${y.size}}")
if (alpha == 0.0) {
logDebug("gemv: alpha is equal to 0. Returning y.")
} else {
A match {
case sparse: SparseMatrix =>
- gemv(trans, alpha, sparse, x, beta, y)
+ gemv(alpha, sparse, x, beta, y)
case dense: DenseMatrix =>
- gemv(trans, alpha, dense, x, beta, y)
+ gemv(alpha, dense, x, beta, y)
case _ =>
throw new IllegalArgumentException(s"gemv doesn't support matrix type ${A.getClass}.")
}
}
}
- /**
- * y := alpha * A * x + beta * y
- *
- * @param alpha a scalar to scale the multiplication A * x.
- * @param A the matrix A that will be left multiplied to x. Size of m x n.
- * @param x the vector x that will be left multiplied by A. Size of n x 1.
- * @param beta a scalar that can be used to scale vector y.
- * @param y the resulting vector y. Size of m x 1.
- */
- def gemv(
- alpha: Double,
- A: Matrix,
- x: DenseVector,
- beta: Double,
- y: DenseVector): Unit = {
- gemv(false, alpha, A, x, beta, y)
- }
-
/**
* y := alpha * A * x + beta * y
* For `DenseMatrix` A.
*/
private def gemv(
- trans: Boolean,
alpha: Double,
A: DenseMatrix,
x: DenseVector,
beta: Double,
y: DenseVector): Unit = {
- val tStrA = if (!trans) "N" else "T"
- nativeBLAS.dgemv(tStrA, A.numRows, A.numCols, alpha, A.values, A.numRows, x.values, 1, beta,
+ val tStrA = if (A.isTransposed) "T" else "N"
+ val mA = if (!A.isTransposed) A.numRows else A.numCols
+ val nA = if (!A.isTransposed) A.numCols else A.numRows
+ nativeBLAS.dgemv(tStrA, mA, nA, alpha, A.values, mA, x.values, 1, beta,
y.values, 1)
}
@@ -524,24 +477,21 @@ private[spark] object BLAS extends Serializable with Logging {
* For `SparseMatrix` A.
*/
private def gemv(
- trans: Boolean,
alpha: Double,
A: SparseMatrix,
x: DenseVector,
beta: Double,
y: DenseVector): Unit = {
-
val xValues = x.values
val yValues = y.values
-
- val mA: Int = if (!trans) A.numRows else A.numCols
- val nA: Int = if (!trans) A.numCols else A.numRows
+ val mA: Int = A.numRows
+ val nA: Int = A.numCols
val Avals = A.values
- val Arows = if (!trans) A.rowIndices else A.colPtrs
- val Acols = if (!trans) A.colPtrs else A.rowIndices
+ val Arows = if (!A.isTransposed) A.rowIndices else A.colPtrs
+ val Acols = if (!A.isTransposed) A.colPtrs else A.rowIndices
// Slicing is easy in this case. This is the optimal multiplication setting for sparse matrices
- if (trans) {
+ if (A.isTransposed) {
var rowCounter = 0
while (rowCounter < mA) {
var i = Arows(rowCounter)
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala
index 5a7281ec6dc3c..ad7e86827b368 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala
@@ -34,8 +34,17 @@ sealed trait Matrix extends Serializable {
/** Number of columns. */
def numCols: Int
+ /** Flag that keeps track whether the matrix is transposed or not. False by default. */
+ val isTransposed: Boolean = false
+
/** Converts to a dense array in column major. */
- def toArray: Array[Double]
+ def toArray: Array[Double] = {
+ val newArray = new Array[Double](numRows * numCols)
+ foreachActive { (i, j, v) =>
+ newArray(j * numRows + i) = v
+ }
+ newArray
+ }
/** Converts to a breeze matrix. */
private[mllib] def toBreeze: BM[Double]
@@ -52,10 +61,13 @@ sealed trait Matrix extends Serializable {
/** Get a deep copy of the matrix. */
def copy: Matrix
+ /** Transpose the Matrix. Returns a new `Matrix` instance sharing the same underlying data. */
+ def transpose: Matrix
+
/** Convenience method for `Matrix`-`DenseMatrix` multiplication. */
def multiply(y: DenseMatrix): DenseMatrix = {
- val C: DenseMatrix = Matrices.zeros(numRows, y.numCols).asInstanceOf[DenseMatrix]
- BLAS.gemm(false, false, 1.0, this, y, 0.0, C)
+ val C: DenseMatrix = DenseMatrix.zeros(numRows, y.numCols)
+ BLAS.gemm(1.0, this, y, 0.0, C)
C
}
@@ -66,20 +78,6 @@ sealed trait Matrix extends Serializable {
output
}
- /** Convenience method for `Matrix`^T^-`DenseMatrix` multiplication. */
- private[mllib] def transposeMultiply(y: DenseMatrix): DenseMatrix = {
- val C: DenseMatrix = Matrices.zeros(numCols, y.numCols).asInstanceOf[DenseMatrix]
- BLAS.gemm(true, false, 1.0, this, y, 0.0, C)
- C
- }
-
- /** Convenience method for `Matrix`^T^-`DenseVector` multiplication. */
- private[mllib] def transposeMultiply(y: DenseVector): DenseVector = {
- val output = new DenseVector(new Array[Double](numCols))
- BLAS.gemv(true, 1.0, this, y, 0.0, output)
- output
- }
-
/** A human readable representation of the matrix */
override def toString: String = toBreeze.toString()
@@ -92,6 +90,16 @@ sealed trait Matrix extends Serializable {
* backing array. For example, an operation such as addition or subtraction will only be
* performed on the non-zero values in a `SparseMatrix`. */
private[mllib] def update(f: Double => Double): Matrix
+
+ /**
+ * Applies a function `f` to all the active elements of dense and sparse matrix. The ordering
+ * of the elements are not defined.
+ *
+ * @param f the function takes three parameters where the first two parameters are the row
+ * and column indices respectively with the type `Int`, and the final parameter is the
+ * corresponding value in the matrix with type `Double`.
+ */
+ private[spark] def foreachActive(f: (Int, Int, Double) => Unit)
}
/**
@@ -108,13 +116,35 @@ sealed trait Matrix extends Serializable {
* @param numRows number of rows
* @param numCols number of columns
* @param values matrix entries in column major
+ * @param isTransposed whether the matrix is transposed. If true, `values` stores the matrix in
+ * row major.
*/
-class DenseMatrix(val numRows: Int, val numCols: Int, val values: Array[Double]) extends Matrix {
+class DenseMatrix(
+ val numRows: Int,
+ val numCols: Int,
+ val values: Array[Double],
+ override val isTransposed: Boolean) extends Matrix {
require(values.length == numRows * numCols, "The number of values supplied doesn't match the " +
s"size of the matrix! values.length: ${values.length}, numRows * numCols: ${numRows * numCols}")
- override def toArray: Array[Double] = values
+ /**
+ * Column-major dense matrix.
+ * The entry values are stored in a single array of doubles with columns listed in sequence.
+ * For example, the following matrix
+ * {{{
+ * 1.0 2.0
+ * 3.0 4.0
+ * 5.0 6.0
+ * }}}
+ * is stored as `[1.0, 3.0, 5.0, 2.0, 4.0, 6.0]`.
+ *
+ * @param numRows number of rows
+ * @param numCols number of columns
+ * @param values matrix entries in column major
+ */
+ def this(numRows: Int, numCols: Int, values: Array[Double]) =
+ this(numRows, numCols, values, false)
override def equals(o: Any) = o match {
case m: DenseMatrix =>
@@ -122,13 +152,22 @@ class DenseMatrix(val numRows: Int, val numCols: Int, val values: Array[Double])
case _ => false
}
- private[mllib] def toBreeze: BM[Double] = new BDM[Double](numRows, numCols, values)
+ private[mllib] def toBreeze: BM[Double] = {
+ if (!isTransposed) {
+ new BDM[Double](numRows, numCols, values)
+ } else {
+ val breezeMatrix = new BDM[Double](numCols, numRows, values)
+ breezeMatrix.t
+ }
+ }
private[mllib] def apply(i: Int): Double = values(i)
private[mllib] def apply(i: Int, j: Int): Double = values(index(i, j))
- private[mllib] def index(i: Int, j: Int): Int = i + numRows * j
+ private[mllib] def index(i: Int, j: Int): Int = {
+ if (!isTransposed) i + numRows * j else j + numCols * i
+ }
private[mllib] def update(i: Int, j: Int, v: Double): Unit = {
values(index(i, j)) = v
@@ -148,7 +187,38 @@ class DenseMatrix(val numRows: Int, val numCols: Int, val values: Array[Double])
this
}
- /** Generate a `SparseMatrix` from the given `DenseMatrix`. */
+ override def transpose: Matrix = new DenseMatrix(numCols, numRows, values, !isTransposed)
+
+ private[spark] override def foreachActive(f: (Int, Int, Double) => Unit): Unit = {
+ if (!isTransposed) {
+ // outer loop over columns
+ var j = 0
+ while (j < numCols) {
+ var i = 0
+ val indStart = j * numRows
+ while (i < numRows) {
+ f(i, j, values(indStart + i))
+ i += 1
+ }
+ j += 1
+ }
+ } else {
+ // outer loop over rows
+ var i = 0
+ while (i < numRows) {
+ var j = 0
+ val indStart = i * numCols
+ while (j < numCols) {
+ f(i, j, values(indStart + j))
+ j += 1
+ }
+ i += 1
+ }
+ }
+ }
+
+ /** Generate a `SparseMatrix` from the given `DenseMatrix`. The new matrix will have isTransposed
+ * set to false. */
def toSparse(): SparseMatrix = {
val spVals: MArrayBuilder[Double] = new MArrayBuilder.ofDouble
val colPtrs: Array[Int] = new Array[Int](numCols + 1)
@@ -157,9 +227,8 @@ class DenseMatrix(val numRows: Int, val numCols: Int, val values: Array[Double])
var j = 0
while (j < numCols) {
var i = 0
- val indStart = j * numRows
while (i < numRows) {
- val v = values(indStart + i)
+ val v = values(index(i, j))
if (v != 0.0) {
rowIndices += i
spVals += v
@@ -271,49 +340,73 @@ object DenseMatrix {
* @param rowIndices the row index of the entry. They must be in strictly increasing order for each
* column
* @param values non-zero matrix entries in column major
+ * @param isTransposed whether the matrix is transposed. If true, the matrix can be considered
+ * Compressed Sparse Row (CSR) format, where `colPtrs` behaves as rowPtrs,
+ * and `rowIndices` behave as colIndices, and `values` are stored in row major.
*/
class SparseMatrix(
val numRows: Int,
val numCols: Int,
val colPtrs: Array[Int],
val rowIndices: Array[Int],
- val values: Array[Double]) extends Matrix {
+ val values: Array[Double],
+ override val isTransposed: Boolean) extends Matrix {
require(values.length == rowIndices.length, "The number of row indices and values don't match! " +
s"values.length: ${values.length}, rowIndices.length: ${rowIndices.length}")
- require(colPtrs.length == numCols + 1, "The length of the column indices should be the " +
- s"number of columns + 1. Currently, colPointers.length: ${colPtrs.length}, " +
- s"numCols: $numCols")
+ // The Or statement is for the case when the matrix is transposed
+ require(colPtrs.length == numCols + 1 || colPtrs.length == numRows + 1, "The length of the " +
+ "column indices should be the number of columns + 1. Currently, colPointers.length: " +
+ s"${colPtrs.length}, numCols: $numCols")
require(values.length == colPtrs.last, "The last value of colPtrs must equal the number of " +
s"elements. values.length: ${values.length}, colPtrs.last: ${colPtrs.last}")
- override def toArray: Array[Double] = {
- val arr = new Array[Double](numRows * numCols)
- var j = 0
- while (j < numCols) {
- var i = colPtrs(j)
- val indEnd = colPtrs(j + 1)
- val offset = j * numRows
- while (i < indEnd) {
- val rowIndex = rowIndices(i)
- arr(offset + rowIndex) = values(i)
- i += 1
- }
- j += 1
- }
- arr
+ /**
+ * Column-major sparse matrix.
+ * The entry values are stored in Compressed Sparse Column (CSC) format.
+ * For example, the following matrix
+ * {{{
+ * 1.0 0.0 4.0
+ * 0.0 3.0 5.0
+ * 2.0 0.0 6.0
+ * }}}
+ * is stored as `values: [1.0, 2.0, 3.0, 4.0, 5.0, 6.0]`,
+ * `rowIndices=[0, 2, 1, 0, 1, 2]`, `colPointers=[0, 2, 3, 6]`.
+ *
+ * @param numRows number of rows
+ * @param numCols number of columns
+ * @param colPtrs the index corresponding to the start of a new column
+ * @param rowIndices the row index of the entry. They must be in strictly increasing
+ * order for each column
+ * @param values non-zero matrix entries in column major
+ */
+ def this(
+ numRows: Int,
+ numCols: Int,
+ colPtrs: Array[Int],
+ rowIndices: Array[Int],
+ values: Array[Double]) = this(numRows, numCols, colPtrs, rowIndices, values, false)
+
+ private[mllib] def toBreeze: BM[Double] = {
+ if (!isTransposed) {
+ new BSM[Double](values, numRows, numCols, colPtrs, rowIndices)
+ } else {
+ val breezeMatrix = new BSM[Double](values, numCols, numRows, colPtrs, rowIndices)
+ breezeMatrix.t
+ }
}
- private[mllib] def toBreeze: BM[Double] =
- new BSM[Double](values, numRows, numCols, colPtrs, rowIndices)
-
private[mllib] def apply(i: Int, j: Int): Double = {
val ind = index(i, j)
if (ind < 0) 0.0 else values(ind)
}
private[mllib] def index(i: Int, j: Int): Int = {
- Arrays.binarySearch(rowIndices, colPtrs(j), colPtrs(j + 1), i)
+ if (!isTransposed) {
+ Arrays.binarySearch(rowIndices, colPtrs(j), colPtrs(j + 1), i)
+ } else {
+ Arrays.binarySearch(rowIndices, colPtrs(i), colPtrs(i + 1), j)
+ }
}
private[mllib] def update(i: Int, j: Int, v: Double): Unit = {
@@ -322,7 +415,7 @@ class SparseMatrix(
throw new NoSuchElementException("The given row and column indices correspond to a zero " +
"value. Only non-zero elements in Sparse Matrices can be updated.")
} else {
- values(index(i, j)) = v
+ values(ind) = v
}
}
@@ -341,7 +434,38 @@ class SparseMatrix(
this
}
- /** Generate a `DenseMatrix` from the given `SparseMatrix`. */
+ override def transpose: Matrix =
+ new SparseMatrix(numCols, numRows, colPtrs, rowIndices, values, !isTransposed)
+
+ private[spark] override def foreachActive(f: (Int, Int, Double) => Unit): Unit = {
+ if (!isTransposed) {
+ var j = 0
+ while (j < numCols) {
+ var idx = colPtrs(j)
+ val idxEnd = colPtrs(j + 1)
+ while (idx < idxEnd) {
+ f(rowIndices(idx), j, values(idx))
+ idx += 1
+ }
+ j += 1
+ }
+ } else {
+ var i = 0
+ while (i < numRows) {
+ var idx = colPtrs(i)
+ val idxEnd = colPtrs(i + 1)
+ while (idx < idxEnd) {
+ val j = rowIndices(idx)
+ f(i, j, values(idx))
+ idx += 1
+ }
+ i += 1
+ }
+ }
+ }
+
+ /** Generate a `DenseMatrix` from the given `SparseMatrix`. The new matrix will have isTransposed
+ * set to false. */
def toDense(): DenseMatrix = {
new DenseMatrix(numRows, numCols, toArray)
}
@@ -557,10 +681,9 @@ object Matrices {
private[mllib] def fromBreeze(breeze: BM[Double]): Matrix = {
breeze match {
case dm: BDM[Double] =>
- require(dm.majorStride == dm.rows,
- "Do not support stride size different from the number of rows.")
- new DenseMatrix(dm.rows, dm.cols, dm.data)
+ new DenseMatrix(dm.rows, dm.cols, dm.data, dm.isTranspose)
case sm: BSM[Double] =>
+ // There is no isTranspose flag for sparse matrices in Breeze
new SparseMatrix(sm.rows, sm.cols, sm.colPtrs, sm.rowIndices, sm.data)
case _ =>
throw new UnsupportedOperationException(
@@ -679,46 +802,28 @@ object Matrices {
new DenseMatrix(numRows, numCols, matrices.flatMap(_.toArray))
} else {
var startCol = 0
- val entries: Array[(Int, Int, Double)] = matrices.flatMap {
- case spMat: SparseMatrix =>
- var j = 0
- val colPtrs = spMat.colPtrs
- val rowIndices = spMat.rowIndices
- val values = spMat.values
- val data = new Array[(Int, Int, Double)](values.length)
- val nCols = spMat.numCols
- while (j < nCols) {
- var idx = colPtrs(j)
- while (idx < colPtrs(j + 1)) {
- val i = rowIndices(idx)
- val v = values(idx)
- data(idx) = (i, j + startCol, v)
- idx += 1
+ val entries: Array[(Int, Int, Double)] = matrices.flatMap { mat =>
+ val nCols = mat.numCols
+ mat match {
+ case spMat: SparseMatrix =>
+ val data = new Array[(Int, Int, Double)](spMat.values.length)
+ var cnt = 0
+ spMat.foreachActive { (i, j, v) =>
+ data(cnt) = (i, j + startCol, v)
+ cnt += 1
}
- j += 1
- }
- startCol += nCols
- data
- case dnMat: DenseMatrix =>
- val data = new ArrayBuffer[(Int, Int, Double)]()
- var j = 0
- val nCols = dnMat.numCols
- val nRows = dnMat.numRows
- val values = dnMat.values
- while (j < nCols) {
- var i = 0
- val indStart = j * nRows
- while (i < nRows) {
- val v = values(indStart + i)
+ startCol += nCols
+ data
+ case dnMat: DenseMatrix =>
+ val data = new ArrayBuffer[(Int, Int, Double)]()
+ dnMat.foreachActive { (i, j, v) =>
if (v != 0.0) {
data.append((i, j + startCol, v))
}
- i += 1
}
- j += 1
- }
- startCol += nCols
- data
+ startCol += nCols
+ data
+ }
}
SparseMatrix.fromCOO(numRows, numCols, entries)
}
@@ -744,14 +849,12 @@ object Matrices {
require(numCols == mat.numCols, "The number of rows of the matrices in this sequence, " +
"don't match!")
mat match {
- case sparse: SparseMatrix =>
- hasSparse = true
- case dense: DenseMatrix =>
+ case sparse: SparseMatrix => hasSparse = true
+ case dense: DenseMatrix => // empty on purpose
case _ => throw new IllegalArgumentException("Unsupported matrix format. Expected " +
s"SparseMatrix or DenseMatrix. Instead got: ${mat.getClass}")
}
numRows += mat.numRows
-
}
if (!hasSparse) {
val allValues = new Array[Double](numRows * numCols)
@@ -759,61 +862,37 @@ object Matrices {
matrices.foreach { mat =>
var j = 0
val nRows = mat.numRows
- val values = mat.toArray
- while (j < numCols) {
- var i = 0
+ mat.foreachActive { (i, j, v) =>
val indStart = j * numRows + startRow
- val subMatStart = j * nRows
- while (i < nRows) {
- allValues(indStart + i) = values(subMatStart + i)
- i += 1
- }
- j += 1
+ allValues(indStart + i) = v
}
startRow += nRows
}
new DenseMatrix(numRows, numCols, allValues)
} else {
var startRow = 0
- val entries: Array[(Int, Int, Double)] = matrices.flatMap {
- case spMat: SparseMatrix =>
- var j = 0
- val colPtrs = spMat.colPtrs
- val rowIndices = spMat.rowIndices
- val values = spMat.values
- val data = new Array[(Int, Int, Double)](values.length)
- while (j < numCols) {
- var idx = colPtrs(j)
- while (idx < colPtrs(j + 1)) {
- val i = rowIndices(idx)
- val v = values(idx)
- data(idx) = (i + startRow, j, v)
- idx += 1
+ val entries: Array[(Int, Int, Double)] = matrices.flatMap { mat =>
+ val nRows = mat.numRows
+ mat match {
+ case spMat: SparseMatrix =>
+ val data = new Array[(Int, Int, Double)](spMat.values.length)
+ var cnt = 0
+ spMat.foreachActive { (i, j, v) =>
+ data(cnt) = (i + startRow, j, v)
+ cnt += 1
}
- j += 1
- }
- startRow += spMat.numRows
- data
- case dnMat: DenseMatrix =>
- val data = new ArrayBuffer[(Int, Int, Double)]()
- var j = 0
- val nCols = dnMat.numCols
- val nRows = dnMat.numRows
- val values = dnMat.values
- while (j < nCols) {
- var i = 0
- val indStart = j * nRows
- while (i < nRows) {
- val v = values(indStart + i)
+ startRow += nRows
+ data
+ case dnMat: DenseMatrix =>
+ val data = new ArrayBuffer[(Int, Int, Double)]()
+ dnMat.foreachActive { (i, j, v) =>
if (v != 0.0) {
data.append((i + startRow, j, v))
}
- i += 1
}
- j += 1
- }
- startRow += nRows
- data
+ startRow += nRows
+ data
+ }
}
SparseMatrix.fromCOO(numRows, numCols, entries)
}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala
index adbd8266ed6fa..2834ea75ceb8f 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala
@@ -50,13 +50,35 @@ sealed trait Vector extends Serializable {
override def equals(other: Any): Boolean = {
other match {
- case v: Vector =>
- util.Arrays.equals(this.toArray, v.toArray)
+ case v2: Vector => {
+ if (this.size != v2.size) return false
+ (this, v2) match {
+ case (s1: SparseVector, s2: SparseVector) =>
+ Vectors.equals(s1.indices, s1.values, s2.indices, s2.values)
+ case (s1: SparseVector, d1: DenseVector) =>
+ Vectors.equals(s1.indices, s1.values, 0 until d1.size, d1.values)
+ case (d1: DenseVector, s1: SparseVector) =>
+ Vectors.equals(0 until d1.size, d1.values, s1.indices, s1.values)
+ case (_, _) => util.Arrays.equals(this.toArray, v2.toArray)
+ }
+ }
case _ => false
}
}
- override def hashCode(): Int = util.Arrays.hashCode(this.toArray)
+ override def hashCode(): Int = {
+ var result: Int = size + 31
+ this.foreachActive { case (index, value) =>
+ // ignore explict 0 for comparison between sparse and dense
+ if (value != 0) {
+ result = 31 * result + index
+ // refer to {@link java.util.Arrays.equals} for hash algorithm
+ val bits = java.lang.Double.doubleToLongBits(value)
+ result = 31 * result + (bits ^ (bits >>> 32)).toInt
+ }
+ }
+ return result
+ }
/**
* Converts the instance to a breeze vector.
@@ -311,7 +333,7 @@ object Vectors {
math.pow(sum, 1.0 / p)
}
}
-
+
/**
* Returns the squared distance between two Vectors.
* @param v1 first Vector.
@@ -319,8 +341,9 @@ object Vectors {
* @return squared distance between two Vectors.
*/
def sqdist(v1: Vector, v2: Vector): Double = {
+ require(v1.size == v2.size, "vector dimension mismatch")
var squaredDistance = 0.0
- (v1, v2) match {
+ (v1, v2) match {
case (v1: SparseVector, v2: SparseVector) =>
val v1Values = v1.values
val v1Indices = v1.indices
@@ -328,12 +351,12 @@ object Vectors {
val v2Indices = v2.indices
val nnzv1 = v1Indices.size
val nnzv2 = v2Indices.size
-
+
var kv1 = 0
var kv2 = 0
while (kv1 < nnzv1 || kv2 < nnzv2) {
var score = 0.0
-
+
if (kv2 >= nnzv2 || (kv1 < nnzv1 && v1Indices(kv1) < v2Indices(kv2))) {
score = v1Values(kv1)
kv1 += 1
@@ -348,18 +371,23 @@ object Vectors {
squaredDistance += score * score
}
- case (v1: SparseVector, v2: DenseVector) if v1.indices.length / v1.size < 0.5 =>
+ case (v1: SparseVector, v2: DenseVector) =>
squaredDistance = sqdist(v1, v2)
- case (v1: DenseVector, v2: SparseVector) if v2.indices.length / v2.size < 0.5 =>
+ case (v1: DenseVector, v2: SparseVector) =>
squaredDistance = sqdist(v2, v1)
- // When a SparseVector is approximately dense, we treat it as a DenseVector
- case (v1, v2) =>
- squaredDistance = v1.toArray.zip(v2.toArray).foldLeft(0.0){ (distance, elems) =>
- val score = elems._1 - elems._2
- distance + score * score
+ case (DenseVector(vv1), DenseVector(vv2)) =>
+ var kv = 0
+ val sz = vv1.size
+ while (kv < sz) {
+ val score = vv1(kv) - vv2(kv)
+ squaredDistance += score * score
+ kv += 1
}
+ case _ =>
+ throw new IllegalArgumentException("Do not support vector type " + v1.getClass +
+ " and " + v2.getClass)
}
squaredDistance
}
@@ -375,7 +403,7 @@ object Vectors {
val nnzv1 = indices.size
val nnzv2 = v2.size
var iv1 = if (nnzv1 > 0) indices(kv1) else -1
-
+
while (kv2 < nnzv2) {
var score = 0.0
if (kv2 != iv1) {
@@ -392,6 +420,33 @@ object Vectors {
}
squaredDistance
}
+
+ /**
+ * Check equality between sparse/dense vectors
+ */
+ private[mllib] def equals(
+ v1Indices: IndexedSeq[Int],
+ v1Values: Array[Double],
+ v2Indices: IndexedSeq[Int],
+ v2Values: Array[Double]): Boolean = {
+ val v1Size = v1Values.size
+ val v2Size = v2Values.size
+ var k1 = 0
+ var k2 = 0
+ var allEqual = true
+ while (allEqual) {
+ while (k1 < v1Size && v1Values(k1) == 0) k1 += 1
+ while (k2 < v2Size && v2Values(k2) == 0) k2 += 1
+
+ if (k1 >= v1Size || k2 >= v2Size) {
+ return k1 >= v1Size && k2 >= v2Size // check end alignment
+ }
+ allEqual = v1Indices(k1) == v2Indices(k2) && v1Values(k1) == v2Values(k2)
+ k1 += 1
+ k2 += 1
+ }
+ allEqual
+ }
}
/**
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/CoordinateMatrix.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/CoordinateMatrix.scala
index 06d8915f3bfa1..b60559c853a50 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/CoordinateMatrix.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/CoordinateMatrix.scala
@@ -69,6 +69,11 @@ class CoordinateMatrix(
nRows
}
+ /** Transposes this CoordinateMatrix. */
+ def transpose(): CoordinateMatrix = {
+ new CoordinateMatrix(entries.map(x => MatrixEntry(x.j, x.i, x.value)), numCols(), numRows())
+ }
+
/** Converts to IndexedRowMatrix. The number of columns must be within the integer range. */
def toIndexedRowMatrix(): IndexedRowMatrix = {
val nl = numCols()
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/IndexedRowMatrix.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/IndexedRowMatrix.scala
index 181f507516485..c518271f04729 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/IndexedRowMatrix.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/IndexedRowMatrix.scala
@@ -75,6 +75,23 @@ class IndexedRowMatrix(
new RowMatrix(rows.map(_.vector), 0L, nCols)
}
+ /**
+ * Converts this matrix to a
+ * [[org.apache.spark.mllib.linalg.distributed.CoordinateMatrix]].
+ */
+ def toCoordinateMatrix(): CoordinateMatrix = {
+ val entries = rows.flatMap { row =>
+ val rowIndex = row.index
+ row.vector match {
+ case SparseVector(size, indices, values) =>
+ Iterator.tabulate(indices.size)(i => MatrixEntry(rowIndex, indices(i), values(i)))
+ case DenseVector(values) =>
+ Iterator.tabulate(values.size)(i => MatrixEntry(rowIndex, i, values(i)))
+ }
+ }
+ new CoordinateMatrix(entries, numRows(), numCols())
+ }
+
/**
* Computes the singular value decomposition of this IndexedRowMatrix.
* Denote this matrix by A (m x n), this will compute matrices U, S, V such that A = U * S * V'.
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala
index d5abba6a4b645..02075edbabf85 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala
@@ -131,8 +131,8 @@ class RowMatrix(
throw new IllegalArgumentException(s"Argument with more than 65535 cols: $cols")
}
if (cols > 10000) {
- val mem = cols * cols * 8
- logWarning(s"$cols columns will require at least $mem bytes of memory!")
+ val memMB = (cols.toLong * cols) / 125000
+ logWarning(s"$cols columns will require at least $memMB megabytes of memory!")
}
}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/recommendation/ALS.scala b/mllib/src/main/scala/org/apache/spark/mllib/recommendation/ALS.scala
index bee951a2e5e26..5f84677be238d 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/recommendation/ALS.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/recommendation/ALS.scala
@@ -90,7 +90,7 @@ case class Rating(user: Int, product: Int, rating: Double)
*
* Essentially instead of finding the low-rank approximations to the rating matrix `R`,
* this finds the approximations for a preference matrix `P` where the elements of `P` are 1 if
- * r > 0 and 0 if r = 0. The ratings then act as 'confidence' values related to strength of
+ * r > 0 and 0 if r <= 0. The ratings then act as 'confidence' values related to strength of
* indicated user
* preferences rather than explicit ratings given to items.
*/
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/RandomForest.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/RandomForest.scala
index e9304b5e5c650..482dd4b272d1d 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/RandomForest.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/RandomForest.scala
@@ -140,6 +140,7 @@ private class RandomForest (
logDebug("maxBins = " + metadata.maxBins)
logDebug("featureSubsetStrategy = " + featureSubsetStrategy)
logDebug("numFeaturesPerNode = " + metadata.numFeaturesPerNode)
+ logDebug("subsamplingRate = " + strategy.subsamplingRate)
// Find the splits and the corresponding bins (interval between the splits) using a sample
// of the input data.
@@ -155,19 +156,12 @@ private class RandomForest (
// Cache input RDD for speedup during multiple passes.
val treeInput = TreePoint.convertToTreeRDD(retaggedInput, bins, metadata)
- val (subsample, withReplacement) = {
- // TODO: Have a stricter check for RF in the strategy
- val isRandomForest = numTrees > 1
- if (isRandomForest) {
- (1.0, true)
- } else {
- (strategy.subsamplingRate, false)
- }
- }
+ val withReplacement = if (numTrees > 1) true else false
val baggedInput
- = BaggedPoint.convertToBaggedRDD(treeInput, subsample, numTrees, withReplacement, seed)
- .persist(StorageLevel.MEMORY_AND_DISK)
+ = BaggedPoint.convertToBaggedRDD(treeInput,
+ strategy.subsamplingRate, numTrees,
+ withReplacement, seed).persist(StorageLevel.MEMORY_AND_DISK)
// depth of the decision tree
val maxDepth = strategy.maxDepth
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/BoostingStrategy.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/BoostingStrategy.scala
index cf51d041c65a9..ed8e6a796f8c4 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/BoostingStrategy.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/BoostingStrategy.scala
@@ -68,6 +68,15 @@ case class BoostingStrategy(
@Experimental
object BoostingStrategy {
+ /**
+ * Returns default configuration for the boosting algorithm
+ * @param algo Learning goal. Supported: "Classification" or "Regression"
+ * @return Configuration for boosting algorithm
+ */
+ def defaultParams(algo: String): BoostingStrategy = {
+ defaultParams(Algo.fromString(algo))
+ }
+
/**
* Returns default configuration for the boosting algorithm
* @param algo Learning goal. Supported:
@@ -75,15 +84,15 @@ object BoostingStrategy {
* [[org.apache.spark.mllib.tree.configuration.Algo.Regression]]
* @return Configuration for boosting algorithm
*/
- def defaultParams(algo: String): BoostingStrategy = {
- val treeStrategy = Strategy.defaultStrategy(algo)
- treeStrategy.maxDepth = 3
+ def defaultParams(algo: Algo): BoostingStrategy = {
+ val treeStragtegy = Strategy.defaultStategy(algo)
+ treeStragtegy.maxDepth = 3
algo match {
- case "Classification" =>
- treeStrategy.numClasses = 2
- new BoostingStrategy(treeStrategy, LogLoss)
- case "Regression" =>
- new BoostingStrategy(treeStrategy, SquaredError)
+ case Algo.Classification =>
+ treeStragtegy.numClasses = 2
+ new BoostingStrategy(treeStragtegy, LogLoss)
+ case Algo.Regression =>
+ new BoostingStrategy(treeStragtegy, SquaredError)
case _ =>
throw new IllegalArgumentException(s"$algo is not supported by boosting.")
}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala
index d5cd89ab94e81..3308adb6752ff 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala
@@ -156,6 +156,9 @@ class Strategy (
s"DecisionTree Strategy requires minInstancesPerNode >= 1 but was given $minInstancesPerNode")
require(maxMemoryInMB <= 10240,
s"DecisionTree Strategy requires maxMemoryInMB <= 10240, but was given $maxMemoryInMB")
+ require(subsamplingRate > 0 && subsamplingRate <= 1,
+ s"DecisionTree Strategy requires subsamplingRate <=1 and >0, but was given " +
+ s"$subsamplingRate")
}
/** Returns a shallow copy of this instance. */
@@ -173,11 +176,19 @@ object Strategy {
* Construct a default set of parameters for [[org.apache.spark.mllib.tree.DecisionTree]]
* @param algo "Classification" or "Regression"
*/
- def defaultStrategy(algo: String): Strategy = algo match {
- case "Classification" =>
+ def defaultStrategy(algo: String): Strategy = {
+ defaultStategy(Algo.fromString(algo))
+ }
+
+ /**
+ * Construct a default set of parameters for [[org.apache.spark.mllib.tree.DecisionTree]]
+ * @param algo Algo.Classification or Algo.Regression
+ */
+ def defaultStategy(algo: Algo): Strategy = algo match {
+ case Algo.Classification =>
new Strategy(algo = Classification, impurity = Gini, maxDepth = 10,
numClasses = 2)
- case "Regression" =>
+ case Algo.Regression =>
new Strategy(algo = Regression, impurity = Variance, maxDepth = 10,
numClasses = 0)
}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Entropy.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Entropy.scala
index 0e02345aa3774..b7950e00786ab 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Entropy.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Entropy.scala
@@ -94,6 +94,10 @@ private[tree] class EntropyAggregator(numClasses: Int)
throw new IllegalArgumentException(s"EntropyAggregator given label $label" +
s" but requires label < numClasses (= $statsSize).")
}
+ if (label < 0) {
+ throw new IllegalArgumentException(s"EntropyAggregator given label $label" +
+ s"but requires label is non-negative.")
+ }
allStats(offset + label.toInt) += instanceWeight
}
@@ -147,6 +151,7 @@ private[tree] class EntropyCalculator(stats: Array[Double]) extends ImpurityCalc
val lbl = label.toInt
require(lbl < stats.length,
s"EntropyCalculator.prob given invalid label: $lbl (should be < ${stats.length}")
+ require(lbl >= 0, "Entropy does not support negative labels")
val cnt = count
if (cnt == 0) {
0
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Gini.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Gini.scala
index 7c83cd48e16a0..c946db9c0d1c8 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Gini.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Gini.scala
@@ -90,6 +90,10 @@ private[tree] class GiniAggregator(numClasses: Int)
throw new IllegalArgumentException(s"GiniAggregator given label $label" +
s" but requires label < numClasses (= $statsSize).")
}
+ if (label < 0) {
+ throw new IllegalArgumentException(s"GiniAggregator given label $label" +
+ s"but requires label is non-negative.")
+ }
allStats(offset + label.toInt) += instanceWeight
}
@@ -143,6 +147,7 @@ private[tree] class GiniCalculator(stats: Array[Double]) extends ImpurityCalcula
val lbl = label.toInt
require(lbl < stats.length,
s"GiniCalculator.prob given invalid label: $lbl (should be < ${stats.length}")
+ require(lbl >= 0, "GiniImpurity does not support negative labels")
val cnt = count
if (cnt == 0) {
0
diff --git a/mllib/src/test/java/org/apache/spark/ml/JavaPipelineSuite.java b/mllib/src/test/java/org/apache/spark/ml/JavaPipelineSuite.java
index 42846677ed285..56a9dbdd58b64 100644
--- a/mllib/src/test/java/org/apache/spark/ml/JavaPipelineSuite.java
+++ b/mllib/src/test/java/org/apache/spark/ml/JavaPipelineSuite.java
@@ -26,10 +26,9 @@
import org.apache.spark.mllib.regression.LabeledPoint;
import org.apache.spark.ml.classification.LogisticRegression;
import org.apache.spark.ml.feature.StandardScaler;
-import org.apache.spark.sql.api.java.JavaSQLContext;
-import org.apache.spark.sql.api.java.JavaSchemaRDD;
-import static org.apache.spark.mllib.classification.LogisticRegressionSuite
- .generateLogisticInputAsList;
+import org.apache.spark.sql.DataFrame;
+import org.apache.spark.sql.SQLContext;
+import static org.apache.spark.mllib.classification.LogisticRegressionSuite.generateLogisticInputAsList;
/**
* Test Pipeline construction and fitting in Java.
@@ -37,13 +36,13 @@
public class JavaPipelineSuite {
private transient JavaSparkContext jsc;
- private transient JavaSQLContext jsql;
- private transient JavaSchemaRDD dataset;
+ private transient SQLContext jsql;
+ private transient DataFrame dataset;
@Before
public void setUp() {
jsc = new JavaSparkContext("local", "JavaPipelineSuite");
- jsql = new JavaSQLContext(jsc);
+ jsql = new SQLContext(jsc);
JavaRDD points =
jsc.parallelize(generateLogisticInputAsList(1.0, 1.0, 100, 42), 2);
dataset = jsql.applySchema(points, LabeledPoint.class);
@@ -66,7 +65,7 @@ public void pipeline() {
.setStages(new PipelineStage[] {scaler, lr});
PipelineModel model = pipeline.fit(dataset);
model.transform(dataset).registerTempTable("prediction");
- JavaSchemaRDD predictions = jsql.sql("SELECT label, score, prediction FROM prediction");
- predictions.collect();
+ DataFrame predictions = jsql.sql("SELECT label, score, prediction FROM prediction");
+ predictions.collectAsList();
}
}
diff --git a/mllib/src/test/java/org/apache/spark/ml/classification/JavaLogisticRegressionSuite.java b/mllib/src/test/java/org/apache/spark/ml/classification/JavaLogisticRegressionSuite.java
index 76eb7f00329f2..f4ba23c44563e 100644
--- a/mllib/src/test/java/org/apache/spark/ml/classification/JavaLogisticRegressionSuite.java
+++ b/mllib/src/test/java/org/apache/spark/ml/classification/JavaLogisticRegressionSuite.java
@@ -26,21 +26,20 @@
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.mllib.regression.LabeledPoint;
-import org.apache.spark.sql.api.java.JavaSQLContext;
-import org.apache.spark.sql.api.java.JavaSchemaRDD;
-import static org.apache.spark.mllib.classification.LogisticRegressionSuite
- .generateLogisticInputAsList;
+import org.apache.spark.sql.DataFrame;
+import org.apache.spark.sql.SQLContext;
+import static org.apache.spark.mllib.classification.LogisticRegressionSuite.generateLogisticInputAsList;
public class JavaLogisticRegressionSuite implements Serializable {
private transient JavaSparkContext jsc;
- private transient JavaSQLContext jsql;
- private transient JavaSchemaRDD dataset;
+ private transient SQLContext jsql;
+ private transient DataFrame dataset;
@Before
public void setUp() {
jsc = new JavaSparkContext("local", "JavaLogisticRegressionSuite");
- jsql = new JavaSQLContext(jsc);
+ jsql = new SQLContext(jsc);
List points = generateLogisticInputAsList(1.0, 1.0, 100, 42);
dataset = jsql.applySchema(jsc.parallelize(points, 2), LabeledPoint.class);
}
@@ -56,8 +55,8 @@ public void logisticRegression() {
LogisticRegression lr = new LogisticRegression();
LogisticRegressionModel model = lr.fit(dataset);
model.transform(dataset).registerTempTable("prediction");
- JavaSchemaRDD predictions = jsql.sql("SELECT label, score, prediction FROM prediction");
- predictions.collect();
+ DataFrame predictions = jsql.sql("SELECT label, score, prediction FROM prediction");
+ predictions.collectAsList();
}
@Test
@@ -68,8 +67,8 @@ public void logisticRegressionWithSetters() {
LogisticRegressionModel model = lr.fit(dataset);
model.transform(dataset, model.threshold().w(0.8)) // overwrite threshold
.registerTempTable("prediction");
- JavaSchemaRDD predictions = jsql.sql("SELECT label, score, prediction FROM prediction");
- predictions.collect();
+ DataFrame predictions = jsql.sql("SELECT label, score, prediction FROM prediction");
+ predictions.collectAsList();
}
@Test
diff --git a/mllib/src/test/java/org/apache/spark/ml/tuning/JavaCrossValidatorSuite.java b/mllib/src/test/java/org/apache/spark/ml/tuning/JavaCrossValidatorSuite.java
index a266ebd2071a1..074b58c07df7a 100644
--- a/mllib/src/test/java/org/apache/spark/ml/tuning/JavaCrossValidatorSuite.java
+++ b/mllib/src/test/java/org/apache/spark/ml/tuning/JavaCrossValidatorSuite.java
@@ -30,21 +30,20 @@
import org.apache.spark.ml.evaluation.BinaryClassificationEvaluator;
import org.apache.spark.ml.param.ParamMap;
import org.apache.spark.mllib.regression.LabeledPoint;
-import org.apache.spark.sql.api.java.JavaSQLContext;
-import org.apache.spark.sql.api.java.JavaSchemaRDD;
-import static org.apache.spark.mllib.classification.LogisticRegressionSuite
- .generateLogisticInputAsList;
+import org.apache.spark.sql.DataFrame;
+import org.apache.spark.sql.SQLContext;
+import static org.apache.spark.mllib.classification.LogisticRegressionSuite.generateLogisticInputAsList;
public class JavaCrossValidatorSuite implements Serializable {
private transient JavaSparkContext jsc;
- private transient JavaSQLContext jsql;
- private transient JavaSchemaRDD dataset;
+ private transient SQLContext jsql;
+ private transient DataFrame dataset;
@Before
public void setUp() {
jsc = new JavaSparkContext("local", "JavaCrossValidatorSuite");
- jsql = new JavaSQLContext(jsc);
+ jsql = new SQLContext(jsc);
List points = generateLogisticInputAsList(1.0, 1.0, 100, 42);
dataset = jsql.applySchema(jsc.parallelize(points, 2), LabeledPoint.class);
}
diff --git a/mllib/src/test/scala/org/apache/spark/ml/PipelineSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/PipelineSuite.scala
index 4515084bc7ae9..2f175fb117941 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/PipelineSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/PipelineSuite.scala
@@ -23,7 +23,7 @@ import org.scalatest.FunSuite
import org.scalatest.mock.MockitoSugar.mock
import org.apache.spark.ml.param.ParamMap
-import org.apache.spark.sql.SchemaRDD
+import org.apache.spark.sql.DataFrame
class PipelineSuite extends FunSuite {
@@ -36,11 +36,11 @@ class PipelineSuite extends FunSuite {
val estimator2 = mock[Estimator[MyModel]]
val model2 = mock[MyModel]
val transformer3 = mock[Transformer]
- val dataset0 = mock[SchemaRDD]
- val dataset1 = mock[SchemaRDD]
- val dataset2 = mock[SchemaRDD]
- val dataset3 = mock[SchemaRDD]
- val dataset4 = mock[SchemaRDD]
+ val dataset0 = mock[DataFrame]
+ val dataset1 = mock[DataFrame]
+ val dataset2 = mock[DataFrame]
+ val dataset3 = mock[DataFrame]
+ val dataset4 = mock[DataFrame]
when(estimator0.fit(meq(dataset0), any[ParamMap]())).thenReturn(model0)
when(model0.transform(meq(dataset0), any[ParamMap]())).thenReturn(dataset1)
@@ -74,7 +74,7 @@ class PipelineSuite extends FunSuite {
val estimator = mock[Estimator[MyModel]]
val pipeline = new Pipeline()
.setStages(Array(estimator, estimator))
- val dataset = mock[SchemaRDD]
+ val dataset = mock[DataFrame]
intercept[IllegalArgumentException] {
pipeline.fit(dataset)
}
diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala
index e8030fef55b1d..1912afce93b18 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala
@@ -21,12 +21,12 @@ import org.scalatest.FunSuite
import org.apache.spark.mllib.classification.LogisticRegressionSuite.generateLogisticInput
import org.apache.spark.mllib.util.MLlibTestSparkContext
-import org.apache.spark.sql.{SQLContext, SchemaRDD}
+import org.apache.spark.sql.{SQLContext, DataFrame}
class LogisticRegressionSuite extends FunSuite with MLlibTestSparkContext {
@transient var sqlContext: SQLContext = _
- @transient var dataset: SchemaRDD = _
+ @transient var dataset: DataFrame = _
override def beforeAll(): Unit = {
super.beforeAll()
@@ -36,34 +36,28 @@ class LogisticRegressionSuite extends FunSuite with MLlibTestSparkContext {
}
test("logistic regression") {
- val sqlContext = this.sqlContext
- import sqlContext._
val lr = new LogisticRegression
val model = lr.fit(dataset)
model.transform(dataset)
- .select('label, 'prediction)
+ .select("label", "prediction")
.collect()
}
test("logistic regression with setters") {
- val sqlContext = this.sqlContext
- import sqlContext._
val lr = new LogisticRegression()
.setMaxIter(10)
.setRegParam(1.0)
val model = lr.fit(dataset)
model.transform(dataset, model.threshold -> 0.8) // overwrite threshold
- .select('label, 'score, 'prediction)
+ .select("label", "score", "prediction")
.collect()
}
test("logistic regression fit and transform with varargs") {
- val sqlContext = this.sqlContext
- import sqlContext._
val lr = new LogisticRegression
val model = lr.fit(dataset, lr.maxIter -> 10, lr.regParam -> 1.0)
model.transform(dataset, model.threshold -> 0.8, model.scoreCol -> "probability")
- .select('label, 'probability, 'prediction)
+ .select("label", "probability", "prediction")
.collect()
}
}
diff --git a/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala
new file mode 100644
index 0000000000000..58289acdbc095
--- /dev/null
+++ b/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala
@@ -0,0 +1,435 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.recommendation
+
+import java.util.Random
+
+import scala.collection.mutable
+import scala.collection.mutable.ArrayBuffer
+
+import com.github.fommil.netlib.BLAS.{getInstance => blas}
+import org.scalatest.FunSuite
+
+import org.apache.spark.Logging
+import org.apache.spark.ml.recommendation.ALS._
+import org.apache.spark.mllib.linalg.Vectors
+import org.apache.spark.mllib.util.MLlibTestSparkContext
+import org.apache.spark.mllib.util.TestingUtils._
+import org.apache.spark.rdd.RDD
+import org.apache.spark.sql.{Row, SQLContext}
+
+class ALSSuite extends FunSuite with MLlibTestSparkContext with Logging {
+
+ private var sqlContext: SQLContext = _
+
+ override def beforeAll(): Unit = {
+ super.beforeAll()
+ sqlContext = new SQLContext(sc)
+ }
+
+ test("LocalIndexEncoder") {
+ val random = new Random
+ for (numBlocks <- Seq(1, 2, 5, 10, 20, 50, 100)) {
+ val encoder = new LocalIndexEncoder(numBlocks)
+ val maxLocalIndex = Int.MaxValue / numBlocks
+ val tests = Seq.fill(5)((random.nextInt(numBlocks), random.nextInt(maxLocalIndex))) ++
+ Seq((0, 0), (numBlocks - 1, maxLocalIndex))
+ tests.foreach { case (blockId, localIndex) =>
+ val err = s"Failed with numBlocks=$numBlocks, blockId=$blockId, and localIndex=$localIndex."
+ val encoded = encoder.encode(blockId, localIndex)
+ assert(encoder.blockId(encoded) === blockId, err)
+ assert(encoder.localIndex(encoded) === localIndex, err)
+ }
+ }
+ }
+
+ test("normal equation construction with explict feedback") {
+ val k = 2
+ val ne0 = new NormalEquation(k)
+ .add(Array(1.0f, 2.0f), 3.0f)
+ .add(Array(4.0f, 5.0f), 6.0f)
+ assert(ne0.k === k)
+ assert(ne0.triK === k * (k + 1) / 2)
+ assert(ne0.n === 2)
+ // NumPy code that computes the expected values:
+ // A = np.matrix("1 2; 4 5")
+ // b = np.matrix("3; 6")
+ // ata = A.transpose() * A
+ // atb = A.transpose() * b
+ assert(Vectors.dense(ne0.ata) ~== Vectors.dense(17.0, 22.0, 29.0) relTol 1e-8)
+ assert(Vectors.dense(ne0.atb) ~== Vectors.dense(27.0, 36.0) relTol 1e-8)
+
+ val ne1 = new NormalEquation(2)
+ .add(Array(7.0f, 8.0f), 9.0f)
+ ne0.merge(ne1)
+ assert(ne0.n === 3)
+ // NumPy code that computes the expected values:
+ // A = np.matrix("1 2; 4 5; 7 8")
+ // b = np.matrix("3; 6; 9")
+ // ata = A.transpose() * A
+ // atb = A.transpose() * b
+ assert(Vectors.dense(ne0.ata) ~== Vectors.dense(66.0, 78.0, 93.0) relTol 1e-8)
+ assert(Vectors.dense(ne0.atb) ~== Vectors.dense(90.0, 108.0) relTol 1e-8)
+
+ intercept[IllegalArgumentException] {
+ ne0.add(Array(1.0f), 2.0f)
+ }
+ intercept[IllegalArgumentException] {
+ ne0.add(Array(1.0f, 2.0f, 3.0f), 4.0f)
+ }
+ intercept[IllegalArgumentException] {
+ val ne2 = new NormalEquation(3)
+ ne0.merge(ne2)
+ }
+
+ ne0.reset()
+ assert(ne0.n === 0)
+ assert(ne0.ata.forall(_ == 0.0))
+ assert(ne0.atb.forall(_ == 0.0))
+ }
+
+ test("normal equation construction with implicit feedback") {
+ val k = 2
+ val alpha = 0.5
+ val ne0 = new NormalEquation(k)
+ .addImplicit(Array(-5.0f, -4.0f), -3.0f, alpha)
+ .addImplicit(Array(-2.0f, -1.0f), 0.0f, alpha)
+ .addImplicit(Array(1.0f, 2.0f), 3.0f, alpha)
+ assert(ne0.k === k)
+ assert(ne0.triK === k * (k + 1) / 2)
+ assert(ne0.n === 0) // addImplicit doesn't increase the count.
+ // NumPy code that computes the expected values:
+ // alpha = 0.5
+ // A = np.matrix("-5 -4; -2 -1; 1 2")
+ // b = np.matrix("-3; 0; 3")
+ // b1 = b > 0
+ // c = 1.0 + alpha * np.abs(b)
+ // C = np.diag(c.A1)
+ // I = np.eye(3)
+ // ata = A.transpose() * (C - I) * A
+ // atb = A.transpose() * C * b1
+ assert(Vectors.dense(ne0.ata) ~== Vectors.dense(39.0, 33.0, 30.0) relTol 1e-8)
+ assert(Vectors.dense(ne0.atb) ~== Vectors.dense(2.5, 5.0) relTol 1e-8)
+ }
+
+ test("CholeskySolver") {
+ val k = 2
+ val ne0 = new NormalEquation(k)
+ .add(Array(1.0f, 2.0f), 4.0f)
+ .add(Array(1.0f, 3.0f), 9.0f)
+ .add(Array(1.0f, 4.0f), 16.0f)
+ val ne1 = new NormalEquation(k)
+ .merge(ne0)
+
+ val chol = new CholeskySolver
+ val x0 = chol.solve(ne0, 0.0).map(_.toDouble)
+ // NumPy code that computes the expected solution:
+ // A = np.matrix("1 2; 1 3; 1 4")
+ // b = b = np.matrix("3; 6")
+ // x0 = np.linalg.lstsq(A, b)[0]
+ assert(Vectors.dense(x0) ~== Vectors.dense(-8.333333, 6.0) relTol 1e-6)
+
+ assert(ne0.n === 0)
+ assert(ne0.ata.forall(_ == 0.0))
+ assert(ne0.atb.forall(_ == 0.0))
+
+ val x1 = chol.solve(ne1, 0.5).map(_.toDouble)
+ // NumPy code that computes the expected solution, where lambda is scaled by n:
+ // x0 = np.linalg.solve(A.transpose() * A + 0.5 * 3 * np.eye(2), A.transpose() * b)
+ assert(Vectors.dense(x1) ~== Vectors.dense(-0.1155556, 3.28) relTol 1e-6)
+ }
+
+ test("RatingBlockBuilder") {
+ val emptyBuilder = new RatingBlockBuilder()
+ assert(emptyBuilder.size === 0)
+ val emptyBlock = emptyBuilder.build()
+ assert(emptyBlock.srcIds.isEmpty)
+ assert(emptyBlock.dstIds.isEmpty)
+ assert(emptyBlock.ratings.isEmpty)
+
+ val builder0 = new RatingBlockBuilder()
+ .add(Rating(0, 1, 2.0f))
+ .add(Rating(3, 4, 5.0f))
+ assert(builder0.size === 2)
+ val builder1 = new RatingBlockBuilder()
+ .add(Rating(6, 7, 8.0f))
+ .merge(builder0.build())
+ assert(builder1.size === 3)
+ val block = builder1.build()
+ val ratings = Seq.tabulate(block.size) { i =>
+ (block.srcIds(i), block.dstIds(i), block.ratings(i))
+ }.toSet
+ assert(ratings === Set((0, 1, 2.0f), (3, 4, 5.0f), (6, 7, 8.0f)))
+ }
+
+ test("UncompressedInBlock") {
+ val encoder = new LocalIndexEncoder(10)
+ val uncompressed = new UncompressedInBlockBuilder(encoder)
+ .add(0, Array(1, 0, 2), Array(0, 1, 4), Array(1.0f, 2.0f, 3.0f))
+ .add(1, Array(3, 0), Array(2, 5), Array(4.0f, 5.0f))
+ .build()
+ assert(uncompressed.size === 5)
+ val records = Seq.tabulate(uncompressed.size) { i =>
+ val dstEncodedIndex = uncompressed.dstEncodedIndices(i)
+ val dstBlockId = encoder.blockId(dstEncodedIndex)
+ val dstLocalIndex = encoder.localIndex(dstEncodedIndex)
+ (uncompressed.srcIds(i), dstBlockId, dstLocalIndex, uncompressed.ratings(i))
+ }.toSet
+ val expected =
+ Set((1, 0, 0, 1.0f), (0, 0, 1, 2.0f), (2, 0, 4, 3.0f), (3, 1, 2, 4.0f), (0, 1, 5, 5.0f))
+ assert(records === expected)
+
+ val compressed = uncompressed.compress()
+ assert(compressed.size === 5)
+ assert(compressed.srcIds.toSeq === Seq(0, 1, 2, 3))
+ assert(compressed.dstPtrs.toSeq === Seq(0, 2, 3, 4, 5))
+ var decompressed = ArrayBuffer.empty[(Int, Int, Int, Float)]
+ var i = 0
+ while (i < compressed.srcIds.size) {
+ var j = compressed.dstPtrs(i)
+ while (j < compressed.dstPtrs(i + 1)) {
+ val dstEncodedIndex = compressed.dstEncodedIndices(j)
+ val dstBlockId = encoder.blockId(dstEncodedIndex)
+ val dstLocalIndex = encoder.localIndex(dstEncodedIndex)
+ decompressed += ((compressed.srcIds(i), dstBlockId, dstLocalIndex, compressed.ratings(j)))
+ j += 1
+ }
+ i += 1
+ }
+ assert(decompressed.toSet === expected)
+ }
+
+ /**
+ * Generates an explicit feedback dataset for testing ALS.
+ * @param numUsers number of users
+ * @param numItems number of items
+ * @param rank rank
+ * @param noiseStd the standard deviation of additive Gaussian noise on training data
+ * @param seed random seed
+ * @return (training, test)
+ */
+ def genExplicitTestData(
+ numUsers: Int,
+ numItems: Int,
+ rank: Int,
+ noiseStd: Double = 0.0,
+ seed: Long = 11L): (RDD[Rating], RDD[Rating]) = {
+ val trainingFraction = 0.6
+ val testFraction = 0.3
+ val totalFraction = trainingFraction + testFraction
+ val random = new Random(seed)
+ val userFactors = genFactors(numUsers, rank, random)
+ val itemFactors = genFactors(numItems, rank, random)
+ val training = ArrayBuffer.empty[Rating]
+ val test = ArrayBuffer.empty[Rating]
+ for ((userId, userFactor) <- userFactors; (itemId, itemFactor) <- itemFactors) {
+ val x = random.nextDouble()
+ if (x < totalFraction) {
+ val rating = blas.sdot(rank, userFactor, 1, itemFactor, 1)
+ if (x < trainingFraction) {
+ val noise = noiseStd * random.nextGaussian()
+ training += Rating(userId, itemId, rating + noise.toFloat)
+ } else {
+ test += Rating(userId, itemId, rating)
+ }
+ }
+ }
+ logInfo(s"Generated an explicit feedback dataset with ${training.size} ratings for training " +
+ s"and ${test.size} for test.")
+ (sc.parallelize(training, 2), sc.parallelize(test, 2))
+ }
+
+ /**
+ * Generates an implicit feedback dataset for testing ALS.
+ * @param numUsers number of users
+ * @param numItems number of items
+ * @param rank rank
+ * @param noiseStd the standard deviation of additive Gaussian noise on training data
+ * @param seed random seed
+ * @return (training, test)
+ */
+ def genImplicitTestData(
+ numUsers: Int,
+ numItems: Int,
+ rank: Int,
+ noiseStd: Double = 0.0,
+ seed: Long = 11L): (RDD[Rating], RDD[Rating]) = {
+ // The assumption of the implicit feedback model is that unobserved ratings are more likely to
+ // be negatives.
+ val positiveFraction = 0.8
+ val negativeFraction = 1.0 - positiveFraction
+ val trainingFraction = 0.6
+ val testFraction = 0.3
+ val totalFraction = trainingFraction + testFraction
+ val random = new Random(seed)
+ val userFactors = genFactors(numUsers, rank, random)
+ val itemFactors = genFactors(numItems, rank, random)
+ val training = ArrayBuffer.empty[Rating]
+ val test = ArrayBuffer.empty[Rating]
+ for ((userId, userFactor) <- userFactors; (itemId, itemFactor) <- itemFactors) {
+ val rating = blas.sdot(rank, userFactor, 1, itemFactor, 1)
+ val threshold = if (rating > 0) positiveFraction else negativeFraction
+ val observed = random.nextDouble() < threshold
+ if (observed) {
+ val x = random.nextDouble()
+ if (x < totalFraction) {
+ if (x < trainingFraction) {
+ val noise = noiseStd * random.nextGaussian()
+ training += Rating(userId, itemId, rating + noise.toFloat)
+ } else {
+ test += Rating(userId, itemId, rating)
+ }
+ }
+ }
+ }
+ logInfo(s"Generated an implicit feedback dataset with ${training.size} ratings for training " +
+ s"and ${test.size} for test.")
+ (sc.parallelize(training, 2), sc.parallelize(test, 2))
+ }
+
+ /**
+ * Generates random user/item factors, with i.i.d. values drawn from U(a, b).
+ * @param size number of users/items
+ * @param rank number of features
+ * @param random random number generator
+ * @param a min value of the support (default: -1)
+ * @param b max value of the support (default: 1)
+ * @return a sequence of (ID, factors) pairs
+ */
+ private def genFactors(
+ size: Int,
+ rank: Int,
+ random: Random,
+ a: Float = -1.0f,
+ b: Float = 1.0f): Seq[(Int, Array[Float])] = {
+ require(size > 0 && size < Int.MaxValue / 3)
+ require(b > a)
+ val ids = mutable.Set.empty[Int]
+ while (ids.size < size) {
+ ids += random.nextInt()
+ }
+ val width = b - a
+ ids.toSeq.sorted.map(id => (id, Array.fill(rank)(a + random.nextFloat() * width)))
+ }
+
+ /**
+ * Test ALS using the given training/test splits and parameters.
+ * @param training training dataset
+ * @param test test dataset
+ * @param rank rank of the matrix factorization
+ * @param maxIter max number of iterations
+ * @param regParam regularization constant
+ * @param implicitPrefs whether to use implicit preference
+ * @param numUserBlocks number of user blocks
+ * @param numItemBlocks number of item blocks
+ * @param targetRMSE target test RMSE
+ */
+ def testALS(
+ training: RDD[Rating],
+ test: RDD[Rating],
+ rank: Int,
+ maxIter: Int,
+ regParam: Double,
+ implicitPrefs: Boolean = false,
+ numUserBlocks: Int = 2,
+ numItemBlocks: Int = 3,
+ targetRMSE: Double = 0.05): Unit = {
+ val sqlContext = this.sqlContext
+ import sqlContext.createSchemaRDD
+ val als = new ALS()
+ .setRank(rank)
+ .setRegParam(regParam)
+ .setImplicitPrefs(implicitPrefs)
+ .setNumUserBlocks(numUserBlocks)
+ .setNumItemBlocks(numItemBlocks)
+ val alpha = als.getAlpha
+ val model = als.fit(training)
+ val predictions = model.transform(test)
+ .select("rating", "prediction")
+ .map { case Row(rating: Float, prediction: Float) =>
+ (rating.toDouble, prediction.toDouble)
+ }
+ val rmse =
+ if (implicitPrefs) {
+ // TODO: Use a better (rank-based?) evaluation metric for implicit feedback.
+ // We limit the ratings and the predictions to interval [0, 1] and compute the weighted RMSE
+ // with the confidence scores as weights.
+ val (totalWeight, weightedSumSq) = predictions.map { case (rating, prediction) =>
+ val confidence = 1.0 + alpha * math.abs(rating)
+ val rating01 = math.max(math.min(rating, 1.0), 0.0)
+ val prediction01 = math.max(math.min(prediction, 1.0), 0.0)
+ val err = prediction01 - rating01
+ (confidence, confidence * err * err)
+ }.reduce { case ((c0, e0), (c1, e1)) =>
+ (c0 + c1, e0 + e1)
+ }
+ math.sqrt(weightedSumSq / totalWeight)
+ } else {
+ val mse = predictions.map { case (rating, prediction) =>
+ val err = rating - prediction
+ err * err
+ }.mean()
+ math.sqrt(mse)
+ }
+ logInfo(s"Test RMSE is $rmse.")
+ assert(rmse < targetRMSE)
+ }
+
+ test("exact rank-1 matrix") {
+ val (training, test) = genExplicitTestData(numUsers = 20, numItems = 40, rank = 1)
+ testALS(training, test, maxIter = 1, rank = 1, regParam = 1e-5, targetRMSE = 0.001)
+ testALS(training, test, maxIter = 1, rank = 2, regParam = 1e-5, targetRMSE = 0.001)
+ }
+
+ test("approximate rank-1 matrix") {
+ val (training, test) =
+ genExplicitTestData(numUsers = 20, numItems = 40, rank = 1, noiseStd = 0.01)
+ testALS(training, test, maxIter = 2, rank = 1, regParam = 0.01, targetRMSE = 0.02)
+ testALS(training, test, maxIter = 2, rank = 2, regParam = 0.01, targetRMSE = 0.02)
+ }
+
+ test("approximate rank-2 matrix") {
+ val (training, test) =
+ genExplicitTestData(numUsers = 20, numItems = 40, rank = 2, noiseStd = 0.01)
+ testALS(training, test, maxIter = 4, rank = 2, regParam = 0.01, targetRMSE = 0.03)
+ testALS(training, test, maxIter = 4, rank = 3, regParam = 0.01, targetRMSE = 0.03)
+ }
+
+ test("different block settings") {
+ val (training, test) =
+ genExplicitTestData(numUsers = 20, numItems = 40, rank = 2, noiseStd = 0.01)
+ for ((numUserBlocks, numItemBlocks) <- Seq((1, 1), (1, 2), (2, 1), (2, 2))) {
+ testALS(training, test, maxIter = 4, rank = 2, regParam = 0.01, targetRMSE = 0.03,
+ numUserBlocks = numUserBlocks, numItemBlocks = numItemBlocks)
+ }
+ }
+
+ test("more blocks than ratings") {
+ val (training, test) =
+ genExplicitTestData(numUsers = 4, numItems = 4, rank = 1)
+ testALS(training, test, maxIter = 2, rank = 1, regParam = 1e-4, targetRMSE = 0.002,
+ numItemBlocks = 5, numUserBlocks = 5)
+ }
+
+ test("implicit feedback") {
+ val (training, test) =
+ genImplicitTestData(numUsers = 20, numItems = 40, rank = 2, noiseStd = 0.01)
+ testALS(training, test, maxIter = 4, rank = 2, regParam = 0.01, implicitPrefs = true,
+ targetRMSE = 0.3)
+ }
+}
diff --git a/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala
index 41cc13da4d5b1..74104fa7a681a 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala
@@ -23,11 +23,11 @@ import org.apache.spark.ml.classification.LogisticRegression
import org.apache.spark.ml.evaluation.BinaryClassificationEvaluator
import org.apache.spark.mllib.classification.LogisticRegressionSuite.generateLogisticInput
import org.apache.spark.mllib.util.MLlibTestSparkContext
-import org.apache.spark.sql.{SQLContext, SchemaRDD}
+import org.apache.spark.sql.{SQLContext, DataFrame}
class CrossValidatorSuite extends FunSuite with MLlibTestSparkContext {
- @transient var dataset: SchemaRDD = _
+ @transient var dataset: DataFrame = _
override def beforeAll(): Unit = {
super.beforeAll()
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/clustering/GMMExpectationMaximizationSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/clustering/GMMExpectationMaximizationSuite.scala
index 9da5495741a80..198997b5bb2b2 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/clustering/GMMExpectationMaximizationSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/clustering/GMMExpectationMaximizationSuite.scala
@@ -20,6 +20,7 @@ package org.apache.spark.mllib.clustering
import org.scalatest.FunSuite
import org.apache.spark.mllib.linalg.{Vectors, Matrices}
+import org.apache.spark.mllib.stat.distribution.MultivariateGaussian
import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.mllib.util.TestingUtils._
@@ -39,9 +40,9 @@ class GMMExpectationMaximizationSuite extends FunSuite with MLlibTestSparkContex
val seeds = Array(314589, 29032897, 50181, 494821, 4660)
seeds.foreach { seed =>
val gmm = new GaussianMixtureEM().setK(1).setSeed(seed).run(data)
- assert(gmm.weight(0) ~== Ew absTol 1E-5)
- assert(gmm.mu(0) ~== Emu absTol 1E-5)
- assert(gmm.sigma(0) ~== Esigma absTol 1E-5)
+ assert(gmm.weights(0) ~== Ew absTol 1E-5)
+ assert(gmm.gaussians(0).mu ~== Emu absTol 1E-5)
+ assert(gmm.gaussians(0).sigma ~== Esigma absTol 1E-5)
}
}
@@ -57,8 +58,10 @@ class GMMExpectationMaximizationSuite extends FunSuite with MLlibTestSparkContex
// we set an initial gaussian to induce expected results
val initialGmm = new GaussianMixtureModel(
Array(0.5, 0.5),
- Array(Vectors.dense(-1.0), Vectors.dense(1.0)),
- Array(Matrices.dense(1, 1, Array(1.0)), Matrices.dense(1, 1, Array(1.0)))
+ Array(
+ new MultivariateGaussian(Vectors.dense(-1.0), Matrices.dense(1, 1, Array(1.0))),
+ new MultivariateGaussian(Vectors.dense(1.0), Matrices.dense(1, 1, Array(1.0)))
+ )
)
val Ew = Array(1.0 / 3.0, 2.0 / 3.0)
@@ -70,11 +73,11 @@ class GMMExpectationMaximizationSuite extends FunSuite with MLlibTestSparkContex
.setInitialModel(initialGmm)
.run(data)
- assert(gmm.weight(0) ~== Ew(0) absTol 1E-3)
- assert(gmm.weight(1) ~== Ew(1) absTol 1E-3)
- assert(gmm.mu(0) ~== Emu(0) absTol 1E-3)
- assert(gmm.mu(1) ~== Emu(1) absTol 1E-3)
- assert(gmm.sigma(0) ~== Esigma(0) absTol 1E-3)
- assert(gmm.sigma(1) ~== Esigma(1) absTol 1E-3)
+ assert(gmm.weights(0) ~== Ew(0) absTol 1E-3)
+ assert(gmm.weights(1) ~== Ew(1) absTol 1E-3)
+ assert(gmm.gaussians(0).mu ~== Emu(0) absTol 1E-3)
+ assert(gmm.gaussians(1).mu ~== Emu(1) absTol 1E-3)
+ assert(gmm.gaussians(0).sigma ~== Esigma(0) absTol 1E-3)
+ assert(gmm.gaussians(1).sigma ~== Esigma(1) absTol 1E-3)
}
}
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/clustering/KMeansSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/clustering/KMeansSuite.scala
index 9ebef8466c831..caee5917000aa 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/clustering/KMeansSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/clustering/KMeansSuite.scala
@@ -90,6 +90,27 @@ class KMeansSuite extends FunSuite with MLlibTestSparkContext {
assert(model.clusterCenters.size === 3)
}
+ test("deterministic initialization") {
+ // Create a large-ish set of points for clustering
+ val points = List.tabulate(1000)(n => Vectors.dense(n, n))
+ val rdd = sc.parallelize(points, 3)
+
+ for (initMode <- Seq(RANDOM, K_MEANS_PARALLEL)) {
+ // Create three deterministic models and compare cluster means
+ val model1 = KMeans.train(rdd, k = 10, maxIterations = 2, runs = 1,
+ initializationMode = initMode, seed = 42)
+ val centers1 = model1.clusterCenters
+
+ val model2 = KMeans.train(rdd, k = 10, maxIterations = 2, runs = 1,
+ initializationMode = initMode, seed = 42)
+ val centers2 = model2.clusterCenters
+
+ centers1.zip(centers2).foreach { case (c1, c2) =>
+ assert(c1 ~== c2 absTol 1E-14)
+ }
+ }
+ }
+
test("single cluster with big dataset") {
val smallData = Array(
Vectors.dense(1.0, 2.0, 6.0),
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/linalg/BLASSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/linalg/BLASSuite.scala
index 771878e925ea7..b0b78acd6df16 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/linalg/BLASSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/linalg/BLASSuite.scala
@@ -169,16 +169,17 @@ class BLASSuite extends FunSuite {
}
test("gemm") {
-
val dA =
new DenseMatrix(4, 3, Array(0.0, 1.0, 0.0, 0.0, 2.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 3.0))
val sA = new SparseMatrix(4, 3, Array(0, 1, 3, 4), Array(1, 0, 2, 3), Array(1.0, 2.0, 1.0, 3.0))
val B = new DenseMatrix(3, 2, Array(1.0, 0.0, 0.0, 0.0, 2.0, 1.0))
val expected = new DenseMatrix(4, 2, Array(0.0, 1.0, 0.0, 0.0, 4.0, 0.0, 2.0, 3.0))
+ val BTman = new DenseMatrix(2, 3, Array(1.0, 0.0, 0.0, 2.0, 0.0, 1.0))
+ val BT = B.transpose
- assert(dA multiply B ~== expected absTol 1e-15)
- assert(sA multiply B ~== expected absTol 1e-15)
+ assert(dA.multiply(B) ~== expected absTol 1e-15)
+ assert(sA.multiply(B) ~== expected absTol 1e-15)
val C1 = new DenseMatrix(4, 2, Array(1.0, 0.0, 2.0, 1.0, 0.0, 0.0, 1.0, 0.0))
val C2 = C1.copy
@@ -188,6 +189,10 @@ class BLASSuite extends FunSuite {
val C6 = C1.copy
val C7 = C1.copy
val C8 = C1.copy
+ val C9 = C1.copy
+ val C10 = C1.copy
+ val C11 = C1.copy
+ val C12 = C1.copy
val expected2 = new DenseMatrix(4, 2, Array(2.0, 1.0, 4.0, 2.0, 4.0, 0.0, 4.0, 3.0))
val expected3 = new DenseMatrix(4, 2, Array(2.0, 2.0, 4.0, 2.0, 8.0, 0.0, 6.0, 6.0))
@@ -202,26 +207,40 @@ class BLASSuite extends FunSuite {
withClue("columns of A don't match the rows of B") {
intercept[Exception] {
- gemm(true, false, 1.0, dA, B, 2.0, C1)
+ gemm(1.0, dA.transpose, B, 2.0, C1)
}
}
- val dAT =
+ val dATman =
new DenseMatrix(3, 4, Array(0.0, 2.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 3.0))
- val sAT =
+ val sATman =
new SparseMatrix(3, 4, Array(0, 1, 2, 3, 4), Array(1, 0, 1, 2), Array(2.0, 1.0, 1.0, 3.0))
- assert(dAT transposeMultiply B ~== expected absTol 1e-15)
- assert(sAT transposeMultiply B ~== expected absTol 1e-15)
-
- gemm(true, false, 1.0, dAT, B, 2.0, C5)
- gemm(true, false, 1.0, sAT, B, 2.0, C6)
- gemm(true, false, 2.0, dAT, B, 2.0, C7)
- gemm(true, false, 2.0, sAT, B, 2.0, C8)
+ val dATT = dATman.transpose
+ val sATT = sATman.transpose
+ val BTT = BTman.transpose.asInstanceOf[DenseMatrix]
+
+ assert(dATT.multiply(B) ~== expected absTol 1e-15)
+ assert(sATT.multiply(B) ~== expected absTol 1e-15)
+ assert(dATT.multiply(BTT) ~== expected absTol 1e-15)
+ assert(sATT.multiply(BTT) ~== expected absTol 1e-15)
+
+ gemm(1.0, dATT, BTT, 2.0, C5)
+ gemm(1.0, sATT, BTT, 2.0, C6)
+ gemm(2.0, dATT, BTT, 2.0, C7)
+ gemm(2.0, sATT, BTT, 2.0, C8)
+ gemm(1.0, dA, BTT, 2.0, C9)
+ gemm(1.0, sA, BTT, 2.0, C10)
+ gemm(2.0, dA, BTT, 2.0, C11)
+ gemm(2.0, sA, BTT, 2.0, C12)
assert(C5 ~== expected2 absTol 1e-15)
assert(C6 ~== expected2 absTol 1e-15)
assert(C7 ~== expected3 absTol 1e-15)
assert(C8 ~== expected3 absTol 1e-15)
+ assert(C9 ~== expected2 absTol 1e-15)
+ assert(C10 ~== expected2 absTol 1e-15)
+ assert(C11 ~== expected3 absTol 1e-15)
+ assert(C12 ~== expected3 absTol 1e-15)
}
test("gemv") {
@@ -233,17 +252,13 @@ class BLASSuite extends FunSuite {
val x = new DenseVector(Array(1.0, 2.0, 3.0))
val expected = new DenseVector(Array(4.0, 1.0, 2.0, 9.0))
- assert(dA multiply x ~== expected absTol 1e-15)
- assert(sA multiply x ~== expected absTol 1e-15)
+ assert(dA.multiply(x) ~== expected absTol 1e-15)
+ assert(sA.multiply(x) ~== expected absTol 1e-15)
val y1 = new DenseVector(Array(1.0, 3.0, 1.0, 0.0))
val y2 = y1.copy
val y3 = y1.copy
val y4 = y1.copy
- val y5 = y1.copy
- val y6 = y1.copy
- val y7 = y1.copy
- val y8 = y1.copy
val expected2 = new DenseVector(Array(6.0, 7.0, 4.0, 9.0))
val expected3 = new DenseVector(Array(10.0, 8.0, 6.0, 18.0))
@@ -257,25 +272,18 @@ class BLASSuite extends FunSuite {
assert(y4 ~== expected3 absTol 1e-15)
withClue("columns of A don't match the rows of B") {
intercept[Exception] {
- gemv(true, 1.0, dA, x, 2.0, y1)
+ gemv(1.0, dA.transpose, x, 2.0, y1)
}
}
-
val dAT =
new DenseMatrix(3, 4, Array(0.0, 2.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 3.0))
val sAT =
new SparseMatrix(3, 4, Array(0, 1, 2, 3, 4), Array(1, 0, 1, 2), Array(2.0, 1.0, 1.0, 3.0))
- assert(dAT transposeMultiply x ~== expected absTol 1e-15)
- assert(sAT transposeMultiply x ~== expected absTol 1e-15)
-
- gemv(true, 1.0, dAT, x, 2.0, y5)
- gemv(true, 1.0, sAT, x, 2.0, y6)
- gemv(true, 2.0, dAT, x, 2.0, y7)
- gemv(true, 2.0, sAT, x, 2.0, y8)
- assert(y5 ~== expected2 absTol 1e-15)
- assert(y6 ~== expected2 absTol 1e-15)
- assert(y7 ~== expected3 absTol 1e-15)
- assert(y8 ~== expected3 absTol 1e-15)
+ val dATT = dAT.transpose
+ val sATT = sAT.transpose
+
+ assert(dATT.multiply(x) ~== expected absTol 1e-15)
+ assert(sATT.multiply(x) ~== expected absTol 1e-15)
}
}
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/linalg/BreezeMatrixConversionSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/linalg/BreezeMatrixConversionSuite.scala
index 73a6d3a27d868..2031032373971 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/linalg/BreezeMatrixConversionSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/linalg/BreezeMatrixConversionSuite.scala
@@ -36,6 +36,11 @@ class BreezeMatrixConversionSuite extends FunSuite {
assert(mat.numRows === breeze.rows)
assert(mat.numCols === breeze.cols)
assert(mat.values.eq(breeze.data), "should not copy data")
+ // transposed matrix
+ val matTransposed = Matrices.fromBreeze(breeze.t).asInstanceOf[DenseMatrix]
+ assert(matTransposed.numRows === breeze.cols)
+ assert(matTransposed.numCols === breeze.rows)
+ assert(matTransposed.values.eq(breeze.data), "should not copy data")
}
test("sparse matrix to breeze") {
@@ -58,5 +63,9 @@ class BreezeMatrixConversionSuite extends FunSuite {
assert(mat.numRows === breeze.rows)
assert(mat.numCols === breeze.cols)
assert(mat.values.eq(breeze.data), "should not copy data")
+ val matTransposed = Matrices.fromBreeze(breeze.t).asInstanceOf[SparseMatrix]
+ assert(matTransposed.numRows === breeze.cols)
+ assert(matTransposed.numCols === breeze.rows)
+ assert(!matTransposed.values.eq(breeze.data), "has to copy data")
}
}
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/linalg/MatricesSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/linalg/MatricesSuite.scala
index a35d0fe389fdd..b1ebfde0e5e57 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/linalg/MatricesSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/linalg/MatricesSuite.scala
@@ -22,6 +22,9 @@ import java.util.Random
import org.mockito.Mockito.when
import org.scalatest.FunSuite
import org.scalatest.mock.MockitoSugar._
+import scala.collection.mutable.{Map => MutableMap}
+
+import org.apache.spark.mllib.util.TestingUtils._
class MatricesSuite extends FunSuite {
test("dense matrix construction") {
@@ -32,7 +35,6 @@ class MatricesSuite extends FunSuite {
assert(mat.numRows === m)
assert(mat.numCols === n)
assert(mat.values.eq(values), "should not copy data")
- assert(mat.toArray.eq(values), "toArray should not copy data")
}
test("dense matrix construction with wrong dimension") {
@@ -161,6 +163,66 @@ class MatricesSuite extends FunSuite {
assert(deMat1.toArray === deMat2.toArray)
}
+ test("transpose") {
+ val dA =
+ new DenseMatrix(4, 3, Array(0.0, 1.0, 0.0, 0.0, 2.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 3.0))
+ val sA = new SparseMatrix(4, 3, Array(0, 1, 3, 4), Array(1, 0, 2, 3), Array(1.0, 2.0, 1.0, 3.0))
+
+ val dAT = dA.transpose.asInstanceOf[DenseMatrix]
+ val sAT = sA.transpose.asInstanceOf[SparseMatrix]
+ val dATexpected =
+ new DenseMatrix(3, 4, Array(0.0, 2.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 3.0))
+ val sATexpected =
+ new SparseMatrix(3, 4, Array(0, 1, 2, 3, 4), Array(1, 0, 1, 2), Array(2.0, 1.0, 1.0, 3.0))
+
+ assert(dAT.toBreeze === dATexpected.toBreeze)
+ assert(sAT.toBreeze === sATexpected.toBreeze)
+ assert(dA(1, 0) === dAT(0, 1))
+ assert(dA(2, 1) === dAT(1, 2))
+ assert(sA(1, 0) === sAT(0, 1))
+ assert(sA(2, 1) === sAT(1, 2))
+
+ assert(!dA.toArray.eq(dAT.toArray), "has to have a new array")
+ assert(dA.values.eq(dAT.transpose.asInstanceOf[DenseMatrix].values), "should not copy array")
+
+ assert(dAT.toSparse().toBreeze === sATexpected.toBreeze)
+ assert(sAT.toDense().toBreeze === dATexpected.toBreeze)
+ }
+
+ test("foreachActive") {
+ val m = 3
+ val n = 2
+ val values = Array(1.0, 2.0, 4.0, 5.0)
+ val allValues = Array(1.0, 2.0, 0.0, 0.0, 4.0, 5.0)
+ val colPtrs = Array(0, 2, 4)
+ val rowIndices = Array(0, 1, 1, 2)
+
+ val sp = new SparseMatrix(m, n, colPtrs, rowIndices, values)
+ val dn = new DenseMatrix(m, n, allValues)
+
+ val dnMap = MutableMap[(Int, Int), Double]()
+ dn.foreachActive { (i, j, value) =>
+ dnMap.put((i, j), value)
+ }
+ assert(dnMap.size === 6)
+ assert(dnMap(0, 0) === 1.0)
+ assert(dnMap(1, 0) === 2.0)
+ assert(dnMap(2, 0) === 0.0)
+ assert(dnMap(0, 1) === 0.0)
+ assert(dnMap(1, 1) === 4.0)
+ assert(dnMap(2, 1) === 5.0)
+
+ val spMap = MutableMap[(Int, Int), Double]()
+ sp.foreachActive { (i, j, value) =>
+ spMap.put((i, j), value)
+ }
+ assert(spMap.size === 4)
+ assert(spMap(0, 0) === 1.0)
+ assert(spMap(1, 0) === 2.0)
+ assert(spMap(1, 1) === 4.0)
+ assert(spMap(2, 1) === 5.0)
+ }
+
test("horzcat, vertcat, eye, speye") {
val m = 3
val n = 2
@@ -168,9 +230,20 @@ class MatricesSuite extends FunSuite {
val allValues = Array(1.0, 2.0, 0.0, 0.0, 4.0, 5.0)
val colPtrs = Array(0, 2, 4)
val rowIndices = Array(0, 1, 1, 2)
+ // transposed versions
+ val allValuesT = Array(1.0, 0.0, 2.0, 4.0, 0.0, 5.0)
+ val colPtrsT = Array(0, 1, 3, 4)
+ val rowIndicesT = Array(0, 0, 1, 1)
val spMat1 = new SparseMatrix(m, n, colPtrs, rowIndices, values)
val deMat1 = new DenseMatrix(m, n, allValues)
+ val spMat1T = new SparseMatrix(n, m, colPtrsT, rowIndicesT, values)
+ val deMat1T = new DenseMatrix(n, m, allValuesT)
+
+ // should equal spMat1 & deMat1 respectively
+ val spMat1TT = spMat1T.transpose
+ val deMat1TT = deMat1T.transpose
+
val deMat2 = Matrices.eye(3)
val spMat2 = Matrices.speye(3)
val deMat3 = Matrices.eye(2)
@@ -180,7 +253,6 @@ class MatricesSuite extends FunSuite {
val spHorz2 = Matrices.horzcat(Array(spMat1, deMat2))
val spHorz3 = Matrices.horzcat(Array(deMat1, spMat2))
val deHorz1 = Matrices.horzcat(Array(deMat1, deMat2))
-
val deHorz2 = Matrices.horzcat(Array[Matrix]())
assert(deHorz1.numRows === 3)
@@ -195,8 +267,8 @@ class MatricesSuite extends FunSuite {
assert(deHorz2.numCols === 0)
assert(deHorz2.toArray.length === 0)
- assert(deHorz1.toBreeze.toDenseMatrix === spHorz2.toBreeze.toDenseMatrix)
- assert(spHorz2.toBreeze === spHorz3.toBreeze)
+ assert(deHorz1 ~== spHorz2.asInstanceOf[SparseMatrix].toDense absTol 1e-15)
+ assert(spHorz2 ~== spHorz3 absTol 1e-15)
assert(spHorz(0, 0) === 1.0)
assert(spHorz(2, 1) === 5.0)
assert(spHorz(0, 2) === 1.0)
@@ -212,6 +284,17 @@ class MatricesSuite extends FunSuite {
assert(deHorz1(2, 4) === 1.0)
assert(deHorz1(1, 4) === 0.0)
+ // containing transposed matrices
+ val spHorzT = Matrices.horzcat(Array(spMat1TT, spMat2))
+ val spHorz2T = Matrices.horzcat(Array(spMat1TT, deMat2))
+ val spHorz3T = Matrices.horzcat(Array(deMat1TT, spMat2))
+ val deHorz1T = Matrices.horzcat(Array(deMat1TT, deMat2))
+
+ assert(deHorz1T ~== deHorz1 absTol 1e-15)
+ assert(spHorzT ~== spHorz absTol 1e-15)
+ assert(spHorz2T ~== spHorz2 absTol 1e-15)
+ assert(spHorz3T ~== spHorz3 absTol 1e-15)
+
intercept[IllegalArgumentException] {
Matrices.horzcat(Array(spMat1, spMat3))
}
@@ -238,8 +321,8 @@ class MatricesSuite extends FunSuite {
assert(deVert2.numCols === 0)
assert(deVert2.toArray.length === 0)
- assert(deVert1.toBreeze.toDenseMatrix === spVert2.toBreeze.toDenseMatrix)
- assert(spVert2.toBreeze === spVert3.toBreeze)
+ assert(deVert1 ~== spVert2.asInstanceOf[SparseMatrix].toDense absTol 1e-15)
+ assert(spVert2 ~== spVert3 absTol 1e-15)
assert(spVert(0, 0) === 1.0)
assert(spVert(2, 1) === 5.0)
assert(spVert(3, 0) === 1.0)
@@ -251,6 +334,17 @@ class MatricesSuite extends FunSuite {
assert(deVert1(3, 1) === 0.0)
assert(deVert1(4, 1) === 1.0)
+ // containing transposed matrices
+ val spVertT = Matrices.vertcat(Array(spMat1TT, spMat3))
+ val deVert1T = Matrices.vertcat(Array(deMat1TT, deMat3))
+ val spVert2T = Matrices.vertcat(Array(spMat1TT, deMat3))
+ val spVert3T = Matrices.vertcat(Array(deMat1TT, spMat3))
+
+ assert(deVert1T ~== deVert1 absTol 1e-15)
+ assert(spVertT ~== spVert absTol 1e-15)
+ assert(spVert2T ~== spVert2 absTol 1e-15)
+ assert(spVert3T ~== spVert3 absTol 1e-15)
+
intercept[IllegalArgumentException] {
Matrices.vertcat(Array(spMat1, spMat2))
}
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/linalg/VectorsSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/linalg/VectorsSuite.scala
index 85ac8ccebfc59..5def899cea117 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/linalg/VectorsSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/linalg/VectorsSuite.scala
@@ -89,6 +89,24 @@ class VectorsSuite extends FunSuite {
}
}
+ test("vectors equals with explicit 0") {
+ val dv1 = Vectors.dense(Array(0, 0.9, 0, 0.8, 0))
+ val sv1 = Vectors.sparse(5, Array(1, 3), Array(0.9, 0.8))
+ val sv2 = Vectors.sparse(5, Array(0, 1, 2, 3, 4), Array(0, 0.9, 0, 0.8, 0))
+
+ val vectors = Seq(dv1, sv1, sv2)
+ for (v <- vectors; u <- vectors) {
+ assert(v === u)
+ assert(v.## === u.##)
+ }
+
+ val another = Vectors.sparse(5, Array(0, 1, 3), Array(0, 0.9, 0.2))
+ for (v <- vectors) {
+ assert(v != another)
+ assert(v.## != another.##)
+ }
+ }
+
test("indexing dense vectors") {
val vec = Vectors.dense(1.0, 2.0, 3.0, 4.0)
assert(vec(0) === 1.0)
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/CoordinateMatrixSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/CoordinateMatrixSuite.scala
index f8709751efce6..80bef814ce50d 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/CoordinateMatrixSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/CoordinateMatrixSuite.scala
@@ -73,6 +73,11 @@ class CoordinateMatrixSuite extends FunSuite with MLlibTestSparkContext {
assert(mat.toBreeze() === expected)
}
+ test("transpose") {
+ val transposed = mat.transpose()
+ assert(mat.toBreeze().t === transposed.toBreeze())
+ }
+
test("toIndexedRowMatrix") {
val indexedRowMatrix = mat.toIndexedRowMatrix()
val expected = BDM(
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/IndexedRowMatrixSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/IndexedRowMatrixSuite.scala
index 741cd4997b853..b86c2ca5ff136 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/IndexedRowMatrixSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/IndexedRowMatrixSuite.scala
@@ -80,6 +80,14 @@ class IndexedRowMatrixSuite extends FunSuite with MLlibTestSparkContext {
assert(rowMat.rows.collect().toSeq === data.map(_.vector).toSeq)
}
+ test("toCoordinateMatrix") {
+ val idxRowMat = new IndexedRowMatrix(indexedRows)
+ val coordMat = idxRowMat.toCoordinateMatrix()
+ assert(coordMat.numRows() === m)
+ assert(coordMat.numCols() === n)
+ assert(coordMat.toBreeze() === idxRowMat.toBreeze())
+ }
+
test("multiply a local matrix") {
val A = new IndexedRowMatrix(indexedRows)
val B = Matrices.dense(3, 2, Array(0.0, 1.0, 2.0, 3.0, 4.0, 5.0))
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/recommendation/ALSSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/recommendation/ALSSuite.scala
index f3b7bfda788fa..e9fc37e000526 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/recommendation/ALSSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/recommendation/ALSSuite.scala
@@ -215,7 +215,7 @@ class ALSSuite extends FunSuite with MLlibTestSparkContext {
* @param samplingRate what fraction of the user-product pairs are known
* @param matchThreshold max difference allowed to consider a predicted rating correct
* @param implicitPrefs flag to test implicit feedback
- * @param bulkPredict flag to test bulk prediciton
+ * @param bulkPredict flag to test bulk predicition
* @param negativeWeights whether the generated data can contain negative values
* @param numUserBlocks number of user blocks to partition users into
* @param numProductBlocks number of product blocks to partition products into
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/tree/ImpuritySuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/tree/ImpuritySuite.scala
new file mode 100644
index 0000000000000..92b498580af03
--- /dev/null
+++ b/mllib/src/test/scala/org/apache/spark/mllib/tree/ImpuritySuite.scala
@@ -0,0 +1,42 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.mllib.tree
+
+import org.scalatest.FunSuite
+
+import org.apache.spark.mllib.tree.impurity.{EntropyAggregator, GiniAggregator}
+import org.apache.spark.mllib.util.MLlibTestSparkContext
+
+/**
+ * Test suites for [[GiniAggregator]] and [[EntropyAggregator]].
+ */
+class ImpuritySuite extends FunSuite with MLlibTestSparkContext {
+ test("Gini impurity does not support negative labels") {
+ val gini = new GiniAggregator(2)
+ intercept[IllegalArgumentException] {
+ gini.update(Array(0.0, 1.0, 2.0), 0, -1, 0.0)
+ }
+ }
+
+ test("Entropy does not support negative labels") {
+ val entropy = new EntropyAggregator(2)
+ intercept[IllegalArgumentException] {
+ entropy.update(Array(0.0, 1.0, 2.0), 0, -1, 0.0)
+ }
+ }
+}
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/tree/RandomForestSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/tree/RandomForestSuite.scala
index f7f0f20c6c125..55e963977b54f 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/tree/RandomForestSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/tree/RandomForestSuite.scala
@@ -196,6 +196,22 @@ class RandomForestSuite extends FunSuite with MLlibTestSparkContext {
featureSubsetStrategy = "sqrt", seed = 12345)
EnsembleTestHelper.validateClassifier(model, arr, 1.0)
}
+
+ test("subsampling rate in RandomForest"){
+ val arr = EnsembleTestHelper.generateOrderedLabeledPoints(5, 20)
+ val rdd = sc.parallelize(arr)
+ val strategy = new Strategy(algo = Classification, impurity = Gini, maxDepth = 2,
+ numClasses = 2, categoricalFeaturesInfo = Map.empty[Int, Int],
+ useNodeIdCache = true)
+
+ val rf1 = RandomForest.trainClassifier(rdd, strategy, numTrees = 3,
+ featureSubsetStrategy = "auto", seed = 123)
+ strategy.subsamplingRate = 0.5
+ val rf2 = RandomForest.trainClassifier(rdd, strategy, numTrees = 3,
+ featureSubsetStrategy = "auto", seed = 123)
+ assert(rf1.toDebugString != rf2.toDebugString)
+ }
+
}
diff --git a/pom.xml b/pom.xml
index f4466e56c2a53..05cb3797fc55b 100644
--- a/pom.xml
+++ b/pom.xml
@@ -117,7 +117,7 @@
2.0.10.21.0shaded-protobuf
- 1.7.5
+ 1.7.101.2.171.0.42.4.1
@@ -1507,7 +1507,7 @@
mapr31.0.3-mapr-3.0.3
- 2.3.0-mapr-4.0.0-FCS
+ 2.4.1-mapr-14080.94.17-mapr-14053.4.5-mapr-1406
@@ -1516,8 +1516,8 @@
mapr4
- 2.3.0-mapr-4.0.0-FCS
- 2.3.0-mapr-4.0.0-FCS
+ 2.4.1-mapr-1408
+ 2.4.1-mapr-14080.94.17-mapr-1405-4.0.0-FCS3.4.5-mapr-1406
diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala
index d3ea594245722..e750fed7448cd 100644
--- a/project/MimaExcludes.scala
+++ b/project/MimaExcludes.scala
@@ -52,6 +52,20 @@ object MimaExcludes {
"org.apache.spark.mllib.linalg.Matrices.randn"),
ProblemFilters.exclude[MissingMethodProblem](
"org.apache.spark.mllib.linalg.Matrices.rand")
+ ) ++ Seq(
+ // SPARK-5321
+ ProblemFilters.exclude[MissingMethodProblem](
+ "org.apache.spark.mllib.linalg.SparseMatrix.transposeMultiply"),
+ ProblemFilters.exclude[MissingMethodProblem](
+ "org.apache.spark.mllib.linalg.Matrix.transpose"),
+ ProblemFilters.exclude[MissingMethodProblem](
+ "org.apache.spark.mllib.linalg.DenseMatrix.transposeMultiply"),
+ ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.mllib.linalg.Matrix." +
+ "org$apache$spark$mllib$linalg$Matrix$_setter_$isTransposed_="),
+ ProblemFilters.exclude[MissingMethodProblem](
+ "org.apache.spark.mllib.linalg.Matrix.isTransposed"),
+ ProblemFilters.exclude[MissingMethodProblem](
+ "org.apache.spark.mllib.linalg.Matrix.foreachActive")
) ++ Seq(
// SPARK-3325
ProblemFilters.exclude[MissingMethodProblem](
@@ -78,6 +92,35 @@ object MimaExcludes {
"org.apache.spark.TaskContext.taskAttemptId"),
ProblemFilters.exclude[MissingMethodProblem](
"org.apache.spark.TaskContext.attemptNumber")
+ ) ++ Seq(
+ // SPARK-5166 Spark SQL API stabilization
+ ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.Transformer.transform"),
+ ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.Estimator.fit"),
+ ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.ml.Transformer.transform"),
+ ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.Pipeline.fit"),
+ ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.PipelineModel.transform"),
+ ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.ml.Estimator.fit"),
+ ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.Evaluator.evaluate"),
+ ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.ml.Evaluator.evaluate"),
+ ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.tuning.CrossValidator.fit"),
+ ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.tuning.CrossValidatorModel.transform"),
+ ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.feature.StandardScaler.fit"),
+ ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.feature.StandardScalerModel.transform"),
+ ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.classification.LogisticRegressionModel.transform"),
+ ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.classification.LogisticRegression.fit"),
+ ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.evaluation.BinaryClassificationEvaluator.evaluate")
+ ) ++ Seq(
+ // SPARK-5270
+ ProblemFilters.exclude[MissingMethodProblem](
+ "org.apache.spark.api.java.JavaRDDLike.isEmpty")
+ ) ++ Seq(
+ // SPARK-5297 Java FileStream do not work with custom key/values
+ ProblemFilters.exclude[MissingMethodProblem](
+ "org.apache.spark.streaming.api.java.JavaStreamingContext.fileStream")
+ ) ++ Seq(
+ // SPARK-5315 Spark Streaming Java API returns Scala DStream
+ ProblemFilters.exclude[MissingMethodProblem](
+ "org.apache.spark.streaming.api.java.JavaDStreamLike.reduceByWindow")
)
case v if v.startsWith("1.2") =>
diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala
index b2c546da21c70..ded4b5443a904 100644
--- a/project/SparkBuild.scala
+++ b/project/SparkBuild.scala
@@ -114,17 +114,6 @@ object SparkBuild extends PomBuild {
override val userPropertiesMap = System.getProperties.toMap
- // Handle case where hadoop.version is set via profile.
- // Needed only because we read back this property in sbt
- // when we create the assembly jar.
- val pom = loadEffectivePom(new File("pom.xml"),
- profiles = profiles,
- userProps = userPropertiesMap)
- if (System.getProperty("hadoop.version") == null) {
- System.setProperty("hadoop.version",
- pom.getProperties.get("hadoop.version").asInstanceOf[String])
- }
-
lazy val MavenCompile = config("m2r") extend(Compile)
lazy val publishLocalBoth = TaskKey[Unit]("publish-local", "publish local for m2 and ivy")
@@ -303,16 +292,15 @@ object Assembly {
import sbtassembly.Plugin._
import AssemblyKeys._
+ val hadoopVersion = taskKey[String]("The version of hadoop that spark is compiled against.")
+
lazy val settings = assemblySettings ++ Seq(
test in assembly := {},
- jarName in assembly <<= (version, moduleName) map { (v, mName) =>
- if (mName.contains("network-yarn")) {
- // This must match the same name used in maven (see network/yarn/pom.xml)
- "spark-" + v + "-yarn-shuffle.jar"
- } else {
- mName + "-" + v + "-hadoop" + System.getProperty("hadoop.version") + ".jar"
- }
+ hadoopVersion := {
+ sys.props.get("hadoop.version")
+ .getOrElse(SbtPomKeys.effectivePom.value.getProperties.get("hadoop.version").asInstanceOf[String])
},
+ jarName in assembly := s"${moduleName.value}-${version.value}-hadoop${hadoopVersion.value}.jar",
mergeStrategy in assembly := {
case PathList("org", "datanucleus", xs @ _*) => MergeStrategy.discard
case m if m.toLowerCase.endsWith("manifest.mf") => MergeStrategy.discard
@@ -323,7 +311,6 @@ object Assembly {
case _ => MergeStrategy.first
}
)
-
}
object Unidoc {
diff --git a/python/pyspark/context.py b/python/pyspark/context.py
index 593d74bca5fff..568e21f3803bf 100644
--- a/python/pyspark/context.py
+++ b/python/pyspark/context.py
@@ -229,6 +229,14 @@ def _ensure_initialized(cls, instance=None, gateway=None):
else:
SparkContext._active_spark_context = instance
+ def __getnewargs__(self):
+ # This method is called when attempting to pickle SparkContext, which is always an error:
+ raise Exception(
+ "It appears that you are attempting to reference SparkContext from a broadcast "
+ "variable, action, or transforamtion. SparkContext can only be used on the driver, "
+ "not in code that it run on workers. For more information, see SPARK-5063."
+ )
+
def __enter__(self):
"""
Enable 'with SparkContext(...) as sc: app(sc)' syntax.
@@ -319,7 +327,7 @@ def f(split, iterator):
# Make sure we distribute data evenly if it's smaller than self.batchSize
if "__len__" not in dir(c):
c = list(c) # Make it a list so we can compute its length
- batchSize = max(1, min(len(c) // numSlices, self._batchSize))
+ batchSize = max(1, min(len(c) // numSlices, self._batchSize or 1024))
serializer = BatchedSerializer(self._unbatched_serializer, batchSize)
serializer.dump_stream(c, tempFile)
tempFile.close()
diff --git a/python/pyspark/java_gateway.py b/python/pyspark/java_gateway.py
index a975dc19cb78e..a0a028446d5fd 100644
--- a/python/pyspark/java_gateway.py
+++ b/python/pyspark/java_gateway.py
@@ -111,10 +111,9 @@ def run(self):
java_import(gateway.jvm, "org.apache.spark.api.java.*")
java_import(gateway.jvm, "org.apache.spark.api.python.*")
java_import(gateway.jvm, "org.apache.spark.mllib.api.python.*")
- java_import(gateway.jvm, "org.apache.spark.sql.SQLContext")
- java_import(gateway.jvm, "org.apache.spark.sql.hive.HiveContext")
- java_import(gateway.jvm, "org.apache.spark.sql.hive.LocalHiveContext")
- java_import(gateway.jvm, "org.apache.spark.sql.hive.TestHiveContext")
+ # TODO(davies): move into sql
+ java_import(gateway.jvm, "org.apache.spark.sql.*")
+ java_import(gateway.jvm, "org.apache.spark.sql.hive.*")
java_import(gateway.jvm, "scala.Tuple2")
return gateway
diff --git a/python/pyspark/mllib/clustering.py b/python/pyspark/mllib/clustering.py
index e2492eef5bd6a..6b713aa39374e 100644
--- a/python/pyspark/mllib/clustering.py
+++ b/python/pyspark/mllib/clustering.py
@@ -78,10 +78,10 @@ def predict(self, x):
class KMeans(object):
@classmethod
- def train(cls, rdd, k, maxIterations=100, runs=1, initializationMode="k-means||"):
+ def train(cls, rdd, k, maxIterations=100, runs=1, initializationMode="k-means||", seed=None):
"""Train a k-means clustering model."""
model = callMLlibFunc("trainKMeansModel", rdd.map(_convert_to_vector), k, maxIterations,
- runs, initializationMode)
+ runs, initializationMode, seed)
centers = callJavaFunc(rdd.context, model.clusterCenters)
return KMeansModel([c.toArray() for c in centers])
diff --git a/python/pyspark/mllib/tests.py b/python/pyspark/mllib/tests.py
index 140c22b5fd4e8..f48e3d6dacb4b 100644
--- a/python/pyspark/mllib/tests.py
+++ b/python/pyspark/mllib/tests.py
@@ -140,7 +140,7 @@ class ListTests(PySparkTestCase):
as NumPy arrays.
"""
- def test_clustering(self):
+ def test_kmeans(self):
from pyspark.mllib.clustering import KMeans
data = [
[0, 1.1],
@@ -152,6 +152,21 @@ def test_clustering(self):
self.assertEquals(clusters.predict(data[0]), clusters.predict(data[1]))
self.assertEquals(clusters.predict(data[2]), clusters.predict(data[3]))
+ def test_kmeans_deterministic(self):
+ from pyspark.mllib.clustering import KMeans
+ X = range(0, 100, 10)
+ Y = range(0, 100, 10)
+ data = [[x, y] for x, y in zip(X, Y)]
+ clusters1 = KMeans.train(self.sc.parallelize(data),
+ 3, initializationMode="k-means||", seed=42)
+ clusters2 = KMeans.train(self.sc.parallelize(data),
+ 3, initializationMode="k-means||", seed=42)
+ centers1 = clusters1.centers
+ centers2 = clusters2.centers
+ for c1, c2 in zip(centers1, centers2):
+ # TODO: Allow small numeric difference.
+ self.assertTrue(array_equal(c1, c2))
+
def test_classification(self):
from pyspark.mllib.classification import LogisticRegressionWithSGD, SVMWithSGD, NaiveBayes
from pyspark.mllib.tree import DecisionTree
diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py
index c1120cf781e5e..f4cfe4845dc20 100644
--- a/python/pyspark/rdd.py
+++ b/python/pyspark/rdd.py
@@ -141,6 +141,17 @@ def id(self):
def __repr__(self):
return self._jrdd.toString()
+ def __getnewargs__(self):
+ # This method is called when attempting to pickle an RDD, which is always an error:
+ raise Exception(
+ "It appears that you are attempting to broadcast an RDD or reference an RDD from an "
+ "action or transformation. RDD transformations and actions can only be invoked by the "
+ "driver, not inside of other transformations; for example, "
+ "rdd1.map(lambda x: rdd2.values.count() * x) is invalid because the values "
+ "transformation and count action cannot be performed inside of the rdd1.map "
+ "transformation. For more information, see SPARK-5063."
+ )
+
@property
def context(self):
"""
@@ -1130,6 +1141,18 @@ def first(self):
return rs[0]
raise ValueError("RDD is empty")
+ def isEmpty(self):
+ """
+ Returns true if and only if the RDD contains no elements at all. Note that an RDD
+ may be empty even when it has at least 1 partition.
+
+ >>> sc.parallelize([]).isEmpty()
+ True
+ >>> sc.parallelize([1]).isEmpty()
+ False
+ """
+ return self._jrdd.partitions().size() == 0 or len(self.take(1)) == 0
+
def saveAsNewAPIHadoopDataset(self, conf, keyConverter=None, valueConverter=None):
"""
Output a Python RDD of key-value pairs (of form C{RDD[(K, V)]}) to any Hadoop file
diff --git a/python/pyspark/serializers.py b/python/pyspark/serializers.py
index bd08c9a6d20d6..b8bda835174b2 100644
--- a/python/pyspark/serializers.py
+++ b/python/pyspark/serializers.py
@@ -181,6 +181,10 @@ def __init__(self, serializer, batchSize=UNLIMITED_BATCH_SIZE):
def _batched(self, iterator):
if self.batchSize == self.UNLIMITED_BATCH_SIZE:
yield list(iterator)
+ elif hasattr(iterator, "__len__") and hasattr(iterator, "__getslice__"):
+ n = len(iterator)
+ for i in xrange(0, n, self.batchSize):
+ yield iterator[i: i + self.batchSize]
else:
items = []
count = 0
diff --git a/python/pyspark/sql.py b/python/pyspark/sql.py
index 014ac1791c849..7d7550c854b2f 100644
--- a/python/pyspark/sql.py
+++ b/python/pyspark/sql.py
@@ -20,15 +20,19 @@
- L{SQLContext}
Main entry point for SQL functionality.
- - L{SchemaRDD}
+ - L{DataFrame}
A Resilient Distributed Dataset (RDD) with Schema information for the data contained. In
- addition to normal RDD operations, SchemaRDDs also support SQL.
+ addition to normal RDD operations, DataFrames also support SQL.
+ - L{GroupedDataFrame}
+ - L{Column}
+ Column is a DataFrame with a single column.
- L{Row}
A Row of data returned by a Spark SQL query.
- L{HiveContext}
Main entry point for accessing data stored in Apache Hive..
"""
+import sys
import itertools
import decimal
import datetime
@@ -36,6 +40,9 @@
import warnings
import json
import re
+import random
+import os
+from tempfile import NamedTemporaryFile
from array import array
from operator import itemgetter
from itertools import imap
@@ -43,6 +50,7 @@
from py4j.protocol import Py4JError
from py4j.java_collections import ListConverter, MapConverter
+from pyspark.context import SparkContext
from pyspark.rdd import RDD
from pyspark.serializers import BatchedSerializer, AutoBatchedSerializer, PickleSerializer, \
CloudPickleSerializer, UTF8Deserializer
@@ -54,7 +62,8 @@
"StringType", "BinaryType", "BooleanType", "DateType", "TimestampType", "DecimalType",
"DoubleType", "FloatType", "ByteType", "IntegerType", "LongType",
"ShortType", "ArrayType", "MapType", "StructField", "StructType",
- "SQLContext", "HiveContext", "SchemaRDD", "Row"]
+ "SQLContext", "HiveContext", "DataFrame", "GroupedDataFrame", "Column", "Row",
+ "SchemaRDD"]
class DataType(object):
@@ -1171,7 +1180,7 @@ def Dict(d):
class Row(tuple):
- """ Row in SchemaRDD """
+ """ Row in DataFrame """
__DATATYPE__ = dataType
__FIELDS__ = tuple(f.name for f in dataType.fields)
__slots__ = ()
@@ -1198,7 +1207,7 @@ class SQLContext(object):
"""Main entry point for Spark SQL functionality.
- A SQLContext can be used create L{SchemaRDD}, register L{SchemaRDD} as
+ A SQLContext can be used create L{DataFrame}, register L{DataFrame} as
tables, execute SQL over tables, cache tables, and read parquet files.
"""
@@ -1209,8 +1218,8 @@ def __init__(self, sparkContext, sqlContext=None):
:param sqlContext: An optional JVM Scala SQLContext. If set, we do not instatiate a new
SQLContext in the JVM, instead we make all calls to this object.
- >>> srdd = sqlCtx.inferSchema(rdd)
- >>> sqlCtx.inferSchema(srdd) # doctest: +IGNORE_EXCEPTION_DETAIL
+ >>> df = sqlCtx.inferSchema(rdd)
+ >>> sqlCtx.inferSchema(df) # doctest: +IGNORE_EXCEPTION_DETAIL
Traceback (most recent call last):
...
TypeError:...
@@ -1225,12 +1234,12 @@ def __init__(self, sparkContext, sqlContext=None):
>>> allTypes = sc.parallelize([Row(i=1, s="string", d=1.0, l=1L,
... b=True, list=[1, 2, 3], dict={"s": 0}, row=Row(a=1),
... time=datetime(2014, 8, 1, 14, 1, 5))])
- >>> srdd = sqlCtx.inferSchema(allTypes)
- >>> srdd.registerTempTable("allTypes")
+ >>> df = sqlCtx.inferSchema(allTypes)
+ >>> df.registerTempTable("allTypes")
>>> sqlCtx.sql('select i+1, d+1, not b, list[1], dict["s"], time, row.a '
... 'from allTypes where b and i > 0').collect()
[Row(c0=2, c1=2.0, c2=False, c3=2, c4=0...8, 1, 14, 1, 5), a=1)]
- >>> srdd.map(lambda x: (x.i, x.s, x.d, x.l, x.b, x.time,
+ >>> df.map(lambda x: (x.i, x.s, x.d, x.l, x.b, x.time,
... x.row.a, x.list)).collect()
[(1, u'string', 1.0, 1, True, ...(2014, 8, 1, 14, 1, 5), 1, [1, 2, 3])]
"""
@@ -1281,14 +1290,14 @@ def registerFunction(self, name, f, returnType=StringType()):
self._sc._gateway._gateway_client)
includes = ListConverter().convert(self._sc._python_includes,
self._sc._gateway._gateway_client)
- self._ssql_ctx.registerPython(name,
- bytearray(pickled_command),
- env,
- includes,
- self._sc.pythonExec,
- broadcast_vars,
- self._sc._javaAccumulator,
- returnType.json())
+ self._ssql_ctx.udf().registerPython(name,
+ bytearray(pickled_command),
+ env,
+ includes,
+ self._sc.pythonExec,
+ broadcast_vars,
+ self._sc._javaAccumulator,
+ returnType.json())
def inferSchema(self, rdd, samplingRatio=None):
"""Infer and apply a schema to an RDD of L{Row}.
@@ -1309,23 +1318,23 @@ def inferSchema(self, rdd, samplingRatio=None):
... [Row(field1=1, field2="row1"),
... Row(field1=2, field2="row2"),
... Row(field1=3, field2="row3")])
- >>> srdd = sqlCtx.inferSchema(rdd)
- >>> srdd.collect()[0]
+ >>> df = sqlCtx.inferSchema(rdd)
+ >>> df.collect()[0]
Row(field1=1, field2=u'row1')
>>> NestedRow = Row("f1", "f2")
>>> nestedRdd1 = sc.parallelize([
... NestedRow(array('i', [1, 2]), {"row1": 1.0}),
... NestedRow(array('i', [2, 3]), {"row2": 2.0})])
- >>> srdd = sqlCtx.inferSchema(nestedRdd1)
- >>> srdd.collect()
+ >>> df = sqlCtx.inferSchema(nestedRdd1)
+ >>> df.collect()
[Row(f1=[1, 2], f2={u'row1': 1.0}), ..., f2={u'row2': 2.0})]
>>> nestedRdd2 = sc.parallelize([
... NestedRow([[1, 2], [2, 3]], [1, 2]),
... NestedRow([[2, 3], [3, 4]], [2, 3])])
- >>> srdd = sqlCtx.inferSchema(nestedRdd2)
- >>> srdd.collect()
+ >>> df = sqlCtx.inferSchema(nestedRdd2)
+ >>> df.collect()
[Row(f1=[[1, 2], [2, 3]], f2=[1, 2]), ..., f2=[2, 3])]
>>> from collections import namedtuple
@@ -1334,13 +1343,13 @@ def inferSchema(self, rdd, samplingRatio=None):
... [CustomRow(field1=1, field2="row1"),
... CustomRow(field1=2, field2="row2"),
... CustomRow(field1=3, field2="row3")])
- >>> srdd = sqlCtx.inferSchema(rdd)
- >>> srdd.collect()[0]
+ >>> df = sqlCtx.inferSchema(rdd)
+ >>> df.collect()[0]
Row(field1=1, field2=u'row1')
"""
- if isinstance(rdd, SchemaRDD):
- raise TypeError("Cannot apply schema to SchemaRDD")
+ if isinstance(rdd, DataFrame):
+ raise TypeError("Cannot apply schema to DataFrame")
first = rdd.first()
if not first:
@@ -1384,10 +1393,10 @@ def applySchema(self, rdd, schema):
>>> rdd2 = sc.parallelize([(1, "row1"), (2, "row2"), (3, "row3")])
>>> schema = StructType([StructField("field1", IntegerType(), False),
... StructField("field2", StringType(), False)])
- >>> srdd = sqlCtx.applySchema(rdd2, schema)
- >>> sqlCtx.registerRDDAsTable(srdd, "table1")
- >>> srdd2 = sqlCtx.sql("SELECT * from table1")
- >>> srdd2.collect()
+ >>> df = sqlCtx.applySchema(rdd2, schema)
+ >>> sqlCtx.registerRDDAsTable(df, "table1")
+ >>> df2 = sqlCtx.sql("SELECT * from table1")
+ >>> df2.collect()
[Row(field1=1, field2=u'row1'),..., Row(field1=3, field2=u'row3')]
>>> from datetime import date, datetime
@@ -1410,15 +1419,15 @@ def applySchema(self, rdd, schema):
... StructType([StructField("b", ShortType(), False)]), False),
... StructField("list", ArrayType(ByteType(), False), False),
... StructField("null", DoubleType(), True)])
- >>> srdd = sqlCtx.applySchema(rdd, schema)
- >>> results = srdd.map(
+ >>> df = sqlCtx.applySchema(rdd, schema)
+ >>> results = df.map(
... lambda x: (x.byte1, x.byte2, x.short1, x.short2, x.int, x.float, x.date,
... x.time, x.map["a"], x.struct.b, x.list, x.null))
>>> results.collect()[0] # doctest: +NORMALIZE_WHITESPACE
(127, -128, -32768, 32767, 2147483647, 1.0, datetime.date(2010, 1, 1),
datetime.datetime(2010, 1, 1, 1, 1, 1), 1, 2, [1, 2, 3], None)
- >>> srdd.registerTempTable("table2")
+ >>> df.registerTempTable("table2")
>>> sqlCtx.sql(
... "SELECT byte1 - 1 AS byte1, byte2 + 1 AS byte2, " +
... "short1 + 1 AS short1, short2 - 1 AS short2, int - 1 AS int, " +
@@ -1431,13 +1440,13 @@ def applySchema(self, rdd, schema):
>>> abstract = "byte short float time map{} struct(b) list[]"
>>> schema = _parse_schema_abstract(abstract)
>>> typedSchema = _infer_schema_type(rdd.first(), schema)
- >>> srdd = sqlCtx.applySchema(rdd, typedSchema)
- >>> srdd.collect()
+ >>> df = sqlCtx.applySchema(rdd, typedSchema)
+ >>> df.collect()
[Row(byte=127, short=-32768, float=1.0, time=..., list=[1, 2, 3])]
"""
- if isinstance(rdd, SchemaRDD):
- raise TypeError("Cannot apply schema to SchemaRDD")
+ if isinstance(rdd, DataFrame):
+ raise TypeError("Cannot apply schema to DataFrame")
if not isinstance(schema, StructType):
raise TypeError("schema should be StructType")
@@ -1457,8 +1466,8 @@ def applySchema(self, rdd, schema):
rdd = rdd.map(converter)
jrdd = self._jvm.SerDeUtil.toJavaArray(rdd._to_java_object_rdd())
- srdd = self._ssql_ctx.applySchemaToPythonRDD(jrdd.rdd(), schema.json())
- return SchemaRDD(srdd.toJavaSchemaRDD(), self)
+ df = self._ssql_ctx.applySchemaToPythonRDD(jrdd.rdd(), schema.json())
+ return DataFrame(df, self)
def registerRDDAsTable(self, rdd, tableName):
"""Registers the given RDD as a temporary table in the catalog.
@@ -1466,34 +1475,34 @@ def registerRDDAsTable(self, rdd, tableName):
Temporary tables exist only during the lifetime of this instance of
SQLContext.
- >>> srdd = sqlCtx.inferSchema(rdd)
- >>> sqlCtx.registerRDDAsTable(srdd, "table1")
+ >>> df = sqlCtx.inferSchema(rdd)
+ >>> sqlCtx.registerRDDAsTable(df, "table1")
"""
- if (rdd.__class__ is SchemaRDD):
- srdd = rdd._jschema_rdd.baseSchemaRDD()
- self._ssql_ctx.registerRDDAsTable(srdd, tableName)
+ if (rdd.__class__ is DataFrame):
+ df = rdd._jdf
+ self._ssql_ctx.registerRDDAsTable(df, tableName)
else:
- raise ValueError("Can only register SchemaRDD as table")
+ raise ValueError("Can only register DataFrame as table")
def parquetFile(self, path):
- """Loads a Parquet file, returning the result as a L{SchemaRDD}.
+ """Loads a Parquet file, returning the result as a L{DataFrame}.
>>> import tempfile, shutil
>>> parquetFile = tempfile.mkdtemp()
>>> shutil.rmtree(parquetFile)
- >>> srdd = sqlCtx.inferSchema(rdd)
- >>> srdd.saveAsParquetFile(parquetFile)
- >>> srdd2 = sqlCtx.parquetFile(parquetFile)
- >>> sorted(srdd.collect()) == sorted(srdd2.collect())
+ >>> df = sqlCtx.inferSchema(rdd)
+ >>> df.saveAsParquetFile(parquetFile)
+ >>> df2 = sqlCtx.parquetFile(parquetFile)
+ >>> sorted(df.collect()) == sorted(df2.collect())
True
"""
- jschema_rdd = self._ssql_ctx.parquetFile(path).toJavaSchemaRDD()
- return SchemaRDD(jschema_rdd, self)
+ jdf = self._ssql_ctx.parquetFile(path)
+ return DataFrame(jdf, self)
def jsonFile(self, path, schema=None, samplingRatio=1.0):
"""
Loads a text file storing one JSON object per line as a
- L{SchemaRDD}.
+ L{DataFrame}.
If the schema is provided, applies the given schema to this
JSON dataset.
@@ -1508,23 +1517,23 @@ def jsonFile(self, path, schema=None, samplingRatio=1.0):
>>> for json in jsonStrings:
... print>>ofn, json
>>> ofn.close()
- >>> srdd1 = sqlCtx.jsonFile(jsonFile)
- >>> sqlCtx.registerRDDAsTable(srdd1, "table1")
- >>> srdd2 = sqlCtx.sql(
+ >>> df1 = sqlCtx.jsonFile(jsonFile)
+ >>> sqlCtx.registerRDDAsTable(df1, "table1")
+ >>> df2 = sqlCtx.sql(
... "SELECT field1 AS f1, field2 as f2, field3 as f3, "
... "field6 as f4 from table1")
- >>> for r in srdd2.collect():
+ >>> for r in df2.collect():
... print r
Row(f1=1, f2=u'row1', f3=Row(field4=11, field5=None), f4=None)
Row(f1=2, f2=None, f3=Row(field4=22,..., f4=[Row(field7=u'row2')])
Row(f1=None, f2=u'row3', f3=Row(field4=33, field5=[]), f4=None)
- >>> srdd3 = sqlCtx.jsonFile(jsonFile, srdd1.schema())
- >>> sqlCtx.registerRDDAsTable(srdd3, "table2")
- >>> srdd4 = sqlCtx.sql(
+ >>> df3 = sqlCtx.jsonFile(jsonFile, df1.schema())
+ >>> sqlCtx.registerRDDAsTable(df3, "table2")
+ >>> df4 = sqlCtx.sql(
... "SELECT field1 AS f1, field2 as f2, field3 as f3, "
... "field6 as f4 from table2")
- >>> for r in srdd4.collect():
+ >>> for r in df4.collect():
... print r
Row(f1=1, f2=u'row1', f3=Row(field4=11, field5=None), f4=None)
Row(f1=2, f2=None, f3=Row(field4=22,..., f4=[Row(field7=u'row2')])
@@ -1536,23 +1545,23 @@ def jsonFile(self, path, schema=None, samplingRatio=1.0):
... StructType([
... StructField("field5",
... ArrayType(IntegerType(), False), True)]), False)])
- >>> srdd5 = sqlCtx.jsonFile(jsonFile, schema)
- >>> sqlCtx.registerRDDAsTable(srdd5, "table3")
- >>> srdd6 = sqlCtx.sql(
+ >>> df5 = sqlCtx.jsonFile(jsonFile, schema)
+ >>> sqlCtx.registerRDDAsTable(df5, "table3")
+ >>> df6 = sqlCtx.sql(
... "SELECT field2 AS f1, field3.field5 as f2, "
... "field3.field5[0] as f3 from table3")
- >>> srdd6.collect()
+ >>> df6.collect()
[Row(f1=u'row1', f2=None, f3=None)...Row(f1=u'row3', f2=[], f3=None)]
"""
if schema is None:
- srdd = self._ssql_ctx.jsonFile(path, samplingRatio)
+ df = self._ssql_ctx.jsonFile(path, samplingRatio)
else:
scala_datatype = self._ssql_ctx.parseDataType(schema.json())
- srdd = self._ssql_ctx.jsonFile(path, scala_datatype)
- return SchemaRDD(srdd.toJavaSchemaRDD(), self)
+ df = self._ssql_ctx.jsonFile(path, scala_datatype)
+ return DataFrame(df, self)
def jsonRDD(self, rdd, schema=None, samplingRatio=1.0):
- """Loads an RDD storing one JSON object per string as a L{SchemaRDD}.
+ """Loads an RDD storing one JSON object per string as a L{DataFrame}.
If the schema is provided, applies the given schema to this
JSON dataset.
@@ -1560,23 +1569,23 @@ def jsonRDD(self, rdd, schema=None, samplingRatio=1.0):
Otherwise, it samples the dataset with ratio `samplingRatio` to
determine the schema.
- >>> srdd1 = sqlCtx.jsonRDD(json)
- >>> sqlCtx.registerRDDAsTable(srdd1, "table1")
- >>> srdd2 = sqlCtx.sql(
+ >>> df1 = sqlCtx.jsonRDD(json)
+ >>> sqlCtx.registerRDDAsTable(df1, "table1")
+ >>> df2 = sqlCtx.sql(
... "SELECT field1 AS f1, field2 as f2, field3 as f3, "
... "field6 as f4 from table1")
- >>> for r in srdd2.collect():
+ >>> for r in df2.collect():
... print r
Row(f1=1, f2=u'row1', f3=Row(field4=11, field5=None), f4=None)
Row(f1=2, f2=None, f3=Row(field4=22..., f4=[Row(field7=u'row2')])
Row(f1=None, f2=u'row3', f3=Row(field4=33, field5=[]), f4=None)
- >>> srdd3 = sqlCtx.jsonRDD(json, srdd1.schema())
- >>> sqlCtx.registerRDDAsTable(srdd3, "table2")
- >>> srdd4 = sqlCtx.sql(
+ >>> df3 = sqlCtx.jsonRDD(json, df1.schema())
+ >>> sqlCtx.registerRDDAsTable(df3, "table2")
+ >>> df4 = sqlCtx.sql(
... "SELECT field1 AS f1, field2 as f2, field3 as f3, "
... "field6 as f4 from table2")
- >>> for r in srdd4.collect():
+ >>> for r in df4.collect():
... print r
Row(f1=1, f2=u'row1', f3=Row(field4=11, field5=None), f4=None)
Row(f1=2, f2=None, f3=Row(field4=22..., f4=[Row(field7=u'row2')])
@@ -1588,12 +1597,12 @@ def jsonRDD(self, rdd, schema=None, samplingRatio=1.0):
... StructType([
... StructField("field5",
... ArrayType(IntegerType(), False), True)]), False)])
- >>> srdd5 = sqlCtx.jsonRDD(json, schema)
- >>> sqlCtx.registerRDDAsTable(srdd5, "table3")
- >>> srdd6 = sqlCtx.sql(
+ >>> df5 = sqlCtx.jsonRDD(json, schema)
+ >>> sqlCtx.registerRDDAsTable(df5, "table3")
+ >>> df6 = sqlCtx.sql(
... "SELECT field2 AS f1, field3.field5 as f2, "
... "field3.field5[0] as f3 from table3")
- >>> srdd6.collect()
+ >>> df6.collect()
[Row(f1=u'row1', f2=None,...Row(f1=u'row3', f2=[], f3=None)]
>>> sqlCtx.jsonRDD(sc.parallelize(['{}',
@@ -1615,33 +1624,33 @@ def func(iterator):
keyed._bypass_serializer = True
jrdd = keyed._jrdd.map(self._jvm.BytesToString())
if schema is None:
- srdd = self._ssql_ctx.jsonRDD(jrdd.rdd(), samplingRatio)
+ df = self._ssql_ctx.jsonRDD(jrdd.rdd(), samplingRatio)
else:
scala_datatype = self._ssql_ctx.parseDataType(schema.json())
- srdd = self._ssql_ctx.jsonRDD(jrdd.rdd(), scala_datatype)
- return SchemaRDD(srdd.toJavaSchemaRDD(), self)
+ df = self._ssql_ctx.jsonRDD(jrdd.rdd(), scala_datatype)
+ return DataFrame(df, self)
def sql(self, sqlQuery):
- """Return a L{SchemaRDD} representing the result of the given query.
+ """Return a L{DataFrame} representing the result of the given query.
- >>> srdd = sqlCtx.inferSchema(rdd)
- >>> sqlCtx.registerRDDAsTable(srdd, "table1")
- >>> srdd2 = sqlCtx.sql("SELECT field1 AS f1, field2 as f2 from table1")
- >>> srdd2.collect()
+ >>> df = sqlCtx.inferSchema(rdd)
+ >>> sqlCtx.registerRDDAsTable(df, "table1")
+ >>> df2 = sqlCtx.sql("SELECT field1 AS f1, field2 as f2 from table1")
+ >>> df2.collect()
[Row(f1=1, f2=u'row1'), Row(f1=2, f2=u'row2'), Row(f1=3, f2=u'row3')]
"""
- return SchemaRDD(self._ssql_ctx.sql(sqlQuery).toJavaSchemaRDD(), self)
+ return DataFrame(self._ssql_ctx.sql(sqlQuery), self)
def table(self, tableName):
- """Returns the specified table as a L{SchemaRDD}.
+ """Returns the specified table as a L{DataFrame}.
- >>> srdd = sqlCtx.inferSchema(rdd)
- >>> sqlCtx.registerRDDAsTable(srdd, "table1")
- >>> srdd2 = sqlCtx.table("table1")
- >>> sorted(srdd.collect()) == sorted(srdd2.collect())
+ >>> df = sqlCtx.inferSchema(rdd)
+ >>> sqlCtx.registerRDDAsTable(df, "table1")
+ >>> df2 = sqlCtx.table("table1")
+ >>> sorted(df.collect()) == sorted(df2.collect())
True
"""
- return SchemaRDD(self._ssql_ctx.table(tableName).toJavaSchemaRDD(), self)
+ return DataFrame(self._ssql_ctx.table(tableName), self)
def cacheTable(self, tableName):
"""Caches the specified table in-memory."""
@@ -1686,24 +1695,6 @@ def _ssql_ctx(self):
def _get_hive_ctx(self):
return self._jvm.HiveContext(self._jsc.sc())
- def hiveql(self, hqlQuery):
- """
- DEPRECATED: Use sql()
- """
- warnings.warn("hiveql() is deprecated as the sql function now parses using HiveQL by" +
- "default. The SQL dialect for parsing can be set using 'spark.sql.dialect'",
- DeprecationWarning)
- return SchemaRDD(self._ssql_ctx.hiveql(hqlQuery).toJavaSchemaRDD(), self)
-
- def hql(self, hqlQuery):
- """
- DEPRECATED: Use sql()
- """
- warnings.warn("hql() is deprecated as the sql function now parses using HiveQL by" +
- "default. The SQL dialect for parsing can be set using 'spark.sql.dialect'",
- DeprecationWarning)
- return self.hiveql(hqlQuery)
-
class LocalHiveContext(HiveContext):
@@ -1716,12 +1707,6 @@ def _get_hive_ctx(self):
return self._jvm.LocalHiveContext(self._jsc.sc())
-class TestHiveContext(HiveContext):
-
- def _get_hive_ctx(self):
- return self._jvm.TestHiveContext(self._jsc.sc())
-
-
def _create_row(fields, values):
row = Row(*values)
row.__FIELDS__ = fields
@@ -1731,7 +1716,7 @@ def _create_row(fields, values):
class Row(tuple):
"""
- A row in L{SchemaRDD}. The fields in it can be accessed like attributes.
+ A row in L{DataFrame}. The fields in it can be accessed like attributes.
Row can be used to create a row object by using named arguments,
the fields will be sorted by names.
@@ -1823,111 +1808,119 @@ def inherit_doc(cls):
return cls
-@inherit_doc
-class SchemaRDD(RDD):
+class DataFrame(object):
+
+ """A collection of rows that have the same columns.
+
+ A :class:`DataFrame` is equivalent to a relational table in Spark SQL,
+ and can be created using various functions in :class:`SQLContext`::
+
+ people = sqlContext.parquetFile("...")
+
+ Once created, it can be manipulated using the various domain-specific-language
+ (DSL) functions defined in: [[DataFrame]], [[Column]].
+
+ To select a column from the data frame, use the apply method::
+
+ ageCol = people.age
+
+ Note that the :class:`Column` type can also be manipulated
+ through its various functions::
+
+ # The following creates a new column that increases everybody's age by 10.
+ people.age + 10
- """An RDD of L{Row} objects that has an associated schema.
- The underlying JVM object is a SchemaRDD, not a PythonRDD, so we can
- utilize the relational query api exposed by Spark SQL.
+ A more concrete example::
- For normal L{pyspark.rdd.RDD} operations (map, count, etc.) the
- L{SchemaRDD} is not operated on directly, as it's underlying
- implementation is an RDD composed of Java objects. Instead it is
- converted to a PythonRDD in the JVM, on which Python operations can
- be done.
+ # To create DataFrame using SQLContext
+ people = sqlContext.parquetFile("...")
+ department = sqlContext.parquetFile("...")
- This class receives raw tuples from Java but assigns a class to it in
- all its data-collection methods (mapPartitionsWithIndex, collect, take,
- etc) so that PySpark sees them as Row objects with named fields.
+ people.filter(people.age > 30).join(department, people.deptId == department.id)) \
+ .groupBy(department.name, "gender").agg({"salary": "avg", "age": "max"})
"""
- def __init__(self, jschema_rdd, sql_ctx):
+ def __init__(self, jdf, sql_ctx):
+ self._jdf = jdf
self.sql_ctx = sql_ctx
- self._sc = sql_ctx._sc
- clsName = jschema_rdd.getClass().getName()
- assert clsName.endswith("JavaSchemaRDD"), "jschema_rdd must be JavaSchemaRDD"
- self._jschema_rdd = jschema_rdd
- self._id = None
+ self._sc = sql_ctx and sql_ctx._sc
self.is_cached = False
- self.is_checkpointed = False
- self.ctx = self.sql_ctx._sc
- # the _jrdd is created by javaToPython(), serialized by pickle
- self._jrdd_deserializer = AutoBatchedSerializer(PickleSerializer())
@property
- def _jrdd(self):
- """Lazy evaluation of PythonRDD object.
+ def rdd(self):
+ """Return the content of the :class:`DataFrame` as an :class:`RDD`
+ of :class:`Row`s. """
+ if not hasattr(self, '_lazy_rdd'):
+ jrdd = self._jdf.javaToPython()
+ rdd = RDD(jrdd, self.sql_ctx._sc, BatchedSerializer(PickleSerializer()))
+ schema = self.schema()
- Only done when a user calls methods defined by the
- L{pyspark.rdd.RDD} super class (map, filter, etc.).
- """
- if not hasattr(self, '_lazy_jrdd'):
- self._lazy_jrdd = self._jschema_rdd.baseSchemaRDD().javaToPython()
- return self._lazy_jrdd
+ def applySchema(it):
+ cls = _create_cls(schema)
+ return itertools.imap(cls, it)
+
+ self._lazy_rdd = rdd.mapPartitions(applySchema)
- def id(self):
- if self._id is None:
- self._id = self._jrdd.id()
- return self._id
+ return self._lazy_rdd
def limit(self, num):
"""Limit the result count to the number specified.
- >>> srdd = sqlCtx.inferSchema(rdd)
- >>> srdd.limit(2).collect()
+ >>> df = sqlCtx.inferSchema(rdd)
+ >>> df.limit(2).collect()
[Row(field1=1, field2=u'row1'), Row(field1=2, field2=u'row2')]
- >>> srdd.limit(0).collect()
+ >>> df.limit(0).collect()
[]
"""
- rdd = self._jschema_rdd.baseSchemaRDD().limit(num).toJavaSchemaRDD()
- return SchemaRDD(rdd, self.sql_ctx)
+ jdf = self._jdf.limit(num)
+ return DataFrame(jdf, self.sql_ctx)
def toJSON(self, use_unicode=False):
- """Convert a SchemaRDD into a MappedRDD of JSON documents; one document per row.
+ """Convert a DataFrame into a MappedRDD of JSON documents; one document per row.
- >>> srdd1 = sqlCtx.jsonRDD(json)
- >>> sqlCtx.registerRDDAsTable(srdd1, "table1")
- >>> srdd2 = sqlCtx.sql( "SELECT * from table1")
- >>> srdd2.toJSON().take(1)[0] == '{"field1":1,"field2":"row1","field3":{"field4":11}}'
+ >>> df1 = sqlCtx.jsonRDD(json)
+ >>> sqlCtx.registerRDDAsTable(df1, "table1")
+ >>> df2 = sqlCtx.sql( "SELECT * from table1")
+ >>> df2.toJSON().take(1)[0] == '{"field1":1,"field2":"row1","field3":{"field4":11}}'
True
- >>> srdd3 = sqlCtx.sql( "SELECT field3.field4 from table1")
- >>> srdd3.toJSON().collect() == ['{"field4":11}', '{"field4":22}', '{"field4":33}']
+ >>> df3 = sqlCtx.sql( "SELECT field3.field4 from table1")
+ >>> df3.toJSON().collect() == ['{"field4":11}', '{"field4":22}', '{"field4":33}']
True
"""
- rdd = self._jschema_rdd.baseSchemaRDD().toJSON()
+ rdd = self._jdf.toJSON()
return RDD(rdd.toJavaRDD(), self._sc, UTF8Deserializer(use_unicode))
def saveAsParquetFile(self, path):
"""Save the contents as a Parquet file, preserving the schema.
Files that are written out using this method can be read back in as
- a SchemaRDD using the L{SQLContext.parquetFile} method.
+ a DataFrame using the L{SQLContext.parquetFile} method.
>>> import tempfile, shutil
>>> parquetFile = tempfile.mkdtemp()
>>> shutil.rmtree(parquetFile)
- >>> srdd = sqlCtx.inferSchema(rdd)
- >>> srdd.saveAsParquetFile(parquetFile)
- >>> srdd2 = sqlCtx.parquetFile(parquetFile)
- >>> sorted(srdd2.collect()) == sorted(srdd.collect())
+ >>> df = sqlCtx.inferSchema(rdd)
+ >>> df.saveAsParquetFile(parquetFile)
+ >>> df2 = sqlCtx.parquetFile(parquetFile)
+ >>> sorted(df2.collect()) == sorted(df.collect())
True
"""
- self._jschema_rdd.saveAsParquetFile(path)
+ self._jdf.saveAsParquetFile(path)
def registerTempTable(self, name):
"""Registers this RDD as a temporary table using the given name.
The lifetime of this temporary table is tied to the L{SQLContext}
- that was used to create this SchemaRDD.
+ that was used to create this DataFrame.
- >>> srdd = sqlCtx.inferSchema(rdd)
- >>> srdd.registerTempTable("test")
- >>> srdd2 = sqlCtx.sql("select * from test")
- >>> sorted(srdd.collect()) == sorted(srdd2.collect())
+ >>> df = sqlCtx.inferSchema(rdd)
+ >>> df.registerTempTable("test")
+ >>> df2 = sqlCtx.sql("select * from test")
+ >>> sorted(df.collect()) == sorted(df2.collect())
True
"""
- self._jschema_rdd.registerTempTable(name)
+ self._jdf.registerTempTable(name)
def registerAsTable(self, name):
"""DEPRECATED: use registerTempTable() instead"""
@@ -1935,62 +1928,61 @@ def registerAsTable(self, name):
self.registerTempTable(name)
def insertInto(self, tableName, overwrite=False):
- """Inserts the contents of this SchemaRDD into the specified table.
+ """Inserts the contents of this DataFrame into the specified table.
Optionally overwriting any existing data.
"""
- self._jschema_rdd.insertInto(tableName, overwrite)
+ self._jdf.insertInto(tableName, overwrite)
def saveAsTable(self, tableName):
- """Creates a new table with the contents of this SchemaRDD."""
- self._jschema_rdd.saveAsTable(tableName)
+ """Creates a new table with the contents of this DataFrame."""
+ self._jdf.saveAsTable(tableName)
def schema(self):
- """Returns the schema of this SchemaRDD (represented by
+ """Returns the schema of this DataFrame (represented by
a L{StructType})."""
- return _parse_datatype_json_string(self._jschema_rdd.baseSchemaRDD().schema().json())
-
- def schemaString(self):
- """Returns the output schema in the tree format."""
- return self._jschema_rdd.schemaString()
+ return _parse_datatype_json_string(self._jdf.schema().json())
def printSchema(self):
"""Prints out the schema in the tree format."""
- print self.schemaString()
+ print (self._jdf.schema().treeString())
def count(self):
"""Return the number of elements in this RDD.
Unlike the base RDD implementation of count, this implementation
- leverages the query optimizer to compute the count on the SchemaRDD,
+ leverages the query optimizer to compute the count on the DataFrame,
which supports features such as filter pushdown.
- >>> srdd = sqlCtx.inferSchema(rdd)
- >>> srdd.count()
+ >>> df = sqlCtx.inferSchema(rdd)
+ >>> df.count()
3L
- >>> srdd.count() == srdd.map(lambda x: x).count()
+ >>> df.count() == df.map(lambda x: x).count()
True
"""
- return self._jschema_rdd.count()
+ return self._jdf.count()
def collect(self):
- """Return a list that contains all of the rows in this RDD.
+ """Return a list that contains all of the rows.
Each object in the list is a Row, the fields can be accessed as
attributes.
- Unlike the base RDD implementation of collect, this implementation
- leverages the query optimizer to perform a collect on the SchemaRDD,
- which supports features such as filter pushdown.
-
- >>> srdd = sqlCtx.inferSchema(rdd)
- >>> srdd.collect()
+ >>> df = sqlCtx.inferSchema(rdd)
+ >>> df.collect()
[Row(field1=1, field2=u'row1'), ..., Row(field1=3, field2=u'row3')]
"""
- with SCCallSiteSync(self.context) as css:
- bytesInJava = self._jschema_rdd.baseSchemaRDD().collectToPython().iterator()
+ with SCCallSiteSync(self._sc) as css:
+ bytesInJava = self._jdf.javaToPython().collect().iterator()
cls = _create_cls(self.schema())
- return map(cls, self._collect_iterator_through_file(bytesInJava))
+ tempFile = NamedTemporaryFile(delete=False, dir=self._sc._temp_dir)
+ tempFile.close()
+ self._sc._writeToFile(bytesInJava, tempFile.name)
+ # Read the data into Python and deserialize it:
+ with open(tempFile.name, 'rb') as tempFile:
+ rs = list(BatchedSerializer(PickleSerializer()).load_stream(tempFile))
+ os.unlink(tempFile.name)
+ return [cls(r) for r in rs]
def take(self, num):
"""Take the first num rows of the RDD.
@@ -1998,130 +1990,555 @@ def take(self, num):
Each object in the list is a Row, the fields can be accessed as
attributes.
- Unlike the base RDD implementation of take, this implementation
- leverages the query optimizer to perform a collect on a SchemaRDD,
- which supports features such as filter pushdown.
-
- >>> srdd = sqlCtx.inferSchema(rdd)
- >>> srdd.take(2)
+ >>> df = sqlCtx.inferSchema(rdd)
+ >>> df.take(2)
[Row(field1=1, field2=u'row1'), Row(field1=2, field2=u'row2')]
"""
return self.limit(num).collect()
- # Convert each object in the RDD to a Row with the right class
- # for this SchemaRDD, so that fields can be accessed as attributes.
- def mapPartitionsWithIndex(self, f, preservesPartitioning=False):
+ def map(self, f):
+ """ Return a new RDD by applying a function to each Row, it's a
+ shorthand for df.rdd.map()
"""
- Return a new RDD by applying a function to each partition of this RDD,
- while tracking the index of the original partition.
+ return self.rdd.map(f)
- >>> rdd = sc.parallelize([1, 2, 3, 4], 4)
- >>> def f(splitIndex, iterator): yield splitIndex
- >>> rdd.mapPartitionsWithIndex(f).sum()
- 6
+ def mapPartitions(self, f, preservesPartitioning=False):
"""
- rdd = RDD(self._jrdd, self._sc, self._jrdd_deserializer)
-
- schema = self.schema()
+ Return a new RDD by applying a function to each partition.
- def applySchema(_, it):
- cls = _create_cls(schema)
- return itertools.imap(cls, it)
-
- objrdd = rdd.mapPartitionsWithIndex(applySchema, preservesPartitioning)
- return objrdd.mapPartitionsWithIndex(f, preservesPartitioning)
+ >>> rdd = sc.parallelize([1, 2, 3, 4], 4)
+ >>> def f(iterator): yield 1
+ >>> rdd.mapPartitions(f).sum()
+ 4
+ """
+ return self.rdd.mapPartitions(f, preservesPartitioning)
- # We override the default cache/persist/checkpoint behavior
- # as we want to cache the underlying SchemaRDD object in the JVM,
- # not the PythonRDD checkpointed by the super class
def cache(self):
+ """ Persist with the default storage level (C{MEMORY_ONLY_SER}).
+ """
self.is_cached = True
- self._jschema_rdd.cache()
+ self._jdf.cache()
return self
def persist(self, storageLevel=StorageLevel.MEMORY_ONLY_SER):
+ """ Set the storage level to persist its values across operations
+ after the first time it is computed. This can only be used to assign
+ a new storage level if the RDD does not have a storage level set yet.
+ If no storage level is specified defaults to (C{MEMORY_ONLY_SER}).
+ """
self.is_cached = True
- javaStorageLevel = self.ctx._getJavaStorageLevel(storageLevel)
- self._jschema_rdd.persist(javaStorageLevel)
+ javaStorageLevel = self._sc._getJavaStorageLevel(storageLevel)
+ self._jdf.persist(javaStorageLevel)
return self
def unpersist(self, blocking=True):
+ """ Mark it as non-persistent, and remove all blocks for it from
+ memory and disk.
+ """
self.is_cached = False
- self._jschema_rdd.unpersist(blocking)
+ self._jdf.unpersist(blocking)
return self
- def checkpoint(self):
- self.is_checkpointed = True
- self._jschema_rdd.checkpoint()
+ # def coalesce(self, numPartitions, shuffle=False):
+ # rdd = self._jdf.coalesce(numPartitions, shuffle, None)
+ # return DataFrame(rdd, self.sql_ctx)
+
+ def repartition(self, numPartitions):
+ """ Return a new :class:`DataFrame` that has exactly `numPartitions`
+ partitions.
+ """
+ rdd = self._jdf.repartition(numPartitions, None)
+ return DataFrame(rdd, self.sql_ctx)
+
+ def sample(self, withReplacement, fraction, seed=None):
+ """
+ Return a sampled subset of this DataFrame.
+
+ >>> df = sqlCtx.inferSchema(rdd)
+ >>> df.sample(False, 0.5, 97).count()
+ 2L
+ """
+ assert fraction >= 0.0, "Negative fraction value: %s" % fraction
+ seed = seed if seed is not None else random.randint(0, sys.maxint)
+ rdd = self._jdf.sample(withReplacement, fraction, long(seed))
+ return DataFrame(rdd, self.sql_ctx)
+
+ # def takeSample(self, withReplacement, num, seed=None):
+ # """Return a fixed-size sampled subset of this DataFrame.
+ #
+ # >>> df = sqlCtx.inferSchema(rdd)
+ # >>> df.takeSample(False, 2, 97)
+ # [Row(field1=3, field2=u'row3'), Row(field1=1, field2=u'row1')]
+ # """
+ # seed = seed if seed is not None else random.randint(0, sys.maxint)
+ # with SCCallSiteSync(self.context) as css:
+ # bytesInJava = self._jdf \
+ # .takeSampleToPython(withReplacement, num, long(seed)) \
+ # .iterator()
+ # cls = _create_cls(self.schema())
+ # return map(cls, self._collect_iterator_through_file(bytesInJava))
+
+ @property
+ def dtypes(self):
+ """Return all column names and their data types as a list.
+ """
+ return [(f.name, str(f.dataType)) for f in self.schema().fields]
- def isCheckpointed(self):
- return self._jschema_rdd.isCheckpointed()
+ @property
+ def columns(self):
+ """ Return all column names as a list.
+ """
+ return [f.name for f in self.schema().fields]
- def getCheckpointFile(self):
- checkpointFile = self._jschema_rdd.getCheckpointFile()
- if checkpointFile.isPresent():
- return checkpointFile.get()
+ def show(self):
+ raise NotImplemented
- def coalesce(self, numPartitions, shuffle=False):
- rdd = self._jschema_rdd.coalesce(numPartitions, shuffle)
- return SchemaRDD(rdd, self.sql_ctx)
+ def join(self, other, joinExprs=None, joinType=None):
+ """
+ Join with another DataFrame, using the given join expression.
+ The following performs a full outer join between `df1` and `df2`::
+
+ df1.join(df2, df1.key == df2.key, "outer")
- def distinct(self, numPartitions=None):
- if numPartitions is None:
- rdd = self._jschema_rdd.distinct()
+ :param other: Right side of the join
+ :param joinExprs: Join expression
+ :param joinType: One of `inner`, `outer`, `left_outer`, `right_outer`,
+ `semijoin`.
+ """
+ if joinType is None:
+ if joinExprs is None:
+ jdf = self._jdf.join(other._jdf)
+ else:
+ jdf = self._jdf.join(other._jdf, joinExprs)
else:
- rdd = self._jschema_rdd.distinct(numPartitions)
- return SchemaRDD(rdd, self.sql_ctx)
+ jdf = self._jdf.join(other._jdf, joinExprs, joinType)
+ return DataFrame(jdf, self.sql_ctx)
- def intersection(self, other):
- if (other.__class__ is SchemaRDD):
- rdd = self._jschema_rdd.intersection(other._jschema_rdd)
- return SchemaRDD(rdd, self.sql_ctx)
+ def sort(self, *cols):
+ """ Return a new [[DataFrame]] sorted by the specified column,
+ in ascending column.
+
+ :param cols: The columns or expressions used for sorting
+ """
+ if not cols:
+ raise ValueError("should sort by at least one column")
+ for i, c in enumerate(cols):
+ if isinstance(c, basestring):
+ cols[i] = Column(c)
+ jcols = [c._jc for c in cols]
+ jdf = self._jdf.join(*jcols)
+ return DataFrame(jdf, self.sql_ctx)
+
+ sortBy = sort
+
+ def head(self, n=None):
+ """ Return the first `n` rows or the first row if n is None. """
+ if n is None:
+ rs = self.head(1)
+ return rs[0] if rs else None
+ return self.take(n)
+
+ def tail(self):
+ raise NotImplemented
+
+ def __getitem__(self, item):
+ if isinstance(item, basestring):
+ return Column(self._jdf.apply(item))
+
+ # TODO projection
+ raise IndexError
+
+ def __getattr__(self, name):
+ """ Return the column by given name """
+ if isinstance(name, basestring):
+ return Column(self._jdf.apply(name))
+ raise AttributeError
+
+ def As(self, name):
+ """ Alias the current DataFrame """
+ return DataFrame(getattr(self._jdf, "as")(name), self.sql_ctx)
+
+ def select(self, *cols):
+ """ Selecting a set of expressions.::
+
+ df.select()
+ df.select('colA', 'colB')
+ df.select(df.colA, df.colB + 1)
+
+ """
+ if not cols:
+ cols = ["*"]
+ if isinstance(cols[0], basestring):
+ cols = [_create_column_from_name(n) for n in cols]
else:
- raise ValueError("Can only intersect with another SchemaRDD")
+ cols = [c._jc for c in cols]
+ jcols = ListConverter().convert(cols, self._sc._gateway._gateway_client)
+ jdf = self._jdf.select(self._jdf.toColumnArray(jcols))
+ return DataFrame(jdf, self.sql_ctx)
- def repartition(self, numPartitions):
- rdd = self._jschema_rdd.repartition(numPartitions)
- return SchemaRDD(rdd, self.sql_ctx)
+ def filter(self, condition):
+ """ Filtering rows using the given condition::
- def subtract(self, other, numPartitions=None):
- if (other.__class__ is SchemaRDD):
- if numPartitions is None:
- rdd = self._jschema_rdd.subtract(other._jschema_rdd)
- else:
- rdd = self._jschema_rdd.subtract(other._jschema_rdd,
- numPartitions)
- return SchemaRDD(rdd, self.sql_ctx)
+ df.filter(df.age > 15)
+ df.where(df.age > 15)
+
+ """
+ return DataFrame(self._jdf.filter(condition._jc), self.sql_ctx)
+
+ where = filter
+
+ def groupBy(self, *cols):
+ """ Group the [[DataFrame]] using the specified columns,
+ so we can run aggregation on them. See :class:`GroupedDataFrame`
+ for all the available aggregate functions::
+
+ df.groupBy(df.department).avg()
+ df.groupBy("department", "gender").agg({
+ "salary": "avg",
+ "age": "max",
+ })
+ """
+ if cols and isinstance(cols[0], basestring):
+ cols = [_create_column_from_name(n) for n in cols]
else:
- raise ValueError("Can only subtract another SchemaRDD")
+ cols = [c._jc for c in cols]
+ jcols = ListConverter().convert(cols, self._sc._gateway._gateway_client)
+ jdf = self._jdf.groupBy(self._jdf.toColumnArray(jcols))
+ return GroupedDataFrame(jdf, self.sql_ctx)
- def sample(self, withReplacement, fraction, seed=None):
+ def agg(self, *exprs):
+ """ Aggregate on the entire [[DataFrame]] without groups
+ (shorthand for df.groupBy.agg())::
+
+ df.agg({"age": "max", "salary": "avg"})
"""
- Return a sampled subset of this SchemaRDD.
+ return self.groupBy().agg(*exprs)
- >>> srdd = sqlCtx.inferSchema(rdd)
- >>> srdd.sample(False, 0.5, 97).count()
- 2L
+ def unionAll(self, other):
+ """ Return a new DataFrame containing union of rows in this
+ frame and another frame.
+
+ This is equivalent to `UNION ALL` in SQL.
"""
- assert fraction >= 0.0, "Negative fraction value: %s" % fraction
- seed = seed if seed is not None else random.randint(0, sys.maxint)
- rdd = self._jschema_rdd.sample(withReplacement, fraction, long(seed))
- return SchemaRDD(rdd, self.sql_ctx)
+ return DataFrame(self._jdf.unionAll(other._jdf), self.sql_ctx)
- def takeSample(self, withReplacement, num, seed=None):
- """Return a fixed-size sampled subset of this SchemaRDD.
+ def intersect(self, other):
+ """ Return a new [[DataFrame]] containing rows only in
+ both this frame and another frame.
- >>> srdd = sqlCtx.inferSchema(rdd)
- >>> srdd.takeSample(False, 2, 97)
- [Row(field1=3, field2=u'row3'), Row(field1=1, field2=u'row1')]
+ This is equivalent to `INTERSECT` in SQL.
"""
- seed = seed if seed is not None else random.randint(0, sys.maxint)
- with SCCallSiteSync(self.context) as css:
- bytesInJava = self._jschema_rdd.baseSchemaRDD() \
- .takeSampleToPython(withReplacement, num, long(seed)) \
- .iterator()
- cls = _create_cls(self.schema())
- return map(cls, self._collect_iterator_through_file(bytesInJava))
+ return DataFrame(self._jdf.intersect(other._jdf), self.sql_ctx)
+
+ def Except(self, other):
+ """ Return a new [[DataFrame]] containing rows in this frame
+ but not in another frame.
+
+ This is equivalent to `EXCEPT` in SQL.
+ """
+ return DataFrame(getattr(self._jdf, "except")(other._jdf), self.sql_ctx)
+
+ def sample(self, withReplacement, fraction, seed=None):
+ """ Return a new DataFrame by sampling a fraction of rows. """
+ if seed is None:
+ jdf = self._jdf.sample(withReplacement, fraction)
+ else:
+ jdf = self._jdf.sample(withReplacement, fraction, seed)
+ return DataFrame(jdf, self.sql_ctx)
+
+ def addColumn(self, colName, col):
+ """ Return a new [[DataFrame]] by adding a column. """
+ return self.select('*', col.As(colName))
+
+ def removeColumn(self, colName):
+ raise NotImplemented
+
+
+# Having SchemaRDD for backward compatibility (for docs)
+class SchemaRDD(DataFrame):
+ """
+ SchemaRDD is deprecated, please use DataFrame
+ """
+
+
+def dfapi(f):
+ def _api(self):
+ name = f.__name__
+ jdf = getattr(self._jdf, name)()
+ return DataFrame(jdf, self.sql_ctx)
+ _api.__name__ = f.__name__
+ _api.__doc__ = f.__doc__
+ return _api
+
+
+class GroupedDataFrame(object):
+
+ """
+ A set of methods for aggregations on a :class:`DataFrame`,
+ created by DataFrame.groupBy().
+ """
+
+ def __init__(self, jdf, sql_ctx):
+ self._jdf = jdf
+ self.sql_ctx = sql_ctx
+
+ def agg(self, *exprs):
+ """ Compute aggregates by specifying a map from column name
+ to aggregate methods.
+
+ The available aggregate methods are `avg`, `max`, `min`,
+ `sum`, `count`.
+
+ :param exprs: list or aggregate columns or a map from column
+ name to agregate methods.
+ """
+ if len(exprs) == 1 and isinstance(exprs[0], dict):
+ jmap = MapConverter().convert(exprs[0],
+ self.sql_ctx._sc._gateway._gateway_client)
+ jdf = self._jdf.agg(jmap)
+ else:
+ # Columns
+ assert all(isinstance(c, Column) for c in exprs), "all exprs should be Columns"
+ jdf = self._jdf.agg(*exprs)
+ return DataFrame(jdf, self.sql_ctx)
+
+ @dfapi
+ def count(self):
+ """ Count the number of rows for each group. """
+
+ @dfapi
+ def mean(self):
+ """Compute the average value for each numeric columns
+ for each group. This is an alias for `avg`."""
+
+ @dfapi
+ def avg(self):
+ """Compute the average value for each numeric columns
+ for each group."""
+
+ @dfapi
+ def max(self):
+ """Compute the max value for each numeric columns for
+ each group. """
+
+ @dfapi
+ def min(self):
+ """Compute the min value for each numeric column for
+ each group."""
+
+ @dfapi
+ def sum(self):
+ """Compute the sum for each numeric columns for each
+ group."""
+
+
+SCALA_METHOD_MAPPINGS = {
+ '=': '$eq',
+ '>': '$greater',
+ '<': '$less',
+ '+': '$plus',
+ '-': '$minus',
+ '*': '$times',
+ '/': '$div',
+ '!': '$bang',
+ '@': '$at',
+ '#': '$hash',
+ '%': '$percent',
+ '^': '$up',
+ '&': '$amp',
+ '~': '$tilde',
+ '?': '$qmark',
+ '|': '$bar',
+ '\\': '$bslash',
+ ':': '$colon',
+}
+
+
+def _create_column_from_literal(literal):
+ sc = SparkContext._active_spark_context
+ return sc._jvm.Literal.apply(literal)
+
+
+def _create_column_from_name(name):
+ sc = SparkContext._active_spark_context
+ return sc._jvm.Column(name)
+
+
+def _scalaMethod(name):
+ """ Translate operators into methodName in Scala
+
+ For example:
+ >>> _scalaMethod('+')
+ '$plus'
+ >>> _scalaMethod('>=')
+ '$greater$eq'
+ >>> _scalaMethod('cast')
+ 'cast'
+ """
+ return ''.join(SCALA_METHOD_MAPPINGS.get(c, c) for c in name)
+
+
+def _unary_op(name):
+ """ Create a method for given unary operator """
+ def _(self):
+ return Column(getattr(self._jc, _scalaMethod(name))(), self._jdf, self.sql_ctx)
+ return _
+
+
+def _bin_op(name):
+ """ Create a method for given binary operator """
+ def _(self, other):
+ if isinstance(other, Column):
+ jc = other._jc
+ else:
+ jc = _create_column_from_literal(other)
+ return Column(getattr(self._jc, _scalaMethod(name))(jc), self._jdf, self.sql_ctx)
+ return _
+
+
+def _reverse_op(name):
+ """ Create a method for binary operator (this object is on right side)
+ """
+ def _(self, other):
+ return Column(getattr(_create_column_from_literal(other), _scalaMethod(name))(self._jc),
+ self._jdf, self.sql_ctx)
+ return _
+
+
+class Column(DataFrame):
+
+ """
+ A column in a DataFrame.
+
+ `Column` instances can be created by:
+ {{{
+ // 1. Select a column out of a DataFrame
+ df.colName
+ df["colName"]
+
+ // 2. Create from an expression
+ df["colName"] + 1
+ }}}
+ """
+
+ def __init__(self, jc, jdf=None, sql_ctx=None):
+ self._jc = jc
+ super(Column, self).__init__(jdf, sql_ctx)
+
+ # arithmetic operators
+ __neg__ = _unary_op("unary_-")
+ __add__ = _bin_op("+")
+ __sub__ = _bin_op("-")
+ __mul__ = _bin_op("*")
+ __div__ = _bin_op("/")
+ __mod__ = _bin_op("%")
+ __radd__ = _bin_op("+")
+ __rsub__ = _reverse_op("-")
+ __rmul__ = _bin_op("*")
+ __rdiv__ = _reverse_op("/")
+ __rmod__ = _reverse_op("%")
+ __abs__ = _unary_op("abs")
+ abs = _unary_op("abs")
+ sqrt = _unary_op("sqrt")
+
+ # logistic operators
+ __eq__ = _bin_op("===")
+ __ne__ = _bin_op("!==")
+ __lt__ = _bin_op("<")
+ __le__ = _bin_op("<=")
+ __ge__ = _bin_op(">=")
+ __gt__ = _bin_op(">")
+ # `and`, `or`, `not` cannot be overloaded in Python
+ And = _bin_op('&&')
+ Or = _bin_op('||')
+ Not = _unary_op('unary_!')
+
+ # bitwise operators
+ __and__ = _bin_op("&")
+ __or__ = _bin_op("|")
+ __invert__ = _unary_op("unary_~")
+ __xor__ = _bin_op("^")
+ # __lshift__ = _bin_op("<<")
+ # __rshift__ = _bin_op(">>")
+ __rand__ = _bin_op("&")
+ __ror__ = _bin_op("|")
+ __rxor__ = _bin_op("^")
+ # __rlshift__ = _reverse_op("<<")
+ # __rrshift__ = _reverse_op(">>")
+
+ # container operators
+ __contains__ = _bin_op("contains")
+ __getitem__ = _bin_op("getItem")
+ # __getattr__ = _bin_op("getField")
+
+ # string methods
+ rlike = _bin_op("rlike")
+ like = _bin_op("like")
+ startswith = _bin_op("startsWith")
+ endswith = _bin_op("endsWith")
+ upper = _unary_op("upper")
+ lower = _unary_op("lower")
+
+ def substr(self, startPos, pos):
+ if type(startPos) != type(pos):
+ raise TypeError("Can not mix the type")
+ if isinstance(startPos, (int, long)):
+
+ jc = self._jc.substr(startPos, pos)
+ elif isinstance(startPos, Column):
+ jc = self._jc.substr(startPos._jc, pos._jc)
+ else:
+ raise TypeError("Unexpected type: %s" % type(startPos))
+ return Column(jc, self._jdf, self.sql_ctx)
+
+ __getslice__ = substr
+
+ # order
+ asc = _unary_op("asc")
+ desc = _unary_op("desc")
+
+ isNull = _unary_op("isNull")
+ isNotNull = _unary_op("isNotNull")
+
+ # `as` is keyword
+ def As(self, alias):
+ return Column(getattr(self._jsc, "as")(alias), self._jdf, self.sql_ctx)
+
+ def cast(self, dataType):
+ if self.sql_ctx is None:
+ sc = SparkContext._active_spark_context
+ ssql_ctx = sc._jvm.SQLContext(sc._jsc.sc())
+ else:
+ ssql_ctx = self.sql_ctx._ssql_ctx
+ jdt = ssql_ctx.parseDataType(dataType.json())
+ return Column(self._jc.cast(jdt), self._jdf, self.sql_ctx)
+
+
+def _aggregate_func(name):
+ """ Creat a function for aggregator by name"""
+ def _(col):
+ sc = SparkContext._active_spark_context
+ if isinstance(col, Column):
+ jcol = col._jc
+ else:
+ jcol = _create_column_from_name(col)
+ # FIXME: can not access dsl.min/max ...
+ jc = getattr(sc._jvm.org.apache.spark.sql.dsl(), name)(jcol)
+ return Column(jc)
+ return staticmethod(_)
+
+
+class Aggregator(object):
+ """
+ A collections of builtin aggregators
+ """
+ max = _aggregate_func("max")
+ min = _aggregate_func("min")
+ avg = mean = _aggregate_func("mean")
+ sum = _aggregate_func("sum")
+ first = _aggregate_func("first")
+ last = _aggregate_func("last")
+ count = _aggregate_func("count")
def _test():
diff --git a/python/pyspark/tests.py b/python/pyspark/tests.py
index b474fcf5bfb7e..e8e207af462de 100644
--- a/python/pyspark/tests.py
+++ b/python/pyspark/tests.py
@@ -806,6 +806,9 @@ def tearDownClass(cls):
def setUp(self):
self.sqlCtx = SQLContext(self.sc)
+ self.testData = [Row(key=i, value=str(i)) for i in range(100)]
+ rdd = self.sc.parallelize(self.testData)
+ self.df = self.sqlCtx.inferSchema(rdd)
def test_udf(self):
self.sqlCtx.registerFunction("twoArgs", lambda x, y: len(x) + y, IntegerType())
@@ -821,7 +824,7 @@ def test_udf2(self):
def test_udf_with_array_type(self):
d = [Row(l=range(3), d={"key": range(5)})]
rdd = self.sc.parallelize(d)
- srdd = self.sqlCtx.inferSchema(rdd).registerTempTable("test")
+ self.sqlCtx.inferSchema(rdd).registerTempTable("test")
self.sqlCtx.registerFunction("copylist", lambda l: list(l), ArrayType(IntegerType()))
self.sqlCtx.registerFunction("maplen", lambda d: len(d), IntegerType())
[(l1, l2)] = self.sqlCtx.sql("select copylist(l), maplen(d) from test").collect()
@@ -839,68 +842,51 @@ def test_broadcast_in_udf(self):
def test_basic_functions(self):
rdd = self.sc.parallelize(['{"foo":"bar"}', '{"foo":"baz"}'])
- srdd = self.sqlCtx.jsonRDD(rdd)
- srdd.count()
- srdd.collect()
- srdd.schemaString()
- srdd.schema()
+ df = self.sqlCtx.jsonRDD(rdd)
+ df.count()
+ df.collect()
+ df.schema()
# cache and checkpoint
- self.assertFalse(srdd.is_cached)
- srdd.persist()
- srdd.unpersist()
- srdd.cache()
- self.assertTrue(srdd.is_cached)
- self.assertFalse(srdd.isCheckpointed())
- self.assertEqual(None, srdd.getCheckpointFile())
-
- srdd = srdd.coalesce(2, True)
- srdd = srdd.repartition(3)
- srdd = srdd.distinct()
- srdd.intersection(srdd)
- self.assertEqual(2, srdd.count())
-
- srdd.registerTempTable("temp")
- srdd = self.sqlCtx.sql("select foo from temp")
- srdd.count()
- srdd.collect()
-
- def test_distinct(self):
- rdd = self.sc.parallelize(['{"a": 1}', '{"b": 2}', '{"c": 3}']*10, 10)
- srdd = self.sqlCtx.jsonRDD(rdd)
- self.assertEquals(srdd.getNumPartitions(), 10)
- self.assertEquals(srdd.distinct().count(), 3)
- result = srdd.distinct(5)
- self.assertEquals(result.getNumPartitions(), 5)
- self.assertEquals(result.count(), 3)
+ self.assertFalse(df.is_cached)
+ df.persist()
+ df.unpersist()
+ df.cache()
+ self.assertTrue(df.is_cached)
+ self.assertEqual(2, df.count())
+
+ df.registerTempTable("temp")
+ df = self.sqlCtx.sql("select foo from temp")
+ df.count()
+ df.collect()
def test_apply_schema_to_row(self):
- srdd = self.sqlCtx.jsonRDD(self.sc.parallelize(["""{"a":2}"""]))
- srdd2 = self.sqlCtx.applySchema(srdd.map(lambda x: x), srdd.schema())
- self.assertEqual(srdd.collect(), srdd2.collect())
+ df = self.sqlCtx.jsonRDD(self.sc.parallelize(["""{"a":2}"""]))
+ df2 = self.sqlCtx.applySchema(df.map(lambda x: x), df.schema())
+ self.assertEqual(df.collect(), df2.collect())
rdd = self.sc.parallelize(range(10)).map(lambda x: Row(a=x))
- srdd3 = self.sqlCtx.applySchema(rdd, srdd.schema())
- self.assertEqual(10, srdd3.count())
+ df3 = self.sqlCtx.applySchema(rdd, df.schema())
+ self.assertEqual(10, df3.count())
def test_serialize_nested_array_and_map(self):
d = [Row(l=[Row(a=1, b='s')], d={"key": Row(c=1.0, d="2")})]
rdd = self.sc.parallelize(d)
- srdd = self.sqlCtx.inferSchema(rdd)
- row = srdd.first()
+ df = self.sqlCtx.inferSchema(rdd)
+ row = df.head()
self.assertEqual(1, len(row.l))
self.assertEqual(1, row.l[0].a)
self.assertEqual("2", row.d["key"].d)
- l = srdd.map(lambda x: x.l).first()
+ l = df.map(lambda x: x.l).first()
self.assertEqual(1, len(l))
self.assertEqual('s', l[0].b)
- d = srdd.map(lambda x: x.d).first()
+ d = df.map(lambda x: x.d).first()
self.assertEqual(1, len(d))
self.assertEqual(1.0, d["key"].c)
- row = srdd.map(lambda x: x.d["key"]).first()
+ row = df.map(lambda x: x.d["key"]).first()
self.assertEqual(1.0, row.c)
self.assertEqual("2", row.d)
@@ -908,26 +894,26 @@ def test_infer_schema(self):
d = [Row(l=[], d={}),
Row(l=[Row(a=1, b='s')], d={"key": Row(c=1.0, d="2")}, s="")]
rdd = self.sc.parallelize(d)
- srdd = self.sqlCtx.inferSchema(rdd)
- self.assertEqual([], srdd.map(lambda r: r.l).first())
- self.assertEqual([None, ""], srdd.map(lambda r: r.s).collect())
- srdd.registerTempTable("test")
+ df = self.sqlCtx.inferSchema(rdd)
+ self.assertEqual([], df.map(lambda r: r.l).first())
+ self.assertEqual([None, ""], df.map(lambda r: r.s).collect())
+ df.registerTempTable("test")
result = self.sqlCtx.sql("SELECT l[0].a from test where d['key'].d = '2'")
- self.assertEqual(1, result.first()[0])
+ self.assertEqual(1, result.head()[0])
- srdd2 = self.sqlCtx.inferSchema(rdd, 1.0)
- self.assertEqual(srdd.schema(), srdd2.schema())
- self.assertEqual({}, srdd2.map(lambda r: r.d).first())
- self.assertEqual([None, ""], srdd2.map(lambda r: r.s).collect())
- srdd2.registerTempTable("test2")
+ df2 = self.sqlCtx.inferSchema(rdd, 1.0)
+ self.assertEqual(df.schema(), df2.schema())
+ self.assertEqual({}, df2.map(lambda r: r.d).first())
+ self.assertEqual([None, ""], df2.map(lambda r: r.s).collect())
+ df2.registerTempTable("test2")
result = self.sqlCtx.sql("SELECT l[0].a from test2 where d['key'].d = '2'")
- self.assertEqual(1, result.first()[0])
+ self.assertEqual(1, result.head()[0])
def test_struct_in_map(self):
d = [Row(m={Row(i=1): Row(s="")})]
rdd = self.sc.parallelize(d)
- srdd = self.sqlCtx.inferSchema(rdd)
- k, v = srdd.first().m.items()[0]
+ df = self.sqlCtx.inferSchema(rdd)
+ k, v = df.head().m.items()[0]
self.assertEqual(1, k.i)
self.assertEqual("", v.s)
@@ -935,9 +921,9 @@ def test_convert_row_to_dict(self):
row = Row(l=[Row(a=1, b='s')], d={"key": Row(c=1.0, d="2")})
self.assertEqual(1, row.asDict()['l'][0].a)
rdd = self.sc.parallelize([row])
- srdd = self.sqlCtx.inferSchema(rdd)
- srdd.registerTempTable("test")
- row = self.sqlCtx.sql("select l, d from test").first()
+ df = self.sqlCtx.inferSchema(rdd)
+ df.registerTempTable("test")
+ row = self.sqlCtx.sql("select l, d from test").head()
self.assertEqual(1, row.asDict()["l"][0].a)
self.assertEqual(1.0, row.asDict()['d']['key'].c)
@@ -945,12 +931,12 @@ def test_infer_schema_with_udt(self):
from pyspark.tests import ExamplePoint, ExamplePointUDT
row = Row(label=1.0, point=ExamplePoint(1.0, 2.0))
rdd = self.sc.parallelize([row])
- srdd = self.sqlCtx.inferSchema(rdd)
- schema = srdd.schema()
+ df = self.sqlCtx.inferSchema(rdd)
+ schema = df.schema()
field = [f for f in schema.fields if f.name == "point"][0]
self.assertEqual(type(field.dataType), ExamplePointUDT)
- srdd.registerTempTable("labeled_point")
- point = self.sqlCtx.sql("SELECT point FROM labeled_point").first().point
+ df.registerTempTable("labeled_point")
+ point = self.sqlCtx.sql("SELECT point FROM labeled_point").head().point
self.assertEqual(point, ExamplePoint(1.0, 2.0))
def test_apply_schema_with_udt(self):
@@ -959,21 +945,52 @@ def test_apply_schema_with_udt(self):
rdd = self.sc.parallelize([row])
schema = StructType([StructField("label", DoubleType(), False),
StructField("point", ExamplePointUDT(), False)])
- srdd = self.sqlCtx.applySchema(rdd, schema)
- point = srdd.first().point
+ df = self.sqlCtx.applySchema(rdd, schema)
+ point = df.head().point
self.assertEquals(point, ExamplePoint(1.0, 2.0))
def test_parquet_with_udt(self):
from pyspark.tests import ExamplePoint
row = Row(label=1.0, point=ExamplePoint(1.0, 2.0))
rdd = self.sc.parallelize([row])
- srdd0 = self.sqlCtx.inferSchema(rdd)
+ df0 = self.sqlCtx.inferSchema(rdd)
output_dir = os.path.join(self.tempdir.name, "labeled_point")
- srdd0.saveAsParquetFile(output_dir)
- srdd1 = self.sqlCtx.parquetFile(output_dir)
- point = srdd1.first().point
+ df0.saveAsParquetFile(output_dir)
+ df1 = self.sqlCtx.parquetFile(output_dir)
+ point = df1.head().point
self.assertEquals(point, ExamplePoint(1.0, 2.0))
+ def test_column_operators(self):
+ from pyspark.sql import Column, LongType
+ ci = self.df.key
+ cs = self.df.value
+ c = ci == cs
+ self.assertTrue(isinstance((- ci - 1 - 2) % 3 * 2.5 / 3.5, Column))
+ rcc = (1 + ci), (1 - ci), (1 * ci), (1 / ci), (1 % ci)
+ self.assertTrue(all(isinstance(c, Column) for c in rcc))
+ cb = [ci == 5, ci != 0, ci > 3, ci < 4, ci >= 0, ci <= 7, ci and cs, ci or cs]
+ self.assertTrue(all(isinstance(c, Column) for c in cb))
+ cbit = (ci & ci), (ci | ci), (ci ^ ci), (~ci)
+ self.assertTrue(all(isinstance(c, Column) for c in cbit))
+ css = cs.like('a'), cs.rlike('a'), cs.asc(), cs.desc(), cs.startswith('a'), cs.endswith('a')
+ self.assertTrue(all(isinstance(c, Column) for c in css))
+ self.assertTrue(isinstance(ci.cast(LongType()), Column))
+
+ def test_column_select(self):
+ df = self.df
+ self.assertEqual(self.testData, df.select("*").collect())
+ self.assertEqual(self.testData, df.select(df.key, df.value).collect())
+ self.assertEqual([Row(value='1')], df.where(df.key == 1).select(df.value).collect())
+
+ def test_aggregator(self):
+ df = self.df
+ g = df.groupBy()
+ self.assertEqual([99, 100], sorted(g.agg({'key': 'max', 'value': 'count'}).collect()[0]))
+ self.assertEqual([Row(**{"AVG(key#0)": 49.5})], g.mean().collect())
+ # TODO(davies): fix aggregators
+ from pyspark.sql import Aggregator as Agg
+ # self.assertEqual((0, '100'), tuple(g.agg(Agg.first(df.key), Agg.last(df.value)).first()))
+
class InputFormatTests(ReusedPySparkTestCase):
diff --git a/repl/pom.xml b/repl/pom.xml
index 0bc8bccf90a6d..ae7c31aef4f5f 100644
--- a/repl/pom.xml
+++ b/repl/pom.xml
@@ -92,13 +92,6 @@
target/scala-${scala.binary.version}/classestarget/scala-${scala.binary.version}/test-classes
-
- org.apache.maven.plugins
- maven-deploy-plugin
-
- true
-
- org.codehaus.mojo
diff --git a/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkCommandLine.scala b/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkCommandLine.scala
index 05816941b54b3..6480e2d24e044 100644
--- a/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkCommandLine.scala
+++ b/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkCommandLine.scala
@@ -19,14 +19,21 @@ package org.apache.spark.repl
import scala.tools.nsc.{Settings, CompilerCommand}
import scala.Predef._
+import org.apache.spark.annotation.DeveloperApi
/**
* Command class enabling Spark-specific command line options (provided by
* org.apache.spark.repl.SparkRunnerSettings).
+ *
+ * @example new SparkCommandLine(Nil).settings
+ *
+ * @param args The list of command line arguments
+ * @param settings The underlying settings to associate with this set of
+ * command-line options
*/
+@DeveloperApi
class SparkCommandLine(args: List[String], override val settings: Settings)
extends CompilerCommand(args, settings) {
-
def this(args: List[String], error: String => Unit) {
this(args, new SparkRunnerSettings(error))
}
diff --git a/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkExprTyper.scala b/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkExprTyper.scala
index f8432c8af6ed2..5fb378112ef92 100644
--- a/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkExprTyper.scala
+++ b/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkExprTyper.scala
@@ -15,7 +15,7 @@ import scala.tools.nsc.ast.parser.Tokens.EOF
import org.apache.spark.Logging
-trait SparkExprTyper extends Logging {
+private[repl] trait SparkExprTyper extends Logging {
val repl: SparkIMain
import repl._
diff --git a/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkHelper.scala b/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkHelper.scala
index 5340951d91331..955be17a73b85 100644
--- a/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkHelper.scala
+++ b/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkHelper.scala
@@ -17,6 +17,23 @@
package scala.tools.nsc
+import org.apache.spark.annotation.DeveloperApi
+
+// NOTE: Forced to be public (and in scala.tools.nsc package) to access the
+// settings "explicitParentLoader" method
+
+/**
+ * Provides exposure for the explicitParentLoader method on settings instances.
+ */
+@DeveloperApi
object SparkHelper {
+ /**
+ * Retrieves the explicit parent loader for the provided settings.
+ *
+ * @param settings The settings whose explicit parent loader to retrieve
+ *
+ * @return The Optional classloader representing the explicit parent loader
+ */
+ @DeveloperApi
def explicitParentLoader(settings: Settings) = settings.explicitParentLoader
}
diff --git a/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkILoop.scala b/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkILoop.scala
index e56b74edba88c..72c1a989999b4 100644
--- a/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkILoop.scala
+++ b/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkILoop.scala
@@ -10,6 +10,8 @@ package org.apache.spark.repl
import java.net.URL
+import org.apache.spark.annotation.DeveloperApi
+
import scala.reflect.io.AbstractFile
import scala.tools.nsc._
import scala.tools.nsc.backend.JavaPlatform
@@ -57,20 +59,22 @@ import org.apache.spark.util.Utils
* @author Lex Spoon
* @version 1.2
*/
-class SparkILoop(in0: Option[BufferedReader], protected val out: JPrintWriter,
- val master: Option[String])
- extends AnyRef
- with LoopCommands
- with SparkILoopInit
- with Logging
-{
+@DeveloperApi
+class SparkILoop(
+ private val in0: Option[BufferedReader],
+ protected val out: JPrintWriter,
+ val master: Option[String]
+) extends AnyRef with LoopCommands with SparkILoopInit with Logging {
def this(in0: BufferedReader, out: JPrintWriter, master: String) = this(Some(in0), out, Some(master))
def this(in0: BufferedReader, out: JPrintWriter) = this(Some(in0), out, None)
def this() = this(None, new JPrintWriter(Console.out, true), None)
- var in: InteractiveReader = _ // the input stream from which commands come
- var settings: Settings = _
- var intp: SparkIMain = _
+ private var in: InteractiveReader = _ // the input stream from which commands come
+
+ // NOTE: Exposed in package for testing
+ private[repl] var settings: Settings = _
+
+ private[repl] var intp: SparkIMain = _
@deprecated("Use `intp` instead.", "2.9.0") def interpreter = intp
@deprecated("Use `intp` instead.", "2.9.0") def interpreter_= (i: SparkIMain): Unit = intp = i
@@ -123,6 +127,8 @@ class SparkILoop(in0: Option[BufferedReader], protected val out: JPrintWriter,
}
}
+ // NOTE: Must be public for visibility
+ @DeveloperApi
var sparkContext: SparkContext = _
override def echoCommandMessage(msg: String) {
@@ -130,45 +136,45 @@ class SparkILoop(in0: Option[BufferedReader], protected val out: JPrintWriter,
}
// def isAsync = !settings.Yreplsync.value
- def isAsync = false
+ private[repl] def isAsync = false
// lazy val power = new Power(intp, new StdReplVals(this))(tagOfStdReplVals, classTag[StdReplVals])
- def history = in.history
+ private def history = in.history
/** The context class loader at the time this object was created */
protected val originalClassLoader = Utils.getContextOrSparkClassLoader
// classpath entries added via :cp
- var addedClasspath: String = ""
+ private var addedClasspath: String = ""
/** A reverse list of commands to replay if the user requests a :replay */
- var replayCommandStack: List[String] = Nil
+ private var replayCommandStack: List[String] = Nil
/** A list of commands to replay if the user requests a :replay */
- def replayCommands = replayCommandStack.reverse
+ private def replayCommands = replayCommandStack.reverse
/** Record a command for replay should the user request a :replay */
- def addReplay(cmd: String) = replayCommandStack ::= cmd
+ private def addReplay(cmd: String) = replayCommandStack ::= cmd
- def savingReplayStack[T](body: => T): T = {
+ private def savingReplayStack[T](body: => T): T = {
val saved = replayCommandStack
try body
finally replayCommandStack = saved
}
- def savingReader[T](body: => T): T = {
+ private def savingReader[T](body: => T): T = {
val saved = in
try body
finally in = saved
}
- def sparkCleanUp(){
+ private def sparkCleanUp(){
echo("Stopping spark context.")
intp.beQuietDuring {
command("sc.stop()")
}
}
/** Close the interpreter and set the var to null. */
- def closeInterpreter() {
+ private def closeInterpreter() {
if (intp ne null) {
sparkCleanUp()
intp.close()
@@ -179,14 +185,16 @@ class SparkILoop(in0: Option[BufferedReader], protected val out: JPrintWriter,
class SparkILoopInterpreter extends SparkIMain(settings, out) {
outer =>
- override lazy val formatting = new Formatting {
+ override private[repl] lazy val formatting = new Formatting {
def prompt = SparkILoop.this.prompt
}
override protected def parentClassLoader = SparkHelper.explicitParentLoader(settings).getOrElse(classOf[SparkILoop].getClassLoader)
}
- /** Create a new interpreter. */
- def createInterpreter() {
+ /**
+ * Constructs a new interpreter.
+ */
+ protected def createInterpreter() {
require(settings != null)
if (addedClasspath != "") settings.classpath.append(addedClasspath)
@@ -207,7 +215,7 @@ class SparkILoop(in0: Option[BufferedReader], protected val out: JPrintWriter,
}
/** print a friendly help message */
- def helpCommand(line: String): Result = {
+ private def helpCommand(line: String): Result = {
if (line == "") helpSummary()
else uniqueCommand(line) match {
case Some(lc) => echo("\n" + lc.longHelp)
@@ -258,7 +266,7 @@ class SparkILoop(in0: Option[BufferedReader], protected val out: JPrintWriter,
}
/** Show the history */
- lazy val historyCommand = new LoopCommand("history", "show the history (optional num is commands to show)") {
+ private lazy val historyCommand = new LoopCommand("history", "show the history (optional num is commands to show)") {
override def usage = "[num]"
def defaultLines = 20
@@ -279,21 +287,21 @@ class SparkILoop(in0: Option[BufferedReader], protected val out: JPrintWriter,
// When you know you are most likely breaking into the middle
// of a line being typed. This softens the blow.
- protected def echoAndRefresh(msg: String) = {
+ private[repl] def echoAndRefresh(msg: String) = {
echo("\n" + msg)
in.redrawLine()
}
- protected def echo(msg: String) = {
+ private[repl] def echo(msg: String) = {
out println msg
out.flush()
}
- protected def echoNoNL(msg: String) = {
+ private def echoNoNL(msg: String) = {
out print msg
out.flush()
}
/** Search the history */
- def searchHistory(_cmdline: String) {
+ private def searchHistory(_cmdline: String) {
val cmdline = _cmdline.toLowerCase
val offset = history.index - history.size + 1
@@ -302,14 +310,27 @@ class SparkILoop(in0: Option[BufferedReader], protected val out: JPrintWriter,
}
private var currentPrompt = Properties.shellPromptString
+
+ /**
+ * Sets the prompt string used by the REPL.
+ *
+ * @param prompt The new prompt string
+ */
+ @DeveloperApi
def setPrompt(prompt: String) = currentPrompt = prompt
- /** Prompt to print when awaiting input */
+
+ /**
+ * Represents the current prompt string used by the REPL.
+ *
+ * @return The current prompt string
+ */
+ @DeveloperApi
def prompt = currentPrompt
import LoopCommand.{ cmd, nullary }
/** Standard commands */
- lazy val standardCommands = List(
+ private lazy val standardCommands = List(
cmd("cp", "", "add a jar or directory to the classpath", addClasspath),
cmd("help", "[command]", "print this summary or command-specific help", helpCommand),
historyCommand,
@@ -333,7 +354,7 @@ class SparkILoop(in0: Option[BufferedReader], protected val out: JPrintWriter,
)
/** Power user commands */
- lazy val powerCommands: List[LoopCommand] = List(
+ private lazy val powerCommands: List[LoopCommand] = List(
// cmd("phase", "", "set the implicit phase for power commands", phaseCommand)
)
@@ -459,7 +480,7 @@ class SparkILoop(in0: Option[BufferedReader], protected val out: JPrintWriter,
}
}
- protected def newJavap() = new JavapClass(addToolsJarToLoader(), new SparkIMain.ReplStrippingWriter(intp)) {
+ private def newJavap() = new JavapClass(addToolsJarToLoader(), new SparkIMain.ReplStrippingWriter(intp)) {
override def tryClass(path: String): Array[Byte] = {
val hd :: rest = path split '.' toList;
// If there are dots in the name, the first segment is the
@@ -581,7 +602,12 @@ class SparkILoop(in0: Option[BufferedReader], protected val out: JPrintWriter,
// }
// }
- /** Available commands */
+ /**
+ * Provides a list of available commands.
+ *
+ * @return The list of commands
+ */
+ @DeveloperApi
def commands: List[LoopCommand] = standardCommands /*++ (
if (isReplPower) powerCommands else Nil
)*/
@@ -613,7 +639,7 @@ class SparkILoop(in0: Option[BufferedReader], protected val out: JPrintWriter,
* command() for each line of input, and stops when
* command() returns false.
*/
- def loop() {
+ private def loop() {
def readOneLine() = {
out.flush()
in readLine prompt
@@ -642,7 +668,7 @@ class SparkILoop(in0: Option[BufferedReader], protected val out: JPrintWriter,
}
/** interpret all lines from a specified file */
- def interpretAllFrom(file: File) {
+ private def interpretAllFrom(file: File) {
savingReader {
savingReplayStack {
file applyReader { reader =>
@@ -655,7 +681,7 @@ class SparkILoop(in0: Option[BufferedReader], protected val out: JPrintWriter,
}
/** create a new interpreter and replay the given commands */
- def replay() {
+ private def replay() {
reset()
if (replayCommandStack.isEmpty)
echo("Nothing to replay.")
@@ -665,7 +691,7 @@ class SparkILoop(in0: Option[BufferedReader], protected val out: JPrintWriter,
echo("")
}
}
- def resetCommand() {
+ private def resetCommand() {
echo("Resetting repl state.")
if (replayCommandStack.nonEmpty) {
echo("Forgetting this session history:\n")
@@ -681,13 +707,13 @@ class SparkILoop(in0: Option[BufferedReader], protected val out: JPrintWriter,
reset()
}
- def reset() {
+ private def reset() {
intp.reset()
// unleashAndSetPhase()
}
/** fork a shell and run a command */
- lazy val shCommand = new LoopCommand("sh", "run a shell command (result is implicitly => List[String])") {
+ private lazy val shCommand = new LoopCommand("sh", "run a shell command (result is implicitly => List[String])") {
override def usage = ""
def apply(line: String): Result = line match {
case "" => showUsage()
@@ -698,14 +724,14 @@ class SparkILoop(in0: Option[BufferedReader], protected val out: JPrintWriter,
}
}
- def withFile(filename: String)(action: File => Unit) {
+ private def withFile(filename: String)(action: File => Unit) {
val f = File(filename)
if (f.exists) action(f)
else echo("That file does not exist")
}
- def loadCommand(arg: String) = {
+ private def loadCommand(arg: String) = {
var shouldReplay: Option[String] = None
withFile(arg)(f => {
interpretAllFrom(f)
@@ -714,7 +740,7 @@ class SparkILoop(in0: Option[BufferedReader], protected val out: JPrintWriter,
Result(true, shouldReplay)
}
- def addAllClasspath(args: Seq[String]): Unit = {
+ private def addAllClasspath(args: Seq[String]): Unit = {
var added = false
var totalClasspath = ""
for (arg <- args) {
@@ -729,7 +755,7 @@ class SparkILoop(in0: Option[BufferedReader], protected val out: JPrintWriter,
}
}
- def addClasspath(arg: String): Unit = {
+ private def addClasspath(arg: String): Unit = {
val f = File(arg).normalize
if (f.exists) {
addedClasspath = ClassPath.join(addedClasspath, f.path)
@@ -741,12 +767,12 @@ class SparkILoop(in0: Option[BufferedReader], protected val out: JPrintWriter,
}
- def powerCmd(): Result = {
+ private def powerCmd(): Result = {
if (isReplPower) "Already in power mode."
else enablePowerMode(false)
}
- def enablePowerMode(isDuringInit: Boolean) = {
+ private[repl] def enablePowerMode(isDuringInit: Boolean) = {
// replProps.power setValue true
// unleashAndSetPhase()
// asyncEcho(isDuringInit, power.banner)
@@ -759,12 +785,12 @@ class SparkILoop(in0: Option[BufferedReader], protected val out: JPrintWriter,
// }
// }
- def asyncEcho(async: Boolean, msg: => String) {
+ private def asyncEcho(async: Boolean, msg: => String) {
if (async) asyncMessage(msg)
else echo(msg)
}
- def verbosity() = {
+ private def verbosity() = {
// val old = intp.printResults
// intp.printResults = !old
// echo("Switched " + (if (old) "off" else "on") + " result printing.")
@@ -773,7 +799,7 @@ class SparkILoop(in0: Option[BufferedReader], protected val out: JPrintWriter,
/** Run one command submitted by the user. Two values are returned:
* (1) whether to keep running, (2) the line to record for replay,
* if any. */
- def command(line: String): Result = {
+ private[repl] def command(line: String): Result = {
if (line startsWith ":") {
val cmd = line.tail takeWhile (x => !x.isWhitespace)
uniqueCommand(cmd) match {
@@ -789,7 +815,7 @@ class SparkILoop(in0: Option[BufferedReader], protected val out: JPrintWriter,
Iterator continually in.readLine("") takeWhile (x => x != null && cond(x))
}
- def pasteCommand(): Result = {
+ private def pasteCommand(): Result = {
echo("// Entering paste mode (ctrl-D to finish)\n")
val code = readWhile(_ => true) mkString "\n"
echo("\n// Exiting paste mode, now interpreting.\n")
@@ -820,7 +846,7 @@ class SparkILoop(in0: Option[BufferedReader], protected val out: JPrintWriter,
* read, go ahead and interpret it. Return the full string
* to be recorded for replay, if any.
*/
- def interpretStartingWith(code: String): Option[String] = {
+ private def interpretStartingWith(code: String): Option[String] = {
// signal completion non-completion input has been received
in.completion.resetVerbosity()
@@ -874,7 +900,7 @@ class SparkILoop(in0: Option[BufferedReader], protected val out: JPrintWriter,
}
// runs :load `file` on any files passed via -i
- def loadFiles(settings: Settings) = settings match {
+ private def loadFiles(settings: Settings) = settings match {
case settings: SparkRunnerSettings =>
for (filename <- settings.loadfiles.value) {
val cmd = ":load " + filename
@@ -889,7 +915,7 @@ class SparkILoop(in0: Option[BufferedReader], protected val out: JPrintWriter,
* unless settings or properties are such that it should start
* with SimpleReader.
*/
- def chooseReader(settings: Settings): InteractiveReader = {
+ private def chooseReader(settings: Settings): InteractiveReader = {
if (settings.Xnojline.value || Properties.isEmacsShell)
SimpleReader()
else try new SparkJLineReader(
@@ -903,8 +929,8 @@ class SparkILoop(in0: Option[BufferedReader], protected val out: JPrintWriter,
}
}
- val u: scala.reflect.runtime.universe.type = scala.reflect.runtime.universe
- val m = u.runtimeMirror(Utils.getSparkClassLoader)
+ private val u: scala.reflect.runtime.universe.type = scala.reflect.runtime.universe
+ private val m = u.runtimeMirror(Utils.getSparkClassLoader)
private def tagOfStaticClass[T: ClassTag]: u.TypeTag[T] =
u.TypeTag[T](
m,
@@ -913,7 +939,7 @@ class SparkILoop(in0: Option[BufferedReader], protected val out: JPrintWriter,
m.staticClass(classTag[T].runtimeClass.getName).toTypeConstructor.asInstanceOf[U # Type]
})
- def process(settings: Settings): Boolean = savingContextLoader {
+ private def process(settings: Settings): Boolean = savingContextLoader {
if (getMaster() == "yarn-client") System.setProperty("SPARK_YARN_MODE", "true")
this.settings = settings
@@ -972,6 +998,8 @@ class SparkILoop(in0: Option[BufferedReader], protected val out: JPrintWriter,
true
}
+ // NOTE: Must be public for visibility
+ @DeveloperApi
def createSparkContext(): SparkContext = {
val execUri = System.getenv("SPARK_EXECUTOR_URI")
val jars = SparkILoop.getAddedJars
@@ -979,7 +1007,7 @@ class SparkILoop(in0: Option[BufferedReader], protected val out: JPrintWriter,
.setMaster(getMaster())
.setAppName("Spark shell")
.setJars(jars)
- .set("spark.repl.class.uri", intp.classServer.uri)
+ .set("spark.repl.class.uri", intp.classServerUri)
if (execUri != null) {
conf.set("spark.executor.uri", execUri)
}
@@ -1014,7 +1042,7 @@ class SparkILoop(in0: Option[BufferedReader], protected val out: JPrintWriter,
}
@deprecated("Use `process` instead", "2.9.0")
- def main(settings: Settings): Unit = process(settings)
+ private def main(settings: Settings): Unit = process(settings)
}
object SparkILoop {
@@ -1033,7 +1061,7 @@ object SparkILoop {
// Designed primarily for use by test code: take a String with a
// bunch of code, and prints out a transcript of what it would look
// like if you'd just typed it into the repl.
- def runForTranscript(code: String, settings: Settings): String = {
+ private[repl] def runForTranscript(code: String, settings: Settings): String = {
import java.io.{ BufferedReader, StringReader, OutputStreamWriter }
stringFromStream { ostream =>
@@ -1071,7 +1099,7 @@ object SparkILoop {
/** Creates an interpreter loop with default settings and feeds
* the given code to it as input.
*/
- def run(code: String, sets: Settings = new Settings): String = {
+ private[repl] def run(code: String, sets: Settings = new Settings): String = {
import java.io.{ BufferedReader, StringReader, OutputStreamWriter }
stringFromStream { ostream =>
@@ -1087,5 +1115,5 @@ object SparkILoop {
}
}
}
- def run(lines: List[String]): String = run(lines map (_ + "\n") mkString)
+ private[repl] def run(lines: List[String]): String = run(lines map (_ + "\n") mkString)
}
diff --git a/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkILoopInit.scala b/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkILoopInit.scala
index da4286c5e4874..99bd777c04fdb 100644
--- a/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkILoopInit.scala
+++ b/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkILoopInit.scala
@@ -19,7 +19,7 @@ import org.apache.spark.SPARK_VERSION
/**
* Machinery for the asynchronous initialization of the repl.
*/
-trait SparkILoopInit {
+private[repl] trait SparkILoopInit {
self: SparkILoop =>
/** Print a welcome message */
diff --git a/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkIMain.scala b/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkIMain.scala
index b646f0b6f0868..35fb625645022 100644
--- a/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkIMain.scala
+++ b/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkIMain.scala
@@ -39,6 +39,7 @@ import scala.util.control.ControlThrowable
import org.apache.spark.{Logging, HttpServer, SecurityManager, SparkConf}
import org.apache.spark.util.Utils
+import org.apache.spark.annotation.DeveloperApi
// /** directory to save .class files to */
// private class ReplVirtualDirectory(out: JPrintWriter) extends VirtualDirectory("((memory))", None) {
@@ -84,17 +85,18 @@ import org.apache.spark.util.Utils
* @author Moez A. Abdel-Gawad
* @author Lex Spoon
*/
+ @DeveloperApi
class SparkIMain(
initialSettings: Settings,
val out: JPrintWriter,
propagateExceptions: Boolean = false)
extends SparkImports with Logging { imain =>
- val conf = new SparkConf()
+ private val conf = new SparkConf()
- val SPARK_DEBUG_REPL: Boolean = (System.getenv("SPARK_DEBUG_REPL") == "1")
+ private val SPARK_DEBUG_REPL: Boolean = (System.getenv("SPARK_DEBUG_REPL") == "1")
/** Local directory to save .class files too */
- lazy val outputDir = {
+ private lazy val outputDir = {
val tmp = System.getProperty("java.io.tmpdir")
val rootDir = conf.get("spark.repl.classdir", tmp)
Utils.createTempDir(rootDir)
@@ -103,13 +105,20 @@ import org.apache.spark.util.Utils
echo("Output directory: " + outputDir)
}
- val virtualDirectory = new PlainFile(outputDir) // "directory" for classfiles
+ /**
+ * Returns the path to the output directory containing all generated
+ * class files that will be served by the REPL class server.
+ */
+ @DeveloperApi
+ lazy val getClassOutputDirectory = outputDir
+
+ private val virtualDirectory = new PlainFile(outputDir) // "directory" for classfiles
/** Jetty server that will serve our classes to worker nodes */
- val classServerPort = conf.getInt("spark.replClassServer.port", 0)
- val classServer = new HttpServer(conf, outputDir, new SecurityManager(conf), classServerPort, "HTTP class server")
+ private val classServerPort = conf.getInt("spark.replClassServer.port", 0)
+ private val classServer = new HttpServer(conf, outputDir, new SecurityManager(conf), classServerPort, "HTTP class server")
private var currentSettings: Settings = initialSettings
- var printResults = true // whether to print result lines
- var totalSilence = false // whether to print anything
+ private var printResults = true // whether to print result lines
+ private var totalSilence = false // whether to print anything
private var _initializeComplete = false // compiler is initialized
private var _isInitialized: Future[Boolean] = null // set up initialization future
private var bindExceptions = true // whether to bind the lastException variable
@@ -123,6 +132,14 @@ import org.apache.spark.util.Utils
echo("Class server started, URI = " + classServer.uri)
}
+ /**
+ * URI of the class server used to feed REPL compiled classes.
+ *
+ * @return The string representing the class server uri
+ */
+ @DeveloperApi
+ def classServerUri = classServer.uri
+
/** We're going to go to some trouble to initialize the compiler asynchronously.
* It's critical that nothing call into it until it's been initialized or we will
* run into unrecoverable issues, but the perceived repl startup time goes
@@ -141,17 +158,18 @@ import org.apache.spark.util.Utils
() => { counter += 1 ; counter }
}
- def compilerClasspath: Seq[URL] = (
+ private def compilerClasspath: Seq[URL] = (
if (isInitializeComplete) global.classPath.asURLs
else new PathResolver(settings).result.asURLs // the compiler's classpath
)
- def settings = currentSettings
- def mostRecentLine = prevRequestList match {
+ // NOTE: Exposed to repl package since accessed indirectly from SparkIMain
+ private[repl] def settings = currentSettings
+ private def mostRecentLine = prevRequestList match {
case Nil => ""
case req :: _ => req.originalLine
}
// Run the code body with the given boolean settings flipped to true.
- def withoutWarnings[T](body: => T): T = beQuietDuring {
+ private def withoutWarnings[T](body: => T): T = beQuietDuring {
val saved = settings.nowarn.value
if (!saved)
settings.nowarn.value = true
@@ -164,16 +182,28 @@ import org.apache.spark.util.Utils
def this(settings: Settings) = this(settings, new NewLinePrintWriter(new ConsoleWriter, true))
def this() = this(new Settings())
- lazy val repllog: Logger = new Logger {
+ private lazy val repllog: Logger = new Logger {
val out: JPrintWriter = imain.out
val isInfo: Boolean = BooleanProp keyExists "scala.repl.info"
val isDebug: Boolean = BooleanProp keyExists "scala.repl.debug"
val isTrace: Boolean = BooleanProp keyExists "scala.repl.trace"
}
- lazy val formatting: Formatting = new Formatting {
+ private[repl] lazy val formatting: Formatting = new Formatting {
val prompt = Properties.shellPromptString
}
- lazy val reporter: ConsoleReporter = new SparkIMain.ReplReporter(this)
+
+ // NOTE: Exposed to repl package since used by SparkExprTyper and SparkILoop
+ private[repl] lazy val reporter: ConsoleReporter = new SparkIMain.ReplReporter(this)
+
+ /**
+ * Determines if errors were reported (typically during compilation).
+ *
+ * @note This is not for runtime errors
+ *
+ * @return True if had errors, otherwise false
+ */
+ @DeveloperApi
+ def isReportingErrors = reporter.hasErrors
import formatting._
import reporter.{ printMessage, withoutTruncating }
@@ -193,7 +223,8 @@ import org.apache.spark.util.Utils
private def tquoted(s: String) = "\"\"\"" + s + "\"\"\""
// argument is a thunk to execute after init is done
- def initialize(postInitSignal: => Unit) {
+ // NOTE: Exposed to repl package since used by SparkILoop
+ private[repl] def initialize(postInitSignal: => Unit) {
synchronized {
if (_isInitialized == null) {
_isInitialized = io.spawn {
@@ -203,15 +234,27 @@ import org.apache.spark.util.Utils
}
}
}
+
+ /**
+ * Initializes the underlying compiler/interpreter in a blocking fashion.
+ *
+ * @note Must be executed before using SparkIMain!
+ */
+ @DeveloperApi
def initializeSynchronous(): Unit = {
if (!isInitializeComplete) {
_initialize()
assert(global != null, global)
}
}
- def isInitializeComplete = _initializeComplete
+ private def isInitializeComplete = _initializeComplete
/** the public, go through the future compiler */
+
+ /**
+ * The underlying compiler used to generate ASTs and execute code.
+ */
+ @DeveloperApi
lazy val global: Global = {
if (isInitializeComplete) _compiler
else {
@@ -226,13 +269,13 @@ import org.apache.spark.util.Utils
}
}
@deprecated("Use `global` for access to the compiler instance.", "2.9.0")
- lazy val compiler: global.type = global
+ private lazy val compiler: global.type = global
import global._
import definitions.{ScalaPackage, JavaLangPackage, termMember, typeMember}
import rootMirror.{RootClass, getClassIfDefined, getModuleIfDefined, getRequiredModule, getRequiredClass}
- implicit class ReplTypeOps(tp: Type) {
+ private implicit class ReplTypeOps(tp: Type) {
def orElse(other: => Type): Type = if (tp ne NoType) tp else other
def andAlso(fn: Type => Type): Type = if (tp eq NoType) tp else fn(tp)
}
@@ -240,7 +283,8 @@ import org.apache.spark.util.Utils
// TODO: If we try to make naming a lazy val, we run into big time
// scalac unhappiness with what look like cycles. It has not been easy to
// reduce, but name resolution clearly takes different paths.
- object naming extends {
+ // NOTE: Exposed to repl package since used by SparkExprTyper
+ private[repl] object naming extends {
val global: imain.global.type = imain.global
} with Naming {
// make sure we don't overwrite their unwisely named res3 etc.
@@ -254,22 +298,43 @@ import org.apache.spark.util.Utils
}
import naming._
- object deconstruct extends {
+ // NOTE: Exposed to repl package since used by SparkILoop
+ private[repl] object deconstruct extends {
val global: imain.global.type = imain.global
} with StructuredTypeStrings
- lazy val memberHandlers = new {
+ // NOTE: Exposed to repl package since used by SparkImports
+ private[repl] lazy val memberHandlers = new {
val intp: imain.type = imain
} with SparkMemberHandlers
import memberHandlers._
- /** Temporarily be quiet */
+ /**
+ * Suppresses overwriting print results during the operation.
+ *
+ * @param body The block to execute
+ * @tparam T The return type of the block
+ *
+ * @return The result from executing the block
+ */
+ @DeveloperApi
def beQuietDuring[T](body: => T): T = {
val saved = printResults
printResults = false
try body
finally printResults = saved
}
+
+ /**
+ * Completely masks all output during the operation (minus JVM standard
+ * out and error).
+ *
+ * @param operation The block to execute
+ * @tparam T The return type of the block
+ *
+ * @return The result from executing the block
+ */
+ @DeveloperApi
def beSilentDuring[T](operation: => T): T = {
val saved = totalSilence
totalSilence = true
@@ -277,10 +342,10 @@ import org.apache.spark.util.Utils
finally totalSilence = saved
}
- def quietRun[T](code: String) = beQuietDuring(interpret(code))
+ // NOTE: Exposed to repl package since used by SparkILoop
+ private[repl] def quietRun[T](code: String) = beQuietDuring(interpret(code))
-
- private def logAndDiscard[T](label: String, alt: => T): PartialFunction[Throwable, T] = {
+ private def logAndDiscard[T](label: String, alt: => T): PartialFunction[Throwable, T] = {
case t: ControlThrowable => throw t
case t: Throwable =>
logDebug(label + ": " + unwrap(t))
@@ -298,14 +363,44 @@ import org.apache.spark.util.Utils
finally bindExceptions = true
}
+ /**
+ * Contains the code (in string form) representing a wrapper around all
+ * code executed by this instance.
+ *
+ * @return The wrapper code as a string
+ */
+ @DeveloperApi
def executionWrapper = _executionWrapper
+
+ /**
+ * Sets the code to use as a wrapper around all code executed by this
+ * instance.
+ *
+ * @param code The wrapper code as a string
+ */
+ @DeveloperApi
def setExecutionWrapper(code: String) = _executionWrapper = code
+
+ /**
+ * Clears the code used as a wrapper around all code executed by
+ * this instance.
+ */
+ @DeveloperApi
def clearExecutionWrapper() = _executionWrapper = ""
/** interpreter settings */
- lazy val isettings = new SparkISettings(this)
+ private lazy val isettings = new SparkISettings(this)
- /** Instantiate a compiler. Overridable. */
+ /**
+ * Instantiates a new compiler used by SparkIMain. Overridable to provide
+ * own instance of a compiler.
+ *
+ * @param settings The settings to provide the compiler
+ * @param reporter The reporter to use for compiler output
+ *
+ * @return The compiler as a Global
+ */
+ @DeveloperApi
protected def newCompiler(settings: Settings, reporter: Reporter): ReplGlobal = {
settings.outputDirs setSingleOutput virtualDirectory
settings.exposeEmptyPackage.value = true
@@ -320,13 +415,14 @@ import org.apache.spark.util.Utils
* @note Currently only supports jars, not directories
* @param urls The list of items to add to the compile and runtime classpaths
*/
+ @DeveloperApi
def addUrlsToClassPath(urls: URL*): Unit = {
new Run // Needed to force initialization of "something" to correctly load Scala classes from jars
urls.foreach(_runtimeClassLoader.addNewUrl) // Add jars/classes to runtime for execution
updateCompilerClassPath(urls: _*) // Add jars/classes to compile time for compiling
}
- protected def updateCompilerClassPath(urls: URL*): Unit = {
+ private def updateCompilerClassPath(urls: URL*): Unit = {
require(!global.forMSIL) // Only support JavaPlatform
val platform = global.platform.asInstanceOf[JavaPlatform]
@@ -342,7 +438,7 @@ import org.apache.spark.util.Utils
global.invalidateClassPathEntries(urls.map(_.getPath): _*)
}
- protected def mergeUrlsIntoClassPath(platform: JavaPlatform, urls: URL*): MergedClassPath[AbstractFile] = {
+ private def mergeUrlsIntoClassPath(platform: JavaPlatform, urls: URL*): MergedClassPath[AbstractFile] = {
// Collect our new jars/directories and add them to the existing set of classpaths
val allClassPaths = (
platform.classPath.asInstanceOf[MergedClassPath[AbstractFile]].entries ++
@@ -365,7 +461,13 @@ import org.apache.spark.util.Utils
new MergedClassPath(allClassPaths, platform.classPath.context)
}
- /** Parent classloader. Overridable. */
+ /**
+ * Represents the parent classloader used by this instance. Can be
+ * overridden to provide alternative classloader.
+ *
+ * @return The classloader used as the parent loader of this instance
+ */
+ @DeveloperApi
protected def parentClassLoader: ClassLoader =
SparkHelper.explicitParentLoader(settings).getOrElse( this.getClass.getClassLoader() )
@@ -382,16 +484,18 @@ import org.apache.spark.util.Utils
shadow the old ones, and old code objects refer to the old
definitions.
*/
- def resetClassLoader() = {
+ private def resetClassLoader() = {
logDebug("Setting new classloader: was " + _classLoader)
_classLoader = null
ensureClassLoader()
}
- final def ensureClassLoader() {
+ private final def ensureClassLoader() {
if (_classLoader == null)
_classLoader = makeClassLoader()
}
- def classLoader: AbstractFileClassLoader = {
+
+ // NOTE: Exposed to repl package since used by SparkILoop
+ private[repl] def classLoader: AbstractFileClassLoader = {
ensureClassLoader()
_classLoader
}
@@ -418,27 +522,58 @@ import org.apache.spark.util.Utils
_runtimeClassLoader
})
- def getInterpreterClassLoader() = classLoader
+ private def getInterpreterClassLoader() = classLoader
// Set the current Java "context" class loader to this interpreter's class loader
- def setContextClassLoader() = classLoader.setAsContext()
+ // NOTE: Exposed to repl package since used by SparkILoopInit
+ private[repl] def setContextClassLoader() = classLoader.setAsContext()
- /** Given a simple repl-defined name, returns the real name of
- * the class representing it, e.g. for "Bippy" it may return
- * {{{
- * $line19.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$Bippy
- * }}}
+ /**
+ * Returns the real name of a class based on its repl-defined name.
+ *
+ * ==Example==
+ * Given a simple repl-defined name, returns the real name of
+ * the class representing it, e.g. for "Bippy" it may return
+ * {{{
+ * $line19.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$Bippy
+ * }}}
+ *
+ * @param simpleName The repl-defined name whose real name to retrieve
+ *
+ * @return Some real name if the simple name exists, else None
*/
+ @DeveloperApi
def generatedName(simpleName: String): Option[String] = {
if (simpleName endsWith nme.MODULE_SUFFIX_STRING) optFlatName(simpleName.init) map (_ + nme.MODULE_SUFFIX_STRING)
else optFlatName(simpleName)
}
- def flatName(id: String) = optFlatName(id) getOrElse id
- def optFlatName(id: String) = requestForIdent(id) map (_ fullFlatName id)
+ // NOTE: Exposed to repl package since used by SparkILoop
+ private[repl] def flatName(id: String) = optFlatName(id) getOrElse id
+ // NOTE: Exposed to repl package since used by SparkILoop
+ private[repl] def optFlatName(id: String) = requestForIdent(id) map (_ fullFlatName id)
+
+ /**
+ * Retrieves all simple names contained in the current instance.
+ *
+ * @return A list of sorted names
+ */
+ @DeveloperApi
def allDefinedNames = definedNameMap.keys.toList.sorted
- def pathToType(id: String): String = pathToName(newTypeName(id))
- def pathToTerm(id: String): String = pathToName(newTermName(id))
+
+ private def pathToType(id: String): String = pathToName(newTypeName(id))
+ // NOTE: Exposed to repl package since used by SparkILoop
+ private[repl] def pathToTerm(id: String): String = pathToName(newTermName(id))
+
+ /**
+ * Retrieves the full code path to access the specified simple name
+ * content.
+ *
+ * @param name The simple name of the target whose path to determine
+ *
+ * @return The full path used to access the specified target (name)
+ */
+ @DeveloperApi
def pathToName(name: Name): String = {
if (definedNameMap contains name)
definedNameMap(name) fullPath name
@@ -457,13 +592,13 @@ import org.apache.spark.util.Utils
}
/** Stubs for work in progress. */
- def handleTypeRedefinition(name: TypeName, old: Request, req: Request) = {
+ private def handleTypeRedefinition(name: TypeName, old: Request, req: Request) = {
for (t1 <- old.simpleNameOfType(name) ; t2 <- req.simpleNameOfType(name)) {
logDebug("Redefining type '%s'\n %s -> %s".format(name, t1, t2))
}
}
- def handleTermRedefinition(name: TermName, old: Request, req: Request) = {
+ private def handleTermRedefinition(name: TermName, old: Request, req: Request) = {
for (t1 <- old.compilerTypeOf get name ; t2 <- req.compilerTypeOf get name) {
// Printing the types here has a tendency to cause assertion errors, like
// assertion failed: fatal: has owner value x, but a class owner is required
@@ -473,7 +608,7 @@ import org.apache.spark.util.Utils
}
}
- def recordRequest(req: Request) {
+ private def recordRequest(req: Request) {
if (req == null || referencedNameMap == null)
return
@@ -504,12 +639,12 @@ import org.apache.spark.util.Utils
}
}
- def replwarn(msg: => String) {
+ private def replwarn(msg: => String) {
if (!settings.nowarnings.value)
printMessage(msg)
}
- def isParseable(line: String): Boolean = {
+ private def isParseable(line: String): Boolean = {
beSilentDuring {
try parse(line) match {
case Some(xs) => xs.nonEmpty // parses as-is
@@ -522,22 +657,32 @@ import org.apache.spark.util.Utils
}
}
- def compileSourcesKeepingRun(sources: SourceFile*) = {
+ private def compileSourcesKeepingRun(sources: SourceFile*) = {
val run = new Run()
reporter.reset()
run compileSources sources.toList
(!reporter.hasErrors, run)
}
- /** Compile an nsc SourceFile. Returns true if there are
- * no compilation errors, or false otherwise.
+ /**
+ * Compiles specified source files.
+ *
+ * @param sources The sequence of source files to compile
+ *
+ * @return True if successful, otherwise false
*/
+ @DeveloperApi
def compileSources(sources: SourceFile*): Boolean =
compileSourcesKeepingRun(sources: _*)._1
- /** Compile a string. Returns true if there are no
- * compilation errors, or false otherwise.
+ /**
+ * Compiles a string of code.
+ *
+ * @param code The string of code to compile
+ *
+ * @return True if successful, otherwise false
*/
+ @DeveloperApi
def compileString(code: String): Boolean =
compileSources(new BatchSourceFile("