diff --git a/assembly/pom.xml b/assembly/pom.xml
index 5ec9da22ae83f..31a01e4d8e1de 100644
--- a/assembly/pom.xml
+++ b/assembly/pom.xml
@@ -349,5 +349,15 @@
+
+ kinesis-asl
+
+
+ org.apache.httpcomponents
+ httpclient
+ ${commons.httpclient.version}
+
+
+
diff --git a/bagel/src/test/resources/log4j.properties b/bagel/src/test/resources/log4j.properties
index 30b4baa4d714a..789869f72e3b0 100644
--- a/bagel/src/test/resources/log4j.properties
+++ b/bagel/src/test/resources/log4j.properties
@@ -21,7 +21,7 @@ log4j.appender.file=org.apache.log4j.FileAppender
log4j.appender.file.append=false
log4j.appender.file.file=target/unit-tests.log
log4j.appender.file.layout=org.apache.log4j.PatternLayout
-log4j.appender.file.layout.ConversionPattern=%d{yy/MM/dd HH:mm:ss.SSS} %p %c{1}: %m%n
+log4j.appender.file.layout.ConversionPattern=%d{yy/MM/dd HH:mm:ss.SSS} %t %p %c{1}: %m%n
# Ignore messages below warning level from Jetty, because it's a bit verbose
log4j.logger.org.eclipse.jetty=WARN
diff --git a/bin/compute-classpath.cmd b/bin/compute-classpath.cmd
index 5ad52452a5c98..3cd0579aea8d3 100644
--- a/bin/compute-classpath.cmd
+++ b/bin/compute-classpath.cmd
@@ -36,7 +36,13 @@ rem Load environment variables from conf\spark-env.cmd, if it exists
if exist "%FWDIR%conf\spark-env.cmd" call "%FWDIR%conf\spark-env.cmd"
rem Build up classpath
-set CLASSPATH=%SPARK_CLASSPATH%;%SPARK_SUBMIT_CLASSPATH%;%FWDIR%conf
+set CLASSPATH=%SPARK_CLASSPATH%;%SPARK_SUBMIT_CLASSPATH%
+
+if not "x%SPARK_CONF_DIR%"=="x" (
+ set CLASSPATH=%CLASSPATH%;%SPARK_CONF_DIR%
+) else (
+ set CLASSPATH=%CLASSPATH%;%FWDIR%conf
+)
if exist "%FWDIR%RELEASE" (
for %%d in ("%FWDIR%lib\spark-assembly*.jar") do (
diff --git a/bin/compute-classpath.sh b/bin/compute-classpath.sh
index 0f63e36d8aeca..905bbaf99b374 100755
--- a/bin/compute-classpath.sh
+++ b/bin/compute-classpath.sh
@@ -27,8 +27,14 @@ FWDIR="$(cd "`dirname "$0"`"/..; pwd)"
. "$FWDIR"/bin/load-spark-env.sh
+CLASSPATH="$SPARK_CLASSPATH:$SPARK_SUBMIT_CLASSPATH"
+
# Build up classpath
-CLASSPATH="$SPARK_CLASSPATH:$SPARK_SUBMIT_CLASSPATH:$FWDIR/conf"
+if [ -n "$SPARK_CONF_DIR" ]; then
+ CLASSPATH="$CLASSPATH:$SPARK_CONF_DIR"
+else
+ CLASSPATH="$CLASSPATH:$FWDIR/conf"
+fi
ASSEMBLY_DIR="$FWDIR/assembly/target/scala-$SCALA_VERSION"
diff --git a/bin/pyspark b/bin/pyspark
index 5142411e36974..96f30a260a09e 100755
--- a/bin/pyspark
+++ b/bin/pyspark
@@ -50,9 +50,44 @@ fi
. "$FWDIR"/bin/load-spark-env.sh
-# Figure out which Python executable to use
+# In Spark <= 1.1, setting IPYTHON=1 would cause the driver to be launched using the `ipython`
+# executable, while the worker would still be launched using PYSPARK_PYTHON.
+#
+# In Spark 1.2, we removed the documentation of the IPYTHON and IPYTHON_OPTS variables and added
+# PYSPARK_DRIVER_PYTHON and PYSPARK_DRIVER_PYTHON_OPTS to allow IPython to be used for the driver.
+# Now, users can simply set PYSPARK_DRIVER_PYTHON=ipython to use IPython and set
+# PYSPARK_DRIVER_PYTHON_OPTS to pass options when starting the Python driver
+# (e.g. PYSPARK_DRIVER_PYTHON_OPTS='notebook'). This supports full customization of the IPython
+# and executor Python executables.
+#
+# For backwards-compatibility, we retain the old IPYTHON and IPYTHON_OPTS variables.
+
+# Determine the Python executable to use if PYSPARK_PYTHON or PYSPARK_DRIVER_PYTHON isn't set:
+if hash python2.7 2>/dev/null; then
+ # Attempt to use Python 2.7, if installed:
+ DEFAULT_PYTHON="python2.7"
+else
+ DEFAULT_PYTHON="python"
+fi
+
+# Determine the Python executable to use for the driver:
+if [[ -n "$IPYTHON_OPTS" || "$IPYTHON" == "1" ]]; then
+ # If IPython options are specified, assume user wants to run IPython
+ # (for backwards-compatibility)
+ PYSPARK_DRIVER_PYTHON_OPTS="$PYSPARK_DRIVER_PYTHON_OPTS $IPYTHON_OPTS"
+ PYSPARK_DRIVER_PYTHON="ipython"
+elif [[ -z "$PYSPARK_DRIVER_PYTHON" ]]; then
+ PYSPARK_DRIVER_PYTHON="${PYSPARK_PYTHON:-"$DEFAULT_PYTHON"}"
+fi
+
+# Determine the Python executable to use for the executors:
if [[ -z "$PYSPARK_PYTHON" ]]; then
- PYSPARK_PYTHON="python"
+ if [[ $PYSPARK_DRIVER_PYTHON == *ipython* && $DEFAULT_PYTHON != "python2.7" ]]; then
+ echo "IPython requires Python 2.7+; please install python2.7 or set PYSPARK_PYTHON" 1>&2
+ exit 1
+ else
+ PYSPARK_PYTHON="$DEFAULT_PYTHON"
+ fi
fi
export PYSPARK_PYTHON
@@ -64,11 +99,6 @@ export PYTHONPATH="$SPARK_HOME/python/lib/py4j-0.8.2.1-src.zip:$PYTHONPATH"
export OLD_PYTHONSTARTUP="$PYTHONSTARTUP"
export PYTHONSTARTUP="$FWDIR/python/pyspark/shell.py"
-# If IPython options are specified, assume user wants to run IPython
-if [[ -n "$IPYTHON_OPTS" ]]; then
- IPYTHON=1
-fi
-
# Build up arguments list manually to preserve quotes and backslashes.
# We export Spark submit arguments as an environment variable because shell.py must run as a
# PYTHONSTARTUP script, which does not take in arguments. This is required for IPython notebooks.
@@ -88,9 +118,9 @@ if [[ -n "$SPARK_TESTING" ]]; then
unset YARN_CONF_DIR
unset HADOOP_CONF_DIR
if [[ -n "$PYSPARK_DOC_TEST" ]]; then
- exec "$PYSPARK_PYTHON" -m doctest $1
+ exec "$PYSPARK_DRIVER_PYTHON" -m doctest $1
else
- exec "$PYSPARK_PYTHON" $1
+ exec "$PYSPARK_DRIVER_PYTHON" $1
fi
exit
fi
@@ -106,10 +136,5 @@ if [[ "$1" =~ \.py$ ]]; then
else
# PySpark shell requires special handling downstream
export PYSPARK_SHELL=1
- # Only use ipython if no command line arguments were provided [SPARK-1134]
- if [[ "$IPYTHON" = "1" ]]; then
- exec ${PYSPARK_PYTHON:-ipython} $IPYTHON_OPTS
- else
- exec "$PYSPARK_PYTHON"
- fi
+ exec "$PYSPARK_DRIVER_PYTHON" $PYSPARK_DRIVER_PYTHON_OPTS
fi
diff --git a/bin/pyspark2.cmd b/bin/pyspark2.cmd
index 2c4b08af8d4c3..a0e66abcc26c9 100644
--- a/bin/pyspark2.cmd
+++ b/bin/pyspark2.cmd
@@ -33,7 +33,7 @@ for %%d in ("%FWDIR%assembly\target\scala-%SCALA_VERSION%\spark-assembly*hadoop*
)
if [%FOUND_JAR%] == [0] (
echo Failed to find Spark assembly JAR.
- echo You need to build Spark with sbt\sbt assembly before running this program.
+ echo You need to build Spark before running this program.
goto exit
)
:skip_build_test
diff --git a/bin/run-example2.cmd b/bin/run-example2.cmd
index b29bf90c64e90..b49d0dcb4ff2d 100644
--- a/bin/run-example2.cmd
+++ b/bin/run-example2.cmd
@@ -52,7 +52,7 @@ if exist "%FWDIR%RELEASE" (
)
if "x%SPARK_EXAMPLES_JAR%"=="x" (
echo Failed to find Spark examples assembly JAR.
- echo You need to build Spark with sbt\sbt assembly before running this program.
+ echo You need to build Spark before running this program.
goto exit
)
diff --git a/bin/spark-class b/bin/spark-class
index 613dc9c4566f2..91d858bc063d0 100755
--- a/bin/spark-class
+++ b/bin/spark-class
@@ -105,7 +105,7 @@ else
exit 1
fi
fi
-JAVA_VERSION=$("$RUNNER" -version 2>&1 | sed 's/.* version "\(.*\)\.\(.*\)\..*"/\1\2/; 1q')
+JAVA_VERSION=$("$RUNNER" -version 2>&1 | grep 'version' | sed 's/.* version "\(.*\)\.\(.*\)\..*"/\1\2/; 1q')
# Set JAVA_OPTS to be able to load native libraries and to set heap size
if [ "$JAVA_VERSION" -ge 18 ]; then
@@ -146,7 +146,7 @@ fi
if [[ "$1" =~ org.apache.spark.tools.* ]]; then
if test -z "$SPARK_TOOLS_JAR"; then
echo "Failed to find Spark Tools Jar in $FWDIR/tools/target/scala-$SCALA_VERSION/" 1>&2
- echo "You need to build spark before running $1." 1>&2
+ echo "You need to build Spark before running $1." 1>&2
exit 1
fi
CLASSPATH="$CLASSPATH:$SPARK_TOOLS_JAR"
diff --git a/bin/spark-class2.cmd b/bin/spark-class2.cmd
index 6c5672819172b..da46543647efd 100644
--- a/bin/spark-class2.cmd
+++ b/bin/spark-class2.cmd
@@ -104,7 +104,7 @@ for %%d in ("%FWDIR%assembly\target\scala-%SCALA_VERSION%\spark-assembly*hadoop*
)
if "%FOUND_JAR%"=="0" (
echo Failed to find Spark assembly JAR.
- echo You need to build Spark with sbt\sbt assembly before running this program.
+ echo You need to build Spark before running this program.
goto exit
)
:skip_build_test
diff --git a/bin/spark-shell.cmd b/bin/spark-shell.cmd
index 2ee60b4e2a2b3..8f90ba5a0b3b8 100755
--- a/bin/spark-shell.cmd
+++ b/bin/spark-shell.cmd
@@ -17,6 +17,7 @@ rem See the License for the specific language governing permissions and
rem limitations under the License.
rem
-set SPARK_HOME=%~dp0..
+rem This is the entry point for running Spark shell. To avoid polluting the
+rem environment, it just launches a new cmd to do the real work.
-cmd /V /E /C %SPARK_HOME%\bin\spark-submit.cmd --class org.apache.spark.repl.Main %* spark-shell
+cmd /V /E /C %~dp0spark-shell2.cmd %*
diff --git a/bin/spark-shell2.cmd b/bin/spark-shell2.cmd
new file mode 100644
index 0000000000000..2ee60b4e2a2b3
--- /dev/null
+++ b/bin/spark-shell2.cmd
@@ -0,0 +1,22 @@
+@echo off
+
+rem
+rem Licensed to the Apache Software Foundation (ASF) under one or more
+rem contributor license agreements. See the NOTICE file distributed with
+rem this work for additional information regarding copyright ownership.
+rem The ASF licenses this file to You under the Apache License, Version 2.0
+rem (the "License"); you may not use this file except in compliance with
+rem the License. You may obtain a copy of the License at
+rem
+rem http://www.apache.org/licenses/LICENSE-2.0
+rem
+rem Unless required by applicable law or agreed to in writing, software
+rem distributed under the License is distributed on an "AS IS" BASIS,
+rem WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+rem See the License for the specific language governing permissions and
+rem limitations under the License.
+rem
+
+set SPARK_HOME=%~dp0..
+
+cmd /V /E /C %SPARK_HOME%\bin\spark-submit.cmd --class org.apache.spark.repl.Main %* spark-shell
diff --git a/bin/spark-sql b/bin/spark-sql
index 9d66140b6aa17..63d00437d508d 100755
--- a/bin/spark-sql
+++ b/bin/spark-sql
@@ -24,7 +24,6 @@
set -o posix
CLASS="org.apache.spark.sql.hive.thriftserver.SparkSQLCLIDriver"
-CLASS_NOT_FOUND_EXIT_STATUS=101
# Figure out where Spark is installed
FWDIR="$(cd "`dirname "$0"`"/..; pwd)"
@@ -53,13 +52,4 @@ source "$FWDIR"/bin/utils.sh
SUBMIT_USAGE_FUNCTION=usage
gatherSparkSubmitOpts "$@"
-"$FWDIR"/bin/spark-submit --class $CLASS "${SUBMISSION_OPTS[@]}" spark-internal "${APPLICATION_OPTS[@]}"
-exit_status=$?
-
-if [[ exit_status -eq CLASS_NOT_FOUND_EXIT_STATUS ]]; then
- echo
- echo "Failed to load Spark SQL CLI main class $CLASS."
- echo "You need to build Spark with -Phive."
-fi
-
-exit $exit_status
+exec "$FWDIR"/bin/spark-submit --class $CLASS "${SUBMISSION_OPTS[@]}" spark-internal "${APPLICATION_OPTS[@]}"
diff --git a/bin/spark-submit.cmd b/bin/spark-submit.cmd
index cf6046d1547ad..8f3b84c7b971d 100644
--- a/bin/spark-submit.cmd
+++ b/bin/spark-submit.cmd
@@ -17,52 +17,7 @@ rem See the License for the specific language governing permissions and
rem limitations under the License.
rem
-rem NOTE: Any changes in this file must be reflected in SparkSubmitDriverBootstrapper.scala!
+rem This is the entry point for running Spark submit. To avoid polluting the
+rem environment, it just launches a new cmd to do the real work.
-set SPARK_HOME=%~dp0..
-set ORIG_ARGS=%*
-
-rem Reset the values of all variables used
-set SPARK_SUBMIT_DEPLOY_MODE=client
-set SPARK_SUBMIT_PROPERTIES_FILE=%SPARK_HOME%\conf\spark-defaults.conf
-set SPARK_SUBMIT_DRIVER_MEMORY=
-set SPARK_SUBMIT_LIBRARY_PATH=
-set SPARK_SUBMIT_CLASSPATH=
-set SPARK_SUBMIT_OPTS=
-set SPARK_SUBMIT_BOOTSTRAP_DRIVER=
-
-:loop
-if [%1] == [] goto continue
- if [%1] == [--deploy-mode] (
- set SPARK_SUBMIT_DEPLOY_MODE=%2
- ) else if [%1] == [--properties-file] (
- set SPARK_SUBMIT_PROPERTIES_FILE=%2
- ) else if [%1] == [--driver-memory] (
- set SPARK_SUBMIT_DRIVER_MEMORY=%2
- ) else if [%1] == [--driver-library-path] (
- set SPARK_SUBMIT_LIBRARY_PATH=%2
- ) else if [%1] == [--driver-class-path] (
- set SPARK_SUBMIT_CLASSPATH=%2
- ) else if [%1] == [--driver-java-options] (
- set SPARK_SUBMIT_OPTS=%2
- )
- shift
-goto loop
-:continue
-
-rem For client mode, the driver will be launched in the same JVM that launches
-rem SparkSubmit, so we may need to read the properties file for any extra class
-rem paths, library paths, java options and memory early on. Otherwise, it will
-rem be too late by the time the driver JVM has started.
-
-if [%SPARK_SUBMIT_DEPLOY_MODE%] == [client] (
- if exist %SPARK_SUBMIT_PROPERTIES_FILE% (
- rem Parse the properties file only if the special configs exist
- for /f %%i in ('findstr /r /c:"^[\t ]*spark.driver.memory" /c:"^[\t ]*spark.driver.extra" ^
- %SPARK_SUBMIT_PROPERTIES_FILE%') do (
- set SPARK_SUBMIT_BOOTSTRAP_DRIVER=1
- )
- )
-)
-
-cmd /V /E /C %SPARK_HOME%\bin\spark-class.cmd org.apache.spark.deploy.SparkSubmit %ORIG_ARGS%
+cmd /V /E /C %~dp0spark-submit2.cmd %*
diff --git a/bin/spark-submit2.cmd b/bin/spark-submit2.cmd
new file mode 100644
index 0000000000000..cf6046d1547ad
--- /dev/null
+++ b/bin/spark-submit2.cmd
@@ -0,0 +1,68 @@
+@echo off
+
+rem
+rem Licensed to the Apache Software Foundation (ASF) under one or more
+rem contributor license agreements. See the NOTICE file distributed with
+rem this work for additional information regarding copyright ownership.
+rem The ASF licenses this file to You under the Apache License, Version 2.0
+rem (the "License"); you may not use this file except in compliance with
+rem the License. You may obtain a copy of the License at
+rem
+rem http://www.apache.org/licenses/LICENSE-2.0
+rem
+rem Unless required by applicable law or agreed to in writing, software
+rem distributed under the License is distributed on an "AS IS" BASIS,
+rem WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+rem See the License for the specific language governing permissions and
+rem limitations under the License.
+rem
+
+rem NOTE: Any changes in this file must be reflected in SparkSubmitDriverBootstrapper.scala!
+
+set SPARK_HOME=%~dp0..
+set ORIG_ARGS=%*
+
+rem Reset the values of all variables used
+set SPARK_SUBMIT_DEPLOY_MODE=client
+set SPARK_SUBMIT_PROPERTIES_FILE=%SPARK_HOME%\conf\spark-defaults.conf
+set SPARK_SUBMIT_DRIVER_MEMORY=
+set SPARK_SUBMIT_LIBRARY_PATH=
+set SPARK_SUBMIT_CLASSPATH=
+set SPARK_SUBMIT_OPTS=
+set SPARK_SUBMIT_BOOTSTRAP_DRIVER=
+
+:loop
+if [%1] == [] goto continue
+ if [%1] == [--deploy-mode] (
+ set SPARK_SUBMIT_DEPLOY_MODE=%2
+ ) else if [%1] == [--properties-file] (
+ set SPARK_SUBMIT_PROPERTIES_FILE=%2
+ ) else if [%1] == [--driver-memory] (
+ set SPARK_SUBMIT_DRIVER_MEMORY=%2
+ ) else if [%1] == [--driver-library-path] (
+ set SPARK_SUBMIT_LIBRARY_PATH=%2
+ ) else if [%1] == [--driver-class-path] (
+ set SPARK_SUBMIT_CLASSPATH=%2
+ ) else if [%1] == [--driver-java-options] (
+ set SPARK_SUBMIT_OPTS=%2
+ )
+ shift
+goto loop
+:continue
+
+rem For client mode, the driver will be launched in the same JVM that launches
+rem SparkSubmit, so we may need to read the properties file for any extra class
+rem paths, library paths, java options and memory early on. Otherwise, it will
+rem be too late by the time the driver JVM has started.
+
+if [%SPARK_SUBMIT_DEPLOY_MODE%] == [client] (
+ if exist %SPARK_SUBMIT_PROPERTIES_FILE% (
+ rem Parse the properties file only if the special configs exist
+ for /f %%i in ('findstr /r /c:"^[\t ]*spark.driver.memory" /c:"^[\t ]*spark.driver.extra" ^
+ %SPARK_SUBMIT_PROPERTIES_FILE%') do (
+ set SPARK_SUBMIT_BOOTSTRAP_DRIVER=1
+ )
+ )
+)
+
+cmd /V /E /C %SPARK_HOME%\bin\spark-class.cmd org.apache.spark.deploy.SparkSubmit %ORIG_ARGS%
diff --git a/bin/utils.sh b/bin/utils.sh
index 0804b1ed9f231..22ea2b9a6d586 100755
--- a/bin/utils.sh
+++ b/bin/utils.sh
@@ -17,7 +17,7 @@
# limitations under the License.
#
-# Gather all all spark-submit options into SUBMISSION_OPTS
+# Gather all spark-submit options into SUBMISSION_OPTS
function gatherSparkSubmitOpts() {
if [ -z "$SUBMIT_USAGE_FUNCTION" ]; then
diff --git a/core/pom.xml b/core/pom.xml
index e012c5e673b74..a5a178079bc57 100644
--- a/core/pom.xml
+++ b/core/pom.xml
@@ -322,6 +322,17 @@
+
+ maven-clean-plugin
+
+
+
+ ${basedir}/../python/build
+
+
+ true
+
+ org.apache.maven.pluginsmaven-shade-plugin
diff --git a/core/src/main/java/org/apache/spark/TaskContext.java b/core/src/main/java/org/apache/spark/TaskContext.java
index 4e6d708af0ea7..2d998d4c7a5d9 100644
--- a/core/src/main/java/org/apache/spark/TaskContext.java
+++ b/core/src/main/java/org/apache/spark/TaskContext.java
@@ -18,131 +18,55 @@
package org.apache.spark;
import java.io.Serializable;
-import java.util.ArrayList;
-import java.util.Collections;
-import java.util.List;
import scala.Function0;
import scala.Function1;
import scala.Unit;
-import scala.collection.JavaConversions;
import org.apache.spark.annotation.DeveloperApi;
import org.apache.spark.executor.TaskMetrics;
import org.apache.spark.util.TaskCompletionListener;
-import org.apache.spark.util.TaskCompletionListenerException;
/**
-* :: DeveloperApi ::
-* Contextual information about a task which can be read or mutated during execution.
-*/
-@DeveloperApi
-public class TaskContext implements Serializable {
-
- private int stageId;
- private int partitionId;
- private long attemptId;
- private boolean runningLocally;
- private TaskMetrics taskMetrics;
-
- /**
- * :: DeveloperApi ::
- * Contextual information about a task which can be read or mutated during execution.
- *
- * @param stageId stage id
- * @param partitionId index of the partition
- * @param attemptId the number of attempts to execute this task
- * @param runningLocally whether the task is running locally in the driver JVM
- * @param taskMetrics performance metrics of the task
- */
- @DeveloperApi
- public TaskContext(int stageId, int partitionId, long attemptId, boolean runningLocally,
- TaskMetrics taskMetrics) {
- this.attemptId = attemptId;
- this.partitionId = partitionId;
- this.runningLocally = runningLocally;
- this.stageId = stageId;
- this.taskMetrics = taskMetrics;
- }
-
- /**
- * :: DeveloperApi ::
- * Contextual information about a task which can be read or mutated during execution.
- *
- * @param stageId stage id
- * @param partitionId index of the partition
- * @param attemptId the number of attempts to execute this task
- * @param runningLocally whether the task is running locally in the driver JVM
- */
- @DeveloperApi
- public TaskContext(int stageId, int partitionId, long attemptId, boolean runningLocally) {
- this.attemptId = attemptId;
- this.partitionId = partitionId;
- this.runningLocally = runningLocally;
- this.stageId = stageId;
- this.taskMetrics = TaskMetrics.empty();
- }
-
+ * Contextual information about a task which can be read or mutated during
+ * execution. To access the TaskContext for a running task use
+ * TaskContext.get().
+ */
+public abstract class TaskContext implements Serializable {
/**
- * :: DeveloperApi ::
- * Contextual information about a task which can be read or mutated during execution.
- *
- * @param stageId stage id
- * @param partitionId index of the partition
- * @param attemptId the number of attempts to execute this task
+ * Return the currently active TaskContext. This can be called inside of
+ * user functions to access contextual information about running tasks.
*/
- @DeveloperApi
- public TaskContext(int stageId, int partitionId, long attemptId) {
- this.attemptId = attemptId;
- this.partitionId = partitionId;
- this.runningLocally = false;
- this.stageId = stageId;
- this.taskMetrics = TaskMetrics.empty();
+ public static TaskContext get() {
+ return taskContext.get();
}
private static ThreadLocal taskContext =
new ThreadLocal();
- /**
- * :: Internal API ::
- * This is spark internal API, not intended to be called from user programs.
- */
- public static void setTaskContext(TaskContext tc) {
+ static void setTaskContext(TaskContext tc) {
taskContext.set(tc);
}
- public static TaskContext get() {
- return taskContext.get();
- }
-
- /** :: Internal API :: */
- public static void unset() {
+ static void unset() {
taskContext.remove();
}
- // List of callback functions to execute when the task completes.
- private transient List onCompleteCallbacks =
- new ArrayList();
-
- // Whether the corresponding task has been killed.
- private volatile boolean interrupted = false;
-
- // Whether the task has completed.
- private volatile boolean completed = false;
-
/**
- * Checks whether the task has completed.
+ * Whether the task has completed.
*/
- public boolean isCompleted() {
- return completed;
- }
+ public abstract boolean isCompleted();
/**
- * Checks whether the task has been killed.
+ * Whether the task has been killed.
*/
- public boolean isInterrupted() {
- return interrupted;
- }
+ public abstract boolean isInterrupted();
+
+ /** @deprecated: use isRunningLocally() */
+ @Deprecated
+ public abstract boolean runningLocally();
+
+ public abstract boolean isRunningLocally();
/**
* Add a (Java friendly) listener to be executed on task completion.
@@ -150,10 +74,7 @@ public boolean isInterrupted() {
*
* An example use is for HadoopRDD to register a callback to close the input stream.
*/
- public TaskContext addTaskCompletionListener(TaskCompletionListener listener) {
- onCompleteCallbacks.add(listener);
- return this;
- }
+ public abstract TaskContext addTaskCompletionListener(TaskCompletionListener listener);
/**
* Add a listener in the form of a Scala closure to be executed on task completion.
@@ -161,109 +82,27 @@ public TaskContext addTaskCompletionListener(TaskCompletionListener listener) {
*
* An example use is for HadoopRDD to register a callback to close the input stream.
*/
- public TaskContext addTaskCompletionListener(final Function1 f) {
- onCompleteCallbacks.add(new TaskCompletionListener() {
- @Override
- public void onTaskCompletion(TaskContext context) {
- f.apply(context);
- }
- });
- return this;
- }
+ public abstract TaskContext addTaskCompletionListener(final Function1 f);
/**
* Add a callback function to be executed on task completion. An example use
* is for HadoopRDD to register a callback to close the input stream.
* Will be called in any situation - success, failure, or cancellation.
*
- * Deprecated: use addTaskCompletionListener
- *
+ * @deprecated: use addTaskCompletionListener
+ *
* @param f Callback function.
*/
@Deprecated
- public void addOnCompleteCallback(final Function0 f) {
- onCompleteCallbacks.add(new TaskCompletionListener() {
- @Override
- public void onTaskCompletion(TaskContext context) {
- f.apply();
- }
- });
- }
-
- /**
- * ::Internal API::
- * Marks the task as completed and triggers the listeners.
- */
- public void markTaskCompleted() throws TaskCompletionListenerException {
- completed = true;
- List errorMsgs = new ArrayList(2);
- // Process complete callbacks in the reverse order of registration
- List revlist =
- new ArrayList(onCompleteCallbacks);
- Collections.reverse(revlist);
- for (TaskCompletionListener tcl: revlist) {
- try {
- tcl.onTaskCompletion(this);
- } catch (Throwable e) {
- errorMsgs.add(e.getMessage());
- }
- }
-
- if (!errorMsgs.isEmpty()) {
- throw new TaskCompletionListenerException(JavaConversions.asScalaBuffer(errorMsgs));
- }
- }
-
- /**
- * ::Internal API::
- * Marks the task for interruption, i.e. cancellation.
- */
- public void markInterrupted() {
- interrupted = true;
- }
-
- @Deprecated
- /** Deprecated: use getStageId() */
- public int stageId() {
- return stageId;
- }
-
- @Deprecated
- /** Deprecated: use getPartitionId() */
- public int partitionId() {
- return partitionId;
- }
-
- @Deprecated
- /** Deprecated: use getAttemptId() */
- public long attemptId() {
- return attemptId;
- }
-
- @Deprecated
- /** Deprecated: use isRunningLocally() */
- public boolean runningLocally() {
- return runningLocally;
- }
-
- public boolean isRunningLocally() {
- return runningLocally;
- }
+ public abstract void addOnCompleteCallback(final Function0 f);
- public int getStageId() {
- return stageId;
- }
+ public abstract int stageId();
- public int getPartitionId() {
- return partitionId;
- }
+ public abstract int partitionId();
- public long getAttemptId() {
- return attemptId;
- }
+ public abstract long attemptId();
- /** ::Internal API:: */
- public TaskMetrics taskMetrics() {
- return taskMetrics;
- }
+ /** ::DeveloperApi:: */
+ @DeveloperApi
+ public abstract TaskMetrics taskMetrics();
}
diff --git a/core/src/main/java/org/apache/spark/api/java/JavaFutureAction.java b/core/src/main/java/org/apache/spark/api/java/JavaFutureAction.java
new file mode 100644
index 0000000000000..0ad189633e427
--- /dev/null
+++ b/core/src/main/java/org/apache/spark/api/java/JavaFutureAction.java
@@ -0,0 +1,33 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.api.java;
+
+
+import java.util.List;
+import java.util.concurrent.Future;
+
+public interface JavaFutureAction extends Future {
+
+ /**
+ * Returns the job IDs run by the underlying async operation.
+ *
+ * This returns the current snapshot of the job list. Certain operations may run multiple
+ * jobs, so multiple calls to this method may return different lists.
+ */
+ List jobIds();
+}
diff --git a/core/src/main/resources/org/apache/spark/ui/static/webui.css b/core/src/main/resources/org/apache/spark/ui/static/webui.css
index 445110d63e184..152bde5f6994f 100644
--- a/core/src/main/resources/org/apache/spark/ui/static/webui.css
+++ b/core/src/main/resources/org/apache/spark/ui/static/webui.css
@@ -51,6 +51,11 @@ table.sortable thead {
cursor: pointer;
}
+table.sortable td {
+ word-wrap: break-word;
+ max-width: 600px;
+}
+
.progress {
margin-bottom: 0px; position: relative
}
diff --git a/core/src/main/scala/org/apache/spark/CacheManager.scala b/core/src/main/scala/org/apache/spark/CacheManager.scala
index f8584b90cabe6..d89bb50076c9a 100644
--- a/core/src/main/scala/org/apache/spark/CacheManager.scala
+++ b/core/src/main/scala/org/apache/spark/CacheManager.scala
@@ -168,8 +168,6 @@ private[spark] class CacheManager(blockManager: BlockManager) extends Logging {
arr.iterator.asInstanceOf[Iterator[T]]
case Right(it) =>
// There is not enough space to cache this partition in memory
- logWarning(s"Not enough space to cache partition $key in memory! " +
- s"Free memory is ${blockManager.memoryStore.freeMemory} bytes.")
val returnValues = it.asInstanceOf[Iterator[T]]
if (putLevel.useDisk) {
logWarning(s"Persisting partition $key to disk instead.")
diff --git a/core/src/main/scala/org/apache/spark/FutureAction.scala b/core/src/main/scala/org/apache/spark/FutureAction.scala
index 75ea535f2f57b..d5c8f9d76c476 100644
--- a/core/src/main/scala/org/apache/spark/FutureAction.scala
+++ b/core/src/main/scala/org/apache/spark/FutureAction.scala
@@ -17,20 +17,21 @@
package org.apache.spark
-import scala.concurrent._
-import scala.concurrent.duration.Duration
-import scala.util.Try
+import java.util.Collections
+import java.util.concurrent.TimeUnit
-import org.apache.spark.annotation.Experimental
+import org.apache.spark.api.java.JavaFutureAction
import org.apache.spark.rdd.RDD
import org.apache.spark.scheduler.{JobFailed, JobSucceeded, JobWaiter}
+import scala.concurrent._
+import scala.concurrent.duration.Duration
+import scala.util.{Failure, Try}
+
/**
- * :: Experimental ::
* A future for the result of an action to support cancellation. This is an extension of the
* Scala Future interface to support cancellation.
*/
-@Experimental
trait FutureAction[T] extends Future[T] {
// Note that we redefine methods of the Future trait here explicitly so we can specify a different
// documentation (with reference to the word "action").
@@ -69,6 +70,11 @@ trait FutureAction[T] extends Future[T] {
*/
override def isCompleted: Boolean
+ /**
+ * Returns whether the action has been cancelled.
+ */
+ def isCancelled: Boolean
+
/**
* The value of this Future.
*
@@ -83,19 +89,29 @@ trait FutureAction[T] extends Future[T] {
*/
@throws(classOf[Exception])
def get(): T = Await.result(this, Duration.Inf)
+
+ /**
+ * Returns the job IDs run by the underlying async operation.
+ *
+ * This returns the current snapshot of the job list. Certain operations may run multiple
+ * jobs, so multiple calls to this method may return different lists.
+ */
+ def jobIds: Seq[Int]
+
}
/**
- * :: Experimental ::
* A [[FutureAction]] holding the result of an action that triggers a single job. Examples include
* count, collect, reduce.
*/
-@Experimental
class SimpleFutureAction[T] private[spark](jobWaiter: JobWaiter[_], resultFunc: => T)
extends FutureAction[T] {
+ @volatile private var _cancelled: Boolean = false
+
override def cancel() {
+ _cancelled = true
jobWaiter.cancel()
}
@@ -134,6 +150,8 @@ class SimpleFutureAction[T] private[spark](jobWaiter: JobWaiter[_], resultFunc:
}
override def isCompleted: Boolean = jobWaiter.jobFinished
+
+ override def isCancelled: Boolean = _cancelled
override def value: Option[Try[T]] = {
if (jobWaiter.jobFinished) {
@@ -150,18 +168,15 @@ class SimpleFutureAction[T] private[spark](jobWaiter: JobWaiter[_], resultFunc:
}
}
- /** Get the corresponding job id for this action. */
- def jobId = jobWaiter.jobId
+ def jobIds = Seq(jobWaiter.jobId)
}
/**
- * :: Experimental ::
* A [[FutureAction]] for actions that could trigger multiple Spark jobs. Examples include take,
* takeSample. Cancellation works by setting the cancelled flag to true and interrupting the
* action thread if it is being blocked by a job.
*/
-@Experimental
class ComplexFutureAction[T] extends FutureAction[T] {
// Pointer to the thread that is executing the action. It is set when the action is run.
@@ -171,6 +186,8 @@ class ComplexFutureAction[T] extends FutureAction[T] {
// is cancelled before the action was even run (and thus we have no thread to interrupt).
@volatile private var _cancelled: Boolean = false
+ @volatile private var jobs: Seq[Int] = Nil
+
// A promise used to signal the future.
private val p = promise[T]()
@@ -212,13 +229,15 @@ class ComplexFutureAction[T] extends FutureAction[T] {
// If the action hasn't been cancelled yet, submit the job. The check and the submitJob
// command need to be in an atomic block.
val job = this.synchronized {
- if (!cancelled) {
+ if (!isCancelled) {
rdd.context.submitJob(rdd, processPartition, partitions, resultHandler, resultFunc)
} else {
throw new SparkException("Action has been cancelled")
}
}
+ this.jobs = jobs ++ job.jobIds
+
// Wait for the job to complete. If the action is cancelled (with an interrupt),
// cancel the job and stop the execution. This is not in a synchronized block because
// Await.ready eventually waits on the monitor in FutureJob.jobWaiter.
@@ -231,10 +250,7 @@ class ComplexFutureAction[T] extends FutureAction[T] {
}
}
- /**
- * Returns whether the promise has been cancelled.
- */
- def cancelled: Boolean = _cancelled
+ override def isCancelled: Boolean = _cancelled
@throws(classOf[InterruptedException])
@throws(classOf[scala.concurrent.TimeoutException])
@@ -255,4 +271,59 @@ class ComplexFutureAction[T] extends FutureAction[T] {
override def isCompleted: Boolean = p.isCompleted
override def value: Option[Try[T]] = p.future.value
+
+ def jobIds = jobs
+
+}
+
+private[spark]
+class JavaFutureActionWrapper[S, T](futureAction: FutureAction[S], converter: S => T)
+ extends JavaFutureAction[T] {
+
+ import scala.collection.JavaConverters._
+
+ override def isCancelled: Boolean = futureAction.isCancelled
+
+ override def isDone: Boolean = {
+ // According to java.util.Future's Javadoc, this returns True if the task was completed,
+ // whether that completion was due to successful execution, an exception, or a cancellation.
+ futureAction.isCancelled || futureAction.isCompleted
+ }
+
+ override def jobIds(): java.util.List[java.lang.Integer] = {
+ Collections.unmodifiableList(futureAction.jobIds.map(Integer.valueOf).asJava)
+ }
+
+ private def getImpl(timeout: Duration): T = {
+ // This will throw TimeoutException on timeout:
+ Await.ready(futureAction, timeout)
+ futureAction.value.get match {
+ case scala.util.Success(value) => converter(value)
+ case Failure(exception) =>
+ if (isCancelled) {
+ throw new CancellationException("Job cancelled").initCause(exception)
+ } else {
+ // java.util.Future.get() wraps exceptions in ExecutionException
+ throw new ExecutionException("Exception thrown by job", exception)
+ }
+ }
+ }
+
+ override def get(): T = getImpl(Duration.Inf)
+
+ override def get(timeout: Long, unit: TimeUnit): T =
+ getImpl(Duration.fromNanos(unit.toNanos(timeout)))
+
+ override def cancel(mayInterruptIfRunning: Boolean): Boolean = synchronized {
+ if (isDone) {
+ // According to java.util.Future's Javadoc, this should return false if the task is completed.
+ false
+ } else {
+ // We're limited in terms of the semantics we can provide here; our cancellation is
+ // asynchronous and doesn't provide a mechanism to not cancel if the job is running.
+ futureAction.cancel()
+ true
+ }
+ }
+
}
diff --git a/core/src/main/scala/org/apache/spark/SecurityManager.scala b/core/src/main/scala/org/apache/spark/SecurityManager.scala
index 3832a780ec4bc..0e0f1a7b2377e 100644
--- a/core/src/main/scala/org/apache/spark/SecurityManager.scala
+++ b/core/src/main/scala/org/apache/spark/SecurityManager.scala
@@ -103,10 +103,9 @@ import org.apache.spark.deploy.SparkHadoopUtil
* and a Server, so for a particular connection is has to determine what to do.
* A ConnectionId was added to be able to track connections and is used to
* match up incoming messages with connections waiting for authentication.
- * If its acting as a client and trying to send a message to another ConnectionManager,
- * it blocks the thread calling sendMessage until the SASL negotiation has occurred.
* The ConnectionManager tracks all the sendingConnections using the ConnectionId
- * and waits for the response from the server and does the handshake.
+ * and waits for the response from the server and does the handshake before sending
+ * the real message.
*
* - HTTP for the Spark UI -> the UI was changed to use servlets so that javax servlet filters
* can be used. Yarn requires a specific AmIpFilter be installed for security to work
diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala
index 979d178c35969..dd3157990ef2d 100644
--- a/core/src/main/scala/org/apache/spark/SparkContext.scala
+++ b/core/src/main/scala/org/apache/spark/SparkContext.scala
@@ -21,6 +21,7 @@ import scala.language.implicitConversions
import java.io._
import java.net.URI
+import java.util.Arrays
import java.util.concurrent.atomic.AtomicInteger
import java.util.{Properties, UUID}
import java.util.UUID.randomUUID
@@ -187,6 +188,15 @@ class SparkContext(config: SparkConf) extends Logging {
val master = conf.get("spark.master")
val appName = conf.get("spark.app.name")
+ private[spark] val isEventLogEnabled = conf.getBoolean("spark.eventLog.enabled", false)
+ private[spark] val eventLogDir: Option[String] = {
+ if (isEventLogEnabled) {
+ Some(conf.get("spark.eventLog.dir", EventLoggingListener.DEFAULT_LOG_DIR).stripSuffix("/"))
+ } else {
+ None
+ }
+ }
+
// Generate the random name for a temp folder in Tachyon
// Add a timestamp as the suffix here to make it more safe
val tachyonFolderName = "spark-" + randomUUID.toString()
@@ -200,6 +210,7 @@ class SparkContext(config: SparkConf) extends Logging {
private[spark] val listenerBus = new LiveListenerBus
// Create the Spark execution environment (cache, map output tracker, etc)
+ conf.set("spark.executor.id", "driver")
private[spark] val env = SparkEnv.create(
conf,
"",
@@ -227,24 +238,10 @@ class SparkContext(config: SparkConf) extends Logging {
// For tests, do not enable the UI
None
}
- ui.foreach(_.bind())
/** A default Hadoop Configuration for the Hadoop code (e.g. file systems) that we reuse. */
val hadoopConfiguration = SparkHadoopUtil.get.newConfiguration(conf)
- // Optionally log Spark events
- private[spark] val eventLogger: Option[EventLoggingListener] = {
- if (conf.getBoolean("spark.eventLog.enabled", false)) {
- val logger = new EventLoggingListener(appName, conf, hadoopConfiguration)
- logger.start()
- listenerBus.addListener(logger)
- Some(logger)
- } else None
- }
-
- // At this point, all relevant SparkListeners have been registered, so begin releasing events
- listenerBus.start()
-
val startTime = System.currentTimeMillis()
// Add each JAR given through the constructor
@@ -309,6 +306,29 @@ class SparkContext(config: SparkConf) extends Logging {
// constructor
taskScheduler.start()
+ val applicationId: String = taskScheduler.applicationId()
+ conf.set("spark.app.id", applicationId)
+
+ val metricsSystem = env.metricsSystem
+
+ // The metrics system for Driver need to be set spark.app.id to app ID.
+ // So it should start after we get app ID from the task scheduler and set spark.app.id.
+ metricsSystem.start()
+
+ // Optionally log Spark events
+ private[spark] val eventLogger: Option[EventLoggingListener] = {
+ if (isEventLogEnabled) {
+ val logger =
+ new EventLoggingListener(applicationId, eventLogDir.get, conf, hadoopConfiguration)
+ logger.start()
+ listenerBus.addListener(logger)
+ Some(logger)
+ } else None
+ }
+
+ // At this point, all relevant SparkListeners have been registered, so begin releasing events
+ listenerBus.start()
+
private[spark] val cleaner: Option[ContextCleaner] = {
if (conf.getBoolean("spark.cleaner.referenceTracking", true)) {
Some(new ContextCleaner(this))
@@ -321,6 +341,10 @@ class SparkContext(config: SparkConf) extends Logging {
postEnvironmentUpdate()
postApplicationStart()
+ // Bind the SparkUI after starting the task scheduler
+ // because certain pages and listeners depend on it
+ ui.foreach(_.bind())
+
private[spark] var checkpointDir: Option[String] = None
// Thread Local variable that can be used by users to pass information down the stack
@@ -411,8 +435,8 @@ class SparkContext(config: SparkConf) extends Logging {
// Post init
taskScheduler.postStartHook()
- private val dagSchedulerSource = new DAGSchedulerSource(this.dagScheduler, this)
- private val blockManagerSource = new BlockManagerSource(SparkEnv.get.blockManager, this)
+ private val dagSchedulerSource = new DAGSchedulerSource(this.dagScheduler)
+ private val blockManagerSource = new BlockManagerSource(SparkEnv.get.blockManager)
private def initDriverMetrics() {
SparkEnv.get.metricsSystem.registerSource(dagSchedulerSource)
@@ -759,20 +783,20 @@ class SparkContext(config: SparkConf) extends Logging {
/**
* Create an [[org.apache.spark.Accumulable]] shared variable, to which tasks can add values
* with `+=`. Only the driver can access the accumuable's `value`.
- * @tparam T accumulator type
- * @tparam R type that can be added to the accumulator
+ * @tparam R accumulator result type
+ * @tparam T type that can be added to the accumulator
*/
- def accumulable[T, R](initialValue: T)(implicit param: AccumulableParam[T, R]) =
+ def accumulable[R, T](initialValue: R)(implicit param: AccumulableParam[R, T]) =
new Accumulable(initialValue, param)
/**
* Create an [[org.apache.spark.Accumulable]] shared variable, with a name for display in the
* Spark UI. Tasks can add values to the accumuable using the `+=` operator. Only the driver can
* access the accumuable's `value`.
- * @tparam T accumulator type
- * @tparam R type that can be added to the accumulator
+ * @tparam R accumulator result type
+ * @tparam T type that can be added to the accumulator
*/
- def accumulable[T, R](initialValue: T, name: String)(implicit param: AccumulableParam[T, R]) =
+ def accumulable[R, T](initialValue: R, name: String)(implicit param: AccumulableParam[R, T]) =
new Accumulable(initialValue, param, Some(name))
/**
@@ -794,6 +818,8 @@ class SparkContext(config: SparkConf) extends Logging {
*/
def broadcast[T: ClassTag](value: T): Broadcast[T] = {
val bc = env.broadcastManager.newBroadcast[T](value, isLocal)
+ val callSite = getCallSite
+ logInfo("Created broadcast " + bc.id + " from " + callSite.shortForm)
cleaner.foreach(_.registerBroadcastForCleanup(bc))
bc
}
@@ -1278,7 +1304,7 @@ class SparkContext(config: SparkConf) extends Logging {
private def postApplicationStart() {
// Note: this code assumes that the task scheduler has been initialized and has contacted
// the cluster manager to get an application ID (in case the cluster manager provides one).
- listenerBus.post(SparkListenerApplicationStart(appName, taskScheduler.applicationId(),
+ listenerBus.post(SparkListenerApplicationStart(appName, Some(applicationId),
startTime, sparkUser))
}
@@ -1409,7 +1435,10 @@ object SparkContext extends Logging {
simpleWritableConverter[Boolean, BooleanWritable](_.get)
implicit def bytesWritableConverter(): WritableConverter[Array[Byte]] = {
- simpleWritableConverter[Array[Byte], BytesWritable](_.getBytes)
+ simpleWritableConverter[Array[Byte], BytesWritable](bw =>
+ // getBytes method returns array which is longer then data to be returned
+ Arrays.copyOfRange(bw.getBytes, 0, bw.getLength)
+ )
}
implicit def stringWritableConverter(): WritableConverter[String] =
diff --git a/core/src/main/scala/org/apache/spark/SparkEnv.scala b/core/src/main/scala/org/apache/spark/SparkEnv.scala
index 009ed64775844..aba713cb4267a 100644
--- a/core/src/main/scala/org/apache/spark/SparkEnv.scala
+++ b/core/src/main/scala/org/apache/spark/SparkEnv.scala
@@ -43,9 +43,8 @@ import org.apache.spark.util.{AkkaUtils, Utils}
* :: DeveloperApi ::
* Holds all the runtime environment objects for a running Spark instance (either master or worker),
* including the serializer, Akka actor system, block manager, map output tracker, etc. Currently
- * Spark code finds the SparkEnv through a thread-local variable, so each thread that accesses these
- * objects needs to have the right SparkEnv set. You can get the current environment with
- * SparkEnv.get (e.g. after creating a SparkContext) and set it with SparkEnv.set.
+ * Spark code finds the SparkEnv through a global variable, so all the threads can access the same
+ * SparkEnv. It can be accessed by SparkEnv.get (e.g. after creating a SparkContext).
*
* NOTE: This is not intended for external use. This is exposed for Shark and may be made private
* in a future release.
@@ -119,30 +118,28 @@ class SparkEnv (
}
object SparkEnv extends Logging {
- private val env = new ThreadLocal[SparkEnv]
- @volatile private var lastSetSparkEnv : SparkEnv = _
+ @volatile private var env: SparkEnv = _
private[spark] val driverActorSystemName = "sparkDriver"
private[spark] val executorActorSystemName = "sparkExecutor"
def set(e: SparkEnv) {
- lastSetSparkEnv = e
- env.set(e)
+ env = e
}
/**
- * Returns the ThreadLocal SparkEnv, if non-null. Else returns the SparkEnv
- * previously set in any thread.
+ * Returns the SparkEnv.
*/
def get: SparkEnv = {
- Option(env.get()).getOrElse(lastSetSparkEnv)
+ env
}
/**
* Returns the ThreadLocal SparkEnv.
*/
+ @deprecated("Use SparkEnv.get instead", "1.2")
def getThreadLocal: SparkEnv = {
- env.get()
+ env
}
private[spark] def create(
@@ -259,11 +256,15 @@ object SparkEnv extends Logging {
}
val metricsSystem = if (isDriver) {
+ // Don't start metrics system right now for Driver.
+ // We need to wait for the task scheduler to give us an app ID.
+ // Then we can start the metrics system.
MetricsSystem.createMetricsSystem("driver", conf, securityManager)
} else {
- MetricsSystem.createMetricsSystem("executor", conf, securityManager)
+ val ms = MetricsSystem.createMetricsSystem("executor", conf, securityManager)
+ ms.start()
+ ms
}
- metricsSystem.start()
// Set the sparkFiles directory, used when downloading dependencies. In local mode,
// this is a temporary directory; in distributed mode, this is the executor's current working
diff --git a/core/src/main/scala/org/apache/spark/TaskContextHelper.scala b/core/src/main/scala/org/apache/spark/TaskContextHelper.scala
new file mode 100644
index 0000000000000..4636c4600a01a
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/TaskContextHelper.scala
@@ -0,0 +1,29 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark
+
+/**
+ * This class exists to restrict the visibility of TaskContext setters.
+ */
+private [spark] object TaskContextHelper {
+
+ def setTaskContext(tc: TaskContext): Unit = TaskContext.setTaskContext(tc)
+
+ def unset(): Unit = TaskContext.unset()
+
+}
diff --git a/core/src/main/scala/org/apache/spark/TaskContextImpl.scala b/core/src/main/scala/org/apache/spark/TaskContextImpl.scala
new file mode 100644
index 0000000000000..afd2b85d33a77
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/TaskContextImpl.scala
@@ -0,0 +1,91 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark
+
+import org.apache.spark.executor.TaskMetrics
+import org.apache.spark.util.{TaskCompletionListener, TaskCompletionListenerException}
+
+import scala.collection.mutable.ArrayBuffer
+
+private[spark] class TaskContextImpl(val stageId: Int,
+ val partitionId: Int,
+ val attemptId: Long,
+ val runningLocally: Boolean = false,
+ val taskMetrics: TaskMetrics = TaskMetrics.empty)
+ extends TaskContext
+ with Logging {
+
+ // List of callback functions to execute when the task completes.
+ @transient private val onCompleteCallbacks = new ArrayBuffer[TaskCompletionListener]
+
+ // Whether the corresponding task has been killed.
+ @volatile private var interrupted: Boolean = false
+
+ // Whether the task has completed.
+ @volatile private var completed: Boolean = false
+
+ override def addTaskCompletionListener(listener: TaskCompletionListener): this.type = {
+ onCompleteCallbacks += listener
+ this
+ }
+
+ override def addTaskCompletionListener(f: TaskContext => Unit): this.type = {
+ onCompleteCallbacks += new TaskCompletionListener {
+ override def onTaskCompletion(context: TaskContext): Unit = f(context)
+ }
+ this
+ }
+
+ @deprecated("use addTaskCompletionListener", "1.1.0")
+ override def addOnCompleteCallback(f: () => Unit) {
+ onCompleteCallbacks += new TaskCompletionListener {
+ override def onTaskCompletion(context: TaskContext): Unit = f()
+ }
+ }
+
+ /** Marks the task as completed and triggers the listeners. */
+ private[spark] def markTaskCompleted(): Unit = {
+ completed = true
+ val errorMsgs = new ArrayBuffer[String](2)
+ // Process complete callbacks in the reverse order of registration
+ onCompleteCallbacks.reverse.foreach { listener =>
+ try {
+ listener.onTaskCompletion(this)
+ } catch {
+ case e: Throwable =>
+ errorMsgs += e.getMessage
+ logError("Error in TaskCompletionListener", e)
+ }
+ }
+ if (errorMsgs.nonEmpty) {
+ throw new TaskCompletionListenerException(errorMsgs)
+ }
+ }
+
+ /** Marks the task for interruption, i.e. cancellation. */
+ private[spark] def markInterrupted(): Unit = {
+ interrupted = true
+ }
+
+ override def isCompleted: Boolean = completed
+
+ override def isRunningLocally: Boolean = runningLocally
+
+ override def isInterrupted: Boolean = interrupted
+}
+
diff --git a/core/src/main/scala/org/apache/spark/TestUtils.scala b/core/src/main/scala/org/apache/spark/TestUtils.scala
index 8ca731038e528..e72826dc25f41 100644
--- a/core/src/main/scala/org/apache/spark/TestUtils.scala
+++ b/core/src/main/scala/org/apache/spark/TestUtils.scala
@@ -26,6 +26,8 @@ import scala.collection.JavaConversions._
import javax.tools.{JavaFileObject, SimpleJavaFileObject, ToolProvider}
import com.google.common.io.Files
+import org.apache.spark.util.Utils
+
/**
* Utilities for tests. Included in main codebase since it's used by multiple
* projects.
@@ -42,8 +44,7 @@ private[spark] object TestUtils {
* in order to avoid interference between tests.
*/
def createJarWithClasses(classNames: Seq[String], value: String = ""): URL = {
- val tempDir = Files.createTempDir()
- tempDir.deleteOnExit()
+ val tempDir = Utils.createTempDir()
val files = for (name <- classNames) yield createCompiledClass(name, tempDir, value)
val jarFile = new File(tempDir, "testJar-%s.jar".format(System.currentTimeMillis()))
createJar(files, jarFile)
diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala b/core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala
index 0846225e4f992..c38b96528d037 100644
--- a/core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala
+++ b/core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala
@@ -35,6 +35,7 @@ import org.apache.spark.Partitioner._
import org.apache.spark.SparkContext.rddToPairRDDFunctions
import org.apache.spark.annotation.Experimental
import org.apache.spark.api.java.JavaSparkContext.fakeClassTag
+import org.apache.spark.api.java.JavaUtils.mapAsSerializableJavaMap
import org.apache.spark.api.java.function.{Function => JFunction, Function2 => JFunction2, PairFunction}
import org.apache.spark.partial.{BoundedDouble, PartialResult}
import org.apache.spark.rdd.{OrderedRDDFunctions, RDD}
@@ -265,10 +266,10 @@ class JavaPairRDD[K, V](val rdd: RDD[(K, V)])
* before sending results to a reducer, similarly to a "combiner" in MapReduce.
*/
def reduceByKeyLocally(func: JFunction2[V, V, V]): java.util.Map[K, V] =
- mapAsJavaMap(rdd.reduceByKeyLocally(func))
+ mapAsSerializableJavaMap(rdd.reduceByKeyLocally(func))
/** Count the number of elements for each key, and return the result to the master as a Map. */
- def countByKey(): java.util.Map[K, Long] = mapAsJavaMap(rdd.countByKey())
+ def countByKey(): java.util.Map[K, Long] = mapAsSerializableJavaMap(rdd.countByKey())
/**
* :: Experimental ::
@@ -277,7 +278,7 @@ class JavaPairRDD[K, V](val rdd: RDD[(K, V)])
*/
@Experimental
def countByKeyApprox(timeout: Long): PartialResult[java.util.Map[K, BoundedDouble]] =
- rdd.countByKeyApprox(timeout).map(mapAsJavaMap)
+ rdd.countByKeyApprox(timeout).map(mapAsSerializableJavaMap)
/**
* :: Experimental ::
@@ -287,7 +288,7 @@ class JavaPairRDD[K, V](val rdd: RDD[(K, V)])
@Experimental
def countByKeyApprox(timeout: Long, confidence: Double = 0.95)
: PartialResult[java.util.Map[K, BoundedDouble]] =
- rdd.countByKeyApprox(timeout, confidence).map(mapAsJavaMap)
+ rdd.countByKeyApprox(timeout, confidence).map(mapAsSerializableJavaMap)
/**
* Aggregate the values of each key, using given combine functions and a neutral "zero value".
@@ -614,7 +615,8 @@ class JavaPairRDD[K, V](val rdd: RDD[(K, V)])
/**
* Return the key-value pairs in this RDD to the master as a Map.
*/
- def collectAsMap(): java.util.Map[K, V] = mapAsJavaMap(rdd.collectAsMap())
+ def collectAsMap(): java.util.Map[K, V] = mapAsSerializableJavaMap(rdd.collectAsMap())
+
/**
* Pass each value in the key-value pair RDD through a map function without changing the keys;
diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaRDDLike.scala b/core/src/main/scala/org/apache/spark/api/java/JavaRDDLike.scala
index 545bc0e9e99ed..efb8978f7ce12 100644
--- a/core/src/main/scala/org/apache/spark/api/java/JavaRDDLike.scala
+++ b/core/src/main/scala/org/apache/spark/api/java/JavaRDDLike.scala
@@ -21,15 +21,18 @@ import java.util.{Comparator, List => JList, Iterator => JIterator}
import java.lang.{Iterable => JIterable, Long => JLong}
import scala.collection.JavaConversions._
+import scala.collection.JavaConverters._
import scala.reflect.ClassTag
import com.google.common.base.Optional
import org.apache.hadoop.io.compress.CompressionCodec
-import org.apache.spark.{FutureAction, Partition, SparkContext, TaskContext}
+import org.apache.spark._
+import org.apache.spark.SparkContext._
import org.apache.spark.annotation.Experimental
import org.apache.spark.api.java.JavaPairRDD._
import org.apache.spark.api.java.JavaSparkContext.fakeClassTag
+import org.apache.spark.api.java.JavaUtils.mapAsSerializableJavaMap
import org.apache.spark.api.java.function.{Function => JFunction, Function2 => JFunction2, _}
import org.apache.spark.partial.{BoundedDouble, PartialResult}
import org.apache.spark.rdd.RDD
@@ -293,8 +296,7 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable {
* Applies a function f to all elements of this RDD.
*/
def foreach(f: VoidFunction[T]) {
- val cleanF = rdd.context.clean((x: T) => f.call(x))
- rdd.foreach(cleanF)
+ rdd.foreach(x => f.call(x))
}
/**
@@ -390,7 +392,7 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable {
* combine step happens locally on the master, equivalent to running a single reduce task.
*/
def countByValue(): java.util.Map[T, java.lang.Long] =
- mapAsJavaMap(rdd.countByValue().map((x => (x._1, new java.lang.Long(x._2)))))
+ mapAsSerializableJavaMap(rdd.countByValue().map((x => (x._1, new java.lang.Long(x._2)))))
/**
* (Experimental) Approximate version of countByValue().
@@ -399,13 +401,13 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable {
timeout: Long,
confidence: Double
): PartialResult[java.util.Map[T, BoundedDouble]] =
- rdd.countByValueApprox(timeout, confidence).map(mapAsJavaMap)
+ rdd.countByValueApprox(timeout, confidence).map(mapAsSerializableJavaMap)
/**
* (Experimental) Approximate version of countByValue().
*/
def countByValueApprox(timeout: Long): PartialResult[java.util.Map[T, BoundedDouble]] =
- rdd.countByValueApprox(timeout).map(mapAsJavaMap)
+ rdd.countByValueApprox(timeout).map(mapAsSerializableJavaMap)
/**
* Take the first num elements of the RDD. This currently scans the partitions *one by one*, so
@@ -575,16 +577,44 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable {
def name(): String = rdd.name
/**
- * :: Experimental ::
- * The asynchronous version of the foreach action.
- *
- * @param f the function to apply to all the elements of the RDD
- * @return a FutureAction for the action
+ * The asynchronous version of `count`, which returns a
+ * future for counting the number of elements in this RDD.
*/
- @Experimental
- def foreachAsync(f: VoidFunction[T]): FutureAction[Unit] = {
- import org.apache.spark.SparkContext._
- rdd.foreachAsync(x => f.call(x))
+ def countAsync(): JavaFutureAction[JLong] = {
+ new JavaFutureActionWrapper[Long, JLong](rdd.countAsync(), JLong.valueOf)
+ }
+
+ /**
+ * The asynchronous version of `collect`, which returns a future for
+ * retrieving an array containing all of the elements in this RDD.
+ */
+ def collectAsync(): JavaFutureAction[JList[T]] = {
+ new JavaFutureActionWrapper(rdd.collectAsync(), (x: Seq[T]) => x.asJava)
+ }
+
+ /**
+ * The asynchronous version of the `take` action, which returns a
+ * future for retrieving the first `num` elements of this RDD.
+ */
+ def takeAsync(num: Int): JavaFutureAction[JList[T]] = {
+ new JavaFutureActionWrapper(rdd.takeAsync(num), (x: Seq[T]) => x.asJava)
}
+ /**
+ * The asynchronous version of the `foreach` action, which
+ * applies a function f to all the elements of this RDD.
+ */
+ def foreachAsync(f: VoidFunction[T]): JavaFutureAction[Void] = {
+ new JavaFutureActionWrapper[Unit, Void](rdd.foreachAsync(x => f.call(x)),
+ { x => null.asInstanceOf[Void] })
+ }
+
+ /**
+ * The asynchronous version of the `foreachPartition` action, which
+ * applies a function f to each partition of this RDD.
+ */
+ def foreachPartitionAsync(f: VoidFunction[java.util.Iterator[T]]): JavaFutureAction[Void] = {
+ new JavaFutureActionWrapper[Unit, Void](rdd.foreachPartitionAsync(x => f.call(x)),
+ { x => null.asInstanceOf[Void] })
+ }
}
diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaUtils.scala b/core/src/main/scala/org/apache/spark/api/java/JavaUtils.scala
index 22810cb1c662d..b52d0a5028e84 100644
--- a/core/src/main/scala/org/apache/spark/api/java/JavaUtils.scala
+++ b/core/src/main/scala/org/apache/spark/api/java/JavaUtils.scala
@@ -19,10 +19,20 @@ package org.apache.spark.api.java
import com.google.common.base.Optional
+import scala.collection.convert.Wrappers.MapWrapper
+
private[spark] object JavaUtils {
def optionToOptional[T](option: Option[T]): Optional[T] =
option match {
case Some(value) => Optional.of(value)
case None => Optional.absent()
}
+
+ // Workaround for SPARK-3926 / SI-8911
+ def mapAsSerializableJavaMap[A, B](underlying: collection.Map[A, B]) =
+ new SerializableMapWrapper(underlying)
+
+ class SerializableMapWrapper[A, B](underlying: collection.Map[A, B])
+ extends MapWrapper(underlying) with java.io.Serializable
+
}
diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
index f9ff4ea6ca157..29ca751519abd 100644
--- a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
+++ b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
@@ -23,10 +23,9 @@ import java.nio.charset.Charset
import java.util.{List => JList, ArrayList => JArrayList, Map => JMap, Collections}
import scala.collection.JavaConversions._
+import scala.collection.JavaConverters._
import scala.collection.mutable
import scala.language.existentials
-import scala.reflect.ClassTag
-import scala.util.{Try, Success, Failure}
import net.razorvine.pickle.{Pickler, Unpickler}
@@ -42,7 +41,7 @@ import org.apache.spark.rdd.RDD
import org.apache.spark.util.Utils
private[spark] class PythonRDD(
- parent: RDD[_],
+ @transient parent: RDD[_],
command: Array[Byte],
envVars: JMap[String, String],
pythonIncludes: JList[String],
@@ -55,9 +54,9 @@ private[spark] class PythonRDD(
val bufferSize = conf.getInt("spark.buffer.size", 65536)
val reuse_worker = conf.getBoolean("spark.python.worker.reuse", true)
- override def getPartitions = parent.partitions
+ override def getPartitions = firstParent.partitions
- override val partitioner = if (preservePartitoning) parent.partitioner else None
+ override val partitioner = if (preservePartitoning) firstParent.partitioner else None
override def compute(split: Partition, context: TaskContext): Iterator[Array[Byte]] = {
val startTime = System.currentTimeMillis
@@ -196,7 +195,6 @@ private[spark] class PythonRDD(
override def run(): Unit = Utils.logUncaughtExceptions {
try {
- SparkEnv.set(env)
val stream = new BufferedOutputStream(worker.getOutputStream, bufferSize)
val dataOut = new DataOutputStream(stream)
// Partition index
@@ -235,7 +233,7 @@ private[spark] class PythonRDD(
dataOut.writeInt(command.length)
dataOut.write(command)
// Data values
- PythonRDD.writeIteratorToStream(parent.iterator(split, context), dataOut)
+ PythonRDD.writeIteratorToStream(firstParent.iterator(split, context), dataOut)
dataOut.writeInt(SpecialLengths.END_OF_DATA_SECTION)
dataOut.flush()
} catch {
@@ -248,6 +246,11 @@ private[spark] class PythonRDD(
// will kill the whole executor (see org.apache.spark.executor.Executor).
_exception = e
worker.shutdownOutput()
+ } finally {
+ // Release memory used by this thread for shuffles
+ env.shuffleMemoryManager.releaseMemoryForThisThread()
+ // Release memory used by this thread for unrolling blocks
+ env.blockManager.memoryStore.releaseUnrollMemoryForThisThread()
}
}
}
@@ -339,26 +342,34 @@ private[spark] object PythonRDD extends Logging {
def readRDDFromFile(sc: JavaSparkContext, filename: String, parallelism: Int):
JavaRDD[Array[Byte]] = {
val file = new DataInputStream(new FileInputStream(filename))
- val objs = new collection.mutable.ArrayBuffer[Array[Byte]]
try {
- while (true) {
- val length = file.readInt()
- val obj = new Array[Byte](length)
- file.readFully(obj)
- objs.append(obj)
+ val objs = new collection.mutable.ArrayBuffer[Array[Byte]]
+ try {
+ while (true) {
+ val length = file.readInt()
+ val obj = new Array[Byte](length)
+ file.readFully(obj)
+ objs.append(obj)
+ }
+ } catch {
+ case eof: EOFException => {}
}
- } catch {
- case eof: EOFException => {}
+ JavaRDD.fromRDD(sc.sc.parallelize(objs, parallelism))
+ } finally {
+ file.close()
}
- JavaRDD.fromRDD(sc.sc.parallelize(objs, parallelism))
}
def readBroadcastFromFile(sc: JavaSparkContext, filename: String): Broadcast[Array[Byte]] = {
val file = new DataInputStream(new FileInputStream(filename))
- val length = file.readInt()
- val obj = new Array[Byte](length)
- file.readFully(obj)
- sc.broadcast(obj)
+ try {
+ val length = file.readInt()
+ val obj = new Array[Byte](length)
+ file.readFully(obj)
+ sc.broadcast(obj)
+ } finally {
+ file.close()
+ }
}
def writeIteratorToStream[T](iter: Iterator[T], dataOut: DataOutputStream) {
@@ -736,6 +747,7 @@ private[spark] object PythonRDD extends Logging {
def pythonToJavaMap(pyRDD: JavaRDD[Array[Byte]]): JavaRDD[Map[String, _]] = {
pyRDD.rdd.mapPartitions { iter =>
val unpickle = new Unpickler
+ SerDeUtil.initialize()
iter.flatMap { row =>
unpickle.loads(row) match {
// in case of objects are pickled in batch mode
@@ -775,7 +787,7 @@ private[spark] object PythonRDD extends Logging {
}.toJavaRDD()
}
- private class AutoBatchedPickler(iter: Iterator[Any]) extends Iterator[Array[Byte]] {
+ private[spark] class AutoBatchedPickler(iter: Iterator[Any]) extends Iterator[Array[Byte]] {
private val pickle = new Pickler()
private var batch = 1
private val buffer = new mutable.ArrayBuffer[Any]
@@ -812,11 +824,12 @@ private[spark] object PythonRDD extends Logging {
*/
def pythonToJava(pyRDD: JavaRDD[Array[Byte]], batched: Boolean): JavaRDD[Any] = {
pyRDD.rdd.mapPartitions { iter =>
+ SerDeUtil.initialize()
val unpickle = new Unpickler
iter.flatMap { row =>
val obj = unpickle.loads(row)
if (batched) {
- obj.asInstanceOf[JArrayList[_]]
+ obj.asInstanceOf[JArrayList[_]].asScala
} else {
Seq(obj)
}
diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonWorkerFactory.scala b/core/src/main/scala/org/apache/spark/api/python/PythonWorkerFactory.scala
index 71bdf0fe1b917..e314408c067e9 100644
--- a/core/src/main/scala/org/apache/spark/api/python/PythonWorkerFactory.scala
+++ b/core/src/main/scala/org/apache/spark/api/python/PythonWorkerFactory.scala
@@ -108,10 +108,12 @@ private[spark] class PythonWorkerFactory(pythonExec: String, envVars: Map[String
serverSocket = new ServerSocket(0, 1, InetAddress.getByAddress(Array(127, 0, 0, 1)))
// Create and start the worker
- val pb = new ProcessBuilder(Seq(pythonExec, "-u", "-m", "pyspark.worker"))
+ val pb = new ProcessBuilder(Seq(pythonExec, "-m", "pyspark.worker"))
val workerEnv = pb.environment()
workerEnv.putAll(envVars)
workerEnv.put("PYTHONPATH", pythonPath)
+ // This is equivalent to setting the -u flag; we use it because ipython doesn't support -u:
+ workerEnv.put("PYTHONUNBUFFERED", "YES")
val worker = pb.start()
// Redirect worker stdout and stderr
@@ -149,10 +151,12 @@ private[spark] class PythonWorkerFactory(pythonExec: String, envVars: Map[String
try {
// Create and start the daemon
- val pb = new ProcessBuilder(Seq(pythonExec, "-u", "-m", "pyspark.daemon"))
+ val pb = new ProcessBuilder(Seq(pythonExec, "-m", "pyspark.daemon"))
val workerEnv = pb.environment()
workerEnv.putAll(envVars)
workerEnv.put("PYTHONPATH", pythonPath)
+ // This is equivalent to setting the -u flag; we use it because ipython doesn't support -u:
+ workerEnv.put("PYTHONUNBUFFERED", "YES")
daemon = pb.start()
val in = new DataInputStream(daemon.getInputStream)
diff --git a/core/src/main/scala/org/apache/spark/api/python/SerDeUtil.scala b/core/src/main/scala/org/apache/spark/api/python/SerDeUtil.scala
index 7903457b17e13..ebdc3533e0992 100644
--- a/core/src/main/scala/org/apache/spark/api/python/SerDeUtil.scala
+++ b/core/src/main/scala/org/apache/spark/api/python/SerDeUtil.scala
@@ -29,7 +29,7 @@ import org.apache.spark.{Logging, SparkException}
import org.apache.spark.rdd.RDD
/** Utilities for serialization / deserialization between Python and Java, using Pickle. */
-private[python] object SerDeUtil extends Logging {
+private[spark] object SerDeUtil extends Logging {
// Unpickle array.array generated by Python 2.6
class ArrayConstructor extends net.razorvine.pickle.objects.ArrayConstructor {
// /* Description of types */
@@ -76,9 +76,18 @@ private[python] object SerDeUtil extends Logging {
}
}
+ private var initialized = false
+ // This should be called before trying to unpickle array.array from Python
+ // In cluster mode, this should be put in closure
def initialize() = {
- Unpickler.registerConstructor("array", "array", new ArrayConstructor())
+ synchronized{
+ if (!initialized) {
+ Unpickler.registerConstructor("array", "array", new ArrayConstructor())
+ initialized = true
+ }
+ }
}
+ initialize()
private def checkPickle(t: (Any, Any)): (Boolean, Boolean) = {
val pickle = new Pickler
@@ -143,6 +152,7 @@ private[python] object SerDeUtil extends Logging {
obj.asInstanceOf[Array[_]].length == 2
}
pyRDD.mapPartitions { iter =>
+ initialize()
val unpickle = new Unpickler
val unpickled =
if (batchSerialized) {
diff --git a/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcast.scala b/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcast.scala
index 942dc7d7eac87..4cd4f4f96fd16 100644
--- a/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcast.scala
+++ b/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcast.scala
@@ -163,18 +163,23 @@ private[broadcast] object HttpBroadcast extends Logging {
private def write(id: Long, value: Any) {
val file = getFile(id)
- val out: OutputStream = {
- if (compress) {
- compressionCodec.compressedOutputStream(new FileOutputStream(file))
- } else {
- new BufferedOutputStream(new FileOutputStream(file), bufferSize)
+ val fileOutputStream = new FileOutputStream(file)
+ try {
+ val out: OutputStream = {
+ if (compress) {
+ compressionCodec.compressedOutputStream(fileOutputStream)
+ } else {
+ new BufferedOutputStream(fileOutputStream, bufferSize)
+ }
}
+ val ser = SparkEnv.get.serializer.newInstance()
+ val serOut = ser.serializeStream(out)
+ serOut.writeObject(value)
+ serOut.close()
+ files += file
+ } finally {
+ fileOutputStream.close()
}
- val ser = SparkEnv.get.serializer.newInstance()
- val serOut = ser.serializeStream(out)
- serOut.writeObject(value)
- serOut.close()
- files += file
}
private def read[T: ClassTag](id: Long): T = {
diff --git a/core/src/main/scala/org/apache/spark/deploy/Client.scala b/core/src/main/scala/org/apache/spark/deploy/Client.scala
index 065ddda50e65e..f2687ce6b42b4 100644
--- a/core/src/main/scala/org/apache/spark/deploy/Client.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/Client.scala
@@ -130,7 +130,7 @@ private class ClientActor(driverArgs: ClientArguments, conf: SparkConf)
println(s"Error connecting to master ${driverArgs.master} ($remoteAddress), exiting.")
System.exit(-1)
- case AssociationErrorEvent(cause, _, remoteAddress, _) =>
+ case AssociationErrorEvent(cause, _, remoteAddress, _, _) =>
println(s"Error connecting to master ${driverArgs.master} ($remoteAddress), exiting.")
println(s"Cause was: $cause")
System.exit(-1)
diff --git a/core/src/main/scala/org/apache/spark/deploy/PythonRunner.scala b/core/src/main/scala/org/apache/spark/deploy/PythonRunner.scala
index b66c3ba4d5fb0..af94b05ce3847 100644
--- a/core/src/main/scala/org/apache/spark/deploy/PythonRunner.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/PythonRunner.scala
@@ -34,7 +34,8 @@ object PythonRunner {
val pythonFile = args(0)
val pyFiles = args(1)
val otherArgs = args.slice(2, args.length)
- val pythonExec = sys.env.get("PYSPARK_PYTHON").getOrElse("python") // TODO: get this from conf
+ val pythonExec =
+ sys.env.getOrElse("PYSPARK_DRIVER_PYTHON", sys.env.getOrElse("PYSPARK_PYTHON", "python"))
// Format python file paths before adding them to the PYTHONPATH
val formattedPythonFile = formatPath(pythonFile)
@@ -54,9 +55,11 @@ object PythonRunner {
val pythonPath = PythonUtils.mergePythonPaths(pathElements: _*)
// Launch Python process
- val builder = new ProcessBuilder(Seq(pythonExec, "-u", formattedPythonFile) ++ otherArgs)
+ val builder = new ProcessBuilder(Seq(pythonExec, formattedPythonFile) ++ otherArgs)
val env = builder.environment()
env.put("PYTHONPATH", pythonPath)
+ // This is equivalent to setting the -u flag; we use it because ipython doesn't support -u:
+ env.put("PYTHONUNBUFFERED", "YES") // value is needed to be set to a non-empty string
env.put("PYSPARK_GATEWAY_PORT", "" + gatewayServer.getListeningPort)
builder.redirectErrorStream(true) // Ugly but needed for stdout and stderr to synchronize
val process = builder.start()
diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala
index 580a439c9a892..f97bf67fa5a3b 100644
--- a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala
@@ -320,6 +320,10 @@ object SparkSubmit {
} catch {
case e: ClassNotFoundException =>
e.printStackTrace(printStream)
+ if (childMainClass.contains("thriftserver")) {
+ println(s"Failed to load main class $childMainClass.")
+ println("You need to build Spark with -Phive.")
+ }
System.exit(CLASS_NOT_FOUND_EXIT_STATUS)
}
diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala b/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala
index 2b72c61cc8177..72a452e0aefb5 100644
--- a/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala
@@ -17,20 +17,18 @@
package org.apache.spark.deploy
-import java.io.{File, FileInputStream, IOException}
-import java.util.Properties
import java.util.jar.JarFile
import scala.collection.JavaConversions._
import scala.collection.mutable.{ArrayBuffer, HashMap}
-import org.apache.spark.SparkException
import org.apache.spark.util.Utils
/**
* Parses and encapsulates arguments from the spark-submit script.
+ * The env argument is used for testing.
*/
-private[spark] class SparkSubmitArguments(args: Seq[String]) {
+private[spark] class SparkSubmitArguments(args: Seq[String], env: Map[String, String] = sys.env) {
var master: String = null
var deployMode: String = null
var executorMemory: String = null
@@ -62,9 +60,8 @@ private[spark] class SparkSubmitArguments(args: Seq[String]) {
val defaultProperties = new HashMap[String, String]()
if (verbose) SparkSubmit.printStream.println(s"Using properties file: $propertiesFile")
Option(propertiesFile).foreach { filename =>
- val file = new File(filename)
- SparkSubmitArguments.getPropertiesFromFile(file).foreach { case (k, v) =>
- if (k.startsWith("spark")) {
+ Utils.getPropertiesFromFile(filename).foreach { case (k, v) =>
+ if (k.startsWith("spark.")) {
defaultProperties(k) = v
if (verbose) SparkSubmit.printStream.println(s"Adding default property: $k=$v")
} else {
@@ -89,27 +86,7 @@ private[spark] class SparkSubmitArguments(args: Seq[String]) {
*/
private def mergeSparkProperties(): Unit = {
// Use common defaults file, if not specified by user
- if (propertiesFile == null) {
- sys.env.get("SPARK_CONF_DIR").foreach { sparkConfDir =>
- val sep = File.separator
- val defaultPath = s"${sparkConfDir}${sep}spark-defaults.conf"
- val file = new File(defaultPath)
- if (file.exists()) {
- propertiesFile = file.getAbsolutePath
- }
- }
- }
-
- if (propertiesFile == null) {
- sys.env.get("SPARK_HOME").foreach { sparkHome =>
- val sep = File.separator
- val defaultPath = s"${sparkHome}${sep}conf${sep}spark-defaults.conf"
- val file = new File(defaultPath)
- if (file.exists()) {
- propertiesFile = file.getAbsolutePath
- }
- }
- }
+ propertiesFile = Option(propertiesFile).getOrElse(Utils.getDefaultPropertiesFile(env))
val properties = HashMap[String, String]()
properties.putAll(defaultSparkProperties)
@@ -117,19 +94,18 @@ private[spark] class SparkSubmitArguments(args: Seq[String]) {
// Use properties file as fallback for values which have a direct analog to
// arguments in this script.
- master = Option(master).getOrElse(properties.get("spark.master").orNull)
- executorMemory = Option(executorMemory)
- .getOrElse(properties.get("spark.executor.memory").orNull)
- executorCores = Option(executorCores)
- .getOrElse(properties.get("spark.executor.cores").orNull)
+ master = Option(master).orElse(properties.get("spark.master")).orNull
+ executorMemory = Option(executorMemory).orElse(properties.get("spark.executor.memory")).orNull
+ executorCores = Option(executorCores).orElse(properties.get("spark.executor.cores")).orNull
totalExecutorCores = Option(totalExecutorCores)
- .getOrElse(properties.get("spark.cores.max").orNull)
- name = Option(name).getOrElse(properties.get("spark.app.name").orNull)
- jars = Option(jars).getOrElse(properties.get("spark.jars").orNull)
+ .orElse(properties.get("spark.cores.max"))
+ .orNull
+ name = Option(name).orElse(properties.get("spark.app.name")).orNull
+ jars = Option(jars).orElse(properties.get("spark.jars")).orNull
// This supports env vars in older versions of Spark
- master = Option(master).getOrElse(System.getenv("MASTER"))
- deployMode = Option(deployMode).getOrElse(System.getenv("DEPLOY_MODE"))
+ master = Option(master).orElse(env.get("MASTER")).orNull
+ deployMode = Option(deployMode).orElse(env.get("DEPLOY_MODE")).orNull
// Try to set main class from JAR if no --class argument is given
if (mainClass == null && !isPython && primaryResource != null) {
@@ -182,7 +158,7 @@ private[spark] class SparkSubmitArguments(args: Seq[String]) {
}
if (master.startsWith("yarn")) {
- val hasHadoopEnv = sys.env.contains("HADOOP_CONF_DIR") || sys.env.contains("YARN_CONF_DIR")
+ val hasHadoopEnv = env.contains("HADOOP_CONF_DIR") || env.contains("YARN_CONF_DIR")
if (!hasHadoopEnv && !Utils.isTesting) {
throw new Exception(s"When running with master '$master' " +
"either HADOOP_CONF_DIR or YARN_CONF_DIR must be set in the environment.")
@@ -405,23 +381,3 @@ private[spark] class SparkSubmitArguments(args: Seq[String]) {
SparkSubmit.exitFn()
}
}
-
-object SparkSubmitArguments {
- /** Load properties present in the given file. */
- def getPropertiesFromFile(file: File): Seq[(String, String)] = {
- require(file.exists(), s"Properties file $file does not exist")
- require(file.isFile(), s"Properties file $file is not a normal file")
- val inputStream = new FileInputStream(file)
- try {
- val properties = new Properties()
- properties.load(inputStream)
- properties.stringPropertyNames().toSeq.map(k => (k, properties(k).trim))
- } catch {
- case e: IOException =>
- val message = s"Failed when loading Spark properties file $file"
- throw new SparkException(message, e)
- } finally {
- inputStream.close()
- }
- }
-}
diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkSubmitDriverBootstrapper.scala b/core/src/main/scala/org/apache/spark/deploy/SparkSubmitDriverBootstrapper.scala
index 38b5d8e1739d0..0125330589da5 100644
--- a/core/src/main/scala/org/apache/spark/deploy/SparkSubmitDriverBootstrapper.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/SparkSubmitDriverBootstrapper.scala
@@ -68,7 +68,7 @@ private[spark] object SparkSubmitDriverBootstrapper {
assume(bootstrapDriver != null, "SPARK_SUBMIT_BOOTSTRAP_DRIVER must be set")
// Parse the properties file for the equivalent spark.driver.* configs
- val properties = SparkSubmitArguments.getPropertiesFromFile(new File(propertiesFile)).toMap
+ val properties = Utils.getPropertiesFromFile(propertiesFile)
val confDriverMemory = properties.get("spark.driver.memory")
val confLibraryPath = properties.get("spark.driver.extraLibraryPath")
val confClasspath = properties.get("spark.driver.extraClassPath")
@@ -154,7 +154,8 @@ private[spark] object SparkSubmitDriverBootstrapper {
process.destroy()
}
}
- process.waitFor()
+ val returnCode = process.waitFor()
+ sys.exit(returnCode)
}
}
diff --git a/core/src/main/scala/org/apache/spark/deploy/client/AppClient.scala b/core/src/main/scala/org/apache/spark/deploy/client/AppClient.scala
index 32790053a6be8..98a93d1fcb2a3 100644
--- a/core/src/main/scala/org/apache/spark/deploy/client/AppClient.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/client/AppClient.scala
@@ -154,7 +154,7 @@ private[spark] class AppClient(
logWarning(s"Connection to $address failed; waiting for master to reconnect...")
markDisconnected()
- case AssociationErrorEvent(cause, _, address, _) if isPossibleMaster(address) =>
+ case AssociationErrorEvent(cause, _, address, _, _) if isPossibleMaster(address) =>
logWarning(s"Could not connect to $address: $cause")
case StopAppClient =>
diff --git a/core/src/main/scala/org/apache/spark/deploy/history/HistoryServerArguments.scala b/core/src/main/scala/org/apache/spark/deploy/history/HistoryServerArguments.scala
index 25fc76c23e0fb..5bce32a04d16d 100644
--- a/core/src/main/scala/org/apache/spark/deploy/history/HistoryServerArguments.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/history/HistoryServerArguments.scala
@@ -18,12 +18,14 @@
package org.apache.spark.deploy.history
import org.apache.spark.SparkConf
+import org.apache.spark.util.Utils
/**
* Command-line parser for the master.
*/
private[spark] class HistoryServerArguments(conf: SparkConf, args: Array[String]) {
private var logDir: String = null
+ private var propertiesFile: String = null
parse(args.toList)
@@ -32,11 +34,16 @@ private[spark] class HistoryServerArguments(conf: SparkConf, args: Array[String]
case ("--dir" | "-d") :: value :: tail =>
logDir = value
conf.set("spark.history.fs.logDirectory", value)
+ System.setProperty("spark.history.fs.logDirectory", value)
parse(tail)
case ("--help" | "-h") :: tail =>
printUsageAndExit(0)
+ case ("--properties-file") :: value :: tail =>
+ propertiesFile = value
+ parse(tail)
+
case Nil =>
case _ =>
@@ -44,10 +51,17 @@ private[spark] class HistoryServerArguments(conf: SparkConf, args: Array[String]
}
}
+ // This mutates the SparkConf, so all accesses to it must be made after this line
+ Utils.loadDefaultSparkProperties(conf, propertiesFile)
+
private def printUsageAndExit(exitCode: Int) {
System.err.println(
"""
- |Usage: HistoryServer
+ |Usage: HistoryServer [options]
+ |
+ |Options:
+ | --properties-file FILE Path to a custom Spark properties file.
+ | Default is conf/spark-defaults.conf.
|
|Configuration options can be set by setting the corresponding JVM system property.
|History Server options are always available; additional options depend on the provider.
diff --git a/core/src/main/scala/org/apache/spark/deploy/master/FileSystemPersistenceEngine.scala b/core/src/main/scala/org/apache/spark/deploy/master/FileSystemPersistenceEngine.scala
index aa85aa060d9c1..08a99bbe68578 100644
--- a/core/src/main/scala/org/apache/spark/deploy/master/FileSystemPersistenceEngine.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/master/FileSystemPersistenceEngine.scala
@@ -83,15 +83,21 @@ private[spark] class FileSystemPersistenceEngine(
val serialized = serializer.toBinary(value)
val out = new FileOutputStream(file)
- out.write(serialized)
- out.close()
+ try {
+ out.write(serialized)
+ } finally {
+ out.close()
+ }
}
def deserializeFromFile[T](file: File)(implicit m: Manifest[T]): T = {
val fileData = new Array[Byte](file.length().asInstanceOf[Int])
val dis = new DataInputStream(new FileInputStream(file))
- dis.readFully(fileData)
- dis.close()
+ try {
+ dis.readFully(fileData)
+ } finally {
+ dis.close()
+ }
val clazz = m.runtimeClass.asInstanceOf[Class[T]]
val serializer = serialization.serializerFor(clazz)
diff --git a/core/src/main/scala/org/apache/spark/deploy/master/Master.scala b/core/src/main/scala/org/apache/spark/deploy/master/Master.scala
index 432b552c58cd8..f98b531316a3d 100644
--- a/core/src/main/scala/org/apache/spark/deploy/master/Master.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/master/Master.scala
@@ -33,8 +33,8 @@ import akka.remote.{DisassociatedEvent, RemotingLifecycleEvent}
import akka.serialization.SerializationExtension
import org.apache.spark.{Logging, SecurityManager, SparkConf, SparkException}
-import org.apache.spark.deploy.{ApplicationDescription, DriverDescription, ExecutorState,
- SparkHadoopUtil}
+import org.apache.spark.deploy.{ApplicationDescription, DriverDescription,
+ ExecutorState, SparkHadoopUtil}
import org.apache.spark.deploy.DeployMessages._
import org.apache.spark.deploy.history.HistoryServer
import org.apache.spark.deploy.master.DriverState.DriverState
@@ -693,16 +693,18 @@ private[spark] class Master(
app.desc.appUiUrl = notFoundBasePath
return false
}
- val fileSystem = Utils.getHadoopFileSystem(eventLogDir,
+
+ val appEventLogDir = EventLoggingListener.getLogDirPath(eventLogDir, app.id)
+ val fileSystem = Utils.getHadoopFileSystem(appEventLogDir,
SparkHadoopUtil.get.newConfiguration(conf))
- val eventLogInfo = EventLoggingListener.parseLoggingInfo(eventLogDir, fileSystem)
+ val eventLogInfo = EventLoggingListener.parseLoggingInfo(appEventLogDir, fileSystem)
val eventLogPaths = eventLogInfo.logPaths
val compressionCodec = eventLogInfo.compressionCodec
if (eventLogPaths.isEmpty) {
// Event logging is enabled for this application, but no event logs are found
val title = s"Application history not found (${app.id})"
- var msg = s"No event logs found for application $appName in $eventLogDir."
+ var msg = s"No event logs found for application $appName in $appEventLogDir."
logWarning(msg)
msg += " Did you specify the correct logging directory?"
msg = URLEncoder.encode(msg, "UTF-8")
diff --git a/core/src/main/scala/org/apache/spark/deploy/master/MasterArguments.scala b/core/src/main/scala/org/apache/spark/deploy/master/MasterArguments.scala
index 4b0dbbe543d3f..e34bee7854292 100644
--- a/core/src/main/scala/org/apache/spark/deploy/master/MasterArguments.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/master/MasterArguments.scala
@@ -27,6 +27,7 @@ private[spark] class MasterArguments(args: Array[String], conf: SparkConf) {
var host = Utils.localHostName()
var port = 7077
var webUiPort = 8080
+ var propertiesFile: String = null
// Check for settings in environment variables
if (System.getenv("SPARK_MASTER_HOST") != null) {
@@ -38,12 +39,16 @@ private[spark] class MasterArguments(args: Array[String], conf: SparkConf) {
if (System.getenv("SPARK_MASTER_WEBUI_PORT") != null) {
webUiPort = System.getenv("SPARK_MASTER_WEBUI_PORT").toInt
}
+
+ parse(args.toList)
+
+ // This mutates the SparkConf, so all accesses to it must be made after this line
+ propertiesFile = Utils.loadDefaultSparkProperties(conf, propertiesFile)
+
if (conf.contains("spark.master.ui.port")) {
webUiPort = conf.get("spark.master.ui.port").toInt
}
- parse(args.toList)
-
def parse(args: List[String]): Unit = args match {
case ("--ip" | "-i") :: value :: tail =>
Utils.checkHost(value, "ip no longer supported, please use hostname " + value)
@@ -63,7 +68,11 @@ private[spark] class MasterArguments(args: Array[String], conf: SparkConf) {
webUiPort = value
parse(tail)
- case ("--help" | "-h") :: tail =>
+ case ("--properties-file") :: value :: tail =>
+ propertiesFile = value
+ parse(tail)
+
+ case ("--help") :: tail =>
printUsageAndExit(0)
case Nil => {}
@@ -83,7 +92,9 @@ private[spark] class MasterArguments(args: Array[String], conf: SparkConf) {
" -i HOST, --ip HOST Hostname to listen on (deprecated, please use --host or -h) \n" +
" -h HOST, --host HOST Hostname to listen on\n" +
" -p PORT, --port PORT Port to listen on (default: 7077)\n" +
- " --webui-port PORT Port for web UI (default: 8080)")
+ " --webui-port PORT Port for web UI (default: 8080)\n" +
+ " --properties-file FILE Path to a custom Spark properties file.\n" +
+ " Default is conf/spark-defaults.conf.")
System.exit(exitCode)
}
}
diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/ExecutorRunner.scala b/core/src/main/scala/org/apache/spark/deploy/worker/ExecutorRunner.scala
index 00a43673e5cd3..71d7385b08eb9 100644
--- a/core/src/main/scala/org/apache/spark/deploy/worker/ExecutorRunner.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/worker/ExecutorRunner.scala
@@ -42,7 +42,7 @@ private[spark] class ExecutorRunner(
val workerId: String,
val host: String,
val sparkHome: File,
- val workDir: File,
+ val executorDir: File,
val workerUrl: String,
val conf: SparkConf,
var state: ExecutorState.Value)
@@ -111,13 +111,14 @@ private[spark] class ExecutorRunner(
case "{{EXECUTOR_ID}}" => execId.toString
case "{{HOSTNAME}}" => host
case "{{CORES}}" => cores.toString
+ case "{{APP_ID}}" => appId
case other => other
}
def getCommandSeq = {
val command = Command(
appDesc.command.mainClass,
- appDesc.command.arguments.map(substituteVariables) ++ Seq(appId),
+ appDesc.command.arguments.map(substituteVariables),
appDesc.command.environment,
appDesc.command.classPathEntries,
appDesc.command.libraryPathEntries,
@@ -130,12 +131,6 @@ private[spark] class ExecutorRunner(
*/
def fetchAndRunExecutor() {
try {
- // Create the executor's working directory
- val executorDir = new File(workDir, appId + "/" + execId)
- if (!executorDir.mkdirs()) {
- throw new IOException("Failed to create directory " + executorDir)
- }
-
// Launch the process
val command = getCommandSeq
logInfo("Launch command: " + command.mkString("\"", "\" \"", "\""))
diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala b/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala
index 0c454e4138c96..9b52cb06fb6fa 100755
--- a/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala
@@ -18,9 +18,11 @@
package org.apache.spark.deploy.worker
import java.io.File
+import java.io.IOException
import java.text.SimpleDateFormat
import java.util.Date
+import scala.collection.JavaConversions._
import scala.collection.mutable.HashMap
import scala.concurrent.duration._
import scala.language.postfixOps
@@ -191,6 +193,7 @@ private[spark] class Worker(
changeMaster(masterUrl, masterWebUiUrl)
context.system.scheduler.schedule(0 millis, HEARTBEAT_MILLIS millis, self, SendHeartbeat)
if (CLEANUP_ENABLED) {
+ logInfo(s"Worker cleanup enabled; old application directories will be deleted in: $workDir")
context.system.scheduler.schedule(CLEANUP_INTERVAL_MILLIS millis,
CLEANUP_INTERVAL_MILLIS millis, self, WorkDirCleanup)
}
@@ -201,10 +204,23 @@ private[spark] class Worker(
case WorkDirCleanup =>
// Spin up a separate thread (in a future) to do the dir cleanup; don't tie up worker actor
val cleanupFuture = concurrent.future {
- logInfo("Cleaning up oldest application directories in " + workDir + " ...")
- Utils.findOldFiles(workDir, APP_DATA_RETENTION_SECS)
- .foreach(Utils.deleteRecursively)
+ val appDirs = workDir.listFiles()
+ if (appDirs == null) {
+ throw new IOException("ERROR: Failed to list files in " + appDirs)
+ }
+ appDirs.filter { dir =>
+ // the directory is used by an application - check that the application is not running
+ // when cleaning up
+ val appIdFromDir = dir.getName
+ val isAppStillRunning = executors.values.map(_.appId).contains(appIdFromDir)
+ dir.isDirectory && !isAppStillRunning &&
+ !Utils.doesDirectoryContainAnyNewFiles(dir, APP_DATA_RETENTION_SECS)
+ }.foreach { dir =>
+ logInfo(s"Removing directory: ${dir.getPath}")
+ Utils.deleteRecursively(dir)
+ }
}
+
cleanupFuture onFailure {
case e: Throwable =>
logError("App dir cleanup failed: " + e.getMessage, e)
@@ -233,8 +249,15 @@ private[spark] class Worker(
} else {
try {
logInfo("Asked to launch executor %s/%d for %s".format(appId, execId, appDesc.name))
+
+ // Create the executor's working directory
+ val executorDir = new File(workDir, appId + "/" + execId)
+ if (!executorDir.mkdirs()) {
+ throw new IOException("Failed to create directory " + executorDir)
+ }
+
val manager = new ExecutorRunner(appId, execId, appDesc, cores_, memory_,
- self, workerId, host, sparkHome, workDir, akkaUrl, conf, ExecutorState.LOADING)
+ self, workerId, host, sparkHome, executorDir, akkaUrl, conf, ExecutorState.LOADING)
executors(appId + "/" + execId) = manager
manager.start()
coresUsed += cores_
@@ -242,12 +265,13 @@ private[spark] class Worker(
master ! ExecutorStateChanged(appId, execId, manager.state, None, None)
} catch {
case e: Exception => {
- logError("Failed to launch executor %s/%d for %s".format(appId, execId, appDesc.name))
+ logError(s"Failed to launch executor $appId/$execId for ${appDesc.name}.", e)
if (executors.contains(appId + "/" + execId)) {
executors(appId + "/" + execId).kill()
executors -= appId + "/" + execId
}
- master ! ExecutorStateChanged(appId, execId, ExecutorState.FAILED, None, None)
+ master ! ExecutorStateChanged(appId, execId, ExecutorState.FAILED,
+ Some(e.toString), None)
}
}
}
diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/WorkerArguments.scala b/core/src/main/scala/org/apache/spark/deploy/worker/WorkerArguments.scala
index 1e295aaa48c30..019cd70f2a229 100644
--- a/core/src/main/scala/org/apache/spark/deploy/worker/WorkerArguments.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/worker/WorkerArguments.scala
@@ -33,6 +33,7 @@ private[spark] class WorkerArguments(args: Array[String], conf: SparkConf) {
var memory = inferDefaultMemory()
var masters: Array[String] = null
var workDir: String = null
+ var propertiesFile: String = null
// Check for settings in environment variables
if (System.getenv("SPARK_WORKER_PORT") != null) {
@@ -41,21 +42,27 @@ private[spark] class WorkerArguments(args: Array[String], conf: SparkConf) {
if (System.getenv("SPARK_WORKER_CORES") != null) {
cores = System.getenv("SPARK_WORKER_CORES").toInt
}
- if (System.getenv("SPARK_WORKER_MEMORY") != null) {
- memory = Utils.memoryStringToMb(System.getenv("SPARK_WORKER_MEMORY"))
+ if (conf.getenv("SPARK_WORKER_MEMORY") != null) {
+ memory = Utils.memoryStringToMb(conf.getenv("SPARK_WORKER_MEMORY"))
}
if (System.getenv("SPARK_WORKER_WEBUI_PORT") != null) {
webUiPort = System.getenv("SPARK_WORKER_WEBUI_PORT").toInt
}
- if (conf.contains("spark.worker.ui.port")) {
- webUiPort = conf.get("spark.worker.ui.port").toInt
- }
if (System.getenv("SPARK_WORKER_DIR") != null) {
workDir = System.getenv("SPARK_WORKER_DIR")
}
parse(args.toList)
+ // This mutates the SparkConf, so all accesses to it must be made after this line
+ propertiesFile = Utils.loadDefaultSparkProperties(conf, propertiesFile)
+
+ if (conf.contains("spark.worker.ui.port")) {
+ webUiPort = conf.get("spark.worker.ui.port").toInt
+ }
+
+ checkWorkerMemory()
+
def parse(args: List[String]): Unit = args match {
case ("--ip" | "-i") :: value :: tail =>
Utils.checkHost(value, "ip no longer supported, please use hostname " + value)
@@ -87,7 +94,11 @@ private[spark] class WorkerArguments(args: Array[String], conf: SparkConf) {
webUiPort = value
parse(tail)
- case ("--help" | "-h") :: tail =>
+ case ("--properties-file") :: value :: tail =>
+ propertiesFile = value
+ parse(tail)
+
+ case ("--help") :: tail =>
printUsageAndExit(0)
case value :: tail =>
@@ -122,7 +133,9 @@ private[spark] class WorkerArguments(args: Array[String], conf: SparkConf) {
" -i HOST, --ip IP Hostname to listen on (deprecated, please use --host or -h)\n" +
" -h HOST, --host HOST Hostname to listen on\n" +
" -p PORT, --port PORT Port to listen on (default: random)\n" +
- " --webui-port PORT Port for web UI (default: 8081)")
+ " --webui-port PORT Port for web UI (default: 8081)\n" +
+ " --properties-file FILE Path to a custom Spark properties file.\n" +
+ " Default is conf/spark-defaults.conf.")
System.exit(exitCode)
}
@@ -153,4 +166,11 @@ private[spark] class WorkerArguments(args: Array[String], conf: SparkConf) {
// Leave out 1 GB for the operating system, but don't return a negative memory size
math.max(totalMb - 1024, 512)
}
+
+ def checkWorkerMemory(): Unit = {
+ if (memory <= 0) {
+ val message = "Memory can't be 0, missing a M or G on the end of the memory specification?"
+ throw new IllegalStateException(message)
+ }
+ }
}
diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/WorkerWatcher.scala b/core/src/main/scala/org/apache/spark/deploy/worker/WorkerWatcher.scala
index 6d0d0bbe5ecec..63a8ac817b618 100644
--- a/core/src/main/scala/org/apache/spark/deploy/worker/WorkerWatcher.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/worker/WorkerWatcher.scala
@@ -54,7 +54,7 @@ private[spark] class WorkerWatcher(workerUrl: String)
case AssociatedEvent(localAddress, remoteAddress, inbound) if isWorker(remoteAddress) =>
logInfo(s"Successfully connected to $workerUrl")
- case AssociationErrorEvent(cause, localAddress, remoteAddress, inbound)
+ case AssociationErrorEvent(cause, localAddress, remoteAddress, inbound, _)
if isWorker(remoteAddress) =>
// These logs may not be seen if the worker (and associated pipe) has died
logError(s"Could not initialize connection to worker $workerUrl. Exiting.")
diff --git a/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala b/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala
index 13af5b6f5812d..c40a3e16675ad 100644
--- a/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala
+++ b/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala
@@ -106,6 +106,7 @@ private[spark] object CoarseGrainedExecutorBackend extends Logging {
executorId: String,
hostname: String,
cores: Int,
+ appId: String,
workerUrl: Option[String]) {
SignalLogger.register(log)
@@ -122,7 +123,8 @@ private[spark] object CoarseGrainedExecutorBackend extends Logging {
val driver = fetcher.actorSelection(driverUrl)
val timeout = AkkaUtils.askTimeout(executorConf)
val fut = Patterns.ask(driver, RetrieveSparkProps, timeout)
- val props = Await.result(fut, timeout).asInstanceOf[Seq[(String, String)]]
+ val props = Await.result(fut, timeout).asInstanceOf[Seq[(String, String)]] ++
+ Seq[(String, String)](("spark.app.id", appId))
fetcher.shutdown()
// Create a new ActorSystem using driver's Spark properties to run the backend.
@@ -144,16 +146,19 @@ private[spark] object CoarseGrainedExecutorBackend extends Logging {
def main(args: Array[String]) {
args.length match {
- case x if x < 4 =>
+ case x if x < 5 =>
System.err.println(
// Worker url is used in spark standalone mode to enforce fate-sharing with worker
"Usage: CoarseGrainedExecutorBackend " +
- " []")
+ " [] ")
System.exit(1)
- case 4 =>
- run(args(0), args(1), args(2), args(3).toInt, None)
- case x if x > 4 =>
- run(args(0), args(1), args(2), args(3).toInt, Some(args(4)))
+
+ // NB: These arguments are provided by SparkDeploySchedulerBackend (for standalone mode)
+ // and CoarseMesosSchedulerBackend (for mesos mode).
+ case 5 =>
+ run(args(0), args(1), args(2), args(3).toInt, args(4), None)
+ case x if x > 5 =>
+ run(args(0), args(1), args(2), args(3).toInt, args(4), Some(args(5)))
}
}
}
diff --git a/core/src/main/scala/org/apache/spark/executor/Executor.scala b/core/src/main/scala/org/apache/spark/executor/Executor.scala
index d7211ae465902..616c7e6a46368 100644
--- a/core/src/main/scala/org/apache/spark/executor/Executor.scala
+++ b/core/src/main/scala/org/apache/spark/executor/Executor.scala
@@ -74,6 +74,7 @@ private[spark] class Executor(
val executorSource = new ExecutorSource(this, executorId)
// Initialize Spark environment (using system properties read above)
+ conf.set("spark.executor.id", "executor." + executorId)
private val env = {
if (!isLocal) {
val _env = SparkEnv.create(conf, executorId, slaveHostname, 0,
@@ -147,7 +148,6 @@ private[spark] class Executor(
override def run() {
val startTime = System.currentTimeMillis()
- SparkEnv.set(env)
Thread.currentThread.setContextClassLoader(replClassLoader)
val ser = SparkEnv.get.closureSerializer.newInstance()
logInfo(s"Running $taskName (TID $taskId)")
@@ -157,7 +157,6 @@ private[spark] class Executor(
val startGCTime = gcTime
try {
- SparkEnv.set(env)
Accumulators.clear()
val (taskFiles, taskJars, taskBytes) = Task.deserializeWithDependencies(serializedTask)
updateDependencies(taskFiles, taskJars)
diff --git a/core/src/main/scala/org/apache/spark/executor/ExecutorSource.scala b/core/src/main/scala/org/apache/spark/executor/ExecutorSource.scala
index d6721586566c2..c4d73622c4727 100644
--- a/core/src/main/scala/org/apache/spark/executor/ExecutorSource.scala
+++ b/core/src/main/scala/org/apache/spark/executor/ExecutorSource.scala
@@ -37,8 +37,7 @@ private[spark] class ExecutorSource(val executor: Executor, executorId: String)
override val metricRegistry = new MetricRegistry()
- // TODO: It would be nice to pass the application name here
- override val sourceName = "executor.%s".format(executorId)
+ override val sourceName = "executor"
// Gauge for executor thread pool's actively executing task counts
metricRegistry.register(MetricRegistry.name("threadpool", "activeTasks"), new Gauge[Int] {
diff --git a/core/src/main/scala/org/apache/spark/executor/MesosExecutorBackend.scala b/core/src/main/scala/org/apache/spark/executor/MesosExecutorBackend.scala
index a42c8b43bbf7f..bca0b152268ad 100644
--- a/core/src/main/scala/org/apache/spark/executor/MesosExecutorBackend.scala
+++ b/core/src/main/scala/org/apache/spark/executor/MesosExecutorBackend.scala
@@ -52,7 +52,8 @@ private[spark] class MesosExecutorBackend
slaveInfo: SlaveInfo) {
logInfo("Registered with Mesos as executor ID " + executorInfo.getExecutorId.getValue)
this.driver = driver
- val properties = Utils.deserialize[Array[(String, String)]](executorInfo.getData.toByteArray)
+ val properties = Utils.deserialize[Array[(String, String)]](executorInfo.getData.toByteArray) ++
+ Seq[(String, String)](("spark.app.id", frameworkInfo.getId.getValue))
executor = new Executor(
executorInfo.getExecutorId.getValue,
slaveInfo.getHostname,
diff --git a/core/src/main/scala/org/apache/spark/metrics/MetricsSystem.scala b/core/src/main/scala/org/apache/spark/metrics/MetricsSystem.scala
index fd316a89a1a10..5dd67b0cbf683 100644
--- a/core/src/main/scala/org/apache/spark/metrics/MetricsSystem.scala
+++ b/core/src/main/scala/org/apache/spark/metrics/MetricsSystem.scala
@@ -83,10 +83,10 @@ private[spark] class MetricsSystem private (
def getServletHandlers = metricsServlet.map(_.getHandlers).getOrElse(Array())
metricsConfig.initialize()
- registerSources()
- registerSinks()
def start() {
+ registerSources()
+ registerSinks()
sinks.foreach(_.start)
}
@@ -98,10 +98,39 @@ private[spark] class MetricsSystem private (
sinks.foreach(_.report())
}
+ /**
+ * Build a name that uniquely identifies each metric source.
+ * The name is structured as follows: ...
+ * If either ID is not available, this defaults to just using .
+ *
+ * @param source Metric source to be named by this method.
+ * @return An unique metric name for each combination of
+ * application, executor/driver and metric source.
+ */
+ def buildRegistryName(source: Source): String = {
+ val appId = conf.getOption("spark.app.id")
+ val executorId = conf.getOption("spark.executor.id")
+ val defaultName = MetricRegistry.name(source.sourceName)
+
+ if (instance == "driver" || instance == "executor") {
+ if (appId.isDefined && executorId.isDefined) {
+ MetricRegistry.name(appId.get, executorId.get, source.sourceName)
+ } else {
+ // Only Driver and Executor are set spark.app.id and spark.executor.id.
+ // For instance, Master and Worker are not related to a specific application.
+ val warningMsg = s"Using default name $defaultName for source because %s is not set."
+ if (appId.isEmpty) { logWarning(warningMsg.format("spark.app.id")) }
+ if (executorId.isEmpty) { logWarning(warningMsg.format("spark.executor.id")) }
+ defaultName
+ }
+ } else { defaultName }
+ }
+
def registerSource(source: Source) {
sources += source
try {
- registry.register(source.sourceName, source.metricRegistry)
+ val regName = buildRegistryName(source)
+ registry.register(regName, source.metricRegistry)
} catch {
case e: IllegalArgumentException => logInfo("Metrics already registered", e)
}
@@ -109,8 +138,9 @@ private[spark] class MetricsSystem private (
def removeSource(source: Source) {
sources -= source
+ val regName = buildRegistryName(source)
registry.removeMatching(new MetricFilter {
- def matches(name: String, metric: Metric): Boolean = name.startsWith(source.sourceName)
+ def matches(name: String, metric: Metric): Boolean = name.startsWith(regName)
})
}
@@ -125,7 +155,7 @@ private[spark] class MetricsSystem private (
val source = Class.forName(classPath).newInstance()
registerSource(source.asInstanceOf[Source])
} catch {
- case e: Exception => logError("Source class " + classPath + " cannot be instantialized", e)
+ case e: Exception => logError("Source class " + classPath + " cannot be instantiated", e)
}
}
}
diff --git a/core/src/main/scala/org/apache/spark/network/ManagedBuffer.scala b/core/src/main/scala/org/apache/spark/network/ManagedBuffer.scala
index a4409181ec907..4c9ca97a2a6b7 100644
--- a/core/src/main/scala/org/apache/spark/network/ManagedBuffer.scala
+++ b/core/src/main/scala/org/apache/spark/network/ManagedBuffer.scala
@@ -66,13 +66,27 @@ sealed abstract class ManagedBuffer {
final class FileSegmentManagedBuffer(val file: File, val offset: Long, val length: Long)
extends ManagedBuffer {
+ /**
+ * Memory mapping is expensive and can destabilize the JVM (SPARK-1145, SPARK-3889).
+ * Avoid unless there's a good reason not to.
+ */
+ private val MIN_MEMORY_MAP_BYTES = 2 * 1024 * 1024;
+
override def size: Long = length
override def nioByteBuffer(): ByteBuffer = {
var channel: FileChannel = null
try {
channel = new RandomAccessFile(file, "r").getChannel
- channel.map(MapMode.READ_ONLY, offset, length)
+ // Just copy the buffer if it's sufficiently small, as memory mapping has a high overhead.
+ if (length < MIN_MEMORY_MAP_BYTES) {
+ val buf = ByteBuffer.allocate(length.toInt)
+ channel.read(buf, offset)
+ buf.flip()
+ buf
+ } else {
+ channel.map(MapMode.READ_ONLY, offset, length)
+ }
} catch {
case e: IOException =>
Try(channel.size).toOption match {
diff --git a/core/src/main/scala/org/apache/spark/network/nio/Connection.scala b/core/src/main/scala/org/apache/spark/network/nio/Connection.scala
index 18172d359cb35..4f6f5e235811d 100644
--- a/core/src/main/scala/org/apache/spark/network/nio/Connection.scala
+++ b/core/src/main/scala/org/apache/spark/network/nio/Connection.scala
@@ -20,23 +20,30 @@ package org.apache.spark.network.nio
import java.net._
import java.nio._
import java.nio.channels._
+import java.util.concurrent.ConcurrentLinkedQueue
+import java.util.LinkedList
import org.apache.spark._
-import scala.collection.mutable.{ArrayBuffer, HashMap, Queue}
+import scala.collection.JavaConversions._
+import scala.collection.mutable.{ArrayBuffer, HashMap}
+import scala.util.control.NonFatal
private[nio]
abstract class Connection(val channel: SocketChannel, val selector: Selector,
- val socketRemoteConnectionManagerId: ConnectionManagerId, val connectionId: ConnectionId)
+ val socketRemoteConnectionManagerId: ConnectionManagerId, val connectionId: ConnectionId,
+ val securityMgr: SecurityManager)
extends Logging {
var sparkSaslServer: SparkSaslServer = null
var sparkSaslClient: SparkSaslClient = null
- def this(channel_ : SocketChannel, selector_ : Selector, id_ : ConnectionId) = {
+ def this(channel_ : SocketChannel, selector_ : Selector, id_ : ConnectionId,
+ securityMgr_ : SecurityManager) = {
this(channel_, selector_,
ConnectionManagerId.fromSocketAddress(
- channel_.socket.getRemoteSocketAddress.asInstanceOf[InetSocketAddress]), id_)
+ channel_.socket.getRemoteSocketAddress.asInstanceOf[InetSocketAddress]),
+ id_, securityMgr_)
}
channel.configureBlocking(false)
@@ -47,19 +54,11 @@ abstract class Connection(val channel: SocketChannel, val selector: Selector,
@volatile private var closed = false
var onCloseCallback: Connection => Unit = null
- var onExceptionCallback: (Connection, Exception) => Unit = null
+ val onExceptionCallbacks = new ConcurrentLinkedQueue[(Connection, Throwable) => Unit]
var onKeyInterestChangeCallback: (Connection, Int) => Unit = null
val remoteAddress = getRemoteAddress()
- /**
- * Used to synchronize client requests: client's work-related requests must
- * wait until SASL authentication completes.
- */
- private val authenticated = new Object()
-
- def getAuthenticated(): Object = authenticated
-
def isSaslComplete(): Boolean
def resetForceReregister(): Boolean
@@ -134,20 +133,24 @@ abstract class Connection(val channel: SocketChannel, val selector: Selector,
onCloseCallback = callback
}
- def onException(callback: (Connection, Exception) => Unit) {
- onExceptionCallback = callback
+ def onException(callback: (Connection, Throwable) => Unit) {
+ onExceptionCallbacks.add(callback)
}
def onKeyInterestChange(callback: (Connection, Int) => Unit) {
onKeyInterestChangeCallback = callback
}
- def callOnExceptionCallback(e: Exception) {
- if (onExceptionCallback != null) {
- onExceptionCallback(this, e)
- } else {
- logError("Error in connection to " + getRemoteConnectionManagerId() +
- " and OnExceptionCallback not registered", e)
+ def callOnExceptionCallbacks(e: Throwable) {
+ onExceptionCallbacks foreach {
+ callback =>
+ try {
+ callback(this, e)
+ } catch {
+ case NonFatal(e) => {
+ logWarning("Ignored error in onExceptionCallback", e)
+ }
+ }
}
}
@@ -192,22 +195,22 @@ abstract class Connection(val channel: SocketChannel, val selector: Selector,
private[nio]
class SendingConnection(val address: InetSocketAddress, selector_ : Selector,
- remoteId_ : ConnectionManagerId, id_ : ConnectionId)
- extends Connection(SocketChannel.open, selector_, remoteId_, id_) {
+ remoteId_ : ConnectionManagerId, id_ : ConnectionId,
+ securityMgr_ : SecurityManager)
+ extends Connection(SocketChannel.open, selector_, remoteId_, id_, securityMgr_) {
def isSaslComplete(): Boolean = {
if (sparkSaslClient != null) sparkSaslClient.isComplete() else false
}
private class Outbox {
- val messages = new Queue[Message]()
+ val messages = new LinkedList[Message]()
val defaultChunkSize = 65536
var nextMessageToBeUsed = 0
def addMessage(message: Message) {
messages.synchronized {
- /* messages += message */
- messages.enqueue(message)
+ messages.add(message)
logDebug("Added [" + message + "] to outbox for sending to " +
"[" + getRemoteConnectionManagerId() + "]")
}
@@ -218,10 +221,27 @@ class SendingConnection(val address: InetSocketAddress, selector_ : Selector,
while (!messages.isEmpty) {
/* nextMessageToBeUsed = nextMessageToBeUsed % messages.size */
/* val message = messages(nextMessageToBeUsed) */
- val message = messages.dequeue()
+
+ val message = if (securityMgr.isAuthenticationEnabled() && !isSaslComplete()) {
+ // only allow sending of security messages until sasl is complete
+ var pos = 0
+ var securityMsg: Message = null
+ while (pos < messages.size() && securityMsg == null) {
+ if (messages.get(pos).isSecurityNeg) {
+ securityMsg = messages.remove(pos)
+ }
+ pos = pos + 1
+ }
+ // didn't find any security messages and auth isn't completed so return
+ if (securityMsg == null) return None
+ securityMsg
+ } else {
+ messages.removeFirst()
+ }
+
val chunk = message.getChunkForSending(defaultChunkSize)
if (chunk.isDefined) {
- messages.enqueue(message)
+ messages.add(message)
nextMessageToBeUsed = nextMessageToBeUsed + 1
if (!message.started) {
logDebug(
@@ -273,6 +293,15 @@ class SendingConnection(val address: InetSocketAddress, selector_ : Selector,
changeConnectionKeyInterest(DEFAULT_INTEREST)
}
+ def registerAfterAuth(): Unit = {
+ outbox.synchronized {
+ needForceReregister = true
+ }
+ if (channel.isConnected) {
+ registerInterest()
+ }
+ }
+
def send(message: Message) {
outbox.synchronized {
outbox.addMessage(message)
@@ -301,7 +330,7 @@ class SendingConnection(val address: InetSocketAddress, selector_ : Selector,
} catch {
case e: Exception => {
logError("Error connecting to " + address, e)
- callOnExceptionCallback(e)
+ callOnExceptionCallbacks(e)
}
}
}
@@ -326,7 +355,7 @@ class SendingConnection(val address: InetSocketAddress, selector_ : Selector,
} catch {
case e: Exception => {
logWarning("Error finishing connection to " + address, e)
- callOnExceptionCallback(e)
+ callOnExceptionCallbacks(e)
}
}
true
@@ -371,7 +400,7 @@ class SendingConnection(val address: InetSocketAddress, selector_ : Selector,
} catch {
case e: Exception => {
logWarning("Error writing in connection to " + getRemoteConnectionManagerId(), e)
- callOnExceptionCallback(e)
+ callOnExceptionCallbacks(e)
close()
return false
}
@@ -398,7 +427,7 @@ class SendingConnection(val address: InetSocketAddress, selector_ : Selector,
case e: Exception =>
logError("Exception while reading SendingConnection to " + getRemoteConnectionManagerId(),
e)
- callOnExceptionCallback(e)
+ callOnExceptionCallbacks(e)
close()
}
@@ -415,8 +444,9 @@ class SendingConnection(val address: InetSocketAddress, selector_ : Selector,
private[spark] class ReceivingConnection(
channel_ : SocketChannel,
selector_ : Selector,
- id_ : ConnectionId)
- extends Connection(channel_, selector_, id_) {
+ id_ : ConnectionId,
+ securityMgr_ : SecurityManager)
+ extends Connection(channel_, selector_, id_, securityMgr_) {
def isSaslComplete(): Boolean = {
if (sparkSaslServer != null) sparkSaslServer.isComplete() else false
@@ -554,7 +584,7 @@ private[spark] class ReceivingConnection(
} catch {
case e: Exception => {
logWarning("Error reading from connection to " + getRemoteConnectionManagerId(), e)
- callOnExceptionCallback(e)
+ callOnExceptionCallbacks(e)
close()
return false
}
diff --git a/core/src/main/scala/org/apache/spark/network/nio/ConnectionManager.scala b/core/src/main/scala/org/apache/spark/network/nio/ConnectionManager.scala
index 5aa7e94943561..bda4bf50932c3 100644
--- a/core/src/main/scala/org/apache/spark/network/nio/ConnectionManager.scala
+++ b/core/src/main/scala/org/apache/spark/network/nio/ConnectionManager.scala
@@ -32,8 +32,10 @@ import scala.concurrent.{Await, ExecutionContext, Future, Promise}
import scala.language.postfixOps
import org.apache.spark._
-import org.apache.spark.util.{SystemClock, Utils}
+import org.apache.spark.util.Utils
+import scala.util.Try
+import scala.util.control.NonFatal
private[nio] class ConnectionManager(
port: Int,
@@ -51,22 +53,29 @@ private[nio] class ConnectionManager(
class MessageStatus(
val message: Message,
val connectionManagerId: ConnectionManagerId,
- completionHandler: MessageStatus => Unit) {
+ completionHandler: Try[Message] => Unit) {
- /** This is non-None if message has been ack'd */
- var ackMessage: Option[Message] = None
+ def success(ackMessage: Message) {
+ if (ackMessage == null) {
+ failure(new NullPointerException)
+ }
+ else {
+ completionHandler(scala.util.Success(ackMessage))
+ }
+ }
+
+ def failWithoutAck() {
+ completionHandler(scala.util.Failure(new IOException("Failed without being ACK'd")))
+ }
- def markDone(ackMessage: Option[Message]) {
- this.ackMessage = ackMessage
- completionHandler(this)
+ def failure(e: Throwable) {
+ completionHandler(scala.util.Failure(e))
}
}
private val selector = SelectorProvider.provider.openSelector()
private val ackTimeoutMonitor = new Timer("AckTimeoutMonitor", true)
- // default to 30 second timeout waiting for authentication
- private val authTimeout = conf.getInt("spark.core.connection.auth.wait.timeout", 30)
private val ackTimeout = conf.getInt("spark.core.connection.ack.wait.timeout", 60)
private val handleMessageExecutor = new ThreadPoolExecutor(
@@ -74,14 +83,32 @@ private[nio] class ConnectionManager(
conf.getInt("spark.core.connection.handler.threads.max", 60),
conf.getInt("spark.core.connection.handler.threads.keepalive", 60), TimeUnit.SECONDS,
new LinkedBlockingDeque[Runnable](),
- Utils.namedThreadFactory("handle-message-executor"))
+ Utils.namedThreadFactory("handle-message-executor")) {
+
+ override def afterExecute(r: Runnable, t: Throwable): Unit = {
+ super.afterExecute(r, t)
+ if (t != null && NonFatal(t)) {
+ logError("Error in handleMessageExecutor is not handled properly", t)
+ }
+ }
+
+ }
private val handleReadWriteExecutor = new ThreadPoolExecutor(
conf.getInt("spark.core.connection.io.threads.min", 4),
conf.getInt("spark.core.connection.io.threads.max", 32),
conf.getInt("spark.core.connection.io.threads.keepalive", 60), TimeUnit.SECONDS,
new LinkedBlockingDeque[Runnable](),
- Utils.namedThreadFactory("handle-read-write-executor"))
+ Utils.namedThreadFactory("handle-read-write-executor")) {
+
+ override def afterExecute(r: Runnable, t: Throwable): Unit = {
+ super.afterExecute(r, t)
+ if (t != null && NonFatal(t)) {
+ logError("Error in handleReadWriteExecutor is not handled properly", t)
+ }
+ }
+
+ }
// Use a different, yet smaller, thread pool - infrequently used with very short lived tasks :
// which should be executed asap
@@ -90,7 +117,16 @@ private[nio] class ConnectionManager(
conf.getInt("spark.core.connection.connect.threads.max", 8),
conf.getInt("spark.core.connection.connect.threads.keepalive", 60), TimeUnit.SECONDS,
new LinkedBlockingDeque[Runnable](),
- Utils.namedThreadFactory("handle-connect-executor"))
+ Utils.namedThreadFactory("handle-connect-executor")) {
+
+ override def afterExecute(r: Runnable, t: Throwable): Unit = {
+ super.afterExecute(r, t)
+ if (t != null && NonFatal(t)) {
+ logError("Error in handleConnectExecutor is not handled properly", t)
+ }
+ }
+
+ }
private val serverChannel = ServerSocketChannel.open()
// used to track the SendingConnections waiting to do SASL negotiation
@@ -155,17 +191,24 @@ private[nio] class ConnectionManager(
}
handleReadWriteExecutor.execute(new Runnable {
override def run() {
- var register: Boolean = false
try {
- register = conn.write()
- } finally {
- writeRunnableStarted.synchronized {
- writeRunnableStarted -= key
- val needReregister = register || conn.resetForceReregister()
- if (needReregister && conn.changeInterestForWrite()) {
- conn.registerInterest()
+ var register: Boolean = false
+ try {
+ register = conn.write()
+ } finally {
+ writeRunnableStarted.synchronized {
+ writeRunnableStarted -= key
+ val needReregister = register || conn.resetForceReregister()
+ if (needReregister && conn.changeInterestForWrite()) {
+ conn.registerInterest()
+ }
}
}
+ } catch {
+ case NonFatal(e) => {
+ logError("Error when writing to " + conn.getRemoteConnectionManagerId(), e)
+ conn.callOnExceptionCallbacks(e)
+ }
}
}
} )
@@ -189,16 +232,23 @@ private[nio] class ConnectionManager(
}
handleReadWriteExecutor.execute(new Runnable {
override def run() {
- var register: Boolean = false
try {
- register = conn.read()
- } finally {
- readRunnableStarted.synchronized {
- readRunnableStarted -= key
- if (register && conn.changeInterestForRead()) {
- conn.registerInterest()
+ var register: Boolean = false
+ try {
+ register = conn.read()
+ } finally {
+ readRunnableStarted.synchronized {
+ readRunnableStarted -= key
+ if (register && conn.changeInterestForRead()) {
+ conn.registerInterest()
+ }
}
}
+ } catch {
+ case NonFatal(e) => {
+ logError("Error when reading from " + conn.getRemoteConnectionManagerId(), e)
+ conn.callOnExceptionCallbacks(e)
+ }
}
}
} )
@@ -215,19 +265,25 @@ private[nio] class ConnectionManager(
handleConnectExecutor.execute(new Runnable {
override def run() {
+ try {
+ var tries: Int = 10
+ while (tries >= 0) {
+ if (conn.finishConnect(false)) return
+ // Sleep ?
+ Thread.sleep(1)
+ tries -= 1
+ }
- var tries: Int = 10
- while (tries >= 0) {
- if (conn.finishConnect(false)) return
- // Sleep ?
- Thread.sleep(1)
- tries -= 1
+ // fallback to previous behavior : we should not really come here since this method was
+ // triggered since channel became connectable : but at times, the first finishConnect need
+ // not succeed : hence the loop to retry a few 'times'.
+ conn.finishConnect(true)
+ } catch {
+ case NonFatal(e) => {
+ logError("Error when finishConnect for " + conn.getRemoteConnectionManagerId(), e)
+ conn.callOnExceptionCallbacks(e)
+ }
}
-
- // fallback to previous behavior : we should not really come here since this method was
- // triggered since channel became connectable : but at times, the first finishConnect need
- // not succeed : hence the loop to retry a few 'times'.
- conn.finishConnect(true)
}
} )
}
@@ -248,16 +304,16 @@ private[nio] class ConnectionManager(
handleConnectExecutor.execute(new Runnable {
override def run() {
try {
- conn.callOnExceptionCallback(e)
+ conn.callOnExceptionCallbacks(e)
} catch {
// ignore exceptions
- case e: Exception => logDebug("Ignoring exception", e)
+ case NonFatal(e) => logDebug("Ignoring exception", e)
}
try {
conn.close()
} catch {
// ignore exceptions
- case e: Exception => logDebug("Ignoring exception", e)
+ case NonFatal(e) => logDebug("Ignoring exception", e)
}
}
})
@@ -409,7 +465,8 @@ private[nio] class ConnectionManager(
while (newChannel != null) {
try {
val newConnectionId = new ConnectionId(id, idCount.getAndIncrement.intValue)
- val newConnection = new ReceivingConnection(newChannel, selector, newConnectionId)
+ val newConnection = new ReceivingConnection(newChannel, selector, newConnectionId,
+ securityManager)
newConnection.onReceive(receiveMessage)
addListeners(newConnection)
addConnection(newConnection)
@@ -449,7 +506,7 @@ private[nio] class ConnectionManager(
messageStatuses.values.filter(_.connectionManagerId == sendingConnectionManagerId)
.foreach(status => {
logInfo("Notifying " + status)
- status.markDone(None)
+ status.failWithoutAck()
})
messageStatuses.retain((i, status) => {
@@ -478,7 +535,7 @@ private[nio] class ConnectionManager(
for (s <- messageStatuses.values
if s.connectionManagerId == sendingConnectionManagerId) {
logInfo("Notifying " + s)
- s.markDone(None)
+ s.failWithoutAck()
}
messageStatuses.retain((i, status) => {
@@ -493,7 +550,7 @@ private[nio] class ConnectionManager(
}
}
- def handleConnectionError(connection: Connection, e: Exception) {
+ def handleConnectionError(connection: Connection, e: Throwable) {
logInfo("Handling connection error on connection to " +
connection.getRemoteConnectionManagerId())
removeConnection(connection)
@@ -511,9 +568,17 @@ private[nio] class ConnectionManager(
val runnable = new Runnable() {
val creationTime = System.currentTimeMillis
def run() {
- logDebug("Handler thread delay is " + (System.currentTimeMillis - creationTime) + " ms")
- handleMessage(connectionManagerId, message, connection)
- logDebug("Handling delay is " + (System.currentTimeMillis - creationTime) + " ms")
+ try {
+ logDebug("Handler thread delay is " + (System.currentTimeMillis - creationTime) + " ms")
+ handleMessage(connectionManagerId, message, connection)
+ logDebug("Handling delay is " + (System.currentTimeMillis - creationTime) + " ms")
+ } catch {
+ case NonFatal(e) => {
+ logError("Error when handling messages from " +
+ connection.getRemoteConnectionManagerId(), e)
+ connection.callOnExceptionCallbacks(e)
+ }
+ }
}
}
handleMessageExecutor.execute(runnable)
@@ -527,9 +592,8 @@ private[nio] class ConnectionManager(
if (waitingConn.isSaslComplete()) {
logDebug("Client sasl completed for id: " + waitingConn.connectionId)
connectionsAwaitingSasl -= waitingConn.connectionId
- waitingConn.getAuthenticated().synchronized {
- waitingConn.getAuthenticated().notifyAll()
- }
+ waitingConn.registerAfterAuth()
+ wakeupSelector()
return
} else {
var replyToken : Array[Byte] = null
@@ -538,9 +602,8 @@ private[nio] class ConnectionManager(
if (waitingConn.isSaslComplete()) {
logDebug("Client sasl completed after evaluate for id: " + waitingConn.connectionId)
connectionsAwaitingSasl -= waitingConn.connectionId
- waitingConn.getAuthenticated().synchronized {
- waitingConn.getAuthenticated().notifyAll()
- }
+ waitingConn.registerAfterAuth()
+ wakeupSelector()
return
}
val securityMsgResp = SecurityMessage.fromResponse(replyToken,
@@ -574,9 +637,11 @@ private[nio] class ConnectionManager(
}
replyToken = connection.sparkSaslServer.response(securityMsg.getToken)
if (connection.isSaslComplete()) {
- logDebug("Server sasl completed: " + connection.connectionId)
+ logDebug("Server sasl completed: " + connection.connectionId +
+ " for: " + connectionId)
} else {
- logDebug("Server sasl not completed: " + connection.connectionId)
+ logDebug("Server sasl not completed: " + connection.connectionId +
+ " for: " + connectionId)
}
if (replyToken != null) {
val securityMsgResp = SecurityMessage.fromResponse(replyToken,
@@ -652,7 +717,7 @@ private[nio] class ConnectionManager(
messageStatuses.get(bufferMessage.ackId) match {
case Some(status) => {
messageStatuses -= bufferMessage.ackId
- status.markDone(Some(message))
+ status.success(message)
}
case None => {
/**
@@ -692,9 +757,7 @@ private[nio] class ConnectionManager(
} catch {
case e: Exception => {
logError(s"Exception was thrown while processing message", e)
- val m = Message.createBufferMessage(bufferMessage.id)
- m.hasError = true
- ackMessage = Some(m)
+ ackMessage = Some(Message.createErrorMessage(e, bufferMessage.id))
}
} finally {
sendMessage(connectionManagerId, ackMessage.getOrElse {
@@ -723,7 +786,8 @@ private[nio] class ConnectionManager(
if (message == null) throw new Exception("Error creating security message")
connectionsAwaitingSasl += ((conn.connectionId, conn))
sendSecurityMessage(connManagerId, message)
- logDebug("adding connectionsAwaitingSasl id: " + conn.connectionId)
+ logDebug("adding connectionsAwaitingSasl id: " + conn.connectionId +
+ " to: " + connManagerId)
} catch {
case e: Exception => {
logError("Error getting first response from the SaslClient.", e)
@@ -744,7 +808,7 @@ private[nio] class ConnectionManager(
val inetSocketAddress = new InetSocketAddress(connManagerId.host, connManagerId.port)
val newConnectionId = new ConnectionId(id, idCount.getAndIncrement.intValue)
val newConnection = new SendingConnection(inetSocketAddress, selector, connManagerId,
- newConnectionId)
+ newConnectionId, securityManager)
logInfo("creating new sending connection for security! " + newConnectionId )
registerRequests.enqueue(newConnection)
@@ -769,64 +833,55 @@ private[nio] class ConnectionManager(
connectionManagerId.port)
val newConnectionId = new ConnectionId(id, idCount.getAndIncrement.intValue)
val newConnection = new SendingConnection(inetSocketAddress, selector, connectionManagerId,
- newConnectionId)
+ newConnectionId, securityManager)
+ newConnection.onException {
+ case (conn, e) => {
+ logError("Exception while sending message.", e)
+ reportSendingMessageFailure(message.id, e)
+ }
+ }
logTrace("creating new sending connection: " + newConnectionId)
registerRequests.enqueue(newConnection)
newConnection
}
val connection = connectionsById.getOrElseUpdate(connectionManagerId, startNewConnection())
- if (authEnabled) {
- checkSendAuthFirst(connectionManagerId, connection)
- }
+
message.senderAddress = id.toSocketAddress()
logDebug("Before Sending [" + message + "] to [" + connectionManagerId + "]" + " " +
"connectionid: " + connection.connectionId)
if (authEnabled) {
- // if we aren't authenticated yet lets block the senders until authentication completes
try {
- connection.getAuthenticated().synchronized {
- val clock = SystemClock
- val startTime = clock.getTime()
-
- while (!connection.isSaslComplete()) {
- logDebug("getAuthenticated wait connectionid: " + connection.connectionId)
- // have timeout in case remote side never responds
- connection.getAuthenticated().wait(500)
- if (((clock.getTime() - startTime) >= (authTimeout * 1000))
- && (!connection.isSaslComplete())) {
- // took to long to authenticate the connection, something probably went wrong
- throw new Exception("Took to long for authentication to " + connectionManagerId +
- ", waited " + authTimeout + "seconds, failing.")
- }
- }
- }
+ checkSendAuthFirst(connectionManagerId, connection)
} catch {
- case e: Exception => logError("Exception while waiting for authentication.", e)
-
- // need to tell sender it failed
- messageStatuses.synchronized {
- val s = messageStatuses.get(message.id)
- s match {
- case Some(msgStatus) => {
- messageStatuses -= message.id
- logInfo("Notifying " + msgStatus.connectionManagerId)
- msgStatus.markDone(None)
- }
- case None => {
- logError("no messageStatus for failed message id: " + message.id)
- }
- }
+ case NonFatal(e) => {
+ reportSendingMessageFailure(message.id, e)
}
}
}
logDebug("Sending [" + message + "] to [" + connectionManagerId + "]")
connection.send(message)
-
wakeupSelector()
}
+ private def reportSendingMessageFailure(messageId: Int, e: Throwable): Unit = {
+ // need to tell sender it failed
+ messageStatuses.synchronized {
+ val s = messageStatuses.get(messageId)
+ s match {
+ case Some(msgStatus) => {
+ messageStatuses -= messageId
+ logInfo("Notifying " + msgStatus.connectionManagerId)
+ msgStatus.failure(e)
+ }
+ case None => {
+ logError("no messageStatus for failed message id: " + messageId)
+ }
+ }
+ }
+ }
+
private def wakeupSelector() {
selector.wakeup()
}
@@ -845,9 +900,11 @@ private[nio] class ConnectionManager(
override def run(): Unit = {
messageStatuses.synchronized {
messageStatuses.remove(message.id).foreach ( s => {
- promise.failure(
- new IOException("sendMessageReliably failed because ack " +
- s"was not received within $ackTimeout sec"))
+ val e = new IOException("sendMessageReliably failed because ack " +
+ s"was not received within $ackTimeout sec")
+ if (!promise.tryFailure(e)) {
+ logWarning("Ignore error because promise is completed", e)
+ }
})
}
}
@@ -855,15 +912,27 @@ private[nio] class ConnectionManager(
val status = new MessageStatus(message, connectionManagerId, s => {
timeoutTask.cancel()
- s.ackMessage match {
- case None => // Indicates a failure where we either never sent or never got ACK'd
- promise.failure(new IOException("sendMessageReliably failed without being ACK'd"))
- case Some(ackMessage) =>
+ s match {
+ case scala.util.Failure(e) =>
+ // Indicates a failure where we either never sent or never got ACK'd
+ if (!promise.tryFailure(e)) {
+ logWarning("Ignore error because promise is completed", e)
+ }
+ case scala.util.Success(ackMessage) =>
if (ackMessage.hasError) {
- promise.failure(
- new IOException("sendMessageReliably failed with ACK that signalled a remote error"))
+ val errorMsgByteBuf = ackMessage.asInstanceOf[BufferMessage].buffers.head
+ val errorMsgBytes = new Array[Byte](errorMsgByteBuf.limit())
+ errorMsgByteBuf.get(errorMsgBytes)
+ val errorMsg = new String(errorMsgBytes, "utf-8")
+ val e = new IOException(
+ s"sendMessageReliably failed with ACK that signalled a remote error: $errorMsg")
+ if (!promise.tryFailure(e)) {
+ logWarning("Ignore error because promise is completed", e)
+ }
} else {
- promise.success(ackMessage)
+ if (!promise.trySuccess(ackMessage)) {
+ logWarning("Drop ackMessage because promise is completed")
+ }
}
}
})
diff --git a/core/src/main/scala/org/apache/spark/network/nio/Message.scala b/core/src/main/scala/org/apache/spark/network/nio/Message.scala
index 0b874c2891255..3ad04591da658 100644
--- a/core/src/main/scala/org/apache/spark/network/nio/Message.scala
+++ b/core/src/main/scala/org/apache/spark/network/nio/Message.scala
@@ -22,6 +22,7 @@ import java.nio.ByteBuffer
import scala.collection.mutable.ArrayBuffer
+import org.apache.spark.util.Utils
private[nio] abstract class Message(val typ: Long, val id: Int) {
var senderAddress: InetSocketAddress = null
@@ -84,6 +85,19 @@ private[nio] object Message {
createBufferMessage(new Array[ByteBuffer](0), ackId)
}
+ /**
+ * Create a "negative acknowledgment" to notify a sender that an error occurred
+ * while processing its message. The exception's stacktrace will be formatted
+ * as a string, serialized into a byte array, and sent as the message payload.
+ */
+ def createErrorMessage(exception: Exception, ackId: Int): BufferMessage = {
+ val exceptionString = Utils.exceptionString(exception)
+ val serializedExceptionString = ByteBuffer.wrap(exceptionString.getBytes("utf-8"))
+ val errorMessage = createBufferMessage(serializedExceptionString, ackId)
+ errorMessage.hasError = true
+ errorMessage
+ }
+
def create(header: MessageChunkHeader): Message = {
val newMessage: Message = header.typ match {
case BUFFER_MESSAGE => new BufferMessage(header.id,
diff --git a/core/src/main/scala/org/apache/spark/network/nio/NioBlockTransferService.scala b/core/src/main/scala/org/apache/spark/network/nio/NioBlockTransferService.scala
index b389b9a2022c6..5add4fc433fb3 100644
--- a/core/src/main/scala/org/apache/spark/network/nio/NioBlockTransferService.scala
+++ b/core/src/main/scala/org/apache/spark/network/nio/NioBlockTransferService.scala
@@ -151,17 +151,14 @@ final class NioBlockTransferService(conf: SparkConf, securityManager: SecurityMa
} catch {
case e: Exception => {
logError("Exception handling buffer message", e)
- val errorMessage = Message.createBufferMessage(msg.id)
- errorMessage.hasError = true
- Some(errorMessage)
+ Some(Message.createErrorMessage(e, msg.id))
}
}
case otherMessage: Any =>
- logError("Unknown type message received: " + otherMessage)
- val errorMessage = Message.createBufferMessage(msg.id)
- errorMessage.hasError = true
- Some(errorMessage)
+ val errorMsg = s"Received unknown message type: ${otherMessage.getClass.getName}"
+ logError(errorMsg)
+ Some(Message.createErrorMessage(new UnsupportedOperationException(errorMsg), msg.id))
}
}
diff --git a/core/src/main/scala/org/apache/spark/rdd/AsyncRDDActions.scala b/core/src/main/scala/org/apache/spark/rdd/AsyncRDDActions.scala
index b62f3fbdc4a15..9f9f10b7ebc3a 100644
--- a/core/src/main/scala/org/apache/spark/rdd/AsyncRDDActions.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/AsyncRDDActions.scala
@@ -24,14 +24,11 @@ import scala.concurrent.ExecutionContext.Implicits.global
import scala.reflect.ClassTag
import org.apache.spark.{ComplexFutureAction, FutureAction, Logging}
-import org.apache.spark.annotation.Experimental
/**
- * :: Experimental ::
* A set of asynchronous RDD actions available through an implicit conversion.
* Import `org.apache.spark.SparkContext._` at the top of your program to use these functions.
*/
-@Experimental
class AsyncRDDActions[T: ClassTag](self: RDD[T]) extends Serializable with Logging {
/**
@@ -78,16 +75,18 @@ class AsyncRDDActions[T: ClassTag](self: RDD[T]) extends Serializable with Loggi
// greater than totalParts because we actually cap it at totalParts in runJob.
var numPartsToTry = 1
if (partsScanned > 0) {
- // If we didn't find any rows after the first iteration, just try all partitions next.
+ // If we didn't find any rows after the previous iteration, quadruple and retry.
// Otherwise, interpolate the number of partitions we need to try, but overestimate it
- // by 50%.
+ // by 50%. We also cap the estimation in the end.
if (results.size == 0) {
- numPartsToTry = totalParts - 1
+ numPartsToTry = partsScanned * 4
} else {
- numPartsToTry = (1.5 * num * partsScanned / results.size).toInt
+ // the left side of max is >=1 whenever partsScanned >= 2
+ numPartsToTry = Math.max(1,
+ (1.5 * num * partsScanned / results.size).toInt - partsScanned)
+ numPartsToTry = Math.min(numPartsToTry, partsScanned * 4)
}
}
- numPartsToTry = math.max(0, numPartsToTry) // guard against negative num of partitions
val left = num - results.size
val p = partsScanned until math.min(partsScanned + numPartsToTry, totalParts)
diff --git a/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala b/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala
index 21d0cc7b5cbaa..775141775e06c 100644
--- a/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala
@@ -23,6 +23,7 @@ import java.io.EOFException
import scala.collection.immutable.Map
import scala.reflect.ClassTag
+import scala.collection.mutable.ListBuffer
import org.apache.hadoop.conf.{Configurable, Configuration}
import org.apache.hadoop.mapred.FileSplit
@@ -43,6 +44,7 @@ import org.apache.spark.deploy.SparkHadoopUtil
import org.apache.spark.executor.{DataReadMethod, InputMetrics}
import org.apache.spark.rdd.HadoopRDD.HadoopMapPartitionsWithSplitRDD
import org.apache.spark.util.{NextIterator, Utils}
+import org.apache.spark.scheduler.{HostTaskLocation, HDFSCacheTaskLocation}
/**
@@ -130,27 +132,47 @@ class HadoopRDD[K, V](
// used to build JobTracker ID
private val createTime = new Date()
+ private val shouldCloneJobConf = sc.conf.get("spark.hadoop.cloneConf", "false").toBoolean
+
// Returns a JobConf that will be used on slaves to obtain input splits for Hadoop reads.
protected def getJobConf(): JobConf = {
val conf: Configuration = broadcastedConf.value.value
- if (conf.isInstanceOf[JobConf]) {
- // A user-broadcasted JobConf was provided to the HadoopRDD, so always use it.
- conf.asInstanceOf[JobConf]
- } else if (HadoopRDD.containsCachedMetadata(jobConfCacheKey)) {
- // getJobConf() has been called previously, so there is already a local cache of the JobConf
- // needed by this RDD.
- HadoopRDD.getCachedMetadata(jobConfCacheKey).asInstanceOf[JobConf]
- } else {
- // Create a JobConf that will be cached and used across this RDD's getJobConf() calls in the
- // local process. The local cache is accessed through HadoopRDD.putCachedMetadata().
- // The caching helps minimize GC, since a JobConf can contain ~10KB of temporary objects.
- // Synchronize to prevent ConcurrentModificationException (Spark-1097, Hadoop-10456).
+ if (shouldCloneJobConf) {
+ // Hadoop Configuration objects are not thread-safe, which may lead to various problems if
+ // one job modifies a configuration while another reads it (SPARK-2546). This problem occurs
+ // somewhat rarely because most jobs treat the configuration as though it's immutable. One
+ // solution, implemented here, is to clone the Configuration object. Unfortunately, this
+ // clone can be very expensive. To avoid unexpected performance regressions for workloads and
+ // Hadoop versions that do not suffer from these thread-safety issues, this cloning is
+ // disabled by default.
HadoopRDD.CONFIGURATION_INSTANTIATION_LOCK.synchronized {
+ logDebug("Cloning Hadoop Configuration")
val newJobConf = new JobConf(conf)
- initLocalJobConfFuncOpt.map(f => f(newJobConf))
- HadoopRDD.putCachedMetadata(jobConfCacheKey, newJobConf)
+ if (!conf.isInstanceOf[JobConf]) {
+ initLocalJobConfFuncOpt.map(f => f(newJobConf))
+ }
newJobConf
}
+ } else {
+ if (conf.isInstanceOf[JobConf]) {
+ logDebug("Re-using user-broadcasted JobConf")
+ conf.asInstanceOf[JobConf]
+ } else if (HadoopRDD.containsCachedMetadata(jobConfCacheKey)) {
+ logDebug("Re-using cached JobConf")
+ HadoopRDD.getCachedMetadata(jobConfCacheKey).asInstanceOf[JobConf]
+ } else {
+ // Create a JobConf that will be cached and used across this RDD's getJobConf() calls in the
+ // local process. The local cache is accessed through HadoopRDD.putCachedMetadata().
+ // The caching helps minimize GC, since a JobConf can contain ~10KB of temporary objects.
+ // Synchronize to prevent ConcurrentModificationException (SPARK-1097, HADOOP-10456).
+ HadoopRDD.CONFIGURATION_INSTANTIATION_LOCK.synchronized {
+ logDebug("Creating new JobConf and caching it for later re-use")
+ val newJobConf = new JobConf(conf)
+ initLocalJobConfFuncOpt.map(f => f(newJobConf))
+ HadoopRDD.putCachedMetadata(jobConfCacheKey, newJobConf)
+ newJobConf
+ }
+ }
}
}
@@ -194,7 +216,7 @@ class HadoopRDD[K, V](
val jobConf = getJobConf()
val inputFormat = getInputFormat(jobConf)
HadoopRDD.addLocalConfiguration(new SimpleDateFormat("yyyyMMddHHmm").format(createTime),
- context.getStageId, theSplit.index, context.getAttemptId.toInt, jobConf)
+ context.stageId, theSplit.index, context.attemptId.toInt, jobConf)
reader = inputFormat.getRecordReader(split.inputSplit.value, jobConf, Reporter.NULL)
// Register an on-task-completion callback to close the input stream.
@@ -249,9 +271,21 @@ class HadoopRDD[K, V](
}
override def getPreferredLocations(split: Partition): Seq[String] = {
- // TODO: Filtering out "localhost" in case of file:// URLs
- val hadoopSplit = split.asInstanceOf[HadoopPartition]
- hadoopSplit.inputSplit.value.getLocations.filter(_ != "localhost")
+ val hsplit = split.asInstanceOf[HadoopPartition].inputSplit.value
+ val locs: Option[Seq[String]] = HadoopRDD.SPLIT_INFO_REFLECTIONS match {
+ case Some(c) =>
+ try {
+ val lsplit = c.inputSplitWithLocationInfo.cast(hsplit)
+ val infos = c.getLocationInfo.invoke(lsplit).asInstanceOf[Array[AnyRef]]
+ Some(HadoopRDD.convertSplitLocationInfo(infos))
+ } catch {
+ case e: Exception =>
+ logDebug("Failed to use InputSplitWithLocations.", e)
+ None
+ }
+ case None => None
+ }
+ locs.getOrElse(hsplit.getLocations.filter(_ != "localhost"))
}
override def checkpoint() {
@@ -261,8 +295,11 @@ class HadoopRDD[K, V](
def getConf: Configuration = getJobConf()
}
-private[spark] object HadoopRDD {
- /** Constructing Configuration objects is not threadsafe, use this lock to serialize. */
+private[spark] object HadoopRDD extends Logging {
+ /**
+ * Configuration's constructor is not threadsafe (see SPARK-1097 and HADOOP-10456).
+ * Therefore, we synchronize on this lock before calling new JobConf() or new Configuration().
+ */
val CONFIGURATION_INSTANTIATION_LOCK = new Object()
/**
@@ -309,4 +346,42 @@ private[spark] object HadoopRDD {
f(inputSplit, firstParent[T].iterator(split, context))
}
}
+
+ private[spark] class SplitInfoReflections {
+ val inputSplitWithLocationInfo =
+ Class.forName("org.apache.hadoop.mapred.InputSplitWithLocationInfo")
+ val getLocationInfo = inputSplitWithLocationInfo.getMethod("getLocationInfo")
+ val newInputSplit = Class.forName("org.apache.hadoop.mapreduce.InputSplit")
+ val newGetLocationInfo = newInputSplit.getMethod("getLocationInfo")
+ val splitLocationInfo = Class.forName("org.apache.hadoop.mapred.SplitLocationInfo")
+ val isInMemory = splitLocationInfo.getMethod("isInMemory")
+ val getLocation = splitLocationInfo.getMethod("getLocation")
+ }
+
+ private[spark] val SPLIT_INFO_REFLECTIONS: Option[SplitInfoReflections] = try {
+ Some(new SplitInfoReflections)
+ } catch {
+ case e: Exception =>
+ logDebug("SplitLocationInfo and other new Hadoop classes are " +
+ "unavailable. Using the older Hadoop location info code.", e)
+ None
+ }
+
+ private[spark] def convertSplitLocationInfo(infos: Array[AnyRef]): Seq[String] = {
+ val out = ListBuffer[String]()
+ infos.foreach { loc => {
+ val locationStr = HadoopRDD.SPLIT_INFO_REFLECTIONS.get.
+ getLocation.invoke(loc).asInstanceOf[String]
+ if (locationStr != "localhost") {
+ if (HadoopRDD.SPLIT_INFO_REFLECTIONS.get.isInMemory.
+ invoke(loc).asInstanceOf[Boolean]) {
+ logDebug("Partition " + locationStr + " is cached by Hadoop.")
+ out += new HDFSCacheTaskLocation(locationStr).toString
+ } else {
+ out += new HostTaskLocation(locationStr).toString
+ }
+ }
+ }}
+ out.seq
+ }
}
diff --git a/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala b/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala
index 4c84b3f62354d..0cccdefc5ee09 100644
--- a/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala
@@ -173,9 +173,21 @@ class NewHadoopRDD[K, V](
new NewHadoopMapPartitionsWithSplitRDD(this, f, preservesPartitioning)
}
- override def getPreferredLocations(split: Partition): Seq[String] = {
- val theSplit = split.asInstanceOf[NewHadoopPartition]
- theSplit.serializableHadoopSplit.value.getLocations.filter(_ != "localhost")
+ override def getPreferredLocations(hsplit: Partition): Seq[String] = {
+ val split = hsplit.asInstanceOf[NewHadoopPartition].serializableHadoopSplit.value
+ val locs = HadoopRDD.SPLIT_INFO_REFLECTIONS match {
+ case Some(c) =>
+ try {
+ val infos = c.newGetLocationInfo.invoke(split).asInstanceOf[Array[AnyRef]]
+ Some(HadoopRDD.convertSplitLocationInfo(infos))
+ } catch {
+ case e : Exception =>
+ logDebug("Failed to use InputSplit#getLocationInfo.", e)
+ None
+ }
+ case None => None
+ }
+ locs.getOrElse(split.getLocations.filter(_ != "localhost"))
}
def getConf: Configuration = confBroadcast.value.value
diff --git a/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala b/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala
index 929ded58a3bd5..ac96de86dd6d4 100644
--- a/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala
@@ -1032,10 +1032,8 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)])
writer.setup(context.stageId, context.partitionId, attemptNumber)
writer.open()
try {
- var count = 0
while (iter.hasNext) {
val record = iter.next()
- count += 1
writer.write(record._1.asInstanceOf[AnyRef], record._2.asInstanceOf[AnyRef])
}
} finally {
diff --git a/core/src/main/scala/org/apache/spark/rdd/PipedRDD.scala b/core/src/main/scala/org/apache/spark/rdd/PipedRDD.scala
index 5d77d37378458..56ac7a69be0d3 100644
--- a/core/src/main/scala/org/apache/spark/rdd/PipedRDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/PipedRDD.scala
@@ -131,7 +131,6 @@ private[spark] class PipedRDD[T: ClassTag](
// Start a thread to feed the process input from our parent's iterator
new Thread("stdin writer for " + command) {
override def run() {
- SparkEnv.set(env)
val out = new PrintWriter(proc.getOutputStream)
// input the pipe context firstly
diff --git a/core/src/main/scala/org/apache/spark/rdd/RDD.scala b/core/src/main/scala/org/apache/spark/rdd/RDD.scala
index ab9e97c8fe409..71cabf61d4ee0 100644
--- a/core/src/main/scala/org/apache/spark/rdd/RDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/RDD.scala
@@ -208,7 +208,7 @@ abstract class RDD[T: ClassTag](
}
/**
- * Get the preferred locations of a partition (as hostnames), taking into account whether the
+ * Get the preferred locations of a partition, taking into account whether the
* RDD is checkpointed.
*/
final def preferredLocations(split: Partition): Seq[String] = {
@@ -1079,15 +1079,17 @@ abstract class RDD[T: ClassTag](
// greater than totalParts because we actually cap it at totalParts in runJob.
var numPartsToTry = 1
if (partsScanned > 0) {
- // If we didn't find any rows after the previous iteration, quadruple and retry. Otherwise,
+ // If we didn't find any rows after the previous iteration, quadruple and retry. Otherwise,
// interpolate the number of partitions we need to try, but overestimate it by 50%.
+ // We also cap the estimation in the end.
if (buf.size == 0) {
numPartsToTry = partsScanned * 4
} else {
- numPartsToTry = (1.5 * num * partsScanned / buf.size).toInt
+ // the left side of max is >=1 whenever partsScanned >= 2
+ numPartsToTry = Math.max((1.5 * num * partsScanned / buf.size).toInt - partsScanned, 1)
+ numPartsToTry = Math.min(numPartsToTry, partsScanned * 4)
}
}
- numPartsToTry = math.max(0, numPartsToTry) // guard against negative num of partitions
val left = num - buf.size
val p = partsScanned until math.min(partsScanned + numPartsToTry, totalParts)
diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
index 5a96f52a10cd4..f81fa6d8089fc 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
@@ -630,18 +630,17 @@ class DAGScheduler(
protected def runLocallyWithinThread(job: ActiveJob) {
var jobResult: JobResult = JobSucceeded
try {
- SparkEnv.set(env)
val rdd = job.finalStage.rdd
val split = rdd.partitions(job.partitions(0))
val taskContext =
- new TaskContext(job.finalStage.id, job.partitions(0), 0, true)
- TaskContext.setTaskContext(taskContext)
+ new TaskContextImpl(job.finalStage.id, job.partitions(0), 0, true)
+ TaskContextHelper.setTaskContext(taskContext)
try {
val result = job.func(taskContext, rdd.iterator(split, taskContext))
job.listener.taskSucceeded(0, result)
} finally {
taskContext.markTaskCompleted()
- TaskContext.unset()
+ TaskContextHelper.unset()
}
} catch {
case e: Exception =>
@@ -1303,7 +1302,7 @@ class DAGScheduler(
// If the RDD has some placement preferences (as is the case for input RDDs), get those
val rddPrefs = rdd.preferredLocations(rdd.partitions(partition)).toList
if (!rddPrefs.isEmpty) {
- return rddPrefs.map(host => TaskLocation(host))
+ return rddPrefs.map(TaskLocation(_))
}
// If the RDD has narrow dependencies, pick the first partition of the first narrow dep
// that has any placement preferences. Ideally we would choose based on transfer sizes,
diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerSource.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerSource.scala
index 94944399b134a..12668b6c0988e 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerSource.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerSource.scala
@@ -22,10 +22,10 @@ import com.codahale.metrics.{Gauge,MetricRegistry}
import org.apache.spark.SparkContext
import org.apache.spark.metrics.source.Source
-private[spark] class DAGSchedulerSource(val dagScheduler: DAGScheduler, sc: SparkContext)
+private[spark] class DAGSchedulerSource(val dagScheduler: DAGScheduler)
extends Source {
override val metricRegistry = new MetricRegistry()
- override val sourceName = "%s.DAGScheduler".format(sc.appName)
+ override val sourceName = "DAGScheduler"
metricRegistry.register(MetricRegistry.name("stage", "failedStages"), new Gauge[Int] {
override def getValue: Int = dagScheduler.failedStages.size
diff --git a/core/src/main/scala/org/apache/spark/scheduler/EventLoggingListener.scala b/core/src/main/scala/org/apache/spark/scheduler/EventLoggingListener.scala
index 64b32ae0edaac..100c9ba9b7809 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/EventLoggingListener.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/EventLoggingListener.scala
@@ -43,38 +43,29 @@ import org.apache.spark.util.{FileLogger, JsonProtocol, Utils}
* spark.eventLog.buffer.kb - Buffer size to use when writing to output streams
*/
private[spark] class EventLoggingListener(
- appName: String,
+ appId: String,
+ logBaseDir: String,
sparkConf: SparkConf,
hadoopConf: Configuration)
extends SparkListener with Logging {
import EventLoggingListener._
- def this(appName: String, sparkConf: SparkConf) =
- this(appName, sparkConf, SparkHadoopUtil.get.newConfiguration(sparkConf))
+ def this(appId: String, logBaseDir: String, sparkConf: SparkConf) =
+ this(appId, logBaseDir, sparkConf, SparkHadoopUtil.get.newConfiguration(sparkConf))
private val shouldCompress = sparkConf.getBoolean("spark.eventLog.compress", false)
private val shouldOverwrite = sparkConf.getBoolean("spark.eventLog.overwrite", false)
private val testing = sparkConf.getBoolean("spark.eventLog.testing", false)
private val outputBufferSize = sparkConf.getInt("spark.eventLog.buffer.kb", 100) * 1024
- private val logBaseDir = sparkConf.get("spark.eventLog.dir", DEFAULT_LOG_DIR).stripSuffix("/")
- private val name = appName.replaceAll("[ :/]", "-").replaceAll("[${}'\"]", "_")
- .toLowerCase + "-" + System.currentTimeMillis
- val logDir = Utils.resolveURI(logBaseDir) + "/" + name.stripSuffix("/")
-
+ val logDir = EventLoggingListener.getLogDirPath(logBaseDir, appId)
+ val logDirName: String = logDir.split("/").last
protected val logger = new FileLogger(logDir, sparkConf, hadoopConf, outputBufferSize,
shouldCompress, shouldOverwrite, Some(LOG_FILE_PERMISSIONS))
// For testing. Keep track of all JSON serialized events that have been logged.
private[scheduler] val loggedEvents = new ArrayBuffer[JValue]
- /**
- * Return only the unique application directory without the base directory.
- */
- def getApplicationLogDir(): String = {
- name
- }
-
/**
* Begin logging events.
* If compression is used, log a file that indicates which compression library is used.
@@ -184,6 +175,18 @@ private[spark] object EventLoggingListener extends Logging {
} else ""
}
+ /**
+ * Return a file-system-safe path to the log directory for the given application.
+ *
+ * @param logBaseDir A base directory for the path to the log directory for given application.
+ * @param appId A unique app ID.
+ * @return A path which consists of file-system-safe characters.
+ */
+ def getLogDirPath(logBaseDir: String, appId: String): String = {
+ val name = appId.replaceAll("[ :/]", "-").replaceAll("[${}'\"]", "_").toLowerCase
+ Utils.resolveURI(logBaseDir) + "/" + name.stripSuffix("/")
+ }
+
/**
* Parse the event logging information associated with the logs in the given directory.
*
diff --git a/core/src/main/scala/org/apache/spark/scheduler/SchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/SchedulerBackend.scala
index a0be8307eff27..992c477493d8e 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/SchedulerBackend.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/SchedulerBackend.scala
@@ -23,6 +23,8 @@ package org.apache.spark.scheduler
* machines become available and can launch tasks on them.
*/
private[spark] trait SchedulerBackend {
+ private val appId = "spark-application-" + System.currentTimeMillis
+
def start(): Unit
def stop(): Unit
def reviveOffers(): Unit
@@ -33,10 +35,10 @@ private[spark] trait SchedulerBackend {
def isReady(): Boolean = true
/**
- * The application ID associated with the job, if any.
+ * Get an application ID associated with the job.
*
- * @return The application ID, or None if the backend does not provide an ID.
+ * @return An application ID
*/
- def applicationId(): Option[String] = None
+ def applicationId(): String = appId
}
diff --git a/core/src/main/scala/org/apache/spark/scheduler/Task.scala b/core/src/main/scala/org/apache/spark/scheduler/Task.scala
index c6e47c84a0cb2..2552d03d18d06 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/Task.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/Task.scala
@@ -22,7 +22,7 @@ import java.nio.ByteBuffer
import scala.collection.mutable.HashMap
-import org.apache.spark.TaskContext
+import org.apache.spark.{TaskContextHelper, TaskContextImpl, TaskContext}
import org.apache.spark.executor.TaskMetrics
import org.apache.spark.serializer.SerializerInstance
import org.apache.spark.util.ByteBufferInputStream
@@ -45,8 +45,8 @@ import org.apache.spark.util.Utils
private[spark] abstract class Task[T](val stageId: Int, var partitionId: Int) extends Serializable {
final def run(attemptId: Long): T = {
- context = new TaskContext(stageId, partitionId, attemptId, false)
- TaskContext.setTaskContext(context)
+ context = new TaskContextImpl(stageId, partitionId, attemptId, false)
+ TaskContextHelper.setTaskContext(context)
context.taskMetrics.hostname = Utils.localHostName()
taskThread = Thread.currentThread()
if (_killed) {
@@ -56,7 +56,7 @@ private[spark] abstract class Task[T](val stageId: Int, var partitionId: Int) ex
runTask(context)
} finally {
context.markTaskCompleted()
- TaskContext.unset()
+ TaskContextHelper.unset()
}
}
@@ -70,7 +70,7 @@ private[spark] abstract class Task[T](val stageId: Int, var partitionId: Int) ex
var metrics: Option[TaskMetrics] = None
// Task context, to be initialized in run().
- @transient protected var context: TaskContext = _
+ @transient protected var context: TaskContextImpl = _
// The actual Thread on which the task is running, if any. Initialized in run().
@volatile @transient private var taskThread: Thread = _
diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskLocation.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskLocation.scala
index 67c9a6760b1b3..10c685f29d3ac 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/TaskLocation.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/TaskLocation.scala
@@ -22,13 +22,51 @@ package org.apache.spark.scheduler
* In the latter case, we will prefer to launch the task on that executorID, but our next level
* of preference will be executors on the same host if this is not possible.
*/
-private[spark]
-class TaskLocation private (val host: String, val executorId: Option[String]) extends Serializable {
- override def toString: String = "TaskLocation(" + host + ", " + executorId + ")"
+private[spark] sealed trait TaskLocation {
+ def host: String
+}
+
+/**
+ * A location that includes both a host and an executor id on that host.
+ */
+private [spark] case class ExecutorCacheTaskLocation(override val host: String,
+ val executorId: String) extends TaskLocation {
+}
+
+/**
+ * A location on a host.
+ */
+private [spark] case class HostTaskLocation(override val host: String) extends TaskLocation {
+ override def toString = host
+}
+
+/**
+ * A location on a host that is cached by HDFS.
+ */
+private [spark] case class HDFSCacheTaskLocation(override val host: String)
+ extends TaskLocation {
+ override def toString = TaskLocation.inMemoryLocationTag + host
}
private[spark] object TaskLocation {
- def apply(host: String, executorId: String) = new TaskLocation(host, Some(executorId))
+ // We identify hosts on which the block is cached with this prefix. Because this prefix contains
+ // underscores, which are not legal characters in hostnames, there should be no potential for
+ // confusion. See RFC 952 and RFC 1123 for information about the format of hostnames.
+ val inMemoryLocationTag = "hdfs_cache_"
+
+ def apply(host: String, executorId: String) = new ExecutorCacheTaskLocation(host, executorId)
- def apply(host: String) = new TaskLocation(host, None)
+ /**
+ * Create a TaskLocation from a string returned by getPreferredLocations.
+ * These strings have the form [hostname] or hdfs_cache_[hostname], depending on whether the
+ * location is cached.
+ */
+ def apply(str: String) = {
+ val hstr = str.stripPrefix(inMemoryLocationTag)
+ if (hstr.equals(str)) {
+ new HostTaskLocation(str)
+ } else {
+ new HostTaskLocation(hstr)
+ }
+ }
}
diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskResultGetter.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskResultGetter.scala
index df59f444b7a0e..3f345ceeaaf7a 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/TaskResultGetter.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/TaskResultGetter.scala
@@ -19,6 +19,8 @@ package org.apache.spark.scheduler
import java.nio.ByteBuffer
+import scala.util.control.NonFatal
+
import org.apache.spark._
import org.apache.spark.TaskState.TaskState
import org.apache.spark.serializer.SerializerInstance
@@ -32,7 +34,7 @@ private[spark] class TaskResultGetter(sparkEnv: SparkEnv, scheduler: TaskSchedul
private val THREADS = sparkEnv.conf.getInt("spark.resultGetter.threads", 4)
private val getTaskResultExecutor = Utils.newDaemonFixedThreadPool(
- THREADS, "Result resolver thread")
+ THREADS, "task-result-getter")
protected val serializer = new ThreadLocal[SerializerInstance] {
override def initialValue(): SerializerInstance = {
@@ -70,7 +72,8 @@ private[spark] class TaskResultGetter(sparkEnv: SparkEnv, scheduler: TaskSchedul
case cnf: ClassNotFoundException =>
val loader = Thread.currentThread.getContextClassLoader
taskSetManager.abort("ClassNotFound with classloader: " + loader)
- case ex: Exception =>
+ // Matching NonFatal so we don't catch the ControlThrowable from the "return" above.
+ case NonFatal(ex) =>
logError("Exception while getting task result", ex)
taskSetManager.abort("Exception while getting task result: %s".format(ex))
}
diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskScheduler.scala
index 1c1ce666eab0f..a129a434c9a1a 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/TaskScheduler.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/TaskScheduler.scala
@@ -31,6 +31,8 @@ import org.apache.spark.storage.BlockManagerId
*/
private[spark] trait TaskScheduler {
+ private val appId = "spark-application-" + System.currentTimeMillis
+
def rootPool: Pool
def schedulingMode: SchedulingMode
@@ -66,10 +68,10 @@ private[spark] trait TaskScheduler {
blockManagerId: BlockManagerId): Boolean
/**
- * The application ID associated with the job, if any.
+ * Get an application ID associated with the job.
*
- * @return The application ID, or None if the backend does not provide an ID.
+ * @return An application ID
*/
- def applicationId(): Option[String] = None
+ def applicationId(): String = appId
}
diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala
index 633e892554c50..6d697e3d003f6 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala
@@ -216,8 +216,6 @@ private[spark] class TaskSchedulerImpl(
* that tasks are balanced across the cluster.
*/
def resourceOffers(offers: Seq[WorkerOffer]): Seq[Seq[TaskDescription]] = synchronized {
- SparkEnv.set(sc.env)
-
// Mark each slave as alive and remember its hostname
// Also track if new executor is added
var newExecAvail = false
@@ -492,7 +490,7 @@ private[spark] class TaskSchedulerImpl(
}
}
- override def applicationId(): Option[String] = backend.applicationId()
+ override def applicationId(): String = backend.applicationId()
}
diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala
index d9d53faf843ff..a6c23fc85a1b0 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala
@@ -181,8 +181,24 @@ private[spark] class TaskSetManager(
}
for (loc <- tasks(index).preferredLocations) {
- for (execId <- loc.executorId) {
- addTo(pendingTasksForExecutor.getOrElseUpdate(execId, new ArrayBuffer))
+ loc match {
+ case e: ExecutorCacheTaskLocation =>
+ addTo(pendingTasksForExecutor.getOrElseUpdate(e.executorId, new ArrayBuffer))
+ case e: HDFSCacheTaskLocation => {
+ val exe = sched.getExecutorsAliveOnHost(loc.host)
+ exe match {
+ case Some(set) => {
+ for (e <- set) {
+ addTo(pendingTasksForExecutor.getOrElseUpdate(e, new ArrayBuffer))
+ }
+ logInfo(s"Pending task $index has a cached location at ${e.host} " +
+ ", where there are executors " + set.mkString(","))
+ }
+ case None => logDebug(s"Pending task $index has a cached location at ${e.host} " +
+ ", but there are no executors alive there.")
+ }
+ }
+ case _ => Unit
}
addTo(pendingTasksForHost.getOrElseUpdate(loc.host, new ArrayBuffer))
for (rack <- sched.getRackForHost(loc.host)) {
@@ -283,7 +299,10 @@ private[spark] class TaskSetManager(
// on multiple nodes when we replicate cached blocks, as in Spark Streaming
for (index <- speculatableTasks if canRunOnHost(index)) {
val prefs = tasks(index).preferredLocations
- val executors = prefs.flatMap(_.executorId)
+ val executors = prefs.flatMap(_ match {
+ case e: ExecutorCacheTaskLocation => Some(e.executorId)
+ case _ => None
+ });
if (executors.contains(execId)) {
speculatableTasks -= index
return Some((index, TaskLocality.PROCESS_LOCAL))
diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedClusterMessage.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedClusterMessage.scala
index 6abf6d930c155..fb8160abc59db 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedClusterMessage.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedClusterMessage.scala
@@ -66,7 +66,7 @@ private[spark] object CoarseGrainedClusterMessages {
case class RemoveExecutor(executorId: String, reason: String) extends CoarseGrainedClusterMessage
- case class AddWebUIFilter(filterName:String, filterParams: String, proxyBase :String)
+ case class AddWebUIFilter(filterName:String, filterParams: Map[String, String], proxyBase :String)
extends CoarseGrainedClusterMessage
}
diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala
index 89089e7d6f8a8..59aed6b72fe42 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala
@@ -275,15 +275,17 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, actorSystem: A
}
// Add filters to the SparkUI
- def addWebUIFilter(filterName: String, filterParams: String, proxyBase: String) {
+ def addWebUIFilter(filterName: String, filterParams: Map[String, String], proxyBase: String) {
if (proxyBase != null && proxyBase.nonEmpty) {
System.setProperty("spark.ui.proxyBase", proxyBase)
}
- if (Seq(filterName, filterParams).forall(t => t != null && t.nonEmpty)) {
+ val hasFilter = (filterName != null && filterName.nonEmpty &&
+ filterParams != null && filterParams.nonEmpty)
+ if (hasFilter) {
logInfo(s"Add WebUI Filter. $filterName, $filterParams, $proxyBase")
conf.set("spark.ui.filters", filterName)
- conf.set(s"spark.$filterName.params", filterParams)
+ filterParams.foreach { case (k, v) => conf.set(s"spark.$filterName.param.$k", v) }
scheduler.sc.ui.foreach { ui => JettyUtils.addFilters(ui.getHandlers, conf) }
}
}
diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala
index 5c5ecc8434d78..8c7de75600b5f 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala
@@ -51,7 +51,8 @@ private[spark] class SparkDeploySchedulerBackend(
conf.get("spark.driver.host"),
conf.get("spark.driver.port"),
CoarseGrainedSchedulerBackend.ACTOR_NAME)
- val args = Seq(driverUrl, "{{EXECUTOR_ID}}", "{{HOSTNAME}}", "{{CORES}}", "{{WORKER_URL}}")
+ val args = Seq(driverUrl, "{{EXECUTOR_ID}}", "{{HOSTNAME}}", "{{CORES}}", "{{APP_ID}}",
+ "{{WORKER_URL}}")
val extraJavaOpts = sc.conf.getOption("spark.executor.extraJavaOptions")
.map(Utils.splitCommandString).getOrElse(Seq.empty)
val classPathEntries = sc.conf.getOption("spark.executor.extraClassPath").toSeq.flatMap { cp =>
@@ -68,9 +69,8 @@ private[spark] class SparkDeploySchedulerBackend(
val command = Command("org.apache.spark.executor.CoarseGrainedExecutorBackend",
args, sc.executorEnvs, classPathEntries, libraryPathEntries, javaOpts)
val appUIAddress = sc.ui.map(_.appUIAddress).getOrElse("")
- val eventLogDir = sc.eventLogger.map(_.logDir)
val appDesc = new ApplicationDescription(sc.appName, maxCores, sc.executorMemory, command,
- appUIAddress, eventLogDir)
+ appUIAddress, sc.eventLogDir)
client = new AppClient(sc.env.actorSystem, masters, appDesc, this, conf)
client.start()
@@ -129,7 +129,11 @@ private[spark] class SparkDeploySchedulerBackend(
totalCoreCount.get() >= totalExpectedCores * minRegisteredRatio
}
- override def applicationId(): Option[String] = Option(appId)
+ override def applicationId(): String =
+ Option(appId).getOrElse {
+ logWarning("Application ID is not initialized yet.")
+ super.applicationId
+ }
private def waitForRegistration() = {
registrationLock.synchronized {
diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala
index 64568409dbafd..d7f88de4b40aa 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala
@@ -76,6 +76,8 @@ private[spark] class CoarseMesosSchedulerBackend(
var nextMesosTaskId = 0
+ @volatile var appId: String = _
+
def newMesosTaskId(): Int = {
val id = nextMesosTaskId
nextMesosTaskId += 1
@@ -148,17 +150,17 @@ private[spark] class CoarseMesosSchedulerBackend(
if (uri == null) {
val runScript = new File(executorSparkHome, "./bin/spark-class").getCanonicalPath
command.setValue(
- "\"%s\" org.apache.spark.executor.CoarseGrainedExecutorBackend %s %s %s %d".format(
- runScript, driverUrl, offer.getSlaveId.getValue, offer.getHostname, numCores))
+ "\"%s\" org.apache.spark.executor.CoarseGrainedExecutorBackend %s %s %s %d %s".format(
+ runScript, driverUrl, offer.getSlaveId.getValue, offer.getHostname, numCores, appId))
} else {
// Grab everything to the first '.'. We'll use that and '*' to
// glob the directory "correctly".
val basename = uri.split('/').last.split('.').head
command.setValue(
("cd %s*; " +
- "./bin/spark-class org.apache.spark.executor.CoarseGrainedExecutorBackend %s %s %s %d")
+ "./bin/spark-class org.apache.spark.executor.CoarseGrainedExecutorBackend %s %s %s %d %s")
.format(basename, driverUrl, offer.getSlaveId.getValue,
- offer.getHostname, numCores))
+ offer.getHostname, numCores, appId))
command.addUris(CommandInfo.URI.newBuilder().setValue(uri))
}
command.build()
@@ -167,7 +169,8 @@ private[spark] class CoarseMesosSchedulerBackend(
override def offerRescinded(d: SchedulerDriver, o: OfferID) {}
override def registered(d: SchedulerDriver, frameworkId: FrameworkID, masterInfo: MasterInfo) {
- logInfo("Registered as framework ID " + frameworkId.getValue)
+ appId = frameworkId.getValue
+ logInfo("Registered as framework ID " + appId)
registeredLock.synchronized {
isRegistered = true
registeredLock.notifyAll()
@@ -198,7 +201,9 @@ private[spark] class CoarseMesosSchedulerBackend(
val slaveId = offer.getSlaveId.toString
val mem = getResource(offer.getResourcesList, "mem")
val cpus = getResource(offer.getResourcesList, "cpus").toInt
- if (totalCoresAcquired < maxCores && mem >= sc.executorMemory && cpus >= 1 &&
+ if (totalCoresAcquired < maxCores &&
+ mem >= MemoryUtils.calculateTotalMemory(sc) &&
+ cpus >= 1 &&
failuresBySlaveId.getOrElse(slaveId, 0) < MAX_SLAVE_FAILURES &&
!slaveIdsWithExecutors.contains(slaveId)) {
// Launch an executor on the slave
@@ -214,7 +219,8 @@ private[spark] class CoarseMesosSchedulerBackend(
.setCommand(createCommand(offer, cpusToUse + extraCoresPerSlave))
.setName("Task " + taskId)
.addResources(createResource("cpus", cpusToUse))
- .addResources(createResource("mem", sc.executorMemory))
+ .addResources(createResource("mem",
+ MemoryUtils.calculateTotalMemory(sc)))
.build()
d.launchTasks(
Collections.singleton(offer.getId), Collections.singletonList(task), filters)
@@ -310,4 +316,10 @@ private[spark] class CoarseMesosSchedulerBackend(
slaveLost(d, s)
}
+ override def applicationId(): String =
+ Option(appId).getOrElse {
+ logWarning("Application ID is not initialized yet.")
+ super.applicationId
+ }
+
}
diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MemoryUtils.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MemoryUtils.scala
new file mode 100644
index 0000000000000..5101ec8352e79
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MemoryUtils.scala
@@ -0,0 +1,35 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.scheduler.cluster.mesos
+
+import org.apache.spark.SparkContext
+
+private[spark] object MemoryUtils {
+ // These defaults copied from YARN
+ val OVERHEAD_FRACTION = 1.07
+ val OVERHEAD_MINIMUM = 384
+
+ def calculateTotalMemory(sc: SparkContext) = {
+ math.max(
+ sc.conf.getOption("spark.mesos.executor.memoryOverhead")
+ .getOrElse(OVERHEAD_MINIMUM.toString)
+ .toInt + sc.executorMemory,
+ OVERHEAD_FRACTION * sc.executorMemory
+ )
+ }
+}
diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackend.scala
index a9ef126f5de0e..e0f2fd622f54c 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackend.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackend.scala
@@ -30,7 +30,7 @@ import org.apache.mesos._
import org.apache.mesos.Protos.{TaskInfo => MesosTaskInfo, TaskState => MesosTaskState, _}
import org.apache.spark.{Logging, SparkContext, SparkException, TaskState}
-import org.apache.spark.scheduler.{ExecutorExited, ExecutorLossReason, SchedulerBackend, SlaveLost, TaskDescription, TaskSchedulerImpl, WorkerOffer}
+import org.apache.spark.scheduler._
import org.apache.spark.util.Utils
/**
@@ -62,6 +62,8 @@ private[spark] class MesosSchedulerBackend(
var classLoader: ClassLoader = null
+ @volatile var appId: String = _
+
override def start() {
synchronized {
classLoader = Thread.currentThread.getContextClassLoader
@@ -124,15 +126,24 @@ private[spark] class MesosSchedulerBackend(
command.setValue("cd %s*; ./sbin/spark-executor".format(basename))
command.addUris(CommandInfo.URI.newBuilder().setValue(uri))
}
+ val cpus = Resource.newBuilder()
+ .setName("cpus")
+ .setType(Value.Type.SCALAR)
+ .setScalar(Value.Scalar.newBuilder()
+ .setValue(scheduler.CPUS_PER_TASK).build())
+ .build()
val memory = Resource.newBuilder()
.setName("mem")
.setType(Value.Type.SCALAR)
- .setScalar(Value.Scalar.newBuilder().setValue(sc.executorMemory).build())
+ .setScalar(
+ Value.Scalar.newBuilder()
+ .setValue(MemoryUtils.calculateTotalMemory(sc)).build())
.build()
ExecutorInfo.newBuilder()
.setExecutorId(ExecutorID.newBuilder().setValue(execId).build())
.setCommand(command)
.setData(ByteString.copyFrom(createExecArg()))
+ .addResources(cpus)
.addResources(memory)
.build()
}
@@ -168,7 +179,8 @@ private[spark] class MesosSchedulerBackend(
override def registered(d: SchedulerDriver, frameworkId: FrameworkID, masterInfo: MasterInfo) {
val oldClassLoader = setClassLoader()
try {
- logInfo("Registered as framework ID " + frameworkId.getValue)
+ appId = frameworkId.getValue
+ logInfo("Registered as framework ID " + appId)
registeredLock.synchronized {
isRegistered = true
registeredLock.notifyAll()
@@ -204,18 +216,31 @@ private[spark] class MesosSchedulerBackend(
val offerableWorkers = new ArrayBuffer[WorkerOffer]
val offerableIndices = new HashMap[String, Int]
- def enoughMemory(o: Offer) = {
+ def sufficientOffer(o: Offer) = {
val mem = getResource(o.getResourcesList, "mem")
+ val cpus = getResource(o.getResourcesList, "cpus")
val slaveId = o.getSlaveId.getValue
- mem >= sc.executorMemory || slaveIdsWithExecutors.contains(slaveId)
+ (mem >= MemoryUtils.calculateTotalMemory(sc) &&
+ // need at least 1 for executor, 1 for task
+ cpus >= 2 * scheduler.CPUS_PER_TASK) ||
+ (slaveIdsWithExecutors.contains(slaveId) &&
+ cpus >= scheduler.CPUS_PER_TASK)
}
- for ((offer, index) <- offers.zipWithIndex if enoughMemory(offer)) {
- offerableIndices.put(offer.getSlaveId.getValue, index)
+ for ((offer, index) <- offers.zipWithIndex if sufficientOffer(offer)) {
+ val slaveId = offer.getSlaveId.getValue
+ offerableIndices.put(slaveId, index)
+ val cpus = if (slaveIdsWithExecutors.contains(slaveId)) {
+ getResource(offer.getResourcesList, "cpus").toInt
+ } else {
+ // If the executor doesn't exist yet, subtract CPU for executor
+ getResource(offer.getResourcesList, "cpus").toInt -
+ scheduler.CPUS_PER_TASK
+ }
offerableWorkers += new WorkerOffer(
offer.getSlaveId.getValue,
offer.getHostname,
- getResource(offer.getResourcesList, "cpus").toInt)
+ cpus)
}
// Call into the TaskSchedulerImpl
@@ -347,7 +372,20 @@ private[spark] class MesosSchedulerBackend(
recordSlaveLost(d, slaveId, ExecutorExited(status))
}
+ override def killTask(taskId: Long, executorId: String, interruptThread: Boolean): Unit = {
+ driver.killTask(
+ TaskID.newBuilder()
+ .setValue(taskId.toString).build()
+ )
+ }
+
// TODO: query Mesos for number of cores
override def defaultParallelism() = sc.conf.getInt("spark.default.parallelism", 8)
+ override def applicationId(): String =
+ Option(appId).getOrElse {
+ logWarning("Application ID is not initialized yet.")
+ super.applicationId
+ }
+
}
diff --git a/core/src/main/scala/org/apache/spark/scheduler/local/LocalBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/local/LocalBackend.scala
index 9ea25c2bc7090..58b78f041cd85 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/local/LocalBackend.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/local/LocalBackend.scala
@@ -88,6 +88,7 @@ private[spark] class LocalActor(
private[spark] class LocalBackend(scheduler: TaskSchedulerImpl, val totalCores: Int)
extends SchedulerBackend with ExecutorBackend {
+ private val appId = "local-" + System.currentTimeMillis
var localActor: ActorRef = null
override def start() {
@@ -115,4 +116,6 @@ private[spark] class LocalBackend(scheduler: TaskSchedulerImpl, val totalCores:
localActor ! StatusUpdate(taskId, state, serializedData)
}
+ override def applicationId(): String = appId
+
}
diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala
index d1bee3d2c033c..3f5d06e1aeee7 100644
--- a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala
+++ b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala
@@ -22,6 +22,7 @@ import java.nio.{ByteBuffer, MappedByteBuffer}
import scala.concurrent.ExecutionContext.Implicits.global
+import scala.collection.mutable
import scala.collection.mutable.{ArrayBuffer, HashMap}
import scala.concurrent.{Await, Future}
import scala.concurrent.duration._
@@ -112,6 +113,11 @@ private[spark] class BlockManager(
private val broadcastCleaner = new MetadataCleaner(
MetadataCleanerType.BROADCAST_VARS, this.dropOldBroadcastBlocks, conf)
+ // Field related to peer block managers that are necessary for block replication
+ @volatile private var cachedPeers: Seq[BlockManagerId] = _
+ private val peerFetchLock = new Object
+ private var lastPeerFetchTime = 0L
+
initialize()
/* The compression codec to use. Note that the "lazy" val is necessary because we want to delay
@@ -787,31 +793,111 @@ private[spark] class BlockManager(
}
/**
- * Replicate block to another node.
+ * Get peer block managers in the system.
+ */
+ private def getPeers(forceFetch: Boolean): Seq[BlockManagerId] = {
+ peerFetchLock.synchronized {
+ val cachedPeersTtl = conf.getInt("spark.storage.cachedPeersTtl", 60 * 1000) // milliseconds
+ val timeout = System.currentTimeMillis - lastPeerFetchTime > cachedPeersTtl
+ if (cachedPeers == null || forceFetch || timeout) {
+ cachedPeers = master.getPeers(blockManagerId).sortBy(_.hashCode)
+ lastPeerFetchTime = System.currentTimeMillis
+ logDebug("Fetched peers from master: " + cachedPeers.mkString("[", ",", "]"))
+ }
+ cachedPeers
+ }
+ }
+
+ /**
+ * Replicate block to another node. Not that this is a blocking call that returns after
+ * the block has been replicated.
*/
- @volatile var cachedPeers: Seq[BlockManagerId] = null
private def replicate(blockId: BlockId, data: ByteBuffer, level: StorageLevel): Unit = {
+ val maxReplicationFailures = conf.getInt("spark.storage.maxReplicationFailures", 1)
+ val numPeersToReplicateTo = level.replication - 1
+ val peersForReplication = new ArrayBuffer[BlockManagerId]
+ val peersReplicatedTo = new ArrayBuffer[BlockManagerId]
+ val peersFailedToReplicateTo = new ArrayBuffer[BlockManagerId]
val tLevel = StorageLevel(
level.useDisk, level.useMemory, level.useOffHeap, level.deserialized, 1)
- if (cachedPeers == null) {
- cachedPeers = master.getPeers(blockManagerId, level.replication - 1)
+ val startTime = System.currentTimeMillis
+ val random = new Random(blockId.hashCode)
+
+ var replicationFailed = false
+ var failures = 0
+ var done = false
+
+ // Get cached list of peers
+ peersForReplication ++= getPeers(forceFetch = false)
+
+ // Get a random peer. Note that this selection of a peer is deterministic on the block id.
+ // So assuming the list of peers does not change and no replication failures,
+ // if there are multiple attempts in the same node to replicate the same block,
+ // the same set of peers will be selected.
+ def getRandomPeer(): Option[BlockManagerId] = {
+ // If replication had failed, then force update the cached list of peers and remove the peers
+ // that have been already used
+ if (replicationFailed) {
+ peersForReplication.clear()
+ peersForReplication ++= getPeers(forceFetch = true)
+ peersForReplication --= peersReplicatedTo
+ peersForReplication --= peersFailedToReplicateTo
+ }
+ if (!peersForReplication.isEmpty) {
+ Some(peersForReplication(random.nextInt(peersForReplication.size)))
+ } else {
+ None
+ }
}
- for (peer: BlockManagerId <- cachedPeers) {
- val start = System.nanoTime
- data.rewind()
- logDebug(s"Try to replicate $blockId once; The size of the data is ${data.limit()} Bytes. " +
- s"To node: $peer")
- try {
- blockTransferService.uploadBlockSync(
- peer.host, peer.port, blockId.toString, new NioByteBufferManagedBuffer(data), tLevel)
- } catch {
- case e: Exception =>
- logError(s"Failed to replicate block to $peer", e)
+ // One by one choose a random peer and try uploading the block to it
+ // If replication fails (e.g., target peer is down), force the list of cached peers
+ // to be re-fetched from driver and then pick another random peer for replication. Also
+ // temporarily black list the peer for which replication failed.
+ //
+ // This selection of a peer and replication is continued in a loop until one of the
+ // following 3 conditions is fulfilled:
+ // (i) specified number of peers have been replicated to
+ // (ii) too many failures in replicating to peers
+ // (iii) no peer left to replicate to
+ //
+ while (!done) {
+ getRandomPeer() match {
+ case Some(peer) =>
+ try {
+ val onePeerStartTime = System.currentTimeMillis
+ data.rewind()
+ logTrace(s"Trying to replicate $blockId of ${data.limit()} bytes to $peer")
+ blockTransferService.uploadBlockSync(
+ peer.host, peer.port, blockId.toString, new NioByteBufferManagedBuffer(data), tLevel)
+ logTrace(s"Replicated $blockId of ${data.limit()} bytes to $peer in %f ms"
+ .format((System.currentTimeMillis - onePeerStartTime)))
+ peersReplicatedTo += peer
+ peersForReplication -= peer
+ replicationFailed = false
+ if (peersReplicatedTo.size == numPeersToReplicateTo) {
+ done = true // specified number of peers have been replicated to
+ }
+ } catch {
+ case e: Exception =>
+ logWarning(s"Failed to replicate $blockId to $peer, failure #$failures", e)
+ failures += 1
+ replicationFailed = true
+ peersFailedToReplicateTo += peer
+ if (failures > maxReplicationFailures) { // too many failures in replcating to peers
+ done = true
+ }
+ }
+ case None => // no peer left to replicate to
+ done = true
}
-
- logDebug("Replicating BlockId %s once used %fs; The size of the data is %d bytes."
- .format(blockId, (System.nanoTime - start) / 1e6, data.limit()))
+ }
+ val timeTakeMs = (System.currentTimeMillis - startTime)
+ logDebug(s"Replicating $blockId of ${data.limit()} bytes to " +
+ s"${peersReplicatedTo.size} peer(s) took $timeTakeMs ms")
+ if (peersReplicatedTo.size < numPeersToReplicateTo) {
+ logWarning(s"Block $blockId replicated to only " +
+ s"${peersReplicatedTo.size} peer(s) instead of $numPeersToReplicateTo peers")
}
}
diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerId.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerId.scala
index d4487fce49ab6..142285094342c 100644
--- a/core/src/main/scala/org/apache/spark/storage/BlockManagerId.scala
+++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerId.scala
@@ -59,6 +59,8 @@ class BlockManagerId private (
def port: Int = port_
+ def isDriver: Boolean = (executorId == "")
+
override def writeExternal(out: ObjectOutput) {
out.writeUTF(executorId_)
out.writeUTF(host_)
diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala
index 2e262594b3538..d08e1419e3e41 100644
--- a/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala
+++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala
@@ -84,13 +84,8 @@ class BlockManagerMaster(
}
/** Get ids of other nodes in the cluster from the driver */
- def getPeers(blockManagerId: BlockManagerId, numPeers: Int): Seq[BlockManagerId] = {
- val result = askDriverWithReply[Seq[BlockManagerId]](GetPeers(blockManagerId, numPeers))
- if (result.length != numPeers) {
- throw new SparkException(
- "Error getting peers, only got " + result.size + " instead of " + numPeers)
- }
- result
+ def getPeers(blockManagerId: BlockManagerId): Seq[BlockManagerId] = {
+ askDriverWithReply[Seq[BlockManagerId]](GetPeers(blockManagerId))
}
/**
diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterActor.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterActor.scala
index 1a6c7cb24f9ac..088f06e389d83 100644
--- a/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterActor.scala
+++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterActor.scala
@@ -83,8 +83,8 @@ class BlockManagerMasterActor(val isLocal: Boolean, conf: SparkConf, listenerBus
case GetLocationsMultipleBlockIds(blockIds) =>
sender ! getLocationsMultipleBlockIds(blockIds)
- case GetPeers(blockManagerId, size) =>
- sender ! getPeers(blockManagerId, size)
+ case GetPeers(blockManagerId) =>
+ sender ! getPeers(blockManagerId)
case GetMemoryStatus =>
sender ! memoryStatus
@@ -173,11 +173,10 @@ class BlockManagerMasterActor(val isLocal: Boolean, conf: SparkConf, listenerBus
* from the executors, but not from the driver.
*/
private def removeBroadcast(broadcastId: Long, removeFromDriver: Boolean): Future[Seq[Int]] = {
- // TODO: Consolidate usages of
import context.dispatcher
val removeMsg = RemoveBroadcast(broadcastId, removeFromDriver)
val requiredBlockManagers = blockManagerInfo.values.filter { info =>
- removeFromDriver || info.blockManagerId.executorId != ""
+ removeFromDriver || !info.blockManagerId.isDriver
}
Future.sequence(
requiredBlockManagers.map { bm =>
@@ -212,7 +211,7 @@ class BlockManagerMasterActor(val isLocal: Boolean, conf: SparkConf, listenerBus
val minSeenTime = now - slaveTimeout
val toRemove = new mutable.HashSet[BlockManagerId]
for (info <- blockManagerInfo.values) {
- if (info.lastSeenMs < minSeenTime && info.blockManagerId.executorId != "") {
+ if (info.lastSeenMs < minSeenTime && !info.blockManagerId.isDriver) {
logWarning("Removing BlockManager " + info.blockManagerId + " with no recent heart beats: "
+ (now - info.lastSeenMs) + "ms exceeds " + slaveTimeout + "ms")
toRemove += info.blockManagerId
@@ -232,7 +231,7 @@ class BlockManagerMasterActor(val isLocal: Boolean, conf: SparkConf, listenerBus
*/
private def heartbeatReceived(blockManagerId: BlockManagerId): Boolean = {
if (!blockManagerInfo.contains(blockManagerId)) {
- blockManagerId.executorId == "" && !isLocal
+ blockManagerId.isDriver && !isLocal
} else {
blockManagerInfo(blockManagerId).updateLastSeenMs()
true
@@ -355,7 +354,7 @@ class BlockManagerMasterActor(val isLocal: Boolean, conf: SparkConf, listenerBus
tachyonSize: Long) {
if (!blockManagerInfo.contains(blockManagerId)) {
- if (blockManagerId.executorId == "" && !isLocal) {
+ if (blockManagerId.isDriver && !isLocal) {
// We intentionally do not register the master (except in local mode),
// so we should not indicate failure.
sender ! true
@@ -403,16 +402,14 @@ class BlockManagerMasterActor(val isLocal: Boolean, conf: SparkConf, listenerBus
blockIds.map(blockId => getLocations(blockId))
}
- private def getPeers(blockManagerId: BlockManagerId, size: Int): Seq[BlockManagerId] = {
- val peers: Array[BlockManagerId] = blockManagerInfo.keySet.toArray
-
- val selfIndex = peers.indexOf(blockManagerId)
- if (selfIndex == -1) {
- throw new SparkException("Self index for " + blockManagerId + " not found")
+ /** Get the list of the peers of the given block manager */
+ private def getPeers(blockManagerId: BlockManagerId): Seq[BlockManagerId] = {
+ val blockManagerIds = blockManagerInfo.keySet
+ if (blockManagerIds.contains(blockManagerId)) {
+ blockManagerIds.filterNot { _.isDriver }.filterNot { _ == blockManagerId }.toSeq
+ } else {
+ Seq.empty
}
-
- // Note that this logic will select the same node multiple times if there aren't enough peers
- Array.tabulate[BlockManagerId](size) { i => peers((selfIndex + i + 1) % peers.length) }.toSeq
}
}
@@ -460,16 +457,18 @@ private[spark] class BlockManagerInfo(
if (_blocks.containsKey(blockId)) {
// The block exists on the slave already.
- val originalLevel: StorageLevel = _blocks.get(blockId).storageLevel
+ val blockStatus: BlockStatus = _blocks.get(blockId)
+ val originalLevel: StorageLevel = blockStatus.storageLevel
+ val originalMemSize: Long = blockStatus.memSize
if (originalLevel.useMemory) {
- _remainingMem += memSize
+ _remainingMem += originalMemSize
}
}
if (storageLevel.isValid) {
/* isValid means it is either stored in-memory, on-disk or on-Tachyon.
- * But the memSize here indicates the data size in or dropped from memory,
+ * The memSize here indicates the data size in or dropped from memory,
* tachyonSize here indicates the data size in or dropped from Tachyon,
* and the diskSize here indicates the data size in or dropped to disk.
* They can be both larger than 0, when a block is dropped from memory to disk.
@@ -496,7 +495,6 @@ private[spark] class BlockManagerInfo(
val blockStatus: BlockStatus = _blocks.get(blockId)
_blocks.remove(blockId)
if (blockStatus.storageLevel.useMemory) {
- _remainingMem += blockStatus.memSize
logInfo("Removed %s on %s in memory (size: %s, free: %s)".format(
blockId, blockManagerId.hostPort, Utils.bytesToString(blockStatus.memSize),
Utils.bytesToString(_remainingMem)))
diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerMessages.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerMessages.scala
index 2ba16b8476600..3db5dd9774ae8 100644
--- a/core/src/main/scala/org/apache/spark/storage/BlockManagerMessages.scala
+++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerMessages.scala
@@ -88,7 +88,7 @@ private[spark] object BlockManagerMessages {
case class GetLocationsMultipleBlockIds(blockIds: Array[BlockId]) extends ToBlockManagerMaster
- case class GetPeers(blockManagerId: BlockManagerId, size: Int) extends ToBlockManagerMaster
+ case class GetPeers(blockManagerId: BlockManagerId) extends ToBlockManagerMaster
case class RemoveExecutor(execId: String) extends ToBlockManagerMaster
diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerSource.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerSource.scala
index 49fea6d9e2a76..8569c6f3cbbc3 100644
--- a/core/src/main/scala/org/apache/spark/storage/BlockManagerSource.scala
+++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerSource.scala
@@ -22,10 +22,10 @@ import com.codahale.metrics.{Gauge,MetricRegistry}
import org.apache.spark.SparkContext
import org.apache.spark.metrics.source.Source
-private[spark] class BlockManagerSource(val blockManager: BlockManager, sc: SparkContext)
+private[spark] class BlockManagerSource(val blockManager: BlockManager)
extends Source {
override val metricRegistry = new MetricRegistry()
- override val sourceName = "%s.BlockManager".format(sc.appName)
+ override val sourceName = "BlockManager"
metricRegistry.register(MetricRegistry.name("memory", "maxMem_MB"), new Gauge[Long] {
override def getValue: Long = {
diff --git a/core/src/main/scala/org/apache/spark/storage/DiskStore.scala b/core/src/main/scala/org/apache/spark/storage/DiskStore.scala
index e9304f6bb45d0..bac459e835a3f 100644
--- a/core/src/main/scala/org/apache/spark/storage/DiskStore.scala
+++ b/core/src/main/scala/org/apache/spark/storage/DiskStore.scala
@@ -73,7 +73,21 @@ private[spark] class DiskStore(blockManager: BlockManager, diskManager: DiskBloc
val startTime = System.currentTimeMillis
val file = diskManager.getFile(blockId)
val outputStream = new FileOutputStream(file)
- blockManager.dataSerializeStream(blockId, outputStream, values)
+ try {
+ try {
+ blockManager.dataSerializeStream(blockId, outputStream, values)
+ } finally {
+ // Close outputStream here because it should be closed before file is deleted.
+ outputStream.close()
+ }
+ } catch {
+ case e: Throwable =>
+ if (file.exists()) {
+ file.delete()
+ }
+ throw e
+ }
+
val length = file.length
val timeTaken = System.currentTimeMillis - startTime
diff --git a/core/src/main/scala/org/apache/spark/storage/MemoryStore.scala b/core/src/main/scala/org/apache/spark/storage/MemoryStore.scala
index 0a09c24d61879..edbc729c17ade 100644
--- a/core/src/main/scala/org/apache/spark/storage/MemoryStore.scala
+++ b/core/src/main/scala/org/apache/spark/storage/MemoryStore.scala
@@ -132,8 +132,6 @@ private[spark] class MemoryStore(blockManager: BlockManager, maxMemory: Long)
PutResult(res.size, res.data, droppedBlocks)
case Right(iteratorValues) =>
// Not enough space to unroll this block; drop to disk if applicable
- logWarning(s"Not enough space to store block $blockId in memory! " +
- s"Free memory is $freeMemory bytes.")
if (level.useDisk && allowPersistToDisk) {
logWarning(s"Persisting block $blockId to disk instead.")
val res = blockManager.diskStore.putIterator(blockId, iteratorValues, level, returnValues)
@@ -265,6 +263,7 @@ private[spark] class MemoryStore(blockManager: BlockManager, maxMemory: Long)
Left(vector.toArray)
} else {
// We ran out of space while unrolling the values for this block
+ logUnrollFailureMessage(blockId, vector.estimateSize())
Right(vector.iterator ++ values)
}
@@ -424,7 +423,7 @@ private[spark] class MemoryStore(blockManager: BlockManager, maxMemory: Long)
* Reserve additional memory for unrolling blocks used by this thread.
* Return whether the request is granted.
*/
- private[spark] def reserveUnrollMemoryForThisThread(memory: Long): Boolean = {
+ def reserveUnrollMemoryForThisThread(memory: Long): Boolean = {
accountingLock.synchronized {
val granted = freeMemory > currentUnrollMemory + memory
if (granted) {
@@ -439,7 +438,7 @@ private[spark] class MemoryStore(blockManager: BlockManager, maxMemory: Long)
* Release memory used by this thread for unrolling blocks.
* If the amount is not specified, remove the current thread's allocation altogether.
*/
- private[spark] def releaseUnrollMemoryForThisThread(memory: Long = -1L): Unit = {
+ def releaseUnrollMemoryForThisThread(memory: Long = -1L): Unit = {
val threadId = Thread.currentThread().getId
accountingLock.synchronized {
if (memory < 0) {
@@ -457,16 +456,50 @@ private[spark] class MemoryStore(blockManager: BlockManager, maxMemory: Long)
/**
* Return the amount of memory currently occupied for unrolling blocks across all threads.
*/
- private[spark] def currentUnrollMemory: Long = accountingLock.synchronized {
+ def currentUnrollMemory: Long = accountingLock.synchronized {
unrollMemoryMap.values.sum
}
/**
* Return the amount of memory currently occupied for unrolling blocks by this thread.
*/
- private[spark] def currentUnrollMemoryForThisThread: Long = accountingLock.synchronized {
+ def currentUnrollMemoryForThisThread: Long = accountingLock.synchronized {
unrollMemoryMap.getOrElse(Thread.currentThread().getId, 0L)
}
+
+ /**
+ * Return the number of threads currently unrolling blocks.
+ */
+ def numThreadsUnrolling: Int = accountingLock.synchronized { unrollMemoryMap.keys.size }
+
+ /**
+ * Log information about current memory usage.
+ */
+ def logMemoryUsage(): Unit = {
+ val blocksMemory = currentMemory
+ val unrollMemory = currentUnrollMemory
+ val totalMemory = blocksMemory + unrollMemory
+ logInfo(
+ s"Memory use = ${Utils.bytesToString(blocksMemory)} (blocks) + " +
+ s"${Utils.bytesToString(unrollMemory)} (scratch space shared across " +
+ s"$numThreadsUnrolling thread(s)) = ${Utils.bytesToString(totalMemory)}. " +
+ s"Storage limit = ${Utils.bytesToString(maxMemory)}."
+ )
+ }
+
+ /**
+ * Log a warning for failing to unroll a block.
+ *
+ * @param blockId ID of the block we are trying to unroll.
+ * @param finalVectorSize Final size of the vector before unrolling failed.
+ */
+ def logUnrollFailureMessage(blockId: BlockId, finalVectorSize: Long): Unit = {
+ logWarning(
+ s"Not enough space to cache $blockId in memory! " +
+ s"(computed ${Utils.bytesToString(finalVectorSize)} so far)"
+ )
+ logMemoryUsage()
+ }
}
private[spark] case class ResultWithDroppedBlocks(
diff --git a/core/src/main/scala/org/apache/spark/ui/JettyUtils.scala b/core/src/main/scala/org/apache/spark/ui/JettyUtils.scala
index 6b4689291097f..2a27d49d2de05 100644
--- a/core/src/main/scala/org/apache/spark/ui/JettyUtils.scala
+++ b/core/src/main/scala/org/apache/spark/ui/JettyUtils.scala
@@ -21,9 +21,7 @@ import java.net.{InetSocketAddress, URL}
import javax.servlet.DispatcherType
import javax.servlet.http.{HttpServlet, HttpServletRequest, HttpServletResponse}
-import scala.annotation.tailrec
import scala.language.implicitConversions
-import scala.util.{Failure, Success, Try}
import scala.xml.Node
import org.eclipse.jetty.server.Server
@@ -147,15 +145,19 @@ private[spark] object JettyUtils extends Logging {
val holder : FilterHolder = new FilterHolder()
holder.setClassName(filter)
// Get any parameters for each filter
- val paramName = "spark." + filter + ".params"
- val params = conf.get(paramName, "").split(',').map(_.trim()).toSet
- params.foreach {
- case param : String =>
+ conf.get("spark." + filter + ".params", "").split(',').map(_.trim()).toSet.foreach {
+ param: String =>
if (!param.isEmpty) {
val parts = param.split("=")
if (parts.length == 2) holder.setInitParameter(parts(0), parts(1))
}
}
+
+ val prefix = s"spark.$filter.param."
+ conf.getAll
+ .filter { case (k, v) => k.length() > prefix.length() && k.startsWith(prefix) }
+ .foreach { case (k, v) => holder.setInitParameter(k.substring(prefix.length()), v) }
+
val enumDispatcher = java.util.EnumSet.of(DispatcherType.ASYNC, DispatcherType.ERROR,
DispatcherType.FORWARD, DispatcherType.INCLUDE, DispatcherType.REQUEST)
handlers.foreach { case(handler) => handler.addFilter(holder, "/*", enumDispatcher) }
diff --git a/core/src/main/scala/org/apache/spark/ui/UIUtils.scala b/core/src/main/scala/org/apache/spark/ui/UIUtils.scala
index f0006b42aee4f..32e6b15bb0999 100644
--- a/core/src/main/scala/org/apache/spark/ui/UIUtils.scala
+++ b/core/src/main/scala/org/apache/spark/ui/UIUtils.scala
@@ -21,6 +21,7 @@ import java.text.SimpleDateFormat
import java.util.{Locale, Date}
import scala.xml.Node
+
import org.apache.spark.Logging
/** Utility functions for generating XML pages with spark content. */
@@ -169,6 +170,7 @@ private[spark] object UIUtils extends Logging {
refreshInterval: Option[Int] = None): Seq[Node] = {
val appName = activeTab.appName
+ val shortAppName = if (appName.length < 36) appName else appName.take(32) + "..."
val header = activeTab.headerTabs.map { tab =>
// scalastyle:on
}
diff --git a/core/src/main/scala/org/apache/spark/util/AkkaUtils.scala b/core/src/main/scala/org/apache/spark/util/AkkaUtils.scala
index e2d32c859bbda..f41c8d0315cb3 100644
--- a/core/src/main/scala/org/apache/spark/util/AkkaUtils.scala
+++ b/core/src/main/scala/org/apache/spark/util/AkkaUtils.scala
@@ -77,7 +77,7 @@ private[spark] object AkkaUtils extends Logging {
val logAkkaConfig = if (conf.getBoolean("spark.akka.logAkkaConfig", false)) "on" else "off"
- val akkaHeartBeatPauses = conf.getInt("spark.akka.heartbeat.pauses", 600)
+ val akkaHeartBeatPauses = conf.getInt("spark.akka.heartbeat.pauses", 6000)
val akkaFailureDetector =
conf.getDouble("spark.akka.failure-detector.threshold", 300.0)
val akkaHeartBeatInterval = conf.getInt("spark.akka.heartbeat.interval", 1000)
diff --git a/core/src/main/scala/org/apache/spark/util/FileLogger.scala b/core/src/main/scala/org/apache/spark/util/FileLogger.scala
index 6d1fc05a15d2c..fdc73f08261a6 100644
--- a/core/src/main/scala/org/apache/spark/util/FileLogger.scala
+++ b/core/src/main/scala/org/apache/spark/util/FileLogger.scala
@@ -51,12 +51,27 @@ private[spark] class FileLogger(
def this(
logDir: String,
sparkConf: SparkConf,
- compress: Boolean = false,
- overwrite: Boolean = true) = {
+ compress: Boolean,
+ overwrite: Boolean) = {
this(logDir, sparkConf, SparkHadoopUtil.get.newConfiguration(sparkConf), compress = compress,
overwrite = overwrite)
}
+ def this(
+ logDir: String,
+ sparkConf: SparkConf,
+ compress: Boolean) = {
+ this(logDir, sparkConf, SparkHadoopUtil.get.newConfiguration(sparkConf), compress = compress,
+ overwrite = true)
+ }
+
+ def this(
+ logDir: String,
+ sparkConf: SparkConf) = {
+ this(logDir, sparkConf, SparkHadoopUtil.get.newConfiguration(sparkConf), compress = false,
+ overwrite = true)
+ }
+
private val dateFormat = new ThreadLocal[SimpleDateFormat]() {
override def initialValue(): SimpleDateFormat = new SimpleDateFormat("yyyy/MM/dd HH:mm:ss")
}
diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala
index dbe0cfa2b8ff9..53a7512edd852 100644
--- a/core/src/main/scala/org/apache/spark/util/Utils.scala
+++ b/core/src/main/scala/org/apache/spark/util/Utils.scala
@@ -23,6 +23,8 @@ import java.nio.ByteBuffer
import java.util.{Properties, Locale, Random, UUID}
import java.util.concurrent.{ThreadFactory, ConcurrentHashMap, Executors, ThreadPoolExecutor}
+import org.eclipse.jetty.util.MultiException
+
import scala.collection.JavaConversions._
import scala.collection.Map
import scala.collection.mutable.ArrayBuffer
@@ -166,6 +168,20 @@ private[spark] object Utils extends Logging {
private val shutdownDeletePaths = new scala.collection.mutable.HashSet[String]()
private val shutdownDeleteTachyonPaths = new scala.collection.mutable.HashSet[String]()
+ // Add a shutdown hook to delete the temp dirs when the JVM exits
+ Runtime.getRuntime.addShutdownHook(new Thread("delete Spark temp dirs") {
+ override def run(): Unit = Utils.logUncaughtExceptions {
+ logDebug("Shutdown hook called")
+ shutdownDeletePaths.foreach { dirPath =>
+ try {
+ Utils.deleteRecursively(new File(dirPath))
+ } catch {
+ case e: Exception => logError(s"Exception while deleting Spark temp dir: $dirPath", e)
+ }
+ }
+ }
+ })
+
// Register the path to be deleted via shutdown hook
def registerShutdownDeleteDir(file: File) {
val absolutePath = file.getAbsolutePath()
@@ -250,14 +266,6 @@ private[spark] object Utils extends Logging {
}
registerShutdownDeleteDir(dir)
-
- // Add a shutdown hook to delete the temp dir when the JVM exits
- Runtime.getRuntime.addShutdownHook(new Thread("delete Spark temp dir " + dir) {
- override def run() {
- // Attempt to delete if some patch which is parent of this is not already registered.
- if (! hasRootAsShutdownDeleteDir(dir)) Utils.deleteRecursively(dir)
- }
- })
dir
}
@@ -332,7 +340,7 @@ private[spark] object Utils extends Logging {
val targetFile = new File(targetDir, filename)
val uri = new URI(url)
val fileOverwrite = conf.getBoolean("spark.files.overwrite", defaultValue = false)
- uri.getScheme match {
+ Option(uri.getScheme).getOrElse("file") match {
case "http" | "https" | "ftp" =>
logInfo("Fetching " + url + " to " + tempFile)
@@ -366,7 +374,7 @@ private[spark] object Utils extends Logging {
}
}
Files.move(tempFile, targetFile)
- case "file" | null =>
+ case "file" =>
// In the case of a local file, copy the local file to the target directory.
// Note the difference between uri vs url.
val sourceFile = if (uri.isAbsolute) new File(uri) else new File(url)
@@ -664,15 +672,30 @@ private[spark] object Utils extends Logging {
*/
def deleteRecursively(file: File) {
if (file != null) {
- if (file.isDirectory() && !isSymlink(file)) {
- for (child <- listFilesSafely(file)) {
- deleteRecursively(child)
+ try {
+ if (file.isDirectory && !isSymlink(file)) {
+ var savedIOException: IOException = null
+ for (child <- listFilesSafely(file)) {
+ try {
+ deleteRecursively(child)
+ } catch {
+ // In case of multiple exceptions, only last one will be thrown
+ case ioe: IOException => savedIOException = ioe
+ }
+ }
+ if (savedIOException != null) {
+ throw savedIOException
+ }
+ shutdownDeletePaths.synchronized {
+ shutdownDeletePaths.remove(file.getAbsolutePath)
+ }
}
- }
- if (!file.delete()) {
- // Delete can also fail if the file simply did not exist
- if (file.exists()) {
- throw new IOException("Failed to delete: " + file.getAbsolutePath)
+ } finally {
+ if (!file.delete()) {
+ // Delete can also fail if the file simply did not exist
+ if (file.exists()) {
+ throw new IOException("Failed to delete: " + file.getAbsolutePath)
+ }
}
}
}
@@ -703,18 +726,23 @@ private[spark] object Utils extends Logging {
}
/**
- * Finds all the files in a directory whose last modified time is older than cutoff seconds.
- * @param dir must be the path to a directory, or IllegalArgumentException is thrown
- * @param cutoff measured in seconds. Files older than this are returned.
+ * Determines if a directory contains any files newer than cutoff seconds.
+ *
+ * @param dir must be the path to a directory, or IllegalArgumentException is thrown
+ * @param cutoff measured in seconds. Returns true if there are any files or directories in the
+ * given directory whose last modified time is later than this many seconds ago
*/
- def findOldFiles(dir: File, cutoff: Long): Seq[File] = {
- val currentTimeMillis = System.currentTimeMillis
- if (dir.isDirectory) {
- val files = listFilesSafely(dir)
- files.filter { file => file.lastModified < (currentTimeMillis - cutoff * 1000) }
- } else {
- throw new IllegalArgumentException(dir + " is not a directory!")
+ def doesDirectoryContainAnyNewFiles(dir: File, cutoff: Long): Boolean = {
+ if (!dir.isDirectory) {
+ throw new IllegalArgumentException(s"$dir is not a directory!")
}
+ val filesAndDirs = dir.listFiles()
+ val cutoffTimeInMillis = System.currentTimeMillis - (cutoff * 1000)
+
+ filesAndDirs.exists(_.lastModified() > cutoffTimeInMillis) ||
+ filesAndDirs.filter(_.isDirectory).exists(
+ subdir => doesDirectoryContainAnyNewFiles(subdir, cutoff)
+ )
}
/**
@@ -1340,16 +1368,17 @@ private[spark] object Utils extends Logging {
if (uri.getPath == null) {
throw new IllegalArgumentException(s"Given path is malformed: $uri")
}
- uri.getScheme match {
- case windowsDrive(d) if windows =>
+
+ Option(uri.getScheme) match {
+ case Some(windowsDrive(d)) if windows =>
new URI("file:/" + uri.toString.stripPrefix("/"))
- case null =>
+ case None =>
// Preserve fragments for HDFS file name substitution (denoted by "#")
// For instance, in "abc.py#xyz.py", "xyz.py" is the name observed by the application
val fragment = uri.getFragment
val part = new File(uri.getPath).toURI
new URI(part.getScheme, part.getPath, fragment)
- case _ =>
+ case Some(other) =>
uri
}
}
@@ -1371,15 +1400,64 @@ private[spark] object Utils extends Logging {
} else {
paths.split(",").filter { p =>
val formattedPath = if (windows) formatWindowsPath(p) else p
- new URI(formattedPath).getScheme match {
+ val uri = new URI(formattedPath)
+ Option(uri.getScheme).getOrElse("file") match {
case windowsDrive(d) if windows => false
- case "local" | "file" | null => false
+ case "local" | "file" => false
case _ => true
}
}
}
}
+ /**
+ * Load default Spark properties from the given file. If no file is provided,
+ * use the common defaults file. This mutates state in the given SparkConf and
+ * in this JVM's system properties if the config specified in the file is not
+ * already set. Return the path of the properties file used.
+ */
+ def loadDefaultSparkProperties(conf: SparkConf, filePath: String = null): String = {
+ val path = Option(filePath).getOrElse(getDefaultPropertiesFile())
+ Option(path).foreach { confFile =>
+ getPropertiesFromFile(confFile).filter { case (k, v) =>
+ k.startsWith("spark.")
+ }.foreach { case (k, v) =>
+ conf.setIfMissing(k, v)
+ sys.props.getOrElseUpdate(k, v)
+ }
+ }
+ path
+ }
+
+ /** Load properties present in the given file. */
+ def getPropertiesFromFile(filename: String): Map[String, String] = {
+ val file = new File(filename)
+ require(file.exists(), s"Properties file $file does not exist")
+ require(file.isFile(), s"Properties file $file is not a normal file")
+
+ val inReader = new InputStreamReader(new FileInputStream(file), "UTF-8")
+ try {
+ val properties = new Properties()
+ properties.load(inReader)
+ properties.stringPropertyNames().map(k => (k, properties(k).trim)).toMap
+ } catch {
+ case e: IOException =>
+ throw new SparkException(s"Failed when loading Spark properties from $filename", e)
+ } finally {
+ inReader.close()
+ }
+ }
+
+ /** Return the path of the default Spark properties file. */
+ def getDefaultPropertiesFile(env: Map[String, String] = sys.env): String = {
+ env.get("SPARK_CONF_DIR")
+ .orElse(env.get("SPARK_HOME").map { t => s"$t${File.separator}conf" })
+ .map { t => new File(s"$t${File.separator}spark-defaults.conf")}
+ .filter(_.isFile)
+ .map(_.getAbsolutePath)
+ .orNull
+ }
+
/** Return a nice string representation of the exception, including the stack trace. */
def exceptionString(e: Exception): String = {
if (e == null) "" else exceptionString(getFormattedClassName(e), e.getMessage, e.getStackTrace)
@@ -1437,7 +1515,12 @@ private[spark] object Utils extends Logging {
val serviceString = if (serviceName.isEmpty) "" else s" '$serviceName'"
for (offset <- 0 to maxRetries) {
// Do not increment port if startPort is 0, which is treated as a special port
- val tryPort = if (startPort == 0) startPort else (startPort + offset) % 65536
+ val tryPort = if (startPort == 0) {
+ startPort
+ } else {
+ // If the new port wraps around, do not try a privilege port
+ ((startPort + offset - 1024) % (65536 - 1024)) + 1024
+ }
try {
val (service, port) = startService(tryPort)
logInfo(s"Successfully started service$serviceString on port $port.")
@@ -1470,6 +1553,7 @@ private[spark] object Utils extends Logging {
return true
}
isBindCollision(e.getCause)
+ case e: MultiException => e.getThrowables.exists(isBindCollision)
case e: Exception => isBindCollision(e.getCause)
case _ => false
}
diff --git a/core/src/test/java/org/apache/spark/JavaAPISuite.java b/core/src/test/java/org/apache/spark/JavaAPISuite.java
index 4a078435447e5..3190148fb5f43 100644
--- a/core/src/test/java/org/apache/spark/JavaAPISuite.java
+++ b/core/src/test/java/org/apache/spark/JavaAPISuite.java
@@ -20,6 +20,7 @@
import java.io.*;
import java.net.URI;
import java.util.*;
+import java.util.concurrent.*;
import scala.Tuple2;
import scala.Tuple3;
@@ -29,6 +30,7 @@
import com.google.common.collect.Iterators;
import com.google.common.collect.Lists;
import com.google.common.collect.Maps;
+import com.google.common.base.Throwables;
import com.google.common.base.Optional;
import com.google.common.base.Charsets;
import com.google.common.io.Files;
@@ -43,10 +45,7 @@
import org.junit.Before;
import org.junit.Test;
-import org.apache.spark.api.java.JavaDoubleRDD;
-import org.apache.spark.api.java.JavaPairRDD;
-import org.apache.spark.api.java.JavaRDD;
-import org.apache.spark.api.java.JavaSparkContext;
+import org.apache.spark.api.java.*;
import org.apache.spark.api.java.function.*;
import org.apache.spark.executor.TaskMetrics;
import org.apache.spark.partial.BoundedDouble;
@@ -776,7 +775,7 @@ public void persist() {
@Test
public void iterator() {
JavaRDD rdd = sc.parallelize(Arrays.asList(1, 2, 3, 4, 5), 2);
- TaskContext context = new TaskContext(0, 0, 0L, false, new TaskMetrics());
+ TaskContext context = new TaskContextImpl(0, 0, 0L, false, new TaskMetrics());
Assert.assertEquals(1, rdd.iterator(rdd.partitions().get(0), context).next().intValue());
}
@@ -1308,6 +1307,92 @@ public void collectUnderlyingScalaRDD() {
Assert.assertEquals(data.size(), collected.length);
}
+ private static final class BuggyMapFunction implements Function {
+
+ @Override
+ public T call(T x) throws Exception {
+ throw new IllegalStateException("Custom exception!");
+ }
+ }
+
+ @Test
+ public void collectAsync() throws Exception {
+ List data = Arrays.asList(1, 2, 3, 4, 5);
+ JavaRDD rdd = sc.parallelize(data, 1);
+ JavaFutureAction> future = rdd.collectAsync();
+ List result = future.get();
+ Assert.assertEquals(data, result);
+ Assert.assertFalse(future.isCancelled());
+ Assert.assertTrue(future.isDone());
+ Assert.assertEquals(1, future.jobIds().size());
+ }
+
+ @Test
+ public void foreachAsync() throws Exception {
+ List data = Arrays.asList(1, 2, 3, 4, 5);
+ JavaRDD rdd = sc.parallelize(data, 1);
+ JavaFutureAction future = rdd.foreachAsync(
+ new VoidFunction() {
+ @Override
+ public void call(Integer integer) throws Exception {
+ // intentionally left blank.
+ }
+ }
+ );
+ future.get();
+ Assert.assertFalse(future.isCancelled());
+ Assert.assertTrue(future.isDone());
+ Assert.assertEquals(1, future.jobIds().size());
+ }
+
+ @Test
+ public void countAsync() throws Exception {
+ List data = Arrays.asList(1, 2, 3, 4, 5);
+ JavaRDD rdd = sc.parallelize(data, 1);
+ JavaFutureAction future = rdd.countAsync();
+ long count = future.get();
+ Assert.assertEquals(data.size(), count);
+ Assert.assertFalse(future.isCancelled());
+ Assert.assertTrue(future.isDone());
+ Assert.assertEquals(1, future.jobIds().size());
+ }
+
+ @Test
+ public void testAsyncActionCancellation() throws Exception {
+ List data = Arrays.asList(1, 2, 3, 4, 5);
+ JavaRDD rdd = sc.parallelize(data, 1);
+ JavaFutureAction future = rdd.foreachAsync(new VoidFunction() {
+ @Override
+ public void call(Integer integer) throws Exception {
+ Thread.sleep(10000); // To ensure that the job won't finish before it's cancelled.
+ }
+ });
+ future.cancel(true);
+ Assert.assertTrue(future.isCancelled());
+ Assert.assertTrue(future.isDone());
+ try {
+ future.get(2000, TimeUnit.MILLISECONDS);
+ Assert.fail("Expected future.get() for cancelled job to throw CancellationException");
+ } catch (CancellationException ignored) {
+ // pass
+ }
+ }
+
+ @Test
+ public void testAsyncActionErrorWrapping() throws Exception {
+ List data = Arrays.asList(1, 2, 3, 4, 5);
+ JavaRDD rdd = sc.parallelize(data, 1);
+ JavaFutureAction future = rdd.map(new BuggyMapFunction()).countAsync();
+ try {
+ future.get(2, TimeUnit.SECONDS);
+ Assert.fail("Expected future.get() for failed job to throw ExcecutionException");
+ } catch (ExecutionException ee) {
+ Assert.assertTrue(Throwables.getStackTraceAsString(ee).contains("Custom exception!"));
+ }
+ Assert.assertTrue(future.isDone());
+ }
+
+
/**
* Test for SPARK-3647. This test needs to use the maven-built assembly to trigger the issue,
* since that's the only artifact where Guava classes have been relocated.
diff --git a/core/src/test/java/org/apache/spark/util/JavaTaskCompletionListenerImpl.java b/core/src/test/java/org/apache/spark/util/JavaTaskCompletionListenerImpl.java
index 0944bf8cd5c71..e9ec700e32e15 100644
--- a/core/src/test/java/org/apache/spark/util/JavaTaskCompletionListenerImpl.java
+++ b/core/src/test/java/org/apache/spark/util/JavaTaskCompletionListenerImpl.java
@@ -30,8 +30,8 @@ public class JavaTaskCompletionListenerImpl implements TaskCompletionListener {
public void onTaskCompletion(TaskContext context) {
context.isCompleted();
context.isInterrupted();
- context.getStageId();
- context.getPartitionId();
+ context.stageId();
+ context.partitionId();
context.isRunningLocally();
context.addTaskCompletionListener(this);
}
diff --git a/core/src/test/resources/log4j.properties b/core/src/test/resources/log4j.properties
index 26b73a1b39744..9dd05f17f012b 100644
--- a/core/src/test/resources/log4j.properties
+++ b/core/src/test/resources/log4j.properties
@@ -21,7 +21,7 @@ log4j.appender.file=org.apache.log4j.FileAppender
log4j.appender.file.append=false
log4j.appender.file.file=target/unit-tests.log
log4j.appender.file.layout=org.apache.log4j.PatternLayout
-log4j.appender.file.layout.ConversionPattern=%d{yy/MM/dd HH:mm:ss.SSS} %p %c{1}: %m%n
+log4j.appender.file.layout.ConversionPattern=%d{yy/MM/dd HH:mm:ss.SSS} %t %p %c{1}: %m%n
# Ignore messages below warning level from Jetty, because it's a bit verbose
log4j.logger.org.eclipse.jetty=WARN
diff --git a/core/src/test/scala/org/apache/spark/CacheManagerSuite.scala b/core/src/test/scala/org/apache/spark/CacheManagerSuite.scala
index d735010d7c9d5..c0735f448d193 100644
--- a/core/src/test/scala/org/apache/spark/CacheManagerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/CacheManagerSuite.scala
@@ -66,7 +66,7 @@ class CacheManagerSuite extends FunSuite with BeforeAndAfter with EasyMockSugar
// in blockManager.put is a losing battle. You have been warned.
blockManager = sc.env.blockManager
cacheManager = sc.env.cacheManager
- val context = new TaskContext(0, 0, 0)
+ val context = new TaskContextImpl(0, 0, 0)
val computeValue = cacheManager.getOrCompute(rdd, split, context, StorageLevel.MEMORY_ONLY)
val getValue = blockManager.get(RDDBlockId(rdd.id, split.index))
assert(computeValue.toList === List(1, 2, 3, 4))
@@ -81,7 +81,7 @@ class CacheManagerSuite extends FunSuite with BeforeAndAfter with EasyMockSugar
}
whenExecuting(blockManager) {
- val context = new TaskContext(0, 0, 0)
+ val context = new TaskContextImpl(0, 0, 0)
val value = cacheManager.getOrCompute(rdd, split, context, StorageLevel.MEMORY_ONLY)
assert(value.toList === List(5, 6, 7))
}
@@ -94,7 +94,7 @@ class CacheManagerSuite extends FunSuite with BeforeAndAfter with EasyMockSugar
}
whenExecuting(blockManager) {
- val context = new TaskContext(0, 0, 0, true)
+ val context = new TaskContextImpl(0, 0, 0, true)
val value = cacheManager.getOrCompute(rdd, split, context, StorageLevel.MEMORY_ONLY)
assert(value.toList === List(1, 2, 3, 4))
}
@@ -102,7 +102,7 @@ class CacheManagerSuite extends FunSuite with BeforeAndAfter with EasyMockSugar
test("verify task metrics updated correctly") {
cacheManager = sc.env.cacheManager
- val context = new TaskContext(0, 0, 0)
+ val context = new TaskContextImpl(0, 0, 0)
cacheManager.getOrCompute(rdd3, split, context, StorageLevel.MEMORY_ONLY)
assert(context.taskMetrics.updatedBlocks.getOrElse(Seq()).size === 2)
}
diff --git a/core/src/test/scala/org/apache/spark/FileServerSuite.scala b/core/src/test/scala/org/apache/spark/FileServerSuite.scala
index 7e18f45de7b5b..a8867020e457d 100644
--- a/core/src/test/scala/org/apache/spark/FileServerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/FileServerSuite.scala
@@ -20,7 +20,6 @@ package org.apache.spark
import java.io._
import java.util.jar.{JarEntry, JarOutputStream}
-import com.google.common.io.Files
import org.scalatest.FunSuite
import org.apache.spark.SparkContext._
@@ -41,8 +40,7 @@ class FileServerSuite extends FunSuite with LocalSparkContext {
override def beforeAll() {
super.beforeAll()
- tmpDir = Files.createTempDir()
- tmpDir.deleteOnExit()
+ tmpDir = Utils.createTempDir()
val testTempDir = new File(tmpDir, "test")
testTempDir.mkdir()
diff --git a/core/src/test/scala/org/apache/spark/FileSuite.scala b/core/src/test/scala/org/apache/spark/FileSuite.scala
index 4a53d25012ad9..a2b74c4419d46 100644
--- a/core/src/test/scala/org/apache/spark/FileSuite.scala
+++ b/core/src/test/scala/org/apache/spark/FileSuite.scala
@@ -21,7 +21,6 @@ import java.io.{File, FileWriter}
import scala.io.Source
-import com.google.common.io.Files
import org.apache.hadoop.io._
import org.apache.hadoop.io.compress.DefaultCodec
import org.apache.hadoop.mapred.{JobConf, FileAlreadyExistsException, FileSplit, TextInputFormat, TextOutputFormat}
@@ -39,8 +38,7 @@ class FileSuite extends FunSuite with LocalSparkContext {
override def beforeEach() {
super.beforeEach()
- tempDir = Files.createTempDir()
- tempDir.deleteOnExit()
+ tempDir = Utils.createTempDir()
}
override def afterEach() {
diff --git a/core/src/test/scala/org/apache/spark/FutureActionSuite.scala b/core/src/test/scala/org/apache/spark/FutureActionSuite.scala
new file mode 100644
index 0000000000000..db9c25fc457a4
--- /dev/null
+++ b/core/src/test/scala/org/apache/spark/FutureActionSuite.scala
@@ -0,0 +1,49 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark
+
+import scala.concurrent.Await
+import scala.concurrent.duration.Duration
+
+import org.scalatest.{BeforeAndAfter, FunSuite, Matchers}
+
+import org.apache.spark.SparkContext._
+
+class FutureActionSuite extends FunSuite with BeforeAndAfter with Matchers with LocalSparkContext {
+
+ before {
+ sc = new SparkContext("local", "FutureActionSuite")
+ }
+
+ test("simple async action") {
+ val rdd = sc.parallelize(1 to 10, 2)
+ val job = rdd.countAsync()
+ val res = Await.result(job, Duration.Inf)
+ res should be (10)
+ job.jobIds.size should be (1)
+ }
+
+ test("complex async action") {
+ val rdd = sc.parallelize(1 to 15, 3)
+ val job = rdd.takeAsync(10)
+ val res = Await.result(job, Duration.Inf)
+ res should be (1 to 10)
+ job.jobIds.size should be (2)
+ }
+
+}
diff --git a/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala b/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala
index 1fef79ad1001f..cbc0bd178d894 100644
--- a/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala
@@ -146,7 +146,7 @@ class MapOutputTrackerSuite extends FunSuite with LocalSparkContext {
val masterTracker = new MapOutputTrackerMaster(conf)
val actorSystem = ActorSystem("test")
val actorRef = TestActorRef[MapOutputTrackerMasterActor](
- new MapOutputTrackerMasterActor(masterTracker, newConf))(actorSystem)
+ Props(new MapOutputTrackerMasterActor(masterTracker, newConf)))(actorSystem)
val masterActor = actorRef.underlyingActor
// Frame size should be ~123B, and no exception should be thrown
@@ -164,7 +164,7 @@ class MapOutputTrackerSuite extends FunSuite with LocalSparkContext {
val masterTracker = new MapOutputTrackerMaster(conf)
val actorSystem = ActorSystem("test")
val actorRef = TestActorRef[MapOutputTrackerMasterActor](
- new MapOutputTrackerMasterActor(masterTracker, newConf))(actorSystem)
+ Props(new MapOutputTrackerMasterActor(masterTracker, newConf)))(actorSystem)
val masterActor = actorRef.underlyingActor
// Frame size should be ~1.1MB, and MapOutputTrackerMasterActor should throw exception.
diff --git a/core/src/test/scala/org/apache/spark/SparkContextSuite.scala b/core/src/test/scala/org/apache/spark/SparkContextSuite.scala
new file mode 100644
index 0000000000000..31edad1c56c73
--- /dev/null
+++ b/core/src/test/scala/org/apache/spark/SparkContextSuite.scala
@@ -0,0 +1,40 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark
+
+import org.scalatest.FunSuite
+
+import org.apache.hadoop.io.BytesWritable
+
+class SparkContextSuite extends FunSuite {
+ //Regression test for SPARK-3121
+ test("BytesWritable implicit conversion is correct") {
+ val bytesWritable = new BytesWritable()
+ val inputArray = (1 to 10).map(_.toByte).toArray
+ bytesWritable.set(inputArray, 0, 10)
+ bytesWritable.set(inputArray, 0, 5)
+
+ val converter = SparkContext.bytesWritableConverter()
+ val byteArray = converter.convert(bytesWritable)
+ assert(byteArray.length === 5)
+
+ bytesWritable.set(inputArray, 0, 0)
+ val byteArray2 = converter.convert(bytesWritable)
+ assert(byteArray2.length === 0)
+ }
+}
diff --git a/core/src/test/scala/org/apache/spark/broadcast/BroadcastSuite.scala b/core/src/test/scala/org/apache/spark/broadcast/BroadcastSuite.scala
index 978a6ded80829..acaf321de52fb 100644
--- a/core/src/test/scala/org/apache/spark/broadcast/BroadcastSuite.scala
+++ b/core/src/test/scala/org/apache/spark/broadcast/BroadcastSuite.scala
@@ -132,7 +132,7 @@ class BroadcastSuite extends FunSuite with LocalSparkContext {
val statuses = bmm.getBlockStatus(blockId, askSlaves = true)
assert(statuses.size === 1)
statuses.head match { case (bm, status) =>
- assert(bm.executorId === "", "Block should only be on the driver")
+ assert(bm.isDriver, "Block should only be on the driver")
assert(status.storageLevel === StorageLevel.MEMORY_AND_DISK)
assert(status.memSize > 0, "Block should be in memory store on the driver")
assert(status.diskSize === 0, "Block should not be in disk store on the driver")
diff --git a/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala
index 0c324d8bdf6a4..1cdf50d5c08c7 100644
--- a/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala
+++ b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala
@@ -17,7 +17,7 @@
package org.apache.spark.deploy
-import java.io.{File, OutputStream, PrintStream}
+import java.io._
import scala.collection.mutable.ArrayBuffer
@@ -306,6 +306,21 @@ class SparkSubmitSuite extends FunSuite with Matchers {
runSparkSubmit(args)
}
+ test("SPARK_CONF_DIR overrides spark-defaults.conf") {
+ forConfDir(Map("spark.executor.memory" -> "2.3g")) { path =>
+ val unusedJar = TestUtils.createJarWithClasses(Seq.empty)
+ val args = Seq(
+ "--class", SimpleApplicationTest.getClass.getName.stripSuffix("$"),
+ "--name", "testApp",
+ "--master", "local",
+ unusedJar.toString)
+ val appArgs = new SparkSubmitArguments(args, Map("SPARK_CONF_DIR" -> path))
+ assert(appArgs.propertiesFile != null)
+ assert(appArgs.propertiesFile.startsWith(path))
+ appArgs.executorMemory should be ("2.3g")
+ }
+ }
+
// NOTE: This is an expensive operation in terms of time (10 seconds+). Use sparingly.
def runSparkSubmit(args: Seq[String]): String = {
val sparkHome = sys.props.getOrElse("spark.test.home", fail("spark.test.home is not set!"))
@@ -314,6 +329,22 @@ class SparkSubmitSuite extends FunSuite with Matchers {
new File(sparkHome),
Map("SPARK_TESTING" -> "1", "SPARK_HOME" -> sparkHome))
}
+
+ def forConfDir(defaults: Map[String, String]) (f: String => Unit) = {
+ val tmpDir = Utils.createTempDir()
+
+ val defaultsConf = new File(tmpDir.getAbsolutePath, "spark-defaults.conf")
+ val writer = new OutputStreamWriter(new FileOutputStream(defaultsConf))
+ for ((key, value) <- defaults) writer.write(s"$key $value\n")
+
+ writer.close()
+
+ try {
+ f(tmpDir.getAbsolutePath)
+ } finally {
+ Utils.deleteRecursively(tmpDir)
+ }
+ }
}
object JarCreationTest {
diff --git a/core/src/test/scala/org/apache/spark/deploy/worker/ExecutorRunnerTest.scala b/core/src/test/scala/org/apache/spark/deploy/worker/ExecutorRunnerTest.scala
index 39ab53cf0b5b1..5e2592e8d2e8d 100644
--- a/core/src/test/scala/org/apache/spark/deploy/worker/ExecutorRunnerTest.scala
+++ b/core/src/test/scala/org/apache/spark/deploy/worker/ExecutorRunnerTest.scala
@@ -26,14 +26,12 @@ import org.apache.spark.SparkConf
class ExecutorRunnerTest extends FunSuite {
test("command includes appId") {
- def f(s:String) = new File(s)
+ val appId = "12345-worker321-9876"
val sparkHome = sys.props.getOrElse("spark.test.home", fail("spark.test.home is not set!"))
val appDesc = new ApplicationDescription("app name", Some(8), 500,
- Command("foo", Seq(), Map(), Seq(), Seq(), Seq()), "appUiUrl")
- val appId = "12345-worker321-9876"
- val er = new ExecutorRunner(appId, 1, appDesc, 8, 500, null, "blah", "worker321", f(sparkHome),
- f("ooga"), "blah", new SparkConf, ExecutorState.RUNNING)
-
+ Command("foo", Seq(appId), Map(), Seq(), Seq(), Seq()), "appUiUrl")
+ val er = new ExecutorRunner(appId, 1, appDesc, 8, 500, null, "blah", "worker321",
+ new File(sparkHome), new File("ooga"), "blah", new SparkConf, ExecutorState.RUNNING)
assert(er.getCommandSeq.last === appId)
}
}
diff --git a/core/src/test/scala/org/apache/spark/deploy/worker/WorkerArgumentsTest.scala b/core/src/test/scala/org/apache/spark/deploy/worker/WorkerArgumentsTest.scala
new file mode 100644
index 0000000000000..1a28a9a187cd7
--- /dev/null
+++ b/core/src/test/scala/org/apache/spark/deploy/worker/WorkerArgumentsTest.scala
@@ -0,0 +1,82 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+
+package org.apache.spark.deploy.worker
+
+import org.apache.spark.SparkConf
+import org.scalatest.FunSuite
+
+
+class WorkerArgumentsTest extends FunSuite {
+
+ test("Memory can't be set to 0 when cmd line args leave off M or G") {
+ val conf = new SparkConf
+ val args = Array("-m", "10000", "spark://localhost:0000 ")
+ intercept[IllegalStateException] {
+ new WorkerArguments(args, conf)
+ }
+ }
+
+
+ test("Memory can't be set to 0 when SPARK_WORKER_MEMORY env property leaves off M or G") {
+ val args = Array("spark://localhost:0000 ")
+
+ class MySparkConf extends SparkConf(false) {
+ override def getenv(name: String) = {
+ if (name == "SPARK_WORKER_MEMORY") "50000"
+ else super.getenv(name)
+ }
+
+ override def clone: SparkConf = {
+ new MySparkConf().setAll(settings)
+ }
+ }
+ val conf = new MySparkConf()
+ intercept[IllegalStateException] {
+ new WorkerArguments(args, conf)
+ }
+ }
+
+ test("Memory correctly set when SPARK_WORKER_MEMORY env property appends G") {
+ val args = Array("spark://localhost:0000 ")
+
+ class MySparkConf extends SparkConf(false) {
+ override def getenv(name: String) = {
+ if (name == "SPARK_WORKER_MEMORY") "5G"
+ else super.getenv(name)
+ }
+
+ override def clone: SparkConf = {
+ new MySparkConf().setAll(settings)
+ }
+ }
+ val conf = new MySparkConf()
+ val workerArgs = new WorkerArguments(args, conf)
+ assert(workerArgs.memory === 5120)
+ }
+
+ test("Memory correctly set from args with M appended to memory value") {
+ val conf = new SparkConf
+ val args = Array("-m", "10000M", "spark://localhost:0000 ")
+
+ val workerArgs = new WorkerArguments(args, conf)
+ assert(workerArgs.memory === 10000)
+
+ }
+
+}
diff --git a/core/src/test/scala/org/apache/spark/input/WholeTextFileRecordReaderSuite.scala b/core/src/test/scala/org/apache/spark/input/WholeTextFileRecordReaderSuite.scala
index d5ebfb3f3fae1..12d1c7b2faba6 100644
--- a/core/src/test/scala/org/apache/spark/input/WholeTextFileRecordReaderSuite.scala
+++ b/core/src/test/scala/org/apache/spark/input/WholeTextFileRecordReaderSuite.scala
@@ -23,8 +23,6 @@ import java.io.FileOutputStream
import scala.collection.immutable.IndexedSeq
-import com.google.common.io.Files
-
import org.scalatest.BeforeAndAfterAll
import org.scalatest.FunSuite
@@ -66,9 +64,7 @@ class WholeTextFileRecordReaderSuite extends FunSuite with BeforeAndAfterAll {
* 3) Does the contents be the same.
*/
test("Correctness of WholeTextFileRecordReader.") {
-
- val dir = Files.createTempDir()
- dir.deleteOnExit()
+ val dir = Utils.createTempDir()
println(s"Local disk address is ${dir.toString}.")
WholeTextFileRecordReaderSuite.files.foreach { case (filename, contents) =>
diff --git a/core/src/test/scala/org/apache/spark/metrics/MetricsSystemSuite.scala b/core/src/test/scala/org/apache/spark/metrics/MetricsSystemSuite.scala
index e42b181194727..3925f0ccbdbf0 100644
--- a/core/src/test/scala/org/apache/spark/metrics/MetricsSystemSuite.scala
+++ b/core/src/test/scala/org/apache/spark/metrics/MetricsSystemSuite.scala
@@ -17,14 +17,15 @@
package org.apache.spark.metrics
-import org.apache.spark.metrics.source.Source
import org.scalatest.{BeforeAndAfter, FunSuite, PrivateMethodTester}
import org.apache.spark.{SecurityManager, SparkConf}
import org.apache.spark.deploy.master.MasterSource
+import org.apache.spark.metrics.source.Source
-import scala.collection.mutable.ArrayBuffer
+import com.codahale.metrics.MetricRegistry
+import scala.collection.mutable.ArrayBuffer
class MetricsSystemSuite extends FunSuite with BeforeAndAfter with PrivateMethodTester{
var filePath: String = _
@@ -39,6 +40,7 @@ class MetricsSystemSuite extends FunSuite with BeforeAndAfter with PrivateMethod
test("MetricsSystem with default config") {
val metricsSystem = MetricsSystem.createMetricsSystem("default", conf, securityMgr)
+ metricsSystem.start()
val sources = PrivateMethod[ArrayBuffer[Source]]('sources)
val sinks = PrivateMethod[ArrayBuffer[Source]]('sinks)
@@ -49,6 +51,7 @@ class MetricsSystemSuite extends FunSuite with BeforeAndAfter with PrivateMethod
test("MetricsSystem with sources add") {
val metricsSystem = MetricsSystem.createMetricsSystem("test", conf, securityMgr)
+ metricsSystem.start()
val sources = PrivateMethod[ArrayBuffer[Source]]('sources)
val sinks = PrivateMethod[ArrayBuffer[Source]]('sinks)
@@ -60,4 +63,125 @@ class MetricsSystemSuite extends FunSuite with BeforeAndAfter with PrivateMethod
metricsSystem.registerSource(source)
assert(metricsSystem.invokePrivate(sources()).length === 1)
}
+
+ test("MetricsSystem with Driver instance") {
+ val source = new Source {
+ override val sourceName = "dummySource"
+ override val metricRegistry = new MetricRegistry()
+ }
+
+ val appId = "testId"
+ val executorId = "driver"
+ conf.set("spark.app.id", appId)
+ conf.set("spark.executor.id", executorId)
+
+ val instanceName = "driver"
+ val driverMetricsSystem = MetricsSystem.createMetricsSystem(instanceName, conf, securityMgr)
+
+ val metricName = driverMetricsSystem.buildRegistryName(source)
+ assert(metricName === s"$appId.$executorId.${source.sourceName}")
+ }
+
+ test("MetricsSystem with Driver instance and spark.app.id is not set") {
+ val source = new Source {
+ override val sourceName = "dummySource"
+ override val metricRegistry = new MetricRegistry()
+ }
+
+ val executorId = "driver"
+ conf.set("spark.executor.id", executorId)
+
+ val instanceName = "driver"
+ val driverMetricsSystem = MetricsSystem.createMetricsSystem(instanceName, conf, securityMgr)
+
+ val metricName = driverMetricsSystem.buildRegistryName(source)
+ assert(metricName === source.sourceName)
+ }
+
+ test("MetricsSystem with Driver instance and spark.executor.id is not set") {
+ val source = new Source {
+ override val sourceName = "dummySource"
+ override val metricRegistry = new MetricRegistry()
+ }
+
+ val appId = "testId"
+ conf.set("spark.app.id", appId)
+
+ val instanceName = "driver"
+ val driverMetricsSystem = MetricsSystem.createMetricsSystem(instanceName, conf, securityMgr)
+
+ val metricName = driverMetricsSystem.buildRegistryName(source)
+ assert(metricName === source.sourceName)
+ }
+
+ test("MetricsSystem with Executor instance") {
+ val source = new Source {
+ override val sourceName = "dummySource"
+ override val metricRegistry = new MetricRegistry()
+ }
+
+ val appId = "testId"
+ val executorId = "executor.1"
+ conf.set("spark.app.id", appId)
+ conf.set("spark.executor.id", executorId)
+
+ val instanceName = "executor"
+ val driverMetricsSystem = MetricsSystem.createMetricsSystem(instanceName, conf, securityMgr)
+
+ val metricName = driverMetricsSystem.buildRegistryName(source)
+ assert(metricName === s"$appId.$executorId.${source.sourceName}")
+ }
+
+ test("MetricsSystem with Executor instance and spark.app.id is not set") {
+ val source = new Source {
+ override val sourceName = "dummySource"
+ override val metricRegistry = new MetricRegistry()
+ }
+
+ val executorId = "executor.1"
+ conf.set("spark.executor.id", executorId)
+
+ val instanceName = "executor"
+ val driverMetricsSystem = MetricsSystem.createMetricsSystem(instanceName, conf, securityMgr)
+
+ val metricName = driverMetricsSystem.buildRegistryName(source)
+ assert(metricName === source.sourceName)
+ }
+
+ test("MetricsSystem with Executor instance and spark.executor.id is not set") {
+ val source = new Source {
+ override val sourceName = "dummySource"
+ override val metricRegistry = new MetricRegistry()
+ }
+
+ val appId = "testId"
+ conf.set("spark.app.id", appId)
+
+ val instanceName = "executor"
+ val driverMetricsSystem = MetricsSystem.createMetricsSystem(instanceName, conf, securityMgr)
+
+ val metricName = driverMetricsSystem.buildRegistryName(source)
+ assert(metricName === source.sourceName)
+ }
+
+ test("MetricsSystem with instance which is neither Driver nor Executor") {
+ val source = new Source {
+ override val sourceName = "dummySource"
+ override val metricRegistry = new MetricRegistry()
+ }
+
+ val appId = "testId"
+ val executorId = "dummyExecutorId"
+ conf.set("spark.app.id", appId)
+ conf.set("spark.executor.id", executorId)
+
+ val instanceName = "testInstance"
+ val driverMetricsSystem = MetricsSystem.createMetricsSystem(instanceName, conf, securityMgr)
+
+ val metricName = driverMetricsSystem.buildRegistryName(source)
+
+ // Even if spark.app.id and spark.executor.id are set, they are not used for the metric name.
+ assert(metricName != s"$appId.$executorId.${source.sourceName}")
+ assert(metricName === source.sourceName)
+ }
}
diff --git a/core/src/test/scala/org/apache/spark/network/nio/ConnectionManagerSuite.scala b/core/src/test/scala/org/apache/spark/network/nio/ConnectionManagerSuite.scala
index 9f49587cdc670..b70734dfe37cf 100644
--- a/core/src/test/scala/org/apache/spark/network/nio/ConnectionManagerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/network/nio/ConnectionManagerSuite.scala
@@ -27,6 +27,7 @@ import scala.language.postfixOps
import org.scalatest.FunSuite
import org.apache.spark.{SecurityManager, SparkConf}
+import org.apache.spark.util.Utils
/**
* Test the ConnectionManager with various security settings.
@@ -236,7 +237,7 @@ class ConnectionManagerSuite extends FunSuite {
val manager = new ConnectionManager(0, conf, securityManager)
val managerServer = new ConnectionManager(0, conf, securityManager)
managerServer.onReceiveMessage((msg: Message, id: ConnectionManagerId) => {
- throw new Exception
+ throw new Exception("Custom exception text")
})
val size = 10 * 1024 * 1024
@@ -246,9 +247,10 @@ class ConnectionManagerSuite extends FunSuite {
val future = manager.sendMessageReliably(managerServer.id, bufferMessage)
- intercept[IOException] {
+ val exception = intercept[IOException] {
Await.result(future, 1 second)
}
+ assert(Utils.exceptionString(exception).contains("Custom exception text"))
manager.stop()
managerServer.stop()
diff --git a/core/src/test/scala/org/apache/spark/rdd/PairRDDFunctionsSuite.scala b/core/src/test/scala/org/apache/spark/rdd/PairRDDFunctionsSuite.scala
index 75b01191901b8..3620e251cc139 100644
--- a/core/src/test/scala/org/apache/spark/rdd/PairRDDFunctionsSuite.scala
+++ b/core/src/test/scala/org/apache/spark/rdd/PairRDDFunctionsSuite.scala
@@ -24,13 +24,14 @@ import org.apache.hadoop.util.Progressable
import scala.collection.mutable.{ArrayBuffer, HashSet}
import scala.util.Random
-import com.google.common.io.Files
import org.apache.hadoop.conf.{Configurable, Configuration}
import org.apache.hadoop.mapreduce.{JobContext => NewJobContext, OutputCommitter => NewOutputCommitter,
OutputFormat => NewOutputFormat, RecordWriter => NewRecordWriter,
TaskAttemptContext => NewTaskAttempContext}
import org.apache.spark.{Partitioner, SharedSparkContext}
import org.apache.spark.SparkContext._
+import org.apache.spark.util.Utils
+
import org.scalatest.FunSuite
class PairRDDFunctionsSuite extends FunSuite with SharedSparkContext {
@@ -381,14 +382,16 @@ class PairRDDFunctionsSuite extends FunSuite with SharedSparkContext {
}
test("zero-partition RDD") {
- val emptyDir = Files.createTempDir()
- emptyDir.deleteOnExit()
- val file = sc.textFile(emptyDir.getAbsolutePath)
- assert(file.partitions.size == 0)
- assert(file.collect().toList === Nil)
- // Test that a shuffle on the file works, because this used to be a bug
- assert(file.map(line => (line, 1)).reduceByKey(_ + _).collect().toList === Nil)
- emptyDir.delete()
+ val emptyDir = Utils.createTempDir()
+ try {
+ val file = sc.textFile(emptyDir.getAbsolutePath)
+ assert(file.partitions.isEmpty)
+ assert(file.collect().toList === Nil)
+ // Test that a shuffle on the file works, because this used to be a bug
+ assert(file.map(line => (line, 1)).reduceByKey(_ + _).collect().toList === Nil)
+ } finally {
+ Utils.deleteRecursively(emptyDir)
+ }
}
test("keys and values") {
diff --git a/core/src/test/scala/org/apache/spark/rdd/PipedRDDSuite.scala b/core/src/test/scala/org/apache/spark/rdd/PipedRDDSuite.scala
index be972c5e97a7e..271a90c6646bb 100644
--- a/core/src/test/scala/org/apache/spark/rdd/PipedRDDSuite.scala
+++ b/core/src/test/scala/org/apache/spark/rdd/PipedRDDSuite.scala
@@ -174,7 +174,7 @@ class PipedRDDSuite extends FunSuite with SharedSparkContext {
}
val hadoopPart1 = generateFakeHadoopPartition()
val pipedRdd = new PipedRDD(nums, "printenv " + varName)
- val tContext = new TaskContext(0, 0, 0)
+ val tContext = new TaskContextImpl(0, 0, 0)
val rddIter = pipedRdd.compute(hadoopPart1, tContext)
val arr = rddIter.toArray
assert(arr(0) == "/some/path")
diff --git a/core/src/test/scala/org/apache/spark/scheduler/EventLoggingListenerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/EventLoggingListenerSuite.scala
index e5315bc93e217..abc300fcffaf9 100644
--- a/core/src/test/scala/org/apache/spark/scheduler/EventLoggingListenerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/scheduler/EventLoggingListenerSuite.scala
@@ -20,7 +20,6 @@ package org.apache.spark.scheduler
import scala.collection.mutable
import scala.io.Source
-import com.google.common.io.Files
import org.apache.hadoop.fs.{FileStatus, Path}
import org.json4s.jackson.JsonMethods._
import org.scalatest.{BeforeAndAfter, FunSuite}
@@ -51,8 +50,7 @@ class EventLoggingListenerSuite extends FunSuite with BeforeAndAfter {
private var logDirPath: Path = _
before {
- testDir = Files.createTempDir()
- testDir.deleteOnExit()
+ testDir = Utils.createTempDir()
logDirPath = Utils.getFilePath(testDir, "spark-events")
}
@@ -169,7 +167,9 @@ class EventLoggingListenerSuite extends FunSuite with BeforeAndAfter {
// Verify logging directory exists
val conf = getLoggingConf(logDirPath, compressionCodec)
- val eventLogger = new EventLoggingListener("test", conf)
+ val logBaseDir = conf.get("spark.eventLog.dir")
+ val appId = EventLoggingListenerSuite.getUniqueApplicationId
+ val eventLogger = new EventLoggingListener(appId, logBaseDir, conf)
eventLogger.start()
val logPath = new Path(eventLogger.logDir)
assert(fileSystem.exists(logPath))
@@ -209,7 +209,9 @@ class EventLoggingListenerSuite extends FunSuite with BeforeAndAfter {
// Verify that all information is correctly parsed before stop()
val conf = getLoggingConf(logDirPath, compressionCodec)
- val eventLogger = new EventLoggingListener("test", conf)
+ val logBaseDir = conf.get("spark.eventLog.dir")
+ val appId = EventLoggingListenerSuite.getUniqueApplicationId
+ val eventLogger = new EventLoggingListener(appId, logBaseDir, conf)
eventLogger.start()
var eventLoggingInfo = EventLoggingListener.parseLoggingInfo(eventLogger.logDir, fileSystem)
assertInfoCorrect(eventLoggingInfo, loggerStopped = false)
@@ -228,7 +230,9 @@ class EventLoggingListenerSuite extends FunSuite with BeforeAndAfter {
*/
private def testEventLogging(compressionCodec: Option[String] = None) {
val conf = getLoggingConf(logDirPath, compressionCodec)
- val eventLogger = new EventLoggingListener("test", conf)
+ val logBaseDir = conf.get("spark.eventLog.dir")
+ val appId = EventLoggingListenerSuite.getUniqueApplicationId
+ val eventLogger = new EventLoggingListener(appId, logBaseDir, conf)
val listenerBus = new LiveListenerBus
val applicationStart = SparkListenerApplicationStart("Greatest App (N)ever", None,
125L, "Mickey")
@@ -408,4 +412,6 @@ object EventLoggingListenerSuite {
}
conf
}
+
+ def getUniqueApplicationId = "test-" + System.currentTimeMillis
}
diff --git a/core/src/test/scala/org/apache/spark/scheduler/ReplayListenerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/ReplayListenerSuite.scala
index 7ab351d1b4d24..e05f373392d4a 100644
--- a/core/src/test/scala/org/apache/spark/scheduler/ReplayListenerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/scheduler/ReplayListenerSuite.scala
@@ -19,7 +19,6 @@ package org.apache.spark.scheduler
import java.io.{File, PrintWriter}
-import com.google.common.io.Files
import org.json4s.jackson.JsonMethods._
import org.scalatest.{BeforeAndAfter, FunSuite}
@@ -39,8 +38,7 @@ class ReplayListenerSuite extends FunSuite with BeforeAndAfter {
private var testDir: File = _
before {
- testDir = Files.createTempDir()
- testDir.deleteOnExit()
+ testDir = Utils.createTempDir()
}
after {
@@ -155,7 +153,8 @@ class ReplayListenerSuite extends FunSuite with BeforeAndAfter {
* This child listener inherits only the event buffering functionality, but does not actually
* log the events.
*/
- private class EventMonster(conf: SparkConf) extends EventLoggingListener("test", conf) {
+ private class EventMonster(conf: SparkConf)
+ extends EventLoggingListener("test", "testdir", conf) {
logger.close()
}
diff --git a/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala
index faba5508c906c..561a5e9cd90c4 100644
--- a/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala
+++ b/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala
@@ -51,7 +51,7 @@ class TaskContextSuite extends FunSuite with BeforeAndAfter with LocalSparkConte
}
test("all TaskCompletionListeners should be called even if some fail") {
- val context = new TaskContext(0, 0, 0)
+ val context = new TaskContextImpl(0, 0, 0)
val listener = mock(classOf[TaskCompletionListener])
context.addTaskCompletionListener(_ => throw new Exception("blah"))
context.addTaskCompletionListener(listener)
diff --git a/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala
index 93e8ddacf8865..c0b07649eb6dd 100644
--- a/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala
@@ -642,6 +642,28 @@ class TaskSetManagerSuite extends FunSuite with LocalSparkContext with Logging {
assert(manager.resourceOffer("execC", "host3", ANY) !== None)
}
+ test("Test that locations with HDFSCacheTaskLocation are treated as PROCESS_LOCAL.") {
+ // Regression test for SPARK-2931
+ sc = new SparkContext("local", "test")
+ val sched = new FakeTaskScheduler(sc,
+ ("execA", "host1"), ("execB", "host2"), ("execC", "host3"))
+ val taskSet = FakeTask.createTaskSet(3,
+ Seq(HostTaskLocation("host1")),
+ Seq(HostTaskLocation("host2")),
+ Seq(HDFSCacheTaskLocation("host3")))
+ val clock = new FakeClock
+ val manager = new TaskSetManager(sched, taskSet, MAX_TASK_FAILURES, clock)
+ assert(manager.myLocalityLevels.sameElements(Array(PROCESS_LOCAL, NODE_LOCAL, ANY)))
+ sched.removeExecutor("execA")
+ manager.executorAdded()
+ assert(manager.myLocalityLevels.sameElements(Array(PROCESS_LOCAL, NODE_LOCAL, ANY)))
+ sched.removeExecutor("execB")
+ manager.executorAdded()
+ assert(manager.myLocalityLevels.sameElements(Array(PROCESS_LOCAL, NODE_LOCAL, ANY)))
+ sched.removeExecutor("execC")
+ manager.executorAdded()
+ assert(manager.myLocalityLevels.sameElements(Array(ANY)))
+ }
def createTaskResult(id: Int): DirectTaskResult[Int] = {
val valueSer = SparkEnv.get.serializer.newInstance()
diff --git a/core/src/test/scala/org/apache/spark/storage/BlockManagerReplicationSuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockManagerReplicationSuite.scala
new file mode 100644
index 0000000000000..1f1d53a1ee3b0
--- /dev/null
+++ b/core/src/test/scala/org/apache/spark/storage/BlockManagerReplicationSuite.scala
@@ -0,0 +1,418 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.storage
+
+import scala.collection.mutable.ArrayBuffer
+import scala.concurrent.duration._
+import scala.language.implicitConversions
+import scala.language.postfixOps
+
+import akka.actor.{ActorSystem, Props}
+import org.mockito.Mockito.{mock, when}
+import org.scalatest.{BeforeAndAfter, FunSuite, Matchers, PrivateMethodTester}
+import org.scalatest.concurrent.Eventually._
+
+import org.apache.spark.{MapOutputTrackerMaster, SecurityManager, SparkConf}
+import org.apache.spark.network.BlockTransferService
+import org.apache.spark.network.nio.NioBlockTransferService
+import org.apache.spark.scheduler.LiveListenerBus
+import org.apache.spark.serializer.KryoSerializer
+import org.apache.spark.shuffle.hash.HashShuffleManager
+import org.apache.spark.storage.StorageLevel._
+import org.apache.spark.util.{AkkaUtils, SizeEstimator}
+
+/** Testsuite that tests block replication in BlockManager */
+class BlockManagerReplicationSuite extends FunSuite with Matchers with BeforeAndAfter {
+
+ private val conf = new SparkConf(false)
+ var actorSystem: ActorSystem = null
+ var master: BlockManagerMaster = null
+ val securityMgr = new SecurityManager(conf)
+ val mapOutputTracker = new MapOutputTrackerMaster(conf)
+ val shuffleManager = new HashShuffleManager(conf)
+
+ // List of block manager created during an unit test, so that all of the them can be stopped
+ // after the unit test.
+ val allStores = new ArrayBuffer[BlockManager]
+
+ // Reuse a serializer across tests to avoid creating a new thread-local buffer on each test
+ conf.set("spark.kryoserializer.buffer.mb", "1")
+ val serializer = new KryoSerializer(conf)
+
+ // Implicitly convert strings to BlockIds for test clarity.
+ implicit def StringToBlockId(value: String): BlockId = new TestBlockId(value)
+
+ private def makeBlockManager(maxMem: Long, name: String = ""): BlockManager = {
+ val transfer = new NioBlockTransferService(conf, securityMgr)
+ val store = new BlockManager(name, actorSystem, master, serializer, maxMem, conf,
+ mapOutputTracker, shuffleManager, transfer)
+ allStores += store
+ store
+ }
+
+ before {
+ val (actorSystem, boundPort) = AkkaUtils.createActorSystem(
+ "test", "localhost", 0, conf = conf, securityManager = securityMgr)
+ this.actorSystem = actorSystem
+
+ conf.set("spark.authenticate", "false")
+ conf.set("spark.driver.port", boundPort.toString)
+ conf.set("spark.storage.unrollFraction", "0.4")
+ conf.set("spark.storage.unrollMemoryThreshold", "512")
+
+ // to make a replication attempt to inactive store fail fast
+ conf.set("spark.core.connection.ack.wait.timeout", "1")
+ // to make cached peers refresh frequently
+ conf.set("spark.storage.cachedPeersTtl", "10")
+
+ master = new BlockManagerMaster(
+ actorSystem.actorOf(Props(new BlockManagerMasterActor(true, conf, new LiveListenerBus))),
+ conf, true)
+ allStores.clear()
+ }
+
+ after {
+ allStores.foreach { _.stop() }
+ allStores.clear()
+ actorSystem.shutdown()
+ actorSystem.awaitTermination()
+ actorSystem = null
+ master = null
+ }
+
+
+ test("get peers with addition and removal of block managers") {
+ val numStores = 4
+ val stores = (1 to numStores - 1).map { i => makeBlockManager(1000, s"store$i") }
+ val storeIds = stores.map { _.blockManagerId }.toSet
+ assert(master.getPeers(stores(0).blockManagerId).toSet ===
+ storeIds.filterNot { _ == stores(0).blockManagerId })
+ assert(master.getPeers(stores(1).blockManagerId).toSet ===
+ storeIds.filterNot { _ == stores(1).blockManagerId })
+ assert(master.getPeers(stores(2).blockManagerId).toSet ===
+ storeIds.filterNot { _ == stores(2).blockManagerId })
+
+ // Add driver store and test whether it is filtered out
+ val driverStore = makeBlockManager(1000, "")
+ assert(master.getPeers(stores(0).blockManagerId).forall(!_.isDriver))
+ assert(master.getPeers(stores(1).blockManagerId).forall(!_.isDriver))
+ assert(master.getPeers(stores(2).blockManagerId).forall(!_.isDriver))
+
+ // Add a new store and test whether get peers returns it
+ val newStore = makeBlockManager(1000, s"store$numStores")
+ assert(master.getPeers(stores(0).blockManagerId).toSet ===
+ storeIds.filterNot { _ == stores(0).blockManagerId } + newStore.blockManagerId)
+ assert(master.getPeers(stores(1).blockManagerId).toSet ===
+ storeIds.filterNot { _ == stores(1).blockManagerId } + newStore.blockManagerId)
+ assert(master.getPeers(stores(2).blockManagerId).toSet ===
+ storeIds.filterNot { _ == stores(2).blockManagerId } + newStore.blockManagerId)
+ assert(master.getPeers(newStore.blockManagerId).toSet === storeIds)
+
+ // Remove a store and test whether get peers returns it
+ val storeIdToRemove = stores(0).blockManagerId
+ master.removeExecutor(storeIdToRemove.executorId)
+ assert(!master.getPeers(stores(1).blockManagerId).contains(storeIdToRemove))
+ assert(!master.getPeers(stores(2).blockManagerId).contains(storeIdToRemove))
+ assert(!master.getPeers(newStore.blockManagerId).contains(storeIdToRemove))
+
+ // Test whether asking for peers of a unregistered block manager id returns empty list
+ assert(master.getPeers(stores(0).blockManagerId).isEmpty)
+ assert(master.getPeers(BlockManagerId("", "", 1)).isEmpty)
+ }
+
+
+ test("block replication - 2x replication") {
+ testReplication(2,
+ Seq(MEMORY_ONLY, MEMORY_ONLY_SER, DISK_ONLY, MEMORY_AND_DISK_2, MEMORY_AND_DISK_SER_2)
+ )
+ }
+
+ test("block replication - 3x replication") {
+ // Generate storage levels with 3x replication
+ val storageLevels = {
+ Seq(MEMORY_ONLY, MEMORY_ONLY_SER, DISK_ONLY, MEMORY_AND_DISK, MEMORY_AND_DISK_SER).map {
+ level => StorageLevel(
+ level.useDisk, level.useMemory, level.useOffHeap, level.deserialized, 3)
+ }
+ }
+ testReplication(3, storageLevels)
+ }
+
+ test("block replication - mixed between 1x to 5x") {
+ // Generate storage levels with varying replication
+ val storageLevels = Seq(
+ MEMORY_ONLY,
+ MEMORY_ONLY_SER_2,
+ StorageLevel(true, false, false, false, 3),
+ StorageLevel(true, true, false, true, 4),
+ StorageLevel(true, true, false, false, 5),
+ StorageLevel(true, true, false, true, 4),
+ StorageLevel(true, false, false, false, 3),
+ MEMORY_ONLY_SER_2,
+ MEMORY_ONLY
+ )
+ testReplication(5, storageLevels)
+ }
+
+ test("block replication - 2x replication without peers") {
+ intercept[org.scalatest.exceptions.TestFailedException] {
+ testReplication(1,
+ Seq(StorageLevel.MEMORY_AND_DISK_2, StorageLevel(true, false, false, false, 3)))
+ }
+ }
+
+ test("block replication - deterministic node selection") {
+ val blockSize = 1000
+ val storeSize = 10000
+ val stores = (1 to 5).map {
+ i => makeBlockManager(storeSize, s"store$i")
+ }
+ val storageLevel2x = StorageLevel.MEMORY_AND_DISK_2
+ val storageLevel3x = StorageLevel(true, true, false, true, 3)
+ val storageLevel4x = StorageLevel(true, true, false, true, 4)
+
+ def putBlockAndGetLocations(blockId: String, level: StorageLevel): Set[BlockManagerId] = {
+ stores.head.putSingle(blockId, new Array[Byte](blockSize), level)
+ val locations = master.getLocations(blockId).sortBy { _.executorId }.toSet
+ stores.foreach { _.removeBlock(blockId) }
+ master.removeBlock(blockId)
+ locations
+ }
+
+ // Test if two attempts to 2x replication returns same set of locations
+ val a1Locs = putBlockAndGetLocations("a1", storageLevel2x)
+ assert(putBlockAndGetLocations("a1", storageLevel2x) === a1Locs,
+ "Inserting a 2x replicated block second time gave different locations from the first")
+
+ // Test if two attempts to 3x replication returns same set of locations
+ val a2Locs3x = putBlockAndGetLocations("a2", storageLevel3x)
+ assert(putBlockAndGetLocations("a2", storageLevel3x) === a2Locs3x,
+ "Inserting a 3x replicated block second time gave different locations from the first")
+
+ // Test if 2x replication of a2 returns a strict subset of the locations of 3x replication
+ val a2Locs2x = putBlockAndGetLocations("a2", storageLevel2x)
+ assert(
+ a2Locs2x.subsetOf(a2Locs3x),
+ "Inserting a with 2x replication gave locations that are not a subset of locations" +
+ s" with 3x replication [3x: ${a2Locs3x.mkString(",")}; 2x: ${a2Locs2x.mkString(",")}"
+ )
+
+ // Test if 4x replication of a2 returns a strict superset of the locations of 3x replication
+ val a2Locs4x = putBlockAndGetLocations("a2", storageLevel4x)
+ assert(
+ a2Locs3x.subsetOf(a2Locs4x),
+ "Inserting a with 4x replication gave locations that are not a superset of locations " +
+ s"with 3x replication [3x: ${a2Locs3x.mkString(",")}; 4x: ${a2Locs4x.mkString(",")}"
+ )
+
+ // Test if 3x replication of two different blocks gives two different sets of locations
+ val a3Locs3x = putBlockAndGetLocations("a3", storageLevel3x)
+ assert(a3Locs3x !== a2Locs3x, "Two blocks gave same locations with 3x replication")
+ }
+
+ test("block replication - replication failures") {
+ /*
+ Create a system of three block managers / stores. One of them (say, failableStore)
+ cannot receive blocks. So attempts to use that as replication target fails.
+
+ +-----------/fails/-----------> failableStore
+ |
+ normalStore
+ |
+ +-----------/works/-----------> anotherNormalStore
+
+ We are first going to add a normal block manager (i.e. normalStore) and the failable block
+ manager (i.e. failableStore), and test whether 2x replication fails to create two
+ copies of a block. Then we are going to add another normal block manager
+ (i.e., anotherNormalStore), and test that now 2x replication works as the
+ new store will be used for replication.
+ */
+
+ // Add a normal block manager
+ val store = makeBlockManager(10000, "store")
+
+ // Insert a block with 2x replication and return the number of copies of the block
+ def replicateAndGetNumCopies(blockId: String): Int = {
+ store.putSingle(blockId, new Array[Byte](1000), StorageLevel.MEMORY_AND_DISK_2)
+ val numLocations = master.getLocations(blockId).size
+ allStores.foreach { _.removeBlock(blockId) }
+ numLocations
+ }
+
+ // Add a failable block manager with a mock transfer service that does not
+ // allow receiving of blocks. So attempts to use it as a replication target will fail.
+ val failableTransfer = mock(classOf[BlockTransferService]) // this wont actually work
+ when(failableTransfer.hostName).thenReturn("some-hostname")
+ when(failableTransfer.port).thenReturn(1000)
+ val failableStore = new BlockManager("failable-store", actorSystem, master, serializer,
+ 10000, conf, mapOutputTracker, shuffleManager, failableTransfer)
+ allStores += failableStore // so that this gets stopped after test
+ assert(master.getPeers(store.blockManagerId).toSet === Set(failableStore.blockManagerId))
+
+ // Test that 2x replication fails by creating only one copy of the block
+ assert(replicateAndGetNumCopies("a1") === 1)
+
+ // Add another normal block manager and test that 2x replication works
+ makeBlockManager(10000, "anotherStore")
+ eventually(timeout(1000 milliseconds), interval(10 milliseconds)) {
+ assert(replicateAndGetNumCopies("a2") === 2)
+ }
+ }
+
+ test("block replication - addition and deletion of block managers") {
+ val blockSize = 1000
+ val storeSize = 10000
+ val initialStores = (1 to 2).map { i => makeBlockManager(storeSize, s"store$i") }
+
+ // Insert a block with given replication factor and return the number of copies of the block\
+ def replicateAndGetNumCopies(blockId: String, replicationFactor: Int): Int = {
+ val storageLevel = StorageLevel(true, true, false, true, replicationFactor)
+ initialStores.head.putSingle(blockId, new Array[Byte](blockSize), storageLevel)
+ val numLocations = master.getLocations(blockId).size
+ allStores.foreach { _.removeBlock(blockId) }
+ numLocations
+ }
+
+ // 2x replication should work, 3x replication should only replicate 2x
+ assert(replicateAndGetNumCopies("a1", 2) === 2)
+ assert(replicateAndGetNumCopies("a2", 3) === 2)
+
+ // Add another store, 3x replication should work now, 4x replication should only replicate 3x
+ val newStore1 = makeBlockManager(storeSize, s"newstore1")
+ eventually(timeout(1000 milliseconds), interval(10 milliseconds)) {
+ assert(replicateAndGetNumCopies("a3", 3) === 3)
+ }
+ assert(replicateAndGetNumCopies("a4", 4) === 3)
+
+ // Add another store, 4x replication should work now
+ val newStore2 = makeBlockManager(storeSize, s"newstore2")
+ eventually(timeout(1000 milliseconds), interval(10 milliseconds)) {
+ assert(replicateAndGetNumCopies("a5", 4) === 4)
+ }
+
+ // Remove all but the 1st store, 2x replication should fail
+ (initialStores.tail ++ Seq(newStore1, newStore2)).foreach {
+ store =>
+ master.removeExecutor(store.blockManagerId.executorId)
+ store.stop()
+ }
+ assert(replicateAndGetNumCopies("a6", 2) === 1)
+
+ // Add new stores, 3x replication should work
+ val newStores = (3 to 5).map {
+ i => makeBlockManager(storeSize, s"newstore$i")
+ }
+ eventually(timeout(1000 milliseconds), interval(10 milliseconds)) {
+ assert(replicateAndGetNumCopies("a7", 3) === 3)
+ }
+ }
+
+ /**
+ * Test replication of blocks with different storage levels (various combinations of
+ * memory, disk & serialization). For each storage level, this function tests every store
+ * whether the block is present and also tests the master whether its knowledge of blocks
+ * is correct. Then it also drops the block from memory of each store (using LRU) and
+ * again checks whether the master's knowledge gets updated.
+ */
+ private def testReplication(maxReplication: Int, storageLevels: Seq[StorageLevel]) {
+ import org.apache.spark.storage.StorageLevel._
+
+ assert(maxReplication > 1,
+ s"Cannot test replication factor $maxReplication")
+
+ // storage levels to test with the given replication factor
+
+ val storeSize = 10000
+ val blockSize = 1000
+
+ // As many stores as the replication factor
+ val stores = (1 to maxReplication).map {
+ i => makeBlockManager(storeSize, s"store$i")
+ }
+
+ storageLevels.foreach { storageLevel =>
+ // Put the block into one of the stores
+ val blockId = new TestBlockId(
+ "block-with-" + storageLevel.description.replace(" ", "-").toLowerCase)
+ stores(0).putSingle(blockId, new Array[Byte](blockSize), storageLevel)
+
+ // Assert that master know two locations for the block
+ val blockLocations = master.getLocations(blockId).map(_.executorId).toSet
+ assert(blockLocations.size === storageLevel.replication,
+ s"master did not have ${storageLevel.replication} locations for $blockId")
+
+ // Test state of the stores that contain the block
+ stores.filter {
+ testStore => blockLocations.contains(testStore.blockManagerId.executorId)
+ }.foreach { testStore =>
+ val testStoreName = testStore.blockManagerId.executorId
+ assert(testStore.getLocal(blockId).isDefined, s"$blockId was not found in $testStoreName")
+ assert(master.getLocations(blockId).map(_.executorId).toSet.contains(testStoreName),
+ s"master does not have status for ${blockId.name} in $testStoreName")
+
+ val blockStatus = master.getBlockStatus(blockId)(testStore.blockManagerId)
+
+ // Assert that block status in the master for this store has expected storage level
+ assert(
+ blockStatus.storageLevel.useDisk === storageLevel.useDisk &&
+ blockStatus.storageLevel.useMemory === storageLevel.useMemory &&
+ blockStatus.storageLevel.useOffHeap === storageLevel.useOffHeap &&
+ blockStatus.storageLevel.deserialized === storageLevel.deserialized,
+ s"master does not know correct storage level for ${blockId.name} in $testStoreName")
+
+ // Assert that the block status in the master for this store has correct memory usage info
+ assert(!blockStatus.storageLevel.useMemory || blockStatus.memSize >= blockSize,
+ s"master does not know size of ${blockId.name} stored in memory of $testStoreName")
+
+
+ // If the block is supposed to be in memory, then drop the copy of the block in
+ // this store test whether master is updated with zero memory usage this store
+ if (storageLevel.useMemory) {
+ // Force the block to be dropped by adding a number of dummy blocks
+ (1 to 10).foreach {
+ i =>
+ testStore.putSingle(s"dummy-block-$i", new Array[Byte](1000), MEMORY_ONLY_SER)
+ }
+ (1 to 10).foreach {
+ i => testStore.removeBlock(s"dummy-block-$i")
+ }
+
+ val newBlockStatusOption = master.getBlockStatus(blockId).get(testStore.blockManagerId)
+
+ // Assert that the block status in the master either does not exist (block removed
+ // from every store) or has zero memory usage for this store
+ assert(
+ newBlockStatusOption.isEmpty || newBlockStatusOption.get.memSize === 0,
+ s"after dropping, master does not know size of ${blockId.name} " +
+ s"stored in memory of $testStoreName"
+ )
+ }
+
+ // If the block is supposed to be in disk (after dropping or otherwise, then
+ // test whether master has correct disk usage for this store
+ if (storageLevel.useDisk) {
+ assert(master.getBlockStatus(blockId)(testStore.blockManagerId).diskSize >= blockSize,
+ s"after dropping, master does not know size of ${blockId.name} " +
+ s"stored in disk of $testStoreName"
+ )
+ }
+ }
+ master.removeBlock(blockId)
+ }
+ }
+}
diff --git a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala
index e251660dae5de..9d96202a3e7ac 100644
--- a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala
@@ -21,8 +21,6 @@ import java.nio.{ByteBuffer, MappedByteBuffer}
import java.util.Arrays
import java.util.concurrent.TimeUnit
-import org.apache.spark.network.nio.NioBlockTransferService
-
import scala.collection.mutable.ArrayBuffer
import scala.concurrent.Await
import scala.concurrent.duration._
@@ -35,13 +33,13 @@ import akka.util.Timeout
import org.mockito.Mockito.{mock, when}
-import org.scalatest.{BeforeAndAfter, FunSuite, PrivateMethodTester}
+import org.scalatest.{BeforeAndAfter, FunSuite, Matchers, PrivateMethodTester}
import org.scalatest.concurrent.Eventually._
import org.scalatest.concurrent.Timeouts._
-import org.scalatest.Matchers
import org.apache.spark.{MapOutputTrackerMaster, SecurityManager, SparkConf}
import org.apache.spark.executor.DataReadMethod
+import org.apache.spark.network.nio.NioBlockTransferService
import org.apache.spark.scheduler.LiveListenerBus
import org.apache.spark.serializer.{JavaSerializer, KryoSerializer}
import org.apache.spark.shuffle.hash.HashShuffleManager
@@ -189,7 +187,7 @@ class BlockManagerSuite extends FunSuite with Matchers with BeforeAndAfter
store = makeBlockManager(2000, "exec1")
store2 = makeBlockManager(2000, "exec2")
- val peers = master.getPeers(store.blockManagerId, 1)
+ val peers = master.getPeers(store.blockManagerId)
assert(peers.size === 1, "master did not return the other manager as a peer")
assert(peers.head === store2.blockManagerId, "peer returned by master is not the other manager")
@@ -448,7 +446,6 @@ class BlockManagerSuite extends FunSuite with Matchers with BeforeAndAfter
val list2DiskGet = store.get("list2disk")
assert(list2DiskGet.isDefined, "list2memory expected to be in store")
assert(list2DiskGet.get.data.size === 3)
- System.out.println(list2DiskGet)
// We don't know the exact size of the data on disk, but it should certainly be > 0.
assert(list2DiskGet.get.inputMetrics.bytesRead > 0)
assert(list2DiskGet.get.inputMetrics.readMethod === DataReadMethod.Disk)
diff --git a/core/src/test/scala/org/apache/spark/storage/DiskBlockManagerSuite.scala b/core/src/test/scala/org/apache/spark/storage/DiskBlockManagerSuite.scala
index e4522e00a622d..bc5c74c126b74 100644
--- a/core/src/test/scala/org/apache/spark/storage/DiskBlockManagerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/storage/DiskBlockManagerSuite.scala
@@ -19,22 +19,13 @@ package org.apache.spark.storage
import java.io.{File, FileWriter}
-import org.apache.spark.network.nio.NioBlockTransferService
-import org.apache.spark.shuffle.hash.HashShuffleManager
-
-import scala.collection.mutable
import scala.language.reflectiveCalls
-import akka.actor.Props
-import com.google.common.io.Files
import org.mockito.Mockito.{mock, when}
import org.scalatest.{BeforeAndAfterAll, BeforeAndAfterEach, FunSuite}
import org.apache.spark.SparkConf
-import org.apache.spark.scheduler.LiveListenerBus
-import org.apache.spark.serializer.JavaSerializer
-import org.apache.spark.util.{AkkaUtils, Utils}
-import org.apache.spark.executor.ShuffleWriteMetrics
+import org.apache.spark.util.Utils
class DiskBlockManagerSuite extends FunSuite with BeforeAndAfterEach with BeforeAndAfterAll {
private val testConf = new SparkConf(false)
@@ -48,10 +39,8 @@ class DiskBlockManagerSuite extends FunSuite with BeforeAndAfterEach with Before
override def beforeAll() {
super.beforeAll()
- rootDir0 = Files.createTempDir()
- rootDir0.deleteOnExit()
- rootDir1 = Files.createTempDir()
- rootDir1.deleteOnExit()
+ rootDir0 = Utils.createTempDir()
+ rootDir1 = Utils.createTempDir()
rootDirs = rootDir0.getAbsolutePath + "," + rootDir1.getAbsolutePath
}
diff --git a/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala b/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala
index 809bd70929656..a8c049d749015 100644
--- a/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala
+++ b/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala
@@ -17,7 +17,7 @@
package org.apache.spark.storage
-import org.apache.spark.TaskContext
+import org.apache.spark.{TaskContextImpl, TaskContext}
import org.apache.spark.network.{BlockFetchingListener, BlockTransferService}
import org.mockito.Mockito._
@@ -62,7 +62,7 @@ class ShuffleBlockFetcherIteratorSuite extends FunSuite {
)
val iterator = new ShuffleBlockFetcherIterator(
- new TaskContext(0, 0, 0),
+ new TaskContextImpl(0, 0, 0),
transfer,
blockManager,
blocksByAddress,
@@ -120,7 +120,7 @@ class ShuffleBlockFetcherIteratorSuite extends FunSuite {
)
val iterator = new ShuffleBlockFetcherIterator(
- new TaskContext(0, 0, 0),
+ new TaskContextImpl(0, 0, 0),
transfer,
blockManager,
blocksByAddress,
@@ -169,7 +169,7 @@ class ShuffleBlockFetcherIteratorSuite extends FunSuite {
(bmId, Seq((blId1, 1L), (blId2, 1L))))
val iterator = new ShuffleBlockFetcherIterator(
- new TaskContext(0, 0, 0),
+ new TaskContextImpl(0, 0, 0),
transfer,
blockManager,
blocksByAddress,
diff --git a/core/src/test/scala/org/apache/spark/util/FileLoggerSuite.scala b/core/src/test/scala/org/apache/spark/util/FileLoggerSuite.scala
index c3dd156b40514..72466a3aa1130 100644
--- a/core/src/test/scala/org/apache/spark/util/FileLoggerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/util/FileLoggerSuite.scala
@@ -21,7 +21,6 @@ import java.io.{File, IOException}
import scala.io.Source
-import com.google.common.io.Files
import org.apache.hadoop.fs.Path
import org.scalatest.{BeforeAndAfter, FunSuite}
@@ -44,7 +43,7 @@ class FileLoggerSuite extends FunSuite with BeforeAndAfter {
private var logDirPathString: String = _
before {
- testDir = Files.createTempDir()
+ testDir = Utils.createTempDir()
logDirPath = Utils.getFilePath(testDir, "test-file-logger")
logDirPathString = logDirPath.toString
}
@@ -75,13 +74,13 @@ class FileLoggerSuite extends FunSuite with BeforeAndAfter {
test("Logging when directory already exists") {
// Create the logging directory multiple times
- new FileLogger(logDirPathString, new SparkConf, overwrite = true).start()
- new FileLogger(logDirPathString, new SparkConf, overwrite = true).start()
- new FileLogger(logDirPathString, new SparkConf, overwrite = true).start()
+ new FileLogger(logDirPathString, new SparkConf, compress = false, overwrite = true).start()
+ new FileLogger(logDirPathString, new SparkConf, compress = false, overwrite = true).start()
+ new FileLogger(logDirPathString, new SparkConf, compress = false, overwrite = true).start()
// If overwrite is not enabled, an exception should be thrown
intercept[IOException] {
- new FileLogger(logDirPathString, new SparkConf, overwrite = false).start()
+ new FileLogger(logDirPathString, new SparkConf, compress = false, overwrite = false).start()
}
}
diff --git a/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala b/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala
index 70d423ba8a04d..ea7ef0524d1e1 100644
--- a/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala
+++ b/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala
@@ -27,6 +27,8 @@ import com.google.common.base.Charsets
import com.google.common.io.Files
import org.scalatest.FunSuite
+import org.apache.spark.SparkConf
+
class UtilsSuite extends FunSuite {
test("bytesToString") {
@@ -112,7 +114,7 @@ class UtilsSuite extends FunSuite {
}
test("reading offset bytes of a file") {
- val tmpDir2 = Files.createTempDir()
+ val tmpDir2 = Utils.createTempDir()
tmpDir2.deleteOnExit()
val f1Path = tmpDir2 + "/f1"
val f1 = new FileOutputStream(f1Path)
@@ -141,7 +143,7 @@ class UtilsSuite extends FunSuite {
}
test("reading offset bytes across multiple files") {
- val tmpDir = Files.createTempDir()
+ val tmpDir = Utils.createTempDir()
tmpDir.deleteOnExit()
val files = (1 to 3).map(i => new File(tmpDir, i.toString))
Files.write("0123456789", files(0), Charsets.UTF_8)
@@ -189,17 +191,28 @@ class UtilsSuite extends FunSuite {
assert(Utils.getIteratorSize(iterator) === 5L)
}
- test("findOldFiles") {
+ test("doesDirectoryContainFilesNewerThan") {
// create some temporary directories and files
val parent: File = Utils.createTempDir()
val child1: File = Utils.createTempDir(parent.getCanonicalPath) // The parent directory has two child directories
val child2: File = Utils.createTempDir(parent.getCanonicalPath)
- // set the last modified time of child1 to 10 secs old
- child1.setLastModified(System.currentTimeMillis() - (1000 * 10))
+ val child3: File = Utils.createTempDir(child1.getCanonicalPath)
+ // set the last modified time of child1 to 30 secs old
+ child1.setLastModified(System.currentTimeMillis() - (1000 * 30))
+
+ // although child1 is old, child2 is still new so return true
+ assert(Utils.doesDirectoryContainAnyNewFiles(parent, 5))
+
+ child2.setLastModified(System.currentTimeMillis - (1000 * 30))
+ assert(Utils.doesDirectoryContainAnyNewFiles(parent, 5))
- val result = Utils.findOldFiles(parent, 5) // find files older than 5 secs
- assert(result.size.equals(1))
- assert(result(0).getCanonicalPath.equals(child1.getCanonicalPath))
+ parent.setLastModified(System.currentTimeMillis - (1000 * 30))
+ // although parent and its immediate children are new, child3 is still old
+ // we expect a full recursive search for new files.
+ assert(Utils.doesDirectoryContainAnyNewFiles(parent, 5))
+
+ child3.setLastModified(System.currentTimeMillis - (1000 * 30))
+ assert(!Utils.doesDirectoryContainAnyNewFiles(parent, 5))
}
test("resolveURI") {
@@ -297,4 +310,45 @@ class UtilsSuite extends FunSuite {
}
}
+ test("deleteRecursively") {
+ val tempDir1 = Utils.createTempDir()
+ assert(tempDir1.exists())
+ Utils.deleteRecursively(tempDir1)
+ assert(!tempDir1.exists())
+
+ val tempDir2 = Utils.createTempDir()
+ val tempFile1 = new File(tempDir2, "foo.txt")
+ Files.touch(tempFile1)
+ assert(tempFile1.exists())
+ Utils.deleteRecursively(tempFile1)
+ assert(!tempFile1.exists())
+
+ val tempDir3 = new File(tempDir2, "subdir")
+ assert(tempDir3.mkdir())
+ val tempFile2 = new File(tempDir3, "bar.txt")
+ Files.touch(tempFile2)
+ assert(tempFile2.exists())
+ Utils.deleteRecursively(tempDir2)
+ assert(!tempDir2.exists())
+ assert(!tempDir3.exists())
+ assert(!tempFile2.exists())
+ }
+
+ test("loading properties from file") {
+ val outFile = File.createTempFile("test-load-spark-properties", "test")
+ try {
+ System.setProperty("spark.test.fileNameLoadB", "2")
+ Files.write("spark.test.fileNameLoadA true\n" +
+ "spark.test.fileNameLoadB 1\n", outFile, Charsets.UTF_8)
+ val properties = Utils.getPropertiesFromFile(outFile.getAbsolutePath)
+ properties
+ .filter { case (k, v) => k.startsWith("spark.")}
+ .foreach { case (k, v) => sys.props.getOrElseUpdate(k, v)}
+ val sparkConf = new SparkConf
+ assert(sparkConf.getBoolean("spark.test.fileNameLoadA", false) === true)
+ assert(sparkConf.getInt("spark.test.fileNameLoadB", 1) === 2)
+ } finally {
+ outFile.delete()
+ }
+ }
}
diff --git a/dev/merge_spark_pr.py b/dev/merge_spark_pr.py
index a8e92e36fe0d8..02ac20984add9 100755
--- a/dev/merge_spark_pr.py
+++ b/dev/merge_spark_pr.py
@@ -73,11 +73,10 @@ def fail(msg):
def run_cmd(cmd):
+ print cmd
if isinstance(cmd, list):
- print " ".join(cmd)
return subprocess.check_output(cmd)
else:
- print cmd
return subprocess.check_output(cmd.split(" "))
diff --git a/dev/run-tests b/dev/run-tests
index c3d8f49cdd993..f47fcf66ff7e7 100755
--- a/dev/run-tests
+++ b/dev/run-tests
@@ -24,6 +24,16 @@ cd "$FWDIR"
# Remove work directory
rm -rf ./work
+source "$FWDIR/dev/run-tests-codes.sh"
+
+CURRENT_BLOCK=$BLOCK_GENERAL
+
+function handle_error () {
+ echo "[error] Got a return code of $? on line $1 of the run-tests script."
+ exit $CURRENT_BLOCK
+}
+
+
# Build against the right verison of Hadoop.
{
if [ -n "$AMPLAB_JENKINS_BUILD_PROFILE" ]; then
@@ -32,7 +42,7 @@ rm -rf ./work
elif [ "$AMPLAB_JENKINS_BUILD_PROFILE" = "hadoop2.0" ]; then
export SBT_MAVEN_PROFILES_ARGS="-Dhadoop.version=2.0.0-mr1-cdh4.1.1"
elif [ "$AMPLAB_JENKINS_BUILD_PROFILE" = "hadoop2.2" ]; then
- export SBT_MAVEN_PROFILES_ARGS="-Pyarn -Dhadoop.version=2.2.0"
+ export SBT_MAVEN_PROFILES_ARGS="-Pyarn -Phadoop-2.2 -Dhadoop.version=2.2.0"
elif [ "$AMPLAB_JENKINS_BUILD_PROFILE" = "hadoop2.3" ]; then
export SBT_MAVEN_PROFILES_ARGS="-Pyarn -Phadoop-2.3 -Dhadoop.version=2.3.0"
fi
@@ -91,26 +101,34 @@ if [ -n "$AMPLAB_JENKINS" ]; then
fi
fi
-# Fail fast
-set -e
set -o pipefail
+trap 'handle_error $LINENO' ERR
echo ""
echo "========================================================================="
echo "Running Apache RAT checks"
echo "========================================================================="
+
+CURRENT_BLOCK=$BLOCK_RAT
+
./dev/check-license
echo ""
echo "========================================================================="
echo "Running Scala style checks"
echo "========================================================================="
+
+CURRENT_BLOCK=$BLOCK_SCALA_STYLE
+
./dev/lint-scala
echo ""
echo "========================================================================="
echo "Running Python style checks"
echo "========================================================================="
+
+CURRENT_BLOCK=$BLOCK_PYTHON_STYLE
+
./dev/lint-python
echo ""
@@ -118,6 +136,8 @@ echo "========================================================================="
echo "Building Spark"
echo "========================================================================="
+CURRENT_BLOCK=$BLOCK_BUILD
+
{
# We always build with Hive because the PySpark Spark SQL tests need it.
BUILD_MVN_PROFILE_ARGS="$SBT_MAVEN_PROFILES_ARGS -Phive"
@@ -141,6 +161,8 @@ echo "========================================================================="
echo "Running Spark unit tests"
echo "========================================================================="
+CURRENT_BLOCK=$BLOCK_SPARK_UNIT_TESTS
+
{
# If the Spark SQL tests are enabled, run the tests with the Hive profiles enabled.
# This must be a single argument, as it is.
@@ -175,10 +197,16 @@ echo ""
echo "========================================================================="
echo "Running PySpark tests"
echo "========================================================================="
+
+CURRENT_BLOCK=$BLOCK_PYSPARK_UNIT_TESTS
+
./python/run-tests
echo ""
echo "========================================================================="
echo "Detecting binary incompatibilites with MiMa"
echo "========================================================================="
+
+CURRENT_BLOCK=$BLOCK_MIMA
+
./dev/mima
diff --git a/dev/run-tests-codes.sh b/dev/run-tests-codes.sh
new file mode 100644
index 0000000000000..1348e0609dda4
--- /dev/null
+++ b/dev/run-tests-codes.sh
@@ -0,0 +1,27 @@
+#!/usr/bin/env bash
+
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+readonly BLOCK_GENERAL=10
+readonly BLOCK_RAT=11
+readonly BLOCK_SCALA_STYLE=12
+readonly BLOCK_PYTHON_STYLE=13
+readonly BLOCK_BUILD=14
+readonly BLOCK_SPARK_UNIT_TESTS=15
+readonly BLOCK_PYSPARK_UNIT_TESTS=16
+readonly BLOCK_MIMA=17
diff --git a/dev/run-tests-jenkins b/dev/run-tests-jenkins
index 0b1e31b9413cf..451f3b771cc76 100755
--- a/dev/run-tests-jenkins
+++ b/dev/run-tests-jenkins
@@ -26,9 +26,23 @@
FWDIR="$(cd `dirname $0`/..; pwd)"
cd "$FWDIR"
+source "$FWDIR/dev/run-tests-codes.sh"
+
COMMENTS_URL="https://api.github.com/repos/apache/spark/issues/$ghprbPullId/comments"
PULL_REQUEST_URL="https://github.com/apache/spark/pull/$ghprbPullId"
+# Important Environment Variables
+# ---
+# $ghprbActualCommit
+#+ This is the hash of the most recent commit in the PR.
+#+ The merge-base of this and master is the commit from which the PR was branched.
+# $sha1
+#+ If the patch merges cleanly, this is a reference to the merge commit hash
+#+ (e.g. "origin/pr/2606/merge").
+#+ If the patch does not merge cleanly, it is equal to $ghprbActualCommit.
+#+ The merge-base of this and master in the case of a clean merge is the most recent commit
+#+ against master.
+
COMMIT_URL="https://github.com/apache/spark/commit/${ghprbActualCommit}"
# GitHub doesn't auto-link short hashes when submitted via the API, unfortunately. :(
SHORT_COMMIT_HASH="${ghprbActualCommit:0:7}"
@@ -84,42 +98,46 @@ function post_message () {
fi
}
+
+# We diff master...$ghprbActualCommit because that gets us changes introduced in the PR
+#+ and not anything else added to master since the PR was branched.
+
# check PR merge-ability and check for new public classes
{
if [ "$sha1" == "$ghprbActualCommit" ]; then
- merge_note=" * This patch **does not** merge cleanly!"
+ merge_note=" * This patch **does not merge cleanly**."
else
merge_note=" * This patch merges cleanly."
+ fi
+
+ source_files=$(
+ git diff master...$ghprbActualCommit --name-only `# diff patch against master from branch point` \
+ | grep -v -e "\/test" `# ignore files in test directories` \
+ | grep -e "\.py$" -e "\.java$" -e "\.scala$" `# include only code files` \
+ | tr "\n" " "
+ )
+ new_public_classes=$(
+ git diff master...$ghprbActualCommit ${source_files} `# diff patch against master from branch point` \
+ | grep "^\+" `# filter in only added lines` \
+ | sed -r -e "s/^\+//g" `# remove the leading +` \
+ | grep -e "trait " -e "class " `# filter in lines with these key words` \
+ | grep -e "{" -e "(" `# filter in lines with these key words, too` \
+ | grep -v -e "\@\@" -e "private" `# exclude lines with these words` \
+ | grep -v -e "^// " -e "^/\*" -e "^ \* " `# exclude comment lines` \
+ | sed -r -e "s/\{.*//g" `# remove from the { onwards` \
+ | sed -r -e "s/\}//g" `# just in case, remove }; they mess the JSON` \
+ | sed -r -e "s/\"/\\\\\"/g" `# escape double quotes; they mess the JSON` \
+ | sed -r -e "s/^(.*)$/\`\1\`/g" `# surround with backticks for style` \
+ | sed -r -e "s/^/ \* /g" `# prepend ' *' to start of line` \
+ | sed -r -e "s/$/\\\n/g" `# append newline to end of line` \
+ | tr -d "\n" `# remove actual LF characters`
+ )
- source_files=$(
- git diff master... --name-only `# diff patch against master from branch point` \
- | grep -v -e "\/test" `# ignore files in test directories` \
- | grep -e "\.py$" -e "\.java$" -e "\.scala$" `# include only code files` \
- | tr "\n" " "
- )
- new_public_classes=$(
- git diff master... ${source_files} `# diff patch against master from branch point` \
- | grep "^\+" `# filter in only added lines` \
- | sed -r -e "s/^\+//g" `# remove the leading +` \
- | grep -e "trait " -e "class " `# filter in lines with these key words` \
- | grep -e "{" -e "(" `# filter in lines with these key words, too` \
- | grep -v -e "\@\@" -e "private" `# exclude lines with these words` \
- | grep -v -e "^// " -e "^/\*" -e "^ \* " `# exclude comment lines` \
- | sed -r -e "s/\{.*//g" `# remove from the { onwards` \
- | sed -r -e "s/\}//g" `# just in case, remove }; they mess the JSON` \
- | sed -r -e "s/\"/\\\\\"/g" `# escape double quotes; they mess the JSON` \
- | sed -r -e "s/^(.*)$/\`\1\`/g" `# surround with backticks for style` \
- | sed -r -e "s/^/ \* /g" `# prepend ' *' to start of line` \
- | sed -r -e "s/$/\\\n/g" `# append newline to end of line` \
- | tr -d "\n" `# remove actual LF characters`
- )
-
- if [ "$new_public_classes" == "" ]; then
- public_classes_note=" * This patch adds no public classes."
- else
- public_classes_note=" * This patch adds the following public classes _(experimental)_:"
- public_classes_note="${public_classes_note}\n${new_public_classes}"
- fi
+ if [ -z "$new_public_classes" ]; then
+ public_classes_note=" * This patch adds no public classes."
+ else
+ public_classes_note=" * This patch adds the following public classes _(experimental)_:"
+ public_classes_note="${public_classes_note}\n${new_public_classes}"
fi
}
@@ -147,12 +165,30 @@ function post_message () {
post_message "$fail_message"
exit $test_result
+ elif [ "$test_result" -eq "0" ]; then
+ test_result_note=" * This patch **passes all tests**."
else
- if [ "$test_result" -eq "0" ]; then
- test_result_note=" * This patch **passes** unit tests."
+ if [ "$test_result" -eq "$BLOCK_GENERAL" ]; then
+ failing_test="some tests"
+ elif [ "$test_result" -eq "$BLOCK_RAT" ]; then
+ failing_test="RAT tests"
+ elif [ "$test_result" -eq "$BLOCK_SCALA_STYLE" ]; then
+ failing_test="Scala style tests"
+ elif [ "$test_result" -eq "$BLOCK_PYTHON_STYLE" ]; then
+ failing_test="Python style tests"
+ elif [ "$test_result" -eq "$BLOCK_BUILD" ]; then
+ failing_test="to build"
+ elif [ "$test_result" -eq "$BLOCK_SPARK_UNIT_TESTS" ]; then
+ failing_test="Spark unit tests"
+ elif [ "$test_result" -eq "$BLOCK_PYSPARK_UNIT_TESTS" ]; then
+ failing_test="PySpark unit tests"
+ elif [ "$test_result" -eq "$BLOCK_MIMA" ]; then
+ failing_test="MiMa tests"
else
- test_result_note=" * This patch **fails** unit tests."
+ failing_test="some tests"
fi
+
+ test_result_note=" * This patch **fails $failing_test**."
fi
}
diff --git a/dev/scalastyle b/dev/scalastyle
index efb5f291ea3b7..c3b356bcb3c06 100755
--- a/dev/scalastyle
+++ b/dev/scalastyle
@@ -26,6 +26,8 @@ echo -e "q\n" | sbt/sbt -Pyarn -Phadoop-2.2 -Dhadoop.version=2.2.0 yarn/scalasty
>> scalastyle.txt
ERRORS=$(cat scalastyle.txt | grep -e "\")
+rm scalastyle.txt
+
if test ! -z "$ERRORS"; then
echo -e "Scalastyle checks failed at following occurrences:\n$ERRORS"
exit 1
diff --git a/docs/README.md b/docs/README.md
index 79708c3df9106..d2d58e435d4c4 100644
--- a/docs/README.md
+++ b/docs/README.md
@@ -25,8 +25,7 @@ installing via the Ruby Gem dependency manager. Since the exact HTML output
varies between versions of Jekyll and its dependencies, we list specific versions here
in some cases:
- $ sudo gem install jekyll -v 1.4.3
- $ sudo gem uninstall kramdown -v 1.4.1
+ $ sudo gem install jekyll
$ sudo gem install jekyll-redirect-from
Execute `jekyll` from the `docs/` directory. Compiling the site with Jekyll will create a directory
@@ -54,19 +53,19 @@ phase, use the following sytax:
// supported languages too.
{% endhighlight %}
-## API Docs (Scaladoc and Epydoc)
+## API Docs (Scaladoc and Sphinx)
You can build just the Spark scaladoc by running `sbt/sbt doc` from the SPARK_PROJECT_ROOT directory.
-Similarly, you can build just the PySpark epydoc by running `epydoc --config epydoc.conf` from the
-SPARK_PROJECT_ROOT/pyspark directory. Documentation is only generated for classes that are listed as
+Similarly, you can build just the PySpark docs by running `make html` from the
+SPARK_PROJECT_ROOT/python/docs directory. Documentation is only generated for classes that are listed as
public in `__init__.py`.
When you run `jekyll` in the `docs` directory, it will also copy over the scaladoc for the various
Spark subprojects into the `docs` directory (and then also into the `_site` directory). We use a
jekyll plugin to run `sbt/sbt doc` before building the site so if you haven't run it (recently) it
may take some time as it generates all of the scaladoc. The jekyll plugin also generates the
-PySpark docs using [epydoc](http://epydoc.sourceforge.net/).
+PySpark docs [Sphinx](http://sphinx-doc.org/).
NOTE: To skip the step of building and copying over the Scala and Python API docs, run `SKIP_API=1
jekyll`.
diff --git a/docs/_config.yml b/docs/_config.yml
index 7bc3a78e2d265..f4bf242ac191b 100644
--- a/docs/_config.yml
+++ b/docs/_config.yml
@@ -8,6 +8,9 @@ gems:
kramdown:
entity_output: numeric
+include:
+ - _static
+
# These allow the documentation to be updated with nerw releases
# of Spark, Scala, and Mesos.
SPARK_VERSION: 1.0.0-SNAPSHOT
diff --git a/docs/_plugins/copy_api_dirs.rb b/docs/_plugins/copy_api_dirs.rb
index 3b02e090aec28..4566a2fff562b 100644
--- a/docs/_plugins/copy_api_dirs.rb
+++ b/docs/_plugins/copy_api_dirs.rb
@@ -63,19 +63,20 @@
puts "cp -r " + source + "/. " + dest
cp_r(source + "/.", dest)
- # Build Epydoc for Python
- puts "Moving to python directory and building epydoc."
- cd("../python")
- puts `epydoc --config epydoc.conf`
+ # Build Sphinx docs for Python
- puts "Moving back into docs dir."
- cd("../docs")
+ puts "Moving to python/docs directory and building sphinx."
+ cd("../python/docs")
+ puts `make html`
+
+ puts "Moving back into home dir."
+ cd("../../")
puts "Making directory api/python"
- mkdir_p "api/python"
+ mkdir_p "docs/api/python"
- puts "cp -r ../python/docs/. api/python"
- cp_r("../python/docs/.", "api/python")
+ puts "cp -r python/docs/_build/html/. docs/api/python"
+ cp_r("python/docs/_build/html/.", "docs/api/python")
cd("..")
end
diff --git a/docs/building-spark.md b/docs/building-spark.md
index 2378092d4a1a8..b2940ee4029e8 100644
--- a/docs/building-spark.md
+++ b/docs/building-spark.md
@@ -169,7 +169,22 @@ compilation. More advanced developers may wish to use SBT.
The SBT build is derived from the Maven POM files, and so the same Maven profiles and variables
can be set to control the SBT build. For example:
- sbt/sbt -Pyarn -Phadoop-2.3 compile
+ sbt/sbt -Pyarn -Phadoop-2.3 assembly
+
+# Testing with SBT
+
+Some of the tests require Spark to be packaged first, so always run `sbt/sbt assembly` the first time. The following is an example of a correct (build, test) sequence:
+
+ sbt/sbt -Pyarn -Phadoop-2.3 -Phive assembly
+ sbt/sbt -Pyarn -Phadoop-2.3 -Phive test
+
+To run only a specific test suite as follows:
+
+ sbt/sbt -Pyarn -Phadoop-2.3 -Phive "test-only org.apache.spark.repl.ReplSuite"
+
+To run test suites of a specific sub project as follows:
+
+ sbt/sbt -Pyarn -Phadoop-2.3 -Phive core/test
# Speeding up Compilation with Zinc
diff --git a/docs/configuration.md b/docs/configuration.md
index a6dd7245e1552..96fa1377ec399 100644
--- a/docs/configuration.md
+++ b/docs/configuration.md
@@ -103,6 +103,14 @@ of the most common options to set are:
(e.g. 512m, 2g).
+
+
spark.driver.memory
+
512m
+
+ Amount of memory to use for the driver process, i.e. where SparkContext is initialized.
+ (e.g. 512m, 2g).
+
+
spark.serializer
org.apache.spark.serializer. JavaSerializer
@@ -153,14 +161,6 @@ Apart from these, the following properties are also available, and may be useful
#### Runtime Environment
Property Name
Default
Meaning
-
-
spark.executor.memory
-
512m
-
- Amount of memory to use per executor process, in the same format as JVM memory strings
- (e.g. 512m, 2g).
-
-
spark.executor.extraJavaOptions
(none)
@@ -206,6 +206,25 @@ Apart from these, the following properties are also available, and may be useful
used during aggregation goes above this amount, it will spill the data into disks.
+
+
spark.python.profile
+
false
+
+ Enable profiling in Python worker, the profile result will show up by `sc.show_profiles()`,
+ or it will be displayed before the driver exiting. It also can be dumped into disk by
+ `sc.dump_profiles(path)`. If some of the profile results had been displayed maually,
+ they will not be displayed automatically before driver exiting.
+
+
+
+
spark.python.profile.dump
+
(none)
+
+ The directory which is used to dump the profile result before driver exiting.
+ The results will be dumped as separated file for each RDD. They can be loaded
+ by ptats.Stats(). If this is specified, the profile result will not be displayed
+ automatically.
+
spark.python.worker.reuse
true
@@ -234,6 +253,17 @@ Apart from these, the following properties are also available, and may be useful
spark.executor.uri.
+
+
spark.mesos.executor.memoryOverhead
+
executor memory * 0.07, with minimum of 384
+
+ This value is an additive for spark.executor.memory, specified in MiB,
+ which is used to calculate the total Mesos task memory. A value of 384
+ implies a 384MiB overhead. Additionally, there is a hard-coded 7% minimum
+ overhead. The final overhead will be the larger of either
+ `spark.mesos.executor.memoryOverhead` or 7% of `spark.executor.memory`.
+
+
#### Shuffle Behavior
@@ -327,7 +357,7 @@ Apart from these, the following properties are also available, and may be useful
spark.ui.port
4040
- Port for your application's dashboard, which shows memory and workload data
+ Port for your application's dashboard, which shows memory and workload data.
@@ -394,10 +424,11 @@ Apart from these, the following properties are also available, and may be useful
spark.io.compression.codec
snappy
- The codec used to compress internal data such as RDD partitions and shuffle outputs. By default,
- Spark provides three codecs: lz4, lzf, and snappy. You
- can also use fully qualified class names to specify the codec, e.g.
- org.apache.spark.io.LZ4CompressionCodec,
+ The codec used to compress internal data such as RDD partitions, broadcast variables and
+ shuffle outputs. By default, Spark provides three codecs: lz4, lzf,
+ and snappy. You can also use fully qualified class names to specify the codec,
+ e.g.
+ org.apache.spark.io.LZ4CompressionCodec,
org.apache.spark.io.LZFCompressionCodec,
and org.apache.spark.io.SnappyCompressionCodec.
@@ -588,6 +619,15 @@ Apart from these, the following properties are also available, and may be useful
output directories. We recommend that users do not disable this except if trying to achieve compatibility with
previous versions of Spark. Simply use Hadoop's FileSystem API to delete output directories by hand.
+
+
spark.hadoop.cloneConf
+
false
+
If set to true, clones a new Hadoop Configuration object for each task. This
+ option should be enabled to work around Configuration thread-safety issues (see
+ SPARK-2546 for more details).
+ This is disabled by default in order to avoid unexpected performance regressions for jobs that
+ are not affected by these issues.
+
spark.executor.heartbeatInterval
10000
@@ -686,7 +726,7 @@ Apart from these, the following properties are also available, and may be useful
spark.akka.heartbeat.pauses
-
600
+
6000
This is set to a larger value to disable failure detector that comes inbuilt akka. It can be
enabled again, if you plan to use this feature (Not recommended). Acceptable heart beat pause
@@ -841,8 +881,8 @@ Apart from these, the following properties are also available, and may be useful
spark.scheduler.revive.interval
1000
- The interval length for the scheduler to revive the worker resource offers to run tasks.
- (in milliseconds)
+ The interval length for the scheduler to revive the worker resource offers to run tasks
+ (in milliseconds).
@@ -854,7 +894,7 @@ Apart from these, the following properties are also available, and may be useful
to wait for before scheduling begins. Specified as a double between 0 and 1.
Regardless of whether the minimum ratio of resources has been reached,
the maximum amount of time it will wait before scheduling begins is controlled by config
- spark.scheduler.maxRegisteredResourcesWaitingTime
+ spark.scheduler.maxRegisteredResourcesWaitingTime.
@@ -1088,3 +1128,10 @@ compute `SPARK_LOCAL_IP` by looking up the IP of a specific network interface.
Spark uses [log4j](http://logging.apache.org/log4j/) for logging. You can configure it by adding a
`log4j.properties` file in the `conf` directory. One way to start is to copy the existing
`log4j.properties.template` located there.
+
+# Overriding configuration directory
+
+To specify a different configuration directory other than the default "SPARK_HOME/conf",
+you can set SPARK_CONF_DIR. Spark will use the the configuration files (spark-defaults.conf, spark-env.sh, log4j.properties, etc)
+from this directory.
+
diff --git a/docs/mllib-clustering.md b/docs/mllib-clustering.md
index d10bd63746629..7978e934fb36b 100644
--- a/docs/mllib-clustering.md
+++ b/docs/mllib-clustering.md
@@ -69,7 +69,7 @@ println("Within Set Sum of Squared Errors = " + WSSSE)
All of MLlib's methods use Java-friendly types, so you can import and call them there the same
way you do in Scala. The only caveat is that the methods take Scala RDD objects, while the
Spark Java API uses a separate `JavaRDD` class. You can convert a Java RDD to a Scala one by
-calling `.rdd()` on your `JavaRDD` object. A standalone application example
+calling `.rdd()` on your `JavaRDD` object. A self-contained application example
that is equivalent to the provided example in Scala is given below:
{% highlight java %}
@@ -113,12 +113,6 @@ public class KMeansExample {
}
}
{% endhighlight %}
-
-In order to run the above standalone application, follow the instructions
-provided in the [Standalone
-Applications](quick-start.html#standalone-applications) section of the Spark
-quick-start guide. Be sure to also include *spark-mllib* to your build file as
-a dependency.
@@ -153,3 +147,9 @@ print("Within Set Sum of Squared Error = " + str(WSSSE))
+
+In order to run the above application, follow the instructions
+provided in the [Self-Contained Applications](quick-start.html#self-contained-applications)
+section of the Spark
+Quick Start guide. Be sure to also include *spark-mllib* to your build file as
+a dependency.
diff --git a/docs/mllib-collaborative-filtering.md b/docs/mllib-collaborative-filtering.md
index d5c539db791be..2094963392295 100644
--- a/docs/mllib-collaborative-filtering.md
+++ b/docs/mllib-collaborative-filtering.md
@@ -110,7 +110,7 @@ val model = ALS.trainImplicit(ratings, rank, numIterations, alpha)
All of MLlib's methods use Java-friendly types, so you can import and call them there the same
way you do in Scala. The only caveat is that the methods take Scala RDD objects, while the
Spark Java API uses a separate `JavaRDD` class. You can convert a Java RDD to a Scala one by
-calling `.rdd()` on your `JavaRDD` object. A standalone application example
+calling `.rdd()` on your `JavaRDD` object. A self-contained application example
that is equivalent to the provided example in Scala is given bellow:
{% highlight java %}
@@ -184,12 +184,6 @@ public class CollaborativeFiltering {
}
}
{% endhighlight %}
-
-In order to run the above standalone application, follow the instructions
-provided in the [Standalone
-Applications](quick-start.html#standalone-applications) section of the Spark
-quick-start guide. Be sure to also include *spark-mllib* to your build file as
-a dependency.
+In order to run the above application, follow the instructions
+provided in the [Self-Contained Applications](quick-start.html#self-contained-applications)
+section of the Spark
+Quick Start guide. Be sure to also include *spark-mllib* to your build file as
+a dependency.
+
## Tutorial
The [training exercises](https://databricks-training.s3.amazonaws.com/index.html) from the Spark Summit 2014 include a hands-on tutorial for
diff --git a/docs/mllib-dimensionality-reduction.md b/docs/mllib-dimensionality-reduction.md
index 21cb35b4270ca..870fed6cc5024 100644
--- a/docs/mllib-dimensionality-reduction.md
+++ b/docs/mllib-dimensionality-reduction.md
@@ -121,9 +121,9 @@ public class SVD {
The same code applies to `IndexedRowMatrix` if `U` is defined as an
`IndexedRowMatrix`.
-In order to run the above standalone application, follow the instructions
-provided in the [Standalone
-Applications](quick-start.html#standalone-applications) section of the Spark
+In order to run the above application, follow the instructions
+provided in the [Self-Contained
+Applications](quick-start.html#self-contained-applications) section of the Spark
quick-start guide. Be sure to also include *spark-mllib* to your build file as
a dependency.
@@ -200,10 +200,11 @@ public class PCA {
}
{% endhighlight %}
-In order to run the above standalone application, follow the instructions
-provided in the [Standalone
-Applications](quick-start.html#standalone-applications) section of the Spark
-quick-start guide. Be sure to also include *spark-mllib* to your build file as
-a dependency.
+
+In order to run the above application, follow the instructions
+provided in the [Self-Contained Applications](quick-start.html#self-contained-applications)
+section of the Spark
+quick-start guide. Be sure to also include *spark-mllib* to your build file as
+a dependency.
diff --git a/docs/mllib-linear-methods.md b/docs/mllib-linear-methods.md
index d31bec3e1bd01..bc914a1899801 100644
--- a/docs/mllib-linear-methods.md
+++ b/docs/mllib-linear-methods.md
@@ -247,7 +247,7 @@ val modelL1 = svmAlg.run(training)
All of MLlib's methods use Java-friendly types, so you can import and call them there the same
way you do in Scala. The only caveat is that the methods take Scala RDD objects, while the
Spark Java API uses a separate `JavaRDD` class. You can convert a Java RDD to a Scala one by
-calling `.rdd()` on your `JavaRDD` object. A standalone application example
+calling `.rdd()` on your `JavaRDD` object. A self-contained application example
that is equivalent to the provided example in Scala is given bellow:
{% highlight java %}
@@ -323,9 +323,9 @@ svmAlg.optimizer()
final SVMModel modelL1 = svmAlg.run(training.rdd());
{% endhighlight %}
-In order to run the above standalone application, follow the instructions
-provided in the [Standalone
-Applications](quick-start.html#standalone-applications) section of the Spark
+In order to run the above application, follow the instructions
+provided in the [Self-Contained
+Applications](quick-start.html#self-contained-applications) section of the Spark
quick-start guide. Be sure to also include *spark-mllib* to your build file as
a dependency.
@@ -482,12 +482,6 @@ public class LinearRegression {
}
}
{% endhighlight %}
-
-In order to run the above standalone application, follow the instructions
-provided in the [Standalone
-Applications](quick-start.html#standalone-applications) section of the Spark
-quick-start guide. Be sure to also include *spark-mllib* to your build file as
-a dependency.
+In order to run the above application, follow the instructions
+provided in the [Self-Contained Applications](quick-start.html#self-contained-applications)
+section of the Spark
+quick-start guide. Be sure to also include *spark-mllib* to your build file as
+a dependency.
+
## Streaming linear regression
When data arrive in a streaming fashion, it is useful to fit regression models online,
diff --git a/docs/monitoring.md b/docs/monitoring.md
index d07ec4a57a2cc..e3f81a76acdbb 100644
--- a/docs/monitoring.md
+++ b/docs/monitoring.md
@@ -77,6 +77,13 @@ follows:
one implementation, provided by Spark, which looks for application logs stored in the
file system.
+
+
spark.history.fs.logDirectory
+
(none)
+
+ Directory that contains application event logs to be loaded by the history server
+
+
spark.history.fs.updateInterval
10
diff --git a/docs/programming-guide.md b/docs/programming-guide.md
index 1d61a3c555eaf..18420afb27e3c 100644
--- a/docs/programming-guide.md
+++ b/docs/programming-guide.md
@@ -211,17 +211,17 @@ For a complete list of options, run `pyspark --help`. Behind the scenes,
It is also possible to launch the PySpark shell in [IPython](http://ipython.org), the
enhanced Python interpreter. PySpark works with IPython 1.0.0 and later. To
-use IPython, set the `IPYTHON` variable to `1` when running `bin/pyspark`:
+use IPython, set the `PYSPARK_DRIVER_PYTHON` variable to `ipython` when running `bin/pyspark`:
{% highlight bash %}
-$ IPYTHON=1 ./bin/pyspark
+$ PYSPARK_DRIVER_PYTHON=ipython ./bin/pyspark
{% endhighlight %}
-You can customize the `ipython` command by setting `IPYTHON_OPTS`. For example, to launch
+You can customize the `ipython` command by setting `PYSPARK_DRIVER_PYTHON_OPTS`. For example, to launch
the [IPython Notebook](http://ipython.org/notebook.html) with PyLab plot support:
{% highlight bash %}
-$ IPYTHON_OPTS="notebook --pylab inline" ./bin/pyspark
+$ PYSPARK_DRIVER_PYTHON=ipython PYSPARK_DRIVER_PYTHON_OPTS="notebook --pylab inline" ./bin/pyspark
{% endhighlight %}
diff --git a/docs/quick-start.md b/docs/quick-start.md
index 23313d8aa6152..6236de0e1f2c4 100644
--- a/docs/quick-start.md
+++ b/docs/quick-start.md
@@ -8,7 +8,7 @@ title: Quick Start
This tutorial provides a quick introduction to using Spark. We will first introduce the API through Spark's
interactive shell (in Python or Scala),
-then show how to write standalone applications in Java, Scala, and Python.
+then show how to write applications in Java, Scala, and Python.
See the [programming guide](programming-guide.html) for a more complete reference.
To follow along with this guide, first download a packaged release of Spark from the
@@ -215,8 +215,8 @@ a cluster, as described in the [programming guide](programming-guide.html#initia
-# Standalone Applications
-Now say we wanted to write a standalone application using the Spark API. We will walk through a
+# Self-Contained Applications
+Now say we wanted to write a self-contained application using the Spark API. We will walk through a
simple application in both Scala (with SBT), Java (with Maven), and Python.
@@ -387,7 +387,7 @@ Lines with a: 46, Lines with b: 23
-Now we will show how to write a standalone application using the Python API (PySpark).
+Now we will show how to write an application using the Python API (PySpark).
As an example, we'll create a simple Spark application, `SimpleApp.py`:
diff --git a/docs/running-on-yarn.md b/docs/running-on-yarn.md
index 4b3a49eca7007..695813a2ba881 100644
--- a/docs/running-on-yarn.md
+++ b/docs/running-on-yarn.md
@@ -79,16 +79,16 @@ Most of the configs are the same for Spark on YARN as for other deployment modes
spark.yarn.executor.memoryOverhead
-
384
+
executorMemory * 0.07, with minimum of 384
- The amount of off heap memory (in megabytes) to be allocated per executor. This is memory that accounts for things like VM overheads, interned strings, other native overheads, etc.
+ The amount of off heap memory (in megabytes) to be allocated per executor. This is memory that accounts for things like VM overheads, interned strings, other native overheads, etc. This tends to grow with the executor size (typically 6-10%).
spark.yarn.driver.memoryOverhead
-
384
+
driverMemory * 0.07, with minimum of 384
- The amount of off heap memory (in megabytes) to be allocated per driver. This is memory that accounts for things like VM overheads, interned strings, other native overheads, etc.
+ The amount of off heap memory (in megabytes) to be allocated per driver. This is memory that accounts for things like VM overheads, interned strings, other native overheads, etc. This tends to grow with the container size (typically 6-10%).
diff --git a/docs/sql-programming-guide.md b/docs/sql-programming-guide.md
index 818fd5ab80af8..368c3d0008b07 100644
--- a/docs/sql-programming-guide.md
+++ b/docs/sql-programming-guide.md
@@ -620,8 +620,8 @@ val people = sqlContext.jsonFile(path)
// The inferred schema can be visualized using the printSchema() method.
people.printSchema()
// root
-// |-- age: IntegerType
-// |-- name: StringType
+// |-- age: integer (nullable = true)
+// |-- name: string (nullable = true)
// Register this SchemaRDD as a table.
people.registerTempTable("people")
@@ -658,8 +658,8 @@ JavaSchemaRDD people = sqlContext.jsonFile(path);
// The inferred schema can be visualized using the printSchema() method.
people.printSchema();
// root
-// |-- age: IntegerType
-// |-- name: StringType
+// |-- age: integer (nullable = true)
+// |-- name: string (nullable = true)
// Register this JavaSchemaRDD as a table.
people.registerTempTable("people");
@@ -697,8 +697,8 @@ people = sqlContext.jsonFile(path)
# The inferred schema can be visualized using the printSchema() method.
people.printSchema()
# root
-# |-- age: IntegerType
-# |-- name: StringType
+# |-- age: integer (nullable = true)
+# |-- name: string (nullable = true)
# Register this SchemaRDD as a table.
people.registerTempTable("people")
@@ -1394,7 +1394,7 @@ please use factory methods provided in
StructType
-
org.apache.spark.sql.api.java
+
org.apache.spark.sql.api.java.Row
DataType.createStructType(fields) Note:fields is a List or an array of StructFields.
diff --git a/docs/streaming-programming-guide.md b/docs/streaming-programming-guide.md
index 5c21e912ea160..8bbba88b31978 100644
--- a/docs/streaming-programming-guide.md
+++ b/docs/streaming-programming-guide.md
@@ -212,6 +212,67 @@ The complete code can be found in the Spark Streaming example
[JavaNetworkWordCount]({{site.SPARK_GITHUB_URL}}/blob/master/examples/src/main/java/org/apache/spark/examples/streaming/JavaNetworkWordCount.java).
+
+
+First, we import StreamingContext, which is the main entry point for all streaming functionality. We create a local StreamingContext with two execution threads, and batch interval of 1 second.
+
+{% highlight python %}
+from pyspark import SparkContext
+from pyspark.streaming import StreamingContext
+
+# Create a local StreamingContext with two working thread and batch interval of 1 second
+sc = SparkContext("local[2]", "NetworkWordCount")
+ssc = StreamingContext(sc, 1)
+{% endhighlight %}
+
+Using this context, we can create a DStream that represents streaming data from a TCP
+source hostname, e.g. `localhost`, and port, e.g. `9999`
+
+{% highlight python %}
+# Create a DStream that will connect to hostname:port, like localhost:9999
+lines = ssc.socketTextStream("localhost", 9999)
+{% endhighlight %}
+
+This `lines` DStream represents the stream of data that will be received from the data
+server. Each record in this DStream is a line of text. Next, we want to split the lines by
+space into words.
+
+{% highlight python %}
+# Split each line into words
+words = lines.flatMap(lambda line: line.split(" "))
+{% endhighlight %}
+
+`flatMap` is a one-to-many DStream operation that creates a new DStream by
+generating multiple new records from each record in the source DStream. In this case,
+each line will be split into multiple words and the stream of words is represented as the
+`words` DStream. Next, we want to count these words.
+
+{% highlight python %}
+# Count each word in each batch
+pairs = words.map(lambda word: (word, 1))
+wordCounts = pairs.reduceByKey(lambda x, y: x + y)
+
+# Print the first ten elements of each RDD generated in this DStream to the console
+wordCounts.pprint()
+{% endhighlight %}
+
+The `words` DStream is further mapped (one-to-one transformation) to a DStream of `(word,
+1)` pairs, which is then reduced to get the frequency of words in each batch of data.
+Finally, `wordCounts.pprint()` will print a few of the counts generated every second.
+
+Note that when these lines are executed, Spark Streaming only sets up the computation it
+will perform when it is started, and no real processing has started yet. To start the processing
+after all the transformations have been setup, we finally call
+
+{% highlight python %}
+ssc.start() # Start the computation
+ssc.awaitTermination() # Wait for the computation to terminate
+{% endhighlight %}
+
+The complete code can be found in the Spark Streaming example
+[NetworkWordCount]({{site.SPARK_GITHUB_URL}}/blob/master/examples/src/main/python/streaming/network_wordcount.py).
+
+
+
+A [StreamingContext](api/python/pyspark.streaming.html#pyspark.streaming.StreamingContext) object can be created from a [SparkContext](api/python/pyspark.html#pyspark.SparkContext) object.
+
+{% highlight python %}
+from pyspark import SparkContext
+from pyspark.streaming import StreamingContext
+
+sc = SparkContext(master, appName)
+ssc = StreamingContext(sc, 1)
+{% endhighlight %}
+
+The `appName` parameter is a name for your application to show on the cluster UI.
+`master` is a [Spark, Mesos or YARN cluster URL](submitting-applications.html#master-urls),
+or a special __"local[\*]"__ string to run in local mode. In practice, when running on a cluster,
+you will not want to hardcode `master` in the program,
+but rather [launch the application with `spark-submit`](submitting-applications.html) and
+receive it there. However, for local testing and unit tests, you can pass "local[\*]" to run Spark Streaming
+in-process (detects the number of cores in the local system).
+
+The batch interval must be set based on the latency requirements of your application
+and available cluster resources. See the [Performance Tuning](#setting-the-right-batch-size)
+section for more details.
+
After a context is defined, you have to do the follow steps.
+
1. Define the input sources.
1. Setup the streaming computations.
1. Start the receiving and procesing of data using `streamingContext.start()`.
@@ -483,6 +608,9 @@ methods for creating DStreams from files and Akka actors as input sources.
Spark Streaming will monitor the directory `dataDirectory` and process any files created in that directory (files written in nested directories not supported). Note that
@@ -494,7 +622,7 @@ methods for creating DStreams from files and Akka actors as input sources.
For simple text files, there is an easier method `streamingContext.textFileStream(dataDirectory)`. And file streams do not require running a receiver, hence does not require allocating cores.
-- **Streams based on Custom Actors:** DStreams can be created with data streams received through Akka actors by using `streamingContext.actorStream(actorProps, actor-name)`. See the [Custom Receiver Guide](#implementing-and-using-a-custom-actor-based-receiver) for more details.
+- **Streams based on Custom Actors:** DStreams can be created with data streams received through Akka actors by using `streamingContext.actorStream(actorProps, actor-name)`. See the [Custom Receiver Guide](streaming-custom-receivers.html#implementing-and-using-a-custom-actor-based-receiver) for more details.
- **Queue of RDDs as a Stream:** For testing a Spark Streaming application with test data, one can also create a DStream based on a queue of RDDs, using `streamingContext.queueStream(queueOfRDDs)`. Each RDD pushed into the queue will be treated as a batch of data in the DStream, and processed like a stream.
@@ -684,13 +812,30 @@ This is applied on a DStream containing words (say, the `pairs` DStream containi
JavaPairDStream runningCounts = pairs.updateStateByKey(updateFunction);
{% endhighlight %}
+
+
+
+{% highlight python %}
+def updateFunction(newValues, runningCount):
+ if runningCount is None:
+ runningCount = 0
+ return sum(newValues, runningCount) # add the new values with the previous running count to get the new count
+{% endhighlight %}
+
+This is applied on a DStream containing words (say, the `pairs` DStream containing `(word,
+1)` pairs in the [earlier example](#a-quick-example)).
+
+{% highlight python %}
+runningCounts = pairs.updateStateByKey(updateFunction)
+{% endhighlight %}
+
The update function will be called for each word, with `newValues` having a sequence of 1's (from
the `(word, 1)` pairs) and the `runningCount` having the previous count. For the complete
Scala code, take a look at the example
-[StatefulNetworkWordCount]({{site.SPARK_GITHUB_URL}}/blob/master/examples/src/main/scala/org/apache/spark/examples/streaming/StatefulNetworkWordCount.scala).
+[stateful_network_wordcount.py]({{site.SPARK_GITHUB_URL}}/blob/master/examples/src/main/python/streaming/stateful_network_wordcount.py).
#### Transform Operation
{:.no_toc}
@@ -732,6 +877,15 @@ JavaPairDStream cleanedDStream = wordCounts.transform(
});
{% endhighlight %}
+
+
+
+{% highlight python %}
+spamInfoRDD = sc.pickleFile(...) # RDD containing spam information
+
+# join data stream with spam information to do data cleaning
+cleanedDStream = wordCounts.transform(lambda rdd: rdd.join(spamInfoRDD).filter(...))
+{% endhighlight %}
@@ -793,6 +947,14 @@ Function2 reduceFunc = new Function2 windowedWordCounts = pairs.reduceByKeyAndWindow(reduceFunc, new Duration(30000), new Duration(10000));
{% endhighlight %}
+
+
+
+{% highlight python %}
+# Reduce last 30 seconds of data, every 10 seconds
+windowedWordCounts = pairs.reduceByKeyAndWindow(lambda x, y: x + y, lambda x, y: x - y, 30, 10)
+{% endhighlight %}
+
@@ -860,6 +1022,7 @@ see [DStream](api/scala/index.html#org.apache.spark.streaming.dstream.DStream)
and [PairDStreamFunctions](api/scala/index.html#org.apache.spark.streaming.dstream.PairDStreamFunctions).
For the Java API, see [JavaDStream](api/java/index.html?org/apache/spark/streaming/api/java/JavaDStream.html)
and [JavaPairDStream](api/java/index.html?org/apache/spark/streaming/api/java/JavaPairDStream.html).
+For the Python API, see [DStream](api/python/pyspark.streaming.html#pyspark.streaming.DStream)
***
@@ -872,9 +1035,12 @@ Currently, the following output operations are defined:
Output Operation
Meaning
-
print()
+
print()
Prints first ten elements of every batch of data in a DStream on the driver.
- This is useful for development and debugging.
+ This is useful for development and debugging.
+
+ PS: called pprint() in Python)
+
saveAsObjectFiles(prefix, [suffix])
@@ -915,17 +1081,41 @@ For this purpose, a developer may inadvertantly try creating a connection object
the Spark driver, but try to use it in a Spark worker to save records in the RDDs.
For example (in Scala),
+
+
+
+{% highlight scala %}
dstream.foreachRDD(rdd => {
val connection = createNewConnection() // executed at the driver
rdd.foreach(record => {
connection.send(record) // executed at the worker
})
})
+{% endhighlight %}
+
+
- This is incorrect as this requires the connection object to be serialized and sent from the driver to the worker. Such connection objects are rarely transferrable across machines. This error may manifest as serialization errors (connection object not serializable), initialization errors (connection object needs to be initialized at the workers), etc. The correct solution is to create the connection object at the worker.
+ This is incorrect as this requires the connection object to be serialized and sent from the driver to the worker. Such connection objects are rarely transferrable across machines. This error may manifest as serialization errors (connection object not serializable), initialization errors (connection object needs to be initialized at the workers), etc. The correct solution is to create the connection object at the worker.
- However, this can lead to another common mistake - creating a new connection for every record. For example,
+
+
+
+{% highlight scala %}
dstream.foreachRDD(rdd => {
rdd.foreach(record => {
val connection = createNewConnection()
@@ -933,9 +1123,28 @@ For example (in Scala),
connection.close()
})
})
+{% endhighlight %}
- Typically, creating a connection object has time and resource overheads. Therefore, creating and destroying a connection object for each record can incur unnecessarily high overheads and can significantly reduce the overall throughput of the system. A better solution is to use `rdd.foreachPartition` - create a single connection object and send all the records in a RDD partition using that connection.
+
+
+ Typically, creating a connection object has time and resource overheads. Therefore, creating and destroying a connection object for each record can incur unnecessarily high overheads and can significantly reduce the overall throughput of the system. A better solution is to use `rdd.foreachPartition` - create a single connection object and send all the records in a RDD partition using that connection.
+
+
+
+{% highlight scala %}
dstream.foreachRDD(rdd => {
rdd.foreachPartition(partitionOfRecords => {
val connection = createNewConnection()
@@ -943,13 +1152,31 @@ For example (in Scala),
connection.close()
})
})
+{% endhighlight %}
+
+
+
+{% highlight python %}
+def sendPartition(iter):
+ connection = createNewConnection()
+ for record in iter:
+ connection.send(record)
+ connection.close()
+
+dstream.foreachRDD(lambda rdd: rdd.foreachPartition(sendPartition))
+{% endhighlight %}
+
+
- This amortizes the connection creation overheads over many records.
+ This amortizes the connection creation overheads over many records.
- Finally, this can be further optimized by reusing connection objects across multiple RDDs/batches.
One can maintain a static pool of connection objects than can be reused as
RDDs of multiple batches are pushed to the external system, thus further reducing the overheads.
-
+
+
+
+{% highlight scala %}
dstream.foreachRDD(rdd => {
rdd.foreachPartition(partitionOfRecords => {
// ConnectionPool is a static, lazily initialized pool of connections
@@ -958,8 +1185,25 @@ For example (in Scala),
ConnectionPool.returnConnection(connection) // return to the pool for future reuse
})
})
+{% endhighlight %}
+
- Note that the connections in the pool should be lazily created on demand and timed out if not used for a while. This achieves the most efficient sending of data to external systems.
+
+{% highlight python %}
+def sendPartition(iter):
+ # ConnectionPool is a static, lazily initialized pool of connections
+ connection = ConnectionPool.getConnection()
+ for record in iter:
+ connection.send(record)
+ # return to the pool for future reuse
+ ConnectionPool.returnConnection(connection)
+
+dstream.foreachRDD(lambda rdd: rdd.foreachPartition(sendPartition))
+{% endhighlight %}
+
+
+
+Note that the connections in the pool should be lazily created on demand and timed out if not used for a while. This achieves the most efficient sending of data to external systems.
##### Other points to remember:
@@ -1376,6 +1620,44 @@ You can also explicitly create a `JavaStreamingContext` from the checkpoint data
the computation by using `new JavaStreamingContext(checkpointDirectory)`.
+
+
+This behavior is made simple by using `StreamingContext.getOrCreate`. This is used as follows.
+
+{% highlight python %}
+# Function to create and setup a new StreamingContext
+def functionToCreateContext():
+ sc = SparkContext(...) # new context
+ ssc = new StreamingContext(...)
+ lines = ssc.socketTextStream(...) # create DStreams
+ ...
+ ssc.checkpoint(checkpointDirectory) # set checkpoint directory
+ return ssc
+
+# Get StreamingContext from checkpoint data or create a new one
+context = StreamingContext.getOrCreate(checkpointDirectory, functionToCreateContext)
+
+# Do additional setup on context that needs to be done,
+# irrespective of whether it is being started or restarted
+context. ...
+
+# Start the context
+context.start()
+context.awaitTermination()
+{% endhighlight %}
+
+If the `checkpointDirectory` exists, then the context will be recreated from the checkpoint data.
+If the directory does not exist (i.e., running for the first time),
+then the function `functionToCreateContext` will be called to create a new
+context and set up the DStreams. See the Python example
+[recoverable_network_wordcount.py]({{site.SPARK_GITHUB_URL}}/tree/master/examples/src/main/python/streaming/recoverable_network_wordcount.py).
+This example appends the word counts of network data into a file.
+
+You can also explicitly create a `StreamingContext` from the checkpoint data and start the
+ computation by using `StreamingContext.getOrCreate(checkpointDirectory, None)`.
+
+
+
**Note**: If Spark Streaming and/or the Spark Streaming program is recompiled,
@@ -1572,7 +1854,11 @@ package and renamed for better clarity.
[TwitterUtils](api/java/index.html?org/apache/spark/streaming/twitter/TwitterUtils.html),
[ZeroMQUtils](api/java/index.html?org/apache/spark/streaming/zeromq/ZeroMQUtils.html), and
[MQTTUtils](api/java/index.html?org/apache/spark/streaming/mqtt/MQTTUtils.html)
+ - Python docs
+ * [StreamingContext](api/python/pyspark.streaming.html#pyspark.streaming.StreamingContext)
+ * [DStream](api/python/pyspark.streaming.html#pyspark.streaming.DStream)
* More examples in [Scala]({{site.SPARK_GITHUB_URL}}/tree/master/examples/src/main/scala/org/apache/spark/examples/streaming)
and [Java]({{site.SPARK_GITHUB_URL}}/tree/master/examples/src/main/java/org/apache/spark/examples/streaming)
+ and [Python] ({{site.SPARK_GITHUB_URL}}/tree/master/examples/src/main/python/streaming)
* [Paper](http://www.eecs.berkeley.edu/Pubs/TechRpts/2012/EECS-2012-259.pdf) and [video](http://youtu.be/g171ndOHgJ0) describing Spark Streaming.
diff --git a/ec2/spark_ec2.py b/ec2/spark_ec2.py
index 941dfb988b9fb..0d6b82b4944f3 100755
--- a/ec2/spark_ec2.py
+++ b/ec2/spark_ec2.py
@@ -32,6 +32,7 @@
import tempfile
import time
import urllib2
+import warnings
from optparse import OptionParser
from sys import stderr
import boto
@@ -61,8 +62,8 @@ def parse_args():
"-s", "--slaves", type="int", default=1,
help="Number of slaves to launch (default: %default)")
parser.add_option(
- "-w", "--wait", type="int", default=120,
- help="Seconds to wait for nodes to start (default: %default)")
+ "-w", "--wait", type="int",
+ help="DEPRECATED (no longer necessary) - Seconds to wait for nodes to start")
parser.add_option(
"-k", "--key-pair",
help="Key pair to use on instances")
@@ -195,18 +196,6 @@ def get_or_make_group(conn, name):
return conn.create_security_group(name, "Spark EC2 group")
-# Wait for a set of launched instances to exit the "pending" state
-# (i.e. either to start running or to fail and be terminated)
-def wait_for_instances(conn, instances):
- while True:
- for i in instances:
- i.update()
- if len([i for i in instances if i.state == 'pending']) > 0:
- time.sleep(5)
- else:
- return
-
-
# Check whether a given EC2 instance object is in a state we consider active,
# i.e. not terminating or terminated. We count both stopping and stopped as
# active since we can restart stopped clusters.
@@ -594,7 +583,7 @@ def setup_cluster(conn, master_nodes, slave_nodes, opts, deploy_ssh_key):
# NOTE: We should clone the repository before running deploy_files to
# prevent ec2-variables.sh from being overwritten
- ssh(master, opts, "rm -rf spark-ec2 && git clone https://github.com/mesos/spark-ec2.git -b v3")
+ ssh(master, opts, "rm -rf spark-ec2 && git clone https://github.com/mesos/spark-ec2.git -b v4")
print "Deploying files to master..."
deploy_files(conn, "deploy.generic", opts, master_nodes, slave_nodes, modules)
@@ -619,14 +608,64 @@ def setup_spark_cluster(master, opts):
print "Ganglia started at http://%s:5080/ganglia" % master
-# Wait for a whole cluster (masters, slaves and ZooKeeper) to start up
-def wait_for_cluster(conn, wait_secs, master_nodes, slave_nodes):
- print "Waiting for instances to start up..."
- time.sleep(5)
- wait_for_instances(conn, master_nodes)
- wait_for_instances(conn, slave_nodes)
- print "Waiting %d more seconds..." % wait_secs
- time.sleep(wait_secs)
+def is_ssh_available(host, opts):
+ "Checks if SSH is available on the host."
+ try:
+ with open(os.devnull, 'w') as devnull:
+ ret = subprocess.check_call(
+ ssh_command(opts) + ['-t', '-t', '-o', 'ConnectTimeout=3',
+ '%s@%s' % (opts.user, host), stringify_command('true')],
+ stdout=devnull,
+ stderr=devnull
+ )
+ return ret == 0
+ except subprocess.CalledProcessError as e:
+ return False
+
+
+def is_cluster_ssh_available(cluster_instances, opts):
+ for i in cluster_instances:
+ if not is_ssh_available(host=i.ip_address, opts=opts):
+ return False
+ else:
+ return True
+
+
+def wait_for_cluster_state(cluster_instances, cluster_state, opts):
+ """
+ cluster_instances: a list of boto.ec2.instance.Instance
+ cluster_state: a string representing the desired state of all the instances in the cluster
+ value can be 'ssh-ready' or a valid value from boto.ec2.instance.InstanceState such as
+ 'running', 'terminated', etc.
+ (would be nice to replace this with a proper enum: http://stackoverflow.com/a/1695250)
+ """
+ sys.stdout.write(
+ "Waiting for all instances in cluster to enter '{s}' state.".format(s=cluster_state)
+ )
+ sys.stdout.flush()
+
+ num_attempts = 0
+
+ while True:
+ time.sleep(3 * num_attempts)
+
+ for i in cluster_instances:
+ s = i.update() # capture output to suppress print to screen in newer versions of boto
+
+ if cluster_state == 'ssh-ready':
+ if all(i.state == 'running' for i in cluster_instances) and \
+ is_cluster_ssh_available(cluster_instances, opts):
+ break
+ else:
+ if all(i.state == cluster_state for i in cluster_instances):
+ break
+
+ num_attempts += 1
+
+ sys.stdout.write(".")
+ sys.stdout.flush()
+
+ sys.stdout.write("\n")
# Get number of local disks available for a given EC2 instance type.
@@ -868,6 +907,16 @@ def real_main():
(opts, action, cluster_name) = parse_args()
# Input parameter validation
+ if opts.wait is not None:
+ # NOTE: DeprecationWarnings are silent in 2.7+ by default.
+ # To show them, run Python with the -Wdefault switch.
+ # See: https://docs.python.org/3.5/whatsnew/2.7.html
+ warnings.warn(
+ "This option is deprecated and has no effect. "
+ "spark-ec2 automatically waits as long as necessary for clusters to startup.",
+ DeprecationWarning
+ )
+
if opts.ebs_vol_num > 8:
print >> stderr, "ebs-vol-num cannot be greater than 8"
sys.exit(1)
@@ -890,7 +939,11 @@ def real_main():
(master_nodes, slave_nodes) = get_existing_cluster(conn, opts, cluster_name)
else:
(master_nodes, slave_nodes) = launch_cluster(conn, opts, cluster_name)
- wait_for_cluster(conn, opts.wait, master_nodes, slave_nodes)
+ wait_for_cluster_state(
+ cluster_instances=(master_nodes + slave_nodes),
+ cluster_state='ssh-ready',
+ opts=opts
+ )
setup_cluster(conn, master_nodes, slave_nodes, opts, True)
elif action == "destroy":
@@ -919,7 +972,11 @@ def real_main():
else:
group_names = [opts.security_group_prefix + "-master",
opts.security_group_prefix + "-slaves"]
-
+ wait_for_cluster_state(
+ cluster_instances=(master_nodes + slave_nodes),
+ cluster_state='terminated',
+ opts=opts
+ )
attempt = 1
while attempt <= 3:
print "Attempt %d" % attempt
@@ -1019,7 +1076,11 @@ def real_main():
for inst in master_nodes:
if inst.state not in ["shutting-down", "terminated"]:
inst.start()
- wait_for_cluster(conn, opts.wait, master_nodes, slave_nodes)
+ wait_for_cluster_state(
+ cluster_instances=(master_nodes + slave_nodes),
+ cluster_state='ssh-ready',
+ opts=opts
+ )
setup_cluster(conn, master_nodes, slave_nodes, opts, False)
else:
diff --git a/examples/pom.xml b/examples/pom.xml
index 2b561857f9f33..eb49a0e5af22d 100644
--- a/examples/pom.xml
+++ b/examples/pom.xml
@@ -43,6 +43,11 @@
spark-streaming-kinesis-asl_${scala.binary.version}${project.version}
+
+ org.apache.httpcomponents
+ httpclient
+ ${commons.httpclient.version}
+
diff --git a/examples/src/main/java/org/apache/spark/examples/JavaSparkPi.java b/examples/src/main/java/org/apache/spark/examples/JavaSparkPi.java
index 11157d7573fae..0f07cb4098325 100644
--- a/examples/src/main/java/org/apache/spark/examples/JavaSparkPi.java
+++ b/examples/src/main/java/org/apache/spark/examples/JavaSparkPi.java
@@ -31,7 +31,6 @@
* Usage: JavaSparkPi [slices]
*/
public final class JavaSparkPi {
-
public static void main(String[] args) throws Exception {
SparkConf sparkConf = new SparkConf().setAppName("JavaSparkPi");
@@ -61,5 +60,7 @@ public Integer call(Integer integer, Integer integer2) {
});
System.out.println("Pi is roughly " + 4.0 * count / n);
+
+ jsc.stop();
}
}
diff --git a/examples/src/main/java/org/apache/spark/examples/sql/JavaSparkSQL.java b/examples/src/main/java/org/apache/spark/examples/sql/JavaSparkSQL.java
index 898297dc658ba..01c77bd44337e 100644
--- a/examples/src/main/java/org/apache/spark/examples/sql/JavaSparkSQL.java
+++ b/examples/src/main/java/org/apache/spark/examples/sql/JavaSparkSQL.java
@@ -61,7 +61,8 @@ public static void main(String[] args) throws Exception {
// Load a text file and convert each line to a Java Bean.
JavaRDD people = ctx.textFile("examples/src/main/resources/people.txt").map(
new Function() {
- public Person call(String line) throws Exception {
+ @Override
+ public Person call(String line) {
String[] parts = line.split(",");
Person person = new Person();
@@ -82,6 +83,7 @@ public Person call(String line) throws Exception {
// The results of SQL queries are SchemaRDDs and support all the normal RDD operations.
// The columns of a row in the result can be accessed by ordinal.
List teenagerNames = teenagers.map(new Function() {
+ @Override
public String call(Row row) {
return "Name: " + row.getString(0);
}
@@ -104,6 +106,7 @@ public String call(Row row) {
JavaSchemaRDD teenagers2 =
sqlCtx.sql("SELECT name FROM parquetFile WHERE age >= 13 AND age <= 19");
teenagerNames = teenagers2.map(new Function() {
+ @Override
public String call(Row row) {
return "Name: " + row.getString(0);
}
@@ -136,6 +139,7 @@ public String call(Row row) {
// The results of SQL queries are JavaSchemaRDDs and support all the normal RDD operations.
// The columns of a row in the result can be accessed by ordinal.
teenagerNames = teenagers3.map(new Function() {
+ @Override
public String call(Row row) { return "Name: " + row.getString(0); }
}).collect();
for (String name: teenagerNames) {
@@ -162,6 +166,7 @@ public String call(Row row) {
JavaSchemaRDD peopleWithCity = sqlCtx.sql("SELECT name, address.city FROM people2");
List nameAndCity = peopleWithCity.map(new Function() {
+ @Override
public String call(Row row) {
return "Name: " + row.getString(0) + ", City: " + row.getString(1);
}
@@ -169,5 +174,7 @@ public String call(Row row) {
for (String name: nameAndCity) {
System.out.println(name);
}
+
+ ctx.stop();
}
}
diff --git a/examples/src/main/python/avro_inputformat.py b/examples/src/main/python/avro_inputformat.py
index cfda8d8327aa3..4626bbb7e3b02 100644
--- a/examples/src/main/python/avro_inputformat.py
+++ b/examples/src/main/python/avro_inputformat.py
@@ -78,3 +78,5 @@
output = avro_rdd.map(lambda x: x[0]).collect()
for k in output:
print k
+
+ sc.stop()
diff --git a/examples/src/main/python/parquet_inputformat.py b/examples/src/main/python/parquet_inputformat.py
index c9b08f878a1e6..fa4c20ab20281 100644
--- a/examples/src/main/python/parquet_inputformat.py
+++ b/examples/src/main/python/parquet_inputformat.py
@@ -57,3 +57,5 @@
output = parquet_rdd.map(lambda x: x[1]).collect()
for k in output:
print k
+
+ sc.stop()
diff --git a/examples/src/main/python/sql.py b/examples/src/main/python/sql.py
new file mode 100644
index 0000000000000..d2c5ca48c6cb8
--- /dev/null
+++ b/examples/src/main/python/sql.py
@@ -0,0 +1,73 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import os
+
+from pyspark import SparkContext
+from pyspark.sql import SQLContext
+from pyspark.sql import Row, StructField, StructType, StringType, IntegerType
+
+
+if __name__ == "__main__":
+ sc = SparkContext(appName="PythonSQL")
+ sqlContext = SQLContext(sc)
+
+ # RDD is created from a list of rows
+ some_rdd = sc.parallelize([Row(name="John", age=19),
+ Row(name="Smith", age=23),
+ Row(name="Sarah", age=18)])
+ # Infer schema from the first row, create a SchemaRDD and print the schema
+ some_schemardd = sqlContext.inferSchema(some_rdd)
+ some_schemardd.printSchema()
+
+ # Another RDD is created from a list of tuples
+ another_rdd = sc.parallelize([("John", 19), ("Smith", 23), ("Sarah", 18)])
+ # Schema with two fields - person_name and person_age
+ schema = StructType([StructField("person_name", StringType(), False),
+ StructField("person_age", IntegerType(), False)])
+ # Create a SchemaRDD by applying the schema to the RDD and print the schema
+ another_schemardd = sqlContext.applySchema(another_rdd, schema)
+ another_schemardd.printSchema()
+ # root
+ # |-- age: integer (nullable = true)
+ # |-- name: string (nullable = true)
+
+ # A JSON dataset is pointed to by path.
+ # The path can be either a single text file or a directory storing text files.
+ path = os.path.join(os.environ['SPARK_HOME'], "examples/src/main/resources/people.json")
+ # Create a SchemaRDD from the file(s) pointed to by path
+ people = sqlContext.jsonFile(path)
+ # root
+ # |-- person_name: string (nullable = false)
+ # |-- person_age: integer (nullable = false)
+
+ # The inferred schema can be visualized using the printSchema() method.
+ people.printSchema()
+ # root
+ # |-- age: IntegerType
+ # |-- name: StringType
+
+ # Register this SchemaRDD as a table.
+ people.registerAsTable("people")
+
+ # SQL statements can be run by using the sql methods provided by sqlContext
+ teenagers = sqlContext.sql("SELECT name FROM people WHERE age >= 13 AND age <= 19")
+
+ for each in teenagers.collect():
+ print each[0]
+
+ sc.stop()
diff --git a/examples/src/main/python/streaming/hdfs_wordcount.py b/examples/src/main/python/streaming/hdfs_wordcount.py
new file mode 100644
index 0000000000000..40faff0ccc7db
--- /dev/null
+++ b/examples/src/main/python/streaming/hdfs_wordcount.py
@@ -0,0 +1,49 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+"""
+ Counts words in new text files created in the given directory
+ Usage: hdfs_wordcount.py
+ is the directory that Spark Streaming will use to find and read new text files.
+
+ To run this on your local machine on directory `localdir`, run this example
+ $ bin/spark-submit examples/src/main/python/streaming/network_wordcount.py localdir
+
+ Then create a text file in `localdir` and the words in the file will get counted.
+"""
+
+import sys
+
+from pyspark import SparkContext
+from pyspark.streaming import StreamingContext
+
+if __name__ == "__main__":
+ if len(sys.argv) != 2:
+ print >> sys.stderr, "Usage: hdfs_wordcount.py "
+ exit(-1)
+
+ sc = SparkContext(appName="PythonStreamingHDFSWordCount")
+ ssc = StreamingContext(sc, 1)
+
+ lines = ssc.textFileStream(sys.argv[1])
+ counts = lines.flatMap(lambda line: line.split(" "))\
+ .map(lambda x: (x, 1))\
+ .reduceByKey(lambda a, b: a+b)
+ counts.pprint()
+
+ ssc.start()
+ ssc.awaitTermination()
diff --git a/examples/src/main/python/streaming/network_wordcount.py b/examples/src/main/python/streaming/network_wordcount.py
new file mode 100644
index 0000000000000..cfa9c1ff5bfbc
--- /dev/null
+++ b/examples/src/main/python/streaming/network_wordcount.py
@@ -0,0 +1,48 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+"""
+ Counts words in UTF8 encoded, '\n' delimited text received from the network every second.
+ Usage: network_wordcount.py
+ and describe the TCP server that Spark Streaming would connect to receive data.
+
+ To run this on your local machine, you need to first run a Netcat server
+ `$ nc -lk 9999`
+ and then run the example
+ `$ bin/spark-submit examples/src/main/python/streaming/network_wordcount.py localhost 9999`
+"""
+
+import sys
+
+from pyspark import SparkContext
+from pyspark.streaming import StreamingContext
+
+if __name__ == "__main__":
+ if len(sys.argv) != 3:
+ print >> sys.stderr, "Usage: network_wordcount.py "
+ exit(-1)
+ sc = SparkContext(appName="PythonStreamingNetworkWordCount")
+ ssc = StreamingContext(sc, 1)
+
+ lines = ssc.socketTextStream(sys.argv[1], int(sys.argv[2]))
+ counts = lines.flatMap(lambda line: line.split(" "))\
+ .map(lambda word: (word, 1))\
+ .reduceByKey(lambda a, b: a+b)
+ counts.pprint()
+
+ ssc.start()
+ ssc.awaitTermination()
diff --git a/examples/src/main/python/streaming/recoverable_network_wordcount.py b/examples/src/main/python/streaming/recoverable_network_wordcount.py
new file mode 100644
index 0000000000000..fc6827c82bf9b
--- /dev/null
+++ b/examples/src/main/python/streaming/recoverable_network_wordcount.py
@@ -0,0 +1,80 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+"""
+ Counts words in text encoded with UTF8 received from the network every second.
+
+ Usage: recoverable_network_wordcount.py
+ and describe the TCP server that Spark Streaming would connect to receive
+ data. directory to HDFS-compatible file system which checkpoint data
+ file to which the word counts will be appended
+
+ To run this on your local machine, you need to first run a Netcat server
+ `$ nc -lk 9999`
+
+ and then run the example
+ `$ bin/spark-submit examples/src/main/python/streaming/recoverable_network_wordcount.py \
+ localhost 9999 ~/checkpoint/ ~/out`
+
+ If the directory ~/checkpoint/ does not exist (e.g. running for the first time), it will create
+ a new StreamingContext (will print "Creating new context" to the console). Otherwise, if
+ checkpoint data exists in ~/checkpoint/, then it will create StreamingContext from
+ the checkpoint data.
+"""
+
+import os
+import sys
+
+from pyspark import SparkContext
+from pyspark.streaming import StreamingContext
+
+
+def createContext(host, port, outputPath):
+ # If you do not see this printed, that means the StreamingContext has been loaded
+ # from the new checkpoint
+ print "Creating new context"
+ if os.path.exists(outputPath):
+ os.remove(outputPath)
+ sc = SparkContext(appName="PythonStreamingRecoverableNetworkWordCount")
+ ssc = StreamingContext(sc, 1)
+
+ # Create a socket stream on target ip:port and count the
+ # words in input stream of \n delimited text (eg. generated by 'nc')
+ lines = ssc.socketTextStream(host, port)
+ words = lines.flatMap(lambda line: line.split(" "))
+ wordCounts = words.map(lambda x: (x, 1)).reduceByKey(lambda x, y: x + y)
+
+ def echo(time, rdd):
+ counts = "Counts at time %s %s" % (time, rdd.collect())
+ print counts
+ print "Appending to " + os.path.abspath(outputPath)
+ with open(outputPath, 'a') as f:
+ f.write(counts + "\n")
+
+ wordCounts.foreachRDD(echo)
+ return ssc
+
+if __name__ == "__main__":
+ if len(sys.argv) != 5:
+ print >> sys.stderr, "Usage: recoverable_network_wordcount.py "\
+ ""
+ exit(-1)
+ host, port, checkpoint, output = sys.argv[1:]
+ ssc = StreamingContext.getOrCreate(checkpoint,
+ lambda: createContext(host, int(port), output))
+ ssc.start()
+ ssc.awaitTermination()
diff --git a/examples/src/main/python/streaming/stateful_network_wordcount.py b/examples/src/main/python/streaming/stateful_network_wordcount.py
new file mode 100644
index 0000000000000..18a9a5a452ffb
--- /dev/null
+++ b/examples/src/main/python/streaming/stateful_network_wordcount.py
@@ -0,0 +1,57 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+"""
+ Counts words in UTF8 encoded, '\n' delimited text received from the
+ network every second.
+
+ Usage: stateful_network_wordcount.py
+ and describe the TCP server that Spark Streaming
+ would connect to receive data.
+
+ To run this on your local machine, you need to first run a Netcat server
+ `$ nc -lk 9999`
+ and then run the example
+ `$ bin/spark-submit examples/src/main/python/streaming/stateful_network_wordcount.py \
+ localhost 9999`
+"""
+
+import sys
+
+from pyspark import SparkContext
+from pyspark.streaming import StreamingContext
+
+if __name__ == "__main__":
+ if len(sys.argv) != 3:
+ print >> sys.stderr, "Usage: stateful_network_wordcount.py "
+ exit(-1)
+ sc = SparkContext(appName="PythonStreamingStatefulNetworkWordCount")
+ ssc = StreamingContext(sc, 1)
+ ssc.checkpoint("checkpoint")
+
+ def updateFunc(new_values, last_sum):
+ return sum(new_values) + (last_sum or 0)
+
+ lines = ssc.socketTextStream(sys.argv[1], int(sys.argv[2]))
+ running_counts = lines.flatMap(lambda line: line.split(" "))\
+ .map(lambda word: (word, 1))\
+ .updateStateByKey(updateFunc)
+
+ running_counts.pprint()
+
+ ssc.start()
+ ssc.awaitTermination()
diff --git a/examples/src/main/scala/org/apache/spark/examples/CassandraCQLTest.scala b/examples/src/main/scala/org/apache/spark/examples/CassandraCQLTest.scala
index 71f53af68f4d3..11d5c92c5952d 100644
--- a/examples/src/main/scala/org/apache/spark/examples/CassandraCQLTest.scala
+++ b/examples/src/main/scala/org/apache/spark/examples/CassandraCQLTest.scala
@@ -136,5 +136,7 @@ object CassandraCQLTest {
classOf[CqlOutputFormat],
job.getConfiguration()
)
+
+ sc.stop()
}
}
diff --git a/examples/src/main/scala/org/apache/spark/examples/CassandraTest.scala b/examples/src/main/scala/org/apache/spark/examples/CassandraTest.scala
index 91ba364a346a5..ec689474aecb0 100644
--- a/examples/src/main/scala/org/apache/spark/examples/CassandraTest.scala
+++ b/examples/src/main/scala/org/apache/spark/examples/CassandraTest.scala
@@ -126,6 +126,8 @@ object CassandraTest {
}
}.saveAsNewAPIHadoopFile("casDemo", classOf[ByteBuffer], classOf[List[Mutation]],
classOf[ColumnFamilyOutputFormat], job.getConfiguration)
+
+ sc.stop()
}
}
diff --git a/examples/src/main/scala/org/apache/spark/examples/GroupByTest.scala b/examples/src/main/scala/org/apache/spark/examples/GroupByTest.scala
index efd91bb054981..15f6678648b29 100644
--- a/examples/src/main/scala/org/apache/spark/examples/GroupByTest.scala
+++ b/examples/src/main/scala/org/apache/spark/examples/GroupByTest.scala
@@ -44,11 +44,11 @@ object GroupByTest {
arr1(i) = (ranGen.nextInt(Int.MaxValue), byteArr)
}
arr1
- }.cache
+ }.cache()
// Enforce that everything has been calculated and in cache
- pairs1.count
+ pairs1.count()
- println(pairs1.groupByKey(numReducers).count)
+ println(pairs1.groupByKey(numReducers).count())
sc.stop()
}
diff --git a/examples/src/main/scala/org/apache/spark/examples/LogQuery.scala b/examples/src/main/scala/org/apache/spark/examples/LogQuery.scala
index 4c655b84fde2e..74620ad007d83 100644
--- a/examples/src/main/scala/org/apache/spark/examples/LogQuery.scala
+++ b/examples/src/main/scala/org/apache/spark/examples/LogQuery.scala
@@ -79,5 +79,7 @@ object LogQuery {
.reduceByKey((a, b) => a.merge(b))
.collect().foreach{
case (user, query) => println("%s\t%s".format(user, query))}
+
+ sc.stop()
}
}
diff --git a/examples/src/main/scala/org/apache/spark/examples/bagel/WikipediaPageRank.scala b/examples/src/main/scala/org/apache/spark/examples/bagel/WikipediaPageRank.scala
index 235c3bf820244..e4db3ec51313d 100644
--- a/examples/src/main/scala/org/apache/spark/examples/bagel/WikipediaPageRank.scala
+++ b/examples/src/main/scala/org/apache/spark/examples/bagel/WikipediaPageRank.scala
@@ -21,7 +21,6 @@ import org.apache.spark._
import org.apache.spark.SparkContext._
import org.apache.spark.bagel._
-import org.apache.spark.bagel.Bagel._
import scala.xml.{XML,NodeSeq}
@@ -78,9 +77,9 @@ object WikipediaPageRank {
(id, new PRVertex(1.0 / numVertices, outEdges))
})
if (usePartitioner) {
- vertices = vertices.partitionBy(new HashPartitioner(sc.defaultParallelism)).cache
+ vertices = vertices.partitionBy(new HashPartitioner(sc.defaultParallelism)).cache()
} else {
- vertices = vertices.cache
+ vertices = vertices.cache()
}
println("Done parsing input file.")
@@ -100,7 +99,9 @@ object WikipediaPageRank {
(result
.filter { case (id, vertex) => vertex.value >= threshold }
.map { case (id, vertex) => "%s\t%s\n".format(id, vertex.value) }
- .collect.mkString)
+ .collect().mkString)
println(top)
+
+ sc.stop()
}
}
diff --git a/examples/src/main/scala/org/apache/spark/examples/graphx/Analytics.scala b/examples/src/main/scala/org/apache/spark/examples/graphx/Analytics.scala
index c4317a6aec798..45527d9382fd0 100644
--- a/examples/src/main/scala/org/apache/spark/examples/graphx/Analytics.scala
+++ b/examples/src/main/scala/org/apache/spark/examples/graphx/Analytics.scala
@@ -46,17 +46,6 @@ object Analytics extends Logging {
}
val options = mutable.Map(optionsList: _*)
- def pickPartitioner(v: String): PartitionStrategy = {
- // TODO: Use reflection rather than listing all the partitioning strategies here.
- v match {
- case "RandomVertexCut" => RandomVertexCut
- case "EdgePartition1D" => EdgePartition1D
- case "EdgePartition2D" => EdgePartition2D
- case "CanonicalRandomVertexCut" => CanonicalRandomVertexCut
- case _ => throw new IllegalArgumentException("Invalid PartitionStrategy: " + v)
- }
- }
-
val conf = new SparkConf()
.set("spark.serializer", "org.apache.spark.serializer.KryoSerializer")
.set("spark.kryo.registrator", "org.apache.spark.graphx.GraphKryoRegistrator")
@@ -67,7 +56,7 @@ object Analytics extends Logging {
sys.exit(1)
}
val partitionStrategy: Option[PartitionStrategy] = options.remove("partStrategy")
- .map(pickPartitioner(_))
+ .map(PartitionStrategy.fromString(_))
val edgeStorageLevel = options.remove("edgeStorageLevel")
.map(StorageLevel.fromString(_)).getOrElse(StorageLevel.MEMORY_ONLY)
val vertexStorageLevel = options.remove("vertexStorageLevel")
@@ -107,7 +96,7 @@ object Analytics extends Logging {
if (!outFname.isEmpty) {
logWarning("Saving pageranks of pages to " + outFname)
- pr.map{case (id, r) => id + "\t" + r}.saveAsTextFile(outFname)
+ pr.map { case (id, r) => id + "\t" + r }.saveAsTextFile(outFname)
}
sc.stop()
@@ -129,7 +118,7 @@ object Analytics extends Logging {
val graph = partitionStrategy.foldLeft(unpartitionedGraph)(_.partitionBy(_))
val cc = ConnectedComponents.run(graph)
- println("Components: " + cc.vertices.map{ case (vid,data) => data}.distinct())
+ println("Components: " + cc.vertices.map { case (vid, data) => data }.distinct())
sc.stop()
case "triangles" =>
@@ -147,7 +136,7 @@ object Analytics extends Logging {
minEdgePartitions = numEPart,
edgeStorageLevel = edgeStorageLevel,
vertexStorageLevel = vertexStorageLevel)
- // TriangleCount requires the graph to be partitioned
+ // TriangleCount requires the graph to be partitioned
.partitionBy(partitionStrategy.getOrElse(RandomVertexCut)).cache()
val triangles = TriangleCount.run(graph)
println("Triangles: " + triangles.vertices.map {
diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/AbstractParams.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/AbstractParams.scala
new file mode 100644
index 0000000000000..ae6057758d6fc
--- /dev/null
+++ b/examples/src/main/scala/org/apache/spark/examples/mllib/AbstractParams.scala
@@ -0,0 +1,53 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.examples.mllib
+
+import scala.reflect.runtime.universe._
+
+/**
+ * Abstract class for parameter case classes.
+ * This overrides the [[toString]] method to print all case class fields by name and value.
+ * @tparam T Concrete parameter class.
+ */
+abstract class AbstractParams[T: TypeTag] {
+
+ private def tag: TypeTag[T] = typeTag[T]
+
+ /**
+ * Finds all case class fields in concrete class instance, and outputs them in JSON-style format:
+ * {
+ * [field name]:\t[field value]\n
+ * [field name]:\t[field value]\n
+ * ...
+ * }
+ */
+ override def toString: String = {
+ val tpe = tag.tpe
+ val allAccessors = tpe.declarations.collect {
+ case m: MethodSymbol if m.isCaseAccessor => m
+ }
+ val mirror = runtimeMirror(getClass.getClassLoader)
+ val instanceMirror = mirror.reflect(this)
+ allAccessors.map { f =>
+ val paramName = f.name.toString
+ val fieldMirror = instanceMirror.reflectField(f)
+ val paramValue = fieldMirror.get
+ s" $paramName:\t$paramValue"
+ }.mkString("{\n", ",\n", "\n}")
+ }
+}
diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/BinaryClassification.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/BinaryClassification.scala
index a6f78d2441db1..1edd2432a0352 100644
--- a/examples/src/main/scala/org/apache/spark/examples/mllib/BinaryClassification.scala
+++ b/examples/src/main/scala/org/apache/spark/examples/mllib/BinaryClassification.scala
@@ -55,7 +55,7 @@ object BinaryClassification {
stepSize: Double = 1.0,
algorithm: Algorithm = LR,
regType: RegType = L2,
- regParam: Double = 0.1)
+ regParam: Double = 0.1) extends AbstractParams[Params]
def main(args: Array[String]) {
val defaultParams = Params()
diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/Correlations.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/Correlations.scala
index d6b2fe430e5a4..e49129c4e7844 100644
--- a/examples/src/main/scala/org/apache/spark/examples/mllib/Correlations.scala
+++ b/examples/src/main/scala/org/apache/spark/examples/mllib/Correlations.scala
@@ -35,6 +35,7 @@ import org.apache.spark.{SparkConf, SparkContext}
object Correlations {
case class Params(input: String = "data/mllib/sample_linear_regression_data.txt")
+ extends AbstractParams[Params]
def main(args: Array[String]) {
diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/CosineSimilarity.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/CosineSimilarity.scala
new file mode 100644
index 0000000000000..cb1abbd18fd4d
--- /dev/null
+++ b/examples/src/main/scala/org/apache/spark/examples/mllib/CosineSimilarity.scala
@@ -0,0 +1,108 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.examples.mllib
+
+import scopt.OptionParser
+
+import org.apache.spark.SparkContext._
+import org.apache.spark.mllib.linalg.Vectors
+import org.apache.spark.mllib.linalg.distributed.{MatrixEntry, RowMatrix}
+import org.apache.spark.{SparkConf, SparkContext}
+
+/**
+ * Compute the similar columns of a matrix, using cosine similarity.
+ *
+ * The input matrix must be stored in row-oriented dense format, one line per row with its entries
+ * separated by space. For example,
+ * {{{
+ * 0.5 1.0
+ * 2.0 3.0
+ * 4.0 5.0
+ * }}}
+ * represents a 3-by-2 matrix, whose first row is (0.5, 1.0).
+ *
+ * Example invocation:
+ *
+ * bin/run-example mllib.CosineSimilarity \
+ * --threshold 0.1 data/mllib/sample_svm_data.txt
+ */
+object CosineSimilarity {
+ case class Params(inputFile: String = null, threshold: Double = 0.1)
+ extends AbstractParams[Params]
+
+ def main(args: Array[String]) {
+ val defaultParams = Params()
+
+ val parser = new OptionParser[Params]("CosineSimilarity") {
+ head("CosineSimilarity: an example app.")
+ opt[Double]("threshold")
+ .required()
+ .text(s"threshold similarity: to tradeoff computation vs quality estimate")
+ .action((x, c) => c.copy(threshold = x))
+ arg[String]("")
+ .required()
+ .text(s"input file, one row per line, space-separated")
+ .action((x, c) => c.copy(inputFile = x))
+ note(
+ """
+ |For example, the following command runs this app on a dataset:
+ |
+ | ./bin/spark-submit --class org.apache.spark.examples.mllib.CosineSimilarity \
+ | examplesjar.jar \
+ | --threshold 0.1 data/mllib/sample_svm_data.txt
+ """.stripMargin)
+ }
+
+ parser.parse(args, defaultParams).map { params =>
+ run(params)
+ } getOrElse {
+ System.exit(1)
+ }
+ }
+
+ def run(params: Params) {
+ val conf = new SparkConf().setAppName("CosineSimilarity")
+ val sc = new SparkContext(conf)
+
+ // Load and parse the data file.
+ val rows = sc.textFile(params.inputFile).map { line =>
+ val values = line.split(' ').map(_.toDouble)
+ Vectors.dense(values)
+ }.cache()
+ val mat = new RowMatrix(rows)
+
+ // Compute similar columns perfectly, with brute force.
+ val exact = mat.columnSimilarities()
+
+ // Compute similar columns with estimation using DIMSUM
+ val approx = mat.columnSimilarities(params.threshold)
+
+ val exactEntries = exact.entries.map { case MatrixEntry(i, j, u) => ((i, j), u) }
+ val approxEntries = approx.entries.map { case MatrixEntry(i, j, v) => ((i, j), v) }
+ val MAE = exactEntries.leftOuterJoin(approxEntries).values.map {
+ case (u, Some(v)) =>
+ math.abs(u - v)
+ case (u, None) =>
+ math.abs(u)
+ }.mean()
+
+ println(s"Average absolute error in estimate is: $MAE")
+
+ sc.stop()
+ }
+}
diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRunner.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRunner.scala
index 96fb068e9e126..0890e6263e165 100644
--- a/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRunner.scala
+++ b/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRunner.scala
@@ -52,6 +52,7 @@ object DecisionTreeRunner {
case class Params(
input: String = null,
+ testInput: String = "",
dataFormat: String = "libsvm",
algo: Algo = Classification,
maxDepth: Int = 5,
@@ -61,7 +62,7 @@ object DecisionTreeRunner {
minInfoGain: Double = 0.0,
numTrees: Int = 1,
featureSubsetStrategy: String = "auto",
- fracTest: Double = 0.2)
+ fracTest: Double = 0.2) extends AbstractParams[Params]
def main(args: Array[String]) {
val defaultParams = Params()
@@ -98,13 +99,18 @@ object DecisionTreeRunner {
s"default: ${defaultParams.featureSubsetStrategy}")
.action((x, c) => c.copy(featureSubsetStrategy = x))
opt[Double]("fracTest")
- .text(s"fraction of data to hold out for testing, default: ${defaultParams.fracTest}")
+ .text(s"fraction of data to hold out for testing. If given option testInput, " +
+ s"this option is ignored. default: ${defaultParams.fracTest}")
.action((x, c) => c.copy(fracTest = x))
+ opt[String]("testInput")
+ .text(s"input path to test dataset. If given, option fracTest is ignored." +
+ s" default: ${defaultParams.testInput}")
+ .action((x, c) => c.copy(testInput = x))
opt[String]("")
.text("data format: libsvm (default), dense (deprecated in Spark v1.1)")
.action((x, c) => c.copy(dataFormat = x))
arg[String]("")
- .text("input paths to labeled examples in dense format (label,f0 f1 f2 ...)")
+ .text("input path to labeled examples")
.required()
.action((x, c) => c.copy(input = x))
checkConfig { params =>
@@ -132,16 +138,18 @@ object DecisionTreeRunner {
def run(params: Params) {
- val conf = new SparkConf().setAppName("DecisionTreeRunner")
+ val conf = new SparkConf().setAppName(s"DecisionTreeRunner with $params")
val sc = new SparkContext(conf)
+ println(s"DecisionTreeRunner with parameters:\n$params")
+
// Load training data and cache it.
val origExamples = params.dataFormat match {
case "dense" => MLUtils.loadLabeledPoints(sc, params.input).cache()
case "libsvm" => MLUtils.loadLibSVMFile(sc, params.input).cache()
}
// For classification, re-index classes if needed.
- val (examples, numClasses) = params.algo match {
+ val (examples, classIndexMap, numClasses) = params.algo match {
case Classification => {
// classCounts: class --> # examples in class
val classCounts = origExamples.map(_.label).countByValue()
@@ -170,16 +178,41 @@ object DecisionTreeRunner {
val frac = classCounts(c) / numExamples.toDouble
println(s"$c\t$frac\t${classCounts(c)}")
}
- (examples, numClasses)
+ (examples, classIndexMap, numClasses)
}
case Regression =>
- (origExamples, 0)
+ (origExamples, null, 0)
case _ =>
throw new IllegalArgumentException("Algo ${params.algo} not supported.")
}
- // Split into training, test.
- val splits = examples.randomSplit(Array(1.0 - params.fracTest, params.fracTest))
+ // Create training, test sets.
+ val splits = if (params.testInput != "") {
+ // Load testInput.
+ val numFeatures = examples.take(1)(0).features.size
+ val origTestExamples = params.dataFormat match {
+ case "dense" => MLUtils.loadLabeledPoints(sc, params.testInput)
+ case "libsvm" => MLUtils.loadLibSVMFile(sc, params.testInput, numFeatures)
+ }
+ params.algo match {
+ case Classification => {
+ // classCounts: class --> # examples in class
+ val testExamples = {
+ if (classIndexMap.isEmpty) {
+ origTestExamples
+ } else {
+ origTestExamples.map(lp => LabeledPoint(classIndexMap(lp.label), lp.features))
+ }
+ }
+ Array(examples, testExamples)
+ }
+ case Regression =>
+ Array(examples, origTestExamples)
+ }
+ } else {
+ // Split input into training, test.
+ examples.randomSplit(Array(1.0 - params.fracTest, params.fracTest))
+ }
val training = splits(0).cache()
val test = splits(1).cache()
val numTraining = training.count()
@@ -205,48 +238,72 @@ object DecisionTreeRunner {
minInstancesPerNode = params.minInstancesPerNode,
minInfoGain = params.minInfoGain)
if (params.numTrees == 1) {
+ val startTime = System.nanoTime()
val model = DecisionTree.train(training, strategy)
- println(model)
+ val elapsedTime = (System.nanoTime() - startTime) / 1e9
+ println(s"Training time: $elapsedTime seconds")
+ if (model.numNodes < 20) {
+ println(model.toDebugString) // Print full model.
+ } else {
+ println(model) // Print model summary.
+ }
if (params.algo == Classification) {
- val accuracy =
+ val trainAccuracy =
+ new MulticlassMetrics(training.map(lp => (model.predict(lp.features), lp.label)))
+ .precision
+ println(s"Train accuracy = $trainAccuracy")
+ val testAccuracy =
new MulticlassMetrics(test.map(lp => (model.predict(lp.features), lp.label))).precision
- println(s"Test accuracy = $accuracy")
+ println(s"Test accuracy = $testAccuracy")
}
if (params.algo == Regression) {
- val mse = meanSquaredError(model, test)
- println(s"Test mean squared error = $mse")
+ val trainMSE = meanSquaredError(model, training)
+ println(s"Train mean squared error = $trainMSE")
+ val testMSE = meanSquaredError(model, test)
+ println(s"Test mean squared error = $testMSE")
}
} else {
val randomSeed = Utils.random.nextInt()
if (params.algo == Classification) {
+ val startTime = System.nanoTime()
val model = RandomForest.trainClassifier(training, strategy, params.numTrees,
params.featureSubsetStrategy, randomSeed)
- println(model)
- val accuracy =
+ val elapsedTime = (System.nanoTime() - startTime) / 1e9
+ println(s"Training time: $elapsedTime seconds")
+ if (model.totalNumNodes < 30) {
+ println(model.toDebugString) // Print full model.
+ } else {
+ println(model) // Print model summary.
+ }
+ val trainAccuracy =
+ new MulticlassMetrics(training.map(lp => (model.predict(lp.features), lp.label)))
+ .precision
+ println(s"Train accuracy = $trainAccuracy")
+ val testAccuracy =
new MulticlassMetrics(test.map(lp => (model.predict(lp.features), lp.label))).precision
- println(s"Test accuracy = $accuracy")
+ println(s"Test accuracy = $testAccuracy")
}
if (params.algo == Regression) {
+ val startTime = System.nanoTime()
val model = RandomForest.trainRegressor(training, strategy, params.numTrees,
params.featureSubsetStrategy, randomSeed)
- println(model)
- val mse = meanSquaredError(model, test)
- println(s"Test mean squared error = $mse")
+ val elapsedTime = (System.nanoTime() - startTime) / 1e9
+ println(s"Training time: $elapsedTime seconds")
+ if (model.totalNumNodes < 30) {
+ println(model.toDebugString) // Print full model.
+ } else {
+ println(model) // Print model summary.
+ }
+ val trainMSE = meanSquaredError(model, training)
+ println(s"Train mean squared error = $trainMSE")
+ val testMSE = meanSquaredError(model, test)
+ println(s"Test mean squared error = $testMSE")
}
}
sc.stop()
}
- /**
- * Calculates the classifier accuracy.
- */
- private def accuracyScore(model: DecisionTreeModel, data: RDD[LabeledPoint]): Double = {
- val correctCount = data.filter(y => model.predict(y.features) == y.label).count()
- val count = data.count()
- correctCount.toDouble / count
- }
-
/**
* Calculates the mean squared error for regression.
*/
diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/DenseKMeans.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/DenseKMeans.scala
index 89dfa26c2299c..11e35598baf50 100644
--- a/examples/src/main/scala/org/apache/spark/examples/mllib/DenseKMeans.scala
+++ b/examples/src/main/scala/org/apache/spark/examples/mllib/DenseKMeans.scala
@@ -44,7 +44,7 @@ object DenseKMeans {
input: String = null,
k: Int = -1,
numIterations: Int = 10,
- initializationMode: InitializationMode = Parallel)
+ initializationMode: InitializationMode = Parallel) extends AbstractParams[Params]
def main(args: Array[String]) {
val defaultParams = Params()
diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/LinearRegression.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/LinearRegression.scala
index 05b7d66f8dffd..e1f9622350135 100644
--- a/examples/src/main/scala/org/apache/spark/examples/mllib/LinearRegression.scala
+++ b/examples/src/main/scala/org/apache/spark/examples/mllib/LinearRegression.scala
@@ -47,7 +47,7 @@ object LinearRegression extends App {
numIterations: Int = 100,
stepSize: Double = 1.0,
regType: RegType = L2,
- regParam: Double = 0.1)
+ regParam: Double = 0.1) extends AbstractParams[Params]
val defaultParams = Params()
diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/MovieLensALS.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/MovieLensALS.scala
index 98aaedb9d7dc9..fc6678013b932 100644
--- a/examples/src/main/scala/org/apache/spark/examples/mllib/MovieLensALS.scala
+++ b/examples/src/main/scala/org/apache/spark/examples/mllib/MovieLensALS.scala
@@ -55,7 +55,7 @@ object MovieLensALS {
rank: Int = 10,
numUserBlocks: Int = -1,
numProductBlocks: Int = -1,
- implicitPrefs: Boolean = false)
+ implicitPrefs: Boolean = false) extends AbstractParams[Params]
def main(args: Array[String]) {
val defaultParams = Params()
diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/MultivariateSummarizer.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/MultivariateSummarizer.scala
index 4532512c01f84..6e4e2d07f284b 100644
--- a/examples/src/main/scala/org/apache/spark/examples/mllib/MultivariateSummarizer.scala
+++ b/examples/src/main/scala/org/apache/spark/examples/mllib/MultivariateSummarizer.scala
@@ -36,6 +36,7 @@ import org.apache.spark.{SparkConf, SparkContext}
object MultivariateSummarizer {
case class Params(input: String = "data/mllib/sample_linear_regression_data.txt")
+ extends AbstractParams[Params]
def main(args: Array[String]) {
diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/SampledRDDs.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/SampledRDDs.scala
index f01b8266e3fe3..663c12734af68 100644
--- a/examples/src/main/scala/org/apache/spark/examples/mllib/SampledRDDs.scala
+++ b/examples/src/main/scala/org/apache/spark/examples/mllib/SampledRDDs.scala
@@ -33,6 +33,7 @@ import org.apache.spark.SparkContext._
object SampledRDDs {
case class Params(input: String = "data/mllib/sample_binary_classification_data.txt")
+ extends AbstractParams[Params]
def main(args: Array[String]) {
val defaultParams = Params()
diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/SparseNaiveBayes.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/SparseNaiveBayes.scala
index 952fa2a5109a4..f1ff4e6911f5e 100644
--- a/examples/src/main/scala/org/apache/spark/examples/mllib/SparseNaiveBayes.scala
+++ b/examples/src/main/scala/org/apache/spark/examples/mllib/SparseNaiveBayes.scala
@@ -37,7 +37,7 @@ object SparseNaiveBayes {
input: String = null,
minPartitions: Int = 0,
numFeatures: Int = -1,
- lambda: Double = 1.0)
+ lambda: Double = 1.0) extends AbstractParams[Params]
def main(args: Array[String]) {
val defaultParams = Params()
diff --git a/examples/src/main/scala/org/apache/spark/examples/sql/RDDRelation.scala b/examples/src/main/scala/org/apache/spark/examples/sql/RDDRelation.scala
index d56d64c564200..2e98b2dc30b80 100644
--- a/examples/src/main/scala/org/apache/spark/examples/sql/RDDRelation.scala
+++ b/examples/src/main/scala/org/apache/spark/examples/sql/RDDRelation.scala
@@ -51,7 +51,7 @@ object RDDRelation {
val rddFromSql = sql("SELECT key, value FROM records WHERE key < 10")
println("Result of RDD.map:")
- rddFromSql.map(row => s"Key: ${row(0)}, Value: ${row(1)}").collect.foreach(println)
+ rddFromSql.map(row => s"Key: ${row(0)}, Value: ${row(1)}").collect().foreach(println)
// Queries can also be written using a LINQ-like Scala DSL.
rdd.where('key === 1).orderBy('value.asc).select('key).collect().foreach(println)
@@ -68,5 +68,7 @@ object RDDRelation {
// These files can also be registered as tables.
parquetFile.registerTempTable("parquetFile")
sql("SELECT * FROM parquetFile").collect().foreach(println)
+
+ sc.stop()
}
}
diff --git a/examples/src/main/scala/org/apache/spark/examples/sql/hive/HiveFromSpark.scala b/examples/src/main/scala/org/apache/spark/examples/sql/hive/HiveFromSpark.scala
index 3423fac0ad303..0c52ef8ed96ac 100644
--- a/examples/src/main/scala/org/apache/spark/examples/sql/hive/HiveFromSpark.scala
+++ b/examples/src/main/scala/org/apache/spark/examples/sql/hive/HiveFromSpark.scala
@@ -28,7 +28,7 @@ object HiveFromSpark {
val sparkConf = new SparkConf().setAppName("HiveFromSpark")
val sc = new SparkContext(sparkConf)
- // A local hive context creates an instance of the Hive Metastore in process, storing the
+ // A local hive context creates an instance of the Hive Metastore in process, storing
// the warehouse data in the current directory. This location can be overridden by
// specifying a second parameter to the constructor.
val hiveContext = new HiveContext(sc)
@@ -39,7 +39,7 @@ object HiveFromSpark {
// Queries are expressed in HiveQL
println("Result of 'SELECT *': ")
- sql("SELECT * FROM src").collect.foreach(println)
+ sql("SELECT * FROM src").collect().foreach(println)
// Aggregation queries are also supported.
val count = sql("SELECT COUNT(*) FROM src").collect().head.getLong(0)
@@ -61,5 +61,7 @@ object HiveFromSpark {
// Queries can then join RDD data with data stored in Hive.
println("Result of SELECT *:")
sql("SELECT * FROM records r JOIN src s ON r.key = s.key").collect().foreach(println)
+
+ sc.stop()
}
}
diff --git a/examples/src/main/scala/org/apache/spark/examples/streaming/KafkaWordCount.scala b/examples/src/main/scala/org/apache/spark/examples/streaming/KafkaWordCount.scala
index 566ba6f911e02..c9e1511278ede 100644
--- a/examples/src/main/scala/org/apache/spark/examples/streaming/KafkaWordCount.scala
+++ b/examples/src/main/scala/org/apache/spark/examples/streaming/KafkaWordCount.scala
@@ -53,8 +53,8 @@ object KafkaWordCount {
val ssc = new StreamingContext(sparkConf, Seconds(2))
ssc.checkpoint("checkpoint")
- val topicpMap = topics.split(",").map((_,numThreads.toInt)).toMap
- val lines = KafkaUtils.createStream(ssc, zkQuorum, group, topicpMap).map(_._2)
+ val topicMap = topics.split(",").map((_,numThreads.toInt)).toMap
+ val lines = KafkaUtils.createStream(ssc, zkQuorum, group, topicMap).map(_._2)
val words = lines.flatMap(_.split(" "))
val wordCounts = words.map(x => (x, 1L))
.reduceByKeyAndWindow(_ + _, _ - _, Minutes(10), Seconds(2), 2)
diff --git a/external/flume/src/test/resources/log4j.properties b/external/flume/src/test/resources/log4j.properties
index 45d2ec676df66..4411d6e20c52a 100644
--- a/external/flume/src/test/resources/log4j.properties
+++ b/external/flume/src/test/resources/log4j.properties
@@ -22,7 +22,7 @@ log4j.appender.file=org.apache.log4j.FileAppender
log4j.appender.file.append=false
log4j.appender.file.file=target/unit-tests.log
log4j.appender.file.layout=org.apache.log4j.PatternLayout
-log4j.appender.file.layout.ConversionPattern=%d{yy/MM/dd HH:mm:ss.SSS} %p %c{1}: %m%n
+log4j.appender.file.layout.ConversionPattern=%d{yy/MM/dd HH:mm:ss.SSS} %t %p %c{1}: %m%n
# Ignore messages below warning level from Jetty, because it's a bit verbose
log4j.logger.org.eclipse.jetty=WARN
diff --git a/external/flume/src/test/scala/org/apache/spark/streaming/flume/FlumeStreamSuite.scala b/external/flume/src/test/scala/org/apache/spark/streaming/flume/FlumeStreamSuite.scala
index 6ee7ac974b4a0..13943ed5442b9 100644
--- a/external/flume/src/test/scala/org/apache/spark/streaming/flume/FlumeStreamSuite.scala
+++ b/external/flume/src/test/scala/org/apache/spark/streaming/flume/FlumeStreamSuite.scala
@@ -17,98 +17,141 @@
package org.apache.spark.streaming.flume
-import scala.collection.JavaConversions._
-import scala.collection.mutable.{ArrayBuffer, SynchronizedBuffer}
-
-import java.net.InetSocketAddress
+import java.net.{InetSocketAddress, ServerSocket}
import java.nio.ByteBuffer
import java.nio.charset.Charset
+import scala.collection.JavaConversions._
+import scala.collection.mutable.{ArrayBuffer, SynchronizedBuffer}
+import scala.concurrent.duration._
+import scala.language.postfixOps
+
import org.apache.avro.ipc.NettyTransceiver
import org.apache.avro.ipc.specific.SpecificRequestor
+import org.apache.flume.source.avro
import org.apache.flume.source.avro.{AvroFlumeEvent, AvroSourceProtocol}
-
-import org.apache.spark.storage.StorageLevel
-import org.apache.spark.streaming.{TestOutputStream, StreamingContext, TestSuiteBase}
-import org.apache.spark.streaming.util.ManualClock
-import org.apache.spark.streaming.api.java.JavaReceiverInputDStream
-
import org.jboss.netty.channel.ChannelPipeline
-import org.jboss.netty.channel.socket.nio.NioClientSocketChannelFactory
import org.jboss.netty.channel.socket.SocketChannel
+import org.jboss.netty.channel.socket.nio.NioClientSocketChannelFactory
import org.jboss.netty.handler.codec.compression._
+import org.scalatest.{BeforeAndAfter, FunSuite, Matchers}
+import org.scalatest.concurrent.Eventually._
+
+import org.apache.spark.{Logging, SparkConf}
+import org.apache.spark.storage.StorageLevel
+import org.apache.spark.streaming.{Milliseconds, StreamingContext, TestOutputStream}
+import org.apache.spark.streaming.scheduler.{StreamingListener, StreamingListenerReceiverStarted}
+import org.apache.spark.util.Utils
-class FlumeStreamSuite extends TestSuiteBase {
+class FlumeStreamSuite extends FunSuite with BeforeAndAfter with Matchers with Logging {
+ val conf = new SparkConf().setMaster("local[4]").setAppName("FlumeStreamSuite")
+
+ var ssc: StreamingContext = null
+ var transceiver: NettyTransceiver = null
+
+ after {
+ if (ssc != null) {
+ ssc.stop()
+ }
+ if (transceiver != null) {
+ transceiver.close()
+ }
+ }
test("flume input stream") {
- runFlumeStreamTest(false, 9998)
+ testFlumeStream(testCompression = false)
}
test("flume input compressed stream") {
- runFlumeStreamTest(true, 9997)
+ testFlumeStream(testCompression = true)
+ }
+
+ /** Run test on flume stream */
+ private def testFlumeStream(testCompression: Boolean): Unit = {
+ val input = (1 to 100).map { _.toString }
+ val testPort = findFreePort()
+ val outputBuffer = startContext(testPort, testCompression)
+ writeAndVerify(input, testPort, outputBuffer, testCompression)
}
-
- def runFlumeStreamTest(enableDecompression: Boolean, testPort: Int) {
- // Set up the streaming context and input streams
- val ssc = new StreamingContext(conf, batchDuration)
- val flumeStream: JavaReceiverInputDStream[SparkFlumeEvent] =
- FlumeUtils.createStream(ssc, "localhost", testPort, StorageLevel.MEMORY_AND_DISK, enableDecompression)
+
+ /** Find a free port */
+ private def findFreePort(): Int = {
+ Utils.startServiceOnPort(23456, (trialPort: Int) => {
+ val socket = new ServerSocket(trialPort)
+ socket.close()
+ (null, trialPort)
+ })._2
+ }
+
+ /** Setup and start the streaming context */
+ private def startContext(
+ testPort: Int, testCompression: Boolean): (ArrayBuffer[Seq[SparkFlumeEvent]]) = {
+ ssc = new StreamingContext(conf, Milliseconds(200))
+ val flumeStream = FlumeUtils.createStream(
+ ssc, "localhost", testPort, StorageLevel.MEMORY_AND_DISK, testCompression)
val outputBuffer = new ArrayBuffer[Seq[SparkFlumeEvent]]
with SynchronizedBuffer[Seq[SparkFlumeEvent]]
- val outputStream = new TestOutputStream(flumeStream.receiverInputDStream, outputBuffer)
+ val outputStream = new TestOutputStream(flumeStream, outputBuffer)
outputStream.register()
ssc.start()
+ outputBuffer
+ }
- val clock = ssc.scheduler.clock.asInstanceOf[ManualClock]
- val input = Seq(1, 2, 3, 4, 5)
- Thread.sleep(1000)
- val transceiver = new NettyTransceiver(new InetSocketAddress("localhost", testPort))
- var client: AvroSourceProtocol = null;
-
- if (enableDecompression) {
- client = SpecificRequestor.getClient(
- classOf[AvroSourceProtocol],
- new NettyTransceiver(new InetSocketAddress("localhost", testPort),
- new CompressionChannelFactory(6)));
- } else {
- client = SpecificRequestor.getClient(
- classOf[AvroSourceProtocol], transceiver)
- }
+ /** Send data to the flume receiver and verify whether the data was received */
+ private def writeAndVerify(
+ input: Seq[String],
+ testPort: Int,
+ outputBuffer: ArrayBuffer[Seq[SparkFlumeEvent]],
+ enableCompression: Boolean
+ ) {
+ val testAddress = new InetSocketAddress("localhost", testPort)
- for (i <- 0 until input.size) {
+ val inputEvents = input.map { item =>
val event = new AvroFlumeEvent
- event.setBody(ByteBuffer.wrap(input(i).toString.getBytes("utf-8")))
+ event.setBody(ByteBuffer.wrap(item.getBytes("UTF-8")))
event.setHeaders(Map[CharSequence, CharSequence]("test" -> "header"))
- client.append(event)
- Thread.sleep(500)
- clock.addToTime(batchDuration.milliseconds)
+ event
}
- Thread.sleep(1000)
-
- val startTime = System.currentTimeMillis()
- while (outputBuffer.size < input.size && System.currentTimeMillis() - startTime < maxWaitTimeMillis) {
- logInfo("output.size = " + outputBuffer.size + ", input.size = " + input.size)
- Thread.sleep(100)
+ eventually(timeout(10 seconds), interval(100 milliseconds)) {
+ // if last attempted transceiver had succeeded, close it
+ if (transceiver != null) {
+ transceiver.close()
+ transceiver = null
+ }
+
+ // Create transceiver
+ transceiver = {
+ if (enableCompression) {
+ new NettyTransceiver(testAddress, new CompressionChannelFactory(6))
+ } else {
+ new NettyTransceiver(testAddress)
+ }
+ }
+
+ // Create Avro client with the transceiver
+ val client = SpecificRequestor.getClient(classOf[AvroSourceProtocol], transceiver)
+ client should not be null
+
+ // Send data
+ val status = client.appendBatch(inputEvents.toList)
+ status should be (avro.Status.OK)
}
- Thread.sleep(1000)
- val timeTaken = System.currentTimeMillis() - startTime
- assert(timeTaken < maxWaitTimeMillis, "Operation timed out after " + timeTaken + " ms")
- logInfo("Stopping context")
- ssc.stop()
-
- val decoder = Charset.forName("UTF-8").newDecoder()
-
- assert(outputBuffer.size === input.length)
- for (i <- 0 until outputBuffer.size) {
- assert(outputBuffer(i).size === 1)
- val str = decoder.decode(outputBuffer(i).head.event.getBody)
- assert(str.toString === input(i).toString)
- assert(outputBuffer(i).head.event.getHeaders.get("test") === "header")
+
+ val decoder = Charset.forName("UTF-8").newDecoder()
+ eventually(timeout(10 seconds), interval(100 milliseconds)) {
+ val outputEvents = outputBuffer.flatten.map { _.event }
+ outputEvents.foreach {
+ event =>
+ event.getHeaders.get("test") should be("header")
+ }
+ val output = outputEvents.map(event => decoder.decode(event.getBody()).toString)
+ output should be (input)
}
}
- class CompressionChannelFactory(compressionLevel: Int) extends NioClientSocketChannelFactory {
+ /** Class to create socket channel with compression */
+ private class CompressionChannelFactory(compressionLevel: Int) extends NioClientSocketChannelFactory {
override def newChannel(pipeline: ChannelPipeline): SocketChannel = {
val encoder = new ZlibEncoder(compressionLevel)
pipeline.addFirst("deflater", encoder)
diff --git a/external/kafka/src/test/resources/log4j.properties b/external/kafka/src/test/resources/log4j.properties
index 45d2ec676df66..4411d6e20c52a 100644
--- a/external/kafka/src/test/resources/log4j.properties
+++ b/external/kafka/src/test/resources/log4j.properties
@@ -22,7 +22,7 @@ log4j.appender.file=org.apache.log4j.FileAppender
log4j.appender.file.append=false
log4j.appender.file.file=target/unit-tests.log
log4j.appender.file.layout=org.apache.log4j.PatternLayout
-log4j.appender.file.layout.ConversionPattern=%d{yy/MM/dd HH:mm:ss.SSS} %p %c{1}: %m%n
+log4j.appender.file.layout.ConversionPattern=%d{yy/MM/dd HH:mm:ss.SSS} %t %p %c{1}: %m%n
# Ignore messages below warning level from Jetty, because it's a bit verbose
log4j.logger.org.eclipse.jetty=WARN
diff --git a/external/mqtt/src/test/resources/log4j.properties b/external/mqtt/src/test/resources/log4j.properties
index 45d2ec676df66..4411d6e20c52a 100644
--- a/external/mqtt/src/test/resources/log4j.properties
+++ b/external/mqtt/src/test/resources/log4j.properties
@@ -22,7 +22,7 @@ log4j.appender.file=org.apache.log4j.FileAppender
log4j.appender.file.append=false
log4j.appender.file.file=target/unit-tests.log
log4j.appender.file.layout=org.apache.log4j.PatternLayout
-log4j.appender.file.layout.ConversionPattern=%d{yy/MM/dd HH:mm:ss.SSS} %p %c{1}: %m%n
+log4j.appender.file.layout.ConversionPattern=%d{yy/MM/dd HH:mm:ss.SSS} %t %p %c{1}: %m%n
# Ignore messages below warning level from Jetty, because it's a bit verbose
log4j.logger.org.eclipse.jetty=WARN
diff --git a/external/twitter/src/test/resources/log4j.properties b/external/twitter/src/test/resources/log4j.properties
index 45d2ec676df66..4411d6e20c52a 100644
--- a/external/twitter/src/test/resources/log4j.properties
+++ b/external/twitter/src/test/resources/log4j.properties
@@ -22,7 +22,7 @@ log4j.appender.file=org.apache.log4j.FileAppender
log4j.appender.file.append=false
log4j.appender.file.file=target/unit-tests.log
log4j.appender.file.layout=org.apache.log4j.PatternLayout
-log4j.appender.file.layout.ConversionPattern=%d{yy/MM/dd HH:mm:ss.SSS} %p %c{1}: %m%n
+log4j.appender.file.layout.ConversionPattern=%d{yy/MM/dd HH:mm:ss.SSS} %t %p %c{1}: %m%n
# Ignore messages below warning level from Jetty, because it's a bit verbose
log4j.logger.org.eclipse.jetty=WARN
diff --git a/external/zeromq/src/test/resources/log4j.properties b/external/zeromq/src/test/resources/log4j.properties
index 45d2ec676df66..4411d6e20c52a 100644
--- a/external/zeromq/src/test/resources/log4j.properties
+++ b/external/zeromq/src/test/resources/log4j.properties
@@ -22,7 +22,7 @@ log4j.appender.file=org.apache.log4j.FileAppender
log4j.appender.file.append=false
log4j.appender.file.file=target/unit-tests.log
log4j.appender.file.layout=org.apache.log4j.PatternLayout
-log4j.appender.file.layout.ConversionPattern=%d{yy/MM/dd HH:mm:ss.SSS} %p %c{1}: %m%n
+log4j.appender.file.layout.ConversionPattern=%d{yy/MM/dd HH:mm:ss.SSS} %t %p %c{1}: %m%n
# Ignore messages below warning level from Jetty, because it's a bit verbose
log4j.logger.org.eclipse.jetty=WARN
diff --git a/extras/java8-tests/src/test/resources/log4j.properties b/extras/java8-tests/src/test/resources/log4j.properties
index 180beaa8cc5a7..bb0ab319a0080 100644
--- a/extras/java8-tests/src/test/resources/log4j.properties
+++ b/extras/java8-tests/src/test/resources/log4j.properties
@@ -21,7 +21,7 @@ log4j.appender.file=org.apache.log4j.FileAppender
log4j.appender.file.append=false
log4j.appender.file.file=target/unit-tests.log
log4j.appender.file.layout=org.apache.log4j.PatternLayout
-log4j.appender.file.layout.ConversionPattern=%d{yy/MM/dd HH:mm:ss.SSS} %p %c{1}: %m%n
+log4j.appender.file.layout.ConversionPattern=%d{yy/MM/dd HH:mm:ss.SSS} %t %p %c{1}: %m%n
# Ignore messages below warning level from Jetty, because it's a bit verbose
log4j.logger.org.eclipse.jetty=WARN
diff --git a/extras/kinesis-asl/src/test/resources/log4j.properties b/extras/kinesis-asl/src/test/resources/log4j.properties
index e01e049595475..d9d08f68687d3 100644
--- a/extras/kinesis-asl/src/test/resources/log4j.properties
+++ b/extras/kinesis-asl/src/test/resources/log4j.properties
@@ -20,7 +20,7 @@ log4j.appender.file=org.apache.log4j.FileAppender
log4j.appender.file.append=false
log4j.appender.file.file=target/unit-tests.log
log4j.appender.file.layout=org.apache.log4j.PatternLayout
-log4j.appender.file.layout.ConversionPattern=%d{yy/MM/dd HH:mm:ss.SSS} %p %c{1}: %m%n
+log4j.appender.file.layout.ConversionPattern=%d{yy/MM/dd HH:mm:ss.SSS} %t %p %c{1}: %m%n
# Ignore messages below warning level from Jetty, because it's a bit verbose
log4j.logger.org.eclipse.jetty=WARN
diff --git a/graphx/src/test/resources/log4j.properties b/graphx/src/test/resources/log4j.properties
index 26b73a1b39744..9dd05f17f012b 100644
--- a/graphx/src/test/resources/log4j.properties
+++ b/graphx/src/test/resources/log4j.properties
@@ -21,7 +21,7 @@ log4j.appender.file=org.apache.log4j.FileAppender
log4j.appender.file.append=false
log4j.appender.file.file=target/unit-tests.log
log4j.appender.file.layout=org.apache.log4j.PatternLayout
-log4j.appender.file.layout.ConversionPattern=%d{yy/MM/dd HH:mm:ss.SSS} %p %c{1}: %m%n
+log4j.appender.file.layout.ConversionPattern=%d{yy/MM/dd HH:mm:ss.SSS} %t %p %c{1}: %m%n
# Ignore messages below warning level from Jetty, because it's a bit verbose
log4j.logger.org.eclipse.jetty=WARN
diff --git a/mllib/pom.xml b/mllib/pom.xml
index a5eeef88e9d62..696e9396f627c 100644
--- a/mllib/pom.xml
+++ b/mllib/pom.xml
@@ -57,7 +57,7 @@
org.scalanlpbreeze_${scala.binary.version}
- 0.9
+ 0.10
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala
index b7fe508f5120c..b478c21537c2a 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala
@@ -18,6 +18,7 @@
package org.apache.spark.mllib.api.python
import java.io.OutputStream
+import java.util.{ArrayList => JArrayList}
import scala.collection.JavaConverters._
import scala.language.existentials
@@ -27,8 +28,11 @@ import net.razorvine.pickle._
import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.api.java.{JavaRDD, JavaSparkContext}
+import org.apache.spark.api.python.{PythonRDD, SerDeUtil}
import org.apache.spark.mllib.classification._
import org.apache.spark.mllib.clustering._
+import org.apache.spark.mllib.feature.Word2Vec
+import org.apache.spark.mllib.feature.Word2VecModel
import org.apache.spark.mllib.optimization._
import org.apache.spark.mllib.linalg._
import org.apache.spark.mllib.random.{RandomRDDs => RG}
@@ -42,9 +46,9 @@ import org.apache.spark.mllib.stat.{MultivariateStatisticalSummary, Statistics}
import org.apache.spark.mllib.stat.correlation.CorrelationNames
import org.apache.spark.mllib.util.MLUtils
import org.apache.spark.rdd.RDD
+import org.apache.spark.storage.StorageLevel
import org.apache.spark.util.Utils
-
/**
* :: DeveloperApi ::
* The Java stubs necessary for the Python mllib bindings.
@@ -287,6 +291,59 @@ class PythonMLLibAPI extends Serializable {
ALS.trainImplicit(ratingsJRDD.rdd, rank, iterations, lambda, blocks, alpha)
}
+ /**
+ * Java stub for Python mllib Word2Vec fit(). This stub returns a
+ * handle to the Java object instead of the content of the Java object.
+ * Extra care needs to be taken in the Python code to ensure it gets freed on
+ * exit; see the Py4J documentation.
+ * @param dataJRDD input JavaRDD
+ * @param vectorSize size of vector
+ * @param learningRate initial learning rate
+ * @param numPartitions number of partitions
+ * @param numIterations number of iterations
+ * @param seed initial seed for random generator
+ * @return A handle to java Word2VecModelWrapper instance at python side
+ */
+ def trainWord2Vec(
+ dataJRDD: JavaRDD[java.util.ArrayList[String]],
+ vectorSize: Int,
+ learningRate: Double,
+ numPartitions: Int,
+ numIterations: Int,
+ seed: Long): Word2VecModelWrapper = {
+ val data = dataJRDD.rdd.persist(StorageLevel.MEMORY_AND_DISK_SER)
+ val word2vec = new Word2Vec()
+ .setVectorSize(vectorSize)
+ .setLearningRate(learningRate)
+ .setNumPartitions(numPartitions)
+ .setNumIterations(numIterations)
+ .setSeed(seed)
+ val model = word2vec.fit(data)
+ data.unpersist()
+ new Word2VecModelWrapper(model)
+ }
+
+ private[python] class Word2VecModelWrapper(model: Word2VecModel) {
+ def transform(word: String): Vector = {
+ model.transform(word)
+ }
+
+ def findSynonyms(word: String, num: Int): java.util.List[java.lang.Object] = {
+ val vec = transform(word)
+ findSynonyms(vec, num)
+ }
+
+ def findSynonyms(vector: Vector, num: Int): java.util.List[java.lang.Object] = {
+ val result = model.findSynonyms(vector, num)
+ val similarity = Vectors.dense(result.map(_._2))
+ val words = result.map(_._1)
+ val ret = new java.util.LinkedList[java.lang.Object]()
+ ret.add(words)
+ ret.add(similarity)
+ ret
+ }
+ }
+
/**
* Java stub for Python mllib DecisionTree.train().
* This stub returns a handle to the Java object instead of the content of the Java object.
@@ -584,13 +641,24 @@ private[spark] object SerDe extends Serializable {
}
}
+ var initialized = false
+ // This should be called before trying to serialize any above classes
+ // In cluster mode, this should be put in the closure
def initialize(): Unit = {
- new DenseVectorPickler().register()
- new DenseMatrixPickler().register()
- new SparseVectorPickler().register()
- new LabeledPointPickler().register()
- new RatingPickler().register()
+ SerDeUtil.initialize()
+ synchronized {
+ if (!initialized) {
+ new DenseVectorPickler().register()
+ new DenseMatrixPickler().register()
+ new SparseVectorPickler().register()
+ new LabeledPointPickler().register()
+ new RatingPickler().register()
+ initialized = true
+ }
+ }
}
+ // will not called in Executor automatically
+ initialize()
def dumps(obj: AnyRef): Array[Byte] = {
new Pickler().dumps(obj)
@@ -610,4 +678,32 @@ private[spark] object SerDe extends Serializable {
rdd.map(x => Array(x._1, x._2))
}
+ /**
+ * Convert an RDD of Java objects to an RDD of serialized Python objects, that is usable by
+ * PySpark.
+ */
+ def javaToPython(jRDD: JavaRDD[Any]): JavaRDD[Array[Byte]] = {
+ jRDD.rdd.mapPartitions { iter =>
+ initialize() // let it called in executor
+ new PythonRDD.AutoBatchedPickler(iter)
+ }
+ }
+
+ /**
+ * Convert an RDD of serialized Python objects to RDD of objects, that is usable by PySpark.
+ */
+ def pythonToJava(pyRDD: JavaRDD[Array[Byte]], batched: Boolean): JavaRDD[Any] = {
+ pyRDD.rdd.mapPartitions { iter =>
+ initialize() // let it called in executor
+ val unpickle = new Unpickler
+ iter.flatMap { row =>
+ val obj = unpickle.loads(row)
+ if (batched) {
+ obj.asInstanceOf[JArrayList[_]].asScala
+ } else {
+ Seq(obj)
+ }
+ }
+ }.toJavaRDD()
+ }
}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/feature/Normalizer.scala b/mllib/src/main/scala/org/apache/spark/mllib/feature/Normalizer.scala
index 3afb47767281c..4734251127bb4 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/feature/Normalizer.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/feature/Normalizer.scala
@@ -17,7 +17,7 @@
package org.apache.spark.mllib.feature
-import breeze.linalg.{DenseVector => BDV, SparseVector => BSV}
+import breeze.linalg.{DenseVector => BDV, SparseVector => BSV, norm => brzNorm}
import org.apache.spark.annotation.Experimental
import org.apache.spark.mllib.linalg.{Vector, Vectors}
@@ -47,7 +47,7 @@ class Normalizer(p: Double) extends VectorTransformer {
* @return normalized vector. If the norm of the input is zero, it will return the input vector.
*/
override def transform(vector: Vector): Vector = {
- var norm = vector.toBreeze.norm(p)
+ var norm = brzNorm(vector.toBreeze, p)
if (norm != 0.0) {
// For dense vector, we've to allocate new memory for new output vector.
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala b/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala
index fc1444705364a..d321994c2a651 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala
@@ -67,7 +67,7 @@ private case class VocabWord(
class Word2Vec extends Serializable with Logging {
private var vectorSize = 100
- private var startingAlpha = 0.025
+ private var learningRate = 0.025
private var numPartitions = 1
private var numIterations = 1
private var seed = Utils.random.nextLong()
@@ -84,7 +84,7 @@ class Word2Vec extends Serializable with Logging {
* Sets initial learning rate (default: 0.025).
*/
def setLearningRate(learningRate: Double): this.type = {
- this.startingAlpha = learningRate
+ this.learningRate = learningRate
this
}
@@ -286,7 +286,7 @@ class Word2Vec extends Serializable with Logging {
val syn0Global =
Array.fill[Float](vocabSize * vectorSize)((initRandom.nextFloat() - 0.5f) / vectorSize)
val syn1Global = new Array[Float](vocabSize * vectorSize)
- var alpha = startingAlpha
+ var alpha = learningRate
for (k <- 1 to numIterations) {
val partial = newSentences.mapPartitionsWithIndex { case (idx, iter) =>
val random = new XORShiftRandom(seed ^ ((idx + 1) << 16) ^ ((-k - 1) << 8))
@@ -300,8 +300,8 @@ class Word2Vec extends Serializable with Logging {
lwc = wordCount
// TODO: discount by iteration?
alpha =
- startingAlpha * (1 - numPartitions * wordCount.toDouble / (trainWordsCount + 1))
- if (alpha < startingAlpha * 0.0001) alpha = startingAlpha * 0.0001
+ learningRate * (1 - numPartitions * wordCount.toDouble / (trainWordsCount + 1))
+ if (alpha < learningRate * 0.0001) alpha = learningRate * 0.0001
logInfo("wordCount = " + wordCount + ", alpha = " + alpha)
}
wc += sentence.size
@@ -437,7 +437,7 @@ class Word2VecModel private[mllib] (
* Find synonyms of a word
* @param word a word
* @param num number of synonyms to find
- * @return array of (word, similarity)
+ * @return array of (word, cosineSimilarity)
*/
def findSynonyms(word: String, num: Int): Array[(String, Double)] = {
val vector = transform(word)
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala
index 4e87fe088ecc5..2cc52e94282ba 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala
@@ -85,7 +85,7 @@ sealed trait Matrix extends Serializable {
}
/**
- * Column-majored dense matrix.
+ * Column-major dense matrix.
* The entry values are stored in a single array of doubles with columns listed in sequence.
* For example, the following matrix
* {{{
@@ -128,7 +128,7 @@ class DenseMatrix(val numRows: Int, val numCols: Int, val values: Array[Double])
}
/**
- * Column-majored sparse matrix.
+ * Column-major sparse matrix.
* The entry values are stored in Compressed Sparse Column (CSC) format.
* For example, the following matrix
* {{{
@@ -207,7 +207,7 @@ class SparseMatrix(
object Matrices {
/**
- * Creates a column-majored dense matrix.
+ * Creates a column-major dense matrix.
*
* @param numRows number of rows
* @param numCols number of columns
@@ -218,7 +218,7 @@ object Matrices {
}
/**
- * Creates a column-majored sparse matrix in Compressed Sparse Column (CSC) format.
+ * Creates a column-major sparse matrix in Compressed Sparse Column (CSC) format.
*
* @param numRows number of rows
* @param numCols number of columns
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala
index 8380058cf9b41..ec2d481dccc22 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala
@@ -111,7 +111,10 @@ class RowMatrix(
*/
def computeGramianMatrix(): Matrix = {
val n = numCols().toInt
- val nt: Int = n * (n + 1) / 2
+ checkNumColumns(n)
+ // Computes n*(n+1)/2, avoiding overflow in the multiplication.
+ // This succeeds when n <= 65535, which is checked above
+ val nt: Int = if (n % 2 == 0) ((n / 2) * (n + 1)) else (n * ((n + 1) / 2))
// Compute the upper triangular part of the gram matrix.
val GU = rows.treeAggregate(new BDV[Double](new Array[Double](nt)))(
@@ -123,6 +126,16 @@ class RowMatrix(
RowMatrix.triuToFull(n, GU.data)
}
+ private def checkNumColumns(cols: Int): Unit = {
+ if (cols > 65535) {
+ throw new IllegalArgumentException(s"Argument with more than 65535 cols: $cols")
+ }
+ if (cols > 10000) {
+ val mem = cols * cols * 8
+ logWarning(s"$cols columns will require at least $mem bytes of memory!")
+ }
+ }
+
/**
* Computes singular value decomposition of this matrix. Denote this matrix by A (m x n). This
* will compute matrices U, S, V such that A ~= U * S * V', where S contains the leading k
@@ -301,12 +314,7 @@ class RowMatrix(
*/
def computeCovariance(): Matrix = {
val n = numCols().toInt
-
- if (n > 10000) {
- val mem = n * n * java.lang.Double.SIZE / java.lang.Byte.SIZE
- logWarning(s"The number of columns $n is greater than 10000! " +
- s"We need at least $mem bytes of memory.")
- }
+ checkNumColumns(n)
val (m, mean) = rows.treeAggregate[(Long, BDV[Double])]((0L, BDV.zeros[Double](n)))(
seqOp = (s: (Long, BDV[Double]), v: Vector) => (s._1 + 1L, s._2 += v.toBreeze),
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/regression/GeneralizedLinearAlgorithm.scala b/mllib/src/main/scala/org/apache/spark/mllib/regression/GeneralizedLinearAlgorithm.scala
index d0fe4179685ca..00dfc86c9e0bd 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/regression/GeneralizedLinearAlgorithm.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/regression/GeneralizedLinearAlgorithm.scala
@@ -75,6 +75,8 @@ abstract class GeneralizedLinearModel(val weights: Vector, val intercept: Double
def predict(testData: Vector): Double = {
predictPoint(testData, weights, intercept)
}
+
+ override def toString() = "(weights=%s, intercept=%s)".format(weights, intercept)
}
/**
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala
index b7dc373ebd9cc..03eeaa707715b 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala
@@ -23,7 +23,6 @@ import scala.collection.mutable
import org.apache.spark.annotation.Experimental
import org.apache.spark.api.java.JavaRDD
import org.apache.spark.Logging
-import org.apache.spark.mllib.rdd.RDDFunctions._
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.tree.RandomForest.NodeIndexInfo
import org.apache.spark.mllib.tree.configuration.Strategy
@@ -36,6 +35,7 @@ import org.apache.spark.mllib.tree.impurity._
import org.apache.spark.mllib.tree.model._
import org.apache.spark.rdd.RDD
import org.apache.spark.util.random.XORShiftRandom
+import org.apache.spark.SparkContext._
/**
@@ -328,9 +328,8 @@ object DecisionTree extends Serializable with Logging {
* for each subset is updated.
*
* @param agg Array storing aggregate calculation, with a set of sufficient statistics for
- * each (node, feature, bin).
+ * each (feature, bin).
* @param treePoint Data point being aggregated.
- * @param nodeIndex Node corresponding to treePoint. agg is indexed in [0, numNodes).
* @param bins possible bins for all features, indexed (numFeatures)(numBins)
* @param unorderedFeatures Set of indices of unordered features.
* @param instanceWeight Weight (importance) of instance in dataset.
@@ -338,7 +337,6 @@ object DecisionTree extends Serializable with Logging {
private def mixedBinSeqOp(
agg: DTStatsAggregator,
treePoint: TreePoint,
- nodeIndex: Int,
bins: Array[Array[Bin]],
unorderedFeatures: Set[Int],
instanceWeight: Double,
@@ -350,7 +348,6 @@ object DecisionTree extends Serializable with Logging {
// Use all features
agg.metadata.numFeatures
}
- val nodeOffset = agg.getNodeOffset(nodeIndex)
// Iterate over features.
var featureIndexIdx = 0
while (featureIndexIdx < numFeaturesPerNode) {
@@ -363,16 +360,16 @@ object DecisionTree extends Serializable with Logging {
// Unordered feature
val featureValue = treePoint.binnedFeatures(featureIndex)
val (leftNodeFeatureOffset, rightNodeFeatureOffset) =
- agg.getLeftRightNodeFeatureOffsets(nodeIndex, featureIndexIdx)
+ agg.getLeftRightFeatureOffsets(featureIndexIdx)
// Update the left or right bin for each split.
val numSplits = agg.metadata.numSplits(featureIndex)
var splitIndex = 0
while (splitIndex < numSplits) {
if (bins(featureIndex)(splitIndex).highSplit.categories.contains(featureValue)) {
- agg.nodeFeatureUpdate(leftNodeFeatureOffset, splitIndex, treePoint.label,
+ agg.featureUpdate(leftNodeFeatureOffset, splitIndex, treePoint.label,
instanceWeight)
} else {
- agg.nodeFeatureUpdate(rightNodeFeatureOffset, splitIndex, treePoint.label,
+ agg.featureUpdate(rightNodeFeatureOffset, splitIndex, treePoint.label,
instanceWeight)
}
splitIndex += 1
@@ -380,8 +377,7 @@ object DecisionTree extends Serializable with Logging {
} else {
// Ordered feature
val binIndex = treePoint.binnedFeatures(featureIndex)
- agg.nodeUpdate(nodeOffset, nodeIndex, featureIndexIdx, binIndex, treePoint.label,
- instanceWeight)
+ agg.update(featureIndexIdx, binIndex, treePoint.label, instanceWeight)
}
featureIndexIdx += 1
}
@@ -393,26 +389,24 @@ object DecisionTree extends Serializable with Logging {
* For each feature, the sufficient statistics of one bin are updated.
*
* @param agg Array storing aggregate calculation, with a set of sufficient statistics for
- * each (node, feature, bin).
+ * each (feature, bin).
* @param treePoint Data point being aggregated.
- * @param nodeIndex Node corresponding to treePoint. agg is indexed in [0, numNodes).
* @param instanceWeight Weight (importance) of instance in dataset.
*/
private def orderedBinSeqOp(
agg: DTStatsAggregator,
treePoint: TreePoint,
- nodeIndex: Int,
instanceWeight: Double,
featuresForNode: Option[Array[Int]]): Unit = {
val label = treePoint.label
- val nodeOffset = agg.getNodeOffset(nodeIndex)
+
// Iterate over features.
if (featuresForNode.nonEmpty) {
// Use subsampled features
var featureIndexIdx = 0
while (featureIndexIdx < featuresForNode.get.size) {
val binIndex = treePoint.binnedFeatures(featuresForNode.get.apply(featureIndexIdx))
- agg.nodeUpdate(nodeOffset, nodeIndex, featureIndexIdx, binIndex, label, instanceWeight)
+ agg.update(featureIndexIdx, binIndex, label, instanceWeight)
featureIndexIdx += 1
}
} else {
@@ -421,7 +415,7 @@ object DecisionTree extends Serializable with Logging {
var featureIndex = 0
while (featureIndex < numFeatures) {
val binIndex = treePoint.binnedFeatures(featureIndex)
- agg.nodeUpdate(nodeOffset, nodeIndex, featureIndex, binIndex, label, instanceWeight)
+ agg.update(featureIndex, binIndex, label, instanceWeight)
featureIndex += 1
}
}
@@ -496,8 +490,8 @@ object DecisionTree extends Serializable with Logging {
* @return agg
*/
def binSeqOp(
- agg: DTStatsAggregator,
- baggedPoint: BaggedPoint[TreePoint]): DTStatsAggregator = {
+ agg: Array[DTStatsAggregator],
+ baggedPoint: BaggedPoint[TreePoint]): Array[DTStatsAggregator] = {
treeToNodeToIndexInfo.foreach { case (treeIndex, nodeIndexToInfo) =>
val nodeIndex = predictNodeIndex(topNodes(treeIndex), baggedPoint.datum.binnedFeatures,
bins, metadata.unorderedFeatures)
@@ -508,9 +502,9 @@ object DecisionTree extends Serializable with Logging {
val featuresForNode = nodeInfo.featureSubset
val instanceWeight = baggedPoint.subsampleWeights(treeIndex)
if (metadata.unorderedFeatures.isEmpty) {
- orderedBinSeqOp(agg, baggedPoint.datum, aggNodeIndex, instanceWeight, featuresForNode)
+ orderedBinSeqOp(agg(aggNodeIndex), baggedPoint.datum, instanceWeight, featuresForNode)
} else {
- mixedBinSeqOp(agg, baggedPoint.datum, aggNodeIndex, bins, metadata.unorderedFeatures,
+ mixedBinSeqOp(agg(aggNodeIndex), baggedPoint.datum, bins, metadata.unorderedFeatures,
instanceWeight, featuresForNode)
}
}
@@ -518,46 +512,113 @@ object DecisionTree extends Serializable with Logging {
agg
}
- // Calculate bin aggregates.
- timer.start("aggregation")
- val binAggregates: DTStatsAggregator = {
- val initAgg = if (metadata.subsamplingFeatures) {
- new DTStatsAggregatorSubsampledFeatures(metadata, treeToNodeToIndexInfo)
- } else {
- new DTStatsAggregatorFixedFeatures(metadata, numNodes)
+ /**
+ * Get node index in group --> features indices map,
+ * which is a short cut to find feature indices for a node given node index in group
+ * @param treeToNodeToIndexInfo
+ * @return
+ */
+ def getNodeToFeatures(treeToNodeToIndexInfo: Map[Int, Map[Int, NodeIndexInfo]])
+ : Option[Map[Int, Array[Int]]] = if (!metadata.subsamplingFeatures) {
+ None
+ } else {
+ val mutableNodeToFeatures = new mutable.HashMap[Int, Array[Int]]()
+ treeToNodeToIndexInfo.values.foreach { nodeIdToNodeInfo =>
+ nodeIdToNodeInfo.values.foreach { nodeIndexInfo =>
+ assert(nodeIndexInfo.featureSubset.isDefined)
+ mutableNodeToFeatures(nodeIndexInfo.nodeIndexInGroup) = nodeIndexInfo.featureSubset.get
+ }
+ }
+ Some(mutableNodeToFeatures.toMap)
+ }
+
+ // array of nodes to train indexed by node index in group
+ val nodes = new Array[Node](numNodes)
+ nodesForGroup.foreach { case (treeIndex, nodesForTree) =>
+ nodesForTree.foreach { node =>
+ nodes(treeToNodeToIndexInfo(treeIndex)(node.id).nodeIndexInGroup) = node
}
- input.treeAggregate(initAgg)(binSeqOp, DTStatsAggregator.binCombOp)
}
- timer.stop("aggregation")
// Calculate best splits for all nodes in the group
timer.start("chooseSplits")
+ // In each partition, iterate all instances and compute aggregate stats for each node,
+ // yield an (nodeIndex, nodeAggregateStats) pair for each node.
+ // After a `reduceByKey` operation,
+ // stats of a node will be shuffled to a particular partition and be combined together,
+ // then best splits for nodes are found there.
+ // Finally, only best Splits for nodes are collected to driver to construct decision tree.
+ val nodeToFeatures = getNodeToFeatures(treeToNodeToIndexInfo)
+ val nodeToFeaturesBc = input.sparkContext.broadcast(nodeToFeatures)
+ val nodeToBestSplits =
+ input.mapPartitions { points =>
+ // Construct a nodeStatsAggregators array to hold node aggregate stats,
+ // each node will have a nodeStatsAggregator
+ val nodeStatsAggregators = Array.tabulate(numNodes) { nodeIndex =>
+ val featuresForNode = nodeToFeaturesBc.value.flatMap { nodeToFeatures =>
+ Some(nodeToFeatures(nodeIndex))
+ }
+ new DTStatsAggregator(metadata, featuresForNode)
+ }
+
+ // iterator all instances in current partition and update aggregate stats
+ points.foreach(binSeqOp(nodeStatsAggregators, _))
+
+ // transform nodeStatsAggregators array to (nodeIndex, nodeAggregateStats) pairs,
+ // which can be combined with other partition using `reduceByKey`
+ nodeStatsAggregators.view.zipWithIndex.map(_.swap).iterator
+ }.reduceByKey((a, b) => a.merge(b))
+ .map { case (nodeIndex, aggStats) =>
+ val featuresForNode = nodeToFeaturesBc.value.flatMap { nodeToFeatures =>
+ Some(nodeToFeatures(nodeIndex))
+ }
+
+ // find best split for each node
+ val (split: Split, stats: InformationGainStats, predict: Predict) =
+ binsToBestSplit(aggStats, splits, featuresForNode, nodes(nodeIndex))
+ (nodeIndex, (split, stats, predict))
+ }.collectAsMap()
+
+ timer.stop("chooseSplits")
+
// Iterate over all nodes in this group.
nodesForGroup.foreach { case (treeIndex, nodesForTree) =>
nodesForTree.foreach { node =>
val nodeIndex = node.id
val nodeInfo = treeToNodeToIndexInfo(treeIndex)(nodeIndex)
val aggNodeIndex = nodeInfo.nodeIndexInGroup
- val featuresForNode = nodeInfo.featureSubset
val (split: Split, stats: InformationGainStats, predict: Predict) =
- binsToBestSplit(binAggregates, aggNodeIndex, splits, featuresForNode)
+ nodeToBestSplits(aggNodeIndex)
logDebug("best split = " + split)
// Extract info for this node. Create children if not leaf.
val isLeaf = (stats.gain <= 0) || (Node.indexToLevel(nodeIndex) == metadata.maxDepth)
assert(node.id == nodeIndex)
- node.predict = predict.predict
+ node.predict = predict
node.isLeaf = isLeaf
node.stats = Some(stats)
+ node.impurity = stats.impurity
logDebug("Node = " + node)
if (!isLeaf) {
node.split = Some(split)
- node.leftNode = Some(Node.emptyNode(Node.leftChildIndex(nodeIndex)))
- node.rightNode = Some(Node.emptyNode(Node.rightChildIndex(nodeIndex)))
- nodeQueue.enqueue((treeIndex, node.leftNode.get))
- nodeQueue.enqueue((treeIndex, node.rightNode.get))
+ val childIsLeaf = (Node.indexToLevel(nodeIndex) + 1) == metadata.maxDepth
+ val leftChildIsLeaf = childIsLeaf || (stats.leftImpurity == 0.0)
+ val rightChildIsLeaf = childIsLeaf || (stats.rightImpurity == 0.0)
+ node.leftNode = Some(Node(Node.leftChildIndex(nodeIndex),
+ stats.leftPredict, stats.leftImpurity, leftChildIsLeaf))
+ node.rightNode = Some(Node(Node.rightChildIndex(nodeIndex),
+ stats.rightPredict, stats.rightImpurity, rightChildIsLeaf))
+
+ // enqueue left child and right child if they are not leaves
+ if (!leftChildIsLeaf) {
+ nodeQueue.enqueue((treeIndex, node.leftNode.get))
+ }
+ if (!rightChildIsLeaf) {
+ nodeQueue.enqueue((treeIndex, node.rightNode.get))
+ }
+
logDebug("leftChildIndex = " + node.leftNode.get.id +
", impurity = " + stats.leftImpurity)
logDebug("rightChildIndex = " + node.rightNode.get.id +
@@ -565,7 +626,7 @@ object DecisionTree extends Serializable with Logging {
}
}
}
- timer.stop("chooseSplits")
+
}
/**
@@ -577,7 +638,8 @@ object DecisionTree extends Serializable with Logging {
private def calculateGainForSplit(
leftImpurityCalculator: ImpurityCalculator,
rightImpurityCalculator: ImpurityCalculator,
- metadata: DecisionTreeMetadata): InformationGainStats = {
+ metadata: DecisionTreeMetadata,
+ impurity: Double): InformationGainStats = {
val leftCount = leftImpurityCalculator.count
val rightCount = rightImpurityCalculator.count
@@ -590,11 +652,6 @@ object DecisionTree extends Serializable with Logging {
val totalCount = leftCount + rightCount
- val parentNodeAgg = leftImpurityCalculator.copy
- parentNodeAgg.add(rightImpurityCalculator)
-
- val impurity = parentNodeAgg.calculate()
-
val leftImpurity = leftImpurityCalculator.calculate() // Note: This equals 0 if count = 0
val rightImpurity = rightImpurityCalculator.calculate()
@@ -609,7 +666,18 @@ object DecisionTree extends Serializable with Logging {
return InformationGainStats.invalidInformationGainStats
}
- new InformationGainStats(gain, impurity, leftImpurity, rightImpurity)
+ // calculate left and right predict
+ val leftPredict = calculatePredict(leftImpurityCalculator)
+ val rightPredict = calculatePredict(rightImpurityCalculator)
+
+ new InformationGainStats(gain, impurity, leftImpurity, rightImpurity,
+ leftPredict, rightPredict)
+ }
+
+ private def calculatePredict(impurityCalculator: ImpurityCalculator): Predict = {
+ val predict = impurityCalculator.predict
+ val prob = impurityCalculator.prob(predict)
+ new Predict(predict, prob)
}
/**
@@ -617,52 +685,55 @@ object DecisionTree extends Serializable with Logging {
* Note that this function is called only once for each node.
* @param leftImpurityCalculator left node aggregates for a split
* @param rightImpurityCalculator right node aggregates for a split
- * @return predict value for current node
+ * @return predict value and impurity for current node
*/
- private def calculatePredict(
+ private def calculatePredictImpurity(
leftImpurityCalculator: ImpurityCalculator,
- rightImpurityCalculator: ImpurityCalculator): Predict = {
+ rightImpurityCalculator: ImpurityCalculator): (Predict, Double) = {
val parentNodeAgg = leftImpurityCalculator.copy
parentNodeAgg.add(rightImpurityCalculator)
- val predict = parentNodeAgg.predict
- val prob = parentNodeAgg.prob(predict)
+ val predict = calculatePredict(parentNodeAgg)
+ val impurity = parentNodeAgg.calculate()
- new Predict(predict, prob)
+ (predict, impurity)
}
/**
* Find the best split for a node.
* @param binAggregates Bin statistics.
- * @param nodeIndex Index into aggregates for node to split in this group.
* @return tuple for best split: (Split, information gain, prediction at node)
*/
private def binsToBestSplit(
binAggregates: DTStatsAggregator,
- nodeIndex: Int,
splits: Array[Array[Split]],
- featuresForNode: Option[Array[Int]]): (Split, InformationGainStats, Predict) = {
-
- val metadata: DecisionTreeMetadata = binAggregates.metadata
+ featuresForNode: Option[Array[Int]],
+ node: Node): (Split, InformationGainStats, Predict) = {
- // calculate predict only once
- var predict: Option[Predict] = None
+ // calculate predict and impurity if current node is top node
+ val level = Node.indexToLevel(node.id)
+ var predictWithImpurity: Option[(Predict, Double)] = if (level == 0) {
+ None
+ } else {
+ Some((node.predict, node.impurity))
+ }
// For each (feature, split), calculate the gain, and select the best (feature, split).
- val (bestSplit, bestSplitStats) = Range(0, metadata.numFeaturesPerNode).map { featureIndexIdx =>
+ val (bestSplit, bestSplitStats) =
+ Range(0, binAggregates.metadata.numFeaturesPerNode).map { featureIndexIdx =>
val featureIndex = if (featuresForNode.nonEmpty) {
featuresForNode.get.apply(featureIndexIdx)
} else {
featureIndexIdx
}
- val numSplits = metadata.numSplits(featureIndex)
- if (metadata.isContinuous(featureIndex)) {
+ val numSplits = binAggregates.metadata.numSplits(featureIndex)
+ if (binAggregates.metadata.isContinuous(featureIndex)) {
// Cumulative sum (scanLeft) of bin statistics.
// Afterwards, binAggregates for a bin is the sum of aggregates for
// that bin + all preceding bins.
- val nodeFeatureOffset = binAggregates.getNodeFeatureOffset(nodeIndex, featureIndexIdx)
+ val nodeFeatureOffset = binAggregates.getFeatureOffset(featureIndexIdx)
var splitIndex = 0
while (splitIndex < numSplits) {
- binAggregates.mergeForNodeFeature(nodeFeatureOffset, splitIndex + 1, splitIndex)
+ binAggregates.mergeForFeature(nodeFeatureOffset, splitIndex + 1, splitIndex)
splitIndex += 1
}
// Find best split.
@@ -671,28 +742,32 @@ object DecisionTree extends Serializable with Logging {
val leftChildStats = binAggregates.getImpurityCalculator(nodeFeatureOffset, splitIdx)
val rightChildStats = binAggregates.getImpurityCalculator(nodeFeatureOffset, numSplits)
rightChildStats.subtract(leftChildStats)
- predict = Some(predict.getOrElse(calculatePredict(leftChildStats, rightChildStats)))
- val gainStats = calculateGainForSplit(leftChildStats, rightChildStats, metadata)
+ predictWithImpurity = Some(predictWithImpurity.getOrElse(
+ calculatePredictImpurity(leftChildStats, rightChildStats)))
+ val gainStats = calculateGainForSplit(leftChildStats,
+ rightChildStats, binAggregates.metadata, predictWithImpurity.get._2)
(splitIdx, gainStats)
}.maxBy(_._2.gain)
(splits(featureIndex)(bestFeatureSplitIndex), bestFeatureGainStats)
- } else if (metadata.isUnordered(featureIndex)) {
+ } else if (binAggregates.metadata.isUnordered(featureIndex)) {
// Unordered categorical feature
val (leftChildOffset, rightChildOffset) =
- binAggregates.getLeftRightNodeFeatureOffsets(nodeIndex, featureIndexIdx)
+ binAggregates.getLeftRightFeatureOffsets(featureIndexIdx)
val (bestFeatureSplitIndex, bestFeatureGainStats) =
Range(0, numSplits).map { splitIndex =>
val leftChildStats = binAggregates.getImpurityCalculator(leftChildOffset, splitIndex)
val rightChildStats = binAggregates.getImpurityCalculator(rightChildOffset, splitIndex)
- predict = Some(predict.getOrElse(calculatePredict(leftChildStats, rightChildStats)))
- val gainStats = calculateGainForSplit(leftChildStats, rightChildStats, metadata)
+ predictWithImpurity = Some(predictWithImpurity.getOrElse(
+ calculatePredictImpurity(leftChildStats, rightChildStats)))
+ val gainStats = calculateGainForSplit(leftChildStats,
+ rightChildStats, binAggregates.metadata, predictWithImpurity.get._2)
(splitIndex, gainStats)
}.maxBy(_._2.gain)
(splits(featureIndex)(bestFeatureSplitIndex), bestFeatureGainStats)
} else {
// Ordered categorical feature
- val nodeFeatureOffset = binAggregates.getNodeFeatureOffset(nodeIndex, featureIndexIdx)
- val numBins = metadata.numBins(featureIndex)
+ val nodeFeatureOffset = binAggregates.getFeatureOffset(featureIndexIdx)
+ val numBins = binAggregates.metadata.numBins(featureIndex)
/* Each bin is one category (feature value).
* The bins are ordered based on centroidForCategories, and this ordering determines which
@@ -700,7 +775,7 @@ object DecisionTree extends Serializable with Logging {
*
* centroidForCategories is a list: (category, centroid)
*/
- val centroidForCategories = if (metadata.isMulticlass) {
+ val centroidForCategories = if (binAggregates.metadata.isMulticlass) {
// For categorical variables in multiclass classification,
// the bins are ordered by the impurity of their corresponding labels.
Range(0, numBins).map { case featureValue =>
@@ -741,7 +816,7 @@ object DecisionTree extends Serializable with Logging {
while (splitIndex < numSplits) {
val currentCategory = categoriesSortedByCentroid(splitIndex)._1
val nextCategory = categoriesSortedByCentroid(splitIndex + 1)._1
- binAggregates.mergeForNodeFeature(nodeFeatureOffset, nextCategory, currentCategory)
+ binAggregates.mergeForFeature(nodeFeatureOffset, nextCategory, currentCategory)
splitIndex += 1
}
// lastCategory = index of bin with total aggregates for this (node, feature)
@@ -755,8 +830,10 @@ object DecisionTree extends Serializable with Logging {
val rightChildStats =
binAggregates.getImpurityCalculator(nodeFeatureOffset, lastCategory)
rightChildStats.subtract(leftChildStats)
- predict = Some(predict.getOrElse(calculatePredict(leftChildStats, rightChildStats)))
- val gainStats = calculateGainForSplit(leftChildStats, rightChildStats, metadata)
+ predictWithImpurity = Some(predictWithImpurity.getOrElse(
+ calculatePredictImpurity(leftChildStats, rightChildStats)))
+ val gainStats = calculateGainForSplit(leftChildStats,
+ rightChildStats, binAggregates.metadata, predictWithImpurity.get._2)
(splitIndex, gainStats)
}.maxBy(_._2.gain)
val categoriesForSplit =
@@ -767,9 +844,7 @@ object DecisionTree extends Serializable with Logging {
}
}.maxBy(_._2.gain)
- assert(predict.isDefined, "must calculate predict for each node")
-
- (bestSplit, bestSplitStats, predict.get)
+ (bestSplit, bestSplitStats, predictWithImpurity.get._1)
}
/**
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/RandomForest.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/RandomForest.scala
index 7fa7725e79e46..ebbd8e0257209 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/RandomForest.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/RandomForest.scala
@@ -171,11 +171,13 @@ private class RandomForest (
// Choose node splits, and enqueue new nodes as needed.
timer.start("findBestSplits")
- DecisionTree.findBestSplits(baggedInput,
- metadata, topNodes, nodesForGroup, treeToNodeToIndexInfo, splits, bins, nodeQueue, timer)
+ DecisionTree.findBestSplits(baggedInput, metadata, topNodes, nodesForGroup,
+ treeToNodeToIndexInfo, splits, bins, nodeQueue, timer)
timer.stop("findBestSplits")
}
+ baggedInput.unpersist()
+
timer.stop("total")
logInfo("Internal timing for DecisionTree:")
@@ -382,6 +384,7 @@ object RandomForest extends Serializable with Logging {
* @param maxMemoryUsage Bound on size of aggregate statistics.
* @return (nodesForGroup, treeToNodeToIndexInfo).
* nodesForGroup holds the nodes to split: treeIndex --> nodes in tree.
+ *
* treeToNodeToIndexInfo holds indices selected features for each node:
* treeIndex --> (global) node index --> (node index in group, feature indices).
* The (global) node index is the index in the tree; the node index in group is the
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DTStatsAggregator.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DTStatsAggregator.scala
index d49df7a016375..ce8825cc03229 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DTStatsAggregator.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DTStatsAggregator.scala
@@ -17,17 +17,19 @@
package org.apache.spark.mllib.tree.impl
-import org.apache.spark.mllib.tree.RandomForest.NodeIndexInfo
import org.apache.spark.mllib.tree.impurity._
+
+
/**
- * DecisionTree statistics aggregator.
- * This holds a flat array of statistics for a set of (nodes, features, bins)
+ * DecisionTree statistics aggregator for a node.
+ * This holds a flat array of statistics for a set of (features, bins)
* and helps with indexing.
* This class is abstract to support learning with and without feature subsampling.
*/
-private[tree] abstract class DTStatsAggregator(
- val metadata: DecisionTreeMetadata) extends Serializable {
+private[tree] class DTStatsAggregator(
+ val metadata: DecisionTreeMetadata,
+ featureSubset: Option[Array[Int]]) extends Serializable {
/**
* [[ImpurityAggregator]] instance specifying the impurity type.
@@ -42,116 +44,108 @@ private[tree] abstract class DTStatsAggregator(
/**
* Number of elements (Double values) used for the sufficient statistics of each bin.
*/
- val statsSize: Int = impurityAggregator.statsSize
+ private val statsSize: Int = impurityAggregator.statsSize
/**
- * Indicator for each feature of whether that feature is an unordered feature.
- * TODO: Is Array[Boolean] any faster?
+ * Number of bins for each feature. This is indexed by the feature index.
*/
- def isUnordered(featureIndex: Int): Boolean = metadata.isUnordered(featureIndex)
+ private val numBins: Array[Int] = {
+ if (featureSubset.isDefined) {
+ featureSubset.get.map(metadata.numBins(_))
+ } else {
+ metadata.numBins
+ }
+ }
/**
- * Total number of elements stored in this aggregator.
+ * Offset for each feature for calculating indices into the [[allStats]] array.
*/
- def allStatsSize: Int
+ private val featureOffsets: Array[Int] = {
+ numBins.scanLeft(0)((total, nBins) => total + statsSize * nBins)
+ }
/**
- * Get flat array of elements stored in this aggregator.
+ * Total number of elements stored in this aggregator
*/
- protected def allStats: Array[Double]
+ private val allStatsSize: Int = featureOffsets.last
+
+ /**
+ * Flat array of elements.
+ * Index for start of stats for a (feature, bin) is:
+ * index = featureOffsets(featureIndex) + binIndex * statsSize
+ * Note: For unordered features,
+ * the left child stats have binIndex in [0, numBins(featureIndex) / 2))
+ * and the right child stats in [numBins(featureIndex) / 2), numBins(featureIndex))
+ */
+ private val allStats: Array[Double] = new Array[Double](allStatsSize)
+
/**
* Get an [[ImpurityCalculator]] for a given (node, feature, bin).
- * @param nodeFeatureOffset For ordered features, this is a pre-computed (node, feature) offset
- * from [[getNodeFeatureOffset]].
+ * @param featureOffset For ordered features, this is a pre-computed (node, feature) offset
+ * from [[getFeatureOffset]].
* For unordered features, this is a pre-computed
* (node, feature, left/right child) offset from
- * [[getLeftRightNodeFeatureOffsets]].
+ * [[getLeftRightFeatureOffsets]].
*/
- def getImpurityCalculator(nodeFeatureOffset: Int, binIndex: Int): ImpurityCalculator = {
- impurityAggregator.getCalculator(allStats, nodeFeatureOffset + binIndex * statsSize)
+ def getImpurityCalculator(featureOffset: Int, binIndex: Int): ImpurityCalculator = {
+ impurityAggregator.getCalculator(allStats, featureOffset + binIndex * statsSize)
}
/**
- * Update the stats for a given (node, feature, bin) for ordered features, using the given label.
+ * Update the stats for a given (feature, bin) for ordered features, using the given label.
*/
- def update(
- nodeIndex: Int,
- featureIndex: Int,
- binIndex: Int,
- label: Double,
- instanceWeight: Double): Unit = {
- val i = getNodeFeatureOffset(nodeIndex, featureIndex) + binIndex * statsSize
+ def update(featureIndex: Int, binIndex: Int, label: Double, instanceWeight: Double): Unit = {
+ val i = featureOffsets(featureIndex) + binIndex * statsSize
impurityAggregator.update(allStats, i, label, instanceWeight)
}
- /**
- * Pre-compute node offset for use with [[nodeUpdate]].
- */
- def getNodeOffset(nodeIndex: Int): Int
-
/**
* Faster version of [[update]].
- * Update the stats for a given (node, feature, bin) for ordered features, using the given label.
- * @param nodeOffset Pre-computed node offset from [[getNodeOffset]].
+ * Update the stats for a given (feature, bin), using the given label.
+ * @param featureOffset For ordered features, this is a pre-computed feature offset
+ * from [[getFeatureOffset]].
+ * For unordered features, this is a pre-computed
+ * (feature, left/right child) offset from
+ * [[getLeftRightFeatureOffsets]].
*/
- def nodeUpdate(
- nodeOffset: Int,
- nodeIndex: Int,
- featureIndex: Int,
+ def featureUpdate(
+ featureOffset: Int,
binIndex: Int,
label: Double,
- instanceWeight: Double): Unit
+ instanceWeight: Double): Unit = {
+ impurityAggregator.update(allStats, featureOffset + binIndex * statsSize,
+ label, instanceWeight)
+ }
/**
- * Pre-compute (node, feature) offset for use with [[nodeFeatureUpdate]].
+ * Pre-compute feature offset for use with [[featureUpdate]].
* For ordered features only.
*/
- def getNodeFeatureOffset(nodeIndex: Int, featureIndex: Int): Int
+ def getFeatureOffset(featureIndex: Int): Int = featureOffsets(featureIndex)
/**
- * Pre-compute (node, feature) offset for use with [[nodeFeatureUpdate]].
+ * Pre-compute feature offset for use with [[featureUpdate]].
* For unordered features only.
*/
- def getLeftRightNodeFeatureOffsets(nodeIndex: Int, featureIndex: Int): (Int, Int) = {
- require(isUnordered(featureIndex),
- s"DTStatsAggregator.getLeftRightNodeFeatureOffsets is for unordered features only," +
- s" but was called for ordered feature $featureIndex.")
- val baseOffset = getNodeFeatureOffset(nodeIndex, featureIndex)
- (baseOffset, baseOffset + (metadata.numBins(featureIndex) >> 1) * statsSize)
- }
-
- /**
- * Faster version of [[update]].
- * Update the stats for a given (node, feature, bin), using the given label.
- * @param nodeFeatureOffset For ordered features, this is a pre-computed (node, feature) offset
- * from [[getNodeFeatureOffset]].
- * For unordered features, this is a pre-computed
- * (node, feature, left/right child) offset from
- * [[getLeftRightNodeFeatureOffsets]].
- */
- def nodeFeatureUpdate(
- nodeFeatureOffset: Int,
- binIndex: Int,
- label: Double,
- instanceWeight: Double): Unit = {
- impurityAggregator.update(allStats, nodeFeatureOffset + binIndex * statsSize, label,
- instanceWeight)
+ def getLeftRightFeatureOffsets(featureIndex: Int): (Int, Int) = {
+ val baseOffset = featureOffsets(featureIndex)
+ (baseOffset, baseOffset + (numBins(featureIndex) >> 1) * statsSize)
}
/**
- * For a given (node, feature), merge the stats for two bins.
- * @param nodeFeatureOffset For ordered features, this is a pre-computed (node, feature) offset
- * from [[getNodeFeatureOffset]].
+ * For a given feature, merge the stats for two bins.
+ * @param featureOffset For ordered features, this is a pre-computed feature offset
+ * from [[getFeatureOffset]].
* For unordered features, this is a pre-computed
- * (node, feature, left/right child) offset from
- * [[getLeftRightNodeFeatureOffsets]].
+ * (feature, left/right child) offset from
+ * [[getLeftRightFeatureOffsets]].
* @param binIndex The other bin is merged into this bin.
* @param otherBinIndex This bin is not modified.
*/
- def mergeForNodeFeature(nodeFeatureOffset: Int, binIndex: Int, otherBinIndex: Int): Unit = {
- impurityAggregator.merge(allStats, nodeFeatureOffset + binIndex * statsSize,
- nodeFeatureOffset + otherBinIndex * statsSize)
+ def mergeForFeature(featureOffset: Int, binIndex: Int, otherBinIndex: Int): Unit = {
+ impurityAggregator.merge(allStats, featureOffset + binIndex * statsSize,
+ featureOffset + otherBinIndex * statsSize)
}
/**
@@ -161,7 +155,7 @@ private[tree] abstract class DTStatsAggregator(
def merge(other: DTStatsAggregator): DTStatsAggregator = {
require(allStatsSize == other.allStatsSize,
s"DTStatsAggregator.merge requires that both aggregators have the same length stats vectors."
- + s" This aggregator is of length $allStatsSize, but the other is ${other.allStatsSize}.")
+ + s" This aggregator is of length $allStatsSize, but the other is ${other.allStatsSize}.")
var i = 0
// TODO: Test BLAS.axpy
while (i < allStatsSize) {
@@ -171,149 +165,3 @@ private[tree] abstract class DTStatsAggregator(
this
}
}
-
-/**
- * DecisionTree statistics aggregator.
- * This holds a flat array of statistics for a set of (nodes, features, bins)
- * and helps with indexing.
- *
- * This instance of [[DTStatsAggregator]] is used when not subsampling features.
- *
- * @param numNodes Number of nodes to collect statistics for.
- */
-private[tree] class DTStatsAggregatorFixedFeatures(
- metadata: DecisionTreeMetadata,
- numNodes: Int) extends DTStatsAggregator(metadata) {
-
- /**
- * Offset for each feature for calculating indices into the [[allStats]] array.
- * Mapping: featureIndex --> offset
- */
- private val featureOffsets: Array[Int] = {
- metadata.numBins.scanLeft(0)((total, nBins) => total + statsSize * nBins)
- }
-
- /**
- * Number of elements for each node, corresponding to stride between nodes in [[allStats]].
- */
- private val nodeStride: Int = featureOffsets.last
-
- override val allStatsSize: Int = numNodes * nodeStride
-
- /**
- * Flat array of elements.
- * Index for start of stats for a (node, feature, bin) is:
- * index = nodeIndex * nodeStride + featureOffsets(featureIndex) + binIndex * statsSize
- * Note: For unordered features, the left child stats precede the right child stats
- * in the binIndex order.
- */
- override protected val allStats: Array[Double] = new Array[Double](allStatsSize)
-
- override def getNodeOffset(nodeIndex: Int): Int = nodeIndex * nodeStride
-
- override def nodeUpdate(
- nodeOffset: Int,
- nodeIndex: Int,
- featureIndex: Int,
- binIndex: Int,
- label: Double,
- instanceWeight: Double): Unit = {
- val i = nodeOffset + featureOffsets(featureIndex) + binIndex * statsSize
- impurityAggregator.update(allStats, i, label, instanceWeight)
- }
-
- override def getNodeFeatureOffset(nodeIndex: Int, featureIndex: Int): Int = {
- nodeIndex * nodeStride + featureOffsets(featureIndex)
- }
-}
-
-/**
- * DecisionTree statistics aggregator.
- * This holds a flat array of statistics for a set of (nodes, features, bins)
- * and helps with indexing.
- *
- * This instance of [[DTStatsAggregator]] is used when subsampling features.
- *
- * @param treeToNodeToIndexInfo Mapping: treeIndex --> nodeIndex --> nodeIndexInfo,
- * where nodeIndexInfo stores the index in the group and the
- * feature subsets (if using feature subsets).
- */
-private[tree] class DTStatsAggregatorSubsampledFeatures(
- metadata: DecisionTreeMetadata,
- treeToNodeToIndexInfo: Map[Int, Map[Int, NodeIndexInfo]]) extends DTStatsAggregator(metadata) {
-
- /**
- * For each node, offset for each feature for calculating indices into the [[allStats]] array.
- * Mapping: nodeIndex --> featureIndex --> offset
- */
- private val featureOffsets: Array[Array[Int]] = {
- val numNodes: Int = treeToNodeToIndexInfo.values.map(_.size).sum
- val offsets = new Array[Array[Int]](numNodes)
- treeToNodeToIndexInfo.foreach { case (treeIndex, nodeToIndexInfo) =>
- nodeToIndexInfo.foreach { case (globalNodeIndex, nodeInfo) =>
- offsets(nodeInfo.nodeIndexInGroup) = nodeInfo.featureSubset.get.map(metadata.numBins(_))
- .scanLeft(0)((total, nBins) => total + statsSize * nBins)
- }
- }
- offsets
- }
-
- /**
- * For each node, offset for each feature for calculating indices into the [[allStats]] array.
- */
- protected val nodeOffsets: Array[Int] = featureOffsets.map(_.last).scanLeft(0)(_ + _)
-
- override val allStatsSize: Int = nodeOffsets.last
-
- /**
- * Flat array of elements.
- * Index for start of stats for a (node, feature, bin) is:
- * index = nodeOffsets(nodeIndex) + featureOffsets(featureIndex) + binIndex * statsSize
- * Note: For unordered features, the left child stats precede the right child stats
- * in the binIndex order.
- */
- override protected val allStats: Array[Double] = new Array[Double](allStatsSize)
-
- override def getNodeOffset(nodeIndex: Int): Int = nodeOffsets(nodeIndex)
-
- /**
- * Faster version of [[update]].
- * Update the stats for a given (node, feature, bin) for ordered features, using the given label.
- * @param nodeOffset Pre-computed node offset from [[getNodeOffset]].
- * @param featureIndex Index of feature in featuresForNodes(nodeIndex).
- * Note: This is NOT the original feature index.
- */
- override def nodeUpdate(
- nodeOffset: Int,
- nodeIndex: Int,
- featureIndex: Int,
- binIndex: Int,
- label: Double,
- instanceWeight: Double): Unit = {
- val i = nodeOffset + featureOffsets(nodeIndex)(featureIndex) + binIndex * statsSize
- impurityAggregator.update(allStats, i, label, instanceWeight)
- }
-
- /**
- * Pre-compute (node, feature) offset for use with [[nodeFeatureUpdate]].
- * For ordered features only.
- * @param featureIndex Index of feature in featuresForNodes(nodeIndex).
- * Note: This is NOT the original feature index.
- */
- override def getNodeFeatureOffset(nodeIndex: Int, featureIndex: Int): Int = {
- nodeOffsets(nodeIndex) + featureOffsets(nodeIndex)(featureIndex)
- }
-}
-
-private[tree] object DTStatsAggregator extends Serializable {
-
- /**
- * Combines two aggregates (modifying the first) and returns the combination.
- */
- def binCombOp(
- agg1: DTStatsAggregator,
- agg2: DTStatsAggregator): DTStatsAggregator = {
- agg1.merge(agg2)
- }
-
-}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DecisionTreeMetadata.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DecisionTreeMetadata.scala
index 212dce25236e0..772c02670e541 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DecisionTreeMetadata.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DecisionTreeMetadata.scala
@@ -19,6 +19,7 @@ package org.apache.spark.mllib.tree.impl
import scala.collection.mutable
+import org.apache.spark.Logging
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.tree.configuration.Algo._
import org.apache.spark.mllib.tree.configuration.QuantileStrategy._
@@ -82,7 +83,7 @@ private[tree] class DecisionTreeMetadata(
}
-private[tree] object DecisionTreeMetadata {
+private[tree] object DecisionTreeMetadata extends Logging {
/**
* Construct a [[DecisionTreeMetadata]] instance for this dataset and parameters.
@@ -103,6 +104,10 @@ private[tree] object DecisionTreeMetadata {
}
val maxPossibleBins = math.min(strategy.maxBins, numExamples).toInt
+ if (maxPossibleBins < strategy.maxBins) {
+ logWarning(s"DecisionTree reducing maxBins from ${strategy.maxBins} to $maxPossibleBins" +
+ s" (= number of training instances)")
+ }
// We check the number of bins here against maxPossibleBins.
// This needs to be checked here instead of in Strategy since maxPossibleBins can be modified
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala
index 271b2c4ad813e..ec1d99ab26f9c 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala
@@ -68,15 +68,23 @@ class DecisionTreeModel(val topNode: Node, val algo: Algo) extends Serializable
}
/**
- * Print full model.
+ * Print a summary of the model.
*/
override def toString: String = algo match {
case Classification =>
- s"DecisionTreeModel classifier\n" + topNode.subtreeToString(2)
+ s"DecisionTreeModel classifier of depth $depth with $numNodes nodes"
case Regression =>
- s"DecisionTreeModel regressor\n" + topNode.subtreeToString(2)
+ s"DecisionTreeModel regressor of depth $depth with $numNodes nodes"
case _ => throw new IllegalArgumentException(
s"DecisionTreeModel given unknown algo parameter: $algo.")
}
+ /**
+ * Print the full model to a string.
+ */
+ def toDebugString: String = {
+ val header = toString + "\n"
+ header + topNode.subtreeToString(2)
+ }
+
}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/InformationGainStats.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/InformationGainStats.scala
index f3e2619bd8ba0..9a50ecb550c38 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/InformationGainStats.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/InformationGainStats.scala
@@ -26,18 +26,33 @@ import org.apache.spark.annotation.DeveloperApi
* @param impurity current node impurity
* @param leftImpurity left node impurity
* @param rightImpurity right node impurity
+ * @param leftPredict left node predict
+ * @param rightPredict right node predict
*/
@DeveloperApi
class InformationGainStats(
val gain: Double,
val impurity: Double,
val leftImpurity: Double,
- val rightImpurity: Double) extends Serializable {
+ val rightImpurity: Double,
+ val leftPredict: Predict,
+ val rightPredict: Predict) extends Serializable {
override def toString = {
"gain = %f, impurity = %f, left impurity = %f, right impurity = %f"
.format(gain, impurity, leftImpurity, rightImpurity)
}
+
+ override def equals(o: Any) =
+ o match {
+ case other: InformationGainStats => {
+ gain == other.gain &&
+ impurity == other.impurity &&
+ leftImpurity == other.leftImpurity &&
+ rightImpurity == other.rightImpurity
+ }
+ case _ => false
+ }
}
@@ -47,5 +62,6 @@ private[tree] object InformationGainStats {
* denote that current split doesn't satisfies minimum info gain or
* minimum number of instances per node.
*/
- val invalidInformationGainStats = new InformationGainStats(Double.MinValue, -1.0, -1.0, -1.0)
+ val invalidInformationGainStats = new InformationGainStats(Double.MinValue, -1.0, -1.0, -1.0,
+ new Predict(0.0, 0.0), new Predict(0.0, 0.0))
}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Node.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Node.scala
index 56c3e25d9285f..2179da8dbe03e 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Node.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Node.scala
@@ -32,7 +32,8 @@ import org.apache.spark.mllib.linalg.Vector
*
* @param id integer node id, from 1
* @param predict predicted value at the node
- * @param isLeaf whether the leaf is a node
+ * @param impurity current node impurity
+ * @param isLeaf whether the node is a leaf
* @param split split to calculate left and right nodes
* @param leftNode left child
* @param rightNode right child
@@ -41,7 +42,8 @@ import org.apache.spark.mllib.linalg.Vector
@DeveloperApi
class Node (
val id: Int,
- var predict: Double,
+ var predict: Predict,
+ var impurity: Double,
var isLeaf: Boolean,
var split: Option[Split],
var leftNode: Option[Node],
@@ -49,7 +51,7 @@ class Node (
var stats: Option[InformationGainStats]) extends Serializable with Logging {
override def toString = "id = " + id + ", isLeaf = " + isLeaf + ", predict = " + predict + ", " +
- "split = " + split + ", stats = " + stats
+ "impurity = " + impurity + "split = " + split + ", stats = " + stats
/**
* build the left node and right nodes if not leaf
@@ -62,6 +64,7 @@ class Node (
logDebug("id = " + id + ", split = " + split)
logDebug("stats = " + stats)
logDebug("predict = " + predict)
+ logDebug("impurity = " + impurity)
if (!isLeaf) {
leftNode = Some(nodes(Node.leftChildIndex(id)))
rightNode = Some(nodes(Node.rightChildIndex(id)))
@@ -77,7 +80,7 @@ class Node (
*/
def predict(features: Vector) : Double = {
if (isLeaf) {
- predict
+ predict.predict
} else{
if (split.get.featureType == Continuous) {
if (features(split.get.feature) <= split.get.threshold) {
@@ -109,7 +112,7 @@ class Node (
} else {
Some(rightNode.get.deepCopy())
}
- new Node(id, predict, isLeaf, split, leftNodeCopy, rightNodeCopy, stats)
+ new Node(id, predict, impurity, isLeaf, split, leftNodeCopy, rightNodeCopy, stats)
}
/**
@@ -154,7 +157,7 @@ class Node (
}
val prefix: String = " " * indentFactor
if (isLeaf) {
- prefix + s"Predict: $predict\n"
+ prefix + s"Predict: ${predict.predict}\n"
} else {
prefix + s"If ${splitToString(split.get, left=true)}\n" +
leftNode.get.subtreeToString(indentFactor + 1) +
@@ -170,7 +173,27 @@ private[tree] object Node {
/**
* Return a node with the given node id (but nothing else set).
*/
- def emptyNode(nodeIndex: Int): Node = new Node(nodeIndex, 0, false, None, None, None, None)
+ def emptyNode(nodeIndex: Int): Node = new Node(nodeIndex, new Predict(Double.MinValue), -1.0,
+ false, None, None, None, None)
+
+ /**
+ * Construct a node with nodeIndex, predict, impurity and isLeaf parameters.
+ * This is used in `DecisionTree.findBestSplits` to construct child nodes
+ * after finding the best splits for parent nodes.
+ * Other fields are set at next level.
+ * @param nodeIndex integer node id, from 1
+ * @param predict predicted value at the node
+ * @param impurity current node impurity
+ * @param isLeaf whether the node is a leaf
+ * @return new node instance
+ */
+ def apply(
+ nodeIndex: Int,
+ predict: Predict,
+ impurity: Double,
+ isLeaf: Boolean): Node = {
+ new Node(nodeIndex, predict, impurity, isLeaf, None, None, None, None)
+ }
/**
* Return the index of the left child of this node.
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Predict.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Predict.scala
index d8476b5cd7bc7..004838ee5ba0e 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Predict.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Predict.scala
@@ -17,12 +17,15 @@
package org.apache.spark.mllib.tree.model
+import org.apache.spark.annotation.DeveloperApi
+
/**
* Predicted value for a node
* @param predict predicted value
* @param prob probability of the label (classification only)
*/
-private[tree] class Predict(
+@DeveloperApi
+class Predict(
val predict: Double,
val prob: Double = 0.0) extends Serializable {
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/RandomForestModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/RandomForestModel.scala
index 538c0e233202a..6a22e2abe59bd 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/RandomForestModel.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/RandomForestModel.scala
@@ -73,17 +73,27 @@ class RandomForestModel(val trees: Array[DecisionTreeModel], val algo: Algo) ext
def numTrees: Int = trees.size
/**
- * Print full model.
+ * Get total number of nodes, summed over all trees in the forest.
*/
- override def toString: String = {
- val header = algo match {
- case Classification =>
- s"RandomForestModel classifier with $numTrees trees\n"
- case Regression =>
- s"RandomForestModel regressor with $numTrees trees\n"
- case _ => throw new IllegalArgumentException(
- s"RandomForestModel given unknown algo parameter: $algo.")
- }
+ def totalNumNodes: Int = trees.map(tree => tree.numNodes).sum
+
+ /**
+ * Print a summary of the model.
+ */
+ override def toString: String = algo match {
+ case Classification =>
+ s"RandomForestModel classifier with $numTrees trees and $totalNumNodes total nodes"
+ case Regression =>
+ s"RandomForestModel regressor with $numTrees trees and $totalNumNodes total nodes"
+ case _ => throw new IllegalArgumentException(
+ s"RandomForestModel given unknown algo parameter: $algo.")
+ }
+
+ /**
+ * Print the full model to a string.
+ */
+ def toDebugString: String = {
+ val header = toString + "\n"
header + trees.zipWithIndex.map { case (tree, treeIndex) =>
s" Tree $treeIndex:\n" + tree.topNode.subtreeToString(4)
}.fold("")(_ + _)
diff --git a/mllib/src/test/resources/log4j.properties b/mllib/src/test/resources/log4j.properties
index ddfc4ac6b23ed..a469badf603c6 100644
--- a/mllib/src/test/resources/log4j.properties
+++ b/mllib/src/test/resources/log4j.properties
@@ -21,7 +21,7 @@ log4j.appender.file=org.apache.log4j.FileAppender
log4j.appender.file.append=false
log4j.appender.file.file=target/unit-tests.log
log4j.appender.file.layout=org.apache.log4j.PatternLayout
-log4j.appender.file.layout.ConversionPattern=%d{yy/MM/dd HH:mm:ss.SSS} %p %c{1}: %m%n
+log4j.appender.file.layout.ConversionPattern=%d{yy/MM/dd HH:mm:ss.SSS} %t %p %c{1}: %m%n
# Ignore messages below warning level from Jetty, because it's a bit verbose
log4j.logger.org.eclipse.jetty=WARN
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/feature/NormalizerSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/feature/NormalizerSuite.scala
index fb76dccfdf79e..2bf9d9816ae45 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/feature/NormalizerSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/feature/NormalizerSuite.scala
@@ -19,6 +19,8 @@ package org.apache.spark.mllib.feature
import org.scalatest.FunSuite
+import breeze.linalg.{norm => brzNorm}
+
import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vectors}
import org.apache.spark.mllib.util.LocalSparkContext
import org.apache.spark.mllib.util.TestingUtils._
@@ -50,10 +52,10 @@ class NormalizerSuite extends FunSuite with LocalSparkContext {
assert((data1, data1RDD.collect()).zipped.forall((v1, v2) => v1 ~== v2 absTol 1E-5))
- assert(data1(0).toBreeze.norm(1) ~== 1.0 absTol 1E-5)
- assert(data1(2).toBreeze.norm(1) ~== 1.0 absTol 1E-5)
- assert(data1(3).toBreeze.norm(1) ~== 1.0 absTol 1E-5)
- assert(data1(4).toBreeze.norm(1) ~== 1.0 absTol 1E-5)
+ assert(brzNorm(data1(0).toBreeze, 1) ~== 1.0 absTol 1E-5)
+ assert(brzNorm(data1(2).toBreeze, 1) ~== 1.0 absTol 1E-5)
+ assert(brzNorm(data1(3).toBreeze, 1) ~== 1.0 absTol 1E-5)
+ assert(brzNorm(data1(4).toBreeze, 1) ~== 1.0 absTol 1E-5)
assert(data1(0) ~== Vectors.sparse(3, Seq((0, -0.465116279), (1, 0.53488372))) absTol 1E-5)
assert(data1(1) ~== Vectors.dense(0.0, 0.0, 0.0) absTol 1E-5)
@@ -77,10 +79,10 @@ class NormalizerSuite extends FunSuite with LocalSparkContext {
assert((data2, data2RDD.collect()).zipped.forall((v1, v2) => v1 ~== v2 absTol 1E-5))
- assert(data2(0).toBreeze.norm(2) ~== 1.0 absTol 1E-5)
- assert(data2(2).toBreeze.norm(2) ~== 1.0 absTol 1E-5)
- assert(data2(3).toBreeze.norm(2) ~== 1.0 absTol 1E-5)
- assert(data2(4).toBreeze.norm(2) ~== 1.0 absTol 1E-5)
+ assert(brzNorm(data2(0).toBreeze, 2) ~== 1.0 absTol 1E-5)
+ assert(brzNorm(data2(2).toBreeze, 2) ~== 1.0 absTol 1E-5)
+ assert(brzNorm(data2(3).toBreeze, 2) ~== 1.0 absTol 1E-5)
+ assert(brzNorm(data2(4).toBreeze, 2) ~== 1.0 absTol 1E-5)
assert(data2(0) ~== Vectors.sparse(3, Seq((0, -0.65617871), (1, 0.75460552))) absTol 1E-5)
assert(data2(1) ~== Vectors.dense(0.0, 0.0, 0.0) absTol 1E-5)
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala
index a48ed71a1c5fc..98a72b0c4d750 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala
@@ -253,7 +253,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
val stats = rootNode.stats.get
assert(stats.gain > 0)
- assert(rootNode.predict === 1)
+ assert(rootNode.predict.predict === 1)
assert(stats.impurity > 0.2)
}
@@ -282,7 +282,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
val stats = rootNode.stats.get
assert(stats.gain > 0)
- assert(rootNode.predict === 0.6)
+ assert(rootNode.predict.predict === 0.6)
assert(stats.impurity > 0.2)
}
@@ -352,7 +352,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
assert(stats.gain === 0)
assert(stats.leftImpurity === 0)
assert(stats.rightImpurity === 0)
- assert(rootNode.predict === 1)
+ assert(rootNode.predict.predict === 1)
}
test("Binary classification stump with fixed label 0 for Entropy") {
@@ -377,7 +377,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
assert(stats.gain === 0)
assert(stats.leftImpurity === 0)
assert(stats.rightImpurity === 0)
- assert(rootNode.predict === 0)
+ assert(rootNode.predict.predict === 0)
}
test("Binary classification stump with fixed label 1 for Entropy") {
@@ -402,7 +402,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
assert(stats.gain === 0)
assert(stats.leftImpurity === 0)
assert(stats.rightImpurity === 0)
- assert(rootNode.predict === 1)
+ assert(rootNode.predict.predict === 1)
}
test("Second level node building with vs. without groups") {
@@ -471,7 +471,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
assert(stats1.impurity === stats2.impurity)
assert(stats1.leftImpurity === stats2.leftImpurity)
assert(stats1.rightImpurity === stats2.rightImpurity)
- assert(children1(i).predict === children2(i).predict)
+ assert(children1(i).predict.predict === children2(i).predict.predict)
}
}
@@ -646,7 +646,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
val model = DecisionTree.train(rdd, strategy)
assert(model.topNode.isLeaf)
- assert(model.topNode.predict == 0.0)
+ assert(model.topNode.predict.predict == 0.0)
val predicts = rdd.map(p => model.predict(p.features)).collect()
predicts.foreach { predict =>
assert(predict == 0.0)
@@ -693,7 +693,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
val model = DecisionTree.train(input, strategy)
assert(model.topNode.isLeaf)
- assert(model.topNode.predict == 0.0)
+ assert(model.topNode.predict.predict == 0.0)
val predicts = input.map(p => model.predict(p.features)).collect()
predicts.foreach { predict =>
assert(predict == 0.0)
@@ -705,6 +705,92 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
val gain = rootNode.stats.get
assert(gain == InformationGainStats.invalidInformationGainStats)
}
+
+ test("Avoid aggregation on the last level") {
+ val arr = new Array[LabeledPoint](4)
+ arr(0) = new LabeledPoint(0.0, Vectors.dense(1.0, 0.0, 0.0))
+ arr(1) = new LabeledPoint(1.0, Vectors.dense(0.0, 1.0, 1.0))
+ arr(2) = new LabeledPoint(0.0, Vectors.dense(2.0, 0.0, 0.0))
+ arr(3) = new LabeledPoint(1.0, Vectors.dense(0.0, 2.0, 1.0))
+ val input = sc.parallelize(arr)
+
+ val strategy = new Strategy(algo = Classification, impurity = Gini, maxDepth = 1,
+ numClassesForClassification = 2, categoricalFeaturesInfo = Map(0 -> 3))
+ val metadata = DecisionTreeMetadata.buildMetadata(input, strategy)
+ val (splits, bins) = DecisionTree.findSplitsBins(input, metadata)
+
+ val treeInput = TreePoint.convertToTreeRDD(input, bins, metadata)
+ val baggedInput = BaggedPoint.convertToBaggedRDDWithoutSampling(treeInput)
+
+ val topNode = Node.emptyNode(nodeIndex = 1)
+ assert(topNode.predict.predict === Double.MinValue)
+ assert(topNode.impurity === -1.0)
+ assert(topNode.isLeaf === false)
+
+ val nodesForGroup = Map((0, Array(topNode)))
+ val treeToNodeToIndexInfo = Map((0, Map(
+ (topNode.id, new RandomForest.NodeIndexInfo(0, None))
+ )))
+ val nodeQueue = new mutable.Queue[(Int, Node)]()
+ DecisionTree.findBestSplits(baggedInput, metadata, Array(topNode),
+ nodesForGroup, treeToNodeToIndexInfo, splits, bins, nodeQueue)
+
+ // don't enqueue leaf nodes into node queue
+ assert(nodeQueue.isEmpty)
+
+ // set impurity and predict for topNode
+ assert(topNode.predict.predict !== Double.MinValue)
+ assert(topNode.impurity !== -1.0)
+
+ // set impurity and predict for child nodes
+ assert(topNode.leftNode.get.predict.predict === 0.0)
+ assert(topNode.rightNode.get.predict.predict === 1.0)
+ assert(topNode.leftNode.get.impurity === 0.0)
+ assert(topNode.rightNode.get.impurity === 0.0)
+ }
+
+ test("Avoid aggregation if impurity is 0.0") {
+ val arr = new Array[LabeledPoint](4)
+ arr(0) = new LabeledPoint(0.0, Vectors.dense(1.0, 0.0, 0.0))
+ arr(1) = new LabeledPoint(1.0, Vectors.dense(0.0, 1.0, 1.0))
+ arr(2) = new LabeledPoint(0.0, Vectors.dense(2.0, 0.0, 0.0))
+ arr(3) = new LabeledPoint(1.0, Vectors.dense(0.0, 2.0, 1.0))
+ val input = sc.parallelize(arr)
+
+ val strategy = new Strategy(algo = Classification, impurity = Gini, maxDepth = 5,
+ numClassesForClassification = 2, categoricalFeaturesInfo = Map(0 -> 3))
+ val metadata = DecisionTreeMetadata.buildMetadata(input, strategy)
+ val (splits, bins) = DecisionTree.findSplitsBins(input, metadata)
+
+ val treeInput = TreePoint.convertToTreeRDD(input, bins, metadata)
+ val baggedInput = BaggedPoint.convertToBaggedRDDWithoutSampling(treeInput)
+
+ val topNode = Node.emptyNode(nodeIndex = 1)
+ assert(topNode.predict.predict === Double.MinValue)
+ assert(topNode.impurity === -1.0)
+ assert(topNode.isLeaf === false)
+
+ val nodesForGroup = Map((0, Array(topNode)))
+ val treeToNodeToIndexInfo = Map((0, Map(
+ (topNode.id, new RandomForest.NodeIndexInfo(0, None))
+ )))
+ val nodeQueue = new mutable.Queue[(Int, Node)]()
+ DecisionTree.findBestSplits(baggedInput, metadata, Array(topNode),
+ nodesForGroup, treeToNodeToIndexInfo, splits, bins, nodeQueue)
+
+ // don't enqueue a node into node queue if its impurity is 0.0
+ assert(nodeQueue.isEmpty)
+
+ // set impurity and predict for topNode
+ assert(topNode.predict.predict !== Double.MinValue)
+ assert(topNode.impurity !== -1.0)
+
+ // set impurity and predict for child nodes
+ assert(topNode.leftNode.get.predict.predict === 0.0)
+ assert(topNode.rightNode.get.predict.predict === 1.0)
+ assert(topNode.leftNode.get.impurity === 0.0)
+ assert(topNode.rightNode.get.impurity === 0.0)
+ }
}
object DecisionTreeSuite {
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/tree/RandomForestSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/tree/RandomForestSuite.scala
index 30669fcd1c75b..fb44ceb0f57ee 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/tree/RandomForestSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/tree/RandomForestSuite.scala
@@ -145,6 +145,7 @@ class RandomForestSuite extends FunSuite with LocalSparkContext {
assert(nodesForGroup.size === numTrees, failString)
assert(nodesForGroup.values.forall(_.size == 1), failString) // 1 node per tree
+
if (numFeaturesPerNode == numFeatures) {
// featureSubset values should all be None
assert(treeToNodeToIndexInfo.values.forall(_.values.forall(_.featureSubset.isEmpty)),
@@ -172,6 +173,22 @@ class RandomForestSuite extends FunSuite with LocalSparkContext {
checkFeatureSubsetStrategy(numTrees = 2, "onethird", (numFeatures / 3.0).ceil.toInt)
}
+ test("alternating categorical and continuous features with multiclass labels to test indexing") {
+ val arr = new Array[LabeledPoint](4)
+ arr(0) = new LabeledPoint(0.0, Vectors.dense(1.0, 0.0, 0.0, 3.0, 1.0))
+ arr(1) = new LabeledPoint(1.0, Vectors.dense(0.0, 1.0, 1.0, 1.0, 2.0))
+ arr(2) = new LabeledPoint(0.0, Vectors.dense(2.0, 0.0, 0.0, 6.0, 3.0))
+ arr(3) = new LabeledPoint(2.0, Vectors.dense(0.0, 2.0, 1.0, 3.0, 2.0))
+ val categoricalFeaturesInfo = Map(0 -> 3, 2 -> 2, 4 -> 4)
+ val input = sc.parallelize(arr)
+
+ val strategy = new Strategy(algo = Classification, impurity = Gini, maxDepth = 5,
+ numClassesForClassification = 3, categoricalFeaturesInfo = categoricalFeaturesInfo)
+ val model = RandomForest.trainClassifier(input, strategy, numTrees = 2,
+ featureSubsetStrategy = "sqrt", seed = 12345)
+ RandomForestSuite.validateClassifier(model, arr, 1.0)
+ }
+
}
object RandomForestSuite {
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/util/MLUtilsSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/util/MLUtilsSuite.scala
index 8ef2bb1bf6a78..0dbe766b4d917 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/util/MLUtilsSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/util/MLUtilsSuite.scala
@@ -67,8 +67,7 @@ class MLUtilsSuite extends FunSuite with LocalSparkContext {
|0
|0 2:4.0 4:5.0 6:6.0
""".stripMargin
- val tempDir = Files.createTempDir()
- tempDir.deleteOnExit()
+ val tempDir = Utils.createTempDir()
val file = new File(tempDir.getPath, "part-00000")
Files.write(lines, file, Charsets.US_ASCII)
val path = tempDir.toURI.toString
@@ -100,7 +99,7 @@ class MLUtilsSuite extends FunSuite with LocalSparkContext {
LabeledPoint(1.1, Vectors.sparse(3, Seq((0, 1.23), (2, 4.56)))),
LabeledPoint(0.0, Vectors.dense(1.01, 2.02, 3.03))
), 2)
- val tempDir = Files.createTempDir()
+ val tempDir = Utils.createTempDir()
val outputDir = new File(tempDir, "output")
MLUtils.saveAsLibSVMFile(examples, outputDir.toURI.toString)
val lines = outputDir.listFiles()
@@ -166,7 +165,7 @@ class MLUtilsSuite extends FunSuite with LocalSparkContext {
Vectors.sparse(2, Array(1), Array(-1.0)),
Vectors.dense(0.0, 1.0)
), 2)
- val tempDir = Files.createTempDir()
+ val tempDir = Utils.createTempDir()
val outputDir = new File(tempDir, "vectors")
val path = outputDir.toURI.toString
vectors.saveAsTextFile(path)
@@ -181,7 +180,7 @@ class MLUtilsSuite extends FunSuite with LocalSparkContext {
LabeledPoint(0.0, Vectors.sparse(2, Array(1), Array(-1.0))),
LabeledPoint(1.0, Vectors.dense(0.0, 1.0))
), 2)
- val tempDir = Files.createTempDir()
+ val tempDir = Utils.createTempDir()
val outputDir = new File(tempDir, "points")
val path = outputDir.toURI.toString
points.saveAsTextFile(path)
diff --git a/pom.xml b/pom.xml
index 70cb9729ff6d3..288bbf1114bea 100644
--- a/pom.xml
+++ b/pom.xml
@@ -118,7 +118,7 @@
0.18.1shaded-protobuforg.spark-project.akka
- 2.2.3-shaded-protobuf
+ 2.3.4-spark1.7.51.2.171.0.4
@@ -127,7 +127,7 @@
0.94.61.4.03.4.5
- 0.12.0
+ 0.12.0-protobuf-2.51.4.31.2.38.1.14.v20131031
@@ -138,6 +138,7 @@
0.7.11.8.31.1.0
+ 4.2.664m512m
@@ -222,6 +223,18 @@
false
+
+
+ spark-staging
+ Spring Staging Repository
+ https://oss.sonatype.org/content/repositories/orgspark-project-1085
+
+ true
+
+
+ false
+
+
diff --git a/project/MimaBuild.scala b/project/MimaBuild.scala
index 39f8ba4745737..d919b18e09855 100644
--- a/project/MimaBuild.scala
+++ b/project/MimaBuild.scala
@@ -32,7 +32,7 @@ object MimaBuild {
ProblemFilters.exclude[MissingMethodProblem](fullName),
// Sometimes excluded methods have default arguments and
// they are translated into public methods/fields($default$) in generated
- // bytecode. It is not possible to exhustively list everything.
+ // bytecode. It is not possible to exhaustively list everything.
// But this should be okay.
ProblemFilters.exclude[MissingMethodProblem](fullName+"$default$2"),
ProblemFilters.exclude[MissingMethodProblem](fullName+"$default$1"),
diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala
index 4076ebc6fc8d5..c58666af84f24 100644
--- a/project/MimaExcludes.scala
+++ b/project/MimaExcludes.scala
@@ -41,6 +41,8 @@ object MimaExcludes {
MimaBuild.excludeSparkClass("mllib.linalg.Matrix") ++
MimaBuild.excludeSparkClass("mllib.linalg.Vector") ++
Seq(
+ ProblemFilters.exclude[IncompatibleTemplateDefProblem](
+ "org.apache.spark.scheduler.TaskLocation"),
// Added normL1 and normL2 to trait MultivariateStatisticalSummary
ProblemFilters.exclude[MissingMethodProblem](
"org.apache.spark.mllib.stat.MultivariateStatisticalSummary.normL1"),
@@ -48,7 +50,22 @@ object MimaExcludes {
"org.apache.spark.mllib.stat.MultivariateStatisticalSummary.normL2"),
// MapStatus should be private[spark]
ProblemFilters.exclude[IncompatibleTemplateDefProblem](
- "org.apache.spark.scheduler.MapStatus")
+ "org.apache.spark.scheduler.MapStatus"),
+ // TaskContext was promoted to Abstract class
+ ProblemFilters.exclude[AbstractClassProblem](
+ "org.apache.spark.TaskContext")
+ ) ++ Seq(
+ // Adding new methods to the JavaRDDLike trait:
+ ProblemFilters.exclude[MissingMethodProblem](
+ "org.apache.spark.api.java.JavaRDDLike.takeAsync"),
+ ProblemFilters.exclude[MissingMethodProblem](
+ "org.apache.spark.api.java.JavaRDDLike.foreachPartitionAsync"),
+ ProblemFilters.exclude[MissingMethodProblem](
+ "org.apache.spark.api.java.JavaRDDLike.countAsync"),
+ ProblemFilters.exclude[MissingMethodProblem](
+ "org.apache.spark.api.java.JavaRDDLike.foreachAsync"),
+ ProblemFilters.exclude[MissingMethodProblem](
+ "org.apache.spark.api.java.JavaRDDLike.collectAsync")
)
case v if v.startsWith("1.1") =>
diff --git a/project/plugins.sbt b/project/plugins.sbt
index 8096c61414660..678f5ed1ba610 100644
--- a/project/plugins.sbt
+++ b/project/plugins.sbt
@@ -17,7 +17,7 @@ addSbtPlugin("com.github.mpeltonen" % "sbt-idea" % "1.6.0")
addSbtPlugin("net.virtual-void" % "sbt-dependency-graph" % "0.7.4")
-addSbtPlugin("org.scalastyle" %% "scalastyle-sbt-plugin" % "0.4.0")
+addSbtPlugin("org.scalastyle" %% "scalastyle-sbt-plugin" % "0.5.0")
addSbtPlugin("com.typesafe" % "sbt-mima-plugin" % "0.1.6")
diff --git a/project/spark-style/src/main/scala/org/apache/spark/scalastyle/SparkSpaceAfterCommentStartChecker.scala b/project/spark-style/src/main/scala/org/apache/spark/scalastyle/SparkSpaceAfterCommentStartChecker.scala
deleted file mode 100644
index 80d3faa3fe749..0000000000000
--- a/project/spark-style/src/main/scala/org/apache/spark/scalastyle/SparkSpaceAfterCommentStartChecker.scala
+++ /dev/null
@@ -1,58 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-
-package org.apache.spark.scalastyle
-
-import java.util.regex.Pattern
-
-import org.scalastyle.{PositionError, ScalariformChecker, ScalastyleError}
-import scalariform.lexer.{MultiLineComment, ScalaDocComment, SingleLineComment, Token}
-import scalariform.parser.CompilationUnit
-
-class SparkSpaceAfterCommentStartChecker extends ScalariformChecker {
- val errorKey: String = "insert.a.single.space.after.comment.start.and.before.end"
-
- private def multiLineCommentRegex(comment: Token) =
- Pattern.compile( """/\*\S+.*""", Pattern.DOTALL).matcher(comment.text.trim).matches() ||
- Pattern.compile( """/\*.*\S\*/""", Pattern.DOTALL).matcher(comment.text.trim).matches()
-
- private def scalaDocPatternRegex(comment: Token) =
- Pattern.compile( """/\*\*\S+.*""", Pattern.DOTALL).matcher(comment.text.trim).matches() ||
- Pattern.compile( """/\*\*.*\S\*/""", Pattern.DOTALL).matcher(comment.text.trim).matches()
-
- private def singleLineCommentRegex(comment: Token): Boolean =
- comment.text.trim.matches( """//\S+.*""") && !comment.text.trim.matches( """///+""")
-
- override def verify(ast: CompilationUnit): List[ScalastyleError] = {
- ast.tokens
- .filter(hasComment)
- .map {
- _.associatedWhitespaceAndComments.comments.map {
- case x: SingleLineComment if singleLineCommentRegex(x.token) => Some(x.token.offset)
- case x: MultiLineComment if multiLineCommentRegex(x.token) => Some(x.token.offset)
- case x: ScalaDocComment if scalaDocPatternRegex(x.token) => Some(x.token.offset)
- case _ => None
- }.flatten
- }.flatten.map(PositionError(_))
- }
-
-
- private def hasComment(x: Token) =
- x.associatedWhitespaceAndComments != null && !x.associatedWhitespaceAndComments.comments.isEmpty
-
-}
diff --git a/python/.gitignore b/python/.gitignore
index 80b361ffbd51c..52128cf844a79 100644
--- a/python/.gitignore
+++ b/python/.gitignore
@@ -1,5 +1,5 @@
*.pyc
-docs/
+docs/_build/
pyspark.egg-info
build/
dist/
diff --git a/python/docs/conf.py b/python/docs/conf.py
index c368cf81a003b..e58d97ae6a746 100644
--- a/python/docs/conf.py
+++ b/python/docs/conf.py
@@ -55,9 +55,9 @@
# built documents.
#
# The short X.Y version.
-version = '1.1'
+version = '1.2-SNAPSHOT'
# The full version, including alpha/beta/rc tags.
-release = ''
+release = '1.2-SNAPSHOT'
# The language for content autogenerated by Sphinx. Refer to documentation
# for a list of supported languages.
@@ -102,7 +102,7 @@
# The theme to use for HTML and HTML Help pages. See the documentation for
# a list of builtin themes.
-html_theme = 'default'
+html_theme = 'nature'
# Theme options are theme-specific and customize the look and feel of a theme
# further. For a list of options available for each theme, see the
@@ -121,7 +121,7 @@
# The name of an image file (relative to this directory) to place at the top
# of the sidebar.
-#html_logo = None
+html_logo = "../../docs/img/spark-logo-hd.png"
# The name of an image file (within the static path) to use as favicon of the
# docs. This file should be a Windows icon file (.ico) being 16x16 or 32x32
@@ -131,7 +131,7 @@
# Add any paths that contain custom static files (such as style sheets) here,
# relative to this directory. They are copied after the builtin static files,
# so a file named "default.css" will overwrite the builtin "default.css".
-html_static_path = ['_static']
+#html_static_path = ['_static']
# Add any extra paths that contain custom files (such as robots.txt or
# .htaccess) here, relative to this directory. These files are copied
@@ -154,10 +154,10 @@
#html_additional_pages = {}
# If false, no module index is generated.
-#html_domain_indices = True
+html_domain_indices = False
# If false, no index is generated.
-#html_use_index = True
+html_use_index = False
# If true, the index is split into individual pages for each letter.
#html_split_index = False
diff --git a/python/docs/epytext.py b/python/docs/epytext.py
index 61d731bff570d..19fefbfc057a4 100644
--- a/python/docs/epytext.py
+++ b/python/docs/epytext.py
@@ -5,7 +5,7 @@
(r"L{([\w.()]+)}", r":class:`\1`"),
(r"[LC]{(\w+\.\w+)\(\)}", r":func:`\1`"),
(r"C{([\w.()]+)}", r":class:`\1`"),
- (r"[IBCM]{(.+)}", r"`\1`"),
+ (r"[IBCM]{([^}]+)}", r"`\1`"),
('pyspark.rdd.RDD', 'RDD'),
)
diff --git a/python/docs/index.rst b/python/docs/index.rst
index 25b3f9bd93e63..703bef644de28 100644
--- a/python/docs/index.rst
+++ b/python/docs/index.rst
@@ -3,7 +3,7 @@
You can adapt this file completely to your liking, but it should at least
contain the root `toctree` directive.
-Welcome to PySpark API reference!
+Welcome to Spark Python API Docs!
===================================
Contents:
@@ -13,6 +13,7 @@ Contents:
pyspark
pyspark.sql
+ pyspark.streaming
pyspark.mllib
@@ -24,14 +25,12 @@ Core classes:
Main entry point for Spark functionality.
:class:`pyspark.RDD`
-
+
A Resilient Distributed Dataset (RDD), the basic abstraction in Spark.
Indices and tables
==================
-* :ref:`genindex`
-* :ref:`modindex`
* :ref:`search`
diff --git a/python/docs/make.bat b/python/docs/make.bat
index adad44fd7536a..c011e82b4a35a 100644
--- a/python/docs/make.bat
+++ b/python/docs/make.bat
@@ -1,242 +1,6 @@
@ECHO OFF
-REM Command file for Sphinx documentation
+rem This is the entry point for running Sphinx documentation. To avoid polluting the
+rem environment, it just launches a new cmd to do the real work.
-if "%SPHINXBUILD%" == "" (
- set SPHINXBUILD=sphinx-build
-)
-set BUILDDIR=_build
-set ALLSPHINXOPTS=-d %BUILDDIR%/doctrees %SPHINXOPTS% .
-set I18NSPHINXOPTS=%SPHINXOPTS% .
-if NOT "%PAPER%" == "" (
- set ALLSPHINXOPTS=-D latex_paper_size=%PAPER% %ALLSPHINXOPTS%
- set I18NSPHINXOPTS=-D latex_paper_size=%PAPER% %I18NSPHINXOPTS%
-)
-
-if "%1" == "" goto help
-
-if "%1" == "help" (
- :help
- echo.Please use `make ^` where ^ is one of
- echo. html to make standalone HTML files
- echo. dirhtml to make HTML files named index.html in directories
- echo. singlehtml to make a single large HTML file
- echo. pickle to make pickle files
- echo. json to make JSON files
- echo. htmlhelp to make HTML files and a HTML help project
- echo. qthelp to make HTML files and a qthelp project
- echo. devhelp to make HTML files and a Devhelp project
- echo. epub to make an epub
- echo. latex to make LaTeX files, you can set PAPER=a4 or PAPER=letter
- echo. text to make text files
- echo. man to make manual pages
- echo. texinfo to make Texinfo files
- echo. gettext to make PO message catalogs
- echo. changes to make an overview over all changed/added/deprecated items
- echo. xml to make Docutils-native XML files
- echo. pseudoxml to make pseudoxml-XML files for display purposes
- echo. linkcheck to check all external links for integrity
- echo. doctest to run all doctests embedded in the documentation if enabled
- goto end
-)
-
-if "%1" == "clean" (
- for /d %%i in (%BUILDDIR%\*) do rmdir /q /s %%i
- del /q /s %BUILDDIR%\*
- goto end
-)
-
-
-%SPHINXBUILD% 2> nul
-if errorlevel 9009 (
- echo.
- echo.The 'sphinx-build' command was not found. Make sure you have Sphinx
- echo.installed, then set the SPHINXBUILD environment variable to point
- echo.to the full path of the 'sphinx-build' executable. Alternatively you
- echo.may add the Sphinx directory to PATH.
- echo.
- echo.If you don't have Sphinx installed, grab it from
- echo.http://sphinx-doc.org/
- exit /b 1
-)
-
-if "%1" == "html" (
- %SPHINXBUILD% -b html %ALLSPHINXOPTS% %BUILDDIR%/html
- if errorlevel 1 exit /b 1
- echo.
- echo.Build finished. The HTML pages are in %BUILDDIR%/html.
- goto end
-)
-
-if "%1" == "dirhtml" (
- %SPHINXBUILD% -b dirhtml %ALLSPHINXOPTS% %BUILDDIR%/dirhtml
- if errorlevel 1 exit /b 1
- echo.
- echo.Build finished. The HTML pages are in %BUILDDIR%/dirhtml.
- goto end
-)
-
-if "%1" == "singlehtml" (
- %SPHINXBUILD% -b singlehtml %ALLSPHINXOPTS% %BUILDDIR%/singlehtml
- if errorlevel 1 exit /b 1
- echo.
- echo.Build finished. The HTML pages are in %BUILDDIR%/singlehtml.
- goto end
-)
-
-if "%1" == "pickle" (
- %SPHINXBUILD% -b pickle %ALLSPHINXOPTS% %BUILDDIR%/pickle
- if errorlevel 1 exit /b 1
- echo.
- echo.Build finished; now you can process the pickle files.
- goto end
-)
-
-if "%1" == "json" (
- %SPHINXBUILD% -b json %ALLSPHINXOPTS% %BUILDDIR%/json
- if errorlevel 1 exit /b 1
- echo.
- echo.Build finished; now you can process the JSON files.
- goto end
-)
-
-if "%1" == "htmlhelp" (
- %SPHINXBUILD% -b htmlhelp %ALLSPHINXOPTS% %BUILDDIR%/htmlhelp
- if errorlevel 1 exit /b 1
- echo.
- echo.Build finished; now you can run HTML Help Workshop with the ^
-.hhp project file in %BUILDDIR%/htmlhelp.
- goto end
-)
-
-if "%1" == "qthelp" (
- %SPHINXBUILD% -b qthelp %ALLSPHINXOPTS% %BUILDDIR%/qthelp
- if errorlevel 1 exit /b 1
- echo.
- echo.Build finished; now you can run "qcollectiongenerator" with the ^
-.qhcp project file in %BUILDDIR%/qthelp, like this:
- echo.^> qcollectiongenerator %BUILDDIR%\qthelp\pyspark.qhcp
- echo.To view the help file:
- echo.^> assistant -collectionFile %BUILDDIR%\qthelp\pyspark.ghc
- goto end
-)
-
-if "%1" == "devhelp" (
- %SPHINXBUILD% -b devhelp %ALLSPHINXOPTS% %BUILDDIR%/devhelp
- if errorlevel 1 exit /b 1
- echo.
- echo.Build finished.
- goto end
-)
-
-if "%1" == "epub" (
- %SPHINXBUILD% -b epub %ALLSPHINXOPTS% %BUILDDIR%/epub
- if errorlevel 1 exit /b 1
- echo.
- echo.Build finished. The epub file is in %BUILDDIR%/epub.
- goto end
-)
-
-if "%1" == "latex" (
- %SPHINXBUILD% -b latex %ALLSPHINXOPTS% %BUILDDIR%/latex
- if errorlevel 1 exit /b 1
- echo.
- echo.Build finished; the LaTeX files are in %BUILDDIR%/latex.
- goto end
-)
-
-if "%1" == "latexpdf" (
- %SPHINXBUILD% -b latex %ALLSPHINXOPTS% %BUILDDIR%/latex
- cd %BUILDDIR%/latex
- make all-pdf
- cd %BUILDDIR%/..
- echo.
- echo.Build finished; the PDF files are in %BUILDDIR%/latex.
- goto end
-)
-
-if "%1" == "latexpdfja" (
- %SPHINXBUILD% -b latex %ALLSPHINXOPTS% %BUILDDIR%/latex
- cd %BUILDDIR%/latex
- make all-pdf-ja
- cd %BUILDDIR%/..
- echo.
- echo.Build finished; the PDF files are in %BUILDDIR%/latex.
- goto end
-)
-
-if "%1" == "text" (
- %SPHINXBUILD% -b text %ALLSPHINXOPTS% %BUILDDIR%/text
- if errorlevel 1 exit /b 1
- echo.
- echo.Build finished. The text files are in %BUILDDIR%/text.
- goto end
-)
-
-if "%1" == "man" (
- %SPHINXBUILD% -b man %ALLSPHINXOPTS% %BUILDDIR%/man
- if errorlevel 1 exit /b 1
- echo.
- echo.Build finished. The manual pages are in %BUILDDIR%/man.
- goto end
-)
-
-if "%1" == "texinfo" (
- %SPHINXBUILD% -b texinfo %ALLSPHINXOPTS% %BUILDDIR%/texinfo
- if errorlevel 1 exit /b 1
- echo.
- echo.Build finished. The Texinfo files are in %BUILDDIR%/texinfo.
- goto end
-)
-
-if "%1" == "gettext" (
- %SPHINXBUILD% -b gettext %I18NSPHINXOPTS% %BUILDDIR%/locale
- if errorlevel 1 exit /b 1
- echo.
- echo.Build finished. The message catalogs are in %BUILDDIR%/locale.
- goto end
-)
-
-if "%1" == "changes" (
- %SPHINXBUILD% -b changes %ALLSPHINXOPTS% %BUILDDIR%/changes
- if errorlevel 1 exit /b 1
- echo.
- echo.The overview file is in %BUILDDIR%/changes.
- goto end
-)
-
-if "%1" == "linkcheck" (
- %SPHINXBUILD% -b linkcheck %ALLSPHINXOPTS% %BUILDDIR%/linkcheck
- if errorlevel 1 exit /b 1
- echo.
- echo.Link check complete; look for any errors in the above output ^
-or in %BUILDDIR%/linkcheck/output.txt.
- goto end
-)
-
-if "%1" == "doctest" (
- %SPHINXBUILD% -b doctest %ALLSPHINXOPTS% %BUILDDIR%/doctest
- if errorlevel 1 exit /b 1
- echo.
- echo.Testing of doctests in the sources finished, look at the ^
-results in %BUILDDIR%/doctest/output.txt.
- goto end
-)
-
-if "%1" == "xml" (
- %SPHINXBUILD% -b xml %ALLSPHINXOPTS% %BUILDDIR%/xml
- if errorlevel 1 exit /b 1
- echo.
- echo.Build finished. The XML files are in %BUILDDIR%/xml.
- goto end
-)
-
-if "%1" == "pseudoxml" (
- %SPHINXBUILD% -b pseudoxml %ALLSPHINXOPTS% %BUILDDIR%/pseudoxml
- if errorlevel 1 exit /b 1
- echo.
- echo.Build finished. The pseudo-XML files are in %BUILDDIR%/pseudoxml.
- goto end
-)
-
-:end
+cmd /V /E /C %~dp0make2.bat %*
diff --git a/python/docs/make2.bat b/python/docs/make2.bat
new file mode 100644
index 0000000000000..7bcaeafad13d7
--- /dev/null
+++ b/python/docs/make2.bat
@@ -0,0 +1,243 @@
+@ECHO OFF
+
+REM Command file for Sphinx documentation
+
+
+if "%SPHINXBUILD%" == "" (
+ set SPHINXBUILD=sphinx-build
+)
+set BUILDDIR=_build
+set ALLSPHINXOPTS=-d %BUILDDIR%/doctrees %SPHINXOPTS% .
+set I18NSPHINXOPTS=%SPHINXOPTS% .
+if NOT "%PAPER%" == "" (
+ set ALLSPHINXOPTS=-D latex_paper_size=%PAPER% %ALLSPHINXOPTS%
+ set I18NSPHINXOPTS=-D latex_paper_size=%PAPER% %I18NSPHINXOPTS%
+)
+
+if "%1" == "" goto help
+
+if "%1" == "help" (
+ :help
+ echo.Please use `make ^` where ^ is one of
+ echo. html to make standalone HTML files
+ echo. dirhtml to make HTML files named index.html in directories
+ echo. singlehtml to make a single large HTML file
+ echo. pickle to make pickle files
+ echo. json to make JSON files
+ echo. htmlhelp to make HTML files and a HTML help project
+ echo. qthelp to make HTML files and a qthelp project
+ echo. devhelp to make HTML files and a Devhelp project
+ echo. epub to make an epub
+ echo. latex to make LaTeX files, you can set PAPER=a4 or PAPER=letter
+ echo. text to make text files
+ echo. man to make manual pages
+ echo. texinfo to make Texinfo files
+ echo. gettext to make PO message catalogs
+ echo. changes to make an overview over all changed/added/deprecated items
+ echo. xml to make Docutils-native XML files
+ echo. pseudoxml to make pseudoxml-XML files for display purposes
+ echo. linkcheck to check all external links for integrity
+ echo. doctest to run all doctests embedded in the documentation if enabled
+ goto end
+)
+
+if "%1" == "clean" (
+ for /d %%i in (%BUILDDIR%\*) do rmdir /q /s %%i
+ del /q /s %BUILDDIR%\*
+ goto end
+)
+
+
+%SPHINXBUILD% 2> nul
+if errorlevel 9009 (
+ echo.
+ echo.The 'sphinx-build' command was not found. Make sure you have Sphinx
+ echo.installed, then set the SPHINXBUILD environment variable to point
+ echo.to the full path of the 'sphinx-build' executable. Alternatively you
+ echo.may add the Sphinx directory to PATH.
+ echo.
+ echo.If you don't have Sphinx installed, grab it from
+ echo.http://sphinx-doc.org/
+ exit /b 1
+)
+
+if "%1" == "html" (
+ %SPHINXBUILD% -b html %ALLSPHINXOPTS% %BUILDDIR%/html
+ if errorlevel 1 exit /b 1
+ echo.
+ echo.Build finished. The HTML pages are in %BUILDDIR%/html.
+ goto end
+)
+
+if "%1" == "dirhtml" (
+ %SPHINXBUILD% -b dirhtml %ALLSPHINXOPTS% %BUILDDIR%/dirhtml
+ if errorlevel 1 exit /b 1
+ echo.
+ echo.Build finished. The HTML pages are in %BUILDDIR%/dirhtml.
+ goto end
+)
+
+if "%1" == "singlehtml" (
+ %SPHINXBUILD% -b singlehtml %ALLSPHINXOPTS% %BUILDDIR%/singlehtml
+ if errorlevel 1 exit /b 1
+ echo.
+ echo.Build finished. The HTML pages are in %BUILDDIR%/singlehtml.
+ goto end
+)
+
+if "%1" == "pickle" (
+ %SPHINXBUILD% -b pickle %ALLSPHINXOPTS% %BUILDDIR%/pickle
+ if errorlevel 1 exit /b 1
+ echo.
+ echo.Build finished; now you can process the pickle files.
+ goto end
+)
+
+if "%1" == "json" (
+ %SPHINXBUILD% -b json %ALLSPHINXOPTS% %BUILDDIR%/json
+ if errorlevel 1 exit /b 1
+ echo.
+ echo.Build finished; now you can process the JSON files.
+ goto end
+)
+
+if "%1" == "htmlhelp" (
+ %SPHINXBUILD% -b htmlhelp %ALLSPHINXOPTS% %BUILDDIR%/htmlhelp
+ if errorlevel 1 exit /b 1
+ echo.
+ echo.Build finished; now you can run HTML Help Workshop with the ^
+.hhp project file in %BUILDDIR%/htmlhelp.
+ goto end
+)
+
+if "%1" == "qthelp" (
+ %SPHINXBUILD% -b qthelp %ALLSPHINXOPTS% %BUILDDIR%/qthelp
+ if errorlevel 1 exit /b 1
+ echo.
+ echo.Build finished; now you can run "qcollectiongenerator" with the ^
+.qhcp project file in %BUILDDIR%/qthelp, like this:
+ echo.^> qcollectiongenerator %BUILDDIR%\qthelp\pyspark.qhcp
+ echo.To view the help file:
+ echo.^> assistant -collectionFile %BUILDDIR%\qthelp\pyspark.ghc
+ goto end
+)
+
+if "%1" == "devhelp" (
+ %SPHINXBUILD% -b devhelp %ALLSPHINXOPTS% %BUILDDIR%/devhelp
+ if errorlevel 1 exit /b 1
+ echo.
+ echo.Build finished.
+ goto end
+)
+
+if "%1" == "epub" (
+ %SPHINXBUILD% -b epub %ALLSPHINXOPTS% %BUILDDIR%/epub
+ if errorlevel 1 exit /b 1
+ echo.
+ echo.Build finished. The epub file is in %BUILDDIR%/epub.
+ goto end
+)
+
+if "%1" == "latex" (
+ %SPHINXBUILD% -b latex %ALLSPHINXOPTS% %BUILDDIR%/latex
+ if errorlevel 1 exit /b 1
+ echo.
+ echo.Build finished; the LaTeX files are in %BUILDDIR%/latex.
+ goto end
+)
+
+if "%1" == "latexpdf" (
+ %SPHINXBUILD% -b latex %ALLSPHINXOPTS% %BUILDDIR%/latex
+ cd %BUILDDIR%/latex
+ make all-pdf
+ cd %BUILDDIR%/..
+ echo.
+ echo.Build finished; the PDF files are in %BUILDDIR%/latex.
+ goto end
+)
+
+if "%1" == "latexpdfja" (
+ %SPHINXBUILD% -b latex %ALLSPHINXOPTS% %BUILDDIR%/latex
+ cd %BUILDDIR%/latex
+ make all-pdf-ja
+ cd %BUILDDIR%/..
+ echo.
+ echo.Build finished; the PDF files are in %BUILDDIR%/latex.
+ goto end
+)
+
+if "%1" == "text" (
+ %SPHINXBUILD% -b text %ALLSPHINXOPTS% %BUILDDIR%/text
+ if errorlevel 1 exit /b 1
+ echo.
+ echo.Build finished. The text files are in %BUILDDIR%/text.
+ goto end
+)
+
+if "%1" == "man" (
+ %SPHINXBUILD% -b man %ALLSPHINXOPTS% %BUILDDIR%/man
+ if errorlevel 1 exit /b 1
+ echo.
+ echo.Build finished. The manual pages are in %BUILDDIR%/man.
+ goto end
+)
+
+if "%1" == "texinfo" (
+ %SPHINXBUILD% -b texinfo %ALLSPHINXOPTS% %BUILDDIR%/texinfo
+ if errorlevel 1 exit /b 1
+ echo.
+ echo.Build finished. The Texinfo files are in %BUILDDIR%/texinfo.
+ goto end
+)
+
+if "%1" == "gettext" (
+ %SPHINXBUILD% -b gettext %I18NSPHINXOPTS% %BUILDDIR%/locale
+ if errorlevel 1 exit /b 1
+ echo.
+ echo.Build finished. The message catalogs are in %BUILDDIR%/locale.
+ goto end
+)
+
+if "%1" == "changes" (
+ %SPHINXBUILD% -b changes %ALLSPHINXOPTS% %BUILDDIR%/changes
+ if errorlevel 1 exit /b 1
+ echo.
+ echo.The overview file is in %BUILDDIR%/changes.
+ goto end
+)
+
+if "%1" == "linkcheck" (
+ %SPHINXBUILD% -b linkcheck %ALLSPHINXOPTS% %BUILDDIR%/linkcheck
+ if errorlevel 1 exit /b 1
+ echo.
+ echo.Link check complete; look for any errors in the above output ^
+or in %BUILDDIR%/linkcheck/output.txt.
+ goto end
+)
+
+if "%1" == "doctest" (
+ %SPHINXBUILD% -b doctest %ALLSPHINXOPTS% %BUILDDIR%/doctest
+ if errorlevel 1 exit /b 1
+ echo.
+ echo.Testing of doctests in the sources finished, look at the ^
+results in %BUILDDIR%/doctest/output.txt.
+ goto end
+)
+
+if "%1" == "xml" (
+ %SPHINXBUILD% -b xml %ALLSPHINXOPTS% %BUILDDIR%/xml
+ if errorlevel 1 exit /b 1
+ echo.
+ echo.Build finished. The XML files are in %BUILDDIR%/xml.
+ goto end
+)
+
+if "%1" == "pseudoxml" (
+ %SPHINXBUILD% -b pseudoxml %ALLSPHINXOPTS% %BUILDDIR%/pseudoxml
+ if errorlevel 1 exit /b 1
+ echo.
+ echo.Build finished. The pseudo-XML files are in %BUILDDIR%/pseudoxml.
+ goto end
+)
+
+:end
diff --git a/python/docs/modules.rst b/python/docs/modules.rst
deleted file mode 100644
index 183564659fbcf..0000000000000
--- a/python/docs/modules.rst
+++ /dev/null
@@ -1,7 +0,0 @@
-.
-=
-
-.. toctree::
- :maxdepth: 4
-
- pyspark
diff --git a/python/docs/pyspark.mllib.rst b/python/docs/pyspark.mllib.rst
index e95d19e97f151..4548b8739ed91 100644
--- a/python/docs/pyspark.mllib.rst
+++ b/python/docs/pyspark.mllib.rst
@@ -20,6 +20,14 @@ pyspark.mllib.clustering module
:undoc-members:
:show-inheritance:
+pyspark.mllib.feature module
+-------------------------------
+
+.. automodule:: pyspark.mllib.feature
+ :members:
+ :undoc-members:
+ :show-inheritance:
+
pyspark.mllib.linalg module
---------------------------
diff --git a/python/docs/pyspark.rst b/python/docs/pyspark.rst
index a68bd62433085..e81be3b6cb796 100644
--- a/python/docs/pyspark.rst
+++ b/python/docs/pyspark.rst
@@ -7,8 +7,9 @@ Subpackages
.. toctree::
:maxdepth: 1
- pyspark.mllib
pyspark.sql
+ pyspark.streaming
+ pyspark.mllib
Contents
--------
diff --git a/python/docs/pyspark.streaming.rst b/python/docs/pyspark.streaming.rst
new file mode 100644
index 0000000000000..5024d694b668f
--- /dev/null
+++ b/python/docs/pyspark.streaming.rst
@@ -0,0 +1,10 @@
+pyspark.streaming module
+==================
+
+Module contents
+---------------
+
+.. automodule:: pyspark.streaming
+ :members:
+ :undoc-members:
+ :show-inheritance:
diff --git a/python/pyspark/__init__.py b/python/pyspark/__init__.py
index 1a2e774738fe7..e39e6514d77a1 100644
--- a/python/pyspark/__init__.py
+++ b/python/pyspark/__init__.py
@@ -20,33 +20,21 @@
Public classes:
- - L{SparkContext}
+ - :class:`SparkContext`:
Main entry point for Spark functionality.
- - L{RDD}
+ - L{RDD}
A Resilient Distributed Dataset (RDD), the basic abstraction in Spark.
- - L{Broadcast}
+ - L{Broadcast}
A broadcast variable that gets reused across tasks.
- - L{Accumulator}
+ - L{Accumulator}
An "add-only" shared variable that tasks can only add values to.
- - L{SparkConf}
+ - L{SparkConf}
For configuring Spark.
- - L{SparkFiles}
+ - L{SparkFiles}
Access files shipped with jobs.
- - L{StorageLevel}
+ - L{StorageLevel}
Finer-grained cache persistence levels.
-Spark SQL:
- - L{SQLContext}
- Main entry point for SQL functionality.
- - L{SchemaRDD}
- A Resilient Distributed Dataset (RDD) with Schema information for the data contained. In
- addition to normal RDD operations, SchemaRDDs also support SQL.
- - L{Row}
- A Row of data returned by a Spark SQL query.
-
-Hive:
- - L{HiveContext}
- Main entry point for accessing data stored in Apache Hive..
"""
# The following block allows us to import python's random instead of mllib.random for scripts in
diff --git a/python/pyspark/accumulators.py b/python/pyspark/accumulators.py
index ccbca67656c8d..b8cdbbe3cf2b6 100644
--- a/python/pyspark/accumulators.py
+++ b/python/pyspark/accumulators.py
@@ -215,6 +215,21 @@ def addInPlace(self, value1, value2):
COMPLEX_ACCUMULATOR_PARAM = AddingAccumulatorParam(0.0j)
+class PStatsParam(AccumulatorParam):
+ """PStatsParam is used to merge pstats.Stats"""
+
+ @staticmethod
+ def zero(value):
+ return None
+
+ @staticmethod
+ def addInPlace(value1, value2):
+ if value1 is None:
+ return value2
+ value1.add(value2)
+ return value1
+
+
class _UpdateRequestHandler(SocketServer.StreamRequestHandler):
"""
diff --git a/python/pyspark/conf.py b/python/pyspark/conf.py
index b64875a3f495a..dc7cd0bce56f3 100644
--- a/python/pyspark/conf.py
+++ b/python/pyspark/conf.py
@@ -83,11 +83,11 @@ def __init__(self, loadDefaults=True, _jvm=None, _jconf=None):
"""
Create a new Spark configuration.
- @param loadDefaults: whether to load values from Java system
+ :param loadDefaults: whether to load values from Java system
properties (True by default)
- @param _jvm: internal parameter used to pass a handle to the
+ :param _jvm: internal parameter used to pass a handle to the
Java VM; does not need to be set by users
- @param _jconf: Optionally pass in an existing SparkConf handle
+ :param _jconf: Optionally pass in an existing SparkConf handle
to use its parameters
"""
if _jconf:
@@ -139,7 +139,7 @@ def setAll(self, pairs):
"""
Set multiple parameters, passed as a list of key-value pairs.
- @param pairs: list of key-value pairs to set
+ :param pairs: list of key-value pairs to set
"""
for (k, v) in pairs:
self._jconf.set(k, v)
diff --git a/python/pyspark/context.py b/python/pyspark/context.py
index 8e7b00469e246..8d27ccb95f82c 100644
--- a/python/pyspark/context.py
+++ b/python/pyspark/context.py
@@ -20,6 +20,7 @@
import sys
from threading import Lock
from tempfile import NamedTemporaryFile
+import atexit
from pyspark import accumulators
from pyspark.accumulators import Accumulator
@@ -28,9 +29,8 @@
from pyspark.files import SparkFiles
from pyspark.java_gateway import launch_gateway
from pyspark.serializers import PickleSerializer, BatchedSerializer, UTF8Deserializer, \
- PairDeserializer, CompressedSerializer
+ PairDeserializer, CompressedSerializer, AutoBatchedSerializer
from pyspark.storagelevel import StorageLevel
-from pyspark import rdd
from pyspark.rdd import RDD
from pyspark.traceback_utils import CallSite, first_spark_call
@@ -67,27 +67,28 @@ class SparkContext(object):
_default_batch_size_for_serialized_input = 10
def __init__(self, master=None, appName=None, sparkHome=None, pyFiles=None,
- environment=None, batchSize=1024, serializer=PickleSerializer(), conf=None,
- gateway=None):
+ environment=None, batchSize=0, serializer=PickleSerializer(), conf=None,
+ gateway=None, jsc=None):
"""
Create a new SparkContext. At least the master and app name should be set,
either through the named parameters here or through C{conf}.
- @param master: Cluster URL to connect to
+ :param master: Cluster URL to connect to
(e.g. mesos://host:port, spark://host:port, local[4]).
- @param appName: A name for your job, to display on the cluster web UI.
- @param sparkHome: Location where Spark is installed on cluster nodes.
- @param pyFiles: Collection of .zip or .py files to send to the cluster
+ :param appName: A name for your job, to display on the cluster web UI.
+ :param sparkHome: Location where Spark is installed on cluster nodes.
+ :param pyFiles: Collection of .zip or .py files to send to the cluster
and add to PYTHONPATH. These can be paths on the local file
system or HDFS, HTTP, HTTPS, or FTP URLs.
- @param environment: A dictionary of environment variables to set on
+ :param environment: A dictionary of environment variables to set on
worker nodes.
- @param batchSize: The number of Python objects represented as a single
- Java object. Set 1 to disable batching or -1 to use an
- unlimited batch size.
- @param serializer: The serializer for RDDs.
- @param conf: A L{SparkConf} object setting Spark properties.
- @param gateway: Use an existing gateway and JVM, otherwise a new JVM
+ :param batchSize: The number of Python objects represented as a single
+ Java object. Set 1 to disable batching, 0 to automatically choose
+ the batch size based on object sizes, or -1 to use an unlimited
+ batch size
+ :param serializer: The serializer for RDDs.
+ :param conf: A L{SparkConf} object setting Spark properties.
+ :param gateway: Use an existing gateway and JVM, otherwise a new JVM
will be instantiated.
@@ -103,20 +104,22 @@ def __init__(self, master=None, appName=None, sparkHome=None, pyFiles=None,
SparkContext._ensure_initialized(self, gateway=gateway)
try:
self._do_init(master, appName, sparkHome, pyFiles, environment, batchSize, serializer,
- conf)
+ conf, jsc)
except:
# If an error occurs, clean up in order to allow future SparkContext creation:
self.stop()
raise
def _do_init(self, master, appName, sparkHome, pyFiles, environment, batchSize, serializer,
- conf):
+ conf, jsc):
self.environment = environment or {}
self._conf = conf or SparkConf(_jvm=self._jvm)
self._batchSize = batchSize # -1 represents an unlimited batch size
self._unbatched_serializer = serializer
if batchSize == 1:
self.serializer = self._unbatched_serializer
+ elif batchSize == 0:
+ self.serializer = AutoBatchedSerializer(self._unbatched_serializer)
else:
self.serializer = BatchedSerializer(self._unbatched_serializer,
batchSize)
@@ -151,7 +154,7 @@ def _do_init(self, master, appName, sparkHome, pyFiles, environment, batchSize,
self.environment[varName] = v
# Create the Java SparkContext through Py4J
- self._jsc = self._initialize_context(self._conf._jconf)
+ self._jsc = jsc or self._initialize_context(self._conf._jconf)
# Create a single Accumulator in Java that we'll send all our updates through;
# they will be passed back to us through a TCP server
@@ -192,6 +195,9 @@ def _do_init(self, master, appName, sparkHome, pyFiles, environment, batchSize,
self._temp_dir = \
self._jvm.org.apache.spark.util.Utils.createTempDir(local_dir).getAbsolutePath()
+ # profiling stats collected for each PythonRDD
+ self._profile_stats = []
+
def _initialize_context(self, jconf):
"""
Initialize SparkContext in function to allow subclass specific initialization
@@ -209,8 +215,6 @@ def _ensure_initialized(cls, instance=None, gateway=None):
SparkContext._gateway = gateway or launch_gateway()
SparkContext._jvm = SparkContext._gateway.jvm
SparkContext._writeToFile = SparkContext._jvm.PythonRDD.writeToFile
- SparkContext._jvm.SerDeUtil.initialize()
- SparkContext._jvm.SerDe.initialize()
if instance:
if (SparkContext._active_spark_context and
@@ -407,22 +411,23 @@ def sequenceFile(self, path, keyClass=None, valueClass=None, keyConverter=None,
Read a Hadoop SequenceFile with arbitrary key and value Writable class from HDFS,
a local file system (available on all nodes), or any Hadoop-supported file system URI.
The mechanism is as follows:
+
1. A Java RDD is created from the SequenceFile or other InputFormat, and the key
and value Writable classes
2. Serialization is attempted via Pyrolite pickling
3. If this fails, the fallback is to call 'toString' on each key and value
4. C{PickleSerializer} is used to deserialize pickled objects on the Python side
- @param path: path to sequncefile
- @param keyClass: fully qualified classname of key Writable class
+ :param path: path to sequncefile
+ :param keyClass: fully qualified classname of key Writable class
(e.g. "org.apache.hadoop.io.Text")
- @param valueClass: fully qualified classname of value Writable class
+ :param valueClass: fully qualified classname of value Writable class
(e.g. "org.apache.hadoop.io.LongWritable")
- @param keyConverter:
- @param valueConverter:
- @param minSplits: minimum splits in dataset
+ :param keyConverter:
+ :param valueConverter:
+ :param minSplits: minimum splits in dataset
(default min(2, sc.defaultParallelism))
- @param batchSize: The number of Python objects represented as a single
+ :param batchSize: The number of Python objects represented as a single
Java object. (default sc._default_batch_size_for_serialized_input)
"""
minSplits = minSplits or min(self.defaultParallelism, 2)
@@ -442,18 +447,18 @@ def newAPIHadoopFile(self, path, inputFormatClass, keyClass, valueClass, keyConv
A Hadoop configuration can be passed in as a Python dict. This will be converted into a
Configuration in Java
- @param path: path to Hadoop file
- @param inputFormatClass: fully qualified classname of Hadoop InputFormat
+ :param path: path to Hadoop file
+ :param inputFormatClass: fully qualified classname of Hadoop InputFormat
(e.g. "org.apache.hadoop.mapreduce.lib.input.TextInputFormat")
- @param keyClass: fully qualified classname of key Writable class
+ :param keyClass: fully qualified classname of key Writable class
(e.g. "org.apache.hadoop.io.Text")
- @param valueClass: fully qualified classname of value Writable class
+ :param valueClass: fully qualified classname of value Writable class
(e.g. "org.apache.hadoop.io.LongWritable")
- @param keyConverter: (None by default)
- @param valueConverter: (None by default)
- @param conf: Hadoop configuration, passed in as a dict
+ :param keyConverter: (None by default)
+ :param valueConverter: (None by default)
+ :param conf: Hadoop configuration, passed in as a dict
(None by default)
- @param batchSize: The number of Python objects represented as a single
+ :param batchSize: The number of Python objects represented as a single
Java object. (default sc._default_batch_size_for_serialized_input)
"""
jconf = self._dictToJavaMap(conf)
@@ -472,17 +477,17 @@ def newAPIHadoopRDD(self, inputFormatClass, keyClass, valueClass, keyConverter=N
This will be converted into a Configuration in Java.
The mechanism is the same as for sc.sequenceFile.
- @param inputFormatClass: fully qualified classname of Hadoop InputFormat
+ :param inputFormatClass: fully qualified classname of Hadoop InputFormat
(e.g. "org.apache.hadoop.mapreduce.lib.input.TextInputFormat")
- @param keyClass: fully qualified classname of key Writable class
+ :param keyClass: fully qualified classname of key Writable class
(e.g. "org.apache.hadoop.io.Text")
- @param valueClass: fully qualified classname of value Writable class
+ :param valueClass: fully qualified classname of value Writable class
(e.g. "org.apache.hadoop.io.LongWritable")
- @param keyConverter: (None by default)
- @param valueConverter: (None by default)
- @param conf: Hadoop configuration, passed in as a dict
+ :param keyConverter: (None by default)
+ :param valueConverter: (None by default)
+ :param conf: Hadoop configuration, passed in as a dict
(None by default)
- @param batchSize: The number of Python objects represented as a single
+ :param batchSize: The number of Python objects represented as a single
Java object. (default sc._default_batch_size_for_serialized_input)
"""
jconf = self._dictToJavaMap(conf)
@@ -503,18 +508,18 @@ def hadoopFile(self, path, inputFormatClass, keyClass, valueClass, keyConverter=
A Hadoop configuration can be passed in as a Python dict. This will be converted into a
Configuration in Java.
- @param path: path to Hadoop file
- @param inputFormatClass: fully qualified classname of Hadoop InputFormat
+ :param path: path to Hadoop file
+ :param inputFormatClass: fully qualified classname of Hadoop InputFormat
(e.g. "org.apache.hadoop.mapred.TextInputFormat")
- @param keyClass: fully qualified classname of key Writable class
+ :param keyClass: fully qualified classname of key Writable class
(e.g. "org.apache.hadoop.io.Text")
- @param valueClass: fully qualified classname of value Writable class
+ :param valueClass: fully qualified classname of value Writable class
(e.g. "org.apache.hadoop.io.LongWritable")
- @param keyConverter: (None by default)
- @param valueConverter: (None by default)
- @param conf: Hadoop configuration, passed in as a dict
+ :param keyConverter: (None by default)
+ :param valueConverter: (None by default)
+ :param conf: Hadoop configuration, passed in as a dict
(None by default)
- @param batchSize: The number of Python objects represented as a single
+ :param batchSize: The number of Python objects represented as a single
Java object. (default sc._default_batch_size_for_serialized_input)
"""
jconf = self._dictToJavaMap(conf)
@@ -533,17 +538,17 @@ def hadoopRDD(self, inputFormatClass, keyClass, valueClass, keyConverter=None,
This will be converted into a Configuration in Java.
The mechanism is the same as for sc.sequenceFile.
- @param inputFormatClass: fully qualified classname of Hadoop InputFormat
+ :param inputFormatClass: fully qualified classname of Hadoop InputFormat
(e.g. "org.apache.hadoop.mapred.TextInputFormat")
- @param keyClass: fully qualified classname of key Writable class
+ :param keyClass: fully qualified classname of key Writable class
(e.g. "org.apache.hadoop.io.Text")
- @param valueClass: fully qualified classname of value Writable class
+ :param valueClass: fully qualified classname of value Writable class
(e.g. "org.apache.hadoop.io.LongWritable")
- @param keyConverter: (None by default)
- @param valueConverter: (None by default)
- @param conf: Hadoop configuration, passed in as a dict
+ :param keyConverter: (None by default)
+ :param valueConverter: (None by default)
+ :param conf: Hadoop configuration, passed in as a dict
(None by default)
- @param batchSize: The number of Python objects represented as a single
+ :param batchSize: The number of Python objects represented as a single
Java object. (default sc._default_batch_size_for_serialized_input)
"""
jconf = self._dictToJavaMap(conf)
@@ -792,6 +797,40 @@ def runJob(self, rdd, partitionFunc, partitions=None, allowLocal=False):
it = self._jvm.PythonRDD.runJob(self._jsc.sc(), mappedRDD._jrdd, javaPartitions, allowLocal)
return list(mappedRDD._collect_iterator_through_file(it))
+ def _add_profile(self, id, profileAcc):
+ if not self._profile_stats:
+ dump_path = self._conf.get("spark.python.profile.dump")
+ if dump_path:
+ atexit.register(self.dump_profiles, dump_path)
+ else:
+ atexit.register(self.show_profiles)
+
+ self._profile_stats.append([id, profileAcc, False])
+
+ def show_profiles(self):
+ """ Print the profile stats to stdout """
+ for i, (id, acc, showed) in enumerate(self._profile_stats):
+ stats = acc.value
+ if not showed and stats:
+ print "=" * 60
+ print "Profile of RDD" % id
+ print "=" * 60
+ stats.sort_stats("time", "cumulative").print_stats()
+ # mark it as showed
+ self._profile_stats[i][2] = True
+
+ def dump_profiles(self, path):
+ """ Dump the profile stats into directory `path`
+ """
+ if not os.path.exists(path):
+ os.makedirs(path)
+ for id, acc, _ in self._profile_stats:
+ stats = acc.value
+ if stats:
+ p = os.path.join(path, "rdd_%d.pstats" % id)
+ stats.dump_stats(p)
+ self._profile_stats = []
+
def _test():
import atexit
diff --git a/python/pyspark/mllib/classification.py b/python/pyspark/mllib/classification.py
index ac142fb49a90c..e295c9d0954d9 100644
--- a/python/pyspark/mllib/classification.py
+++ b/python/pyspark/mllib/classification.py
@@ -21,7 +21,7 @@
from numpy import array
from pyspark import SparkContext, PickleSerializer
-from pyspark.mllib.linalg import SparseVector, _convert_to_vector
+from pyspark.mllib.linalg import SparseVector, _convert_to_vector, _to_java_object_rdd
from pyspark.mllib.regression import LabeledPoint, LinearModel, _regression_train_wrapper
@@ -79,21 +79,24 @@ def train(cls, data, iterations=100, step=1.0, miniBatchFraction=1.0,
"""
Train a logistic regression model on the given data.
- @param data: The training data.
- @param iterations: The number of iterations (default: 100).
- @param step: The step parameter used in SGD
+ :param data: The training data.
+ :param iterations: The number of iterations (default: 100).
+ :param step: The step parameter used in SGD
(default: 1.0).
- @param miniBatchFraction: Fraction of data to be used for each SGD
+ :param miniBatchFraction: Fraction of data to be used for each SGD
iteration.
- @param initialWeights: The initial weights (default: None).
- @param regParam: The regularizer parameter (default: 1.0).
- @param regType: The type of regularizer used for training
+ :param initialWeights: The initial weights (default: None).
+ :param regParam: The regularizer parameter (default: 1.0).
+ :param regType: The type of regularizer used for training
our model.
- Allowed values: "l1" for using L1Updater,
- "l2" for using
- SquaredL2Updater,
- "none" for no regularizer.
- (default: "none")
+
+ :Allowed values:
+ - "l1" for using L1Updater
+ - "l2" for using SquaredL2Updater
+ - "none" for no regularizer
+
+ (default: "none")
+
@param intercept: Boolean parameter which indicates the use
or not of the augmented representation for
training data (i.e. whether bias features
@@ -148,21 +151,24 @@ def train(cls, data, iterations=100, step=1.0, regParam=1.0,
"""
Train a support vector machine on the given data.
- @param data: The training data.
- @param iterations: The number of iterations (default: 100).
- @param step: The step parameter used in SGD
+ :param data: The training data.
+ :param iterations: The number of iterations (default: 100).
+ :param step: The step parameter used in SGD
(default: 1.0).
- @param regParam: The regularizer parameter (default: 1.0).
- @param miniBatchFraction: Fraction of data to be used for each SGD
+ :param regParam: The regularizer parameter (default: 1.0).
+ :param miniBatchFraction: Fraction of data to be used for each SGD
iteration.
- @param initialWeights: The initial weights (default: None).
- @param regType: The type of regularizer used for training
+ :param initialWeights: The initial weights (default: None).
+ :param regType: The type of regularizer used for training
our model.
- Allowed values: "l1" for using L1Updater,
- "l2" for using
- SquaredL2Updater,
- "none" for no regularizer.
- (default: "none")
+
+ :Allowed values:
+ - "l1" for using L1Updater
+ - "l2" for using SquaredL2Updater,
+ - "none" for no regularizer.
+
+ (default: "none")
+
@param intercept: Boolean parameter which indicates the use
or not of the augmented representation for
training data (i.e. whether bias features
@@ -232,13 +238,13 @@ def train(cls, data, lambda_=1.0):
classification. By making every vector a 0-1 vector, it can also be
used as Bernoulli NB (U{http://tinyurl.com/p7c96j6}).
- @param data: RDD of NumPy vectors, one per element, where the first
+ :param data: RDD of NumPy vectors, one per element, where the first
coordinate is the label and the rest is the feature vector
(e.g. a count vector).
- @param lambda_: The smoothing parameter
+ :param lambda_: The smoothing parameter
"""
sc = data.context
- jlist = sc._jvm.PythonMLLibAPI().trainNaiveBayes(data._to_java_object_rdd(), lambda_)
+ jlist = sc._jvm.PythonMLLibAPI().trainNaiveBayes(_to_java_object_rdd(data), lambda_)
labels, pi, theta = PickleSerializer().loads(str(sc._jvm.SerDe.dumps(jlist)))
return NaiveBayesModel(labels.toArray(), pi.toArray(), numpy.array(theta))
diff --git a/python/pyspark/mllib/clustering.py b/python/pyspark/mllib/clustering.py
index 12c56022717a5..5ee7997104d21 100644
--- a/python/pyspark/mllib/clustering.py
+++ b/python/pyspark/mllib/clustering.py
@@ -17,7 +17,7 @@
from pyspark import SparkContext
from pyspark.serializers import PickleSerializer, AutoBatchedSerializer
-from pyspark.mllib.linalg import SparseVector, _convert_to_vector
+from pyspark.mllib.linalg import SparseVector, _convert_to_vector, _to_java_object_rdd
__all__ = ['KMeansModel', 'KMeans']
@@ -85,7 +85,7 @@ def train(cls, rdd, k, maxIterations=100, runs=1, initializationMode="k-means||"
# cache serialized data to avoid objects over head in JVM
cached = rdd.map(_convert_to_vector)._reserialize(AutoBatchedSerializer(ser)).cache()
model = sc._jvm.PythonMLLibAPI().trainKMeansModel(
- cached._to_java_object_rdd(), k, maxIterations, runs, initializationMode)
+ _to_java_object_rdd(cached), k, maxIterations, runs, initializationMode)
bytes = sc._jvm.SerDe.dumps(model.clusterCenters())
centers = ser.loads(str(bytes))
return KMeansModel([c.toArray() for c in centers])
diff --git a/python/pyspark/mllib/feature.py b/python/pyspark/mllib/feature.py
new file mode 100644
index 0000000000000..b5a3f22c6907e
--- /dev/null
+++ b/python/pyspark/mllib/feature.py
@@ -0,0 +1,194 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+"""
+Python package for feature in MLlib.
+"""
+from pyspark.serializers import PickleSerializer, AutoBatchedSerializer
+from pyspark.mllib.linalg import _convert_to_vector, _to_java_object_rdd
+
+__all__ = ['Word2Vec', 'Word2VecModel']
+
+
+class Word2VecModel(object):
+ """
+ class for Word2Vec model
+ """
+ def __init__(self, sc, java_model):
+ """
+ :param sc: Spark context
+ :param java_model: Handle to Java model object
+ """
+ self._sc = sc
+ self._java_model = java_model
+
+ def __del__(self):
+ self._sc._gateway.detach(self._java_model)
+
+ def transform(self, word):
+ """
+ :param word: a word
+ :return: vector representation of word
+
+ Transforms a word to its vector representation
+
+ Note: local use only
+ """
+ # TODO: make transform usable in RDD operations from python side
+ result = self._java_model.transform(word)
+ return PickleSerializer().loads(str(self._sc._jvm.SerDe.dumps(result)))
+
+ def findSynonyms(self, x, num):
+ """
+ :param x: a word or a vector representation of word
+ :param num: number of synonyms to find
+ :return: array of (word, cosineSimilarity)
+
+ Find synonyms of a word
+
+ Note: local use only
+ """
+ # TODO: make findSynonyms usable in RDD operations from python side
+ ser = PickleSerializer()
+ if type(x) == str:
+ jlist = self._java_model.findSynonyms(x, num)
+ else:
+ bytes = bytearray(ser.dumps(_convert_to_vector(x)))
+ vec = self._sc._jvm.SerDe.loads(bytes)
+ jlist = self._java_model.findSynonyms(vec, num)
+ words, similarity = ser.loads(str(self._sc._jvm.SerDe.dumps(jlist)))
+ return zip(words, similarity)
+
+
+class Word2Vec(object):
+ """
+ Word2Vec creates vector representation of words in a text corpus.
+ The algorithm first constructs a vocabulary from the corpus
+ and then learns vector representation of words in the vocabulary.
+ The vector representation can be used as features in
+ natural language processing and machine learning algorithms.
+
+ We used skip-gram model in our implementation and hierarchical softmax
+ method to train the model. The variable names in the implementation
+ matches the original C implementation.
+ For original C implementation, see https://code.google.com/p/word2vec/
+ For research papers, see
+ Efficient Estimation of Word Representations in Vector Space
+ and
+ Distributed Representations of Words and Phrases and their Compositionality.
+
+ >>> sentence = "a b " * 100 + "a c " * 10
+ >>> localDoc = [sentence, sentence]
+ >>> doc = sc.parallelize(localDoc).map(lambda line: line.split(" "))
+ >>> model = Word2Vec().setVectorSize(10).setSeed(42L).fit(doc)
+ >>> syms = model.findSynonyms("a", 2)
+ >>> str(syms[0][0])
+ 'b'
+ >>> str(syms[1][0])
+ 'c'
+ >>> len(syms)
+ 2
+ >>> vec = model.transform("a")
+ >>> len(vec)
+ 10
+ >>> syms = model.findSynonyms(vec, 2)
+ >>> str(syms[0][0])
+ 'b'
+ >>> str(syms[1][0])
+ 'c'
+ >>> len(syms)
+ 2
+ """
+ def __init__(self):
+ """
+ Construct Word2Vec instance
+ """
+ self.vectorSize = 100
+ self.learningRate = 0.025
+ self.numPartitions = 1
+ self.numIterations = 1
+ self.seed = 42L
+
+ def setVectorSize(self, vectorSize):
+ """
+ Sets vector size (default: 100).
+ """
+ self.vectorSize = vectorSize
+ return self
+
+ def setLearningRate(self, learningRate):
+ """
+ Sets initial learning rate (default: 0.025).
+ """
+ self.learningRate = learningRate
+ return self
+
+ def setNumPartitions(self, numPartitions):
+ """
+ Sets number of partitions (default: 1). Use a small number for accuracy.
+ """
+ self.numPartitions = numPartitions
+ return self
+
+ def setNumIterations(self, numIterations):
+ """
+ Sets number of iterations (default: 1), which should be smaller than or equal to number of
+ partitions.
+ """
+ self.numIterations = numIterations
+ return self
+
+ def setSeed(self, seed):
+ """
+ Sets random seed.
+ """
+ self.seed = seed
+ return self
+
+ def fit(self, data):
+ """
+ Computes the vector representation of each word in vocabulary.
+
+ :param data: training data. RDD of subtype of Iterable[String]
+ :return: python Word2VecModel instance
+ """
+ sc = data.context
+ ser = PickleSerializer()
+ vectorSize = self.vectorSize
+ learningRate = self.learningRate
+ numPartitions = self.numPartitions
+ numIterations = self.numIterations
+ seed = self.seed
+
+ model = sc._jvm.PythonMLLibAPI().trainWord2Vec(
+ _to_java_object_rdd(data), vectorSize,
+ learningRate, numPartitions, numIterations, seed)
+ return Word2VecModel(sc, model)
+
+
+def _test():
+ import doctest
+ from pyspark import SparkContext
+ globs = globals().copy()
+ globs['sc'] = SparkContext('local[4]', 'PythonTest', batchSize=2)
+ (failure_count, test_count) = doctest.testmod(globs=globs, optionflags=doctest.ELLIPSIS)
+ globs['sc'].stop()
+ if failure_count:
+ exit(-1)
+
+if __name__ == "__main__":
+ _test()
diff --git a/python/pyspark/mllib/linalg.py b/python/pyspark/mllib/linalg.py
index 0a5dcaac55e46..773d8d393805d 100644
--- a/python/pyspark/mllib/linalg.py
+++ b/python/pyspark/mllib/linalg.py
@@ -29,6 +29,8 @@
import numpy as np
+from pyspark.serializers import AutoBatchedSerializer, PickleSerializer
+
__all__ = ['Vector', 'DenseVector', 'SparseVector', 'Vectors']
@@ -50,6 +52,17 @@ def fast_pickle_array(ar):
_have_scipy = False
+# this will call the MLlib version of pythonToJava()
+def _to_java_object_rdd(rdd):
+ """ Return an JavaRDD of Object by unpickling
+
+ It will convert each Python object into Java object by Pyrolite, whenever the
+ RDD is serialized in batch or not.
+ """
+ rdd = rdd._reserialize(AutoBatchedSerializer(PickleSerializer()))
+ return rdd.ctx._jvm.SerDe.pythonToJava(rdd._jrdd, True)
+
+
def _convert_to_vector(l):
if isinstance(l, Vector):
return l
@@ -63,6 +76,41 @@ def _convert_to_vector(l):
raise TypeError("Cannot convert type %s into Vector" % type(l))
+def _vector_size(v):
+ """
+ Returns the size of the vector.
+
+ >>> _vector_size([1., 2., 3.])
+ 3
+ >>> _vector_size((1., 2., 3.))
+ 3
+ >>> _vector_size(array.array('d', [1., 2., 3.]))
+ 3
+ >>> _vector_size(np.zeros(3))
+ 3
+ >>> _vector_size(np.zeros((3, 1)))
+ 3
+ >>> _vector_size(np.zeros((1, 3)))
+ Traceback (most recent call last):
+ ...
+ ValueError: Cannot treat an ndarray of shape (1, 3) as a vector
+ """
+ if isinstance(v, Vector):
+ return len(v)
+ elif type(v) in (array.array, list, tuple):
+ return len(v)
+ elif type(v) == np.ndarray:
+ if v.ndim == 1 or (v.ndim == 2 and v.shape[1] == 1):
+ return len(v)
+ else:
+ raise ValueError("Cannot treat an ndarray of shape %s as a vector" % str(v.shape))
+ elif _have_scipy and scipy.sparse.issparse(v):
+ assert v.shape[1] == 1, "Expected column vector"
+ return v.shape[0]
+ else:
+ raise TypeError("Cannot treat type %s as a vector" % type(v))
+
+
class Vector(object):
"""
Abstract class for DenseVector and SparseVector
@@ -76,6 +124,9 @@ def toArray(self):
class DenseVector(Vector):
+ """
+ A dense vector represented by a value array.
+ """
def __init__(self, ar):
if not isinstance(ar, array.array):
ar = array.array('d', ar)
@@ -100,15 +151,31 @@ def dot(self, other):
5.0
>>> dense.dot(np.array(range(1, 3)))
5.0
+ >>> dense.dot([1.,])
+ Traceback (most recent call last):
+ ...
+ AssertionError: dimension mismatch
+ >>> dense.dot(np.reshape([1., 2., 3., 4.], (2, 2), order='F'))
+ array([ 5., 11.])
+ >>> dense.dot(np.reshape([1., 2., 3.], (3, 1), order='F'))
+ Traceback (most recent call last):
+ ...
+ AssertionError: dimension mismatch
"""
- if isinstance(other, SparseVector):
- return other.dot(self)
+ if type(other) == np.ndarray and other.ndim > 1:
+ assert len(self) == other.shape[0], "dimension mismatch"
+ return np.dot(self.toArray(), other)
elif _have_scipy and scipy.sparse.issparse(other):
- return other.transpose().dot(self.toArray())[0]
- elif isinstance(other, Vector):
- return np.dot(self.toArray(), other.toArray())
+ assert len(self) == other.shape[0], "dimension mismatch"
+ return other.transpose().dot(self.toArray())
else:
- return np.dot(self.toArray(), other)
+ assert len(self) == _vector_size(other), "dimension mismatch"
+ if isinstance(other, SparseVector):
+ return other.dot(self)
+ elif isinstance(other, Vector):
+ return np.dot(self.toArray(), other.toArray())
+ else:
+ return np.dot(self.toArray(), other)
def squared_distance(self, other):
"""
@@ -126,7 +193,16 @@ def squared_distance(self, other):
>>> sparse1 = SparseVector(2, [0, 1], [2., 1.])
>>> dense1.squared_distance(sparse1)
2.0
+ >>> dense1.squared_distance([1.,])
+ Traceback (most recent call last):
+ ...
+ AssertionError: dimension mismatch
+ >>> dense1.squared_distance(SparseVector(1, [0,], [1.,]))
+ Traceback (most recent call last):
+ ...
+ AssertionError: dimension mismatch
"""
+ assert len(self) == _vector_size(other), "dimension mismatch"
if isinstance(other, SparseVector):
return other.squared_distance(self)
elif _have_scipy and scipy.sparse.issparse(other):
@@ -165,20 +241,18 @@ def __getattr__(self, item):
class SparseVector(Vector):
-
"""
A simple sparse vector class for passing data to MLlib. Users may
alternatively pass SciPy's {scipy.sparse} data types.
"""
-
def __init__(self, size, *args):
"""
Create a sparse vector, using either a dictionary, a list of
(index, value) pairs, or two separate arrays of indices and
values (sorted by index).
- @param size: Size of the vector.
- @param args: Non-zero entries, as a dictionary, list of tupes,
+ :param size: Size of the vector.
+ :param args: Non-zero entries, as a dictionary, list of tupes,
or two sorted lists containing indices and values.
>>> print SparseVector(4, {1: 1.0, 3: 5.5})
@@ -222,20 +296,33 @@ def dot(self, other):
0.0
>>> a.dot(np.array([[1, 1], [2, 2], [3, 3], [4, 4]]))
array([ 22., 22.])
+ >>> a.dot([1., 2., 3.])
+ Traceback (most recent call last):
+ ...
+ AssertionError: dimension mismatch
+ >>> a.dot(np.array([1., 2.]))
+ Traceback (most recent call last):
+ ...
+ AssertionError: dimension mismatch
+ >>> a.dot(DenseVector([1., 2.]))
+ Traceback (most recent call last):
+ ...
+ AssertionError: dimension mismatch
+ >>> a.dot(np.zeros((3, 2)))
+ Traceback (most recent call last):
+ ...
+ AssertionError: dimension mismatch
"""
if type(other) == np.ndarray:
- if other.ndim == 1:
- result = 0.0
- for i in xrange(len(self.indices)):
- result += self.values[i] * other[self.indices[i]]
- return result
- elif other.ndim == 2:
+ if other.ndim == 2:
results = [self.dot(other[:, i]) for i in xrange(other.shape[1])]
return np.array(results)
- else:
- raise Exception("Cannot call dot with %d-dimensional array" % other.ndim)
+ elif other.ndim > 2:
+ raise ValueError("Cannot call dot with %d-dimensional array" % other.ndim)
+
+ assert len(self) == _vector_size(other), "dimension mismatch"
- elif type(other) in (array.array, DenseVector):
+ if type(other) in (np.ndarray, array.array, DenseVector):
result = 0.0
for i in xrange(len(self.indices)):
result += self.values[i] * other[self.indices[i]]
@@ -254,6 +341,7 @@ def dot(self, other):
else:
j += 1
return result
+
else:
return self.dot(_convert_to_vector(other))
@@ -273,7 +361,16 @@ def squared_distance(self, other):
30.0
>>> b.squared_distance(a)
30.0
+ >>> b.squared_distance([1., 2.])
+ Traceback (most recent call last):
+ ...
+ AssertionError: dimension mismatch
+ >>> b.squared_distance(SparseVector(3, [1,], [1.0,]))
+ Traceback (most recent call last):
+ ...
+ AssertionError: dimension mismatch
"""
+ assert len(self) == _vector_size(other), "dimension mismatch"
if type(other) in (list, array.array, DenseVector, np.array, np.ndarray):
if type(other) is np.array and other.ndim != 1:
raise Exception("Cannot call squared_distance with %d-dimensional array" %
@@ -348,7 +445,6 @@ def __eq__(self, other):
>>> v1 != v2
False
"""
-
return (isinstance(other, self.__class__)
and other.size == self.size
and other.indices == self.indices
@@ -375,8 +471,8 @@ def sparse(size, *args):
(index, value) pairs, or two separate arrays of indices and
values (sorted by index).
- @param size: Size of the vector.
- @param args: Non-zero entries, as a dictionary, list of tupes,
+ :param size: Size of the vector.
+ :param args: Non-zero entries, as a dictionary, list of tupes,
or two sorted lists containing indices and values.
>>> print Vectors.sparse(4, {1: 1.0, 3: 5.5})
@@ -414,23 +510,32 @@ def stringify(vector):
class Matrix(object):
- """ the Matrix """
- def __init__(self, nRow, nCol):
- self.nRow = nRow
- self.nCol = nCol
+ """
+ Represents a local matrix.
+ """
+
+ def __init__(self, numRows, numCols):
+ self.numRows = numRows
+ self.numCols = numCols
def toArray(self):
+ """
+ Returns its elements in a NumPy ndarray.
+ """
raise NotImplementedError
class DenseMatrix(Matrix):
- def __init__(self, nRow, nCol, values):
- Matrix.__init__(self, nRow, nCol)
- assert len(values) == nRow * nCol
+ """
+ Column-major dense matrix.
+ """
+ def __init__(self, numRows, numCols, values):
+ Matrix.__init__(self, numRows, numCols)
+ assert len(values) == numRows * numCols
self.values = values
def __reduce__(self):
- return DenseMatrix, (self.nRow, self.nCol, self.values)
+ return DenseMatrix, (self.numRows, self.numCols, self.values)
def toArray(self):
"""
@@ -439,10 +544,10 @@ def toArray(self):
>>> arr = array.array('d', [float(i) for i in range(4)])
>>> m = DenseMatrix(2, 2, arr)
>>> m.toArray()
- array([[ 0., 1.],
- [ 2., 3.]])
+ array([[ 0., 2.],
+ [ 1., 3.]])
"""
- return np.ndarray((self.nRow, self.nCol), np.float64, buffer=self.values.tostring())
+ return np.reshape(self.values, (self.numRows, self.numCols), order='F')
def _test():
diff --git a/python/pyspark/mllib/random.py b/python/pyspark/mllib/random.py
index a787e4dea2c55..73baba4ace5f6 100644
--- a/python/pyspark/mllib/random.py
+++ b/python/pyspark/mllib/random.py
@@ -32,7 +32,7 @@ def serialize(f):
@wraps(f)
def func(sc, *a, **kw):
jrdd = f(sc, *a, **kw)
- return RDD(sc._jvm.PythonRDD.javaToPython(jrdd), sc,
+ return RDD(sc._jvm.SerDe.javaToPython(jrdd), sc,
BatchedSerializer(PickleSerializer(), 1024))
return func
diff --git a/python/pyspark/mllib/recommendation.py b/python/pyspark/mllib/recommendation.py
index a880a4bd7eab6..f99d8acf030e6 100644
--- a/python/pyspark/mllib/recommendation.py
+++ b/python/pyspark/mllib/recommendation.py
@@ -18,6 +18,7 @@
from pyspark import SparkContext
from pyspark.serializers import PickleSerializer, AutoBatchedSerializer
from pyspark.rdd import RDD
+from pyspark.mllib.linalg import _to_java_object_rdd
__all__ = ['MatrixFactorizationModel', 'ALS']
@@ -77,9 +78,9 @@ def predictAll(self, user_product):
first = tuple(map(int, first))
assert all(type(x) is int for x in first), "user and product in user_product shoul be int"
sc = self._context
- tuplerdd = sc._jvm.SerDe.asTupleRDD(user_product._to_java_object_rdd().rdd())
+ tuplerdd = sc._jvm.SerDe.asTupleRDD(_to_java_object_rdd(user_product).rdd())
jresult = self._java_model.predict(tuplerdd).toJavaRDD()
- return RDD(sc._jvm.PythonRDD.javaToPython(jresult), sc,
+ return RDD(sc._jvm.SerDe.javaToPython(jresult), sc,
AutoBatchedSerializer(PickleSerializer()))
def userFeatures(self):
@@ -111,7 +112,7 @@ def _prepare(cls, ratings):
# serialize them by AutoBatchedSerializer before cache to reduce the
# objects overhead in JVM
cached = ratings._reserialize(AutoBatchedSerializer(PickleSerializer())).cache()
- return cached._to_java_object_rdd()
+ return _to_java_object_rdd(cached)
@classmethod
def train(cls, ratings, rank, iterations=5, lambda_=0.01, blocks=-1):
diff --git a/python/pyspark/mllib/regression.py b/python/pyspark/mllib/regression.py
index cbdbc09858013..93e17faf5cd51 100644
--- a/python/pyspark/mllib/regression.py
+++ b/python/pyspark/mllib/regression.py
@@ -19,10 +19,10 @@
from numpy import array
from pyspark import SparkContext
-from pyspark.mllib.linalg import SparseVector, _convert_to_vector
from pyspark.serializers import PickleSerializer, AutoBatchedSerializer
+from pyspark.mllib.linalg import SparseVector, _convert_to_vector, _to_java_object_rdd
-__all__ = ['LabeledPoint', 'LinearModel', 'LinearRegressionModel', 'RidgeRegressionModel'
+__all__ = ['LabeledPoint', 'LinearModel', 'LinearRegressionModel', 'RidgeRegressionModel',
'LinearRegressionWithSGD', 'LassoWithSGD', 'RidgeRegressionWithSGD']
@@ -31,8 +31,8 @@ class LabeledPoint(object):
"""
The features and labels of a data point.
- @param label: Label for this data point.
- @param features: Vector of features for this point (NumPy array, list,
+ :param label: Label for this data point.
+ :param features: Vector of features for this point (NumPy array, list,
pyspark.mllib.linalg.SparseVector, or scipy.sparse column matrix)
"""
@@ -66,6 +66,9 @@ def weights(self):
def intercept(self):
return self._intercept
+ def __repr__(self):
+ return "(weights=%s, intercept=%s)" % (self._coeff, self._intercept)
+
class LinearRegressionModelBase(LinearModel):
@@ -128,7 +131,7 @@ def _regression_train_wrapper(sc, train_func, modelClass, data, initial_weights)
# use AutoBatchedSerializer before cache to reduce the memory
# overhead in JVM
cached = data._reserialize(AutoBatchedSerializer(ser)).cache()
- ans = train_func(cached._to_java_object_rdd(), initial_bytes)
+ ans = train_func(_to_java_object_rdd(cached), initial_bytes)
assert len(ans) == 2, "JVM call result had unexpected length"
weights = ser.loads(str(ans[0]))
return modelClass(weights, ans[1])
@@ -142,21 +145,24 @@ def train(cls, data, iterations=100, step=1.0, miniBatchFraction=1.0,
"""
Train a linear regression model on the given data.
- @param data: The training data.
- @param iterations: The number of iterations (default: 100).
- @param step: The step parameter used in SGD
+ :param data: The training data.
+ :param iterations: The number of iterations (default: 100).
+ :param step: The step parameter used in SGD
(default: 1.0).
- @param miniBatchFraction: Fraction of data to be used for each SGD
+ :param miniBatchFraction: Fraction of data to be used for each SGD
iteration.
- @param initialWeights: The initial weights (default: None).
- @param regParam: The regularizer parameter (default: 1.0).
- @param regType: The type of regularizer used for training
+ :param initialWeights: The initial weights (default: None).
+ :param regParam: The regularizer parameter (default: 1.0).
+ :param regType: The type of regularizer used for training
our model.
- Allowed values: "l1" for using L1Updater,
- "l2" for using
- SquaredL2Updater,
- "none" for no regularizer.
- (default: "none")
+
+ :Allowed values:
+ - "l1" for using L1Updater,
+ - "l2" for using SquaredL2Updater,
+ - "none" for no regularizer.
+
+ (default: "none")
+
@param intercept: Boolean parameter which indicates the use
or not of the augmented representation for
training data (i.e. whether bias features
diff --git a/python/pyspark/mllib/stat.py b/python/pyspark/mllib/stat.py
index b9de0909a6fb1..a6019dadf781c 100644
--- a/python/pyspark/mllib/stat.py
+++ b/python/pyspark/mllib/stat.py
@@ -22,6 +22,7 @@
from functools import wraps
from pyspark import PickleSerializer
+from pyspark.mllib.linalg import _to_java_object_rdd
__all__ = ['MultivariateStatisticalSummary', 'Statistics']
@@ -106,7 +107,7 @@ def colStats(rdd):
array([ 2., 0., 0., -2.])
"""
sc = rdd.ctx
- jrdd = rdd._to_java_object_rdd()
+ jrdd = _to_java_object_rdd(rdd)
cStats = sc._jvm.PythonMLLibAPI().colStats(jrdd)
return MultivariateStatisticalSummary(sc, cStats)
@@ -162,14 +163,14 @@ def corr(x, y=None, method=None):
if type(y) == str:
raise TypeError("Use 'method=' to specify method name.")
- jx = x._to_java_object_rdd()
+ jx = _to_java_object_rdd(x)
if not y:
resultMat = sc._jvm.PythonMLLibAPI().corr(jx, method)
bytes = sc._jvm.SerDe.dumps(resultMat)
ser = PickleSerializer()
return ser.loads(str(bytes)).toArray()
else:
- jy = y._to_java_object_rdd()
+ jy = _to_java_object_rdd(y)
return sc._jvm.PythonMLLibAPI().corr(jx, jy, method)
diff --git a/python/pyspark/mllib/tests.py b/python/pyspark/mllib/tests.py
index f72e88ba6e2ba..463faf7b6f520 100644
--- a/python/pyspark/mllib/tests.py
+++ b/python/pyspark/mllib/tests.py
@@ -25,14 +25,18 @@
from numpy import array, array_equal
if sys.version_info[:2] <= (2, 6):
- import unittest2 as unittest
+ try:
+ import unittest2 as unittest
+ except ImportError:
+ sys.stderr.write('Please install unittest2 to test with Python 2.6 or earlier')
+ sys.exit(1)
else:
import unittest
from pyspark.serializers import PickleSerializer
from pyspark.mllib.linalg import Vector, SparseVector, DenseVector, _convert_to_vector
from pyspark.mllib.regression import LabeledPoint
-from pyspark.tests import PySparkTestCase
+from pyspark.tests import ReusedPySparkTestCase as PySparkTestCase
_have_scipy = False
diff --git a/python/pyspark/mllib/tree.py b/python/pyspark/mllib/tree.py
index f59a818a6e74d..0938eebd3a548 100644
--- a/python/pyspark/mllib/tree.py
+++ b/python/pyspark/mllib/tree.py
@@ -19,7 +19,7 @@
from pyspark import SparkContext, RDD
from pyspark.serializers import BatchedSerializer, PickleSerializer
-from pyspark.mllib.linalg import Vector, _convert_to_vector
+from pyspark.mllib.linalg import Vector, _convert_to_vector, _to_java_object_rdd
from pyspark.mllib.regression import LabeledPoint
__all__ = ['DecisionTreeModel', 'DecisionTree']
@@ -48,6 +48,7 @@ def __del__(self):
def predict(self, x):
"""
Predict the label of one or more examples.
+
:param x: Data point (feature vector),
or an RDD of data points (feature vectors).
"""
@@ -60,8 +61,8 @@ def predict(self, x):
return self._sc.parallelize([])
if not isinstance(first[0], Vector):
x = x.map(_convert_to_vector)
- jPred = self._java_model.predict(x._to_java_object_rdd()).toJavaRDD()
- jpyrdd = self._sc._jvm.PythonRDD.javaToPython(jPred)
+ jPred = self._java_model.predict(_to_java_object_rdd(x)).toJavaRDD()
+ jpyrdd = self._sc._jvm.SerDe.javaToPython(jPred)
return RDD(jpyrdd, self._sc, BatchedSerializer(ser, 1024))
else:
@@ -77,8 +78,13 @@ def depth(self):
return self._java_model.depth()
def __repr__(self):
+ """ Print summary of model. """
return self._java_model.toString()
+ def toDebugString(self):
+ """ Print full model. """
+ return self._java_model.toDebugString()
+
class DecisionTree(object):
@@ -98,7 +104,7 @@ def _train(data, type, numClasses, categoricalFeaturesInfo,
first = data.first()
assert isinstance(first, LabeledPoint), "the data should be RDD of LabeledPoint"
sc = data.context
- jrdd = data._to_java_object_rdd()
+ jrdd = _to_java_object_rdd(data)
cfiMap = MapConverter().convert(categoricalFeaturesInfo,
sc._gateway._gateway_client)
model = sc._jvm.PythonMLLibAPI().trainDecisionTreeModel(
@@ -135,7 +141,6 @@ def trainClassifier(data, numClasses, categoricalFeaturesInfo,
>>> from numpy import array
>>> from pyspark.mllib.regression import LabeledPoint
>>> from pyspark.mllib.tree import DecisionTree
- >>> from pyspark.mllib.linalg import SparseVector
>>>
>>> data = [
... LabeledPoint(0.0, [0.0]),
@@ -145,7 +150,9 @@ def trainClassifier(data, numClasses, categoricalFeaturesInfo,
... ]
>>> model = DecisionTree.trainClassifier(sc.parallelize(data), 2, {})
>>> print model, # it already has newline
- DecisionTreeModel classifier
+ DecisionTreeModel classifier of depth 1 with 3 nodes
+ >>> print model.toDebugString(), # it already has newline
+ DecisionTreeModel classifier of depth 1 with 3 nodes
If (feature 0 <= 0.5)
Predict: 0.0
Else (feature 0 > 0.5)
diff --git a/python/pyspark/mllib/util.py b/python/pyspark/mllib/util.py
index 8233d4e81f1ca..84b39a48619d2 100644
--- a/python/pyspark/mllib/util.py
+++ b/python/pyspark/mllib/util.py
@@ -19,7 +19,7 @@
import warnings
from pyspark.rdd import RDD
-from pyspark.serializers import BatchedSerializer, PickleSerializer
+from pyspark.serializers import AutoBatchedSerializer, PickleSerializer
from pyspark.mllib.linalg import Vectors, SparseVector, _convert_to_vector
from pyspark.mllib.regression import LabeledPoint
@@ -77,10 +77,10 @@ def loadLibSVMFile(sc, path, numFeatures=-1, minPartitions=None, multiclass=None
method parses each line into a LabeledPoint, where the feature
indices are converted to zero-based.
- @param sc: Spark context
- @param path: file or directory path in any Hadoop-supported file
+ :param sc: Spark context
+ :param path: file or directory path in any Hadoop-supported file
system URI
- @param numFeatures: number of features, which will be determined
+ :param numFeatures: number of features, which will be determined
from the input data if a nonpositive value
is given. This is useful when the dataset is
already split into multiple files and you
@@ -88,7 +88,7 @@ def loadLibSVMFile(sc, path, numFeatures=-1, minPartitions=None, multiclass=None
features may not present in certain files,
which leads to inconsistent feature
dimensions.
- @param minPartitions: min number of partitions
+ :param minPartitions: min number of partitions
@return: labeled data stored as an RDD of LabeledPoint
>>> from tempfile import NamedTemporaryFile
@@ -126,8 +126,8 @@ def saveAsLibSVMFile(data, dir):
"""
Save labeled data in LIBSVM format.
- @param data: an RDD of LabeledPoint to be saved
- @param dir: directory to save the data
+ :param data: an RDD of LabeledPoint to be saved
+ :param dir: directory to save the data
>>> from tempfile import NamedTemporaryFile
>>> from fileinput import input
@@ -149,10 +149,10 @@ def loadLabeledPoints(sc, path, minPartitions=None):
"""
Load labeled points saved using RDD.saveAsTextFile.
- @param sc: Spark context
- @param path: file or directory path in any Hadoop-supported file
+ :param sc: Spark context
+ :param path: file or directory path in any Hadoop-supported file
system URI
- @param minPartitions: min number of partitions
+ :param minPartitions: min number of partitions
@return: labeled data stored as an RDD of LabeledPoint
>>> from tempfile import NamedTemporaryFile
@@ -174,8 +174,8 @@ def loadLabeledPoints(sc, path, minPartitions=None):
"""
minPartitions = minPartitions or min(sc.defaultParallelism, 2)
jrdd = sc._jvm.PythonMLLibAPI().loadLabeledPoints(sc._jsc, path, minPartitions)
- jpyrdd = sc._jvm.PythonRDD.javaToPython(jrdd)
- return RDD(jpyrdd, sc, BatchedSerializer(PickleSerializer()))
+ jpyrdd = sc._jvm.SerDe.javaToPython(jrdd)
+ return RDD(jpyrdd, sc, AutoBatchedSerializer(PickleSerializer()))
def _test():
diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py
index 680140d72d03c..15be4bfec92f9 100644
--- a/python/pyspark/rdd.py
+++ b/python/pyspark/rdd.py
@@ -15,7 +15,6 @@
# limitations under the License.
#
-from base64 import standard_b64encode as b64enc
import copy
from collections import defaultdict
from itertools import chain, ifilter, imap
@@ -32,6 +31,7 @@
from random import Random
from math import sqrt, log, isinf, isnan
+from pyspark.accumulators import PStatsParam
from pyspark.serializers import NoOpSerializer, CartesianDeserializer, \
BatchedSerializer, CloudPickleSerializer, PairDeserializer, \
PickleSerializer, pack_long, AutoBatchedSerializer
@@ -752,7 +752,7 @@ def max(self, key=None):
"""
Find the maximum item in this RDD.
- @param key: A function used to generate key for comparing
+ :param key: A function used to generate key for comparing
>>> rdd = sc.parallelize([1.0, 5.0, 43.0, 10.0])
>>> rdd.max()
@@ -768,7 +768,7 @@ def min(self, key=None):
"""
Find the minimum item in this RDD.
- @param key: A function used to generate key for comparing
+ :param key: A function used to generate key for comparing
>>> rdd = sc.parallelize([2.0, 5.0, 43.0, 10.0])
>>> rdd.min()
@@ -1070,10 +1070,13 @@ def take(self, num):
# If we didn't find any rows after the previous iteration,
# quadruple and retry. Otherwise, interpolate the number of
# partitions we need to try, but overestimate it by 50%.
+ # We also cap the estimation in the end.
if len(items) == 0:
numPartsToTry = partsScanned * 4
else:
- numPartsToTry = int(1.5 * num * partsScanned / len(items))
+ # the first paramter of max is >=1 whenever partsScanned >= 2
+ numPartsToTry = int(1.5 * num * partsScanned / len(items)) - partsScanned
+ numPartsToTry = min(max(numPartsToTry, 1), partsScanned * 4)
left = num - len(items)
@@ -1115,9 +1118,9 @@ def saveAsNewAPIHadoopDataset(self, conf, keyConverter=None, valueConverter=None
converted for output using either user specified converters or, by default,
L{org.apache.spark.api.python.JavaToWritableConverter}.
- @param conf: Hadoop job configuration, passed in as a dict
- @param keyConverter: (None by default)
- @param valueConverter: (None by default)
+ :param conf: Hadoop job configuration, passed in as a dict
+ :param keyConverter: (None by default)
+ :param valueConverter: (None by default)
"""
jconf = self.ctx._dictToJavaMap(conf)
pickledRDD = self._toPickleSerialization()
@@ -1135,16 +1138,16 @@ def saveAsNewAPIHadoopFile(self, path, outputFormatClass, keyClass=None, valueCl
C{conf} is applied on top of the base Hadoop conf associated with the SparkContext
of this RDD to create a merged Hadoop MapReduce job configuration for saving the data.
- @param path: path to Hadoop file
- @param outputFormatClass: fully qualified classname of Hadoop OutputFormat
+ :param path: path to Hadoop file
+ :param outputFormatClass: fully qualified classname of Hadoop OutputFormat
(e.g. "org.apache.hadoop.mapreduce.lib.output.SequenceFileOutputFormat")
- @param keyClass: fully qualified classname of key Writable class
+ :param keyClass: fully qualified classname of key Writable class
(e.g. "org.apache.hadoop.io.IntWritable", None by default)
- @param valueClass: fully qualified classname of value Writable class
+ :param valueClass: fully qualified classname of value Writable class
(e.g. "org.apache.hadoop.io.Text", None by default)
- @param keyConverter: (None by default)
- @param valueConverter: (None by default)
- @param conf: Hadoop job configuration, passed in as a dict (None by default)
+ :param keyConverter: (None by default)
+ :param valueConverter: (None by default)
+ :param conf: Hadoop job configuration, passed in as a dict (None by default)
"""
jconf = self.ctx._dictToJavaMap(conf)
pickledRDD = self._toPickleSerialization()
@@ -1161,9 +1164,9 @@ def saveAsHadoopDataset(self, conf, keyConverter=None, valueConverter=None):
converted for output using either user specified converters or, by default,
L{org.apache.spark.api.python.JavaToWritableConverter}.
- @param conf: Hadoop job configuration, passed in as a dict
- @param keyConverter: (None by default)
- @param valueConverter: (None by default)
+ :param conf: Hadoop job configuration, passed in as a dict
+ :param keyConverter: (None by default)
+ :param valueConverter: (None by default)
"""
jconf = self.ctx._dictToJavaMap(conf)
pickledRDD = self._toPickleSerialization()
@@ -1182,17 +1185,17 @@ def saveAsHadoopFile(self, path, outputFormatClass, keyClass=None, valueClass=No
C{conf} is applied on top of the base Hadoop conf associated with the SparkContext
of this RDD to create a merged Hadoop MapReduce job configuration for saving the data.
- @param path: path to Hadoop file
- @param outputFormatClass: fully qualified classname of Hadoop OutputFormat
+ :param path: path to Hadoop file
+ :param outputFormatClass: fully qualified classname of Hadoop OutputFormat
(e.g. "org.apache.hadoop.mapred.SequenceFileOutputFormat")
- @param keyClass: fully qualified classname of key Writable class
+ :param keyClass: fully qualified classname of key Writable class
(e.g. "org.apache.hadoop.io.IntWritable", None by default)
- @param valueClass: fully qualified classname of value Writable class
+ :param valueClass: fully qualified classname of value Writable class
(e.g. "org.apache.hadoop.io.Text", None by default)
- @param keyConverter: (None by default)
- @param valueConverter: (None by default)
- @param conf: (None by default)
- @param compressionCodecClass: (None by default)
+ :param keyConverter: (None by default)
+ :param valueConverter: (None by default)
+ :param conf: (None by default)
+ :param compressionCodecClass: (None by default)
"""
jconf = self.ctx._dictToJavaMap(conf)
pickledRDD = self._toPickleSerialization()
@@ -1208,11 +1211,12 @@ def saveAsSequenceFile(self, path, compressionCodecClass=None):
Output a Python RDD of key-value pairs (of form C{RDD[(K, V)]}) to any Hadoop file
system, using the L{org.apache.hadoop.io.Writable} types that we convert from the
RDD's key and value types. The mechanism is as follows:
+
1. Pyrolite is used to convert pickled Python RDD into RDD of Java objects.
2. Keys and values of this Java RDD are converted to Writables and written out.
- @param path: path to sequence file
- @param compressionCodecClass: (None by default)
+ :param path: path to sequence file
+ :param compressionCodecClass: (None by default)
"""
pickledRDD = self._toPickleSerialization()
batched = isinstance(pickledRDD._jrdd_deserializer, BatchedSerializer)
@@ -2008,7 +2012,7 @@ def countApproxDistinct(self, relativeSD=0.05):
of The Art Cardinality Estimation Algorithm", available
here.
- @param relativeSD Relative accuracy. Smaller values create
+ :param relativeSD: Relative accuracy. Smaller values create
counters that require more space.
It must be greater than 0.000017.
@@ -2073,6 +2077,12 @@ def pipeline_func(split, iterator):
self._jrdd_deserializer = self.ctx.serializer
self._bypass_serializer = False
self._partitionFunc = prev._partitionFunc if self.preservesPartitioning else None
+ self._broadcast = None
+
+ def __del__(self):
+ if self._broadcast:
+ self._broadcast.unpersist()
+ self._broadcast = None
@property
def _jrdd(self):
@@ -2080,14 +2090,16 @@ def _jrdd(self):
return self._jrdd_val
if self._bypass_serializer:
self._jrdd_deserializer = NoOpSerializer()
- command = (self.func, self._prev_jrdd_deserializer,
+ enable_profile = self.ctx._conf.get("spark.python.profile", "false") == "true"
+ profileStats = self.ctx.accumulator(None, PStatsParam) if enable_profile else None
+ command = (self.func, profileStats, self._prev_jrdd_deserializer,
self._jrdd_deserializer)
# the serialized command will be compressed by broadcast
ser = CloudPickleSerializer()
pickled_command = ser.dumps(command)
- if pickled_command > (1 << 20): # 1M
- broadcast = self.ctx.broadcast(pickled_command)
- pickled_command = ser.dumps(broadcast)
+ if len(pickled_command) > (1 << 20): # 1M
+ self._broadcast = self.ctx.broadcast(pickled_command)
+ pickled_command = ser.dumps(self._broadcast)
broadcast_vars = ListConverter().convert(
[x._jbroadcast for x in self.ctx._pickled_broadcast_vars],
self.ctx._gateway._gateway_client)
@@ -2102,6 +2114,10 @@ def _jrdd(self):
self.ctx.pythonExec,
broadcast_vars, self.ctx._javaAccumulator)
self._jrdd_val = python_rdd.asJavaRDD()
+
+ if enable_profile:
+ self._id = self._jrdd_val.id()
+ self.ctx._add_profile(self._id, profileStats)
return self._jrdd_val
def id(self):
diff --git a/python/pyspark/serializers.py b/python/pyspark/serializers.py
index 2672da36c1f50..08a0f0d8ffb3e 100644
--- a/python/pyspark/serializers.py
+++ b/python/pyspark/serializers.py
@@ -114,6 +114,9 @@ def __ne__(self, other):
def __repr__(self):
return "<%s object>" % self.__class__.__name__
+ def __hash__(self):
+ return hash(str(self))
+
class FramedSerializer(Serializer):
@@ -211,7 +214,7 @@ def __eq__(self, other):
return (isinstance(other, BatchedSerializer) and
other.serializer == self.serializer)
- def __str__(self):
+ def __repr__(self):
return "BatchedSerializer<%s>" % str(self.serializer)
@@ -220,7 +223,7 @@ class AutoBatchedSerializer(BatchedSerializer):
Choose the size of batch automatically based on the size of object
"""
- def __init__(self, serializer, bestSize=1 << 20):
+ def __init__(self, serializer, bestSize=1 << 16):
BatchedSerializer.__init__(self, serializer, -1)
self.bestSize = bestSize
@@ -247,7 +250,7 @@ def __eq__(self, other):
other.serializer == self.serializer)
def __str__(self):
- return "BatchedSerializer<%s>" % str(self.serializer)
+ return "AutoBatchedSerializer<%s>" % str(self.serializer)
class CartesianDeserializer(FramedSerializer):
@@ -279,7 +282,7 @@ def __eq__(self, other):
return (isinstance(other, CartesianDeserializer) and
self.key_ser == other.key_ser and self.val_ser == other.val_ser)
- def __str__(self):
+ def __repr__(self):
return "CartesianDeserializer<%s, %s>" % \
(str(self.key_ser), str(self.val_ser))
@@ -306,7 +309,7 @@ def __eq__(self, other):
return (isinstance(other, PairDeserializer) and
self.key_ser == other.key_ser and self.val_ser == other.val_ser)
- def __str__(self):
+ def __repr__(self):
return "PairDeserializer<%s, %s>" % (str(self.key_ser), str(self.val_ser))
diff --git a/python/pyspark/shuffle.py b/python/pyspark/shuffle.py
index ce597cbe91e15..d57a802e4734a 100644
--- a/python/pyspark/shuffle.py
+++ b/python/pyspark/shuffle.py
@@ -396,7 +396,6 @@ def _external_items(self):
for v in self.data.iteritems():
yield v
self.data.clear()
- gc.collect()
# remove the merged partition
for j in range(self.spills):
@@ -428,7 +427,7 @@ def _recursive_merged_items(self, start):
subdirs = [os.path.join(d, "parts", str(i))
for d in self.localdirs]
m = ExternalMerger(self.agg, self.memory_limit, self.serializer,
- subdirs, self.scale * self.partitions)
+ subdirs, self.scale * self.partitions, self.partitions)
m.pdata = [{} for _ in range(self.partitions)]
limit = self._next_limit()
@@ -486,7 +485,7 @@ def sorted(self, iterator, key=None, reverse=False):
goes above the limit.
"""
global MemoryBytesSpilled, DiskBytesSpilled
- batch = 10
+ batch = 100
chunks, current_chunk = [], []
iterator = iter(iterator)
while True:
diff --git a/python/pyspark/sql.py b/python/pyspark/sql.py
index f71d24c470dc9..b31a82f9b19ac 100644
--- a/python/pyspark/sql.py
+++ b/python/pyspark/sql.py
@@ -15,28 +15,38 @@
# limitations under the License.
#
+"""
+public classes of Spark SQL:
+
+ - L{SQLContext}
+ Main entry point for SQL functionality.
+ - L{SchemaRDD}
+ A Resilient Distributed Dataset (RDD) with Schema information for the data contained. In
+ addition to normal RDD operations, SchemaRDDs also support SQL.
+ - L{Row}
+ A Row of data returned by a Spark SQL query.
+ - L{HiveContext}
+ Main entry point for accessing data stored in Apache Hive..
+"""
-import sys
-import types
import itertools
-import warnings
import decimal
import datetime
import keyword
import warnings
+import json
from array import array
from operator import itemgetter
+from itertools import imap
+
+from py4j.protocol import Py4JError
+from py4j.java_collections import ListConverter, MapConverter
from pyspark.rdd import RDD
from pyspark.serializers import BatchedSerializer, PickleSerializer, CloudPickleSerializer
from pyspark.storagelevel import StorageLevel
from pyspark.traceback_utils import SCCallSiteSync
-from itertools import chain, ifilter, imap
-
-from py4j.protocol import Py4JError
-from py4j.java_collections import ListConverter, MapConverter
-
__all__ = [
"StringType", "BinaryType", "BooleanType", "TimestampType", "DecimalType",
@@ -62,6 +72,18 @@ def __eq__(self, other):
def __ne__(self, other):
return not self.__eq__(other)
+ @classmethod
+ def typeName(cls):
+ return cls.__name__[:-4].lower()
+
+ def jsonValue(self):
+ return self.typeName()
+
+ def json(self):
+ return json.dumps(self.jsonValue(),
+ separators=(',', ':'),
+ sort_keys=True)
+
class PrimitiveTypeSingleton(type):
@@ -201,10 +223,20 @@ def __init__(self, elementType, containsNull=True):
self.elementType = elementType
self.containsNull = containsNull
- def __str__(self):
+ def __repr__(self):
return "ArrayType(%s,%s)" % (self.elementType,
str(self.containsNull).lower())
+ def jsonValue(self):
+ return {"type": self.typeName(),
+ "elementType": self.elementType.jsonValue(),
+ "containsNull": self.containsNull}
+
+ @classmethod
+ def fromJson(cls, json):
+ return ArrayType(_parse_datatype_json_value(json["elementType"]),
+ json["containsNull"])
+
class MapType(DataType):
@@ -245,6 +277,18 @@ def __repr__(self):
return "MapType(%s,%s,%s)" % (self.keyType, self.valueType,
str(self.valueContainsNull).lower())
+ def jsonValue(self):
+ return {"type": self.typeName(),
+ "keyType": self.keyType.jsonValue(),
+ "valueType": self.valueType.jsonValue(),
+ "valueContainsNull": self.valueContainsNull}
+
+ @classmethod
+ def fromJson(cls, json):
+ return MapType(_parse_datatype_json_value(json["keyType"]),
+ _parse_datatype_json_value(json["valueType"]),
+ json["valueContainsNull"])
+
class StructField(DataType):
@@ -283,6 +327,17 @@ def __repr__(self):
return "StructField(%s,%s,%s)" % (self.name, self.dataType,
str(self.nullable).lower())
+ def jsonValue(self):
+ return {"name": self.name,
+ "type": self.dataType.jsonValue(),
+ "nullable": self.nullable}
+
+ @classmethod
+ def fromJson(cls, json):
+ return StructField(json["name"],
+ _parse_datatype_json_value(json["type"]),
+ json["nullable"])
+
class StructType(DataType):
@@ -312,42 +367,30 @@ def __repr__(self):
return ("StructType(List(%s))" %
",".join(str(field) for field in self.fields))
+ def jsonValue(self):
+ return {"type": self.typeName(),
+ "fields": [f.jsonValue() for f in self.fields]}
-def _parse_datatype_list(datatype_list_string):
- """Parses a list of comma separated data types."""
- index = 0
- datatype_list = []
- start = 0
- depth = 0
- while index < len(datatype_list_string):
- if depth == 0 and datatype_list_string[index] == ",":
- datatype_string = datatype_list_string[start:index].strip()
- datatype_list.append(_parse_datatype_string(datatype_string))
- start = index + 1
- elif datatype_list_string[index] == "(":
- depth += 1
- elif datatype_list_string[index] == ")":
- depth -= 1
+ @classmethod
+ def fromJson(cls, json):
+ return StructType([StructField.fromJson(f) for f in json["fields"]])
- index += 1
- # Handle the last data type
- datatype_string = datatype_list_string[start:index].strip()
- datatype_list.append(_parse_datatype_string(datatype_string))
- return datatype_list
+_all_primitive_types = dict((v.typeName(), v)
+ for v in globals().itervalues()
+ if type(v) is PrimitiveTypeSingleton and
+ v.__base__ == PrimitiveType)
-_all_primitive_types = dict((k, v) for k, v in globals().iteritems()
- if type(v) is PrimitiveTypeSingleton and v.__base__ == PrimitiveType)
+_all_complex_types = dict((v.typeName(), v)
+ for v in [ArrayType, MapType, StructType])
-def _parse_datatype_string(datatype_string):
- """Parses the given data type string.
-
+def _parse_datatype_json_string(json_string):
+ """Parses the given data type JSON string.
>>> def check_datatype(datatype):
- ... scala_datatype = sqlCtx._ssql_ctx.parseDataType(str(datatype))
- ... python_datatype = _parse_datatype_string(
- ... scala_datatype.toString())
+ ... scala_datatype = sqlCtx._ssql_ctx.parseDataType(datatype.json())
+ ... python_datatype = _parse_datatype_json_string(scala_datatype.json())
... return datatype == python_datatype
>>> all(check_datatype(cls()) for cls in _all_primitive_types.values())
True
@@ -385,51 +428,14 @@ def _parse_datatype_string(datatype_string):
>>> check_datatype(complex_maptype)
True
"""
- index = datatype_string.find("(")
- if index == -1:
- # It is a primitive type.
- index = len(datatype_string)
- type_or_field = datatype_string[:index]
- rest_part = datatype_string[index + 1:len(datatype_string) - 1].strip()
-
- if type_or_field in _all_primitive_types:
- return _all_primitive_types[type_or_field]()
-
- elif type_or_field == "ArrayType":
- last_comma_index = rest_part.rfind(",")
- containsNull = True
- if rest_part[last_comma_index + 1:].strip().lower() == "false":
- containsNull = False
- elementType = _parse_datatype_string(
- rest_part[:last_comma_index].strip())
- return ArrayType(elementType, containsNull)
-
- elif type_or_field == "MapType":
- last_comma_index = rest_part.rfind(",")
- valueContainsNull = True
- if rest_part[last_comma_index + 1:].strip().lower() == "false":
- valueContainsNull = False
- keyType, valueType = _parse_datatype_list(
- rest_part[:last_comma_index].strip())
- return MapType(keyType, valueType, valueContainsNull)
-
- elif type_or_field == "StructField":
- first_comma_index = rest_part.find(",")
- name = rest_part[:first_comma_index].strip()
- last_comma_index = rest_part.rfind(",")
- nullable = True
- if rest_part[last_comma_index + 1:].strip().lower() == "false":
- nullable = False
- dataType = _parse_datatype_string(
- rest_part[first_comma_index + 1:last_comma_index].strip())
- return StructField(name, dataType, nullable)
-
- elif type_or_field == "StructType":
- # rest_part should be in the format like
- # List(StructField(field1,IntegerType,false)).
- field_list_string = rest_part[rest_part.find("(") + 1:-1]
- fields = _parse_datatype_list(field_list_string)
- return StructType(fields)
+ return _parse_datatype_json_value(json.loads(json_string))
+
+
+def _parse_datatype_json_value(json_value):
+ if type(json_value) is unicode and json_value in _all_primitive_types.keys():
+ return _all_primitive_types[json_value]()
+ else:
+ return _all_complex_types[json_value["type"]].fromJson(json_value)
# Mapping Python types to Spark SQL DateType
@@ -899,8 +905,8 @@ class SQLContext(object):
def __init__(self, sparkContext, sqlContext=None):
"""Create a new SQLContext.
- @param sparkContext: The SparkContext to wrap.
- @param sqlContext: An optional JVM Scala SQLContext. If set, we do not instatiate a new
+ :param sparkContext: The SparkContext to wrap.
+ :param sqlContext: An optional JVM Scala SQLContext. If set, we do not instatiate a new
SQLContext in the JVM, instead we make all calls to this object.
>>> srdd = sqlCtx.inferSchema(rdd)
@@ -960,12 +966,12 @@ def registerFunction(self, name, f, returnType=StringType()):
[Row(c0=4)]
"""
func = lambda _, it: imap(lambda x: f(*x), it)
- command = (func,
+ command = (func, None,
BatchedSerializer(PickleSerializer(), 1024),
BatchedSerializer(PickleSerializer(), 1024))
ser = CloudPickleSerializer()
pickled_command = ser.dumps(command)
- if pickled_command > (1 << 20): # 1M
+ if len(pickled_command) > (1 << 20): # 1M
broadcast = self._sc.broadcast(pickled_command)
pickled_command = ser.dumps(broadcast)
broadcast_vars = ListConverter().convert(
@@ -983,7 +989,7 @@ def registerFunction(self, name, f, returnType=StringType()):
self._sc.pythonExec,
broadcast_vars,
self._sc._javaAccumulator,
- str(returnType))
+ returnType.json())
def inferSchema(self, rdd):
"""Infer and apply a schema to an RDD of L{Row}.
@@ -1119,7 +1125,7 @@ def applySchema(self, rdd, schema):
batched = isinstance(rdd._jrdd_deserializer, BatchedSerializer)
jrdd = self._pythonToJava(rdd._jrdd, batched)
- srdd = self._ssql_ctx.applySchemaToPythonRDD(jrdd.rdd(), str(schema))
+ srdd = self._ssql_ctx.applySchemaToPythonRDD(jrdd.rdd(), schema.json())
return SchemaRDD(srdd.toJavaSchemaRDD(), self)
def registerRDDAsTable(self, rdd, tableName):
@@ -1209,7 +1215,7 @@ def jsonFile(self, path, schema=None):
if schema is None:
srdd = self._ssql_ctx.jsonFile(path)
else:
- scala_datatype = self._ssql_ctx.parseDataType(str(schema))
+ scala_datatype = self._ssql_ctx.parseDataType(schema.json())
srdd = self._ssql_ctx.jsonFile(path, scala_datatype)
return SchemaRDD(srdd.toJavaSchemaRDD(), self)
@@ -1279,7 +1285,7 @@ def func(iterator):
if schema is None:
srdd = self._ssql_ctx.jsonRDD(jrdd.rdd())
else:
- scala_datatype = self._ssql_ctx.parseDataType(str(schema))
+ scala_datatype = self._ssql_ctx.parseDataType(schema.json())
srdd = self._ssql_ctx.jsonRDD(jrdd.rdd(), scala_datatype)
return SchemaRDD(srdd.toJavaSchemaRDD(), self)
@@ -1325,8 +1331,8 @@ class HiveContext(SQLContext):
def __init__(self, sparkContext, hiveContext=None):
"""Create a new HiveContext.
- @param sparkContext: The SparkContext to wrap.
- @param hiveContext: An optional JVM Scala HiveContext. If set, we do not instatiate a new
+ :param sparkContext: The SparkContext to wrap.
+ :param hiveContext: An optional JVM Scala HiveContext. If set, we do not instatiate a new
HiveContext in the JVM, instead we make all calls to this object.
"""
SQLContext.__init__(self, sparkContext)
@@ -1614,7 +1620,7 @@ def saveAsTable(self, tableName):
def schema(self):
"""Returns the schema of this SchemaRDD (represented by
a L{StructType})."""
- return _parse_datatype_string(self._jschema_rdd.baseSchemaRDD().schema().toString())
+ return _parse_datatype_json_string(self._jschema_rdd.baseSchemaRDD().schema().json())
def schemaString(self):
"""Returns the output schema in the tree format."""
diff --git a/python/pyspark/streaming/__init__.py b/python/pyspark/streaming/__init__.py
new file mode 100644
index 0000000000000..d2644a1d4ffab
--- /dev/null
+++ b/python/pyspark/streaming/__init__.py
@@ -0,0 +1,21 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+from pyspark.streaming.context import StreamingContext
+from pyspark.streaming.dstream import DStream
+
+__all__ = ['StreamingContext', 'DStream']
diff --git a/python/pyspark/streaming/context.py b/python/pyspark/streaming/context.py
new file mode 100644
index 0000000000000..dc9dc41121935
--- /dev/null
+++ b/python/pyspark/streaming/context.py
@@ -0,0 +1,325 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+import os
+import sys
+
+from py4j.java_collections import ListConverter
+from py4j.java_gateway import java_import, JavaObject
+
+from pyspark import RDD, SparkConf
+from pyspark.serializers import UTF8Deserializer, CloudPickleSerializer
+from pyspark.context import SparkContext
+from pyspark.storagelevel import StorageLevel
+from pyspark.streaming.dstream import DStream
+from pyspark.streaming.util import TransformFunction, TransformFunctionSerializer
+
+__all__ = ["StreamingContext"]
+
+
+def _daemonize_callback_server():
+ """
+ Hack Py4J to daemonize callback server
+
+ The thread of callback server has daemon=False, it will block the driver
+ from exiting if it's not shutdown. The following code replace `start()`
+ of CallbackServer with a new version, which set daemon=True for this
+ thread.
+
+ Also, it will update the port number (0) with real port
+ """
+ # TODO: create a patch for Py4J
+ import socket
+ import py4j.java_gateway
+ logger = py4j.java_gateway.logger
+ from py4j.java_gateway import Py4JNetworkError
+ from threading import Thread
+
+ def start(self):
+ """Starts the CallbackServer. This method should be called by the
+ client instead of run()."""
+ self.server_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
+ self.server_socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR,
+ 1)
+ try:
+ self.server_socket.bind((self.address, self.port))
+ if not self.port:
+ # update port with real port
+ self.port = self.server_socket.getsockname()[1]
+ except Exception as e:
+ msg = 'An error occurred while trying to start the callback server: %s' % e
+ logger.exception(msg)
+ raise Py4JNetworkError(msg)
+
+ # Maybe thread needs to be cleanup up?
+ self.thread = Thread(target=self.run)
+ self.thread.daemon = True
+ self.thread.start()
+
+ py4j.java_gateway.CallbackServer.start = start
+
+
+class StreamingContext(object):
+ """
+ Main entry point for Spark Streaming functionality. A StreamingContext
+ represents the connection to a Spark cluster, and can be used to create
+ L{DStream} various input sources. It can be from an existing L{SparkContext}.
+ After creating and transforming DStreams, the streaming computation can
+ be started and stopped using `context.start()` and `context.stop()`,
+ respectively. `context.awaitTransformation()` allows the current thread
+ to wait for the termination of the context by `stop()` or by an exception.
+ """
+ _transformerSerializer = None
+
+ def __init__(self, sparkContext, batchDuration=None, jssc=None):
+ """
+ Create a new StreamingContext.
+
+ @param sparkContext: L{SparkContext} object.
+ @param batchDuration: the time interval (in seconds) at which streaming
+ data will be divided into batches
+ """
+
+ self._sc = sparkContext
+ self._jvm = self._sc._jvm
+ self._jssc = jssc or self._initialize_context(self._sc, batchDuration)
+
+ def _initialize_context(self, sc, duration):
+ self._ensure_initialized()
+ return self._jvm.JavaStreamingContext(sc._jsc, self._jduration(duration))
+
+ def _jduration(self, seconds):
+ """
+ Create Duration object given number of seconds
+ """
+ return self._jvm.Duration(int(seconds * 1000))
+
+ @classmethod
+ def _ensure_initialized(cls):
+ SparkContext._ensure_initialized()
+ gw = SparkContext._gateway
+
+ java_import(gw.jvm, "org.apache.spark.streaming.*")
+ java_import(gw.jvm, "org.apache.spark.streaming.api.java.*")
+ java_import(gw.jvm, "org.apache.spark.streaming.api.python.*")
+
+ # start callback server
+ # getattr will fallback to JVM, so we cannot test by hasattr()
+ if "_callback_server" not in gw.__dict__:
+ _daemonize_callback_server()
+ # use random port
+ gw._start_callback_server(0)
+ # gateway with real port
+ gw._python_proxy_port = gw._callback_server.port
+ # get the GatewayServer object in JVM by ID
+ jgws = JavaObject("GATEWAY_SERVER", gw._gateway_client)
+ # update the port of CallbackClient with real port
+ gw.jvm.PythonDStream.updatePythonGatewayPort(jgws, gw._python_proxy_port)
+
+ # register serializer for TransformFunction
+ # it happens before creating SparkContext when loading from checkpointing
+ cls._transformerSerializer = TransformFunctionSerializer(
+ SparkContext._active_spark_context, CloudPickleSerializer(), gw)
+
+ @classmethod
+ def getOrCreate(cls, checkpointPath, setupFunc):
+ """
+ Either recreate a StreamingContext from checkpoint data or create a new StreamingContext.
+ If checkpoint data exists in the provided `checkpointPath`, then StreamingContext will be
+ recreated from the checkpoint data. If the data does not exist, then the provided setupFunc
+ will be used to create a JavaStreamingContext.
+
+ @param checkpointPath Checkpoint directory used in an earlier JavaStreamingContext program
+ @param setupFunc Function to create a new JavaStreamingContext and setup DStreams
+ """
+ # TODO: support checkpoint in HDFS
+ if not os.path.exists(checkpointPath) or not os.listdir(checkpointPath):
+ ssc = setupFunc()
+ ssc.checkpoint(checkpointPath)
+ return ssc
+
+ cls._ensure_initialized()
+ gw = SparkContext._gateway
+
+ try:
+ jssc = gw.jvm.JavaStreamingContext(checkpointPath)
+ except Exception:
+ print >>sys.stderr, "failed to load StreamingContext from checkpoint"
+ raise
+
+ jsc = jssc.sparkContext()
+ conf = SparkConf(_jconf=jsc.getConf())
+ sc = SparkContext(conf=conf, gateway=gw, jsc=jsc)
+ # update ctx in serializer
+ SparkContext._active_spark_context = sc
+ cls._transformerSerializer.ctx = sc
+ return StreamingContext(sc, None, jssc)
+
+ @property
+ def sparkContext(self):
+ """
+ Return SparkContext which is associated with this StreamingContext.
+ """
+ return self._sc
+
+ def start(self):
+ """
+ Start the execution of the streams.
+ """
+ self._jssc.start()
+
+ def awaitTermination(self, timeout=None):
+ """
+ Wait for the execution to stop.
+ @param timeout: time to wait in seconds
+ """
+ if timeout is None:
+ self._jssc.awaitTermination()
+ else:
+ self._jssc.awaitTermination(int(timeout * 1000))
+
+ def stop(self, stopSparkContext=True, stopGraceFully=False):
+ """
+ Stop the execution of the streams, with option of ensuring all
+ received data has been processed.
+
+ @param stopSparkContext: Stop the associated SparkContext or not
+ @param stopGracefully: Stop gracefully by waiting for the processing
+ of all received data to be completed
+ """
+ self._jssc.stop(stopSparkContext, stopGraceFully)
+ if stopSparkContext:
+ self._sc.stop()
+
+ def remember(self, duration):
+ """
+ Set each DStreams in this context to remember RDDs it generated
+ in the last given duration. DStreams remember RDDs only for a
+ limited duration of time and releases them for garbage collection.
+ This method allows the developer to specify how to long to remember
+ the RDDs (if the developer wishes to query old data outside the
+ DStream computation).
+
+ @param duration: Minimum duration (in seconds) that each DStream
+ should remember its RDDs
+ """
+ self._jssc.remember(self._jduration(duration))
+
+ def checkpoint(self, directory):
+ """
+ Sets the context to periodically checkpoint the DStream operations for master
+ fault-tolerance. The graph will be checkpointed every batch interval.
+
+ @param directory: HDFS-compatible directory where the checkpoint data
+ will be reliably stored
+ """
+ self._jssc.checkpoint(directory)
+
+ def socketTextStream(self, hostname, port, storageLevel=StorageLevel.MEMORY_AND_DISK_SER_2):
+ """
+ Create an input from TCP source hostname:port. Data is received using
+ a TCP socket and receive byte is interpreted as UTF8 encoded ``\\n`` delimited
+ lines.
+
+ @param hostname: Hostname to connect to for receiving data
+ @param port: Port to connect to for receiving data
+ @param storageLevel: Storage level to use for storing the received objects
+ """
+ jlevel = self._sc._getJavaStorageLevel(storageLevel)
+ return DStream(self._jssc.socketTextStream(hostname, port, jlevel), self,
+ UTF8Deserializer())
+
+ def textFileStream(self, directory):
+ """
+ Create an input stream that monitors a Hadoop-compatible file system
+ for new files and reads them as text files. Files must be wrriten to the
+ monitored directory by "moving" them from another location within the same
+ file system. File names starting with . are ignored.
+ """
+ return DStream(self._jssc.textFileStream(directory), self, UTF8Deserializer())
+
+ def _check_serializers(self, rdds):
+ # make sure they have same serializer
+ if len(set(rdd._jrdd_deserializer for rdd in rdds)) > 1:
+ for i in range(len(rdds)):
+ # reset them to sc.serializer
+ rdds[i] = rdds[i]._reserialize()
+
+ def queueStream(self, rdds, oneAtATime=True, default=None):
+ """
+ Create an input stream from an queue of RDDs or list. In each batch,
+ it will process either one or all of the RDDs returned by the queue.
+
+ NOTE: changes to the queue after the stream is created will not be recognized.
+
+ @param rdds: Queue of RDDs
+ @param oneAtATime: pick one rdd each time or pick all of them once.
+ @param default: The default rdd if no more in rdds
+ """
+ if default and not isinstance(default, RDD):
+ default = self._sc.parallelize(default)
+
+ if not rdds and default:
+ rdds = [rdds]
+
+ if rdds and not isinstance(rdds[0], RDD):
+ rdds = [self._sc.parallelize(input) for input in rdds]
+ self._check_serializers(rdds)
+
+ jrdds = ListConverter().convert([r._jrdd for r in rdds],
+ SparkContext._gateway._gateway_client)
+ queue = self._jvm.PythonDStream.toRDDQueue(jrdds)
+ if default:
+ default = default._reserialize(rdds[0]._jrdd_deserializer)
+ jdstream = self._jssc.queueStream(queue, oneAtATime, default._jrdd)
+ else:
+ jdstream = self._jssc.queueStream(queue, oneAtATime)
+ return DStream(jdstream, self, rdds[0]._jrdd_deserializer)
+
+ def transform(self, dstreams, transformFunc):
+ """
+ Create a new DStream in which each RDD is generated by applying
+ a function on RDDs of the DStreams. The order of the JavaRDDs in
+ the transform function parameter will be the same as the order
+ of corresponding DStreams in the list.
+ """
+ jdstreams = ListConverter().convert([d._jdstream for d in dstreams],
+ SparkContext._gateway._gateway_client)
+ # change the final serializer to sc.serializer
+ func = TransformFunction(self._sc,
+ lambda t, *rdds: transformFunc(rdds).map(lambda x: x),
+ *[d._jrdd_deserializer for d in dstreams])
+ jfunc = self._jvm.TransformFunction(func)
+ jdstream = self._jssc.transform(jdstreams, jfunc)
+ return DStream(jdstream, self, self._sc.serializer)
+
+ def union(self, *dstreams):
+ """
+ Create a unified DStream from multiple DStreams of the same
+ type and same slide duration.
+ """
+ if not dstreams:
+ raise ValueError("should have at least one DStream to union")
+ if len(dstreams) == 1:
+ return dstreams[0]
+ if len(set(s._jrdd_deserializer for s in dstreams)) > 1:
+ raise ValueError("All DStreams should have same serializer")
+ if len(set(s._slideDuration for s in dstreams)) > 1:
+ raise ValueError("All DStreams should have same slide duration")
+ first = dstreams[0]
+ jrest = ListConverter().convert([d._jdstream for d in dstreams[1:]],
+ SparkContext._gateway._gateway_client)
+ return DStream(self._jssc.union(first._jdstream, jrest), self, first._jrdd_deserializer)
diff --git a/python/pyspark/streaming/dstream.py b/python/pyspark/streaming/dstream.py
new file mode 100644
index 0000000000000..0826ddc56e844
--- /dev/null
+++ b/python/pyspark/streaming/dstream.py
@@ -0,0 +1,623 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+from itertools import chain, ifilter, imap
+import operator
+import time
+from datetime import datetime
+
+from py4j.protocol import Py4JJavaError
+
+from pyspark import RDD
+from pyspark.storagelevel import StorageLevel
+from pyspark.streaming.util import rddToFileName, TransformFunction
+from pyspark.rdd import portable_hash
+from pyspark.resultiterable import ResultIterable
+
+__all__ = ["DStream"]
+
+
+class DStream(object):
+ """
+ A Discretized Stream (DStream), the basic abstraction in Spark Streaming,
+ is a continuous sequence of RDDs (of the same type) representing a
+ continuous stream of data (see L{RDD} in the Spark core documentation
+ for more details on RDDs).
+
+ DStreams can either be created from live data (such as, data from TCP
+ sockets, Kafka, Flume, etc.) using a L{StreamingContext} or it can be
+ generated by transforming existing DStreams using operations such as
+ `map`, `window` and `reduceByKeyAndWindow`. While a Spark Streaming
+ program is running, each DStream periodically generates a RDD, either
+ from live data or by transforming the RDD generated by a parent DStream.
+
+ DStreams internally is characterized by a few basic properties:
+ - A list of other DStreams that the DStream depends on
+ - A time interval at which the DStream generates an RDD
+ - A function that is used to generate an RDD after each time interval
+ """
+ def __init__(self, jdstream, ssc, jrdd_deserializer):
+ self._jdstream = jdstream
+ self._ssc = ssc
+ self._sc = ssc._sc
+ self._jrdd_deserializer = jrdd_deserializer
+ self.is_cached = False
+ self.is_checkpointed = False
+
+ def context(self):
+ """
+ Return the StreamingContext associated with this DStream
+ """
+ return self._ssc
+
+ def count(self):
+ """
+ Return a new DStream in which each RDD has a single element
+ generated by counting each RDD of this DStream.
+ """
+ return self.mapPartitions(lambda i: [sum(1 for _ in i)]).reduce(operator.add)
+
+ def filter(self, f):
+ """
+ Return a new DStream containing only the elements that satisfy predicate.
+ """
+ def func(iterator):
+ return ifilter(f, iterator)
+ return self.mapPartitions(func, True)
+
+ def flatMap(self, f, preservesPartitioning=False):
+ """
+ Return a new DStream by applying a function to all elements of
+ this DStream, and then flattening the results
+ """
+ def func(s, iterator):
+ return chain.from_iterable(imap(f, iterator))
+ return self.mapPartitionsWithIndex(func, preservesPartitioning)
+
+ def map(self, f, preservesPartitioning=False):
+ """
+ Return a new DStream by applying a function to each element of DStream.
+ """
+ def func(iterator):
+ return imap(f, iterator)
+ return self.mapPartitions(func, preservesPartitioning)
+
+ def mapPartitions(self, f, preservesPartitioning=False):
+ """
+ Return a new DStream in which each RDD is generated by applying
+ mapPartitions() to each RDDs of this DStream.
+ """
+ def func(s, iterator):
+ return f(iterator)
+ return self.mapPartitionsWithIndex(func, preservesPartitioning)
+
+ def mapPartitionsWithIndex(self, f, preservesPartitioning=False):
+ """
+ Return a new DStream in which each RDD is generated by applying
+ mapPartitionsWithIndex() to each RDDs of this DStream.
+ """
+ return self.transform(lambda rdd: rdd.mapPartitionsWithIndex(f, preservesPartitioning))
+
+ def reduce(self, func):
+ """
+ Return a new DStream in which each RDD has a single element
+ generated by reducing each RDD of this DStream.
+ """
+ return self.map(lambda x: (None, x)).reduceByKey(func, 1).map(lambda x: x[1])
+
+ def reduceByKey(self, func, numPartitions=None):
+ """
+ Return a new DStream by applying reduceByKey to each RDD.
+ """
+ if numPartitions is None:
+ numPartitions = self._sc.defaultParallelism
+ return self.combineByKey(lambda x: x, func, func, numPartitions)
+
+ def combineByKey(self, createCombiner, mergeValue, mergeCombiners,
+ numPartitions=None):
+ """
+ Return a new DStream by applying combineByKey to each RDD.
+ """
+ if numPartitions is None:
+ numPartitions = self._sc.defaultParallelism
+
+ def func(rdd):
+ return rdd.combineByKey(createCombiner, mergeValue, mergeCombiners, numPartitions)
+ return self.transform(func)
+
+ def partitionBy(self, numPartitions, partitionFunc=portable_hash):
+ """
+ Return a copy of the DStream in which each RDD are partitioned
+ using the specified partitioner.
+ """
+ return self.transform(lambda rdd: rdd.partitionBy(numPartitions, partitionFunc))
+
+ def foreachRDD(self, func):
+ """
+ Apply a function to each RDD in this DStream.
+ """
+ if func.func_code.co_argcount == 1:
+ old_func = func
+ func = lambda t, rdd: old_func(rdd)
+ jfunc = TransformFunction(self._sc, func, self._jrdd_deserializer)
+ api = self._ssc._jvm.PythonDStream
+ api.callForeachRDD(self._jdstream, jfunc)
+
+ def pprint(self):
+ """
+ Print the first ten elements of each RDD generated in this DStream.
+ """
+ def takeAndPrint(time, rdd):
+ taken = rdd.take(11)
+ print "-------------------------------------------"
+ print "Time: %s" % time
+ print "-------------------------------------------"
+ for record in taken[:10]:
+ print record
+ if len(taken) > 10:
+ print "..."
+ print
+
+ self.foreachRDD(takeAndPrint)
+
+ def mapValues(self, f):
+ """
+ Return a new DStream by applying a map function to the value of
+ each key-value pairs in this DStream without changing the key.
+ """
+ map_values_fn = lambda (k, v): (k, f(v))
+ return self.map(map_values_fn, preservesPartitioning=True)
+
+ def flatMapValues(self, f):
+ """
+ Return a new DStream by applying a flatmap function to the value
+ of each key-value pairs in this DStream without changing the key.
+ """
+ flat_map_fn = lambda (k, v): ((k, x) for x in f(v))
+ return self.flatMap(flat_map_fn, preservesPartitioning=True)
+
+ def glom(self):
+ """
+ Return a new DStream in which RDD is generated by applying glom()
+ to RDD of this DStream.
+ """
+ def func(iterator):
+ yield list(iterator)
+ return self.mapPartitions(func)
+
+ def cache(self):
+ """
+ Persist the RDDs of this DStream with the default storage level
+ (C{MEMORY_ONLY_SER}).
+ """
+ self.is_cached = True
+ self.persist(StorageLevel.MEMORY_ONLY_SER)
+ return self
+
+ def persist(self, storageLevel):
+ """
+ Persist the RDDs of this DStream with the given storage level
+ """
+ self.is_cached = True
+ javaStorageLevel = self._sc._getJavaStorageLevel(storageLevel)
+ self._jdstream.persist(javaStorageLevel)
+ return self
+
+ def checkpoint(self, interval):
+ """
+ Enable periodic checkpointing of RDDs of this DStream
+
+ @param interval: time in seconds, after each period of that, generated
+ RDD will be checkpointed
+ """
+ self.is_checkpointed = True
+ self._jdstream.checkpoint(self._ssc._jduration(interval))
+ return self
+
+ def groupByKey(self, numPartitions=None):
+ """
+ Return a new DStream by applying groupByKey on each RDD.
+ """
+ if numPartitions is None:
+ numPartitions = self._sc.defaultParallelism
+ return self.transform(lambda rdd: rdd.groupByKey(numPartitions))
+
+ def countByValue(self):
+ """
+ Return a new DStream in which each RDD contains the counts of each
+ distinct value in each RDD of this DStream.
+ """
+ return self.map(lambda x: (x, None)).reduceByKey(lambda x, y: None).count()
+
+ def saveAsTextFiles(self, prefix, suffix=None):
+ """
+ Save each RDD in this DStream as at text file, using string
+ representation of elements.
+ """
+ def saveAsTextFile(t, rdd):
+ path = rddToFileName(prefix, suffix, t)
+ try:
+ rdd.saveAsTextFile(path)
+ except Py4JJavaError as e:
+ # after recovered from checkpointing, the foreachRDD may
+ # be called twice
+ if 'FileAlreadyExistsException' not in str(e):
+ raise
+ return self.foreachRDD(saveAsTextFile)
+
+ # TODO: uncomment this until we have ssc.pickleFileStream()
+ # def saveAsPickleFiles(self, prefix, suffix=None):
+ # """
+ # Save each RDD in this DStream as at binary file, the elements are
+ # serialized by pickle.
+ # """
+ # def saveAsPickleFile(t, rdd):
+ # path = rddToFileName(prefix, suffix, t)
+ # try:
+ # rdd.saveAsPickleFile(path)
+ # except Py4JJavaError as e:
+ # # after recovered from checkpointing, the foreachRDD may
+ # # be called twice
+ # if 'FileAlreadyExistsException' not in str(e):
+ # raise
+ # return self.foreachRDD(saveAsPickleFile)
+
+ def transform(self, func):
+ """
+ Return a new DStream in which each RDD is generated by applying a function
+ on each RDD of this DStream.
+
+ `func` can have one argument of `rdd`, or have two arguments of
+ (`time`, `rdd`)
+ """
+ if func.func_code.co_argcount == 1:
+ oldfunc = func
+ func = lambda t, rdd: oldfunc(rdd)
+ assert func.func_code.co_argcount == 2, "func should take one or two arguments"
+ return TransformedDStream(self, func)
+
+ def transformWith(self, func, other, keepSerializer=False):
+ """
+ Return a new DStream in which each RDD is generated by applying a function
+ on each RDD of this DStream and 'other' DStream.
+
+ `func` can have two arguments of (`rdd_a`, `rdd_b`) or have three
+ arguments of (`time`, `rdd_a`, `rdd_b`)
+ """
+ if func.func_code.co_argcount == 2:
+ oldfunc = func
+ func = lambda t, a, b: oldfunc(a, b)
+ assert func.func_code.co_argcount == 3, "func should take two or three arguments"
+ jfunc = TransformFunction(self._sc, func, self._jrdd_deserializer, other._jrdd_deserializer)
+ dstream = self._sc._jvm.PythonTransformed2DStream(self._jdstream.dstream(),
+ other._jdstream.dstream(), jfunc)
+ jrdd_serializer = self._jrdd_deserializer if keepSerializer else self._sc.serializer
+ return DStream(dstream.asJavaDStream(), self._ssc, jrdd_serializer)
+
+ def repartition(self, numPartitions):
+ """
+ Return a new DStream with an increased or decreased level of parallelism.
+ """
+ return self.transform(lambda rdd: rdd.repartition(numPartitions))
+
+ @property
+ def _slideDuration(self):
+ """
+ Return the slideDuration in seconds of this DStream
+ """
+ return self._jdstream.dstream().slideDuration().milliseconds() / 1000.0
+
+ def union(self, other):
+ """
+ Return a new DStream by unifying data of another DStream with this DStream.
+
+ @param other: Another DStream having the same interval (i.e., slideDuration)
+ as this DStream.
+ """
+ if self._slideDuration != other._slideDuration:
+ raise ValueError("the two DStream should have same slide duration")
+ return self.transformWith(lambda a, b: a.union(b), other, True)
+
+ def cogroup(self, other, numPartitions=None):
+ """
+ Return a new DStream by applying 'cogroup' between RDDs of this
+ DStream and `other` DStream.
+
+ Hash partitioning is used to generate the RDDs with `numPartitions` partitions.
+ """
+ if numPartitions is None:
+ numPartitions = self._sc.defaultParallelism
+ return self.transformWith(lambda a, b: a.cogroup(b, numPartitions), other)
+
+ def join(self, other, numPartitions=None):
+ """
+ Return a new DStream by applying 'join' between RDDs of this DStream and
+ `other` DStream.
+
+ Hash partitioning is used to generate the RDDs with `numPartitions`
+ partitions.
+ """
+ if numPartitions is None:
+ numPartitions = self._sc.defaultParallelism
+ return self.transformWith(lambda a, b: a.join(b, numPartitions), other)
+
+ def leftOuterJoin(self, other, numPartitions=None):
+ """
+ Return a new DStream by applying 'left outer join' between RDDs of this DStream and
+ `other` DStream.
+
+ Hash partitioning is used to generate the RDDs with `numPartitions`
+ partitions.
+ """
+ if numPartitions is None:
+ numPartitions = self._sc.defaultParallelism
+ return self.transformWith(lambda a, b: a.leftOuterJoin(b, numPartitions), other)
+
+ def rightOuterJoin(self, other, numPartitions=None):
+ """
+ Return a new DStream by applying 'right outer join' between RDDs of this DStream and
+ `other` DStream.
+
+ Hash partitioning is used to generate the RDDs with `numPartitions`
+ partitions.
+ """
+ if numPartitions is None:
+ numPartitions = self._sc.defaultParallelism
+ return self.transformWith(lambda a, b: a.rightOuterJoin(b, numPartitions), other)
+
+ def fullOuterJoin(self, other, numPartitions=None):
+ """
+ Return a new DStream by applying 'full outer join' between RDDs of this DStream and
+ `other` DStream.
+
+ Hash partitioning is used to generate the RDDs with `numPartitions`
+ partitions.
+ """
+ if numPartitions is None:
+ numPartitions = self._sc.defaultParallelism
+ return self.transformWith(lambda a, b: a.fullOuterJoin(b, numPartitions), other)
+
+ def _jtime(self, timestamp):
+ """ Convert datetime or unix_timestamp into Time
+ """
+ if isinstance(timestamp, datetime):
+ timestamp = time.mktime(timestamp.timetuple())
+ return self._sc._jvm.Time(long(timestamp * 1000))
+
+ def slice(self, begin, end):
+ """
+ Return all the RDDs between 'begin' to 'end' (both included)
+
+ `begin`, `end` could be datetime.datetime() or unix_timestamp
+ """
+ jrdds = self._jdstream.slice(self._jtime(begin), self._jtime(end))
+ return [RDD(jrdd, self._sc, self._jrdd_deserializer) for jrdd in jrdds]
+
+ def _validate_window_param(self, window, slide):
+ duration = self._jdstream.dstream().slideDuration().milliseconds()
+ if int(window * 1000) % duration != 0:
+ raise ValueError("windowDuration must be multiple of the slide duration (%d ms)"
+ % duration)
+ if slide and int(slide * 1000) % duration != 0:
+ raise ValueError("slideDuration must be multiple of the slide duration (%d ms)"
+ % duration)
+
+ def window(self, windowDuration, slideDuration=None):
+ """
+ Return a new DStream in which each RDD contains all the elements in seen in a
+ sliding window of time over this DStream.
+
+ @param windowDuration: width of the window; must be a multiple of this DStream's
+ batching interval
+ @param slideDuration: sliding interval of the window (i.e., the interval after which
+ the new DStream will generate RDDs); must be a multiple of this
+ DStream's batching interval
+ """
+ self._validate_window_param(windowDuration, slideDuration)
+ d = self._ssc._jduration(windowDuration)
+ if slideDuration is None:
+ return DStream(self._jdstream.window(d), self._ssc, self._jrdd_deserializer)
+ s = self._ssc._jduration(slideDuration)
+ return DStream(self._jdstream.window(d, s), self._ssc, self._jrdd_deserializer)
+
+ def reduceByWindow(self, reduceFunc, invReduceFunc, windowDuration, slideDuration):
+ """
+ Return a new DStream in which each RDD has a single element generated by reducing all
+ elements in a sliding window over this DStream.
+
+ if `invReduceFunc` is not None, the reduction is done incrementally
+ using the old window's reduced value :
+
+ 1. reduce the new values that entered the window (e.g., adding new counts)
+
+ 2. "inverse reduce" the old values that left the window (e.g., subtracting old counts)
+ This is more efficient than `invReduceFunc` is None.
+
+ @param reduceFunc: associative reduce function
+ @param invReduceFunc: inverse reduce function of `reduceFunc`
+ @param windowDuration: width of the window; must be a multiple of this DStream's
+ batching interval
+ @param slideDuration: sliding interval of the window (i.e., the interval after which
+ the new DStream will generate RDDs); must be a multiple of this
+ DStream's batching interval
+ """
+ keyed = self.map(lambda x: (1, x))
+ reduced = keyed.reduceByKeyAndWindow(reduceFunc, invReduceFunc,
+ windowDuration, slideDuration, 1)
+ return reduced.map(lambda (k, v): v)
+
+ def countByWindow(self, windowDuration, slideDuration):
+ """
+ Return a new DStream in which each RDD has a single element generated
+ by counting the number of elements in a window over this DStream.
+ windowDuration and slideDuration are as defined in the window() operation.
+
+ This is equivalent to window(windowDuration, slideDuration).count(),
+ but will be more efficient if window is large.
+ """
+ return self.map(lambda x: 1).reduceByWindow(operator.add, operator.sub,
+ windowDuration, slideDuration)
+
+ def countByValueAndWindow(self, windowDuration, slideDuration, numPartitions=None):
+ """
+ Return a new DStream in which each RDD contains the count of distinct elements in
+ RDDs in a sliding window over this DStream.
+
+ @param windowDuration: width of the window; must be a multiple of this DStream's
+ batching interval
+ @param slideDuration: sliding interval of the window (i.e., the interval after which
+ the new DStream will generate RDDs); must be a multiple of this
+ DStream's batching interval
+ @param numPartitions: number of partitions of each RDD in the new DStream.
+ """
+ keyed = self.map(lambda x: (x, 1))
+ counted = keyed.reduceByKeyAndWindow(operator.add, operator.sub,
+ windowDuration, slideDuration, numPartitions)
+ return counted.filter(lambda (k, v): v > 0).count()
+
+ def groupByKeyAndWindow(self, windowDuration, slideDuration, numPartitions=None):
+ """
+ Return a new DStream by applying `groupByKey` over a sliding window.
+ Similar to `DStream.groupByKey()`, but applies it over a sliding window.
+
+ @param windowDuration: width of the window; must be a multiple of this DStream's
+ batching interval
+ @param slideDuration: sliding interval of the window (i.e., the interval after which
+ the new DStream will generate RDDs); must be a multiple of this
+ DStream's batching interval
+ @param numPartitions: Number of partitions of each RDD in the new DStream.
+ """
+ ls = self.mapValues(lambda x: [x])
+ grouped = ls.reduceByKeyAndWindow(lambda a, b: a.extend(b) or a, lambda a, b: a[len(b):],
+ windowDuration, slideDuration, numPartitions)
+ return grouped.mapValues(ResultIterable)
+
+ def reduceByKeyAndWindow(self, func, invFunc, windowDuration, slideDuration=None,
+ numPartitions=None, filterFunc=None):
+ """
+ Return a new DStream by applying incremental `reduceByKey` over a sliding window.
+
+ The reduced value of over a new window is calculated using the old window's reduce value :
+ 1. reduce the new values that entered the window (e.g., adding new counts)
+ 2. "inverse reduce" the old values that left the window (e.g., subtracting old counts)
+
+ `invFunc` can be None, then it will reduce all the RDDs in window, could be slower
+ than having `invFunc`.
+
+ @param reduceFunc: associative reduce function
+ @param invReduceFunc: inverse function of `reduceFunc`
+ @param windowDuration: width of the window; must be a multiple of this DStream's
+ batching interval
+ @param slideDuration: sliding interval of the window (i.e., the interval after which
+ the new DStream will generate RDDs); must be a multiple of this
+ DStream's batching interval
+ @param numPartitions: number of partitions of each RDD in the new DStream.
+ @param filterFunc: function to filter expired key-value pairs;
+ only pairs that satisfy the function are retained
+ set this to null if you do not want to filter
+ """
+ self._validate_window_param(windowDuration, slideDuration)
+ if numPartitions is None:
+ numPartitions = self._sc.defaultParallelism
+
+ reduced = self.reduceByKey(func, numPartitions)
+
+ def reduceFunc(t, a, b):
+ b = b.reduceByKey(func, numPartitions)
+ r = a.union(b).reduceByKey(func, numPartitions) if a else b
+ if filterFunc:
+ r = r.filter(filterFunc)
+ return r
+
+ def invReduceFunc(t, a, b):
+ b = b.reduceByKey(func, numPartitions)
+ joined = a.leftOuterJoin(b, numPartitions)
+ return joined.mapValues(lambda (v1, v2): invFunc(v1, v2) if v2 is not None else v1)
+
+ jreduceFunc = TransformFunction(self._sc, reduceFunc, reduced._jrdd_deserializer)
+ if invReduceFunc:
+ jinvReduceFunc = TransformFunction(self._sc, invReduceFunc, reduced._jrdd_deserializer)
+ else:
+ jinvReduceFunc = None
+ if slideDuration is None:
+ slideDuration = self._slideDuration
+ dstream = self._sc._jvm.PythonReducedWindowedDStream(reduced._jdstream.dstream(),
+ jreduceFunc, jinvReduceFunc,
+ self._ssc._jduration(windowDuration),
+ self._ssc._jduration(slideDuration))
+ return DStream(dstream.asJavaDStream(), self._ssc, self._sc.serializer)
+
+ def updateStateByKey(self, updateFunc, numPartitions=None):
+ """
+ Return a new "state" DStream where the state for each key is updated by applying
+ the given function on the previous state of the key and the new values of the key.
+
+ @param updateFunc: State update function. If this function returns None, then
+ corresponding state key-value pair will be eliminated.
+ """
+ if numPartitions is None:
+ numPartitions = self._sc.defaultParallelism
+
+ def reduceFunc(t, a, b):
+ if a is None:
+ g = b.groupByKey(numPartitions).mapValues(lambda vs: (list(vs), None))
+ else:
+ g = a.cogroup(b, numPartitions)
+ g = g.mapValues(lambda (va, vb): (list(vb), list(va)[0] if len(va) else None))
+ state = g.mapValues(lambda (vs, s): updateFunc(vs, s))
+ return state.filter(lambda (k, v): v is not None)
+
+ jreduceFunc = TransformFunction(self._sc, reduceFunc,
+ self._sc.serializer, self._jrdd_deserializer)
+ dstream = self._sc._jvm.PythonStateDStream(self._jdstream.dstream(), jreduceFunc)
+ return DStream(dstream.asJavaDStream(), self._ssc, self._sc.serializer)
+
+
+class TransformedDStream(DStream):
+ """
+ TransformedDStream is an DStream generated by an Python function
+ transforming each RDD of an DStream to another RDDs.
+
+ Multiple continuous transformations of DStream can be combined into
+ one transformation.
+ """
+ def __init__(self, prev, func):
+ self._ssc = prev._ssc
+ self._sc = self._ssc._sc
+ self._jrdd_deserializer = self._sc.serializer
+ self.is_cached = False
+ self.is_checkpointed = False
+ self._jdstream_val = None
+
+ if (isinstance(prev, TransformedDStream) and
+ not prev.is_cached and not prev.is_checkpointed):
+ prev_func = prev.func
+ self.func = lambda t, rdd: func(t, prev_func(t, rdd))
+ self.prev = prev.prev
+ else:
+ self.prev = prev
+ self.func = func
+
+ @property
+ def _jdstream(self):
+ if self._jdstream_val is not None:
+ return self._jdstream_val
+
+ jfunc = TransformFunction(self._sc, self.func, self.prev._jrdd_deserializer)
+ dstream = self._sc._jvm.PythonTransformedDStream(self.prev._jdstream.dstream(), jfunc)
+ self._jdstream_val = dstream.asJavaDStream()
+ return self._jdstream_val
diff --git a/python/pyspark/streaming/tests.py b/python/pyspark/streaming/tests.py
new file mode 100644
index 0000000000000..a8d876d0fa3b3
--- /dev/null
+++ b/python/pyspark/streaming/tests.py
@@ -0,0 +1,545 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import os
+from itertools import chain
+import time
+import operator
+import unittest
+import tempfile
+
+from pyspark.context import SparkConf, SparkContext, RDD
+from pyspark.streaming.context import StreamingContext
+
+
+class PySparkStreamingTestCase(unittest.TestCase):
+
+ timeout = 10 # seconds
+ duration = 1
+
+ def setUp(self):
+ class_name = self.__class__.__name__
+ conf = SparkConf().set("spark.default.parallelism", 1)
+ self.sc = SparkContext(appName=class_name, conf=conf)
+ self.sc.setCheckpointDir("/tmp")
+ # TODO: decrease duration to speed up tests
+ self.ssc = StreamingContext(self.sc, self.duration)
+
+ def tearDown(self):
+ self.ssc.stop()
+
+ def wait_for(self, result, n):
+ start_time = time.time()
+ while len(result) < n and time.time() - start_time < self.timeout:
+ time.sleep(0.01)
+ if len(result) < n:
+ print "timeout after", self.timeout
+
+ def _take(self, dstream, n):
+ """
+ Return the first `n` elements in the stream (will start and stop).
+ """
+ results = []
+
+ def take(_, rdd):
+ if rdd and len(results) < n:
+ results.extend(rdd.take(n - len(results)))
+
+ dstream.foreachRDD(take)
+
+ self.ssc.start()
+ self.wait_for(results, n)
+ return results
+
+ def _collect(self, dstream, n, block=True):
+ """
+ Collect each RDDs into the returned list.
+
+ :return: list, which will have the collected items.
+ """
+ result = []
+
+ def get_output(_, rdd):
+ if rdd and len(result) < n:
+ r = rdd.collect()
+ if r:
+ result.append(r)
+
+ dstream.foreachRDD(get_output)
+
+ if not block:
+ return result
+
+ self.ssc.start()
+ self.wait_for(result, n)
+ return result
+
+ def _test_func(self, input, func, expected, sort=False, input2=None):
+ """
+ @param input: dataset for the test. This should be list of lists.
+ @param func: wrapped function. This function should return PythonDStream object.
+ @param expected: expected output for this testcase.
+ """
+ if not isinstance(input[0], RDD):
+ input = [self.sc.parallelize(d, 1) for d in input]
+ input_stream = self.ssc.queueStream(input)
+ if input2 and not isinstance(input2[0], RDD):
+ input2 = [self.sc.parallelize(d, 1) for d in input2]
+ input_stream2 = self.ssc.queueStream(input2) if input2 is not None else None
+
+ # Apply test function to stream.
+ if input2:
+ stream = func(input_stream, input_stream2)
+ else:
+ stream = func(input_stream)
+
+ result = self._collect(stream, len(expected))
+ if sort:
+ self._sort_result_based_on_key(result)
+ self._sort_result_based_on_key(expected)
+ self.assertEqual(expected, result)
+
+ def _sort_result_based_on_key(self, outputs):
+ """Sort the list based on first value."""
+ for output in outputs:
+ output.sort(key=lambda x: x[0])
+
+
+class BasicOperationTests(PySparkStreamingTestCase):
+
+ def test_map(self):
+ """Basic operation test for DStream.map."""
+ input = [range(1, 5), range(5, 9), range(9, 13)]
+
+ def func(dstream):
+ return dstream.map(str)
+ expected = map(lambda x: map(str, x), input)
+ self._test_func(input, func, expected)
+
+ def test_flatMap(self):
+ """Basic operation test for DStream.faltMap."""
+ input = [range(1, 5), range(5, 9), range(9, 13)]
+
+ def func(dstream):
+ return dstream.flatMap(lambda x: (x, x * 2))
+ expected = map(lambda x: list(chain.from_iterable((map(lambda y: [y, y * 2], x)))),
+ input)
+ self._test_func(input, func, expected)
+
+ def test_filter(self):
+ """Basic operation test for DStream.filter."""
+ input = [range(1, 5), range(5, 9), range(9, 13)]
+
+ def func(dstream):
+ return dstream.filter(lambda x: x % 2 == 0)
+ expected = map(lambda x: filter(lambda y: y % 2 == 0, x), input)
+ self._test_func(input, func, expected)
+
+ def test_count(self):
+ """Basic operation test for DStream.count."""
+ input = [range(5), range(10), range(20)]
+
+ def func(dstream):
+ return dstream.count()
+ expected = map(lambda x: [len(x)], input)
+ self._test_func(input, func, expected)
+
+ def test_reduce(self):
+ """Basic operation test for DStream.reduce."""
+ input = [range(1, 5), range(5, 9), range(9, 13)]
+
+ def func(dstream):
+ return dstream.reduce(operator.add)
+ expected = map(lambda x: [reduce(operator.add, x)], input)
+ self._test_func(input, func, expected)
+
+ def test_reduceByKey(self):
+ """Basic operation test for DStream.reduceByKey."""
+ input = [[("a", 1), ("a", 1), ("b", 1), ("b", 1)],
+ [("", 1), ("", 1), ("", 1), ("", 1)],
+ [(1, 1), (1, 1), (2, 1), (2, 1), (3, 1)]]
+
+ def func(dstream):
+ return dstream.reduceByKey(operator.add)
+ expected = [[("a", 2), ("b", 2)], [("", 4)], [(1, 2), (2, 2), (3, 1)]]
+ self._test_func(input, func, expected, sort=True)
+
+ def test_mapValues(self):
+ """Basic operation test for DStream.mapValues."""
+ input = [[("a", 2), ("b", 2), ("c", 1), ("d", 1)],
+ [("", 4), (1, 1), (2, 2), (3, 3)],
+ [(1, 1), (2, 1), (3, 1), (4, 1)]]
+
+ def func(dstream):
+ return dstream.mapValues(lambda x: x + 10)
+ expected = [[("a", 12), ("b", 12), ("c", 11), ("d", 11)],
+ [("", 14), (1, 11), (2, 12), (3, 13)],
+ [(1, 11), (2, 11), (3, 11), (4, 11)]]
+ self._test_func(input, func, expected, sort=True)
+
+ def test_flatMapValues(self):
+ """Basic operation test for DStream.flatMapValues."""
+ input = [[("a", 2), ("b", 2), ("c", 1), ("d", 1)],
+ [("", 4), (1, 1), (2, 1), (3, 1)],
+ [(1, 1), (2, 1), (3, 1), (4, 1)]]
+
+ def func(dstream):
+ return dstream.flatMapValues(lambda x: (x, x + 10))
+ expected = [[("a", 2), ("a", 12), ("b", 2), ("b", 12),
+ ("c", 1), ("c", 11), ("d", 1), ("d", 11)],
+ [("", 4), ("", 14), (1, 1), (1, 11), (2, 1), (2, 11), (3, 1), (3, 11)],
+ [(1, 1), (1, 11), (2, 1), (2, 11), (3, 1), (3, 11), (4, 1), (4, 11)]]
+ self._test_func(input, func, expected)
+
+ def test_glom(self):
+ """Basic operation test for DStream.glom."""
+ input = [range(1, 5), range(5, 9), range(9, 13)]
+ rdds = [self.sc.parallelize(r, 2) for r in input]
+
+ def func(dstream):
+ return dstream.glom()
+ expected = [[[1, 2], [3, 4]], [[5, 6], [7, 8]], [[9, 10], [11, 12]]]
+ self._test_func(rdds, func, expected)
+
+ def test_mapPartitions(self):
+ """Basic operation test for DStream.mapPartitions."""
+ input = [range(1, 5), range(5, 9), range(9, 13)]
+ rdds = [self.sc.parallelize(r, 2) for r in input]
+
+ def func(dstream):
+ def f(iterator):
+ yield sum(iterator)
+ return dstream.mapPartitions(f)
+ expected = [[3, 7], [11, 15], [19, 23]]
+ self._test_func(rdds, func, expected)
+
+ def test_countByValue(self):
+ """Basic operation test for DStream.countByValue."""
+ input = [range(1, 5) * 2, range(5, 7) + range(5, 9), ["a", "a", "b", ""]]
+
+ def func(dstream):
+ return dstream.countByValue()
+ expected = [[4], [4], [3]]
+ self._test_func(input, func, expected)
+
+ def test_groupByKey(self):
+ """Basic operation test for DStream.groupByKey."""
+ input = [[(1, 1), (2, 1), (3, 1), (4, 1)],
+ [(1, 1), (1, 1), (1, 1), (2, 1), (2, 1), (3, 1)],
+ [("a", 1), ("a", 1), ("b", 1), ("", 1), ("", 1), ("", 1)]]
+
+ def func(dstream):
+ return dstream.groupByKey().mapValues(list)
+
+ expected = [[(1, [1]), (2, [1]), (3, [1]), (4, [1])],
+ [(1, [1, 1, 1]), (2, [1, 1]), (3, [1])],
+ [("a", [1, 1]), ("b", [1]), ("", [1, 1, 1])]]
+ self._test_func(input, func, expected, sort=True)
+
+ def test_combineByKey(self):
+ """Basic operation test for DStream.combineByKey."""
+ input = [[(1, 1), (2, 1), (3, 1), (4, 1)],
+ [(1, 1), (1, 1), (1, 1), (2, 1), (2, 1), (3, 1)],
+ [("a", 1), ("a", 1), ("b", 1), ("", 1), ("", 1), ("", 1)]]
+
+ def func(dstream):
+ def add(a, b):
+ return a + str(b)
+ return dstream.combineByKey(str, add, add)
+ expected = [[(1, "1"), (2, "1"), (3, "1"), (4, "1")],
+ [(1, "111"), (2, "11"), (3, "1")],
+ [("a", "11"), ("b", "1"), ("", "111")]]
+ self._test_func(input, func, expected, sort=True)
+
+ def test_repartition(self):
+ input = [range(1, 5), range(5, 9)]
+ rdds = [self.sc.parallelize(r, 2) for r in input]
+
+ def func(dstream):
+ return dstream.repartition(1).glom()
+ expected = [[[1, 2, 3, 4]], [[5, 6, 7, 8]]]
+ self._test_func(rdds, func, expected)
+
+ def test_union(self):
+ input1 = [range(3), range(5), range(6)]
+ input2 = [range(3, 6), range(5, 6)]
+
+ def func(d1, d2):
+ return d1.union(d2)
+
+ expected = [range(6), range(6), range(6)]
+ self._test_func(input1, func, expected, input2=input2)
+
+ def test_cogroup(self):
+ input = [[(1, 1), (2, 1), (3, 1)],
+ [(1, 1), (1, 1), (1, 1), (2, 1)],
+ [("a", 1), ("a", 1), ("b", 1), ("", 1), ("", 1)]]
+ input2 = [[(1, 2)],
+ [(4, 1)],
+ [("a", 1), ("a", 1), ("b", 1), ("", 1), ("", 2)]]
+
+ def func(d1, d2):
+ return d1.cogroup(d2).mapValues(lambda vs: tuple(map(list, vs)))
+
+ expected = [[(1, ([1], [2])), (2, ([1], [])), (3, ([1], []))],
+ [(1, ([1, 1, 1], [])), (2, ([1], [])), (4, ([], [1]))],
+ [("a", ([1, 1], [1, 1])), ("b", ([1], [1])), ("", ([1, 1], [1, 2]))]]
+ self._test_func(input, func, expected, sort=True, input2=input2)
+
+ def test_join(self):
+ input = [[('a', 1), ('b', 2)]]
+ input2 = [[('b', 3), ('c', 4)]]
+
+ def func(a, b):
+ return a.join(b)
+
+ expected = [[('b', (2, 3))]]
+ self._test_func(input, func, expected, True, input2)
+
+ def test_left_outer_join(self):
+ input = [[('a', 1), ('b', 2)]]
+ input2 = [[('b', 3), ('c', 4)]]
+
+ def func(a, b):
+ return a.leftOuterJoin(b)
+
+ expected = [[('a', (1, None)), ('b', (2, 3))]]
+ self._test_func(input, func, expected, True, input2)
+
+ def test_right_outer_join(self):
+ input = [[('a', 1), ('b', 2)]]
+ input2 = [[('b', 3), ('c', 4)]]
+
+ def func(a, b):
+ return a.rightOuterJoin(b)
+
+ expected = [[('b', (2, 3)), ('c', (None, 4))]]
+ self._test_func(input, func, expected, True, input2)
+
+ def test_full_outer_join(self):
+ input = [[('a', 1), ('b', 2)]]
+ input2 = [[('b', 3), ('c', 4)]]
+
+ def func(a, b):
+ return a.fullOuterJoin(b)
+
+ expected = [[('a', (1, None)), ('b', (2, 3)), ('c', (None, 4))]]
+ self._test_func(input, func, expected, True, input2)
+
+ def test_update_state_by_key(self):
+
+ def updater(vs, s):
+ if not s:
+ s = []
+ s.extend(vs)
+ return s
+
+ input = [[('k', i)] for i in range(5)]
+
+ def func(dstream):
+ return dstream.updateStateByKey(updater)
+
+ expected = [[0], [0, 1], [0, 1, 2], [0, 1, 2, 3], [0, 1, 2, 3, 4]]
+ expected = [[('k', v)] for v in expected]
+ self._test_func(input, func, expected)
+
+
+class WindowFunctionTests(PySparkStreamingTestCase):
+
+ timeout = 20
+
+ def test_window(self):
+ input = [range(1), range(2), range(3), range(4), range(5)]
+
+ def func(dstream):
+ return dstream.window(3, 1).count()
+
+ expected = [[1], [3], [6], [9], [12], [9], [5]]
+ self._test_func(input, func, expected)
+
+ def test_count_by_window(self):
+ input = [range(1), range(2), range(3), range(4), range(5)]
+
+ def func(dstream):
+ return dstream.countByWindow(3, 1)
+
+ expected = [[1], [3], [6], [9], [12], [9], [5]]
+ self._test_func(input, func, expected)
+
+ def test_count_by_window_large(self):
+ input = [range(1), range(2), range(3), range(4), range(5), range(6)]
+
+ def func(dstream):
+ return dstream.countByWindow(5, 1)
+
+ expected = [[1], [3], [6], [10], [15], [20], [18], [15], [11], [6]]
+ self._test_func(input, func, expected)
+
+ def test_count_by_value_and_window(self):
+ input = [range(1), range(2), range(3), range(4), range(5), range(6)]
+
+ def func(dstream):
+ return dstream.countByValueAndWindow(5, 1)
+
+ expected = [[1], [2], [3], [4], [5], [6], [6], [6], [6], [6]]
+ self._test_func(input, func, expected)
+
+ def test_group_by_key_and_window(self):
+ input = [[('a', i)] for i in range(5)]
+
+ def func(dstream):
+ return dstream.groupByKeyAndWindow(3, 1).mapValues(list)
+
+ expected = [[('a', [0])], [('a', [0, 1])], [('a', [0, 1, 2])], [('a', [1, 2, 3])],
+ [('a', [2, 3, 4])], [('a', [3, 4])], [('a', [4])]]
+ self._test_func(input, func, expected)
+
+ def test_reduce_by_invalid_window(self):
+ input1 = [range(3), range(5), range(1), range(6)]
+ d1 = self.ssc.queueStream(input1)
+ self.assertRaises(ValueError, lambda: d1.reduceByKeyAndWindow(None, None, 0.1, 0.1))
+ self.assertRaises(ValueError, lambda: d1.reduceByKeyAndWindow(None, None, 1, 0.1))
+
+
+class StreamingContextTests(PySparkStreamingTestCase):
+
+ duration = 0.1
+
+ def _add_input_stream(self):
+ inputs = map(lambda x: range(1, x), range(101))
+ stream = self.ssc.queueStream(inputs)
+ self._collect(stream, 1, block=False)
+
+ def test_stop_only_streaming_context(self):
+ self._add_input_stream()
+ self.ssc.start()
+ self.ssc.stop(False)
+ self.assertEqual(len(self.sc.parallelize(range(5), 5).glom().collect()), 5)
+
+ def test_stop_multiple_times(self):
+ self._add_input_stream()
+ self.ssc.start()
+ self.ssc.stop()
+ self.ssc.stop()
+
+ def test_queue_stream(self):
+ input = [range(i + 1) for i in range(3)]
+ dstream = self.ssc.queueStream(input)
+ result = self._collect(dstream, 3)
+ self.assertEqual(input, result)
+
+ def test_text_file_stream(self):
+ d = tempfile.mkdtemp()
+ self.ssc = StreamingContext(self.sc, self.duration)
+ dstream2 = self.ssc.textFileStream(d).map(int)
+ result = self._collect(dstream2, 2, block=False)
+ self.ssc.start()
+ for name in ('a', 'b'):
+ time.sleep(1)
+ with open(os.path.join(d, name), "w") as f:
+ f.writelines(["%d\n" % i for i in range(10)])
+ self.wait_for(result, 2)
+ self.assertEqual([range(10), range(10)], result)
+
+ def test_union(self):
+ input = [range(i + 1) for i in range(3)]
+ dstream = self.ssc.queueStream(input)
+ dstream2 = self.ssc.queueStream(input)
+ dstream3 = self.ssc.union(dstream, dstream2)
+ result = self._collect(dstream3, 3)
+ expected = [i * 2 for i in input]
+ self.assertEqual(expected, result)
+
+ def test_transform(self):
+ dstream1 = self.ssc.queueStream([[1]])
+ dstream2 = self.ssc.queueStream([[2]])
+ dstream3 = self.ssc.queueStream([[3]])
+
+ def func(rdds):
+ rdd1, rdd2, rdd3 = rdds
+ return rdd2.union(rdd3).union(rdd1)
+
+ dstream = self.ssc.transform([dstream1, dstream2, dstream3], func)
+
+ self.assertEqual([2, 3, 1], self._take(dstream, 3))
+
+
+class CheckpointTests(PySparkStreamingTestCase):
+
+ def setUp(self):
+ pass
+
+ def test_get_or_create(self):
+ inputd = tempfile.mkdtemp()
+ outputd = tempfile.mkdtemp() + "/"
+
+ def updater(vs, s):
+ return sum(vs, s or 0)
+
+ def setup():
+ conf = SparkConf().set("spark.default.parallelism", 1)
+ sc = SparkContext(conf=conf)
+ ssc = StreamingContext(sc, 0.5)
+ dstream = ssc.textFileStream(inputd).map(lambda x: (x, 1))
+ wc = dstream.updateStateByKey(updater)
+ wc.map(lambda x: "%s,%d" % x).saveAsTextFiles(outputd + "test")
+ wc.checkpoint(.5)
+ return ssc
+
+ cpd = tempfile.mkdtemp("test_streaming_cps")
+ self.ssc = ssc = StreamingContext.getOrCreate(cpd, setup)
+ ssc.start()
+
+ def check_output(n):
+ while not os.listdir(outputd):
+ time.sleep(0.1)
+ time.sleep(1) # make sure mtime is larger than the previous one
+ with open(os.path.join(inputd, str(n)), 'w') as f:
+ f.writelines(["%d\n" % i for i in range(10)])
+
+ while True:
+ p = os.path.join(outputd, max(os.listdir(outputd)))
+ if '_SUCCESS' not in os.listdir(p):
+ # not finished
+ time.sleep(0.01)
+ continue
+ ordd = ssc.sparkContext.textFile(p).map(lambda line: line.split(","))
+ d = ordd.values().map(int).collect()
+ if not d:
+ time.sleep(0.01)
+ continue
+ self.assertEqual(10, len(d))
+ s = set(d)
+ self.assertEqual(1, len(s))
+ m = s.pop()
+ if n > m:
+ continue
+ self.assertEqual(n, m)
+ break
+
+ check_output(1)
+ check_output(2)
+ ssc.stop(True, True)
+
+ time.sleep(1)
+ self.ssc = ssc = StreamingContext.getOrCreate(cpd, setup)
+ ssc.start()
+ check_output(3)
+
+
+if __name__ == "__main__":
+ unittest.main()
diff --git a/python/pyspark/streaming/util.py b/python/pyspark/streaming/util.py
new file mode 100644
index 0000000000000..86ee5aa04f252
--- /dev/null
+++ b/python/pyspark/streaming/util.py
@@ -0,0 +1,128 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import time
+from datetime import datetime
+import traceback
+
+from pyspark import SparkContext, RDD
+
+
+class TransformFunction(object):
+ """
+ This class wraps a function RDD[X] -> RDD[Y] that was passed to
+ DStream.transform(), allowing it to be called from Java via Py4J's
+ callback server.
+
+ Java calls this function with a sequence of JavaRDDs and this function
+ returns a single JavaRDD pointer back to Java.
+ """
+ _emptyRDD = None
+
+ def __init__(self, ctx, func, *deserializers):
+ self.ctx = ctx
+ self.func = func
+ self.deserializers = deserializers
+
+ def call(self, milliseconds, jrdds):
+ try:
+ if self.ctx is None:
+ self.ctx = SparkContext._active_spark_context
+ if not self.ctx or not self.ctx._jsc:
+ # stopped
+ return
+
+ # extend deserializers with the first one
+ sers = self.deserializers
+ if len(sers) < len(jrdds):
+ sers += (sers[0],) * (len(jrdds) - len(sers))
+
+ rdds = [RDD(jrdd, self.ctx, ser) if jrdd else None
+ for jrdd, ser in zip(jrdds, sers)]
+ t = datetime.fromtimestamp(milliseconds / 1000.0)
+ r = self.func(t, *rdds)
+ if r:
+ return r._jrdd
+ except Exception:
+ traceback.print_exc()
+
+ def __repr__(self):
+ return "TransformFunction(%s)" % self.func
+
+ class Java:
+ implements = ['org.apache.spark.streaming.api.python.PythonTransformFunction']
+
+
+class TransformFunctionSerializer(object):
+ """
+ This class implements a serializer for PythonTransformFunction Java
+ objects.
+
+ This is necessary because the Java PythonTransformFunction objects are
+ actually Py4J references to Python objects and thus are not directly
+ serializable. When Java needs to serialize a PythonTransformFunction,
+ it uses this class to invoke Python, which returns the serialized function
+ as a byte array.
+ """
+ def __init__(self, ctx, serializer, gateway=None):
+ self.ctx = ctx
+ self.serializer = serializer
+ self.gateway = gateway or self.ctx._gateway
+ self.gateway.jvm.PythonDStream.registerSerializer(self)
+
+ def dumps(self, id):
+ try:
+ func = self.gateway.gateway_property.pool[id]
+ return bytearray(self.serializer.dumps((func.func, func.deserializers)))
+ except Exception:
+ traceback.print_exc()
+
+ def loads(self, bytes):
+ try:
+ f, deserializers = self.serializer.loads(str(bytes))
+ return TransformFunction(self.ctx, f, *deserializers)
+ except Exception:
+ traceback.print_exc()
+
+ def __repr__(self):
+ return "TransformFunctionSerializer(%s)" % self.serializer
+
+ class Java:
+ implements = ['org.apache.spark.streaming.api.python.PythonTransformFunctionSerializer']
+
+
+def rddToFileName(prefix, suffix, timestamp):
+ """
+ Return string prefix-time(.suffix)
+
+ >>> rddToFileName("spark", None, 12345678910)
+ 'spark-12345678910'
+ >>> rddToFileName("spark", "tmp", 12345678910)
+ 'spark-12345678910.tmp'
+ """
+ if isinstance(timestamp, datetime):
+ seconds = time.mktime(timestamp.timetuple())
+ timestamp = long(seconds * 1000) + timestamp.microsecond / 1000
+ if suffix is None:
+ return prefix + "-" + str(timestamp)
+ else:
+ return prefix + "-" + str(timestamp) + "." + suffix
+
+
+if __name__ == "__main__":
+ import doctest
+ doctest.testmod()
diff --git a/python/pyspark/tests.py b/python/pyspark/tests.py
index 29df754c6fd29..f5ccf31abb3fa 100644
--- a/python/pyspark/tests.py
+++ b/python/pyspark/tests.py
@@ -34,7 +34,11 @@
from platform import python_implementation
if sys.version_info[:2] <= (2, 6):
- import unittest2 as unittest
+ try:
+ import unittest2 as unittest
+ except ImportError:
+ sys.stderr.write('Please install unittest2 to test with Python 2.6 or earlier')
+ sys.exit(1)
else:
import unittest
@@ -67,10 +71,10 @@
SPARK_HOME = os.environ["SPARK_HOME"]
-class TestMerger(unittest.TestCase):
+class MergerTests(unittest.TestCase):
def setUp(self):
- self.N = 1 << 16
+ self.N = 1 << 14
self.l = [i for i in xrange(self.N)]
self.data = zip(self.l, self.l)
self.agg = Aggregator(lambda x: [x],
@@ -115,7 +119,7 @@ def test_medium_dataset(self):
sum(xrange(self.N)) * 3)
def test_huge_dataset(self):
- m = ExternalMerger(self.agg, 10)
+ m = ExternalMerger(self.agg, 10, partitions=3)
m.mergeCombiners(map(lambda (k, v): (k, [str(v)]), self.data * 10))
self.assertTrue(m.spills >= 1)
self.assertEqual(sum(len(v) for k, v in m._recursive_merged_items(0)),
@@ -123,7 +127,7 @@ def test_huge_dataset(self):
m._cleanup()
-class TestSorter(unittest.TestCase):
+class SorterTests(unittest.TestCase):
def test_in_memory_sort(self):
l = range(1024)
random.shuffle(l)
@@ -244,16 +248,25 @@ def tearDown(self):
sys.path = self._old_sys_path
-class TestCheckpoint(PySparkTestCase):
+class ReusedPySparkTestCase(unittest.TestCase):
+
+ @classmethod
+ def setUpClass(cls):
+ cls.sc = SparkContext('local[4]', cls.__name__, batchSize=2)
+
+ @classmethod
+ def tearDownClass(cls):
+ cls.sc.stop()
+
+
+class CheckpointTests(ReusedPySparkTestCase):
def setUp(self):
- PySparkTestCase.setUp(self)
self.checkpointDir = tempfile.NamedTemporaryFile(delete=False)
os.unlink(self.checkpointDir.name)
self.sc.setCheckpointDir(self.checkpointDir.name)
def tearDown(self):
- PySparkTestCase.tearDown(self)
shutil.rmtree(self.checkpointDir.name)
def test_basic_checkpointing(self):
@@ -288,7 +301,7 @@ def test_checkpoint_and_restore(self):
self.assertEquals([1, 2, 3, 4], recovered.collect())
-class TestAddFile(PySparkTestCase):
+class AddFileTests(PySparkTestCase):
def test_add_py_file(self):
# To ensure that we're actually testing addPyFile's effects, check that
@@ -354,7 +367,7 @@ def func(x):
self.assertEqual(["My Server"], self.sc.parallelize(range(1)).map(func).collect())
-class TestRDDFunctions(PySparkTestCase):
+class RDDTests(ReusedPySparkTestCase):
def test_id(self):
rdd = self.sc.parallelize(range(10))
@@ -365,12 +378,6 @@ def test_id(self):
self.assertEqual(id + 1, id2)
self.assertEqual(id2, rdd2.id())
- def test_failed_sparkcontext_creation(self):
- # Regression test for SPARK-1550
- self.sc.stop()
- self.assertRaises(Exception, lambda: SparkContext("an-invalid-master-name"))
- self.sc = SparkContext("local")
-
def test_save_as_textfile_with_unicode(self):
# Regression test for SPARK-970
x = u"\u00A1Hola, mundo!"
@@ -467,8 +474,12 @@ def test_large_broadcast(self):
def test_large_closure(self):
N = 1000000
data = [float(i) for i in xrange(N)]
- m = self.sc.parallelize(range(1), 1).map(lambda x: len(data)).sum()
- self.assertEquals(N, m)
+ rdd = self.sc.parallelize(range(1), 1).map(lambda x: len(data))
+ self.assertEquals(N, rdd.first())
+ self.assertTrue(rdd._broadcast is not None)
+ rdd = self.sc.parallelize(range(1), 1).map(lambda x: 1)
+ self.assertEqual(1, rdd.first())
+ self.assertTrue(rdd._broadcast is None)
def test_zip_with_different_serializers(self):
a = self.sc.parallelize(range(5))
@@ -632,10 +643,39 @@ def test_distinct(self):
self.assertEquals(result.count(), 3)
-class TestSQL(PySparkTestCase):
+class ProfilerTests(PySparkTestCase):
+
+ def setUp(self):
+ self._old_sys_path = list(sys.path)
+ class_name = self.__class__.__name__
+ conf = SparkConf().set("spark.python.profile", "true")
+ self.sc = SparkContext('local[4]', class_name, batchSize=2, conf=conf)
+
+ def test_profiler(self):
+
+ def heavy_foo(x):
+ for i in range(1 << 20):
+ x = 1
+ rdd = self.sc.parallelize(range(100))
+ rdd.foreach(heavy_foo)
+ profiles = self.sc._profile_stats
+ self.assertEqual(1, len(profiles))
+ id, acc, _ = profiles[0]
+ stats = acc.value
+ self.assertTrue(stats is not None)
+ width, stat_list = stats.get_print_list([])
+ func_names = [func_name for fname, n, func_name in stat_list]
+ self.assertTrue("heavy_foo" in func_names)
+
+ self.sc.show_profiles()
+ d = tempfile.gettempdir()
+ self.sc.dump_profiles(d)
+ self.assertTrue("rdd_%d.pstats" % id in os.listdir(d))
+
+
+class SQLTests(ReusedPySparkTestCase):
def setUp(self):
- PySparkTestCase.setUp(self)
self.sqlCtx = SQLContext(self.sc)
def test_udf(self):
@@ -643,6 +683,12 @@ def test_udf(self):
[row] = self.sqlCtx.sql("SELECT twoArgs('test', 1)").collect()
self.assertEqual(row[0], 5)
+ def test_udf2(self):
+ self.sqlCtx.registerFunction("strlen", lambda string: len(string))
+ self.sqlCtx.inferSchema(self.sc.parallelize([Row(a="test")])).registerTempTable("test")
+ [res] = self.sqlCtx.sql("SELECT strlen(a) FROM test WHERE strlen(a) > 1").collect()
+ self.assertEqual(u"4", res[0])
+
def test_broadcast_in_udf(self):
bar = {"a": "aa", "b": "bb", "c": "abc"}
foo = self.sc.broadcast(bar)
@@ -720,27 +766,19 @@ def test_serialize_nested_array_and_map(self):
self.assertEqual("2", row.d)
-class TestIO(PySparkTestCase):
+class InputFormatTests(ReusedPySparkTestCase):
- def test_stdout_redirection(self):
- import subprocess
-
- def func(x):
- subprocess.check_call('ls', shell=True)
- self.sc.parallelize([1]).foreach(func)
-
-
-class TestInputFormat(PySparkTestCase):
-
- def setUp(self):
- PySparkTestCase.setUp(self)
- self.tempdir = tempfile.NamedTemporaryFile(delete=False)
- os.unlink(self.tempdir.name)
- self.sc._jvm.WriteInputFormatTestDataGenerator.generateData(self.tempdir.name, self.sc._jsc)
+ @classmethod
+ def setUpClass(cls):
+ ReusedPySparkTestCase.setUpClass()
+ cls.tempdir = tempfile.NamedTemporaryFile(delete=False)
+ os.unlink(cls.tempdir.name)
+ cls.sc._jvm.WriteInputFormatTestDataGenerator.generateData(cls.tempdir.name, cls.sc._jsc)
- def tearDown(self):
- PySparkTestCase.tearDown(self)
- shutil.rmtree(self.tempdir.name)
+ @classmethod
+ def tearDownClass(cls):
+ ReusedPySparkTestCase.tearDownClass()
+ shutil.rmtree(cls.tempdir.name)
def test_sequencefiles(self):
basepath = self.tempdir.name
@@ -920,15 +958,13 @@ def test_converters(self):
self.assertEqual(maps, em)
-class TestOutputFormat(PySparkTestCase):
+class OutputFormatTests(ReusedPySparkTestCase):
def setUp(self):
- PySparkTestCase.setUp(self)
self.tempdir = tempfile.NamedTemporaryFile(delete=False)
os.unlink(self.tempdir.name)
def tearDown(self):
- PySparkTestCase.tearDown(self)
shutil.rmtree(self.tempdir.name, ignore_errors=True)
def test_sequencefiles(self):
@@ -1209,8 +1245,7 @@ def test_malformed_RDD(self):
basepath + "/malformed/sequence"))
-class TestDaemon(unittest.TestCase):
-
+class DaemonTests(unittest.TestCase):
def connect(self, port):
from socket import socket, AF_INET, SOCK_STREAM
sock = socket(AF_INET, SOCK_STREAM)
@@ -1256,7 +1291,7 @@ def test_termination_sigterm(self):
self.do_termination_test(lambda daemon: os.kill(daemon.pid, SIGTERM))
-class TestWorker(PySparkTestCase):
+class WorkerTests(PySparkTestCase):
def test_cancel_task(self):
temp = tempfile.NamedTemporaryFile(delete=True)
@@ -1308,11 +1343,6 @@ def run():
rdd = self.sc.parallelize(range(100), 1)
self.assertEqual(100, rdd.map(str).count())
- def test_fd_leak(self):
- N = 1100 # fd limit is 1024 by default
- rdd = self.sc.parallelize(range(N), N)
- self.assertEquals(N, rdd.count())
-
def test_after_exception(self):
def raise_exception(_):
raise Exception()
@@ -1345,7 +1375,7 @@ def test_accumulator_when_reuse_worker(self):
self.assertEqual(sum(range(100)), acc1.value)
-class TestSparkSubmit(unittest.TestCase):
+class SparkSubmitTests(unittest.TestCase):
def setUp(self):
self.programDir = tempfile.mkdtemp()
@@ -1458,6 +1488,8 @@ def test_single_script_on_cluster(self):
|sc = SparkContext()
|print sc.parallelize([1, 2, 3]).map(foo).collect()
""")
+ # this will fail if you have different spark.executor.memory
+ # in conf/spark-defaults.conf
proc = subprocess.Popen(
[self.sparkSubmit, "--master", "local-cluster[1,1,512]", script],
stdout=subprocess.PIPE)
@@ -1466,7 +1498,11 @@ def test_single_script_on_cluster(self):
self.assertIn("[2, 4, 6]", out)
-class ContextStopTests(unittest.TestCase):
+class ContextTests(unittest.TestCase):
+
+ def test_failed_sparkcontext_creation(self):
+ # Regression test for SPARK-1550
+ self.assertRaises(Exception, lambda: SparkContext("an-invalid-master-name"))
def test_stop(self):
sc = SparkContext()
diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py
index c1f6e3e4a1f40..8257dddfee1c3 100644
--- a/python/pyspark/worker.py
+++ b/python/pyspark/worker.py
@@ -23,6 +23,8 @@
import time
import socket
import traceback
+import cProfile
+import pstats
from pyspark.accumulators import _accumulatorRegistry
from pyspark.broadcast import Broadcast, _broadcastRegistry
@@ -90,10 +92,21 @@ def main(infile, outfile):
command = pickleSer._read_with_length(infile)
if isinstance(command, Broadcast):
command = pickleSer.loads(command.value)
- (func, deserializer, serializer) = command
+ (func, stats, deserializer, serializer) = command
init_time = time.time()
- iterator = deserializer.load_stream(infile)
- serializer.dump_stream(func(split_index, iterator), outfile)
+
+ def process():
+ iterator = deserializer.load_stream(infile)
+ serializer.dump_stream(func(split_index, iterator), outfile)
+
+ if stats:
+ p = cProfile.Profile()
+ p.runcall(process)
+ st = pstats.Stats(p)
+ st.stream = None # make it picklable
+ stats.add(st.strip_dirs())
+ else:
+ process()
except Exception:
try:
write_int(SpecialLengths.PYTHON_EXCEPTION_THROWN, outfile)
diff --git a/python/run-tests b/python/run-tests
index a7ec270c7da21..80acd002ab7eb 100755
--- a/python/run-tests
+++ b/python/run-tests
@@ -25,16 +25,17 @@ FWDIR="$(cd "`dirname "$0"`"; cd ../; pwd)"
cd "$FWDIR/python"
FAILED=0
+LOG_FILE=unit-tests.log
-rm -f unit-tests.log
+rm -f $LOG_FILE
# Remove the metastore and warehouse directory created by the HiveContext tests in Spark SQL
rm -rf metastore warehouse
function run_test() {
- echo "Running test: $1"
+ echo "Running test: $1" | tee -a $LOG_FILE
- SPARK_TESTING=1 "$FWDIR"/bin/pyspark $1 2>&1 | tee -a unit-tests.log
+ SPARK_TESTING=1 time "$FWDIR"/bin/pyspark $1 2>&1 | tee -a $LOG_FILE
FAILED=$((PIPESTATUS[0]||$FAILED))
@@ -48,6 +49,44 @@ function run_test() {
fi
}
+function run_core_tests() {
+ echo "Run core tests ..."
+ run_test "pyspark/rdd.py"
+ run_test "pyspark/context.py"
+ run_test "pyspark/conf.py"
+ PYSPARK_DOC_TEST=1 run_test "pyspark/broadcast.py"
+ PYSPARK_DOC_TEST=1 run_test "pyspark/accumulators.py"
+ PYSPARK_DOC_TEST=1 run_test "pyspark/serializers.py"
+ run_test "pyspark/shuffle.py"
+ run_test "pyspark/tests.py"
+}
+
+function run_sql_tests() {
+ echo "Run sql tests ..."
+ run_test "pyspark/sql.py"
+}
+
+function run_mllib_tests() {
+ echo "Run mllib tests ..."
+ run_test "pyspark/mllib/classification.py"
+ run_test "pyspark/mllib/clustering.py"
+ run_test "pyspark/mllib/feature.py"
+ run_test "pyspark/mllib/linalg.py"
+ run_test "pyspark/mllib/random.py"
+ run_test "pyspark/mllib/recommendation.py"
+ run_test "pyspark/mllib/regression.py"
+ run_test "pyspark/mllib/stat.py"
+ run_test "pyspark/mllib/tree.py"
+ run_test "pyspark/mllib/util.py"
+ run_test "pyspark/mllib/tests.py"
+}
+
+function run_streaming_tests() {
+ echo "Run streaming tests ..."
+ run_test "pyspark/streaming/util.py"
+ run_test "pyspark/streaming/tests.py"
+}
+
echo "Running PySpark tests. Output is in python/unit-tests.log."
export PYSPARK_PYTHON="python"
@@ -60,29 +99,10 @@ fi
echo "Testing with Python version:"
$PYSPARK_PYTHON --version
-run_test "pyspark/rdd.py"
-run_test "pyspark/context.py"
-run_test "pyspark/conf.py"
-run_test "pyspark/sql.py"
-# These tests are included in the module-level docs, and so must
-# be handled on a higher level rather than within the python file.
-export PYSPARK_DOC_TEST=1
-run_test "pyspark/broadcast.py"
-run_test "pyspark/accumulators.py"
-run_test "pyspark/serializers.py"
-unset PYSPARK_DOC_TEST
-run_test "pyspark/shuffle.py"
-run_test "pyspark/tests.py"
-run_test "pyspark/mllib/classification.py"
-run_test "pyspark/mllib/clustering.py"
-run_test "pyspark/mllib/linalg.py"
-run_test "pyspark/mllib/random.py"
-run_test "pyspark/mllib/recommendation.py"
-run_test "pyspark/mllib/regression.py"
-run_test "pyspark/mllib/stat.py"
-run_test "pyspark/mllib/tests.py"
-run_test "pyspark/mllib/tree.py"
-run_test "pyspark/mllib/util.py"
+run_core_tests
+run_sql_tests
+run_mllib_tests
+run_streaming_tests
# Try to test with PyPy
if [ $(which pypy) ]; then
@@ -90,19 +110,9 @@ if [ $(which pypy) ]; then
echo "Testing with PyPy version:"
$PYSPARK_PYTHON --version
- run_test "pyspark/rdd.py"
- run_test "pyspark/context.py"
- run_test "pyspark/conf.py"
- run_test "pyspark/sql.py"
- # These tests are included in the module-level docs, and so must
- # be handled on a higher level rather than within the python file.
- export PYSPARK_DOC_TEST=1
- run_test "pyspark/broadcast.py"
- run_test "pyspark/accumulators.py"
- run_test "pyspark/serializers.py"
- unset PYSPARK_DOC_TEST
- run_test "pyspark/shuffle.py"
- run_test "pyspark/tests.py"
+ run_core_tests
+ run_sql_tests
+ run_streaming_tests
fi
if [[ $FAILED == 0 ]]; then
diff --git a/repl/src/main/scala/org/apache/spark/repl/SparkIMain.scala b/repl/src/main/scala/org/apache/spark/repl/SparkIMain.scala
index 6ddb6accd696b..646c68e60c2e9 100644
--- a/repl/src/main/scala/org/apache/spark/repl/SparkIMain.scala
+++ b/repl/src/main/scala/org/apache/spark/repl/SparkIMain.scala
@@ -84,9 +84,11 @@ import org.apache.spark.util.Utils
* @author Moez A. Abdel-Gawad
* @author Lex Spoon
*/
- class SparkIMain(initialSettings: Settings, val out: JPrintWriter)
- extends SparkImports with Logging {
- imain =>
+ class SparkIMain(
+ initialSettings: Settings,
+ val out: JPrintWriter,
+ propagateExceptions: Boolean = false)
+ extends SparkImports with Logging { imain =>
val conf = new SparkConf()
@@ -816,6 +818,10 @@ import org.apache.spark.util.Utils
val resultName = FixedSessionNames.resultName
def bindError(t: Throwable) = {
+ // Immediately throw the exception if we are asked to propagate them
+ if (propagateExceptions) {
+ throw unwrap(t)
+ }
if (!bindExceptions) // avoid looping if already binding
throw t
diff --git a/repl/src/test/resources/log4j.properties b/repl/src/test/resources/log4j.properties
index 9c4896e49698c..52098993f5c3c 100644
--- a/repl/src/test/resources/log4j.properties
+++ b/repl/src/test/resources/log4j.properties
@@ -21,7 +21,7 @@ log4j.appender.file=org.apache.log4j.FileAppender
log4j.appender.file.append=false
log4j.appender.file.file=target/unit-tests.log
log4j.appender.file.layout=org.apache.log4j.PatternLayout
-log4j.appender.file.layout.ConversionPattern=%d{yy/MM/dd HH:mm:ss.SSS} %p %c{1}: %m%n
+log4j.appender.file.layout.ConversionPattern=%d{yy/MM/dd HH:mm:ss.SSS} %t %p %c{1}: %m%n
# Ignore messages below warning level from Jetty, because it's a bit verbose
log4j.logger.org.eclipse.jetty=WARN
diff --git a/repl/src/test/scala/org/apache/spark/repl/ExecutorClassLoaderSuite.scala b/repl/src/test/scala/org/apache/spark/repl/ExecutorClassLoaderSuite.scala
index 3e2ee7541f40d..6a79e76a34db8 100644
--- a/repl/src/test/scala/org/apache/spark/repl/ExecutorClassLoaderSuite.scala
+++ b/repl/src/test/scala/org/apache/spark/repl/ExecutorClassLoaderSuite.scala
@@ -23,8 +23,6 @@ import java.net.{URL, URLClassLoader}
import org.scalatest.BeforeAndAfterAll
import org.scalatest.FunSuite
-import com.google.common.io.Files
-
import org.apache.spark.{SparkConf, TestUtils}
import org.apache.spark.util.Utils
@@ -39,10 +37,8 @@ class ExecutorClassLoaderSuite extends FunSuite with BeforeAndAfterAll {
override def beforeAll() {
super.beforeAll()
- tempDir1 = Files.createTempDir()
- tempDir1.deleteOnExit()
- tempDir2 = Files.createTempDir()
- tempDir2.deleteOnExit()
+ tempDir1 = Utils.createTempDir()
+ tempDir2 = Utils.createTempDir()
url1 = "file://" + tempDir1
urls2 = List(tempDir2.toURI.toURL).toArray
childClassNames.foreach(TestUtils.createCompiledClass(_, tempDir1, "1"))
diff --git a/repl/src/test/scala/org/apache/spark/repl/ReplSuite.scala b/repl/src/test/scala/org/apache/spark/repl/ReplSuite.scala
index c8763eb277052..91c9c52c3c98a 100644
--- a/repl/src/test/scala/org/apache/spark/repl/ReplSuite.scala
+++ b/repl/src/test/scala/org/apache/spark/repl/ReplSuite.scala
@@ -22,7 +22,6 @@ import java.net.URLClassLoader
import scala.collection.mutable.ArrayBuffer
-import com.google.common.io.Files
import org.scalatest.FunSuite
import org.apache.spark.SparkContext
import org.apache.commons.lang3.StringEscapeUtils
@@ -190,8 +189,7 @@ class ReplSuite extends FunSuite {
}
test("interacting with files") {
- val tempDir = Files.createTempDir()
- tempDir.deleteOnExit()
+ val tempDir = Utils.createTempDir()
val out = new FileWriter(tempDir + "/input")
out.write("Hello world!\n")
out.write("What's up?\n")
diff --git a/sbin/spark-config.sh b/sbin/spark-config.sh
index 2718d6cba1c9a..1d154e62ed5b6 100755
--- a/sbin/spark-config.sh
+++ b/sbin/spark-config.sh
@@ -33,7 +33,7 @@ this="$config_bin/$script"
export SPARK_PREFIX="`dirname "$this"`"/..
export SPARK_HOME="${SPARK_PREFIX}"
-export SPARK_CONF_DIR="$SPARK_HOME/conf"
+export SPARK_CONF_DIR="${SPARK_CONF_DIR:-"$SPARK_HOME/conf"}"
# Add the PySpark classes to the PYTHONPATH:
export PYTHONPATH="$SPARK_HOME/python:$PYTHONPATH"
export PYTHONPATH="$SPARK_HOME/python/lib/py4j-0.8.2.1-src.zip:$PYTHONPATH"
diff --git a/sbin/spark-daemon.sh b/sbin/spark-daemon.sh
index bd476b400e1c3..cba475e2dd8c8 100755
--- a/sbin/spark-daemon.sh
+++ b/sbin/spark-daemon.sh
@@ -62,7 +62,7 @@ then
shift
fi
-startStop=$1
+option=$1
shift
command=$1
shift
@@ -122,9 +122,9 @@ if [ "$SPARK_NICENESS" = "" ]; then
fi
-case $startStop in
+case $option in
- (start)
+ (start|spark-submit)
mkdir -p "$SPARK_PID_DIR"
@@ -142,8 +142,14 @@ case $startStop in
spark_rotate_log "$log"
echo starting $command, logging to $log
- cd "$SPARK_PREFIX"
- nohup nice -n $SPARK_NICENESS "$SPARK_PREFIX"/bin/spark-class $command "$@" >> "$log" 2>&1 < /dev/null &
+ if [ $option == spark-submit ]; then
+ source "$SPARK_HOME"/bin/utils.sh
+ gatherSparkSubmitOpts "$@"
+ nohup nice -n $SPARK_NICENESS "$SPARK_PREFIX"/bin/spark-submit --class $command \
+ "${SUBMISSION_OPTS[@]}" spark-internal "${APPLICATION_OPTS[@]}" >> "$log" 2>&1 < /dev/null &
+ else
+ nohup nice -n $SPARK_NICENESS "$SPARK_PREFIX"/bin/spark-class $command "$@" >> "$log" 2>&1 < /dev/null &
+ fi
newpid=$!
echo $newpid > $pid
sleep 2
diff --git a/sbin/start-thriftserver.sh b/sbin/start-thriftserver.sh
index ba953e763faab..50e8e06418b07 100755
--- a/sbin/start-thriftserver.sh
+++ b/sbin/start-thriftserver.sh
@@ -27,7 +27,6 @@ set -o posix
FWDIR="$(cd "`dirname "$0"`"/..; pwd)"
CLASS="org.apache.spark.sql.hive.thriftserver.HiveThriftServer2"
-CLASS_NOT_FOUND_EXIT_STATUS=101
function usage {
echo "Usage: ./sbin/start-thriftserver [options] [thrift server options]"
@@ -49,17 +48,6 @@ if [[ "$@" = *--help ]] || [[ "$@" = *-h ]]; then
exit 0
fi
-source "$FWDIR"/bin/utils.sh
-SUBMIT_USAGE_FUNCTION=usage
-gatherSparkSubmitOpts "$@"
+export SUBMIT_USAGE_FUNCTION=usage
-"$FWDIR"/bin/spark-submit --class $CLASS "${SUBMISSION_OPTS[@]}" spark-internal "${APPLICATION_OPTS[@]}"
-exit_status=$?
-
-if [[ exit_status -eq CLASS_NOT_FOUND_EXIT_STATUS ]]; then
- echo
- echo "Failed to load Hive Thrift server main class $CLASS."
- echo "You need to build Spark with -Phive."
-fi
-
-exit $exit_status
+exec "$FWDIR"/sbin/spark-daemon.sh spark-submit $CLASS 1 "$@"
diff --git a/python/epydoc.conf b/sbin/stop-thriftserver.sh
old mode 100644
new mode 100755
similarity index 55%
rename from python/epydoc.conf
rename to sbin/stop-thriftserver.sh
index 8593e08deda19..4031a00d4a689
--- a/python/epydoc.conf
+++ b/sbin/stop-thriftserver.sh
@@ -1,4 +1,4 @@
-[epydoc] # Epydoc section marker (required by ConfigParser)
+#!/usr/bin/env bash
#
# Licensed to the Apache Software Foundation (ASF) under one or more
@@ -17,22 +17,9 @@
# limitations under the License.
#
-# Information about the project.
-name: Spark 1.0.0 Python API Docs
-url: http://spark.apache.org
+# Stops the thrift server on the machine this script is executed on.
-# The list of modules to document. Modules can be named using
-# dotted names, module filenames, or package directory names.
-# This option may be repeated.
-modules: pyspark
+sbin="`dirname "$0"`"
+sbin="`cd "$sbin"; pwd`"
-# Write html output to the directory "apidocs"
-output: html
-target: docs/
-
-private: no
-
-exclude: pyspark.cloudpickle pyspark.worker pyspark.join
- pyspark.java_gateway pyspark.examples pyspark.shell pyspark.tests
- pyspark.rddsampler pyspark.daemon
- pyspark.mllib.tests pyspark.shuffle
+"$sbin"/spark-daemon.sh stop org.apache.spark.sql.hive.thriftserver.HiveThriftServer2 1
diff --git a/scalastyle-config.xml b/scalastyle-config.xml
index c54f8b72ebf42..0ff521706c71a 100644
--- a/scalastyle-config.xml
+++ b/scalastyle-config.xml
@@ -141,5 +141,5 @@
-
+
diff --git a/sql/README.md b/sql/README.md
index 31f9152344086..c84534da9a3d3 100644
--- a/sql/README.md
+++ b/sql/README.md
@@ -44,38 +44,37 @@ Type in expressions to have them evaluated.
Type :help for more information.
scala> val query = sql("SELECT * FROM (SELECT * FROM src) a")
-query: org.apache.spark.sql.ExecutedQuery =
-SELECT * FROM (SELECT * FROM src) a
-=== Query Plan ===
-Project [key#6:0.0,value#7:0.1]
- HiveTableScan [key#6,value#7], (MetastoreRelation default, src, None), None
+query: org.apache.spark.sql.SchemaRDD =
+== Query Plan ==
+== Physical Plan ==
+HiveTableScan [key#10,value#11], (MetastoreRelation default, src, None), None
```
Query results are RDDs and can be operated as such.
```
scala> query.collect()
-res8: Array[org.apache.spark.sql.execution.Row] = Array([238,val_238], [86,val_86], [311,val_311]...
+res2: Array[org.apache.spark.sql.Row] = Array([238,val_238], [86,val_86], [311,val_311], [27,val_27]...
```
You can also build further queries on top of these RDDs using the query DSL.
```
-scala> query.where('key === 100).toRdd.collect()
-res11: Array[org.apache.spark.sql.execution.Row] = Array([100,val_100], [100,val_100])
+scala> query.where('key === 100).collect()
+res3: Array[org.apache.spark.sql.Row] = Array([100,val_100], [100,val_100])
```
-From the console you can even write rules that transform query plans. For example, the above query has redundant project operators that aren't doing anything. This redundancy can be eliminated using the `transform` function that is available on all [`TreeNode`](http://databricks.github.io/catalyst/latest/api/#catalyst.trees.TreeNode) objects.
+From the console you can even write rules that transform query plans. For example, the above query has redundant project operators that aren't doing anything. This redundancy can be eliminated using the `transform` function that is available on all [`TreeNode`](https://github.com/apache/spark/blob/master/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala) objects.
```scala
-scala> query.logicalPlan
-res1: catalyst.plans.logical.LogicalPlan =
-Project {key#0,value#1}
- Project {key#0,value#1}
+scala> query.queryExecution.analyzed
+res4: org.apache.spark.sql.catalyst.plans.logical.LogicalPlan =
+Project [key#10,value#11]
+ Project [key#10,value#11]
MetastoreRelation default, src, None
-scala> query.logicalPlan transform {
+scala> query.queryExecution.analyzed transform {
| case Project(projectList, child) if projectList == child.output => child
| }
-res2: catalyst.plans.logical.LogicalPlan =
-Project {key#0,value#1}
+res5: res17: org.apache.spark.sql.catalyst.plans.logical.LogicalPlan =
+Project [key#10,value#11]
MetastoreRelation default, src, None
```
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
index 88a8fa7c28e0f..3d4296f9d7068 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
@@ -17,7 +17,7 @@
package org.apache.spark.sql.catalyst
-import java.sql.Timestamp
+import java.sql.{Date, Timestamp}
import org.apache.spark.sql.catalyst.expressions.{GenericRow, Attribute, AttributeReference}
import org.apache.spark.sql.catalyst.plans.logical.LocalRelation
@@ -33,7 +33,7 @@ object ScalaReflection {
/** Converts Scala objects to catalyst rows / types */
def convertToCatalyst(a: Any): Any = a match {
- case o: Option[_] => o.orNull
+ case o: Option[_] => o.map(convertToCatalyst).orNull
case s: Seq[_] => s.map(convertToCatalyst)
case m: Map[_, _] => m.map { case (k, v) => convertToCatalyst(k) -> convertToCatalyst(v) }
case p: Product => new GenericRow(p.productIterator.map(convertToCatalyst).toArray)
@@ -77,8 +77,9 @@ object ScalaReflection {
val Schema(valueDataType, valueNullable) = schemaFor(valueType)
Schema(MapType(schemaFor(keyType).dataType,
valueDataType, valueContainsNull = valueNullable), nullable = true)
- case t if t <:< typeOf[String] => Schema(StringType, nullable = true)
+ case t if t <:< typeOf[String] => Schema(StringType, nullable = true)
case t if t <:< typeOf[Timestamp] => Schema(TimestampType, nullable = true)
+ case t if t <:< typeOf[Date] => Schema(DateType, nullable = true)
case t if t <:< typeOf[BigDecimal] => Schema(DecimalType, nullable = true)
case t if t <:< typeOf[java.lang.Integer] => Schema(IntegerType, nullable = true)
case t if t <:< typeOf[java.lang.Long] => Schema(LongType, nullable = true)
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SparkSQLParser.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SparkSQLParser.scala
new file mode 100644
index 0000000000000..04467342e6ab5
--- /dev/null
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SparkSQLParser.scala
@@ -0,0 +1,186 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst
+
+import scala.language.implicitConversions
+import scala.util.parsing.combinator.lexical.StdLexical
+import scala.util.parsing.combinator.syntactical.StandardTokenParsers
+import scala.util.parsing.combinator.{PackratParsers, RegexParsers}
+import scala.util.parsing.input.CharArrayReader.EofCh
+
+import org.apache.spark.sql.catalyst.plans.logical._
+
+private[sql] abstract class AbstractSparkSQLParser
+ extends StandardTokenParsers with PackratParsers {
+
+ def apply(input: String): LogicalPlan = phrase(start)(new lexical.Scanner(input)) match {
+ case Success(plan, _) => plan
+ case failureOrError => sys.error(failureOrError.toString)
+ }
+
+ protected case class Keyword(str: String)
+
+ protected def start: Parser[LogicalPlan]
+
+ // Returns the whole input string
+ protected lazy val wholeInput: Parser[String] = new Parser[String] {
+ def apply(in: Input): ParseResult[String] =
+ Success(in.source.toString, in.drop(in.source.length()))
+ }
+
+ // Returns the rest of the input string that are not parsed yet
+ protected lazy val restInput: Parser[String] = new Parser[String] {
+ def apply(in: Input): ParseResult[String] =
+ Success(
+ in.source.subSequence(in.offset, in.source.length()).toString,
+ in.drop(in.source.length()))
+ }
+}
+
+class SqlLexical(val keywords: Seq[String]) extends StdLexical {
+ case class FloatLit(chars: String) extends Token {
+ override def toString = chars
+ }
+
+ reserved ++= keywords.flatMap(w => allCaseVersions(w))
+
+ delimiters += (
+ "@", "*", "+", "-", "<", "=", "<>", "!=", "<=", ">=", ">", "/", "(", ")",
+ ",", ";", "%", "{", "}", ":", "[", "]", "."
+ )
+
+ override lazy val token: Parser[Token] =
+ ( identChar ~ (identChar | digit).* ^^
+ { case first ~ rest => processIdent((first :: rest).mkString) }
+ | rep1(digit) ~ ('.' ~> digit.*).? ^^ {
+ case i ~ None => NumericLit(i.mkString)
+ case i ~ Some(d) => FloatLit(i.mkString + "." + d.mkString)
+ }
+ | '\'' ~> chrExcept('\'', '\n', EofCh).* <~ '\'' ^^
+ { case chars => StringLit(chars mkString "") }
+ | '"' ~> chrExcept('"', '\n', EofCh).* <~ '"' ^^
+ { case chars => StringLit(chars mkString "") }
+ | EofCh ^^^ EOF
+ | '\'' ~> failure("unclosed string literal")
+ | '"' ~> failure("unclosed string literal")
+ | delim
+ | failure("illegal character")
+ )
+
+ override def identChar = letter | elem('_')
+
+ override def whitespace: Parser[Any] =
+ ( whitespaceChar
+ | '/' ~ '*' ~ comment
+ | '/' ~ '/' ~ chrExcept(EofCh, '\n').*
+ | '#' ~ chrExcept(EofCh, '\n').*
+ | '-' ~ '-' ~ chrExcept(EofCh, '\n').*
+ | '/' ~ '*' ~ failure("unclosed comment")
+ ).*
+
+ /** Generate all variations of upper and lower case of a given string */
+ def allCaseVersions(s: String, prefix: String = ""): Stream[String] = {
+ if (s == "") {
+ Stream(prefix)
+ } else {
+ allCaseVersions(s.tail, prefix + s.head.toLower) ++
+ allCaseVersions(s.tail, prefix + s.head.toUpper)
+ }
+ }
+}
+
+/**
+ * The top level Spark SQL parser. This parser recognizes syntaxes that are available for all SQL
+ * dialects supported by Spark SQL, and delegates all the other syntaxes to the `fallback` parser.
+ *
+ * @param fallback A function that parses an input string to a logical plan
+ */
+private[sql] class SparkSQLParser(fallback: String => LogicalPlan) extends AbstractSparkSQLParser {
+
+ // A parser for the key-value part of the "SET [key = [value ]]" syntax
+ private object SetCommandParser extends RegexParsers {
+ private val key: Parser[String] = "(?m)[^=]+".r
+
+ private val value: Parser[String] = "(?m).*$".r
+
+ private val pair: Parser[LogicalPlan] =
+ (key ~ ("=".r ~> value).?).? ^^ {
+ case None => SetCommand(None)
+ case Some(k ~ v) => SetCommand(Some(k.trim -> v.map(_.trim)))
+ }
+
+ def apply(input: String): LogicalPlan = parseAll(pair, input) match {
+ case Success(plan, _) => plan
+ case x => sys.error(x.toString)
+ }
+ }
+
+ protected val AS = Keyword("AS")
+ protected val CACHE = Keyword("CACHE")
+ protected val LAZY = Keyword("LAZY")
+ protected val SET = Keyword("SET")
+ protected val TABLE = Keyword("TABLE")
+ protected val SOURCE = Keyword("SOURCE")
+ protected val UNCACHE = Keyword("UNCACHE")
+
+ protected implicit def asParser(k: Keyword): Parser[String] =
+ lexical.allCaseVersions(k.str).map(x => x : Parser[String]).reduce(_ | _)
+
+ private val reservedWords: Seq[String] =
+ this
+ .getClass
+ .getMethods
+ .filter(_.getReturnType == classOf[Keyword])
+ .map(_.invoke(this).asInstanceOf[Keyword].str)
+
+ override val lexical = new SqlLexical(reservedWords)
+
+ override protected lazy val start: Parser[LogicalPlan] =
+ cache | uncache | set | shell | source | others
+
+ private lazy val cache: Parser[LogicalPlan] =
+ CACHE ~> LAZY.? ~ (TABLE ~> ident) ~ (AS ~> restInput).? ^^ {
+ case isLazy ~ tableName ~ plan =>
+ CacheTableCommand(tableName, plan.map(fallback), isLazy.isDefined)
+ }
+
+ private lazy val uncache: Parser[LogicalPlan] =
+ UNCACHE ~ TABLE ~> ident ^^ {
+ case tableName => UncacheTableCommand(tableName)
+ }
+
+ private lazy val set: Parser[LogicalPlan] =
+ SET ~> restInput ^^ {
+ case input => SetCommandParser(input)
+ }
+
+ private lazy val shell: Parser[LogicalPlan] =
+ "!" ~> restInput ^^ {
+ case input => ShellCommand(input.trim)
+ }
+
+ private lazy val source: Parser[LogicalPlan] =
+ SOURCE ~> restInput ^^ {
+ case input => SourceCommand(input.trim)
+ }
+
+ private lazy val others: Parser[LogicalPlan] =
+ wholeInput ^^ {
+ case input => fallback(input)
+ }
+}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala
index 862f78702c4e6..b4d606d37e732 100755
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala
@@ -18,10 +18,6 @@
package org.apache.spark.sql.catalyst
import scala.language.implicitConversions
-import scala.util.parsing.combinator.lexical.StdLexical
-import scala.util.parsing.combinator.syntactical.StandardTokenParsers
-import scala.util.parsing.combinator.PackratParsers
-import scala.util.parsing.input.CharArrayReader.EofCh
import org.apache.spark.sql.catalyst.analysis._
import org.apache.spark.sql.catalyst.expressions._
@@ -39,50 +35,30 @@ import org.apache.spark.sql.catalyst.types._
* This is currently included mostly for illustrative purposes. Users wanting more complete support
* for a SQL like language should checkout the HiveQL support in the sql/hive sub-project.
*/
-class SqlParser extends StandardTokenParsers with PackratParsers {
-
- def apply(input: String): LogicalPlan = {
- // Special-case out set commands since the value fields can be
- // complex to handle without RegexParsers. Also this approach
- // is clearer for the several possible cases of set commands.
- if (input.trim.toLowerCase.startsWith("set")) {
- input.trim.drop(3).split("=", 2).map(_.trim) match {
- case Array("") => // "set"
- SetCommand(None, None)
- case Array(key) => // "set key"
- SetCommand(Some(key), None)
- case Array(key, value) => // "set key=value"
- SetCommand(Some(key), Some(value))
- }
- } else {
- phrase(query)(new lexical.Scanner(input)) match {
- case Success(r, x) => r
- case x => sys.error(x.toString)
- }
- }
- }
-
- protected case class Keyword(str: String)
-
+class SqlParser extends AbstractSparkSQLParser {
protected implicit def asParser(k: Keyword): Parser[String] =
lexical.allCaseVersions(k.str).map(x => x : Parser[String]).reduce(_ | _)
+ protected val ABS = Keyword("ABS")
protected val ALL = Keyword("ALL")
protected val AND = Keyword("AND")
+ protected val APPROXIMATE = Keyword("APPROXIMATE")
protected val AS = Keyword("AS")
protected val ASC = Keyword("ASC")
- protected val APPROXIMATE = Keyword("APPROXIMATE")
protected val AVG = Keyword("AVG")
protected val BETWEEN = Keyword("BETWEEN")
protected val BY = Keyword("BY")
protected val CACHE = Keyword("CACHE")
+ protected val CASE = Keyword("CASE")
protected val CAST = Keyword("CAST")
protected val COUNT = Keyword("COUNT")
protected val DESC = Keyword("DESC")
protected val DISTINCT = Keyword("DISTINCT")
+ protected val ELSE = Keyword("ELSE")
+ protected val END = Keyword("END")
+ protected val EXCEPT = Keyword("EXCEPT")
protected val FALSE = Keyword("FALSE")
protected val FIRST = Keyword("FIRST")
- protected val LAST = Keyword("LAST")
protected val FROM = Keyword("FROM")
protected val FULL = Keyword("FULL")
protected val GROUP = Keyword("GROUP")
@@ -91,46 +67,47 @@ class SqlParser extends StandardTokenParsers with PackratParsers {
protected val IN = Keyword("IN")
protected val INNER = Keyword("INNER")
protected val INSERT = Keyword("INSERT")
+ protected val INTERSECT = Keyword("INTERSECT")
protected val INTO = Keyword("INTO")
protected val IS = Keyword("IS")
protected val JOIN = Keyword("JOIN")
+ protected val LAST = Keyword("LAST")
protected val LEFT = Keyword("LEFT")
+ protected val LIKE = Keyword("LIKE")
protected val LIMIT = Keyword("LIMIT")
+ protected val LOWER = Keyword("LOWER")
protected val MAX = Keyword("MAX")
protected val MIN = Keyword("MIN")
protected val NOT = Keyword("NOT")
protected val NULL = Keyword("NULL")
protected val ON = Keyword("ON")
protected val OR = Keyword("OR")
- protected val OVERWRITE = Keyword("OVERWRITE")
- protected val LIKE = Keyword("LIKE")
- protected val RLIKE = Keyword("RLIKE")
- protected val UPPER = Keyword("UPPER")
- protected val LOWER = Keyword("LOWER")
- protected val REGEXP = Keyword("REGEXP")
protected val ORDER = Keyword("ORDER")
protected val OUTER = Keyword("OUTER")
+ protected val OVERWRITE = Keyword("OVERWRITE")
+ protected val REGEXP = Keyword("REGEXP")
protected val RIGHT = Keyword("RIGHT")
+ protected val RLIKE = Keyword("RLIKE")
protected val SELECT = Keyword("SELECT")
protected val SEMI = Keyword("SEMI")
+ protected val SQRT = Keyword("SQRT")
protected val STRING = Keyword("STRING")
+ protected val SUBSTR = Keyword("SUBSTR")
+ protected val SUBSTRING = Keyword("SUBSTRING")
protected val SUM = Keyword("SUM")
protected val TABLE = Keyword("TABLE")
+ protected val THEN = Keyword("THEN")
protected val TIMESTAMP = Keyword("TIMESTAMP")
protected val TRUE = Keyword("TRUE")
- protected val UNCACHE = Keyword("UNCACHE")
protected val UNION = Keyword("UNION")
+ protected val UPPER = Keyword("UPPER")
+ protected val WHEN = Keyword("WHEN")
protected val WHERE = Keyword("WHERE")
- protected val INTERSECT = Keyword("INTERSECT")
- protected val EXCEPT = Keyword("EXCEPT")
- protected val SUBSTR = Keyword("SUBSTR")
- protected val SUBSTRING = Keyword("SUBSTRING")
- protected val SQRT = Keyword("SQRT")
- protected val ABS = Keyword("ABS")
// Use reflection to find the reserved words defined in this class.
protected val reservedWords =
- this.getClass
+ this
+ .getClass
.getMethods
.filter(_.getReturnType == classOf[Keyword])
.map(_.invoke(this).asInstanceOf[Keyword].str)
@@ -144,88 +121,68 @@ class SqlParser extends StandardTokenParsers with PackratParsers {
}
}
- protected lazy val query: Parser[LogicalPlan] = (
- select * (
- UNION ~ ALL ^^^ { (q1: LogicalPlan, q2: LogicalPlan) => Union(q1, q2) } |
- INTERSECT ^^^ { (q1: LogicalPlan, q2: LogicalPlan) => Intersect(q1, q2) } |
- EXCEPT ^^^ { (q1: LogicalPlan, q2: LogicalPlan) => Except(q1, q2)} |
- UNION ~ opt(DISTINCT) ^^^ { (q1: LogicalPlan, q2: LogicalPlan) => Distinct(Union(q1, q2)) }
+ protected lazy val start: Parser[LogicalPlan] =
+ ( select *
+ ( UNION ~ ALL ^^^ { (q1: LogicalPlan, q2: LogicalPlan) => Union(q1, q2) }
+ | INTERSECT ^^^ { (q1: LogicalPlan, q2: LogicalPlan) => Intersect(q1, q2) }
+ | EXCEPT ^^^ { (q1: LogicalPlan, q2: LogicalPlan) => Except(q1, q2)}
+ | UNION ~ DISTINCT.? ^^^ { (q1: LogicalPlan, q2: LogicalPlan) => Distinct(Union(q1, q2)) }
)
- | insert | cache | unCache
- )
+ | insert
+ )
protected lazy val select: Parser[LogicalPlan] =
- SELECT ~> opt(DISTINCT) ~ projections ~
- opt(from) ~ opt(filter) ~
- opt(grouping) ~
- opt(having) ~
- opt(orderBy) ~
- opt(limit) <~ opt(";") ^^ {
- case d ~ p ~ r ~ f ~ g ~ h ~ o ~ l =>
- val base = r.getOrElse(NoRelation)
- val withFilter = f.map(f => Filter(f, base)).getOrElse(base)
- val withProjection =
- g.map {g =>
- Aggregate(assignAliases(g), assignAliases(p), withFilter)
- }.getOrElse(Project(assignAliases(p), withFilter))
- val withDistinct = d.map(_ => Distinct(withProjection)).getOrElse(withProjection)
- val withHaving = h.map(h => Filter(h, withDistinct)).getOrElse(withDistinct)
- val withOrder = o.map(o => Sort(o, withHaving)).getOrElse(withHaving)
- val withLimit = l.map { l => Limit(l, withOrder) }.getOrElse(withOrder)
- withLimit
- }
+ SELECT ~> DISTINCT.? ~
+ repsep(projection, ",") ~
+ (FROM ~> relations).? ~
+ (WHERE ~> expression).? ~
+ (GROUP ~ BY ~> rep1sep(expression, ",")).? ~
+ (HAVING ~> expression).? ~
+ (ORDER ~ BY ~> ordering).? ~
+ (LIMIT ~> expression).? ^^ {
+ case d ~ p ~ r ~ f ~ g ~ h ~ o ~ l =>
+ val base = r.getOrElse(NoRelation)
+ val withFilter = f.map(f => Filter(f, base)).getOrElse(base)
+ val withProjection = g
+ .map(Aggregate(_, assignAliases(p), withFilter))
+ .getOrElse(Project(assignAliases(p), withFilter))
+ val withDistinct = d.map(_ => Distinct(withProjection)).getOrElse(withProjection)
+ val withHaving = h.map(Filter(_, withDistinct)).getOrElse(withDistinct)
+ val withOrder = o.map(Sort(_, withHaving)).getOrElse(withHaving)
+ val withLimit = l.map(Limit(_, withOrder)).getOrElse(withOrder)
+ withLimit
+ }
protected lazy val insert: Parser[LogicalPlan] =
- INSERT ~> opt(OVERWRITE) ~ inTo ~ select <~ opt(";") ^^ {
- case o ~ r ~ s =>
- val overwrite: Boolean = o.getOrElse("") == "OVERWRITE"
- InsertIntoTable(r, Map[String, Option[String]](), s, overwrite)
- }
-
- protected lazy val cache: Parser[LogicalPlan] =
- CACHE ~ TABLE ~> ident ~ opt(AS ~> select) <~ opt(";") ^^ {
- case tableName ~ None =>
- CacheCommand(tableName, true)
- case tableName ~ Some(plan) =>
- CacheTableAsSelectCommand(tableName, plan)
+ INSERT ~> OVERWRITE.? ~ (INTO ~> relation) ~ select ^^ {
+ case o ~ r ~ s => InsertIntoTable(r, Map.empty[String, Option[String]], s, o.isDefined)
}
-
- protected lazy val unCache: Parser[LogicalPlan] =
- UNCACHE ~ TABLE ~> ident <~ opt(";") ^^ {
- case tableName => CacheCommand(tableName, false)
- }
-
- protected lazy val projections: Parser[Seq[Expression]] = repsep(projection, ",")
protected lazy val projection: Parser[Expression] =
- expression ~ (opt(AS) ~> opt(ident)) ^^ {
- case e ~ None => e
- case e ~ Some(a) => Alias(e, a)()
+ expression ~ (AS.? ~> ident.?) ^^ {
+ case e ~ a => a.fold(e)(Alias(e, _)())
}
- protected lazy val from: Parser[LogicalPlan] = FROM ~> relations
-
- protected lazy val inTo: Parser[LogicalPlan] = INTO ~> relation
-
// Based very loosely on the MySQL Grammar.
// http://dev.mysql.com/doc/refman/5.0/en/join.html
protected lazy val relations: Parser[LogicalPlan] =
- relation ~ "," ~ relation ^^ { case r1 ~ _ ~ r2 => Join(r1, r2, Inner, None) } |
- relation
+ ( relation ~ ("," ~> relation) ^^ { case r1 ~ r2 => Join(r1, r2, Inner, None) }
+ | relation
+ )
protected lazy val relation: Parser[LogicalPlan] =
- joinedRelation |
- relationFactor
+ joinedRelation | relationFactor
protected lazy val relationFactor: Parser[LogicalPlan] =
- ident ~ (opt(AS) ~> opt(ident)) ^^ {
- case tableName ~ alias => UnresolvedRelation(None, tableName, alias)
- } |
- "(" ~> query ~ ")" ~ opt(AS) ~ ident ^^ { case s ~ _ ~ _ ~ a => Subquery(a, s) }
+ ( ident ~ (opt(AS) ~> opt(ident)) ^^ {
+ case tableName ~ alias => UnresolvedRelation(None, tableName, alias)
+ }
+ | ("(" ~> start <~ ")") ~ (AS.? ~> ident) ^^ { case s ~ a => Subquery(a, s) }
+ )
protected lazy val joinedRelation: Parser[LogicalPlan] =
- relationFactor ~ opt(joinType) ~ JOIN ~ relationFactor ~ opt(joinConditions) ^^ {
- case r1 ~ jt ~ _ ~ r2 ~ cond =>
+ relationFactor ~ joinType.? ~ (JOIN ~> relationFactor) ~ joinConditions.? ^^ {
+ case r1 ~ jt ~ r2 ~ cond =>
Join(r1, r2, joinType = jt.getOrElse(Inner), cond)
}
@@ -233,151 +190,145 @@ class SqlParser extends StandardTokenParsers with PackratParsers {
ON ~> expression
protected lazy val joinType: Parser[JoinType] =
- INNER ^^^ Inner |
- LEFT ~ SEMI ^^^ LeftSemi |
- LEFT ~ opt(OUTER) ^^^ LeftOuter |
- RIGHT ~ opt(OUTER) ^^^ RightOuter |
- FULL ~ opt(OUTER) ^^^ FullOuter
-
- protected lazy val filter: Parser[Expression] = WHERE ~ expression ^^ { case _ ~ e => e }
-
- protected lazy val orderBy: Parser[Seq[SortOrder]] =
- ORDER ~> BY ~> ordering
+ ( INNER ^^^ Inner
+ | LEFT ~ SEMI ^^^ LeftSemi
+ | LEFT ~ OUTER.? ^^^ LeftOuter
+ | RIGHT ~ OUTER.? ^^^ RightOuter
+ | FULL ~ OUTER.? ^^^ FullOuter
+ )
protected lazy val ordering: Parser[Seq[SortOrder]] =
- rep1sep(singleOrder, ",") |
- rep1sep(expression, ",") ~ opt(direction) ^^ {
- case exps ~ None => exps.map(SortOrder(_, Ascending))
- case exps ~ Some(d) => exps.map(SortOrder(_, d))
- }
+ ( rep1sep(singleOrder, ",")
+ | rep1sep(expression, ",") ~ direction.? ^^ {
+ case exps ~ d => exps.map(SortOrder(_, d.getOrElse(Ascending)))
+ }
+ )
protected lazy val singleOrder: Parser[SortOrder] =
- expression ~ direction ^^ { case e ~ o => SortOrder(e,o) }
+ expression ~ direction ^^ { case e ~ o => SortOrder(e, o) }
protected lazy val direction: Parser[SortDirection] =
- ASC ^^^ Ascending |
- DESC ^^^ Descending
-
- protected lazy val grouping: Parser[Seq[Expression]] =
- GROUP ~> BY ~> rep1sep(expression, ",")
-
- protected lazy val having: Parser[Expression] =
- HAVING ~> expression
-
- protected lazy val limit: Parser[Expression] =
- LIMIT ~> expression
+ ( ASC ^^^ Ascending
+ | DESC ^^^ Descending
+ )
- protected lazy val expression: Parser[Expression] = orExpression
+ protected lazy val expression: Parser[Expression] =
+ orExpression
protected lazy val orExpression: Parser[Expression] =
- andExpression * (OR ^^^ { (e1: Expression, e2: Expression) => Or(e1,e2) })
+ andExpression * (OR ^^^ { (e1: Expression, e2: Expression) => Or(e1, e2) })
protected lazy val andExpression: Parser[Expression] =
- comparisonExpression * (AND ^^^ { (e1: Expression, e2: Expression) => And(e1,e2) })
+ comparisonExpression * (AND ^^^ { (e1: Expression, e2: Expression) => And(e1, e2) })
protected lazy val comparisonExpression: Parser[Expression] =
- termExpression ~ "=" ~ termExpression ^^ { case e1 ~ _ ~ e2 => EqualTo(e1, e2) } |
- termExpression ~ "<" ~ termExpression ^^ { case e1 ~ _ ~ e2 => LessThan(e1, e2) } |
- termExpression ~ "<=" ~ termExpression ^^ { case e1 ~ _ ~ e2 => LessThanOrEqual(e1, e2) } |
- termExpression ~ ">" ~ termExpression ^^ { case e1 ~ _ ~ e2 => GreaterThan(e1, e2) } |
- termExpression ~ ">=" ~ termExpression ^^ { case e1 ~ _ ~ e2 => GreaterThanOrEqual(e1, e2) } |
- termExpression ~ "!=" ~ termExpression ^^ { case e1 ~ _ ~ e2 => Not(EqualTo(e1, e2)) } |
- termExpression ~ "<>" ~ termExpression ^^ { case e1 ~ _ ~ e2 => Not(EqualTo(e1, e2)) } |
- termExpression ~ BETWEEN ~ termExpression ~ AND ~ termExpression ^^ {
- case e ~ _ ~ el ~ _ ~ eu => And(GreaterThanOrEqual(e, el), LessThanOrEqual(e, eu))
- } |
- termExpression ~ RLIKE ~ termExpression ^^ { case e1 ~ _ ~ e2 => RLike(e1, e2) } |
- termExpression ~ REGEXP ~ termExpression ^^ { case e1 ~ _ ~ e2 => RLike(e1, e2) } |
- termExpression ~ LIKE ~ termExpression ^^ { case e1 ~ _ ~ e2 => Like(e1, e2) } |
- termExpression ~ IN ~ "(" ~ rep1sep(termExpression, ",") <~ ")" ^^ {
- case e1 ~ _ ~ _ ~ e2 => In(e1, e2)
- } |
- termExpression ~ NOT ~ IN ~ "(" ~ rep1sep(termExpression, ",") <~ ")" ^^ {
- case e1 ~ _ ~ _ ~ _ ~ e2 => Not(In(e1, e2))
- } |
- termExpression <~ IS ~ NULL ^^ { case e => IsNull(e) } |
- termExpression <~ IS ~ NOT ~ NULL ^^ { case e => IsNotNull(e) } |
- NOT ~> termExpression ^^ {e => Not(e)} |
- termExpression
+ ( termExpression ~ ("=" ~> termExpression) ^^ { case e1 ~ e2 => EqualTo(e1, e2) }
+ | termExpression ~ ("<" ~> termExpression) ^^ { case e1 ~ e2 => LessThan(e1, e2) }
+ | termExpression ~ ("<=" ~> termExpression) ^^ { case e1 ~ e2 => LessThanOrEqual(e1, e2) }
+ | termExpression ~ (">" ~> termExpression) ^^ { case e1 ~ e2 => GreaterThan(e1, e2) }
+ | termExpression ~ (">=" ~> termExpression) ^^ { case e1 ~ e2 => GreaterThanOrEqual(e1, e2) }
+ | termExpression ~ ("!=" ~> termExpression) ^^ { case e1 ~ e2 => Not(EqualTo(e1, e2)) }
+ | termExpression ~ ("<>" ~> termExpression) ^^ { case e1 ~ e2 => Not(EqualTo(e1, e2)) }
+ | termExpression ~ (BETWEEN ~> termExpression) ~ (AND ~> termExpression) ^^ {
+ case e ~ el ~ eu => And(GreaterThanOrEqual(e, el), LessThanOrEqual(e, eu))
+ }
+ | termExpression ~ (RLIKE ~> termExpression) ^^ { case e1 ~ e2 => RLike(e1, e2) }
+ | termExpression ~ (REGEXP ~> termExpression) ^^ { case e1 ~ e2 => RLike(e1, e2) }
+ | termExpression ~ (LIKE ~> termExpression) ^^ { case e1 ~ e2 => Like(e1, e2) }
+ | termExpression ~ (IN ~ "(" ~> rep1sep(termExpression, ",")) <~ ")" ^^ {
+ case e1 ~ e2 => In(e1, e2)
+ }
+ | termExpression ~ (NOT ~ IN ~ "(" ~> rep1sep(termExpression, ",")) <~ ")" ^^ {
+ case e1 ~ e2 => Not(In(e1, e2))
+ }
+ | termExpression <~ IS ~ NULL ^^ { case e => IsNull(e) }
+ | termExpression <~ IS ~ NOT ~ NULL ^^ { case e => IsNotNull(e) }
+ | NOT ~> termExpression ^^ {e => Not(e)}
+ | termExpression
+ )
protected lazy val termExpression: Parser[Expression] =
- productExpression * (
- "+" ^^^ { (e1: Expression, e2: Expression) => Add(e1,e2) } |
- "-" ^^^ { (e1: Expression, e2: Expression) => Subtract(e1,e2) } )
+ productExpression *
+ ( "+" ^^^ { (e1: Expression, e2: Expression) => Add(e1, e2) }
+ | "-" ^^^ { (e1: Expression, e2: Expression) => Subtract(e1, e2) }
+ )
protected lazy val productExpression: Parser[Expression] =
- baseExpression * (
- "*" ^^^ { (e1: Expression, e2: Expression) => Multiply(e1,e2) } |
- "/" ^^^ { (e1: Expression, e2: Expression) => Divide(e1,e2) } |
- "%" ^^^ { (e1: Expression, e2: Expression) => Remainder(e1,e2) }
- )
+ baseExpression *
+ ( "*" ^^^ { (e1: Expression, e2: Expression) => Multiply(e1, e2) }
+ | "/" ^^^ { (e1: Expression, e2: Expression) => Divide(e1, e2) }
+ | "%" ^^^ { (e1: Expression, e2: Expression) => Remainder(e1, e2) }
+ )
protected lazy val function: Parser[Expression] =
- SUM ~> "(" ~> expression <~ ")" ^^ { case exp => Sum(exp) } |
- SUM ~> "(" ~> DISTINCT ~> expression <~ ")" ^^ { case exp => SumDistinct(exp) } |
- COUNT ~> "(" ~ "*" <~ ")" ^^ { case _ => Count(Literal(1)) } |
- COUNT ~> "(" ~ expression <~ ")" ^^ { case dist ~ exp => Count(exp) } |
- COUNT ~> "(" ~> DISTINCT ~> expression <~ ")" ^^ { case exp => CountDistinct(exp :: Nil) } |
- APPROXIMATE ~> COUNT ~> "(" ~> DISTINCT ~> expression <~ ")" ^^ {
- case exp => ApproxCountDistinct(exp)
- } |
- APPROXIMATE ~> "(" ~> floatLit ~ ")" ~ COUNT ~ "(" ~ DISTINCT ~ expression <~ ")" ^^ {
- case s ~ _ ~ _ ~ _ ~ _ ~ e => ApproxCountDistinct(e, s.toDouble)
- } |
- FIRST ~> "(" ~> expression <~ ")" ^^ { case exp => First(exp) } |
- LAST ~> "(" ~> expression <~ ")" ^^ { case exp => Last(exp) } |
- AVG ~> "(" ~> expression <~ ")" ^^ { case exp => Average(exp) } |
- MIN ~> "(" ~> expression <~ ")" ^^ { case exp => Min(exp) } |
- MAX ~> "(" ~> expression <~ ")" ^^ { case exp => Max(exp) } |
- UPPER ~> "(" ~> expression <~ ")" ^^ { case exp => Upper(exp) } |
- LOWER ~> "(" ~> expression <~ ")" ^^ { case exp => Lower(exp) } |
- IF ~> "(" ~> expression ~ "," ~ expression ~ "," ~ expression <~ ")" ^^ {
- case c ~ "," ~ t ~ "," ~ f => If(c,t,f)
- } |
- (SUBSTR | SUBSTRING) ~> "(" ~> expression ~ "," ~ expression <~ ")" ^^ {
- case s ~ "," ~ p => Substring(s,p,Literal(Integer.MAX_VALUE))
- } |
- (SUBSTR | SUBSTRING) ~> "(" ~> expression ~ "," ~ expression ~ "," ~ expression <~ ")" ^^ {
- case s ~ "," ~ p ~ "," ~ l => Substring(s,p,l)
- } |
- SQRT ~> "(" ~> expression <~ ")" ^^ { case exp => Sqrt(exp) } |
- ABS ~> "(" ~> expression <~ ")" ^^ { case exp => Abs(exp) } |
- ident ~ "(" ~ repsep(expression, ",") <~ ")" ^^ {
- case udfName ~ _ ~ exprs => UnresolvedFunction(udfName, exprs)
- }
+ ( SUM ~> "(" ~> expression <~ ")" ^^ { case exp => Sum(exp) }
+ | SUM ~> "(" ~> DISTINCT ~> expression <~ ")" ^^ { case exp => SumDistinct(exp) }
+ | COUNT ~ "(" ~> "*" <~ ")" ^^ { case _ => Count(Literal(1)) }
+ | COUNT ~ "(" ~> expression <~ ")" ^^ { case exp => Count(exp) }
+ | COUNT ~> "(" ~> DISTINCT ~> expression <~ ")" ^^ { case exp => CountDistinct(exp :: Nil) }
+ | APPROXIMATE ~ COUNT ~ "(" ~ DISTINCT ~> expression <~ ")" ^^
+ { case exp => ApproxCountDistinct(exp) }
+ | APPROXIMATE ~> "(" ~> floatLit ~ ")" ~ COUNT ~ "(" ~ DISTINCT ~ expression <~ ")" ^^
+ { case s ~ _ ~ _ ~ _ ~ _ ~ e => ApproxCountDistinct(e, s.toDouble) }
+ | FIRST ~ "(" ~> expression <~ ")" ^^ { case exp => First(exp) }
+ | LAST ~ "(" ~> expression <~ ")" ^^ { case exp => Last(exp) }
+ | AVG ~ "(" ~> expression <~ ")" ^^ { case exp => Average(exp) }
+ | MIN ~ "(" ~> expression <~ ")" ^^ { case exp => Min(exp) }
+ | MAX ~ "(" ~> expression <~ ")" ^^ { case exp => Max(exp) }
+ | UPPER ~ "(" ~> expression <~ ")" ^^ { case exp => Upper(exp) }
+ | LOWER ~ "(" ~> expression <~ ")" ^^ { case exp => Lower(exp) }
+ | IF ~ "(" ~> expression ~ ("," ~> expression) ~ ("," ~> expression) <~ ")" ^^
+ { case c ~ t ~ f => If(c, t, f) }
+ | CASE ~> expression.? ~ (WHEN ~> expression ~ (THEN ~> expression)).* ~
+ (ELSE ~> expression).? <~ END ^^ {
+ case casePart ~ altPart ~ elsePart =>
+ val altExprs = altPart.flatMap { case whenExpr ~ thenExpr =>
+ Seq(casePart.fold(whenExpr)(EqualTo(_, whenExpr)), thenExpr)
+ }
+ CaseWhen(altExprs ++ elsePart.toList)
+ }
+ | (SUBSTR | SUBSTRING) ~ "(" ~> expression ~ ("," ~> expression) <~ ")" ^^
+ { case s ~ p => Substring(s, p, Literal(Integer.MAX_VALUE)) }
+ | (SUBSTR | SUBSTRING) ~ "(" ~> expression ~ ("," ~> expression) ~ ("," ~> expression) <~ ")" ^^
+ { case s ~ p ~ l => Substring(s, p, l) }
+ | SQRT ~ "(" ~> expression <~ ")" ^^ { case exp => Sqrt(exp) }
+ | ABS ~ "(" ~> expression <~ ")" ^^ { case exp => Abs(exp) }
+ | ident ~ ("(" ~> repsep(expression, ",")) <~ ")" ^^
+ { case udfName ~ exprs => UnresolvedFunction(udfName, exprs) }
+ )
protected lazy val cast: Parser[Expression] =
- CAST ~> "(" ~> expression ~ AS ~ dataType <~ ")" ^^ { case exp ~ _ ~ t => Cast(exp, t) }
+ CAST ~ "(" ~> expression ~ (AS ~> dataType) <~ ")" ^^ { case exp ~ t => Cast(exp, t) }
protected lazy val literal: Parser[Literal] =
- numericLit ^^ {
- case i if i.toLong > Int.MaxValue => Literal(i.toLong)
- case i => Literal(i.toInt)
- } |
- NULL ^^^ Literal(null, NullType) |
- floatLit ^^ {case f => Literal(f.toDouble) } |
- stringLit ^^ {case s => Literal(s, StringType) }
+ ( numericLit ^^ {
+ case i if i.toLong > Int.MaxValue => Literal(i.toLong)
+ case i => Literal(i.toInt)
+ }
+ | NULL ^^^ Literal(null, NullType)
+ | floatLit ^^ {case f => Literal(f.toDouble) }
+ | stringLit ^^ {case s => Literal(s, StringType) }
+ )
protected lazy val floatLit: Parser[String] =
elem("decimal", _.isInstanceOf[lexical.FloatLit]) ^^ (_.chars)
protected lazy val baseExpression: PackratParser[Expression] =
- expression ~ "[" ~ expression <~ "]" ^^ {
- case base ~ _ ~ ordinal => GetItem(base, ordinal)
- } |
- (expression <~ ".") ~ ident ^^ {
- case base ~ fieldName => GetField(base, fieldName)
- } |
- TRUE ^^^ Literal(true, BooleanType) |
- FALSE ^^^ Literal(false, BooleanType) |
- cast |
- "(" ~> expression <~ ")" |
- function |
- "-" ~> literal ^^ UnaryMinus |
- dotExpressionHeader |
- ident ^^ UnresolvedAttribute |
- "*" ^^^ Star(None) |
- literal
+ ( expression ~ ("[" ~> expression <~ "]") ^^
+ { case base ~ ordinal => GetItem(base, ordinal) }
+ | (expression <~ ".") ~ ident ^^
+ { case base ~ fieldName => GetField(base, fieldName) }
+ | TRUE ^^^ Literal(true, BooleanType)
+ | FALSE ^^^ Literal(false, BooleanType)
+ | cast
+ | "(" ~> expression <~ ")"
+ | function
+ | "-" ~> literal ^^ UnaryMinus
+ | dotExpressionHeader
+ | ident ^^ UnresolvedAttribute
+ | "*" ^^^ Star(None)
+ | literal
+ )
protected lazy val dotExpressionHeader: Parser[Expression] =
(ident <~ ".") ~ ident ~ rep("." ~> ident) ^^ {
@@ -387,55 +338,3 @@ class SqlParser extends StandardTokenParsers with PackratParsers {
protected lazy val dataType: Parser[DataType] =
STRING ^^^ StringType | TIMESTAMP ^^^ TimestampType
}
-
-class SqlLexical(val keywords: Seq[String]) extends StdLexical {
- case class FloatLit(chars: String) extends Token {
- override def toString = chars
- }
-
- reserved ++= keywords.flatMap(w => allCaseVersions(w))
-
- delimiters += (
- "@", "*", "+", "-", "<", "=", "<>", "!=", "<=", ">=", ">", "/", "(", ")",
- ",", ";", "%", "{", "}", ":", "[", "]", "."
- )
-
- override lazy val token: Parser[Token] = (
- identChar ~ rep( identChar | digit ) ^^
- { case first ~ rest => processIdent(first :: rest mkString "") }
- | rep1(digit) ~ opt('.' ~> rep(digit)) ^^ {
- case i ~ None => NumericLit(i mkString "")
- case i ~ Some(d) => FloatLit(i.mkString("") + "." + d.mkString(""))
- }
- | '\'' ~ rep( chrExcept('\'', '\n', EofCh) ) ~ '\'' ^^
- { case '\'' ~ chars ~ '\'' => StringLit(chars mkString "") }
- | '\"' ~ rep( chrExcept('\"', '\n', EofCh) ) ~ '\"' ^^
- { case '\"' ~ chars ~ '\"' => StringLit(chars mkString "") }
- | EofCh ^^^ EOF
- | '\'' ~> failure("unclosed string literal")
- | '\"' ~> failure("unclosed string literal")
- | delim
- | failure("illegal character")
- )
-
- override def identChar = letter | elem('_')
-
- override def whitespace: Parser[Any] = rep(
- whitespaceChar
- | '/' ~ '*' ~ comment
- | '/' ~ '/' ~ rep( chrExcept(EofCh, '\n') )
- | '#' ~ rep( chrExcept(EofCh, '\n') )
- | '-' ~ '-' ~ rep( chrExcept(EofCh, '\n') )
- | '/' ~ '*' ~ failure("unclosed comment")
- )
-
- /** Generate all variations of upper and lower case of a given string */
- def allCaseVersions(s: String, prefix: String = ""): Stream[String] = {
- if (s == "") {
- Stream(prefix)
- } else {
- allCaseVersions(s.tail, prefix + s.head.toLower) ++
- allCaseVersions(s.tail, prefix + s.head.toUpper)
- }
- }
-}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
index 71810b798bd04..82553063145b8 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
@@ -63,7 +63,8 @@ class Analyzer(catalog: Catalog, registry: FunctionRegistry, caseSensitive: Bool
typeCoercionRules ++
extendedRules : _*),
Batch("Check Analysis", Once,
- CheckResolution),
+ CheckResolution,
+ CheckAggregation),
Batch("AnalysisOperators", fixedPoint,
EliminateAnalysisOperators)
)
@@ -88,11 +89,40 @@ class Analyzer(catalog: Catalog, registry: FunctionRegistry, caseSensitive: Bool
}
}
+ /**
+ * Checks for non-aggregated attributes with aggregation
+ */
+ object CheckAggregation extends Rule[LogicalPlan] {
+ def apply(plan: LogicalPlan): LogicalPlan = {
+ plan.transform {
+ case aggregatePlan @ Aggregate(groupingExprs, aggregateExprs, child) =>
+ def isValidAggregateExpression(expr: Expression): Boolean = expr match {
+ case _: AggregateExpression => true
+ case e: Attribute => groupingExprs.contains(e)
+ case e if groupingExprs.contains(e) => true
+ case e if e.references.isEmpty => true
+ case e => e.children.forall(isValidAggregateExpression)
+ }
+
+ aggregateExprs.foreach { e =>
+ if (!isValidAggregateExpression(e)) {
+ throw new TreeNodeException(plan, s"Expression not in GROUP BY: $e")
+ }
+ }
+
+ aggregatePlan
+ }
+ }
+ }
+
/**
* Replaces [[UnresolvedRelation]]s with concrete relations from the catalog.
*/
object ResolveRelations extends Rule[LogicalPlan] {
def apply(plan: LogicalPlan): LogicalPlan = plan transform {
+ case i @ InsertIntoTable(UnresolvedRelation(databaseName, name, alias), _, _, _) =>
+ i.copy(
+ table = EliminateAnalysisOperators(catalog.lookupRelation(databaseName, name, alias)))
case UnresolvedRelation(databaseName, name, alias) =>
catalog.lookupRelation(databaseName, name, alias)
}
@@ -201,18 +231,17 @@ class Analyzer(catalog: Catalog, registry: FunctionRegistry, caseSensitive: Bool
*/
object UnresolvedHavingClauseAttributes extends Rule[LogicalPlan] {
def apply(plan: LogicalPlan): LogicalPlan = plan transformUp {
- case filter @ Filter(havingCondition, aggregate @ Aggregate(_, originalAggExprs, _))
+ case filter @ Filter(havingCondition, aggregate @ Aggregate(_, originalAggExprs, _))
if aggregate.resolved && containsAggregate(havingCondition) => {
val evaluatedCondition = Alias(havingCondition, "havingCondition")()
val aggExprsWithHaving = evaluatedCondition +: originalAggExprs
-
+
Project(aggregate.output,
Filter(evaluatedCondition.toAttribute,
aggregate.copy(aggregateExpressions = aggExprsWithHaving)))
}
-
}
-
+
protected def containsAggregate(condition: Expression): Boolean =
condition
.collect { case ae: AggregateExpression => ae }
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Catalog.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Catalog.scala
index 616f1e2ecb60f..2059a91ba0612 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Catalog.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Catalog.scala
@@ -87,7 +87,7 @@ class SimpleCatalog(val caseSensitive: Boolean) extends Catalog {
tableName: String,
alias: Option[String] = None): LogicalPlan = {
val (dbName, tblName) = processDatabaseAndTableName(databaseName, tableName)
- val table = tables.get(tblName).getOrElse(sys.error(s"Table Not Found: $tableName"))
+ val table = tables.getOrElse(tblName, sys.error(s"Table Not Found: $tableName"))
val tableWithQualifiers = Subquery(tblName, table)
// If an alias was specified by the lookup, wrap the plan in a subquery so that attributes are
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala
index 79e5283e86a37..7c480de107e7f 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala
@@ -220,20 +220,39 @@ trait HiveTypeCoercion {
case a: BinaryArithmetic if a.right.dataType == StringType =>
a.makeCopy(Array(a.left, Cast(a.right, DoubleType)))
+ // we should cast all timestamp/date/string compare into string compare
+ case p: BinaryPredicate if p.left.dataType == StringType
+ && p.right.dataType == DateType =>
+ p.makeCopy(Array(p.left, Cast(p.right, StringType)))
+ case p: BinaryPredicate if p.left.dataType == DateType
+ && p.right.dataType == StringType =>
+ p.makeCopy(Array(Cast(p.left, StringType), p.right))
case p: BinaryPredicate if p.left.dataType == StringType
&& p.right.dataType == TimestampType =>
- p.makeCopy(Array(Cast(p.left, TimestampType), p.right))
+ p.makeCopy(Array(p.left, Cast(p.right, StringType)))
case p: BinaryPredicate if p.left.dataType == TimestampType
&& p.right.dataType == StringType =>
- p.makeCopy(Array(p.left, Cast(p.right, TimestampType)))
+ p.makeCopy(Array(Cast(p.left, StringType), p.right))
+ case p: BinaryPredicate if p.left.dataType == TimestampType
+ && p.right.dataType == DateType =>
+ p.makeCopy(Array(Cast(p.left, StringType), Cast(p.right, StringType)))
+ case p: BinaryPredicate if p.left.dataType == DateType
+ && p.right.dataType == TimestampType =>
+ p.makeCopy(Array(Cast(p.left, StringType), Cast(p.right, StringType)))
case p: BinaryPredicate if p.left.dataType == StringType && p.right.dataType != StringType =>
p.makeCopy(Array(Cast(p.left, DoubleType), p.right))
case p: BinaryPredicate if p.left.dataType != StringType && p.right.dataType == StringType =>
p.makeCopy(Array(p.left, Cast(p.right, DoubleType)))
- case i @ In(a,b) if a.dataType == TimestampType && b.forall(_.dataType == StringType) =>
- i.makeCopy(Array(a,b.map(Cast(_,TimestampType))))
+ case i @ In(a, b) if a.dataType == DateType && b.forall(_.dataType == StringType) =>
+ i.makeCopy(Array(Cast(a, StringType), b))
+ case i @ In(a, b) if a.dataType == TimestampType && b.forall(_.dataType == StringType) =>
+ i.makeCopy(Array(Cast(a, StringType), b))
+ case i @ In(a, b) if a.dataType == DateType && b.forall(_.dataType == TimestampType) =>
+ i.makeCopy(Array(Cast(a, StringType), b.map(Cast(_, StringType))))
+ case i @ In(a, b) if a.dataType == TimestampType && b.forall(_.dataType == DateType) =>
+ i.makeCopy(Array(Cast(a, StringType), b.map(Cast(_, StringType))))
case Sum(e) if e.dataType == StringType =>
Sum(Cast(e, DoubleType))
@@ -283,6 +302,8 @@ trait HiveTypeCoercion {
// Skip if the type is boolean type already. Note that this extra cast should be removed
// by optimizer.SimplifyCasts.
case Cast(e, BooleanType) if e.dataType == BooleanType => e
+ // DateType should be null if be cast to boolean.
+ case Cast(e, BooleanType) if e.dataType == DateType => Cast(e, BooleanType)
// If the data type is not boolean and is being cast boolean, turn it into a comparison
// with the numeric value, i.e. x != 0. This will coerce the type into numeric type.
case Cast(e, BooleanType) if e.dataType != BooleanType => Not(EqualTo(e, Literal(0)))
@@ -348,8 +369,11 @@ trait HiveTypeCoercion {
case e if !e.childrenResolved => e
// Decimal and Double remain the same
- case d: Divide if d.dataType == DoubleType => d
- case d: Divide if d.dataType == DecimalType => d
+ case d: Divide if d.resolved && d.dataType == DoubleType => d
+ case d: Divide if d.resolved && d.dataType == DecimalType => d
+
+ case Divide(l, r) if l.dataType == DecimalType => Divide(l, Cast(r, DecimalType))
+ case Divide(l, r) if r.dataType == DecimalType => Divide(Cast(l, DecimalType), r)
case Divide(l, r) => Divide(Cast(l, DoubleType), Cast(r, DoubleType))
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala
index 67570a6f73c36..77d84e1687e1b 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala
@@ -88,7 +88,7 @@ case class Star(
mapFunction: Attribute => Expression = identity[Attribute])
extends Attribute with trees.LeafNode[Expression] {
- override def name = throw new UnresolvedException(this, "exprId")
+ override def name = throw new UnresolvedException(this, "name")
override def exprId = throw new UnresolvedException(this, "exprId")
override def dataType = throw new UnresolvedException(this, "dataType")
override def nullable = throw new UnresolvedException(this, "nullable")
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala
index deb622c39faf5..75b6e37c2a1f9 100755
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala
@@ -17,7 +17,7 @@
package org.apache.spark.sql.catalyst
-import java.sql.Timestamp
+import java.sql.{Date, Timestamp}
import scala.language.implicitConversions
@@ -119,6 +119,7 @@ package object dsl {
implicit def floatToLiteral(f: Float) = Literal(f)
implicit def doubleToLiteral(d: Double) = Literal(d)
implicit def stringToLiteral(s: String) = Literal(s)
+ implicit def dateToLiteral(d: Date) = Literal(d)
implicit def decimalToLiteral(d: BigDecimal) = Literal(d)
implicit def timestampToLiteral(t: Timestamp) = Literal(t)
implicit def binaryToLiteral(a: Array[Byte]) = Literal(a)
@@ -174,6 +175,9 @@ package object dsl {
/** Creates a new AttributeReference of type string */
def string = AttributeReference(s, StringType, nullable = true)()
+ /** Creates a new AttributeReference of type date */
+ def date = AttributeReference(s, DateType, nullable = true)()
+
/** Creates a new AttributeReference of type decimal */
def decimal = AttributeReference(s, DecimalType, nullable = true)()
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/AttributeSet.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/AttributeSet.scala
index c3a08bbdb6bc7..2b4969b7cfec0 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/AttributeSet.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/AttributeSet.scala
@@ -17,19 +17,26 @@
package org.apache.spark.sql.catalyst.expressions
+import org.apache.spark.sql.catalyst.analysis.Star
+
protected class AttributeEquals(val a: Attribute) {
override def hashCode() = a.exprId.hashCode()
- override def equals(other: Any) = other match {
- case otherReference: AttributeEquals => a.exprId == otherReference.a.exprId
- case otherAttribute => false
+ override def equals(other: Any) = (a, other.asInstanceOf[AttributeEquals].a) match {
+ case (a1: AttributeReference, a2: AttributeReference) => a1.exprId == a2.exprId
+ case (a1, a2) => a1 == a2
}
}
object AttributeSet {
- /** Constructs a new [[AttributeSet]] given a sequence of [[Attribute Attributes]]. */
- def apply(baseSet: Seq[Attribute]) = {
- new AttributeSet(baseSet.map(new AttributeEquals(_)).toSet)
- }
+ def apply(a: Attribute) =
+ new AttributeSet(Set(new AttributeEquals(a)))
+
+ /** Constructs a new [[AttributeSet]] given a sequence of [[Expression Expressions]]. */
+ def apply(baseSet: Seq[Expression]) =
+ new AttributeSet(
+ baseSet
+ .flatMap(_.references)
+ .map(new AttributeEquals(_)).toSet)
}
/**
@@ -103,4 +110,6 @@ class AttributeSet private (val baseSet: Set[AttributeEquals])
// We must force toSeq to not be strict otherwise we end up with a [[Stream]] that captures all
// sorts of things in its closure.
override def toSeq: Seq[Attribute] = baseSet.map(_.a).toArray.toSeq
+
+ override def toString = "{" + baseSet.map(_.a).mkString(", ") + "}"
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
index f626d09f037bc..8e5ee12e314bf 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
@@ -17,18 +17,21 @@
package org.apache.spark.sql.catalyst.expressions
-import java.sql.Timestamp
+import java.sql.{Date, Timestamp}
import java.text.{DateFormat, SimpleDateFormat}
+import org.apache.spark.Logging
+import org.apache.spark.sql.catalyst.errors.TreeNodeException
import org.apache.spark.sql.catalyst.types._
/** Cast the child expression to the target data type. */
-case class Cast(child: Expression, dataType: DataType) extends UnaryExpression {
+case class Cast(child: Expression, dataType: DataType) extends UnaryExpression with Logging {
override def foldable = child.foldable
override def nullable = (child.dataType, dataType) match {
case (StringType, _: NumericType) => true
case (StringType, TimestampType) => true
+ case (StringType, DateType) => true
case _ => child.nullable
}
@@ -42,6 +45,7 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression {
// UDFToString
private[this] def castToString: Any => Any = child.dataType match {
case BinaryType => buildCast[Array[Byte]](_, new String(_, "UTF-8"))
+ case DateType => buildCast[Date](_, dateToString)
case TimestampType => buildCast[Timestamp](_, timestampToString)
case _ => buildCast[Any](_, _.toString)
}
@@ -56,7 +60,10 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression {
case StringType =>
buildCast[String](_, _.length() != 0)
case TimestampType =>
- buildCast[Timestamp](_, b => b.getTime() != 0 || b.getNanos() != 0)
+ buildCast[Timestamp](_, t => t.getTime() != 0 || t.getNanos() != 0)
+ case DateType =>
+ // Hive would return null when cast from date to boolean
+ buildCast[Date](_, d => null)
case LongType =>
buildCast[Long](_, _ != 0)
case IntegerType =>
@@ -95,6 +102,8 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression {
buildCast[Short](_, s => new Timestamp(s))
case ByteType =>
buildCast[Byte](_, b => new Timestamp(b))
+ case DateType =>
+ buildCast[Date](_, d => new Timestamp(d.getTime))
// TimestampWritable.decimalToTimestamp
case DecimalType =>
buildCast[BigDecimal](_, d => decimalToTimestamp(d))
@@ -130,7 +139,7 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression {
// Converts Timestamp to string according to Hive TimestampWritable convention
private[this] def timestampToString(ts: Timestamp): String = {
val timestampString = ts.toString
- val formatted = Cast.threadLocalDateFormat.get.format(ts)
+ val formatted = Cast.threadLocalTimestampFormat.get.format(ts)
if (timestampString.length > 19 && timestampString.substring(19) != ".0") {
formatted + timestampString.substring(19)
@@ -139,6 +148,39 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression {
}
}
+ // Converts Timestamp to string according to Hive TimestampWritable convention
+ private[this] def timestampToDateString(ts: Timestamp): String = {
+ Cast.threadLocalDateFormat.get.format(ts)
+ }
+
+ // DateConverter
+ private[this] def castToDate: Any => Any = child.dataType match {
+ case StringType =>
+ buildCast[String](_, s =>
+ try Date.valueOf(s) catch { case _: java.lang.IllegalArgumentException => null }
+ )
+ case TimestampType =>
+ // throw valid precision more than seconds, according to Hive.
+ // Timestamp.nanos is in 0 to 999,999,999, no more than a second.
+ buildCast[Timestamp](_, t => new Date(Math.floor(t.getTime / 1000.0).toLong * 1000))
+ // Hive throws this exception as a Semantic Exception
+ // It is never possible to compare result when hive return with exception, so we can return null
+ // NULL is more reasonable here, since the query itself obeys the grammar.
+ case _ => _ => null
+ }
+
+ // Date cannot be cast to long, according to hive
+ private[this] def dateToLong(d: Date) = null
+
+ // Date cannot be cast to double, according to hive
+ private[this] def dateToDouble(d: Date) = null
+
+ // Converts Date to string according to Hive DateWritable convention
+ private[this] def dateToString(d: Date): String = {
+ Cast.threadLocalDateFormat.get.format(d)
+ }
+
+ // LongConverter
private[this] def castToLong: Any => Any = child.dataType match {
case StringType =>
buildCast[String](_, s => try s.toLong catch {
@@ -146,6 +188,8 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression {
})
case BooleanType =>
buildCast[Boolean](_, b => if (b) 1L else 0L)
+ case DateType =>
+ buildCast[Date](_, d => dateToLong(d))
case TimestampType =>
buildCast[Timestamp](_, t => timestampToLong(t))
case DecimalType =>
@@ -154,6 +198,7 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression {
b => x.numeric.asInstanceOf[Numeric[Any]].toLong(b)
}
+ // IntConverter
private[this] def castToInt: Any => Any = child.dataType match {
case StringType =>
buildCast[String](_, s => try s.toInt catch {
@@ -161,6 +206,8 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression {
})
case BooleanType =>
buildCast[Boolean](_, b => if (b) 1 else 0)
+ case DateType =>
+ buildCast[Date](_, d => dateToLong(d))
case TimestampType =>
buildCast[Timestamp](_, t => timestampToLong(t).toInt)
case DecimalType =>
@@ -169,6 +216,7 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression {
b => x.numeric.asInstanceOf[Numeric[Any]].toInt(b)
}
+ // ShortConverter
private[this] def castToShort: Any => Any = child.dataType match {
case StringType =>
buildCast[String](_, s => try s.toShort catch {
@@ -176,6 +224,8 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression {
})
case BooleanType =>
buildCast[Boolean](_, b => if (b) 1.toShort else 0.toShort)
+ case DateType =>
+ buildCast[Date](_, d => dateToLong(d))
case TimestampType =>
buildCast[Timestamp](_, t => timestampToLong(t).toShort)
case DecimalType =>
@@ -184,6 +234,7 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression {
b => x.numeric.asInstanceOf[Numeric[Any]].toInt(b).toShort
}
+ // ByteConverter
private[this] def castToByte: Any => Any = child.dataType match {
case StringType =>
buildCast[String](_, s => try s.toByte catch {
@@ -191,6 +242,8 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression {
})
case BooleanType =>
buildCast[Boolean](_, b => if (b) 1.toByte else 0.toByte)
+ case DateType =>
+ buildCast[Date](_, d => dateToLong(d))
case TimestampType =>
buildCast[Timestamp](_, t => timestampToLong(t).toByte)
case DecimalType =>
@@ -199,6 +252,7 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression {
b => x.numeric.asInstanceOf[Numeric[Any]].toInt(b).toByte
}
+ // DecimalConverter
private[this] def castToDecimal: Any => Any = child.dataType match {
case StringType =>
buildCast[String](_, s => try BigDecimal(s.toDouble) catch {
@@ -206,6 +260,8 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression {
})
case BooleanType =>
buildCast[Boolean](_, b => if (b) BigDecimal(1) else BigDecimal(0))
+ case DateType =>
+ buildCast[Date](_, d => dateToDouble(d))
case TimestampType =>
// Note that we lose precision here.
buildCast[Timestamp](_, t => BigDecimal(timestampToDouble(t)))
@@ -213,6 +269,7 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression {
b => BigDecimal(x.numeric.asInstanceOf[Numeric[Any]].toDouble(b))
}
+ // DoubleConverter
private[this] def castToDouble: Any => Any = child.dataType match {
case StringType =>
buildCast[String](_, s => try s.toDouble catch {
@@ -220,6 +277,8 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression {
})
case BooleanType =>
buildCast[Boolean](_, b => if (b) 1d else 0d)
+ case DateType =>
+ buildCast[Date](_, d => dateToDouble(d))
case TimestampType =>
buildCast[Timestamp](_, t => timestampToDouble(t))
case DecimalType =>
@@ -228,6 +287,7 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression {
b => x.numeric.asInstanceOf[Numeric[Any]].toDouble(b)
}
+ // FloatConverter
private[this] def castToFloat: Any => Any = child.dataType match {
case StringType =>
buildCast[String](_, s => try s.toFloat catch {
@@ -235,6 +295,8 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression {
})
case BooleanType =>
buildCast[Boolean](_, b => if (b) 1f else 0f)
+ case DateType =>
+ buildCast[Date](_, d => dateToDouble(d))
case TimestampType =>
buildCast[Timestamp](_, t => timestampToDouble(t).toFloat)
case DecimalType =>
@@ -245,17 +307,18 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression {
private[this] lazy val cast: Any => Any = dataType match {
case dt if dt == child.dataType => identity[Any]
- case StringType => castToString
- case BinaryType => castToBinary
- case DecimalType => castToDecimal
+ case StringType => castToString
+ case BinaryType => castToBinary
+ case DecimalType => castToDecimal
+ case DateType => castToDate
case TimestampType => castToTimestamp
- case BooleanType => castToBoolean
- case ByteType => castToByte
- case ShortType => castToShort
- case IntegerType => castToInt
- case FloatType => castToFloat
- case LongType => castToLong
- case DoubleType => castToDouble
+ case BooleanType => castToBoolean
+ case ByteType => castToByte
+ case ShortType => castToShort
+ case IntegerType => castToInt
+ case FloatType => castToFloat
+ case LongType => castToLong
+ case DoubleType => castToDouble
}
override def eval(input: Row): Any = {
@@ -267,6 +330,13 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression {
object Cast {
// `SimpleDateFormat` is not thread-safe.
private[sql] val threadLocalDateFormat = new ThreadLocal[DateFormat] {
+ override def initialValue() = {
+ new SimpleDateFormat("yyyy-MM-dd")
+ }
+ }
+
+ // `SimpleDateFormat` is not thread-safe.
+ private[sql] val threadLocalTimestampFormat = new ThreadLocal[DateFormat] {
override def initialValue() = {
new SimpleDateFormat("yyyy-MM-dd HH:mm:ss")
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala
index ef1d12531f109..e7e81a21fdf03 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala
@@ -39,6 +39,8 @@ class InterpretedProjection(expressions: Seq[Expression]) extends Projection {
}
new GenericRow(outputArray)
}
+
+ override def toString = s"Row => [${exprArray.mkString(",")}]"
}
/**
@@ -137,6 +139,9 @@ class JoinedRow extends Row {
def getString(i: Int): String =
if (i < row1.size) row1.getString(i) else row2.getString(i - row1.size)
+ override def getAs[T](i: Int): T =
+ if (i < row1.size) row1.getAs[T](i) else row2.getAs[T](i - row1.size)
+
def copy() = {
val totalSize = row1.size + row2.size
val copiedValues = new Array[Any](totalSize)
@@ -226,6 +231,9 @@ class JoinedRow2 extends Row {
def getString(i: Int): String =
if (i < row1.size) row1.getString(i) else row2.getString(i - row1.size)
+ override def getAs[T](i: Int): T =
+ if (i < row1.size) row1.getAs[T](i) else row2.getAs[T](i - row1.size)
+
def copy() = {
val totalSize = row1.size + row2.size
val copiedValues = new Array[Any](totalSize)
@@ -309,6 +317,9 @@ class JoinedRow3 extends Row {
def getString(i: Int): String =
if (i < row1.size) row1.getString(i) else row2.getString(i - row1.size)
+ override def getAs[T](i: Int): T =
+ if (i < row1.size) row1.getAs[T](i) else row2.getAs[T](i - row1.size)
+
def copy() = {
val totalSize = row1.size + row2.size
val copiedValues = new Array[Any](totalSize)
@@ -392,6 +403,9 @@ class JoinedRow4 extends Row {
def getString(i: Int): String =
if (i < row1.size) row1.getString(i) else row2.getString(i - row1.size)
+ override def getAs[T](i: Int): T =
+ if (i < row1.size) row1.getAs[T](i) else row2.getAs[T](i - row1.size)
+
def copy() = {
val totalSize = row1.size + row2.size
val copiedValues = new Array[Any](totalSize)
@@ -475,6 +489,9 @@ class JoinedRow5 extends Row {
def getString(i: Int): String =
if (i < row1.size) row1.getString(i) else row2.getString(i - row1.size)
+ override def getAs[T](i: Int): T =
+ if (i < row1.size) row1.getAs[T](i) else row2.getAs[T](i - row1.size)
+
def copy() = {
val totalSize = row1.size + row2.size
val copiedValues = new Array[Any](totalSize)
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Row.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Row.scala
index d68a4fabeac77..d00ec39774c35 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Row.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Row.scala
@@ -64,6 +64,7 @@ trait Row extends Seq[Any] with Serializable {
def getShort(i: Int): Short
def getByte(i: Int): Byte
def getString(i: Int): String
+ def getAs[T](i: Int): T = apply(i).asInstanceOf[T]
override def toString() =
s"[${this.mkString(",")}]"
@@ -118,6 +119,7 @@ object EmptyRow extends Row {
def getShort(i: Int): Short = throw new UnsupportedOperationException
def getByte(i: Int): Byte = throw new UnsupportedOperationException
def getString(i: Int): String = throw new UnsupportedOperationException
+ override def getAs[T](i: Int): T = throw new UnsupportedOperationException
def copy() = this
}
@@ -217,19 +219,19 @@ class GenericMutableRow(size: Int) extends GenericRow(size) with MutableRow {
/** No-arg constructor for serialization. */
def this() = this(0)
- override def setBoolean(ordinal: Int,value: Boolean): Unit = { values(ordinal) = value }
- override def setByte(ordinal: Int,value: Byte): Unit = { values(ordinal) = value }
- override def setDouble(ordinal: Int,value: Double): Unit = { values(ordinal) = value }
- override def setFloat(ordinal: Int,value: Float): Unit = { values(ordinal) = value }
- override def setInt(ordinal: Int,value: Int): Unit = { values(ordinal) = value }
- override def setLong(ordinal: Int,value: Long): Unit = { values(ordinal) = value }
- override def setString(ordinal: Int,value: String): Unit = { values(ordinal) = value }
+ override def setBoolean(ordinal: Int, value: Boolean): Unit = { values(ordinal) = value }
+ override def setByte(ordinal: Int, value: Byte): Unit = { values(ordinal) = value }
+ override def setDouble(ordinal: Int, value: Double): Unit = { values(ordinal) = value }
+ override def setFloat(ordinal: Int, value: Float): Unit = { values(ordinal) = value }
+ override def setInt(ordinal: Int, value: Int): Unit = { values(ordinal) = value }
+ override def setLong(ordinal: Int, value: Long): Unit = { values(ordinal) = value }
+ override def setString(ordinal: Int, value: String): Unit = { values(ordinal) = value }
override def setNullAt(i: Int): Unit = { values(i) = null }
- override def setShort(ordinal: Int,value: Short): Unit = { values(ordinal) = value }
+ override def setShort(ordinal: Int, value: Short): Unit = { values(ordinal) = value }
- override def update(ordinal: Int,value: Any): Unit = { values(ordinal) = value }
+ override def update(ordinal: Int, value: Any): Unit = { values(ordinal) = value }
override def copy() = new GenericRow(values.clone())
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SpecificRow.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SpecificMutableRow.scala
similarity index 97%
rename from sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SpecificRow.scala
rename to sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SpecificMutableRow.scala
index 9cbab3d5d0d0d..570379c533e1f 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SpecificRow.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SpecificMutableRow.scala
@@ -233,9 +233,9 @@ final class SpecificMutableRow(val values: Array[MutableValue]) extends MutableR
override def iterator: Iterator[Any] = values.map(_.boxed).iterator
- def setString(ordinal: Int, value: String) = update(ordinal, value)
+ override def setString(ordinal: Int, value: String) = update(ordinal, value)
- def getString(ordinal: Int) = apply(ordinal).asInstanceOf[String]
+ override def getString(ordinal: Int) = apply(ordinal).asInstanceOf[String]
override def setInt(ordinal: Int, value: Int): Unit = {
val currentValue = values(ordinal).asInstanceOf[MutableInt]
@@ -306,4 +306,8 @@ final class SpecificMutableRow(val values: Array[MutableValue]) extends MutableR
override def getByte(i: Int): Byte = {
values(i).asInstanceOf[MutableByte].value
}
+
+ override def getAs[T](i: Int): T = {
+ values(i).boxed.asInstanceOf[T]
+ }
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/WrapDynamic.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/WrapDynamic.scala
index 1eb55715794a7..1a4ac06c7a79d 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/WrapDynamic.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/WrapDynamic.scala
@@ -24,9 +24,7 @@ import org.apache.spark.sql.catalyst.types.DataType
/**
* The data type representing [[DynamicRow]] values.
*/
-case object DynamicType extends DataType {
- def simpleString: String = "dynamic"
-}
+case object DynamicType extends DataType
/**
* Wrap a [[Row]] as a [[DynamicRow]].
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala
index 78a0c55e4bbe5..ba240233cae61 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala
@@ -17,7 +17,7 @@
package org.apache.spark.sql.catalyst.expressions
-import java.sql.Timestamp
+import java.sql.{Date, Timestamp}
import org.apache.spark.sql.catalyst.types._
@@ -33,6 +33,7 @@ object Literal {
case b: Boolean => Literal(b, BooleanType)
case d: BigDecimal => Literal(d, DecimalType)
case t: Timestamp => Literal(t, TimestampType)
+ case d: Date => Literal(d, DateType)
case a: Array[Byte] => Literal(a, BinaryType)
case null => Literal(null, NullType)
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala
index 59fb0311a9c44..d023db44d8543 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala
@@ -57,12 +57,14 @@ abstract class NamedExpression extends Expression {
abstract class Attribute extends NamedExpression {
self: Product =>
+ override def references = AttributeSet(this)
+
def withNullability(newNullability: Boolean): Attribute
def withQualifiers(newQualifiers: Seq[String]): Attribute
def withName(newName: String): Attribute
def toAttribute = this
- def newInstance: Attribute
+ def newInstance(): Attribute
}
@@ -116,8 +118,6 @@ case class AttributeReference(name: String, dataType: DataType, nullable: Boolea
(val exprId: ExprId = NamedExpression.newExprId, val qualifiers: Seq[String] = Nil)
extends Attribute with trees.LeafNode[Expression] {
- override def references = AttributeSet(this :: Nil)
-
override def equals(other: Any) = other match {
case ar: AttributeReference => exprId == ar.exprId && dataType == ar.dataType
case _ => false
@@ -131,7 +131,7 @@ case class AttributeReference(name: String, dataType: DataType, nullable: Boolea
h
}
- override def newInstance = AttributeReference(name, dataType, nullable)(qualifiers = qualifiers)
+ override def newInstance() = AttributeReference(name, dataType, nullable)(qualifiers = qualifiers)
/**
* Returns a copy of this [[AttributeReference]] with changed nullability.
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala
index 329af332d0fa1..1e22b2d03c672 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala
@@ -17,11 +17,11 @@
package org.apache.spark.sql.catalyst.expressions
+import scala.collection.immutable.HashSet
import org.apache.spark.sql.catalyst.analysis.UnresolvedException
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.catalyst.types.BooleanType
-
object InterpretedPredicate {
def apply(expression: Expression, inputSchema: Seq[Attribute]): (Row => Boolean) =
apply(BindReferences.bindReference(expression, inputSchema))
@@ -95,6 +95,23 @@ case class In(value: Expression, list: Seq[Expression]) extends Predicate {
}
}
+/**
+ * Optimized version of In clause, when all filter values of In clause are
+ * static.
+ */
+case class InSet(value: Expression, hset: HashSet[Any], child: Seq[Expression])
+ extends Predicate {
+
+ def children = child
+
+ def nullable = true // TODO: Figure out correct nullability semantics of IN.
+ override def toString = s"$value INSET ${hset.mkString("(", ",", ")")}"
+
+ override def eval(input: Row): Any = {
+ hset.contains(value.eval(input))
+ }
+}
+
case class And(left: Expression, right: Expression) extends BinaryPredicate {
def symbol = "&&"
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
index a4133feae8166..3693b41404fd6 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
@@ -17,6 +17,7 @@
package org.apache.spark.sql.catalyst.optimizer
+import scala.collection.immutable.HashSet
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.Inner
import org.apache.spark.sql.catalyst.plans.FullOuter
@@ -38,7 +39,8 @@ object Optimizer extends RuleExecutor[LogicalPlan] {
BooleanSimplification,
SimplifyFilters,
SimplifyCasts,
- SimplifyCaseConversionExpressions) ::
+ SimplifyCaseConversionExpressions,
+ OptimizeIn) ::
Batch("Filter Pushdown", FixedPoint(100),
UnionPushdown,
CombineFilters,
@@ -273,6 +275,20 @@ object ConstantFolding extends Rule[LogicalPlan] {
}
}
+/**
+ * Replaces [[In (value, seq[Literal])]] with optimized version[[InSet (value, HashSet[Literal])]]
+ * which is much faster
+ */
+object OptimizeIn extends Rule[LogicalPlan] {
+ def apply(plan: LogicalPlan): LogicalPlan = plan transform {
+ case q: LogicalPlan => q transformExpressionsDown {
+ case In(v, list) if !list.exists(!_.isInstanceOf[Literal]) =>
+ val hSet = list.map(e => e.eval(null))
+ InSet(v, HashSet() ++ hSet, v +: list)
+ }
+ }
+}
+
/**
* Simplifies boolean expressions where the answer can be determined without evaluating both sides.
* Note that this rule can eliminate expressions that might otherwise have been evaluated and thus
@@ -299,6 +315,18 @@ object BooleanSimplification extends Rule[LogicalPlan] {
case (_, _) => or
}
+ case not @ Not(exp) =>
+ exp match {
+ case Literal(true, BooleanType) => Literal(false)
+ case Literal(false, BooleanType) => Literal(true)
+ case GreaterThan(l, r) => LessThanOrEqual(l, r)
+ case GreaterThanOrEqual(l, r) => LessThan(l, r)
+ case LessThan(l, r) => GreaterThanOrEqual(l, r)
+ case LessThanOrEqual(l, r) => GreaterThan(l, r)
+ case Not(e) => e
+ case _ => not
+ }
+
// Turn "if (true) a else b" into "a", and if (false) a else b" into "b".
case e @ If(Literal(v, _), trueValue, falseValue) => if (v == true) trueValue else falseValue
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala
index af9e4d86e995a..dcbbb62c0aca4 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala
@@ -31,6 +31,25 @@ abstract class QueryPlan[PlanType <: TreeNode[PlanType]] extends TreeNode[PlanTy
*/
def outputSet: AttributeSet = AttributeSet(output)
+ /**
+ * All Attributes that appear in expressions from this operator. Note that this set does not
+ * include attributes that are implicitly referenced by being passed through to the output tuple.
+ */
+ def references: AttributeSet = AttributeSet(expressions.flatMap(_.references))
+
+ /**
+ * The set of all attributes that are input to this operator by its children.
+ */
+ def inputSet: AttributeSet =
+ AttributeSet(children.flatMap(_.asInstanceOf[QueryPlan[PlanType]].output))
+
+ /**
+ * Attributes that are referenced by expressions but not provided by this nodes children.
+ * Subclasses should override this method if they produce attributes internally as it is used by
+ * assertions designed to prevent the construction of invalid plans.
+ */
+ def missingInput: AttributeSet = references -- inputSet
+
/**
* Runs [[transform]] with `rule` on all expressions present in this query operator.
* Users should not expect a specific directionality. If a specific directionality is needed,
@@ -132,4 +151,8 @@ abstract class QueryPlan[PlanType <: TreeNode[PlanType]] extends TreeNode[PlanTy
/** Prints out the schema in the tree format */
def printSchema(): Unit = println(schemaString)
+
+ protected def statePrefix = if (missingInput.nonEmpty && children.nonEmpty) "!" else ""
+
+ override def simpleString = statePrefix + super.simpleString
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala
index 28d863e58beca..882e9c6110089 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala
@@ -22,6 +22,7 @@ import org.apache.spark.sql.catalyst.analysis.Resolver
import org.apache.spark.sql.catalyst.errors.TreeNodeException
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.QueryPlan
+import org.apache.spark.sql.catalyst.trees.TreeNode
import org.apache.spark.sql.catalyst.types.StructType
import org.apache.spark.sql.catalyst.trees
@@ -52,12 +53,6 @@ abstract class LogicalPlan extends QueryPlan[LogicalPlan] with Logging {
sizeInBytes = children.map(_.statistics).map(_.sizeInBytes).product)
}
- /**
- * Returns the set of attributes that this node takes as
- * input from its children.
- */
- lazy val inputSet: AttributeSet = AttributeSet(children.flatMap(_.output))
-
/**
* Returns true if this expression and all its children have been resolved to a specific schema
* and false if it still contains any unresolved placeholders. Implementations of LogicalPlan
@@ -67,11 +62,54 @@ abstract class LogicalPlan extends QueryPlan[LogicalPlan] with Logging {
*/
lazy val resolved: Boolean = !expressions.exists(!_.resolved) && childrenResolved
+ override protected def statePrefix = if (!resolved) "'" else super.statePrefix
+
/**
* Returns true if all its children of this query plan have been resolved.
*/
def childrenResolved: Boolean = !children.exists(!_.resolved)
+ /**
+ * Returns true when the given logical plan will return the same results as this logical plan.
+ *
+ * Since its likely undecideable to generally determine if two given plans will produce the same
+ * results, it is okay for this function to return false, even if the results are actually
+ * the same. Such behavior will not affect correctness, only the application of performance
+ * enhancements like caching. However, it is not acceptable to return true if the results could
+ * possibly be different.
+ *
+ * By default this function performs a modified version of equality that is tolerant of cosmetic
+ * differences like attribute naming and or expression id differences. Logical operators that
+ * can do better should override this function.
+ */
+ def sameResult(plan: LogicalPlan): Boolean = {
+ plan.getClass == this.getClass &&
+ plan.children.size == children.size && {
+ logDebug(s"[${cleanArgs.mkString(", ")}] == [${plan.cleanArgs.mkString(", ")}]")
+ cleanArgs == plan.cleanArgs
+ } &&
+ (plan.children, children).zipped.forall(_ sameResult _)
+ }
+
+ /** Args that have cleaned such that differences in expression id should not affect equality */
+ protected lazy val cleanArgs: Seq[Any] = {
+ val input = children.flatMap(_.output)
+ productIterator.map {
+ // Children are checked using sameResult above.
+ case tn: TreeNode[_] if children contains tn => null
+ case e: Expression => BindReferences.bindReference(e, input, allowFailures = true)
+ case s: Option[_] => s.map {
+ case e: Expression => BindReferences.bindReference(e, input, allowFailures = true)
+ case other => other
+ }
+ case s: Seq[_] => s.map {
+ case e: Expression => BindReferences.bindReference(e, input, allowFailures = true)
+ case other => other
+ }
+ case other => other
+ }.toSeq
+ }
+
/**
* Optionally resolves the given string to a [[NamedExpression]] using the input from all child
* nodes of this LogicalPlan. The attribute is expressed as
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/TestRelation.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/TestRelation.scala
index f8fe558511bfd..19769986ef58c 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/TestRelation.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/TestRelation.scala
@@ -41,4 +41,10 @@ case class LocalRelation(output: Seq[Attribute], data: Seq[Product] = Nil)
}
override protected def stringArgs = Iterator(output)
+
+ override def sameResult(plan: LogicalPlan): Boolean = plan match {
+ case LocalRelation(otherOutput, otherData) =>
+ otherOutput.map(_.dataType) == output.map(_.dataType) && otherData == data
+ case _ => false
+ }
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala
index 391508279bb80..14b03c7445c13 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala
@@ -105,8 +105,8 @@ case class InsertIntoTable(
child: LogicalPlan,
overwrite: Boolean)
extends LogicalPlan {
- // The table being inserted into is a child for the purposes of transformations.
- override def children = table :: child :: Nil
+
+ override def children = child :: Nil
override def output = child.output
override lazy val resolved = childrenResolved && child.output.zip(table.output).forall {
@@ -138,11 +138,6 @@ case class Aggregate(
child: LogicalPlan)
extends UnaryNode {
- /** The set of all AttributeReferences required for this aggregation. */
- def references =
- AttributeSet(
- groupingExpressions.flatMap(_.references) ++ aggregateExpressions.flatMap(_.references))
-
override def output = aggregateExpressions.map(_.toAttribute)
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/commands.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/commands.scala
index 8366639fa0e8b..b8ba2ee428a20 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/commands.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/commands.scala
@@ -39,9 +39,9 @@ case class NativeCommand(cmd: String) extends Command {
}
/**
- * Commands of the form "SET (key) (= value)".
+ * Commands of the form "SET [key [= value] ]".
*/
-case class SetCommand(key: Option[String], value: Option[String]) extends Command {
+case class SetCommand(kv: Option[(String, Option[String])]) extends Command {
override def output = Seq(
AttributeReference("", StringType, nullable = false)())
}
@@ -56,9 +56,15 @@ case class ExplainCommand(plan: LogicalPlan, extended: Boolean = false) extends
}
/**
- * Returned for the "CACHE TABLE tableName" and "UNCACHE TABLE tableName" command.
+ * Returned for the "CACHE TABLE tableName [AS SELECT ...]" command.
*/
-case class CacheCommand(tableName: String, doCache: Boolean) extends Command
+case class CacheTableCommand(tableName: String, plan: Option[LogicalPlan], isLazy: Boolean)
+ extends Command
+
+/**
+ * Returned for the "UNCACHE TABLE tableName" command.
+ */
+case class UncacheTableCommand(tableName: String) extends Command
/**
* Returned for the "DESCRIBE [EXTENDED] [dbName.]tableName" command.
@@ -77,6 +83,12 @@ case class DescribeCommand(
}
/**
- * Returned for the "CACHE TABLE tableName AS SELECT .." command.
+ * Returned for the "! shellCommand" command
+ */
+case class ShellCommand(cmd: String) extends Command
+
+
+/**
+ * Returned for the "SOURCE file" command
*/
-case class CacheTableAsSelectCommand(tableName: String, plan: LogicalPlan) extends Command
+case class SourceCommand(filePath: String) extends Command
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/dataTypes.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/dataTypes.scala
index c7d73d3990c3a..0cf139ebde417 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/dataTypes.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/dataTypes.scala
@@ -17,73 +17,127 @@
package org.apache.spark.sql.catalyst.types
-import java.sql.Timestamp
+import java.sql.{Date, Timestamp}
-import scala.math.Numeric.{FloatAsIfIntegral, BigDecimalAsIfIntegral, DoubleAsIfIntegral}
+import scala.math.Numeric.{BigDecimalAsIfIntegral, DoubleAsIfIntegral, FloatAsIfIntegral}
import scala.reflect.ClassTag
-import scala.reflect.runtime.universe.{typeTag, TypeTag, runtimeMirror}
+import scala.reflect.runtime.universe.{TypeTag, runtimeMirror, typeTag}
import scala.util.parsing.combinator.RegexParsers
+import org.json4s.JsonAST.JValue
+import org.json4s._
+import org.json4s.JsonDSL._
+import org.json4s.jackson.JsonMethods._
+
import org.apache.spark.sql.catalyst.ScalaReflectionLock
import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, Expression}
import org.apache.spark.util.Utils
-/**
- * Utility functions for working with DataTypes.
- */
-object DataType extends RegexParsers {
- protected lazy val primitiveType: Parser[DataType] =
- "StringType" ^^^ StringType |
- "FloatType" ^^^ FloatType |
- "IntegerType" ^^^ IntegerType |
- "ByteType" ^^^ ByteType |
- "ShortType" ^^^ ShortType |
- "DoubleType" ^^^ DoubleType |
- "LongType" ^^^ LongType |
- "BinaryType" ^^^ BinaryType |
- "BooleanType" ^^^ BooleanType |
- "DecimalType" ^^^ DecimalType |
- "TimestampType" ^^^ TimestampType
-
- protected lazy val arrayType: Parser[DataType] =
- "ArrayType" ~> "(" ~> dataType ~ "," ~ boolVal <~ ")" ^^ {
- case tpe ~ _ ~ containsNull => ArrayType(tpe, containsNull)
- }
- protected lazy val mapType: Parser[DataType] =
- "MapType" ~> "(" ~> dataType ~ "," ~ dataType ~ "," ~ boolVal <~ ")" ^^ {
- case t1 ~ _ ~ t2 ~ _ ~ valueContainsNull => MapType(t1, t2, valueContainsNull)
- }
+object DataType {
+ def fromJson(json: String): DataType = parseDataType(parse(json))
- protected lazy val structField: Parser[StructField] =
- ("StructField(" ~> "[a-zA-Z0-9_]*".r) ~ ("," ~> dataType) ~ ("," ~> boolVal <~ ")") ^^ {
- case name ~ tpe ~ nullable =>
- StructField(name, tpe, nullable = nullable)
+ private object JSortedObject {
+ def unapplySeq(value: JValue): Option[List[(String, JValue)]] = value match {
+ case JObject(seq) => Some(seq.toList.sortBy(_._1))
+ case _ => None
}
+ }
- protected lazy val boolVal: Parser[Boolean] =
- "true" ^^^ true |
- "false" ^^^ false
+ // NOTE: Map fields must be sorted in alphabetical order to keep consistent with the Python side.
+ private def parseDataType(json: JValue): DataType = json match {
+ case JString(name) =>
+ PrimitiveType.nameToType(name)
+
+ case JSortedObject(
+ ("containsNull", JBool(n)),
+ ("elementType", t: JValue),
+ ("type", JString("array"))) =>
+ ArrayType(parseDataType(t), n)
+
+ case JSortedObject(
+ ("keyType", k: JValue),
+ ("type", JString("map")),
+ ("valueContainsNull", JBool(n)),
+ ("valueType", v: JValue)) =>
+ MapType(parseDataType(k), parseDataType(v), n)
+
+ case JSortedObject(
+ ("fields", JArray(fields)),
+ ("type", JString("struct"))) =>
+ StructType(fields.map(parseStructField))
+ }
- protected lazy val structType: Parser[DataType] =
- "StructType\\([A-zA-z]*\\(".r ~> repsep(structField, ",") <~ "))" ^^ {
- case fields => new StructType(fields)
- }
+ private def parseStructField(json: JValue): StructField = json match {
+ case JSortedObject(
+ ("name", JString(name)),
+ ("nullable", JBool(nullable)),
+ ("type", dataType: JValue)) =>
+ StructField(name, parseDataType(dataType), nullable)
+ }
- protected lazy val dataType: Parser[DataType] =
- arrayType |
- mapType |
- structType |
- primitiveType
+ @deprecated("Use DataType.fromJson instead")
+ def fromCaseClassString(string: String): DataType = CaseClassStringParser(string)
+
+ private object CaseClassStringParser extends RegexParsers {
+ protected lazy val primitiveType: Parser[DataType] =
+ ( "StringType" ^^^ StringType
+ | "FloatType" ^^^ FloatType
+ | "IntegerType" ^^^ IntegerType
+ | "ByteType" ^^^ ByteType
+ | "ShortType" ^^^ ShortType
+ | "DoubleType" ^^^ DoubleType
+ | "LongType" ^^^ LongType
+ | "BinaryType" ^^^ BinaryType
+ | "BooleanType" ^^^ BooleanType
+ | "DecimalType" ^^^ DecimalType
+ | "TimestampType" ^^^ TimestampType
+ )
+
+ protected lazy val arrayType: Parser[DataType] =
+ "ArrayType" ~> "(" ~> dataType ~ "," ~ boolVal <~ ")" ^^ {
+ case tpe ~ _ ~ containsNull => ArrayType(tpe, containsNull)
+ }
+
+ protected lazy val mapType: Parser[DataType] =
+ "MapType" ~> "(" ~> dataType ~ "," ~ dataType ~ "," ~ boolVal <~ ")" ^^ {
+ case t1 ~ _ ~ t2 ~ _ ~ valueContainsNull => MapType(t1, t2, valueContainsNull)
+ }
+
+ protected lazy val structField: Parser[StructField] =
+ ("StructField(" ~> "[a-zA-Z0-9_]*".r) ~ ("," ~> dataType) ~ ("," ~> boolVal <~ ")") ^^ {
+ case name ~ tpe ~ nullable =>
+ StructField(name, tpe, nullable = nullable)
+ }
+
+ protected lazy val boolVal: Parser[Boolean] =
+ ( "true" ^^^ true
+ | "false" ^^^ false
+ )
+
+ protected lazy val structType: Parser[DataType] =
+ "StructType\\([A-zA-z]*\\(".r ~> repsep(structField, ",") <~ "))" ^^ {
+ case fields => new StructType(fields)
+ }
+
+ protected lazy val dataType: Parser[DataType] =
+ ( arrayType
+ | mapType
+ | structType
+ | primitiveType
+ )
+
+ /**
+ * Parses a string representation of a DataType.
+ *
+ * TODO: Generate parser as pickler...
+ */
+ def apply(asString: String): DataType = parseAll(dataType, asString) match {
+ case Success(result, _) => result
+ case failure: NoSuccess =>
+ throw new IllegalArgumentException(s"Unsupported dataType: $asString, $failure")
+ }
- /**
- * Parses a string representation of a DataType.
- *
- * TODO: Generate parser as pickler...
- */
- def apply(asString: String): DataType = parseAll(dataType, asString) match {
- case Success(result, _) => result
- case failure: NoSuccess => sys.error(s"Unsupported dataType: $asString, $failure")
}
protected[types] def buildFormattedString(
@@ -111,15 +165,19 @@ abstract class DataType {
def isPrimitive: Boolean = false
- def simpleString: String
-}
+ def typeName: String = this.getClass.getSimpleName.stripSuffix("$").dropRight(4).toLowerCase
+
+ private[sql] def jsonValue: JValue = typeName
+
+ def json: String = compact(render(jsonValue))
-case object NullType extends DataType {
- def simpleString: String = "null"
+ def prettyJson: String = pretty(render(jsonValue))
}
+case object NullType extends DataType
+
object NativeType {
- def all = Seq(
+ val all = Seq(
IntegerType, BooleanType, LongType, DoubleType, FloatType, ShortType, ByteType, StringType)
def unapply(dt: DataType): Boolean = all.contains(dt)
@@ -139,6 +197,12 @@ trait PrimitiveType extends DataType {
override def isPrimitive = true
}
+object PrimitiveType {
+ private[sql] val all = Seq(DecimalType, TimestampType, BinaryType) ++ NativeType.all
+
+ private[sql] val nameToType = all.map(t => t.typeName -> t).toMap
+}
+
abstract class NativeType extends DataType {
private[sql] type JvmType
@transient private[sql] val tag: TypeTag[JvmType]
@@ -154,19 +218,26 @@ case object StringType extends NativeType with PrimitiveType {
private[sql] type JvmType = String
@transient private[sql] lazy val tag = ScalaReflectionLock.synchronized { typeTag[JvmType] }
private[sql] val ordering = implicitly[Ordering[JvmType]]
- def simpleString: String = "string"
}
-case object BinaryType extends DataType with PrimitiveType {
+case object BinaryType extends NativeType with PrimitiveType {
private[sql] type JvmType = Array[Byte]
- def simpleString: String = "binary"
+ @transient private[sql] lazy val tag = ScalaReflectionLock.synchronized { typeTag[JvmType] }
+ private[sql] val ordering = new Ordering[JvmType] {
+ def compare(x: Array[Byte], y: Array[Byte]): Int = {
+ for (i <- 0 until x.length; if i < y.length) {
+ val res = x(i).compareTo(y(i))
+ if (res != 0) return res
+ }
+ x.length - y.length
+ }
+ }
}
case object BooleanType extends NativeType with PrimitiveType {
private[sql] type JvmType = Boolean
@transient private[sql] lazy val tag = ScalaReflectionLock.synchronized { typeTag[JvmType] }
private[sql] val ordering = implicitly[Ordering[JvmType]]
- def simpleString: String = "boolean"
}
case object TimestampType extends NativeType {
@@ -177,8 +248,16 @@ case object TimestampType extends NativeType {
private[sql] val ordering = new Ordering[JvmType] {
def compare(x: Timestamp, y: Timestamp) = x.compareTo(y)
}
+}
+
+case object DateType extends NativeType {
+ private[sql] type JvmType = Date
+
+ @transient private[sql] lazy val tag = ScalaReflectionLock.synchronized { typeTag[JvmType] }
- def simpleString: String = "timestamp"
+ private[sql] val ordering = new Ordering[JvmType] {
+ def compare(x: Date, y: Date) = x.compareTo(y)
+ }
}
abstract class NumericType extends NativeType with PrimitiveType {
@@ -212,7 +291,6 @@ case object LongType extends IntegralType {
private[sql] val numeric = implicitly[Numeric[Long]]
private[sql] val integral = implicitly[Integral[Long]]
private[sql] val ordering = implicitly[Ordering[JvmType]]
- def simpleString: String = "long"
}
case object IntegerType extends IntegralType {
@@ -221,7 +299,6 @@ case object IntegerType extends IntegralType {
private[sql] val numeric = implicitly[Numeric[Int]]
private[sql] val integral = implicitly[Integral[Int]]
private[sql] val ordering = implicitly[Ordering[JvmType]]
- def simpleString: String = "integer"
}
case object ShortType extends IntegralType {
@@ -230,7 +307,6 @@ case object ShortType extends IntegralType {
private[sql] val numeric = implicitly[Numeric[Short]]
private[sql] val integral = implicitly[Integral[Short]]
private[sql] val ordering = implicitly[Ordering[JvmType]]
- def simpleString: String = "short"
}
case object ByteType extends IntegralType {
@@ -239,7 +315,6 @@ case object ByteType extends IntegralType {
private[sql] val numeric = implicitly[Numeric[Byte]]
private[sql] val integral = implicitly[Integral[Byte]]
private[sql] val ordering = implicitly[Ordering[JvmType]]
- def simpleString: String = "byte"
}
/** Matcher for any expressions that evaluate to [[FractionalType]]s */
@@ -261,7 +336,6 @@ case object DecimalType extends FractionalType {
private[sql] val fractional = implicitly[Fractional[BigDecimal]]
private[sql] val ordering = implicitly[Ordering[JvmType]]
private[sql] val asIntegral = BigDecimalAsIfIntegral
- def simpleString: String = "decimal"
}
case object DoubleType extends FractionalType {
@@ -271,7 +345,6 @@ case object DoubleType extends FractionalType {
private[sql] val fractional = implicitly[Fractional[Double]]
private[sql] val ordering = implicitly[Ordering[JvmType]]
private[sql] val asIntegral = DoubleAsIfIntegral
- def simpleString: String = "double"
}
case object FloatType extends FractionalType {
@@ -281,7 +354,6 @@ case object FloatType extends FractionalType {
private[sql] val fractional = implicitly[Fractional[Float]]
private[sql] val ordering = implicitly[Ordering[JvmType]]
private[sql] val asIntegral = FloatAsIfIntegral
- def simpleString: String = "float"
}
object ArrayType {
@@ -299,11 +371,14 @@ object ArrayType {
case class ArrayType(elementType: DataType, containsNull: Boolean) extends DataType {
private[sql] def buildFormattedString(prefix: String, builder: StringBuilder): Unit = {
builder.append(
- s"${prefix}-- element: ${elementType.simpleString} (containsNull = ${containsNull})\n")
+ s"$prefix-- element: ${elementType.typeName} (containsNull = $containsNull)\n")
DataType.buildFormattedString(elementType, s"$prefix |", builder)
}
- def simpleString: String = "array"
+ override private[sql] def jsonValue =
+ ("type" -> typeName) ~
+ ("elementType" -> elementType.jsonValue) ~
+ ("containsNull" -> containsNull)
}
/**
@@ -315,9 +390,15 @@ case class ArrayType(elementType: DataType, containsNull: Boolean) extends DataT
case class StructField(name: String, dataType: DataType, nullable: Boolean) {
private[sql] def buildFormattedString(prefix: String, builder: StringBuilder): Unit = {
- builder.append(s"${prefix}-- ${name}: ${dataType.simpleString} (nullable = ${nullable})\n")
+ builder.append(s"$prefix-- $name: ${dataType.typeName} (nullable = $nullable)\n")
DataType.buildFormattedString(dataType, s"$prefix |", builder)
}
+
+ private[sql] def jsonValue: JValue = {
+ ("name" -> name) ~
+ ("type" -> dataType.jsonValue) ~
+ ("nullable" -> nullable)
+ }
}
object StructType {
@@ -338,8 +419,7 @@ case class StructType(fields: Seq[StructField]) extends DataType {
* have a name matching the given name, `null` will be returned.
*/
def apply(name: String): StructField = {
- nameToField.get(name).getOrElse(
- throw new IllegalArgumentException(s"Field ${name} does not exist."))
+ nameToField.getOrElse(name, throw new IllegalArgumentException(s"Field $name does not exist."))
}
/**
@@ -348,7 +428,7 @@ case class StructType(fields: Seq[StructField]) extends DataType {
*/
def apply(names: Set[String]): StructType = {
val nonExistFields = names -- fieldNamesSet
- if (!nonExistFields.isEmpty) {
+ if (nonExistFields.nonEmpty) {
throw new IllegalArgumentException(
s"Field ${nonExistFields.mkString(",")} does not exist.")
}
@@ -374,7 +454,9 @@ case class StructType(fields: Seq[StructField]) extends DataType {
fields.foreach(field => field.buildFormattedString(prefix, builder))
}
- def simpleString: String = "struct"
+ override private[sql] def jsonValue =
+ ("type" -> typeName) ~
+ ("fields" -> fields.map(_.jsonValue))
}
object MapType {
@@ -397,12 +479,16 @@ case class MapType(
valueType: DataType,
valueContainsNull: Boolean) extends DataType {
private[sql] def buildFormattedString(prefix: String, builder: StringBuilder): Unit = {
- builder.append(s"${prefix}-- key: ${keyType.simpleString}\n")
- builder.append(s"${prefix}-- value: ${valueType.simpleString} " +
- s"(valueContainsNull = ${valueContainsNull})\n")
+ builder.append(s"$prefix-- key: ${keyType.typeName}\n")
+ builder.append(s"$prefix-- value: ${valueType.typeName} " +
+ s"(valueContainsNull = $valueContainsNull)\n")
DataType.buildFormattedString(keyType, s"$prefix |", builder)
DataType.buildFormattedString(valueType, s"$prefix |", builder)
}
- def simpleString: String = "map"
+ override private[sql] def jsonValue: JValue =
+ ("type" -> typeName) ~
+ ("keyType" -> keyType.jsonValue) ~
+ ("valueType" -> valueType.jsonValue) ~
+ ("valueContainsNull" -> valueContainsNull)
}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala
index 428607d8c8253..488e373854bb3 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala
@@ -53,7 +53,8 @@ case class OptionalData(
floatField: Option[Float],
shortField: Option[Short],
byteField: Option[Byte],
- booleanField: Option[Boolean])
+ booleanField: Option[Boolean],
+ structField: Option[PrimitiveData])
case class ComplexData(
arrayField: Seq[Int],
@@ -100,7 +101,7 @@ class ScalaReflectionSuite extends FunSuite {
nullable = true))
}
- test("optinal data") {
+ test("optional data") {
val schema = schemaFor[OptionalData]
assert(schema === Schema(
StructType(Seq(
@@ -110,7 +111,8 @@ class ScalaReflectionSuite extends FunSuite {
StructField("floatField", FloatType, nullable = true),
StructField("shortField", ShortType, nullable = true),
StructField("byteField", ByteType, nullable = true),
- StructField("booleanField", BooleanType, nullable = true))),
+ StructField("booleanField", BooleanType, nullable = true),
+ StructField("structField", schemaFor[PrimitiveData].dataType, nullable = true))),
nullable = true))
}
@@ -228,4 +230,17 @@ class ScalaReflectionSuite extends FunSuite {
assert(ArrayType(IntegerType) === typeOfObject3(Seq(1, 2, 3)))
assert(ArrayType(ArrayType(IntegerType)) === typeOfObject3(Seq(Seq(1,2,3))))
}
+
+ test("convert PrimitiveData to catalyst") {
+ val data = PrimitiveData(1, 1, 1, 1, 1, 1, true)
+ val convertedData = Seq(1, 1.toLong, 1.toDouble, 1.toFloat, 1.toShort, 1.toByte, true)
+ assert(convertToCatalyst(data) === convertedData)
+ }
+
+ test("convert Option[Product] to catalyst") {
+ val primitiveData = PrimitiveData(1, 1, 1, 1, 1, 1, true)
+ val data = OptionalData(Some(1), Some(1), Some(1), Some(1), Some(1), Some(1), Some(true), Some(primitiveData))
+ val convertedData = Seq(1, 1.toLong, 1.toDouble, 1.toFloat, 1.toShort, 1.toByte, true, convertToCatalyst(primitiveData))
+ assert(convertToCatalyst(data) === convertedData)
+ }
}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala
index 5809a108ff62e..7b45738c4fc95 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala
@@ -19,10 +19,11 @@ package org.apache.spark.sql.catalyst.analysis
import org.scalatest.{BeforeAndAfter, FunSuite}
-import org.apache.spark.sql.catalyst.expressions.AttributeReference
+import org.apache.spark.sql.catalyst.dsl.expressions._
+import org.apache.spark.sql.catalyst.expressions.{Alias, AttributeReference}
import org.apache.spark.sql.catalyst.errors.TreeNodeException
import org.apache.spark.sql.catalyst.plans.logical._
-import org.apache.spark.sql.catalyst.types.IntegerType
+import org.apache.spark.sql.catalyst.types._
class AnalysisSuite extends FunSuite with BeforeAndAfter {
val caseSensitiveCatalog = new SimpleCatalog(true)
@@ -33,6 +34,12 @@ class AnalysisSuite extends FunSuite with BeforeAndAfter {
new Analyzer(caseInsensitiveCatalog, EmptyFunctionRegistry, caseSensitive = false)
val testRelation = LocalRelation(AttributeReference("a", IntegerType, nullable = true)())
+ val testRelation2 = LocalRelation(
+ AttributeReference("a", StringType)(),
+ AttributeReference("b", StringType)(),
+ AttributeReference("c", DoubleType)(),
+ AttributeReference("d", DecimalType)(),
+ AttributeReference("e", ShortType)())
before {
caseSensitiveCatalog.registerTable(None, "TaBlE", testRelation)
@@ -74,7 +81,7 @@ class AnalysisSuite extends FunSuite with BeforeAndAfter {
val e = intercept[RuntimeException] {
caseSensitiveAnalyze(UnresolvedRelation(None, "tAbLe", None))
}
- assert(e.getMessage === "Table Not Found: tAbLe")
+ assert(e.getMessage == "Table Not Found: tAbLe")
assert(
caseSensitiveAnalyze(UnresolvedRelation(None, "TaBlE", None)) ===
@@ -106,4 +113,31 @@ class AnalysisSuite extends FunSuite with BeforeAndAfter {
}
assert(e.getMessage().toLowerCase.contains("unresolved plan"))
}
+
+ test("divide should be casted into fractional types") {
+ val testRelation2 = LocalRelation(
+ AttributeReference("a", StringType)(),
+ AttributeReference("b", StringType)(),
+ AttributeReference("c", DoubleType)(),
+ AttributeReference("d", DecimalType)(),
+ AttributeReference("e", ShortType)())
+
+ val expr0 = 'a / 2
+ val expr1 = 'a / 'b
+ val expr2 = 'a / 'c
+ val expr3 = 'a / 'd
+ val expr4 = 'e / 'e
+ val plan = caseInsensitiveAnalyze(Project(
+ Alias(expr0, s"Analyzer($expr0)")() ::
+ Alias(expr1, s"Analyzer($expr1)")() ::
+ Alias(expr2, s"Analyzer($expr2)")() ::
+ Alias(expr3, s"Analyzer($expr3)")() ::
+ Alias(expr4, s"Analyzer($expr4)")() :: Nil, testRelation2))
+ val pl = plan.asInstanceOf[Project].projectList
+ assert(pl(0).dataType == DoubleType)
+ assert(pl(1).dataType == DoubleType)
+ assert(pl(2).dataType == DoubleType)
+ assert(pl(3).dataType == DecimalType)
+ assert(pl(4).dataType == DoubleType)
+ }
}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala
index 63931af4bac3d..6dc5942023f9e 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala
@@ -17,7 +17,9 @@
package org.apache.spark.sql.catalyst.expressions
-import java.sql.Timestamp
+import java.sql.{Date, Timestamp}
+
+import scala.collection.immutable.HashSet
import org.scalatest.FunSuite
import org.scalatest.Matchers._
@@ -25,6 +27,7 @@ import org.scalautils.TripleEqualsSupport.Spread
import org.apache.spark.sql.catalyst.types._
+
/* Implicit conversions */
import org.apache.spark.sql.catalyst.dsl.expressions._
@@ -145,6 +148,24 @@ class ExpressionEvaluationSuite extends FunSuite {
checkEvaluation(In(Literal(1), Seq(Literal(1), Literal(2))) && In(Literal(2), Seq(Literal(1), Literal(2))), true)
}
+ test("INSET") {
+ val hS = HashSet[Any]() + 1 + 2
+ val nS = HashSet[Any]() + 1 + 2 + null
+ val one = Literal(1)
+ val two = Literal(2)
+ val three = Literal(3)
+ val nl = Literal(null)
+ val s = Seq(one, two)
+ val nullS = Seq(one, two, null)
+ checkEvaluation(InSet(one, hS, one +: s), true)
+ checkEvaluation(InSet(two, hS, two +: s), true)
+ checkEvaluation(InSet(two, nS, two +: nullS), true)
+ checkEvaluation(InSet(nl, nS, nl +: nullS), true)
+ checkEvaluation(InSet(three, hS, three +: s), false)
+ checkEvaluation(InSet(three, nS, three +: nullS), false)
+ checkEvaluation(InSet(one, hS, one +: s) && InSet(two, hS, two +: s), true)
+ }
+
test("MaxOf") {
checkEvaluation(MaxOf(1, 2), 2)
checkEvaluation(MaxOf(2, 1), 2)
@@ -231,8 +252,11 @@ class ExpressionEvaluationSuite extends FunSuite {
test("data type casting") {
- val sts = "1970-01-01 00:00:01.1"
- val ts = Timestamp.valueOf(sts)
+ val sd = "1970-01-01"
+ val d = Date.valueOf(sd)
+ val sts = sd + " 00:00:02"
+ val nts = sts + ".1"
+ val ts = Timestamp.valueOf(nts)
checkEvaluation("abdef" cast StringType, "abdef")
checkEvaluation("abdef" cast DecimalType, null)
@@ -245,8 +269,15 @@ class ExpressionEvaluationSuite extends FunSuite {
checkEvaluation(Cast(Literal(1.toDouble) cast TimestampType, DoubleType), 1.toDouble)
checkEvaluation(Cast(Literal(1.toDouble) cast TimestampType, DoubleType), 1.toDouble)
- checkEvaluation(Cast(Literal(sts) cast TimestampType, StringType), sts)
+ checkEvaluation(Cast(Literal(sd) cast DateType, StringType), sd)
+ checkEvaluation(Cast(Literal(d) cast StringType, DateType), d)
+ checkEvaluation(Cast(Literal(nts) cast TimestampType, StringType), nts)
checkEvaluation(Cast(Literal(ts) cast StringType, TimestampType), ts)
+ // all convert to string type to check
+ checkEvaluation(
+ Cast(Cast(Literal(nts) cast TimestampType, DateType), StringType), sd)
+ checkEvaluation(
+ Cast(Cast(Literal(ts) cast DateType, TimestampType), StringType), sts)
checkEvaluation(Cast("abdef" cast BinaryType, StringType), "abdef")
@@ -295,6 +326,12 @@ class ExpressionEvaluationSuite extends FunSuite {
checkEvaluation(Cast(Literal(null, IntegerType), ShortType), null)
}
+ test("date") {
+ val d1 = Date.valueOf("1970-01-01")
+ val d2 = Date.valueOf("1970-01-02")
+ checkEvaluation(Literal(d1) < Literal(d2), true)
+ }
+
test("timestamp") {
val ts1 = new Timestamp(12)
val ts2 = new Timestamp(123)
@@ -302,6 +339,17 @@ class ExpressionEvaluationSuite extends FunSuite {
checkEvaluation(Literal(ts1) < Literal(ts2), true)
}
+ test("date casting") {
+ val d = Date.valueOf("1970-01-01")
+ checkEvaluation(Cast(d, ShortType), null)
+ checkEvaluation(Cast(d, IntegerType), null)
+ checkEvaluation(Cast(d, LongType), null)
+ checkEvaluation(Cast(d, FloatType), null)
+ checkEvaluation(Cast(d, DoubleType), null)
+ checkEvaluation(Cast(d, StringType), "1970-01-01")
+ checkEvaluation(Cast(Cast(d, TimestampType), StringType), "1970-01-01 00:00:00")
+ }
+
test("timestamp casting") {
val millis = 15 * 1000 + 2
val seconds = millis * 1000 + 2
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeInSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeInSuite.scala
new file mode 100644
index 0000000000000..97a78ec971c39
--- /dev/null
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeInSuite.scala
@@ -0,0 +1,76 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.optimizer
+
+import scala.collection.immutable.HashSet
+import org.apache.spark.sql.catalyst.analysis.{EliminateAnalysisOperators, UnresolvedAttribute}
+import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan}
+import org.apache.spark.sql.catalyst.plans.PlanTest
+import org.apache.spark.sql.catalyst.rules.RuleExecutor
+import org.apache.spark.sql.catalyst.types._
+
+// For implicit conversions
+import org.apache.spark.sql.catalyst.dsl.plans._
+import org.apache.spark.sql.catalyst.dsl.expressions._
+
+class OptimizeInSuite extends PlanTest {
+
+ object Optimize extends RuleExecutor[LogicalPlan] {
+ val batches =
+ Batch("AnalysisNodes", Once,
+ EliminateAnalysisOperators) ::
+ Batch("ConstantFolding", Once,
+ ConstantFolding,
+ BooleanSimplification,
+ OptimizeIn) :: Nil
+ }
+
+ val testRelation = LocalRelation('a.int, 'b.int, 'c.int)
+
+ test("OptimizedIn test: In clause optimized to InSet") {
+ val originalQuery =
+ testRelation
+ .where(In(UnresolvedAttribute("a"), Seq(Literal(1),Literal(2))))
+ .analyze
+
+ val optimized = Optimize(originalQuery.analyze)
+ val correctAnswer =
+ testRelation
+ .where(InSet(UnresolvedAttribute("a"), HashSet[Any]()+1+2,
+ UnresolvedAttribute("a") +: Seq(Literal(1),Literal(2))))
+ .analyze
+
+ comparePlans(optimized, correctAnswer)
+ }
+
+ test("OptimizedIn test: In clause not optimized in case filter has attributes") {
+ val originalQuery =
+ testRelation
+ .where(In(UnresolvedAttribute("a"), Seq(Literal(1),Literal(2), UnresolvedAttribute("b"))))
+ .analyze
+
+ val optimized = Optimize(originalQuery.analyze)
+ val correctAnswer =
+ testRelation
+ .where(In(UnresolvedAttribute("a"), Seq(Literal(1),Literal(2), UnresolvedAttribute("b"))))
+ .analyze
+
+ comparePlans(optimized, correctAnswer)
+ }
+}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/SameResultSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/SameResultSuite.scala
new file mode 100644
index 0000000000000..e8a793d107451
--- /dev/null
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/SameResultSuite.scala
@@ -0,0 +1,62 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.plans
+
+import org.scalatest.FunSuite
+
+import org.apache.spark.sql.catalyst.dsl.plans._
+import org.apache.spark.sql.catalyst.dsl.expressions._
+import org.apache.spark.sql.catalyst.expressions.{ExprId, AttributeReference}
+import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan}
+import org.apache.spark.sql.catalyst.util._
+
+/**
+ * Provides helper methods for comparing plans.
+ */
+class SameResultSuite extends FunSuite {
+ val testRelation = LocalRelation('a.int, 'b.int, 'c.int)
+ val testRelation2 = LocalRelation('a.int, 'b.int, 'c.int)
+
+ def assertSameResult(a: LogicalPlan, b: LogicalPlan, result: Boolean = true) = {
+ val aAnalyzed = a.analyze
+ val bAnalyzed = b.analyze
+
+ if (aAnalyzed.sameResult(bAnalyzed) != result) {
+ val comparison = sideBySide(aAnalyzed.toString, bAnalyzed.toString).mkString("\n")
+ fail(s"Plans should return sameResult = $result\n$comparison")
+ }
+ }
+
+ test("relations") {
+ assertSameResult(testRelation, testRelation2)
+ }
+
+ test("projections") {
+ assertSameResult(testRelation.select('a), testRelation2.select('a))
+ assertSameResult(testRelation.select('b), testRelation2.select('b))
+ assertSameResult(testRelation.select('a, 'b), testRelation2.select('a, 'b))
+ assertSameResult(testRelation.select('b, 'a), testRelation2.select('b, 'a))
+
+ assertSameResult(testRelation, testRelation2.select('a), false)
+ assertSameResult(testRelation.select('b, 'a), testRelation2.select('a, 'b), false)
+ }
+
+ test("filters") {
+ assertSameResult(testRelation.where('a === 'b), testRelation2.where('a === 'b))
+ }
+}
diff --git a/sql/core/src/main/java/org/apache/spark/sql/api/java/DataType.java b/sql/core/src/main/java/org/apache/spark/sql/api/java/DataType.java
index 37b4c8ffcba0b..37e88d72b9172 100644
--- a/sql/core/src/main/java/org/apache/spark/sql/api/java/DataType.java
+++ b/sql/core/src/main/java/org/apache/spark/sql/api/java/DataType.java
@@ -44,6 +44,11 @@ public abstract class DataType {
*/
public static final BooleanType BooleanType = new BooleanType();
+ /**
+ * Gets the DateType object.
+ */
+ public static final DateType DateType = new DateType();
+
/**
* Gets the TimestampType object.
*/
diff --git a/sql/core/src/main/java/org/apache/spark/sql/api/java/DateType.java b/sql/core/src/main/java/org/apache/spark/sql/api/java/DateType.java
new file mode 100644
index 0000000000000..6677793baa365
--- /dev/null
+++ b/sql/core/src/main/java/org/apache/spark/sql/api/java/DateType.java
@@ -0,0 +1,27 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.api.java;
+
+/**
+ * The data type representing java.sql.Date values.
+ *
+ * {@code DateType} is represented by the singleton object {@link DataType#DateType}.
+ */
+public class DateType extends DataType {
+ protected DateType() {}
+}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/CacheManager.scala b/sql/core/src/main/scala/org/apache/spark/sql/CacheManager.scala
new file mode 100644
index 0000000000000..5ab2b5316ab10
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/CacheManager.scala
@@ -0,0 +1,138 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql
+
+import java.util.concurrent.locks.ReentrantReadWriteLock
+
+import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
+import org.apache.spark.sql.columnar.InMemoryRelation
+import org.apache.spark.storage.StorageLevel
+import org.apache.spark.storage.StorageLevel.MEMORY_AND_DISK
+
+/** Holds a cached logical plan and its data */
+private case class CachedData(plan: LogicalPlan, cachedRepresentation: InMemoryRelation)
+
+/**
+ * Provides support in a SQLContext for caching query results and automatically using these cached
+ * results when subsequent queries are executed. Data is cached using byte buffers stored in an
+ * InMemoryRelation. This relation is automatically substituted query plans that return the
+ * `sameResult` as the originally cached query.
+ */
+private[sql] trait CacheManager {
+ self: SQLContext =>
+
+ @transient
+ private val cachedData = new scala.collection.mutable.ArrayBuffer[CachedData]
+
+ @transient
+ private val cacheLock = new ReentrantReadWriteLock
+
+ /** Returns true if the table is currently cached in-memory. */
+ def isCached(tableName: String): Boolean = lookupCachedData(table(tableName)).nonEmpty
+
+ /** Caches the specified table in-memory. */
+ def cacheTable(tableName: String): Unit = cacheQuery(table(tableName))
+
+ /** Removes the specified table from the in-memory cache. */
+ def uncacheTable(tableName: String): Unit = uncacheQuery(table(tableName))
+
+ /** Acquires a read lock on the cache for the duration of `f`. */
+ private def readLock[A](f: => A): A = {
+ val lock = cacheLock.readLock()
+ lock.lock()
+ try f finally {
+ lock.unlock()
+ }
+ }
+
+ /** Acquires a write lock on the cache for the duration of `f`. */
+ private def writeLock[A](f: => A): A = {
+ val lock = cacheLock.writeLock()
+ lock.lock()
+ try f finally {
+ lock.unlock()
+ }
+ }
+
+ private[sql] def clearCache(): Unit = writeLock {
+ cachedData.foreach(_.cachedRepresentation.cachedColumnBuffers.unpersist())
+ cachedData.clear()
+ }
+
+ /**
+ * Caches the data produced by the logical representation of the given schema rdd. Unlike
+ * `RDD.cache()`, the default storage level is set to be `MEMORY_AND_DISK` because recomputing
+ * the in-memory columnar representation of the underlying table is expensive.
+ */
+ private[sql] def cacheQuery(
+ query: SchemaRDD,
+ storageLevel: StorageLevel = MEMORY_AND_DISK): Unit = writeLock {
+ val planToCache = query.queryExecution.optimizedPlan
+ if (lookupCachedData(planToCache).nonEmpty) {
+ logWarning("Asked to cache already cached data.")
+ } else {
+ cachedData +=
+ CachedData(
+ planToCache,
+ InMemoryRelation(
+ useCompression, columnBatchSize, storageLevel, query.queryExecution.executedPlan))
+ }
+ }
+
+ /** Removes the data for the given SchemaRDD from the cache */
+ private[sql] def uncacheQuery(query: SchemaRDD, blocking: Boolean = true): Unit = writeLock {
+ val planToCache = query.queryExecution.optimizedPlan
+ val dataIndex = cachedData.indexWhere(_.plan.sameResult(planToCache))
+ require(dataIndex >= 0, s"Table $query is not cached.")
+ cachedData(dataIndex).cachedRepresentation.cachedColumnBuffers.unpersist(blocking)
+ cachedData.remove(dataIndex)
+ }
+
+
+ /** Optionally returns cached data for the given SchemaRDD */
+ private[sql] def lookupCachedData(query: SchemaRDD): Option[CachedData] = readLock {
+ lookupCachedData(query.queryExecution.optimizedPlan)
+ }
+
+ /** Optionally returns cached data for the given LogicalPlan. */
+ private[sql] def lookupCachedData(plan: LogicalPlan): Option[CachedData] = readLock {
+ cachedData.find(_.plan.sameResult(plan))
+ }
+
+ /** Replaces segments of the given logical plan with cached versions where possible. */
+ private[sql] def useCachedData(plan: LogicalPlan): LogicalPlan = {
+ plan transformDown {
+ case currentFragment =>
+ lookupCachedData(currentFragment)
+ .map(_.cachedRepresentation.withOutput(currentFragment.output))
+ .getOrElse(currentFragment)
+ }
+ }
+
+ /**
+ * Invalidates the cache of any data that contains `plan`. Note that it is possible that this
+ * function will over invalidate.
+ */
+ private[sql] def invalidateCache(plan: LogicalPlan): Unit = writeLock {
+ cachedData.foreach {
+ case data if data.plan.collect { case p if p.sameResult(plan) => p }.nonEmpty =>
+ data.cachedRepresentation.recache()
+ case _ =>
+ }
+ }
+}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala
index f6f4cf3b80d41..07e6e2eccddf4 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala
@@ -35,6 +35,7 @@ private[spark] object SQLConf {
val PARQUET_BINARY_AS_STRING = "spark.sql.parquet.binaryAsString"
val PARQUET_CACHE_METADATA = "spark.sql.parquet.cacheMetadata"
val PARQUET_COMPRESSION = "spark.sql.parquet.compression.codec"
+ val COLUMN_NAME_OF_CORRUPT_RECORD = "spark.sql.columnNameOfCorruptRecord"
// This is only used for the thriftserver
val THRIFTSERVER_POOL = "spark.sql.thriftserver.scheduler.pool"
@@ -131,6 +132,9 @@ private[sql] trait SQLConf {
private[spark] def inMemoryPartitionPruning: Boolean =
getConf(IN_MEMORY_PARTITION_PRUNING, "false").toBoolean
+ private[spark] def columnNameOfCorruptRecord: String =
+ getConf(COLUMN_NAME_OF_CORRUPT_RECORD, "_corrupt_record")
+
/** ********************** SQLConf functionality methods ************ */
/** Set Spark SQL configuration properties. */
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
index a42bedbe6c04e..23e7b2d270777 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
@@ -22,6 +22,7 @@ import scala.reflect.runtime.universe.TypeTag
import org.apache.hadoop.conf.Configuration
+import org.apache.spark.SparkContext
import org.apache.spark.annotation.{AlphaComponent, DeveloperApi, Experimental}
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.ScalaReflection
@@ -31,12 +32,11 @@ import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.optimizer.Optimizer
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.catalyst.rules.RuleExecutor
+import org.apache.spark.sql.catalyst.types.DataType
import org.apache.spark.sql.columnar.InMemoryRelation
-import org.apache.spark.sql.execution._
-import org.apache.spark.sql.execution.SparkStrategies
+import org.apache.spark.sql.execution.{SparkStrategies, _}
import org.apache.spark.sql.json._
import org.apache.spark.sql.parquet.ParquetRelation
-import org.apache.spark.{Logging, SparkContext}
/**
* :: AlphaComponent ::
@@ -50,6 +50,7 @@ import org.apache.spark.{Logging, SparkContext}
class SQLContext(@transient val sparkContext: SparkContext)
extends org.apache.spark.Logging
with SQLConf
+ with CacheManager
with ExpressionConversions
with UDFRegistration
with Serializable {
@@ -65,12 +66,17 @@ class SQLContext(@transient val sparkContext: SparkContext)
@transient
protected[sql] lazy val analyzer: Analyzer =
new Analyzer(catalog, functionRegistry, caseSensitive = true)
+
@transient
protected[sql] val optimizer = Optimizer
+
@transient
- protected[sql] val parser = new catalyst.SqlParser
+ protected[sql] val sqlParser = {
+ val fallback = new catalyst.SqlParser
+ new catalyst.SparkSQLParser(fallback(_))
+ }
- protected[sql] def parseSql(sql: String): LogicalPlan = parser(sql)
+ protected[sql] def parseSql(sql: String): LogicalPlan = sqlParser(sql)
protected[sql] def executeSql(sql: String): this.QueryExecution = executePlan(parseSql(sql))
protected[sql] def executePlan(plan: LogicalPlan): this.QueryExecution =
new this.QueryExecution { val logical = plan }
@@ -96,7 +102,8 @@ class SQLContext(@transient val sparkContext: SparkContext)
*/
implicit def createSchemaRDD[A <: Product: TypeTag](rdd: RDD[A]) = {
SparkPlan.currentContext.set(self)
- new SchemaRDD(this, SparkLogicalPlan(ExistingRdd.fromProductRdd(rdd))(self))
+ new SchemaRDD(this,
+ LogicalRDD(ScalaReflection.attributesFor[A], RDDConversions.productToRowRdd(rdd))(self))
}
/**
@@ -133,7 +140,7 @@ class SQLContext(@transient val sparkContext: SparkContext)
def applySchema(rowRDD: RDD[Row], schema: StructType): SchemaRDD = {
// TODO: use MutableProjection when rowRDD is another SchemaRDD and the applied
// schema differs from the existing schema on any field data type.
- val logicalPlan = SparkLogicalPlan(ExistingRdd(schema.toAttributes, rowRDD))(self)
+ val logicalPlan = LogicalRDD(schema.toAttributes, rowRDD)(self)
new SchemaRDD(this, logicalPlan)
}
@@ -193,9 +200,12 @@ class SQLContext(@transient val sparkContext: SparkContext)
*/
@Experimental
def jsonRDD(json: RDD[String], schema: StructType): SchemaRDD = {
+ val columnNameOfCorruptJsonRecord = columnNameOfCorruptRecord
val appliedSchema =
- Option(schema).getOrElse(JsonRDD.nullTypeToStringType(JsonRDD.inferSchema(json, 1.0)))
- val rowRDD = JsonRDD.jsonStringToRow(json, appliedSchema)
+ Option(schema).getOrElse(
+ JsonRDD.nullTypeToStringType(
+ JsonRDD.inferSchema(json, 1.0, columnNameOfCorruptJsonRecord)))
+ val rowRDD = JsonRDD.jsonStringToRow(json, appliedSchema, columnNameOfCorruptJsonRecord)
applySchema(rowRDD, appliedSchema)
}
@@ -204,8 +214,11 @@ class SQLContext(@transient val sparkContext: SparkContext)
*/
@Experimental
def jsonRDD(json: RDD[String], samplingRatio: Double): SchemaRDD = {
- val appliedSchema = JsonRDD.nullTypeToStringType(JsonRDD.inferSchema(json, samplingRatio))
- val rowRDD = JsonRDD.jsonStringToRow(json, appliedSchema)
+ val columnNameOfCorruptJsonRecord = columnNameOfCorruptRecord
+ val appliedSchema =
+ JsonRDD.nullTypeToStringType(
+ JsonRDD.inferSchema(json, samplingRatio, columnNameOfCorruptJsonRecord))
+ val rowRDD = JsonRDD.jsonStringToRow(json, appliedSchema, columnNameOfCorruptJsonRecord)
applySchema(rowRDD, appliedSchema)
}
@@ -272,45 +285,6 @@ class SQLContext(@transient val sparkContext: SparkContext)
def table(tableName: String): SchemaRDD =
new SchemaRDD(this, catalog.lookupRelation(None, tableName))
- /** Caches the specified table in-memory. */
- def cacheTable(tableName: String): Unit = {
- val currentTable = table(tableName).queryExecution.analyzed
- val asInMemoryRelation = currentTable match {
- case _: InMemoryRelation =>
- currentTable
-
- case _ =>
- InMemoryRelation(useCompression, columnBatchSize, executePlan(currentTable).executedPlan)
- }
-
- catalog.registerTable(None, tableName, asInMemoryRelation)
- }
-
- /** Removes the specified table from the in-memory cache. */
- def uncacheTable(tableName: String): Unit = {
- table(tableName).queryExecution.analyzed match {
- // This is kind of a hack to make sure that if this was just an RDD registered as a table,
- // we reregister the RDD as a table.
- case inMem @ InMemoryRelation(_, _, _, e: ExistingRdd) =>
- inMem.cachedColumnBuffers.unpersist()
- catalog.unregisterTable(None, tableName)
- catalog.registerTable(None, tableName, SparkLogicalPlan(e)(self))
- case inMem: InMemoryRelation =>
- inMem.cachedColumnBuffers.unpersist()
- catalog.unregisterTable(None, tableName)
- case plan => throw new IllegalArgumentException(s"Table $tableName is not cached: $plan")
- }
- }
-
- /** Returns true if the table is currently cached in-memory. */
- def isCached(tableName: String): Boolean = {
- val relation = table(tableName).queryExecution.analyzed
- relation match {
- case _: InMemoryRelation => true
- case _ => false
- }
- }
-
protected[sql] class SparkPlanner extends SparkStrategies {
val sparkContext: SparkContext = self.sparkContext
@@ -401,10 +375,12 @@ class SQLContext(@transient val sparkContext: SparkContext)
lazy val analyzed = ExtractPythonUdfs(analyzer(logical))
lazy val optimizedPlan = optimizer(analyzed)
+ lazy val withCachedData = useCachedData(optimizedPlan)
+
// TODO: Don't just pick the first one...
lazy val sparkPlan = {
SparkPlan.currentContext.set(self)
- planner(optimizedPlan).next()
+ planner(withCachedData).next()
}
// executedPlan should not be used to initialize any SparkPlan. It should be
// only used for execution.
@@ -444,8 +420,7 @@ class SQLContext(@transient val sparkContext: SparkContext)
* It is only used by PySpark.
*/
private[sql] def parseDataType(dataTypeString: String): DataType = {
- val parser = org.apache.spark.sql.catalyst.types.DataType
- parser(dataTypeString)
+ DataType.fromJson(dataTypeString)
}
/**
@@ -526,6 +501,6 @@ class SQLContext(@transient val sparkContext: SparkContext)
iter.map { m => new GenericRow(m): Row}
}
- new SchemaRDD(this, SparkLogicalPlan(ExistingRdd(schema.toAttributes, rowRdd))(self))
+ new SchemaRDD(this, LogicalRDD(schema.toAttributes, rowRdd)(self))
}
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala
index 3b873f7c62cb6..948122d42f0e1 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala
@@ -19,6 +19,8 @@ package org.apache.spark.sql
import java.util.{Map => JMap, List => JList}
+import org.apache.spark.storage.StorageLevel
+
import scala.collection.JavaConversions._
import scala.collection.JavaConverters._
@@ -32,7 +34,7 @@ import org.apache.spark.sql.catalyst.analysis._
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.plans.{Inner, JoinType}
-import org.apache.spark.sql.execution.{ExistingRdd, SparkLogicalPlan}
+import org.apache.spark.sql.execution.LogicalRDD
import org.apache.spark.api.java.JavaRDD
/**
@@ -358,7 +360,7 @@ class SchemaRDD(
join: Boolean = false,
outer: Boolean = false,
alias: Option[String] = None) =
- new SchemaRDD(sqlContext, Generate(generator, join, outer, None, logicalPlan))
+ new SchemaRDD(sqlContext, Generate(generator, join, outer, alias, logicalPlan))
/**
* Returns this RDD as a SchemaRDD. Intended primarily to force the invocation of the implicit
@@ -442,8 +444,7 @@ class SchemaRDD(
*/
private def applySchema(rdd: RDD[Row]): SchemaRDD = {
new SchemaRDD(sqlContext,
- SparkLogicalPlan(
- ExistingRdd(queryExecution.analyzed.output.map(_.newInstance), rdd))(sqlContext))
+ LogicalRDD(queryExecution.analyzed.output.map(_.newInstance()), rdd)(sqlContext))
}
// =======================================================================
@@ -497,4 +498,20 @@ class SchemaRDD(
override def subtract(other: RDD[Row], p: Partitioner)
(implicit ord: Ordering[Row] = null): SchemaRDD =
applySchema(super.subtract(other, p)(ord))
+
+ /** Overridden cache function will always use the in-memory columnar caching. */
+ override def cache(): this.type = {
+ sqlContext.cacheQuery(this)
+ this
+ }
+
+ override def persist(newLevel: StorageLevel): this.type = {
+ sqlContext.cacheQuery(this, newLevel)
+ this
+ }
+
+ override def unpersist(blocking: Boolean): this.type = {
+ sqlContext.uncacheQuery(this, blocking)
+ this
+ }
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDDLike.scala b/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDDLike.scala
index e52eeb3e1c47e..25ba7d88ba538 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDDLike.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDDLike.scala
@@ -20,7 +20,7 @@ package org.apache.spark.sql
import org.apache.spark.annotation.{DeveloperApi, Experimental}
import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation
import org.apache.spark.sql.catalyst.plans.logical._
-import org.apache.spark.sql.execution.SparkLogicalPlan
+import org.apache.spark.sql.execution.LogicalRDD
/**
* Contains functions that are shared between all SchemaRDD types (i.e., Scala, Java)
@@ -55,8 +55,7 @@ private[sql] trait SchemaRDDLike {
// For various commands (like DDL) and queries with side effects, we force query optimization to
// happen right away to let these side effects take place eagerly.
case _: Command | _: InsertIntoTable | _: CreateTableAsSelect |_: WriteToFile =>
- queryExecution.toRdd
- SparkLogicalPlan(queryExecution.executedPlan)(sqlContext)
+ LogicalRDD(queryExecution.analyzed.output, queryExecution.toRdd)(sqlContext)
case _ =>
baseLogicalPlan
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/api/java/JavaSQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/api/java/JavaSQLContext.scala
index 150ff8a42063d..f8171c3be3207 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/api/java/JavaSQLContext.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/api/java/JavaSQLContext.scala
@@ -27,7 +27,7 @@ import org.apache.spark.sql.json.JsonRDD
import org.apache.spark.sql.{SQLContext, StructType => SStructType}
import org.apache.spark.sql.catalyst.expressions.{AttributeReference, GenericRow, Row => ScalaRow}
import org.apache.spark.sql.parquet.ParquetRelation
-import org.apache.spark.sql.execution.{ExistingRdd, SparkLogicalPlan}
+import org.apache.spark.sql.execution.LogicalRDD
import org.apache.spark.sql.types.util.DataTypeConversions.asScalaDataType
import org.apache.spark.util.Utils
@@ -100,7 +100,7 @@ class JavaSQLContext(val sqlContext: SQLContext) extends UDFRegistration {
new GenericRow(extractors.map(e => e.invoke(row)).toArray[Any]): ScalaRow
}
}
- new JavaSchemaRDD(sqlContext, SparkLogicalPlan(ExistingRdd(schema, rowRdd))(sqlContext))
+ new JavaSchemaRDD(sqlContext, LogicalRDD(schema, rowRdd)(sqlContext))
}
/**
@@ -114,7 +114,7 @@ class JavaSQLContext(val sqlContext: SQLContext) extends UDFRegistration {
val scalaRowRDD = rowRDD.rdd.map(r => r.row)
val scalaSchema = asScalaDataType(schema).asInstanceOf[SStructType]
val logicalPlan =
- SparkLogicalPlan(ExistingRdd(scalaSchema.toAttributes, scalaRowRDD))(sqlContext)
+ LogicalRDD(scalaSchema.toAttributes, scalaRowRDD)(sqlContext)
new JavaSchemaRDD(sqlContext, logicalPlan)
}
@@ -148,10 +148,14 @@ class JavaSQLContext(val sqlContext: SQLContext) extends UDFRegistration {
* It goes through the entire dataset once to determine the schema.
*/
def jsonRDD(json: JavaRDD[String]): JavaSchemaRDD = {
- val appliedScalaSchema = JsonRDD.nullTypeToStringType(JsonRDD.inferSchema(json.rdd, 1.0))
- val scalaRowRDD = JsonRDD.jsonStringToRow(json.rdd, appliedScalaSchema)
+ val columnNameOfCorruptJsonRecord = sqlContext.columnNameOfCorruptRecord
+ val appliedScalaSchema =
+ JsonRDD.nullTypeToStringType(
+ JsonRDD.inferSchema(json.rdd, 1.0, columnNameOfCorruptJsonRecord))
+ val scalaRowRDD =
+ JsonRDD.jsonStringToRow(json.rdd, appliedScalaSchema, columnNameOfCorruptJsonRecord)
val logicalPlan =
- SparkLogicalPlan(ExistingRdd(appliedScalaSchema.toAttributes, scalaRowRDD))(sqlContext)
+ LogicalRDD(appliedScalaSchema.toAttributes, scalaRowRDD)(sqlContext)
new JavaSchemaRDD(sqlContext, logicalPlan)
}
@@ -162,12 +166,16 @@ class JavaSQLContext(val sqlContext: SQLContext) extends UDFRegistration {
*/
@Experimental
def jsonRDD(json: JavaRDD[String], schema: StructType): JavaSchemaRDD = {
+ val columnNameOfCorruptJsonRecord = sqlContext.columnNameOfCorruptRecord
val appliedScalaSchema =
Option(asScalaDataType(schema)).getOrElse(
- JsonRDD.nullTypeToStringType(JsonRDD.inferSchema(json.rdd, 1.0))).asInstanceOf[SStructType]
- val scalaRowRDD = JsonRDD.jsonStringToRow(json.rdd, appliedScalaSchema)
+ JsonRDD.nullTypeToStringType(
+ JsonRDD.inferSchema(
+ json.rdd, 1.0, columnNameOfCorruptJsonRecord))).asInstanceOf[SStructType]
+ val scalaRowRDD = JsonRDD.jsonStringToRow(
+ json.rdd, appliedScalaSchema, columnNameOfCorruptJsonRecord)
val logicalPlan =
- SparkLogicalPlan(ExistingRdd(appliedScalaSchema.toAttributes, scalaRowRDD))(sqlContext)
+ LogicalRDD(appliedScalaSchema.toAttributes, scalaRowRDD)(sqlContext)
new JavaSchemaRDD(sqlContext, logicalPlan)
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/api/java/Row.scala b/sql/core/src/main/scala/org/apache/spark/sql/api/java/Row.scala
index e9d04ce7aae4c..df01411f60a05 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/api/java/Row.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/api/java/Row.scala
@@ -22,6 +22,7 @@ import scala.collection.convert.Wrappers.{JListWrapper, JMapWrapper}
import scala.collection.JavaConversions
import scala.math.BigDecimal
+import org.apache.spark.api.java.JavaUtils.mapAsSerializableJavaMap
import org.apache.spark.sql.catalyst.expressions.{Row => ScalaRow}
/**
@@ -114,7 +115,7 @@ object Row {
// they are actually accessed.
case row: ScalaRow => new Row(row)
case map: scala.collection.Map[_, _] =>
- JavaConversions.mapAsJavaMap(
+ mapAsSerializableJavaMap(
map.map {
case (key, value) => (toJavaValue(key), toJavaValue(value))
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnAccessor.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnAccessor.scala
index c9faf0852142a..538dd5b734664 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnAccessor.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnAccessor.scala
@@ -92,6 +92,9 @@ private[sql] class FloatColumnAccessor(buffer: ByteBuffer)
private[sql] class StringColumnAccessor(buffer: ByteBuffer)
extends NativeColumnAccessor(buffer, STRING)
+private[sql] class DateColumnAccessor(buffer: ByteBuffer)
+ extends NativeColumnAccessor(buffer, DATE)
+
private[sql] class TimestampColumnAccessor(buffer: ByteBuffer)
extends NativeColumnAccessor(buffer, TIMESTAMP)
@@ -118,6 +121,7 @@ private[sql] object ColumnAccessor {
case BYTE.typeId => new ByteColumnAccessor(dup)
case SHORT.typeId => new ShortColumnAccessor(dup)
case STRING.typeId => new StringColumnAccessor(dup)
+ case DATE.typeId => new DateColumnAccessor(dup)
case TIMESTAMP.typeId => new TimestampColumnAccessor(dup)
case BINARY.typeId => new BinaryColumnAccessor(dup)
case GENERIC.typeId => new GenericColumnAccessor(dup)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnBuilder.scala
index 2e61a981375aa..300cef15bf8a4 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnBuilder.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnBuilder.scala
@@ -107,6 +107,8 @@ private[sql] class FloatColumnBuilder extends NativeColumnBuilder(new FloatColum
private[sql] class StringColumnBuilder extends NativeColumnBuilder(new StringColumnStats, STRING)
+private[sql] class DateColumnBuilder extends NativeColumnBuilder(new DateColumnStats, DATE)
+
private[sql] class TimestampColumnBuilder
extends NativeColumnBuilder(new TimestampColumnStats, TIMESTAMP)
@@ -151,6 +153,7 @@ private[sql] object ColumnBuilder {
case STRING.typeId => new StringColumnBuilder
case BINARY.typeId => new BinaryColumnBuilder
case GENERIC.typeId => new GenericColumnBuilder
+ case DATE.typeId => new DateColumnBuilder
case TIMESTAMP.typeId => new TimestampColumnBuilder
}).asInstanceOf[ColumnBuilder]
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnStats.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnStats.scala
index 203a714e03c97..b34ab255d084a 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnStats.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnStats.scala
@@ -17,7 +17,7 @@
package org.apache.spark.sql.columnar
-import java.sql.Timestamp
+import java.sql.{Date, Timestamp}
import org.apache.spark.sql.Row
import org.apache.spark.sql.catalyst.expressions.{AttributeMap, Attribute, AttributeReference}
@@ -190,6 +190,24 @@ private[sql] class StringColumnStats extends ColumnStats {
def collectedStatistics = Row(lower, upper, nullCount)
}
+private[sql] class DateColumnStats extends ColumnStats {
+ var upper: Date = null
+ var lower: Date = null
+ var nullCount = 0
+
+ override def gatherStats(row: Row, ordinal: Int) {
+ if (!row.isNullAt(ordinal)) {
+ val value = row(ordinal).asInstanceOf[Date]
+ if (upper == null || value.compareTo(upper) > 0) upper = value
+ if (lower == null || value.compareTo(lower) < 0) lower = value
+ } else {
+ nullCount += 1
+ }
+ }
+
+ def collectedStatistics = Row(lower, upper, nullCount)
+}
+
private[sql] class TimestampColumnStats extends ColumnStats {
var upper: Timestamp = null
var lower: Timestamp = null
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnType.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnType.scala
index 198b5756676aa..ab66c85c4f242 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnType.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnType.scala
@@ -18,7 +18,7 @@
package org.apache.spark.sql.columnar
import java.nio.ByteBuffer
-import java.sql.Timestamp
+import java.sql.{Date, Timestamp}
import scala.reflect.runtime.universe.TypeTag
@@ -335,7 +335,26 @@ private[sql] object STRING extends NativeColumnType(StringType, 7, 8) {
}
}
-private[sql] object TIMESTAMP extends NativeColumnType(TimestampType, 8, 12) {
+private[sql] object DATE extends NativeColumnType(DateType, 8, 8) {
+ override def extract(buffer: ByteBuffer) = {
+ val date = new Date(buffer.getLong())
+ date
+ }
+
+ override def append(v: Date, buffer: ByteBuffer): Unit = {
+ buffer.putLong(v.getTime)
+ }
+
+ override def getField(row: Row, ordinal: Int) = {
+ row(ordinal).asInstanceOf[Date]
+ }
+
+ override def setField(row: MutableRow, ordinal: Int, value: Date): Unit = {
+ row(ordinal) = value
+ }
+}
+
+private[sql] object TIMESTAMP extends NativeColumnType(TimestampType, 9, 12) {
override def extract(buffer: ByteBuffer) = {
val timestamp = new Timestamp(buffer.getLong())
timestamp.setNanos(buffer.getInt())
@@ -376,7 +395,7 @@ private[sql] sealed abstract class ByteArrayColumnType[T <: DataType](
}
}
-private[sql] object BINARY extends ByteArrayColumnType[BinaryType.type](9, 16) {
+private[sql] object BINARY extends ByteArrayColumnType[BinaryType.type](10, 16) {
override def setField(row: MutableRow, ordinal: Int, value: Array[Byte]): Unit = {
row(ordinal) = value
}
@@ -387,7 +406,7 @@ private[sql] object BINARY extends ByteArrayColumnType[BinaryType.type](9, 16) {
// Used to process generic objects (all types other than those listed above). Objects should be
// serialized first before appending to the column `ByteBuffer`, and is also extracted as serialized
// byte array.
-private[sql] object GENERIC extends ByteArrayColumnType[DataType](10, 16) {
+private[sql] object GENERIC extends ByteArrayColumnType[DataType](11, 16) {
override def setField(row: MutableRow, ordinal: Int, value: Array[Byte]): Unit = {
row(ordinal) = SparkSqlSerializer.deserialize[Any](value)
}
@@ -407,6 +426,7 @@ private[sql] object ColumnType {
case ShortType => SHORT
case StringType => STRING
case BinaryType => BINARY
+ case DateType => DATE
case TimestampType => TIMESTAMP
case _ => GENERIC
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/InMemoryColumnarTableScan.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/InMemoryColumnarTableScan.scala
index 8a3612cdf19be..22ab0e2613f21 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/InMemoryColumnarTableScan.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/InMemoryColumnarTableScan.scala
@@ -27,18 +27,24 @@ import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.execution.{LeafNode, SparkPlan}
+import org.apache.spark.storage.StorageLevel
private[sql] object InMemoryRelation {
- def apply(useCompression: Boolean, batchSize: Int, child: SparkPlan): InMemoryRelation =
- new InMemoryRelation(child.output, useCompression, batchSize, child)()
+ def apply(
+ useCompression: Boolean,
+ batchSize: Int,
+ storageLevel: StorageLevel,
+ child: SparkPlan): InMemoryRelation =
+ new InMemoryRelation(child.output, useCompression, batchSize, storageLevel, child)()
}
-private[sql] case class CachedBatch(buffers: Array[ByteBuffer], stats: Row)
+private[sql] case class CachedBatch(buffers: Array[Array[Byte]], stats: Row)
private[sql] case class InMemoryRelation(
output: Seq[Attribute],
useCompression: Boolean,
batchSize: Int,
+ storageLevel: StorageLevel,
child: SparkPlan)
(private var _cachedColumnBuffers: RDD[CachedBatch] = null)
extends LogicalPlan with MultiInstanceRelation {
@@ -51,6 +57,16 @@ private[sql] case class InMemoryRelation(
// If the cached column buffers were not passed in, we calculate them in the constructor.
// As in Spark, the actual work of caching is lazy.
if (_cachedColumnBuffers == null) {
+ buildBuffers()
+ }
+
+ def recache() = {
+ _cachedColumnBuffers.unpersist()
+ _cachedColumnBuffers = null
+ buildBuffers()
+ }
+
+ private def buildBuffers(): Unit = {
val output = child.output
val cached = child.execute().mapPartitions { rowIterator =>
new Iterator[CachedBatch] {
@@ -75,24 +91,30 @@ private[sql] case class InMemoryRelation(
val stats = Row.fromSeq(
columnBuilders.map(_.columnStats.collectedStatistics).foldLeft(Seq.empty[Any])(_ ++ _))
- CachedBatch(columnBuilders.map(_.build()), stats)
+ CachedBatch(columnBuilders.map(_.build().array()), stats)
}
def hasNext = rowIterator.hasNext
}
- }.cache()
+ }.persist(storageLevel)
cached.setName(child.toString)
_cachedColumnBuffers = cached
}
+ def withOutput(newOutput: Seq[Attribute]): InMemoryRelation = {
+ InMemoryRelation(
+ newOutput, useCompression, batchSize, storageLevel, child)(_cachedColumnBuffers)
+ }
+
override def children = Seq.empty
override def newInstance() = {
new InMemoryRelation(
- output.map(_.newInstance),
+ output.map(_.newInstance()),
useCompression,
batchSize,
+ storageLevel,
child)(
_cachedColumnBuffers).asInstanceOf[this.type]
}
@@ -216,8 +238,9 @@ private[sql] case class InMemoryColumnarTableScan(
def cachedBatchesToRows(cacheBatches: Iterator[CachedBatch]) = {
val rows = cacheBatches.flatMap { cachedBatch =>
// Build column accessors
- val columnAccessors =
- requestedColumnIndices.map(cachedBatch.buffers(_)).map(ColumnAccessor(_))
+ val columnAccessors = requestedColumnIndices.map { batch =>
+ ColumnAccessor(ByteBuffer.wrap(cachedBatch.buffers(batch)))
+ }
// Extract rows via column accessors
new Iterator[Row] {
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala
new file mode 100644
index 0000000000000..2ddf513b6fc98
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala
@@ -0,0 +1,119 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution
+
+import org.apache.spark.sql.catalyst.analysis.MultiInstanceRelation
+import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
+
+import scala.reflect.runtime.universe.TypeTag
+
+import org.apache.spark.annotation.DeveloperApi
+import org.apache.spark.rdd.RDD
+import org.apache.spark.sql.{SQLContext, Row}
+import org.apache.spark.sql.catalyst.ScalaReflection
+import org.apache.spark.sql.catalyst.expressions.{Attribute, GenericMutableRow}
+
+/**
+ * :: DeveloperApi ::
+ */
+@DeveloperApi
+object RDDConversions {
+ def productToRowRdd[A <: Product](data: RDD[A]): RDD[Row] = {
+ data.mapPartitions { iterator =>
+ if (iterator.isEmpty) {
+ Iterator.empty
+ } else {
+ val bufferedIterator = iterator.buffered
+ val mutableRow = new GenericMutableRow(bufferedIterator.head.productArity)
+
+ bufferedIterator.map { r =>
+ var i = 0
+ while (i < mutableRow.length) {
+ mutableRow(i) = ScalaReflection.convertToCatalyst(r.productElement(i))
+ i += 1
+ }
+
+ mutableRow
+ }
+ }
+ }
+ }
+
+ /*
+ def toLogicalPlan[A <: Product : TypeTag](productRdd: RDD[A]): LogicalPlan = {
+ LogicalRDD(ScalaReflection.attributesFor[A], productToRowRdd(productRdd))
+ }
+ */
+}
+
+case class LogicalRDD(output: Seq[Attribute], rdd: RDD[Row])(sqlContext: SQLContext)
+ extends LogicalPlan with MultiInstanceRelation {
+
+ def children = Nil
+
+ def newInstance() =
+ LogicalRDD(output.map(_.newInstance()), rdd)(sqlContext).asInstanceOf[this.type]
+
+ override def sameResult(plan: LogicalPlan) = plan match {
+ case LogicalRDD(_, otherRDD) => rdd.id == otherRDD.id
+ case _ => false
+ }
+
+ @transient override lazy val statistics = Statistics(
+ // TODO: Instead of returning a default value here, find a way to return a meaningful size
+ // estimate for RDDs. See PR 1238 for more discussions.
+ sizeInBytes = BigInt(sqlContext.defaultSizeInBytes)
+ )
+}
+
+case class PhysicalRDD(output: Seq[Attribute], rdd: RDD[Row]) extends LeafNode {
+ override def execute() = rdd
+}
+
+@deprecated("Use LogicalRDD", "1.2.0")
+case class ExistingRdd(output: Seq[Attribute], rdd: RDD[Row]) extends LeafNode {
+ override def execute() = rdd
+}
+
+@deprecated("Use LogicalRDD", "1.2.0")
+case class SparkLogicalPlan(alreadyPlanned: SparkPlan)(@transient sqlContext: SQLContext)
+ extends LogicalPlan with MultiInstanceRelation {
+
+ def output = alreadyPlanned.output
+ override def children = Nil
+
+ override final def newInstance(): this.type = {
+ SparkLogicalPlan(
+ alreadyPlanned match {
+ case ExistingRdd(output, rdd) => ExistingRdd(output.map(_.newInstance), rdd)
+ case _ => sys.error("Multiple instance of the same relation detected.")
+ })(sqlContext).asInstanceOf[this.type]
+ }
+
+ override def sameResult(plan: LogicalPlan) = plan match {
+ case SparkLogicalPlan(ExistingRdd(_, rdd)) =>
+ rdd.id == alreadyPlanned.asInstanceOf[ExistingRdd].rdd.id
+ case _ => false
+ }
+
+ @transient override lazy val statistics = Statistics(
+ // TODO: Instead of returning a default value here, find a way to return a meaningful size
+ // estimate for RDDs. See PR 1238 for more discussions.
+ sizeInBytes = BigInt(sqlContext.defaultSizeInBytes)
+ )
+}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Generate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Generate.scala
index c386fd121c5de..38877c28de3a8 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Generate.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Generate.scala
@@ -39,7 +39,8 @@ case class Generate(
child: SparkPlan)
extends UnaryNode {
- protected def generatorOutput: Seq[Attribute] = {
+ // This must be a val since the generator output expr ids are not preserved by serialization.
+ protected val generatorOutput: Seq[Attribute] = {
if (join && outer) {
generator.output.map(_.withNullability(true))
} else {
@@ -62,7 +63,7 @@ case class Generate(
newProjection(child.output ++ nullValues, child.output)
val joinProjection =
- newProjection(child.output ++ generator.output, child.output ++ generator.output)
+ newProjection(child.output ++ generatorOutput, child.output ++ generatorOutput)
val joinedRow = new JoinedRow
iter.flatMap {row =>
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala
index 2b8913985b028..b1a7948b66cb6 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala
@@ -126,39 +126,6 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ
}
}
-/**
- * :: DeveloperApi ::
- * Allows already planned SparkQueries to be linked into logical query plans.
- *
- * Note that in general it is not valid to use this class to link multiple copies of the same
- * physical operator into the same query plan as this violates the uniqueness of expression ids.
- * Special handling exists for ExistingRdd as these are already leaf operators and thus we can just
- * replace the output attributes with new copies of themselves without breaking any attribute
- * linking.
- */
-@DeveloperApi
-case class SparkLogicalPlan(alreadyPlanned: SparkPlan)(@transient sqlContext: SQLContext)
- extends LogicalPlan with MultiInstanceRelation {
-
- def output = alreadyPlanned.output
- override def children = Nil
-
- override final def newInstance(): this.type = {
- SparkLogicalPlan(
- alreadyPlanned match {
- case ExistingRdd(output, rdd) => ExistingRdd(output.map(_.newInstance), rdd)
- case _ => sys.error("Multiple instance of the same relation detected.")
- })(sqlContext).asInstanceOf[this.type]
- }
-
- @transient override lazy val statistics = Statistics(
- // TODO: Instead of returning a default value here, find a way to return a meaningful size
- // estimate for RDDs. See PR 1238 for more discussions.
- sizeInBytes = BigInt(sqlContext.defaultSizeInBytes)
- )
-
-}
-
private[sql] trait LeafNode extends SparkPlan with trees.LeafNode[SparkPlan] {
self: Product =>
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
index 45687d960404c..79e4ddb8c4f5d 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
@@ -27,6 +27,7 @@ import org.apache.spark.sql.catalyst.types._
import org.apache.spark.sql.columnar.{InMemoryRelation, InMemoryColumnarTableScan}
import org.apache.spark.sql.parquet._
+
private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
self: SQLContext#SparkPlanner =>
@@ -34,13 +35,12 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
// Find left semi joins where at least some predicates can be evaluated by matching join keys
case ExtractEquiJoinKeys(LeftSemi, leftKeys, rightKeys, condition, left, right) =>
- val semiJoin = execution.LeftSemiJoinHash(
+ val semiJoin = joins.LeftSemiJoinHash(
leftKeys, rightKeys, planLater(left), planLater(right))
condition.map(Filter(_, semiJoin)).getOrElse(semiJoin) :: Nil
// no predicate can be evaluated by matching hash keys
case logical.Join(left, right, LeftSemi, condition) =>
- execution.LeftSemiJoinBNL(
- planLater(left), planLater(right), condition) :: Nil
+ joins.LeftSemiJoinBNL(planLater(left), planLater(right), condition) :: Nil
case _ => Nil
}
}
@@ -50,13 +50,13 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
* evaluated by matching hash keys.
*
* This strategy applies a simple optimization based on the estimates of the physical sizes of
- * the two join sides. When planning a [[execution.BroadcastHashJoin]], if one side has an
+ * the two join sides. When planning a [[joins.BroadcastHashJoin]], if one side has an
* estimated physical size smaller than the user-settable threshold
* [[org.apache.spark.sql.SQLConf.AUTO_BROADCASTJOIN_THRESHOLD]], the planner would mark it as the
* ''build'' relation and mark the other relation as the ''stream'' side. The build table will be
* ''broadcasted'' to all of the executors involved in the join, as a
* [[org.apache.spark.broadcast.Broadcast]] object. If both estimates exceed the threshold, they
- * will instead be used to decide the build side in a [[execution.ShuffledHashJoin]].
+ * will instead be used to decide the build side in a [[joins.ShuffledHashJoin]].
*/
object HashJoin extends Strategy with PredicateHelper {
@@ -66,8 +66,8 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
left: LogicalPlan,
right: LogicalPlan,
condition: Option[Expression],
- side: BuildSide) = {
- val broadcastHashJoin = execution.BroadcastHashJoin(
+ side: joins.BuildSide) = {
+ val broadcastHashJoin = execution.joins.BroadcastHashJoin(
leftKeys, rightKeys, side, planLater(left), planLater(right))
condition.map(Filter(_, broadcastHashJoin)).getOrElse(broadcastHashJoin) :: Nil
}
@@ -76,27 +76,26 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
case ExtractEquiJoinKeys(Inner, leftKeys, rightKeys, condition, left, right)
if sqlContext.autoBroadcastJoinThreshold > 0 &&
right.statistics.sizeInBytes <= sqlContext.autoBroadcastJoinThreshold =>
- makeBroadcastHashJoin(leftKeys, rightKeys, left, right, condition, BuildRight)
+ makeBroadcastHashJoin(leftKeys, rightKeys, left, right, condition, joins.BuildRight)
case ExtractEquiJoinKeys(Inner, leftKeys, rightKeys, condition, left, right)
if sqlContext.autoBroadcastJoinThreshold > 0 &&
left.statistics.sizeInBytes <= sqlContext.autoBroadcastJoinThreshold =>
- makeBroadcastHashJoin(leftKeys, rightKeys, left, right, condition, BuildLeft)
+ makeBroadcastHashJoin(leftKeys, rightKeys, left, right, condition, joins.BuildLeft)
case ExtractEquiJoinKeys(Inner, leftKeys, rightKeys, condition, left, right) =>
val buildSide =
if (right.statistics.sizeInBytes <= left.statistics.sizeInBytes) {
- BuildRight
+ joins.BuildRight
} else {
- BuildLeft
+ joins.BuildLeft
}
- val hashJoin =
- execution.ShuffledHashJoin(
- leftKeys, rightKeys, buildSide, planLater(left), planLater(right))
+ val hashJoin = joins.ShuffledHashJoin(
+ leftKeys, rightKeys, buildSide, planLater(left), planLater(right))
condition.map(Filter(_, hashJoin)).getOrElse(hashJoin) :: Nil
case ExtractEquiJoinKeys(joinType, leftKeys, rightKeys, condition, left, right) =>
- execution.HashOuterJoin(
+ joins.HashOuterJoin(
leftKeys, rightKeys, joinType, condition, planLater(left), planLater(right)) :: Nil
case _ => Nil
@@ -164,8 +163,12 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
case logical.Join(left, right, joinType, condition) =>
val buildSide =
- if (right.statistics.sizeInBytes <= left.statistics.sizeInBytes) BuildRight else BuildLeft
- execution.BroadcastNestedLoopJoin(
+ if (right.statistics.sizeInBytes <= left.statistics.sizeInBytes) {
+ joins.BuildRight
+ } else {
+ joins.BuildLeft
+ }
+ joins.BroadcastNestedLoopJoin(
planLater(left), planLater(right), buildSide, joinType, condition) :: Nil
case _ => Nil
}
@@ -174,10 +177,10 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
object CartesianProduct extends Strategy {
def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
case logical.Join(left, right, _, None) =>
- execution.CartesianProduct(planLater(left), planLater(right)) :: Nil
+ execution.joins.CartesianProduct(planLater(left), planLater(right)) :: Nil
case logical.Join(left, right, Inner, Some(condition)) =>
execution.Filter(condition,
- execution.CartesianProduct(planLater(left), planLater(right))) :: Nil
+ execution.joins.CartesianProduct(planLater(left), planLater(right))) :: Nil
case _ => Nil
}
}
@@ -272,10 +275,12 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
execution.Aggregate(partial = false, group, agg, planLater(child)) :: Nil
case logical.Sample(fraction, withReplacement, seed, child) =>
execution.Sample(fraction, withReplacement, seed, planLater(child)) :: Nil
+ case SparkLogicalPlan(alreadyPlanned) => alreadyPlanned :: Nil
case logical.LocalRelation(output, data) =>
- ExistingRdd(
+ val nPartitions = if (data.isEmpty) 1 else numPartitions
+ PhysicalRDD(
output,
- ExistingRdd.productToRowRdd(sparkContext.parallelize(data, numPartitions))) :: Nil
+ RDDConversions.productToRowRdd(sparkContext.parallelize(data, nPartitions))) :: Nil
case logical.Limit(IntegerLiteral(limit), child) =>
execution.Limit(limit, planLater(child)) :: Nil
case Unions(unionChildren) =>
@@ -287,26 +292,26 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
case logical.Generate(generator, join, outer, _, child) =>
execution.Generate(generator, join = join, outer = outer, planLater(child)) :: Nil
case logical.NoRelation =>
- execution.ExistingRdd(Nil, singleRowRdd) :: Nil
+ execution.PhysicalRDD(Nil, singleRowRdd) :: Nil
case logical.Repartition(expressions, child) =>
execution.Exchange(HashPartitioning(expressions, numPartitions), planLater(child)) :: Nil
- case e @ EvaluatePython(udf, child) =>
+ case e @ EvaluatePython(udf, child, _) =>
BatchPythonEvaluation(udf, e.output, planLater(child)) :: Nil
- case SparkLogicalPlan(existingPlan) => existingPlan :: Nil
+ case LogicalRDD(output, rdd) => PhysicalRDD(output, rdd) :: Nil
case _ => Nil
}
}
case class CommandStrategy(context: SQLContext) extends Strategy {
def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
- case logical.SetCommand(key, value) =>
- Seq(execution.SetCommand(key, value, plan.output)(context))
+ case logical.SetCommand(kv) =>
+ Seq(execution.SetCommand(kv, plan.output)(context))
case logical.ExplainCommand(logicalPlan, extended) =>
Seq(execution.ExplainCommand(logicalPlan, plan.output, extended)(context))
- case logical.CacheCommand(tableName, cache) =>
- Seq(execution.CacheCommand(tableName, cache)(context))
- case logical.CacheTableAsSelectCommand(tableName, plan) =>
- Seq(execution.CacheTableAsSelectCommand(tableName, plan))
+ case logical.CacheTableCommand(tableName, optPlan, isLazy) =>
+ Seq(execution.CacheTableCommand(tableName, optPlan, isLazy))
+ case logical.UncacheTableCommand(tableName) =>
+ Seq(execution.UncacheTableCommand(tableName))
case _ => Nil
}
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala
index cac376608be29..977f3c9f32096 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala
@@ -210,45 +210,6 @@ case class Sort(
override def output = child.output
}
-/**
- * :: DeveloperApi ::
- */
-@DeveloperApi
-object ExistingRdd {
- def productToRowRdd[A <: Product](data: RDD[A]): RDD[Row] = {
- data.mapPartitions { iterator =>
- if (iterator.isEmpty) {
- Iterator.empty
- } else {
- val bufferedIterator = iterator.buffered
- val mutableRow = new GenericMutableRow(bufferedIterator.head.productArity)
-
- bufferedIterator.map { r =>
- var i = 0
- while (i < mutableRow.length) {
- mutableRow(i) = ScalaReflection.convertToCatalyst(r.productElement(i))
- i += 1
- }
-
- mutableRow
- }
- }
- }
- }
-
- def fromProductRdd[A <: Product : TypeTag](productRdd: RDD[A]) = {
- ExistingRdd(ScalaReflection.attributesFor[A], productToRowRdd(productRdd))
- }
-}
-
-/**
- * :: DeveloperApi ::
- */
-@DeveloperApi
-case class ExistingRdd(output: Seq[Attribute], rdd: RDD[Row]) extends LeafNode {
- override def execute() = rdd
-}
-
/**
* :: DeveloperApi ::
* Computes the set of distinct input rows using a HashSet.
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/commands.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/commands.scala
index c2f48a902a3e9..5859eba408ee1 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/commands.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/commands.scala
@@ -37,7 +37,7 @@ trait Command {
* The `execute()` method of all the physical command classes should reference `sideEffectResult`
* so that the command can be executed eagerly right after the command query is created.
*/
- protected[sql] lazy val sideEffectResult: Seq[Row] = Seq.empty[Row]
+ protected lazy val sideEffectResult: Seq[Row] = Seq.empty[Row]
override def executeCollect(): Array[Row] = sideEffectResult.toArray
@@ -48,29 +48,28 @@ trait Command {
* :: DeveloperApi ::
*/
@DeveloperApi
-case class SetCommand(
- key: Option[String], value: Option[String], output: Seq[Attribute])(
+case class SetCommand(kv: Option[(String, Option[String])], output: Seq[Attribute])(
@transient context: SQLContext)
extends LeafNode with Command with Logging {
- override protected[sql] lazy val sideEffectResult: Seq[Row] = (key, value) match {
- // Set value for key k.
- case (Some(k), Some(v)) =>
- if (k == SQLConf.Deprecated.MAPRED_REDUCE_TASKS) {
+ override protected lazy val sideEffectResult: Seq[Row] = kv match {
+ // Set value for the key.
+ case Some((key, Some(value))) =>
+ if (key == SQLConf.Deprecated.MAPRED_REDUCE_TASKS) {
logWarning(s"Property ${SQLConf.Deprecated.MAPRED_REDUCE_TASKS} is deprecated, " +
s"automatically converted to ${SQLConf.SHUFFLE_PARTITIONS} instead.")
- context.setConf(SQLConf.SHUFFLE_PARTITIONS, v)
- Seq(Row(s"${SQLConf.SHUFFLE_PARTITIONS}=$v"))
+ context.setConf(SQLConf.SHUFFLE_PARTITIONS, value)
+ Seq(Row(s"${SQLConf.SHUFFLE_PARTITIONS}=$value"))
} else {
- context.setConf(k, v)
- Seq(Row(s"$k=$v"))
+ context.setConf(key, value)
+ Seq(Row(s"$key=$value"))
}
- // Query the value bound to key k.
- case (Some(k), _) =>
+ // Query the value bound to the key.
+ case Some((key, None)) =>
// TODO (lian) This is just a workaround to make the Simba ODBC driver work.
// Should remove this once we get the ODBC driver updated.
- if (k == "-v") {
+ if (key == "-v") {
val hiveJars = Seq(
"hive-exec-0.12.0.jar",
"hive-service-0.12.0.jar",
@@ -84,23 +83,20 @@ case class SetCommand(
Row("system:java.class.path=" + hiveJars),
Row("system:sun.java.command=shark.SharkServer2"))
} else {
- if (k == SQLConf.Deprecated.MAPRED_REDUCE_TASKS) {
+ if (key == SQLConf.Deprecated.MAPRED_REDUCE_TASKS) {
logWarning(s"Property ${SQLConf.Deprecated.MAPRED_REDUCE_TASKS} is deprecated, " +
s"showing ${SQLConf.SHUFFLE_PARTITIONS} instead.")
Seq(Row(s"${SQLConf.SHUFFLE_PARTITIONS}=${context.numShufflePartitions}"))
} else {
- Seq(Row(s"$k=${context.getConf(k, "")}"))
+ Seq(Row(s"$key=${context.getConf(key, "")}"))
}
}
// Query all key-value pairs that are set in the SQLConf of the context.
- case (None, None) =>
+ case _ =>
context.getAllConfs.map { case (k, v) =>
Row(s"$k=$v")
}.toSeq
-
- case _ =>
- throw new IllegalArgumentException()
}
override def otherCopyArgs = context :: Nil
@@ -121,7 +117,7 @@ case class ExplainCommand(
extends LeafNode with Command {
// Run through the optimizer to generate the physical plan.
- override protected[sql] lazy val sideEffectResult: Seq[Row] = try {
+ override protected lazy val sideEffectResult: Seq[Row] = try {
// TODO in Hive, the "extended" ExplainCommand prints the AST as well, and detailed properties.
val queryExecution = context.executePlan(logicalPlan)
val outputString = if (extended) queryExecution.toString else queryExecution.simpleString
@@ -138,49 +134,54 @@ case class ExplainCommand(
* :: DeveloperApi ::
*/
@DeveloperApi
-case class CacheCommand(tableName: String, doCache: Boolean)(@transient context: SQLContext)
+case class CacheTableCommand(
+ tableName: String,
+ plan: Option[LogicalPlan],
+ isLazy: Boolean)
extends LeafNode with Command {
- override protected[sql] lazy val sideEffectResult = {
- if (doCache) {
- context.cacheTable(tableName)
- } else {
- context.uncacheTable(tableName)
+ override protected lazy val sideEffectResult = {
+ import sqlContext._
+
+ plan.foreach(_.registerTempTable(tableName))
+ val schemaRDD = table(tableName)
+ schemaRDD.cache()
+
+ if (!isLazy) {
+ // Performs eager caching
+ schemaRDD.count()
}
+
Seq.empty[Row]
}
override def output: Seq[Attribute] = Seq.empty
}
+
/**
* :: DeveloperApi ::
*/
@DeveloperApi
-case class DescribeCommand(child: SparkPlan, output: Seq[Attribute])(
- @transient context: SQLContext)
- extends LeafNode with Command {
-
- override protected[sql] lazy val sideEffectResult: Seq[Row] = {
- Row("# Registered as a temporary table", null, null) +:
- child.output.map(field => Row(field.name, field.dataType.toString, null))
+case class UncacheTableCommand(tableName: String) extends LeafNode with Command {
+ override protected lazy val sideEffectResult: Seq[Row] = {
+ sqlContext.table(tableName).unpersist()
+ Seq.empty[Row]
}
+
+ override def output: Seq[Attribute] = Seq.empty
}
/**
* :: DeveloperApi ::
*/
@DeveloperApi
-case class CacheTableAsSelectCommand(tableName: String, logicalPlan: LogicalPlan)
+case class DescribeCommand(child: SparkPlan, output: Seq[Attribute])(
+ @transient context: SQLContext)
extends LeafNode with Command {
-
- override protected[sql] lazy val sideEffectResult = {
- import sqlContext._
- logicalPlan.registerTempTable(tableName)
- cacheTable(tableName)
- Seq.empty[Row]
- }
- override def output: Seq[Attribute] = Seq.empty
-
+ override protected lazy val sideEffectResult: Seq[Row] = {
+ Row("# Registered as a temporary table", null, null) +:
+ child.output.map(field => Row(field.name, field.dataType.toString, null))
+ }
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala
index a9535a750bcd7..61be5ed2db65c 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala
@@ -24,6 +24,7 @@ import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.SparkContext._
import org.apache.spark.sql.{SchemaRDD, Row}
import org.apache.spark.sql.catalyst.trees.TreeNodeRef
+import org.apache.spark.sql.catalyst.types._
/**
* :: DeveloperApi ::
@@ -56,6 +57,23 @@ package object debug {
case _ =>
}
}
+
+ def typeCheck(): Unit = {
+ val plan = query.queryExecution.executedPlan
+ val visited = new collection.mutable.HashSet[TreeNodeRef]()
+ val debugPlan = plan transform {
+ case s: SparkPlan if !visited.contains(new TreeNodeRef(s)) =>
+ visited += new TreeNodeRef(s)
+ TypeCheck(s)
+ }
+ try {
+ println(s"Results returned: ${debugPlan.execute().count()}")
+ } catch {
+ case e: Exception =>
+ def unwrap(e: Throwable): Throwable = if (e.getCause == null) e else unwrap(e.getCause)
+ println(s"Deepest Error: ${unwrap(e)}")
+ }
+ }
}
private[sql] case class DebugNode(child: SparkPlan) extends UnaryNode {
@@ -115,4 +133,71 @@ package object debug {
}
}
}
+
+ /**
+ * :: DeveloperApi ::
+ * Helper functions for checking that runtime types match a given schema.
+ */
+ @DeveloperApi
+ object TypeCheck {
+ def typeCheck(data: Any, schema: DataType): Unit = (data, schema) match {
+ case (null, _) =>
+
+ case (row: Row, StructType(fields)) =>
+ row.zip(fields.map(_.dataType)).foreach { case(d,t) => typeCheck(d,t) }
+ case (s: Seq[_], ArrayType(elemType, _)) =>
+ s.foreach(typeCheck(_, elemType))
+ case (m: Map[_, _], MapType(keyType, valueType, _)) =>
+ m.keys.foreach(typeCheck(_, keyType))
+ m.values.foreach(typeCheck(_, valueType))
+
+ case (_: Long, LongType) =>
+ case (_: Int, IntegerType) =>
+ case (_: String, StringType) =>
+ case (_: Float, FloatType) =>
+ case (_: Byte, ByteType) =>
+ case (_: Short, ShortType) =>
+ case (_: Boolean, BooleanType) =>
+ case (_: Double, DoubleType) =>
+
+ case (d, t) => sys.error(s"Invalid data found: got $d (${d.getClass}) expected $t")
+ }
+ }
+
+ /**
+ * :: DeveloperApi ::
+ * Augments SchemaRDDs with debug methods.
+ */
+ @DeveloperApi
+ private[sql] case class TypeCheck(child: SparkPlan) extends SparkPlan {
+ import TypeCheck._
+
+ override def nodeName = ""
+
+ /* Only required when defining this class in a REPL.
+ override def makeCopy(args: Array[Object]): this.type =
+ TypeCheck(args(0).asInstanceOf[SparkPlan]).asInstanceOf[this.type]
+ */
+
+ def output = child.output
+
+ def children = child :: Nil
+
+ def execute() = {
+ child.execute().map { row =>
+ try typeCheck(row, child.schema) catch {
+ case e: Exception =>
+ sys.error(
+ s"""
+ |ERROR WHEN TYPE CHECKING QUERY
+ |==============================
+ |$e
+ |======== BAD TREE ============
+ |$child
+ """.stripMargin)
+ }
+ row
+ }
+ }
+ }
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins.scala
deleted file mode 100644
index 2890a563bed48..0000000000000
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins.scala
+++ /dev/null
@@ -1,624 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.sql.execution
-
-import java.util.{HashMap => JavaHashMap}
-
-import scala.concurrent.ExecutionContext.Implicits.global
-import scala.concurrent._
-import scala.concurrent.duration._
-
-import org.apache.spark.annotation.DeveloperApi
-import org.apache.spark.sql.catalyst.expressions._
-import org.apache.spark.sql.catalyst.plans._
-import org.apache.spark.sql.catalyst.plans.physical._
-import org.apache.spark.util.collection.CompactBuffer
-
-@DeveloperApi
-sealed abstract class BuildSide
-
-@DeveloperApi
-case object BuildLeft extends BuildSide
-
-@DeveloperApi
-case object BuildRight extends BuildSide
-
-trait HashJoin {
- self: SparkPlan =>
-
- val leftKeys: Seq[Expression]
- val rightKeys: Seq[Expression]
- val buildSide: BuildSide
- val left: SparkPlan
- val right: SparkPlan
-
- lazy val (buildPlan, streamedPlan) = buildSide match {
- case BuildLeft => (left, right)
- case BuildRight => (right, left)
- }
-
- lazy val (buildKeys, streamedKeys) = buildSide match {
- case BuildLeft => (leftKeys, rightKeys)
- case BuildRight => (rightKeys, leftKeys)
- }
-
- def output = left.output ++ right.output
-
- @transient lazy val buildSideKeyGenerator = newProjection(buildKeys, buildPlan.output)
- @transient lazy val streamSideKeyGenerator =
- newMutableProjection(streamedKeys, streamedPlan.output)
-
- def joinIterators(buildIter: Iterator[Row], streamIter: Iterator[Row]): Iterator[Row] = {
- // TODO: Use Spark's HashMap implementation.
-
- val hashTable = new java.util.HashMap[Row, CompactBuffer[Row]]()
- var currentRow: Row = null
-
- // Create a mapping of buildKeys -> rows
- while (buildIter.hasNext) {
- currentRow = buildIter.next()
- val rowKey = buildSideKeyGenerator(currentRow)
- if (!rowKey.anyNull) {
- val existingMatchList = hashTable.get(rowKey)
- val matchList = if (existingMatchList == null) {
- val newMatchList = new CompactBuffer[Row]()
- hashTable.put(rowKey, newMatchList)
- newMatchList
- } else {
- existingMatchList
- }
- matchList += currentRow.copy()
- }
- }
-
- new Iterator[Row] {
- private[this] var currentStreamedRow: Row = _
- private[this] var currentHashMatches: CompactBuffer[Row] = _
- private[this] var currentMatchPosition: Int = -1
-
- // Mutable per row objects.
- private[this] val joinRow = new JoinedRow2
-
- private[this] val joinKeys = streamSideKeyGenerator()
-
- override final def hasNext: Boolean =
- (currentMatchPosition != -1 && currentMatchPosition < currentHashMatches.size) ||
- (streamIter.hasNext && fetchNext())
-
- override final def next() = {
- val ret = buildSide match {
- case BuildRight => joinRow(currentStreamedRow, currentHashMatches(currentMatchPosition))
- case BuildLeft => joinRow(currentHashMatches(currentMatchPosition), currentStreamedRow)
- }
- currentMatchPosition += 1
- ret
- }
-
- /**
- * Searches the streamed iterator for the next row that has at least one match in hashtable.
- *
- * @return true if the search is successful, and false if the streamed iterator runs out of
- * tuples.
- */
- private final def fetchNext(): Boolean = {
- currentHashMatches = null
- currentMatchPosition = -1
-
- while (currentHashMatches == null && streamIter.hasNext) {
- currentStreamedRow = streamIter.next()
- if (!joinKeys(currentStreamedRow).anyNull) {
- currentHashMatches = hashTable.get(joinKeys.currentValue)
- }
- }
-
- if (currentHashMatches == null) {
- false
- } else {
- currentMatchPosition = 0
- true
- }
- }
- }
- }
-}
-
-/**
- * :: DeveloperApi ::
- * Performs a hash based outer join for two child relations by shuffling the data using
- * the join keys. This operator requires loading the associated partition in both side into memory.
- */
-@DeveloperApi
-case class HashOuterJoin(
- leftKeys: Seq[Expression],
- rightKeys: Seq[Expression],
- joinType: JoinType,
- condition: Option[Expression],
- left: SparkPlan,
- right: SparkPlan) extends BinaryNode {
-
- override def outputPartitioning: Partitioning = joinType match {
- case LeftOuter => left.outputPartitioning
- case RightOuter => right.outputPartitioning
- case FullOuter => UnknownPartitioning(left.outputPartitioning.numPartitions)
- case x => throw new Exception(s"HashOuterJoin should not take $x as the JoinType")
- }
-
- override def requiredChildDistribution =
- ClusteredDistribution(leftKeys) :: ClusteredDistribution(rightKeys) :: Nil
-
- override def output = {
- joinType match {
- case LeftOuter =>
- left.output ++ right.output.map(_.withNullability(true))
- case RightOuter =>
- left.output.map(_.withNullability(true)) ++ right.output
- case FullOuter =>
- left.output.map(_.withNullability(true)) ++ right.output.map(_.withNullability(true))
- case x =>
- throw new Exception(s"HashOuterJoin should not take $x as the JoinType")
- }
- }
-
- @transient private[this] lazy val DUMMY_LIST = Seq[Row](null)
- @transient private[this] lazy val EMPTY_LIST = Seq.empty[Row]
-
- // TODO we need to rewrite all of the iterators with our own implementation instead of the Scala
- // iterator for performance purpose.
-
- private[this] def leftOuterIterator(
- key: Row, leftIter: Iterable[Row], rightIter: Iterable[Row]): Iterator[Row] = {
- val joinedRow = new JoinedRow()
- val rightNullRow = new GenericRow(right.output.length)
- val boundCondition =
- condition.map(newPredicate(_, left.output ++ right.output)).getOrElse((row: Row) => true)
-
- leftIter.iterator.flatMap { l =>
- joinedRow.withLeft(l)
- var matched = false
- (if (!key.anyNull) rightIter.collect { case r if (boundCondition(joinedRow.withRight(r))) =>
- matched = true
- joinedRow.copy
- } else {
- Nil
- }) ++ DUMMY_LIST.filter(_ => !matched).map( _ => {
- // DUMMY_LIST.filter(_ => !matched) is a tricky way to add additional row,
- // as we don't know whether we need to append it until finish iterating all of the
- // records in right side.
- // If we didn't get any proper row, then append a single row with empty right
- joinedRow.withRight(rightNullRow).copy
- })
- }
- }
-
- private[this] def rightOuterIterator(
- key: Row, leftIter: Iterable[Row], rightIter: Iterable[Row]): Iterator[Row] = {
- val joinedRow = new JoinedRow()
- val leftNullRow = new GenericRow(left.output.length)
- val boundCondition =
- condition.map(newPredicate(_, left.output ++ right.output)).getOrElse((row: Row) => true)
-
- rightIter.iterator.flatMap { r =>
- joinedRow.withRight(r)
- var matched = false
- (if (!key.anyNull) leftIter.collect { case l if (boundCondition(joinedRow.withLeft(l))) =>
- matched = true
- joinedRow.copy
- } else {
- Nil
- }) ++ DUMMY_LIST.filter(_ => !matched).map( _ => {
- // DUMMY_LIST.filter(_ => !matched) is a tricky way to add additional row,
- // as we don't know whether we need to append it until finish iterating all of the
- // records in left side.
- // If we didn't get any proper row, then append a single row with empty left.
- joinedRow.withLeft(leftNullRow).copy
- })
- }
- }
-
- private[this] def fullOuterIterator(
- key: Row, leftIter: Iterable[Row], rightIter: Iterable[Row]): Iterator[Row] = {
- val joinedRow = new JoinedRow()
- val leftNullRow = new GenericRow(left.output.length)
- val rightNullRow = new GenericRow(right.output.length)
- val boundCondition =
- condition.map(newPredicate(_, left.output ++ right.output)).getOrElse((row: Row) => true)
-
- if (!key.anyNull) {
- // Store the positions of records in right, if one of its associated row satisfy
- // the join condition.
- val rightMatchedSet = scala.collection.mutable.Set[Int]()
- leftIter.iterator.flatMap[Row] { l =>
- joinedRow.withLeft(l)
- var matched = false
- rightIter.zipWithIndex.collect {
- // 1. For those matched (satisfy the join condition) records with both sides filled,
- // append them directly
-
- case (r, idx) if (boundCondition(joinedRow.withRight(r)))=> {
- matched = true
- // if the row satisfy the join condition, add its index into the matched set
- rightMatchedSet.add(idx)
- joinedRow.copy
- }
- } ++ DUMMY_LIST.filter(_ => !matched).map( _ => {
- // 2. For those unmatched records in left, append additional records with empty right.
-
- // DUMMY_LIST.filter(_ => !matched) is a tricky way to add additional row,
- // as we don't know whether we need to append it until finish iterating all
- // of the records in right side.
- // If we didn't get any proper row, then append a single row with empty right.
- joinedRow.withRight(rightNullRow).copy
- })
- } ++ rightIter.zipWithIndex.collect {
- // 3. For those unmatched records in right, append additional records with empty left.
-
- // Re-visiting the records in right, and append additional row with empty left, if its not
- // in the matched set.
- case (r, idx) if (!rightMatchedSet.contains(idx)) => {
- joinedRow(leftNullRow, r).copy
- }
- }
- } else {
- leftIter.iterator.map[Row] { l =>
- joinedRow(l, rightNullRow).copy
- } ++ rightIter.iterator.map[Row] { r =>
- joinedRow(leftNullRow, r).copy
- }
- }
- }
-
- private[this] def buildHashTable(
- iter: Iterator[Row], keyGenerator: Projection): JavaHashMap[Row, CompactBuffer[Row]] = {
- val hashTable = new JavaHashMap[Row, CompactBuffer[Row]]()
- while (iter.hasNext) {
- val currentRow = iter.next()
- val rowKey = keyGenerator(currentRow)
-
- var existingMatchList = hashTable.get(rowKey)
- if (existingMatchList == null) {
- existingMatchList = new CompactBuffer[Row]()
- hashTable.put(rowKey, existingMatchList)
- }
-
- existingMatchList += currentRow.copy()
- }
-
- hashTable
- }
-
- def execute() = {
- left.execute().zipPartitions(right.execute()) { (leftIter, rightIter) =>
- // TODO this probably can be replaced by external sort (sort merged join?)
- // Build HashMap for current partition in left relation
- val leftHashTable = buildHashTable(leftIter, newProjection(leftKeys, left.output))
- // Build HashMap for current partition in right relation
- val rightHashTable = buildHashTable(rightIter, newProjection(rightKeys, right.output))
-
- import scala.collection.JavaConversions._
- val boundCondition =
- condition.map(newPredicate(_, left.output ++ right.output)).getOrElse((row: Row) => true)
- joinType match {
- case LeftOuter => leftHashTable.keysIterator.flatMap { key =>
- leftOuterIterator(key, leftHashTable.getOrElse(key, EMPTY_LIST),
- rightHashTable.getOrElse(key, EMPTY_LIST))
- }
- case RightOuter => rightHashTable.keysIterator.flatMap { key =>
- rightOuterIterator(key, leftHashTable.getOrElse(key, EMPTY_LIST),
- rightHashTable.getOrElse(key, EMPTY_LIST))
- }
- case FullOuter => (leftHashTable.keySet ++ rightHashTable.keySet).iterator.flatMap { key =>
- fullOuterIterator(key,
- leftHashTable.getOrElse(key, EMPTY_LIST),
- rightHashTable.getOrElse(key, EMPTY_LIST))
- }
- case x => throw new Exception(s"HashOuterJoin should not take $x as the JoinType")
- }
- }
- }
-}
-
-/**
- * :: DeveloperApi ::
- * Performs an inner hash join of two child relations by first shuffling the data using the join
- * keys.
- */
-@DeveloperApi
-case class ShuffledHashJoin(
- leftKeys: Seq[Expression],
- rightKeys: Seq[Expression],
- buildSide: BuildSide,
- left: SparkPlan,
- right: SparkPlan) extends BinaryNode with HashJoin {
-
- override def outputPartitioning: Partitioning = left.outputPartitioning
-
- override def requiredChildDistribution =
- ClusteredDistribution(leftKeys) :: ClusteredDistribution(rightKeys) :: Nil
-
- def execute() = {
- buildPlan.execute().zipPartitions(streamedPlan.execute()) {
- (buildIter, streamIter) => joinIterators(buildIter, streamIter)
- }
- }
-}
-
-/**
- * :: DeveloperApi ::
- * Build the right table's join keys into a HashSet, and iteratively go through the left
- * table, to find the if join keys are in the Hash set.
- */
-@DeveloperApi
-case class LeftSemiJoinHash(
- leftKeys: Seq[Expression],
- rightKeys: Seq[Expression],
- left: SparkPlan,
- right: SparkPlan) extends BinaryNode with HashJoin {
-
- val buildSide = BuildRight
-
- override def requiredChildDistribution =
- ClusteredDistribution(leftKeys) :: ClusteredDistribution(rightKeys) :: Nil
-
- override def output = left.output
-
- def execute() = {
- buildPlan.execute().zipPartitions(streamedPlan.execute()) { (buildIter, streamIter) =>
- val hashSet = new java.util.HashSet[Row]()
- var currentRow: Row = null
-
- // Create a Hash set of buildKeys
- while (buildIter.hasNext) {
- currentRow = buildIter.next()
- val rowKey = buildSideKeyGenerator(currentRow)
- if (!rowKey.anyNull) {
- val keyExists = hashSet.contains(rowKey)
- if (!keyExists) {
- hashSet.add(rowKey)
- }
- }
- }
-
- val joinKeys = streamSideKeyGenerator()
- streamIter.filter(current => {
- !joinKeys(current).anyNull && hashSet.contains(joinKeys.currentValue)
- })
- }
- }
-}
-
-
-/**
- * :: DeveloperApi ::
- * Performs an inner hash join of two child relations. When the output RDD of this operator is
- * being constructed, a Spark job is asynchronously started to calculate the values for the
- * broadcasted relation. This data is then placed in a Spark broadcast variable. The streamed
- * relation is not shuffled.
- */
-@DeveloperApi
-case class BroadcastHashJoin(
- leftKeys: Seq[Expression],
- rightKeys: Seq[Expression],
- buildSide: BuildSide,
- left: SparkPlan,
- right: SparkPlan) extends BinaryNode with HashJoin {
-
- override def outputPartitioning: Partitioning = streamedPlan.outputPartitioning
-
- override def requiredChildDistribution =
- UnspecifiedDistribution :: UnspecifiedDistribution :: Nil
-
- @transient
- val broadcastFuture = future {
- sparkContext.broadcast(buildPlan.executeCollect())
- }
-
- def execute() = {
- val broadcastRelation = Await.result(broadcastFuture, 5.minute)
-
- streamedPlan.execute().mapPartitions { streamedIter =>
- joinIterators(broadcastRelation.value.iterator, streamedIter)
- }
- }
-}
-
-/**
- * :: DeveloperApi ::
- * Using BroadcastNestedLoopJoin to calculate left semi join result when there's no join keys
- * for hash join.
- */
-@DeveloperApi
-case class LeftSemiJoinBNL(
- streamed: SparkPlan, broadcast: SparkPlan, condition: Option[Expression])
- extends BinaryNode {
- // TODO: Override requiredChildDistribution.
-
- override def outputPartitioning: Partitioning = streamed.outputPartitioning
-
- def output = left.output
-
- /** The Streamed Relation */
- def left = streamed
- /** The Broadcast relation */
- def right = broadcast
-
- @transient lazy val boundCondition =
- InterpretedPredicate(
- condition
- .map(c => BindReferences.bindReference(c, left.output ++ right.output))
- .getOrElse(Literal(true)))
-
- def execute() = {
- val broadcastedRelation =
- sparkContext.broadcast(broadcast.execute().map(_.copy()).collect().toIndexedSeq)
-
- streamed.execute().mapPartitions { streamedIter =>
- val joinedRow = new JoinedRow
-
- streamedIter.filter(streamedRow => {
- var i = 0
- var matched = false
-
- while (i < broadcastedRelation.value.size && !matched) {
- val broadcastedRow = broadcastedRelation.value(i)
- if (boundCondition(joinedRow(streamedRow, broadcastedRow))) {
- matched = true
- }
- i += 1
- }
- matched
- })
- }
- }
-}
-
-/**
- * :: DeveloperApi ::
- */
-@DeveloperApi
-case class CartesianProduct(left: SparkPlan, right: SparkPlan) extends BinaryNode {
- def output = left.output ++ right.output
-
- def execute() = {
- val leftResults = left.execute().map(_.copy())
- val rightResults = right.execute().map(_.copy())
-
- leftResults.cartesian(rightResults).mapPartitions { iter =>
- val joinedRow = new JoinedRow
- iter.map(r => joinedRow(r._1, r._2))
- }
- }
-}
-
-/**
- * :: DeveloperApi ::
- */
-@DeveloperApi
-case class BroadcastNestedLoopJoin(
- left: SparkPlan,
- right: SparkPlan,
- buildSide: BuildSide,
- joinType: JoinType,
- condition: Option[Expression]) extends BinaryNode {
- // TODO: Override requiredChildDistribution.
-
- /** BuildRight means the right relation <=> the broadcast relation. */
- val (streamed, broadcast) = buildSide match {
- case BuildRight => (left, right)
- case BuildLeft => (right, left)
- }
-
- override def outputPartitioning: Partitioning = streamed.outputPartitioning
-
- override def output = {
- joinType match {
- case LeftOuter =>
- left.output ++ right.output.map(_.withNullability(true))
- case RightOuter =>
- left.output.map(_.withNullability(true)) ++ right.output
- case FullOuter =>
- left.output.map(_.withNullability(true)) ++ right.output.map(_.withNullability(true))
- case _ =>
- left.output ++ right.output
- }
- }
-
- @transient lazy val boundCondition =
- InterpretedPredicate(
- condition
- .map(c => BindReferences.bindReference(c, left.output ++ right.output))
- .getOrElse(Literal(true)))
-
- def execute() = {
- val broadcastedRelation =
- sparkContext.broadcast(broadcast.execute().map(_.copy()).collect().toIndexedSeq)
-
- /** All rows that either match both-way, or rows from streamed joined with nulls. */
- val matchesOrStreamedRowsWithNulls = streamed.execute().mapPartitions { streamedIter =>
- val matchedRows = new CompactBuffer[Row]
- // TODO: Use Spark's BitSet.
- val includedBroadcastTuples =
- new scala.collection.mutable.BitSet(broadcastedRelation.value.size)
- val joinedRow = new JoinedRow
- val leftNulls = new GenericMutableRow(left.output.size)
- val rightNulls = new GenericMutableRow(right.output.size)
-
- streamedIter.foreach { streamedRow =>
- var i = 0
- var streamRowMatched = false
-
- while (i < broadcastedRelation.value.size) {
- // TODO: One bitset per partition instead of per row.
- val broadcastedRow = broadcastedRelation.value(i)
- buildSide match {
- case BuildRight if boundCondition(joinedRow(streamedRow, broadcastedRow)) =>
- matchedRows += joinedRow(streamedRow, broadcastedRow).copy()
- streamRowMatched = true
- includedBroadcastTuples += i
- case BuildLeft if boundCondition(joinedRow(broadcastedRow, streamedRow)) =>
- matchedRows += joinedRow(broadcastedRow, streamedRow).copy()
- streamRowMatched = true
- includedBroadcastTuples += i
- case _ =>
- }
- i += 1
- }
-
- (streamRowMatched, joinType, buildSide) match {
- case (false, LeftOuter | FullOuter, BuildRight) =>
- matchedRows += joinedRow(streamedRow, rightNulls).copy()
- case (false, RightOuter | FullOuter, BuildLeft) =>
- matchedRows += joinedRow(leftNulls, streamedRow).copy()
- case _ =>
- }
- }
- Iterator((matchedRows, includedBroadcastTuples))
- }
-
- val includedBroadcastTuples = matchesOrStreamedRowsWithNulls.map(_._2)
- val allIncludedBroadcastTuples =
- if (includedBroadcastTuples.count == 0) {
- new scala.collection.mutable.BitSet(broadcastedRelation.value.size)
- } else {
- includedBroadcastTuples.reduce(_ ++ _)
- }
-
- val leftNulls = new GenericMutableRow(left.output.size)
- val rightNulls = new GenericMutableRow(right.output.size)
- /** Rows from broadcasted joined with nulls. */
- val broadcastRowsWithNulls: Seq[Row] = {
- val buf: CompactBuffer[Row] = new CompactBuffer()
- var i = 0
- val rel = broadcastedRelation.value
- while (i < rel.length) {
- if (!allIncludedBroadcastTuples.contains(i)) {
- (joinType, buildSide) match {
- case (RightOuter | FullOuter, BuildRight) => buf += new JoinedRow(leftNulls, rel(i))
- case (LeftOuter | FullOuter, BuildLeft) => buf += new JoinedRow(rel(i), rightNulls)
- case _ =>
- }
- }
- i += 1
- }
- buf.toSeq
- }
-
- // TODO: Breaks lineage.
- sparkContext.union(
- matchesOrStreamedRowsWithNulls.flatMap(_._1), sparkContext.makeRDD(broadcastRowsWithNulls))
- }
-}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala
new file mode 100644
index 0000000000000..8fd35880eedfe
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala
@@ -0,0 +1,64 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.joins
+
+import scala.concurrent._
+import scala.concurrent.duration._
+import scala.concurrent.ExecutionContext.Implicits.global
+
+import org.apache.spark.annotation.DeveloperApi
+import org.apache.spark.sql.catalyst.expressions.{Row, Expression}
+import org.apache.spark.sql.catalyst.plans.physical.{Partitioning, UnspecifiedDistribution}
+import org.apache.spark.sql.execution.{BinaryNode, SparkPlan}
+
+/**
+ * :: DeveloperApi ::
+ * Performs an inner hash join of two child relations. When the output RDD of this operator is
+ * being constructed, a Spark job is asynchronously started to calculate the values for the
+ * broadcasted relation. This data is then placed in a Spark broadcast variable. The streamed
+ * relation is not shuffled.
+ */
+@DeveloperApi
+case class BroadcastHashJoin(
+ leftKeys: Seq[Expression],
+ rightKeys: Seq[Expression],
+ buildSide: BuildSide,
+ left: SparkPlan,
+ right: SparkPlan)
+ extends BinaryNode with HashJoin {
+
+ override def outputPartitioning: Partitioning = streamedPlan.outputPartitioning
+
+ override def requiredChildDistribution =
+ UnspecifiedDistribution :: UnspecifiedDistribution :: Nil
+
+ @transient
+ private val broadcastFuture = future {
+ val input: Array[Row] = buildPlan.executeCollect()
+ val hashed = HashedRelation(input.iterator, buildSideKeyGenerator, input.length)
+ sparkContext.broadcast(hashed)
+ }
+
+ override def execute() = {
+ val broadcastRelation = Await.result(broadcastFuture, 5.minute)
+
+ streamedPlan.execute().mapPartitions { streamedIter =>
+ hashJoin(streamedIter, broadcastRelation.value)
+ }
+ }
+}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoin.scala
new file mode 100644
index 0000000000000..36aad13778bd2
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoin.scala
@@ -0,0 +1,144 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.joins
+
+import org.apache.spark.annotation.DeveloperApi
+import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.plans.physical.Partitioning
+import org.apache.spark.sql.catalyst.plans.{FullOuter, JoinType, LeftOuter, RightOuter}
+import org.apache.spark.sql.execution.{BinaryNode, SparkPlan}
+import org.apache.spark.util.collection.CompactBuffer
+
+/**
+ * :: DeveloperApi ::
+ */
+@DeveloperApi
+case class BroadcastNestedLoopJoin(
+ left: SparkPlan,
+ right: SparkPlan,
+ buildSide: BuildSide,
+ joinType: JoinType,
+ condition: Option[Expression]) extends BinaryNode {
+ // TODO: Override requiredChildDistribution.
+
+ /** BuildRight means the right relation <=> the broadcast relation. */
+ private val (streamed, broadcast) = buildSide match {
+ case BuildRight => (left, right)
+ case BuildLeft => (right, left)
+ }
+
+ override def outputPartitioning: Partitioning = streamed.outputPartitioning
+
+ override def output = {
+ joinType match {
+ case LeftOuter =>
+ left.output ++ right.output.map(_.withNullability(true))
+ case RightOuter =>
+ left.output.map(_.withNullability(true)) ++ right.output
+ case FullOuter =>
+ left.output.map(_.withNullability(true)) ++ right.output.map(_.withNullability(true))
+ case _ =>
+ left.output ++ right.output
+ }
+ }
+
+ @transient private lazy val boundCondition =
+ InterpretedPredicate(
+ condition
+ .map(c => BindReferences.bindReference(c, left.output ++ right.output))
+ .getOrElse(Literal(true)))
+
+ override def execute() = {
+ val broadcastedRelation =
+ sparkContext.broadcast(broadcast.execute().map(_.copy()).collect().toIndexedSeq)
+
+ /** All rows that either match both-way, or rows from streamed joined with nulls. */
+ val matchesOrStreamedRowsWithNulls = streamed.execute().mapPartitions { streamedIter =>
+ val matchedRows = new CompactBuffer[Row]
+ // TODO: Use Spark's BitSet.
+ val includedBroadcastTuples =
+ new scala.collection.mutable.BitSet(broadcastedRelation.value.size)
+ val joinedRow = new JoinedRow
+ val leftNulls = new GenericMutableRow(left.output.size)
+ val rightNulls = new GenericMutableRow(right.output.size)
+
+ streamedIter.foreach { streamedRow =>
+ var i = 0
+ var streamRowMatched = false
+
+ while (i < broadcastedRelation.value.size) {
+ // TODO: One bitset per partition instead of per row.
+ val broadcastedRow = broadcastedRelation.value(i)
+ buildSide match {
+ case BuildRight if boundCondition(joinedRow(streamedRow, broadcastedRow)) =>
+ matchedRows += joinedRow(streamedRow, broadcastedRow).copy()
+ streamRowMatched = true
+ includedBroadcastTuples += i
+ case BuildLeft if boundCondition(joinedRow(broadcastedRow, streamedRow)) =>
+ matchedRows += joinedRow(broadcastedRow, streamedRow).copy()
+ streamRowMatched = true
+ includedBroadcastTuples += i
+ case _ =>
+ }
+ i += 1
+ }
+
+ (streamRowMatched, joinType, buildSide) match {
+ case (false, LeftOuter | FullOuter, BuildRight) =>
+ matchedRows += joinedRow(streamedRow, rightNulls).copy()
+ case (false, RightOuter | FullOuter, BuildLeft) =>
+ matchedRows += joinedRow(leftNulls, streamedRow).copy()
+ case _ =>
+ }
+ }
+ Iterator((matchedRows, includedBroadcastTuples))
+ }
+
+ val includedBroadcastTuples = matchesOrStreamedRowsWithNulls.map(_._2)
+ val allIncludedBroadcastTuples =
+ if (includedBroadcastTuples.count == 0) {
+ new scala.collection.mutable.BitSet(broadcastedRelation.value.size)
+ } else {
+ includedBroadcastTuples.reduce(_ ++ _)
+ }
+
+ val leftNulls = new GenericMutableRow(left.output.size)
+ val rightNulls = new GenericMutableRow(right.output.size)
+ /** Rows from broadcasted joined with nulls. */
+ val broadcastRowsWithNulls: Seq[Row] = {
+ val buf: CompactBuffer[Row] = new CompactBuffer()
+ var i = 0
+ val rel = broadcastedRelation.value
+ while (i < rel.length) {
+ if (!allIncludedBroadcastTuples.contains(i)) {
+ (joinType, buildSide) match {
+ case (RightOuter | FullOuter, BuildRight) => buf += new JoinedRow(leftNulls, rel(i))
+ case (LeftOuter | FullOuter, BuildLeft) => buf += new JoinedRow(rel(i), rightNulls)
+ case _ =>
+ }
+ }
+ i += 1
+ }
+ buf.toSeq
+ }
+
+ // TODO: Breaks lineage.
+ sparkContext.union(
+ matchesOrStreamedRowsWithNulls.flatMap(_._1), sparkContext.makeRDD(broadcastRowsWithNulls))
+ }
+}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/CartesianProduct.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/CartesianProduct.scala
new file mode 100644
index 0000000000000..76c14c02aab34
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/CartesianProduct.scala
@@ -0,0 +1,40 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.joins
+
+import org.apache.spark.annotation.DeveloperApi
+import org.apache.spark.sql.catalyst.expressions.JoinedRow
+import org.apache.spark.sql.execution.{BinaryNode, SparkPlan}
+
+/**
+ * :: DeveloperApi ::
+ */
+@DeveloperApi
+case class CartesianProduct(left: SparkPlan, right: SparkPlan) extends BinaryNode {
+ override def output = left.output ++ right.output
+
+ override def execute() = {
+ val leftResults = left.execute().map(_.copy())
+ val rightResults = right.execute().map(_.copy())
+
+ leftResults.cartesian(rightResults).mapPartitions { iter =>
+ val joinedRow = new JoinedRow
+ iter.map(r => joinedRow(r._1, r._2))
+ }
+ }
+}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala
new file mode 100644
index 0000000000000..4012d757d5f9a
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala
@@ -0,0 +1,103 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.joins
+
+import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.execution.SparkPlan
+import org.apache.spark.util.collection.CompactBuffer
+
+
+trait HashJoin {
+ self: SparkPlan =>
+
+ val leftKeys: Seq[Expression]
+ val rightKeys: Seq[Expression]
+ val buildSide: BuildSide
+ val left: SparkPlan
+ val right: SparkPlan
+
+ protected lazy val (buildPlan, streamedPlan) = buildSide match {
+ case BuildLeft => (left, right)
+ case BuildRight => (right, left)
+ }
+
+ protected lazy val (buildKeys, streamedKeys) = buildSide match {
+ case BuildLeft => (leftKeys, rightKeys)
+ case BuildRight => (rightKeys, leftKeys)
+ }
+
+ override def output = left.output ++ right.output
+
+ @transient protected lazy val buildSideKeyGenerator: Projection =
+ newProjection(buildKeys, buildPlan.output)
+
+ @transient protected lazy val streamSideKeyGenerator: () => MutableProjection =
+ newMutableProjection(streamedKeys, streamedPlan.output)
+
+ protected def hashJoin(streamIter: Iterator[Row], hashedRelation: HashedRelation): Iterator[Row] =
+ {
+ new Iterator[Row] {
+ private[this] var currentStreamedRow: Row = _
+ private[this] var currentHashMatches: CompactBuffer[Row] = _
+ private[this] var currentMatchPosition: Int = -1
+
+ // Mutable per row objects.
+ private[this] val joinRow = new JoinedRow2
+
+ private[this] val joinKeys = streamSideKeyGenerator()
+
+ override final def hasNext: Boolean =
+ (currentMatchPosition != -1 && currentMatchPosition < currentHashMatches.size) ||
+ (streamIter.hasNext && fetchNext())
+
+ override final def next() = {
+ val ret = buildSide match {
+ case BuildRight => joinRow(currentStreamedRow, currentHashMatches(currentMatchPosition))
+ case BuildLeft => joinRow(currentHashMatches(currentMatchPosition), currentStreamedRow)
+ }
+ currentMatchPosition += 1
+ ret
+ }
+
+ /**
+ * Searches the streamed iterator for the next row that has at least one match in hashtable.
+ *
+ * @return true if the search is successful, and false if the streamed iterator runs out of
+ * tuples.
+ */
+ private final def fetchNext(): Boolean = {
+ currentHashMatches = null
+ currentMatchPosition = -1
+
+ while (currentHashMatches == null && streamIter.hasNext) {
+ currentStreamedRow = streamIter.next()
+ if (!joinKeys(currentStreamedRow).anyNull) {
+ currentHashMatches = hashedRelation.get(joinKeys.currentValue)
+ }
+ }
+
+ if (currentHashMatches == null) {
+ false
+ } else {
+ currentMatchPosition = 0
+ true
+ }
+ }
+ }
+ }
+}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashOuterJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashOuterJoin.scala
new file mode 100644
index 0000000000000..b73041d306b36
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashOuterJoin.scala
@@ -0,0 +1,222 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.joins
+
+import java.util.{HashMap => JavaHashMap}
+
+import scala.collection.JavaConversions._
+
+import org.apache.spark.annotation.DeveloperApi
+import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.plans.physical.{ClusteredDistribution, Partitioning, UnknownPartitioning}
+import org.apache.spark.sql.catalyst.plans.{FullOuter, JoinType, LeftOuter, RightOuter}
+import org.apache.spark.sql.execution.{BinaryNode, SparkPlan}
+import org.apache.spark.util.collection.CompactBuffer
+
+/**
+ * :: DeveloperApi ::
+ * Performs a hash based outer join for two child relations by shuffling the data using
+ * the join keys. This operator requires loading the associated partition in both side into memory.
+ */
+@DeveloperApi
+case class HashOuterJoin(
+ leftKeys: Seq[Expression],
+ rightKeys: Seq[Expression],
+ joinType: JoinType,
+ condition: Option[Expression],
+ left: SparkPlan,
+ right: SparkPlan) extends BinaryNode {
+
+ override def outputPartitioning: Partitioning = joinType match {
+ case LeftOuter => left.outputPartitioning
+ case RightOuter => right.outputPartitioning
+ case FullOuter => UnknownPartitioning(left.outputPartitioning.numPartitions)
+ case x => throw new Exception(s"HashOuterJoin should not take $x as the JoinType")
+ }
+
+ override def requiredChildDistribution =
+ ClusteredDistribution(leftKeys) :: ClusteredDistribution(rightKeys) :: Nil
+
+ override def output = {
+ joinType match {
+ case LeftOuter =>
+ left.output ++ right.output.map(_.withNullability(true))
+ case RightOuter =>
+ left.output.map(_.withNullability(true)) ++ right.output
+ case FullOuter =>
+ left.output.map(_.withNullability(true)) ++ right.output.map(_.withNullability(true))
+ case x =>
+ throw new Exception(s"HashOuterJoin should not take $x as the JoinType")
+ }
+ }
+
+ @transient private[this] lazy val DUMMY_LIST = Seq[Row](null)
+ @transient private[this] lazy val EMPTY_LIST = Seq.empty[Row]
+
+ // TODO we need to rewrite all of the iterators with our own implementation instead of the Scala
+ // iterator for performance purpose.
+
+ private[this] def leftOuterIterator(
+ key: Row, leftIter: Iterable[Row], rightIter: Iterable[Row]): Iterator[Row] = {
+ val joinedRow = new JoinedRow()
+ val rightNullRow = new GenericRow(right.output.length)
+ val boundCondition =
+ condition.map(newPredicate(_, left.output ++ right.output)).getOrElse((row: Row) => true)
+
+ leftIter.iterator.flatMap { l =>
+ joinedRow.withLeft(l)
+ var matched = false
+ (if (!key.anyNull) rightIter.collect { case r if (boundCondition(joinedRow.withRight(r))) =>
+ matched = true
+ joinedRow.copy
+ } else {
+ Nil
+ }) ++ DUMMY_LIST.filter(_ => !matched).map( _ => {
+ // DUMMY_LIST.filter(_ => !matched) is a tricky way to add additional row,
+ // as we don't know whether we need to append it until finish iterating all of the
+ // records in right side.
+ // If we didn't get any proper row, then append a single row with empty right
+ joinedRow.withRight(rightNullRow).copy
+ })
+ }
+ }
+
+ private[this] def rightOuterIterator(
+ key: Row, leftIter: Iterable[Row], rightIter: Iterable[Row]): Iterator[Row] = {
+ val joinedRow = new JoinedRow()
+ val leftNullRow = new GenericRow(left.output.length)
+ val boundCondition =
+ condition.map(newPredicate(_, left.output ++ right.output)).getOrElse((row: Row) => true)
+
+ rightIter.iterator.flatMap { r =>
+ joinedRow.withRight(r)
+ var matched = false
+ (if (!key.anyNull) leftIter.collect { case l if (boundCondition(joinedRow.withLeft(l))) =>
+ matched = true
+ joinedRow.copy
+ } else {
+ Nil
+ }) ++ DUMMY_LIST.filter(_ => !matched).map( _ => {
+ // DUMMY_LIST.filter(_ => !matched) is a tricky way to add additional row,
+ // as we don't know whether we need to append it until finish iterating all of the
+ // records in left side.
+ // If we didn't get any proper row, then append a single row with empty left.
+ joinedRow.withLeft(leftNullRow).copy
+ })
+ }
+ }
+
+ private[this] def fullOuterIterator(
+ key: Row, leftIter: Iterable[Row], rightIter: Iterable[Row]): Iterator[Row] = {
+ val joinedRow = new JoinedRow()
+ val leftNullRow = new GenericRow(left.output.length)
+ val rightNullRow = new GenericRow(right.output.length)
+ val boundCondition =
+ condition.map(newPredicate(_, left.output ++ right.output)).getOrElse((row: Row) => true)
+
+ if (!key.anyNull) {
+ // Store the positions of records in right, if one of its associated row satisfy
+ // the join condition.
+ val rightMatchedSet = scala.collection.mutable.Set[Int]()
+ leftIter.iterator.flatMap[Row] { l =>
+ joinedRow.withLeft(l)
+ var matched = false
+ rightIter.zipWithIndex.collect {
+ // 1. For those matched (satisfy the join condition) records with both sides filled,
+ // append them directly
+
+ case (r, idx) if (boundCondition(joinedRow.withRight(r)))=> {
+ matched = true
+ // if the row satisfy the join condition, add its index into the matched set
+ rightMatchedSet.add(idx)
+ joinedRow.copy
+ }
+ } ++ DUMMY_LIST.filter(_ => !matched).map( _ => {
+ // 2. For those unmatched records in left, append additional records with empty right.
+
+ // DUMMY_LIST.filter(_ => !matched) is a tricky way to add additional row,
+ // as we don't know whether we need to append it until finish iterating all
+ // of the records in right side.
+ // If we didn't get any proper row, then append a single row with empty right.
+ joinedRow.withRight(rightNullRow).copy
+ })
+ } ++ rightIter.zipWithIndex.collect {
+ // 3. For those unmatched records in right, append additional records with empty left.
+
+ // Re-visiting the records in right, and append additional row with empty left, if its not
+ // in the matched set.
+ case (r, idx) if (!rightMatchedSet.contains(idx)) => {
+ joinedRow(leftNullRow, r).copy
+ }
+ }
+ } else {
+ leftIter.iterator.map[Row] { l =>
+ joinedRow(l, rightNullRow).copy
+ } ++ rightIter.iterator.map[Row] { r =>
+ joinedRow(leftNullRow, r).copy
+ }
+ }
+ }
+
+ private[this] def buildHashTable(
+ iter: Iterator[Row], keyGenerator: Projection): JavaHashMap[Row, CompactBuffer[Row]] = {
+ val hashTable = new JavaHashMap[Row, CompactBuffer[Row]]()
+ while (iter.hasNext) {
+ val currentRow = iter.next()
+ val rowKey = keyGenerator(currentRow)
+
+ var existingMatchList = hashTable.get(rowKey)
+ if (existingMatchList == null) {
+ existingMatchList = new CompactBuffer[Row]()
+ hashTable.put(rowKey, existingMatchList)
+ }
+
+ existingMatchList += currentRow.copy()
+ }
+
+ hashTable
+ }
+
+ override def execute() = {
+ left.execute().zipPartitions(right.execute()) { (leftIter, rightIter) =>
+ // TODO this probably can be replaced by external sort (sort merged join?)
+ // Build HashMap for current partition in left relation
+ val leftHashTable = buildHashTable(leftIter, newProjection(leftKeys, left.output))
+ // Build HashMap for current partition in right relation
+ val rightHashTable = buildHashTable(rightIter, newProjection(rightKeys, right.output))
+ val boundCondition =
+ condition.map(newPredicate(_, left.output ++ right.output)).getOrElse((row: Row) => true)
+ joinType match {
+ case LeftOuter => leftHashTable.keysIterator.flatMap { key =>
+ leftOuterIterator(key, leftHashTable.getOrElse(key, EMPTY_LIST),
+ rightHashTable.getOrElse(key, EMPTY_LIST))
+ }
+ case RightOuter => rightHashTable.keysIterator.flatMap { key =>
+ rightOuterIterator(key, leftHashTable.getOrElse(key, EMPTY_LIST),
+ rightHashTable.getOrElse(key, EMPTY_LIST))
+ }
+ case FullOuter => (leftHashTable.keySet ++ rightHashTable.keySet).iterator.flatMap { key =>
+ fullOuterIterator(key,
+ leftHashTable.getOrElse(key, EMPTY_LIST),
+ rightHashTable.getOrElse(key, EMPTY_LIST))
+ }
+ case x => throw new Exception(s"HashOuterJoin should not take $x as the JoinType")
+ }
+ }
+ }
+}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala
new file mode 100644
index 0000000000000..38b8993b03f82
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala
@@ -0,0 +1,109 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.joins
+
+import java.util.{HashMap => JavaHashMap}
+
+import org.apache.spark.sql.catalyst.expressions.{Projection, Row}
+import org.apache.spark.util.collection.CompactBuffer
+
+
+/**
+ * Interface for a hashed relation by some key. Use [[HashedRelation.apply]] to create a concrete
+ * object.
+ */
+private[joins] sealed trait HashedRelation {
+ def get(key: Row): CompactBuffer[Row]
+}
+
+
+/**
+ * A general [[HashedRelation]] backed by a hash map that maps the key into a sequence of values.
+ */
+private[joins] final class GeneralHashedRelation(hashTable: JavaHashMap[Row, CompactBuffer[Row]])
+ extends HashedRelation with Serializable {
+
+ override def get(key: Row) = hashTable.get(key)
+}
+
+
+/**
+ * A specialized [[HashedRelation]] that maps key into a single value. This implementation
+ * assumes the key is unique.
+ */
+private[joins] final class UniqueKeyHashedRelation(hashTable: JavaHashMap[Row, Row])
+ extends HashedRelation with Serializable {
+
+ override def get(key: Row) = {
+ val v = hashTable.get(key)
+ if (v eq null) null else CompactBuffer(v)
+ }
+
+ def getValue(key: Row): Row = hashTable.get(key)
+}
+
+
+// TODO(rxin): a version of [[HashedRelation]] backed by arrays for consecutive integer keys.
+
+
+private[joins] object HashedRelation {
+
+ def apply(
+ input: Iterator[Row],
+ keyGenerator: Projection,
+ sizeEstimate: Int = 64): HashedRelation = {
+
+ // TODO: Use Spark's HashMap implementation.
+ val hashTable = new JavaHashMap[Row, CompactBuffer[Row]](sizeEstimate)
+ var currentRow: Row = null
+
+ // Whether the join key is unique. If the key is unique, we can convert the underlying
+ // hash map into one specialized for this.
+ var keyIsUnique = true
+
+ // Create a mapping of buildKeys -> rows
+ while (input.hasNext) {
+ currentRow = input.next()
+ val rowKey = keyGenerator(currentRow)
+ if (!rowKey.anyNull) {
+ val existingMatchList = hashTable.get(rowKey)
+ val matchList = if (existingMatchList == null) {
+ val newMatchList = new CompactBuffer[Row]()
+ hashTable.put(rowKey, newMatchList)
+ newMatchList
+ } else {
+ keyIsUnique = false
+ existingMatchList
+ }
+ matchList += currentRow.copy()
+ }
+ }
+
+ if (keyIsUnique) {
+ val uniqHashTable = new JavaHashMap[Row, Row](hashTable.size)
+ val iter = hashTable.entrySet().iterator()
+ while (iter.hasNext) {
+ val entry = iter.next()
+ uniqHashTable.put(entry.getKey, entry.getValue()(0))
+ }
+ new UniqueKeyHashedRelation(uniqHashTable)
+ } else {
+ new GeneralHashedRelation(hashTable)
+ }
+ }
+}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/LeftSemiJoinBNL.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/LeftSemiJoinBNL.scala
new file mode 100644
index 0000000000000..60003d1900d85
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/LeftSemiJoinBNL.scala
@@ -0,0 +1,73 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.joins
+
+import org.apache.spark.annotation.DeveloperApi
+import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.plans.physical.Partitioning
+import org.apache.spark.sql.execution.{BinaryNode, SparkPlan}
+
+/**
+ * :: DeveloperApi ::
+ * Using BroadcastNestedLoopJoin to calculate left semi join result when there's no join keys
+ * for hash join.
+ */
+@DeveloperApi
+case class LeftSemiJoinBNL(
+ streamed: SparkPlan, broadcast: SparkPlan, condition: Option[Expression])
+ extends BinaryNode {
+ // TODO: Override requiredChildDistribution.
+
+ override def outputPartitioning: Partitioning = streamed.outputPartitioning
+
+ override def output = left.output
+
+ /** The Streamed Relation */
+ override def left = streamed
+ /** The Broadcast relation */
+ override def right = broadcast
+
+ @transient private lazy val boundCondition =
+ InterpretedPredicate(
+ condition
+ .map(c => BindReferences.bindReference(c, left.output ++ right.output))
+ .getOrElse(Literal(true)))
+
+ override def execute() = {
+ val broadcastedRelation =
+ sparkContext.broadcast(broadcast.execute().map(_.copy()).collect().toIndexedSeq)
+
+ streamed.execute().mapPartitions { streamedIter =>
+ val joinedRow = new JoinedRow
+
+ streamedIter.filter(streamedRow => {
+ var i = 0
+ var matched = false
+
+ while (i < broadcastedRelation.value.size && !matched) {
+ val broadcastedRow = broadcastedRelation.value(i)
+ if (boundCondition(joinedRow(streamedRow, broadcastedRow))) {
+ matched = true
+ }
+ i += 1
+ }
+ matched
+ })
+ }
+ }
+}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/LeftSemiJoinHash.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/LeftSemiJoinHash.scala
new file mode 100644
index 0000000000000..ea7babf3be948
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/LeftSemiJoinHash.scala
@@ -0,0 +1,67 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.joins
+
+import org.apache.spark.annotation.DeveloperApi
+import org.apache.spark.sql.catalyst.expressions.{Expression, Row}
+import org.apache.spark.sql.catalyst.plans.physical.ClusteredDistribution
+import org.apache.spark.sql.execution.{BinaryNode, SparkPlan}
+
+/**
+ * :: DeveloperApi ::
+ * Build the right table's join keys into a HashSet, and iteratively go through the left
+ * table, to find the if join keys are in the Hash set.
+ */
+@DeveloperApi
+case class LeftSemiJoinHash(
+ leftKeys: Seq[Expression],
+ rightKeys: Seq[Expression],
+ left: SparkPlan,
+ right: SparkPlan) extends BinaryNode with HashJoin {
+
+ override val buildSide = BuildRight
+
+ override def requiredChildDistribution =
+ ClusteredDistribution(leftKeys) :: ClusteredDistribution(rightKeys) :: Nil
+
+ override def output = left.output
+
+ override def execute() = {
+ buildPlan.execute().zipPartitions(streamedPlan.execute()) { (buildIter, streamIter) =>
+ val hashSet = new java.util.HashSet[Row]()
+ var currentRow: Row = null
+
+ // Create a Hash set of buildKeys
+ while (buildIter.hasNext) {
+ currentRow = buildIter.next()
+ val rowKey = buildSideKeyGenerator(currentRow)
+ if (!rowKey.anyNull) {
+ val keyExists = hashSet.contains(rowKey)
+ if (!keyExists) {
+ hashSet.add(rowKey)
+ }
+ }
+ }
+
+ val joinKeys = streamSideKeyGenerator()
+ streamIter.filter(current => {
+ !joinKeys(current).anyNull && hashSet.contains(joinKeys.currentValue)
+ })
+ }
+ }
+}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoin.scala
new file mode 100644
index 0000000000000..418c1c23e5546
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoin.scala
@@ -0,0 +1,50 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.joins
+
+import org.apache.spark.annotation.DeveloperApi
+import org.apache.spark.sql.catalyst.expressions.Expression
+import org.apache.spark.sql.catalyst.plans.physical.{ClusteredDistribution, Partitioning}
+import org.apache.spark.sql.execution.{BinaryNode, SparkPlan}
+
+/**
+ * :: DeveloperApi ::
+ * Performs an inner hash join of two child relations by first shuffling the data using the join
+ * keys.
+ */
+@DeveloperApi
+case class ShuffledHashJoin(
+ leftKeys: Seq[Expression],
+ rightKeys: Seq[Expression],
+ buildSide: BuildSide,
+ left: SparkPlan,
+ right: SparkPlan)
+ extends BinaryNode with HashJoin {
+
+ override def outputPartitioning: Partitioning = left.outputPartitioning
+
+ override def requiredChildDistribution =
+ ClusteredDistribution(leftKeys) :: ClusteredDistribution(rightKeys) :: Nil
+
+ override def execute() = {
+ buildPlan.execute().zipPartitions(streamedPlan.execute()) { (buildIter, streamIter) =>
+ val hashed = HashedRelation(buildIter, buildSideKeyGenerator)
+ hashJoin(streamIter, hashed)
+ }
+ }
+}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/package.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/package.scala
new file mode 100644
index 0000000000000..7f2ab1765b28f
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/package.scala
@@ -0,0 +1,37 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution
+
+import org.apache.spark.annotation.DeveloperApi
+
+/**
+ * :: DeveloperApi ::
+ * Physical execution operators for join operations.
+ */
+package object joins {
+
+ @DeveloperApi
+ sealed abstract class BuildSide
+
+ @DeveloperApi
+ case object BuildRight extends BuildSide
+
+ @DeveloperApi
+ case object BuildLeft extends BuildSide
+
+}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUdfs.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUdfs.scala
index 0977da3e8577c..be729e5d244b0 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUdfs.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUdfs.scala
@@ -105,13 +105,21 @@ private[spark] object ExtractPythonUdfs extends Rule[LogicalPlan] {
}
}
+object EvaluatePython {
+ def apply(udf: PythonUDF, child: LogicalPlan) =
+ new EvaluatePython(udf, child, AttributeReference("pythonUDF", udf.dataType)())
+}
+
/**
* :: DeveloperApi ::
* Evaluates a [[PythonUDF]], appending the result to the end of the input tuple.
*/
@DeveloperApi
-case class EvaluatePython(udf: PythonUDF, child: LogicalPlan) extends logical.UnaryNode {
- val resultAttribute = AttributeReference("pythonUDF", udf.dataType, nullable=true)()
+case class EvaluatePython(
+ udf: PythonUDF,
+ child: LogicalPlan,
+ resultAttribute: AttributeReference)
+ extends logical.UnaryNode {
def output = child.output :+ resultAttribute
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/json/JsonRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/json/JsonRDD.scala
index 0f27fd13e7379..61ee960aad9d2 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/json/JsonRDD.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/json/JsonRDD.scala
@@ -20,7 +20,9 @@ package org.apache.spark.sql.json
import scala.collection.Map
import scala.collection.convert.Wrappers.{JMapWrapper, JListWrapper}
import scala.math.BigDecimal
+import java.sql.Timestamp
+import com.fasterxml.jackson.core.JsonProcessingException
import com.fasterxml.jackson.databind.ObjectMapper
import org.apache.spark.rdd.RDD
@@ -34,16 +36,19 @@ private[sql] object JsonRDD extends Logging {
private[sql] def jsonStringToRow(
json: RDD[String],
- schema: StructType): RDD[Row] = {
- parseJson(json).map(parsed => asRow(parsed, schema))
+ schema: StructType,
+ columnNameOfCorruptRecords: String): RDD[Row] = {
+ parseJson(json, columnNameOfCorruptRecords).map(parsed => asRow(parsed, schema))
}
private[sql] def inferSchema(
json: RDD[String],
- samplingRatio: Double = 1.0): StructType = {
+ samplingRatio: Double = 1.0,
+ columnNameOfCorruptRecords: String): StructType = {
require(samplingRatio > 0, s"samplingRatio ($samplingRatio) should be greater than 0")
val schemaData = if (samplingRatio > 0.99) json else json.sample(false, samplingRatio, 1)
- val allKeys = parseJson(schemaData).map(allKeysWithValueTypes).reduce(_ ++ _)
+ val allKeys =
+ parseJson(schemaData, columnNameOfCorruptRecords).map(allKeysWithValueTypes).reduce(_ ++ _)
createSchema(allKeys)
}
@@ -273,7 +278,9 @@ private[sql] object JsonRDD extends Logging {
case atom => atom
}
- private def parseJson(json: RDD[String]): RDD[Map[String, Any]] = {
+ private def parseJson(
+ json: RDD[String],
+ columnNameOfCorruptRecords: String): RDD[Map[String, Any]] = {
// According to [Jackson-72: https://jira.codehaus.org/browse/JACKSON-72],
// ObjectMapper will not return BigDecimal when
// "DeserializationFeature.USE_BIG_DECIMAL_FOR_FLOATS" is disabled
@@ -288,12 +295,16 @@ private[sql] object JsonRDD extends Logging {
// For example: for {"key": 1, "key":2}, we will get "key"->2.
val mapper = new ObjectMapper()
iter.flatMap { record =>
- val parsed = mapper.readValue(record, classOf[Object]) match {
- case map: java.util.Map[_, _] => scalafy(map).asInstanceOf[Map[String, Any]] :: Nil
- case list: java.util.List[_] => scalafy(list).asInstanceOf[Seq[Map[String, Any]]]
- }
+ try {
+ val parsed = mapper.readValue(record, classOf[Object]) match {
+ case map: java.util.Map[_, _] => scalafy(map).asInstanceOf[Map[String, Any]] :: Nil
+ case list: java.util.List[_] => scalafy(list).asInstanceOf[Seq[Map[String, Any]]]
+ }
- parsed
+ parsed
+ } catch {
+ case e: JsonProcessingException => Map(columnNameOfCorruptRecords -> record) :: Nil
+ }
}
})
}
@@ -361,6 +372,14 @@ private[sql] object JsonRDD extends Logging {
}
}
+ private def toTimestamp(value: Any): Timestamp = {
+ value match {
+ case value: java.lang.Integer => new Timestamp(value.asInstanceOf[Int].toLong)
+ case value: java.lang.Long => new Timestamp(value)
+ case value: java.lang.String => Timestamp.valueOf(value)
+ }
+ }
+
private[json] def enforceCorrectType(value: Any, desiredType: DataType): Any ={
if (value == null) {
null
@@ -377,6 +396,7 @@ private[sql] object JsonRDD extends Logging {
case ArrayType(elementType, _) =>
value.asInstanceOf[Seq[Any]].map(enforceCorrectType(_, elementType))
case struct: StructType => asRow(value.asInstanceOf[Map[String, Any]], struct)
+ case TimestampType => toTimestamp(value)
}
}
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/package.scala b/sql/core/src/main/scala/org/apache/spark/sql/package.scala
index f513eae9c2d13..e98d151286818 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/package.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/package.scala
@@ -165,6 +165,16 @@ package object sql {
@DeveloperApi
val TimestampType = catalyst.types.TimestampType
+ /**
+ * :: DeveloperApi ::
+ *
+ * The data type representing `java.sql.Date` values.
+ *
+ * @group dataType
+ */
+ @DeveloperApi
+ val DateType = catalyst.types.DateType
+
/**
* :: DeveloperApi ::
*
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableOperations.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableOperations.scala
index ffb732347d30a..5c6fa78ae3895 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableOperations.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableOperations.scala
@@ -289,9 +289,9 @@ case class InsertIntoParquetTable(
def writeShard(context: TaskContext, iter: Iterator[Row]): Int = {
// Hadoop wants a 32-bit task attempt ID, so if ours is bigger than Int.MaxValue, roll it
// around by taking a mod. We expect that no task will be attempted 2 billion times.
- val attemptNumber = (context.getAttemptId % Int.MaxValue).toInt
+ val attemptNumber = (context.attemptId % Int.MaxValue).toInt
/* "reduce task" */
- val attemptId = newTaskAttemptID(jobtrackerID, stageId, isMap = false, context.getPartitionId,
+ val attemptId = newTaskAttemptID(jobtrackerID, stageId, isMap = false, context.partitionId,
attemptNumber)
val hadoopContext = newTaskAttemptContext(wrappedConf.value, attemptId)
val format = new AppendingParquetOutputFormat(taskIdOffset)
@@ -331,13 +331,21 @@ private[parquet] class AppendingParquetOutputFormat(offset: Int)
// override to choose output filename so not overwrite existing ones
override def getDefaultWorkFile(context: TaskAttemptContext, extension: String): Path = {
- val taskId: TaskID = context.getTaskAttemptID.getTaskID
+ val taskId: TaskID = getTaskAttemptID(context).getTaskID
val partition: Int = taskId.getId
val filename = s"part-r-${partition + offset}.parquet"
val committer: FileOutputCommitter =
getOutputCommitter(context).asInstanceOf[FileOutputCommitter]
new Path(committer.getWorkPath, filename)
}
+
+ // The TaskAttemptContext is a class in hadoop-1 but is an interface in hadoop-2.
+ // The signatures of the method TaskAttemptContext.getTaskAttemptID for the both versions
+ // are the same, so the method calls are source-compatible but NOT binary-compatible because
+ // the opcode of method call for class is INVOKEVIRTUAL and for interface is INVOKEINTERFACE.
+ private def getTaskAttemptID(context: TaskAttemptContext): TaskAttemptID = {
+ context.getClass.getMethod("getTaskAttemptID").invoke(context).asInstanceOf[TaskAttemptID]
+ }
}
/**
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTypes.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTypes.scala
index 2941b9793597f..e6389cf77a4c9 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTypes.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTypes.scala
@@ -19,6 +19,8 @@ package org.apache.spark.sql.parquet
import java.io.IOException
+import scala.util.Try
+
import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.fs.{FileSystem, Path}
import org.apache.hadoop.mapreduce.Job
@@ -323,14 +325,14 @@ private[parquet] object ParquetTypesConverter extends Logging {
}
def convertFromString(string: String): Seq[Attribute] = {
- DataType(string) match {
+ Try(DataType.fromJson(string)).getOrElse(DataType.fromCaseClassString(string)) match {
case s: StructType => s.toAttributes
case other => sys.error(s"Can convert $string to row")
}
}
def convertToString(schema: Seq[Attribute]): String = {
- StructType.fromAttributes(schema).toString
+ StructType.fromAttributes(schema).json
}
def writeMetaData(attributes: Seq[Attribute], origPath: Path, conf: Configuration): Unit = {
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/types/util/DataTypeConversions.scala b/sql/core/src/main/scala/org/apache/spark/sql/types/util/DataTypeConversions.scala
index 77353f4eb0227..e44cb08309523 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/types/util/DataTypeConversions.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/types/util/DataTypeConversions.scala
@@ -41,6 +41,7 @@ protected[sql] object DataTypeConversions {
case StringType => JDataType.StringType
case BinaryType => JDataType.BinaryType
case BooleanType => JDataType.BooleanType
+ case DateType => JDataType.DateType
case TimestampType => JDataType.TimestampType
case DecimalType => JDataType.DecimalType
case DoubleType => JDataType.DoubleType
@@ -80,6 +81,8 @@ protected[sql] object DataTypeConversions {
BinaryType
case booleanType: org.apache.spark.sql.api.java.BooleanType =>
BooleanType
+ case dateType: org.apache.spark.sql.api.java.DateType =>
+ DateType
case timestampType: org.apache.spark.sql.api.java.TimestampType =>
TimestampType
case decimalType: org.apache.spark.sql.api.java.DecimalType =>
diff --git a/sql/core/src/test/resources/log4j.properties b/sql/core/src/test/resources/log4j.properties
index c7e0ff1cf6494..fbed0a782dd3e 100644
--- a/sql/core/src/test/resources/log4j.properties
+++ b/sql/core/src/test/resources/log4j.properties
@@ -30,7 +30,7 @@ log4j.appender.FA=org.apache.log4j.FileAppender
log4j.appender.FA.append=false
log4j.appender.FA.file=target/unit-tests.log
log4j.appender.FA.layout=org.apache.log4j.PatternLayout
-log4j.appender.FA.layout.ConversionPattern=%d{HH:mm:ss.SSS} %p %c{1}: %m%n
+log4j.appender.FA.layout.ConversionPattern=%d{HH:mm:ss.SSS} %t %p %c{1}: %m%n
# Set the logger level of File Appender to WARN
log4j.appender.FA.Threshold = INFO
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala
index 591592841e9fe..444bc95009c31 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala
@@ -18,118 +18,188 @@
package org.apache.spark.sql
import org.apache.spark.sql.TestData._
-import org.apache.spark.sql.columnar.{InMemoryRelation, InMemoryColumnarTableScan}
-import org.apache.spark.sql.test.TestSQLContext
+import org.apache.spark.sql.columnar.{InMemoryColumnarTableScan, InMemoryRelation}
import org.apache.spark.sql.test.TestSQLContext._
+import org.apache.spark.storage.{StorageLevel, RDDBlockId}
case class BigData(s: String)
class CachedTableSuite extends QueryTest {
TestData // Load test tables.
+ def assertCached(query: SchemaRDD, numCachedTables: Int = 1): Unit = {
+ val planWithCaching = query.queryExecution.withCachedData
+ val cachedData = planWithCaching collect {
+ case cached: InMemoryRelation => cached
+ }
+
+ assert(
+ cachedData.size == numCachedTables,
+ s"Expected query to contain $numCachedTables, but it actually had ${cachedData.size}\n" +
+ planWithCaching)
+ }
+
+ def rddIdOf(tableName: String): Int = {
+ val executedPlan = table(tableName).queryExecution.executedPlan
+ executedPlan.collect {
+ case InMemoryColumnarTableScan(_, _, relation) =>
+ relation.cachedColumnBuffers.id
+ case _ =>
+ fail(s"Table $tableName is not cached\n" + executedPlan)
+ }.head
+ }
+
+ def isMaterialized(rddId: Int): Boolean = {
+ sparkContext.env.blockManager.get(RDDBlockId(rddId, 0)).nonEmpty
+ }
+
test("too big for memory") {
val data = "*" * 10000
- sparkContext.parallelize(1 to 1000000, 1).map(_ => BigData(data)).registerTempTable("bigData")
- cacheTable("bigData")
- assert(table("bigData").count() === 1000000L)
- uncacheTable("bigData")
+ sparkContext.parallelize(1 to 200000, 1).map(_ => BigData(data)).registerTempTable("bigData")
+ table("bigData").persist(StorageLevel.MEMORY_AND_DISK)
+ assert(table("bigData").count() === 200000L)
+ table("bigData").unpersist()
+ }
+
+ test("calling .cache() should use in-memory columnar caching") {
+ table("testData").cache()
+ assertCached(table("testData"))
+ }
+
+ test("calling .unpersist() should drop in-memory columnar cache") {
+ table("testData").cache()
+ table("testData").count()
+ table("testData").unpersist(blocking = true)
+ assertCached(table("testData"), 0)
+ }
+
+ test("isCached") {
+ cacheTable("testData")
+
+ assertCached(table("testData"))
+ assert(table("testData").queryExecution.withCachedData match {
+ case _: InMemoryRelation => true
+ case _ => false
+ })
+
+ uncacheTable("testData")
+ assert(!isCached("testData"))
+ assert(table("testData").queryExecution.withCachedData match {
+ case _: InMemoryRelation => false
+ case _ => true
+ })
}
test("SPARK-1669: cacheTable should be idempotent") {
assume(!table("testData").logicalPlan.isInstanceOf[InMemoryRelation])
cacheTable("testData")
- table("testData").queryExecution.analyzed match {
- case _: InMemoryRelation =>
- case _ =>
- fail("testData should be cached")
+ assertCached(table("testData"))
+
+ assertResult(1, "InMemoryRelation not found, testData should have been cached") {
+ table("testData").queryExecution.withCachedData.collect {
+ case r: InMemoryRelation => r
+ }.size
}
cacheTable("testData")
- table("testData").queryExecution.analyzed match {
- case InMemoryRelation(_, _, _, _: InMemoryColumnarTableScan) =>
- fail("cacheTable is not idempotent")
-
- case _ =>
+ assertResult(0, "Double InMemoryRelations found, cacheTable() is not idempotent") {
+ table("testData").queryExecution.withCachedData.collect {
+ case r @ InMemoryRelation(_, _, _, _, _: InMemoryColumnarTableScan) => r
+ }.size
}
}
test("read from cached table and uncache") {
- TestSQLContext.cacheTable("testData")
-
- checkAnswer(
- TestSQLContext.table("testData"),
- testData.collect().toSeq
- )
-
- TestSQLContext.table("testData").queryExecution.analyzed match {
- case _ : InMemoryRelation => // Found evidence of caching
- case noCache => fail(s"No cache node found in plan $noCache")
- }
-
- TestSQLContext.uncacheTable("testData")
+ cacheTable("testData")
+ checkAnswer(table("testData"), testData.collect().toSeq)
+ assertCached(table("testData"))
- checkAnswer(
- TestSQLContext.table("testData"),
- testData.collect().toSeq
- )
-
- TestSQLContext.table("testData").queryExecution.analyzed match {
- case cachePlan: InMemoryRelation =>
- fail(s"Table still cached after uncache: $cachePlan")
- case noCache => // Table uncached successfully
- }
+ uncacheTable("testData")
+ checkAnswer(table("testData"), testData.collect().toSeq)
+ assertCached(table("testData"), 0)
}
test("correct error on uncache of non-cached table") {
intercept[IllegalArgumentException] {
- TestSQLContext.uncacheTable("testData")
+ uncacheTable("testData")
}
}
- test("SELECT Star Cached Table") {
- TestSQLContext.sql("SELECT * FROM testData").registerTempTable("selectStar")
- TestSQLContext.cacheTable("selectStar")
- TestSQLContext.sql("SELECT * FROM selectStar WHERE key = 1").collect()
- TestSQLContext.uncacheTable("selectStar")
+ test("SELECT star from cached table") {
+ sql("SELECT * FROM testData").registerTempTable("selectStar")
+ cacheTable("selectStar")
+ checkAnswer(
+ sql("SELECT * FROM selectStar WHERE key = 1"),
+ Seq(Row(1, "1")))
+ uncacheTable("selectStar")
}
test("Self-join cached") {
val unCachedAnswer =
- TestSQLContext.sql("SELECT * FROM testData a JOIN testData b ON a.key = b.key").collect()
- TestSQLContext.cacheTable("testData")
+ sql("SELECT * FROM testData a JOIN testData b ON a.key = b.key").collect()
+ cacheTable("testData")
checkAnswer(
- TestSQLContext.sql("SELECT * FROM testData a JOIN testData b ON a.key = b.key"),
+ sql("SELECT * FROM testData a JOIN testData b ON a.key = b.key"),
unCachedAnswer.toSeq)
- TestSQLContext.uncacheTable("testData")
+ uncacheTable("testData")
}
test("'CACHE TABLE' and 'UNCACHE TABLE' SQL statement") {
- TestSQLContext.sql("CACHE TABLE testData")
- TestSQLContext.table("testData").queryExecution.executedPlan match {
- case _: InMemoryColumnarTableScan => // Found evidence of caching
- case _ => fail(s"Table 'testData' should be cached")
- }
- assert(TestSQLContext.isCached("testData"), "Table 'testData' should be cached")
+ sql("CACHE TABLE testData")
+ assertCached(table("testData"))
- TestSQLContext.sql("UNCACHE TABLE testData")
- TestSQLContext.table("testData").queryExecution.executedPlan match {
- case _: InMemoryColumnarTableScan => fail(s"Table 'testData' should not be cached")
- case _ => // Found evidence of uncaching
- }
- assert(!TestSQLContext.isCached("testData"), "Table 'testData' should not be cached")
- }
-
- test("CACHE TABLE tableName AS SELECT Star Table") {
- TestSQLContext.sql("CACHE TABLE testCacheTable AS SELECT * FROM testData")
- TestSQLContext.sql("SELECT * FROM testCacheTable WHERE key = 1").collect()
- assert(TestSQLContext.isCached("testCacheTable"), "Table 'testCacheTable' should be cached")
- TestSQLContext.uncacheTable("testCacheTable")
- }
-
- test("'CACHE TABLE tableName AS SELECT ..'") {
- TestSQLContext.sql("CACHE TABLE testCacheTable AS SELECT * FROM testData")
- assert(TestSQLContext.isCached("testCacheTable"), "Table 'testCacheTable' should be cached")
- TestSQLContext.uncacheTable("testCacheTable")
+ val rddId = rddIdOf("testData")
+ assert(
+ isMaterialized(rddId),
+ "Eagerly cached in-memory table should have already been materialized")
+
+ sql("UNCACHE TABLE testData")
+ assert(!isCached("testData"), "Table 'testData' should not be cached")
+ assert(!isMaterialized(rddId), "Uncached in-memory table should have been unpersisted")
+ }
+
+ test("CACHE TABLE tableName AS SELECT * FROM anotherTable") {
+ sql("CACHE TABLE testCacheTable AS SELECT * FROM testData")
+ assertCached(table("testCacheTable"))
+
+ val rddId = rddIdOf("testCacheTable")
+ assert(
+ isMaterialized(rddId),
+ "Eagerly cached in-memory table should have already been materialized")
+
+ uncacheTable("testCacheTable")
+ assert(!isMaterialized(rddId), "Uncached in-memory table should have been unpersisted")
+ }
+
+ test("CACHE TABLE tableName AS SELECT ...") {
+ sql("CACHE TABLE testCacheTable AS SELECT key FROM testData LIMIT 10")
+ assertCached(table("testCacheTable"))
+
+ val rddId = rddIdOf("testCacheTable")
+ assert(
+ isMaterialized(rddId),
+ "Eagerly cached in-memory table should have already been materialized")
+
+ uncacheTable("testCacheTable")
+ assert(!isMaterialized(rddId), "Uncached in-memory table should have been unpersisted")
+ }
+
+ test("CACHE LAZY TABLE tableName") {
+ sql("CACHE LAZY TABLE testData")
+ assertCached(table("testData"))
+
+ val rddId = rddIdOf("testData")
+ assert(
+ !isMaterialized(rddId),
+ "Lazily cached in-memory table shouldn't be materialized eagerly")
+
+ sql("SELECT COUNT(*) FROM testData").collect()
+ assert(
+ isMaterialized(rddId),
+ "Lazily cached in-memory table should have been materialized")
+
+ uncacheTable("testData")
+ assert(!isMaterialized(rddId), "Uncached in-memory table should have been unpersisted")
}
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataTypeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataTypeSuite.scala
index 8fb59c5830f6d..100ecb45e9e88 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataTypeSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataTypeSuite.scala
@@ -19,6 +19,8 @@ package org.apache.spark.sql
import org.scalatest.FunSuite
+import org.apache.spark.sql.catalyst.types.DataType
+
class DataTypeSuite extends FunSuite {
test("construct an ArrayType") {
@@ -55,4 +57,30 @@ class DataTypeSuite extends FunSuite {
struct(Set("b", "d", "e", "f"))
}
}
+
+ def checkDataTypeJsonRepr(dataType: DataType): Unit = {
+ test(s"JSON - $dataType") {
+ assert(DataType.fromJson(dataType.json) === dataType)
+ }
+ }
+
+ checkDataTypeJsonRepr(BooleanType)
+ checkDataTypeJsonRepr(ByteType)
+ checkDataTypeJsonRepr(ShortType)
+ checkDataTypeJsonRepr(IntegerType)
+ checkDataTypeJsonRepr(LongType)
+ checkDataTypeJsonRepr(FloatType)
+ checkDataTypeJsonRepr(DoubleType)
+ checkDataTypeJsonRepr(DecimalType)
+ checkDataTypeJsonRepr(TimestampType)
+ checkDataTypeJsonRepr(StringType)
+ checkDataTypeJsonRepr(BinaryType)
+ checkDataTypeJsonRepr(ArrayType(DoubleType, true))
+ checkDataTypeJsonRepr(ArrayType(StringType, false))
+ checkDataTypeJsonRepr(MapType(IntegerType, StringType, true))
+ checkDataTypeJsonRepr(MapType(IntegerType, ArrayType(DoubleType), false))
+ checkDataTypeJsonRepr(
+ StructType(Seq(
+ StructField("a", IntegerType, nullable = true),
+ StructField("b", ArrayType(DoubleType), nullable = false))))
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DslQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DslQuerySuite.scala
index d001abb7e1fcc..45e58afe9d9a2 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DslQuerySuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DslQuerySuite.scala
@@ -147,6 +147,14 @@ class DslQuerySuite extends QueryTest {
(1, 1, 1, 2) :: Nil)
}
+ test("SPARK-3858 generator qualifiers are discarded") {
+ checkAnswer(
+ arrayData.as('ad)
+ .generate(Explode("data" :: Nil, 'data), alias = Some("ex"))
+ .select("ex.data".attr),
+ Seq(1, 2, 3, 2, 3, 4).map(Seq(_)))
+ }
+
test("average") {
checkAnswer(
testData2.groupBy()(avg('a)),
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala
index 6c7697ece8c56..07f4d2946c1b5 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala
@@ -24,6 +24,7 @@ import org.apache.spark.sql.TestData._
import org.apache.spark.sql.catalyst.plans.JoinType
import org.apache.spark.sql.catalyst.plans.{LeftOuter, RightOuter, FullOuter, Inner, LeftSemi}
import org.apache.spark.sql.execution._
+import org.apache.spark.sql.execution.joins._
import org.apache.spark.sql.test.TestSQLContext
import org.apache.spark.sql.test.TestSQLContext._
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
index 08376eb5e5c4e..15f6ba4f72bbd 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
@@ -19,7 +19,8 @@ package org.apache.spark.sql
import org.apache.spark.sql.catalyst.errors.TreeNodeException
import org.apache.spark.sql.catalyst.expressions._
-import org.apache.spark.sql.execution.{ShuffledHashJoin, BroadcastHashJoin}
+import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
+import org.apache.spark.sql.execution.joins.BroadcastHashJoin
import org.apache.spark.sql.test._
import org.scalatest.BeforeAndAfterAll
import java.util.TimeZone
@@ -42,7 +43,6 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll {
TimeZone.setDefault(origZone)
}
-
test("SPARK-3176 Added Parser of SQL ABS()") {
checkAnswer(
sql("SELECT ABS(-1.3)"),
@@ -61,7 +61,6 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll {
4)
}
-
test("SPARK-2041 column name equals tablename") {
checkAnswer(
sql("SELECT tableName FROM tableName"),
@@ -190,6 +189,14 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll {
sql("SELECT * FROM testData2 ORDER BY a DESC, b ASC"),
Seq((3,1), (3,2), (2,1), (2,2), (1,1), (1,2)))
+ checkAnswer(
+ sql("SELECT b FROM binaryData ORDER BY a ASC"),
+ (1 to 5).map(Row(_)).toSeq)
+
+ checkAnswer(
+ sql("SELECT b FROM binaryData ORDER BY a DESC"),
+ (1 to 5).map(Row(_)).toSeq.reverse)
+
checkAnswer(
sql("SELECT * FROM arrayData ORDER BY data[0] ASC"),
arrayData.collect().sortBy(_.data(0)).toSeq)
@@ -672,4 +679,45 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll {
sql("SELECT CAST(TRUE AS STRING), CAST(FALSE AS STRING) FROM testData LIMIT 1"),
("true", "false") :: Nil)
}
+
+ test("SPARK-3371 Renaming a function expression with group by gives error") {
+ registerFunction("len", (s: String) => s.length)
+ checkAnswer(
+ sql("SELECT len(value) as temp FROM testData WHERE key = 1 group by len(value)"), 1)
+ }
+
+ test("SPARK-3813 CASE a WHEN b THEN c [WHEN d THEN e]* [ELSE f] END") {
+ checkAnswer(
+ sql("SELECT CASE key WHEN 1 THEN 1 ELSE 0 END FROM testData WHERE key = 1 group by key"), 1)
+ }
+
+ test("SPARK-3813 CASE WHEN a THEN b [WHEN c THEN d]* [ELSE e] END") {
+ checkAnswer(
+ sql("SELECT CASE WHEN key = 1 THEN 1 ELSE 2 END FROM testData WHERE key = 1 group by key"), 1)
+ }
+
+ test("throw errors for non-aggregate attributes with aggregation") {
+ def checkAggregation(query: String, isInvalidQuery: Boolean = true) {
+ val logicalPlan = sql(query).queryExecution.logical
+
+ if (isInvalidQuery) {
+ val e = intercept[TreeNodeException[LogicalPlan]](sql(query).queryExecution.analyzed)
+ assert(
+ e.getMessage.startsWith("Expression not in GROUP BY"),
+ "Non-aggregate attribute(s) not detected\n" + logicalPlan)
+ } else {
+ // Should not throw
+ sql(query).queryExecution.analyzed
+ }
+ }
+
+ checkAggregation("SELECT key, COUNT(*) FROM testData")
+ checkAggregation("SELECT COUNT(key), COUNT(*) FROM testData", false)
+
+ checkAggregation("SELECT value, COUNT(*) FROM testData GROUP BY key")
+ checkAggregation("SELECT COUNT(value), SUM(key) FROM testData GROUP BY key", false)
+
+ checkAggregation("SELECT key + 2, COUNT(*) FROM testData GROUP BY key + 1")
+ checkAggregation("SELECT key + 1 + 1, COUNT(*) FROM testData GROUP BY key + 1", false)
+ }
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ScalaReflectionRelationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ScalaReflectionRelationSuite.scala
index e24c521d24c7a..bfa9ea416266d 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/ScalaReflectionRelationSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/ScalaReflectionRelationSuite.scala
@@ -17,7 +17,7 @@
package org.apache.spark.sql
-import java.sql.Timestamp
+import java.sql.{Date, Timestamp}
import org.scalatest.FunSuite
@@ -34,6 +34,7 @@ case class ReflectData(
byteField: Byte,
booleanField: Boolean,
decimalField: BigDecimal,
+ date: Date,
timestampField: Timestamp,
seqInt: Seq[Int])
@@ -76,7 +77,7 @@ case class ComplexReflectData(
class ScalaReflectionRelationSuite extends FunSuite {
test("query case class RDD") {
val data = ReflectData("a", 1, 1L, 1.toFloat, 1.toDouble, 1.toShort, 1.toByte, true,
- BigDecimal(1), new Timestamp(12345), Seq(1,2,3))
+ BigDecimal(1), new Date(12345), new Timestamp(12345), Seq(1,2,3))
val rdd = sparkContext.parallelize(data :: Nil)
rdd.registerTempTable("reflectData")
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala b/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala
index eb33a61c6e811..10b7979df7375 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala
@@ -54,6 +54,16 @@ object TestData {
TestData2(3, 2) :: Nil)
testData2.registerTempTable("testData2")
+ case class BinaryData(a: Array[Byte], b: Int)
+ val binaryData: SchemaRDD =
+ TestSQLContext.sparkContext.parallelize(
+ BinaryData("12".getBytes(), 1) ::
+ BinaryData("22".getBytes(), 5) ::
+ BinaryData("122".getBytes(), 3) ::
+ BinaryData("121".getBytes(), 2) ::
+ BinaryData("123".getBytes(), 4) :: Nil)
+ binaryData.registerTempTable("binaryData")
+
// TODO: There is no way to express null primitives as case classes currently...
val testData3 =
logical.LocalRelation('a.int, 'b.int).loadData(
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnStatsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnStatsSuite.scala
index 0cdbb3167ce36..6bdf741134e2f 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnStatsSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnStatsSuite.scala
@@ -30,6 +30,7 @@ class ColumnStatsSuite extends FunSuite {
testColumnStats(classOf[FloatColumnStats], FLOAT, Row(Float.MaxValue, Float.MinValue, 0))
testColumnStats(classOf[DoubleColumnStats], DOUBLE, Row(Double.MaxValue, Double.MinValue, 0))
testColumnStats(classOf[StringColumnStats], STRING, Row(null, null, 0))
+ testColumnStats(classOf[DateColumnStats], DATE, Row(null, null, 0))
testColumnStats(classOf[TimestampColumnStats], TIMESTAMP, Row(null, null, 0))
def testColumnStats[T <: NativeType, U <: ColumnStats](
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnTypeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnTypeSuite.scala
index 4fb1ecf1d532b..3f3f35d50188b 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnTypeSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnTypeSuite.scala
@@ -18,7 +18,7 @@
package org.apache.spark.sql.columnar
import java.nio.ByteBuffer
-import java.sql.Timestamp
+import java.sql.{Date, Timestamp}
import org.scalatest.FunSuite
@@ -33,8 +33,8 @@ class ColumnTypeSuite extends FunSuite with Logging {
test("defaultSize") {
val checks = Map(
- INT -> 4, SHORT -> 2, LONG -> 8, BYTE -> 1, DOUBLE -> 8, FLOAT -> 4,
- BOOLEAN -> 1, STRING -> 8, TIMESTAMP -> 12, BINARY -> 16, GENERIC -> 16)
+ INT -> 4, SHORT -> 2, LONG -> 8, BYTE -> 1, DOUBLE -> 8, FLOAT -> 4, BOOLEAN -> 1,
+ STRING -> 8, DATE -> 8, TIMESTAMP -> 12, BINARY -> 16, GENERIC -> 16)
checks.foreach { case (columnType, expectedSize) =>
assertResult(expectedSize, s"Wrong defaultSize for $columnType") {
@@ -64,6 +64,7 @@ class ColumnTypeSuite extends FunSuite with Logging {
checkActualSize(FLOAT, Float.MaxValue, 4)
checkActualSize(BOOLEAN, true, 1)
checkActualSize(STRING, "hello", 4 + "hello".getBytes("utf-8").length)
+ checkActualSize(DATE, new Date(0L), 8)
checkActualSize(TIMESTAMP, new Timestamp(0L), 12)
val binary = Array.fill[Byte](4)(0: Byte)
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnarTestUtils.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnarTestUtils.scala
index 38b04dd959f70..a1f21219eaf2f 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnarTestUtils.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnarTestUtils.scala
@@ -20,7 +20,7 @@ package org.apache.spark.sql.columnar
import scala.collection.immutable.HashSet
import scala.util.Random
-import java.sql.Timestamp
+import java.sql.{Date, Timestamp}
import org.apache.spark.sql.Row
import org.apache.spark.sql.catalyst.expressions.GenericMutableRow
@@ -50,6 +50,7 @@ object ColumnarTestUtils {
case STRING => Random.nextString(Random.nextInt(32))
case BOOLEAN => Random.nextBoolean()
case BINARY => randomBytes(Random.nextInt(32))
+ case DATE => new Date(Random.nextLong())
case TIMESTAMP =>
val timestamp = new Timestamp(Random.nextLong())
timestamp.setNanos(Random.nextInt(999999999))
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/InMemoryColumnarQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/InMemoryColumnarQuerySuite.scala
index c1278248ef655..9775dd26b7773 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/InMemoryColumnarQuerySuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/InMemoryColumnarQuerySuite.scala
@@ -20,6 +20,7 @@ package org.apache.spark.sql.columnar
import org.apache.spark.sql.catalyst.expressions.Row
import org.apache.spark.sql.test.TestSQLContext
import org.apache.spark.sql.{QueryTest, TestData}
+import org.apache.spark.storage.StorageLevel.MEMORY_ONLY
class InMemoryColumnarQuerySuite extends QueryTest {
import org.apache.spark.sql.TestData._
@@ -27,7 +28,7 @@ class InMemoryColumnarQuerySuite extends QueryTest {
test("simple columnar query") {
val plan = TestSQLContext.executePlan(testData.logicalPlan).executedPlan
- val scan = InMemoryRelation(useCompression = true, 5, plan)
+ val scan = InMemoryRelation(useCompression = true, 5, MEMORY_ONLY, plan)
checkAnswer(scan, testData.collect().toSeq)
}
@@ -42,7 +43,7 @@ class InMemoryColumnarQuerySuite extends QueryTest {
test("projection") {
val plan = TestSQLContext.executePlan(testData.select('value, 'key).logicalPlan).executedPlan
- val scan = InMemoryRelation(useCompression = true, 5, plan)
+ val scan = InMemoryRelation(useCompression = true, 5, MEMORY_ONLY, plan)
checkAnswer(scan, testData.collect().map {
case Row(key: Int, value: String) => value -> key
@@ -51,7 +52,7 @@ class InMemoryColumnarQuerySuite extends QueryTest {
test("SPARK-1436 regression: in-memory columns must be able to be accessed multiple times") {
val plan = TestSQLContext.executePlan(testData.logicalPlan).executedPlan
- val scan = InMemoryRelation(useCompression = true, 5, plan)
+ val scan = InMemoryRelation(useCompression = true, 5, MEMORY_ONLY, plan)
checkAnswer(scan, testData.collect().toSeq)
checkAnswer(scan, testData.collect().toSeq)
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/NullableColumnAccessorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/NullableColumnAccessorSuite.scala
index 6c9a9ab6c3418..21906e3fdcc6f 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/NullableColumnAccessorSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/NullableColumnAccessorSuite.scala
@@ -41,7 +41,9 @@ object TestNullableColumnAccessor {
class NullableColumnAccessorSuite extends FunSuite {
import ColumnarTestUtils._
- Seq(INT, LONG, SHORT, BOOLEAN, BYTE, STRING, DOUBLE, FLOAT, BINARY, GENERIC, TIMESTAMP).foreach {
+ Seq(
+ INT, LONG, SHORT, BOOLEAN, BYTE, STRING, DOUBLE, FLOAT, BINARY, GENERIC, DATE, TIMESTAMP
+ ).foreach {
testNullableColumnAccessor(_)
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/NullableColumnBuilderSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/NullableColumnBuilderSuite.scala
index f54a21eb4fbb1..cb73f3da81e24 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/NullableColumnBuilderSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/NullableColumnBuilderSuite.scala
@@ -37,7 +37,9 @@ object TestNullableColumnBuilder {
class NullableColumnBuilderSuite extends FunSuite {
import ColumnarTestUtils._
- Seq(INT, LONG, SHORT, BOOLEAN, BYTE, STRING, DOUBLE, FLOAT, BINARY, GENERIC, TIMESTAMP).foreach {
+ Seq(
+ INT, LONG, SHORT, BOOLEAN, BYTE, STRING, DOUBLE, FLOAT, BINARY, GENERIC, DATE, TIMESTAMP
+ ).foreach {
testNullableColumnBuilder(_)
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/PartitionBatchPruningSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/PartitionBatchPruningSuite.scala
index 69e0adbd3ee0d..f53acc8c9f718 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/PartitionBatchPruningSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/PartitionBatchPruningSuite.scala
@@ -67,10 +67,11 @@ class PartitionBatchPruningSuite extends FunSuite with BeforeAndAfterAll with Be
checkBatchPruning("i > 8 AND i <= 21", 9 to 21, 2, 3)
checkBatchPruning("i < 2 OR i > 99", Seq(1, 100), 2, 2)
checkBatchPruning("i < 2 OR (i > 78 AND i < 92)", Seq(1) ++ (79 to 91), 3, 4)
+ checkBatchPruning("NOT (i < 88)", 88 to 100, 1, 2)
// With unsupported predicate
checkBatchPruning("i < 12 AND i IS NOT NULL", 1 to 11, 1, 2)
- checkBatchPruning("NOT (i < 88)", 88 to 100, 5, 10)
+ checkBatchPruning(s"NOT (i in (${(1 to 30).mkString(",")}))", 31 to 100, 5, 10)
def checkBatchPruning(
filter: String,
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala
index bfbf431a11913..f14ffca0e4d35 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala
@@ -19,10 +19,11 @@ package org.apache.spark.sql.execution
import org.scalatest.FunSuite
+import org.apache.spark.sql.{SQLConf, execution}
import org.apache.spark.sql.TestData._
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans._
-import org.apache.spark.sql.{SQLConf, execution}
+import org.apache.spark.sql.execution.joins.{BroadcastHashJoin, ShuffledHashJoin}
import org.apache.spark.sql.test.TestSQLContext._
import org.apache.spark.sql.test.TestSQLContext.planner._
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/debug/DebuggingSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/debug/DebuggingSuite.scala
new file mode 100644
index 0000000000000..87c28c334d228
--- /dev/null
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/debug/DebuggingSuite.scala
@@ -0,0 +1,33 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.debug
+
+import org.scalatest.FunSuite
+
+import org.apache.spark.sql.TestData._
+import org.apache.spark.sql.test.TestSQLContext._
+
+class DebuggingSuite extends FunSuite {
+ test("SchemaRDD.debug()") {
+ testData.debug()
+ }
+
+ test("SchemaRDD.typeCheck()") {
+ testData.typeCheck()
+ }
+}
\ No newline at end of file
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala
new file mode 100644
index 0000000000000..2aad01ded1acf
--- /dev/null
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala
@@ -0,0 +1,63 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.joins
+
+import org.scalatest.FunSuite
+
+import org.apache.spark.sql.catalyst.expressions.{Projection, Row}
+import org.apache.spark.util.collection.CompactBuffer
+
+
+class HashedRelationSuite extends FunSuite {
+
+ // Key is simply the record itself
+ private val keyProjection = new Projection {
+ override def apply(row: Row): Row = row
+ }
+
+ test("GeneralHashedRelation") {
+ val data = Array(Row(0), Row(1), Row(2), Row(2))
+ val hashed = HashedRelation(data.iterator, keyProjection)
+ assert(hashed.isInstanceOf[GeneralHashedRelation])
+
+ assert(hashed.get(data(0)) == CompactBuffer[Row](data(0)))
+ assert(hashed.get(data(1)) == CompactBuffer[Row](data(1)))
+ assert(hashed.get(Row(10)) === null)
+
+ val data2 = CompactBuffer[Row](data(2))
+ data2 += data(2)
+ assert(hashed.get(data(2)) == data2)
+ }
+
+ test("UniqueKeyHashedRelation") {
+ val data = Array(Row(0), Row(1), Row(2))
+ val hashed = HashedRelation(data.iterator, keyProjection)
+ assert(hashed.isInstanceOf[UniqueKeyHashedRelation])
+
+ assert(hashed.get(data(0)) == CompactBuffer[Row](data(0)))
+ assert(hashed.get(data(1)) == CompactBuffer[Row](data(1)))
+ assert(hashed.get(data(2)) == CompactBuffer[Row](data(2)))
+ assert(hashed.get(Row(10)) === null)
+
+ val uniqHashed = hashed.asInstanceOf[UniqueKeyHashedRelation]
+ assert(uniqHashed.getValue(data(0)) == data(0))
+ assert(uniqHashed.getValue(data(1)) == data(1))
+ assert(uniqHashed.getValue(data(2)) == data(2))
+ assert(uniqHashed.getValue(Row(10)) == null)
+ }
+}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala
index 685e788207725..7bb08f1b513ce 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala
@@ -21,8 +21,12 @@ import org.apache.spark.sql.catalyst.types._
import org.apache.spark.sql.catalyst.util._
import org.apache.spark.sql.json.JsonRDD.{enforceCorrectType, compatibleType}
import org.apache.spark.sql.QueryTest
+import org.apache.spark.sql.SQLConf
+import org.apache.spark.sql.test.TestSQLContext
import org.apache.spark.sql.test.TestSQLContext._
+import java.sql.Timestamp
+
class JsonSuite extends QueryTest {
import TestJsonData._
TestJsonData
@@ -50,6 +54,12 @@ class JsonSuite extends QueryTest {
val doubleNumber: Double = 1.7976931348623157E308d
checkTypePromotion(doubleNumber.toDouble, enforceCorrectType(doubleNumber, DoubleType))
checkTypePromotion(BigDecimal(doubleNumber), enforceCorrectType(doubleNumber, DecimalType))
+
+ checkTypePromotion(new Timestamp(intNumber), enforceCorrectType(intNumber, TimestampType))
+ checkTypePromotion(new Timestamp(intNumber.toLong),
+ enforceCorrectType(intNumber.toLong, TimestampType))
+ val strDate = "2014-09-30 12:34:56"
+ checkTypePromotion(Timestamp.valueOf(strDate), enforceCorrectType(strDate, TimestampType))
}
test("Get compatible type") {
@@ -636,7 +646,65 @@ class JsonSuite extends QueryTest {
("str_a_1", null, null) ::
("str_a_2", null, null) ::
(null, "str_b_3", null) ::
- ("str_a_4", "str_b_4", "str_c_4") ::Nil
+ ("str_a_4", "str_b_4", "str_c_4") :: Nil
+ )
+ }
+
+ test("Corrupt records") {
+ // Test if we can query corrupt records.
+ val oldColumnNameOfCorruptRecord = TestSQLContext.columnNameOfCorruptRecord
+ TestSQLContext.setConf(SQLConf.COLUMN_NAME_OF_CORRUPT_RECORD, "_unparsed")
+
+ val jsonSchemaRDD = jsonRDD(corruptRecords)
+ jsonSchemaRDD.registerTempTable("jsonTable")
+
+ val schema = StructType(
+ StructField("_unparsed", StringType, true) ::
+ StructField("a", StringType, true) ::
+ StructField("b", StringType, true) ::
+ StructField("c", StringType, true) :: Nil)
+
+ assert(schema === jsonSchemaRDD.schema)
+
+ // In HiveContext, backticks should be used to access columns starting with a underscore.
+ checkAnswer(
+ sql(
+ """
+ |SELECT a, b, c, _unparsed
+ |FROM jsonTable
+ """.stripMargin),
+ (null, null, null, "{") ::
+ (null, null, null, "") ::
+ (null, null, null, """{"a":1, b:2}""") ::
+ (null, null, null, """{"a":{, b:3}""") ::
+ ("str_a_4", "str_b_4", "str_c_4", null) ::
+ (null, null, null, "]") :: Nil
)
+
+ checkAnswer(
+ sql(
+ """
+ |SELECT a, b, c
+ |FROM jsonTable
+ |WHERE _unparsed IS NULL
+ """.stripMargin),
+ ("str_a_4", "str_b_4", "str_c_4") :: Nil
+ )
+
+ checkAnswer(
+ sql(
+ """
+ |SELECT _unparsed
+ |FROM jsonTable
+ |WHERE _unparsed IS NOT NULL
+ """.stripMargin),
+ Seq("{") ::
+ Seq("") ::
+ Seq("""{"a":1, b:2}""") ::
+ Seq("""{"a":{, b:3}""") ::
+ Seq("]") :: Nil
+ )
+
+ TestSQLContext.setConf(SQLConf.COLUMN_NAME_OF_CORRUPT_RECORD, oldColumnNameOfCorruptRecord)
}
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/json/TestJsonData.scala b/sql/core/src/test/scala/org/apache/spark/sql/json/TestJsonData.scala
index fc833b8b54e4c..eaca9f0508a12 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/json/TestJsonData.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/json/TestJsonData.scala
@@ -143,4 +143,13 @@ object TestJsonData {
"""[{"a":"str_a_2"}, {"b":"str_b_3"}]""" ::
"""{"b":"str_b_4", "a":"str_a_4", "c":"str_c_4"}""" ::
"""[]""" :: Nil)
+
+ val corruptRecords =
+ TestSQLContext.sparkContext.parallelize(
+ """{""" ::
+ """""" ::
+ """{"a":1, b:2}""" ::
+ """{"a":{, b:3}""" ::
+ """{"b":"str_b_4", "a":"str_a_4", "c":"str_c_4"}""" ::
+ """]""" :: Nil)
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala
index 07adf731405af..25e41ecf28e2e 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala
@@ -789,7 +789,7 @@ class ParquetQuerySuite extends QueryTest with FunSuiteLike with BeforeAndAfterA
assert(result3(0)(1) === "the answer")
Utils.deleteRecursively(tmpdir)
}
-
+
test("Querying on empty parquet throws exception (SPARK-3536)") {
val tmpdir = Utils.createTempDir()
Utils.deleteRecursively(tmpdir)
@@ -798,4 +798,18 @@ class ParquetQuerySuite extends QueryTest with FunSuiteLike with BeforeAndAfterA
assert(result1.size === 0)
Utils.deleteRecursively(tmpdir)
}
+
+ test("DataType string parser compatibility") {
+ val schema = StructType(List(
+ StructField("c1", IntegerType, false),
+ StructField("c2", BinaryType, false)))
+
+ val fromCaseClassString = ParquetTypesConverter.convertFromString(schema.toString)
+ val fromJson = ParquetTypesConverter.convertFromString(schema.json)
+
+ (fromCaseClassString, fromJson).zipped.foreach { (a, b) =>
+ assert(a.name == b.name)
+ assert(a.dataType === b.dataType)
+ }
+ }
}
diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/server/SparkSQLOperationManager.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/server/SparkSQLOperationManager.scala
index bd3f68d92d8c7..accf61576b804 100644
--- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/server/SparkSQLOperationManager.scala
+++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/server/SparkSQLOperationManager.scala
@@ -113,7 +113,7 @@ private[thriftserver] class SparkSQLOperationManager(hiveContext: HiveContext)
case ByteType =>
to.addColumnValue(ColumnValue.byteValue(from.getByte(ordinal)))
case ShortType =>
- to.addColumnValue(ColumnValue.intValue(from.getShort(ordinal)))
+ to.addColumnValue(ColumnValue.shortValue(from.getShort(ordinal)))
case TimestampType =>
to.addColumnValue(
ColumnValue.timestampValue(from.get(ordinal).asInstanceOf[Timestamp]))
@@ -145,7 +145,7 @@ private[thriftserver] class SparkSQLOperationManager(hiveContext: HiveContext)
case ByteType =>
to.addColumnValue(ColumnValue.byteValue(null))
case ShortType =>
- to.addColumnValue(ColumnValue.intValue(null))
+ to.addColumnValue(ColumnValue.shortValue(null))
case TimestampType =>
to.addColumnValue(ColumnValue.timestampValue(null))
case BinaryType | _: ArrayType | _: StructType | _: MapType =>
@@ -172,7 +172,7 @@ private[thriftserver] class SparkSQLOperationManager(hiveContext: HiveContext)
result = hiveContext.sql(statement)
logDebug(result.queryExecution.toString())
result.queryExecution.logical match {
- case SetCommand(Some(key), Some(value)) if (key == SQLConf.THRIFTSERVER_POOL) =>
+ case SetCommand(Some((SQLConf.THRIFTSERVER_POOL, Some(value)))) =>
sessionToActivePool(parentSession) = value
logInfo(s"Setting spark.scheduler.pool=$value for future statements in this session.")
case _ =>
diff --git a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/CliSuite.scala b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/CliSuite.scala
index 3475c2c9db080..8a72e9d2aef57 100644
--- a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/CliSuite.scala
+++ b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/CliSuite.scala
@@ -30,7 +30,7 @@ import java.util.concurrent.atomic.AtomicInteger
import org.apache.hadoop.hive.conf.HiveConf.ConfVars
import org.scalatest.{BeforeAndAfterAll, FunSuite}
-import org.apache.spark.Logging
+import org.apache.spark.{SparkException, Logging}
import org.apache.spark.sql.catalyst.util.getTempFilePath
class CliSuite extends FunSuite with BeforeAndAfterAll with Logging {
@@ -62,9 +62,14 @@ class CliSuite extends FunSuite with BeforeAndAfterAll with Logging {
def captureOutput(source: String)(line: String) {
buffer += s"$source> $line"
- if (line.contains(expectedAnswers(next.get()))) {
- if (next.incrementAndGet() == expectedAnswers.size) {
- foundAllExpectedAnswers.trySuccess(())
+ // If we haven't found all expected answers...
+ if (next.get() < expectedAnswers.size) {
+ // If another expected answer is found...
+ if (line.startsWith(expectedAnswers(next.get()))) {
+ // If all expected answers have been found...
+ if (next.incrementAndGet() == expectedAnswers.size) {
+ foundAllExpectedAnswers.trySuccess(())
+ }
}
}
}
@@ -73,11 +78,6 @@ class CliSuite extends FunSuite with BeforeAndAfterAll with Logging {
val process = (Process(command) #< queryStream).run(
ProcessLogger(captureOutput("stdout"), captureOutput("stderr")))
- Future {
- val exitValue = process.exitValue()
- logInfo(s"Spark SQL CLI process exit value: $exitValue")
- }
-
try {
Await.result(foundAllExpectedAnswers.future, timeout)
} catch { case cause: Throwable =>
@@ -96,6 +96,7 @@ class CliSuite extends FunSuite with BeforeAndAfterAll with Logging {
|End CliSuite failure output
|===========================
""".stripMargin, cause)
+ throw cause
} finally {
warehousePath.delete()
metastorePath.delete()
@@ -107,7 +108,7 @@ class CliSuite extends FunSuite with BeforeAndAfterAll with Logging {
val dataFilePath =
Thread.currentThread().getContextClassLoader.getResource("data/files/small_kv.txt")
- runCliWithin(1.minute)(
+ runCliWithin(3.minute)(
"CREATE TABLE hive_test(key INT, val STRING);"
-> "OK",
"SHOW TABLES;"
@@ -118,7 +119,7 @@ class CliSuite extends FunSuite with BeforeAndAfterAll with Logging {
-> "Time taken: ",
"SELECT COUNT(*) FROM hive_test;"
-> "5",
- "DROP TABLE hive_test"
+ "DROP TABLE hive_test;"
-> "Time taken: "
)
}
diff --git a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suite.scala b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suite.scala
index 38977ff162097..e3b4e45a3d68c 100644
--- a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suite.scala
+++ b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suite.scala
@@ -17,17 +17,17 @@
package org.apache.spark.sql.hive.thriftserver
-import scala.collection.mutable.ArrayBuffer
-import scala.concurrent.ExecutionContext.Implicits.global
-import scala.concurrent.duration._
-import scala.concurrent.{Await, Future, Promise}
-import scala.sys.process.{Process, ProcessLogger}
-
import java.io.File
import java.net.ServerSocket
import java.sql.{DriverManager, Statement}
import java.util.concurrent.TimeoutException
+import scala.collection.mutable.ArrayBuffer
+import scala.concurrent.duration._
+import scala.concurrent.{Await, Promise}
+import scala.sys.process.{Process, ProcessLogger}
+import scala.util.Try
+
import org.apache.hadoop.hive.conf.HiveConf.ConfVars
import org.apache.hive.jdbc.HiveDriver
import org.scalatest.FunSuite
@@ -41,25 +41,25 @@ import org.apache.spark.sql.catalyst.util.getTempFilePath
class HiveThriftServer2Suite extends FunSuite with Logging {
Class.forName(classOf[HiveDriver].getCanonicalName)
- private val listeningHost = "localhost"
- private val listeningPort = {
- // Let the system to choose a random available port to avoid collision with other parallel
- // builds.
- val socket = new ServerSocket(0)
- val port = socket.getLocalPort
- socket.close()
- port
- }
-
- private val warehousePath = getTempFilePath("warehouse")
- private val metastorePath = getTempFilePath("metastore")
- private val metastoreJdbcUri = s"jdbc:derby:;databaseName=$metastorePath;create=true"
-
- def startThriftServerWithin(timeout: FiniteDuration = 30.seconds)(f: Statement => Unit) {
- val serverScript = "../../sbin/start-thriftserver.sh".split("/").mkString(File.separator)
+ def startThriftServerWithin(timeout: FiniteDuration = 1.minute)(f: Statement => Unit) {
+ val startScript = "../../sbin/start-thriftserver.sh".split("/").mkString(File.separator)
+ val stopScript = "../../sbin/stop-thriftserver.sh".split("/").mkString(File.separator)
+
+ val warehousePath = getTempFilePath("warehouse")
+ val metastorePath = getTempFilePath("metastore")
+ val metastoreJdbcUri = s"jdbc:derby:;databaseName=$metastorePath;create=true"
+ val listeningHost = "localhost"
+ val listeningPort = {
+ // Let the system to choose a random available port to avoid collision with other parallel
+ // builds.
+ val socket = new ServerSocket(0)
+ val port = socket.getLocalPort
+ socket.close()
+ port
+ }
val command =
- s"""$serverScript
+ s"""$startScript
| --master local
| --hiveconf hive.root.logger=INFO,console
| --hiveconf ${ConfVars.METASTORECONNECTURLKEY}=$metastoreJdbcUri
@@ -68,29 +68,40 @@ class HiveThriftServer2Suite extends FunSuite with Logging {
| --hiveconf ${ConfVars.HIVE_SERVER2_THRIFT_PORT}=$listeningPort
""".stripMargin.split("\\s+").toSeq
- val serverStarted = Promise[Unit]()
+ val serverRunning = Promise[Unit]()
val buffer = new ArrayBuffer[String]()
+ val LOGGING_MARK =
+ s"starting ${HiveThriftServer2.getClass.getCanonicalName.stripSuffix("$")}, logging to "
+ var logTailingProcess: Process = null
+ var logFilePath: String = null
- def captureOutput(source: String)(line: String) {
- buffer += s"$source> $line"
+ def captureLogOutput(line: String): Unit = {
+ buffer += line
if (line.contains("ThriftBinaryCLIService listening on")) {
- serverStarted.success(())
+ serverRunning.success(())
}
}
- val process = Process(command).run(
- ProcessLogger(captureOutput("stdout"), captureOutput("stderr")))
-
- Future {
- val exitValue = process.exitValue()
- logInfo(s"Spark SQL Thrift server process exit value: $exitValue")
+ def captureThriftServerOutput(source: String)(line: String): Unit = {
+ if (line.startsWith(LOGGING_MARK)) {
+ logFilePath = line.drop(LOGGING_MARK.length).trim
+ // Ensure that the log file is created so that the `tail' command won't fail
+ Try(new File(logFilePath).createNewFile())
+ logTailingProcess = Process(s"/usr/bin/env tail -f $logFilePath")
+ .run(ProcessLogger(captureLogOutput, _ => ()))
+ }
}
+ // Resets SPARK_TESTING to avoid loading Log4J configurations in testing class paths
+ Process(command, None, "SPARK_TESTING" -> "0").run(ProcessLogger(
+ captureThriftServerOutput("stdout"),
+ captureThriftServerOutput("stderr")))
+
val jdbcUri = s"jdbc:hive2://$listeningHost:$listeningPort/"
val user = System.getProperty("user.name")
try {
- Await.result(serverStarted.future, timeout)
+ Await.result(serverRunning.future, timeout)
val connection = DriverManager.getConnection(jdbcUri, user, "")
val statement = connection.createStatement()
@@ -122,10 +133,15 @@ class HiveThriftServer2Suite extends FunSuite with Logging {
|End HiveThriftServer2Suite failure output
|=========================================
""".stripMargin, cause)
+ throw cause
} finally {
warehousePath.delete()
metastorePath.delete()
- process.destroy()
+ Process(stopScript).run().exitValue()
+ // The `spark-daemon.sh' script uses kill, which is not synchronous, have to wait for a while.
+ Thread.sleep(3.seconds.toMillis)
+ Option(logTailingProcess).map(_.destroy())
+ Option(logFilePath).map(new File(_).delete())
}
}
diff --git a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala
index 556c984ad392b..463888551a359 100644
--- a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala
+++ b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala
@@ -220,6 +220,23 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter {
*/
override def whiteList = Seq(
"add_part_exist",
+ "dynamic_partition_skip_default",
+ "infer_bucket_sort_dyn_part",
+ "load_dyn_part1",
+ "load_dyn_part2",
+ "load_dyn_part3",
+ "load_dyn_part4",
+ "load_dyn_part5",
+ "load_dyn_part6",
+ "load_dyn_part7",
+ "load_dyn_part8",
+ "load_dyn_part9",
+ "load_dyn_part10",
+ "load_dyn_part11",
+ "load_dyn_part12",
+ "load_dyn_part13",
+ "load_dyn_part14",
+ "load_dyn_part14_win",
"add_part_multiple",
"add_partition_no_whitelist",
"add_partition_with_whitelist",
@@ -326,6 +343,13 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter {
"ct_case_insensitive",
"database_location",
"database_properties",
+ "date_2",
+ "date_3",
+ "date_4",
+ "date_comparison",
+ "date_join1",
+ "date_serde",
+ "date_udf",
"decimal_1",
"decimal_4",
"decimal_join",
@@ -587,8 +611,10 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter {
"part_inherit_tbl_props",
"part_inherit_tbl_props_empty",
"part_inherit_tbl_props_with_star",
+ "partition_date",
"partition_schema1",
"partition_serde_format",
+ "partition_type_check",
"partition_varchar1",
"partition_wise_fileformat4",
"partition_wise_fileformat5",
@@ -887,6 +913,7 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter {
"union7",
"union8",
"union9",
+ "union_date",
"union_lateralview",
"union_ppr",
"union_remove_11",
diff --git a/sql/hive/src/main/scala/org/apache/spark/SparkHadoopWriter.scala b/sql/hive/src/main/scala/org/apache/spark/SparkHadoopWriter.scala
deleted file mode 100644
index ab7862f4f9e06..0000000000000
--- a/sql/hive/src/main/scala/org/apache/spark/SparkHadoopWriter.scala
+++ /dev/null
@@ -1,195 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.sql.hive
-
-import java.io.IOException
-import java.text.NumberFormat
-import java.util.Date
-
-import org.apache.hadoop.fs.Path
-import org.apache.hadoop.hive.ql.exec.{FileSinkOperator, Utilities}
-import org.apache.hadoop.hive.ql.io.{HiveFileFormatUtils, HiveOutputFormat}
-import org.apache.hadoop.hive.ql.plan.FileSinkDesc
-import org.apache.hadoop.mapred._
-import org.apache.hadoop.io.Writable
-
-import org.apache.spark.{Logging, SerializableWritable, SparkHadoopWriter}
-
-/**
- * Internal helper class that saves an RDD using a Hive OutputFormat.
- * It is based on [[SparkHadoopWriter]].
- */
-private[hive] class SparkHiveHadoopWriter(
- @transient jobConf: JobConf,
- fileSinkConf: FileSinkDesc)
- extends Logging
- with SparkHadoopMapRedUtil
- with Serializable {
-
- private val now = new Date()
- private val conf = new SerializableWritable(jobConf)
-
- private var jobID = 0
- private var splitID = 0
- private var attemptID = 0
- private var jID: SerializableWritable[JobID] = null
- private var taID: SerializableWritable[TaskAttemptID] = null
-
- @transient private var writer: FileSinkOperator.RecordWriter = null
- @transient private var format: HiveOutputFormat[AnyRef, Writable] = null
- @transient private var committer: OutputCommitter = null
- @transient private var jobContext: JobContext = null
- @transient private var taskContext: TaskAttemptContext = null
-
- def preSetup() {
- setIDs(0, 0, 0)
- setConfParams()
-
- val jCtxt = getJobContext()
- getOutputCommitter().setupJob(jCtxt)
- }
-
-
- def setup(jobid: Int, splitid: Int, attemptid: Int) {
- setIDs(jobid, splitid, attemptid)
- setConfParams()
- }
-
- def open() {
- val numfmt = NumberFormat.getInstance()
- numfmt.setMinimumIntegerDigits(5)
- numfmt.setGroupingUsed(false)
-
- val extension = Utilities.getFileExtension(
- conf.value,
- fileSinkConf.getCompressed,
- getOutputFormat())
-
- val outputName = "part-" + numfmt.format(splitID) + extension
- val path = FileOutputFormat.getTaskOutputPath(conf.value, outputName)
-
- getOutputCommitter().setupTask(getTaskContext())
- writer = HiveFileFormatUtils.getHiveRecordWriter(
- conf.value,
- fileSinkConf.getTableInfo,
- conf.value.getOutputValueClass.asInstanceOf[Class[Writable]],
- fileSinkConf,
- path,
- null)
- }
-
- def write(value: Writable) {
- if (writer != null) {
- writer.write(value)
- } else {
- throw new IOException("Writer is null, open() has not been called")
- }
- }
-
- def close() {
- // Seems the boolean value passed into close does not matter.
- writer.close(false)
- }
-
- def commit() {
- val taCtxt = getTaskContext()
- val cmtr = getOutputCommitter()
- if (cmtr.needsTaskCommit(taCtxt)) {
- try {
- cmtr.commitTask(taCtxt)
- logInfo (taID + ": Committed")
- } catch {
- case e: IOException =>
- logError("Error committing the output of task: " + taID.value, e)
- cmtr.abortTask(taCtxt)
- throw e
- }
- } else {
- logWarning ("No need to commit output of task: " + taID.value)
- }
- }
-
- def commitJob() {
- // always ? Or if cmtr.needsTaskCommit ?
- val cmtr = getOutputCommitter()
- cmtr.commitJob(getJobContext())
- }
-
- // ********* Private Functions *********
-
- private def getOutputFormat(): HiveOutputFormat[AnyRef,Writable] = {
- if (format == null) {
- format = conf.value.getOutputFormat()
- .asInstanceOf[HiveOutputFormat[AnyRef,Writable]]
- }
- format
- }
-
- private def getOutputCommitter(): OutputCommitter = {
- if (committer == null) {
- committer = conf.value.getOutputCommitter
- }
- committer
- }
-
- private def getJobContext(): JobContext = {
- if (jobContext == null) {
- jobContext = newJobContext(conf.value, jID.value)
- }
- jobContext
- }
-
- private def getTaskContext(): TaskAttemptContext = {
- if (taskContext == null) {
- taskContext = newTaskAttemptContext(conf.value, taID.value)
- }
- taskContext
- }
-
- private def setIDs(jobId: Int, splitId: Int, attemptId: Int) {
- jobID = jobId
- splitID = splitId
- attemptID = attemptId
-
- jID = new SerializableWritable[JobID](SparkHadoopWriter.createJobID(now, jobId))
- taID = new SerializableWritable[TaskAttemptID](
- new TaskAttemptID(new TaskID(jID.value, true, splitID), attemptID))
- }
-
- private def setConfParams() {
- conf.value.set("mapred.job.id", jID.value.toString)
- conf.value.set("mapred.tip.id", taID.value.getTaskID.toString)
- conf.value.set("mapred.task.id", taID.value.toString)
- conf.value.setBoolean("mapred.task.is.map", true)
- conf.value.setInt("mapred.task.partition", splitID)
- }
-}
-
-private[hive] object SparkHiveHadoopWriter {
- def createPathFromString(path: String, conf: JobConf): Path = {
- if (path == null) {
- throw new IllegalArgumentException("Output path is null")
- }
- val outputPath = new Path(path)
- val fs = outputPath.getFileSystem(conf)
- if (outputPath == null || fs == null) {
- throw new IllegalArgumentException("Incorrectly formatted output path")
- }
- outputPath.makeQualified(fs)
- }
-}
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/ExtendedHiveQlParser.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/ExtendedHiveQlParser.scala
new file mode 100644
index 0000000000000..430ffb29989ea
--- /dev/null
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/ExtendedHiveQlParser.scala
@@ -0,0 +1,67 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.hive
+
+import scala.language.implicitConversions
+
+import org.apache.spark.sql.catalyst.plans.logical._
+import org.apache.spark.sql.catalyst.{AbstractSparkSQLParser, SqlLexical}
+
+/**
+ * A parser that recognizes all HiveQL constructs together with Spark SQL specific extensions.
+ */
+private[hive] class ExtendedHiveQlParser extends AbstractSparkSQLParser {
+ protected implicit def asParser(k: Keyword): Parser[String] =
+ lexical.allCaseVersions(k.str).map(x => x : Parser[String]).reduce(_ | _)
+
+ protected val ADD = Keyword("ADD")
+ protected val DFS = Keyword("DFS")
+ protected val FILE = Keyword("FILE")
+ protected val JAR = Keyword("JAR")
+
+ private val reservedWords =
+ this
+ .getClass
+ .getMethods
+ .filter(_.getReturnType == classOf[Keyword])
+ .map(_.invoke(this).asInstanceOf[Keyword].str)
+
+ override val lexical = new SqlLexical(reservedWords)
+
+ protected lazy val start: Parser[LogicalPlan] = dfs | addJar | addFile | hiveQl
+
+ protected lazy val hiveQl: Parser[LogicalPlan] =
+ restInput ^^ {
+ case statement => HiveQl.createPlan(statement.trim)
+ }
+
+ protected lazy val dfs: Parser[LogicalPlan] =
+ DFS ~> wholeInput ^^ {
+ case command => NativeCommand(command.trim)
+ }
+
+ private lazy val addFile: Parser[LogicalPlan] =
+ ADD ~ FILE ~> restInput ^^ {
+ case input => AddFile(input.trim)
+ }
+
+ private lazy val addJar: Parser[LogicalPlan] =
+ ADD ~ JAR ~> restInput ^^ {
+ case input => AddJar(input.trim)
+ }
+}
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala
index 3e1a7b71528e0..8b5a90159e1bb 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala
@@ -18,7 +18,7 @@
package org.apache.spark.sql.hive
import java.io.{BufferedReader, File, InputStreamReader, PrintStream}
-import java.sql.Timestamp
+import java.sql.{Date, Timestamp}
import java.util.{ArrayList => JArrayList}
import scala.collection.JavaConversions._
@@ -34,6 +34,7 @@ import org.apache.hadoop.hive.ql.processors._
import org.apache.hadoop.hive.ql.session.SessionState
import org.apache.hadoop.hive.ql.stats.StatsSetupConst
import org.apache.hadoop.hive.serde2.io.TimestampWritable
+import org.apache.hadoop.hive.serde2.io.DateWritable
import org.apache.spark.SparkContext
import org.apache.spark.rdd.RDD
@@ -231,12 +232,13 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) {
@transient protected[hive] lazy val sessionState = {
val ss = new SessionState(hiveconf)
setConf(hiveconf.getAllProperties) // Have SQLConf pick up the initial set of HiveConf.
+ SessionState.start(ss)
+ ss.err = new PrintStream(outputBuffer, true, "UTF-8")
+ ss.out = new PrintStream(outputBuffer, true, "UTF-8")
+
ss
}
- sessionState.err = new PrintStream(outputBuffer, true, "UTF-8")
- sessionState.out = new PrintStream(outputBuffer, true, "UTF-8")
-
override def setConf(key: String, value: String): Unit = {
super.setConf(key, value)
runSqlHive(s"SET $key=$value")
@@ -267,13 +269,12 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) {
*/
protected[sql] def runSqlHive(sql: String): Seq[String] = {
val maxResults = 100000
- val results = runHive(sql, 100000)
+ val results = runHive(sql, maxResults)
// It is very confusing when you only get back some of the results...
if (results.size == maxResults) sys.error("RESULTS POSSIBLY TRUNCATED")
results
}
- SessionState.start(sessionState)
/**
* Execute the command using Hive and return the results as a sequence. Each element
@@ -281,13 +282,14 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) {
*/
protected def runHive(cmd: String, maxRows: Int = 1000): Seq[String] = {
try {
+ // Session state must be initilized before the CommandProcessor is created .
+ SessionState.start(sessionState)
+
val cmd_trimmed: String = cmd.trim()
val tokens: Array[String] = cmd_trimmed.split("\\s+")
val cmd_1: String = cmd_trimmed.substring(tokens(0).length()).trim()
val proc: CommandProcessor = CommandProcessorFactory.get(tokens(0), hiveconf)
- SessionState.start(sessionState)
-
proc match {
case driver: Driver =>
driver.init()
@@ -356,7 +358,7 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) {
protected val primitiveTypes =
Seq(StringType, IntegerType, LongType, DoubleType, FloatType, BooleanType, ByteType,
- ShortType, DecimalType, TimestampType, BinaryType)
+ ShortType, DecimalType, DateType, TimestampType, BinaryType)
protected[sql] def toHiveString(a: (Any, DataType)): String = a match {
case (struct: Row, StructType(fields)) =>
@@ -371,6 +373,7 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) {
toHiveStructString((key, kType)) + ":" + toHiveStructString((value, vType))
}.toSeq.sorted.mkString("{", ",", "}")
case (null, _) => "NULL"
+ case (d: Date, DateType) => new DateWritable(d).toString
case (t: Timestamp, TimestampType) => new TimestampWritable(t).toString
case (bin: Array[Byte], BinaryType) => new String(bin, "UTF-8")
case (other, tpe) if primitiveTypes contains tpe => other.toString
@@ -404,7 +407,7 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) {
// be similar with Hive.
describeHiveTableCommand.hiveString
case command: PhysicalCommand =>
- command.sideEffectResult.map(_.head.toString)
+ command.executeCollect().map(_.head.toString)
case other =>
val result: Seq[Seq[Any]] = toRdd.collect().toSeq
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala
index fa889ec104c6e..1977618b4c9f2 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala
@@ -39,6 +39,7 @@ private[hive] trait HiveInspectors {
case c: Class[_] if c == classOf[hiveIo.HiveDecimalWritable] => DecimalType
case c: Class[_] if c == classOf[hiveIo.ByteWritable] => ByteType
case c: Class[_] if c == classOf[hiveIo.ShortWritable] => ShortType
+ case c: Class[_] if c == classOf[hiveIo.DateWritable] => DateType
case c: Class[_] if c == classOf[hiveIo.TimestampWritable] => TimestampType
case c: Class[_] if c == classOf[hadoopIo.Text] => StringType
case c: Class[_] if c == classOf[hadoopIo.IntWritable] => IntegerType
@@ -49,6 +50,7 @@ private[hive] trait HiveInspectors {
// java class
case c: Class[_] if c == classOf[java.lang.String] => StringType
+ case c: Class[_] if c == classOf[java.sql.Date] => DateType
case c: Class[_] if c == classOf[java.sql.Timestamp] => TimestampType
case c: Class[_] if c == classOf[HiveDecimal] => DecimalType
case c: Class[_] if c == classOf[java.math.BigDecimal] => DecimalType
@@ -93,6 +95,7 @@ private[hive] trait HiveInspectors {
System.arraycopy(b.getBytes(), 0, bytes, 0, b.getLength)
bytes
}
+ case d: hiveIo.DateWritable => d.get
case t: hiveIo.TimestampWritable => t.getTimestamp
case b: hiveIo.HiveDecimalWritable => BigDecimal(b.getHiveDecimal().bigDecimalValue())
case list: java.util.List[_] => list.map(unwrap)
@@ -108,6 +111,7 @@ private[hive] trait HiveInspectors {
case str: String => str
case p: java.math.BigDecimal => p
case p: Array[Byte] => p
+ case p: java.sql.Date => p
case p: java.sql.Timestamp => p
}
@@ -147,6 +151,7 @@ private[hive] trait HiveInspectors {
case l: Byte => l: java.lang.Byte
case b: BigDecimal => new HiveDecimal(b.underlying())
case b: Array[Byte] => b
+ case d: java.sql.Date => d
case t: java.sql.Timestamp => t
case s: Seq[_] => seqAsJavaList(s.map(wrap))
case m: Map[_,_] =>
@@ -173,6 +178,7 @@ private[hive] trait HiveInspectors {
case ByteType => PrimitiveObjectInspectorFactory.javaByteObjectInspector
case NullType => PrimitiveObjectInspectorFactory.javaVoidObjectInspector
case BinaryType => PrimitiveObjectInspectorFactory.javaByteArrayObjectInspector
+ case DateType => PrimitiveObjectInspectorFactory.javaDateObjectInspector
case TimestampType => PrimitiveObjectInspectorFactory.javaTimestampObjectInspector
case DecimalType => PrimitiveObjectInspectorFactory.javaHiveDecimalObjectInspector
case StructType(fields) =>
@@ -211,8 +217,12 @@ private[hive] trait HiveInspectors {
case _: JavaBinaryObjectInspector => BinaryType
case _: WritableHiveDecimalObjectInspector => DecimalType
case _: JavaHiveDecimalObjectInspector => DecimalType
+ case _: WritableDateObjectInspector => DateType
+ case _: JavaDateObjectInspector => DateType
case _: WritableTimestampObjectInspector => TimestampType
case _: JavaTimestampObjectInspector => TimestampType
+ case _: WritableVoidObjectInspector => NullType
+ case _: JavaVoidObjectInspector => NullType
}
implicit class typeInfoConversions(dt: DataType) {
@@ -236,6 +246,7 @@ private[hive] trait HiveInspectors {
case ShortType => shortTypeInfo
case StringType => stringTypeInfo
case DecimalType => decimalTypeInfo
+ case DateType => dateTypeInfo
case TimestampType => timestampTypeInfo
case NullType => voidTypeInfo
}
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala
index 9a0b9b46ac4ee..75a19656af110 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala
@@ -19,31 +19,28 @@ package org.apache.spark.sql.hive
import scala.util.parsing.combinator.RegexParsers
-import org.apache.hadoop.hive.metastore.api.{FieldSchema, StorageDescriptor, SerDeInfo}
-import org.apache.hadoop.hive.metastore.api.{Table => TTable, Partition => TPartition}
+import org.apache.hadoop.hive.metastore.api.{FieldSchema, SerDeInfo, StorageDescriptor, Partition => TPartition, Table => TTable}
import org.apache.hadoop.hive.ql.metadata.{Hive, Partition, Table}
import org.apache.hadoop.hive.ql.plan.TableDesc
import org.apache.hadoop.hive.ql.stats.StatsSetupConst
import org.apache.hadoop.hive.serde2.Deserializer
-import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.Logging
+import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.sql.SQLContext
-import org.apache.spark.sql.catalyst.analysis.{EliminateAnalysisOperators, Catalog}
+import org.apache.spark.sql.catalyst.analysis.Catalog
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.logical
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.rules._
import org.apache.spark.sql.catalyst.types._
-import org.apache.spark.sql.columnar.InMemoryRelation
-import org.apache.spark.sql.hive.execution.HiveTableScan
import org.apache.spark.util.Utils
/* Implicit conversions */
import scala.collection.JavaConversions._
private[hive] class HiveMetastoreCatalog(hive: HiveContext) extends Catalog with Logging {
- import HiveMetastoreTypes._
+ import org.apache.spark.sql.hive.HiveMetastoreTypes._
/** Connection to hive metastore. Usages should lock on `this`. */
protected[hive] val client = Hive.get(hive.hiveconf)
@@ -96,10 +93,12 @@ private[hive] class HiveMetastoreCatalog(hive: HiveContext) extends Catalog with
serDeInfo.setParameters(Map[String, String]())
sd.setSerdeInfo(serDeInfo)
- try client.createTable(table) catch {
- case e: org.apache.hadoop.hive.ql.metadata.HiveException
- if e.getCause.isInstanceOf[org.apache.hadoop.hive.metastore.api.AlreadyExistsException] &&
- allowExisting => // Do nothing.
+ synchronized {
+ try client.createTable(table) catch {
+ case e: org.apache.hadoop.hive.ql.metadata.HiveException
+ if e.getCause.isInstanceOf[org.apache.hadoop.hive.metastore.api.AlreadyExistsException] &&
+ allowExisting => // Do nothing.
+ }
}
}
@@ -131,18 +130,12 @@ private[hive] class HiveMetastoreCatalog(hive: HiveContext) extends Catalog with
case p @ InsertIntoTable(table: MetastoreRelation, _, child, _) =>
castChildOutput(p, table, child)
-
- case p @ logical.InsertIntoTable(
- InMemoryRelation(_, _, _,
- HiveTableScan(_, table, _)), _, child, _) =>
- castChildOutput(p, table, child)
}
def castChildOutput(p: InsertIntoTable, table: MetastoreRelation, child: LogicalPlan) = {
val childOutputDataTypes = child.output.map(_.dataType)
- // Only check attributes, not partitionKeys since they are always strings.
- // TODO: Fully support inserting into partitioned tables.
- val tableOutputDataTypes = table.attributes.map(_.dataType)
+ val tableOutputDataTypes =
+ (table.attributes ++ table.partitionKeys).take(child.output.length).map(_.dataType)
if (childOutputDataTypes == tableOutputDataTypes) {
p
@@ -193,6 +186,7 @@ object HiveMetastoreTypes extends RegexParsers {
"binary" ^^^ BinaryType |
"boolean" ^^^ BooleanType |
"decimal" ^^^ DecimalType |
+ "date" ^^^ DateType |
"timestamp" ^^^ TimestampType |
"varchar\\((\\d+)\\)".r ^^^ StringType
@@ -242,6 +236,7 @@ object HiveMetastoreTypes extends RegexParsers {
case LongType => "bigint"
case BinaryType => "binary"
case BooleanType => "boolean"
+ case DateType => "date"
case DecimalType => "decimal"
case TimestampType => "timestamp"
case NullType => "void"
@@ -303,14 +298,14 @@ private[hive] case class MetastoreRelation
HiveMetastoreTypes.toDataType(f.getType),
// Since data can be dumped in randomly with no validation, everything is nullable.
nullable = true
- )(qualifiers = tableName +: alias.toSeq)
+ )(qualifiers = Seq(alias.getOrElse(tableName)))
}
// Must be a stable value since new attributes are born here.
val partitionKeys = hiveQlTable.getPartitionKeys.map(_.toAttribute)
/** Non-partitionKey attributes */
- val attributes = table.getSd.getCols.map(_.toAttribute)
+ val attributes = hiveQlTable.getCols.map(_.toAttribute)
val output = attributes ++ partitionKeys
}
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala
index 0aa6292c0184e..2b599157d15d3 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala
@@ -17,10 +17,13 @@
package org.apache.spark.sql.hive
+import java.sql.Date
+
import org.apache.hadoop.hive.ql.lib.Node
import org.apache.hadoop.hive.ql.parse._
import org.apache.hadoop.hive.ql.plan.PlanUtils
+import org.apache.spark.sql.catalyst.SparkSQLParser
import org.apache.spark.sql.catalyst.analysis._
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans._
@@ -38,10 +41,6 @@ import scala.collection.JavaConversions._
*/
private[hive] case object NativePlaceholder extends Command
-private[hive] case class ShellCommand(cmd: String) extends Command
-
-private[hive] case class SourceCommand(filePath: String) extends Command
-
private[hive] case class AddFile(filePath: String) extends Command
private[hive] case class AddJar(path: String) extends Command
@@ -127,6 +126,11 @@ private[hive] object HiveQl {
"TOK_DESCTABLE"
) ++ nativeCommands
+ protected val hqlParser = {
+ val fallback = new ExtendedHiveQlParser
+ new SparkSQLParser(fallback(_))
+ }
+
/**
* A set of implicit transformations that allow Hive ASTNodes to be rewritten by transformations
* similar to [[catalyst.trees.TreeNode]].
@@ -215,40 +219,19 @@ private[hive] object HiveQl {
def getAst(sql: String): ASTNode = ParseUtils.findRootNonNullToken((new ParseDriver).parse(sql))
/** Returns a LogicalPlan for a given HiveQL string. */
- def parseSql(sql: String): LogicalPlan = {
+ def parseSql(sql: String): LogicalPlan = hqlParser(sql)
+
+ /** Creates LogicalPlan for a given HiveQL string. */
+ def createPlan(sql: String) = {
try {
- if (sql.trim.toLowerCase.startsWith("set")) {
- // Split in two parts since we treat the part before the first "="
- // as key, and the part after as value, which may contain other "=" signs.
- sql.trim.drop(3).split("=", 2).map(_.trim) match {
- case Array("") => // "set"
- SetCommand(None, None)
- case Array(key) => // "set key"
- SetCommand(Some(key), None)
- case Array(key, value) => // "set key=value"
- SetCommand(Some(key), Some(value))
- }
- } else if (sql.trim.toLowerCase.startsWith("cache table")) {
- sql.trim.drop(12).trim.split(" ").toSeq match {
- case Seq(tableName) =>
- CacheCommand(tableName, true)
- case Seq(tableName, _, select @ _*) =>
- CacheTableAsSelectCommand(tableName, createPlan(select.mkString(" ").trim))
- }
- } else if (sql.trim.toLowerCase.startsWith("uncache table")) {
- CacheCommand(sql.trim.drop(14).trim, false)
- } else if (sql.trim.toLowerCase.startsWith("add jar")) {
- AddJar(sql.trim.drop(8).trim)
- } else if (sql.trim.toLowerCase.startsWith("add file")) {
- AddFile(sql.trim.drop(9))
- } else if (sql.trim.toLowerCase.startsWith("dfs")) {
+ val tree = getAst(sql)
+ if (nativeCommands contains tree.getText) {
NativeCommand(sql)
- } else if (sql.trim.startsWith("source")) {
- SourceCommand(sql.split(" ").toSeq match { case Seq("source", filePath) => filePath })
- } else if (sql.trim.startsWith("!")) {
- ShellCommand(sql.drop(1))
} else {
- createPlan(sql)
+ nodeToPlan(tree) match {
+ case NativePlaceholder => NativeCommand(sql)
+ case other => other
+ }
}
} catch {
case e: Exception => throw new ParseException(sql, e)
@@ -259,19 +242,6 @@ private[hive] object HiveQl {
""".stripMargin)
}
}
-
- /** Creates LogicalPlan for a given HiveQL string. */
- def createPlan(sql: String) = {
- val tree = getAst(sql)
- if (nativeCommands contains tree.getText) {
- NativeCommand(sql)
- } else {
- nodeToPlan(tree) match {
- case NativePlaceholder => NativeCommand(sql)
- case other => other
- }
- }
- }
def parseDdl(ddl: String): Seq[Attribute] = {
val tree =
@@ -349,6 +319,7 @@ private[hive] object HiveQl {
case Token("TOK_STRING", Nil) => StringType
case Token("TOK_FLOAT", Nil) => FloatType
case Token("TOK_DOUBLE", Nil) => DoubleType
+ case Token("TOK_DATE", Nil) => DateType
case Token("TOK_TIMESTAMP", Nil) => TimestampType
case Token("TOK_BINARY", Nil) => BinaryType
case Token("TOK_LIST", elementType :: Nil) => ArrayType(nodeToDataType(elementType))
@@ -670,7 +641,7 @@ private[hive] object HiveQl {
def nodeToRelation(node: Node): LogicalPlan = node match {
case Token("TOK_SUBQUERY",
query :: Token(alias, Nil) :: Nil) =>
- Subquery(alias, nodeToPlan(query))
+ Subquery(cleanIdentifier(alias), nodeToPlan(query))
case Token(laterViewToken(isOuter), selectClause :: relationClause :: Nil) =>
val Token("TOK_SELECT",
@@ -837,11 +808,6 @@ private[hive] object HiveQl {
cleanIdentifier(key.toLowerCase) -> None
}.toMap).getOrElse(Map.empty)
- if (partitionKeys.values.exists(p => p.isEmpty)) {
- throw new NotImplementedError(s"Do not support INSERT INTO/OVERWRITE with" +
- s"dynamic partitioning.")
- }
-
InsertIntoTable(UnresolvedRelation(db, tableName, None), partitionKeys, query, overwrite)
case a: ASTNode =>
@@ -855,7 +821,7 @@ private[hive] object HiveQl {
case Token("TOK_SELEXPR",
e :: Token(alias, Nil) :: Nil) =>
- Some(Alias(nodeToExpr(e), alias)())
+ Some(Alias(nodeToExpr(e), cleanIdentifier(alias))())
/* Hints are ignored */
case Token("TOK_HINTLIST", _) => None
@@ -961,6 +927,8 @@ private[hive] object HiveQl {
Cast(nodeToExpr(arg), DecimalType)
case Token("TOK_FUNCTION", Token("TOK_TIMESTAMP", Nil) :: arg :: Nil) =>
Cast(nodeToExpr(arg), TimestampType)
+ case Token("TOK_FUNCTION", Token("TOK_DATE", Nil) :: arg :: Nil) =>
+ Cast(nodeToExpr(arg), DateType)
/* Arithmetic */
case Token("-", child :: Nil) => UnaryMinus(nodeToExpr(child))
@@ -1084,6 +1052,9 @@ private[hive] object HiveQl {
case ast: ASTNode if ast.getType == HiveParser.StringLiteral =>
Literal(BaseSemanticAnalyzer.unescapeSQLString(ast.getText))
+ case ast: ASTNode if ast.getType == HiveParser.TOK_DATELITERAL =>
+ Literal(Date.valueOf(ast.getText.substring(1, ast.getText.length - 1)))
+
case a: ASTNode =>
throw new NotImplementedError(
s"""No parse rules for ASTNode type: ${a.getType}, text: ${a.getText} :
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala
index 8ac17f37201a8..5c66322f1ed99 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala
@@ -25,7 +25,6 @@ import org.apache.spark.sql.catalyst.planning._
import org.apache.spark.sql.catalyst.plans._
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.catalyst.types.StringType
-import org.apache.spark.sql.columnar.InMemoryRelation
import org.apache.spark.sql.execution.{DescribeCommand, OutputFaker, SparkPlan}
import org.apache.spark.sql.hive
import org.apache.spark.sql.hive.execution._
@@ -161,20 +160,17 @@ private[hive] trait HiveStrategies {
def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
case logical.InsertIntoTable(table: MetastoreRelation, partition, child, overwrite) =>
InsertIntoHiveTable(table, partition, planLater(child), overwrite)(hiveContext) :: Nil
- case logical.InsertIntoTable(
- InMemoryRelation(_, _, _,
- HiveTableScan(_, table, _)), partition, child, overwrite) =>
- InsertIntoHiveTable(table, partition, planLater(child), overwrite)(hiveContext) :: Nil
+
case logical.CreateTableAsSelect(database, tableName, child) =>
val query = planLater(child)
CreateTableAsSelect(
database.get,
tableName,
query,
- InsertIntoHiveTable(_: MetastoreRelation,
- Map(),
- query,
- true)(hiveContext)) :: Nil
+ InsertIntoHiveTable(_: MetastoreRelation,
+ Map(),
+ query,
+ overwrite = true)(hiveContext)) :: Nil
case _ => Nil
}
}
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/TableReader.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/TableReader.scala
index 84fafcde63d05..0de29d5cffd0e 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/TableReader.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/TableReader.scala
@@ -19,6 +19,7 @@ package org.apache.spark.sql.hive
import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.fs.{Path, PathFilter}
+import org.apache.hadoop.hive.conf.HiveConf
import org.apache.hadoop.hive.metastore.api.hive_metastoreConstants._
import org.apache.hadoop.hive.ql.exec.Utilities
import org.apache.hadoop.hive.ql.metadata.{Partition => HivePartition, Table => HiveTable}
@@ -52,7 +53,8 @@ private[hive]
class HadoopTableReader(
@transient attributes: Seq[Attribute],
@transient relation: MetastoreRelation,
- @transient sc: HiveContext)
+ @transient sc: HiveContext,
+ @transient hiveExtraConf: HiveConf)
extends TableReader {
// Choose the minimum number of splits. If mapred.map.tasks is set, then use that unless
@@ -63,7 +65,7 @@ class HadoopTableReader(
// TODO: set aws s3 credentials.
private val _broadcastedHiveConf =
- sc.sparkContext.broadcast(new SerializableWritable(sc.hiveconf))
+ sc.sparkContext.broadcast(new SerializableWritable(hiveExtraConf))
def broadcastedHiveConf = _broadcastedHiveConf
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/TestHive.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/TestHive.scala
index 70fb15259e7d7..9a9e2eda6bcd4 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/TestHive.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/TestHive.scala
@@ -31,8 +31,9 @@ import org.apache.hadoop.hive.serde2.`lazy`.LazySimpleSerDe
import org.apache.hadoop.hive.serde2.avro.AvroSerDe
import org.apache.spark.{SparkConf, SparkContext}
+import org.apache.spark.util.Utils
import org.apache.spark.sql.catalyst.analysis._
-import org.apache.spark.sql.catalyst.plans.logical.{CacheCommand, LogicalPlan, NativeCommand}
+import org.apache.spark.sql.catalyst.plans.logical.{CacheTableCommand, LogicalPlan, NativeCommand}
import org.apache.spark.sql.catalyst.util._
import org.apache.spark.sql.hive._
import org.apache.spark.sql.SQLConf
@@ -40,8 +41,10 @@ import org.apache.spark.sql.SQLConf
/* Implicit conversions */
import scala.collection.JavaConversions._
+// SPARK-3729: Test key required to check for initialization errors with config.
object TestHive
- extends TestHiveContext(new SparkContext("local[2]", "TestSQLContext", new SparkConf()))
+ extends TestHiveContext(
+ new SparkContext("local[2]", "TestSQLContext", new SparkConf().set("spark.sql.test", "")))
/**
* A locally running test instance of Spark's Hive execution engine.
@@ -65,15 +68,18 @@ class TestHiveContext(sc: SparkContext) extends HiveContext(sc) {
lazy val metastorePath = getTempFilePath("sparkHiveMetastore").getCanonicalPath
/** Sets up the system initially or after a RESET command */
- protected def configure() {
+ protected def configure(): Unit = {
setConf("javax.jdo.option.ConnectionURL",
s"jdbc:derby:;databaseName=$metastorePath;create=true")
setConf("hive.metastore.warehouse.dir", warehousePath)
+ Utils.registerShutdownDeleteDir(new File(warehousePath))
+ Utils.registerShutdownDeleteDir(new File(metastorePath))
}
val testTempDir = File.createTempFile("testTempFiles", "spark.hive.tmp")
testTempDir.delete()
testTempDir.mkdir()
+ Utils.registerShutdownDeleteDir(testTempDir)
// For some hive test case which contain ${system:test.tmp.dir}
System.setProperty("test.tmp.dir", testTempDir.getCanonicalPath)
@@ -119,8 +125,7 @@ class TestHiveContext(sc: SparkContext) extends HiveContext(sc) {
val hiveFilesTemp = File.createTempFile("catalystHiveFiles", "")
hiveFilesTemp.delete()
hiveFilesTemp.mkdir()
- hiveFilesTemp.deleteOnExit()
-
+ Utils.registerShutdownDeleteDir(hiveFilesTemp)
val inRepoTests = if (System.getProperty("user.dir").endsWith("sql" + File.separator + "hive")) {
new File("src" + File.separator + "test" + File.separator + "resources" + File.separator)
@@ -152,7 +157,7 @@ class TestHiveContext(sc: SparkContext) extends HiveContext(sc) {
override lazy val analyzed = {
val describedTables = logical match {
case NativeCommand(describedTable(tbl)) => tbl :: Nil
- case CacheCommand(tbl, _) => tbl :: Nil
+ case CacheTableCommand(tbl, _, _) => tbl :: Nil
case _ => Nil
}
@@ -351,7 +356,7 @@ class TestHiveContext(sc: SparkContext) extends HiveContext(sc) {
var cacheTables: Boolean = false
def loadTestTable(name: String) {
if (!(loadedTables contains name)) {
- // Marks the table as loaded first to prevent infite mutually recursive table loading.
+ // Marks the table as loaded first to prevent infinite mutually recursive table loading.
loadedTables += name
logInfo(s"Loading test table $name")
val createCmds =
@@ -381,6 +386,7 @@ class TestHiveContext(sc: SparkContext) extends HiveContext(sc) {
log.asInstanceOf[org.apache.log4j.Logger].setLevel(org.apache.log4j.Level.WARN)
}
+ clearCache()
loadedTables.clear()
catalog.client.getAllTables("default").foreach { t =>
logDebug(s"Deleting table $t")
@@ -426,7 +432,7 @@ class TestHiveContext(sc: SparkContext) extends HiveContext(sc) {
loadTestTable("srcpart")
} catch {
case e: Exception =>
- logError(s"FATAL ERROR: Failed to reset TestDB state. $e")
+ logError("FATAL ERROR: Failed to reset TestDB state.", e)
// At this point there is really no reason to continue, but the test framework traps exits.
// So instead we just pause forever so that at least the developer can see where things
// started to go wrong.
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/CreateTableAsSelect.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/CreateTableAsSelect.scala
index 1017fe6d5396d..3625708d03175 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/CreateTableAsSelect.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/CreateTableAsSelect.scala
@@ -30,23 +30,23 @@ import org.apache.spark.sql.hive.MetastoreRelation
* Create table and insert the query result into it.
* @param database the database name of the new relation
* @param tableName the table name of the new relation
- * @param insertIntoRelation function of creating the `InsertIntoHiveTable`
+ * @param insertIntoRelation function of creating the `InsertIntoHiveTable`
* by specifying the `MetaStoreRelation`, the data will be inserted into that table.
* TODO Add more table creating properties, e.g. SerDe, StorageHandler, in-memory cache etc.
*/
@Experimental
case class CreateTableAsSelect(
- database: String,
- tableName: String,
- query: SparkPlan,
- insertIntoRelation: MetastoreRelation => InsertIntoHiveTable)
- extends LeafNode with Command {
+ database: String,
+ tableName: String,
+ query: SparkPlan,
+ insertIntoRelation: MetastoreRelation => InsertIntoHiveTable)
+ extends LeafNode with Command {
def output = Seq.empty
// A lazy computing of the metastoreRelation
private[this] lazy val metastoreRelation: MetastoreRelation = {
- // Create the table
+ // Create the table
val sc = sqlContext.asInstanceOf[HiveContext]
sc.catalog.createTable(database, tableName, query.output, false)
// Get the Metastore Relation
@@ -55,7 +55,7 @@ case class CreateTableAsSelect(
}
}
- override protected[sql] lazy val sideEffectResult: Seq[Row] = {
+ override protected lazy val sideEffectResult: Seq[Row] = {
insertIntoRelation(metastoreRelation).execute
Seq.empty[Row]
}
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/DescribeHiveTableCommand.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/DescribeHiveTableCommand.scala
index 317801001c7a4..106cede9788ec 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/DescribeHiveTableCommand.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/DescribeHiveTableCommand.scala
@@ -48,7 +48,7 @@ case class DescribeHiveTableCommand(
.mkString("\t")
}
- override protected[sql] lazy val sideEffectResult: Seq[Row] = {
+ override protected lazy val sideEffectResult: Seq[Row] = {
// Trying to mimic the format of Hive's output. But not exactly the same.
var results: Seq[(String, String, String)] = Nil
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveTableScan.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveTableScan.scala
index 577ca928b43b6..5b83b77d80a22 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveTableScan.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveTableScan.scala
@@ -64,8 +64,14 @@ case class HiveTableScan(
BindReferences.bindReference(pred, relation.partitionKeys)
}
+ // Create a local copy of hiveconf,so that scan specific modifications should not impact
+ // other queries
@transient
- private[this] val hadoopReader = new HadoopTableReader(attributes, relation, context)
+ private[this] val hiveExtraConf = new HiveConf(context.hiveconf)
+
+ @transient
+ private[this] val hadoopReader =
+ new HadoopTableReader(attributes, relation, context, hiveExtraConf)
private[this] def castFromString(value: String, dataType: DataType) = {
Cast(Literal(value), dataType).eval(null)
@@ -80,10 +86,14 @@ case class HiveTableScan(
ColumnProjectionUtils.appendReadColumnIDs(hiveConf, neededColumnIDs)
ColumnProjectionUtils.appendReadColumnNames(hiveConf, attributes.map(_.name))
+ val tableDesc = relation.tableDesc
+ val deserializer = tableDesc.getDeserializerClass.newInstance
+ deserializer.initialize(hiveConf, tableDesc.getProperties)
+
// Specifies types and object inspectors of columns to be scanned.
val structOI = ObjectInspectorUtils
.getStandardObjectInspector(
- relation.tableDesc.getDeserializer.getObjectInspector,
+ deserializer.getObjectInspector,
ObjectInspectorCopyOption.JAVA)
.asInstanceOf[StructObjectInspector]
@@ -97,7 +107,7 @@ case class HiveTableScan(
hiveConf.set(serdeConstants.LIST_COLUMNS, relation.attributes.map(_.name).mkString(","))
}
- addColumnMetadataToConf(context.hiveconf)
+ addColumnMetadataToConf(hiveExtraConf)
/**
* Prunes partitions not involve the query plan.
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala
index a284a91a91e31..f0785d8882636 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala
@@ -19,27 +19,25 @@ package org.apache.spark.sql.hive.execution
import scala.collection.JavaConversions._
-import java.util.{HashMap => JHashMap}
-
import org.apache.hadoop.hive.common.`type`.{HiveDecimal, HiveVarchar}
+import org.apache.hadoop.hive.conf.HiveConf
+import org.apache.hadoop.hive.conf.HiveConf.ConfVars
import org.apache.hadoop.hive.metastore.MetaStoreUtils
-import org.apache.hadoop.hive.ql.Context
import org.apache.hadoop.hive.ql.metadata.Hive
import org.apache.hadoop.hive.ql.plan.{FileSinkDesc, TableDesc}
+import org.apache.hadoop.hive.ql.{Context, ErrorMsg}
import org.apache.hadoop.hive.serde2.Serializer
-import org.apache.hadoop.hive.serde2.objectinspector._
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorUtils.ObjectInspectorCopyOption
-import org.apache.hadoop.hive.serde2.objectinspector.primitive.JavaHiveDecimalObjectInspector
-import org.apache.hadoop.hive.serde2.objectinspector.primitive.JavaHiveVarcharObjectInspector
-import org.apache.hadoop.io.Writable
+import org.apache.hadoop.hive.serde2.objectinspector._
+import org.apache.hadoop.hive.serde2.objectinspector.primitive.{JavaHiveDecimalObjectInspector, JavaHiveVarcharObjectInspector}
import org.apache.hadoop.mapred.{FileOutputCommitter, FileOutputFormat, JobConf}
-import org.apache.spark.{SparkException, TaskContext}
import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.expressions.Row
-import org.apache.spark.sql.execution.{SparkPlan, UnaryNode}
-import org.apache.spark.sql.hive.{HiveContext, MetastoreRelation, SparkHiveHadoopWriter}
+import org.apache.spark.sql.execution.{Command, SparkPlan, UnaryNode}
+import org.apache.spark.sql.hive._
+import org.apache.spark.{SerializableWritable, SparkException, TaskContext}
/**
* :: DeveloperApi ::
@@ -51,7 +49,7 @@ case class InsertIntoHiveTable(
child: SparkPlan,
overwrite: Boolean)
(@transient sc: HiveContext)
- extends UnaryNode {
+ extends UnaryNode with Command {
@transient lazy val outputClass = newSerializer(table.tableDesc).getSerializedClass
@transient private lazy val hiveContext = new Context(sc.hiveconf)
@@ -71,96 +69,95 @@ case class InsertIntoHiveTable(
* Wraps with Hive types based on object inspector.
* TODO: Consolidate all hive OI/data interface code.
*/
- protected def wrap(a: (Any, ObjectInspector)): Any = a match {
- case (s: String, oi: JavaHiveVarcharObjectInspector) =>
- new HiveVarchar(s, s.size)
-
- case (bd: BigDecimal, oi: JavaHiveDecimalObjectInspector) =>
- new HiveDecimal(bd.underlying())
-
- case (row: Row, oi: StandardStructObjectInspector) =>
- val struct = oi.create()
- row.zip(oi.getAllStructFieldRefs: Seq[StructField]).foreach {
- case (data, field) =>
- oi.setStructFieldData(struct, field, wrap(data, field.getFieldObjectInspector))
+ protected def wrapperFor(oi: ObjectInspector): Any => Any = oi match {
+ case _: JavaHiveVarcharObjectInspector =>
+ (o: Any) => new HiveVarchar(o.asInstanceOf[String], o.asInstanceOf[String].size)
+
+ case _: JavaHiveDecimalObjectInspector =>
+ (o: Any) => new HiveDecimal(o.asInstanceOf[BigDecimal].underlying())
+
+ case soi: StandardStructObjectInspector =>
+ val wrappers = soi.getAllStructFieldRefs.map(ref => wrapperFor(ref.getFieldObjectInspector))
+ (o: Any) => {
+ val struct = soi.create()
+ (soi.getAllStructFieldRefs, wrappers, o.asInstanceOf[Row]).zipped.foreach {
+ (field, wrapper, data) => soi.setStructFieldData(struct, field, wrapper(data))
+ }
+ struct
}
- struct
- case (s: Seq[_], oi: ListObjectInspector) =>
- val wrappedSeq = s.map(wrap(_, oi.getListElementObjectInspector))
- seqAsJavaList(wrappedSeq)
+ case loi: ListObjectInspector =>
+ val wrapper = wrapperFor(loi.getListElementObjectInspector)
+ (o: Any) => seqAsJavaList(o.asInstanceOf[Seq[_]].map(wrapper))
- case (m: Map[_, _], oi: MapObjectInspector) =>
- val keyOi = oi.getMapKeyObjectInspector
- val valueOi = oi.getMapValueObjectInspector
- val wrappedMap = m.map { case (key, value) => wrap(key, keyOi) -> wrap(value, valueOi) }
- mapAsJavaMap(wrappedMap)
+ case moi: MapObjectInspector =>
+ val keyWrapper = wrapperFor(moi.getMapKeyObjectInspector)
+ val valueWrapper = wrapperFor(moi.getMapValueObjectInspector)
+ (o: Any) => mapAsJavaMap(o.asInstanceOf[Map[_, _]].map { case (key, value) =>
+ keyWrapper(key) -> valueWrapper(value)
+ })
- case (obj, _) =>
- obj
+ case _ =>
+ identity[Any]
}
def saveAsHiveFile(
- rdd: RDD[Writable],
+ rdd: RDD[Row],
valueClass: Class[_],
fileSinkConf: FileSinkDesc,
- conf: JobConf,
- isCompressed: Boolean) {
- if (valueClass == null) {
- throw new SparkException("Output value class not set")
- }
- conf.setOutputValueClass(valueClass)
- if (fileSinkConf.getTableInfo.getOutputFileFormatClassName == null) {
- throw new SparkException("Output format class not set")
- }
- // Doesn't work in Scala 2.9 due to what may be a generics bug
- // TODO: Should we uncomment this for Scala 2.10?
- // conf.setOutputFormat(outputFormatClass)
- conf.set("mapred.output.format.class", fileSinkConf.getTableInfo.getOutputFileFormatClassName)
- if (isCompressed) {
- // Please note that isCompressed, "mapred.output.compress", "mapred.output.compression.codec",
- // and "mapred.output.compression.type" have no impact on ORC because it uses table properties
- // to store compression information.
- conf.set("mapred.output.compress", "true")
- fileSinkConf.setCompressed(true)
- fileSinkConf.setCompressCodec(conf.get("mapred.output.compression.codec"))
- fileSinkConf.setCompressType(conf.get("mapred.output.compression.type"))
- }
- conf.setOutputCommitter(classOf[FileOutputCommitter])
- FileOutputFormat.setOutputPath(
- conf,
- SparkHiveHadoopWriter.createPathFromString(fileSinkConf.getDirName, conf))
+ conf: SerializableWritable[JobConf],
+ writerContainer: SparkHiveWriterContainer): Unit = {
+ assert(valueClass != null, "Output value class not set")
+ conf.value.setOutputValueClass(valueClass)
+ val outputFileFormatClassName = fileSinkConf.getTableInfo.getOutputFileFormatClassName
+ assert(outputFileFormatClassName != null, "Output format class not set")
+ conf.value.set("mapred.output.format.class", outputFileFormatClassName)
+ conf.value.setOutputCommitter(classOf[FileOutputCommitter])
+
+ FileOutputFormat.setOutputPath(
+ conf.value,
+ SparkHiveWriterContainer.createPathFromString(fileSinkConf.getDirName, conf.value))
log.debug("Saving as hadoop file of type " + valueClass.getSimpleName)
- val writer = new SparkHiveHadoopWriter(conf, fileSinkConf)
- writer.preSetup()
+ writerContainer.driverSideSetup()
+ sc.sparkContext.runJob(rdd, writeToFile _)
+ writerContainer.commitJob()
+
+ // Note that this function is executed on executor side
+ def writeToFile(context: TaskContext, iterator: Iterator[Row]): Unit = {
+ val serializer = newSerializer(fileSinkConf.getTableInfo)
+ val standardOI = ObjectInspectorUtils
+ .getStandardObjectInspector(
+ fileSinkConf.getTableInfo.getDeserializer.getObjectInspector,
+ ObjectInspectorCopyOption.JAVA)
+ .asInstanceOf[StructObjectInspector]
+
+ val fieldOIs = standardOI.getAllStructFieldRefs.map(_.getFieldObjectInspector).toArray
+ val wrappers = fieldOIs.map(wrapperFor)
+ val outputData = new Array[Any](fieldOIs.length)
- def writeToFile(context: TaskContext, iter: Iterator[Writable]) {
// Hadoop wants a 32-bit task attempt ID, so if ours is bigger than Int.MaxValue, roll it
// around by taking a mod. We expect that no task will be attempted 2 billion times.
val attemptNumber = (context.attemptId % Int.MaxValue).toInt
+ writerContainer.executorSideSetup(context.stageId, context.partitionId, attemptNumber)
- writer.setup(context.stageId, context.partitionId, attemptNumber)
- writer.open()
+ iterator.foreach { row =>
+ var i = 0
+ while (i < fieldOIs.length) {
+ outputData(i) = if (row.isNullAt(i)) null else wrappers(i)(row(i))
+ i += 1
+ }
- var count = 0
- while(iter.hasNext) {
- val record = iter.next()
- count += 1
- writer.write(record)
+ writerContainer
+ .getLocalFileWriter(row)
+ .write(serializer.serialize(outputData, standardOI))
}
- writer.close()
- writer.commit()
+ writerContainer.close()
}
-
- sc.sparkContext.runJob(rdd, writeToFile _)
- writer.commitJob()
}
- override def execute() = result
-
/**
* Inserts all the rows in the table into Hive. Row objects are properly serialized with the
* `org.apache.hadoop.hive.serde2.SerDe` and the
@@ -168,50 +165,69 @@ case class InsertIntoHiveTable(
*
* Note: this is run once and then kept to avoid double insertions.
*/
- private lazy val result: RDD[Row] = {
- val childRdd = child.execute()
- assert(childRdd != null)
-
+ override protected[sql] lazy val sideEffectResult: Seq[Row] = {
// Have to pass the TableDesc object to RDD.mapPartitions and then instantiate new serializer
// instances within the closure, since Serializer is not serializable while TableDesc is.
val tableDesc = table.tableDesc
val tableLocation = table.hiveQlTable.getDataLocation
val tmpLocation = hiveContext.getExternalTmpFileURI(tableLocation)
val fileSinkConf = new FileSinkDesc(tmpLocation.toString, tableDesc, false)
- val rdd = childRdd.mapPartitions { iter =>
- val serializer = newSerializer(fileSinkConf.getTableInfo)
- val standardOI = ObjectInspectorUtils
- .getStandardObjectInspector(
- fileSinkConf.getTableInfo.getDeserializer.getObjectInspector,
- ObjectInspectorCopyOption.JAVA)
- .asInstanceOf[StructObjectInspector]
+ val isCompressed = sc.hiveconf.getBoolean(
+ ConfVars.COMPRESSRESULT.varname, ConfVars.COMPRESSRESULT.defaultBoolVal)
+ if (isCompressed) {
+ // Please note that isCompressed, "mapred.output.compress", "mapred.output.compression.codec",
+ // and "mapred.output.compression.type" have no impact on ORC because it uses table properties
+ // to store compression information.
+ sc.hiveconf.set("mapred.output.compress", "true")
+ fileSinkConf.setCompressed(true)
+ fileSinkConf.setCompressCodec(sc.hiveconf.get("mapred.output.compression.codec"))
+ fileSinkConf.setCompressType(sc.hiveconf.get("mapred.output.compression.type"))
+ }
- val fieldOIs = standardOI.getAllStructFieldRefs.map(_.getFieldObjectInspector).toArray
- val outputData = new Array[Any](fieldOIs.length)
- iter.map { row =>
- var i = 0
- while (i < row.length) {
- // Casts Strings to HiveVarchars when necessary.
- outputData(i) = wrap(row(i), fieldOIs(i))
- i += 1
- }
+ val numDynamicPartitions = partition.values.count(_.isEmpty)
+ val numStaticPartitions = partition.values.count(_.nonEmpty)
+ val partitionSpec = partition.map {
+ case (key, Some(value)) => key -> value
+ case (key, None) => key -> ""
+ }
- serializer.serialize(outputData, standardOI)
+ // All partition column names in the format of "//..."
+ val partitionColumns = fileSinkConf.getTableInfo.getProperties.getProperty("partition_columns")
+ val partitionColumnNames = Option(partitionColumns).map(_.split("/")).orNull
+
+ // Validate partition spec if there exist any dynamic partitions
+ if (numDynamicPartitions > 0) {
+ // Report error if dynamic partitioning is not enabled
+ if (!sc.hiveconf.getBoolVar(HiveConf.ConfVars.DYNAMICPARTITIONING)) {
+ throw new SparkException(ErrorMsg.DYNAMIC_PARTITION_DISABLED.getMsg)
+ }
+
+ // Report error if dynamic partition strict mode is on but no static partition is found
+ if (numStaticPartitions == 0 &&
+ sc.hiveconf.getVar(HiveConf.ConfVars.DYNAMICPARTITIONINGMODE).equalsIgnoreCase("strict")) {
+ throw new SparkException(ErrorMsg.DYNAMIC_PARTITION_STRICT_MODE.getMsg)
+ }
+
+ // Report error if any static partition appears after a dynamic partition
+ val isDynamic = partitionColumnNames.map(partitionSpec(_).isEmpty)
+ if (isDynamic.init.zip(isDynamic.tail).contains((true, false))) {
+ throw new SparkException(ErrorMsg.PARTITION_DYN_STA_ORDER.getMsg)
}
}
- // ORC stores compression information in table properties. While, there are other formats
- // (e.g. RCFile) that rely on hadoop configurations to store compression information.
val jobConf = new JobConf(sc.hiveconf)
- saveAsHiveFile(
- rdd,
- outputClass,
- fileSinkConf,
- jobConf,
- sc.hiveconf.getBoolean("hive.exec.compress.output", false))
-
- // TODO: Handle dynamic partitioning.
+ val jobConfSer = new SerializableWritable(jobConf)
+
+ val writerContainer = if (numDynamicPartitions > 0) {
+ val dynamicPartColNames = partitionColumnNames.takeRight(numDynamicPartitions)
+ new SparkHiveDynamicPartitionWriterContainer(jobConf, fileSinkConf, dynamicPartColNames)
+ } else {
+ new SparkHiveWriterContainer(jobConf, fileSinkConf)
+ }
+
+ saveAsHiveFile(child.execute(), outputClass, fileSinkConf, jobConfSer, writerContainer)
+
val outputPath = FileOutputFormat.getOutputPath(jobConf)
// Have to construct the format of dbname.tablename.
val qualifiedTableName = s"${table.databaseName}.${table.tableName}"
@@ -220,10 +236,6 @@ case class InsertIntoHiveTable(
// holdDDLTime will be true when TOK_HOLD_DDLTIME presents in the query as a hint.
val holdDDLTime = false
if (partition.nonEmpty) {
- val partitionSpec = partition.map {
- case (key, Some(value)) => key -> value
- case (key, None) => key -> "" // Should not reach here right now.
- }
val partVals = MetaStoreUtils.getPvals(table.hiveQlTable.getPartCols, partitionSpec)
db.validatePartitionNameCharacters(partVals)
// inheritTableSpecs is set to true. It should be set to false for a IMPORT query
@@ -231,14 +243,26 @@ case class InsertIntoHiveTable(
val inheritTableSpecs = true
// TODO: Correctly set isSkewedStoreAsSubdir.
val isSkewedStoreAsSubdir = false
- db.loadPartition(
- outputPath,
- qualifiedTableName,
- partitionSpec,
- overwrite,
- holdDDLTime,
- inheritTableSpecs,
- isSkewedStoreAsSubdir)
+ if (numDynamicPartitions > 0) {
+ db.loadDynamicPartitions(
+ outputPath,
+ qualifiedTableName,
+ partitionSpec,
+ overwrite,
+ numDynamicPartitions,
+ holdDDLTime,
+ isSkewedStoreAsSubdir
+ )
+ } else {
+ db.loadPartition(
+ outputPath,
+ qualifiedTableName,
+ partitionSpec,
+ overwrite,
+ holdDDLTime,
+ inheritTableSpecs,
+ isSkewedStoreAsSubdir)
+ }
} else {
db.loadTable(
outputPath,
@@ -247,10 +271,13 @@ case class InsertIntoHiveTable(
holdDDLTime)
}
+ // Invalidate the cache.
+ sqlContext.invalidateCache(table)
+
// It would be nice to just return the childRdd unchanged so insert operations could be chained,
// however for now we return an empty list to simplify compatibility checks with hive, which
// does not return anything for insert operations.
// TODO: implement hive compatibility as rules.
- sc.sparkContext.makeRDD(Nil, 1)
+ Seq.empty[Row]
}
}
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/NativeCommand.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/NativeCommand.scala
index 8f10e1ba7f426..6930c2babd117 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/NativeCommand.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/NativeCommand.scala
@@ -32,7 +32,7 @@ case class NativeCommand(
@transient context: HiveContext)
extends LeafNode with Command {
- override protected[sql] lazy val sideEffectResult: Seq[Row] = context.runSqlHive(sql).map(Row(_))
+ override protected lazy val sideEffectResult: Seq[Row] = context.runSqlHive(sql).map(Row(_))
override def otherCopyArgs = context :: Nil
}
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/commands.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/commands.scala
index d61c5e274a596..0fc674af31885 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/commands.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/commands.scala
@@ -37,7 +37,7 @@ case class AnalyzeTable(tableName: String) extends LeafNode with Command {
def output = Seq.empty
- override protected[sql] lazy val sideEffectResult: Seq[Row] = {
+ override protected lazy val sideEffectResult: Seq[Row] = {
hiveContext.analyze(tableName)
Seq.empty[Row]
}
@@ -53,7 +53,7 @@ case class DropTable(tableName: String, ifExists: Boolean) extends LeafNode with
def output = Seq.empty
- override protected[sql] lazy val sideEffectResult: Seq[Row] = {
+ override protected lazy val sideEffectResult: Seq[Row] = {
val ifExistsClause = if (ifExists) "IF EXISTS " else ""
hiveContext.runSqlHive(s"DROP TABLE $ifExistsClause$tableName")
hiveContext.catalog.unregisterTable(None, tableName)
@@ -70,7 +70,7 @@ case class AddJar(path: String) extends LeafNode with Command {
override def output = Seq.empty
- override protected[sql] lazy val sideEffectResult: Seq[Row] = {
+ override protected lazy val sideEffectResult: Seq[Row] = {
hiveContext.runSqlHive(s"ADD JAR $path")
hiveContext.sparkContext.addJar(path)
Seq.empty[Row]
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala
index 732e4976f6843..68f93f247d9bb 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala
@@ -22,7 +22,7 @@ import org.apache.hadoop.hive.ql.udf.generic.GenericUDFUtils.ConversionHelper
import scala.collection.mutable.ArrayBuffer
import org.apache.hadoop.hive.common.`type`.HiveDecimal
-import org.apache.hadoop.hive.ql.exec.UDF
+import org.apache.hadoop.hive.ql.exec.{UDF, UDAF}
import org.apache.hadoop.hive.ql.exec.{FunctionInfo, FunctionRegistry}
import org.apache.hadoop.hive.ql.udf.{UDFType => HiveUDFType}
import org.apache.hadoop.hive.ql.udf.generic._
@@ -57,7 +57,8 @@ private[hive] abstract class HiveFunctionRegistry
} else if (
classOf[AbstractGenericUDAFResolver].isAssignableFrom(functionInfo.getFunctionClass)) {
HiveGenericUdaf(functionClassName, children)
-
+ } else if (classOf[UDAF].isAssignableFrom(functionInfo.getFunctionClass)) {
+ HiveUdaf(functionClassName, children)
} else if (classOf[GenericUDTF].isAssignableFrom(functionInfo.getFunctionClass)) {
HiveGenericUdtf(functionClassName, Nil, children)
} else {
@@ -194,6 +195,37 @@ private[hive] case class HiveGenericUdaf(
def newInstance() = new HiveUdafFunction(functionClassName, children, this)
}
+/** It is used as a wrapper for the hive functions which uses UDAF interface */
+private[hive] case class HiveUdaf(
+ functionClassName: String,
+ children: Seq[Expression]) extends AggregateExpression
+ with HiveInspectors
+ with HiveFunctionFactory {
+
+ type UDFType = UDAF
+
+ @transient
+ protected lazy val resolver: AbstractGenericUDAFResolver = new GenericUDAFBridge(createFunction())
+
+ @transient
+ protected lazy val objectInspector = {
+ resolver.getEvaluator(children.map(_.dataType.toTypeInfo).toArray)
+ .init(GenericUDAFEvaluator.Mode.COMPLETE, inspectors.toArray)
+ }
+
+ @transient
+ protected lazy val inspectors = children.map(_.dataType).map(toInspector)
+
+ def dataType: DataType = inspectorToDataType(objectInspector)
+
+ def nullable: Boolean = true
+
+ override def toString = s"$nodeName#$functionClassName(${children.mkString(",")})"
+
+ def newInstance() =
+ new HiveUdafFunction(functionClassName, children, this, true)
+}
+
/**
* Converts a Hive Generic User Defined Table Generating Function (UDTF) to a
* [[catalyst.expressions.Generator Generator]]. Note that the semantics of Generators do not allow
@@ -275,14 +307,20 @@ private[hive] case class HiveGenericUdtf(
private[hive] case class HiveUdafFunction(
functionClassName: String,
exprs: Seq[Expression],
- base: AggregateExpression)
+ base: AggregateExpression,
+ isUDAFBridgeRequired: Boolean = false)
extends AggregateFunction
with HiveInspectors
with HiveFunctionFactory {
def this() = this(null, null, null)
- private val resolver = createFunction[AbstractGenericUDAFResolver]()
+ private val resolver =
+ if (isUDAFBridgeRequired) {
+ new GenericUDAFBridge(createFunction[UDAF]())
+ } else {
+ createFunction[AbstractGenericUDAFResolver]()
+ }
private val inspectors = exprs.map(_.dataType).map(toInspector).toArray
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveWriterContainers.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveWriterContainers.scala
new file mode 100644
index 0000000000000..6ccbc22a4acfb
--- /dev/null
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveWriterContainers.scala
@@ -0,0 +1,235 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.hive
+
+import java.io.IOException
+import java.text.NumberFormat
+import java.util.Date
+
+import scala.collection.mutable
+
+import org.apache.hadoop.fs.Path
+import org.apache.hadoop.hive.conf.HiveConf.ConfVars
+import org.apache.hadoop.hive.ql.exec.{FileSinkOperator, Utilities}
+import org.apache.hadoop.hive.ql.io.{HiveFileFormatUtils, HiveOutputFormat}
+import org.apache.hadoop.hive.ql.plan.FileSinkDesc
+import org.apache.hadoop.io.Writable
+import org.apache.hadoop.mapred._
+
+import org.apache.spark.sql.Row
+import org.apache.spark.{Logging, SerializableWritable, SparkHadoopWriter}
+
+/**
+ * Internal helper class that saves an RDD using a Hive OutputFormat.
+ * It is based on [[SparkHadoopWriter]].
+ */
+private[hive] class SparkHiveWriterContainer(
+ @transient jobConf: JobConf,
+ fileSinkConf: FileSinkDesc)
+ extends Logging
+ with SparkHadoopMapRedUtil
+ with Serializable {
+
+ private val now = new Date()
+ protected val conf = new SerializableWritable(jobConf)
+
+ private var jobID = 0
+ private var splitID = 0
+ private var attemptID = 0
+ private var jID: SerializableWritable[JobID] = null
+ private var taID: SerializableWritable[TaskAttemptID] = null
+
+ @transient private var writer: FileSinkOperator.RecordWriter = null
+ @transient protected lazy val committer = conf.value.getOutputCommitter
+ @transient protected lazy val jobContext = newJobContext(conf.value, jID.value)
+ @transient private lazy val taskContext = newTaskAttemptContext(conf.value, taID.value)
+ @transient private lazy val outputFormat =
+ conf.value.getOutputFormat.asInstanceOf[HiveOutputFormat[AnyRef,Writable]]
+
+ def driverSideSetup() {
+ setIDs(0, 0, 0)
+ setConfParams()
+ committer.setupJob(jobContext)
+ }
+
+ def executorSideSetup(jobId: Int, splitId: Int, attemptId: Int) {
+ setIDs(jobId, splitId, attemptId)
+ setConfParams()
+ committer.setupTask(taskContext)
+ initWriters()
+ }
+
+ protected def getOutputName: String = {
+ val numberFormat = NumberFormat.getInstance()
+ numberFormat.setMinimumIntegerDigits(5)
+ numberFormat.setGroupingUsed(false)
+ val extension = Utilities.getFileExtension(conf.value, fileSinkConf.getCompressed, outputFormat)
+ "part-" + numberFormat.format(splitID) + extension
+ }
+
+ def getLocalFileWriter(row: Row): FileSinkOperator.RecordWriter = writer
+
+ def close() {
+ // Seems the boolean value passed into close does not matter.
+ writer.close(false)
+ commit()
+ }
+
+ def commitJob() {
+ committer.commitJob(jobContext)
+ }
+
+ protected def initWriters() {
+ // NOTE this method is executed at the executor side.
+ // For Hive tables without partitions or with only static partitions, only 1 writer is needed.
+ writer = HiveFileFormatUtils.getHiveRecordWriter(
+ conf.value,
+ fileSinkConf.getTableInfo,
+ conf.value.getOutputValueClass.asInstanceOf[Class[Writable]],
+ fileSinkConf,
+ FileOutputFormat.getTaskOutputPath(conf.value, getOutputName),
+ Reporter.NULL)
+ }
+
+ protected def commit() {
+ if (committer.needsTaskCommit(taskContext)) {
+ try {
+ committer.commitTask(taskContext)
+ logInfo (taID + ": Committed")
+ } catch {
+ case e: IOException =>
+ logError("Error committing the output of task: " + taID.value, e)
+ committer.abortTask(taskContext)
+ throw e
+ }
+ } else {
+ logInfo("No need to commit output of task: " + taID.value)
+ }
+ }
+
+ private def setIDs(jobId: Int, splitId: Int, attemptId: Int) {
+ jobID = jobId
+ splitID = splitId
+ attemptID = attemptId
+
+ jID = new SerializableWritable[JobID](SparkHadoopWriter.createJobID(now, jobId))
+ taID = new SerializableWritable[TaskAttemptID](
+ new TaskAttemptID(new TaskID(jID.value, true, splitID), attemptID))
+ }
+
+ private def setConfParams() {
+ conf.value.set("mapred.job.id", jID.value.toString)
+ conf.value.set("mapred.tip.id", taID.value.getTaskID.toString)
+ conf.value.set("mapred.task.id", taID.value.toString)
+ conf.value.setBoolean("mapred.task.is.map", true)
+ conf.value.setInt("mapred.task.partition", splitID)
+ }
+}
+
+private[hive] object SparkHiveWriterContainer {
+ def createPathFromString(path: String, conf: JobConf): Path = {
+ if (path == null) {
+ throw new IllegalArgumentException("Output path is null")
+ }
+ val outputPath = new Path(path)
+ val fs = outputPath.getFileSystem(conf)
+ if (outputPath == null || fs == null) {
+ throw new IllegalArgumentException("Incorrectly formatted output path")
+ }
+ outputPath.makeQualified(fs.getUri, fs.getWorkingDirectory)
+ }
+}
+
+private[spark] object SparkHiveDynamicPartitionWriterContainer {
+ val SUCCESSFUL_JOB_OUTPUT_DIR_MARKER = "mapreduce.fileoutputcommitter.marksuccessfuljobs"
+}
+
+private[spark] class SparkHiveDynamicPartitionWriterContainer(
+ @transient jobConf: JobConf,
+ fileSinkConf: FileSinkDesc,
+ dynamicPartColNames: Array[String])
+ extends SparkHiveWriterContainer(jobConf, fileSinkConf) {
+
+ import SparkHiveDynamicPartitionWriterContainer._
+
+ private val defaultPartName = jobConf.get(
+ ConfVars.DEFAULTPARTITIONNAME.varname, ConfVars.DEFAULTPARTITIONNAME.defaultVal)
+
+ @transient private var writers: mutable.HashMap[String, FileSinkOperator.RecordWriter] = _
+
+ override protected def initWriters(): Unit = {
+ // NOTE: This method is executed at the executor side.
+ // Actual writers are created for each dynamic partition on the fly.
+ writers = mutable.HashMap.empty[String, FileSinkOperator.RecordWriter]
+ }
+
+ override def close(): Unit = {
+ writers.values.foreach(_.close(false))
+ commit()
+ }
+
+ override def commitJob(): Unit = {
+ // This is a hack to avoid writing _SUCCESS mark file. In lower versions of Hadoop (e.g. 1.0.4),
+ // semantics of FileSystem.globStatus() is different from higher versions (e.g. 2.4.1) and will
+ // include _SUCCESS file when glob'ing for dynamic partition data files.
+ //
+ // Better solution is to add a step similar to what Hive FileSinkOperator.jobCloseOp does:
+ // calling something like Utilities.mvFileToFinalPath to cleanup the output directory and then
+ // load it with loadDynamicPartitions/loadPartition/loadTable.
+ val oldMarker = jobConf.getBoolean(SUCCESSFUL_JOB_OUTPUT_DIR_MARKER, true)
+ jobConf.setBoolean(SUCCESSFUL_JOB_OUTPUT_DIR_MARKER, false)
+ super.commitJob()
+ jobConf.setBoolean(SUCCESSFUL_JOB_OUTPUT_DIR_MARKER, oldMarker)
+ }
+
+ override def getLocalFileWriter(row: Row): FileSinkOperator.RecordWriter = {
+ val dynamicPartPath = dynamicPartColNames
+ .zip(row.takeRight(dynamicPartColNames.length))
+ .map { case (col, rawVal) =>
+ val string = if (rawVal == null) null else String.valueOf(rawVal)
+ s"/$col=${if (string == null || string.isEmpty) defaultPartName else string}"
+ }
+ .mkString
+
+ def newWriter = {
+ val newFileSinkDesc = new FileSinkDesc(
+ fileSinkConf.getDirName + dynamicPartPath,
+ fileSinkConf.getTableInfo,
+ fileSinkConf.getCompressed)
+ newFileSinkDesc.setCompressCodec(fileSinkConf.getCompressCodec)
+ newFileSinkDesc.setCompressType(fileSinkConf.getCompressType)
+
+ val path = {
+ val outputPath = FileOutputFormat.getOutputPath(conf.value)
+ assert(outputPath != null, "Undefined job output-path")
+ val workPath = new Path(outputPath, dynamicPartPath.stripPrefix("/"))
+ new Path(workPath, getOutputName)
+ }
+
+ HiveFileFormatUtils.getHiveRecordWriter(
+ conf.value,
+ fileSinkConf.getTableInfo,
+ conf.value.getOutputValueClass.asInstanceOf[Class[Writable]],
+ newFileSinkDesc,
+ path,
+ Reporter.NULL)
+ }
+
+ writers.getOrElseUpdate(dynamicPartPath, newWriter)
+ }
+}
diff --git a/sql/hive/src/test/java/org/apache/spark/sql/hive/execution/UDFIntegerToString.java b/sql/hive/src/test/java/org/apache/spark/sql/hive/execution/UDFIntegerToString.java
new file mode 100644
index 0000000000000..6c4f378bc5471
--- /dev/null
+++ b/sql/hive/src/test/java/org/apache/spark/sql/hive/execution/UDFIntegerToString.java
@@ -0,0 +1,26 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.hive.execution;
+
+import org.apache.hadoop.hive.ql.exec.UDF;
+
+public class UDFIntegerToString extends UDF {
+ public String evaluate(Integer i) {
+ return i.toString();
+ }
+}
diff --git a/sql/hive/src/test/java/org/apache/spark/sql/hive/execution/UDFListListInt.java b/sql/hive/src/test/java/org/apache/spark/sql/hive/execution/UDFListListInt.java
new file mode 100644
index 0000000000000..d2d39a8c4dc28
--- /dev/null
+++ b/sql/hive/src/test/java/org/apache/spark/sql/hive/execution/UDFListListInt.java
@@ -0,0 +1,51 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.hive.execution;
+
+import org.apache.hadoop.hive.ql.exec.UDF;
+
+import java.util.List;
+
+public class UDFListListInt extends UDF {
+ /**
+ *
+ * @param obj
+ * SQL schema: array>
+ * Java Type: List>
+ * @return
+ */
+ public long evaluate(Object obj) {
+ if (obj == null) {
+ return 0l;
+ }
+ List listList = (List) obj;
+ long retVal = 0;
+ for (List aList : listList) {
+ @SuppressWarnings("unchecked")
+ List
+
+
+
+ hadoop-2.2
+
+ 1.9
+
+
+
+ org.mortbay.jetty
+ jetty
+ 6.1.26
+
+
+ org.mortbay.jetty
+ servlet-api
+
+
+ test
+
+
+ com.sun.jersey
+ jersey-core
+ ${jersey.version}
+ test
+
+
+ com.sun.jersey
+ jersey-json
+ ${jersey.version}
+ test
+
+
+ stax
+ stax-api
+
+
+
+
+ com.sun.jersey
+ jersey-server
+ ${jersey.version}
+ test
+
+
+
+
+
diff --git a/yarn/stable/src/main/scala/org/apache/spark/deploy/yarn/ExecutorRunnable.scala b/yarn/stable/src/main/scala/org/apache/spark/deploy/yarn/ExecutorRunnable.scala
index 833be12982e71..0b5a92d87d722 100644
--- a/yarn/stable/src/main/scala/org/apache/spark/deploy/yarn/ExecutorRunnable.scala
+++ b/yarn/stable/src/main/scala/org/apache/spark/deploy/yarn/ExecutorRunnable.scala
@@ -47,6 +47,7 @@ class ExecutorRunnable(
hostname: String,
executorMemory: Int,
executorCores: Int,
+ appId: String,
securityMgr: SecurityManager)
extends Runnable with ExecutorRunnableUtil with Logging {
@@ -80,7 +81,7 @@ class ExecutorRunnable(
ctx.setTokens(ByteBuffer.wrap(dob.getData()))
val commands = prepareCommand(masterAddress, slaveId, hostname, executorMemory, executorCores,
- localResources)
+ appId, localResources)
logInfo(s"Setting up executor with environment: $env")
logInfo("Setting up executor with commands: " + commands)
diff --git a/yarn/stable/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocationHandler.scala b/yarn/stable/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocationHandler.scala
index e44a8db41b97e..2bbf5d7db8668 100644
--- a/yarn/stable/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocationHandler.scala
+++ b/yarn/stable/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocationHandler.scala
@@ -41,7 +41,7 @@ private[yarn] class YarnAllocationHandler(
args: ApplicationMasterArguments,
preferredNodes: collection.Map[String, collection.Set[SplitInfo]],
securityMgr: SecurityManager)
- extends YarnAllocator(conf, sparkConf, args, preferredNodes, securityMgr) {
+ extends YarnAllocator(conf, sparkConf, appAttemptId, args, preferredNodes, securityMgr) {
override protected def releaseContainer(container: Container) = {
amClient.releaseAssignedContainer(container.getId())
diff --git a/yarn/stable/src/main/scala/org/apache/spark/deploy/yarn/YarnRMClientImpl.scala b/yarn/stable/src/main/scala/org/apache/spark/deploy/yarn/YarnRMClientImpl.scala
index 54bc6b14c44ce..8d4b96ed79933 100644
--- a/yarn/stable/src/main/scala/org/apache/spark/deploy/yarn/YarnRMClientImpl.scala
+++ b/yarn/stable/src/main/scala/org/apache/spark/deploy/yarn/YarnRMClientImpl.scala
@@ -17,8 +17,13 @@
package org.apache.spark.deploy.yarn
+import java.util.{List => JList}
+
import scala.collection.{Map, Set}
+import scala.collection.JavaConversions._
+import scala.util._
+import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.yarn.api._
import org.apache.hadoop.yarn.api.protocolrecords._
import org.apache.hadoop.yarn.api.records._
@@ -40,6 +45,7 @@ private class YarnRMClientImpl(args: ApplicationMasterArguments) extends YarnRMC
private var amClient: AMRMClient[ContainerRequest] = _
private var uiHistoryAddress: String = _
+ private var registered: Boolean = false
override def register(
conf: YarnConfiguration,
@@ -54,13 +60,19 @@ private class YarnRMClientImpl(args: ApplicationMasterArguments) extends YarnRMC
this.uiHistoryAddress = uiHistoryAddress
logInfo("Registering the ApplicationMaster")
- amClient.registerApplicationMaster(Utils.localHostName(), 0, uiAddress)
+ synchronized {
+ amClient.registerApplicationMaster(Utils.localHostName(), 0, uiAddress)
+ registered = true
+ }
new YarnAllocationHandler(conf, sparkConf, amClient, getAttemptId(), args,
preferredNodeLocations, securityMgr)
}
- override def shutdown(status: FinalApplicationStatus, diagnostics: String = "") =
- amClient.unregisterApplicationMaster(status, diagnostics, uiHistoryAddress)
+ override def unregister(status: FinalApplicationStatus, diagnostics: String = "") = synchronized {
+ if (registered) {
+ amClient.unregisterApplicationMaster(status, diagnostics, uiHistoryAddress)
+ }
+ }
override def getAttemptId() = {
val containerIdString = System.getenv(ApplicationConstants.Environment.CONTAINER_ID.name())
@@ -69,7 +81,28 @@ private class YarnRMClientImpl(args: ApplicationMasterArguments) extends YarnRMC
appAttemptId
}
- override def getProxyHostAndPort(conf: YarnConfiguration) = WebAppUtils.getProxyHostAndPort(conf)
+ override def getAmIpFilterParams(conf: YarnConfiguration, proxyBase: String) = {
+ // Figure out which scheme Yarn is using. Note the method seems to have been added after 2.2,
+ // so not all stable releases have it.
+ val prefix = Try(classOf[WebAppUtils].getMethod("getHttpSchemePrefix", classOf[Configuration])
+ .invoke(null, conf).asInstanceOf[String]).getOrElse("http://")
+
+ // If running a new enough Yarn, use the HA-aware API for retrieving the RM addresses.
+ try {
+ val method = classOf[WebAppUtils].getMethod("getProxyHostsAndPortsForAmFilter",
+ classOf[Configuration])
+ val proxies = method.invoke(null, conf).asInstanceOf[JList[String]]
+ val hosts = proxies.map { proxy => proxy.split(":")(0) }
+ val uriBases = proxies.map { proxy => prefix + proxy + proxyBase }
+ Map("PROXY_HOSTS" -> hosts.mkString(","), "PROXY_URI_BASES" -> uriBases.mkString(","))
+ } catch {
+ case e: NoSuchMethodException =>
+ val proxy = WebAppUtils.getProxyHostAndPort(conf)
+ val parts = proxy.split(":")
+ val uriBase = prefix + proxy + proxyBase
+ Map("PROXY_HOST" -> parts(0), "PROXY_URI_BASE" -> uriBase)
+ }
+ }
override def getMaxRegAttempts(conf: YarnConfiguration) =
conf.getInt(YarnConfiguration.RM_AM_MAX_ATTEMPTS, YarnConfiguration.DEFAULT_RM_AM_MAX_ATTEMPTS)
diff --git a/yarn/stable/src/test/resources/log4j.properties b/yarn/stable/src/test/resources/log4j.properties
index 26b73a1b39744..9dd05f17f012b 100644
--- a/yarn/stable/src/test/resources/log4j.properties
+++ b/yarn/stable/src/test/resources/log4j.properties
@@ -21,7 +21,7 @@ log4j.appender.file=org.apache.log4j.FileAppender
log4j.appender.file.append=false
log4j.appender.file.file=target/unit-tests.log
log4j.appender.file.layout=org.apache.log4j.PatternLayout
-log4j.appender.file.layout.ConversionPattern=%d{yy/MM/dd HH:mm:ss.SSS} %p %c{1}: %m%n
+log4j.appender.file.layout.ConversionPattern=%d{yy/MM/dd HH:mm:ss.SSS} %t %p %c{1}: %m%n
# Ignore messages below warning level from Jetty, because it's a bit verbose
log4j.logger.org.eclipse.jetty=WARN
diff --git a/yarn/stable/src/test/scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala b/yarn/stable/src/test/scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala
index 4b6635679f053..a826b2a78a8f5 100644
--- a/yarn/stable/src/test/scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala
+++ b/yarn/stable/src/test/scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala
@@ -18,6 +18,7 @@
package org.apache.spark.deploy.yarn
import java.io.File
+import java.util.concurrent.TimeUnit
import scala.collection.JavaConversions._
@@ -32,7 +33,7 @@ import org.apache.spark.{Logging, SparkConf, SparkContext}
import org.apache.spark.deploy.SparkHadoopUtil
import org.apache.spark.util.Utils
-class YarnClusterSuite extends FunSuite with BeforeAndAfterAll with Matchers {
+class YarnClusterSuite extends FunSuite with BeforeAndAfterAll with Matchers with Logging {
// log4j configuration for the Yarn containers, so that their output is collected
// by Yarn instead of trying to overwrite unit-tests.log.
@@ -66,7 +67,33 @@ class YarnClusterSuite extends FunSuite with BeforeAndAfterAll with Matchers {
yarnCluster = new MiniYARNCluster(getClass().getName(), 1, 1, 1)
yarnCluster.init(new YarnConfiguration())
yarnCluster.start()
- yarnCluster.getConfig().foreach { e =>
+
+ // There's a race in MiniYARNCluster in which start() may return before the RM has updated
+ // its address in the configuration. You can see this in the logs by noticing that when
+ // MiniYARNCluster prints the address, it still has port "0" assigned, although later the
+ // test works sometimes:
+ //
+ // INFO MiniYARNCluster: MiniYARN ResourceManager address: blah:0
+ //
+ // That log message prints the contents of the RM_ADDRESS config variable. If you check it
+ // later on, it looks something like this:
+ //
+ // INFO YarnClusterSuite: RM address in configuration is blah:42631
+ //
+ // This hack loops for a bit waiting for the port to change, and fails the test if it hasn't
+ // done so in a timely manner (defined to be 10 seconds).
+ val config = yarnCluster.getConfig()
+ val deadline = System.currentTimeMillis() + TimeUnit.SECONDS.toMillis(10)
+ while (config.get(YarnConfiguration.RM_ADDRESS).split(":")(1) == "0") {
+ if (System.currentTimeMillis() > deadline) {
+ throw new IllegalStateException("Timed out waiting for RM to come up.")
+ }
+ logDebug("RM address still not set in configuration, waiting...")
+ TimeUnit.MILLISECONDS.sleep(100)
+ }
+
+ logInfo(s"RM address in configuration is ${config.get(YarnConfiguration.RM_ADDRESS)}")
+ config.foreach { e =>
sys.props += ("spark.hadoop." + e.getKey() -> e.getValue())
}
@@ -86,13 +113,13 @@ class YarnClusterSuite extends FunSuite with BeforeAndAfterAll with Matchers {
super.afterAll()
}
- ignore("run Spark in yarn-client mode") {
+ test("run Spark in yarn-client mode") {
var result = File.createTempFile("result", null, tempDir)
YarnClusterDriver.main(Array("yarn-client", result.getAbsolutePath()))
checkResult(result)
}
- ignore("run Spark in yarn-cluster mode") {
+ test("run Spark in yarn-cluster mode") {
val main = YarnClusterDriver.getClass.getName().stripSuffix("$")
var result = File.createTempFile("result", null, tempDir)