diff --git a/.gitattributes b/.gitattributes new file mode 100644 index 0000000000000..2b65f6fe3cc80 --- /dev/null +++ b/.gitattributes @@ -0,0 +1,2 @@ +*.bat text eol=crlf +*.cmd text eol=crlf diff --git a/.rat-excludes b/.rat-excludes index ae9745673c87d..20e3372464386 100644 --- a/.rat-excludes +++ b/.rat-excludes @@ -1,5 +1,6 @@ target .gitignore +.gitattributes .project .classpath .mima-excludes diff --git a/LICENSE b/LICENSE index a7eee041129cb..f1732fb47afc0 100644 --- a/LICENSE +++ b/LICENSE @@ -712,18 +712,6 @@ THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -======================================================================== -For colt: -======================================================================== - -Copyright (c) 1999 CERN - European Organization for Nuclear Research. -Permission to use, copy, modify, distribute and sell this software and its documentation for any purpose is hereby granted without fee, provided that the above copyright notice appear in all copies and that both that copyright notice and this permission notice appear in supporting documentation. CERN makes no representations about the suitability of this software for any purpose. It is provided "as is" without expressed or implied warranty. - -Packages hep.aida.* - -Written by Pavel Binko, Dino Ferrero Merlino, Wolfgang Hoschek, Tony Johnson, Andreas Pfeiffer, and others. Check the FreeHEP home page for more info. Permission to use and/or redistribute this work is granted under the terms of the LGPL License, with the exception that any usage related to military applications is expressly forbidden. The software and documentation made available under the terms of this license are provided with no warranty. - - ======================================================================== For SnapTree: ======================================================================== diff --git a/README.md b/README.md index dbf53dcd76b2d..9916ac7b1ae8e 100644 --- a/README.md +++ b/README.md @@ -84,7 +84,7 @@ storage systems. Because the protocols have changed in different versions of Hadoop, you must build Spark against the same version that your cluster runs. Please refer to the build documentation at -["Specifying the Hadoop Version"](http://spark.apache.org/docs/latest/building-spark.html#specifying-the-hadoop-version) +["Specifying the Hadoop Version"](http://spark.apache.org/docs/latest/building-with-maven.html#specifying-the-hadoop-version) for detailed guidance on building for a particular distribution of Hadoop, including building for particular Hive and Hive Thriftserver distributions. See also ["Third Party Hadoop Distributions"](http://spark.apache.org/docs/latest/hadoop-third-party-distributions.html) diff --git a/assembly/pom.xml b/assembly/pom.xml index bfef95b8deb95..31a01e4d8e1de 100644 --- a/assembly/pom.xml +++ b/assembly/pom.xml @@ -197,12 +197,6 @@ spark-hive_${scala.binary.version} ${project.version} - - - - - hive-0.12.0 - org.apache.spark spark-hive-thriftserver_${scala.binary.version} diff --git a/bin/compute-classpath.cmd b/bin/compute-classpath.cmd index 3cd0579aea8d3..a4c099fb45b14 100644 --- a/bin/compute-classpath.cmd +++ b/bin/compute-classpath.cmd @@ -1,117 +1,117 @@ -@echo off - -rem -rem Licensed to the Apache Software Foundation (ASF) under one or more -rem contributor license agreements. See the NOTICE file distributed with -rem this work for additional information regarding copyright ownership. -rem The ASF licenses this file to You under the Apache License, Version 2.0 -rem (the "License"); you may not use this file except in compliance with -rem the License. You may obtain a copy of the License at -rem -rem http://www.apache.org/licenses/LICENSE-2.0 -rem -rem Unless required by applicable law or agreed to in writing, software -rem distributed under the License is distributed on an "AS IS" BASIS, -rem WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -rem See the License for the specific language governing permissions and -rem limitations under the License. -rem - -rem This script computes Spark's classpath and prints it to stdout; it's used by both the "run" -rem script and the ExecutorRunner in standalone cluster mode. - -rem If we're called from spark-class2.cmd, it already set enabledelayedexpansion and setting -rem it here would stop us from affecting its copy of the CLASSPATH variable; otherwise we -rem need to set it here because we use !datanucleus_jars! below. -if "%DONT_PRINT_CLASSPATH%"=="1" goto skip_delayed_expansion -setlocal enabledelayedexpansion -:skip_delayed_expansion - -set SCALA_VERSION=2.10 - -rem Figure out where the Spark framework is installed -set FWDIR=%~dp0..\ - -rem Load environment variables from conf\spark-env.cmd, if it exists -if exist "%FWDIR%conf\spark-env.cmd" call "%FWDIR%conf\spark-env.cmd" - -rem Build up classpath -set CLASSPATH=%SPARK_CLASSPATH%;%SPARK_SUBMIT_CLASSPATH% - -if not "x%SPARK_CONF_DIR%"=="x" ( - set CLASSPATH=%CLASSPATH%;%SPARK_CONF_DIR% -) else ( - set CLASSPATH=%CLASSPATH%;%FWDIR%conf -) - -if exist "%FWDIR%RELEASE" ( - for %%d in ("%FWDIR%lib\spark-assembly*.jar") do ( - set ASSEMBLY_JAR=%%d - ) -) else ( - for %%d in ("%FWDIR%assembly\target\scala-%SCALA_VERSION%\spark-assembly*hadoop*.jar") do ( - set ASSEMBLY_JAR=%%d - ) -) - -set CLASSPATH=%CLASSPATH%;%ASSEMBLY_JAR% - -rem When Hive support is needed, Datanucleus jars must be included on the classpath. -rem Datanucleus jars do not work if only included in the uber jar as plugin.xml metadata is lost. -rem Both sbt and maven will populate "lib_managed/jars/" with the datanucleus jars when Spark is -rem built with Hive, so look for them there. -if exist "%FWDIR%RELEASE" ( - set datanucleus_dir=%FWDIR%lib -) else ( - set datanucleus_dir=%FWDIR%lib_managed\jars -) -set "datanucleus_jars=" -for %%d in ("%datanucleus_dir%\datanucleus-*.jar") do ( - set datanucleus_jars=!datanucleus_jars!;%%d -) -set CLASSPATH=%CLASSPATH%;%datanucleus_jars% - -set SPARK_CLASSES=%FWDIR%core\target\scala-%SCALA_VERSION%\classes -set SPARK_CLASSES=%SPARK_CLASSES%;%FWDIR%repl\target\scala-%SCALA_VERSION%\classes -set SPARK_CLASSES=%SPARK_CLASSES%;%FWDIR%mllib\target\scala-%SCALA_VERSION%\classes -set SPARK_CLASSES=%SPARK_CLASSES%;%FWDIR%bagel\target\scala-%SCALA_VERSION%\classes -set SPARK_CLASSES=%SPARK_CLASSES%;%FWDIR%graphx\target\scala-%SCALA_VERSION%\classes -set SPARK_CLASSES=%SPARK_CLASSES%;%FWDIR%streaming\target\scala-%SCALA_VERSION%\classes -set SPARK_CLASSES=%SPARK_CLASSES%;%FWDIR%tools\target\scala-%SCALA_VERSION%\classes -set SPARK_CLASSES=%SPARK_CLASSES%;%FWDIR%sql\catalyst\target\scala-%SCALA_VERSION%\classes -set SPARK_CLASSES=%SPARK_CLASSES%;%FWDIR%sql\core\target\scala-%SCALA_VERSION%\classes -set SPARK_CLASSES=%SPARK_CLASSES%;%FWDIR%sql\hive\target\scala-%SCALA_VERSION%\classes - -set SPARK_TEST_CLASSES=%FWDIR%core\target\scala-%SCALA_VERSION%\test-classes -set SPARK_TEST_CLASSES=%SPARK_TEST_CLASSES%;%FWDIR%repl\target\scala-%SCALA_VERSION%\test-classes -set SPARK_TEST_CLASSES=%SPARK_TEST_CLASSES%;%FWDIR%mllib\target\scala-%SCALA_VERSION%\test-classes -set SPARK_TEST_CLASSES=%SPARK_TEST_CLASSES%;%FWDIR%bagel\target\scala-%SCALA_VERSION%\test-classes -set SPARK_TEST_CLASSES=%SPARK_TEST_CLASSES%;%FWDIR%graphx\target\scala-%SCALA_VERSION%\test-classes -set SPARK_TEST_CLASSES=%SPARK_TEST_CLASSES%;%FWDIR%streaming\target\scala-%SCALA_VERSION%\test-classes -set SPARK_TEST_CLASSES=%SPARK_TEST_CLASSES%;%FWDIR%sql\catalyst\target\scala-%SCALA_VERSION%\test-classes -set SPARK_TEST_CLASSES=%SPARK_TEST_CLASSES%;%FWDIR%sql\core\target\scala-%SCALA_VERSION%\test-classes -set SPARK_TEST_CLASSES=%SPARK_TEST_CLASSES%;%FWDIR%sql\hive\target\scala-%SCALA_VERSION%\test-classes - -if "x%SPARK_TESTING%"=="x1" ( - rem Add test clases to path - note, add SPARK_CLASSES and SPARK_TEST_CLASSES before CLASSPATH - rem so that local compilation takes precedence over assembled jar - set CLASSPATH=%SPARK_CLASSES%;%SPARK_TEST_CLASSES%;%CLASSPATH% -) - -rem Add hadoop conf dir - else FileSystem.*, etc fail -rem Note, this assumes that there is either a HADOOP_CONF_DIR or YARN_CONF_DIR which hosts -rem the configurtion files. -if "x%HADOOP_CONF_DIR%"=="x" goto no_hadoop_conf_dir - set CLASSPATH=%CLASSPATH%;%HADOOP_CONF_DIR% -:no_hadoop_conf_dir - -if "x%YARN_CONF_DIR%"=="x" goto no_yarn_conf_dir - set CLASSPATH=%CLASSPATH%;%YARN_CONF_DIR% -:no_yarn_conf_dir - -rem A bit of a hack to allow calling this script within run2.cmd without seeing output -if "%DONT_PRINT_CLASSPATH%"=="1" goto exit - -echo %CLASSPATH% - -:exit +@echo off + +rem +rem Licensed to the Apache Software Foundation (ASF) under one or more +rem contributor license agreements. See the NOTICE file distributed with +rem this work for additional information regarding copyright ownership. +rem The ASF licenses this file to You under the Apache License, Version 2.0 +rem (the "License"); you may not use this file except in compliance with +rem the License. You may obtain a copy of the License at +rem +rem http://www.apache.org/licenses/LICENSE-2.0 +rem +rem Unless required by applicable law or agreed to in writing, software +rem distributed under the License is distributed on an "AS IS" BASIS, +rem WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +rem See the License for the specific language governing permissions and +rem limitations under the License. +rem + +rem This script computes Spark's classpath and prints it to stdout; it's used by both the "run" +rem script and the ExecutorRunner in standalone cluster mode. + +rem If we're called from spark-class2.cmd, it already set enabledelayedexpansion and setting +rem it here would stop us from affecting its copy of the CLASSPATH variable; otherwise we +rem need to set it here because we use !datanucleus_jars! below. +if "%DONT_PRINT_CLASSPATH%"=="1" goto skip_delayed_expansion +setlocal enabledelayedexpansion +:skip_delayed_expansion + +set SCALA_VERSION=2.10 + +rem Figure out where the Spark framework is installed +set FWDIR=%~dp0..\ + +rem Load environment variables from conf\spark-env.cmd, if it exists +if exist "%FWDIR%conf\spark-env.cmd" call "%FWDIR%conf\spark-env.cmd" + +rem Build up classpath +set CLASSPATH=%SPARK_CLASSPATH%;%SPARK_SUBMIT_CLASSPATH% + +if not "x%SPARK_CONF_DIR%"=="x" ( + set CLASSPATH=%CLASSPATH%;%SPARK_CONF_DIR% +) else ( + set CLASSPATH=%CLASSPATH%;%FWDIR%conf +) + +if exist "%FWDIR%RELEASE" ( + for %%d in ("%FWDIR%lib\spark-assembly*.jar") do ( + set ASSEMBLY_JAR=%%d + ) +) else ( + for %%d in ("%FWDIR%assembly\target\scala-%SCALA_VERSION%\spark-assembly*hadoop*.jar") do ( + set ASSEMBLY_JAR=%%d + ) +) + +set CLASSPATH=%CLASSPATH%;%ASSEMBLY_JAR% + +rem When Hive support is needed, Datanucleus jars must be included on the classpath. +rem Datanucleus jars do not work if only included in the uber jar as plugin.xml metadata is lost. +rem Both sbt and maven will populate "lib_managed/jars/" with the datanucleus jars when Spark is +rem built with Hive, so look for them there. +if exist "%FWDIR%RELEASE" ( + set datanucleus_dir=%FWDIR%lib +) else ( + set datanucleus_dir=%FWDIR%lib_managed\jars +) +set "datanucleus_jars=" +for %%d in ("%datanucleus_dir%\datanucleus-*.jar") do ( + set datanucleus_jars=!datanucleus_jars!;%%d +) +set CLASSPATH=%CLASSPATH%;%datanucleus_jars% + +set SPARK_CLASSES=%FWDIR%core\target\scala-%SCALA_VERSION%\classes +set SPARK_CLASSES=%SPARK_CLASSES%;%FWDIR%repl\target\scala-%SCALA_VERSION%\classes +set SPARK_CLASSES=%SPARK_CLASSES%;%FWDIR%mllib\target\scala-%SCALA_VERSION%\classes +set SPARK_CLASSES=%SPARK_CLASSES%;%FWDIR%bagel\target\scala-%SCALA_VERSION%\classes +set SPARK_CLASSES=%SPARK_CLASSES%;%FWDIR%graphx\target\scala-%SCALA_VERSION%\classes +set SPARK_CLASSES=%SPARK_CLASSES%;%FWDIR%streaming\target\scala-%SCALA_VERSION%\classes +set SPARK_CLASSES=%SPARK_CLASSES%;%FWDIR%tools\target\scala-%SCALA_VERSION%\classes +set SPARK_CLASSES=%SPARK_CLASSES%;%FWDIR%sql\catalyst\target\scala-%SCALA_VERSION%\classes +set SPARK_CLASSES=%SPARK_CLASSES%;%FWDIR%sql\core\target\scala-%SCALA_VERSION%\classes +set SPARK_CLASSES=%SPARK_CLASSES%;%FWDIR%sql\hive\target\scala-%SCALA_VERSION%\classes + +set SPARK_TEST_CLASSES=%FWDIR%core\target\scala-%SCALA_VERSION%\test-classes +set SPARK_TEST_CLASSES=%SPARK_TEST_CLASSES%;%FWDIR%repl\target\scala-%SCALA_VERSION%\test-classes +set SPARK_TEST_CLASSES=%SPARK_TEST_CLASSES%;%FWDIR%mllib\target\scala-%SCALA_VERSION%\test-classes +set SPARK_TEST_CLASSES=%SPARK_TEST_CLASSES%;%FWDIR%bagel\target\scala-%SCALA_VERSION%\test-classes +set SPARK_TEST_CLASSES=%SPARK_TEST_CLASSES%;%FWDIR%graphx\target\scala-%SCALA_VERSION%\test-classes +set SPARK_TEST_CLASSES=%SPARK_TEST_CLASSES%;%FWDIR%streaming\target\scala-%SCALA_VERSION%\test-classes +set SPARK_TEST_CLASSES=%SPARK_TEST_CLASSES%;%FWDIR%sql\catalyst\target\scala-%SCALA_VERSION%\test-classes +set SPARK_TEST_CLASSES=%SPARK_TEST_CLASSES%;%FWDIR%sql\core\target\scala-%SCALA_VERSION%\test-classes +set SPARK_TEST_CLASSES=%SPARK_TEST_CLASSES%;%FWDIR%sql\hive\target\scala-%SCALA_VERSION%\test-classes + +if "x%SPARK_TESTING%"=="x1" ( + rem Add test clases to path - note, add SPARK_CLASSES and SPARK_TEST_CLASSES before CLASSPATH + rem so that local compilation takes precedence over assembled jar + set CLASSPATH=%SPARK_CLASSES%;%SPARK_TEST_CLASSES%;%CLASSPATH% +) + +rem Add hadoop conf dir - else FileSystem.*, etc fail +rem Note, this assumes that there is either a HADOOP_CONF_DIR or YARN_CONF_DIR which hosts +rem the configurtion files. +if "x%HADOOP_CONF_DIR%"=="x" goto no_hadoop_conf_dir + set CLASSPATH=%CLASSPATH%;%HADOOP_CONF_DIR% +:no_hadoop_conf_dir + +if "x%YARN_CONF_DIR%"=="x" goto no_yarn_conf_dir + set CLASSPATH=%CLASSPATH%;%YARN_CONF_DIR% +:no_yarn_conf_dir + +rem A bit of a hack to allow calling this script within run2.cmd without seeing output +if "%DONT_PRINT_CLASSPATH%"=="1" goto exit + +echo %CLASSPATH% + +:exit diff --git a/bin/pyspark2.cmd b/bin/pyspark2.cmd index a0e66abcc26c9..59415e9bdec2c 100644 --- a/bin/pyspark2.cmd +++ b/bin/pyspark2.cmd @@ -59,7 +59,12 @@ for /f %%i in ('echo %1^| findstr /R "\.py"') do ( ) if [%PYTHON_FILE%] == [] ( - %PYSPARK_PYTHON% + set PYSPARK_SHELL=1 + if [%IPYTHON%] == [1] ( + ipython %IPYTHON_OPTS% + ) else ( + %PYSPARK_PYTHON% + ) ) else ( echo. echo WARNING: Running python applications through ./bin/pyspark.cmd is deprecated as of Spark 1.0. diff --git a/bin/spark-class b/bin/spark-class index 91d858bc063d0..925367b0dd187 100755 --- a/bin/spark-class +++ b/bin/spark-class @@ -81,7 +81,11 @@ case "$1" in OUR_JAVA_OPTS="$SPARK_JAVA_OPTS $SPARK_SUBMIT_OPTS" OUR_JAVA_MEM=${SPARK_DRIVER_MEMORY:-$DEFAULT_MEM} if [ -n "$SPARK_SUBMIT_LIBRARY_PATH" ]; then - OUR_JAVA_OPTS="$OUR_JAVA_OPTS -Djava.library.path=$SPARK_SUBMIT_LIBRARY_PATH" + if [[ $OSTYPE == darwin* ]]; then + export DYLD_LIBRARY_PATH="$SPARK_SUBMIT_LIBRARY_PATH:$DYLD_LIBRARY_PATH" + else + export LD_LIBRARY_PATH="$SPARK_SUBMIT_LIBRARY_PATH:$LD_LIBRARY_PATH" + fi fi if [ -n "$SPARK_SUBMIT_DRIVER_MEMORY" ]; then OUR_JAVA_MEM="$SPARK_SUBMIT_DRIVER_MEMORY" diff --git a/core/pom.xml b/core/pom.xml index 7b68dbaea4789..41296e0eca330 100644 --- a/core/pom.xml +++ b/core/pom.xml @@ -44,6 +44,16 @@ + + org.apache.spark + spark-network-common_2.10 + ${project.version} + + + org.apache.spark + spark-network-shuffle_2.10 + ${project.version} + net.java.dev.jets3t jets3t @@ -85,8 +95,6 @@ org.apache.commons commons-math3 - 3.3 - test com.google.code.findbugs @@ -162,10 +170,6 @@ json4s-jackson_${scala.binary.version} 3.2.10 - - colt - colt - org.apache.mesos mesos @@ -247,6 +251,11 @@ + + org.seleniumhq.selenium + selenium-java + test + org.scalatest scalatest_${scala.binary.version} diff --git a/core/src/main/java/org/apache/spark/JobExecutionStatus.java b/core/src/main/java/org/apache/spark/JobExecutionStatus.java new file mode 100644 index 0000000000000..6e161313702bb --- /dev/null +++ b/core/src/main/java/org/apache/spark/JobExecutionStatus.java @@ -0,0 +1,25 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark; + +public enum JobExecutionStatus { + RUNNING, + SUCCEEDED, + FAILED, + UNKNOWN +} diff --git a/core/src/main/scala/org/apache/spark/storage/BlockDataProvider.scala b/core/src/main/java/org/apache/spark/SparkJobInfo.java similarity index 67% rename from core/src/main/scala/org/apache/spark/storage/BlockDataProvider.scala rename to core/src/main/java/org/apache/spark/SparkJobInfo.java index 5b6d086630834..4e3c983b1170a 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockDataProvider.scala +++ b/core/src/main/java/org/apache/spark/SparkJobInfo.java @@ -15,18 +15,16 @@ * limitations under the License. */ -package org.apache.spark.storage - -import java.nio.ByteBuffer - +package org.apache.spark; /** - * An interface for providing data for blocks. - * - * getBlockData returns either a FileSegment (for zero-copy send), or a ByteBuffer. + * Exposes information about Spark Jobs. * - * Aside from unit tests, [[BlockManager]] is the main class that implements this. + * This interface is not designed to be implemented outside of Spark. We may add additional methods + * which may break binary compatibility with outside implementations. */ -private[spark] trait BlockDataProvider { - def getBlockData(blockId: String): Either[FileSegment, ByteBuffer] +public interface SparkJobInfo { + int jobId(); + int[] stageIds(); + JobExecutionStatus status(); } diff --git a/core/src/main/java/org/apache/spark/SparkStageInfo.java b/core/src/main/java/org/apache/spark/SparkStageInfo.java new file mode 100644 index 0000000000000..04e2247210ecc --- /dev/null +++ b/core/src/main/java/org/apache/spark/SparkStageInfo.java @@ -0,0 +1,34 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark; + +/** + * Exposes information about Spark Stages. + * + * This interface is not designed to be implemented outside of Spark. We may add additional methods + * which may break binary compatibility with outside implementations. + */ +public interface SparkStageInfo { + int stageId(); + int currentAttemptId(); + String name(); + int numTasks(); + int numActiveTasks(); + int numCompletedTasks(); + int numFailedTasks(); +} diff --git a/core/src/main/java/org/apache/spark/TaskContext.java b/core/src/main/java/org/apache/spark/TaskContext.java index 2d998d4c7a5d9..0d6973203eba1 100644 --- a/core/src/main/java/org/apache/spark/TaskContext.java +++ b/core/src/main/java/org/apache/spark/TaskContext.java @@ -71,7 +71,6 @@ static void unset() { /** * Add a (Java friendly) listener to be executed on task completion. * This will be called in all situation - success, failure, or cancellation. - *

* An example use is for HadoopRDD to register a callback to close the input stream. */ public abstract TaskContext addTaskCompletionListener(TaskCompletionListener listener); @@ -79,7 +78,6 @@ static void unset() { /** * Add a listener in the form of a Scala closure to be executed on task completion. * This will be called in all situations - success, failure, or cancellation. - *

* An example use is for HadoopRDD to register a callback to close the input stream. */ public abstract TaskContext addTaskCompletionListener(final Function1 f); diff --git a/core/src/main/java/org/apache/spark/api/java/function/PairFunction.java b/core/src/main/java/org/apache/spark/api/java/function/PairFunction.java index abd9bcc07ac61..99bf240a17225 100644 --- a/core/src/main/java/org/apache/spark/api/java/function/PairFunction.java +++ b/core/src/main/java/org/apache/spark/api/java/function/PairFunction.java @@ -22,7 +22,8 @@ import scala.Tuple2; /** - * A function that returns key-value pairs (Tuple2), and can be used to construct PairRDDs. + * A function that returns key-value pairs (Tuple2<K, V>), and can be used to + * construct PairRDDs. */ public interface PairFunction extends Serializable { public Tuple2 call(T t) throws Exception; diff --git a/core/src/main/java/org/apache/spark/util/collection/Sorter.java b/core/src/main/java/org/apache/spark/util/collection/TimSort.java similarity index 92% rename from core/src/main/java/org/apache/spark/util/collection/Sorter.java rename to core/src/main/java/org/apache/spark/util/collection/TimSort.java index 64ad18c0e463a..409e1a41c5d49 100644 --- a/core/src/main/java/org/apache/spark/util/collection/Sorter.java +++ b/core/src/main/java/org/apache/spark/util/collection/TimSort.java @@ -20,18 +20,25 @@ import java.util.Comparator; /** - * A port of the Android Timsort class, which utilizes a "stable, adaptive, iterative mergesort." + * A port of the Android TimSort class, which utilizes a "stable, adaptive, iterative mergesort." * See the method comment on sort() for more details. * * This has been kept in Java with the original style in order to match very closely with the - * Anroid source code, and thus be easy to verify correctness. + * Android source code, and thus be easy to verify correctness. The class is package private. We put + * a simple Scala wrapper {@link org.apache.spark.util.collection.Sorter}, which is available to + * package org.apache.spark. * * The purpose of the port is to generalize the interface to the sort to accept input data formats * besides simple arrays where every element is sorted individually. For instance, the AppendOnlyMap * uses this to sort an Array with alternating elements of the form [key, value, key, value]. * This generalization comes with minimal overhead -- see SortDataFormat for more information. + * + * We allow key reuse to prevent creating many key objects -- see SortDataFormat. + * + * @see org.apache.spark.util.collection.SortDataFormat + * @see org.apache.spark.util.collection.Sorter */ -class Sorter { +class TimSort { /** * This is the minimum sized sequence that will be merged. Shorter @@ -54,7 +61,7 @@ class Sorter { private final SortDataFormat s; - public Sorter(SortDataFormat sortDataFormat) { + public TimSort(SortDataFormat sortDataFormat) { this.s = sortDataFormat; } @@ -91,7 +98,7 @@ public Sorter(SortDataFormat sortDataFormat) { * * @author Josh Bloch */ - void sort(Buffer a, int lo, int hi, Comparator c) { + public void sort(Buffer a, int lo, int hi, Comparator c) { assert c != null; int nRemaining = hi - lo; @@ -162,10 +169,13 @@ private void binarySort(Buffer a, int lo, int hi, int start, Comparator>> 1; - if (c.compare(pivot, s.getKey(a, mid)) < 0) + if (c.compare(pivot, s.getKey(a, mid, key1)) < 0) right = mid; else left = mid + 1; @@ -235,13 +245,16 @@ private int countRunAndMakeAscending(Buffer a, int lo, int hi, Comparator= 0) + while (runHi < hi && c.compare(s.getKey(a, runHi, key0), s.getKey(a, runHi - 1, key1)) >= 0) runHi++; } @@ -468,11 +481,13 @@ private void mergeAt(int i) { } stackSize--; + K key0 = s.newKey(); + /* * Find where the first element of run2 goes in run1. Prior elements * in run1 can be ignored (because they're already in place). */ - int k = gallopRight(s.getKey(a, base2), a, base1, len1, 0, c); + int k = gallopRight(s.getKey(a, base2, key0), a, base1, len1, 0, c); assert k >= 0; base1 += k; len1 -= k; @@ -483,7 +498,7 @@ private void mergeAt(int i) { * Find where the last element of run1 goes in run2. Subsequent elements * in run2 can be ignored (because they're already in place). */ - len2 = gallopLeft(s.getKey(a, base1 + len1 - 1), a, base2, len2, len2 - 1, c); + len2 = gallopLeft(s.getKey(a, base1 + len1 - 1, key0), a, base2, len2, len2 - 1, c); assert len2 >= 0; if (len2 == 0) return; @@ -517,10 +532,12 @@ private int gallopLeft(K key, Buffer a, int base, int len, int hint, Comparator< assert len > 0 && hint >= 0 && hint < len; int lastOfs = 0; int ofs = 1; - if (c.compare(key, s.getKey(a, base + hint)) > 0) { + K key0 = s.newKey(); + + if (c.compare(key, s.getKey(a, base + hint, key0)) > 0) { // Gallop right until a[base+hint+lastOfs] < key <= a[base+hint+ofs] int maxOfs = len - hint; - while (ofs < maxOfs && c.compare(key, s.getKey(a, base + hint + ofs)) > 0) { + while (ofs < maxOfs && c.compare(key, s.getKey(a, base + hint + ofs, key0)) > 0) { lastOfs = ofs; ofs = (ofs << 1) + 1; if (ofs <= 0) // int overflow @@ -535,7 +552,7 @@ private int gallopLeft(K key, Buffer a, int base, int len, int hint, Comparator< } else { // key <= a[base + hint] // Gallop left until a[base+hint-ofs] < key <= a[base+hint-lastOfs] final int maxOfs = hint + 1; - while (ofs < maxOfs && c.compare(key, s.getKey(a, base + hint - ofs)) <= 0) { + while (ofs < maxOfs && c.compare(key, s.getKey(a, base + hint - ofs, key0)) <= 0) { lastOfs = ofs; ofs = (ofs << 1) + 1; if (ofs <= 0) // int overflow @@ -560,7 +577,7 @@ private int gallopLeft(K key, Buffer a, int base, int len, int hint, Comparator< while (lastOfs < ofs) { int m = lastOfs + ((ofs - lastOfs) >>> 1); - if (c.compare(key, s.getKey(a, base + m)) > 0) + if (c.compare(key, s.getKey(a, base + m, key0)) > 0) lastOfs = m + 1; // a[base + m] < key else ofs = m; // key <= a[base + m] @@ -587,10 +604,12 @@ private int gallopRight(K key, Buffer a, int base, int len, int hint, Comparator int ofs = 1; int lastOfs = 0; - if (c.compare(key, s.getKey(a, base + hint)) < 0) { + K key1 = s.newKey(); + + if (c.compare(key, s.getKey(a, base + hint, key1)) < 0) { // Gallop left until a[b+hint - ofs] <= key < a[b+hint - lastOfs] int maxOfs = hint + 1; - while (ofs < maxOfs && c.compare(key, s.getKey(a, base + hint - ofs)) < 0) { + while (ofs < maxOfs && c.compare(key, s.getKey(a, base + hint - ofs, key1)) < 0) { lastOfs = ofs; ofs = (ofs << 1) + 1; if (ofs <= 0) // int overflow @@ -606,7 +625,7 @@ private int gallopRight(K key, Buffer a, int base, int len, int hint, Comparator } else { // a[b + hint] <= key // Gallop right until a[b+hint + lastOfs] <= key < a[b+hint + ofs] int maxOfs = len - hint; - while (ofs < maxOfs && c.compare(key, s.getKey(a, base + hint + ofs)) >= 0) { + while (ofs < maxOfs && c.compare(key, s.getKey(a, base + hint + ofs, key1)) >= 0) { lastOfs = ofs; ofs = (ofs << 1) + 1; if (ofs <= 0) // int overflow @@ -630,7 +649,7 @@ private int gallopRight(K key, Buffer a, int base, int len, int hint, Comparator while (lastOfs < ofs) { int m = lastOfs + ((ofs - lastOfs) >>> 1); - if (c.compare(key, s.getKey(a, base + m)) < 0) + if (c.compare(key, s.getKey(a, base + m, key1)) < 0) ofs = m; // key < a[b + m] else lastOfs = m + 1; // a[b + m] <= key @@ -679,6 +698,9 @@ private void mergeLo(int base1, int len1, int base2, int len2) { return; } + K key0 = s.newKey(); + K key1 = s.newKey(); + Comparator c = this.c; // Use local variable for performance int minGallop = this.minGallop; // " " " " " outer: @@ -692,7 +714,7 @@ private void mergeLo(int base1, int len1, int base2, int len2) { */ do { assert len1 > 1 && len2 > 0; - if (c.compare(s.getKey(a, cursor2), s.getKey(tmp, cursor1)) < 0) { + if (c.compare(s.getKey(a, cursor2, key0), s.getKey(tmp, cursor1, key1)) < 0) { s.copyElement(a, cursor2++, a, dest++); count2++; count1 = 0; @@ -714,7 +736,7 @@ private void mergeLo(int base1, int len1, int base2, int len2) { */ do { assert len1 > 1 && len2 > 0; - count1 = gallopRight(s.getKey(a, cursor2), tmp, cursor1, len1, 0, c); + count1 = gallopRight(s.getKey(a, cursor2, key0), tmp, cursor1, len1, 0, c); if (count1 != 0) { s.copyRange(tmp, cursor1, a, dest, count1); dest += count1; @@ -727,7 +749,7 @@ private void mergeLo(int base1, int len1, int base2, int len2) { if (--len2 == 0) break outer; - count2 = gallopLeft(s.getKey(tmp, cursor1), a, cursor2, len2, 0, c); + count2 = gallopLeft(s.getKey(tmp, cursor1, key0), a, cursor2, len2, 0, c); if (count2 != 0) { s.copyRange(a, cursor2, a, dest, count2); dest += count2; @@ -784,6 +806,9 @@ private void mergeHi(int base1, int len1, int base2, int len2) { int cursor2 = len2 - 1; // Indexes into tmp array int dest = base2 + len2 - 1; // Indexes into a + K key0 = s.newKey(); + K key1 = s.newKey(); + // Move last element of first run and deal with degenerate cases s.copyElement(a, cursor1--, a, dest--); if (--len1 == 0) { @@ -811,7 +836,7 @@ private void mergeHi(int base1, int len1, int base2, int len2) { */ do { assert len1 > 0 && len2 > 1; - if (c.compare(s.getKey(tmp, cursor2), s.getKey(a, cursor1)) < 0) { + if (c.compare(s.getKey(tmp, cursor2, key0), s.getKey(a, cursor1, key1)) < 0) { s.copyElement(a, cursor1--, a, dest--); count1++; count2 = 0; @@ -833,7 +858,7 @@ private void mergeHi(int base1, int len1, int base2, int len2) { */ do { assert len1 > 0 && len2 > 1; - count1 = len1 - gallopRight(s.getKey(tmp, cursor2), a, base1, len1, len1 - 1, c); + count1 = len1 - gallopRight(s.getKey(tmp, cursor2, key0), a, base1, len1, len1 - 1, c); if (count1 != 0) { dest -= count1; cursor1 -= count1; @@ -846,7 +871,7 @@ private void mergeHi(int base1, int len1, int base2, int len2) { if (--len2 == 1) break outer; - count2 = len2 - gallopLeft(s.getKey(a, cursor1), tmp, 0, len2, len2 - 1, c); + count2 = len2 - gallopLeft(s.getKey(a, cursor1, key0), tmp, 0, len2, len2 - 1, c); if (count2 != 0) { dest -= count2; cursor2 -= count2; diff --git a/core/src/main/resources/org/apache/spark/ui/static/additional-metrics.js b/core/src/main/resources/org/apache/spark/ui/static/additional-metrics.js new file mode 100644 index 0000000000000..c5936b5038ac9 --- /dev/null +++ b/core/src/main/resources/org/apache/spark/ui/static/additional-metrics.js @@ -0,0 +1,53 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* Register functions to show/hide columns based on checkboxes. These need + * to be registered after the page loads. */ +$(function() { + $("span.expand-additional-metrics").click(function(){ + // Expand the list of additional metrics. + var additionalMetricsDiv = $(this).parent().find('.additional-metrics'); + $(additionalMetricsDiv).toggleClass('collapsed'); + + // Switch the class of the arrow from open to closed. + $(this).find('.expand-additional-metrics-arrow').toggleClass('arrow-open'); + $(this).find('.expand-additional-metrics-arrow').toggleClass('arrow-closed'); + + // If clicking caused the metrics to expand, automatically check all options for additional + // metrics (don't trigger a click when collapsing metrics, because it leads to weird + // toggling behavior). + if (!$(additionalMetricsDiv).hasClass('collapsed')) { + $(this).parent().find('input:checkbox:not(:checked)').trigger('click'); + } + }); + + $("input:checkbox:not(:checked)").each(function() { + var column = "table ." + $(this).attr("name"); + $(column).hide(); + }); + + $("input:checkbox").click(function() { + var column = "table ." + $(this).attr("name"); + $(column).toggle(); + stripeTables(); + }); + + // Trigger a click on the checkbox if a user clicks the label next to it. + $("span.additional-metric-title").click(function() { + $(this).parent().find('input:checkbox').trigger('click'); + }); +}); diff --git a/core/src/main/resources/org/apache/spark/ui/static/table.js b/core/src/main/resources/org/apache/spark/ui/static/table.js new file mode 100644 index 0000000000000..32187ba6e8df0 --- /dev/null +++ b/core/src/main/resources/org/apache/spark/ui/static/table.js @@ -0,0 +1,35 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* Adds background colors to stripe table rows. This is necessary (instead of using css or the + * table striping provided by bootstrap) to appropriately stripe tables with hidden rows. */ +function stripeTables() { + $("table.table-striped-custom").each(function() { + $(this).find("tr:not(:hidden)").each(function (index) { + if (index % 2 == 1) { + $(this).css("background-color", "#f9f9f9"); + } else { + $(this).css("background-color", "#ffffff"); + } + }); + }); +} + +/* Stripe all tables after pages finish loading. */ +$(function() { + stripeTables(); +}); diff --git a/core/src/main/resources/org/apache/spark/ui/static/webui.css b/core/src/main/resources/org/apache/spark/ui/static/webui.css index 152bde5f6994f..a2220e761ac98 100644 --- a/core/src/main/resources/org/apache/spark/ui/static/webui.css +++ b/core/src/main/resources/org/apache/spark/ui/static/webui.css @@ -120,7 +120,37 @@ pre { border: none; } +span.expand-additional-metrics { + cursor: pointer; +} + +span.additional-metric-title { + cursor: pointer; +} + +.additional-metrics.collapsed { + display: none; +} + .tooltip { font-weight: normal; } +.arrow-open { + width: 0; + height: 0; + border-left: 5px solid transparent; + border-right: 5px solid transparent; + border-top: 5px solid black; + float: left; + margin-top: 6px; +} + +.arrow-closed { + width: 0; + height: 0; + border-top: 5px solid transparent; + border-bottom: 5px solid transparent; + border-left: 5px solid black; + display: inline-block; +} diff --git a/core/src/main/scala/org/apache/spark/ExecutorAllocationManager.scala b/core/src/main/scala/org/apache/spark/ExecutorAllocationManager.scala new file mode 100644 index 0000000000000..c11f1db0064fd --- /dev/null +++ b/core/src/main/scala/org/apache/spark/ExecutorAllocationManager.scala @@ -0,0 +1,462 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark + +import scala.collection.mutable + +import org.apache.spark.scheduler._ + +/** + * An agent that dynamically allocates and removes executors based on the workload. + * + * The add policy depends on whether there are backlogged tasks waiting to be scheduled. If + * the scheduler queue is not drained in N seconds, then new executors are added. If the queue + * persists for another M seconds, then more executors are added and so on. The number added + * in each round increases exponentially from the previous round until an upper bound on the + * number of executors has been reached. + * + * The rationale for the exponential increase is twofold: (1) Executors should be added slowly + * in the beginning in case the number of extra executors needed turns out to be small. Otherwise, + * we may add more executors than we need just to remove them later. (2) Executors should be added + * quickly over time in case the maximum number of executors is very high. Otherwise, it will take + * a long time to ramp up under heavy workloads. + * + * The remove policy is simpler: If an executor has been idle for K seconds, meaning it has not + * been scheduled to run any tasks, then it is removed. + * + * There is no retry logic in either case because we make the assumption that the cluster manager + * will eventually fulfill all requests it receives asynchronously. + * + * The relevant Spark properties include the following: + * + * spark.dynamicAllocation.enabled - Whether this feature is enabled + * spark.dynamicAllocation.minExecutors - Lower bound on the number of executors + * spark.dynamicAllocation.maxExecutors - Upper bound on the number of executors + * + * spark.dynamicAllocation.schedulerBacklogTimeout (M) - + * If there are backlogged tasks for this duration, add new executors + * + * spark.dynamicAllocation.sustainedSchedulerBacklogTimeout (N) - + * If the backlog is sustained for this duration, add more executors + * This is used only after the initial backlog timeout is exceeded + * + * spark.dynamicAllocation.executorIdleTimeout (K) - + * If an executor has been idle for this duration, remove it + */ +private[spark] class ExecutorAllocationManager(sc: SparkContext) extends Logging { + import ExecutorAllocationManager._ + + private val conf = sc.conf + + // Lower and upper bounds on the number of executors. These are required. + private val minNumExecutors = conf.getInt("spark.dynamicAllocation.minExecutors", -1) + private val maxNumExecutors = conf.getInt("spark.dynamicAllocation.maxExecutors", -1) + verifyBounds() + + // How long there must be backlogged tasks for before an addition is triggered + private val schedulerBacklogTimeout = conf.getLong( + "spark.dynamicAllocation.schedulerBacklogTimeout", 60) + + // Same as above, but used only after `schedulerBacklogTimeout` is exceeded + private val sustainedSchedulerBacklogTimeout = conf.getLong( + "spark.dynamicAllocation.sustainedSchedulerBacklogTimeout", schedulerBacklogTimeout) + + // How long an executor must be idle for before it is removed + private val removeThresholdSeconds = conf.getLong( + "spark.dynamicAllocation.executorIdleTimeout", 600) + + // Number of executors to add in the next round + private var numExecutorsToAdd = 1 + + // Number of executors that have been requested but have not registered yet + private var numExecutorsPending = 0 + + // Executors that have been requested to be removed but have not been killed yet + private val executorsPendingToRemove = new mutable.HashSet[String] + + // All known executors + private val executorIds = new mutable.HashSet[String] + + // A timestamp of when an addition should be triggered, or NOT_SET if it is not set + // This is set when pending tasks are added but not scheduled yet + private var addTime: Long = NOT_SET + + // A timestamp for each executor of when the executor should be removed, indexed by the ID + // This is set when an executor is no longer running a task, or when it first registers + private val removeTimes = new mutable.HashMap[String, Long] + + // Polling loop interval (ms) + private val intervalMillis: Long = 100 + + // Whether we are testing this class. This should only be used internally. + private val testing = conf.getBoolean("spark.dynamicAllocation.testing", false) + + // Clock used to schedule when executors should be added and removed + private var clock: Clock = new RealClock + + /** + * Verify that the lower and upper bounds on the number of executors are valid. + * If not, throw an appropriate exception. + */ + private def verifyBounds(): Unit = { + if (minNumExecutors < 0 || maxNumExecutors < 0) { + throw new SparkException("spark.dynamicAllocation.{min/max}Executors must be set!") + } + if (minNumExecutors == 0 || maxNumExecutors == 0) { + throw new SparkException("spark.dynamicAllocation.{min/max}Executors cannot be 0!") + } + if (minNumExecutors > maxNumExecutors) { + throw new SparkException(s"spark.dynamicAllocation.minExecutors ($minNumExecutors) must " + + s"be less than or equal to spark.dynamicAllocation.maxExecutors ($maxNumExecutors)!") + } + } + + /** + * Use a different clock for this allocation manager. This is mainly used for testing. + */ + def setClock(newClock: Clock): Unit = { + clock = newClock + } + + /** + * Register for scheduler callbacks to decide when to add and remove executors. + */ + def start(): Unit = { + val listener = new ExecutorAllocationListener(this) + sc.addSparkListener(listener) + startPolling() + } + + /** + * Start the main polling thread that keeps track of when to add and remove executors. + */ + private def startPolling(): Unit = { + val t = new Thread { + override def run(): Unit = { + while (true) { + try { + schedule() + } catch { + case e: Exception => logError("Exception in dynamic executor allocation thread!", e) + } + Thread.sleep(intervalMillis) + } + } + } + t.setName("spark-dynamic-executor-allocation") + t.setDaemon(true) + t.start() + } + + /** + * If the add time has expired, request new executors and refresh the add time. + * If the remove time for an existing executor has expired, kill the executor. + * This is factored out into its own method for testing. + */ + private def schedule(): Unit = synchronized { + val now = clock.getTimeMillis + if (addTime != NOT_SET && now >= addTime) { + addExecutors() + logDebug(s"Starting timer to add more executors (to " + + s"expire in $sustainedSchedulerBacklogTimeout seconds)") + addTime += sustainedSchedulerBacklogTimeout * 1000 + } + + removeTimes.foreach { case (executorId, expireTime) => + if (now >= expireTime) { + removeExecutor(executorId) + removeTimes.remove(executorId) + } + } + } + + /** + * Request a number of executors from the cluster manager. + * If the cap on the number of executors is reached, give up and reset the + * number of executors to add next round instead of continuing to double it. + * Return the number actually requested. + */ + private def addExecutors(): Int = synchronized { + // Do not request more executors if we have already reached the upper bound + val numExistingExecutors = executorIds.size + numExecutorsPending + if (numExistingExecutors >= maxNumExecutors) { + logDebug(s"Not adding executors because there are already ${executorIds.size} " + + s"registered and $numExecutorsPending pending executor(s) (limit $maxNumExecutors)") + numExecutorsToAdd = 1 + return 0 + } + + // Request executors with respect to the upper bound + val actualNumExecutorsToAdd = + if (numExistingExecutors + numExecutorsToAdd <= maxNumExecutors) { + numExecutorsToAdd + } else { + maxNumExecutors - numExistingExecutors + } + val newTotalExecutors = numExistingExecutors + actualNumExecutorsToAdd + val addRequestAcknowledged = testing || sc.requestExecutors(actualNumExecutorsToAdd) + if (addRequestAcknowledged) { + logInfo(s"Requesting $actualNumExecutorsToAdd new executor(s) because " + + s"tasks are backlogged (new desired total will be $newTotalExecutors)") + numExecutorsToAdd = + if (actualNumExecutorsToAdd == numExecutorsToAdd) numExecutorsToAdd * 2 else 1 + numExecutorsPending += actualNumExecutorsToAdd + actualNumExecutorsToAdd + } else { + logWarning(s"Unable to reach the cluster manager " + + s"to request $actualNumExecutorsToAdd executors!") + 0 + } + } + + /** + * Request the cluster manager to remove the given executor. + * Return whether the request is received. + */ + private def removeExecutor(executorId: String): Boolean = synchronized { + // Do not kill the executor if we are not aware of it (should never happen) + if (!executorIds.contains(executorId)) { + logWarning(s"Attempted to remove unknown executor $executorId!") + return false + } + + // Do not kill the executor again if it is already pending to be killed (should never happen) + if (executorsPendingToRemove.contains(executorId)) { + logWarning(s"Attempted to remove executor $executorId " + + s"when it is already pending to be removed!") + return false + } + + // Do not kill the executor if we have already reached the lower bound + val numExistingExecutors = executorIds.size - executorsPendingToRemove.size + if (numExistingExecutors - 1 < minNumExecutors) { + logInfo(s"Not removing idle executor $executorId because there are only " + + s"$numExistingExecutors executor(s) left (limit $minNumExecutors)") + return false + } + + // Send a request to the backend to kill this executor + val removeRequestAcknowledged = testing || sc.killExecutor(executorId) + if (removeRequestAcknowledged) { + logInfo(s"Removing executor $executorId because it has been idle for " + + s"$removeThresholdSeconds seconds (new desired total will be ${numExistingExecutors - 1})") + executorsPendingToRemove.add(executorId) + true + } else { + logWarning(s"Unable to reach the cluster manager to kill executor $executorId!") + false + } + } + + /** + * Callback invoked when the specified executor has been added. + */ + private def onExecutorAdded(executorId: String): Unit = synchronized { + if (!executorIds.contains(executorId)) { + executorIds.add(executorId) + executorIds.foreach(onExecutorIdle) + logInfo(s"New executor $executorId has registered (new total is ${executorIds.size})") + if (numExecutorsPending > 0) { + numExecutorsPending -= 1 + logDebug(s"Decremented number of pending executors ($numExecutorsPending left)") + } + } else { + logWarning(s"Duplicate executor $executorId has registered") + } + } + + /** + * Callback invoked when the specified executor has been removed. + */ + private def onExecutorRemoved(executorId: String): Unit = synchronized { + if (executorIds.contains(executorId)) { + executorIds.remove(executorId) + removeTimes.remove(executorId) + logInfo(s"Existing executor $executorId has been removed (new total is ${executorIds.size})") + if (executorsPendingToRemove.contains(executorId)) { + executorsPendingToRemove.remove(executorId) + logDebug(s"Executor $executorId is no longer pending to " + + s"be removed (${executorsPendingToRemove.size} left)") + } + } else { + logWarning(s"Unknown executor $executorId has been removed!") + } + } + + /** + * Callback invoked when the scheduler receives new pending tasks. + * This sets a time in the future that decides when executors should be added + * if it is not already set. + */ + private def onSchedulerBacklogged(): Unit = synchronized { + if (addTime == NOT_SET) { + logDebug(s"Starting timer to add executors because pending tasks " + + s"are building up (to expire in $schedulerBacklogTimeout seconds)") + addTime = clock.getTimeMillis + schedulerBacklogTimeout * 1000 + } + } + + /** + * Callback invoked when the scheduler queue is drained. + * This resets all variables used for adding executors. + */ + private def onSchedulerQueueEmpty(): Unit = synchronized { + logDebug(s"Clearing timer to add executors because there are no more pending tasks") + addTime = NOT_SET + numExecutorsToAdd = 1 + } + + /** + * Callback invoked when the specified executor is no longer running any tasks. + * This sets a time in the future that decides when this executor should be removed if + * the executor is not already marked as idle. + */ + private def onExecutorIdle(executorId: String): Unit = synchronized { + if (!removeTimes.contains(executorId) && !executorsPendingToRemove.contains(executorId)) { + logDebug(s"Starting idle timer for $executorId because there are no more tasks " + + s"scheduled to run on the executor (to expire in $removeThresholdSeconds seconds)") + removeTimes(executorId) = clock.getTimeMillis + removeThresholdSeconds * 1000 + } + } + + /** + * Callback invoked when the specified executor is now running a task. + * This resets all variables used for removing this executor. + */ + private def onExecutorBusy(executorId: String): Unit = synchronized { + logDebug(s"Clearing idle timer for $executorId because it is now running a task") + removeTimes.remove(executorId) + } + + /** + * A listener that notifies the given allocation manager of when to add and remove executors. + * + * This class is intentionally conservative in its assumptions about the relative ordering + * and consistency of events returned by the listener. For simplicity, it does not account + * for speculated tasks. + */ + private class ExecutorAllocationListener(allocationManager: ExecutorAllocationManager) + extends SparkListener { + + private val stageIdToNumTasks = new mutable.HashMap[Int, Int] + private val stageIdToTaskIndices = new mutable.HashMap[Int, mutable.HashSet[Int]] + private val executorIdToTaskIds = new mutable.HashMap[String, mutable.HashSet[Long]] + + override def onStageSubmitted(stageSubmitted: SparkListenerStageSubmitted): Unit = { + synchronized { + val stageId = stageSubmitted.stageInfo.stageId + val numTasks = stageSubmitted.stageInfo.numTasks + stageIdToNumTasks(stageId) = numTasks + allocationManager.onSchedulerBacklogged() + } + } + + override def onStageCompleted(stageCompleted: SparkListenerStageCompleted): Unit = { + synchronized { + val stageId = stageCompleted.stageInfo.stageId + stageIdToNumTasks -= stageId + stageIdToTaskIndices -= stageId + + // If this is the last stage with pending tasks, mark the scheduler queue as empty + // This is needed in case the stage is aborted for any reason + if (stageIdToNumTasks.isEmpty) { + allocationManager.onSchedulerQueueEmpty() + } + } + } + + override def onTaskStart(taskStart: SparkListenerTaskStart): Unit = synchronized { + val stageId = taskStart.stageId + val taskId = taskStart.taskInfo.taskId + val taskIndex = taskStart.taskInfo.index + val executorId = taskStart.taskInfo.executorId + + // If this is the last pending task, mark the scheduler queue as empty + stageIdToTaskIndices.getOrElseUpdate(stageId, new mutable.HashSet[Int]) += taskIndex + val numTasksScheduled = stageIdToTaskIndices(stageId).size + val numTasksTotal = stageIdToNumTasks.getOrElse(stageId, -1) + if (numTasksScheduled == numTasksTotal) { + // No more pending tasks for this stage + stageIdToNumTasks -= stageId + if (stageIdToNumTasks.isEmpty) { + allocationManager.onSchedulerQueueEmpty() + } + } + + // Mark the executor on which this task is scheduled as busy + executorIdToTaskIds.getOrElseUpdate(executorId, new mutable.HashSet[Long]) += taskId + allocationManager.onExecutorBusy(executorId) + } + + override def onTaskEnd(taskEnd: SparkListenerTaskEnd): Unit = synchronized { + val executorId = taskEnd.taskInfo.executorId + val taskId = taskEnd.taskInfo.taskId + + // If the executor is no longer running scheduled any tasks, mark it as idle + if (executorIdToTaskIds.contains(executorId)) { + executorIdToTaskIds(executorId) -= taskId + if (executorIdToTaskIds(executorId).isEmpty) { + executorIdToTaskIds -= executorId + allocationManager.onExecutorIdle(executorId) + } + } + } + + override def onBlockManagerAdded(blockManagerAdded: SparkListenerBlockManagerAdded): Unit = { + val executorId = blockManagerAdded.blockManagerId.executorId + if (executorId != SparkContext.DRIVER_IDENTIFIER) { + allocationManager.onExecutorAdded(executorId) + } + } + + override def onBlockManagerRemoved( + blockManagerRemoved: SparkListenerBlockManagerRemoved): Unit = { + allocationManager.onExecutorRemoved(blockManagerRemoved.blockManagerId.executorId) + } + } + +} + +private object ExecutorAllocationManager { + val NOT_SET = Long.MaxValue +} + +/** + * An abstract clock for measuring elapsed time. + */ +private trait Clock { + def getTimeMillis: Long +} + +/** + * A clock backed by a monotonically increasing time source. + * The time returned by this clock does not correspond to any notion of wall-clock time. + */ +private class RealClock extends Clock { + override def getTimeMillis: Long = System.nanoTime / (1000 * 1000) +} + +/** + * A clock that allows the caller to customize the time. + * This is used mainly for testing. + */ +private class TestClock(startTimeMillis: Long) extends Clock { + private var time: Long = startTimeMillis + override def getTimeMillis: Long = time + def tick(ms: Long): Unit = { time += ms } +} diff --git a/core/src/main/scala/org/apache/spark/FutureAction.scala b/core/src/main/scala/org/apache/spark/FutureAction.scala index d5c8f9d76c476..e97a7375a267b 100644 --- a/core/src/main/scala/org/apache/spark/FutureAction.scala +++ b/core/src/main/scala/org/apache/spark/FutureAction.scala @@ -210,7 +210,11 @@ class ComplexFutureAction[T] extends FutureAction[T] { } catch { case e: Exception => p.failure(e) } finally { - thread = null + // This lock guarantees when calling `thread.interrupt()` in `cancel`, + // thread won't be set to null. + ComplexFutureAction.this.synchronized { + thread = null + } } } this diff --git a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala index 4cb0bd4142435..7d96962c4acd7 100644 --- a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala +++ b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala @@ -178,6 +178,7 @@ private[spark] abstract class MapOutputTracker(conf: SparkConf) extends Logging return MapOutputTracker.convertMapStatuses(shuffleId, reduceId, fetchedStatuses) } } else { + logError("Missing all output locations for shuffle " + shuffleId) throw new MetadataFetchFailedException( shuffleId, reduceId, "Missing all output locations for shuffle " + shuffleId) } @@ -348,7 +349,7 @@ private[spark] class MapOutputTrackerWorker(conf: SparkConf) extends MapOutputTr new ConcurrentHashMap[Int, Array[MapStatus]] } -private[spark] object MapOutputTracker { +private[spark] object MapOutputTracker extends Logging { // Serialize an array of map output locations into an efficient byte format so that we can send // it to reduce tasks. We do this by compressing the serialized bytes using GZIP. They will @@ -381,6 +382,7 @@ private[spark] object MapOutputTracker { statuses.map { status => if (status == null) { + logError("Missing an output location for shuffle " + shuffleId) throw new MetadataFetchFailedException( shuffleId, reduceId, "Missing an output location for shuffle " + shuffleId) } else { diff --git a/core/src/main/scala/org/apache/spark/SparkConf.scala b/core/src/main/scala/org/apache/spark/SparkConf.scala index dbbcc23305c50..ad0a9017afead 100644 --- a/core/src/main/scala/org/apache/spark/SparkConf.scala +++ b/core/src/main/scala/org/apache/spark/SparkConf.scala @@ -244,6 +244,19 @@ class SparkConf(loadDefaults: Boolean) extends Cloneable with Logging { val executorClasspathKey = "spark.executor.extraClassPath" val driverOptsKey = "spark.driver.extraJavaOptions" val driverClassPathKey = "spark.driver.extraClassPath" + val driverLibraryPathKey = "spark.driver.extraLibraryPath" + + // Used by Yarn in 1.1 and before + sys.props.get("spark.driver.libraryPath").foreach { value => + val warning = + s""" + |spark.driver.libraryPath was detected (set to '$value'). + |This is deprecated in Spark 1.2+. + | + |Please instead use: $driverLibraryPathKey + """.stripMargin + logWarning(warning) + } // Validate spark.executor.extraJavaOptions settings.get(executorOptsKey).map { javaOpts => diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index 4565832334420..40444c237b738 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -21,12 +21,10 @@ import scala.language.implicitConversions import java.io._ import java.net.URI -import java.util.Arrays +import java.util.{Arrays, Properties, UUID} import java.util.concurrent.atomic.AtomicInteger -import java.util.{Properties, UUID} import java.util.UUID.randomUUID import scala.collection.{Map, Set} -import scala.collection.JavaConversions._ import scala.collection.generic.Growable import scala.collection.mutable.HashMap import scala.reflect.{ClassTag, classTag} @@ -42,7 +40,8 @@ import akka.actor.Props import org.apache.spark.annotation.{DeveloperApi, Experimental} import org.apache.spark.broadcast.Broadcast import org.apache.spark.deploy.{LocalSparkCluster, SparkHadoopUtil} -import org.apache.spark.input.WholeTextFileInputFormat +import org.apache.spark.executor.TriggerThreadDump +import org.apache.spark.input.{StreamInputFormat, PortableDataStream, WholeTextFileInputFormat, FixedLengthBinaryInputFormat} import org.apache.spark.partial.{ApproximateEvaluator, PartialResult} import org.apache.spark.rdd._ import org.apache.spark.scheduler._ @@ -51,7 +50,8 @@ import org.apache.spark.scheduler.cluster.mesos.{CoarseMesosSchedulerBackend, Me import org.apache.spark.scheduler.local.LocalBackend import org.apache.spark.storage._ import org.apache.spark.ui.SparkUI -import org.apache.spark.util.{CallSite, ClosureCleaner, MetadataCleaner, MetadataCleanerType, TimeStampedWeakValueHashMap, Utils} +import org.apache.spark.ui.jobs.JobProgressListener +import org.apache.spark.util._ /** * Main entry point for Spark functionality. A SparkContext represents the connection to a Spark @@ -61,7 +61,7 @@ import org.apache.spark.util.{CallSite, ClosureCleaner, MetadataCleaner, Metadat * this config overrides the default configs as well as system properties. */ -class SparkContext(config: SparkConf) extends Logging { +class SparkContext(config: SparkConf) extends SparkStatusAPI with Logging { // This is used only by YARN for now, but should be relevant to other cluster types (Mesos, // etc) too. This is typically generated from InputFormatInfo.computePreferredLocations. It @@ -224,10 +224,15 @@ class SparkContext(config: SparkConf) extends Logging { private[spark] val metadataCleaner = new MetadataCleaner(MetadataCleanerType.SPARK_CONTEXT, this.cleanup, conf) - // Initialize the Spark UI, registering all associated listeners + + private[spark] val jobProgressListener = new JobProgressListener(conf) + listenerBus.addListener(jobProgressListener) + + // Initialize the Spark UI private[spark] val ui: Option[SparkUI] = if (conf.getBoolean("spark.ui.enabled", true)) { - Some(new SparkUI(this)) + Some(SparkUI.createLiveUI(this, conf, listenerBus, jobProgressListener, + env.securityManager,appName)) } else { // For tests, do not enable the UI None @@ -289,7 +294,8 @@ class SparkContext(config: SparkConf) extends Logging { executorEnvs("SPARK_USER") = sparkUser // Create and start the scheduler - private[spark] var taskScheduler = SparkContext.createTaskScheduler(this, master) + private[spark] var (schedulerBackend, taskScheduler) = + SparkContext.createTaskScheduler(this, master) private val heartbeatReceiver = env.actorSystem.actorOf( Props(new HeartbeatReceiver(taskScheduler)), "HeartbeatReceiver") @volatile private[spark] var dagScheduler: DAGScheduler = _ @@ -324,6 +330,15 @@ class SparkContext(config: SparkConf) extends Logging { } else None } + // Optionally scale number of executors dynamically based on workload. Exposed for testing. + private[spark] val executorAllocationManager: Option[ExecutorAllocationManager] = + if (conf.getBoolean("spark.dynamicAllocation.enabled", false)) { + Some(new ExecutorAllocationManager(this)) + } else { + None + } + executorAllocationManager.foreach(_.start()) + // At this point, all relevant SparkListeners have been registered, so begin releasing events listenerBus.start() @@ -346,6 +361,29 @@ class SparkContext(config: SparkConf) extends Logging { override protected def childValue(parent: Properties): Properties = new Properties(parent) } + /** + * Called by the web UI to obtain executor thread dumps. This method may be expensive. + * Logs an error and returns None if we failed to obtain a thread dump, which could occur due + * to an executor being dead or unresponsive or due to network issues while sending the thread + * dump message back to the driver. + */ + private[spark] def getExecutorThreadDump(executorId: String): Option[Array[ThreadStackTrace]] = { + try { + if (executorId == SparkContext.DRIVER_IDENTIFIER) { + Some(Utils.getThreadDump()) + } else { + val (host, port) = env.blockManager.master.getActorSystemHostPortForExecutor(executorId).get + val actorRef = AkkaUtils.makeExecutorRef("ExecutorActor", conf, host, port, env.actorSystem) + Some(AkkaUtils.askWithReply[Array[ThreadStackTrace]](TriggerThreadDump, actorRef, + AkkaUtils.numRetries(conf), AkkaUtils.retryWaitMs(conf), AkkaUtils.askTimeout(conf))) + } + } catch { + case e: Exception => + logError(s"Exception getting thread dump from executor $executorId", e) + None + } + } + private[spark] def getLocalProperties: Properties = localProperties.get() private[spark] def setLocalProperties(props: Properties) { @@ -518,6 +556,69 @@ class SparkContext(config: SparkConf) extends Logging { minPartitions).setName(path) } + + /** + * Get an RDD for a Hadoop-readable dataset as PortableDataStream for each file + * (useful for binary data) + * + * For example, if you have the following files: + * {{{ + * hdfs://a-hdfs-path/part-00000 + * hdfs://a-hdfs-path/part-00001 + * ... + * hdfs://a-hdfs-path/part-nnnnn + * }}} + * + * Do + * `val rdd = sparkContext.dataStreamFiles("hdfs://a-hdfs-path")`, + * + * then `rdd` contains + * {{{ + * (a-hdfs-path/part-00000, its content) + * (a-hdfs-path/part-00001, its content) + * ... + * (a-hdfs-path/part-nnnnn, its content) + * }}} + * + * @param minPartitions A suggestion value of the minimal splitting number for input data. + * + * @note Small files are preferred; very large files may cause bad performance. + */ + @Experimental + def binaryFiles(path: String, minPartitions: Int = defaultMinPartitions): + RDD[(String, PortableDataStream)] = { + val job = new NewHadoopJob(hadoopConfiguration) + NewFileInputFormat.addInputPath(job, new Path(path)) + val updateConf = job.getConfiguration + new BinaryFileRDD( + this, + classOf[StreamInputFormat], + classOf[String], + classOf[PortableDataStream], + updateConf, + minPartitions).setName(path) + } + + /** + * Load data from a flat binary file, assuming the length of each record is constant. + * + * @param path Directory to the input data files + * @param recordLength The length at which to split the records + * @return An RDD of data with values, represented as byte arrays + */ + @Experimental + def binaryRecords(path: String, recordLength: Int, conf: Configuration = hadoopConfiguration) + : RDD[Array[Byte]] = { + conf.setInt(FixedLengthBinaryInputFormat.RECORD_LENGTH_PROPERTY, recordLength) + val br = newAPIHadoopFile[LongWritable, BytesWritable, FixedLengthBinaryInputFormat](path, + classOf[FixedLengthBinaryInputFormat], + classOf[LongWritable], + classOf[BytesWritable], + conf=conf) + val data = br.map{ case (k, v) => v.getBytes} + data + } + /** * Get an RDD for a Hadoop-readable dataset from a Hadoop JobConf given its InputFormat and other * necessary info (e.g. file name for a filesystem-based dataset, table name for HyperTable), @@ -851,71 +952,48 @@ class SparkContext(config: SparkConf) extends Logging { listenerBus.addListener(listener) } - /** The version of Spark on which this application is running. */ - def version = SPARK_VERSION - - /** - * Return a map from the slave to the max memory available for caching and the remaining - * memory available for caching. - */ - def getExecutorMemoryStatus: Map[String, (Long, Long)] = { - env.blockManager.master.getMemoryStatus.map { case(blockManagerId, mem) => - (blockManagerId.host + ":" + blockManagerId.port, mem) - } - } - - /** - * :: DeveloperApi :: - * Return information about what RDDs are cached, if they are in mem or on disk, how much space - * they take, etc. - */ - @DeveloperApi - def getRDDStorageInfo: Array[RDDInfo] = { - val rddInfos = persistentRdds.values.map(RDDInfo.fromRdd).toArray - StorageUtils.updateRddInfo(rddInfos, getExecutorStorageStatus) - rddInfos.filter(_.isCached) - } - - /** - * Returns an immutable map of RDDs that have marked themselves as persistent via cache() call. - * Note that this does not necessarily mean the caching or computation was successful. - */ - def getPersistentRDDs: Map[Int, RDD[_]] = persistentRdds.toMap - /** * :: DeveloperApi :: - * Return information about blocks stored in all of the slaves + * Request an additional number of executors from the cluster manager. + * This is currently only supported in Yarn mode. Return whether the request is received. */ @DeveloperApi - def getExecutorStorageStatus: Array[StorageStatus] = { - env.blockManager.master.getStorageStatus + def requestExecutors(numAdditionalExecutors: Int): Boolean = { + schedulerBackend match { + case b: CoarseGrainedSchedulerBackend => + b.requestExecutors(numAdditionalExecutors) + case _ => + logWarning("Requesting executors is only supported in coarse-grained mode") + false + } } /** * :: DeveloperApi :: - * Return pools for fair scheduler + * Request that the cluster manager kill the specified executors. + * This is currently only supported in Yarn mode. Return whether the request is received. */ @DeveloperApi - def getAllPools: Seq[Schedulable] = { - // TODO(xiajunluan): We should take nested pools into account - taskScheduler.rootPool.schedulableQueue.toSeq + def killExecutors(executorIds: Seq[String]): Boolean = { + schedulerBackend match { + case b: CoarseGrainedSchedulerBackend => + b.killExecutors(executorIds) + case _ => + logWarning("Killing executors is only supported in coarse-grained mode") + false + } } /** * :: DeveloperApi :: - * Return the pool associated with the given name, if one exists + * Request that cluster manager the kill the specified executor. + * This is currently only supported in Yarn mode. Return whether the request is received. */ @DeveloperApi - def getPoolForName(pool: String): Option[Schedulable] = { - Option(taskScheduler.rootPool.schedulableNameToSchedulable.get(pool)) - } + def killExecutor(executorId: String): Boolean = killExecutors(Seq(executorId)) - /** - * Return current scheduling mode - */ - def getSchedulingMode: SchedulingMode.SchedulingMode = { - taskScheduler.schedulingMode - } + /** The version of Spark on which this application is running. */ + def version = SPARK_VERSION /** * Clear the job's list of files added by `addFile` so that they do not get downloaded to @@ -1341,6 +1419,8 @@ object SparkContext extends Logging { private[spark] val SPARK_UNKNOWN_USER = "" + private[spark] val DRIVER_IDENTIFIER = "" + implicit object DoubleAccumulatorParam extends AccumulatorParam[Double] { def addInPlace(t1: Double, t2: Double): Double = t1 + t2 def zero(initialValue: Double) = 0.0 @@ -1496,8 +1576,13 @@ object SparkContext extends Logging { res } - /** Creates a task scheduler based on a given master URL. Extracted for testing. */ - private def createTaskScheduler(sc: SparkContext, master: String): TaskScheduler = { + /** + * Create a task scheduler based on a given master URL. + * Return a 2-tuple of the scheduler backend and the task scheduler. + */ + private def createTaskScheduler( + sc: SparkContext, + master: String): (SchedulerBackend, TaskScheduler) = { // Regular expression used for local[N] and local[*] master formats val LOCAL_N_REGEX = """local\[([0-9]+|\*)\]""".r // Regular expression for local[N, maxRetries], used in tests with failing tasks @@ -1519,7 +1604,7 @@ object SparkContext extends Logging { val scheduler = new TaskSchedulerImpl(sc, MAX_LOCAL_TASK_FAILURES, isLocal = true) val backend = new LocalBackend(scheduler, 1) scheduler.initialize(backend) - scheduler + (backend, scheduler) case LOCAL_N_REGEX(threads) => def localCpuCount = Runtime.getRuntime.availableProcessors() @@ -1528,7 +1613,7 @@ object SparkContext extends Logging { val scheduler = new TaskSchedulerImpl(sc, MAX_LOCAL_TASK_FAILURES, isLocal = true) val backend = new LocalBackend(scheduler, threadCount) scheduler.initialize(backend) - scheduler + (backend, scheduler) case LOCAL_N_FAILURES_REGEX(threads, maxFailures) => def localCpuCount = Runtime.getRuntime.availableProcessors() @@ -1538,14 +1623,14 @@ object SparkContext extends Logging { val scheduler = new TaskSchedulerImpl(sc, maxFailures.toInt, isLocal = true) val backend = new LocalBackend(scheduler, threadCount) scheduler.initialize(backend) - scheduler + (backend, scheduler) case SPARK_REGEX(sparkUrl) => val scheduler = new TaskSchedulerImpl(sc) val masterUrls = sparkUrl.split(",").map("spark://" + _) val backend = new SparkDeploySchedulerBackend(scheduler, sc, masterUrls) scheduler.initialize(backend) - scheduler + (backend, scheduler) case LOCAL_CLUSTER_REGEX(numSlaves, coresPerSlave, memoryPerSlave) => // Check to make sure memory requested <= memoryPerSlave. Otherwise Spark will just hang. @@ -1565,7 +1650,7 @@ object SparkContext extends Logging { backend.shutdownCallback = (backend: SparkDeploySchedulerBackend) => { localCluster.stop() } - scheduler + (backend, scheduler) case "yarn-standalone" | "yarn-cluster" => if (master == "yarn-standalone") { @@ -1594,7 +1679,7 @@ object SparkContext extends Logging { } } scheduler.initialize(backend) - scheduler + (backend, scheduler) case "yarn-client" => val scheduler = try { @@ -1621,7 +1706,7 @@ object SparkContext extends Logging { } scheduler.initialize(backend) - scheduler + (backend, scheduler) case mesosUrl @ MESOS_REGEX(_) => MesosNativeLibrary.load() @@ -1634,13 +1719,13 @@ object SparkContext extends Logging { new MesosSchedulerBackend(scheduler, sc, url) } scheduler.initialize(backend) - scheduler + (backend, scheduler) case SIMR_REGEX(simrUrl) => val scheduler = new TaskSchedulerImpl(sc) val backend = new SimrSchedulerBackend(scheduler, sc, simrUrl) scheduler.initialize(backend) - scheduler + (backend, scheduler) case _ => throw new SparkException("Could not parse Master URL: '" + master + "'") diff --git a/core/src/main/scala/org/apache/spark/SparkEnv.scala b/core/src/main/scala/org/apache/spark/SparkEnv.scala index 5c076e5f1c11d..e2f13accdfab5 100644 --- a/core/src/main/scala/org/apache/spark/SparkEnv.scala +++ b/core/src/main/scala/org/apache/spark/SparkEnv.scala @@ -32,6 +32,7 @@ import org.apache.spark.api.python.PythonWorkerFactory import org.apache.spark.broadcast.BroadcastManager import org.apache.spark.metrics.MetricsSystem import org.apache.spark.network.BlockTransferService +import org.apache.spark.network.netty.NettyBlockTransferService import org.apache.spark.network.nio.NioBlockTransferService import org.apache.spark.scheduler.LiveListenerBus import org.apache.spark.serializer.Serializer @@ -155,7 +156,7 @@ object SparkEnv extends Logging { assert(conf.contains("spark.driver.port"), "spark.driver.port is not set on the driver!") val hostname = conf.get("spark.driver.host") val port = conf.get("spark.driver.port").toInt - create(conf, "", hostname, port, true, isLocal, listenerBus) + create(conf, SparkContext.DRIVER_IDENTIFIER, hostname, port, true, isLocal, listenerBus) } /** @@ -272,7 +273,13 @@ object SparkEnv extends Logging { val shuffleMemoryManager = new ShuffleMemoryManager(conf) - val blockTransferService = new NioBlockTransferService(conf, securityManager) + val blockTransferService = + conf.get("spark.shuffle.blockTransferService", "netty").toLowerCase match { + case "netty" => + new NettyBlockTransferService(conf) + case "nio" => + new NioBlockTransferService(conf, securityManager) + } val blockManagerMaster = new BlockManagerMaster(registerOrLookup( "BlockManagerMaster", diff --git a/core/src/main/scala/org/apache/spark/SparkSaslClient.scala b/core/src/main/scala/org/apache/spark/SparkSaslClient.scala index 65003b6ac6a0a..a954fcc0c31fa 100644 --- a/core/src/main/scala/org/apache/spark/SparkSaslClient.scala +++ b/core/src/main/scala/org/apache/spark/SparkSaslClient.scala @@ -17,7 +17,6 @@ package org.apache.spark -import java.io.IOException import javax.security.auth.callback.Callback import javax.security.auth.callback.CallbackHandler import javax.security.auth.callback.NameCallback @@ -31,6 +30,8 @@ import javax.security.sasl.SaslException import scala.collection.JavaConversions.mapAsJavaMap +import com.google.common.base.Charsets.UTF_8 + /** * Implements SASL Client logic for Spark */ @@ -111,10 +112,10 @@ private[spark] class SparkSaslClient(securityMgr: SecurityManager) extends Logg CallbackHandler { private val userName: String = - SparkSaslServer.encodeIdentifier(securityMgr.getSaslUser().getBytes("utf-8")) + SparkSaslServer.encodeIdentifier(securityMgr.getSaslUser().getBytes(UTF_8)) private val secretKey = securityMgr.getSecretKey() private val userPassword: Array[Char] = SparkSaslServer.encodePassword( - if (secretKey != null) secretKey.getBytes("utf-8") else "".getBytes("utf-8")) + if (secretKey != null) secretKey.getBytes(UTF_8) else "".getBytes(UTF_8)) /** * Implementation used to respond to SASL request from the server. diff --git a/core/src/main/scala/org/apache/spark/SparkSaslServer.scala b/core/src/main/scala/org/apache/spark/SparkSaslServer.scala index f6b0a9132aca4..7c2afb364661f 100644 --- a/core/src/main/scala/org/apache/spark/SparkSaslServer.scala +++ b/core/src/main/scala/org/apache/spark/SparkSaslServer.scala @@ -28,6 +28,8 @@ import javax.security.sasl.Sasl import javax.security.sasl.SaslException import javax.security.sasl.SaslServer import scala.collection.JavaConversions.mapAsJavaMap + +import com.google.common.base.Charsets.UTF_8 import org.apache.commons.net.util.Base64 /** @@ -89,7 +91,7 @@ private[spark] class SparkSaslServer(securityMgr: SecurityManager) extends Loggi extends CallbackHandler { private val userName: String = - SparkSaslServer.encodeIdentifier(securityMgr.getSaslUser().getBytes("utf-8")) + SparkSaslServer.encodeIdentifier(securityMgr.getSaslUser().getBytes(UTF_8)) override def handle(callbacks: Array[Callback]) { logDebug("In the sasl server callback handler") @@ -101,7 +103,7 @@ private[spark] class SparkSaslServer(securityMgr: SecurityManager) extends Loggi case pc: PasswordCallback => { logDebug("handle: SASL server callback: setting userPassword") val password: Array[Char] = - SparkSaslServer.encodePassword(securityMgr.getSecretKey().getBytes("utf-8")) + SparkSaslServer.encodePassword(securityMgr.getSecretKey().getBytes(UTF_8)) pc.setPassword(password) } case rc: RealmCallback => { @@ -159,7 +161,7 @@ private[spark] object SparkSaslServer { * @return Base64-encoded string */ def encodeIdentifier(identifier: Array[Byte]): String = { - new String(Base64.encodeBase64(identifier), "utf-8") + new String(Base64.encodeBase64(identifier), UTF_8) } /** @@ -168,7 +170,7 @@ private[spark] object SparkSaslServer { * @return password as a char array. */ def encodePassword(password: Array[Byte]): Array[Char] = { - new String(Base64.encodeBase64(password), "utf-8").toCharArray() + new String(Base64.encodeBase64(password), UTF_8).toCharArray() } } diff --git a/core/src/main/scala/org/apache/spark/SparkStatusAPI.scala b/core/src/main/scala/org/apache/spark/SparkStatusAPI.scala new file mode 100644 index 0000000000000..1982499c5e1d3 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/SparkStatusAPI.scala @@ -0,0 +1,142 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark + +import scala.collection.Map +import scala.collection.JavaConversions._ + +import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.rdd.RDD +import org.apache.spark.scheduler.{SchedulingMode, Schedulable} +import org.apache.spark.storage.{StorageStatus, StorageUtils, RDDInfo} + +/** + * Trait that implements Spark's status APIs. This trait is designed to be mixed into + * SparkContext; it allows the status API code to live in its own file. + */ +private[spark] trait SparkStatusAPI { this: SparkContext => + + /** + * Return a map from the slave to the max memory available for caching and the remaining + * memory available for caching. + */ + def getExecutorMemoryStatus: Map[String, (Long, Long)] = { + env.blockManager.master.getMemoryStatus.map { case(blockManagerId, mem) => + (blockManagerId.host + ":" + blockManagerId.port, mem) + } + } + + /** + * :: DeveloperApi :: + * Return information about what RDDs are cached, if they are in mem or on disk, how much space + * they take, etc. + */ + @DeveloperApi + def getRDDStorageInfo: Array[RDDInfo] = { + val rddInfos = persistentRdds.values.map(RDDInfo.fromRdd).toArray + StorageUtils.updateRddInfo(rddInfos, getExecutorStorageStatus) + rddInfos.filter(_.isCached) + } + + /** + * Returns an immutable map of RDDs that have marked themselves as persistent via cache() call. + * Note that this does not necessarily mean the caching or computation was successful. + */ + def getPersistentRDDs: Map[Int, RDD[_]] = persistentRdds.toMap + + /** + * :: DeveloperApi :: + * Return information about blocks stored in all of the slaves + */ + @DeveloperApi + def getExecutorStorageStatus: Array[StorageStatus] = { + env.blockManager.master.getStorageStatus + } + + /** + * :: DeveloperApi :: + * Return pools for fair scheduler + */ + @DeveloperApi + def getAllPools: Seq[Schedulable] = { + // TODO(xiajunluan): We should take nested pools into account + taskScheduler.rootPool.schedulableQueue.toSeq + } + + /** + * :: DeveloperApi :: + * Return the pool associated with the given name, if one exists + */ + @DeveloperApi + def getPoolForName(pool: String): Option[Schedulable] = { + Option(taskScheduler.rootPool.schedulableNameToSchedulable.get(pool)) + } + + /** + * Return current scheduling mode + */ + def getSchedulingMode: SchedulingMode.SchedulingMode = { + taskScheduler.schedulingMode + } + + + /** + * Return a list of all known jobs in a particular job group. The returned list may contain + * running, failed, and completed jobs, and may vary across invocations of this method. This + * method does not guarantee the order of the elements in its result. + */ + def getJobIdsForGroup(jobGroup: String): Array[Int] = { + jobProgressListener.synchronized { + val jobData = jobProgressListener.jobIdToData.valuesIterator + jobData.filter(_.jobGroup.exists(_ == jobGroup)).map(_.jobId).toArray + } + } + + /** + * Returns job information, or `None` if the job info could not be found or was garbage collected. + */ + def getJobInfo(jobId: Int): Option[SparkJobInfo] = { + jobProgressListener.synchronized { + jobProgressListener.jobIdToData.get(jobId).map { data => + new SparkJobInfoImpl(jobId, data.stageIds.toArray, data.status) + } + } + } + + /** + * Returns stage information, or `None` if the stage info could not be found or was + * garbage collected. + */ + def getStageInfo(stageId: Int): Option[SparkStageInfo] = { + jobProgressListener.synchronized { + for ( + info <- jobProgressListener.stageIdToInfo.get(stageId); + data <- jobProgressListener.stageIdToData.get((stageId, info.attemptId)) + ) yield { + new SparkStageInfoImpl( + stageId, + info.attemptId, + info.name, + info.numTasks, + data.numActiveTasks, + data.numCompleteTasks, + data.numFailedTasks) + } + } + } +} diff --git a/core/src/main/scala/org/apache/spark/StatusAPIImpl.scala b/core/src/main/scala/org/apache/spark/StatusAPIImpl.scala new file mode 100644 index 0000000000000..90b47c847fbca --- /dev/null +++ b/core/src/main/scala/org/apache/spark/StatusAPIImpl.scala @@ -0,0 +1,34 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark + +private class SparkJobInfoImpl ( + val jobId: Int, + val stageIds: Array[Int], + val status: JobExecutionStatus) + extends SparkJobInfo + +private class SparkStageInfoImpl( + val stageId: Int, + val currentAttemptId: Int, + val name: String, + val numTasks: Int, + val numActiveTasks: Int, + val numCompletedTasks: Int, + val numFailedTasks: Int) + extends SparkStageInfo diff --git a/core/src/main/scala/org/apache/spark/TaskEndReason.scala b/core/src/main/scala/org/apache/spark/TaskEndReason.scala index 8f0c5e78416c2..f45b463fb6f62 100644 --- a/core/src/main/scala/org/apache/spark/TaskEndReason.scala +++ b/core/src/main/scala/org/apache/spark/TaskEndReason.scala @@ -69,11 +69,13 @@ case class FetchFailed( bmAddress: BlockManagerId, // Note that bmAddress can be null shuffleId: Int, mapId: Int, - reduceId: Int) + reduceId: Int, + message: String) extends TaskFailedReason { override def toErrorString: String = { val bmAddressString = if (bmAddress == null) "null" else bmAddress.toString - s"FetchFailed($bmAddressString, shuffleId=$shuffleId, mapId=$mapId, reduceId=$reduceId)" + s"FetchFailed($bmAddressString, shuffleId=$shuffleId, mapId=$mapId, reduceId=$reduceId, " + + s"message=\n$message\n)" } } @@ -117,8 +119,8 @@ case object TaskKilled extends TaskFailedReason { * the task crashed the JVM. */ @DeveloperApi -case object ExecutorLostFailure extends TaskFailedReason { - override def toErrorString: String = "ExecutorLostFailure (executor lost)" +case class ExecutorLostFailure(execId: String) extends TaskFailedReason { + override def toErrorString: String = s"ExecutorLostFailure (executor ${execId} lost)" } /** diff --git a/core/src/main/scala/org/apache/spark/TestUtils.scala b/core/src/main/scala/org/apache/spark/TestUtils.scala index e72826dc25f41..34078142f5385 100644 --- a/core/src/main/scala/org/apache/spark/TestUtils.scala +++ b/core/src/main/scala/org/apache/spark/TestUtils.scala @@ -23,8 +23,8 @@ import java.util.jar.{JarEntry, JarOutputStream} import scala.collection.JavaConversions._ +import com.google.common.io.{ByteStreams, Files} import javax.tools.{JavaFileObject, SimpleJavaFileObject, ToolProvider} -import com.google.common.io.Files import org.apache.spark.util.Utils @@ -64,12 +64,7 @@ private[spark] object TestUtils { jarStream.putNextEntry(jarEntry) val in = new FileInputStream(file) - val buffer = new Array[Byte](10240) - var nRead = 0 - while (nRead <= 0) { - nRead = in.read(buffer, 0, buffer.length) - jarStream.write(buffer, 0, nRead) - } + ByteStreams.copy(in, jarStream) in.close() } jarStream.close() diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaDoubleRDD.scala b/core/src/main/scala/org/apache/spark/api/java/JavaDoubleRDD.scala index a6123bd108c11..8e8f7f6c4fda2 100644 --- a/core/src/main/scala/org/apache/spark/api/java/JavaDoubleRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/java/JavaDoubleRDD.scala @@ -114,7 +114,7 @@ class JavaDoubleRDD(val srdd: RDD[scala.Double]) extends JavaRDDLike[JDouble, Ja * Return an RDD with the elements from `this` that are not in `other`. * * Uses `this` partitioner/partition size, because even if `other` is huge, the resulting - * RDD will be <= us. + * RDD will be <= us. */ def subtract(other: JavaDoubleRDD): JavaDoubleRDD = fromRDD(srdd.subtract(other)) @@ -233,11 +233,11 @@ class JavaDoubleRDD(val srdd: RDD[scala.Double]) extends JavaRDDLike[JDouble, Ja * to the left except for the last which is closed * e.g. for the array * [1,10,20,50] the buckets are [1,10) [10,20) [20,50] - * e.g 1<=x<10 , 10<=x<20, 20<=x<50 + * e.g 1<=x<10 , 10<=x<20, 20<=x<50 * And on the input of 1 and 50 we would have a histogram of 1,0,0 * * Note: if your histogram is evenly spaced (e.g. [0, 10, 20, 30]) this can be switched - * from an O(log n) inseration to O(1) per element. (where n = # buckets) if you set evenBuckets + * from an O(log n) insertion to O(1) per element. (where n = # buckets) if you set evenBuckets * to true. * buckets must be sorted and not contain any duplicates. * buckets array must be at least two elements diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala b/core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala index c38b96528d037..e37f3acaf6e30 100644 --- a/core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala @@ -392,7 +392,7 @@ class JavaPairRDD[K, V](val rdd: RDD[(K, V)]) * Return an RDD with the elements from `this` that are not in `other`. * * Uses `this` partitioner/partition size, because even if `other` is huge, the resulting - * RDD will be <= us. + * RDD will be <= us. */ def subtract(other: JavaPairRDD[K, V]): JavaPairRDD[K, V] = fromRDD(rdd.subtract(other)) @@ -413,7 +413,7 @@ class JavaPairRDD[K, V](val rdd: RDD[(K, V)]) * Return an RDD with the pairs from `this` whose keys are not in `other`. * * Uses `this` partitioner/partition size, because even if `other` is huge, the resulting - * RDD will be <= us. + * RDD will be <= us. */ def subtractByKey[W](other: JavaPairRDD[K, W]): JavaPairRDD[K, V] = { implicit val ctag: ClassTag[W] = fakeClassTag diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaSparkContext.scala b/core/src/main/scala/org/apache/spark/api/java/JavaSparkContext.scala index 791d853a015a1..e3aeba7e6c39d 100644 --- a/core/src/main/scala/org/apache/spark/api/java/JavaSparkContext.scala +++ b/core/src/main/scala/org/apache/spark/api/java/JavaSparkContext.scala @@ -21,6 +21,11 @@ import java.io.Closeable import java.util import java.util.{Map => JMap} +import java.io.DataInputStream + +import org.apache.hadoop.io.{BytesWritable, LongWritable} +import org.apache.spark.input.{PortableDataStream, FixedLengthBinaryInputFormat} + import scala.collection.JavaConversions import scala.collection.JavaConversions._ import scala.language.implicitConversions @@ -32,7 +37,8 @@ import org.apache.hadoop.mapred.{InputFormat, JobConf} import org.apache.hadoop.mapreduce.{InputFormat => NewInputFormat} import org.apache.spark._ -import org.apache.spark.SparkContext.{DoubleAccumulatorParam, IntAccumulatorParam} +import org.apache.spark.SparkContext._ +import org.apache.spark.annotation.Experimental import org.apache.spark.api.java.JavaSparkContext.fakeClassTag import org.apache.spark.broadcast.Broadcast import org.apache.spark.rdd.{EmptyRDD, HadoopRDD, NewHadoopRDD, RDD} @@ -132,6 +138,25 @@ class JavaSparkContext(val sc: SparkContext) /** Default min number of partitions for Hadoop RDDs when not given by user */ def defaultMinPartitions: java.lang.Integer = sc.defaultMinPartitions + + /** + * Return a list of all known jobs in a particular job group. The returned list may contain + * running, failed, and completed jobs, and may vary across invocations of this method. This + * method does not guarantee the order of the elements in its result. + */ + def getJobIdsForGroup(jobGroup: String): Array[Int] = sc.getJobIdsForGroup(jobGroup) + + /** + * Returns job information, or `null` if the job info could not be found or was garbage collected. + */ + def getJobInfo(jobId: Int): SparkJobInfo = sc.getJobInfo(jobId).orNull + + /** + * Returns stage information, or `null` if the stage info could not be found or was + * garbage collected. + */ + def getStageInfo(stageId: Int): SparkStageInfo = sc.getStageInfo(stageId).orNull + /** Distribute a local Scala collection to form an RDD. */ def parallelize[T](list: java.util.List[T], numSlices: Int): JavaRDD[T] = { implicit val ctag: ClassTag[T] = fakeClassTag @@ -183,6 +208,8 @@ class JavaSparkContext(val sc: SparkContext) def textFile(path: String, minPartitions: Int): JavaRDD[String] = sc.textFile(path, minPartitions) + + /** * Read a directory of text files from HDFS, a local file system (available on all nodes), or any * Hadoop-supported file system URI. Each file is read as a single record and returned in a @@ -196,7 +223,10 @@ class JavaSparkContext(val sc: SparkContext) * hdfs://a-hdfs-path/part-nnnnn * }}} * - * Do `JavaPairRDD rdd = sparkContext.wholeTextFiles("hdfs://a-hdfs-path")`, + * Do + * {{{ + * JavaPairRDD rdd = sparkContext.wholeTextFiles("hdfs://a-hdfs-path") + * }}} * *

then `rdd` contains * {{{ @@ -223,6 +253,78 @@ class JavaSparkContext(val sc: SparkContext) def wholeTextFiles(path: String): JavaPairRDD[String, String] = new JavaPairRDD(sc.wholeTextFiles(path)) + /** + * Read a directory of binary files from HDFS, a local file system (available on all nodes), + * or any Hadoop-supported file system URI as a byte array. Each file is read as a single + * record and returned in a key-value pair, where the key is the path of each file, + * the value is the content of each file. + * + * For example, if you have the following files: + * {{{ + * hdfs://a-hdfs-path/part-00000 + * hdfs://a-hdfs-path/part-00001 + * ... + * hdfs://a-hdfs-path/part-nnnnn + * }}} + * + * Do + * `JavaPairRDD rdd = sparkContext.dataStreamFiles("hdfs://a-hdfs-path")`, + * + * then `rdd` contains + * {{{ + * (a-hdfs-path/part-00000, its content) + * (a-hdfs-path/part-00001, its content) + * ... + * (a-hdfs-path/part-nnnnn, its content) + * }}} + * + * @note Small files are preferred; very large files but may cause bad performance. + * + * @param minPartitions A suggestion value of the minimal splitting number for input data. + */ + def binaryFiles(path: String, minPartitions: Int): JavaPairRDD[String, PortableDataStream] = + new JavaPairRDD(sc.binaryFiles(path, minPartitions)) + + /** + * Read a directory of binary files from HDFS, a local file system (available on all nodes), + * or any Hadoop-supported file system URI as a byte array. Each file is read as a single + * record and returned in a key-value pair, where the key is the path of each file, + * the value is the content of each file. + * + * For example, if you have the following files: + * {{{ + * hdfs://a-hdfs-path/part-00000 + * hdfs://a-hdfs-path/part-00001 + * ... + * hdfs://a-hdfs-path/part-nnnnn + * }}} + * + * Do + * `JavaPairRDD rdd = sparkContext.dataStreamFiles("hdfs://a-hdfs-path")`, + * + * then `rdd` contains + * {{{ + * (a-hdfs-path/part-00000, its content) + * (a-hdfs-path/part-00001, its content) + * ... + * (a-hdfs-path/part-nnnnn, its content) + * }}} + * + * @note Small files are preferred; very large files but may cause bad performance. + */ + def binaryFiles(path: String): JavaPairRDD[String, PortableDataStream] = + new JavaPairRDD(sc.binaryFiles(path, defaultMinPartitions)) + + /** + * Load data from a flat binary file, assuming the length of each record is constant. + * + * @param path Directory to the input data files + * @return An RDD of data with values, represented as byte arrays + */ + def binaryRecords(path: String, recordLength: Int): JavaRDD[Array[Byte]] = { + new JavaRDD(sc.binaryRecords(path, recordLength)) + } + /** Get an RDD for a Hadoop SequenceFile with given key and value types. * * '''Note:''' Because Hadoop's RecordReader class re-uses the same Writable object for each diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonHadoopUtil.scala b/core/src/main/scala/org/apache/spark/api/python/PythonHadoopUtil.scala index 49dc95f349eac..5ba66178e2b78 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonHadoopUtil.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonHadoopUtil.scala @@ -61,8 +61,7 @@ private[python] object Converter extends Logging { * Other objects are passed through without conversion. */ private[python] class WritableToJavaConverter( - conf: Broadcast[SerializableWritable[Configuration]], - batchSize: Int) extends Converter[Any, Any] { + conf: Broadcast[SerializableWritable[Configuration]]) extends Converter[Any, Any] { /** * Converts a [[org.apache.hadoop.io.Writable]] to the underlying primitive, String or @@ -94,8 +93,7 @@ private[python] class WritableToJavaConverter( map.put(convertWritable(k), convertWritable(v)) } map - case w: Writable => - if (batchSize > 1) WritableUtils.clone(w, conf.value.value) else w + case w: Writable => WritableUtils.clone(w, conf.value.value) case other => other } } diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala index 163dca6cade5a..e94ccdcd47bb7 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala @@ -19,15 +19,13 @@ package org.apache.spark.api.python import java.io._ import java.net._ -import java.nio.charset.Charset import java.util.{List => JList, ArrayList => JArrayList, Map => JMap, Collections} import scala.collection.JavaConversions._ -import scala.collection.JavaConverters._ import scala.collection.mutable import scala.language.existentials -import net.razorvine.pickle.{Pickler, Unpickler} +import com.google.common.base.Charsets.UTF_8 import org.apache.hadoop.conf.Configuration import org.apache.hadoop.io.compress.CompressionCodec @@ -134,7 +132,7 @@ private[spark] class PythonRDD( val exLength = stream.readInt() val obj = new Array[Byte](exLength) stream.readFully(obj) - throw new PythonException(new String(obj, "utf-8"), + throw new PythonException(new String(obj, UTF_8), writerThread.exception.getOrElse(null)) case SpecialLengths.END_OF_DATA_SECTION => // We've finished the data section of the output, but we can still @@ -318,7 +316,6 @@ private object SpecialLengths { } private[spark] object PythonRDD extends Logging { - val UTF8 = Charset.forName("UTF-8") // remember the broadcasts sent to each worker private val workerBroadcasts = new mutable.WeakHashMap[Socket, mutable.Set[Long]]() @@ -443,7 +440,7 @@ private[spark] object PythonRDD extends Logging { val rdd = sc.sc.sequenceFile[K, V](path, kc, vc, minSplits) val confBroadcasted = sc.sc.broadcast(new SerializableWritable(sc.hadoopConfiguration())) val converted = convertRDD(rdd, keyConverterClass, valueConverterClass, - new WritableToJavaConverter(confBroadcasted, batchSize)) + new WritableToJavaConverter(confBroadcasted)) JavaRDD.fromRDD(SerDeUtil.pairRDDToPython(converted, batchSize)) } @@ -469,7 +466,7 @@ private[spark] object PythonRDD extends Logging { Some(path), inputFormatClass, keyClass, valueClass, mergedConf) val confBroadcasted = sc.sc.broadcast(new SerializableWritable(mergedConf)) val converted = convertRDD(rdd, keyConverterClass, valueConverterClass, - new WritableToJavaConverter(confBroadcasted, batchSize)) + new WritableToJavaConverter(confBroadcasted)) JavaRDD.fromRDD(SerDeUtil.pairRDDToPython(converted, batchSize)) } @@ -495,7 +492,7 @@ private[spark] object PythonRDD extends Logging { None, inputFormatClass, keyClass, valueClass, conf) val confBroadcasted = sc.sc.broadcast(new SerializableWritable(conf)) val converted = convertRDD(rdd, keyConverterClass, valueConverterClass, - new WritableToJavaConverter(confBroadcasted, batchSize)) + new WritableToJavaConverter(confBroadcasted)) JavaRDD.fromRDD(SerDeUtil.pairRDDToPython(converted, batchSize)) } @@ -538,7 +535,7 @@ private[spark] object PythonRDD extends Logging { Some(path), inputFormatClass, keyClass, valueClass, mergedConf) val confBroadcasted = sc.sc.broadcast(new SerializableWritable(mergedConf)) val converted = convertRDD(rdd, keyConverterClass, valueConverterClass, - new WritableToJavaConverter(confBroadcasted, batchSize)) + new WritableToJavaConverter(confBroadcasted)) JavaRDD.fromRDD(SerDeUtil.pairRDDToPython(converted, batchSize)) } @@ -564,7 +561,7 @@ private[spark] object PythonRDD extends Logging { None, inputFormatClass, keyClass, valueClass, conf) val confBroadcasted = sc.sc.broadcast(new SerializableWritable(conf)) val converted = convertRDD(rdd, keyConverterClass, valueConverterClass, - new WritableToJavaConverter(confBroadcasted, batchSize)) + new WritableToJavaConverter(confBroadcasted)) JavaRDD.fromRDD(SerDeUtil.pairRDDToPython(converted, batchSize)) } @@ -586,7 +583,7 @@ private[spark] object PythonRDD extends Logging { } def writeUTF(str: String, dataOut: DataOutputStream) { - val bytes = str.getBytes(UTF8) + val bytes = str.getBytes(UTF_8) dataOut.writeInt(bytes.length) dataOut.write(bytes) } @@ -747,109 +744,11 @@ private[spark] object PythonRDD extends Logging { converted.saveAsHadoopDataset(new JobConf(conf)) } } - - - /** - * Convert an RDD of serialized Python dictionaries to Scala Maps (no recursive conversions). - */ - @deprecated("PySpark does not use it anymore", "1.1") - def pythonToJavaMap(pyRDD: JavaRDD[Array[Byte]]): JavaRDD[Map[String, _]] = { - pyRDD.rdd.mapPartitions { iter => - val unpickle = new Unpickler - SerDeUtil.initialize() - iter.flatMap { row => - unpickle.loads(row) match { - // in case of objects are pickled in batch mode - case objs: JArrayList[JMap[String, _] @unchecked] => objs.map(_.toMap) - // not in batch mode - case obj: JMap[String @unchecked, _] => Seq(obj.toMap) - } - } - } - } - - /** - * Convert an RDD of serialized Python tuple to Array (no recursive conversions). - * It is only used by pyspark.sql. - */ - def pythonToJavaArray(pyRDD: JavaRDD[Array[Byte]], batched: Boolean): JavaRDD[Array[_]] = { - - def toArray(obj: Any): Array[_] = { - obj match { - case objs: JArrayList[_] => - objs.toArray - case obj if obj.getClass.isArray => - obj.asInstanceOf[Array[_]].toArray - } - } - - pyRDD.rdd.mapPartitions { iter => - val unpickle = new Unpickler - iter.flatMap { row => - val obj = unpickle.loads(row) - if (batched) { - obj.asInstanceOf[JArrayList[_]].map(toArray) - } else { - Seq(toArray(obj)) - } - } - }.toJavaRDD() - } - - private[spark] class AutoBatchedPickler(iter: Iterator[Any]) extends Iterator[Array[Byte]] { - private val pickle = new Pickler() - private var batch = 1 - private val buffer = new mutable.ArrayBuffer[Any] - - override def hasNext(): Boolean = iter.hasNext - - override def next(): Array[Byte] = { - while (iter.hasNext && buffer.length < batch) { - buffer += iter.next() - } - val bytes = pickle.dumps(buffer.toArray) - val size = bytes.length - // let 1M < size < 10M - if (size < 1024 * 1024) { - batch *= 2 - } else if (size > 1024 * 1024 * 10 && batch > 1) { - batch /= 2 - } - buffer.clear() - bytes - } - } - - /** - * Convert an RDD of Java objects to an RDD of serialized Python objects, that is usable by - * PySpark. - */ - def javaToPython(jRDD: JavaRDD[Any]): JavaRDD[Array[Byte]] = { - jRDD.rdd.mapPartitions { iter => new AutoBatchedPickler(iter) } - } - - /** - * Convert an RDD of serialized Python objects to RDD of objects, that is usable by PySpark. - */ - def pythonToJava(pyRDD: JavaRDD[Array[Byte]], batched: Boolean): JavaRDD[Any] = { - pyRDD.rdd.mapPartitions { iter => - SerDeUtil.initialize() - val unpickle = new Unpickler - iter.flatMap { row => - val obj = unpickle.loads(row) - if (batched) { - obj.asInstanceOf[JArrayList[_]].asScala - } else { - Seq(obj) - } - } - }.toJavaRDD() - } } private class BytesToString extends org.apache.spark.api.java.function.Function[Array[Byte], String] { - override def call(arr: Array[Byte]) : String = new String(arr, PythonRDD.UTF8) + override def call(arr: Array[Byte]) : String = new String(arr, UTF_8) } /** diff --git a/core/src/main/scala/org/apache/spark/api/python/SerDeUtil.scala b/core/src/main/scala/org/apache/spark/api/python/SerDeUtil.scala index ebdc3533e0992..a4153aaa926f8 100644 --- a/core/src/main/scala/org/apache/spark/api/python/SerDeUtil.scala +++ b/core/src/main/scala/org/apache/spark/api/python/SerDeUtil.scala @@ -18,8 +18,13 @@ package org.apache.spark.api.python import java.nio.ByteOrder +import java.util.{ArrayList => JArrayList} + +import org.apache.spark.api.java.JavaRDD import scala.collection.JavaConversions._ +import scala.collection.JavaConverters._ +import scala.collection.mutable import scala.util.Failure import scala.util.Try @@ -89,6 +94,73 @@ private[spark] object SerDeUtil extends Logging { } initialize() + + /** + * Convert an RDD of Java objects to Array (no recursive conversions). + * It is only used by pyspark.sql. + */ + def toJavaArray(jrdd: JavaRDD[Any]): JavaRDD[Array[_]] = { + jrdd.rdd.map { + case objs: JArrayList[_] => + objs.toArray + case obj if obj.getClass.isArray => + obj.asInstanceOf[Array[_]].toArray + }.toJavaRDD() + } + + /** + * Choose batch size based on size of objects + */ + private[spark] class AutoBatchedPickler(iter: Iterator[Any]) extends Iterator[Array[Byte]] { + private val pickle = new Pickler() + private var batch = 1 + private val buffer = new mutable.ArrayBuffer[Any] + + override def hasNext: Boolean = iter.hasNext + + override def next(): Array[Byte] = { + while (iter.hasNext && buffer.length < batch) { + buffer += iter.next() + } + val bytes = pickle.dumps(buffer.toArray) + val size = bytes.length + // let 1M < size < 10M + if (size < 1024 * 1024) { + batch *= 2 + } else if (size > 1024 * 1024 * 10 && batch > 1) { + batch /= 2 + } + buffer.clear() + bytes + } + } + + /** + * Convert an RDD of Java objects to an RDD of serialized Python objects, that is usable by + * PySpark. + */ + private[spark] def javaToPython(jRDD: JavaRDD[_]): JavaRDD[Array[Byte]] = { + jRDD.rdd.mapPartitions { iter => new AutoBatchedPickler(iter) } + } + + /** + * Convert an RDD of serialized Python objects to RDD of objects, that is usable by PySpark. + */ + def pythonToJava(pyRDD: JavaRDD[Array[Byte]], batched: Boolean): JavaRDD[Any] = { + pyRDD.rdd.mapPartitions { iter => + initialize() + val unpickle = new Unpickler + iter.flatMap { row => + val obj = unpickle.loads(row) + if (batched) { + obj.asInstanceOf[JArrayList[_]].asScala + } else { + Seq(obj) + } + } + }.toJavaRDD() + } + private def checkPickle(t: (Any, Any)): (Boolean, Boolean) = { val pickle = new Pickler val kt = Try { @@ -128,17 +200,18 @@ private[spark] object SerDeUtil extends Logging { */ def pairRDDToPython(rdd: RDD[(Any, Any)], batchSize: Int): RDD[Array[Byte]] = { val (keyFailed, valueFailed) = checkPickle(rdd.first()) + rdd.mapPartitions { iter => - val pickle = new Pickler val cleaned = iter.map { case (k, v) => val key = if (keyFailed) k.toString else k val value = if (valueFailed) v.toString else v Array[Any](key, value) } - if (batchSize > 1) { - cleaned.grouped(batchSize).map(batched => pickle.dumps(seqAsJavaList(batched))) + if (batchSize == 0) { + new AutoBatchedPickler(cleaned) } else { - cleaned.map(pickle.dumps(_)) + val pickle = new Pickler + cleaned.grouped(batchSize).map(batched => pickle.dumps(seqAsJavaList(batched))) } } } @@ -146,36 +219,22 @@ private[spark] object SerDeUtil extends Logging { /** * Convert an RDD of serialized Python tuple (K, V) to RDD[(K, V)]. */ - def pythonToPairRDD[K, V](pyRDD: RDD[Array[Byte]], batchSerialized: Boolean): RDD[(K, V)] = { + def pythonToPairRDD[K, V](pyRDD: RDD[Array[Byte]], batched: Boolean): RDD[(K, V)] = { def isPair(obj: Any): Boolean = { - Option(obj.getClass.getComponentType).map(!_.isPrimitive).getOrElse(false) && + Option(obj.getClass.getComponentType).exists(!_.isPrimitive) && obj.asInstanceOf[Array[_]].length == 2 } - pyRDD.mapPartitions { iter => - initialize() - val unpickle = new Unpickler - val unpickled = - if (batchSerialized) { - iter.flatMap { batch => - unpickle.loads(batch) match { - case objs: java.util.List[_] => collectionAsScalaIterable(objs) - case other => throw new SparkException( - s"Unexpected type ${other.getClass.getName} for batch serialized Python RDD") - } - } - } else { - iter.map(unpickle.loads(_)) - } - unpickled.map { - case obj if isPair(obj) => - // we only accept (K, V) - val arr = obj.asInstanceOf[Array[_]] - (arr.head.asInstanceOf[K], arr.last.asInstanceOf[V]) - case other => throw new SparkException( - s"RDD element of type ${other.getClass.getName} cannot be used") - } + + val rdd = pythonToJava(pyRDD, batched).rdd + rdd.first match { + case obj if isPair(obj) => + // we only accept (K, V) + case other => throw new SparkException( + s"RDD element of type ${other.getClass.getName} cannot be used") + } + rdd.map { obj => + val arr = obj.asInstanceOf[Array[_]] + (arr.head.asInstanceOf[K], arr.last.asInstanceOf[V]) } } - } - diff --git a/core/src/main/scala/org/apache/spark/api/python/WriteInputFormatTestDataGenerator.scala b/core/src/main/scala/org/apache/spark/api/python/WriteInputFormatTestDataGenerator.scala index d11db978b842e..c0cbd28a845be 100644 --- a/core/src/main/scala/org/apache/spark/api/python/WriteInputFormatTestDataGenerator.scala +++ b/core/src/main/scala/org/apache/spark/api/python/WriteInputFormatTestDataGenerator.scala @@ -18,7 +18,8 @@ package org.apache.spark.api.python import java.io.{DataOutput, DataInput} -import java.nio.charset.Charset + +import com.google.common.base.Charsets.UTF_8 import org.apache.hadoop.io._ import org.apache.hadoop.mapreduce.lib.output.SequenceFileOutputFormat @@ -136,7 +137,7 @@ object WriteInputFormatTestDataGenerator { sc.parallelize(intKeys).saveAsSequenceFile(intPath) sc.parallelize(intKeys.map{ case (k, v) => (k.toDouble, v) }).saveAsSequenceFile(doublePath) sc.parallelize(intKeys.map{ case (k, v) => (k.toString, v) }).saveAsSequenceFile(textPath) - sc.parallelize(intKeys.map{ case (k, v) => (k, v.getBytes(Charset.forName("UTF-8"))) } + sc.parallelize(intKeys.map{ case (k, v) => (k, v.getBytes(UTF_8)) } ).saveAsSequenceFile(bytesPath) val bools = Seq((1, true), (2, true), (2, false), (3, true), (2, false), (1, false)) sc.parallelize(bools).saveAsSequenceFile(boolPath) @@ -175,11 +176,11 @@ object WriteInputFormatTestDataGenerator { // Create test data for arbitrary custom writable TestWritable val testClass = Seq( - ("1", TestWritable("test1", 123, 54.0)), - ("2", TestWritable("test2", 456, 8762.3)), - ("1", TestWritable("test3", 123, 423.1)), - ("3", TestWritable("test56", 456, 423.5)), - ("2", TestWritable("test2", 123, 5435.2)) + ("1", TestWritable("test1", 1, 1.0)), + ("2", TestWritable("test2", 2, 2.3)), + ("3", TestWritable("test3", 3, 3.1)), + ("5", TestWritable("test56", 5, 5.5)), + ("4", TestWritable("test4", 4, 4.2)) ) val rdd = sc.parallelize(testClass, numSlices = 2).map{ case (k, v) => (new Text(k), v) } rdd.saveAsNewAPIHadoopFile(classPath, diff --git a/core/src/main/scala/org/apache/spark/broadcast/Broadcast.scala b/core/src/main/scala/org/apache/spark/broadcast/Broadcast.scala index 15fd30e65761d..87f5cf944ed85 100644 --- a/core/src/main/scala/org/apache/spark/broadcast/Broadcast.scala +++ b/core/src/main/scala/org/apache/spark/broadcast/Broadcast.scala @@ -20,6 +20,8 @@ package org.apache.spark.broadcast import java.io.Serializable import org.apache.spark.SparkException +import org.apache.spark.Logging +import org.apache.spark.util.Utils import scala.reflect.ClassTag @@ -52,7 +54,7 @@ import scala.reflect.ClassTag * @param id A unique identifier for the broadcast variable. * @tparam T Type of the data contained in the broadcast variable. */ -abstract class Broadcast[T: ClassTag](val id: Long) extends Serializable { +abstract class Broadcast[T: ClassTag](val id: Long) extends Serializable with Logging { /** * Flag signifying whether the broadcast variable is valid @@ -60,6 +62,8 @@ abstract class Broadcast[T: ClassTag](val id: Long) extends Serializable { */ @volatile private var _isValid = true + private var _destroySite = "" + /** Get the broadcasted value. */ def value: T = { assertValid() @@ -84,13 +88,26 @@ abstract class Broadcast[T: ClassTag](val id: Long) extends Serializable { doUnpersist(blocking) } + + /** + * Destroy all data and metadata related to this broadcast variable. Use this with caution; + * once a broadcast variable has been destroyed, it cannot be used again. + * This method blocks until destroy has completed + */ + def destroy() { + destroy(blocking = true) + } + /** * Destroy all data and metadata related to this broadcast variable. Use this with caution; * once a broadcast variable has been destroyed, it cannot be used again. + * @param blocking Whether to block until destroy has completed */ private[spark] def destroy(blocking: Boolean) { assertValid() _isValid = false + _destroySite = Utils.getCallSite().shortForm + logInfo("Destroying %s (from %s)".format(toString, _destroySite)) doDestroy(blocking) } @@ -124,7 +141,8 @@ abstract class Broadcast[T: ClassTag](val id: Long) extends Serializable { /** Check if this broadcast is valid. If not valid, exception is thrown. */ protected def assertValid() { if (!_isValid) { - throw new SparkException("Attempted to use %s after it has been destroyed!".format(toString)) + throw new SparkException( + "Attempted to use %s after it was destroyed (%s) ".format(toString, _destroySite)) } } diff --git a/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala b/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala index 75e64c1bf401e..94142d33369c7 100644 --- a/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala +++ b/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala @@ -56,11 +56,13 @@ private[spark] class TorrentBroadcast[T: ClassTag](obj: T, id: Long) extends Broadcast[T](id) with Logging with Serializable { /** - * Value of the broadcast object. On driver, this is set directly by the constructor. - * On executors, this is reconstructed by [[readObject]], which builds this value by reading - * blocks from the driver and/or other executors. + * Value of the broadcast object on executors. This is reconstructed by [[readBroadcastBlock]], + * which builds this value by reading blocks from the driver and/or other executors. + * + * On the driver, if the value is required, it is read lazily from the block manager. */ - @transient private var _value: T = obj + @transient private lazy val _value: T = readBroadcastBlock() + /** The compression codec to use, or None if compression is disabled */ @transient private var compressionCodec: Option[CompressionCodec] = _ /** Size of each block. Default value is 4MB. This value is only read by the broadcaster. */ @@ -79,22 +81,24 @@ private[spark] class TorrentBroadcast[T: ClassTag](obj: T, id: Long) private val broadcastId = BroadcastBlockId(id) /** Total number of blocks this broadcast variable contains. */ - private val numBlocks: Int = writeBlocks() + private val numBlocks: Int = writeBlocks(obj) - override protected def getValue() = _value + override protected def getValue() = { + _value + } /** * Divide the object into multiple blocks and put those blocks in the block manager. - * + * @param value the object to divide * @return number of blocks this broadcast variable is divided into */ - private def writeBlocks(): Int = { + private def writeBlocks(value: T): Int = { // Store a copy of the broadcast variable in the driver so that tasks run on the driver // do not create a duplicate copy of the broadcast variable's value. - SparkEnv.get.blockManager.putSingle(broadcastId, _value, StorageLevel.MEMORY_AND_DISK, + SparkEnv.get.blockManager.putSingle(broadcastId, value, StorageLevel.MEMORY_AND_DISK, tellMaster = false) val blocks = - TorrentBroadcast.blockifyObject(_value, blockSize, SparkEnv.get.serializer, compressionCodec) + TorrentBroadcast.blockifyObject(value, blockSize, SparkEnv.get.serializer, compressionCodec) blocks.zipWithIndex.foreach { case (block, i) => SparkEnv.get.blockManager.putBytes( BroadcastBlockId(id, "piece" + i), @@ -157,31 +161,30 @@ private[spark] class TorrentBroadcast[T: ClassTag](obj: T, id: Long) out.defaultWriteObject() } - /** Used by the JVM when deserializing this object. */ - private def readObject(in: ObjectInputStream): Unit = Utils.tryOrIOException { - in.defaultReadObject() + private def readBroadcastBlock(): T = Utils.tryOrIOException { TorrentBroadcast.synchronized { setConf(SparkEnv.get.conf) SparkEnv.get.blockManager.getLocal(broadcastId).map(_.data.next()) match { case Some(x) => - _value = x.asInstanceOf[T] + x.asInstanceOf[T] case None => logInfo("Started reading broadcast variable " + id) - val start = System.nanoTime() + val startTimeMs = System.currentTimeMillis() val blocks = readBlocks() - val time = (System.nanoTime() - start) / 1e9 - logInfo("Reading broadcast variable " + id + " took " + time + " s") + logInfo("Reading broadcast variable " + id + " took" + Utils.getUsedTimeMs(startTimeMs)) - _value = - TorrentBroadcast.unBlockifyObject[T](blocks, SparkEnv.get.serializer, compressionCodec) + val obj = TorrentBroadcast.unBlockifyObject[T]( + blocks, SparkEnv.get.serializer, compressionCodec) // Store the merged copy in BlockManager so other tasks on this executor don't // need to re-fetch it. SparkEnv.get.blockManager.putSingle( - broadcastId, _value, StorageLevel.MEMORY_AND_DISK, tellMaster = false) + broadcastId, obj, StorageLevel.MEMORY_AND_DISK, tellMaster = false) + obj } } } + } diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala b/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala index fe0ad9ebbca12..e28eaad8a5180 100644 --- a/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala +++ b/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala @@ -20,12 +20,15 @@ package org.apache.spark.deploy import java.security.PrivilegedExceptionAction import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.fs.{FileSystem, Path} +import org.apache.hadoop.fs.FileSystem.Statistics import org.apache.hadoop.mapred.JobConf import org.apache.hadoop.security.Credentials import org.apache.hadoop.security.UserGroupInformation import org.apache.spark.{Logging, SparkContext, SparkConf, SparkException} import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.util.Utils import scala.collection.JavaConversions._ @@ -121,6 +124,33 @@ class SparkHadoopUtil extends Logging { UserGroupInformation.loginUserFromKeytab(principalName, keytabFilename) } + /** + * Returns a function that can be called to find Hadoop FileSystem bytes read. If + * getFSBytesReadOnThreadCallback is called from thread r at time t, the returned callback will + * return the bytes read on r since t. Reflection is required because thread-level FileSystem + * statistics are only available as of Hadoop 2.5 (see HADOOP-10688). + * Returns None if the required method can't be found. + */ + private[spark] def getFSBytesReadOnThreadCallback(path: Path, conf: Configuration) + : Option[() => Long] = { + val qualifiedPath = path.getFileSystem(conf).makeQualified(path) + val scheme = qualifiedPath.toUri().getScheme() + val stats = FileSystem.getAllStatistics().filter(_.getScheme().equals(scheme)) + try { + val threadStats = stats.map(Utils.invoke(classOf[Statistics], _, "getThreadStatistics")) + val statisticsDataClass = + Class.forName("org.apache.hadoop.fs.FileSystem$Statistics$StatisticsData") + val getBytesReadMethod = statisticsDataClass.getDeclaredMethod("getBytesRead") + val f = () => threadStats.map(getBytesReadMethod.invoke(_).asInstanceOf[Long]).sum + val baselineBytesRead = f() + Some(() => f() - baselineBytesRead) + } catch { + case e: NoSuchMethodException => { + logDebug("Couldn't find method for retrieving thread-level FileSystem input data", e) + None + } + } + } } object SparkHadoopUtil { diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala index f97bf67fa5a3b..b43e68e40f791 100644 --- a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala +++ b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala @@ -158,8 +158,9 @@ object SparkSubmit { args.files = mergeFileLists(args.files, args.primaryResource) } args.files = mergeFileLists(args.files, args.pyFiles) - // Format python file paths properly before adding them to the PYTHONPATH - sysProps("spark.submit.pyFiles") = PythonRunner.formatPaths(args.pyFiles).mkString(",") + if (args.pyFiles != null) { + sysProps("spark.submit.pyFiles") = args.pyFiles + } } // Special flag to avoid deprecation warnings at the client @@ -273,15 +274,32 @@ object SparkSubmit { } } - // Properties given with --conf are superceded by other options, but take precedence over - // properties in the defaults file. + // Load any properties specified through --conf and the default properties file for ((k, v) <- args.sparkProperties) { sysProps.getOrElseUpdate(k, v) } - // Read from default spark properties, if any - for ((k, v) <- args.defaultSparkProperties) { - sysProps.getOrElseUpdate(k, v) + // Resolve paths in certain spark properties + val pathConfigs = Seq( + "spark.jars", + "spark.files", + "spark.yarn.jar", + "spark.yarn.dist.files", + "spark.yarn.dist.archives") + pathConfigs.foreach { config => + // Replace old URIs with resolved URIs, if they exist + sysProps.get(config).foreach { oldValue => + sysProps(config) = Utils.resolveURIs(oldValue) + } + } + + // Resolve and format python file paths properly before adding them to the PYTHONPATH. + // The resolving part is redundant in the case of --py-files, but necessary if the user + // explicitly sets `spark.submit.pyFiles` in his/her default properties file. + sysProps.get("spark.submit.pyFiles").foreach { pyFiles => + val resolvedPyFiles = Utils.resolveURIs(pyFiles) + val formattedPyFiles = PythonRunner.formatPaths(resolvedPyFiles).mkString(",") + sysProps("spark.submit.pyFiles") = formattedPyFiles } (childArgs, childClasspath, sysProps, childMainClass) diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala b/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala index 72a452e0aefb5..f0e9ee67f6a67 100644 --- a/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala +++ b/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala @@ -19,7 +19,6 @@ package org.apache.spark.deploy import java.util.jar.JarFile -import scala.collection.JavaConversions._ import scala.collection.mutable.{ArrayBuffer, HashMap} import org.apache.spark.util.Utils @@ -72,39 +71,54 @@ private[spark] class SparkSubmitArguments(args: Seq[String], env: Map[String, St defaultProperties } - // Respect SPARK_*_MEMORY for cluster mode - driverMemory = sys.env.get("SPARK_DRIVER_MEMORY").orNull - executorMemory = sys.env.get("SPARK_EXECUTOR_MEMORY").orNull - + // Set parameters from command line arguments parseOpts(args.toList) - mergeSparkProperties() + // Populate `sparkProperties` map from properties file + mergeDefaultSparkProperties() + // Use `sparkProperties` map along with env vars to fill in any missing parameters + loadEnvironmentArguments() + checkRequiredArguments() /** - * Fill in any undefined values based on the default properties file or options passed in through - * the '--conf' flag. + * Merge values from the default properties file with those specified through --conf. + * When this is called, `sparkProperties` is already filled with configs from the latter. */ - private def mergeSparkProperties(): Unit = { + private def mergeDefaultSparkProperties(): Unit = { // Use common defaults file, if not specified by user propertiesFile = Option(propertiesFile).getOrElse(Utils.getDefaultPropertiesFile(env)) + // Honor --conf before the defaults file + defaultSparkProperties.foreach { case (k, v) => + if (!sparkProperties.contains(k)) { + sparkProperties(k) = v + } + } + } - val properties = HashMap[String, String]() - properties.putAll(defaultSparkProperties) - properties.putAll(sparkProperties) - - // Use properties file as fallback for values which have a direct analog to - // arguments in this script. - master = Option(master).orElse(properties.get("spark.master")).orNull - executorMemory = Option(executorMemory).orElse(properties.get("spark.executor.memory")).orNull - executorCores = Option(executorCores).orElse(properties.get("spark.executor.cores")).orNull + /** + * Load arguments from environment variables, Spark properties etc. + */ + private def loadEnvironmentArguments(): Unit = { + master = Option(master) + .orElse(sparkProperties.get("spark.master")) + .orElse(env.get("MASTER")) + .orNull + driverMemory = Option(driverMemory) + .orElse(sparkProperties.get("spark.driver.memory")) + .orElse(env.get("SPARK_DRIVER_MEMORY")) + .orNull + executorMemory = Option(executorMemory) + .orElse(sparkProperties.get("spark.executor.memory")) + .orElse(env.get("SPARK_EXECUTOR_MEMORY")) + .orNull + executorCores = Option(executorCores) + .orElse(sparkProperties.get("spark.executor.cores")) + .orNull totalExecutorCores = Option(totalExecutorCores) - .orElse(properties.get("spark.cores.max")) + .orElse(sparkProperties.get("spark.cores.max")) .orNull - name = Option(name).orElse(properties.get("spark.app.name")).orNull - jars = Option(jars).orElse(properties.get("spark.jars")).orNull - - // This supports env vars in older versions of Spark - master = Option(master).orElse(env.get("MASTER")).orNull + name = Option(name).orElse(sparkProperties.get("spark.app.name")).orNull + jars = Option(jars).orElse(sparkProperties.get("spark.jars")).orNull deployMode = Option(deployMode).orElse(env.get("DEPLOY_MODE")).orNull // Try to set main class from JAR if no --class argument is given @@ -131,7 +145,7 @@ private[spark] class SparkSubmitArguments(args: Seq[String], env: Map[String, St } /** Ensure that required fields exists. Call this only once all defaults are loaded. */ - private def checkRequiredArguments() = { + private def checkRequiredArguments(): Unit = { if (args.length == 0) { printUsageAndExit(-1) } @@ -166,7 +180,7 @@ private[spark] class SparkSubmitArguments(args: Seq[String], env: Map[String, St } } - override def toString = { + override def toString = { s"""Parsed arguments: | master $master | deployMode $deployMode @@ -174,7 +188,6 @@ private[spark] class SparkSubmitArguments(args: Seq[String], env: Map[String, St | executorCores $executorCores | totalExecutorCores $totalExecutorCores | propertiesFile $propertiesFile - | extraSparkProperties $sparkProperties | driverMemory $driverMemory | driverCores $driverCores | driverExtraClassPath $driverExtraClassPath @@ -193,8 +206,9 @@ private[spark] class SparkSubmitArguments(args: Seq[String], env: Map[String, St | jars $jars | verbose $verbose | - |Default properties from $propertiesFile: - |${defaultSparkProperties.mkString(" ", "\n ", "\n")} + |Spark properties used, including those specified through + | --conf and those from the properties file $propertiesFile: + |${sparkProperties.mkString(" ", "\n ", "\n")} """.stripMargin } @@ -327,7 +341,7 @@ private[spark] class SparkSubmitArguments(args: Seq[String], env: Map[String, St } } - private def printUsageAndExit(exitCode: Int, unknownParam: Any = null) { + private def printUsageAndExit(exitCode: Int, unknownParam: Any = null): Unit = { val outStream = SparkSubmit.printStream if (unknownParam != null) { outStream.println("Unknown/unsupported param " + unknownParam) diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkSubmitDriverBootstrapper.scala b/core/src/main/scala/org/apache/spark/deploy/SparkSubmitDriverBootstrapper.scala index 0125330589da5..2b894a796c8c6 100644 --- a/core/src/main/scala/org/apache/spark/deploy/SparkSubmitDriverBootstrapper.scala +++ b/core/src/main/scala/org/apache/spark/deploy/SparkSubmitDriverBootstrapper.scala @@ -82,17 +82,8 @@ private[spark] object SparkSubmitDriverBootstrapper { .orElse(confDriverMemory) .getOrElse(defaultDriverMemory) - val newLibraryPath = - if (submitLibraryPath.isDefined) { - // SPARK_SUBMIT_LIBRARY_PATH is already captured in JAVA_OPTS - "" - } else { - confLibraryPath.map("-Djava.library.path=" + _).getOrElse("") - } - val newClasspath = if (submitClasspath.isDefined) { - // SPARK_SUBMIT_CLASSPATH is already captured in CLASSPATH classpath } else { classpath + confClasspath.map(sys.props("path.separator") + _).getOrElse("") @@ -114,7 +105,6 @@ private[spark] object SparkSubmitDriverBootstrapper { val command: Seq[String] = Seq(runner) ++ Seq("-cp", newClasspath) ++ - Seq(newLibraryPath) ++ filteredJavaOpts ++ Seq(s"-Xms$newDriverMemory", s"-Xmx$newDriverMemory") ++ Seq("org.apache.spark.deploy.SparkSubmit") ++ @@ -130,6 +120,13 @@ private[spark] object SparkSubmitDriverBootstrapper { // Start the driver JVM val filteredCommand = command.filter(_.nonEmpty) val builder = new ProcessBuilder(filteredCommand) + val env = builder.environment() + + if (submitLibraryPath.isEmpty && confLibraryPath.nonEmpty) { + val libraryPaths = confLibraryPath ++ sys.env.get(Utils.libraryPathEnvName) + env.put(Utils.libraryPathEnvName, libraryPaths.mkString(sys.props("path.separator"))) + } + val process = builder.start() // Redirect stdout and stderr from the child JVM diff --git a/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala b/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala index 481f6c93c6a8d..2d1609b973607 100644 --- a/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala +++ b/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala @@ -112,7 +112,7 @@ private[history] class FsHistoryProvider(conf: SparkConf) extends ApplicationHis val ui = { val conf = this.conf.clone() val appSecManager = new SecurityManager(conf) - new SparkUI(conf, appSecManager, replayBus, appId, + SparkUI.createHistoryUI(conf, replayBus, appSecManager, appId, s"${HistoryServer.UI_PATH_PREFIX}/$appId") // Do not call ui.bind() to avoid creating a new server for each application } diff --git a/core/src/main/scala/org/apache/spark/deploy/history/HistoryPage.scala b/core/src/main/scala/org/apache/spark/deploy/history/HistoryPage.scala index d25c29113d6da..0e249e51a77d8 100644 --- a/core/src/main/scala/org/apache/spark/deploy/history/HistoryPage.scala +++ b/core/src/main/scala/org/apache/spark/deploy/history/HistoryPage.scala @@ -84,11 +84,11 @@ private[spark] class HistoryPage(parent: HistoryServer) extends WebUIPage("") { {info.id} {info.name} - {startTime} - {endTime} - {duration} + {startTime} + {endTime} + {duration} {info.sparkUser} - {lastUpdated} + {lastUpdated} } } diff --git a/core/src/main/scala/org/apache/spark/deploy/master/Master.scala b/core/src/main/scala/org/apache/spark/deploy/master/Master.scala index 3b6bb9fe128a4..2f81d472d7b78 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/Master.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/Master.scala @@ -721,8 +721,8 @@ private[spark] class Master( try { val replayBus = new ReplayListenerBus(eventLogPaths, fileSystem, compressionCodec) - val ui = new SparkUI(new SparkConf, replayBus, appName + " (completed)", - HistoryServer.UI_PATH_PREFIX + s"/${app.id}") + val ui = SparkUI.createHistoryUI(new SparkConf, replayBus, new SecurityManager(conf), + appName + " (completed)", HistoryServer.UI_PATH_PREFIX + s"/${app.id}") replayBus.replay() appIdToUI(app.id) = ui webUi.attachSparkUI(ui) diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/CommandUtils.scala b/core/src/main/scala/org/apache/spark/deploy/worker/CommandUtils.scala index 2e9be2a180c68..28e9662db5da9 100644 --- a/core/src/main/scala/org/apache/spark/deploy/worker/CommandUtils.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/CommandUtils.scala @@ -20,6 +20,8 @@ package org.apache.spark.deploy.worker import java.io.{File, FileOutputStream, InputStream, IOException} import java.lang.System._ +import scala.collection.Map + import org.apache.spark.Logging import org.apache.spark.deploy.Command import org.apache.spark.util.Utils @@ -29,7 +31,29 @@ import org.apache.spark.util.Utils */ private[spark] object CommandUtils extends Logging { - def buildCommandSeq(command: Command, memory: Int, sparkHome: String): Seq[String] = { + + /** + * Build a ProcessBuilder based on the given parameters. + * The `env` argument is exposed for testing. + */ + def buildProcessBuilder( + command: Command, + memory: Int, + sparkHome: String, + substituteArguments: String => String, + classPaths: Seq[String] = Seq[String](), + env: Map[String, String] = sys.env): ProcessBuilder = { + val localCommand = buildLocalCommand(command, substituteArguments, classPaths, env) + val commandSeq = buildCommandSeq(localCommand, memory, sparkHome) + val builder = new ProcessBuilder(commandSeq: _*) + val environment = builder.environment() + for ((key, value) <- localCommand.environment) { + environment.put(key, value) + } + builder + } + + private def buildCommandSeq(command: Command, memory: Int, sparkHome: String): Seq[String] = { val runner = sys.env.get("JAVA_HOME").map(_ + "/bin/java").getOrElse("java") // SPARK-698: do not call the run.cmd script, as process.destroy() @@ -38,11 +62,41 @@ object CommandUtils extends Logging { command.arguments } + /** + * Build a command based on the given one, taking into account the local environment + * of where this command is expected to run, substitute any placeholders, and append + * any extra class paths. + */ + private def buildLocalCommand( + command: Command, + substituteArguments: String => String, + classPath: Seq[String] = Seq[String](), + env: Map[String, String]): Command = { + val libraryPathName = Utils.libraryPathEnvName + val libraryPathEntries = command.libraryPathEntries + val cmdLibraryPath = command.environment.get(libraryPathName) + + val newEnvironment = if (libraryPathEntries.nonEmpty && libraryPathName.nonEmpty) { + val libraryPaths = libraryPathEntries ++ cmdLibraryPath ++ env.get(libraryPathName) + command.environment + ((libraryPathName, libraryPaths.mkString(File.pathSeparator))) + } else { + command.environment + } + + Command( + command.mainClass, + command.arguments.map(substituteArguments), + newEnvironment, + command.classPathEntries ++ classPath, + Seq[String](), // library path already captured in environment variable + command.javaOpts) + } + /** * Attention: this must always be aligned with the environment variables in the run scripts and * the way the JAVA_OPTS are assembled there. */ - def buildJavaOpts(command: Command, memory: Int, sparkHome: String): Seq[String] = { + private def buildJavaOpts(command: Command, memory: Int, sparkHome: String): Seq[String] = { val memoryOpts = Seq(s"-Xms${memory}M", s"-Xmx${memory}M") // Exists for backwards compatibility with older Spark versions @@ -53,14 +107,6 @@ object CommandUtils extends Logging { logWarning("Set SPARK_LOCAL_DIRS for node-specific storage locations.") } - val libraryOpts = - if (command.libraryPathEntries.size > 0) { - val joined = command.libraryPathEntries.mkString(File.pathSeparator) - Seq(s"-Djava.library.path=$joined") - } else { - Seq() - } - // Figure out our classpath with the external compute-classpath script val ext = if (System.getProperty("os.name").startsWith("Windows")) ".cmd" else ".sh" val classPath = Utils.executeAndGetOutput( @@ -71,7 +117,7 @@ object CommandUtils extends Logging { val javaVersion = System.getProperty("java.version") val permGenOpt = if (!javaVersion.startsWith("1.8")) Some("-XX:MaxPermSize=128m") else None Seq("-cp", userClassPath.filterNot(_.isEmpty).mkString(File.pathSeparator)) ++ - permGenOpt ++ libraryOpts ++ workerLocalOpts ++ command.javaOpts ++ memoryOpts + permGenOpt ++ workerLocalOpts ++ command.javaOpts ++ memoryOpts } /** Spawn a thread that will redirect a given stream to a file */ diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/DriverRunner.scala b/core/src/main/scala/org/apache/spark/deploy/worker/DriverRunner.scala index 9f9911762505a..28cab36c7b9e2 100644 --- a/core/src/main/scala/org/apache/spark/deploy/worker/DriverRunner.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/DriverRunner.scala @@ -23,7 +23,7 @@ import scala.collection.JavaConversions._ import scala.collection.Map import akka.actor.ActorRef -import com.google.common.base.Charsets +import com.google.common.base.Charsets.UTF_8 import com.google.common.io.Files import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.{FileUtil, Path} @@ -76,17 +76,9 @@ private[spark] class DriverRunner( // Make sure user application jar is on the classpath // TODO: If we add ability to submit multiple jars they should also be added here - val classPath = driverDesc.command.classPathEntries ++ Seq(s"$localJarFilename") - val newCommand = Command( - driverDesc.command.mainClass, - driverDesc.command.arguments.map(substituteVariables), - driverDesc.command.environment, - classPath, - driverDesc.command.libraryPathEntries, - driverDesc.command.javaOpts) - val command = CommandUtils.buildCommandSeq(newCommand, driverDesc.mem, - sparkHome.getAbsolutePath) - launchDriver(command, driverDesc.command.environment, driverDir, driverDesc.supervise) + val builder = CommandUtils.buildProcessBuilder(driverDesc.command, driverDesc.mem, + sparkHome.getAbsolutePath, substituteVariables, Seq(localJarFilename)) + launchDriver(builder, driverDir, driverDesc.supervise) } catch { case e: Exception => finalException = Some(e) @@ -165,11 +157,8 @@ private[spark] class DriverRunner( localJarFilename } - private def launchDriver(command: Seq[String], envVars: Map[String, String], baseDir: File, - supervise: Boolean) { - val builder = new ProcessBuilder(command: _*).directory(baseDir) - envVars.map{ case(k,v) => builder.environment().put(k, v) } - + private def launchDriver(builder: ProcessBuilder, baseDir: File, supervise: Boolean) { + builder.directory(baseDir) def initialize(process: Process) = { // Redirect stdout and stderr to files val stdout = new File(baseDir, "stdout") @@ -177,8 +166,8 @@ private[spark] class DriverRunner( val stderr = new File(baseDir, "stderr") val header = "Launch Command: %s\n%s\n\n".format( - command.mkString("\"", "\" \"", "\""), "=" * 40) - Files.append(header, stderr, Charsets.UTF_8) + builder.command.mkString("\"", "\" \"", "\""), "=" * 40) + Files.append(header, stderr, UTF_8) CommandUtils.redirectStream(process.getErrorStream, stderr) } runCommandWithRetry(ProcessBuilderLike(builder), initialize, supervise) diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/ExecutorRunner.scala b/core/src/main/scala/org/apache/spark/deploy/worker/ExecutorRunner.scala index 71d7385b08eb9..8ba6a01bbcb97 100644 --- a/core/src/main/scala/org/apache/spark/deploy/worker/ExecutorRunner.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/ExecutorRunner.scala @@ -19,8 +19,10 @@ package org.apache.spark.deploy.worker import java.io._ +import scala.collection.JavaConversions._ + import akka.actor.ActorRef -import com.google.common.base.Charsets +import com.google.common.base.Charsets.UTF_8 import com.google.common.io.Files import org.apache.spark.{SparkConf, Logging} @@ -115,33 +117,21 @@ private[spark] class ExecutorRunner( case other => other } - def getCommandSeq = { - val command = Command( - appDesc.command.mainClass, - appDesc.command.arguments.map(substituteVariables), - appDesc.command.environment, - appDesc.command.classPathEntries, - appDesc.command.libraryPathEntries, - appDesc.command.javaOpts) - CommandUtils.buildCommandSeq(command, memory, sparkHome.getAbsolutePath) - } - /** * Download and run the executor described in our ApplicationDescription */ def fetchAndRunExecutor() { try { // Launch the process - val command = getCommandSeq + val builder = CommandUtils.buildProcessBuilder(appDesc.command, memory, + sparkHome.getAbsolutePath, substituteVariables) + val command = builder.command() logInfo("Launch command: " + command.mkString("\"", "\" \"", "\"")) - val builder = new ProcessBuilder(command: _*).directory(executorDir) - val env = builder.environment() - for ((key, value) <- appDesc.command.environment) { - env.put(key, value) - } + + builder.directory(executorDir) // In case we are running this from within the Spark Shell, avoid creating a "scala" // parent process for the executor command - env.put("SPARK_LAUNCH_WITH_SCALA", "0") + builder.environment.put("SPARK_LAUNCH_WITH_SCALA", "0") process = builder.start() val header = "Spark Executor Command: %s\n%s\n\n".format( command.mkString("\"", "\" \"", "\""), "=" * 40) @@ -151,7 +141,7 @@ private[spark] class ExecutorRunner( stdoutAppender = FileAppender(process.getInputStream, stdout, conf) val stderr = new File(executorDir, "stderr") - Files.write(header, stderr, Charsets.UTF_8) + Files.write(header, stderr, UTF_8) stderrAppender = FileAppender(process.getErrorStream, stderr, conf) state = ExecutorState.RUNNING diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala b/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala index c4a8ec2e5e7b0..f1f66d0903f1c 100755 --- a/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala @@ -186,11 +186,11 @@ private[spark] class Worker( private def retryConnectToMaster() { Utils.tryOrExit { connectionAttemptCount += 1 - logInfo(s"Attempting to connect to master (attempt # $connectionAttemptCount") if (registered) { registrationRetryTimer.foreach(_.cancel()) registrationRetryTimer = None } else if (connectionAttemptCount <= TOTAL_REGISTRATION_RETRIES) { + logInfo(s"Retrying connection to master (attempt # $connectionAttemptCount)") tryRegisterAllMasters() if (connectionAttemptCount == INITIAL_REGISTRATION_RETRIES) { registrationRetryTimer.foreach(_.cancel()) diff --git a/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala b/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala index 697154d762d41..3711824a40cfc 100644 --- a/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala +++ b/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala @@ -131,7 +131,8 @@ private[spark] object CoarseGrainedExecutorBackend extends Logging { // Create a new ActorSystem using driver's Spark properties to run the backend. val driverConf = new SparkConf().setAll(props) val (actorSystem, boundPort) = AkkaUtils.createActorSystem( - "sparkExecutor", hostname, port, driverConf, new SecurityManager(driverConf)) + SparkEnv.executorActorSystemName, + hostname, port, driverConf, new SecurityManager(driverConf)) // set it val sparkHostPort = hostname + ":" + boundPort actorSystem.actorOf( diff --git a/core/src/main/scala/org/apache/spark/executor/Executor.scala b/core/src/main/scala/org/apache/spark/executor/Executor.scala index 2889f59e33e84..8b095e23f32ff 100644 --- a/core/src/main/scala/org/apache/spark/executor/Executor.scala +++ b/core/src/main/scala/org/apache/spark/executor/Executor.scala @@ -26,7 +26,7 @@ import scala.collection.JavaConversions._ import scala.collection.mutable.{ArrayBuffer, HashMap} import scala.util.control.NonFatal -import akka.actor.ActorSystem +import akka.actor.{Props, ActorSystem} import org.apache.spark._ import org.apache.spark.deploy.SparkHadoopUtil @@ -78,7 +78,7 @@ private[spark] class Executor( val executorSource = new ExecutorSource(this, executorId) // Initialize Spark environment (using system properties read above) - conf.set("spark.executor.id", "executor." + executorId) + conf.set("spark.executor.id", executorId) private val env = { if (!isLocal) { val port = conf.getInt("spark.executor.port", 0) @@ -92,6 +92,10 @@ private[spark] class Executor( } } + // Create an actor for receiving RPCs from the driver + private val executorActor = env.actorSystem.actorOf( + Props(new ExecutorActor(executorId)), "ExecutorActor") + // Create our ClassLoader // do this after SparkEnv creation so can access the SecurityManager private val urlClassLoader = createClassLoader() @@ -104,6 +108,9 @@ private[spark] class Executor( // to send the result back. private val akkaFrameSize = AkkaUtils.maxFrameSizeBytes(conf) + // Limit of bytes for total size of results (default is 1GB) + private val maxResultSize = Utils.getMaxResultSize(conf) + // Start worker thread pool val threadPool = Utils.newDaemonCachedThreadPool("Executor task launch worker") @@ -128,6 +135,7 @@ private[spark] class Executor( def stop() { env.metricsSystem.report() + env.actorSystem.stop(executorActor) isStopped = true threadPool.shutdown() if (!isLocal) { @@ -210,25 +218,27 @@ private[spark] class Executor( val resultSize = serializedDirectResult.limit // directSend = sending directly back to the driver - val (serializedResult, directSend) = { - if (resultSize >= akkaFrameSize - AkkaUtils.reservedSizeBytes) { + val serializedResult = { + if (resultSize > maxResultSize) { + logWarning(s"Finished $taskName (TID $taskId). Result is larger than maxResultSize " + + s"(${Utils.bytesToString(resultSize)} > ${Utils.bytesToString(maxResultSize)}), " + + s"dropping it.") + ser.serialize(new IndirectTaskResult[Any](TaskResultBlockId(taskId), resultSize)) + } else if (resultSize >= akkaFrameSize - AkkaUtils.reservedSizeBytes) { val blockId = TaskResultBlockId(taskId) env.blockManager.putBytes( blockId, serializedDirectResult, StorageLevel.MEMORY_AND_DISK_SER) - (ser.serialize(new IndirectTaskResult[Any](blockId)), false) + logInfo( + s"Finished $taskName (TID $taskId). $resultSize bytes result sent via BlockManager)") + ser.serialize(new IndirectTaskResult[Any](blockId, resultSize)) } else { - (serializedDirectResult, true) + logInfo(s"Finished $taskName (TID $taskId). $resultSize bytes result sent to driver") + serializedDirectResult } } execBackend.statusUpdate(taskId, TaskState.FINISHED, serializedResult) - if (directSend) { - logInfo(s"Finished $taskName (TID $taskId). $resultSize bytes result sent to driver") - } else { - logInfo( - s"Finished $taskName (TID $taskId). $resultSize bytes result sent via BlockManager)") - } } catch { case ffe: FetchFailedException => { val reason = ffe.toTaskEndReason diff --git a/core/src/main/scala/org/apache/spark/executor/ExecutorActor.scala b/core/src/main/scala/org/apache/spark/executor/ExecutorActor.scala new file mode 100644 index 0000000000000..41925f7e97e84 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/executor/ExecutorActor.scala @@ -0,0 +1,41 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.executor + +import akka.actor.Actor +import org.apache.spark.Logging + +import org.apache.spark.util.{Utils, ActorLogReceive} + +/** + * Driver -> Executor message to trigger a thread dump. + */ +private[spark] case object TriggerThreadDump + +/** + * Actor that runs inside of executors to enable driver -> executor RPC. + */ +private[spark] +class ExecutorActor(executorId: String) extends Actor with ActorLogReceive with Logging { + + override def receiveWithLogging = { + case TriggerThreadDump => + sender ! Utils.getThreadDump() + } + +} diff --git a/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala b/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala index 3e49b6235aff3..57bc2b40cec44 100644 --- a/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala +++ b/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala @@ -169,7 +169,6 @@ case class InputMetrics(readMethod: DataReadMethod.Value) { var bytesRead: Long = 0L } - /** * :: DeveloperApi :: * Metrics pertaining to shuffle data read in a given task. diff --git a/core/src/main/scala/org/apache/spark/input/FixedLengthBinaryInputFormat.scala b/core/src/main/scala/org/apache/spark/input/FixedLengthBinaryInputFormat.scala new file mode 100644 index 0000000000000..89b29af2000c8 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/input/FixedLengthBinaryInputFormat.scala @@ -0,0 +1,85 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.input + +import org.apache.hadoop.fs.Path +import org.apache.hadoop.io.{BytesWritable, LongWritable} +import org.apache.hadoop.mapreduce.lib.input.FileInputFormat +import org.apache.hadoop.mapreduce.{InputSplit, JobContext, RecordReader, TaskAttemptContext} + +/** + * Custom Input Format for reading and splitting flat binary files that contain records, + * each of which are a fixed size in bytes. The fixed record size is specified through + * a parameter recordLength in the Hadoop configuration. + */ +private[spark] object FixedLengthBinaryInputFormat { + /** Property name to set in Hadoop JobConfs for record length */ + val RECORD_LENGTH_PROPERTY = "org.apache.spark.input.FixedLengthBinaryInputFormat.recordLength" + + /** Retrieves the record length property from a Hadoop configuration */ + def getRecordLength(context: JobContext): Int = { + context.getConfiguration.get(RECORD_LENGTH_PROPERTY).toInt + } +} + +private[spark] class FixedLengthBinaryInputFormat + extends FileInputFormat[LongWritable, BytesWritable] { + + private var recordLength = -1 + + /** + * Override of isSplitable to ensure initial computation of the record length + */ + override def isSplitable(context: JobContext, filename: Path): Boolean = { + if (recordLength == -1) { + recordLength = FixedLengthBinaryInputFormat.getRecordLength(context) + } + if (recordLength <= 0) { + println("record length is less than 0, file cannot be split") + false + } else { + true + } + } + + /** + * This input format overrides computeSplitSize() to make sure that each split + * only contains full records. Each InputSplit passed to FixedLengthBinaryRecordReader + * will start at the first byte of a record, and the last byte will the last byte of a record. + */ + override def computeSplitSize(blockSize: Long, minSize: Long, maxSize: Long): Long = { + val defaultSize = super.computeSplitSize(blockSize, minSize, maxSize) + // If the default size is less than the length of a record, make it equal to it + // Otherwise, make sure the split size is as close to possible as the default size, + // but still contains a complete set of records, with the first record + // starting at the first byte in the split and the last record ending with the last byte + if (defaultSize < recordLength) { + recordLength.toLong + } else { + (Math.floor(defaultSize / recordLength) * recordLength).toLong + } + } + + /** + * Create a FixedLengthBinaryRecordReader + */ + override def createRecordReader(split: InputSplit, context: TaskAttemptContext) + : RecordReader[LongWritable, BytesWritable] = { + new FixedLengthBinaryRecordReader + } +} diff --git a/core/src/main/scala/org/apache/spark/input/FixedLengthBinaryRecordReader.scala b/core/src/main/scala/org/apache/spark/input/FixedLengthBinaryRecordReader.scala new file mode 100644 index 0000000000000..5164a74bec4e9 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/input/FixedLengthBinaryRecordReader.scala @@ -0,0 +1,126 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.input + +import java.io.IOException + +import org.apache.hadoop.fs.FSDataInputStream +import org.apache.hadoop.io.compress.CompressionCodecFactory +import org.apache.hadoop.io.{BytesWritable, LongWritable} +import org.apache.hadoop.mapreduce.{InputSplit, RecordReader, TaskAttemptContext} +import org.apache.hadoop.mapreduce.lib.input.FileSplit + +/** + * FixedLengthBinaryRecordReader is returned by FixedLengthBinaryInputFormat. + * It uses the record length set in FixedLengthBinaryInputFormat to + * read one record at a time from the given InputSplit. + * + * Each call to nextKeyValue() updates the LongWritable key and BytesWritable value. + * + * key = record index (Long) + * value = the record itself (BytesWritable) + */ +private[spark] class FixedLengthBinaryRecordReader + extends RecordReader[LongWritable, BytesWritable] { + + private var splitStart: Long = 0L + private var splitEnd: Long = 0L + private var currentPosition: Long = 0L + private var recordLength: Int = 0 + private var fileInputStream: FSDataInputStream = null + private var recordKey: LongWritable = null + private var recordValue: BytesWritable = null + + override def close() { + if (fileInputStream != null) { + fileInputStream.close() + } + } + + override def getCurrentKey: LongWritable = { + recordKey + } + + override def getCurrentValue: BytesWritable = { + recordValue + } + + override def getProgress: Float = { + splitStart match { + case x if x == splitEnd => 0.0.toFloat + case _ => Math.min( + ((currentPosition - splitStart) / (splitEnd - splitStart)).toFloat, 1.0 + ).toFloat + } + } + + override def initialize(inputSplit: InputSplit, context: TaskAttemptContext) { + // the file input + val fileSplit = inputSplit.asInstanceOf[FileSplit] + + // the byte position this fileSplit starts at + splitStart = fileSplit.getStart + + // splitEnd byte marker that the fileSplit ends at + splitEnd = splitStart + fileSplit.getLength + + // the actual file we will be reading from + val file = fileSplit.getPath + // job configuration + val job = context.getConfiguration + // check compression + val codec = new CompressionCodecFactory(job).getCodec(file) + if (codec != null) { + throw new IOException("FixedLengthRecordReader does not support reading compressed files") + } + // get the record length + recordLength = FixedLengthBinaryInputFormat.getRecordLength(context) + // get the filesystem + val fs = file.getFileSystem(job) + // open the File + fileInputStream = fs.open(file) + // seek to the splitStart position + fileInputStream.seek(splitStart) + // set our current position + currentPosition = splitStart + } + + override def nextKeyValue(): Boolean = { + if (recordKey == null) { + recordKey = new LongWritable() + } + // the key is a linear index of the record, given by the + // position the record starts divided by the record length + recordKey.set(currentPosition / recordLength) + // the recordValue to place the bytes into + if (recordValue == null) { + recordValue = new BytesWritable(new Array[Byte](recordLength)) + } + // read a record if the currentPosition is less than the split end + if (currentPosition < splitEnd) { + // setup a buffer to store the record + val buffer = recordValue.getBytes + fileInputStream.read(buffer, 0, recordLength) + // update our current position + currentPosition = currentPosition + recordLength + // return true + return true + } + false + } +} diff --git a/core/src/main/scala/org/apache/spark/input/PortableDataStream.scala b/core/src/main/scala/org/apache/spark/input/PortableDataStream.scala new file mode 100644 index 0000000000000..457472547fcbb --- /dev/null +++ b/core/src/main/scala/org/apache/spark/input/PortableDataStream.scala @@ -0,0 +1,218 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.input + +import java.io.{ByteArrayInputStream, ByteArrayOutputStream, DataInputStream, DataOutputStream} + +import scala.collection.JavaConversions._ + +import com.google.common.io.ByteStreams +import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.fs.Path +import org.apache.hadoop.mapreduce.{InputSplit, JobContext, RecordReader, TaskAttemptContext} +import org.apache.hadoop.mapreduce.lib.input.{CombineFileInputFormat, CombineFileRecordReader, CombineFileSplit} + +import org.apache.spark.annotation.Experimental + +/** + * A general format for reading whole files in as streams, byte arrays, + * or other functions to be added + */ +private[spark] abstract class StreamFileInputFormat[T] + extends CombineFileInputFormat[String, T] +{ + override protected def isSplitable(context: JobContext, file: Path): Boolean = false + + /** + * Allow minPartitions set by end-user in order to keep compatibility with old Hadoop API + * which is set through setMaxSplitSize + */ + def setMinPartitions(context: JobContext, minPartitions: Int) { + val files = listStatus(context) + val totalLen = files.map { file => + if (file.isDir) 0L else file.getLen + }.sum + + val maxSplitSize = Math.ceil(totalLen * 1.0 / files.length).toLong + super.setMaxSplitSize(maxSplitSize) + } + + def createRecordReader(split: InputSplit, taContext: TaskAttemptContext): RecordReader[String, T] + +} + +/** + * An abstract class of [[org.apache.hadoop.mapreduce.RecordReader RecordReader]] + * to reading files out as streams + */ +private[spark] abstract class StreamBasedRecordReader[T]( + split: CombineFileSplit, + context: TaskAttemptContext, + index: Integer) + extends RecordReader[String, T] { + + // True means the current file has been processed, then skip it. + private var processed = false + + private var key = "" + private var value: T = null.asInstanceOf[T] + + override def initialize(split: InputSplit, context: TaskAttemptContext) = {} + override def close() = {} + + override def getProgress = if (processed) 1.0f else 0.0f + + override def getCurrentKey = key + + override def getCurrentValue = value + + override def nextKeyValue = { + if (!processed) { + val fileIn = new PortableDataStream(split, context, index) + value = parseStream(fileIn) + fileIn.close() // if it has not been open yet, close does nothing + key = fileIn.getPath + processed = true + true + } else { + false + } + } + + /** + * Parse the stream (and close it afterwards) and return the value as in type T + * @param inStream the stream to be read in + * @return the data formatted as + */ + def parseStream(inStream: PortableDataStream): T +} + +/** + * Reads the record in directly as a stream for other objects to manipulate and handle + */ +private[spark] class StreamRecordReader( + split: CombineFileSplit, + context: TaskAttemptContext, + index: Integer) + extends StreamBasedRecordReader[PortableDataStream](split, context, index) { + + def parseStream(inStream: PortableDataStream): PortableDataStream = inStream +} + +/** + * The format for the PortableDataStream files + */ +private[spark] class StreamInputFormat extends StreamFileInputFormat[PortableDataStream] { + override def createRecordReader(split: InputSplit, taContext: TaskAttemptContext) = { + new CombineFileRecordReader[String, PortableDataStream]( + split.asInstanceOf[CombineFileSplit], taContext, classOf[StreamRecordReader]) + } +} + +/** + * A class that allows DataStreams to be serialized and moved around by not creating them + * until they need to be read + * @note TaskAttemptContext is not serializable resulting in the confBytes construct + * @note CombineFileSplit is not serializable resulting in the splitBytes construct + */ +@Experimental +class PortableDataStream( + @transient isplit: CombineFileSplit, + @transient context: TaskAttemptContext, + index: Integer) + extends Serializable { + + // transient forces file to be reopened after being serialization + // it is also used for non-serializable classes + + @transient private var fileIn: DataInputStream = null + @transient private var isOpen = false + + private val confBytes = { + val baos = new ByteArrayOutputStream() + context.getConfiguration.write(new DataOutputStream(baos)) + baos.toByteArray + } + + private val splitBytes = { + val baos = new ByteArrayOutputStream() + isplit.write(new DataOutputStream(baos)) + baos.toByteArray + } + + @transient private lazy val split = { + val bais = new ByteArrayInputStream(splitBytes) + val nsplit = new CombineFileSplit() + nsplit.readFields(new DataInputStream(bais)) + nsplit + } + + @transient private lazy val conf = { + val bais = new ByteArrayInputStream(confBytes) + val nconf = new Configuration() + nconf.readFields(new DataInputStream(bais)) + nconf + } + /** + * Calculate the path name independently of opening the file + */ + @transient private lazy val path = { + val pathp = split.getPath(index) + pathp.toString + } + + /** + * Create a new DataInputStream from the split and context + */ + def open(): DataInputStream = { + if (!isOpen) { + val pathp = split.getPath(index) + val fs = pathp.getFileSystem(conf) + fileIn = fs.open(pathp) + isOpen = true + } + fileIn + } + + /** + * Read the file as a byte array + */ + def toArray(): Array[Byte] = { + open() + val innerBuffer = ByteStreams.toByteArray(fileIn) + close() + innerBuffer + } + + /** + * Close the file (if it is currently open) + */ + def close() = { + if (isOpen) { + try { + fileIn.close() + isOpen = false + } catch { + case ioe: java.io.IOException => // do nothing + } + } + } + + def getPath(): String = path +} + diff --git a/core/src/main/scala/org/apache/spark/input/WholeTextFileInputFormat.scala b/core/src/main/scala/org/apache/spark/input/WholeTextFileInputFormat.scala index 4cb450577796a..183bce3d8d8d3 100644 --- a/core/src/main/scala/org/apache/spark/input/WholeTextFileInputFormat.scala +++ b/core/src/main/scala/org/apache/spark/input/WholeTextFileInputFormat.scala @@ -48,9 +48,10 @@ private[spark] class WholeTextFileInputFormat extends CombineFileInputFormat[Str } /** - * Allow minPartitions set by end-user in order to keep compatibility with old Hadoop API. + * Allow minPartitions set by end-user in order to keep compatibility with old Hadoop API, + * which is set through setMaxSplitSize */ - def setMaxSplitSize(context: JobContext, minPartitions: Int) { + def setMinPartitions(context: JobContext, minPartitions: Int) { val files = listStatus(context) val totalLen = files.map { file => if (file.isDir) 0L else file.getLen diff --git a/core/src/main/scala/org/apache/spark/network/BlockDataManager.scala b/core/src/main/scala/org/apache/spark/network/BlockDataManager.scala index e0e91724271c8..1745d52c81923 100644 --- a/core/src/main/scala/org/apache/spark/network/BlockDataManager.scala +++ b/core/src/main/scala/org/apache/spark/network/BlockDataManager.scala @@ -17,20 +17,20 @@ package org.apache.spark.network -import org.apache.spark.storage.StorageLevel - +import org.apache.spark.network.buffer.ManagedBuffer +import org.apache.spark.storage.{BlockId, StorageLevel} +private[spark] trait BlockDataManager { /** - * Interface to get local block data. - * - * @return Some(buffer) if the block exists locally, and None if it doesn't. + * Interface to get local block data. Throws an exception if the block cannot be found or + * cannot be read successfully. */ - def getBlockData(blockId: String): Option[ManagedBuffer] + def getBlockData(blockId: BlockId): ManagedBuffer /** * Put the block locally, using the given storage level. */ - def putBlockData(blockId: String, data: ManagedBuffer, level: StorageLevel): Unit + def putBlockData(blockId: BlockId, data: ManagedBuffer, level: StorageLevel): Unit } diff --git a/core/src/main/scala/org/apache/spark/network/BlockTransferService.scala b/core/src/main/scala/org/apache/spark/network/BlockTransferService.scala index 84d991fa6808c..210a581db466e 100644 --- a/core/src/main/scala/org/apache/spark/network/BlockTransferService.scala +++ b/core/src/main/scala/org/apache/spark/network/BlockTransferService.scala @@ -17,13 +17,19 @@ package org.apache.spark.network -import scala.concurrent.{Await, Future} -import scala.concurrent.duration.Duration +import java.io.Closeable +import java.nio.ByteBuffer -import org.apache.spark.storage.StorageLevel +import scala.concurrent.{Promise, Await, Future} +import scala.concurrent.duration.Duration +import org.apache.spark.Logging +import org.apache.spark.network.buffer.{NioManagedBuffer, ManagedBuffer} +import org.apache.spark.network.shuffle.{ShuffleClient, BlockFetchingListener} +import org.apache.spark.storage.{BlockManagerId, BlockId, StorageLevel} -abstract class BlockTransferService { +private[spark] +abstract class BlockTransferService extends ShuffleClient with Closeable with Logging { /** * Initialize the transfer service by giving it the BlockDataManager that can be used to fetch @@ -34,7 +40,7 @@ abstract class BlockTransferService { /** * Tear down the transfer service. */ - def stop(): Unit + def close(): Unit /** * Port number the service is listening on, available only after [[init]] is invoked. @@ -50,17 +56,15 @@ abstract class BlockTransferService { * Fetch a sequence of blocks from a remote node asynchronously, * available only after [[init]] is invoked. * - * Note that [[BlockFetchingListener.onBlockFetchSuccess]] is called once per block, - * while [[BlockFetchingListener.onBlockFetchFailure]] is called once per failure (not per block). - * * Note that this API takes a sequence so the implementation can batch requests, and does not * return a future so the underlying implementation can invoke onBlockFetchSuccess as soon as * the data of a block is fetched, rather than waiting for all blocks to be fetched. */ - def fetchBlocks( - hostName: String, + override def fetchBlocks( + host: String, port: Int, - blockIds: Seq[String], + execId: String, + blockIds: Array[String], listener: BlockFetchingListener): Unit /** @@ -69,7 +73,7 @@ abstract class BlockTransferService { def uploadBlock( hostname: String, port: Int, - blockId: String, + blockId: BlockId, blockData: ManagedBuffer, level: StorageLevel): Future[Unit] @@ -78,40 +82,23 @@ abstract class BlockTransferService { * * It is also only available after [[init]] is invoked. */ - def fetchBlockSync(hostName: String, port: Int, blockId: String): ManagedBuffer = { + def fetchBlockSync(host: String, port: Int, execId: String, blockId: String): ManagedBuffer = { // A monitor for the thread to wait on. - val lock = new Object - @volatile var result: Either[ManagedBuffer, Throwable] = null - fetchBlocks(hostName, port, Seq(blockId), new BlockFetchingListener { - override def onBlockFetchFailure(exception: Throwable): Unit = { - lock.synchronized { - result = Right(exception) - lock.notify() + val result = Promise[ManagedBuffer]() + fetchBlocks(host, port, execId, Array(blockId), + new BlockFetchingListener { + override def onBlockFetchFailure(blockId: String, exception: Throwable): Unit = { + result.failure(exception) } - } - override def onBlockFetchSuccess(blockId: String, data: ManagedBuffer): Unit = { - lock.synchronized { - result = Left(data) - lock.notify() - } - } - }) - - // Sleep until result is no longer null - lock.synchronized { - while (result == null) { - try { - lock.wait() - } catch { - case e: InterruptedException => + override def onBlockFetchSuccess(blockId: String, data: ManagedBuffer): Unit = { + val ret = ByteBuffer.allocate(data.size.toInt) + ret.put(data.nioByteBuffer()) + ret.flip() + result.success(new NioManagedBuffer(ret)) } - } - } + }) - result match { - case Left(data) => data - case Right(e) => throw e - } + Await.result(result.future, Duration.Inf) } /** @@ -123,7 +110,7 @@ abstract class BlockTransferService { def uploadBlockSync( hostname: String, port: Int, - blockId: String, + blockId: BlockId, blockData: ManagedBuffer, level: StorageLevel): Unit = { Await.result(uploadBlock(hostname, port, blockId, blockData, level), Duration.Inf) diff --git a/core/src/main/scala/org/apache/spark/network/ManagedBuffer.scala b/core/src/main/scala/org/apache/spark/network/ManagedBuffer.scala deleted file mode 100644 index 4c9ca97a2a6b7..0000000000000 --- a/core/src/main/scala/org/apache/spark/network/ManagedBuffer.scala +++ /dev/null @@ -1,160 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.network - -import java.io._ -import java.nio.ByteBuffer -import java.nio.channels.FileChannel -import java.nio.channels.FileChannel.MapMode - -import scala.util.Try - -import com.google.common.io.ByteStreams -import io.netty.buffer.{ByteBufInputStream, ByteBuf} - -import org.apache.spark.util.{ByteBufferInputStream, Utils} - - -/** - * This interface provides an immutable view for data in the form of bytes. The implementation - * should specify how the data is provided: - * - * - FileSegmentManagedBuffer: data backed by part of a file - * - NioByteBufferManagedBuffer: data backed by a NIO ByteBuffer - * - NettyByteBufManagedBuffer: data backed by a Netty ByteBuf - */ -sealed abstract class ManagedBuffer { - // Note that all the methods are defined with parenthesis because their implementations can - // have side effects (io operations). - - /** Number of bytes of the data. */ - def size: Long - - /** - * Exposes this buffer's data as an NIO ByteBuffer. Changing the position and limit of the - * returned ByteBuffer should not affect the content of this buffer. - */ - def nioByteBuffer(): ByteBuffer - - /** - * Exposes this buffer's data as an InputStream. The underlying implementation does not - * necessarily check for the length of bytes read, so the caller is responsible for making sure - * it does not go over the limit. - */ - def inputStream(): InputStream -} - - -/** - * A [[ManagedBuffer]] backed by a segment in a file - */ -final class FileSegmentManagedBuffer(val file: File, val offset: Long, val length: Long) - extends ManagedBuffer { - - /** - * Memory mapping is expensive and can destabilize the JVM (SPARK-1145, SPARK-3889). - * Avoid unless there's a good reason not to. - */ - private val MIN_MEMORY_MAP_BYTES = 2 * 1024 * 1024; - - override def size: Long = length - - override def nioByteBuffer(): ByteBuffer = { - var channel: FileChannel = null - try { - channel = new RandomAccessFile(file, "r").getChannel - // Just copy the buffer if it's sufficiently small, as memory mapping has a high overhead. - if (length < MIN_MEMORY_MAP_BYTES) { - val buf = ByteBuffer.allocate(length.toInt) - channel.read(buf, offset) - buf.flip() - buf - } else { - channel.map(MapMode.READ_ONLY, offset, length) - } - } catch { - case e: IOException => - Try(channel.size).toOption match { - case Some(fileLen) => - throw new IOException(s"Error in reading $this (actual file length $fileLen)", e) - case None => - throw new IOException(s"Error in opening $this", e) - } - } finally { - if (channel != null) { - Utils.tryLog(channel.close()) - } - } - } - - override def inputStream(): InputStream = { - var is: FileInputStream = null - try { - is = new FileInputStream(file) - is.skip(offset) - ByteStreams.limit(is, length) - } catch { - case e: IOException => - if (is != null) { - Utils.tryLog(is.close()) - } - Try(file.length).toOption match { - case Some(fileLen) => - throw new IOException(s"Error in reading $this (actual file length $fileLen)", e) - case None => - throw new IOException(s"Error in opening $this", e) - } - case e: Throwable => - if (is != null) { - Utils.tryLog(is.close()) - } - throw e - } - } - - override def toString: String = s"${getClass.getName}($file, $offset, $length)" -} - - -/** - * A [[ManagedBuffer]] backed by [[java.nio.ByteBuffer]]. - */ -final class NioByteBufferManagedBuffer(buf: ByteBuffer) extends ManagedBuffer { - - override def size: Long = buf.remaining() - - override def nioByteBuffer() = buf.duplicate() - - override def inputStream() = new ByteBufferInputStream(buf) -} - - -/** - * A [[ManagedBuffer]] backed by a Netty [[ByteBuf]]. - */ -final class NettyByteBufManagedBuffer(buf: ByteBuf) extends ManagedBuffer { - - override def size: Long = buf.readableBytes() - - override def nioByteBuffer() = buf.nioBuffer() - - override def inputStream() = new ByteBufInputStream(buf) - - // TODO(rxin): Promote this to top level ManagedBuffer interface and add documentation for it. - def release(): Unit = buf.release() -} diff --git a/core/src/main/scala/org/apache/spark/network/netty/NettyBlockRpcServer.scala b/core/src/main/scala/org/apache/spark/network/netty/NettyBlockRpcServer.scala new file mode 100644 index 0000000000000..1950e7bd634ee --- /dev/null +++ b/core/src/main/scala/org/apache/spark/network/netty/NettyBlockRpcServer.scala @@ -0,0 +1,80 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.netty + +import java.nio.ByteBuffer + +import scala.collection.JavaConversions._ + +import org.apache.spark.Logging +import org.apache.spark.network.BlockDataManager +import org.apache.spark.network.buffer.{ManagedBuffer, NioManagedBuffer} +import org.apache.spark.network.client.{RpcResponseCallback, TransportClient} +import org.apache.spark.network.server.{OneForOneStreamManager, RpcHandler, StreamManager} +import org.apache.spark.network.shuffle.ShuffleStreamHandle +import org.apache.spark.serializer.Serializer +import org.apache.spark.storage.{BlockId, StorageLevel} + +object NettyMessages { + /** Request to read a set of blocks. Returns [[ShuffleStreamHandle]] to identify the stream. */ + case class OpenBlocks(blockIds: Seq[BlockId]) + + /** Request to upload a block with a certain StorageLevel. Returns nothing (empty byte array). */ + case class UploadBlock(blockId: BlockId, blockData: Array[Byte], level: StorageLevel) +} + +/** + * Serves requests to open blocks by simply registering one chunk per block requested. + * Handles opening and uploading arbitrary BlockManager blocks. + * + * Opened blocks are registered with the "one-for-one" strategy, meaning each Transport-layer Chunk + * is equivalent to one Spark-level shuffle block. + */ +class NettyBlockRpcServer( + serializer: Serializer, + blockManager: BlockDataManager) + extends RpcHandler with Logging { + + import NettyMessages._ + + private val streamManager = new OneForOneStreamManager() + + override def receive( + client: TransportClient, + messageBytes: Array[Byte], + responseContext: RpcResponseCallback): Unit = { + val ser = serializer.newInstance() + val message = ser.deserialize[AnyRef](ByteBuffer.wrap(messageBytes)) + logTrace(s"Received request: $message") + + message match { + case OpenBlocks(blockIds) => + val blocks: Seq[ManagedBuffer] = blockIds.map(blockManager.getBlockData) + val streamId = streamManager.registerStream(blocks.iterator) + logTrace(s"Registered streamId $streamId with ${blocks.size} buffers") + responseContext.onSuccess( + ser.serialize(new ShuffleStreamHandle(streamId, blocks.size)).array()) + + case UploadBlock(blockId, blockData, level) => + blockManager.putBlockData(blockId, new NioManagedBuffer(ByteBuffer.wrap(blockData)), level) + responseContext.onSuccess(new Array[Byte](0)) + } + } + + override def getStreamManager(): StreamManager = streamManager +} diff --git a/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala b/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala new file mode 100644 index 0000000000000..1c4327cf13b51 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala @@ -0,0 +1,113 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.netty + +import scala.concurrent.{Future, Promise} + +import org.apache.spark.SparkConf +import org.apache.spark.network._ +import org.apache.spark.network.buffer.ManagedBuffer +import org.apache.spark.network.client.{RpcResponseCallback, TransportClientFactory} +import org.apache.spark.network.netty.NettyMessages.{OpenBlocks, UploadBlock} +import org.apache.spark.network.server._ +import org.apache.spark.network.shuffle.{BlockFetchingListener, OneForOneBlockFetcher} +import org.apache.spark.serializer.JavaSerializer +import org.apache.spark.storage.{BlockId, StorageLevel} +import org.apache.spark.util.Utils + +/** + * A BlockTransferService that uses Netty to fetch a set of blocks at at time. + */ +class NettyBlockTransferService(conf: SparkConf) extends BlockTransferService { + // TODO: Don't use Java serialization, use a more cross-version compatible serialization format. + val serializer = new JavaSerializer(conf) + + private[this] var transportContext: TransportContext = _ + private[this] var server: TransportServer = _ + private[this] var clientFactory: TransportClientFactory = _ + + override def init(blockDataManager: BlockDataManager): Unit = { + val rpcHandler = new NettyBlockRpcServer(serializer, blockDataManager) + transportContext = new TransportContext(SparkTransportConf.fromSparkConf(conf), rpcHandler) + clientFactory = transportContext.createClientFactory() + server = transportContext.createServer() + logInfo("Server created on " + server.getPort) + } + + override def fetchBlocks( + host: String, + port: Int, + execId: String, + blockIds: Array[String], + listener: BlockFetchingListener): Unit = { + logTrace(s"Fetch blocks from $host:$port (executor id $execId)") + try { + val client = clientFactory.createClient(host, port) + new OneForOneBlockFetcher(client, blockIds.toArray, listener) + .start(OpenBlocks(blockIds.map(BlockId.apply))) + } catch { + case e: Exception => + logError("Exception while beginning fetchBlocks", e) + blockIds.foreach(listener.onBlockFetchFailure(_, e)) + } + } + + override def hostName: String = Utils.localHostName() + + override def port: Int = server.getPort + + override def uploadBlock( + hostname: String, + port: Int, + blockId: BlockId, + blockData: ManagedBuffer, + level: StorageLevel): Future[Unit] = { + val result = Promise[Unit]() + val client = clientFactory.createClient(hostname, port) + + // Convert or copy nio buffer into array in order to serialize it. + val nioBuffer = blockData.nioByteBuffer() + val array = if (nioBuffer.hasArray) { + nioBuffer.array() + } else { + val data = new Array[Byte](nioBuffer.remaining()) + nioBuffer.get(data) + data + } + + val ser = serializer.newInstance() + client.sendRpc(ser.serialize(new UploadBlock(blockId, array, level)).array(), + new RpcResponseCallback { + override def onSuccess(response: Array[Byte]): Unit = { + logTrace(s"Successfully uploaded block $blockId") + result.success() + } + override def onFailure(e: Throwable): Unit = { + logError(s"Error while uploading block $blockId", e) + result.failure(e) + } + }) + + result.future + } + + override def close(): Unit = { + server.close() + clientFactory.close() + } +} diff --git a/core/src/main/scala/org/apache/spark/network/netty/NettyConfig.scala b/core/src/main/scala/org/apache/spark/network/netty/NettyConfig.scala deleted file mode 100644 index b5870152c5a64..0000000000000 --- a/core/src/main/scala/org/apache/spark/network/netty/NettyConfig.scala +++ /dev/null @@ -1,59 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.network.netty - -import org.apache.spark.SparkConf - -/** - * A central location that tracks all the settings we exposed to users. - */ -private[spark] -class NettyConfig(conf: SparkConf) { - - /** Port the server listens on. Default to a random port. */ - private[netty] val serverPort = conf.getInt("spark.shuffle.io.port", 0) - - /** IO mode: nio, oio, epoll, or auto (try epoll first and then nio). */ - private[netty] val ioMode = conf.get("spark.shuffle.io.mode", "nio").toLowerCase - - /** Connect timeout in secs. Default 60 secs. */ - private[netty] val connectTimeoutMs = conf.getInt("spark.shuffle.io.connectionTimeout", 60) * 1000 - - /** - * Percentage of the desired amount of time spent for I/O in the child event loops. - * Only applicable in nio and epoll. - */ - private[netty] val ioRatio = conf.getInt("spark.shuffle.io.netty.ioRatio", 80) - - /** Requested maximum length of the queue of incoming connections. */ - private[netty] val backLog: Option[Int] = conf.getOption("spark.shuffle.io.backLog").map(_.toInt) - - /** - * Receive buffer size (SO_RCVBUF). - * Note: the optimal size for receive buffer and send buffer should be - * latency * network_bandwidth. - * Assuming latency = 1ms, network_bandwidth = 10Gbps - * buffer size should be ~ 1.25MB - */ - private[netty] val receiveBuf: Option[Int] = - conf.getOption("spark.shuffle.io.sendBuffer").map(_.toInt) - - /** Send buffer size (SO_SNDBUF). */ - private[netty] val sendBuf: Option[Int] = - conf.getOption("spark.shuffle.io.sendBuffer").map(_.toInt) -} diff --git a/core/src/main/scala/org/apache/spark/network/BlockFetchingListener.scala b/core/src/main/scala/org/apache/spark/network/netty/SparkTransportConf.scala similarity index 62% rename from core/src/main/scala/org/apache/spark/network/BlockFetchingListener.scala rename to core/src/main/scala/org/apache/spark/network/netty/SparkTransportConf.scala index 34acaa563ca58..9fa4fa77b8817 100644 --- a/core/src/main/scala/org/apache/spark/network/BlockFetchingListener.scala +++ b/core/src/main/scala/org/apache/spark/network/netty/SparkTransportConf.scala @@ -15,23 +15,18 @@ * limitations under the License. */ -package org.apache.spark.network - -import java.util.EventListener +package org.apache.spark.network.netty +import org.apache.spark.SparkConf +import org.apache.spark.network.util.{TransportConf, ConfigProvider} /** - * Listener callback interface for [[BlockTransferService.fetchBlocks]]. + * Utility for creating a [[TransportConf]] from a [[SparkConf]]. */ -trait BlockFetchingListener extends EventListener { - - /** - * Called once per successfully fetched block. - */ - def onBlockFetchSuccess(blockId: String, data: ManagedBuffer): Unit - - /** - * Called upon failures. For each failure, this is called only once (i.e. not once per block). - */ - def onBlockFetchFailure(exception: Throwable): Unit +object SparkTransportConf { + def fromSparkConf(conf: SparkConf): TransportConf = { + new TransportConf(new ConfigProvider { + override def get(name: String): String = conf.get(name) + }) + } } diff --git a/core/src/main/scala/org/apache/spark/network/netty/client/BlockFetchingClient.scala b/core/src/main/scala/org/apache/spark/network/netty/client/BlockFetchingClient.scala deleted file mode 100644 index 5aea7ba2f3673..0000000000000 --- a/core/src/main/scala/org/apache/spark/network/netty/client/BlockFetchingClient.scala +++ /dev/null @@ -1,132 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.network.netty.client - -import java.util.concurrent.TimeoutException - -import io.netty.bootstrap.Bootstrap -import io.netty.buffer.PooledByteBufAllocator -import io.netty.channel.socket.SocketChannel -import io.netty.channel.{ChannelFutureListener, ChannelFuture, ChannelInitializer, ChannelOption} -import io.netty.handler.codec.LengthFieldBasedFrameDecoder -import io.netty.handler.codec.string.StringEncoder -import io.netty.util.CharsetUtil - -import org.apache.spark.Logging - -/** - * Client for fetching data blocks from [[org.apache.spark.network.netty.server.BlockServer]]. - * Use [[BlockFetchingClientFactory]] to instantiate this client. - * - * The constructor blocks until a connection is successfully established. - * - * See [[org.apache.spark.network.netty.server.BlockServer]] for client/server protocol. - * - * Concurrency: thread safe and can be called from multiple threads. - */ -@throws[TimeoutException] -private[spark] -class BlockFetchingClient(factory: BlockFetchingClientFactory, hostname: String, port: Int) - extends Logging { - - private val handler = new BlockFetchingClientHandler - - /** Netty Bootstrap for creating the TCP connection. */ - private val bootstrap: Bootstrap = { - val b = new Bootstrap - b.group(factory.workerGroup) - .channel(factory.socketChannelClass) - // Use pooled buffers to reduce temporary buffer allocation - .option(ChannelOption.ALLOCATOR, PooledByteBufAllocator.DEFAULT) - // Disable Nagle's Algorithm since we don't want packets to wait - .option(ChannelOption.TCP_NODELAY, java.lang.Boolean.TRUE) - .option(ChannelOption.SO_KEEPALIVE, java.lang.Boolean.TRUE) - .option[Integer](ChannelOption.CONNECT_TIMEOUT_MILLIS, factory.conf.connectTimeoutMs) - - b.handler(new ChannelInitializer[SocketChannel] { - override def initChannel(ch: SocketChannel): Unit = { - ch.pipeline - .addLast("encoder", new StringEncoder(CharsetUtil.UTF_8)) - // maxFrameLength = 2G, lengthFieldOffset = 0, lengthFieldLength = 4 - .addLast("framedLengthDecoder", new LengthFieldBasedFrameDecoder(Int.MaxValue, 0, 4)) - .addLast("handler", handler) - } - }) - b - } - - /** Netty ChannelFuture for the connection. */ - private val cf: ChannelFuture = bootstrap.connect(hostname, port) - if (!cf.awaitUninterruptibly(factory.conf.connectTimeoutMs)) { - throw new TimeoutException( - s"Connecting to $hostname:$port timed out (${factory.conf.connectTimeoutMs} ms)") - } - - /** - * Ask the remote server for a sequence of blocks, and execute the callback. - * - * Note that this is asynchronous and returns immediately. Upstream caller should throttle the - * rate of fetching; otherwise we could run out of memory. - * - * @param blockIds sequence of block ids to fetch. - * @param listener callback to fire on fetch success / failure. - */ - def fetchBlocks(blockIds: Seq[String], listener: BlockClientListener): Unit = { - // It's best to limit the number of "write" calls since it needs to traverse the whole pipeline. - // It's also best to limit the number of "flush" calls since it requires system calls. - // Let's concatenate the string and then call writeAndFlush once. - // This is also why this implementation might be more efficient than multiple, separate - // fetch block calls. - var startTime: Long = 0 - logTrace { - startTime = System.nanoTime - s"Sending request $blockIds to $hostname:$port" - } - - blockIds.foreach { blockId => - handler.addRequest(blockId, listener) - } - - val writeFuture = cf.channel().writeAndFlush(blockIds.mkString("\n") + "\n") - writeFuture.addListener(new ChannelFutureListener { - override def operationComplete(future: ChannelFuture): Unit = { - if (future.isSuccess) { - logTrace { - val timeTaken = (System.nanoTime - startTime).toDouble / 1000000 - s"Sending request $blockIds to $hostname:$port took $timeTaken ms" - } - } else { - // Fail all blocks. - val errorMsg = - s"Failed to send request $blockIds to $hostname:$port: ${future.cause.getMessage}" - logError(errorMsg, future.cause) - blockIds.foreach { blockId => - listener.onFetchFailure(blockId, errorMsg) - handler.removeRequest(blockId) - } - } - } - }) - } - - def waitForClose(): Unit = { - cf.channel().closeFuture().sync() - } - - def close(): Unit = cf.channel().close() -} diff --git a/core/src/main/scala/org/apache/spark/network/netty/client/BlockFetchingClientFactory.scala b/core/src/main/scala/org/apache/spark/network/netty/client/BlockFetchingClientFactory.scala deleted file mode 100644 index 2b28402c52b49..0000000000000 --- a/core/src/main/scala/org/apache/spark/network/netty/client/BlockFetchingClientFactory.scala +++ /dev/null @@ -1,99 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.network.netty.client - -import io.netty.channel.epoll.{EpollEventLoopGroup, EpollSocketChannel} -import io.netty.channel.nio.NioEventLoopGroup -import io.netty.channel.oio.OioEventLoopGroup -import io.netty.channel.socket.nio.NioSocketChannel -import io.netty.channel.socket.oio.OioSocketChannel -import io.netty.channel.{EventLoopGroup, Channel} - -import org.apache.spark.SparkConf -import org.apache.spark.network.netty.NettyConfig -import org.apache.spark.util.Utils - -/** - * Factory for creating [[BlockFetchingClient]] by using createClient. This factory reuses - * the worker thread pool for Netty. - * - * Concurrency: createClient is safe to be called from multiple threads concurrently. - */ -private[spark] -class BlockFetchingClientFactory(val conf: NettyConfig) { - - def this(sparkConf: SparkConf) = this(new NettyConfig(sparkConf)) - - /** A thread factory so the threads are named (for debugging). */ - val threadFactory = Utils.namedThreadFactory("spark-shuffle-client") - - /** The following two are instantiated by the [[init]] method, depending ioMode. */ - var socketChannelClass: Class[_ <: Channel] = _ - var workerGroup: EventLoopGroup = _ - - init() - - /** Initialize [[socketChannelClass]] and [[workerGroup]] based on ioMode. */ - private def init(): Unit = { - def initOio(): Unit = { - socketChannelClass = classOf[OioSocketChannel] - workerGroup = new OioEventLoopGroup(0, threadFactory) - } - def initNio(): Unit = { - socketChannelClass = classOf[NioSocketChannel] - workerGroup = new NioEventLoopGroup(0, threadFactory) - } - def initEpoll(): Unit = { - socketChannelClass = classOf[EpollSocketChannel] - workerGroup = new EpollEventLoopGroup(0, threadFactory) - } - - conf.ioMode match { - case "nio" => initNio() - case "oio" => initOio() - case "epoll" => initEpoll() - case "auto" => - // For auto mode, first try epoll (only available on Linux), then nio. - try { - initEpoll() - } catch { - // TODO: Should we log the throwable? But that always happen on non-Linux systems. - // Perhaps the right thing to do is to check whether the system is Linux, and then only - // call initEpoll on Linux. - case e: Throwable => initNio() - } - } - } - - /** - * Create a new BlockFetchingClient connecting to the given remote host / port. - * - * This blocks until a connection is successfully established. - * - * Concurrency: This method is safe to call from multiple threads. - */ - def createClient(remoteHost: String, remotePort: Int): BlockFetchingClient = { - new BlockFetchingClient(this, remoteHost, remotePort) - } - - def stop(): Unit = { - if (workerGroup != null) { - workerGroup.shutdownGracefully() - } - } -} diff --git a/core/src/main/scala/org/apache/spark/network/netty/client/BlockFetchingClientHandler.scala b/core/src/main/scala/org/apache/spark/network/netty/client/BlockFetchingClientHandler.scala deleted file mode 100644 index 83265b164299d..0000000000000 --- a/core/src/main/scala/org/apache/spark/network/netty/client/BlockFetchingClientHandler.scala +++ /dev/null @@ -1,103 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.network.netty.client - -import io.netty.buffer.ByteBuf -import io.netty.channel.{ChannelHandlerContext, SimpleChannelInboundHandler} - -import org.apache.spark.Logging - - -/** - * Handler that processes server responses. It uses the protocol documented in - * [[org.apache.spark.network.netty.server.BlockServer]]. - * - * Concurrency: thread safe and can be called from multiple threads. - */ -private[client] -class BlockFetchingClientHandler extends SimpleChannelInboundHandler[ByteBuf] with Logging { - - /** Tracks the list of outstanding requests and their listeners on success/failure. */ - private val outstandingRequests = java.util.Collections.synchronizedMap { - new java.util.HashMap[String, BlockClientListener] - } - - def addRequest(blockId: String, listener: BlockClientListener): Unit = { - outstandingRequests.put(blockId, listener) - } - - def removeRequest(blockId: String): Unit = { - outstandingRequests.remove(blockId) - } - - override def exceptionCaught(ctx: ChannelHandlerContext, cause: Throwable): Unit = { - val errorMsg = s"Exception in connection from ${ctx.channel.remoteAddress}: ${cause.getMessage}" - logError(errorMsg, cause) - - // Fire the failure callback for all outstanding blocks - outstandingRequests.synchronized { - val iter = outstandingRequests.entrySet().iterator() - while (iter.hasNext) { - val entry = iter.next() - entry.getValue.onFetchFailure(entry.getKey, errorMsg) - } - outstandingRequests.clear() - } - - ctx.close() - } - - override def channelRead0(ctx: ChannelHandlerContext, in: ByteBuf) { - val totalLen = in.readInt() - val blockIdLen = in.readInt() - val blockIdBytes = new Array[Byte](math.abs(blockIdLen)) - in.readBytes(blockIdBytes) - val blockId = new String(blockIdBytes) - val blockSize = totalLen - math.abs(blockIdLen) - 4 - - def server = ctx.channel.remoteAddress.toString - - // blockIdLen is negative when it is an error message. - if (blockIdLen < 0) { - val errorMessageBytes = new Array[Byte](blockSize) - in.readBytes(errorMessageBytes) - val errorMsg = new String(errorMessageBytes) - logTrace(s"Received block $blockId ($blockSize B) with error $errorMsg from $server") - - val listener = outstandingRequests.get(blockId) - if (listener == null) { - // Ignore callback - logWarning(s"Got a response for block $blockId but it is not in our outstanding requests") - } else { - outstandingRequests.remove(blockId) - listener.onFetchFailure(blockId, errorMsg) - } - } else { - logTrace(s"Received block $blockId ($blockSize B) from $server") - - val listener = outstandingRequests.get(blockId) - if (listener == null) { - // Ignore callback - logWarning(s"Got a response for block $blockId but it is not in our outstanding requests") - } else { - outstandingRequests.remove(blockId) - listener.onFetchSuccess(blockId, new ReferenceCountedBuffer(in)) - } - } - } -} diff --git a/core/src/main/scala/org/apache/spark/network/netty/client/ReferenceCountedBuffer.scala b/core/src/main/scala/org/apache/spark/network/netty/client/ReferenceCountedBuffer.scala deleted file mode 100644 index ea1abf5eccc26..0000000000000 --- a/core/src/main/scala/org/apache/spark/network/netty/client/ReferenceCountedBuffer.scala +++ /dev/null @@ -1,47 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.network.netty.client - -import java.io.InputStream -import java.nio.ByteBuffer - -import io.netty.buffer.{ByteBuf, ByteBufInputStream} - - -/** - * A buffer abstraction based on Netty's ByteBuf so we don't expose Netty. - * This is a Scala value class. - * - * The buffer's life cycle is NOT managed by the JVM, and thus requiring explicit declaration of - * reference by the retain method and release method. - */ -private[spark] -class ReferenceCountedBuffer(val underlying: ByteBuf) extends AnyVal { - - /** Return the nio ByteBuffer view of the underlying buffer. */ - def byteBuffer(): ByteBuffer = underlying.nioBuffer - - /** Creates a new input stream that starts from the current position of the buffer. */ - def inputStream(): InputStream = new ByteBufInputStream(underlying) - - /** Increment the reference counter by one. */ - def retain(): Unit = underlying.retain() - - /** Decrement the reference counter by one and release the buffer if the ref count is 0. */ - def release(): Unit = underlying.release() -} diff --git a/core/src/main/scala/org/apache/spark/network/netty/server/BlockHeaderEncoder.scala b/core/src/main/scala/org/apache/spark/network/netty/server/BlockHeaderEncoder.scala deleted file mode 100644 index 8e4dda4ef8595..0000000000000 --- a/core/src/main/scala/org/apache/spark/network/netty/server/BlockHeaderEncoder.scala +++ /dev/null @@ -1,47 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.network.netty.server - -import io.netty.buffer.ByteBuf -import io.netty.channel.ChannelHandlerContext -import io.netty.handler.codec.MessageToByteEncoder - -/** - * A simple encoder for BlockHeader. See [[BlockServer]] for the server to client protocol. - */ -private[server] -class BlockHeaderEncoder extends MessageToByteEncoder[BlockHeader] { - override def encode(ctx: ChannelHandlerContext, msg: BlockHeader, out: ByteBuf): Unit = { - // message = message length (4 bytes) + block id length (4 bytes) + block id + block data - // message length = block id length (4 bytes) + size of block id + size of block data - val blockIdBytes = msg.blockId.getBytes - msg.error match { - case Some(errorMsg) => - val errorBytes = errorMsg.getBytes - out.writeInt(4 + blockIdBytes.length + errorBytes.size) - out.writeInt(-blockIdBytes.length) // use negative block id length to represent errors - out.writeBytes(blockIdBytes) // next is blockId itself - out.writeBytes(errorBytes) // error message - case None => - out.writeInt(4 + blockIdBytes.length + msg.blockSize) - out.writeInt(blockIdBytes.length) // First 4 bytes is blockId length - out.writeBytes(blockIdBytes) // next is blockId itself - // msg of size blockSize will be written by ServerHandler - } - } -} diff --git a/core/src/main/scala/org/apache/spark/network/netty/server/BlockServer.scala b/core/src/main/scala/org/apache/spark/network/netty/server/BlockServer.scala deleted file mode 100644 index 7b2f9a8d4dfd0..0000000000000 --- a/core/src/main/scala/org/apache/spark/network/netty/server/BlockServer.scala +++ /dev/null @@ -1,162 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.network.netty.server - -import java.net.InetSocketAddress - -import io.netty.bootstrap.ServerBootstrap -import io.netty.buffer.PooledByteBufAllocator -import io.netty.channel.{ChannelFuture, ChannelInitializer, ChannelOption} -import io.netty.channel.epoll.{EpollEventLoopGroup, EpollServerSocketChannel} -import io.netty.channel.nio.NioEventLoopGroup -import io.netty.channel.oio.OioEventLoopGroup -import io.netty.channel.socket.SocketChannel -import io.netty.channel.socket.nio.NioServerSocketChannel -import io.netty.channel.socket.oio.OioServerSocketChannel -import io.netty.handler.codec.LineBasedFrameDecoder -import io.netty.handler.codec.string.StringDecoder -import io.netty.util.CharsetUtil - -import org.apache.spark.{Logging, SparkConf} -import org.apache.spark.network.netty.NettyConfig -import org.apache.spark.storage.BlockDataProvider -import org.apache.spark.util.Utils - - -/** - * Server for serving Spark data blocks. - * This should be used together with [[org.apache.spark.network.netty.client.BlockFetchingClient]]. - * - * Protocol for requesting blocks (client to server): - * One block id per line, e.g. to request 3 blocks: "block1\nblock2\nblock3\n" - * - * Protocol for sending blocks (server to client): - * frame-length (4 bytes), block-id-length (4 bytes), block-id, block-data. - * - * frame-length should not include the length of itself. - * If block-id-length is negative, then this is an error message rather than block-data. The real - * length is the absolute value of the frame-length. - * - */ -private[spark] -class BlockServer(conf: NettyConfig, dataProvider: BlockDataProvider) extends Logging { - - def this(sparkConf: SparkConf, dataProvider: BlockDataProvider) = { - this(new NettyConfig(sparkConf), dataProvider) - } - - def port: Int = _port - - def hostName: String = _hostName - - private var _port: Int = conf.serverPort - private var _hostName: String = "" - private var bootstrap: ServerBootstrap = _ - private var channelFuture: ChannelFuture = _ - - init() - - /** Initialize the server. */ - private def init(): Unit = { - bootstrap = new ServerBootstrap - val bossThreadFactory = Utils.namedThreadFactory("spark-shuffle-server-boss") - val workerThreadFactory = Utils.namedThreadFactory("spark-shuffle-server-worker") - - // Use only one thread to accept connections, and 2 * num_cores for worker. - def initNio(): Unit = { - val bossGroup = new NioEventLoopGroup(1, bossThreadFactory) - val workerGroup = new NioEventLoopGroup(0, workerThreadFactory) - workerGroup.setIoRatio(conf.ioRatio) - bootstrap.group(bossGroup, workerGroup).channel(classOf[NioServerSocketChannel]) - } - def initOio(): Unit = { - val bossGroup = new OioEventLoopGroup(1, bossThreadFactory) - val workerGroup = new OioEventLoopGroup(0, workerThreadFactory) - bootstrap.group(bossGroup, workerGroup).channel(classOf[OioServerSocketChannel]) - } - def initEpoll(): Unit = { - val bossGroup = new EpollEventLoopGroup(1, bossThreadFactory) - val workerGroup = new EpollEventLoopGroup(0, workerThreadFactory) - workerGroup.setIoRatio(conf.ioRatio) - bootstrap.group(bossGroup, workerGroup).channel(classOf[EpollServerSocketChannel]) - } - - conf.ioMode match { - case "nio" => initNio() - case "oio" => initOio() - case "epoll" => initEpoll() - case "auto" => - // For auto mode, first try epoll (only available on Linux), then nio. - try { - initEpoll() - } catch { - // TODO: Should we log the throwable? But that always happen on non-Linux systems. - // Perhaps the right thing to do is to check whether the system is Linux, and then only - // call initEpoll on Linux. - case e: Throwable => initNio() - } - } - - // Use pooled buffers to reduce temporary buffer allocation - bootstrap.option(ChannelOption.ALLOCATOR, PooledByteBufAllocator.DEFAULT) - bootstrap.childOption(ChannelOption.ALLOCATOR, PooledByteBufAllocator.DEFAULT) - - // Various (advanced) user-configured settings. - conf.backLog.foreach { backLog => - bootstrap.option[java.lang.Integer](ChannelOption.SO_BACKLOG, backLog) - } - conf.receiveBuf.foreach { receiveBuf => - bootstrap.option[java.lang.Integer](ChannelOption.SO_RCVBUF, receiveBuf) - } - conf.sendBuf.foreach { sendBuf => - bootstrap.option[java.lang.Integer](ChannelOption.SO_SNDBUF, sendBuf) - } - - bootstrap.childHandler(new ChannelInitializer[SocketChannel] { - override def initChannel(ch: SocketChannel): Unit = { - ch.pipeline - .addLast("frameDecoder", new LineBasedFrameDecoder(1024)) // max block id length 1024 - .addLast("stringDecoder", new StringDecoder(CharsetUtil.UTF_8)) - .addLast("blockHeaderEncoder", new BlockHeaderEncoder) - .addLast("handler", new BlockServerHandler(dataProvider)) - } - }) - - channelFuture = bootstrap.bind(new InetSocketAddress(_port)) - channelFuture.sync() - - val addr = channelFuture.channel.localAddress.asInstanceOf[InetSocketAddress] - _port = addr.getPort - _hostName = addr.getHostName - } - - /** Shutdown the server. */ - def stop(): Unit = { - if (channelFuture != null) { - channelFuture.channel().close().awaitUninterruptibly() - channelFuture = null - } - if (bootstrap != null && bootstrap.group() != null) { - bootstrap.group().shutdownGracefully() - } - if (bootstrap != null && bootstrap.childGroup() != null) { - bootstrap.childGroup().shutdownGracefully() - } - bootstrap = null - } -} diff --git a/core/src/main/scala/org/apache/spark/network/netty/server/BlockServerChannelInitializer.scala b/core/src/main/scala/org/apache/spark/network/netty/server/BlockServerChannelInitializer.scala deleted file mode 100644 index cc70bd0c5c477..0000000000000 --- a/core/src/main/scala/org/apache/spark/network/netty/server/BlockServerChannelInitializer.scala +++ /dev/null @@ -1,40 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.network.netty.server - -import io.netty.channel.ChannelInitializer -import io.netty.channel.socket.SocketChannel -import io.netty.handler.codec.LineBasedFrameDecoder -import io.netty.handler.codec.string.StringDecoder -import io.netty.util.CharsetUtil -import org.apache.spark.storage.BlockDataProvider - - -/** Channel initializer that sets up the pipeline for the BlockServer. */ -private[netty] -class BlockServerChannelInitializer(dataProvider: BlockDataProvider) - extends ChannelInitializer[SocketChannel] { - - override def initChannel(ch: SocketChannel): Unit = { - ch.pipeline - .addLast("frameDecoder", new LineBasedFrameDecoder(1024)) // max block id length 1024 - .addLast("stringDecoder", new StringDecoder(CharsetUtil.UTF_8)) - .addLast("blockHeaderEncoder", new BlockHeaderEncoder) - .addLast("handler", new BlockServerHandler(dataProvider)) - } -} diff --git a/core/src/main/scala/org/apache/spark/network/netty/server/BlockServerHandler.scala b/core/src/main/scala/org/apache/spark/network/netty/server/BlockServerHandler.scala deleted file mode 100644 index 40dd5e5d1a2ac..0000000000000 --- a/core/src/main/scala/org/apache/spark/network/netty/server/BlockServerHandler.scala +++ /dev/null @@ -1,140 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.network.netty.server - -import java.io.FileInputStream -import java.nio.ByteBuffer -import java.nio.channels.FileChannel - -import io.netty.buffer.Unpooled -import io.netty.channel._ - -import org.apache.spark.Logging -import org.apache.spark.storage.{FileSegment, BlockDataProvider} - - -/** - * A handler that processes requests from clients and writes block data back. - * - * The messages should have been processed by a LineBasedFrameDecoder and a StringDecoder first - * so channelRead0 is called once per line (i.e. per block id). - */ -private[server] -class BlockServerHandler(dataProvider: BlockDataProvider) - extends SimpleChannelInboundHandler[String] with Logging { - - override def exceptionCaught(ctx: ChannelHandlerContext, cause: Throwable): Unit = { - logError(s"Exception in connection from ${ctx.channel.remoteAddress}", cause) - ctx.close() - } - - override def channelRead0(ctx: ChannelHandlerContext, blockId: String): Unit = { - def client = ctx.channel.remoteAddress.toString - - // A helper function to send error message back to the client. - def respondWithError(error: String): Unit = { - ctx.writeAndFlush(new BlockHeader(-1, blockId, Some(error))).addListener( - new ChannelFutureListener { - override def operationComplete(future: ChannelFuture) { - if (!future.isSuccess) { - // TODO: Maybe log the success case as well. - logError(s"Error sending error back to $client", future.cause) - ctx.close() - } - } - } - ) - } - - def writeFileSegment(segment: FileSegment): Unit = { - // Send error message back if the block is too large. Even though we are capable of sending - // large (2G+) blocks, the receiving end cannot handle it so let's fail fast. - // Once we fixed the receiving end to be able to process large blocks, this should be removed. - // Also make sure we update BlockHeaderEncoder to support length > 2G. - - // See [[BlockHeaderEncoder]] for the way length is encoded. - if (segment.length + blockId.length + 4 > Int.MaxValue) { - respondWithError(s"Block $blockId size ($segment.length) greater than 2G") - return - } - - var fileChannel: FileChannel = null - try { - fileChannel = new FileInputStream(segment.file).getChannel - } catch { - case e: Exception => - logError( - s"Error opening channel for $blockId in ${segment.file} for request from $client", e) - respondWithError(e.getMessage) - } - - // Found the block. Send it back. - if (fileChannel != null) { - // Write the header and block data. In the case of failures, the listener on the block data - // write should close the connection. - ctx.write(new BlockHeader(segment.length.toInt, blockId)) - - val region = new DefaultFileRegion(fileChannel, segment.offset, segment.length) - ctx.writeAndFlush(region).addListener(new ChannelFutureListener { - override def operationComplete(future: ChannelFuture) { - if (future.isSuccess) { - logTrace(s"Sent block $blockId (${segment.length} B) back to $client") - } else { - logError(s"Error sending block $blockId to $client; closing connection", future.cause) - ctx.close() - } - } - }) - } - } - - def writeByteBuffer(buf: ByteBuffer): Unit = { - ctx.write(new BlockHeader(buf.remaining, blockId)) - ctx.writeAndFlush(Unpooled.wrappedBuffer(buf)).addListener(new ChannelFutureListener { - override def operationComplete(future: ChannelFuture) { - if (future.isSuccess) { - logTrace(s"Sent block $blockId (${buf.remaining} B) back to $client") - } else { - logError(s"Error sending block $blockId to $client; closing connection", future.cause) - ctx.close() - } - } - }) - } - - logTrace(s"Received request from $client to fetch block $blockId") - - var blockData: Either[FileSegment, ByteBuffer] = null - - // First make sure we can find the block. If not, send error back to the user. - try { - blockData = dataProvider.getBlockData(blockId) - } catch { - case e: Exception => - logError(s"Error opening block $blockId for request from $client", e) - respondWithError(e.getMessage) - return - } - - blockData match { - case Left(segment) => writeFileSegment(segment) - case Right(buf) => writeByteBuffer(buf) - } - - } // end of channelRead0 -} diff --git a/core/src/main/scala/org/apache/spark/network/nio/ConnectionManager.scala b/core/src/main/scala/org/apache/spark/network/nio/ConnectionManager.scala index bda4bf50932c3..8408b75bb4d65 100644 --- a/core/src/main/scala/org/apache/spark/network/nio/ConnectionManager.scala +++ b/core/src/main/scala/org/apache/spark/network/nio/ConnectionManager.scala @@ -31,6 +31,8 @@ import scala.concurrent.duration._ import scala.concurrent.{Await, ExecutionContext, Future, Promise} import scala.language.postfixOps +import com.google.common.base.Charsets.UTF_8 + import org.apache.spark._ import org.apache.spark.util.Utils @@ -923,7 +925,7 @@ private[nio] class ConnectionManager( val errorMsgByteBuf = ackMessage.asInstanceOf[BufferMessage].buffers.head val errorMsgBytes = new Array[Byte](errorMsgByteBuf.limit()) errorMsgByteBuf.get(errorMsgBytes) - val errorMsg = new String(errorMsgBytes, "utf-8") + val errorMsg = new String(errorMsgBytes, UTF_8) val e = new IOException( s"sendMessageReliably failed with ACK that signalled a remote error: $errorMsg") if (!promise.tryFailure(e)) { diff --git a/core/src/main/scala/org/apache/spark/network/nio/Message.scala b/core/src/main/scala/org/apache/spark/network/nio/Message.scala index 3ad04591da658..fb4a979b824c3 100644 --- a/core/src/main/scala/org/apache/spark/network/nio/Message.scala +++ b/core/src/main/scala/org/apache/spark/network/nio/Message.scala @@ -22,6 +22,8 @@ import java.nio.ByteBuffer import scala.collection.mutable.ArrayBuffer +import com.google.common.base.Charsets.UTF_8 + import org.apache.spark.util.Utils private[nio] abstract class Message(val typ: Long, val id: Int) { @@ -92,7 +94,7 @@ private[nio] object Message { */ def createErrorMessage(exception: Exception, ackId: Int): BufferMessage = { val exceptionString = Utils.exceptionString(exception) - val serializedExceptionString = ByteBuffer.wrap(exceptionString.getBytes("utf-8")) + val serializedExceptionString = ByteBuffer.wrap(exceptionString.getBytes(UTF_8)) val errorMessage = createBufferMessage(serializedExceptionString, ackId) errorMessage.hasError = true errorMessage diff --git a/core/src/main/scala/org/apache/spark/network/nio/NioBlockTransferService.scala b/core/src/main/scala/org/apache/spark/network/nio/NioBlockTransferService.scala index 5add4fc433fb3..f56d165daba55 100644 --- a/core/src/main/scala/org/apache/spark/network/nio/NioBlockTransferService.scala +++ b/core/src/main/scala/org/apache/spark/network/nio/NioBlockTransferService.scala @@ -19,12 +19,14 @@ package org.apache.spark.network.nio import java.nio.ByteBuffer -import scala.concurrent.Future - -import org.apache.spark.{SparkException, Logging, SecurityManager, SparkConf} import org.apache.spark.network._ +import org.apache.spark.network.buffer.{ManagedBuffer, NioManagedBuffer} +import org.apache.spark.network.shuffle.BlockFetchingListener import org.apache.spark.storage.{BlockId, StorageLevel} import org.apache.spark.util.Utils +import org.apache.spark.{Logging, SecurityManager, SparkConf, SparkException} + +import scala.concurrent.Future /** @@ -71,20 +73,21 @@ final class NioBlockTransferService(conf: SparkConf, securityManager: SecurityMa /** * Tear down the transfer service. */ - override def stop(): Unit = { + override def close(): Unit = { if (cm != null) { cm.stop() } } override def fetchBlocks( - hostName: String, + host: String, port: Int, - blockIds: Seq[String], + execId: String, + blockIds: Array[String], listener: BlockFetchingListener): Unit = { checkInit() - val cmId = new ConnectionManagerId(hostName, port) + val cmId = new ConnectionManagerId(host, port) val blockMessageArray = new BlockMessageArray(blockIds.map { blockId => BlockMessage.fromGetBlock(GetBlock(BlockId(blockId))) }) @@ -96,21 +99,33 @@ final class NioBlockTransferService(conf: SparkConf, securityManager: SecurityMa val bufferMessage = message.asInstanceOf[BufferMessage] val blockMessageArray = BlockMessageArray.fromBufferMessage(bufferMessage) - for (blockMessage <- blockMessageArray) { - if (blockMessage.getType != BlockMessage.TYPE_GOT_BLOCK) { - listener.onBlockFetchFailure( - new SparkException(s"Unexpected message ${blockMessage.getType} received from $cmId")) - } else { - val blockId = blockMessage.getId - val networkSize = blockMessage.getData.limit() - listener.onBlockFetchSuccess( - blockId.toString, new NioByteBufferManagedBuffer(blockMessage.getData)) + // SPARK-4064: In some cases(eg. Remote block was removed) blockMessageArray may be empty. + if (blockMessageArray.isEmpty) { + blockIds.foreach { id => + listener.onBlockFetchFailure(id, new SparkException(s"Received empty message from $cmId")) + } + } else { + for (blockMessage: BlockMessage <- blockMessageArray) { + val msgType = blockMessage.getType + if (msgType != BlockMessage.TYPE_GOT_BLOCK) { + if (blockMessage.getId != null) { + listener.onBlockFetchFailure(blockMessage.getId.toString, + new SparkException(s"Unexpected message $msgType received from $cmId")) + } + } else { + val blockId = blockMessage.getId + val networkSize = blockMessage.getData.limit() + listener.onBlockFetchSuccess( + blockId.toString, new NioManagedBuffer(blockMessage.getData)) + } } } }(cm.futureExecContext) future.onFailure { case exception => - listener.onBlockFetchFailure(exception) + blockIds.foreach { blockId => + listener.onBlockFetchFailure(blockId, exception) + } }(cm.futureExecContext) } @@ -122,12 +137,12 @@ final class NioBlockTransferService(conf: SparkConf, securityManager: SecurityMa override def uploadBlock( hostname: String, port: Int, - blockId: String, + blockId: BlockId, blockData: ManagedBuffer, level: StorageLevel) : Future[Unit] = { checkInit() - val msg = PutBlock(BlockId(blockId), blockData.nioByteBuffer(), level) + val msg = PutBlock(blockId, blockData.nioByteBuffer(), level) val blockMessageArray = new BlockMessageArray(BlockMessage.fromPutBlock(msg)) val remoteCmId = new ConnectionManagerId(hostName, port) val reply = cm.sendMessageReliably(remoteCmId, blockMessageArray.toBufferMessage) @@ -149,10 +164,9 @@ final class NioBlockTransferService(conf: SparkConf, securityManager: SecurityMa val responseMessages = blockMessages.map(processBlockMessage).filter(_ != None).map(_.get) Some(new BlockMessageArray(responseMessages).toBufferMessage) } catch { - case e: Exception => { + case e: Exception => logError("Exception handling buffer message", e) Some(Message.createErrorMessage(e, msg.id)) - } } case otherMessage: Any => @@ -167,13 +181,13 @@ final class NioBlockTransferService(conf: SparkConf, securityManager: SecurityMa case BlockMessage.TYPE_PUT_BLOCK => val msg = PutBlock(blockMessage.getId, blockMessage.getData, blockMessage.getLevel) logDebug("Received [" + msg + "]") - putBlock(msg.id.toString, msg.data, msg.level) + putBlock(msg.id, msg.data, msg.level) None case BlockMessage.TYPE_GET_BLOCK => val msg = new GetBlock(blockMessage.getId) logDebug("Received [" + msg + "]") - val buffer = getBlock(msg.id.toString) + val buffer = getBlock(msg.id) if (buffer == null) { return None } @@ -183,20 +197,20 @@ final class NioBlockTransferService(conf: SparkConf, securityManager: SecurityMa } } - private def putBlock(blockId: String, bytes: ByteBuffer, level: StorageLevel) { + private def putBlock(blockId: BlockId, bytes: ByteBuffer, level: StorageLevel) { val startTimeMs = System.currentTimeMillis() logDebug("PutBlock " + blockId + " started from " + startTimeMs + " with data: " + bytes) - blockDataManager.putBlockData(blockId, new NioByteBufferManagedBuffer(bytes), level) + blockDataManager.putBlockData(blockId, new NioManagedBuffer(bytes), level) logDebug("PutBlock " + blockId + " used " + Utils.getUsedTimeMs(startTimeMs) + " with data size: " + bytes.limit) } - private def getBlock(blockId: String): ByteBuffer = { + private def getBlock(blockId: BlockId): ByteBuffer = { val startTimeMs = System.currentTimeMillis() logDebug("GetBlock " + blockId + " started from " + startTimeMs) - val buffer = blockDataManager.getBlockData(blockId).orNull + val buffer = blockDataManager.getBlockData(blockId) logDebug("GetBlock " + blockId + " used " + Utils.getUsedTimeMs(startTimeMs) + " and got buffer " + buffer) - if (buffer == null) null else buffer.nioByteBuffer() + buffer.nioByteBuffer() } } diff --git a/core/src/main/scala/org/apache/spark/partial/CountEvaluator.scala b/core/src/main/scala/org/apache/spark/partial/CountEvaluator.scala index 3155dfe165664..637492a97551b 100644 --- a/core/src/main/scala/org/apache/spark/partial/CountEvaluator.scala +++ b/core/src/main/scala/org/apache/spark/partial/CountEvaluator.scala @@ -17,7 +17,7 @@ package org.apache.spark.partial -import cern.jet.stat.Probability +import org.apache.commons.math3.distribution.NormalDistribution /** * An ApproximateEvaluator for counts. @@ -46,7 +46,8 @@ private[spark] class CountEvaluator(totalOutputs: Int, confidence: Double) val mean = (sum + 1 - p) / p val variance = (sum + 1) * (1 - p) / (p * p) val stdev = math.sqrt(variance) - val confFactor = Probability.normalInverse(1 - (1 - confidence) / 2) + val confFactor = new NormalDistribution(). + inverseCumulativeProbability(1 - (1 - confidence) / 2) val low = mean - confFactor * stdev val high = mean + confFactor * stdev new BoundedDouble(mean, confidence, low, high) diff --git a/core/src/main/scala/org/apache/spark/partial/GroupedCountEvaluator.scala b/core/src/main/scala/org/apache/spark/partial/GroupedCountEvaluator.scala index 8bb78123e3c9c..3ef3cc219dec6 100644 --- a/core/src/main/scala/org/apache/spark/partial/GroupedCountEvaluator.scala +++ b/core/src/main/scala/org/apache/spark/partial/GroupedCountEvaluator.scala @@ -24,7 +24,7 @@ import scala.collection.Map import scala.collection.mutable.HashMap import scala.reflect.ClassTag -import cern.jet.stat.Probability +import org.apache.commons.math3.distribution.NormalDistribution import org.apache.spark.util.collection.OpenHashMap @@ -55,7 +55,8 @@ private[spark] class GroupedCountEvaluator[T : ClassTag](totalOutputs: Int, conf new HashMap[T, BoundedDouble] } else { val p = outputsMerged.toDouble / totalOutputs - val confFactor = Probability.normalInverse(1 - (1 - confidence) / 2) + val confFactor = new NormalDistribution(). + inverseCumulativeProbability(1 - (1 - confidence) / 2) val result = new JHashMap[T, BoundedDouble](sums.size) sums.foreach { case (key, sum) => val mean = (sum + 1 - p) / p diff --git a/core/src/main/scala/org/apache/spark/partial/MeanEvaluator.scala b/core/src/main/scala/org/apache/spark/partial/MeanEvaluator.scala index d24959cba8727..787a21a61fdcf 100644 --- a/core/src/main/scala/org/apache/spark/partial/MeanEvaluator.scala +++ b/core/src/main/scala/org/apache/spark/partial/MeanEvaluator.scala @@ -17,7 +17,7 @@ package org.apache.spark.partial -import cern.jet.stat.Probability +import org.apache.commons.math3.distribution.{NormalDistribution, TDistribution} import org.apache.spark.util.StatCounter @@ -45,9 +45,10 @@ private[spark] class MeanEvaluator(totalOutputs: Int, confidence: Double) val stdev = math.sqrt(counter.sampleVariance / counter.count) val confFactor = { if (counter.count > 100) { - Probability.normalInverse(1 - (1 - confidence) / 2) + new NormalDistribution().inverseCumulativeProbability(1 - (1 - confidence) / 2) } else { - Probability.studentTInverse(1 - confidence, (counter.count - 1).toInt) + val degreesOfFreedom = (counter.count - 1).toInt + new TDistribution(degreesOfFreedom).inverseCumulativeProbability(1 - (1 - confidence) / 2) } } val low = mean - confFactor * stdev diff --git a/core/src/main/scala/org/apache/spark/partial/StudentTCacher.scala b/core/src/main/scala/org/apache/spark/partial/StudentTCacher.scala index 92915ee66d29f..828bf96c2c0bd 100644 --- a/core/src/main/scala/org/apache/spark/partial/StudentTCacher.scala +++ b/core/src/main/scala/org/apache/spark/partial/StudentTCacher.scala @@ -17,7 +17,7 @@ package org.apache.spark.partial -import cern.jet.stat.Probability +import org.apache.commons.math3.distribution.{TDistribution, NormalDistribution} /** * A utility class for caching Student's T distribution values for a given confidence level @@ -25,8 +25,10 @@ import cern.jet.stat.Probability * confidence intervals for many keys. */ private[spark] class StudentTCacher(confidence: Double) { + val NORMAL_APPROX_SAMPLE_SIZE = 100 // For samples bigger than this, use Gaussian approximation - val normalApprox = Probability.normalInverse(1 - (1 - confidence) / 2) + + val normalApprox = new NormalDistribution().inverseCumulativeProbability(1 - (1 - confidence) / 2) val cache = Array.fill[Double](NORMAL_APPROX_SAMPLE_SIZE)(-1.0) def get(sampleSize: Long): Double = { @@ -35,7 +37,8 @@ private[spark] class StudentTCacher(confidence: Double) { } else { val size = sampleSize.toInt if (cache(size) < 0) { - cache(size) = Probability.studentTInverse(1 - confidence, size - 1) + val tDist = new TDistribution(size - 1) + cache(size) = tDist.inverseCumulativeProbability(1 - (1 - confidence) / 2) } cache(size) } diff --git a/core/src/main/scala/org/apache/spark/partial/SumEvaluator.scala b/core/src/main/scala/org/apache/spark/partial/SumEvaluator.scala index d5336284571d2..1753c2561b678 100644 --- a/core/src/main/scala/org/apache/spark/partial/SumEvaluator.scala +++ b/core/src/main/scala/org/apache/spark/partial/SumEvaluator.scala @@ -17,7 +17,7 @@ package org.apache.spark.partial -import cern.jet.stat.Probability +import org.apache.commons.math3.distribution.{TDistribution, NormalDistribution} import org.apache.spark.util.StatCounter @@ -55,9 +55,10 @@ private[spark] class SumEvaluator(totalOutputs: Int, confidence: Double) val sumStdev = math.sqrt(sumVar) val confFactor = { if (counter.count > 100) { - Probability.normalInverse(1 - (1 - confidence) / 2) + new NormalDistribution().inverseCumulativeProbability(1 - (1 - confidence) / 2) } else { - Probability.studentTInverse(1 - confidence, (counter.count - 1).toInt) + val degreesOfFreedom = (counter.count - 1).toInt + new TDistribution(degreesOfFreedom).inverseCumulativeProbability(1 - (1 - confidence) / 2) } } val low = sumEstimate - confFactor * sumStdev diff --git a/core/src/main/scala/org/apache/spark/rdd/BinaryFileRDD.scala b/core/src/main/scala/org/apache/spark/rdd/BinaryFileRDD.scala new file mode 100644 index 0000000000000..6e66ddbdef788 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/rdd/BinaryFileRDD.scala @@ -0,0 +1,51 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.rdd + +import org.apache.hadoop.conf.{ Configurable, Configuration } +import org.apache.hadoop.io.Writable +import org.apache.hadoop.mapreduce._ +import org.apache.spark.input.StreamFileInputFormat +import org.apache.spark.{ Partition, SparkContext } + +private[spark] class BinaryFileRDD[T]( + sc: SparkContext, + inputFormatClass: Class[_ <: StreamFileInputFormat[T]], + keyClass: Class[String], + valueClass: Class[T], + @transient conf: Configuration, + minPartitions: Int) + extends NewHadoopRDD[String, T](sc, inputFormatClass, keyClass, valueClass, conf) { + + override def getPartitions: Array[Partition] = { + val inputFormat = inputFormatClass.newInstance + inputFormat match { + case configurable: Configurable => + configurable.setConf(conf) + case _ => + } + val jobContext = newJobContext(conf, jobId) + inputFormat.setMinPartitions(jobContext, minPartitions) + val rawSplits = inputFormat.getSplits(jobContext).toArray + val result = new Array[Partition](rawSplits.size) + for (i <- 0 until rawSplits.size) { + result(i) = new NewHadoopPartition(id, i, rawSplits(i).asInstanceOf[InputSplit with Writable]) + } + result + } +} diff --git a/core/src/main/scala/org/apache/spark/rdd/BlockRDD.scala b/core/src/main/scala/org/apache/spark/rdd/BlockRDD.scala index 2673ec22509e9..fffa1911f5bc2 100644 --- a/core/src/main/scala/org/apache/spark/rdd/BlockRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/BlockRDD.scala @@ -84,5 +84,9 @@ class BlockRDD[T: ClassTag](@transient sc: SparkContext, @transient val blockIds "Attempted to use %s after its blocks have been removed!".format(toString)) } } + + protected def getBlockIdLocations(): Map[BlockId, Seq[String]] = { + locations_ + } } diff --git a/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala b/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala index 775141775e06c..a157e36e2286e 100644 --- a/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala @@ -46,7 +46,6 @@ import org.apache.spark.rdd.HadoopRDD.HadoopMapPartitionsWithSplitRDD import org.apache.spark.util.{NextIterator, Utils} import org.apache.spark.scheduler.{HostTaskLocation, HDFSCacheTaskLocation} - /** * A Spark split class that wraps around a Hadoop InputSplit. */ @@ -212,8 +211,22 @@ class HadoopRDD[K, V]( val split = theSplit.asInstanceOf[HadoopPartition] logInfo("Input split: " + split.inputSplit) - var reader: RecordReader[K, V] = null val jobConf = getJobConf() + + val inputMetrics = new InputMetrics(DataReadMethod.Hadoop) + // Find a function that will return the FileSystem bytes read by this thread. Do this before + // creating RecordReader, because RecordReader's constructor might read some bytes + val bytesReadCallback = if (split.inputSplit.value.isInstanceOf[FileSplit]) { + SparkHadoopUtil.get.getFSBytesReadOnThreadCallback( + split.inputSplit.value.asInstanceOf[FileSplit].getPath, jobConf) + } else { + None + } + if (bytesReadCallback.isDefined) { + context.taskMetrics.inputMetrics = Some(inputMetrics) + } + + var reader: RecordReader[K, V] = null val inputFormat = getInputFormat(jobConf) HadoopRDD.addLocalConfiguration(new SimpleDateFormat("yyyyMMddHHmm").format(createTime), context.stageId, theSplit.index, context.attemptId.toInt, jobConf) @@ -224,18 +237,7 @@ class HadoopRDD[K, V]( val key: K = reader.createKey() val value: V = reader.createValue() - // Set the task input metrics. - val inputMetrics = new InputMetrics(DataReadMethod.Hadoop) - try { - /* bytesRead may not exactly equal the bytes read by a task: split boundaries aren't - * always at record boundaries, so tasks may need to read into other splits to complete - * a record. */ - inputMetrics.bytesRead = split.inputSplit.value.getLength() - } catch { - case e: java.io.IOException => - logWarning("Unable to get input size to set InputMetrics for task", e) - } - context.taskMetrics.inputMetrics = Some(inputMetrics) + var recordsSinceMetricsUpdate = 0 override def getNext() = { try { @@ -244,12 +246,36 @@ class HadoopRDD[K, V]( case eof: EOFException => finished = true } + + // Update bytes read metric every few records + if (recordsSinceMetricsUpdate == HadoopRDD.RECORDS_BETWEEN_BYTES_READ_METRIC_UPDATES + && bytesReadCallback.isDefined) { + recordsSinceMetricsUpdate = 0 + val bytesReadFn = bytesReadCallback.get + inputMetrics.bytesRead = bytesReadFn() + } else { + recordsSinceMetricsUpdate += 1 + } (key, value) } override def close() { try { reader.close() + if (bytesReadCallback.isDefined) { + val bytesReadFn = bytesReadCallback.get + inputMetrics.bytesRead = bytesReadFn() + } else if (split.inputSplit.value.isInstanceOf[FileSplit]) { + // If we can't get the bytes read from the FS stats, fall back to the split size, + // which may be inaccurate. + try { + inputMetrics.bytesRead = split.inputSplit.value.getLength + context.taskMetrics.inputMetrics = Some(inputMetrics) + } catch { + case e: java.io.IOException => + logWarning("Unable to get input size to set InputMetrics for task", e) + } + } } catch { case e: Exception => { if (!Utils.inShutdown()) { @@ -302,6 +328,9 @@ private[spark] object HadoopRDD extends Logging { */ val CONFIGURATION_INSTANTIATION_LOCK = new Object() + /** Update the input bytes read metric each time this number of records has been read */ + val RECORDS_BETWEEN_BYTES_READ_METRIC_UPDATES = 256 + /** * The three methods below are helpers for accessing the local map, a property of the SparkEnv of * the local process. diff --git a/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala b/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala index 0cccdefc5ee09..351e145f96f9a 100644 --- a/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala @@ -25,6 +25,7 @@ import scala.reflect.ClassTag import org.apache.hadoop.conf.{Configurable, Configuration} import org.apache.hadoop.io.Writable import org.apache.hadoop.mapreduce._ +import org.apache.hadoop.mapreduce.lib.input.FileSplit import org.apache.spark.annotation.DeveloperApi import org.apache.spark.input.WholeTextFileInputFormat @@ -36,6 +37,7 @@ import org.apache.spark.{SparkContext, TaskContext} import org.apache.spark.executor.{DataReadMethod, InputMetrics} import org.apache.spark.rdd.NewHadoopRDD.NewHadoopMapPartitionsWithSplitRDD import org.apache.spark.util.Utils +import org.apache.spark.deploy.SparkHadoopUtil private[spark] class NewHadoopPartition( rddId: Int, @@ -105,6 +107,20 @@ class NewHadoopRDD[K, V]( val split = theSplit.asInstanceOf[NewHadoopPartition] logInfo("Input split: " + split.serializableHadoopSplit) val conf = confBroadcast.value.value + + val inputMetrics = new InputMetrics(DataReadMethod.Hadoop) + // Find a function that will return the FileSystem bytes read by this thread. Do this before + // creating RecordReader, because RecordReader's constructor might read some bytes + val bytesReadCallback = if (split.serializableHadoopSplit.value.isInstanceOf[FileSplit]) { + SparkHadoopUtil.get.getFSBytesReadOnThreadCallback( + split.serializableHadoopSplit.value.asInstanceOf[FileSplit].getPath, conf) + } else { + None + } + if (bytesReadCallback.isDefined) { + context.taskMetrics.inputMetrics = Some(inputMetrics) + } + val attemptId = newTaskAttemptID(jobTrackerId, id, isMap = true, split.index, 0) val hadoopAttemptContext = newTaskAttemptContext(conf, attemptId) val format = inputFormatClass.newInstance @@ -117,22 +133,11 @@ class NewHadoopRDD[K, V]( split.serializableHadoopSplit.value, hadoopAttemptContext) reader.initialize(split.serializableHadoopSplit.value, hadoopAttemptContext) - val inputMetrics = new InputMetrics(DataReadMethod.Hadoop) - try { - /* bytesRead may not exactly equal the bytes read by a task: split boundaries aren't - * always at record boundaries, so tasks may need to read into other splits to complete - * a record. */ - inputMetrics.bytesRead = split.serializableHadoopSplit.value.getLength() - } catch { - case e: Exception => - logWarning("Unable to get input split size in order to set task input bytes", e) - } - context.taskMetrics.inputMetrics = Some(inputMetrics) - // Register an on-task-completion callback to close the input stream. context.addTaskCompletionListener(context => close()) var havePair = false var finished = false + var recordsSinceMetricsUpdate = 0 override def hasNext: Boolean = { if (!finished && !havePair) { @@ -147,12 +152,39 @@ class NewHadoopRDD[K, V]( throw new java.util.NoSuchElementException("End of stream") } havePair = false + + // Update bytes read metric every few records + if (recordsSinceMetricsUpdate == HadoopRDD.RECORDS_BETWEEN_BYTES_READ_METRIC_UPDATES + && bytesReadCallback.isDefined) { + recordsSinceMetricsUpdate = 0 + val bytesReadFn = bytesReadCallback.get + inputMetrics.bytesRead = bytesReadFn() + } else { + recordsSinceMetricsUpdate += 1 + } + (reader.getCurrentKey, reader.getCurrentValue) } private def close() { try { reader.close() + + // Update metrics with final amount + if (bytesReadCallback.isDefined) { + val bytesReadFn = bytesReadCallback.get + inputMetrics.bytesRead = bytesReadFn() + } else if (split.serializableHadoopSplit.value.isInstanceOf[FileSplit]) { + // If we can't get the bytes read from the FS stats, fall back to the split size, + // which may be inaccurate. + try { + inputMetrics.bytesRead = split.serializableHadoopSplit.value.getLength + context.taskMetrics.inputMetrics = Some(inputMetrics) + } catch { + case e: java.io.IOException => + logWarning("Unable to get input size to set InputMetrics for task", e) + } + } } catch { case e: Exception => { if (!Utils.inShutdown()) { @@ -233,7 +265,7 @@ private[spark] class WholeTextFileRDD( case _ => } val jobContext = newJobContext(conf, jobId) - inputFormat.setMaxSplitSize(jobContext, minPartitions) + inputFormat.setMinPartitions(jobContext, minPartitions) val rawSplits = inputFormat.getSplits(jobContext).toArray val result = new Array[Partition](rawSplits.size) for (i <- 0 until rawSplits.size) { diff --git a/core/src/main/scala/org/apache/spark/rdd/RDD.scala b/core/src/main/scala/org/apache/spark/rdd/RDD.scala index b7f125d01dfaf..c169b2d3fe97f 100644 --- a/core/src/main/scala/org/apache/spark/rdd/RDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/RDD.scala @@ -43,7 +43,8 @@ import org.apache.spark.partial.PartialResult import org.apache.spark.storage.StorageLevel import org.apache.spark.util.{BoundedPriorityQueue, Utils, CallSite} import org.apache.spark.util.collection.OpenHashMap -import org.apache.spark.util.random.{BernoulliSampler, PoissonSampler, SamplingUtils} +import org.apache.spark.util.random.{BernoulliSampler, PoissonSampler, BernoulliCellSampler, + SamplingUtils} /** * A Resilient Distributed Dataset (RDD), the basic abstraction in Spark. Represents an immutable, @@ -375,7 +376,8 @@ abstract class RDD[T: ClassTag]( val sum = weights.sum val normalizedCumWeights = weights.map(_ / sum).scanLeft(0.0d)(_ + _) normalizedCumWeights.sliding(2).map { x => - new PartitionwiseSampledRDD[T, T](this, new BernoulliSampler[T](x(0), x(1)), true, seed) + new PartitionwiseSampledRDD[T, T]( + this, new BernoulliCellSampler[T](x(0), x(1)), true, seed) }.toArray } diff --git a/core/src/main/scala/org/apache/spark/rdd/SampledRDD.scala b/core/src/main/scala/org/apache/spark/rdd/SampledRDD.scala index b097c30f8c231..9e8cee5331cf8 100644 --- a/core/src/main/scala/org/apache/spark/rdd/SampledRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/SampledRDD.scala @@ -21,8 +21,7 @@ import java.util.Random import scala.reflect.ClassTag -import cern.jet.random.Poisson -import cern.jet.random.engine.DRand +import org.apache.commons.math3.distribution.PoissonDistribution import org.apache.spark.{Partition, TaskContext} @@ -53,9 +52,11 @@ private[spark] class SampledRDD[T: ClassTag]( if (withReplacement) { // For large datasets, the expected number of occurrences of each element in a sample with // replacement is Poisson(frac). We use that to get a count for each element. - val poisson = new Poisson(frac, new DRand(split.seed)) + val poisson = new PoissonDistribution(frac) + poisson.reseedRandomGenerator(split.seed) + firstParent[T].iterator(split.prev, context).flatMap { element => - val count = poisson.nextInt() + val count = poisson.sample() if (count == 0) { Iterator.empty // Avoid object allocation when we return 0 items, which is quite often } else { diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala index f81fa6d8089fc..96114c0423a9e 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala @@ -124,6 +124,9 @@ class DAGScheduler( /** If enabled, we may run certain actions like take() and first() locally. */ private val localExecutionEnabled = sc.getConf.getBoolean("spark.localExecution.enabled", false) + /** If enabled, FetchFailed will not cause stage retry, in order to surface the problem. */ + private val disallowStageRetryForTest = sc.getConf.getBoolean("spark.test.noStageRetry", false) + private def initializeEventProcessActor() { // blocking the thread until supervisor is started, which ensures eventProcessActor is // not null before any job is submitted @@ -1050,7 +1053,7 @@ class DAGScheduler( logInfo("Resubmitted " + task + ", so marking it as still running") stage.pendingTasks += task - case FetchFailed(bmAddress, shuffleId, mapId, reduceId) => + case FetchFailed(bmAddress, shuffleId, mapId, reduceId, failureMessage) => val failedStage = stageIdToStage(task.stageId) val mapStage = shuffleToMapStage(shuffleId) @@ -1060,11 +1063,13 @@ class DAGScheduler( if (runningStages.contains(failedStage)) { logInfo(s"Marking $failedStage (${failedStage.name}) as failed " + s"due to a fetch failure from $mapStage (${mapStage.name})") - markStageAsFinished(failedStage, Some("Fetch failure")) + markStageAsFinished(failedStage, Some("Fetch failure: " + failureMessage)) runningStages -= failedStage } - if (failedStages.isEmpty && eventProcessActor != null) { + if (disallowStageRetryForTest) { + abortStage(failedStage, "Fetch failure will not retry stage due to testing config") + } else if (failedStages.isEmpty && eventProcessActor != null) { // Don't schedule an event to resubmit failed stages if failed isn't empty, because // in that case the event will already have been scheduled. eventProcessActor may be // null during unit tests. @@ -1086,7 +1091,7 @@ class DAGScheduler( // TODO: mark the executor as failed only if there were lots of fetch failures on it if (bmAddress != null) { - handleExecutorLost(bmAddress.executorId, Some(task.epoch)) + handleExecutorLost(bmAddress.executorId, fetchFailed = true, Some(task.epoch)) } case ExceptionFailure(className, description, stackTrace, metrics) => @@ -1106,25 +1111,35 @@ class DAGScheduler( * Responds to an executor being lost. This is called inside the event loop, so it assumes it can * modify the scheduler's internal state. Use executorLost() to post a loss event from outside. * + * We will also assume that we've lost all shuffle blocks associated with the executor if the + * executor serves its own blocks (i.e., we're not using external shuffle) OR a FetchFailed + * occurred, in which case we presume all shuffle data related to this executor to be lost. + * * Optionally the epoch during which the failure was caught can be passed to avoid allowing * stray fetch failures from possibly retriggering the detection of a node as lost. */ - private[scheduler] def handleExecutorLost(execId: String, maybeEpoch: Option[Long] = None) { + private[scheduler] def handleExecutorLost( + execId: String, + fetchFailed: Boolean, + maybeEpoch: Option[Long] = None) { val currentEpoch = maybeEpoch.getOrElse(mapOutputTracker.getEpoch) if (!failedEpoch.contains(execId) || failedEpoch(execId) < currentEpoch) { failedEpoch(execId) = currentEpoch logInfo("Executor lost: %s (epoch %d)".format(execId, currentEpoch)) blockManagerMaster.removeExecutor(execId) - // TODO: This will be really slow if we keep accumulating shuffle map stages - for ((shuffleId, stage) <- shuffleToMapStage) { - stage.removeOutputsOnExecutor(execId) - val locs = stage.outputLocs.map(list => if (list.isEmpty) null else list.head).toArray - mapOutputTracker.registerMapOutputs(shuffleId, locs, changeEpoch = true) - } - if (shuffleToMapStage.isEmpty) { - mapOutputTracker.incrementEpoch() + + if (!env.blockManager.externalShuffleServiceEnabled || fetchFailed) { + // TODO: This will be really slow if we keep accumulating shuffle map stages + for ((shuffleId, stage) <- shuffleToMapStage) { + stage.removeOutputsOnExecutor(execId) + val locs = stage.outputLocs.map(list => if (list.isEmpty) null else list.head).toArray + mapOutputTracker.registerMapOutputs(shuffleId, locs, changeEpoch = true) + } + if (shuffleToMapStage.isEmpty) { + mapOutputTracker.incrementEpoch() + } + clearCacheLocs() } - clearCacheLocs() } else { logDebug("Additional executor lost message for " + execId + "(epoch " + currentEpoch + ")") @@ -1382,7 +1397,7 @@ private[scheduler] class DAGSchedulerEventProcessActor(dagScheduler: DAGSchedule dagScheduler.handleExecutorAdded(execId, host) case ExecutorLost(execId) => - dagScheduler.handleExecutorLost(execId) + dagScheduler.handleExecutorLost(execId, fetchFailed = false) case BeginEvent(task, taskInfo) => dagScheduler.handleBeginEvent(task, taskInfo) diff --git a/core/src/main/scala/org/apache/spark/scheduler/EventLoggingListener.scala b/core/src/main/scala/org/apache/spark/scheduler/EventLoggingListener.scala index 100c9ba9b7809..597dbc884913c 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/EventLoggingListener.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/EventLoggingListener.scala @@ -142,7 +142,7 @@ private[spark] object EventLoggingListener extends Logging { val SPARK_VERSION_PREFIX = "SPARK_VERSION_" val COMPRESSION_CODEC_PREFIX = "COMPRESSION_CODEC_" val APPLICATION_COMPLETE = "APPLICATION_COMPLETE" - val LOG_FILE_PERMISSIONS = FsPermission.createImmutable(Integer.parseInt("770", 8).toShort) + val LOG_FILE_PERMISSIONS = new FsPermission(Integer.parseInt("770", 8).toShort) // A cache for compression codecs to avoid creating the same codec many times private val codecMap = new mutable.HashMap[String, CompressionCodec] diff --git a/core/src/main/scala/org/apache/spark/scheduler/JobLogger.scala b/core/src/main/scala/org/apache/spark/scheduler/JobLogger.scala index 54904bffdf10b..4e3d9de540783 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/JobLogger.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/JobLogger.scala @@ -215,7 +215,7 @@ class JobLogger(val user: String, val logDirName: String) extends SparkListener taskStatus += " STATUS=RESUBMITTED TID=" + taskInfo.taskId + " STAGE_ID=" + taskEnd.stageId stageLogInfo(taskEnd.stageId, taskStatus) - case FetchFailed(bmAddress, shuffleId, mapId, reduceId) => + case FetchFailed(bmAddress, shuffleId, mapId, reduceId, message) => taskStatus += " STATUS=FETCHFAILED TID=" + taskInfo.taskId + " STAGE_ID=" + taskEnd.stageId + " SHUFFLE_ID=" + shuffleId + " MAP_ID=" + mapId + " REDUCE_ID=" + reduceId diff --git a/core/src/main/scala/org/apache/spark/scheduler/Stage.scala b/core/src/main/scala/org/apache/spark/scheduler/Stage.scala index 071568cdfb429..cc13f57a49b89 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/Stage.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/Stage.scala @@ -102,6 +102,11 @@ private[spark] class Stage( } } + /** + * Removes all shuffle outputs associated with this executor. Note that this will also remove + * outputs which are served by an external shuffle server (if one exists), as they are still + * registered with this execId. + */ def removeOutputsOnExecutor(execId: String) { var becameUnavailable = false for (partition <- 0 until numPartitions) { @@ -131,4 +136,9 @@ private[spark] class Stage( override def toString = "Stage " + id override def hashCode(): Int = id + + override def equals(other: Any): Boolean = other match { + case stage: Stage => stage != null && stage.id == id + case _ => false + } } diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskResult.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskResult.scala index 11c19eeb6e42c..1f114a0207f7b 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskResult.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskResult.scala @@ -31,8 +31,8 @@ import org.apache.spark.util.Utils private[spark] sealed trait TaskResult[T] /** A reference to a DirectTaskResult that has been stored in the worker's BlockManager. */ -private[spark] -case class IndirectTaskResult[T](blockId: BlockId) extends TaskResult[T] with Serializable +private[spark] case class IndirectTaskResult[T](blockId: BlockId, size: Int) + extends TaskResult[T] with Serializable /** A TaskResult that contains the task's return value and accumulator updates. */ private[spark] diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskResultGetter.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskResultGetter.scala index 3f345ceeaaf7a..819b51e12ad8c 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskResultGetter.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskResultGetter.scala @@ -47,9 +47,18 @@ private[spark] class TaskResultGetter(sparkEnv: SparkEnv, scheduler: TaskSchedul getTaskResultExecutor.execute(new Runnable { override def run(): Unit = Utils.logUncaughtExceptions { try { - val result = serializer.get().deserialize[TaskResult[_]](serializedData) match { - case directResult: DirectTaskResult[_] => directResult - case IndirectTaskResult(blockId) => + val (result, size) = serializer.get().deserialize[TaskResult[_]](serializedData) match { + case directResult: DirectTaskResult[_] => + if (!taskSetManager.canFetchMoreResults(serializedData.limit())) { + return + } + (directResult, serializedData.limit()) + case IndirectTaskResult(blockId, size) => + if (!taskSetManager.canFetchMoreResults(size)) { + // dropped by executor if size is larger than maxResultSize + sparkEnv.blockManager.master.removeBlock(blockId) + return + } logDebug("Fetching indirect task result for TID %s".format(tid)) scheduler.handleTaskGettingResult(taskSetManager, tid) val serializedTaskResult = sparkEnv.blockManager.getRemoteBytes(blockId) @@ -64,9 +73,10 @@ private[spark] class TaskResultGetter(sparkEnv: SparkEnv, scheduler: TaskSchedul val deserializedResult = serializer.get().deserialize[DirectTaskResult[_]]( serializedTaskResult.get) sparkEnv.blockManager.master.removeBlock(blockId) - deserializedResult + (deserializedResult, size) } - result.metrics.resultSize = serializedData.limit() + + result.metrics.resultSize = size scheduler.handleSuccessfulTask(taskSetManager, tid, result) } catch { case cnf: ClassNotFoundException => @@ -93,7 +103,7 @@ private[spark] class TaskResultGetter(sparkEnv: SparkEnv, scheduler: TaskSchedul } } catch { case cnd: ClassNotFoundException => - // Log an error but keep going here -- the task failed, so not catastropic if we can't + // Log an error but keep going here -- the task failed, so not catastrophic if we can't // deserialize the reason. val loader = Utils.getContextOrSparkClassLoader logError( diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskScheduler.scala index a129a434c9a1a..f095915352b17 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskScheduler.scala @@ -23,7 +23,7 @@ import org.apache.spark.storage.BlockManagerId /** * Low-level task scheduler interface, currently implemented exclusively by TaskSchedulerImpl. - * This interface allows plugging in different task schedulers. Each TaskScheduler schedulers tasks + * This interface allows plugging in different task schedulers. Each TaskScheduler schedules tasks * for a single SparkContext. These schedulers get sets of tasks submitted to them from the * DAGScheduler for each stage, and are responsible for sending the tasks to the cluster, running * them, retrying if there are failures, and mitigating stragglers. They return events to the @@ -41,7 +41,7 @@ private[spark] trait TaskScheduler { // Invoked after system has successfully initialized (typically in spark context). // Yarn uses this to bootstrap allocation of resources based on preferred locations, - // wait for slave registerations, etc. + // wait for slave registrations, etc. def postStartHook() { } // Disconnect from the cluster. diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala index 2b39c7fc872da..cd3c015321e85 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala @@ -34,7 +34,6 @@ import org.apache.spark.scheduler.SchedulingMode.SchedulingMode import org.apache.spark.util.Utils import org.apache.spark.executor.TaskMetrics import org.apache.spark.storage.BlockManagerId -import akka.actor.Props /** * Schedules tasks for multiple types of clusters by acting through a SchedulerBackend. diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala index a6c23fc85a1b0..d8fb640350343 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala @@ -23,13 +23,12 @@ import java.util.Arrays import scala.collection.mutable.ArrayBuffer import scala.collection.mutable.HashMap import scala.collection.mutable.HashSet -import scala.math.max -import scala.math.min +import scala.math.{min, max} import org.apache.spark._ -import org.apache.spark.TaskState.TaskState import org.apache.spark.executor.TaskMetrics -import org.apache.spark.util.{Clock, SystemClock} +import org.apache.spark.TaskState.TaskState +import org.apache.spark.util.{Clock, SystemClock, Utils} /** * Schedules the tasks within a single TaskSet in the TaskSchedulerImpl. This class keeps track of @@ -68,6 +67,9 @@ private[spark] class TaskSetManager( val SPECULATION_QUANTILE = conf.getDouble("spark.speculation.quantile", 0.75) val SPECULATION_MULTIPLIER = conf.getDouble("spark.speculation.multiplier", 1.5) + // Limit of bytes for total size of results (default is 1GB) + val maxResultSize = Utils.getMaxResultSize(conf) + // Serializer for closures and tasks. val env = SparkEnv.get val ser = env.closureSerializer.newInstance() @@ -89,6 +91,8 @@ private[spark] class TaskSetManager( var stageId = taskSet.stageId var name = "TaskSet_" + taskSet.stageId.toString var parent: Pool = null + var totalResultSize = 0L + var calculatedTasks = 0 val runningTasksSet = new HashSet[Long] override def runningTasks = runningTasksSet.size @@ -515,12 +519,33 @@ private[spark] class TaskSetManager( index } + /** + * Marks the task as getting result and notifies the DAG Scheduler + */ def handleTaskGettingResult(tid: Long) = { val info = taskInfos(tid) info.markGettingResult() sched.dagScheduler.taskGettingResult(info) } + /** + * Check whether has enough quota to fetch the result with `size` bytes + */ + def canFetchMoreResults(size: Long): Boolean = synchronized { + totalResultSize += size + calculatedTasks += 1 + if (maxResultSize > 0 && totalResultSize > maxResultSize) { + val msg = s"Total size of serialized results of ${calculatedTasks} tasks " + + s"(${Utils.bytesToString(totalResultSize)}) is bigger than maxResultSize " + + s"(${Utils.bytesToString(maxResultSize)})" + logError(msg) + abort(msg) + false + } else { + true + } + } + /** * Marks the task as successful and notifies the DAGScheduler that a task has ended. */ @@ -687,10 +712,11 @@ private[spark] class TaskSetManager( addPendingTask(index, readding=true) } - // Re-enqueue any tasks that ran on the failed executor if this is a shuffle map stage. + // Re-enqueue any tasks that ran on the failed executor if this is a shuffle map stage, + // and we are not using an external shuffle server which could serve the shuffle outputs. // The reason is the next stage wouldn't be able to fetch the data from this dead executor // so we would need to rerun these tasks on other executors. - if (tasks(0).isInstanceOf[ShuffleMapTask]) { + if (tasks(0).isInstanceOf[ShuffleMapTask] && !env.blockManager.externalShuffleServiceEnabled) { for ((tid, info) <- taskInfos if info.executorId == execId) { val index = taskInfos(tid).index if (successful(index)) { @@ -706,7 +732,7 @@ private[spark] class TaskSetManager( } // Also re-enqueue any tasks that were running on the node for ((tid, info) <- taskInfos if info.running && info.executorId == execId) { - handleFailedTask(tid, TaskState.FAILED, ExecutorLostFailure) + handleFailedTask(tid, TaskState.FAILED, ExecutorLostFailure(execId)) } // recalculate valid locality levels and waits when executor is lost recomputeLocality() diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedClusterMessage.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedClusterMessage.scala index fb8160abc59db..1da6fe976da5b 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedClusterMessage.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedClusterMessage.scala @@ -66,7 +66,19 @@ private[spark] object CoarseGrainedClusterMessages { case class RemoveExecutor(executorId: String, reason: String) extends CoarseGrainedClusterMessage - case class AddWebUIFilter(filterName:String, filterParams: Map[String, String], proxyBase :String) + // Exchanged between the driver and the AM in Yarn client mode + case class AddWebUIFilter(filterName:String, filterParams: Map[String, String], proxyBase: String) extends CoarseGrainedClusterMessage + // Messages exchanged between the driver and the cluster manager for executor allocation + // In Yarn mode, these are exchanged between the driver and the AM + + case object RegisterClusterManager extends CoarseGrainedClusterMessage + + // Request executors by specifying the new total number of executors desired + // This includes executors already pending or running + case class RequestExecutors(requestedTotal: Int) extends CoarseGrainedClusterMessage + + case class KillExecutors(executorIds: Seq[String]) extends CoarseGrainedClusterMessage + } diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala index 59aed6b72fe42..7a6ee56f81689 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala @@ -31,7 +31,6 @@ import org.apache.spark.{SparkEnv, Logging, SparkException, TaskState} import org.apache.spark.scheduler.{SchedulerBackend, SlaveLost, TaskDescription, TaskSchedulerImpl, WorkerOffer} import org.apache.spark.scheduler.cluster.CoarseGrainedClusterMessages._ import org.apache.spark.util.{ActorLogReceive, SerializableBuffer, AkkaUtils, Utils} -import org.apache.spark.ui.JettyUtils /** * A scheduler backend that waits for coarse grained executors to connect to it through Akka. @@ -42,7 +41,7 @@ import org.apache.spark.ui.JettyUtils * (spark.deploy.*). */ private[spark] -class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, actorSystem: ActorSystem) +class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val actorSystem: ActorSystem) extends SchedulerBackend with Logging { // Use an atomic variable to track total number of cores in the cluster for simplicity and speed @@ -61,10 +60,17 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, actorSystem: A conf.getInt("spark.scheduler.maxRegisteredResourcesWaitingTime", 30000) val createTime = System.currentTimeMillis() + private val executorDataMap = new HashMap[String, ExecutorData] + + // Number of executors requested from the cluster manager that have not registered yet + private var numPendingExecutors = 0 + + // Executors we have requested the cluster manager to kill that have not died yet + private val executorsPendingToRemove = new HashSet[String] + class DriverActor(sparkProperties: Seq[(String, String)]) extends Actor with ActorLogReceive { override protected def log = CoarseGrainedSchedulerBackend.this.log private val addressToExecutorId = new HashMap[Address, String] - private val executorDataMap = new HashMap[String, ExecutorData] override def preStart() { // Listen for remote client disconnection events, since they don't go through Akka's watch() @@ -84,12 +90,21 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, actorSystem: A } else { logInfo("Registered executor: " + sender + " with ID " + executorId) sender ! RegisteredExecutor - executorDataMap.put(executorId, new ExecutorData(sender, sender.path.address, - Utils.parseHostPort(hostPort)._1, cores, cores)) addressToExecutorId(sender.path.address) = executorId totalCoreCount.addAndGet(cores) totalRegisteredExecutors.addAndGet(1) + val (host, _) = Utils.parseHostPort(hostPort) + val data = new ExecutorData(sender, sender.path.address, host, cores, cores) + // This must be synchronized because variables mutated + // in this block are read when requesting executors + CoarseGrainedSchedulerBackend.this.synchronized { + executorDataMap.put(executorId, data) + if (numPendingExecutors > 0) { + numPendingExecutors -= 1 + logDebug(s"Decremented number of pending executors ($numPendingExecutors left)") + } + } makeOffers() } @@ -128,10 +143,6 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, actorSystem: A removeExecutor(executorId, reason) sender ! true - case AddWebUIFilter(filterName, filterParams, proxyBase) => - addWebUIFilter(filterName, filterParams, proxyBase) - sender ! true - case DisassociatedEvent(_, address, _) => addressToExecutorId.get(address).foreach(removeExecutor(_, "remote Akka client disassociated")) @@ -183,13 +194,18 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, actorSystem: A } // Remove a disconnected slave from the cluster - def removeExecutor(executorId: String, reason: String) { + def removeExecutor(executorId: String, reason: String): Unit = { executorDataMap.get(executorId) match { case Some(executorInfo) => - executorDataMap -= executorId + // This must be synchronized because variables mutated + // in this block are read when requesting executors + CoarseGrainedSchedulerBackend.this.synchronized { + executorDataMap -= executorId + executorsPendingToRemove -= executorId + } totalCoreCount.addAndGet(-executorInfo.totalCores) scheduler.executorLost(executorId, SlaveLost(reason)) - case None => logError(s"Asked to remove non existant executor $executorId") + case None => logError(s"Asked to remove non-existent executor $executorId") } } } @@ -274,21 +290,62 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, actorSystem: A false } - // Add filters to the SparkUI - def addWebUIFilter(filterName: String, filterParams: Map[String, String], proxyBase: String) { - if (proxyBase != null && proxyBase.nonEmpty) { - System.setProperty("spark.ui.proxyBase", proxyBase) - } + /** + * Return the number of executors currently registered with this backend. + */ + def numExistingExecutors: Int = executorDataMap.size + + /** + * Request an additional number of executors from the cluster manager. + * Return whether the request is acknowledged. + */ + final def requestExecutors(numAdditionalExecutors: Int): Boolean = synchronized { + logInfo(s"Requesting $numAdditionalExecutors additional executor(s) from the cluster manager") + logDebug(s"Number of pending executors is now $numPendingExecutors") + numPendingExecutors += numAdditionalExecutors + // Account for executors pending to be added or removed + val newTotal = numExistingExecutors + numPendingExecutors - executorsPendingToRemove.size + doRequestTotalExecutors(newTotal) + } - val hasFilter = (filterName != null && filterName.nonEmpty && - filterParams != null && filterParams.nonEmpty) - if (hasFilter) { - logInfo(s"Add WebUI Filter. $filterName, $filterParams, $proxyBase") - conf.set("spark.ui.filters", filterName) - filterParams.foreach { case (k, v) => conf.set(s"spark.$filterName.param.$k", v) } - scheduler.sc.ui.foreach { ui => JettyUtils.addFilters(ui.getHandlers, conf) } + /** + * Request executors from the cluster manager by specifying the total number desired, + * including existing pending and running executors. + * + * The semantics here guarantee that we do not over-allocate executors for this application, + * since a later request overrides the value of any prior request. The alternative interface + * of requesting a delta of executors risks double counting new executors when there are + * insufficient resources to satisfy the first request. We make the assumption here that the + * cluster manager will eventually fulfill all requests when resources free up. + * + * Return whether the request is acknowledged. + */ + protected def doRequestTotalExecutors(requestedTotal: Int): Boolean = false + + /** + * Request that the cluster manager kill the specified executors. + * Return whether the kill request is acknowledged. + */ + final def killExecutors(executorIds: Seq[String]): Boolean = { + logInfo(s"Requesting to kill executor(s) ${executorIds.mkString(", ")}") + val filteredExecutorIds = new ArrayBuffer[String] + executorIds.foreach { id => + if (executorDataMap.contains(id)) { + filteredExecutorIds += id + } else { + logWarning(s"Executor to kill $id does not exist!") + } } + executorsPendingToRemove ++= filteredExecutorIds + doKillExecutors(filteredExecutorIds) } + + /** + * Kill the given list of executors through the cluster manager. + * Return whether the kill request is acknowledged. + */ + protected def doKillExecutors(executorIds: Seq[String]): Boolean = false + } private[spark] object CoarseGrainedSchedulerBackend { diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/YarnSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/YarnSchedulerBackend.scala new file mode 100644 index 0000000000000..50721b9d6cd6c --- /dev/null +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/YarnSchedulerBackend.scala @@ -0,0 +1,142 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.scheduler.cluster + +import akka.actor.{Actor, ActorRef, Props} +import akka.remote.{DisassociatedEvent, RemotingLifecycleEvent} + +import org.apache.spark.SparkContext +import org.apache.spark.scheduler.cluster.CoarseGrainedClusterMessages._ +import org.apache.spark.scheduler.TaskSchedulerImpl +import org.apache.spark.ui.JettyUtils +import org.apache.spark.util.AkkaUtils + +/** + * Abstract Yarn scheduler backend that contains common logic + * between the client and cluster Yarn scheduler backends. + */ +private[spark] abstract class YarnSchedulerBackend( + scheduler: TaskSchedulerImpl, + sc: SparkContext) + extends CoarseGrainedSchedulerBackend(scheduler, sc.env.actorSystem) { + + if (conf.getOption("spark.scheduler.minRegisteredResourcesRatio").isEmpty) { + minRegisteredRatio = 0.8 + } + + protected var totalExpectedExecutors = 0 + + private val yarnSchedulerActor: ActorRef = + actorSystem.actorOf( + Props(new YarnSchedulerActor), + name = YarnSchedulerBackend.ACTOR_NAME) + + private implicit val askTimeout = AkkaUtils.askTimeout(sc.conf) + + /** + * Request executors from the ApplicationMaster by specifying the total number desired. + * This includes executors already pending or running. + */ + override def doRequestTotalExecutors(requestedTotal: Int): Boolean = { + AkkaUtils.askWithReply[Boolean]( + RequestExecutors(requestedTotal), yarnSchedulerActor, askTimeout) + } + + /** + * Request that the ApplicationMaster kill the specified executors. + */ + override def doKillExecutors(executorIds: Seq[String]): Boolean = { + AkkaUtils.askWithReply[Boolean]( + KillExecutors(executorIds), yarnSchedulerActor, askTimeout) + } + + override def sufficientResourcesRegistered(): Boolean = { + totalRegisteredExecutors.get() >= totalExpectedExecutors * minRegisteredRatio + } + + /** + * Add filters to the SparkUI. + */ + private def addWebUIFilter( + filterName: String, + filterParams: Map[String, String], + proxyBase: String): Unit = { + if (proxyBase != null && proxyBase.nonEmpty) { + System.setProperty("spark.ui.proxyBase", proxyBase) + } + + val hasFilter = + filterName != null && filterName.nonEmpty && + filterParams != null && filterParams.nonEmpty + if (hasFilter) { + logInfo(s"Add WebUI Filter. $filterName, $filterParams, $proxyBase") + conf.set("spark.ui.filters", filterName) + filterParams.foreach { case (k, v) => conf.set(s"spark.$filterName.param.$k", v) } + scheduler.sc.ui.foreach { ui => JettyUtils.addFilters(ui.getHandlers, conf) } + } + } + + /** + * An actor that communicates with the ApplicationMaster. + */ + private class YarnSchedulerActor extends Actor { + private var amActor: Option[ActorRef] = None + + override def preStart(): Unit = { + // Listen for disassociation events + context.system.eventStream.subscribe(self, classOf[RemotingLifecycleEvent]) + } + + override def receive = { + case RegisterClusterManager => + logInfo(s"ApplicationMaster registered as $sender") + amActor = Some(sender) + + case r: RequestExecutors => + amActor match { + case Some(actor) => + sender ! AkkaUtils.askWithReply[Boolean](r, actor, askTimeout) + case None => + logWarning("Attempted to request executors before the AM has registered!") + sender ! false + } + + case k: KillExecutors => + amActor match { + case Some(actor) => + sender ! AkkaUtils.askWithReply[Boolean](k, actor, askTimeout) + case None => + logWarning("Attempted to kill executors before the AM has registered!") + sender ! false + } + + case AddWebUIFilter(filterName, filterParams, proxyBase) => + addWebUIFilter(filterName, filterParams, proxyBase) + sender ! true + + case d: DisassociatedEvent => + if (amActor.isDefined && sender == amActor.get) { + logWarning(s"ApplicationMaster has disassociated: $d") + } + } + } +} + +private[spark] object YarnSchedulerBackend { + val ACTOR_NAME = "YarnScheduler" +} diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala index d7f88de4b40aa..d8c0e2f66df01 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala @@ -31,6 +31,7 @@ import org.apache.mesos.Protos.{TaskInfo => MesosTaskInfo, TaskState => MesosTas import org.apache.spark.{Logging, SparkContext, SparkEnv, SparkException} import org.apache.spark.scheduler.TaskSchedulerImpl import org.apache.spark.scheduler.cluster.CoarseGrainedSchedulerBackend +import org.apache.spark.util.Utils /** * A SchedulerBackend that runs tasks on Mesos, but uses "coarse-grained" tasks, where it holds @@ -120,16 +121,18 @@ private[spark] class CoarseMesosSchedulerBackend( environment.addVariables( Environment.Variable.newBuilder().setName("SPARK_CLASSPATH").setValue(cp).build()) } - val extraJavaOpts = conf.getOption("spark.executor.extraJavaOptions") + val extraJavaOpts = conf.get("spark.executor.extraJavaOptions", "") - val libraryPathOption = "spark.executor.extraLibraryPath" - val extraLibraryPath = conf.getOption(libraryPathOption).map(p => s"-Djava.library.path=$p") - val extraOpts = Seq(extraJavaOpts, extraLibraryPath).flatten.mkString(" ") + // Set the environment variable through a command prefix + // to append to the existing value of the variable + val prefixEnv = conf.getOption("spark.executor.extraLibraryPath").map { p => + Utils.libraryPathEnvPrefix(Seq(p)) + }.getOrElse("") environment.addVariables( Environment.Variable.newBuilder() .setName("SPARK_EXECUTOR_OPTS") - .setValue(extraOpts) + .setValue(extraJavaOpts) .build()) sc.executorEnvs.foreach { case (key, value) => @@ -150,16 +153,17 @@ private[spark] class CoarseMesosSchedulerBackend( if (uri == null) { val runScript = new File(executorSparkHome, "./bin/spark-class").getCanonicalPath command.setValue( - "\"%s\" org.apache.spark.executor.CoarseGrainedExecutorBackend %s %s %s %d %s".format( - runScript, driverUrl, offer.getSlaveId.getValue, offer.getHostname, numCores, appId)) + "%s \"%s\" org.apache.spark.executor.CoarseGrainedExecutorBackend %s %s %s %d %s".format( + prefixEnv, runScript, driverUrl, offer.getSlaveId.getValue, + offer.getHostname, numCores, appId)) } else { // Grab everything to the first '.'. We'll use that and '*' to // glob the directory "correctly". val basename = uri.split('/').last.split('.').head command.setValue( - ("cd %s*; " + + ("cd %s*; %s " + "./bin/spark-class org.apache.spark.executor.CoarseGrainedExecutorBackend %s %s %s %d %s") - .format(basename, driverUrl, offer.getSlaveId.getValue, + .format(basename, prefixEnv, driverUrl, offer.getSlaveId.getValue, offer.getHostname, numCores, appId)) command.addUris(CommandInfo.URI.newBuilder().setValue(uri)) } diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackend.scala index e0f2fd622f54c..8e2faff90f9b2 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackend.scala @@ -98,15 +98,16 @@ private[spark] class MesosSchedulerBackend( environment.addVariables( Environment.Variable.newBuilder().setName("SPARK_CLASSPATH").setValue(cp).build()) } - val extraJavaOpts = sc.conf.getOption("spark.executor.extraJavaOptions") - val extraLibraryPath = sc.conf.getOption("spark.executor.extraLibraryPath").map { lp => - s"-Djava.library.path=$lp" - } - val extraOpts = Seq(extraJavaOpts, extraLibraryPath).flatten.mkString(" ") + val extraJavaOpts = sc.conf.getOption("spark.executor.extraJavaOptions").getOrElse("") + + val prefixEnv = sc.conf.getOption("spark.executor.extraLibraryPath").map { p => + Utils.libraryPathEnvPrefix(Seq(p)) + }.getOrElse("") + environment.addVariables( Environment.Variable.newBuilder() .setName("SPARK_EXECUTOR_OPTS") - .setValue(extraOpts) + .setValue(extraJavaOpts) .build()) sc.executorEnvs.foreach { case (key, value) => environment.addVariables(Environment.Variable.newBuilder() @@ -118,12 +119,13 @@ private[spark] class MesosSchedulerBackend( .setEnvironment(environment) val uri = sc.conf.get("spark.executor.uri", null) if (uri == null) { - command.setValue(new File(executorSparkHome, "/sbin/spark-executor").getCanonicalPath) + val executorPath = new File(executorSparkHome, "/sbin/spark-executor").getCanonicalPath + command.setValue("%s %s".format(prefixEnv, executorPath)) } else { // Grab everything to the first '.'. We'll use that and '*' to // glob the directory "correctly". val basename = uri.split('/').last.split('.').head - command.setValue("cd %s*; ./sbin/spark-executor".format(basename)) + command.setValue("cd %s*; %s ./sbin/spark-executor".format(basename, prefixEnv)) command.addUris(CommandInfo.URI.newBuilder().setValue(uri)) } val cpus = Resource.newBuilder() diff --git a/core/src/main/scala/org/apache/spark/scheduler/local/LocalBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/local/LocalBackend.scala index 58b78f041cd85..c0264836de738 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/local/LocalBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/local/LocalBackend.scala @@ -21,7 +21,7 @@ import java.nio.ByteBuffer import akka.actor.{Actor, ActorRef, Props} -import org.apache.spark.{Logging, SparkEnv, TaskState} +import org.apache.spark.{Logging, SparkContext, SparkEnv, TaskState} import org.apache.spark.TaskState.TaskState import org.apache.spark.executor.{Executor, ExecutorBackend} import org.apache.spark.scheduler.{SchedulerBackend, TaskSchedulerImpl, WorkerOffer} @@ -47,7 +47,7 @@ private[spark] class LocalActor( private var freeCores = totalCores - private val localExecutorId = "localhost" + private val localExecutorId = SparkContext.DRIVER_IDENTIFIER private val localExecutorHostname = "localhost" val executor = new Executor( diff --git a/core/src/main/scala/org/apache/spark/serializer/Serializer.scala b/core/src/main/scala/org/apache/spark/serializer/Serializer.scala index a9144cdd97b8c..ca6e971d227fb 100644 --- a/core/src/main/scala/org/apache/spark/serializer/Serializer.scala +++ b/core/src/main/scala/org/apache/spark/serializer/Serializer.scala @@ -17,14 +17,14 @@ package org.apache.spark.serializer -import java.io.{ByteArrayOutputStream, EOFException, InputStream, OutputStream} +import java.io._ import java.nio.ByteBuffer import scala.reflect.ClassTag -import org.apache.spark.SparkEnv +import org.apache.spark.{SparkConf, SparkEnv} import org.apache.spark.annotation.DeveloperApi -import org.apache.spark.util.{ByteBufferInputStream, NextIterator} +import org.apache.spark.util.{Utils, ByteBufferInputStream, NextIterator} /** * :: DeveloperApi :: diff --git a/core/src/main/scala/org/apache/spark/shuffle/FetchFailedException.scala b/core/src/main/scala/org/apache/spark/shuffle/FetchFailedException.scala index 71c08e9d5a8c3..0c1b6f4defdb3 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/FetchFailedException.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/FetchFailedException.scala @@ -19,6 +19,7 @@ package org.apache.spark.shuffle import org.apache.spark.storage.BlockManagerId import org.apache.spark.{FetchFailed, TaskEndReason} +import org.apache.spark.util.Utils /** * Failed to fetch a shuffle block. The executor catches this exception and propagates it @@ -30,13 +31,11 @@ private[spark] class FetchFailedException( bmAddress: BlockManagerId, shuffleId: Int, mapId: Int, - reduceId: Int) - extends Exception { - - override def getMessage: String = - "Fetch failed: %s %d %d %d".format(bmAddress, shuffleId, mapId, reduceId) + reduceId: Int, + message: String) + extends Exception(message) { - def toTaskEndReason: TaskEndReason = FetchFailed(bmAddress, shuffleId, mapId, reduceId) + def toTaskEndReason: TaskEndReason = FetchFailed(bmAddress, shuffleId, mapId, reduceId, message) } /** @@ -46,7 +45,4 @@ private[spark] class MetadataFetchFailedException( shuffleId: Int, reduceId: Int, message: String) - extends FetchFailedException(null, shuffleId, -1, reduceId) { - - override def getMessage: String = message -} + extends FetchFailedException(null, shuffleId, -1, reduceId, message) diff --git a/core/src/main/scala/org/apache/spark/shuffle/FileShuffleBlockManager.scala b/core/src/main/scala/org/apache/spark/shuffle/FileShuffleBlockManager.scala index 439981d232349..f03e8e4bf1b7e 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/FileShuffleBlockManager.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/FileShuffleBlockManager.scala @@ -24,9 +24,9 @@ import java.util.concurrent.atomic.AtomicInteger import scala.collection.JavaConversions._ -import org.apache.spark.{SparkEnv, SparkConf, Logging} +import org.apache.spark.{Logging, SparkConf, SparkEnv} import org.apache.spark.executor.ShuffleWriteMetrics -import org.apache.spark.network.{FileSegmentManagedBuffer, ManagedBuffer} +import org.apache.spark.network.buffer.{FileSegmentManagedBuffer, ManagedBuffer} import org.apache.spark.serializer.Serializer import org.apache.spark.shuffle.FileShuffleBlockManager.ShuffleFileGroup import org.apache.spark.storage._ @@ -62,7 +62,8 @@ private[spark] trait ShuffleWriterGroup { * each block stored in each file. In order to find the location of a shuffle block, we search the * files within a ShuffleFileGroups associated with the block's reducer. */ - +// Note: Changes to the format in this file should be kept in sync with +// org.apache.spark.network.shuffle.StandaloneShuffleBlockManager#getHashBasedShuffleBlockData(). private[spark] class FileShuffleBlockManager(conf: SparkConf) extends ShuffleBlockManager with Logging { diff --git a/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockManager.scala b/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockManager.scala index 4ab34336d3f01..a48f0c9eceb5e 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockManager.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockManager.scala @@ -20,8 +20,10 @@ package org.apache.spark.shuffle import java.io._ import java.nio.ByteBuffer +import com.google.common.io.ByteStreams + import org.apache.spark.SparkEnv -import org.apache.spark.network.{ManagedBuffer, FileSegmentManagedBuffer} +import org.apache.spark.network.buffer.{FileSegmentManagedBuffer, ManagedBuffer} import org.apache.spark.storage._ /** @@ -33,6 +35,8 @@ import org.apache.spark.storage._ * as the filename postfix for data file, and ".index" as the filename postfix for index file. * */ +// Note: Changes to the format in this file should be kept in sync with +// org.apache.spark.network.shuffle.StandaloneShuffleBlockManager#getSortBasedShuffleBlockData(). private[spark] class IndexShuffleBlockManager extends ShuffleBlockManager { @@ -101,7 +105,7 @@ class IndexShuffleBlockManager extends ShuffleBlockManager { val in = new DataInputStream(new FileInputStream(indexFile)) try { - in.skip(blockId.reduceId * 8) + ByteStreams.skipFully(in, blockId.reduceId * 8) val offset = in.readLong() val nextOffset = in.readLong() new FileSegmentManagedBuffer( diff --git a/core/src/main/scala/org/apache/spark/shuffle/ShuffleBlockManager.scala b/core/src/main/scala/org/apache/spark/shuffle/ShuffleBlockManager.scala index 63863cc0250a3..b521f0c7fc77e 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/ShuffleBlockManager.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/ShuffleBlockManager.scala @@ -18,8 +18,7 @@ package org.apache.spark.shuffle import java.nio.ByteBuffer - -import org.apache.spark.network.ManagedBuffer +import org.apache.spark.network.buffer.ManagedBuffer import org.apache.spark.storage.ShuffleBlockId private[spark] diff --git a/core/src/main/scala/org/apache/spark/shuffle/ShuffleReader.scala b/core/src/main/scala/org/apache/spark/shuffle/ShuffleReader.scala index b30e366d06006..292e48314ee10 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/ShuffleReader.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/ShuffleReader.scala @@ -24,6 +24,10 @@ private[spark] trait ShuffleReader[K, C] { /** Read the combined key-values for this reduce task */ def read(): Iterator[Product2[K, C]] - /** Close this reader */ - def stop(): Unit + /** + * Close this reader. + * TODO: Add this back when we make the ShuffleReader a developer API that others can implement + * (at which point this will likely be necessary). + */ + // def stop(): Unit } diff --git a/core/src/main/scala/org/apache/spark/shuffle/hash/BlockStoreShuffleFetcher.scala b/core/src/main/scala/org/apache/spark/shuffle/hash/BlockStoreShuffleFetcher.scala index 6cf9305977a3c..0d5247f4176d4 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/hash/BlockStoreShuffleFetcher.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/hash/BlockStoreShuffleFetcher.scala @@ -19,12 +19,13 @@ package org.apache.spark.shuffle.hash import scala.collection.mutable.ArrayBuffer import scala.collection.mutable.HashMap +import scala.util.{Failure, Success, Try} import org.apache.spark._ import org.apache.spark.serializer.Serializer import org.apache.spark.shuffle.FetchFailedException import org.apache.spark.storage.{BlockId, BlockManagerId, ShuffleBlockFetcherIterator, ShuffleBlockId} -import org.apache.spark.util.CompletionIterator +import org.apache.spark.util.{CompletionIterator, Utils} private[hash] object BlockStoreShuffleFetcher extends Logging { def fetch[T]( @@ -52,21 +53,22 @@ private[hash] object BlockStoreShuffleFetcher extends Logging { (address, splits.map(s => (ShuffleBlockId(shuffleId, s._1, reduceId), s._2))) } - def unpackBlock(blockPair: (BlockId, Option[Iterator[Any]])) : Iterator[T] = { + def unpackBlock(blockPair: (BlockId, Try[Iterator[Any]])) : Iterator[T] = { val blockId = blockPair._1 val blockOption = blockPair._2 blockOption match { - case Some(block) => { + case Success(block) => { block.asInstanceOf[Iterator[T]] } - case None => { + case Failure(e) => { blockId match { case ShuffleBlockId(shufId, mapId, _) => val address = statuses(mapId.toInt)._1 - throw new FetchFailedException(address, shufId.toInt, mapId.toInt, reduceId) + throw new FetchFailedException(address, shufId.toInt, mapId.toInt, reduceId, + Utils.exceptionString(e)) case _ => throw new SparkException( - "Failed to get block " + blockId + ", which is not a shuffle block") + "Failed to get block " + blockId + ", which is not a shuffle block", e) } } } @@ -74,7 +76,7 @@ private[hash] object BlockStoreShuffleFetcher extends Logging { val blockFetcherItr = new ShuffleBlockFetcherIterator( context, - SparkEnv.get.blockTransferService, + SparkEnv.get.blockManager.shuffleClient, blockManager, blocksByAddress, serializer, diff --git a/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleReader.scala b/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleReader.scala index 88a5f1e5ddf58..5baf45db45c17 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleReader.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleReader.scala @@ -66,7 +66,4 @@ private[spark] class HashShuffleReader[K, C]( aggregatedIter } } - - /** Close this reader */ - override def stop(): Unit = ??? } diff --git a/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleWriter.scala b/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleWriter.scala index 746ed33b54c00..183a30373b28c 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleWriter.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleWriter.scala @@ -107,7 +107,7 @@ private[spark] class HashShuffleWriter[K, V]( writer.commitAndClose() writer.fileSegment().length } - MapStatus(blockManager.blockManagerId, sizes) + MapStatus(blockManager.shuffleServerId, sizes) } private def revertWrites(): Unit = { diff --git a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala index 927481b72cf4f..d75f9d7311fad 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala @@ -70,7 +70,7 @@ private[spark] class SortShuffleWriter[K, V, C]( val partitionLengths = sorter.writePartitionedFile(blockId, context, outputFile) shuffleBlockManager.writeIndexFile(dep.shuffleId, mapId, partitionLengths) - mapStatus = MapStatus(blockManager.blockManagerId, partitionLengths) + mapStatus = MapStatus(blockManager.shuffleServerId, partitionLengths) } /** Close this writer, passing along whether the map completed */ diff --git a/core/src/main/scala/org/apache/spark/storage/BlockId.scala b/core/src/main/scala/org/apache/spark/storage/BlockId.scala index 8df5ec6bde184..1f012941c85ab 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockId.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockId.scala @@ -53,6 +53,8 @@ case class RDDBlockId(rddId: Int, splitIndex: Int) extends BlockId { def name = "rdd_" + rddId + "_" + splitIndex } +// Format of the shuffle block ids (including data and index) should be kept in sync with +// org.apache.spark.network.shuffle.StandaloneShuffleBlockManager#getBlockData(). @DeveloperApi case class ShuffleBlockId(shuffleId: Int, mapId: Int, reduceId: Int) extends BlockId { def name = "shuffle_" + shuffleId + "_" + mapId + "_" + reduceId diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala index 4cc97923658bc..5f5dd0dc1c63f 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala @@ -17,14 +17,12 @@ package org.apache.spark.storage -import java.io.{File, InputStream, OutputStream, BufferedOutputStream, ByteArrayOutputStream} +import java.io.{BufferedOutputStream, ByteArrayOutputStream, File, InputStream, OutputStream} import java.nio.{ByteBuffer, MappedByteBuffer} -import scala.concurrent.ExecutionContext.Implicits.global - -import scala.collection.mutable import scala.collection.mutable.{ArrayBuffer, HashMap} import scala.concurrent.{Await, Future} +import scala.concurrent.ExecutionContext.Implicits.global import scala.concurrent.duration._ import scala.util.Random @@ -35,11 +33,16 @@ import org.apache.spark._ import org.apache.spark.executor._ import org.apache.spark.io.CompressionCodec import org.apache.spark.network._ +import org.apache.spark.network.buffer.{ManagedBuffer, NioManagedBuffer} +import org.apache.spark.network.netty.{SparkTransportConf, NettyBlockTransferService} +import org.apache.spark.network.shuffle.{ExecutorShuffleInfo, ExternalShuffleClient} +import org.apache.spark.network.util.{ConfigProvider, TransportConf} import org.apache.spark.serializer.Serializer import org.apache.spark.shuffle.ShuffleManager +import org.apache.spark.shuffle.hash.HashShuffleManager +import org.apache.spark.shuffle.sort.SortShuffleManager import org.apache.spark.util._ - private[spark] sealed trait BlockValues private[spark] case class ByteBufferValues(buffer: ByteBuffer) extends BlockValues private[spark] case class IteratorValues(iterator: Iterator[Any]) extends BlockValues @@ -87,9 +90,38 @@ private[spark] class BlockManager( new TachyonStore(this, tachyonBlockManager) } + private[spark] + val externalShuffleServiceEnabled = conf.getBoolean("spark.shuffle.service.enabled", false) + private val externalShuffleServicePort = conf.getInt("spark.shuffle.service.port", 7337) + // Check that we're not using external shuffle service with consolidated shuffle files. + if (externalShuffleServiceEnabled + && conf.getBoolean("spark.shuffle.consolidateFiles", false) + && shuffleManager.isInstanceOf[HashShuffleManager]) { + throw new UnsupportedOperationException("Cannot use external shuffle service with consolidated" + + " shuffle files in hash-based shuffle. Please disable spark.shuffle.consolidateFiles or " + + " switch to sort-based shuffle.") + } + val blockManagerId = BlockManagerId( executorId, blockTransferService.hostName, blockTransferService.port) + // Address of the server that serves this executor's shuffle files. This is either an external + // service, or just our own Executor's BlockManager. + private[spark] val shuffleServerId = if (externalShuffleServiceEnabled) { + BlockManagerId(executorId, blockTransferService.hostName, externalShuffleServicePort) + } else { + blockManagerId + } + + // Client to read other executors' shuffle files. This is either an external service, or just the + // standard BlockTranserService to directly connect to other Executors. + private[spark] val shuffleClient = if (externalShuffleServiceEnabled) { + val appId = conf.get("spark.app.id", "unknown-app-id") + new ExternalShuffleClient(SparkTransportConf.fromSparkConf(conf), appId) + } else { + blockTransferService + } + // Whether to compress broadcast variables that are stored private val compressBroadcast = conf.getBoolean("spark.broadcast.compress", true) // Whether to compress shuffle output that are stored @@ -145,10 +177,41 @@ private[spark] class BlockManager( /** * Initialize the BlockManager. Register to the BlockManagerMaster, and start the - * BlockManagerWorker actor. + * BlockManagerWorker actor. Additionally registers with a local shuffle service if configured. */ private def initialize(): Unit = { master.registerBlockManager(blockManagerId, maxMemory, slaveActor) + + // Register Executors' configuration with the local shuffle service, if one should exist. + if (externalShuffleServiceEnabled && !blockManagerId.isDriver) { + registerWithExternalShuffleServer() + } + } + + private def registerWithExternalShuffleServer() { + logInfo("Registering executor with local external shuffle service.") + val shuffleConfig = new ExecutorShuffleInfo( + diskBlockManager.localDirs.map(_.toString), + diskBlockManager.subDirsPerLocalDir, + shuffleManager.getClass.getName) + + val MAX_ATTEMPTS = 3 + val SLEEP_TIME_SECS = 5 + + for (i <- 1 to MAX_ATTEMPTS) { + try { + // Synchronous and will throw an exception if we cannot connect. + shuffleClient.asInstanceOf[ExternalShuffleClient].registerWithShuffleServer( + shuffleServerId.host, shuffleServerId.port, shuffleServerId.executorId, shuffleConfig) + return + } catch { + case e: Exception if i < MAX_ATTEMPTS => + val attemptsRemaining = + logError(s"Failed to connect to external shuffle server, will retry ${MAX_ATTEMPTS - i}}" + + s" more times after waiting $SLEEP_TIME_SECS seconds...", e) + Thread.sleep(SLEEP_TIME_SECS * 1000) + } + } } /** @@ -212,21 +275,20 @@ private[spark] class BlockManager( } /** - * Interface to get local block data. - * - * @return Some(buffer) if the block exists locally, and None if it doesn't. + * Interface to get local block data. Throws an exception if the block cannot be found or + * cannot be read successfully. */ - override def getBlockData(blockId: String): Option[ManagedBuffer] = { - val bid = BlockId(blockId) - if (bid.isShuffle) { - Some(shuffleManager.shuffleBlockManager.getBlockData(bid.asInstanceOf[ShuffleBlockId])) + override def getBlockData(blockId: BlockId): ManagedBuffer = { + if (blockId.isShuffle) { + shuffleManager.shuffleBlockManager.getBlockData(blockId.asInstanceOf[ShuffleBlockId]) } else { - val blockBytesOpt = doGetLocal(bid, asBlockResult = false).asInstanceOf[Option[ByteBuffer]] + val blockBytesOpt = doGetLocal(blockId, asBlockResult = false) + .asInstanceOf[Option[ByteBuffer]] if (blockBytesOpt.isDefined) { val buffer = blockBytesOpt.get - Some(new NioByteBufferManagedBuffer(buffer)) + new NioManagedBuffer(buffer) } else { - None + throw new BlockNotFoundException(blockId.toString) } } } @@ -234,8 +296,8 @@ private[spark] class BlockManager( /** * Put the block locally, using the given storage level. */ - override def putBlockData(blockId: String, data: ManagedBuffer, level: StorageLevel): Unit = { - putBytes(BlockId(blockId), data.nioByteBuffer(), level) + override def putBlockData(blockId: BlockId, data: ManagedBuffer, level: StorageLevel): Unit = { + putBytes(blockId, data.nioByteBuffer(), level) } /** @@ -340,17 +402,6 @@ private[spark] class BlockManager( locations } - /** - * A short-circuited method to get blocks directly from disk. This is used for getting - * shuffle blocks. It is safe to do so without a lock on block info since disk store - * never deletes (recent) items. - */ - def getLocalShuffleFromDisk(blockId: BlockId, serializer: Serializer): Option[Iterator[Any]] = { - val buf = shuffleManager.shuffleBlockManager.getBlockData(blockId.asInstanceOf[ShuffleBlockId]) - val is = wrapForCompression(blockId, buf.inputStream()) - Some(serializer.newInstance().deserializeStream(is).asIterator) - } - /** * Get block from local block manager. */ @@ -520,7 +571,7 @@ private[spark] class BlockManager( for (loc <- locations) { logDebug(s"Getting remote block $blockId from $loc") val data = blockTransferService.fetchBlockSync( - loc.host, loc.port, blockId.toString).nioByteBuffer() + loc.host, loc.port, loc.executorId, blockId.toString).nioByteBuffer() if (data != null) { if (asBlockResult) { @@ -869,9 +920,9 @@ private[spark] class BlockManager( data.rewind() logTrace(s"Trying to replicate $blockId of ${data.limit()} bytes to $peer") blockTransferService.uploadBlockSync( - peer.host, peer.port, blockId.toString, new NioByteBufferManagedBuffer(data), tLevel) - logTrace(s"Replicated $blockId of ${data.limit()} bytes to $peer in %d ms" - .format((System.currentTimeMillis - onePeerStartTime))) + peer.host, peer.port, blockId, new NioManagedBuffer(data), tLevel) + logTrace(s"Replicated $blockId of ${data.limit()} bytes to $peer in %s ms" + .format(System.currentTimeMillis - onePeerStartTime)) peersReplicatedTo += peer peersForReplication -= peer replicationFailed = false @@ -1126,7 +1177,11 @@ private[spark] class BlockManager( } def stop(): Unit = { - blockTransferService.stop() + blockTransferService.close() + if (shuffleClient ne blockTransferService) { + // Closing should be idempotent, but maybe not for the NioBlockTransferService. + shuffleClient.close() + } diskBlockManager.stop() actorSystem.stop(slaveActor) blockInfo.clear() diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerId.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerId.scala index 259f423c73e6b..b177a59c721df 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManagerId.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerId.scala @@ -20,6 +20,7 @@ package org.apache.spark.storage import java.io.{Externalizable, IOException, ObjectInput, ObjectOutput} import java.util.concurrent.ConcurrentHashMap +import org.apache.spark.SparkContext import org.apache.spark.annotation.DeveloperApi import org.apache.spark.util.Utils @@ -59,7 +60,7 @@ class BlockManagerId private ( def port: Int = port_ - def isDriver: Boolean = (executorId == "") + def isDriver: Boolean = { executorId == SparkContext.DRIVER_IDENTIFIER } override def writeExternal(out: ObjectOutput): Unit = Utils.tryOrIOException { out.writeUTF(executorId_) diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala index d08e1419e3e41..b63c7f191155c 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala @@ -88,6 +88,10 @@ class BlockManagerMaster( askDriverWithReply[Seq[BlockManagerId]](GetPeers(blockManagerId)) } + def getActorSystemHostPortForExecutor(executorId: String): Option[(String, Int)] = { + askDriverWithReply[Option[(String, Int)]](GetActorSystemHostPortForExecutor(executorId)) + } + /** * Remove a block from the slaves that have it. This can only be used to remove * blocks that the driver knows about. diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterActor.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterActor.scala index 5e375a2553979..685b2e11440fb 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterActor.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterActor.scala @@ -86,6 +86,9 @@ class BlockManagerMasterActor(val isLocal: Boolean, conf: SparkConf, listenerBus case GetPeers(blockManagerId) => sender ! getPeers(blockManagerId) + case GetActorSystemHostPortForExecutor(executorId) => + sender ! getActorSystemHostPortForExecutor(executorId) + case GetMemoryStatus => sender ! memoryStatus @@ -412,6 +415,21 @@ class BlockManagerMasterActor(val isLocal: Boolean, conf: SparkConf, listenerBus Seq.empty } } + + /** + * Returns the hostname and port of an executor's actor system, based on the Akka address of its + * BlockManagerSlaveActor. + */ + private def getActorSystemHostPortForExecutor(executorId: String): Option[(String, Int)] = { + for ( + blockManagerId <- blockManagerIdByExecutor.get(executorId); + info <- blockManagerInfo.get(blockManagerId); + host <- info.slaveActor.path.address.host; + port <- info.slaveActor.path.address.port + ) yield { + (host, port) + } + } } @DeveloperApi diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerMessages.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerMessages.scala index 291ddfcc113ac..3f32099d08cc9 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManagerMessages.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerMessages.scala @@ -92,6 +92,8 @@ private[spark] object BlockManagerMessages { case class GetPeers(blockManagerId: BlockManagerId) extends ToBlockManagerMaster + case class GetActorSystemHostPortForExecutor(executorId: String) extends ToBlockManagerMaster + case class RemoveExecutor(execId: String) extends ToBlockManagerMaster case object StopBlockManagerMaster extends ToBlockManagerMaster diff --git a/core/src/main/scala/org/apache/spark/storage/BlockNotFoundException.scala b/core/src/main/scala/org/apache/spark/storage/BlockNotFoundException.scala index 9ef453605f4f1..81f5f2d31dbd8 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockNotFoundException.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockNotFoundException.scala @@ -17,5 +17,4 @@ package org.apache.spark.storage - class BlockNotFoundException(blockId: String) extends Exception(s"Block $blockId not found") diff --git a/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala b/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala index 6633a1db57e59..58fba54710510 100644 --- a/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala @@ -38,12 +38,13 @@ private[spark] class DiskBlockManager(blockManager: BlockManager, conf: SparkCon extends Logging { private val MAX_DIR_CREATION_ATTEMPTS: Int = 10 - private val subDirsPerLocalDir = blockManager.conf.getInt("spark.diskStore.subDirectories", 64) + private[spark] + val subDirsPerLocalDir = blockManager.conf.getInt("spark.diskStore.subDirectories", 64) /* Create one local directory for each path mentioned in spark.local.dir; then, inside this * directory, create multiple subdirectories that we will hash files into, in order to avoid * having really large inodes at the top level. */ - val localDirs: Array[File] = createLocalDirs(conf) + private[spark] val localDirs: Array[File] = createLocalDirs(conf) if (localDirs.isEmpty) { logError("Failed to create any local dir.") System.exit(ExecutorExitCode.DISK_STORE_FAILED_TO_CREATE_DIR) @@ -52,6 +53,9 @@ private[spark] class DiskBlockManager(blockManager: BlockManager, conf: SparkCon addShutdownHook() + /** Looks up a file by hashing it into one of our local subdirectories. */ + // This method should be kept in sync with + // org.apache.spark.network.shuffle.StandaloneShuffleBlockManager#getFile(). def getFile(filename: String): File = { // Figure out which local directory it hashes to, and which subdirectory in that val hash = Utils.nonNegativeHash(filename) @@ -149,7 +153,6 @@ private[spark] class DiskBlockManager(blockManager: BlockManager, conf: SparkCon } private def addShutdownHook() { - localDirs.foreach(localDir => Utils.registerShutdownDeleteDir(localDir)) Runtime.getRuntime.addShutdownHook(new Thread("delete Spark local dirs") { override def run(): Unit = Utils.logUncaughtExceptions { logDebug("Shutdown hook called") @@ -160,13 +163,16 @@ private[spark] class DiskBlockManager(blockManager: BlockManager, conf: SparkCon /** Cleanup local dirs and stop shuffle sender. */ private[spark] def stop() { - localDirs.foreach { localDir => - if (localDir.isDirectory() && localDir.exists()) { - try { - if (!Utils.hasRootAsShutdownDeleteDir(localDir)) Utils.deleteRecursively(localDir) - } catch { - case e: Exception => - logError(s"Exception while deleting local spark dir: $localDir", e) + // Only perform cleanup if an external service is not serving our shuffle files. + if (!blockManager.externalShuffleServiceEnabled) { + localDirs.foreach { localDir => + if (localDir.isDirectory() && localDir.exists()) { + try { + if (!Utils.hasRootAsShutdownDeleteDir(localDir)) Utils.deleteRecursively(localDir) + } catch { + case e: Exception => + logError(s"Exception while deleting local spark dir: $localDir", e) + } } } } diff --git a/core/src/main/scala/org/apache/spark/storage/DiskStore.scala b/core/src/main/scala/org/apache/spark/storage/DiskStore.scala index bac459e835a3f..8dadf6794039e 100644 --- a/core/src/main/scala/org/apache/spark/storage/DiskStore.scala +++ b/core/src/main/scala/org/apache/spark/storage/DiskStore.scala @@ -17,7 +17,7 @@ package org.apache.spark.storage -import java.io.{File, FileOutputStream, RandomAccessFile} +import java.io.{IOException, File, FileOutputStream, RandomAccessFile} import java.nio.ByteBuffer import java.nio.channels.FileChannel.MapMode @@ -110,7 +110,13 @@ private[spark] class DiskStore(blockManager: BlockManager, diskManager: DiskBloc // For small files, directly read rather than memory map if (length < minMemoryMapBytes) { val buf = ByteBuffer.allocate(length.toInt) - channel.read(buf, offset) + channel.position(offset) + while (buf.remaining() != 0) { + if (channel.read(buf) == -1) { + throw new IOException("Reached EOF before filling buffer\n" + + s"offset=$offset\nfile=${file.getAbsolutePath}\nbuf.remaining=${buf.remaining}") + } + } buf.flip() Some(buf) } else { diff --git a/core/src/main/scala/org/apache/spark/storage/MemoryStore.scala b/core/src/main/scala/org/apache/spark/storage/MemoryStore.scala index edbc729c17ade..71305a46bf570 100644 --- a/core/src/main/scala/org/apache/spark/storage/MemoryStore.scala +++ b/core/src/main/scala/org/apache/spark/storage/MemoryStore.scala @@ -56,6 +56,16 @@ private[spark] class MemoryStore(blockManager: BlockManager, maxMemory: Long) (maxMemory * unrollFraction).toLong } + // Initial memory to request before unrolling any block + private val unrollMemoryThreshold: Long = + conf.getLong("spark.storage.unrollMemoryThreshold", 1024 * 1024) + + if (maxMemory < unrollMemoryThreshold) { + logWarning(s"Max memory ${Utils.bytesToString(maxMemory)} is less than the initial memory " + + s"threshold ${Utils.bytesToString(unrollMemoryThreshold)} needed to store a block in " + + s"memory. Please configure Spark with more memory.") + } + logInfo("MemoryStore started with capacity %s".format(Utils.bytesToString(maxMemory))) /** Free memory not occupied by existing blocks. Note that this does not include unroll memory. */ @@ -213,7 +223,7 @@ private[spark] class MemoryStore(blockManager: BlockManager, maxMemory: Long) // Whether there is still enough memory for us to continue unrolling this block var keepUnrolling = true // Initial per-thread memory to request for unrolling blocks (bytes). Exposed for testing. - val initialMemoryThreshold = conf.getLong("spark.storage.unrollMemoryThreshold", 1024 * 1024) + val initialMemoryThreshold = unrollMemoryThreshold // How often to check whether we need to request more memory val memoryCheckPeriod = 16 // Memory currently reserved by this thread for this particular unrolling operation @@ -228,6 +238,11 @@ private[spark] class MemoryStore(blockManager: BlockManager, maxMemory: Long) // Request enough memory to begin unrolling keepUnrolling = reserveUnrollMemoryForThisThread(initialMemoryThreshold) + if (!keepUnrolling) { + logWarning(s"Failed to reserve initial memory threshold of " + + s"${Utils.bytesToString(initialMemoryThreshold)} for computing block $blockId in memory.") + } + // Unroll this block safely, checking whether we have exceeded our threshold periodically try { while (values.hasNext && keepUnrolling) { diff --git a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala index 71b276b5f18e4..1e579187e4193 100644 --- a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala +++ b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala @@ -19,15 +19,15 @@ package org.apache.spark.storage import java.util.concurrent.LinkedBlockingQueue -import scala.collection.mutable.ArrayBuffer -import scala.collection.mutable.HashSet -import scala.collection.mutable.Queue +import scala.collection.mutable.{ArrayBuffer, HashSet, Queue} +import scala.util.{Failure, Success, Try} -import org.apache.spark.{TaskContext, Logging} -import org.apache.spark.network.{ManagedBuffer, BlockFetchingListener, BlockTransferService} +import org.apache.spark.{Logging, TaskContext} +import org.apache.spark.network.BlockTransferService +import org.apache.spark.network.shuffle.{BlockFetchingListener, ShuffleClient} +import org.apache.spark.network.buffer.ManagedBuffer import org.apache.spark.serializer.Serializer -import org.apache.spark.util.Utils - +import org.apache.spark.util.{CompletionIterator, Utils} /** * An iterator that fetches multiple blocks. For local blocks, it fetches from the local block @@ -40,8 +40,8 @@ import org.apache.spark.util.Utils * using too much memory. * * @param context [[TaskContext]], used for metrics update - * @param blockTransferService [[BlockTransferService]] for fetching remote blocks - * @param blockManager [[BlockManager]] for reading local blocks + * @param shuffleClient [[ShuffleClient]] for fetching remote blocks + * @param blockManager [[BlockManager]] for reading local blocks * @param blocksByAddress list of blocks to fetch grouped by the [[BlockManagerId]]. * For each block we also require the size (in bytes as a long field) in * order to throttle the memory usage. @@ -51,12 +51,12 @@ import org.apache.spark.util.Utils private[spark] final class ShuffleBlockFetcherIterator( context: TaskContext, - blockTransferService: BlockTransferService, + shuffleClient: ShuffleClient, blockManager: BlockManager, blocksByAddress: Seq[(BlockManagerId, Seq[(BlockId, Long)])], serializer: Serializer, maxBytesInFlight: Long) - extends Iterator[(BlockId, Option[Iterator[Any]])] with Logging { + extends Iterator[(BlockId, Try[Iterator[Any]])] with Logging { import ShuffleBlockFetcherIterator._ @@ -88,17 +88,53 @@ final class ShuffleBlockFetcherIterator( */ private[this] val results = new LinkedBlockingQueue[FetchResult] - // Queue of fetch requests to issue; we'll pull requests off this gradually to make sure that - // the number of bytes in flight is limited to maxBytesInFlight + /** + * Current [[FetchResult]] being processed. We track this so we can release the current buffer + * in case of a runtime exception when processing the current buffer. + */ + private[this] var currentResult: FetchResult = null + + /** + * Queue of fetch requests to issue; we'll pull requests off this gradually to make sure that + * the number of bytes in flight is limited to maxBytesInFlight. + */ private[this] val fetchRequests = new Queue[FetchRequest] - // Current bytes in flight from our requests + /** Current bytes in flight from our requests */ private[this] var bytesInFlight = 0L private[this] val shuffleMetrics = context.taskMetrics.createShuffleReadMetricsForDependency() + /** + * Whether the iterator is still active. If isZombie is true, the callback interface will no + * longer place fetched blocks into [[results]]. + */ + @volatile private[this] var isZombie = false + initialize() + /** + * Mark the iterator as zombie, and release all buffers that haven't been deserialized yet. + */ + private[this] def cleanup() { + isZombie = true + // Release the current buffer if necessary + currentResult match { + case SuccessFetchResult(_, _, buf) => buf.release() + case _ => + } + + // Release buffers in the results queue + val iter = results.iterator() + while (iter.hasNext) { + val result = iter.next() + result match { + case SuccessFetchResult(_, _, buf) => buf.release() + case _ => + } + } + } + private[this] def sendRequest(req: FetchRequest) { logDebug("Sending request for %d blocks (%s) from %s".format( req.blocks.size, Utils.bytesToString(req.size), req.address.hostPort)) @@ -108,26 +144,26 @@ final class ShuffleBlockFetcherIterator( val sizeMap = req.blocks.map { case (blockId, size) => (blockId.toString, size) }.toMap val blockIds = req.blocks.map(_._1.toString) - blockTransferService.fetchBlocks(req.address.host, req.address.port, blockIds, + val address = req.address + shuffleClient.fetchBlocks(address.host, address.port, address.executorId, blockIds.toArray, new BlockFetchingListener { - override def onBlockFetchSuccess(blockId: String, data: ManagedBuffer): Unit = { - results.put(new FetchResult(BlockId(blockId), sizeMap(blockId), - () => serializer.newInstance().deserializeStream( - blockManager.wrapForCompression(BlockId(blockId), data.inputStream())).asIterator - )) - shuffleMetrics.remoteBytesRead += data.size - shuffleMetrics.remoteBlocksFetched += 1 - logDebug("Got remote block " + blockId + " after " + Utils.getUsedTimeMs(startTime)) + override def onBlockFetchSuccess(blockId: String, buf: ManagedBuffer): Unit = { + // Only add the buffer to results queue if the iterator is not zombie, + // i.e. cleanup() has not been called yet. + if (!isZombie) { + // Increment the ref count because we need to pass this to a different thread. + // This needs to be released after use. + buf.retain() + results.put(new SuccessFetchResult(BlockId(blockId), sizeMap(blockId), buf)) + shuffleMetrics.remoteBytesRead += buf.size + shuffleMetrics.remoteBlocksFetched += 1 + } + logTrace("Got remote block " + blockId + " after " + Utils.getUsedTimeMs(startTime)) } - override def onBlockFetchFailure(e: Throwable): Unit = { + override def onBlockFetchFailure(blockId: String, e: Throwable): Unit = { logError(s"Failed to get block(s) from ${req.address.host}:${req.address.port}", e) - // Note that there is a chance that some blocks have been fetched successfully, but we - // still add them to the failed queue. This is fine because when the caller see a - // FetchFailedException, it is going to fail the entire task anyway. - for ((blockId, size) <- req.blocks) { - results.put(new FetchResult(blockId, -1, null)) - } + results.put(new FailureFetchResult(BlockId(blockId), e)) } } ) @@ -138,7 +174,7 @@ final class ShuffleBlockFetcherIterator( // smaller than maxBytesInFlight is to allow multiple, parallel fetches from up to 5 // nodes, rather than blocking on reading output from one node. val targetRequestSize = math.max(maxBytesInFlight / 5, 1L) - logInfo("maxBytesInFlight: " + maxBytesInFlight + ", targetRequestSize: " + targetRequestSize) + logDebug("maxBytesInFlight: " + maxBytesInFlight + ", targetRequestSize: " + targetRequestSize) // Split local and remote blocks. Remote blocks are further split into FetchRequests of size // at most maxBytesInFlight in order to limit the amount of data in flight. @@ -148,7 +184,7 @@ final class ShuffleBlockFetcherIterator( var totalBlocks = 0 for ((address, blockInfos) <- blocksByAddress) { totalBlocks += blockInfos.size - if (address == blockManager.blockManagerId) { + if (address.executorId == blockManager.blockManagerId.executorId) { // Filter out zero-sized blocks localBlocks ++= blockInfos.filter(_._2 != 0).map(_._1) numBlocksToFetch += localBlocks.size @@ -185,26 +221,34 @@ final class ShuffleBlockFetcherIterator( remoteRequests } + /** + * Fetch the local blocks while we are fetching remote blocks. This is ok because + * [[ManagedBuffer]]'s memory is allocated lazily when we create the input stream, so all we + * track in-memory are the ManagedBuffer references themselves. + */ private[this] def fetchLocalBlocks() { - // Get the local blocks while remote blocks are being fetched. Note that it's okay to do - // these all at once because they will just memory-map some files, so they won't consume - // any memory that might exceed our maxBytesInFlight - for (id <- localBlocks) { + val iter = localBlocks.iterator + while (iter.hasNext) { + val blockId = iter.next() try { + val buf = blockManager.getBlockData(blockId) shuffleMetrics.localBlocksFetched += 1 - results.put(new FetchResult( - id, 0, () => blockManager.getLocalShuffleFromDisk(id, serializer).get)) - logDebug("Got local block " + id) + buf.retain() + results.put(new SuccessFetchResult(blockId, 0, buf)) } catch { case e: Exception => + // If we see an exception, stop immediately. logError(s"Error occurred while fetching local blocks", e) - results.put(new FetchResult(id, -1, null)) + results.put(new FailureFetchResult(blockId, e)) return } } } private[this] def initialize(): Unit = { + // Add a task completion callback (called in both success case and failure case) to cleanup. + context.addTaskCompletionListener(_ => cleanup()) + // Split local and remote blocks. val remoteRequests = splitLocalRemoteBlocks() // Add the remote requests into our queue in a random order @@ -226,21 +270,39 @@ final class ShuffleBlockFetcherIterator( override def hasNext: Boolean = numBlocksProcessed < numBlocksToFetch - override def next(): (BlockId, Option[Iterator[Any]]) = { + override def next(): (BlockId, Try[Iterator[Any]]) = { numBlocksProcessed += 1 val startFetchWait = System.currentTimeMillis() - val result = results.take() + currentResult = results.take() + val result = currentResult val stopFetchWait = System.currentTimeMillis() shuffleMetrics.fetchWaitTime += (stopFetchWait - startFetchWait) - if (!result.failed) { - bytesInFlight -= result.size + + result match { + case SuccessFetchResult(_, size, _) => bytesInFlight -= size + case _ => } // Send fetch requests up to maxBytesInFlight while (fetchRequests.nonEmpty && (bytesInFlight == 0 || bytesInFlight + fetchRequests.front.size <= maxBytesInFlight)) { sendRequest(fetchRequests.dequeue()) } - (result.blockId, if (result.failed) None else Some(result.deserialize())) + + val iteratorTry: Try[Iterator[Any]] = result match { + case FailureFetchResult(_, e) => Failure(e) + case SuccessFetchResult(blockId, _, buf) => { + val is = blockManager.wrapForCompression(blockId, buf.createInputStream()) + val iter = serializer.newInstance().deserializeStream(is).asIterator + Success(CompletionIterator[Any, Iterator[Any]](iter, { + // Once the iterator is exhausted, release the buffer and set currentResult to null + // so we don't release it again in cleanup. + currentResult = null + buf.release() + })) + } + } + + (result.blockId, iteratorTry) } } @@ -254,18 +316,35 @@ object ShuffleBlockFetcherIterator { * @param blocks Sequence of tuple, where the first element is the block id, * and the second element is the estimated size, used to calculate bytesInFlight. */ - class FetchRequest(val address: BlockManagerId, val blocks: Seq[(BlockId, Long)]) { + case class FetchRequest(address: BlockManagerId, blocks: Seq[(BlockId, Long)]) { val size = blocks.map(_._2).sum } /** - * Result of a fetch from a remote block. A failure is represented as size == -1. + * Result of a fetch from a remote block. + */ + private[storage] sealed trait FetchResult { + val blockId: BlockId + } + + /** + * Result of a fetch from a remote block successfully. * @param blockId block id * @param size estimated size of the block, used to calculate bytesInFlight. * Note that this is NOT the exact bytes. - * @param deserialize closure to return the result in the form of an Iterator. + * @param buf [[ManagedBuffer]] for the content. */ - class FetchResult(val blockId: BlockId, val size: Long, val deserialize: () => Iterator[Any]) { - def failed: Boolean = size == -1 + private[storage] case class SuccessFetchResult(blockId: BlockId, size: Long, buf: ManagedBuffer) + extends FetchResult { + require(buf != null) + require(size >= 0) } + + /** + * Result of a fetch from a remote block unsuccessfully. + * @param blockId block id + * @param e the failure exception + */ + private[storage] case class FailureFetchResult(blockId: BlockId, e: Throwable) + extends FetchResult } diff --git a/core/src/main/scala/org/apache/spark/storage/StorageStatusListener.scala b/core/src/main/scala/org/apache/spark/storage/StorageStatusListener.scala index d9066f766476e..def49e80a3605 100644 --- a/core/src/main/scala/org/apache/spark/storage/StorageStatusListener.scala +++ b/core/src/main/scala/org/apache/spark/storage/StorageStatusListener.scala @@ -19,6 +19,7 @@ package org.apache.spark.storage import scala.collection.mutable +import org.apache.spark.SparkContext import org.apache.spark.annotation.DeveloperApi import org.apache.spark.scheduler._ @@ -59,10 +60,9 @@ class StorageStatusListener extends SparkListener { val info = taskEnd.taskInfo val metrics = taskEnd.taskMetrics if (info != null && metrics != null) { - val execId = formatExecutorId(info.executorId) val updatedBlocks = metrics.updatedBlocks.getOrElse(Seq[(BlockId, BlockStatus)]()) if (updatedBlocks.length > 0) { - updateStorageStatus(execId, updatedBlocks) + updateStorageStatus(info.executorId, updatedBlocks) } } } @@ -88,13 +88,4 @@ class StorageStatusListener extends SparkListener { } } - /** - * In the local mode, there is a discrepancy between the executor ID according to the - * task ("localhost") and that according to SparkEnv (""). In the UI, this - * results in duplicate rows for the same executor. Thus, in this mode, we aggregate - * these two rows and use the executor ID of "" to be consistent. - */ - def formatExecutorId(execId: String): String = { - if (execId == "localhost") "" else execId - } } diff --git a/core/src/main/scala/org/apache/spark/storage/TachyonStore.scala b/core/src/main/scala/org/apache/spark/storage/TachyonStore.scala index 932b5616043b4..6dbad5ff0518e 100644 --- a/core/src/main/scala/org/apache/spark/storage/TachyonStore.scala +++ b/core/src/main/scala/org/apache/spark/storage/TachyonStore.scala @@ -20,6 +20,7 @@ package org.apache.spark.storage import java.io.IOException import java.nio.ByteBuffer +import com.google.common.io.ByteStreams import tachyon.client.{ReadType, WriteType} import org.apache.spark.Logging @@ -105,25 +106,17 @@ private[spark] class TachyonStore( return None } val is = file.getInStream(ReadType.CACHE) - var buffer: ByteBuffer = null + assert (is != null) try { - if (is != null) { - val size = file.length - val bs = new Array[Byte](size.asInstanceOf[Int]) - val fetchSize = is.read(bs, 0, size.asInstanceOf[Int]) - buffer = ByteBuffer.wrap(bs) - if (fetchSize != size) { - logWarning(s"Failed to fetch the block $blockId from Tachyon: Size $size " + - s"is not equal to fetched size $fetchSize") - return None - } - } + val size = file.length + val bs = new Array[Byte](size.asInstanceOf[Int]) + ByteStreams.readFully(is, bs) + Some(ByteBuffer.wrap(bs)) } catch { case ioe: IOException => logWarning(s"Failed to fetch the block $blockId from Tachyon", ioe) - return None + None } - Some(buffer) } override def contains(blockId: BlockId): Boolean = { diff --git a/core/src/main/scala/org/apache/spark/ui/SparkUI.scala b/core/src/main/scala/org/apache/spark/ui/SparkUI.scala index cccd59d122a92..049938f827291 100644 --- a/core/src/main/scala/org/apache/spark/ui/SparkUI.scala +++ b/core/src/main/scala/org/apache/spark/ui/SparkUI.scala @@ -21,47 +21,30 @@ import org.apache.spark.{Logging, SecurityManager, SparkConf, SparkContext} import org.apache.spark.scheduler._ import org.apache.spark.storage.StorageStatusListener import org.apache.spark.ui.JettyUtils._ -import org.apache.spark.ui.env.EnvironmentTab -import org.apache.spark.ui.exec.ExecutorsTab -import org.apache.spark.ui.jobs.JobProgressTab -import org.apache.spark.ui.storage.StorageTab +import org.apache.spark.ui.env.{EnvironmentListener, EnvironmentTab} +import org.apache.spark.ui.exec.{ExecutorsListener, ExecutorsTab} +import org.apache.spark.ui.jobs.{JobProgressListener, JobProgressTab} +import org.apache.spark.ui.storage.{StorageListener, StorageTab} /** * Top level user interface for a Spark application. */ -private[spark] class SparkUI( - val sc: SparkContext, +private[spark] class SparkUI private ( + val sc: Option[SparkContext], val conf: SparkConf, val securityManager: SecurityManager, - val listenerBus: SparkListenerBus, + val environmentListener: EnvironmentListener, + val storageStatusListener: StorageStatusListener, + val executorsListener: ExecutorsListener, + val jobProgressListener: JobProgressListener, + val storageListener: StorageListener, var appName: String, - val basePath: String = "") + val basePath: String) extends WebUI(securityManager, SparkUI.getUIPort(conf), conf, basePath, "SparkUI") with Logging { - def this(sc: SparkContext) = this(sc, sc.conf, sc.env.securityManager, sc.listenerBus, sc.appName) - def this(conf: SparkConf, listenerBus: SparkListenerBus, appName: String, basePath: String) = - this(null, conf, new SecurityManager(conf), listenerBus, appName, basePath) - - def this( - conf: SparkConf, - securityManager: SecurityManager, - listenerBus: SparkListenerBus, - appName: String, - basePath: String) = - this(null, conf, securityManager, listenerBus, appName, basePath) - - // If SparkContext is not provided, assume the associated application is not live - val live = sc != null - - // Maintain executor storage status through Spark events - val storageStatusListener = new StorageStatusListener - - initialize() - /** Initialize all components of the server. */ def initialize() { - listenerBus.addListener(storageStatusListener) val jobProgressTab = new JobProgressTab(this) attachTab(jobProgressTab) attachTab(new StorageTab(this)) @@ -71,10 +54,10 @@ private[spark] class SparkUI( attachHandler(createRedirectHandler("/", "/stages", basePath = basePath)) attachHandler( createRedirectHandler("/stages/stage/kill", "/stages", jobProgressTab.handleKillRequest)) - if (live) { - sc.env.metricsSystem.getServletHandlers.foreach(attachHandler) - } + // If the UI is live, then serve + sc.foreach { _.env.metricsSystem.getServletHandlers.foreach(attachHandler) } } + initialize() def getAppName = appName @@ -83,11 +66,6 @@ private[spark] class SparkUI( appName = name } - /** Register the given listener with the listener bus. */ - def registerListener(listener: SparkListener) { - listenerBus.addListener(listener) - } - /** Stop the server behind this web interface. Only valid after bind(). */ override def stop() { super.stop() @@ -116,4 +94,60 @@ private[spark] object SparkUI { def getUIPort(conf: SparkConf): Int = { conf.getInt("spark.ui.port", SparkUI.DEFAULT_PORT) } + + def createLiveUI( + sc: SparkContext, + conf: SparkConf, + listenerBus: SparkListenerBus, + jobProgressListener: JobProgressListener, + securityManager: SecurityManager, + appName: String): SparkUI = { + create(Some(sc), conf, listenerBus, securityManager, appName, + jobProgressListener = Some(jobProgressListener)) + } + + def createHistoryUI( + conf: SparkConf, + listenerBus: SparkListenerBus, + securityManager: SecurityManager, + appName: String, + basePath: String): SparkUI = { + create(None, conf, listenerBus, securityManager, appName, basePath) + } + + /** + * Create a new Spark UI. + * + * @param sc optional SparkContext; this can be None when reconstituting a UI from event logs. + * @param jobProgressListener if supplied, this JobProgressListener will be used; otherwise, the + * web UI will create and register its own JobProgressListener. + */ + private def create( + sc: Option[SparkContext], + conf: SparkConf, + listenerBus: SparkListenerBus, + securityManager: SecurityManager, + appName: String, + basePath: String = "", + jobProgressListener: Option[JobProgressListener] = None): SparkUI = { + + val _jobProgressListener: JobProgressListener = jobProgressListener.getOrElse { + val listener = new JobProgressListener(conf) + listenerBus.addListener(listener) + listener + } + + val environmentListener = new EnvironmentListener + val storageStatusListener = new StorageStatusListener + val executorsListener = new ExecutorsListener(storageStatusListener) + val storageListener = new StorageListener(storageStatusListener) + + listenerBus.addListener(environmentListener) + listenerBus.addListener(storageStatusListener) + listenerBus.addListener(executorsListener) + listenerBus.addListener(storageListener) + + new SparkUI(sc, conf, securityManager, environmentListener, storageStatusListener, + executorsListener, _jobProgressListener, storageListener, appName, basePath) + } } diff --git a/core/src/main/scala/org/apache/spark/ui/ToolTips.scala b/core/src/main/scala/org/apache/spark/ui/ToolTips.scala index 9ced9b8107ebf..f02904df31fcf 100644 --- a/core/src/main/scala/org/apache/spark/ui/ToolTips.scala +++ b/core/src/main/scala/org/apache/spark/ui/ToolTips.scala @@ -31,4 +31,16 @@ private[spark] object ToolTips { val SHUFFLE_READ = """Bytes read from remote executors. Typically less than shuffle write bytes because this does not include shuffle data read locally.""" + + val GETTING_RESULT_TIME = + """Time that the driver spends fetching task results from workers. If this is large, consider + decreasing the amount of data returned from each task.""" + + val RESULT_SERIALIZATION_TIME = + """Time spent serializing the task result on the executor before sending it back to the + driver.""" + + val GC_TIME = + """Time that the executor spent paused for Java garbage collection while the task was + running.""" } diff --git a/core/src/main/scala/org/apache/spark/ui/UIUtils.scala b/core/src/main/scala/org/apache/spark/ui/UIUtils.scala index 32e6b15bb0999..3312671b6f885 100644 --- a/core/src/main/scala/org/apache/spark/ui/UIUtils.scala +++ b/core/src/main/scala/org/apache/spark/ui/UIUtils.scala @@ -20,13 +20,13 @@ package org.apache.spark.ui import java.text.SimpleDateFormat import java.util.{Locale, Date} -import scala.xml.Node +import scala.xml.{Node, Text} import org.apache.spark.Logging /** Utility functions for generating XML pages with spark content. */ private[spark] object UIUtils extends Logging { - val TABLE_CLASS = "table table-bordered table-striped table-condensed sortable" + val TABLE_CLASS = "table table-bordered table-striped-custom table-condensed sortable" // SimpleDateFormat is not thread-safe. Don't expose it to avoid improper use. private val dateFormat = new ThreadLocal[SimpleDateFormat]() { @@ -160,6 +160,8 @@ private[spark] object UIUtils extends Logging { + + } /** Returns a spark page with correctly formatted headers */ @@ -239,7 +241,9 @@ private[spark] object UIUtils extends Logging { headers: Seq[String], generateDataRow: T => Seq[Node], data: Iterable[T], - fixedWidth: Boolean = false): Seq[Node] = { + fixedWidth: Boolean = false, + id: Option[String] = None, + headerClasses: Seq[String] = Seq.empty): Seq[Node] = { var listingTableClass = TABLE_CLASS if (fixedWidth) { @@ -247,23 +251,32 @@ private[spark] object UIUtils extends Logging { } val colWidth = 100.toDouble / headers.size val colWidthAttr = if (fixedWidth) colWidth + "%" else "" - val headerRow: Seq[Node] = { - // if none of the headers have "\n" in them - if (headers.forall(!_.contains("\n"))) { - // represent header as simple text - headers.map(h => {h}) + + def getClass(index: Int): String = { + if (index < headerClasses.size) { + headerClasses(index) } else { - // represent header text as list while respecting "\n" - headers.map { case h => - -

    - { h.split("\n").map { case t =>
  • {t}
  • } } -
- - } + "" + } + } + + val newlinesInHeader = headers.exists(_.contains("\n")) + def getHeaderContent(header: String): Seq[Node] = { + if (newlinesInHeader) { +
    + { header.split("\n").map { case t =>
  • {t}
  • } } +
+ } else { + Text(header) + } + } + + val headerRow: Seq[Node] = { + headers.view.zipWithIndex.map { x => + {getHeaderContent(x._1)} } } - +
{headerRow} {data.map(r => generateDataRow(r))} diff --git a/core/src/main/scala/org/apache/spark/ui/WebUI.scala b/core/src/main/scala/org/apache/spark/ui/WebUI.scala index 5d88ca403a674..9be65a4a39a09 100644 --- a/core/src/main/scala/org/apache/spark/ui/WebUI.scala +++ b/core/src/main/scala/org/apache/spark/ui/WebUI.scala @@ -82,7 +82,7 @@ private[spark] abstract class WebUI( } /** Detach a handler from this UI. */ - def detachHandler(handler: ServletContextHandler) { + protected def detachHandler(handler: ServletContextHandler) { handlers -= handler serverInfo.foreach { info => info.rootHandler.removeHandler(handler) diff --git a/core/src/main/scala/org/apache/spark/ui/env/EnvironmentTab.scala b/core/src/main/scala/org/apache/spark/ui/env/EnvironmentTab.scala index 0d158fbe638d3..f62260c6f6e1d 100644 --- a/core/src/main/scala/org/apache/spark/ui/env/EnvironmentTab.scala +++ b/core/src/main/scala/org/apache/spark/ui/env/EnvironmentTab.scala @@ -22,10 +22,8 @@ import org.apache.spark.scheduler._ import org.apache.spark.ui._ private[ui] class EnvironmentTab(parent: SparkUI) extends SparkUITab(parent, "environment") { - val listener = new EnvironmentListener - + val listener = parent.environmentListener attachPage(new EnvironmentPage(this)) - parent.registerListener(listener) } /** diff --git a/core/src/main/scala/org/apache/spark/ui/exec/ExecutorThreadDumpPage.scala b/core/src/main/scala/org/apache/spark/ui/exec/ExecutorThreadDumpPage.scala new file mode 100644 index 0000000000000..e9c755e36f716 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/ui/exec/ExecutorThreadDumpPage.scala @@ -0,0 +1,73 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ui.exec + +import javax.servlet.http.HttpServletRequest + +import scala.util.Try +import scala.xml.{Text, Node} + +import org.apache.spark.ui.{UIUtils, WebUIPage} + +private[ui] class ExecutorThreadDumpPage(parent: ExecutorsTab) extends WebUIPage("threadDump") { + + private val sc = parent.sc + + def render(request: HttpServletRequest): Seq[Node] = { + val executorId = Option(request.getParameter("executorId")).getOrElse { + return Text(s"Missing executorId parameter") + } + val time = System.currentTimeMillis() + val maybeThreadDump = sc.get.getExecutorThreadDump(executorId) + + val content = maybeThreadDump.map { threadDump => + val dumpRows = threadDump.map { thread => + + } + +
+

Updated at {UIUtils.formatDate(time)}

+ { + // scalastyle:off +

+ Expand All +

+

+ // scalastyle:on + } +
{dumpRows}
+
+ }.getOrElse(Text("Error fetching thread dump")) + UIUtils.headerSparkPage(s"Thread dump for executor $executorId", content, parent) + } +} diff --git a/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsPage.scala b/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsPage.scala index b0e3bb3b552fd..048fee3ce1ff4 100644 --- a/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsPage.scala +++ b/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsPage.scala @@ -41,7 +41,10 @@ private case class ExecutorSummaryInfo( totalShuffleWrite: Long, maxMemory: Long) -private[ui] class ExecutorsPage(parent: ExecutorsTab) extends WebUIPage("") { +private[ui] class ExecutorsPage( + parent: ExecutorsTab, + threadDumpEnabled: Boolean) + extends WebUIPage("") { private val listener = parent.listener def render(request: HttpServletRequest): Seq[Node] = { @@ -75,6 +78,7 @@ private[ui] class ExecutorsPage(parent: ExecutorsTab) extends WebUIPage("") { Shuffle Write + {if (threadDumpEnabled) else Seq.empty} {execInfoSorted.map(execRow)} @@ -133,6 +137,15 @@ private[ui] class ExecutorsPage(parent: ExecutorsTab) extends WebUIPage("") { + { + if (threadDumpEnabled) { + + } else { + Seq.empty + } + } } diff --git a/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsTab.scala b/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsTab.scala index 61eb111cd9100..ba97630f025c1 100644 --- a/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsTab.scala +++ b/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsTab.scala @@ -26,10 +26,15 @@ import org.apache.spark.storage.StorageStatusListener import org.apache.spark.ui.{SparkUI, SparkUITab} private[ui] class ExecutorsTab(parent: SparkUI) extends SparkUITab(parent, "executors") { - val listener = new ExecutorsListener(parent.storageStatusListener) + val listener = parent.executorsListener + val sc = parent.sc + val threadDumpEnabled = + sc.isDefined && parent.conf.getBoolean("spark.ui.threadDumpsEnabled", true) - attachPage(new ExecutorsPage(this)) - parent.registerListener(listener) + attachPage(new ExecutorsPage(this, threadDumpEnabled)) + if (threadDumpEnabled) { + attachPage(new ExecutorThreadDumpPage(this)) + } } /** @@ -49,14 +54,14 @@ class ExecutorsListener(storageStatusListener: StorageStatusListener) extends Sp def storageStatusList = storageStatusListener.storageStatusList override def onTaskStart(taskStart: SparkListenerTaskStart) = synchronized { - val eid = formatExecutorId(taskStart.taskInfo.executorId) + val eid = taskStart.taskInfo.executorId executorToTasksActive(eid) = executorToTasksActive.getOrElse(eid, 0) + 1 } override def onTaskEnd(taskEnd: SparkListenerTaskEnd) = synchronized { val info = taskEnd.taskInfo if (info != null) { - val eid = formatExecutorId(info.executorId) + val eid = info.executorId executorToTasksActive(eid) = executorToTasksActive.getOrElse(eid, 1) - 1 executorToDuration(eid) = executorToDuration.getOrElse(eid, 0L) + info.duration taskEnd.reason match { @@ -85,6 +90,4 @@ class ExecutorsListener(storageStatusListener: StorageStatusListener) extends Sp } } - // This addresses executor ID inconsistencies in the local mode - private def formatExecutorId(execId: String) = storageStatusListener.formatExecutorId(execId) } diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala b/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala index eaeb861f59e5a..e3223403c17f4 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala @@ -40,17 +40,32 @@ class JobProgressListener(conf: SparkConf) extends SparkListener with Logging { import JobProgressListener._ + type JobId = Int + type StageId = Int + type StageAttemptId = Int + // How many stages to remember val retainedStages = conf.getInt("spark.ui.retainedStages", DEFAULT_RETAINED_STAGES) + // How many jobs to remember + val retailedJobs = conf.getInt("spark.ui.retainedJobs", DEFAULT_RETAINED_JOBS) - // Map from stageId to StageInfo - val activeStages = new HashMap[Int, StageInfo] - - // Map from (stageId, attemptId) to StageUIData - val stageIdToData = new HashMap[(Int, Int), StageUIData] + val activeJobs = new HashMap[JobId, JobUIData] + val completedJobs = ListBuffer[JobUIData]() + val failedJobs = ListBuffer[JobUIData]() + val jobIdToData = new HashMap[JobId, JobUIData] + val activeStages = new HashMap[StageId, StageInfo] val completedStages = ListBuffer[StageInfo]() val failedStages = ListBuffer[StageInfo]() + val stageIdToData = new HashMap[(StageId, StageAttemptId), StageUIData] + val stageIdToInfo = new HashMap[StageId, StageInfo] + + // Number of completed and failed stages, may not actually equal to completedStages.size and + // failedStages.size respectively due to completedStage and failedStages only maintain the latest + // part of the stages, the earlier ones will be removed when there are too many stages for + // memory sake. + var numCompletedStages = 0 + var numFailedStages = 0 // Map from pool name to a hash map (map from stage id to StageInfo). val poolToActiveStages = HashMap[String, HashMap[Int, StageInfo]]() @@ -61,8 +76,32 @@ class JobProgressListener(conf: SparkConf) extends SparkListener with Logging { def blockManagerIds = executorIdToBlockManagerId.values.toSeq + override def onJobStart(jobStart: SparkListenerJobStart) = synchronized { + val jobGroup = Option(jobStart.properties).map(_.getProperty(SparkContext.SPARK_JOB_GROUP_ID)) + val jobData: JobUIData = + new JobUIData(jobStart.jobId, jobStart.stageIds, jobGroup, JobExecutionStatus.RUNNING) + jobIdToData(jobStart.jobId) = jobData + activeJobs(jobStart.jobId) = jobData + } + + override def onJobEnd(jobEnd: SparkListenerJobEnd) = synchronized { + val jobData = activeJobs.remove(jobEnd.jobId).getOrElse { + logWarning(s"Job completed for unknown job ${jobEnd.jobId}") + new JobUIData(jobId = jobEnd.jobId) + } + jobEnd.jobResult match { + case JobSucceeded => + completedJobs += jobData + jobData.status = JobExecutionStatus.SUCCEEDED + case JobFailed(exception) => + failedJobs += jobData + jobData.status = JobExecutionStatus.FAILED + } + } + override def onStageCompleted(stageCompleted: SparkListenerStageCompleted) = synchronized { val stage = stageCompleted.stageInfo + stageIdToInfo(stage.stageId) = stage val stageData = stageIdToData.getOrElseUpdate((stage.stageId, stage.attemptId), { logWarning("Stage completed for unknown stage " + stage.stageId) new StageUIData @@ -78,9 +117,11 @@ class JobProgressListener(conf: SparkConf) extends SparkListener with Logging { activeStages.remove(stage.stageId) if (stage.failureReason.isEmpty) { completedStages += stage + numCompletedStages += 1 trimIfNecessary(completedStages) } else { failedStages += stage + numFailedStages += 1 trimIfNecessary(failedStages) } } @@ -89,7 +130,10 @@ class JobProgressListener(conf: SparkConf) extends SparkListener with Logging { private def trimIfNecessary(stages: ListBuffer[StageInfo]) = synchronized { if (stages.size > retainedStages) { val toRemove = math.max(retainedStages / 10, 1) - stages.take(toRemove).foreach { s => stageIdToData.remove((s.stageId, s.attemptId)) } + stages.take(toRemove).foreach { s => + stageIdToData.remove((s.stageId, s.attemptId)) + stageIdToInfo.remove(s.stageId) + } stages.trimStart(toRemove) } } @@ -103,6 +147,7 @@ class JobProgressListener(conf: SparkConf) extends SparkListener with Logging { p => p.getProperty("spark.scheduler.pool", DEFAULT_POOL_NAME) }.getOrElse(DEFAULT_POOL_NAME) + stageIdToInfo(stage.stageId) = stage val stageData = stageIdToData.getOrElseUpdate((stage.stageId, stage.attemptId), new StageUIData) stageData.schedulingPool = poolName @@ -277,4 +322,5 @@ class JobProgressListener(conf: SparkConf) extends SparkListener with Logging { private object JobProgressListener { val DEFAULT_POOL_NAME = "default" val DEFAULT_RETAINED_STAGES = 1000 + val DEFAULT_RETAINED_JOBS = 1000 } diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressPage.scala b/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressPage.scala index 1e02f1225d344..83a7898071c9b 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressPage.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressPage.scala @@ -26,7 +26,6 @@ import org.apache.spark.ui.{WebUIPage, UIUtils} /** Page showing list of all ongoing and recently finished stages and pools */ private[ui] class JobProgressPage(parent: JobProgressTab) extends WebUIPage("") { - private val live = parent.live private val sc = parent.sc private val listener = parent.listener private def isFairScheduler = parent.isFairScheduler @@ -35,7 +34,9 @@ private[ui] class JobProgressPage(parent: JobProgressTab) extends WebUIPage("") listener.synchronized { val activeStages = listener.activeStages.values.toSeq val completedStages = listener.completedStages.reverse.toSeq + val numCompletedStages = listener.numCompletedStages val failedStages = listener.failedStages.reverse.toSeq + val numFailedStages = listener.numFailedStages val now = System.currentTimeMillis val activeStagesTable = @@ -47,17 +48,17 @@ private[ui] class JobProgressPage(parent: JobProgressTab) extends WebUIPage("") new FailedStageTable(failedStages.sortBy(_.submissionTime).reverse, parent) // For now, pool information is only accessible in live UIs - val pools = if (live) sc.getAllPools else Seq[Schedulable]() + val pools = sc.map(_.getAllPools).getOrElse(Seq.empty[Schedulable]) val poolTable = new PoolTable(pools, parent) val summary: NodeSeq =
    - {if (live) { + {if (sc.isDefined) { // Total duration is not meaningful unless the UI is live
  • Total Duration: - {UIUtils.formatDuration(now - sc.startTime)} + {UIUtils.formatDuration(now - sc.get.startTime)}
  • }}
  • @@ -70,26 +71,26 @@ private[ui] class JobProgressPage(parent: JobProgressTab) extends WebUIPage("")
  • Completed Stages: - {completedStages.size} + {numCompletedStages}
  • Failed Stages: - {failedStages.size} + {numFailedStages}
val content = summary ++ - {if (live && isFairScheduler) { + {if (sc.isDefined && isFairScheduler) {

{pools.size} Fair Scheduler Pools

++ poolTable.toNodeSeq } else { Seq[Node]() }} ++

Active Stages ({activeStages.size})

++ activeStagesTable.toNodeSeq ++ -

Completed Stages ({completedStages.size})

++ +

Completed Stages ({numCompletedStages})

++ completedStagesTable.toNodeSeq ++ -

Failed Stages ({failedStages.size})

++ +

Failed Stages ({numFailedStages})

++ failedStagesTable.toNodeSeq UIUtils.headerSparkPage("Spark Stages", content, parent) diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressTab.scala b/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressTab.scala index c16542c9db30f..03ca918e2e8b3 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressTab.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressTab.scala @@ -25,16 +25,14 @@ import org.apache.spark.ui.{SparkUI, SparkUITab} /** Web UI showing progress status of all jobs in the given SparkContext. */ private[ui] class JobProgressTab(parent: SparkUI) extends SparkUITab(parent, "stages") { - val live = parent.live val sc = parent.sc - val conf = if (live) sc.conf else new SparkConf - val killEnabled = conf.getBoolean("spark.ui.killEnabled", true) - val listener = new JobProgressListener(conf) + val conf = sc.map(_.conf).getOrElse(new SparkConf) + val killEnabled = sc.map(_.conf.getBoolean("spark.ui.killEnabled", true)).getOrElse(false) + val listener = parent.jobProgressListener attachPage(new JobProgressPage(this)) attachPage(new StagePage(this)) attachPage(new PoolPage(this)) - parent.registerListener(listener) def isFairScheduler = listener.schedulingMode.exists(_ == SchedulingMode.FAIR) @@ -43,7 +41,7 @@ private[ui] class JobProgressTab(parent: SparkUI) extends SparkUITab(parent, "st val killFlag = Option(request.getParameter("terminate")).getOrElse("false").toBoolean val stageId = Option(request.getParameter("id")).getOrElse("-1").toInt if (stageId >= 0 && killFlag && listener.activeStages.contains(stageId)) { - sc.cancelStage(stageId) + sc.get.cancelStage(stageId) } // Do a quick pause here to give Spark time to kill the stage so it shows up as // killed after the refresh. Note that this will block the serving thread so the diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/PoolPage.scala b/core/src/main/scala/org/apache/spark/ui/jobs/PoolPage.scala index 7a6c7d1a497ed..770d99eea1c9d 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/PoolPage.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/PoolPage.scala @@ -26,7 +26,6 @@ import org.apache.spark.ui.{WebUIPage, UIUtils} /** Page showing specific pool details */ private[ui] class PoolPage(parent: JobProgressTab) extends WebUIPage("pool") { - private val live = parent.live private val sc = parent.sc private val listener = parent.listener @@ -42,7 +41,7 @@ private[ui] class PoolPage(parent: JobProgressTab) extends WebUIPage("pool") { new StageTableBase(activeStages.sortBy(_.submissionTime).reverse, parent) // For now, pool information is only accessible in live UIs - val pools = if (live) Seq(sc.getPoolForName(poolName).get) else Seq[Schedulable]() + val pools = sc.map(_.getPoolForName(poolName).get).toSeq val poolTable = new PoolTable(pools, parent) val content = diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala b/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala index 2414e4c65237e..7cc03b7d333df 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala @@ -22,10 +22,11 @@ import javax.servlet.http.HttpServletRequest import scala.xml.{Node, Unparsed} +import org.apache.spark.executor.TaskMetrics import org.apache.spark.ui.{ToolTips, WebUIPage, UIUtils} import org.apache.spark.ui.jobs.UIData._ import org.apache.spark.util.{Utils, Distribution} -import org.apache.spark.scheduler.AccumulableInfo +import org.apache.spark.scheduler.{AccumulableInfo, TaskInfo} /** Page showing statistics and task list for a given stage */ private[ui] class StagePage(parent: JobProgressTab) extends WebUIPage("stage") { @@ -52,12 +53,12 @@ private[ui] class StagePage(parent: JobProgressTab) extends WebUIPage("stage") { val numCompleted = tasks.count(_.taskInfo.finished) val accumulables = listener.stageIdToData((stageId, stageAttemptId)).accumulables + val hasAccumulators = accumulables.size > 0 val hasInput = stageData.inputBytes > 0 val hasShuffleRead = stageData.shuffleReadBytes > 0 val hasShuffleWrite = stageData.shuffleWriteBytes > 0 val hasBytesSpilled = stageData.memoryBytesSpilled > 0 && stageData.diskBytesSpilled > 0 - // scalastyle:off val summary =
    @@ -65,55 +66,105 @@ private[ui] class StagePage(parent: JobProgressTab) extends WebUIPage("stage") { Total task time across all tasks: {UIUtils.formatDuration(stageData.executorRunTime)} - {if (hasInput) + {if (hasInput) {
  • Input: {Utils.bytesToString(stageData.inputBytes)}
  • - } - {if (hasShuffleRead) + }} + {if (hasShuffleRead) {
  • Shuffle read: {Utils.bytesToString(stageData.shuffleReadBytes)}
  • - } - {if (hasShuffleWrite) + }} + {if (hasShuffleWrite) {
  • Shuffle write: {Utils.bytesToString(stageData.shuffleWriteBytes)}
  • - } - {if (hasBytesSpilled) -
  • - Shuffle spill (memory): - {Utils.bytesToString(stageData.memoryBytesSpilled)} -
  • -
  • - Shuffle spill (disk): - {Utils.bytesToString(stageData.diskBytesSpilled)} -
  • - } + }} + {if (hasBytesSpilled) { +
  • + Shuffle spill (memory): + {Utils.bytesToString(stageData.memoryBytesSpilled)} +
  • +
  • + Shuffle spill (disk): + {Utils.bytesToString(stageData.diskBytesSpilled)} +
  • + }}
- // scalastyle:on + + val showAdditionalMetrics = +
+ + + Show additional metrics + + +
+ val accumulableHeaders: Seq[String] = Seq("Accumulable", "Value") def accumulableRow(acc: AccumulableInfo) = val accumulableTable = UIUtils.listingTable(accumulableHeaders, accumulableRow, accumulables.values.toSeq) - val taskHeaders: Seq[String] = + val taskHeadersAndCssClasses: Seq[(String, String)] = Seq( - "Index", "ID", "Attempt", "Status", "Locality Level", "Executor ID / Host", - "Launch Time", "Duration", "GC Time", "Accumulators") ++ - {if (hasInput) Seq("Input") else Nil} ++ - {if (hasShuffleRead) Seq("Shuffle Read") else Nil} ++ - {if (hasShuffleWrite) Seq("Write Time", "Shuffle Write") else Nil} ++ - {if (hasBytesSpilled) Seq("Shuffle Spill (Memory)", "Shuffle Spill (Disk)") else Nil} ++ - Seq("Errors") + ("Index", ""), ("ID", ""), ("Attempt", ""), ("Status", ""), ("Locality Level", ""), + ("Executor ID / Host", ""), ("Launch Time", ""), ("Duration", ""), + ("Scheduler Delay", TaskDetailsClassNames.SCHEDULER_DELAY), + ("GC Time", TaskDetailsClassNames.GC_TIME), + ("Result Serialization Time", TaskDetailsClassNames.RESULT_SERIALIZATION_TIME), + ("Getting Result Time", TaskDetailsClassNames.GETTING_RESULT_TIME)) ++ + {if (hasAccumulators) Seq(("Accumulators", "")) else Nil} ++ + {if (hasInput) Seq(("Input", "")) else Nil} ++ + {if (hasShuffleRead) Seq(("Shuffle Read", "")) else Nil} ++ + {if (hasShuffleWrite) Seq(("Write Time", ""), ("Shuffle Write", "")) else Nil} ++ + {if (hasBytesSpilled) Seq(("Shuffle Spill (Memory)", ""), ("Shuffle Spill (Disk)", "")) + else Nil} ++ + Seq(("Errors", "")) + + val unzipped = taskHeadersAndCssClasses.unzip val taskTable = UIUtils.listingTable( - taskHeaders, taskRow(hasInput, hasShuffleRead, hasShuffleWrite, hasBytesSpilled), tasks) - + unzipped._1, + taskRow(hasAccumulators, hasInput, hasShuffleRead, hasShuffleWrite, hasBytesSpilled), + tasks, + headerClasses = unzipped._2) // Excludes tasks which failed and have incomplete metrics val validTasks = tasks.filter(t => t.taskInfo.status == "SUCCESS" && t.taskMetrics.isDefined) @@ -122,18 +173,37 @@ private[ui] class StagePage(parent: JobProgressTab) extends WebUIPage("stage") { None } else { - val serializationTimes = validTasks.map { case TaskUIData(_, metrics, _) => - metrics.get.resultSerializationTime.toDouble + def getFormattedTimeQuantiles(times: Seq[Double]): Seq[Node] = { + Distribution(times).get.getQuantiles().map { millis => + + } } - val serializationQuantiles = - +: Distribution(serializationTimes). - get.getQuantiles().map(ms => ) val serviceTimes = validTasks.map { case TaskUIData(_, metrics, _) => metrics.get.executorRunTime.toDouble } - val serviceQuantiles = +: Distribution(serviceTimes).get.getQuantiles() - .map(ms => ) + val serviceQuantiles = +: getFormattedTimeQuantiles(serviceTimes) + + val gcTimes = validTasks.map { case TaskUIData(_, metrics, _) => + metrics.get.jvmGCTime.toDouble + } + val gcQuantiles = + +: getFormattedTimeQuantiles(gcTimes) + + val serializationTimes = validTasks.map { case TaskUIData(_, metrics, _) => + metrics.get.resultSerializationTime.toDouble + } + val serializationQuantiles = + +: getFormattedTimeQuantiles(serializationTimes) val gettingResultTimes = validTasks.map { case TaskUIData(info, _, _) => if (info.gettingResultTime > 0) { @@ -142,76 +212,75 @@ private[ui] class StagePage(parent: JobProgressTab) extends WebUIPage("stage") { 0.0 } } - val gettingResultQuantiles = +: - Distribution(gettingResultTimes).get.getQuantiles().map { millis => - - } + val gettingResultQuantiles = + +: + getFormattedTimeQuantiles(gettingResultTimes) // The scheduler delay includes the network delay to send the task to the worker // machine and to send back the result (but not the time to fetch the task result, // if it needed to be fetched from the block manager on the worker). val schedulerDelays = validTasks.map { case TaskUIData(info, metrics, _) => - val totalExecutionTime = { - if (info.gettingResultTime > 0) { - (info.gettingResultTime - info.launchTime).toDouble - } else { - (info.finishTime - info.launchTime).toDouble - } - } - totalExecutionTime - metrics.get.executorRunTime + getSchedulerDelay(info, metrics.get).toDouble } val schedulerDelayTitle = + title={ToolTips.SCHEDULER_DELAY} data-placement="right">Scheduler Delay val schedulerDelayQuantiles = schedulerDelayTitle +: - Distribution(schedulerDelays).get.getQuantiles().map { millis => - - } + getFormattedTimeQuantiles(schedulerDelays) - def getQuantileCols(data: Seq[Double]) = + def getFormattedSizeQuantiles(data: Seq[Double]) = Distribution(data).get.getQuantiles().map(d => ) val inputSizes = validTasks.map { case TaskUIData(_, metrics, _) => metrics.get.inputMetrics.map(_.bytesRead).getOrElse(0L).toDouble } - val inputQuantiles = +: getQuantileCols(inputSizes) + val inputQuantiles = +: getFormattedSizeQuantiles(inputSizes) val shuffleReadSizes = validTasks.map { case TaskUIData(_, metrics, _) => metrics.get.shuffleReadMetrics.map(_.remoteBytesRead).getOrElse(0L).toDouble } val shuffleReadQuantiles = +: - getQuantileCols(shuffleReadSizes) + getFormattedSizeQuantiles(shuffleReadSizes) val shuffleWriteSizes = validTasks.map { case TaskUIData(_, metrics, _) => metrics.get.shuffleWriteMetrics.map(_.shuffleBytesWritten).getOrElse(0L).toDouble } - val shuffleWriteQuantiles = +: getQuantileCols(shuffleWriteSizes) + val shuffleWriteQuantiles = +: + getFormattedSizeQuantiles(shuffleWriteSizes) val memoryBytesSpilledSizes = validTasks.map { case TaskUIData(_, metrics, _) => metrics.get.memoryBytesSpilled.toDouble } val memoryBytesSpilledQuantiles = +: - getQuantileCols(memoryBytesSpilledSizes) + getFormattedSizeQuantiles(memoryBytesSpilledSizes) val diskBytesSpilledSizes = validTasks.map { case TaskUIData(_, metrics, _) => metrics.get.diskBytesSpilled.toDouble } val diskBytesSpilledQuantiles = +: - getQuantileCols(diskBytesSpilledSizes) + getFormattedSizeQuantiles(diskBytesSpilledSizes) val listings: Seq[Seq[Node]] = Seq( - serializationQuantiles, - serviceQuantiles, - gettingResultQuantiles, - schedulerDelayQuantiles, - if (hasInput) inputQuantiles else Nil, - if (hasShuffleRead) shuffleReadQuantiles else Nil, - if (hasShuffleWrite) shuffleWriteQuantiles else Nil, - if (hasBytesSpilled) memoryBytesSpilledQuantiles else Nil, - if (hasBytesSpilled) diskBytesSpilledQuantiles else Nil) + {serviceQuantiles}, + {schedulerDelayQuantiles}, + {gcQuantiles}, + + {serializationQuantiles} + , + {gettingResultQuantiles}, + if (hasInput) {inputQuantiles} else Nil, + if (hasShuffleRead) {shuffleReadQuantiles} else Nil, + if (hasShuffleWrite) {shuffleWriteQuantiles} else Nil, + if (hasBytesSpilled) {memoryBytesSpilledQuantiles} else Nil, + if (hasBytesSpilled) {diskBytesSpilledQuantiles} else Nil) val quantileHeaders = Seq("Metric", "Min", "25th percentile", "Median", "75th percentile", "Max") - def quantileRow(data: Seq[Node]): Seq[Node] = {data} - Some(UIUtils.listingTable(quantileHeaders, quantileRow, listings, fixedWidth = true)) + Some(UIUtils.listingTable( + quantileHeaders, identity[Seq[Node]], listings, fixedWidth = true)) } val executorTable = new ExecutorTable(stageId, stageAttemptId, parent) @@ -221,6 +290,7 @@ private[ui] class StagePage(parent: JobProgressTab) extends WebUIPage("stage") { val content = summary ++ + showAdditionalMetrics ++

Summary Metrics for {numCompleted} Completed Tasks

++
{summaryTable.getOrElse("No tasks have reported metrics yet.")}
++

Aggregated Metrics by Executor

++ executorTable.toNodeSeq ++ @@ -232,6 +302,7 @@ private[ui] class StagePage(parent: JobProgressTab) extends WebUIPage("stage") { } def taskRow( + hasAccumulators: Boolean, hasInput: Boolean, hasShuffleRead: Boolean, hasShuffleWrite: Boolean, @@ -241,8 +312,13 @@ private[ui] class StagePage(parent: JobProgressTab) extends WebUIPage("stage") { else metrics.map(_.executorRunTime).getOrElse(1L) val formatDuration = if (info.status == "RUNNING") UIUtils.formatDuration(duration) else metrics.map(m => UIUtils.formatDuration(m.executorRunTime)).getOrElse("") + val schedulerDelay = metrics.map(getSchedulerDelay(info, _)).getOrElse(0L) val gcTime = metrics.map(_.jvmGCTime).getOrElse(0L) val serializationTime = metrics.map(_.resultSerializationTime).getOrElse(0L) + val gettingResultTime = info.gettingResultTime + + val maybeAccumulators = info.accumulables + val accumulatorsReadable = maybeAccumulators.map{acc => s"${acc.name}: ${acc.update.get}"} val maybeInput = metrics.flatMap(_.inputMetrics) val inputSortable = maybeInput.map(_.bytesRead.toString).getOrElse("") @@ -287,20 +363,26 @@ private[ui] class StagePage(parent: JobProgressTab) extends WebUIPage("stage") { - + - - + {if (hasAccumulators) { + + }} {if (hasInput) { } } + + private def getSchedulerDelay(info: TaskInfo, metrics: TaskMetrics): Long = { + val totalExecutionTime = { + if (info.gettingResultTime > 0) { + (info.gettingResultTime - info.launchTime) + } else { + (info.finishTime - info.launchTime) + } + } + totalExecutionTime - metrics.executorRunTime + } } diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/TaskDetailsClassNames.scala b/core/src/main/scala/org/apache/spark/ui/jobs/TaskDetailsClassNames.scala new file mode 100644 index 0000000000000..23d672cabda07 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/ui/jobs/TaskDetailsClassNames.scala @@ -0,0 +1,29 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ui.jobs + +/** + * Names of the CSS classes corresponding to each type of task detail. Used to allow users + * to optionally show/hide columns. + */ +private object TaskDetailsClassNames { + val SCHEDULER_DELAY = "scheduler_delay" + val GC_TIME = "gc_time" + val RESULT_SERIALIZATION_TIME = "serialization_time" + val GETTING_RESULT_TIME = "getting_result_time" +} diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/UIData.scala b/core/src/main/scala/org/apache/spark/ui/jobs/UIData.scala index a336bf7e1ed02..e2813f8eb5ab9 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/UIData.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/UIData.scala @@ -17,6 +17,7 @@ package org.apache.spark.ui.jobs +import org.apache.spark.JobExecutionStatus import org.apache.spark.executor.TaskMetrics import org.apache.spark.scheduler.{AccumulableInfo, TaskInfo} import org.apache.spark.util.collection.OpenHashSet @@ -36,6 +37,13 @@ private[jobs] object UIData { var diskBytesSpilled : Long = 0 } + class JobUIData( + var jobId: Int = -1, + var stageIds: Seq[Int] = Seq.empty, + var jobGroup: Option[String] = None, + var status: JobExecutionStatus = JobExecutionStatus.UNKNOWN + ) + class StageUIData { var numActiveTasks: Int = _ var numCompleteTasks: Int = _ diff --git a/core/src/main/scala/org/apache/spark/ui/storage/RDDPage.scala b/core/src/main/scala/org/apache/spark/ui/storage/RDDPage.scala index 8a0075ae8daf7..12d23a92878cf 100644 --- a/core/src/main/scala/org/apache/spark/ui/storage/RDDPage.scala +++ b/core/src/main/scala/org/apache/spark/ui/storage/RDDPage.scala @@ -39,7 +39,8 @@ private[ui] class RDDPage(parent: StorageTab) extends WebUIPage("rdd") { // Worker table val workers = storageStatusList.map((rddId, _)) - val workerTable = UIUtils.listingTable(workerHeader, workerRow, workers) + val workerTable = UIUtils.listingTable(workerHeader, workerRow, workers, + id = Some("rdd-storage-by-worker-table")) // Block table val blockLocations = StorageUtils.getRddBlockLocations(rddId, storageStatusList) @@ -49,7 +50,8 @@ private[ui] class RDDPage(parent: StorageTab) extends WebUIPage("rdd") { .map { case (blockId, status) => (blockId, status, blockLocations.get(blockId).getOrElse(Seq[String]("Unknown"))) } - val blockTable = UIUtils.listingTable(blockHeader, blockRow, blocks) + val blockTable = UIUtils.listingTable(blockHeader, blockRow, blocks, + id = Some("rdd-storage-by-block-table")) val content =
diff --git a/core/src/main/scala/org/apache/spark/ui/storage/StoragePage.scala b/core/src/main/scala/org/apache/spark/ui/storage/StoragePage.scala index 83489ca0679ee..6ced6052d2b18 100644 --- a/core/src/main/scala/org/apache/spark/ui/storage/StoragePage.scala +++ b/core/src/main/scala/org/apache/spark/ui/storage/StoragePage.scala @@ -31,7 +31,7 @@ private[ui] class StoragePage(parent: StorageTab) extends WebUIPage("") { def render(request: HttpServletRequest): Seq[Node] = { val rdds = listener.rddInfoList - val content = UIUtils.listingTable(rddHeader, rddRow, rdds) + val content = UIUtils.listingTable(rddHeader, rddRow, rdds, id = Some("storage-by-rdd-table")) UIUtils.headerSparkPage("Storage", content, parent) } diff --git a/core/src/main/scala/org/apache/spark/ui/storage/StorageTab.scala b/core/src/main/scala/org/apache/spark/ui/storage/StorageTab.scala index 76097f1c51f8e..a81291d505583 100644 --- a/core/src/main/scala/org/apache/spark/ui/storage/StorageTab.scala +++ b/core/src/main/scala/org/apache/spark/ui/storage/StorageTab.scala @@ -26,11 +26,10 @@ import org.apache.spark.storage._ /** Web UI showing storage status of all RDD's in the given SparkContext. */ private[ui] class StorageTab(parent: SparkUI) extends SparkUITab(parent, "storage") { - val listener = new StorageListener(parent.storageStatusListener) + val listener = parent.storageListener attachPage(new StoragePage(this)) attachPage(new RDDPage(this)) - parent.registerListener(listener) } /** diff --git a/core/src/main/scala/org/apache/spark/util/AkkaUtils.scala b/core/src/main/scala/org/apache/spark/util/AkkaUtils.scala index f41c8d0315cb3..10010bdfa1a51 100644 --- a/core/src/main/scala/org/apache/spark/util/AkkaUtils.scala +++ b/core/src/main/scala/org/apache/spark/util/AkkaUtils.scala @@ -159,17 +159,28 @@ private[spark] object AkkaUtils extends Logging { def askWithReply[T]( message: Any, actor: ActorRef, - retryAttempts: Int, + timeout: FiniteDuration): T = { + askWithReply[T](message, actor, maxAttempts = 1, retryInterval = Int.MaxValue, timeout) + } + + /** + * Send a message to the given actor and get its result within a default timeout, or + * throw a SparkException if this fails even after the specified number of retries. + */ + def askWithReply[T]( + message: Any, + actor: ActorRef, + maxAttempts: Int, retryInterval: Int, timeout: FiniteDuration): T = { // TODO: Consider removing multiple attempts if (actor == null) { - throw new SparkException("Error sending message as driverActor is null " + + throw new SparkException("Error sending message as actor is null " + "[message = " + message + "]") } var attempts = 0 var lastException: Exception = null - while (attempts < retryAttempts) { + while (attempts < maxAttempts) { attempts += 1 try { val future = actor.ask(message)(timeout) @@ -201,4 +212,18 @@ private[spark] object AkkaUtils extends Logging { logInfo(s"Connecting to $name: $url") Await.result(actorSystem.actorSelection(url).resolveOne(timeout), timeout) } + + def makeExecutorRef( + name: String, + conf: SparkConf, + host: String, + port: Int, + actorSystem: ActorSystem): ActorRef = { + val executorActorSystemName = SparkEnv.executorActorSystemName + Utils.checkHost(host, "Expected hostname") + val url = s"akka.tcp://$executorActorSystemName@$host:$port/user/$name" + val timeout = AkkaUtils.lookupTimeout(conf) + logInfo(s"Connecting to $name: $url") + Await.result(actorSystem.actorSelection(url).resolveOne(timeout), timeout) + } } diff --git a/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala b/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala index 5b2e7d3a7edb9..f7ae1f7f334de 100644 --- a/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala +++ b/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala @@ -272,14 +272,15 @@ private[spark] object JsonProtocol { def taskEndReasonToJson(taskEndReason: TaskEndReason): JValue = { val reason = Utils.getFormattedClassName(taskEndReason) - val json = taskEndReason match { + val json: JObject = taskEndReason match { case fetchFailed: FetchFailed => val blockManagerAddress = Option(fetchFailed.bmAddress). map(blockManagerIdToJson).getOrElse(JNothing) ("Block Manager Address" -> blockManagerAddress) ~ ("Shuffle ID" -> fetchFailed.shuffleId) ~ ("Map ID" -> fetchFailed.mapId) ~ - ("Reduce ID" -> fetchFailed.reduceId) + ("Reduce ID" -> fetchFailed.reduceId) ~ + ("Message" -> fetchFailed.message) case exceptionFailure: ExceptionFailure => val stackTrace = stackTraceToJson(exceptionFailure.stackTrace) val metrics = exceptionFailure.metrics.map(taskMetricsToJson).getOrElse(JNothing) @@ -287,6 +288,8 @@ private[spark] object JsonProtocol { ("Description" -> exceptionFailure.description) ~ ("Stack Trace" -> stackTrace) ~ ("Metrics" -> metrics) + case ExecutorLostFailure(executorId) => + ("Executor ID" -> executorId) case _ => Utils.emptyJson } ("Reason" -> reason) ~ json @@ -627,7 +630,9 @@ private[spark] object JsonProtocol { val shuffleId = (json \ "Shuffle ID").extract[Int] val mapId = (json \ "Map ID").extract[Int] val reduceId = (json \ "Reduce ID").extract[Int] - new FetchFailed(blockManagerAddress, shuffleId, mapId, reduceId) + val message = Utils.jsonOption(json \ "Message").map(_.extract[String]) + new FetchFailed(blockManagerAddress, shuffleId, mapId, reduceId, + message.getOrElse("Unknown reason")) case `exceptionFailure` => val className = (json \ "Class Name").extract[String] val description = (json \ "Description").extract[String] @@ -636,7 +641,9 @@ private[spark] object JsonProtocol { new ExceptionFailure(className, description, stackTrace, metrics) case `taskResultLost` => TaskResultLost case `taskKilled` => TaskKilled - case `executorLostFailure` => ExecutorLostFailure + case `executorLostFailure` => + val executorId = Utils.jsonOption(json \ "Executor ID").map(_.extract[String]) + ExecutorLostFailure(executorId.getOrElse("Unknown")) case `unknownReason` => UnknownReason } } diff --git a/core/src/main/scala/org/apache/spark/util/ThreadStackTrace.scala b/core/src/main/scala/org/apache/spark/util/ThreadStackTrace.scala new file mode 100644 index 0000000000000..d4e0ad93b966a --- /dev/null +++ b/core/src/main/scala/org/apache/spark/util/ThreadStackTrace.scala @@ -0,0 +1,27 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.util + +/** + * Used for shipping per-thread stacktraces from the executors to driver. + */ +private[spark] case class ThreadStackTrace( + threadId: Long, + threadName: String, + threadState: Thread.State, + stackTrace: String) diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala index e1dc49238733c..6ab94af9f3739 100644 --- a/core/src/main/scala/org/apache/spark/util/Utils.scala +++ b/core/src/main/scala/org/apache/spark/util/Utils.scala @@ -18,12 +18,13 @@ package org.apache.spark.util import java.io._ +import java.lang.management.ManagementFactory import java.net._ import java.nio.ByteBuffer +import java.util.jar.Attributes.Name import java.util.{Properties, Locale, Random, UUID} import java.util.concurrent.{ThreadFactory, ConcurrentHashMap, Executors, ThreadPoolExecutor} - -import org.eclipse.jetty.util.MultiException +import java.util.jar.{Manifest => JarManifest} import scala.collection.JavaConversions._ import scala.collection.Map @@ -33,17 +34,17 @@ import scala.reflect.ClassTag import scala.util.Try import scala.util.control.{ControlThrowable, NonFatal} -import com.google.common.io.Files +import com.google.common.io.{ByteStreams, Files} import com.google.common.util.concurrent.ThreadFactoryBuilder import org.apache.commons.lang3.SystemUtils import org.apache.hadoop.conf.Configuration import org.apache.log4j.PropertyConfigurator import org.apache.hadoop.fs.{FileSystem, FileUtil, Path} +import org.eclipse.jetty.util.MultiException import org.json4s._ import tachyon.client.{TachyonFile,TachyonFS} import org.apache.spark._ -import org.apache.spark.util.SparkUncaughtExceptionHandler import org.apache.spark.serializer.{DeserializationStream, SerializationStream, SerializerInstance} /** CallSite represents a place in user code. It can have a short and a long form. */ @@ -739,11 +740,15 @@ private[spark] object Utils extends Logging { } private def listFilesSafely(file: File): Seq[File] = { - val files = file.listFiles() - if (files == null) { - throw new IOException("Failed to list files for dir: " + file) + if (file.exists()) { + val files = file.listFiles() + if (files == null) { + throw new IOException("Failed to list files for dir: " + file) + } + files + } else { + List() } - files } /** @@ -984,11 +989,27 @@ private[spark] object Utils extends Logging { } } + /** + * Execute a block of code that returns a value, re-throwing any non-fatal uncaught + * exceptions as IOException. This is used when implementing Externalizable and Serializable's + * read and write methods, since Java's serializer will not report non-IOExceptions properly; + * see SPARK-4080 for more context. + */ + def tryOrIOException[T](block: => T): T = { + try { + block + } catch { + case e: IOException => throw e + case NonFatal(t) => throw new IOException(t) + } + } + /** Default filtering function for finding call sites using `getCallSite`. */ private def coreExclusionFunction(className: String): Boolean = { // A regular expression to match classes of the "core" Spark API that we want to skip when // finding the call site of a method. - val SPARK_CORE_CLASS_REGEX = """^org\.apache\.spark(\.api\.java)?(\.util)?(\.rdd)?\.[A-Z]""".r + val SPARK_CORE_CLASS_REGEX = + """^org\.apache\.spark(\.api\.java)?(\.util)?(\.rdd)?(\.broadcast)?\.[A-Z]""".r val SCALA_CLASS_REGEX = """^scala""".r val isSparkCoreClass = SPARK_CORE_CLASS_REGEX.findFirstIn(className).isDefined val isScalaClass = SCALA_CLASS_REGEX.findFirstIn(className).isDefined @@ -1057,8 +1078,8 @@ private[spark] object Utils extends Logging { val stream = new FileInputStream(file) try { - stream.skip(effectiveStart) - stream.read(buff) + ByteStreams.skipFully(stream, effectiveStart) + ByteStreams.readFully(stream, buff) } finally { stream.close() } @@ -1219,6 +1240,8 @@ private[spark] object Utils extends Logging { } // Handles idiosyncracies with hash (add more as required) + // This method should be kept in sync with + // org.apache.spark.network.util.JavaUtils#nonNegativeHash(). def nonNegativeHash(obj: AnyRef): Int = { // Required ? @@ -1252,12 +1275,28 @@ private[spark] object Utils extends Logging { /** * Timing method based on iterations that permit JVM JIT optimization. * @param numIters number of iterations - * @param f function to be executed + * @param f function to be executed. If prepare is not None, the running time of each call to f + * must be an order of magnitude longer than one millisecond for accurate timing. + * @param prepare function to be executed before each call to f. Its running time doesn't count. + * @return the total time across all iterations (not couting preparation time) */ - def timeIt(numIters: Int)(f: => Unit): Long = { - val start = System.currentTimeMillis - times(numIters)(f) - System.currentTimeMillis - start + def timeIt(numIters: Int)(f: => Unit, prepare: Option[() => Unit] = None): Long = { + if (prepare.isEmpty) { + val start = System.currentTimeMillis + times(numIters)(f) + System.currentTimeMillis - start + } else { + var i = 0 + var sum = 0L + while (i < numIters) { + prepare.get.apply() + val start = System.currentTimeMillis + f + sum += System.currentTimeMillis - start + i += 1 + } + sum + } } /** @@ -1346,6 +1385,11 @@ private[spark] object Utils extends Logging { */ val isWindows = SystemUtils.IS_OS_WINDOWS + /** + * Whether the underlying operating system is Mac OS X. + */ + val isMac = SystemUtils.IS_OS_MAC_OSX + /** * Pattern for matching a Windows drive, which contains only a single alphabet character. */ @@ -1554,7 +1598,7 @@ private[spark] object Utils extends Logging { } /** Return a nice string representation of the exception, including the stack trace. */ - def exceptionString(e: Exception): String = { + def exceptionString(e: Throwable): String = { if (e == null) "" else exceptionString(getFormattedClassName(e), e.getMessage, e.getStackTrace) } @@ -1568,6 +1612,18 @@ private[spark] object Utils extends Logging { s"$className: $desc\n$st" } + /** Return a thread dump of all threads' stacktraces. Used to capture dumps for the web UI */ + def getThreadDump(): Array[ThreadStackTrace] = { + // We need to filter out null values here because dumpAllThreads() may return null array + // elements for threads that are dead / don't exist. + val threadInfos = ManagementFactory.getThreadMXBean.dumpAllThreads(true, true).filter(_ != null) + threadInfos.sortBy(_.getThreadId).map { case threadInfo => + val stackTrace = threadInfo.getStackTrace.map(_.toString).mkString("\n") + ThreadStackTrace(threadInfo.getThreadId, threadInfo.getThreadName, + threadInfo.getThreadState, stackTrace) + } + } + /** * Convert all spark properties set in the given SparkConf to a sequence of java options. */ @@ -1668,6 +1724,62 @@ private[spark] object Utils extends Logging { PropertyConfigurator.configure(pro) } + def invoke( + clazz: Class[_], + obj: AnyRef, + methodName: String, + args: (Class[_], AnyRef)*): AnyRef = { + val (types, values) = args.unzip + val method = clazz.getDeclaredMethod(methodName, types: _*) + method.setAccessible(true) + method.invoke(obj, values.toSeq: _*) + } + + // Limit of bytes for total size of results (default is 1GB) + def getMaxResultSize(conf: SparkConf): Long = { + memoryStringToMb(conf.get("spark.driver.maxResultSize", "1g")).toLong << 20 + } + + /** + * Return the current system LD_LIBRARY_PATH name + */ + def libraryPathEnvName: String = { + if (isWindows) { + "PATH" + } else if (isMac) { + "DYLD_LIBRARY_PATH" + } else { + "LD_LIBRARY_PATH" + } + } + + /** + * Return the prefix of a command that appends the given library paths to the + * system-specific library path environment variable. On Unix, for instance, + * this returns the string LD_LIBRARY_PATH="path1:path2:$LD_LIBRARY_PATH". + */ + def libraryPathEnvPrefix(libraryPaths: Seq[String]): String = { + val libraryPathScriptVar = if (isWindows) { + s"%${libraryPathEnvName}%" + } else { + "$" + libraryPathEnvName + } + val libraryPath = (libraryPaths :+ libraryPathScriptVar).mkString("\"", + File.pathSeparator, "\"") + val ampersand = if (Utils.isWindows) { + " &" + } else { + "" + } + s"$libraryPathEnvName=$libraryPath$ampersand" + } + + lazy val sparkVersion = + SparkContext.jarOfObject(this).map { path => + val manifestUrl = new URL(s"jar:file:$path!/META-INF/MANIFEST.MF") + val manifest = new JarManifest(manifestUrl.openStream()) + manifest.getMainAttributes.getValue(Name.IMPLEMENTATION_VERSION) + }.getOrElse("Unknown") } /** diff --git a/core/src/main/scala/org/apache/spark/util/collection/SortDataFormat.scala b/core/src/main/scala/org/apache/spark/util/collection/SortDataFormat.scala index ac1528969f0be..4f0bf8384afc9 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/SortDataFormat.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/SortDataFormat.scala @@ -27,33 +27,51 @@ import scala.reflect.ClassTag * Example format: an array of numbers, where each element is also the key. * See [[KVArraySortDataFormat]] for a more exciting format. * - * This trait extends Any to ensure it is universal (and thus compiled to a Java interface). + * Note: Declaring and instantiating multiple subclasses of this class would prevent JIT inlining + * overridden methods and hence decrease the shuffle performance. * * @tparam K Type of the sort key of each element * @tparam Buffer Internal data structure used by a particular format (e.g., Array[Int]). */ // TODO: Making Buffer a real trait would be a better abstraction, but adds some complexity. -private[spark] trait SortDataFormat[K, Buffer] extends Any { +private[spark] +abstract class SortDataFormat[K, Buffer] { + + /** + * Creates a new mutable key for reuse. This should be implemented if you want to override + * [[getKey(Buffer, Int, K)]]. + */ + def newKey(): K = null.asInstanceOf[K] + /** Return the sort key for the element at the given index. */ protected def getKey(data: Buffer, pos: Int): K + /** + * Returns the sort key for the element at the given index and reuse the input key if possible. + * The default implementation ignores the reuse parameter and invokes [[getKey(Buffer, Int]]. + * If you want to override this method, you must implement [[newKey()]]. + */ + def getKey(data: Buffer, pos: Int, reuse: K): K = { + getKey(data, pos) + } + /** Swap two elements. */ - protected def swap(data: Buffer, pos0: Int, pos1: Int): Unit + def swap(data: Buffer, pos0: Int, pos1: Int): Unit /** Copy a single element from src(srcPos) to dst(dstPos). */ - protected def copyElement(src: Buffer, srcPos: Int, dst: Buffer, dstPos: Int): Unit + def copyElement(src: Buffer, srcPos: Int, dst: Buffer, dstPos: Int): Unit /** * Copy a range of elements starting at src(srcPos) to dst, starting at dstPos. * Overlapping ranges are allowed. */ - protected def copyRange(src: Buffer, srcPos: Int, dst: Buffer, dstPos: Int, length: Int): Unit + def copyRange(src: Buffer, srcPos: Int, dst: Buffer, dstPos: Int, length: Int): Unit /** * Allocates a Buffer that can hold up to 'length' elements. * All elements of the buffer should be considered invalid until data is explicitly copied in. */ - protected def allocate(length: Int): Buffer + def allocate(length: Int): Buffer } /** @@ -67,9 +85,9 @@ private[spark] trait SortDataFormat[K, Buffer] extends Any { private[spark] class KVArraySortDataFormat[K, T <: AnyRef : ClassTag] extends SortDataFormat[K, Array[T]] { - override protected def getKey(data: Array[T], pos: Int): K = data(2 * pos).asInstanceOf[K] + override def getKey(data: Array[T], pos: Int): K = data(2 * pos).asInstanceOf[K] - override protected def swap(data: Array[T], pos0: Int, pos1: Int) { + override def swap(data: Array[T], pos0: Int, pos1: Int) { val tmpKey = data(2 * pos0) val tmpVal = data(2 * pos0 + 1) data(2 * pos0) = data(2 * pos1) @@ -78,17 +96,16 @@ class KVArraySortDataFormat[K, T <: AnyRef : ClassTag] extends SortDataFormat[K, data(2 * pos1 + 1) = tmpVal } - override protected def copyElement(src: Array[T], srcPos: Int, dst: Array[T], dstPos: Int) { + override def copyElement(src: Array[T], srcPos: Int, dst: Array[T], dstPos: Int) { dst(2 * dstPos) = src(2 * srcPos) dst(2 * dstPos + 1) = src(2 * srcPos + 1) } - override protected def copyRange(src: Array[T], srcPos: Int, - dst: Array[T], dstPos: Int, length: Int) { + override def copyRange(src: Array[T], srcPos: Int, dst: Array[T], dstPos: Int, length: Int) { System.arraycopy(src, 2 * srcPos, dst, 2 * dstPos, 2 * length) } - override protected def allocate(length: Int): Array[T] = { + override def allocate(length: Int): Array[T] = { new Array[T](2 * length) } } diff --git a/core/src/main/scala/org/apache/spark/util/collection/Sorter.scala b/core/src/main/scala/org/apache/spark/util/collection/Sorter.scala new file mode 100644 index 0000000000000..39f66b8c428c6 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/util/collection/Sorter.scala @@ -0,0 +1,39 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.util.collection + +import java.util.Comparator + +/** + * A simple wrapper over the Java implementation [[TimSort]]. + * + * The Java implementation is package private, and hence it cannot be called outside package + * org.apache.spark.util.collection. This is a simple wrapper of it that is available to spark. + */ +private[spark] +class Sorter[K, Buffer](private val s: SortDataFormat[K, Buffer]) { + + private val timSort = new TimSort(s) + + /** + * Sorts the input buffer within range [lo, hi). + */ + def sort(a: Buffer, lo: Int, hi: Int, c: Comparator[_ >: K]): Unit = { + timSort.sort(a, lo, hi, c) + } +} diff --git a/core/src/main/scala/org/apache/spark/util/random/RandomSampler.scala b/core/src/main/scala/org/apache/spark/util/random/RandomSampler.scala index 32c5fdad75e58..76e7a2760bcd1 100644 --- a/core/src/main/scala/org/apache/spark/util/random/RandomSampler.scala +++ b/core/src/main/scala/org/apache/spark/util/random/RandomSampler.scala @@ -19,8 +19,10 @@ package org.apache.spark.util.random import java.util.Random -import cern.jet.random.Poisson -import cern.jet.random.engine.DRand +import scala.reflect.ClassTag +import scala.collection.mutable.ArrayBuffer + +import org.apache.commons.math3.distribution.PoissonDistribution import org.apache.spark.annotation.DeveloperApi @@ -39,13 +41,47 @@ trait RandomSampler[T, U] extends Pseudorandom with Cloneable with Serializable /** take a random sample */ def sample(items: Iterator[T]): Iterator[U] + /** return a copy of the RandomSampler object */ override def clone: RandomSampler[T, U] = throw new NotImplementedError("clone() is not implemented.") } +private[spark] +object RandomSampler { + /** Default random number generator used by random samplers. */ + def newDefaultRNG: Random = new XORShiftRandom + + /** + * Default maximum gap-sampling fraction. + * For sampling fractions <= this value, the gap sampling optimization will be applied. + * Above this value, it is assumed that "tradtional" Bernoulli sampling is faster. The + * optimal value for this will depend on the RNG. More expensive RNGs will tend to make + * the optimal value higher. The most reliable way to determine this value for a new RNG + * is to experiment. When tuning for a new RNG, I would expect a value of 0.5 to be close + * in most cases, as an initial guess. + */ + val defaultMaxGapSamplingFraction = 0.4 + + /** + * Default epsilon for floating point numbers sampled from the RNG. + * The gap-sampling compute logic requires taking log(x), where x is sampled from an RNG. + * To guard against errors from taking log(0), a positive epsilon lower bound is applied. + * A good value for this parameter is at or near the minimum positive floating + * point value returned by "nextDouble()" (or equivalent), for the RNG being used. + */ + val rngEpsilon = 5e-11 + + /** + * Sampling fraction arguments may be results of computation, and subject to floating + * point jitter. I check the arguments with this epsilon slop factor to prevent spurious + * warnings for cases such as summing some numbers to get a sampling fraction of 1.000000001 + */ + val roundingEpsilon = 1e-6 +} + /** * :: DeveloperApi :: - * A sampler based on Bernoulli trials. + * A sampler based on Bernoulli trials for partitioning a data sequence. * * @param lb lower bound of the acceptance range * @param ub upper bound of the acceptance range @@ -53,56 +89,262 @@ trait RandomSampler[T, U] extends Pseudorandom with Cloneable with Serializable * @tparam T item type */ @DeveloperApi -class BernoulliSampler[T](lb: Double, ub: Double, complement: Boolean = false) +class BernoulliCellSampler[T](lb: Double, ub: Double, complement: Boolean = false) extends RandomSampler[T, T] { - private[random] var rng: Random = new XORShiftRandom + /** epsilon slop to avoid failure from floating point jitter. */ + require( + lb <= (ub + RandomSampler.roundingEpsilon), + s"Lower bound ($lb) must be <= upper bound ($ub)") + require( + lb >= (0.0 - RandomSampler.roundingEpsilon), + s"Lower bound ($lb) must be >= 0.0") + require( + ub <= (1.0 + RandomSampler.roundingEpsilon), + s"Upper bound ($ub) must be <= 1.0") - def this(ratio: Double) = this(0.0d, ratio) + private val rng: Random = new XORShiftRandom override def setSeed(seed: Long) = rng.setSeed(seed) override def sample(items: Iterator[T]): Iterator[T] = { - items.filter { item => - val x = rng.nextDouble() - (x >= lb && x < ub) ^ complement + if (ub - lb <= 0.0) { + if (complement) items else Iterator.empty + } else { + if (complement) { + items.filter { item => { + val x = rng.nextDouble() + (x < lb) || (x >= ub) + }} + } else { + items.filter { item => { + val x = rng.nextDouble() + (x >= lb) && (x < ub) + }} + } } } /** * Return a sampler that is the complement of the range specified of the current sampler. */ - def cloneComplement(): BernoulliSampler[T] = new BernoulliSampler[T](lb, ub, !complement) + def cloneComplement(): BernoulliCellSampler[T] = + new BernoulliCellSampler[T](lb, ub, !complement) + + override def clone = new BernoulliCellSampler[T](lb, ub, complement) +} + + +/** + * :: DeveloperApi :: + * A sampler based on Bernoulli trials. + * + * @param fraction the sampling fraction, aka Bernoulli sampling probability + * @tparam T item type + */ +@DeveloperApi +class BernoulliSampler[T: ClassTag](fraction: Double) extends RandomSampler[T, T] { + + /** epsilon slop to avoid failure from floating point jitter */ + require( + fraction >= (0.0 - RandomSampler.roundingEpsilon) + && fraction <= (1.0 + RandomSampler.roundingEpsilon), + s"Sampling fraction ($fraction) must be on interval [0, 1]") - override def clone = new BernoulliSampler[T](lb, ub, complement) + private val rng: Random = RandomSampler.newDefaultRNG + + override def setSeed(seed: Long) = rng.setSeed(seed) + + override def sample(items: Iterator[T]): Iterator[T] = { + if (fraction <= 0.0) { + Iterator.empty + } else if (fraction >= 1.0) { + items + } else if (fraction <= RandomSampler.defaultMaxGapSamplingFraction) { + new GapSamplingIterator(items, fraction, rng, RandomSampler.rngEpsilon) + } else { + items.filter { _ => rng.nextDouble() <= fraction } + } + } + + override def clone = new BernoulliSampler[T](fraction) } + /** * :: DeveloperApi :: - * A sampler based on values drawn from Poisson distribution. + * A sampler for sampling with replacement, based on values drawn from Poisson distribution. * - * @param mean Poisson mean + * @param fraction the sampling fraction (with replacement) * @tparam T item type */ @DeveloperApi -class PoissonSampler[T](mean: Double) extends RandomSampler[T, T] { +class PoissonSampler[T: ClassTag](fraction: Double) extends RandomSampler[T, T] { + + /** Epsilon slop to avoid failure from floating point jitter. */ + require( + fraction >= (0.0 - RandomSampler.roundingEpsilon), + s"Sampling fraction ($fraction) must be >= 0") - private[random] var rng = new Poisson(mean, new DRand) + // PoissonDistribution throws an exception when fraction <= 0 + // If fraction is <= 0, Iterator.empty is used below, so we can use any placeholder value. + private val rng = new PoissonDistribution(if (fraction > 0.0) fraction else 1.0) + private val rngGap = RandomSampler.newDefaultRNG override def setSeed(seed: Long) { - rng = new Poisson(mean, new DRand(seed.toInt)) + rng.reseedRandomGenerator(seed) + rngGap.setSeed(seed) } override def sample(items: Iterator[T]): Iterator[T] = { - items.flatMap { item => - val count = rng.nextInt() - if (count == 0) { - Iterator.empty - } else { - Iterator.fill(count)(item) - } + if (fraction <= 0.0) { + Iterator.empty + } else if (fraction <= RandomSampler.defaultMaxGapSamplingFraction) { + new GapSamplingReplacementIterator(items, fraction, rngGap, RandomSampler.rngEpsilon) + } else { + items.flatMap { item => { + val count = rng.sample() + if (count == 0) Iterator.empty else Iterator.fill(count)(item) + }} + } + } + + override def clone = new PoissonSampler[T](fraction) +} + + +private[spark] +class GapSamplingIterator[T: ClassTag]( + var data: Iterator[T], + f: Double, + rng: Random = RandomSampler.newDefaultRNG, + epsilon: Double = RandomSampler.rngEpsilon) extends Iterator[T] { + + require(f > 0.0 && f < 1.0, s"Sampling fraction ($f) must reside on open interval (0, 1)") + require(epsilon > 0.0, s"epsilon ($epsilon) must be > 0") + + /** implement efficient linear-sequence drop until Scala includes fix for jira SI-8835. */ + private val iterDrop: Int => Unit = { + val arrayClass = Array.empty[T].iterator.getClass + val arrayBufferClass = ArrayBuffer.empty[T].iterator.getClass + data.getClass match { + case `arrayClass` => ((n: Int) => { data = data.drop(n) }) + case `arrayBufferClass` => ((n: Int) => { data = data.drop(n) }) + case _ => ((n: Int) => { + var j = 0 + while (j < n && data.hasNext) { + data.next() + j += 1 + } + }) + } + } + + override def hasNext: Boolean = data.hasNext + + override def next(): T = { + val r = data.next() + advance + r + } + + private val lnq = math.log1p(-f) + + /** skip elements that won't be sampled, according to geometric dist P(k) = (f)(1-f)^k. */ + private def advance: Unit = { + val u = math.max(rng.nextDouble(), epsilon) + val k = (math.log(u) / lnq).toInt + iterDrop(k) + } + + /** advance to first sample as part of object construction. */ + advance + // Attempting to invoke this closer to the top with other object initialization + // was causing it to break in strange ways, so I'm invoking it last, which seems to + // work reliably. +} + +private[spark] +class GapSamplingReplacementIterator[T: ClassTag]( + var data: Iterator[T], + f: Double, + rng: Random = RandomSampler.newDefaultRNG, + epsilon: Double = RandomSampler.rngEpsilon) extends Iterator[T] { + + require(f > 0.0, s"Sampling fraction ($f) must be > 0") + require(epsilon > 0.0, s"epsilon ($epsilon) must be > 0") + + /** implement efficient linear-sequence drop until scala includes fix for jira SI-8835. */ + private val iterDrop: Int => Unit = { + val arrayClass = Array.empty[T].iterator.getClass + val arrayBufferClass = ArrayBuffer.empty[T].iterator.getClass + data.getClass match { + case `arrayClass` => ((n: Int) => { data = data.drop(n) }) + case `arrayBufferClass` => ((n: Int) => { data = data.drop(n) }) + case _ => ((n: Int) => { + var j = 0 + while (j < n && data.hasNext) { + data.next() + j += 1 + } + }) + } + } + + /** current sampling value, and its replication factor, as we are sampling with replacement. */ + private var v: T = _ + private var rep: Int = 0 + + override def hasNext: Boolean = data.hasNext || rep > 0 + + override def next(): T = { + val r = v + rep -= 1 + if (rep <= 0) advance + r + } + + /** + * Skip elements with replication factor zero (i.e. elements that won't be sampled). + * Samples 'k' from geometric distribution P(k) = (1-q)(q)^k, where q = e^(-f), that is + * q is the probabililty of Poisson(0; f) + */ + private def advance: Unit = { + val u = math.max(rng.nextDouble(), epsilon) + val k = (math.log(u) / (-f)).toInt + iterDrop(k) + // set the value and replication factor for the next value + if (data.hasNext) { + v = data.next() + rep = poissonGE1 + } + } + + private val q = math.exp(-f) + + /** + * Sample from Poisson distribution, conditioned such that the sampled value is >= 1. + * This is an adaptation from the algorithm for Generating Poisson distributed random variables: + * http://en.wikipedia.org/wiki/Poisson_distribution + */ + private def poissonGE1: Int = { + // simulate that the standard poisson sampling + // gave us at least one iteration, for a sample of >= 1 + var pp = q + ((1.0 - q) * rng.nextDouble()) + var r = 1 + + // now continue with standard poisson sampling algorithm + pp *= rng.nextDouble() + while (pp > q) { + r += 1 + pp *= rng.nextDouble() } + r } - override def clone = new PoissonSampler[T](mean) + /** advance to first sample as part of object construction. */ + advance + // Attempting to invoke this closer to the top with other object initialization + // was causing it to break in strange ways, so I'm invoking it last, which seems to + // work reliably. } diff --git a/core/src/main/scala/org/apache/spark/util/random/StratifiedSamplingUtils.scala b/core/src/main/scala/org/apache/spark/util/random/StratifiedSamplingUtils.scala index 8f95d7c6b799b..4fa357edd6f07 100644 --- a/core/src/main/scala/org/apache/spark/util/random/StratifiedSamplingUtils.scala +++ b/core/src/main/scala/org/apache/spark/util/random/StratifiedSamplingUtils.scala @@ -22,8 +22,7 @@ import scala.collection.mutable import scala.collection.mutable.ArrayBuffer import scala.reflect.ClassTag -import cern.jet.random.Poisson -import cern.jet.random.engine.DRand +import org.apache.commons.math3.distribution.PoissonDistribution import org.apache.spark.Logging import org.apache.spark.SparkContext._ @@ -209,7 +208,7 @@ private[spark] object StratifiedSamplingUtils extends Logging { samplingRateByKey = computeThresholdByKey(finalResult, fractions) } (idx: Int, iter: Iterator[(K, V)]) => { - val rng = new RandomDataGenerator + val rng = new RandomDataGenerator() rng.reSeed(seed + idx) // Must use the same invoke pattern on the rng as in getSeqOp for without replacement // in order to generate the same sequence of random numbers when creating the sample @@ -245,9 +244,9 @@ private[spark] object StratifiedSamplingUtils extends Logging { // Must use the same invoke pattern on the rng as in getSeqOp for with replacement // in order to generate the same sequence of random numbers when creating the sample val copiesAccepted = if (acceptBound == 0) 0L else rng.nextPoisson(acceptBound) - val copiesWailisted = rng.nextPoisson(finalResult(key).waitListBound) + val copiesWaitlisted = rng.nextPoisson(finalResult(key).waitListBound) val copiesInSample = copiesAccepted + - (0 until copiesWailisted).count(i => rng.nextUniform() < thresholdByKey(key)) + (0 until copiesWaitlisted).count(i => rng.nextUniform() < thresholdByKey(key)) if (copiesInSample > 0) { Iterator.fill(copiesInSample.toInt)(item) } else { @@ -261,10 +260,10 @@ private[spark] object StratifiedSamplingUtils extends Logging { rng.reSeed(seed + idx) iter.flatMap { item => val count = rng.nextPoisson(fractions(item._1)) - if (count > 0) { - Iterator.fill(count)(item) - } else { + if (count == 0) { Iterator.empty + } else { + Iterator.fill(count)(item) } } } @@ -274,15 +273,24 @@ private[spark] object StratifiedSamplingUtils extends Logging { /** A random data generator that generates both uniform values and Poisson values. */ private class RandomDataGenerator { val uniform = new XORShiftRandom() - var poisson = new Poisson(1.0, new DRand) + // commons-math3 doesn't have a method to generate Poisson from an arbitrary mean; + // maintain a cache of Poisson(m) distributions for various m + val poissonCache = mutable.Map[Double, PoissonDistribution]() + var poissonSeed = 0L - def reSeed(seed: Long) { + def reSeed(seed: Long): Unit = { uniform.setSeed(seed) - poisson = new Poisson(1.0, new DRand(seed.toInt)) + poissonSeed = seed + poissonCache.clear() } def nextPoisson(mean: Double): Int = { - poisson.nextInt(mean) + val poisson = poissonCache.getOrElseUpdate(mean, { + val newPoisson = new PoissonDistribution(mean) + newPoisson.reseedRandomGenerator(poissonSeed) + newPoisson + }) + poisson.sample() } def nextUniform(): Double = { diff --git a/core/src/main/scala/org/apache/spark/util/random/XORShiftRandom.scala b/core/src/main/scala/org/apache/spark/util/random/XORShiftRandom.scala index 55b5713706178..467b890fb4bb9 100644 --- a/core/src/main/scala/org/apache/spark/util/random/XORShiftRandom.scala +++ b/core/src/main/scala/org/apache/spark/util/random/XORShiftRandom.scala @@ -96,13 +96,9 @@ private[spark] object XORShiftRandom { xorRand.nextInt() } - val iters = timeIt(numIters)(_) - /* Return results as a map instead of just printing to screen in case the user wants to do something with them */ - Map("javaTime" -> iters {javaRand.nextInt()}, - "xorTime" -> iters {xorRand.nextInt()}) - + Map("javaTime" -> timeIt(numIters) { javaRand.nextInt() }, + "xorTime" -> timeIt(numIters) { xorRand.nextInt() }) } - } diff --git a/core/src/test/java/org/apache/spark/JavaAPISuite.java b/core/src/test/java/org/apache/spark/JavaAPISuite.java index 814e40c4f77cc..59c86eecac5e8 100644 --- a/core/src/test/java/org/apache/spark/JavaAPISuite.java +++ b/core/src/test/java/org/apache/spark/JavaAPISuite.java @@ -18,10 +18,13 @@ package org.apache.spark; import java.io.*; +import java.nio.channels.FileChannel; +import java.nio.ByteBuffer; import java.net.URI; import java.util.*; import java.util.concurrent.*; +import org.apache.spark.input.PortableDataStream; import scala.Tuple2; import scala.Tuple3; import scala.Tuple4; @@ -140,11 +143,10 @@ public void intersection() { public void sample() { List ints = Arrays.asList(1, 2, 3, 4, 5, 6, 7, 8, 9, 10); JavaRDD rdd = sc.parallelize(ints); - JavaRDD sample20 = rdd.sample(true, 0.2, 11); - // expected 2 but of course result varies randomly a bit - Assert.assertEquals(3, sample20.count()); - JavaRDD sample20NoReplacement = rdd.sample(false, 0.2, 11); - Assert.assertEquals(2, sample20NoReplacement.count()); + JavaRDD sample20 = rdd.sample(true, 0.2, 3); + Assert.assertEquals(2, sample20.count()); + JavaRDD sample20WithoutReplacement = rdd.sample(false, 0.2, 5); + Assert.assertEquals(2, sample20WithoutReplacement.count()); } @Test @@ -864,6 +866,82 @@ public Tuple2 call(Tuple2 pair) { Assert.assertEquals(pairs, readRDD.collect()); } + @Test + public void binaryFiles() throws Exception { + // Reusing the wholeText files example + byte[] content1 = "spark is easy to use.\n".getBytes("utf-8"); + + String tempDirName = tempDir.getAbsolutePath(); + File file1 = new File(tempDirName + "/part-00000"); + + FileOutputStream fos1 = new FileOutputStream(file1); + + FileChannel channel1 = fos1.getChannel(); + ByteBuffer bbuf = java.nio.ByteBuffer.wrap(content1); + channel1.write(bbuf); + channel1.close(); + JavaPairRDD readRDD = sc.binaryFiles(tempDirName, 3); + List> result = readRDD.collect(); + for (Tuple2 res : result) { + Assert.assertArrayEquals(content1, res._2().toArray()); + } + } + + @Test + public void binaryFilesCaching() throws Exception { + // Reusing the wholeText files example + byte[] content1 = "spark is easy to use.\n".getBytes("utf-8"); + + String tempDirName = tempDir.getAbsolutePath(); + File file1 = new File(tempDirName + "/part-00000"); + + FileOutputStream fos1 = new FileOutputStream(file1); + + FileChannel channel1 = fos1.getChannel(); + ByteBuffer bbuf = java.nio.ByteBuffer.wrap(content1); + channel1.write(bbuf); + channel1.close(); + + JavaPairRDD readRDD = sc.binaryFiles(tempDirName).cache(); + readRDD.foreach(new VoidFunction>() { + @Override + public void call(Tuple2 pair) throws Exception { + pair._2().toArray(); // force the file to read + } + }); + + List> result = readRDD.collect(); + for (Tuple2 res : result) { + Assert.assertArrayEquals(content1, res._2().toArray()); + } + } + + @Test + public void binaryRecords() throws Exception { + // Reusing the wholeText files example + byte[] content1 = "spark isn't always easy to use.\n".getBytes("utf-8"); + int numOfCopies = 10; + String tempDirName = tempDir.getAbsolutePath(); + File file1 = new File(tempDirName + "/part-00000"); + + FileOutputStream fos1 = new FileOutputStream(file1); + + FileChannel channel1 = fos1.getChannel(); + + for (int i = 0; i < numOfCopies; i++) { + ByteBuffer bbuf = java.nio.ByteBuffer.wrap(content1); + channel1.write(bbuf); + } + channel1.close(); + + JavaRDD readRDD = sc.binaryRecords(tempDirName, content1.length); + Assert.assertEquals(numOfCopies,readRDD.count()); + List result = readRDD.collect(); + for (byte[] res : result) { + Assert.assertArrayEquals(content1, res); + } + } + @SuppressWarnings("unchecked") @Test public void writeWithNewAPIHadoopFile() { diff --git a/core/src/test/scala/org/apache/spark/DistributedSuite.scala b/core/src/test/scala/org/apache/spark/DistributedSuite.scala index 81b64c36ddca1..429199f2075c6 100644 --- a/core/src/test/scala/org/apache/spark/DistributedSuite.scala +++ b/core/src/test/scala/org/apache/spark/DistributedSuite.scala @@ -202,7 +202,8 @@ class DistributedSuite extends FunSuite with Matchers with BeforeAndAfter val blockManager = SparkEnv.get.blockManager val blockTransfer = SparkEnv.get.blockTransferService blockManager.master.getLocations(blockId).foreach { cmId => - val bytes = blockTransfer.fetchBlockSync(cmId.host, cmId.port, blockId.toString) + val bytes = blockTransfer.fetchBlockSync(cmId.host, cmId.port, cmId.executorId, + blockId.toString) val deserialized = blockManager.dataDeserialize(blockId, bytes.nioByteBuffer()) .asInstanceOf[Iterator[Int]].toList assert(deserialized === (1 to 100).toList) diff --git a/core/src/test/scala/org/apache/spark/ExecutorAllocationManagerSuite.scala b/core/src/test/scala/org/apache/spark/ExecutorAllocationManagerSuite.scala new file mode 100644 index 0000000000000..66cf60d25f6d1 --- /dev/null +++ b/core/src/test/scala/org/apache/spark/ExecutorAllocationManagerSuite.scala @@ -0,0 +1,662 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark + +import org.scalatest.{FunSuite, PrivateMethodTester} +import org.apache.spark.executor.TaskMetrics +import org.apache.spark.scheduler._ +import org.apache.spark.storage.BlockManagerId + +/** + * Test add and remove behavior of ExecutorAllocationManager. + */ +class ExecutorAllocationManagerSuite extends FunSuite with LocalSparkContext { + import ExecutorAllocationManager._ + import ExecutorAllocationManagerSuite._ + + test("verify min/max executors") { + // No min or max + val conf = new SparkConf() + .setMaster("local") + .setAppName("test-executor-allocation-manager") + .set("spark.dynamicAllocation.enabled", "true") + intercept[SparkException] { new SparkContext(conf) } + SparkEnv.get.stop() // cleanup the created environment + + // Only min + val conf1 = conf.clone().set("spark.dynamicAllocation.minExecutors", "1") + intercept[SparkException] { new SparkContext(conf1) } + SparkEnv.get.stop() + + // Only max + val conf2 = conf.clone().set("spark.dynamicAllocation.maxExecutors", "2") + intercept[SparkException] { new SparkContext(conf2) } + SparkEnv.get.stop() + + // Both min and max, but min > max + intercept[SparkException] { createSparkContext(2, 1) } + SparkEnv.get.stop() + + // Both min and max, and min == max + val sc1 = createSparkContext(1, 1) + assert(sc1.executorAllocationManager.isDefined) + sc1.stop() + + // Both min and max, and min < max + val sc2 = createSparkContext(1, 2) + assert(sc2.executorAllocationManager.isDefined) + sc2.stop() + } + + test("starting state") { + sc = createSparkContext() + val manager = sc.executorAllocationManager.get + assert(numExecutorsPending(manager) === 0) + assert(executorsPendingToRemove(manager).isEmpty) + assert(executorIds(manager).isEmpty) + assert(addTime(manager) === ExecutorAllocationManager.NOT_SET) + assert(removeTimes(manager).isEmpty) + } + + test("add executors") { + sc = createSparkContext(1, 10) + val manager = sc.executorAllocationManager.get + + // Keep adding until the limit is reached + assert(numExecutorsPending(manager) === 0) + assert(numExecutorsToAdd(manager) === 1) + assert(addExecutors(manager) === 1) + assert(numExecutorsPending(manager) === 1) + assert(numExecutorsToAdd(manager) === 2) + assert(addExecutors(manager) === 2) + assert(numExecutorsPending(manager) === 3) + assert(numExecutorsToAdd(manager) === 4) + assert(addExecutors(manager) === 4) + assert(numExecutorsPending(manager) === 7) + assert(numExecutorsToAdd(manager) === 8) + assert(addExecutors(manager) === 3) // reached the limit of 10 + assert(numExecutorsPending(manager) === 10) + assert(numExecutorsToAdd(manager) === 1) + assert(addExecutors(manager) === 0) + assert(numExecutorsPending(manager) === 10) + assert(numExecutorsToAdd(manager) === 1) + + // Register previously requested executors + onExecutorAdded(manager, "first") + assert(numExecutorsPending(manager) === 9) + onExecutorAdded(manager, "second") + onExecutorAdded(manager, "third") + onExecutorAdded(manager, "fourth") + assert(numExecutorsPending(manager) === 6) + onExecutorAdded(manager, "first") // duplicates should not count + onExecutorAdded(manager, "second") + assert(numExecutorsPending(manager) === 6) + + // Try adding again + // This should still fail because the number pending + running is still at the limit + assert(addExecutors(manager) === 0) + assert(numExecutorsPending(manager) === 6) + assert(numExecutorsToAdd(manager) === 1) + assert(addExecutors(manager) === 0) + assert(numExecutorsPending(manager) === 6) + assert(numExecutorsToAdd(manager) === 1) + } + + test("remove executors") { + sc = createSparkContext(5, 10) + val manager = sc.executorAllocationManager.get + (1 to 10).map(_.toString).foreach { id => onExecutorAdded(manager, id) } + + // Keep removing until the limit is reached + assert(executorsPendingToRemove(manager).isEmpty) + assert(removeExecutor(manager, "1")) + assert(executorsPendingToRemove(manager).size === 1) + assert(executorsPendingToRemove(manager).contains("1")) + assert(removeExecutor(manager, "2")) + assert(removeExecutor(manager, "3")) + assert(executorsPendingToRemove(manager).size === 3) + assert(executorsPendingToRemove(manager).contains("2")) + assert(executorsPendingToRemove(manager).contains("3")) + assert(!removeExecutor(manager, "100")) // remove non-existent executors + assert(!removeExecutor(manager, "101")) + assert(executorsPendingToRemove(manager).size === 3) + assert(removeExecutor(manager, "4")) + assert(removeExecutor(manager, "5")) + assert(!removeExecutor(manager, "6")) // reached the limit of 5 + assert(executorsPendingToRemove(manager).size === 5) + assert(executorsPendingToRemove(manager).contains("4")) + assert(executorsPendingToRemove(manager).contains("5")) + assert(!executorsPendingToRemove(manager).contains("6")) + + // Kill executors previously requested to remove + onExecutorRemoved(manager, "1") + assert(executorsPendingToRemove(manager).size === 4) + assert(!executorsPendingToRemove(manager).contains("1")) + onExecutorRemoved(manager, "2") + onExecutorRemoved(manager, "3") + assert(executorsPendingToRemove(manager).size === 2) + assert(!executorsPendingToRemove(manager).contains("2")) + assert(!executorsPendingToRemove(manager).contains("3")) + onExecutorRemoved(manager, "2") // duplicates should not count + onExecutorRemoved(manager, "3") + assert(executorsPendingToRemove(manager).size === 2) + onExecutorRemoved(manager, "4") + onExecutorRemoved(manager, "5") + assert(executorsPendingToRemove(manager).isEmpty) + + // Try removing again + // This should still fail because the number pending + running is still at the limit + assert(!removeExecutor(manager, "7")) + assert(executorsPendingToRemove(manager).isEmpty) + assert(!removeExecutor(manager, "8")) + assert(executorsPendingToRemove(manager).isEmpty) + } + + test ("interleaving add and remove") { + sc = createSparkContext(5, 10) + val manager = sc.executorAllocationManager.get + + // Add a few executors + assert(addExecutors(manager) === 1) + assert(addExecutors(manager) === 2) + assert(addExecutors(manager) === 4) + onExecutorAdded(manager, "1") + onExecutorAdded(manager, "2") + onExecutorAdded(manager, "3") + onExecutorAdded(manager, "4") + onExecutorAdded(manager, "5") + onExecutorAdded(manager, "6") + onExecutorAdded(manager, "7") + assert(executorIds(manager).size === 7) + + // Remove until limit + assert(removeExecutor(manager, "1")) + assert(removeExecutor(manager, "2")) + assert(!removeExecutor(manager, "3")) // lower limit reached + assert(!removeExecutor(manager, "4")) + onExecutorRemoved(manager, "1") + onExecutorRemoved(manager, "2") + assert(executorIds(manager).size === 5) + + // Add until limit + assert(addExecutors(manager) === 5) // upper limit reached + assert(addExecutors(manager) === 0) + assert(!removeExecutor(manager, "3")) // still at lower limit + assert(!removeExecutor(manager, "4")) + onExecutorAdded(manager, "8") + onExecutorAdded(manager, "9") + onExecutorAdded(manager, "10") + onExecutorAdded(manager, "11") + onExecutorAdded(manager, "12") + assert(executorIds(manager).size === 10) + + // Remove succeeds again, now that we are no longer at the lower limit + assert(removeExecutor(manager, "3")) + assert(removeExecutor(manager, "4")) + assert(removeExecutor(manager, "5")) + assert(removeExecutor(manager, "6")) + assert(executorIds(manager).size === 10) + assert(addExecutors(manager) === 0) // still at upper limit + onExecutorRemoved(manager, "3") + onExecutorRemoved(manager, "4") + assert(executorIds(manager).size === 8) + + // Add succeeds again, now that we are no longer at the upper limit + // Number of executors added restarts at 1 + assert(addExecutors(manager) === 1) + assert(addExecutors(manager) === 1) // upper limit reached again + assert(addExecutors(manager) === 0) + assert(executorIds(manager).size === 8) + onExecutorRemoved(manager, "5") + onExecutorRemoved(manager, "6") + onExecutorAdded(manager, "13") + onExecutorAdded(manager, "14") + assert(executorIds(manager).size === 8) + assert(addExecutors(manager) === 1) + assert(addExecutors(manager) === 1) // upper limit reached again + assert(addExecutors(manager) === 0) + onExecutorAdded(manager, "15") + onExecutorAdded(manager, "16") + assert(executorIds(manager).size === 10) + } + + test("starting/canceling add timer") { + sc = createSparkContext(2, 10) + val clock = new TestClock(8888L) + val manager = sc.executorAllocationManager.get + manager.setClock(clock) + + // Starting add timer is idempotent + assert(addTime(manager) === NOT_SET) + onSchedulerBacklogged(manager) + val firstAddTime = addTime(manager) + assert(firstAddTime === clock.getTimeMillis + schedulerBacklogTimeout * 1000) + clock.tick(100L) + onSchedulerBacklogged(manager) + assert(addTime(manager) === firstAddTime) // timer is already started + clock.tick(200L) + onSchedulerBacklogged(manager) + assert(addTime(manager) === firstAddTime) + onSchedulerQueueEmpty(manager) + + // Restart add timer + clock.tick(1000L) + assert(addTime(manager) === NOT_SET) + onSchedulerBacklogged(manager) + val secondAddTime = addTime(manager) + assert(secondAddTime === clock.getTimeMillis + schedulerBacklogTimeout * 1000) + clock.tick(100L) + onSchedulerBacklogged(manager) + assert(addTime(manager) === secondAddTime) // timer is already started + assert(addTime(manager) !== firstAddTime) + assert(firstAddTime !== secondAddTime) + } + + test("starting/canceling remove timers") { + sc = createSparkContext(2, 10) + val clock = new TestClock(14444L) + val manager = sc.executorAllocationManager.get + manager.setClock(clock) + + // Starting remove timer is idempotent for each executor + assert(removeTimes(manager).isEmpty) + onExecutorIdle(manager, "1") + assert(removeTimes(manager).size === 1) + assert(removeTimes(manager).contains("1")) + val firstRemoveTime = removeTimes(manager)("1") + assert(firstRemoveTime === clock.getTimeMillis + executorIdleTimeout * 1000) + clock.tick(100L) + onExecutorIdle(manager, "1") + assert(removeTimes(manager)("1") === firstRemoveTime) // timer is already started + clock.tick(200L) + onExecutorIdle(manager, "1") + assert(removeTimes(manager)("1") === firstRemoveTime) + clock.tick(300L) + onExecutorIdle(manager, "2") + assert(removeTimes(manager)("2") !== firstRemoveTime) // different executor + assert(removeTimes(manager)("2") === clock.getTimeMillis + executorIdleTimeout * 1000) + clock.tick(400L) + onExecutorIdle(manager, "3") + assert(removeTimes(manager)("3") !== firstRemoveTime) + assert(removeTimes(manager)("3") === clock.getTimeMillis + executorIdleTimeout * 1000) + assert(removeTimes(manager).size === 3) + assert(removeTimes(manager).contains("2")) + assert(removeTimes(manager).contains("3")) + + // Restart remove timer + clock.tick(1000L) + onExecutorBusy(manager, "1") + assert(removeTimes(manager).size === 2) + onExecutorIdle(manager, "1") + assert(removeTimes(manager).size === 3) + assert(removeTimes(manager).contains("1")) + val secondRemoveTime = removeTimes(manager)("1") + assert(secondRemoveTime === clock.getTimeMillis + executorIdleTimeout * 1000) + assert(removeTimes(manager)("1") === secondRemoveTime) // timer is already started + assert(removeTimes(manager)("1") !== firstRemoveTime) + assert(firstRemoveTime !== secondRemoveTime) + } + + test("mock polling loop with no events") { + sc = createSparkContext(1, 20) + val manager = sc.executorAllocationManager.get + val clock = new TestClock(2020L) + manager.setClock(clock) + + // No events - we should not be adding or removing + assert(numExecutorsPending(manager) === 0) + assert(executorsPendingToRemove(manager).isEmpty) + schedule(manager) + assert(numExecutorsPending(manager) === 0) + assert(executorsPendingToRemove(manager).isEmpty) + clock.tick(100L) + schedule(manager) + assert(numExecutorsPending(manager) === 0) + assert(executorsPendingToRemove(manager).isEmpty) + clock.tick(1000L) + schedule(manager) + assert(numExecutorsPending(manager) === 0) + assert(executorsPendingToRemove(manager).isEmpty) + clock.tick(10000L) + schedule(manager) + assert(numExecutorsPending(manager) === 0) + assert(executorsPendingToRemove(manager).isEmpty) + } + + test("mock polling loop add behavior") { + sc = createSparkContext(1, 20) + val clock = new TestClock(2020L) + val manager = sc.executorAllocationManager.get + manager.setClock(clock) + + // Scheduler queue backlogged + onSchedulerBacklogged(manager) + clock.tick(schedulerBacklogTimeout * 1000 / 2) + schedule(manager) + assert(numExecutorsPending(manager) === 0) // timer not exceeded yet + clock.tick(schedulerBacklogTimeout * 1000) + schedule(manager) + assert(numExecutorsPending(manager) === 1) // first timer exceeded + clock.tick(sustainedSchedulerBacklogTimeout * 1000 / 2) + schedule(manager) + assert(numExecutorsPending(manager) === 1) // second timer not exceeded yet + clock.tick(sustainedSchedulerBacklogTimeout * 1000) + schedule(manager) + assert(numExecutorsPending(manager) === 1 + 2) // second timer exceeded + clock.tick(sustainedSchedulerBacklogTimeout * 1000) + schedule(manager) + assert(numExecutorsPending(manager) === 1 + 2 + 4) // third timer exceeded + + // Scheduler queue drained + onSchedulerQueueEmpty(manager) + clock.tick(sustainedSchedulerBacklogTimeout * 1000) + schedule(manager) + assert(numExecutorsPending(manager) === 7) // timer is canceled + clock.tick(sustainedSchedulerBacklogTimeout * 1000) + schedule(manager) + assert(numExecutorsPending(manager) === 7) + + // Scheduler queue backlogged again + onSchedulerBacklogged(manager) + clock.tick(schedulerBacklogTimeout * 1000) + schedule(manager) + assert(numExecutorsPending(manager) === 7 + 1) // timer restarted + clock.tick(sustainedSchedulerBacklogTimeout * 1000) + schedule(manager) + assert(numExecutorsPending(manager) === 7 + 1 + 2) + clock.tick(sustainedSchedulerBacklogTimeout * 1000) + schedule(manager) + assert(numExecutorsPending(manager) === 7 + 1 + 2 + 4) + clock.tick(sustainedSchedulerBacklogTimeout * 1000) + schedule(manager) + assert(numExecutorsPending(manager) === 20) // limit reached + } + + test("mock polling loop remove behavior") { + sc = createSparkContext(1, 20) + val clock = new TestClock(2020L) + val manager = sc.executorAllocationManager.get + manager.setClock(clock) + + // Remove idle executors on timeout + onExecutorAdded(manager, "executor-1") + onExecutorAdded(manager, "executor-2") + onExecutorAdded(manager, "executor-3") + assert(removeTimes(manager).size === 3) + assert(executorsPendingToRemove(manager).isEmpty) + clock.tick(executorIdleTimeout * 1000 / 2) + schedule(manager) + assert(removeTimes(manager).size === 3) // idle threshold not reached yet + assert(executorsPendingToRemove(manager).isEmpty) + clock.tick(executorIdleTimeout * 1000) + schedule(manager) + assert(removeTimes(manager).isEmpty) // idle threshold exceeded + assert(executorsPendingToRemove(manager).size === 2) // limit reached (1 executor remaining) + + // Mark a subset as busy - only idle executors should be removed + onExecutorAdded(manager, "executor-4") + onExecutorAdded(manager, "executor-5") + onExecutorAdded(manager, "executor-6") + onExecutorAdded(manager, "executor-7") + assert(removeTimes(manager).size === 5) // 5 active executors + assert(executorsPendingToRemove(manager).size === 2) // 2 pending to be removed + onExecutorBusy(manager, "executor-4") + onExecutorBusy(manager, "executor-5") + onExecutorBusy(manager, "executor-6") // 3 busy and 2 idle (of the 5 active ones) + schedule(manager) + assert(removeTimes(manager).size === 2) // remove only idle executors + assert(!removeTimes(manager).contains("executor-4")) + assert(!removeTimes(manager).contains("executor-5")) + assert(!removeTimes(manager).contains("executor-6")) + assert(executorsPendingToRemove(manager).size === 2) + clock.tick(executorIdleTimeout * 1000) + schedule(manager) + assert(removeTimes(manager).isEmpty) // idle executors are removed + assert(executorsPendingToRemove(manager).size === 4) + assert(!executorsPendingToRemove(manager).contains("executor-4")) + assert(!executorsPendingToRemove(manager).contains("executor-5")) + assert(!executorsPendingToRemove(manager).contains("executor-6")) + + // Busy executors are now idle and should be removed + onExecutorIdle(manager, "executor-4") + onExecutorIdle(manager, "executor-5") + onExecutorIdle(manager, "executor-6") + schedule(manager) + assert(removeTimes(manager).size === 3) // 0 busy and 3 idle + assert(removeTimes(manager).contains("executor-4")) + assert(removeTimes(manager).contains("executor-5")) + assert(removeTimes(manager).contains("executor-6")) + assert(executorsPendingToRemove(manager).size === 4) + clock.tick(executorIdleTimeout * 1000) + schedule(manager) + assert(removeTimes(manager).isEmpty) + assert(executorsPendingToRemove(manager).size === 6) // limit reached (1 executor remaining) + } + + test("listeners trigger add executors correctly") { + sc = createSparkContext(2, 10) + val manager = sc.executorAllocationManager.get + assert(addTime(manager) === NOT_SET) + + // Starting a stage should start the add timer + val numTasks = 10 + sc.listenerBus.postToAll(SparkListenerStageSubmitted(createStageInfo(0, numTasks))) + assert(addTime(manager) !== NOT_SET) + + // Starting a subset of the tasks should not cancel the add timer + val taskInfos = (0 to numTasks - 1).map { i => createTaskInfo(i, i, "executor-1") } + taskInfos.tail.foreach { info => sc.listenerBus.postToAll(SparkListenerTaskStart(0, 0, info)) } + assert(addTime(manager) !== NOT_SET) + + // Starting all remaining tasks should cancel the add timer + sc.listenerBus.postToAll(SparkListenerTaskStart(0, 0, taskInfos.head)) + assert(addTime(manager) === NOT_SET) + + // Start two different stages + // The add timer should be canceled only if all tasks in both stages start running + sc.listenerBus.postToAll(SparkListenerStageSubmitted(createStageInfo(1, numTasks))) + sc.listenerBus.postToAll(SparkListenerStageSubmitted(createStageInfo(2, numTasks))) + assert(addTime(manager) !== NOT_SET) + taskInfos.foreach { info => sc.listenerBus.postToAll(SparkListenerTaskStart(1, 0, info)) } + assert(addTime(manager) !== NOT_SET) + taskInfos.foreach { info => sc.listenerBus.postToAll(SparkListenerTaskStart(2, 0, info)) } + assert(addTime(manager) === NOT_SET) + } + + test("listeners trigger remove executors correctly") { + sc = createSparkContext(2, 10) + val manager = sc.executorAllocationManager.get + assert(removeTimes(manager).isEmpty) + + // Added executors should start the remove timers for each executor + (1 to 5).map("executor-" + _).foreach { id => onExecutorAdded(manager, id) } + assert(removeTimes(manager).size === 5) + + // Starting a task cancel the remove timer for that executor + sc.listenerBus.postToAll(SparkListenerTaskStart(0, 0, createTaskInfo(0, 0, "executor-1"))) + sc.listenerBus.postToAll(SparkListenerTaskStart(0, 0, createTaskInfo(1, 1, "executor-1"))) + sc.listenerBus.postToAll(SparkListenerTaskStart(0, 0, createTaskInfo(2, 2, "executor-2"))) + assert(removeTimes(manager).size === 3) + assert(!removeTimes(manager).contains("executor-1")) + assert(!removeTimes(manager).contains("executor-2")) + + // Finishing all tasks running on an executor should start the remove timer for that executor + sc.listenerBus.postToAll(SparkListenerTaskEnd( + 0, 0, "task-type", Success, createTaskInfo(0, 0, "executor-1"), new TaskMetrics)) + sc.listenerBus.postToAll(SparkListenerTaskEnd( + 0, 0, "task-type", Success, createTaskInfo(2, 2, "executor-2"), new TaskMetrics)) + assert(removeTimes(manager).size === 4) + assert(!removeTimes(manager).contains("executor-1")) // executor-1 has not finished yet + assert(removeTimes(manager).contains("executor-2")) + sc.listenerBus.postToAll(SparkListenerTaskEnd( + 0, 0, "task-type", Success, createTaskInfo(1, 1, "executor-1"), new TaskMetrics)) + assert(removeTimes(manager).size === 5) + assert(removeTimes(manager).contains("executor-1")) // executor-1 has now finished + } + + test("listeners trigger add and remove executor callbacks correctly") { + sc = createSparkContext(2, 10) + val manager = sc.executorAllocationManager.get + assert(executorIds(manager).isEmpty) + assert(removeTimes(manager).isEmpty) + + // New executors have registered + sc.listenerBus.postToAll(SparkListenerBlockManagerAdded( + 0L, BlockManagerId("executor-1", "host1", 1), 100L)) + assert(executorIds(manager).size === 1) + assert(executorIds(manager).contains("executor-1")) + assert(removeTimes(manager).size === 1) + assert(removeTimes(manager).contains("executor-1")) + sc.listenerBus.postToAll(SparkListenerBlockManagerAdded( + 0L, BlockManagerId("executor-2", "host2", 1), 100L)) + assert(executorIds(manager).size === 2) + assert(executorIds(manager).contains("executor-2")) + assert(removeTimes(manager).size === 2) + assert(removeTimes(manager).contains("executor-2")) + + // Existing executors have disconnected + sc.listenerBus.postToAll(SparkListenerBlockManagerRemoved( + 0L, BlockManagerId("executor-1", "host1", 1))) + assert(executorIds(manager).size === 1) + assert(!executorIds(manager).contains("executor-1")) + assert(removeTimes(manager).size === 1) + assert(!removeTimes(manager).contains("executor-1")) + + // Unknown executor has disconnected + sc.listenerBus.postToAll(SparkListenerBlockManagerRemoved( + 0L, BlockManagerId("executor-3", "host3", 1))) + assert(executorIds(manager).size === 1) + assert(removeTimes(manager).size === 1) + } + +} + +/** + * Helper methods for testing ExecutorAllocationManager. + * This includes methods to access private methods and fields in ExecutorAllocationManager. + */ +private object ExecutorAllocationManagerSuite extends PrivateMethodTester { + private val schedulerBacklogTimeout = 1L + private val sustainedSchedulerBacklogTimeout = 2L + private val executorIdleTimeout = 3L + + private def createSparkContext(minExecutors: Int = 1, maxExecutors: Int = 5): SparkContext = { + val conf = new SparkConf() + .setMaster("local") + .setAppName("test-executor-allocation-manager") + .set("spark.dynamicAllocation.enabled", "true") + .set("spark.dynamicAllocation.minExecutors", minExecutors.toString) + .set("spark.dynamicAllocation.maxExecutors", maxExecutors.toString) + .set("spark.dynamicAllocation.schedulerBacklogTimeout", schedulerBacklogTimeout.toString) + .set("spark.dynamicAllocation.sustainedSchedulerBacklogTimeout", + sustainedSchedulerBacklogTimeout.toString) + .set("spark.dynamicAllocation.executorIdleTimeout", executorIdleTimeout.toString) + .set("spark.dynamicAllocation.testing", "true") + new SparkContext(conf) + } + + private def createStageInfo(stageId: Int, numTasks: Int): StageInfo = { + new StageInfo(stageId, 0, "name", numTasks, Seq.empty, "no details") + } + + private def createTaskInfo(taskId: Int, taskIndex: Int, executorId: String): TaskInfo = { + new TaskInfo(taskId, taskIndex, 0, 0, executorId, "", TaskLocality.ANY, speculative = false) + } + + /* ------------------------------------------------------- * + | Helper methods for accessing private methods and fields | + * ------------------------------------------------------- */ + + private val _numExecutorsToAdd = PrivateMethod[Int]('numExecutorsToAdd) + private val _numExecutorsPending = PrivateMethod[Int]('numExecutorsPending) + private val _executorsPendingToRemove = + PrivateMethod[collection.Set[String]]('executorsPendingToRemove) + private val _executorIds = PrivateMethod[collection.Set[String]]('executorIds) + private val _addTime = PrivateMethod[Long]('addTime) + private val _removeTimes = PrivateMethod[collection.Map[String, Long]]('removeTimes) + private val _schedule = PrivateMethod[Unit]('schedule) + private val _addExecutors = PrivateMethod[Int]('addExecutors) + private val _removeExecutor = PrivateMethod[Boolean]('removeExecutor) + private val _onExecutorAdded = PrivateMethod[Unit]('onExecutorAdded) + private val _onExecutorRemoved = PrivateMethod[Unit]('onExecutorRemoved) + private val _onSchedulerBacklogged = PrivateMethod[Unit]('onSchedulerBacklogged) + private val _onSchedulerQueueEmpty = PrivateMethod[Unit]('onSchedulerQueueEmpty) + private val _onExecutorIdle = PrivateMethod[Unit]('onExecutorIdle) + private val _onExecutorBusy = PrivateMethod[Unit]('onExecutorBusy) + + private def numExecutorsToAdd(manager: ExecutorAllocationManager): Int = { + manager invokePrivate _numExecutorsToAdd() + } + + private def numExecutorsPending(manager: ExecutorAllocationManager): Int = { + manager invokePrivate _numExecutorsPending() + } + + private def executorsPendingToRemove( + manager: ExecutorAllocationManager): collection.Set[String] = { + manager invokePrivate _executorsPendingToRemove() + } + + private def executorIds(manager: ExecutorAllocationManager): collection.Set[String] = { + manager invokePrivate _executorIds() + } + + private def addTime(manager: ExecutorAllocationManager): Long = { + manager invokePrivate _addTime() + } + + private def removeTimes(manager: ExecutorAllocationManager): collection.Map[String, Long] = { + manager invokePrivate _removeTimes() + } + + private def schedule(manager: ExecutorAllocationManager): Unit = { + manager invokePrivate _schedule() + } + + private def addExecutors(manager: ExecutorAllocationManager): Int = { + manager invokePrivate _addExecutors() + } + + private def removeExecutor(manager: ExecutorAllocationManager, id: String): Boolean = { + manager invokePrivate _removeExecutor(id) + } + + private def onExecutorAdded(manager: ExecutorAllocationManager, id: String): Unit = { + manager invokePrivate _onExecutorAdded(id) + } + + private def onExecutorRemoved(manager: ExecutorAllocationManager, id: String): Unit = { + manager invokePrivate _onExecutorRemoved(id) + } + + private def onSchedulerBacklogged(manager: ExecutorAllocationManager): Unit = { + manager invokePrivate _onSchedulerBacklogged() + } + + private def onSchedulerQueueEmpty(manager: ExecutorAllocationManager): Unit = { + manager invokePrivate _onSchedulerQueueEmpty() + } + + private def onExecutorIdle(manager: ExecutorAllocationManager, id: String): Unit = { + manager invokePrivate _onExecutorIdle(id) + } + + private def onExecutorBusy(manager: ExecutorAllocationManager, id: String): Unit = { + manager invokePrivate _onExecutorBusy(id) + } +} diff --git a/core/src/test/scala/org/apache/spark/ExternalShuffleServiceSuite.scala b/core/src/test/scala/org/apache/spark/ExternalShuffleServiceSuite.scala new file mode 100644 index 0000000000000..792b9cd8b6ff2 --- /dev/null +++ b/core/src/test/scala/org/apache/spark/ExternalShuffleServiceSuite.scala @@ -0,0 +1,76 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark + +import java.util.concurrent.atomic.AtomicInteger + +import org.scalatest.BeforeAndAfterAll + +import org.apache.spark.SparkContext._ +import org.apache.spark.network.TransportContext +import org.apache.spark.network.netty.SparkTransportConf +import org.apache.spark.network.server.TransportServer +import org.apache.spark.network.shuffle.{ExternalShuffleBlockHandler, ExternalShuffleClient} + +/** + * This suite creates an external shuffle server and routes all shuffle fetches through it. + * Note that failures in this suite may arise due to changes in Spark that invalidate expectations + * set up in [[ExternalShuffleBlockHandler]], such as changing the format of shuffle files or how + * we hash files into folders. + */ +class ExternalShuffleServiceSuite extends ShuffleSuite with BeforeAndAfterAll { + var server: TransportServer = _ + var rpcHandler: ExternalShuffleBlockHandler = _ + + override def beforeAll() { + val transportConf = SparkTransportConf.fromSparkConf(conf) + rpcHandler = new ExternalShuffleBlockHandler() + val transportContext = new TransportContext(transportConf, rpcHandler) + server = transportContext.createServer() + + conf.set("spark.shuffle.manager", "sort") + conf.set("spark.shuffle.service.enabled", "true") + conf.set("spark.shuffle.service.port", server.getPort.toString) + } + + override def afterAll() { + server.close() + } + + // This test ensures that the external shuffle service is actually in use for the other tests. + test("using external shuffle service") { + sc = new SparkContext("local-cluster[2,1,512]", "test", conf) + sc.env.blockManager.externalShuffleServiceEnabled should equal(true) + sc.env.blockManager.shuffleClient.getClass should equal(classOf[ExternalShuffleClient]) + + val rdd = sc.parallelize(0 until 1000, 10).map(i => (i, 1)).reduceByKey(_ + _) + + rdd.count() + rdd.count() + + // Invalidate the registered executors, disallowing access to their shuffle blocks. + rpcHandler.clearRegisteredExecutors() + + // Now Spark will receive FetchFailed, and not retry the stage due to "spark.test.noStageRetry" + // being set. + val e = intercept[SparkException] { + rdd.count() + } + e.getMessage should include ("Fetch failure will not retry stage due to testing config") + } +} diff --git a/core/src/test/scala/org/apache/spark/FileServerSuite.scala b/core/src/test/scala/org/apache/spark/FileServerSuite.scala index a8867020e457d..379c2a6ea4b55 100644 --- a/core/src/test/scala/org/apache/spark/FileServerSuite.scala +++ b/core/src/test/scala/org/apache/spark/FileServerSuite.scala @@ -20,6 +20,7 @@ package org.apache.spark import java.io._ import java.util.jar.{JarEntry, JarOutputStream} +import com.google.common.io.ByteStreams import org.scalatest.FunSuite import org.apache.spark.SparkContext._ @@ -58,12 +59,7 @@ class FileServerSuite extends FunSuite with LocalSparkContext { jar.putNextEntry(jarEntry) val in = new FileInputStream(textFile) - val buffer = new Array[Byte](10240) - var nRead = 0 - while (nRead <= 0) { - nRead = in.read(buffer, 0, buffer.length) - jar.write(buffer, 0, nRead) - } + ByteStreams.copy(in, jar) in.close() jar.close() diff --git a/core/src/test/scala/org/apache/spark/FileSuite.scala b/core/src/test/scala/org/apache/spark/FileSuite.scala index a2b74c4419d46..5e24196101fbc 100644 --- a/core/src/test/scala/org/apache/spark/FileSuite.scala +++ b/core/src/test/scala/org/apache/spark/FileSuite.scala @@ -19,6 +19,9 @@ package org.apache.spark import java.io.{File, FileWriter} +import org.apache.spark.input.PortableDataStream +import org.apache.spark.storage.StorageLevel + import scala.io.Source import org.apache.hadoop.io._ @@ -224,6 +227,187 @@ class FileSuite extends FunSuite with LocalSparkContext { assert(output.map(_.toString).collect().toList === List("(1,a)", "(2,aa)", "(3,aaa)")) } + test("binary file input as byte array") { + sc = new SparkContext("local", "test") + val outFile = new File(tempDir, "record-bytestream-00000.bin") + val outFileName = outFile.getAbsolutePath() + + // create file + val testOutput = Array[Byte](1, 2, 3, 4, 5, 6) + val bbuf = java.nio.ByteBuffer.wrap(testOutput) + // write data to file + val file = new java.io.FileOutputStream(outFile) + val channel = file.getChannel + channel.write(bbuf) + channel.close() + file.close() + + val inRdd = sc.binaryFiles(outFileName) + val (infile: String, indata: PortableDataStream) = inRdd.collect.head + + // Make sure the name and array match + assert(infile.contains(outFileName)) // a prefix may get added + assert(indata.toArray === testOutput) + } + + test("portabledatastream caching tests") { + sc = new SparkContext("local", "test") + val outFile = new File(tempDir, "record-bytestream-00000.bin") + val outFileName = outFile.getAbsolutePath() + + // create file + val testOutput = Array[Byte](1, 2, 3, 4, 5, 6) + val bbuf = java.nio.ByteBuffer.wrap(testOutput) + // write data to file + val file = new java.io.FileOutputStream(outFile) + val channel = file.getChannel + channel.write(bbuf) + channel.close() + file.close() + + val inRdd = sc.binaryFiles(outFileName).cache() + inRdd.foreach{ + curData: (String, PortableDataStream) => + curData._2.toArray() // force the file to read + } + val mappedRdd = inRdd.map { + curData: (String, PortableDataStream) => + (curData._2.getPath(), curData._2) + } + val (infile: String, indata: PortableDataStream) = mappedRdd.collect.head + + // Try reading the output back as an object file + + assert(indata.toArray === testOutput) + } + + test("portabledatastream persist disk storage") { + sc = new SparkContext("local", "test") + val outFile = new File(tempDir, "record-bytestream-00000.bin") + val outFileName = outFile.getAbsolutePath() + + // create file + val testOutput = Array[Byte](1, 2, 3, 4, 5, 6) + val bbuf = java.nio.ByteBuffer.wrap(testOutput) + // write data to file + val file = new java.io.FileOutputStream(outFile) + val channel = file.getChannel + channel.write(bbuf) + channel.close() + file.close() + + val inRdd = sc.binaryFiles(outFileName).persist(StorageLevel.DISK_ONLY) + inRdd.foreach{ + curData: (String, PortableDataStream) => + curData._2.toArray() // force the file to read + } + val mappedRdd = inRdd.map { + curData: (String, PortableDataStream) => + (curData._2.getPath(), curData._2) + } + val (infile: String, indata: PortableDataStream) = mappedRdd.collect.head + + // Try reading the output back as an object file + + assert(indata.toArray === testOutput) + } + + test("portabledatastream flatmap tests") { + sc = new SparkContext("local", "test") + val outFile = new File(tempDir, "record-bytestream-00000.bin") + val outFileName = outFile.getAbsolutePath() + + // create file + val testOutput = Array[Byte](1, 2, 3, 4, 5, 6) + val numOfCopies = 3 + val bbuf = java.nio.ByteBuffer.wrap(testOutput) + // write data to file + val file = new java.io.FileOutputStream(outFile) + val channel = file.getChannel + channel.write(bbuf) + channel.close() + file.close() + + val inRdd = sc.binaryFiles(outFileName) + val mappedRdd = inRdd.map { + curData: (String, PortableDataStream) => + (curData._2.getPath(), curData._2) + } + val copyRdd = mappedRdd.flatMap { + curData: (String, PortableDataStream) => + for(i <- 1 to numOfCopies) yield (i, curData._2) + } + + val copyArr: Array[(Int, PortableDataStream)] = copyRdd.collect() + + // Try reading the output back as an object file + assert(copyArr.length == numOfCopies) + copyArr.foreach{ + cEntry: (Int, PortableDataStream) => + assert(cEntry._2.toArray === testOutput) + } + + } + + test("fixed record length binary file as byte array") { + // a fixed length of 6 bytes + + sc = new SparkContext("local", "test") + + val outFile = new File(tempDir, "record-bytestream-00000.bin") + val outFileName = outFile.getAbsolutePath() + + // create file + val testOutput = Array[Byte](1, 2, 3, 4, 5, 6) + val testOutputCopies = 10 + + // write data to file + val file = new java.io.FileOutputStream(outFile) + val channel = file.getChannel + for(i <- 1 to testOutputCopies) { + val bbuf = java.nio.ByteBuffer.wrap(testOutput) + channel.write(bbuf) + } + channel.close() + file.close() + + val inRdd = sc.binaryRecords(outFileName, testOutput.length) + // make sure there are enough elements + assert(inRdd.count == testOutputCopies) + + // now just compare the first one + val indata: Array[Byte] = inRdd.collect.head + assert(indata === testOutput) + } + + test ("negative binary record length should raise an exception") { + // a fixed length of 6 bytes + sc = new SparkContext("local", "test") + + val outFile = new File(tempDir, "record-bytestream-00000.bin") + val outFileName = outFile.getAbsolutePath() + + // create file + val testOutput = Array[Byte](1, 2, 3, 4, 5, 6) + val testOutputCopies = 10 + + // write data to file + val file = new java.io.FileOutputStream(outFile) + val channel = file.getChannel + for(i <- 1 to testOutputCopies) { + val bbuf = java.nio.ByteBuffer.wrap(testOutput) + channel.write(bbuf) + } + channel.close() + file.close() + + val inRdd = sc.binaryRecords(outFileName, -1) + + intercept[SparkException] { + inRdd.count + } + } + test("file caching") { sc = new SparkContext("local", "test") val out = new FileWriter(tempDir + "/input") diff --git a/core/src/test/scala/org/apache/spark/HashShuffleSuite.scala b/core/src/test/scala/org/apache/spark/HashShuffleSuite.scala index 2acc02a54fa3d..19180e88ebe0a 100644 --- a/core/src/test/scala/org/apache/spark/HashShuffleSuite.scala +++ b/core/src/test/scala/org/apache/spark/HashShuffleSuite.scala @@ -24,10 +24,6 @@ class HashShuffleSuite extends ShuffleSuite with BeforeAndAfterAll { // This test suite should run all tests in ShuffleSuite with hash-based shuffle. override def beforeAll() { - System.setProperty("spark.shuffle.manager", "hash") - } - - override def afterAll() { - System.clearProperty("spark.shuffle.manager") + conf.set("spark.shuffle.manager", "hash") } } diff --git a/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala b/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala index cbc0bd178d894..d27880f4bc32f 100644 --- a/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala +++ b/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala @@ -28,7 +28,7 @@ import org.apache.spark.shuffle.FetchFailedException import org.apache.spark.storage.BlockManagerId import org.apache.spark.util.AkkaUtils -class MapOutputTrackerSuite extends FunSuite with LocalSparkContext { +class MapOutputTrackerSuite extends FunSuite { private val conf = new SparkConf test("master start and stop") { @@ -37,6 +37,7 @@ class MapOutputTrackerSuite extends FunSuite with LocalSparkContext { tracker.trackerActor = actorSystem.actorOf(Props(new MapOutputTrackerMasterActor(tracker, conf))) tracker.stop() + actorSystem.shutdown() } test("master register shuffle and fetch") { @@ -56,6 +57,7 @@ class MapOutputTrackerSuite extends FunSuite with LocalSparkContext { assert(statuses.toSeq === Seq((BlockManagerId("a", "hostA", 1000), size1000), (BlockManagerId("b", "hostB", 1000), size10000))) tracker.stop() + actorSystem.shutdown() } test("master register and unregister shuffle") { @@ -74,6 +76,9 @@ class MapOutputTrackerSuite extends FunSuite with LocalSparkContext { tracker.unregisterShuffle(10) assert(!tracker.containsShuffle(10)) assert(tracker.getServerStatuses(10, 0).isEmpty) + + tracker.stop() + actorSystem.shutdown() } test("master register shuffle and unregister map output and fetch") { @@ -97,6 +102,9 @@ class MapOutputTrackerSuite extends FunSuite with LocalSparkContext { // this should cause it to fail, and the scheduler will ignore the failure due to the // stage already being aborted. intercept[FetchFailedException] { tracker.getServerStatuses(10, 1) } + + tracker.stop() + actorSystem.shutdown() } test("remote fetch") { @@ -136,6 +144,11 @@ class MapOutputTrackerSuite extends FunSuite with LocalSparkContext { // failure should be cached intercept[FetchFailedException] { slaveTracker.getServerStatuses(10, 0) } + + masterTracker.stop() + slaveTracker.stop() + actorSystem.shutdown() + slaveSystem.shutdown() } test("remote fetch below akka frame size") { @@ -154,6 +167,9 @@ class MapOutputTrackerSuite extends FunSuite with LocalSparkContext { masterTracker.registerMapOutput(10, 0, MapStatus( BlockManagerId("88", "mph", 1000), Array.fill[Long](10)(0))) masterActor.receive(GetMapOutputStatuses(10)) + +// masterTracker.stop() // this throws an exception + actorSystem.shutdown() } test("remote fetch exceeds akka frame size") { @@ -176,5 +192,8 @@ class MapOutputTrackerSuite extends FunSuite with LocalSparkContext { BlockManagerId("999", "mps", 1000), Array.fill[Long](4000000)(0))) } intercept[SparkException] { masterActor.receive(GetMapOutputStatuses(20)) } + +// masterTracker.stop() // this throws an exception + actorSystem.shutdown() } } diff --git a/core/src/test/scala/org/apache/spark/ShuffleNettySuite.scala b/core/src/test/scala/org/apache/spark/ShuffleNettySuite.scala index d7b2d2e1e330f..d78c99c2e1e06 100644 --- a/core/src/test/scala/org/apache/spark/ShuffleNettySuite.scala +++ b/core/src/test/scala/org/apache/spark/ShuffleNettySuite.scala @@ -24,10 +24,6 @@ class ShuffleNettySuite extends ShuffleSuite with BeforeAndAfterAll { // This test suite should run all tests in ShuffleSuite with Netty shuffle mode. override def beforeAll() { - System.setProperty("spark.shuffle.use.netty", "true") - } - - override def afterAll() { - System.clearProperty("spark.shuffle.use.netty") + conf.set("spark.shuffle.blockTransferService", "netty") } } diff --git a/core/src/test/scala/org/apache/spark/ShuffleSuite.scala b/core/src/test/scala/org/apache/spark/ShuffleSuite.scala index 2bdd84ce69ab8..cda942e15a704 100644 --- a/core/src/test/scala/org/apache/spark/ShuffleSuite.scala +++ b/core/src/test/scala/org/apache/spark/ShuffleSuite.scala @@ -30,10 +30,14 @@ abstract class ShuffleSuite extends FunSuite with Matchers with LocalSparkContex val conf = new SparkConf(loadDefaults = false) + // Ensure that the DAGScheduler doesn't retry stages whose fetches fail, so that we accurately + // test that the shuffle works (rather than retrying until all blocks are local to one Executor). + conf.set("spark.test.noStageRetry", "true") + test("groupByKey without compression") { try { System.setProperty("spark.shuffle.compress", "false") - sc = new SparkContext("local", "test") + sc = new SparkContext("local", "test", conf) val pairs = sc.parallelize(Array((1, 1), (1, 2), (1, 3), (2, 1)), 4) val groups = pairs.groupByKey(4).collect() assert(groups.size === 2) @@ -47,7 +51,7 @@ abstract class ShuffleSuite extends FunSuite with Matchers with LocalSparkContex } test("shuffle non-zero block size") { - sc = new SparkContext("local-cluster[2,1,512]", "test") + sc = new SparkContext("local-cluster[2,1,512]", "test", conf) val NUM_BLOCKS = 3 val a = sc.parallelize(1 to 10, 2) @@ -73,7 +77,7 @@ abstract class ShuffleSuite extends FunSuite with Matchers with LocalSparkContex test("shuffle serializer") { // Use a local cluster with 2 processes to make sure there are both local and remote blocks - sc = new SparkContext("local-cluster[2,1,512]", "test") + sc = new SparkContext("local-cluster[2,1,512]", "test", conf) val a = sc.parallelize(1 to 10, 2) val b = a.map { x => (x, new NonJavaSerializableClass(x * 2)) @@ -89,7 +93,7 @@ abstract class ShuffleSuite extends FunSuite with Matchers with LocalSparkContex test("zero sized blocks") { // Use a local cluster with 2 processes to make sure there are both local and remote blocks - sc = new SparkContext("local-cluster[2,1,512]", "test") + sc = new SparkContext("local-cluster[2,1,512]", "test", conf) // 10 partitions from 4 keys val NUM_BLOCKS = 10 @@ -116,7 +120,7 @@ abstract class ShuffleSuite extends FunSuite with Matchers with LocalSparkContex test("zero sized blocks without kryo") { // Use a local cluster with 2 processes to make sure there are both local and remote blocks - sc = new SparkContext("local-cluster[2,1,512]", "test") + sc = new SparkContext("local-cluster[2,1,512]", "test", conf) // 10 partitions from 4 keys val NUM_BLOCKS = 10 @@ -141,7 +145,7 @@ abstract class ShuffleSuite extends FunSuite with Matchers with LocalSparkContex test("shuffle on mutable pairs") { // Use a local cluster with 2 processes to make sure there are both local and remote blocks - sc = new SparkContext("local-cluster[2,1,512]", "test") + sc = new SparkContext("local-cluster[2,1,512]", "test", conf) def p[T1, T2](_1: T1, _2: T2) = MutablePair(_1, _2) val data = Array(p(1, 1), p(1, 2), p(1, 3), p(2, 1)) val pairs: RDD[MutablePair[Int, Int]] = sc.parallelize(data, 2) @@ -154,7 +158,7 @@ abstract class ShuffleSuite extends FunSuite with Matchers with LocalSparkContex test("sorting on mutable pairs") { // This is not in SortingSuite because of the local cluster setup. // Use a local cluster with 2 processes to make sure there are both local and remote blocks - sc = new SparkContext("local-cluster[2,1,512]", "test") + sc = new SparkContext("local-cluster[2,1,512]", "test", conf) def p[T1, T2](_1: T1, _2: T2) = MutablePair(_1, _2) val data = Array(p(1, 11), p(3, 33), p(100, 100), p(2, 22)) val pairs: RDD[MutablePair[Int, Int]] = sc.parallelize(data, 2) @@ -168,7 +172,7 @@ abstract class ShuffleSuite extends FunSuite with Matchers with LocalSparkContex test("cogroup using mutable pairs") { // Use a local cluster with 2 processes to make sure there are both local and remote blocks - sc = new SparkContext("local-cluster[2,1,512]", "test") + sc = new SparkContext("local-cluster[2,1,512]", "test", conf) def p[T1, T2](_1: T1, _2: T2) = MutablePair(_1, _2) val data1 = Seq(p(1, 1), p(1, 2), p(1, 3), p(2, 1)) val data2 = Seq(p(1, "11"), p(1, "12"), p(2, "22"), p(3, "3")) @@ -195,7 +199,7 @@ abstract class ShuffleSuite extends FunSuite with Matchers with LocalSparkContex test("subtract mutable pairs") { // Use a local cluster with 2 processes to make sure there are both local and remote blocks - sc = new SparkContext("local-cluster[2,1,512]", "test") + sc = new SparkContext("local-cluster[2,1,512]", "test", conf) def p[T1, T2](_1: T1, _2: T2) = MutablePair(_1, _2) val data1 = Seq(p(1, 1), p(1, 2), p(1, 3), p(2, 1), p(3, 33)) val data2 = Seq(p(1, "11"), p(1, "12"), p(2, "22")) @@ -209,11 +213,8 @@ abstract class ShuffleSuite extends FunSuite with Matchers with LocalSparkContex test("sort with Java non serializable class - Kryo") { // Use a local cluster with 2 processes to make sure there are both local and remote blocks - val conf = new SparkConf() - .set("spark.serializer", "org.apache.spark.serializer.KryoSerializer") - .setAppName("test") - .setMaster("local-cluster[2,1,512]") - sc = new SparkContext(conf) + val myConf = conf.clone().set("spark.serializer", "org.apache.spark.serializer.KryoSerializer") + sc = new SparkContext("local-cluster[2,1,512]", "test", myConf) val a = sc.parallelize(1 to 10, 2) val b = a.map { x => (new NonJavaSerializableClass(x), x) @@ -226,10 +227,7 @@ abstract class ShuffleSuite extends FunSuite with Matchers with LocalSparkContex test("sort with Java non serializable class - Java") { // Use a local cluster with 2 processes to make sure there are both local and remote blocks - val conf = new SparkConf() - .setAppName("test") - .setMaster("local-cluster[2,1,512]") - sc = new SparkContext(conf) + sc = new SparkContext("local-cluster[2,1,512]", "test", conf) val a = sc.parallelize(1 to 10, 2) val b = a.map { x => (new NonJavaSerializableClass(x), x) diff --git a/core/src/test/scala/org/apache/spark/SortShuffleSuite.scala b/core/src/test/scala/org/apache/spark/SortShuffleSuite.scala index 639e56c488db4..63358172ea1f4 100644 --- a/core/src/test/scala/org/apache/spark/SortShuffleSuite.scala +++ b/core/src/test/scala/org/apache/spark/SortShuffleSuite.scala @@ -24,10 +24,6 @@ class SortShuffleSuite extends ShuffleSuite with BeforeAndAfterAll { // This test suite should run all tests in ShuffleSuite with sort-based shuffle. override def beforeAll() { - System.setProperty("spark.shuffle.manager", "sort") - } - - override def afterAll() { - System.clearProperty("spark.shuffle.manager") + conf.set("spark.shuffle.manager", "sort") } } diff --git a/core/src/test/scala/org/apache/spark/SparkContextSchedulerCreationSuite.scala b/core/src/test/scala/org/apache/spark/SparkContextSchedulerCreationSuite.scala index 495a0d48633a4..0390a2e4f1dbb 100644 --- a/core/src/test/scala/org/apache/spark/SparkContextSchedulerCreationSuite.scala +++ b/core/src/test/scala/org/apache/spark/SparkContextSchedulerCreationSuite.scala @@ -17,22 +17,23 @@ package org.apache.spark -import org.scalatest.{BeforeAndAfterEach, FunSuite, PrivateMethodTester} +import org.scalatest.{FunSuite, PrivateMethodTester} -import org.apache.spark.scheduler.{TaskScheduler, TaskSchedulerImpl} +import org.apache.spark.scheduler.{SchedulerBackend, TaskScheduler, TaskSchedulerImpl} import org.apache.spark.scheduler.cluster.{SimrSchedulerBackend, SparkDeploySchedulerBackend} import org.apache.spark.scheduler.cluster.mesos.{CoarseMesosSchedulerBackend, MesosSchedulerBackend} import org.apache.spark.scheduler.local.LocalBackend class SparkContextSchedulerCreationSuite - extends FunSuite with PrivateMethodTester with Logging with BeforeAndAfterEach { + extends FunSuite with LocalSparkContext with PrivateMethodTester with Logging { def createTaskScheduler(master: String): TaskSchedulerImpl = { // Create local SparkContext to setup a SparkEnv. We don't actually want to start() the // real schedulers, so we don't want to create a full SparkContext with the desired scheduler. - val sc = new SparkContext("local", "test") - val createTaskSchedulerMethod = PrivateMethod[TaskScheduler]('createTaskScheduler) - val sched = SparkContext invokePrivate createTaskSchedulerMethod(sc, master) + sc = new SparkContext("local", "test") + val createTaskSchedulerMethod = + PrivateMethod[Tuple2[SchedulerBackend, TaskScheduler]]('createTaskScheduler) + val (_, sched) = SparkContext invokePrivate createTaskSchedulerMethod(sc, master) sched.asInstanceOf[TaskSchedulerImpl] } diff --git a/core/src/test/scala/org/apache/spark/StatusAPISuite.scala b/core/src/test/scala/org/apache/spark/StatusAPISuite.scala new file mode 100644 index 0000000000000..4468fba8c1dff --- /dev/null +++ b/core/src/test/scala/org/apache/spark/StatusAPISuite.scala @@ -0,0 +1,78 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark + +import scala.concurrent.duration._ +import scala.language.implicitConversions +import scala.language.postfixOps + +import org.scalatest.{Matchers, FunSuite} +import org.scalatest.concurrent.Eventually._ + +import org.apache.spark.JobExecutionStatus._ +import org.apache.spark.SparkContext._ + +class StatusAPISuite extends FunSuite with Matchers with SharedSparkContext { + + test("basic status API usage") { + val jobFuture = sc.parallelize(1 to 10000, 2).map(identity).groupBy(identity).collectAsync() + val jobId: Int = eventually(timeout(10 seconds)) { + val jobIds = jobFuture.jobIds + jobIds.size should be(1) + jobIds.head + } + val jobInfo = eventually(timeout(10 seconds)) { + sc.getJobInfo(jobId).get + } + jobInfo.status() should not be FAILED + val stageIds = jobInfo.stageIds() + stageIds.size should be(2) + + val firstStageInfo = eventually(timeout(10 seconds)) { + sc.getStageInfo(stageIds(0)).get + } + firstStageInfo.stageId() should be(stageIds(0)) + firstStageInfo.currentAttemptId() should be(0) + firstStageInfo.numTasks() should be(2) + eventually(timeout(10 seconds)) { + val updatedFirstStageInfo = sc.getStageInfo(stageIds(0)).get + updatedFirstStageInfo.numCompletedTasks() should be(2) + updatedFirstStageInfo.numActiveTasks() should be(0) + updatedFirstStageInfo.numFailedTasks() should be(0) + } + } + + test("getJobIdsForGroup()") { + sc.setJobGroup("my-job-group", "description") + sc.getJobIdsForGroup("my-job-group") should be (Seq.empty) + val firstJobFuture = sc.parallelize(1 to 1000).countAsync() + val firstJobId = eventually(timeout(10 seconds)) { + firstJobFuture.jobIds.head + } + eventually(timeout(10 seconds)) { + sc.getJobIdsForGroup("my-job-group") should be (Seq(firstJobId)) + } + val secondJobFuture = sc.parallelize(1 to 1000).countAsync() + val secondJobId = eventually(timeout(10 seconds)) { + secondJobFuture.jobIds.head + } + eventually(timeout(10 seconds)) { + sc.getJobIdsForGroup("my-job-group").toSet should be (Set(firstJobId, secondJobId)) + } + } +} \ No newline at end of file diff --git a/core/src/test/scala/org/apache/spark/broadcast/BroadcastSuite.scala b/core/src/test/scala/org/apache/spark/broadcast/BroadcastSuite.scala index e096c8c3e9b46..b0a70f012f1f3 100644 --- a/core/src/test/scala/org/apache/spark/broadcast/BroadcastSuite.scala +++ b/core/src/test/scala/org/apache/spark/broadcast/BroadcastSuite.scala @@ -19,13 +19,30 @@ package org.apache.spark.broadcast import scala.util.Random -import org.scalatest.FunSuite +import org.scalatest.{Assertions, FunSuite} -import org.apache.spark.{LocalSparkContext, SparkConf, SparkContext, SparkException} +import org.apache.spark.{LocalSparkContext, SparkConf, SparkContext, SparkException, SparkEnv} import org.apache.spark.io.SnappyCompressionCodec +import org.apache.spark.rdd.RDD import org.apache.spark.serializer.JavaSerializer import org.apache.spark.storage._ +// Dummy class that creates a broadcast variable but doesn't use it +class DummyBroadcastClass(rdd: RDD[Int]) extends Serializable { + @transient val list = List(1, 2, 3, 4) + val broadcast = rdd.context.broadcast(list) + val bid = broadcast.id + + def doSomething() = { + rdd.map { x => + val bm = SparkEnv.get.blockManager + // Check if broadcast block was fetched + val isFound = bm.getLocal(BroadcastBlockId(bid)).isDefined + (x, isFound) + }.collect().toSet + } +} + class BroadcastSuite extends FunSuite with LocalSparkContext { private val httpConf = broadcastConf("HttpBroadcastFactory") @@ -105,6 +122,17 @@ class BroadcastSuite extends FunSuite with LocalSparkContext { } } + test("Test Lazy Broadcast variables with TorrentBroadcast") { + val numSlaves = 2 + val conf = torrentConf.clone + sc = new SparkContext("local-cluster[%d, 1, 512]".format(numSlaves), "test", conf) + val rdd = sc.parallelize(1 to numSlaves) + + val results = new DummyBroadcastClass(rdd).doSomething() + + assert(results.toSet === (1 to numSlaves).map(x => (x, false)).toSet) + } + test("Unpersisting HttpBroadcast on executors only in local mode") { testUnpersistHttpBroadcast(distributed = false, removeFromDriver = false) } @@ -136,6 +164,12 @@ class BroadcastSuite extends FunSuite with LocalSparkContext { test("Unpersisting TorrentBroadcast on executors and driver in distributed mode") { testUnpersistTorrentBroadcast(distributed = true, removeFromDriver = true) } + + test("Using broadcast after destroy prints callsite") { + sc = new SparkContext("local", "test") + testPackage.runCallSiteTest(sc) + } + /** * Verify the persistence of state associated with an HttpBroadcast in either local mode or * local-cluster mode (when distributed = true). @@ -311,3 +345,15 @@ class BroadcastSuite extends FunSuite with LocalSparkContext { conf } } + +package object testPackage extends Assertions { + + def runCallSiteTest(sc: SparkContext) { + val rdd = sc.makeRDD(Array(1, 2, 3, 4), 2) + val broadcast = sc.broadcast(rdd) + broadcast.destroy() + val thrown = intercept[SparkException] { broadcast.value } + assert(thrown.getMessage.contains("BroadcastSuite.scala")) + } + +} diff --git a/core/src/main/scala/org/apache/spark/network/netty/client/LazyInitIterator.scala b/core/src/test/scala/org/apache/spark/deploy/CommandUtilsSuite.scala similarity index 51% rename from core/src/main/scala/org/apache/spark/network/netty/client/LazyInitIterator.scala rename to core/src/test/scala/org/apache/spark/deploy/CommandUtilsSuite.scala index 9740ee64d1f2d..7915ee75d8778 100644 --- a/core/src/main/scala/org/apache/spark/network/netty/client/LazyInitIterator.scala +++ b/core/src/test/scala/org/apache/spark/deploy/CommandUtilsSuite.scala @@ -15,30 +15,23 @@ * limitations under the License. */ -package org.apache.spark.network.netty.client +package org.apache.spark.deploy -/** - * A simple iterator that lazily initializes the underlying iterator. - * - * The use case is that sometimes we might have many iterators open at the same time, and each of - * the iterator might initialize its own buffer (e.g. decompression buffer, deserialization buffer). - * This could lead to too many buffers open. If this iterator is used, we lazily initialize those - * buffers. - */ -private[spark] -class LazyInitIterator(createIterator: => Iterator[Any]) extends Iterator[Any] { +import org.apache.spark.deploy.worker.CommandUtils +import org.apache.spark.util.Utils - lazy val proxy = createIterator +import org.scalatest.{FunSuite, Matchers} - override def hasNext: Boolean = { - val gotNext = proxy.hasNext - if (!gotNext) { - close() - } - gotNext - } +class CommandUtilsSuite extends FunSuite with Matchers { - override def next(): Any = proxy.next() - - def close(): Unit = Unit + test("set libraryPath correctly") { + val appId = "12345-worker321-9876" + val sparkHome = sys.props.getOrElse("spark.test.home", fail("spark.test.home is not set!")) + val cmd = new Command("mainClass", Seq(), Map(), Seq(), Seq("libraryPathToB"), Seq()) + val builder = CommandUtils.buildProcessBuilder(cmd, 512, sparkHome, t => t) + val libraryPath = Utils.libraryPathEnvName + val env = builder.environment + env.keySet should contain(libraryPath) + assert(env.get(libraryPath).startsWith("libraryPathToB")) + } } diff --git a/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala index 1cdf50d5c08c7..d8cd0ff2c9026 100644 --- a/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala @@ -292,7 +292,7 @@ class SparkSubmitSuite extends FunSuite with Matchers { runSparkSubmit(args) } - test("spark submit includes jars passed in through --jar") { + test("includes jars passed in through --jars") { val unusedJar = TestUtils.createJarWithClasses(Seq.empty) val jar1 = TestUtils.createJarWithClasses(Seq("SparkSubmitClassA")) val jar2 = TestUtils.createJarWithClasses(Seq("SparkSubmitClassB")) @@ -306,6 +306,110 @@ class SparkSubmitSuite extends FunSuite with Matchers { runSparkSubmit(args) } + test("resolves command line argument paths correctly") { + val jars = "/jar1,/jar2" // --jars + val files = "hdfs:/file1,file2" // --files + val archives = "file:/archive1,archive2" // --archives + val pyFiles = "py-file1,py-file2" // --py-files + + // Test jars and files + val clArgs = Seq( + "--master", "local", + "--class", "org.SomeClass", + "--jars", jars, + "--files", files, + "thejar.jar") + val appArgs = new SparkSubmitArguments(clArgs) + val sysProps = SparkSubmit.createLaunchEnv(appArgs)._3 + appArgs.jars should be (Utils.resolveURIs(jars)) + appArgs.files should be (Utils.resolveURIs(files)) + sysProps("spark.jars") should be (Utils.resolveURIs(jars + ",thejar.jar")) + sysProps("spark.files") should be (Utils.resolveURIs(files)) + + // Test files and archives (Yarn) + val clArgs2 = Seq( + "--master", "yarn-client", + "--class", "org.SomeClass", + "--files", files, + "--archives", archives, + "thejar.jar" + ) + val appArgs2 = new SparkSubmitArguments(clArgs2) + val sysProps2 = SparkSubmit.createLaunchEnv(appArgs2)._3 + appArgs2.files should be (Utils.resolveURIs(files)) + appArgs2.archives should be (Utils.resolveURIs(archives)) + sysProps2("spark.yarn.dist.files") should be (Utils.resolveURIs(files)) + sysProps2("spark.yarn.dist.archives") should be (Utils.resolveURIs(archives)) + + // Test python files + val clArgs3 = Seq( + "--master", "local", + "--py-files", pyFiles, + "mister.py" + ) + val appArgs3 = new SparkSubmitArguments(clArgs3) + val sysProps3 = SparkSubmit.createLaunchEnv(appArgs3)._3 + appArgs3.pyFiles should be (Utils.resolveURIs(pyFiles)) + sysProps3("spark.submit.pyFiles") should be ( + PythonRunner.formatPaths(Utils.resolveURIs(pyFiles)).mkString(",")) + } + + test("resolves config paths correctly") { + val jars = "/jar1,/jar2" // spark.jars + val files = "hdfs:/file1,file2" // spark.files / spark.yarn.dist.files + val archives = "file:/archive1,archive2" // spark.yarn.dist.archives + val pyFiles = "py-file1,py-file2" // spark.submit.pyFiles + + // Test jars and files + val f1 = File.createTempFile("test-submit-jars-files", "") + val writer1 = new PrintWriter(f1) + writer1.println("spark.jars " + jars) + writer1.println("spark.files " + files) + writer1.close() + val clArgs = Seq( + "--master", "local", + "--class", "org.SomeClass", + "--properties-file", f1.getPath, + "thejar.jar" + ) + val appArgs = new SparkSubmitArguments(clArgs) + val sysProps = SparkSubmit.createLaunchEnv(appArgs)._3 + sysProps("spark.jars") should be(Utils.resolveURIs(jars + ",thejar.jar")) + sysProps("spark.files") should be(Utils.resolveURIs(files)) + + // Test files and archives (Yarn) + val f2 = File.createTempFile("test-submit-files-archives", "") + val writer2 = new PrintWriter(f2) + writer2.println("spark.yarn.dist.files " + files) + writer2.println("spark.yarn.dist.archives " + archives) + writer2.close() + val clArgs2 = Seq( + "--master", "yarn-client", + "--class", "org.SomeClass", + "--properties-file", f2.getPath, + "thejar.jar" + ) + val appArgs2 = new SparkSubmitArguments(clArgs2) + val sysProps2 = SparkSubmit.createLaunchEnv(appArgs2)._3 + sysProps2("spark.yarn.dist.files") should be(Utils.resolveURIs(files)) + sysProps2("spark.yarn.dist.archives") should be(Utils.resolveURIs(archives)) + + // Test python files + val f3 = File.createTempFile("test-submit-python-files", "") + val writer3 = new PrintWriter(f3) + writer3.println("spark.submit.pyFiles " + pyFiles) + writer3.close() + val clArgs3 = Seq( + "--master", "local", + "--properties-file", f3.getPath, + "mister.py" + ) + val appArgs3 = new SparkSubmitArguments(clArgs3) + val sysProps3 = SparkSubmit.createLaunchEnv(appArgs3)._3 + sysProps3("spark.submit.pyFiles") should be( + PythonRunner.formatPaths(Utils.resolveURIs(pyFiles)).mkString(",")) + } + test("SPARK_CONF_DIR overrides spark-defaults.conf") { forConfDir(Map("spark.executor.memory" -> "2.3g")) { path => val unusedJar = TestUtils.createJarWithClasses(Seq.empty) diff --git a/core/src/test/scala/org/apache/spark/deploy/worker/ExecutorRunnerTest.scala b/core/src/test/scala/org/apache/spark/deploy/worker/ExecutorRunnerTest.scala index 5e2592e8d2e8d..196217062991e 100644 --- a/core/src/test/scala/org/apache/spark/deploy/worker/ExecutorRunnerTest.scala +++ b/core/src/test/scala/org/apache/spark/deploy/worker/ExecutorRunnerTest.scala @@ -19,6 +19,8 @@ package org.apache.spark.deploy.worker import java.io.File +import scala.collection.JavaConversions._ + import org.scalatest.FunSuite import org.apache.spark.deploy.{ApplicationDescription, Command, ExecutorState} @@ -32,6 +34,7 @@ class ExecutorRunnerTest extends FunSuite { Command("foo", Seq(appId), Map(), Seq(), Seq(), Seq()), "appUiUrl") val er = new ExecutorRunner(appId, 1, appDesc, 8, 500, null, "blah", "worker321", new File(sparkHome), new File("ooga"), "blah", new SparkConf, ExecutorState.RUNNING) - assert(er.getCommandSeq.last === appId) + val builder = CommandUtils.buildProcessBuilder(appDesc.command, 512, sparkHome, er.substituteVariables) + assert(builder.command().last === appId) } } diff --git a/core/src/test/scala/org/apache/spark/metrics/InputMetricsSuite.scala b/core/src/test/scala/org/apache/spark/metrics/InputMetricsSuite.scala new file mode 100644 index 0000000000000..48c386ba04311 --- /dev/null +++ b/core/src/test/scala/org/apache/spark/metrics/InputMetricsSuite.scala @@ -0,0 +1,76 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.metrics + +import org.scalatest.FunSuite + +import org.apache.spark.SharedSparkContext +import org.apache.spark.scheduler.{SparkListenerTaskEnd, SparkListener} + +import scala.collection.mutable.ArrayBuffer + +import java.io.{FileWriter, PrintWriter, File} + +class InputMetricsSuite extends FunSuite with SharedSparkContext { + test("input metrics when reading text file with single split") { + val file = new File(getClass.getSimpleName + ".txt") + val pw = new PrintWriter(new FileWriter(file)) + pw.println("some stuff") + pw.println("some other stuff") + pw.println("yet more stuff") + pw.println("too much stuff") + pw.close() + file.deleteOnExit() + + val taskBytesRead = new ArrayBuffer[Long]() + sc.addSparkListener(new SparkListener() { + override def onTaskEnd(taskEnd: SparkListenerTaskEnd) { + taskBytesRead += taskEnd.taskMetrics.inputMetrics.get.bytesRead + } + }) + sc.textFile("file://" + file.getAbsolutePath, 2).count() + + // Wait for task end events to come in + sc.listenerBus.waitUntilEmpty(500) + assert(taskBytesRead.length == 2) + assert(taskBytesRead.sum >= file.length()) + } + + test("input metrics when reading text file with multiple splits") { + val file = new File(getClass.getSimpleName + ".txt") + val pw = new PrintWriter(new FileWriter(file)) + for (i <- 0 until 10000) { + pw.println("some stuff") + } + pw.close() + file.deleteOnExit() + + val taskBytesRead = new ArrayBuffer[Long]() + sc.addSparkListener(new SparkListener() { + override def onTaskEnd(taskEnd: SparkListenerTaskEnd) { + taskBytesRead += taskEnd.taskMetrics.inputMetrics.get.bytesRead + } + }) + sc.textFile("file://" + file.getAbsolutePath, 2).count() + + // Wait for task end events to come in + sc.listenerBus.waitUntilEmpty(500) + assert(taskBytesRead.length == 2) + assert(taskBytesRead.sum >= file.length()) + } +} diff --git a/core/src/test/scala/org/apache/spark/metrics/MetricsSystemSuite.scala b/core/src/test/scala/org/apache/spark/metrics/MetricsSystemSuite.scala index 3925f0ccbdbf0..bbdc9568a6ddb 100644 --- a/core/src/test/scala/org/apache/spark/metrics/MetricsSystemSuite.scala +++ b/core/src/test/scala/org/apache/spark/metrics/MetricsSystemSuite.scala @@ -121,7 +121,7 @@ class MetricsSystemSuite extends FunSuite with BeforeAndAfter with PrivateMethod } val appId = "testId" - val executorId = "executor.1" + val executorId = "1" conf.set("spark.app.id", appId) conf.set("spark.executor.id", executorId) @@ -138,7 +138,7 @@ class MetricsSystemSuite extends FunSuite with BeforeAndAfter with PrivateMethod override val metricRegistry = new MetricRegistry() } - val executorId = "executor.1" + val executorId = "1" conf.set("spark.executor.id", executorId) val instanceName = "executor" diff --git a/core/src/test/scala/org/apache/spark/network/netty/ServerClientIntegrationSuite.scala b/core/src/test/scala/org/apache/spark/network/netty/ServerClientIntegrationSuite.scala deleted file mode 100644 index 02d0ffc86f58f..0000000000000 --- a/core/src/test/scala/org/apache/spark/network/netty/ServerClientIntegrationSuite.scala +++ /dev/null @@ -1,161 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.network.netty - -import java.io.{RandomAccessFile, File} -import java.nio.ByteBuffer -import java.util.{Collections, HashSet} -import java.util.concurrent.{TimeUnit, Semaphore} - -import scala.collection.JavaConversions._ - -import io.netty.buffer.{ByteBufUtil, Unpooled} - -import org.scalatest.{BeforeAndAfterAll, FunSuite} - -import org.apache.spark.SparkConf -import org.apache.spark.network.netty.client.{BlockClientListener, ReferenceCountedBuffer, BlockFetchingClientFactory} -import org.apache.spark.network.netty.server.BlockServer -import org.apache.spark.storage.{FileSegment, BlockDataProvider} - - -/** - * Test suite that makes sure the server and the client implementations share the same protocol. - */ -class ServerClientIntegrationSuite extends FunSuite with BeforeAndAfterAll { - - val bufSize = 100000 - var buf: ByteBuffer = _ - var testFile: File = _ - var server: BlockServer = _ - var clientFactory: BlockFetchingClientFactory = _ - - val bufferBlockId = "buffer_block" - val fileBlockId = "file_block" - - val fileContent = new Array[Byte](1024) - scala.util.Random.nextBytes(fileContent) - - override def beforeAll() = { - buf = ByteBuffer.allocate(bufSize) - for (i <- 1 to bufSize) { - buf.put(i.toByte) - } - buf.flip() - - testFile = File.createTempFile("netty-test-file", "txt") - val fp = new RandomAccessFile(testFile, "rw") - fp.write(fileContent) - fp.close() - - server = new BlockServer(new SparkConf, new BlockDataProvider { - override def getBlockData(blockId: String): Either[FileSegment, ByteBuffer] = { - if (blockId == bufferBlockId) { - Right(buf) - } else if (blockId == fileBlockId) { - Left(new FileSegment(testFile, 10, testFile.length - 25)) - } else { - throw new Exception("Unknown block id " + blockId) - } - } - }) - - clientFactory = new BlockFetchingClientFactory(new SparkConf) - } - - override def afterAll() = { - server.stop() - clientFactory.stop() - } - - /** A ByteBuf for buffer_block */ - lazy val byteBufferBlockReference = Unpooled.wrappedBuffer(buf) - - /** A ByteBuf for file_block */ - lazy val fileBlockReference = Unpooled.wrappedBuffer(fileContent, 10, fileContent.length - 25) - - def fetchBlocks(blockIds: Seq[String]): (Set[String], Set[ReferenceCountedBuffer], Set[String]) = - { - val client = clientFactory.createClient(server.hostName, server.port) - val sem = new Semaphore(0) - val receivedBlockIds = Collections.synchronizedSet(new HashSet[String]) - val errorBlockIds = Collections.synchronizedSet(new HashSet[String]) - val receivedBuffers = Collections.synchronizedSet(new HashSet[ReferenceCountedBuffer]) - - client.fetchBlocks( - blockIds, - new BlockClientListener { - override def onFetchFailure(blockId: String, errorMsg: String): Unit = { - errorBlockIds.add(blockId) - sem.release() - } - - override def onFetchSuccess(blockId: String, data: ReferenceCountedBuffer): Unit = { - receivedBlockIds.add(blockId) - data.retain() - receivedBuffers.add(data) - sem.release() - } - } - ) - if (!sem.tryAcquire(blockIds.size, 30, TimeUnit.SECONDS)) { - fail("Timeout getting response from the server") - } - client.close() - (receivedBlockIds.toSet, receivedBuffers.toSet, errorBlockIds.toSet) - } - - test("fetch a ByteBuffer block") { - val (blockIds, buffers, failBlockIds) = fetchBlocks(Seq(bufferBlockId)) - assert(blockIds === Set(bufferBlockId)) - assert(buffers.map(_.underlying) === Set(byteBufferBlockReference)) - assert(failBlockIds.isEmpty) - buffers.foreach(_.release()) - } - - test("fetch a FileSegment block via zero-copy send") { - val (blockIds, buffers, failBlockIds) = fetchBlocks(Seq(fileBlockId)) - assert(blockIds === Set(fileBlockId)) - assert(buffers.map(_.underlying) === Set(fileBlockReference)) - assert(failBlockIds.isEmpty) - buffers.foreach(_.release()) - } - - test("fetch a non-existent block") { - val (blockIds, buffers, failBlockIds) = fetchBlocks(Seq("random-block")) - assert(blockIds.isEmpty) - assert(buffers.isEmpty) - assert(failBlockIds === Set("random-block")) - } - - test("fetch both ByteBuffer block and FileSegment block") { - val (blockIds, buffers, failBlockIds) = fetchBlocks(Seq(bufferBlockId, fileBlockId)) - assert(blockIds === Set(bufferBlockId, fileBlockId)) - assert(buffers.map(_.underlying) === Set(byteBufferBlockReference, fileBlockReference)) - assert(failBlockIds.isEmpty) - buffers.foreach(_.release()) - } - - test("fetch both ByteBuffer block and a non-existent block") { - val (blockIds, buffers, failBlockIds) = fetchBlocks(Seq(bufferBlockId, "random-block")) - assert(blockIds === Set(bufferBlockId)) - assert(buffers.map(_.underlying) === Set(byteBufferBlockReference)) - assert(failBlockIds === Set("random-block")) - buffers.foreach(_.release()) - } -} diff --git a/core/src/test/scala/org/apache/spark/network/netty/client/BlockFetchingClientHandlerSuite.scala b/core/src/test/scala/org/apache/spark/network/netty/client/BlockFetchingClientHandlerSuite.scala deleted file mode 100644 index 903ab09ae4322..0000000000000 --- a/core/src/test/scala/org/apache/spark/network/netty/client/BlockFetchingClientHandlerSuite.scala +++ /dev/null @@ -1,105 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.network.netty.client - -import java.nio.ByteBuffer - -import io.netty.buffer.Unpooled -import io.netty.channel.embedded.EmbeddedChannel - -import org.scalatest.{PrivateMethodTester, FunSuite} - - -class BlockFetchingClientHandlerSuite extends FunSuite with PrivateMethodTester { - - test("handling block data (successful fetch)") { - val blockId = "test_block" - val blockData = "blahblahblahblahblah" - val totalLength = 4 + blockId.length + blockData.length - - var parsedBlockId: String = "" - var parsedBlockData: String = "" - val handler = new BlockFetchingClientHandler - handler.addRequest(blockId, - new BlockClientListener { - override def onFetchFailure(blockId: String, errorMsg: String): Unit = ??? - override def onFetchSuccess(bid: String, refCntBuf: ReferenceCountedBuffer): Unit = { - parsedBlockId = bid - val bytes = new Array[Byte](refCntBuf.byteBuffer().remaining) - refCntBuf.byteBuffer().get(bytes) - parsedBlockData = new String(bytes) - } - } - ) - - val outstandingRequests = PrivateMethod[java.util.Map[_, _]]('outstandingRequests) - assert(handler.invokePrivate(outstandingRequests()).size === 1) - - val channel = new EmbeddedChannel(handler) - val buf = ByteBuffer.allocate(totalLength + 4) // 4 bytes for the length field itself - buf.putInt(totalLength) - buf.putInt(blockId.length) - buf.put(blockId.getBytes) - buf.put(blockData.getBytes) - buf.flip() - - channel.writeInbound(Unpooled.wrappedBuffer(buf)) - assert(parsedBlockId === blockId) - assert(parsedBlockData === blockData) - - assert(handler.invokePrivate(outstandingRequests()).size === 0) - - channel.close() - } - - test("handling error message (failed fetch)") { - val blockId = "test_block" - val errorMsg = "error erro5r error err4or error3 error6 error erro1r" - val totalLength = 4 + blockId.length + errorMsg.length - - var parsedBlockId: String = "" - var parsedErrorMsg: String = "" - val handler = new BlockFetchingClientHandler - handler.addRequest(blockId, new BlockClientListener { - override def onFetchFailure(bid: String, msg: String) ={ - parsedBlockId = bid - parsedErrorMsg = msg - } - override def onFetchSuccess(bid: String, refCntBuf: ReferenceCountedBuffer) = ??? - }) - - val outstandingRequests = PrivateMethod[java.util.Map[_, _]]('outstandingRequests) - assert(handler.invokePrivate(outstandingRequests()).size === 1) - - val channel = new EmbeddedChannel(handler) - val buf = ByteBuffer.allocate(totalLength + 4) // 4 bytes for the length field itself - buf.putInt(totalLength) - buf.putInt(-blockId.length) - buf.put(blockId.getBytes) - buf.put(errorMsg.getBytes) - buf.flip() - - channel.writeInbound(Unpooled.wrappedBuffer(buf)) - assert(parsedBlockId === blockId) - assert(parsedErrorMsg === errorMsg) - - assert(handler.invokePrivate(outstandingRequests()).size === 0) - - channel.close() - } -} diff --git a/core/src/test/scala/org/apache/spark/network/netty/server/BlockHeaderEncoderSuite.scala b/core/src/test/scala/org/apache/spark/network/netty/server/BlockHeaderEncoderSuite.scala deleted file mode 100644 index 3ee281cb1350b..0000000000000 --- a/core/src/test/scala/org/apache/spark/network/netty/server/BlockHeaderEncoderSuite.scala +++ /dev/null @@ -1,64 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.network.netty.server - -import io.netty.buffer.ByteBuf -import io.netty.channel.embedded.EmbeddedChannel - -import org.scalatest.FunSuite - - -class BlockHeaderEncoderSuite extends FunSuite { - - test("encode normal block data") { - val blockId = "test_block" - val channel = new EmbeddedChannel(new BlockHeaderEncoder) - channel.writeOutbound(new BlockHeader(17, blockId, None)) - val out = channel.readOutbound().asInstanceOf[ByteBuf] - assert(out.readInt() === 4 + blockId.length + 17) - assert(out.readInt() === blockId.length) - - val blockIdBytes = new Array[Byte](blockId.length) - out.readBytes(blockIdBytes) - assert(new String(blockIdBytes) === blockId) - assert(out.readableBytes() === 0) - - channel.close() - } - - test("encode error message") { - val blockId = "error_block" - val errorMsg = "error encountered" - val channel = new EmbeddedChannel(new BlockHeaderEncoder) - channel.writeOutbound(new BlockHeader(17, blockId, Some(errorMsg))) - val out = channel.readOutbound().asInstanceOf[ByteBuf] - assert(out.readInt() === 4 + blockId.length + errorMsg.length) - assert(out.readInt() === -blockId.length) - - val blockIdBytes = new Array[Byte](blockId.length) - out.readBytes(blockIdBytes) - assert(new String(blockIdBytes) === blockId) - - val errorMsgBytes = new Array[Byte](errorMsg.length) - out.readBytes(errorMsgBytes) - assert(new String(errorMsgBytes) === errorMsg) - assert(out.readableBytes() === 0) - - channel.close() - } -} diff --git a/core/src/test/scala/org/apache/spark/network/netty/server/BlockServerHandlerSuite.scala b/core/src/test/scala/org/apache/spark/network/netty/server/BlockServerHandlerSuite.scala deleted file mode 100644 index 3239c710f1639..0000000000000 --- a/core/src/test/scala/org/apache/spark/network/netty/server/BlockServerHandlerSuite.scala +++ /dev/null @@ -1,107 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.network.netty.server - -import java.io.{RandomAccessFile, File} -import java.nio.ByteBuffer - -import io.netty.buffer.{Unpooled, ByteBuf} -import io.netty.channel.{ChannelHandlerContext, SimpleChannelInboundHandler, DefaultFileRegion} -import io.netty.channel.embedded.EmbeddedChannel - -import org.scalatest.FunSuite - -import org.apache.spark.storage.{BlockDataProvider, FileSegment} - - -class BlockServerHandlerSuite extends FunSuite { - - test("ByteBuffer block") { - val expectedBlockId = "test_bytebuffer_block" - val buf = ByteBuffer.allocate(10000) - for (i <- 1 to 10000) { - buf.put(i.toByte) - } - buf.flip() - - val channel = new EmbeddedChannel(new BlockServerHandler(new BlockDataProvider { - override def getBlockData(blockId: String): Either[FileSegment, ByteBuffer] = Right(buf) - })) - - channel.writeInbound(expectedBlockId) - assert(channel.outboundMessages().size === 2) - - val out1 = channel.readOutbound().asInstanceOf[BlockHeader] - val out2 = channel.readOutbound().asInstanceOf[ByteBuf] - - assert(out1.blockId === expectedBlockId) - assert(out1.blockSize === buf.remaining) - assert(out1.error === None) - - assert(out2.equals(Unpooled.wrappedBuffer(buf))) - - channel.close() - } - - test("FileSegment block via zero-copy") { - val expectedBlockId = "test_file_block" - - // Create random file data - val fileContent = new Array[Byte](1024) - scala.util.Random.nextBytes(fileContent) - val testFile = File.createTempFile("netty-test-file", "txt") - val fp = new RandomAccessFile(testFile, "rw") - fp.write(fileContent) - fp.close() - - val channel = new EmbeddedChannel(new BlockServerHandler(new BlockDataProvider { - override def getBlockData(blockId: String): Either[FileSegment, ByteBuffer] = { - Left(new FileSegment(testFile, 15, testFile.length - 25)) - } - })) - - channel.writeInbound(expectedBlockId) - assert(channel.outboundMessages().size === 2) - - val out1 = channel.readOutbound().asInstanceOf[BlockHeader] - val out2 = channel.readOutbound().asInstanceOf[DefaultFileRegion] - - assert(out1.blockId === expectedBlockId) - assert(out1.blockSize === testFile.length - 25) - assert(out1.error === None) - - assert(out2.count === testFile.length - 25) - assert(out2.position === 15) - } - - test("pipeline exception propagation") { - val blockServerHandler = new BlockServerHandler(new BlockDataProvider { - override def getBlockData(blockId: String): Either[FileSegment, ByteBuffer] = ??? - }) - val exceptionHandler = new SimpleChannelInboundHandler[String]() { - override def channelRead0(ctx: ChannelHandlerContext, msg: String): Unit = { - throw new Exception("this is an error") - } - } - - val channel = new EmbeddedChannel(exceptionHandler, blockServerHandler) - assert(channel.isOpen) - channel.writeInbound("a message to trigger the error") - assert(!channel.isOpen) - } -} diff --git a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala index a2e4f712db55b..819f95634bcdc 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala @@ -431,7 +431,7 @@ class DAGSchedulerSuite extends TestKit(ActorSystem("DAGSchedulerSuite")) with F // the 2nd ResultTask failed complete(taskSets(1), Seq( (Success, 42), - (FetchFailed(makeBlockManagerId("hostA"), shuffleId, 0, 0), null))) + (FetchFailed(makeBlockManagerId("hostA"), shuffleId, 0, 0, "ignored"), null))) // this will get called // blockManagerMaster.removeExecutor("exec-hostA") // ask the scheduler to try it again @@ -461,7 +461,7 @@ class DAGSchedulerSuite extends TestKit(ActorSystem("DAGSchedulerSuite")) with F // The first result task fails, with a fetch failure for the output from the first mapper. runEvent(CompletionEvent( taskSets(1).tasks(0), - FetchFailed(makeBlockManagerId("hostA"), shuffleId, 0, 0), + FetchFailed(makeBlockManagerId("hostA"), shuffleId, 0, 0, "ignored"), null, Map[Long, Any](), null, @@ -472,7 +472,7 @@ class DAGSchedulerSuite extends TestKit(ActorSystem("DAGSchedulerSuite")) with F // The second ResultTask fails, with a fetch failure for the output from the second mapper. runEvent(CompletionEvent( taskSets(1).tasks(0), - FetchFailed(makeBlockManagerId("hostA"), shuffleId, 1, 1), + FetchFailed(makeBlockManagerId("hostA"), shuffleId, 1, 1, "ignored"), null, Map[Long, Any](), null, @@ -624,7 +624,7 @@ class DAGSchedulerSuite extends TestKit(ActorSystem("DAGSchedulerSuite")) with F (Success, makeMapStatus("hostC", 1)))) // fail the third stage because hostA went down complete(taskSets(2), Seq( - (FetchFailed(makeBlockManagerId("hostA"), shuffleDepTwo.shuffleId, 0, 0), null))) + (FetchFailed(makeBlockManagerId("hostA"), shuffleDepTwo.shuffleId, 0, 0, "ignored"), null))) // TODO assert this: // blockManagerMaster.removeExecutor("exec-hostA") // have DAGScheduler try again @@ -655,7 +655,7 @@ class DAGSchedulerSuite extends TestKit(ActorSystem("DAGSchedulerSuite")) with F (Success, makeMapStatus("hostB", 1)))) // pretend stage 0 failed because hostA went down complete(taskSets(2), Seq( - (FetchFailed(makeBlockManagerId("hostA"), shuffleDepTwo.shuffleId, 0, 0), null))) + (FetchFailed(makeBlockManagerId("hostA"), shuffleDepTwo.shuffleId, 0, 0, "ignored"), null))) // TODO assert this: // blockManagerMaster.removeExecutor("exec-hostA") // DAGScheduler should notice the cached copy of the second shuffle and try to get it rerun. diff --git a/core/src/test/scala/org/apache/spark/scheduler/TaskResultGetterSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskResultGetterSuite.scala index c4e7a4bb7d385..5768a3a733f00 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/TaskResultGetterSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/TaskResultGetterSuite.scala @@ -40,7 +40,7 @@ class ResultDeletingTaskResultGetter(sparkEnv: SparkEnv, scheduler: TaskSchedule // Only remove the result once, since we'd like to test the case where the task eventually // succeeds. serializer.get().deserialize[TaskResult[_]](serializedData) match { - case IndirectTaskResult(blockId) => + case IndirectTaskResult(blockId, size) => sparkEnv.blockManager.master.removeBlock(blockId) case directResult: DirectTaskResult[_] => taskSetManager.abort("Internal error: expect only indirect results") diff --git a/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala index c0b07649eb6dd..1809b5396d53e 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala @@ -563,6 +563,31 @@ class TaskSetManagerSuite extends FunSuite with LocalSparkContext with Logging { assert(manager.emittedTaskSizeWarning) } + test("abort the job if total size of results is too large") { + val conf = new SparkConf().set("spark.driver.maxResultSize", "2m") + sc = new SparkContext("local", "test", conf) + + def genBytes(size: Int) = { (x: Int) => + val bytes = Array.ofDim[Byte](size) + scala.util.Random.nextBytes(bytes) + bytes + } + + // multiple 1k result + val r = sc.makeRDD(0 until 10, 10).map(genBytes(1024)).collect() + assert(10 === r.size ) + + // single 10M result + val thrown = intercept[SparkException] {sc.makeRDD(genBytes(10 << 20)(0), 1).collect()} + assert(thrown.getMessage().contains("bigger than maxResultSize")) + + // multiple 1M results + val thrown2 = intercept[SparkException] { + sc.makeRDD(0 until 10, 10).map(genBytes(1 << 20)).collect() + } + assert(thrown2.getMessage().contains("bigger than maxResultSize")) + } + test("speculative and noPref task should be scheduled after node-local") { sc = new SparkContext("local", "test") val sched = new FakeTaskScheduler(sc, ("execA", "host1"), ("execB", "host2"), ("execC", "host3")) diff --git a/core/src/test/scala/org/apache/spark/serializer/KryoSerializerSuite.scala b/core/src/test/scala/org/apache/spark/serializer/KryoSerializerSuite.scala index 64ac6d2d920d2..a70f67af2e62e 100644 --- a/core/src/test/scala/org/apache/spark/serializer/KryoSerializerSuite.scala +++ b/core/src/test/scala/org/apache/spark/serializer/KryoSerializerSuite.scala @@ -201,12 +201,17 @@ class KryoSerializerSuite extends FunSuite with SharedSparkContext { assert(control.sum === result) } - // TODO: this still doesn't work - ignore("kryo with fold") { + test("kryo with fold") { val control = 1 :: 2 :: Nil + // zeroValue must not be a ClassWithoutNoArgConstructor instance because it will be + // serialized by spark.closure.serializer but spark.closure.serializer only supports + // the default Java serializer. val result = sc.parallelize(control, 2).map(new ClassWithoutNoArgConstructor(_)) - .fold(new ClassWithoutNoArgConstructor(10))((t1, t2) => new ClassWithoutNoArgConstructor(t1.x + t2.x)).x - assert(10 + control.sum === result) + .fold(null)((t1, t2) => { + val t1x = if (t1 == null) 0 else t1.x + new ClassWithoutNoArgConstructor(t1x + t2.x) + }).x + assert(control.sum === result) } test("kryo with nonexistent custom registrator should fail") { diff --git a/core/src/test/scala/org/apache/spark/serializer/TestSerializer.scala b/core/src/test/scala/org/apache/spark/serializer/TestSerializer.scala new file mode 100644 index 0000000000000..0ade1bab18d7e --- /dev/null +++ b/core/src/test/scala/org/apache/spark/serializer/TestSerializer.scala @@ -0,0 +1,60 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.serializer + +import java.io.{EOFException, OutputStream, InputStream} +import java.nio.ByteBuffer + +import scala.reflect.ClassTag + + +/** + * A serializer implementation that always return a single element in a deserialization stream. + */ +class TestSerializer extends Serializer { + override def newInstance() = new TestSerializerInstance +} + + +class TestSerializerInstance extends SerializerInstance { + override def serialize[T: ClassTag](t: T): ByteBuffer = ??? + + override def serializeStream(s: OutputStream): SerializationStream = ??? + + override def deserializeStream(s: InputStream) = new TestDeserializationStream + + override def deserialize[T: ClassTag](bytes: ByteBuffer): T = ??? + + override def deserialize[T: ClassTag](bytes: ByteBuffer, loader: ClassLoader): T = ??? +} + + +class TestDeserializationStream extends DeserializationStream { + + private var count = 0 + + override def readObject[T: ClassTag](): T = { + count += 1 + if (count == 2) { + throw new EOFException + } + new Object().asInstanceOf[T] + } + + override def close(): Unit = {} +} diff --git a/core/src/test/scala/org/apache/spark/shuffle/hash/HashShuffleManagerSuite.scala b/core/src/test/scala/org/apache/spark/shuffle/hash/HashShuffleManagerSuite.scala index ba47fe5e25b9b..6790388f96603 100644 --- a/core/src/test/scala/org/apache/spark/shuffle/hash/HashShuffleManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/shuffle/hash/HashShuffleManagerSuite.scala @@ -25,7 +25,7 @@ import org.scalatest.FunSuite import org.apache.spark.{SparkEnv, SparkContext, LocalSparkContext, SparkConf} import org.apache.spark.executor.ShuffleWriteMetrics -import org.apache.spark.network.{FileSegmentManagedBuffer, ManagedBuffer} +import org.apache.spark.network.buffer.{FileSegmentManagedBuffer, ManagedBuffer} import org.apache.spark.serializer.JavaSerializer import org.apache.spark.shuffle.FileShuffleBlockManager import org.apache.spark.storage.{ShuffleBlockId, FileSegment} @@ -36,9 +36,9 @@ class HashShuffleManagerSuite extends FunSuite with LocalSparkContext { private def checkSegments(expected: FileSegment, buffer: ManagedBuffer) { assert(buffer.isInstanceOf[FileSegmentManagedBuffer]) val segment = buffer.asInstanceOf[FileSegmentManagedBuffer] - assert(expected.file.getCanonicalPath === segment.file.getCanonicalPath) - assert(expected.offset === segment.offset) - assert(expected.length === segment.length) + assert(expected.file.getCanonicalPath === segment.getFile.getCanonicalPath) + assert(expected.offset === segment.getOffset) + assert(expected.length === segment.getLength) } test("consolidated shuffle can write to shuffle group without messing existing offsets/lengths") { diff --git a/core/src/test/scala/org/apache/spark/storage/BlockManagerReplicationSuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockManagerReplicationSuite.scala index 1f1d53a1ee3b0..c6d7105592096 100644 --- a/core/src/test/scala/org/apache/spark/storage/BlockManagerReplicationSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/BlockManagerReplicationSuite.scala @@ -27,7 +27,7 @@ import org.mockito.Mockito.{mock, when} import org.scalatest.{BeforeAndAfter, FunSuite, Matchers, PrivateMethodTester} import org.scalatest.concurrent.Eventually._ -import org.apache.spark.{MapOutputTrackerMaster, SecurityManager, SparkConf} +import org.apache.spark.{MapOutputTrackerMaster, SparkConf, SparkContext, SecurityManager} import org.apache.spark.network.BlockTransferService import org.apache.spark.network.nio.NioBlockTransferService import org.apache.spark.scheduler.LiveListenerBus @@ -57,7 +57,9 @@ class BlockManagerReplicationSuite extends FunSuite with Matchers with BeforeAnd // Implicitly convert strings to BlockIds for test clarity. implicit def StringToBlockId(value: String): BlockId = new TestBlockId(value) - private def makeBlockManager(maxMem: Long, name: String = ""): BlockManager = { + private def makeBlockManager( + maxMem: Long, + name: String = SparkContext.DRIVER_IDENTIFIER): BlockManager = { val transfer = new NioBlockTransferService(conf, securityMgr) val store = new BlockManager(name, actorSystem, master, serializer, maxMem, conf, mapOutputTracker, shuffleManager, transfer) @@ -108,7 +110,7 @@ class BlockManagerReplicationSuite extends FunSuite with Matchers with BeforeAnd storeIds.filterNot { _ == stores(2).blockManagerId }) // Add driver store and test whether it is filtered out - val driverStore = makeBlockManager(1000, "") + val driverStore = makeBlockManager(1000, SparkContext.DRIVER_IDENTIFIER) assert(master.getPeers(stores(0).blockManagerId).forall(!_.isDriver)) assert(master.getPeers(stores(1).blockManagerId).forall(!_.isDriver)) assert(master.getPeers(stores(2).blockManagerId).forall(!_.isDriver)) diff --git a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala index 9d96202a3e7ac..715b740b857b2 100644 --- a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala @@ -37,7 +37,7 @@ import org.scalatest.{BeforeAndAfter, FunSuite, Matchers, PrivateMethodTester} import org.scalatest.concurrent.Eventually._ import org.scalatest.concurrent.Timeouts._ -import org.apache.spark.{MapOutputTrackerMaster, SecurityManager, SparkConf} +import org.apache.spark.{MapOutputTrackerMaster, SparkConf, SparkContext, SecurityManager} import org.apache.spark.executor.DataReadMethod import org.apache.spark.network.nio.NioBlockTransferService import org.apache.spark.scheduler.LiveListenerBus @@ -69,7 +69,9 @@ class BlockManagerSuite extends FunSuite with Matchers with BeforeAndAfter implicit def StringToBlockId(value: String): BlockId = new TestBlockId(value) def rdd(rddId: Int, splitId: Int) = RDDBlockId(rddId, splitId) - private def makeBlockManager(maxMem: Long, name: String = ""): BlockManager = { + private def makeBlockManager( + maxMem: Long, + name: String = SparkContext.DRIVER_IDENTIFIER): BlockManager = { val transfer = new NioBlockTransferService(conf, securityMgr) new BlockManager(name, actorSystem, master, serializer, maxMem, conf, mapOutputTracker, shuffleManager, transfer) @@ -790,8 +792,8 @@ class BlockManagerSuite extends FunSuite with Matchers with BeforeAndAfter test("block store put failure") { // Use Java serializer so we can create an unserializable error. val transfer = new NioBlockTransferService(conf, securityMgr) - store = new BlockManager("", actorSystem, master, new JavaSerializer(conf), 1200, conf, - mapOutputTracker, shuffleManager, transfer) + store = new BlockManager(SparkContext.DRIVER_IDENTIFIER, actorSystem, master, + new JavaSerializer(conf), 1200, conf, mapOutputTracker, shuffleManager, transfer) // The put should fail since a1 is not serializable. class UnserializableClass diff --git a/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala b/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala index a8c049d749015..1eaabb93adbed 100644 --- a/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala @@ -17,48 +17,75 @@ package org.apache.spark.storage -import org.apache.spark.{TaskContextImpl, TaskContext} -import org.apache.spark.network.{BlockFetchingListener, BlockTransferService} +import java.util.concurrent.Semaphore + +import scala.concurrent.future +import scala.concurrent.ExecutionContext.Implicits.global -import org.mockito.Mockito._ import org.mockito.Matchers.{any, eq => meq} +import org.mockito.Mockito._ import org.mockito.invocation.InvocationOnMock import org.mockito.stubbing.Answer - import org.scalatest.FunSuite +import org.apache.spark.{SparkConf, TaskContextImpl} +import org.apache.spark.network._ +import org.apache.spark.network.buffer.ManagedBuffer +import org.apache.spark.network.shuffle.BlockFetchingListener +import org.apache.spark.serializer.TestSerializer class ShuffleBlockFetcherIteratorSuite extends FunSuite { + // Some of the tests are quite tricky because we are testing the cleanup behavior + // in the presence of faults. - test("handle local read failures in BlockManager") { + /** Creates a mock [[BlockTransferService]] that returns data from the given map. */ + private def createMockTransfer(data: Map[BlockId, ManagedBuffer]): BlockTransferService = { val transfer = mock(classOf[BlockTransferService]) - val blockManager = mock(classOf[BlockManager]) - doReturn(BlockManagerId("test-client", "test-client", 1)).when(blockManager).blockManagerId - - val blIds = Array[BlockId]( - ShuffleBlockId(0,0,0), - ShuffleBlockId(0,1,0), - ShuffleBlockId(0,2,0), - ShuffleBlockId(0,3,0), - ShuffleBlockId(0,4,0)) - - val optItr = mock(classOf[Option[Iterator[Any]]]) - val answer = new Answer[Option[Iterator[Any]]] { - override def answer(invocation: InvocationOnMock) = Option[Iterator[Any]] { - throw new Exception + when(transfer.fetchBlocks(any(), any(), any(), any(), any())).thenAnswer(new Answer[Unit] { + override def answer(invocation: InvocationOnMock): Unit = { + val blocks = invocation.getArguments()(3).asInstanceOf[Array[String]] + val listener = invocation.getArguments()(4).asInstanceOf[BlockFetchingListener] + + for (blockId <- blocks) { + if (data.contains(BlockId(blockId))) { + listener.onBlockFetchSuccess(blockId, data(BlockId(blockId))) + } else { + listener.onBlockFetchFailure(blockId, new BlockNotFoundException(blockId)) + } + } } + }) + transfer + } + + private val conf = new SparkConf + + test("successful 3 local reads + 2 remote reads") { + val blockManager = mock(classOf[BlockManager]) + val localBmId = BlockManagerId("test-client", "test-client", 1) + doReturn(localBmId).when(blockManager).blockManagerId + + // Make sure blockManager.getBlockData would return the blocks + val localBlocks = Map[BlockId, ManagedBuffer]( + ShuffleBlockId(0, 0, 0) -> mock(classOf[ManagedBuffer]), + ShuffleBlockId(0, 1, 0) -> mock(classOf[ManagedBuffer]), + ShuffleBlockId(0, 2, 0) -> mock(classOf[ManagedBuffer])) + localBlocks.foreach { case (blockId, buf) => + doReturn(buf).when(blockManager).getBlockData(meq(blockId)) } - // 3rd block is going to fail - doReturn(optItr).when(blockManager).getLocalShuffleFromDisk(meq(blIds(0)), any()) - doReturn(optItr).when(blockManager).getLocalShuffleFromDisk(meq(blIds(1)), any()) - doAnswer(answer).when(blockManager).getLocalShuffleFromDisk(meq(blIds(2)), any()) - doReturn(optItr).when(blockManager).getLocalShuffleFromDisk(meq(blIds(3)), any()) - doReturn(optItr).when(blockManager).getLocalShuffleFromDisk(meq(blIds(4)), any()) + // Make sure remote blocks would return + val remoteBmId = BlockManagerId("test-client-1", "test-client-1", 2) + val remoteBlocks = Map[BlockId, ManagedBuffer]( + ShuffleBlockId(0, 3, 0) -> mock(classOf[ManagedBuffer]), + ShuffleBlockId(0, 4, 0) -> mock(classOf[ManagedBuffer]) + ) + + val transfer = createMockTransfer(remoteBlocks) - val bmId = BlockManagerId("test-client", "test-client", 1) val blocksByAddress = Seq[(BlockManagerId, Seq[(BlockId, Long)])]( - (bmId, blIds.map(blId => (blId, 1.asInstanceOf[Long])).toSeq) + (localBmId, localBlocks.keys.map(blockId => (blockId, 1.asInstanceOf[Long])).toSeq), + (remoteBmId, remoteBlocks.keys.map(blockId => (blockId, 1.asInstanceOf[Long])).toSeq) ) val iterator = new ShuffleBlockFetcherIterator( @@ -66,118 +93,145 @@ class ShuffleBlockFetcherIteratorSuite extends FunSuite { transfer, blockManager, blocksByAddress, - null, + new TestSerializer, 48 * 1024 * 1024) - // Without exhausting the iterator, the iterator should be lazy and not call - // getLocalShuffleFromDisk. - verify(blockManager, times(0)).getLocalShuffleFromDisk(any(), any()) - - assert(iterator.hasNext, "iterator should have 5 elements but actually has no elements") - // the 2nd element of the tuple returned by iterator.next should be defined when - // fetching successfully - assert(iterator.next()._2.isDefined, - "1st element should be defined but is not actually defined") - verify(blockManager, times(1)).getLocalShuffleFromDisk(any(), any()) - - assert(iterator.hasNext, "iterator should have 5 elements but actually has 1 element") - assert(iterator.next()._2.isDefined, - "2nd element should be defined but is not actually defined") - verify(blockManager, times(2)).getLocalShuffleFromDisk(any(), any()) - - assert(iterator.hasNext, "iterator should have 5 elements but actually has 2 elements") - // 3rd fetch should be failed - intercept[Exception] { - iterator.next() + // 3 local blocks fetched in initialization + verify(blockManager, times(3)).getBlockData(any()) + + for (i <- 0 until 5) { + assert(iterator.hasNext, s"iterator should have 5 elements but actually has $i elements") + val (blockId, subIterator) = iterator.next() + assert(subIterator.isSuccess, + s"iterator should have 5 elements defined but actually has $i elements") + + // Make sure we release the buffer once the iterator is exhausted. + val mockBuf = localBlocks.getOrElse(blockId, remoteBlocks(blockId)) + verify(mockBuf, times(0)).release() + subIterator.get.foreach(_ => Unit) // exhaust the iterator + verify(mockBuf, times(1)).release() } - verify(blockManager, times(3)).getLocalShuffleFromDisk(any(), any()) + + // 3 local blocks, and 2 remote blocks + // (but from the same block manager so one call to fetchBlocks) + verify(blockManager, times(3)).getBlockData(any()) + verify(transfer, times(1)).fetchBlocks(any(), any(), any(), any(), any()) } - test("handle local read successes") { - val transfer = mock(classOf[BlockTransferService]) + test("release current unexhausted buffer in case the task completes early") { val blockManager = mock(classOf[BlockManager]) - doReturn(BlockManagerId("test-client", "test-client", 1)).when(blockManager).blockManagerId - - val blIds = Array[BlockId]( - ShuffleBlockId(0,0,0), - ShuffleBlockId(0,1,0), - ShuffleBlockId(0,2,0), - ShuffleBlockId(0,3,0), - ShuffleBlockId(0,4,0)) + val localBmId = BlockManagerId("test-client", "test-client", 1) + doReturn(localBmId).when(blockManager).blockManagerId + + // Make sure remote blocks would return + val remoteBmId = BlockManagerId("test-client-1", "test-client-1", 2) + val blocks = Map[BlockId, ManagedBuffer]( + ShuffleBlockId(0, 0, 0) -> mock(classOf[ManagedBuffer]), + ShuffleBlockId(0, 1, 0) -> mock(classOf[ManagedBuffer]), + ShuffleBlockId(0, 2, 0) -> mock(classOf[ManagedBuffer]) + ) - val optItr = mock(classOf[Option[Iterator[Any]]]) + // Semaphore to coordinate event sequence in two different threads. + val sem = new Semaphore(0) - // All blocks should be fetched successfully - doReturn(optItr).when(blockManager).getLocalShuffleFromDisk(meq(blIds(0)), any()) - doReturn(optItr).when(blockManager).getLocalShuffleFromDisk(meq(blIds(1)), any()) - doReturn(optItr).when(blockManager).getLocalShuffleFromDisk(meq(blIds(2)), any()) - doReturn(optItr).when(blockManager).getLocalShuffleFromDisk(meq(blIds(3)), any()) - doReturn(optItr).when(blockManager).getLocalShuffleFromDisk(meq(blIds(4)), any()) + val transfer = mock(classOf[BlockTransferService]) + when(transfer.fetchBlocks(any(), any(), any(), any(), any())).thenAnswer(new Answer[Unit] { + override def answer(invocation: InvocationOnMock): Unit = { + val listener = invocation.getArguments()(4).asInstanceOf[BlockFetchingListener] + future { + // Return the first two blocks, and wait till task completion before returning the 3rd one + listener.onBlockFetchSuccess( + ShuffleBlockId(0, 0, 0).toString, blocks(ShuffleBlockId(0, 0, 0))) + listener.onBlockFetchSuccess( + ShuffleBlockId(0, 1, 0).toString, blocks(ShuffleBlockId(0, 1, 0))) + sem.acquire() + listener.onBlockFetchSuccess( + ShuffleBlockId(0, 2, 0).toString, blocks(ShuffleBlockId(0, 2, 0))) + } + } + }) - val bmId = BlockManagerId("test-client", "test-client", 1) val blocksByAddress = Seq[(BlockManagerId, Seq[(BlockId, Long)])]( - (bmId, blIds.map(blId => (blId, 1.asInstanceOf[Long])).toSeq) - ) + (remoteBmId, blocks.keys.map(blockId => (blockId, 1.asInstanceOf[Long])).toSeq)) + val taskContext = new TaskContextImpl(0, 0, 0) val iterator = new ShuffleBlockFetcherIterator( - new TaskContextImpl(0, 0, 0), + taskContext, transfer, blockManager, blocksByAddress, - null, + new TestSerializer, 48 * 1024 * 1024) - // Without exhausting the iterator, the iterator should be lazy and not call getLocalShuffleFromDisk. - verify(blockManager, times(0)).getLocalShuffleFromDisk(any(), any()) - - assert(iterator.hasNext, "iterator should have 5 elements but actually has no elements") - assert(iterator.next()._2.isDefined, - "All elements should be defined but 1st element is not actually defined") - assert(iterator.hasNext, "iterator should have 5 elements but actually has 1 element") - assert(iterator.next()._2.isDefined, - "All elements should be defined but 2nd element is not actually defined") - assert(iterator.hasNext, "iterator should have 5 elements but actually has 2 elements") - assert(iterator.next()._2.isDefined, - "All elements should be defined but 3rd element is not actually defined") - assert(iterator.hasNext, "iterator should have 5 elements but actually has 3 elements") - assert(iterator.next()._2.isDefined, - "All elements should be defined but 4th element is not actually defined") - assert(iterator.hasNext, "iterator should have 5 elements but actually has 4 elements") - assert(iterator.next()._2.isDefined, - "All elements should be defined but 5th element is not actually defined") - - verify(blockManager, times(5)).getLocalShuffleFromDisk(any(), any()) + // Exhaust the first block, and then it should be released. + iterator.next()._2.get.foreach(_ => Unit) + verify(blocks(ShuffleBlockId(0, 0, 0)), times(1)).release() + + // Get the 2nd block but do not exhaust the iterator + val subIter = iterator.next()._2.get + + // Complete the task; then the 2nd block buffer should be exhausted + verify(blocks(ShuffleBlockId(0, 1, 0)), times(0)).release() + taskContext.markTaskCompleted() + verify(blocks(ShuffleBlockId(0, 1, 0)), times(1)).release() + + // The 3rd block should not be retained because the iterator is already in zombie state + sem.release() + verify(blocks(ShuffleBlockId(0, 2, 0)), times(0)).retain() + verify(blocks(ShuffleBlockId(0, 2, 0)), times(0)).release() } - test("handle remote fetch failures in BlockTransferService") { + test("fail all blocks if any of the remote request fails") { + val blockManager = mock(classOf[BlockManager]) + val localBmId = BlockManagerId("test-client", "test-client", 1) + doReturn(localBmId).when(blockManager).blockManagerId + + // Make sure remote blocks would return + val remoteBmId = BlockManagerId("test-client-1", "test-client-1", 2) + val blocks = Map[BlockId, ManagedBuffer]( + ShuffleBlockId(0, 0, 0) -> mock(classOf[ManagedBuffer]), + ShuffleBlockId(0, 1, 0) -> mock(classOf[ManagedBuffer]), + ShuffleBlockId(0, 2, 0) -> mock(classOf[ManagedBuffer]) + ) + + // Semaphore to coordinate event sequence in two different threads. + val sem = new Semaphore(0) + val transfer = mock(classOf[BlockTransferService]) - when(transfer.fetchBlocks(any(), any(), any(), any())).thenAnswer(new Answer[Unit] { + when(transfer.fetchBlocks(any(), any(), any(), any(), any())).thenAnswer(new Answer[Unit] { override def answer(invocation: InvocationOnMock): Unit = { - val listener = invocation.getArguments()(3).asInstanceOf[BlockFetchingListener] - listener.onBlockFetchFailure(new Exception("blah")) + val listener = invocation.getArguments()(4).asInstanceOf[BlockFetchingListener] + future { + // Return the first block, and then fail. + listener.onBlockFetchSuccess( + ShuffleBlockId(0, 0, 0).toString, blocks(ShuffleBlockId(0, 0, 0))) + listener.onBlockFetchFailure( + ShuffleBlockId(0, 1, 0).toString, new BlockNotFoundException("blah")) + listener.onBlockFetchFailure( + ShuffleBlockId(0, 2, 0).toString, new BlockNotFoundException("blah")) + sem.release() + } } }) - val blockManager = mock(classOf[BlockManager]) - - when(blockManager.blockManagerId).thenReturn(BlockManagerId("test-client", "test-client", 1)) - - val blId1 = ShuffleBlockId(0, 0, 0) - val blId2 = ShuffleBlockId(0, 1, 0) - val bmId = BlockManagerId("test-server", "test-server", 1) val blocksByAddress = Seq[(BlockManagerId, Seq[(BlockId, Long)])]( - (bmId, Seq((blId1, 1L), (blId2, 1L)))) + (remoteBmId, blocks.keys.map(blockId => (blockId, 1.asInstanceOf[Long])).toSeq)) + val taskContext = new TaskContextImpl(0, 0, 0) val iterator = new ShuffleBlockFetcherIterator( - new TaskContextImpl(0, 0, 0), + taskContext, transfer, blockManager, blocksByAddress, - null, + new TestSerializer, 48 * 1024 * 1024) - iterator.foreach { case (_, iterOption) => - assert(!iterOption.isDefined) - } + // Continue only after the mock calls onBlockFetchFailure + sem.acquire() + + // The first block should be defined, and the last two are not defined (due to failure) + assert(iterator.next()._2.isSuccess) + assert(iterator.next()._2.isFailure) + assert(iterator.next()._2.isFailure) } } diff --git a/core/src/test/scala/org/apache/spark/ui/UISeleniumSuite.scala b/core/src/test/scala/org/apache/spark/ui/UISeleniumSuite.scala new file mode 100644 index 0000000000000..bacf6a16fc233 --- /dev/null +++ b/core/src/test/scala/org/apache/spark/ui/UISeleniumSuite.scala @@ -0,0 +1,112 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ui + +import org.apache.spark.api.java.StorageLevels +import org.apache.spark.{SparkException, SparkConf, SparkContext} +import org.openqa.selenium.WebDriver +import org.openqa.selenium.htmlunit.HtmlUnitDriver +import org.scalatest._ +import org.scalatest.concurrent.Eventually._ +import org.scalatest.selenium.WebBrowser +import org.scalatest.time.SpanSugar._ + +import org.apache.spark.LocalSparkContext._ + +/** + * Selenium tests for the Spark Web UI. These tests are not run by default + * because they're slow. + */ +@DoNotDiscover +class UISeleniumSuite extends FunSuite with WebBrowser with Matchers { + implicit val webDriver: WebDriver = new HtmlUnitDriver + + /** + * Create a test SparkContext with the SparkUI enabled. + * It is safe to `get` the SparkUI directly from the SparkContext returned here. + */ + private def newSparkContext(): SparkContext = { + val conf = new SparkConf() + .setMaster("local") + .setAppName("test") + .set("spark.ui.enabled", "true") + val sc = new SparkContext(conf) + assert(sc.ui.isDefined) + sc + } + + test("effects of unpersist() / persist() should be reflected") { + // Regression test for SPARK-2527 + withSpark(newSparkContext()) { sc => + val ui = sc.ui.get + val rdd = sc.parallelize(Seq(1, 2, 3)) + rdd.persist(StorageLevels.DISK_ONLY).count() + eventually(timeout(5 seconds), interval(50 milliseconds)) { + go to (ui.appUIAddress.stripSuffix("/") + "/storage") + val tableRowText = findAll(cssSelector("#storage-by-rdd-table td")).map(_.text).toSeq + tableRowText should contain (StorageLevels.DISK_ONLY.description) + } + eventually(timeout(5 seconds), interval(50 milliseconds)) { + go to (ui.appUIAddress.stripSuffix("/") + "/storage/rdd/?id=0") + val tableRowText = findAll(cssSelector("#rdd-storage-by-block-table td")).map(_.text).toSeq + tableRowText should contain (StorageLevels.DISK_ONLY.description) + } + + rdd.unpersist() + rdd.persist(StorageLevels.MEMORY_ONLY).count() + eventually(timeout(5 seconds), interval(50 milliseconds)) { + go to (ui.appUIAddress.stripSuffix("/") + "/storage") + val tableRowText = findAll(cssSelector("#storage-by-rdd-table td")).map(_.text).toSeq + tableRowText should contain (StorageLevels.MEMORY_ONLY.description) + } + eventually(timeout(5 seconds), interval(50 milliseconds)) { + go to (ui.appUIAddress.stripSuffix("/") + "/storage/rdd/?id=0") + val tableRowText = findAll(cssSelector("#rdd-storage-by-block-table td")).map(_.text).toSeq + tableRowText should contain (StorageLevels.MEMORY_ONLY.description) + } + } + } + + test("failed stages should not appear to be active") { + withSpark(newSparkContext()) { sc => + // Regression test for SPARK-3021 + intercept[SparkException] { + sc.parallelize(1 to 10).map { x => throw new Exception()}.collect() + } + eventually(timeout(5 seconds), interval(50 milliseconds)) { + go to sc.ui.get.appUIAddress + find(id("active")).get.text should be("Active Stages (0)") + find(id("failed")).get.text should be("Failed Stages (1)") + } + + // Regression test for SPARK-2105 + class NotSerializable + val unserializableObject = new NotSerializable + intercept[SparkException] { + sc.parallelize(1 to 10).map { x => unserializableObject}.collect() + } + eventually(timeout(5 seconds), interval(50 milliseconds)) { + go to sc.ui.get.appUIAddress + find(id("active")).get.text should be("Active Stages (0)") + // The failure occurs before the stage becomes active, hence we should still show only one + // failed stage, not two: + find(id("failed")).get.text should be("Failed Stages (1)") + } + } + } +} diff --git a/core/src/test/scala/org/apache/spark/ui/jobs/JobProgressListenerSuite.scala b/core/src/test/scala/org/apache/spark/ui/jobs/JobProgressListenerSuite.scala index 3370dd4156c3f..2efbae689771a 100644 --- a/core/src/test/scala/org/apache/spark/ui/jobs/JobProgressListenerSuite.scala +++ b/core/src/test/scala/org/apache/spark/ui/jobs/JobProgressListenerSuite.scala @@ -115,11 +115,11 @@ class JobProgressListenerSuite extends FunSuite with LocalSparkContext with Matc // Go through all the failure cases to make sure we are counting them as failures. val taskFailedReasons = Seq( Resubmitted, - new FetchFailed(null, 0, 0, 0), + new FetchFailed(null, 0, 0, 0, "ignored"), new ExceptionFailure("Exception", "description", null, None), TaskResultLost, TaskKilled, - ExecutorLostFailure, + ExecutorLostFailure("0"), UnknownReason) var failCount = 0 for (reason <- taskFailedReasons) { diff --git a/core/src/test/scala/org/apache/spark/util/FileAppenderSuite.scala b/core/src/test/scala/org/apache/spark/util/FileAppenderSuite.scala index d2bee448d4d3b..4dc5b6103db74 100644 --- a/core/src/test/scala/org/apache/spark/util/FileAppenderSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/FileAppenderSuite.scala @@ -18,13 +18,13 @@ package org.apache.spark.util import java.io._ -import java.nio.charset.Charset import scala.collection.mutable.HashSet import scala.reflect._ import org.scalatest.{BeforeAndAfter, FunSuite} +import com.google.common.base.Charsets.UTF_8 import com.google.common.io.Files import org.apache.spark.{Logging, SparkConf} @@ -44,11 +44,11 @@ class FileAppenderSuite extends FunSuite with BeforeAndAfter with Logging { test("basic file appender") { val testString = (1 to 1000).mkString(", ") - val inputStream = new ByteArrayInputStream(testString.getBytes(Charset.forName("UTF-8"))) + val inputStream = new ByteArrayInputStream(testString.getBytes(UTF_8)) val appender = new FileAppender(inputStream, testFile) inputStream.close() appender.awaitTermination() - assert(Files.toString(testFile, Charset.forName("UTF-8")) === testString) + assert(Files.toString(testFile, UTF_8) === testString) } test("rolling file appender - time-based rolling") { @@ -96,7 +96,7 @@ class FileAppenderSuite extends FunSuite with BeforeAndAfter with Logging { val allGeneratedFiles = new HashSet[String]() val items = (1 to 10).map { _.toString * 10000 } for (i <- 0 until items.size) { - testOutputStream.write(items(i).getBytes(Charset.forName("UTF-8"))) + testOutputStream.write(items(i).getBytes(UTF_8)) testOutputStream.flush() allGeneratedFiles ++= RollingFileAppender.getSortedRolledOverFiles( testFile.getParentFile.toString, testFile.getName).map(_.toString) @@ -199,7 +199,7 @@ class FileAppenderSuite extends FunSuite with BeforeAndAfter with Logging { // send data to appender through the input stream, and wait for the data to be written val expectedText = textToAppend.mkString("") for (i <- 0 until textToAppend.size) { - outputStream.write(textToAppend(i).getBytes(Charset.forName("UTF-8"))) + outputStream.write(textToAppend(i).getBytes(UTF_8)) outputStream.flush() Thread.sleep(sleepTimeBetweenTexts) } @@ -214,7 +214,7 @@ class FileAppenderSuite extends FunSuite with BeforeAndAfter with Logging { logInfo("Filtered files: \n" + generatedFiles.mkString("\n")) assert(generatedFiles.size > 1) val allText = generatedFiles.map { file => - Files.toString(file, Charset.forName("UTF-8")) + Files.toString(file, UTF_8) }.mkString("") assert(allText === expectedText) generatedFiles diff --git a/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala b/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala index f1f88c5fd3634..aec1e409db95c 100644 --- a/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala @@ -107,7 +107,8 @@ class JsonProtocolSuite extends FunSuite { testJobResult(jobFailed) // TaskEndReason - val fetchFailed = FetchFailed(BlockManagerId("With or", "without you", 15), 17, 18, 19) + val fetchFailed = FetchFailed(BlockManagerId("With or", "without you", 15), 17, 18, 19, + "Some exception") val exceptionFailure = ExceptionFailure("To be", "or not to be", stackTrace, None) testTaskEndReason(Success) testTaskEndReason(Resubmitted) @@ -115,7 +116,7 @@ class JsonProtocolSuite extends FunSuite { testTaskEndReason(exceptionFailure) testTaskEndReason(TaskResultLost) testTaskEndReason(TaskKilled) - testTaskEndReason(ExecutorLostFailure) + testTaskEndReason(ExecutorLostFailure("100")) testTaskEndReason(UnknownReason) // BlockId @@ -176,6 +177,17 @@ class JsonProtocolSuite extends FunSuite { deserializedBmRemoved) } + test("FetchFailed backwards compatibility") { + // FetchFailed in Spark 1.1.0 does not have an "Message" property. + val fetchFailed = FetchFailed(BlockManagerId("With or", "without you", 15), 17, 18, 19, + "ignored") + val oldEvent = JsonProtocol.taskEndReasonToJson(fetchFailed) + .removeField({ _._1 == "Message" }) + val expectedFetchFailed = FetchFailed(BlockManagerId("With or", "without you", 15), 17, 18, 19, + "Unknown reason") + assert(expectedFetchFailed === JsonProtocol.taskEndReasonFromJson(oldEvent)) + } + test("SparkListenerApplicationStart backwards compatibility") { // SparkListenerApplicationStart in Spark 1.0.0 do not have an "appId" property. val applicationStart = SparkListenerApplicationStart("test", None, 1L, "user") @@ -184,6 +196,15 @@ class JsonProtocolSuite extends FunSuite { assert(applicationStart === JsonProtocol.applicationStartFromJson(oldEvent)) } + test("ExecutorLostFailure backward compatibility") { + // ExecutorLostFailure in Spark 1.1.0 does not have an "Executor ID" property. + val executorLostFailure = ExecutorLostFailure("100") + val oldEvent = JsonProtocol.taskEndReasonToJson(executorLostFailure) + .removeField({ _._1 == "Executor ID" }) + val expectedExecutorLostFailure = ExecutorLostFailure("Unknown") + assert(expectedExecutorLostFailure === JsonProtocol.taskEndReasonFromJson(oldEvent)) + } + /** -------------------------- * | Helper test running methods | * --------------------------- */ @@ -396,6 +417,7 @@ class JsonProtocolSuite extends FunSuite { assert(r1.mapId === r2.mapId) assert(r1.reduceId === r2.reduceId) assertEquals(r1.bmAddress, r2.bmAddress) + assert(r1.message === r2.message) case (r1: ExceptionFailure, r2: ExceptionFailure) => assert(r1.className === r2.className) assert(r1.description === r2.description) @@ -403,7 +425,8 @@ class JsonProtocolSuite extends FunSuite { assertOptionEquals(r1.metrics, r2.metrics, assertTaskMetricsEquals) case (TaskResultLost, TaskResultLost) => case (TaskKilled, TaskKilled) => - case (ExecutorLostFailure, ExecutorLostFailure) => + case (ExecutorLostFailure(execId1), ExecutorLostFailure(execId2)) => + assert(execId1 === execId2) case (UnknownReason, UnknownReason) => case _ => fail("Task end reasons don't match in types!") } diff --git a/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala b/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala index ea7ef0524d1e1..8ffe3e2b139c3 100644 --- a/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala @@ -23,7 +23,7 @@ import java.io.{File, ByteArrayOutputStream, ByteArrayInputStream, FileOutputStr import java.net.{BindException, ServerSocket, URI} import java.nio.{ByteBuffer, ByteOrder} -import com.google.common.base.Charsets +import com.google.common.base.Charsets.UTF_8 import com.google.common.io.Files import org.scalatest.FunSuite @@ -118,7 +118,7 @@ class UtilsSuite extends FunSuite { tmpDir2.deleteOnExit() val f1Path = tmpDir2 + "/f1" val f1 = new FileOutputStream(f1Path) - f1.write("1\n2\n3\n4\n5\n6\n7\n8\n9\n".getBytes(Charsets.UTF_8)) + f1.write("1\n2\n3\n4\n5\n6\n7\n8\n9\n".getBytes(UTF_8)) f1.close() // Read first few bytes @@ -146,9 +146,9 @@ class UtilsSuite extends FunSuite { val tmpDir = Utils.createTempDir() tmpDir.deleteOnExit() val files = (1 to 3).map(i => new File(tmpDir, i.toString)) - Files.write("0123456789", files(0), Charsets.UTF_8) - Files.write("abcdefghij", files(1), Charsets.UTF_8) - Files.write("ABCDEFGHIJ", files(2), Charsets.UTF_8) + Files.write("0123456789", files(0), UTF_8) + Files.write("abcdefghij", files(1), UTF_8) + Files.write("ABCDEFGHIJ", files(2), UTF_8) // Read first few bytes in the 1st file assert(Utils.offsetBytes(files, 0, 5) === "01234") @@ -217,9 +217,14 @@ class UtilsSuite extends FunSuite { test("resolveURI") { def assertResolves(before: String, after: String, testWindows: Boolean = false): Unit = { - assume(before.split(",").length == 1) - assert(Utils.resolveURI(before, testWindows) === new URI(after)) - assert(Utils.resolveURI(after, testWindows) === new URI(after)) + // This should test only single paths + assume(before.split(",").length === 1) + // Repeated invocations of resolveURI should yield the same result + def resolve(uri: String): String = Utils.resolveURI(uri, testWindows).toString + assert(resolve(after) === after) + assert(resolve(resolve(after)) === after) + assert(resolve(resolve(resolve(after))) === after) + // Also test resolveURIs with single paths assert(new URI(Utils.resolveURIs(before, testWindows)) === new URI(after)) assert(new URI(Utils.resolveURIs(after, testWindows)) === new URI(after)) } @@ -235,16 +240,27 @@ class UtilsSuite extends FunSuite { assertResolves("file:/C:/file.txt#alias.txt", "file:/C:/file.txt#alias.txt", testWindows = true) intercept[IllegalArgumentException] { Utils.resolveURI("file:foo") } intercept[IllegalArgumentException] { Utils.resolveURI("file:foo:baby") } + } - // Test resolving comma-delimited paths - assert(Utils.resolveURIs("jar1,jar2") === s"file:$cwd/jar1,file:$cwd/jar2") - assert(Utils.resolveURIs("file:/jar1,file:/jar2") === "file:/jar1,file:/jar2") - assert(Utils.resolveURIs("hdfs:/jar1,file:/jar2,jar3") === - s"hdfs:/jar1,file:/jar2,file:$cwd/jar3") - assert(Utils.resolveURIs("hdfs:/jar1,file:/jar2,jar3,jar4#jar5") === + test("resolveURIs with multiple paths") { + def assertResolves(before: String, after: String, testWindows: Boolean = false): Unit = { + assume(before.split(",").length > 1) + assert(Utils.resolveURIs(before, testWindows) === after) + assert(Utils.resolveURIs(after, testWindows) === after) + // Repeated invocations of resolveURIs should yield the same result + def resolve(uri: String): String = Utils.resolveURIs(uri, testWindows) + assert(resolve(after) === after) + assert(resolve(resolve(after)) === after) + assert(resolve(resolve(resolve(after))) === after) + } + val cwd = System.getProperty("user.dir") + assertResolves("jar1,jar2", s"file:$cwd/jar1,file:$cwd/jar2") + assertResolves("file:/jar1,file:/jar2", "file:/jar1,file:/jar2") + assertResolves("hdfs:/jar1,file:/jar2,jar3", s"hdfs:/jar1,file:/jar2,file:$cwd/jar3") + assertResolves("hdfs:/jar1,file:/jar2,jar3,jar4#jar5", s"hdfs:/jar1,file:/jar2,file:$cwd/jar3,file:$cwd/jar4#jar5") - assert(Utils.resolveURIs("hdfs:/jar1,file:/jar2,jar3,C:\\pi.py#py.pi", testWindows = true) === - s"hdfs:/jar1,file:/jar2,file:$cwd/jar3,file:/C:/pi.py#py.pi") + assertResolves("hdfs:/jar1,file:/jar2,jar3,C:\\pi.py#py.pi", + s"hdfs:/jar1,file:/jar2,file:$cwd/jar3,file:/C:/pi.py#py.pi", testWindows = true) } test("nonLocalPaths") { @@ -339,7 +355,7 @@ class UtilsSuite extends FunSuite { try { System.setProperty("spark.test.fileNameLoadB", "2") Files.write("spark.test.fileNameLoadA true\n" + - "spark.test.fileNameLoadB 1\n", outFile, Charsets.UTF_8) + "spark.test.fileNameLoadB 1\n", outFile, UTF_8) val properties = Utils.getPropertiesFromFile(outFile.getAbsolutePath) properties .filter { case (k, v) => k.startsWith("spark.")} @@ -351,4 +367,15 @@ class UtilsSuite extends FunSuite { outFile.delete() } } + + test("timeIt with prepare") { + var cnt = 0 + val prepare = () => { + cnt += 1 + Thread.sleep(1000) + } + val time = Utils.timeIt(2)({}, Some(prepare)) + require(cnt === 2, "prepare should be called twice") + require(time < 500, "preparation time should not count") + } } diff --git a/core/src/test/scala/org/apache/spark/util/collection/SorterSuite.scala b/core/src/test/scala/org/apache/spark/util/collection/SorterSuite.scala index 6fe1079c2719a..0cb1ed7397655 100644 --- a/core/src/test/scala/org/apache/spark/util/collection/SorterSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/collection/SorterSuite.scala @@ -17,7 +17,7 @@ package org.apache.spark.util.collection -import java.lang.{Float => JFloat} +import java.lang.{Float => JFloat, Integer => JInteger} import java.util.{Arrays, Comparator} import org.scalatest.FunSuite @@ -30,11 +30,15 @@ class SorterSuite extends FunSuite { val rand = new XORShiftRandom(123) val data0 = Array.tabulate[Int](10000) { i => rand.nextInt() } val data1 = data0.clone() + val data2 = data0.clone() Arrays.sort(data0) new Sorter(new IntArraySortDataFormat).sort(data1, 0, data1.length, Ordering.Int) + new Sorter(new KeyReuseIntArraySortDataFormat) + .sort(data2, 0, data2.length, Ordering[IntWrapper]) - data0.zip(data1).foreach { case (x, y) => assert(x === y) } + assert(data0.view === data1.view) + assert(data0.view === data2.view) } test("KVArraySorter") { @@ -61,10 +65,33 @@ class SorterSuite extends FunSuite { } } + /** Runs an experiment several times. */ + def runExperiment(name: String, skip: Boolean = false)(f: => Unit, prepare: () => Unit): Unit = { + if (skip) { + println(s"Skipped experiment $name.") + return + } + + val firstTry = org.apache.spark.util.Utils.timeIt(1)(f, Some(prepare)) + System.gc() + + var i = 0 + var next10: Long = 0 + while (i < 10) { + val time = org.apache.spark.util.Utils.timeIt(1)(f, Some(prepare)) + next10 += time + println(s"$name: Took $time ms") + i += 1 + } + + println(s"$name: ($firstTry ms first try, ${next10 / 10} ms average)") + } + /** * This provides a simple benchmark for comparing the Sorter with Java internal sorting. * Ideally these would be executed one at a time, each in their own JVM, so their listing - * here is mainly to have the code. + * here is mainly to have the code. Running multiple tests within the same JVM session would + * prevent JIT inlining overridden methods and hence hurt the performance. * * The goal of this code is to sort an array of key-value pairs, where the array physically * has the keys and values alternating. The basic Java sorts work only on the keys, so the @@ -72,96 +99,167 @@ class SorterSuite extends FunSuite { * those, while the Sorter approach can work directly on the input data format. * * Note that the Java implementation varies tremendously between Java 6 and Java 7, when - * the Java sort changed from merge sort to Timsort. + * the Java sort changed from merge sort to TimSort. */ - ignore("Sorter benchmark") { - - /** Runs an experiment several times. */ - def runExperiment(name: String)(f: => Unit): Unit = { - val firstTry = org.apache.spark.util.Utils.timeIt(1)(f) - System.gc() - - var i = 0 - var next10: Long = 0 - while (i < 10) { - val time = org.apache.spark.util.Utils.timeIt(1)(f) - next10 += time - println(s"$name: Took $time ms") - i += 1 - } - - println(s"$name: ($firstTry ms first try, ${next10 / 10} ms average)") - } - + ignore("Sorter benchmark for key-value pairs") { val numElements = 25000000 // 25 mil val rand = new XORShiftRandom(123) - val keys = Array.tabulate[JFloat](numElements) { i => - new JFloat(rand.nextFloat()) + // Test our key-value pairs where each element is a Tuple2[Float, Integer]. + + val kvTuples = Array.tabulate(numElements) { i => + (new JFloat(rand.nextFloat()), new JInteger(i)) } - // Test our key-value pairs where each element is a Tuple2[Float, Integer) - val kvTupleArray = Array.tabulate[AnyRef](numElements) { i => - (keys(i / 2): Float, i / 2: Int) + val kvTupleArray = new Array[AnyRef](numElements) + val prepareKvTupleArray = () => { + System.arraycopy(kvTuples, 0, kvTupleArray, 0, numElements) } - runExperiment("Tuple-sort using Arrays.sort()") { + runExperiment("Tuple-sort using Arrays.sort()")({ Arrays.sort(kvTupleArray, new Comparator[AnyRef] { override def compare(x: AnyRef, y: AnyRef): Int = - Ordering.Float.compare(x.asInstanceOf[(Float, _)]._1, y.asInstanceOf[(Float, _)]._1) + x.asInstanceOf[(JFloat, _)]._1.compareTo(y.asInstanceOf[(JFloat, _)]._1) }) - } + }, prepareKvTupleArray) // Test our Sorter where each element alternates between Float and Integer, non-primitive - val keyValueArray = Array.tabulate[AnyRef](numElements * 2) { i => - if (i % 2 == 0) keys(i / 2) else new Integer(i / 2) + + val keyValues = { + val data = new Array[AnyRef](numElements * 2) + var i = 0 + while (i < numElements) { + data(2 * i) = kvTuples(i)._1 + data(2 * i + 1) = kvTuples(i)._2 + i += 1 + } + data } + + val keyValueArray = new Array[AnyRef](numElements * 2) + val prepareKeyValueArray = () => { + System.arraycopy(keyValues, 0, keyValueArray, 0, numElements * 2) + } + val sorter = new Sorter(new KVArraySortDataFormat[JFloat, AnyRef]) - runExperiment("KV-sort using Sorter") { - sorter.sort(keyValueArray, 0, keys.length, new Comparator[JFloat] { - override def compare(x: JFloat, y: JFloat): Int = Ordering.Float.compare(x, y) + runExperiment("KV-sort using Sorter")({ + sorter.sort(keyValueArray, 0, numElements, new Comparator[JFloat] { + override def compare(x: JFloat, y: JFloat): Int = x.compareTo(y) }) + }, prepareKeyValueArray) + } + + /** + * Tests for sorting with primitive keys with/without key reuse. Java's Arrays.sort is used as + * reference, which is expected to be faster but it can only sort a single array. Sorter can be + * used to sort parallel arrays. + * + * Ideally these would be executed one at a time, each in their own JVM, so their listing + * here is mainly to have the code. Running multiple tests within the same JVM session would + * prevent JIT inlining overridden methods and hence hurt the performance. + */ + ignore("Sorter benchmark for primitive int array") { + val numElements = 25000000 // 25 mil + val rand = new XORShiftRandom(123) + + val ints = Array.fill(numElements)(rand.nextInt()) + val intObjects = { + val data = new Array[JInteger](numElements) + var i = 0 + while (i < numElements) { + data(i) = new JInteger(ints(i)) + i += 1 + } + data } - // Test non-primitive sort on float array - runExperiment("Java Arrays.sort()") { - Arrays.sort(keys, new Comparator[JFloat] { - override def compare(x: JFloat, y: JFloat): Int = Ordering.Float.compare(x, y) - }) + val intObjectArray = new Array[JInteger](numElements) + val prepareIntObjectArray = () => { + System.arraycopy(intObjects, 0, intObjectArray, 0, numElements) } - // Test primitive sort on float array - val primitiveKeys = Array.tabulate[Float](numElements) { i => rand.nextFloat() } - runExperiment("Java Arrays.sort() on primitive keys") { - Arrays.sort(primitiveKeys) + runExperiment("Java Arrays.sort() on non-primitive int array")({ + Arrays.sort(intObjectArray, new Comparator[JInteger] { + override def compare(x: JInteger, y: JInteger): Int = x.compareTo(y) + }) + }, prepareIntObjectArray) + + val intPrimitiveArray = new Array[Int](numElements) + val prepareIntPrimitiveArray = () => { + System.arraycopy(ints, 0, intPrimitiveArray, 0, numElements) } - } -} + runExperiment("Java Arrays.sort() on primitive int array")({ + Arrays.sort(intPrimitiveArray) + }, prepareIntPrimitiveArray) -/** Format to sort a simple Array[Int]. Could be easily generified and specialized. */ -class IntArraySortDataFormat extends SortDataFormat[Int, Array[Int]] { - override protected def getKey(data: Array[Int], pos: Int): Int = { - data(pos) + val sorterWithoutKeyReuse = new Sorter(new IntArraySortDataFormat) + runExperiment("Sorter without key reuse on primitive int array")({ + sorterWithoutKeyReuse.sort(intPrimitiveArray, 0, numElements, Ordering[Int]) + }, prepareIntPrimitiveArray) + + val sorterWithKeyReuse = new Sorter(new KeyReuseIntArraySortDataFormat) + runExperiment("Sorter with key reuse on primitive int array")({ + sorterWithKeyReuse.sort(intPrimitiveArray, 0, numElements, Ordering[IntWrapper]) + }, prepareIntPrimitiveArray) } +} - override protected def swap(data: Array[Int], pos0: Int, pos1: Int): Unit = { +abstract class AbstractIntArraySortDataFormat[K] extends SortDataFormat[K, Array[Int]] { + + override def swap(data: Array[Int], pos0: Int, pos1: Int): Unit = { val tmp = data(pos0) data(pos0) = data(pos1) data(pos1) = tmp } - override protected def copyElement(src: Array[Int], srcPos: Int, dst: Array[Int], dstPos: Int) { + override def copyElement(src: Array[Int], srcPos: Int, dst: Array[Int], dstPos: Int) { dst(dstPos) = src(srcPos) } /** Copy a range of elements starting at src(srcPos) to dest, starting at destPos. */ - override protected def copyRange(src: Array[Int], srcPos: Int, - dst: Array[Int], dstPos: Int, length: Int) { + override def copyRange(src: Array[Int], srcPos: Int, dst: Array[Int], dstPos: Int, length: Int) { System.arraycopy(src, srcPos, dst, dstPos, length) } /** Allocates a new structure that can hold up to 'length' elements. */ - override protected def allocate(length: Int): Array[Int] = { + override def allocate(length: Int): Array[Int] = { new Array[Int](length) } } + +/** Format to sort a simple Array[Int]. Could be easily generified and specialized. */ +class IntArraySortDataFormat extends AbstractIntArraySortDataFormat[Int] { + + override protected def getKey(data: Array[Int], pos: Int): Int = { + data(pos) + } +} + +/** Wrapper of Int for key reuse. */ +class IntWrapper(var key: Int = 0) extends Ordered[IntWrapper] { + + override def compare(that: IntWrapper): Int = { + Ordering.Int.compare(key, that.key) + } +} + +/** SortDataFormat for Array[Int] with reused keys. */ +class KeyReuseIntArraySortDataFormat extends AbstractIntArraySortDataFormat[IntWrapper] { + + override def newKey(): IntWrapper = { + new IntWrapper() + } + + override def getKey(data: Array[Int], pos: Int, reuse: IntWrapper): IntWrapper = { + if (reuse == null) { + new IntWrapper(data(pos)) + } else { + reuse.key = data(pos) + reuse + } + } + + override protected def getKey(data: Array[Int], pos: Int): IntWrapper = { + getKey(data, pos, null) + } +} diff --git a/core/src/test/scala/org/apache/spark/util/random/RandomSamplerSuite.scala b/core/src/test/scala/org/apache/spark/util/random/RandomSamplerSuite.scala index 36877476e708e..20944b62473c5 100644 --- a/core/src/test/scala/org/apache/spark/util/random/RandomSamplerSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/random/RandomSamplerSuite.scala @@ -18,96 +18,523 @@ package org.apache.spark.util.random import java.util.Random +import scala.collection.mutable.ArrayBuffer +import org.apache.commons.math3.distribution.PoissonDistribution -import cern.jet.random.Poisson -import org.scalatest.{BeforeAndAfter, FunSuite} -import org.scalatest.mock.EasyMockSugar - -class RandomSamplerSuite extends FunSuite with BeforeAndAfter with EasyMockSugar { - - val a = List(1, 2, 3, 4, 5, 6, 7, 8, 9) - - var random: Random = _ - var poisson: Poisson = _ - - before { - random = mock[Random] - poisson = mock[Poisson] - } - - test("BernoulliSamplerWithRange") { - expecting { - for(x <- Seq(0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9)) { - random.nextDouble().andReturn(x) - } - } - whenExecuting(random) { - val sampler = new BernoulliSampler[Int](0.25, 0.55) - sampler.rng = random - assert(sampler.sample(a.iterator).toList == List(3, 4, 5)) - } - } - - test("BernoulliSamplerWithRangeInverse") { - expecting { - for(x <- Seq(0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9)) { - random.nextDouble().andReturn(x) - } - } - whenExecuting(random) { - val sampler = new BernoulliSampler[Int](0.25, 0.55, true) - sampler.rng = random - assert(sampler.sample(a.iterator).toList === List(1, 2, 6, 7, 8, 9)) - } - } - - test("BernoulliSamplerWithRatio") { - expecting { - for(x <- Seq(0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9)) { - random.nextDouble().andReturn(x) - } - } - whenExecuting(random) { - val sampler = new BernoulliSampler[Int](0.35) - sampler.rng = random - assert(sampler.sample(a.iterator).toList == List(1, 2, 3)) - } - } - - test("BernoulliSamplerWithComplement") { - expecting { - for(x <- Seq(0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9)) { - random.nextDouble().andReturn(x) - } - } - whenExecuting(random) { - val sampler = new BernoulliSampler[Int](0.25, 0.55, true) - sampler.rng = random - assert(sampler.sample(a.iterator).toList == List(1, 2, 6, 7, 8, 9)) - } - } - - test("BernoulliSamplerSetSeed") { - expecting { - random.setSeed(10L) - } - whenExecuting(random) { - val sampler = new BernoulliSampler[Int](0.2) - sampler.rng = random - sampler.setSeed(10L) - } - } - - test("PoissonSampler") { - expecting { - for(x <- Seq(0, 1, 2, 0, 1, 1, 0, 0, 0)) { - poisson.nextInt().andReturn(x) - } - } - whenExecuting(poisson) { - val sampler = new PoissonSampler[Int](0.2) - sampler.rng = poisson - assert(sampler.sample(a.iterator).toList == List(2, 3, 3, 5, 6)) - } +import org.scalatest.{FunSuite, Matchers} + +class RandomSamplerSuite extends FunSuite with Matchers { + /** + * My statistical testing methodology is to run a Kolmogorov-Smirnov (KS) test + * between the random samplers and simple reference samplers (known to work correctly). + * The sampling gap sizes between chosen samples should show up as having the same + * distributions between test and reference, if things are working properly. That is, + * the KS test will fail to strongly reject the null hypothesis that the distributions of + * sampling gaps are the same. + * There are no actual KS tests implemented for scala (that I can find) - and so what I + * have done here is pre-compute "D" - the KS statistic - that corresponds to a "weak" + * p-value for a particular sample size. I can then test that my measured KS stats + * are less than D. Computing D-values is easy, and implemented below. + * + * I used the scipy 'kstwobign' distribution to pre-compute my D value: + * + * def ksdval(q=0.1, n=1000): + * en = np.sqrt(float(n) / 2.0) + * return stats.kstwobign.isf(float(q)) / (en + 0.12 + 0.11 / en) + * + * When comparing KS stats I take the median of a small number of independent test runs + * to compensate for the issue that any sampled statistic will show "false positive" with + * some probability. Even when two distributions are the same, they will register as + * different 10% of the time at a p-value of 0.1 + */ + + // This D value is the precomputed KS statistic for p-value 0.1, sample size 1000: + val sampleSize = 1000 + val D = 0.0544280747619 + + // I'm not a big fan of fixing seeds, but unit testing based on running statistical tests + // will always fail with some nonzero probability, so I'll fix the seed to prevent these + // tests from generating random failure noise in CI testing, etc. + val rngSeed: Random = RandomSampler.newDefaultRNG + rngSeed.setSeed(235711) + + // Reference implementation of sampling without replacement (bernoulli) + def sample[T](data: Iterator[T], f: Double): Iterator[T] = { + val rng: Random = RandomSampler.newDefaultRNG + rng.setSeed(rngSeed.nextLong) + data.filter(_ => (rng.nextDouble <= f)) + } + + // Reference implementation of sampling with replacement + def sampleWR[T](data: Iterator[T], f: Double): Iterator[T] = { + val rng = new PoissonDistribution(f) + rng.reseedRandomGenerator(rngSeed.nextLong) + data.flatMap { v => { + val rep = rng.sample() + if (rep == 0) Iterator.empty else Iterator.fill(rep)(v) + }} + } + + // Returns iterator over gap lengths between samples. + // This function assumes input data is integers sampled from the sequence of + // increasing integers: {0, 1, 2, ...}. This works because that is how I generate them, + // and the samplers preserve their input order + def gaps(data: Iterator[Int]): Iterator[Int] = { + data.sliding(2).withPartial(false).map { x => x(1) - x(0) } + } + + // Returns the cumulative distribution from a histogram + def cumulativeDist(hist: Array[Int]): Array[Double] = { + val n = hist.sum.toDouble + assert(n > 0.0) + hist.scanLeft(0)(_ + _).drop(1).map { _.toDouble / n } + } + + // Returns aligned cumulative distributions from two arrays of data + def cumulants(d1: Array[Int], d2: Array[Int], + ss: Int = sampleSize): (Array[Double], Array[Double]) = { + assert(math.min(d1.length, d2.length) > 0) + assert(math.min(d1.min, d2.min) >= 0) + val m = 1 + math.max(d1.max, d2.max) + val h1 = Array.fill[Int](m)(0) + val h2 = Array.fill[Int](m)(0) + for (v <- d1) { h1(v) += 1 } + for (v <- d2) { h2(v) += 1 } + assert(h1.sum == h2.sum) + assert(h1.sum == ss) + (cumulativeDist(h1), cumulativeDist(h2)) + } + + // Computes the Kolmogorov-Smirnov 'D' statistic from two cumulative distributions + def KSD(cdf1: Array[Double], cdf2: Array[Double]): Double = { + assert(cdf1.length == cdf2.length) + val n = cdf1.length + assert(n > 0) + assert(cdf1(n-1) == 1.0) + assert(cdf2(n-1) == 1.0) + cdf1.zip(cdf2).map { x => Math.abs(x._1 - x._2) }.max + } + + // Returns the median KS 'D' statistic between two samples, over (m) sampling trials + def medianKSD(data1: => Iterator[Int], data2: => Iterator[Int], m: Int = 5): Double = { + val t = Array.fill[Double](m) { + val (c1, c2) = cumulants(data1.take(sampleSize).toArray, + data2.take(sampleSize).toArray) + KSD(c1, c2) + }.sorted + // return the median KS statistic + t(m / 2) + } + + test("utilities") { + val s1 = Array(0, 1, 1, 0, 2) + val s2 = Array(1, 0, 3, 2, 1) + val (c1, c2) = cumulants(s1, s2, ss = 5) + c1 should be (Array(0.4, 0.8, 1.0, 1.0)) + c2 should be (Array(0.2, 0.6, 0.8, 1.0)) + KSD(c1, c2) should be (0.2 +- 0.000001) + KSD(c2, c1) should be (KSD(c1, c2)) + gaps(List(0, 1, 1, 2, 4, 11).iterator).toArray should be (Array(1, 0, 1, 2, 7)) + } + + test("sanity check medianKSD against references") { + var d: Double = 0.0 + + // should be statistically same, i.e. fail to reject null hypothesis strongly + d = medianKSD(gaps(sample(Iterator.from(0), 0.5)), gaps(sample(Iterator.from(0), 0.5))) + d should be < D + + // should be statistically different - null hypothesis will have high D value, + // corresponding to low p-value that rejects the null hypothesis + d = medianKSD(gaps(sample(Iterator.from(0), 0.4)), gaps(sample(Iterator.from(0), 0.5))) + d should be > D + + // same! + d = medianKSD(gaps(sampleWR(Iterator.from(0), 0.5)), gaps(sampleWR(Iterator.from(0), 0.5))) + d should be < D + + // different! + d = medianKSD(gaps(sampleWR(Iterator.from(0), 0.5)), gaps(sampleWR(Iterator.from(0), 0.6))) + d should be > D + } + + test("bernoulli sampling") { + // Tests expect maximum gap sampling fraction to be this value + RandomSampler.defaultMaxGapSamplingFraction should be (0.4) + + var d: Double = 0.0 + + var sampler: RandomSampler[Int, Int] = new BernoulliSampler[Int](0.5) + sampler.setSeed(rngSeed.nextLong) + d = medianKSD(gaps(sampler.sample(Iterator.from(0))), gaps(sample(Iterator.from(0), 0.5))) + d should be < D + + sampler = new BernoulliSampler[Int](0.7) + sampler.setSeed(rngSeed.nextLong) + d = medianKSD(gaps(sampler.sample(Iterator.from(0))), gaps(sample(Iterator.from(0), 0.7))) + d should be < D + + sampler = new BernoulliSampler[Int](0.9) + sampler.setSeed(rngSeed.nextLong) + d = medianKSD(gaps(sampler.sample(Iterator.from(0))), gaps(sample(Iterator.from(0), 0.9))) + d should be < D + + // sampling at different frequencies should show up as statistically different: + sampler = new BernoulliSampler[Int](0.5) + sampler.setSeed(rngSeed.nextLong) + d = medianKSD(gaps(sampler.sample(Iterator.from(0))), gaps(sample(Iterator.from(0), 0.6))) + d should be > D + } + + test("bernoulli sampling with gap sampling optimization") { + // Tests expect maximum gap sampling fraction to be this value + RandomSampler.defaultMaxGapSamplingFraction should be (0.4) + + var d: Double = 0.0 + + var sampler: RandomSampler[Int, Int] = new BernoulliSampler[Int](0.01) + sampler.setSeed(rngSeed.nextLong) + d = medianKSD(gaps(sampler.sample(Iterator.from(0))), gaps(sample(Iterator.from(0), 0.01))) + d should be < D + + sampler = new BernoulliSampler[Int](0.1) + sampler.setSeed(rngSeed.nextLong) + d = medianKSD(gaps(sampler.sample(Iterator.from(0))), gaps(sample(Iterator.from(0), 0.1))) + d should be < D + + sampler = new BernoulliSampler[Int](0.3) + sampler.setSeed(rngSeed.nextLong) + d = medianKSD(gaps(sampler.sample(Iterator.from(0))), gaps(sample(Iterator.from(0), 0.3))) + d should be < D + + // sampling at different frequencies should show up as statistically different: + sampler = new BernoulliSampler[Int](0.3) + sampler.setSeed(rngSeed.nextLong) + d = medianKSD(gaps(sampler.sample(Iterator.from(0))), gaps(sample(Iterator.from(0), 0.4))) + d should be > D + } + + test("bernoulli boundary cases") { + val data = (1 to 100).toArray + + var sampler = new BernoulliSampler[Int](0.0) + sampler.sample(data.iterator).toArray should be (Array.empty[Int]) + + sampler = new BernoulliSampler[Int](1.0) + sampler.sample(data.iterator).toArray should be (data) + + sampler = new BernoulliSampler[Int](0.0 - (RandomSampler.roundingEpsilon / 2.0)) + sampler.sample(data.iterator).toArray should be (Array.empty[Int]) + + sampler = new BernoulliSampler[Int](1.0 + (RandomSampler.roundingEpsilon / 2.0)) + sampler.sample(data.iterator).toArray should be (data) + } + + test("bernoulli data types") { + // Tests expect maximum gap sampling fraction to be this value + RandomSampler.defaultMaxGapSamplingFraction should be (0.4) + + var d: Double = 0.0 + var sampler = new BernoulliSampler[Int](0.1) + sampler.setSeed(rngSeed.nextLong) + + // Array iterator (indexable type) + d = medianKSD( + gaps(sampler.sample(Iterator.from(0).take(20*sampleSize).toArray.iterator)), + gaps(sample(Iterator.from(0), 0.1))) + d should be < D + + // ArrayBuffer iterator (indexable type) + d = medianKSD( + gaps(sampler.sample(Iterator.from(0).take(20*sampleSize).to[ArrayBuffer].iterator)), + gaps(sample(Iterator.from(0), 0.1))) + d should be < D + + // List iterator (non-indexable type) + d = medianKSD( + gaps(sampler.sample(Iterator.from(0).take(20*sampleSize).toList.iterator)), + gaps(sample(Iterator.from(0), 0.1))) + d should be < D + } + + test("bernoulli clone") { + // Tests expect maximum gap sampling fraction to be this value + RandomSampler.defaultMaxGapSamplingFraction should be (0.4) + + var d = 0.0 + var sampler = new BernoulliSampler[Int](0.1).clone + sampler.setSeed(rngSeed.nextLong) + d = medianKSD(gaps(sampler.sample(Iterator.from(0))), gaps(sample(Iterator.from(0), 0.1))) + d should be < D + + sampler = new BernoulliSampler[Int](0.9).clone + sampler.setSeed(rngSeed.nextLong) + d = medianKSD(gaps(sampler.sample(Iterator.from(0))), gaps(sample(Iterator.from(0), 0.9))) + d should be < D + } + + test("bernoulli set seed") { + RandomSampler.defaultMaxGapSamplingFraction should be (0.4) + + var d: Double = 0.0 + var sampler1 = new BernoulliSampler[Int](0.2) + var sampler2 = new BernoulliSampler[Int](0.2) + + // distributions should be identical if seeds are set same + sampler1.setSeed(73) + sampler2.setSeed(73) + d = medianKSD(gaps(sampler1.sample(Iterator.from(0))), gaps(sampler2.sample(Iterator.from(0)))) + d should be (0.0) + + // should be different for different seeds + sampler1.setSeed(73) + sampler2.setSeed(37) + d = medianKSD(gaps(sampler1.sample(Iterator.from(0))), gaps(sampler2.sample(Iterator.from(0)))) + d should be > 0.0 + d should be < D + + sampler1 = new BernoulliSampler[Int](0.8) + sampler2 = new BernoulliSampler[Int](0.8) + + // distributions should be identical if seeds are set same + sampler1.setSeed(73) + sampler2.setSeed(73) + d = medianKSD(gaps(sampler1.sample(Iterator.from(0))), gaps(sampler2.sample(Iterator.from(0)))) + d should be (0.0) + + // should be different for different seeds + sampler1.setSeed(73) + sampler2.setSeed(37) + d = medianKSD(gaps(sampler1.sample(Iterator.from(0))), gaps(sampler2.sample(Iterator.from(0)))) + d should be > 0.0 + d should be < D + } + + test("replacement sampling") { + // Tests expect maximum gap sampling fraction to be this value + RandomSampler.defaultMaxGapSamplingFraction should be (0.4) + + var d: Double = 0.0 + + var sampler = new PoissonSampler[Int](0.5) + sampler.setSeed(rngSeed.nextLong) + d = medianKSD(gaps(sampler.sample(Iterator.from(0))), gaps(sampleWR(Iterator.from(0), 0.5))) + d should be < D + + sampler = new PoissonSampler[Int](0.7) + sampler.setSeed(rngSeed.nextLong) + d = medianKSD(gaps(sampler.sample(Iterator.from(0))), gaps(sampleWR(Iterator.from(0), 0.7))) + d should be < D + + sampler = new PoissonSampler[Int](0.9) + sampler.setSeed(rngSeed.nextLong) + d = medianKSD(gaps(sampler.sample(Iterator.from(0))), gaps(sampleWR(Iterator.from(0), 0.9))) + d should be < D + + // sampling at different frequencies should show up as statistically different: + sampler = new PoissonSampler[Int](0.5) + sampler.setSeed(rngSeed.nextLong) + d = medianKSD(gaps(sampler.sample(Iterator.from(0))), gaps(sampleWR(Iterator.from(0), 0.6))) + d should be > D + } + + test("replacement sampling with gap sampling") { + // Tests expect maximum gap sampling fraction to be this value + RandomSampler.defaultMaxGapSamplingFraction should be (0.4) + + var d: Double = 0.0 + + var sampler = new PoissonSampler[Int](0.01) + sampler.setSeed(rngSeed.nextLong) + d = medianKSD(gaps(sampler.sample(Iterator.from(0))), gaps(sampleWR(Iterator.from(0), 0.01))) + d should be < D + + sampler = new PoissonSampler[Int](0.1) + sampler.setSeed(rngSeed.nextLong) + d = medianKSD(gaps(sampler.sample(Iterator.from(0))), gaps(sampleWR(Iterator.from(0), 0.1))) + d should be < D + + sampler = new PoissonSampler[Int](0.3) + sampler.setSeed(rngSeed.nextLong) + d = medianKSD(gaps(sampler.sample(Iterator.from(0))), gaps(sampleWR(Iterator.from(0), 0.3))) + d should be < D + + // sampling at different frequencies should show up as statistically different: + sampler = new PoissonSampler[Int](0.3) + sampler.setSeed(rngSeed.nextLong) + d = medianKSD(gaps(sampler.sample(Iterator.from(0))), gaps(sampleWR(Iterator.from(0), 0.4))) + d should be > D + } + + test("replacement boundary cases") { + val data = (1 to 100).toArray + + var sampler = new PoissonSampler[Int](0.0) + sampler.sample(data.iterator).toArray should be (Array.empty[Int]) + + sampler = new PoissonSampler[Int](0.0 - (RandomSampler.roundingEpsilon / 2.0)) + sampler.sample(data.iterator).toArray should be (Array.empty[Int]) + + // sampling with replacement has no upper bound on sampling fraction + sampler = new PoissonSampler[Int](2.0) + sampler.sample(data.iterator).length should be > (data.length) + } + + test("replacement data types") { + // Tests expect maximum gap sampling fraction to be this value + RandomSampler.defaultMaxGapSamplingFraction should be (0.4) + + var d: Double = 0.0 + var sampler = new PoissonSampler[Int](0.1) + sampler.setSeed(rngSeed.nextLong) + + // Array iterator (indexable type) + d = medianKSD( + gaps(sampler.sample(Iterator.from(0).take(20*sampleSize).toArray.iterator)), + gaps(sampleWR(Iterator.from(0), 0.1))) + d should be < D + + // ArrayBuffer iterator (indexable type) + d = medianKSD( + gaps(sampler.sample(Iterator.from(0).take(20*sampleSize).to[ArrayBuffer].iterator)), + gaps(sampleWR(Iterator.from(0), 0.1))) + d should be < D + + // List iterator (non-indexable type) + d = medianKSD( + gaps(sampler.sample(Iterator.from(0).take(20*sampleSize).toList.iterator)), + gaps(sampleWR(Iterator.from(0), 0.1))) + d should be < D + } + + test("replacement clone") { + // Tests expect maximum gap sampling fraction to be this value + RandomSampler.defaultMaxGapSamplingFraction should be (0.4) + + var d = 0.0 + var sampler = new PoissonSampler[Int](0.1).clone + sampler.setSeed(rngSeed.nextLong) + d = medianKSD(gaps(sampler.sample(Iterator.from(0))), gaps(sampleWR(Iterator.from(0), 0.1))) + d should be < D + + sampler = new PoissonSampler[Int](0.9).clone + sampler.setSeed(rngSeed.nextLong) + d = medianKSD(gaps(sampler.sample(Iterator.from(0))), gaps(sampleWR(Iterator.from(0), 0.9))) + d should be < D + } + + test("replacement set seed") { + RandomSampler.defaultMaxGapSamplingFraction should be (0.4) + + var d: Double = 0.0 + var sampler1 = new PoissonSampler[Int](0.2) + var sampler2 = new PoissonSampler[Int](0.2) + + // distributions should be identical if seeds are set same + sampler1.setSeed(73) + sampler2.setSeed(73) + d = medianKSD(gaps(sampler1.sample(Iterator.from(0))), gaps(sampler2.sample(Iterator.from(0)))) + d should be (0.0) + + // should be different for different seeds + sampler1.setSeed(73) + sampler2.setSeed(37) + d = medianKSD(gaps(sampler1.sample(Iterator.from(0))), gaps(sampler2.sample(Iterator.from(0)))) + d should be > 0.0 + d should be < D + + sampler1 = new PoissonSampler[Int](0.8) + sampler2 = new PoissonSampler[Int](0.8) + + // distributions should be identical if seeds are set same + sampler1.setSeed(73) + sampler2.setSeed(73) + d = medianKSD(gaps(sampler1.sample(Iterator.from(0))), gaps(sampler2.sample(Iterator.from(0)))) + d should be (0.0) + + // should be different for different seeds + sampler1.setSeed(73) + sampler2.setSeed(37) + d = medianKSD(gaps(sampler1.sample(Iterator.from(0))), gaps(sampler2.sample(Iterator.from(0)))) + d should be > 0.0 + d should be < D + } + + test("bernoulli partitioning sampling") { + var d: Double = 0.0 + + var sampler = new BernoulliCellSampler[Int](0.1, 0.2) + sampler.setSeed(rngSeed.nextLong) + d = medianKSD(gaps(sampler.sample(Iterator.from(0))), gaps(sample(Iterator.from(0), 0.1))) + d should be < D + + sampler = new BernoulliCellSampler[Int](0.1, 0.2, true) + sampler.setSeed(rngSeed.nextLong) + d = medianKSD(gaps(sampler.sample(Iterator.from(0))), gaps(sample(Iterator.from(0), 0.9))) + d should be < D + } + + test("bernoulli partitioning boundary cases") { + val data = (1 to 100).toArray + val d = RandomSampler.roundingEpsilon / 2.0 + + var sampler = new BernoulliCellSampler[Int](0.0, 0.0) + sampler.sample(data.iterator).toArray should be (Array.empty[Int]) + + sampler = new BernoulliCellSampler[Int](0.5, 0.5) + sampler.sample(data.iterator).toArray should be (Array.empty[Int]) + + sampler = new BernoulliCellSampler[Int](1.0, 1.0) + sampler.sample(data.iterator).toArray should be (Array.empty[Int]) + + sampler = new BernoulliCellSampler[Int](0.0, 1.0) + sampler.sample(data.iterator).toArray should be (data) + + sampler = new BernoulliCellSampler[Int](0.0 - d, 1.0 + d) + sampler.sample(data.iterator).toArray should be (data) + + sampler = new BernoulliCellSampler[Int](0.5, 0.5 - d) + sampler.sample(data.iterator).toArray should be (Array.empty[Int]) + } + + test("bernoulli partitioning data") { + val seed = rngSeed.nextLong + val data = (1 to 100).toArray + + var sampler = new BernoulliCellSampler[Int](0.4, 0.6) + sampler.setSeed(seed) + val s1 = sampler.sample(data.iterator).toArray + s1.length should be > 0 + + sampler = new BernoulliCellSampler[Int](0.4, 0.6, true) + sampler.setSeed(seed) + val s2 = sampler.sample(data.iterator).toArray + s2.length should be > 0 + + (s1 ++ s2).sorted should be (data) + + sampler = new BernoulliCellSampler[Int](0.5, 0.5) + sampler.sample(data.iterator).toArray should be (Array.empty[Int]) + + sampler = new BernoulliCellSampler[Int](0.5, 0.5, true) + sampler.sample(data.iterator).toArray should be (data) + } + + test("bernoulli partitioning clone") { + val seed = rngSeed.nextLong + val data = (1 to 100).toArray + val base = new BernoulliCellSampler[Int](0.35, 0.65) + + var sampler = base.clone + sampler.setSeed(seed) + val s1 = sampler.sample(data.iterator).toArray + s1.length should be > 0 + + sampler = base.cloneComplement + sampler.setSeed(seed) + val s2 = sampler.sample(data.iterator).toArray + s2.length should be > 0 + + (s1 ++ s2).sorted should be (data) } } diff --git a/dev/run-tests b/dev/run-tests index f55497ae2bfbd..de607e4344453 100755 --- a/dev/run-tests +++ b/dev/run-tests @@ -142,17 +142,24 @@ CURRENT_BLOCK=$BLOCK_BUILD # We always build with Hive because the PySpark Spark SQL tests need it. BUILD_MVN_PROFILE_ARGS="$SBT_MAVEN_PROFILES_ARGS -Phive -Phive-0.12.0" - echo "[info] Building Spark with these arguments: $BUILD_MVN_PROFILE_ARGS" # NOTE: echo "q" is needed because sbt on encountering a build file with failure #+ (either resolution or compilation) prompts the user for input either q, r, etc #+ to quit or retry. This echo is there to make it not block. - # NOTE: Do not quote $BUILD_MVN_PROFILE_ARGS or else it will be interpreted as a + # NOTE: Do not quote $BUILD_MVN_PROFILE_ARGS or else it will be interpreted as a #+ single argument! # QUESTION: Why doesn't 'yes "q"' work? # QUESTION: Why doesn't 'grep -v -e "^\[info\] Resolving"' work? + # First build with 0.12 to ensure patches do not break the hive 12 build + echo "[info] Compile with hive 0.12" echo -e "q\n" \ - | sbt/sbt $BUILD_MVN_PROFILE_ARGS clean package assembly/assembly \ + | sbt/sbt $BUILD_MVN_PROFILE_ARGS clean hive/compile hive-thriftserver/compile \ + | grep -v -e "info.*Resolving" -e "warn.*Merging" -e "info.*Including" + + # Then build with default version(0.13.1) because tests are based on this version + echo "[info] Building Spark with these arguments: $SBT_MAVEN_PROFILES_ARGS -Phive" + echo -e "q\n" \ + | sbt/sbt $SBT_MAVEN_PROFILES_ARGS -Phive package assembly/assembly \ | grep -v -e "info.*Resolving" -e "warn.*Merging" -e "info.*Including" } @@ -173,7 +180,7 @@ CURRENT_BLOCK=$BLOCK_SPARK_UNIT_TESTS if [ -n "$_SQL_TESTS_ONLY" ]; then # This must be an array of individual arguments. Otherwise, having one long string #+ will be interpreted as a single test, which doesn't work. - SBT_MAVEN_TEST_ARGS=("catalyst/test" "sql/test" "hive/test" "hive-thriftserver/test") + SBT_MAVEN_TEST_ARGS=("catalyst/test" "sql/test" "hive/test" "mllib/test") else SBT_MAVEN_TEST_ARGS=("test") fi diff --git a/dev/scalastyle b/dev/scalastyle index c3b356bcb3c06..ed1b6b730af6e 100755 --- a/dev/scalastyle +++ b/dev/scalastyle @@ -25,7 +25,7 @@ echo -e "q\n" | sbt/sbt -Pyarn-alpha -Phadoop-0.23 -Dhadoop.version=0.23.9 yarn- echo -e "q\n" | sbt/sbt -Pyarn -Phadoop-2.2 -Dhadoop.version=2.2.0 yarn/scalastyle \ >> scalastyle.txt -ERRORS=$(cat scalastyle.txt | grep -e "\") +ERRORS=$(cat scalastyle.txt | awk '{if($1~/error/)print}') rm scalastyle.txt if test ! -z "$ERRORS"; then diff --git a/docs/_config.yml b/docs/_config.yml index f4bf242ac191b..cdea02fcffbc5 100644 --- a/docs/_config.yml +++ b/docs/_config.yml @@ -11,10 +11,10 @@ kramdown: include: - _static -# These allow the documentation to be updated with nerw releases +# These allow the documentation to be updated with newer releases # of Spark, Scala, and Mesos. -SPARK_VERSION: 1.0.0-SNAPSHOT -SPARK_VERSION_SHORT: 1.0.0 +SPARK_VERSION: 1.2.0-SNAPSHOT +SPARK_VERSION_SHORT: 1.2.0 SCALA_BINARY_VERSION: "2.10" SCALA_VERSION: "2.10.4" MESOS_VERSION: 0.18.1 diff --git a/docs/building-spark.md b/docs/building-spark.md index 11fd56c145c01..238ddae15545e 100644 --- a/docs/building-spark.md +++ b/docs/building-spark.md @@ -67,11 +67,13 @@ For Apache Hadoop 2.x, 0.23.x, Cloudera CDH, and other Hadoop versions with YARN
- +
Thread Dump
{Utils.bytesToString(info.totalShuffleWrite)} + Thread Dump +
{acc.name}{acc.value}
{UIUtils.formatDuration(millis.toLong)}Result serialization time{UIUtils.formatDuration(ms.toLong)}Duration{UIUtils.formatDuration(ms.toLong)}Duration + GC Time + + + + Result Serialization Time + + Time spent fetching task results{UIUtils.formatDuration(millis.toLong)} + + Getting Result Time + + Scheduler delay{UIUtils.formatDuration(millis.toLong)}{Utils.bytesToString(d.toLong)}InputInputShuffle Read (Remote)Shuffle WriteShuffle WriteShuffle spill (memory)Shuffle spill (disk)
{formatDuration} + + {UIUtils.formatDuration(schedulerDelay.toLong)} + {if (gcTime > 0) UIUtils.formatDuration(gcTime) else ""} - {Unparsed( - info.accumulables.map{acc => s"${acc.name}: ${acc.update.get}"}.mkString("
") - )} +
+ {UIUtils.formatDuration(serializationTime)} + {Unparsed(accumulatorsReadable.mkString("
"))} +
{inputReadable} @@ -333,4 +415,15 @@ private[ui] class StagePage(parent: JobProgressTab) extends WebUIPage("stage") {
YARN versionProfile required
0.23.x to 2.1.xyarn-alpha
0.23.x to 2.1.xyarn-alpha (Deprecated.)
2.2.x and lateryarn
+Note: Support for YARN-alpha API's will be removed in Spark 1.3 (see SPARK-3445). + Examples: {% highlight bash %} @@ -97,14 +99,11 @@ mvn -Pyarn -Phadoop-2.4 -Dhadoop.version=2.4.0 -DskipTests clean package mvn -Pyarn-alpha -Phadoop-2.3 -Dhadoop.version=2.3.0 -Dyarn.version=0.23.7 -DskipTests clean package {% endhighlight %} - - # Building With Hive and JDBC Support To enable Hive integration for Spark SQL along with its JDBC server and CLI, add the `-Phive` profile to your existing build options. By default Spark will build with Hive 0.13.1 bindings. You can also build for Hive 0.12.0 using -the `-Phive-0.12.0` profile. NOTE: currently the JDBC server is only -supported for Hive 0.12.0. +the `-Phive-0.12.0` profile. {% highlight bash %} # Apache Hadoop 2.4.X with Hive 13 support mvn -Pyarn -Phadoop-2.4 -Dhadoop.version=2.4.0 -Phive -DskipTests clean package @@ -119,8 +118,8 @@ Tests are run by default via the [ScalaTest Maven plugin](http://www.scalatest.o Some of the tests require Spark to be packaged first, so always run `mvn package` with `-DskipTests` the first time. The following is an example of a correct (build, test) sequence: - mvn -Pyarn -Phadoop-2.3 -DskipTests -Phive -Phive-0.12.0 clean package - mvn -Pyarn -Phadoop-2.3 -Phive -Phive-0.12.0 test + mvn -Pyarn -Phadoop-2.3 -DskipTests -Phive clean package + mvn -Pyarn -Phadoop-2.3 -Phive test The ScalaTest plugin also supports running only a specific test suite as follows: @@ -183,16 +182,16 @@ can be set to control the SBT build. For example: Some of the tests require Spark to be packaged first, so always run `sbt/sbt assembly` the first time. The following is an example of a correct (build, test) sequence: - sbt/sbt -Pyarn -Phadoop-2.3 -Phive -Phive-0.12.0 assembly - sbt/sbt -Pyarn -Phadoop-2.3 -Phive -Phive-0.12.0 test + sbt/sbt -Pyarn -Phadoop-2.3 -Phive assembly + sbt/sbt -Pyarn -Phadoop-2.3 -Phive test To run only a specific test suite as follows: - sbt/sbt -Pyarn -Phadoop-2.3 -Phive -Phive-0.12.0 "test-only org.apache.spark.repl.ReplSuite" + sbt/sbt -Pyarn -Phadoop-2.3 -Phive "test-only org.apache.spark.repl.ReplSuite" To run test suites of a specific sub project as follows: - sbt/sbt -Pyarn -Phadoop-2.3 -Phive -Phive-0.12.0 core/test + sbt/sbt -Pyarn -Phadoop-2.3 -Phive core/test # Speeding up Compilation with Zinc diff --git a/docs/configuration.md b/docs/configuration.md index 66738d3ca754e..685101ea5c9c9 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -111,6 +111,18 @@ of the most common options to set are: (e.g. 512m, 2g). + + spark.driver.maxResultSize + 1g + + Limit of total size of serialized results of all partitions for each Spark action (e.g. collect). + Should be at least 1M, or 0 for unlimited. Jobs will be aborted if the total size + is above this limit. + Having a high limit may cause out-of-memory errors in driver (depends on spark.driver.memory + and memory overhead of objects in JVM). Setting a proper limit can protect the driver from + out-of-memory errors. + + spark.serializer org.apache.spark.serializer.
JavaSerializer @@ -359,6 +371,16 @@ Apart from these, the following properties are also available, and may be useful map-side aggregation and there are at most this many reduce partitions. + + spark.shuffle.blockTransferService + netty + + Implementation to use for transferring shuffle and cached blocks between executors. There + are two implementations available: netty and nio. Netty-based + block transfer is intended to be simpler but equally efficient and is the default option + starting in 1.2. + + #### Spark UI @@ -375,7 +397,16 @@ Apart from these, the following properties are also available, and may be useful spark.ui.retainedStages 1000 - How many stages the Spark UI remembers before garbage collecting. + How many stages the Spark UI and status APIs remember before garbage + collecting. + + + + spark.ui.retainedJobs + 1000 + + How many stages the Spark UI and status APIs remember before garbage + collecting. diff --git a/docs/mllib-clustering.md b/docs/mllib-clustering.md index 7978e934fb36b..c696ae9c8e8c8 100644 --- a/docs/mllib-clustering.md +++ b/docs/mllib-clustering.md @@ -34,7 +34,7 @@ a given dataset, the algorithm returns the best clustering result). * *initializationSteps* determines the number of steps in the k-means\|\| algorithm. * *epsilon* determines the distance threshold within which we consider k-means to have converged. -## Examples +### Examples
@@ -153,3 +153,97 @@ provided in the [Self-Contained Applications](quick-start.html#self-contained-ap section of the Spark Quick Start guide. Be sure to also include *spark-mllib* to your build file as a dependency. + +## Streaming clustering + +When data arrive in a stream, we may want to estimate clusters dynamically, +updating them as new data arrive. MLlib provides support for streaming k-means clustering, +with parameters to control the decay (or "forgetfulness") of the estimates. The algorithm +uses a generalization of the mini-batch k-means update rule. For each batch of data, we assign +all points to their nearest cluster, compute new cluster centers, then update each cluster using: + +`\begin{equation} + c_{t+1} = \frac{c_tn_t\alpha + x_tm_t}{n_t\alpha+m_t} +\end{equation}` +`\begin{equation} + n_{t+1} = n_t + m_t +\end{equation}` + +Where `$c_t$` is the previous center for the cluster, `$n_t$` is the number of points assigned +to the cluster thus far, `$x_t$` is the new cluster center from the current batch, and `$m_t$` +is the number of points added to the cluster in the current batch. The decay factor `$\alpha$` +can be used to ignore the past: with `$\alpha$=1` all data will be used from the beginning; +with `$\alpha$=0` only the most recent data will be used. This is analogous to an +exponentially-weighted moving average. + +The decay can be specified using a `halfLife` parameter, which determines the +correct decay factor `a` such that, for data acquired +at time `t`, its contribution by time `t + halfLife` will have dropped to 0.5. +The unit of time can be specified either as `batches` or `points` and the update rule +will be adjusted accordingly. + +### Examples + +This example shows how to estimate clusters on streaming data. + +
+ +
+ +First we import the neccessary classes. + +{% highlight scala %} + +import org.apache.spark.mllib.linalg.Vectors +import org.apache.spark.mllib.regression.LabeledPoint +import org.apache.spark.mllib.clustering.StreamingKMeans + +{% endhighlight %} + +Then we make an input stream of vectors for training, as well as a stream of labeled data +points for testing. We assume a StreamingContext `ssc` has been created, see +[Spark Streaming Programming Guide](streaming-programming-guide.html#initializing) for more info. + +{% highlight scala %} + +val trainingData = ssc.textFileStream("/training/data/dir").map(Vectors.parse) +val testData = ssc.textFileStream("/testing/data/dir").map(LabeledPoint.parse) + +{% endhighlight %} + +We create a model with random clusters and specify the number of clusters to find + +{% highlight scala %} + +val numDimensions = 3 +val numClusters = 2 +val model = new StreamingKMeans() + .setK(numClusters) + .setDecayFactor(1.0) + .setRandomCenters(numDimensions, 0.0) + +{% endhighlight %} + +Now register the streams for training and testing and start the job, printing +the predicted cluster assignments on new data points as they arrive. + +{% highlight scala %} + +model.trainOn(trainingData) +model.predictOnValues(testData).print() + +ssc.start() +ssc.awaitTermination() + +{% endhighlight %} + +As you add new text files with data the cluster centers will update. Each training +point should be formatted as `[x1, x2, x3]`, and each test data point +should be formatted as `(y, [x1, x2, x3])`, where `y` is some useful label or identifier +(e.g. a true category assignment). Anytime a text file is placed in `/training/data/dir` +the model will update. Anytime a text file is placed in `/testing/data/dir` +you will see predictions. With new data, the cluster centers will change! + +
+ +
diff --git a/docs/mllib-feature-extraction.md b/docs/mllib-feature-extraction.md index 11622414494e4..197bc77d506c6 100644 --- a/docs/mllib-feature-extraction.md +++ b/docs/mllib-feature-extraction.md @@ -95,8 +95,49 @@ tf.cache() val idf = new IDF(minDocFreq = 2).fit(tf) val tfidf: RDD[Vector] = idf.transform(tf) {% endhighlight %} +
+
+ +TF and IDF are implemented in [HashingTF](api/python/pyspark.mllib.html#pyspark.mllib.feature.HashingTF) +and [IDF](api/python/pyspark.mllib.html#pyspark.mllib.feature.IDF). +`HashingTF` takes an RDD of list as the input. +Each record could be an iterable of strings or other types. + +{% highlight python %} +from pyspark import SparkContext +from pyspark.mllib.feature import HashingTF + +sc = SparkContext() + +# Load documents (one per line). +documents = sc.textFile("...").map(lambda line: line.split(" ")) + +hashingTF = HashingTF() +tf = hashingTF.transform(documents) +{% endhighlight %} + +While applying `HashingTF` only needs a single pass to the data, applying `IDF` needs two passes: +first to compute the IDF vector and second to scale the term frequencies by IDF. + +{% highlight python %} +from pyspark.mllib.feature import IDF + +# ... continue from the previous example +tf.cache() +idf = IDF().fit(tf) +tfidf = idf.transform(tf) +{% endhighlight %} +MLLib's IDF implementation provides an option for ignoring terms which occur in less than a +minimum number of documents. In such cases, the IDF for these terms is set to 0. This feature +can be used by passing the `minDocFreq` value to the IDF constructor. +{% highlight python %} +# ... continue from the previous example +tf.cache() +idf = IDF(minDocFreq=2).fit(tf) +tfidf = idf.transform(tf) +{% endhighlight %}
@@ -162,6 +203,23 @@ for((synonym, cosineSimilarity) <- synonyms) { } {% endhighlight %} +
+{% highlight python %} +from pyspark import SparkContext +from pyspark.mllib.feature import Word2Vec + +sc = SparkContext(appName='Word2Vec') +inp = sc.textFile("text8_lines").map(lambda row: row.split(" ")) + +word2vec = Word2Vec() +model = word2vec.fit(inp) + +synonyms = model.findSynonyms('china', 40) + +for word, cosine_distance in synonyms: + print "{}: {}".format(word, cosine_distance) +{% endhighlight %} +
## StandardScaler @@ -223,6 +281,29 @@ val data1 = data.map(x => (x.label, scaler1.transform(x.features))) val data2 = data.map(x => (x.label, scaler2.transform(Vectors.dense(x.features.toArray)))) {% endhighlight %} + +
+{% highlight python %} +from pyspark.mllib.util import MLUtils +from pyspark.mllib.linalg import Vectors +from pyspark.mllib.feature import StandardScaler + +data = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_libsvm_data.txt") +label = data.map(lambda x: x.label) +features = data.map(lambda x: x.features) + +scaler1 = StandardScaler().fit(features) +scaler2 = StandardScaler(withMean=True, withStd=True).fit(features) + +# data1 will be unit variance. +data1 = label.zip(scaler1.transform(features)) + +# Without converting the features into dense vectors, transformation with zero mean will raise +# exception on sparse vector. +# data2 will be unit variance and zero mean. +data2 = label.zip(scaler1.transform(features.map(lambda x: Vectors.dense(x.toArray())))) +{% endhighlight %} +
## Normalizer @@ -267,4 +348,25 @@ val data1 = data.map(x => (x.label, normalizer1.transform(x.features))) val data2 = data.map(x => (x.label, normalizer2.transform(x.features))) {% endhighlight %} + +
+{% highlight python %} +from pyspark.mllib.util import MLUtils +from pyspark.mllib.linalg import Vectors +from pyspark.mllib.feature import Normalizer + +data = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_libsvm_data.txt") +labels = data.map(lambda x: x.label) +features = data.map(lambda x: x.features) + +normalizer1 = Normalizer() +normalizer2 = Normalizer(p=float("inf")) + +# Each sample in data1 will be normalized using $L^2$ norm. +data1 = labels.zip(normalizer1.transform(features)) + +# Each sample in data2 will be normalized using $L^\infty$ norm. +data2 = labels.zip(normalizer2.transform(features)) +{% endhighlight %} +
diff --git a/docs/sql-programming-guide.md b/docs/sql-programming-guide.md index 368c3d0008b07..e399fecbbc78c 100644 --- a/docs/sql-programming-guide.md +++ b/docs/sql-programming-guide.md @@ -582,19 +582,27 @@ Configuration of Parquet can be done using the `setConf` method on SQLContext or spark.sql.parquet.cacheMetadata - false + true Turns on caching of Parquet schema metadata. Can speed up querying of static data. spark.sql.parquet.compression.codec - snappy + gzip Sets the compression codec use when writing Parquet files. Acceptable values include: uncompressed, snappy, gzip, lzo. + + spark.sql.hive.convertMetastoreParquet + true + + When set to false, Spark SQL will use the Hive SerDe for parquet tables instead of the built in + support. + + ## JSON Datasets @@ -815,7 +823,7 @@ Configuration of in-memory caching can be done using the `setConf` method on SQL Property NameDefaultMeaning spark.sql.inMemoryColumnarStorage.compressed - false + true When set to true Spark SQL will automatically select a compression codec for each column based on statistics of the data. @@ -823,7 +831,7 @@ Configuration of in-memory caching can be done using the `setConf` method on SQL spark.sql.inMemoryColumnarStorage.batchSize - 1000 + 10000 Controls the size of batches for columnar caching. Larger batch sizes can improve memory utilization and compression, but risk OOMs when caching data. @@ -841,7 +849,7 @@ that these options will be deprecated in future release as more optimizations ar Property NameDefaultMeaning spark.sql.autoBroadcastJoinThreshold - 10000 + 10485760 (10 MB) Configures the maximum size in bytes for a table that will be broadcast to all worker nodes when performing a join. By setting this value to -1 broadcasting can be disabled. Note that currently @@ -1215,7 +1223,7 @@ import org.apache.spark.sql._ DecimalType - scala.math.sql.BigDecimal + scala.math.BigDecimal DecimalType diff --git a/ec2/spark_ec2.py b/ec2/spark_ec2.py index 0d6b82b4944f3..50f88f735650e 100755 --- a/ec2/spark_ec2.py +++ b/ec2/spark_ec2.py @@ -41,8 +41,9 @@ DEFAULT_SPARK_VERSION = "1.1.0" +MESOS_SPARK_EC2_BRANCH = "v4" # A URL prefix from which to fetch AMI information -AMI_PREFIX = "https://raw.github.com/mesos/spark-ec2/v2/ami-list" +AMI_PREFIX = "https://raw.github.com/mesos/spark-ec2/{b}/ami-list".format(b=MESOS_SPARK_EC2_BRANCH) class UsageError(Exception): @@ -583,7 +584,13 @@ def setup_cluster(conn, master_nodes, slave_nodes, opts, deploy_ssh_key): # NOTE: We should clone the repository before running deploy_files to # prevent ec2-variables.sh from being overwritten - ssh(master, opts, "rm -rf spark-ec2 && git clone https://github.com/mesos/spark-ec2.git -b v4") + ssh( + host=master, + opts=opts, + command="rm -rf spark-ec2" + + " && " + + "git clone https://github.com/mesos/spark-ec2.git -b {b}".format(b=MESOS_SPARK_EC2_BRANCH) + ) print "Deploying files to master..." deploy_files(conn, "deploy.generic", opts, master_nodes, slave_nodes, modules) diff --git a/examples/pom.xml b/examples/pom.xml index eb49a0e5af22d..bc3291803c324 100644 --- a/examples/pom.xml +++ b/examples/pom.xml @@ -156,6 +156,10 @@ algebird-core_${scala.binary.version} 0.1.11
+ + org.apache.commons + commons-math3 + org.scalatest scalatest_${scala.binary.version} @@ -268,6 +272,10 @@ com.google.common.base.Optional** + + org.apache.commons.math3 + org.spark-project.commons.math3 + diff --git a/examples/src/main/java/org/apache/spark/examples/JavaStatusAPIDemo.java b/examples/src/main/java/org/apache/spark/examples/JavaStatusAPIDemo.java new file mode 100644 index 0000000000000..430e96ab14d9d --- /dev/null +++ b/examples/src/main/java/org/apache/spark/examples/JavaStatusAPIDemo.java @@ -0,0 +1,70 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples; + +import org.apache.spark.SparkConf; +import org.apache.spark.SparkJobInfo; +import org.apache.spark.SparkStageInfo; +import org.apache.spark.api.java.JavaFutureAction; +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.api.java.function.Function; + +import java.util.Arrays; +import java.util.List; + +/** + * Example of using Spark's status APIs from Java. + */ +public final class JavaStatusAPIDemo { + + public static final String APP_NAME = "JavaStatusAPIDemo"; + + public static final class IdentityWithDelay implements Function { + @Override + public T call(T x) throws Exception { + Thread.sleep(2 * 1000); // 2 seconds + return x; + } + } + + public static void main(String[] args) throws Exception { + SparkConf sparkConf = new SparkConf().setAppName(APP_NAME); + final JavaSparkContext sc = new JavaSparkContext(sparkConf); + + // Example of implementing a progress reporter for a simple job. + JavaRDD rdd = sc.parallelize(Arrays.asList(1, 2, 3, 4, 5), 5).map( + new IdentityWithDelay()); + JavaFutureAction> jobFuture = rdd.collectAsync(); + while (!jobFuture.isDone()) { + Thread.sleep(1000); // 1 second + List jobIds = jobFuture.jobIds(); + if (jobIds.isEmpty()) { + continue; + } + int currentJobId = jobIds.get(jobIds.size() - 1); + SparkJobInfo jobInfo = sc.getJobInfo(currentJobId); + SparkStageInfo stageInfo = sc.getStageInfo(jobInfo.stageIds()[0]); + System.out.println(stageInfo.numTasks() + " tasks total: " + stageInfo.numActiveTasks() + + " active, " + stageInfo.numCompletedTasks() + " complete"); + } + + System.out.println("Job results are: " + jobFuture.get()); + sc.stop(); + } +} diff --git a/examples/src/main/python/mllib/dataset_example.py b/examples/src/main/python/mllib/dataset_example.py new file mode 100644 index 0000000000000..540dae785f6ea --- /dev/null +++ b/examples/src/main/python/mllib/dataset_example.py @@ -0,0 +1,62 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +""" +An example of how to use SchemaRDD as a dataset for ML. Run with:: + bin/spark-submit examples/src/main/python/mllib/dataset_example.py +""" + +import os +import sys +import tempfile +import shutil + +from pyspark import SparkContext +from pyspark.sql import SQLContext +from pyspark.mllib.util import MLUtils +from pyspark.mllib.stat import Statistics + + +def summarize(dataset): + print "schema: %s" % dataset.schema().json() + labels = dataset.map(lambda r: r.label) + print "label average: %f" % labels.mean() + features = dataset.map(lambda r: r.features) + summary = Statistics.colStats(features) + print "features average: %r" % summary.mean() + +if __name__ == "__main__": + if len(sys.argv) > 2: + print >> sys.stderr, "Usage: dataset_example.py " + exit(-1) + sc = SparkContext(appName="DatasetExample") + sqlCtx = SQLContext(sc) + if len(sys.argv) == 2: + input = sys.argv[1] + else: + input = "data/mllib/sample_libsvm_data.txt" + points = MLUtils.loadLibSVMFile(sc, input) + dataset0 = sqlCtx.inferSchema(points).setName("dataset0").cache() + summarize(dataset0) + tempdir = tempfile.NamedTemporaryFile(delete=False).name + os.unlink(tempdir) + print "Save dataset as a Parquet file to %s." % tempdir + dataset0.saveAsParquetFile(tempdir) + print "Load it back and summarize it again." + dataset1 = sqlCtx.parquetFile(tempdir).setName("dataset1").cache() + summarize(dataset1) + shutil.rmtree(tempdir) diff --git a/examples/src/main/python/mllib/word2vec.py b/examples/src/main/python/mllib/word2vec.py new file mode 100644 index 0000000000000..99fef4276a369 --- /dev/null +++ b/examples/src/main/python/mllib/word2vec.py @@ -0,0 +1,50 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# This example uses text8 file from http://mattmahoney.net/dc/text8.zip +# The file was downloadded, unziped and split into multiple lines using +# +# wget http://mattmahoney.net/dc/text8.zip +# unzip text8.zip +# grep -o -E '\w+(\W+\w+){0,15}' text8 > text8_lines +# This was done so that the example can be run in local mode + + +import sys + +from pyspark import SparkContext +from pyspark.mllib.feature import Word2Vec + +USAGE = ("bin/spark-submit --driver-memory 4g " + "examples/src/main/python/mllib/word2vec.py text8_lines") + +if __name__ == "__main__": + if len(sys.argv) < 2: + print USAGE + sys.exit("Argument for file not provided") + file_path = sys.argv[1] + sc = SparkContext(appName='Word2Vec') + inp = sc.textFile(file_path).map(lambda row: row.split(" ")) + + word2vec = Word2Vec() + model = word2vec.fit(inp) + + synonyms = model.findSynonyms('china', 40) + + for word, cosine_distance in synonyms: + print "{}: {}".format(word, cosine_distance) + sc.stop() diff --git a/examples/src/main/python/streaming/hdfs_wordcount.py b/examples/src/main/python/streaming/hdfs_wordcount.py index 40faff0ccc7db..f7ffb5379681e 100644 --- a/examples/src/main/python/streaming/hdfs_wordcount.py +++ b/examples/src/main/python/streaming/hdfs_wordcount.py @@ -21,7 +21,7 @@ is the directory that Spark Streaming will use to find and read new text files. To run this on your local machine on directory `localdir`, run this example - $ bin/spark-submit examples/src/main/python/streaming/network_wordcount.py localdir + $ bin/spark-submit examples/src/main/python/streaming/hdfs_wordcount.py localdir Then create a text file in `localdir` and the words in the file will get counted. """ diff --git a/examples/src/main/scala/org/apache/spark/examples/LocalALS.scala b/examples/src/main/scala/org/apache/spark/examples/LocalALS.scala index 1f576319b3ca8..3d5259463003d 100644 --- a/examples/src/main/scala/org/apache/spark/examples/LocalALS.scala +++ b/examples/src/main/scala/org/apache/spark/examples/LocalALS.scala @@ -17,11 +17,7 @@ package org.apache.spark.examples -import scala.math.sqrt - -import cern.colt.matrix._ -import cern.colt.matrix.linalg._ -import cern.jet.math._ +import org.apache.commons.math3.linear._ /** * Alternating least squares matrix factorization. @@ -30,84 +26,70 @@ import cern.jet.math._ * please refer to org.apache.spark.mllib.recommendation.ALS */ object LocalALS { + // Parameters set through command line arguments var M = 0 // Number of movies var U = 0 // Number of users var F = 0 // Number of features var ITERATIONS = 0 - val LAMBDA = 0.01 // Regularization coefficient - // Some COLT objects - val factory2D = DoubleFactory2D.dense - val factory1D = DoubleFactory1D.dense - val algebra = Algebra.DEFAULT - val blas = SeqBlas.seqBlas - - def generateR(): DoubleMatrix2D = { - val mh = factory2D.random(M, F) - val uh = factory2D.random(U, F) - algebra.mult(mh, algebra.transpose(uh)) + def generateR(): RealMatrix = { + val mh = randomMatrix(M, F) + val uh = randomMatrix(U, F) + mh.multiply(uh.transpose()) } - def rmse(targetR: DoubleMatrix2D, ms: Array[DoubleMatrix1D], - us: Array[DoubleMatrix1D]): Double = - { - val r = factory2D.make(M, U) + def rmse(targetR: RealMatrix, ms: Array[RealVector], us: Array[RealVector]): Double = { + val r = new Array2DRowRealMatrix(M, U) for (i <- 0 until M; j <- 0 until U) { - r.set(i, j, blas.ddot(ms(i), us(j))) + r.setEntry(i, j, ms(i).dotProduct(us(j))) } - blas.daxpy(-1, targetR, r) - val sumSqs = r.aggregate(Functions.plus, Functions.square) - sqrt(sumSqs / (M * U)) + val diffs = r.subtract(targetR) + var sumSqs = 0.0 + for (i <- 0 until M; j <- 0 until U) { + val diff = diffs.getEntry(i, j) + sumSqs += diff * diff + } + math.sqrt(sumSqs / (M.toDouble * U.toDouble)) } - def updateMovie(i: Int, m: DoubleMatrix1D, us: Array[DoubleMatrix1D], - R: DoubleMatrix2D) : DoubleMatrix1D = - { - val XtX = factory2D.make(F, F) - val Xty = factory1D.make(F) + def updateMovie(i: Int, m: RealVector, us: Array[RealVector], R: RealMatrix) : RealVector = { + var XtX: RealMatrix = new Array2DRowRealMatrix(F, F) + var Xty: RealVector = new ArrayRealVector(F) // For each user that rated the movie for (j <- 0 until U) { val u = us(j) // Add u * u^t to XtX - blas.dger(1, u, u, XtX) + XtX = XtX.add(u.outerProduct(u)) // Add u * rating to Xty - blas.daxpy(R.get(i, j), u, Xty) + Xty = Xty.add(u.mapMultiply(R.getEntry(i, j))) } - // Add regularization coefs to diagonal terms + // Add regularization coefficients to diagonal terms for (d <- 0 until F) { - XtX.set(d, d, XtX.get(d, d) + LAMBDA * U) + XtX.addToEntry(d, d, LAMBDA * U) } // Solve it with Cholesky - val ch = new CholeskyDecomposition(XtX) - val Xty2D = factory2D.make(Xty.toArray, F) - val solved2D = ch.solve(Xty2D) - solved2D.viewColumn(0) + new CholeskyDecomposition(XtX).getSolver.solve(Xty) } - def updateUser(j: Int, u: DoubleMatrix1D, ms: Array[DoubleMatrix1D], - R: DoubleMatrix2D) : DoubleMatrix1D = - { - val XtX = factory2D.make(F, F) - val Xty = factory1D.make(F) + def updateUser(j: Int, u: RealVector, ms: Array[RealVector], R: RealMatrix) : RealVector = { + var XtX: RealMatrix = new Array2DRowRealMatrix(F, F) + var Xty: RealVector = new ArrayRealVector(F) // For each movie that the user rated for (i <- 0 until M) { val m = ms(i) // Add m * m^t to XtX - blas.dger(1, m, m, XtX) + XtX = XtX.add(m.outerProduct(m)) // Add m * rating to Xty - blas.daxpy(R.get(i, j), m, Xty) + Xty = Xty.add(m.mapMultiply(R.getEntry(i, j))) } - // Add regularization coefs to diagonal terms + // Add regularization coefficients to diagonal terms for (d <- 0 until F) { - XtX.set(d, d, XtX.get(d, d) + LAMBDA * M) + XtX.addToEntry(d, d, LAMBDA * M) } // Solve it with Cholesky - val ch = new CholeskyDecomposition(XtX) - val Xty2D = factory2D.make(Xty.toArray, F) - val solved2D = ch.solve(Xty2D) - solved2D.viewColumn(0) + new CholeskyDecomposition(XtX).getSolver.solve(Xty) } def showWarning() { @@ -135,21 +117,28 @@ object LocalALS { showWarning() - printf("Running with M=%d, U=%d, F=%d, iters=%d\n", M, U, F, ITERATIONS) + println(s"Running with M=$M, U=$U, F=$F, iters=$ITERATIONS") val R = generateR() // Initialize m and u randomly - var ms = Array.fill(M)(factory1D.random(F)) - var us = Array.fill(U)(factory1D.random(F)) + var ms = Array.fill(M)(randomVector(F)) + var us = Array.fill(U)(randomVector(F)) // Iteratively update movies then users for (iter <- 1 to ITERATIONS) { - println("Iteration " + iter + ":") + println(s"Iteration $iter:") ms = (0 until M).map(i => updateMovie(i, ms(i), us, R)).toArray us = (0 until U).map(j => updateUser(j, us(j), ms, R)).toArray println("RMSE = " + rmse(R, ms, us)) println() } } + + private def randomVector(n: Int): RealVector = + new ArrayRealVector(Array.fill(n)(math.random)) + + private def randomMatrix(rows: Int, cols: Int): RealMatrix = + new Array2DRowRealMatrix(Array.fill(rows, cols)(math.random)) + } diff --git a/examples/src/main/scala/org/apache/spark/examples/SparkALS.scala b/examples/src/main/scala/org/apache/spark/examples/SparkALS.scala index fde8ffeedf8b4..6c0ac8013ce34 100644 --- a/examples/src/main/scala/org/apache/spark/examples/SparkALS.scala +++ b/examples/src/main/scala/org/apache/spark/examples/SparkALS.scala @@ -17,11 +17,7 @@ package org.apache.spark.examples -import scala.math.sqrt - -import cern.colt.matrix._ -import cern.colt.matrix.linalg._ -import cern.jet.math._ +import org.apache.commons.math3.linear._ import org.apache.spark._ @@ -32,62 +28,53 @@ import org.apache.spark._ * please refer to org.apache.spark.mllib.recommendation.ALS */ object SparkALS { + // Parameters set through command line arguments var M = 0 // Number of movies var U = 0 // Number of users var F = 0 // Number of features var ITERATIONS = 0 - val LAMBDA = 0.01 // Regularization coefficient - // Some COLT objects - val factory2D = DoubleFactory2D.dense - val factory1D = DoubleFactory1D.dense - val algebra = Algebra.DEFAULT - val blas = SeqBlas.seqBlas - - def generateR(): DoubleMatrix2D = { - val mh = factory2D.random(M, F) - val uh = factory2D.random(U, F) - algebra.mult(mh, algebra.transpose(uh)) + def generateR(): RealMatrix = { + val mh = randomMatrix(M, F) + val uh = randomMatrix(U, F) + mh.multiply(uh.transpose()) } - def rmse(targetR: DoubleMatrix2D, ms: Array[DoubleMatrix1D], - us: Array[DoubleMatrix1D]): Double = - { - val r = factory2D.make(M, U) + def rmse(targetR: RealMatrix, ms: Array[RealVector], us: Array[RealVector]): Double = { + val r = new Array2DRowRealMatrix(M, U) for (i <- 0 until M; j <- 0 until U) { - r.set(i, j, blas.ddot(ms(i), us(j))) + r.setEntry(i, j, ms(i).dotProduct(us(j))) } - blas.daxpy(-1, targetR, r) - val sumSqs = r.aggregate(Functions.plus, Functions.square) - sqrt(sumSqs / (M * U)) + val diffs = r.subtract(targetR) + var sumSqs = 0.0 + for (i <- 0 until M; j <- 0 until U) { + val diff = diffs.getEntry(i, j) + sumSqs += diff * diff + } + math.sqrt(sumSqs / (M.toDouble * U.toDouble)) } - def update(i: Int, m: DoubleMatrix1D, us: Array[DoubleMatrix1D], - R: DoubleMatrix2D) : DoubleMatrix1D = - { + def update(i: Int, m: RealVector, us: Array[RealVector], R: RealMatrix) : RealVector = { val U = us.size - val F = us(0).size - val XtX = factory2D.make(F, F) - val Xty = factory1D.make(F) + val F = us(0).getDimension + var XtX: RealMatrix = new Array2DRowRealMatrix(F, F) + var Xty: RealVector = new ArrayRealVector(F) // For each user that rated the movie for (j <- 0 until U) { val u = us(j) // Add u * u^t to XtX - blas.dger(1, u, u, XtX) + XtX = XtX.add(u.outerProduct(u)) // Add u * rating to Xty - blas.daxpy(R.get(i, j), u, Xty) + Xty = Xty.add(u.mapMultiply(R.getEntry(i, j))) } // Add regularization coefs to diagonal terms for (d <- 0 until F) { - XtX.set(d, d, XtX.get(d, d) + LAMBDA * U) + XtX.addToEntry(d, d, LAMBDA * U) } // Solve it with Cholesky - val ch = new CholeskyDecomposition(XtX) - val Xty2D = factory2D.make(Xty.toArray, F) - val solved2D = ch.solve(Xty2D) - solved2D.viewColumn(0) + new CholeskyDecomposition(XtX).getSolver.solve(Xty) } def showWarning() { @@ -118,7 +105,7 @@ object SparkALS { showWarning() - printf("Running with M=%d, U=%d, F=%d, iters=%d\n", M, U, F, ITERATIONS) + println(s"Running with M=$M, U=$U, F=$F, iters=$ITERATIONS") val sparkConf = new SparkConf().setAppName("SparkALS") val sc = new SparkContext(sparkConf) @@ -126,21 +113,21 @@ object SparkALS { val R = generateR() // Initialize m and u randomly - var ms = Array.fill(M)(factory1D.random(F)) - var us = Array.fill(U)(factory1D.random(F)) + var ms = Array.fill(M)(randomVector(F)) + var us = Array.fill(U)(randomVector(F)) // Iteratively update movies then users val Rc = sc.broadcast(R) var msb = sc.broadcast(ms) var usb = sc.broadcast(us) for (iter <- 1 to ITERATIONS) { - println("Iteration " + iter + ":") + println(s"Iteration $iter:") ms = sc.parallelize(0 until M, slices) .map(i => update(i, msb.value(i), usb.value, Rc.value)) .collect() msb = sc.broadcast(ms) // Re-broadcast ms because it was updated us = sc.parallelize(0 until U, slices) - .map(i => update(i, usb.value(i), msb.value, algebra.transpose(Rc.value))) + .map(i => update(i, usb.value(i), msb.value, Rc.value.transpose())) .collect() usb = sc.broadcast(us) // Re-broadcast us because it was updated println("RMSE = " + rmse(R, ms, us)) @@ -149,4 +136,11 @@ object SparkALS { sc.stop() } + + private def randomVector(n: Int): RealVector = + new ArrayRealVector(Array.fill(n)(math.random)) + + private def randomMatrix(rows: Int, cols: Int): RealMatrix = + new Array2DRowRealMatrix(Array.fill(rows, cols)(math.random)) + } diff --git a/examples/src/main/scala/org/apache/spark/examples/graphx/Analytics.scala b/examples/src/main/scala/org/apache/spark/examples/graphx/Analytics.scala index d70d93608a57c..828cffb01ca1e 100644 --- a/examples/src/main/scala/org/apache/spark/examples/graphx/Analytics.scala +++ b/examples/src/main/scala/org/apache/spark/examples/graphx/Analytics.scala @@ -77,7 +77,7 @@ object Analytics extends Logging { val sc = new SparkContext(conf.setAppName("PageRank(" + fname + ")")) val unpartitionedGraph = GraphLoader.edgeListFile(sc, fname, - minEdgePartitions = numEPart, + numEdgePartitions = numEPart, edgeStorageLevel = edgeStorageLevel, vertexStorageLevel = vertexStorageLevel).cache() val graph = partitionStrategy.foldLeft(unpartitionedGraph)(_.partitionBy(_)) @@ -110,7 +110,7 @@ object Analytics extends Logging { val sc = new SparkContext(conf.setAppName("ConnectedComponents(" + fname + ")")) val unpartitionedGraph = GraphLoader.edgeListFile(sc, fname, - minEdgePartitions = numEPart, + numEdgePartitions = numEPart, edgeStorageLevel = edgeStorageLevel, vertexStorageLevel = vertexStorageLevel).cache() val graph = partitionStrategy.foldLeft(unpartitionedGraph)(_.partitionBy(_)) @@ -131,7 +131,7 @@ object Analytics extends Logging { val sc = new SparkContext(conf.setAppName("TriangleCount(" + fname + ")")) val graph = GraphLoader.edgeListFile(sc, fname, canonicalOrientation = true, - minEdgePartitions = numEPart, + numEdgePartitions = numEPart, edgeStorageLevel = edgeStorageLevel, vertexStorageLevel = vertexStorageLevel) // TriangleCount requires the graph to be partitioned diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/DatasetExample.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/DatasetExample.scala new file mode 100644 index 0000000000000..f8d83f4ec7327 --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/DatasetExample.scala @@ -0,0 +1,121 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.mllib + +import java.io.File + +import com.google.common.io.Files +import scopt.OptionParser + +import org.apache.spark.{SparkConf, SparkContext} +import org.apache.spark.mllib.linalg.Vector +import org.apache.spark.mllib.regression.LabeledPoint +import org.apache.spark.mllib.stat.MultivariateOnlineSummarizer +import org.apache.spark.mllib.util.MLUtils +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.{Row, SQLContext, SchemaRDD} + +/** + * An example of how to use [[org.apache.spark.sql.SchemaRDD]] as a Dataset for ML. Run with + * {{{ + * ./bin/run-example org.apache.spark.examples.mllib.DatasetExample [options] + * }}} + * If you use it as a template to create your own app, please use `spark-submit` to submit your app. + */ +object DatasetExample { + + case class Params( + input: String = "data/mllib/sample_libsvm_data.txt", + dataFormat: String = "libsvm") extends AbstractParams[Params] + + def main(args: Array[String]) { + val defaultParams = Params() + + val parser = new OptionParser[Params]("DatasetExample") { + head("Dataset: an example app using SchemaRDD as a Dataset for ML.") + opt[String]("input") + .text(s"input path to dataset") + .action((x, c) => c.copy(input = x)) + opt[String]("dataFormat") + .text("data format: libsvm (default), dense (deprecated in Spark v1.1)") + .action((x, c) => c.copy(input = x)) + checkConfig { params => + success + } + } + + parser.parse(args, defaultParams).map { params => + run(params) + }.getOrElse { + sys.exit(1) + } + } + + def run(params: Params) { + + val conf = new SparkConf().setAppName(s"DatasetExample with $params") + val sc = new SparkContext(conf) + val sqlContext = new SQLContext(sc) + import sqlContext._ // for implicit conversions + + // Load input data + val origData: RDD[LabeledPoint] = params.dataFormat match { + case "dense" => MLUtils.loadLabeledPoints(sc, params.input) + case "libsvm" => MLUtils.loadLibSVMFile(sc, params.input) + } + println(s"Loaded ${origData.count()} instances from file: ${params.input}") + + // Convert input data to SchemaRDD explicitly. + val schemaRDD: SchemaRDD = origData + println(s"Inferred schema:\n${schemaRDD.schema.prettyJson}") + println(s"Converted to SchemaRDD with ${schemaRDD.count()} records") + + // Select columns, using implicit conversion to SchemaRDD. + val labelsSchemaRDD: SchemaRDD = origData.select('label) + val labels: RDD[Double] = labelsSchemaRDD.map { case Row(v: Double) => v } + val numLabels = labels.count() + val meanLabel = labels.fold(0.0)(_ + _) / numLabels + println(s"Selected label column with average value $meanLabel") + + val featuresSchemaRDD: SchemaRDD = origData.select('features) + val features: RDD[Vector] = featuresSchemaRDD.map { case Row(v: Vector) => v } + val featureSummary = features.aggregate(new MultivariateOnlineSummarizer())( + (summary, feat) => summary.add(feat), + (sum1, sum2) => sum1.merge(sum2)) + println(s"Selected features column with average values:\n ${featureSummary.mean.toString}") + + val tmpDir = Files.createTempDir() + tmpDir.deleteOnExit() + val outputDir = new File(tmpDir, "dataset").toString + println(s"Saving to $outputDir as Parquet file.") + schemaRDD.saveAsParquetFile(outputDir) + + println(s"Loading Parquet file with UDT from $outputDir.") + val newDataset = sqlContext.parquetFile(outputDir) + + println(s"Schema from Parquet: ${newDataset.schema.prettyJson}") + val newFeatures = newDataset.select('features).map { case Row(v: Vector) => v } + val newFeaturesSummary = newFeatures.aggregate(new MultivariateOnlineSummarizer())( + (summary, feat) => summary.add(feat), + (sum1, sum2) => sum1.merge(sum2)) + println(s"Selected features column with average values:\n ${newFeaturesSummary.mean.toString}") + + sc.stop() + } + +} diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRunner.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRunner.scala index 0890e6263e165..49751a30491d0 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRunner.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRunner.scala @@ -26,7 +26,7 @@ import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.tree.{RandomForest, DecisionTree, impurity} import org.apache.spark.mllib.tree.configuration.{Algo, Strategy} import org.apache.spark.mllib.tree.configuration.Algo._ -import org.apache.spark.mllib.tree.model.{RandomForestModel, DecisionTreeModel} +import org.apache.spark.mllib.tree.model.{WeightedEnsembleModel, DecisionTreeModel} import org.apache.spark.mllib.util.MLUtils import org.apache.spark.rdd.RDD import org.apache.spark.util.Utils @@ -62,7 +62,10 @@ object DecisionTreeRunner { minInfoGain: Double = 0.0, numTrees: Int = 1, featureSubsetStrategy: String = "auto", - fracTest: Double = 0.2) extends AbstractParams[Params] + fracTest: Double = 0.2, + useNodeIdCache: Boolean = false, + checkpointDir: Option[String] = None, + checkpointInterval: Int = 10) extends AbstractParams[Params] def main(args: Array[String]) { val defaultParams = Params() @@ -102,6 +105,21 @@ object DecisionTreeRunner { .text(s"fraction of data to hold out for testing. If given option testInput, " + s"this option is ignored. default: ${defaultParams.fracTest}") .action((x, c) => c.copy(fracTest = x)) + opt[Boolean]("useNodeIdCache") + .text(s"whether to use node Id cache during training, " + + s"default: ${defaultParams.useNodeIdCache}") + .action((x, c) => c.copy(useNodeIdCache = x)) + opt[String]("checkpointDir") + .text(s"checkpoint directory where intermediate node Id caches will be stored, " + + s"default: ${defaultParams.checkpointDir match { + case Some(strVal) => strVal + case None => "None" + }}") + .action((x, c) => c.copy(checkpointDir = Some(x))) + opt[Int]("checkpointInterval") + .text(s"how often to checkpoint the node Id cache, " + + s"default: ${defaultParams.checkpointInterval}") + .action((x, c) => c.copy(checkpointInterval = x)) opt[String]("testInput") .text(s"input path to test dataset. If given, option fracTest is ignored." + s" default: ${defaultParams.testInput}") @@ -236,7 +254,10 @@ object DecisionTreeRunner { maxBins = params.maxBins, numClassesForClassification = numClasses, minInstancesPerNode = params.minInstancesPerNode, - minInfoGain = params.minInfoGain) + minInfoGain = params.minInfoGain, + useNodeIdCache = params.useNodeIdCache, + checkpointDir = params.checkpointDir, + checkpointInterval = params.checkpointInterval) if (params.numTrees == 1) { val startTime = System.nanoTime() val model = DecisionTree.train(training, strategy) @@ -317,7 +338,7 @@ object DecisionTreeRunner { /** * Calculates the mean squared error for regression. */ - private def meanSquaredError(tree: RandomForestModel, data: RDD[LabeledPoint]): Double = { + private def meanSquaredError(tree: WeightedEnsembleModel, data: RDD[LabeledPoint]): Double = { data.map { y => val err = tree.predict(y.features) - y.label err * err diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/StreamingKMeans.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/StreamingKMeans.scala new file mode 100644 index 0000000000000..33e5760aed997 --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/StreamingKMeans.scala @@ -0,0 +1,77 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.mllib + +import org.apache.spark.mllib.linalg.Vectors +import org.apache.spark.mllib.regression.LabeledPoint +import org.apache.spark.mllib.clustering.StreamingKMeans +import org.apache.spark.SparkConf +import org.apache.spark.streaming.{Seconds, StreamingContext} + +/** + * Estimate clusters on one stream of data and make predictions + * on another stream, where the data streams arrive as text files + * into two different directories. + * + * The rows of the training text files must be vector data in the form + * `[x1,x2,x3,...,xn]` + * Where n is the number of dimensions. + * + * The rows of the test text files must be labeled data in the form + * `(y,[x1,x2,x3,...,xn])` + * Where y is some identifier. n must be the same for train and test. + * + * Usage: StreamingKmeans + * + * To run on your local machine using the two directories `trainingDir` and `testDir`, + * with updates every 5 seconds, 2 dimensions per data point, and 3 clusters, call: + * $ bin/run-example \ + * org.apache.spark.examples.mllib.StreamingKMeans trainingDir testDir 5 3 2 + * + * As you add text files to `trainingDir` the clusters will continuously update. + * Anytime you add text files to `testDir`, you'll see predicted labels using the current model. + * + */ +object StreamingKMeans { + + def main(args: Array[String]) { + if (args.length != 5) { + System.err.println( + "Usage: StreamingKMeans " + + " ") + System.exit(1) + } + + val conf = new SparkConf().setMaster("local").setAppName("StreamingLinearRegression") + val ssc = new StreamingContext(conf, Seconds(args(2).toLong)) + + val trainingData = ssc.textFileStream(args(0)).map(Vectors.parse) + val testData = ssc.textFileStream(args(1)).map(LabeledPoint.parse) + + val model = new StreamingKMeans() + .setK(args(3).toInt) + .setDecayFactor(1.0) + .setRandomCenters(args(4).toInt, 0.0) + + model.trainOn(trainingData) + model.predictOnValues(testData.map(lp => (lp.label, lp.features))).print() + + ssc.start() + ssc.awaitTermination() + } +} diff --git a/external/flume/src/test/scala/org/apache/spark/streaming/flume/FlumePollingStreamSuite.scala b/external/flume/src/test/scala/org/apache/spark/streaming/flume/FlumePollingStreamSuite.scala index 32a19787a28e1..475026e8eb140 100644 --- a/external/flume/src/test/scala/org/apache/spark/streaming/flume/FlumePollingStreamSuite.scala +++ b/external/flume/src/test/scala/org/apache/spark/streaming/flume/FlumePollingStreamSuite.scala @@ -145,11 +145,16 @@ class FlumePollingStreamSuite extends TestSuiteBase { outputStream.register() ssc.start() - writeAndVerify(Seq(channel, channel2), ssc, outputBuffer) - assertChannelIsEmpty(channel) - assertChannelIsEmpty(channel2) - sink.stop() - channel.stop() + try { + writeAndVerify(Seq(channel, channel2), ssc, outputBuffer) + assertChannelIsEmpty(channel) + assertChannelIsEmpty(channel2) + } finally { + sink.stop() + sink2.stop() + channel.stop() + channel2.stop() + } } def writeAndVerify(channels: Seq[MemoryChannel], ssc: StreamingContext, diff --git a/graphx/src/main/scala/org/apache/spark/graphx/EdgeRDD.scala b/graphx/src/main/scala/org/apache/spark/graphx/EdgeRDD.scala index 5bcb96b136ed7..5267560b3e5ce 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/EdgeRDD.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/EdgeRDD.scala @@ -82,12 +82,17 @@ class EdgeRDD[@specialized ED: ClassTag, VD: ClassTag]( this } - /** Persists the vertex partitions using `targetStorageLevel`, which defaults to MEMORY_ONLY. */ + /** Persists the edge partitions using `targetStorageLevel`, which defaults to MEMORY_ONLY. */ override def cache(): this.type = { partitionsRDD.persist(targetStorageLevel) this } + /** The number of edges in the RDD. */ + override def count(): Long = { + partitionsRDD.map(_._2.size.toLong).reduce(_ + _) + } + private[graphx] def mapEdgePartitions[ED2: ClassTag, VD2: ClassTag]( f: (PartitionID, EdgePartition[ED, VD]) => EdgePartition[ED2, VD2]): EdgeRDD[ED2, VD2] = { this.withPartitionsRDD[ED2, VD2](partitionsRDD.mapPartitions({ iter => diff --git a/graphx/src/main/scala/org/apache/spark/graphx/GraphLoader.scala b/graphx/src/main/scala/org/apache/spark/graphx/GraphLoader.scala index f4c79365b16da..4933aecba1286 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/GraphLoader.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/GraphLoader.scala @@ -48,7 +48,8 @@ object GraphLoader extends Logging { * @param path the path to the file (e.g., /home/data/file or hdfs://file) * @param canonicalOrientation whether to orient edges in the positive * direction - * @param minEdgePartitions the number of partitions for the edge RDD + * @param numEdgePartitions the number of partitions for the edge RDD + * Setting this value to -1 will use the default parallelism. * @param edgeStorageLevel the desired storage level for the edge partitions * @param vertexStorageLevel the desired storage level for the vertex partitions */ @@ -56,7 +57,7 @@ object GraphLoader extends Logging { sc: SparkContext, path: String, canonicalOrientation: Boolean = false, - minEdgePartitions: Int = 1, + numEdgePartitions: Int = -1, edgeStorageLevel: StorageLevel = StorageLevel.MEMORY_ONLY, vertexStorageLevel: StorageLevel = StorageLevel.MEMORY_ONLY) : Graph[Int, Int] = @@ -64,7 +65,12 @@ object GraphLoader extends Logging { val startTime = System.currentTimeMillis // Parse the edge data table directly into edge partitions - val lines = sc.textFile(path, minEdgePartitions).coalesce(minEdgePartitions) + val lines = + if (numEdgePartitions > 0) { + sc.textFile(path, numEdgePartitions).coalesce(numEdgePartitions) + } else { + sc.textFile(path) + } val edges = lines.mapPartitionsWithIndex { (pid, iter) => val builder = new EdgePartitionBuilder[Int, Int] iter.foreach { line => diff --git a/mllib/pom.xml b/mllib/pom.xml index 696e9396f627c..87a7ddaba97f2 100644 --- a/mllib/pom.xml +++ b/mllib/pom.xml @@ -45,6 +45,11 @@ spark-streaming_${scala.binary.version} ${project.version} + + org.apache.spark + spark-sql_${scala.binary.version} + ${project.version} + org.eclipse.jetty jetty-server @@ -71,6 +76,10 @@ + + org.apache.commons + commons-math3 + org.scalatest scalatest_${scala.binary.version} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala index b478c21537c2a..65b98a8ceea55 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala @@ -18,7 +18,7 @@ package org.apache.spark.mllib.api.python import java.io.OutputStream -import java.util.{ArrayList => JArrayList} +import java.util.{ArrayList => JArrayList, List => JList, Map => JMap} import scala.collection.JavaConverters._ import scala.language.existentials @@ -31,8 +31,7 @@ import org.apache.spark.api.java.{JavaRDD, JavaSparkContext} import org.apache.spark.api.python.{PythonRDD, SerDeUtil} import org.apache.spark.mllib.classification._ import org.apache.spark.mllib.clustering._ -import org.apache.spark.mllib.feature.Word2Vec -import org.apache.spark.mllib.feature.Word2VecModel +import org.apache.spark.mllib.feature._ import org.apache.spark.mllib.optimization._ import org.apache.spark.mllib.linalg._ import org.apache.spark.mllib.random.{RandomRDDs => RG} @@ -73,15 +72,11 @@ class PythonMLLibAPI extends Serializable { private def trainRegressionModel( learner: GeneralizedLinearAlgorithm[_ <: GeneralizedLinearModel], data: JavaRDD[LabeledPoint], - initialWeightsBA: Array[Byte]): java.util.LinkedList[java.lang.Object] = { - val initialWeights = SerDe.loads(initialWeightsBA).asInstanceOf[Vector] + initialWeights: Vector): JList[Object] = { // Disable the uncached input warning because 'data' is a deliberately uncached MappedRDD. learner.disableUncachedWarning() val model = learner.run(data.rdd, initialWeights) - val ret = new java.util.LinkedList[java.lang.Object]() - ret.add(SerDe.dumps(model.weights)) - ret.add(model.intercept: java.lang.Double) - ret + List(model.weights, model.intercept).map(_.asInstanceOf[Object]).asJava } /** @@ -92,10 +87,10 @@ class PythonMLLibAPI extends Serializable { numIterations: Int, stepSize: Double, miniBatchFraction: Double, - initialWeightsBA: Array[Byte], + initialWeights: Vector, regParam: Double, regType: String, - intercept: Boolean): java.util.List[java.lang.Object] = { + intercept: Boolean): JList[Object] = { val lrAlg = new LinearRegressionWithSGD() lrAlg.setIntercept(intercept) lrAlg.optimizer @@ -114,7 +109,7 @@ class PythonMLLibAPI extends Serializable { trainRegressionModel( lrAlg, data, - initialWeightsBA) + initialWeights) } /** @@ -126,7 +121,7 @@ class PythonMLLibAPI extends Serializable { stepSize: Double, regParam: Double, miniBatchFraction: Double, - initialWeightsBA: Array[Byte]): java.util.List[java.lang.Object] = { + initialWeights: Vector): JList[Object] = { val lassoAlg = new LassoWithSGD() lassoAlg.optimizer .setNumIterations(numIterations) @@ -136,7 +131,7 @@ class PythonMLLibAPI extends Serializable { trainRegressionModel( lassoAlg, data, - initialWeightsBA) + initialWeights) } /** @@ -148,7 +143,7 @@ class PythonMLLibAPI extends Serializable { stepSize: Double, regParam: Double, miniBatchFraction: Double, - initialWeightsBA: Array[Byte]): java.util.List[java.lang.Object] = { + initialWeights: Vector): JList[Object] = { val ridgeAlg = new RidgeRegressionWithSGD() ridgeAlg.optimizer .setNumIterations(numIterations) @@ -158,7 +153,7 @@ class PythonMLLibAPI extends Serializable { trainRegressionModel( ridgeAlg, data, - initialWeightsBA) + initialWeights) } /** @@ -170,9 +165,9 @@ class PythonMLLibAPI extends Serializable { stepSize: Double, regParam: Double, miniBatchFraction: Double, - initialWeightsBA: Array[Byte], + initialWeights: Vector, regType: String, - intercept: Boolean): java.util.List[java.lang.Object] = { + intercept: Boolean): JList[Object] = { val SVMAlg = new SVMWithSGD() SVMAlg.setIntercept(intercept) SVMAlg.optimizer @@ -191,7 +186,7 @@ class PythonMLLibAPI extends Serializable { trainRegressionModel( SVMAlg, data, - initialWeightsBA) + initialWeights) } /** @@ -202,10 +197,10 @@ class PythonMLLibAPI extends Serializable { numIterations: Int, stepSize: Double, miniBatchFraction: Double, - initialWeightsBA: Array[Byte], + initialWeights: Vector, regParam: Double, regType: String, - intercept: Boolean): java.util.List[java.lang.Object] = { + intercept: Boolean): JList[Object] = { val LogRegAlg = new LogisticRegressionWithSGD() LogRegAlg.setIntercept(intercept) LogRegAlg.optimizer @@ -224,7 +219,7 @@ class PythonMLLibAPI extends Serializable { trainRegressionModel( LogRegAlg, data, - initialWeightsBA) + initialWeights) } /** @@ -232,13 +227,10 @@ class PythonMLLibAPI extends Serializable { */ def trainNaiveBayes( data: JavaRDD[LabeledPoint], - lambda: Double): java.util.List[java.lang.Object] = { + lambda: Double): JList[Object] = { val model = NaiveBayes.train(data.rdd, lambda) - val ret = new java.util.LinkedList[java.lang.Object]() - ret.add(Vectors.dense(model.labels)) - ret.add(Vectors.dense(model.pi)) - ret.add(model.theta) - ret + List(Vectors.dense(model.labels), Vectors.dense(model.pi), model.theta). + map(_.asInstanceOf[Object]).asJava } /** @@ -260,6 +252,21 @@ class PythonMLLibAPI extends Serializable { return kMeansAlg.run(data.rdd) } + /** + * A Wrapper of MatrixFactorizationModel to provide helpfer method for Python + */ + private[python] class MatrixFactorizationModelWrapper(model: MatrixFactorizationModel) + extends MatrixFactorizationModel(model.rank, model.userFeatures, model.productFeatures) { + + def predict(userAndProducts: JavaRDD[Array[Any]]): RDD[Rating] = + predict(SerDe.asTupleRDD(userAndProducts.rdd)) + + def getUserFeatures = SerDe.fromTuple2RDD(userFeatures.asInstanceOf[RDD[(Any, Any)]]) + + def getProductFeatures = SerDe.fromTuple2RDD(productFeatures.asInstanceOf[RDD[(Any, Any)]]) + + } + /** * Java stub for Python mllib ALS.train(). This stub returns a handle * to the Java object instead of the content of the Java object. Extra care @@ -272,7 +279,7 @@ class PythonMLLibAPI extends Serializable { iterations: Int, lambda: Double, blocks: Int): MatrixFactorizationModel = { - ALS.train(ratings.rdd, rank, iterations, lambda, blocks) + new MatrixFactorizationModelWrapper(ALS.train(ratings.rdd, rank, iterations, lambda, blocks)) } /** @@ -288,7 +295,45 @@ class PythonMLLibAPI extends Serializable { lambda: Double, blocks: Int, alpha: Double): MatrixFactorizationModel = { - ALS.trainImplicit(ratingsJRDD.rdd, rank, iterations, lambda, blocks, alpha) + new MatrixFactorizationModelWrapper( + ALS.trainImplicit(ratingsJRDD.rdd, rank, iterations, lambda, blocks, alpha)) + } + + /** + * Java stub for Normalizer.transform() + */ + def normalizeVector(p: Double, vector: Vector): Vector = { + new Normalizer(p).transform(vector) + } + + /** + * Java stub for Normalizer.transform() + */ + def normalizeVector(p: Double, rdd: JavaRDD[Vector]): JavaRDD[Vector] = { + new Normalizer(p).transform(rdd) + } + + /** + * Java stub for IDF.fit(). This stub returns a + * handle to the Java object instead of the content of the Java object. + * Extra care needs to be taken in the Python code to ensure it gets freed on + * exit; see the Py4J documentation. + */ + def fitStandardScaler( + withMean: Boolean, + withStd: Boolean, + data: JavaRDD[Vector]): StandardScalerModel = { + new StandardScaler(withMean, withStd).fit(data.rdd) + } + + /** + * Java stub for IDF.fit(). This stub returns a + * handle to the Java object instead of the content of the Java object. + * Extra care needs to be taken in the Python code to ensure it gets freed on + * exit; see the Py4J documentation. + */ + def fitIDF(minDocFreq: Int, dataset: JavaRDD[Vector]): IDFModel = { + new IDF(minDocFreq).fit(dataset) } /** @@ -328,19 +373,25 @@ class PythonMLLibAPI extends Serializable { model.transform(word) } - def findSynonyms(word: String, num: Int): java.util.List[java.lang.Object] = { + /** + * Transforms an RDD of words to its vector representation + * @param rdd an RDD of words + * @return an RDD of vector representations of words + */ + def transform(rdd: JavaRDD[String]): JavaRDD[Vector] = { + rdd.rdd.map(model.transform) + } + + def findSynonyms(word: String, num: Int): JList[Object] = { val vec = transform(word) findSynonyms(vec, num) } - def findSynonyms(vector: Vector, num: Int): java.util.List[java.lang.Object] = { + def findSynonyms(vector: Vector, num: Int): JList[Object] = { val result = model.findSynonyms(vector, num) val similarity = Vectors.dense(result.map(_._2)) val words = result.map(_._1) - val ret = new java.util.LinkedList[java.lang.Object]() - ret.add(words) - ret.add(similarity) - ret + List(words, similarity).map(_.asInstanceOf[Object]).asJava } } @@ -350,13 +401,13 @@ class PythonMLLibAPI extends Serializable { * Extra care needs to be taken in the Python code to ensure it gets freed on exit; * see the Py4J documentation. * @param data Training data - * @param categoricalFeaturesInfoJMap Categorical features info, as Java map + * @param categoricalFeaturesInfo Categorical features info, as Java map */ def trainDecisionTreeModel( data: JavaRDD[LabeledPoint], algoStr: String, numClasses: Int, - categoricalFeaturesInfoJMap: java.util.Map[Int, Int], + categoricalFeaturesInfo: JMap[Int, Int], impurityStr: String, maxDepth: Int, maxBins: Int, @@ -372,7 +423,7 @@ class PythonMLLibAPI extends Serializable { maxDepth = maxDepth, numClassesForClassification = numClasses, maxBins = maxBins, - categoricalFeaturesInfo = categoricalFeaturesInfoJMap.asScala.toMap, + categoricalFeaturesInfo = categoricalFeaturesInfo.asScala.toMap, minInstancesPerNode = minInstancesPerNode, minInfoGain = minInfoGain) @@ -544,7 +595,7 @@ private[spark] object SerDe extends Serializable { if (objects.length == 0 || objects.length > 3) { out.write(Opcodes.MARK) } - objects.foreach(pickler.save(_)) + objects.foreach(pickler.save) val code = objects.length match { case 1 => Opcodes.TUPLE1 case 2 => Opcodes.TUPLE2 @@ -674,7 +725,7 @@ private[spark] object SerDe extends Serializable { } /* convert RDD[Tuple2[,]] to RDD[Array[Any]] */ - def fromTuple2RDD(rdd: RDD[Tuple2[Any, Any]]): RDD[Array[Any]] = { + def fromTuple2RDD(rdd: RDD[(Any, Any)]): RDD[Array[Any]] = { rdd.map(x => Array(x._1, x._2)) } @@ -685,7 +736,7 @@ private[spark] object SerDe extends Serializable { def javaToPython(jRDD: JavaRDD[Any]): JavaRDD[Array[Byte]] = { jRDD.rdd.mapPartitions { iter => initialize() // let it called in executor - new PythonRDD.AutoBatchedPickler(iter) + new SerDeUtil.AutoBatchedPickler(iter) } } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/StreamingKMeans.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/StreamingKMeans.scala new file mode 100644 index 0000000000000..6189dce9b27da --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/StreamingKMeans.scala @@ -0,0 +1,268 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.clustering + +import scala.reflect.ClassTag + +import org.apache.spark.Logging +import org.apache.spark.SparkContext._ +import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.mllib.linalg.{BLAS, Vector, Vectors} +import org.apache.spark.rdd.RDD +import org.apache.spark.streaming.StreamingContext._ +import org.apache.spark.streaming.dstream.DStream +import org.apache.spark.util.Utils +import org.apache.spark.util.random.XORShiftRandom + +/** + * :: DeveloperApi :: + * StreamingKMeansModel extends MLlib's KMeansModel for streaming + * algorithms, so it can keep track of a continuously updated weight + * associated with each cluster, and also update the model by + * doing a single iteration of the standard k-means algorithm. + * + * The update algorithm uses the "mini-batch" KMeans rule, + * generalized to incorporate forgetfullness (i.e. decay). + * The update rule (for each cluster) is: + * + * c_t+1 = [(c_t * n_t * a) + (x_t * m_t)] / [n_t + m_t] + * n_t+t = n_t * a + m_t + * + * Where c_t is the previously estimated centroid for that cluster, + * n_t is the number of points assigned to it thus far, x_t is the centroid + * estimated on the current batch, and m_t is the number of points assigned + * to that centroid in the current batch. + * + * The decay factor 'a' scales the contribution of the clusters as estimated thus far, + * by applying a as a discount weighting on the current point when evaluating + * new incoming data. If a=1, all batches are weighted equally. If a=0, new centroids + * are determined entirely by recent data. Lower values correspond to + * more forgetting. + * + * Decay can optionally be specified by a half life and associated + * time unit. The time unit can either be a batch of data or a single + * data point. Considering data arrived at time t, the half life h is defined + * such that at time t + h the discount applied to the data from t is 0.5. + * The definition remains the same whether the time unit is given + * as batches or points. + * + */ +@DeveloperApi +class StreamingKMeansModel( + override val clusterCenters: Array[Vector], + val clusterWeights: Array[Double]) extends KMeansModel(clusterCenters) with Logging { + + /** Perform a k-means update on a batch of data. */ + def update(data: RDD[Vector], decayFactor: Double, timeUnit: String): StreamingKMeansModel = { + + // find nearest cluster to each point + val closest = data.map(point => (this.predict(point), (point, 1L))) + + // get sums and counts for updating each cluster + val mergeContribs: ((Vector, Long), (Vector, Long)) => (Vector, Long) = (p1, p2) => { + BLAS.axpy(1.0, p2._1, p1._1) + (p1._1, p1._2 + p2._2) + } + val dim = clusterCenters(0).size + val pointStats: Array[(Int, (Vector, Long))] = closest + .aggregateByKey((Vectors.zeros(dim), 0L))(mergeContribs, mergeContribs) + .collect() + + val discount = timeUnit match { + case StreamingKMeans.BATCHES => decayFactor + case StreamingKMeans.POINTS => + val numNewPoints = pointStats.view.map { case (_, (_, n)) => + n + }.sum + math.pow(decayFactor, numNewPoints) + } + + // apply discount to weights + BLAS.scal(discount, Vectors.dense(clusterWeights)) + + // implement update rule + pointStats.foreach { case (label, (sum, count)) => + val centroid = clusterCenters(label) + + val updatedWeight = clusterWeights(label) + count + val lambda = count / math.max(updatedWeight, 1e-16) + + clusterWeights(label) = updatedWeight + BLAS.scal(1.0 - lambda, centroid) + BLAS.axpy(lambda / count, sum, centroid) + + // display the updated cluster centers + val display = clusterCenters(label).size match { + case x if x > 100 => centroid.toArray.take(100).mkString("[", ",", "...") + case _ => centroid.toArray.mkString("[", ",", "]") + } + + logInfo(s"Cluster $label updated with weight $updatedWeight and centroid: $display") + } + + // Check whether the smallest cluster is dying. If so, split the largest cluster. + val weightsWithIndex = clusterWeights.view.zipWithIndex + val (maxWeight, largest) = weightsWithIndex.maxBy(_._1) + val (minWeight, smallest) = weightsWithIndex.minBy(_._1) + if (minWeight < 1e-8 * maxWeight) { + logInfo(s"Cluster $smallest is dying. Split the largest cluster $largest into two.") + val weight = (maxWeight + minWeight) / 2.0 + clusterWeights(largest) = weight + clusterWeights(smallest) = weight + val largestClusterCenter = clusterCenters(largest) + val smallestClusterCenter = clusterCenters(smallest) + var j = 0 + while (j < dim) { + val x = largestClusterCenter(j) + val p = 1e-14 * math.max(math.abs(x), 1.0) + largestClusterCenter.toBreeze(j) = x + p + smallestClusterCenter.toBreeze(j) = x - p + j += 1 + } + } + + this + } +} + +/** + * :: DeveloperApi :: + * StreamingKMeans provides methods for configuring a + * streaming k-means analysis, training the model on streaming, + * and using the model to make predictions on streaming data. + * See KMeansModel for details on algorithm and update rules. + * + * Use a builder pattern to construct a streaming k-means analysis + * in an application, like: + * + * val model = new StreamingKMeans() + * .setDecayFactor(0.5) + * .setK(3) + * .setRandomCenters(5, 100.0) + * .trainOn(DStream) + */ +@DeveloperApi +class StreamingKMeans( + var k: Int, + var decayFactor: Double, + var timeUnit: String) extends Logging { + + def this() = this(2, 1.0, StreamingKMeans.BATCHES) + + protected var model: StreamingKMeansModel = new StreamingKMeansModel(null, null) + + /** Set the number of clusters. */ + def setK(k: Int): this.type = { + this.k = k + this + } + + /** Set the decay factor directly (for forgetful algorithms). */ + def setDecayFactor(a: Double): this.type = { + this.decayFactor = decayFactor + this + } + + /** Set the half life and time unit ("batches" or "points") for forgetful algorithms. */ + def setHalfLife(halfLife: Double, timeUnit: String): this.type = { + if (timeUnit != StreamingKMeans.BATCHES && timeUnit != StreamingKMeans.POINTS) { + throw new IllegalArgumentException("Invalid time unit for decay: " + timeUnit) + } + this.decayFactor = math.exp(math.log(0.5) / halfLife) + logInfo("Setting decay factor to: %g ".format (this.decayFactor)) + this.timeUnit = timeUnit + this + } + + /** Specify initial centers directly. */ + def setInitialCenters(centers: Array[Vector], weights: Array[Double]): this.type = { + model = new StreamingKMeansModel(centers, weights) + this + } + + /** + * Initialize random centers, requiring only the number of dimensions. + * + * @param dim Number of dimensions + * @param weight Weight for each center + * @param seed Random seed + */ + def setRandomCenters(dim: Int, weight: Double, seed: Long = Utils.random.nextLong): this.type = { + val random = new XORShiftRandom(seed) + val centers = Array.fill(k)(Vectors.dense(Array.fill(dim)(random.nextGaussian()))) + val weights = Array.fill(k)(weight) + model = new StreamingKMeansModel(centers, weights) + this + } + + /** Return the latest model. */ + def latestModel(): StreamingKMeansModel = { + model + } + + /** + * Update the clustering model by training on batches of data from a DStream. + * This operation registers a DStream for training the model, + * checks whether the cluster centers have been initialized, + * and updates the model using each batch of data from the stream. + * + * @param data DStream containing vector data + */ + def trainOn(data: DStream[Vector]) { + assertInitialized() + data.foreachRDD { (rdd, time) => + model = model.update(rdd, decayFactor, timeUnit) + } + } + + /** + * Use the clustering model to make predictions on batches of data from a DStream. + * + * @param data DStream containing vector data + * @return DStream containing predictions + */ + def predictOn(data: DStream[Vector]): DStream[Int] = { + assertInitialized() + data.map(model.predict) + } + + /** + * Use the model to make predictions on the values of a DStream and carry over its keys. + * + * @param data DStream containing (key, feature vector) pairs + * @tparam K key type + * @return DStream containing the input keys and the predictions as values + */ + def predictOnValues[K: ClassTag](data: DStream[(K, Vector)]): DStream[(K, Int)] = { + assertInitialized() + data.mapValues(model.predict) + } + + /** Check whether cluster centers have been initialized. */ + private[this] def assertInitialized(): Unit = { + if (model.clusterCenters == null) { + throw new IllegalStateException( + "Initial cluster centers must be set before starting predictions") + } + } +} + +private[clustering] object StreamingKMeans { + final val BATCHES = "batches" + final val POINTS = "points" +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/evaluation/MultilabelMetrics.scala b/mllib/src/main/scala/org/apache/spark/mllib/evaluation/MultilabelMetrics.scala new file mode 100644 index 0000000000000..ea10bde5fa252 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/evaluation/MultilabelMetrics.scala @@ -0,0 +1,157 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.evaluation + +import org.apache.spark.rdd.RDD +import org.apache.spark.SparkContext._ + +/** + * Evaluator for multilabel classification. + * @param predictionAndLabels an RDD of (predictions, labels) pairs, + * both are non-null Arrays, each with unique elements. + */ +class MultilabelMetrics(predictionAndLabels: RDD[(Array[Double], Array[Double])]) { + + private lazy val numDocs: Long = predictionAndLabels.count() + + private lazy val numLabels: Long = predictionAndLabels.flatMap { case (_, labels) => + labels}.distinct().count() + + /** + * Returns subset accuracy + * (for equal sets of labels) + */ + lazy val subsetAccuracy: Double = predictionAndLabels.filter { case (predictions, labels) => + predictions.deep == labels.deep + }.count().toDouble / numDocs + + /** + * Returns accuracy + */ + lazy val accuracy: Double = predictionAndLabels.map { case (predictions, labels) => + labels.intersect(predictions).size.toDouble / + (labels.size + predictions.size - labels.intersect(predictions).size)}.sum / numDocs + + + /** + * Returns Hamming-loss + */ + lazy val hammingLoss: Double = predictionAndLabels.map { case (predictions, labels) => + labels.size + predictions.size - 2 * labels.intersect(predictions).size + }.sum / (numDocs * numLabels) + + /** + * Returns document-based precision averaged by the number of documents + */ + lazy val precision: Double = predictionAndLabels.map { case (predictions, labels) => + if (predictions.size > 0) { + predictions.intersect(labels).size.toDouble / predictions.size + } else { + 0 + } + }.sum / numDocs + + /** + * Returns document-based recall averaged by the number of documents + */ + lazy val recall: Double = predictionAndLabels.map { case (predictions, labels) => + labels.intersect(predictions).size.toDouble / labels.size + }.sum / numDocs + + /** + * Returns document-based f1-measure averaged by the number of documents + */ + lazy val f1Measure: Double = predictionAndLabels.map { case (predictions, labels) => + 2.0 * predictions.intersect(labels).size / (predictions.size + labels.size) + }.sum / numDocs + + private lazy val tpPerClass = predictionAndLabels.flatMap { case (predictions, labels) => + predictions.intersect(labels) + }.countByValue() + + private lazy val fpPerClass = predictionAndLabels.flatMap { case (predictions, labels) => + predictions.diff(labels) + }.countByValue() + + private lazy val fnPerClass = predictionAndLabels.flatMap { case(predictions, labels) => + labels.diff(predictions) + }.countByValue() + + /** + * Returns precision for a given label (category) + * @param label the label. + */ + def precision(label: Double) = { + val tp = tpPerClass(label) + val fp = fpPerClass.getOrElse(label, 0L) + if (tp + fp == 0) 0 else tp.toDouble / (tp + fp) + } + + /** + * Returns recall for a given label (category) + * @param label the label. + */ + def recall(label: Double) = { + val tp = tpPerClass(label) + val fn = fnPerClass.getOrElse(label, 0L) + if (tp + fn == 0) 0 else tp.toDouble / (tp + fn) + } + + /** + * Returns f1-measure for a given label (category) + * @param label the label. + */ + def f1Measure(label: Double) = { + val p = precision(label) + val r = recall(label) + if((p + r) == 0) 0 else 2 * p * r / (p + r) + } + + private lazy val sumTp = tpPerClass.foldLeft(0L) { case (sum, (_, tp)) => sum + tp } + private lazy val sumFpClass = fpPerClass.foldLeft(0L) { case (sum, (_, fp)) => sum + fp } + private lazy val sumFnClass = fnPerClass.foldLeft(0L) { case (sum, (_, fn)) => sum + fn } + + /** + * Returns micro-averaged label-based precision + * (equals to micro-averaged document-based precision) + */ + lazy val microPrecision = { + val sumFp = fpPerClass.foldLeft(0L){ case(cum, (_, fp)) => cum + fp} + sumTp.toDouble / (sumTp + sumFp) + } + + /** + * Returns micro-averaged label-based recall + * (equals to micro-averaged document-based recall) + */ + lazy val microRecall = { + val sumFn = fnPerClass.foldLeft(0.0){ case(cum, (_, fn)) => cum + fn} + sumTp.toDouble / (sumTp + sumFn) + } + + /** + * Returns micro-averaged label-based f1-measure + * (equals to micro-averaged document-based f1-measure) + */ + lazy val microF1Measure = 2.0 * sumTp / (2 * sumTp + sumFnClass + sumFpClass) + + /** + * Returns the sequence of labels in ascending order + */ + lazy val labels: Array[Double] = tpPerClass.keys.toArray.sorted +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/evaluation/RegressionMetrics.scala b/mllib/src/main/scala/org/apache/spark/mllib/evaluation/RegressionMetrics.scala new file mode 100644 index 0000000000000..693117d820580 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/evaluation/RegressionMetrics.scala @@ -0,0 +1,89 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.evaluation + +import org.apache.spark.annotation.Experimental +import org.apache.spark.rdd.RDD +import org.apache.spark.Logging +import org.apache.spark.mllib.linalg.Vectors +import org.apache.spark.mllib.stat.{MultivariateStatisticalSummary, MultivariateOnlineSummarizer} + +/** + * :: Experimental :: + * Evaluator for regression. + * + * @param predictionAndObservations an RDD of (prediction, observation) pairs. + */ +@Experimental +class RegressionMetrics(predictionAndObservations: RDD[(Double, Double)]) extends Logging { + + /** + * Use MultivariateOnlineSummarizer to calculate summary statistics of observations and errors. + */ + private lazy val summary: MultivariateStatisticalSummary = { + val summary: MultivariateStatisticalSummary = predictionAndObservations.map { + case (prediction, observation) => Vectors.dense(observation, observation - prediction) + }.aggregate(new MultivariateOnlineSummarizer())( + (summary, v) => summary.add(v), + (sum1, sum2) => sum1.merge(sum2) + ) + summary + } + + /** + * Returns the explained variance regression score. + * explainedVariance = 1 - variance(y - \hat{y}) / variance(y) + * Reference: [[http://en.wikipedia.org/wiki/Explained_variation]] + */ + def explainedVariance: Double = { + 1 - summary.variance(1) / summary.variance(0) + } + + /** + * Returns the mean absolute error, which is a risk function corresponding to the + * expected value of the absolute error loss or l1-norm loss. + */ + def meanAbsoluteError: Double = { + summary.normL1(1) / summary.count + } + + /** + * Returns the mean squared error, which is a risk function corresponding to the + * expected value of the squared error loss or quadratic loss. + */ + def meanSquaredError: Double = { + val rmse = summary.normL2(1) / math.sqrt(summary.count) + rmse * rmse + } + + /** + * Returns the root mean squared error, which is defined as the square root of + * the mean squared error. + */ + def rootMeanSquaredError: Double = { + summary.normL2(1) / math.sqrt(summary.count) + } + + /** + * Returns R^2^, the coefficient of determination. + * Reference: [[http://en.wikipedia.org/wiki/Coefficient_of_determination]] + */ + def r2: Double = { + 1 - math.pow(summary.normL2(1), 2) / (summary.variance(0) * (summary.count - 1)) + } +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/feature/Normalizer.scala b/mllib/src/main/scala/org/apache/spark/mllib/feature/Normalizer.scala index 4734251127bb4..dfad25d57c947 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/feature/Normalizer.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/feature/Normalizer.scala @@ -26,7 +26,7 @@ import org.apache.spark.mllib.linalg.{Vector, Vectors} * :: Experimental :: * Normalizes samples individually to unit L^p^ norm * - * For any 1 <= p < Double.PositiveInfinity, normalizes samples using + * For any 1 <= p < Double.PositiveInfinity, normalizes samples using * sum(abs(vector).^p^)^(1/p)^ as norm. * * For p = Double.PositiveInfinity, max(abs(vector)) will be used as norm for normalization. diff --git a/mllib/src/main/scala/org/apache/spark/mllib/feature/VectorTransformer.scala b/mllib/src/main/scala/org/apache/spark/mllib/feature/VectorTransformer.scala index 415a845332d45..7358c1c84f79c 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/feature/VectorTransformer.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/feature/VectorTransformer.scala @@ -18,6 +18,7 @@ package org.apache.spark.mllib.feature import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.api.java.JavaRDD import org.apache.spark.mllib.linalg.Vector import org.apache.spark.rdd.RDD @@ -48,4 +49,14 @@ trait VectorTransformer extends Serializable { data.map(x => this.transform(x)) } + /** + * Applies transformation on an JavaRDD[Vector]. + * + * @param data JavaRDD[Vector] to be transformed. + * @return transformed JavaRDD[Vector]. + */ + def transform(data: JavaRDD[Vector]): JavaRDD[Vector] = { + transform(data.rdd) + } + } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala b/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala index d321994c2a651..f5f7ad613d4c4 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala @@ -432,7 +432,7 @@ class Word2VecModel private[mllib] ( throw new IllegalStateException(s"$word not in vocabulary") } } - + /** * Find synonyms of a word * @param word a word @@ -443,7 +443,7 @@ class Word2VecModel private[mllib] ( val vector = transform(word) findSynonyms(vector,num) } - + /** * Find synonyms of the vector representation of a word * @param vector vector representation of a word diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala index 6af225b7f49f7..ac217edc619ab 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala @@ -17,22 +17,26 @@ package org.apache.spark.mllib.linalg -import java.lang.{Double => JavaDouble, Integer => JavaInteger, Iterable => JavaIterable} import java.util +import java.lang.{Double => JavaDouble, Integer => JavaInteger, Iterable => JavaIterable} import scala.annotation.varargs import scala.collection.JavaConverters._ import breeze.linalg.{DenseVector => BDV, SparseVector => BSV, Vector => BV} -import org.apache.spark.mllib.util.NumericParser import org.apache.spark.SparkException +import org.apache.spark.mllib.util.NumericParser +import org.apache.spark.sql.catalyst.annotation.SQLUserDefinedType +import org.apache.spark.sql.catalyst.expressions.{GenericMutableRow, Row} +import org.apache.spark.sql.catalyst.types._ /** * Represents a numeric vector, whose index type is Int and value type is Double. * * Note: Users should not implement this interface. */ +@SQLUserDefinedType(udt = classOf[VectorUDT]) sealed trait Vector extends Serializable { /** @@ -74,6 +78,65 @@ sealed trait Vector extends Serializable { } } +/** + * User-defined type for [[Vector]] which allows easy interaction with SQL + * via [[org.apache.spark.sql.SchemaRDD]]. + */ +private[spark] class VectorUDT extends UserDefinedType[Vector] { + + override def sqlType: StructType = { + // type: 0 = sparse, 1 = dense + // We only use "values" for dense vectors, and "size", "indices", and "values" for sparse + // vectors. The "values" field is nullable because we might want to add binary vectors later, + // which uses "size" and "indices", but not "values". + StructType(Seq( + StructField("type", ByteType, nullable = false), + StructField("size", IntegerType, nullable = true), + StructField("indices", ArrayType(IntegerType, containsNull = false), nullable = true), + StructField("values", ArrayType(DoubleType, containsNull = false), nullable = true))) + } + + override def serialize(obj: Any): Row = { + val row = new GenericMutableRow(4) + obj match { + case sv: SparseVector => + row.setByte(0, 0) + row.setInt(1, sv.size) + row.update(2, sv.indices.toSeq) + row.update(3, sv.values.toSeq) + case dv: DenseVector => + row.setByte(0, 1) + row.setNullAt(1) + row.setNullAt(2) + row.update(3, dv.values.toSeq) + } + row + } + + override def deserialize(datum: Any): Vector = { + datum match { + case row: Row => + require(row.length == 4, + s"VectorUDT.deserialize given row with length ${row.length} but requires length == 4") + val tpe = row.getByte(0) + tpe match { + case 0 => + val size = row.getInt(1) + val indices = row.getAs[Iterable[Int]](2).toArray + val values = row.getAs[Iterable[Double]](3).toArray + new SparseVector(size, indices, values) + case 1 => + val values = row.getAs[Iterable[Double]](3).toArray + new DenseVector(values) + } + } + } + + override def pyUDT: String = "pyspark.mllib.linalg.VectorUDT" + + override def userClass: Class[Vector] = classOf[Vector] +} + /** * Factory methods for [[org.apache.spark.mllib.linalg.Vector]]. * We don't use the name `Vector` because Scala imports @@ -191,6 +254,7 @@ object Vectors { /** * A dense vector represented by a value array. */ +@SQLUserDefinedType(udt = classOf[VectorUDT]) class DenseVector(val values: Array[Double]) extends Vector { override def size: Int = values.length @@ -215,6 +279,7 @@ class DenseVector(val values: Array[Double]) extends Vector { * @param indices index array, assume to be strictly increasing. * @param values value array, must have the same length as the index array. */ +@SQLUserDefinedType(udt = classOf[VectorUDT]) class SparseVector( override val size: Int, val indices: Array[Int], diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala index ec2d481dccc22..10a515af88802 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala @@ -152,7 +152,7 @@ class RowMatrix( * storing the right singular vectors, is computed via matrix multiplication as * U = A * (V * S^-1^), if requested by user. The actual method to use is determined * automatically based on the cost: - * - If n is small (n < 100) or k is large compared with n (k > n / 2), we compute the Gramian + * - If n is small (n < 100) or k is large compared with n (k > n / 2), we compute the Gramian * matrix first and then compute its top eigenvalues and eigenvectors locally on the driver. * This requires a single pass with O(n^2^) storage on each executor and on the driver, and * O(n^2^ k) time on the driver. @@ -169,7 +169,8 @@ class RowMatrix( * @note The conditions that decide which method to use internally and the default parameters are * subject to change. * - * @param k number of leading singular values to keep (0 < k <= n). It might return less than k if + * @param k number of leading singular values to keep (0 < k <= n). + * It might return less than k if * there are numerically zero singular values or there are not enough Ritz values * converged before the maximum number of Arnoldi update iterations is reached (in case * that matrix A is ill-conditioned). @@ -192,7 +193,7 @@ class RowMatrix( /** * The actual SVD implementation, visible for testing. * - * @param k number of leading singular values to keep (0 < k <= n) + * @param k number of leading singular values to keep (0 < k <= n) * @param computeU whether to compute U * @param rCond the reciprocal condition number * @param maxIter max number of iterations (if ARPACK is used) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/optimization/NNLS.scala b/mllib/src/main/scala/org/apache/spark/mllib/optimization/NNLS.scala index e4b436b023794..fef062e02b6ec 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/optimization/NNLS.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/optimization/NNLS.scala @@ -79,7 +79,7 @@ private[mllib] object NNLS { // stopping condition def stop(step: Double, ndir: Double, nx: Double): Boolean = { ((step.isNaN) // NaN - || (step < 1e-6) // too small or negative + || (step < 1e-7) // too small or negative || (step > 1e40) // too small; almost certainly numerical problems || (ndir < 1e-12 * nx) // gradient relatively too small || (ndir < 1e-32) // gradient absolutely too small; numerical issues may lurk diff --git a/mllib/src/main/scala/org/apache/spark/mllib/random/RandomDataGenerator.scala b/mllib/src/main/scala/org/apache/spark/mllib/random/RandomDataGenerator.scala index 28179fbc450c0..51f9b8657c640 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/random/RandomDataGenerator.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/random/RandomDataGenerator.scala @@ -17,8 +17,7 @@ package org.apache.spark.mllib.random -import cern.jet.random.Poisson -import cern.jet.random.engine.DRand +import org.apache.commons.math3.distribution.PoissonDistribution import org.apache.spark.annotation.DeveloperApi import org.apache.spark.util.random.{XORShiftRandom, Pseudorandom} @@ -89,12 +88,13 @@ class StandardNormalGenerator extends RandomDataGenerator[Double] { @DeveloperApi class PoissonGenerator(val mean: Double) extends RandomDataGenerator[Double] { - private var rng = new Poisson(mean, new DRand) + private var rng = new PoissonDistribution(mean) - override def nextValue(): Double = rng.nextDouble() + override def nextValue(): Double = rng.sample() override def setSeed(seed: Long) { - rng = new Poisson(mean, new DRand(seed.toInt)) + rng = new PoissonDistribution(mean) + rng.reseedRandomGenerator(seed) } override def copy(): PoissonGenerator = new PoissonGenerator(mean) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizer.scala b/mllib/src/main/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizer.scala index 3025d4837cab4..fab7c4405c65d 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizer.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizer.scala @@ -20,7 +20,7 @@ package org.apache.spark.mllib.stat import breeze.linalg.{DenseVector => BDV} import org.apache.spark.annotation.DeveloperApi -import org.apache.spark.mllib.linalg.{Vectors, Vector} +import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vectors, Vector} /** * :: DeveloperApi :: @@ -72,9 +72,8 @@ class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with S require(n == sample.size, s"Dimensions mismatch when adding new sample." + s" Expecting $n but got ${sample.size}.") - sample.toBreeze.activeIterator.foreach { - case (_, 0.0) => // Skip explicit zero elements. - case (i, value) => + @inline def update(i: Int, value: Double) = { + if (value != 0.0) { if (currMax(i) < value) { currMax(i) = value } @@ -89,6 +88,24 @@ class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with S currL1(i) += math.abs(value) nnz(i) += 1.0 + } + } + + sample match { + case dv: DenseVector => { + var j = 0 + while (j < dv.size) { + update(j, dv.values(j)) + j += 1 + } + } + case sv: SparseVector => + var j = 0 + while (j < sv.indices.size) { + update(sv.indices(j), sv.values(j)) + j += 1 + } + case v => throw new IllegalArgumentException("Do not support vector type " + v.getClass) } totalCnt += 1 diff --git a/mllib/src/main/scala/org/apache/spark/mllib/stat/test/ChiSqTest.scala b/mllib/src/main/scala/org/apache/spark/mllib/stat/test/ChiSqTest.scala index 0089419c2c5d4..ea82d39b72c03 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/stat/test/ChiSqTest.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/stat/test/ChiSqTest.scala @@ -18,7 +18,7 @@ package org.apache.spark.mllib.stat.test import breeze.linalg.{DenseMatrix => BDM} -import cern.jet.stat.Probability.chiSquareComplemented +import org.apache.commons.math3.distribution.ChiSquaredDistribution import org.apache.spark.{SparkException, Logging} import org.apache.spark.mllib.linalg.{Matrices, Matrix, Vector, Vectors} @@ -33,7 +33,7 @@ import scala.collection.mutable * on an input of type `Matrix` in which independence between columns is assessed. * We also provide a method for computing the chi-squared statistic between each feature and the * label for an input `RDD[LabeledPoint]`, return an `Array[ChiSquaredTestResult]` of size = - * number of features in the inpuy RDD. + * number of features in the input RDD. * * Supported methods for goodness of fit: `pearson` (default) * Supported methods for independence: `pearson` (default) @@ -139,7 +139,7 @@ private[stat] object ChiSqTest extends Logging { } /* - * Pearon's goodness of fit test on the input observed and expected counts/relative frequencies. + * Pearson's goodness of fit test on the input observed and expected counts/relative frequencies. * Uniform distribution is assumed when `expected` is not passed in. */ def chiSquared(observed: Vector, @@ -188,12 +188,12 @@ private[stat] object ChiSqTest extends Logging { } } val df = size - 1 - val pValue = chiSquareComplemented(df, statistic) + val pValue = 1.0 - new ChiSquaredDistribution(df).cumulativeProbability(statistic) new ChiSqTestResult(pValue, df, statistic, PEARSON.name, NullHypothesis.goodnessOfFit.toString) } /* - * Pearon's independence test on the input contingency matrix. + * Pearson's independence test on the input contingency matrix. * TODO: optimize for SparseMatrix when it becomes supported. */ def chiSquaredMatrix(counts: Matrix, methodName:String = PEARSON.name): ChiSqTestResult = { @@ -238,7 +238,13 @@ private[stat] object ChiSqTest extends Logging { j += 1 } val df = (numCols - 1) * (numRows - 1) - val pValue = chiSquareComplemented(df, statistic) - new ChiSqTestResult(pValue, df, statistic, methodName, NullHypothesis.independence.toString) + if (df == 0) { + // 1 column or 1 row. Constant distribution is independent of anything. + // pValue = 1.0 and statistic = 0.0 in this case. + new ChiSqTestResult(1.0, 0, 0.0, methodName, NullHypothesis.independence.toString) + } else { + val pValue = 1.0 - new ChiSquaredDistribution(df).cumulativeProbability(statistic) + new ChiSqTestResult(pValue, df, statistic, methodName, NullHypothesis.independence.toString) + } } } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala index 6737a2f4176c2..78acc17f901c1 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala @@ -62,7 +62,7 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo // Note: random seed will not be used since numTrees = 1. val rf = new RandomForest(strategy, numTrees = 1, featureSubsetStrategy = "all", seed = 0) val rfModel = rf.train(input) - rfModel.trees(0) + rfModel.weakHypotheses(0) } } @@ -437,6 +437,11 @@ object DecisionTree extends Serializable with Logging { * @param bins possible bins for all features, indexed (numFeatures)(numBins) * @param nodeQueue Queue of nodes to split, with values (treeIndex, node). * Updated with new non-leaf nodes which are created. + * @param nodeIdCache Node Id cache containing an RDD of Array[Int] where + * each value in the array is the data point's node Id + * for a corresponding tree. This is used to prevent the need + * to pass the entire tree to the executors during + * the node stat aggregation phase. */ private[tree] def findBestSplits( input: RDD[BaggedPoint[TreePoint]], @@ -447,7 +452,8 @@ object DecisionTree extends Serializable with Logging { splits: Array[Array[Split]], bins: Array[Array[Bin]], nodeQueue: mutable.Queue[(Int, Node)], - timer: TimeTracker = new TimeTracker): Unit = { + timer: TimeTracker = new TimeTracker, + nodeIdCache: Option[NodeIdCache] = None): Unit = { /* * The high-level descriptions of the best split optimizations are noted here. @@ -479,6 +485,37 @@ object DecisionTree extends Serializable with Logging { logDebug("isMulticlass = " + metadata.isMulticlass) logDebug("isMulticlassWithCategoricalFeatures = " + metadata.isMulticlassWithCategoricalFeatures) + logDebug("using nodeIdCache = " + nodeIdCache.nonEmpty.toString) + + /** + * Performs a sequential aggregation over a partition for a particular tree and node. + * + * For each feature, the aggregate sufficient statistics are updated for the relevant + * bins. + * + * @param treeIndex Index of the tree that we want to perform aggregation for. + * @param nodeInfo The node info for the tree node. + * @param agg Array storing aggregate calculation, with a set of sufficient statistics + * for each (node, feature, bin). + * @param baggedPoint Data point being aggregated. + */ + def nodeBinSeqOp( + treeIndex: Int, + nodeInfo: RandomForest.NodeIndexInfo, + agg: Array[DTStatsAggregator], + baggedPoint: BaggedPoint[TreePoint]): Unit = { + if (nodeInfo != null) { + val aggNodeIndex = nodeInfo.nodeIndexInGroup + val featuresForNode = nodeInfo.featureSubset + val instanceWeight = baggedPoint.subsampleWeights(treeIndex) + if (metadata.unorderedFeatures.isEmpty) { + orderedBinSeqOp(agg(aggNodeIndex), baggedPoint.datum, instanceWeight, featuresForNode) + } else { + mixedBinSeqOp(agg(aggNodeIndex), baggedPoint.datum, bins, metadata.unorderedFeatures, + instanceWeight, featuresForNode) + } + } + } /** * Performs a sequential aggregation over a partition. @@ -497,20 +534,25 @@ object DecisionTree extends Serializable with Logging { treeToNodeToIndexInfo.foreach { case (treeIndex, nodeIndexToInfo) => val nodeIndex = predictNodeIndex(topNodes(treeIndex), baggedPoint.datum.binnedFeatures, bins, metadata.unorderedFeatures) - val nodeInfo = nodeIndexToInfo.getOrElse(nodeIndex, null) - // If the example does not reach a node in this group, then nodeIndex = null. - if (nodeInfo != null) { - val aggNodeIndex = nodeInfo.nodeIndexInGroup - val featuresForNode = nodeInfo.featureSubset - val instanceWeight = baggedPoint.subsampleWeights(treeIndex) - if (metadata.unorderedFeatures.isEmpty) { - orderedBinSeqOp(agg(aggNodeIndex), baggedPoint.datum, instanceWeight, featuresForNode) - } else { - mixedBinSeqOp(agg(aggNodeIndex), baggedPoint.datum, bins, metadata.unorderedFeatures, - instanceWeight, featuresForNode) - } - } + nodeBinSeqOp(treeIndex, nodeIndexToInfo.getOrElse(nodeIndex, null), agg, baggedPoint) + } + + agg + } + + /** + * Do the same thing as binSeqOp, but with nodeIdCache. + */ + def binSeqOpWithNodeIdCache( + agg: Array[DTStatsAggregator], + dataPoint: (BaggedPoint[TreePoint], Array[Int])): Array[DTStatsAggregator] = { + treeToNodeToIndexInfo.foreach { case (treeIndex, nodeIndexToInfo) => + val baggedPoint = dataPoint._1 + val nodeIdCache = dataPoint._2 + val nodeIndex = nodeIdCache(treeIndex) + nodeBinSeqOp(treeIndex, nodeIndexToInfo.getOrElse(nodeIndex, null), agg, baggedPoint) } + agg } @@ -553,7 +595,26 @@ object DecisionTree extends Serializable with Logging { // Finally, only best Splits for nodes are collected to driver to construct decision tree. val nodeToFeatures = getNodeToFeatures(treeToNodeToIndexInfo) val nodeToFeaturesBc = input.sparkContext.broadcast(nodeToFeatures) - val nodeToBestSplits = + + val partitionAggregates : RDD[(Int, DTStatsAggregator)] = if (nodeIdCache.nonEmpty) { + input.zip(nodeIdCache.get.nodeIdsForInstances).mapPartitions { points => + // Construct a nodeStatsAggregators array to hold node aggregate stats, + // each node will have a nodeStatsAggregator + val nodeStatsAggregators = Array.tabulate(numNodes) { nodeIndex => + val featuresForNode = nodeToFeaturesBc.value.flatMap { nodeToFeatures => + Some(nodeToFeatures(nodeIndex)) + } + new DTStatsAggregator(metadata, featuresForNode) + } + + // iterator all instances in current partition and update aggregate stats + points.foreach(binSeqOpWithNodeIdCache(nodeStatsAggregators, _)) + + // transform nodeStatsAggregators array to (nodeIndex, nodeAggregateStats) pairs, + // which can be combined with other partition using `reduceByKey` + nodeStatsAggregators.view.zipWithIndex.map(_.swap).iterator + } + } else { input.mapPartitions { points => // Construct a nodeStatsAggregators array to hold node aggregate stats, // each node will have a nodeStatsAggregator @@ -570,7 +631,10 @@ object DecisionTree extends Serializable with Logging { // transform nodeStatsAggregators array to (nodeIndex, nodeAggregateStats) pairs, // which can be combined with other partition using `reduceByKey` nodeStatsAggregators.view.zipWithIndex.map(_.swap).iterator - }.reduceByKey((a, b) => a.merge(b)) + } + } + + val nodeToBestSplits = partitionAggregates.reduceByKey((a, b) => a.merge(b)) .map { case (nodeIndex, aggStats) => val featuresForNode = nodeToFeaturesBc.value.flatMap { nodeToFeatures => Some(nodeToFeatures(nodeIndex)) @@ -584,6 +648,13 @@ object DecisionTree extends Serializable with Logging { timer.stop("chooseSplits") + val nodeIdUpdaters = if (nodeIdCache.nonEmpty) { + Array.fill[mutable.Map[Int, NodeIndexUpdater]]( + metadata.numTrees)(mutable.Map[Int, NodeIndexUpdater]()) + } else { + null + } + // Iterate over all nodes in this group. nodesForGroup.foreach { case (treeIndex, nodesForTree) => nodesForTree.foreach { node => @@ -613,6 +684,13 @@ object DecisionTree extends Serializable with Logging { node.rightNode = Some(Node(Node.rightChildIndex(nodeIndex), stats.rightPredict, stats.rightImpurity, rightChildIsLeaf)) + if (nodeIdCache.nonEmpty) { + val nodeIndexUpdater = NodeIndexUpdater( + split = split, + nodeIndex = nodeIndex) + nodeIdUpdaters(treeIndex).put(nodeIndex, nodeIndexUpdater) + } + // enqueue left child and right child if they are not leaves if (!leftChildIsLeaf) { nodeQueue.enqueue((treeIndex, node.leftNode.get)) @@ -629,6 +707,10 @@ object DecisionTree extends Serializable with Logging { } } + if (nodeIdCache.nonEmpty) { + // Update the cache if needed. + nodeIdCache.get.updateNodeIndices(input, nodeIdUpdaters, bins) + } } /** diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoosting.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoosting.scala new file mode 100644 index 0000000000000..1a847201ce157 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoosting.scala @@ -0,0 +1,314 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.tree + +import scala.collection.JavaConverters._ + +import org.apache.spark.annotation.Experimental +import org.apache.spark.api.java.JavaRDD +import org.apache.spark.mllib.tree.configuration.{Strategy, BoostingStrategy} +import org.apache.spark.Logging +import org.apache.spark.mllib.tree.impl.TimeTracker +import org.apache.spark.mllib.tree.loss.Losses +import org.apache.spark.rdd.RDD +import org.apache.spark.mllib.regression.LabeledPoint +import org.apache.spark.mllib.tree.model.{WeightedEnsembleModel, DecisionTreeModel} +import org.apache.spark.mllib.tree.configuration.Algo._ +import org.apache.spark.storage.StorageLevel +import org.apache.spark.mllib.tree.configuration.EnsembleCombiningStrategy.Sum + +/** + * :: Experimental :: + * A class that implements gradient boosting for regression and binary classification problems. + * @param boostingStrategy Parameters for the gradient boosting algorithm + */ +@Experimental +class GradientBoosting ( + private val boostingStrategy: BoostingStrategy) extends Serializable with Logging { + + /** + * Method to train a gradient boosting model + * @param input Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]]. + * @return WeightedEnsembleModel that can be used for prediction + */ + def train(input: RDD[LabeledPoint]): WeightedEnsembleModel = { + val algo = boostingStrategy.algo + algo match { + case Regression => GradientBoosting.boost(input, boostingStrategy) + case Classification => + val remappedInput = input.map(x => new LabeledPoint((x.label * 2) - 1, x.features)) + GradientBoosting.boost(remappedInput, boostingStrategy) + case _ => + throw new IllegalArgumentException(s"$algo is not supported by the gradient boosting.") + } + } + +} + + +object GradientBoosting extends Logging { + + /** + * Method to train a gradient boosting model. + * + * Note: Using [[org.apache.spark.mllib.tree.GradientBoosting$#trainRegressor]] + * is recommended to clearly specify regression. + * Using [[org.apache.spark.mllib.tree.GradientBoosting$#trainClassifier]] + * is recommended to clearly specify regression. + * + * @param input Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]]. + * For classification, labels should take values {0, 1, ..., numClasses-1}. + * For regression, labels are real numbers. + * @param boostingStrategy Configuration options for the boosting algorithm. + * @return WeightedEnsembleModel that can be used for prediction + */ + def train( + input: RDD[LabeledPoint], + boostingStrategy: BoostingStrategy): WeightedEnsembleModel = { + new GradientBoosting(boostingStrategy).train(input) + } + + /** + * Method to train a gradient boosting classification model. + * + * @param input Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]]. + * For classification, labels should take values {0, 1, ..., numClasses-1}. + * For regression, labels are real numbers. + * @param boostingStrategy Configuration options for the boosting algorithm. + * @return WeightedEnsembleModel that can be used for prediction + */ + def trainClassifier( + input: RDD[LabeledPoint], + boostingStrategy: BoostingStrategy): WeightedEnsembleModel = { + val algo = boostingStrategy.algo + require(algo == Classification, s"Only Classification algo supported. Provided algo is $algo.") + new GradientBoosting(boostingStrategy).train(input) + } + + /** + * Method to train a gradient boosting regression model. + * + * @param input Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]]. + * For classification, labels should take values {0, 1, ..., numClasses-1}. + * For regression, labels are real numbers. + * @param boostingStrategy Configuration options for the boosting algorithm. + * @return WeightedEnsembleModel that can be used for prediction + */ + def trainRegressor( + input: RDD[LabeledPoint], + boostingStrategy: BoostingStrategy): WeightedEnsembleModel = { + val algo = boostingStrategy.algo + require(algo == Regression, s"Only Regression algo supported. Provided algo is $algo.") + new GradientBoosting(boostingStrategy).train(input) + } + + /** + * Method to train a gradient boosting binary classification model. + * + * @param input Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]]. + * For classification, labels should take values {0, 1, ..., numClasses-1}. + * For regression, labels are real numbers. + * @param numEstimators Number of estimators used in boosting stages. In other words, + * number of boosting iterations performed. + * @param loss Loss function used for minimization during gradient boosting. + * @param learningRate Learning rate for shrinking the contribution of each estimator. The + * learning rate should be between in the interval (0, 1] + * @param subsamplingRate Fraction of the training data used for learning the decision tree. + * @param numClassesForClassification Number of classes for classification. + * (Ignored for regression.) + * @param categoricalFeaturesInfo A map storing information about the categorical variables and + * the number of discrete values they take. For example, + * an entry (n -> k) implies the feature n is categorical with k + * categories 0, 1, 2, ... , k-1. It's important to note that + * features are zero-indexed. + * @param weakLearnerParams Parameters for the weak learner. (Currently only decision tree is + * supported.) + * @return WeightedEnsembleModel that can be used for prediction + */ + def trainClassifier( + input: RDD[LabeledPoint], + numEstimators: Int, + loss: String, + learningRate: Double, + subsamplingRate: Double, + numClassesForClassification: Int, + categoricalFeaturesInfo: Map[Int, Int], + weakLearnerParams: Strategy): WeightedEnsembleModel = { + val lossType = Losses.fromString(loss) + val boostingStrategy = new BoostingStrategy(Classification, numEstimators, lossType, + learningRate, subsamplingRate, numClassesForClassification, categoricalFeaturesInfo, + weakLearnerParams) + new GradientBoosting(boostingStrategy).train(input) + } + + /** + * Method to train a gradient boosting regression model. + * + * @param input Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]]. + * For classification, labels should take values {0, 1, ..., numClasses-1}. + * For regression, labels are real numbers. + * @param numEstimators Number of estimators used in boosting stages. In other words, + * number of boosting iterations performed. + * @param loss Loss function used for minimization during gradient boosting. + * @param learningRate Learning rate for shrinking the contribution of each estimator. The + * learning rate should be between in the interval (0, 1] + * @param subsamplingRate Fraction of the training data used for learning the decision tree. + * @param numClassesForClassification Number of classes for classification. + * (Ignored for regression.) + * @param categoricalFeaturesInfo A map storing information about the categorical variables and + * the number of discrete values they take. For example, + * an entry (n -> k) implies the feature n is categorical with k + * categories 0, 1, 2, ... , k-1. It's important to note that + * features are zero-indexed. + * @param weakLearnerParams Parameters for the weak learner. (Currently only decision tree is + * supported.) + * @return WeightedEnsembleModel that can be used for prediction + */ + def trainRegressor( + input: RDD[LabeledPoint], + numEstimators: Int, + loss: String, + learningRate: Double, + subsamplingRate: Double, + numClassesForClassification: Int, + categoricalFeaturesInfo: Map[Int, Int], + weakLearnerParams: Strategy): WeightedEnsembleModel = { + val lossType = Losses.fromString(loss) + val boostingStrategy = new BoostingStrategy(Regression, numEstimators, lossType, + learningRate, subsamplingRate, numClassesForClassification, categoricalFeaturesInfo, + weakLearnerParams) + new GradientBoosting(boostingStrategy).train(input) + } + + /** + * Java-friendly API for [[org.apache.spark.mllib.tree.GradientBoosting$#trainClassifier]] + */ + def trainClassifier( + input: RDD[LabeledPoint], + numEstimators: Int, + loss: String, + learningRate: Double, + subsamplingRate: Double, + numClassesForClassification: Int, + categoricalFeaturesInfo:java.util.Map[java.lang.Integer, java.lang.Integer], + weakLearnerParams: Strategy): WeightedEnsembleModel = { + trainClassifier(input, numEstimators, loss, learningRate, subsamplingRate, + numClassesForClassification, + categoricalFeaturesInfo.asInstanceOf[java.util.Map[Int, Int]].asScala.toMap, + weakLearnerParams) + } + + /** + * Java-friendly API for [[org.apache.spark.mllib.tree.GradientBoosting$#trainRegressor]] + */ + def trainRegressor( + input: RDD[LabeledPoint], + numEstimators: Int, + loss: String, + learningRate: Double, + subsamplingRate: Double, + numClassesForClassification: Int, + categoricalFeaturesInfo: java.util.Map[java.lang.Integer, java.lang.Integer], + weakLearnerParams: Strategy): WeightedEnsembleModel = { + trainRegressor(input, numEstimators, loss, learningRate, subsamplingRate, + numClassesForClassification, + categoricalFeaturesInfo.asInstanceOf[java.util.Map[Int, Int]].asScala.toMap, + weakLearnerParams) + } + + + /** + * Internal method for performing regression using trees as base learners. + * @param input training dataset + * @param boostingStrategy boosting parameters + * @return + */ + private def boost( + input: RDD[LabeledPoint], + boostingStrategy: BoostingStrategy): WeightedEnsembleModel = { + + val timer = new TimeTracker() + timer.start("total") + timer.start("init") + + // Initialize gradient boosting parameters + val numEstimators = boostingStrategy.numEstimators + val baseLearners = new Array[DecisionTreeModel](numEstimators) + val baseLearnerWeights = new Array[Double](numEstimators) + val loss = boostingStrategy.loss + val learningRate = boostingStrategy.learningRate + val strategy = boostingStrategy.weakLearnerParams + + // Cache input + input.persist(StorageLevel.MEMORY_AND_DISK) + + timer.stop("init") + + logDebug("##########") + logDebug("Building tree 0") + logDebug("##########") + var data = input + + // 1. Initialize tree + timer.start("building tree 0") + val firstTreeModel = new DecisionTree(strategy).train(data) + baseLearners(0) = firstTreeModel + baseLearnerWeights(0) = 1.0 + val startingModel = new WeightedEnsembleModel(Array(firstTreeModel), Array(1.0), Regression, + Sum) + logDebug("error of gbt = " + loss.computeError(startingModel, input)) + // Note: A model of type regression is used since we require raw prediction + timer.stop("building tree 0") + + // psuedo-residual for second iteration + data = input.map(point => LabeledPoint(loss.gradient(startingModel, point), + point.features)) + + var m = 1 + while (m < numEstimators) { + timer.start(s"building tree $m") + logDebug("###################################################") + logDebug("Gradient boosting tree iteration " + m) + logDebug("###################################################") + val model = new DecisionTree(strategy).train(data) + timer.stop(s"building tree $m") + // Create partial model + baseLearners(m) = model + baseLearnerWeights(m) = learningRate + // Note: A model of type regression is used since we require raw prediction + val partialModel = new WeightedEnsembleModel(baseLearners.slice(0, m + 1), + baseLearnerWeights.slice(0, m + 1), Regression, Sum) + logDebug("error of gbt = " + loss.computeError(partialModel, input)) + // Update data with pseudo-residuals + data = input.map(point => LabeledPoint(-loss.gradient(partialModel, point), + point.features)) + m += 1 + } + + timer.stop("total") + + logInfo("Internal timing for DecisionTree:") + logInfo(s"$timer") + + + // 3. Output classifier + new WeightedEnsembleModel(baseLearners, baseLearnerWeights, boostingStrategy.algo, Sum) + + } + +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/RandomForest.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/RandomForest.scala index ebbd8e0257209..9683916d9b3f1 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/RandomForest.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/RandomForest.scala @@ -26,8 +26,9 @@ import org.apache.spark.api.java.JavaRDD import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.tree.configuration.Algo._ import org.apache.spark.mllib.tree.configuration.QuantileStrategy._ +import org.apache.spark.mllib.tree.configuration.EnsembleCombiningStrategy.Average import org.apache.spark.mllib.tree.configuration.Strategy -import org.apache.spark.mllib.tree.impl.{BaggedPoint, TreePoint, DecisionTreeMetadata, TimeTracker} +import org.apache.spark.mllib.tree.impl.{BaggedPoint, TreePoint, DecisionTreeMetadata, TimeTracker, NodeIdCache } import org.apache.spark.mllib.tree.impurity.Impurities import org.apache.spark.mllib.tree.model._ import org.apache.spark.rdd.RDD @@ -59,7 +60,7 @@ import org.apache.spark.util.Utils * if numTrees == 1, set to "all"; * if numTrees > 1 (forest) set to "sqrt" for classification and * to "onethird" for regression. - * @param seed Random seed for bootstrapping and choosing feature subsets. + * @param seed Random seed for bootstrapping and choosing feature subsets. */ @Experimental private class RandomForest ( @@ -78,9 +79,9 @@ private class RandomForest ( /** * Method to train a decision tree model over an RDD * @param input Training data: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]] - * @return RandomForestModel that can be used for prediction + * @return WeightedEnsembleModel that can be used for prediction */ - def train(input: RDD[LabeledPoint]): RandomForestModel = { + def train(input: RDD[LabeledPoint]): WeightedEnsembleModel = { val timer = new TimeTracker() @@ -111,11 +112,20 @@ private class RandomForest ( // Bin feature values (TreePoint representation). // Cache input RDD for speedup during multiple passes. val treeInput = TreePoint.convertToTreeRDD(retaggedInput, bins, metadata) - val baggedInput = if (numTrees > 1) { - BaggedPoint.convertToBaggedRDD(treeInput, numTrees, seed) - } else { - BaggedPoint.convertToBaggedRDDWithoutSampling(treeInput) - }.persist(StorageLevel.MEMORY_AND_DISK) + + val (subsample, withReplacement) = { + // TODO: Have a stricter check for RF in the strategy + val isRandomForest = numTrees > 1 + if (isRandomForest) { + (1.0, true) + } else { + (strategy.subsamplingRate, false) + } + } + + val baggedInput + = BaggedPoint.convertToBaggedRDD(treeInput, subsample, numTrees, withReplacement, seed) + .persist(StorageLevel.MEMORY_AND_DISK) // depth of the decision tree val maxDepth = strategy.maxDepth @@ -150,6 +160,19 @@ private class RandomForest ( * in lower levels). */ + // Create an RDD of node Id cache. + // At first, all the rows belong to the root nodes (node Id == 1). + val nodeIdCache = if (strategy.useNodeIdCache) { + Some(NodeIdCache.init( + data = baggedInput, + numTrees = numTrees, + checkpointDir = strategy.checkpointDir, + checkpointInterval = strategy.checkpointInterval, + initVal = 1)) + } else { + None + } + // FIFO queue of nodes to train: (treeIndex, node) val nodeQueue = new mutable.Queue[(Int, Node)]() @@ -172,7 +195,7 @@ private class RandomForest ( // Choose node splits, and enqueue new nodes as needed. timer.start("findBestSplits") DecisionTree.findBestSplits(baggedInput, metadata, topNodes, nodesForGroup, - treeToNodeToIndexInfo, splits, bins, nodeQueue, timer) + treeToNodeToIndexInfo, splits, bins, nodeQueue, timer, nodeIdCache = nodeIdCache) timer.stop("findBestSplits") } @@ -183,8 +206,14 @@ private class RandomForest ( logInfo("Internal timing for DecisionTree:") logInfo(s"$timer") + // Delete any remaining checkpoints used for node Id cache. + if (nodeIdCache.nonEmpty) { + nodeIdCache.get.deleteAllCheckpoints() + } + val trees = topNodes.map(topNode => new DecisionTreeModel(topNode, strategy.algo)) - RandomForestModel.build(trees) + val treeWeights = Array.fill[Double](numTrees)(1.0) + new WeightedEnsembleModel(trees, treeWeights, strategy.algo, Average) } } @@ -205,14 +234,14 @@ object RandomForest extends Serializable with Logging { * if numTrees > 1 (forest) set to "sqrt" for classification and * to "onethird" for regression. * @param seed Random seed for bootstrapping and choosing feature subsets. - * @return RandomForestModel that can be used for prediction + * @return WeightedEnsembleModel that can be used for prediction */ def trainClassifier( input: RDD[LabeledPoint], strategy: Strategy, numTrees: Int, featureSubsetStrategy: String, - seed: Int): RandomForestModel = { + seed: Int): WeightedEnsembleModel = { require(strategy.algo == Classification, s"RandomForest.trainClassifier given Strategy with invalid algo: ${strategy.algo}") val rf = new RandomForest(strategy, numTrees, featureSubsetStrategy, seed) @@ -243,7 +272,7 @@ object RandomForest extends Serializable with Logging { * @param maxBins maximum number of bins used for splitting features * (suggested value: 100) * @param seed Random seed for bootstrapping and choosing feature subsets. - * @return RandomForestModel that can be used for prediction + * @return WeightedEnsembleModel that can be used for prediction */ def trainClassifier( input: RDD[LabeledPoint], @@ -254,7 +283,7 @@ object RandomForest extends Serializable with Logging { impurity: String, maxDepth: Int, maxBins: Int, - seed: Int = Utils.random.nextInt()): RandomForestModel = { + seed: Int = Utils.random.nextInt()): WeightedEnsembleModel = { val impurityType = Impurities.fromString(impurity) val strategy = new Strategy(Classification, impurityType, maxDepth, numClassesForClassification, maxBins, Sort, categoricalFeaturesInfo) @@ -273,7 +302,7 @@ object RandomForest extends Serializable with Logging { impurity: String, maxDepth: Int, maxBins: Int, - seed: Int): RandomForestModel = { + seed: Int): WeightedEnsembleModel = { trainClassifier(input.rdd, numClassesForClassification, categoricalFeaturesInfo.asInstanceOf[java.util.Map[Int, Int]].asScala.toMap, numTrees, featureSubsetStrategy, impurity, maxDepth, maxBins, seed) @@ -293,14 +322,14 @@ object RandomForest extends Serializable with Logging { * if numTrees > 1 (forest) set to "sqrt" for classification and * to "onethird" for regression. * @param seed Random seed for bootstrapping and choosing feature subsets. - * @return RandomForestModel that can be used for prediction + * @return WeightedEnsembleModel that can be used for prediction */ def trainRegressor( input: RDD[LabeledPoint], strategy: Strategy, numTrees: Int, featureSubsetStrategy: String, - seed: Int): RandomForestModel = { + seed: Int): WeightedEnsembleModel = { require(strategy.algo == Regression, s"RandomForest.trainRegressor given Strategy with invalid algo: ${strategy.algo}") val rf = new RandomForest(strategy, numTrees, featureSubsetStrategy, seed) @@ -330,7 +359,7 @@ object RandomForest extends Serializable with Logging { * @param maxBins maximum number of bins used for splitting features * (suggested value: 100) * @param seed Random seed for bootstrapping and choosing feature subsets. - * @return RandomForestModel that can be used for prediction + * @return WeightedEnsembleModel that can be used for prediction */ def trainRegressor( input: RDD[LabeledPoint], @@ -340,7 +369,7 @@ object RandomForest extends Serializable with Logging { impurity: String, maxDepth: Int, maxBins: Int, - seed: Int = Utils.random.nextInt()): RandomForestModel = { + seed: Int = Utils.random.nextInt()): WeightedEnsembleModel = { val impurityType = Impurities.fromString(impurity) val strategy = new Strategy(Regression, impurityType, maxDepth, 0, maxBins, Sort, categoricalFeaturesInfo) @@ -358,7 +387,7 @@ object RandomForest extends Serializable with Logging { impurity: String, maxDepth: Int, maxBins: Int, - seed: Int): RandomForestModel = { + seed: Int): WeightedEnsembleModel = { trainRegressor(input.rdd, categoricalFeaturesInfo.asInstanceOf[java.util.Map[Int, Int]].asScala.toMap, numTrees, featureSubsetStrategy, impurity, maxDepth, maxBins, seed) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/BoostingStrategy.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/BoostingStrategy.scala new file mode 100644 index 0000000000000..501d9ff9ea9b7 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/BoostingStrategy.scala @@ -0,0 +1,109 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.tree.configuration + +import scala.beans.BeanProperty + +import org.apache.spark.annotation.Experimental +import org.apache.spark.mllib.tree.configuration.Algo._ +import org.apache.spark.mllib.tree.impurity.{Gini, Variance} +import org.apache.spark.mllib.tree.loss.{LogLoss, SquaredError, Loss} + +/** + * :: Experimental :: + * Stores all the configuration options for the boosting algorithms + * @param algo Learning goal. Supported: + * [[org.apache.spark.mllib.tree.configuration.Algo.Classification]], + * [[org.apache.spark.mllib.tree.configuration.Algo.Regression]] + * @param numEstimators Number of estimators used in boosting stages. In other words, + * number of boosting iterations performed. + * @param loss Loss function used for minimization during gradient boosting. + * @param learningRate Learning rate for shrinking the contribution of each estimator. The + * learning rate should be between in the interval (0, 1] + * @param subsamplingRate Fraction of the training data used for learning the decision tree. + * @param numClassesForClassification Number of classes for classification. + * (Ignored for regression.) + * Default value is 2 (binary classification). + * @param categoricalFeaturesInfo A map storing information about the categorical variables and the + * number of discrete values they take. For example, an entry (n -> + * k) implies the feature n is categorical with k categories 0, + * 1, 2, ... , k-1. It's important to note that features are + * zero-indexed. + * @param weakLearnerParams Parameters for weak learners. Currently only decision trees are + * supported. + */ +@Experimental +case class BoostingStrategy( + // Required boosting parameters + algo: Algo, + @BeanProperty var numEstimators: Int, + @BeanProperty var loss: Loss, + // Optional boosting parameters + @BeanProperty var learningRate: Double = 0.1, + @BeanProperty var subsamplingRate: Double = 1.0, + @BeanProperty var numClassesForClassification: Int = 2, + @BeanProperty var categoricalFeaturesInfo: Map[Int, Int] = Map[Int, Int](), + @BeanProperty var weakLearnerParams: Strategy) extends Serializable { + + require(learningRate <= 1, "Learning rate should be <= 1. Provided learning rate is " + + s"$learningRate.") + require(learningRate > 0, "Learning rate should be > 0. Provided learning rate is " + + s"$learningRate.") + + // Ensure values for weak learner are the same as what is provided to the boosting algorithm. + weakLearnerParams.categoricalFeaturesInfo = categoricalFeaturesInfo + weakLearnerParams.numClassesForClassification = numClassesForClassification + weakLearnerParams.subsamplingRate = subsamplingRate + +} + +@Experimental +object BoostingStrategy { + + /** + * Returns default configuration for the boosting algorithm + * @param algo Learning goal. Supported: + * [[org.apache.spark.mllib.tree.configuration.Algo.Classification]], + * [[org.apache.spark.mllib.tree.configuration.Algo.Regression]] + * @return Configuration for boosting algorithm + */ + def defaultParams(algo: Algo): BoostingStrategy = { + val treeStrategy = defaultWeakLearnerParams(algo) + algo match { + case Classification => + new BoostingStrategy(algo, 100, LogLoss, weakLearnerParams = treeStrategy) + case Regression => + new BoostingStrategy(algo, 100, SquaredError, weakLearnerParams = treeStrategy) + case _ => + throw new IllegalArgumentException(s"$algo is not supported by the boosting.") + } + } + + /** + * Returns default configuration for the weak learner (decision tree) algorithm + * @param algo Learning goal. Supported: + * [[org.apache.spark.mllib.tree.configuration.Algo.Classification]], + * [[org.apache.spark.mllib.tree.configuration.Algo.Regression]] + * @return Configuration for weak learner + */ + def defaultWeakLearnerParams(algo: Algo): Strategy = { + // Note: Regression tree used even for classification for GBT. + new Strategy(Regression, Variance, 3) + } + +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/EnsembleCombiningStrategy.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/EnsembleCombiningStrategy.scala new file mode 100644 index 0000000000000..82889dc00cdad --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/EnsembleCombiningStrategy.scala @@ -0,0 +1,30 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.tree.configuration + +import org.apache.spark.annotation.DeveloperApi + +/** + * :: Experimental :: + * Enum to select ensemble combining strategy for base learners + */ +@DeveloperApi +object EnsembleCombiningStrategy extends Enumeration { + type EnsembleCombiningStrategy = Value + val Sum, Average = Value +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala index caaccbfb8ad16..d09295c507d67 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala @@ -17,6 +17,7 @@ package org.apache.spark.mllib.tree.configuration +import scala.beans.BeanProperty import scala.collection.JavaConverters._ import org.apache.spark.annotation.Experimental @@ -43,7 +44,7 @@ import org.apache.spark.mllib.tree.configuration.QuantileStrategy._ * for choosing how to split on features at each node. * More bins give higher granularity. * @param quantileCalculationStrategy Algorithm for calculating quantiles. Supported: - * [[org.apache.spark.mllib.tree.configuration.QuantileStrategy.Sort]] + * [[org.apache.spark.mllib.tree.configuration.QuantileStrategy.Sort]] * @param categoricalFeaturesInfo A map storing information about the categorical variables and the * number of discrete values they take. For example, an entry (n -> * k) implies the feature n is categorical with k categories 0, @@ -58,19 +59,31 @@ import org.apache.spark.mllib.tree.configuration.QuantileStrategy._ * this split will not be considered as a valid split. * @param maxMemoryInMB Maximum memory in MB allocated to histogram aggregation. Default value is * 256 MB. + * @param subsamplingRate Fraction of the training data used for learning decision tree. + * @param useNodeIdCache If this is true, instead of passing trees to executors, the algorithm will + * maintain a separate RDD of node Id cache for each row. + * @param checkpointDir If the node Id cache is used, it will help to checkpoint + * the node Id cache periodically. This is the checkpoint directory + * to be used for the node Id cache. + * @param checkpointInterval How often to checkpoint when the node Id cache gets updated. + * E.g. 10 means that the cache will get checkpointed every 10 updates. */ @Experimental class Strategy ( val algo: Algo, - val impurity: Impurity, - val maxDepth: Int, - val numClassesForClassification: Int = 2, - val maxBins: Int = 32, - val quantileCalculationStrategy: QuantileStrategy = Sort, - val categoricalFeaturesInfo: Map[Int, Int] = Map[Int, Int](), - val minInstancesPerNode: Int = 1, - val minInfoGain: Double = 0.0, - val maxMemoryInMB: Int = 256) extends Serializable { + @BeanProperty var impurity: Impurity, + @BeanProperty var maxDepth: Int, + @BeanProperty var numClassesForClassification: Int = 2, + @BeanProperty var maxBins: Int = 32, + @BeanProperty var quantileCalculationStrategy: QuantileStrategy = Sort, + @BeanProperty var categoricalFeaturesInfo: Map[Int, Int] = Map[Int, Int](), + @BeanProperty var minInstancesPerNode: Int = 1, + @BeanProperty var minInfoGain: Double = 0.0, + @BeanProperty var maxMemoryInMB: Int = 256, + @BeanProperty var subsamplingRate: Double = 1, + @BeanProperty var useNodeIdCache: Boolean = false, + @BeanProperty var checkpointDir: Option[String] = None, + @BeanProperty var checkpointInterval: Int = 10) extends Serializable { if (algo == Classification) { require(numClassesForClassification >= 2) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/BaggedPoint.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/BaggedPoint.scala index 937c8a2ac5836..089010c81ffb6 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/BaggedPoint.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/BaggedPoint.scala @@ -17,18 +17,18 @@ package org.apache.spark.mllib.tree.impl -import cern.jet.random.Poisson -import cern.jet.random.engine.DRand +import org.apache.commons.math3.distribution.PoissonDistribution import org.apache.spark.rdd.RDD import org.apache.spark.util.Utils +import org.apache.spark.util.random.XORShiftRandom /** * Internal representation of a datapoint which belongs to several subsamples of the same dataset, * particularly for bagging (e.g., for random forests). * * This holds one instance, as well as an array of weights which represent the (weighted) - * number of times which this instance appears in each subsample. + * number of times which this instance appears in each subsamplingRate. * E.g., (datum, [1, 0, 4]) indicates that there are 3 subsamples of the dataset and that * this datum has 1 copy, 0 copies, and 4 copies in the 3 subsamples, respectively. * @@ -45,27 +45,71 @@ private[tree] object BaggedPoint { /** * Convert an input dataset into its BaggedPoint representation, - * choosing subsample counts for each instance. - * Each subsample has the same number of instances as the original dataset, - * and is created by subsampling with replacement. - * @param input Input dataset. - * @param numSubsamples Number of subsamples of this RDD to take. - * @param seed Random seed. - * @return BaggedPoint dataset representation + * choosing subsamplingRate counts for each instance. + * Each subsamplingRate has the same number of instances as the original dataset, + * and is created by subsampling without replacement. + * @param input Input dataset. + * @param subsamplingRate Fraction of the training data used for learning decision tree. + * @param numSubsamples Number of subsamples of this RDD to take. + * @param withReplacement Sampling with/without replacement. + * @param seed Random seed. + * @return BaggedPoint dataset representation. */ - def convertToBaggedRDD[Datum]( + def convertToBaggedRDD[Datum] ( input: RDD[Datum], + subsamplingRate: Double, numSubsamples: Int, + withReplacement: Boolean, seed: Int = Utils.random.nextInt()): RDD[BaggedPoint[Datum]] = { + if (withReplacement) { + convertToBaggedRDDSamplingWithReplacement(input, subsamplingRate, numSubsamples, seed) + } else { + if (numSubsamples == 1 && subsamplingRate == 1.0) { + convertToBaggedRDDWithoutSampling(input) + } else { + convertToBaggedRDDSamplingWithoutReplacement(input, subsamplingRate, numSubsamples, seed) + } + } + } + + private def convertToBaggedRDDSamplingWithoutReplacement[Datum] ( + input: RDD[Datum], + subsamplingRate: Double, + numSubsamples: Int, + seed: Int): RDD[BaggedPoint[Datum]] = { + input.mapPartitionsWithIndex { (partitionIndex, instances) => + // Use random seed = seed + partitionIndex + 1 to make generation reproducible. + val rng = new XORShiftRandom + rng.setSeed(seed + partitionIndex + 1) + instances.map { instance => + val subsampleWeights = new Array[Double](numSubsamples) + var subsampleIndex = 0 + while (subsampleIndex < numSubsamples) { + val x = rng.nextDouble() + subsampleWeights(subsampleIndex) = { + if (x < subsamplingRate) 1.0 else 0.0 + } + subsampleIndex += 1 + } + new BaggedPoint(instance, subsampleWeights) + } + } + } + + private def convertToBaggedRDDSamplingWithReplacement[Datum] ( + input: RDD[Datum], + subsample: Double, + numSubsamples: Int, + seed: Int): RDD[BaggedPoint[Datum]] = { input.mapPartitionsWithIndex { (partitionIndex, instances) => - // TODO: Support different sampling rates, and sampling without replacement. // Use random seed = seed + partitionIndex + 1 to make generation reproducible. - val poisson = new Poisson(1.0, new DRand(seed + partitionIndex + 1)) + val poisson = new PoissonDistribution(subsample) + poisson.reseedRandomGenerator(seed + partitionIndex + 1) instances.map { instance => val subsampleWeights = new Array[Double](numSubsamples) var subsampleIndex = 0 while (subsampleIndex < numSubsamples) { - subsampleWeights(subsampleIndex) = poisson.nextInt() + subsampleWeights(subsampleIndex) = poisson.sample() subsampleIndex += 1 } new BaggedPoint(instance, subsampleWeights) @@ -73,7 +117,8 @@ private[tree] object BaggedPoint { } } - def convertToBaggedRDDWithoutSampling[Datum](input: RDD[Datum]): RDD[BaggedPoint[Datum]] = { + private def convertToBaggedRDDWithoutSampling[Datum] ( + input: RDD[Datum]): RDD[BaggedPoint[Datum]] = { input.map(datum => new BaggedPoint(datum, Array(1.0))) } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/NodeIdCache.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/NodeIdCache.scala new file mode 100644 index 0000000000000..83011b48b7d9b --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/NodeIdCache.scala @@ -0,0 +1,204 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.tree.impl + +import scala.collection.mutable + +import org.apache.hadoop.fs.{Path, FileSystem} + +import org.apache.spark.rdd.RDD +import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.mllib.tree.configuration.FeatureType._ +import org.apache.spark.storage.StorageLevel +import org.apache.spark.mllib.tree.model.{Bin, Node, Split} + +/** + * :: DeveloperApi :: + * This is used by the node id cache to find the child id that a data point would belong to. + * @param split Split information. + * @param nodeIndex The current node index of a data point that this will update. + */ +@DeveloperApi +private[tree] case class NodeIndexUpdater( + split: Split, + nodeIndex: Int) { + /** + * Determine a child node index based on the feature value and the split. + * @param binnedFeatures Binned feature values. + * @param bins Bin information to convert the bin indices to approximate feature values. + * @return Child node index to update to. + */ + def updateNodeIndex(binnedFeatures: Array[Int], bins: Array[Array[Bin]]): Int = { + if (split.featureType == Continuous) { + val featureIndex = split.feature + val binIndex = binnedFeatures(featureIndex) + val featureValueUpperBound = bins(featureIndex)(binIndex).highSplit.threshold + if (featureValueUpperBound <= split.threshold) { + Node.leftChildIndex(nodeIndex) + } else { + Node.rightChildIndex(nodeIndex) + } + } else { + if (split.categories.contains(binnedFeatures(split.feature).toDouble)) { + Node.leftChildIndex(nodeIndex) + } else { + Node.rightChildIndex(nodeIndex) + } + } + } +} + +/** + * :: DeveloperApi :: + * A given TreePoint would belong to a particular node per tree. + * Each row in the nodeIdsForInstances RDD is an array over trees of the node index + * in each tree. Initially, values should all be 1 for root node. + * The nodeIdsForInstances RDD needs to be updated at each iteration. + * @param nodeIdsForInstances The initial values in the cache + * (should be an Array of all 1's (meaning the root nodes)). + * @param checkpointDir The checkpoint directory where + * the checkpointed files will be stored. + * @param checkpointInterval The checkpointing interval + * (how often should the cache be checkpointed.). + */ +@DeveloperApi +private[tree] class NodeIdCache( + var nodeIdsForInstances: RDD[Array[Int]], + val checkpointDir: Option[String], + val checkpointInterval: Int) { + + // Keep a reference to a previous node Ids for instances. + // Because we will keep on re-persisting updated node Ids, + // we want to unpersist the previous RDD. + private var prevNodeIdsForInstances: RDD[Array[Int]] = null + + // To keep track of the past checkpointed RDDs. + private val checkpointQueue = mutable.Queue[RDD[Array[Int]]]() + private var rddUpdateCount = 0 + + // If a checkpoint directory is given, and there's no prior checkpoint directory, + // then set the checkpoint directory with the given one. + if (checkpointDir.nonEmpty && nodeIdsForInstances.sparkContext.getCheckpointDir.isEmpty) { + nodeIdsForInstances.sparkContext.setCheckpointDir(checkpointDir.get) + } + + /** + * Update the node index values in the cache. + * This updates the RDD and its lineage. + * TODO: Passing bin information to executors seems unnecessary and costly. + * @param data The RDD of training rows. + * @param nodeIdUpdaters A map of node index updaters. + * The key is the indices of nodes that we want to update. + * @param bins Bin information needed to find child node indices. + */ + def updateNodeIndices( + data: RDD[BaggedPoint[TreePoint]], + nodeIdUpdaters: Array[mutable.Map[Int, NodeIndexUpdater]], + bins: Array[Array[Bin]]): Unit = { + if (prevNodeIdsForInstances != null) { + // Unpersist the previous one if one exists. + prevNodeIdsForInstances.unpersist() + } + + prevNodeIdsForInstances = nodeIdsForInstances + nodeIdsForInstances = data.zip(nodeIdsForInstances).map { + dataPoint => { + var treeId = 0 + while (treeId < nodeIdUpdaters.length) { + val nodeIdUpdater = nodeIdUpdaters(treeId).getOrElse(dataPoint._2(treeId), null) + if (nodeIdUpdater != null) { + val newNodeIndex = nodeIdUpdater.updateNodeIndex( + binnedFeatures = dataPoint._1.datum.binnedFeatures, + bins = bins) + dataPoint._2(treeId) = newNodeIndex + } + + treeId += 1 + } + + dataPoint._2 + } + } + + // Keep on persisting new ones. + nodeIdsForInstances.persist(StorageLevel.MEMORY_AND_DISK) + rddUpdateCount += 1 + + // Handle checkpointing if the directory is not None. + if (nodeIdsForInstances.sparkContext.getCheckpointDir.nonEmpty && + (rddUpdateCount % checkpointInterval) == 0) { + // Let's see if we can delete previous checkpoints. + var canDelete = true + while (checkpointQueue.size > 1 && canDelete) { + // We can delete the oldest checkpoint iff + // the next checkpoint actually exists in the file system. + if (checkpointQueue.get(1).get.getCheckpointFile != None) { + val old = checkpointQueue.dequeue() + + // Since the old checkpoint is not deleted by Spark, + // we'll manually delete it here. + val fs = FileSystem.get(old.sparkContext.hadoopConfiguration) + fs.delete(new Path(old.getCheckpointFile.get), true) + } else { + canDelete = false + } + } + + nodeIdsForInstances.checkpoint() + checkpointQueue.enqueue(nodeIdsForInstances) + } + } + + /** + * Call this after training is finished to delete any remaining checkpoints. + */ + def deleteAllCheckpoints(): Unit = { + while (checkpointQueue.size > 0) { + val old = checkpointQueue.dequeue() + if (old.getCheckpointFile != None) { + val fs = FileSystem.get(old.sparkContext.hadoopConfiguration) + fs.delete(new Path(old.getCheckpointFile.get), true) + } + } + } +} + +@DeveloperApi +private[tree] object NodeIdCache { + /** + * Initialize the node Id cache with initial node Id values. + * @param data The RDD of training rows. + * @param numTrees The number of trees that we want to create cache for. + * @param checkpointDir The checkpoint directory where the checkpointed files will be stored. + * @param checkpointInterval The checkpointing interval + * (how often should the cache be checkpointed.). + * @param initVal The initial values in the cache. + * @return A node Id cache containing an RDD of initial root node Indices. + */ + def init( + data: RDD[BaggedPoint[TreePoint]], + numTrees: Int, + checkpointDir: Option[String], + checkpointInterval: Int, + initVal: Int = 1): NodeIdCache = { + new NodeIdCache( + data.map(_ => Array.fill[Int](numTrees)(initVal)), + checkpointDir, + checkpointInterval) + } +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/AbsoluteError.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/AbsoluteError.scala new file mode 100644 index 0000000000000..d111ffe30ed9e --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/AbsoluteError.scala @@ -0,0 +1,66 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.tree.loss + +import org.apache.spark.SparkContext._ +import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.mllib.regression.LabeledPoint +import org.apache.spark.mllib.tree.model.WeightedEnsembleModel +import org.apache.spark.rdd.RDD + +/** + * :: DeveloperApi :: + * Class for least absolute error loss calculation. + * The features x and the corresponding label y is predicted using the function F. + * For each instance: + * Loss: |y - F| + * Negative gradient: sign(y - F) + */ +@DeveloperApi +object AbsoluteError extends Loss { + + /** + * Method to calculate the gradients for the gradient boosting calculation for least + * absolute error calculation. + * @param model Model of the weak learner + * @param point Instance of the training dataset + * @return Loss gradient + */ + override def gradient( + model: WeightedEnsembleModel, + point: LabeledPoint): Double = { + if ((point.label - model.predict(point.features)) < 0) 1.0 else -1.0 + } + + /** + * Method to calculate error of the base learner for the gradient boosting calculation. + * Note: This method is not used by the gradient boosting algorithm but is useful for debugging + * purposes. + * @param model Model of the weak learner. + * @param data Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]]. + * @return + */ + override def computeError(model: WeightedEnsembleModel, data: RDD[LabeledPoint]): Double = { + val sumOfAbsolutes = data.map { y => + val err = model.predict(y.features) - y.label + math.abs(err) + }.sum() + sumOfAbsolutes / data.count() + } + +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/LogLoss.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/LogLoss.scala new file mode 100644 index 0000000000000..6f3d4340f0d3b --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/LogLoss.scala @@ -0,0 +1,63 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.tree.loss + +import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.mllib.regression.LabeledPoint +import org.apache.spark.mllib.tree.model.WeightedEnsembleModel +import org.apache.spark.rdd.RDD + +/** + * :: DeveloperApi :: + * Class for least squares error loss calculation. + * + * The features x and the corresponding label y is predicted using the function F. + * For each instance: + * Loss: log(1 + exp(-2yF)), y in {-1, 1} + * Negative gradient: 2y / ( 1 + exp(2yF)) + */ +@DeveloperApi +object LogLoss extends Loss { + + /** + * Method to calculate the loss gradients for the gradient boosting calculation for binary + * classification + * @param model Model of the weak learner + * @param point Instance of the training dataset + * @return Loss gradient + */ + override def gradient( + model: WeightedEnsembleModel, + point: LabeledPoint): Double = { + val prediction = model.predict(point.features) + 1.0 / (1.0 + math.exp(-prediction)) - point.label + } + + /** + * Method to calculate error of the base learner for the gradient boosting calculation. + * Note: This method is not used by the gradient boosting algorithm but is useful for debugging + * purposes. + * @param model Model of the weak learner. + * @param data Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]]. + * @return + */ + override def computeError(model: WeightedEnsembleModel, data: RDD[LabeledPoint]): Double = { + val wrongPredictions = data.filter(lp => model.predict(lp.features) != lp.label).count() + wrongPredictions / data.count + } +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/Loss.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/Loss.scala new file mode 100644 index 0000000000000..5580866c879e2 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/Loss.scala @@ -0,0 +1,52 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.tree.loss + +import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.mllib.regression.LabeledPoint +import org.apache.spark.mllib.tree.model.WeightedEnsembleModel +import org.apache.spark.rdd.RDD + +/** + * :: DeveloperApi :: + * Trait for adding "pluggable" loss functions for the gradient boosting algorithm. + */ +@DeveloperApi +trait Loss extends Serializable { + + /** + * Method to calculate the gradients for the gradient boosting calculation. + * @param model Model of the weak learner. + * @param point Instance of the training dataset. + * @return Loss gradient. + */ + def gradient( + model: WeightedEnsembleModel, + point: LabeledPoint): Double + + /** + * Method to calculate error of the base learner for the gradient boosting calculation. + * Note: This method is not used by the gradient boosting algorithm but is useful for debugging + * purposes. + * @param model Model of the weak learner. + * @param data Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]]. + * @return + */ + def computeError(model: WeightedEnsembleModel, data: RDD[LabeledPoint]): Double + +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/Losses.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/Losses.scala new file mode 100644 index 0000000000000..42c9ead9884b4 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/Losses.scala @@ -0,0 +1,29 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.tree.loss + +object Losses { + + def fromString(name: String): Loss = name match { + case "leastSquaresError" => SquaredError + case "leastAbsoluteError" => AbsoluteError + case "logLoss" => LogLoss + case _ => throw new IllegalArgumentException(s"Did not recognize Loss name: $name") + } + +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/SquaredError.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/SquaredError.scala new file mode 100644 index 0000000000000..4349fefef2c74 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/SquaredError.scala @@ -0,0 +1,66 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.tree.loss + +import org.apache.spark.SparkContext._ +import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.mllib.regression.LabeledPoint +import org.apache.spark.mllib.tree.model.WeightedEnsembleModel +import org.apache.spark.rdd.RDD + +/** + * :: DeveloperApi :: + * Class for least squares error loss calculation. + * + * The features x and the corresponding label y is predicted using the function F. + * For each instance: + * Loss: (y - F)**2/2 + * Negative gradient: y - F + */ +@DeveloperApi +object SquaredError extends Loss { + + /** + * Method to calculate the gradients for the gradient boosting calculation for least + * squares error calculation. + * @param model Model of the weak learner + * @param point Instance of the training dataset + * @return Loss gradient + */ + override def gradient( + model: WeightedEnsembleModel, + point: LabeledPoint): Double = { + model.predict(point.features) - point.label + } + + /** + * Method to calculate error of the base learner for the gradient boosting calculation. + * Note: This method is not used by the gradient boosting algorithm but is useful for debugging + * purposes. + * @param model Model of the weak learner. + * @param data Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]]. + * @return + */ + override def computeError(model: WeightedEnsembleModel, data: RDD[LabeledPoint]): Double = { + data.map { y => + val err = model.predict(y.features) - y.label + err * err + }.mean() + } + +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/RandomForestModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/RandomForestModel.scala deleted file mode 100644 index 6a22e2abe59bd..0000000000000 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/RandomForestModel.scala +++ /dev/null @@ -1,115 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.mllib.tree.model - -import scala.collection.mutable - -import org.apache.spark.annotation.Experimental -import org.apache.spark.mllib.linalg.Vector -import org.apache.spark.mllib.tree.configuration.Algo._ -import org.apache.spark.rdd.RDD - -/** - * :: Experimental :: - * Random forest model for classification or regression. - * This model stores a collection of [[DecisionTreeModel]] instances and uses them to make - * aggregate predictions. - * @param trees Trees which make up this forest. This cannot be empty. - * @param algo algorithm type -- classification or regression - */ -@Experimental -class RandomForestModel(val trees: Array[DecisionTreeModel], val algo: Algo) extends Serializable { - - require(trees.size > 0, s"RandomForestModel cannot be created with empty trees collection.") - - /** - * Predict values for a single data point. - * - * @param features array representing a single data point - * @return Double prediction from the trained model - */ - def predict(features: Vector): Double = { - algo match { - case Classification => - val predictionToCount = new mutable.HashMap[Int, Int]() - trees.foreach { tree => - val prediction = tree.predict(features).toInt - predictionToCount(prediction) = predictionToCount.getOrElse(prediction, 0) + 1 - } - predictionToCount.maxBy(_._2)._1 - case Regression => - trees.map(_.predict(features)).sum / trees.size - } - } - - /** - * Predict values for the given data set. - * - * @param features RDD representing data points to be predicted - * @return RDD[Double] where each entry contains the corresponding prediction - */ - def predict(features: RDD[Vector]): RDD[Double] = { - features.map(x => predict(x)) - } - - /** - * Get number of trees in forest. - */ - def numTrees: Int = trees.size - - /** - * Get total number of nodes, summed over all trees in the forest. - */ - def totalNumNodes: Int = trees.map(tree => tree.numNodes).sum - - /** - * Print a summary of the model. - */ - override def toString: String = algo match { - case Classification => - s"RandomForestModel classifier with $numTrees trees and $totalNumNodes total nodes" - case Regression => - s"RandomForestModel regressor with $numTrees trees and $totalNumNodes total nodes" - case _ => throw new IllegalArgumentException( - s"RandomForestModel given unknown algo parameter: $algo.") - } - - /** - * Print the full model to a string. - */ - def toDebugString: String = { - val header = toString + "\n" - header + trees.zipWithIndex.map { case (tree, treeIndex) => - s" Tree $treeIndex:\n" + tree.topNode.subtreeToString(4) - }.fold("")(_ + _) - } - -} - -private[tree] object RandomForestModel { - - def build(trees: Array[DecisionTreeModel]): RandomForestModel = { - require(trees.size > 0, s"RandomForestModel cannot be created with empty trees collection.") - val algo: Algo = trees(0).algo - require(trees.forall(_.algo == algo), - "RandomForestModel cannot combine trees which have different output types" + - " (classification/regression).") - new RandomForestModel(trees, algo) - } - -} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/WeightedEnsembleModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/WeightedEnsembleModel.scala new file mode 100644 index 0000000000000..7b052d9163a13 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/WeightedEnsembleModel.scala @@ -0,0 +1,158 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.tree.model + +import org.apache.spark.annotation.Experimental +import org.apache.spark.mllib.linalg.Vector +import org.apache.spark.mllib.tree.configuration.Algo._ +import org.apache.spark.mllib.tree.configuration.EnsembleCombiningStrategy._ +import org.apache.spark.rdd.RDD + +import scala.collection.mutable + +@Experimental +class WeightedEnsembleModel( + val weakHypotheses: Array[DecisionTreeModel], + val weakHypothesisWeights: Array[Double], + val algo: Algo, + val combiningStrategy: EnsembleCombiningStrategy) extends Serializable { + + require(numWeakHypotheses > 0, s"WeightedEnsembleModel cannot be created without weakHypotheses" + + s". Number of weakHypotheses = $weakHypotheses") + + /** + * Predict values for a single data point using the model trained. + * + * @param features array representing a single data point + * @return predicted category from the trained model + */ + private def predictRaw(features: Vector): Double = { + val treePredictions = weakHypotheses.map(learner => learner.predict(features)) + if (numWeakHypotheses == 1){ + treePredictions(0) + } else { + var prediction = treePredictions(0) + var index = 1 + while (index < numWeakHypotheses) { + prediction += weakHypothesisWeights(index) * treePredictions(index) + index += 1 + } + prediction + } + } + + /** + * Predict values for a single data point using the model trained. + * + * @param features array representing a single data point + * @return predicted category from the trained model + */ + private def predictBySumming(features: Vector): Double = { + algo match { + case Regression => predictRaw(features) + case Classification => { + // TODO: predicted labels are +1 or -1 for GBT. Need a better way to store this info. + if (predictRaw(features) > 0 ) 1.0 else 0.0 + } + case _ => throw new IllegalArgumentException( + s"WeightedEnsembleModel given unknown algo parameter: $algo.") + } + } + + /** + * Predict values for a single data point. + * + * @param features array representing a single data point + * @return Double prediction from the trained model + */ + private def predictByAveraging(features: Vector): Double = { + algo match { + case Classification => + val predictionToCount = new mutable.HashMap[Int, Int]() + weakHypotheses.foreach { learner => + val prediction = learner.predict(features).toInt + predictionToCount(prediction) = predictionToCount.getOrElse(prediction, 0) + 1 + } + predictionToCount.maxBy(_._2)._1 + case Regression => + weakHypotheses.map(_.predict(features)).sum / weakHypotheses.size + } + } + + + /** + * Predict values for a single data point using the model trained. + * + * @param features array representing a single data point + * @return predicted category from the trained model + */ + def predict(features: Vector): Double = { + combiningStrategy match { + case Sum => predictBySumming(features) + case Average => predictByAveraging(features) + case _ => throw new IllegalArgumentException( + s"WeightedEnsembleModel given unknown combining parameter: $combiningStrategy.") + } + } + + /** + * Predict values for the given data set. + * + * @param features RDD representing data points to be predicted + * @return RDD[Double] where each entry contains the corresponding prediction + */ + def predict(features: RDD[Vector]): RDD[Double] = features.map(x => predict(x)) + + /** + * Print a summary of the model. + */ + override def toString: String = { + algo match { + case Classification => + s"WeightedEnsembleModel classifier with $numWeakHypotheses trees\n" + case Regression => + s"WeightedEnsembleModel regressor with $numWeakHypotheses trees\n" + case _ => throw new IllegalArgumentException( + s"WeightedEnsembleModel given unknown algo parameter: $algo.") + } + } + + /** + * Print the full model to a string. + */ + def toDebugString: String = { + val header = toString + "\n" + header + weakHypotheses.zipWithIndex.map { case (tree, treeIndex) => + s" Tree $treeIndex:\n" + tree.topNode.subtreeToString(4) + }.fold("")(_ + _) + } + + /** + * Get number of trees in forest. + */ + def numWeakHypotheses: Int = weakHypotheses.size + + // TODO: Remove these helpers methods once class is generalized to support any base learning + // algorithms. + + /** + * Get total number of nodes, summed over all trees in the forest. + */ + def totalNumNodes: Int = weakHypotheses.map(tree => tree.numNodes).sum + +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala b/mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala index ca35100aa99c6..9353351af72a0 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala @@ -26,7 +26,7 @@ import org.apache.spark.annotation.Experimental import org.apache.spark.SparkContext import org.apache.spark.rdd.RDD import org.apache.spark.rdd.PartitionwiseSampledRDD -import org.apache.spark.util.random.BernoulliSampler +import org.apache.spark.util.random.BernoulliCellSampler import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.linalg.{Vector, Vectors} import org.apache.spark.storage.StorageLevel @@ -76,7 +76,7 @@ object MLUtils { .map { line => val items = line.split(' ') val label = items.head.toDouble - val (indices, values) = items.tail.map { item => + val (indices, values) = items.tail.filter(_.nonEmpty).map { item => val indexAndValue = item.split(':') val index = indexAndValue(0).toInt - 1 // Convert 1-based indices to 0-based. val value = indexAndValue(1).toDouble @@ -196,8 +196,8 @@ object MLUtils { /** * Load labeled data from a file. The data format used here is - * , ... - * where , are feature values in Double and is the corresponding label as Double. + * L, f1 f2 ... + * where f1, f2 are feature values in Double and L is the corresponding label as Double. * * @param sc SparkContext * @param dir Directory to the input data files. @@ -219,8 +219,8 @@ object MLUtils { /** * Save labeled data to a file. The data format used here is - * , ... - * where , are feature values in Double and is the corresponding label as Double. + * L, f1 f2 ... + * where f1, f2 are feature values in Double and L is the corresponding label as Double. * * @param data An RDD of LabeledPoints containing data to be saved. * @param dir Directory to save the data. @@ -244,7 +244,7 @@ object MLUtils { def kFold[T: ClassTag](rdd: RDD[T], numFolds: Int, seed: Int): Array[(RDD[T], RDD[T])] = { val numFoldsF = numFolds.toFloat (1 to numFolds).map { fold => - val sampler = new BernoulliSampler[T]((fold - 1) / numFoldsF, fold / numFoldsF, + val sampler = new BernoulliCellSampler[T]((fold - 1) / numFoldsF, fold / numFoldsF, complement = false) val validation = new PartitionwiseSampledRDD(rdd, sampler, true, seed) val training = new PartitionwiseSampledRDD(rdd, sampler.cloneComplement(), true, seed) diff --git a/mllib/src/test/scala/org/apache/spark/mllib/clustering/StreamingKMeansSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/clustering/StreamingKMeansSuite.scala new file mode 100644 index 0000000000000..850c9fce507cd --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/mllib/clustering/StreamingKMeansSuite.scala @@ -0,0 +1,157 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.clustering + +import org.scalatest.FunSuite + +import org.apache.spark.mllib.linalg.{Vector, Vectors} +import org.apache.spark.mllib.util.TestingUtils._ +import org.apache.spark.streaming.TestSuiteBase +import org.apache.spark.streaming.dstream.DStream +import org.apache.spark.util.random.XORShiftRandom + +class StreamingKMeansSuite extends FunSuite with TestSuiteBase { + + override def maxWaitTimeMillis = 30000 + + test("accuracy for single center and equivalence to grand average") { + // set parameters + val numBatches = 10 + val numPoints = 50 + val k = 1 + val d = 5 + val r = 0.1 + + // create model with one cluster + val model = new StreamingKMeans() + .setK(1) + .setDecayFactor(1.0) + .setInitialCenters(Array(Vectors.dense(0.0, 0.0, 0.0, 0.0, 0.0)), Array(0.0)) + + // generate random data for k-means + val (input, centers) = StreamingKMeansDataGenerator(numPoints, numBatches, k, d, r, 42) + + // setup and run the model training + val ssc = setupStreams(input, (inputDStream: DStream[Vector]) => { + model.trainOn(inputDStream) + inputDStream.count() + }) + runStreams(ssc, numBatches, numBatches) + + // estimated center should be close to true center + assert(centers(0) ~== model.latestModel().clusterCenters(0) absTol 1E-1) + + // estimated center from streaming should exactly match the arithmetic mean of all data points + // because the decay factor is set to 1.0 + val grandMean = + input.flatten.map(x => x.toBreeze).reduce(_+_) / (numBatches * numPoints).toDouble + assert(model.latestModel().clusterCenters(0) ~== Vectors.dense(grandMean.toArray) absTol 1E-5) + } + + test("accuracy for two centers") { + val numBatches = 10 + val numPoints = 5 + val k = 2 + val d = 5 + val r = 0.1 + + // create model with two clusters + val kMeans = new StreamingKMeans() + .setK(2) + .setHalfLife(2, "batches") + .setInitialCenters( + Array(Vectors.dense(-0.1, 0.1, -0.2, -0.3, -0.1), + Vectors.dense(0.1, -0.2, 0.0, 0.2, 0.1)), + Array(5.0, 5.0)) + + // generate random data for k-means + val (input, centers) = StreamingKMeansDataGenerator(numPoints, numBatches, k, d, r, 42) + + // setup and run the model training + val ssc = setupStreams(input, (inputDStream: DStream[Vector]) => { + kMeans.trainOn(inputDStream) + inputDStream.count() + }) + runStreams(ssc, numBatches, numBatches) + + // check that estimated centers are close to true centers + // NOTE exact assignment depends on the initialization! + assert(centers(0) ~== kMeans.latestModel().clusterCenters(0) absTol 1E-1) + assert(centers(1) ~== kMeans.latestModel().clusterCenters(1) absTol 1E-1) + } + + test("detecting dying clusters") { + val numBatches = 10 + val numPoints = 5 + val k = 1 + val d = 1 + val r = 1.0 + + // create model with two clusters + val kMeans = new StreamingKMeans() + .setK(2) + .setHalfLife(0.5, "points") + .setInitialCenters( + Array(Vectors.dense(0.0), Vectors.dense(1000.0)), + Array(1.0, 1.0)) + + // new data are all around the first cluster 0.0 + val (input, _) = + StreamingKMeansDataGenerator(numPoints, numBatches, k, d, r, 42, Array(Vectors.dense(0.0))) + + // setup and run the model training + val ssc = setupStreams(input, (inputDStream: DStream[Vector]) => { + kMeans.trainOn(inputDStream) + inputDStream.count() + }) + runStreams(ssc, numBatches, numBatches) + + // check that estimated centers are close to true centers + // NOTE exact assignment depends on the initialization! + val model = kMeans.latestModel() + val c0 = model.clusterCenters(0)(0) + val c1 = model.clusterCenters(1)(0) + + assert(c0 * c1 < 0.0, "should have one positive center and one negative center") + // 0.8 is the mean of half-normal distribution + assert(math.abs(c0) ~== 0.8 absTol 0.6) + assert(math.abs(c1) ~== 0.8 absTol 0.6) + } + + def StreamingKMeansDataGenerator( + numPoints: Int, + numBatches: Int, + k: Int, + d: Int, + r: Double, + seed: Int, + initCenters: Array[Vector] = null): (IndexedSeq[IndexedSeq[Vector]], Array[Vector]) = { + val rand = new XORShiftRandom(seed) + val centers = initCenters match { + case null => Array.fill(k)(Vectors.dense(Array.fill(d)(rand.nextGaussian()))) + case _ => initCenters + } + val data = (0 until numBatches).map { i => + (0 until numPoints).map { idx => + val center = centers(idx % k) + Vectors.dense(Array.tabulate(d)(x => center(x) + rand.nextGaussian() * r)) + } + } + (data, centers) + } +} diff --git a/mllib/src/test/scala/org/apache/spark/mllib/evaluation/MultilabelMetricsSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/evaluation/MultilabelMetricsSuite.scala new file mode 100644 index 0000000000000..342baa0274e9c --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/mllib/evaluation/MultilabelMetricsSuite.scala @@ -0,0 +1,103 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.evaluation + +import org.scalatest.FunSuite + +import org.apache.spark.mllib.util.LocalSparkContext +import org.apache.spark.rdd.RDD + +class MultilabelMetricsSuite extends FunSuite with LocalSparkContext { + test("Multilabel evaluation metrics") { + /* + * Documents true labels (5x class0, 3x class1, 4x class2): + * doc 0 - predict 0, 1 - class 0, 2 + * doc 1 - predict 0, 2 - class 0, 1 + * doc 2 - predict none - class 0 + * doc 3 - predict 2 - class 2 + * doc 4 - predict 2, 0 - class 2, 0 + * doc 5 - predict 0, 1, 2 - class 0, 1 + * doc 6 - predict 1 - class 1, 2 + * + * predicted classes + * class 0 - doc 0, 1, 4, 5 (total 4) + * class 1 - doc 0, 5, 6 (total 3) + * class 2 - doc 1, 3, 4, 5 (total 4) + * + * true classes + * class 0 - doc 0, 1, 2, 4, 5 (total 5) + * class 1 - doc 1, 5, 6 (total 3) + * class 2 - doc 0, 3, 4, 6 (total 4) + * + */ + val scoreAndLabels: RDD[(Array[Double], Array[Double])] = sc.parallelize( + Seq((Array(0.0, 1.0), Array(0.0, 2.0)), + (Array(0.0, 2.0), Array(0.0, 1.0)), + (Array(), Array(0.0)), + (Array(2.0), Array(2.0)), + (Array(2.0, 0.0), Array(2.0, 0.0)), + (Array(0.0, 1.0, 2.0), Array(0.0, 1.0)), + (Array(1.0), Array(1.0, 2.0))), 2) + val metrics = new MultilabelMetrics(scoreAndLabels) + val delta = 0.00001 + val precision0 = 4.0 / (4 + 0) + val precision1 = 2.0 / (2 + 1) + val precision2 = 2.0 / (2 + 2) + val recall0 = 4.0 / (4 + 1) + val recall1 = 2.0 / (2 + 1) + val recall2 = 2.0 / (2 + 2) + val f1measure0 = 2 * precision0 * recall0 / (precision0 + recall0) + val f1measure1 = 2 * precision1 * recall1 / (precision1 + recall1) + val f1measure2 = 2 * precision2 * recall2 / (precision2 + recall2) + val sumTp = 4 + 2 + 2 + assert(sumTp == (1 + 1 + 0 + 1 + 2 + 2 + 1)) + val microPrecisionClass = sumTp.toDouble / (4 + 0 + 2 + 1 + 2 + 2) + val microRecallClass = sumTp.toDouble / (4 + 1 + 2 + 1 + 2 + 2) + val microF1MeasureClass = 2.0 * sumTp.toDouble / + (2 * sumTp.toDouble + (1 + 1 + 2) + (0 + 1 + 2)) + val macroPrecisionDoc = 1.0 / 7 * + (1.0 / 2 + 1.0 / 2 + 0 + 1.0 / 1 + 2.0 / 2 + 2.0 / 3 + 1.0 / 1.0) + val macroRecallDoc = 1.0 / 7 * + (1.0 / 2 + 1.0 / 2 + 0 / 1 + 1.0 / 1 + 2.0 / 2 + 2.0 / 2 + 1.0 / 2) + val macroF1MeasureDoc = (1.0 / 7) * + 2 * ( 1.0 / (2 + 2) + 1.0 / (2 + 2) + 0 + 1.0 / (1 + 1) + + 2.0 / (2 + 2) + 2.0 / (3 + 2) + 1.0 / (1 + 2) ) + val hammingLoss = (1.0 / (7 * 3)) * (2 + 2 + 1 + 0 + 0 + 1 + 1) + val strictAccuracy = 2.0 / 7 + val accuracy = 1.0 / 7 * (1.0 / 3 + 1.0 /3 + 0 + 1.0 / 1 + 2.0 / 2 + 2.0 / 3 + 1.0 / 2) + assert(math.abs(metrics.precision(0.0) - precision0) < delta) + assert(math.abs(metrics.precision(1.0) - precision1) < delta) + assert(math.abs(metrics.precision(2.0) - precision2) < delta) + assert(math.abs(metrics.recall(0.0) - recall0) < delta) + assert(math.abs(metrics.recall(1.0) - recall1) < delta) + assert(math.abs(metrics.recall(2.0) - recall2) < delta) + assert(math.abs(metrics.f1Measure(0.0) - f1measure0) < delta) + assert(math.abs(metrics.f1Measure(1.0) - f1measure1) < delta) + assert(math.abs(metrics.f1Measure(2.0) - f1measure2) < delta) + assert(math.abs(metrics.microPrecision - microPrecisionClass) < delta) + assert(math.abs(metrics.microRecall - microRecallClass) < delta) + assert(math.abs(metrics.microF1Measure - microF1MeasureClass) < delta) + assert(math.abs(metrics.precision - macroPrecisionDoc) < delta) + assert(math.abs(metrics.recall - macroRecallDoc) < delta) + assert(math.abs(metrics.f1Measure - macroF1MeasureDoc) < delta) + assert(math.abs(metrics.hammingLoss - hammingLoss) < delta) + assert(math.abs(metrics.subsetAccuracy - strictAccuracy) < delta) + assert(math.abs(metrics.accuracy - accuracy) < delta) + assert(metrics.labels.sameElements(Array(0.0, 1.0, 2.0))) + } +} diff --git a/mllib/src/test/scala/org/apache/spark/mllib/evaluation/RegressionMetricsSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/evaluation/RegressionMetricsSuite.scala new file mode 100644 index 0000000000000..5396d7b2b74fa --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/mllib/evaluation/RegressionMetricsSuite.scala @@ -0,0 +1,52 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.evaluation + +import org.scalatest.FunSuite + +import org.apache.spark.mllib.util.LocalSparkContext +import org.apache.spark.mllib.util.TestingUtils._ + +class RegressionMetricsSuite extends FunSuite with LocalSparkContext { + + test("regression metrics") { + val predictionAndObservations = sc.parallelize( + Seq((2.5,3.0),(0.0,-0.5),(2.0,2.0),(8.0,7.0)), 2) + val metrics = new RegressionMetrics(predictionAndObservations) + assert(metrics.explainedVariance ~== 0.95717 absTol 1E-5, + "explained variance regression score mismatch") + assert(metrics.meanAbsoluteError ~== 0.5 absTol 1E-5, "mean absolute error mismatch") + assert(metrics.meanSquaredError ~== 0.375 absTol 1E-5, "mean squared error mismatch") + assert(metrics.rootMeanSquaredError ~== 0.61237 absTol 1E-5, + "root mean squared error mismatch") + assert(metrics.r2 ~== 0.94861 absTol 1E-5, "r2 score mismatch") + } + + test("regression metrics with complete fitting") { + val predictionAndObservations = sc.parallelize( + Seq((3.0,3.0),(0.0,0.0),(2.0,2.0),(8.0,8.0)), 2) + val metrics = new RegressionMetrics(predictionAndObservations) + assert(metrics.explainedVariance ~== 1.0 absTol 1E-5, + "explained variance regression score mismatch") + assert(metrics.meanAbsoluteError ~== 0.0 absTol 1E-5, "mean absolute error mismatch") + assert(metrics.meanSquaredError ~== 0.0 absTol 1E-5, "mean squared error mismatch") + assert(metrics.rootMeanSquaredError ~== 0.0 absTol 1E-5, + "root mean squared error mismatch") + assert(metrics.r2 ~== 1.0 absTol 1E-5, "r2 score mismatch") + } +} diff --git a/mllib/src/test/scala/org/apache/spark/mllib/linalg/VectorsSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/linalg/VectorsSuite.scala index cd651fe2d2ddf..93a84fe07b32a 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/linalg/VectorsSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/linalg/VectorsSuite.scala @@ -155,4 +155,15 @@ class VectorsSuite extends FunSuite { throw new RuntimeException(s"copy returned ${dvCopy.getClass} on ${dv.getClass}.") } } + + test("VectorUDT") { + val dv0 = Vectors.dense(Array.empty[Double]) + val dv1 = Vectors.dense(1.0, 2.0) + val sv0 = Vectors.sparse(2, Array.empty, Array.empty) + val sv1 = Vectors.sparse(2, Array(1), Array(2.0)) + val udt = new VectorUDT() + for (v <- Seq(dv0, dv1, sv0, sv1)) { + assert(v === udt.deserialize(udt.serialize(v))) + } + } } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/optimization/NNLSSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/optimization/NNLSSuite.scala index b781a6aed9a8c..82c327bd49fcd 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/optimization/NNLSSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/optimization/NNLSSuite.scala @@ -37,6 +37,12 @@ class NNLSSuite extends FunSuite { (ata, atb) } + /** Compute the objective value */ + def computeObjectiveValue(ata: DoubleMatrix, atb: DoubleMatrix, x: DoubleMatrix): Double = { + val res = (x.transpose().mmul(ata).mmul(x)).mul(0.5).sub(atb.dot(x)) + res.get(0) + } + test("NNLS: exact solution cases") { val n = 20 val rand = new Random(12346) @@ -79,4 +85,28 @@ class NNLSSuite extends FunSuite { assert(x(i) >= 0) } } + + test("NNLS: objective value test") { + val n = 5 + val ata = new DoubleMatrix(5, 5 + , 517399.13534, 242529.67289, -153644.98976, 130802.84503, -798452.29283 + , 242529.67289, 126017.69765, -75944.21743, 81785.36128, -405290.60884 + , -153644.98976, -75944.21743, 46986.44577, -45401.12659, 247059.51049 + , 130802.84503, 81785.36128, -45401.12659, 67457.31310, -253747.03819 + , -798452.29283, -405290.60884, 247059.51049, -253747.03819, 1310939.40814 + ) + val atb = new DoubleMatrix(5, 1, + -31755.05710, 13047.14813, -20191.24443, 25993.77580, 11963.55017) + + /** reference solution obtained from matlab function quadprog */ + val refx = new DoubleMatrix(Array(34.90751, 103.96254, 0.00000, 27.82094, 58.79627)) + val refObj = computeObjectiveValue(ata, atb, refx) + + + val ws = NNLS.createWorkspace(n) + val x = new DoubleMatrix(NNLS.solve(ata, atb, ws)) + val obj = computeObjectiveValue(ata, atb, x) + + assert(obj < refObj + 1E-5) + } } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala index 8fc5e111bbc17..c579cb58549f5 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala @@ -493,7 +493,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { assert(rootNode1.rightNode.nonEmpty) val treeInput = TreePoint.convertToTreeRDD(rdd, bins, metadata) - val baggedInput = BaggedPoint.convertToBaggedRDDWithoutSampling(treeInput) + val baggedInput = BaggedPoint.convertToBaggedRDD(treeInput, 1.0, 1, false) // Single group second level tree construction. val nodesForGroup = Map((0, Array(rootNode1.leftNode.get, rootNode1.rightNode.get))) @@ -786,7 +786,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { val (splits, bins) = DecisionTree.findSplitsBins(input, metadata) val treeInput = TreePoint.convertToTreeRDD(input, bins, metadata) - val baggedInput = BaggedPoint.convertToBaggedRDDWithoutSampling(treeInput) + val baggedInput = BaggedPoint.convertToBaggedRDD(treeInput, 1.0, 1, false) val topNode = Node.emptyNode(nodeIndex = 1) assert(topNode.predict.predict === Double.MinValue) @@ -829,7 +829,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { val (splits, bins) = DecisionTree.findSplitsBins(input, metadata) val treeInput = TreePoint.convertToTreeRDD(input, bins, metadata) - val baggedInput = BaggedPoint.convertToBaggedRDDWithoutSampling(treeInput) + val baggedInput = BaggedPoint.convertToBaggedRDD(treeInput, 1.0, 1, false) val topNode = Node.emptyNode(nodeIndex = 1) assert(topNode.predict.predict === Double.MinValue) diff --git a/mllib/src/test/scala/org/apache/spark/mllib/tree/EnsembleTestHelper.scala b/mllib/src/test/scala/org/apache/spark/mllib/tree/EnsembleTestHelper.scala new file mode 100644 index 0000000000000..effb7b8259ffb --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/mllib/tree/EnsembleTestHelper.scala @@ -0,0 +1,94 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.tree + +import org.apache.spark.mllib.linalg.Vectors +import org.apache.spark.mllib.regression.LabeledPoint +import org.apache.spark.mllib.tree.model.WeightedEnsembleModel +import org.apache.spark.util.StatCounter + +import scala.collection.mutable + +object EnsembleTestHelper { + + /** + * Aggregates all values in data, and tests whether the empirical mean and stddev are within + * epsilon of the expected values. + * @param data Every element of the data should be an i.i.d. sample from some distribution. + */ + def testRandomArrays( + data: Array[Array[Double]], + numCols: Int, + expectedMean: Double, + expectedStddev: Double, + epsilon: Double) { + val values = new mutable.ArrayBuffer[Double]() + data.foreach { row => + assert(row.size == numCols) + values ++= row + } + val stats = new StatCounter(values) + assert(math.abs(stats.mean - expectedMean) < epsilon) + assert(math.abs(stats.stdev - expectedStddev) < epsilon) + } + + def validateClassifier( + model: WeightedEnsembleModel, + input: Seq[LabeledPoint], + requiredAccuracy: Double) { + val predictions = input.map(x => model.predict(x.features)) + val numOffPredictions = predictions.zip(input).count { case (prediction, expected) => + prediction != expected.label + } + val accuracy = (input.length - numOffPredictions).toDouble / input.length + assert(accuracy >= requiredAccuracy, + s"validateClassifier calculated accuracy $accuracy but required $requiredAccuracy.") + } + + def validateRegressor( + model: WeightedEnsembleModel, + input: Seq[LabeledPoint], + requiredMSE: Double) { + val predictions = input.map(x => model.predict(x.features)) + val squaredError = predictions.zip(input).map { case (prediction, expected) => + val err = prediction - expected.label + err * err + }.sum + val mse = squaredError / input.length + assert(mse <= requiredMSE, s"validateRegressor calculated MSE $mse but required $requiredMSE.") + } + + def generateOrderedLabeledPoints(numFeatures: Int, numInstances: Int): Array[LabeledPoint] = { + val arr = new Array[LabeledPoint](numInstances) + for (i <- 0 until numInstances) { + val label = if (i < numInstances / 10) { + 0.0 + } else if (i < numInstances / 2) { + 1.0 + } else if (i < numInstances * 0.9) { + 0.0 + } else { + 1.0 + } + val features = Array.fill[Double](numFeatures)(i.toDouble) + arr(i) = new LabeledPoint(label, Vectors.dense(features)) + } + arr + } + +} diff --git a/mllib/src/test/scala/org/apache/spark/mllib/tree/GradientBoostingSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/tree/GradientBoostingSuite.scala new file mode 100644 index 0000000000000..970fff82215e2 --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/mllib/tree/GradientBoostingSuite.scala @@ -0,0 +1,132 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.tree + +import org.scalatest.FunSuite + +import org.apache.spark.mllib.regression.LabeledPoint +import org.apache.spark.mllib.tree.configuration.Algo._ +import org.apache.spark.mllib.tree.configuration.{BoostingStrategy, Strategy} +import org.apache.spark.mllib.tree.impurity.{Variance, Gini} +import org.apache.spark.mllib.tree.loss.{SquaredError, LogLoss} +import org.apache.spark.mllib.tree.model.{WeightedEnsembleModel, DecisionTreeModel} + +import org.apache.spark.mllib.util.LocalSparkContext + +/** + * Test suite for [[GradientBoosting]]. + */ +class GradientBoostingSuite extends FunSuite with LocalSparkContext { + + test("Regression with continuous features: SquaredError") { + + GradientBoostingSuite.testCombinations.foreach { + case (numEstimators, learningRate, subsamplingRate) => + val arr = EnsembleTestHelper.generateOrderedLabeledPoints(numFeatures = 50, 1000) + val rdd = sc.parallelize(arr) + val categoricalFeaturesInfo = Map.empty[Int, Int] + + val remappedInput = rdd.map(x => new LabeledPoint((x.label * 2) - 1, x.features)) + val treeStrategy = new Strategy(algo = Regression, impurity = Variance, maxDepth = 2, + numClassesForClassification = 2, categoricalFeaturesInfo = categoricalFeaturesInfo, + subsamplingRate = subsamplingRate) + + val dt = DecisionTree.train(remappedInput, treeStrategy) + + val boostingStrategy = new BoostingStrategy(Regression, numEstimators, SquaredError, + subsamplingRate, learningRate, 1, categoricalFeaturesInfo, treeStrategy) + + val gbt = GradientBoosting.trainRegressor(rdd, boostingStrategy) + assert(gbt.weakHypotheses.size === numEstimators) + val gbtTree = gbt.weakHypotheses(0) + + EnsembleTestHelper.validateRegressor(gbt, arr, 0.02) + + // Make sure trees are the same. + assert(gbtTree.toString == dt.toString) + } + } + + test("Regression with continuous features: Absolute Error") { + + GradientBoostingSuite.testCombinations.foreach { + case (numEstimators, learningRate, subsamplingRate) => + val arr = EnsembleTestHelper.generateOrderedLabeledPoints(numFeatures = 50, 1000) + val rdd = sc.parallelize(arr) + val categoricalFeaturesInfo = Map.empty[Int, Int] + + val remappedInput = rdd.map(x => new LabeledPoint((x.label * 2) - 1, x.features)) + val treeStrategy = new Strategy(algo = Regression, impurity = Variance, maxDepth = 2, + numClassesForClassification = 2, categoricalFeaturesInfo = categoricalFeaturesInfo, + subsamplingRate = subsamplingRate) + + val dt = DecisionTree.train(remappedInput, treeStrategy) + + val boostingStrategy = new BoostingStrategy(Regression, numEstimators, SquaredError, + subsamplingRate, learningRate, 1, categoricalFeaturesInfo, treeStrategy) + + val gbt = GradientBoosting.trainRegressor(rdd, boostingStrategy) + assert(gbt.weakHypotheses.size === numEstimators) + val gbtTree = gbt.weakHypotheses(0) + + EnsembleTestHelper.validateRegressor(gbt, arr, 0.02) + + // Make sure trees are the same. + assert(gbtTree.toString == dt.toString) + } + } + + + test("Binary classification with continuous features: Log Loss") { + + GradientBoostingSuite.testCombinations.foreach { + case (numEstimators, learningRate, subsamplingRate) => + val arr = EnsembleTestHelper.generateOrderedLabeledPoints(numFeatures = 50, 1000) + val rdd = sc.parallelize(arr) + val categoricalFeaturesInfo = Map.empty[Int, Int] + + val remappedInput = rdd.map(x => new LabeledPoint((x.label * 2) - 1, x.features)) + val treeStrategy = new Strategy(algo = Regression, impurity = Variance, maxDepth = 2, + numClassesForClassification = 2, categoricalFeaturesInfo = categoricalFeaturesInfo, + subsamplingRate = subsamplingRate) + + val dt = DecisionTree.train(remappedInput, treeStrategy) + + val boostingStrategy = new BoostingStrategy(Classification, numEstimators, LogLoss, + subsamplingRate, learningRate, 1, categoricalFeaturesInfo, treeStrategy) + + val gbt = GradientBoosting.trainClassifier(rdd, boostingStrategy) + assert(gbt.weakHypotheses.size === numEstimators) + val gbtTree = gbt.weakHypotheses(0) + + EnsembleTestHelper.validateClassifier(gbt, arr, 0.9) + + // Make sure trees are the same. + assert(gbtTree.toString == dt.toString) + } + } + +} + +object GradientBoostingSuite { + + // Combinations for estimators, learning rates and subsamplingRate + val testCombinations + = Array((10, 1.0, 1.0), (10, 0.1, 1.0), (10, 1.0, 0.75), (10, 0.1, 0.75)) + +} diff --git a/mllib/src/test/scala/org/apache/spark/mllib/tree/RandomForestSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/tree/RandomForestSuite.scala index 6b13765b98f41..73c4393c3581a 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/tree/RandomForestSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/tree/RandomForestSuite.scala @@ -25,100 +25,91 @@ import org.apache.spark.mllib.linalg.Vectors import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.tree.configuration.Algo._ import org.apache.spark.mllib.tree.configuration.Strategy -import org.apache.spark.mllib.tree.impl.{BaggedPoint, DecisionTreeMetadata} +import org.apache.spark.mllib.tree.impl.DecisionTreeMetadata import org.apache.spark.mllib.tree.impurity.{Gini, Variance} -import org.apache.spark.mllib.tree.model.{Node, RandomForestModel} +import org.apache.spark.mllib.tree.model.Node import org.apache.spark.mllib.util.LocalSparkContext -import org.apache.spark.util.StatCounter /** * Test suite for [[RandomForest]]. */ class RandomForestSuite extends FunSuite with LocalSparkContext { - - test("BaggedPoint RDD: without subsampling") { - val arr = RandomForestSuite.generateOrderedLabeledPoints(numFeatures = 1) - val rdd = sc.parallelize(arr) - val baggedRDD = BaggedPoint.convertToBaggedRDDWithoutSampling(rdd) - baggedRDD.collect().foreach { baggedPoint => - assert(baggedPoint.subsampleWeights.size == 1 && baggedPoint.subsampleWeights(0) == 1) - } - } - - test("BaggedPoint RDD: with subsampling") { - val numSubsamples = 100 - val (expectedMean, expectedStddev) = (1.0, 1.0) - - val seeds = Array(123, 5354, 230, 349867, 23987) - val arr = RandomForestSuite.generateOrderedLabeledPoints(numFeatures = 1) + def binaryClassificationTestWithContinuousFeatures(strategy: Strategy) { + val arr = EnsembleTestHelper.generateOrderedLabeledPoints(numFeatures = 50, 1000) val rdd = sc.parallelize(arr) - seeds.foreach { seed => - val baggedRDD = BaggedPoint.convertToBaggedRDD(rdd, numSubsamples, seed = seed) - val subsampleCounts: Array[Array[Double]] = baggedRDD.map(_.subsampleWeights).collect() - RandomForestSuite.testRandomArrays(subsampleCounts, numSubsamples, expectedMean, - expectedStddev, epsilon = 0.01) - } - } - - test("Binary classification with continuous features:" + - " comparing DecisionTree vs. RandomForest(numTrees = 1)") { - - val arr = RandomForestSuite.generateOrderedLabeledPoints(numFeatures = 50) - val rdd = sc.parallelize(arr) - val categoricalFeaturesInfo = Map.empty[Int, Int] val numTrees = 1 - val strategy = new Strategy(algo = Classification, impurity = Gini, maxDepth = 2, - numClassesForClassification = 2, categoricalFeaturesInfo = categoricalFeaturesInfo) - val rf = RandomForest.trainClassifier(rdd, strategy, numTrees = numTrees, featureSubsetStrategy = "auto", seed = 123) - assert(rf.trees.size === 1) - val rfTree = rf.trees(0) + assert(rf.weakHypotheses.size === 1) + val rfTree = rf.weakHypotheses(0) val dt = DecisionTree.train(rdd, strategy) - RandomForestSuite.validateClassifier(rf, arr, 0.9) + EnsembleTestHelper.validateClassifier(rf, arr, 0.9) DecisionTreeSuite.validateClassifier(dt, arr, 0.9) // Make sure trees are the same. assert(rfTree.toString == dt.toString) } - test("Regression with continuous features:" + + test("Binary classification with continuous features:" + + " comparing DecisionTree vs. RandomForest(numTrees = 1)") { + val categoricalFeaturesInfo = Map.empty[Int, Int] + val strategy = new Strategy(algo = Classification, impurity = Gini, maxDepth = 2, + numClassesForClassification = 2, categoricalFeaturesInfo = categoricalFeaturesInfo) + binaryClassificationTestWithContinuousFeatures(strategy) + } + + test("Binary classification with continuous features and node Id cache :" + " comparing DecisionTree vs. RandomForest(numTrees = 1)") { + val categoricalFeaturesInfo = Map.empty[Int, Int] + val strategy = new Strategy(algo = Classification, impurity = Gini, maxDepth = 2, + numClassesForClassification = 2, categoricalFeaturesInfo = categoricalFeaturesInfo, useNodeIdCache = true) + binaryClassificationTestWithContinuousFeatures(strategy) + } - val arr = RandomForestSuite.generateOrderedLabeledPoints(numFeatures = 50) + def regressionTestWithContinuousFeatures(strategy: Strategy) { + val arr = EnsembleTestHelper.generateOrderedLabeledPoints(numFeatures = 50, 1000) val rdd = sc.parallelize(arr) - val categoricalFeaturesInfo = Map.empty[Int, Int] val numTrees = 1 - val strategy = new Strategy(algo = Regression, impurity = Variance, - maxDepth = 2, maxBins = 10, numClassesForClassification = 2, - categoricalFeaturesInfo = categoricalFeaturesInfo) - val rf = RandomForest.trainRegressor(rdd, strategy, numTrees = numTrees, featureSubsetStrategy = "auto", seed = 123) - assert(rf.trees.size === 1) - val rfTree = rf.trees(0) + assert(rf.weakHypotheses.size === 1) + val rfTree = rf.weakHypotheses(0) val dt = DecisionTree.train(rdd, strategy) - RandomForestSuite.validateRegressor(rf, arr, 0.01) + EnsembleTestHelper.validateRegressor(rf, arr, 0.01) DecisionTreeSuite.validateRegressor(dt, arr, 0.01) // Make sure trees are the same. assert(rfTree.toString == dt.toString) } - test("Binary classification with continuous features: subsampling features") { - val numFeatures = 50 - val arr = RandomForestSuite.generateOrderedLabeledPoints(numFeatures) - val rdd = sc.parallelize(arr) + test("Regression with continuous features:" + + " comparing DecisionTree vs. RandomForest(numTrees = 1)") { val categoricalFeaturesInfo = Map.empty[Int, Int] + val strategy = new Strategy(algo = Regression, impurity = Variance, + maxDepth = 2, maxBins = 10, numClassesForClassification = 2, + categoricalFeaturesInfo = categoricalFeaturesInfo) + regressionTestWithContinuousFeatures(strategy) + } - val strategy = new Strategy(algo = Classification, impurity = Gini, maxDepth = 2, - numClassesForClassification = 2, categoricalFeaturesInfo = categoricalFeaturesInfo) + test("Regression with continuous features and node Id cache :" + + " comparing DecisionTree vs. RandomForest(numTrees = 1)") { + val categoricalFeaturesInfo = Map.empty[Int, Int] + val strategy = new Strategy(algo = Regression, impurity = Variance, + maxDepth = 2, maxBins = 10, numClassesForClassification = 2, + categoricalFeaturesInfo = categoricalFeaturesInfo, useNodeIdCache = true) + regressionTestWithContinuousFeatures(strategy) + } + + def binaryClassificationTestWithContinuousFeaturesAndSubsampledFeatures(strategy: Strategy) { + val numFeatures = 50 + val arr = EnsembleTestHelper.generateOrderedLabeledPoints(numFeatures, 1000) + val rdd = sc.parallelize(arr) // Select feature subset for top nodes. Return true if OK. def checkFeatureSubsetStrategy( @@ -174,6 +165,20 @@ class RandomForestSuite extends FunSuite with LocalSparkContext { checkFeatureSubsetStrategy(numTrees = 2, "onethird", (numFeatures / 3.0).ceil.toInt) } + test("Binary classification with continuous features: subsampling features") { + val categoricalFeaturesInfo = Map.empty[Int, Int] + val strategy = new Strategy(algo = Classification, impurity = Gini, maxDepth = 2, + numClassesForClassification = 2, categoricalFeaturesInfo = categoricalFeaturesInfo) + binaryClassificationTestWithContinuousFeaturesAndSubsampledFeatures(strategy) + } + + test("Binary classification with continuous features and node Id cache: subsampling features") { + val categoricalFeaturesInfo = Map.empty[Int, Int] + val strategy = new Strategy(algo = Classification, impurity = Gini, maxDepth = 2, + numClassesForClassification = 2, categoricalFeaturesInfo = categoricalFeaturesInfo, useNodeIdCache = true) + binaryClassificationTestWithContinuousFeaturesAndSubsampledFeatures(strategy) + } + test("alternating categorical and continuous features with multiclass labels to test indexing") { val arr = new Array[LabeledPoint](4) arr(0) = new LabeledPoint(0.0, Vectors.dense(1.0, 0.0, 0.0, 3.0, 1.0)) @@ -187,77 +192,8 @@ class RandomForestSuite extends FunSuite with LocalSparkContext { numClassesForClassification = 3, categoricalFeaturesInfo = categoricalFeaturesInfo) val model = RandomForest.trainClassifier(input, strategy, numTrees = 2, featureSubsetStrategy = "sqrt", seed = 12345) - RandomForestSuite.validateClassifier(model, arr, 1.0) + EnsembleTestHelper.validateClassifier(model, arr, 1.0) } - } -object RandomForestSuite { - - /** - * Aggregates all values in data, and tests whether the empirical mean and stddev are within - * epsilon of the expected values. - * @param data Every element of the data should be an i.i.d. sample from some distribution. - */ - def testRandomArrays( - data: Array[Array[Double]], - numCols: Int, - expectedMean: Double, - expectedStddev: Double, - epsilon: Double) { - val values = new mutable.ArrayBuffer[Double]() - data.foreach { row => - assert(row.size == numCols) - values ++= row - } - val stats = new StatCounter(values) - assert(math.abs(stats.mean - expectedMean) < epsilon) - assert(math.abs(stats.stdev - expectedStddev) < epsilon) - } - - def validateClassifier( - model: RandomForestModel, - input: Seq[LabeledPoint], - requiredAccuracy: Double) { - val predictions = input.map(x => model.predict(x.features)) - val numOffPredictions = predictions.zip(input).count { case (prediction, expected) => - prediction != expected.label - } - val accuracy = (input.length - numOffPredictions).toDouble / input.length - assert(accuracy >= requiredAccuracy, - s"validateClassifier calculated accuracy $accuracy but required $requiredAccuracy.") - } - def validateRegressor( - model: RandomForestModel, - input: Seq[LabeledPoint], - requiredMSE: Double) { - val predictions = input.map(x => model.predict(x.features)) - val squaredError = predictions.zip(input).map { case (prediction, expected) => - val err = prediction - expected.label - err * err - }.sum - val mse = squaredError / input.length - assert(mse <= requiredMSE, s"validateRegressor calculated MSE $mse but required $requiredMSE.") - } - - def generateOrderedLabeledPoints(numFeatures: Int): Array[LabeledPoint] = { - val numInstances = 1000 - val arr = new Array[LabeledPoint](numInstances) - for (i <- 0 until numInstances) { - val label = if (i < numInstances / 10) { - 0.0 - } else if (i < numInstances / 2) { - 1.0 - } else if (i < numInstances * 0.9) { - 0.0 - } else { - 1.0 - } - val features = Array.fill[Double](numFeatures)(i.toDouble) - arr(i) = new LabeledPoint(label, Vectors.dense(features)) - } - arr - } - -} diff --git a/mllib/src/test/scala/org/apache/spark/mllib/tree/impl/BaggedPointSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/tree/impl/BaggedPointSuite.scala new file mode 100644 index 0000000000000..5cb433232e714 --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/mllib/tree/impl/BaggedPointSuite.scala @@ -0,0 +1,100 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.tree.impl + +import org.scalatest.FunSuite + +import org.apache.spark.mllib.tree.EnsembleTestHelper +import org.apache.spark.mllib.util.LocalSparkContext + +/** + * Test suite for [[BaggedPoint]]. + */ +class BaggedPointSuite extends FunSuite with LocalSparkContext { + + test("BaggedPoint RDD: without subsampling") { + val arr = EnsembleTestHelper.generateOrderedLabeledPoints(1, 1000) + val rdd = sc.parallelize(arr) + val baggedRDD = BaggedPoint.convertToBaggedRDD(rdd, 1.0, 1, false, 42) + baggedRDD.collect().foreach { baggedPoint => + assert(baggedPoint.subsampleWeights.size == 1 && baggedPoint.subsampleWeights(0) == 1) + } + } + + test("BaggedPoint RDD: with subsampling with replacement (fraction = 1.0)") { + val numSubsamples = 100 + val (expectedMean, expectedStddev) = (1.0, 1.0) + + val seeds = Array(123, 5354, 230, 349867, 23987) + val arr = EnsembleTestHelper.generateOrderedLabeledPoints(1, 1000) + val rdd = sc.parallelize(arr) + seeds.foreach { seed => + val baggedRDD = BaggedPoint.convertToBaggedRDD(rdd, 1.0, numSubsamples, true, seed) + val subsampleCounts: Array[Array[Double]] = baggedRDD.map(_.subsampleWeights).collect() + EnsembleTestHelper.testRandomArrays(subsampleCounts, numSubsamples, expectedMean, + expectedStddev, epsilon = 0.01) + } + } + + test("BaggedPoint RDD: with subsampling with replacement (fraction = 0.5)") { + val numSubsamples = 100 + val subsample = 0.5 + val (expectedMean, expectedStddev) = (subsample, math.sqrt(subsample)) + + val seeds = Array(123, 5354, 230, 349867, 23987) + val arr = EnsembleTestHelper.generateOrderedLabeledPoints(1, 1000) + val rdd = sc.parallelize(arr) + seeds.foreach { seed => + val baggedRDD = BaggedPoint.convertToBaggedRDD(rdd, subsample, numSubsamples, true, seed) + val subsampleCounts: Array[Array[Double]] = baggedRDD.map(_.subsampleWeights).collect() + EnsembleTestHelper.testRandomArrays(subsampleCounts, numSubsamples, expectedMean, + expectedStddev, epsilon = 0.01) + } + } + + test("BaggedPoint RDD: with subsampling without replacement (fraction = 1.0)") { + val numSubsamples = 100 + val (expectedMean, expectedStddev) = (1.0, 0) + + val seeds = Array(123, 5354, 230, 349867, 23987) + val arr = EnsembleTestHelper.generateOrderedLabeledPoints(1, 1000) + val rdd = sc.parallelize(arr) + seeds.foreach { seed => + val baggedRDD = BaggedPoint.convertToBaggedRDD(rdd, 1.0, numSubsamples, false, seed) + val subsampleCounts: Array[Array[Double]] = baggedRDD.map(_.subsampleWeights).collect() + EnsembleTestHelper.testRandomArrays(subsampleCounts, numSubsamples, expectedMean, + expectedStddev, epsilon = 0.01) + } + } + + test("BaggedPoint RDD: with subsampling without replacement (fraction = 0.5)") { + val numSubsamples = 100 + val subsample = 0.5 + val (expectedMean, expectedStddev) = (subsample, math.sqrt(subsample * (1 - subsample))) + + val seeds = Array(123, 5354, 230, 349867, 23987) + val arr = EnsembleTestHelper.generateOrderedLabeledPoints(1, 1000) + val rdd = sc.parallelize(arr) + seeds.foreach { seed => + val baggedRDD = BaggedPoint.convertToBaggedRDD(rdd, subsample, numSubsamples, false, seed) + val subsampleCounts: Array[Array[Double]] = baggedRDD.map(_.subsampleWeights).collect() + EnsembleTestHelper.testRandomArrays(subsampleCounts, numSubsamples, expectedMean, + expectedStddev, epsilon = 0.01) + } + } +} diff --git a/network/common/pom.xml b/network/common/pom.xml new file mode 100644 index 0000000000000..ea887148d98ba --- /dev/null +++ b/network/common/pom.xml @@ -0,0 +1,110 @@ + + + + + 4.0.0 + + org.apache.spark + spark-parent + 1.2.0-SNAPSHOT + ../../pom.xml + + + org.apache.spark + spark-network-common_2.10 + jar + Spark Project Common Network Code + http://spark.apache.org/ + + network-common + + + + + + io.netty + netty-all + + + org.slf4j + slf4j-api + + + + + com.google.guava + guava + provided + + + + + junit + junit + test + + + com.novocode + junit-interface + test + + + log4j + log4j + test + + + org.mockito + mockito-all + test + + + org.scalatest + scalatest_${scala.binary.version} + test + + + + + target/scala-${scala.binary.version}/classes + target/scala-${scala.binary.version}/test-classes + + + + org.apache.maven.plugins + maven-jar-plugin + 2.2 + + + + test-jar + + + + test-jar-on-test-compile + test-compile + + test-jar + + + + + + + diff --git a/network/common/src/main/java/org/apache/spark/network/TransportContext.java b/network/common/src/main/java/org/apache/spark/network/TransportContext.java new file mode 100644 index 0000000000000..a271841e4e56c --- /dev/null +++ b/network/common/src/main/java/org/apache/spark/network/TransportContext.java @@ -0,0 +1,121 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network; + +import io.netty.channel.Channel; +import io.netty.channel.socket.SocketChannel; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import org.apache.spark.network.client.TransportClient; +import org.apache.spark.network.client.TransportClientFactory; +import org.apache.spark.network.client.TransportResponseHandler; +import org.apache.spark.network.protocol.MessageDecoder; +import org.apache.spark.network.protocol.MessageEncoder; +import org.apache.spark.network.server.RpcHandler; +import org.apache.spark.network.server.TransportChannelHandler; +import org.apache.spark.network.server.TransportRequestHandler; +import org.apache.spark.network.server.TransportServer; +import org.apache.spark.network.server.StreamManager; +import org.apache.spark.network.util.NettyUtils; +import org.apache.spark.network.util.TransportConf; + +/** + * Contains the context to create a {@link TransportServer}, {@link TransportClientFactory}, and to + * setup Netty Channel pipelines with a {@link org.apache.spark.network.server.TransportChannelHandler}. + * + * There are two communication protocols that the TransportClient provides, control-plane RPCs and + * data-plane "chunk fetching". The handling of the RPCs is performed outside of the scope of the + * TransportContext (i.e., by a user-provided handler), and it is responsible for setting up streams + * which can be streamed through the data plane in chunks using zero-copy IO. + * + * The TransportServer and TransportClientFactory both create a TransportChannelHandler for each + * channel. As each TransportChannelHandler contains a TransportClient, this enables server + * processes to send messages back to the client on an existing channel. + */ +public class TransportContext { + private final Logger logger = LoggerFactory.getLogger(TransportContext.class); + + private final TransportConf conf; + private final RpcHandler rpcHandler; + + private final MessageEncoder encoder; + private final MessageDecoder decoder; + + public TransportContext(TransportConf conf, RpcHandler rpcHandler) { + this.conf = conf; + this.rpcHandler = rpcHandler; + this.encoder = new MessageEncoder(); + this.decoder = new MessageDecoder(); + } + + public TransportClientFactory createClientFactory() { + return new TransportClientFactory(this); + } + + /** Create a server which will attempt to bind to a specific port. */ + public TransportServer createServer(int port) { + return new TransportServer(this, port); + } + + /** Creates a new server, binding to any available ephemeral port. */ + public TransportServer createServer() { + return new TransportServer(this, 0); + } + + /** + * Initializes a client or server Netty Channel Pipeline which encodes/decodes messages and + * has a {@link org.apache.spark.network.server.TransportChannelHandler} to handle request or + * response messages. + * + * @return Returns the created TransportChannelHandler, which includes a TransportClient that can + * be used to communicate on this channel. The TransportClient is directly associated with a + * ChannelHandler to ensure all users of the same channel get the same TransportClient object. + */ + public TransportChannelHandler initializePipeline(SocketChannel channel) { + try { + TransportChannelHandler channelHandler = createChannelHandler(channel); + channel.pipeline() + .addLast("encoder", encoder) + .addLast("frameDecoder", NettyUtils.createFrameDecoder()) + .addLast("decoder", decoder) + // NOTE: Chunks are currently guaranteed to be returned in the order of request, but this + // would require more logic to guarantee if this were not part of the same event loop. + .addLast("handler", channelHandler); + return channelHandler; + } catch (RuntimeException e) { + logger.error("Error while initializing Netty pipeline", e); + throw e; + } + } + + /** + * Creates the server- and client-side handler which is used to handle both RequestMessages and + * ResponseMessages. The channel is expected to have been successfully created, though certain + * properties (such as the remoteAddress()) may not be available yet. + */ + private TransportChannelHandler createChannelHandler(Channel channel) { + TransportResponseHandler responseHandler = new TransportResponseHandler(channel); + TransportClient client = new TransportClient(channel, responseHandler); + TransportRequestHandler requestHandler = new TransportRequestHandler(channel, client, + rpcHandler); + return new TransportChannelHandler(client, responseHandler, requestHandler); + } + + public TransportConf getConf() { return conf; } +} diff --git a/network/common/src/main/java/org/apache/spark/network/buffer/FileSegmentManagedBuffer.java b/network/common/src/main/java/org/apache/spark/network/buffer/FileSegmentManagedBuffer.java new file mode 100644 index 0000000000000..89ed79bc63903 --- /dev/null +++ b/network/common/src/main/java/org/apache/spark/network/buffer/FileSegmentManagedBuffer.java @@ -0,0 +1,154 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.buffer; + +import java.io.File; +import java.io.FileInputStream; +import java.io.IOException; +import java.io.InputStream; +import java.io.RandomAccessFile; +import java.nio.ByteBuffer; +import java.nio.channels.FileChannel; + +import com.google.common.base.Objects; +import com.google.common.io.ByteStreams; +import io.netty.channel.DefaultFileRegion; + +import org.apache.spark.network.util.JavaUtils; + +/** + * A {@link ManagedBuffer} backed by a segment in a file. + */ +public final class FileSegmentManagedBuffer extends ManagedBuffer { + + /** + * Memory mapping is expensive and can destabilize the JVM (SPARK-1145, SPARK-3889). + * Avoid unless there's a good reason not to. + */ + // TODO: Make this configurable + private static final long MIN_MEMORY_MAP_BYTES = 2 * 1024 * 1024; + + private final File file; + private final long offset; + private final long length; + + public FileSegmentManagedBuffer(File file, long offset, long length) { + this.file = file; + this.offset = offset; + this.length = length; + } + + @Override + public long size() { + return length; + } + + @Override + public ByteBuffer nioByteBuffer() throws IOException { + FileChannel channel = null; + try { + channel = new RandomAccessFile(file, "r").getChannel(); + // Just copy the buffer if it's sufficiently small, as memory mapping has a high overhead. + if (length < MIN_MEMORY_MAP_BYTES) { + ByteBuffer buf = ByteBuffer.allocate((int) length); + channel.position(offset); + while (buf.remaining() != 0) { + if (channel.read(buf) == -1) { + throw new IOException(String.format("Reached EOF before filling buffer\n" + + "offset=%s\nfile=%s\nbuf.remaining=%s", + offset, file.getAbsoluteFile(), buf.remaining())); + } + } + buf.flip(); + return buf; + } else { + return channel.map(FileChannel.MapMode.READ_ONLY, offset, length); + } + } catch (IOException e) { + try { + if (channel != null) { + long size = channel.size(); + throw new IOException("Error in reading " + this + " (actual file length " + size + ")", + e); + } + } catch (IOException ignored) { + // ignore + } + throw new IOException("Error in opening " + this, e); + } finally { + JavaUtils.closeQuietly(channel); + } + } + + @Override + public InputStream createInputStream() throws IOException { + FileInputStream is = null; + try { + is = new FileInputStream(file); + ByteStreams.skipFully(is, offset); + return ByteStreams.limit(is, length); + } catch (IOException e) { + try { + if (is != null) { + long size = file.length(); + throw new IOException("Error in reading " + this + " (actual file length " + size + ")", + e); + } + } catch (IOException ignored) { + // ignore + } finally { + JavaUtils.closeQuietly(is); + } + throw new IOException("Error in opening " + this, e); + } catch (RuntimeException e) { + JavaUtils.closeQuietly(is); + throw e; + } + } + + @Override + public ManagedBuffer retain() { + return this; + } + + @Override + public ManagedBuffer release() { + return this; + } + + @Override + public Object convertToNetty() throws IOException { + FileChannel fileChannel = new FileInputStream(file).getChannel(); + return new DefaultFileRegion(fileChannel, offset, length); + } + + public File getFile() { return file; } + + public long getOffset() { return offset; } + + public long getLength() { return length; } + + @Override + public String toString() { + return Objects.toStringHelper(this) + .add("file", file) + .add("offset", offset) + .add("length", length) + .toString(); + } +} diff --git a/network/common/src/main/java/org/apache/spark/network/buffer/ManagedBuffer.java b/network/common/src/main/java/org/apache/spark/network/buffer/ManagedBuffer.java new file mode 100644 index 0000000000000..a415db593a788 --- /dev/null +++ b/network/common/src/main/java/org/apache/spark/network/buffer/ManagedBuffer.java @@ -0,0 +1,71 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.buffer; + +import java.io.IOException; +import java.io.InputStream; +import java.nio.ByteBuffer; + +/** + * This interface provides an immutable view for data in the form of bytes. The implementation + * should specify how the data is provided: + * + * - {@link FileSegmentManagedBuffer}: data backed by part of a file + * - {@link NioManagedBuffer}: data backed by a NIO ByteBuffer + * - {@link NettyManagedBuffer}: data backed by a Netty ByteBuf + * + * The concrete buffer implementation might be managed outside the JVM garbage collector. + * For example, in the case of {@link NettyManagedBuffer}, the buffers are reference counted. + * In that case, if the buffer is going to be passed around to a different thread, retain/release + * should be called. + */ +public abstract class ManagedBuffer { + + /** Number of bytes of the data. */ + public abstract long size(); + + /** + * Exposes this buffer's data as an NIO ByteBuffer. Changing the position and limit of the + * returned ByteBuffer should not affect the content of this buffer. + */ + // TODO: Deprecate this, usage may require expensive memory mapping or allocation. + public abstract ByteBuffer nioByteBuffer() throws IOException; + + /** + * Exposes this buffer's data as an InputStream. The underlying implementation does not + * necessarily check for the length of bytes read, so the caller is responsible for making sure + * it does not go over the limit. + */ + public abstract InputStream createInputStream() throws IOException; + + /** + * Increment the reference count by one if applicable. + */ + public abstract ManagedBuffer retain(); + + /** + * If applicable, decrement the reference count by one and deallocates the buffer if the + * reference count reaches zero. + */ + public abstract ManagedBuffer release(); + + /** + * Convert the buffer into an Netty object, used to write the data out. + */ + public abstract Object convertToNetty() throws IOException; +} diff --git a/network/common/src/main/java/org/apache/spark/network/buffer/NettyManagedBuffer.java b/network/common/src/main/java/org/apache/spark/network/buffer/NettyManagedBuffer.java new file mode 100644 index 0000000000000..c806bfa45bef3 --- /dev/null +++ b/network/common/src/main/java/org/apache/spark/network/buffer/NettyManagedBuffer.java @@ -0,0 +1,76 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.buffer; + +import java.io.IOException; +import java.io.InputStream; +import java.nio.ByteBuffer; + +import com.google.common.base.Objects; +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufInputStream; + +/** + * A {@link ManagedBuffer} backed by a Netty {@link ByteBuf}. + */ +public final class NettyManagedBuffer extends ManagedBuffer { + private final ByteBuf buf; + + public NettyManagedBuffer(ByteBuf buf) { + this.buf = buf; + } + + @Override + public long size() { + return buf.readableBytes(); + } + + @Override + public ByteBuffer nioByteBuffer() throws IOException { + return buf.nioBuffer(); + } + + @Override + public InputStream createInputStream() throws IOException { + return new ByteBufInputStream(buf); + } + + @Override + public ManagedBuffer retain() { + buf.retain(); + return this; + } + + @Override + public ManagedBuffer release() { + buf.release(); + return this; + } + + @Override + public Object convertToNetty() throws IOException { + return buf.duplicate(); + } + + @Override + public String toString() { + return Objects.toStringHelper(this) + .add("buf", buf) + .toString(); + } +} diff --git a/network/common/src/main/java/org/apache/spark/network/buffer/NioManagedBuffer.java b/network/common/src/main/java/org/apache/spark/network/buffer/NioManagedBuffer.java new file mode 100644 index 0000000000000..f55b884bc45ce --- /dev/null +++ b/network/common/src/main/java/org/apache/spark/network/buffer/NioManagedBuffer.java @@ -0,0 +1,75 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.buffer; + +import java.io.IOException; +import java.io.InputStream; +import java.nio.ByteBuffer; + +import com.google.common.base.Objects; +import io.netty.buffer.ByteBufInputStream; +import io.netty.buffer.Unpooled; + +/** + * A {@link ManagedBuffer} backed by {@link ByteBuffer}. + */ +public final class NioManagedBuffer extends ManagedBuffer { + private final ByteBuffer buf; + + public NioManagedBuffer(ByteBuffer buf) { + this.buf = buf; + } + + @Override + public long size() { + return buf.remaining(); + } + + @Override + public ByteBuffer nioByteBuffer() throws IOException { + return buf.duplicate(); + } + + @Override + public InputStream createInputStream() throws IOException { + return new ByteBufInputStream(Unpooled.wrappedBuffer(buf)); + } + + @Override + public ManagedBuffer retain() { + return this; + } + + @Override + public ManagedBuffer release() { + return this; + } + + @Override + public Object convertToNetty() throws IOException { + return Unpooled.wrappedBuffer(buf); + } + + @Override + public String toString() { + return Objects.toStringHelper(this) + .add("buf", buf) + .toString(); + } +} + diff --git a/network/common/src/main/java/org/apache/spark/network/client/ChunkFetchFailureException.java b/network/common/src/main/java/org/apache/spark/network/client/ChunkFetchFailureException.java new file mode 100644 index 0000000000000..1fbdcd6780785 --- /dev/null +++ b/network/common/src/main/java/org/apache/spark/network/client/ChunkFetchFailureException.java @@ -0,0 +1,31 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.client; + +/** + * General exception caused by a remote exception while fetching a chunk. + */ +public class ChunkFetchFailureException extends RuntimeException { + public ChunkFetchFailureException(String errorMsg, Throwable cause) { + super(errorMsg, cause); + } + + public ChunkFetchFailureException(String errorMsg) { + super(errorMsg); + } +} diff --git a/network/common/src/main/java/org/apache/spark/network/client/ChunkReceivedCallback.java b/network/common/src/main/java/org/apache/spark/network/client/ChunkReceivedCallback.java new file mode 100644 index 0000000000000..519e6cb470d0d --- /dev/null +++ b/network/common/src/main/java/org/apache/spark/network/client/ChunkReceivedCallback.java @@ -0,0 +1,47 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.client; + +import org.apache.spark.network.buffer.ManagedBuffer; + +/** + * Callback for the result of a single chunk result. For a single stream, the callbacks are + * guaranteed to be called by the same thread in the same order as the requests for chunks were + * made. + * + * Note that if a general stream failure occurs, all outstanding chunk requests may be failed. + */ +public interface ChunkReceivedCallback { + /** + * Called upon receipt of a particular chunk. + * + * The given buffer will initially have a refcount of 1, but will be release()'d as soon as this + * call returns. You must therefore either retain() the buffer or copy its contents before + * returning. + */ + void onSuccess(int chunkIndex, ManagedBuffer buffer); + + /** + * Called upon failure to fetch a particular chunk. Note that this may actually be called due + * to failure to fetch a prior chunk in this stream. + * + * After receiving a failure, the stream may or may not be valid. The client should not assume + * that the server's side of the stream has been closed. + */ + void onFailure(int chunkIndex, Throwable e); +} diff --git a/network/common/src/main/java/org/apache/spark/network/client/RpcResponseCallback.java b/network/common/src/main/java/org/apache/spark/network/client/RpcResponseCallback.java new file mode 100644 index 0000000000000..6ec960d795420 --- /dev/null +++ b/network/common/src/main/java/org/apache/spark/network/client/RpcResponseCallback.java @@ -0,0 +1,30 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.client; + +/** + * Callback for the result of a single RPC. This will be invoked once with either success or + * failure. + */ +public interface RpcResponseCallback { + /** Successful serialized result from server. */ + void onSuccess(byte[] response); + + /** Exception either propagated from server or raised on client side. */ + void onFailure(Throwable e); +} diff --git a/network/common/src/main/java/org/apache/spark/network/client/TransportClient.java b/network/common/src/main/java/org/apache/spark/network/client/TransportClient.java new file mode 100644 index 0000000000000..01c143fff423c --- /dev/null +++ b/network/common/src/main/java/org/apache/spark/network/client/TransportClient.java @@ -0,0 +1,189 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.client; + +import java.io.Closeable; +import java.util.UUID; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.TimeoutException; + +import com.google.common.base.Preconditions; +import com.google.common.base.Throwables; +import com.google.common.util.concurrent.SettableFuture; +import io.netty.channel.Channel; +import io.netty.channel.ChannelFuture; +import io.netty.channel.ChannelFutureListener; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import org.apache.spark.network.protocol.ChunkFetchRequest; +import org.apache.spark.network.protocol.RpcRequest; +import org.apache.spark.network.protocol.StreamChunkId; +import org.apache.spark.network.util.NettyUtils; + +/** + * Client for fetching consecutive chunks of a pre-negotiated stream. This API is intended to allow + * efficient transfer of a large amount of data, broken up into chunks with size ranging from + * hundreds of KB to a few MB. + * + * Note that while this client deals with the fetching of chunks from a stream (i.e., data plane), + * the actual setup of the streams is done outside the scope of the transport layer. The convenience + * method "sendRPC" is provided to enable control plane communication between the client and server + * to perform this setup. + * + * For example, a typical workflow might be: + * client.sendRPC(new OpenFile("/foo")) --> returns StreamId = 100 + * client.fetchChunk(streamId = 100, chunkIndex = 0, callback) + * client.fetchChunk(streamId = 100, chunkIndex = 1, callback) + * ... + * client.sendRPC(new CloseStream(100)) + * + * Construct an instance of TransportClient using {@link TransportClientFactory}. A single + * TransportClient may be used for multiple streams, but any given stream must be restricted to a + * single client, in order to avoid out-of-order responses. + * + * NB: This class is used to make requests to the server, while {@link TransportResponseHandler} is + * responsible for handling responses from the server. + * + * Concurrency: thread safe and can be called from multiple threads. + */ +public class TransportClient implements Closeable { + private final Logger logger = LoggerFactory.getLogger(TransportClient.class); + + private final Channel channel; + private final TransportResponseHandler handler; + + public TransportClient(Channel channel, TransportResponseHandler handler) { + this.channel = Preconditions.checkNotNull(channel); + this.handler = Preconditions.checkNotNull(handler); + } + + public boolean isActive() { + return channel.isOpen() || channel.isActive(); + } + + /** + * Requests a single chunk from the remote side, from the pre-negotiated streamId. + * + * Chunk indices go from 0 onwards. It is valid to request the same chunk multiple times, though + * some streams may not support this. + * + * Multiple fetchChunk requests may be outstanding simultaneously, and the chunks are guaranteed + * to be returned in the same order that they were requested, assuming only a single + * TransportClient is used to fetch the chunks. + * + * @param streamId Identifier that refers to a stream in the remote StreamManager. This should + * be agreed upon by client and server beforehand. + * @param chunkIndex 0-based index of the chunk to fetch + * @param callback Callback invoked upon successful receipt of chunk, or upon any failure. + */ + public void fetchChunk( + long streamId, + final int chunkIndex, + final ChunkReceivedCallback callback) { + final String serverAddr = NettyUtils.getRemoteAddress(channel); + final long startTime = System.currentTimeMillis(); + logger.debug("Sending fetch chunk request {} to {}", chunkIndex, serverAddr); + + final StreamChunkId streamChunkId = new StreamChunkId(streamId, chunkIndex); + handler.addFetchRequest(streamChunkId, callback); + + channel.writeAndFlush(new ChunkFetchRequest(streamChunkId)).addListener( + new ChannelFutureListener() { + @Override + public void operationComplete(ChannelFuture future) throws Exception { + if (future.isSuccess()) { + long timeTaken = System.currentTimeMillis() - startTime; + logger.trace("Sending request {} to {} took {} ms", streamChunkId, serverAddr, + timeTaken); + } else { + String errorMsg = String.format("Failed to send request %s to %s: %s", streamChunkId, + serverAddr, future.cause()); + logger.error(errorMsg, future.cause()); + handler.removeFetchRequest(streamChunkId); + callback.onFailure(chunkIndex, new RuntimeException(errorMsg, future.cause())); + channel.close(); + } + } + }); + } + + /** + * Sends an opaque message to the RpcHandler on the server-side. The callback will be invoked + * with the server's response or upon any failure. + */ + public void sendRpc(byte[] message, final RpcResponseCallback callback) { + final String serverAddr = NettyUtils.getRemoteAddress(channel); + final long startTime = System.currentTimeMillis(); + logger.trace("Sending RPC to {}", serverAddr); + + final long requestId = Math.abs(UUID.randomUUID().getLeastSignificantBits()); + handler.addRpcRequest(requestId, callback); + + channel.writeAndFlush(new RpcRequest(requestId, message)).addListener( + new ChannelFutureListener() { + @Override + public void operationComplete(ChannelFuture future) throws Exception { + if (future.isSuccess()) { + long timeTaken = System.currentTimeMillis() - startTime; + logger.trace("Sending request {} to {} took {} ms", requestId, serverAddr, timeTaken); + } else { + String errorMsg = String.format("Failed to send RPC %s to %s: %s", requestId, + serverAddr, future.cause()); + logger.error(errorMsg, future.cause()); + handler.removeRpcRequest(requestId); + callback.onFailure(new RuntimeException(errorMsg, future.cause())); + channel.close(); + } + } + }); + } + + /** + * Synchronously sends an opaque message to the RpcHandler on the server-side, waiting for up to + * a specified timeout for a response. + */ + public byte[] sendRpcSync(byte[] message, long timeoutMs) { + final SettableFuture result = SettableFuture.create(); + + sendRpc(message, new RpcResponseCallback() { + @Override + public void onSuccess(byte[] response) { + result.set(response); + } + + @Override + public void onFailure(Throwable e) { + result.setException(e); + } + }); + + try { + return result.get(timeoutMs, TimeUnit.MILLISECONDS); + } catch (Exception e) { + throw Throwables.propagate(e); + } + } + + @Override + public void close() { + // close is a local operation and should finish with milliseconds; timeout just to be safe + channel.close().awaitUninterruptibly(10, TimeUnit.SECONDS); + } +} diff --git a/network/common/src/main/java/org/apache/spark/network/client/TransportClientFactory.java b/network/common/src/main/java/org/apache/spark/network/client/TransportClientFactory.java new file mode 100644 index 0000000000000..0b4a1d8286407 --- /dev/null +++ b/network/common/src/main/java/org/apache/spark/network/client/TransportClientFactory.java @@ -0,0 +1,186 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.client; + +import java.io.Closeable; +import java.lang.reflect.Field; +import java.net.InetSocketAddress; +import java.net.SocketAddress; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.TimeoutException; +import java.util.concurrent.atomic.AtomicReference; + +import io.netty.bootstrap.Bootstrap; +import io.netty.buffer.PooledByteBufAllocator; +import io.netty.channel.Channel; +import io.netty.channel.ChannelFuture; +import io.netty.channel.ChannelInitializer; +import io.netty.channel.ChannelOption; +import io.netty.channel.EventLoopGroup; +import io.netty.channel.socket.SocketChannel; +import io.netty.util.internal.PlatformDependent; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import org.apache.spark.network.TransportContext; +import org.apache.spark.network.server.TransportChannelHandler; +import org.apache.spark.network.util.IOMode; +import org.apache.spark.network.util.NettyUtils; +import org.apache.spark.network.util.TransportConf; + +/** + * Factory for creating {@link TransportClient}s by using createClient. + * + * The factory maintains a connection pool to other hosts and should return the same + * {@link TransportClient} for the same remote host. It also shares a single worker thread pool for + * all {@link TransportClient}s. + */ +public class TransportClientFactory implements Closeable { + private final Logger logger = LoggerFactory.getLogger(TransportClientFactory.class); + + private final TransportContext context; + private final TransportConf conf; + private final ConcurrentHashMap connectionPool; + + private final Class socketChannelClass; + private EventLoopGroup workerGroup; + + public TransportClientFactory(TransportContext context) { + this.context = context; + this.conf = context.getConf(); + this.connectionPool = new ConcurrentHashMap(); + + IOMode ioMode = IOMode.valueOf(conf.ioMode()); + this.socketChannelClass = NettyUtils.getClientChannelClass(ioMode); + // TODO: Make thread pool name configurable. + this.workerGroup = NettyUtils.createEventLoop(ioMode, conf.clientThreads(), "shuffle-client"); + } + + /** + * Create a new BlockFetchingClient connecting to the given remote host / port. + * + * This blocks until a connection is successfully established. + * + * Concurrency: This method is safe to call from multiple threads. + */ + public TransportClient createClient(String remoteHost, int remotePort) { + // Get connection from the connection pool first. + // If it is not found or not active, create a new one. + final InetSocketAddress address = new InetSocketAddress(remoteHost, remotePort); + TransportClient cachedClient = connectionPool.get(address); + if (cachedClient != null) { + if (cachedClient.isActive()) { + return cachedClient; + } else { + connectionPool.remove(address, cachedClient); // Remove inactive clients. + } + } + + logger.debug("Creating new connection to " + address); + + Bootstrap bootstrap = new Bootstrap(); + bootstrap.group(workerGroup) + .channel(socketChannelClass) + // Disable Nagle's Algorithm since we don't want packets to wait + .option(ChannelOption.TCP_NODELAY, true) + .option(ChannelOption.SO_KEEPALIVE, true) + .option(ChannelOption.CONNECT_TIMEOUT_MILLIS, conf.connectionTimeoutMs()); + + // Use pooled buffers to reduce temporary buffer allocation + bootstrap.option(ChannelOption.ALLOCATOR, createPooledByteBufAllocator()); + + final AtomicReference client = new AtomicReference(); + + bootstrap.handler(new ChannelInitializer() { + @Override + public void initChannel(SocketChannel ch) { + TransportChannelHandler clientHandler = context.initializePipeline(ch); + client.set(clientHandler.getClient()); + } + }); + + // Connect to the remote server + ChannelFuture cf = bootstrap.connect(address); + if (!cf.awaitUninterruptibly(conf.connectionTimeoutMs())) { + throw new RuntimeException( + String.format("Connecting to %s timed out (%s ms)", address, conf.connectionTimeoutMs())); + } else if (cf.cause() != null) { + throw new RuntimeException(String.format("Failed to connect to %s", address), cf.cause()); + } + + // Successful connection -- in the event that two threads raced to create a client, we will + // use the first one that was put into the connectionPool and close the one we made here. + assert client.get() != null : "Channel future completed successfully with null client"; + TransportClient oldClient = connectionPool.putIfAbsent(address, client.get()); + if (oldClient == null) { + return client.get(); + } else { + logger.debug("Two clients were created concurrently, second one will be disposed."); + client.get().close(); + return oldClient; + } + } + + /** Close all connections in the connection pool, and shutdown the worker thread pool. */ + @Override + public void close() { + for (TransportClient client : connectionPool.values()) { + try { + client.close(); + } catch (RuntimeException e) { + logger.warn("Ignoring exception during close", e); + } + } + connectionPool.clear(); + + if (workerGroup != null) { + workerGroup.shutdownGracefully(); + workerGroup = null; + } + } + + /** + * Create a pooled ByteBuf allocator but disables the thread-local cache. Thread-local caches + * are disabled because the ByteBufs are allocated by the event loop thread, but released by the + * executor thread rather than the event loop thread. Those thread-local caches actually delay + * the recycling of buffers, leading to larger memory usage. + */ + private PooledByteBufAllocator createPooledByteBufAllocator() { + return new PooledByteBufAllocator( + PlatformDependent.directBufferPreferred(), + getPrivateStaticField("DEFAULT_NUM_HEAP_ARENA"), + getPrivateStaticField("DEFAULT_NUM_DIRECT_ARENA"), + getPrivateStaticField("DEFAULT_PAGE_SIZE"), + getPrivateStaticField("DEFAULT_MAX_ORDER"), + 0, // tinyCacheSize + 0, // smallCacheSize + 0 // normalCacheSize + ); + } + + /** Used to get defaults from Netty's private static fields. */ + private int getPrivateStaticField(String name) { + try { + Field f = PooledByteBufAllocator.DEFAULT.getClass().getDeclaredField(name); + f.setAccessible(true); + return f.getInt(null); + } catch (Exception e) { + throw new RuntimeException(e); + } + } +} diff --git a/network/common/src/main/java/org/apache/spark/network/client/TransportResponseHandler.java b/network/common/src/main/java/org/apache/spark/network/client/TransportResponseHandler.java new file mode 100644 index 0000000000000..d8965590b34da --- /dev/null +++ b/network/common/src/main/java/org/apache/spark/network/client/TransportResponseHandler.java @@ -0,0 +1,167 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.client; + +import java.util.Map; +import java.util.concurrent.ConcurrentHashMap; + +import com.google.common.annotations.VisibleForTesting; +import io.netty.channel.Channel; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import org.apache.spark.network.protocol.ChunkFetchFailure; +import org.apache.spark.network.protocol.ChunkFetchSuccess; +import org.apache.spark.network.protocol.ResponseMessage; +import org.apache.spark.network.protocol.RpcFailure; +import org.apache.spark.network.protocol.RpcResponse; +import org.apache.spark.network.protocol.StreamChunkId; +import org.apache.spark.network.server.MessageHandler; +import org.apache.spark.network.util.NettyUtils; + +/** + * Handler that processes server responses, in response to requests issued from a + * [[TransportClient]]. It works by tracking the list of outstanding requests (and their callbacks). + * + * Concurrency: thread safe and can be called from multiple threads. + */ +public class TransportResponseHandler extends MessageHandler { + private final Logger logger = LoggerFactory.getLogger(TransportResponseHandler.class); + + private final Channel channel; + + private final Map outstandingFetches; + + private final Map outstandingRpcs; + + public TransportResponseHandler(Channel channel) { + this.channel = channel; + this.outstandingFetches = new ConcurrentHashMap(); + this.outstandingRpcs = new ConcurrentHashMap(); + } + + public void addFetchRequest(StreamChunkId streamChunkId, ChunkReceivedCallback callback) { + outstandingFetches.put(streamChunkId, callback); + } + + public void removeFetchRequest(StreamChunkId streamChunkId) { + outstandingFetches.remove(streamChunkId); + } + + public void addRpcRequest(long requestId, RpcResponseCallback callback) { + outstandingRpcs.put(requestId, callback); + } + + public void removeRpcRequest(long requestId) { + outstandingRpcs.remove(requestId); + } + + /** + * Fire the failure callback for all outstanding requests. This is called when we have an + * uncaught exception or pre-mature connection termination. + */ + private void failOutstandingRequests(Throwable cause) { + for (Map.Entry entry : outstandingFetches.entrySet()) { + entry.getValue().onFailure(entry.getKey().chunkIndex, cause); + } + for (Map.Entry entry : outstandingRpcs.entrySet()) { + entry.getValue().onFailure(cause); + } + + // It's OK if new fetches appear, as they will fail immediately. + outstandingFetches.clear(); + outstandingRpcs.clear(); + } + + @Override + public void channelUnregistered() { + if (numOutstandingRequests() > 0) { + String remoteAddress = NettyUtils.getRemoteAddress(channel); + logger.error("Still have {} requests outstanding when connection from {} is closed", + numOutstandingRequests(), remoteAddress); + failOutstandingRequests(new RuntimeException("Connection from " + remoteAddress + " closed")); + } + } + + @Override + public void exceptionCaught(Throwable cause) { + if (numOutstandingRequests() > 0) { + String remoteAddress = NettyUtils.getRemoteAddress(channel); + logger.error("Still have {} requests outstanding when connection from {} is closed", + numOutstandingRequests(), remoteAddress); + failOutstandingRequests(cause); + } + } + + @Override + public void handle(ResponseMessage message) { + String remoteAddress = NettyUtils.getRemoteAddress(channel); + if (message instanceof ChunkFetchSuccess) { + ChunkFetchSuccess resp = (ChunkFetchSuccess) message; + ChunkReceivedCallback listener = outstandingFetches.get(resp.streamChunkId); + if (listener == null) { + logger.warn("Ignoring response for block {} from {} since it is not outstanding", + resp.streamChunkId, remoteAddress); + resp.buffer.release(); + } else { + outstandingFetches.remove(resp.streamChunkId); + listener.onSuccess(resp.streamChunkId.chunkIndex, resp.buffer); + resp.buffer.release(); + } + } else if (message instanceof ChunkFetchFailure) { + ChunkFetchFailure resp = (ChunkFetchFailure) message; + ChunkReceivedCallback listener = outstandingFetches.get(resp.streamChunkId); + if (listener == null) { + logger.warn("Ignoring response for block {} from {} ({}) since it is not outstanding", + resp.streamChunkId, remoteAddress, resp.errorString); + } else { + outstandingFetches.remove(resp.streamChunkId); + listener.onFailure(resp.streamChunkId.chunkIndex, new ChunkFetchFailureException( + "Failure while fetching " + resp.streamChunkId + ": " + resp.errorString)); + } + } else if (message instanceof RpcResponse) { + RpcResponse resp = (RpcResponse) message; + RpcResponseCallback listener = outstandingRpcs.get(resp.requestId); + if (listener == null) { + logger.warn("Ignoring response for RPC {} from {} ({} bytes) since it is not outstanding", + resp.requestId, remoteAddress, resp.response.length); + } else { + outstandingRpcs.remove(resp.requestId); + listener.onSuccess(resp.response); + } + } else if (message instanceof RpcFailure) { + RpcFailure resp = (RpcFailure) message; + RpcResponseCallback listener = outstandingRpcs.get(resp.requestId); + if (listener == null) { + logger.warn("Ignoring response for RPC {} from {} ({}) since it is not outstanding", + resp.requestId, remoteAddress, resp.errorString); + } else { + outstandingRpcs.remove(resp.requestId); + listener.onFailure(new RuntimeException(resp.errorString)); + } + } else { + throw new IllegalStateException("Unknown response type: " + message.type()); + } + } + + /** Returns total number of outstanding requests (fetch requests + rpcs) */ + @VisibleForTesting + public int numOutstandingRequests() { + return outstandingFetches.size() + outstandingRpcs.size(); + } +} diff --git a/network/common/src/main/java/org/apache/spark/network/protocol/ChunkFetchFailure.java b/network/common/src/main/java/org/apache/spark/network/protocol/ChunkFetchFailure.java new file mode 100644 index 0000000000000..152af98ced7ce --- /dev/null +++ b/network/common/src/main/java/org/apache/spark/network/protocol/ChunkFetchFailure.java @@ -0,0 +1,76 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.protocol; + +import com.google.common.base.Charsets; +import com.google.common.base.Objects; +import io.netty.buffer.ByteBuf; + +/** + * Response to {@link ChunkFetchRequest} when there is an error fetching the chunk. + */ +public final class ChunkFetchFailure implements ResponseMessage { + public final StreamChunkId streamChunkId; + public final String errorString; + + public ChunkFetchFailure(StreamChunkId streamChunkId, String errorString) { + this.streamChunkId = streamChunkId; + this.errorString = errorString; + } + + @Override + public Type type() { return Type.ChunkFetchFailure; } + + @Override + public int encodedLength() { + return streamChunkId.encodedLength() + 4 + errorString.getBytes(Charsets.UTF_8).length; + } + + @Override + public void encode(ByteBuf buf) { + streamChunkId.encode(buf); + byte[] errorBytes = errorString.getBytes(Charsets.UTF_8); + buf.writeInt(errorBytes.length); + buf.writeBytes(errorBytes); + } + + public static ChunkFetchFailure decode(ByteBuf buf) { + StreamChunkId streamChunkId = StreamChunkId.decode(buf); + int numErrorStringBytes = buf.readInt(); + byte[] errorBytes = new byte[numErrorStringBytes]; + buf.readBytes(errorBytes); + return new ChunkFetchFailure(streamChunkId, new String(errorBytes, Charsets.UTF_8)); + } + + @Override + public boolean equals(Object other) { + if (other instanceof ChunkFetchFailure) { + ChunkFetchFailure o = (ChunkFetchFailure) other; + return streamChunkId.equals(o.streamChunkId) && errorString.equals(o.errorString); + } + return false; + } + + @Override + public String toString() { + return Objects.toStringHelper(this) + .add("streamChunkId", streamChunkId) + .add("errorString", errorString) + .toString(); + } +} diff --git a/network/common/src/main/java/org/apache/spark/network/protocol/ChunkFetchRequest.java b/network/common/src/main/java/org/apache/spark/network/protocol/ChunkFetchRequest.java new file mode 100644 index 0000000000000..980947cf13f6b --- /dev/null +++ b/network/common/src/main/java/org/apache/spark/network/protocol/ChunkFetchRequest.java @@ -0,0 +1,66 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.protocol; + +import com.google.common.base.Objects; +import io.netty.buffer.ByteBuf; + +/** + * Request to fetch a sequence of a single chunk of a stream. This will correspond to a single + * {@link org.apache.spark.network.protocol.ResponseMessage} (either success or failure). + */ +public final class ChunkFetchRequest implements RequestMessage { + public final StreamChunkId streamChunkId; + + public ChunkFetchRequest(StreamChunkId streamChunkId) { + this.streamChunkId = streamChunkId; + } + + @Override + public Type type() { return Type.ChunkFetchRequest; } + + @Override + public int encodedLength() { + return streamChunkId.encodedLength(); + } + + @Override + public void encode(ByteBuf buf) { + streamChunkId.encode(buf); + } + + public static ChunkFetchRequest decode(ByteBuf buf) { + return new ChunkFetchRequest(StreamChunkId.decode(buf)); + } + + @Override + public boolean equals(Object other) { + if (other instanceof ChunkFetchRequest) { + ChunkFetchRequest o = (ChunkFetchRequest) other; + return streamChunkId.equals(o.streamChunkId); + } + return false; + } + + @Override + public String toString() { + return Objects.toStringHelper(this) + .add("streamChunkId", streamChunkId) + .toString(); + } +} diff --git a/network/common/src/main/java/org/apache/spark/network/protocol/ChunkFetchSuccess.java b/network/common/src/main/java/org/apache/spark/network/protocol/ChunkFetchSuccess.java new file mode 100644 index 0000000000000..ff4936470c697 --- /dev/null +++ b/network/common/src/main/java/org/apache/spark/network/protocol/ChunkFetchSuccess.java @@ -0,0 +1,80 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.protocol; + +import com.google.common.base.Objects; +import io.netty.buffer.ByteBuf; + +import org.apache.spark.network.buffer.ManagedBuffer; +import org.apache.spark.network.buffer.NettyManagedBuffer; + +/** + * Response to {@link ChunkFetchRequest} when a chunk exists and has been successfully fetched. + * + * Note that the server-side encoding of this messages does NOT include the buffer itself, as this + * may be written by Netty in a more efficient manner (i.e., zero-copy write). + * Similarly, the client-side decoding will reuse the Netty ByteBuf as the buffer. + */ +public final class ChunkFetchSuccess implements ResponseMessage { + public final StreamChunkId streamChunkId; + public final ManagedBuffer buffer; + + public ChunkFetchSuccess(StreamChunkId streamChunkId, ManagedBuffer buffer) { + this.streamChunkId = streamChunkId; + this.buffer = buffer; + } + + @Override + public Type type() { return Type.ChunkFetchSuccess; } + + @Override + public int encodedLength() { + return streamChunkId.encodedLength(); + } + + /** Encoding does NOT include 'buffer' itself. See {@link MessageEncoder}. */ + @Override + public void encode(ByteBuf buf) { + streamChunkId.encode(buf); + } + + /** Decoding uses the given ByteBuf as our data, and will retain() it. */ + public static ChunkFetchSuccess decode(ByteBuf buf) { + StreamChunkId streamChunkId = StreamChunkId.decode(buf); + buf.retain(); + NettyManagedBuffer managedBuf = new NettyManagedBuffer(buf.duplicate()); + return new ChunkFetchSuccess(streamChunkId, managedBuf); + } + + @Override + public boolean equals(Object other) { + if (other instanceof ChunkFetchSuccess) { + ChunkFetchSuccess o = (ChunkFetchSuccess) other; + return streamChunkId.equals(o.streamChunkId) && buffer.equals(o.buffer); + } + return false; + } + + @Override + public String toString() { + return Objects.toStringHelper(this) + .add("streamChunkId", streamChunkId) + .add("buffer", buffer) + .toString(); + } +} diff --git a/network/common/src/main/java/org/apache/spark/network/protocol/Encodable.java b/network/common/src/main/java/org/apache/spark/network/protocol/Encodable.java new file mode 100644 index 0000000000000..b4e299471b41a --- /dev/null +++ b/network/common/src/main/java/org/apache/spark/network/protocol/Encodable.java @@ -0,0 +1,41 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.protocol; + +import io.netty.buffer.ByteBuf; + +/** + * Interface for an object which can be encoded into a ByteBuf. Multiple Encodable objects are + * stored in a single, pre-allocated ByteBuf, so Encodables must also provide their length. + * + * Encodable objects should provide a static "decode(ByteBuf)" method which is invoked by + * {@link MessageDecoder}. During decoding, if the object uses the ByteBuf as its data (rather than + * just copying data from it), then you must retain() the ByteBuf. + * + * Additionally, when adding a new Encodable Message, add it to {@link Message.Type}. + */ +public interface Encodable { + /** Number of bytes of the encoded form of this object. */ + int encodedLength(); + + /** + * Serializes this object by writing into the given ByteBuf. + * This method must write exactly encodedLength() bytes. + */ + void encode(ByteBuf buf); +} diff --git a/network/common/src/main/java/org/apache/spark/network/protocol/Message.java b/network/common/src/main/java/org/apache/spark/network/protocol/Message.java new file mode 100644 index 0000000000000..d568370125fd4 --- /dev/null +++ b/network/common/src/main/java/org/apache/spark/network/protocol/Message.java @@ -0,0 +1,58 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.protocol; + +import io.netty.buffer.ByteBuf; + +/** An on-the-wire transmittable message. */ +public interface Message extends Encodable { + /** Used to identify this request type. */ + Type type(); + + /** Preceding every serialized Message is its type, which allows us to deserialize it. */ + public static enum Type implements Encodable { + ChunkFetchRequest(0), ChunkFetchSuccess(1), ChunkFetchFailure(2), + RpcRequest(3), RpcResponse(4), RpcFailure(5); + + private final byte id; + + private Type(int id) { + assert id < 128 : "Cannot have more than 128 message types"; + this.id = (byte) id; + } + + public byte id() { return id; } + + @Override public int encodedLength() { return 1; } + + @Override public void encode(ByteBuf buf) { buf.writeByte(id); } + + public static Type decode(ByteBuf buf) { + byte id = buf.readByte(); + switch (id) { + case 0: return ChunkFetchRequest; + case 1: return ChunkFetchSuccess; + case 2: return ChunkFetchFailure; + case 3: return RpcRequest; + case 4: return RpcResponse; + case 5: return RpcFailure; + default: throw new IllegalArgumentException("Unknown message type: " + id); + } + } + } +} diff --git a/network/common/src/main/java/org/apache/spark/network/protocol/MessageDecoder.java b/network/common/src/main/java/org/apache/spark/network/protocol/MessageDecoder.java new file mode 100644 index 0000000000000..81f8d7f96350f --- /dev/null +++ b/network/common/src/main/java/org/apache/spark/network/protocol/MessageDecoder.java @@ -0,0 +1,70 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.protocol; + +import java.util.List; + +import io.netty.buffer.ByteBuf; +import io.netty.channel.ChannelHandler; +import io.netty.channel.ChannelHandlerContext; +import io.netty.handler.codec.MessageToMessageDecoder; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** + * Decoder used by the client side to encode server-to-client responses. + * This encoder is stateless so it is safe to be shared by multiple threads. + */ +@ChannelHandler.Sharable +public final class MessageDecoder extends MessageToMessageDecoder { + + private final Logger logger = LoggerFactory.getLogger(MessageDecoder.class); + @Override + public void decode(ChannelHandlerContext ctx, ByteBuf in, List out) { + Message.Type msgType = Message.Type.decode(in); + Message decoded = decode(msgType, in); + assert decoded.type() == msgType; + logger.trace("Received message " + msgType + ": " + decoded); + out.add(decoded); + } + + private Message decode(Message.Type msgType, ByteBuf in) { + switch (msgType) { + case ChunkFetchRequest: + return ChunkFetchRequest.decode(in); + + case ChunkFetchSuccess: + return ChunkFetchSuccess.decode(in); + + case ChunkFetchFailure: + return ChunkFetchFailure.decode(in); + + case RpcRequest: + return RpcRequest.decode(in); + + case RpcResponse: + return RpcResponse.decode(in); + + case RpcFailure: + return RpcFailure.decode(in); + + default: + throw new IllegalArgumentException("Unexpected message type: " + msgType); + } + } +} diff --git a/network/common/src/main/java/org/apache/spark/network/protocol/MessageEncoder.java b/network/common/src/main/java/org/apache/spark/network/protocol/MessageEncoder.java new file mode 100644 index 0000000000000..4cb8becc3ed22 --- /dev/null +++ b/network/common/src/main/java/org/apache/spark/network/protocol/MessageEncoder.java @@ -0,0 +1,80 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.protocol; + +import java.util.List; + +import io.netty.buffer.ByteBuf; +import io.netty.channel.ChannelHandler; +import io.netty.channel.ChannelHandlerContext; +import io.netty.handler.codec.MessageToMessageEncoder; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** + * Encoder used by the server side to encode server-to-client responses. + * This encoder is stateless so it is safe to be shared by multiple threads. + */ +@ChannelHandler.Sharable +public final class MessageEncoder extends MessageToMessageEncoder { + + private final Logger logger = LoggerFactory.getLogger(MessageEncoder.class); + + /*** + * Encodes a Message by invoking its encode() method. For non-data messages, we will add one + * ByteBuf to 'out' containing the total frame length, the message type, and the message itself. + * In the case of a ChunkFetchSuccess, we will also add the ManagedBuffer corresponding to the + * data to 'out', in order to enable zero-copy transfer. + */ + @Override + public void encode(ChannelHandlerContext ctx, Message in, List out) { + Object body = null; + long bodyLength = 0; + + // Only ChunkFetchSuccesses have data besides the header. + // The body is used in order to enable zero-copy transfer for the payload. + if (in instanceof ChunkFetchSuccess) { + ChunkFetchSuccess resp = (ChunkFetchSuccess) in; + try { + bodyLength = resp.buffer.size(); + body = resp.buffer.convertToNetty(); + } catch (Exception e) { + // Re-encode this message as BlockFetchFailure. + logger.error(String.format("Error opening block %s for client %s", + resp.streamChunkId, ctx.channel().remoteAddress()), e); + encode(ctx, new ChunkFetchFailure(resp.streamChunkId, e.getMessage()), out); + return; + } + } + + Message.Type msgType = in.type(); + // All messages have the frame length, message type, and message itself. + int headerLength = 8 + msgType.encodedLength() + in.encodedLength(); + long frameLength = headerLength + bodyLength; + ByteBuf header = ctx.alloc().buffer(headerLength); + header.writeLong(frameLength); + msgType.encode(header); + in.encode(header); + assert header.writableBytes() == 0; + + out.add(header); + if (body != null && bodyLength > 0) { + out.add(body); + } + } +} diff --git a/core/src/main/scala/org/apache/spark/network/netty/PathResolver.scala b/network/common/src/main/java/org/apache/spark/network/protocol/RequestMessage.java similarity index 77% rename from core/src/main/scala/org/apache/spark/network/netty/PathResolver.scala rename to network/common/src/main/java/org/apache/spark/network/protocol/RequestMessage.java index 0d7695072a7b1..31b15bb17a327 100644 --- a/core/src/main/scala/org/apache/spark/network/netty/PathResolver.scala +++ b/network/common/src/main/java/org/apache/spark/network/protocol/RequestMessage.java @@ -15,11 +15,11 @@ * limitations under the License. */ -package org.apache.spark.network.netty +package org.apache.spark.network.protocol; -import org.apache.spark.storage.{BlockId, FileSegment} +import org.apache.spark.network.protocol.Message; -trait PathResolver { - /** Get the file segment in which the given block resides. */ - def getBlockLocation(blockId: BlockId): FileSegment +/** Messages from the client to the server. */ +public interface RequestMessage extends Message { + // token interface } diff --git a/core/src/main/scala/org/apache/spark/network/netty/client/BlockClientListener.scala b/network/common/src/main/java/org/apache/spark/network/protocol/ResponseMessage.java similarity index 75% rename from core/src/main/scala/org/apache/spark/network/netty/client/BlockClientListener.scala rename to network/common/src/main/java/org/apache/spark/network/protocol/ResponseMessage.java index e28219dd7745b..6edffd11cf1e2 100644 --- a/core/src/main/scala/org/apache/spark/network/netty/client/BlockClientListener.scala +++ b/network/common/src/main/java/org/apache/spark/network/protocol/ResponseMessage.java @@ -15,15 +15,11 @@ * limitations under the License. */ -package org.apache.spark.network.netty.client +package org.apache.spark.network.protocol; -import java.util.EventListener - - -trait BlockClientListener extends EventListener { - - def onFetchSuccess(blockId: String, data: ReferenceCountedBuffer): Unit - - def onFetchFailure(blockId: String, errorMsg: String): Unit +import org.apache.spark.network.protocol.Message; +/** Messages from the server to the client. */ +public interface ResponseMessage extends Message { + // token interface } diff --git a/network/common/src/main/java/org/apache/spark/network/protocol/RpcFailure.java b/network/common/src/main/java/org/apache/spark/network/protocol/RpcFailure.java new file mode 100644 index 0000000000000..e239d4ffbd29c --- /dev/null +++ b/network/common/src/main/java/org/apache/spark/network/protocol/RpcFailure.java @@ -0,0 +1,74 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.protocol; + +import com.google.common.base.Charsets; +import com.google.common.base.Objects; +import io.netty.buffer.ByteBuf; + +/** Response to {@link RpcRequest} for a failed RPC. */ +public final class RpcFailure implements ResponseMessage { + public final long requestId; + public final String errorString; + + public RpcFailure(long requestId, String errorString) { + this.requestId = requestId; + this.errorString = errorString; + } + + @Override + public Type type() { return Type.RpcFailure; } + + @Override + public int encodedLength() { + return 8 + 4 + errorString.getBytes(Charsets.UTF_8).length; + } + + @Override + public void encode(ByteBuf buf) { + buf.writeLong(requestId); + byte[] errorBytes = errorString.getBytes(Charsets.UTF_8); + buf.writeInt(errorBytes.length); + buf.writeBytes(errorBytes); + } + + public static RpcFailure decode(ByteBuf buf) { + long requestId = buf.readLong(); + int numErrorStringBytes = buf.readInt(); + byte[] errorBytes = new byte[numErrorStringBytes]; + buf.readBytes(errorBytes); + return new RpcFailure(requestId, new String(errorBytes, Charsets.UTF_8)); + } + + @Override + public boolean equals(Object other) { + if (other instanceof RpcFailure) { + RpcFailure o = (RpcFailure) other; + return requestId == o.requestId && errorString.equals(o.errorString); + } + return false; + } + + @Override + public String toString() { + return Objects.toStringHelper(this) + .add("requestId", requestId) + .add("errorString", errorString) + .toString(); + } +} diff --git a/network/common/src/main/java/org/apache/spark/network/protocol/RpcRequest.java b/network/common/src/main/java/org/apache/spark/network/protocol/RpcRequest.java new file mode 100644 index 0000000000000..099e934ae018c --- /dev/null +++ b/network/common/src/main/java/org/apache/spark/network/protocol/RpcRequest.java @@ -0,0 +1,81 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.protocol; + +import java.util.Arrays; + +import com.google.common.base.Objects; +import io.netty.buffer.ByteBuf; + +/** + * A generic RPC which is handled by a remote {@link org.apache.spark.network.server.RpcHandler}. + * This will correspond to a single + * {@link org.apache.spark.network.protocol.ResponseMessage} (either success or failure). + */ +public final class RpcRequest implements RequestMessage { + /** Used to link an RPC request with its response. */ + public final long requestId; + + /** Serialized message to send to remote RpcHandler. */ + public final byte[] message; + + public RpcRequest(long requestId, byte[] message) { + this.requestId = requestId; + this.message = message; + } + + @Override + public Type type() { return Type.RpcRequest; } + + @Override + public int encodedLength() { + return 8 + 4 + message.length; + } + + @Override + public void encode(ByteBuf buf) { + buf.writeLong(requestId); + buf.writeInt(message.length); + buf.writeBytes(message); + } + + public static RpcRequest decode(ByteBuf buf) { + long requestId = buf.readLong(); + int messageLen = buf.readInt(); + byte[] message = new byte[messageLen]; + buf.readBytes(message); + return new RpcRequest(requestId, message); + } + + @Override + public boolean equals(Object other) { + if (other instanceof RpcRequest) { + RpcRequest o = (RpcRequest) other; + return requestId == o.requestId && Arrays.equals(message, o.message); + } + return false; + } + + @Override + public String toString() { + return Objects.toStringHelper(this) + .add("requestId", requestId) + .add("message", message) + .toString(); + } +} diff --git a/network/common/src/main/java/org/apache/spark/network/protocol/RpcResponse.java b/network/common/src/main/java/org/apache/spark/network/protocol/RpcResponse.java new file mode 100644 index 0000000000000..ed479478325b6 --- /dev/null +++ b/network/common/src/main/java/org/apache/spark/network/protocol/RpcResponse.java @@ -0,0 +1,72 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.protocol; + +import java.util.Arrays; + +import com.google.common.base.Objects; +import io.netty.buffer.ByteBuf; + +/** Response to {@link RpcRequest} for a successful RPC. */ +public final class RpcResponse implements ResponseMessage { + public final long requestId; + public final byte[] response; + + public RpcResponse(long requestId, byte[] response) { + this.requestId = requestId; + this.response = response; + } + + @Override + public Type type() { return Type.RpcResponse; } + + @Override + public int encodedLength() { return 8 + 4 + response.length; } + + @Override + public void encode(ByteBuf buf) { + buf.writeLong(requestId); + buf.writeInt(response.length); + buf.writeBytes(response); + } + + public static RpcResponse decode(ByteBuf buf) { + long requestId = buf.readLong(); + int responseLen = buf.readInt(); + byte[] response = new byte[responseLen]; + buf.readBytes(response); + return new RpcResponse(requestId, response); + } + + @Override + public boolean equals(Object other) { + if (other instanceof RpcResponse) { + RpcResponse o = (RpcResponse) other; + return requestId == o.requestId && Arrays.equals(response, o.response); + } + return false; + } + + @Override + public String toString() { + return Objects.toStringHelper(this) + .add("requestId", requestId) + .add("response", response) + .toString(); + } +} diff --git a/network/common/src/main/java/org/apache/spark/network/protocol/StreamChunkId.java b/network/common/src/main/java/org/apache/spark/network/protocol/StreamChunkId.java new file mode 100644 index 0000000000000..d46a263884807 --- /dev/null +++ b/network/common/src/main/java/org/apache/spark/network/protocol/StreamChunkId.java @@ -0,0 +1,73 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.protocol; + +import com.google.common.base.Objects; +import io.netty.buffer.ByteBuf; + +/** +* Encapsulates a request for a particular chunk of a stream. +*/ +public final class StreamChunkId implements Encodable { + public final long streamId; + public final int chunkIndex; + + public StreamChunkId(long streamId, int chunkIndex) { + this.streamId = streamId; + this.chunkIndex = chunkIndex; + } + + @Override + public int encodedLength() { + return 8 + 4; + } + + public void encode(ByteBuf buffer) { + buffer.writeLong(streamId); + buffer.writeInt(chunkIndex); + } + + public static StreamChunkId decode(ByteBuf buffer) { + assert buffer.readableBytes() >= 8 + 4; + long streamId = buffer.readLong(); + int chunkIndex = buffer.readInt(); + return new StreamChunkId(streamId, chunkIndex); + } + + @Override + public int hashCode() { + return Objects.hashCode(streamId, chunkIndex); + } + + @Override + public boolean equals(Object other) { + if (other instanceof StreamChunkId) { + StreamChunkId o = (StreamChunkId) other; + return streamId == o.streamId && chunkIndex == o.chunkIndex; + } + return false; + } + + @Override + public String toString() { + return Objects.toStringHelper(this) + .add("streamId", streamId) + .add("chunkIndex", chunkIndex) + .toString(); + } +} diff --git a/core/src/main/scala/org/apache/spark/network/netty/server/BlockHeader.scala b/network/common/src/main/java/org/apache/spark/network/server/MessageHandler.java similarity index 53% rename from core/src/main/scala/org/apache/spark/network/netty/server/BlockHeader.scala rename to network/common/src/main/java/org/apache/spark/network/server/MessageHandler.java index 162e9cc6828d4..b80c15106ecbd 100644 --- a/core/src/main/scala/org/apache/spark/network/netty/server/BlockHeader.scala +++ b/network/common/src/main/java/org/apache/spark/network/server/MessageHandler.java @@ -15,18 +15,22 @@ * limitations under the License. */ -package org.apache.spark.network.netty.server +package org.apache.spark.network.server; + +import org.apache.spark.network.protocol.Message; /** - * Header describing a block. This is used only in the server pipeline. - * - * [[BlockServerHandler]] creates this, and [[BlockHeaderEncoder]] encodes it. - * - * @param blockSize length of the block content, excluding the length itself. - * If positive, this is the header for a block (not part of the header). - * If negative, this is the header and content for an error message. - * @param blockId block id - * @param error some error message from reading the block + * Handles either request or response messages coming off of Netty. A MessageHandler instance + * is associated with a single Netty Channel (though it may have multiple clients on the same + * Channel.) */ -private[server] -class BlockHeader(val blockSize: Int, val blockId: String, val error: Option[String] = None) +public abstract class MessageHandler { + /** Handles the receipt of a single message. */ + public abstract void handle(T message); + + /** Invoked when an exception was caught on the Channel. */ + public abstract void exceptionCaught(Throwable cause); + + /** Invoked when the channel this MessageHandler is on has been unregistered. */ + public abstract void channelUnregistered(); +} diff --git a/network/common/src/main/java/org/apache/spark/network/server/NoOpRpcHandler.java b/network/common/src/main/java/org/apache/spark/network/server/NoOpRpcHandler.java new file mode 100644 index 0000000000000..5a3f003726fc1 --- /dev/null +++ b/network/common/src/main/java/org/apache/spark/network/server/NoOpRpcHandler.java @@ -0,0 +1,38 @@ +package org.apache.spark.network.server; + +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import org.apache.spark.network.client.RpcResponseCallback; +import org.apache.spark.network.client.TransportClient; + +/** An RpcHandler suitable for a client-only TransportContext, which cannot receive RPCs. */ +public class NoOpRpcHandler implements RpcHandler { + private final StreamManager streamManager; + + public NoOpRpcHandler() { + streamManager = new OneForOneStreamManager(); + } + + @Override + public void receive(TransportClient client, byte[] message, RpcResponseCallback callback) { + throw new UnsupportedOperationException("Cannot handle messages"); + } + + @Override + public StreamManager getStreamManager() { return streamManager; } +} diff --git a/network/common/src/main/java/org/apache/spark/network/server/OneForOneStreamManager.java b/network/common/src/main/java/org/apache/spark/network/server/OneForOneStreamManager.java new file mode 100644 index 0000000000000..731d48d4d9c6c --- /dev/null +++ b/network/common/src/main/java/org/apache/spark/network/server/OneForOneStreamManager.java @@ -0,0 +1,104 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.server; + +import java.util.Iterator; +import java.util.Map; +import java.util.Random; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.atomic.AtomicLong; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import org.apache.spark.network.buffer.ManagedBuffer; + +/** + * StreamManager which allows registration of an Iterator, which are individually + * fetched as chunks by the client. Each registered buffer is one chunk. + */ +public class OneForOneStreamManager extends StreamManager { + private final Logger logger = LoggerFactory.getLogger(OneForOneStreamManager.class); + + private final AtomicLong nextStreamId; + private final Map streams; + + /** State of a single stream. */ + private static class StreamState { + final Iterator buffers; + + // Used to keep track of the index of the buffer that the user has retrieved, just to ensure + // that the caller only requests each chunk one at a time, in order. + int curChunk = 0; + + StreamState(Iterator buffers) { + this.buffers = buffers; + } + } + + public OneForOneStreamManager() { + // For debugging purposes, start with a random stream id to help identifying different streams. + // This does not need to be globally unique, only unique to this class. + nextStreamId = new AtomicLong((long) new Random().nextInt(Integer.MAX_VALUE) * 1000); + streams = new ConcurrentHashMap(); + } + + @Override + public ManagedBuffer getChunk(long streamId, int chunkIndex) { + StreamState state = streams.get(streamId); + if (chunkIndex != state.curChunk) { + throw new IllegalStateException(String.format( + "Received out-of-order chunk index %s (expected %s)", chunkIndex, state.curChunk)); + } else if (!state.buffers.hasNext()) { + throw new IllegalStateException(String.format( + "Requested chunk index beyond end %s", chunkIndex)); + } + state.curChunk += 1; + ManagedBuffer nextChunk = state.buffers.next(); + + if (!state.buffers.hasNext()) { + logger.trace("Removing stream id {}", streamId); + streams.remove(streamId); + } + + return nextChunk; + } + + @Override + public void connectionTerminated(long streamId) { + // Release all remaining buffers. + StreamState state = streams.remove(streamId); + if (state != null && state.buffers != null) { + while (state.buffers.hasNext()) { + state.buffers.next().release(); + } + } + } + + /** + * Registers a stream of ManagedBuffers which are served as individual chunks one at a time to + * callers. Each ManagedBuffer will be release()'d after it is transferred on the wire. If a + * client connection is closed before the iterator is fully drained, then the remaining buffers + * will all be release()'d. + */ + public long registerStream(Iterator buffers) { + long myStreamId = nextStreamId.getAndIncrement(); + streams.put(myStreamId, new StreamState(buffers)); + return myStreamId; + } +} diff --git a/network/common/src/main/java/org/apache/spark/network/server/RpcHandler.java b/network/common/src/main/java/org/apache/spark/network/server/RpcHandler.java new file mode 100644 index 0000000000000..2369dc6203944 --- /dev/null +++ b/network/common/src/main/java/org/apache/spark/network/server/RpcHandler.java @@ -0,0 +1,44 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.server; + +import org.apache.spark.network.client.RpcResponseCallback; +import org.apache.spark.network.client.TransportClient; + +/** + * Handler for sendRPC() messages sent by {@link org.apache.spark.network.client.TransportClient}s. + */ +public interface RpcHandler { + /** + * Receive a single RPC message. Any exception thrown while in this method will be sent back to + * the client in string form as a standard RPC failure. + * + * @param client A channel client which enables the handler to make requests back to the sender + * of this RPC. + * @param message The serialized bytes of the RPC. + * @param callback Callback which should be invoked exactly once upon success or failure of the + * RPC. + */ + void receive(TransportClient client, byte[] message, RpcResponseCallback callback); + + /** + * Returns the StreamManager which contains the state about which streams are currently being + * fetched by a TransportClient. + */ + StreamManager getStreamManager(); +} diff --git a/network/common/src/main/java/org/apache/spark/network/server/StreamManager.java b/network/common/src/main/java/org/apache/spark/network/server/StreamManager.java new file mode 100644 index 0000000000000..5a9a14a180c10 --- /dev/null +++ b/network/common/src/main/java/org/apache/spark/network/server/StreamManager.java @@ -0,0 +1,52 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.server; + +import org.apache.spark.network.buffer.ManagedBuffer; + +/** + * The StreamManager is used to fetch individual chunks from a stream. This is used in + * {@link TransportRequestHandler} in order to respond to fetchChunk() requests. Creation of the + * stream is outside the scope of the transport layer, but a given stream is guaranteed to be read + * by only one client connection, meaning that getChunk() for a particular stream will be called + * serially and that once the connection associated with the stream is closed, that stream will + * never be used again. + */ +public abstract class StreamManager { + /** + * Called in response to a fetchChunk() request. The returned buffer will be passed as-is to the + * client. A single stream will be associated with a single TCP connection, so this method + * will not be called in parallel for a particular stream. + * + * Chunks may be requested in any order, and requests may be repeated, but it is not required + * that implementations support this behavior. + * + * The returned ManagedBuffer will be release()'d after being written to the network. + * + * @param streamId id of a stream that has been previously registered with the StreamManager. + * @param chunkIndex 0-indexed chunk of the stream that's requested + */ + public abstract ManagedBuffer getChunk(long streamId, int chunkIndex); + + /** + * Indicates that the TCP connection that was tied to the given stream has been terminated. After + * this occurs, we are guaranteed not to read from the stream again, so any state can be cleaned + * up. + */ + public void connectionTerminated(long streamId) { } +} diff --git a/network/common/src/main/java/org/apache/spark/network/server/TransportChannelHandler.java b/network/common/src/main/java/org/apache/spark/network/server/TransportChannelHandler.java new file mode 100644 index 0000000000000..e491367fa4528 --- /dev/null +++ b/network/common/src/main/java/org/apache/spark/network/server/TransportChannelHandler.java @@ -0,0 +1,96 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.server; + +import io.netty.channel.ChannelHandlerContext; +import io.netty.channel.SimpleChannelInboundHandler; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import org.apache.spark.network.client.TransportClient; +import org.apache.spark.network.client.TransportResponseHandler; +import org.apache.spark.network.protocol.Message; +import org.apache.spark.network.protocol.RequestMessage; +import org.apache.spark.network.protocol.ResponseMessage; +import org.apache.spark.network.util.NettyUtils; + +/** + * The single Transport-level Channel handler which is used for delegating requests to the + * {@link TransportRequestHandler} and responses to the {@link TransportResponseHandler}. + * + * All channels created in the transport layer are bidirectional. When the Client initiates a Netty + * Channel with a RequestMessage (which gets handled by the Server's RequestHandler), the Server + * will produce a ResponseMessage (handled by the Client's ResponseHandler). However, the Server + * also gets a handle on the same Channel, so it may then begin to send RequestMessages to the + * Client. + * This means that the Client also needs a RequestHandler and the Server needs a ResponseHandler, + * for the Client's responses to the Server's requests. + */ +public class TransportChannelHandler extends SimpleChannelInboundHandler { + private final Logger logger = LoggerFactory.getLogger(TransportChannelHandler.class); + + private final TransportClient client; + private final TransportResponseHandler responseHandler; + private final TransportRequestHandler requestHandler; + + public TransportChannelHandler( + TransportClient client, + TransportResponseHandler responseHandler, + TransportRequestHandler requestHandler) { + this.client = client; + this.responseHandler = responseHandler; + this.requestHandler = requestHandler; + } + + public TransportClient getClient() { + return client; + } + + @Override + public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws Exception { + logger.warn("Exception in connection from " + NettyUtils.getRemoteAddress(ctx.channel()), + cause); + requestHandler.exceptionCaught(cause); + responseHandler.exceptionCaught(cause); + ctx.close(); + } + + @Override + public void channelUnregistered(ChannelHandlerContext ctx) throws Exception { + try { + requestHandler.channelUnregistered(); + } catch (RuntimeException e) { + logger.error("Exception from request handler while unregistering channel", e); + } + try { + responseHandler.channelUnregistered(); + } catch (RuntimeException e) { + logger.error("Exception from response handler while unregistering channel", e); + } + super.channelUnregistered(ctx); + } + + @Override + public void channelRead0(ChannelHandlerContext ctx, Message request) { + if (request instanceof RequestMessage) { + requestHandler.handle((RequestMessage) request); + } else { + responseHandler.handle((ResponseMessage) request); + } + } +} diff --git a/network/common/src/main/java/org/apache/spark/network/server/TransportRequestHandler.java b/network/common/src/main/java/org/apache/spark/network/server/TransportRequestHandler.java new file mode 100644 index 0000000000000..17fe9001b35cc --- /dev/null +++ b/network/common/src/main/java/org/apache/spark/network/server/TransportRequestHandler.java @@ -0,0 +1,161 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.server; + +import java.util.Set; + +import com.google.common.base.Throwables; +import com.google.common.collect.Sets; +import io.netty.channel.Channel; +import io.netty.channel.ChannelFuture; +import io.netty.channel.ChannelFutureListener; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import org.apache.spark.network.buffer.ManagedBuffer; +import org.apache.spark.network.client.RpcResponseCallback; +import org.apache.spark.network.client.TransportClient; +import org.apache.spark.network.protocol.Encodable; +import org.apache.spark.network.protocol.RequestMessage; +import org.apache.spark.network.protocol.ChunkFetchRequest; +import org.apache.spark.network.protocol.RpcRequest; +import org.apache.spark.network.protocol.ChunkFetchFailure; +import org.apache.spark.network.protocol.ChunkFetchSuccess; +import org.apache.spark.network.protocol.RpcFailure; +import org.apache.spark.network.protocol.RpcResponse; +import org.apache.spark.network.util.NettyUtils; + +/** + * A handler that processes requests from clients and writes chunk data back. Each handler is + * attached to a single Netty channel, and keeps track of which streams have been fetched via this + * channel, in order to clean them up if the channel is terminated (see #channelUnregistered). + * + * The messages should have been processed by the pipeline setup by {@link TransportServer}. + */ +public class TransportRequestHandler extends MessageHandler { + private final Logger logger = LoggerFactory.getLogger(TransportRequestHandler.class); + + /** The Netty channel that this handler is associated with. */ + private final Channel channel; + + /** Client on the same channel allowing us to talk back to the requester. */ + private final TransportClient reverseClient; + + /** Handles all RPC messages. */ + private final RpcHandler rpcHandler; + + /** Returns each chunk part of a stream. */ + private final StreamManager streamManager; + + /** List of all stream ids that have been read on this handler, used for cleanup. */ + private final Set streamIds; + + public TransportRequestHandler( + Channel channel, + TransportClient reverseClient, + RpcHandler rpcHandler) { + this.channel = channel; + this.reverseClient = reverseClient; + this.rpcHandler = rpcHandler; + this.streamManager = rpcHandler.getStreamManager(); + this.streamIds = Sets.newHashSet(); + } + + @Override + public void exceptionCaught(Throwable cause) { + } + + @Override + public void channelUnregistered() { + // Inform the StreamManager that these streams will no longer be read from. + for (long streamId : streamIds) { + streamManager.connectionTerminated(streamId); + } + } + + @Override + public void handle(RequestMessage request) { + if (request instanceof ChunkFetchRequest) { + processFetchRequest((ChunkFetchRequest) request); + } else if (request instanceof RpcRequest) { + processRpcRequest((RpcRequest) request); + } else { + throw new IllegalArgumentException("Unknown request type: " + request); + } + } + + private void processFetchRequest(final ChunkFetchRequest req) { + final String client = NettyUtils.getRemoteAddress(channel); + streamIds.add(req.streamChunkId.streamId); + + logger.trace("Received req from {} to fetch block {}", client, req.streamChunkId); + + ManagedBuffer buf; + try { + buf = streamManager.getChunk(req.streamChunkId.streamId, req.streamChunkId.chunkIndex); + } catch (Exception e) { + logger.error(String.format( + "Error opening block %s for request from %s", req.streamChunkId, client), e); + respond(new ChunkFetchFailure(req.streamChunkId, Throwables.getStackTraceAsString(e))); + return; + } + + respond(new ChunkFetchSuccess(req.streamChunkId, buf)); + } + + private void processRpcRequest(final RpcRequest req) { + try { + rpcHandler.receive(reverseClient, req.message, new RpcResponseCallback() { + @Override + public void onSuccess(byte[] response) { + respond(new RpcResponse(req.requestId, response)); + } + + @Override + public void onFailure(Throwable e) { + respond(new RpcFailure(req.requestId, Throwables.getStackTraceAsString(e))); + } + }); + } catch (Exception e) { + logger.error("Error while invoking RpcHandler#receive() on RPC id " + req.requestId, e); + respond(new RpcFailure(req.requestId, Throwables.getStackTraceAsString(e))); + } + } + + /** + * Responds to a single message with some Encodable object. If a failure occurs while sending, + * it will be logged and the channel closed. + */ + private void respond(final Encodable result) { + final String remoteAddress = channel.remoteAddress().toString(); + channel.writeAndFlush(result).addListener( + new ChannelFutureListener() { + @Override + public void operationComplete(ChannelFuture future) throws Exception { + if (future.isSuccess()) { + logger.trace(String.format("Sent result %s to client %s", result, remoteAddress)); + } else { + logger.error(String.format("Error sending result %s to %s; closing connection", + result, remoteAddress), future.cause()); + channel.close(); + } + } + } + ); + } +} diff --git a/network/common/src/main/java/org/apache/spark/network/server/TransportServer.java b/network/common/src/main/java/org/apache/spark/network/server/TransportServer.java new file mode 100644 index 0000000000000..70da48ca8ee79 --- /dev/null +++ b/network/common/src/main/java/org/apache/spark/network/server/TransportServer.java @@ -0,0 +1,122 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.server; + +import java.io.Closeable; +import java.net.InetSocketAddress; +import java.util.concurrent.TimeUnit; + +import io.netty.bootstrap.ServerBootstrap; +import io.netty.buffer.PooledByteBufAllocator; +import io.netty.channel.ChannelFuture; +import io.netty.channel.ChannelInitializer; +import io.netty.channel.ChannelOption; +import io.netty.channel.EventLoopGroup; +import io.netty.channel.socket.SocketChannel; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import org.apache.spark.network.TransportContext; +import org.apache.spark.network.util.IOMode; +import org.apache.spark.network.util.NettyUtils; +import org.apache.spark.network.util.TransportConf; + +/** + * Server for the efficient, low-level streaming service. + */ +public class TransportServer implements Closeable { + private final Logger logger = LoggerFactory.getLogger(TransportServer.class); + + private final TransportContext context; + private final TransportConf conf; + + private ServerBootstrap bootstrap; + private ChannelFuture channelFuture; + private int port = -1; + + /** Creates a TransportServer that binds to the given port, or to any available if 0. */ + public TransportServer(TransportContext context, int portToBind) { + this.context = context; + this.conf = context.getConf(); + + init(portToBind); + } + + public int getPort() { + if (port == -1) { + throw new IllegalStateException("Server not initialized"); + } + return port; + } + + private void init(int portToBind) { + + IOMode ioMode = IOMode.valueOf(conf.ioMode()); + EventLoopGroup bossGroup = + NettyUtils.createEventLoop(ioMode, conf.serverThreads(), "shuffle-server"); + EventLoopGroup workerGroup = bossGroup; + + bootstrap = new ServerBootstrap() + .group(bossGroup, workerGroup) + .channel(NettyUtils.getServerChannelClass(ioMode)) + .option(ChannelOption.ALLOCATOR, PooledByteBufAllocator.DEFAULT) + .childOption(ChannelOption.ALLOCATOR, PooledByteBufAllocator.DEFAULT); + + if (conf.backLog() > 0) { + bootstrap.option(ChannelOption.SO_BACKLOG, conf.backLog()); + } + + if (conf.receiveBuf() > 0) { + bootstrap.childOption(ChannelOption.SO_RCVBUF, conf.receiveBuf()); + } + + if (conf.sendBuf() > 0) { + bootstrap.childOption(ChannelOption.SO_SNDBUF, conf.sendBuf()); + } + + bootstrap.childHandler(new ChannelInitializer() { + @Override + protected void initChannel(SocketChannel ch) throws Exception { + context.initializePipeline(ch); + } + }); + + channelFuture = bootstrap.bind(new InetSocketAddress(portToBind)); + channelFuture.syncUninterruptibly(); + + port = ((InetSocketAddress) channelFuture.channel().localAddress()).getPort(); + logger.debug("Shuffle server started on port :" + port); + } + + @Override + public void close() { + if (channelFuture != null) { + // close is a local operation and should finish within milliseconds; timeout just to be safe + channelFuture.channel().close().awaitUninterruptibly(10, TimeUnit.SECONDS); + channelFuture = null; + } + if (bootstrap != null && bootstrap.group() != null) { + bootstrap.group().shutdownGracefully(); + } + if (bootstrap != null && bootstrap.childGroup() != null) { + bootstrap.childGroup().shutdownGracefully(); + } + bootstrap = null; + } + +} diff --git a/network/common/src/main/java/org/apache/spark/network/util/ConfigProvider.java b/network/common/src/main/java/org/apache/spark/network/util/ConfigProvider.java new file mode 100644 index 0000000000000..d944d9da1c7f8 --- /dev/null +++ b/network/common/src/main/java/org/apache/spark/network/util/ConfigProvider.java @@ -0,0 +1,52 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.util; + +import java.util.NoSuchElementException; + +/** + * Provides a mechanism for constructing a {@link TransportConf} using some sort of configuration. + */ +public abstract class ConfigProvider { + /** Obtains the value of the given config, throws NoSuchElementException if it doesn't exist. */ + public abstract String get(String name); + + public String get(String name, String defaultValue) { + try { + return get(name); + } catch (NoSuchElementException e) { + return defaultValue; + } + } + + public int getInt(String name, int defaultValue) { + return Integer.parseInt(get(name, Integer.toString(defaultValue))); + } + + public long getLong(String name, long defaultValue) { + return Long.parseLong(get(name, Long.toString(defaultValue))); + } + + public double getDouble(String name, double defaultValue) { + return Double.parseDouble(get(name, Double.toString(defaultValue))); + } + + public boolean getBoolean(String name, boolean defaultValue) { + return Boolean.parseBoolean(get(name, Boolean.toString(defaultValue))); + } +} diff --git a/network/common/src/main/java/org/apache/spark/network/util/IOMode.java b/network/common/src/main/java/org/apache/spark/network/util/IOMode.java new file mode 100644 index 0000000000000..6b208d95bbfbc --- /dev/null +++ b/network/common/src/main/java/org/apache/spark/network/util/IOMode.java @@ -0,0 +1,27 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.util; + +/** + * Selector for which form of low-level IO we should use. + * NIO is always available, while EPOLL is only available on Linux. + * AUTO is used to select EPOLL if it's available, or NIO otherwise. + */ +public enum IOMode { + NIO, EPOLL +} diff --git a/network/common/src/main/java/org/apache/spark/network/util/JavaUtils.java b/network/common/src/main/java/org/apache/spark/network/util/JavaUtils.java new file mode 100644 index 0000000000000..40b71b0c87a47 --- /dev/null +++ b/network/common/src/main/java/org/apache/spark/network/util/JavaUtils.java @@ -0,0 +1,76 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.util; + +import java.io.ByteArrayInputStream; +import java.io.ByteArrayOutputStream; +import java.io.Closeable; +import java.io.IOException; +import java.io.ObjectInputStream; +import java.io.ObjectOutputStream; + +import com.google.common.io.Closeables; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +public class JavaUtils { + private static final Logger logger = LoggerFactory.getLogger(JavaUtils.class); + + /** Closes the given object, ignoring IOExceptions. */ + public static void closeQuietly(Closeable closeable) { + try { + closeable.close(); + } catch (IOException e) { + logger.error("IOException should not have been thrown.", e); + } + } + + // TODO: Make this configurable, do not use Java serialization! + public static T deserialize(byte[] bytes) { + try { + ObjectInputStream is = new ObjectInputStream(new ByteArrayInputStream(bytes)); + Object out = is.readObject(); + is.close(); + return (T) out; + } catch (ClassNotFoundException e) { + throw new RuntimeException("Could not deserialize object", e); + } catch (IOException e) { + throw new RuntimeException("Could not deserialize object", e); + } + } + + // TODO: Make this configurable, do not use Java serialization! + public static byte[] serialize(Object object) { + try { + ByteArrayOutputStream baos = new ByteArrayOutputStream(); + ObjectOutputStream os = new ObjectOutputStream(baos); + os.writeObject(object); + os.close(); + return baos.toByteArray(); + } catch (IOException e) { + throw new RuntimeException("Could not serialize object", e); + } + } + + /** Returns a hash consistent with Spark's Utils.nonNegativeHash(). */ + public static int nonNegativeHash(Object obj) { + if (obj == null) { return 0; } + int hash = obj.hashCode(); + return hash != Integer.MIN_VALUE ? Math.abs(hash) : 0; + } +} diff --git a/network/common/src/main/java/org/apache/spark/network/util/NettyUtils.java b/network/common/src/main/java/org/apache/spark/network/util/NettyUtils.java new file mode 100644 index 0000000000000..b1872341198e0 --- /dev/null +++ b/network/common/src/main/java/org/apache/spark/network/util/NettyUtils.java @@ -0,0 +1,102 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.util; + +import java.util.concurrent.ThreadFactory; + +import com.google.common.util.concurrent.ThreadFactoryBuilder; +import io.netty.channel.Channel; +import io.netty.channel.EventLoopGroup; +import io.netty.channel.ServerChannel; +import io.netty.channel.epoll.Epoll; +import io.netty.channel.epoll.EpollEventLoopGroup; +import io.netty.channel.epoll.EpollServerSocketChannel; +import io.netty.channel.epoll.EpollSocketChannel; +import io.netty.channel.nio.NioEventLoopGroup; +import io.netty.channel.socket.nio.NioServerSocketChannel; +import io.netty.channel.socket.nio.NioSocketChannel; +import io.netty.handler.codec.ByteToMessageDecoder; +import io.netty.handler.codec.LengthFieldBasedFrameDecoder; + +/** + * Utilities for creating various Netty constructs based on whether we're using EPOLL or NIO. + */ +public class NettyUtils { + /** Creates a Netty EventLoopGroup based on the IOMode. */ + public static EventLoopGroup createEventLoop(IOMode mode, int numThreads, String threadPrefix) { + + ThreadFactory threadFactory = new ThreadFactoryBuilder() + .setDaemon(true) + .setNameFormat(threadPrefix + "-%d") + .build(); + + switch (mode) { + case NIO: + return new NioEventLoopGroup(numThreads, threadFactory); + case EPOLL: + return new EpollEventLoopGroup(numThreads, threadFactory); + default: + throw new IllegalArgumentException("Unknown io mode: " + mode); + } + } + + /** Returns the correct (client) SocketChannel class based on IOMode. */ + public static Class getClientChannelClass(IOMode mode) { + switch (mode) { + case NIO: + return NioSocketChannel.class; + case EPOLL: + return EpollSocketChannel.class; + default: + throw new IllegalArgumentException("Unknown io mode: " + mode); + } + } + + /** Returns the correct ServerSocketChannel class based on IOMode. */ + public static Class getServerChannelClass(IOMode mode) { + switch (mode) { + case NIO: + return NioServerSocketChannel.class; + case EPOLL: + return EpollServerSocketChannel.class; + default: + throw new IllegalArgumentException("Unknown io mode: " + mode); + } + } + + /** + * Creates a LengthFieldBasedFrameDecoder where the first 8 bytes are the length of the frame. + * This is used before all decoders. + */ + public static ByteToMessageDecoder createFrameDecoder() { + // maxFrameLength = 2G + // lengthFieldOffset = 0 + // lengthFieldLength = 8 + // lengthAdjustment = -8, i.e. exclude the 8 byte length itself + // initialBytesToStrip = 8, i.e. strip out the length field itself + return new LengthFieldBasedFrameDecoder(Integer.MAX_VALUE, 0, 8, -8, 8); + } + + /** Returns the remote address on the channel or "" if none exists. */ + public static String getRemoteAddress(Channel channel) { + if (channel != null && channel.remoteAddress() != null) { + return channel.remoteAddress().toString(); + } + return ""; + } +} diff --git a/network/common/src/main/java/org/apache/spark/network/util/SystemPropertyConfigProvider.java b/network/common/src/main/java/org/apache/spark/network/util/SystemPropertyConfigProvider.java new file mode 100644 index 0000000000000..5f20b70678d1e --- /dev/null +++ b/network/common/src/main/java/org/apache/spark/network/util/SystemPropertyConfigProvider.java @@ -0,0 +1,34 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.util; + +import java.util.NoSuchElementException; + +import org.apache.spark.network.util.ConfigProvider; + +/** Uses System properties to obtain config values. */ +public class SystemPropertyConfigProvider extends ConfigProvider { + @Override + public String get(String name) { + String value = System.getProperty(name); + if (value == null) { + throw new NoSuchElementException(name); + } + return value; + } +} diff --git a/network/common/src/main/java/org/apache/spark/network/util/TransportConf.java b/network/common/src/main/java/org/apache/spark/network/util/TransportConf.java new file mode 100644 index 0000000000000..a68f38e0e94c9 --- /dev/null +++ b/network/common/src/main/java/org/apache/spark/network/util/TransportConf.java @@ -0,0 +1,58 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.util; + +/** + * A central location that tracks all the settings we expose to users. + */ +public class TransportConf { + private final ConfigProvider conf; + + public TransportConf(ConfigProvider conf) { + this.conf = conf; + } + + /** IO mode: nio or epoll */ + public String ioMode() { return conf.get("spark.shuffle.io.mode", "NIO").toUpperCase(); } + + /** Connect timeout in secs. Default 120 secs. */ + public int connectionTimeoutMs() { + return conf.getInt("spark.shuffle.io.connectionTimeout", 120) * 1000; + } + + /** Requested maximum length of the queue of incoming connections. Default -1 for no backlog. */ + public int backLog() { return conf.getInt("spark.shuffle.io.backLog", -1); } + + /** Number of threads used in the server thread pool. Default to 0, which is 2x#cores. */ + public int serverThreads() { return conf.getInt("spark.shuffle.io.serverThreads", 0); } + + /** Number of threads used in the client thread pool. Default to 0, which is 2x#cores. */ + public int clientThreads() { return conf.getInt("spark.shuffle.io.clientThreads", 0); } + + /** + * Receive buffer size (SO_RCVBUF). + * Note: the optimal size for receive buffer and send buffer should be + * latency * network_bandwidth. + * Assuming latency = 1ms, network_bandwidth = 10Gbps + * buffer size should be ~ 1.25MB + */ + public int receiveBuf() { return conf.getInt("spark.shuffle.io.receiveBuffer", -1); } + + /** Send buffer size (SO_SNDBUF). */ + public int sendBuf() { return conf.getInt("spark.shuffle.io.sendBuffer", -1); } +} diff --git a/network/common/src/test/java/org/apache/spark/network/ChunkFetchIntegrationSuite.java b/network/common/src/test/java/org/apache/spark/network/ChunkFetchIntegrationSuite.java new file mode 100644 index 0000000000000..c4158833976aa --- /dev/null +++ b/network/common/src/test/java/org/apache/spark/network/ChunkFetchIntegrationSuite.java @@ -0,0 +1,231 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network; + +import java.io.File; +import java.io.RandomAccessFile; +import java.nio.ByteBuffer; +import java.util.Collections; +import java.util.HashSet; +import java.util.LinkedList; +import java.util.List; +import java.util.Random; +import java.util.Set; +import java.util.concurrent.Semaphore; +import java.util.concurrent.TimeUnit; + +import com.google.common.collect.Lists; +import com.google.common.collect.Sets; +import org.junit.AfterClass; +import org.junit.BeforeClass; +import org.junit.Test; + +import static org.junit.Assert.*; + +import org.apache.spark.network.buffer.FileSegmentManagedBuffer; +import org.apache.spark.network.buffer.ManagedBuffer; +import org.apache.spark.network.buffer.NioManagedBuffer; +import org.apache.spark.network.client.ChunkReceivedCallback; +import org.apache.spark.network.client.RpcResponseCallback; +import org.apache.spark.network.client.TransportClient; +import org.apache.spark.network.client.TransportClientFactory; +import org.apache.spark.network.server.RpcHandler; +import org.apache.spark.network.server.TransportServer; +import org.apache.spark.network.server.StreamManager; +import org.apache.spark.network.util.SystemPropertyConfigProvider; +import org.apache.spark.network.util.TransportConf; + +public class ChunkFetchIntegrationSuite { + static final long STREAM_ID = 1; + static final int BUFFER_CHUNK_INDEX = 0; + static final int FILE_CHUNK_INDEX = 1; + + static TransportServer server; + static TransportClientFactory clientFactory; + static StreamManager streamManager; + static File testFile; + + static ManagedBuffer bufferChunk; + static ManagedBuffer fileChunk; + + @BeforeClass + public static void setUp() throws Exception { + int bufSize = 100000; + final ByteBuffer buf = ByteBuffer.allocate(bufSize); + for (int i = 0; i < bufSize; i ++) { + buf.put((byte) i); + } + buf.flip(); + bufferChunk = new NioManagedBuffer(buf); + + testFile = File.createTempFile("shuffle-test-file", "txt"); + testFile.deleteOnExit(); + RandomAccessFile fp = new RandomAccessFile(testFile, "rw"); + byte[] fileContent = new byte[1024]; + new Random().nextBytes(fileContent); + fp.write(fileContent); + fp.close(); + fileChunk = new FileSegmentManagedBuffer(testFile, 10, testFile.length() - 25); + + TransportConf conf = new TransportConf(new SystemPropertyConfigProvider()); + streamManager = new StreamManager() { + @Override + public ManagedBuffer getChunk(long streamId, int chunkIndex) { + assertEquals(STREAM_ID, streamId); + if (chunkIndex == BUFFER_CHUNK_INDEX) { + return new NioManagedBuffer(buf); + } else if (chunkIndex == FILE_CHUNK_INDEX) { + return new FileSegmentManagedBuffer(testFile, 10, testFile.length() - 25); + } else { + throw new IllegalArgumentException("Invalid chunk index: " + chunkIndex); + } + } + }; + RpcHandler handler = new RpcHandler() { + @Override + public void receive(TransportClient client, byte[] message, RpcResponseCallback callback) { + throw new UnsupportedOperationException(); + } + + @Override + public StreamManager getStreamManager() { + return streamManager; + } + }; + TransportContext context = new TransportContext(conf, handler); + server = context.createServer(); + clientFactory = context.createClientFactory(); + } + + @AfterClass + public static void tearDown() { + server.close(); + clientFactory.close(); + testFile.delete(); + } + + class FetchResult { + public Set successChunks; + public Set failedChunks; + public List buffers; + + public void releaseBuffers() { + for (ManagedBuffer buffer : buffers) { + buffer.release(); + } + } + } + + private FetchResult fetchChunks(List chunkIndices) throws Exception { + TransportClient client = clientFactory.createClient(TestUtils.getLocalHost(), server.getPort()); + final Semaphore sem = new Semaphore(0); + + final FetchResult res = new FetchResult(); + res.successChunks = Collections.synchronizedSet(new HashSet()); + res.failedChunks = Collections.synchronizedSet(new HashSet()); + res.buffers = Collections.synchronizedList(new LinkedList()); + + ChunkReceivedCallback callback = new ChunkReceivedCallback() { + @Override + public void onSuccess(int chunkIndex, ManagedBuffer buffer) { + buffer.retain(); + res.successChunks.add(chunkIndex); + res.buffers.add(buffer); + sem.release(); + } + + @Override + public void onFailure(int chunkIndex, Throwable e) { + res.failedChunks.add(chunkIndex); + sem.release(); + } + }; + + for (int chunkIndex : chunkIndices) { + client.fetchChunk(STREAM_ID, chunkIndex, callback); + } + if (!sem.tryAcquire(chunkIndices.size(), 5, TimeUnit.SECONDS)) { + fail("Timeout getting response from the server"); + } + client.close(); + return res; + } + + @Test + public void fetchBufferChunk() throws Exception { + FetchResult res = fetchChunks(Lists.newArrayList(BUFFER_CHUNK_INDEX)); + assertEquals(res.successChunks, Sets.newHashSet(BUFFER_CHUNK_INDEX)); + assertTrue(res.failedChunks.isEmpty()); + assertBufferListsEqual(res.buffers, Lists.newArrayList(bufferChunk)); + res.releaseBuffers(); + } + + @Test + public void fetchFileChunk() throws Exception { + FetchResult res = fetchChunks(Lists.newArrayList(FILE_CHUNK_INDEX)); + assertEquals(res.successChunks, Sets.newHashSet(FILE_CHUNK_INDEX)); + assertTrue(res.failedChunks.isEmpty()); + assertBufferListsEqual(res.buffers, Lists.newArrayList(fileChunk)); + res.releaseBuffers(); + } + + @Test + public void fetchNonExistentChunk() throws Exception { + FetchResult res = fetchChunks(Lists.newArrayList(12345)); + assertTrue(res.successChunks.isEmpty()); + assertEquals(res.failedChunks, Sets.newHashSet(12345)); + assertTrue(res.buffers.isEmpty()); + } + + @Test + public void fetchBothChunks() throws Exception { + FetchResult res = fetchChunks(Lists.newArrayList(BUFFER_CHUNK_INDEX, FILE_CHUNK_INDEX)); + assertEquals(res.successChunks, Sets.newHashSet(BUFFER_CHUNK_INDEX, FILE_CHUNK_INDEX)); + assertTrue(res.failedChunks.isEmpty()); + assertBufferListsEqual(res.buffers, Lists.newArrayList(bufferChunk, fileChunk)); + res.releaseBuffers(); + } + + @Test + public void fetchChunkAndNonExistent() throws Exception { + FetchResult res = fetchChunks(Lists.newArrayList(BUFFER_CHUNK_INDEX, 12345)); + assertEquals(res.successChunks, Sets.newHashSet(BUFFER_CHUNK_INDEX)); + assertEquals(res.failedChunks, Sets.newHashSet(12345)); + assertBufferListsEqual(res.buffers, Lists.newArrayList(bufferChunk)); + res.releaseBuffers(); + } + + private void assertBufferListsEqual(List list0, List list1) + throws Exception { + assertEquals(list0.size(), list1.size()); + for (int i = 0; i < list0.size(); i ++) { + assertBuffersEqual(list0.get(i), list1.get(i)); + } + } + + private void assertBuffersEqual(ManagedBuffer buffer0, ManagedBuffer buffer1) throws Exception { + ByteBuffer nio0 = buffer0.nioByteBuffer(); + ByteBuffer nio1 = buffer1.nioByteBuffer(); + + int len = nio0.remaining(); + assertEquals(nio0.remaining(), nio1.remaining()); + for (int i = 0; i < len; i ++) { + assertEquals(nio0.get(), nio1.get()); + } + } +} diff --git a/network/common/src/test/java/org/apache/spark/network/ProtocolSuite.java b/network/common/src/test/java/org/apache/spark/network/ProtocolSuite.java new file mode 100644 index 0000000000000..43dc0cf8c7194 --- /dev/null +++ b/network/common/src/test/java/org/apache/spark/network/ProtocolSuite.java @@ -0,0 +1,86 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network; + +import io.netty.channel.embedded.EmbeddedChannel; +import org.junit.Test; + +import static org.junit.Assert.assertEquals; + +import org.apache.spark.network.protocol.Message; +import org.apache.spark.network.protocol.StreamChunkId; +import org.apache.spark.network.protocol.ChunkFetchRequest; +import org.apache.spark.network.protocol.ChunkFetchFailure; +import org.apache.spark.network.protocol.ChunkFetchSuccess; +import org.apache.spark.network.protocol.RpcRequest; +import org.apache.spark.network.protocol.RpcFailure; +import org.apache.spark.network.protocol.RpcResponse; +import org.apache.spark.network.protocol.MessageDecoder; +import org.apache.spark.network.protocol.MessageEncoder; +import org.apache.spark.network.util.NettyUtils; + +public class ProtocolSuite { + private void testServerToClient(Message msg) { + EmbeddedChannel serverChannel = new EmbeddedChannel(new MessageEncoder()); + serverChannel.writeOutbound(msg); + + EmbeddedChannel clientChannel = new EmbeddedChannel( + NettyUtils.createFrameDecoder(), new MessageDecoder()); + + while (!serverChannel.outboundMessages().isEmpty()) { + clientChannel.writeInbound(serverChannel.readOutbound()); + } + + assertEquals(1, clientChannel.inboundMessages().size()); + assertEquals(msg, clientChannel.readInbound()); + } + + private void testClientToServer(Message msg) { + EmbeddedChannel clientChannel = new EmbeddedChannel(new MessageEncoder()); + clientChannel.writeOutbound(msg); + + EmbeddedChannel serverChannel = new EmbeddedChannel( + NettyUtils.createFrameDecoder(), new MessageDecoder()); + + while (!clientChannel.outboundMessages().isEmpty()) { + serverChannel.writeInbound(clientChannel.readOutbound()); + } + + assertEquals(1, serverChannel.inboundMessages().size()); + assertEquals(msg, serverChannel.readInbound()); + } + + @Test + public void requests() { + testClientToServer(new ChunkFetchRequest(new StreamChunkId(1, 2))); + testClientToServer(new RpcRequest(12345, new byte[0])); + testClientToServer(new RpcRequest(12345, new byte[100])); + } + + @Test + public void responses() { + testServerToClient(new ChunkFetchSuccess(new StreamChunkId(1, 2), new TestManagedBuffer(10))); + testServerToClient(new ChunkFetchSuccess(new StreamChunkId(1, 2), new TestManagedBuffer(0))); + testServerToClient(new ChunkFetchFailure(new StreamChunkId(1, 2), "this is an error")); + testServerToClient(new ChunkFetchFailure(new StreamChunkId(1, 2), "")); + testServerToClient(new RpcResponse(12345, new byte[0])); + testServerToClient(new RpcResponse(12345, new byte[1000])); + testServerToClient(new RpcFailure(0, "this is an error")); + testServerToClient(new RpcFailure(0, "")); + } +} diff --git a/network/common/src/test/java/org/apache/spark/network/RpcIntegrationSuite.java b/network/common/src/test/java/org/apache/spark/network/RpcIntegrationSuite.java new file mode 100644 index 0000000000000..64b457b4b3f01 --- /dev/null +++ b/network/common/src/test/java/org/apache/spark/network/RpcIntegrationSuite.java @@ -0,0 +1,180 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network; + +import java.util.Collections; +import java.util.HashSet; +import java.util.Iterator; +import java.util.Set; +import java.util.concurrent.Semaphore; +import java.util.concurrent.TimeUnit; + +import com.google.common.base.Charsets; +import com.google.common.collect.Sets; +import org.junit.AfterClass; +import org.junit.BeforeClass; +import org.junit.Test; + +import static org.junit.Assert.*; + +import org.apache.spark.network.client.RpcResponseCallback; +import org.apache.spark.network.client.TransportClient; +import org.apache.spark.network.client.TransportClientFactory; +import org.apache.spark.network.server.OneForOneStreamManager; +import org.apache.spark.network.server.RpcHandler; +import org.apache.spark.network.server.StreamManager; +import org.apache.spark.network.server.TransportServer; +import org.apache.spark.network.util.SystemPropertyConfigProvider; +import org.apache.spark.network.util.TransportConf; + +public class RpcIntegrationSuite { + static TransportServer server; + static TransportClientFactory clientFactory; + static RpcHandler rpcHandler; + + @BeforeClass + public static void setUp() throws Exception { + TransportConf conf = new TransportConf(new SystemPropertyConfigProvider()); + rpcHandler = new RpcHandler() { + @Override + public void receive(TransportClient client, byte[] message, RpcResponseCallback callback) { + String msg = new String(message, Charsets.UTF_8); + String[] parts = msg.split("/"); + if (parts[0].equals("hello")) { + callback.onSuccess(("Hello, " + parts[1] + "!").getBytes(Charsets.UTF_8)); + } else if (parts[0].equals("return error")) { + callback.onFailure(new RuntimeException("Returned: " + parts[1])); + } else if (parts[0].equals("throw error")) { + throw new RuntimeException("Thrown: " + parts[1]); + } + } + + @Override + public StreamManager getStreamManager() { return new OneForOneStreamManager(); } + }; + TransportContext context = new TransportContext(conf, rpcHandler); + server = context.createServer(); + clientFactory = context.createClientFactory(); + } + + @AfterClass + public static void tearDown() { + server.close(); + clientFactory.close(); + } + + class RpcResult { + public Set successMessages; + public Set errorMessages; + } + + private RpcResult sendRPC(String ... commands) throws Exception { + TransportClient client = clientFactory.createClient(TestUtils.getLocalHost(), server.getPort()); + final Semaphore sem = new Semaphore(0); + + final RpcResult res = new RpcResult(); + res.successMessages = Collections.synchronizedSet(new HashSet()); + res.errorMessages = Collections.synchronizedSet(new HashSet()); + + RpcResponseCallback callback = new RpcResponseCallback() { + @Override + public void onSuccess(byte[] message) { + res.successMessages.add(new String(message, Charsets.UTF_8)); + sem.release(); + } + + @Override + public void onFailure(Throwable e) { + res.errorMessages.add(e.getMessage()); + sem.release(); + } + }; + + for (String command : commands) { + client.sendRpc(command.getBytes(Charsets.UTF_8), callback); + } + + if (!sem.tryAcquire(commands.length, 5, TimeUnit.SECONDS)) { + fail("Timeout getting response from the server"); + } + client.close(); + return res; + } + + @Test + public void singleRPC() throws Exception { + RpcResult res = sendRPC("hello/Aaron"); + assertEquals(res.successMessages, Sets.newHashSet("Hello, Aaron!")); + assertTrue(res.errorMessages.isEmpty()); + } + + @Test + public void doubleRPC() throws Exception { + RpcResult res = sendRPC("hello/Aaron", "hello/Reynold"); + assertEquals(res.successMessages, Sets.newHashSet("Hello, Aaron!", "Hello, Reynold!")); + assertTrue(res.errorMessages.isEmpty()); + } + + @Test + public void returnErrorRPC() throws Exception { + RpcResult res = sendRPC("return error/OK"); + assertTrue(res.successMessages.isEmpty()); + assertErrorsContain(res.errorMessages, Sets.newHashSet("Returned: OK")); + } + + @Test + public void throwErrorRPC() throws Exception { + RpcResult res = sendRPC("throw error/uh-oh"); + assertTrue(res.successMessages.isEmpty()); + assertErrorsContain(res.errorMessages, Sets.newHashSet("Thrown: uh-oh")); + } + + @Test + public void doubleTrouble() throws Exception { + RpcResult res = sendRPC("return error/OK", "throw error/uh-oh"); + assertTrue(res.successMessages.isEmpty()); + assertErrorsContain(res.errorMessages, Sets.newHashSet("Returned: OK", "Thrown: uh-oh")); + } + + @Test + public void sendSuccessAndFailure() throws Exception { + RpcResult res = sendRPC("hello/Bob", "throw error/the", "hello/Builder", "return error/!"); + assertEquals(res.successMessages, Sets.newHashSet("Hello, Bob!", "Hello, Builder!")); + assertErrorsContain(res.errorMessages, Sets.newHashSet("Thrown: the", "Returned: !")); + } + + private void assertErrorsContain(Set errors, Set contains) { + assertEquals(contains.size(), errors.size()); + + Set remainingErrors = Sets.newHashSet(errors); + for (String contain : contains) { + Iterator it = remainingErrors.iterator(); + boolean foundMatch = false; + while (it.hasNext()) { + if (it.next().contains(contain)) { + it.remove(); + foundMatch = true; + break; + } + } + assertTrue("Could not find error containing " + contain + "; errors: " + errors, foundMatch); + } + + assertTrue(remainingErrors.isEmpty()); + } +} diff --git a/network/common/src/test/java/org/apache/spark/network/TestManagedBuffer.java b/network/common/src/test/java/org/apache/spark/network/TestManagedBuffer.java new file mode 100644 index 0000000000000..38113a918f795 --- /dev/null +++ b/network/common/src/test/java/org/apache/spark/network/TestManagedBuffer.java @@ -0,0 +1,104 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network; + +import java.io.IOException; +import java.io.InputStream; +import java.nio.ByteBuffer; + +import com.google.common.base.Preconditions; +import io.netty.buffer.Unpooled; + +import org.apache.spark.network.buffer.ManagedBuffer; +import org.apache.spark.network.buffer.NettyManagedBuffer; + +/** + * A ManagedBuffer implementation that contains 0, 1, 2, 3, ..., (len-1). + * + * Used for testing. + */ +public class TestManagedBuffer extends ManagedBuffer { + + private final int len; + private NettyManagedBuffer underlying; + + public TestManagedBuffer(int len) { + Preconditions.checkArgument(len <= Byte.MAX_VALUE); + this.len = len; + byte[] byteArray = new byte[len]; + for (int i = 0; i < len; i ++) { + byteArray[i] = (byte) i; + } + this.underlying = new NettyManagedBuffer(Unpooled.wrappedBuffer(byteArray)); + } + + + @Override + public long size() { + return underlying.size(); + } + + @Override + public ByteBuffer nioByteBuffer() throws IOException { + return underlying.nioByteBuffer(); + } + + @Override + public InputStream createInputStream() throws IOException { + return underlying.createInputStream(); + } + + @Override + public ManagedBuffer retain() { + underlying.retain(); + return this; + } + + @Override + public ManagedBuffer release() { + underlying.release(); + return this; + } + + @Override + public Object convertToNetty() throws IOException { + return underlying.convertToNetty(); + } + + @Override + public boolean equals(Object other) { + if (other instanceof ManagedBuffer) { + try { + ByteBuffer nioBuf = ((ManagedBuffer) other).nioByteBuffer(); + if (nioBuf.remaining() != len) { + return false; + } else { + for (int i = 0; i < len; i ++) { + if (nioBuf.get() != i) { + return false; + } + } + return true; + } + } catch (IOException e) { + throw new RuntimeException(e); + } + } + return false; + } +} diff --git a/network/common/src/test/java/org/apache/spark/network/TestUtils.java b/network/common/src/test/java/org/apache/spark/network/TestUtils.java new file mode 100644 index 0000000000000..56a2b805f154c --- /dev/null +++ b/network/common/src/test/java/org/apache/spark/network/TestUtils.java @@ -0,0 +1,30 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network; + +import java.net.InetAddress; + +public class TestUtils { + public static String getLocalHost() { + try { + return InetAddress.getLocalHost().getHostAddress(); + } catch (Exception e) { + throw new RuntimeException(e); + } + } +} diff --git a/network/common/src/test/java/org/apache/spark/network/TransportClientFactorySuite.java b/network/common/src/test/java/org/apache/spark/network/TransportClientFactorySuite.java new file mode 100644 index 0000000000000..5a10fdb3842ef --- /dev/null +++ b/network/common/src/test/java/org/apache/spark/network/TransportClientFactorySuite.java @@ -0,0 +1,101 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network; + +import java.util.concurrent.TimeoutException; + +import org.junit.After; +import org.junit.Before; +import org.junit.Test; + +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertTrue; + +import org.apache.spark.network.client.TransportClient; +import org.apache.spark.network.client.TransportClientFactory; +import org.apache.spark.network.server.NoOpRpcHandler; +import org.apache.spark.network.server.RpcHandler; +import org.apache.spark.network.server.TransportServer; +import org.apache.spark.network.util.JavaUtils; +import org.apache.spark.network.util.SystemPropertyConfigProvider; +import org.apache.spark.network.util.TransportConf; + +public class TransportClientFactorySuite { + private TransportConf conf; + private TransportContext context; + private TransportServer server1; + private TransportServer server2; + + @Before + public void setUp() { + conf = new TransportConf(new SystemPropertyConfigProvider()); + RpcHandler rpcHandler = new NoOpRpcHandler(); + context = new TransportContext(conf, rpcHandler); + server1 = context.createServer(); + server2 = context.createServer(); + } + + @After + public void tearDown() { + JavaUtils.closeQuietly(server1); + JavaUtils.closeQuietly(server2); + } + + @Test + public void createAndReuseBlockClients() throws TimeoutException { + TransportClientFactory factory = context.createClientFactory(); + TransportClient c1 = factory.createClient(TestUtils.getLocalHost(), server1.getPort()); + TransportClient c2 = factory.createClient(TestUtils.getLocalHost(), server1.getPort()); + TransportClient c3 = factory.createClient(TestUtils.getLocalHost(), server2.getPort()); + assertTrue(c1.isActive()); + assertTrue(c3.isActive()); + assertTrue(c1 == c2); + assertTrue(c1 != c3); + factory.close(); + } + + @Test + public void neverReturnInactiveClients() throws Exception { + TransportClientFactory factory = context.createClientFactory(); + TransportClient c1 = factory.createClient(TestUtils.getLocalHost(), server1.getPort()); + c1.close(); + + long start = System.currentTimeMillis(); + while (c1.isActive() && (System.currentTimeMillis() - start) < 3000) { + Thread.sleep(10); + } + assertFalse(c1.isActive()); + + TransportClient c2 = factory.createClient(TestUtils.getLocalHost(), server1.getPort()); + assertFalse(c1 == c2); + assertTrue(c2.isActive()); + factory.close(); + } + + @Test + public void closeBlockClientsWithFactory() throws TimeoutException { + TransportClientFactory factory = context.createClientFactory(); + TransportClient c1 = factory.createClient(TestUtils.getLocalHost(), server1.getPort()); + TransportClient c2 = factory.createClient(TestUtils.getLocalHost(), server2.getPort()); + assertTrue(c1.isActive()); + assertTrue(c2.isActive()); + factory.close(); + assertFalse(c1.isActive()); + assertFalse(c2.isActive()); + } +} diff --git a/network/common/src/test/java/org/apache/spark/network/TransportResponseHandlerSuite.java b/network/common/src/test/java/org/apache/spark/network/TransportResponseHandlerSuite.java new file mode 100644 index 0000000000000..17a03ebe88a93 --- /dev/null +++ b/network/common/src/test/java/org/apache/spark/network/TransportResponseHandlerSuite.java @@ -0,0 +1,115 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network; + +import io.netty.channel.local.LocalChannel; +import org.junit.Test; + +import static org.junit.Assert.assertEquals; +import static org.mockito.Matchers.any; +import static org.mockito.Matchers.eq; +import static org.mockito.Mockito.*; + +import org.apache.spark.network.buffer.ManagedBuffer; +import org.apache.spark.network.client.ChunkReceivedCallback; +import org.apache.spark.network.client.RpcResponseCallback; +import org.apache.spark.network.client.TransportResponseHandler; +import org.apache.spark.network.protocol.ChunkFetchFailure; +import org.apache.spark.network.protocol.ChunkFetchSuccess; +import org.apache.spark.network.protocol.RpcFailure; +import org.apache.spark.network.protocol.RpcResponse; +import org.apache.spark.network.protocol.StreamChunkId; + +public class TransportResponseHandlerSuite { + @Test + public void handleSuccessfulFetch() { + StreamChunkId streamChunkId = new StreamChunkId(1, 0); + + TransportResponseHandler handler = new TransportResponseHandler(new LocalChannel()); + ChunkReceivedCallback callback = mock(ChunkReceivedCallback.class); + handler.addFetchRequest(streamChunkId, callback); + assertEquals(1, handler.numOutstandingRequests()); + + handler.handle(new ChunkFetchSuccess(streamChunkId, new TestManagedBuffer(123))); + verify(callback, times(1)).onSuccess(eq(0), (ManagedBuffer) any()); + assertEquals(0, handler.numOutstandingRequests()); + } + + @Test + public void handleFailedFetch() { + StreamChunkId streamChunkId = new StreamChunkId(1, 0); + TransportResponseHandler handler = new TransportResponseHandler(new LocalChannel()); + ChunkReceivedCallback callback = mock(ChunkReceivedCallback.class); + handler.addFetchRequest(streamChunkId, callback); + assertEquals(1, handler.numOutstandingRequests()); + + handler.handle(new ChunkFetchFailure(streamChunkId, "some error msg")); + verify(callback, times(1)).onFailure(eq(0), (Throwable) any()); + assertEquals(0, handler.numOutstandingRequests()); + } + + @Test + public void clearAllOutstandingRequests() { + TransportResponseHandler handler = new TransportResponseHandler(new LocalChannel()); + ChunkReceivedCallback callback = mock(ChunkReceivedCallback.class); + handler.addFetchRequest(new StreamChunkId(1, 0), callback); + handler.addFetchRequest(new StreamChunkId(1, 1), callback); + handler.addFetchRequest(new StreamChunkId(1, 2), callback); + assertEquals(3, handler.numOutstandingRequests()); + + handler.handle(new ChunkFetchSuccess(new StreamChunkId(1, 0), new TestManagedBuffer(12))); + handler.exceptionCaught(new Exception("duh duh duhhhh")); + + // should fail both b2 and b3 + verify(callback, times(1)).onSuccess(eq(0), (ManagedBuffer) any()); + verify(callback, times(1)).onFailure(eq(1), (Throwable) any()); + verify(callback, times(1)).onFailure(eq(2), (Throwable) any()); + assertEquals(0, handler.numOutstandingRequests()); + } + + @Test + public void handleSuccessfulRPC() { + TransportResponseHandler handler = new TransportResponseHandler(new LocalChannel()); + RpcResponseCallback callback = mock(RpcResponseCallback.class); + handler.addRpcRequest(12345, callback); + assertEquals(1, handler.numOutstandingRequests()); + + handler.handle(new RpcResponse(54321, new byte[7])); // should be ignored + assertEquals(1, handler.numOutstandingRequests()); + + byte[] arr = new byte[10]; + handler.handle(new RpcResponse(12345, arr)); + verify(callback, times(1)).onSuccess(eq(arr)); + assertEquals(0, handler.numOutstandingRequests()); + } + + @Test + public void handleFailedRPC() { + TransportResponseHandler handler = new TransportResponseHandler(new LocalChannel()); + RpcResponseCallback callback = mock(RpcResponseCallback.class); + handler.addRpcRequest(12345, callback); + assertEquals(1, handler.numOutstandingRequests()); + + handler.handle(new RpcFailure(54321, "uh-oh!")); // should be ignored + assertEquals(1, handler.numOutstandingRequests()); + + handler.handle(new RpcFailure(12345, "oh no")); + verify(callback, times(1)).onFailure((Throwable) any()); + assertEquals(0, handler.numOutstandingRequests()); + } +} diff --git a/network/shuffle/pom.xml b/network/shuffle/pom.xml new file mode 100644 index 0000000000000..d271704d98a7a --- /dev/null +++ b/network/shuffle/pom.xml @@ -0,0 +1,96 @@ + + + + + 4.0.0 + + org.apache.spark + spark-parent + 1.2.0-SNAPSHOT + ../../pom.xml + + + org.apache.spark + spark-network-shuffle_2.10 + jar + Spark Project Shuffle Streaming Service Code + http://spark.apache.org/ + + network-shuffle + + + + + + org.apache.spark + spark-network-common_2.10 + ${project.version} + + + org.slf4j + slf4j-api + + + + + com.google.guava + guava + provided + + + + + org.apache.spark + spark-network-common_2.10 + ${project.version} + test-jar + test + + + junit + junit + test + + + com.novocode + junit-interface + test + + + log4j + log4j + test + + + org.mockito + mockito-all + test + + + org.scalatest + scalatest_${scala.binary.version} + test + + + + + target/scala-${scala.binary.version}/classes + target/scala-${scala.binary.version}/test-classes + + diff --git a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/BlockFetchingListener.java b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/BlockFetchingListener.java new file mode 100644 index 0000000000000..138fd5389c20a --- /dev/null +++ b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/BlockFetchingListener.java @@ -0,0 +1,36 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.shuffle; + +import java.util.EventListener; + +import org.apache.spark.network.buffer.ManagedBuffer; + +public interface BlockFetchingListener extends EventListener { + /** + * Called once per successfully fetched block. After this call returns, data will be released + * automatically. If the data will be passed to another thread, the receiver should retain() + * and release() the buffer on their own, or copy the data to a new buffer. + */ + void onBlockFetchSuccess(String blockId, ManagedBuffer data); + + /** + * Called at least once per block upon failures. + */ + void onBlockFetchFailure(String blockId, Throwable exception); +} diff --git a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExecutorShuffleInfo.java b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExecutorShuffleInfo.java new file mode 100644 index 0000000000000..d45e64656a0e3 --- /dev/null +++ b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExecutorShuffleInfo.java @@ -0,0 +1,64 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.shuffle; + +import java.io.Serializable; +import java.util.Arrays; + +import com.google.common.base.Objects; + +/** Contains all configuration necessary for locating the shuffle files of an executor. */ +public class ExecutorShuffleInfo implements Serializable { + /** The base set of local directories that the executor stores its shuffle files in. */ + final String[] localDirs; + /** Number of subdirectories created within each localDir. */ + final int subDirsPerLocalDir; + /** Shuffle manager (SortShuffleManager or HashShuffleManager) that the executor is using. */ + final String shuffleManager; + + public ExecutorShuffleInfo(String[] localDirs, int subDirsPerLocalDir, String shuffleManager) { + this.localDirs = localDirs; + this.subDirsPerLocalDir = subDirsPerLocalDir; + this.shuffleManager = shuffleManager; + } + + @Override + public int hashCode() { + return Objects.hashCode(subDirsPerLocalDir, shuffleManager) * 41 + Arrays.hashCode(localDirs); + } + + @Override + public String toString() { + return Objects.toStringHelper(this) + .add("localDirs", Arrays.toString(localDirs)) + .add("subDirsPerLocalDir", subDirsPerLocalDir) + .add("shuffleManager", shuffleManager) + .toString(); + } + + @Override + public boolean equals(Object other) { + if (other != null && other instanceof ExecutorShuffleInfo) { + ExecutorShuffleInfo o = (ExecutorShuffleInfo) other; + return Arrays.equals(localDirs, o.localDirs) + && Objects.equal(subDirsPerLocalDir, o.subDirsPerLocalDir) + && Objects.equal(shuffleManager, o.shuffleManager); + } + return false; + } +} diff --git a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandler.java b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandler.java new file mode 100644 index 0000000000000..a9dff31decc83 --- /dev/null +++ b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandler.java @@ -0,0 +1,102 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.shuffle; + +import java.util.List; + +import com.google.common.annotations.VisibleForTesting; +import com.google.common.collect.Lists; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import static org.apache.spark.network.shuffle.ExternalShuffleMessages.*; + +import org.apache.spark.network.buffer.ManagedBuffer; +import org.apache.spark.network.client.RpcResponseCallback; +import org.apache.spark.network.client.TransportClient; +import org.apache.spark.network.server.OneForOneStreamManager; +import org.apache.spark.network.server.RpcHandler; +import org.apache.spark.network.server.StreamManager; +import org.apache.spark.network.util.JavaUtils; + +/** + * RPC Handler for a server which can serve shuffle blocks from outside of an Executor process. + * + * Handles registering executors and opening shuffle blocks from them. Shuffle blocks are registered + * with the "one-for-one" strategy, meaning each Transport-layer Chunk is equivalent to one Spark- + * level shuffle block. + */ +public class ExternalShuffleBlockHandler implements RpcHandler { + private final Logger logger = LoggerFactory.getLogger(ExternalShuffleBlockHandler.class); + + private final ExternalShuffleBlockManager blockManager; + private final OneForOneStreamManager streamManager; + + public ExternalShuffleBlockHandler() { + this(new OneForOneStreamManager(), new ExternalShuffleBlockManager()); + } + + /** Enables mocking out the StreamManager and BlockManager. */ + @VisibleForTesting + ExternalShuffleBlockHandler( + OneForOneStreamManager streamManager, + ExternalShuffleBlockManager blockManager) { + this.streamManager = streamManager; + this.blockManager = blockManager; + } + + @Override + public void receive(TransportClient client, byte[] message, RpcResponseCallback callback) { + Object msgObj = JavaUtils.deserialize(message); + + logger.trace("Received message: " + msgObj); + + if (msgObj instanceof OpenShuffleBlocks) { + OpenShuffleBlocks msg = (OpenShuffleBlocks) msgObj; + List blocks = Lists.newArrayList(); + + for (String blockId : msg.blockIds) { + blocks.add(blockManager.getBlockData(msg.appId, msg.execId, blockId)); + } + long streamId = streamManager.registerStream(blocks.iterator()); + logger.trace("Registered streamId {} with {} buffers", streamId, msg.blockIds.length); + callback.onSuccess(JavaUtils.serialize( + new ShuffleStreamHandle(streamId, msg.blockIds.length))); + + } else if (msgObj instanceof RegisterExecutor) { + RegisterExecutor msg = (RegisterExecutor) msgObj; + blockManager.registerExecutor(msg.appId, msg.execId, msg.executorInfo); + callback.onSuccess(new byte[0]); + + } else { + throw new UnsupportedOperationException(String.format( + "Unexpected message: %s (class = %s)", msgObj, msgObj.getClass())); + } + } + + @Override + public StreamManager getStreamManager() { + return streamManager; + } + + /** For testing, clears all executors registered with "RegisterExecutor". */ + @VisibleForTesting + public void clearRegisteredExecutors() { + blockManager.clearRegisteredExecutors(); + } +} diff --git a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockManager.java b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockManager.java new file mode 100644 index 0000000000000..6589889fe1be7 --- /dev/null +++ b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockManager.java @@ -0,0 +1,154 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.shuffle; + +import java.io.DataInputStream; +import java.io.File; +import java.io.FileInputStream; +import java.io.IOException; +import java.util.concurrent.ConcurrentHashMap; + +import com.google.common.annotations.VisibleForTesting; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import org.apache.spark.network.buffer.FileSegmentManagedBuffer; +import org.apache.spark.network.buffer.ManagedBuffer; +import org.apache.spark.network.util.JavaUtils; + +/** + * Manages converting shuffle BlockIds into physical segments of local files, from a process outside + * of Executors. Each Executor must register its own configuration about where it stores its files + * (local dirs) and how (shuffle manager). The logic for retrieval of individual files is replicated + * from Spark's FileShuffleBlockManager and IndexShuffleBlockManager. + * + * Executors with shuffle file consolidation are not currently supported, as the index is stored in + * the Executor's memory, unlike the IndexShuffleBlockManager. + */ +public class ExternalShuffleBlockManager { + private final Logger logger = LoggerFactory.getLogger(ExternalShuffleBlockManager.class); + + // Map from "appId-execId" to the executor's configuration. + private final ConcurrentHashMap executors = + new ConcurrentHashMap(); + + // Returns an id suitable for a single executor within a single application. + private String getAppExecId(String appId, String execId) { + return appId + "-" + execId; + } + + /** Registers a new Executor with all the configuration we need to find its shuffle files. */ + public void registerExecutor( + String appId, + String execId, + ExecutorShuffleInfo executorInfo) { + String fullId = getAppExecId(appId, execId); + logger.info("Registered executor {} with {}", fullId, executorInfo); + executors.put(fullId, executorInfo); + } + + /** + * Obtains a FileSegmentManagedBuffer from a shuffle block id. We expect the blockId has the + * format "shuffle_ShuffleId_MapId_ReduceId" (from ShuffleBlockId), and additionally make + * assumptions about how the hash and sort based shuffles store their data. + */ + public ManagedBuffer getBlockData(String appId, String execId, String blockId) { + String[] blockIdParts = blockId.split("_"); + if (blockIdParts.length < 4) { + throw new IllegalArgumentException("Unexpected block id format: " + blockId); + } else if (!blockIdParts[0].equals("shuffle")) { + throw new IllegalArgumentException("Expected shuffle block id, got: " + blockId); + } + int shuffleId = Integer.parseInt(blockIdParts[1]); + int mapId = Integer.parseInt(blockIdParts[2]); + int reduceId = Integer.parseInt(blockIdParts[3]); + + ExecutorShuffleInfo executor = executors.get(getAppExecId(appId, execId)); + if (executor == null) { + throw new RuntimeException( + String.format("Executor is not registered (appId=%s, execId=%s)", appId, execId)); + } + + if ("org.apache.spark.shuffle.hash.HashShuffleManager".equals(executor.shuffleManager)) { + return getHashBasedShuffleBlockData(executor, blockId); + } else if ("org.apache.spark.shuffle.sort.SortShuffleManager".equals(executor.shuffleManager)) { + return getSortBasedShuffleBlockData(executor, shuffleId, mapId, reduceId); + } else { + throw new UnsupportedOperationException( + "Unsupported shuffle manager: " + executor.shuffleManager); + } + } + + /** + * Hash-based shuffle data is simply stored as one file per block. + * This logic is from FileShuffleBlockManager. + */ + // TODO: Support consolidated hash shuffle files + private ManagedBuffer getHashBasedShuffleBlockData(ExecutorShuffleInfo executor, String blockId) { + File shuffleFile = getFile(executor.localDirs, executor.subDirsPerLocalDir, blockId); + return new FileSegmentManagedBuffer(shuffleFile, 0, shuffleFile.length()); + } + + /** + * Sort-based shuffle data uses an index called "shuffle_ShuffleId_MapId_0.index" into a data file + * called "shuffle_ShuffleId_MapId_0.data". This logic is from IndexShuffleBlockManager, + * and the block id format is from ShuffleDataBlockId and ShuffleIndexBlockId. + */ + private ManagedBuffer getSortBasedShuffleBlockData( + ExecutorShuffleInfo executor, int shuffleId, int mapId, int reduceId) { + File indexFile = getFile(executor.localDirs, executor.subDirsPerLocalDir, + "shuffle_" + shuffleId + "_" + mapId + "_0.index"); + + DataInputStream in = null; + try { + in = new DataInputStream(new FileInputStream(indexFile)); + in.skipBytes(reduceId * 8); + long offset = in.readLong(); + long nextOffset = in.readLong(); + return new FileSegmentManagedBuffer( + getFile(executor.localDirs, executor.subDirsPerLocalDir, + "shuffle_" + shuffleId + "_" + mapId + "_0.data"), + offset, + nextOffset - offset); + } catch (IOException e) { + throw new RuntimeException("Failed to open file: " + indexFile, e); + } finally { + if (in != null) { + JavaUtils.closeQuietly(in); + } + } + } + + /** + * Hashes a filename into the corresponding local directory, in a manner consistent with + * Spark's DiskBlockManager.getFile(). + */ + @VisibleForTesting + static File getFile(String[] localDirs, int subDirsPerLocalDir, String filename) { + int hash = JavaUtils.nonNegativeHash(filename); + String localDir = localDirs[hash % localDirs.length]; + int subDirId = (hash / localDirs.length) % subDirsPerLocalDir; + return new File(new File(localDir, String.format("%02x", subDirId)), filename); + } + + /** For testing, clears all registered executors. */ + @VisibleForTesting + void clearRegisteredExecutors() { + executors.clear(); + } +} diff --git a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleClient.java b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleClient.java new file mode 100644 index 0000000000000..6bbabc44b958b --- /dev/null +++ b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleClient.java @@ -0,0 +1,95 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.shuffle; + +import java.io.Closeable; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import org.apache.spark.network.TransportContext; +import org.apache.spark.network.client.TransportClient; +import org.apache.spark.network.client.TransportClientFactory; +import org.apache.spark.network.server.NoOpRpcHandler; +import org.apache.spark.network.shuffle.ExternalShuffleMessages.RegisterExecutor; +import org.apache.spark.network.util.JavaUtils; +import org.apache.spark.network.util.TransportConf; + +/** + * Client for reading shuffle blocks which points to an external (outside of executor) server. + * This is instead of reading shuffle blocks directly from other executors (via + * BlockTransferService), which has the downside of losing the shuffle data if we lose the + * executors. + */ +public class ExternalShuffleClient implements ShuffleClient { + private final Logger logger = LoggerFactory.getLogger(ExternalShuffleClient.class); + + private final TransportClientFactory clientFactory; + private final String appId; + + public ExternalShuffleClient(TransportConf conf, String appId) { + TransportContext context = new TransportContext(conf, new NoOpRpcHandler()); + this.clientFactory = context.createClientFactory(); + this.appId = appId; + } + + @Override + public void fetchBlocks( + String host, + int port, + String execId, + String[] blockIds, + BlockFetchingListener listener) { + logger.debug("External shuffle fetch from {}:{} (executor id {})", host, port, execId); + try { + TransportClient client = clientFactory.createClient(host, port); + new OneForOneBlockFetcher(client, blockIds, listener) + .start(new ExternalShuffleMessages.OpenShuffleBlocks(appId, execId, blockIds)); + } catch (Exception e) { + logger.error("Exception while beginning fetchBlocks", e); + for (String blockId : blockIds) { + listener.onBlockFetchFailure(blockId, e); + } + } + } + + /** + * Registers this executor with an external shuffle server. This registration is required to + * inform the shuffle server about where and how we store our shuffle files. + * + * @param host Host of shuffle server. + * @param port Port of shuffle server. + * @param execId This Executor's id. + * @param executorInfo Contains all info necessary for the service to find our shuffle files. + */ + public void registerWithShuffleServer( + String host, + int port, + String execId, + ExecutorShuffleInfo executorInfo) { + TransportClient client = clientFactory.createClient(host, port); + byte[] registerExecutorMessage = + JavaUtils.serialize(new RegisterExecutor(appId, execId, executorInfo)); + client.sendRpcSync(registerExecutorMessage, 5000 /* timeoutMs */); + } + + @Override + public void close() { + clientFactory.close(); + } +} diff --git a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleMessages.java b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleMessages.java new file mode 100644 index 0000000000000..e79420ed8254f --- /dev/null +++ b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleMessages.java @@ -0,0 +1,106 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.shuffle; + +import java.io.Serializable; +import java.util.Arrays; + +import com.google.common.base.Objects; + +/** Messages handled by the {@link ExternalShuffleBlockHandler}. */ +public class ExternalShuffleMessages { + + /** Request to read a set of shuffle blocks. Returns [[ShuffleStreamHandle]]. */ + public static class OpenShuffleBlocks implements Serializable { + public final String appId; + public final String execId; + public final String[] blockIds; + + public OpenShuffleBlocks(String appId, String execId, String[] blockIds) { + this.appId = appId; + this.execId = execId; + this.blockIds = blockIds; + } + + @Override + public int hashCode() { + return Objects.hashCode(appId, execId) * 41 + Arrays.hashCode(blockIds); + } + + @Override + public String toString() { + return Objects.toStringHelper(this) + .add("appId", appId) + .add("execId", execId) + .add("blockIds", Arrays.toString(blockIds)) + .toString(); + } + + @Override + public boolean equals(Object other) { + if (other != null && other instanceof OpenShuffleBlocks) { + OpenShuffleBlocks o = (OpenShuffleBlocks) other; + return Objects.equal(appId, o.appId) + && Objects.equal(execId, o.execId) + && Arrays.equals(blockIds, o.blockIds); + } + return false; + } + } + + /** Initial registration message between an executor and its local shuffle server. */ + public static class RegisterExecutor implements Serializable { + public final String appId; + public final String execId; + public final ExecutorShuffleInfo executorInfo; + + public RegisterExecutor( + String appId, + String execId, + ExecutorShuffleInfo executorInfo) { + this.appId = appId; + this.execId = execId; + this.executorInfo = executorInfo; + } + + @Override + public int hashCode() { + return Objects.hashCode(appId, execId, executorInfo); + } + + @Override + public String toString() { + return Objects.toStringHelper(this) + .add("appId", appId) + .add("execId", execId) + .add("executorInfo", executorInfo) + .toString(); + } + + @Override + public boolean equals(Object other) { + if (other != null && other instanceof RegisterExecutor) { + RegisterExecutor o = (RegisterExecutor) other; + return Objects.equal(appId, o.appId) + && Objects.equal(execId, o.execId) + && Objects.equal(executorInfo, o.executorInfo); + } + return false; + } + } +} diff --git a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/OneForOneBlockFetcher.java b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/OneForOneBlockFetcher.java new file mode 100644 index 0000000000000..39b6f30f92baf --- /dev/null +++ b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/OneForOneBlockFetcher.java @@ -0,0 +1,121 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.shuffle; + +import java.util.Arrays; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import org.apache.spark.network.buffer.ManagedBuffer; +import org.apache.spark.network.client.ChunkReceivedCallback; +import org.apache.spark.network.client.RpcResponseCallback; +import org.apache.spark.network.client.TransportClient; +import org.apache.spark.network.util.JavaUtils; + +/** + * Simple wrapper on top of a TransportClient which interprets each chunk as a whole block, and + * invokes the BlockFetchingListener appropriately. This class is agnostic to the actual RPC + * handler, as long as there is a single "open blocks" message which returns a ShuffleStreamHandle, + * and Java serialization is used. + * + * Note that this typically corresponds to a + * {@link org.apache.spark.network.server.OneForOneStreamManager} on the server side. + */ +public class OneForOneBlockFetcher { + private final Logger logger = LoggerFactory.getLogger(OneForOneBlockFetcher.class); + + private final TransportClient client; + private final String[] blockIds; + private final BlockFetchingListener listener; + private final ChunkReceivedCallback chunkCallback; + + private ShuffleStreamHandle streamHandle = null; + + public OneForOneBlockFetcher( + TransportClient client, + String[] blockIds, + BlockFetchingListener listener) { + if (blockIds.length == 0) { + throw new IllegalArgumentException("Zero-sized blockIds array"); + } + this.client = client; + this.blockIds = blockIds; + this.listener = listener; + this.chunkCallback = new ChunkCallback(); + } + + /** Callback invoked on receipt of each chunk. We equate a single chunk to a single block. */ + private class ChunkCallback implements ChunkReceivedCallback { + @Override + public void onSuccess(int chunkIndex, ManagedBuffer buffer) { + // On receipt of a chunk, pass it upwards as a block. + listener.onBlockFetchSuccess(blockIds[chunkIndex], buffer); + } + + @Override + public void onFailure(int chunkIndex, Throwable e) { + // On receipt of a failure, fail every block from chunkIndex onwards. + String[] remainingBlockIds = Arrays.copyOfRange(blockIds, chunkIndex, blockIds.length); + failRemainingBlocks(remainingBlockIds, e); + } + } + + /** + * Begins the fetching process, calling the listener with every block fetched. + * The given message will be serialized with the Java serializer, and the RPC must return a + * {@link ShuffleStreamHandle}. We will send all fetch requests immediately, without throttling. + */ + public void start(Object openBlocksMessage) { + client.sendRpc(JavaUtils.serialize(openBlocksMessage), new RpcResponseCallback() { + @Override + public void onSuccess(byte[] response) { + try { + streamHandle = JavaUtils.deserialize(response); + logger.trace("Successfully opened blocks {}, preparing to fetch chunks.", streamHandle); + + // Immediately request all chunks -- we expect that the total size of the request is + // reasonable due to higher level chunking in [[ShuffleBlockFetcherIterator]]. + for (int i = 0; i < streamHandle.numChunks; i++) { + client.fetchChunk(streamHandle.streamId, i, chunkCallback); + } + } catch (Exception e) { + logger.error("Failed while starting block fetches", e); + failRemainingBlocks(blockIds, e); + } + } + + @Override + public void onFailure(Throwable e) { + logger.error("Failed while starting block fetches", e); + failRemainingBlocks(blockIds, e); + } + }); + } + + /** Invokes the "onBlockFetchFailure" callback for every listed block id. */ + private void failRemainingBlocks(String[] failedBlockIds, Throwable e) { + for (String blockId : failedBlockIds) { + try { + listener.onBlockFetchFailure(blockId, e); + } catch (Exception e2) { + logger.error("Error in block fetch failure callback", e2); + } + } + } +} diff --git a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ShuffleClient.java b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ShuffleClient.java new file mode 100644 index 0000000000000..d46a562394557 --- /dev/null +++ b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ShuffleClient.java @@ -0,0 +1,37 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.shuffle; + +import java.io.Closeable; + +/** Provides an interface for reading shuffle files, either from an Executor or external service. */ +public interface ShuffleClient extends Closeable { + /** + * Fetch a sequence of blocks from a remote node asynchronously, + * + * Note that this API takes a sequence so the implementation can batch requests, and does not + * return a future so the underlying implementation can invoke onBlockFetchSuccess as soon as + * the data of a block is fetched, rather than waiting for all blocks to be fetched. + */ + public void fetchBlocks( + String host, + int port, + String execId, + String[] blockIds, + BlockFetchingListener listener); +} diff --git a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ShuffleStreamHandle.java b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ShuffleStreamHandle.java new file mode 100644 index 0000000000000..9c94691224328 --- /dev/null +++ b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ShuffleStreamHandle.java @@ -0,0 +1,60 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.shuffle; + +import java.io.Serializable; +import java.util.Arrays; + +import com.google.common.base.Objects; + +/** + * Identifier for a fixed number of chunks to read from a stream created by an "open blocks" + * message. This is used by {@link OneForOneBlockFetcher}. + */ +public class ShuffleStreamHandle implements Serializable { + public final long streamId; + public final int numChunks; + + public ShuffleStreamHandle(long streamId, int numChunks) { + this.streamId = streamId; + this.numChunks = numChunks; + } + + @Override + public int hashCode() { + return Objects.hashCode(streamId, numChunks); + } + + @Override + public String toString() { + return Objects.toStringHelper(this) + .add("streamId", streamId) + .add("numChunks", numChunks) + .toString(); + } + + @Override + public boolean equals(Object other) { + if (other != null && other instanceof ShuffleStreamHandle) { + ShuffleStreamHandle o = (ShuffleStreamHandle) other; + return Objects.equal(streamId, o.streamId) + && Objects.equal(numChunks, o.numChunks); + } + return false; + } +} diff --git a/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandlerSuite.java b/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandlerSuite.java new file mode 100644 index 0000000000000..7939cb4d32690 --- /dev/null +++ b/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandlerSuite.java @@ -0,0 +1,123 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.shuffle; + +import java.nio.ByteBuffer; +import java.util.Iterator; + +import org.junit.Before; +import org.junit.Test; +import org.mockito.ArgumentCaptor; + +import static org.apache.spark.network.shuffle.ExternalShuffleMessages.OpenShuffleBlocks; +import static org.apache.spark.network.shuffle.ExternalShuffleMessages.RegisterExecutor; +import static org.junit.Assert.*; +import static org.mockito.Matchers.any; +import static org.mockito.Mockito.*; + +import org.apache.spark.network.buffer.ManagedBuffer; +import org.apache.spark.network.buffer.NioManagedBuffer; +import org.apache.spark.network.client.RpcResponseCallback; +import org.apache.spark.network.client.TransportClient; +import org.apache.spark.network.server.OneForOneStreamManager; +import org.apache.spark.network.server.RpcHandler; +import org.apache.spark.network.util.JavaUtils; + +public class ExternalShuffleBlockHandlerSuite { + TransportClient client = mock(TransportClient.class); + + OneForOneStreamManager streamManager; + ExternalShuffleBlockManager blockManager; + RpcHandler handler; + + @Before + public void beforeEach() { + streamManager = mock(OneForOneStreamManager.class); + blockManager = mock(ExternalShuffleBlockManager.class); + handler = new ExternalShuffleBlockHandler(streamManager, blockManager); + } + + @Test + public void testRegisterExecutor() { + RpcResponseCallback callback = mock(RpcResponseCallback.class); + + ExecutorShuffleInfo config = new ExecutorShuffleInfo(new String[] {"/a", "/b"}, 16, "sort"); + byte[] registerMessage = JavaUtils.serialize( + new RegisterExecutor("app0", "exec1", config)); + handler.receive(client, registerMessage, callback); + verify(blockManager, times(1)).registerExecutor("app0", "exec1", config); + + verify(callback, times(1)).onSuccess((byte[]) any()); + verify(callback, never()).onFailure((Throwable) any()); + } + + @SuppressWarnings("unchecked") + @Test + public void testOpenShuffleBlocks() { + RpcResponseCallback callback = mock(RpcResponseCallback.class); + + ManagedBuffer block0Marker = new NioManagedBuffer(ByteBuffer.wrap(new byte[3])); + ManagedBuffer block1Marker = new NioManagedBuffer(ByteBuffer.wrap(new byte[7])); + when(blockManager.getBlockData("app0", "exec1", "b0")).thenReturn(block0Marker); + when(blockManager.getBlockData("app0", "exec1", "b1")).thenReturn(block1Marker); + byte[] openBlocksMessage = JavaUtils.serialize( + new OpenShuffleBlocks("app0", "exec1", new String[] { "b0", "b1" })); + handler.receive(client, openBlocksMessage, callback); + verify(blockManager, times(1)).getBlockData("app0", "exec1", "b0"); + verify(blockManager, times(1)).getBlockData("app0", "exec1", "b1"); + + ArgumentCaptor response = ArgumentCaptor.forClass(byte[].class); + verify(callback, times(1)).onSuccess(response.capture()); + verify(callback, never()).onFailure((Throwable) any()); + + ShuffleStreamHandle handle = JavaUtils.deserialize(response.getValue()); + assertEquals(2, handle.numChunks); + + ArgumentCaptor stream = ArgumentCaptor.forClass(Iterator.class); + verify(streamManager, times(1)).registerStream(stream.capture()); + Iterator buffers = (Iterator) stream.getValue(); + assertEquals(block0Marker, buffers.next()); + assertEquals(block1Marker, buffers.next()); + assertFalse(buffers.hasNext()); + } + + @Test + public void testBadMessages() { + RpcResponseCallback callback = mock(RpcResponseCallback.class); + + byte[] unserializableMessage = new byte[] { 0x12, 0x34, 0x56 }; + try { + handler.receive(client, unserializableMessage, callback); + fail("Should have thrown"); + } catch (Exception e) { + // pass + } + + byte[] unexpectedMessage = JavaUtils.serialize( + new ExecutorShuffleInfo(new String[] {"/a", "/b"}, 16, "sort")); + try { + handler.receive(client, unexpectedMessage, callback); + fail("Should have thrown"); + } catch (UnsupportedOperationException e) { + // pass + } + + verify(callback, never()).onSuccess((byte[]) any()); + verify(callback, never()).onFailure((Throwable) any()); + } +} diff --git a/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleBlockManagerSuite.java b/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleBlockManagerSuite.java new file mode 100644 index 0000000000000..da54797e8923c --- /dev/null +++ b/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleBlockManagerSuite.java @@ -0,0 +1,125 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.shuffle; + +import java.io.IOException; +import java.io.InputStream; +import java.io.InputStreamReader; + +import com.google.common.io.CharStreams; +import org.junit.AfterClass; +import org.junit.BeforeClass; +import org.junit.Test; + +import static org.junit.Assert.*; + +public class ExternalShuffleBlockManagerSuite { + static String sortBlock0 = "Hello!"; + static String sortBlock1 = "World!"; + + static String hashBlock0 = "Elementary"; + static String hashBlock1 = "Tabular"; + + static TestShuffleDataContext dataContext; + + @BeforeClass + public static void beforeAll() throws IOException { + dataContext = new TestShuffleDataContext(2, 5); + + dataContext.create(); + // Write some sort and hash data. + dataContext.insertSortShuffleData(0, 0, + new byte[][] { sortBlock0.getBytes(), sortBlock1.getBytes() } ); + dataContext.insertHashShuffleData(1, 0, + new byte[][] { hashBlock0.getBytes(), hashBlock1.getBytes() } ); + } + + @AfterClass + public static void afterAll() { + dataContext.cleanup(); + } + + @Test + public void testBadRequests() { + ExternalShuffleBlockManager manager = new ExternalShuffleBlockManager(); + // Unregistered executor + try { + manager.getBlockData("app0", "exec1", "shuffle_1_1_0"); + fail("Should have failed"); + } catch (RuntimeException e) { + assertTrue("Bad error message: " + e, e.getMessage().contains("not registered")); + } + + // Invalid shuffle manager + manager.registerExecutor("app0", "exec2", dataContext.createExecutorInfo("foobar")); + try { + manager.getBlockData("app0", "exec2", "shuffle_1_1_0"); + fail("Should have failed"); + } catch (UnsupportedOperationException e) { + // pass + } + + // Nonexistent shuffle block + manager.registerExecutor("app0", "exec3", + dataContext.createExecutorInfo("org.apache.spark.shuffle.sort.SortShuffleManager")); + try { + manager.getBlockData("app0", "exec3", "shuffle_1_1_0"); + fail("Should have failed"); + } catch (Exception e) { + // pass + } + } + + @Test + public void testSortShuffleBlocks() throws IOException { + ExternalShuffleBlockManager manager = new ExternalShuffleBlockManager(); + manager.registerExecutor("app0", "exec0", + dataContext.createExecutorInfo("org.apache.spark.shuffle.sort.SortShuffleManager")); + + InputStream block0Stream = + manager.getBlockData("app0", "exec0", "shuffle_0_0_0").createInputStream(); + String block0 = CharStreams.toString(new InputStreamReader(block0Stream)); + block0Stream.close(); + assertEquals(sortBlock0, block0); + + InputStream block1Stream = + manager.getBlockData("app0", "exec0", "shuffle_0_0_1").createInputStream(); + String block1 = CharStreams.toString(new InputStreamReader(block1Stream)); + block1Stream.close(); + assertEquals(sortBlock1, block1); + } + + @Test + public void testHashShuffleBlocks() throws IOException { + ExternalShuffleBlockManager manager = new ExternalShuffleBlockManager(); + manager.registerExecutor("app0", "exec0", + dataContext.createExecutorInfo("org.apache.spark.shuffle.hash.HashShuffleManager")); + + InputStream block0Stream = + manager.getBlockData("app0", "exec0", "shuffle_1_0_0").createInputStream(); + String block0 = CharStreams.toString(new InputStreamReader(block0Stream)); + block0Stream.close(); + assertEquals(hashBlock0, block0); + + InputStream block1Stream = + manager.getBlockData("app0", "exec0", "shuffle_1_0_1").createInputStream(); + String block1 = CharStreams.toString(new InputStreamReader(block1Stream)); + block1Stream.close(); + assertEquals(hashBlock1, block1); + } +} diff --git a/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleIntegrationSuite.java b/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleIntegrationSuite.java new file mode 100644 index 0000000000000..b3bcf5fd68e73 --- /dev/null +++ b/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleIntegrationSuite.java @@ -0,0 +1,291 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.shuffle; + +import java.io.IOException; +import java.nio.ByteBuffer; +import java.util.Collections; +import java.util.HashSet; +import java.util.LinkedList; +import java.util.List; +import java.util.Random; +import java.util.Set; +import java.util.concurrent.Semaphore; +import java.util.concurrent.TimeUnit; + +import com.google.common.collect.Lists; +import com.google.common.collect.Sets; +import org.junit.After; +import org.junit.AfterClass; +import org.junit.BeforeClass; +import org.junit.Test; + +import static org.junit.Assert.*; + +import org.apache.spark.network.TestUtils; +import org.apache.spark.network.TransportContext; +import org.apache.spark.network.buffer.ManagedBuffer; +import org.apache.spark.network.buffer.NioManagedBuffer; +import org.apache.spark.network.server.TransportServer; +import org.apache.spark.network.util.SystemPropertyConfigProvider; +import org.apache.spark.network.util.TransportConf; + +public class ExternalShuffleIntegrationSuite { + + static String APP_ID = "app-id"; + static String SORT_MANAGER = "org.apache.spark.shuffle.sort.SortShuffleManager"; + static String HASH_MANAGER = "org.apache.spark.shuffle.hash.HashShuffleManager"; + + // Executor 0 is sort-based + static TestShuffleDataContext dataContext0; + // Executor 1 is hash-based + static TestShuffleDataContext dataContext1; + + static ExternalShuffleBlockHandler handler; + static TransportServer server; + static TransportConf conf; + + static byte[][] exec0Blocks = new byte[][] { + new byte[123], + new byte[12345], + new byte[1234567], + }; + + static byte[][] exec1Blocks = new byte[][] { + new byte[321], + new byte[54321], + }; + + @BeforeClass + public static void beforeAll() throws IOException { + Random rand = new Random(); + + for (byte[] block : exec0Blocks) { + rand.nextBytes(block); + } + for (byte[] block: exec1Blocks) { + rand.nextBytes(block); + } + + dataContext0 = new TestShuffleDataContext(2, 5); + dataContext0.create(); + dataContext0.insertSortShuffleData(0, 0, exec0Blocks); + + dataContext1 = new TestShuffleDataContext(6, 2); + dataContext1.create(); + dataContext1.insertHashShuffleData(1, 0, exec1Blocks); + + conf = new TransportConf(new SystemPropertyConfigProvider()); + handler = new ExternalShuffleBlockHandler(); + TransportContext transportContext = new TransportContext(conf, handler); + server = transportContext.createServer(); + } + + @AfterClass + public static void afterAll() { + dataContext0.cleanup(); + dataContext1.cleanup(); + server.close(); + } + + @After + public void afterEach() { + handler.clearRegisteredExecutors(); + } + + class FetchResult { + public Set successBlocks; + public Set failedBlocks; + public List buffers; + + public void releaseBuffers() { + for (ManagedBuffer buffer : buffers) { + buffer.release(); + } + } + } + + // Fetch a set of blocks from a pre-registered executor. + private FetchResult fetchBlocks(String execId, String[] blockIds) throws Exception { + return fetchBlocks(execId, blockIds, server.getPort()); + } + + // Fetch a set of blocks from a pre-registered executor. Connects to the server on the given port, + // to allow connecting to invalid servers. + private FetchResult fetchBlocks(String execId, String[] blockIds, int port) throws Exception { + final FetchResult res = new FetchResult(); + res.successBlocks = Collections.synchronizedSet(new HashSet()); + res.failedBlocks = Collections.synchronizedSet(new HashSet()); + res.buffers = Collections.synchronizedList(new LinkedList()); + + final Semaphore requestsRemaining = new Semaphore(0); + + ExternalShuffleClient client = new ExternalShuffleClient(conf, APP_ID); + client.fetchBlocks(TestUtils.getLocalHost(), port, execId, blockIds, + new BlockFetchingListener() { + @Override + public void onBlockFetchSuccess(String blockId, ManagedBuffer data) { + synchronized (this) { + if (!res.successBlocks.contains(blockId) && !res.failedBlocks.contains(blockId)) { + data.retain(); + res.successBlocks.add(blockId); + res.buffers.add(data); + requestsRemaining.release(); + } + } + } + + @Override + public void onBlockFetchFailure(String blockId, Throwable exception) { + synchronized (this) { + if (!res.successBlocks.contains(blockId) && !res.failedBlocks.contains(blockId)) { + res.failedBlocks.add(blockId); + requestsRemaining.release(); + } + } + } + }); + + if (!requestsRemaining.tryAcquire(blockIds.length, 5, TimeUnit.SECONDS)) { + fail("Timeout getting response from the server"); + } + return res; + } + + @Test + public void testFetchOneSort() throws Exception { + registerExecutor("exec-0", dataContext0.createExecutorInfo(SORT_MANAGER)); + FetchResult exec0Fetch = fetchBlocks("exec-0", new String[] { "shuffle_0_0_0" }); + assertEquals(Sets.newHashSet("shuffle_0_0_0"), exec0Fetch.successBlocks); + assertTrue(exec0Fetch.failedBlocks.isEmpty()); + assertBufferListsEqual(exec0Fetch.buffers, Lists.newArrayList(exec0Blocks[0])); + exec0Fetch.releaseBuffers(); + } + + @Test + public void testFetchThreeSort() throws Exception { + registerExecutor("exec-0", dataContext0.createExecutorInfo(SORT_MANAGER)); + FetchResult exec0Fetch = fetchBlocks("exec-0", + new String[] { "shuffle_0_0_0", "shuffle_0_0_1", "shuffle_0_0_2" }); + assertEquals(Sets.newHashSet("shuffle_0_0_0", "shuffle_0_0_1", "shuffle_0_0_2"), + exec0Fetch.successBlocks); + assertTrue(exec0Fetch.failedBlocks.isEmpty()); + assertBufferListsEqual(exec0Fetch.buffers, Lists.newArrayList(exec0Blocks)); + exec0Fetch.releaseBuffers(); + } + + @Test + public void testFetchHash() throws Exception { + registerExecutor("exec-1", dataContext1.createExecutorInfo(HASH_MANAGER)); + FetchResult execFetch = fetchBlocks("exec-1", + new String[] { "shuffle_1_0_0", "shuffle_1_0_1" }); + assertEquals(Sets.newHashSet("shuffle_1_0_0", "shuffle_1_0_1"), execFetch.successBlocks); + assertTrue(execFetch.failedBlocks.isEmpty()); + assertBufferListsEqual(execFetch.buffers, Lists.newArrayList(exec1Blocks)); + execFetch.releaseBuffers(); + } + + @Test + public void testFetchWrongShuffle() throws Exception { + registerExecutor("exec-1", dataContext1.createExecutorInfo(SORT_MANAGER /* wrong manager */)); + FetchResult execFetch = fetchBlocks("exec-1", + new String[] { "shuffle_1_0_0", "shuffle_1_0_1" }); + assertTrue(execFetch.successBlocks.isEmpty()); + assertEquals(Sets.newHashSet("shuffle_1_0_0", "shuffle_1_0_1"), execFetch.failedBlocks); + } + + @Test + public void testFetchInvalidShuffle() throws Exception { + registerExecutor("exec-1", dataContext1.createExecutorInfo("unknown sort manager")); + FetchResult execFetch = fetchBlocks("exec-1", + new String[] { "shuffle_1_0_0" }); + assertTrue(execFetch.successBlocks.isEmpty()); + assertEquals(Sets.newHashSet("shuffle_1_0_0"), execFetch.failedBlocks); + } + + @Test + public void testFetchWrongBlockId() throws Exception { + registerExecutor("exec-1", dataContext1.createExecutorInfo(SORT_MANAGER /* wrong manager */)); + FetchResult execFetch = fetchBlocks("exec-1", + new String[] { "rdd_1_0_0" }); + assertTrue(execFetch.successBlocks.isEmpty()); + assertEquals(Sets.newHashSet("rdd_1_0_0"), execFetch.failedBlocks); + } + + @Test + public void testFetchNonexistent() throws Exception { + registerExecutor("exec-0", dataContext0.createExecutorInfo(SORT_MANAGER)); + FetchResult execFetch = fetchBlocks("exec-0", + new String[] { "shuffle_2_0_0" }); + assertTrue(execFetch.successBlocks.isEmpty()); + assertEquals(Sets.newHashSet("shuffle_2_0_0"), execFetch.failedBlocks); + } + + @Test + public void testFetchWrongExecutor() throws Exception { + registerExecutor("exec-0", dataContext0.createExecutorInfo(SORT_MANAGER)); + FetchResult execFetch = fetchBlocks("exec-0", + new String[] { "shuffle_0_0_0" /* right */, "shuffle_1_0_0" /* wrong */ }); + // Both still fail, as we start by checking for all block. + assertTrue(execFetch.successBlocks.isEmpty()); + assertEquals(Sets.newHashSet("shuffle_0_0_0", "shuffle_1_0_0"), execFetch.failedBlocks); + } + + @Test + public void testFetchUnregisteredExecutor() throws Exception { + registerExecutor("exec-0", dataContext0.createExecutorInfo(SORT_MANAGER)); + FetchResult execFetch = fetchBlocks("exec-2", + new String[] { "shuffle_0_0_0", "shuffle_1_0_0" }); + assertTrue(execFetch.successBlocks.isEmpty()); + assertEquals(Sets.newHashSet("shuffle_0_0_0", "shuffle_1_0_0"), execFetch.failedBlocks); + } + + @Test + public void testFetchNoServer() throws Exception { + registerExecutor("exec-0", dataContext0.createExecutorInfo(SORT_MANAGER)); + FetchResult execFetch = fetchBlocks("exec-0", + new String[] { "shuffle_1_0_0", "shuffle_1_0_1" }, 1 /* port */); + assertTrue(execFetch.successBlocks.isEmpty()); + assertEquals(Sets.newHashSet("shuffle_1_0_0", "shuffle_1_0_1"), execFetch.failedBlocks); + } + + private void registerExecutor(String executorId, ExecutorShuffleInfo executorInfo) { + ExternalShuffleClient client = new ExternalShuffleClient(conf, APP_ID); + client.registerWithShuffleServer(TestUtils.getLocalHost(), server.getPort(), + executorId, executorInfo); + } + + private void assertBufferListsEqual(List list0, List list1) + throws Exception { + assertEquals(list0.size(), list1.size()); + for (int i = 0; i < list0.size(); i ++) { + assertBuffersEqual(list0.get(i), new NioManagedBuffer(ByteBuffer.wrap(list1.get(i)))); + } + } + + private void assertBuffersEqual(ManagedBuffer buffer0, ManagedBuffer buffer1) throws Exception { + ByteBuffer nio0 = buffer0.nioByteBuffer(); + ByteBuffer nio1 = buffer1.nioByteBuffer(); + + int len = nio0.remaining(); + assertEquals(nio0.remaining(), nio1.remaining()); + for (int i = 0; i < len; i ++) { + assertEquals(nio0.get(), nio1.get()); + } + } +} diff --git a/network/shuffle/src/test/java/org/apache/spark/network/shuffle/OneForOneBlockFetcherSuite.java b/network/shuffle/src/test/java/org/apache/spark/network/shuffle/OneForOneBlockFetcherSuite.java new file mode 100644 index 0000000000000..c18346f6966d6 --- /dev/null +++ b/network/shuffle/src/test/java/org/apache/spark/network/shuffle/OneForOneBlockFetcherSuite.java @@ -0,0 +1,167 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.shuffle; + +import java.nio.ByteBuffer; +import java.util.Iterator; +import java.util.LinkedHashMap; +import java.util.concurrent.atomic.AtomicInteger; + +import com.google.common.collect.Maps; +import io.netty.buffer.Unpooled; +import org.junit.Test; +import org.mockito.invocation.InvocationOnMock; +import org.mockito.stubbing.Answer; + +import static org.junit.Assert.*; +import static org.junit.Assert.assertEquals; +import static org.mockito.Matchers.any; +import static org.mockito.Matchers.eq; +import static org.mockito.Mockito.*; + +import org.apache.spark.network.buffer.ManagedBuffer; +import org.apache.spark.network.buffer.NettyManagedBuffer; +import org.apache.spark.network.buffer.NioManagedBuffer; +import org.apache.spark.network.client.ChunkReceivedCallback; +import org.apache.spark.network.client.RpcResponseCallback; +import org.apache.spark.network.client.TransportClient; +import org.apache.spark.network.util.JavaUtils; + +public class OneForOneBlockFetcherSuite { + @Test + public void testFetchOne() { + LinkedHashMap blocks = Maps.newLinkedHashMap(); + blocks.put("shuffle_0_0_0", new NioManagedBuffer(ByteBuffer.wrap(new byte[0]))); + + BlockFetchingListener listener = fetchBlocks(blocks); + + verify(listener).onBlockFetchSuccess("shuffle_0_0_0", blocks.get("shuffle_0_0_0")); + } + + @Test + public void testFetchThree() { + LinkedHashMap blocks = Maps.newLinkedHashMap(); + blocks.put("b0", new NioManagedBuffer(ByteBuffer.wrap(new byte[12]))); + blocks.put("b1", new NioManagedBuffer(ByteBuffer.wrap(new byte[23]))); + blocks.put("b2", new NettyManagedBuffer(Unpooled.wrappedBuffer(new byte[23]))); + + BlockFetchingListener listener = fetchBlocks(blocks); + + for (int i = 0; i < 3; i ++) { + verify(listener, times(1)).onBlockFetchSuccess("b" + i, blocks.get("b" + i)); + } + } + + @Test + public void testFailure() { + LinkedHashMap blocks = Maps.newLinkedHashMap(); + blocks.put("b0", new NioManagedBuffer(ByteBuffer.wrap(new byte[12]))); + blocks.put("b1", null); + blocks.put("b2", null); + + BlockFetchingListener listener = fetchBlocks(blocks); + + // Each failure will cause a failure to be invoked in all remaining block fetches. + verify(listener, times(1)).onBlockFetchSuccess("b0", blocks.get("b0")); + verify(listener, times(1)).onBlockFetchFailure(eq("b1"), (Throwable) any()); + verify(listener, times(2)).onBlockFetchFailure(eq("b2"), (Throwable) any()); + } + + @Test + public void testFailureAndSuccess() { + LinkedHashMap blocks = Maps.newLinkedHashMap(); + blocks.put("b0", new NioManagedBuffer(ByteBuffer.wrap(new byte[12]))); + blocks.put("b1", null); + blocks.put("b2", new NioManagedBuffer(ByteBuffer.wrap(new byte[21]))); + + BlockFetchingListener listener = fetchBlocks(blocks); + + // We may call both success and failure for the same block. + verify(listener, times(1)).onBlockFetchSuccess("b0", blocks.get("b0")); + verify(listener, times(1)).onBlockFetchFailure(eq("b1"), (Throwable) any()); + verify(listener, times(1)).onBlockFetchSuccess("b2", blocks.get("b2")); + verify(listener, times(1)).onBlockFetchFailure(eq("b2"), (Throwable) any()); + } + + @Test + public void testEmptyBlockFetch() { + try { + fetchBlocks(Maps.newLinkedHashMap()); + fail(); + } catch (IllegalArgumentException e) { + assertEquals("Zero-sized blockIds array", e.getMessage()); + } + } + + /** + * Begins a fetch on the given set of blocks by mocking out the server side of the RPC which + * simply returns the given (BlockId, Block) pairs. + * As "blocks" is a LinkedHashMap, the blocks are guaranteed to be returned in the same order + * that they were inserted in. + * + * If a block's buffer is "null", an exception will be thrown instead. + */ + private BlockFetchingListener fetchBlocks(final LinkedHashMap blocks) { + TransportClient client = mock(TransportClient.class); + BlockFetchingListener listener = mock(BlockFetchingListener.class); + String[] blockIds = blocks.keySet().toArray(new String[blocks.size()]); + OneForOneBlockFetcher fetcher = new OneForOneBlockFetcher(client, blockIds, listener); + + // Respond to the "OpenBlocks" message with an appropirate ShuffleStreamHandle with streamId 123 + doAnswer(new Answer() { + @Override + public Void answer(InvocationOnMock invocationOnMock) throws Throwable { + String message = JavaUtils.deserialize((byte[]) invocationOnMock.getArguments()[0]); + RpcResponseCallback callback = (RpcResponseCallback) invocationOnMock.getArguments()[1]; + callback.onSuccess(JavaUtils.serialize(new ShuffleStreamHandle(123, blocks.size()))); + assertEquals("OpenZeBlocks", message); + return null; + } + }).when(client).sendRpc((byte[]) any(), (RpcResponseCallback) any()); + + // Respond to each chunk request with a single buffer from our blocks array. + final AtomicInteger expectedChunkIndex = new AtomicInteger(0); + final Iterator blockIterator = blocks.values().iterator(); + doAnswer(new Answer() { + @Override + public Void answer(InvocationOnMock invocation) throws Throwable { + try { + long streamId = (Long) invocation.getArguments()[0]; + int myChunkIndex = (Integer) invocation.getArguments()[1]; + assertEquals(123, streamId); + assertEquals(expectedChunkIndex.getAndIncrement(), myChunkIndex); + + ChunkReceivedCallback callback = (ChunkReceivedCallback) invocation.getArguments()[2]; + ManagedBuffer result = blockIterator.next(); + if (result != null) { + callback.onSuccess(myChunkIndex, result); + } else { + callback.onFailure(myChunkIndex, new RuntimeException("Failed " + myChunkIndex)); + } + } catch (Exception e) { + e.printStackTrace(); + fail("Unexpected failure"); + } + return null; + } + }).when(client).fetchChunk(anyLong(), anyInt(), (ChunkReceivedCallback) any()); + + fetcher.start("OpenZeBlocks"); + return listener; + } +} diff --git a/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ShuffleMessagesSuite.java b/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ShuffleMessagesSuite.java new file mode 100644 index 0000000000000..ee9482b49cfc3 --- /dev/null +++ b/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ShuffleMessagesSuite.java @@ -0,0 +1,51 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.shuffle; + +import org.junit.Test; + +import static org.junit.Assert.*; + +import org.apache.spark.network.util.JavaUtils; + +import static org.apache.spark.network.shuffle.ExternalShuffleMessages.*; + +public class ShuffleMessagesSuite { + @Test + public void serializeOpenShuffleBlocks() { + OpenShuffleBlocks msg = new OpenShuffleBlocks("app-1", "exec-2", + new String[] { "block0", "block1" }); + OpenShuffleBlocks msg2 = JavaUtils.deserialize(JavaUtils.serialize(msg)); + assertEquals(msg, msg2); + } + + @Test + public void serializeRegisterExecutor() { + RegisterExecutor msg = new RegisterExecutor("app-1", "exec-2", new ExecutorShuffleInfo( + new String[] { "/local1", "/local2" }, 32, "MyShuffleManager")); + RegisterExecutor msg2 = JavaUtils.deserialize(JavaUtils.serialize(msg)); + assertEquals(msg, msg2); + } + + @Test + public void serializeShuffleStreamHandle() { + ShuffleStreamHandle msg = new ShuffleStreamHandle(12345, 16); + ShuffleStreamHandle msg2 = JavaUtils.deserialize(JavaUtils.serialize(msg)); + assertEquals(msg, msg2); + } +} diff --git a/network/shuffle/src/test/java/org/apache/spark/network/shuffle/TestShuffleDataContext.java b/network/shuffle/src/test/java/org/apache/spark/network/shuffle/TestShuffleDataContext.java new file mode 100644 index 0000000000000..442b756467442 --- /dev/null +++ b/network/shuffle/src/test/java/org/apache/spark/network/shuffle/TestShuffleDataContext.java @@ -0,0 +1,107 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.shuffle; + +import java.io.DataOutputStream; +import java.io.File; +import java.io.FileOutputStream; +import java.io.IOException; +import java.io.OutputStream; + +import com.google.common.io.Files; + +/** + * Manages some sort- and hash-based shuffle data, including the creation + * and cleanup of directories that can be read by the {@link ExternalShuffleBlockManager}. + */ +public class TestShuffleDataContext { + private final String[] localDirs; + private final int subDirsPerLocalDir; + + public TestShuffleDataContext(int numLocalDirs, int subDirsPerLocalDir) { + this.localDirs = new String[numLocalDirs]; + this.subDirsPerLocalDir = subDirsPerLocalDir; + } + + public void create() { + for (int i = 0; i < localDirs.length; i ++) { + localDirs[i] = Files.createTempDir().getAbsolutePath(); + + for (int p = 0; p < subDirsPerLocalDir; p ++) { + new File(localDirs[i], String.format("%02x", p)).mkdirs(); + } + } + } + + public void cleanup() { + for (String localDir : localDirs) { + deleteRecursively(new File(localDir)); + } + } + + /** Creates reducer blocks in a sort-based data format within our local dirs. */ + public void insertSortShuffleData(int shuffleId, int mapId, byte[][] blocks) throws IOException { + String blockId = "shuffle_" + shuffleId + "_" + mapId + "_0"; + + OutputStream dataStream = new FileOutputStream( + ExternalShuffleBlockManager.getFile(localDirs, subDirsPerLocalDir, blockId + ".data")); + DataOutputStream indexStream = new DataOutputStream(new FileOutputStream( + ExternalShuffleBlockManager.getFile(localDirs, subDirsPerLocalDir, blockId + ".index"))); + + long offset = 0; + indexStream.writeLong(offset); + for (byte[] block : blocks) { + offset += block.length; + dataStream.write(block); + indexStream.writeLong(offset); + } + + dataStream.close(); + indexStream.close(); + } + + /** Creates reducer blocks in a hash-based data format within our local dirs. */ + public void insertHashShuffleData(int shuffleId, int mapId, byte[][] blocks) throws IOException { + for (int i = 0; i < blocks.length; i ++) { + String blockId = "shuffle_" + shuffleId + "_" + mapId + "_" + i; + Files.write(blocks[i], + ExternalShuffleBlockManager.getFile(localDirs, subDirsPerLocalDir, blockId)); + } + } + + /** + * Creates an ExecutorShuffleInfo object based on the given shuffle manager which targets this + * context's directories. + */ + public ExecutorShuffleInfo createExecutorInfo(String shuffleManager) { + return new ExecutorShuffleInfo(localDirs, subDirsPerLocalDir, shuffleManager); + } + + private static void deleteRecursively(File f) { + assert f != null; + if (f.isDirectory()) { + File[] children = f.listFiles(); + if (children != null) { + for (File child : children) { + deleteRecursively(child); + } + } + } + f.delete(); + } +} diff --git a/pom.xml b/pom.xml index 030bea948b5ce..eb613531b8a5f 100644 --- a/pom.xml +++ b/pom.xml @@ -91,6 +91,8 @@ graphx mllib tools + network/common + network/shuffle streaming sql/catalyst sql/core @@ -128,11 +130,11 @@ 1.4.0 3.4.5 - 0.13.1 + 0.13.1a 0.13.1 10.10.1.1 - 1.4.3 + 1.6.0rc3 1.2.3 8.1.14.v20131031 0.3.6 @@ -143,6 +145,7 @@ 1.8.3 1.1.0 4.2.6 + 3.1.1 64m 512m @@ -239,6 +242,18 @@ false + + + spark-staging-hive13 + Spring Staging Repository Hive 13 + https://oss.sonatype.org/content/repositories/orgspark-project-1089/ + + true + + + false + + @@ -304,14 +319,19 @@ org.apache.commons commons-math3 - 3.3 - test + ${commons.math3.version} com.google.code.findbugs jsr305 1.3.9 + + org.seleniumhq.selenium + selenium-java + 2.42.2 + test + org.slf4j slf4j-api @@ -346,7 +366,7 @@ org.xerial.snappy snappy-java - 1.1.1.5 + 1.1.1.3 net.jpountz.lz4 @@ -425,11 +445,6 @@ akka-testkit_${scala.binary.version} ${akka.version} - - colt - colt - 1.2.0 - org.apache.mesos mesos @@ -445,7 +460,7 @@ org.roaringbitmap RoaringBitmap - 0.4.1 + 0.4.5 commons-net @@ -520,7 +535,7 @@ org.scalatest scalatest_${scala.binary.version} - 2.1.5 + 2.2.1 test @@ -907,9 +922,9 @@ by Spark SQL for code generation. --> - org.scalamacros - paradise_${scala.version} - ${scala.macros.version} + org.scalamacros + paradise_${scala.version} + ${scala.macros.version} @@ -1161,6 +1176,10 @@ + + hadoop-0.23 @@ -1190,6 +1209,7 @@ 2.3.0 2.5.0 0.9.0 + 3.1.1 hadoop2 @@ -1200,6 +1220,7 @@ 2.4.0 2.5.0 0.9.0 + 3.1.1 hadoop2 @@ -1313,14 +1334,19 @@ - hive-0.12.0 + hive false - sql/hive-thriftserver + + + hive-0.12.0 + + false + 0.12.0-protobuf-2.5 0.12.0 @@ -1333,7 +1359,7 @@ false - 0.13.1 + 0.13.1a 0.13.1 10.10.1.1 diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index c58666af84f24..6a0495f8fd540 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -51,9 +51,16 @@ object MimaExcludes { // MapStatus should be private[spark] ProblemFilters.exclude[IncompatibleTemplateDefProblem]( "org.apache.spark.scheduler.MapStatus"), + ProblemFilters.exclude[MissingClassProblem]( + "org.apache.spark.network.netty.PathResolver"), + ProblemFilters.exclude[MissingClassProblem]( + "org.apache.spark.network.netty.client.BlockClientListener"), + // TaskContext was promoted to Abstract class ProblemFilters.exclude[AbstractClassProblem]( - "org.apache.spark.TaskContext") + "org.apache.spark.TaskContext"), + ProblemFilters.exclude[IncompatibleTemplateDefProblem]( + "org.apache.spark.util.collection.SortDataFormat") ) ++ Seq( // Adding new methods to the JavaRDDLike trait: ProblemFilters.exclude[MissingMethodProblem]( @@ -66,6 +73,10 @@ object MimaExcludes { "org.apache.spark.api.java.JavaRDDLike.foreachAsync"), ProblemFilters.exclude[MissingMethodProblem]( "org.apache.spark.api.java.JavaRDDLike.collectAsync") + ) ++ Seq( + // SPARK-3822 + ProblemFilters.exclude[IncompatibleResultTypeProblem]( + "org.apache.spark.SparkContext.org$apache$spark$SparkContext$$createTaskScheduler") ) case v if v.startsWith("1.1") => diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala index ea04473854007..33618f5401768 100644 --- a/project/SparkBuild.scala +++ b/project/SparkBuild.scala @@ -23,7 +23,6 @@ import sbt.Classpaths.publishTask import sbt.Keys._ import sbtunidoc.Plugin.genjavadocSettings import sbtunidoc.Plugin.UnidocKeys.unidocGenjavadocVersion -import org.scalastyle.sbt.ScalastylePlugin.{Settings => ScalaStyleSettings} import com.typesafe.sbt.pom.{PomBuild, SbtPomKeys} import net.virtualvoid.sbt.graph.Plugin.graphSettings @@ -32,11 +31,12 @@ object BuildCommons { private val buildLocation = file(".").getAbsoluteFile.getParentFile val allProjects@Seq(bagel, catalyst, core, graphx, hive, hiveThriftServer, mllib, repl, - sql, streaming, streamingFlumeSink, streamingFlume, streamingKafka, streamingMqtt, - streamingTwitter, streamingZeromq) = + sql, networkCommon, networkShuffle, streaming, streamingFlumeSink, streamingFlume, streamingKafka, + streamingMqtt, streamingTwitter, streamingZeromq) = Seq("bagel", "catalyst", "core", "graphx", "hive", "hive-thriftserver", "mllib", "repl", - "sql", "streaming", "streaming-flume-sink", "streaming-flume", "streaming-kafka", - "streaming-mqtt", "streaming-twitter", "streaming-zeromq").map(ProjectRef(buildLocation, _)) + "sql", "network-common", "network-shuffle", "streaming", "streaming-flume-sink", + "streaming-flume", "streaming-kafka", "streaming-mqtt", "streaming-twitter", + "streaming-zeromq").map(ProjectRef(buildLocation, _)) val optionallyEnabledProjects@Seq(yarn, yarnStable, yarnAlpha, java8Tests, sparkGangliaLgpl, sparkKinesisAsl) = Seq("yarn", "yarn-stable", "yarn-alpha", "java8-tests", "ganglia-lgpl", "kinesis-asl") @@ -111,7 +111,7 @@ object SparkBuild extends PomBuild { lazy val MavenCompile = config("m2r") extend(Compile) lazy val publishLocalBoth = TaskKey[Unit]("publish-local", "publish local for m2 and ivy") - lazy val sharedSettings = graphSettings ++ ScalaStyleSettings ++ genjavadocSettings ++ Seq ( + lazy val sharedSettings = graphSettings ++ genjavadocSettings ++ Seq ( javaHome := Properties.envOrNone("JAVA_HOME").map(file), incOptions := incOptions.value.withNameHashing(true), retrieveManaged := true, @@ -143,7 +143,9 @@ object SparkBuild extends PomBuild { // TODO: Add Sql to mima checks allProjects.filterNot(x => Seq(spark, sql, hive, hiveThriftServer, catalyst, repl, - streamingFlumeSink).contains(x)).foreach(x => enable(MimaBuild.mimaSettings(sparkHome, x))(x)) + streamingFlumeSink, networkCommon, networkShuffle).contains(x)).foreach { + x => enable(MimaBuild.mimaSettings(sparkHome, x))(x) + } /* Enable Assembly for all assembly projects */ assemblyProjects.foreach(enable(Assembly.settings)) diff --git a/project/plugins.sbt b/project/plugins.sbt index 9d50a50b109af..ee45b6a51905e 100644 --- a/project/plugins.sbt +++ b/project/plugins.sbt @@ -19,7 +19,7 @@ addSbtPlugin("com.github.mpeltonen" % "sbt-idea" % "1.6.0") addSbtPlugin("net.virtual-void" % "sbt-dependency-graph" % "0.7.4") -addSbtPlugin("org.scalastyle" %% "scalastyle-sbt-plugin" % "0.5.0") +addSbtPlugin("org.scalastyle" %% "scalastyle-sbt-plugin" % "0.6.0") addSbtPlugin("com.typesafe" % "sbt-mima-plugin" % "0.1.6") diff --git a/python/docs/make.bat b/python/docs/make.bat index c011e82b4a35a..cc29acdc19686 100644 --- a/python/docs/make.bat +++ b/python/docs/make.bat @@ -1,6 +1,6 @@ -@ECHO OFF - -rem This is the entry point for running Sphinx documentation. To avoid polluting the -rem environment, it just launches a new cmd to do the real work. - -cmd /V /E /C %~dp0make2.bat %* +@ECHO OFF + +rem This is the entry point for running Sphinx documentation. To avoid polluting the +rem environment, it just launches a new cmd to do the real work. + +cmd /V /E /C %~dp0make2.bat %* diff --git a/python/docs/make2.bat b/python/docs/make2.bat index 7bcaeafad13d7..05d22eb5cdd23 100644 --- a/python/docs/make2.bat +++ b/python/docs/make2.bat @@ -1,243 +1,243 @@ -@ECHO OFF - -REM Command file for Sphinx documentation - - -if "%SPHINXBUILD%" == "" ( - set SPHINXBUILD=sphinx-build -) -set BUILDDIR=_build -set ALLSPHINXOPTS=-d %BUILDDIR%/doctrees %SPHINXOPTS% . -set I18NSPHINXOPTS=%SPHINXOPTS% . -if NOT "%PAPER%" == "" ( - set ALLSPHINXOPTS=-D latex_paper_size=%PAPER% %ALLSPHINXOPTS% - set I18NSPHINXOPTS=-D latex_paper_size=%PAPER% %I18NSPHINXOPTS% -) - -if "%1" == "" goto help - -if "%1" == "help" ( - :help - echo.Please use `make ^` where ^ is one of - echo. html to make standalone HTML files - echo. dirhtml to make HTML files named index.html in directories - echo. singlehtml to make a single large HTML file - echo. pickle to make pickle files - echo. json to make JSON files - echo. htmlhelp to make HTML files and a HTML help project - echo. qthelp to make HTML files and a qthelp project - echo. devhelp to make HTML files and a Devhelp project - echo. epub to make an epub - echo. latex to make LaTeX files, you can set PAPER=a4 or PAPER=letter - echo. text to make text files - echo. man to make manual pages - echo. texinfo to make Texinfo files - echo. gettext to make PO message catalogs - echo. changes to make an overview over all changed/added/deprecated items - echo. xml to make Docutils-native XML files - echo. pseudoxml to make pseudoxml-XML files for display purposes - echo. linkcheck to check all external links for integrity - echo. doctest to run all doctests embedded in the documentation if enabled - goto end -) - -if "%1" == "clean" ( - for /d %%i in (%BUILDDIR%\*) do rmdir /q /s %%i - del /q /s %BUILDDIR%\* - goto end -) - - -%SPHINXBUILD% 2> nul -if errorlevel 9009 ( - echo. - echo.The 'sphinx-build' command was not found. Make sure you have Sphinx - echo.installed, then set the SPHINXBUILD environment variable to point - echo.to the full path of the 'sphinx-build' executable. Alternatively you - echo.may add the Sphinx directory to PATH. - echo. - echo.If you don't have Sphinx installed, grab it from - echo.http://sphinx-doc.org/ - exit /b 1 -) - -if "%1" == "html" ( - %SPHINXBUILD% -b html %ALLSPHINXOPTS% %BUILDDIR%/html - if errorlevel 1 exit /b 1 - echo. - echo.Build finished. The HTML pages are in %BUILDDIR%/html. - goto end -) - -if "%1" == "dirhtml" ( - %SPHINXBUILD% -b dirhtml %ALLSPHINXOPTS% %BUILDDIR%/dirhtml - if errorlevel 1 exit /b 1 - echo. - echo.Build finished. The HTML pages are in %BUILDDIR%/dirhtml. - goto end -) - -if "%1" == "singlehtml" ( - %SPHINXBUILD% -b singlehtml %ALLSPHINXOPTS% %BUILDDIR%/singlehtml - if errorlevel 1 exit /b 1 - echo. - echo.Build finished. The HTML pages are in %BUILDDIR%/singlehtml. - goto end -) - -if "%1" == "pickle" ( - %SPHINXBUILD% -b pickle %ALLSPHINXOPTS% %BUILDDIR%/pickle - if errorlevel 1 exit /b 1 - echo. - echo.Build finished; now you can process the pickle files. - goto end -) - -if "%1" == "json" ( - %SPHINXBUILD% -b json %ALLSPHINXOPTS% %BUILDDIR%/json - if errorlevel 1 exit /b 1 - echo. - echo.Build finished; now you can process the JSON files. - goto end -) - -if "%1" == "htmlhelp" ( - %SPHINXBUILD% -b htmlhelp %ALLSPHINXOPTS% %BUILDDIR%/htmlhelp - if errorlevel 1 exit /b 1 - echo. - echo.Build finished; now you can run HTML Help Workshop with the ^ -.hhp project file in %BUILDDIR%/htmlhelp. - goto end -) - -if "%1" == "qthelp" ( - %SPHINXBUILD% -b qthelp %ALLSPHINXOPTS% %BUILDDIR%/qthelp - if errorlevel 1 exit /b 1 - echo. - echo.Build finished; now you can run "qcollectiongenerator" with the ^ -.qhcp project file in %BUILDDIR%/qthelp, like this: - echo.^> qcollectiongenerator %BUILDDIR%\qthelp\pyspark.qhcp - echo.To view the help file: - echo.^> assistant -collectionFile %BUILDDIR%\qthelp\pyspark.ghc - goto end -) - -if "%1" == "devhelp" ( - %SPHINXBUILD% -b devhelp %ALLSPHINXOPTS% %BUILDDIR%/devhelp - if errorlevel 1 exit /b 1 - echo. - echo.Build finished. - goto end -) - -if "%1" == "epub" ( - %SPHINXBUILD% -b epub %ALLSPHINXOPTS% %BUILDDIR%/epub - if errorlevel 1 exit /b 1 - echo. - echo.Build finished. The epub file is in %BUILDDIR%/epub. - goto end -) - -if "%1" == "latex" ( - %SPHINXBUILD% -b latex %ALLSPHINXOPTS% %BUILDDIR%/latex - if errorlevel 1 exit /b 1 - echo. - echo.Build finished; the LaTeX files are in %BUILDDIR%/latex. - goto end -) - -if "%1" == "latexpdf" ( - %SPHINXBUILD% -b latex %ALLSPHINXOPTS% %BUILDDIR%/latex - cd %BUILDDIR%/latex - make all-pdf - cd %BUILDDIR%/.. - echo. - echo.Build finished; the PDF files are in %BUILDDIR%/latex. - goto end -) - -if "%1" == "latexpdfja" ( - %SPHINXBUILD% -b latex %ALLSPHINXOPTS% %BUILDDIR%/latex - cd %BUILDDIR%/latex - make all-pdf-ja - cd %BUILDDIR%/.. - echo. - echo.Build finished; the PDF files are in %BUILDDIR%/latex. - goto end -) - -if "%1" == "text" ( - %SPHINXBUILD% -b text %ALLSPHINXOPTS% %BUILDDIR%/text - if errorlevel 1 exit /b 1 - echo. - echo.Build finished. The text files are in %BUILDDIR%/text. - goto end -) - -if "%1" == "man" ( - %SPHINXBUILD% -b man %ALLSPHINXOPTS% %BUILDDIR%/man - if errorlevel 1 exit /b 1 - echo. - echo.Build finished. The manual pages are in %BUILDDIR%/man. - goto end -) - -if "%1" == "texinfo" ( - %SPHINXBUILD% -b texinfo %ALLSPHINXOPTS% %BUILDDIR%/texinfo - if errorlevel 1 exit /b 1 - echo. - echo.Build finished. The Texinfo files are in %BUILDDIR%/texinfo. - goto end -) - -if "%1" == "gettext" ( - %SPHINXBUILD% -b gettext %I18NSPHINXOPTS% %BUILDDIR%/locale - if errorlevel 1 exit /b 1 - echo. - echo.Build finished. The message catalogs are in %BUILDDIR%/locale. - goto end -) - -if "%1" == "changes" ( - %SPHINXBUILD% -b changes %ALLSPHINXOPTS% %BUILDDIR%/changes - if errorlevel 1 exit /b 1 - echo. - echo.The overview file is in %BUILDDIR%/changes. - goto end -) - -if "%1" == "linkcheck" ( - %SPHINXBUILD% -b linkcheck %ALLSPHINXOPTS% %BUILDDIR%/linkcheck - if errorlevel 1 exit /b 1 - echo. - echo.Link check complete; look for any errors in the above output ^ -or in %BUILDDIR%/linkcheck/output.txt. - goto end -) - -if "%1" == "doctest" ( - %SPHINXBUILD% -b doctest %ALLSPHINXOPTS% %BUILDDIR%/doctest - if errorlevel 1 exit /b 1 - echo. - echo.Testing of doctests in the sources finished, look at the ^ -results in %BUILDDIR%/doctest/output.txt. - goto end -) - -if "%1" == "xml" ( - %SPHINXBUILD% -b xml %ALLSPHINXOPTS% %BUILDDIR%/xml - if errorlevel 1 exit /b 1 - echo. - echo.Build finished. The XML files are in %BUILDDIR%/xml. - goto end -) - -if "%1" == "pseudoxml" ( - %SPHINXBUILD% -b pseudoxml %ALLSPHINXOPTS% %BUILDDIR%/pseudoxml - if errorlevel 1 exit /b 1 - echo. - echo.Build finished. The pseudo-XML files are in %BUILDDIR%/pseudoxml. - goto end -) - -:end +@ECHO OFF + +REM Command file for Sphinx documentation + + +if "%SPHINXBUILD%" == "" ( + set SPHINXBUILD=sphinx-build +) +set BUILDDIR=_build +set ALLSPHINXOPTS=-d %BUILDDIR%/doctrees %SPHINXOPTS% . +set I18NSPHINXOPTS=%SPHINXOPTS% . +if NOT "%PAPER%" == "" ( + set ALLSPHINXOPTS=-D latex_paper_size=%PAPER% %ALLSPHINXOPTS% + set I18NSPHINXOPTS=-D latex_paper_size=%PAPER% %I18NSPHINXOPTS% +) + +if "%1" == "" goto help + +if "%1" == "help" ( + :help + echo.Please use `make ^` where ^ is one of + echo. html to make standalone HTML files + echo. dirhtml to make HTML files named index.html in directories + echo. singlehtml to make a single large HTML file + echo. pickle to make pickle files + echo. json to make JSON files + echo. htmlhelp to make HTML files and a HTML help project + echo. qthelp to make HTML files and a qthelp project + echo. devhelp to make HTML files and a Devhelp project + echo. epub to make an epub + echo. latex to make LaTeX files, you can set PAPER=a4 or PAPER=letter + echo. text to make text files + echo. man to make manual pages + echo. texinfo to make Texinfo files + echo. gettext to make PO message catalogs + echo. changes to make an overview over all changed/added/deprecated items + echo. xml to make Docutils-native XML files + echo. pseudoxml to make pseudoxml-XML files for display purposes + echo. linkcheck to check all external links for integrity + echo. doctest to run all doctests embedded in the documentation if enabled + goto end +) + +if "%1" == "clean" ( + for /d %%i in (%BUILDDIR%\*) do rmdir /q /s %%i + del /q /s %BUILDDIR%\* + goto end +) + + +%SPHINXBUILD% 2> nul +if errorlevel 9009 ( + echo. + echo.The 'sphinx-build' command was not found. Make sure you have Sphinx + echo.installed, then set the SPHINXBUILD environment variable to point + echo.to the full path of the 'sphinx-build' executable. Alternatively you + echo.may add the Sphinx directory to PATH. + echo. + echo.If you don't have Sphinx installed, grab it from + echo.http://sphinx-doc.org/ + exit /b 1 +) + +if "%1" == "html" ( + %SPHINXBUILD% -b html %ALLSPHINXOPTS% %BUILDDIR%/html + if errorlevel 1 exit /b 1 + echo. + echo.Build finished. The HTML pages are in %BUILDDIR%/html. + goto end +) + +if "%1" == "dirhtml" ( + %SPHINXBUILD% -b dirhtml %ALLSPHINXOPTS% %BUILDDIR%/dirhtml + if errorlevel 1 exit /b 1 + echo. + echo.Build finished. The HTML pages are in %BUILDDIR%/dirhtml. + goto end +) + +if "%1" == "singlehtml" ( + %SPHINXBUILD% -b singlehtml %ALLSPHINXOPTS% %BUILDDIR%/singlehtml + if errorlevel 1 exit /b 1 + echo. + echo.Build finished. The HTML pages are in %BUILDDIR%/singlehtml. + goto end +) + +if "%1" == "pickle" ( + %SPHINXBUILD% -b pickle %ALLSPHINXOPTS% %BUILDDIR%/pickle + if errorlevel 1 exit /b 1 + echo. + echo.Build finished; now you can process the pickle files. + goto end +) + +if "%1" == "json" ( + %SPHINXBUILD% -b json %ALLSPHINXOPTS% %BUILDDIR%/json + if errorlevel 1 exit /b 1 + echo. + echo.Build finished; now you can process the JSON files. + goto end +) + +if "%1" == "htmlhelp" ( + %SPHINXBUILD% -b htmlhelp %ALLSPHINXOPTS% %BUILDDIR%/htmlhelp + if errorlevel 1 exit /b 1 + echo. + echo.Build finished; now you can run HTML Help Workshop with the ^ +.hhp project file in %BUILDDIR%/htmlhelp. + goto end +) + +if "%1" == "qthelp" ( + %SPHINXBUILD% -b qthelp %ALLSPHINXOPTS% %BUILDDIR%/qthelp + if errorlevel 1 exit /b 1 + echo. + echo.Build finished; now you can run "qcollectiongenerator" with the ^ +.qhcp project file in %BUILDDIR%/qthelp, like this: + echo.^> qcollectiongenerator %BUILDDIR%\qthelp\pyspark.qhcp + echo.To view the help file: + echo.^> assistant -collectionFile %BUILDDIR%\qthelp\pyspark.ghc + goto end +) + +if "%1" == "devhelp" ( + %SPHINXBUILD% -b devhelp %ALLSPHINXOPTS% %BUILDDIR%/devhelp + if errorlevel 1 exit /b 1 + echo. + echo.Build finished. + goto end +) + +if "%1" == "epub" ( + %SPHINXBUILD% -b epub %ALLSPHINXOPTS% %BUILDDIR%/epub + if errorlevel 1 exit /b 1 + echo. + echo.Build finished. The epub file is in %BUILDDIR%/epub. + goto end +) + +if "%1" == "latex" ( + %SPHINXBUILD% -b latex %ALLSPHINXOPTS% %BUILDDIR%/latex + if errorlevel 1 exit /b 1 + echo. + echo.Build finished; the LaTeX files are in %BUILDDIR%/latex. + goto end +) + +if "%1" == "latexpdf" ( + %SPHINXBUILD% -b latex %ALLSPHINXOPTS% %BUILDDIR%/latex + cd %BUILDDIR%/latex + make all-pdf + cd %BUILDDIR%/.. + echo. + echo.Build finished; the PDF files are in %BUILDDIR%/latex. + goto end +) + +if "%1" == "latexpdfja" ( + %SPHINXBUILD% -b latex %ALLSPHINXOPTS% %BUILDDIR%/latex + cd %BUILDDIR%/latex + make all-pdf-ja + cd %BUILDDIR%/.. + echo. + echo.Build finished; the PDF files are in %BUILDDIR%/latex. + goto end +) + +if "%1" == "text" ( + %SPHINXBUILD% -b text %ALLSPHINXOPTS% %BUILDDIR%/text + if errorlevel 1 exit /b 1 + echo. + echo.Build finished. The text files are in %BUILDDIR%/text. + goto end +) + +if "%1" == "man" ( + %SPHINXBUILD% -b man %ALLSPHINXOPTS% %BUILDDIR%/man + if errorlevel 1 exit /b 1 + echo. + echo.Build finished. The manual pages are in %BUILDDIR%/man. + goto end +) + +if "%1" == "texinfo" ( + %SPHINXBUILD% -b texinfo %ALLSPHINXOPTS% %BUILDDIR%/texinfo + if errorlevel 1 exit /b 1 + echo. + echo.Build finished. The Texinfo files are in %BUILDDIR%/texinfo. + goto end +) + +if "%1" == "gettext" ( + %SPHINXBUILD% -b gettext %I18NSPHINXOPTS% %BUILDDIR%/locale + if errorlevel 1 exit /b 1 + echo. + echo.Build finished. The message catalogs are in %BUILDDIR%/locale. + goto end +) + +if "%1" == "changes" ( + %SPHINXBUILD% -b changes %ALLSPHINXOPTS% %BUILDDIR%/changes + if errorlevel 1 exit /b 1 + echo. + echo.The overview file is in %BUILDDIR%/changes. + goto end +) + +if "%1" == "linkcheck" ( + %SPHINXBUILD% -b linkcheck %ALLSPHINXOPTS% %BUILDDIR%/linkcheck + if errorlevel 1 exit /b 1 + echo. + echo.Link check complete; look for any errors in the above output ^ +or in %BUILDDIR%/linkcheck/output.txt. + goto end +) + +if "%1" == "doctest" ( + %SPHINXBUILD% -b doctest %ALLSPHINXOPTS% %BUILDDIR%/doctest + if errorlevel 1 exit /b 1 + echo. + echo.Testing of doctests in the sources finished, look at the ^ +results in %BUILDDIR%/doctest/output.txt. + goto end +) + +if "%1" == "xml" ( + %SPHINXBUILD% -b xml %ALLSPHINXOPTS% %BUILDDIR%/xml + if errorlevel 1 exit /b 1 + echo. + echo.Build finished. The XML files are in %BUILDDIR%/xml. + goto end +) + +if "%1" == "pseudoxml" ( + %SPHINXBUILD% -b pseudoxml %ALLSPHINXOPTS% %BUILDDIR%/pseudoxml + if errorlevel 1 exit /b 1 + echo. + echo.Build finished. The pseudo-XML files are in %BUILDDIR%/pseudoxml. + goto end +) + +:end diff --git a/python/pyspark/context.py b/python/pyspark/context.py index 5f8dcedb1eea2..a0e4821728c8b 100644 --- a/python/pyspark/context.py +++ b/python/pyspark/context.py @@ -63,7 +63,6 @@ class SparkContext(object): _active_spark_context = None _lock = Lock() _python_includes = None # zip and egg files that need to be added to PYTHONPATH - _default_batch_size_for_serialized_input = 10 def __init__(self, master=None, appName=None, sparkHome=None, pyFiles=None, environment=None, batchSize=0, serializer=PickleSerializer(), conf=None, @@ -115,9 +114,7 @@ def _do_init(self, master, appName, sparkHome, pyFiles, environment, batchSize, self._conf = conf or SparkConf(_jvm=self._jvm) self._batchSize = batchSize # -1 represents an unlimited batch size self._unbatched_serializer = serializer - if batchSize == 1: - self.serializer = self._unbatched_serializer - elif batchSize == 0: + if batchSize == 0: self.serializer = AutoBatchedSerializer(self._unbatched_serializer) else: self.serializer = BatchedSerializer(self._unbatched_serializer, @@ -305,12 +302,8 @@ def parallelize(self, c, numSlices=None): # Make sure we distribute data evenly if it's smaller than self.batchSize if "__len__" not in dir(c): c = list(c) # Make it a list so we can compute its length - batchSize = min(len(c) // numSlices, self._batchSize) - if batchSize > 1: - serializer = BatchedSerializer(self._unbatched_serializer, - batchSize) - else: - serializer = self._unbatched_serializer + batchSize = max(1, min(len(c) // numSlices, self._batchSize)) + serializer = BatchedSerializer(self._unbatched_serializer, batchSize) serializer.dump_stream(c, tempFile) tempFile.close() readRDDFromFile = self._jvm.PythonRDD.readRDDFromFile @@ -328,8 +321,7 @@ def pickleFile(self, name, minPartitions=None): [0, 1, 2, 3, 4, 5, 6, 7, 8, 9] """ minPartitions = minPartitions or self.defaultMinPartitions - return RDD(self._jsc.objectFile(name, minPartitions), self, - BatchedSerializer(PickleSerializer())) + return RDD(self._jsc.objectFile(name, minPartitions), self) def textFile(self, name, minPartitions=None, use_unicode=True): """ @@ -405,7 +397,7 @@ def _dictToJavaMap(self, d): return jm def sequenceFile(self, path, keyClass=None, valueClass=None, keyConverter=None, - valueConverter=None, minSplits=None, batchSize=None): + valueConverter=None, minSplits=None, batchSize=0): """ Read a Hadoop SequenceFile with arbitrary key and value Writable class from HDFS, a local file system (available on all nodes), or any Hadoop-supported file system URI. @@ -427,17 +419,15 @@ def sequenceFile(self, path, keyClass=None, valueClass=None, keyConverter=None, :param minSplits: minimum splits in dataset (default min(2, sc.defaultParallelism)) :param batchSize: The number of Python objects represented as a single - Java object. (default sc._default_batch_size_for_serialized_input) + Java object. (default 0, choose batchSize automatically) """ minSplits = minSplits or min(self.defaultParallelism, 2) - batchSize = max(1, batchSize or self._default_batch_size_for_serialized_input) - ser = BatchedSerializer(PickleSerializer()) if (batchSize > 1) else PickleSerializer() jrdd = self._jvm.PythonRDD.sequenceFile(self._jsc, path, keyClass, valueClass, keyConverter, valueConverter, minSplits, batchSize) - return RDD(jrdd, self, ser) + return RDD(jrdd, self) def newAPIHadoopFile(self, path, inputFormatClass, keyClass, valueClass, keyConverter=None, - valueConverter=None, conf=None, batchSize=None): + valueConverter=None, conf=None, batchSize=0): """ Read a 'new API' Hadoop InputFormat with arbitrary key and value class from HDFS, a local file system (available on all nodes), or any Hadoop-supported file system URI. @@ -458,18 +448,16 @@ def newAPIHadoopFile(self, path, inputFormatClass, keyClass, valueClass, keyConv :param conf: Hadoop configuration, passed in as a dict (None by default) :param batchSize: The number of Python objects represented as a single - Java object. (default sc._default_batch_size_for_serialized_input) + Java object. (default 0, choose batchSize automatically) """ jconf = self._dictToJavaMap(conf) - batchSize = max(1, batchSize or self._default_batch_size_for_serialized_input) - ser = BatchedSerializer(PickleSerializer()) if (batchSize > 1) else PickleSerializer() jrdd = self._jvm.PythonRDD.newAPIHadoopFile(self._jsc, path, inputFormatClass, keyClass, valueClass, keyConverter, valueConverter, jconf, batchSize) - return RDD(jrdd, self, ser) + return RDD(jrdd, self) def newAPIHadoopRDD(self, inputFormatClass, keyClass, valueClass, keyConverter=None, - valueConverter=None, conf=None, batchSize=None): + valueConverter=None, conf=None, batchSize=0): """ Read a 'new API' Hadoop InputFormat with arbitrary key and value class, from an arbitrary Hadoop configuration, which is passed in as a Python dict. @@ -487,18 +475,16 @@ def newAPIHadoopRDD(self, inputFormatClass, keyClass, valueClass, keyConverter=N :param conf: Hadoop configuration, passed in as a dict (None by default) :param batchSize: The number of Python objects represented as a single - Java object. (default sc._default_batch_size_for_serialized_input) + Java object. (default 0, choose batchSize automatically) """ jconf = self._dictToJavaMap(conf) - batchSize = max(1, batchSize or self._default_batch_size_for_serialized_input) - ser = BatchedSerializer(PickleSerializer()) if (batchSize > 1) else PickleSerializer() jrdd = self._jvm.PythonRDD.newAPIHadoopRDD(self._jsc, inputFormatClass, keyClass, valueClass, keyConverter, valueConverter, jconf, batchSize) - return RDD(jrdd, self, ser) + return RDD(jrdd, self) def hadoopFile(self, path, inputFormatClass, keyClass, valueClass, keyConverter=None, - valueConverter=None, conf=None, batchSize=None): + valueConverter=None, conf=None, batchSize=0): """ Read an 'old' Hadoop InputFormat with arbitrary key and value class from HDFS, a local file system (available on all nodes), or any Hadoop-supported file system URI. @@ -519,18 +505,16 @@ def hadoopFile(self, path, inputFormatClass, keyClass, valueClass, keyConverter= :param conf: Hadoop configuration, passed in as a dict (None by default) :param batchSize: The number of Python objects represented as a single - Java object. (default sc._default_batch_size_for_serialized_input) + Java object. (default 0, choose batchSize automatically) """ jconf = self._dictToJavaMap(conf) - batchSize = max(1, batchSize or self._default_batch_size_for_serialized_input) - ser = BatchedSerializer(PickleSerializer()) if (batchSize > 1) else PickleSerializer() jrdd = self._jvm.PythonRDD.hadoopFile(self._jsc, path, inputFormatClass, keyClass, valueClass, keyConverter, valueConverter, jconf, batchSize) - return RDD(jrdd, self, ser) + return RDD(jrdd, self) def hadoopRDD(self, inputFormatClass, keyClass, valueClass, keyConverter=None, - valueConverter=None, conf=None, batchSize=None): + valueConverter=None, conf=None, batchSize=0): """ Read an 'old' Hadoop InputFormat with arbitrary key and value class, from an arbitrary Hadoop configuration, which is passed in as a Python dict. @@ -548,15 +532,13 @@ def hadoopRDD(self, inputFormatClass, keyClass, valueClass, keyConverter=None, :param conf: Hadoop configuration, passed in as a dict (None by default) :param batchSize: The number of Python objects represented as a single - Java object. (default sc._default_batch_size_for_serialized_input) + Java object. (default 0, choose batchSize automatically) """ jconf = self._dictToJavaMap(conf) - batchSize = max(1, batchSize or self._default_batch_size_for_serialized_input) - ser = BatchedSerializer(PickleSerializer()) if (batchSize > 1) else PickleSerializer() jrdd = self._jvm.PythonRDD.hadoopRDD(self._jsc, inputFormatClass, keyClass, valueClass, keyConverter, valueConverter, jconf, batchSize) - return RDD(jrdd, self, ser) + return RDD(jrdd, self) def _checkpointFile(self, name, input_deserializer): jrdd = self._jsc.checkpointFile(name) @@ -836,7 +818,7 @@ def _test(): import doctest import tempfile globs = globals().copy() - globs['sc'] = SparkContext('local[4]', 'PythonTest', batchSize=2) + globs['sc'] = SparkContext('local[4]', 'PythonTest') globs['tempdir'] = tempfile.mkdtemp() atexit.register(lambda: shutil.rmtree(globs['tempdir'])) (failure_count, test_count) = doctest.testmod(globs=globs, optionflags=doctest.ELLIPSIS) diff --git a/python/pyspark/daemon.py b/python/pyspark/daemon.py index dbb34775d9ac5..f09587f211708 100644 --- a/python/pyspark/daemon.py +++ b/python/pyspark/daemon.py @@ -62,8 +62,7 @@ def worker(sock): exit_code = compute_real_exit_code(exc.code) finally: outfile.flush() - if exit_code: - os._exit(exit_code) + return exit_code # Cleanup zombie children @@ -160,10 +159,13 @@ def handle_sigterm(*args): outfile.flush() outfile.close() while True: - worker(sock) - if not reuse: + code = worker(sock) + if not reuse or code: # wait for closing - while sock.recv(1024): + try: + while sock.recv(1024): + pass + except Exception: pass break gc.collect() diff --git a/python/pyspark/mllib/classification.py b/python/pyspark/mllib/classification.py index e295c9d0954d9..297a2bf37d2cf 100644 --- a/python/pyspark/mllib/classification.py +++ b/python/pyspark/mllib/classification.py @@ -20,8 +20,8 @@ import numpy from numpy import array -from pyspark import SparkContext, PickleSerializer -from pyspark.mllib.linalg import SparseVector, _convert_to_vector, _to_java_object_rdd +from pyspark.mllib.common import callMLlibFunc +from pyspark.mllib.linalg import SparseVector, _convert_to_vector from pyspark.mllib.regression import LabeledPoint, LinearModel, _regression_train_wrapper @@ -102,14 +102,11 @@ def train(cls, data, iterations=100, step=1.0, miniBatchFraction=1.0, training data (i.e. whether bias features are activated or not). """ - sc = data.context + def train(rdd, i): + return callMLlibFunc("trainLogisticRegressionModelWithSGD", rdd, iterations, step, + miniBatchFraction, i, regParam, regType, intercept) - def train(jdata, i): - return sc._jvm.PythonMLLibAPI().trainLogisticRegressionModelWithSGD( - jdata, iterations, step, miniBatchFraction, i, regParam, regType, intercept) - - return _regression_train_wrapper(sc, train, LogisticRegressionModel, data, - initialWeights) + return _regression_train_wrapper(train, LogisticRegressionModel, data, initialWeights) class SVMModel(LinearModel): @@ -174,13 +171,11 @@ def train(cls, data, iterations=100, step=1.0, regParam=1.0, training data (i.e. whether bias features are activated or not). """ - sc = data.context - - def train(jrdd, i): - return sc._jvm.PythonMLLibAPI().trainSVMModelWithSGD( - jrdd, iterations, step, regParam, miniBatchFraction, i, regType, intercept) + def train(rdd, i): + return callMLlibFunc("trainSVMModelWithSGD", rdd, iterations, step, regParam, + miniBatchFraction, i, regType, intercept) - return _regression_train_wrapper(sc, train, SVMModel, data, initialWeights) + return _regression_train_wrapper(train, SVMModel, data, initialWeights) class NaiveBayesModel(object): @@ -243,14 +238,13 @@ def train(cls, data, lambda_=1.0): (e.g. a count vector). :param lambda_: The smoothing parameter """ - sc = data.context - jlist = sc._jvm.PythonMLLibAPI().trainNaiveBayes(_to_java_object_rdd(data), lambda_) - labels, pi, theta = PickleSerializer().loads(str(sc._jvm.SerDe.dumps(jlist))) + labels, pi, theta = callMLlibFunc("trainNaiveBayes", data, lambda_) return NaiveBayesModel(labels.toArray(), pi.toArray(), numpy.array(theta)) def _test(): import doctest + from pyspark import SparkContext globs = globals().copy() globs['sc'] = SparkContext('local[4]', 'PythonTest', batchSize=2) (failure_count, test_count) = doctest.testmod(globs=globs, optionflags=doctest.ELLIPSIS) diff --git a/python/pyspark/mllib/clustering.py b/python/pyspark/mllib/clustering.py index 5ee7997104d21..fe4c4cc5094d8 100644 --- a/python/pyspark/mllib/clustering.py +++ b/python/pyspark/mllib/clustering.py @@ -16,8 +16,8 @@ # from pyspark import SparkContext -from pyspark.serializers import PickleSerializer, AutoBatchedSerializer -from pyspark.mllib.linalg import SparseVector, _convert_to_vector, _to_java_object_rdd +from pyspark.mllib.common import callMLlibFunc, callJavaFunc, _to_java_object_rdd +from pyspark.mllib.linalg import SparseVector, _convert_to_vector __all__ = ['KMeansModel', 'KMeans'] @@ -80,14 +80,11 @@ class KMeans(object): @classmethod def train(cls, rdd, k, maxIterations=100, runs=1, initializationMode="k-means||"): """Train a k-means clustering model.""" - sc = rdd.context - ser = PickleSerializer() # cache serialized data to avoid objects over head in JVM - cached = rdd.map(_convert_to_vector)._reserialize(AutoBatchedSerializer(ser)).cache() - model = sc._jvm.PythonMLLibAPI().trainKMeansModel( - _to_java_object_rdd(cached), k, maxIterations, runs, initializationMode) - bytes = sc._jvm.SerDe.dumps(model.clusterCenters()) - centers = ser.loads(str(bytes)) + jcached = _to_java_object_rdd(rdd.map(_convert_to_vector), cache=True) + model = callMLlibFunc("trainKMeansModel", jcached, k, maxIterations, runs, + initializationMode) + centers = callJavaFunc(rdd.context, model.clusterCenters) return KMeansModel([c.toArray() for c in centers]) diff --git a/python/pyspark/mllib/common.py b/python/pyspark/mllib/common.py new file mode 100644 index 0000000000000..dbe5f698b7345 --- /dev/null +++ b/python/pyspark/mllib/common.py @@ -0,0 +1,135 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import py4j.protocol +from py4j.protocol import Py4JJavaError +from py4j.java_gateway import JavaObject +from py4j.java_collections import MapConverter, ListConverter, JavaArray, JavaList + +from pyspark import RDD, SparkContext +from pyspark.serializers import PickleSerializer, AutoBatchedSerializer + + +# Hack for support float('inf') in Py4j +_old_smart_decode = py4j.protocol.smart_decode + +_float_str_mapping = { + 'nan': 'NaN', + 'inf': 'Infinity', + '-inf': '-Infinity', +} + + +def _new_smart_decode(obj): + if isinstance(obj, float): + s = unicode(obj) + return _float_str_mapping.get(s, s) + return _old_smart_decode(obj) + +py4j.protocol.smart_decode = _new_smart_decode + + +_picklable_classes = [ + 'LinkedList', + 'SparseVector', + 'DenseVector', + 'DenseMatrix', + 'Rating', + 'LabeledPoint', +] + + +# this will call the MLlib version of pythonToJava() +def _to_java_object_rdd(rdd, cache=False): + """ Return an JavaRDD of Object by unpickling + + It will convert each Python object into Java object by Pyrolite, whenever the + RDD is serialized in batch or not. + """ + rdd = rdd._reserialize(AutoBatchedSerializer(PickleSerializer())) + if cache: + rdd.cache() + return rdd.ctx._jvm.SerDe.pythonToJava(rdd._jrdd, True) + + +def _py2java(sc, obj): + """ Convert Python object into Java """ + if isinstance(obj, RDD): + obj = _to_java_object_rdd(obj) + elif isinstance(obj, SparkContext): + obj = obj._jsc + elif isinstance(obj, dict): + obj = MapConverter().convert(obj, sc._gateway._gateway_client) + elif isinstance(obj, (list, tuple)): + obj = ListConverter().convert(obj, sc._gateway._gateway_client) + elif isinstance(obj, JavaObject): + pass + elif isinstance(obj, (int, long, float, bool, basestring)): + pass + else: + bytes = bytearray(PickleSerializer().dumps(obj)) + obj = sc._jvm.SerDe.loads(bytes) + return obj + + +def _java2py(sc, r): + if isinstance(r, JavaObject): + clsName = r.getClass().getSimpleName() + # convert RDD into JavaRDD + if clsName != 'JavaRDD' and clsName.endswith("RDD"): + r = r.toJavaRDD() + clsName = 'JavaRDD' + + if clsName == 'JavaRDD': + jrdd = sc._jvm.SerDe.javaToPython(r) + return RDD(jrdd, sc) + + elif isinstance(r, (JavaArray, JavaList)) or clsName in _picklable_classes: + r = sc._jvm.SerDe.dumps(r) + + if isinstance(r, bytearray): + r = PickleSerializer().loads(str(r)) + return r + + +def callJavaFunc(sc, func, *args): + """ Call Java Function """ + args = [_py2java(sc, a) for a in args] + return _java2py(sc, func(*args)) + + +def callMLlibFunc(name, *args): + """ Call API in PythonMLLibAPI """ + sc = SparkContext._active_spark_context + api = getattr(sc._jvm.PythonMLLibAPI(), name) + return callJavaFunc(sc, api, *args) + + +class JavaModelWrapper(object): + """ + Wrapper for the model in JVM + """ + def __init__(self, java_model): + self._sc = SparkContext._active_spark_context + self._java_model = java_model + + def __del__(self): + self._sc._gateway.detach(self._java_model) + + def call(self, name, *a): + """Call method of java_model""" + return callJavaFunc(self._sc, getattr(self._java_model, name), *a) diff --git a/python/pyspark/mllib/feature.py b/python/pyspark/mllib/feature.py index b5a3f22c6907e..44bf6f269d7a3 100644 --- a/python/pyspark/mllib/feature.py +++ b/python/pyspark/mllib/feature.py @@ -18,59 +18,276 @@ """ Python package for feature in MLlib. """ -from pyspark.serializers import PickleSerializer, AutoBatchedSerializer -from pyspark.mllib.linalg import _convert_to_vector, _to_java_object_rdd +import sys +import warnings -__all__ = ['Word2Vec', 'Word2VecModel'] +from py4j.protocol import Py4JJavaError +from pyspark import RDD, SparkContext +from pyspark.mllib.common import callMLlibFunc, JavaModelWrapper +from pyspark.mllib.linalg import Vectors -class Word2VecModel(object): +__all__ = ['Normalizer', 'StandardScalerModel', 'StandardScaler', + 'HashingTF', 'IDFModel', 'IDF', 'Word2Vec', 'Word2VecModel'] + + +class VectorTransformer(object): """ - class for Word2Vec model + :: DeveloperApi :: + + Base class for transformation of a vector or RDD of vector """ - def __init__(self, sc, java_model): + def transform(self, vector): """ - :param sc: Spark context - :param java_model: Handle to Java model object + Applies transformation on a vector. + + :param vector: vector to be transformed. """ - self._sc = sc - self._java_model = java_model + raise NotImplementedError - def __del__(self): - self._sc._gateway.detach(self._java_model) - def transform(self, word): +class Normalizer(VectorTransformer): + """ + :: Experimental :: + + Normalizes samples individually to unit L\ :sup:`p`\ norm + + For any 1 <= `p` <= float('inf'), normalizes samples using + sum(abs(vector). :sup:`p`) :sup:`(1/p)` as norm. + + For `p` = float('inf'), max(abs(vector)) will be used as norm for normalization. + + >>> v = Vectors.dense(range(3)) + >>> nor = Normalizer(1) + >>> nor.transform(v) + DenseVector([0.0, 0.3333, 0.6667]) + + >>> rdd = sc.parallelize([v]) + >>> nor.transform(rdd).collect() + [DenseVector([0.0, 0.3333, 0.6667])] + + >>> nor2 = Normalizer(float("inf")) + >>> nor2.transform(v) + DenseVector([0.0, 0.5, 1.0]) + """ + def __init__(self, p=2.0): """ - :param word: a word - :return: vector representation of word + :param p: Normalization in L^p^ space, p = 2 by default. + """ + assert p >= 1.0, "p should be greater than 1.0" + self.p = float(p) + + def transform(self, vector): + """ + Applies unit length normalization on a vector. + + :param vector: vector to be normalized. + :return: normalized vector. If the norm of the input is zero, it + will return the input vector. + """ + sc = SparkContext._active_spark_context + assert sc is not None, "SparkContext should be initialized first" + return callMLlibFunc("normalizeVector", self.p, vector) + + +class JavaVectorTransformer(JavaModelWrapper, VectorTransformer): + """ + Wrapper for the model in JVM + """ + + def transform(self, dataset): + return self.call("transform", dataset) + + +class StandardScalerModel(JavaVectorTransformer): + """ + :: Experimental :: + + Represents a StandardScaler model that can transform vectors. + """ + def transform(self, vector): + """ + Applies standardization transformation on a vector. + + :param vector: Vector to be standardized. + :return: Standardized vector. If the variance of a column is zero, + it will return default `0.0` for the column with zero variance. + """ + return JavaVectorTransformer.transform(self, vector) + + +class StandardScaler(object): + """ + :: Experimental :: + + Standardizes features by removing the mean and scaling to unit + variance using column summary statistics on the samples in the + training set. + + >>> vs = [Vectors.dense([-2.0, 2.3, 0]), Vectors.dense([3.8, 0.0, 1.9])] + >>> dataset = sc.parallelize(vs) + >>> standardizer = StandardScaler(True, True) + >>> model = standardizer.fit(dataset) + >>> result = model.transform(dataset) + >>> for r in result.collect(): r + DenseVector([-0.7071, 0.7071, -0.7071]) + DenseVector([0.7071, -0.7071, 0.7071]) + """ + def __init__(self, withMean=False, withStd=True): + """ + :param withMean: False by default. Centers the data with mean + before scaling. It will build a dense output, so this + does not work on sparse input and will raise an exception. + :param withStd: True by default. Scales the data to unit standard + deviation. + """ + if not (withMean or withStd): + warnings.warn("Both withMean and withStd are false. The model does nothing.") + self.withMean = withMean + self.withStd = withStd + + def fit(self, dataset): + """ + Computes the mean and variance and stores as a model to be used for later scaling. + + :param data: The data used to compute the mean and variance to build + the transformation model. + :return: a StandardScalarModel + """ + jmodel = callMLlibFunc("fitStandardScaler", self.withMean, self.withStd, dataset) + return StandardScalerModel(jmodel) + + +class HashingTF(object): + """ + :: Experimental :: + + Maps a sequence of terms to their term frequencies using the hashing trick. + Note: the terms must be hashable (can not be dict/set/list...). + + >>> htf = HashingTF(100) + >>> doc = "a a b b c d".split(" ") + >>> htf.transform(doc) + SparseVector(100, {1: 1.0, 14: 1.0, 31: 2.0, 44: 2.0}) + """ + def __init__(self, numFeatures=1 << 20): + """ + :param numFeatures: number of features (default: 2^20) + """ + self.numFeatures = numFeatures + + def indexOf(self, term): + """ Returns the index of the input term. """ + return hash(term) % self.numFeatures + + def transform(self, document): + """ + Transforms the input document (list of terms) to term frequency vectors, + or transform the RDD of document to RDD of term frequency vectors. + """ + if isinstance(document, RDD): + return document.map(self.transform) + + freq = {} + for term in document: + i = self.indexOf(term) + freq[i] = freq.get(i, 0) + 1.0 + return Vectors.sparse(self.numFeatures, freq.items()) + + +class IDFModel(JavaVectorTransformer): + """ + Represents an IDF model that can transform term frequency vectors. + """ + def transform(self, dataset): + """ + Transforms term frequency (TF) vectors to TF-IDF vectors. + + If `minDocFreq` was set for the IDF calculation, + the terms which occur in fewer than `minDocFreq` + documents will have an entry of 0. + + :param dataset: an RDD of term frequency vectors + :return: an RDD of TF-IDF vectors + """ + return JavaVectorTransformer.transform(self, dataset) + + +class IDF(object): + """ + :: Experimental :: + + Inverse document frequency (IDF). + + The standard formulation is used: `idf = log((m + 1) / (d(t) + 1))`, + where `m` is the total number of documents and `d(t)` is the number + of documents that contain term `t`. + + This implementation supports filtering out terms which do not appear + in a minimum number of documents (controlled by the variable `minDocFreq`). + For terms that are not in at least `minDocFreq` documents, the IDF is + found as 0, resulting in TF-IDFs of 0. + + >>> n = 4 + >>> freqs = [Vectors.sparse(n, (1, 3), (1.0, 2.0)), + ... Vectors.dense([0.0, 1.0, 2.0, 3.0]), + ... Vectors.sparse(n, [1], [1.0])] + >>> data = sc.parallelize(freqs) + >>> idf = IDF() + >>> model = idf.fit(data) + >>> tfidf = model.transform(data) + >>> for r in tfidf.collect(): r + SparseVector(4, {1: 0.0, 3: 0.5754}) + DenseVector([0.0, 0.0, 1.3863, 0.863]) + SparseVector(4, {1: 0.0}) + """ + def __init__(self, minDocFreq=0): + """ + :param minDocFreq: minimum of documents in which a term + should appear for filtering + """ + self.minDocFreq = minDocFreq + + def fit(self, dataset): + """ + Computes the inverse document frequency. + + :param dataset: an RDD of term frequency vectors + """ + jmodel = callMLlibFunc("fitIDF", self.minDocFreq, dataset) + return IDFModel(jmodel) + + +class Word2VecModel(JavaVectorTransformer): + """ + class for Word2Vec model + """ + def transform(self, word): + """ Transforms a word to its vector representation Note: local use only + + :param word: a word + :return: vector representation of word(s) """ - # TODO: make transform usable in RDD operations from python side - result = self._java_model.transform(word) - return PickleSerializer().loads(str(self._sc._jvm.SerDe.dumps(result))) + try: + return self.call("transform", word) + except Py4JJavaError: + raise ValueError("%s not found" % word) - def findSynonyms(self, x, num): + def findSynonyms(self, word, num): """ - :param x: a word or a vector representation of word + Find synonyms of a word + + :param word: a word or a vector representation of word :param num: number of synonyms to find :return: array of (word, cosineSimilarity) - Find synonyms of a word - Note: local use only """ - # TODO: make findSynonyms usable in RDD operations from python side - ser = PickleSerializer() - if type(x) == str: - jlist = self._java_model.findSynonyms(x, num) - else: - bytes = bytearray(ser.dumps(_convert_to_vector(x))) - vec = self._sc._jvm.SerDe.loads(bytes) - jlist = self._java_model.findSynonyms(vec, num) - words, similarity = ser.loads(str(self._sc._jvm.SerDe.dumps(jlist))) + words, similarity = self.call("findSynonyms", word, num) return zip(words, similarity) @@ -85,6 +302,7 @@ class Word2Vec(object): We used skip-gram model in our implementation and hierarchical softmax method to train the model. The variable names in the implementation matches the original C implementation. + For original C implementation, see https://code.google.com/p/word2vec/ For research papers, see Efficient Estimation of Word Representations in Vector Space @@ -95,33 +313,26 @@ class Word2Vec(object): >>> localDoc = [sentence, sentence] >>> doc = sc.parallelize(localDoc).map(lambda line: line.split(" ")) >>> model = Word2Vec().setVectorSize(10).setSeed(42L).fit(doc) + >>> syms = model.findSynonyms("a", 2) - >>> str(syms[0][0]) - 'b' - >>> str(syms[1][0]) - 'c' - >>> len(syms) - 2 + >>> [s[0] for s in syms] + [u'b', u'c'] >>> vec = model.transform("a") - >>> len(vec) - 10 >>> syms = model.findSynonyms(vec, 2) - >>> str(syms[0][0]) - 'b' - >>> str(syms[1][0]) - 'c' - >>> len(syms) - 2 + >>> [s[0] for s in syms] + [u'b', u'c'] """ def __init__(self): """ Construct Word2Vec instance """ + import random # this can't be on the top because of mllib.random + self.vectorSize = 100 self.learningRate = 0.025 self.numPartitions = 1 self.numIterations = 1 - self.seed = 42L + self.seed = random.randint(0, sys.maxint) def setVectorSize(self, vectorSize): """ @@ -164,20 +375,12 @@ def fit(self, data): Computes the vector representation of each word in vocabulary. :param data: training data. RDD of subtype of Iterable[String] - :return: python Word2VecModel instance + :return: Word2VecModel instance """ - sc = data.context - ser = PickleSerializer() - vectorSize = self.vectorSize - learningRate = self.learningRate - numPartitions = self.numPartitions - numIterations = self.numIterations - seed = self.seed - - model = sc._jvm.PythonMLLibAPI().trainWord2Vec( - _to_java_object_rdd(data), vectorSize, - learningRate, numPartitions, numIterations, seed) - return Word2VecModel(sc, model) + jmodel = callMLlibFunc("trainWord2Vec", data, int(self.vectorSize), + float(self.learningRate), int(self.numPartitions), + int(self.numIterations), long(self.seed)) + return Word2VecModel(jmodel) def _test(): @@ -191,4 +394,8 @@ def _test(): exit(-1) if __name__ == "__main__": + # remove current path from list of search paths to avoid importing mllib.random + # for C{import random}, which is done in an external dependency of pyspark during doctests. + import sys + sys.path.pop(0) _test() diff --git a/python/pyspark/mllib/linalg.py b/python/pyspark/mllib/linalg.py index 773d8d393805d..c0c3dff31e7f8 100644 --- a/python/pyspark/mllib/linalg.py +++ b/python/pyspark/mllib/linalg.py @@ -29,7 +29,9 @@ import numpy as np -from pyspark.serializers import AutoBatchedSerializer, PickleSerializer +from pyspark.sql import UserDefinedType, StructField, StructType, ArrayType, DoubleType, \ + IntegerType, ByteType, Row + __all__ = ['Vector', 'DenseVector', 'SparseVector', 'Vectors'] @@ -52,17 +54,6 @@ def fast_pickle_array(ar): _have_scipy = False -# this will call the MLlib version of pythonToJava() -def _to_java_object_rdd(rdd): - """ Return an JavaRDD of Object by unpickling - - It will convert each Python object into Java object by Pyrolite, whenever the - RDD is serialized in batch or not. - """ - rdd = rdd._reserialize(AutoBatchedSerializer(PickleSerializer())) - return rdd.ctx._jvm.SerDe.pythonToJava(rdd._jrdd, True) - - def _convert_to_vector(l): if isinstance(l, Vector): return l @@ -111,7 +102,61 @@ def _vector_size(v): raise TypeError("Cannot treat type %s as a vector" % type(v)) +def _format_float(f, digits=4): + s = str(round(f, digits)) + if '.' in s: + s = s[:s.index('.') + 1 + digits] + return s + + +class VectorUDT(UserDefinedType): + """ + SQL user-defined type (UDT) for Vector. + """ + + @classmethod + def sqlType(cls): + return StructType([ + StructField("type", ByteType(), False), + StructField("size", IntegerType(), True), + StructField("indices", ArrayType(IntegerType(), False), True), + StructField("values", ArrayType(DoubleType(), False), True)]) + + @classmethod + def module(cls): + return "pyspark.mllib.linalg" + + @classmethod + def scalaUDT(cls): + return "org.apache.spark.mllib.linalg.VectorUDT" + + def serialize(self, obj): + if isinstance(obj, SparseVector): + indices = [int(i) for i in obj.indices] + values = [float(v) for v in obj.values] + return (0, obj.size, indices, values) + elif isinstance(obj, DenseVector): + values = [float(v) for v in obj] + return (1, None, None, values) + else: + raise ValueError("cannot serialize %r of type %r" % (obj, type(obj))) + + def deserialize(self, datum): + assert len(datum) == 4, \ + "VectorUDT.deserialize given row with length %d but requires 4" % len(datum) + tpe = datum[0] + if tpe == 0: + return SparseVector(datum[1], datum[2], datum[3]) + elif tpe == 1: + return DenseVector(datum[3]) + else: + raise ValueError("do not recognize type %r" % tpe) + + class Vector(object): + + __UDT__ = VectorUDT() + """ Abstract class for DenseVector and SparseVector """ @@ -228,7 +273,7 @@ def __str__(self): return "[" + ",".join([str(v) for v in self.array]) + "]" def __repr__(self): - return "DenseVector(%r)" % self.array + return "DenseVector([%s])" % (', '.join(_format_float(i) for i in self.array)) def __eq__(self, other): return isinstance(other, DenseVector) and self.array == other.array @@ -416,7 +461,7 @@ def toArray(self): Returns a copy of this SparseVector as a 1-dimensional NumPy array. """ arr = np.zeros((self.size,), dtype=np.float64) - for i in xrange(self.indices.size): + for i in xrange(len(self.indices)): arr[self.indices[i]] = self.values[i] return arr @@ -431,7 +476,8 @@ def __str__(self): def __repr__(self): inds = self.indices vals = self.values - entries = ", ".join(["{0}: {1}".format(inds[i], vals[i]) for i in xrange(len(inds))]) + entries = ", ".join(["{0}: {1}".format(inds[i], _format_float(vals[i])) + for i in xrange(len(inds))]) return "SparseVector({0}, {{{1}}})".format(self.size, entries) def __eq__(self, other): @@ -491,7 +537,7 @@ def dense(elements): returns a NumPy array. >>> Vectors.dense([1, 2, 3]) - DenseVector(array('d', [1.0, 2.0, 3.0])) + DenseVector([1.0, 2.0, 3.0]) """ return DenseVector(elements) diff --git a/python/pyspark/mllib/random.py b/python/pyspark/mllib/random.py index 73baba4ace5f6..7eebfc6bcd894 100644 --- a/python/pyspark/mllib/random.py +++ b/python/pyspark/mllib/random.py @@ -21,22 +21,12 @@ from functools import wraps -from pyspark.rdd import RDD -from pyspark.serializers import BatchedSerializer, PickleSerializer +from pyspark.mllib.common import callMLlibFunc __all__ = ['RandomRDDs', ] -def serialize(f): - @wraps(f) - def func(sc, *a, **kw): - jrdd = f(sc, *a, **kw) - return RDD(sc._jvm.SerDe.javaToPython(jrdd), sc, - BatchedSerializer(PickleSerializer(), 1024)) - return func - - def toArray(f): @wraps(f) def func(sc, *a, **kw): @@ -52,7 +42,6 @@ class RandomRDDs(object): """ @staticmethod - @serialize def uniformRDD(sc, size, numPartitions=None, seed=None): """ Generates an RDD comprised of i.i.d. samples from the @@ -74,10 +63,9 @@ def uniformRDD(sc, size, numPartitions=None, seed=None): >>> parts == sc.defaultParallelism True """ - return sc._jvm.PythonMLLibAPI().uniformRDD(sc._jsc, size, numPartitions, seed) + return callMLlibFunc("uniformRDD", sc._jsc, size, numPartitions, seed) @staticmethod - @serialize def normalRDD(sc, size, numPartitions=None, seed=None): """ Generates an RDD comprised of i.i.d. samples from the standard normal @@ -97,17 +85,16 @@ def normalRDD(sc, size, numPartitions=None, seed=None): >>> abs(stats.stdev() - 1.0) < 0.1 True """ - return sc._jvm.PythonMLLibAPI().normalRDD(sc._jsc, size, numPartitions, seed) + return callMLlibFunc("normalRDD", sc._jsc, size, numPartitions, seed) @staticmethod - @serialize def poissonRDD(sc, mean, size, numPartitions=None, seed=None): """ Generates an RDD comprised of i.i.d. samples from the Poisson distribution with the input mean. >>> mean = 100.0 - >>> x = RandomRDDs.poissonRDD(sc, mean, 1000, seed=1L) + >>> x = RandomRDDs.poissonRDD(sc, mean, 1000, seed=2L) >>> stats = x.stats() >>> stats.count() 1000L @@ -117,11 +104,10 @@ def poissonRDD(sc, mean, size, numPartitions=None, seed=None): >>> abs(stats.stdev() - sqrt(mean)) < 0.5 True """ - return sc._jvm.PythonMLLibAPI().poissonRDD(sc._jsc, mean, size, numPartitions, seed) + return callMLlibFunc("poissonRDD", sc._jsc, mean, size, numPartitions, seed) @staticmethod @toArray - @serialize def uniformVectorRDD(sc, numRows, numCols, numPartitions=None, seed=None): """ Generates an RDD comprised of vectors containing i.i.d. samples drawn @@ -136,12 +122,10 @@ def uniformVectorRDD(sc, numRows, numCols, numPartitions=None, seed=None): >>> RandomRDDs.uniformVectorRDD(sc, 10, 10, 4).getNumPartitions() 4 """ - return sc._jvm.PythonMLLibAPI() \ - .uniformVectorRDD(sc._jsc, numRows, numCols, numPartitions, seed) + return callMLlibFunc("uniformVectorRDD", sc._jsc, numRows, numCols, numPartitions, seed) @staticmethod @toArray - @serialize def normalVectorRDD(sc, numRows, numCols, numPartitions=None, seed=None): """ Generates an RDD comprised of vectors containing i.i.d. samples drawn @@ -156,12 +140,10 @@ def normalVectorRDD(sc, numRows, numCols, numPartitions=None, seed=None): >>> abs(mat.std() - 1.0) < 0.1 True """ - return sc._jvm.PythonMLLibAPI() \ - .normalVectorRDD(sc._jsc, numRows, numCols, numPartitions, seed) + return callMLlibFunc("normalVectorRDD", sc._jsc, numRows, numCols, numPartitions, seed) @staticmethod @toArray - @serialize def poissonVectorRDD(sc, mean, numRows, numCols, numPartitions=None, seed=None): """ Generates an RDD comprised of vectors containing i.i.d. samples drawn @@ -179,8 +161,8 @@ def poissonVectorRDD(sc, mean, numRows, numCols, numPartitions=None, seed=None): >>> abs(mat.std() - sqrt(mean)) < 0.5 True """ - return sc._jvm.PythonMLLibAPI() \ - .poissonVectorRDD(sc._jsc, mean, numRows, numCols, numPartitions, seed) + return callMLlibFunc("poissonVectorRDD", sc._jsc, mean, numRows, numCols, + numPartitions, seed) def _test(): diff --git a/python/pyspark/mllib/recommendation.py b/python/pyspark/mllib/recommendation.py index 22872dbbe3b55..e8b998414d319 100644 --- a/python/pyspark/mllib/recommendation.py +++ b/python/pyspark/mllib/recommendation.py @@ -16,9 +16,8 @@ # from pyspark import SparkContext -from pyspark.serializers import PickleSerializer, AutoBatchedSerializer from pyspark.rdd import RDD -from pyspark.mllib.linalg import _to_java_object_rdd +from pyspark.mllib.common import JavaModelWrapper, callMLlibFunc, _to_java_object_rdd __all__ = ['MatrixFactorizationModel', 'ALS'] @@ -36,7 +35,7 @@ def __repr__(self): return "Rating(%d, %d, %d)" % (self.user, self.product, self.rating) -class MatrixFactorizationModel(object): +class MatrixFactorizationModel(JavaModelWrapper): """A matrix factorisation model trained by regularized alternating least-squares. @@ -71,48 +70,21 @@ class MatrixFactorizationModel(object): >>> len(latents) == 4 True """ - - def __init__(self, sc, java_model): - self._context = sc - self._java_model = java_model - - def __del__(self): - self._context._gateway.detach(self._java_model) - def predict(self, user, product): return self._java_model.predict(user, product) def predictAll(self, user_product): assert isinstance(user_product, RDD), "user_product should be RDD of (user, product)" first = user_product.first() - if isinstance(first, list): - user_product = user_product.map(tuple) - first = tuple(first) - assert type(first) is tuple and len(first) == 2, \ - "user_product should be RDD of (user, product)" - if any(isinstance(x, str) for x in first): - user_product = user_product.map(lambda (u, p): (int(x), int(p))) - first = tuple(map(int, first)) - assert all(type(x) is int for x in first), "user and product in user_product shoul be int" - sc = self._context - tuplerdd = sc._jvm.SerDe.asTupleRDD(_to_java_object_rdd(user_product).rdd()) - jresult = self._java_model.predict(tuplerdd).toJavaRDD() - return RDD(sc._jvm.SerDe.javaToPython(jresult), sc, - AutoBatchedSerializer(PickleSerializer())) + assert len(first) == 2, "user_product should be RDD of (user, product)" + user_product = user_product.map(lambda (u, p): (int(u), int(p))) + return self.call("predict", user_product) def userFeatures(self): - sc = self._context - juf = self._java_model.userFeatures() - juf = sc._jvm.SerDe.fromTuple2RDD(juf).toJavaRDD() - return RDD(sc._jvm.PythonRDD.javaToPython(juf), sc, - AutoBatchedSerializer(PickleSerializer())) + return self.call("getUserFeatures") def productFeatures(self): - sc = self._context - jpf = self._java_model.productFeatures() - jpf = sc._jvm.SerDe.fromTuple2RDD(jpf).toJavaRDD() - return RDD(sc._jvm.PythonRDD.javaToPython(jpf), sc, - AutoBatchedSerializer(PickleSerializer())) + return self.call("getProductFeatures") class ALS(object): @@ -126,32 +98,26 @@ def _prepare(cls, ratings): ratings = ratings.map(lambda x: Rating(*x)) else: raise ValueError("rating should be RDD of Rating or tuple/list") - # serialize them by AutoBatchedSerializer before cache to reduce the - # objects overhead in JVM - cached = ratings._reserialize(AutoBatchedSerializer(PickleSerializer())).cache() - return _to_java_object_rdd(cached) + return _to_java_object_rdd(ratings, True) @classmethod def train(cls, ratings, rank, iterations=5, lambda_=0.01, blocks=-1): - sc = ratings.context - jrating = cls._prepare(ratings) - mod = sc._jvm.PythonMLLibAPI().trainALSModel(jrating, rank, iterations, lambda_, blocks) - return MatrixFactorizationModel(sc, mod) + model = callMLlibFunc("trainALSModel", cls._prepare(ratings), rank, iterations, + lambda_, blocks) + return MatrixFactorizationModel(model) @classmethod def trainImplicit(cls, ratings, rank, iterations=5, lambda_=0.01, blocks=-1, alpha=0.01): - sc = ratings.context - jrating = cls._prepare(ratings) - mod = sc._jvm.PythonMLLibAPI().trainImplicitALSModel( - jrating, rank, iterations, lambda_, blocks, alpha) - return MatrixFactorizationModel(sc, mod) + model = callMLlibFunc("trainImplicitALSModel", cls._prepare(ratings), rank, + iterations, lambda_, blocks, alpha) + return MatrixFactorizationModel(model) def _test(): import doctest import pyspark.mllib.recommendation globs = pyspark.mllib.recommendation.__dict__.copy() - globs['sc'] = SparkContext('local[4]', 'PythonTest', batchSize=2) + globs['sc'] = SparkContext('local[4]', 'PythonTest') (failure_count, test_count) = doctest.testmod(globs=globs, optionflags=doctest.ELLIPSIS) globs['sc'].stop() if failure_count: diff --git a/python/pyspark/mllib/regression.py b/python/pyspark/mllib/regression.py index 93e17faf5cd51..43c1a2fc101dd 100644 --- a/python/pyspark/mllib/regression.py +++ b/python/pyspark/mllib/regression.py @@ -18,9 +18,8 @@ import numpy as np from numpy import array -from pyspark import SparkContext -from pyspark.serializers import PickleSerializer, AutoBatchedSerializer -from pyspark.mllib.linalg import SparseVector, _convert_to_vector, _to_java_object_rdd +from pyspark.mllib.common import callMLlibFunc, _to_java_object_rdd +from pyspark.mllib.linalg import SparseVector, _convert_to_vector __all__ = ['LabeledPoint', 'LinearModel', 'LinearRegressionModel', 'RidgeRegressionModel', 'LinearRegressionWithSGD', 'LassoWithSGD', 'RidgeRegressionWithSGD'] @@ -124,17 +123,11 @@ class LinearRegressionModel(LinearRegressionModelBase): # train_func should take two parameters, namely data and initial_weights, and # return the result of a call to the appropriate JVM stub. # _regression_train_wrapper is responsible for setup and error checking. -def _regression_train_wrapper(sc, train_func, modelClass, data, initial_weights): +def _regression_train_wrapper(train_func, modelClass, data, initial_weights): initial_weights = initial_weights or [0.0] * len(data.first().features) - ser = PickleSerializer() - initial_bytes = bytearray(ser.dumps(_convert_to_vector(initial_weights))) - # use AutoBatchedSerializer before cache to reduce the memory - # overhead in JVM - cached = data._reserialize(AutoBatchedSerializer(ser)).cache() - ans = train_func(_to_java_object_rdd(cached), initial_bytes) - assert len(ans) == 2, "JVM call result had unexpected length" - weights = ser.loads(str(ans[0])) - return modelClass(weights, ans[1]) + weights, intercept = train_func(_to_java_object_rdd(data, cache=True), + _convert_to_vector(initial_weights)) + return modelClass(weights, intercept) class LinearRegressionWithSGD(object): @@ -168,13 +161,12 @@ def train(cls, data, iterations=100, step=1.0, miniBatchFraction=1.0, training data (i.e. whether bias features are activated or not). """ - sc = data.context + def train(rdd, i): + return callMLlibFunc("trainLinearRegressionModelWithSGD", rdd, iterations, step, + miniBatchFraction, i, regParam, regType, intercept) - def train(jrdd, i): - return sc._jvm.PythonMLLibAPI().trainLinearRegressionModelWithSGD( - jrdd, iterations, step, miniBatchFraction, i, regParam, regType, intercept) - - return _regression_train_wrapper(sc, train, LinearRegressionModel, data, initialWeights) + return _regression_train_wrapper(train, LinearRegressionModel, + data, initialWeights) class LassoModel(LinearRegressionModelBase): @@ -216,12 +208,10 @@ class LassoWithSGD(object): def train(cls, data, iterations=100, step=1.0, regParam=1.0, miniBatchFraction=1.0, initialWeights=None): """Train a Lasso regression model on the given data.""" - sc = data.context - - def train(jrdd, i): - return sc._jvm.PythonMLLibAPI().trainLassoModelWithSGD( - jrdd, iterations, step, regParam, miniBatchFraction, i) - return _regression_train_wrapper(sc, train, LassoModel, data, initialWeights) + def train(rdd, i): + return callMLlibFunc("trainLassoModelWithSGD", rdd, iterations, step, regParam, + miniBatchFraction, i) + return _regression_train_wrapper(train, LassoModel, data, initialWeights) class RidgeRegressionModel(LinearRegressionModelBase): @@ -263,17 +253,17 @@ class RidgeRegressionWithSGD(object): def train(cls, data, iterations=100, step=1.0, regParam=1.0, miniBatchFraction=1.0, initialWeights=None): """Train a ridge regression model on the given data.""" - sc = data.context - - def train(jrdd, i): - return sc._jvm.PythonMLLibAPI().trainRidgeModelWithSGD( - jrdd, iterations, step, regParam, miniBatchFraction, i) + def train(rdd, i): + return callMLlibFunc("trainRidgeModelWithSGD", rdd, iterations, step, regParam, + miniBatchFraction, i) - return _regression_train_wrapper(sc, train, RidgeRegressionModel, data, initialWeights) + return _regression_train_wrapper(train, RidgeRegressionModel, + data, initialWeights) def _test(): import doctest + from pyspark import SparkContext globs = globals().copy() globs['sc'] = SparkContext('local[4]', 'PythonTest', batchSize=2) (failure_count, test_count) = doctest.testmod(globs=globs, optionflags=doctest.ELLIPSIS) diff --git a/python/pyspark/mllib/stat.py b/python/pyspark/mllib/stat.py index 84baf12b906df..15f0652f833d7 100644 --- a/python/pyspark/mllib/stat.py +++ b/python/pyspark/mllib/stat.py @@ -19,66 +19,36 @@ Python package for statistical functions in MLlib. """ -from functools import wraps - -from pyspark import PickleSerializer -from pyspark.mllib.linalg import _convert_to_vector, _to_java_object_rdd +from pyspark.mllib.common import callMLlibFunc, JavaModelWrapper +from pyspark.mllib.linalg import _convert_to_vector __all__ = ['MultivariateStatisticalSummary', 'Statistics'] -def serialize(f): - ser = PickleSerializer() - - @wraps(f) - def func(self): - jvec = f(self) - bytes = self._sc._jvm.SerDe.dumps(jvec) - return ser.loads(str(bytes)).toArray() - - return func - - -class MultivariateStatisticalSummary(object): +class MultivariateStatisticalSummary(JavaModelWrapper): """ Trait for multivariate statistical summary of a data matrix. """ - def __init__(self, sc, java_summary): - """ - :param sc: Spark context - :param java_summary: Handle to Java summary object - """ - self._sc = sc - self._java_summary = java_summary - - def __del__(self): - self._sc._gateway.detach(self._java_summary) - - @serialize def mean(self): - return self._java_summary.mean() + return self.call("mean").toArray() - @serialize def variance(self): - return self._java_summary.variance() + return self.call("variance").toArray() def count(self): - return self._java_summary.count() + return self.call("count") - @serialize def numNonzeros(self): - return self._java_summary.numNonzeros() + return self.call("numNonzeros").toArray() - @serialize def max(self): - return self._java_summary.max() + return self.call("max").toArray() - @serialize def min(self): - return self._java_summary.min() + return self.call("min").toArray() class Statistics(object): @@ -106,10 +76,8 @@ def colStats(rdd): >>> cStats.min() array([ 2., 0., 0., -2.]) """ - sc = rdd.ctx - jrdd = _to_java_object_rdd(rdd.map(_convert_to_vector)) - cStats = sc._jvm.PythonMLLibAPI().colStats(jrdd) - return MultivariateStatisticalSummary(sc, cStats) + cStats = callMLlibFunc("colStats", rdd.map(_convert_to_vector)) + return MultivariateStatisticalSummary(cStats) @staticmethod def corr(x, y=None, method=None): @@ -156,7 +124,6 @@ def corr(x, y=None, method=None): ... except TypeError: ... pass """ - sc = x.ctx # Check inputs to determine whether a single value or a matrix is needed for output. # Since it's legal for users to use the method name as the second argument, we need to # check if y is used to specify the method name instead. @@ -164,15 +131,9 @@ def corr(x, y=None, method=None): raise TypeError("Use 'method=' to specify method name.") if not y: - jx = _to_java_object_rdd(x.map(_convert_to_vector)) - resultMat = sc._jvm.PythonMLLibAPI().corr(jx, method) - bytes = sc._jvm.SerDe.dumps(resultMat) - ser = PickleSerializer() - return ser.loads(str(bytes)).toArray() + return callMLlibFunc("corr", x.map(_convert_to_vector), method).toArray() else: - jx = _to_java_object_rdd(x.map(float)) - jy = _to_java_object_rdd(y.map(float)) - return sc._jvm.PythonMLLibAPI().corr(jx, jy, method) + return callMLlibFunc("corr", x.map(float), y.map(float), method) def _test(): diff --git a/python/pyspark/mllib/tests.py b/python/pyspark/mllib/tests.py index d6fb87b378b4a..9fa4d6f6a2f5f 100644 --- a/python/pyspark/mllib/tests.py +++ b/python/pyspark/mllib/tests.py @@ -33,14 +33,14 @@ else: import unittest -from pyspark.serializers import PickleSerializer -from pyspark.mllib.linalg import Vector, SparseVector, DenseVector, _convert_to_vector +from pyspark.mllib.linalg import Vector, SparseVector, DenseVector, VectorUDT, _convert_to_vector from pyspark.mllib.regression import LabeledPoint from pyspark.mllib.random import RandomRDDs from pyspark.mllib.stat import Statistics +from pyspark.serializers import PickleSerializer +from pyspark.sql import SQLContext from pyspark.tests import ReusedPySparkTestCase as PySparkTestCase - _have_scipy = False try: import scipy.sparse @@ -221,6 +221,39 @@ def test_col_with_different_rdds(self): self.assertEqual(10, summary.count()) +class VectorUDTTests(PySparkTestCase): + + dv0 = DenseVector([]) + dv1 = DenseVector([1.0, 2.0]) + sv0 = SparseVector(2, [], []) + sv1 = SparseVector(2, [1], [2.0]) + udt = VectorUDT() + + def test_json_schema(self): + self.assertEqual(VectorUDT.fromJson(self.udt.jsonValue()), self.udt) + + def test_serialization(self): + for v in [self.dv0, self.dv1, self.sv0, self.sv1]: + self.assertEqual(v, self.udt.deserialize(self.udt.serialize(v))) + + def test_infer_schema(self): + sqlCtx = SQLContext(self.sc) + rdd = self.sc.parallelize([LabeledPoint(1.0, self.dv1), LabeledPoint(0.0, self.sv1)]) + srdd = sqlCtx.inferSchema(rdd) + schema = srdd.schema() + field = [f for f in schema.fields if f.name == "features"][0] + self.assertEqual(field.dataType, self.udt) + vectors = srdd.map(lambda p: p.features).collect() + self.assertEqual(len(vectors), 2) + for v in vectors: + if isinstance(v, SparseVector): + self.assertEqual(v, self.sv1) + elif isinstance(v, DenseVector): + self.assertEqual(v, self.dv1) + else: + raise ValueError("expecting a vector but got %r of type %r" % (v, type(v))) + + @unittest.skipIf(not _have_scipy, "SciPy not installed") class SciPyTests(PySparkTestCase): diff --git a/python/pyspark/mllib/tree.py b/python/pyspark/mllib/tree.py index 64ee79d83e849..5d1a3c0962796 100644 --- a/python/pyspark/mllib/tree.py +++ b/python/pyspark/mllib/tree.py @@ -15,36 +15,22 @@ # limitations under the License. # -from py4j.java_collections import MapConverter - from pyspark import SparkContext, RDD -from pyspark.serializers import BatchedSerializer, PickleSerializer -from pyspark.mllib.linalg import Vector, _convert_to_vector, _to_java_object_rdd +from pyspark.mllib.common import callMLlibFunc, JavaModelWrapper +from pyspark.mllib.linalg import _convert_to_vector from pyspark.mllib.regression import LabeledPoint __all__ = ['DecisionTreeModel', 'DecisionTree'] -class DecisionTreeModel(object): +class DecisionTreeModel(JavaModelWrapper): """ A decision tree model for classification or regression. EXPERIMENTAL: This is an experimental API. - It will probably be modified for Spark v1.2. + It will probably be modified in future. """ - - def __init__(self, sc, java_model): - """ - :param sc: Spark context - :param java_model: Handle to Java model object - """ - self._sc = sc - self._java_model = java_model - - def __del__(self): - self._sc._gateway.detach(self._java_model) - def predict(self, x): """ Predict the label of one or more examples. @@ -52,24 +38,11 @@ def predict(self, x): :param x: Data point (feature vector), or an RDD of data points (feature vectors). """ - SerDe = self._sc._jvm.SerDe - ser = PickleSerializer() if isinstance(x, RDD): - # Bulk prediction - first = x.take(1) - if not first: - return self._sc.parallelize([]) - if not isinstance(first[0], Vector): - x = x.map(_convert_to_vector) - jPred = self._java_model.predict(_to_java_object_rdd(x)).toJavaRDD() - jpyrdd = self._sc._jvm.SerDe.javaToPython(jPred) - return RDD(jpyrdd, self._sc, BatchedSerializer(ser, 1024)) + return self.call("predict", x.map(_convert_to_vector)) else: - # Assume x is a single data point. - bytes = bytearray(ser.dumps(_convert_to_vector(x))) - vec = self._sc._jvm.SerDe.loads(bytes) - return self._java_model.predict(vec) + return self.call("predict", _convert_to_vector(x)) def numNodes(self): return self._java_model.numNodes() @@ -98,19 +71,13 @@ class DecisionTree(object): """ @staticmethod - def _train(data, type, numClasses, categoricalFeaturesInfo, - impurity="gini", maxDepth=5, maxBins=32, minInstancesPerNode=1, - minInfoGain=0.0): + def _train(data, type, numClasses, features, impurity="gini", maxDepth=5, maxBins=32, + minInstancesPerNode=1, minInfoGain=0.0): first = data.first() assert isinstance(first, LabeledPoint), "the data should be RDD of LabeledPoint" - sc = data.context - jrdd = _to_java_object_rdd(data) - cfiMap = MapConverter().convert(categoricalFeaturesInfo, - sc._gateway._gateway_client) - model = sc._jvm.PythonMLLibAPI().trainDecisionTreeModel( - jrdd, type, numClasses, cfiMap, - impurity, maxDepth, maxBins, minInstancesPerNode, minInfoGain) - return DecisionTreeModel(sc, model) + model = callMLlibFunc("trainDecisionTreeModel", data, type, numClasses, features, + impurity, maxDepth, maxBins, minInstancesPerNode, minInfoGain) + return DecisionTreeModel(model) @staticmethod def trainClassifier(data, numClasses, categoricalFeaturesInfo, diff --git a/python/pyspark/mllib/util.py b/python/pyspark/mllib/util.py index 84b39a48619d2..96aef8f510fa6 100644 --- a/python/pyspark/mllib/util.py +++ b/python/pyspark/mllib/util.py @@ -18,8 +18,7 @@ import numpy as np import warnings -from pyspark.rdd import RDD -from pyspark.serializers import AutoBatchedSerializer, PickleSerializer +from pyspark.mllib.common import callMLlibFunc from pyspark.mllib.linalg import Vectors, SparseVector, _convert_to_vector from pyspark.mllib.regression import LabeledPoint @@ -173,9 +172,7 @@ def loadLabeledPoints(sc, path, minPartitions=None): (0.0,[1.01,2.02,3.03]) """ minPartitions = minPartitions or min(sc.defaultParallelism, 2) - jrdd = sc._jvm.PythonMLLibAPI().loadLabeledPoints(sc._jsc, path, minPartitions) - jpyrdd = sc._jvm.SerDe.javaToPython(jrdd) - return RDD(jpyrdd, sc, AutoBatchedSerializer(PickleSerializer())) + return callMLlibFunc("loadLabeledPoints", sc, path, minPartitions) def _test(): diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py index 126ee77616dca..053b342b0329e 100644 --- a/python/pyspark/rdd.py +++ b/python/pyspark/rdd.py @@ -120,7 +120,7 @@ class RDD(object): operated on in parallel. """ - def __init__(self, jrdd, ctx, jrdd_deserializer): + def __init__(self, jrdd, ctx, jrdd_deserializer=AutoBatchedSerializer(PickleSerializer())): self._jrdd = jrdd self.is_cached = False self.is_checkpointed = False @@ -129,12 +129,8 @@ def __init__(self, jrdd, ctx, jrdd_deserializer): self._id = jrdd.id() self._partitionFunc = None - def _toPickleSerialization(self): - if (self._jrdd_deserializer == PickleSerializer() or - self._jrdd_deserializer == BatchedSerializer(PickleSerializer())): - return self - else: - return self._reserialize(BatchedSerializer(PickleSerializer(), 10)) + def _pickled(self): + return self._reserialize(AutoBatchedSerializer(PickleSerializer())) def id(self): """ @@ -316,9 +312,6 @@ def sample(self, withReplacement, fraction, seed=None): """ Return a sampled subset of this RDD (relies on numpy and falls back on default random generator if numpy is unavailable). - - >>> sc.parallelize(range(0, 100)).sample(False, 0.1, 2).collect() #doctest: +SKIP - [2, 3, 20, 21, 24, 41, 42, 66, 67, 89, 90, 98] """ assert fraction >= 0.0, "Negative fraction value: %s" % fraction return self.mapPartitionsWithIndex(RDDSampler(withReplacement, fraction, seed).func, True) @@ -449,12 +442,11 @@ def intersection(self, other): def _reserialize(self, serializer=None): serializer = serializer or self.ctx.serializer - if self._jrdd_deserializer == serializer: - return self - else: - converted = self.map(lambda x: x, preservesPartitioning=True) - converted._jrdd_deserializer = serializer - return converted + if self._jrdd_deserializer != serializer: + if not isinstance(self, PipelinedRDD): + self = self.map(lambda x: x, preservesPartitioning=True) + self._jrdd_deserializer = serializer + return self def __add__(self, other): """ @@ -1123,9 +1115,8 @@ def saveAsNewAPIHadoopDataset(self, conf, keyConverter=None, valueConverter=None :param valueConverter: (None by default) """ jconf = self.ctx._dictToJavaMap(conf) - pickledRDD = self._toPickleSerialization() - batched = isinstance(pickledRDD._jrdd_deserializer, BatchedSerializer) - self.ctx._jvm.PythonRDD.saveAsHadoopDataset(pickledRDD._jrdd, batched, jconf, + pickledRDD = self._pickled() + self.ctx._jvm.PythonRDD.saveAsHadoopDataset(pickledRDD._jrdd, True, jconf, keyConverter, valueConverter, True) def saveAsNewAPIHadoopFile(self, path, outputFormatClass, keyClass=None, valueClass=None, @@ -1150,9 +1141,8 @@ def saveAsNewAPIHadoopFile(self, path, outputFormatClass, keyClass=None, valueCl :param conf: Hadoop job configuration, passed in as a dict (None by default) """ jconf = self.ctx._dictToJavaMap(conf) - pickledRDD = self._toPickleSerialization() - batched = isinstance(pickledRDD._jrdd_deserializer, BatchedSerializer) - self.ctx._jvm.PythonRDD.saveAsNewAPIHadoopFile(pickledRDD._jrdd, batched, path, + pickledRDD = self._pickled() + self.ctx._jvm.PythonRDD.saveAsNewAPIHadoopFile(pickledRDD._jrdd, True, path, outputFormatClass, keyClass, valueClass, keyConverter, valueConverter, jconf) @@ -1169,9 +1159,8 @@ def saveAsHadoopDataset(self, conf, keyConverter=None, valueConverter=None): :param valueConverter: (None by default) """ jconf = self.ctx._dictToJavaMap(conf) - pickledRDD = self._toPickleSerialization() - batched = isinstance(pickledRDD._jrdd_deserializer, BatchedSerializer) - self.ctx._jvm.PythonRDD.saveAsHadoopDataset(pickledRDD._jrdd, batched, jconf, + pickledRDD = self._pickled() + self.ctx._jvm.PythonRDD.saveAsHadoopDataset(pickledRDD._jrdd, True, jconf, keyConverter, valueConverter, False) def saveAsHadoopFile(self, path, outputFormatClass, keyClass=None, valueClass=None, @@ -1198,9 +1187,8 @@ def saveAsHadoopFile(self, path, outputFormatClass, keyClass=None, valueClass=No :param compressionCodecClass: (None by default) """ jconf = self.ctx._dictToJavaMap(conf) - pickledRDD = self._toPickleSerialization() - batched = isinstance(pickledRDD._jrdd_deserializer, BatchedSerializer) - self.ctx._jvm.PythonRDD.saveAsHadoopFile(pickledRDD._jrdd, batched, path, + pickledRDD = self._pickled() + self.ctx._jvm.PythonRDD.saveAsHadoopFile(pickledRDD._jrdd, True, path, outputFormatClass, keyClass, valueClass, keyConverter, valueConverter, @@ -1218,9 +1206,8 @@ def saveAsSequenceFile(self, path, compressionCodecClass=None): :param path: path to sequence file :param compressionCodecClass: (None by default) """ - pickledRDD = self._toPickleSerialization() - batched = isinstance(pickledRDD._jrdd_deserializer, BatchedSerializer) - self.ctx._jvm.PythonRDD.saveAsSequenceFile(pickledRDD._jrdd, batched, + pickledRDD = self._pickled() + self.ctx._jvm.PythonRDD.saveAsSequenceFile(pickledRDD._jrdd, True, path, compressionCodecClass) def saveAsPickleFile(self, path, batchSize=10): @@ -1235,8 +1222,11 @@ def saveAsPickleFile(self, path, batchSize=10): >>> sorted(sc.pickleFile(tmpFile.name, 5).collect()) [1, 2, 'rdd', 'spark'] """ - self._reserialize(BatchedSerializer(PickleSerializer(), - batchSize))._jrdd.saveAsObjectFile(path) + if batchSize == 0: + ser = AutoBatchedSerializer(PickleSerializer()) + else: + ser = BatchedSerializer(PickleSerializer(), batchSize) + self._reserialize(ser)._jrdd.saveAsObjectFile(path) def saveAsTextFile(self, path): """ @@ -1807,13 +1797,10 @@ def zip(self, other): >>> x.zip(y).collect() [(0, 1000), (1, 1001), (2, 1002), (3, 1003), (4, 1004)] """ - if self.getNumPartitions() != other.getNumPartitions(): - raise ValueError("Can only zip with RDD which has the same number of partitions") - def get_batch_size(ser): if isinstance(ser, BatchedSerializer): return ser.batchSize - return 0 + return 1 def batch_as(rdd, batchSize): ser = rdd._jrdd_deserializer @@ -1823,12 +1810,16 @@ def batch_as(rdd, batchSize): my_batch = get_batch_size(self._jrdd_deserializer) other_batch = get_batch_size(other._jrdd_deserializer) - if my_batch != other_batch: - # use the greatest batchSize to batch the other one. - if my_batch > other_batch: - other = batch_as(other, my_batch) - else: - self = batch_as(self, other_batch) + # use the smallest batchSize for both of them + batchSize = min(my_batch, other_batch) + if batchSize <= 0: + # auto batched or unlimited + batchSize = 100 + other = batch_as(other, batchSize) + self = batch_as(self, batchSize) + + if self.getNumPartitions() != other.getNumPartitions(): + raise ValueError("Can only zip with RDD which has the same number of partitions") # There will be an Exception in JVM if there are different number # of items in each partitions. @@ -1897,11 +1888,11 @@ def setName(self, name): Assign a name to this RDD. >>> rdd1 = sc.parallelize([1,2]) - >>> rdd1.setName('RDD1') - >>> rdd1.name() + >>> rdd1.setName('RDD1').name() 'RDD1' """ self._jrdd.setName(name) + return self def toDebugString(self): """ @@ -1967,25 +1958,14 @@ def lookup(self, key): return values.collect() - def _is_pickled(self): - """ Return this RDD is serialized by Pickle or not. """ - der = self._jrdd_deserializer - if isinstance(der, PickleSerializer): - return True - if isinstance(der, BatchedSerializer) and isinstance(der.serializer, PickleSerializer): - return True - return False - def _to_java_object_rdd(self): """ Return an JavaRDD of Object by unpickling It will convert each Python object into Java object by Pyrolite, whenever the RDD is serialized in batch or not. """ - rdd = self._reserialize(AutoBatchedSerializer(PickleSerializer())) \ - if not self._is_pickled() else self - is_batch = isinstance(rdd._jrdd_deserializer, BatchedSerializer) - return self.ctx._jvm.PythonRDD.pythonToJava(rdd._jrdd, is_batch) + rdd = self._pickled() + return self.ctx._jvm.SerDeUtil.pythonToJava(rdd._jrdd, True) def countApprox(self, timeout, confidence=0.95): """ @@ -2165,7 +2145,7 @@ def _test(): globs = globals().copy() # The small batch size here ensures that we see multiple batches, # even in these small test examples: - globs['sc'] = SparkContext('local[4]', 'PythonTest', batchSize=2) + globs['sc'] = SparkContext('local[4]', 'PythonTest') (failure_count, test_count) = doctest.testmod( globs=globs, optionflags=doctest.ELLIPSIS) globs['sc'].stop() diff --git a/python/pyspark/rddsampler.py b/python/pyspark/rddsampler.py index 528a181e8905a..f5c3cfd259a5b 100644 --- a/python/pyspark/rddsampler.py +++ b/python/pyspark/rddsampler.py @@ -40,14 +40,13 @@ def __init__(self, withReplacement, seed=None): def initRandomGenerator(self, split): if self._use_numpy: import numpy - self._random = numpy.random.RandomState(self._seed) + self._random = numpy.random.RandomState(self._seed ^ split) else: - self._random = random.Random(self._seed) + self._random = random.Random(self._seed ^ split) - for _ in range(0, split): - # discard the next few values in the sequence to have a - # different seed for the different splits - self._random.randint(0, 2 ** 32 - 1) + # mixing because the initial seeds are close to each other + for _ in xrange(10): + self._random.randint(0, 1) self._split = split self._rand_initialized = True diff --git a/python/pyspark/serializers.py b/python/pyspark/serializers.py index 30187f378edfe..8bf8c86a0ed92 100644 --- a/python/pyspark/serializers.py +++ b/python/pyspark/serializers.py @@ -33,9 +33,8 @@ [0, 2, 4, 6, 8, 10, 12, 14, 16, 18] >>> sc.stop() -By default, PySpark serialize objects in batches; the batch size can be -controlled through SparkContext's C{batchSize} parameter -(the default size is 1024 objects): +PySpark serialize objects in batches; By default, the batch size is chosen based +on the size of objects, also configurable by SparkContext's C{batchSize} parameter: >>> sc = SparkContext('local', 'test', batchSize=2) >>> rdd = sc.parallelize(range(16), 4).map(lambda x: x) @@ -48,16 +47,6 @@ >>> rdd._jrdd.count() 8L >>> sc.stop() - -A batch size of -1 uses an unlimited batch size, and a size of 1 disables -batching: - ->>> sc = SparkContext('local', 'test', batchSize=1) ->>> rdd = sc.parallelize(range(16), 4).map(lambda x: x) ->>> rdd.glom().collect() -[[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11], [12, 13, 14, 15]] ->>> rdd._jrdd.count() -16L """ import cPickle @@ -73,7 +62,7 @@ from pyspark import cloudpickle -__all__ = ["PickleSerializer", "MarshalSerializer"] +__all__ = ["PickleSerializer", "MarshalSerializer", "UTF8Deserializer"] class SpecialLengths(object): @@ -113,7 +102,7 @@ def __ne__(self, other): return not self.__eq__(other) def __repr__(self): - return "<%s object>" % self.__class__.__name__ + return "%s()" % self.__class__.__name__ def __hash__(self): return hash(str(self)) @@ -181,6 +170,7 @@ class BatchedSerializer(Serializer): """ UNLIMITED_BATCH_SIZE = -1 + UNKNOWN_BATCH_SIZE = 0 def __init__(self, serializer, batchSize=UNLIMITED_BATCH_SIZE): self.serializer = serializer @@ -213,10 +203,10 @@ def _load_stream_without_unbatching(self, stream): def __eq__(self, other): return (isinstance(other, BatchedSerializer) and - other.serializer == self.serializer) + other.serializer == self.serializer and other.batchSize == self.batchSize) def __repr__(self): - return "BatchedSerializer<%s>" % str(self.serializer) + return "BatchedSerializer(%s, %d)" % (str(self.serializer), self.batchSize) class FlattedValuesSerializer(BatchedSerializer): @@ -245,7 +235,7 @@ class AutoBatchedSerializer(BatchedSerializer): """ def __init__(self, serializer, bestSize=1 << 16): - BatchedSerializer.__init__(self, serializer, -1) + BatchedSerializer.__init__(self, serializer, self.UNKNOWN_BATCH_SIZE) self.bestSize = bestSize def dump_stream(self, iterator, stream): @@ -268,10 +258,10 @@ def dump_stream(self, iterator, stream): def __eq__(self, other): return (isinstance(other, AutoBatchedSerializer) and - other.serializer == self.serializer) + other.serializer == self.serializer and other.bestSize == self.bestSize) def __str__(self): - return "AutoBatchedSerializer<%s>" % str(self.serializer) + return "AutoBatchedSerializer(%s)" % str(self.serializer) class CartesianDeserializer(FramedSerializer): @@ -304,7 +294,7 @@ def __eq__(self, other): self.key_ser == other.key_ser and self.val_ser == other.val_ser) def __repr__(self): - return "CartesianDeserializer<%s, %s>" % \ + return "CartesianDeserializer(%s, %s)" % \ (str(self.key_ser), str(self.val_ser)) @@ -331,7 +321,7 @@ def __eq__(self, other): self.key_ser == other.key_ser and self.val_ser == other.val_ser) def __repr__(self): - return "PairDeserializer<%s, %s>" % (str(self.key_ser), str(self.val_ser)) + return "PairDeserializer(%s, %s)" % (str(self.key_ser), str(self.val_ser)) class NoOpSerializer(FramedSerializer): @@ -450,7 +440,7 @@ def loads(self, obj): class AutoSerializer(FramedSerializer): """ - Choose marshal or cPickle as serialization protocol autumatically + Choose marshal or cPickle as serialization protocol automatically """ def __init__(self): diff --git a/python/pyspark/shuffle.py b/python/pyspark/shuffle.py index 6d52bc5ad807f..4210d730dbb71 100644 --- a/python/pyspark/shuffle.py +++ b/python/pyspark/shuffle.py @@ -16,7 +16,6 @@ # import os -import sys import platform import shutil import warnings @@ -27,7 +26,7 @@ import pyspark.heapq3 as heapq from pyspark.serializers import BatchedSerializer, PickleSerializer, FlattedValuesSerializer, \ - CompressedSerializer + CompressedSerializer, AutoBatchedSerializer try: @@ -162,6 +161,12 @@ def iteritems(self): return self.data.iteritems() +def _compressed_serializer(self, serializer=None): + # always use PickleSerializer to simplify implementation + ser = PickleSerializer() + return AutoBatchedSerializer(CompressedSerializer(ser)) + + class ExternalMerger(Merger): """ @@ -222,7 +227,7 @@ def __init__(self, aggregator, memory_limit=512, serializer=None, localdirs=None, scale=1, partitions=59, batch=1000): Merger.__init__(self, aggregator) self.memory_limit = memory_limit - self.serializer = self._compressed_serializer(serializer) + self.serializer = _compressed_serializer(serializer) self.localdirs = localdirs or _get_local_dirs(str(id(self))) # number of partitions when spill data into disks self.partitions = partitions @@ -239,18 +244,6 @@ def __init__(self, aggregator, memory_limit=512, serializer=None, # randomize the hash of key, id(o) is the address of o (aligned by 8) self._seed = id(self) + 7 - def _compressed_serializer(self, serializer=None): - # default serializer is only used for tests - ser = serializer or PickleSerializer() - # add compression - if isinstance(ser, BatchedSerializer): - if not isinstance(ser.serializer, CompressedSerializer): - ser = BatchedSerializer(CompressedSerializer(ser.serializer), ser.batchSize) - else: - if not isinstance(ser, CompressedSerializer): - ser = BatchedSerializer(CompressedSerializer(ser), 1024) - return ser - def _get_spill_dir(self, n): """ Choose one directory for spill by number n """ return os.path.join(self.localdirs[n % len(self.localdirs)], str(n)) @@ -460,8 +453,7 @@ class ExternalSorter(object): def __init__(self, memory_limit, serializer=None): self.memory_limit = memory_limit self.local_dirs = _get_local_dirs("sort") - self.serializer = serializer or BatchedSerializer( - CompressedSerializer(PickleSerializer()), 1024) + self.serializer = _compressed_serializer(serializer) def _get_path(self, n): """ Choose one directory for spill by number n """ @@ -696,12 +688,9 @@ class ExternalGroupBy(ExternalMerger): SORT_KEY_LIMIT = 1000 def _flatted_serializer(self): + assert isinstance(self.serializer, BatchedSerializer) ser = self.serializer - if not isinstance(ser, (BatchedSerializer, FlattedValuesSerializer)): - ser = BatchedSerializer(ser, 1024) - if not isinstance(ser, FlattedValuesSerializer): - ser = FlattedValuesSerializer(ser, 20) - return ser + return FlattedValuesSerializer(ser, 20) def _object_size(self, obj): return len(obj) diff --git a/python/pyspark/sql.py b/python/pyspark/sql.py index 7daf306f68479..e5d62a466cab6 100644 --- a/python/pyspark/sql.py +++ b/python/pyspark/sql.py @@ -35,6 +35,7 @@ import keyword import warnings import json +import re from array import array from operator import itemgetter from itertools import imap @@ -43,13 +44,14 @@ from py4j.java_collections import ListConverter, MapConverter from pyspark.rdd import RDD -from pyspark.serializers import BatchedSerializer, PickleSerializer, CloudPickleSerializer +from pyspark.serializers import BatchedSerializer, AutoBatchedSerializer, PickleSerializer, \ + CloudPickleSerializer from pyspark.storagelevel import StorageLevel from pyspark.traceback_utils import SCCallSiteSync __all__ = [ - "StringType", "BinaryType", "BooleanType", "TimestampType", "DecimalType", + "StringType", "BinaryType", "BooleanType", "DateType", "TimestampType", "DecimalType", "DoubleType", "FloatType", "ByteType", "IntegerType", "LongType", "ShortType", "ArrayType", "MapType", "StructField", "StructType", "SQLContext", "HiveContext", "SchemaRDD", "Row"] @@ -108,6 +110,15 @@ def __eq__(self, other): return self is other +class NullType(PrimitiveType): + + """Spark SQL NullType + + The data type representing None, used for the types which has not + been inferred. + """ + + class StringType(PrimitiveType): """Spark SQL StringType @@ -132,6 +143,14 @@ class BooleanType(PrimitiveType): """ +class DateType(PrimitiveType): + + """Spark SQL DateType + + The data type representing datetime.date values. + """ + + class TimestampType(PrimitiveType): """Spark SQL TimestampType @@ -140,13 +159,30 @@ class TimestampType(PrimitiveType): """ -class DecimalType(PrimitiveType): +class DecimalType(DataType): """Spark SQL DecimalType The data type representing decimal.Decimal values. """ + def __init__(self, precision=None, scale=None): + self.precision = precision + self.scale = scale + self.hasPrecisionInfo = precision is not None + + def jsonValue(self): + if self.hasPrecisionInfo: + return "decimal(%d,%d)" % (self.precision, self.scale) + else: + return "decimal" + + def __repr__(self): + if self.hasPrecisionInfo: + return "DecimalType(%d,%d)" % (self.precision, self.scale) + else: + return "DecimalType()" + class DoubleType(PrimitiveType): @@ -305,12 +341,15 @@ class StructField(DataType): """ - def __init__(self, name, dataType, nullable): + def __init__(self, name, dataType, nullable=True, metadata=None): """Creates a StructField :param name: the name of this field. :param dataType: the data type of this field. :param nullable: indicates whether values of this field can be null. + :param metadata: metadata of this field, which is a map from string + to simple type that can be serialized to JSON + automatically >>> (StructField("f1", StringType, True) ... == StructField("f1", StringType, True)) @@ -322,6 +361,7 @@ def __init__(self, name, dataType, nullable): self.name = name self.dataType = dataType self.nullable = nullable + self.metadata = metadata or {} def __repr__(self): return "StructField(%s,%s,%s)" % (self.name, self.dataType, @@ -330,13 +370,15 @@ def __repr__(self): def jsonValue(self): return {"name": self.name, "type": self.dataType.jsonValue(), - "nullable": self.nullable} + "nullable": self.nullable, + "metadata": self.metadata} @classmethod def fromJson(cls, json): return StructField(json["name"], _parse_datatype_json_value(json["type"]), - json["nullable"]) + json["nullable"], + json["metadata"]) class StructType(DataType): @@ -376,6 +418,75 @@ def fromJson(cls, json): return StructType([StructField.fromJson(f) for f in json["fields"]]) +class UserDefinedType(DataType): + """ + :: WARN: Spark Internal Use Only :: + SQL User-Defined Type (UDT). + """ + + @classmethod + def typeName(cls): + return cls.__name__.lower() + + @classmethod + def sqlType(cls): + """ + Underlying SQL storage type for this UDT. + """ + raise NotImplementedError("UDT must implement sqlType().") + + @classmethod + def module(cls): + """ + The Python module of the UDT. + """ + raise NotImplementedError("UDT must implement module().") + + @classmethod + def scalaUDT(cls): + """ + The class name of the paired Scala UDT. + """ + raise NotImplementedError("UDT must have a paired Scala UDT.") + + def serialize(self, obj): + """ + Converts the a user-type object into a SQL datum. + """ + raise NotImplementedError("UDT must implement serialize().") + + def deserialize(self, datum): + """ + Converts a SQL datum into a user-type object. + """ + raise NotImplementedError("UDT must implement deserialize().") + + def json(self): + return json.dumps(self.jsonValue(), separators=(',', ':'), sort_keys=True) + + def jsonValue(self): + schema = { + "type": "udt", + "class": self.scalaUDT(), + "pyClass": "%s.%s" % (self.module(), type(self).__name__), + "sqlType": self.sqlType().jsonValue() + } + return schema + + @classmethod + def fromJson(cls, json): + pyUDT = json["pyClass"] + split = pyUDT.rfind(".") + pyModule = pyUDT[:split] + pyClass = pyUDT[split+1:] + m = __import__(pyModule, globals(), locals(), [pyClass], -1) + UDT = getattr(m, pyClass) + return UDT() + + def __eq__(self, other): + return type(self) == type(other) + + _all_primitive_types = dict((v.typeName(), v) for v in globals().itervalues() if type(v) is PrimitiveTypeSingleton and @@ -415,7 +526,8 @@ def _parse_datatype_json_string(json_string): ... StructField("simpleArray", simple_arraytype, True), ... StructField("simpleMap", simple_maptype, True), ... StructField("simpleStruct", simple_structtype, True), - ... StructField("boolean", BooleanType(), False)]) + ... StructField("boolean", BooleanType(), False), + ... StructField("withMeta", DoubleType(), False, {"name": "age"})]) >>> check_datatype(complex_structtype) True >>> # Complex ArrayType. @@ -427,19 +539,43 @@ def _parse_datatype_json_string(json_string): ... complex_arraytype, False) >>> check_datatype(complex_maptype) True + >>> check_datatype(ExamplePointUDT()) + True + >>> structtype_with_udt = StructType([StructField("label", DoubleType(), False), + ... StructField("point", ExamplePointUDT(), False)]) + >>> check_datatype(structtype_with_udt) + True """ return _parse_datatype_json_value(json.loads(json_string)) +_FIXED_DECIMAL = re.compile("decimal\\((\\d+),(\\d+)\\)") + + def _parse_datatype_json_value(json_value): - if type(json_value) is unicode and json_value in _all_primitive_types.keys(): - return _all_primitive_types[json_value]() + if type(json_value) is unicode: + if json_value in _all_primitive_types.keys(): + return _all_primitive_types[json_value]() + elif json_value == u'decimal': + return DecimalType() + elif _FIXED_DECIMAL.match(json_value): + m = _FIXED_DECIMAL.match(json_value) + return DecimalType(int(m.group(1)), int(m.group(2))) + else: + raise ValueError("Could not parse datatype: %s" % json_value) else: - return _all_complex_types[json_value["type"]].fromJson(json_value) + tpe = json_value["type"] + if tpe in _all_complex_types: + return _all_complex_types[tpe].fromJson(json_value) + elif tpe == 'udt': + return UserDefinedType.fromJson(json_value) + else: + raise ValueError("not supported type: %s" % tpe) -# Mapping Python types to Spark SQL DateType +# Mapping Python types to Spark SQL DataType _type_mappings = { + type(None): NullType, bool: BooleanType, int: IntegerType, long: LongType, @@ -448,30 +584,41 @@ def _parse_datatype_json_value(json_value): unicode: StringType, bytearray: BinaryType, decimal.Decimal: DecimalType, + datetime.date: DateType, datetime.datetime: TimestampType, - datetime.date: TimestampType, datetime.time: TimestampType, } def _infer_type(obj): - """Infer the DataType from obj""" + """Infer the DataType from obj + + >>> p = ExamplePoint(1.0, 2.0) + >>> _infer_type(p) + ExamplePointUDT + """ if obj is None: raise ValueError("Can not infer type for None") + if hasattr(obj, '__UDT__'): + return obj.__UDT__ + dataType = _type_mappings.get(type(obj)) if dataType is not None: return dataType() if isinstance(obj, dict): - if not obj: - raise ValueError("Can not infer type for empty dict") - key, value = obj.iteritems().next() - return MapType(_infer_type(key), _infer_type(value), True) + for key, value in obj.iteritems(): + if key is not None and value is not None: + return MapType(_infer_type(key), _infer_type(value), True) + else: + return MapType(NullType(), NullType(), True) elif isinstance(obj, (list, array)): - if not obj: - raise ValueError("Can not infer type for empty list/array") - return ArrayType(_infer_type(obj[0]), True) + for v in obj: + if v is not None: + return ArrayType(_infer_type(obj[0]), True) + else: + return ArrayType(NullType(), True) else: try: return _infer_schema(obj) @@ -504,60 +651,180 @@ def _infer_schema(row): return StructType(fields) -def _create_converter(obj, dataType): +def _need_python_to_sql_conversion(dataType): + """ + Checks whether we need python to sql conversion for the given type. + For now, only UDTs need this conversion. + + >>> _need_python_to_sql_conversion(DoubleType()) + False + >>> schema0 = StructType([StructField("indices", ArrayType(IntegerType(), False), False), + ... StructField("values", ArrayType(DoubleType(), False), False)]) + >>> _need_python_to_sql_conversion(schema0) + False + >>> _need_python_to_sql_conversion(ExamplePointUDT()) + True + >>> schema1 = ArrayType(ExamplePointUDT(), False) + >>> _need_python_to_sql_conversion(schema1) + True + >>> schema2 = StructType([StructField("label", DoubleType(), False), + ... StructField("point", ExamplePointUDT(), False)]) + >>> _need_python_to_sql_conversion(schema2) + True + """ + if isinstance(dataType, StructType): + return any([_need_python_to_sql_conversion(f.dataType) for f in dataType.fields]) + elif isinstance(dataType, ArrayType): + return _need_python_to_sql_conversion(dataType.elementType) + elif isinstance(dataType, MapType): + return _need_python_to_sql_conversion(dataType.keyType) or \ + _need_python_to_sql_conversion(dataType.valueType) + elif isinstance(dataType, UserDefinedType): + return True + else: + return False + + +def _python_to_sql_converter(dataType): + """ + Returns a converter that converts a Python object into a SQL datum for the given type. + + >>> conv = _python_to_sql_converter(DoubleType()) + >>> conv(1.0) + 1.0 + >>> conv = _python_to_sql_converter(ArrayType(DoubleType(), False)) + >>> conv([1.0, 2.0]) + [1.0, 2.0] + >>> conv = _python_to_sql_converter(ExamplePointUDT()) + >>> conv(ExamplePoint(1.0, 2.0)) + [1.0, 2.0] + >>> schema = StructType([StructField("label", DoubleType(), False), + ... StructField("point", ExamplePointUDT(), False)]) + >>> conv = _python_to_sql_converter(schema) + >>> conv((1.0, ExamplePoint(1.0, 2.0))) + (1.0, [1.0, 2.0]) + """ + if not _need_python_to_sql_conversion(dataType): + return lambda x: x + + if isinstance(dataType, StructType): + names, types = zip(*[(f.name, f.dataType) for f in dataType.fields]) + converters = map(_python_to_sql_converter, types) + + def converter(obj): + if isinstance(obj, dict): + return tuple(c(obj.get(n)) for n, c in zip(names, converters)) + elif isinstance(obj, tuple): + if hasattr(obj, "_fields") or hasattr(obj, "__FIELDS__"): + return tuple(c(v) for c, v in zip(converters, obj)) + elif all(isinstance(x, tuple) and len(x) == 2 for x in obj): # k-v pairs + d = dict(obj) + return tuple(c(d.get(n)) for n, c in zip(names, converters)) + else: + return tuple(c(v) for c, v in zip(converters, obj)) + else: + raise ValueError("Unexpected tuple %r with type %r" % (obj, dataType)) + return converter + elif isinstance(dataType, ArrayType): + element_converter = _python_to_sql_converter(dataType.elementType) + return lambda a: [element_converter(v) for v in a] + elif isinstance(dataType, MapType): + key_converter = _python_to_sql_converter(dataType.keyType) + value_converter = _python_to_sql_converter(dataType.valueType) + return lambda m: dict([(key_converter(k), value_converter(v)) for k, v in m.items()]) + elif isinstance(dataType, UserDefinedType): + return lambda obj: dataType.serialize(obj) + else: + raise ValueError("Unexpected type %r" % dataType) + + +def _has_nulltype(dt): + """ Return whether there is NullType in `dt` or not """ + if isinstance(dt, StructType): + return any(_has_nulltype(f.dataType) for f in dt.fields) + elif isinstance(dt, ArrayType): + return _has_nulltype((dt.elementType)) + elif isinstance(dt, MapType): + return _has_nulltype(dt.keyType) or _has_nulltype(dt.valueType) + else: + return isinstance(dt, NullType) + + +def _merge_type(a, b): + if isinstance(a, NullType): + return b + elif isinstance(b, NullType): + return a + elif type(a) is not type(b): + # TODO: type cast (such as int -> long) + raise TypeError("Can not merge type %s and %s" % (a, b)) + + # same type + if isinstance(a, StructType): + nfs = dict((f.name, f.dataType) for f in b.fields) + fields = [StructField(f.name, _merge_type(f.dataType, nfs.get(f.name, NullType()))) + for f in a.fields] + names = set([f.name for f in fields]) + for n in nfs: + if n not in names: + fields.append(StructField(n, nfs[n])) + return StructType(fields) + + elif isinstance(a, ArrayType): + return ArrayType(_merge_type(a.elementType, b.elementType), True) + + elif isinstance(a, MapType): + return MapType(_merge_type(a.keyType, b.keyType), + _merge_type(a.valueType, b.valueType), + True) + else: + return a + + +def _create_converter(dataType): """Create an converter to drop the names of fields in obj """ if isinstance(dataType, ArrayType): - conv = _create_converter(obj[0], dataType.elementType) + conv = _create_converter(dataType.elementType) return lambda row: map(conv, row) elif isinstance(dataType, MapType): - value = obj.values()[0] - conv = _create_converter(value, dataType.valueType) + conv = _create_converter(dataType.valueType) return lambda row: dict((k, conv(v)) for k, v in row.iteritems()) + elif isinstance(dataType, NullType): + return lambda x: None + elif not isinstance(dataType, StructType): return lambda x: x # dataType must be StructType names = [f.name for f in dataType.fields] + converters = [_create_converter(f.dataType) for f in dataType.fields] + + def convert_struct(obj): + if obj is None: + return + + if isinstance(obj, tuple): + if hasattr(obj, "fields"): + d = dict(zip(obj.fields, obj)) + if hasattr(obj, "__FIELDS__"): + d = dict(zip(obj.__FIELDS__, obj)) + elif all(isinstance(x, tuple) and len(x) == 2 for x in obj): + d = dict(obj) + else: + raise ValueError("unexpected tuple: %s" % obj) - if isinstance(obj, dict): - conv = lambda o: tuple(o.get(n) for n in names) - - elif isinstance(obj, tuple): - if hasattr(obj, "_fields"): # namedtuple - conv = tuple - elif hasattr(obj, "__FIELDS__"): - conv = tuple - elif all(isinstance(x, tuple) and len(x) == 2 for x in obj): - conv = lambda o: tuple(v for k, v in o) + elif isinstance(obj, dict): + d = obj + elif hasattr(obj, "__dict__"): # object + d = obj.__dict__ else: - raise ValueError("unexpected tuple") - - elif hasattr(obj, "__dict__"): # object - conv = lambda o: [o.__dict__.get(n, None) for n in names] + raise ValueError("Unexpected obj: %s" % obj) - if all(isinstance(f.dataType, PrimitiveType) for f in dataType.fields): - return conv + return tuple([conv(d.get(name)) for name, conv in zip(names, converters)]) - row = conv(obj) - convs = [_create_converter(v, f.dataType) - for v, f in zip(row, dataType.fields)] - - def nested_conv(row): - return tuple(f(v) for f, v in zip(convs, conv(row))) - - return nested_conv - - -def _drop_schema(rows, schema): - """ all the names of fields, becoming tuples""" - iterator = iter(rows) - row = iterator.next() - converter = _create_converter(row, schema) - yield converter(row) - for i in iterator: - yield converter(i) + return convert_struct _BRACKETS = {'(': ')', '[': ']', '{': '}'} @@ -656,10 +923,10 @@ def _infer_schema_type(obj, dataType): """ Fill the dataType with types infered from obj - >>> schema = _parse_schema_abstract("a b c") - >>> row = (1, 1.0, "str") + >>> schema = _parse_schema_abstract("a b c d") + >>> row = (1, 1.0, "str", datetime.date(2014, 10, 10)) >>> _infer_schema_type(row, schema) - StructType...IntegerType...DoubleType...StringType... + StructType...IntegerType...DoubleType...StringType...DateType... >>> row = [[1], {"key": (1, 2.0)}] >>> schema = _parse_schema_abstract("a[] b{c d}") >>> _infer_schema_type(row, schema) @@ -669,7 +936,7 @@ def _infer_schema_type(obj, dataType): return _infer_type(obj) if not obj: - raise ValueError("Can not infer type from empty value") + return NullType() if isinstance(dataType, ArrayType): eType = _infer_schema_type(obj[0], dataType.elementType) @@ -703,6 +970,7 @@ def _infer_schema_type(obj, dataType): DecimalType: (decimal.Decimal,), StringType: (str, unicode), BinaryType: (bytearray,), + DateType: (datetime.date,), TimestampType: (datetime.datetime,), ArrayType: (list, tuple, array), MapType: (dict,), @@ -730,17 +998,28 @@ def _verify_type(obj, dataType): Traceback (most recent call last): ... ValueError:... + >>> _verify_type(ExamplePoint(1.0, 2.0), ExamplePointUDT()) + >>> _verify_type([1.0, 2.0], ExamplePointUDT()) # doctest: +IGNORE_EXCEPTION_DETAIL + Traceback (most recent call last): + ... + ValueError:... """ # all objects are nullable if obj is None: return + if isinstance(dataType, UserDefinedType): + if not (hasattr(obj, '__UDT__') and obj.__UDT__ == dataType): + raise ValueError("%r is not an instance of type %r" % (obj, dataType)) + _verify_type(dataType.serialize(obj), dataType.sqlType()) + return + _type = type(dataType) assert _type in _acceptable_types, "unkown datatype: %s" % dataType # subclass of them can not be deserialized in JVM if type(obj) not in _acceptable_types[_type]: - raise TypeError("%s can not accept abject in type %s" + raise TypeError("%s can not accept object in type %s" % (dataType, type(obj))) if isinstance(dataType, ArrayType): @@ -767,7 +1046,7 @@ def _restore_object(dataType, obj): """ Restore object during unpickling. """ # use id(dataType) as key to speed up lookup in dict # Because of batched pickling, dataType will be the - # same object in mose cases. + # same object in most cases. k = id(dataType) cls = _cached_cls.get(k) if cls is None: @@ -782,6 +1061,10 @@ def _restore_object(dataType, obj): def _create_object(cls, v): """ Create an customized object with class `cls`. """ + # datetime.date would be deserialized as datetime.datetime + # from java type, so we need to set it back. + if cls is datetime.date and isinstance(v, datetime.datetime): + return v.date() return cls(v) if v is not None else v @@ -795,14 +1078,18 @@ def getter(self): return getter -def _has_struct(dt): - """Return whether `dt` is or has StructType in it""" +def _has_struct_or_date(dt): + """Return whether `dt` is or has StructType/DateType in it""" if isinstance(dt, StructType): return True elif isinstance(dt, ArrayType): - return _has_struct(dt.elementType) + return _has_struct_or_date(dt.elementType) elif isinstance(dt, MapType): - return _has_struct(dt.valueType) + return _has_struct_or_date(dt.valueType) + elif isinstance(dt, DateType): + return True + elif isinstance(dt, UserDefinedType): + return True return False @@ -815,7 +1102,7 @@ def _create_properties(fields): or keyword.iskeyword(name)): warnings.warn("field name %s can not be accessed in Python," "use position to access it instead" % name) - if _has_struct(f.dataType): + if _has_struct_or_date(f.dataType): # delay creating object until accessing it getter = _create_getter(f.dataType, i) else: @@ -870,6 +1157,12 @@ def Dict(d): return Dict + elif isinstance(dataType, DateType): + return datetime.date + + elif isinstance(dataType, UserDefinedType): + return lambda datum: dataType.deserialize(datum) + elif not isinstance(dataType, StructType): raise Exception("unexpected data type: %s" % dataType) @@ -941,7 +1234,6 @@ def __init__(self, sparkContext, sqlContext=None): self._sc = sparkContext self._jsc = self._sc._jsc self._jvm = self._sc._jvm - self._pythonToJava = self._jvm.PythonRDD.pythonToJavaArray self._scala_SQLContext = sqlContext @property @@ -971,8 +1263,8 @@ def registerFunction(self, name, f, returnType=StringType()): """ func = lambda _, it: imap(lambda x: f(*x), it) command = (func, None, - BatchedSerializer(PickleSerializer(), 1024), - BatchedSerializer(PickleSerializer(), 1024)) + AutoBatchedSerializer(PickleSerializer()), + AutoBatchedSerializer(PickleSerializer())) ser = CloudPickleSerializer() pickled_command = ser.dumps(command) if len(pickled_command) > (1 << 20): # 1M @@ -995,18 +1287,20 @@ def registerFunction(self, name, f, returnType=StringType()): self._sc._javaAccumulator, returnType.json()) - def inferSchema(self, rdd): + def inferSchema(self, rdd, samplingRatio=None): """Infer and apply a schema to an RDD of L{Row}. - We peek at the first row of the RDD to determine the fields' names - and types. Nested collections are supported, which include array, - dict, list, Row, tuple, namedtuple, or object. + When samplingRatio is specified, the schema is inferred by looking + at the types of each row in the sampled dataset. Otherwise, the + first 100 rows of the RDD are inspected. Nested collections are + supported, which can include array, dict, list, Row, tuple, + namedtuple, or object. - All the rows in `rdd` should have the same type with the first one, - or it will cause runtime exceptions. + Each row could be L{pyspark.sql.Row} object or namedtuple or objects. + Using top level dicts is deprecated, as dict is used to represent Maps. - Each row could be L{pyspark.sql.Row} object or namedtuple or objects, - using dict is deprecated. + If a single column has multiple distinct inferred types, it may cause + runtime exceptions. >>> rdd = sc.parallelize( ... [Row(field1=1, field2="row1"), @@ -1043,8 +1337,23 @@ def inferSchema(self, rdd): warnings.warn("Using RDD of dict to inferSchema is deprecated," "please use pyspark.sql.Row instead") - schema = _infer_schema(first) - rdd = rdd.mapPartitions(lambda rows: _drop_schema(rows, schema)) + if samplingRatio is None: + schema = _infer_schema(first) + if _has_nulltype(schema): + for row in rdd.take(100)[1:]: + schema = _merge_type(schema, _infer_schema(row)) + if not _has_nulltype(schema): + break + else: + warnings.warn("Some of types cannot be determined by the " + "first 100 rows, please try again with sampling") + else: + if samplingRatio > 0.99: + rdd = rdd.sample(False, float(samplingRatio)) + schema = rdd.map(_infer_schema).reduce(_merge_type) + + converter = _create_converter(schema) + rdd = rdd.map(converter) return self.applySchema(rdd, schema) def applySchema(self, rdd, schema): @@ -1068,8 +1377,9 @@ def applySchema(self, rdd, schema): >>> srdd2.collect() [Row(field1=1, field2=u'row1'),..., Row(field1=3, field2=u'row3')] - >>> from datetime import datetime + >>> from datetime import date, datetime >>> rdd = sc.parallelize([(127, -128L, -32768, 32767, 2147483647L, 1.0, + ... date(2010, 1, 1), ... datetime(2010, 1, 1, 1, 1, 1), ... {"a": 1}, (2,), [1, 2, 3], None)]) >>> schema = StructType([ @@ -1079,6 +1389,7 @@ def applySchema(self, rdd, schema): ... StructField("short2", ShortType(), False), ... StructField("int", IntegerType(), False), ... StructField("float", FloatType(), False), + ... StructField("date", DateType(), False), ... StructField("time", TimestampType(), False), ... StructField("map", ... MapType(StringType(), IntegerType(), False), False), @@ -1088,10 +1399,11 @@ def applySchema(self, rdd, schema): ... StructField("null", DoubleType(), True)]) >>> srdd = sqlCtx.applySchema(rdd, schema) >>> results = srdd.map( - ... lambda x: (x.byte1, x.byte2, x.short1, x.short2, x.int, x.float, x.time, - ... x.map["a"], x.struct.b, x.list, x.null)) - >>> results.collect()[0] - (127, -128, -32768, 32767, 2147483647, 1.0, ...(2010, 1, 1, 1, 1, 1), 1, 2, [1, 2, 3], None) + ... lambda x: (x.byte1, x.byte2, x.short1, x.short2, x.int, x.float, x.date, + ... x.time, x.map["a"], x.struct.b, x.list, x.null)) + >>> results.collect()[0] # doctest: +NORMALIZE_WHITESPACE + (127, -128, -32768, 32767, 2147483647, 1.0, datetime.date(2010, 1, 1), + datetime.datetime(2010, 1, 1, 1, 1, 1), 1, 2, [1, 2, 3], None) >>> srdd.registerTempTable("table2") >>> sqlCtx.sql( @@ -1127,8 +1439,11 @@ def applySchema(self, rdd, schema): for row in rows: _verify_type(row, schema) - batched = isinstance(rdd._jrdd_deserializer, BatchedSerializer) - jrdd = self._pythonToJava(rdd._jrdd, batched) + # convert python objects to sql data + converter = _python_to_sql_converter(schema) + rdd = rdd.map(converter) + + jrdd = self._jvm.SerDeUtil.toJavaArray(rdd._to_java_object_rdd()) srdd = self._ssql_ctx.applySchemaToPythonRDD(jrdd.rdd(), schema.json()) return SchemaRDD(srdd.toJavaSchemaRDD(), self) @@ -1162,7 +1477,7 @@ def parquetFile(self, path): jschema_rdd = self._ssql_ctx.parquetFile(path).toJavaSchemaRDD() return SchemaRDD(jschema_rdd, self) - def jsonFile(self, path, schema=None): + def jsonFile(self, path, schema=None, samplingRatio=1.0): """ Loads a text file storing one JSON object per line as a L{SchemaRDD}. @@ -1170,8 +1485,8 @@ def jsonFile(self, path, schema=None): If the schema is provided, applies the given schema to this JSON dataset. - Otherwise, it goes through the entire dataset once to determine - the schema. + Otherwise, it samples the dataset with ratio `samplingRatio` to + determine the schema. >>> import tempfile, shutil >>> jsonFile = tempfile.mkdtemp() @@ -1217,20 +1532,20 @@ def jsonFile(self, path, schema=None): [Row(f1=u'row1', f2=None, f3=None)...Row(f1=u'row3', f2=[], f3=None)] """ if schema is None: - srdd = self._ssql_ctx.jsonFile(path) + srdd = self._ssql_ctx.jsonFile(path, samplingRatio) else: scala_datatype = self._ssql_ctx.parseDataType(schema.json()) srdd = self._ssql_ctx.jsonFile(path, scala_datatype) return SchemaRDD(srdd.toJavaSchemaRDD(), self) - def jsonRDD(self, rdd, schema=None): + def jsonRDD(self, rdd, schema=None, samplingRatio=1.0): """Loads an RDD storing one JSON object per string as a L{SchemaRDD}. If the schema is provided, applies the given schema to this JSON dataset. - Otherwise, it goes through the entire dataset once to determine - the schema. + Otherwise, it samples the dataset with ratio `samplingRatio` to + determine the schema. >>> srdd1 = sqlCtx.jsonRDD(json) >>> sqlCtx.registerRDDAsTable(srdd1, "table1") @@ -1287,7 +1602,7 @@ def func(iterator): keyed._bypass_serializer = True jrdd = keyed._jrdd.map(self._jvm.BytesToString()) if schema is None: - srdd = self._ssql_ctx.jsonRDD(jrdd.rdd()) + srdd = self._ssql_ctx.jsonRDD(jrdd.rdd(), samplingRatio) else: scala_datatype = self._ssql_ctx.parseDataType(schema.json()) srdd = self._ssql_ctx.jsonRDD(jrdd.rdd(), scala_datatype) @@ -1379,33 +1694,6 @@ def hql(self, hqlQuery): class LocalHiveContext(HiveContext): - """Starts up an instance of hive where metadata is stored locally. - - An in-process metadata data is created with data stored in ./metadata. - Warehouse data is stored in in ./warehouse. - - >>> import os - >>> hiveCtx = LocalHiveContext(sc) - >>> try: - ... supress = hiveCtx.sql("DROP TABLE src") - ... except Exception: - ... pass - >>> kv1 = os.path.join(os.environ["SPARK_HOME"], - ... 'examples/src/main/resources/kv1.txt') - >>> supress = hiveCtx.sql( - ... "CREATE TABLE IF NOT EXISTS src (key INT, value STRING)") - >>> supress = hiveCtx.sql("LOAD DATA LOCAL INPATH '%s' INTO TABLE src" - ... % kv1) - >>> results = hiveCtx.sql("FROM src SELECT value" - ... ).map(lambda r: int(r.value.split('_')[1])) - >>> num = results.count() - >>> reduce_sum = results.reduce(lambda x, y: x + y) - >>> num - 500 - >>> reduce_sum - 130091 - """ - def __init__(self, sparkContext, sqlContext=None): HiveContext.__init__(self, sparkContext, sqlContext) warnings.warn("LocalHiveContext is deprecated. " @@ -1552,7 +1840,7 @@ def __init__(self, jschema_rdd, sql_ctx): self.is_checkpointed = False self.ctx = self.sql_ctx._sc # the _jrdd is created by javaToPython(), serialized by pickle - self._jrdd_deserializer = BatchedSerializer(PickleSerializer()) + self._jrdd_deserializer = AutoBatchedSerializer(PickleSerializer()) @property def _jrdd(self): @@ -1782,15 +2070,13 @@ def subtract(self, other, numPartitions=None): def _test(): import doctest - from array import array from pyspark.context import SparkContext # let doctest run in pyspark.sql, so DataTypes can be picklable import pyspark.sql from pyspark.sql import Row, SQLContext + from pyspark.tests import ExamplePoint, ExamplePointUDT globs = pyspark.sql.__dict__.copy() - # The small batch size here ensures that we see multiple batches, - # even in these small test examples: - sc = SparkContext('local[4]', 'PythonTest', batchSize=2) + sc = SparkContext('local[4]', 'PythonTest') globs['sc'] = sc globs['sqlCtx'] = SQLContext(sc) globs['rdd'] = sc.parallelize( @@ -1798,6 +2084,8 @@ def _test(): Row(field1=2, field2="row2"), Row(field1=3, field2="row3")] ) + globs['ExamplePoint'] = ExamplePoint + globs['ExamplePointUDT'] = ExamplePointUDT jsonStrings = [ '{"field1": 1, "field2": "row1", "field3":{"field4":11}}', '{"field1" : 2, "field3":{"field4":22, "field5": [10, 11]},' diff --git a/python/pyspark/tests.py b/python/pyspark/tests.py index 437a4a387f59b..ba86f913c9638 100644 --- a/python/pyspark/tests.py +++ b/python/pyspark/tests.py @@ -50,7 +50,8 @@ from pyspark.serializers import read_int, BatchedSerializer, MarshalSerializer, PickleSerializer, \ CloudPickleSerializer from pyspark.shuffle import Aggregator, InMemoryMerger, ExternalMerger, ExternalSorter -from pyspark.sql import SQLContext, IntegerType, Row +from pyspark.sql import SQLContext, IntegerType, Row, ArrayType, StructType, StructField, \ + UserDefinedType, DoubleType from pyspark import shuffle _have_scipy = False @@ -271,7 +272,7 @@ class PySparkTestCase(unittest.TestCase): def setUp(self): self._old_sys_path = list(sys.path) class_name = self.__class__.__name__ - self.sc = SparkContext('local[4]', class_name, batchSize=2) + self.sc = SparkContext('local[4]', class_name) def tearDown(self): self.sc.stop() @@ -282,7 +283,7 @@ class ReusedPySparkTestCase(unittest.TestCase): @classmethod def setUpClass(cls): - cls.sc = SparkContext('local[4]', cls.__name__, batchSize=2) + cls.sc = SparkContext('local[4]', cls.__name__) @classmethod def tearDownClass(cls): @@ -691,6 +692,21 @@ def test_external_group_by_key(self): self.assertEqual(N/3, len(result)) self.assertTrue(isinstance(result.it, shuffle.ChainedIterable)) + def test_sample(self): + rdd = self.sc.parallelize(range(0, 100), 4) + wo = rdd.sample(False, 0.1, 2).collect() + wo_dup = rdd.sample(False, 0.1, 2).collect() + self.assertSetEqual(set(wo), set(wo_dup)) + wr = rdd.sample(True, 0.2, 5).collect() + wr_dup = rdd.sample(True, 0.2, 5).collect() + self.assertSetEqual(set(wr), set(wr_dup)) + wo_s10 = rdd.sample(False, 0.3, 10).collect() + wo_s20 = rdd.sample(False, 0.3, 20).collect() + self.assertNotEqual(set(wo_s10), set(wo_s20)) + wr_s11 = rdd.sample(True, 0.4, 11).collect() + wr_s21 = rdd.sample(True, 0.4, 21).collect() + self.assertNotEqual(set(wr_s11), set(wr_s21)) + class ProfilerTests(PySparkTestCase): @@ -698,7 +714,7 @@ def setUp(self): self._old_sys_path = list(sys.path) class_name = self.__class__.__name__ conf = SparkConf().set("spark.python.profile", "true") - self.sc = SparkContext('local[4]', class_name, batchSize=2, conf=conf) + self.sc = SparkContext('local[4]', class_name, conf=conf) def test_profiler(self): @@ -722,8 +738,65 @@ def heavy_foo(x): self.assertTrue("rdd_%d.pstats" % id in os.listdir(d)) +class ExamplePointUDT(UserDefinedType): + """ + User-defined type (UDT) for ExamplePoint. + """ + + @classmethod + def sqlType(self): + return ArrayType(DoubleType(), False) + + @classmethod + def module(cls): + return 'pyspark.tests' + + @classmethod + def scalaUDT(cls): + return 'org.apache.spark.sql.test.ExamplePointUDT' + + def serialize(self, obj): + return [obj.x, obj.y] + + def deserialize(self, datum): + return ExamplePoint(datum[0], datum[1]) + + +class ExamplePoint: + """ + An example class to demonstrate UDT in Scala, Java, and Python. + """ + + __UDT__ = ExamplePointUDT() + + def __init__(self, x, y): + self.x = x + self.y = y + + def __repr__(self): + return "ExamplePoint(%s,%s)" % (self.x, self.y) + + def __str__(self): + return "(%s,%s)" % (self.x, self.y) + + def __eq__(self, other): + return isinstance(other, ExamplePoint) and \ + other.x == self.x and other.y == self.y + + class SQLTests(ReusedPySparkTestCase): + @classmethod + def setUpClass(cls): + ReusedPySparkTestCase.setUpClass() + cls.tempdir = tempfile.NamedTemporaryFile(delete=False) + os.unlink(cls.tempdir.name) + + @classmethod + def tearDownClass(cls): + ReusedPySparkTestCase.tearDownClass() + shutil.rmtree(cls.tempdir.name) + def setUp(self): self.sqlCtx = SQLContext(self.sc) @@ -733,10 +806,20 @@ def test_udf(self): self.assertEqual(row[0], 5) def test_udf2(self): - self.sqlCtx.registerFunction("strlen", lambda string: len(string)) + self.sqlCtx.registerFunction("strlen", lambda string: len(string), IntegerType()) self.sqlCtx.inferSchema(self.sc.parallelize([Row(a="test")])).registerTempTable("test") [res] = self.sqlCtx.sql("SELECT strlen(a) FROM test WHERE strlen(a) > 1").collect() - self.assertEqual(u"4", res[0]) + self.assertEqual(4, res[0]) + + def test_udf_with_array_type(self): + d = [Row(l=range(3), d={"key": range(5)})] + rdd = self.sc.parallelize(d) + srdd = self.sqlCtx.inferSchema(rdd).registerTempTable("test") + self.sqlCtx.registerFunction("copylist", lambda l: list(l), ArrayType(IntegerType())) + self.sqlCtx.registerFunction("maplen", lambda d: len(d), IntegerType()) + [(l1, l2)] = self.sqlCtx.sql("select copylist(l), maplen(d) from test").collect() + self.assertEqual(range(3), l1) + self.assertEqual(1, l2) def test_broadcast_in_udf(self): bar = {"a": "aa", "b": "bb", "c": "abc"} @@ -814,6 +897,25 @@ def test_serialize_nested_array_and_map(self): self.assertEqual(1.0, row.c) self.assertEqual("2", row.d) + def test_infer_schema(self): + d = [Row(l=[], d={}), + Row(l=[Row(a=1, b='s')], d={"key": Row(c=1.0, d="2")}, s="")] + rdd = self.sc.parallelize(d) + srdd = self.sqlCtx.inferSchema(rdd) + self.assertEqual([], srdd.map(lambda r: r.l).first()) + self.assertEqual([None, ""], srdd.map(lambda r: r.s).collect()) + srdd.registerTempTable("test") + result = self.sqlCtx.sql("SELECT l[0].a from test where d['key'].d = '2'") + self.assertEqual(1, result.first()[0]) + + srdd2 = self.sqlCtx.inferSchema(rdd, 1.0) + self.assertEqual(srdd.schema(), srdd2.schema()) + self.assertEqual({}, srdd2.map(lambda r: r.d).first()) + self.assertEqual([None, ""], srdd2.map(lambda r: r.s).collect()) + srdd2.registerTempTable("test2") + result = self.sqlCtx.sql("SELECT l[0].a from test2 where d['key'].d = '2'") + self.assertEqual(1, result.first()[0]) + def test_convert_row_to_dict(self): row = Row(l=[Row(a=1, b='s')], d={"key": Row(c=1.0, d="2")}) self.assertEqual(1, row.asDict()['l'][0].a) @@ -823,6 +925,39 @@ def test_convert_row_to_dict(self): row = self.sqlCtx.sql("select l[0].a AS la from test").first() self.assertEqual(1, row.asDict()["la"]) + def test_infer_schema_with_udt(self): + from pyspark.tests import ExamplePoint, ExamplePointUDT + row = Row(label=1.0, point=ExamplePoint(1.0, 2.0)) + rdd = self.sc.parallelize([row]) + srdd = self.sqlCtx.inferSchema(rdd) + schema = srdd.schema() + field = [f for f in schema.fields if f.name == "point"][0] + self.assertEqual(type(field.dataType), ExamplePointUDT) + srdd.registerTempTable("labeled_point") + point = self.sqlCtx.sql("SELECT point FROM labeled_point").first().point + self.assertEqual(point, ExamplePoint(1.0, 2.0)) + + def test_apply_schema_with_udt(self): + from pyspark.tests import ExamplePoint, ExamplePointUDT + row = (1.0, ExamplePoint(1.0, 2.0)) + rdd = self.sc.parallelize([row]) + schema = StructType([StructField("label", DoubleType(), False), + StructField("point", ExamplePointUDT(), False)]) + srdd = self.sqlCtx.applySchema(rdd, schema) + point = srdd.first().point + self.assertEquals(point, ExamplePoint(1.0, 2.0)) + + def test_parquet_with_udt(self): + from pyspark.tests import ExamplePoint + row = Row(label=1.0, point=ExamplePoint(1.0, 2.0)) + rdd = self.sc.parallelize([row]) + srdd0 = self.sqlCtx.inferSchema(rdd) + output_dir = os.path.join(self.tempdir.name, "labeled_point") + srdd0.saveAsParquetFile(output_dir) + srdd1 = self.sqlCtx.parquetFile(output_dir) + point = srdd1.first().point + self.assertEquals(point, ExamplePoint(1.0, 2.0)) + class InputFormatTests(ReusedPySparkTestCase): @@ -920,16 +1055,19 @@ def test_sequencefiles(self): clazz = sorted(self.sc.sequenceFile(basepath + "/sftestdata/sfclass/", "org.apache.hadoop.io.Text", "org.apache.spark.api.python.TestWritable").collect()) - ec = (u'1', - {u'__class__': u'org.apache.spark.api.python.TestWritable', - u'double': 54.0, u'int': 123, u'str': u'test1'}) - self.assertEqual(clazz[0], ec) + cname = u'org.apache.spark.api.python.TestWritable' + ec = [(u'1', {u'__class__': cname, u'double': 1.0, u'int': 1, u'str': u'test1'}), + (u'2', {u'__class__': cname, u'double': 2.3, u'int': 2, u'str': u'test2'}), + (u'3', {u'__class__': cname, u'double': 3.1, u'int': 3, u'str': u'test3'}), + (u'4', {u'__class__': cname, u'double': 4.2, u'int': 4, u'str': u'test4'}), + (u'5', {u'__class__': cname, u'double': 5.5, u'int': 5, u'str': u'test56'})] + self.assertEqual(clazz, ec) unbatched_clazz = sorted(self.sc.sequenceFile(basepath + "/sftestdata/sfclass/", "org.apache.hadoop.io.Text", "org.apache.spark.api.python.TestWritable", - batchSize=1).collect()) - self.assertEqual(unbatched_clazz[0], ec) + ).collect()) + self.assertEqual(unbatched_clazz, ec) def test_oldhadoop(self): basepath = self.tempdir.name @@ -1249,51 +1387,6 @@ def test_reserialization(self): result5 = sorted(self.sc.sequenceFile(basepath + "/reserialize/newdataset").collect()) self.assertEqual(result5, data) - def test_unbatched_save_and_read(self): - basepath = self.tempdir.name - ei = [(1, u'aa'), (1, u'aa'), (2, u'aa'), (2, u'bb'), (2, u'bb'), (3, u'cc')] - self.sc.parallelize(ei, len(ei)).saveAsSequenceFile( - basepath + "/unbatched/") - - unbatched_sequence = sorted(self.sc.sequenceFile( - basepath + "/unbatched/", - batchSize=1).collect()) - self.assertEqual(unbatched_sequence, ei) - - unbatched_hadoopFile = sorted(self.sc.hadoopFile( - basepath + "/unbatched/", - "org.apache.hadoop.mapred.SequenceFileInputFormat", - "org.apache.hadoop.io.IntWritable", - "org.apache.hadoop.io.Text", - batchSize=1).collect()) - self.assertEqual(unbatched_hadoopFile, ei) - - unbatched_newAPIHadoopFile = sorted(self.sc.newAPIHadoopFile( - basepath + "/unbatched/", - "org.apache.hadoop.mapreduce.lib.input.SequenceFileInputFormat", - "org.apache.hadoop.io.IntWritable", - "org.apache.hadoop.io.Text", - batchSize=1).collect()) - self.assertEqual(unbatched_newAPIHadoopFile, ei) - - oldconf = {"mapred.input.dir": basepath + "/unbatched/"} - unbatched_hadoopRDD = sorted(self.sc.hadoopRDD( - "org.apache.hadoop.mapred.SequenceFileInputFormat", - "org.apache.hadoop.io.IntWritable", - "org.apache.hadoop.io.Text", - conf=oldconf, - batchSize=1).collect()) - self.assertEqual(unbatched_hadoopRDD, ei) - - newconf = {"mapred.input.dir": basepath + "/unbatched/"} - unbatched_newAPIHadoopRDD = sorted(self.sc.newAPIHadoopRDD( - "org.apache.hadoop.mapreduce.lib.input.SequenceFileInputFormat", - "org.apache.hadoop.io.IntWritable", - "org.apache.hadoop.io.Text", - conf=newconf, - batchSize=1).collect()) - self.assertEqual(unbatched_newAPIHadoopRDD, ei) - def test_malformed_RDD(self): basepath = self.tempdir.name # non-batch-serialized RDD[[(K, V)]] should be rejected diff --git a/python/run-tests b/python/run-tests index 80acd002ab7eb..a4f0cac059ff3 100755 --- a/python/run-tests +++ b/python/run-tests @@ -41,7 +41,7 @@ function run_test() { # Fail and exit on the first test failure. if [[ $FAILED != 0 ]]; then - cat unit-tests.log | grep -v "^[0-9][0-9]*" # filter all lines starting with a number. + cat $LOG_FILE | grep -v "^[0-9][0-9]*" # filter all lines starting with a number. echo -en "\033[31m" # Red echo "Had test failures; see logs." echo -en "\033[0m" # No color @@ -87,7 +87,7 @@ function run_streaming_tests() { run_test "pyspark/streaming/tests.py" } -echo "Running PySpark tests. Output is in python/unit-tests.log." +echo "Running PySpark tests. Output is in python/$LOG_FILE." export PYSPARK_PYTHON="python" diff --git a/sbin/spark-daemon.sh b/sbin/spark-daemon.sh index cba475e2dd8c8..89608bc41b71d 100755 --- a/sbin/spark-daemon.sh +++ b/sbin/spark-daemon.sh @@ -21,8 +21,8 @@ # # Environment Variables # -# SPARK_CONF_DIR Alternate conf dir. Default is ${SPARK_PREFIX}/conf. -# SPARK_LOG_DIR Where log files are stored. PWD by default. +# SPARK_CONF_DIR Alternate conf dir. Default is ${SPARK_HOME}/conf. +# SPARK_LOG_DIR Where log files are stored. ${SPARK_HOME}/logs by default. # SPARK_MASTER host:path where spark code should be rsync'd from # SPARK_PID_DIR The pid files are stored. /tmp by default. # SPARK_IDENT_STRING A string representing this instance of spark. $USER by default diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala index 3d4296f9d7068..9cda373623cb5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala @@ -19,82 +19,133 @@ package org.apache.spark.sql.catalyst import java.sql.{Date, Timestamp} -import org.apache.spark.sql.catalyst.expressions.{GenericRow, Attribute, AttributeReference} +import org.apache.spark.util.Utils +import org.apache.spark.sql.catalyst.annotation.SQLUserDefinedType +import org.apache.spark.sql.catalyst.expressions.{GenericRow, Attribute, AttributeReference, Row} import org.apache.spark.sql.catalyst.plans.logical.LocalRelation import org.apache.spark.sql.catalyst.types._ +import org.apache.spark.sql.catalyst.types.decimal.Decimal /** * Provides experimental support for generating catalyst schemas for scala objects. */ object ScalaReflection { + // The Predef.Map is scala.collection.immutable.Map. + // Since the map values can be mutable, we explicitly import scala.collection.Map at here. + import scala.collection.Map import scala.reflect.runtime.universe._ case class Schema(dataType: DataType, nullable: Boolean) - /** Converts Scala objects to catalyst rows / types */ - def convertToCatalyst(a: Any): Any = a match { - case o: Option[_] => o.map(convertToCatalyst).orNull - case s: Seq[_] => s.map(convertToCatalyst) - case m: Map[_, _] => m.map { case (k, v) => convertToCatalyst(k) -> convertToCatalyst(v) } - case p: Product => new GenericRow(p.productIterator.map(convertToCatalyst).toArray) - case other => other + /** + * Converts Scala objects to catalyst rows / types. + * Note: This is always called after schemaFor has been called. + * This ordering is important for UDT registration. + */ + def convertToCatalyst(a: Any, dataType: DataType): Any = (a, dataType) match { + // Check UDT first since UDTs can override other types + case (obj, udt: UserDefinedType[_]) => udt.serialize(obj) + case (o: Option[_], _) => o.map(convertToCatalyst(_, dataType)).orNull + case (s: Seq[_], arrayType: ArrayType) => s.map(convertToCatalyst(_, arrayType.elementType)) + case (m: Map[_, _], mapType: MapType) => m.map { case (k, v) => + convertToCatalyst(k, mapType.keyType) -> convertToCatalyst(v, mapType.valueType) + } + case (p: Product, structType: StructType) => + new GenericRow( + p.productIterator.toSeq.zip(structType.fields).map { case (elem, field) => + convertToCatalyst(elem, field.dataType) + }.toArray) + case (d: BigDecimal, _) => Decimal(d) + case (other, _) => other + } + + /** Converts Catalyst types used internally in rows to standard Scala types */ + def convertToScala(a: Any, dataType: DataType): Any = (a, dataType) match { + // Check UDT first since UDTs can override other types + case (d, udt: UserDefinedType[_]) => udt.deserialize(d) + case (s: Seq[_], arrayType: ArrayType) => s.map(convertToScala(_, arrayType.elementType)) + case (m: Map[_, _], mapType: MapType) => m.map { case (k, v) => + convertToScala(k, mapType.keyType) -> convertToScala(v, mapType.valueType) + } + case (r: Row, s: StructType) => convertRowToScala(r, s) + case (d: Decimal, _: DecimalType) => d.toBigDecimal + case (other, _) => other + } + + def convertRowToScala(r: Row, schema: StructType): Row = { + new GenericRow( + r.zip(schema.fields.map(_.dataType)) + .map(r_dt => convertToScala(r_dt._1, r_dt._2)).toArray) } /** Returns a Sequence of attributes for the given case class type. */ def attributesFor[T: TypeTag]: Seq[Attribute] = schemaFor[T] match { case Schema(s: StructType, _) => - s.fields.map(f => AttributeReference(f.name, f.dataType, f.nullable)()) + s.fields.map(f => AttributeReference(f.name, f.dataType, f.nullable, f.metadata)()) } /** Returns a catalyst DataType and its nullability for the given Scala Type using reflection. */ def schemaFor[T: TypeTag]: Schema = schemaFor(typeOf[T]) /** Returns a catalyst DataType and its nullability for the given Scala Type using reflection. */ - def schemaFor(tpe: `Type`): Schema = tpe match { - case t if t <:< typeOf[Option[_]] => - val TypeRef(_, _, Seq(optType)) = t - Schema(schemaFor(optType).dataType, nullable = true) - case t if t <:< typeOf[Product] => - val formalTypeArgs = t.typeSymbol.asClass.typeParams - val TypeRef(_, _, actualTypeArgs) = t - val params = t.member(nme.CONSTRUCTOR).asMethod.paramss - Schema(StructType( - params.head.map { p => - val Schema(dataType, nullable) = - schemaFor(p.typeSignature.substituteTypes(formalTypeArgs, actualTypeArgs)) - StructField(p.name.toString, dataType, nullable) - }), nullable = true) - // Need to decide if we actually need a special type here. - case t if t <:< typeOf[Array[Byte]] => Schema(BinaryType, nullable = true) - case t if t <:< typeOf[Array[_]] => - sys.error(s"Only Array[Byte] supported now, use Seq instead of $t") - case t if t <:< typeOf[Seq[_]] => - val TypeRef(_, _, Seq(elementType)) = t - val Schema(dataType, nullable) = schemaFor(elementType) - Schema(ArrayType(dataType, containsNull = nullable), nullable = true) - case t if t <:< typeOf[Map[_,_]] => - val TypeRef(_, _, Seq(keyType, valueType)) = t - val Schema(valueDataType, valueNullable) = schemaFor(valueType) - Schema(MapType(schemaFor(keyType).dataType, - valueDataType, valueContainsNull = valueNullable), nullable = true) - case t if t <:< typeOf[String] => Schema(StringType, nullable = true) - case t if t <:< typeOf[Timestamp] => Schema(TimestampType, nullable = true) - case t if t <:< typeOf[Date] => Schema(DateType, nullable = true) - case t if t <:< typeOf[BigDecimal] => Schema(DecimalType, nullable = true) - case t if t <:< typeOf[java.lang.Integer] => Schema(IntegerType, nullable = true) - case t if t <:< typeOf[java.lang.Long] => Schema(LongType, nullable = true) - case t if t <:< typeOf[java.lang.Double] => Schema(DoubleType, nullable = true) - case t if t <:< typeOf[java.lang.Float] => Schema(FloatType, nullable = true) - case t if t <:< typeOf[java.lang.Short] => Schema(ShortType, nullable = true) - case t if t <:< typeOf[java.lang.Byte] => Schema(ByteType, nullable = true) - case t if t <:< typeOf[java.lang.Boolean] => Schema(BooleanType, nullable = true) - case t if t <:< definitions.IntTpe => Schema(IntegerType, nullable = false) - case t if t <:< definitions.LongTpe => Schema(LongType, nullable = false) - case t if t <:< definitions.DoubleTpe => Schema(DoubleType, nullable = false) - case t if t <:< definitions.FloatTpe => Schema(FloatType, nullable = false) - case t if t <:< definitions.ShortTpe => Schema(ShortType, nullable = false) - case t if t <:< definitions.ByteTpe => Schema(ByteType, nullable = false) - case t if t <:< definitions.BooleanTpe => Schema(BooleanType, nullable = false) + def schemaFor(tpe: `Type`): Schema = { + val className: String = tpe.erasure.typeSymbol.asClass.fullName + tpe match { + case t if Utils.classIsLoadable(className) && + Utils.classForName(className).isAnnotationPresent(classOf[SQLUserDefinedType]) => + // Note: We check for classIsLoadable above since Utils.classForName uses Java reflection, + // whereas className is from Scala reflection. This can make it hard to find classes + // in some cases, such as when a class is enclosed in an object (in which case + // Java appends a '$' to the object name but Scala does not). + val udt = Utils.classForName(className) + .getAnnotation(classOf[SQLUserDefinedType]).udt().newInstance() + Schema(udt, nullable = true) + case t if t <:< typeOf[Option[_]] => + val TypeRef(_, _, Seq(optType)) = t + Schema(schemaFor(optType).dataType, nullable = true) + case t if t <:< typeOf[Product] => + val formalTypeArgs = t.typeSymbol.asClass.typeParams + val TypeRef(_, _, actualTypeArgs) = t + val params = t.member(nme.CONSTRUCTOR).asMethod.paramss + Schema(StructType( + params.head.map { p => + val Schema(dataType, nullable) = + schemaFor(p.typeSignature.substituteTypes(formalTypeArgs, actualTypeArgs)) + StructField(p.name.toString, dataType, nullable) + }), nullable = true) + // Need to decide if we actually need a special type here. + case t if t <:< typeOf[Array[Byte]] => Schema(BinaryType, nullable = true) + case t if t <:< typeOf[Array[_]] => + sys.error(s"Only Array[Byte] supported now, use Seq instead of $t") + case t if t <:< typeOf[Seq[_]] => + val TypeRef(_, _, Seq(elementType)) = t + val Schema(dataType, nullable) = schemaFor(elementType) + Schema(ArrayType(dataType, containsNull = nullable), nullable = true) + case t if t <:< typeOf[Map[_, _]] => + val TypeRef(_, _, Seq(keyType, valueType)) = t + val Schema(valueDataType, valueNullable) = schemaFor(valueType) + Schema(MapType(schemaFor(keyType).dataType, + valueDataType, valueContainsNull = valueNullable), nullable = true) + case t if t <:< typeOf[String] => Schema(StringType, nullable = true) + case t if t <:< typeOf[Timestamp] => Schema(TimestampType, nullable = true) + case t if t <:< typeOf[Date] => Schema(DateType, nullable = true) + case t if t <:< typeOf[BigDecimal] => Schema(DecimalType.Unlimited, nullable = true) + case t if t <:< typeOf[Decimal] => Schema(DecimalType.Unlimited, nullable = true) + case t if t <:< typeOf[java.lang.Integer] => Schema(IntegerType, nullable = true) + case t if t <:< typeOf[java.lang.Long] => Schema(LongType, nullable = true) + case t if t <:< typeOf[java.lang.Double] => Schema(DoubleType, nullable = true) + case t if t <:< typeOf[java.lang.Float] => Schema(FloatType, nullable = true) + case t if t <:< typeOf[java.lang.Short] => Schema(ShortType, nullable = true) + case t if t <:< typeOf[java.lang.Byte] => Schema(ByteType, nullable = true) + case t if t <:< typeOf[java.lang.Boolean] => Schema(BooleanType, nullable = true) + case t if t <:< definitions.IntTpe => Schema(IntegerType, nullable = false) + case t if t <:< definitions.LongTpe => Schema(LongType, nullable = false) + case t if t <:< definitions.DoubleTpe => Schema(DoubleType, nullable = false) + case t if t <:< definitions.FloatTpe => Schema(FloatType, nullable = false) + case t if t <:< definitions.ShortTpe => Schema(ShortType, nullable = false) + case t if t <:< definitions.ByteTpe => Schema(ByteType, nullable = false) + case t if t <:< definitions.BooleanTpe => Schema(BooleanType, nullable = false) + } } def typeOfObject: PartialFunction[Any, DataType] = { @@ -108,7 +159,9 @@ object ScalaReflection { case obj: LongType.JvmType => LongType case obj: FloatType.JvmType => FloatType case obj: DoubleType.JvmType => DoubleType - case obj: DecimalType.JvmType => DecimalType + case obj: DateType.JvmType => DateType + case obj: BigDecimal => DecimalType.Unlimited + case obj: Decimal => DecimalType.Unlimited case obj: TimestampType.JvmType => TimestampType case null => NullType // For other cases, there is no obvious mapping from the type of the given object to a diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SparkSQLParser.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SparkSQLParser.scala index 04467342e6ab5..f5c19ee69c37a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SparkSQLParser.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SparkSQLParser.scala @@ -61,7 +61,7 @@ class SqlLexical(val keywords: Seq[String]) extends StdLexical { delimiters += ( "@", "*", "+", "-", "<", "=", "<>", "!=", "<=", ">=", ">", "/", "(", ")", - ",", ";", "%", "{", "}", ":", "[", "]", "." + ",", ";", "%", "{", "}", ":", "[", "]", ".", "&", "|", "^", "~" ) override lazy val token: Parser[Token] = @@ -75,6 +75,8 @@ class SqlLexical(val keywords: Seq[String]) extends StdLexical { { case chars => StringLit(chars mkString "") } | '"' ~> chrExcept('"', '\n', EofCh).* <~ '"' ^^ { case chars => StringLit(chars mkString "") } + | '`' ~> chrExcept('`', '\n', EofCh).* <~ '`' ^^ + { case chars => Identifier(chars mkString "") } | EofCh ^^^ EOF | '\'' ~> failure("unclosed string literal") | '"' ~> failure("unclosed string literal") @@ -135,7 +137,6 @@ private[sql] class SparkSQLParser(fallback: String => LogicalPlan) extends Abstr protected val LAZY = Keyword("LAZY") protected val SET = Keyword("SET") protected val TABLE = Keyword("TABLE") - protected val SOURCE = Keyword("SOURCE") protected val UNCACHE = Keyword("UNCACHE") protected implicit def asParser(k: Keyword): Parser[String] = @@ -150,8 +151,7 @@ private[sql] class SparkSQLParser(fallback: String => LogicalPlan) extends Abstr override val lexical = new SqlLexical(reservedWords) - override protected lazy val start: Parser[LogicalPlan] = - cache | uncache | set | shell | source | others + override protected lazy val start: Parser[LogicalPlan] = cache | uncache | set | others private lazy val cache: Parser[LogicalPlan] = CACHE ~> LAZY.? ~ (TABLE ~> ident) ~ (AS ~> restInput).? ^^ { @@ -169,16 +169,6 @@ private[sql] class SparkSQLParser(fallback: String => LogicalPlan) extends Abstr case input => SetCommandParser(input) } - private lazy val shell: Parser[LogicalPlan] = - "!" ~> restInput ^^ { - case input => ShellCommand(input.trim) - } - - private lazy val source: Parser[LogicalPlan] = - SOURCE ~> restInput ^^ { - case input => SourceCommand(input.trim) - } - private lazy val others: Parser[LogicalPlan] = wholeInput ^^ { case input => fallback(input) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala index a277684f6327c..5e613e0f18ba6 100755 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala @@ -52,11 +52,13 @@ class SqlParser extends AbstractSparkSQLParser { protected val CASE = Keyword("CASE") protected val CAST = Keyword("CAST") protected val COUNT = Keyword("COUNT") + protected val DECIMAL = Keyword("DECIMAL") protected val DESC = Keyword("DESC") protected val DISTINCT = Keyword("DISTINCT") protected val ELSE = Keyword("ELSE") protected val END = Keyword("END") protected val EXCEPT = Keyword("EXCEPT") + protected val DOUBLE = Keyword("DOUBLE") protected val FALSE = Keyword("FALSE") protected val FIRST = Keyword("FIRST") protected val FROM = Keyword("FROM") @@ -142,7 +144,7 @@ class SqlParser extends AbstractSparkSQLParser { (LIMIT ~> expression).? ^^ { case d ~ p ~ r ~ f ~ g ~ h ~ o ~ l => val base = r.getOrElse(NoRelation) - val withFilter = f.map(f => Filter(f, base)).getOrElse(base) + val withFilter = f.map(Filter(_, base)).getOrElse(base) val withProjection = g .map(Aggregate(_, assignAliases(p), withFilter)) .getOrElse(Project(assignAliases(p), withFilter)) @@ -166,7 +168,8 @@ class SqlParser extends AbstractSparkSQLParser { // Based very loosely on the MySQL Grammar. // http://dev.mysql.com/doc/refman/5.0/en/join.html protected lazy val relations: Parser[LogicalPlan] = - ( relation ~ ("," ~> relation) ^^ { case r1 ~ r2 => Join(r1, r2, Inner, None) } + ( relation ~ rep1("," ~> relation) ^^ { + case r1 ~ joins => joins.foldLeft(r1) { case(lhs, r) => Join(lhs, r, Inner, None) } } | relation ) @@ -231,12 +234,15 @@ class SqlParser extends AbstractSparkSQLParser { | termExpression ~ (">=" ~> termExpression) ^^ { case e1 ~ e2 => GreaterThanOrEqual(e1, e2) } | termExpression ~ ("!=" ~> termExpression) ^^ { case e1 ~ e2 => Not(EqualTo(e1, e2)) } | termExpression ~ ("<>" ~> termExpression) ^^ { case e1 ~ e2 => Not(EqualTo(e1, e2)) } - | termExpression ~ (BETWEEN ~> termExpression) ~ (AND ~> termExpression) ^^ { - case e ~ el ~ eu => And(GreaterThanOrEqual(e, el), LessThanOrEqual(e, eu)) + | termExpression ~ NOT.? ~ (BETWEEN ~> termExpression) ~ (AND ~> termExpression) ^^ { + case e ~ not ~ el ~ eu => + val betweenExpr: Expression = And(GreaterThanOrEqual(e, el), LessThanOrEqual(e, eu)) + not.fold(betweenExpr)(f=> Not(betweenExpr)) } | termExpression ~ (RLIKE ~> termExpression) ^^ { case e1 ~ e2 => RLike(e1, e2) } | termExpression ~ (REGEXP ~> termExpression) ^^ { case e1 ~ e2 => RLike(e1, e2) } | termExpression ~ (LIKE ~> termExpression) ^^ { case e1 ~ e2 => Like(e1, e2) } + | termExpression ~ (NOT ~ LIKE ~> termExpression) ^^ { case e1 ~ e2 => Not(Like(e1, e2)) } | termExpression ~ (IN ~ "(" ~> rep1sep(termExpression, ",")) <~ ")" ^^ { case e1 ~ e2 => In(e1, e2) } @@ -260,6 +266,9 @@ class SqlParser extends AbstractSparkSQLParser { ( "*" ^^^ { (e1: Expression, e2: Expression) => Multiply(e1, e2) } | "/" ^^^ { (e1: Expression, e2: Expression) => Divide(e1, e2) } | "%" ^^^ { (e1: Expression, e2: Expression) => Remainder(e1, e2) } + | "&" ^^^ { (e1: Expression, e2: Expression) => BitwiseAnd(e1, e2) } + | "|" ^^^ { (e1: Expression, e2: Expression) => BitwiseOr(e1, e2) } + | "^" ^^^ { (e1: Expression, e2: Expression) => BitwiseXor(e1, e2) } ) protected lazy val function: Parser[Expression] = @@ -303,33 +312,74 @@ class SqlParser extends AbstractSparkSQLParser { CAST ~ "(" ~> expression ~ (AS ~> dataType) <~ ")" ^^ { case exp ~ t => Cast(exp, t) } protected lazy val literal: Parser[Literal] = - ( numericLit ^^ { - case i if i.toLong > Int.MaxValue => Literal(i.toLong) - case i => Literal(i.toInt) - } - | NULL ^^^ Literal(null, NullType) - | floatLit ^^ {case f => Literal(f.toDouble) } + ( numericLiteral + | booleanLiteral | stringLit ^^ {case s => Literal(s, StringType) } + | NULL ^^^ Literal(null, NullType) ) + protected lazy val booleanLiteral: Parser[Literal] = + ( TRUE ^^^ Literal(true, BooleanType) + | FALSE ^^^ Literal(false, BooleanType) + ) + + protected lazy val numericLiteral: Parser[Literal] = + signedNumericLiteral | unsignedNumericLiteral + + protected lazy val sign: Parser[String] = + "+" | "-" + + protected lazy val signedNumericLiteral: Parser[Literal] = + ( sign ~ numericLit ^^ { case s ~ l => Literal(toNarrowestIntegerType(s + l)) } + | sign ~ floatLit ^^ { case s ~ f => Literal((s + f).toDouble) } + ) + + protected lazy val unsignedNumericLiteral: Parser[Literal] = + ( numericLit ^^ { n => Literal(toNarrowestIntegerType(n)) } + | floatLit ^^ { f => Literal(f.toDouble) } + ) + + private val longMax = BigDecimal(s"${Long.MaxValue}") + private val longMin = BigDecimal(s"${Long.MinValue}") + private val intMax = BigDecimal(s"${Int.MaxValue}") + private val intMin = BigDecimal(s"${Int.MinValue}") + + private def toNarrowestIntegerType(value: String) = { + val bigIntValue = BigDecimal(value) + + bigIntValue match { + case v if v < longMin || v > longMax => v + case v if v < intMin || v > intMax => v.toLong + case v => v.toInt + } + } + protected lazy val floatLit: Parser[String] = - elem("decimal", _.isInstanceOf[lexical.FloatLit]) ^^ (_.chars) + ( "." ~> unsignedNumericLiteral ^^ { u => "0." + u } + | elem("decimal", _.isInstanceOf[lexical.FloatLit]) ^^ (_.chars) + ) + + protected lazy val baseExpression: Parser[Expression] = + ( "*" ^^^ Star(None) + | primary + ) + + protected lazy val signedPrimary: Parser[Expression] = + sign ~ primary ^^ { case s ~ e => if (s == "-") UnaryMinus(e) else e} - protected lazy val baseExpression: PackratParser[Expression] = - ( expression ~ ("[" ~> expression <~ "]") ^^ + protected lazy val primary: PackratParser[Expression] = + ( literal + | expression ~ ("[" ~> expression <~ "]") ^^ { case base ~ ordinal => GetItem(base, ordinal) } | (expression <~ ".") ~ ident ^^ { case base ~ fieldName => GetField(base, fieldName) } - | TRUE ^^^ Literal(true, BooleanType) - | FALSE ^^^ Literal(false, BooleanType) | cast | "(" ~> expression <~ ")" | function - | "-" ~> literal ^^ UnaryMinus | dotExpressionHeader | ident ^^ UnresolvedAttribute - | "*" ^^^ Star(None) - | literal + | signedPrimary + | "~" ~> expression ^^ BitwiseNot ) protected lazy val dotExpressionHeader: Parser[Expression] = @@ -338,5 +388,15 @@ class SqlParser extends AbstractSparkSQLParser { } protected lazy val dataType: Parser[DataType] = - STRING ^^^ StringType | TIMESTAMP ^^^ TimestampType + ( STRING ^^^ StringType + | TIMESTAMP ^^^ TimestampType + | DOUBLE ^^^ DoubleType + | fixedDecimalType + | DECIMAL ^^^ DecimalType.Unlimited + ) + + protected lazy val fixedDecimalType: Parser[DataType] = + (DECIMAL ~ "(" ~> numericLit) ~ ("," ~> numericLit <~ ")") ^^ { + case precision ~ scale => DecimalType(precision.toInt, scale.toInt) + } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Catalog.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Catalog.scala index 2059a91ba0612..0415d74bd8141 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Catalog.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Catalog.scala @@ -28,6 +28,8 @@ trait Catalog { def caseSensitive: Boolean + def tableExists(db: Option[String], tableName: String): Boolean + def lookupRelation( databaseName: Option[String], tableName: String, @@ -82,6 +84,14 @@ class SimpleCatalog(val caseSensitive: Boolean) extends Catalog { tables.clear() } + override def tableExists(db: Option[String], tableName: String): Boolean = { + val (dbName, tblName) = processDatabaseAndTableName(db, tableName) + tables.get(tblName) match { + case Some(_) => true + case None => false + } + } + override def lookupRelation( databaseName: Option[String], tableName: String, @@ -107,6 +117,14 @@ trait OverrideCatalog extends Catalog { // TODO: This doesn't work when the database changes... val overrides = new mutable.HashMap[(Option[String],String), LogicalPlan]() + abstract override def tableExists(db: Option[String], tableName: String): Boolean = { + val (dbName, tblName) = processDatabaseAndTableName(db, tableName) + overrides.get((dbName, tblName)) match { + case Some(_) => true + case None => super.tableExists(db, tableName) + } + } + abstract override def lookupRelation( databaseName: Option[String], tableName: String, @@ -149,6 +167,10 @@ object EmptyCatalog extends Catalog { val caseSensitive: Boolean = true + def tableExists(db: Option[String], tableName: String): Boolean = { + throw new UnsupportedOperationException + } + def lookupRelation( databaseName: Option[String], tableName: String, diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala index 2b69c02b28285..e38114ab3cf25 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala @@ -25,19 +25,31 @@ import org.apache.spark.sql.catalyst.types._ object HiveTypeCoercion { // See https://cwiki.apache.org/confluence/display/Hive/LanguageManual+Types. // The conversion for integral and floating point types have a linear widening hierarchy: - val numericPrecedence = - Seq(ByteType, ShortType, IntegerType, LongType, FloatType, DoubleType, DecimalType) - val allPromotions: Seq[Seq[DataType]] = numericPrecedence :: Nil + private val numericPrecedence = + Seq(ByteType, ShortType, IntegerType, LongType, FloatType, DoubleType, DecimalType.Unlimited) + /** + * Find the tightest common type of two types that might be used in a binary expression. + * This handles all numeric types except fixed-precision decimals interacting with each other or + * with primitive types, because in that case the precision and scale of the result depends on + * the operation. Those rules are implemented in [[HiveTypeCoercion.DecimalPrecision]]. + */ def findTightestCommonType(t1: DataType, t2: DataType): Option[DataType] = { val valueTypes = Seq(t1, t2).filter(t => t != NullType) if (valueTypes.distinct.size > 1) { - // Try and find a promotion rule that contains both types in question. - val applicableConversion = - HiveTypeCoercion.allPromotions.find(p => p.contains(t1) && p.contains(t2)) - - // If found return the widest common type, otherwise None - applicableConversion.map(_.filter(t => t == t1 || t == t2).last) + // Promote numeric types to the highest of the two and all numeric types to unlimited decimal + if (numericPrecedence.contains(t1) && numericPrecedence.contains(t2)) { + Some(numericPrecedence.filter(t => t == t1 || t == t2).last) + } else if (t1.isInstanceOf[DecimalType] && t2.isInstanceOf[DecimalType]) { + // Fixed-precision decimals can up-cast into unlimited + if (t1 == DecimalType.Unlimited || t2 == DecimalType.Unlimited) { + Some(DecimalType.Unlimited) + } else { + None + } + } else { + None + } } else { Some(if (valueTypes.size == 0) NullType else valueTypes.head) } @@ -59,6 +71,7 @@ trait HiveTypeCoercion { ConvertNaNs :: WidenTypes :: PromoteStrings :: + DecimalPrecision :: BooleanComparisons :: BooleanCasts :: StringToIntegralCasts :: @@ -151,6 +164,7 @@ trait HiveTypeCoercion { import HiveTypeCoercion._ def apply(plan: LogicalPlan): LogicalPlan = plan transform { + // TODO: unions with fixed-precision decimals case u @ Union(left, right) if u.childrenResolved && !u.resolved => val castedInput = left.output.zip(right.output).map { // When a string is found on one side, make the other side a string too. @@ -265,6 +279,110 @@ trait HiveTypeCoercion { } } + // scalastyle:off + /** + * Calculates and propagates precision for fixed-precision decimals. Hive has a number of + * rules for this based on the SQL standard and MS SQL: + * https://cwiki.apache.org/confluence/download/attachments/27362075/Hive_Decimal_Precision_Scale_Support.pdf + * + * In particular, if we have expressions e1 and e2 with precision/scale p1/s2 and p2/s2 + * respectively, then the following operations have the following precision / scale: + * + * Operation Result Precision Result Scale + * ------------------------------------------------------------------------ + * e1 + e2 max(s1, s2) + max(p1-s1, p2-s2) + 1 max(s1, s2) + * e1 - e2 max(s1, s2) + max(p1-s1, p2-s2) + 1 max(s1, s2) + * e1 * e2 p1 + p2 + 1 s1 + s2 + * e1 / e2 p1 - s1 + s2 + max(6, s1 + p2 + 1) max(6, s1 + p2 + 1) + * e1 % e2 min(p1-s1, p2-s2) + max(s1, s2) max(s1, s2) + * sum(e1) p1 + 10 s1 + * avg(e1) p1 + 4 s1 + 4 + * + * Catalyst also has unlimited-precision decimals. For those, all ops return unlimited precision. + * + * To implement the rules for fixed-precision types, we introduce casts to turn them to unlimited + * precision, do the math on unlimited-precision numbers, then introduce casts back to the + * required fixed precision. This allows us to do all rounding and overflow handling in the + * cast-to-fixed-precision operator. + * + * In addition, when mixing non-decimal types with decimals, we use the following rules: + * - BYTE gets turned into DECIMAL(3, 0) + * - SHORT gets turned into DECIMAL(5, 0) + * - INT gets turned into DECIMAL(10, 0) + * - LONG gets turned into DECIMAL(20, 0) + * - FLOAT and DOUBLE cause fixed-length decimals to turn into DOUBLE (this is the same as Hive, + * but note that unlimited decimals are considered bigger than doubles in WidenTypes) + */ + // scalastyle:on + object DecimalPrecision extends Rule[LogicalPlan] { + import scala.math.{max, min} + + // Conversion rules for integer types into fixed-precision decimals + val intTypeToFixed: Map[DataType, DecimalType] = Map( + ByteType -> DecimalType(3, 0), + ShortType -> DecimalType(5, 0), + IntegerType -> DecimalType(10, 0), + LongType -> DecimalType(20, 0) + ) + + def isFloat(t: DataType): Boolean = t == FloatType || t == DoubleType + + def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions { + // Skip nodes whose children have not been resolved yet + case e if !e.childrenResolved => e + + case Add(e1 @ DecimalType.Expression(p1, s1), e2 @ DecimalType.Expression(p2, s2)) => + Cast( + Add(Cast(e1, DecimalType.Unlimited), Cast(e2, DecimalType.Unlimited)), + DecimalType(max(s1, s2) + max(p1 - s1, p2 - s2) + 1, max(s1, s2)) + ) + + case Subtract(e1 @ DecimalType.Expression(p1, s1), e2 @ DecimalType.Expression(p2, s2)) => + Cast( + Subtract(Cast(e1, DecimalType.Unlimited), Cast(e2, DecimalType.Unlimited)), + DecimalType(max(s1, s2) + max(p1 - s1, p2 - s2) + 1, max(s1, s2)) + ) + + case Multiply(e1 @ DecimalType.Expression(p1, s1), e2 @ DecimalType.Expression(p2, s2)) => + Cast( + Multiply(Cast(e1, DecimalType.Unlimited), Cast(e2, DecimalType.Unlimited)), + DecimalType(p1 + p2 + 1, s1 + s2) + ) + + case Divide(e1 @ DecimalType.Expression(p1, s1), e2 @ DecimalType.Expression(p2, s2)) => + Cast( + Divide(Cast(e1, DecimalType.Unlimited), Cast(e2, DecimalType.Unlimited)), + DecimalType(p1 - s1 + s2 + max(6, s1 + p2 + 1), max(6, s1 + p2 + 1)) + ) + + case Remainder(e1 @ DecimalType.Expression(p1, s1), e2 @ DecimalType.Expression(p2, s2)) => + Cast( + Remainder(Cast(e1, DecimalType.Unlimited), Cast(e2, DecimalType.Unlimited)), + DecimalType(min(p1 - s1, p2 - s2) + max(s1, s2), max(s1, s2)) + ) + + // Promote integers inside a binary expression with fixed-precision decimals to decimals, + // and fixed-precision decimals in an expression with floats / doubles to doubles + case b: BinaryExpression if b.left.dataType != b.right.dataType => + (b.left.dataType, b.right.dataType) match { + case (t, DecimalType.Fixed(p, s)) if intTypeToFixed.contains(t) => + b.makeCopy(Array(Cast(b.left, intTypeToFixed(t)), b.right)) + case (DecimalType.Fixed(p, s), t) if intTypeToFixed.contains(t) => + b.makeCopy(Array(b.left, Cast(b.right, intTypeToFixed(t)))) + case (t, DecimalType.Fixed(p, s)) if isFloat(t) => + b.makeCopy(Array(b.left, Cast(b.right, DoubleType))) + case (DecimalType.Fixed(p, s), t) if isFloat(t) => + b.makeCopy(Array(Cast(b.left, DoubleType), b.right)) + case _ => + b + } + + // TODO: MaxOf, MinOf, etc might want other rules + + // SUM and AVERAGE are handled by the implementations of those expressions + } + } + /** * Changes Boolean values to Bytes so that expressions like true < false can be Evaluated. */ @@ -330,7 +448,7 @@ trait HiveTypeCoercion { case e if !e.childrenResolved => e case Cast(e @ StringType(), t: IntegralType) => - Cast(Cast(e, DecimalType), t) + Cast(Cast(e, DecimalType.Unlimited), t) } } @@ -383,10 +501,12 @@ trait HiveTypeCoercion { // Decimal and Double remain the same case d: Divide if d.resolved && d.dataType == DoubleType => d - case d: Divide if d.resolved && d.dataType == DecimalType => d + case d: Divide if d.resolved && d.dataType.isInstanceOf[DecimalType] => d - case Divide(l, r) if l.dataType == DecimalType => Divide(l, Cast(r, DecimalType)) - case Divide(l, r) if r.dataType == DecimalType => Divide(Cast(l, DecimalType), r) + case Divide(l, r) if l.dataType.isInstanceOf[DecimalType] => + Divide(l, Cast(r, DecimalType.Unlimited)) + case Divide(l, r) if r.dataType.isInstanceOf[DecimalType] => + Divide(Cast(l, DecimalType.Unlimited), r) case Divide(l, r) => Divide(Cast(l, DoubleType), Cast(r, DoubleType)) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/annotation/SQLUserDefinedType.java b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/annotation/SQLUserDefinedType.java new file mode 100644 index 0000000000000..e966aeea1cb23 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/annotation/SQLUserDefinedType.java @@ -0,0 +1,46 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.annotation; + +import java.lang.annotation.*; + +import org.apache.spark.annotation.DeveloperApi; +import org.apache.spark.sql.catalyst.types.UserDefinedType; + +/** + * ::DeveloperApi:: + * A user-defined type which can be automatically recognized by a SQLContext and registered. + * + * WARNING: This annotation will only work if both Java and Scala reflection return the same class + * names (after erasure) for the UDT. This will NOT be the case when, e.g., the UDT class + * is enclosed in an object (a singleton). + * + * WARNING: UDTs are currently only supported from Scala. + */ +// TODO: Should I used @Documented ? +@DeveloperApi +@Retention(RetentionPolicy.RUNTIME) +@Target(ElementType.TYPE) +public @interface SQLUserDefinedType { + + /** + * Returns an instance of the UserDefinedType which can serialize and deserialize the user + * class to and from Catalyst built-in types. + */ + Class > udt(); +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala index 75b6e37c2a1f9..3314e15477016 100755 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala @@ -19,7 +19,10 @@ package org.apache.spark.sql.catalyst import java.sql.{Date, Timestamp} +import org.apache.spark.sql.catalyst.types.decimal.Decimal + import scala.language.implicitConversions +import scala.reflect.runtime.universe.{TypeTag, typeTag} import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute import org.apache.spark.sql.catalyst.expressions._ @@ -62,12 +65,16 @@ package object dsl { def unary_- = UnaryMinus(expr) def unary_! = Not(expr) + def unary_~ = BitwiseNot(expr) def + (other: Expression) = Add(expr, other) def - (other: Expression) = Subtract(expr, other) def * (other: Expression) = Multiply(expr, other) def / (other: Expression) = Divide(expr, other) def % (other: Expression) = Remainder(expr, other) + def & (other: Expression) = BitwiseAnd(expr, other) + def | (other: Expression) = BitwiseOr(expr, other) + def ^ (other: Expression) = BitwiseXor(expr, other) def && (other: Expression) = And(expr, other) def || (other: Expression) = Or(expr, other) @@ -120,7 +127,8 @@ package object dsl { implicit def doubleToLiteral(d: Double) = Literal(d) implicit def stringToLiteral(s: String) = Literal(s) implicit def dateToLiteral(d: Date) = Literal(d) - implicit def decimalToLiteral(d: BigDecimal) = Literal(d) + implicit def bigDecimalToLiteral(d: BigDecimal) = Literal(d) + implicit def decimalToLiteral(d: Decimal) = Literal(d) implicit def timestampToLiteral(t: Timestamp) = Literal(t) implicit def binaryToLiteral(a: Array[Byte]) = Literal(a) @@ -179,7 +187,11 @@ package object dsl { def date = AttributeReference(s, DateType, nullable = true)() /** Creates a new AttributeReference of type decimal */ - def decimal = AttributeReference(s, DecimalType, nullable = true)() + def decimal = AttributeReference(s, DecimalType.Unlimited, nullable = true)() + + /** Creates a new AttributeReference of type decimal */ + def decimal(precision: Int, scale: Int) = + AttributeReference(s, DecimalType(precision, scale), nullable = true)() /** Creates a new AttributeReference of type timestamp */ def timestamp = AttributeReference(s, TimestampType, nullable = true)() @@ -274,4 +286,62 @@ package object dsl { def writeToFile(path: String) = WriteToFile(path, logicalPlan) } } + + case class ScalaUdfBuilder[T: TypeTag](f: AnyRef) { + def call(args: Expression*) = ScalaUdf(f, ScalaReflection.schemaFor(typeTag[T]).dataType, args) + } + + // scalastyle:off + /** functionToUdfBuilder 1-22 were generated by this script + + (1 to 22).map { x => + val argTypes = Seq.fill(x)("_").mkString(", ") + s"implicit def functionToUdfBuilder[T: TypeTag](func: Function$x[$argTypes, T]) = ScalaUdfBuilder(func)" + } + */ + + implicit def functionToUdfBuilder[T: TypeTag](func: Function1[_, T]) = ScalaUdfBuilder(func) + + implicit def functionToUdfBuilder[T: TypeTag](func: Function2[_, _, T]) = ScalaUdfBuilder(func) + + implicit def functionToUdfBuilder[T: TypeTag](func: Function3[_, _, _, T]) = ScalaUdfBuilder(func) + + implicit def functionToUdfBuilder[T: TypeTag](func: Function4[_, _, _, _, T]) = ScalaUdfBuilder(func) + + implicit def functionToUdfBuilder[T: TypeTag](func: Function5[_, _, _, _, _, T]) = ScalaUdfBuilder(func) + + implicit def functionToUdfBuilder[T: TypeTag](func: Function6[_, _, _, _, _, _, T]) = ScalaUdfBuilder(func) + + implicit def functionToUdfBuilder[T: TypeTag](func: Function7[_, _, _, _, _, _, _, T]) = ScalaUdfBuilder(func) + + implicit def functionToUdfBuilder[T: TypeTag](func: Function8[_, _, _, _, _, _, _, _, T]) = ScalaUdfBuilder(func) + + implicit def functionToUdfBuilder[T: TypeTag](func: Function9[_, _, _, _, _, _, _, _, _, T]) = ScalaUdfBuilder(func) + + implicit def functionToUdfBuilder[T: TypeTag](func: Function10[_, _, _, _, _, _, _, _, _, _, T]) = ScalaUdfBuilder(func) + + implicit def functionToUdfBuilder[T: TypeTag](func: Function11[_, _, _, _, _, _, _, _, _, _, _, T]) = ScalaUdfBuilder(func) + + implicit def functionToUdfBuilder[T: TypeTag](func: Function12[_, _, _, _, _, _, _, _, _, _, _, _, T]) = ScalaUdfBuilder(func) + + implicit def functionToUdfBuilder[T: TypeTag](func: Function13[_, _, _, _, _, _, _, _, _, _, _, _, _, T]) = ScalaUdfBuilder(func) + + implicit def functionToUdfBuilder[T: TypeTag](func: Function14[_, _, _, _, _, _, _, _, _, _, _, _, _, _, T]) = ScalaUdfBuilder(func) + + implicit def functionToUdfBuilder[T: TypeTag](func: Function15[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, T]) = ScalaUdfBuilder(func) + + implicit def functionToUdfBuilder[T: TypeTag](func: Function16[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, T]) = ScalaUdfBuilder(func) + + implicit def functionToUdfBuilder[T: TypeTag](func: Function17[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, T]) = ScalaUdfBuilder(func) + + implicit def functionToUdfBuilder[T: TypeTag](func: Function18[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, T]) = ScalaUdfBuilder(func) + + implicit def functionToUdfBuilder[T: TypeTag](func: Function19[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, T]) = ScalaUdfBuilder(func) + + implicit def functionToUdfBuilder[T: TypeTag](func: Function20[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, T]) = ScalaUdfBuilder(func) + + implicit def functionToUdfBuilder[T: TypeTag](func: Function21[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, T]) = ScalaUdfBuilder(func) + + implicit def functionToUdfBuilder[T: TypeTag](func: Function22[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, T]) = ScalaUdfBuilder(func) + // scalastyle:on } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/AttributeMap.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/AttributeMap.scala index 8364379644c90..82e760b6c6916 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/AttributeMap.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/AttributeMap.scala @@ -23,8 +23,7 @@ package org.apache.spark.sql.catalyst.expressions * of the name, or the expected nullability). */ object AttributeMap { - def apply[A](kvs: Seq[(Attribute, A)]) = - new AttributeMap(kvs.map(kv => (kv._1.exprId, (kv._1, kv._2))).toMap) + def apply[A](kvs: Seq[(Attribute, A)]) = new AttributeMap(kvs.map(kv => (kv._1.exprId, kv)).toMap) } class AttributeMap[A](baseMap: Map[ExprId, (Attribute, A)]) @@ -32,10 +31,9 @@ class AttributeMap[A](baseMap: Map[ExprId, (Attribute, A)]) override def get(k: Attribute): Option[A] = baseMap.get(k.exprId).map(_._2) - override def + [B1 >: A](kv: (Attribute, B1)): Map[Attribute, B1] = - (baseMap.map(_._2) + kv).toMap + override def + [B1 >: A](kv: (Attribute, B1)): Map[Attribute, B1] = baseMap.values.toMap + kv - override def iterator: Iterator[(Attribute, A)] = baseMap.map(_._2).iterator + override def iterator: Iterator[(Attribute, A)] = baseMap.valuesIterator - override def -(key: Attribute): Map[Attribute, A] = (baseMap.map(_._2) - key).toMap + override def -(key: Attribute): Map[Attribute, A] = baseMap.values.toMap - key } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala index 8e5baf0eb82d6..22009666196a1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala @@ -23,6 +23,7 @@ import java.text.{DateFormat, SimpleDateFormat} import org.apache.spark.Logging import org.apache.spark.sql.catalyst.errors.TreeNodeException import org.apache.spark.sql.catalyst.types._ +import org.apache.spark.sql.catalyst.types.decimal.Decimal /** Cast the child expression to the target data type. */ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression with Logging { @@ -36,6 +37,7 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w case (BooleanType, DateType) => true case (DateType, _: NumericType) => true case (DateType, BooleanType) => true + case (_, DecimalType.Fixed(_, _)) => true // TODO: not all upcasts here can really give null case _ => child.nullable } @@ -76,8 +78,8 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w buildCast[Short](_, _ != 0) case ByteType => buildCast[Byte](_, _ != 0) - case DecimalType => - buildCast[BigDecimal](_, _ != 0) + case DecimalType() => + buildCast[Decimal](_, _ != 0) case DoubleType => buildCast[Double](_, _ != 0) case FloatType => @@ -109,19 +111,19 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w case DateType => buildCast[Date](_, d => new Timestamp(d.getTime)) // TimestampWritable.decimalToTimestamp - case DecimalType => - buildCast[BigDecimal](_, d => decimalToTimestamp(d)) + case DecimalType() => + buildCast[Decimal](_, d => decimalToTimestamp(d)) // TimestampWritable.doubleToTimestamp case DoubleType => - buildCast[Double](_, d => decimalToTimestamp(d)) + buildCast[Double](_, d => decimalToTimestamp(Decimal(d))) // TimestampWritable.floatToTimestamp case FloatType => - buildCast[Float](_, f => decimalToTimestamp(f)) + buildCast[Float](_, f => decimalToTimestamp(Decimal(f))) } - private[this] def decimalToTimestamp(d: BigDecimal) = { + private[this] def decimalToTimestamp(d: Decimal) = { val seconds = Math.floor(d.toDouble).toLong - val bd = (d - seconds) * 1000000000 + val bd = (d.toBigDecimal - seconds) * 1000000000 val nanos = bd.intValue() val millis = seconds * 1000 @@ -196,8 +198,8 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w buildCast[Date](_, d => dateToLong(d)) case TimestampType => buildCast[Timestamp](_, t => timestampToLong(t)) - case DecimalType => - buildCast[BigDecimal](_, _.toLong) + case DecimalType() => + buildCast[Decimal](_, _.toLong) case x: NumericType => b => x.numeric.asInstanceOf[Numeric[Any]].toLong(b) } @@ -214,8 +216,8 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w buildCast[Date](_, d => dateToLong(d)) case TimestampType => buildCast[Timestamp](_, t => timestampToLong(t).toInt) - case DecimalType => - buildCast[BigDecimal](_, _.toInt) + case DecimalType() => + buildCast[Decimal](_, _.toInt) case x: NumericType => b => x.numeric.asInstanceOf[Numeric[Any]].toInt(b) } @@ -232,8 +234,8 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w buildCast[Date](_, d => dateToLong(d)) case TimestampType => buildCast[Timestamp](_, t => timestampToLong(t).toShort) - case DecimalType => - buildCast[BigDecimal](_, _.toShort) + case DecimalType() => + buildCast[Decimal](_, _.toShort) case x: NumericType => b => x.numeric.asInstanceOf[Numeric[Any]].toInt(b).toShort } @@ -250,27 +252,45 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w buildCast[Date](_, d => dateToLong(d)) case TimestampType => buildCast[Timestamp](_, t => timestampToLong(t).toByte) - case DecimalType => - buildCast[BigDecimal](_, _.toByte) + case DecimalType() => + buildCast[Decimal](_, _.toByte) case x: NumericType => b => x.numeric.asInstanceOf[Numeric[Any]].toInt(b).toByte } - // DecimalConverter - private[this] def castToDecimal: Any => Any = child.dataType match { + /** + * Change the precision / scale in a given decimal to those set in `decimalType` (if any), + * returning null if it overflows or modifying `value` in-place and returning it if successful. + * + * NOTE: this modifies `value` in-place, so don't call it on external data. + */ + private[this] def changePrecision(value: Decimal, decimalType: DecimalType): Decimal = { + decimalType match { + case DecimalType.Unlimited => + value + case DecimalType.Fixed(precision, scale) => + if (value.changePrecision(precision, scale)) value else null + } + } + + private[this] def castToDecimal(target: DecimalType): Any => Any = child.dataType match { case StringType => - buildCast[String](_, s => try BigDecimal(s.toDouble) catch { + buildCast[String](_, s => try changePrecision(Decimal(s.toDouble), target) catch { case _: NumberFormatException => null }) case BooleanType => - buildCast[Boolean](_, b => if (b) BigDecimal(1) else BigDecimal(0)) + buildCast[Boolean](_, b => changePrecision(if (b) Decimal(1) else Decimal(0), target)) case DateType => - buildCast[Date](_, d => dateToDouble(d)) + buildCast[Date](_, d => changePrecision(null, target)) // date can't cast to decimal in Hive case TimestampType => // Note that we lose precision here. - buildCast[Timestamp](_, t => BigDecimal(timestampToDouble(t))) - case x: NumericType => - b => BigDecimal(x.numeric.asInstanceOf[Numeric[Any]].toDouble(b)) + buildCast[Timestamp](_, t => changePrecision(Decimal(timestampToDouble(t)), target)) + case DecimalType() => + b => changePrecision(b.asInstanceOf[Decimal].clone(), target) + case LongType => + b => changePrecision(Decimal(b.asInstanceOf[Long]), target) + case x: NumericType => // All other numeric types can be represented precisely as Doubles + b => changePrecision(Decimal(x.numeric.asInstanceOf[Numeric[Any]].toDouble(b)), target) } // DoubleConverter @@ -285,8 +305,8 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w buildCast[Date](_, d => dateToDouble(d)) case TimestampType => buildCast[Timestamp](_, t => timestampToDouble(t)) - case DecimalType => - buildCast[BigDecimal](_, _.toDouble) + case DecimalType() => + buildCast[Decimal](_, _.toDouble) case x: NumericType => b => x.numeric.asInstanceOf[Numeric[Any]].toDouble(b) } @@ -303,8 +323,8 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w buildCast[Date](_, d => dateToDouble(d)) case TimestampType => buildCast[Timestamp](_, t => timestampToDouble(t).toFloat) - case DecimalType => - buildCast[BigDecimal](_, _.toFloat) + case DecimalType() => + buildCast[Decimal](_, _.toFloat) case x: NumericType => b => x.numeric.asInstanceOf[Numeric[Any]].toFloat(b) } @@ -313,8 +333,8 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w case dt if dt == child.dataType => identity[Any] case StringType => castToString case BinaryType => castToBinary - case DecimalType => castToDecimal case DateType => castToDate + case decimal: DecimalType => castToDecimal(decimal) case TimestampType => castToTimestamp case BooleanType => castToBoolean case ByteType => castToByte diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala index 1eb260efa6387..39b120e8de485 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala @@ -21,6 +21,7 @@ import org.apache.spark.sql.catalyst.errors.TreeNodeException import org.apache.spark.sql.catalyst.trees import org.apache.spark.sql.catalyst.trees.TreeNode import org.apache.spark.sql.catalyst.types.{DataType, FractionalType, IntegralType, NumericType, NativeType} +import org.apache.spark.sql.catalyst.util.Metadata abstract class Expression extends TreeNode[Expression] { self: Product => diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUdf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUdf.scala index 1b687a443ef8b..18c96da2f87fb 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUdf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUdf.scala @@ -21,6 +21,10 @@ import org.apache.spark.sql.catalyst.ScalaReflection import org.apache.spark.sql.catalyst.types.DataType import org.apache.spark.util.ClosureCleaner +/** + * User-defined function. + * @param dataType Return type of function. + */ case class ScalaUdf(function: AnyRef, dataType: DataType, children: Seq[Expression]) extends Expression { @@ -30,323 +34,369 @@ case class ScalaUdf(function: AnyRef, dataType: DataType, children: Seq[Expressi override def toString = s"scalaUDF(${children.mkString(",")})" + // scalastyle:off + /** This method has been generated by this script (1 to 22).map { x => val anys = (1 to x).map(x => "Any").reduce(_ + ", " + _) - val evals = (0 to x - 1).map(x => s"children($x).eval(input)").reduce(_ + ",\n " + _) + val evals = (0 to x - 1).map(x => s" ScalaReflection.convertToScala(children($x).eval(input), children($x).dataType)").reduce(_ + ",\n " + _) s""" case $x => function.asInstanceOf[($anys) => Any]( - $evals) + $evals) """ - } + }.foreach(println) */ - // scalastyle:off override def eval(input: Row): Any = { val result = children.size match { case 0 => function.asInstanceOf[() => Any]() - case 1 => function.asInstanceOf[(Any) => Any](children(0).eval(input)) + case 1 => + function.asInstanceOf[(Any) => Any]( + ScalaReflection.convertToScala(children(0).eval(input), children(0).dataType)) + + case 2 => function.asInstanceOf[(Any, Any) => Any]( - children(0).eval(input), - children(1).eval(input)) + ScalaReflection.convertToScala(children(0).eval(input), children(0).dataType), + ScalaReflection.convertToScala(children(1).eval(input), children(1).dataType)) + + case 3 => function.asInstanceOf[(Any, Any, Any) => Any]( - children(0).eval(input), - children(1).eval(input), - children(2).eval(input)) + ScalaReflection.convertToScala(children(0).eval(input), children(0).dataType), + ScalaReflection.convertToScala(children(1).eval(input), children(1).dataType), + ScalaReflection.convertToScala(children(2).eval(input), children(2).dataType)) + + case 4 => function.asInstanceOf[(Any, Any, Any, Any) => Any]( - children(0).eval(input), - children(1).eval(input), - children(2).eval(input), - children(3).eval(input)) + ScalaReflection.convertToScala(children(0).eval(input), children(0).dataType), + ScalaReflection.convertToScala(children(1).eval(input), children(1).dataType), + ScalaReflection.convertToScala(children(2).eval(input), children(2).dataType), + ScalaReflection.convertToScala(children(3).eval(input), children(3).dataType)) + + case 5 => function.asInstanceOf[(Any, Any, Any, Any, Any) => Any]( - children(0).eval(input), - children(1).eval(input), - children(2).eval(input), - children(3).eval(input), - children(4).eval(input)) + ScalaReflection.convertToScala(children(0).eval(input), children(0).dataType), + ScalaReflection.convertToScala(children(1).eval(input), children(1).dataType), + ScalaReflection.convertToScala(children(2).eval(input), children(2).dataType), + ScalaReflection.convertToScala(children(3).eval(input), children(3).dataType), + ScalaReflection.convertToScala(children(4).eval(input), children(4).dataType)) + + case 6 => function.asInstanceOf[(Any, Any, Any, Any, Any, Any) => Any]( - children(0).eval(input), - children(1).eval(input), - children(2).eval(input), - children(3).eval(input), - children(4).eval(input), - children(5).eval(input)) + ScalaReflection.convertToScala(children(0).eval(input), children(0).dataType), + ScalaReflection.convertToScala(children(1).eval(input), children(1).dataType), + ScalaReflection.convertToScala(children(2).eval(input), children(2).dataType), + ScalaReflection.convertToScala(children(3).eval(input), children(3).dataType), + ScalaReflection.convertToScala(children(4).eval(input), children(4).dataType), + ScalaReflection.convertToScala(children(5).eval(input), children(5).dataType)) + + case 7 => function.asInstanceOf[(Any, Any, Any, Any, Any, Any, Any) => Any]( - children(0).eval(input), - children(1).eval(input), - children(2).eval(input), - children(3).eval(input), - children(4).eval(input), - children(5).eval(input), - children(6).eval(input)) + ScalaReflection.convertToScala(children(0).eval(input), children(0).dataType), + ScalaReflection.convertToScala(children(1).eval(input), children(1).dataType), + ScalaReflection.convertToScala(children(2).eval(input), children(2).dataType), + ScalaReflection.convertToScala(children(3).eval(input), children(3).dataType), + ScalaReflection.convertToScala(children(4).eval(input), children(4).dataType), + ScalaReflection.convertToScala(children(5).eval(input), children(5).dataType), + ScalaReflection.convertToScala(children(6).eval(input), children(6).dataType)) + + case 8 => function.asInstanceOf[(Any, Any, Any, Any, Any, Any, Any, Any) => Any]( - children(0).eval(input), - children(1).eval(input), - children(2).eval(input), - children(3).eval(input), - children(4).eval(input), - children(5).eval(input), - children(6).eval(input), - children(7).eval(input)) + ScalaReflection.convertToScala(children(0).eval(input), children(0).dataType), + ScalaReflection.convertToScala(children(1).eval(input), children(1).dataType), + ScalaReflection.convertToScala(children(2).eval(input), children(2).dataType), + ScalaReflection.convertToScala(children(3).eval(input), children(3).dataType), + ScalaReflection.convertToScala(children(4).eval(input), children(4).dataType), + ScalaReflection.convertToScala(children(5).eval(input), children(5).dataType), + ScalaReflection.convertToScala(children(6).eval(input), children(6).dataType), + ScalaReflection.convertToScala(children(7).eval(input), children(7).dataType)) + + case 9 => function.asInstanceOf[(Any, Any, Any, Any, Any, Any, Any, Any, Any) => Any]( - children(0).eval(input), - children(1).eval(input), - children(2).eval(input), - children(3).eval(input), - children(4).eval(input), - children(5).eval(input), - children(6).eval(input), - children(7).eval(input), - children(8).eval(input)) + ScalaReflection.convertToScala(children(0).eval(input), children(0).dataType), + ScalaReflection.convertToScala(children(1).eval(input), children(1).dataType), + ScalaReflection.convertToScala(children(2).eval(input), children(2).dataType), + ScalaReflection.convertToScala(children(3).eval(input), children(3).dataType), + ScalaReflection.convertToScala(children(4).eval(input), children(4).dataType), + ScalaReflection.convertToScala(children(5).eval(input), children(5).dataType), + ScalaReflection.convertToScala(children(6).eval(input), children(6).dataType), + ScalaReflection.convertToScala(children(7).eval(input), children(7).dataType), + ScalaReflection.convertToScala(children(8).eval(input), children(8).dataType)) + + case 10 => function.asInstanceOf[(Any, Any, Any, Any, Any, Any, Any, Any, Any, Any) => Any]( - children(0).eval(input), - children(1).eval(input), - children(2).eval(input), - children(3).eval(input), - children(4).eval(input), - children(5).eval(input), - children(6).eval(input), - children(7).eval(input), - children(8).eval(input), - children(9).eval(input)) + ScalaReflection.convertToScala(children(0).eval(input), children(0).dataType), + ScalaReflection.convertToScala(children(1).eval(input), children(1).dataType), + ScalaReflection.convertToScala(children(2).eval(input), children(2).dataType), + ScalaReflection.convertToScala(children(3).eval(input), children(3).dataType), + ScalaReflection.convertToScala(children(4).eval(input), children(4).dataType), + ScalaReflection.convertToScala(children(5).eval(input), children(5).dataType), + ScalaReflection.convertToScala(children(6).eval(input), children(6).dataType), + ScalaReflection.convertToScala(children(7).eval(input), children(7).dataType), + ScalaReflection.convertToScala(children(8).eval(input), children(8).dataType), + ScalaReflection.convertToScala(children(9).eval(input), children(9).dataType)) + + case 11 => function.asInstanceOf[(Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any) => Any]( - children(0).eval(input), - children(1).eval(input), - children(2).eval(input), - children(3).eval(input), - children(4).eval(input), - children(5).eval(input), - children(6).eval(input), - children(7).eval(input), - children(8).eval(input), - children(9).eval(input), - children(10).eval(input)) + ScalaReflection.convertToScala(children(0).eval(input), children(0).dataType), + ScalaReflection.convertToScala(children(1).eval(input), children(1).dataType), + ScalaReflection.convertToScala(children(2).eval(input), children(2).dataType), + ScalaReflection.convertToScala(children(3).eval(input), children(3).dataType), + ScalaReflection.convertToScala(children(4).eval(input), children(4).dataType), + ScalaReflection.convertToScala(children(5).eval(input), children(5).dataType), + ScalaReflection.convertToScala(children(6).eval(input), children(6).dataType), + ScalaReflection.convertToScala(children(7).eval(input), children(7).dataType), + ScalaReflection.convertToScala(children(8).eval(input), children(8).dataType), + ScalaReflection.convertToScala(children(9).eval(input), children(9).dataType), + ScalaReflection.convertToScala(children(10).eval(input), children(10).dataType)) + + case 12 => function.asInstanceOf[(Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any) => Any]( - children(0).eval(input), - children(1).eval(input), - children(2).eval(input), - children(3).eval(input), - children(4).eval(input), - children(5).eval(input), - children(6).eval(input), - children(7).eval(input), - children(8).eval(input), - children(9).eval(input), - children(10).eval(input), - children(11).eval(input)) + ScalaReflection.convertToScala(children(0).eval(input), children(0).dataType), + ScalaReflection.convertToScala(children(1).eval(input), children(1).dataType), + ScalaReflection.convertToScala(children(2).eval(input), children(2).dataType), + ScalaReflection.convertToScala(children(3).eval(input), children(3).dataType), + ScalaReflection.convertToScala(children(4).eval(input), children(4).dataType), + ScalaReflection.convertToScala(children(5).eval(input), children(5).dataType), + ScalaReflection.convertToScala(children(6).eval(input), children(6).dataType), + ScalaReflection.convertToScala(children(7).eval(input), children(7).dataType), + ScalaReflection.convertToScala(children(8).eval(input), children(8).dataType), + ScalaReflection.convertToScala(children(9).eval(input), children(9).dataType), + ScalaReflection.convertToScala(children(10).eval(input), children(10).dataType), + ScalaReflection.convertToScala(children(11).eval(input), children(11).dataType)) + + case 13 => function.asInstanceOf[(Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any) => Any]( - children(0).eval(input), - children(1).eval(input), - children(2).eval(input), - children(3).eval(input), - children(4).eval(input), - children(5).eval(input), - children(6).eval(input), - children(7).eval(input), - children(8).eval(input), - children(9).eval(input), - children(10).eval(input), - children(11).eval(input), - children(12).eval(input)) + ScalaReflection.convertToScala(children(0).eval(input), children(0).dataType), + ScalaReflection.convertToScala(children(1).eval(input), children(1).dataType), + ScalaReflection.convertToScala(children(2).eval(input), children(2).dataType), + ScalaReflection.convertToScala(children(3).eval(input), children(3).dataType), + ScalaReflection.convertToScala(children(4).eval(input), children(4).dataType), + ScalaReflection.convertToScala(children(5).eval(input), children(5).dataType), + ScalaReflection.convertToScala(children(6).eval(input), children(6).dataType), + ScalaReflection.convertToScala(children(7).eval(input), children(7).dataType), + ScalaReflection.convertToScala(children(8).eval(input), children(8).dataType), + ScalaReflection.convertToScala(children(9).eval(input), children(9).dataType), + ScalaReflection.convertToScala(children(10).eval(input), children(10).dataType), + ScalaReflection.convertToScala(children(11).eval(input), children(11).dataType), + ScalaReflection.convertToScala(children(12).eval(input), children(12).dataType)) + + case 14 => function.asInstanceOf[(Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any) => Any]( - children(0).eval(input), - children(1).eval(input), - children(2).eval(input), - children(3).eval(input), - children(4).eval(input), - children(5).eval(input), - children(6).eval(input), - children(7).eval(input), - children(8).eval(input), - children(9).eval(input), - children(10).eval(input), - children(11).eval(input), - children(12).eval(input), - children(13).eval(input)) + ScalaReflection.convertToScala(children(0).eval(input), children(0).dataType), + ScalaReflection.convertToScala(children(1).eval(input), children(1).dataType), + ScalaReflection.convertToScala(children(2).eval(input), children(2).dataType), + ScalaReflection.convertToScala(children(3).eval(input), children(3).dataType), + ScalaReflection.convertToScala(children(4).eval(input), children(4).dataType), + ScalaReflection.convertToScala(children(5).eval(input), children(5).dataType), + ScalaReflection.convertToScala(children(6).eval(input), children(6).dataType), + ScalaReflection.convertToScala(children(7).eval(input), children(7).dataType), + ScalaReflection.convertToScala(children(8).eval(input), children(8).dataType), + ScalaReflection.convertToScala(children(9).eval(input), children(9).dataType), + ScalaReflection.convertToScala(children(10).eval(input), children(10).dataType), + ScalaReflection.convertToScala(children(11).eval(input), children(11).dataType), + ScalaReflection.convertToScala(children(12).eval(input), children(12).dataType), + ScalaReflection.convertToScala(children(13).eval(input), children(13).dataType)) + + case 15 => function.asInstanceOf[(Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any) => Any]( - children(0).eval(input), - children(1).eval(input), - children(2).eval(input), - children(3).eval(input), - children(4).eval(input), - children(5).eval(input), - children(6).eval(input), - children(7).eval(input), - children(8).eval(input), - children(9).eval(input), - children(10).eval(input), - children(11).eval(input), - children(12).eval(input), - children(13).eval(input), - children(14).eval(input)) + ScalaReflection.convertToScala(children(0).eval(input), children(0).dataType), + ScalaReflection.convertToScala(children(1).eval(input), children(1).dataType), + ScalaReflection.convertToScala(children(2).eval(input), children(2).dataType), + ScalaReflection.convertToScala(children(3).eval(input), children(3).dataType), + ScalaReflection.convertToScala(children(4).eval(input), children(4).dataType), + ScalaReflection.convertToScala(children(5).eval(input), children(5).dataType), + ScalaReflection.convertToScala(children(6).eval(input), children(6).dataType), + ScalaReflection.convertToScala(children(7).eval(input), children(7).dataType), + ScalaReflection.convertToScala(children(8).eval(input), children(8).dataType), + ScalaReflection.convertToScala(children(9).eval(input), children(9).dataType), + ScalaReflection.convertToScala(children(10).eval(input), children(10).dataType), + ScalaReflection.convertToScala(children(11).eval(input), children(11).dataType), + ScalaReflection.convertToScala(children(12).eval(input), children(12).dataType), + ScalaReflection.convertToScala(children(13).eval(input), children(13).dataType), + ScalaReflection.convertToScala(children(14).eval(input), children(14).dataType)) + + case 16 => function.asInstanceOf[(Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any) => Any]( - children(0).eval(input), - children(1).eval(input), - children(2).eval(input), - children(3).eval(input), - children(4).eval(input), - children(5).eval(input), - children(6).eval(input), - children(7).eval(input), - children(8).eval(input), - children(9).eval(input), - children(10).eval(input), - children(11).eval(input), - children(12).eval(input), - children(13).eval(input), - children(14).eval(input), - children(15).eval(input)) + ScalaReflection.convertToScala(children(0).eval(input), children(0).dataType), + ScalaReflection.convertToScala(children(1).eval(input), children(1).dataType), + ScalaReflection.convertToScala(children(2).eval(input), children(2).dataType), + ScalaReflection.convertToScala(children(3).eval(input), children(3).dataType), + ScalaReflection.convertToScala(children(4).eval(input), children(4).dataType), + ScalaReflection.convertToScala(children(5).eval(input), children(5).dataType), + ScalaReflection.convertToScala(children(6).eval(input), children(6).dataType), + ScalaReflection.convertToScala(children(7).eval(input), children(7).dataType), + ScalaReflection.convertToScala(children(8).eval(input), children(8).dataType), + ScalaReflection.convertToScala(children(9).eval(input), children(9).dataType), + ScalaReflection.convertToScala(children(10).eval(input), children(10).dataType), + ScalaReflection.convertToScala(children(11).eval(input), children(11).dataType), + ScalaReflection.convertToScala(children(12).eval(input), children(12).dataType), + ScalaReflection.convertToScala(children(13).eval(input), children(13).dataType), + ScalaReflection.convertToScala(children(14).eval(input), children(14).dataType), + ScalaReflection.convertToScala(children(15).eval(input), children(15).dataType)) + + case 17 => function.asInstanceOf[(Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any) => Any]( - children(0).eval(input), - children(1).eval(input), - children(2).eval(input), - children(3).eval(input), - children(4).eval(input), - children(5).eval(input), - children(6).eval(input), - children(7).eval(input), - children(8).eval(input), - children(9).eval(input), - children(10).eval(input), - children(11).eval(input), - children(12).eval(input), - children(13).eval(input), - children(14).eval(input), - children(15).eval(input), - children(16).eval(input)) + ScalaReflection.convertToScala(children(0).eval(input), children(0).dataType), + ScalaReflection.convertToScala(children(1).eval(input), children(1).dataType), + ScalaReflection.convertToScala(children(2).eval(input), children(2).dataType), + ScalaReflection.convertToScala(children(3).eval(input), children(3).dataType), + ScalaReflection.convertToScala(children(4).eval(input), children(4).dataType), + ScalaReflection.convertToScala(children(5).eval(input), children(5).dataType), + ScalaReflection.convertToScala(children(6).eval(input), children(6).dataType), + ScalaReflection.convertToScala(children(7).eval(input), children(7).dataType), + ScalaReflection.convertToScala(children(8).eval(input), children(8).dataType), + ScalaReflection.convertToScala(children(9).eval(input), children(9).dataType), + ScalaReflection.convertToScala(children(10).eval(input), children(10).dataType), + ScalaReflection.convertToScala(children(11).eval(input), children(11).dataType), + ScalaReflection.convertToScala(children(12).eval(input), children(12).dataType), + ScalaReflection.convertToScala(children(13).eval(input), children(13).dataType), + ScalaReflection.convertToScala(children(14).eval(input), children(14).dataType), + ScalaReflection.convertToScala(children(15).eval(input), children(15).dataType), + ScalaReflection.convertToScala(children(16).eval(input), children(16).dataType)) + + case 18 => function.asInstanceOf[(Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any) => Any]( - children(0).eval(input), - children(1).eval(input), - children(2).eval(input), - children(3).eval(input), - children(4).eval(input), - children(5).eval(input), - children(6).eval(input), - children(7).eval(input), - children(8).eval(input), - children(9).eval(input), - children(10).eval(input), - children(11).eval(input), - children(12).eval(input), - children(13).eval(input), - children(14).eval(input), - children(15).eval(input), - children(16).eval(input), - children(17).eval(input)) + ScalaReflection.convertToScala(children(0).eval(input), children(0).dataType), + ScalaReflection.convertToScala(children(1).eval(input), children(1).dataType), + ScalaReflection.convertToScala(children(2).eval(input), children(2).dataType), + ScalaReflection.convertToScala(children(3).eval(input), children(3).dataType), + ScalaReflection.convertToScala(children(4).eval(input), children(4).dataType), + ScalaReflection.convertToScala(children(5).eval(input), children(5).dataType), + ScalaReflection.convertToScala(children(6).eval(input), children(6).dataType), + ScalaReflection.convertToScala(children(7).eval(input), children(7).dataType), + ScalaReflection.convertToScala(children(8).eval(input), children(8).dataType), + ScalaReflection.convertToScala(children(9).eval(input), children(9).dataType), + ScalaReflection.convertToScala(children(10).eval(input), children(10).dataType), + ScalaReflection.convertToScala(children(11).eval(input), children(11).dataType), + ScalaReflection.convertToScala(children(12).eval(input), children(12).dataType), + ScalaReflection.convertToScala(children(13).eval(input), children(13).dataType), + ScalaReflection.convertToScala(children(14).eval(input), children(14).dataType), + ScalaReflection.convertToScala(children(15).eval(input), children(15).dataType), + ScalaReflection.convertToScala(children(16).eval(input), children(16).dataType), + ScalaReflection.convertToScala(children(17).eval(input), children(17).dataType)) + + case 19 => function.asInstanceOf[(Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any) => Any]( - children(0).eval(input), - children(1).eval(input), - children(2).eval(input), - children(3).eval(input), - children(4).eval(input), - children(5).eval(input), - children(6).eval(input), - children(7).eval(input), - children(8).eval(input), - children(9).eval(input), - children(10).eval(input), - children(11).eval(input), - children(12).eval(input), - children(13).eval(input), - children(14).eval(input), - children(15).eval(input), - children(16).eval(input), - children(17).eval(input), - children(18).eval(input)) + ScalaReflection.convertToScala(children(0).eval(input), children(0).dataType), + ScalaReflection.convertToScala(children(1).eval(input), children(1).dataType), + ScalaReflection.convertToScala(children(2).eval(input), children(2).dataType), + ScalaReflection.convertToScala(children(3).eval(input), children(3).dataType), + ScalaReflection.convertToScala(children(4).eval(input), children(4).dataType), + ScalaReflection.convertToScala(children(5).eval(input), children(5).dataType), + ScalaReflection.convertToScala(children(6).eval(input), children(6).dataType), + ScalaReflection.convertToScala(children(7).eval(input), children(7).dataType), + ScalaReflection.convertToScala(children(8).eval(input), children(8).dataType), + ScalaReflection.convertToScala(children(9).eval(input), children(9).dataType), + ScalaReflection.convertToScala(children(10).eval(input), children(10).dataType), + ScalaReflection.convertToScala(children(11).eval(input), children(11).dataType), + ScalaReflection.convertToScala(children(12).eval(input), children(12).dataType), + ScalaReflection.convertToScala(children(13).eval(input), children(13).dataType), + ScalaReflection.convertToScala(children(14).eval(input), children(14).dataType), + ScalaReflection.convertToScala(children(15).eval(input), children(15).dataType), + ScalaReflection.convertToScala(children(16).eval(input), children(16).dataType), + ScalaReflection.convertToScala(children(17).eval(input), children(17).dataType), + ScalaReflection.convertToScala(children(18).eval(input), children(18).dataType)) + + case 20 => function.asInstanceOf[(Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any) => Any]( - children(0).eval(input), - children(1).eval(input), - children(2).eval(input), - children(3).eval(input), - children(4).eval(input), - children(5).eval(input), - children(6).eval(input), - children(7).eval(input), - children(8).eval(input), - children(9).eval(input), - children(10).eval(input), - children(11).eval(input), - children(12).eval(input), - children(13).eval(input), - children(14).eval(input), - children(15).eval(input), - children(16).eval(input), - children(17).eval(input), - children(18).eval(input), - children(19).eval(input)) + ScalaReflection.convertToScala(children(0).eval(input), children(0).dataType), + ScalaReflection.convertToScala(children(1).eval(input), children(1).dataType), + ScalaReflection.convertToScala(children(2).eval(input), children(2).dataType), + ScalaReflection.convertToScala(children(3).eval(input), children(3).dataType), + ScalaReflection.convertToScala(children(4).eval(input), children(4).dataType), + ScalaReflection.convertToScala(children(5).eval(input), children(5).dataType), + ScalaReflection.convertToScala(children(6).eval(input), children(6).dataType), + ScalaReflection.convertToScala(children(7).eval(input), children(7).dataType), + ScalaReflection.convertToScala(children(8).eval(input), children(8).dataType), + ScalaReflection.convertToScala(children(9).eval(input), children(9).dataType), + ScalaReflection.convertToScala(children(10).eval(input), children(10).dataType), + ScalaReflection.convertToScala(children(11).eval(input), children(11).dataType), + ScalaReflection.convertToScala(children(12).eval(input), children(12).dataType), + ScalaReflection.convertToScala(children(13).eval(input), children(13).dataType), + ScalaReflection.convertToScala(children(14).eval(input), children(14).dataType), + ScalaReflection.convertToScala(children(15).eval(input), children(15).dataType), + ScalaReflection.convertToScala(children(16).eval(input), children(16).dataType), + ScalaReflection.convertToScala(children(17).eval(input), children(17).dataType), + ScalaReflection.convertToScala(children(18).eval(input), children(18).dataType), + ScalaReflection.convertToScala(children(19).eval(input), children(19).dataType)) + + case 21 => function.asInstanceOf[(Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any) => Any]( - children(0).eval(input), - children(1).eval(input), - children(2).eval(input), - children(3).eval(input), - children(4).eval(input), - children(5).eval(input), - children(6).eval(input), - children(7).eval(input), - children(8).eval(input), - children(9).eval(input), - children(10).eval(input), - children(11).eval(input), - children(12).eval(input), - children(13).eval(input), - children(14).eval(input), - children(15).eval(input), - children(16).eval(input), - children(17).eval(input), - children(18).eval(input), - children(19).eval(input), - children(20).eval(input)) + ScalaReflection.convertToScala(children(0).eval(input), children(0).dataType), + ScalaReflection.convertToScala(children(1).eval(input), children(1).dataType), + ScalaReflection.convertToScala(children(2).eval(input), children(2).dataType), + ScalaReflection.convertToScala(children(3).eval(input), children(3).dataType), + ScalaReflection.convertToScala(children(4).eval(input), children(4).dataType), + ScalaReflection.convertToScala(children(5).eval(input), children(5).dataType), + ScalaReflection.convertToScala(children(6).eval(input), children(6).dataType), + ScalaReflection.convertToScala(children(7).eval(input), children(7).dataType), + ScalaReflection.convertToScala(children(8).eval(input), children(8).dataType), + ScalaReflection.convertToScala(children(9).eval(input), children(9).dataType), + ScalaReflection.convertToScala(children(10).eval(input), children(10).dataType), + ScalaReflection.convertToScala(children(11).eval(input), children(11).dataType), + ScalaReflection.convertToScala(children(12).eval(input), children(12).dataType), + ScalaReflection.convertToScala(children(13).eval(input), children(13).dataType), + ScalaReflection.convertToScala(children(14).eval(input), children(14).dataType), + ScalaReflection.convertToScala(children(15).eval(input), children(15).dataType), + ScalaReflection.convertToScala(children(16).eval(input), children(16).dataType), + ScalaReflection.convertToScala(children(17).eval(input), children(17).dataType), + ScalaReflection.convertToScala(children(18).eval(input), children(18).dataType), + ScalaReflection.convertToScala(children(19).eval(input), children(19).dataType), + ScalaReflection.convertToScala(children(20).eval(input), children(20).dataType)) + + case 22 => function.asInstanceOf[(Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any) => Any]( - children(0).eval(input), - children(1).eval(input), - children(2).eval(input), - children(3).eval(input), - children(4).eval(input), - children(5).eval(input), - children(6).eval(input), - children(7).eval(input), - children(8).eval(input), - children(9).eval(input), - children(10).eval(input), - children(11).eval(input), - children(12).eval(input), - children(13).eval(input), - children(14).eval(input), - children(15).eval(input), - children(16).eval(input), - children(17).eval(input), - children(18).eval(input), - children(19).eval(input), - children(20).eval(input), - children(21).eval(input)) + ScalaReflection.convertToScala(children(0).eval(input), children(0).dataType), + ScalaReflection.convertToScala(children(1).eval(input), children(1).dataType), + ScalaReflection.convertToScala(children(2).eval(input), children(2).dataType), + ScalaReflection.convertToScala(children(3).eval(input), children(3).dataType), + ScalaReflection.convertToScala(children(4).eval(input), children(4).dataType), + ScalaReflection.convertToScala(children(5).eval(input), children(5).dataType), + ScalaReflection.convertToScala(children(6).eval(input), children(6).dataType), + ScalaReflection.convertToScala(children(7).eval(input), children(7).dataType), + ScalaReflection.convertToScala(children(8).eval(input), children(8).dataType), + ScalaReflection.convertToScala(children(9).eval(input), children(9).dataType), + ScalaReflection.convertToScala(children(10).eval(input), children(10).dataType), + ScalaReflection.convertToScala(children(11).eval(input), children(11).dataType), + ScalaReflection.convertToScala(children(12).eval(input), children(12).dataType), + ScalaReflection.convertToScala(children(13).eval(input), children(13).dataType), + ScalaReflection.convertToScala(children(14).eval(input), children(14).dataType), + ScalaReflection.convertToScala(children(15).eval(input), children(15).dataType), + ScalaReflection.convertToScala(children(16).eval(input), children(16).dataType), + ScalaReflection.convertToScala(children(17).eval(input), children(17).dataType), + ScalaReflection.convertToScala(children(18).eval(input), children(18).dataType), + ScalaReflection.convertToScala(children(19).eval(input), children(19).dataType), + ScalaReflection.convertToScala(children(20).eval(input), children(20).dataType), + ScalaReflection.convertToScala(children(21).eval(input), children(21).dataType)) + } // scalastyle:on - ScalaReflection.convertToCatalyst(result) + ScalaReflection.convertToCatalyst(result, dataType) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala index 1b4d892625dbb..2b364fc1df1d8 100755 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala @@ -286,18 +286,38 @@ case class ApproxCountDistinct(child: Expression, relativeSD: Double = 0.05) case class Average(child: Expression) extends PartialAggregate with trees.UnaryNode[Expression] { override def nullable = false - override def dataType = DoubleType + + override def dataType = child.dataType match { + case DecimalType.Fixed(precision, scale) => + DecimalType(precision + 4, scale + 4) // Add 4 digits after decimal point, like Hive + case DecimalType.Unlimited => + DecimalType.Unlimited + case _ => + DoubleType + } + override def toString = s"AVG($child)" override def asPartial: SplitEvaluation = { val partialSum = Alias(Sum(child), "PartialSum")() val partialCount = Alias(Count(child), "PartialCount")() - val castedSum = Cast(Sum(partialSum.toAttribute), dataType) - val castedCount = Cast(Sum(partialCount.toAttribute), dataType) - SplitEvaluation( - Divide(castedSum, castedCount), - partialCount :: partialSum :: Nil) + child.dataType match { + case DecimalType.Fixed(_, _) => + // Turn the results to unlimited decimals for the divsion, before going back to fixed + val castedSum = Cast(Sum(partialSum.toAttribute), DecimalType.Unlimited) + val castedCount = Cast(Sum(partialCount.toAttribute), DecimalType.Unlimited) + SplitEvaluation( + Cast(Divide(castedSum, castedCount), dataType), + partialCount :: partialSum :: Nil) + + case _ => + val castedSum = Cast(Sum(partialSum.toAttribute), dataType) + val castedCount = Cast(Sum(partialCount.toAttribute), dataType) + SplitEvaluation( + Divide(castedSum, castedCount), + partialCount :: partialSum :: Nil) + } } override def newInstance() = new AverageFunction(child, this) @@ -306,7 +326,16 @@ case class Average(child: Expression) extends PartialAggregate with trees.UnaryN case class Sum(child: Expression) extends PartialAggregate with trees.UnaryNode[Expression] { override def nullable = false - override def dataType = child.dataType + + override def dataType = child.dataType match { + case DecimalType.Fixed(precision, scale) => + DecimalType(precision + 10, scale) // Add 10 digits left of decimal point, like Hive + case DecimalType.Unlimited => + DecimalType.Unlimited + case _ => + child.dataType + } + override def toString = s"SUM($child)" override def asPartial: SplitEvaluation = { @@ -322,9 +351,17 @@ case class Sum(child: Expression) extends PartialAggregate with trees.UnaryNode[ case class SumDistinct(child: Expression) extends AggregateExpression with trees.UnaryNode[Expression] { - override def nullable = false - override def dataType = child.dataType + + override def dataType = child.dataType match { + case DecimalType.Fixed(precision, scale) => + DecimalType(precision + 10, scale) // Add 10 digits left of decimal point, like Hive + case DecimalType.Unlimited => + DecimalType.Unlimited + case _ => + child.dataType + } + override def toString = s"SUM(DISTINCT $child)" override def newInstance() = new SumDistinctFunction(child, this) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala index fe825fdcdae37..8574cabc43525 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala @@ -36,7 +36,7 @@ case class UnaryMinus(child: Expression) extends UnaryExpression { case class Sqrt(child: Expression) extends UnaryExpression { type EvaluatedType = Any - + def dataType = DoubleType override def foldable = child.foldable def nullable = child.nullable @@ -55,7 +55,9 @@ abstract class BinaryArithmetic extends BinaryExpression { def nullable = left.nullable || right.nullable override lazy val resolved = - left.resolved && right.resolved && left.dataType == right.dataType + left.resolved && right.resolved && + left.dataType == right.dataType && + !DecimalType.isFixed(left.dataType) def dataType = { if (!resolved) { @@ -64,6 +66,23 @@ abstract class BinaryArithmetic extends BinaryExpression { } left.dataType } + + override def eval(input: Row): Any = { + val evalE1 = left.eval(input) + if(evalE1 == null) { + null + } else { + val evalE2 = right.eval(input) + if (evalE2 == null) { + null + } else { + evalInternal(evalE1, evalE2) + } + } + } + + def evalInternal(evalE1: EvaluatedType, evalE2: EvaluatedType): Any = + sys.error(s"BinaryExpressions must either override eval or evalInternal") } case class Add(left: Expression, right: Expression) extends BinaryArithmetic { @@ -87,6 +106,8 @@ case class Multiply(left: Expression, right: Expression) extends BinaryArithmeti case class Divide(left: Expression, right: Expression) extends BinaryArithmetic { def symbol = "/" + override def nullable = left.nullable || right.nullable || dataType.isInstanceOf[DecimalType] + override def eval(input: Row): Any = dataType match { case _: FractionalType => f2(input, left, right, _.div(_, _)) case _: IntegralType => i2(input, left , right, _.quot(_, _)) @@ -97,9 +118,83 @@ case class Divide(left: Expression, right: Expression) extends BinaryArithmetic case class Remainder(left: Expression, right: Expression) extends BinaryArithmetic { def symbol = "%" + override def nullable = left.nullable || right.nullable || dataType.isInstanceOf[DecimalType] + override def eval(input: Row): Any = i2(input, left, right, _.rem(_, _)) } +/** + * A function that calculates bitwise and(&) of two numbers. + */ +case class BitwiseAnd(left: Expression, right: Expression) extends BinaryArithmetic { + def symbol = "&" + + override def evalInternal(evalE1: EvaluatedType, evalE2: EvaluatedType): Any = dataType match { + case ByteType => (evalE1.asInstanceOf[Byte] & evalE2.asInstanceOf[Byte]).toByte + case ShortType => (evalE1.asInstanceOf[Short] & evalE2.asInstanceOf[Short]).toShort + case IntegerType => evalE1.asInstanceOf[Int] & evalE2.asInstanceOf[Int] + case LongType => evalE1.asInstanceOf[Long] & evalE2.asInstanceOf[Long] + case other => sys.error(s"Unsupported bitwise & operation on ${other}") + } +} + +/** + * A function that calculates bitwise or(|) of two numbers. + */ +case class BitwiseOr(left: Expression, right: Expression) extends BinaryArithmetic { + def symbol = "&" + + override def evalInternal(evalE1: EvaluatedType, evalE2: EvaluatedType): Any = dataType match { + case ByteType => (evalE1.asInstanceOf[Byte] | evalE2.asInstanceOf[Byte]).toByte + case ShortType => (evalE1.asInstanceOf[Short] | evalE2.asInstanceOf[Short]).toShort + case IntegerType => evalE1.asInstanceOf[Int] | evalE2.asInstanceOf[Int] + case LongType => evalE1.asInstanceOf[Long] | evalE2.asInstanceOf[Long] + case other => sys.error(s"Unsupported bitwise | operation on ${other}") + } +} + +/** + * A function that calculates bitwise xor(^) of two numbers. + */ +case class BitwiseXor(left: Expression, right: Expression) extends BinaryArithmetic { + def symbol = "^" + + override def evalInternal(evalE1: EvaluatedType, evalE2: EvaluatedType): Any = dataType match { + case ByteType => (evalE1.asInstanceOf[Byte] ^ evalE2.asInstanceOf[Byte]).toByte + case ShortType => (evalE1.asInstanceOf[Short] ^ evalE2.asInstanceOf[Short]).toShort + case IntegerType => evalE1.asInstanceOf[Int] ^ evalE2.asInstanceOf[Int] + case LongType => evalE1.asInstanceOf[Long] ^ evalE2.asInstanceOf[Long] + case other => sys.error(s"Unsupported bitwise ^ operation on ${other}") + } +} + +/** + * A function that calculates bitwise not(~) of a number. + */ +case class BitwiseNot(child: Expression) extends UnaryExpression { + type EvaluatedType = Any + + def dataType = child.dataType + override def foldable = child.foldable + def nullable = child.nullable + override def toString = s"-$child" + + override def eval(input: Row): Any = { + val evalE = child.eval(input) + if (evalE == null) { + null + } else { + dataType match { + case ByteType => (~(evalE.asInstanceOf[Byte])).toByte + case ShortType => (~(evalE.asInstanceOf[Short])).toShort + case IntegerType => ~(evalE.asInstanceOf[Int]) + case LongType => ~(evalE.asInstanceOf[Long]) + case other => sys.error(s"Unsupported bitwise ~ operation on ${other}") + } + } + } +} + case class MaxOf(left: Expression, right: Expression) extends Expression { type EvaluatedType = Any diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index 5a3f013c34579..67f8d411b6bb4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.catalyst.expressions.codegen import com.google.common.cache.{CacheLoader, CacheBuilder} +import org.apache.spark.sql.catalyst.types.decimal.Decimal import scala.language.existentials @@ -485,6 +486,34 @@ abstract class CodeGenerator[InType <: AnyRef, OutType <: AnyRef] extends Loggin } """.children + case UnscaledValue(child) => + val childEval = expressionEvaluator(child) + + childEval.code ++ + q""" + var $nullTerm = ${childEval.nullTerm} + var $primitiveTerm: Long = if (!$nullTerm) { + ${childEval.primitiveTerm}.toUnscaledLong + } else { + ${defaultPrimitive(LongType)} + } + """.children + + case MakeDecimal(child, precision, scale) => + val childEval = expressionEvaluator(child) + + childEval.code ++ + q""" + var $nullTerm = ${childEval.nullTerm} + var $primitiveTerm: org.apache.spark.sql.catalyst.types.decimal.Decimal = + ${defaultPrimitive(DecimalType())} + + if (!$nullTerm) { + $primitiveTerm = new org.apache.spark.sql.catalyst.types.decimal.Decimal() + $primitiveTerm = $primitiveTerm.setOrNull(${childEval.primitiveTerm}, $precision, $scale) + $nullTerm = $primitiveTerm == null + } + """.children } // If there was no match in the partial function above, we fall back on calling the interpreted @@ -562,7 +591,7 @@ abstract class CodeGenerator[InType <: AnyRef, OutType <: AnyRef] extends Loggin case LongType => ru.Literal(Constant(1L)) case ByteType => ru.Literal(Constant(-1.toByte)) case DoubleType => ru.Literal(Constant(-1.toDouble)) - case DecimalType => ru.Literal(Constant(-1)) // Will get implicity converted as needed. + case DecimalType() => q"org.apache.spark.sql.catalyst.types.decimal.Decimal(-1)" case IntegerType => ru.Literal(Constant(-1)) case _ => ru.Literal(Constant(null)) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalFunctions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalFunctions.scala new file mode 100644 index 0000000000000..d1eab2eb4ed56 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalFunctions.scala @@ -0,0 +1,59 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions + +import org.apache.spark.sql.catalyst.types.decimal.Decimal +import org.apache.spark.sql.catalyst.types.{DecimalType, LongType, DoubleType, DataType} + +/** Return the unscaled Long value of a Decimal, assuming it fits in a Long */ +case class UnscaledValue(child: Expression) extends UnaryExpression { + override type EvaluatedType = Any + + override def dataType: DataType = LongType + override def foldable = child.foldable + def nullable = child.nullable + override def toString = s"UnscaledValue($child)" + + override def eval(input: Row): Any = { + val childResult = child.eval(input) + if (childResult == null) { + null + } else { + childResult.asInstanceOf[Decimal].toUnscaledLong + } + } +} + +/** Create a Decimal from an unscaled Long value */ +case class MakeDecimal(child: Expression, precision: Int, scale: Int) extends UnaryExpression { + override type EvaluatedType = Decimal + + override def dataType: DataType = DecimalType(precision, scale) + override def foldable = child.foldable + def nullable = child.nullable + override def toString = s"MakeDecimal($child,$precision,$scale)" + + override def eval(input: Row): Decimal = { + val childResult = child.eval(input) + if (childResult == null) { + null + } else { + new Decimal().setOrNull(childResult.asInstanceOf[Long], precision, scale) + } + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala index 9c865254e0be9..ab0701fd9a80b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala @@ -43,7 +43,7 @@ abstract class Generator extends Expression { override type EvaluatedType = TraversableOnce[Row] override lazy val dataType = - ArrayType(StructType(output.map(a => StructField(a.name, a.dataType, a.nullable)))) + ArrayType(StructType(output.map(a => StructField(a.name, a.dataType, a.nullable, a.metadata)))) override def nullable = false diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala index ba240233cae61..93c19325151bf 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.catalyst.expressions import java.sql.{Date, Timestamp} import org.apache.spark.sql.catalyst.types._ +import org.apache.spark.sql.catalyst.types.decimal.Decimal object Literal { def apply(v: Any): Literal = v match { @@ -31,7 +32,8 @@ object Literal { case s: Short => Literal(s, ShortType) case s: String => Literal(s, StringType) case b: Boolean => Literal(b, BooleanType) - case d: BigDecimal => Literal(d, DecimalType) + case d: BigDecimal => Literal(Decimal(d), DecimalType.Unlimited) + case d: Decimal => Literal(d, DecimalType.Unlimited) case t: Timestamp => Literal(t, TimestampType) case d: Date => Literal(d, DateType) case a: Array[Byte] => Literal(a, BinaryType) @@ -62,7 +64,7 @@ case class Literal(value: Any, dataType: DataType) extends LeafExpression { } // TODO: Specialize -case class MutableLiteral(var value: Any, dataType: DataType, nullable: Boolean = true) +case class MutableLiteral(var value: Any, dataType: DataType, nullable: Boolean = true) extends LeafExpression { type EvaluatedType = Any diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala index d023db44d8543..fc90a54a58259 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala @@ -21,6 +21,7 @@ import org.apache.spark.sql.catalyst.trees import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute import org.apache.spark.sql.catalyst.errors.TreeNodeException import org.apache.spark.sql.catalyst.types._ +import org.apache.spark.sql.catalyst.util.Metadata object NamedExpression { private val curId = new java.util.concurrent.atomic.AtomicLong() @@ -43,6 +44,9 @@ abstract class NamedExpression extends Expression { def toAttribute: Attribute + /** Returns the metadata when an expression is a reference to another expression with metadata. */ + def metadata: Metadata = Metadata.empty + protected def typeSuffix = if (resolved) { dataType match { @@ -88,10 +92,16 @@ case class Alias(child: Expression, name: String) override def dataType = child.dataType override def nullable = child.nullable + override def metadata: Metadata = { + child match { + case named: NamedExpression => named.metadata + case _ => Metadata.empty + } + } override def toAttribute = { if (resolved) { - AttributeReference(name, child.dataType, child.nullable)(exprId, qualifiers) + AttributeReference(name, child.dataType, child.nullable, metadata)(exprId, qualifiers) } else { UnresolvedAttribute(name) } @@ -108,18 +118,23 @@ case class Alias(child: Expression, name: String) * @param name The name of this attribute, should only be used during analysis or for debugging. * @param dataType The [[DataType]] of this attribute. * @param nullable True if null is a valid value for this attribute. + * @param metadata The metadata of this attribute. * @param exprId A globally unique id used to check if different AttributeReferences refer to the * same attribute. * @param qualifiers a list of strings that can be used to referred to this attribute in a fully * qualified way. Consider the examples tableName.name, subQueryAlias.name. * tableName and subQueryAlias are possible qualifiers. */ -case class AttributeReference(name: String, dataType: DataType, nullable: Boolean = true) - (val exprId: ExprId = NamedExpression.newExprId, val qualifiers: Seq[String] = Nil) - extends Attribute with trees.LeafNode[Expression] { +case class AttributeReference( + name: String, + dataType: DataType, + nullable: Boolean = true, + override val metadata: Metadata = Metadata.empty)( + val exprId: ExprId = NamedExpression.newExprId, + val qualifiers: Seq[String] = Nil) extends Attribute with trees.LeafNode[Expression] { override def equals(other: Any) = other match { - case ar: AttributeReference => exprId == ar.exprId && dataType == ar.dataType + case ar: AttributeReference => name == ar.name && exprId == ar.exprId && dataType == ar.dataType case _ => false } @@ -128,10 +143,12 @@ case class AttributeReference(name: String, dataType: DataType, nullable: Boolea var h = 17 h = h * 37 + exprId.hashCode() h = h * 37 + dataType.hashCode() + h = h * 37 + metadata.hashCode() h } - override def newInstance() = AttributeReference(name, dataType, nullable)(qualifiers = qualifiers) + override def newInstance() = + AttributeReference(name, dataType, nullable, metadata)(qualifiers = qualifiers) /** * Returns a copy of this [[AttributeReference]] with changed nullability. @@ -140,7 +157,7 @@ case class AttributeReference(name: String, dataType: DataType, nullable: Boolea if (nullable == newNullability) { this } else { - AttributeReference(name, dataType, newNullability)(exprId, qualifiers) + AttributeReference(name, dataType, newNullability, metadata)(exprId, qualifiers) } } @@ -156,10 +173,10 @@ case class AttributeReference(name: String, dataType: DataType, nullable: Boolea * Returns a copy of this [[AttributeReference]] with new qualifiers. */ override def withQualifiers(newQualifiers: Seq[String]) = { - if (newQualifiers == qualifiers) { + if (newQualifiers.toSet == qualifiers.toSet) { this } else { - AttributeReference(name, dataType, nullable)(exprId, newQualifiers) + AttributeReference(name, dataType, nullable, metadata)(exprId, newQualifiers) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala index c2a3a5ca3ca8b..f6349767764a3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala @@ -102,31 +102,27 @@ case class Like(left: Expression, right: Expression) // replace the _ with .{1} exactly match 1 time of any character // replace the % with .*, match 0 or more times with any character - override def escape(v: String) = { - val sb = new StringBuilder() - var i = 0; - while (i < v.length) { - // Make a special case for "\\_" and "\\%" - val n = v.charAt(i); - if (n == '\\' && i + 1 < v.length && (v.charAt(i + 1) == '_' || v.charAt(i + 1) == '%')) { - sb.append(v.charAt(i + 1)) - i += 1 - } else { - if (n == '_') { - sb.append("."); - } else if (n == '%') { - sb.append(".*"); - } else { - sb.append(Pattern.quote(Character.toString(n))); - } - } - - i += 1 + override def escape(v: String) = + if (!v.isEmpty) { + "(?s)" + (' ' +: v.init).zip(v).flatMap { + case (prev, '\\') => "" + case ('\\', c) => + c match { + case '_' => "_" + case '%' => "%" + case _ => Pattern.quote("\\" + c) + } + case (prev, c) => + c match { + case '_' => "." + case '%' => ".*" + case _ => Pattern.quote(Character.toString(c)) + } + }.mkString + } else { + v } - sb.toString() - } - override def matches(regex: Pattern, str: String): Boolean = regex.matcher(str).matches() } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index 9ce7c78195830..a4aa322fc52d8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -27,6 +27,7 @@ import org.apache.spark.sql.catalyst.plans.LeftSemi import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules._ import org.apache.spark.sql.catalyst.types._ +import org.apache.spark.sql.catalyst.types.decimal.Decimal abstract class Optimizer extends RuleExecutor[LogicalPlan] @@ -43,6 +44,8 @@ object DefaultOptimizer extends Optimizer { SimplifyCasts, SimplifyCaseConversionExpressions, OptimizeIn) :: + Batch("Decimal Optimizations", FixedPoint(100), + DecimalAggregates) :: Batch("Filter Pushdown", FixedPoint(100), UnionPushdown, CombineFilters, @@ -390,9 +393,9 @@ object PushPredicateThroughProject extends Rule[LogicalPlan] { * evaluated using only the attributes of the left or right side of a join. Other * [[Filter]] conditions are moved into the `condition` of the [[Join]]. * - * And also Pushes down the join filter, where the `condition` can be evaluated using only the - * attributes of the left or right side of sub query when applicable. - * + * And also Pushes down the join filter, where the `condition` can be evaluated using only the + * attributes of the left or right side of sub query when applicable. + * * Check https://cwiki.apache.org/confluence/display/Hive/OuterJoinBehavior for more details */ object PushPredicateThroughJoin extends Rule[LogicalPlan] with PredicateHelper { @@ -404,7 +407,7 @@ object PushPredicateThroughJoin extends Rule[LogicalPlan] with PredicateHelper { private def split(condition: Seq[Expression], left: LogicalPlan, right: LogicalPlan) = { val (leftEvaluateCondition, rest) = condition.partition(_.references subsetOf left.outputSet) - val (rightEvaluateCondition, commonCondition) = + val (rightEvaluateCondition, commonCondition) = rest.partition(_.references subsetOf right.outputSet) (leftEvaluateCondition, rightEvaluateCondition, commonCondition) @@ -413,7 +416,7 @@ object PushPredicateThroughJoin extends Rule[LogicalPlan] with PredicateHelper { def apply(plan: LogicalPlan): LogicalPlan = plan transform { // push the where condition down into join filter case f @ Filter(filterCondition, Join(left, right, joinType, joinCondition)) => - val (leftFilterConditions, rightFilterConditions, commonFilterCondition) = + val (leftFilterConditions, rightFilterConditions, commonFilterCondition) = split(splitConjunctivePredicates(filterCondition), left, right) joinType match { @@ -451,7 +454,7 @@ object PushPredicateThroughJoin extends Rule[LogicalPlan] with PredicateHelper { // push down the join filter into sub query scanning if applicable case f @ Join(left, right, joinType, joinCondition) => - val (leftJoinConditions, rightJoinConditions, commonJoinCondition) = + val (leftJoinConditions, rightJoinConditions, commonJoinCondition) = split(joinCondition.map(splitConjunctivePredicates).getOrElse(Nil), left, right) joinType match { @@ -519,3 +522,26 @@ object SimplifyCaseConversionExpressions extends Rule[LogicalPlan] { } } } + +/** + * Speeds up aggregates on fixed-precision decimals by executing them on unscaled Long values. + * + * This uses the same rules for increasing the precision and scale of the output as + * [[org.apache.spark.sql.catalyst.analysis.HiveTypeCoercion.DecimalPrecision]]. + */ +object DecimalAggregates extends Rule[LogicalPlan] { + import Decimal.MAX_LONG_DIGITS + + /** Maximum number of decimal digits representable precisely in a Double */ + val MAX_DOUBLE_DIGITS = 15 + + def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions { + case Sum(e @ DecimalType.Expression(prec, scale)) if prec + 10 <= MAX_LONG_DIGITS => + MakeDecimal(Sum(UnscaledValue(e)), prec + 10, scale) + + case Average(e @ DecimalType.Expression(prec, scale)) if prec + 4 <= MAX_DOUBLE_DIGITS => + Cast( + Divide(Average(UnscaledValue(e)), Literal(math.pow(10.0, scale), DoubleType)), + DecimalType(prec + 4, scale + 4)) + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/package.scala index bdd07bbeb2230..a38079ced34b2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/package.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/package.scala @@ -17,6 +17,10 @@ package org.apache.spark.sql +/** + * Catalyst is a library for manipulating relational query plans. All classes in catalyst are + * considered an internal API to Spark SQL and are subject to change between minor releases. + */ package object catalyst { /** * A JVM-global lock that should be used to prevent thread safety issues when using things in diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/QueryPlanner.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/QueryPlanner.scala index 5839c9f7c43ef..51b5699affed5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/QueryPlanner.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/QueryPlanner.scala @@ -21,6 +21,15 @@ import org.apache.spark.Logging import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.trees.TreeNode +/** + * Given a [[plans.logical.LogicalPlan LogicalPlan]], returns a list of `PhysicalPlan`s that can + * be used for execution. If this strategy does not apply to the give logical operation then an + * empty list should be returned. + */ +abstract class GenericStrategy[PhysicalPlan <: TreeNode[PhysicalPlan]] extends Logging { + def apply(plan: LogicalPlan): Seq[PhysicalPlan] +} + /** * Abstract class for transforming [[plans.logical.LogicalPlan LogicalPlan]]s into physical plans. * Child classes are responsible for specifying a list of [[Strategy]] objects that each of which @@ -35,16 +44,7 @@ import org.apache.spark.sql.catalyst.trees.TreeNode */ abstract class QueryPlanner[PhysicalPlan <: TreeNode[PhysicalPlan]] { /** A list of execution strategies that can be used by the planner */ - def strategies: Seq[Strategy] - - /** - * Given a [[plans.logical.LogicalPlan LogicalPlan]], returns a list of `PhysicalPlan`s that can - * be used for execution. If this strategy does not apply to the give logical operation then an - * empty list should be returned. - */ - abstract protected class Strategy extends Logging { - def apply(plan: LogicalPlan): Seq[PhysicalPlan] - } + def strategies: Seq[GenericStrategy[PhysicalPlan]] /** * Returns a placeholder for a physical plan that executes `plan`. This placeholder will be diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala index 882e9c6110089..ed578e081be73 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala @@ -26,25 +26,24 @@ import org.apache.spark.sql.catalyst.trees.TreeNode import org.apache.spark.sql.catalyst.types.StructType import org.apache.spark.sql.catalyst.trees +/** + * Estimates of various statistics. The default estimation logic simply lazily multiplies the + * corresponding statistic produced by the children. To override this behavior, override + * `statistics` and assign it an overriden version of `Statistics`. + * + * '''NOTE''': concrete and/or overriden versions of statistics fields should pay attention to the + * performance of the implementations. The reason is that estimations might get triggered in + * performance-critical processes, such as query plan planning. + * + * @param sizeInBytes Physical size in bytes. For leaf operators this defaults to 1, otherwise it + * defaults to the product of children's `sizeInBytes`. + */ +private[sql] case class Statistics(sizeInBytes: BigInt) + abstract class LogicalPlan extends QueryPlan[LogicalPlan] with Logging { self: Product => - /** - * Estimates of various statistics. The default estimation logic simply lazily multiplies the - * corresponding statistic produced by the children. To override this behavior, override - * `statistics` and assign it an overriden version of `Statistics`. - * - * '''NOTE''': concrete and/or overriden versions of statistics fields should pay attention to the - * performance of the implementations. The reason is that estimations might get triggered in - * performance-critical processes, such as query plan planning. - * - * @param sizeInBytes Physical size in bytes. For leaf operators this defaults to 1, otherwise it - * defaults to the product of children's `sizeInBytes`. - */ - case class Statistics( - sizeInBytes: BigInt - ) - lazy val statistics: Statistics = { + def statistics: Statistics = { if (children.size == 0) { throw new UnsupportedOperationException(s"LeafNode $nodeName must implement statistics.") } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala index 14b03c7445c13..00bdf108a8398 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala @@ -114,11 +114,13 @@ case class InsertIntoTable( } } -case class CreateTableAsSelect( +case class CreateTableAsSelect[T]( databaseName: Option[String], tableName: String, - child: LogicalPlan) extends UnaryNode { - override def output = child.output + child: LogicalPlan, + allowExisting: Boolean, + desc: Option[T] = None) extends UnaryNode { + override def output = Seq.empty[Attribute] override lazy val resolved = (databaseName != None && childrenResolved) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/commands.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/commands.scala index b8ba2ee428a20..1d513d7789763 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/commands.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/commands.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.catalyst.plans.logical -import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, BoundReference} +import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference} import org.apache.spark.sql.catalyst.types.StringType /** @@ -41,6 +41,15 @@ case class NativeCommand(cmd: String) extends Command { /** * Commands of the form "SET [key [= value] ]". */ +case class DFSCommand(kv: Option[(String, Option[String])]) extends Command { + override def output = Seq( + AttributeReference("DFS output", StringType, nullable = false)()) +} + +/** + * + * Commands of the form "SET [key [= value] ]". + */ case class SetCommand(kv: Option[(String, Option[String])]) extends Command { override def output = Seq( AttributeReference("", StringType, nullable = false)()) @@ -81,14 +90,3 @@ case class DescribeCommand( AttributeReference("data_type", StringType, nullable = false)(), AttributeReference("comment", StringType, nullable = false)()) } - -/** - * Returned for the "! shellCommand" command - */ -case class ShellCommand(cmd: String) extends Command - - -/** - * Returned for the "SOURCE file" command - */ -case class SourceCommand(filePath: String) extends Command diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/dataTypes.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/dataTypes.scala index 0cf139ebde417..5dd19dd12d8dd 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/dataTypes.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/dataTypes.scala @@ -19,21 +19,23 @@ package org.apache.spark.sql.catalyst.types import java.sql.{Date, Timestamp} -import scala.math.Numeric.{BigDecimalAsIfIntegral, DoubleAsIfIntegral, FloatAsIfIntegral} +import scala.math.Numeric.{FloatAsIfIntegral, BigDecimalAsIfIntegral, DoubleAsIfIntegral} import scala.reflect.ClassTag import scala.reflect.runtime.universe.{TypeTag, runtimeMirror, typeTag} import scala.util.parsing.combinator.RegexParsers -import org.json4s.JsonAST.JValue import org.json4s._ +import org.json4s.JsonAST.JValue import org.json4s.JsonDSL._ import org.json4s.jackson.JsonMethods._ +import org.apache.spark.annotation.DeveloperApi import org.apache.spark.sql.catalyst.ScalaReflectionLock -import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, Expression} +import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, Expression, Row} +import org.apache.spark.sql.catalyst.types.decimal._ +import org.apache.spark.sql.catalyst.util.Metadata import org.apache.spark.util.Utils - object DataType { def fromJson(json: String): DataType = parseDataType(parse(json)) @@ -66,17 +68,25 @@ object DataType { ("fields", JArray(fields)), ("type", JString("struct"))) => StructType(fields.map(parseStructField)) + + case JSortedObject( + ("class", JString(udtClass)), + ("pyClass", _), + ("sqlType", _), + ("type", JString("udt"))) => + Class.forName(udtClass).newInstance().asInstanceOf[UserDefinedType[_]] } private def parseStructField(json: JValue): StructField = json match { case JSortedObject( + ("metadata", metadata: JObject), ("name", JString(name)), ("nullable", JBool(nullable)), ("type", dataType: JValue)) => - StructField(name, parseDataType(dataType), nullable) + StructField(name, parseDataType(dataType), nullable, Metadata.fromJObject(metadata)) } - @deprecated("Use DataType.fromJson instead") + @deprecated("Use DataType.fromJson instead", "1.2.0") def fromCaseClassString(string: String): DataType = CaseClassStringParser(string) private object CaseClassStringParser extends RegexParsers { @@ -90,10 +100,17 @@ object DataType { | "LongType" ^^^ LongType | "BinaryType" ^^^ BinaryType | "BooleanType" ^^^ BooleanType - | "DecimalType" ^^^ DecimalType + | "DateType" ^^^ DateType + | "DecimalType()" ^^^ DecimalType.Unlimited + | fixedDecimalType | "TimestampType" ^^^ TimestampType ) + protected lazy val fixedDecimalType: Parser[DataType] = + ("DecimalType(" ~> "[0-9]+".r) ~ ("," ~> "[0-9]+".r <~ ")") ^^ { + case precision ~ scale => DecimalType(precision.toInt, scale.toInt) + } + protected lazy val arrayType: Parser[DataType] = "ArrayType" ~> "(" ~> dataType ~ "," ~ boolVal <~ ")" ^^ { case tpe ~ _ ~ containsNull => ArrayType(tpe, containsNull) @@ -198,9 +215,18 @@ trait PrimitiveType extends DataType { } object PrimitiveType { - private[sql] val all = Seq(DecimalType, TimestampType, BinaryType) ++ NativeType.all - - private[sql] val nameToType = all.map(t => t.typeName -> t).toMap + private val nonDecimals = Seq(NullType, DateType, TimestampType, BinaryType) ++ NativeType.all + private val nonDecimalNameToType = nonDecimals.map(t => t.typeName -> t).toMap + + /** Given the string representation of a type, return its DataType */ + private[sql] def nameToType(name: String): DataType = { + val FIXED_DECIMAL = """decimal\(\s*(\d+)\s*,\s*(\d+)\s*\)""".r + name match { + case "decimal" => DecimalType.Unlimited + case FIXED_DECIMAL(precision, scale) => DecimalType(precision.toInt, scale.toInt) + case other => nonDecimalNameToType(other) + } + } } abstract class NativeType extends DataType { @@ -324,18 +350,64 @@ object FractionalType { case _ => false } } + abstract class FractionalType extends NumericType { private[sql] val fractional: Fractional[JvmType] private[sql] val asIntegral: Integral[JvmType] } -case object DecimalType extends FractionalType { - private[sql] type JvmType = BigDecimal +/** Precision parameters for a Decimal */ +case class PrecisionInfo(precision: Int, scale: Int) + +/** A Decimal that might have fixed precision and scale, or unlimited values for these */ +case class DecimalType(precisionInfo: Option[PrecisionInfo]) extends FractionalType { + private[sql] type JvmType = Decimal @transient private[sql] lazy val tag = ScalaReflectionLock.synchronized { typeTag[JvmType] } - private[sql] val numeric = implicitly[Numeric[BigDecimal]] - private[sql] val fractional = implicitly[Fractional[BigDecimal]] - private[sql] val ordering = implicitly[Ordering[JvmType]] - private[sql] val asIntegral = BigDecimalAsIfIntegral + private[sql] val numeric = Decimal.DecimalIsFractional + private[sql] val fractional = Decimal.DecimalIsFractional + private[sql] val ordering = Decimal.DecimalIsFractional + private[sql] val asIntegral = Decimal.DecimalAsIfIntegral + + override def typeName: String = precisionInfo match { + case Some(PrecisionInfo(precision, scale)) => s"decimal($precision,$scale)" + case None => "decimal" + } + + override def toString: String = precisionInfo match { + case Some(PrecisionInfo(precision, scale)) => s"DecimalType($precision,$scale)" + case None => "DecimalType()" + } +} + +/** Extra factory methods and pattern matchers for Decimals */ +object DecimalType { + val Unlimited: DecimalType = DecimalType(None) + + object Fixed { + def unapply(t: DecimalType): Option[(Int, Int)] = + t.precisionInfo.map(p => (p.precision, p.scale)) + } + + object Expression { + def unapply(e: Expression): Option[(Int, Int)] = e.dataType match { + case t: DecimalType => t.precisionInfo.map(p => (p.precision, p.scale)) + case _ => None + } + } + + def apply(): DecimalType = Unlimited + + def apply(precision: Int, scale: Int): DecimalType = + DecimalType(Some(PrecisionInfo(precision, scale))) + + def unapply(t: DataType): Boolean = t.isInstanceOf[DecimalType] + + def unapply(e: Expression): Boolean = e.dataType.isInstanceOf[DecimalType] + + def isFixed(dataType: DataType): Boolean = dataType match { + case DecimalType.Fixed(_, _) => true + case _ => false + } } case object DoubleType extends FractionalType { @@ -386,24 +458,34 @@ case class ArrayType(elementType: DataType, containsNull: Boolean) extends DataT * @param name The name of this field. * @param dataType The data type of this field. * @param nullable Indicates if values of this field can be `null` values. + * @param metadata The metadata of this field. The metadata should be preserved during + * transformation if the content of the column is not modified, e.g, in selection. */ -case class StructField(name: String, dataType: DataType, nullable: Boolean) { +case class StructField( + name: String, + dataType: DataType, + nullable: Boolean = true, + metadata: Metadata = Metadata.empty) { private[sql] def buildFormattedString(prefix: String, builder: StringBuilder): Unit = { builder.append(s"$prefix-- $name: ${dataType.typeName} (nullable = $nullable)\n") DataType.buildFormattedString(dataType, s"$prefix |", builder) } + // override the default toString to be compatible with legacy parquet files. + override def toString: String = s"StructField($name,$dataType,$nullable)" + private[sql] def jsonValue: JValue = { ("name" -> name) ~ ("type" -> dataType.jsonValue) ~ - ("nullable" -> nullable) + ("nullable" -> nullable) ~ + ("metadata" -> metadata.jsonValue) } } object StructType { protected[sql] def fromAttributes(attributes: Seq[Attribute]): StructType = - StructType(attributes.map(a => StructField(a.name, a.dataType, a.nullable))) + StructType(attributes.map(a => StructField(a.name, a.dataType, a.nullable, a.metadata))) } case class StructType(fields: Seq[StructField]) extends DataType { @@ -437,7 +519,7 @@ case class StructType(fields: Seq[StructField]) extends DataType { } protected[sql] def toAttributes = - fields.map(f => AttributeReference(f.name, f.dataType, f.nullable)()) + fields.map(f => AttributeReference(f.name, f.dataType, f.nullable, f.metadata)()) def treeString: String = { val builder = new StringBuilder @@ -492,3 +574,50 @@ case class MapType( ("valueType" -> valueType.jsonValue) ~ ("valueContainsNull" -> valueContainsNull) } + +/** + * ::DeveloperApi:: + * The data type for User Defined Types (UDTs). + * + * This interface allows a user to make their own classes more interoperable with SparkSQL; + * e.g., by creating a [[UserDefinedType]] for a class X, it becomes possible to create + * a SchemaRDD which has class X in the schema. + * + * For SparkSQL to recognize UDTs, the UDT must be annotated with + * [[org.apache.spark.sql.catalyst.annotation.SQLUserDefinedType]]. + * + * The conversion via `serialize` occurs when instantiating a `SchemaRDD` from another RDD. + * The conversion via `deserialize` occurs when reading from a `SchemaRDD`. + */ +@DeveloperApi +abstract class UserDefinedType[UserType] extends DataType with Serializable { + + /** Underlying storage type for this UDT */ + def sqlType: DataType + + /** Paired Python UDT class, if exists. */ + def pyUDT: String = null + + /** + * Convert the user type to a SQL datum + * + * TODO: Can we make this take obj: UserType? The issue is in ScalaReflection.convertToCatalyst, + * where we need to convert Any to UserType. + */ + def serialize(obj: Any): Any + + /** Convert a SQL datum to the user type */ + def deserialize(datum: Any): UserType + + override private[sql] def jsonValue: JValue = { + ("type" -> "udt") ~ + ("class" -> this.getClass.getName) ~ + ("pyClass" -> pyUDT) ~ + ("sqlType" -> sqlType.jsonValue) + } + + /** + * Class object for the UserType + */ + def userClass: java.lang.Class[UserType] +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/decimal/Decimal.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/decimal/Decimal.scala new file mode 100644 index 0000000000000..708362acf32dc --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/decimal/Decimal.scala @@ -0,0 +1,335 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.types.decimal + +import org.apache.spark.annotation.DeveloperApi + +/** + * A mutable implementation of BigDecimal that can hold a Long if values are small enough. + * + * The semantics of the fields are as follows: + * - _precision and _scale represent the SQL precision and scale we are looking for + * - If decimalVal is set, it represents the whole decimal value + * - Otherwise, the decimal value is longVal / (10 ** _scale) + */ +final class Decimal extends Ordered[Decimal] with Serializable { + import Decimal.{MAX_LONG_DIGITS, POW_10, ROUNDING_MODE, BIG_DEC_ZERO} + + private var decimalVal: BigDecimal = null + private var longVal: Long = 0L + private var _precision: Int = 1 + private var _scale: Int = 0 + + def precision: Int = _precision + def scale: Int = _scale + + /** + * Set this Decimal to the given Long. Will have precision 20 and scale 0. + */ + def set(longVal: Long): Decimal = { + if (longVal <= -POW_10(MAX_LONG_DIGITS) || longVal >= POW_10(MAX_LONG_DIGITS)) { + // We can't represent this compactly as a long without risking overflow + this.decimalVal = BigDecimal(longVal) + this.longVal = 0L + } else { + this.decimalVal = null + this.longVal = longVal + } + this._precision = 20 + this._scale = 0 + this + } + + /** + * Set this Decimal to the given Int. Will have precision 10 and scale 0. + */ + def set(intVal: Int): Decimal = { + this.decimalVal = null + this.longVal = intVal + this._precision = 10 + this._scale = 0 + this + } + + /** + * Set this Decimal to the given unscaled Long, with a given precision and scale. + */ + def set(unscaled: Long, precision: Int, scale: Int): Decimal = { + if (setOrNull(unscaled, precision, scale) == null) { + throw new IllegalArgumentException("Unscaled value too large for precision") + } + this + } + + /** + * Set this Decimal to the given unscaled Long, with a given precision and scale, + * and return it, or return null if it cannot be set due to overflow. + */ + def setOrNull(unscaled: Long, precision: Int, scale: Int): Decimal = { + if (unscaled <= -POW_10(MAX_LONG_DIGITS) || unscaled >= POW_10(MAX_LONG_DIGITS)) { + // We can't represent this compactly as a long without risking overflow + if (precision < 19) { + return null // Requested precision is too low to represent this value + } + this.decimalVal = BigDecimal(longVal) + this.longVal = 0L + } else { + val p = POW_10(math.min(precision, MAX_LONG_DIGITS)) + if (unscaled <= -p || unscaled >= p) { + return null // Requested precision is too low to represent this value + } + this.decimalVal = null + this.longVal = unscaled + } + this._precision = precision + this._scale = scale + this + } + + /** + * Set this Decimal to the given BigDecimal value, with a given precision and scale. + */ + def set(decimal: BigDecimal, precision: Int, scale: Int): Decimal = { + this.decimalVal = decimal.setScale(scale, ROUNDING_MODE) + require(decimalVal.precision <= precision, "Overflowed precision") + this.longVal = 0L + this._precision = precision + this._scale = scale + this + } + + /** + * Set this Decimal to the given BigDecimal value, inheriting its precision and scale. + */ + def set(decimal: BigDecimal): Decimal = { + this.decimalVal = decimal + this.longVal = 0L + this._precision = decimal.precision + this._scale = decimal.scale + this + } + + /** + * Set this Decimal to the given Decimal value. + */ + def set(decimal: Decimal): Decimal = { + this.decimalVal = decimal.decimalVal + this.longVal = decimal.longVal + this._precision = decimal._precision + this._scale = decimal._scale + this + } + + def toBigDecimal: BigDecimal = { + if (decimalVal.ne(null)) { + decimalVal + } else { + BigDecimal(longVal, _scale) + } + } + + def toUnscaledLong: Long = { + if (decimalVal.ne(null)) { + decimalVal.underlying().unscaledValue().longValue() + } else { + longVal + } + } + + override def toString: String = toBigDecimal.toString() + + @DeveloperApi + def toDebugString: String = { + if (decimalVal.ne(null)) { + s"Decimal(expanded,$decimalVal,$precision,$scale})" + } else { + s"Decimal(compact,$longVal,$precision,$scale})" + } + } + + def toDouble: Double = toBigDecimal.doubleValue() + + def toFloat: Float = toBigDecimal.floatValue() + + def toLong: Long = { + if (decimalVal.eq(null)) { + longVal / POW_10(_scale) + } else { + decimalVal.longValue() + } + } + + def toInt: Int = toLong.toInt + + def toShort: Short = toLong.toShort + + def toByte: Byte = toLong.toByte + + /** + * Update precision and scale while keeping our value the same, and return true if successful. + * + * @return true if successful, false if overflow would occur + */ + def changePrecision(precision: Int, scale: Int): Boolean = { + // First, update our longVal if we can, or transfer over to using a BigDecimal + if (decimalVal.eq(null)) { + if (scale < _scale) { + // Easier case: we just need to divide our scale down + val diff = _scale - scale + val droppedDigits = longVal % POW_10(diff) + longVal /= POW_10(diff) + if (math.abs(droppedDigits) * 2 >= POW_10(diff)) { + longVal += (if (longVal < 0) -1L else 1L) + } + } else if (scale > _scale) { + // We might be able to multiply longVal by a power of 10 and not overflow, but if not, + // switch to using a BigDecimal + val diff = scale - _scale + val p = POW_10(math.max(MAX_LONG_DIGITS - diff, 0)) + if (diff <= MAX_LONG_DIGITS && longVal > -p && longVal < p) { + // Multiplying longVal by POW_10(diff) will still keep it below MAX_LONG_DIGITS + longVal *= POW_10(diff) + } else { + // Give up on using Longs; switch to BigDecimal, which we'll modify below + decimalVal = BigDecimal(longVal, _scale) + } + } + // In both cases, we will check whether our precision is okay below + } + + if (decimalVal.ne(null)) { + // We get here if either we started with a BigDecimal, or we switched to one because we would + // have overflowed our Long; in either case we must rescale decimalVal to the new scale. + val newVal = decimalVal.setScale(scale, ROUNDING_MODE) + if (newVal.precision > precision) { + return false + } + decimalVal = newVal + } else { + // We're still using Longs, but we should check whether we match the new precision + val p = POW_10(math.min(_precision, MAX_LONG_DIGITS)) + if (longVal <= -p || longVal >= p) { + // Note that we shouldn't have been able to fix this by switching to BigDecimal + return false + } + } + + _precision = precision + _scale = scale + true + } + + override def clone(): Decimal = new Decimal().set(this) + + override def compare(other: Decimal): Int = { + if (decimalVal.eq(null) && other.decimalVal.eq(null) && _scale == other._scale) { + if (longVal < other.longVal) -1 else if (longVal == other.longVal) 0 else 1 + } else { + toBigDecimal.compare(other.toBigDecimal) + } + } + + override def equals(other: Any) = other match { + case d: Decimal => + compare(d) == 0 + case _ => + false + } + + override def hashCode(): Int = toBigDecimal.hashCode() + + def isZero: Boolean = if (decimalVal.ne(null)) decimalVal == BIG_DEC_ZERO else longVal == 0 + + def + (that: Decimal): Decimal = Decimal(toBigDecimal + that.toBigDecimal) + + def - (that: Decimal): Decimal = Decimal(toBigDecimal - that.toBigDecimal) + + def * (that: Decimal): Decimal = Decimal(toBigDecimal * that.toBigDecimal) + + def / (that: Decimal): Decimal = + if (that.isZero) null else Decimal(toBigDecimal / that.toBigDecimal) + + def % (that: Decimal): Decimal = + if (that.isZero) null else Decimal(toBigDecimal % that.toBigDecimal) + + def remainder(that: Decimal): Decimal = this % that + + def unary_- : Decimal = { + if (decimalVal.ne(null)) { + Decimal(-decimalVal) + } else { + Decimal(-longVal, precision, scale) + } + } +} + +object Decimal { + private val ROUNDING_MODE = BigDecimal.RoundingMode.HALF_UP + + /** Maximum number of decimal digits a Long can represent */ + val MAX_LONG_DIGITS = 18 + + private val POW_10 = Array.tabulate[Long](MAX_LONG_DIGITS + 1)(i => math.pow(10, i).toLong) + + private val BIG_DEC_ZERO = BigDecimal(0) + + def apply(value: Double): Decimal = new Decimal().set(value) + + def apply(value: Long): Decimal = new Decimal().set(value) + + def apply(value: Int): Decimal = new Decimal().set(value) + + def apply(value: BigDecimal): Decimal = new Decimal().set(value) + + def apply(value: BigDecimal, precision: Int, scale: Int): Decimal = + new Decimal().set(value, precision, scale) + + def apply(unscaled: Long, precision: Int, scale: Int): Decimal = + new Decimal().set(unscaled, precision, scale) + + def apply(value: String): Decimal = new Decimal().set(BigDecimal(value)) + + // Evidence parameters for Decimal considered either as Fractional or Integral. We provide two + // parameters inheriting from a common trait since both traits define mkNumericOps. + // See scala.math's Numeric.scala for examples for Scala's built-in types. + + /** Common methods for Decimal evidence parameters */ + trait DecimalIsConflicted extends Numeric[Decimal] { + override def plus(x: Decimal, y: Decimal): Decimal = x + y + override def times(x: Decimal, y: Decimal): Decimal = x * y + override def minus(x: Decimal, y: Decimal): Decimal = x - y + override def negate(x: Decimal): Decimal = -x + override def toDouble(x: Decimal): Double = x.toDouble + override def toFloat(x: Decimal): Float = x.toFloat + override def toInt(x: Decimal): Int = x.toInt + override def toLong(x: Decimal): Long = x.toLong + override def fromInt(x: Int): Decimal = new Decimal().set(x) + override def compare(x: Decimal, y: Decimal): Int = x.compare(y) + } + + /** A [[scala.math.Fractional]] evidence parameter for Decimals. */ + object DecimalIsFractional extends DecimalIsConflicted with Fractional[Decimal] { + override def div(x: Decimal, y: Decimal): Decimal = x / y + } + + /** A [[scala.math.Integral]] evidence parameter for Decimals. */ + object DecimalAsIfIntegral extends DecimalIsConflicted with Integral[Decimal] { + override def quot(x: Decimal, y: Decimal): Decimal = x / y + override def rem(x: Decimal, y: Decimal): Decimal = x % y + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/Metadata.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/Metadata.scala new file mode 100644 index 0000000000000..2f2082fa3c863 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/Metadata.scala @@ -0,0 +1,255 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.util + +import scala.collection.mutable + +import org.json4s._ +import org.json4s.jackson.JsonMethods._ + +/** + * Metadata is a wrapper over Map[String, Any] that limits the value type to simple ones: Boolean, + * Long, Double, String, Metadata, Array[Boolean], Array[Long], Array[Double], Array[String], and + * Array[Metadata]. JSON is used for serialization. + * + * The default constructor is private. User should use either [[MetadataBuilder]] or + * [[Metadata$#fromJson]] to create Metadata instances. + * + * @param map an immutable map that stores the data + */ +sealed class Metadata private[util] (private[util] val map: Map[String, Any]) extends Serializable { + + /** Gets a Long. */ + def getLong(key: String): Long = get(key) + + /** Gets a Double. */ + def getDouble(key: String): Double = get(key) + + /** Gets a Boolean. */ + def getBoolean(key: String): Boolean = get(key) + + /** Gets a String. */ + def getString(key: String): String = get(key) + + /** Gets a Metadata. */ + def getMetadata(key: String): Metadata = get(key) + + /** Gets a Long array. */ + def getLongArray(key: String): Array[Long] = get(key) + + /** Gets a Double array. */ + def getDoubleArray(key: String): Array[Double] = get(key) + + /** Gets a Boolean array. */ + def getBooleanArray(key: String): Array[Boolean] = get(key) + + /** Gets a String array. */ + def getStringArray(key: String): Array[String] = get(key) + + /** Gets a Metadata array. */ + def getMetadataArray(key: String): Array[Metadata] = get(key) + + /** Converts to its JSON representation. */ + def json: String = compact(render(jsonValue)) + + override def toString: String = json + + override def equals(obj: Any): Boolean = { + obj match { + case that: Metadata => + if (map.keySet == that.map.keySet) { + map.keys.forall { k => + (map(k), that.map(k)) match { + case (v0: Array[_], v1: Array[_]) => + v0.view == v1.view + case (v0, v1) => + v0 == v1 + } + } + } else { + false + } + case other => + false + } + } + + override def hashCode: Int = Metadata.hash(this) + + private def get[T](key: String): T = { + map(key).asInstanceOf[T] + } + + private[sql] def jsonValue: JValue = Metadata.toJsonValue(this) +} + +object Metadata { + + /** Returns an empty Metadata. */ + def empty: Metadata = new Metadata(Map.empty) + + /** Creates a Metadata instance from JSON. */ + def fromJson(json: String): Metadata = { + fromJObject(parse(json).asInstanceOf[JObject]) + } + + /** Creates a Metadata instance from JSON AST. */ + private[sql] def fromJObject(jObj: JObject): Metadata = { + val builder = new MetadataBuilder + jObj.obj.foreach { + case (key, JInt(value)) => + builder.putLong(key, value.toLong) + case (key, JDouble(value)) => + builder.putDouble(key, value) + case (key, JBool(value)) => + builder.putBoolean(key, value) + case (key, JString(value)) => + builder.putString(key, value) + case (key, o: JObject) => + builder.putMetadata(key, fromJObject(o)) + case (key, JArray(value)) => + if (value.isEmpty) { + // If it is an empty array, we cannot infer its element type. We put an empty Array[Long]. + builder.putLongArray(key, Array.empty) + } else { + value.head match { + case _: JInt => + builder.putLongArray(key, value.asInstanceOf[List[JInt]].map(_.num.toLong).toArray) + case _: JDouble => + builder.putDoubleArray(key, value.asInstanceOf[List[JDouble]].map(_.num).toArray) + case _: JBool => + builder.putBooleanArray(key, value.asInstanceOf[List[JBool]].map(_.value).toArray) + case _: JString => + builder.putStringArray(key, value.asInstanceOf[List[JString]].map(_.s).toArray) + case _: JObject => + builder.putMetadataArray( + key, value.asInstanceOf[List[JObject]].map(fromJObject).toArray) + case other => + throw new RuntimeException(s"Do not support array of type ${other.getClass}.") + } + } + case other => + throw new RuntimeException(s"Do not support type ${other.getClass}.") + } + builder.build() + } + + /** Converts to JSON AST. */ + private def toJsonValue(obj: Any): JValue = { + obj match { + case map: Map[_, _] => + val fields = map.toList.map { case (k: String, v) => (k, toJsonValue(v)) } + JObject(fields) + case arr: Array[_] => + val values = arr.toList.map(toJsonValue) + JArray(values) + case x: Long => + JInt(x) + case x: Double => + JDouble(x) + case x: Boolean => + JBool(x) + case x: String => + JString(x) + case x: Metadata => + toJsonValue(x.map) + case other => + throw new RuntimeException(s"Do not support type ${other.getClass}.") + } + } + + /** Computes the hash code for the types we support. */ + private def hash(obj: Any): Int = { + obj match { + case map: Map[_, _] => + map.mapValues(hash).## + case arr: Array[_] => + // Seq.empty[T] has the same hashCode regardless of T. + arr.toSeq.map(hash).## + case x: Long => + x.## + case x: Double => + x.## + case x: Boolean => + x.## + case x: String => + x.## + case x: Metadata => + hash(x.map) + case other => + throw new RuntimeException(s"Do not support type ${other.getClass}.") + } + } +} + +/** + * Builder for [[Metadata]]. If there is a key collision, the latter will overwrite the former. + */ +class MetadataBuilder { + + private val map: mutable.Map[String, Any] = mutable.Map.empty + + /** Returns the immutable version of this map. Used for java interop. */ + protected def getMap = map.toMap + + /** Include the content of an existing [[Metadata]] instance. */ + def withMetadata(metadata: Metadata): this.type = { + map ++= metadata.map + this + } + + /** Puts a Long. */ + def putLong(key: String, value: Long): this.type = put(key, value) + + /** Puts a Double. */ + def putDouble(key: String, value: Double): this.type = put(key, value) + + /** Puts a Boolean. */ + def putBoolean(key: String, value: Boolean): this.type = put(key, value) + + /** Puts a String. */ + def putString(key: String, value: String): this.type = put(key, value) + + /** Puts a [[Metadata]]. */ + def putMetadata(key: String, value: Metadata): this.type = put(key, value) + + /** Puts a Long array. */ + def putLongArray(key: String, value: Array[Long]): this.type = put(key, value) + + /** Puts a Double array. */ + def putDoubleArray(key: String, value: Array[Double]): this.type = put(key, value) + + /** Puts a Boolean array. */ + def putBooleanArray(key: String, value: Array[Boolean]): this.type = put(key, value) + + /** Puts a String array. */ + def putStringArray(key: String, value: Array[String]): this.type = put(key, value) + + /** Puts a [[Metadata]] array. */ + def putMetadataArray(key: String, value: Array[Metadata]): this.type = put(key, value) + + /** Builds the [[Metadata]] instance. */ + def build(): Metadata = { + new Metadata(map.toMap) + } + + private def put(key: String, value: Any): this.type = { + map.put(key, value) + this + } +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala index 488e373854bb3..ddc3d44869c98 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala @@ -18,10 +18,11 @@ package org.apache.spark.sql.catalyst import java.math.BigInteger -import java.sql.Timestamp +import java.sql.{Date, Timestamp} import org.scalatest.FunSuite +import org.apache.spark.sql.catalyst.expressions.Row import org.apache.spark.sql.catalyst.types._ case class PrimitiveData( @@ -43,6 +44,7 @@ case class NullableData( booleanField: java.lang.Boolean, stringField: String, decimalField: BigDecimal, + dateField: Date, timestampField: Timestamp, binaryField: Array[Byte]) @@ -95,7 +97,8 @@ class ScalaReflectionSuite extends FunSuite { StructField("byteField", ByteType, nullable = true), StructField("booleanField", BooleanType, nullable = true), StructField("stringField", StringType, nullable = true), - StructField("decimalField", DecimalType, nullable = true), + StructField("decimalField", DecimalType.Unlimited, nullable = true), + StructField("dateField", DateType, nullable = true), StructField("timestampField", TimestampType, nullable = true), StructField("binaryField", BinaryType, nullable = true))), nullable = true)) @@ -197,28 +200,31 @@ class ScalaReflectionSuite extends FunSuite { assert(DoubleType === typeOfObject(1.7976931348623157E308)) // DecimalType - assert(DecimalType === typeOfObject(BigDecimal("1.7976931348623157E318"))) + assert(DecimalType.Unlimited === typeOfObject(BigDecimal("1.7976931348623157E318"))) + + // DateType + assert(DateType === typeOfObject(Date.valueOf("2014-07-25"))) // TimestampType - assert(TimestampType === typeOfObject(java.sql.Timestamp.valueOf("2014-07-25 10:26:00"))) + assert(TimestampType === typeOfObject(Timestamp.valueOf("2014-07-25 10:26:00"))) // NullType assert(NullType === typeOfObject(null)) def typeOfObject1: PartialFunction[Any, DataType] = typeOfObject orElse { - case value: java.math.BigInteger => DecimalType - case value: java.math.BigDecimal => DecimalType + case value: java.math.BigInteger => DecimalType.Unlimited + case value: java.math.BigDecimal => DecimalType.Unlimited case _ => StringType } - assert(DecimalType === typeOfObject1( + assert(DecimalType.Unlimited === typeOfObject1( new BigInteger("92233720368547758070"))) - assert(DecimalType === typeOfObject1( + assert(DecimalType.Unlimited === typeOfObject1( new java.math.BigDecimal("1.7976931348623157E318"))) assert(StringType === typeOfObject1(BigInt("92233720368547758070"))) def typeOfObject2: PartialFunction[Any, DataType] = typeOfObject orElse { - case value: java.math.BigInteger => DecimalType + case value: java.math.BigInteger => DecimalType.Unlimited } intercept[MatchError](typeOfObject2(BigInt("92233720368547758070"))) @@ -234,13 +240,17 @@ class ScalaReflectionSuite extends FunSuite { test("convert PrimitiveData to catalyst") { val data = PrimitiveData(1, 1, 1, 1, 1, 1, true) val convertedData = Seq(1, 1.toLong, 1.toDouble, 1.toFloat, 1.toShort, 1.toByte, true) - assert(convertToCatalyst(data) === convertedData) + val dataType = schemaFor[PrimitiveData].dataType + assert(convertToCatalyst(data, dataType) === convertedData) } test("convert Option[Product] to catalyst") { val primitiveData = PrimitiveData(1, 1, 1, 1, 1, 1, true) - val data = OptionalData(Some(1), Some(1), Some(1), Some(1), Some(1), Some(1), Some(true), Some(primitiveData)) - val convertedData = Seq(1, 1.toLong, 1.toDouble, 1.toFloat, 1.toShort, 1.toByte, true, convertToCatalyst(primitiveData)) - assert(convertToCatalyst(data) === convertedData) + val data = OptionalData(Some(2), Some(2), Some(2), Some(2), Some(2), Some(2), Some(true), + Some(primitiveData)) + val dataType = schemaFor[OptionalData].dataType + val convertedData = Row(2, 2.toLong, 2.toDouble, 2.toFloat, 2.toShort, 2.toByte, true, + Row(1, 1, 1, 1, 1, 1, true)) + assert(convertToCatalyst(data, dataType) === convertedData) } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala index 7b45738c4fc95..33a3cba3d4c0e 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala @@ -38,7 +38,7 @@ class AnalysisSuite extends FunSuite with BeforeAndAfter { AttributeReference("a", StringType)(), AttributeReference("b", StringType)(), AttributeReference("c", DoubleType)(), - AttributeReference("d", DecimalType)(), + AttributeReference("d", DecimalType.Unlimited)(), AttributeReference("e", ShortType)()) before { @@ -119,7 +119,7 @@ class AnalysisSuite extends FunSuite with BeforeAndAfter { AttributeReference("a", StringType)(), AttributeReference("b", StringType)(), AttributeReference("c", DoubleType)(), - AttributeReference("d", DecimalType)(), + AttributeReference("d", DecimalType.Unlimited)(), AttributeReference("e", ShortType)()) val expr0 = 'a / 2 @@ -137,7 +137,7 @@ class AnalysisSuite extends FunSuite with BeforeAndAfter { assert(pl(0).dataType == DoubleType) assert(pl(1).dataType == DoubleType) assert(pl(2).dataType == DoubleType) - assert(pl(3).dataType == DecimalType) + assert(pl(3).dataType == DecimalType.Unlimited) assert(pl(4).dataType == DoubleType) } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecisionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecisionSuite.scala new file mode 100644 index 0000000000000..d5b7d2789a103 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecisionSuite.scala @@ -0,0 +1,88 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.analysis + +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.plans.logical.{Project, LocalRelation} +import org.apache.spark.sql.catalyst.types._ +import org.scalatest.{BeforeAndAfter, FunSuite} + +class DecimalPrecisionSuite extends FunSuite with BeforeAndAfter { + val catalog = new SimpleCatalog(false) + val analyzer = new Analyzer(catalog, EmptyFunctionRegistry, caseSensitive = false) + + val relation = LocalRelation( + AttributeReference("i", IntegerType)(), + AttributeReference("d1", DecimalType(2, 1))(), + AttributeReference("d2", DecimalType(5, 2))(), + AttributeReference("u", DecimalType.Unlimited)(), + AttributeReference("f", FloatType)() + ) + + val i: Expression = UnresolvedAttribute("i") + val d1: Expression = UnresolvedAttribute("d1") + val d2: Expression = UnresolvedAttribute("d2") + val u: Expression = UnresolvedAttribute("u") + val f: Expression = UnresolvedAttribute("f") + + before { + catalog.registerTable(None, "table", relation) + } + + private def checkType(expression: Expression, expectedType: DataType): Unit = { + val plan = Project(Seq(Alias(expression, "c")()), relation) + assert(analyzer(plan).schema.fields(0).dataType === expectedType) + } + + test("basic operations") { + checkType(Add(d1, d2), DecimalType(6, 2)) + checkType(Subtract(d1, d2), DecimalType(6, 2)) + checkType(Multiply(d1, d2), DecimalType(8, 3)) + checkType(Divide(d1, d2), DecimalType(10, 7)) + checkType(Divide(d2, d1), DecimalType(10, 6)) + checkType(Remainder(d1, d2), DecimalType(3, 2)) + checkType(Remainder(d2, d1), DecimalType(3, 2)) + checkType(Sum(d1), DecimalType(12, 1)) + checkType(Average(d1), DecimalType(6, 5)) + + checkType(Add(Add(d1, d2), d1), DecimalType(7, 2)) + checkType(Add(Add(Add(d1, d2), d1), d2), DecimalType(8, 2)) + checkType(Add(Add(d1, d2), Add(d1, d2)), DecimalType(7, 2)) + } + + test("bringing in primitive types") { + checkType(Add(d1, i), DecimalType(12, 1)) + checkType(Add(d1, f), DoubleType) + checkType(Add(i, d1), DecimalType(12, 1)) + checkType(Add(f, d1), DoubleType) + checkType(Add(d1, Cast(i, LongType)), DecimalType(22, 1)) + checkType(Add(d1, Cast(i, ShortType)), DecimalType(7, 1)) + checkType(Add(d1, Cast(i, ByteType)), DecimalType(5, 1)) + checkType(Add(d1, Cast(i, DoubleType)), DoubleType) + } + + test("unlimited decimals make everything else cast up") { + for (expr <- Seq(d1, d2, i, f, u)) { + checkType(Add(expr, u), DecimalType.Unlimited) + checkType(Subtract(expr, u), DecimalType.Unlimited) + checkType(Multiply(expr, u), DecimalType.Unlimited) + checkType(Divide(expr, u), DecimalType.Unlimited) + checkType(Remainder(expr, u), DecimalType.Unlimited) + } + } +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala index baeb9b0cf5964..dfa2d958c0faf 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala @@ -68,6 +68,21 @@ class HiveTypeCoercionSuite extends FunSuite { widenTest(LongType, FloatType, Some(FloatType)) widenTest(LongType, DoubleType, Some(DoubleType)) + // Casting up to unlimited-precision decimal + widenTest(IntegerType, DecimalType.Unlimited, Some(DecimalType.Unlimited)) + widenTest(DoubleType, DecimalType.Unlimited, Some(DecimalType.Unlimited)) + widenTest(DecimalType(3, 2), DecimalType.Unlimited, Some(DecimalType.Unlimited)) + widenTest(DecimalType.Unlimited, IntegerType, Some(DecimalType.Unlimited)) + widenTest(DecimalType.Unlimited, DoubleType, Some(DecimalType.Unlimited)) + widenTest(DecimalType.Unlimited, DecimalType(3, 2), Some(DecimalType.Unlimited)) + + // No up-casting for fixed-precision decimal (this is handled by arithmetic rules) + widenTest(DecimalType(2, 1), DecimalType(3, 2), None) + widenTest(DecimalType(2, 1), DoubleType, None) + widenTest(DecimalType(2, 1), IntegerType, None) + widenTest(DoubleType, DecimalType(2, 1), None) + widenTest(IntegerType, DecimalType(2, 1), None) + // StringType widenTest(NullType, StringType, Some(StringType)) widenTest(StringType, StringType, Some(StringType)) @@ -92,7 +107,7 @@ class HiveTypeCoercionSuite extends FunSuite { def ruleTest(initial: Expression, transformed: Expression) { val testRelation = LocalRelation(AttributeReference("a", IntegerType)()) assert(booleanCasts(Project(Seq(Alias(initial, "a")()), testRelation)) == - Project(Seq(Alias(transformed, "a")()), testRelation)) + Project(Seq(Alias(transformed, "a")()), testRelation)) } // Remove superflous boolean -> boolean casts. ruleTest(Cast(Literal(true), BooleanType), Literal(true)) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala index 6dc5942023f9e..6bfa0dbd65ba7 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala @@ -21,9 +21,10 @@ import java.sql.{Date, Timestamp} import scala.collection.immutable.HashSet +import org.apache.spark.sql.catalyst.types.decimal.Decimal import org.scalatest.FunSuite import org.scalatest.Matchers._ -import org.scalautils.TripleEqualsSupport.Spread +import org.scalactic.TripleEqualsSupport.Spread import org.apache.spark.sql.catalyst.types._ @@ -138,7 +139,7 @@ class ExpressionEvaluationSuite extends FunSuite { val actual = try evaluate(expression, inputRow) catch { case e: Exception => fail(s"Exception evaluating $expression", e) } - actual.asInstanceOf[Double] shouldBe expected + actual.asInstanceOf[Double] shouldBe expected } test("IN") { @@ -165,7 +166,7 @@ class ExpressionEvaluationSuite extends FunSuite { checkEvaluation(InSet(three, nS, three +: nullS), false) checkEvaluation(InSet(one, hS, one +: s) && InSet(two, hS, two +: s), true) } - + test("MaxOf") { checkEvaluation(MaxOf(1, 2), 2) checkEvaluation(MaxOf(2, 1), 2) @@ -191,6 +192,9 @@ class ExpressionEvaluationSuite extends FunSuite { checkEvaluation("abc" like "a%", true) checkEvaluation("abc" like "b%", false) checkEvaluation("abc" like "bc%", false) + checkEvaluation("a\nb" like "a_b", true) + checkEvaluation("ab" like "a%b", true) + checkEvaluation("a\nb" like "a%b", true) } test("LIKE Non-literal Regular Expression") { @@ -207,6 +211,9 @@ class ExpressionEvaluationSuite extends FunSuite { checkEvaluation("abc" like regEx, true, new GenericRow(Array[Any]("a%"))) checkEvaluation("abc" like regEx, false, new GenericRow(Array[Any]("b%"))) checkEvaluation("abc" like regEx, false, new GenericRow(Array[Any]("bc%"))) + checkEvaluation("a\nb" like regEx, true, new GenericRow(Array[Any]("a_b"))) + checkEvaluation("ab" like regEx, true, new GenericRow(Array[Any]("a%b"))) + checkEvaluation("a\nb" like regEx, true, new GenericRow(Array[Any]("a%b"))) checkEvaluation(Literal(null, StringType) like regEx, null, new GenericRow(Array[Any]("bc%"))) } @@ -259,9 +266,9 @@ class ExpressionEvaluationSuite extends FunSuite { val ts = Timestamp.valueOf(nts) checkEvaluation("abdef" cast StringType, "abdef") - checkEvaluation("abdef" cast DecimalType, null) + checkEvaluation("abdef" cast DecimalType.Unlimited, null) checkEvaluation("abdef" cast TimestampType, null) - checkEvaluation("12.65" cast DecimalType, BigDecimal(12.65)) + checkEvaluation("12.65" cast DecimalType.Unlimited, Decimal(12.65)) checkEvaluation(Literal(1) cast LongType, 1) checkEvaluation(Cast(Literal(1000) cast TimestampType, LongType), 1.toLong) @@ -283,12 +290,12 @@ class ExpressionEvaluationSuite extends FunSuite { checkEvaluation(Cast(Cast(Cast(Cast( Cast("5" cast ByteType, ShortType), IntegerType), FloatType), DoubleType), LongType), 5) - checkEvaluation(Cast(Cast(Cast(Cast( - Cast("5" cast ByteType, TimestampType), DecimalType), LongType), StringType), ShortType), 0) - checkEvaluation(Cast(Cast(Cast(Cast( - Cast("5" cast TimestampType, ByteType), DecimalType), LongType), StringType), ShortType), null) - checkEvaluation(Cast(Cast(Cast(Cast( - Cast("5" cast DecimalType, ByteType), TimestampType), LongType), StringType), ShortType), 0) + checkEvaluation(Cast(Cast(Cast(Cast(Cast("5" cast + ByteType, TimestampType), DecimalType.Unlimited), LongType), StringType), ShortType), 0) + checkEvaluation(Cast(Cast(Cast(Cast(Cast("5" cast + TimestampType, ByteType), DecimalType.Unlimited), LongType), StringType), ShortType), null) + checkEvaluation(Cast(Cast(Cast(Cast(Cast("5" cast + DecimalType.Unlimited, ByteType), TimestampType), LongType), StringType), ShortType), 0) checkEvaluation(Literal(true) cast IntegerType, 1) checkEvaluation(Literal(false) cast IntegerType, 0) checkEvaluation(Cast(Literal(1) cast BooleanType, IntegerType), 1) @@ -296,7 +303,7 @@ class ExpressionEvaluationSuite extends FunSuite { checkEvaluation("23" cast DoubleType, 23d) checkEvaluation("23" cast IntegerType, 23) checkEvaluation("23" cast FloatType, 23f) - checkEvaluation("23" cast DecimalType, 23: BigDecimal) + checkEvaluation("23" cast DecimalType.Unlimited, Decimal(23)) checkEvaluation("23" cast ByteType, 23.toByte) checkEvaluation("23" cast ShortType, 23.toShort) checkEvaluation("2012-12-11" cast DoubleType, null) @@ -305,7 +312,7 @@ class ExpressionEvaluationSuite extends FunSuite { checkEvaluation(Literal(23d) + Cast(true, DoubleType), 24d) checkEvaluation(Literal(23) + Cast(true, IntegerType), 24) checkEvaluation(Literal(23f) + Cast(true, FloatType), 24f) - checkEvaluation(Literal(BigDecimal(23)) + Cast(true, DecimalType), 24: BigDecimal) + checkEvaluation(Literal(Decimal(23)) + Cast(true, DecimalType.Unlimited), Decimal(24)) checkEvaluation(Literal(23.toByte) + Cast(true, ByteType), 24.toByte) checkEvaluation(Literal(23.toShort) + Cast(true, ShortType), 24.toShort) @@ -319,7 +326,8 @@ class ExpressionEvaluationSuite extends FunSuite { assert(("abcdef" cast IntegerType).nullable === true) assert(("abcdef" cast ShortType).nullable === true) assert(("abcdef" cast ByteType).nullable === true) - assert(("abcdef" cast DecimalType).nullable === true) + assert(("abcdef" cast DecimalType.Unlimited).nullable === true) + assert(("abcdef" cast DecimalType(4, 2)).nullable === true) assert(("abcdef" cast DoubleType).nullable === true) assert(("abcdef" cast FloatType).nullable === true) @@ -332,6 +340,64 @@ class ExpressionEvaluationSuite extends FunSuite { checkEvaluation(Literal(d1) < Literal(d2), true) } + test("casting to fixed-precision decimals") { + // Overflow and rounding for casting to fixed-precision decimals: + // - Values should round with HALF_UP mode by default when you lower scale + // - Values that would overflow the target precision should turn into null + // - Because of this, casts to fixed-precision decimals should be nullable + + assert(Cast(Literal(123), DecimalType.Unlimited).nullable === false) + assert(Cast(Literal(10.03f), DecimalType.Unlimited).nullable === false) + assert(Cast(Literal(10.03), DecimalType.Unlimited).nullable === false) + assert(Cast(Literal(Decimal(10.03)), DecimalType.Unlimited).nullable === false) + + assert(Cast(Literal(123), DecimalType(2, 1)).nullable === true) + assert(Cast(Literal(10.03f), DecimalType(2, 1)).nullable === true) + assert(Cast(Literal(10.03), DecimalType(2, 1)).nullable === true) + assert(Cast(Literal(Decimal(10.03)), DecimalType(2, 1)).nullable === true) + + checkEvaluation(Cast(Literal(123), DecimalType.Unlimited), Decimal(123)) + checkEvaluation(Cast(Literal(123), DecimalType(3, 0)), Decimal(123)) + checkEvaluation(Cast(Literal(123), DecimalType(3, 1)), null) + checkEvaluation(Cast(Literal(123), DecimalType(2, 0)), null) + + checkEvaluation(Cast(Literal(10.03), DecimalType.Unlimited), Decimal(10.03)) + checkEvaluation(Cast(Literal(10.03), DecimalType(4, 2)), Decimal(10.03)) + checkEvaluation(Cast(Literal(10.03), DecimalType(3, 1)), Decimal(10.0)) + checkEvaluation(Cast(Literal(10.03), DecimalType(2, 0)), Decimal(10)) + checkEvaluation(Cast(Literal(10.03), DecimalType(1, 0)), null) + checkEvaluation(Cast(Literal(10.03), DecimalType(2, 1)), null) + checkEvaluation(Cast(Literal(10.03), DecimalType(3, 2)), null) + checkEvaluation(Cast(Literal(Decimal(10.03)), DecimalType(3, 1)), Decimal(10.0)) + checkEvaluation(Cast(Literal(Decimal(10.03)), DecimalType(3, 2)), null) + + checkEvaluation(Cast(Literal(10.05), DecimalType.Unlimited), Decimal(10.05)) + checkEvaluation(Cast(Literal(10.05), DecimalType(4, 2)), Decimal(10.05)) + checkEvaluation(Cast(Literal(10.05), DecimalType(3, 1)), Decimal(10.1)) + checkEvaluation(Cast(Literal(10.05), DecimalType(2, 0)), Decimal(10)) + checkEvaluation(Cast(Literal(10.05), DecimalType(1, 0)), null) + checkEvaluation(Cast(Literal(10.05), DecimalType(2, 1)), null) + checkEvaluation(Cast(Literal(10.05), DecimalType(3, 2)), null) + checkEvaluation(Cast(Literal(Decimal(10.05)), DecimalType(3, 1)), Decimal(10.1)) + checkEvaluation(Cast(Literal(Decimal(10.05)), DecimalType(3, 2)), null) + + checkEvaluation(Cast(Literal(9.95), DecimalType(3, 2)), Decimal(9.95)) + checkEvaluation(Cast(Literal(9.95), DecimalType(3, 1)), Decimal(10.0)) + checkEvaluation(Cast(Literal(9.95), DecimalType(2, 0)), Decimal(10)) + checkEvaluation(Cast(Literal(9.95), DecimalType(2, 1)), null) + checkEvaluation(Cast(Literal(9.95), DecimalType(1, 0)), null) + checkEvaluation(Cast(Literal(Decimal(9.95)), DecimalType(3, 1)), Decimal(10.0)) + checkEvaluation(Cast(Literal(Decimal(9.95)), DecimalType(1, 0)), null) + + checkEvaluation(Cast(Literal(-9.95), DecimalType(3, 2)), Decimal(-9.95)) + checkEvaluation(Cast(Literal(-9.95), DecimalType(3, 1)), Decimal(-10.0)) + checkEvaluation(Cast(Literal(-9.95), DecimalType(2, 0)), Decimal(-10)) + checkEvaluation(Cast(Literal(-9.95), DecimalType(2, 1)), null) + checkEvaluation(Cast(Literal(-9.95), DecimalType(1, 0)), null) + checkEvaluation(Cast(Literal(Decimal(-9.95)), DecimalType(3, 1)), Decimal(-10.0)) + checkEvaluation(Cast(Literal(Decimal(-9.95)), DecimalType(1, 0)), null) + } + test("timestamp") { val ts1 = new Timestamp(12) val ts2 = new Timestamp(123) @@ -368,7 +434,7 @@ class ExpressionEvaluationSuite extends FunSuite { millis.toFloat / 1000) checkEvaluation(Cast(Cast(millis.toDouble / 1000, TimestampType), DoubleType), millis.toDouble / 1000) - checkEvaluation(Cast(Literal(BigDecimal(1)) cast TimestampType, DecimalType), 1) + checkEvaluation(Cast(Literal(Decimal(1)) cast TimestampType, DecimalType.Unlimited), Decimal(1)) // A test for higher precision than millis checkEvaluation(Cast(Cast(0.00000001, TimestampType), DoubleType), 0.00000001) @@ -667,11 +733,43 @@ class ExpressionEvaluationSuite extends FunSuite { val expectedResults = inputSequence.map(l => math.sqrt(l.toDouble)) val rowSequence = inputSequence.map(l => new GenericRow(Array[Any](l.toDouble))) val d = 'a.double.at(0) - + for ((row, expected) <- rowSequence zip expectedResults) { checkEvaluation(Sqrt(d), expected, row) } checkEvaluation(Sqrt(Literal(null, DoubleType)), null, new GenericRow(Array[Any](null))) } + + test("Bitwise operations") { + val row = new GenericRow(Array[Any](1, 2, 3, null)) + val c1 = 'a.int.at(0) + val c2 = 'a.int.at(1) + val c3 = 'a.int.at(2) + val c4 = 'a.int.at(3) + + checkEvaluation(BitwiseAnd(c1, c4), null, row) + checkEvaluation(BitwiseAnd(c1, c2), 0, row) + checkEvaluation(BitwiseAnd(c1, Literal(null, IntegerType)), null, row) + checkEvaluation(BitwiseAnd(Literal(null, IntegerType), Literal(null, IntegerType)), null, row) + + checkEvaluation(BitwiseOr(c1, c4), null, row) + checkEvaluation(BitwiseOr(c1, c2), 3, row) + checkEvaluation(BitwiseOr(c1, Literal(null, IntegerType)), null, row) + checkEvaluation(BitwiseOr(Literal(null, IntegerType), Literal(null, IntegerType)), null, row) + + checkEvaluation(BitwiseXor(c1, c4), null, row) + checkEvaluation(BitwiseXor(c1, c2), 3, row) + checkEvaluation(BitwiseXor(c1, Literal(null, IntegerType)), null, row) + checkEvaluation(BitwiseXor(Literal(null, IntegerType), Literal(null, IntegerType)), null, row) + + checkEvaluation(BitwiseNot(c4), null, row) + checkEvaluation(BitwiseNot(c1), -2, row) + checkEvaluation(BitwiseNot(Literal(null, IntegerType)), null, row) + + checkEvaluation(c1 & c2, 0, row) + checkEvaluation(c1 | c2, 3, row) + checkEvaluation(c1 ^ c2, 3, row) + checkEvaluation(~c1, -2, row) + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala index 7e9f47ef21df8..c4a1f899d8a13 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala @@ -33,7 +33,8 @@ class PlanTest extends FunSuite { * we must normalize them to check if two different queries are identical. */ protected def normalizeExprIds(plan: LogicalPlan) = { - val minId = plan.flatMap(_.expressions.flatMap(_.references).map(_.exprId.id)).min + val list = plan.flatMap(_.expressions.flatMap(_.references).map(_.exprId.id)) + val minId = if (list.isEmpty) 0 else list.min plan transformAllExpressions { case a: AttributeReference => AttributeReference(a.name, a.dataType, a.nullable)(exprId = ExprId(a.exprId.id - minId)) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/types/decimal/DecimalSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/types/decimal/DecimalSuite.scala new file mode 100644 index 0000000000000..5aa263484d5ed --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/types/decimal/DecimalSuite.scala @@ -0,0 +1,158 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.types.decimal + +import org.scalatest.{PrivateMethodTester, FunSuite} + +import scala.language.postfixOps + +class DecimalSuite extends FunSuite with PrivateMethodTester { + test("creating decimals") { + /** Check that a Decimal has the given string representation, precision and scale */ + def checkDecimal(d: Decimal, string: String, precision: Int, scale: Int): Unit = { + assert(d.toString === string) + assert(d.precision === precision) + assert(d.scale === scale) + } + + checkDecimal(new Decimal(), "0", 1, 0) + checkDecimal(Decimal(BigDecimal("10.030")), "10.030", 5, 3) + checkDecimal(Decimal(BigDecimal("10.030"), 4, 1), "10.0", 4, 1) + checkDecimal(Decimal(BigDecimal("-9.95"), 4, 1), "-10.0", 4, 1) + checkDecimal(Decimal("10.030"), "10.030", 5, 3) + checkDecimal(Decimal(10.03), "10.03", 4, 2) + checkDecimal(Decimal(17L), "17", 20, 0) + checkDecimal(Decimal(17), "17", 10, 0) + checkDecimal(Decimal(17L, 2, 1), "1.7", 2, 1) + checkDecimal(Decimal(170L, 4, 2), "1.70", 4, 2) + checkDecimal(Decimal(17L, 24, 1), "1.7", 24, 1) + checkDecimal(Decimal(1e17.toLong, 18, 0), 1e17.toLong.toString, 18, 0) + checkDecimal(Decimal(Long.MaxValue), Long.MaxValue.toString, 20, 0) + checkDecimal(Decimal(Long.MinValue), Long.MinValue.toString, 20, 0) + intercept[IllegalArgumentException](Decimal(170L, 2, 1)) + intercept[IllegalArgumentException](Decimal(170L, 2, 0)) + intercept[IllegalArgumentException](Decimal(BigDecimal("10.030"), 2, 1)) + intercept[IllegalArgumentException](Decimal(BigDecimal("-9.95"), 2, 1)) + intercept[IllegalArgumentException](Decimal(1e17.toLong, 17, 0)) + } + + test("double and long values") { + /** Check that a Decimal converts to the given double and long values */ + def checkValues(d: Decimal, doubleValue: Double, longValue: Long): Unit = { + assert(d.toDouble === doubleValue) + assert(d.toLong === longValue) + } + + checkValues(new Decimal(), 0.0, 0L) + checkValues(Decimal(BigDecimal("10.030")), 10.03, 10L) + checkValues(Decimal(BigDecimal("10.030"), 4, 1), 10.0, 10L) + checkValues(Decimal(BigDecimal("-9.95"), 4, 1), -10.0, -10L) + checkValues(Decimal(10.03), 10.03, 10L) + checkValues(Decimal(17L), 17.0, 17L) + checkValues(Decimal(17), 17.0, 17L) + checkValues(Decimal(17L, 2, 1), 1.7, 1L) + checkValues(Decimal(170L, 4, 2), 1.7, 1L) + checkValues(Decimal(1e16.toLong), 1e16, 1e16.toLong) + checkValues(Decimal(1e17.toLong), 1e17, 1e17.toLong) + checkValues(Decimal(1e18.toLong), 1e18, 1e18.toLong) + checkValues(Decimal(2e18.toLong), 2e18, 2e18.toLong) + checkValues(Decimal(Long.MaxValue), Long.MaxValue.toDouble, Long.MaxValue) + checkValues(Decimal(Long.MinValue), Long.MinValue.toDouble, Long.MinValue) + checkValues(Decimal(Double.MaxValue), Double.MaxValue, 0L) + checkValues(Decimal(Double.MinValue), Double.MinValue, 0L) + } + + // Accessor for the BigDecimal value of a Decimal, which will be null if it's using Longs + private val decimalVal = PrivateMethod[BigDecimal]('decimalVal) + + /** Check whether a decimal is represented compactly (passing whether we expect it to be) */ + private def checkCompact(d: Decimal, expected: Boolean): Unit = { + val isCompact = d.invokePrivate(decimalVal()).eq(null) + assert(isCompact == expected, s"$d ${if (expected) "was not" else "was"} compact") + } + + test("small decimals represented as unscaled long") { + checkCompact(new Decimal(), true) + checkCompact(Decimal(BigDecimal(10.03)), false) + checkCompact(Decimal(BigDecimal(1e20)), false) + checkCompact(Decimal(17L), true) + checkCompact(Decimal(17), true) + checkCompact(Decimal(17L, 2, 1), true) + checkCompact(Decimal(170L, 4, 2), true) + checkCompact(Decimal(17L, 24, 1), true) + checkCompact(Decimal(1e16.toLong), true) + checkCompact(Decimal(1e17.toLong), true) + checkCompact(Decimal(1e18.toLong - 1), true) + checkCompact(Decimal(- 1e18.toLong + 1), true) + checkCompact(Decimal(1e18.toLong - 1, 30, 10), true) + checkCompact(Decimal(- 1e18.toLong + 1, 30, 10), true) + checkCompact(Decimal(1e18.toLong), false) + checkCompact(Decimal(-1e18.toLong), false) + checkCompact(Decimal(1e18.toLong, 30, 10), false) + checkCompact(Decimal(-1e18.toLong, 30, 10), false) + checkCompact(Decimal(Long.MaxValue), false) + checkCompact(Decimal(Long.MinValue), false) + } + + test("hash code") { + assert(Decimal(123).hashCode() === (123).##) + assert(Decimal(-123).hashCode() === (-123).##) + assert(Decimal(123.312).hashCode() === (123.312).##) + assert(Decimal(Int.MaxValue).hashCode() === Int.MaxValue.##) + assert(Decimal(Long.MaxValue).hashCode() === Long.MaxValue.##) + assert(Decimal(BigDecimal(123)).hashCode() === (123).##) + + val reallyBig = BigDecimal("123182312312313232112312312123.1231231231") + assert(Decimal(reallyBig).hashCode() === reallyBig.hashCode) + } + + test("equals") { + // The decimals on the left are stored compactly, while the ones on the right aren't + checkCompact(Decimal(123), true) + checkCompact(Decimal(BigDecimal(123)), false) + checkCompact(Decimal("123"), false) + assert(Decimal(123) === Decimal(BigDecimal(123))) + assert(Decimal(123) === Decimal(BigDecimal("123.00"))) + assert(Decimal(-123) === Decimal(BigDecimal(-123))) + assert(Decimal(-123) === Decimal(BigDecimal("-123.00"))) + } + + test("isZero") { + assert(Decimal(0).isZero) + assert(Decimal(0, 4, 2).isZero) + assert(Decimal("0").isZero) + assert(Decimal("0.000").isZero) + assert(!Decimal(1).isZero) + assert(!Decimal(1, 4, 2).isZero) + assert(!Decimal("1").isZero) + assert(!Decimal("0.001").isZero) + } + + test("arithmetic") { + assert(Decimal(100) + Decimal(-100) === Decimal(0)) + assert(Decimal(100) + Decimal(-100) === Decimal(0)) + assert(Decimal(100) * Decimal(-100) === Decimal(-10000)) + assert(Decimal(1e13) * Decimal(1e13) === Decimal(1e26)) + assert(Decimal(100) / Decimal(-100) === Decimal(-1)) + assert(Decimal(100) / Decimal(0) === null) + assert(Decimal(100) % Decimal(-100) === Decimal(0)) + assert(Decimal(100) % Decimal(3) === Decimal(1)) + assert(Decimal(-100) % Decimal(3) === Decimal(-1)) + assert(Decimal(100) % Decimal(0) === null) + } +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/MetadataSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/MetadataSuite.scala new file mode 100644 index 0000000000000..0063d31666c85 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/MetadataSuite.scala @@ -0,0 +1,82 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.util + +import org.json4s.jackson.JsonMethods.parse +import org.scalatest.FunSuite + +class MetadataSuite extends FunSuite { + + val baseMetadata = new MetadataBuilder() + .putString("purpose", "ml") + .putBoolean("isBase", true) + .build() + + val summary = new MetadataBuilder() + .putLong("numFeatures", 10L) + .build() + + val age = new MetadataBuilder() + .putString("name", "age") + .putLong("index", 1L) + .putBoolean("categorical", false) + .putDouble("average", 45.0) + .build() + + val gender = new MetadataBuilder() + .putString("name", "gender") + .putLong("index", 5) + .putBoolean("categorical", true) + .putStringArray("categories", Array("male", "female")) + .build() + + val metadata = new MetadataBuilder() + .withMetadata(baseMetadata) + .putBoolean("isBase", false) // overwrite an existing key + .putMetadata("summary", summary) + .putLongArray("long[]", Array(0L, 1L)) + .putDoubleArray("double[]", Array(3.0, 4.0)) + .putBooleanArray("boolean[]", Array(true, false)) + .putMetadataArray("features", Array(age, gender)) + .build() + + test("metadata builder and getters") { + assert(age.getLong("index") === 1L) + assert(age.getDouble("average") === 45.0) + assert(age.getBoolean("categorical") === false) + assert(age.getString("name") === "age") + assert(metadata.getString("purpose") === "ml") + assert(metadata.getBoolean("isBase") === false) + assert(metadata.getMetadata("summary") === summary) + assert(metadata.getLongArray("long[]").toSeq === Seq(0L, 1L)) + assert(metadata.getDoubleArray("double[]").toSeq === Seq(3.0, 4.0)) + assert(metadata.getBooleanArray("boolean[]").toSeq === Seq(true, false)) + assert(gender.getStringArray("categories").toSeq === Seq("male", "female")) + assert(metadata.getMetadataArray("features").toSeq === Seq(age, gender)) + } + + test("metadata json conversion") { + val json = metadata.json + withClue("toJson must produce a valid JSON string") { + parse(json) + } + val parsed = Metadata.fromJson(json) + assert(parsed === metadata) + assert(parsed.## === metadata.##) + } +} diff --git a/sql/core/src/main/java/org/apache/spark/sql/api/java/DataType.java b/sql/core/src/main/java/org/apache/spark/sql/api/java/DataType.java index 37e88d72b9172..c38354039d686 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/api/java/DataType.java +++ b/sql/core/src/main/java/org/apache/spark/sql/api/java/DataType.java @@ -17,9 +17,7 @@ package org.apache.spark.sql.api.java; -import java.util.HashSet; -import java.util.List; -import java.util.Set; +import java.util.*; /** * The base type of all Spark SQL data types. @@ -54,11 +52,6 @@ public abstract class DataType { */ public static final TimestampType TimestampType = new TimestampType(); - /** - * Gets the DecimalType object. - */ - public static final DecimalType DecimalType = new DecimalType(); - /** * Gets the DoubleType object. */ @@ -151,15 +144,31 @@ public static MapType createMapType( * Creates a StructField by specifying the name ({@code name}), data type ({@code dataType}) and * whether values of this field can be null values ({@code nullable}). */ - public static StructField createStructField(String name, DataType dataType, boolean nullable) { + public static StructField createStructField( + String name, + DataType dataType, + boolean nullable, + Metadata metadata) { if (name == null) { throw new IllegalArgumentException("name should not be null."); } if (dataType == null) { throw new IllegalArgumentException("dataType should not be null."); } + if (metadata == null) { + throw new IllegalArgumentException("metadata should not be null."); + } - return new StructField(name, dataType, nullable); + return new StructField(name, dataType, nullable, metadata); + } + + /** + * Creates a StructField with empty metadata. + * + * @see #createStructField(String, DataType, boolean, Metadata) + */ + public static StructField createStructField(String name, DataType dataType, boolean nullable) { + return createStructField(name, dataType, nullable, (new MetadataBuilder()).build()); } /** @@ -191,5 +200,4 @@ public static StructType createStructType(StructField[] fields) { return new StructType(fields); } - } diff --git a/sql/core/src/main/java/org/apache/spark/sql/api/java/DecimalType.java b/sql/core/src/main/java/org/apache/spark/sql/api/java/DecimalType.java index bc54c078d7a4e..60752451ecfc7 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/api/java/DecimalType.java +++ b/sql/core/src/main/java/org/apache/spark/sql/api/java/DecimalType.java @@ -19,9 +19,61 @@ /** * The data type representing java.math.BigDecimal values. - * - * {@code DecimalType} is represented by the singleton object {@link DataType#DecimalType}. */ public class DecimalType extends DataType { - protected DecimalType() {} + private boolean hasPrecisionInfo; + private int precision; + private int scale; + + public DecimalType(int precision, int scale) { + this.hasPrecisionInfo = true; + this.precision = precision; + this.scale = scale; + } + + public DecimalType() { + this.hasPrecisionInfo = false; + this.precision = -1; + this.scale = -1; + } + + public boolean isUnlimited() { + return !hasPrecisionInfo; + } + + public boolean isFixed() { + return hasPrecisionInfo; + } + + /** Return the precision, or -1 if no precision is set */ + public int getPrecision() { + return precision; + } + + /** Return the scale, or -1 if no precision is set */ + public int getScale() { + return scale; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + + DecimalType that = (DecimalType) o; + + if (hasPrecisionInfo != that.hasPrecisionInfo) return false; + if (precision != that.precision) return false; + if (scale != that.scale) return false; + + return true; + } + + @Override + public int hashCode() { + int result = (hasPrecisionInfo ? 1 : 0); + result = 31 * result + precision; + result = 31 * result + scale; + return result; + } } diff --git a/sql/core/src/main/java/org/apache/spark/sql/api/java/Metadata.java b/sql/core/src/main/java/org/apache/spark/sql/api/java/Metadata.java new file mode 100644 index 0000000000000..0f819fb01a76a --- /dev/null +++ b/sql/core/src/main/java/org/apache/spark/sql/api/java/Metadata.java @@ -0,0 +1,31 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.api.java; + +/** + * Metadata is a wrapper over Map[String, Any] that limits the value type to simple ones: Boolean, + * Long, Double, String, Metadata, Array[Boolean], Array[Long], Array[Double], Array[String], and + * Array[Metadata]. JSON is used for serialization. + * + * The default constructor is private. User should use [[MetadataBuilder]]. + */ +class Metadata extends org.apache.spark.sql.catalyst.util.Metadata { + Metadata(scala.collection.immutable.Map map) { + super(map); + } +} diff --git a/sql/core/src/main/java/org/apache/spark/sql/api/java/MetadataBuilder.java b/sql/core/src/main/java/org/apache/spark/sql/api/java/MetadataBuilder.java new file mode 100644 index 0000000000000..6e6b12f0722c5 --- /dev/null +++ b/sql/core/src/main/java/org/apache/spark/sql/api/java/MetadataBuilder.java @@ -0,0 +1,28 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.api.java; + +/** + * Builder for [[Metadata]]. If there is a key collision, the latter will overwrite the former. + */ +public class MetadataBuilder extends org.apache.spark.sql.catalyst.util.MetadataBuilder { + @Override + public Metadata build() { + return new Metadata(getMap()); + } +} diff --git a/sql/core/src/main/java/org/apache/spark/sql/api/java/StructField.java b/sql/core/src/main/java/org/apache/spark/sql/api/java/StructField.java index b48e2a2c5f953..7c60d492bcdf0 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/api/java/StructField.java +++ b/sql/core/src/main/java/org/apache/spark/sql/api/java/StructField.java @@ -17,6 +17,8 @@ package org.apache.spark.sql.api.java; +import java.util.Map; + /** * A StructField object represents a field in a StructType object. * A StructField object comprises three fields, {@code String name}, {@code DataType dataType}, @@ -24,20 +26,27 @@ * The field of {@code dataType} specifies the data type of a StructField. * The field of {@code nullable} specifies if values of a StructField can contain {@code null} * values. + * The field of {@code metadata} provides extra information of the StructField. * * To create a {@link StructField}, - * {@link DataType#createStructField(String, DataType, boolean)} + * {@link DataType#createStructField(String, DataType, boolean, Metadata)} * should be used. */ public class StructField { private String name; private DataType dataType; private boolean nullable; + private Metadata metadata; - protected StructField(String name, DataType dataType, boolean nullable) { + protected StructField( + String name, + DataType dataType, + boolean nullable, + Metadata metadata) { this.name = name; this.dataType = dataType; this.nullable = nullable; + this.metadata = metadata; } public String getName() { @@ -52,6 +61,10 @@ public boolean isNullable() { return nullable; } + public Metadata getMetadata() { + return metadata; + } + @Override public boolean equals(Object o) { if (this == o) return true; @@ -62,6 +75,7 @@ public boolean equals(Object o) { if (nullable != that.nullable) return false; if (!dataType.equals(that.dataType)) return false; if (!name.equals(that.name)) return false; + if (!metadata.equals(that.metadata)) return false; return true; } @@ -71,6 +85,7 @@ public int hashCode() { int result = name.hashCode(); result = 31 * result + dataType.hashCode(); result = 31 * result + (nullable ? 1 : 0); + result = 31 * result + metadata.hashCode(); return result; } } diff --git a/sql/core/src/main/java/org/apache/spark/sql/api/java/UserDefinedType.java b/sql/core/src/main/java/org/apache/spark/sql/api/java/UserDefinedType.java new file mode 100644 index 0000000000000..b751847b464fd --- /dev/null +++ b/sql/core/src/main/java/org/apache/spark/sql/api/java/UserDefinedType.java @@ -0,0 +1,53 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.api.java; + +import java.io.Serializable; + +import org.apache.spark.annotation.DeveloperApi; + +/** + * ::DeveloperApi:: + * The data type representing User-Defined Types (UDTs). + * UDTs may use any other DataType for an underlying representation. + */ +@DeveloperApi +public abstract class UserDefinedType extends DataType implements Serializable { + + protected UserDefinedType() { } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + UserDefinedType that = (UserDefinedType) o; + return this.sqlType().equals(that.sqlType()); + } + + /** Underlying storage type for this UDT */ + public abstract DataType sqlType(); + + /** Convert the user type to a SQL datum */ + public abstract Object serialize(Object obj); + + /** Convert a SQL datum to the user type */ + public abstract UserType deserialize(Object datum); + + /** Class object for the UserType */ + public abstract Class userClass(); +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/CacheManager.scala b/sql/core/src/main/scala/org/apache/spark/sql/CacheManager.scala index 3ced11a5e6c11..2e7abac1f1bdb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/CacheManager.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/CacheManager.scala @@ -103,6 +103,19 @@ private[sql] trait CacheManager { cachedData.remove(dataIndex) } + /** Tries to remove the data for the given SchemaRDD from the cache if it's cached */ + private[sql] def tryUncacheQuery( + query: SchemaRDD, + blocking: Boolean = true): Boolean = writeLock { + val planToCache = query.queryExecution.analyzed + val dataIndex = cachedData.indexWhere(cd => planToCache.sameResult(cd.plan)) + val found = dataIndex >= 0 + if (found) { + cachedData(dataIndex).cachedRepresentation.cachedColumnBuffers.unpersist(blocking) + cachedData.remove(dataIndex) + } + found + } /** Optionally returns cached data for the given SchemaRDD */ private[sql] def lookupCachedData(query: SchemaRDD): Option[CachedData] = readLock { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala index 07e6e2eccddf4..279495aa64755 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala @@ -79,13 +79,13 @@ private[sql] trait SQLConf { private[spark] def dialect: String = getConf(DIALECT, "sql") /** When true tables cached using the in-memory columnar caching will be compressed. */ - private[spark] def useCompression: Boolean = getConf(COMPRESS_CACHED, "false").toBoolean + private[spark] def useCompression: Boolean = getConf(COMPRESS_CACHED, "true").toBoolean /** The compression codec for writing to a Parquetfile */ - private[spark] def parquetCompressionCodec: String = getConf(PARQUET_COMPRESSION, "snappy") + private[spark] def parquetCompressionCodec: String = getConf(PARQUET_COMPRESSION, "gzip") /** The number of rows that will be */ - private[spark] def columnBatchSize: Int = getConf(COLUMN_BATCH_SIZE, "1000").toInt + private[spark] def columnBatchSize: Int = getConf(COLUMN_BATCH_SIZE, "10000").toInt /** Number of partitions to use for shuffle operators. */ private[spark] def numShufflePartitions: Int = getConf(SHUFFLE_PARTITIONS, "200").toInt @@ -106,10 +106,10 @@ private[sql] trait SQLConf { * a broadcast value during the physical executions of join operations. Setting this to -1 * effectively disables auto conversion. * - * Hive setting: hive.auto.convert.join.noconditionaltask.size, whose default value is also 10000. + * Hive setting: hive.auto.convert.join.noconditionaltask.size, whose default value is 10000. */ private[spark] def autoBroadcastJoinThreshold: Int = - getConf(AUTO_BROADCASTJOIN_THRESHOLD, "10000").toInt + getConf(AUTO_BROADCASTJOIN_THRESHOLD, (10 * 1024 * 1024).toString).toInt /** * The default size in bytes to assign to a logical operator's estimation statistics. By default, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala index c4f4ef01d78df..84eaf401f240c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala @@ -32,10 +32,11 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.optimizer.{Optimizer, DefaultOptimizer} import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.rules.RuleExecutor -import org.apache.spark.sql.catalyst.types.DataType +import org.apache.spark.sql.catalyst.types.UserDefinedType import org.apache.spark.sql.execution.{SparkStrategies, _} import org.apache.spark.sql.json._ import org.apache.spark.sql.parquet.ParquetRelation +import org.apache.spark.sql.sources.{DataSourceStrategy, BaseRelation, DDLParser, LogicalRelation} /** * :: AlphaComponent :: @@ -69,13 +70,19 @@ class SQLContext(@transient val sparkContext: SparkContext) @transient protected[sql] lazy val optimizer: Optimizer = DefaultOptimizer + @transient + protected[sql] val ddlParser = new DDLParser + @transient protected[sql] val sqlParser = { val fallback = new catalyst.SqlParser new catalyst.SparkSQLParser(fallback(_)) } - protected[sql] def parseSql(sql: String): LogicalPlan = sqlParser(sql) + protected[sql] def parseSql(sql: String): LogicalPlan = { + ddlParser(sql).getOrElse(sqlParser(sql)) + } + protected[sql] def executeSql(sql: String): this.QueryExecution = executePlan(parseSql(sql)) protected[sql] def executePlan(plan: LogicalPlan): this.QueryExecution = new this.QueryExecution { val logical = plan } @@ -101,8 +108,14 @@ class SQLContext(@transient val sparkContext: SparkContext) */ implicit def createSchemaRDD[A <: Product: TypeTag](rdd: RDD[A]) = { SparkPlan.currentContext.set(self) - new SchemaRDD(this, - LogicalRDD(ScalaReflection.attributesFor[A], RDDConversions.productToRowRdd(rdd))(self)) + val attributeSeq = ScalaReflection.attributesFor[A] + val schema = StructType.fromAttributes(attributeSeq) + val rowRDD = RDDConversions.productToRowRdd(rdd, schema) + new SchemaRDD(this, LogicalRDD(attributeSeq, rowRDD)(self)) + } + + implicit def baseRelationToSchemaRDD(baseRelation: BaseRelation): SchemaRDD = { + logicalPlanToSparkQuery(LogicalRelation(baseRelation)) } /** @@ -266,6 +279,19 @@ class SQLContext(@transient val sparkContext: SparkContext) catalog.registerTable(None, tableName, rdd.queryExecution.logical) } + /** + * Drops the temporary table with the given table name in the catalog. If the table has been + * cached/persisted before, it's also unpersisted. + * + * @param tableName the name of the table to be unregistered. + * + * @group userf + */ + def dropTempTable(tableName: String): Unit = { + tryUncacheQuery(table(tableName)) + catalog.unregisterTable(None, tableName) + } + /** * Executes a SQL query using Spark, returning the result as a SchemaRDD. The dialect that is * used for SQL parsing can be configured with 'spark.sql.dialect'. @@ -284,6 +310,14 @@ class SQLContext(@transient val sparkContext: SparkContext) def table(tableName: String): SchemaRDD = new SchemaRDD(this, catalog.lookupRelation(None, tableName)) + /** + * :: DeveloperApi :: + * Allows extra strategies to be injected into the query planner at runtime. Note this API + * should be consider experimental and is not intended to be stable across releases. + */ + @DeveloperApi + var extraStrategies: Seq[Strategy] = Nil + protected[sql] class SparkPlanner extends SparkStrategies { val sparkContext: SparkContext = self.sparkContext @@ -294,7 +328,9 @@ class SQLContext(@transient val sparkContext: SparkContext) def numPartitions = self.numShufflePartitions val strategies: Seq[Strategy] = + extraStrategies ++ ( CommandStrategy(self) :: + DataSourceStrategy :: TakeOrdered :: HashAggregation :: LeftSemiJoin :: @@ -303,7 +339,7 @@ class SQLContext(@transient val sparkContext: SparkContext) ParquetOperations :: BasicOperators :: CartesianProduct :: - BroadcastNestedLoopJoin :: Nil + BroadcastNestedLoopJoin :: Nil) /** * Used to build table scan operators where complex projection and filtering are done using @@ -438,59 +474,23 @@ class SQLContext(@transient val sparkContext: SparkContext) private[sql] def applySchemaToPythonRDD( rdd: RDD[Array[Any]], schema: StructType): SchemaRDD = { - import scala.collection.JavaConversions._ def needsConversion(dataType: DataType): Boolean = dataType match { case ByteType => true case ShortType => true case FloatType => true + case DateType => true case TimestampType => true case ArrayType(_, _) => true case MapType(_, _, _) => true case StructType(_) => true + case udt: UserDefinedType[_] => needsConversion(udt.sqlType) case other => false } - // Converts value to the type specified by the data type. - // Because Python does not have data types for TimestampType, FloatType, ShortType, and - // ByteType, we need to explicitly convert values in columns of these data types to the desired - // JVM data types. - def convert(obj: Any, dataType: DataType): Any = (obj, dataType) match { - // TODO: We should check nullable - case (null, _) => null - - case (c: java.util.List[_], ArrayType(elementType, _)) => - c.map { e => convert(e, elementType)}: Seq[Any] - - case (c, ArrayType(elementType, _)) if c.getClass.isArray => - c.asInstanceOf[Array[_]].map(e => convert(e, elementType)): Seq[Any] - - case (c: java.util.Map[_, _], MapType(keyType, valueType, _)) => c.map { - case (key, value) => (convert(key, keyType), convert(value, valueType)) - }.toMap - - case (c, StructType(fields)) if c.getClass.isArray => - new GenericRow(c.asInstanceOf[Array[_]].zip(fields).map { - case (e, f) => convert(e, f.dataType) - }): Row - - case (c: java.util.Calendar, TimestampType) => - new java.sql.Timestamp(c.getTime().getTime()) - - case (c: Int, ByteType) => c.toByte - case (c: Long, ByteType) => c.toByte - case (c: Int, ShortType) => c.toShort - case (c: Long, ShortType) => c.toShort - case (c: Long, IntegerType) => c.toInt - case (c: Double, FloatType) => c.toFloat - case (c, StringType) if !c.isInstanceOf[String] => c.toString - - case (c, _) => c - } - val convertedRdd = if (schema.fields.exists(f => needsConversion(f.dataType))) { rdd.map(m => m.zip(schema.fields).map { - case (value, field) => convert(value, field.dataType) + case (value, field) => EvaluatePython.fromJava(value, field.dataType) }) } else { rdd diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala index 948122d42f0e1..fbec2f9f4b2c1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala @@ -17,25 +17,26 @@ package org.apache.spark.sql -import java.util.{Map => JMap, List => JList} +import java.util.{List => JList} -import org.apache.spark.storage.StorageLevel +import org.apache.spark.api.python.SerDeUtil import scala.collection.JavaConversions._ -import scala.collection.JavaConverters._ import net.razorvine.pickle.Pickler import org.apache.spark.{Dependency, OneToOneDependency, Partition, Partitioner, TaskContext} import org.apache.spark.annotation.{AlphaComponent, Experimental} +import org.apache.spark.api.java.JavaRDD import org.apache.spark.rdd.RDD import org.apache.spark.sql.api.java.JavaSchemaRDD +import org.apache.spark.sql.catalyst.ScalaReflection import org.apache.spark.sql.catalyst.analysis._ import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.plans.{Inner, JoinType} -import org.apache.spark.sql.execution.LogicalRDD -import org.apache.spark.api.java.JavaRDD +import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.execution.{LogicalRDD, EvaluatePython} +import org.apache.spark.storage.StorageLevel /** * :: AlphaComponent :: @@ -113,18 +114,22 @@ class SchemaRDD( // ========================================================================================= override def compute(split: Partition, context: TaskContext): Iterator[Row] = - firstParent[Row].compute(split, context).map(_.copy()) + firstParent[Row].compute(split, context).map(ScalaReflection.convertRowToScala(_, this.schema)) override def getPartitions: Array[Partition] = firstParent[Row].partitions - override protected def getDependencies: Seq[Dependency[_]] = + override protected def getDependencies: Seq[Dependency[_]] = { + schema // Force reification of the schema so it is available on executors. + List(new OneToOneDependency(queryExecution.toRdd)) + } - /** Returns the schema of this SchemaRDD (represented by a [[StructType]]). - * - * @group schema - */ - def schema: StructType = queryExecution.analyzed.schema + /** + * Returns the schema of this SchemaRDD (represented by a [[StructType]]). + * + * @group schema + */ + lazy val schema: StructType = queryExecution.analyzed.schema // ======================================================================= // Query DSL @@ -377,49 +382,13 @@ class SchemaRDD( */ def toJavaSchemaRDD: JavaSchemaRDD = new JavaSchemaRDD(sqlContext, logicalPlan) - /** - * Helper for converting a Row to a simple Array suitable for pyspark serialization. - */ - private def rowToJArray(row: Row, structType: StructType): Array[Any] = { - import scala.collection.Map - - def toJava(obj: Any, dataType: DataType): Any = (obj, dataType) match { - case (null, _) => null - - case (obj: Row, struct: StructType) => rowToJArray(obj, struct) - - case (seq: Seq[Any], array: ArrayType) => - seq.map(x => toJava(x, array.elementType)).asJava - case (list: JList[_], array: ArrayType) => - list.map(x => toJava(x, array.elementType)).asJava - case (arr, array: ArrayType) if arr.getClass.isArray => - arr.asInstanceOf[Array[Any]].map(x => toJava(x, array.elementType)) - - case (obj: Map[_, _], mt: MapType) => obj.map { - case (k, v) => (k, toJava(v, mt.valueType)) // key should be primitive type - }.asJava - - // Pyrolite can handle Timestamp - case (other, _) => other - } - - val fields = structType.fields.map(field => field.dataType) - row.zip(fields).map { - case (obj, dataType) => toJava(obj, dataType) - }.toArray - } - /** * Converts a JavaRDD to a PythonRDD. It is used by pyspark. */ private[sql] def javaToPython: JavaRDD[Array[Byte]] = { - val rowSchema = StructType.fromAttributes(this.queryExecution.analyzed.output) - this.mapPartitions { iter => - val pickle = new Pickler - iter.map { row => - rowToJArray(row, rowSchema) - }.grouped(100).map(batched => pickle.dumps(batched.toArray)) - } + val fieldTypes = schema.fields.map(_.dataType) + val jrdd = this.map(EvaluatePython.rowToArray(_, fieldTypes)).toJavaRDD() + SerDeUtil.javaToPython(jrdd) } /** @@ -427,10 +396,10 @@ class SchemaRDD( * format as javaToPython. It is used by pyspark. */ private[sql] def collectToPython: JList[Array[Byte]] = { - val rowSchema = StructType.fromAttributes(this.queryExecution.analyzed.output) + val fieldTypes = schema.fields.map(_.dataType) val pickle = new Pickler new java.util.ArrayList(collect().map { row => - rowToJArray(row, rowSchema) + EvaluatePython.rowToArray(row, fieldTypes) }.grouped(100).map(batched => pickle.dumps(batched.toArray)).toIterable) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDDLike.scala b/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDDLike.scala index 25ba7d88ba538..fd5f4abcbcd65 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDDLike.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDDLike.scala @@ -26,7 +26,7 @@ import org.apache.spark.sql.execution.LogicalRDD * Contains functions that are shared between all SchemaRDD types (i.e., Scala, Java) */ private[sql] trait SchemaRDDLike { - @transient val sqlContext: SQLContext + @transient def sqlContext: SQLContext @transient val baseLogicalPlan: LogicalPlan private[sql] def baseSchemaRDD: SchemaRDD @@ -54,7 +54,7 @@ private[sql] trait SchemaRDDLike { @transient protected[spark] val logicalPlan: LogicalPlan = baseLogicalPlan match { // For various commands (like DDL) and queries with side effects, we force query optimization to // happen right away to let these side effects take place eagerly. - case _: Command | _: InsertIntoTable | _: CreateTableAsSelect |_: WriteToFile => + case _: Command | _: InsertIntoTable | _: CreateTableAsSelect[_] |_: WriteToFile => LogicalRDD(queryExecution.analyzed.output, queryExecution.toRdd)(sqlContext) case _ => baseLogicalPlan @@ -123,7 +123,7 @@ private[sql] trait SchemaRDDLike { */ @Experimental def saveAsTable(tableName: String): Unit = - sqlContext.executePlan(CreateTableAsSelect(None, tableName, logicalPlan)).toRdd + sqlContext.executePlan(CreateTableAsSelect(None, tableName, logicalPlan, false)).toRdd /** Returns the schema as a string in the tree format. * diff --git a/sql/core/src/main/scala/org/apache/spark/sql/UdfRegistration.scala b/sql/core/src/main/scala/org/apache/spark/sql/UdfRegistration.scala index 595b4aa36eae3..6d4c0d82ac7af 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/UdfRegistration.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/UdfRegistration.scala @@ -78,7 +78,7 @@ private[sql] trait UDFRegistration { s""" def registerFunction[T: TypeTag](name: String, func: Function$x[$types, T]): Unit = { def builder(e: Seq[Expression]) = - ScalaUdf(func, ScalaReflection.schemaFor(typeTag[T]).dataType, e) + ScalaUdf(func, ScalaReflection.schemaFor[T].dataType, e) functionRegistry.registerFunction(name, builder) } """ @@ -87,112 +87,112 @@ private[sql] trait UDFRegistration { // scalastyle:off def registerFunction[T: TypeTag](name: String, func: Function1[_, T]): Unit = { - def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor(typeTag[T]).dataType, e) + def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor[T].dataType, e) functionRegistry.registerFunction(name, builder) } def registerFunction[T: TypeTag](name: String, func: Function2[_, _, T]): Unit = { - def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor(typeTag[T]).dataType, e) + def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor[T].dataType, e) functionRegistry.registerFunction(name, builder) } def registerFunction[T: TypeTag](name: String, func: Function3[_, _, _, T]): Unit = { - def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor(typeTag[T]).dataType, e) + def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor[T].dataType, e) functionRegistry.registerFunction(name, builder) } def registerFunction[T: TypeTag](name: String, func: Function4[_, _, _, _, T]): Unit = { - def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor(typeTag[T]).dataType, e) + def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor[T].dataType, e) functionRegistry.registerFunction(name, builder) } def registerFunction[T: TypeTag](name: String, func: Function5[_, _, _, _, _, T]): Unit = { - def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor(typeTag[T]).dataType, e) + def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor[T].dataType, e) functionRegistry.registerFunction(name, builder) } def registerFunction[T: TypeTag](name: String, func: Function6[_, _, _, _, _, _, T]): Unit = { - def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor(typeTag[T]).dataType, e) + def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor[T].dataType, e) functionRegistry.registerFunction(name, builder) } def registerFunction[T: TypeTag](name: String, func: Function7[_, _, _, _, _, _, _, T]): Unit = { - def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor(typeTag[T]).dataType, e) + def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor[T].dataType, e) functionRegistry.registerFunction(name, builder) } def registerFunction[T: TypeTag](name: String, func: Function8[_, _, _, _, _, _, _, _, T]): Unit = { - def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor(typeTag[T]).dataType, e) + def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor[T].dataType, e) functionRegistry.registerFunction(name, builder) } def registerFunction[T: TypeTag](name: String, func: Function9[_, _, _, _, _, _, _, _, _, T]): Unit = { - def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor(typeTag[T]).dataType, e) + def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor[T].dataType, e) functionRegistry.registerFunction(name, builder) } def registerFunction[T: TypeTag](name: String, func: Function10[_, _, _, _, _, _, _, _, _, _, T]): Unit = { - def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor(typeTag[T]).dataType, e) + def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor[T].dataType, e) functionRegistry.registerFunction(name, builder) } def registerFunction[T: TypeTag](name: String, func: Function11[_, _, _, _, _, _, _, _, _, _, _, T]): Unit = { - def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor(typeTag[T]).dataType, e) + def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor[T].dataType, e) functionRegistry.registerFunction(name, builder) } def registerFunction[T: TypeTag](name: String, func: Function12[_, _, _, _, _, _, _, _, _, _, _, _, T]): Unit = { - def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor(typeTag[T]).dataType, e) + def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor[T].dataType, e) functionRegistry.registerFunction(name, builder) } def registerFunction[T: TypeTag](name: String, func: Function13[_, _, _, _, _, _, _, _, _, _, _, _, _, T]): Unit = { - def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor(typeTag[T]).dataType, e) + def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor[T].dataType, e) functionRegistry.registerFunction(name, builder) } def registerFunction[T: TypeTag](name: String, func: Function14[_, _, _, _, _, _, _, _, _, _, _, _, _, _, T]): Unit = { - def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor(typeTag[T]).dataType, e) + def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor[T].dataType, e) functionRegistry.registerFunction(name, builder) } def registerFunction[T: TypeTag](name: String, func: Function15[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, T]): Unit = { - def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor(typeTag[T]).dataType, e) + def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor[T].dataType, e) functionRegistry.registerFunction(name, builder) } def registerFunction[T: TypeTag](name: String, func: Function16[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, T]): Unit = { - def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor(typeTag[T]).dataType, e) + def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor[T].dataType, e) functionRegistry.registerFunction(name, builder) } def registerFunction[T: TypeTag](name: String, func: Function17[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, T]): Unit = { - def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor(typeTag[T]).dataType, e) + def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor[T].dataType, e) functionRegistry.registerFunction(name, builder) } def registerFunction[T: TypeTag](name: String, func: Function18[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, T]): Unit = { - def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor(typeTag[T]).dataType, e) + def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor[T].dataType, e) functionRegistry.registerFunction(name, builder) } def registerFunction[T: TypeTag](name: String, func: Function19[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, T]): Unit = { - def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor(typeTag[T]).dataType, e) + def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor[T].dataType, e) functionRegistry.registerFunction(name, builder) } def registerFunction[T: TypeTag](name: String, func: Function20[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, T]): Unit = { - def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor(typeTag[T]).dataType, e) + def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor[T].dataType, e) functionRegistry.registerFunction(name, builder) } def registerFunction[T: TypeTag](name: String, func: Function21[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, T]): Unit = { - def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor(typeTag[T]).dataType, e) + def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor[T].dataType, e) functionRegistry.registerFunction(name, builder) } def registerFunction[T: TypeTag](name: String, func: Function22[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, T]): Unit = { - def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor(typeTag[T]).dataType, e) + def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor[T].dataType, e) functionRegistry.registerFunction(name, builder) } // scalastyle:on diff --git a/sql/core/src/main/scala/org/apache/spark/sql/api/java/JavaSQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/api/java/JavaSQLContext.scala index f8171c3be3207..4c0869e05b029 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/api/java/JavaSQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/api/java/JavaSQLContext.scala @@ -23,11 +23,14 @@ import org.apache.hadoop.conf.Configuration import org.apache.spark.annotation.{DeveloperApi, Experimental} import org.apache.spark.api.java.{JavaRDD, JavaSparkContext} -import org.apache.spark.sql.json.JsonRDD import org.apache.spark.sql.{SQLContext, StructType => SStructType} +import org.apache.spark.sql.catalyst.annotation.SQLUserDefinedType import org.apache.spark.sql.catalyst.expressions.{AttributeReference, GenericRow, Row => ScalaRow} -import org.apache.spark.sql.parquet.ParquetRelation import org.apache.spark.sql.execution.LogicalRDD +import org.apache.spark.sql.json.JsonRDD +import org.apache.spark.sql.parquet.ParquetRelation +import org.apache.spark.sql.sources.{LogicalRelation, BaseRelation} +import org.apache.spark.sql.types.util.DataTypeConversions import org.apache.spark.sql.types.util.DataTypeConversions.asScalaDataType import org.apache.spark.util.Utils @@ -38,6 +41,10 @@ class JavaSQLContext(val sqlContext: SQLContext) extends UDFRegistration { def this(sparkContext: JavaSparkContext) = this(new SQLContext(sparkContext.sc)) + def baseRelationToSchemaRDD(baseRelation: BaseRelation): JavaSchemaRDD = { + new JavaSchemaRDD(sqlContext, LogicalRelation(baseRelation)) + } + /** * Executes a SQL query using Spark, returning the result as a SchemaRDD. The dialect that is * used for SQL parsing can be configured with 'spark.sql.dialect'. @@ -85,9 +92,12 @@ class JavaSQLContext(val sqlContext: SQLContext) extends UDFRegistration { /** * Applies a schema to an RDD of Java Beans. + * + * WARNING: Since there is no guaranteed ordering for fields in a Java Bean, + * SELECT * queries will return the columns in an undefined order. */ def applySchema(rdd: JavaRDD[_], beanClass: Class[_]): JavaSchemaRDD = { - val schema = getSchema(beanClass) + val attributeSeq = getSchema(beanClass) val className = beanClass.getName val rowRdd = rdd.rdd.mapPartitions { iter => // BeanInfo is not serializable so we must rediscover it remotely for each partition. @@ -97,10 +107,14 @@ class JavaSQLContext(val sqlContext: SQLContext) extends UDFRegistration { localBeanInfo.getPropertyDescriptors.filterNot(_.getName == "class").map(_.getReadMethod) iter.map { row => - new GenericRow(extractors.map(e => e.invoke(row)).toArray[Any]): ScalaRow + new GenericRow( + extractors.zip(attributeSeq).map { case (e, attr) => + DataTypeConversions.convertJavaToCatalyst(e.invoke(row), attr.dataType) + }.toArray[Any] + ): ScalaRow } } - new JavaSchemaRDD(sqlContext, LogicalRDD(schema, rowRdd)(sqlContext)) + new JavaSchemaRDD(sqlContext, LogicalRDD(attributeSeq, rowRdd)(sqlContext)) } /** @@ -187,14 +201,21 @@ class JavaSQLContext(val sqlContext: SQLContext) extends UDFRegistration { sqlContext.registerRDDAsTable(rdd.baseSchemaRDD, tableName) } - /** Returns a Catalyst Schema for the given java bean class. */ + /** + * Returns a Catalyst Schema for the given java bean class. + */ protected def getSchema(beanClass: Class[_]): Seq[AttributeReference] = { // TODO: All of this could probably be moved to Catalyst as it is mostly not Spark specific. val beanInfo = Introspector.getBeanInfo(beanClass) + // Note: The ordering of elements may differ from when the schema is inferred in Scala. + // This is because beanInfo.getPropertyDescriptors gives no guarantees about + // element ordering. val fields = beanInfo.getPropertyDescriptors.filterNot(_.getName == "class") fields.map { property => val (dataType, nullable) = property.getPropertyType match { + case c: Class[_] if c.isAnnotationPresent(classOf[SQLUserDefinedType]) => + (c.getAnnotation(classOf[SQLUserDefinedType]).udt().newInstance(), true) case c: Class[_] if c == classOf[java.lang.String] => (org.apache.spark.sql.StringType, true) case c: Class[_] if c == java.lang.Short.TYPE => @@ -226,6 +247,12 @@ class JavaSQLContext(val sqlContext: SQLContext) extends UDFRegistration { (org.apache.spark.sql.FloatType, true) case c: Class[_] if c == classOf[java.lang.Boolean] => (org.apache.spark.sql.BooleanType, true) + case c: Class[_] if c == classOf[java.math.BigDecimal] => + (org.apache.spark.sql.DecimalType(), true) + case c: Class[_] if c == classOf[java.sql.Date] => + (org.apache.spark.sql.DateType, true) + case c: Class[_] if c == classOf[java.sql.Timestamp] => + (org.apache.spark.sql.TimestampType, true) } AttributeReference(property.getName, dataType, nullable)() } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/api/java/JavaSchemaRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/api/java/JavaSchemaRDD.scala index e7faba0c7f620..1e0ccb368a276 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/api/java/JavaSchemaRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/api/java/JavaSchemaRDD.scala @@ -193,7 +193,7 @@ class JavaSchemaRDD( * Return an RDD with the elements from `this` that are not in `other`. * * Uses `this` partitioner/partition size, because even if `other` is huge, the resulting - * RDD will be <= us. + * RDD will be <= us. */ def subtract(other: JavaSchemaRDD): JavaSchemaRDD = this.baseSchemaRDD.subtract(other.baseSchemaRDD).toJavaSchemaRDD diff --git a/sql/core/src/main/scala/org/apache/spark/sql/api/java/Row.scala b/sql/core/src/main/scala/org/apache/spark/sql/api/java/Row.scala index df01411f60a05..401798e317e96 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/api/java/Row.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/api/java/Row.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.api.java +import org.apache.spark.sql.catalyst.types.decimal.Decimal + import scala.annotation.varargs import scala.collection.convert.Wrappers.{JListWrapper, JMapWrapper} import scala.collection.JavaConversions @@ -106,6 +108,8 @@ class Row(private[spark] val row: ScalaRow) extends Serializable { } override def hashCode(): Int = row.hashCode() + + override def toString: String = row.toString } object Row { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/api/java/UDTWrappers.scala b/sql/core/src/main/scala/org/apache/spark/sql/api/java/UDTWrappers.scala new file mode 100644 index 0000000000000..a7d0f4f127ecc --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/api/java/UDTWrappers.scala @@ -0,0 +1,75 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.api.java + +import org.apache.spark.sql.catalyst.types.{UserDefinedType => ScalaUserDefinedType} +import org.apache.spark.sql.{DataType => ScalaDataType} +import org.apache.spark.sql.types.util.DataTypeConversions + +/** + * Scala wrapper for a Java UserDefinedType + */ +private[sql] class JavaToScalaUDTWrapper[UserType](val javaUDT: UserDefinedType[UserType]) + extends ScalaUserDefinedType[UserType] with Serializable { + + /** Underlying storage type for this UDT */ + val sqlType: ScalaDataType = DataTypeConversions.asScalaDataType(javaUDT.sqlType()) + + /** Convert the user type to a SQL datum */ + def serialize(obj: Any): Any = javaUDT.serialize(obj) + + /** Convert a SQL datum to the user type */ + def deserialize(datum: Any): UserType = javaUDT.deserialize(datum) + + val userClass: java.lang.Class[UserType] = javaUDT.userClass() +} + +/** + * Java wrapper for a Scala UserDefinedType + */ +private[sql] class ScalaToJavaUDTWrapper[UserType](val scalaUDT: ScalaUserDefinedType[UserType]) + extends UserDefinedType[UserType] with Serializable { + + /** Underlying storage type for this UDT */ + val sqlType: DataType = DataTypeConversions.asJavaDataType(scalaUDT.sqlType) + + /** Convert the user type to a SQL datum */ + def serialize(obj: Any): java.lang.Object = scalaUDT.serialize(obj).asInstanceOf[java.lang.Object] + + /** Convert a SQL datum to the user type */ + def deserialize(datum: Any): UserType = scalaUDT.deserialize(datum) + + val userClass: java.lang.Class[UserType] = scalaUDT.userClass +} + +private[sql] object UDTWrappers { + + def wrapAsScala(udtType: UserDefinedType[_]): ScalaUserDefinedType[_] = { + udtType match { + case t: ScalaToJavaUDTWrapper[_] => t.scalaUDT + case _ => new JavaToScalaUDTWrapper(udtType) + } + } + + def wrapAsJava(udtType: ScalaUserDefinedType[_]): UserDefinedType[_] = { + udtType match { + case t: JavaToScalaUDTWrapper[_] => t.javaUDT + case _ => new ScalaToJavaUDTWrapper(udtType) + } + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnBuilder.scala index 300cef15bf8a4..c68dceef3b142 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnBuilder.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnBuilder.scala @@ -79,8 +79,9 @@ private[sql] class BasicColumnBuilder[T <: DataType, JvmType]( } private[sql] abstract class ComplexColumnBuilder[T <: DataType, JvmType]( + columnStats: ColumnStats, columnType: ColumnType[T, JvmType]) - extends BasicColumnBuilder[T, JvmType](new NoopColumnStats, columnType) + extends BasicColumnBuilder[T, JvmType](columnStats, columnType) with NullableColumnBuilder private[sql] abstract class NativeColumnBuilder[T <: NativeType]( @@ -91,7 +92,7 @@ private[sql] abstract class NativeColumnBuilder[T <: NativeType]( with AllCompressionSchemes with CompressibleColumnBuilder[T] -private[sql] class BooleanColumnBuilder extends NativeColumnBuilder(new NoopColumnStats, BOOLEAN) +private[sql] class BooleanColumnBuilder extends NativeColumnBuilder(new BooleanColumnStats, BOOLEAN) private[sql] class IntColumnBuilder extends NativeColumnBuilder(new IntColumnStats, INT) @@ -112,10 +113,11 @@ private[sql] class DateColumnBuilder extends NativeColumnBuilder(new DateColumnS private[sql] class TimestampColumnBuilder extends NativeColumnBuilder(new TimestampColumnStats, TIMESTAMP) -private[sql] class BinaryColumnBuilder extends ComplexColumnBuilder(BINARY) +private[sql] class BinaryColumnBuilder extends ComplexColumnBuilder(new BinaryColumnStats, BINARY) // TODO (lian) Add support for array, struct and map -private[sql] class GenericColumnBuilder extends ComplexColumnBuilder(GENERIC) +private[sql] class GenericColumnBuilder + extends ComplexColumnBuilder(new GenericColumnStats, GENERIC) private[sql] object ColumnBuilder { val DEFAULT_INITIAL_BUFFER_SIZE = 1024 * 1024 diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnStats.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnStats.scala index b34ab255d084a..668efe4a3b2a8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnStats.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnStats.scala @@ -24,11 +24,13 @@ import org.apache.spark.sql.catalyst.expressions.{AttributeMap, Attribute, Attri import org.apache.spark.sql.catalyst.types._ private[sql] class ColumnStatisticsSchema(a: Attribute) extends Serializable { - val upperBound = AttributeReference(a.name + ".upperBound", a.dataType, nullable = false)() - val lowerBound = AttributeReference(a.name + ".lowerBound", a.dataType, nullable = false)() - val nullCount = AttributeReference(a.name + ".nullCount", IntegerType, nullable = false)() + val upperBound = AttributeReference(a.name + ".upperBound", a.dataType, nullable = true)() + val lowerBound = AttributeReference(a.name + ".lowerBound", a.dataType, nullable = true)() + val nullCount = AttributeReference(a.name + ".nullCount", IntegerType, nullable = false)() + val count = AttributeReference(a.name + ".count", IntegerType, nullable = false)() + val sizeInBytes = AttributeReference(a.name + ".sizeInBytes", LongType, nullable = false)() - val schema = Seq(lowerBound, upperBound, nullCount) + val schema = Seq(lowerBound, upperBound, nullCount, count, sizeInBytes) } private[sql] class PartitionStatistics(tableSchema: Seq[Attribute]) extends Serializable { @@ -45,10 +47,21 @@ private[sql] class PartitionStatistics(tableSchema: Seq[Attribute]) extends Seri * brings significant performance penalty. */ private[sql] sealed trait ColumnStats extends Serializable { + protected var count = 0 + protected var nullCount = 0 + protected var sizeInBytes = 0L + /** * Gathers statistics information from `row(ordinal)`. */ - def gatherStats(row: Row, ordinal: Int): Unit + def gatherStats(row: Row, ordinal: Int): Unit = { + if (row.isNullAt(ordinal)) { + nullCount += 1 + // 4 bytes for null position + sizeInBytes += 4 + } + count += 1 + } /** * Column statistics represented as a single row, currently including closed lower bound, closed @@ -57,171 +70,203 @@ private[sql] sealed trait ColumnStats extends Serializable { def collectedStatistics: Row } +/** + * A no-op ColumnStats only used for testing purposes. + */ private[sql] class NoopColumnStats extends ColumnStats { + override def gatherStats(row: Row, ordinal: Int): Unit = super.gatherStats(row, ordinal) + + def collectedStatistics = Row(null, null, nullCount, count, 0L) +} - override def gatherStats(row: Row, ordinal: Int): Unit = {} +private[sql] class BooleanColumnStats extends ColumnStats { + protected var upper = false + protected var lower = true - override def collectedStatistics = Row() + override def gatherStats(row: Row, ordinal: Int): Unit = { + super.gatherStats(row, ordinal) + if (!row.isNullAt(ordinal)) { + val value = row.getBoolean(ordinal) + if (value > upper) upper = value + if (value < lower) lower = value + sizeInBytes += BOOLEAN.defaultSize + } + } + + def collectedStatistics = Row(lower, upper, nullCount, count, sizeInBytes) } private[sql] class ByteColumnStats extends ColumnStats { - var upper = Byte.MinValue - var lower = Byte.MaxValue - var nullCount = 0 + protected var upper = Byte.MinValue + protected var lower = Byte.MaxValue override def gatherStats(row: Row, ordinal: Int): Unit = { + super.gatherStats(row, ordinal) if (!row.isNullAt(ordinal)) { val value = row.getByte(ordinal) if (value > upper) upper = value if (value < lower) lower = value - } else { - nullCount += 1 + sizeInBytes += BYTE.defaultSize } } - def collectedStatistics = Row(lower, upper, nullCount) + def collectedStatistics = Row(lower, upper, nullCount, count, sizeInBytes) } private[sql] class ShortColumnStats extends ColumnStats { - var upper = Short.MinValue - var lower = Short.MaxValue - var nullCount = 0 + protected var upper = Short.MinValue + protected var lower = Short.MaxValue override def gatherStats(row: Row, ordinal: Int): Unit = { + super.gatherStats(row, ordinal) if (!row.isNullAt(ordinal)) { val value = row.getShort(ordinal) if (value > upper) upper = value if (value < lower) lower = value - } else { - nullCount += 1 + sizeInBytes += SHORT.defaultSize } } - def collectedStatistics = Row(lower, upper, nullCount) + def collectedStatistics = Row(lower, upper, nullCount, count, sizeInBytes) } private[sql] class LongColumnStats extends ColumnStats { - var upper = Long.MinValue - var lower = Long.MaxValue - var nullCount = 0 + protected var upper = Long.MinValue + protected var lower = Long.MaxValue override def gatherStats(row: Row, ordinal: Int): Unit = { + super.gatherStats(row, ordinal) if (!row.isNullAt(ordinal)) { val value = row.getLong(ordinal) if (value > upper) upper = value if (value < lower) lower = value - } else { - nullCount += 1 + sizeInBytes += LONG.defaultSize } } - def collectedStatistics = Row(lower, upper, nullCount) + def collectedStatistics = Row(lower, upper, nullCount, count, sizeInBytes) } private[sql] class DoubleColumnStats extends ColumnStats { - var upper = Double.MinValue - var lower = Double.MaxValue - var nullCount = 0 + protected var upper = Double.MinValue + protected var lower = Double.MaxValue override def gatherStats(row: Row, ordinal: Int): Unit = { + super.gatherStats(row, ordinal) if (!row.isNullAt(ordinal)) { val value = row.getDouble(ordinal) if (value > upper) upper = value if (value < lower) lower = value - } else { - nullCount += 1 + sizeInBytes += DOUBLE.defaultSize } } - def collectedStatistics = Row(lower, upper, nullCount) + def collectedStatistics = Row(lower, upper, nullCount, count, sizeInBytes) } private[sql] class FloatColumnStats extends ColumnStats { - var upper = Float.MinValue - var lower = Float.MaxValue - var nullCount = 0 + protected var upper = Float.MinValue + protected var lower = Float.MaxValue override def gatherStats(row: Row, ordinal: Int): Unit = { + super.gatherStats(row, ordinal) if (!row.isNullAt(ordinal)) { val value = row.getFloat(ordinal) if (value > upper) upper = value if (value < lower) lower = value - } else { - nullCount += 1 + sizeInBytes += FLOAT.defaultSize } } - def collectedStatistics = Row(lower, upper, nullCount) + def collectedStatistics = Row(lower, upper, nullCount, count, sizeInBytes) } private[sql] class IntColumnStats extends ColumnStats { - var upper = Int.MinValue - var lower = Int.MaxValue - var nullCount = 0 + protected var upper = Int.MinValue + protected var lower = Int.MaxValue override def gatherStats(row: Row, ordinal: Int): Unit = { + super.gatherStats(row, ordinal) if (!row.isNullAt(ordinal)) { val value = row.getInt(ordinal) if (value > upper) upper = value if (value < lower) lower = value - } else { - nullCount += 1 + sizeInBytes += INT.defaultSize } } - def collectedStatistics = Row(lower, upper, nullCount) + def collectedStatistics = Row(lower, upper, nullCount, count, sizeInBytes) } private[sql] class StringColumnStats extends ColumnStats { - var upper: String = null - var lower: String = null - var nullCount = 0 + protected var upper: String = null + protected var lower: String = null override def gatherStats(row: Row, ordinal: Int): Unit = { + super.gatherStats(row, ordinal) if (!row.isNullAt(ordinal)) { val value = row.getString(ordinal) if (upper == null || value.compareTo(upper) > 0) upper = value if (lower == null || value.compareTo(lower) < 0) lower = value - } else { - nullCount += 1 + sizeInBytes += STRING.actualSize(row, ordinal) } } - def collectedStatistics = Row(lower, upper, nullCount) + def collectedStatistics = Row(lower, upper, nullCount, count, sizeInBytes) } private[sql] class DateColumnStats extends ColumnStats { - var upper: Date = null - var lower: Date = null - var nullCount = 0 + protected var upper: Date = null + protected var lower: Date = null override def gatherStats(row: Row, ordinal: Int) { + super.gatherStats(row, ordinal) if (!row.isNullAt(ordinal)) { val value = row(ordinal).asInstanceOf[Date] if (upper == null || value.compareTo(upper) > 0) upper = value if (lower == null || value.compareTo(lower) < 0) lower = value - } else { - nullCount += 1 + sizeInBytes += DATE.defaultSize } } - def collectedStatistics = Row(lower, upper, nullCount) + def collectedStatistics = Row(lower, upper, nullCount, count, sizeInBytes) } private[sql] class TimestampColumnStats extends ColumnStats { - var upper: Timestamp = null - var lower: Timestamp = null - var nullCount = 0 + protected var upper: Timestamp = null + protected var lower: Timestamp = null override def gatherStats(row: Row, ordinal: Int): Unit = { + super.gatherStats(row, ordinal) if (!row.isNullAt(ordinal)) { val value = row(ordinal).asInstanceOf[Timestamp] if (upper == null || value.compareTo(upper) > 0) upper = value if (lower == null || value.compareTo(lower) < 0) lower = value - } else { - nullCount += 1 + sizeInBytes += TIMESTAMP.defaultSize + } + } + + def collectedStatistics = Row(lower, upper, nullCount, count, sizeInBytes) +} + +private[sql] class BinaryColumnStats extends ColumnStats { + override def gatherStats(row: Row, ordinal: Int): Unit = { + super.gatherStats(row, ordinal) + if (!row.isNullAt(ordinal)) { + sizeInBytes += BINARY.actualSize(row, ordinal) + } + } + + def collectedStatistics = Row(null, null, nullCount, count, sizeInBytes) +} + +private[sql] class GenericColumnStats extends ColumnStats { + override def gatherStats(row: Row, ordinal: Int): Unit = { + super.gatherStats(row, ordinal) + if (!row.isNullAt(ordinal)) { + sizeInBytes += GENERIC.actualSize(row, ordinal) } } - def collectedStatistics = Row(lower, upper, nullCount) + def collectedStatistics = Row(null, null, nullCount, count, sizeInBytes) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/InMemoryColumnarTableScan.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/InMemoryColumnarTableScan.scala index 22ab0e2613f21..455b415d9d959 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/InMemoryColumnarTableScan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/InMemoryColumnarTableScan.scala @@ -19,13 +19,15 @@ package org.apache.spark.sql.columnar import java.nio.ByteBuffer +import scala.collection.mutable.ArrayBuffer + import org.apache.spark.SparkContext._ import org.apache.spark.rdd.RDD import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.analysis.MultiInstanceRelation import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Statistics} import org.apache.spark.sql.execution.{LeafNode, SparkPlan} import org.apache.spark.storage.StorageLevel @@ -45,15 +47,51 @@ private[sql] case class InMemoryRelation( useCompression: Boolean, batchSize: Int, storageLevel: StorageLevel, - child: SparkPlan) - (private var _cachedColumnBuffers: RDD[CachedBatch] = null) + child: SparkPlan)( + private var _cachedColumnBuffers: RDD[CachedBatch] = null, + private var _statistics: Statistics = null) extends LogicalPlan with MultiInstanceRelation { - override lazy val statistics = - Statistics(sizeInBytes = child.sqlContext.defaultSizeInBytes) + private val batchStats = + child.sqlContext.sparkContext.accumulableCollection(ArrayBuffer.empty[Row]) val partitionStatistics = new PartitionStatistics(output) + private def computeSizeInBytes = { + val sizeOfRow: Expression = + BindReferences.bindReference( + output.map(a => partitionStatistics.forAttribute(a).sizeInBytes).reduce(Add), + partitionStatistics.schema) + + batchStats.value.map(row => sizeOfRow.eval(row).asInstanceOf[Long]).sum + } + + // Statistics propagation contracts: + // 1. Non-null `_statistics` must reflect the actual statistics of the underlying data + // 2. Only propagate statistics when `_statistics` is non-null + private def statisticsToBePropagated = if (_statistics == null) { + val updatedStats = statistics + if (_statistics == null) null else updatedStats + } else { + _statistics + } + + override def statistics = if (_statistics == null) { + if (batchStats.value.isEmpty) { + // Underlying columnar RDD hasn't been materialized, no useful statistics information + // available, return the default statistics. + Statistics(sizeInBytes = child.sqlContext.defaultSizeInBytes) + } else { + // Underlying columnar RDD has been materialized, required information has also been collected + // via the `batchStats` accumulator, compute the final statistics, and update `_statistics`. + _statistics = Statistics(sizeInBytes = computeSizeInBytes) + _statistics + } + } else { + // Pre-computed statistics + _statistics + } + // If the cached column buffers were not passed in, we calculate them in the constructor. // As in Spark, the actual work of caching is lazy. if (_cachedColumnBuffers == null) { @@ -91,6 +129,7 @@ private[sql] case class InMemoryRelation( val stats = Row.fromSeq( columnBuilders.map(_.columnStats.collectedStatistics).foldLeft(Seq.empty[Any])(_ ++ _)) + batchStats += stats CachedBatch(columnBuilders.map(_.build().array()), stats) } @@ -104,7 +143,8 @@ private[sql] case class InMemoryRelation( def withOutput(newOutput: Seq[Attribute]): InMemoryRelation = { InMemoryRelation( - newOutput, useCompression, batchSize, storageLevel, child)(_cachedColumnBuffers) + newOutput, useCompression, batchSize, storageLevel, child)( + _cachedColumnBuffers, statisticsToBePropagated) } override def children = Seq.empty @@ -116,10 +156,14 @@ private[sql] case class InMemoryRelation( batchSize, storageLevel, child)( - _cachedColumnBuffers).asInstanceOf[this.type] + _cachedColumnBuffers, + statisticsToBePropagated).asInstanceOf[this.type] } def cachedColumnBuffers = _cachedColumnBuffers + + override protected def otherCopyArgs: Seq[AnyRef] = + Seq(_cachedColumnBuffers, statisticsToBePropagated) } private[sql] case class InMemoryColumnarTableScan( @@ -132,6 +176,8 @@ private[sql] case class InMemoryColumnarTableScan( override def output: Seq[Attribute] = attributes + private def statsFor(a: Attribute) = relation.partitionStatistics.forAttribute(a) + // Returned filter predicate should return false iff it is impossible for the input expression // to evaluate to `true' based on statistics collected about this partition batch. val buildFilter: PartialFunction[Expression, Expression] = { @@ -144,44 +190,24 @@ private[sql] case class InMemoryColumnarTableScan( buildFilter(lhs) || buildFilter(rhs) case EqualTo(a: AttributeReference, l: Literal) => - val aStats = relation.partitionStatistics.forAttribute(a) - aStats.lowerBound <= l && l <= aStats.upperBound - + statsFor(a).lowerBound <= l && l <= statsFor(a).upperBound case EqualTo(l: Literal, a: AttributeReference) => - val aStats = relation.partitionStatistics.forAttribute(a) - aStats.lowerBound <= l && l <= aStats.upperBound - - case LessThan(a: AttributeReference, l: Literal) => - val aStats = relation.partitionStatistics.forAttribute(a) - aStats.lowerBound < l - - case LessThan(l: Literal, a: AttributeReference) => - val aStats = relation.partitionStatistics.forAttribute(a) - l < aStats.upperBound - - case LessThanOrEqual(a: AttributeReference, l: Literal) => - val aStats = relation.partitionStatistics.forAttribute(a) - aStats.lowerBound <= l + statsFor(a).lowerBound <= l && l <= statsFor(a).upperBound - case LessThanOrEqual(l: Literal, a: AttributeReference) => - val aStats = relation.partitionStatistics.forAttribute(a) - l <= aStats.upperBound + case LessThan(a: AttributeReference, l: Literal) => statsFor(a).lowerBound < l + case LessThan(l: Literal, a: AttributeReference) => l < statsFor(a).upperBound - case GreaterThan(a: AttributeReference, l: Literal) => - val aStats = relation.partitionStatistics.forAttribute(a) - l < aStats.upperBound + case LessThanOrEqual(a: AttributeReference, l: Literal) => statsFor(a).lowerBound <= l + case LessThanOrEqual(l: Literal, a: AttributeReference) => l <= statsFor(a).upperBound - case GreaterThan(l: Literal, a: AttributeReference) => - val aStats = relation.partitionStatistics.forAttribute(a) - aStats.lowerBound < l + case GreaterThan(a: AttributeReference, l: Literal) => l < statsFor(a).upperBound + case GreaterThan(l: Literal, a: AttributeReference) => statsFor(a).lowerBound < l - case GreaterThanOrEqual(a: AttributeReference, l: Literal) => - val aStats = relation.partitionStatistics.forAttribute(a) - l <= aStats.upperBound + case GreaterThanOrEqual(a: AttributeReference, l: Literal) => l <= statsFor(a).upperBound + case GreaterThanOrEqual(l: Literal, a: AttributeReference) => statsFor(a).lowerBound <= l - case GreaterThanOrEqual(l: Literal, a: AttributeReference) => - val aStats = relation.partitionStatistics.forAttribute(a) - aStats.lowerBound <= l + case IsNull(a: Attribute) => statsFor(a).nullCount > 0 + case IsNotNull(a: Attribute) => statsFor(a).count - statsFor(a).nullCount > 0 } val partitionFilters = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala index 2ddf513b6fc98..ed6b95dc6d9d0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala @@ -17,34 +17,34 @@ package org.apache.spark.sql.execution -import org.apache.spark.sql.catalyst.analysis.MultiInstanceRelation -import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan - -import scala.reflect.runtime.universe.TypeTag - import org.apache.spark.annotation.DeveloperApi import org.apache.spark.rdd.RDD -import org.apache.spark.sql.{SQLContext, Row} +import org.apache.spark.sql.{DataType, StructType, Row, SQLContext} import org.apache.spark.sql.catalyst.ScalaReflection +import org.apache.spark.sql.catalyst.ScalaReflection.Schema +import org.apache.spark.sql.catalyst.analysis.MultiInstanceRelation import org.apache.spark.sql.catalyst.expressions.{Attribute, GenericMutableRow} +import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Statistics} +import org.apache.spark.sql.catalyst.types.UserDefinedType /** * :: DeveloperApi :: */ @DeveloperApi object RDDConversions { - def productToRowRdd[A <: Product](data: RDD[A]): RDD[Row] = { + def productToRowRdd[A <: Product](data: RDD[A], schema: StructType): RDD[Row] = { data.mapPartitions { iterator => if (iterator.isEmpty) { Iterator.empty } else { val bufferedIterator = iterator.buffered val mutableRow = new GenericMutableRow(bufferedIterator.head.productArity) - + val schemaFields = schema.fields.toArray bufferedIterator.map { r => var i = 0 while (i < mutableRow.length) { - mutableRow(i) = ScalaReflection.convertToCatalyst(r.productElement(i)) + mutableRow(i) = + ScalaReflection.convertToCatalyst(r.productElement(i), schemaFields(i).dataType) i += 1 } @@ -53,12 +53,6 @@ object RDDConversions { } } } - - /* - def toLogicalPlan[A <: Product : TypeTag](productRdd: RDD[A]): LogicalPlan = { - LogicalRDD(ScalaReflection.attributesFor[A], productToRowRdd(productRdd)) - } - */ } case class LogicalRDD(output: Seq[Attribute], rdd: RDD[Row])(sqlContext: SQLContext) @@ -100,7 +94,7 @@ case class SparkLogicalPlan(alreadyPlanned: SparkPlan)(@transient sqlContext: SQ override final def newInstance(): this.type = { SparkLogicalPlan( alreadyPlanned match { - case ExistingRdd(output, rdd) => ExistingRdd(output.map(_.newInstance), rdd) + case ExistingRdd(output, rdd) => ExistingRdd(output.map(_.newInstance()), rdd) case _ => sys.error("Multiple instance of the same relation detected.") })(sqlContext).asInstanceOf[this.type] } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala index b3edd5020fa8c..087b0ecbb25c0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala @@ -70,16 +70,29 @@ case class GeneratedAggregate( val computeFunctions = aggregatesToCompute.map { case c @ Count(expr) => + // If we're evaluating UnscaledValue(x), we can do Count on x directly, since its + // UnscaledValue will be null if and only if x is null; helps with Average on decimals + val toCount = expr match { + case UnscaledValue(e) => e + case _ => expr + } val currentCount = AttributeReference("currentCount", LongType, nullable = false)() val initialValue = Literal(0L) - val updateFunction = If(IsNotNull(expr), Add(currentCount, Literal(1L)), currentCount) + val updateFunction = If(IsNotNull(toCount), Add(currentCount, Literal(1L)), currentCount) val result = currentCount AggregateEvaluation(currentCount :: Nil, initialValue :: Nil, updateFunction :: Nil, result) case Sum(expr) => - val currentSum = AttributeReference("currentSum", expr.dataType, nullable = false)() - val initialValue = Cast(Literal(0L), expr.dataType) + val resultType = expr.dataType match { + case DecimalType.Fixed(precision, scale) => + DecimalType(precision + 10, scale) + case _ => + expr.dataType + } + + val currentSum = AttributeReference("currentSum", resultType, nullable = false)() + val initialValue = Cast(Literal(0L), resultType) // Coalasce avoids double calculation... // but really, common sub expression elimination would be better.... @@ -93,10 +106,26 @@ case class GeneratedAggregate( val currentSum = AttributeReference("currentSum", expr.dataType, nullable = false)() val initialCount = Literal(0L) val initialSum = Cast(Literal(0L), expr.dataType) - val updateCount = If(IsNotNull(expr), Add(currentCount, Literal(1L)), currentCount) + + // If we're evaluating UnscaledValue(x), we can do Count on x directly, since its + // UnscaledValue will be null if and only if x is null; helps with Average on decimals + val toCount = expr match { + case UnscaledValue(e) => e + case _ => expr + } + + val updateCount = If(IsNotNull(toCount), Add(currentCount, Literal(1L)), currentCount) val updateSum = Coalesce(Add(expr, currentSum) :: currentSum :: Nil) - val result = Divide(Cast(currentSum, DoubleType), Cast(currentCount, DoubleType)) + val resultType = expr.dataType match { + case DecimalType.Fixed(precision, scale) => + DecimalType(precision + 4, scale + 4) + case DecimalType.Unlimited => + DecimalType.Unlimited + case _ => + DoubleType + } + val result = Divide(Cast(currentSum, resultType), Cast(currentCount, resultType)) AggregateEvaluation( currentCount :: currentSum :: Nil, @@ -142,7 +171,7 @@ case class GeneratedAggregate( val computationSchema = computeFunctions.flatMap(_.schema) - val resultMap: Map[TreeNodeRef, Expression] = + val resultMap: Map[TreeNodeRef, Expression] = aggregatesToCompute.zip(computeFunctions).map { case (agg, func) => new TreeNodeRef(agg) -> func.result }.toMap diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala index b1a7948b66cb6..81c60e00505c5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala @@ -20,10 +20,8 @@ package org.apache.spark.sql.execution import org.apache.spark.annotation.DeveloperApi import org.apache.spark.Logging import org.apache.spark.rdd.RDD - - import org.apache.spark.sql.SQLContext -import org.apache.spark.sql.catalyst.trees +import org.apache.spark.sql.catalyst.{ScalaReflection, trees} import org.apache.spark.sql.catalyst.analysis.MultiInstanceRelation import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.codegen._ @@ -82,7 +80,8 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ /** * Runs this query returning the result as an array. */ - def executeCollect(): Array[Row] = execute().map(_.copy()).collect() + def executeCollect(): Array[Row] = + execute().map(ScalaReflection.convertRowToScala(_, schema)).collect() protected def newProjection( expressions: Seq[Expression], inputSchema: Seq[Attribute]): Projection = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlSerializer.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlSerializer.scala index 077e6ebc5f11e..84d96e612f0dc 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlSerializer.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlSerializer.scala @@ -29,6 +29,7 @@ import com.twitter.chill.{AllScalaRegistrar, ResourcePool} import org.apache.spark.{SparkEnv, SparkConf} import org.apache.spark.serializer.{SerializerInstance, KryoSerializer} import org.apache.spark.sql.catalyst.expressions.GenericRow +import org.apache.spark.sql.catalyst.types.decimal.Decimal import org.apache.spark.util.collection.OpenHashSet import org.apache.spark.util.MutablePair import org.apache.spark.util.Utils @@ -51,6 +52,7 @@ private[sql] class SparkSqlSerializer(conf: SparkConf) extends KryoSerializer(co kryo.register(classOf[LongHashSet], new LongHashSetSerializer) kryo.register(classOf[org.apache.spark.util.collection.OpenHashSet[_]], new OpenHashSetSerializer) + kryo.register(classOf[Decimal]) kryo.setReferences(false) kryo.setClassLoader(Utils.getSparkClassLoader) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index 79e4ddb8c4f5d..cc7e0c05ffc70 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.execution -import org.apache.spark.sql.{SQLContext, execution} +import org.apache.spark.sql.{SQLContext, Strategy, execution} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.planning._ import org.apache.spark.sql.catalyst.plans._ @@ -280,7 +280,8 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { val nPartitions = if (data.isEmpty) 1 else numPartitions PhysicalRDD( output, - RDDConversions.productToRowRdd(sparkContext.parallelize(data, nPartitions))) :: Nil + RDDConversions.productToRowRdd(sparkContext.parallelize(data, nPartitions), + StructType.fromAttributes(output))) :: Nil case logical.Limit(IntegerLiteral(limit), child) => execution.Limit(limit, planLater(child)) :: Nil case Unions(unionChildren) => @@ -304,6 +305,7 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { case class CommandStrategy(context: SQLContext) extends Strategy { def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { + case r: RunnableCommand => ExecutedCommand(r) :: Nil case logical.SetCommand(kv) => Seq(execution.SetCommand(kv, plan.output)(context)) case logical.ExplainCommand(logicalPlan, extended) => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala index 977f3c9f32096..1b8ba3ace2a82 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala @@ -143,7 +143,7 @@ case class Limit(limit: Int, child: SparkPlan) partsScanned += numPartsToTry } - buf.toArray + buf.toArray.map(ScalaReflection.convertRowToScala(_, this.schema)) } override def execute() = { @@ -176,10 +176,11 @@ case class TakeOrdered(limit: Int, sortOrder: Seq[SortOrder], child: SparkPlan) override def output = child.output override def outputPartitioning = SinglePartition - val ordering = new RowOrdering(sortOrder, child.output) + val ord = new RowOrdering(sortOrder, child.output) // TODO: Is this copying for no reason? - override def executeCollect() = child.execute().map(_.copy()).takeOrdered(limit)(ordering) + override def executeCollect() = child.execute().map(_.copy()).takeOrdered(limit)(ord) + .map(ScalaReflection.convertRowToScala(_, this.schema)) // TODO: Terminal split should be implemented differently from non-terminal split. // TODO: Pick num splits based on |limit|. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/commands.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/commands.scala index 5859eba408ee1..f23b9c48cfb40 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/commands.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/commands.scala @@ -21,10 +21,12 @@ import org.apache.spark.Logging import org.apache.spark.annotation.DeveloperApi import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.errors.TreeNodeException -import org.apache.spark.sql.catalyst.expressions.Attribute +import org.apache.spark.sql.catalyst.expressions.{Row, Attribute} +import org.apache.spark.sql.catalyst.plans.logical import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan -import org.apache.spark.sql.{Row, SQLConf, SQLContext} +import org.apache.spark.sql.{SQLConf, SQLContext} +// TODO: DELETE ME... trait Command { this: SparkPlan => @@ -44,6 +46,35 @@ trait Command { override def execute(): RDD[Row] = sqlContext.sparkContext.parallelize(sideEffectResult, 1) } +// TODO: Replace command with runnable command. +trait RunnableCommand extends logical.Command { + self: Product => + + def output: Seq[Attribute] + def run(sqlContext: SQLContext): Seq[Row] +} + +case class ExecutedCommand(cmd: RunnableCommand) extends SparkPlan { + /** + * A concrete command should override this lazy field to wrap up any side effects caused by the + * command or any other computation that should be evaluated exactly once. The value of this field + * can be used as the contents of the corresponding RDD generated from the physical plan of this + * command. + * + * The `execute()` method of all the physical command classes should reference `sideEffectResult` + * so that the command can be executed eagerly right after the command query is created. + */ + protected[sql] lazy val sideEffectResult: Seq[Row] = cmd.run(sqlContext) + + override def output = cmd.output + + override def children = Nil + + override def executeCollect(): Array[Row] = sideEffectResult.toArray + + override def execute(): RDD[Row] = sqlContext.sparkContext.parallelize(sideEffectResult, 1) +} + /** * :: DeveloperApi :: */ @@ -53,50 +84,35 @@ case class SetCommand(kv: Option[(String, Option[String])], output: Seq[Attribut extends LeafNode with Command with Logging { override protected lazy val sideEffectResult: Seq[Row] = kv match { - // Set value for the key. - case Some((key, Some(value))) => - if (key == SQLConf.Deprecated.MAPRED_REDUCE_TASKS) { - logWarning(s"Property ${SQLConf.Deprecated.MAPRED_REDUCE_TASKS} is deprecated, " + + // Configures the deprecated "mapred.reduce.tasks" property. + case Some((SQLConf.Deprecated.MAPRED_REDUCE_TASKS, Some(value))) => + logWarning( + s"Property ${SQLConf.Deprecated.MAPRED_REDUCE_TASKS} is deprecated, " + s"automatically converted to ${SQLConf.SHUFFLE_PARTITIONS} instead.") - context.setConf(SQLConf.SHUFFLE_PARTITIONS, value) - Seq(Row(s"${SQLConf.SHUFFLE_PARTITIONS}=$value")) - } else { - context.setConf(key, value) - Seq(Row(s"$key=$value")) - } - - // Query the value bound to the key. + context.setConf(SQLConf.SHUFFLE_PARTITIONS, value) + Seq(Row(s"${SQLConf.SHUFFLE_PARTITIONS}=$value")) + + // Configures a single property. + case Some((key, Some(value))) => + context.setConf(key, value) + Seq(Row(s"$key=$value")) + + // Queries all key-value pairs that are set in the SQLConf of the context. Notice that different + // from Hive, here "SET -v" is an alias of "SET". (In Hive, "SET" returns all changed properties + // while "SET -v" returns all properties.) + case Some(("-v", None)) | None => + context.getAllConfs.map { case (k, v) => Row(s"$k=$v") }.toSeq + + // Queries the deprecated "mapred.reduce.tasks" property. + case Some((SQLConf.Deprecated.MAPRED_REDUCE_TASKS, None)) => + logWarning( + s"Property ${SQLConf.Deprecated.MAPRED_REDUCE_TASKS} is deprecated, " + + s"showing ${SQLConf.SHUFFLE_PARTITIONS} instead.") + Seq(Row(s"${SQLConf.SHUFFLE_PARTITIONS}=${context.numShufflePartitions}")) + + // Queries a single property. case Some((key, None)) => - // TODO (lian) This is just a workaround to make the Simba ODBC driver work. - // Should remove this once we get the ODBC driver updated. - if (key == "-v") { - val hiveJars = Seq( - "hive-exec-0.12.0.jar", - "hive-service-0.12.0.jar", - "hive-common-0.12.0.jar", - "hive-hwi-0.12.0.jar", - "hive-0.12.0.jar").mkString(":") - - context.getAllConfs.map { case (k, v) => - Row(s"$k=$v") - }.toSeq ++ Seq( - Row("system:java.class.path=" + hiveJars), - Row("system:sun.java.command=shark.SharkServer2")) - } else { - if (key == SQLConf.Deprecated.MAPRED_REDUCE_TASKS) { - logWarning(s"Property ${SQLConf.Deprecated.MAPRED_REDUCE_TASKS} is deprecated, " + - s"showing ${SQLConf.SHUFFLE_PARTITIONS} instead.") - Seq(Row(s"${SQLConf.SHUFFLE_PARTITIONS}=${context.numShufflePartitions}")) - } else { - Seq(Row(s"$key=${context.getConf(key, "")}")) - } - } - - // Query all key-value pairs that are set in the SQLConf of the context. - case _ => - context.getAllConfs.map { case (k, v) => - Row(s"$k=$v") - }.toSeq + Seq(Row(s"$key=${context.getConf(key, "")}")) } override def otherCopyArgs = context :: Nil diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala index 8fd35880eedfe..5cf2a785adc7d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala @@ -49,7 +49,8 @@ case class BroadcastHashJoin( @transient private val broadcastFuture = future { - val input: Array[Row] = buildPlan.executeCollect() + // Note that we use .execute().collect() because we don't want to convert data to Scala types + val input: Array[Row] = buildPlan.execute().map(_.copy()).collect() val hashed = HashedRelation(input.iterator, buildSideKeyGenerator, input.length) sparkContext.broadcast(hashed) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUdfs.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUdfs.scala index be729e5d244b0..a83cf5d441d1e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUdfs.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUdfs.scala @@ -19,11 +19,16 @@ package org.apache.spark.sql.execution import java.util.{List => JList, Map => JMap} +import org.apache.spark.sql.catalyst.types.decimal.Decimal + +import scala.collection.JavaConversions._ +import scala.collection.JavaConverters._ + import net.razorvine.pickle.{Pickler, Unpickler} import org.apache.spark.annotation.DeveloperApi import org.apache.spark.api.python.PythonRDD import org.apache.spark.broadcast.Broadcast -import org.apache.spark.rdd.RDD +import org.apache.spark.sql.catalyst.expressions.Row import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.logical import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan @@ -31,8 +36,6 @@ import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.catalyst.types._ import org.apache.spark.{Accumulator, Logging => SparkLogging} -import scala.collection.JavaConversions._ - /** * A serialized version of a Python lambda function. Suitable for use in a [[PythonRDD]]. */ @@ -108,6 +111,87 @@ private[spark] object ExtractPythonUdfs extends Rule[LogicalPlan] { object EvaluatePython { def apply(udf: PythonUDF, child: LogicalPlan) = new EvaluatePython(udf, child, AttributeReference("pythonUDF", udf.dataType)()) + + /** + * Helper for converting a Scala object to a java suitable for pyspark serialization. + */ + def toJava(obj: Any, dataType: DataType): Any = (obj, dataType) match { + case (null, _) => null + + case (row: Seq[Any], struct: StructType) => + val fields = struct.fields.map(field => field.dataType) + row.zip(fields).map { + case (obj, dataType) => toJava(obj, dataType) + }.toArray + + case (seq: Seq[Any], array: ArrayType) => + seq.map(x => toJava(x, array.elementType)).asJava + case (list: JList[_], array: ArrayType) => + list.map(x => toJava(x, array.elementType)).asJava + case (arr, array: ArrayType) if arr.getClass.isArray => + arr.asInstanceOf[Array[Any]].map(x => toJava(x, array.elementType)) + + case (obj: Map[_, _], mt: MapType) => obj.map { + case (k, v) => (k, toJava(v, mt.valueType)) // key should be primitive type + }.asJava + + case (ud, udt: UserDefinedType[_]) => toJava(udt.serialize(ud), udt.sqlType) + + case (dec: BigDecimal, dt: DecimalType) => dec.underlying() // Pyrolite can handle BigDecimal + + // Pyrolite can handle Timestamp + case (other, _) => other + } + + /** + * Convert Row into Java Array (for pickled into Python) + */ + def rowToArray(row: Row, fields: Seq[DataType]): Array[Any] = { + row.zip(fields).map {case (obj, dt) => toJava(obj, dt)}.toArray + } + + // Converts value to the type specified by the data type. + // Because Python does not have data types for TimestampType, FloatType, ShortType, and + // ByteType, we need to explicitly convert values in columns of these data types to the desired + // JVM data types. + def fromJava(obj: Any, dataType: DataType): Any = (obj, dataType) match { + // TODO: We should check nullable + case (null, _) => null + + case (c: java.util.List[_], ArrayType(elementType, _)) => + c.map { e => fromJava(e, elementType)}: Seq[Any] + + case (c, ArrayType(elementType, _)) if c.getClass.isArray => + c.asInstanceOf[Array[_]].map(e => fromJava(e, elementType)): Seq[Any] + + case (c: java.util.Map[_, _], MapType(keyType, valueType, _)) => c.map { + case (key, value) => (fromJava(key, keyType), fromJava(value, valueType)) + }.toMap + + case (c, StructType(fields)) if c.getClass.isArray => + new GenericRow(c.asInstanceOf[Array[_]].zip(fields).map { + case (e, f) => fromJava(e, f.dataType) + }): Row + + case (c: java.util.Calendar, DateType) => + new java.sql.Date(c.getTime().getTime()) + + case (c: java.util.Calendar, TimestampType) => + new java.sql.Timestamp(c.getTime().getTime()) + + case (_, udt: UserDefinedType[_]) => + fromJava(obj, udt.sqlType) + + case (c: Int, ByteType) => c.toByte + case (c: Long, ByteType) => c.toByte + case (c: Int, ShortType) => c.toShort + case (c: Long, ShortType) => c.toShort + case (c: Long, IntegerType) => c.toInt + case (c: Double, FloatType) => c.toFloat + case (c, StringType) if !c.isInstanceOf[String] => c.toString + + case (c, _) => c + } } /** @@ -141,8 +225,11 @@ case class BatchPythonEvaluation(udf: PythonUDF, output: Seq[Attribute], child: val parent = childResults.mapPartitions { iter => val pickle = new Pickler val currentRow = newMutableProjection(udf.children, child.output)() + val fields = udf.children.map(_.dataType) iter.grouped(1000).map { inputRows => - val toBePickled = inputRows.map(currentRow(_).toArray).toArray + val toBePickled = inputRows.map { row => + EvaluatePython.rowToArray(currentRow(row), fields) + }.toArray pickle.dumps(toBePickled) } } @@ -165,10 +252,7 @@ case class BatchPythonEvaluation(udf: PythonUDF, output: Seq[Attribute], child: }.mapPartitions { iter => val row = new GenericMutableRow(1) iter.map { result => - row(0) = udf.dataType match { - case StringType => result.toString - case other => result - } + row(0) = EvaluatePython.fromJava(result, udf.dataType) row: Row } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/json/JSONRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/json/JSONRelation.scala new file mode 100644 index 0000000000000..fc70c183437f6 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/json/JSONRelation.scala @@ -0,0 +1,49 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.json + +import org.apache.spark.sql.SQLContext +import org.apache.spark.sql.sources._ + +private[sql] class DefaultSource extends RelationProvider { + /** Returns a new base relation with the given parameters. */ + override def createRelation( + sqlContext: SQLContext, + parameters: Map[String, String]): BaseRelation = { + val fileName = parameters.getOrElse("path", sys.error("Option 'path' not specified")) + val samplingRatio = parameters.get("samplingRatio").map(_.toDouble).getOrElse(1.0) + + JSONRelation(fileName, samplingRatio)(sqlContext) + } +} + +private[sql] case class JSONRelation(fileName: String, samplingRatio: Double)( + @transient val sqlContext: SQLContext) + extends TableScan { + + private def baseRDD = sqlContext.sparkContext.textFile(fileName) + + override val schema = + JsonRDD.inferSchema( + baseRDD, + samplingRatio, + sqlContext.columnNameOfCorruptRecord) + + override def buildScan() = + JsonRDD.jsonStringToRow(baseRDD, schema, sqlContext.columnNameOfCorruptRecord) +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/json/JsonRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/json/JsonRDD.scala index 61ee960aad9d2..0f2dcdcacf0ca 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/json/JsonRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/json/JsonRDD.scala @@ -17,10 +17,12 @@ package org.apache.spark.sql.json +import org.apache.spark.sql.catalyst.types.decimal.Decimal + import scala.collection.Map import scala.collection.convert.Wrappers.{JMapWrapper, JListWrapper} import scala.math.BigDecimal -import java.sql.Timestamp +import java.sql.{Date, Timestamp} import com.fasterxml.jackson.core.JsonProcessingException import com.fasterxml.jackson.databind.ObjectMapper @@ -71,16 +73,18 @@ private[sql] object JsonRDD extends Logging { def makeStruct(values: Seq[Seq[String]], prefix: Seq[String]): StructType = { val (topLevel, structLike) = values.partition(_.size == 1) + val topLevelFields = topLevel.filter { name => resolved.get(prefix ++ name).get match { case ArrayType(elementType, _) => { def hasInnerStruct(t: DataType): Boolean = t match { - case s: StructType => false + case s: StructType => true case ArrayType(t1, _) => hasInnerStruct(t1) - case o => true + case o => false } - hasInnerStruct(elementType) + // Check if this array has inner struct. + !hasInnerStruct(elementType) } case struct: StructType => false case _ => true @@ -88,8 +92,11 @@ private[sql] object JsonRDD extends Logging { }.map { a => StructField(a.head, resolved.get(prefix ++ a).get, nullable = true) } + val topLevelFieldNameSet = topLevelFields.map(_.name) - val structFields: Seq[StructField] = structLike.groupBy(_(0)).map { + val structFields: Seq[StructField] = structLike.groupBy(_(0)).filter { + case (name, _) => !topLevelFieldNameSet.contains(name) + }.map { case (name, fields) => { val nestedFields = fields.map(_.tail) val structType = makeStruct(nestedFields, prefix :+ name) @@ -117,10 +124,7 @@ private[sql] object JsonRDD extends Logging { } }.flatMap(field => field).toSeq - StructType( - (topLevelFields ++ structFields).sortBy { - case StructField(name, _, _) => name - }) + StructType((topLevelFields ++ structFields).sortBy(_.name)) } makeStruct(resolved.keySet.toSeq, Nil) @@ -128,7 +132,7 @@ private[sql] object JsonRDD extends Logging { private[sql] def nullTypeToStringType(struct: StructType): StructType = { val fields = struct.fields.map { - case StructField(fieldName, dataType, nullable) => { + case StructField(fieldName, dataType, nullable, _) => { val newType = dataType match { case NullType => StringType case ArrayType(NullType, containsNull) => ArrayType(StringType, containsNull) @@ -163,9 +167,7 @@ private[sql] object JsonRDD extends Logging { StructField(name, dataType, true) } } - StructType(newFields.toSeq.sortBy { - case StructField(name, _, _) => name - }) + StructType(newFields.toSeq.sortBy(_.name)) } case (ArrayType(elementType1, containsNull1), ArrayType(elementType2, containsNull2)) => ArrayType(compatibleType(elementType1, elementType2), containsNull1 || containsNull2) @@ -180,9 +182,9 @@ private[sql] object JsonRDD extends Logging { ScalaReflection.typeOfObject orElse { // Since we do not have a data type backed by BigInteger, // when we see a Java BigInteger, we use DecimalType. - case value: java.math.BigInteger => DecimalType + case value: java.math.BigInteger => DecimalType.Unlimited // DecimalType's JVMType is scala BigDecimal. - case value: java.math.BigDecimal => DecimalType + case value: java.math.BigDecimal => DecimalType.Unlimited // Unexpected data type. case _ => StringType } @@ -242,14 +244,14 @@ private[sql] object JsonRDD extends Logging { def buildKeyPathForInnerStructs(v: Any, t: DataType): Seq[(String, DataType)] = t match { case ArrayType(StructType(Nil), containsNull) => { // The elements of this arrays are structs. - v.asInstanceOf[Seq[Map[String, Any]]].flatMap { + v.asInstanceOf[Seq[Map[String, Any]]].flatMap(Option(_)).flatMap { element => allKeysWithValueTypes(element) }.map { case (k, t) => (s"$key.$k", t) } } case ArrayType(t1, containsNull) => - v.asInstanceOf[Seq[Any]].flatMap { + v.asInstanceOf[Seq[Any]].flatMap(Option(_)).flatMap { element => buildKeyPathForInnerStructs(element, t1) } case other => Nil @@ -324,13 +326,13 @@ private[sql] object JsonRDD extends Logging { } } - private def toDecimal(value: Any): BigDecimal = { + private def toDecimal(value: Any): Decimal = { value match { - case value: java.lang.Integer => BigDecimal(value) - case value: java.lang.Long => BigDecimal(value) - case value: java.math.BigInteger => BigDecimal(value) - case value: java.lang.Double => BigDecimal(value) - case value: java.math.BigDecimal => BigDecimal(value) + case value: java.lang.Integer => Decimal(value) + case value: java.lang.Long => Decimal(value) + case value: java.math.BigInteger => Decimal(BigDecimal(value)) + case value: java.lang.Double => Decimal(value) + case value: java.math.BigDecimal => Decimal(BigDecimal(value)) } } @@ -357,7 +359,8 @@ private[sql] object JsonRDD extends Logging { case (key, value) => if (count > 0) builder.append(",") count += 1 - builder.append(s"""\"${key}\":${toString(value)}""") + val stringValue = if (value.isInstanceOf[String]) s"""\"$value\"""" else toString(value) + builder.append(s"""\"${key}\":${stringValue}""") } builder.append("}") @@ -372,13 +375,20 @@ private[sql] object JsonRDD extends Logging { } } + private def toDate(value: Any): Date = { + value match { + // only support string as date + case value: java.lang.String => Date.valueOf(value) + } + } + private def toTimestamp(value: Any): Timestamp = { value match { - case value: java.lang.Integer => new Timestamp(value.asInstanceOf[Int].toLong) - case value: java.lang.Long => new Timestamp(value) - case value: java.lang.String => Timestamp.valueOf(value) - } - } + case value: java.lang.Integer => new Timestamp(value.asInstanceOf[Int].toLong) + case value: java.lang.Long => new Timestamp(value) + case value: java.lang.String => Timestamp.valueOf(value) + } + } private[json] def enforceCorrectType(value: Any, desiredType: DataType): Any ={ if (value == null) { @@ -389,13 +399,14 @@ private[sql] object JsonRDD extends Logging { case IntegerType => value.asInstanceOf[IntegerType.JvmType] case LongType => toLong(value) case DoubleType => toDouble(value) - case DecimalType => toDecimal(value) + case DecimalType() => toDecimal(value) case BooleanType => value.asInstanceOf[BooleanType.JvmType] case NullType => null case ArrayType(elementType, _) => value.asInstanceOf[Seq[Any]].map(enforceCorrectType(_, elementType)) case struct: StructType => asRow(value.asInstanceOf[Map[String, Any]], struct) + case DateType => toDate(value) case TimestampType => toTimestamp(value) } } @@ -405,7 +416,7 @@ private[sql] object JsonRDD extends Logging { // TODO: Reuse the row instead of creating a new one for every record. val row = new GenericMutableRow(schema.fields.length) schema.fields.zipWithIndex.foreach { - case (StructField(name, dataType, _), i) => + case (StructField(name, dataType, _, _), i) => row.update(i, json.get(name).flatMap(v => Option(v)).map( enforceCorrectType(_, dataType)).orNull) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/package.scala b/sql/core/src/main/scala/org/apache/spark/sql/package.scala index e98d151286818..51dad54f1a3f3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/package.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/package.scala @@ -18,6 +18,7 @@ package org.apache.spark import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.sql.execution.SparkPlan /** * Allows the execution of relational queries, including those expressed in SQL using Spark. @@ -125,6 +126,9 @@ package object sql { @DeveloperApi type DataType = catalyst.types.DataType + @DeveloperApi + val DataType = catalyst.types.DataType + /** * :: DeveloperApi :: * @@ -180,6 +184,20 @@ package object sql { * * The data type representing `scala.math.BigDecimal` values. * + * TODO(matei): explain precision and scale + * + * @group dataType + */ + @DeveloperApi + type DecimalType = catalyst.types.DecimalType + + /** + * :: DeveloperApi :: + * + * The data type representing `scala.math.BigDecimal` values. + * + * TODO(matei): explain precision and scale + * * @group dataType */ @DeveloperApi @@ -414,4 +432,32 @@ package object sql { */ @DeveloperApi val StructField = catalyst.types.StructField + + /** + * Converts a logical plan into zero or more SparkPlans. + */ + @DeveloperApi + type Strategy = org.apache.spark.sql.catalyst.planning.GenericStrategy[SparkPlan] + + /** + * :: DeveloperApi :: + * + * Metadata is a wrapper over Map[String, Any] that limits the value type to simple ones: Boolean, + * Long, Double, String, Metadata, Array[Boolean], Array[Long], Array[Double], Array[String], and + * Array[Metadata]. JSON is used for serialization. + * + * The default constructor is private. User should use either [[MetadataBuilder]] or + * [[Metadata$#fromJson]] to create Metadata instances. + * + * @param map an immutable map that stores the data + */ + @DeveloperApi + type Metadata = catalyst.util.Metadata + + /** + * :: DeveloperApi :: + * Builder for [[Metadata]]. If there is a key collision, the latter will overwrite the former. + */ + @DeveloperApi + type MetadataBuilder = catalyst.util.MetadataBuilder } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetConverter.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetConverter.scala index 2fc7e1cf23ab7..1bbb66aaa19a3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetConverter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetConverter.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.parquet +import org.apache.spark.sql.catalyst.types.decimal.Decimal + import scala.collection.mutable.{Buffer, ArrayBuffer, HashMap} import parquet.io.api.{PrimitiveConverter, GroupConverter, Binary, Converter} @@ -75,6 +77,9 @@ private[sql] object CatalystConverter { parent: CatalystConverter): Converter = { val fieldType: DataType = field.dataType fieldType match { + case udt: UserDefinedType[_] => { + createConverter(field.copy(dataType = udt.sqlType), fieldIndex, parent) + } // For native JVM types we use a converter with native arrays case ArrayType(elementType: NativeType, false) => { new CatalystNativeArrayConverter(elementType, fieldIndex, parent) @@ -117,6 +122,12 @@ private[sql] object CatalystConverter { parent.updateByte(fieldIndex, value.asInstanceOf[ByteType.JvmType]) } } + case d: DecimalType => { + new CatalystPrimitiveConverter(parent, fieldIndex) { + override def addBinary(value: Binary): Unit = + parent.updateDecimal(fieldIndex, value, d) + } + } // All other primitive types use the default converter case ctype: PrimitiveType => { // note: need the type tag here! new CatalystPrimitiveConverter(parent, fieldIndex) @@ -191,6 +202,10 @@ private[parquet] abstract class CatalystConverter extends GroupConverter { protected[parquet] def updateString(fieldIndex: Int, value: Binary): Unit = updateField(fieldIndex, value.toStringUsingUTF8) + protected[parquet] def updateDecimal(fieldIndex: Int, value: Binary, ctype: DecimalType): Unit = { + updateField(fieldIndex, readDecimal(new Decimal(), value, ctype)) + } + protected[parquet] def isRootConverter: Boolean = parent == null protected[parquet] def clearBuffer(): Unit @@ -201,6 +216,27 @@ private[parquet] abstract class CatalystConverter extends GroupConverter { * @return */ def getCurrentRecord: Row = throw new UnsupportedOperationException + + /** + * Read a decimal value from a Parquet Binary into "dest". Only supports decimals that fit in + * a long (i.e. precision <= 18) + */ + protected[parquet] def readDecimal(dest: Decimal, value: Binary, ctype: DecimalType): Unit = { + val precision = ctype.precisionInfo.get.precision + val scale = ctype.precisionInfo.get.scale + val bytes = value.getBytes + require(bytes.length <= 16, "Decimal field too large to read") + var unscaled = 0L + var i = 0 + while (i < bytes.length) { + unscaled = (unscaled << 8) | (bytes(i) & 0xFF) + i += 1 + } + // Make sure unscaled has the right sign, by sign-extending the first bit + val numBits = 8 * bytes.length + unscaled = (unscaled << (64 - numBits)) >> (64 - numBits) + dest.set(unscaled, precision, scale) + } } /** @@ -222,8 +258,8 @@ private[parquet] class CatalystGroupConverter( schema, index, parent, - current=null, - buffer=new ArrayBuffer[Row]( + current = null, + buffer = new ArrayBuffer[Row]( CatalystArrayConverter.INITIAL_ARRAY_SIZE)) /** @@ -268,7 +304,7 @@ private[parquet] class CatalystGroupConverter( override def end(): Unit = { if (!isRootConverter) { - assert(current!=null) // there should be no empty groups + assert(current != null) // there should be no empty groups buffer.append(new GenericRow(current.toArray)) parent.updateField(index, new GenericRow(buffer.toArray.asInstanceOf[Array[Any]])) } @@ -325,7 +361,7 @@ private[parquet] class CatalystPrimitiveRowConverter( override def end(): Unit = {} - // Overriden here to avoid auto-boxing for primitive types + // Overridden here to avoid auto-boxing for primitive types override protected[parquet] def updateBoolean(fieldIndex: Int, value: Boolean): Unit = current.setBoolean(fieldIndex, value) @@ -352,6 +388,16 @@ private[parquet] class CatalystPrimitiveRowConverter( override protected[parquet] def updateString(fieldIndex: Int, value: Binary): Unit = current.setString(fieldIndex, value.toStringUsingUTF8) + + override protected[parquet] def updateDecimal( + fieldIndex: Int, value: Binary, ctype: DecimalType): Unit = { + var decimal = current(fieldIndex).asInstanceOf[Decimal] + if (decimal == null) { + decimal = new Decimal + current(fieldIndex) = decimal + } + readDecimal(decimal, value, ctype) + } } /** @@ -490,7 +536,7 @@ private[parquet] class CatalystNativeArrayConverter( override protected[parquet] def updateField(fieldIndex: Int, value: Any): Unit = throw new UnsupportedOperationException - // Overriden here to avoid auto-boxing for primitive types + // Overridden here to avoid auto-boxing for primitive types override protected[parquet] def updateBoolean(fieldIndex: Int, value: Boolean): Unit = { checkGrowBuffer() buffer(elements) = value.asInstanceOf[NativeType] diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetFilters.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetFilters.scala index 7c83f1cad7d71..517a5cf0029ed 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetFilters.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetFilters.scala @@ -21,8 +21,12 @@ import java.nio.ByteBuffer import org.apache.hadoop.conf.Configuration -import parquet.filter._ -import parquet.filter.ColumnPredicates._ +import parquet.filter2.compat.FilterCompat +import parquet.filter2.compat.FilterCompat._ +import parquet.filter2.predicate.FilterPredicate +import parquet.filter2.predicate.FilterApi +import parquet.filter2.predicate.FilterApi._ +import parquet.io.api.Binary import parquet.column.ColumnReader import com.google.common.io.BaseEncoding @@ -38,67 +42,74 @@ private[sql] object ParquetFilters { // set this to false if pushdown should be disabled val PARQUET_FILTER_PUSHDOWN_ENABLED = "spark.sql.hints.parquetFilterPushdown" - def createRecordFilter(filterExpressions: Seq[Expression]): UnboundRecordFilter = { + def createRecordFilter(filterExpressions: Seq[Expression]): Filter = { val filters: Seq[CatalystFilter] = filterExpressions.collect { case (expression: Expression) if createFilter(expression).isDefined => createFilter(expression).get } - if (filters.length > 0) filters.reduce(AndRecordFilter.and) else null + if (filters.length > 0) FilterCompat.get(filters.reduce(FilterApi.and)) else null } - def createFilter(expression: Expression): Option[CatalystFilter] = { + def createFilter(expression: Expression): Option[CatalystFilter] ={ def createEqualityFilter( name: String, literal: Literal, predicate: CatalystPredicate) = literal.dataType match { case BooleanType => - ComparisonFilter.createBooleanFilter(name, literal.value.asInstanceOf[Boolean], predicate) + ComparisonFilter.createBooleanFilter( + name, + literal.value.asInstanceOf[Boolean], + predicate) case IntegerType => - ComparisonFilter.createIntFilter( + new ComparisonFilter( name, - (x: Int) => x == literal.value.asInstanceOf[Int], + FilterApi.eq(intColumn(name), literal.value.asInstanceOf[Integer]), predicate) case LongType => - ComparisonFilter.createLongFilter( + new ComparisonFilter( name, - (x: Long) => x == literal.value.asInstanceOf[Long], + FilterApi.eq(longColumn(name), literal.value.asInstanceOf[java.lang.Long]), predicate) case DoubleType => - ComparisonFilter.createDoubleFilter( + new ComparisonFilter( name, - (x: Double) => x == literal.value.asInstanceOf[Double], + FilterApi.eq(doubleColumn(name), literal.value.asInstanceOf[java.lang.Double]), predicate) case FloatType => - ComparisonFilter.createFloatFilter( + new ComparisonFilter( name, - (x: Float) => x == literal.value.asInstanceOf[Float], + FilterApi.eq(floatColumn(name), literal.value.asInstanceOf[java.lang.Float]), predicate) case StringType => - ComparisonFilter.createStringFilter(name, literal.value.asInstanceOf[String], predicate) + ComparisonFilter.createStringFilter( + name, + literal.value.asInstanceOf[String], + predicate) } + def createLessThanFilter( name: String, literal: Literal, predicate: CatalystPredicate) = literal.dataType match { case IntegerType => - ComparisonFilter.createIntFilter( - name, - (x: Int) => x < literal.value.asInstanceOf[Int], + new ComparisonFilter( + name, + FilterApi.lt(intColumn(name), literal.value.asInstanceOf[Integer]), predicate) case LongType => - ComparisonFilter.createLongFilter( + new ComparisonFilter( name, - (x: Long) => x < literal.value.asInstanceOf[Long], + FilterApi.lt(longColumn(name), literal.value.asInstanceOf[java.lang.Long]), predicate) case DoubleType => - ComparisonFilter.createDoubleFilter( + new ComparisonFilter( name, - (x: Double) => x < literal.value.asInstanceOf[Double], + FilterApi.lt(doubleColumn(name), literal.value.asInstanceOf[java.lang.Double]), predicate) case FloatType => - ComparisonFilter.createFloatFilter( + new ComparisonFilter( name, - (x: Float) => x < literal.value.asInstanceOf[Float], + FilterApi.lt(floatColumn(name), literal.value.asInstanceOf[java.lang.Float]), predicate) } def createLessThanOrEqualFilter( @@ -106,24 +117,24 @@ private[sql] object ParquetFilters { literal: Literal, predicate: CatalystPredicate) = literal.dataType match { case IntegerType => - ComparisonFilter.createIntFilter( + new ComparisonFilter( name, - (x: Int) => x <= literal.value.asInstanceOf[Int], + FilterApi.ltEq(intColumn(name), literal.value.asInstanceOf[Integer]), predicate) case LongType => - ComparisonFilter.createLongFilter( + new ComparisonFilter( name, - (x: Long) => x <= literal.value.asInstanceOf[Long], + FilterApi.ltEq(longColumn(name), literal.value.asInstanceOf[java.lang.Long]), predicate) case DoubleType => - ComparisonFilter.createDoubleFilter( + new ComparisonFilter( name, - (x: Double) => x <= literal.value.asInstanceOf[Double], + FilterApi.ltEq(doubleColumn(name), literal.value.asInstanceOf[java.lang.Double]), predicate) case FloatType => - ComparisonFilter.createFloatFilter( + new ComparisonFilter( name, - (x: Float) => x <= literal.value.asInstanceOf[Float], + FilterApi.ltEq(floatColumn(name), literal.value.asInstanceOf[java.lang.Float]), predicate) } // TODO: combine these two types somehow? @@ -132,24 +143,24 @@ private[sql] object ParquetFilters { literal: Literal, predicate: CatalystPredicate) = literal.dataType match { case IntegerType => - ComparisonFilter.createIntFilter( + new ComparisonFilter( name, - (x: Int) => x > literal.value.asInstanceOf[Int], + FilterApi.gt(intColumn(name), literal.value.asInstanceOf[Integer]), predicate) case LongType => - ComparisonFilter.createLongFilter( + new ComparisonFilter( name, - (x: Long) => x > literal.value.asInstanceOf[Long], + FilterApi.gt(longColumn(name), literal.value.asInstanceOf[java.lang.Long]), predicate) case DoubleType => - ComparisonFilter.createDoubleFilter( + new ComparisonFilter( name, - (x: Double) => x > literal.value.asInstanceOf[Double], + FilterApi.gt(doubleColumn(name), literal.value.asInstanceOf[java.lang.Double]), predicate) case FloatType => - ComparisonFilter.createFloatFilter( + new ComparisonFilter( name, - (x: Float) => x > literal.value.asInstanceOf[Float], + FilterApi.gt(floatColumn(name), literal.value.asInstanceOf[java.lang.Float]), predicate) } def createGreaterThanOrEqualFilter( @@ -157,23 +168,24 @@ private[sql] object ParquetFilters { literal: Literal, predicate: CatalystPredicate) = literal.dataType match { case IntegerType => - ComparisonFilter.createIntFilter( - name, (x: Int) => x >= literal.value.asInstanceOf[Int], + new ComparisonFilter( + name, + FilterApi.gtEq(intColumn(name), literal.value.asInstanceOf[Integer]), predicate) case LongType => - ComparisonFilter.createLongFilter( + new ComparisonFilter( name, - (x: Long) => x >= literal.value.asInstanceOf[Long], + FilterApi.gtEq(longColumn(name), literal.value.asInstanceOf[java.lang.Long]), predicate) case DoubleType => - ComparisonFilter.createDoubleFilter( + new ComparisonFilter( name, - (x: Double) => x >= literal.value.asInstanceOf[Double], + FilterApi.gtEq(doubleColumn(name), literal.value.asInstanceOf[java.lang.Double]), predicate) case FloatType => - ComparisonFilter.createFloatFilter( + new ComparisonFilter( name, - (x: Float) => x >= literal.value.asInstanceOf[Float], + FilterApi.gtEq(floatColumn(name), literal.value.asInstanceOf[java.lang.Float]), predicate) } @@ -209,25 +221,25 @@ private[sql] object ParquetFilters { case _ => None } } - case p @ EqualTo(left: Literal, right: NamedExpression) if !right.nullable => + case p @ EqualTo(left: Literal, right: NamedExpression) => Some(createEqualityFilter(right.name, left, p)) - case p @ EqualTo(left: NamedExpression, right: Literal) if !left.nullable => + case p @ EqualTo(left: NamedExpression, right: Literal) => Some(createEqualityFilter(left.name, right, p)) - case p @ LessThan(left: Literal, right: NamedExpression) if !right.nullable => + case p @ LessThan(left: Literal, right: NamedExpression) => Some(createLessThanFilter(right.name, left, p)) - case p @ LessThan(left: NamedExpression, right: Literal) if !left.nullable => + case p @ LessThan(left: NamedExpression, right: Literal) => Some(createLessThanFilter(left.name, right, p)) - case p @ LessThanOrEqual(left: Literal, right: NamedExpression) if !right.nullable => + case p @ LessThanOrEqual(left: Literal, right: NamedExpression) => Some(createLessThanOrEqualFilter(right.name, left, p)) - case p @ LessThanOrEqual(left: NamedExpression, right: Literal) if !left.nullable => + case p @ LessThanOrEqual(left: NamedExpression, right: Literal) => Some(createLessThanOrEqualFilter(left.name, right, p)) - case p @ GreaterThan(left: Literal, right: NamedExpression) if !right.nullable => + case p @ GreaterThan(left: Literal, right: NamedExpression) => Some(createGreaterThanFilter(right.name, left, p)) - case p @ GreaterThan(left: NamedExpression, right: Literal) if !left.nullable => + case p @ GreaterThan(left: NamedExpression, right: Literal) => Some(createGreaterThanFilter(left.name, right, p)) - case p @ GreaterThanOrEqual(left: Literal, right: NamedExpression) if !right.nullable => + case p @ GreaterThanOrEqual(left: Literal, right: NamedExpression) => Some(createGreaterThanOrEqualFilter(right.name, left, p)) - case p @ GreaterThanOrEqual(left: NamedExpression, right: Literal) if !left.nullable => + case p @ GreaterThanOrEqual(left: NamedExpression, right: Literal) => Some(createGreaterThanOrEqualFilter(left.name, right, p)) case _ => None } @@ -300,52 +312,54 @@ private[sql] object ParquetFilters { } abstract private[parquet] class CatalystFilter( - @transient val predicate: CatalystPredicate) extends UnboundRecordFilter + @transient val predicate: CatalystPredicate) extends FilterPredicate private[parquet] case class ComparisonFilter( val columnName: String, - private var filter: UnboundRecordFilter, + private var filter: FilterPredicate, @transient override val predicate: CatalystPredicate) extends CatalystFilter(predicate) { - override def bind(readers: java.lang.Iterable[ColumnReader]): RecordFilter = { - filter.bind(readers) + override def accept[R](visitor: FilterPredicate.Visitor[R]): R = { + filter.accept(visitor) } } private[parquet] case class OrFilter( - private var filter: UnboundRecordFilter, + private var filter: FilterPredicate, @transient val left: CatalystFilter, @transient val right: CatalystFilter, @transient override val predicate: Or) extends CatalystFilter(predicate) { def this(l: CatalystFilter, r: CatalystFilter) = this( - OrRecordFilter.or(l, r), + FilterApi.or(l, r), l, r, Or(l.predicate, r.predicate)) - override def bind(readers: java.lang.Iterable[ColumnReader]): RecordFilter = { - filter.bind(readers) + override def accept[R](visitor: FilterPredicate.Visitor[R]): R = { + filter.accept(visitor); } + } private[parquet] case class AndFilter( - private var filter: UnboundRecordFilter, + private var filter: FilterPredicate, @transient val left: CatalystFilter, @transient val right: CatalystFilter, @transient override val predicate: And) extends CatalystFilter(predicate) { def this(l: CatalystFilter, r: CatalystFilter) = this( - AndRecordFilter.and(l, r), + FilterApi.and(l, r), l, r, And(l.predicate, r.predicate)) - override def bind(readers: java.lang.Iterable[ColumnReader]): RecordFilter = { - filter.bind(readers) + override def accept[R](visitor: FilterPredicate.Visitor[R]): R = { + filter.accept(visitor); } + } private[parquet] object ComparisonFilter { @@ -355,13 +369,7 @@ private[parquet] object ComparisonFilter { predicate: CatalystPredicate): CatalystFilter = new ComparisonFilter( columnName, - ColumnRecordFilter.column( - columnName, - ColumnPredicates.applyFunctionToBoolean( - new BooleanPredicateFunction { - def functionToApply(input: Boolean): Boolean = input == value - } - )), + FilterApi.eq(booleanColumn(columnName), value.asInstanceOf[java.lang.Boolean]), predicate) def createStringFilter( @@ -370,72 +378,6 @@ private[parquet] object ComparisonFilter { predicate: CatalystPredicate): CatalystFilter = new ComparisonFilter( columnName, - ColumnRecordFilter.column( - columnName, - ColumnPredicates.applyFunctionToString ( - new ColumnPredicates.PredicateFunction[String] { - def functionToApply(input: String): Boolean = input == value - } - )), - predicate) - - def createIntFilter( - columnName: String, - func: Int => Boolean, - predicate: CatalystPredicate): CatalystFilter = - new ComparisonFilter( - columnName, - ColumnRecordFilter.column( - columnName, - ColumnPredicates.applyFunctionToInteger( - new IntegerPredicateFunction { - def functionToApply(input: Int) = func(input) - } - )), - predicate) - - def createLongFilter( - columnName: String, - func: Long => Boolean, - predicate: CatalystPredicate): CatalystFilter = - new ComparisonFilter( - columnName, - ColumnRecordFilter.column( - columnName, - ColumnPredicates.applyFunctionToLong( - new LongPredicateFunction { - def functionToApply(input: Long) = func(input) - } - )), - predicate) - - def createDoubleFilter( - columnName: String, - func: Double => Boolean, - predicate: CatalystPredicate): CatalystFilter = - new ComparisonFilter( - columnName, - ColumnRecordFilter.column( - columnName, - ColumnPredicates.applyFunctionToDouble( - new DoublePredicateFunction { - def functionToApply(input: Double) = func(input) - } - )), - predicate) - - def createFloatFilter( - columnName: String, - func: Float => Boolean, - predicate: CatalystPredicate): CatalystFilter = - new ComparisonFilter( - columnName, - ColumnRecordFilter.column( - columnName, - ColumnPredicates.applyFunctionToFloat( - new FloatPredicateFunction { - def functionToApply(input: Float) = func(input) - } - )), + FilterApi.eq(binaryColumn(columnName), Binary.fromString(value)), predicate) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetRelation.scala index 5ae768293a22e..82130b5459174 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetRelation.scala @@ -22,7 +22,6 @@ import java.io.IOException import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.Path import org.apache.hadoop.fs.permission.FsAction - import parquet.hadoop.ParquetOutputFormat import parquet.hadoop.metadata.CompressionCodecName import parquet.schema.MessageType @@ -30,7 +29,7 @@ import parquet.schema.MessageType import org.apache.spark.sql.SQLContext import org.apache.spark.sql.catalyst.analysis.{MultiInstanceRelation, UnresolvedException} import org.apache.spark.sql.catalyst.expressions.Attribute -import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, LeafNode} +import org.apache.spark.sql.catalyst.plans.logical.{LeafNode, LogicalPlan, Statistics} /** * Relation that consists of data stored in a Parquet columnar format. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableOperations.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableOperations.scala index 5c6fa78ae3895..d00860a8bb8a6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableOperations.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableOperations.scala @@ -38,6 +38,7 @@ import org.apache.hadoop.mapreduce.lib.output.FileOutputCommitter import parquet.hadoop._ import parquet.hadoop.api.{InitContext, ReadSupport} import parquet.hadoop.metadata.GlobalMetaData +import parquet.hadoop.api.ReadSupport.ReadContext import parquet.hadoop.util.ContextUtil import parquet.io.ParquetDecodingException import parquet.schema.MessageType @@ -76,6 +77,10 @@ case class ParquetTableScan( s"$normalOutput + $partOutput != $attributes, ${relation.output}") override def execute(): RDD[Row] = { + import parquet.filter2.compat.FilterCompat.FilterPredicateCompat + import parquet.filter2.compat.FilterCompat.Filter + import parquet.filter2.predicate.FilterPredicate + val sc = sqlContext.sparkContext val job = new Job(sc.hadoopConfiguration) ParquetInputFormat.setReadSupportClass(job, classOf[RowReadSupport]) @@ -106,13 +111,19 @@ case class ParquetTableScan( // "spark.sql.hints.parquetFilterPushdown" to false inside SparkConf. if (columnPruningPred.length > 0 && sc.conf.getBoolean(ParquetFilters.PARQUET_FILTER_PUSHDOWN_ENABLED, true)) { - ParquetFilters.serializeFilterExpressions(columnPruningPred, conf) + + // Set this in configuration of ParquetInputFormat, needed for RowGroupFiltering + val filter: Filter = ParquetFilters.createRecordFilter(columnPruningPred) + if (filter != null){ + val filterPredicate = filter.asInstanceOf[FilterPredicateCompat].getFilterPredicate() + ParquetInputFormat.setFilterPredicate(conf, filterPredicate) + } } // Tell FilteringParquetRowInputFormat whether it's okay to cache Parquet and FS metadata conf.set( SQLConf.PARQUET_CACHE_METADATA, - sqlContext.getConf(SQLConf.PARQUET_CACHE_METADATA, "false")) + sqlContext.getConf(SQLConf.PARQUET_CACHE_METADATA, "true")) val baseRDD = new org.apache.spark.rdd.NewHadoopRDD( @@ -362,15 +373,17 @@ private[parquet] class FilteringParquetRowInputFormat override def createRecordReader( inputSplit: InputSplit, taskAttemptContext: TaskAttemptContext): RecordReader[Void, Row] = { + + import parquet.filter2.compat.FilterCompat.NoOpFilter + import parquet.filter2.compat.FilterCompat.Filter + val readSupport: ReadSupport[Row] = new RowReadSupport() - val filterExpressions = - ParquetFilters.deserializeFilterExpressions(ContextUtil.getConfiguration(taskAttemptContext)) - if (filterExpressions.length > 0) { - logInfo(s"Pushing down predicates for RecordFilter: ${filterExpressions.mkString(", ")}") + val filter = ParquetInputFormat.getFilter(ContextUtil.getConfiguration(taskAttemptContext)) + if (!filter.isInstanceOf[NoOpFilter]) { new ParquetRecordReader[Row]( readSupport, - ParquetFilters.createRecordFilter(filterExpressions)) + filter) } else { new ParquetRecordReader[Row](readSupport) } @@ -381,7 +394,7 @@ private[parquet] class FilteringParquetRowInputFormat if (footers eq null) { val conf = ContextUtil.getConfiguration(jobContext) - val cacheMetadata = conf.getBoolean(SQLConf.PARQUET_CACHE_METADATA, false) + val cacheMetadata = conf.getBoolean(SQLConf.PARQUET_CACHE_METADATA, true) val statuses = listStatus(jobContext) fileStatuses = statuses.map(file => file.getPath -> file).toMap if (statuses.isEmpty) { @@ -423,10 +436,8 @@ private[parquet] class FilteringParquetRowInputFormat configuration: Configuration, footers: JList[Footer]): JList[ParquetInputSplit] = { - import FilteringParquetRowInputFormat.blockLocationCache - - val cacheMetadata = configuration.getBoolean(SQLConf.PARQUET_CACHE_METADATA, false) - + // Use task side strategy by default + val taskSideMetaData = configuration.getBoolean(ParquetInputFormat.TASK_SIDE_METADATA, true) val maxSplitSize: JLong = configuration.getLong("mapred.max.split.size", Long.MaxValue) val minSplitSize: JLong = Math.max(getFormatMinSplitSize(), configuration.getLong("mapred.min.split.size", 0L)) @@ -435,23 +446,67 @@ private[parquet] class FilteringParquetRowInputFormat s"maxSplitSize or minSplitSie should not be negative: maxSplitSize = $maxSplitSize;" + s" minSplitSize = $minSplitSize") } - val splits = mutable.ArrayBuffer.empty[ParquetInputSplit] + + // Uses strict type checking by default val getGlobalMetaData = classOf[ParquetFileWriter].getDeclaredMethod("getGlobalMetaData", classOf[JList[Footer]]) getGlobalMetaData.setAccessible(true) val globalMetaData = getGlobalMetaData.invoke(null, footers).asInstanceOf[GlobalMetaData] - // if parquet file is empty, return empty splits. - if (globalMetaData == null) { - return splits - } + if (globalMetaData == null) { + val splits = mutable.ArrayBuffer.empty[ParquetInputSplit] + return splits + } + val readContext = getReadSupport(configuration).init( new InitContext(configuration, globalMetaData.getKeyValueMetaData(), globalMetaData.getSchema())) + + if (taskSideMetaData){ + logInfo("Using Task Side Metadata Split Strategy") + return getTaskSideSplits(configuration, + footers, + maxSplitSize, + minSplitSize, + readContext) + } else { + logInfo("Using Client Side Metadata Split Strategy") + return getClientSideSplits(configuration, + footers, + maxSplitSize, + minSplitSize, + readContext) + } + + } + + def getClientSideSplits( + configuration: Configuration, + footers: JList[Footer], + maxSplitSize: JLong, + minSplitSize: JLong, + readContext: ReadContext): JList[ParquetInputSplit] = { + + import FilteringParquetRowInputFormat.blockLocationCache + import parquet.filter2.compat.FilterCompat; + import parquet.filter2.compat.FilterCompat.Filter; + import parquet.filter2.compat.RowGroupFilter; + + val cacheMetadata = configuration.getBoolean(SQLConf.PARQUET_CACHE_METADATA, true) + val splits = mutable.ArrayBuffer.empty[ParquetInputSplit] + val filter: Filter = ParquetInputFormat.getFilter(configuration) + var rowGroupsDropped: Long = 0 + var totalRowGroups: Long = 0 + + // Ugly hack, stuck with it until PR: + // https://github.com/apache/incubator-parquet-mr/pull/17 + // is resolved val generateSplits = - classOf[ParquetInputFormat[_]].getDeclaredMethods.find(_.getName == "generateSplits").get + Class.forName("parquet.hadoop.ClientSideMetadataSplitStrategy") + .getDeclaredMethods.find(_.getName == "generateSplits").getOrElse( + sys.error(s"Failed to reflectively invoke ClientSideMetadataSplitStrategy.generateSplits")) generateSplits.setAccessible(true) for (footer <- footers) { @@ -460,29 +515,85 @@ private[parquet] class FilteringParquetRowInputFormat val status = fileStatuses.getOrElse(file, fs.getFileStatus(file)) val parquetMetaData = footer.getParquetMetadata val blocks = parquetMetaData.getBlocks - var blockLocations: Array[BlockLocation] = null - if (!cacheMetadata) { - blockLocations = fs.getFileBlockLocations(status, 0, status.getLen) - } else { - blockLocations = blockLocationCache.get(status, new Callable[Array[BlockLocation]] { - def call(): Array[BlockLocation] = fs.getFileBlockLocations(status, 0, status.getLen) - }) - } + totalRowGroups = totalRowGroups + blocks.size + val filteredBlocks = RowGroupFilter.filterRowGroups( + filter, + blocks, + parquetMetaData.getFileMetaData.getSchema) + rowGroupsDropped = rowGroupsDropped + (blocks.size - filteredBlocks.size) + + if (!filteredBlocks.isEmpty){ + var blockLocations: Array[BlockLocation] = null + if (!cacheMetadata) { + blockLocations = fs.getFileBlockLocations(status, 0, status.getLen) + } else { + blockLocations = blockLocationCache.get(status, new Callable[Array[BlockLocation]] { + def call(): Array[BlockLocation] = fs.getFileBlockLocations(status, 0, status.getLen) + }) + } + splits.addAll( + generateSplits.invoke( + null, + filteredBlocks, + blockLocations, + status, + readContext.getRequestedSchema.toString, + readContext.getReadSupportMetadata, + minSplitSize, + maxSplitSize).asInstanceOf[JList[ParquetInputSplit]]) + } + } + + if (rowGroupsDropped > 0 && totalRowGroups > 0){ + val percentDropped = ((rowGroupsDropped/totalRowGroups.toDouble) * 100).toInt + logInfo(s"Dropping $rowGroupsDropped row groups that do not pass filter predicate " + + s"($percentDropped %) !") + } + else { + logInfo("There were no row groups that could be dropped due to filter predicates") + } + splits + + } + + def getTaskSideSplits( + configuration: Configuration, + footers: JList[Footer], + maxSplitSize: JLong, + minSplitSize: JLong, + readContext: ReadContext): JList[ParquetInputSplit] = { + + val splits = mutable.ArrayBuffer.empty[ParquetInputSplit] + + // Ugly hack, stuck with it until PR: + // https://github.com/apache/incubator-parquet-mr/pull/17 + // is resolved + val generateSplits = + Class.forName("parquet.hadoop.TaskSideMetadataSplitStrategy") + .getDeclaredMethods.find(_.getName == "generateTaskSideMDSplits").getOrElse( + sys.error( + s"Failed to reflectively invoke TaskSideMetadataSplitStrategy.generateTaskSideMDSplits")) + generateSplits.setAccessible(true) + + for (footer <- footers) { + val file = footer.getFile + val fs = file.getFileSystem(configuration) + val status = fileStatuses.getOrElse(file, fs.getFileStatus(file)) + val blockLocations = fs.getFileBlockLocations(status, 0, status.getLen) splits.addAll( generateSplits.invoke( - null, - blocks, - blockLocations, - status, - parquetMetaData.getFileMetaData, - readContext.getRequestedSchema.toString, - readContext.getReadSupportMetadata, - minSplitSize, - maxSplitSize).asInstanceOf[JList[ParquetInputSplit]]) + null, + blockLocations, + status, + readContext.getRequestedSchema.toString, + readContext.getReadSupportMetadata, + minSplitSize, + maxSplitSize).asInstanceOf[JList[ParquetInputSplit]]) } splits - } + } + } private[parquet] object FilteringParquetRowInputFormat { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableSupport.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableSupport.scala index bdf02401b21be..7bc249660053a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableSupport.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableSupport.scala @@ -30,6 +30,7 @@ import parquet.schema.MessageType import org.apache.spark.Logging import org.apache.spark.sql.catalyst.expressions.{Attribute, Row} import org.apache.spark.sql.catalyst.types._ +import org.apache.spark.sql.catalyst.types.decimal.Decimal /** * A `parquet.io.api.RecordMaterializer` for Rows. @@ -173,6 +174,7 @@ private[parquet] class RowWriteSupport extends WriteSupport[Row] with Logging { private[parquet] def writeValue(schema: DataType, value: Any): Unit = { if (value != null) { schema match { + case t: UserDefinedType[_] => writeValue(t.sqlType, value) case t @ ArrayType(_, _) => writeArray( t, value.asInstanceOf[CatalystConverter.ArrayScalaType[_]]) @@ -204,6 +206,11 @@ private[parquet] class RowWriteSupport extends WriteSupport[Row] with Logging { case DoubleType => writer.addDouble(value.asInstanceOf[Double]) case FloatType => writer.addFloat(value.asInstanceOf[Float]) case BooleanType => writer.addBoolean(value.asInstanceOf[Boolean]) + case d: DecimalType => + if (d.precisionInfo == None || d.precisionInfo.get.precision > 18) { + sys.error(s"Unsupported datatype $d, cannot write to consumer") + } + writeDecimal(value.asInstanceOf[Decimal], d.precisionInfo.get.precision) case _ => sys.error(s"Do not know how to writer $schema to consumer") } } @@ -283,6 +290,23 @@ private[parquet] class RowWriteSupport extends WriteSupport[Row] with Logging { } writer.endGroup() } + + // Scratch array used to write decimals as fixed-length binary + private val scratchBytes = new Array[Byte](8) + + private[parquet] def writeDecimal(decimal: Decimal, precision: Int): Unit = { + val numBytes = ParquetTypesConverter.BYTES_FOR_PRECISION(precision) + val unscaledLong = decimal.toUnscaledLong + var i = 0 + var shift = 8 * (numBytes - 1) + while (i < numBytes) { + scratchBytes(i) = (unscaledLong >> shift).toByte + i += 1 + shift -= 8 + } + writer.addBinary(Binary.fromByteArray(scratchBytes, 0, numBytes)) + } + } // Optimized for non-nested rows @@ -326,6 +350,11 @@ private[parquet] class MutableRowWriteSupport extends RowWriteSupport { case DoubleType => writer.addDouble(record.getDouble(index)) case FloatType => writer.addFloat(record.getFloat(index)) case BooleanType => writer.addBoolean(record.getBoolean(index)) + case d: DecimalType => + if (d.precisionInfo == None || d.precisionInfo.get.precision > 18) { + sys.error(s"Unsupported datatype $d, cannot write to consumer") + } + writeDecimal(record(index).asInstanceOf[Decimal], d.precisionInfo.get.precision) case _ => sys.error(s"Unsupported datatype $ctype, cannot write to consumer") } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTestData.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTestData.scala index 837ea7695dbb3..c0918a40d136f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTestData.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTestData.scala @@ -92,6 +92,12 @@ private[sql] object ParquetTestData { required int64 mylong; required float myfloat; required double mydouble; + optional boolean myoptboolean; + optional int32 myoptint; + optional binary myoptstring (UTF8); + optional int64 myoptlong; + optional float myoptfloat; + optional double myoptdouble; } """ @@ -255,6 +261,19 @@ private[sql] object ParquetTestData { record.add(3, i.toLong) record.add(4, i.toFloat + 0.5f) record.add(5, i.toDouble + 0.5d) + if (i % 2 == 0) { + if (i % 3 == 0) { + record.add(6, true) + } else { + record.add(6, false) + } + record.add(7, i) + record.add(8, i.toString) + record.add(9, i.toLong) + record.add(10, i.toFloat + 0.5f) + record.add(11, i.toDouble + 0.5d) + } + writer.write(record) } writer.close() diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTypes.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTypes.scala index e6389cf77a4c9..fa37d1f2ae7e6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTypes.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTypes.scala @@ -29,8 +29,8 @@ import org.apache.hadoop.mapreduce.lib.output.FileOutputCommitter import parquet.hadoop.{ParquetFileReader, Footer, ParquetFileWriter} import parquet.hadoop.metadata.{ParquetMetadata, FileMetaData} import parquet.hadoop.util.ContextUtil -import parquet.schema.{Type => ParquetType, PrimitiveType => ParquetPrimitiveType, MessageType} -import parquet.schema.{GroupType => ParquetGroupType, OriginalType => ParquetOriginalType, ConversionPatterns} +import parquet.schema.{Type => ParquetType, Types => ParquetTypes, PrimitiveType => ParquetPrimitiveType, MessageType} +import parquet.schema.{GroupType => ParquetGroupType, OriginalType => ParquetOriginalType, ConversionPatterns, DecimalMetadata} import parquet.schema.PrimitiveType.{PrimitiveTypeName => ParquetPrimitiveTypeName} import parquet.schema.Type.Repetition @@ -41,17 +41,25 @@ import org.apache.spark.sql.catalyst.types._ // Implicits import scala.collection.JavaConversions._ +/** A class representing Parquet info fields we care about, for passing back to Parquet */ +private[parquet] case class ParquetTypeInfo( + primitiveType: ParquetPrimitiveTypeName, + originalType: Option[ParquetOriginalType] = None, + decimalMetadata: Option[DecimalMetadata] = None, + length: Option[Int] = None) + private[parquet] object ParquetTypesConverter extends Logging { def isPrimitiveType(ctype: DataType): Boolean = classOf[PrimitiveType] isAssignableFrom ctype.getClass def toPrimitiveDataType( parquetType: ParquetPrimitiveType, - binayAsString: Boolean): DataType = + binaryAsString: Boolean): DataType = { + val originalType = parquetType.getOriginalType + val decimalInfo = parquetType.getDecimalMetadata parquetType.getPrimitiveTypeName match { case ParquetPrimitiveTypeName.BINARY - if (parquetType.getOriginalType == ParquetOriginalType.UTF8 || - binayAsString) => StringType + if (originalType == ParquetOriginalType.UTF8 || binaryAsString) => StringType case ParquetPrimitiveTypeName.BINARY => BinaryType case ParquetPrimitiveTypeName.BOOLEAN => BooleanType case ParquetPrimitiveTypeName.DOUBLE => DoubleType @@ -61,9 +69,14 @@ private[parquet] object ParquetTypesConverter extends Logging { case ParquetPrimitiveTypeName.INT96 => // TODO: add BigInteger type? TODO(andre) use DecimalType instead???? sys.error("Potential loss of precision: cannot convert INT96") + case ParquetPrimitiveTypeName.FIXED_LEN_BYTE_ARRAY + if (originalType == ParquetOriginalType.DECIMAL && decimalInfo.getPrecision <= 18) => + // TODO: for now, our reader only supports decimals that fit in a Long + DecimalType(decimalInfo.getPrecision, decimalInfo.getScale) case _ => sys.error( s"Unsupported parquet datatype $parquetType") } + } /** * Converts a given Parquet `Type` into the corresponding @@ -183,23 +196,40 @@ private[parquet] object ParquetTypesConverter extends Logging { * is not primitive. * * @param ctype The type to convert - * @return The name of the corresponding Parquet primitive type + * @return The name of the corresponding Parquet type properties */ - def fromPrimitiveDataType(ctype: DataType): - Option[(ParquetPrimitiveTypeName, Option[ParquetOriginalType])] = ctype match { - case StringType => Some(ParquetPrimitiveTypeName.BINARY, Some(ParquetOriginalType.UTF8)) - case BinaryType => Some(ParquetPrimitiveTypeName.BINARY, None) - case BooleanType => Some(ParquetPrimitiveTypeName.BOOLEAN, None) - case DoubleType => Some(ParquetPrimitiveTypeName.DOUBLE, None) - case FloatType => Some(ParquetPrimitiveTypeName.FLOAT, None) - case IntegerType => Some(ParquetPrimitiveTypeName.INT32, None) + def fromPrimitiveDataType(ctype: DataType): Option[ParquetTypeInfo] = ctype match { + case StringType => Some(ParquetTypeInfo( + ParquetPrimitiveTypeName.BINARY, Some(ParquetOriginalType.UTF8))) + case BinaryType => Some(ParquetTypeInfo(ParquetPrimitiveTypeName.BINARY)) + case BooleanType => Some(ParquetTypeInfo(ParquetPrimitiveTypeName.BOOLEAN)) + case DoubleType => Some(ParquetTypeInfo(ParquetPrimitiveTypeName.DOUBLE)) + case FloatType => Some(ParquetTypeInfo(ParquetPrimitiveTypeName.FLOAT)) + case IntegerType => Some(ParquetTypeInfo(ParquetPrimitiveTypeName.INT32)) // There is no type for Byte or Short so we promote them to INT32. - case ShortType => Some(ParquetPrimitiveTypeName.INT32, None) - case ByteType => Some(ParquetPrimitiveTypeName.INT32, None) - case LongType => Some(ParquetPrimitiveTypeName.INT64, None) + case ShortType => Some(ParquetTypeInfo(ParquetPrimitiveTypeName.INT32)) + case ByteType => Some(ParquetTypeInfo(ParquetPrimitiveTypeName.INT32)) + case LongType => Some(ParquetTypeInfo(ParquetPrimitiveTypeName.INT64)) + case DecimalType.Fixed(precision, scale) if precision <= 18 => + // TODO: for now, our writer only supports decimals that fit in a Long + Some(ParquetTypeInfo(ParquetPrimitiveTypeName.FIXED_LEN_BYTE_ARRAY, + Some(ParquetOriginalType.DECIMAL), + Some(new DecimalMetadata(precision, scale)), + Some(BYTES_FOR_PRECISION(precision)))) case _ => None } + /** + * Compute the FIXED_LEN_BYTE_ARRAY length needed to represent a given DECIMAL precision. + */ + private[parquet] val BYTES_FOR_PRECISION = Array.tabulate[Int](38) { precision => + var length = 1 + while (math.pow(2.0, 8 * length - 1) < math.pow(10.0, precision)) { + length += 1 + } + length + } + /** * Converts a given Catalyst [[org.apache.spark.sql.catalyst.types.DataType]] into * the corresponding Parquet `Type`. @@ -247,12 +277,22 @@ private[parquet] object ParquetTypesConverter extends Logging { } else { if (nullable) Repetition.OPTIONAL else Repetition.REQUIRED } - val primitiveType = fromPrimitiveDataType(ctype) - primitiveType.map { - case (primitiveType, originalType) => - new ParquetPrimitiveType(repetition, primitiveType, name, originalType.orNull) + val typeInfo = fromPrimitiveDataType(ctype) + typeInfo.map { + case ParquetTypeInfo(primitiveType, originalType, decimalMetadata, length) => + val builder = ParquetTypes.primitive(primitiveType, repetition).as(originalType.orNull) + for (len <- length) { + builder.length(len) + } + for (metadata <- decimalMetadata) { + builder.precision(metadata.getPrecision).scale(metadata.getScale) + } + builder.named(name) }.getOrElse { ctype match { + case udt: UserDefinedType[_] => { + fromDataType(udt.sqlType, name, nullable, inArray) + } case ArrayType(elementType, false) => { val parquetElementType = fromDataType( elementType, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/DataSourceStrategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/DataSourceStrategy.scala new file mode 100644 index 0000000000000..9b8c6a56b94b4 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/DataSourceStrategy.scala @@ -0,0 +1,112 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.sources + +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.Row +import org.apache.spark.sql._ +import org.apache.spark.sql.catalyst.expressions +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.planning.PhysicalOperation +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.execution.SparkPlan + +/** + * A Strategy for planning scans over data sources defined using the sources API. + */ +private[sql] object DataSourceStrategy extends Strategy { + def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { + case PhysicalOperation(projectList, filters, l @ LogicalRelation(t: PrunedFilteredScan)) => + pruneFilterProject( + l, + projectList, + filters, + (a, f) => t.buildScan(a, f)) :: Nil + + case PhysicalOperation(projectList, filters, l @ LogicalRelation(t: PrunedScan)) => + pruneFilterProject( + l, + projectList, + filters, + (a, _) => t.buildScan(a)) :: Nil + + case l @ LogicalRelation(t: TableScan) => + execution.PhysicalRDD(l.output, t.buildScan()) :: Nil + + case _ => Nil + } + + protected def pruneFilterProject( + relation: LogicalRelation, + projectList: Seq[NamedExpression], + filterPredicates: Seq[Expression], + scanBuilder: (Array[String], Array[Filter]) => RDD[Row]) = { + + val projectSet = AttributeSet(projectList.flatMap(_.references)) + val filterSet = AttributeSet(filterPredicates.flatMap(_.references)) + val filterCondition = filterPredicates.reduceLeftOption(And) + + val pushedFilters = selectFilters(filterPredicates.map { _ transform { + case a: AttributeReference => relation.attributeMap(a) // Match original case of attributes. + }}).toArray + + if (projectList.map(_.toAttribute) == projectList && + projectSet.size == projectList.size && + filterSet.subsetOf(projectSet)) { + // When it is possible to just use column pruning to get the right projection and + // when the columns of this projection are enough to evaluate all filter conditions, + // just do a scan followed by a filter, with no extra project. + val requestedColumns = + projectList.asInstanceOf[Seq[Attribute]] // Safe due to if above. + .map(relation.attributeMap) // Match original case of attributes. + .map(_.name) + .toArray + + val scan = + execution.PhysicalRDD( + projectList.map(_.toAttribute), + scanBuilder(requestedColumns, pushedFilters)) + filterCondition.map(execution.Filter(_, scan)).getOrElse(scan) + } else { + val requestedColumns = (projectSet ++ filterSet).map(relation.attributeMap).toSeq + val columnNames = requestedColumns.map(_.name).toArray + + val scan = execution.PhysicalRDD(requestedColumns, scanBuilder(columnNames, pushedFilters)) + execution.Project(projectList, filterCondition.map(execution.Filter(_, scan)).getOrElse(scan)) + } + } + + protected def selectFilters(filters: Seq[Expression]): Seq[Filter] = filters.collect { + case expressions.EqualTo(a: Attribute, Literal(v, _)) => EqualTo(a.name, v) + case expressions.EqualTo(Literal(v, _), a: Attribute) => EqualTo(a.name, v) + + case expressions.GreaterThan(a: Attribute, Literal(v, _)) => GreaterThan(a.name, v) + case expressions.GreaterThan(Literal(v, _), a: Attribute) => LessThan(a.name, v) + + case expressions.LessThan(a: Attribute, Literal(v, _)) => LessThan(a.name, v) + case expressions.LessThan(Literal(v, _), a: Attribute) => GreaterThan(a.name, v) + + case expressions.GreaterThanOrEqual(a: Attribute, Literal(v, _)) => + GreaterThanOrEqual(a.name, v) + case expressions.GreaterThanOrEqual(Literal(v, _), a: Attribute) => + LessThanOrEqual(a.name, v) + + case expressions.LessThanOrEqual(a: Attribute, Literal(v, _)) => LessThanOrEqual(a.name, v) + case expressions.LessThanOrEqual(Literal(v, _), a: Attribute) => GreaterThanOrEqual(a.name, v) + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/LogicalRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/LogicalRelation.scala new file mode 100644 index 0000000000000..82a2cf8402f8f --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/LogicalRelation.scala @@ -0,0 +1,54 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.sources + +import org.apache.spark.sql.catalyst.analysis.MultiInstanceRelation +import org.apache.spark.sql.catalyst.expressions.AttributeMap +import org.apache.spark.sql.catalyst.plans.logical.{Statistics, LeafNode, LogicalPlan} + +/** + * Used to link a [[BaseRelation]] in to a logical query plan. + */ +private[sql] case class LogicalRelation(relation: BaseRelation) + extends LeafNode + with MultiInstanceRelation { + + override val output = relation.schema.toAttributes + + // Logical Relations are distinct if they have different output for the sake of transformations. + override def equals(other: Any) = other match { + case l @ LogicalRelation(otherRelation) => relation == otherRelation && output == l.output + case _ => false + } + + override def sameResult(otherPlan: LogicalPlan) = otherPlan match { + case LogicalRelation(otherRelation) => relation == otherRelation + case _ => false + } + + @transient override lazy val statistics = Statistics( + // TODO: Allow datasources to provide statistics as well. + sizeInBytes = BigInt(relation.sqlContext.defaultSizeInBytes) + ) + + /** Used to lookup original attribute capitalization */ + val attributeMap = AttributeMap(output.map(o => (o, o))) + + def newInstance() = LogicalRelation(relation).asInstanceOf[this.type] + + override def simpleString = s"Relation[${output.mkString(",")}] $relation" +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/ddl.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/ddl.scala new file mode 100644 index 0000000000000..9168ca2fc6fec --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/ddl.scala @@ -0,0 +1,108 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.sources + +import org.apache.spark.Logging +import org.apache.spark.sql.SQLContext +import org.apache.spark.sql.execution.RunnableCommand +import org.apache.spark.util.Utils + +import scala.language.implicitConversions +import scala.util.parsing.combinator.lexical.StdLexical +import scala.util.parsing.combinator.syntactical.StandardTokenParsers +import scala.util.parsing.combinator.PackratParsers + +import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.catalyst.SqlLexical + +/** + * A parser for foreign DDL commands. + */ +private[sql] class DDLParser extends StandardTokenParsers with PackratParsers with Logging { + + def apply(input: String): Option[LogicalPlan] = { + phrase(ddl)(new lexical.Scanner(input)) match { + case Success(r, x) => Some(r) + case x => + logDebug(s"Not recognized as DDL: $x") + None + } + } + + protected case class Keyword(str: String) + + protected implicit def asParser(k: Keyword): Parser[String] = + lexical.allCaseVersions(k.str).map(x => x : Parser[String]).reduce(_ | _) + + protected val CREATE = Keyword("CREATE") + protected val TEMPORARY = Keyword("TEMPORARY") + protected val TABLE = Keyword("TABLE") + protected val USING = Keyword("USING") + protected val OPTIONS = Keyword("OPTIONS") + + // Use reflection to find the reserved words defined in this class. + protected val reservedWords = + this.getClass + .getMethods + .filter(_.getReturnType == classOf[Keyword]) + .map(_.invoke(this).asInstanceOf[Keyword].str) + + override val lexical = new SqlLexical(reservedWords) + + protected lazy val ddl: Parser[LogicalPlan] = createTable + + /** + * CREATE FOREIGN TEMPORARY TABLE avroTable + * USING org.apache.spark.sql.avro + * OPTIONS (path "../hive/src/test/resources/data/files/episodes.avro") + */ + protected lazy val createTable: Parser[LogicalPlan] = + CREATE ~ TEMPORARY ~ TABLE ~> ident ~ (USING ~> className) ~ (OPTIONS ~> options) ^^ { + case tableName ~ provider ~ opts => + CreateTableUsing(tableName, provider, opts) + } + + protected lazy val options: Parser[Map[String, String]] = + "(" ~> repsep(pair, ",") <~ ")" ^^ { case s: Seq[(String, String)] => s.toMap } + + protected lazy val className: Parser[String] = repsep(ident, ".") ^^ { case s => s.mkString(".")} + + protected lazy val pair: Parser[(String, String)] = ident ~ stringLit ^^ { case k ~ v => (k,v) } +} + +private[sql] case class CreateTableUsing( + tableName: String, + provider: String, + options: Map[String, String]) extends RunnableCommand { + + def run(sqlContext: SQLContext) = { + val loader = Utils.getContextOrSparkClassLoader + val clazz: Class[_] = try loader.loadClass(provider) catch { + case cnf: java.lang.ClassNotFoundException => + try loader.loadClass(provider + ".DefaultSource") catch { + case cnf: java.lang.ClassNotFoundException => + sys.error(s"Failed to load class for data source: $provider") + } + } + val dataSource = clazz.newInstance().asInstanceOf[org.apache.spark.sql.sources.RelationProvider] + val relation = dataSource.createRelation(sqlContext, options) + + sqlContext.baseRelationToSchemaRDD(relation).registerTempTable(tableName) + Seq.empty + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/filters.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/filters.scala new file mode 100644 index 0000000000000..e72a2aeb8f310 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/filters.scala @@ -0,0 +1,26 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.sources + +abstract class Filter + +case class EqualTo(attribute: String, value: Any) extends Filter +case class GreaterThan(attribute: String, value: Any) extends Filter +case class GreaterThanOrEqual(attribute: String, value: Any) extends Filter +case class LessThan(attribute: String, value: Any) extends Filter +case class LessThanOrEqual(attribute: String, value: Any) extends Filter diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala new file mode 100644 index 0000000000000..ac3bf9d8e1a21 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala @@ -0,0 +1,86 @@ +/* +* Licensed to the Apache Software Foundation (ASF) under one or more +* contributor license agreements. See the NOTICE file distributed with +* this work for additional information regarding copyright ownership. +* The ASF licenses this file to You under the Apache License, Version 2.0 +* (the "License"); you may not use this file except in compliance with +* the License. You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ +package org.apache.spark.sql.sources + +import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.{Row, SQLContext, StructType} +import org.apache.spark.sql.catalyst.expressions.{Expression, Attribute} + +/** + * Implemented by objects that produce relations for a specific kind of data source. When + * Spark SQL is given a DDL operation with a USING clause specified, this interface is used to + * pass in the parameters specified by a user. + * + * Users may specify the fully qualified class name of a given data source. When that class is + * not found Spark SQL will append the class name `DefaultSource` to the path, allowing for + * less verbose invocation. For example, 'org.apache.spark.sql.json' would resolve to the + * data source 'org.apache.spark.sql.json.DefaultSource' + * + * A new instance of this class with be instantiated each time a DDL call is made. + */ +@DeveloperApi +trait RelationProvider { + /** Returns a new base relation with the given parameters. */ + def createRelation(sqlContext: SQLContext, parameters: Map[String, String]): BaseRelation +} + +/** + * Represents a collection of tuples with a known schema. Classes that extend BaseRelation must + * be able to produce the schema of their data in the form of a [[StructType]] Concrete + * implementation should inherit from one of the descendant `Scan` classes, which define various + * abstract methods for execution. + * + * BaseRelations must also define a equality function that only returns true when the two + * instances will return the same data. This equality function is used when determining when + * it is safe to substitute cached results for a given relation. + */ +@DeveloperApi +abstract class BaseRelation { + def sqlContext: SQLContext + def schema: StructType +} + +/** + * A BaseRelation that can produce all of its tuples as an RDD of Row objects. + */ +@DeveloperApi +abstract class TableScan extends BaseRelation { + def buildScan(): RDD[Row] +} + +/** + * A BaseRelation that can eliminate unneeded columns before producing an RDD + * containing all of its tuples as Row objects. + */ +@DeveloperApi +abstract class PrunedScan extends BaseRelation { + def buildScan(requiredColumns: Array[String]): RDD[Row] +} + +/** + * A BaseRelation that can eliminate unneeded columns and filter using selected + * predicates before producing an RDD containing all matching tuples as Row objects. + * + * The pushed down filters are currently purely an optimization as they will all be evaluated + * again. This means it is safe to use them with methods that produce false positives such + * as filtering partitions based on a bloom filter. + */ +@DeveloperApi +abstract class PrunedFilteredScan extends BaseRelation { + def buildScan(requiredColumns: Array[String], filters: Array[Filter]): RDD[Row] +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/package.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/package.scala new file mode 100644 index 0000000000000..8393c510f4f6d --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/package.scala @@ -0,0 +1,22 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql + +/** + * A set of APIs for adding data sources to Spark SQL. + */ +package object sources diff --git a/sql/core/src/main/scala/org/apache/spark/sql/test/ExamplePointUDT.scala b/sql/core/src/main/scala/org/apache/spark/sql/test/ExamplePointUDT.scala new file mode 100644 index 0000000000000..b9569e96c0312 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/test/ExamplePointUDT.scala @@ -0,0 +1,64 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.test + +import java.util + +import scala.collection.JavaConverters._ + +import org.apache.spark.sql.catalyst.annotation.SQLUserDefinedType +import org.apache.spark.sql.catalyst.types._ + +/** + * An example class to demonstrate UDT in Scala, Java, and Python. + * @param x x coordinate + * @param y y coordinate + */ +@SQLUserDefinedType(udt = classOf[ExamplePointUDT]) +private[sql] class ExamplePoint(val x: Double, val y: Double) + +/** + * User-defined type for [[ExamplePoint]]. + */ +private[sql] class ExamplePointUDT extends UserDefinedType[ExamplePoint] { + + override def sqlType: DataType = ArrayType(DoubleType, false) + + override def pyUDT: String = "pyspark.tests.ExamplePointUDT" + + override def serialize(obj: Any): Seq[Double] = { + obj match { + case p: ExamplePoint => + Seq(p.x, p.y) + } + } + + override def deserialize(datum: Any): ExamplePoint = { + datum match { + case values: Seq[_] => + val xy = values.asInstanceOf[Seq[Double]] + assert(xy.length == 2) + new ExamplePoint(xy(0), xy(1)) + case values: util.ArrayList[_] => + val xy = values.asInstanceOf[util.ArrayList[Double]].asScala + new ExamplePoint(xy(0), xy(1)) + } + } + + override def userClass: Class[ExamplePoint] = classOf[ExamplePoint] +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/types/util/DataTypeConversions.scala b/sql/core/src/main/scala/org/apache/spark/sql/types/util/DataTypeConversions.scala index e44cb08309523..3fa4a7c6481d3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/types/util/DataTypeConversions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/types/util/DataTypeConversions.scala @@ -17,11 +17,16 @@ package org.apache.spark.sql.types.util -import org.apache.spark.sql._ -import org.apache.spark.sql.api.java.{DataType => JDataType, StructField => JStructField} - import scala.collection.JavaConverters._ +import org.apache.spark.sql._ +import org.apache.spark.sql.api.java.{DataType => JDataType, StructField => JStructField, + MetadataBuilder => JMetaDataBuilder, UDTWrappers, JavaToScalaUDTWrapper} +import org.apache.spark.sql.api.java.{DecimalType => JDecimalType} +import org.apache.spark.sql.catalyst.types.decimal.Decimal +import org.apache.spark.sql.catalyst.ScalaReflection +import org.apache.spark.sql.catalyst.types.UserDefinedType + protected[sql] object DataTypeConversions { /** @@ -31,19 +36,24 @@ protected[sql] object DataTypeConversions { JDataType.createStructField( scalaStructField.name, asJavaDataType(scalaStructField.dataType), - scalaStructField.nullable) + scalaStructField.nullable, + (new JMetaDataBuilder).withMetadata(scalaStructField.metadata).build()) } /** * Returns the equivalent DataType in Java for the given DataType in Scala. */ def asJavaDataType(scalaDataType: DataType): JDataType = scalaDataType match { + case udtType: UserDefinedType[_] => + UDTWrappers.wrapAsJava(udtType) + case StringType => JDataType.StringType case BinaryType => JDataType.BinaryType case BooleanType => JDataType.BooleanType case DateType => JDataType.DateType case TimestampType => JDataType.TimestampType - case DecimalType => JDataType.DecimalType + case DecimalType.Fixed(precision, scale) => new JDecimalType(precision, scale) + case DecimalType.Unlimited => new JDecimalType() case DoubleType => JDataType.DoubleType case FloatType => JDataType.FloatType case ByteType => JDataType.ByteType @@ -68,13 +78,17 @@ protected[sql] object DataTypeConversions { StructField( javaStructField.getName, asScalaDataType(javaStructField.getDataType), - javaStructField.isNullable) + javaStructField.isNullable, + javaStructField.getMetadata) } /** * Returns the equivalent DataType in Scala for the given DataType in Java. */ def asScalaDataType(javaDataType: JDataType): DataType = javaDataType match { + case udtType: org.apache.spark.sql.api.java.UserDefinedType[_] => + UDTWrappers.wrapAsScala(udtType) + case stringType: org.apache.spark.sql.api.java.StringType => StringType case binaryType: org.apache.spark.sql.api.java.BinaryType => @@ -86,7 +100,11 @@ protected[sql] object DataTypeConversions { case timestampType: org.apache.spark.sql.api.java.TimestampType => TimestampType case decimalType: org.apache.spark.sql.api.java.DecimalType => - DecimalType + if (decimalType.isFixed) { + DecimalType(decimalType.getPrecision, decimalType.getScale) + } else { + DecimalType.Unlimited + } case doubleType: org.apache.spark.sql.api.java.DoubleType => DoubleType case floatType: org.apache.spark.sql.api.java.FloatType => @@ -110,4 +128,18 @@ protected[sql] object DataTypeConversions { case structType: org.apache.spark.sql.api.java.StructType => StructType(structType.getFields.map(asScalaStructField)) } + + /** Converts Java objects to catalyst rows / types */ + def convertJavaToCatalyst(a: Any, dataType: DataType): Any = (a, dataType) match { + case (obj, udt: UserDefinedType[_]) => ScalaReflection.convertToCatalyst(obj, udt) // Scala type + case (d: java.math.BigDecimal, _) => Decimal(BigDecimal(d)) + case (d: java.math.BigDecimal, _) => BigDecimal(d) + case (other, _) => other + } + + /** Converts Java objects to catalyst rows / types */ + def convertCatalystToJava(a: Any): Any = a match { + case d: scala.math.BigDecimal => d.underlying() + case other => other + } } diff --git a/sql/core/src/test/java/org/apache/spark/sql/api/java/JavaApplySchemaSuite.java b/sql/core/src/test/java/org/apache/spark/sql/api/java/JavaApplySchemaSuite.java index 33e5020bc636a..a04b8060cd658 100644 --- a/sql/core/src/test/java/org/apache/spark/sql/api/java/JavaApplySchemaSuite.java +++ b/sql/core/src/test/java/org/apache/spark/sql/api/java/JavaApplySchemaSuite.java @@ -118,7 +118,7 @@ public void applySchemaToJSON() { "\"bigInteger\":92233720368547758069, \"double\":1.7976931348623157E305, " + "\"boolean\":false, \"null\":null}")); List fields = new ArrayList(7); - fields.add(DataType.createStructField("bigInteger", DataType.DecimalType, true)); + fields.add(DataType.createStructField("bigInteger", new DecimalType(), true)); fields.add(DataType.createStructField("boolean", DataType.BooleanType, true)); fields.add(DataType.createStructField("double", DataType.DoubleType, true)); fields.add(DataType.createStructField("integer", DataType.IntegerType, true)); @@ -156,7 +156,7 @@ public void applySchemaToJSON() { JavaSchemaRDD schemaRDD2 = javaSqlCtx.jsonRDD(jsonRDD, expectedSchema); StructType actualSchema2 = schemaRDD2.schema(); Assert.assertEquals(expectedSchema, actualSchema2); - schemaRDD1.registerTempTable("jsonTable2"); + schemaRDD2.registerTempTable("jsonTable2"); List actual2 = javaSqlCtx.sql("select * from jsonTable2").collect(); Assert.assertEquals(expectedResult, actual2); } diff --git a/sql/core/src/test/java/org/apache/spark/sql/api/java/JavaRowSuite.java b/sql/core/src/test/java/org/apache/spark/sql/api/java/JavaRowSuite.java index 52d07b5425cc3..bc5cd66482add 100644 --- a/sql/core/src/test/java/org/apache/spark/sql/api/java/JavaRowSuite.java +++ b/sql/core/src/test/java/org/apache/spark/sql/api/java/JavaRowSuite.java @@ -18,6 +18,7 @@ package org.apache.spark.sql.api.java; import java.math.BigDecimal; +import java.sql.Date; import java.sql.Timestamp; import java.util.Arrays; import java.util.HashMap; @@ -39,6 +40,7 @@ public class JavaRowSuite { private boolean booleanValue; private String stringValue; private byte[] binaryValue; + private Date dateValue; private Timestamp timestampValue; @Before @@ -53,6 +55,7 @@ public void setUp() { booleanValue = true; stringValue = "this is a string"; binaryValue = stringValue.getBytes(); + dateValue = Date.valueOf("2014-06-30"); timestampValue = Timestamp.valueOf("2014-06-30 09:20:00.0"); } @@ -76,6 +79,7 @@ public void constructSimpleRow() { new Boolean(booleanValue), stringValue, // StringType binaryValue, // BinaryType + dateValue, // DateType timestampValue, // TimestampType null // null ); @@ -114,9 +118,10 @@ public void constructSimpleRow() { Assert.assertEquals(stringValue, simpleRow.getString(15)); Assert.assertEquals(stringValue, simpleRow.get(15)); Assert.assertEquals(binaryValue, simpleRow.get(16)); - Assert.assertEquals(timestampValue, simpleRow.get(17)); - Assert.assertEquals(true, simpleRow.isNullAt(18)); - Assert.assertEquals(null, simpleRow.get(18)); + Assert.assertEquals(dateValue, simpleRow.get(17)); + Assert.assertEquals(timestampValue, simpleRow.get(18)); + Assert.assertEquals(true, simpleRow.isNullAt(19)); + Assert.assertEquals(null, simpleRow.get(19)); } @Test diff --git a/sql/core/src/test/java/org/apache/spark/sql/api/java/JavaSideDataTypeConversionSuite.java b/sql/core/src/test/java/org/apache/spark/sql/api/java/JavaSideDataTypeConversionSuite.java index d099a48a1f4b6..8396a29c61c4c 100644 --- a/sql/core/src/test/java/org/apache/spark/sql/api/java/JavaSideDataTypeConversionSuite.java +++ b/sql/core/src/test/java/org/apache/spark/sql/api/java/JavaSideDataTypeConversionSuite.java @@ -39,8 +39,10 @@ public void createDataTypes() { checkDataType(DataType.StringType); checkDataType(DataType.BinaryType); checkDataType(DataType.BooleanType); + checkDataType(DataType.DateType); checkDataType(DataType.TimestampType); - checkDataType(DataType.DecimalType); + checkDataType(new DecimalType()); + checkDataType(new DecimalType(10, 4)); checkDataType(DataType.DoubleType); checkDataType(DataType.FloatType); checkDataType(DataType.ByteType); @@ -58,7 +60,7 @@ public void createDataTypes() { // Simple StructType. List simpleFields = new ArrayList(); - simpleFields.add(DataType.createStructField("a", DataType.DecimalType, false)); + simpleFields.add(DataType.createStructField("a", new DecimalType(), false)); simpleFields.add(DataType.createStructField("b", DataType.BooleanType, true)); simpleFields.add(DataType.createStructField("c", DataType.LongType, true)); simpleFields.add(DataType.createStructField("d", DataType.BinaryType, false)); @@ -127,7 +129,7 @@ public void illegalArgument() { // StructType try { List simpleFields = new ArrayList(); - simpleFields.add(DataType.createStructField("a", DataType.DecimalType, false)); + simpleFields.add(DataType.createStructField("a", new DecimalType(), false)); simpleFields.add(DataType.createStructField("b", DataType.BooleanType, true)); simpleFields.add(DataType.createStructField("c", DataType.LongType, true)); simpleFields.add(null); @@ -137,7 +139,7 @@ public void illegalArgument() { } try { List simpleFields = new ArrayList(); - simpleFields.add(DataType.createStructField("a", DataType.DecimalType, false)); + simpleFields.add(DataType.createStructField("a", new DecimalType(), false)); simpleFields.add(DataType.createStructField("a", DataType.BooleanType, true)); simpleFields.add(DataType.createStructField("c", DataType.LongType, true)); DataType.createStructType(simpleFields); diff --git a/sql/core/src/test/java/org/apache/spark/sql/api/java/JavaUserDefinedTypeSuite.java b/sql/core/src/test/java/org/apache/spark/sql/api/java/JavaUserDefinedTypeSuite.java new file mode 100644 index 0000000000000..0caa8219a63e9 --- /dev/null +++ b/sql/core/src/test/java/org/apache/spark/sql/api/java/JavaUserDefinedTypeSuite.java @@ -0,0 +1,88 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.api.java; + +import java.io.Serializable; +import java.util.*; + +import org.junit.After; +import org.junit.Assert; +import org.junit.Before; +import org.junit.Test; + +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.sql.MyDenseVector; +import org.apache.spark.sql.MyLabeledPoint; + +public class JavaUserDefinedTypeSuite implements Serializable { + private transient JavaSparkContext javaCtx; + private transient JavaSQLContext javaSqlCtx; + + @Before + public void setUp() { + javaCtx = new JavaSparkContext("local", "JavaUserDefinedTypeSuite"); + javaSqlCtx = new JavaSQLContext(javaCtx); + } + + @After + public void tearDown() { + javaCtx.stop(); + javaCtx = null; + javaSqlCtx = null; + } + + @Test + public void useScalaUDT() { + List points = Arrays.asList( + new MyLabeledPoint(1.0, new MyDenseVector(new double[]{0.1, 1.0})), + new MyLabeledPoint(0.0, new MyDenseVector(new double[]{0.2, 2.0}))); + JavaRDD pointsRDD = javaCtx.parallelize(points); + + JavaSchemaRDD schemaRDD = javaSqlCtx.applySchema(pointsRDD, MyLabeledPoint.class); + schemaRDD.registerTempTable("points"); + + List actualLabelRows = javaSqlCtx.sql("SELECT label FROM points").collect(); + List actualLabels = new LinkedList(); + for (Row r : actualLabelRows) { + actualLabels.add(r.getDouble(0)); + } + for (MyLabeledPoint lp : points) { + Assert.assertTrue(actualLabels.contains(lp.label())); + } + + List actualFeatureRows = javaSqlCtx.sql("SELECT features FROM points").collect(); + List actualFeatures = new LinkedList(); + for (Row r : actualFeatureRows) { + actualFeatures.add((MyDenseVector)r.get(0)); + } + for (MyLabeledPoint lp : points) { + Assert.assertTrue(actualFeatures.contains(lp.features())); + } + + List actual = javaSqlCtx.sql("SELECT label, features FROM points").collect(); + List actualPoints = + new LinkedList(); + for (Row r : actual) { + actualPoints.add(new MyLabeledPoint(r.getDouble(0), (MyDenseVector)r.get(1))); + } + for (MyLabeledPoint lp : points) { + Assert.assertTrue(actualPoints.contains(lp)); + } + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala index da5a358df3b1d..765fa82776341 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql import org.apache.spark.sql.TestData._ -import org.apache.spark.sql.columnar.{InMemoryColumnarTableScan, InMemoryRelation} +import org.apache.spark.sql.columnar._ import org.apache.spark.sql.test.TestSQLContext._ import org.apache.spark.storage.{StorageLevel, RDDBlockId} @@ -27,18 +27,6 @@ case class BigData(s: String) class CachedTableSuite extends QueryTest { TestData // Load test tables. - def assertCached(query: SchemaRDD, numCachedTables: Int = 1): Unit = { - val planWithCaching = query.queryExecution.withCachedData - val cachedData = planWithCaching collect { - case cached: InMemoryRelation => cached - } - - assert( - cachedData.size == numCachedTables, - s"Expected query to contain $numCachedTables, but it actually had ${cachedData.size}\n" + - planWithCaching) - } - def rddIdOf(tableName: String): Int = { val executedPlan = table(tableName).queryExecution.executedPlan executedPlan.collect { @@ -234,4 +222,33 @@ class CachedTableSuite extends QueryTest { uncacheTable("testData") assert(!isMaterialized(rddId), "Uncached in-memory table should have been unpersisted") } + + test("InMemoryRelation statistics") { + sql("CACHE TABLE testData") + table("testData").queryExecution.withCachedData.collect { + case cached: InMemoryRelation => + val actualSizeInBytes = (1 to 100).map(i => INT.defaultSize + i.toString.length + 4).sum + assert(cached.statistics.sizeInBytes === actualSizeInBytes) + } + } + + test("Drops temporary table") { + testData.select('key).registerTempTable("t1") + table("t1") + dropTempTable("t1") + assert(intercept[RuntimeException](table("t1")).getMessage.startsWith("Table Not Found")) + } + + test("Drops cached temporary table") { + testData.select('key).registerTempTable("t1") + testData.select('key).registerTempTable("t2") + cacheTable("t1") + + assert(isCached("t1")) + assert(isCached("t2")) + + dropTempTable("t1") + assert(intercept[RuntimeException](table("t1")).getMessage.startsWith("Table Not Found")) + assert(!isCached("t2")) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataTypeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataTypeSuite.scala index 100ecb45e9e88..e9740d913cf57 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataTypeSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataTypeSuite.scala @@ -19,8 +19,6 @@ package org.apache.spark.sql import org.scalatest.FunSuite -import org.apache.spark.sql.catalyst.types.DataType - class DataTypeSuite extends FunSuite { test("construct an ArrayType") { @@ -71,7 +69,7 @@ class DataTypeSuite extends FunSuite { checkDataTypeJsonRepr(LongType) checkDataTypeJsonRepr(FloatType) checkDataTypeJsonRepr(DoubleType) - checkDataTypeJsonRepr(DecimalType) + checkDataTypeJsonRepr(DecimalType.Unlimited) checkDataTypeJsonRepr(TimestampType) checkDataTypeJsonRepr(StringType) checkDataTypeJsonRepr(BinaryType) @@ -79,8 +77,12 @@ class DataTypeSuite extends FunSuite { checkDataTypeJsonRepr(ArrayType(StringType, false)) checkDataTypeJsonRepr(MapType(IntegerType, StringType, true)) checkDataTypeJsonRepr(MapType(IntegerType, ArrayType(DoubleType), false)) + val metadata = new MetadataBuilder() + .putString("name", "age") + .build() checkDataTypeJsonRepr( StructType(Seq( StructField("a", IntegerType, nullable = true), - StructField("b", ArrayType(DoubleType), nullable = false)))) + StructField("b", ArrayType(DoubleType), nullable = false), + StructField("c", DoubleType, nullable = false, metadata)))) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DslQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DslQuerySuite.scala index 45e58afe9d9a2..e70ad891eea36 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DslQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DslQuerySuite.scala @@ -19,14 +19,13 @@ package org.apache.spark.sql import org.apache.spark.sql.catalyst.analysis._ import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.plans._ -import org.apache.spark.sql.test._ /* Implicits */ -import TestSQLContext._ +import org.apache.spark.sql.catalyst.dsl._ +import org.apache.spark.sql.test.TestSQLContext._ class DslQuerySuite extends QueryTest { - import TestData._ + import org.apache.spark.sql.TestData._ test("table scan") { checkAnswer( @@ -216,4 +215,14 @@ class DslQuerySuite extends QueryTest { (4, "d") :: Nil) checkAnswer(lowerCaseData.intersect(upperCaseData), Nil) } + + test("udf") { + val foo = (a: Int, b: String) => a.toString + b + + checkAnswer( + // SELECT *, foo(key, value) FROM testData + testData.select(Star(None), foo.call('key, 'value)).limit(3), + (1, "1", "11") :: (2, "2", "22") :: (3, "3", "33") :: Nil + ) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala index 07f4d2946c1b5..8b4cf5bac0187 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala @@ -19,17 +19,13 @@ package org.apache.spark.sql import org.scalatest.BeforeAndAfterEach -import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation import org.apache.spark.sql.TestData._ -import org.apache.spark.sql.catalyst.plans.JoinType -import org.apache.spark.sql.catalyst.plans.{LeftOuter, RightOuter, FullOuter, Inner, LeftSemi} -import org.apache.spark.sql.execution._ +import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation +import org.apache.spark.sql.catalyst.plans.{FullOuter, Inner, LeftOuter, RightOuter} import org.apache.spark.sql.execution.joins._ -import org.apache.spark.sql.test.TestSQLContext import org.apache.spark.sql.test.TestSQLContext._ class JoinSuite extends QueryTest with BeforeAndAfterEach { - // Ensures tables are loaded. TestData @@ -41,54 +37,65 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach { assert(planned.size === 1) } - test("join operator selection") { - def assertJoin(sqlString: String, c: Class[_]): Any = { - val rdd = sql(sqlString) - val physical = rdd.queryExecution.sparkPlan - val operators = physical.collect { - case j: ShuffledHashJoin => j - case j: HashOuterJoin => j - case j: LeftSemiJoinHash => j - case j: BroadcastHashJoin => j - case j: LeftSemiJoinBNL => j - case j: CartesianProduct => j - case j: BroadcastNestedLoopJoin => j - } - - assert(operators.size === 1) - if (operators(0).getClass() != c) { - fail(s"$sqlString expected operator: $c, but got ${operators(0)}\n physical: \n$physical") - } + def assertJoin(sqlString: String, c: Class[_]): Any = { + val rdd = sql(sqlString) + val physical = rdd.queryExecution.sparkPlan + val operators = physical.collect { + case j: ShuffledHashJoin => j + case j: HashOuterJoin => j + case j: LeftSemiJoinHash => j + case j: BroadcastHashJoin => j + case j: LeftSemiJoinBNL => j + case j: CartesianProduct => j + case j: BroadcastNestedLoopJoin => j + } + + assert(operators.size === 1) + if (operators(0).getClass() != c) { + fail(s"$sqlString expected operator: $c, but got ${operators(0)}\n physical: \n$physical") } + } - val cases1 = Seq( - ("SELECT * FROM testData left semi join testData2 ON key = a", classOf[LeftSemiJoinHash]), - ("SELECT * FROM testData left semi join testData2", classOf[LeftSemiJoinBNL]), - ("SELECT * FROM testData join testData2", classOf[CartesianProduct]), - ("SELECT * FROM testData join testData2 where key=2", classOf[CartesianProduct]), - ("SELECT * FROM testData left join testData2", classOf[CartesianProduct]), - ("SELECT * FROM testData right join testData2", classOf[CartesianProduct]), - ("SELECT * FROM testData full outer join testData2", classOf[CartesianProduct]), - ("SELECT * FROM testData left join testData2 where key=2", classOf[CartesianProduct]), - ("SELECT * FROM testData right join testData2 where key=2", classOf[CartesianProduct]), - ("SELECT * FROM testData full outer join testData2 where key=2", classOf[CartesianProduct]), - ("SELECT * FROM testData join testData2 where key>a", classOf[CartesianProduct]), - ("SELECT * FROM testData full outer join testData2 where key>a", classOf[CartesianProduct]), - ("SELECT * FROM testData join testData2 ON key = a", classOf[ShuffledHashJoin]), - ("SELECT * FROM testData join testData2 ON key = a and key=2", classOf[ShuffledHashJoin]), - ("SELECT * FROM testData join testData2 ON key = a where key=2", classOf[ShuffledHashJoin]), - ("SELECT * FROM testData left join testData2 ON key = a", classOf[HashOuterJoin]), - ("SELECT * FROM testData right join testData2 ON key = a where key=2", + test("join operator selection") { + clearCache() + + Seq( + ("SELECT * FROM testData LEFT SEMI JOIN testData2 ON key = a", classOf[LeftSemiJoinHash]), + ("SELECT * FROM testData LEFT SEMI JOIN testData2", classOf[LeftSemiJoinBNL]), + ("SELECT * FROM testData JOIN testData2", classOf[CartesianProduct]), + ("SELECT * FROM testData JOIN testData2 WHERE key = 2", classOf[CartesianProduct]), + ("SELECT * FROM testData LEFT JOIN testData2", classOf[CartesianProduct]), + ("SELECT * FROM testData RIGHT JOIN testData2", classOf[CartesianProduct]), + ("SELECT * FROM testData FULL OUTER JOIN testData2", classOf[CartesianProduct]), + ("SELECT * FROM testData LEFT JOIN testData2 WHERE key = 2", classOf[CartesianProduct]), + ("SELECT * FROM testData RIGHT JOIN testData2 WHERE key = 2", classOf[CartesianProduct]), + ("SELECT * FROM testData FULL OUTER JOIN testData2 WHERE key = 2", classOf[CartesianProduct]), + ("SELECT * FROM testData JOIN testData2 WHERE key > a", classOf[CartesianProduct]), + ("SELECT * FROM testData FULL OUTER JOIN testData2 WHERE key > a", classOf[CartesianProduct]), + ("SELECT * FROM testData JOIN testData2 ON key = a", classOf[ShuffledHashJoin]), + ("SELECT * FROM testData JOIN testData2 ON key = a and key = 2", classOf[ShuffledHashJoin]), + ("SELECT * FROM testData JOIN testData2 ON key = a where key = 2", classOf[ShuffledHashJoin]), + ("SELECT * FROM testData LEFT JOIN testData2 ON key = a", classOf[HashOuterJoin]), + ("SELECT * FROM testData RIGHT JOIN testData2 ON key = a where key = 2", classOf[HashOuterJoin]), - ("SELECT * FROM testData right join testData2 ON key = a and key=2", + ("SELECT * FROM testData right join testData2 ON key = a and key = 2", classOf[HashOuterJoin]), - ("SELECT * FROM testData full outer join testData2 ON key = a", classOf[HashOuterJoin]), - ("SELECT * FROM testData join testData2 ON key = a", classOf[ShuffledHashJoin]), - ("SELECT * FROM testData join testData2 ON key = a and key=2", classOf[ShuffledHashJoin]), - ("SELECT * FROM testData join testData2 ON key = a where key=2", classOf[ShuffledHashJoin]) - // TODO add BroadcastNestedLoopJoin - ) - cases1.foreach { c => assertJoin(c._1, c._2) } + ("SELECT * FROM testData full outer join testData2 ON key = a", classOf[HashOuterJoin]) + // TODO add BroadcastNestedLoopJoin + ).foreach { case (query, joinClass) => assertJoin(query, joinClass) } + } + + test("broadcasted hash join operator selection") { + clearCache() + sql("CACHE TABLE testData") + + Seq( + ("SELECT * FROM testData join testData2 ON key = a", classOf[BroadcastHashJoin]), + ("SELECT * FROM testData join testData2 ON key = a and key = 2", classOf[BroadcastHashJoin]), + ("SELECT * FROM testData join testData2 ON key = a where key = 2", classOf[BroadcastHashJoin]) + ).foreach { case (query, joinClass) => assertJoin(query, joinClass) } + + sql("UNCACHE TABLE testData") } test("multiple-key equi-join is hash-join") { @@ -171,7 +178,7 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach { (4, "D", 4, "d") :: (5, "E", null, null) :: (6, "F", null, null) :: Nil) - + checkAnswer( upperCaseData.join(lowerCaseData, LeftOuter, Some('n === 'N && 'n > 1)), (1, "A", null, null) :: @@ -180,7 +187,7 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach { (4, "D", 4, "d") :: (5, "E", null, null) :: (6, "F", null, null) :: Nil) - + checkAnswer( upperCaseData.join(lowerCaseData, LeftOuter, Some('n === 'N && 'N > 1)), (1, "A", null, null) :: @@ -189,7 +196,7 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach { (4, "D", 4, "d") :: (5, "E", null, null) :: (6, "F", null, null) :: Nil) - + checkAnswer( upperCaseData.join(lowerCaseData, LeftOuter, Some('n === 'N && 'l > 'L)), (1, "A", 1, "a") :: @@ -300,7 +307,7 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach { (4, "D", 4, "D") :: (null, null, 5, "E") :: (null, null, 6, "F") :: Nil) - + checkAnswer( left.join(right, FullOuter, Some(("left.N".attr === "right.N".attr) && ("left.N".attr !== 3))), (1, "A", null, null) :: @@ -310,7 +317,7 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach { (4, "D", 4, "D") :: (null, null, 5, "E") :: (null, null, 6, "F") :: Nil) - + checkAnswer( left.join(right, FullOuter, Some(("left.N".attr === "right.N".attr) && ("right.N".attr !== 3))), (1, "A", null, null) :: diff --git a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala index 1fd8d27b34c59..3d9f0cbf80fe7 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala @@ -19,8 +19,29 @@ package org.apache.spark.sql import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.util._ +import org.apache.spark.sql.columnar.InMemoryRelation class QueryTest extends PlanTest { + + /** + * Runs the plan and makes sure the answer contains all of the keywords, or the + * none of keywords are listed in the answer + * @param rdd the [[SchemaRDD]] to be executed + * @param exists true for make sure the keywords are listed in the output, otherwise + * to make sure none of the keyword are not listed in the output + * @param keywords keyword in string array + */ + def checkExistence(rdd: SchemaRDD, exists: Boolean, keywords: String*) { + val outputs = rdd.collect().map(_.mkString).mkString + for (key <- keywords) { + if (exists) { + assert(outputs.contains(key), s"Failed for $rdd ($key doens't exist in result)") + } else { + assert(!outputs.contains(key), s"Failed for $rdd ($key existed in the result)") + } + } + } + /** * Runs the plan and makes sure the answer matches the expected result. * @param rdd the [[SchemaRDD]] to be executed @@ -59,11 +80,31 @@ class QueryTest extends PlanTest { |${rdd.queryExecution.executedPlan} |== Results == |${sideBySide( - s"== Correct Answer - ${convertedAnswer.size} ==" +: - prepareAnswer(convertedAnswer).map(_.toString), - s"== Spark Answer - ${sparkAnswer.size} ==" +: - prepareAnswer(sparkAnswer).map(_.toString)).mkString("\n")} + s"== Correct Answer - ${convertedAnswer.size} ==" +: + prepareAnswer(convertedAnswer).map(_.toString), + s"== Spark Answer - ${sparkAnswer.size} ==" +: + prepareAnswer(sparkAnswer).map(_.toString)).mkString("\n")} """.stripMargin) } } + + def sqlTest(sqlString: String, expectedAnswer: Any)(implicit sqlContext: SQLContext): Unit = { + test(sqlString) { + checkAnswer(sqlContext.sql(sqlString), expectedAnswer) + } + } + + /** Asserts that a given SchemaRDD will be executed using the given number of cached results. */ + def assertCached(query: SchemaRDD, numCachedTables: Int = 1): Unit = { + val planWithCaching = query.queryExecution.withCachedData + val cachedData = planWithCaching collect { + case cached: InMemoryRelation => cached + } + + assert( + cachedData.size == numCachedTables, + s"Expected query to contain $numCachedTables, but it actually had ${cachedData.size}\n" + + planWithCaching) + } + } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index 3959925a2e529..702714af5308d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -17,17 +17,17 @@ package org.apache.spark.sql +import java.util.TimeZone + +import org.scalatest.BeforeAndAfterAll + import org.apache.spark.sql.catalyst.errors.TreeNodeException import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan -import org.apache.spark.sql.execution.joins.BroadcastHashJoin -import org.apache.spark.sql.test._ -import org.scalatest.BeforeAndAfterAll -import java.util.TimeZone /* Implicits */ -import TestSQLContext._ -import TestData._ +import org.apache.spark.sql.TestData._ +import org.apache.spark.sql.test.TestSQLContext._ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll { // Make sure the tables are loaded. @@ -697,6 +697,30 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll { ("true", "false") :: Nil) } + test("metadata is propagated correctly") { + val person = sql("SELECT * FROM person") + val schema = person.schema + val docKey = "doc" + val docValue = "first name" + val metadata = new MetadataBuilder() + .putString(docKey, docValue) + .build() + val schemaWithMeta = new StructType(Seq( + schema("id"), schema("name").copy(metadata = metadata), schema("age"))) + val personWithMeta = applySchema(person, schemaWithMeta) + def validateMetadata(rdd: SchemaRDD): Unit = { + assert(rdd.schema("name").metadata.getString(docKey) == docValue) + } + personWithMeta.registerTempTable("personWithMeta") + validateMetadata(personWithMeta.select('name)) + validateMetadata(personWithMeta.select("name".attr)) + validateMetadata(personWithMeta.select('id, 'name)) + validateMetadata(sql("SELECT * FROM personWithMeta")) + validateMetadata(sql("SELECT id, name FROM personWithMeta")) + validateMetadata(sql("SELECT * FROM personWithMeta JOIN salary ON id = personId")) + validateMetadata(sql("SELECT name, salary FROM personWithMeta JOIN salary ON id = personId")) + } + test("SPARK-3371 Renaming a function expression with group by gives error") { registerFunction("len", (s: String) => s.length) checkAnswer( @@ -738,6 +762,135 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll { checkAggregation("SELECT key + 1 + 1, COUNT(*) FROM testData GROUP BY key + 1", false) } + test("Test to check we can use Long.MinValue") { + checkAnswer( + sql(s"SELECT ${Long.MinValue} FROM testData ORDER BY key LIMIT 1"), Long.MinValue + ) + + checkAnswer( + sql(s"SELECT key FROM testData WHERE key > ${Long.MinValue}"), (1 to 100).map(Row(_)).toSeq + ) + } + + test("Floating point number format") { + checkAnswer( + sql("SELECT 0.3"), 0.3 + ) + + checkAnswer( + sql("SELECT -0.8"), -0.8 + ) + + checkAnswer( + sql("SELECT .5"), 0.5 + ) + + checkAnswer( + sql("SELECT -.18"), -0.18 + ) + } + + test("Auto cast integer type") { + checkAnswer( + sql(s"SELECT ${Int.MaxValue + 1L}"), Int.MaxValue + 1L + ) + + checkAnswer( + sql(s"SELECT ${Int.MinValue - 1L}"), Int.MinValue - 1L + ) + + checkAnswer( + sql("SELECT 9223372036854775808"), BigDecimal("9223372036854775808") + ) + + checkAnswer( + sql("SELECT -9223372036854775809"), BigDecimal("-9223372036854775809") + ) + } + + test("Test to check we can apply sign to expression") { + + checkAnswer( + sql("SELECT -100"), -100 + ) + + checkAnswer( + sql("SELECT +230"), 230 + ) + + checkAnswer( + sql("SELECT -5.2"), -5.2 + ) + + checkAnswer( + sql("SELECT +6.8"), 6.8 + ) + + checkAnswer( + sql("SELECT -key FROM testData WHERE key = 2"), -2 + ) + + checkAnswer( + sql("SELECT +key FROM testData WHERE key = 3"), 3 + ) + + checkAnswer( + sql("SELECT -(key + 1) FROM testData WHERE key = 1"), -2 + ) + + checkAnswer( + sql("SELECT - key + 1 FROM testData WHERE key = 10"), -9 + ) + + checkAnswer( + sql("SELECT +(key + 5) FROM testData WHERE key = 5"), 10 + ) + + checkAnswer( + sql("SELECT -MAX(key) FROM testData"), -100 + ) + + checkAnswer( + sql("SELECT +MAX(key) FROM testData"), 100 + ) + + checkAnswer( + sql("SELECT - (-10)"), 10 + ) + + checkAnswer( + sql("SELECT + (-key) FROM testData WHERE key = 32"), -32 + ) + + checkAnswer( + sql("SELECT - (+Max(key)) FROM testData"), -100 + ) + + checkAnswer( + sql("SELECT - - 3"), 3 + ) + + checkAnswer( + sql("SELECT - + 20"), -20 + ) + + checkAnswer( + sql("SELEcT - + 45"), -45 + ) + + checkAnswer( + sql("SELECT + + 100"), 100 + ) + + checkAnswer( + sql("SELECT - - Max(key) FROM testData"), 100 + ) + + checkAnswer( + sql("SELECT + - key FROM testData WHERE key = 33"), -33 + ) + } + test("Multiple join") { checkAnswer( sql( @@ -748,4 +901,46 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll { """.stripMargin), (1 to 100).map(i => Seq(i, i, i))) } + + test("SPARK-3483 Special chars in column names") { + val data = sparkContext.parallelize(Seq("""{"key?number1": "value1", "key.number2": "value2"}""")) + jsonRDD(data).registerTempTable("records") + sql("SELECT `key?number1` FROM records") + } + + test("SPARK-3814 Support Bitwise & operator") { + checkAnswer(sql("SELECT key&1 FROM testData WHERE key = 1 "), 1) + } + + test("SPARK-3814 Support Bitwise | operator") { + checkAnswer(sql("SELECT key|0 FROM testData WHERE key = 1 "), 1) + } + + test("SPARK-3814 Support Bitwise ^ operator") { + checkAnswer(sql("SELECT key^0 FROM testData WHERE key = 1 "), 1) + } + + test("SPARK-3814 Support Bitwise ~ operator") { + checkAnswer(sql("SELECT ~key FROM testData WHERE key = 1 "), -2) + } + + test("SPARK-4120 Join of multiple tables does not work in SparkSQL") { + checkAnswer( + sql( + """SELECT a.key, b.key, c.key + |FROM testData a,testData b,testData c + |where a.key = b.key and a.key = c.key + """.stripMargin), + (1 to 100).map(i => Seq(i, i, i))) + } + + test("SPARK-4154 Query does not work if it has 'not between' in Spark SQL and HQL") { + checkAnswer(sql("SELECT key FROM testData WHERE key not between 0 and 10 order by key"), + (11 to 100).map(i => Seq(i))) + } + + test("SPARK-4207 Query which has syntax like 'not like' is not working in Spark SQL") { + checkAnswer(sql("SELECT key FROM testData WHERE value not like '100%' order by key"), + (1 to 99).map(i => Seq(i))) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ScalaReflectionRelationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ScalaReflectionRelationSuite.scala index bfa9ea416266d..cf3a59e545905 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/ScalaReflectionRelationSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/ScalaReflectionRelationSuite.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql import java.sql.{Date, Timestamp} +import org.apache.spark.sql.catalyst.types.decimal.Decimal import org.scalatest.FunSuite import org.apache.spark.sql.catalyst.expressions._ @@ -81,7 +82,9 @@ class ScalaReflectionRelationSuite extends FunSuite { val rdd = sparkContext.parallelize(data :: Nil) rdd.registerTempTable("reflectData") - assert(sql("SELECT * FROM reflectData").collect().head === data.productIterator.toSeq) + assert(sql("SELECT * FROM reflectData").collect().head === + Seq("a", 1, 1L, 1.toFloat, 1.toDouble, 1.toShort, 1.toByte, true, + BigDecimal(1), new Date(12345), new Timestamp(12345), Seq(1,2,3))) } test("query case class RDD with nulls") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala b/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala index 10b7979df7375..ef87a230639bc 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala @@ -28,40 +28,40 @@ import org.apache.spark.sql.test.TestSQLContext._ case class TestData(key: Int, value: String) object TestData { - val testData: SchemaRDD = TestSQLContext.sparkContext.parallelize( - (1 to 100).map(i => TestData(i, i.toString))) + val testData = TestSQLContext.sparkContext.parallelize( + (1 to 100).map(i => TestData(i, i.toString))).toSchemaRDD testData.registerTempTable("testData") case class LargeAndSmallInts(a: Int, b: Int) - val largeAndSmallInts: SchemaRDD = + val largeAndSmallInts = TestSQLContext.sparkContext.parallelize( LargeAndSmallInts(2147483644, 1) :: LargeAndSmallInts(1, 2) :: LargeAndSmallInts(2147483645, 1) :: LargeAndSmallInts(2, 2) :: LargeAndSmallInts(2147483646, 1) :: - LargeAndSmallInts(3, 2) :: Nil) + LargeAndSmallInts(3, 2) :: Nil).toSchemaRDD largeAndSmallInts.registerTempTable("largeAndSmallInts") case class TestData2(a: Int, b: Int) - val testData2: SchemaRDD = + val testData2 = TestSQLContext.sparkContext.parallelize( TestData2(1, 1) :: TestData2(1, 2) :: TestData2(2, 1) :: TestData2(2, 2) :: TestData2(3, 1) :: - TestData2(3, 2) :: Nil) + TestData2(3, 2) :: Nil).toSchemaRDD testData2.registerTempTable("testData2") case class BinaryData(a: Array[Byte], b: Int) - val binaryData: SchemaRDD = + val binaryData = TestSQLContext.sparkContext.parallelize( BinaryData("12".getBytes(), 1) :: BinaryData("22".getBytes(), 5) :: BinaryData("122".getBytes(), 3) :: BinaryData("121".getBytes(), 2) :: - BinaryData("123".getBytes(), 4) :: Nil) + BinaryData("123".getBytes(), 4) :: Nil).toSchemaRDD binaryData.registerTempTable("binaryData") // TODO: There is no way to express null primitives as case classes currently... @@ -80,7 +80,7 @@ object TestData { UpperCaseData(3, "C") :: UpperCaseData(4, "D") :: UpperCaseData(5, "E") :: - UpperCaseData(6, "F") :: Nil) + UpperCaseData(6, "F") :: Nil).toSchemaRDD upperCaseData.registerTempTable("upperCaseData") case class LowerCaseData(n: Int, l: String) @@ -89,7 +89,7 @@ object TestData { LowerCaseData(1, "a") :: LowerCaseData(2, "b") :: LowerCaseData(3, "c") :: - LowerCaseData(4, "d") :: Nil) + LowerCaseData(4, "d") :: Nil).toSchemaRDD lowerCaseData.registerTempTable("lowerCaseData") case class ArrayData(data: Seq[Int], nestedData: Seq[Seq[Int]]) @@ -99,7 +99,7 @@ object TestData { ArrayData(Seq(2,3,4), Seq(Seq(2,3,4))) :: Nil) arrayData.registerTempTable("arrayData") - case class MapData(data: Map[Int, String]) + case class MapData(data: scala.collection.Map[Int, String]) val mapData = TestSQLContext.sparkContext.parallelize( MapData(Map(1 -> "a1", 2 -> "b1", 3 -> "c1", 4 -> "d1", 5 -> "e1")) :: @@ -166,4 +166,23 @@ object TestData { // An RDD with 4 elements and 8 partitions val withEmptyParts = TestSQLContext.sparkContext.parallelize((1 to 4).map(IntField), 8) withEmptyParts.registerTempTable("withEmptyParts") + + case class Person(id: Int, name: String, age: Int) + case class Salary(personId: Int, salary: Double) + val person = TestSQLContext.sparkContext.parallelize( + Person(0, "mike", 30) :: + Person(1, "jim", 20) :: Nil) + person.registerTempTable("person") + val salary = TestSQLContext.sparkContext.parallelize( + Salary(0, 2000.0) :: + Salary(1, 1000.0) :: Nil) + salary.registerTempTable("salary") + + case class ComplexData(m: Map[Int, String], s: TestData, a: Seq[Int], b: Boolean) + val complexData = + TestSQLContext.sparkContext.parallelize( + ComplexData(Map(1 -> "1"), TestData(1, "1"), Seq(1), true) + :: ComplexData(Map(2 -> "2"), TestData(2, "2"), Seq(2), false) + :: Nil).toSchemaRDD + complexData.registerTempTable("complexData") } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala new file mode 100644 index 0000000000000..1806a1dd82023 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala @@ -0,0 +1,91 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +import scala.beans.{BeanInfo, BeanProperty} + +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.catalyst.annotation.SQLUserDefinedType +import org.apache.spark.sql.catalyst.types.UserDefinedType +import org.apache.spark.sql.test.TestSQLContext._ + +@SQLUserDefinedType(udt = classOf[MyDenseVectorUDT]) +private[sql] class MyDenseVector(val data: Array[Double]) extends Serializable { + override def equals(other: Any): Boolean = other match { + case v: MyDenseVector => + java.util.Arrays.equals(this.data, v.data) + case _ => false + } +} + +@BeanInfo +private[sql] case class MyLabeledPoint( + @BeanProperty label: Double, + @BeanProperty features: MyDenseVector) + +private[sql] class MyDenseVectorUDT extends UserDefinedType[MyDenseVector] { + + override def sqlType: DataType = ArrayType(DoubleType, containsNull = false) + + override def serialize(obj: Any): Seq[Double] = { + obj match { + case features: MyDenseVector => + features.data.toSeq + } + } + + override def deserialize(datum: Any): MyDenseVector = { + datum match { + case data: Seq[_] => + new MyDenseVector(data.asInstanceOf[Seq[Double]].toArray) + } + } + + override def userClass = classOf[MyDenseVector] +} + +class UserDefinedTypeSuite extends QueryTest { + val points = Seq( + MyLabeledPoint(1.0, new MyDenseVector(Array(0.1, 1.0))), + MyLabeledPoint(0.0, new MyDenseVector(Array(0.2, 2.0)))) + val pointsRDD: RDD[MyLabeledPoint] = sparkContext.parallelize(points) + + + test("register user type: MyDenseVector for MyLabeledPoint") { + val labels: RDD[Double] = pointsRDD.select('label).map { case Row(v: Double) => v } + val labelsArrays: Array[Double] = labels.collect() + assert(labelsArrays.size === 2) + assert(labelsArrays.contains(1.0)) + assert(labelsArrays.contains(0.0)) + + val features: RDD[MyDenseVector] = + pointsRDD.select('features).map { case Row(v: MyDenseVector) => v } + val featuresArrays: Array[MyDenseVector] = features.collect() + assert(featuresArrays.size === 2) + assert(featuresArrays.contains(new MyDenseVector(Array(0.1, 1.0)))) + assert(featuresArrays.contains(new MyDenseVector(Array(0.2, 2.0)))) + } + + test("UDTs and UDFs") { + registerFunction("testType", (d: MyDenseVector) => d.isInstanceOf[MyDenseVector]) + pointsRDD.registerTempTable("points") + checkAnswer( + sql("SELECT testType(features) from points"), + Seq(Row(true), Row(true))) + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/api/java/JavaSQLSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/api/java/JavaSQLSuite.scala index 203ff847e94cc..c9012c9e47cff 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/api/java/JavaSQLSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/api/java/JavaSQLSuite.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.api.java +import org.apache.spark.sql.catalyst.types.decimal.Decimal + import scala.beans.BeanProperty import org.scalatest.FunSuite @@ -45,6 +47,9 @@ class AllTypesBean extends Serializable { @BeanProperty var shortField: java.lang.Short = _ @BeanProperty var byteField: java.lang.Byte = _ @BeanProperty var booleanField: java.lang.Boolean = _ + @BeanProperty var dateField: java.sql.Date = _ + @BeanProperty var timestampField: java.sql.Timestamp = _ + @BeanProperty var bigDecimalField: java.math.BigDecimal = _ } class JavaSQLSuite extends FunSuite { @@ -73,6 +78,9 @@ class JavaSQLSuite extends FunSuite { bean.setShortField(0.toShort) bean.setByteField(0.toByte) bean.setBooleanField(false) + bean.setDateField(java.sql.Date.valueOf("2014-10-10")) + bean.setTimestampField(java.sql.Timestamp.valueOf("2014-10-10 00:00:00.0")) + bean.setBigDecimalField(new java.math.BigDecimal(0)) val rdd = javaCtx.parallelize(bean :: Nil) val schemaRDD = javaSqlCtx.applySchema(rdd, classOf[AllTypesBean]) @@ -82,10 +90,34 @@ class JavaSQLSuite extends FunSuite { javaSqlCtx.sql( """ |SELECT stringField, intField, longField, floatField, doubleField, shortField, byteField, - | booleanField + | booleanField, dateField, timestampField, bigDecimalField |FROM allTypes """.stripMargin).collect.head.row === - Seq("", 0, 0L, 0F, 0.0, 0.toShort, 0.toByte, false)) + Seq("", 0, 0L, 0F, 0.0, 0.toShort, 0.toByte, false, java.sql.Date.valueOf("2014-10-10"), + java.sql.Timestamp.valueOf("2014-10-10 00:00:00.0"), scala.math.BigDecimal(0))) + } + + test("decimal types in JavaBeans") { + val bean = new AllTypesBean + bean.setStringField("") + bean.setIntField(0) + bean.setLongField(0) + bean.setFloatField(0.0F) + bean.setDoubleField(0.0) + bean.setShortField(0.toShort) + bean.setByteField(0.toByte) + bean.setBooleanField(false) + bean.setDateField(java.sql.Date.valueOf("2014-10-10")) + bean.setTimestampField(java.sql.Timestamp.valueOf("2014-10-10 00:00:00.0")) + bean.setBigDecimalField(new java.math.BigDecimal(0)) + + val rdd = javaCtx.parallelize(bean :: Nil) + val schemaRDD = javaSqlCtx.applySchema(rdd, classOf[AllTypesBean]) + schemaRDD.registerTempTable("decimalTypes") + + assert(javaSqlCtx.sql( + "select bigDecimalField + bigDecimalField from decimalTypes" + ).collect.head.row === Seq(scala.math.BigDecimal(0))) } test("all types null in JavaBeans") { @@ -98,6 +130,9 @@ class JavaSQLSuite extends FunSuite { bean.setShortField(null) bean.setByteField(null) bean.setBooleanField(null) + bean.setDateField(null) + bean.setTimestampField(null) + bean.setBigDecimalField(null) val rdd = javaCtx.parallelize(bean :: Nil) val schemaRDD = javaSqlCtx.applySchema(rdd, classOf[AllTypesBean]) @@ -107,10 +142,10 @@ class JavaSQLSuite extends FunSuite { javaSqlCtx.sql( """ |SELECT stringField, intField, longField, floatField, doubleField, shortField, byteField, - | booleanField + | booleanField, dateField, timestampField, bigDecimalField |FROM allTypes """.stripMargin).collect.head.row === - Seq.fill(8)(null)) + Seq.fill(11)(null)) } test("loads JSON datasets") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/api/java/ScalaSideDataTypeConversionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/api/java/ScalaSideDataTypeConversionSuite.scala index ff1debff0f8c1..62fe59dd345d7 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/api/java/ScalaSideDataTypeConversionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/api/java/ScalaSideDataTypeConversionSuite.scala @@ -17,12 +17,10 @@ package org.apache.spark.sql.api.java -import org.apache.spark.sql.types.util.DataTypeConversions import org.scalatest.FunSuite -import org.apache.spark.sql.{DataType => SDataType, StructField => SStructField} -import org.apache.spark.sql.{StructType => SStructType} -import DataTypeConversions._ +import org.apache.spark.sql.{DataType => SDataType, StructField => SStructField, StructType => SStructType} +import org.apache.spark.sql.types.util.DataTypeConversions._ class ScalaSideDataTypeConversionSuite extends FunSuite { @@ -38,8 +36,9 @@ class ScalaSideDataTypeConversionSuite extends FunSuite { checkDataType(org.apache.spark.sql.StringType) checkDataType(org.apache.spark.sql.BinaryType) checkDataType(org.apache.spark.sql.BooleanType) + checkDataType(org.apache.spark.sql.DateType) checkDataType(org.apache.spark.sql.TimestampType) - checkDataType(org.apache.spark.sql.DecimalType) + checkDataType(org.apache.spark.sql.DecimalType.Unlimited) checkDataType(org.apache.spark.sql.DoubleType) checkDataType(org.apache.spark.sql.FloatType) checkDataType(org.apache.spark.sql.ByteType) @@ -59,18 +58,22 @@ class ScalaSideDataTypeConversionSuite extends FunSuite { // Simple StructType. val simpleScalaStructType = SStructType( - SStructField("a", org.apache.spark.sql.DecimalType, false) :: + SStructField("a", org.apache.spark.sql.DecimalType.Unlimited, false) :: SStructField("b", org.apache.spark.sql.BooleanType, true) :: SStructField("c", org.apache.spark.sql.LongType, true) :: SStructField("d", org.apache.spark.sql.BinaryType, false) :: Nil) checkDataType(simpleScalaStructType) // Complex StructType. + val metadata = new MetadataBuilder() + .putString("name", "age") + .build() val complexScalaStructType = SStructType( SStructField("simpleArray", simpleScalaArrayType, true) :: SStructField("simpleMap", simpleScalaMapType, true) :: SStructField("simpleStruct", simpleScalaStructType, true) :: - SStructField("boolean", org.apache.spark.sql.BooleanType, false) :: Nil) + SStructField("boolean", org.apache.spark.sql.BooleanType, false) :: + SStructField("withMeta", org.apache.spark.sql.DoubleType, false, metadata) :: Nil) checkDataType(complexScalaStructType) // Complex ArrayType. diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnStatsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnStatsSuite.scala index 6bdf741134e2f..a9f0851f8826c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnStatsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnStatsSuite.scala @@ -61,6 +61,12 @@ class ColumnStatsSuite extends FunSuite { assertResult(values.min(ordering), "Wrong lower bound")(stats(0)) assertResult(values.max(ordering), "Wrong upper bound")(stats(1)) assertResult(10, "Wrong null count")(stats(2)) + assertResult(20, "Wrong row count")(stats(3)) + assertResult(stats(4), "Wrong size in bytes") { + rows.map { row => + if (row.isNullAt(0)) 4 else columnType.actualSize(row, 0) + }.sum + } } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/InMemoryColumnarQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/InMemoryColumnarQuerySuite.scala index 9775dd26b7773..15903d07df29a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/InMemoryColumnarQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/InMemoryColumnarQuerySuite.scala @@ -17,17 +17,18 @@ package org.apache.spark.sql.columnar +import org.apache.spark.sql.TestData._ import org.apache.spark.sql.catalyst.expressions.Row -import org.apache.spark.sql.test.TestSQLContext +import org.apache.spark.sql.test.TestSQLContext._ import org.apache.spark.sql.{QueryTest, TestData} import org.apache.spark.storage.StorageLevel.MEMORY_ONLY class InMemoryColumnarQuerySuite extends QueryTest { - import org.apache.spark.sql.TestData._ - import org.apache.spark.sql.test.TestSQLContext._ + // Make sure the tables are loaded. + TestData test("simple columnar query") { - val plan = TestSQLContext.executePlan(testData.logicalPlan).executedPlan + val plan = executePlan(testData.logicalPlan).executedPlan val scan = InMemoryRelation(useCompression = true, 5, MEMORY_ONLY, plan) checkAnswer(scan, testData.collect().toSeq) @@ -42,7 +43,7 @@ class InMemoryColumnarQuerySuite extends QueryTest { } test("projection") { - val plan = TestSQLContext.executePlan(testData.select('value, 'key).logicalPlan).executedPlan + val plan = executePlan(testData.select('value, 'key).logicalPlan).executedPlan val scan = InMemoryRelation(useCompression = true, 5, MEMORY_ONLY, plan) checkAnswer(scan, testData.collect().map { @@ -51,7 +52,7 @@ class InMemoryColumnarQuerySuite extends QueryTest { } test("SPARK-1436 regression: in-memory columns must be able to be accessed multiple times") { - val plan = TestSQLContext.executePlan(testData.logicalPlan).executedPlan + val plan = executePlan(testData.logicalPlan).executedPlan val scan = InMemoryRelation(useCompression = true, 5, MEMORY_ONLY, plan) checkAnswer(scan, testData.collect().toSeq) @@ -63,7 +64,7 @@ class InMemoryColumnarQuerySuite extends QueryTest { sql("SELECT * FROM repeatedData"), repeatedData.collect().toSeq) - TestSQLContext.cacheTable("repeatedData") + cacheTable("repeatedData") checkAnswer( sql("SELECT * FROM repeatedData"), @@ -75,7 +76,7 @@ class InMemoryColumnarQuerySuite extends QueryTest { sql("SELECT * FROM nullableRepeatedData"), nullableRepeatedData.collect().toSeq) - TestSQLContext.cacheTable("nullableRepeatedData") + cacheTable("nullableRepeatedData") checkAnswer( sql("SELECT * FROM nullableRepeatedData"), @@ -87,7 +88,7 @@ class InMemoryColumnarQuerySuite extends QueryTest { sql("SELECT time FROM timestamps"), timestamps.collect().toSeq) - TestSQLContext.cacheTable("timestamps") + cacheTable("timestamps") checkAnswer( sql("SELECT time FROM timestamps"), @@ -99,10 +100,17 @@ class InMemoryColumnarQuerySuite extends QueryTest { sql("SELECT * FROM withEmptyParts"), withEmptyParts.collect().toSeq) - TestSQLContext.cacheTable("withEmptyParts") + cacheTable("withEmptyParts") checkAnswer( sql("SELECT * FROM withEmptyParts"), withEmptyParts.collect().toSeq) } + + test("SPARK-4182 Caching complex types") { + complexData.cache().count() + // Shouldn't throw + complexData.count() + complexData.unpersist() + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/PartitionBatchPruningSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/PartitionBatchPruningSuite.scala index f53acc8c9f718..9ba3c210171bd 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/PartitionBatchPruningSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/PartitionBatchPruningSuite.scala @@ -22,8 +22,6 @@ import org.scalatest.{BeforeAndAfter, BeforeAndAfterAll, FunSuite} import org.apache.spark.sql._ import org.apache.spark.sql.test.TestSQLContext._ -case class IntegerData(i: Int) - class PartitionBatchPruningSuite extends FunSuite with BeforeAndAfterAll with BeforeAndAfter { val originalColumnBatchSize = columnBatchSize val originalInMemoryPartitionPruning = inMemoryPartitionPruning @@ -31,8 +29,12 @@ class PartitionBatchPruningSuite extends FunSuite with BeforeAndAfterAll with Be override protected def beforeAll(): Unit = { // Make a table with 5 partitions, 2 batches per partition, 10 elements per batch setConf(SQLConf.COLUMN_BATCH_SIZE, "10") - val rawData = sparkContext.makeRDD(1 to 100, 5).map(IntegerData) - rawData.registerTempTable("intData") + + val pruningData = sparkContext.makeRDD((1 to 100).map { key => + val string = if (((key - 1) / 10) % 2 == 0) null else key.toString + TestData(key, string) + }, 5) + pruningData.registerTempTable("pruningData") // Enable in-memory partition pruning setConf(SQLConf.IN_MEMORY_PARTITION_PRUNING, "true") @@ -44,48 +46,64 @@ class PartitionBatchPruningSuite extends FunSuite with BeforeAndAfterAll with Be } before { - cacheTable("intData") + cacheTable("pruningData") } after { - uncacheTable("intData") + uncacheTable("pruningData") } // Comparisons - checkBatchPruning("i = 1", Seq(1), 1, 1) - checkBatchPruning("1 = i", Seq(1), 1, 1) - checkBatchPruning("i < 12", 1 to 11, 1, 2) - checkBatchPruning("i <= 11", 1 to 11, 1, 2) - checkBatchPruning("i > 88", 89 to 100, 1, 2) - checkBatchPruning("i >= 89", 89 to 100, 1, 2) - checkBatchPruning("12 > i", 1 to 11, 1, 2) - checkBatchPruning("11 >= i", 1 to 11, 1, 2) - checkBatchPruning("88 < i", 89 to 100, 1, 2) - checkBatchPruning("89 <= i", 89 to 100, 1, 2) + checkBatchPruning("SELECT key FROM pruningData WHERE key = 1", 1, 1)(Seq(1)) + checkBatchPruning("SELECT key FROM pruningData WHERE 1 = key", 1, 1)(Seq(1)) + checkBatchPruning("SELECT key FROM pruningData WHERE key < 12", 1, 2)(1 to 11) + checkBatchPruning("SELECT key FROM pruningData WHERE key <= 11", 1, 2)(1 to 11) + checkBatchPruning("SELECT key FROM pruningData WHERE key > 88", 1, 2)(89 to 100) + checkBatchPruning("SELECT key FROM pruningData WHERE key >= 89", 1, 2)(89 to 100) + checkBatchPruning("SELECT key FROM pruningData WHERE 12 > key", 1, 2)(1 to 11) + checkBatchPruning("SELECT key FROM pruningData WHERE 11 >= key", 1, 2)(1 to 11) + checkBatchPruning("SELECT key FROM pruningData WHERE 88 < key", 1, 2)(89 to 100) + checkBatchPruning("SELECT key FROM pruningData WHERE 89 <= key", 1, 2)(89 to 100) + + // IS NULL + checkBatchPruning("SELECT key FROM pruningData WHERE value IS NULL", 5, 5) { + (1 to 10) ++ (21 to 30) ++ (41 to 50) ++ (61 to 70) ++ (81 to 90) + } + + // IS NOT NULL + checkBatchPruning("SELECT key FROM pruningData WHERE value IS NOT NULL", 5, 5) { + (11 to 20) ++ (31 to 40) ++ (51 to 60) ++ (71 to 80) ++ (91 to 100) + } // Conjunction and disjunction - checkBatchPruning("i > 8 AND i <= 21", 9 to 21, 2, 3) - checkBatchPruning("i < 2 OR i > 99", Seq(1, 100), 2, 2) - checkBatchPruning("i < 2 OR (i > 78 AND i < 92)", Seq(1) ++ (79 to 91), 3, 4) - checkBatchPruning("NOT (i < 88)", 88 to 100, 1, 2) + checkBatchPruning("SELECT key FROM pruningData WHERE key > 8 AND key <= 21", 2, 3)(9 to 21) + checkBatchPruning("SELECT key FROM pruningData WHERE key < 2 OR key > 99", 2, 2)(Seq(1, 100)) + checkBatchPruning("SELECT key FROM pruningData WHERE key < 2 OR (key > 78 AND key < 92)", 3, 4) { + Seq(1) ++ (79 to 91) + } // With unsupported predicate - checkBatchPruning("i < 12 AND i IS NOT NULL", 1 to 11, 1, 2) - checkBatchPruning(s"NOT (i in (${(1 to 30).mkString(",")}))", 31 to 100, 5, 10) + checkBatchPruning("SELECT key FROM pruningData WHERE NOT (key < 88)", 1, 2)(88 to 100) + checkBatchPruning("SELECT key FROM pruningData WHERE key < 12 AND key IS NOT NULL", 1, 2)(1 to 11) + + { + val seq = (1 to 30).mkString(", ") + checkBatchPruning(s"SELECT key FROM pruningData WHERE NOT (key IN ($seq))", 5, 10)(31 to 100) + } def checkBatchPruning( - filter: String, - expectedQueryResult: Seq[Int], + query: String, expectedReadPartitions: Int, - expectedReadBatches: Int): Unit = { + expectedReadBatches: Int)( + expectedQueryResult: => Seq[Int]): Unit = { - test(filter) { - val query = sql(s"SELECT * FROM intData WHERE $filter") + test(query) { + val schemaRdd = sql(query) assertResult(expectedQueryResult.toArray, "Wrong query result") { - query.collect().map(_.head).toArray + schemaRdd.collect().map(_.head).toArray } - val (readPartitions, readBatches) = query.queryExecution.executedPlan.collect { + val (readPartitions, readBatches) = schemaRdd.queryExecution.executedPlan.collect { case in: InMemoryColumnarTableScan => (in.readPartitions.value, in.readBatches.value) }.head diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala index f14ffca0e4d35..a5af71acfc79a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala @@ -76,4 +76,24 @@ class PlannerSuite extends FunSuite { setConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD, origThreshold.toString) } + + test("InMemoryRelation statistics propagation") { + val origThreshold = autoBroadcastJoinThreshold + setConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD, 81920.toString) + + testData.limit(3).registerTempTable("tiny") + sql("CACHE TABLE tiny") + + val a = testData.as('a) + val b = table("tiny").as('b) + val planned = a.join(b, Inner, Some("a.key".attr === "b.key".attr)).queryExecution.executedPlan + + val broadcastHashJoins = planned.collect { case join: BroadcastHashJoin => join } + val shuffledHashJoins = planned.collect { case join: ShuffledHashJoin => join } + + assert(broadcastHashJoins.size === 1, "Should use broadcast hash join") + assert(shuffledHashJoins.isEmpty, "Should not use shuffled hash join") + + setConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD, origThreshold.toString) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala index 7bb08f1b513ce..cade244f7ac39 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala @@ -18,14 +18,14 @@ package org.apache.spark.sql.json import org.apache.spark.sql.catalyst.types._ +import org.apache.spark.sql.catalyst.types.decimal.Decimal import org.apache.spark.sql.catalyst.util._ import org.apache.spark.sql.json.JsonRDD.{enforceCorrectType, compatibleType} -import org.apache.spark.sql.QueryTest -import org.apache.spark.sql.SQLConf +import org.apache.spark.sql.{Row, SQLConf, QueryTest} import org.apache.spark.sql.test.TestSQLContext import org.apache.spark.sql.test.TestSQLContext._ -import java.sql.Timestamp +import java.sql.{Date, Timestamp} class JsonSuite extends QueryTest { import TestJsonData._ @@ -44,22 +44,28 @@ class JsonSuite extends QueryTest { checkTypePromotion(intNumber, enforceCorrectType(intNumber, IntegerType)) checkTypePromotion(intNumber.toLong, enforceCorrectType(intNumber, LongType)) checkTypePromotion(intNumber.toDouble, enforceCorrectType(intNumber, DoubleType)) - checkTypePromotion(BigDecimal(intNumber), enforceCorrectType(intNumber, DecimalType)) + checkTypePromotion( + Decimal(intNumber), enforceCorrectType(intNumber, DecimalType.Unlimited)) val longNumber: Long = 9223372036854775807L checkTypePromotion(longNumber, enforceCorrectType(longNumber, LongType)) checkTypePromotion(longNumber.toDouble, enforceCorrectType(longNumber, DoubleType)) - checkTypePromotion(BigDecimal(longNumber), enforceCorrectType(longNumber, DecimalType)) + checkTypePromotion( + Decimal(longNumber), enforceCorrectType(longNumber, DecimalType.Unlimited)) val doubleNumber: Double = 1.7976931348623157E308d checkTypePromotion(doubleNumber.toDouble, enforceCorrectType(doubleNumber, DoubleType)) - checkTypePromotion(BigDecimal(doubleNumber), enforceCorrectType(doubleNumber, DecimalType)) - + checkTypePromotion( + Decimal(doubleNumber), enforceCorrectType(doubleNumber, DecimalType.Unlimited)) + checkTypePromotion(new Timestamp(intNumber), enforceCorrectType(intNumber, TimestampType)) - checkTypePromotion(new Timestamp(intNumber.toLong), + checkTypePromotion(new Timestamp(intNumber.toLong), enforceCorrectType(intNumber.toLong, TimestampType)) - val strDate = "2014-09-30 12:34:56" - checkTypePromotion(Timestamp.valueOf(strDate), enforceCorrectType(strDate, TimestampType)) + val strTime = "2014-09-30 12:34:56" + checkTypePromotion(Timestamp.valueOf(strTime), enforceCorrectType(strTime, TimestampType)) + + val strDate = "2014-10-15" + checkTypePromotion(Date.valueOf(strDate), enforceCorrectType(strDate, DateType)) } test("Get compatible type") { @@ -77,7 +83,7 @@ class JsonSuite extends QueryTest { checkDataType(NullType, IntegerType, IntegerType) checkDataType(NullType, LongType, LongType) checkDataType(NullType, DoubleType, DoubleType) - checkDataType(NullType, DecimalType, DecimalType) + checkDataType(NullType, DecimalType.Unlimited, DecimalType.Unlimited) checkDataType(NullType, StringType, StringType) checkDataType(NullType, ArrayType(IntegerType), ArrayType(IntegerType)) checkDataType(NullType, StructType(Nil), StructType(Nil)) @@ -88,7 +94,7 @@ class JsonSuite extends QueryTest { checkDataType(BooleanType, IntegerType, StringType) checkDataType(BooleanType, LongType, StringType) checkDataType(BooleanType, DoubleType, StringType) - checkDataType(BooleanType, DecimalType, StringType) + checkDataType(BooleanType, DecimalType.Unlimited, StringType) checkDataType(BooleanType, StringType, StringType) checkDataType(BooleanType, ArrayType(IntegerType), StringType) checkDataType(BooleanType, StructType(Nil), StringType) @@ -97,7 +103,7 @@ class JsonSuite extends QueryTest { checkDataType(IntegerType, IntegerType, IntegerType) checkDataType(IntegerType, LongType, LongType) checkDataType(IntegerType, DoubleType, DoubleType) - checkDataType(IntegerType, DecimalType, DecimalType) + checkDataType(IntegerType, DecimalType.Unlimited, DecimalType.Unlimited) checkDataType(IntegerType, StringType, StringType) checkDataType(IntegerType, ArrayType(IntegerType), StringType) checkDataType(IntegerType, StructType(Nil), StringType) @@ -105,23 +111,23 @@ class JsonSuite extends QueryTest { // LongType checkDataType(LongType, LongType, LongType) checkDataType(LongType, DoubleType, DoubleType) - checkDataType(LongType, DecimalType, DecimalType) + checkDataType(LongType, DecimalType.Unlimited, DecimalType.Unlimited) checkDataType(LongType, StringType, StringType) checkDataType(LongType, ArrayType(IntegerType), StringType) checkDataType(LongType, StructType(Nil), StringType) // DoubleType checkDataType(DoubleType, DoubleType, DoubleType) - checkDataType(DoubleType, DecimalType, DecimalType) + checkDataType(DoubleType, DecimalType.Unlimited, DecimalType.Unlimited) checkDataType(DoubleType, StringType, StringType) checkDataType(DoubleType, ArrayType(IntegerType), StringType) checkDataType(DoubleType, StructType(Nil), StringType) // DoubleType - checkDataType(DecimalType, DecimalType, DecimalType) - checkDataType(DecimalType, StringType, StringType) - checkDataType(DecimalType, ArrayType(IntegerType), StringType) - checkDataType(DecimalType, StructType(Nil), StringType) + checkDataType(DecimalType.Unlimited, DecimalType.Unlimited, DecimalType.Unlimited) + checkDataType(DecimalType.Unlimited, StringType, StringType) + checkDataType(DecimalType.Unlimited, ArrayType(IntegerType), StringType) + checkDataType(DecimalType.Unlimited, StructType(Nil), StringType) // StringType checkDataType(StringType, StringType, StringType) @@ -175,7 +181,7 @@ class JsonSuite extends QueryTest { checkDataType( StructType( StructField("f1", IntegerType, true) :: Nil), - DecimalType, + DecimalType.Unlimited, StringType) } @@ -183,7 +189,7 @@ class JsonSuite extends QueryTest { val jsonSchemaRDD = jsonRDD(primitiveFieldAndType) val expectedSchema = StructType( - StructField("bigInteger", DecimalType, true) :: + StructField("bigInteger", DecimalType.Unlimited, true) :: StructField("boolean", BooleanType, true) :: StructField("double", DoubleType, true) :: StructField("integer", IntegerType, true) :: @@ -208,12 +214,12 @@ class JsonSuite extends QueryTest { } test("Complex field and type inferring") { - val jsonSchemaRDD = jsonRDD(complexFieldAndType) + val jsonSchemaRDD = jsonRDD(complexFieldAndType1) val expectedSchema = StructType( StructField("arrayOfArray1", ArrayType(ArrayType(StringType, false), false), true) :: StructField("arrayOfArray2", ArrayType(ArrayType(DoubleType, false), false), true) :: - StructField("arrayOfBigInteger", ArrayType(DecimalType, false), true) :: + StructField("arrayOfBigInteger", ArrayType(DecimalType.Unlimited, false), true) :: StructField("arrayOfBoolean", ArrayType(BooleanType, false), true) :: StructField("arrayOfDouble", ArrayType(DoubleType, false), true) :: StructField("arrayOfInteger", ArrayType(IntegerType, false), true) :: @@ -226,8 +232,8 @@ class JsonSuite extends QueryTest { StructField("field2", StringType, true) :: StructField("field3", StringType, true) :: Nil), false), true) :: StructField("struct", StructType( - StructField("field1", BooleanType, true) :: - StructField("field2", DecimalType, true) :: Nil), true) :: + StructField("field1", BooleanType, true) :: + StructField("field2", DecimalType.Unlimited, true) :: Nil), true) :: StructField("structWithArrayFields", StructType( StructField("field1", ArrayType(IntegerType, false), true) :: StructField("field2", ArrayType(StringType, false), true) :: Nil), true) :: Nil) @@ -285,8 +291,8 @@ class JsonSuite extends QueryTest { // Access a struct and fields inside of it. checkAnswer( sql("select struct, struct.field1, struct.field2 from jsonTable"), - ( - Seq(true, BigDecimal("92233720368547758070")), + Row( + Row(true, BigDecimal("92233720368547758070")), true, BigDecimal("92233720368547758070")) :: Nil ) @@ -305,7 +311,7 @@ class JsonSuite extends QueryTest { } ignore("Complex field and type inferring (Ignored)") { - val jsonSchemaRDD = jsonRDD(complexFieldAndType) + val jsonSchemaRDD = jsonRDD(complexFieldAndType1) jsonSchemaRDD.registerTempTable("jsonTable") // Right now, "field1" and "field2" are treated as aliases. We should fix it. @@ -328,7 +334,7 @@ class JsonSuite extends QueryTest { val expectedSchema = StructType( StructField("num_bool", StringType, true) :: StructField("num_num_1", LongType, true) :: - StructField("num_num_2", DecimalType, true) :: + StructField("num_num_2", DecimalType.Unlimited, true) :: StructField("num_num_3", DoubleType, true) :: StructField("num_str", StringType, true) :: StructField("str_bool", StringType, true) :: Nil) @@ -380,6 +386,12 @@ class JsonSuite extends QueryTest { 92233720368547758071.2 ) + // Number and String conflict: resolve the type as number in this query. + checkAnswer( + sql("select num_str + 1.2 from jsonTable where num_str > 92233720368547758060"), + BigDecimal("92233720368547758061.2").toDouble + ) + // String and Boolean conflict: resolve the type as string. checkAnswer( sql("select * from jsonTable where str_bool = 'str1'"), @@ -415,13 +427,6 @@ class JsonSuite extends QueryTest { false ) - // Right now, we have a parsing error. - // Number and String conflict: resolve the type as number in this query. - checkAnswer( - sql("select num_str + 1.2 from jsonTable where num_str > 92233720368547758060"), - BigDecimal("92233720368547758061.2") - ) - // The plan of the following DSL is // Project [(CAST(num_str#65:4, DoubleType) + 1.2) AS num#78] // Filter (CAST(CAST(num_str#65:4, DoubleType), DecimalType) > 92233720368547758060) @@ -477,7 +482,8 @@ class JsonSuite extends QueryTest { val expectedSchema = StructType( StructField("array1", ArrayType(StringType, true), true) :: StructField("array2", ArrayType(StructType( - StructField("field", LongType, true) :: Nil), false), true) :: Nil) + StructField("field", LongType, true) :: Nil), false), true) :: + StructField("array3", ArrayType(StringType, false), true) :: Nil) assert(expectedSchema === jsonSchemaRDD.schema) @@ -486,12 +492,14 @@ class JsonSuite extends QueryTest { checkAnswer( sql("select * from jsonTable"), Seq(Seq("1", "1.1", "true", null, "[]", "{}", "[2,3,4]", - """{"field":str}"""), Seq(Seq(214748364700L), Seq(1))) :: Nil + """{"field":"str"}"""), Seq(Seq(214748364700L), Seq(1)), null) :: + Seq(null, null, Seq("""{"field":"str"}""", """{"field":1}""")) :: + Seq(null, null, Seq("1", "2", "3")) :: Nil ) // Treat an element as a number. checkAnswer( - sql("select array1[0] + 1 from jsonTable"), + sql("select array1[0] + 1 from jsonTable where array1 is not null"), 2 ) } @@ -519,7 +527,7 @@ class JsonSuite extends QueryTest { val jsonSchemaRDD = jsonFile(path) val expectedSchema = StructType( - StructField("bigInteger", DecimalType, true) :: + StructField("bigInteger", DecimalType.Unlimited, true) :: StructField("boolean", BooleanType, true) :: StructField("double", DoubleType, true) :: StructField("integer", IntegerType, true) :: @@ -543,13 +551,39 @@ class JsonSuite extends QueryTest { ) } + test("Loading a JSON dataset from a text file with SQL") { + val file = getTempFilePath("json") + val path = file.toString + primitiveFieldAndType.map(record => record.replaceAll("\n", " ")).saveAsTextFile(path) + + sql( + s""" + |CREATE TEMPORARY TABLE jsonTableSQL + |USING org.apache.spark.sql.json + |OPTIONS ( + | path '$path' + |) + """.stripMargin) + + checkAnswer( + sql("select * from jsonTableSQL"), + (BigDecimal("92233720368547758070"), + true, + 1.7976931348623157E308, + 10, + 21474836470L, + null, + "this is a simple string.") :: Nil + ) + } + test("Applying schemas") { val file = getTempFilePath("json") val path = file.toString primitiveFieldAndType.map(record => record.replaceAll("\n", " ")).saveAsTextFile(path) val schema = StructType( - StructField("bigInteger", DecimalType, true) :: + StructField("bigInteger", DecimalType.Unlimited, true) :: StructField("boolean", BooleanType, true) :: StructField("double", DoubleType, true) :: StructField("integer", IntegerType, true) :: @@ -707,4 +741,35 @@ class JsonSuite extends QueryTest { TestSQLContext.setConf(SQLConf.COLUMN_NAME_OF_CORRUPT_RECORD, oldColumnNameOfCorruptRecord) } + + test("SPARK-4068: nulls in arrays") { + val jsonSchemaRDD = jsonRDD(nullsInArrays) + jsonSchemaRDD.registerTempTable("jsonTable") + + val schema = StructType( + StructField("field1", + ArrayType(ArrayType(ArrayType(ArrayType(StringType, false), false), true), false), true) :: + StructField("field2", + ArrayType(ArrayType( + StructType(StructField("Test", IntegerType, true) :: Nil), false), true), true) :: + StructField("field3", + ArrayType(ArrayType( + StructType(StructField("Test", StringType, true) :: Nil), true), false), true) :: + StructField("field4", + ArrayType(ArrayType(ArrayType(IntegerType, false), true), false), true) :: Nil) + + assert(schema === jsonSchemaRDD.schema) + + checkAnswer( + sql( + """ + |SELECT field1, field2, field3, field4 + |FROM jsonTable + """.stripMargin), + Seq(Seq(Seq(null), Seq(Seq(Seq("Test")))), null, null, null) :: + Seq(null, Seq(null, Seq(Seq(1))), null, null) :: + Seq(null, null, Seq(Seq(null), Seq(Seq("2"))), null) :: + Seq(null, null, null, Seq(Seq(null, Seq(1, 2, 3)))) :: Nil + ) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/json/TestJsonData.scala b/sql/core/src/test/scala/org/apache/spark/sql/json/TestJsonData.scala index eaca9f0508a12..e5773a55875bc 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/json/TestJsonData.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/json/TestJsonData.scala @@ -32,22 +32,6 @@ object TestJsonData { "null":null }""" :: Nil) - val complexFieldAndType = - TestSQLContext.sparkContext.parallelize( - """{"struct":{"field1": true, "field2": 92233720368547758070}, - "structWithArrayFields":{"field1":[4, 5, 6], "field2":["str1", "str2"]}, - "arrayOfString":["str1", "str2"], - "arrayOfInteger":[1, 2147483647, -2147483648], - "arrayOfLong":[21474836470, 9223372036854775807, -9223372036854775808], - "arrayOfBigInteger":[922337203685477580700, -922337203685477580800], - "arrayOfDouble":[1.2, 1.7976931348623157E308, 4.9E-324, 2.2250738585072014E-308], - "arrayOfBoolean":[true, false, true], - "arrayOfNull":[null, null, null, null], - "arrayOfStruct":[{"field1": true, "field2": "str1"}, {"field1": false}, {"field3": null}], - "arrayOfArray1":[[1, 2, 3], ["str1", "str2"]], - "arrayOfArray2":[[1, 2, 3], [1.1, 2.1, 3.1]] - }""" :: Nil) - val primitiveFieldValueTypeConflict = TestSQLContext.sparkContext.parallelize( """{"num_num_1":11, "num_num_2":null, "num_num_3": 1.1, @@ -73,7 +57,9 @@ object TestJsonData { val arrayElementTypeConflict = TestSQLContext.sparkContext.parallelize( """{"array1": [1, 1.1, true, null, [], {}, [2,3,4], {"field":"str"}], - "array2": [{"field":214748364700}, {"field":1}]}""" :: Nil) + "array2": [{"field":214748364700}, {"field":1}]}""" :: + """{"array3": [{"field":"str"}, {"field":1}]}""" :: + """{"array3": [1, 2, 3]}""" :: Nil) val missingFields = TestSQLContext.sparkContext.parallelize( @@ -83,6 +69,22 @@ object TestJsonData { """{"d":{"field":true}}""" :: """{"e":"str"}""" :: Nil) + val complexFieldAndType1 = + TestSQLContext.sparkContext.parallelize( + """{"struct":{"field1": true, "field2": 92233720368547758070}, + "structWithArrayFields":{"field1":[4, 5, 6], "field2":["str1", "str2"]}, + "arrayOfString":["str1", "str2"], + "arrayOfInteger":[1, 2147483647, -2147483648], + "arrayOfLong":[21474836470, 9223372036854775807, -9223372036854775808], + "arrayOfBigInteger":[922337203685477580700, -922337203685477580800], + "arrayOfDouble":[1.2, 1.7976931348623157E308, 4.9E-324, 2.2250738585072014E-308], + "arrayOfBoolean":[true, false, true], + "arrayOfNull":[null, null, null, null], + "arrayOfStruct":[{"field1": true, "field2": "str1"}, {"field1": false}, {"field3": null}], + "arrayOfArray1":[[1, 2, 3], ["str1", "str2"]], + "arrayOfArray2":[[1, 2, 3], [1.1, 2.1, 3.1]] + }""" :: Nil) + val complexFieldAndType2 = TestSQLContext.sparkContext.parallelize( """{"arrayOfStruct":[{"field1": true, "field2": "str1"}, {"field1": false}, {"field3": null}], @@ -137,6 +139,13 @@ object TestJsonData { ]] }""" :: Nil) + val nullsInArrays = + TestSQLContext.sparkContext.parallelize( + """{"field1":[[null], [[["Test"]]]]}""" :: + """{"field2":[null, [{"Test":1}]]}""" :: + """{"field3":[[null], [{"Test":"2"}]]}""" :: + """{"field4":[[null, [1,2,3]]]}""" :: Nil) + val jsonArray = TestSQLContext.sparkContext.parallelize( """[{"a":"str_a_1"}]""" :: diff --git a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala index 25e41ecf28e2e..08d9da27f1b11 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala @@ -77,6 +77,8 @@ case class AllDataTypesWithNonPrimitiveType( case class BinaryData(binaryData: Array[Byte]) +case class NumericData(i: Int, d: Double) + class ParquetQuerySuite extends QueryTest with FunSuiteLike with BeforeAndAfterAll { TestData // Load test data tables. @@ -560,6 +562,63 @@ class ParquetQuerySuite extends QueryTest with FunSuiteLike with BeforeAndAfterA assert(stringResult.size === 1) assert(stringResult(0).getString(2) == "100", "stringvalue incorrect") assert(stringResult(0).getInt(1) === 100) + + val query7 = sql(s"SELECT * FROM testfiltersource WHERE myoptint < 40") + assert( + query7.queryExecution.executedPlan(0)(0).isInstanceOf[ParquetTableScan], + "Top operator should be ParquetTableScan after pushdown") + val optResult = query7.collect() + assert(optResult.size === 20) + for(i <- 0 until 20) { + if (optResult(i)(7) != i * 2) { + fail(s"optional Int value in result row $i should be ${2*4*i}") + } + } + for(myval <- Seq("myoptint", "myoptlong", "myoptdouble", "myoptfloat")) { + val query8 = sql(s"SELECT * FROM testfiltersource WHERE $myval < 150 AND $myval >= 100") + assert( + query8.queryExecution.executedPlan(0)(0).isInstanceOf[ParquetTableScan], + "Top operator should be ParquetTableScan after pushdown") + val result8 = query8.collect() + assert(result8.size === 25) + assert(result8(0)(7) === 100) + assert(result8(24)(7) === 148) + val query9 = sql(s"SELECT * FROM testfiltersource WHERE $myval > 150 AND $myval <= 200") + assert( + query9.queryExecution.executedPlan(0)(0).isInstanceOf[ParquetTableScan], + "Top operator should be ParquetTableScan after pushdown") + val result9 = query9.collect() + assert(result9.size === 25) + if (myval == "myoptint" || myval == "myoptlong") { + assert(result9(0)(7) === 152) + assert(result9(24)(7) === 200) + } else { + assert(result9(0)(7) === 150) + assert(result9(24)(7) === 198) + } + } + val query10 = sql("SELECT * FROM testfiltersource WHERE myoptstring = \"100\"") + assert( + query10.queryExecution.executedPlan(0)(0).isInstanceOf[ParquetTableScan], + "Top operator should be ParquetTableScan after pushdown") + val result10 = query10.collect() + assert(result10.size === 1) + assert(result10(0).getString(8) == "100", "stringvalue incorrect") + assert(result10(0).getInt(7) === 100) + val query11 = sql(s"SELECT * FROM testfiltersource WHERE myoptboolean = true AND myoptint < 40") + assert( + query11.queryExecution.executedPlan(0)(0).isInstanceOf[ParquetTableScan], + "Top operator should be ParquetTableScan after pushdown") + val result11 = query11.collect() + assert(result11.size === 7) + for(i <- 0 until 6) { + if (!result11(i).getBoolean(6)) { + fail(s"optional Boolean value in result row $i not true") + } + if (result11(i).getInt(7) != i * 6) { + fail(s"optional Int value in result row $i should be ${6*i}") + } + } } test("SPARK-1913 regression: columns only referenced by pushed down filters should remain") { @@ -812,4 +871,35 @@ class ParquetQuerySuite extends QueryTest with FunSuiteLike with BeforeAndAfterA assert(a.dataType === b.dataType) } } + + test("read/write fixed-length decimals") { + for ((precision, scale) <- Seq((5, 2), (1, 0), (1, 1), (18, 10), (18, 17))) { + val tempDir = getTempFilePath("parquetTest").getCanonicalPath + val data = sparkContext.parallelize(0 to 1000) + .map(i => NumericData(i, i / 100.0)) + .select('i, 'd cast DecimalType(precision, scale)) + data.saveAsParquetFile(tempDir) + checkAnswer(parquetFile(tempDir), data.toSchemaRDD.collect().toSeq) + } + + // Decimals with precision above 18 are not yet supported + intercept[RuntimeException] { + val tempDir = getTempFilePath("parquetTest").getCanonicalPath + val data = sparkContext.parallelize(0 to 1000) + .map(i => NumericData(i, i / 100.0)) + .select('i, 'd cast DecimalType(19, 10)) + data.saveAsParquetFile(tempDir) + checkAnswer(parquetFile(tempDir), data.toSchemaRDD.collect().toSeq) + } + + // Unlimited-length decimals are not yet supported + intercept[RuntimeException] { + val tempDir = getTempFilePath("parquetTest").getCanonicalPath + val data = sparkContext.parallelize(0 to 1000) + .map(i => NumericData(i, i / 100.0)) + .select('i, 'd cast DecimalType.Unlimited) + data.saveAsParquetFile(tempDir) + checkAnswer(parquetFile(tempDir), data.toSchemaRDD.collect().toSeq) + } + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/DataSourceTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/DataSourceTest.scala new file mode 100644 index 0000000000000..9626252e742e5 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/DataSourceTest.scala @@ -0,0 +1,34 @@ +/* +* Licensed to the Apache Software Foundation (ASF) under one or more +* contributor license agreements. See the NOTICE file distributed with +* this work for additional information regarding copyright ownership. +* The ASF licenses this file to You under the Apache License, Version 2.0 +* (the "License"); you may not use this file except in compliance with +* the License. You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +package org.apache.spark.sql.sources + +import org.apache.spark.sql._ +import org.apache.spark.sql.catalyst.analysis.Analyzer +import org.apache.spark.sql.test.TestSQLContext +import org.scalatest.BeforeAndAfter + +abstract class DataSourceTest extends QueryTest with BeforeAndAfter { + // Case sensitivity is not configurable yet, but we want to test some edge cases. + // TODO: Remove when it is configurable + implicit val caseInsensisitiveContext = new SQLContext(TestSQLContext.sparkContext) { + @transient + override protected[sql] lazy val analyzer: Analyzer = + new Analyzer(catalog, functionRegistry, caseSensitive = false) + } +} + diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/FilteredScanSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/FilteredScanSuite.scala new file mode 100644 index 0000000000000..8b2f1591d5bf3 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/FilteredScanSuite.scala @@ -0,0 +1,176 @@ +/* +* Licensed to the Apache Software Foundation (ASF) under one or more +* contributor license agreements. See the NOTICE file distributed with +* this work for additional information regarding copyright ownership. +* The ASF licenses this file to You under the Apache License, Version 2.0 +* (the "License"); you may not use this file except in compliance with +* the License. You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +package org.apache.spark.sql.sources + +import scala.language.existentials + +import org.apache.spark.sql._ + +class FilteredScanSource extends RelationProvider { + override def createRelation( + sqlContext: SQLContext, + parameters: Map[String, String]): BaseRelation = { + SimpleFilteredScan(parameters("from").toInt, parameters("to").toInt)(sqlContext) + } +} + +case class SimpleFilteredScan(from: Int, to: Int)(@transient val sqlContext: SQLContext) + extends PrunedFilteredScan { + + override def schema = + StructType( + StructField("a", IntegerType, nullable = false) :: + StructField("b", IntegerType, nullable = false) :: Nil) + + override def buildScan(requiredColumns: Array[String], filters: Array[Filter]) = { + val rowBuilders = requiredColumns.map { + case "a" => (i: Int) => Seq(i) + case "b" => (i: Int) => Seq(i * 2) + } + + FiltersPushed.list = filters + + val filterFunctions = filters.collect { + case EqualTo("a", v) => (a: Int) => a == v + case LessThan("a", v: Int) => (a: Int) => a < v + case LessThanOrEqual("a", v: Int) => (a: Int) => a <= v + case GreaterThan("a", v: Int) => (a: Int) => a > v + case GreaterThanOrEqual("a", v: Int) => (a: Int) => a >= v + } + + def eval(a: Int) = !filterFunctions.map(_(a)).contains(false) + + sqlContext.sparkContext.parallelize(from to to).filter(eval).map(i => + Row.fromSeq(rowBuilders.map(_(i)).reduceOption(_ ++ _).getOrElse(Seq.empty))) + } +} + +// A hack for better error messages when filter pushdown fails. +object FiltersPushed { + var list: Seq[Filter] = Nil +} + +class FilteredScanSuite extends DataSourceTest { + + import caseInsensisitiveContext._ + + before { + sql( + """ + |CREATE TEMPORARY TABLE oneToTenFiltered + |USING org.apache.spark.sql.sources.FilteredScanSource + |OPTIONS ( + | from '1', + | to '10' + |) + """.stripMargin) + } + + sqlTest( + "SELECT * FROM oneToTenFiltered", + (1 to 10).map(i => Row(i, i * 2)).toSeq) + + sqlTest( + "SELECT a, b FROM oneToTenFiltered", + (1 to 10).map(i => Row(i, i * 2)).toSeq) + + sqlTest( + "SELECT b, a FROM oneToTenFiltered", + (1 to 10).map(i => Row(i * 2, i)).toSeq) + + sqlTest( + "SELECT a FROM oneToTenFiltered", + (1 to 10).map(i => Row(i)).toSeq) + + sqlTest( + "SELECT b FROM oneToTenFiltered", + (1 to 10).map(i => Row(i * 2)).toSeq) + + sqlTest( + "SELECT a * 2 FROM oneToTenFiltered", + (1 to 10).map(i => Row(i * 2)).toSeq) + + sqlTest( + "SELECT A AS b FROM oneToTenFiltered", + (1 to 10).map(i => Row(i)).toSeq) + + sqlTest( + "SELECT x.b, y.a FROM oneToTenFiltered x JOIN oneToTenFiltered y ON x.a = y.b", + (1 to 5).map(i => Row(i * 4, i)).toSeq) + + sqlTest( + "SELECT x.a, y.b FROM oneToTenFiltered x JOIN oneToTenFiltered y ON x.a = y.b", + (2 to 10 by 2).map(i => Row(i, i)).toSeq) + + sqlTest( + "SELECT * FROM oneToTenFiltered WHERE a = 1", + Seq(1).map(i => Row(i, i * 2)).toSeq) + + sqlTest( + "SELECT * FROM oneToTenFiltered WHERE A = 1", + Seq(1).map(i => Row(i, i * 2)).toSeq) + + sqlTest( + "SELECT * FROM oneToTenFiltered WHERE b = 2", + Seq(1).map(i => Row(i, i * 2)).toSeq) + + testPushDown("SELECT * FROM oneToTenFiltered WHERE A = 1", 1) + testPushDown("SELECT a FROM oneToTenFiltered WHERE A = 1", 1) + testPushDown("SELECT b FROM oneToTenFiltered WHERE A = 1", 1) + testPushDown("SELECT a, b FROM oneToTenFiltered WHERE A = 1", 1) + testPushDown("SELECT * FROM oneToTenFiltered WHERE a = 1", 1) + testPushDown("SELECT * FROM oneToTenFiltered WHERE 1 = a", 1) + + testPushDown("SELECT * FROM oneToTenFiltered WHERE a > 1", 9) + testPushDown("SELECT * FROM oneToTenFiltered WHERE a >= 2", 9) + + testPushDown("SELECT * FROM oneToTenFiltered WHERE 1 < a", 9) + testPushDown("SELECT * FROM oneToTenFiltered WHERE 2 <= a", 9) + + testPushDown("SELECT * FROM oneToTenFiltered WHERE 1 > a", 0) + testPushDown("SELECT * FROM oneToTenFiltered WHERE 2 >= a", 2) + + testPushDown("SELECT * FROM oneToTenFiltered WHERE a < 1", 0) + testPushDown("SELECT * FROM oneToTenFiltered WHERE a <= 2", 2) + + testPushDown("SELECT * FROM oneToTenFiltered WHERE a > 1 AND a < 10", 8) + + testPushDown("SELECT * FROM oneToTenFiltered WHERE a = 20", 0) + testPushDown("SELECT * FROM oneToTenFiltered WHERE b = 1", 10) + + def testPushDown(sqlString: String, expectedCount: Int): Unit = { + test(s"PushDown Returns $expectedCount: $sqlString") { + val queryExecution = sql(sqlString).queryExecution + val rawPlan = queryExecution.executedPlan.collect { + case p: execution.PhysicalRDD => p + } match { + case Seq(p) => p + case _ => fail(s"More than one PhysicalRDD found\n$queryExecution") + } + val rawCount = rawPlan.execute().count() + + if (rawCount != expectedCount) { + fail( + s"Wrong # of results for pushed filter. Got $rawCount, Expected $expectedCount\n" + + s"Filters pushed: ${FiltersPushed.list.mkString(",")}\n" + + queryExecution) + } + } + } +} + diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/PrunedScanSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/PrunedScanSuite.scala new file mode 100644 index 0000000000000..fee2e22611cdc --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/PrunedScanSuite.scala @@ -0,0 +1,137 @@ +/* +* Licensed to the Apache Software Foundation (ASF) under one or more +* contributor license agreements. See the NOTICE file distributed with +* this work for additional information regarding copyright ownership. +* The ASF licenses this file to You under the Apache License, Version 2.0 +* (the "License"); you may not use this file except in compliance with +* the License. You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +package org.apache.spark.sql.sources + +import org.apache.spark.sql._ + +class PrunedScanSource extends RelationProvider { + override def createRelation( + sqlContext: SQLContext, + parameters: Map[String, String]): BaseRelation = { + SimplePrunedScan(parameters("from").toInt, parameters("to").toInt)(sqlContext) + } +} + +case class SimplePrunedScan(from: Int, to: Int)(@transient val sqlContext: SQLContext) + extends PrunedScan { + + override def schema = + StructType( + StructField("a", IntegerType, nullable = false) :: + StructField("b", IntegerType, nullable = false) :: Nil) + + override def buildScan(requiredColumns: Array[String]) = { + val rowBuilders = requiredColumns.map { + case "a" => (i: Int) => Seq(i) + case "b" => (i: Int) => Seq(i * 2) + } + + sqlContext.sparkContext.parallelize(from to to).map(i => + Row.fromSeq(rowBuilders.map(_(i)).reduceOption(_ ++ _).getOrElse(Seq.empty))) + } +} + +class PrunedScanSuite extends DataSourceTest { + import caseInsensisitiveContext._ + + before { + sql( + """ + |CREATE TEMPORARY TABLE oneToTenPruned + |USING org.apache.spark.sql.sources.PrunedScanSource + |OPTIONS ( + | from '1', + | to '10' + |) + """.stripMargin) + } + + sqlTest( + "SELECT * FROM oneToTenPruned", + (1 to 10).map(i => Row(i, i * 2)).toSeq) + + sqlTest( + "SELECT a, b FROM oneToTenPruned", + (1 to 10).map(i => Row(i, i * 2)).toSeq) + + sqlTest( + "SELECT b, a FROM oneToTenPruned", + (1 to 10).map(i => Row(i * 2, i)).toSeq) + + sqlTest( + "SELECT a FROM oneToTenPruned", + (1 to 10).map(i => Row(i)).toSeq) + + sqlTest( + "SELECT a, a FROM oneToTenPruned", + (1 to 10).map(i => Row(i, i)).toSeq) + + sqlTest( + "SELECT b FROM oneToTenPruned", + (1 to 10).map(i => Row(i * 2)).toSeq) + + sqlTest( + "SELECT a * 2 FROM oneToTenPruned", + (1 to 10).map(i => Row(i * 2)).toSeq) + + sqlTest( + "SELECT A AS b FROM oneToTenPruned", + (1 to 10).map(i => Row(i)).toSeq) + + sqlTest( + "SELECT x.b, y.a FROM oneToTenPruned x JOIN oneToTenPruned y ON x.a = y.b", + (1 to 5).map(i => Row(i * 4, i)).toSeq) + + sqlTest( + "SELECT x.a, y.b FROM oneToTenPruned x JOIN oneToTenPruned y ON x.a = y.b", + (2 to 10 by 2).map(i => Row(i, i)).toSeq) + + testPruning("SELECT * FROM oneToTenPruned", "a", "b") + testPruning("SELECT a, b FROM oneToTenPruned", "a", "b") + testPruning("SELECT b, a FROM oneToTenPruned", "b", "a") + testPruning("SELECT b, b FROM oneToTenPruned", "b") + testPruning("SELECT a FROM oneToTenPruned", "a") + testPruning("SELECT b FROM oneToTenPruned", "b") + + def testPruning(sqlString: String, expectedColumns: String*): Unit = { + test(s"Columns output ${expectedColumns.mkString(",")}: $sqlString") { + val queryExecution = sql(sqlString).queryExecution + val rawPlan = queryExecution.executedPlan.collect { + case p: execution.PhysicalRDD => p + } match { + case Seq(p) => p + case _ => fail(s"More than one PhysicalRDD found\n$queryExecution") + } + val rawColumns = rawPlan.output.map(_.name) + val rawOutput = rawPlan.execute().first() + + if (rawColumns != expectedColumns) { + fail( + s"Wrong column names. Got $rawColumns, Expected $expectedColumns\n" + + s"Filters pushed: ${FiltersPushed.list.mkString(",")}\n" + + queryExecution) + } + + if (rawOutput.size != expectedColumns.size) { + fail(s"Wrong output row. Got $rawOutput\n$queryExecution") + } + } + } + +} + diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/TableScanSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/TableScanSuite.scala new file mode 100644 index 0000000000000..b254b0620c779 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/TableScanSuite.scala @@ -0,0 +1,125 @@ +/* +* Licensed to the Apache Software Foundation (ASF) under one or more +* contributor license agreements. See the NOTICE file distributed with +* this work for additional information regarding copyright ownership. +* The ASF licenses this file to You under the Apache License, Version 2.0 +* (the "License"); you may not use this file except in compliance with +* the License. You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +package org.apache.spark.sql.sources + +import org.apache.spark.sql._ + +class DefaultSource extends SimpleScanSource + +class SimpleScanSource extends RelationProvider { + override def createRelation( + sqlContext: SQLContext, + parameters: Map[String, String]): BaseRelation = { + SimpleScan(parameters("from").toInt, parameters("to").toInt)(sqlContext) + } +} + +case class SimpleScan(from: Int, to: Int)(@transient val sqlContext: SQLContext) + extends TableScan { + + override def schema = + StructType(StructField("i", IntegerType, nullable = false) :: Nil) + + override def buildScan() = sqlContext.sparkContext.parallelize(from to to).map(Row(_)) +} + +class TableScanSuite extends DataSourceTest { + import caseInsensisitiveContext._ + + before { + sql( + """ + |CREATE TEMPORARY TABLE oneToTen + |USING org.apache.spark.sql.sources.SimpleScanSource + |OPTIONS ( + | from '1', + | to '10' + |) + """.stripMargin) + } + + sqlTest( + "SELECT * FROM oneToTen", + (1 to 10).map(Row(_)).toSeq) + + sqlTest( + "SELECT i FROM oneToTen", + (1 to 10).map(Row(_)).toSeq) + + sqlTest( + "SELECT i FROM oneToTen WHERE i < 5", + (1 to 4).map(Row(_)).toSeq) + + sqlTest( + "SELECT i * 2 FROM oneToTen", + (1 to 10).map(i => Row(i * 2)).toSeq) + + sqlTest( + "SELECT a.i, b.i FROM oneToTen a JOIN oneToTen b ON a.i = b.i + 1", + (2 to 10).map(i => Row(i, i - 1)).toSeq) + + + test("Caching") { + // Cached Query Execution + cacheTable("oneToTen") + assertCached(sql("SELECT * FROM oneToTen")) + checkAnswer( + sql("SELECT * FROM oneToTen"), + (1 to 10).map(Row(_)).toSeq) + + assertCached(sql("SELECT i FROM oneToTen")) + checkAnswer( + sql("SELECT i FROM oneToTen"), + (1 to 10).map(Row(_)).toSeq) + + assertCached(sql("SELECT i FROM oneToTen WHERE i < 5")) + checkAnswer( + sql("SELECT i FROM oneToTen WHERE i < 5"), + (1 to 4).map(Row(_)).toSeq) + + assertCached(sql("SELECT i * 2 FROM oneToTen")) + checkAnswer( + sql("SELECT i * 2 FROM oneToTen"), + (1 to 10).map(i => Row(i * 2)).toSeq) + + assertCached(sql("SELECT a.i, b.i FROM oneToTen a JOIN oneToTen b ON a.i = b.i + 1"), 2) + checkAnswer( + sql("SELECT a.i, b.i FROM oneToTen a JOIN oneToTen b ON a.i = b.i + 1"), + (2 to 10).map(i => Row(i, i - 1)).toSeq) + + // Verify uncaching + uncacheTable("oneToTen") + assertCached(sql("SELECT * FROM oneToTen"), 0) + } + + test("defaultSource") { + sql( + """ + |CREATE TEMPORARY TABLE oneToTenDef + |USING org.apache.spark.sql.sources + |OPTIONS ( + | from '1', + | to '10' + |) + """.stripMargin) + + checkAnswer( + sql("SELECT * FROM oneToTenDef"), + (1 to 10).map(Row(_)).toSeq) + } +} diff --git a/sql/hive-thriftserver/pom.xml b/sql/hive-thriftserver/pom.xml index 124fc107cb8aa..8db3010624100 100644 --- a/sql/hive-thriftserver/pom.xml +++ b/sql/hive-thriftserver/pom.xml @@ -70,6 +70,24 @@ org.scalatest scalatest-maven-plugin + + org.codehaus.mojo + build-helper-maven-plugin + + + add-default-sources + generate-sources + + add-source + + + + v${hive.version.short}/src/main/scala + + + + + org.apache.maven.plugins maven-deploy-plugin diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLDriver.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/AbstractSparkSQLDriver.scala similarity index 86% rename from sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLDriver.scala rename to sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/AbstractSparkSQLDriver.scala index a5c457c677564..fcb302edbffa8 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLDriver.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/AbstractSparkSQLDriver.scala @@ -29,11 +29,11 @@ import org.apache.hadoop.hive.ql.processors.CommandProcessorResponse import org.apache.spark.Logging import org.apache.spark.sql.hive.{HiveContext, HiveMetastoreTypes} -private[hive] class SparkSQLDriver(val context: HiveContext = SparkSQLEnv.hiveContext) - extends Driver with Logging { +private[hive] abstract class AbstractSparkSQLDriver( + val context: HiveContext = SparkSQLEnv.hiveContext) extends Driver with Logging { - private var tableSchema: Schema = _ - private var hiveResponse: Seq[String] = _ + private[hive] var tableSchema: Schema = _ + private[hive] var hiveResponse: Seq[String] = _ override def init(): Unit = { } @@ -74,16 +74,6 @@ private[hive] class SparkSQLDriver(val context: HiveContext = SparkSQLEnv.hiveCo override def getSchema: Schema = tableSchema - override def getResults(res: JArrayList[String]): Boolean = { - if (hiveResponse == null) { - false - } else { - res.addAll(hiveResponse) - hiveResponse = null - true - } - } - override def destroy() { super.destroy() hiveResponse = null diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2.scala index 3d468d804622c..bd4e99492b395 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2.scala @@ -17,11 +17,8 @@ package org.apache.spark.sql.hive.thriftserver -import scala.collection.JavaConversions._ - import org.apache.commons.logging.LogFactory import org.apache.hadoop.hive.conf.HiveConf -import org.apache.hadoop.hive.ql.session.SessionState import org.apache.hive.service.cli.thrift.ThriftBinaryCLIService import org.apache.hive.service.server.{HiveServer2, ServerOptionsProcessor} @@ -51,24 +48,12 @@ object HiveThriftServer2 extends Logging { def main(args: Array[String]) { val optionsProcessor = new ServerOptionsProcessor("HiveThriftServer2") - if (!optionsProcessor.process(args)) { System.exit(-1) } - val ss = new SessionState(new HiveConf(classOf[SessionState])) - - // Set all properties specified via command line. - val hiveConf: HiveConf = ss.getConf - hiveConf.getAllProperties.toSeq.sortBy(_._1).foreach { case (k, v) => - logDebug(s"HiveConf var: $k=$v") - } - - SessionState.start(ss) - logInfo("Starting SparkContext") SparkSQLEnv.init() - SessionState.start(ss) Runtime.getRuntime.addShutdownHook( new Thread() { @@ -80,7 +65,7 @@ object HiveThriftServer2 extends Logging { try { val server = new HiveThriftServer2(SparkSQLEnv.hiveContext) - server.init(hiveConf) + server.init(SparkSQLEnv.hiveContext.hiveconf) server.start() logInfo("HiveThriftServer2 started") } catch { diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIDriver.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIDriver.scala index 7ba4564602ecd..2cd02ae9269f5 100755 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIDriver.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIDriver.scala @@ -38,6 +38,8 @@ import org.apache.hadoop.hive.shims.ShimLoader import org.apache.thrift.transport.TSocket import org.apache.spark.Logging +import org.apache.spark.sql.hive.HiveShim +import org.apache.spark.sql.hive.thriftserver.HiveThriftServerShim private[hive] object SparkSQLCLIDriver { private var prompt = "spark-sql" @@ -116,7 +118,7 @@ private[hive] object SparkSQLCLIDriver { } } - if (!sessionState.isRemoteMode && !ShimLoader.getHadoopShims.usesJobShell()) { + if (!sessionState.isRemoteMode) { // Hadoop-20 and above - we need to augment classpath using hiveconf // components. // See also: code in ExecDriver.java @@ -258,7 +260,7 @@ private[hive] class SparkSQLCLIDriver extends CliDriver with Logging { } else { var ret = 0 val hconf = conf.asInstanceOf[HiveConf] - val proc: CommandProcessor = CommandProcessorFactory.get(tokens(0), hconf) + val proc: CommandProcessor = HiveShim.getCommandProcessor(Array(tokens(0)), hconf) if (proc != null) { if (proc.isInstanceOf[Driver] || proc.isInstanceOf[SetProcessor]) { diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIService.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIService.scala index 42cbf363b274f..ecfb74473e921 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIService.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIService.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.hive.thriftserver +import java.util.jar.Attributes.Name + import scala.collection.JavaConversions._ import java.io.IOException @@ -24,15 +26,17 @@ import java.util.{List => JList} import javax.security.auth.login.LoginException import org.apache.commons.logging.Log +import org.apache.hadoop.security.UserGroupInformation import org.apache.hadoop.hive.conf.HiveConf import org.apache.hadoop.hive.shims.ShimLoader import org.apache.hive.service.Service.STATE import org.apache.hive.service.auth.HiveAuthFactory -import org.apache.hive.service.cli.CLIService +import org.apache.hive.service.cli._ import org.apache.hive.service.{AbstractService, Service, ServiceException} import org.apache.spark.sql.hive.HiveContext import org.apache.spark.sql.hive.thriftserver.ReflectionUtils._ +import org.apache.spark.util.Utils private[hive] class SparkSQLCLIService(hiveContext: HiveContext) extends CLIService @@ -44,19 +48,30 @@ private[hive] class SparkSQLCLIService(hiveContext: HiveContext) val sparkSqlSessionManager = new SparkSQLSessionManager(hiveContext) setSuperField(this, "sessionManager", sparkSqlSessionManager) addService(sparkSqlSessionManager) + var sparkServiceUGI: UserGroupInformation = null - try { - HiveAuthFactory.loginFromKeytab(hiveConf) - val serverUserName = ShimLoader.getHadoopShims - .getShortUserName(ShimLoader.getHadoopShims.getUGIForConf(hiveConf)) - setSuperField(this, "serverUserName", serverUserName) - } catch { - case e @ (_: IOException | _: LoginException) => - throw new ServiceException("Unable to login to kerberos with given principal/keytab", e) + if (ShimLoader.getHadoopShims().isSecurityEnabled()) { + try { + HiveAuthFactory.loginFromKeytab(hiveConf) + sparkServiceUGI = ShimLoader.getHadoopShims.getUGIForConf(hiveConf) + HiveThriftServerShim.setServerUserName(sparkServiceUGI, this) + } catch { + case e @ (_: IOException | _: LoginException) => + throw new ServiceException("Unable to login to kerberos with given principal/keytab", e) + } } initCompositeService(hiveConf) } + + override def getInfo(sessionHandle: SessionHandle, getInfoType: GetInfoType): GetInfoValue = { + getInfoType match { + case GetInfoType.CLI_SERVER_NAME => new GetInfoValue("Spark SQL") + case GetInfoType.CLI_DBMS_NAME => new GetInfoValue("Spark SQL") + case GetInfoType.CLI_DBMS_VER => new GetInfoValue(Utils.sparkVersion) + case _ => super.getInfo(sessionHandle, getInfoType) + } + } } private[thriftserver] trait ReflectedCompositeService { this: AbstractService => diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLEnv.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLEnv.scala index 2136a2ea63543..89732c939b0ec 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLEnv.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLEnv.scala @@ -17,12 +17,11 @@ package org.apache.spark.sql.hive.thriftserver -import org.apache.hadoop.hive.ql.session.SessionState +import scala.collection.JavaConversions._ -import org.apache.spark.scheduler.{SplitInfo, StatsReportListener} -import org.apache.spark.Logging -import org.apache.spark.sql.hive.HiveContext -import org.apache.spark.{SparkConf, SparkContext} +import org.apache.spark.scheduler.StatsReportListener +import org.apache.spark.sql.hive.{HiveShim, HiveContext} +import org.apache.spark.{Logging, SparkConf, SparkContext} /** A singleton object for the master program. The slaves should not access this. */ private[hive] object SparkSQLEnv extends Logging { @@ -33,18 +32,18 @@ private[hive] object SparkSQLEnv extends Logging { def init() { if (hiveContext == null) { - sparkContext = new SparkContext(new SparkConf() - .setAppName(s"SparkSQL::${java.net.InetAddress.getLocalHost.getHostName}")) + val sparkConf = new SparkConf() + .setAppName(s"SparkSQL::${java.net.InetAddress.getLocalHost.getHostName}") + .set("spark.sql.hive.version", HiveShim.version) + sparkContext = new SparkContext(sparkConf) sparkContext.addSparkListener(new StatsReportListener()) + hiveContext = new HiveContext(sparkContext) - hiveContext = new HiveContext(sparkContext) { - @transient override lazy val sessionState = { - val state = SessionState.get() - setConf(state.getConf.getAllProperties) - state + if (log.isDebugEnabled) { + hiveContext.hiveconf.getAllProperties.toSeq.sorted.foreach { case (k, v) => + logDebug(s"HiveConf var: $k=$v") } - @transient override lazy val hiveconf = sessionState.getConf } } } diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/server/SparkSQLOperationManager.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/server/SparkSQLOperationManager.scala index accf61576b804..99c4f46a82b8e 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/server/SparkSQLOperationManager.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/server/SparkSQLOperationManager.scala @@ -17,24 +17,15 @@ package org.apache.spark.sql.hive.thriftserver.server -import java.sql.Timestamp import java.util.{Map => JMap} +import scala.collection.mutable.Map -import scala.collection.JavaConversions._ -import scala.collection.mutable.{ArrayBuffer, Map} -import scala.math.{random, round} - -import org.apache.hadoop.hive.common.`type`.HiveDecimal -import org.apache.hadoop.hive.metastore.api.FieldSchema import org.apache.hive.service.cli._ import org.apache.hive.service.cli.operation.{ExecuteStatementOperation, Operation, OperationManager} import org.apache.hive.service.cli.session.HiveSession import org.apache.spark.Logging -import org.apache.spark.sql.{Row => SparkRow, SQLConf, SchemaRDD} -import org.apache.spark.sql.catalyst.plans.logical.SetCommand -import org.apache.spark.sql.catalyst.types._ -import org.apache.spark.sql.hive.{HiveContext, HiveMetastoreTypes} -import org.apache.spark.sql.hive.thriftserver.ReflectionUtils +import org.apache.spark.sql.hive.HiveContext +import org.apache.spark.sql.hive.thriftserver.{SparkExecuteStatementOperation, ReflectionUtils} /** * Executes queries using Spark SQL, and maintains a list of handles to active queries. @@ -54,159 +45,9 @@ private[thriftserver] class SparkSQLOperationManager(hiveContext: HiveContext) confOverlay: JMap[String, String], async: Boolean): ExecuteStatementOperation = synchronized { - val operation = new ExecuteStatementOperation(parentSession, statement, confOverlay) { - private var result: SchemaRDD = _ - private var iter: Iterator[SparkRow] = _ - private var dataTypes: Array[DataType] = _ - - def close(): Unit = { - // RDDs will be cleaned automatically upon garbage collection. - logDebug("CLOSING") - } - - def getNextRowSet(order: FetchOrientation, maxRowsL: Long): RowSet = { - if (!iter.hasNext) { - new RowSet() - } else { - // maxRowsL here typically maps to java.sql.Statement.getFetchSize, which is an int - val maxRows = maxRowsL.toInt - var curRow = 0 - var rowSet = new ArrayBuffer[Row](maxRows.min(1024)) - - while (curRow < maxRows && iter.hasNext) { - val sparkRow = iter.next() - val row = new Row() - var curCol = 0 - - while (curCol < sparkRow.length) { - if (sparkRow.isNullAt(curCol)) { - addNullColumnValue(sparkRow, row, curCol) - } else { - addNonNullColumnValue(sparkRow, row, curCol) - } - curCol += 1 - } - rowSet += row - curRow += 1 - } - new RowSet(rowSet, 0) - } - } - - def addNonNullColumnValue(from: SparkRow, to: Row, ordinal: Int) { - dataTypes(ordinal) match { - case StringType => - to.addString(from(ordinal).asInstanceOf[String]) - case IntegerType => - to.addColumnValue(ColumnValue.intValue(from.getInt(ordinal))) - case BooleanType => - to.addColumnValue(ColumnValue.booleanValue(from.getBoolean(ordinal))) - case DoubleType => - to.addColumnValue(ColumnValue.doubleValue(from.getDouble(ordinal))) - case FloatType => - to.addColumnValue(ColumnValue.floatValue(from.getFloat(ordinal))) - case DecimalType => - val hiveDecimal = from.get(ordinal).asInstanceOf[BigDecimal].bigDecimal - to.addColumnValue(ColumnValue.stringValue(new HiveDecimal(hiveDecimal))) - case LongType => - to.addColumnValue(ColumnValue.longValue(from.getLong(ordinal))) - case ByteType => - to.addColumnValue(ColumnValue.byteValue(from.getByte(ordinal))) - case ShortType => - to.addColumnValue(ColumnValue.shortValue(from.getShort(ordinal))) - case TimestampType => - to.addColumnValue( - ColumnValue.timestampValue(from.get(ordinal).asInstanceOf[Timestamp])) - case BinaryType | _: ArrayType | _: StructType | _: MapType => - val hiveString = result - .queryExecution - .asInstanceOf[HiveContext#QueryExecution] - .toHiveString((from.get(ordinal), dataTypes(ordinal))) - to.addColumnValue(ColumnValue.stringValue(hiveString)) - } - } - - def addNullColumnValue(from: SparkRow, to: Row, ordinal: Int) { - dataTypes(ordinal) match { - case StringType => - to.addString(null) - case IntegerType => - to.addColumnValue(ColumnValue.intValue(null)) - case BooleanType => - to.addColumnValue(ColumnValue.booleanValue(null)) - case DoubleType => - to.addColumnValue(ColumnValue.doubleValue(null)) - case FloatType => - to.addColumnValue(ColumnValue.floatValue(null)) - case DecimalType => - to.addColumnValue(ColumnValue.stringValue(null: HiveDecimal)) - case LongType => - to.addColumnValue(ColumnValue.longValue(null)) - case ByteType => - to.addColumnValue(ColumnValue.byteValue(null)) - case ShortType => - to.addColumnValue(ColumnValue.shortValue(null)) - case TimestampType => - to.addColumnValue(ColumnValue.timestampValue(null)) - case BinaryType | _: ArrayType | _: StructType | _: MapType => - to.addColumnValue(ColumnValue.stringValue(null: String)) - } - } - - def getResultSetSchema: TableSchema = { - logInfo(s"Result Schema: ${result.queryExecution.analyzed.output}") - if (result.queryExecution.analyzed.output.size == 0) { - new TableSchema(new FieldSchema("Result", "string", "") :: Nil) - } else { - val schema = result.queryExecution.analyzed.output.map { attr => - new FieldSchema(attr.name, HiveMetastoreTypes.toMetastoreType(attr.dataType), "") - } - new TableSchema(schema) - } - } - - def run(): Unit = { - logInfo(s"Running query '$statement'") - setState(OperationState.RUNNING) - try { - result = hiveContext.sql(statement) - logDebug(result.queryExecution.toString()) - result.queryExecution.logical match { - case SetCommand(Some((SQLConf.THRIFTSERVER_POOL, Some(value)))) => - sessionToActivePool(parentSession) = value - logInfo(s"Setting spark.scheduler.pool=$value for future statements in this session.") - case _ => - } - - val groupId = round(random * 1000000).toString - hiveContext.sparkContext.setJobGroup(groupId, statement) - sessionToActivePool.get(parentSession).foreach { pool => - hiveContext.sparkContext.setLocalProperty("spark.scheduler.pool", pool) - } - iter = { - val resultRdd = result.queryExecution.toRdd - val useIncrementalCollect = - hiveContext.getConf("spark.sql.thriftServer.incrementalCollect", "false").toBoolean - if (useIncrementalCollect) { - resultRdd.toLocalIterator - } else { - resultRdd.collect().iterator - } - } - dataTypes = result.queryExecution.analyzed.output.map(_.dataType).toArray - setHasResultSet(true) - } catch { - // Actually do need to catch Throwable as some failures don't inherit from Exception and - // HiveServer will silently swallow them. - case e: Throwable => - logError("Error executing query:",e) - throw new HiveSQLException(e.toString) - } - setState(OperationState.FINISHED) - } - } - - handleToOperation.put(operation.getHandle, operation) - operation + val operation = new SparkExecuteStatementOperation(parentSession, statement, confOverlay)( + hiveContext, sessionToActivePool) + handleToOperation.put(operation.getHandle, operation) + operation } } diff --git a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/CliSuite.scala b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/CliSuite.scala index 8a72e9d2aef57..e8ffbc5b954d4 100644 --- a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/CliSuite.scala +++ b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/CliSuite.scala @@ -18,19 +18,17 @@ package org.apache.spark.sql.hive.thriftserver +import java.io._ + import scala.collection.mutable.ArrayBuffer -import scala.concurrent.ExecutionContext.Implicits.global import scala.concurrent.duration._ -import scala.concurrent.{Await, Future, Promise} +import scala.concurrent.{Await, Promise} import scala.sys.process.{Process, ProcessLogger} -import java.io._ -import java.util.concurrent.atomic.AtomicInteger - import org.apache.hadoop.hive.conf.HiveConf.ConfVars import org.scalatest.{BeforeAndAfterAll, FunSuite} -import org.apache.spark.{SparkException, Logging} +import org.apache.spark.Logging import org.apache.spark.sql.catalyst.util.getTempFilePath class CliSuite extends FunSuite with BeforeAndAfterAll with Logging { @@ -53,23 +51,20 @@ class CliSuite extends FunSuite with BeforeAndAfterAll with Logging { """.stripMargin.split("\\s+").toSeq ++ extraArgs } - // AtomicInteger is needed because stderr and stdout of the forked process are handled in - // different threads. - val next = new AtomicInteger(0) + var next = 0 val foundAllExpectedAnswers = Promise.apply[Unit]() val queryStream = new ByteArrayInputStream(queries.mkString("\n").getBytes) val buffer = new ArrayBuffer[String]() + val lock = new Object - def captureOutput(source: String)(line: String) { + def captureOutput(source: String)(line: String): Unit = lock.synchronized { buffer += s"$source> $line" - // If we haven't found all expected answers... - if (next.get() < expectedAnswers.size) { - // If another expected answer is found... - if (line.startsWith(expectedAnswers(next.get()))) { - // If all expected answers have been found... - if (next.incrementAndGet() == expectedAnswers.size) { - foundAllExpectedAnswers.trySuccess(()) - } + // If we haven't found all expected answers and another expected answer comes up... + if (next < expectedAnswers.size && line.startsWith(expectedAnswers(next))) { + next += 1 + // If all expected answers have been found... + if (next == expectedAnswers.size) { + foundAllExpectedAnswers.trySuccess(()) } } } @@ -88,8 +83,8 @@ class CliSuite extends FunSuite with BeforeAndAfterAll with Logging { |======================= |Spark SQL CLI command line: ${command.mkString(" ")} | - |Executed query ${next.get()} "${queries(next.get())}", - |But failed to capture expected output "${expectedAnswers(next.get())}" within $timeout. + |Executed query $next "${queries(next)}", + |But failed to capture expected output "${expectedAnswers(next)}" within $timeout. | |${buffer.mkString("\n")} |=========================== diff --git a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suite.scala b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suite.scala index e3b4e45a3d68c..65d910a0c3ffc 100644 --- a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suite.scala +++ b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suite.scala @@ -30,42 +30,95 @@ import scala.util.Try import org.apache.hadoop.hive.conf.HiveConf.ConfVars import org.apache.hive.jdbc.HiveDriver +import org.apache.hive.service.auth.PlainSaslHelper +import org.apache.hive.service.cli.GetInfoType +import org.apache.hive.service.cli.thrift.TCLIService.Client +import org.apache.hive.service.cli.thrift._ +import org.apache.thrift.protocol.TBinaryProtocol +import org.apache.thrift.transport.TSocket import org.scalatest.FunSuite import org.apache.spark.Logging import org.apache.spark.sql.catalyst.util.getTempFilePath +import org.apache.spark.sql.hive.HiveShim /** * Tests for the HiveThriftServer2 using JDBC. + * + * NOTE: SPARK_PREPEND_CLASSES is explicitly disabled in this test suite. Assembly jar must be + * rebuilt after changing HiveThriftServer2 related code. */ class HiveThriftServer2Suite extends FunSuite with Logging { Class.forName(classOf[HiveDriver].getCanonicalName) - def startThriftServerWithin(timeout: FiniteDuration = 1.minute)(f: Statement => Unit) { + def randomListeningPort = { + // Let the system to choose a random available port to avoid collision with other parallel + // builds. + val socket = new ServerSocket(0) + val port = socket.getLocalPort + socket.close() + port + } + + def withJdbcStatement(serverStartTimeout: FiniteDuration = 1.minute)(f: Statement => Unit) { + val port = randomListeningPort + + startThriftServer(port, serverStartTimeout) { + val jdbcUri = s"jdbc:hive2://${"localhost"}:$port/" + val user = System.getProperty("user.name") + val connection = DriverManager.getConnection(jdbcUri, user, "") + val statement = connection.createStatement() + + try { + f(statement) + } finally { + statement.close() + connection.close() + } + } + } + + def withCLIServiceClient( + serverStartTimeout: FiniteDuration = 1.minute)( + f: ThriftCLIServiceClient => Unit) { + val port = randomListeningPort + + startThriftServer(port) { + // Transport creation logics below mimics HiveConnection.createBinaryTransport + val rawTransport = new TSocket("localhost", port) + val user = System.getProperty("user.name") + val transport = PlainSaslHelper.getPlainTransport(user, "anonymous", rawTransport) + val protocol = new TBinaryProtocol(transport) + val client = new ThriftCLIServiceClient(new Client(protocol)) + + transport.open() + + try { + f(client) + } finally { + transport.close() + } + } + } + + def startThriftServer( + port: Int, + serverStartTimeout: FiniteDuration = 1.minute)( + f: => Unit) { val startScript = "../../sbin/start-thriftserver.sh".split("/").mkString(File.separator) val stopScript = "../../sbin/stop-thriftserver.sh".split("/").mkString(File.separator) val warehousePath = getTempFilePath("warehouse") val metastorePath = getTempFilePath("metastore") val metastoreJdbcUri = s"jdbc:derby:;databaseName=$metastorePath;create=true" - val listeningHost = "localhost" - val listeningPort = { - // Let the system to choose a random available port to avoid collision with other parallel - // builds. - val socket = new ServerSocket(0) - val port = socket.getLocalPort - socket.close() - port - } - val command = s"""$startScript | --master local | --hiveconf hive.root.logger=INFO,console | --hiveconf ${ConfVars.METASTORECONNECTURLKEY}=$metastoreJdbcUri | --hiveconf ${ConfVars.METASTOREWAREHOUSE}=$warehousePath - | --hiveconf ${ConfVars.HIVE_SERVER2_THRIFT_BIND_HOST}=$listeningHost - | --hiveconf ${ConfVars.HIVE_SERVER2_THRIFT_PORT}=$listeningPort + | --hiveconf ${ConfVars.HIVE_SERVER2_THRIFT_BIND_HOST}=${"localhost"} + | --hiveconf ${ConfVars.HIVE_SERVER2_THRIFT_PORT}=$port """.stripMargin.split("\\s+").toSeq val serverRunning = Promise[Unit]() @@ -92,31 +145,25 @@ class HiveThriftServer2Suite extends FunSuite with Logging { } } - // Resets SPARK_TESTING to avoid loading Log4J configurations in testing class paths - Process(command, None, "SPARK_TESTING" -> "0").run(ProcessLogger( + val env = Seq( + // Resets SPARK_TESTING to avoid loading Log4J configurations in testing class paths + "SPARK_TESTING" -> "0", + // Prevents loading classes out of the assembly jar. Otherwise Utils.sparkVersion can't read + // proper version information from the jar manifest. + "SPARK_PREPEND_CLASSES" -> "") + + Process(command, None, env: _*).run(ProcessLogger( captureThriftServerOutput("stdout"), captureThriftServerOutput("stderr"))) - val jdbcUri = s"jdbc:hive2://$listeningHost:$listeningPort/" - val user = System.getProperty("user.name") - try { - Await.result(serverRunning.future, timeout) - - val connection = DriverManager.getConnection(jdbcUri, user, "") - val statement = connection.createStatement() - - try { - f(statement) - } finally { - statement.close() - connection.close() - } + Await.result(serverRunning.future, serverStartTimeout) + f } catch { case cause: Exception => cause match { case _: TimeoutException => - logError(s"Failed to start Hive Thrift server within $timeout", cause) + logError(s"Failed to start Hive Thrift server within $serverStartTimeout", cause) case _ => } logError( @@ -125,8 +172,8 @@ class HiveThriftServer2Suite extends FunSuite with Logging { |HiveThriftServer2Suite failure output |===================================== |HiveThriftServer2 command line: ${command.mkString(" ")} - |JDBC URI: $jdbcUri - |User: $user + |Binding port: $port + |System user: ${System.getProperty("user.name")} | |${buffer.mkString("\n")} |========================================= @@ -146,14 +193,16 @@ class HiveThriftServer2Suite extends FunSuite with Logging { } test("Test JDBC query execution") { - startThriftServerWithin() { statement => + withJdbcStatement() { statement => val dataFilePath = Thread.currentThread().getContextClassLoader.getResource("data/files/small_kv.txt") - val queries = Seq( - "CREATE TABLE test(key INT, val STRING)", - s"LOAD DATA LOCAL INPATH '$dataFilePath' OVERWRITE INTO TABLE test", - "CACHE TABLE test") + val queries = + s"""SET spark.sql.shuffle.partitions=3; + |CREATE TABLE test(key INT, val STRING); + |LOAD DATA LOCAL INPATH '$dataFilePath' OVERWRITE INTO TABLE test; + |CACHE TABLE test; + """.stripMargin.split(";").map(_.trim).filter(_.nonEmpty) queries.foreach(statement.execute) @@ -166,7 +215,7 @@ class HiveThriftServer2Suite extends FunSuite with Logging { } test("SPARK-3004 regression: result set containing NULL") { - startThriftServerWithin() { statement => + withJdbcStatement() { statement => val dataFilePath = Thread.currentThread().getContextClassLoader.getResource( "data/files/small_kv_with_null.txt") @@ -189,4 +238,33 @@ class HiveThriftServer2Suite extends FunSuite with Logging { assert(!resultSet.next()) } } + + test("GetInfo Thrift API") { + withCLIServiceClient() { client => + val user = System.getProperty("user.name") + val sessionHandle = client.openSession(user, "") + + assertResult("Spark SQL", "Wrong GetInfo(CLI_DBMS_NAME) result") { + client.getInfo(sessionHandle, GetInfoType.CLI_DBMS_NAME).getStringValue + } + + assertResult("Spark SQL", "Wrong GetInfo(CLI_SERVER_NAME) result") { + client.getInfo(sessionHandle, GetInfoType.CLI_SERVER_NAME).getStringValue + } + + assertResult(true, "Spark version shouldn't be \"Unknown\"") { + val version = client.getInfo(sessionHandle, GetInfoType.CLI_DBMS_VER).getStringValue + logInfo(s"Spark version: $version") + version != "Unknown" + } + } + } + + test("Checks Hive version") { + withJdbcStatement() { statement => + val resultSet = statement.executeQuery("SET spark.sql.hive.version") + resultSet.next() + assert(resultSet.getString(1) === s"spark.sql.hive.version=${HiveShim.version}") + } + } } diff --git a/sql/hive-thriftserver/v0.12.0/src/main/scala/org/apache/spark/sql/hive/thriftserver/Shim12.scala b/sql/hive-thriftserver/v0.12.0/src/main/scala/org/apache/spark/sql/hive/thriftserver/Shim12.scala new file mode 100644 index 0000000000000..8077d0ec46fd7 --- /dev/null +++ b/sql/hive-thriftserver/v0.12.0/src/main/scala/org/apache/spark/sql/hive/thriftserver/Shim12.scala @@ -0,0 +1,225 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive.thriftserver + +import java.sql.Timestamp +import java.util.{ArrayList => JArrayList, Map => JMap} + +import scala.collection.JavaConversions._ +import scala.collection.mutable.{ArrayBuffer, Map => SMap} +import scala.math._ + +import org.apache.hadoop.hive.common.`type`.HiveDecimal +import org.apache.hadoop.hive.conf.HiveConf +import org.apache.hadoop.hive.metastore.api.FieldSchema +import org.apache.hadoop.hive.ql.processors.CommandProcessorFactory +import org.apache.hadoop.hive.shims.ShimLoader +import org.apache.hadoop.security.UserGroupInformation +import org.apache.hive.service.cli._ +import org.apache.hive.service.cli.operation.ExecuteStatementOperation +import org.apache.hive.service.cli.session.HiveSession + +import org.apache.spark.Logging +import org.apache.spark.sql.catalyst.plans.logical.SetCommand +import org.apache.spark.sql.catalyst.types._ +import org.apache.spark.sql.{Row => SparkRow, SQLConf, SchemaRDD} +import org.apache.spark.sql.hive.{HiveMetastoreTypes, HiveContext} +import org.apache.spark.sql.hive.thriftserver.ReflectionUtils._ + +/** + * A compatibility layer for interacting with Hive version 0.12.0. + */ +private[thriftserver] object HiveThriftServerShim { + val version = "0.12.0" + + def setServerUserName(sparkServiceUGI: UserGroupInformation, sparkCliService:SparkSQLCLIService) = { + val serverUserName = ShimLoader.getHadoopShims.getShortUserName(sparkServiceUGI) + setSuperField(sparkCliService, "serverUserName", serverUserName) + } +} + +private[hive] class SparkSQLDriver(val _context: HiveContext = SparkSQLEnv.hiveContext) + extends AbstractSparkSQLDriver(_context) { + override def getResults(res: JArrayList[String]): Boolean = { + if (hiveResponse == null) { + false + } else { + res.addAll(hiveResponse) + hiveResponse = null + true + } + } +} + +private[hive] class SparkExecuteStatementOperation( + parentSession: HiveSession, + statement: String, + confOverlay: JMap[String, String])( + hiveContext: HiveContext, + sessionToActivePool: SMap[HiveSession, String]) extends ExecuteStatementOperation( + parentSession, statement, confOverlay) with Logging { + private var result: SchemaRDD = _ + private var iter: Iterator[SparkRow] = _ + private var dataTypes: Array[DataType] = _ + + def close(): Unit = { + // RDDs will be cleaned automatically upon garbage collection. + logDebug("CLOSING") + } + + def getNextRowSet(order: FetchOrientation, maxRowsL: Long): RowSet = { + if (!iter.hasNext) { + new RowSet() + } else { + // maxRowsL here typically maps to java.sql.Statement.getFetchSize, which is an int + val maxRows = maxRowsL.toInt + var curRow = 0 + var rowSet = new ArrayBuffer[Row](maxRows.min(1024)) + + while (curRow < maxRows && iter.hasNext) { + val sparkRow = iter.next() + val row = new Row() + var curCol = 0 + + while (curCol < sparkRow.length) { + if (sparkRow.isNullAt(curCol)) { + addNullColumnValue(sparkRow, row, curCol) + } else { + addNonNullColumnValue(sparkRow, row, curCol) + } + curCol += 1 + } + rowSet += row + curRow += 1 + } + new RowSet(rowSet, 0) + } + } + + def addNonNullColumnValue(from: SparkRow, to: Row, ordinal: Int) { + dataTypes(ordinal) match { + case StringType => + to.addString(from(ordinal).asInstanceOf[String]) + case IntegerType => + to.addColumnValue(ColumnValue.intValue(from.getInt(ordinal))) + case BooleanType => + to.addColumnValue(ColumnValue.booleanValue(from.getBoolean(ordinal))) + case DoubleType => + to.addColumnValue(ColumnValue.doubleValue(from.getDouble(ordinal))) + case FloatType => + to.addColumnValue(ColumnValue.floatValue(from.getFloat(ordinal))) + case DecimalType() => + val hiveDecimal = from.get(ordinal).asInstanceOf[BigDecimal].bigDecimal + to.addColumnValue(ColumnValue.stringValue(new HiveDecimal(hiveDecimal))) + case LongType => + to.addColumnValue(ColumnValue.longValue(from.getLong(ordinal))) + case ByteType => + to.addColumnValue(ColumnValue.byteValue(from.getByte(ordinal))) + case ShortType => + to.addColumnValue(ColumnValue.shortValue(from.getShort(ordinal))) + case TimestampType => + to.addColumnValue( + ColumnValue.timestampValue(from.get(ordinal).asInstanceOf[Timestamp])) + case BinaryType | _: ArrayType | _: StructType | _: MapType => + val hiveString = result + .queryExecution + .asInstanceOf[HiveContext#QueryExecution] + .toHiveString((from.get(ordinal), dataTypes(ordinal))) + to.addColumnValue(ColumnValue.stringValue(hiveString)) + } + } + + def addNullColumnValue(from: SparkRow, to: Row, ordinal: Int) { + dataTypes(ordinal) match { + case StringType => + to.addString(null) + case IntegerType => + to.addColumnValue(ColumnValue.intValue(null)) + case BooleanType => + to.addColumnValue(ColumnValue.booleanValue(null)) + case DoubleType => + to.addColumnValue(ColumnValue.doubleValue(null)) + case FloatType => + to.addColumnValue(ColumnValue.floatValue(null)) + case DecimalType() => + to.addColumnValue(ColumnValue.stringValue(null: HiveDecimal)) + case LongType => + to.addColumnValue(ColumnValue.longValue(null)) + case ByteType => + to.addColumnValue(ColumnValue.byteValue(null)) + case ShortType => + to.addColumnValue(ColumnValue.shortValue(null)) + case TimestampType => + to.addColumnValue(ColumnValue.timestampValue(null)) + case BinaryType | _: ArrayType | _: StructType | _: MapType => + to.addColumnValue(ColumnValue.stringValue(null: String)) + } + } + + def getResultSetSchema: TableSchema = { + logInfo(s"Result Schema: ${result.queryExecution.analyzed.output}") + if (result.queryExecution.analyzed.output.size == 0) { + new TableSchema(new FieldSchema("Result", "string", "") :: Nil) + } else { + val schema = result.queryExecution.analyzed.output.map { attr => + new FieldSchema(attr.name, HiveMetastoreTypes.toMetastoreType(attr.dataType), "") + } + new TableSchema(schema) + } + } + + def run(): Unit = { + logInfo(s"Running query '$statement'") + setState(OperationState.RUNNING) + try { + result = hiveContext.sql(statement) + logDebug(result.queryExecution.toString()) + result.queryExecution.logical match { + case SetCommand(Some((SQLConf.THRIFTSERVER_POOL, Some(value)))) => + sessionToActivePool(parentSession) = value + logInfo(s"Setting spark.scheduler.pool=$value for future statements in this session.") + case _ => + } + + val groupId = round(random * 1000000).toString + hiveContext.sparkContext.setJobGroup(groupId, statement) + sessionToActivePool.get(parentSession).foreach { pool => + hiveContext.sparkContext.setLocalProperty("spark.scheduler.pool", pool) + } + iter = { + val resultRdd = result.queryExecution.toRdd + val useIncrementalCollect = + hiveContext.getConf("spark.sql.thriftServer.incrementalCollect", "false").toBoolean + if (useIncrementalCollect) { + resultRdd.toLocalIterator + } else { + resultRdd.collect().iterator + } + } + dataTypes = result.queryExecution.analyzed.output.map(_.dataType).toArray + setHasResultSet(true) + } catch { + // Actually do need to catch Throwable as some failures don't inherit from Exception and + // HiveServer will silently swallow them. + case e: Throwable => + logError("Error executing query:",e) + throw new HiveSQLException(e.toString) + } + setState(OperationState.FINISHED) + } +} diff --git a/sql/hive-thriftserver/v0.13.1/src/main/scala/org/apache/spark/sql/hive/thriftserver/Shim13.scala b/sql/hive-thriftserver/v0.13.1/src/main/scala/org/apache/spark/sql/hive/thriftserver/Shim13.scala new file mode 100644 index 0000000000000..2c1983de1d0d5 --- /dev/null +++ b/sql/hive-thriftserver/v0.13.1/src/main/scala/org/apache/spark/sql/hive/thriftserver/Shim13.scala @@ -0,0 +1,267 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive.thriftserver + +import java.security.PrivilegedExceptionAction +import java.sql.Timestamp +import java.util.concurrent.Future +import java.util.{ArrayList => JArrayList, List => JList, Map => JMap} + +import scala.collection.JavaConversions._ +import scala.collection.mutable.{ArrayBuffer, Map => SMap} +import scala.math._ + +import org.apache.hadoop.hive.conf.HiveConf +import org.apache.hadoop.hive.ql.metadata.Hive +import org.apache.hadoop.hive.ql.processors.CommandProcessorFactory +import org.apache.hadoop.hive.ql.session.SessionState +import org.apache.hadoop.hive.metastore.api.FieldSchema +import org.apache.hadoop.hive.shims.ShimLoader +import org.apache.hadoop.security.UserGroupInformation +import org.apache.hive.service.cli._ +import org.apache.hive.service.cli.operation.ExecuteStatementOperation +import org.apache.hive.service.cli.session.HiveSession + +import org.apache.spark.Logging +import org.apache.spark.sql.catalyst.types._ +import org.apache.spark.sql.{Row => SparkRow, SchemaRDD} +import org.apache.spark.sql.hive.{HiveContext, HiveMetastoreTypes} +import org.apache.spark.sql.hive.thriftserver.ReflectionUtils._ + +/** + * A compatibility layer for interacting with Hive version 0.12.0. + */ +private[thriftserver] object HiveThriftServerShim { + val version = "0.13.1" + + def setServerUserName(sparkServiceUGI: UserGroupInformation, sparkCliService:SparkSQLCLIService) = { + setSuperField(sparkCliService, "serviceUGI", sparkServiceUGI) + } +} + +private[hive] class SparkSQLDriver(val _context: HiveContext = SparkSQLEnv.hiveContext) + extends AbstractSparkSQLDriver(_context) { + override def getResults(res: JList[_]): Boolean = { + if (hiveResponse == null) { + false + } else { + res.asInstanceOf[JArrayList[String]].addAll(hiveResponse) + hiveResponse = null + true + } + } +} + +private[hive] class SparkExecuteStatementOperation( + parentSession: HiveSession, + statement: String, + confOverlay: JMap[String, String], + runInBackground: Boolean = true)( + hiveContext: HiveContext, + sessionToActivePool: SMap[HiveSession, String]) extends ExecuteStatementOperation( + parentSession, statement, confOverlay, runInBackground) with Logging { + + private var result: SchemaRDD = _ + private var iter: Iterator[SparkRow] = _ + private var dataTypes: Array[DataType] = _ + + private def runInternal(cmd: String) = { + try { + result = hiveContext.sql(cmd) + logDebug(result.queryExecution.toString()) + val groupId = round(random * 1000000).toString + hiveContext.sparkContext.setJobGroup(groupId, statement) + iter = { + val resultRdd = result.queryExecution.toRdd + val useIncrementalCollect = + hiveContext.getConf("spark.sql.thriftServer.incrementalCollect", "false").toBoolean + if (useIncrementalCollect) { + resultRdd.toLocalIterator + } else { + resultRdd.collect().iterator + } + } + dataTypes = result.queryExecution.analyzed.output.map(_.dataType).toArray + } catch { + // Actually do need to catch Throwable as some failures don't inherit from Exception and + // HiveServer will silently swallow them. + case e: Throwable => + logError("Error executing query:",e) + throw new HiveSQLException(e.toString) + } + } + + def close(): Unit = { + // RDDs will be cleaned automatically upon garbage collection. + logDebug("CLOSING") + } + + def addNonNullColumnValue(from: SparkRow, to: ArrayBuffer[Any], ordinal: Int) { + dataTypes(ordinal) match { + case StringType => + to += from.get(ordinal).asInstanceOf[String] + case IntegerType => + to += from.getInt(ordinal) + case BooleanType => + to += from.getBoolean(ordinal) + case DoubleType => + to += from.getDouble(ordinal) + case FloatType => + to += from.getFloat(ordinal) + case DecimalType() => + to += from.get(ordinal).asInstanceOf[BigDecimal].bigDecimal + case LongType => + to += from.getLong(ordinal) + case ByteType => + to += from.getByte(ordinal) + case ShortType => + to += from.getShort(ordinal) + case TimestampType => + to += from.get(ordinal).asInstanceOf[Timestamp] + case BinaryType => + to += from.get(ordinal).asInstanceOf[String] + case _: ArrayType => + to += from.get(ordinal).asInstanceOf[String] + case _: StructType => + to += from.get(ordinal).asInstanceOf[String] + case _: MapType => + to += from.get(ordinal).asInstanceOf[String] + } + } + + def getNextRowSet(order: FetchOrientation, maxRowsL: Long): RowSet = { + validateDefaultFetchOrientation(order) + assertState(OperationState.FINISHED) + setHasResultSet(true) + val reultRowSet: RowSet = RowSetFactory.create(getResultSetSchema, getProtocolVersion) + if (!iter.hasNext) { + reultRowSet + } else { + // maxRowsL here typically maps to java.sql.Statement.getFetchSize, which is an int + val maxRows = maxRowsL.toInt + var curRow = 0 + while (curRow < maxRows && iter.hasNext) { + val sparkRow = iter.next() + val row = ArrayBuffer[Any]() + var curCol = 0 + while (curCol < sparkRow.length) { + if (sparkRow.isNullAt(curCol)) { + row += null + } else { + addNonNullColumnValue(sparkRow, row, curCol) + } + curCol += 1 + } + reultRowSet.addRow(row.toArray.asInstanceOf[Array[Object]]) + curRow += 1 + } + reultRowSet + } + } + + def getResultSetSchema: TableSchema = { + logInfo(s"Result Schema: ${result.queryExecution.analyzed.output}") + if (result.queryExecution.analyzed.output.size == 0) { + new TableSchema(new FieldSchema("Result", "string", "") :: Nil) + } else { + val schema = result.queryExecution.analyzed.output.map { attr => + new FieldSchema(attr.name, HiveMetastoreTypes.toMetastoreType(attr.dataType), "") + } + new TableSchema(schema) + } + } + + private def getConfigForOperation: HiveConf = { + var sqlOperationConf: HiveConf = getParentSession.getHiveConf + if (!getConfOverlay.isEmpty || shouldRunAsync) { + sqlOperationConf = new HiveConf(sqlOperationConf) + import scala.collection.JavaConversions._ + for (confEntry <- getConfOverlay.entrySet) { + try { + sqlOperationConf.verifyAndSet(confEntry.getKey, confEntry.getValue) + } + catch { + case e: IllegalArgumentException => { + throw new HiveSQLException("Error applying statement specific settings", e) + } + } + } + } + return sqlOperationConf + } + + def run(): Unit = { + logInfo(s"Running query '$statement'") + val opConfig: HiveConf = getConfigForOperation + setState(OperationState.RUNNING) + setHasResultSet(true) + + if (!shouldRunAsync) { + runInternal(statement) + setState(OperationState.FINISHED) + } else { + val parentSessionState = SessionState.get + val sessionHive: Hive = Hive.get + val currentUGI: UserGroupInformation = ShimLoader.getHadoopShims.getUGIForConf(opConfig) + + val backgroundOperation: Runnable = new Runnable { + def run { + val doAsAction: PrivilegedExceptionAction[AnyRef] = + new PrivilegedExceptionAction[AnyRef] { + def run: AnyRef = { + Hive.set(sessionHive) + SessionState.setCurrentSessionState(parentSessionState) + try { + runInternal(statement) + } + catch { + case e: HiveSQLException => { + setOperationException(e) + logError("Error running hive query: ", e) + } + } + return null + } + } + try { + ShimLoader.getHadoopShims.doAs(currentUGI, doAsAction) + } + catch { + case e: Exception => { + setOperationException(new HiveSQLException(e)) + logError("Error running hive query as user : " + currentUGI.getShortUserName, e) + } + } + setState(OperationState.FINISHED) + } + } + + try { + val backgroundHandle: Future[_] = getParentSession.getSessionManager. + submitBackgroundOperation(backgroundOperation) + setBackgroundHandle(backgroundHandle) + } catch { + // Actually do need to catch Throwable as some failures don't inherit from Exception and + // HiveServer will silently swallow them. + case e: Throwable => + logError("Error executing query:",e) + throw new HiveSQLException(e.toString) + } + } + } +} diff --git a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala index 15cd62d3bf869..1a3c24be420e6 100644 --- a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala +++ b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala @@ -229,7 +229,15 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter { // Needs constant object inspectors "udf_round", - "udf7" + "udf7", + + // Sort with Limit clause causes failure. + "ctas", + "ctas_hadoop20", + + // timestamp in array, the output format of Hive contains double quotes, while + // Spark SQL doesn't + "udf_sort_array" ) ++ HiveShim.compatibilityBlackList /** @@ -767,6 +775,7 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter { "touch", "transform_ppr1", "transform_ppr2", + "truncate_table", "type_cast_1", "type_widening", "udaf_collect_set", @@ -856,6 +865,7 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter { "udf_minute", "udf_modulo", "udf_month", + "udf_named_struct", "udf_negative", "udf_not", "udf_notequal", @@ -889,6 +899,7 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter { "udf_stddev_pop", "udf_stddev_samp", "udf_string", + "udf_struct", "udf_substring", "udf_subtract", "udf_sum", diff --git a/sql/hive/pom.xml b/sql/hive/pom.xml index db01363b4d629..67e36a951e506 100644 --- a/sql/hive/pom.xml +++ b/sql/hive/pom.xml @@ -65,6 +65,10 @@ commons-logging commons-logging + + com.esotericsoftware.kryo + kryo + diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala index fad4091d48a89..e88afaaf001c0 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala @@ -21,6 +21,10 @@ import java.io.{BufferedReader, File, InputStreamReader, PrintStream} import java.sql.{Date, Timestamp} import java.util.{ArrayList => JArrayList} +import org.apache.hadoop.hive.common.`type`.HiveDecimal +import org.apache.spark.sql.catalyst.types.DecimalType +import org.apache.spark.sql.catalyst.types.decimal.Decimal + import scala.collection.JavaConversions._ import scala.language.implicitConversions import scala.reflect.runtime.universe.{TypeTag, typeTag} @@ -46,6 +50,7 @@ import org.apache.spark.sql.execution.ExtractPythonUdfs import org.apache.spark.sql.execution.QueryExecutionException import org.apache.spark.sql.execution.{Command => PhysicalCommand} import org.apache.spark.sql.hive.execution.DescribeHiveTableCommand +import org.apache.spark.sql.sources.DataSourceStrategy /** * DEPRECATED: Use HiveContext instead. @@ -85,7 +90,7 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) { * SerDe. */ private[spark] def convertMetastoreParquet: Boolean = - getConf("spark.sql.hive.convertMetastoreParquet", "false") == "true" + getConf("spark.sql.hive.convertMetastoreParquet", "true") == "true" override protected[sql] def executePlan(plan: LogicalPlan): this.QueryExecution = new this.QueryExecution { val logical = plan } @@ -95,7 +100,7 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) { if (dialect == "sql") { super.sql(sqlText) } else if (dialect == "hiveql") { - new SchemaRDD(this, HiveQl.parseSql(sqlText)) + new SchemaRDD(this, ddlParser(sqlText).getOrElse(HiveQl.parseSql(sqlText))) } else { sys.error(s"Unsupported SQL dialect: $dialect. Try 'sql' or 'hiveql'") } @@ -224,21 +229,29 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) { } /** - * SQLConf and HiveConf contracts: when the hive session is first initialized, params in - * HiveConf will get picked up by the SQLConf. Additionally, any properties set by - * set() or a SET command inside sql() will be set in the SQLConf *as well as* - * in the HiveConf. + * SQLConf and HiveConf contracts: + * + * 1. reuse existing started SessionState if any + * 2. when the Hive session is first initialized, params in HiveConf will get picked up by the + * SQLConf. Additionally, any properties set by set() or a SET command inside sql() will be + * set in the SQLConf *as well as* in the HiveConf. */ - @transient lazy val hiveconf = new HiveConf(classOf[SessionState]) - @transient protected[hive] lazy val sessionState = { - val ss = new SessionState(hiveconf) - setConf(hiveconf.getAllProperties) // Have SQLConf pick up the initial set of HiveConf. - SessionState.start(ss) - ss.err = new PrintStream(outputBuffer, true, "UTF-8") - ss.out = new PrintStream(outputBuffer, true, "UTF-8") - - ss - } + @transient protected[hive] lazy val (hiveconf, sessionState) = + Option(SessionState.get()) + .orElse { + val newState = new SessionState(new HiveConf(classOf[SessionState])) + // Only starts newly created `SessionState` instance. Any existing `SessionState` instance + // returned by `SessionState.get()` must be the most recently started one. + SessionState.start(newState) + Some(newState) + } + .map { state => + setConf(state.getConf.getAllProperties) + if (state.out == null) state.out = new PrintStream(outputBuffer, true, "UTF-8") + if (state.err == null) state.err = new PrintStream(outputBuffer, true, "UTF-8") + (state.getConf, state) + } + .get override def setConf(key: String, value: String): Unit = { super.setConf(key, value) @@ -288,6 +301,14 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) { val cmd_1: String = cmd_trimmed.substring(tokens(0).length()).trim() val proc: CommandProcessor = HiveShim.getCommandProcessor(Array(tokens(0)), hiveconf) + // Makes sure the session represented by the `sessionState` field is activated. This implies + // Spark SQL Hive support uses a single `SessionState` for all Hive operations and breaks + // session isolation under multi-user scenarios (i.e. HiveThriftServer2). + // TODO Fix session isolation + if (SessionState.get() != sessionState) { + SessionState.start(sessionState) + } + proc match { case driver: Driver => val results = HiveShim.createDriverResultsArray @@ -302,7 +323,9 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) { driver.close() HiveShim.processResults(results) case _ => - sessionState.out.println(tokens(0) + " " + cmd_1) + if (sessionState.out != null) { + sessionState.out.println(tokens(0) + " " + cmd_1) + } Seq(proc.run(cmd_1).getResponseCode.toString) } } catch { @@ -325,7 +348,8 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) { val hivePlanner = new SparkPlanner with HiveStrategies { val hiveContext = self - override val strategies: Seq[Strategy] = Seq( + override val strategies: Seq[Strategy] = extraStrategies ++ Seq( + DataSourceStrategy, CommandStrategy(self), HiveCommandStrategy(self), TakeOrdered, @@ -350,11 +374,9 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) { /** Extends QueryExecution with hive specific features. */ protected[sql] abstract class QueryExecution extends super.QueryExecution { - override lazy val toRdd: RDD[Row] = executedPlan.execute().map(_.copy()) - protected val primitiveTypes = Seq(StringType, IntegerType, LongType, DoubleType, FloatType, BooleanType, ByteType, - ShortType, DecimalType, DateType, TimestampType, BinaryType) + ShortType, DateType, TimestampType, BinaryType) protected[sql] def toHiveString(a: (Any, DataType)): String = a match { case (struct: Row, StructType(fields)) => @@ -372,6 +394,8 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) { case (d: Date, DateType) => new DateWritable(d).toString case (t: Timestamp, TimestampType) => new TimestampWritable(t).toString case (bin: Array[Byte], BinaryType) => new String(bin, "UTF-8") + case (decimal: Decimal, DecimalType()) => // Hive strips trailing zeros so use its toString + HiveShim.createDecimal(decimal.toBigDecimal.underlying()).toString case (other, tpe) if primitiveTypes contains tpe => other.toString } @@ -390,6 +414,7 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) { }.toSeq.sorted.mkString("{", ",", "}") case (null, _) => "null" case (s: String, StringType) => "\"" + s + "\"" + case (decimal, DecimalType()) => decimal.toString case (other, tpe) if primitiveTypes contains tpe => other.toString } @@ -406,7 +431,7 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) { command.executeCollect().map(_.head.toString) case other => - val result: Seq[Seq[Any]] = toRdd.collect().toSeq + val result: Seq[Seq[Any]] = toRdd.map(_.copy()).collect().toSeq // We need the types so we can output struct field names val types = analyzed.output.map(_.dataType) // Reformat to match hive tab delimited output. diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala index fad7373a2fa39..58815daa82276 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala @@ -17,8 +17,10 @@ package org.apache.spark.sql.hive -import org.apache.hadoop.hive.common.`type`.HiveDecimal +import org.apache.hadoop.hive.common.`type`.{HiveDecimal, HiveVarchar} +import org.apache.hadoop.hive.serde2.typeinfo.TypeInfoFactory import org.apache.hadoop.hive.serde2.objectinspector._ +import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector._ import org.apache.hadoop.hive.serde2.objectinspector.primitive._ import org.apache.hadoop.hive.serde2.{io => hiveIo} import org.apache.hadoop.{io => hadoopIo} @@ -26,6 +28,7 @@ import org.apache.hadoop.{io => hadoopIo} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.types import org.apache.spark.sql.catalyst.types._ +import org.apache.spark.sql.catalyst.types.decimal.Decimal /* Implicit conversions */ import scala.collection.JavaConversions._ @@ -36,7 +39,7 @@ private[hive] trait HiveInspectors { // writable case c: Class[_] if c == classOf[hadoopIo.DoubleWritable] => DoubleType case c: Class[_] if c == classOf[hiveIo.DoubleWritable] => DoubleType - case c: Class[_] if c == classOf[hiveIo.HiveDecimalWritable] => DecimalType + case c: Class[_] if c == classOf[hiveIo.HiveDecimalWritable] => DecimalType.Unlimited case c: Class[_] if c == classOf[hiveIo.ByteWritable] => ByteType case c: Class[_] if c == classOf[hiveIo.ShortWritable] => ShortType case c: Class[_] if c == classOf[hiveIo.DateWritable] => DateType @@ -52,8 +55,8 @@ private[hive] trait HiveInspectors { case c: Class[_] if c == classOf[java.lang.String] => StringType case c: Class[_] if c == classOf[java.sql.Date] => DateType case c: Class[_] if c == classOf[java.sql.Timestamp] => TimestampType - case c: Class[_] if c == classOf[HiveDecimal] => DecimalType - case c: Class[_] if c == classOf[java.math.BigDecimal] => DecimalType + case c: Class[_] if c == classOf[HiveDecimal] => DecimalType.Unlimited + case c: Class[_] if c == classOf[java.math.BigDecimal] => DecimalType.Unlimited case c: Class[_] if c == classOf[Array[Byte]] => BinaryType case c: Class[_] if c == classOf[java.lang.Short] => ShortType case c: Class[_] if c == classOf[java.lang.Integer] => IntegerType @@ -78,88 +81,152 @@ private[hive] trait HiveInspectors { case c: Class[_] if c == classOf[java.lang.Object] => NullType } - /** Converts hive types to native catalyst types. */ - def unwrap(a: Any): Any = a match { - case null => null - case i: hadoopIo.IntWritable => i.get - case t: hadoopIo.Text => t.toString - case l: hadoopIo.LongWritable => l.get - case d: hadoopIo.DoubleWritable => d.get - case d: hiveIo.DoubleWritable => d.get - case s: hiveIo.ShortWritable => s.get - case b: hadoopIo.BooleanWritable => b.get - case b: hiveIo.ByteWritable => b.get - case b: hadoopIo.FloatWritable => b.get - case b: hadoopIo.BytesWritable => { - val bytes = new Array[Byte](b.getLength) - System.arraycopy(b.getBytes(), 0, bytes, 0, b.getLength) - bytes - } - case d: hiveIo.DateWritable => d.get - case t: hiveIo.TimestampWritable => t.getTimestamp - case b: hiveIo.HiveDecimalWritable => BigDecimal(b.getHiveDecimal().bigDecimalValue()) - case list: java.util.List[_] => list.map(unwrap) - case map: java.util.Map[_,_] => map.map { case (k, v) => (unwrap(k), unwrap(v)) }.toMap - case array: Array[_] => array.map(unwrap).toSeq - case p: java.lang.Short => p - case p: java.lang.Long => p - case p: java.lang.Float => p - case p: java.lang.Integer => p - case p: java.lang.Double => p - case p: java.lang.Byte => p - case p: java.lang.Boolean => p - case str: String => str - case p: java.math.BigDecimal => p - case p: Array[Byte] => p - case p: java.sql.Date => p - case p: java.sql.Timestamp => p - } - - def unwrapData(data: Any, oi: ObjectInspector): Any = oi match { + /** + * Converts hive types to native catalyst types. + * @param data the data in Hive type + * @param oi the ObjectInspector associated with the Hive Type + * @return convert the data into catalyst type + */ + def unwrap(data: Any, oi: ObjectInspector): Any = oi match { case hvoi: HiveVarcharObjectInspector => if (data == null) null else hvoi.getPrimitiveJavaObject(data).getValue case hdoi: HiveDecimalObjectInspector => - if (data == null) null else BigDecimal(hdoi.getPrimitiveJavaObject(data).bigDecimalValue()) + if (data == null) null else HiveShim.toCatalystDecimal(hdoi, data) + // org.apache.hadoop.hive.serde2.io.TimestampWritable.set will reset current time object + // if next timestamp is null, so Timestamp object is cloned + case ti: TimestampObjectInspector => ti.getPrimitiveJavaObject(data).clone() case pi: PrimitiveObjectInspector => pi.getPrimitiveJavaObject(data) case li: ListObjectInspector => Option(li.getList(data)) - .map(_.map(unwrapData(_, li.getListElementObjectInspector)).toSeq) + .map(_.map(unwrap(_, li.getListElementObjectInspector)).toSeq) .orNull case mi: MapObjectInspector => Option(mi.getMap(data)).map( _.map { case (k,v) => - (unwrapData(k, mi.getMapKeyObjectInspector), - unwrapData(v, mi.getMapValueObjectInspector)) + (unwrap(k, mi.getMapKeyObjectInspector), + unwrap(v, mi.getMapValueObjectInspector)) }.toMap).orNull case si: StructObjectInspector => val allRefs = si.getAllStructFieldRefs new GenericRow( allRefs.map(r => - unwrapData(si.getStructFieldData(data,r), r.getFieldObjectInspector)).toArray) + unwrap(si.getStructFieldData(data,r), r.getFieldObjectInspector)).toArray) + } + + + /** + * Wraps with Hive types based on object inspector. + * TODO: Consolidate all hive OI/data interface code. + */ + /** + * Wraps with Hive types based on object inspector. + * TODO: Consolidate all hive OI/data interface code. + */ + protected def wrapperFor(oi: ObjectInspector): Any => Any = oi match { + case _: JavaHiveVarcharObjectInspector => + (o: Any) => new HiveVarchar(o.asInstanceOf[String], o.asInstanceOf[String].size) + + case _: JavaHiveDecimalObjectInspector => + (o: Any) => HiveShim.createDecimal(o.asInstanceOf[Decimal].toBigDecimal.underlying()) + + case soi: StandardStructObjectInspector => + val wrappers = soi.getAllStructFieldRefs.map(ref => wrapperFor(ref.getFieldObjectInspector)) + (o: Any) => { + val struct = soi.create() + (soi.getAllStructFieldRefs, wrappers, o.asInstanceOf[Row]).zipped.foreach { + (field, wrapper, data) => soi.setStructFieldData(struct, field, wrapper(data)) + } + struct + } + + case loi: ListObjectInspector => + val wrapper = wrapperFor(loi.getListElementObjectInspector) + (o: Any) => seqAsJavaList(o.asInstanceOf[Seq[_]].map(wrapper)) + + case moi: MapObjectInspector => + // The Predef.Map is scala.collection.immutable.Map. + // Since the map values can be mutable, we explicitly import scala.collection.Map at here. + import scala.collection.Map + + val keyWrapper = wrapperFor(moi.getMapKeyObjectInspector) + val valueWrapper = wrapperFor(moi.getMapValueObjectInspector) + (o: Any) => mapAsJavaMap(o.asInstanceOf[Map[_, _]].map { case (key, value) => + keyWrapper(key) -> valueWrapper(value) + }) + + case _ => + identity[Any] + } + + /** + * Converts native catalyst types to the types expected by Hive + * @param a the value to be wrapped + * @param oi This ObjectInspector associated with the value returned by this function, and + * the ObjectInspector should also be consistent with those returned from + * toInspector: DataType => ObjectInspector and + * toInspector: Expression => ObjectInspector + */ + def wrap(a: Any, oi: ObjectInspector): AnyRef = if (a == null) { + null + } else { + oi match { + case x: ConstantObjectInspector => x.getWritableConstantValue + case x: PrimitiveObjectInspector => a match { + // TODO what if x.preferWritable() == true? reuse the writable? + case s: String => s: java.lang.String + case i: Int => i: java.lang.Integer + case b: Boolean => b: java.lang.Boolean + case f: Float => f: java.lang.Float + case d: Double => d: java.lang.Double + case l: Long => l: java.lang.Long + case l: Short => l: java.lang.Short + case l: Byte => l: java.lang.Byte + case b: BigDecimal => HiveShim.createDecimal(b.underlying()) + case d: Decimal => HiveShim.createDecimal(d.toBigDecimal.underlying()) + case b: Array[Byte] => b + case d: java.sql.Date => d + case t: java.sql.Timestamp => t + } + case x: StructObjectInspector => + val fieldRefs = x.getAllStructFieldRefs + val row = a.asInstanceOf[Seq[_]] + val result = new java.util.ArrayList[AnyRef](fieldRefs.length) + var i = 0 + while (i < fieldRefs.length) { + result.add(wrap(row(i), fieldRefs.get(i).getFieldObjectInspector)) + i += 1 + } + + result + case x: ListObjectInspector => + val list = new java.util.ArrayList[Object] + a.asInstanceOf[Seq[_]].foreach { + v => list.add(wrap(v, x.getListElementObjectInspector)) + } + list + case x: MapObjectInspector => + // Some UDFs seem to assume we pass in a HashMap. + val hashMap = new java.util.HashMap[AnyRef, AnyRef]() + hashMap.putAll(a.asInstanceOf[Map[_, _]].map { + case (k, v) => + wrap(k, x.getMapKeyObjectInspector) -> wrap(v, x.getMapValueObjectInspector) + }) + + hashMap + } } - /** Converts native catalyst types to the types expected by Hive */ - def wrap(a: Any): AnyRef = a match { - case s: String => s: java.lang.String - case i: Int => i: java.lang.Integer - case b: Boolean => b: java.lang.Boolean - case f: Float => f: java.lang.Float - case d: Double => d: java.lang.Double - case l: Long => l: java.lang.Long - case l: Short => l: java.lang.Short - case l: Byte => l: java.lang.Byte - case b: BigDecimal => HiveShim.createDecimal(b.underlying()) - case b: Array[Byte] => b - case d: java.sql.Date => d - case t: java.sql.Timestamp => t - case s: Seq[_] => seqAsJavaList(s.map(wrap)) - case m: Map[_,_] => - // Some UDFs seem to assume we pass in a HashMap. - val hashMap = new java.util.HashMap[AnyRef, AnyRef]() - hashMap.putAll(m.map { case (k, v) => wrap(k) -> wrap(v) }) - hashMap - case null => null + def wrap( + row: Seq[Any], + inspectors: Seq[ObjectInspector], + cache: Array[AnyRef]): Array[AnyRef] = { + var i = 0 + while (i < inspectors.length) { + cache(i) = wrap(row(i), inspectors(i)) + i += 1 + } + cache } def toInspector(dataType: DataType): ObjectInspector = dataType match { @@ -180,12 +247,56 @@ private[hive] trait HiveInspectors { case BinaryType => PrimitiveObjectInspectorFactory.javaByteArrayObjectInspector case DateType => PrimitiveObjectInspectorFactory.javaDateObjectInspector case TimestampType => PrimitiveObjectInspectorFactory.javaTimestampObjectInspector - case DecimalType => PrimitiveObjectInspectorFactory.javaHiveDecimalObjectInspector + case DecimalType() => PrimitiveObjectInspectorFactory.javaHiveDecimalObjectInspector case StructType(fields) => ObjectInspectorFactory.getStandardStructObjectInspector( fields.map(f => f.name), fields.map(f => toInspector(f.dataType))) } + def toInspector(expr: Expression): ObjectInspector = expr match { + case Literal(value: String, StringType) => + HiveShim.getPrimitiveWritableConstantObjectInspector(value) + case Literal(value: Int, IntegerType) => + HiveShim.getPrimitiveWritableConstantObjectInspector(value) + case Literal(value: Double, DoubleType) => + HiveShim.getPrimitiveWritableConstantObjectInspector(value) + case Literal(value: Boolean, BooleanType) => + HiveShim.getPrimitiveWritableConstantObjectInspector(value) + case Literal(value: Long, LongType) => + HiveShim.getPrimitiveWritableConstantObjectInspector(value) + case Literal(value: Float, FloatType) => + HiveShim.getPrimitiveWritableConstantObjectInspector(value) + case Literal(value: Short, ShortType) => + HiveShim.getPrimitiveWritableConstantObjectInspector(value) + case Literal(value: Byte, ByteType) => + HiveShim.getPrimitiveWritableConstantObjectInspector(value) + case Literal(value: Array[Byte], BinaryType) => + HiveShim.getPrimitiveWritableConstantObjectInspector(value) + case Literal(value: java.sql.Date, DateType) => + HiveShim.getPrimitiveWritableConstantObjectInspector(value) + case Literal(value: java.sql.Timestamp, TimestampType) => + HiveShim.getPrimitiveWritableConstantObjectInspector(value) + case Literal(value: BigDecimal, DecimalType()) => + HiveShim.getPrimitiveWritableConstantObjectInspector(value) + case Literal(value: Decimal, DecimalType()) => + HiveShim.getPrimitiveWritableConstantObjectInspector(value.toBigDecimal) + case Literal(_, NullType) => + HiveShim.getPrimitiveNullWritableConstantObjectInspector + case Literal(value: Seq[_], ArrayType(dt, _)) => + val listObjectInspector = toInspector(dt) + val list = new java.util.ArrayList[Object]() + value.foreach(v => list.add(wrap(v, listObjectInspector))) + ObjectInspectorFactory.getStandardConstantListObjectInspector(listObjectInspector, list) + case Literal(map: Map[_, _], MapType(keyType, valueType, _)) => + val value = new java.util.HashMap[Object, Object]() + val keyOI = toInspector(keyType) + val valueOI = toInspector(valueType) + map.foreach (entry => value.put(wrap(entry._1, keyOI), wrap(entry._2, valueOI))) + ObjectInspectorFactory.getStandardConstantMapObjectInspector(keyOI, valueOI, value) + case Literal(_, dt) => sys.error(s"Hive doesn't support the constant type [$dt].") + case _ => toInspector(expr.dataType) + } + def inspectorToDataType(inspector: ObjectInspector): DataType = inspector match { case s: StructObjectInspector => StructType(s.getAllStructFieldRefs.map(f => { @@ -215,8 +326,8 @@ private[hive] trait HiveInspectors { case _: JavaFloatObjectInspector => FloatType case _: WritableBinaryObjectInspector => BinaryType case _: JavaBinaryObjectInspector => BinaryType - case _: WritableHiveDecimalObjectInspector => DecimalType - case _: JavaHiveDecimalObjectInspector => DecimalType + case w: WritableHiveDecimalObjectInspector => HiveShim.decimalTypeInfoToCatalyst(w) + case j: JavaHiveDecimalObjectInspector => HiveShim.decimalTypeInfoToCatalyst(j) case _: WritableDateObjectInspector => DateType case _: JavaDateObjectInspector => DateType case _: WritableTimestampObjectInspector => TimestampType @@ -245,7 +356,7 @@ private[hive] trait HiveInspectors { case LongType => longTypeInfo case ShortType => shortTypeInfo case StringType => stringTypeInfo - case DecimalType => decimalTypeInfo + case d: DecimalType => HiveShim.decimalTypeInfo(d) case DateType => dateTypeInfo case TimestampType => timestampTypeInfo case NullType => voidTypeInfo diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala index 04c48c385966e..0baf4c9f8c7ab 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala @@ -17,17 +17,28 @@ package org.apache.spark.sql.hive +import java.io.IOException +import java.util.{List => JList} + +import scala.util.matching.Regex import scala.util.parsing.combinator.RegexParsers -import org.apache.hadoop.hive.metastore.api.{FieldSchema, SerDeInfo, StorageDescriptor, Partition => TPartition, Table => TTable} -import org.apache.hadoop.hive.ql.metadata.{Hive, Partition, Table} -import org.apache.hadoop.hive.ql.plan.TableDesc -import org.apache.hadoop.hive.serde2.Deserializer +import org.apache.hadoop.util.ReflectionUtils +import org.apache.hadoop.fs.Path + +import org.apache.hadoop.hive.metastore.TableType +import org.apache.hadoop.hive.metastore.api.FieldSchema +import org.apache.hadoop.hive.metastore.api.{Table => TTable, Partition => TPartition} +import org.apache.hadoop.hive.ql.metadata.{Hive, Partition, Table, HiveException} +import org.apache.hadoop.hive.ql.plan.{TableDesc, CreateTableDesc} +import org.apache.hadoop.hive.serde.serdeConstants +import org.apache.hadoop.hive.serde2.{Deserializer, SerDeException} +import org.apache.hadoop.hive.serde2.`lazy`.LazySimpleSerDe import org.apache.spark.Logging import org.apache.spark.annotation.DeveloperApi import org.apache.spark.sql.SQLContext -import org.apache.spark.sql.catalyst.analysis.Catalog +import org.apache.spark.sql.catalyst.analysis.{Catalog, OverrideCatalog} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.logical import org.apache.spark.sql.catalyst.plans.logical._ @@ -46,6 +57,12 @@ private[hive] class HiveMetastoreCatalog(hive: HiveContext) extends Catalog with val caseSensitive: Boolean = false + def tableExists(db: Option[String], tableName: String): Boolean = { + val (databaseName, tblName) = processDatabaseAndTableName( + db.getOrElse(hive.sessionState.getCurrentDatabase), tableName) + client.getTable(databaseName, tblName, false) != null + } + def lookupRelation( db: Option[String], tableName: String, @@ -66,37 +83,164 @@ private[hive] class HiveMetastoreCatalog(hive: HiveContext) extends Catalog with table.getTTable, partitions.map(part => part.getTPartition))(hive) } + /** + * Create table with specified database, table name, table description and schema + * @param databaseName Database Name + * @param tableName Table Name + * @param schema Schema of the new table, if not specified, will use the schema + * specified in crtTbl + * @param allowExisting if true, ignore AlreadyExistsException + * @param desc CreateTableDesc object which contains the SerDe info. Currently + * we support most of the features except the bucket. + */ def createTable( databaseName: String, tableName: String, schema: Seq[Attribute], - allowExisting: Boolean = false): Unit = { + allowExisting: Boolean = false, + desc: Option[CreateTableDesc] = None) { + val hconf = hive.hiveconf + val (dbName, tblName) = processDatabaseAndTableName(databaseName, tableName) - val table = new Table(dbName, tblName) - val hiveSchema = + val tbl = new Table(dbName, tblName) + + val crtTbl: CreateTableDesc = desc.getOrElse(null) + + // We should respect the passed in schema, unless it's not set + val hiveSchema: JList[FieldSchema] = if (schema == null || schema.isEmpty) { + crtTbl.getCols + } else { schema.map(attr => new FieldSchema(attr.name, toMetastoreType(attr.dataType), "")) - table.setFields(hiveSchema) - - val sd = new StorageDescriptor() - table.getTTable.setSd(sd) - sd.setCols(hiveSchema) - - // TODO: THESE ARE ALL DEFAULTS, WE NEED TO PARSE / UNDERSTAND the output specs. - sd.setCompressed(false) - sd.setParameters(Map[String, String]()) - sd.setInputFormat("org.apache.hadoop.mapred.TextInputFormat") - sd.setOutputFormat("org.apache.hadoop.hive.ql.io.HiveIgnoreKeyTextOutputFormat") - val serDeInfo = new SerDeInfo() - serDeInfo.setName(tblName) - serDeInfo.setSerializationLib("org.apache.hadoop.hive.serde2.lazy.LazySimpleSerDe") - serDeInfo.setParameters(Map[String, String]()) - sd.setSerdeInfo(serDeInfo) + } + tbl.setFields(hiveSchema) + // Most of code are similar with the DDLTask.createTable() of Hive, + if (crtTbl != null && crtTbl.getTblProps() != null) { + tbl.getTTable().getParameters().putAll(crtTbl.getTblProps()) + } + + if (crtTbl != null && crtTbl.getPartCols() != null) { + tbl.setPartCols(crtTbl.getPartCols()) + } + + if (crtTbl != null && crtTbl.getStorageHandler() != null) { + tbl.setProperty( + org.apache.hadoop.hive.metastore.api.hive_metastoreConstants.META_TABLE_STORAGE, + crtTbl.getStorageHandler()) + } + + /* + * We use LazySimpleSerDe by default. + * + * If the user didn't specify a SerDe, and any of the columns are not simple + * types, we will have to use DynamicSerDe instead. + */ + if (crtTbl == null || crtTbl.getSerName() == null) { + val storageHandler = tbl.getStorageHandler() + if (storageHandler == null) { + logInfo(s"Default to LazySimpleSerDe for table $dbName.$tblName") + tbl.setSerializationLib(classOf[LazySimpleSerDe].getName()) + + import org.apache.hadoop.mapred.TextInputFormat + import org.apache.hadoop.hive.ql.io.HiveIgnoreKeyTextOutputFormat + import org.apache.hadoop.io.Text + + tbl.setInputFormatClass(classOf[TextInputFormat]) + tbl.setOutputFormatClass(classOf[HiveIgnoreKeyTextOutputFormat[Text, Text]]) + tbl.setSerializationLib("org.apache.hadoop.hive.serde2.lazy.LazySimpleSerDe") + } else { + val serDeClassName = storageHandler.getSerDeClass().getName() + logInfo(s"Use StorageHandler-supplied $serDeClassName for table $dbName.$tblName") + tbl.setSerializationLib(serDeClassName) + } + } else { + // let's validate that the serde exists + val serdeName = crtTbl.getSerName() + try { + val d = ReflectionUtils.newInstance(hconf.getClassByName(serdeName), hconf) + if (d != null) { + logDebug("Found class for $serdeName") + } + } catch { + case e: SerDeException => throw new HiveException("Cannot validate serde: " + serdeName, e) + } + tbl.setSerializationLib(serdeName) + } + + if (crtTbl != null && crtTbl.getFieldDelim() != null) { + tbl.setSerdeParam(serdeConstants.FIELD_DELIM, crtTbl.getFieldDelim()) + tbl.setSerdeParam(serdeConstants.SERIALIZATION_FORMAT, crtTbl.getFieldDelim()) + } + if (crtTbl != null && crtTbl.getFieldEscape() != null) { + tbl.setSerdeParam(serdeConstants.ESCAPE_CHAR, crtTbl.getFieldEscape()) + } + + if (crtTbl != null && crtTbl.getCollItemDelim() != null) { + tbl.setSerdeParam(serdeConstants.COLLECTION_DELIM, crtTbl.getCollItemDelim()) + } + if (crtTbl != null && crtTbl.getMapKeyDelim() != null) { + tbl.setSerdeParam(serdeConstants.MAPKEY_DELIM, crtTbl.getMapKeyDelim()) + } + if (crtTbl != null && crtTbl.getLineDelim() != null) { + tbl.setSerdeParam(serdeConstants.LINE_DELIM, crtTbl.getLineDelim()) + } + + if (crtTbl != null && crtTbl.getSerdeProps() != null) { + val iter = crtTbl.getSerdeProps().entrySet().iterator() + while (iter.hasNext()) { + val m = iter.next() + tbl.setSerdeParam(m.getKey(), m.getValue()) + } + } + + if (crtTbl != null && crtTbl.getComment() != null) { + tbl.setProperty("comment", crtTbl.getComment()) + } + + if (crtTbl != null && crtTbl.getLocation() != null) { + HiveShim.setLocation(tbl, crtTbl) + } + + if (crtTbl != null && crtTbl.getSkewedColNames() != null) { + tbl.setSkewedColNames(crtTbl.getSkewedColNames()) + } + if (crtTbl != null && crtTbl.getSkewedColValues() != null) { + tbl.setSkewedColValues(crtTbl.getSkewedColValues()) + } + + if (crtTbl != null) { + tbl.setStoredAsSubDirectories(crtTbl.isStoredAsSubDirectories()) + tbl.setInputFormatClass(crtTbl.getInputFormat()) + tbl.setOutputFormatClass(crtTbl.getOutputFormat()) + } + + tbl.getTTable().getSd().setInputFormat(tbl.getInputFormatClass().getName()) + tbl.getTTable().getSd().setOutputFormat(tbl.getOutputFormatClass().getName()) + + if (crtTbl != null && crtTbl.isExternal()) { + tbl.setProperty("EXTERNAL", "TRUE") + tbl.setTableType(TableType.EXTERNAL_TABLE) + } + + // set owner + try { + tbl.setOwner(hive.hiveconf.getUser) + } catch { + case e: IOException => throw new HiveException("Unable to get current user", e) + } + + // set create time + tbl.setCreateTime((System.currentTimeMillis() / 1000).asInstanceOf[Int]) + + // TODO add bucket support + // TODO set more info if Hive upgrade + + // create the table synchronized { - try client.createTable(table) catch { - case e: org.apache.hadoop.hive.ql.metadata.HiveException - if e.getCause.isInstanceOf[org.apache.hadoop.hive.metastore.api.AlreadyExistsException] && - allowExisting => // Do nothing. + try client.createTable(tbl, allowExisting) catch { + case e: org.apache.hadoop.hive.metastore.api.AlreadyExistsException + if allowExisting => // Do nothing + case e: Throwable => throw e } } } @@ -110,11 +254,11 @@ private[hive] class HiveMetastoreCatalog(hive: HiveContext) extends Catalog with // Wait until children are resolved. case p: LogicalPlan if !p.childrenResolved => p - case CreateTableAsSelect(db, tableName, child) => + case CreateTableAsSelect(db, tableName, child, allowExisting, extra) => val (dbName, tblName) = processDatabaseAndTableName(db, tableName) val databaseName = dbName.getOrElse(hive.sessionState.getCurrentDatabase) - CreateTableAsSelect(Some(databaseName), tableName, child) + CreateTableAsSelect(Some(databaseName), tableName, child, allowExisting, extra) } } @@ -184,11 +328,18 @@ object HiveMetastoreTypes extends RegexParsers { "bigint" ^^^ LongType | "binary" ^^^ BinaryType | "boolean" ^^^ BooleanType | - HiveShim.metastoreDecimal ^^^ DecimalType | + fixedDecimalType | // Hive 0.13+ decimal with precision/scale + "decimal" ^^^ DecimalType.Unlimited | // Hive 0.12 decimal with no precision/scale "date" ^^^ DateType | "timestamp" ^^^ TimestampType | "varchar\\((\\d+)\\)".r ^^^ StringType + protected lazy val fixedDecimalType: Parser[DataType] = + ("decimal" ~> "(" ~> "\\d+".r) ~ ("," ~> "\\d+".r <~ ")") ^^ { + case precision ~ scale => + DecimalType(precision.toInt, scale.toInt) + } + protected lazy val arrayType: Parser[DataType] = "array" ~> "<" ~> dataType <~ ">" ^^ { case tpe => ArrayType(tpe) @@ -236,7 +387,7 @@ object HiveMetastoreTypes extends RegexParsers { case BinaryType => "binary" case BooleanType => "boolean" case DateType => "date" - case DecimalType => "decimal" + case d: DecimalType => HiveShim.decimalMetastoreString(d) case TimestampType => "timestamp" case NullType => "void" } @@ -304,7 +455,13 @@ private[hive] case class MetastoreRelation val partitionKeys = hiveQlTable.getPartitionKeys.map(_.toAttribute) /** Non-partitionKey attributes */ - val attributes = hiveQlTable.getCols.map(_.toAttribute) + val attributes = hiveQlTable.getCols.map(_.toAttribute) val output = attributes ++ partitionKeys + + /** An attribute map that can be used to lookup original attributes based on expression id. */ + val attributeMap = AttributeMap(output.map(o => (o,o))) + + /** An attribute map for determining the ordinal for non-partition columns. */ + val columnOrdinals = AttributeMap(attributes.zipWithIndex) } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala index 54c619722ee12..74f68d0f95317 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala @@ -31,6 +31,7 @@ import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.types._ +import org.apache.spark.sql.catalyst.types.decimal.Decimal /* Implicit conversions */ import scala.collection.JavaConversions._ @@ -124,7 +125,8 @@ private[hive] object HiveQl { // Commands that we do not need to explain. protected val noExplainCommands = Seq( "TOK_CREATETABLE", - "TOK_DESCTABLE" + "TOK_DESCTABLE", + "TOK_TRUNCATETABLE" // truncate table" is a NativeCommand, does not need to explain. ) ++ nativeCommands protected val hqlParser = { @@ -324,7 +326,11 @@ private[hive] object HiveQl { } protected def nodeToDataType(node: Node): DataType = node match { - case Token("TOK_DECIMAL", Nil) => DecimalType + case Token("TOK_DECIMAL", precision :: scale :: Nil) => + DecimalType(precision.getText.toInt, scale.getText.toInt) + case Token("TOK_DECIMAL", precision :: Nil) => + DecimalType(precision.getText.toInt, 0) + case Token("TOK_DECIMAL", Nil) => DecimalType.Unlimited case Token("TOK_BIGINT", Nil) => LongType case Token("TOK_INT", Nil) => IntegerType case Token("TOK_TINYINT", Nil) => ByteType @@ -447,14 +453,14 @@ private[hive] object HiveQl { } case Token("TOK_CREATETABLE", children) - if children.collect { case t@Token("TOK_QUERY", _) => t }.nonEmpty => - // TODO: Parse other clauses. + if children.collect { case t @ Token("TOK_QUERY", _) => t }.nonEmpty => // Reference: https://cwiki.apache.org/confluence/display/Hive/LanguageManual+DDL val ( Some(tableNameParts) :: _ /* likeTable */ :: - Some(query) +: - notImplemented) = + Some(query) :: + allowExisting +: + ignores) = getClauses( Seq( "TOK_TABNAME", @@ -478,18 +484,17 @@ private[hive] object HiveQl { "TOK_TABLELOCATION", "TOK_TABLEPROPERTIES"), children) - if (notImplemented.exists(token => !token.isEmpty)) { - throw new NotImplementedError( - s"Unhandled clauses: ${notImplemented.flatten.map(dumpTree(_)).mkString("\n")}") - } - val (db, tableName) = extractDbNameTableName(tableNameParts) - CreateTableAsSelect(db, tableName, nodeToPlan(query)) + CreateTableAsSelect(db, tableName, nodeToPlan(query), allowExisting != None, Some(node)) // If its not a "CREATE TABLE AS" like above then just pass it back to hive as a native command. case Token("TOK_CREATETABLE", _) => NativePlaceholder + // Support "TRUNCATE TABLE table_name [PARTITION partition_spec]" + case Token("TOK_TRUNCATETABLE", + Token("TOK_TABLE_PARTITION",table)::Nil) => NativePlaceholder + case Token("TOK_QUERY", Token("TOK_FROM", fromClause :: Nil) :: insertClauses) => @@ -942,8 +947,12 @@ private[hive] object HiveQl { Cast(nodeToExpr(arg), BinaryType) case Token("TOK_FUNCTION", Token("TOK_BOOLEAN", Nil) :: arg :: Nil) => Cast(nodeToExpr(arg), BooleanType) + case Token("TOK_FUNCTION", Token("TOK_DECIMAL", precision :: scale :: nil) :: arg :: Nil) => + Cast(nodeToExpr(arg), DecimalType(precision.getText.toInt, scale.getText.toInt)) + case Token("TOK_FUNCTION", Token("TOK_DECIMAL", precision :: Nil) :: arg :: Nil) => + Cast(nodeToExpr(arg), DecimalType(precision.getText.toInt, 0)) case Token("TOK_FUNCTION", Token("TOK_DECIMAL", Nil) :: arg :: Nil) => - Cast(nodeToExpr(arg), DecimalType) + Cast(nodeToExpr(arg), DecimalType.Unlimited) case Token("TOK_FUNCTION", Token("TOK_TIMESTAMP", Nil) :: arg :: Nil) => Cast(nodeToExpr(arg), TimestampType) case Token("TOK_FUNCTION", Token("TOK_DATE", Nil) :: arg :: Nil) => @@ -951,6 +960,7 @@ private[hive] object HiveQl { /* Arithmetic */ case Token("-", child :: Nil) => UnaryMinus(nodeToExpr(child)) + case Token("~", child :: Nil) => BitwiseNot(nodeToExpr(child)) case Token("+", left :: right:: Nil) => Add(nodeToExpr(left), nodeToExpr(right)) case Token("-", left :: right:: Nil) => Subtract(nodeToExpr(left), nodeToExpr(right)) case Token("*", left :: right:: Nil) => Multiply(nodeToExpr(left), nodeToExpr(right)) @@ -958,6 +968,9 @@ private[hive] object HiveQl { case Token(DIV(), left :: right:: Nil) => Cast(Divide(nodeToExpr(left), nodeToExpr(right)), LongType) case Token("%", left :: right:: Nil) => Remainder(nodeToExpr(left), nodeToExpr(right)) + case Token("&", left :: right:: Nil) => BitwiseAnd(nodeToExpr(left), nodeToExpr(right)) + case Token("|", left :: right:: Nil) => BitwiseOr(nodeToExpr(left), nodeToExpr(right)) + case Token("^", left :: right:: Nil) => BitwiseXor(nodeToExpr(left), nodeToExpr(right)) case Token("TOK_FUNCTION", Token(SQRT(), Nil) :: arg :: Nil) => Sqrt(nodeToExpr(arg)) /* Comparisons */ @@ -981,15 +994,20 @@ private[hive] object HiveQl { In(nodeToExpr(value), list.map(nodeToExpr)) case Token("TOK_FUNCTION", Token(BETWEEN(), Nil) :: - Token("KW_FALSE", Nil) :: + kw :: target :: minValue :: maxValue :: Nil) => val targetExpression = nodeToExpr(target) - And( - GreaterThanOrEqual(targetExpression, nodeToExpr(minValue)), - LessThanOrEqual(targetExpression, nodeToExpr(maxValue))) + val betweenExpr = + And( + GreaterThanOrEqual(targetExpression, nodeToExpr(minValue)), + LessThanOrEqual(targetExpression, nodeToExpr(maxValue))) + kw match { + case Token("KW_FALSE", Nil) => betweenExpr + case Token("KW_TRUE", Nil) => Not(betweenExpr) + } /* Boolean Logic */ case Token(AND(), left :: right:: Nil) => And(nodeToExpr(left), nodeToExpr(right)) @@ -1054,7 +1072,7 @@ private[hive] object HiveQl { } else if (ast.getText.endsWith("BD") || ast.getText.endsWith("D")) { // Literal decimal val strVal = ast.getText.stripSuffix("D").stripSuffix("B") - v = Literal(BigDecimal(strVal)) + v = Literal(Decimal(strVal)) } else { v = Literal(ast.getText.toDouble, DoubleType) v = Literal(ast.getText.toLong, LongType) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala index 5c66322f1ed99..989740c8d43b6 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.hive +import org.apache.hadoop.hive.ql.parse.ASTNode + import org.apache.spark.annotation.Experimental import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute import org.apache.spark.sql.catalyst.expressions._ @@ -29,7 +31,7 @@ import org.apache.spark.sql.execution.{DescribeCommand, OutputFaker, SparkPlan} import org.apache.spark.sql.hive import org.apache.spark.sql.hive.execution._ import org.apache.spark.sql.parquet.ParquetRelation -import org.apache.spark.sql.{SQLContext, SchemaRDD} +import org.apache.spark.sql.{SQLContext, SchemaRDD, Strategy} import scala.collection.JavaConversions._ @@ -160,17 +162,14 @@ private[hive] trait HiveStrategies { def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { case logical.InsertIntoTable(table: MetastoreRelation, partition, child, overwrite) => InsertIntoHiveTable(table, partition, planLater(child), overwrite)(hiveContext) :: Nil - - case logical.CreateTableAsSelect(database, tableName, child) => - val query = planLater(child) + case logical.CreateTableAsSelect( + Some(database), tableName, child, allowExisting, Some(extra: ASTNode)) => CreateTableAsSelect( - database.get, + database, tableName, - query, - InsertIntoHiveTable(_: MetastoreRelation, - Map(), - query, - overwrite = true)(hiveContext)) :: Nil + child, + allowExisting, + extra) :: Nil case _ => Nil } } @@ -207,6 +206,8 @@ private[hive] trait HiveStrategies { case hive.AddJar(path) => execution.AddJar(path) :: Nil + case hive.AddFile(path) => execution.AddFile(path) :: Nil + case hive.AnalyzeTable(tableName) => execution.AnalyzeTable(tableName) :: Nil case describe: logical.DescribeCommand => diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/TableReader.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/TableReader.scala index 9ff7ab5a124c1..e49f0957d188a 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/TableReader.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/TableReader.scala @@ -291,7 +291,7 @@ private[hive] object HadoopTableReader extends HiveInspectors { case oi: DoubleObjectInspector => (value: Any, row: MutableRow, ordinal: Int) => row.setDouble(ordinal, oi.get(value)) case oi => - (value: Any, row: MutableRow, ordinal: Int) => row(ordinal) = unwrapData(value, oi) + (value: Any, row: MutableRow, ordinal: Int) => row(ordinal) = unwrap(value, oi) } } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/CreateTableAsSelect.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/CreateTableAsSelect.scala index 3625708d03175..3d24d87bc3d38 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/CreateTableAsSelect.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/CreateTableAsSelect.scala @@ -17,10 +17,12 @@ package org.apache.spark.sql.hive.execution +import org.apache.hadoop.hive.ql.Context +import org.apache.hadoop.hive.ql.parse.{SemanticAnalyzer, ASTNode} import org.apache.spark.annotation.Experimental import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.expressions.Row -import org.apache.spark.sql.catalyst.expressions.Attribute +import org.apache.spark.sql.catalyst.plans.logical.{InsertIntoTable, LogicalPlan} import org.apache.spark.sql.execution.{SparkPlan, Command, LeafNode} import org.apache.spark.sql.hive.HiveContext import org.apache.spark.sql.hive.MetastoreRelation @@ -30,33 +32,56 @@ import org.apache.spark.sql.hive.MetastoreRelation * Create table and insert the query result into it. * @param database the database name of the new relation * @param tableName the table name of the new relation - * @param insertIntoRelation function of creating the `InsertIntoHiveTable` - * by specifying the `MetaStoreRelation`, the data will be inserted into that table. - * TODO Add more table creating properties, e.g. SerDe, StorageHandler, in-memory cache etc. + * @param query the query whose result will be insert into the new relation + * @param allowExisting allow continue working if it's already exists, otherwise + * raise exception + * @param extra the extra information for this Operator, it should be the + * ASTNode object for extracting the CreateTableDesc. + */ @Experimental case class CreateTableAsSelect( database: String, tableName: String, - query: SparkPlan, - insertIntoRelation: MetastoreRelation => InsertIntoHiveTable) - extends LeafNode with Command { + query: LogicalPlan, + allowExisting: Boolean, + extra: ASTNode) extends LeafNode with Command { def output = Seq.empty + private[this] def sc = sqlContext.asInstanceOf[HiveContext] + // A lazy computing of the metastoreRelation private[this] lazy val metastoreRelation: MetastoreRelation = { - // Create the table - val sc = sqlContext.asInstanceOf[HiveContext] - sc.catalog.createTable(database, tableName, query.output, false) + // Get the CreateTableDesc from Hive SemanticAnalyzer + val sa = new SemanticAnalyzer(sc.hiveconf) + + sa.analyze(extra, new Context(sc.hiveconf)) + val desc = sa.getQB().getTableDesc + // Create Hive Table + sc.catalog.createTable(database, tableName, query.output, allowExisting, Some(desc)) + // Get the Metastore Relation sc.catalog.lookupRelation(Some(database), tableName, None) match { case r: MetastoreRelation => r } } - override protected lazy val sideEffectResult: Seq[Row] = { - insertIntoRelation(metastoreRelation).execute + override protected[sql] lazy val sideEffectResult: Seq[Row] = { + // TODO ideally, we should get the output data ready first and then + // add the relation into catalog, just in case of failure occurs while data + // processing. + if (sc.catalog.tableExists(Some(database), tableName)) { + if (allowExisting) { + // table already exists, will do nothing, to keep consistent with Hive + } else { + throw + new org.apache.hadoop.hive.metastore.api.AlreadyExistsException(s"$database.$tableName") + } + } else { + sc.executePlan(InsertIntoTable(metastoreRelation, Map(), query, true)).toRdd + } + Seq.empty[Row] } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveTableScan.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveTableScan.scala index 85965a6ea095a..d39413a44a6cb 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveTableScan.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveTableScan.scala @@ -44,7 +44,7 @@ import org.apache.spark.sql.hive._ */ @DeveloperApi case class HiveTableScan( - attributes: Seq[Attribute], + requestedAttributes: Seq[Attribute], relation: MetastoreRelation, partitionPruningPred: Option[Expression])( @transient val context: HiveContext) @@ -53,6 +53,9 @@ case class HiveTableScan( require(partitionPruningPred.isEmpty || relation.hiveQlTable.isPartitioned, "Partition pruning predicates only supported for partitioned tables.") + // Retrieve the original attributes based on expression ID so that capitalization matches. + val attributes = requestedAttributes.map(relation.attributeMap) + // Bind all partition key attribute references in the partition pruning predicate for later // evaluation. private[this] val boundPruningPred = partitionPruningPred.map { pred => @@ -68,6 +71,9 @@ case class HiveTableScan( @transient private[this] val hiveExtraConf = new HiveConf(context.hiveconf) + // append columns ids and names before broadcast + addColumnMetadataToConf(hiveExtraConf) + @transient private[this] val hadoopReader = new HadoopTableReader(attributes, relation, context, hiveExtraConf) @@ -78,9 +84,7 @@ case class HiveTableScan( private def addColumnMetadataToConf(hiveConf: HiveConf) { // Specifies needed column IDs for those non-partitioning columns. - val neededColumnIDs = - attributes.map(a => - relation.attributes.indexWhere(_.name == a.name): Integer).filter(index => index >= 0) + val neededColumnIDs = attributes.flatMap(relation.columnOrdinals.get).map(o => o: Integer) HiveShim.appendReadColumns(hiveConf, neededColumnIDs, attributes.map(_.name)) @@ -105,8 +109,6 @@ case class HiveTableScan( hiveConf.set(serdeConstants.LIST_COLUMNS, relation.attributes.map(_.name).mkString(",")) } - addColumnMetadataToConf(hiveExtraConf) - /** * Prunes partitions not involve the query plan. * diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala index 7db5fd804d6ef..74b4e7aaa47a5 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.hive.execution import scala.collection.JavaConversions._ -import org.apache.hadoop.hive.common.`type`.{HiveDecimal, HiveVarchar} +import org.apache.hadoop.hive.common.`type`.HiveVarchar import org.apache.hadoop.hive.conf.HiveConf import org.apache.hadoop.hive.conf.HiveConf.ConfVars import org.apache.hadoop.hive.metastore.MetaStoreUtils @@ -35,6 +35,7 @@ import org.apache.hadoop.mapred.{FileOutputCommitter, FileOutputFormat, JobConf} import org.apache.spark.annotation.DeveloperApi import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.expressions.Row +import org.apache.spark.sql.catalyst.types.decimal.Decimal import org.apache.spark.sql.execution.{Command, SparkPlan, UnaryNode} import org.apache.spark.sql.hive._ import org.apache.spark.sql.hive.{ ShimFileSinkDesc => FileSinkDesc} @@ -51,7 +52,7 @@ case class InsertIntoHiveTable( child: SparkPlan, overwrite: Boolean) (@transient sc: HiveContext) - extends UnaryNode with Command { + extends UnaryNode with Command with HiveInspectors { @transient lazy val outputClass = newSerializer(table.tableDesc).getSerializedClass @transient private lazy val hiveContext = new Context(sc.hiveconf) @@ -67,42 +68,6 @@ case class InsertIntoHiveTable( def output = child.output - /** - * Wraps with Hive types based on object inspector. - * TODO: Consolidate all hive OI/data interface code. - */ - protected def wrapperFor(oi: ObjectInspector): Any => Any = oi match { - case _: JavaHiveVarcharObjectInspector => - (o: Any) => new HiveVarchar(o.asInstanceOf[String], o.asInstanceOf[String].size) - - case _: JavaHiveDecimalObjectInspector => - (o: Any) => HiveShim.createDecimal(o.asInstanceOf[BigDecimal].underlying()) - - case soi: StandardStructObjectInspector => - val wrappers = soi.getAllStructFieldRefs.map(ref => wrapperFor(ref.getFieldObjectInspector)) - (o: Any) => { - val struct = soi.create() - (soi.getAllStructFieldRefs, wrappers, o.asInstanceOf[Row]).zipped.foreach { - (field, wrapper, data) => soi.setStructFieldData(struct, field, wrapper(data)) - } - struct - } - - case loi: ListObjectInspector => - val wrapper = wrapperFor(loi.getListElementObjectInspector) - (o: Any) => seqAsJavaList(o.asInstanceOf[Seq[_]].map(wrapper)) - - case moi: MapObjectInspector => - val keyWrapper = wrapperFor(moi.getMapKeyObjectInspector) - val valueWrapper = wrapperFor(moi.getMapValueObjectInspector) - (o: Any) => mapAsJavaMap(o.asInstanceOf[Map[_, _]].map { case (key, value) => - keyWrapper(key) -> valueWrapper(value) - }) - - case _ => - identity[Any] - } - def saveAsHiveFile( rdd: RDD[Row], valueClass: Class[_], diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/commands.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/commands.scala index 0fc674af31885..903075edf7e04 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/commands.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/commands.scala @@ -76,3 +76,19 @@ case class AddJar(path: String) extends LeafNode with Command { Seq.empty[Row] } } + +/** + * :: DeveloperApi :: + */ +@DeveloperApi +case class AddFile(path: String) extends LeafNode with Command { + def hiveContext = sqlContext.asInstanceOf[HiveContext] + + override def output = Seq.empty + + override protected lazy val sideEffectResult: Seq[Row] = { + hiveContext.runSqlHive(s"ADD FILE $path") + hiveContext.sparkContext.addFile(path) + Seq.empty[Row] + } +} diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala index 68f93f247d9bb..86f7eea5dfd69 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala @@ -21,11 +21,14 @@ import org.apache.hadoop.hive.ql.udf.generic.GenericUDFUtils.ConversionHelper import scala.collection.mutable.ArrayBuffer -import org.apache.hadoop.hive.common.`type`.HiveDecimal +import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector +import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory.ObjectInspectorOptions +import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory import org.apache.hadoop.hive.ql.exec.{UDF, UDAF} import org.apache.hadoop.hive.ql.exec.{FunctionInfo, FunctionRegistry} import org.apache.hadoop.hive.ql.udf.{UDFType => HiveUDFType} import org.apache.hadoop.hive.ql.udf.generic._ +import org.apache.hadoop.hive.ql.udf.generic.GenericUDF._ import org.apache.spark.Logging import org.apache.spark.sql.catalyst.analysis @@ -97,7 +100,17 @@ private[hive] case class HiveSimpleUdf(functionClassName: String, children: Seq[ function.getResolver.getEvalMethod(children.map(_.dataType.toTypeInfo)) @transient - protected lazy val arguments = children.map(c => toInspector(c.dataType)).toArray + protected lazy val arguments = children.map(toInspector).toArray + + @transient + protected lazy val isUDFDeterministic = { + val udfType = function.getClass().getAnnotation(classOf[HiveUDFType]) + udfType != null && udfType.deterministic() + } + + override def foldable = { + isUDFDeterministic && children.foldLeft(true)((prev, n) => prev && n.foldable) + } // Create parameter converters @transient @@ -106,24 +119,39 @@ private[hive] case class HiveSimpleUdf(functionClassName: String, children: Seq[ @transient lazy val dataType = javaClassToDataType(method.getReturnType) + @transient + lazy val returnInspector = ObjectInspectorFactory.getReflectionObjectInspector( + method.getGenericReturnType(), ObjectInspectorOptions.JAVA) + + @transient + protected lazy val cached = new Array[AnyRef](children.length) + // TODO: Finish input output types. override def eval(input: Row): Any = { - val evaluatedChildren = children.map(c => wrap(c.eval(input))) + unwrap( + FunctionRegistry.invoke(method, function, conversionHelper + .convertIfNecessary(wrap(children.map(c => c.eval(input)), arguments, cached): _*): _*), + returnInspector) + } +} - unwrap(FunctionRegistry.invoke(method, function, conversionHelper - .convertIfNecessary(evaluatedChildren: _*): _*)) +// Adapter from Catalyst ExpressionResult to Hive DeferredObject +private[hive] class DeferredObjectAdapter(oi: ObjectInspector) + extends DeferredObject with HiveInspectors { + private var func: () => Any = _ + def set(func: () => Any) { + this.func = func } + override def prepare(i: Int) = {} + override def get(): AnyRef = wrap(func(), oi) } private[hive] case class HiveGenericUdf(functionClassName: String, children: Seq[Expression]) extends HiveUdf with HiveInspectors { - - import org.apache.hadoop.hive.ql.udf.generic.GenericUDF._ - type UDFType = GenericUDF @transient - protected lazy val argumentInspectors = children.map(_.dataType).map(toInspector) + protected lazy val argumentInspectors = children.map(toInspector) @transient protected lazy val returnInspector = function.initialize(argumentInspectors.toArray) @@ -138,19 +166,9 @@ private[hive] case class HiveGenericUdf(functionClassName: String, children: Seq isUDFDeterministic && children.foldLeft(true)((prev, n) => prev && n.foldable) } - protected lazy val deferedObjects = Array.fill[DeferredObject](children.length)({ - new DeferredObjectAdapter - }) - - // Adapter from Catalyst ExpressionResult to Hive DeferredObject - class DeferredObjectAdapter extends DeferredObject { - private var func: () => Any = _ - def set(func: () => Any) { - this.func = func - } - override def prepare(i: Int) = {} - override def get(): AnyRef = wrap(func()) - } + @transient + protected lazy val deferedObjects = + argumentInspectors.map(new DeferredObjectAdapter(_)).toArray[DeferredObject] lazy val dataType: DataType = inspectorToDataType(returnInspector) @@ -159,10 +177,13 @@ private[hive] case class HiveGenericUdf(functionClassName: String, children: Seq var i = 0 while (i < children.length) { val idx = i - deferedObjects(i).asInstanceOf[DeferredObjectAdapter].set(() => {children(idx).eval(input)}) + deferedObjects(i).asInstanceOf[DeferredObjectAdapter].set( + () => { + children(idx).eval(input) + }) i += 1 } - unwrap(function.evaluate(deferedObjects)) + unwrap(function.evaluate(deferedObjects), returnInspector) } } @@ -250,12 +271,14 @@ private[hive] case class HiveGenericUdtf( protected lazy val inputInspectors = children.map(_.dataType).map(toInspector) @transient - protected lazy val outputInspectors = { - val structInspector = function.initialize(inputInspectors.toArray) - structInspector.getAllStructFieldRefs.map(_.getFieldObjectInspector) - } + protected lazy val outputInspector = function.initialize(inputInspectors.toArray) + + @transient + protected lazy val udtInput = new Array[AnyRef](children.length) - protected lazy val outputDataTypes = outputInspectors.map(inspectorToDataType) + protected lazy val outputDataTypes = outputInspector.getAllStructFieldRefs.map { + field => inspectorToDataType(field.getFieldObjectInspector) + } override protected def makeOutput() = { // Use column names when given, otherwise c_1, c_2, ... c_n. @@ -273,14 +296,12 @@ private[hive] case class HiveGenericUdtf( } override def eval(input: Row): TraversableOnce[Row] = { - outputInspectors // Make sure initialized. + outputInspector // Make sure initialized. val inputProjection = new InterpretedProjection(children) val collector = new UDTFCollector function.setCollector(collector) - - val udtInput = inputProjection(input).map(wrap).toArray - function.process(udtInput) + function.process(wrap(inputProjection(input), inputInspectors, udtInput)) collector.collectRows() } @@ -291,7 +312,7 @@ private[hive] case class HiveGenericUdtf( // We need to clone the input here because implementations of // GenericUDTF reuse the same object. Luckily they are always an array, so // it is easy to clone. - collected += new GenericRow(input.asInstanceOf[Array[_]].map(unwrap)) + collected += unwrap(input, outputInspector).asInstanceOf[Row] } def collectRows() = { @@ -332,7 +353,7 @@ private[hive] case class HiveUdafFunction( private val buffer = function.getNewAggregationBuffer.asInstanceOf[GenericUDAFEvaluator.AbstractAggregationBuffer] - override def eval(input: Row): Any = unwrapData(function.evaluate(buffer), returnInspector) + override def eval(input: Row): Any = unwrap(function.evaluate(buffer), returnInspector) @transient val inputProjection = new InterpretedProjection(exprs) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveWriterContainers.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveWriterContainers.scala index 981ab954da489..bf2ce9df67c58 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveWriterContainers.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveWriterContainers.scala @@ -27,6 +27,7 @@ import org.apache.hadoop.fs.Path import org.apache.hadoop.hive.conf.HiveConf.ConfVars import org.apache.hadoop.hive.ql.exec.{FileSinkOperator, Utilities} import org.apache.hadoop.hive.ql.io.{HiveFileFormatUtils, HiveOutputFormat} +import org.apache.hadoop.hive.ql.plan.{PlanUtils, TableDesc} import org.apache.hadoop.io.Writable import org.apache.hadoop.mapred._ @@ -47,6 +48,13 @@ private[hive] class SparkHiveWriterContainer( with Serializable { private val now = new Date() + private val tableDesc: TableDesc = fileSinkConf.getTableInfo + // Add table properties from storage handler to jobConf, so any custom storage + // handler settings can be set to jobConf + if (tableDesc != null) { + PlanUtils.configureOutputJobPropertiesForStorageHandler(tableDesc) + Utilities.copyTableJobPropertiesToConf(tableDesc, jobConf) + } protected val conf = new SerializableWritable(jobConf) private var jobID = 0 diff --git a/sql/hive/src/test/resources/data/files/issue-4077-data.txt b/sql/hive/src/test/resources/data/files/issue-4077-data.txt new file mode 100644 index 0000000000000..18067b0a64c9c --- /dev/null +++ b/sql/hive/src/test/resources/data/files/issue-4077-data.txt @@ -0,0 +1,2 @@ +2014-12-11 00:00:00,1 +2014-12-11astring00:00:00,2 diff --git a/sql/hive/src/test/resources/golden/constant array-0-761ef205b10ac4a10122c8b4ce10ada b/sql/hive/src/test/resources/golden/constant array-0-761ef205b10ac4a10122c8b4ce10ada new file mode 100644 index 0000000000000..94f18d09863a7 --- /dev/null +++ b/sql/hive/src/test/resources/golden/constant array-0-761ef205b10ac4a10122c8b4ce10ada @@ -0,0 +1 @@ +["enterprise databases","hadoop distributed file system","hadoop map-reduce"] diff --git a/sql/hive/src/test/resources/golden/truncate_column-0-616cad77ad5e7ac74da0d7425a7869a b/sql/hive/src/test/resources/golden/truncate_column-0-616cad77ad5e7ac74da0d7425a7869a new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/truncate_column-1-418ec894d08c33fd712eb358f579b7a0 b/sql/hive/src/test/resources/golden/truncate_column-1-418ec894d08c33fd712eb358f579b7a0 new file mode 100644 index 0000000000000..573541ac9702d --- /dev/null +++ b/sql/hive/src/test/resources/golden/truncate_column-1-418ec894d08c33fd712eb358f579b7a0 @@ -0,0 +1 @@ +0 diff --git a/sql/hive/src/test/resources/golden/truncate_column_merge-0-46e8cc1556fa8586802a26267a906acf b/sql/hive/src/test/resources/golden/truncate_column_merge-0-46e8cc1556fa8586802a26267a906acf new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/truncate_table-0-d16efe9bac079f0c5fc6cc424a8fa3eb b/sql/hive/src/test/resources/golden/truncate_table-0-d16efe9bac079f0c5fc6cc424a8fa3eb new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/truncate_table-1-7fc255c86d7c3a9ff088f9eb29a42565 b/sql/hive/src/test/resources/golden/truncate_table-1-7fc255c86d7c3a9ff088f9eb29a42565 new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/truncate_table-1-ec0e3744208003f18c33a1f2c4c1e2c6 b/sql/hive/src/test/resources/golden/truncate_table-1-ec0e3744208003f18c33a1f2c4c1e2c6 new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/truncate_table-10-9ba46fdca3f0f4da8991cb5c7b01efdb b/sql/hive/src/test/resources/golden/truncate_table-10-9ba46fdca3f0f4da8991cb5c7b01efdb new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/truncate_table-10-c32b771845f4d5a0330e2cfa09f89a7f b/sql/hive/src/test/resources/golden/truncate_table-10-c32b771845f4d5a0330e2cfa09f89a7f new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/truncate_table-11-6e0b877ea24fa88c5461b02f7bda0746 b/sql/hive/src/test/resources/golden/truncate_table-11-6e0b877ea24fa88c5461b02f7bda0746 new file mode 100644 index 0000000000000..573541ac9702d --- /dev/null +++ b/sql/hive/src/test/resources/golden/truncate_table-11-6e0b877ea24fa88c5461b02f7bda0746 @@ -0,0 +1 @@ +0 diff --git a/sql/hive/src/test/resources/golden/truncate_table-12-7dee32ebe9887833a9ae2ea6e5568028 b/sql/hive/src/test/resources/golden/truncate_table-12-7dee32ebe9887833a9ae2ea6e5568028 new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/truncate_table-13-3230cfbe1871330193c3190c77582fe b/sql/hive/src/test/resources/golden/truncate_table-13-3230cfbe1871330193c3190c77582fe new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/truncate_table-14-ae23925663d7e9b7e97c42b66086d835 b/sql/hive/src/test/resources/golden/truncate_table-14-ae23925663d7e9b7e97c42b66086d835 new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/truncate_table-15-7850dc059f9d00eb9439d477e92cb913 b/sql/hive/src/test/resources/golden/truncate_table-15-7850dc059f9d00eb9439d477e92cb913 new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/truncate_table-16-623e41aa678d5abc8341a8cee0ac8f94 b/sql/hive/src/test/resources/golden/truncate_table-16-623e41aa678d5abc8341a8cee0ac8f94 new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/truncate_table-17-8c71d29e7db6a8d1cb5746458c7741e6 b/sql/hive/src/test/resources/golden/truncate_table-17-8c71d29e7db6a8d1cb5746458c7741e6 new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/truncate_table-18-64d431f93d8a44fb143cb4b87d63a105 b/sql/hive/src/test/resources/golden/truncate_table-18-64d431f93d8a44fb143cb4b87d63a105 new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/truncate_table-19-1325d566d66f21a06543271c73a95a6f b/sql/hive/src/test/resources/golden/truncate_table-19-1325d566d66f21a06543271c73a95a6f new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/truncate_table-2-fc4118284bf8301cf0d1056c388f963a b/sql/hive/src/test/resources/golden/truncate_table-2-fc4118284bf8301cf0d1056c388f963a new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/truncate_table-20-91f869cc79191b87d31cfd0eca2839f4 b/sql/hive/src/test/resources/golden/truncate_table-20-91f869cc79191b87d31cfd0eca2839f4 new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/truncate_table-21-f635675d59df31843e7be41af7b9e4fa b/sql/hive/src/test/resources/golden/truncate_table-21-f635675d59df31843e7be41af7b9e4fa new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/truncate_table-22-f121fdc101603a8220c0f18e867f581e b/sql/hive/src/test/resources/golden/truncate_table-22-f121fdc101603a8220c0f18e867f581e new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/truncate_table-23-63988ac685a3bd645787116353f024d2 b/sql/hive/src/test/resources/golden/truncate_table-23-63988ac685a3bd645787116353f024d2 new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/truncate_table-3-ecca1d24f36175932911a6e7a78ece2d b/sql/hive/src/test/resources/golden/truncate_table-3-ecca1d24f36175932911a6e7a78ece2d new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/truncate_table-4-88e636ed8bdf647a02ff269aa3ebfe62 b/sql/hive/src/test/resources/golden/truncate_table-4-88e636ed8bdf647a02ff269aa3ebfe62 new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/truncate_table-5-42aeecc67917d731e60fc46bde021d49 b/sql/hive/src/test/resources/golden/truncate_table-5-42aeecc67917d731e60fc46bde021d49 new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/truncate_table-6-5a6776344f711298f27a8f1d3b47d107 b/sql/hive/src/test/resources/golden/truncate_table-6-5a6776344f711298f27a8f1d3b47d107 new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/truncate_table-7-1ad5d350714e3d4ea17201153772d58d b/sql/hive/src/test/resources/golden/truncate_table-7-1ad5d350714e3d4ea17201153772d58d new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/truncate_table-7-65e270fb0b61886aa85255d77eb65794 b/sql/hive/src/test/resources/golden/truncate_table-7-65e270fb0b61886aa85255d77eb65794 new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/truncate_table-8-76c754eac44c7254b45807255d4dbc3a b/sql/hive/src/test/resources/golden/truncate_table-8-76c754eac44c7254b45807255d4dbc3a new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/truncate_table-8-e7699db3640f3b9b1fe44d6b8c9b507e b/sql/hive/src/test/resources/golden/truncate_table-8-e7699db3640f3b9b1fe44d6b8c9b507e new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/truncate_table-9-eedfbb9479ac6c1b955b8e9b41994da4 b/sql/hive/src/test/resources/golden/truncate_table-9-eedfbb9479ac6c1b955b8e9b41994da4 new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/truncate_table-9-f4286b5657674a6a6b6bc6680f72f89a b/sql/hive/src/test/resources/golden/truncate_table-9-f4286b5657674a6a6b6bc6680f72f89a new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/udf_named_struct-0-50131c0ba7b7a6b65c789a5a8497bada b/sql/hive/src/test/resources/golden/udf_named_struct-0-50131c0ba7b7a6b65c789a5a8497bada new file mode 100644 index 0000000000000..573541ac9702d --- /dev/null +++ b/sql/hive/src/test/resources/golden/udf_named_struct-0-50131c0ba7b7a6b65c789a5a8497bada @@ -0,0 +1 @@ +0 diff --git a/sql/hive/src/test/resources/golden/udf_named_struct-0-8f0ea83364b78634fbb3752c5a5c725 b/sql/hive/src/test/resources/golden/udf_named_struct-0-8f0ea83364b78634fbb3752c5a5c725 new file mode 100644 index 0000000000000..9bff96e7fa20e --- /dev/null +++ b/sql/hive/src/test/resources/golden/udf_named_struct-0-8f0ea83364b78634fbb3752c5a5c725 @@ -0,0 +1 @@ +named_struct(name1, val1, name2, val2, ...) - Creates a struct with the given field names and values diff --git a/sql/hive/src/test/resources/golden/udf_named_struct-1-380c9638cc6ea8ea42f187bf0cedf350 b/sql/hive/src/test/resources/golden/udf_named_struct-1-380c9638cc6ea8ea42f187bf0cedf350 new file mode 100644 index 0000000000000..9bff96e7fa20e --- /dev/null +++ b/sql/hive/src/test/resources/golden/udf_named_struct-1-380c9638cc6ea8ea42f187bf0cedf350 @@ -0,0 +1 @@ +named_struct(name1, val1, name2, val2, ...) - Creates a struct with the given field names and values diff --git a/sql/hive/src/test/resources/golden/udf_named_struct-1-8f0ea83364b78634fbb3752c5a5c725 b/sql/hive/src/test/resources/golden/udf_named_struct-1-8f0ea83364b78634fbb3752c5a5c725 new file mode 100644 index 0000000000000..9bff96e7fa20e --- /dev/null +++ b/sql/hive/src/test/resources/golden/udf_named_struct-1-8f0ea83364b78634fbb3752c5a5c725 @@ -0,0 +1 @@ +named_struct(name1, val1, name2, val2, ...) - Creates a struct with the given field names and values diff --git a/sql/hive/src/test/resources/golden/udf_named_struct-2-22a79ac608b1249306f82f4bdc669b17 b/sql/hive/src/test/resources/golden/udf_named_struct-2-22a79ac608b1249306f82f4bdc669b17 new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/udf_named_struct-2-380c9638cc6ea8ea42f187bf0cedf350 b/sql/hive/src/test/resources/golden/udf_named_struct-2-380c9638cc6ea8ea42f187bf0cedf350 new file mode 100644 index 0000000000000..9bff96e7fa20e --- /dev/null +++ b/sql/hive/src/test/resources/golden/udf_named_struct-2-380c9638cc6ea8ea42f187bf0cedf350 @@ -0,0 +1 @@ +named_struct(name1, val1, name2, val2, ...) - Creates a struct with the given field names and values diff --git a/sql/hive/src/test/resources/golden/udf_named_struct-3-c069e28293a12a813f8e881f776bae90 b/sql/hive/src/test/resources/golden/udf_named_struct-3-c069e28293a12a813f8e881f776bae90 new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/udf_named_struct-3-d7e4a555934307155784904ff9df188b b/sql/hive/src/test/resources/golden/udf_named_struct-3-d7e4a555934307155784904ff9df188b new file mode 100644 index 0000000000000..de25f51b5b56d --- /dev/null +++ b/sql/hive/src/test/resources/golden/udf_named_struct-3-d7e4a555934307155784904ff9df188b @@ -0,0 +1 @@ +{"foo":1,"bar":2} 1 diff --git a/sql/hive/src/test/resources/golden/udf_named_struct-4-b499d4120e009f222f2fab160a9006d7 b/sql/hive/src/test/resources/golden/udf_named_struct-4-b499d4120e009f222f2fab160a9006d7 new file mode 100644 index 0000000000000..de25f51b5b56d --- /dev/null +++ b/sql/hive/src/test/resources/golden/udf_named_struct-4-b499d4120e009f222f2fab160a9006d7 @@ -0,0 +1 @@ +{"foo":1,"bar":2} 1 diff --git a/sql/hive/src/test/resources/golden/udf_sort_array-0-e86d559aeb84a4cc017a103182c22bfb b/sql/hive/src/test/resources/golden/udf_sort_array-0-e86d559aeb84a4cc017a103182c22bfb new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/udf_sort_array-1-976cd8b6b50a2748bbc768aa5e11cf82 b/sql/hive/src/test/resources/golden/udf_sort_array-1-976cd8b6b50a2748bbc768aa5e11cf82 new file mode 100644 index 0000000000000..d514df4191b89 --- /dev/null +++ b/sql/hive/src/test/resources/golden/udf_sort_array-1-976cd8b6b50a2748bbc768aa5e11cf82 @@ -0,0 +1 @@ +sort_array(array(obj1, obj2,...)) - Sorts the input array in ascending order according to the natural ordering of the array elements. diff --git a/sql/hive/src/test/resources/golden/udf_sort_array-10-9e047718e5fea6ea79124f1e899f1c13 b/sql/hive/src/test/resources/golden/udf_sort_array-10-9e047718e5fea6ea79124f1e899f1c13 new file mode 100644 index 0000000000000..9d33cd51fef04 --- /dev/null +++ b/sql/hive/src/test/resources/golden/udf_sort_array-10-9e047718e5fea6ea79124f1e899f1c13 @@ -0,0 +1 @@ +[1,2,3,4,5] [1,2,7,8,9] [4,8,16,32,64] [1,100,246,357,1000] [false,true] [1.414,1.618,2.718,3.141] [1.41421,1.61803,2.71828,3.14159] ["","aramis","athos","portos"] ["1970-01-05 13:51:04.042","1970-01-07 00:54:54.442","1970-01-16 12:50:35.242"] diff --git a/sql/hive/src/test/resources/golden/udf_sort_array-2-c429ec85a6da60ebd4bc6f0f266e8b93 b/sql/hive/src/test/resources/golden/udf_sort_array-2-c429ec85a6da60ebd4bc6f0f266e8b93 new file mode 100644 index 0000000000000..43e36513de881 --- /dev/null +++ b/sql/hive/src/test/resources/golden/udf_sort_array-2-c429ec85a6da60ebd4bc6f0f266e8b93 @@ -0,0 +1,4 @@ +sort_array(array(obj1, obj2,...)) - Sorts the input array in ascending order according to the natural ordering of the array elements. +Example: + > SELECT sort_array(array('b', 'd', 'c', 'a')) FROM src LIMIT 1; + 'a', 'b', 'c', 'd' diff --git a/sql/hive/src/test/resources/golden/udf_sort_array-3-55c4cdaf8438b06675d60848d68f35de b/sql/hive/src/test/resources/golden/udf_sort_array-3-55c4cdaf8438b06675d60848d68f35de new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/udf_struct-0-50131c0ba7b7a6b65c789a5a8497bada b/sql/hive/src/test/resources/golden/udf_struct-0-50131c0ba7b7a6b65c789a5a8497bada new file mode 100644 index 0000000000000..573541ac9702d --- /dev/null +++ b/sql/hive/src/test/resources/golden/udf_struct-0-50131c0ba7b7a6b65c789a5a8497bada @@ -0,0 +1 @@ +0 diff --git a/sql/hive/src/test/resources/golden/udf_struct-0-f41043b7d9f14fa5e998c90454c7bdb1 b/sql/hive/src/test/resources/golden/udf_struct-0-f41043b7d9f14fa5e998c90454c7bdb1 new file mode 100644 index 0000000000000..062cb1bc683b1 --- /dev/null +++ b/sql/hive/src/test/resources/golden/udf_struct-0-f41043b7d9f14fa5e998c90454c7bdb1 @@ -0,0 +1 @@ +struct(col1, col2, col3, ...) - Creates a struct with the given field values diff --git a/sql/hive/src/test/resources/golden/udf_struct-1-8ccdb20153debdab789ea8ad0228e2eb b/sql/hive/src/test/resources/golden/udf_struct-1-8ccdb20153debdab789ea8ad0228e2eb new file mode 100644 index 0000000000000..062cb1bc683b1 --- /dev/null +++ b/sql/hive/src/test/resources/golden/udf_struct-1-8ccdb20153debdab789ea8ad0228e2eb @@ -0,0 +1 @@ +struct(col1, col2, col3, ...) - Creates a struct with the given field values diff --git a/sql/hive/src/test/resources/golden/udf_struct-1-f41043b7d9f14fa5e998c90454c7bdb1 b/sql/hive/src/test/resources/golden/udf_struct-1-f41043b7d9f14fa5e998c90454c7bdb1 new file mode 100644 index 0000000000000..062cb1bc683b1 --- /dev/null +++ b/sql/hive/src/test/resources/golden/udf_struct-1-f41043b7d9f14fa5e998c90454c7bdb1 @@ -0,0 +1 @@ +struct(col1, col2, col3, ...) - Creates a struct with the given field values diff --git a/sql/hive/src/test/resources/golden/udf_struct-2-4a62774a6de7571c8d2bcb77da63f8f3 b/sql/hive/src/test/resources/golden/udf_struct-2-4a62774a6de7571c8d2bcb77da63f8f3 new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/udf_struct-2-8ccdb20153debdab789ea8ad0228e2eb b/sql/hive/src/test/resources/golden/udf_struct-2-8ccdb20153debdab789ea8ad0228e2eb new file mode 100644 index 0000000000000..062cb1bc683b1 --- /dev/null +++ b/sql/hive/src/test/resources/golden/udf_struct-2-8ccdb20153debdab789ea8ad0228e2eb @@ -0,0 +1 @@ +struct(col1, col2, col3, ...) - Creates a struct with the given field values diff --git a/sql/hive/src/test/resources/golden/udf_struct-3-71361a92b74c4d026ac7ae6e1e6746f1 b/sql/hive/src/test/resources/golden/udf_struct-3-71361a92b74c4d026ac7ae6e1e6746f1 new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/udf_struct-3-abffdaacb0c7076ab538fbeec072daa2 b/sql/hive/src/test/resources/golden/udf_struct-3-abffdaacb0c7076ab538fbeec072daa2 new file mode 100644 index 0000000000000..ff1a28fa47f18 --- /dev/null +++ b/sql/hive/src/test/resources/golden/udf_struct-3-abffdaacb0c7076ab538fbeec072daa2 @@ -0,0 +1 @@ +{"col1":1} {"col1":1,"col2":"a"} 1 a diff --git a/sql/hive/src/test/resources/golden/udf_struct-4-b196b5d8849d52bbe5e2ee683f29e051 b/sql/hive/src/test/resources/golden/udf_struct-4-b196b5d8849d52bbe5e2ee683f29e051 new file mode 100644 index 0000000000000..ff1a28fa47f18 --- /dev/null +++ b/sql/hive/src/test/resources/golden/udf_struct-4-b196b5d8849d52bbe5e2ee683f29e051 @@ -0,0 +1 @@ +{"col1":1} {"col1":1,"col2":"a"} 1 a diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/QueryTest.scala b/sql/hive/src/test/scala/org/apache/spark/sql/QueryTest.scala index 95921c3d7ae09..f89c49d292c6c 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/QueryTest.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/QueryTest.scala @@ -19,6 +19,8 @@ package org.apache.spark.sql import org.scalatest.FunSuite +import org.apache.spark.sql.catalyst.expressions.{ExprId, AttributeReference} +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.util._ @@ -29,7 +31,26 @@ import org.apache.spark.sql.catalyst.util._ * It is hard to have maven allow one subproject depend on another subprojects test code. * So, we duplicate this code here. */ -class QueryTest extends FunSuite { +class QueryTest extends PlanTest { + /** + * Runs the plan and makes sure the answer contains all of the keywords, or the + * none of keywords are listed in the answer + * @param rdd the [[SchemaRDD]] to be executed + * @param exists true for make sure the keywords are listed in the output, otherwise + * to make sure none of the keyword are not listed in the output + * @param keywords keyword in string array + */ + def checkExistence(rdd: SchemaRDD, exists: Boolean, keywords: String*) { + val outputs = rdd.collect().map(_.mkString).mkString + for (key <- keywords) { + if (exists) { + assert(outputs.contains(key), s"Failed for $rdd ($key doens't exist in result)") + } else { + assert(!outputs.contains(key), s"Failed for $rdd ($key existed in the result)") + } + } + } + /** * Runs the plan and makes sure the answer matches the expected result. * @param rdd the [[SchemaRDD]] to be executed diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala b/sql/hive/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala new file mode 100644 index 0000000000000..081d94b6fc020 --- /dev/null +++ b/sql/hive/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala @@ -0,0 +1,57 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.plans + +import org.apache.spark.sql.catalyst.expressions.{AttributeReference, ExprId} +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.catalyst.util._ +import org.scalatest.FunSuite + +/** + * *** DUPLICATED FROM sql/catalyst/plans. *** + * + * It is hard to have maven allow one subproject depend on another subprojects test code. + * So, we duplicate this code here. + */ +class PlanTest extends FunSuite { + + /** + * Since attribute references are given globally unique ids during analysis, + * we must normalize them to check if two different queries are identical. + */ + protected def normalizeExprIds(plan: LogicalPlan) = { + val list = plan.flatMap(_.expressions.flatMap(_.references).map(_.exprId.id)) + val minId = if (list.isEmpty) 0 else list.min + plan transformAllExpressions { + case a: AttributeReference => + AttributeReference(a.name, a.dataType, a.nullable)(exprId = ExprId(a.exprId.id - minId)) + } + } + + /** Fails the test if the two plans do not match */ + protected def comparePlans(plan1: LogicalPlan, plan2: LogicalPlan) { + val normalized1 = normalizeExprIds(plan1) + val normalized2 = normalizeExprIds(plan2) + if (normalized1 != normalized2) + fail( + s""" + |== FAIL: Plans do not match === + |${sideBySide(normalized1.treeString, normalized2.treeString).mkString("\n")} + """.stripMargin) + } +} diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/InsertIntoHiveTableSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/InsertIntoHiveTableSuite.scala index 7e323146f9da2..18dc937dd2b27 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/InsertIntoHiveTableSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/InsertIntoHiveTableSuite.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.hive import org.apache.spark.sql.QueryTest +import org.apache.spark.sql._ import org.apache.spark.sql.hive.test.TestHive /* Implicits */ @@ -73,4 +74,21 @@ class InsertIntoHiveTableSuite extends QueryTest { createTable[TestData]("createAndInsertTest") createTable[TestData]("createAndInsertTest") } + + test("SPARK-4052: scala.collection.Map as value type of MapType") { + val schema = StructType(StructField("m", MapType(StringType, StringType), true) :: Nil) + val rowRDD = TestHive.sparkContext.parallelize( + (1 to 100).map(i => Row(scala.collection.mutable.HashMap(s"key$i" -> s"value$i")))) + val schemaRDD = applySchema(rowRDD, schema) + schemaRDD.registerTempTable("tableWithMapValue") + sql("CREATE TABLE hiveTableWithMapValue(m MAP )") + sql("INSERT OVERWRITE TABLE hiveTableWithMapValue SELECT m FROM tableWithMapValue") + + checkAnswer( + sql("SELECT * FROM hiveTableWithMapValue"), + rowRDD.collect().toSeq + ) + + sql("DROP TABLE hiveTableWithMapValue") + } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveExplainSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveExplainSuite.scala index 4ed58f4be1167..a68fc2a803bb4 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveExplainSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveExplainSuite.scala @@ -18,37 +18,24 @@ package org.apache.spark.sql.hive.execution import org.apache.spark.sql.QueryTest -import org.apache.spark.sql.hive.test.TestHive import org.apache.spark.sql.hive.test.TestHive._ -import org.apache.spark.sql.Row /** * A set of tests that validates support for Hive Explain command. */ class HiveExplainSuite extends QueryTest { - private def check(sqlCmd: String, exists: Boolean, keywords: String*) { - val outputs = sql(sqlCmd).collect().map(_.getString(0)).mkString - for (key <- keywords) { - if (exists) { - assert(outputs.contains(key), s"Failed for $sqlCmd ($key doens't exist in result)") - } else { - assert(!outputs.contains(key), s"Failed for $sqlCmd ($key existed in the result)") - } - } - } - test("explain extended command") { - check(" explain select * from src where key=123 ", true, - "== Physical Plan ==") - check(" explain select * from src where key=123 ", false, - "== Parsed Logical Plan ==", - "== Analyzed Logical Plan ==", - "== Optimized Logical Plan ==") - check(" explain extended select * from src where key=123 ", true, - "== Parsed Logical Plan ==", - "== Analyzed Logical Plan ==", - "== Optimized Logical Plan ==", - "== Physical Plan ==", - "Code Generation", "== RDD ==") + checkExistence(sql(" explain select * from src where key=123 "), true, + "== Physical Plan ==") + checkExistence(sql(" explain select * from src where key=123 "), false, + "== Parsed Logical Plan ==", + "== Analyzed Logical Plan ==", + "== Optimized Logical Plan ==") + checkExistence(sql(" explain extended select * from src where key=123 "), true, + "== Parsed Logical Plan ==", + "== Analyzed Logical Plan ==", + "== Optimized Logical Plan ==", + "== Physical Plan ==", + "Code Generation", "== RDD ==") } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HivePlanTest.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HivePlanTest.scala new file mode 100644 index 0000000000000..c939e6e99d28a --- /dev/null +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HivePlanTest.scala @@ -0,0 +1,32 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive.execution + +import org.apache.spark.sql.QueryTest +import org.apache.spark.sql.hive.test.TestHive + +class HivePlanTest extends QueryTest { + import TestHive._ + + test("udf constant folding") { + val optimized = sql("SELECT cos(null) FROM src").queryExecution.optimizedPlan + val correctAnswer = sql("SELECT cast(null as double) FROM src").queryExecution.optimizedPlan + + comparePlans(optimized, correctAnswer) + } +} diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala index 322a25bb20837..b897dff0159ff 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala @@ -17,11 +17,13 @@ package org.apache.spark.sql.hive.execution +import java.io.File + import scala.util.Try import org.apache.hadoop.hive.conf.HiveConf.ConfVars -import org.apache.spark.SparkException +import org.apache.spark.{SparkFiles, SparkException} import org.apache.spark.sql.catalyst.plans.logical.Project import org.apache.spark.sql.hive._ import org.apache.spark.sql.hive.test.TestHive @@ -34,6 +36,14 @@ case class TestData(a: Int, b: String) * A set of test cases expressed in Hive QL that are not covered by the tests included in the hive distribution. */ class HiveQuerySuite extends HiveComparisonTest { + createQueryTest("constant array", + """ + |SELECT sort_array( + | sort_array( + | array("hadoop distributed file system", + | "enterprise databases", "hadoop map-reduce"))) + |FROM src LIMIT 1; + """.stripMargin) createQueryTest("count distinct 0 values", """ @@ -561,7 +571,7 @@ class HiveQuerySuite extends HiveComparisonTest { |WITH serdeproperties('s1'='9') """.stripMargin) } - // Now only verify 0.12.0, and ignore other versions due to binary compatability + // Now only verify 0.12.0, and ignore other versions due to binary compatibility // current TestSerDe.jar is from 0.12.0 if (HiveShim.version == "0.12.0") { sql(s"ADD JAR $testJar") @@ -573,6 +583,17 @@ class HiveQuerySuite extends HiveComparisonTest { sql("DROP TABLE alter1") } + test("ADD FILE command") { + val testFile = TestHive.getHiveFile("data/files/v1.txt").getCanonicalFile + sql(s"ADD FILE $testFile") + + val checkAddFileRDD = sparkContext.parallelize(1 to 2, 1).mapPartitions { _ => + Iterator.single(new File(SparkFiles.get("v1.txt")).canRead) + } + + assert(checkAddFileRDD.first()) + } + case class LogEntry(filename: String, message: String) case class LogFile(name: String) @@ -748,7 +769,7 @@ class HiveQuerySuite extends HiveComparisonTest { }.toSet clear() - // "set" itself returns all config variables currently specified in SQLConf. + // "SET" itself returns all config variables currently specified in SQLConf. // TODO: Should we be listing the default here always? probably... assert(sql("SET").collect().size == 0) @@ -757,44 +778,19 @@ class HiveQuerySuite extends HiveComparisonTest { } assert(hiveconf.get(testKey, "") == testVal) - assertResult(Set(testKey -> testVal)) { - collectResults(sql("SET")) - } + assertResult(Set(testKey -> testVal))(collectResults(sql("SET"))) + assertResult(Set(testKey -> testVal))(collectResults(sql("SET -v"))) sql(s"SET ${testKey + testKey}=${testVal + testVal}") assert(hiveconf.get(testKey + testKey, "") == testVal + testVal) assertResult(Set(testKey -> testVal, (testKey + testKey) -> (testVal + testVal))) { collectResults(sql("SET")) } - - // "set key" - assertResult(Set(testKey -> testVal)) { - collectResults(sql(s"SET $testKey")) - } - - assertResult(Set(nonexistentKey -> "")) { - collectResults(sql(s"SET $nonexistentKey")) - } - - // Assert that sql() should have the same effects as sql() by repeating the above using sql(). - clear() - assert(sql("SET").collect().size == 0) - - assertResult(Set(testKey -> testVal)) { - collectResults(sql(s"SET $testKey=$testVal")) - } - - assert(hiveconf.get(testKey, "") == testVal) - assertResult(Set(testKey -> testVal)) { - collectResults(sql("SET")) - } - - sql(s"SET ${testKey + testKey}=${testVal + testVal}") - assert(hiveconf.get(testKey + testKey, "") == testVal + testVal) assertResult(Set(testKey -> testVal, (testKey + testKey) -> (testVal + testVal))) { - collectResults(sql("SET")) + collectResults(sql("SET -v")) } + // "SET key" assertResult(Set(testKey -> testVal)) { collectResults(sql(s"SET $testKey")) } @@ -808,7 +804,7 @@ class HiveQuerySuite extends HiveComparisonTest { createQueryTest("select from thrift based table", "SELECT * from src_thrift") - + // Put tests that depend on specific Hive settings before these last two test, // since they modify /clear stuff. } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveTableScanSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveTableScanSuite.scala index c5736723b47c0..54c0f017d4cb6 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveTableScanSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveTableScanSuite.scala @@ -17,6 +17,11 @@ package org.apache.spark.sql.hive.execution +import org.apache.spark.sql.hive.test.TestHive +import org.apache.spark.sql.{Row, SchemaRDD} + +import org.apache.spark.util.Utils + class HiveTableScanSuite extends HiveComparisonTest { createQueryTest("partition_based_table_scan_with_different_serde", @@ -38,4 +43,30 @@ class HiveTableScanSuite extends HiveComparisonTest { | |SELECT * from part_scan_test; """.stripMargin) + + test("Spark-4041: lowercase issue") { + TestHive.sql("CREATE TABLE tb (KEY INT, VALUE STRING) STORED AS ORC") + TestHive.sql("insert into table tb select key, value from src") + TestHive.sql("select KEY from tb where VALUE='just_for_test' limit 5").collect() + TestHive.sql("drop table tb") + } + + test("Spark-4077: timestamp query for null value") { + TestHive.sql("DROP TABLE IF EXISTS timestamp_query_null") + TestHive.sql( + """ + CREATE EXTERNAL TABLE timestamp_query_null (time TIMESTAMP,id INT) + ROW FORMAT DELIMITED + FIELDS TERMINATED BY ',' + LINES TERMINATED BY '\n' + """.stripMargin) + val location = + Utils.getSparkClassLoader.getResource("data/files/issue-4077-data.txt").getFile() + + TestHive.sql(s"LOAD DATA LOCAL INPATH '$location' INTO TABLE timestamp_query_null") + assert(TestHive.sql("SELECT time from timestamp_query_null limit 2").collect() + === Array(Row(java.sql.Timestamp.valueOf("2014-12-11 00:00:00")),Row(null))) + TestHive.sql("DROP TABLE timestamp_query_null") + } + } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala index fbe6ac765c009..e9b1943ff8db7 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala @@ -32,6 +32,70 @@ case class Nested3(f3: Int) * valid, but Hive currently cannot execute it. */ class SQLQuerySuite extends QueryTest { + test("CTAS with serde") { + sql("CREATE TABLE ctas1 AS SELECT key k, value FROM src ORDER BY k, value").collect + sql( + """CREATE TABLE ctas2 + | ROW FORMAT SERDE "org.apache.hadoop.hive.serde2.columnar.ColumnarSerDe" + | WITH SERDEPROPERTIES("serde_p1"="p1","serde_p2"="p2") + | STORED AS RCFile + | TBLPROPERTIES("tbl_p1"="p11", "tbl_p2"="p22") + | AS + | SELECT key, value + | FROM src + | ORDER BY key, value""".stripMargin).collect + sql( + """CREATE TABLE ctas3 + | ROW FORMAT DELIMITED FIELDS TERMINATED BY ',' LINES TERMINATED BY '\012' + | STORED AS textfile AS + | SELECT key, value + | FROM src + | ORDER BY key, value""".stripMargin).collect + + // the table schema may like (key: integer, value: string) + sql( + """CREATE TABLE IF NOT EXISTS ctas4 AS + | SELECT 1 AS key, value FROM src LIMIT 1""".stripMargin).collect + // do nothing cause the table ctas4 already existed. + sql( + """CREATE TABLE IF NOT EXISTS ctas4 AS + | SELECT key, value FROM src ORDER BY key, value""".stripMargin).collect + + checkAnswer( + sql("SELECT k, value FROM ctas1 ORDER BY k, value"), + sql("SELECT key, value FROM src ORDER BY key, value").collect().toSeq) + checkAnswer( + sql("SELECT key, value FROM ctas2 ORDER BY key, value"), + sql( + """ + SELECT key, value + FROM src + ORDER BY key, value""").collect().toSeq) + checkAnswer( + sql("SELECT key, value FROM ctas3 ORDER BY key, value"), + sql( + """ + SELECT key, value + FROM src + ORDER BY key, value""").collect().toSeq) + intercept[org.apache.hadoop.hive.metastore.api.AlreadyExistsException] { + sql( + """CREATE TABLE ctas4 AS + | SELECT key, value FROM src ORDER BY key, value""".stripMargin).collect + } + checkAnswer( + sql("SELECT key, value FROM ctas4 ORDER BY key, value"), + sql("SELECT key, value FROM ctas4 LIMIT 1").collect().toSeq) + + checkExistence(sql("DESC EXTENDED ctas2"), true, + "name:key", "type:string", "name:value", "ctas2", + "org.apache.hadoop.hive.ql.io.RCFileInputFormat", + "org.apache.hadoop.hive.ql.io.RCFileOutputFormat", + "org.apache.hadoop.hive.serde2.columnar.ColumnarSerDe", + "serde_p1=p1", "serde_p2=p2", "tbl_p1=p11", "tbl_p2=p22","MANAGED_TABLE" + ) + } + test("ordering not in select") { checkAnswer( sql("SELECT key FROM src ORDER BY value"), @@ -75,4 +139,33 @@ class SQLQuerySuite extends QueryTest { sql("SELECT a.key FROM (SELECT key FROM src) `a`"), sql("SELECT `key` FROM src").collect().toSeq) } + + test("SPARK-3814 Support Bitwise & operator") { + checkAnswer( + sql("SELECT case when 1&1=1 then 1 else 0 end FROM src"), + sql("SELECT 1 FROM src").collect().toSeq) + } + + test("SPARK-3814 Support Bitwise | operator") { + checkAnswer( + sql("SELECT case when 1|0=1 then 1 else 0 end FROM src"), + sql("SELECT 1 FROM src").collect().toSeq) + } + + test("SPARK-3814 Support Bitwise ^ operator") { + checkAnswer( + sql("SELECT case when 1^0=1 then 1 else 0 end FROM src"), + sql("SELECT 1 FROM src").collect().toSeq) + } + + test("SPARK-3814 Support Bitwise ~ operator") { + checkAnswer( + sql("SELECT case when ~1=-2 then 1 else 0 end FROM src"), + sql("SELECT 1 FROM src").collect().toSeq) + } + + test("SPARK-4154 Query does not work if it has 'not between' in Spark SQL and HQL") { + checkAnswer(sql("SELECT key FROM src WHERE key not between 0 and 10 order by key"), + sql("SELECT key FROM src WHERE key between 11 and 500 order by key").collect().toSeq) + } } diff --git a/sql/hive/v0.12.0/src/main/scala/org/apache/spark/sql/hive/Shim12.scala b/sql/hive/v0.12.0/src/main/scala/org/apache/spark/sql/hive/Shim12.scala index 2317d2e76341f..8e946b7e82f5d 100644 --- a/sql/hive/v0.12.0/src/main/scala/org/apache/spark/sql/hive/Shim12.scala +++ b/sql/hive/v0.12.0/src/main/scala/org/apache/spark/sql/hive/Shim12.scala @@ -26,21 +26,28 @@ import org.apache.hadoop.hive.common.`type`.HiveDecimal import org.apache.hadoop.hive.conf.HiveConf import org.apache.hadoop.hive.ql.Context import org.apache.hadoop.hive.ql.metadata.{Hive, Partition, Table} -import org.apache.hadoop.hive.ql.plan.{FileSinkDesc, TableDesc} +import org.apache.hadoop.hive.ql.plan.{CreateTableDesc, FileSinkDesc, TableDesc} import org.apache.hadoop.hive.ql.processors._ import org.apache.hadoop.hive.ql.stats.StatsSetupConst +import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector.PrimitiveCategory +import org.apache.hadoop.hive.serde2.objectinspector.primitive.{HiveDecimalObjectInspector, PrimitiveObjectInspectorFactory} +import org.apache.hadoop.hive.serde2.objectinspector.{PrimitiveObjectInspector, ObjectInspector} +import org.apache.hadoop.hive.serde2.typeinfo.{TypeInfo, TypeInfoFactory} import org.apache.hadoop.hive.serde2.{Deserializer, ColumnProjectionUtils} +import org.apache.hadoop.hive.serde2.{io => hiveIo} import org.apache.hadoop.{io => hadoopIo} import org.apache.hadoop.mapred.InputFormat +import org.apache.spark.sql.catalyst.types.decimal.Decimal import scala.collection.JavaConversions._ import scala.language.implicitConversions +import org.apache.spark.sql.catalyst.types.DecimalType + /** * A compatibility layer for interacting with Hive version 0.12.0. */ private[hive] object HiveShim { val version = "0.12.0" - val metastoreDecimal = "decimal" def getTableDesc( serdeClass: Class[_ <: Deserializer], @@ -50,6 +57,59 @@ private[hive] object HiveShim { new TableDesc(serdeClass, inputFormatClass, outputFormatClass, properties) } + def getPrimitiveWritableConstantObjectInspector(value: String): ObjectInspector = + PrimitiveObjectInspectorFactory.getPrimitiveWritableConstantObjectInspector( + PrimitiveCategory.STRING, new hadoopIo.Text(value)) + + def getPrimitiveWritableConstantObjectInspector(value: Int): ObjectInspector = + PrimitiveObjectInspectorFactory.getPrimitiveWritableConstantObjectInspector( + PrimitiveCategory.INT, new hadoopIo.IntWritable(value)) + + def getPrimitiveWritableConstantObjectInspector(value: Double): ObjectInspector = + PrimitiveObjectInspectorFactory.getPrimitiveWritableConstantObjectInspector( + PrimitiveCategory.DOUBLE, new hiveIo.DoubleWritable(value)) + + def getPrimitiveWritableConstantObjectInspector(value: Boolean): ObjectInspector = + PrimitiveObjectInspectorFactory.getPrimitiveWritableConstantObjectInspector( + PrimitiveCategory.BOOLEAN, new hadoopIo.BooleanWritable(value)) + + def getPrimitiveWritableConstantObjectInspector(value: Long): ObjectInspector = + PrimitiveObjectInspectorFactory.getPrimitiveWritableConstantObjectInspector( + PrimitiveCategory.LONG, new hadoopIo.LongWritable(value)) + + def getPrimitiveWritableConstantObjectInspector(value: Float): ObjectInspector = + PrimitiveObjectInspectorFactory.getPrimitiveWritableConstantObjectInspector( + PrimitiveCategory.FLOAT, new hadoopIo.FloatWritable(value)) + + def getPrimitiveWritableConstantObjectInspector(value: Short): ObjectInspector = + PrimitiveObjectInspectorFactory.getPrimitiveWritableConstantObjectInspector( + PrimitiveCategory.SHORT, new hiveIo.ShortWritable(value)) + + def getPrimitiveWritableConstantObjectInspector(value: Byte): ObjectInspector = + PrimitiveObjectInspectorFactory.getPrimitiveWritableConstantObjectInspector( + PrimitiveCategory.BYTE, new hiveIo.ByteWritable(value)) + + def getPrimitiveWritableConstantObjectInspector(value: Array[Byte]): ObjectInspector = + PrimitiveObjectInspectorFactory.getPrimitiveWritableConstantObjectInspector( + PrimitiveCategory.BINARY, new hadoopIo.BytesWritable(value)) + + def getPrimitiveWritableConstantObjectInspector(value: java.sql.Date): ObjectInspector = + PrimitiveObjectInspectorFactory.getPrimitiveWritableConstantObjectInspector( + PrimitiveCategory.DATE, new hiveIo.DateWritable(value)) + + def getPrimitiveWritableConstantObjectInspector(value: java.sql.Timestamp): ObjectInspector = + PrimitiveObjectInspectorFactory.getPrimitiveWritableConstantObjectInspector( + PrimitiveCategory.TIMESTAMP, new hiveIo.TimestampWritable(value)) + + def getPrimitiveWritableConstantObjectInspector(value: BigDecimal): ObjectInspector = + PrimitiveObjectInspectorFactory.getPrimitiveWritableConstantObjectInspector( + PrimitiveCategory.DECIMAL, + new hiveIo.HiveDecimalWritable(HiveShim.createDecimal(value.underlying()))) + + def getPrimitiveNullWritableConstantObjectInspector: ObjectInspector = + PrimitiveObjectInspectorFactory.getPrimitiveWritableConstantObjectInspector( + PrimitiveCategory.VOID, null) + def createDriverResultsArray = new JArrayList[String] def processResults(results: JArrayList[String]) = results @@ -89,6 +149,22 @@ private[hive] object HiveShim { "udf_concat" ) + def setLocation(tbl: Table, crtTbl: CreateTableDesc): Unit = { + tbl.setDataLocation(new Path(crtTbl.getLocation()).toUri()) + } + + def decimalMetastoreString(decimalType: DecimalType): String = "decimal" + + def decimalTypeInfo(decimalType: DecimalType): TypeInfo = + TypeInfoFactory.decimalTypeInfo + + def decimalTypeInfoToCatalyst(inspector: PrimitiveObjectInspector): DecimalType = { + DecimalType.Unlimited + } + + def toCatalystDecimal(hdoi: HiveDecimalObjectInspector, data: Any): Decimal = { + Decimal(hdoi.getPrimitiveJavaObject(data).bigDecimalValue()) + } } class ShimFileSinkDesc(var dir: String, var tableInfo: TableDesc, var compressed: Boolean) diff --git a/sql/hive/v0.13.1/src/main/scala/org/apache/spark/sql/hive/Shim13.scala b/sql/hive/v0.13.1/src/main/scala/org/apache/spark/sql/hive/Shim13.scala index b8d893d8c1319..0bc330cdbecb1 100644 --- a/sql/hive/v0.13.1/src/main/scala/org/apache/spark/sql/hive/Shim13.scala +++ b/sql/hive/v0.13.1/src/main/scala/org/apache/spark/sql/hive/Shim13.scala @@ -21,17 +21,24 @@ import java.util.{ArrayList => JArrayList} import java.util.Properties import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.Path +import org.apache.hadoop.mapred.InputFormat import org.apache.hadoop.hive.common.StatsSetupConst import org.apache.hadoop.hive.common.`type`.{HiveDecimal} import org.apache.hadoop.hive.conf.HiveConf import org.apache.hadoop.hive.ql.Context import org.apache.hadoop.hive.ql.metadata.{Table, Hive, Partition} -import org.apache.hadoop.hive.ql.plan.{FileSinkDesc, TableDesc} +import org.apache.hadoop.hive.ql.plan.{CreateTableDesc, FileSinkDesc, TableDesc} import org.apache.hadoop.hive.ql.processors.CommandProcessorFactory -import org.apache.hadoop.hive.serde2.{ColumnProjectionUtils, Deserializer} -import org.apache.hadoop.mapred.InputFormat -import org.apache.spark.Logging +import org.apache.hadoop.hive.serde2.typeinfo.{TypeInfo, DecimalTypeInfo, TypeInfoFactory} +import org.apache.hadoop.hive.serde2.objectinspector.primitive.{HiveDecimalObjectInspector, PrimitiveObjectInspectorFactory} +import org.apache.hadoop.hive.serde2.objectinspector.{PrimitiveObjectInspector, ObjectInspector} +import org.apache.hadoop.hive.serde2.{Deserializer, ColumnProjectionUtils} +import org.apache.hadoop.hive.serde2.{io => hiveIo} import org.apache.hadoop.{io => hadoopIo} +import org.apache.spark.Logging +import org.apache.spark.sql.catalyst.types.DecimalType +import org.apache.spark.sql.catalyst.types.decimal.Decimal + import scala.collection.JavaConversions._ import scala.language.implicitConversions @@ -40,11 +47,6 @@ import scala.language.implicitConversions */ private[hive] object HiveShim { val version = "0.13.1" - /* - * TODO: hive-0.13 support DECIMAL(precision, scale), DECIMAL in hive-0.12 is actually DECIMAL(38,unbounded) - * Full support of new decimal feature need to be fixed in seperate PR. - */ - val metastoreDecimal = "decimal\\((\\d+),(\\d+)\\)".r def getTableDesc( serdeClass: Class[_ <: Deserializer], @@ -54,6 +56,59 @@ private[hive] object HiveShim { new TableDesc(inputFormatClass, outputFormatClass, properties) } + def getPrimitiveWritableConstantObjectInspector(value: String): ObjectInspector = + PrimitiveObjectInspectorFactory.getPrimitiveWritableConstantObjectInspector( + TypeInfoFactory.stringTypeInfo, new hadoopIo.Text(value)) + + def getPrimitiveWritableConstantObjectInspector(value: Int): ObjectInspector = + PrimitiveObjectInspectorFactory.getPrimitiveWritableConstantObjectInspector( + TypeInfoFactory.intTypeInfo, new hadoopIo.IntWritable(value)) + + def getPrimitiveWritableConstantObjectInspector(value: Double): ObjectInspector = + PrimitiveObjectInspectorFactory.getPrimitiveWritableConstantObjectInspector( + TypeInfoFactory.doubleTypeInfo, new hiveIo.DoubleWritable(value)) + + def getPrimitiveWritableConstantObjectInspector(value: Boolean): ObjectInspector = + PrimitiveObjectInspectorFactory.getPrimitiveWritableConstantObjectInspector( + TypeInfoFactory.booleanTypeInfo, new hadoopIo.BooleanWritable(value)) + + def getPrimitiveWritableConstantObjectInspector(value: Long): ObjectInspector = + PrimitiveObjectInspectorFactory.getPrimitiveWritableConstantObjectInspector( + TypeInfoFactory.longTypeInfo, new hadoopIo.LongWritable(value)) + + def getPrimitiveWritableConstantObjectInspector(value: Float): ObjectInspector = + PrimitiveObjectInspectorFactory.getPrimitiveWritableConstantObjectInspector( + TypeInfoFactory.floatTypeInfo, new hadoopIo.FloatWritable(value)) + + def getPrimitiveWritableConstantObjectInspector(value: Short): ObjectInspector = + PrimitiveObjectInspectorFactory.getPrimitiveWritableConstantObjectInspector( + TypeInfoFactory.shortTypeInfo, new hiveIo.ShortWritable(value)) + + def getPrimitiveWritableConstantObjectInspector(value: Byte): ObjectInspector = + PrimitiveObjectInspectorFactory.getPrimitiveWritableConstantObjectInspector( + TypeInfoFactory.byteTypeInfo, new hiveIo.ByteWritable(value)) + + def getPrimitiveWritableConstantObjectInspector(value: Array[Byte]): ObjectInspector = + PrimitiveObjectInspectorFactory.getPrimitiveWritableConstantObjectInspector( + TypeInfoFactory.binaryTypeInfo, new hadoopIo.BytesWritable(value)) + + def getPrimitiveWritableConstantObjectInspector(value: java.sql.Date): ObjectInspector = + PrimitiveObjectInspectorFactory.getPrimitiveWritableConstantObjectInspector( + TypeInfoFactory.dateTypeInfo, new hiveIo.DateWritable(value)) + + def getPrimitiveWritableConstantObjectInspector(value: java.sql.Timestamp): ObjectInspector = + PrimitiveObjectInspectorFactory.getPrimitiveWritableConstantObjectInspector( + TypeInfoFactory.timestampTypeInfo, new hiveIo.TimestampWritable(value)) + + def getPrimitiveWritableConstantObjectInspector(value: BigDecimal): ObjectInspector = + PrimitiveObjectInspectorFactory.getPrimitiveWritableConstantObjectInspector( + TypeInfoFactory.decimalTypeInfo, + new hiveIo.HiveDecimalWritable(HiveShim.createDecimal(value.underlying()))) + + def getPrimitiveNullWritableConstantObjectInspector: ObjectInspector = + PrimitiveObjectInspectorFactory.getPrimitiveWritableConstantObjectInspector( + TypeInfoFactory.voidTypeInfo, null) + def createDriverResultsArray = new JArrayList[Object] def processResults(results: JArrayList[Object]) = { @@ -121,6 +176,10 @@ private[hive] object HiveShim { def compatibilityBlackList = Seq() + def setLocation(tbl: Table, crtTbl: CreateTableDesc): Unit = { + tbl.setDataLocation(new Path(crtTbl.getLocation())) + } + /* * Bug introdiced in hive-0.13. FileSinkDesc is serializable, but its member path is not. * Fix it through wrapper. @@ -133,6 +192,30 @@ private[hive] object HiveShim { f.setDestTableId(w.destTableId) f } + + // Precision and scale to pass for unlimited decimals; these are the same as the precision and + // scale Hive 0.13 infers for BigDecimals from sources that don't specify them (e.g. UDFs) + private val UNLIMITED_DECIMAL_PRECISION = 38 + private val UNLIMITED_DECIMAL_SCALE = 18 + + def decimalMetastoreString(decimalType: DecimalType): String = decimalType match { + case DecimalType.Fixed(precision, scale) => s"decimal($precision,$scale)" + case _ => s"decimal($UNLIMITED_DECIMAL_PRECISION,$UNLIMITED_DECIMAL_SCALE)" + } + + def decimalTypeInfo(decimalType: DecimalType): TypeInfo = decimalType match { + case DecimalType.Fixed(precision, scale) => new DecimalTypeInfo(precision, scale) + case _ => new DecimalTypeInfo(UNLIMITED_DECIMAL_PRECISION, UNLIMITED_DECIMAL_SCALE) + } + + def decimalTypeInfoToCatalyst(inspector: PrimitiveObjectInspector): DecimalType = { + val info = inspector.getTypeInfo.asInstanceOf[DecimalTypeInfo] + DecimalType(info.precision(), info.scale()) + } + + def toCatalystDecimal(hdoi: HiveDecimalObjectInspector, data: Any): Decimal = { + Decimal(hdoi.getPrimitiveJavaObject(data).bigDecimalValue(), hdoi.precision(), hdoi.scale()) + } } /* diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/ReceiverInputDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/ReceiverInputDStream.scala index 391e40924f38a..bb47d373de63d 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/ReceiverInputDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/ReceiverInputDStream.scala @@ -23,8 +23,9 @@ import scala.reflect.ClassTag import org.apache.spark.rdd.{BlockRDD, RDD} import org.apache.spark.storage.BlockId import org.apache.spark.streaming._ -import org.apache.spark.streaming.receiver.Receiver +import org.apache.spark.streaming.receiver.{WriteAheadLogBasedStoreResult, BlockManagerBasedStoreResult, Receiver} import org.apache.spark.streaming.scheduler.ReceivedBlockInfo +import org.apache.spark.SparkException /** * Abstract class for defining any [[org.apache.spark.streaming.dstream.InputDStream]] @@ -65,10 +66,10 @@ abstract class ReceiverInputDStream[T: ClassTag](@transient ssc_ : StreamingCont if (validTime >= graph.startTime) { val blockInfo = ssc.scheduler.receiverTracker.getReceivedBlockInfo(id) receivedBlockInfo(validTime) = blockInfo - val blockIds = blockInfo.map(_.blockId.asInstanceOf[BlockId]) + val blockIds = blockInfo.map { _.blockStoreResult.blockId.asInstanceOf[BlockId] } Some(new BlockRDD[T](ssc.sc, blockIds)) } else { - Some(new BlockRDD[T](ssc.sc, Array[BlockId]())) + Some(new BlockRDD[T](ssc.sc, Array.empty)) } } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/rdd/WriteAheadLogBackedBlockRDD.scala b/streaming/src/main/scala/org/apache/spark/streaming/rdd/WriteAheadLogBackedBlockRDD.scala new file mode 100644 index 0000000000000..23295bf658712 --- /dev/null +++ b/streaming/src/main/scala/org/apache/spark/streaming/rdd/WriteAheadLogBackedBlockRDD.scala @@ -0,0 +1,125 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.streaming.rdd + +import scala.reflect.ClassTag + +import org.apache.hadoop.conf.Configuration + +import org.apache.spark._ +import org.apache.spark.rdd.BlockRDD +import org.apache.spark.storage.{BlockId, StorageLevel} +import org.apache.spark.streaming.util.{HdfsUtils, WriteAheadLogFileSegment, WriteAheadLogRandomReader} + +/** + * Partition class for [[org.apache.spark.streaming.rdd.WriteAheadLogBackedBlockRDD]]. + * It contains information about the id of the blocks having this partition's data and + * the segment of the write ahead log that backs the partition. + * @param index index of the partition + * @param blockId id of the block having the partition data + * @param segment segment of the write ahead log having the partition data + */ +private[streaming] +class WriteAheadLogBackedBlockRDDPartition( + val index: Int, + val blockId: BlockId, + val segment: WriteAheadLogFileSegment) + extends Partition + + +/** + * This class represents a special case of the BlockRDD where the data blocks in + * the block manager are also backed by segments in write ahead logs. For reading + * the data, this RDD first looks up the blocks by their ids in the block manager. + * If it does not find them, it looks up the corresponding file segment. + * + * @param sc SparkContext + * @param hadoopConfig Hadoop configuration + * @param blockIds Ids of the blocks that contains this RDD's data + * @param segments Segments in write ahead logs that contain this RDD's data + * @param storeInBlockManager Whether to store in the block manager after reading from the segment + * @param storageLevel storage level to store when storing in block manager + * (applicable when storeInBlockManager = true) + */ +private[streaming] +class WriteAheadLogBackedBlockRDD[T: ClassTag]( + @transient sc: SparkContext, + @transient hadoopConfig: Configuration, + @transient blockIds: Array[BlockId], + @transient segments: Array[WriteAheadLogFileSegment], + storeInBlockManager: Boolean, + storageLevel: StorageLevel) + extends BlockRDD[T](sc, blockIds) { + + require( + blockIds.length == segments.length, + s"Number of block ids (${blockIds.length}) must be " + + s"the same as number of segments (${segments.length}})!") + + // Hadoop configuration is not serializable, so broadcast it as a serializable. + private val broadcastedHadoopConf = new SerializableWritable(hadoopConfig) + + override def getPartitions: Array[Partition] = { + assertValid() + Array.tabulate(blockIds.size) { i => + new WriteAheadLogBackedBlockRDDPartition(i, blockIds(i), segments(i)) + } + } + + /** + * Gets the partition data by getting the corresponding block from the block manager. + * If the block does not exist, then the data is read from the corresponding segment + * in write ahead log files. + */ + override def compute(split: Partition, context: TaskContext): Iterator[T] = { + assertValid() + val hadoopConf = broadcastedHadoopConf.value + val blockManager = SparkEnv.get.blockManager + val partition = split.asInstanceOf[WriteAheadLogBackedBlockRDDPartition] + val blockId = partition.blockId + blockManager.get(blockId) match { + case Some(block) => // Data is in Block Manager + val iterator = block.data.asInstanceOf[Iterator[T]] + logDebug(s"Read partition data of $this from block manager, block $blockId") + iterator + case None => // Data not found in Block Manager, grab it from write ahead log file + val reader = new WriteAheadLogRandomReader(partition.segment.path, hadoopConf) + val dataRead = reader.read(partition.segment) + reader.close() + logInfo(s"Read partition data of $this from write ahead log, segment ${partition.segment}") + if (storeInBlockManager) { + blockManager.putBytes(blockId, dataRead, storageLevel) + logDebug(s"Stored partition data of $this into block manager with level $storageLevel") + dataRead.rewind() + } + blockManager.dataDeserialize(blockId, dataRead).asInstanceOf[Iterator[T]] + } + } + + /** + * Get the preferred location of the partition. This returns the locations of the block + * if it is present in the block manager, else it returns the location of the + * corresponding segment in HDFS. + */ + override def getPreferredLocations(split: Partition): Seq[String] = { + val partition = split.asInstanceOf[WriteAheadLogBackedBlockRDDPartition] + val blockLocations = getBlockIdLocations().get(partition.blockId) + def segmentLocations = HdfsUtils.getFileSegmentLocations( + partition.segment.path, partition.segment.offset, partition.segment.length, hadoopConfig) + blockLocations.getOrElse(segmentLocations) + } +} diff --git a/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceivedBlock.scala b/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceivedBlock.scala new file mode 100644 index 0000000000000..47968afef2dbf --- /dev/null +++ b/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceivedBlock.scala @@ -0,0 +1,35 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.streaming.receiver + +import java.nio.ByteBuffer + +import scala.collection.mutable.ArrayBuffer +import scala.language.existentials + +/** Trait representing a received block */ +private[streaming] sealed trait ReceivedBlock + +/** class representing a block received as an ArrayBuffer */ +private[streaming] case class ArrayBufferBlock(arrayBuffer: ArrayBuffer[_]) extends ReceivedBlock + +/** class representing a block received as an Iterator */ +private[streaming] case class IteratorBlock(iterator: Iterator[_]) extends ReceivedBlock + +/** class representing a block received as an ByteBuffer */ +private[streaming] case class ByteBufferBlock(byteBuffer: ByteBuffer) extends ReceivedBlock diff --git a/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceivedBlockHandler.scala b/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceivedBlockHandler.scala new file mode 100644 index 0000000000000..fdf995320beb4 --- /dev/null +++ b/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceivedBlockHandler.scala @@ -0,0 +1,193 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.streaming.receiver + +import scala.concurrent.{Await, ExecutionContext, Future} +import scala.concurrent.duration._ +import scala.language.{existentials, postfixOps} + +import WriteAheadLogBasedBlockHandler._ +import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.fs.Path + +import org.apache.spark.{Logging, SparkConf, SparkException} +import org.apache.spark.storage._ +import org.apache.spark.streaming.util.{Clock, SystemClock, WriteAheadLogFileSegment, WriteAheadLogManager} +import org.apache.spark.util.Utils + +/** Trait that represents the metadata related to storage of blocks */ +private[streaming] trait ReceivedBlockStoreResult { + def blockId: StreamBlockId // Any implementation of this trait will store a block id +} + +/** Trait that represents a class that handles the storage of blocks received by receiver */ +private[streaming] trait ReceivedBlockHandler { + + /** Store a received block with the given block id and return related metadata */ + def storeBlock(blockId: StreamBlockId, receivedBlock: ReceivedBlock): ReceivedBlockStoreResult + + /** Cleanup old blocks older than the given threshold time */ + def cleanupOldBlock(threshTime: Long) +} + + +/** + * Implementation of [[org.apache.spark.streaming.receiver.ReceivedBlockStoreResult]] + * that stores the metadata related to storage of blocks using + * [[org.apache.spark.streaming.receiver.BlockManagerBasedBlockHandler]] + */ +private[streaming] case class BlockManagerBasedStoreResult(blockId: StreamBlockId) + extends ReceivedBlockStoreResult + + +/** + * Implementation of a [[org.apache.spark.streaming.receiver.ReceivedBlockHandler]] which + * stores the received blocks into a block manager with the specified storage level. + */ +private[streaming] class BlockManagerBasedBlockHandler( + blockManager: BlockManager, storageLevel: StorageLevel) + extends ReceivedBlockHandler with Logging { + + def storeBlock(blockId: StreamBlockId, block: ReceivedBlock): ReceivedBlockStoreResult = { + val putResult: Seq[(BlockId, BlockStatus)] = block match { + case ArrayBufferBlock(arrayBuffer) => + blockManager.putIterator(blockId, arrayBuffer.iterator, storageLevel, tellMaster = true) + case IteratorBlock(iterator) => + blockManager.putIterator(blockId, iterator, storageLevel, tellMaster = true) + case ByteBufferBlock(byteBuffer) => + blockManager.putBytes(blockId, byteBuffer, storageLevel, tellMaster = true) + case o => + throw new SparkException( + s"Could not store $blockId to block manager, unexpected block type ${o.getClass.getName}") + } + if (!putResult.map { _._1 }.contains(blockId)) { + throw new SparkException( + s"Could not store $blockId to block manager with storage level $storageLevel") + } + BlockManagerBasedStoreResult(blockId) + } + + def cleanupOldBlock(threshTime: Long) { + // this is not used as blocks inserted into the BlockManager are cleared by DStream's clearing + // of BlockRDDs. + } +} + + +/** + * Implementation of [[org.apache.spark.streaming.receiver.ReceivedBlockStoreResult]] + * that stores the metadata related to storage of blocks using + * [[org.apache.spark.streaming.receiver.WriteAheadLogBasedBlockHandler]] + */ +private[streaming] case class WriteAheadLogBasedStoreResult( + blockId: StreamBlockId, + segment: WriteAheadLogFileSegment + ) extends ReceivedBlockStoreResult + + +/** + * Implementation of a [[org.apache.spark.streaming.receiver.ReceivedBlockHandler]] which + * stores the received blocks in both, a write ahead log and a block manager. + */ +private[streaming] class WriteAheadLogBasedBlockHandler( + blockManager: BlockManager, + streamId: Int, + storageLevel: StorageLevel, + conf: SparkConf, + hadoopConf: Configuration, + checkpointDir: String, + clock: Clock = new SystemClock + ) extends ReceivedBlockHandler with Logging { + + private val blockStoreTimeout = conf.getInt( + "spark.streaming.receiver.blockStoreTimeout", 30).seconds + private val rollingInterval = conf.getInt( + "spark.streaming.receiver.writeAheadLog.rollingInterval", 60) + private val maxFailures = conf.getInt( + "spark.streaming.receiver.writeAheadLog.maxFailures", 3) + + // Manages rolling log files + private val logManager = new WriteAheadLogManager( + checkpointDirToLogDir(checkpointDir, streamId), + hadoopConf, rollingInterval, maxFailures, + callerName = this.getClass.getSimpleName, + clock = clock + ) + + // For processing futures used in parallel block storing into block manager and write ahead log + // # threads = 2, so that both writing to BM and WAL can proceed in parallel + implicit private val executionContext = ExecutionContext.fromExecutorService( + Utils.newDaemonFixedThreadPool(2, this.getClass.getSimpleName)) + + /** + * This implementation stores the block into the block manager as well as a write ahead log. + * It does this in parallel, using Scala Futures, and returns only after the block has + * been stored in both places. + */ + def storeBlock(blockId: StreamBlockId, block: ReceivedBlock): ReceivedBlockStoreResult = { + + // Serialize the block so that it can be inserted into both + val serializedBlock = block match { + case ArrayBufferBlock(arrayBuffer) => + blockManager.dataSerialize(blockId, arrayBuffer.iterator) + case IteratorBlock(iterator) => + blockManager.dataSerialize(blockId, iterator) + case ByteBufferBlock(byteBuffer) => + byteBuffer + case _ => + throw new Exception(s"Could not push $blockId to block manager, unexpected block type") + } + + // Store the block in block manager + val storeInBlockManagerFuture = Future { + val putResult = + blockManager.putBytes(blockId, serializedBlock, storageLevel, tellMaster = true) + if (!putResult.map { _._1 }.contains(blockId)) { + throw new SparkException( + s"Could not store $blockId to block manager with storage level $storageLevel") + } + } + + // Store the block in write ahead log + val storeInWriteAheadLogFuture = Future { + logManager.writeToLog(serializedBlock) + } + + // Combine the futures, wait for both to complete, and return the write ahead log segment + val combinedFuture = for { + _ <- storeInBlockManagerFuture + fileSegment <- storeInWriteAheadLogFuture + } yield fileSegment + val segment = Await.result(combinedFuture, blockStoreTimeout) + WriteAheadLogBasedStoreResult(blockId, segment) + } + + def cleanupOldBlock(threshTime: Long) { + logManager.cleanupOldLogs(threshTime) + } + + def stop() { + logManager.stop() + } +} + +private[streaming] object WriteAheadLogBasedBlockHandler { + def checkpointDirToLogDir(checkpointDir: String, streamId: Int): String = { + new Path(checkpointDir, new Path("receivedData", streamId.toString)).toString + } +} diff --git a/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceiverSupervisorImpl.scala b/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceiverSupervisorImpl.scala index 53a3e6200e340..5360412330d37 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceiverSupervisorImpl.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceiverSupervisorImpl.scala @@ -25,16 +25,13 @@ import scala.concurrent.Await import akka.actor.{Actor, Props} import akka.pattern.ask - import com.google.common.base.Throwables - -import org.apache.spark.{Logging, SparkEnv} -import org.apache.spark.streaming.scheduler._ -import org.apache.spark.util.{Utils, AkkaUtils} +import org.apache.hadoop.conf.Configuration +import org.apache.spark.{Logging, SparkEnv, SparkException} import org.apache.spark.storage.StreamBlockId -import org.apache.spark.streaming.scheduler.DeregisterReceiver -import org.apache.spark.streaming.scheduler.AddBlock -import org.apache.spark.streaming.scheduler.RegisterReceiver +import org.apache.spark.streaming.scheduler._ +import org.apache.spark.streaming.util.WriteAheadLogFileSegment +import org.apache.spark.util.{AkkaUtils, Utils} /** * Concrete implementation of [[org.apache.spark.streaming.receiver.ReceiverSupervisor]] @@ -44,12 +41,26 @@ import org.apache.spark.streaming.scheduler.RegisterReceiver */ private[streaming] class ReceiverSupervisorImpl( receiver: Receiver[_], - env: SparkEnv + env: SparkEnv, + hadoopConf: Configuration, + checkpointDirOption: Option[String] ) extends ReceiverSupervisor(receiver, env.conf) with Logging { - private val blockManager = env.blockManager + private val receivedBlockHandler: ReceivedBlockHandler = { + if (env.conf.getBoolean("spark.streaming.receiver.writeAheadLog.enable", false)) { + if (checkpointDirOption.isEmpty) { + throw new SparkException( + "Cannot enable receiver write-ahead log without checkpoint directory set. " + + "Please use streamingContext.checkpoint() to set the checkpoint directory. " + + "See documentation for more details.") + } + new WriteAheadLogBasedBlockHandler(env.blockManager, receiver.streamId, + receiver.storageLevel, env.conf, hadoopConf, checkpointDirOption.get) + } else { + new BlockManagerBasedBlockHandler(env.blockManager, receiver.storageLevel) + } + } - private val storageLevel = receiver.storageLevel /** Remote Akka actor for the ReceiverTracker */ private val trackerActor = { @@ -105,47 +116,50 @@ private[streaming] class ReceiverSupervisorImpl( /** Store an ArrayBuffer of received data as a data block into Spark's memory. */ def pushArrayBuffer( arrayBuffer: ArrayBuffer[_], - optionalMetadata: Option[Any], - optionalBlockId: Option[StreamBlockId] + metadataOption: Option[Any], + blockIdOption: Option[StreamBlockId] ) { - val blockId = optionalBlockId.getOrElse(nextBlockId) - val time = System.currentTimeMillis - blockManager.putArray(blockId, arrayBuffer.toArray[Any], storageLevel, tellMaster = true) - logDebug("Pushed block " + blockId + " in " + (System.currentTimeMillis - time) + " ms") - reportPushedBlock(blockId, arrayBuffer.size, optionalMetadata) + pushAndReportBlock(ArrayBufferBlock(arrayBuffer), metadataOption, blockIdOption) } /** Store a iterator of received data as a data block into Spark's memory. */ def pushIterator( iterator: Iterator[_], - optionalMetadata: Option[Any], - optionalBlockId: Option[StreamBlockId] + metadataOption: Option[Any], + blockIdOption: Option[StreamBlockId] ) { - val blockId = optionalBlockId.getOrElse(nextBlockId) - val time = System.currentTimeMillis - blockManager.putIterator(blockId, iterator, storageLevel, tellMaster = true) - logDebug("Pushed block " + blockId + " in " + (System.currentTimeMillis - time) + " ms") - reportPushedBlock(blockId, -1, optionalMetadata) + pushAndReportBlock(IteratorBlock(iterator), metadataOption, blockIdOption) } /** Store the bytes of received data as a data block into Spark's memory. */ def pushBytes( bytes: ByteBuffer, - optionalMetadata: Option[Any], - optionalBlockId: Option[StreamBlockId] + metadataOption: Option[Any], + blockIdOption: Option[StreamBlockId] ) { - val blockId = optionalBlockId.getOrElse(nextBlockId) - val time = System.currentTimeMillis - blockManager.putBytes(blockId, bytes, storageLevel, tellMaster = true) - logDebug("Pushed block " + blockId + " in " + (System.currentTimeMillis - time) + " ms") - reportPushedBlock(blockId, -1, optionalMetadata) + pushAndReportBlock(ByteBufferBlock(bytes), metadataOption, blockIdOption) } - /** Report pushed block */ - def reportPushedBlock(blockId: StreamBlockId, numRecords: Long, optionalMetadata: Option[Any]) { - val blockInfo = ReceivedBlockInfo(streamId, blockId, numRecords, optionalMetadata.orNull) - trackerActor ! AddBlock(blockInfo) - logDebug("Reported block " + blockId) + /** Store block and report it to driver */ + def pushAndReportBlock( + receivedBlock: ReceivedBlock, + metadataOption: Option[Any], + blockIdOption: Option[StreamBlockId] + ) { + val blockId = blockIdOption.getOrElse(nextBlockId) + val numRecords = receivedBlock match { + case ArrayBufferBlock(arrayBuffer) => arrayBuffer.size + case _ => -1 + } + + val time = System.currentTimeMillis + val blockStoreResult = receivedBlockHandler.storeBlock(blockId, receivedBlock) + logDebug(s"Pushed block $blockId in ${(System.currentTimeMillis - time)} ms") + + val blockInfo = ReceivedBlockInfo(streamId, numRecords, blockStoreResult) + val future = trackerActor.ask(AddBlock(blockInfo))(askTimeout) + Await.result(future, askTimeout) + logDebug(s"Reported block $blockId") } /** Report error to the receiver tracker */ diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/BatchInfo.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/BatchInfo.scala index a68aecb881117..92dc113f397ca 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/BatchInfo.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/BatchInfo.scala @@ -17,8 +17,8 @@ package org.apache.spark.streaming.scheduler -import org.apache.spark.streaming.Time import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.streaming.Time /** * :: DeveloperApi :: diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobSet.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobSet.scala index a69d74362173e..8c15a75b1b0e0 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobSet.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobSet.scala @@ -17,7 +17,8 @@ package org.apache.spark.streaming.scheduler -import scala.collection.mutable.{ArrayBuffer, HashSet} +import scala.collection.mutable.HashSet + import org.apache.spark.streaming.Time /** Class representing a set of Jobs diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceivedBlockInfo.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceivedBlockInfo.scala new file mode 100644 index 0000000000000..94beb590f52d6 --- /dev/null +++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceivedBlockInfo.scala @@ -0,0 +1,28 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.streaming.scheduler + +import org.apache.spark.streaming.receiver.ReceivedBlockStoreResult + +/** Information about blocks received by the receiver */ +private[streaming] case class ReceivedBlockInfo( + streamId: Int, + numRecords: Long, + blockStoreResult: ReceivedBlockStoreResult + ) + diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverTracker.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverTracker.scala index 7149dbc12a365..d696563bcee83 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverTracker.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverTracker.scala @@ -21,21 +21,12 @@ import scala.collection.mutable.{HashMap, SynchronizedMap, SynchronizedQueue} import scala.language.existentials import akka.actor._ -import org.apache.spark.{Logging, SparkEnv, SparkException} +import org.apache.spark.{SerializableWritable, Logging, SparkEnv, SparkException} import org.apache.spark.SparkContext._ -import org.apache.spark.storage.StreamBlockId import org.apache.spark.streaming.{StreamingContext, Time} import org.apache.spark.streaming.receiver.{Receiver, ReceiverSupervisorImpl, StopReceiver} import org.apache.spark.util.AkkaUtils -/** Information about blocks received by the receiver */ -private[streaming] case class ReceivedBlockInfo( - streamId: Int, - blockId: StreamBlockId, - numRecords: Long, - metadata: Any - ) - /** * Messages used by the NetworkReceiver and the ReceiverTracker to communicate * with each other. @@ -153,7 +144,7 @@ class ReceiverTracker(ssc: StreamingContext) extends Logging { def addBlocks(receivedBlockInfo: ReceivedBlockInfo) { getReceivedBlockInfoQueue(receivedBlockInfo.streamId) += receivedBlockInfo logDebug("Stream " + receivedBlockInfo.streamId + " received new blocks: " + - receivedBlockInfo.blockId) + receivedBlockInfo.blockStoreResult.blockId) } /** Report error sent by a receiver */ @@ -188,6 +179,7 @@ class ReceiverTracker(ssc: StreamingContext) extends Logging { sender ! true case AddBlock(receivedBlockInfo) => addBlocks(receivedBlockInfo) + sender ! true case ReportError(streamId, message, error) => reportError(streamId, message, error) case DeregisterReceiver(streamId, message, error) => @@ -252,6 +244,9 @@ class ReceiverTracker(ssc: StreamingContext) extends Logging { ssc.sc.makeRDD(receivers, receivers.size) } + val checkpointDirOption = Option(ssc.checkpointDir) + val serializableHadoopConf = new SerializableWritable(ssc.sparkContext.hadoopConfiguration) + // Function to start the receiver on the worker node val startReceiver = (iterator: Iterator[Receiver[_]]) => { if (!iterator.hasNext) { @@ -259,9 +254,10 @@ class ReceiverTracker(ssc: StreamingContext) extends Logging { "Could not start receiver as object not found.") } val receiver = iterator.next() - val executor = new ReceiverSupervisorImpl(receiver, SparkEnv.get) - executor.start() - executor.awaitTermination() + val supervisor = new ReceiverSupervisorImpl( + receiver, SparkEnv.get, serializableHadoopConf.value, checkpointDirOption) + supervisor.start() + supervisor.awaitTermination() } // Run the dummy Spark job to ensure that all slaves have registered. // This avoids all the receivers to be scheduled on the same node. diff --git a/streaming/src/main/scala/org/apache/spark/streaming/util/HdfsUtils.scala b/streaming/src/main/scala/org/apache/spark/streaming/util/HdfsUtils.scala index 491f1175576e6..27a28bab83ed5 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/util/HdfsUtils.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/util/HdfsUtils.scala @@ -52,12 +52,14 @@ private[streaming] object HdfsUtils { } } - def getBlockLocations(path: String, conf: Configuration): Option[Array[String]] = { + /** Get the locations of the HDFS blocks containing the given file segment. */ + def getFileSegmentLocations( + path: String, offset: Long, length: Long, conf: Configuration): Array[String] = { val dfsPath = new Path(path) val dfs = getFileSystemForPath(dfsPath, conf) val fileStatus = dfs.getFileStatus(dfsPath) - val blockLocs = Option(dfs.getFileBlockLocations(fileStatus, 0, fileStatus.getLen)) - blockLocs.map(_.flatMap(_.getHosts)) + val blockLocs = Option(dfs.getFileBlockLocations(fileStatus, offset, length)) + blockLocs.map(_.flatMap(_.getHosts)).getOrElse(Array.empty) } def getFileSystemForPath(path: Path, conf: Configuration): FileSystem = { diff --git a/streaming/src/main/scala/org/apache/spark/streaming/util/WriteAheadLogRandomReader.scala b/streaming/src/main/scala/org/apache/spark/streaming/util/WriteAheadLogRandomReader.scala index 92bad7a882a65..003989092a42a 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/util/WriteAheadLogRandomReader.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/util/WriteAheadLogRandomReader.scala @@ -52,4 +52,3 @@ private[streaming] class WriteAheadLogRandomReader(path: String, conf: Configura HdfsUtils.checkState(!closed, "Stream is closed. Create a new Reader to read from the file.") } } - diff --git a/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockHandlerSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockHandlerSuite.scala new file mode 100644 index 0000000000000..ad1a6f01b3a57 --- /dev/null +++ b/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockHandlerSuite.scala @@ -0,0 +1,258 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.streaming + +import java.io.File +import java.nio.ByteBuffer + +import scala.collection.mutable.ArrayBuffer +import scala.concurrent.duration._ +import scala.language.postfixOps + +import akka.actor.{ActorSystem, Props} +import com.google.common.io.Files +import org.apache.commons.io.FileUtils +import org.apache.hadoop.conf.Configuration +import org.scalatest.{BeforeAndAfter, FunSuite, Matchers} +import org.scalatest.concurrent.Eventually._ + +import org.apache.spark._ +import org.apache.spark.network.nio.NioBlockTransferService +import org.apache.spark.scheduler.LiveListenerBus +import org.apache.spark.serializer.KryoSerializer +import org.apache.spark.shuffle.hash.HashShuffleManager +import org.apache.spark.storage._ +import org.apache.spark.streaming.receiver._ +import org.apache.spark.streaming.util._ +import org.apache.spark.util.AkkaUtils +import WriteAheadLogBasedBlockHandler._ +import WriteAheadLogSuite._ + +class ReceivedBlockHandlerSuite extends FunSuite with BeforeAndAfter with Matchers with Logging { + + val conf = new SparkConf().set("spark.streaming.receiver.writeAheadLog.rollingInterval", "1") + val hadoopConf = new Configuration() + val storageLevel = StorageLevel.MEMORY_ONLY_SER + val streamId = 1 + val securityMgr = new SecurityManager(conf) + val mapOutputTracker = new MapOutputTrackerMaster(conf) + val shuffleManager = new HashShuffleManager(conf) + val serializer = new KryoSerializer(conf) + val manualClock = new ManualClock + val blockManagerSize = 10000000 + + var actorSystem: ActorSystem = null + var blockManagerMaster: BlockManagerMaster = null + var blockManager: BlockManager = null + var tempDirectory: File = null + + before { + val (actorSystem, boundPort) = AkkaUtils.createActorSystem( + "test", "localhost", 0, conf = conf, securityManager = securityMgr) + this.actorSystem = actorSystem + conf.set("spark.driver.port", boundPort.toString) + + blockManagerMaster = new BlockManagerMaster( + actorSystem.actorOf(Props(new BlockManagerMasterActor(true, conf, new LiveListenerBus))), + conf, true) + + blockManager = new BlockManager("bm", actorSystem, blockManagerMaster, serializer, + blockManagerSize, conf, mapOutputTracker, shuffleManager, + new NioBlockTransferService(conf, securityMgr)) + + tempDirectory = Files.createTempDir() + manualClock.setTime(0) + } + + after { + if (blockManager != null) { + blockManager.stop() + blockManager = null + } + if (blockManagerMaster != null) { + blockManagerMaster.stop() + blockManagerMaster = null + } + actorSystem.shutdown() + actorSystem.awaitTermination() + actorSystem = null + + if (tempDirectory != null && tempDirectory.exists()) { + FileUtils.deleteDirectory(tempDirectory) + tempDirectory = null + } + } + + test("BlockManagerBasedBlockHandler - store blocks") { + withBlockManagerBasedBlockHandler { handler => + testBlockStoring(handler) { case (data, blockIds, storeResults) => + // Verify the data in block manager is correct + val storedData = blockIds.flatMap { blockId => + blockManager.getLocal(blockId).map { _.data.map {_.toString}.toList }.getOrElse(List.empty) + }.toList + storedData shouldEqual data + + // Verify that the store results are instances of BlockManagerBasedStoreResult + assert( + storeResults.forall { _.isInstanceOf[BlockManagerBasedStoreResult] }, + "Unexpected store result type" + ) + } + } + } + + test("BlockManagerBasedBlockHandler - handle errors in storing block") { + withBlockManagerBasedBlockHandler { handler => + testErrorHandling(handler) + } + } + + test("WriteAheadLogBasedBlockHandler - store blocks") { + withWriteAheadLogBasedBlockHandler { handler => + testBlockStoring(handler) { case (data, blockIds, storeResults) => + // Verify the data in block manager is correct + val storedData = blockIds.flatMap { blockId => + blockManager.getLocal(blockId).map { _.data.map {_.toString}.toList }.getOrElse(List.empty) + }.toList + storedData shouldEqual data + + // Verify that the store results are instances of WriteAheadLogBasedStoreResult + assert( + storeResults.forall { _.isInstanceOf[WriteAheadLogBasedStoreResult] }, + "Unexpected store result type" + ) + // Verify the data in write ahead log files is correct + val fileSegments = storeResults.map { _.asInstanceOf[WriteAheadLogBasedStoreResult].segment} + val loggedData = fileSegments.flatMap { segment => + val reader = new WriteAheadLogRandomReader(segment.path, hadoopConf) + val bytes = reader.read(segment) + reader.close() + blockManager.dataDeserialize(generateBlockId(), bytes).toList + } + loggedData shouldEqual data + } + } + } + + test("WriteAheadLogBasedBlockHandler - handle errors in storing block") { + withWriteAheadLogBasedBlockHandler { handler => + testErrorHandling(handler) + } + } + + test("WriteAheadLogBasedBlockHandler - cleanup old blocks") { + withWriteAheadLogBasedBlockHandler { handler => + val blocks = Seq.tabulate(10) { i => IteratorBlock(Iterator(1 to i)) } + storeBlocks(handler, blocks) + + val preCleanupLogFiles = getWriteAheadLogFiles() + preCleanupLogFiles.size should be > 1 + + // this depends on the number of blocks inserted using generateAndStoreData() + manualClock.currentTime() shouldEqual 5000L + + val cleanupThreshTime = 3000L + handler.cleanupOldBlock(cleanupThreshTime) + eventually(timeout(10000 millis), interval(10 millis)) { + getWriteAheadLogFiles().size should be < preCleanupLogFiles.size + } + } + } + + /** + * Test storing of data using different forms of ReceivedBlocks and verify that they succeeded + * using the given verification function + */ + private def testBlockStoring(receivedBlockHandler: ReceivedBlockHandler) + (verifyFunc: (Seq[String], Seq[StreamBlockId], Seq[ReceivedBlockStoreResult]) => Unit) { + val data = Seq.tabulate(100) { _.toString } + + def storeAndVerify(blocks: Seq[ReceivedBlock]) { + blocks should not be empty + val (blockIds, storeResults) = storeBlocks(receivedBlockHandler, blocks) + withClue(s"Testing with ${blocks.head.getClass.getSimpleName}s:") { + // Verify returns store results have correct block ids + (storeResults.map { _.blockId }) shouldEqual blockIds + + // Call handler-specific verification function + verifyFunc(data, blockIds, storeResults) + } + } + + def dataToByteBuffer(b: Seq[String]) = blockManager.dataSerialize(generateBlockId, b.iterator) + + val blocks = data.grouped(10).toSeq + + storeAndVerify(blocks.map { b => IteratorBlock(b.toIterator) }) + storeAndVerify(blocks.map { b => ArrayBufferBlock(new ArrayBuffer ++= b) }) + storeAndVerify(blocks.map { b => ByteBufferBlock(dataToByteBuffer(b)) }) + } + + /** Test error handling when blocks that cannot be stored */ + private def testErrorHandling(receivedBlockHandler: ReceivedBlockHandler) { + // Handle error in iterator (e.g. divide-by-zero error) + intercept[Exception] { + val iterator = (10 to (-10, -1)).toIterator.map { _ / 0 } + receivedBlockHandler.storeBlock(StreamBlockId(1, 1), IteratorBlock(iterator)) + } + + // Handler error in block manager storing (e.g. too big block) + intercept[SparkException] { + val byteBuffer = ByteBuffer.wrap(new Array[Byte](blockManagerSize + 1)) + receivedBlockHandler.storeBlock(StreamBlockId(1, 1), ByteBufferBlock(byteBuffer)) + } + } + + /** Instantiate a BlockManagerBasedBlockHandler and run a code with it */ + private def withBlockManagerBasedBlockHandler(body: BlockManagerBasedBlockHandler => Unit) { + body(new BlockManagerBasedBlockHandler(blockManager, storageLevel)) + } + + /** Instantiate a WriteAheadLogBasedBlockHandler and run a code with it */ + private def withWriteAheadLogBasedBlockHandler(body: WriteAheadLogBasedBlockHandler => Unit) { + val receivedBlockHandler = new WriteAheadLogBasedBlockHandler(blockManager, 1, + storageLevel, conf, hadoopConf, tempDirectory.toString, manualClock) + try { + body(receivedBlockHandler) + } finally { + receivedBlockHandler.stop() + } + } + + /** Store blocks using a handler */ + private def storeBlocks( + receivedBlockHandler: ReceivedBlockHandler, + blocks: Seq[ReceivedBlock] + ): (Seq[StreamBlockId], Seq[ReceivedBlockStoreResult]) = { + val blockIds = Seq.fill(blocks.size)(generateBlockId()) + val storeResults = blocks.zip(blockIds).map { + case (block, id) => + manualClock.addToTime(500) // log rolling interval set to 1000 ms through SparkConf + logDebug("Inserting block " + id) + receivedBlockHandler.storeBlock(id, block) + }.toList + logDebug("Done inserting") + (blockIds, storeResults) + } + + private def getWriteAheadLogFiles(): Seq[String] = { + getLogFilesInDirectory(checkpointDirToLogDir(tempDirectory.toString, streamId)) + } + + private def generateBlockId(): StreamBlockId = StreamBlockId(streamId, scala.util.Random.nextLong) +} diff --git a/streaming/src/test/scala/org/apache/spark/streaming/NetworkReceiverSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/ReceiverSuite.scala similarity index 86% rename from streaming/src/test/scala/org/apache/spark/streaming/NetworkReceiverSuite.scala rename to streaming/src/test/scala/org/apache/spark/streaming/ReceiverSuite.scala index eb6e88cf5520d..0f6a9489dbe0d 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/NetworkReceiverSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/ReceiverSuite.scala @@ -31,9 +31,9 @@ import org.scalatest.concurrent.Eventually._ import org.scalatest.time.SpanSugar._ /** Testsuite for testing the network receiver behavior */ -class NetworkReceiverSuite extends FunSuite with Timeouts { +class ReceiverSuite extends FunSuite with Timeouts { - test("network receiver life cycle") { + test("receiver life cycle") { val receiver = new FakeReceiver val executor = new FakeReceiverSupervisor(receiver) @@ -152,8 +152,8 @@ class NetworkReceiverSuite extends FunSuite with Timeouts { test("block generator throttling") { val blockGeneratorListener = new FakeBlockGeneratorListener - val blockInterval = 50 - val maxRate = 200 + val blockInterval = 100 + val maxRate = 100 val conf = new SparkConf().set("spark.streaming.blockInterval", blockInterval.toString). set("spark.streaming.receiver.maxRate", maxRate.toString) val blockGenerator = new BlockGenerator(blockGeneratorListener, 1, conf) @@ -175,19 +175,35 @@ class NetworkReceiverSuite extends FunSuite with Timeouts { } blockGenerator.stop() - val recordedData = blockGeneratorListener.arrayBuffers - assert(blockGeneratorListener.arrayBuffers.size > 0) - assert(recordedData.flatten.toSet === generatedData.toSet) + val recordedBlocks = blockGeneratorListener.arrayBuffers + val recordedData = recordedBlocks.flatten + assert(blockGeneratorListener.arrayBuffers.size > 0, "No blocks received") + assert(recordedData.toSet === generatedData.toSet, "Received data not same") + // recordedData size should be close to the expected rate - assert(recordedData.flatten.size >= expectedMessages * 0.9 && - recordedData.flatten.size <= expectedMessages * 1.1 ) - // the first and last block may be incomplete, so we slice them out - recordedData.slice(1, recordedData.size - 1).foreach { block => - assert(block.size >= expectedMessagesPerBlock * 0.8 && - block.size <= expectedMessagesPerBlock * 1.2 ) - } + val minExpectedMessages = expectedMessages - 3 + val maxExpectedMessages = expectedMessages + 1 + val numMessages = recordedData.size + assert( + numMessages >= minExpectedMessages && numMessages <= maxExpectedMessages, + s"#records received = $numMessages, not between $minExpectedMessages and $maxExpectedMessages" + ) + + val minExpectedMessagesPerBlock = expectedMessagesPerBlock - 3 + val maxExpectedMessagesPerBlock = expectedMessagesPerBlock + 1 + val receivedBlockSizes = recordedBlocks.map { _.size }.mkString(",") + println(minExpectedMessagesPerBlock, maxExpectedMessagesPerBlock, ":", receivedBlockSizes) + assert( + // the first and last block may be incomplete, so we slice them out + recordedBlocks.drop(1).dropRight(1).forall { block => + block.size >= minExpectedMessagesPerBlock && block.size <= maxExpectedMessagesPerBlock + }, + s"# records in received blocks = [$receivedBlockSizes], not between " + + s"$minExpectedMessagesPerBlock and $maxExpectedMessagesPerBlock" + ) } + /** * An implementation of NetworkReceiver that is used for testing a receiver's life cycle. */ diff --git a/streaming/src/test/scala/org/apache/spark/streaming/StreamingContextSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/StreamingContextSuite.scala index 655cec1573f58..f47772947d67c 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/StreamingContextSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/StreamingContextSuite.scala @@ -46,6 +46,10 @@ class StreamingContextSuite extends FunSuite with BeforeAndAfter with Timeouts w after { if (ssc != null) { ssc.stop() + if (ssc.sc != null) { + // Calling ssc.stop() does not always stop the associated SparkContext. + ssc.sc.stop() + } ssc = null } if (sc != null) { diff --git a/streaming/src/test/scala/org/apache/spark/streaming/rdd/WriteAheadLogBackedBlockRDDSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/rdd/WriteAheadLogBackedBlockRDDSuite.scala new file mode 100644 index 0000000000000..10160244bcc91 --- /dev/null +++ b/streaming/src/test/scala/org/apache/spark/streaming/rdd/WriteAheadLogBackedBlockRDDSuite.scala @@ -0,0 +1,151 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.streaming.rdd + +import java.io.File + +import scala.util.Random + +import com.google.common.io.Files +import org.apache.hadoop.conf.Configuration +import org.scalatest.{BeforeAndAfterAll, FunSuite} + +import org.apache.spark.{SparkConf, SparkContext} +import org.apache.spark.storage.{BlockId, BlockManager, StorageLevel, StreamBlockId} +import org.apache.spark.streaming.util.{WriteAheadLogFileSegment, WriteAheadLogWriter} + +class WriteAheadLogBackedBlockRDDSuite extends FunSuite with BeforeAndAfterAll { + val conf = new SparkConf() + .setMaster("local[2]") + .setAppName(this.getClass.getSimpleName) + val hadoopConf = new Configuration() + + var sparkContext: SparkContext = null + var blockManager: BlockManager = null + var dir: File = null + + override def beforeAll(): Unit = { + sparkContext = new SparkContext(conf) + blockManager = sparkContext.env.blockManager + dir = Files.createTempDir() + } + + override def afterAll(): Unit = { + // Copied from LocalSparkContext, simpler than to introduced test dependencies to core tests. + sparkContext.stop() + dir.delete() + System.clearProperty("spark.driver.port") + } + + test("Read data available in block manager and write ahead log") { + testRDD(5, 5) + } + + test("Read data available only in block manager, not in write ahead log") { + testRDD(5, 0) + } + + test("Read data available only in write ahead log, not in block manager") { + testRDD(0, 5) + } + + test("Read data available only in write ahead log, and test storing in block manager") { + testRDD(0, 5, testStoreInBM = true) + } + + test("Read data with partially available in block manager, and rest in write ahead log") { + testRDD(3, 2) + } + + /** + * Test the WriteAheadLogBackedRDD, by writing some partitions of the data to block manager + * and the rest to a write ahead log, and then reading reading it all back using the RDD. + * It can also test if the partitions that were read from the log were again stored in + * block manager. + * @param numPartitionsInBM Number of partitions to write to the Block Manager + * @param numPartitionsInWAL Number of partitions to write to the Write Ahead Log + * @param testStoreInBM Test whether blocks read from log are stored back into block manager + */ + private def testRDD(numPartitionsInBM: Int, numPartitionsInWAL: Int, testStoreInBM: Boolean = false) { + val numBlocks = numPartitionsInBM + numPartitionsInWAL + val data = Seq.fill(numBlocks, 10)(scala.util.Random.nextString(50)) + + // Put the necessary blocks in the block manager + val blockIds = Array.fill(numBlocks)(StreamBlockId(Random.nextInt(), Random.nextInt())) + data.zip(blockIds).take(numPartitionsInBM).foreach { case(block, blockId) => + blockManager.putIterator(blockId, block.iterator, StorageLevel.MEMORY_ONLY_SER) + } + + // Generate write ahead log segments + val segments = generateFakeSegments(numPartitionsInBM) ++ + writeLogSegments(data.takeRight(numPartitionsInWAL), blockIds.takeRight(numPartitionsInWAL)) + + // Make sure that the left `numPartitionsInBM` blocks are in block manager, and others are not + require( + blockIds.take(numPartitionsInBM).forall(blockManager.get(_).nonEmpty), + "Expected blocks not in BlockManager" + ) + require( + blockIds.takeRight(numPartitionsInWAL).forall(blockManager.get(_).isEmpty), + "Unexpected blocks in BlockManager" + ) + + // Make sure that the right `numPartitionsInWAL` blocks are in write ahead logs, and other are not + require( + segments.takeRight(numPartitionsInWAL).forall(s => + new File(s.path.stripPrefix("file://")).exists()), + "Expected blocks not in write ahead log" + ) + require( + segments.take(numPartitionsInBM).forall(s => + !new File(s.path.stripPrefix("file://")).exists()), + "Unexpected blocks in write ahead log" + ) + + // Create the RDD and verify whether the returned data is correct + val rdd = new WriteAheadLogBackedBlockRDD[String](sparkContext, hadoopConf, blockIds.toArray, + segments.toArray, storeInBlockManager = false, StorageLevel.MEMORY_ONLY) + assert(rdd.collect() === data.flatten) + + if (testStoreInBM) { + val rdd2 = new WriteAheadLogBackedBlockRDD[String](sparkContext, hadoopConf, blockIds.toArray, + segments.toArray, storeInBlockManager = true, StorageLevel.MEMORY_ONLY) + assert(rdd2.collect() === data.flatten) + assert( + blockIds.forall(blockManager.get(_).nonEmpty), + "All blocks not found in block manager" + ) + } + } + + private def writeLogSegments( + blockData: Seq[Seq[String]], + blockIds: Seq[BlockId] + ): Seq[WriteAheadLogFileSegment] = { + require(blockData.size === blockIds.size) + val writer = new WriteAheadLogWriter(new File(dir, Random.nextString(10)).toString, hadoopConf) + val segments = blockData.zip(blockIds).map { case (data, id) => + writer.write(blockManager.dataSerialize(id, data.iterator)) + } + writer.close() + segments + } + + private def generateFakeSegments(count: Int): Seq[WriteAheadLogFileSegment] = { + Array.fill(count)(new WriteAheadLogFileSegment("random", 0l, 0)) + } +} diff --git a/streaming/src/test/scala/org/apache/spark/streaming/util/WriteAheadLogSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/util/WriteAheadLogSuite.scala index 5eba93c208c50..1956a4f1db90a 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/util/WriteAheadLogSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/util/WriteAheadLogSuite.scala @@ -58,7 +58,7 @@ class WriteAheadLogSuite extends FunSuite with BeforeAndAfter { test("WriteAheadLogWriter - writing data") { val dataToWrite = generateRandomData() val segments = writeDataUsingWriter(testFile, dataToWrite) - val writtenData = readDataManually(testFile, segments) + val writtenData = readDataManually(segments) assert(writtenData === dataToWrite) } @@ -67,7 +67,7 @@ class WriteAheadLogSuite extends FunSuite with BeforeAndAfter { val writer = new WriteAheadLogWriter(testFile, hadoopConf) dataToWrite.foreach { data => val segment = writer.write(stringToByteBuffer(data)) - val dataRead = readDataManually(testFile, Seq(segment)).head + val dataRead = readDataManually(Seq(segment)).head assert(data === dataRead) } writer.close() @@ -281,14 +281,20 @@ object WriteAheadLogSuite { } /** Read data from a segments of a log file directly and return the list of byte buffers.*/ - def readDataManually(file: String, segments: Seq[WriteAheadLogFileSegment]): Seq[String] = { - val reader = HdfsUtils.getInputStream(file, hadoopConf) - segments.map { x => - reader.seek(x.offset) - val data = new Array[Byte](x.length) - reader.readInt() - reader.readFully(data) - Utils.deserialize[String](data) + def readDataManually(segments: Seq[WriteAheadLogFileSegment]): Seq[String] = { + segments.map { segment => + val reader = HdfsUtils.getInputStream(segment.path, hadoopConf) + try { + reader.seek(segment.offset) + val bytes = new Array[Byte](segment.length) + reader.readInt() + reader.readFully(bytes) + val data = Utils.deserialize[String](bytes) + reader.close() + data + } finally { + reader.close() + } } } @@ -335,9 +341,11 @@ object WriteAheadLogSuite { val fileSystem = HdfsUtils.getFileSystemForPath(logDirectoryPath, hadoopConf) if (fileSystem.exists(logDirectoryPath) && fileSystem.getFileStatus(logDirectoryPath).isDir) { - fileSystem.listStatus(logDirectoryPath).map { - _.getPath.toString.stripPrefix("file:") - }.sorted + fileSystem.listStatus(logDirectoryPath).map { _.getPath() }.sortBy { + _.getName().split("-")(1).toLong + }.map { + _.toString.stripPrefix("file:") + } } else { Seq.empty } diff --git a/yarn/alpha/src/main/scala/org/apache/spark/deploy/yarn/Client.scala b/yarn/alpha/src/main/scala/org/apache/spark/deploy/yarn/Client.scala index 9c66c785848a5..73b705ba50051 100644 --- a/yarn/alpha/src/main/scala/org/apache/spark/deploy/yarn/Client.scala +++ b/yarn/alpha/src/main/scala/org/apache/spark/deploy/yarn/Client.scala @@ -35,6 +35,7 @@ import org.apache.spark.deploy.SparkHadoopUtil /** * Version of [[org.apache.spark.deploy.yarn.ClientBase]] tailored to YARN's alpha API. */ +@deprecated("use yarn/stable", "1.2.0") private[spark] class Client( val args: ClientArguments, val hadoopConf: Configuration, @@ -131,6 +132,7 @@ object Client { println("WARNING: This client is deprecated and will be removed in a " + "future version of Spark. Use ./bin/spark-submit with \"--master yarn\"") } + println("WARNING: Support for YARN-alpha API's will be removed in Spark 1.3 (see SPARK-3445)") // Set an env variable indicating we are running in YARN mode. // Note that any env variable with the SPARK_ prefix gets propagated to all (remote) processes diff --git a/yarn/alpha/src/main/scala/org/apache/spark/deploy/yarn/ExecutorRunnable.scala b/yarn/alpha/src/main/scala/org/apache/spark/deploy/yarn/ExecutorRunnable.scala index 229b7a09f456b..7ee4b5c842df1 100644 --- a/yarn/alpha/src/main/scala/org/apache/spark/deploy/yarn/ExecutorRunnable.scala +++ b/yarn/alpha/src/main/scala/org/apache/spark/deploy/yarn/ExecutorRunnable.scala @@ -37,7 +37,7 @@ import org.apache.hadoop.yarn.util.{Apps, ConverterUtils, Records, ProtoUtils} import org.apache.spark.{SecurityManager, SparkConf, Logging} - +@deprecated("use yarn/stable", "1.2.0") class ExecutorRunnable( container: Container, conf: Configuration, diff --git a/yarn/alpha/src/main/scala/org/apache/spark/deploy/yarn/YarnRMClientImpl.scala b/yarn/alpha/src/main/scala/org/apache/spark/deploy/yarn/YarnRMClientImpl.scala index 7faf55bc63372..e342cc82f454e 100644 --- a/yarn/alpha/src/main/scala/org/apache/spark/deploy/yarn/YarnRMClientImpl.scala +++ b/yarn/alpha/src/main/scala/org/apache/spark/deploy/yarn/YarnRMClientImpl.scala @@ -18,7 +18,7 @@ package org.apache.spark.deploy.yarn import scala.collection.{Map, Set} -import java.net.URI; +import java.net.URI import org.apache.hadoop.net.NetUtils import org.apache.hadoop.yarn.api._ @@ -109,7 +109,9 @@ private class YarnRMClientImpl(args: ApplicationMasterArguments) extends YarnRMC appMasterRequest.setHost(Utils.localHostName()) appMasterRequest.setRpcPort(0) // remove the scheme from the url if it exists since Hadoop does not expect scheme - appMasterRequest.setTrackingUrl(new URI(uiAddress).getAuthority()) + val uri = new URI(uiAddress) + val authority = if (uri.getScheme == null) uiAddress else uri.getAuthority + appMasterRequest.setTrackingUrl(authority) resourceManager.registerApplicationMaster(appMasterRequest) } diff --git a/yarn/common/src/main/resources/log4j-spark-container.properties b/yarn/common/src/main/resources/log4j-spark-container.properties deleted file mode 100644 index a1e37a0be27dd..0000000000000 --- a/yarn/common/src/main/resources/log4j-spark-container.properties +++ /dev/null @@ -1,24 +0,0 @@ -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. See accompanying LICENSE file. - -# Set everything to be logged to the console -log4j.rootCategory=INFO, console -log4j.appender.console=org.apache.log4j.ConsoleAppender -log4j.appender.console.target=System.err -log4j.appender.console.layout=org.apache.log4j.PatternLayout -log4j.appender.console.layout.ConversionPattern=%d{yy/MM/dd HH:mm:ss} %p %c{1}: %m%n - -# Settings to quiet third party logs that are too verbose -log4j.logger.org.eclipse.jetty=WARN -log4j.logger.org.apache.spark.repl.SparkIMain$exprTyper=INFO -log4j.logger.org.apache.spark.repl.SparkILoop$SparkILoopInterpreter=INFO diff --git a/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala b/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala index e6fe0265d8811..e90672c004d4b 100644 --- a/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala +++ b/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala @@ -36,8 +36,8 @@ import org.apache.spark.{Logging, SecurityManager, SparkConf, SparkContext, Spar import org.apache.spark.SparkException import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.deploy.history.HistoryServer -import org.apache.spark.scheduler.cluster.CoarseGrainedSchedulerBackend -import org.apache.spark.scheduler.cluster.CoarseGrainedClusterMessages.AddWebUIFilter +import org.apache.spark.scheduler.cluster.YarnSchedulerBackend +import org.apache.spark.scheduler.cluster.CoarseGrainedClusterMessages._ import org.apache.spark.util.{AkkaUtils, SignalLogger, Utils} /** @@ -385,8 +385,8 @@ private[spark] class ApplicationMaster(args: ApplicationMasterArguments, SparkEnv.driverActorSystemName, driverHost, driverPort.toString, - CoarseGrainedSchedulerBackend.ACTOR_NAME) - actorSystem.actorOf(Props(new MonitorActor(driverUrl)), name = "YarnAM") + YarnSchedulerBackend.ACTOR_NAME) + actorSystem.actorOf(Props(new AMActor(driverUrl)), name = "YarnAM") } /** Add the Yarn IP filter that is required for properly securing the UI. */ @@ -479,9 +479,10 @@ private[spark] class ApplicationMaster(args: ApplicationMasterArguments, userThread } - // Actor used to monitor the driver when running in client deploy mode. - private class MonitorActor(driverUrl: String) extends Actor { - + /** + * Actor that communicates with the driver in client deploy mode. + */ + private class AMActor(driverUrl: String) extends Actor { var driver: ActorSelection = _ override def preStart() = { @@ -490,6 +491,7 @@ private[spark] class ApplicationMaster(args: ApplicationMasterArguments, // Send a hello message to establish the connection, after which // we can monitor Lifecycle Events. driver ! "Hello" + driver ! RegisterClusterManager context.system.eventStream.subscribe(self, classOf[RemotingLifecycleEvent]) } @@ -497,11 +499,27 @@ private[spark] class ApplicationMaster(args: ApplicationMasterArguments, case x: DisassociatedEvent => logInfo(s"Driver terminated or disconnected! Shutting down. $x") finish(FinalApplicationStatus.SUCCEEDED, ApplicationMaster.EXIT_SUCCESS) + case x: AddWebUIFilter => logInfo(s"Add WebUI Filter. $x") driver ! x - } + case RequestExecutors(requestedTotal) => + logInfo(s"Driver requested a total number of $requestedTotal executor(s).") + Option(allocator) match { + case Some(a) => a.requestTotalExecutors(requestedTotal) + case None => logWarning("Container allocator is not ready to request executors yet.") + } + sender ! true + + case KillExecutors(executorIds) => + logInfo(s"Driver requested to kill executor(s) ${executorIds.mkString(", ")}.") + Option(allocator) match { + case Some(a) => executorIds.foreach(a.killExecutor) + case None => logWarning("Container allocator is not ready to kill executors yet.") + } + sender ! true + } } } diff --git a/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMasterArguments.scala b/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMasterArguments.scala index 3e6b96fb63cea..8b32c76d14037 100644 --- a/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMasterArguments.scala +++ b/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMasterArguments.scala @@ -17,7 +17,8 @@ package org.apache.spark.deploy.yarn -import org.apache.spark.util.IntParam +import org.apache.spark.util.{MemoryParam, IntParam} +import org.apache.spark.deploy.yarn.YarnSparkHadoopUtil._ import collection.mutable.ArrayBuffer class ApplicationMasterArguments(val args: Array[String]) { @@ -26,7 +27,7 @@ class ApplicationMasterArguments(val args: Array[String]) { var userArgs: Seq[String] = Seq[String]() var executorMemory = 1024 var executorCores = 1 - var numExecutors = ApplicationMasterArguments.DEFAULT_NUMBER_EXECUTORS + var numExecutors = DEFAULT_NUMBER_EXECUTORS parseArgs(args.toList) @@ -55,7 +56,7 @@ class ApplicationMasterArguments(val args: Array[String]) { numExecutors = value args = tail - case ("--worker-memory" | "--executor-memory") :: IntParam(value) :: tail => + case ("--worker-memory" | "--executor-memory") :: MemoryParam(value) :: tail => executorMemory = value args = tail @@ -81,7 +82,7 @@ class ApplicationMasterArguments(val args: Array[String]) { | --jar JAR_PATH Path to your application's JAR file | --class CLASS_NAME Name of your application's main class | --args ARGS Arguments to be passed to your application's main class. - | Mutliple invocations are possible, each will be passed in order. + | Multiple invocations are possible, each will be passed in order. | --num-executors NUM Number of executors to start (Default: 2) | --executor-cores NUM Number of cores for the executors (Default: 1) | --executor-memory MEM Memory per executor (e.g. 1000M, 2G) (Default: 1G) diff --git a/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/ClientArguments.scala b/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/ClientArguments.scala index a12f82d2fbe70..4d859450efc63 100644 --- a/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/ClientArguments.scala +++ b/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/ClientArguments.scala @@ -20,8 +20,8 @@ package org.apache.spark.deploy.yarn import scala.collection.mutable.ArrayBuffer import org.apache.spark.SparkConf -import org.apache.spark.util.{Utils, IntParam, MemoryParam} import org.apache.spark.deploy.yarn.YarnSparkHadoopUtil._ +import org.apache.spark.util.{Utils, IntParam, MemoryParam} // TODO: Add code and support for ensuring that yarn resource 'tasks' are location aware ! private[spark] class ClientArguments(args: Array[String], sparkConf: SparkConf) { @@ -33,23 +33,25 @@ private[spark] class ClientArguments(args: Array[String], sparkConf: SparkConf) var userArgs: Seq[String] = Seq[String]() var executorMemory = 1024 // MB var executorCores = 1 - var numExecutors = 2 + var numExecutors = DEFAULT_NUMBER_EXECUTORS var amQueue = sparkConf.get("spark.yarn.queue", "default") var amMemory: Int = 512 // MB var appName: String = "Spark" var priority = 0 - parseArgs(args.toList) - loadEnvironmentArgs() - // Additional memory to allocate to containers // For now, use driver's memory overhead as our AM container's memory overhead - val amMemoryOverhead = sparkConf.getInt("spark.yarn.driver.memoryOverhead", + val amMemoryOverhead = sparkConf.getInt("spark.yarn.driver.memoryOverhead", math.max((MEMORY_OVERHEAD_FACTOR * amMemory).toInt, MEMORY_OVERHEAD_MIN)) - val executorMemoryOverhead = sparkConf.getInt("spark.yarn.executor.memoryOverhead", + val executorMemoryOverhead = sparkConf.getInt("spark.yarn.executor.memoryOverhead", math.max((MEMORY_OVERHEAD_FACTOR * executorMemory).toInt, MEMORY_OVERHEAD_MIN)) + private val isDynamicAllocationEnabled = + sparkConf.getBoolean("spark.dynamicAllocation.enabled", false) + + parseArgs(args.toList) + loadEnvironmentArgs() validateArgs() /** Load any default arguments provided through environment variables and Spark properties. */ @@ -64,6 +66,15 @@ private[spark] class ClientArguments(args: Array[String], sparkConf: SparkConf) .orElse(sys.env.get("SPARK_YARN_DIST_ARCHIVES")) .orElse(sparkConf.getOption("spark.yarn.dist.archives").map(p => Utils.resolveURIs(p))) .orNull + // If dynamic allocation is enabled, start at the max number of executors + if (isDynamicAllocationEnabled) { + val maxExecutorsConf = "spark.dynamicAllocation.maxExecutors" + if (!sparkConf.contains(maxExecutorsConf)) { + throw new IllegalArgumentException( + s"$maxExecutorsConf must be set if dynamic allocation is enabled!") + } + numExecutors = sparkConf.get(maxExecutorsConf).toInt + } } /** @@ -113,6 +124,11 @@ private[spark] class ClientArguments(args: Array[String], sparkConf: SparkConf) if (args(0) == "--num-workers") { println("--num-workers is deprecated. Use --num-executors instead.") } + // Dynamic allocation is not compatible with this option + if (isDynamicAllocationEnabled) { + throw new IllegalArgumentException("Explicitly setting the number " + + "of executors is not compatible with spark.dynamicAllocation.enabled!") + } numExecutors = value args = tail diff --git a/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/ClientBase.scala b/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/ClientBase.scala index fb0e34bf5985e..f95d72379171c 100644 --- a/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/ClientBase.scala +++ b/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/ClientBase.scala @@ -39,6 +39,7 @@ import org.apache.hadoop.yarn.conf.YarnConfiguration import org.apache.hadoop.yarn.util.Records import org.apache.spark.{Logging, SecurityManager, SparkConf, SparkContext, SparkException} +import org.apache.spark.util.Utils /** * The entry point (starting in Client#main() and Client#run()) for launching Spark on YARN. @@ -55,6 +56,7 @@ private[spark] trait ClientBase extends Logging { protected val amMemoryOverhead = args.amMemoryOverhead // MB protected val executorMemoryOverhead = args.executorMemoryOverhead // MB private val distCacheMgr = new ClientDistributedCacheManager() + private val isLaunchingDriver = args.userClass != null /** * Fail fast if we have requested more resources per container than is available in the cluster. @@ -267,7 +269,6 @@ private[spark] trait ClientBase extends Logging { // Note that to warn the user about the deprecation in cluster mode, some code from // SparkConf#validateSettings() is duplicated here (to avoid triggering the condition // described above). - val isLaunchingDriver = args.userClass != null if (isLaunchingDriver) { sys.env.get("SPARK_JAVA_OPTS").foreach { value => val warning = @@ -312,6 +313,10 @@ private[spark] trait ClientBase extends Logging { val javaOpts = ListBuffer[String]() + // Set the environment variable through a command prefix + // to append to the existing value of the variable + var prefixEnv: Option[String] = None + // Add Xmx for AM memory javaOpts += "-Xmx" + args.amMemory + "m" @@ -344,20 +349,22 @@ private[spark] trait ClientBase extends Logging { } // Include driver-specific java options if we are launching a driver - val isLaunchingDriver = args.userClass != null if (isLaunchingDriver) { sparkConf.getOption("spark.driver.extraJavaOptions") .orElse(sys.env.get("SPARK_JAVA_OPTS")) .foreach(opts => javaOpts += opts) - sparkConf.getOption("spark.driver.libraryPath") - .foreach(p => javaOpts += s"-Djava.library.path=$p") + val libraryPaths = Seq(sys.props.get("spark.driver.extraLibraryPath"), + sys.props.get("spark.driver.libraryPath")).flatten + if (libraryPaths.nonEmpty) { + prefixEnv = Some(Utils.libraryPathEnvPrefix(libraryPaths)) + } } // For log4j configuration to reference javaOpts += ("-Dspark.yarn.app.container.log.dir=" + ApplicationConstants.LOG_DIR_EXPANSION_VAR) val userClass = - if (args.userClass != null) { + if (isLaunchingDriver) { Seq("--class", YarnSparkHadoopUtil.escapeForShell(args.userClass)) } else { Nil @@ -380,12 +387,12 @@ private[spark] trait ClientBase extends Logging { val amArgs = Seq(amClass) ++ userClass ++ userJar ++ userArgs ++ Seq( - "--executor-memory", args.executorMemory.toString, + "--executor-memory", args.executorMemory.toString + "m", "--executor-cores", args.executorCores.toString, "--num-executors ", args.numExecutors.toString) // Command for the ApplicationMaster - val commands = Seq(Environment.JAVA_HOME.$() + "/bin/java", "-server") ++ + val commands = prefixEnv ++ Seq(Environment.JAVA_HOME.$() + "/bin/java", "-server") ++ javaOpts ++ amArgs ++ Seq( "1>", ApplicationConstants.LOG_DIR_EXPANSION_VAR + "/stdout", diff --git a/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/ExecutorRunnableUtil.scala b/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/ExecutorRunnableUtil.scala index 5cb4753de2e84..88dad0febd03f 100644 --- a/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/ExecutorRunnableUtil.scala +++ b/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/ExecutorRunnableUtil.scala @@ -30,6 +30,7 @@ import org.apache.hadoop.yarn.conf.YarnConfiguration import org.apache.hadoop.yarn.util.{ConverterUtils, Records} import org.apache.spark.{Logging, SparkConf} +import org.apache.spark.util.Utils trait ExecutorRunnableUtil extends Logging { @@ -47,6 +48,11 @@ trait ExecutorRunnableUtil extends Logging { localResources: HashMap[String, LocalResource]): List[String] = { // Extra options for the JVM val javaOpts = ListBuffer[String]() + + // Set the environment variable through a command prefix + // to append to the existing value of the variable + var prefixEnv: Option[String] = None + // Set the JVM memory val executorMemoryString = executorMemory + "m" javaOpts += "-Xms" + executorMemoryString + " -Xmx" + executorMemoryString + " " @@ -58,6 +64,9 @@ trait ExecutorRunnableUtil extends Logging { sys.env.get("SPARK_JAVA_OPTS").foreach { opts => javaOpts += opts } + sys.props.get("spark.executor.extraLibraryPath").foreach { p => + prefixEnv = Some(Utils.libraryPathEnvPrefix(Seq(p))) + } javaOpts += "-Djava.io.tmpdir=" + new Path(Environment.PWD.$(), YarnConfiguration.DEFAULT_CONTAINER_TEMP_DIR) @@ -101,7 +110,7 @@ trait ExecutorRunnableUtil extends Logging { // For log4j configuration to reference javaOpts += ("-Dspark.yarn.app.container.log.dir=" + ApplicationConstants.LOG_DIR_EXPANSION_VAR) - val commands = Seq(Environment.JAVA_HOME.$() + "/bin/java", + val commands = prefixEnv ++ Seq(Environment.JAVA_HOME.$() + "/bin/java", "-server", // Kill if OOM is raised - leverage yarn's failure handling to cause rescheduling. // Not killing the task leaves various aspects of the executor and (to some extent) the jvm in diff --git a/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocator.scala b/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocator.scala index e1af8d5a74cb1..b32e15738f28b 100644 --- a/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocator.scala +++ b/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocator.scala @@ -20,6 +20,7 @@ package org.apache.spark.deploy.yarn import java.util.{List => JList} import java.util.concurrent._ import java.util.concurrent.atomic.AtomicInteger +import java.util.regex.Pattern import scala.collection.JavaConversions._ import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet} @@ -63,6 +64,8 @@ private[yarn] abstract class YarnAllocator( securityMgr: SecurityManager) extends Logging { + import YarnAllocator._ + // These three are locked on allocatedHostToContainersMap. Complementary data structures // allocatedHostToContainersMap : containers which are running : host, Set // allocatedContainerToHostMap: container to host mapping. @@ -88,7 +91,10 @@ private[yarn] abstract class YarnAllocator( private val executorIdCounter = new AtomicInteger() private val numExecutorsFailed = new AtomicInteger() - private val maxExecutors = args.numExecutors + private var maxExecutors = args.numExecutors + + // Keep track of which container is running which executor to remove the executors later + private val executorIdToContainer = new HashMap[String, Container] protected val executorMemory = args.executorMemory protected val executorCores = args.executorCores @@ -111,7 +117,48 @@ private[yarn] abstract class YarnAllocator( def getNumExecutorsFailed: Int = numExecutorsFailed.intValue - def allocateResources() = { + /** + * Request as many executors from the ResourceManager as needed to reach the desired total. + * This takes into account executors already running or pending. + */ + def requestTotalExecutors(requestedTotal: Int): Unit = synchronized { + val currentTotal = numPendingAllocate.get + numExecutorsRunning.get + if (requestedTotal > currentTotal) { + maxExecutors += (requestedTotal - currentTotal) + // We need to call `allocateResources` here to avoid the following race condition: + // If we request executors twice before `allocateResources` is called, then we will end up + // double counting the number requested because `numPendingAllocate` is not updated yet. + allocateResources() + } else { + logInfo(s"Not allocating more executors because there are already $currentTotal " + + s"(application requested $requestedTotal total)") + } + } + + /** + * Request that the ResourceManager release the container running the specified executor. + */ + def killExecutor(executorId: String): Unit = synchronized { + if (executorIdToContainer.contains(executorId)) { + val container = executorIdToContainer.remove(executorId).get + internalReleaseContainer(container) + numExecutorsRunning.decrementAndGet() + maxExecutors -= 1 + assert(maxExecutors >= 0, "Allocator killed more executors than are allocated!") + } else { + logWarning(s"Attempted to kill unknown executor $executorId!") + } + } + + /** + * Allocate missing containers based on the number of executors currently pending and running. + * + * This method prioritizes the allocated container responses from the RM based on node and + * rack locality. Additionally, it releases any extra containers allocated for this application + * but are not needed. This must be synchronized because variables read in this block are + * mutated by other methods. + */ + def allocateResources(): Unit = synchronized { val missing = maxExecutors - numPendingAllocate.get() - numExecutorsRunning.get() // this is needed by alpha, do it here since we add numPending right after this @@ -119,7 +166,7 @@ private[yarn] abstract class YarnAllocator( if (missing > 0) { val totalExecutorMemory = executorMemory + memoryOverhead numPendingAllocate.addAndGet(missing) - logInfo(s"Will allocate $missing executor containers, each with $totalExecutorMemory MB " + + logInfo(s"Will allocate $missing executor containers, each with $totalExecutorMemory MB " + s"memory including $memoryOverhead MB overhead") } else { logDebug("Empty allocation request ...") @@ -269,6 +316,7 @@ private[yarn] abstract class YarnAllocator( CoarseGrainedSchedulerBackend.ACTOR_NAME) logInfo("Launching container %s for on host %s".format(containerId, executorHostname)) + executorIdToContainer(executorId) = container // To be safe, remove the container from `releasedContainers`. releasedContainers.remove(containerId) @@ -330,12 +378,22 @@ private[yarn] abstract class YarnAllocator( logInfo("Completed container %s (state: %s, exit status: %s)".format( containerId, completedContainer.getState, - completedContainer.getExitStatus())) + completedContainer.getExitStatus)) // Hadoop 2.2.X added a ContainerExitStatus we should switch to use // there are some exit status' we shouldn't necessarily count against us, but for // now I think its ok as none of the containers are expected to exit - if (completedContainer.getExitStatus() != 0) { - logInfo("Container marked as failed: " + containerId) + if (completedContainer.getExitStatus == -103) { // vmem limit exceeded + logWarning(memLimitExceededLogMessage( + completedContainer.getDiagnostics, + VMEM_EXCEEDED_PATTERN)) + } else if (completedContainer.getExitStatus == -104) { // pmem limit exceeded + logWarning(memLimitExceededLogMessage( + completedContainer.getDiagnostics, + PMEM_EXCEEDED_PATTERN)) + } else if (completedContainer.getExitStatus != 0) { + logInfo("Container marked as failed: " + containerId + + ". Exit status: " + completedContainer.getExitStatus + + ". Diagnostics: " + completedContainer.getDiagnostics) numExecutorsFailed.incrementAndGet() } } @@ -463,3 +521,18 @@ private[yarn] abstract class YarnAllocator( } } + +private object YarnAllocator { + val MEM_REGEX = "[0-9.]+ [KMG]B" + val PMEM_EXCEEDED_PATTERN = + Pattern.compile(s"$MEM_REGEX of $MEM_REGEX physical memory used") + val VMEM_EXCEEDED_PATTERN = + Pattern.compile(s"$MEM_REGEX of $MEM_REGEX virtual memory used") + + def memLimitExceededLogMessage(diagnostics: String, pattern: Pattern): String = { + val matcher = pattern.matcher(diagnostics) + val diag = if (matcher.find()) " " + matcher.group() + "." else "" + ("Container killed by YARN for exceeding memory limits." + diag + + " Consider boosting spark.yarn.executor.memoryOverhead.") + } +} diff --git a/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtil.scala b/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtil.scala index e1e0144f46fe9..7d453ecb7983c 100644 --- a/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtil.scala +++ b/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtil.scala @@ -93,6 +93,8 @@ object YarnSparkHadoopUtil { val ANY_HOST = "*" + val DEFAULT_NUMBER_EXECUTORS = 2 + // All RM requests are issued with same priority : we do not (yet) have any distinction between // request types (like map/reduce in hadoop for example) val RM_REQUEST_PRIORITY = 1 diff --git a/yarn/common/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientSchedulerBackend.scala b/yarn/common/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientSchedulerBackend.scala index d948a2aeedd45..f6f6dc52433e5 100644 --- a/yarn/common/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientSchedulerBackend.scala +++ b/yarn/common/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientSchedulerBackend.scala @@ -17,27 +17,23 @@ package org.apache.spark.scheduler.cluster +import scala.collection.mutable.ArrayBuffer + import org.apache.hadoop.yarn.api.records.{ApplicationId, YarnApplicationState} + import org.apache.spark.{SparkException, Logging, SparkContext} import org.apache.spark.deploy.yarn.{Client, ClientArguments} import org.apache.spark.scheduler.TaskSchedulerImpl -import scala.collection.mutable.ArrayBuffer - private[spark] class YarnClientSchedulerBackend( scheduler: TaskSchedulerImpl, sc: SparkContext) - extends CoarseGrainedSchedulerBackend(scheduler, sc.env.actorSystem) + extends YarnSchedulerBackend(scheduler, sc) with Logging { - if (conf.getOption("spark.scheduler.minRegisteredResourcesRatio").isEmpty) { - minRegisteredRatio = 0.8 - } - private var client: Client = null private var appId: ApplicationId = null private var stopping: Boolean = false - private var totalExpectedExecutors = 0 /** * Create a Yarn client to submit an application to the ResourceManager. @@ -48,7 +44,7 @@ private[spark] class YarnClientSchedulerBackend( val driverHost = conf.get("spark.driver.host") val driverPort = conf.get("spark.driver.port") val hostport = driverHost + ":" + driverPort - sc.ui.foreach { ui => conf.set("spark.driver.appUIAddress", ui.appUIHostPort) } + sc.ui.foreach { ui => conf.set("spark.driver.appUIAddress", ui.appUIAddress) } val argsArrayBuf = new ArrayBuffer[String]() argsArrayBuf += ("--arg", hostport) @@ -151,14 +147,11 @@ private[spark] class YarnClientSchedulerBackend( logInfo("Stopped") } - override def sufficientResourcesRegistered(): Boolean = { - totalRegisteredExecutors.get() >= totalExpectedExecutors * minRegisteredRatio - } - - override def applicationId(): String = + override def applicationId(): String = { Option(appId).map(_.toString).getOrElse { logWarning("Application ID is not initialized yet.") super.applicationId } + } } diff --git a/yarn/common/src/main/scala/org/apache/spark/scheduler/cluster/YarnClusterSchedulerBackend.scala b/yarn/common/src/main/scala/org/apache/spark/scheduler/cluster/YarnClusterSchedulerBackend.scala index 3a186cfeb4eeb..b1de81e6a8b0f 100644 --- a/yarn/common/src/main/scala/org/apache/spark/scheduler/cluster/YarnClusterSchedulerBackend.scala +++ b/yarn/common/src/main/scala/org/apache/spark/scheduler/cluster/YarnClusterSchedulerBackend.scala @@ -18,24 +18,18 @@ package org.apache.spark.scheduler.cluster import org.apache.spark.SparkContext -import org.apache.spark.deploy.yarn.ApplicationMasterArguments +import org.apache.spark.deploy.yarn.YarnSparkHadoopUtil._ import org.apache.spark.scheduler.TaskSchedulerImpl import org.apache.spark.util.IntParam private[spark] class YarnClusterSchedulerBackend( scheduler: TaskSchedulerImpl, sc: SparkContext) - extends CoarseGrainedSchedulerBackend(scheduler, sc.env.actorSystem) { - - var totalExpectedExecutors = 0 - - if (conf.getOption("spark.scheduler.minRegisteredResourcesRatio").isEmpty) { - minRegisteredRatio = 0.8 - } + extends YarnSchedulerBackend(scheduler, sc) { override def start() { super.start() - totalExpectedExecutors = ApplicationMasterArguments.DEFAULT_NUMBER_EXECUTORS + totalExpectedExecutors = DEFAULT_NUMBER_EXECUTORS if (System.getenv("SPARK_EXECUTOR_INSTANCES") != null) { totalExpectedExecutors = IntParam.unapply(System.getenv("SPARK_EXECUTOR_INSTANCES")) .getOrElse(totalExpectedExecutors) @@ -44,10 +38,6 @@ private[spark] class YarnClusterSchedulerBackend( totalExpectedExecutors = sc.getConf.getInt("spark.executor.instances", totalExpectedExecutors) } - override def sufficientResourcesRegistered(): Boolean = { - totalRegisteredExecutors.get() >= totalExpectedExecutors * minRegisteredRatio - } - override def applicationId(): String = // In YARN Cluster mode, spark.yarn.app.id is expect to be set // before user application is launched. diff --git a/yarn/common/src/test/scala/org/apache/spark/deploy/yarn/YarnAllocatorSuite.scala b/yarn/common/src/test/scala/org/apache/spark/deploy/yarn/YarnAllocatorSuite.scala new file mode 100644 index 0000000000000..8d184a09d64cc --- /dev/null +++ b/yarn/common/src/test/scala/org/apache/spark/deploy/yarn/YarnAllocatorSuite.scala @@ -0,0 +1,34 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.deploy.yarn + +import org.apache.spark.deploy.yarn.YarnAllocator._ +import org.scalatest.FunSuite + +class YarnAllocatorSuite extends FunSuite { + test("memory exceeded diagnostic regexes") { + val diagnostics = + "Container [pid=12465,containerID=container_1412887393566_0003_01_000002] is running " + + "beyond physical memory limits. Current usage: 2.1 MB of 2 GB physical memory used; " + + "5.8 GB of 4.2 GB virtual memory used. Killing container." + val vmemMsg = memLimitExceededLogMessage(diagnostics, VMEM_EXCEEDED_PATTERN) + val pmemMsg = memLimitExceededLogMessage(diagnostics, PMEM_EXCEEDED_PATTERN) + assert(vmemMsg.contains("5.8 GB of 4.2 GB virtual memory used.")) + assert(pmemMsg.contains("2.1 MB of 2 GB physical memory used.")) + } +} diff --git a/yarn/pom.xml b/yarn/pom.xml index 8a7035c85e9f1..2885e6607ec24 100644 --- a/yarn/pom.xml +++ b/yarn/pom.xml @@ -73,6 +73,28 @@ yarn-alpha + + + + maven-antrun-plugin + + + validate + + run + + + + ******************************************************************************************* + ***WARNING***: Support for YARN-alpha API's will be removed in Spark 1.3 (see SPARK-3445).* + ******************************************************************************************* + + + + + + + alpha