diff --git a/LICENSE b/LICENSE index 7950dd6ceb6db..c21032a1fd274 100644 --- a/LICENSE +++ b/LICENSE @@ -297,3 +297,4 @@ The text of each license is also included at licenses/LICENSE-[project].txt. (MIT License) RowsGroup (http://datatables.net/license/mit) (MIT License) jsonFormatter (http://www.jqueryscript.net/other/jQuery-Plugin-For-Pretty-JSON-Formatting-jsonFormatter.html) (MIT License) modernizr (https://github.com/Modernizr/Modernizr/blob/master/LICENSE) + (MIT License) machinist (https://github.com/typelevel/machinist) diff --git a/R/check-cran.sh b/R/check-cran.sh index a188b1448a67b..22cc9c6b601fc 100755 --- a/R/check-cran.sh +++ b/R/check-cran.sh @@ -20,18 +20,18 @@ set -o pipefail set -e -FWDIR="$(cd `dirname "${BASH_SOURCE[0]}"`; pwd)" -pushd $FWDIR > /dev/null +FWDIR="$(cd "`dirname "${BASH_SOURCE[0]}"`"; pwd)" +pushd "$FWDIR" > /dev/null -. $FWDIR/find-r.sh +. "$FWDIR/find-r.sh" # Install the package (this is required for code in vignettes to run when building it later) # Build the latest docs, but not vignettes, which is built with the package next -. $FWDIR/install-dev.sh +. "$FWDIR/install-dev.sh" # Build source package with vignettes SPARK_HOME="$(cd "${FWDIR}"/..; pwd)" -. "${SPARK_HOME}"/bin/load-spark-env.sh +. "${SPARK_HOME}/bin/load-spark-env.sh" if [ -f "${SPARK_HOME}/RELEASE" ]; then SPARK_JARS_DIR="${SPARK_HOME}/jars" else @@ -40,16 +40,16 @@ fi if [ -d "$SPARK_JARS_DIR" ]; then # Build a zip file containing the source package with vignettes - SPARK_HOME="${SPARK_HOME}" "$R_SCRIPT_PATH/"R CMD build $FWDIR/pkg + SPARK_HOME="${SPARK_HOME}" "$R_SCRIPT_PATH/R" CMD build "$FWDIR/pkg" find pkg/vignettes/. -not -name '.' -not -name '*.Rmd' -not -name '*.md' -not -name '*.pdf' -not -name '*.html' -delete else - echo "Error Spark JARs not found in $SPARK_HOME" + echo "Error Spark JARs not found in '$SPARK_HOME'" exit 1 fi # Run check as-cran. -VERSION=`grep Version $FWDIR/pkg/DESCRIPTION | awk '{print $NF}'` +VERSION=`grep Version "$FWDIR/pkg/DESCRIPTION" | awk '{print $NF}'` CRAN_CHECK_OPTIONS="--as-cran" @@ -67,10 +67,10 @@ echo "Running CRAN check with $CRAN_CHECK_OPTIONS options" if [ -n "$NO_TESTS" ] && [ -n "$NO_MANUAL" ] then - "$R_SCRIPT_PATH/"R CMD check $CRAN_CHECK_OPTIONS SparkR_"$VERSION".tar.gz + "$R_SCRIPT_PATH/R" CMD check $CRAN_CHECK_OPTIONS "SparkR_$VERSION.tar.gz" else # This will run tests and/or build vignettes, and require SPARK_HOME - SPARK_HOME="${SPARK_HOME}" "$R_SCRIPT_PATH/"R CMD check $CRAN_CHECK_OPTIONS SparkR_"$VERSION".tar.gz + SPARK_HOME="${SPARK_HOME}" "$R_SCRIPT_PATH/R" CMD check $CRAN_CHECK_OPTIONS "SparkR_$VERSION.tar.gz" fi popd > /dev/null diff --git a/R/create-docs.sh b/R/create-docs.sh index 6bef7e75e3bd8..310dbc5fb50a3 100755 --- a/R/create-docs.sh +++ b/R/create-docs.sh @@ -33,15 +33,15 @@ export FWDIR="$(cd "`dirname "${BASH_SOURCE[0]}"`"; pwd)" export SPARK_HOME="$(cd "`dirname "${BASH_SOURCE[0]}"`"/..; pwd)" # Required for setting SPARK_SCALA_VERSION -. "${SPARK_HOME}"/bin/load-spark-env.sh +. "${SPARK_HOME}/bin/load-spark-env.sh" echo "Using Scala $SPARK_SCALA_VERSION" -pushd $FWDIR > /dev/null -. $FWDIR/find-r.sh +pushd "$FWDIR" > /dev/null +. "$FWDIR/find-r.sh" # Install the package (this will also generate the Rd files) -. $FWDIR/install-dev.sh +. "$FWDIR/install-dev.sh" # Now create HTML files @@ -49,7 +49,7 @@ pushd $FWDIR > /dev/null mkdir -p pkg/html pushd pkg/html -"$R_SCRIPT_PATH/"Rscript -e 'libDir <- "../../lib"; library(SparkR, lib.loc=libDir); library(knitr); knit_rd("SparkR", links = tools::findHTMLlinks(paste(libDir, "SparkR", sep="/")))' +"$R_SCRIPT_PATH/Rscript" -e 'libDir <- "../../lib"; library(SparkR, lib.loc=libDir); library(knitr); knit_rd("SparkR", links = tools::findHTMLlinks(paste(libDir, "SparkR", sep="/")))' popd diff --git a/R/create-rd.sh b/R/create-rd.sh index d17e1617397d1..ff622a41a46c0 100755 --- a/R/create-rd.sh +++ b/R/create-rd.sh @@ -29,9 +29,9 @@ set -o pipefail set -e -FWDIR="$(cd `dirname "${BASH_SOURCE[0]}"`; pwd)" -pushd $FWDIR > /dev/null -. $FWDIR/find-r.sh +FWDIR="$(cd "`dirname "${BASH_SOURCE[0]}"`"; pwd)" +pushd "$FWDIR" > /dev/null +. "$FWDIR/find-r.sh" # Generate Rd files if devtools is installed -"$R_SCRIPT_PATH/"Rscript -e ' if("devtools" %in% rownames(installed.packages())) { library(devtools); devtools::document(pkg="./pkg", roclets=c("rd")) }' +"$R_SCRIPT_PATH/Rscript" -e ' if("devtools" %in% rownames(installed.packages())) { library(devtools); devtools::document(pkg="./pkg", roclets=c("rd")) }' diff --git a/R/install-dev.sh b/R/install-dev.sh index 45e6411705814..d613552718307 100755 --- a/R/install-dev.sh +++ b/R/install-dev.sh @@ -29,21 +29,21 @@ set -o pipefail set -e -FWDIR="$(cd `dirname "${BASH_SOURCE[0]}"`; pwd)" +FWDIR="$(cd "`dirname "${BASH_SOURCE[0]}"`"; pwd)" LIB_DIR="$FWDIR/lib" -mkdir -p $LIB_DIR +mkdir -p "$LIB_DIR" -pushd $FWDIR > /dev/null -. $FWDIR/find-r.sh +pushd "$FWDIR" > /dev/null +. "$FWDIR/find-r.sh" -. $FWDIR/create-rd.sh +. "$FWDIR/create-rd.sh" # Install SparkR to $LIB_DIR -"$R_SCRIPT_PATH/"R CMD INSTALL --library=$LIB_DIR $FWDIR/pkg/ +"$R_SCRIPT_PATH/R" CMD INSTALL --library="$LIB_DIR" "$FWDIR/pkg/" # Zip the SparkR package so that it can be distributed to worker nodes on YARN -cd $LIB_DIR +cd "$LIB_DIR" jar cfM "$LIB_DIR/sparkr.zip" SparkR popd > /dev/null diff --git a/R/install-source-package.sh b/R/install-source-package.sh index c6e443c04e628..8de3569d1d482 100755 --- a/R/install-source-package.sh +++ b/R/install-source-package.sh @@ -29,28 +29,28 @@ set -o pipefail set -e -FWDIR="$(cd `dirname "${BASH_SOURCE[0]}"`; pwd)" -pushd $FWDIR > /dev/null -. $FWDIR/find-r.sh +FWDIR="$(cd "`dirname "${BASH_SOURCE[0]}"`"; pwd)" +pushd "$FWDIR" > /dev/null +. "$FWDIR/find-r.sh" if [ -z "$VERSION" ]; then - VERSION=`grep Version $FWDIR/pkg/DESCRIPTION | awk '{print $NF}'` + VERSION=`grep Version "$FWDIR/pkg/DESCRIPTION" | awk '{print $NF}'` fi -if [ ! -f "$FWDIR"/SparkR_"$VERSION".tar.gz ]; then - echo -e "R source package file $FWDIR/SparkR_$VERSION.tar.gz is not found." +if [ ! -f "$FWDIR/SparkR_$VERSION.tar.gz" ]; then + echo -e "R source package file '$FWDIR/SparkR_$VERSION.tar.gz' is not found." echo -e "Please build R source package with check-cran.sh" exit -1; fi echo "Removing lib path and installing from source package" LIB_DIR="$FWDIR/lib" -rm -rf $LIB_DIR -mkdir -p $LIB_DIR -"$R_SCRIPT_PATH/"R CMD INSTALL SparkR_"$VERSION".tar.gz --library=$LIB_DIR +rm -rf "$LIB_DIR" +mkdir -p "$LIB_DIR" +"$R_SCRIPT_PATH/R" CMD INSTALL "SparkR_$VERSION.tar.gz" --library="$LIB_DIR" # Zip the SparkR package so that it can be distributed to worker nodes on YARN -pushd $LIB_DIR > /dev/null +pushd "$LIB_DIR" > /dev/null jar cfM "$LIB_DIR/sparkr.zip" SparkR popd > /dev/null diff --git a/R/pkg/.lintr b/R/pkg/.lintr index 038236fc149e6..ae50b28ec6166 100644 --- a/R/pkg/.lintr +++ b/R/pkg/.lintr @@ -1,2 +1,2 @@ -linters: with_defaults(line_length_linter(100), camel_case_linter = NULL, open_curly_linter(allow_single_line = TRUE), closed_curly_linter(allow_single_line = TRUE)) +linters: with_defaults(line_length_linter(100), multiple_dots_linter = NULL, camel_case_linter = NULL, open_curly_linter(allow_single_line = TRUE), closed_curly_linter(allow_single_line = TRUE)) exclusions: list("inst/profile/general.R" = 1, "inst/profile/shell.R") diff --git a/R/pkg/DESCRIPTION b/R/pkg/DESCRIPTION index cc471edc376b3..879c1f80f2c5d 100644 --- a/R/pkg/DESCRIPTION +++ b/R/pkg/DESCRIPTION @@ -35,6 +35,7 @@ Collate: 'WindowSpec.R' 'backend.R' 'broadcast.R' + 'catalog.R' 'client.R' 'context.R' 'deserialize.R' @@ -43,6 +44,7 @@ Collate: 'jvm.R' 'mllib_classification.R' 'mllib_clustering.R' + 'mllib_fpm.R' 'mllib_recommendation.R' 'mllib_regression.R' 'mllib_stat.R' @@ -51,6 +53,7 @@ Collate: 'serialize.R' 'sparkR.R' 'stats.R' + 'streaming.R' 'types.R' 'utils.R' 'window.R' diff --git a/R/pkg/NAMESPACE b/R/pkg/NAMESPACE index 871f8e41a0f23..ba0fe7708bcc3 100644 --- a/R/pkg/NAMESPACE +++ b/R/pkg/NAMESPACE @@ -66,7 +66,10 @@ exportMethods("glm", "spark.randomForest", "spark.gbt", "spark.bisectingKmeans", - "spark.svmLinear") + "spark.svmLinear", + "spark.fpGrowth", + "spark.freqItemsets", + "spark.associationRules") # Job group lifecycle management methods export("setJobGroup", @@ -82,6 +85,7 @@ exportMethods("arrange", "as.data.frame", "attach", "cache", + "checkpoint", "coalesce", "collect", "colnames", @@ -97,6 +101,7 @@ exportMethods("arrange", "createOrReplaceTempView", "crossJoin", "crosstab", + "cube", "dapply", "dapplyCollect", "describe", @@ -118,9 +123,11 @@ exportMethods("arrange", "group_by", "groupBy", "head", + "hint", "insertInto", "intersect", "isLocal", + "isStreaming", "join", "limit", "merge", @@ -138,6 +145,7 @@ exportMethods("arrange", "registerTempTable", "rename", "repartition", + "rollup", "sample", "sample_frac", "sampleBy", @@ -169,12 +177,14 @@ exportMethods("arrange", "write.json", "write.orc", "write.parquet", + "write.stream", "write.text", "write.ml") exportClasses("Column") -exportMethods("%in%", +exportMethods("%<=>%", + "%in%", "abs", "acos", "add_months", @@ -197,6 +207,8 @@ exportMethods("%in%", "cbrt", "ceil", "ceiling", + "collect_list", + "collect_set", "column", "concat", "concat_ws", @@ -207,6 +219,8 @@ exportMethods("%in%", "count", "countDistinct", "crc32", + "create_array", + "create_map", "hash", "cume_dist", "date_add", @@ -222,6 +236,7 @@ exportMethods("%in%", "endsWith", "exp", "explode", + "explode_outer", "expm1", "expr", "factorial", @@ -235,12 +250,15 @@ exportMethods("%in%", "getField", "getItem", "greatest", + "grouping_bit", + "grouping_id", "hex", "histogram", "hour", "hypot", "ifelse", "initcap", + "input_file_name", "instr", "isNaN", "isNotNull", @@ -278,18 +296,21 @@ exportMethods("%in%", "nanvl", "negate", "next_day", + "not", "ntile", "otherwise", "over", "percent_rank", "pmod", "posexplode", + "posexplode_outer", "quarter", "rand", "randn", "rank", "regexp_extract", "regexp_replace", + "repeat_string", "reverse", "rint", "rlike", @@ -313,6 +334,7 @@ exportMethods("%in%", "sort_array", "soundex", "spark_partition_id", + "split_string", "stddev", "stddev_pop", "stddev_samp", @@ -355,9 +377,15 @@ export("as.DataFrame", "clearCache", "createDataFrame", "createExternalTable", + "createTable", + "currentDatabase", "dropTempTable", "dropTempView", "jsonFile", + "listColumns", + "listDatabases", + "listFunctions", + "listTables", "loadDF", "parquetFile", "read.df", @@ -365,7 +393,13 @@ export("as.DataFrame", "read.json", "read.orc", "read.parquet", + "read.stream", "read.text", + "recoverPartitions", + "refreshByPath", + "refreshTable", + "setCheckpointDir", + "setCurrentDatabase", "spark.lapply", "spark.addFile", "spark.getSparkFilesRootDirectory", @@ -402,6 +436,16 @@ export("partitionBy", export("windowPartitionBy", "windowOrderBy") +exportClasses("StreamingQuery") + +export("awaitTermination", + "isActive", + "lastProgress", + "queryName", + "status", + "stopQuery") + + S3method(print, jobj) S3method(print, structField) S3method(print, structType) diff --git a/R/pkg/R/DataFrame.R b/R/pkg/R/DataFrame.R index e33d0d8e29d49..1c8869202f677 100644 --- a/R/pkg/R/DataFrame.R +++ b/R/pkg/R/DataFrame.R @@ -133,9 +133,6 @@ setMethod("schema", #' #' Print the logical and physical Catalyst plans to the console for debugging. #' -#' @param x a SparkDataFrame. -#' @param extended Logical. If extended is FALSE, explain() only prints the physical plan. -#' @param ... further arguments to be passed to or from other methods. #' @family SparkDataFrame functions #' @aliases explain,SparkDataFrame-method #' @rdname explain @@ -197,6 +194,7 @@ setMethod("isLocal", #' 20 characters will be truncated. However, if set greater than zero, #' truncates strings longer than \code{truncate} characters and all cells #' will be aligned right. +#' @param vertical whether print output rows vertically (one line per column value). #' @param ... further arguments to be passed to or from other methods. #' @family SparkDataFrame functions #' @aliases showDF,SparkDataFrame-method @@ -213,12 +211,13 @@ setMethod("isLocal", #' @note showDF since 1.4.0 setMethod("showDF", signature(x = "SparkDataFrame"), - function(x, numRows = 20, truncate = TRUE) { + function(x, numRows = 20, truncate = TRUE, vertical = FALSE) { if (is.logical(truncate) && truncate) { - s <- callJMethod(x@sdf, "showString", numToInt(numRows), numToInt(20)) + s <- callJMethod(x@sdf, "showString", numToInt(numRows), numToInt(20), vertical) } else { truncate2 <- as.numeric(truncate) - s <- callJMethod(x@sdf, "showString", numToInt(numRows), numToInt(truncate2)) + s <- callJMethod(x@sdf, "showString", numToInt(numRows), numToInt(truncate2), + vertical) } cat(s) }) @@ -560,7 +559,7 @@ setMethod("insertInto", jmode <- convertToJSaveMode(ifelse(overwrite, "overwrite", "append")) write <- callJMethod(x@sdf, "write") write <- callJMethod(write, "mode", jmode) - callJMethod(write, "insertInto", tableName) + invisible(callJMethod(write, "insertInto", tableName)) }) #' Cache @@ -1324,7 +1323,7 @@ setMethod("toRDD", #' Groups the SparkDataFrame using the specified columns, so we can run aggregation on them. #' #' @param x a SparkDataFrame. -#' @param ... variable(s) (character names(s) or Column(s)) to group on. +#' @param ... character name(s) or Column(s) to group on. #' @return A GroupedData. #' @family SparkDataFrame functions #' @aliases groupBy,SparkDataFrame-method @@ -1340,6 +1339,7 @@ setMethod("toRDD", #' agg(groupBy(df, "department", "gender"), salary="avg", "age" -> "max") #' } #' @note groupBy since 1.4.0 +#' @seealso \link{agg}, \link{cube}, \link{rollup} setMethod("groupBy", signature(x = "SparkDataFrame"), function(x, ...) { @@ -2642,6 +2642,7 @@ generateAliasesForIntersectedCols <- function (x, intersectedColNames, suffix) { #' #' Return a new SparkDataFrame containing the union of rows in this SparkDataFrame #' and another SparkDataFrame. This is equivalent to \code{UNION ALL} in SQL. +#' Input SparkDataFrames can have different schemas (names and data types). #' #' Note: This does not remove duplicate rows across the two SparkDataFrames. #' @@ -2685,7 +2686,8 @@ setMethod("unionAll", #' Union two or more SparkDataFrames #' -#' Union two or more SparkDataFrames. This is equivalent to \code{UNION ALL} in SQL. +#' Union two or more SparkDataFrames by row. As in R's \code{rbind}, this method +#' requires that the input SparkDataFrames have the same column names. #' #' Note: This does not remove duplicate rows across the two SparkDataFrames. #' @@ -2709,6 +2711,10 @@ setMethod("unionAll", setMethod("rbind", signature(... = "SparkDataFrame"), function(x, ..., deparse.level = 1) { + nm <- lapply(list(x, ...), names) + if (length(unique(nm)) != 1) { + stop("Names of input data frames are different.") + } if (nargs() == 3) { union(x, ...) } else { @@ -2815,14 +2821,14 @@ setMethod("write.df", signature(df = "SparkDataFrame"), function(df, path = NULL, source = NULL, mode = "error", ...) { if (!is.null(path) && !is.character(path)) { - stop("path should be charactor, NULL or omitted.") + stop("path should be character, NULL or omitted.") } if (!is.null(source) && !is.character(source)) { stop("source should be character, NULL or omitted. It is the datasource specified ", "in 'spark.sql.sources.default' configuration by default.") } if (!is.character(mode)) { - stop("mode should be charactor or omitted. It is 'error' by default.") + stop("mode should be character or omitted. It is 'error' by default.") } if (is.null(source)) { source <- getDefaultSqlSource() @@ -2891,7 +2897,7 @@ setMethod("saveAsTable", write <- callJMethod(write, "format", source) write <- callJMethod(write, "mode", jmode) write <- callJMethod(write, "options", options) - callJMethod(write, "saveAsTable", tableName) + invisible(callJMethod(write, "saveAsTable", tableName)) }) #' summary @@ -3037,7 +3043,7 @@ setMethod("fillna", signature(x = "SparkDataFrame"), function(x, value, cols = NULL) { if (!(class(value) %in% c("integer", "numeric", "character", "list"))) { - stop("value should be an integer, numeric, charactor or named list.") + stop("value should be an integer, numeric, character or named list.") } if (class(value) == "list") { @@ -3049,7 +3055,7 @@ setMethod("fillna", # Check each item in the named list is of valid type lapply(value, function(v) { if (!(class(v) %in% c("integer", "numeric", "character"))) { - stop("Each item in value should be an integer, numeric or charactor.") + stop("Each item in value should be an integer, numeric or character.") } }) @@ -3509,3 +3515,233 @@ setMethod("getNumPartitions", function(x) { callJMethod(callJMethod(x@sdf, "rdd"), "getNumPartitions") }) + +#' isStreaming +#' +#' Returns TRUE if this SparkDataFrame contains one or more sources that continuously return data +#' as it arrives. +#' +#' @param x A SparkDataFrame +#' @return TRUE if this SparkDataFrame is from a streaming source +#' @family SparkDataFrame functions +#' @aliases isStreaming,SparkDataFrame-method +#' @rdname isStreaming +#' @name isStreaming +#' @seealso \link{read.stream} \link{write.stream} +#' @export +#' @examples +#'\dontrun{ +#' sparkR.session() +#' df <- read.stream("socket", host = "localhost", port = 9999) +#' isStreaming(df) +#' } +#' @note isStreaming since 2.2.0 +#' @note experimental +setMethod("isStreaming", + signature(x = "SparkDataFrame"), + function(x) { + callJMethod(x@sdf, "isStreaming") + }) + +#' Write the streaming SparkDataFrame to a data source. +#' +#' The data source is specified by the \code{source} and a set of options (...). +#' If \code{source} is not specified, the default data source configured by +#' spark.sql.sources.default will be used. +#' +#' Additionally, \code{outputMode} specifies how data of a streaming SparkDataFrame is written to a +#' output data source. There are three modes: +#' \itemize{ +#' \item append: Only the new rows in the streaming SparkDataFrame will be written out. This +#' output mode can be only be used in queries that do not contain any aggregation. +#' \item complete: All the rows in the streaming SparkDataFrame will be written out every time +#' there are some updates. This output mode can only be used in queries that +#' contain aggregations. +#' \item update: Only the rows that were updated in the streaming SparkDataFrame will be written +#' out every time there are some updates. If the query doesn't contain aggregations, +#' it will be equivalent to \code{append} mode. +#' } +#' +#' @param df a streaming SparkDataFrame. +#' @param source a name for external data source. +#' @param outputMode one of 'append', 'complete', 'update'. +#' @param ... additional argument(s) passed to the method. +#' +#' @family SparkDataFrame functions +#' @seealso \link{read.stream} +#' @aliases write.stream,SparkDataFrame-method +#' @rdname write.stream +#' @name write.stream +#' @export +#' @examples +#'\dontrun{ +#' sparkR.session() +#' df <- read.stream("socket", host = "localhost", port = 9999) +#' isStreaming(df) +#' wordCounts <- count(group_by(df, "value")) +#' +#' # console +#' q <- write.stream(wordCounts, "console", outputMode = "complete") +#' # text stream +#' q <- write.stream(df, "text", path = "/home/user/out", checkpointLocation = "/home/user/cp") +#' # memory stream +#' q <- write.stream(wordCounts, "memory", queryName = "outs", outputMode = "complete") +#' head(sql("SELECT * from outs")) +#' queryName(q) +#' +#' stopQuery(q) +#' } +#' @note write.stream since 2.2.0 +#' @note experimental +setMethod("write.stream", + signature(df = "SparkDataFrame"), + function(df, source = NULL, outputMode = NULL, ...) { + if (!is.null(source) && !is.character(source)) { + stop("source should be character, NULL or omitted. It is the data source specified ", + "in 'spark.sql.sources.default' configuration by default.") + } + if (!is.null(outputMode) && !is.character(outputMode)) { + stop("outputMode should be character or omitted.") + } + if (is.null(source)) { + source <- getDefaultSqlSource() + } + options <- varargsToStrEnv(...) + write <- handledCallJMethod(df@sdf, "writeStream") + write <- callJMethod(write, "format", source) + if (!is.null(outputMode)) { + write <- callJMethod(write, "outputMode", outputMode) + } + write <- callJMethod(write, "options", options) + ssq <- handledCallJMethod(write, "start") + streamingQuery(ssq) + }) + +#' checkpoint +#' +#' Returns a checkpointed version of this SparkDataFrame. Checkpointing can be used to truncate the +#' logical plan, which is especially useful in iterative algorithms where the plan may grow +#' exponentially. It will be saved to files inside the checkpoint directory set with +#' \code{setCheckpointDir} +#' +#' @param x A SparkDataFrame +#' @param eager whether to checkpoint this SparkDataFrame immediately +#' @return a new checkpointed SparkDataFrame +#' @family SparkDataFrame functions +#' @aliases checkpoint,SparkDataFrame-method +#' @rdname checkpoint +#' @name checkpoint +#' @seealso \link{setCheckpointDir} +#' @export +#' @examples +#'\dontrun{ +#' setCheckpointDir("/checkpoint") +#' df <- checkpoint(df) +#' } +#' @note checkpoint since 2.2.0 +setMethod("checkpoint", + signature(x = "SparkDataFrame"), + function(x, eager = TRUE) { + df <- callJMethod(x@sdf, "checkpoint", as.logical(eager)) + dataFrame(df) + }) + +#' cube +#' +#' Create a multi-dimensional cube for the SparkDataFrame using the specified columns. +#' +#' If grouping expression is missing \code{cube} creates a single global aggregate and is equivalent to +#' direct application of \link{agg}. +#' +#' @param x a SparkDataFrame. +#' @param ... character name(s) or Column(s) to group on. +#' @return A GroupedData. +#' @family SparkDataFrame functions +#' @aliases cube,SparkDataFrame-method +#' @rdname cube +#' @name cube +#' @export +#' @examples +#' \dontrun{ +#' df <- createDataFrame(mtcars) +#' mean(cube(df, "cyl", "gear", "am"), "mpg") +#' +#' # Following calls are equivalent +#' agg(cube(carsDF), mean(carsDF$mpg)) +#' agg(carsDF, mean(carsDF$mpg)) +#' } +#' @note cube since 2.3.0 +#' @seealso \link{agg}, \link{groupBy}, \link{rollup} +setMethod("cube", + signature(x = "SparkDataFrame"), + function(x, ...) { + cols <- list(...) + jcol <- lapply(cols, function(x) if (class(x) == "Column") x@jc else column(x)@jc) + sgd <- callJMethod(x@sdf, "cube", jcol) + groupedData(sgd) + }) + +#' rollup +#' +#' Create a multi-dimensional rollup for the SparkDataFrame using the specified columns. +#' +#' If grouping expression is missing \code{rollup} creates a single global aggregate and is equivalent to +#' direct application of \link{agg}. +#' +#' @param x a SparkDataFrame. +#' @param ... character name(s) or Column(s) to group on. +#' @return A GroupedData. +#' @family SparkDataFrame functions +#' @aliases rollup,SparkDataFrame-method +#' @rdname rollup +#' @name rollup +#' @export +#' @examples +#'\dontrun{ +#' df <- createDataFrame(mtcars) +#' mean(rollup(df, "cyl", "gear", "am"), "mpg") +#' +#' # Following calls are equivalent +#' agg(rollup(carsDF), mean(carsDF$mpg)) +#' agg(carsDF, mean(carsDF$mpg)) +#' } +#' @note rollup since 2.3.0 +#' @seealso \link{agg}, \link{cube}, \link{groupBy} +setMethod("rollup", + signature(x = "SparkDataFrame"), + function(x, ...) { + cols <- list(...) + jcol <- lapply(cols, function(x) if (class(x) == "Column") x@jc else column(x)@jc) + sgd <- callJMethod(x@sdf, "rollup", jcol) + groupedData(sgd) + }) + +#' hint +#' +#' Specifies execution plan hint and return a new SparkDataFrame. +#' +#' @param x a SparkDataFrame. +#' @param name a name of the hint. +#' @param ... optional parameters for the hint. +#' @return A SparkDataFrame. +#' @family SparkDataFrame functions +#' @aliases hint,SparkDataFrame,character-method +#' @rdname hint +#' @name hint +#' @export +#' @examples +#' \dontrun{ +#' df <- createDataFrame(mtcars) +#' avg_mpg <- mean(groupBy(createDataFrame(mtcars), "cyl"), "mpg") +#' +#' head(join(df, hint(avg_mpg, "broadcast"), df$cyl == avg_mpg$cyl)) +#' } +#' @note hint since 2.2.0 +setMethod("hint", + signature(x = "SparkDataFrame", name = "character"), + function(x, name, ...) { + parameters <- list(...) + stopifnot(all(sapply(parameters, is.character))) + jdf <- callJMethod(x@sdf, "hint", name, parameters) + dataFrame(jdf) + }) diff --git a/R/pkg/R/RDD.R b/R/pkg/R/RDD.R index 5667b9d788821..7ad3993e9ecbc 100644 --- a/R/pkg/R/RDD.R +++ b/R/pkg/R/RDD.R @@ -291,7 +291,7 @@ setMethod("unpersistRDD", #' @rdname checkpoint-methods #' @aliases checkpoint,RDD-method #' @noRd -setMethod("checkpoint", +setMethod("checkpointRDD", signature(x = "RDD"), function(x) { jrdd <- getJRDD(x) diff --git a/R/pkg/R/SQLContext.R b/R/pkg/R/SQLContext.R index 8354f705f6dea..f5c3a749fe0a1 100644 --- a/R/pkg/R/SQLContext.R +++ b/R/pkg/R/SQLContext.R @@ -544,12 +544,15 @@ sql <- function(x, ...) { dispatchFunc("sql(sqlQuery)", x, ...) } -#' Create a SparkDataFrame from a SparkSQL Table +#' Create a SparkDataFrame from a SparkSQL table or view #' -#' Returns the specified Table as a SparkDataFrame. The Table must have already been registered -#' in the SparkSession. +#' Returns the specified table or view as a SparkDataFrame. The table or view must already exist or +#' have already been registered in the SparkSession. #' -#' @param tableName The SparkSQL Table to convert to a SparkDataFrame. +#' @param tableName the qualified or unqualified name that designates a table or view. If a database +#' is specified, it identifies the table/view from the database. +#' Otherwise, it first attempts to find a temporary view with the given name +#' and then match the table/view from the current database. #' @return SparkDataFrame #' @rdname tableToDF #' @name tableToDF @@ -569,200 +572,6 @@ tableToDF <- function(tableName) { dataFrame(sdf) } -#' Tables -#' -#' Returns a SparkDataFrame containing names of tables in the given database. -#' -#' @param databaseName name of the database -#' @return a SparkDataFrame -#' @rdname tables -#' @export -#' @examples -#'\dontrun{ -#' sparkR.session() -#' tables("hive") -#' } -#' @name tables -#' @method tables default -#' @note tables since 1.4.0 -tables.default <- function(databaseName = NULL) { - sparkSession <- getSparkSession() - jdf <- callJStatic("org.apache.spark.sql.api.r.SQLUtils", "getTables", sparkSession, databaseName) - dataFrame(jdf) -} - -tables <- function(x, ...) { - dispatchFunc("tables(databaseName = NULL)", x, ...) -} - -#' Table Names -#' -#' Returns the names of tables in the given database as an array. -#' -#' @param databaseName name of the database -#' @return a list of table names -#' @rdname tableNames -#' @export -#' @examples -#'\dontrun{ -#' sparkR.session() -#' tableNames("hive") -#' } -#' @name tableNames -#' @method tableNames default -#' @note tableNames since 1.4.0 -tableNames.default <- function(databaseName = NULL) { - sparkSession <- getSparkSession() - callJStatic("org.apache.spark.sql.api.r.SQLUtils", - "getTableNames", - sparkSession, - databaseName) -} - -tableNames <- function(x, ...) { - dispatchFunc("tableNames(databaseName = NULL)", x, ...) -} - -#' Cache Table -#' -#' Caches the specified table in-memory. -#' -#' @param tableName The name of the table being cached -#' @return SparkDataFrame -#' @rdname cacheTable -#' @export -#' @examples -#'\dontrun{ -#' sparkR.session() -#' path <- "path/to/file.json" -#' df <- read.json(path) -#' createOrReplaceTempView(df, "table") -#' cacheTable("table") -#' } -#' @name cacheTable -#' @method cacheTable default -#' @note cacheTable since 1.4.0 -cacheTable.default <- function(tableName) { - sparkSession <- getSparkSession() - catalog <- callJMethod(sparkSession, "catalog") - invisible(callJMethod(catalog, "cacheTable", tableName)) -} - -cacheTable <- function(x, ...) { - dispatchFunc("cacheTable(tableName)", x, ...) -} - -#' Uncache Table -#' -#' Removes the specified table from the in-memory cache. -#' -#' @param tableName The name of the table being uncached -#' @return SparkDataFrame -#' @rdname uncacheTable -#' @export -#' @examples -#'\dontrun{ -#' sparkR.session() -#' path <- "path/to/file.json" -#' df <- read.json(path) -#' createOrReplaceTempView(df, "table") -#' uncacheTable("table") -#' } -#' @name uncacheTable -#' @method uncacheTable default -#' @note uncacheTable since 1.4.0 -uncacheTable.default <- function(tableName) { - sparkSession <- getSparkSession() - catalog <- callJMethod(sparkSession, "catalog") - invisible(callJMethod(catalog, "uncacheTable", tableName)) -} - -uncacheTable <- function(x, ...) { - dispatchFunc("uncacheTable(tableName)", x, ...) -} - -#' Clear Cache -#' -#' Removes all cached tables from the in-memory cache. -#' -#' @rdname clearCache -#' @export -#' @examples -#' \dontrun{ -#' clearCache() -#' } -#' @name clearCache -#' @method clearCache default -#' @note clearCache since 1.4.0 -clearCache.default <- function() { - sparkSession <- getSparkSession() - catalog <- callJMethod(sparkSession, "catalog") - invisible(callJMethod(catalog, "clearCache")) -} - -clearCache <- function() { - dispatchFunc("clearCache()") -} - -#' (Deprecated) Drop Temporary Table -#' -#' Drops the temporary table with the given table name in the catalog. -#' If the table has been cached/persisted before, it's also unpersisted. -#' -#' @param tableName The name of the SparkSQL table to be dropped. -#' @seealso \link{dropTempView} -#' @rdname dropTempTable-deprecated -#' @export -#' @examples -#' \dontrun{ -#' sparkR.session() -#' df <- read.df(path, "parquet") -#' createOrReplaceTempView(df, "table") -#' dropTempTable("table") -#' } -#' @name dropTempTable -#' @method dropTempTable default -#' @note dropTempTable since 1.4.0 -dropTempTable.default <- function(tableName) { - if (class(tableName) != "character") { - stop("tableName must be a string.") - } - dropTempView(tableName) -} - -dropTempTable <- function(x, ...) { - .Deprecated("dropTempView") - dispatchFunc("dropTempView(viewName)", x, ...) -} - -#' Drops the temporary view with the given view name in the catalog. -#' -#' Drops the temporary view with the given view name in the catalog. -#' If the view has been cached before, then it will also be uncached. -#' -#' @param viewName the name of the view to be dropped. -#' @return TRUE if the view is dropped successfully, FALSE otherwise. -#' @rdname dropTempView -#' @name dropTempView -#' @export -#' @examples -#' \dontrun{ -#' sparkR.session() -#' df <- read.df(path, "parquet") -#' createOrReplaceTempView(df, "table") -#' dropTempView("table") -#' } -#' @note since 2.0.0 - -dropTempView <- function(viewName) { - sparkSession <- getSparkSession() - if (class(viewName) != "character") { - stop("viewName must be a string.") - } - catalog <- callJMethod(sparkSession, "catalog") - callJMethod(catalog, "dropTempView", viewName) -} - #' Load a SparkDataFrame #' #' Returns the dataset in a data source as a SparkDataFrame @@ -797,7 +606,7 @@ dropTempView <- function(viewName) { #' @note read.df since 1.4.0 read.df.default <- function(path = NULL, source = NULL, schema = NULL, na.strings = "NA", ...) { if (!is.null(path) && !is.character(path)) { - stop("path should be charactor, NULL or omitted.") + stop("path should be character, NULL or omitted.") } if (!is.null(source) && !is.character(source)) { stop("source should be character, NULL or omitted. It is the datasource specified ", @@ -841,45 +650,6 @@ loadDF <- function(x = NULL, ...) { dispatchFunc("loadDF(path = NULL, source = NULL, schema = NULL, ...)", x, ...) } -#' Create an external table -#' -#' Creates an external table based on the dataset in a data source, -#' Returns a SparkDataFrame associated with the external table. -#' -#' The data source is specified by the \code{source} and a set of options(...). -#' If \code{source} is not specified, the default data source configured by -#' "spark.sql.sources.default" will be used. -#' -#' @param tableName a name of the table. -#' @param path the path of files to load. -#' @param source the name of external data source. -#' @param ... additional argument(s) passed to the method. -#' @return A SparkDataFrame. -#' @rdname createExternalTable -#' @export -#' @examples -#'\dontrun{ -#' sparkR.session() -#' df <- createExternalTable("myjson", path="path/to/json", source="json") -#' } -#' @name createExternalTable -#' @method createExternalTable default -#' @note createExternalTable since 1.4.0 -createExternalTable.default <- function(tableName, path = NULL, source = NULL, ...) { - sparkSession <- getSparkSession() - options <- varargsToStrEnv(...) - if (!is.null(path)) { - options[["path"]] <- path - } - catalog <- callJMethod(sparkSession, "catalog") - sdf <- callJMethod(catalog, "createExternalTable", tableName, source, options) - dataFrame(sdf) -} - -createExternalTable <- function(x, ...) { - dispatchFunc("createExternalTable(tableName, path = NULL, source = NULL, ...)", x, ...) -} - #' Create a SparkDataFrame representing the database table accessible via JDBC URL #' #' Additional JDBC database connection properties can be set (...) @@ -937,3 +707,53 @@ read.jdbc <- function(url, tableName, } dataFrame(sdf) } + +#' Load a streaming SparkDataFrame +#' +#' Returns the dataset in a data source as a SparkDataFrame +#' +#' The data source is specified by the \code{source} and a set of options(...). +#' If \code{source} is not specified, the default data source configured by +#' "spark.sql.sources.default" will be used. +#' +#' @param source The name of external data source +#' @param schema The data schema defined in structType, this is required for file-based streaming +#' data source +#' @param ... additional external data source specific named options, for instance \code{path} for +#' file-based streaming data source +#' @return SparkDataFrame +#' @rdname read.stream +#' @name read.stream +#' @seealso \link{write.stream} +#' @export +#' @examples +#'\dontrun{ +#' sparkR.session() +#' df <- read.stream("socket", host = "localhost", port = 9999) +#' q <- write.stream(df, "text", path = "/home/user/out", checkpointLocation = "/home/user/cp") +#' +#' df <- read.stream("json", path = jsonDir, schema = schema, maxFilesPerTrigger = 1) +#' } +#' @name read.stream +#' @note read.stream since 2.2.0 +#' @note experimental +read.stream <- function(source = NULL, schema = NULL, ...) { + sparkSession <- getSparkSession() + if (!is.null(source) && !is.character(source)) { + stop("source should be character, NULL or omitted. It is the data source specified ", + "in 'spark.sql.sources.default' configuration by default.") + } + if (is.null(source)) { + source <- getDefaultSqlSource() + } + options <- varargsToStrEnv(...) + read <- callJMethod(sparkSession, "readStream") + read <- callJMethod(read, "format", source) + if (!is.null(schema)) { + stopifnot(class(schema) == "structType") + read <- callJMethod(read, "schema", schema$jobj) + } + read <- callJMethod(read, "options", options) + sdf <- handledCallJMethod(read, "load") + dataFrame(callJMethod(sdf, "toDF")) +} diff --git a/R/pkg/R/catalog.R b/R/pkg/R/catalog.R new file mode 100644 index 0000000000000..e59a7024333ac --- /dev/null +++ b/R/pkg/R/catalog.R @@ -0,0 +1,526 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# catalog.R: SparkSession catalog functions + +#' (Deprecated) Create an external table +#' +#' Creates an external table based on the dataset in a data source, +#' Returns a SparkDataFrame associated with the external table. +#' +#' The data source is specified by the \code{source} and a set of options(...). +#' If \code{source} is not specified, the default data source configured by +#' "spark.sql.sources.default" will be used. +#' +#' @param tableName a name of the table. +#' @param path the path of files to load. +#' @param source the name of external data source. +#' @param schema the schema of the data required for some data sources. +#' @param ... additional argument(s) passed to the method. +#' @return A SparkDataFrame. +#' @rdname createExternalTable-deprecated +#' @seealso \link{createTable} +#' @export +#' @examples +#'\dontrun{ +#' sparkR.session() +#' df <- createExternalTable("myjson", path="path/to/json", source="json", schema) +#' } +#' @name createExternalTable +#' @method createExternalTable default +#' @note createExternalTable since 1.4.0 +createExternalTable.default <- function(tableName, path = NULL, source = NULL, schema = NULL, ...) { + .Deprecated("createTable", old = "createExternalTable") + createTable(tableName, path, source, schema, ...) +} + +createExternalTable <- function(x, ...) { + dispatchFunc("createExternalTable(tableName, path = NULL, source = NULL, ...)", x, ...) +} + +#' Creates a table based on the dataset in a data source +#' +#' Creates a table based on the dataset in a data source. Returns a SparkDataFrame associated with +#' the table. +#' +#' The data source is specified by the \code{source} and a set of options(...). +#' If \code{source} is not specified, the default data source configured by +#' "spark.sql.sources.default" will be used. When a \code{path} is specified, an external table is +#' created from the data at the given path. Otherwise a managed table is created. +#' +#' @param tableName the qualified or unqualified name that designates a table. If no database +#' identifier is provided, it refers to a table in the current database. +#' @param path (optional) the path of files to load. +#' @param source (optional) the name of the data source. +#' @param schema (optional) the schema of the data required for some data sources. +#' @param ... additional named parameters as options for the data source. +#' @return A SparkDataFrame. +#' @rdname createTable +#' @seealso \link{createExternalTable} +#' @export +#' @examples +#'\dontrun{ +#' sparkR.session() +#' df <- createTable("myjson", path="path/to/json", source="json", schema) +#' +#' createTable("people", source = "json", schema = schema) +#' insertInto(df, "people") +#' } +#' @name createTable +#' @note createTable since 2.2.0 +createTable <- function(tableName, path = NULL, source = NULL, schema = NULL, ...) { + sparkSession <- getSparkSession() + options <- varargsToStrEnv(...) + if (!is.null(path)) { + options[["path"]] <- path + } + if (is.null(source)) { + source <- getDefaultSqlSource() + } + catalog <- callJMethod(sparkSession, "catalog") + if (is.null(schema)) { + sdf <- callJMethod(catalog, "createTable", tableName, source, options) + } else if (class(schema) == "structType") { + sdf <- callJMethod(catalog, "createTable", tableName, source, schema$jobj, options) + } else { + stop("schema must be a structType.") + } + dataFrame(sdf) +} + +#' Cache Table +#' +#' Caches the specified table in-memory. +#' +#' @param tableName the qualified or unqualified name that designates a table. If no database +#' identifier is provided, it refers to a table in the current database. +#' @return SparkDataFrame +#' @rdname cacheTable +#' @export +#' @examples +#'\dontrun{ +#' sparkR.session() +#' path <- "path/to/file.json" +#' df <- read.json(path) +#' createOrReplaceTempView(df, "table") +#' cacheTable("table") +#' } +#' @name cacheTable +#' @method cacheTable default +#' @note cacheTable since 1.4.0 +cacheTable.default <- function(tableName) { + sparkSession <- getSparkSession() + catalog <- callJMethod(sparkSession, "catalog") + invisible(handledCallJMethod(catalog, "cacheTable", tableName)) +} + +cacheTable <- function(x, ...) { + dispatchFunc("cacheTable(tableName)", x, ...) +} + +#' Uncache Table +#' +#' Removes the specified table from the in-memory cache. +#' +#' @param tableName the qualified or unqualified name that designates a table. If no database +#' identifier is provided, it refers to a table in the current database. +#' @return SparkDataFrame +#' @rdname uncacheTable +#' @export +#' @examples +#'\dontrun{ +#' sparkR.session() +#' path <- "path/to/file.json" +#' df <- read.json(path) +#' createOrReplaceTempView(df, "table") +#' uncacheTable("table") +#' } +#' @name uncacheTable +#' @method uncacheTable default +#' @note uncacheTable since 1.4.0 +uncacheTable.default <- function(tableName) { + sparkSession <- getSparkSession() + catalog <- callJMethod(sparkSession, "catalog") + invisible(handledCallJMethod(catalog, "uncacheTable", tableName)) +} + +uncacheTable <- function(x, ...) { + dispatchFunc("uncacheTable(tableName)", x, ...) +} + +#' Clear Cache +#' +#' Removes all cached tables from the in-memory cache. +#' +#' @rdname clearCache +#' @export +#' @examples +#' \dontrun{ +#' clearCache() +#' } +#' @name clearCache +#' @method clearCache default +#' @note clearCache since 1.4.0 +clearCache.default <- function() { + sparkSession <- getSparkSession() + catalog <- callJMethod(sparkSession, "catalog") + invisible(callJMethod(catalog, "clearCache")) +} + +clearCache <- function() { + dispatchFunc("clearCache()") +} + +#' (Deprecated) Drop Temporary Table +#' +#' Drops the temporary table with the given table name in the catalog. +#' If the table has been cached/persisted before, it's also unpersisted. +#' +#' @param tableName The name of the SparkSQL table to be dropped. +#' @seealso \link{dropTempView} +#' @rdname dropTempTable-deprecated +#' @export +#' @examples +#' \dontrun{ +#' sparkR.session() +#' df <- read.df(path, "parquet") +#' createOrReplaceTempView(df, "table") +#' dropTempTable("table") +#' } +#' @name dropTempTable +#' @method dropTempTable default +#' @note dropTempTable since 1.4.0 +dropTempTable.default <- function(tableName) { + .Deprecated("dropTempView", old = "dropTempTable") + if (class(tableName) != "character") { + stop("tableName must be a string.") + } + dropTempView(tableName) +} + +dropTempTable <- function(x, ...) { + dispatchFunc("dropTempView(viewName)", x, ...) +} + +#' Drops the temporary view with the given view name in the catalog. +#' +#' Drops the temporary view with the given view name in the catalog. +#' If the view has been cached before, then it will also be uncached. +#' +#' @param viewName the name of the temporary view to be dropped. +#' @return TRUE if the view is dropped successfully, FALSE otherwise. +#' @rdname dropTempView +#' @name dropTempView +#' @export +#' @examples +#' \dontrun{ +#' sparkR.session() +#' df <- read.df(path, "parquet") +#' createOrReplaceTempView(df, "table") +#' dropTempView("table") +#' } +#' @note since 2.0.0 +dropTempView <- function(viewName) { + sparkSession <- getSparkSession() + if (class(viewName) != "character") { + stop("viewName must be a string.") + } + catalog <- callJMethod(sparkSession, "catalog") + callJMethod(catalog, "dropTempView", viewName) +} + +#' Tables +#' +#' Returns a SparkDataFrame containing names of tables in the given database. +#' +#' @param databaseName (optional) name of the database +#' @return a SparkDataFrame +#' @rdname tables +#' @seealso \link{listTables} +#' @export +#' @examples +#'\dontrun{ +#' sparkR.session() +#' tables("hive") +#' } +#' @name tables +#' @method tables default +#' @note tables since 1.4.0 +tables.default <- function(databaseName = NULL) { + # rename column to match previous output schema + withColumnRenamed(listTables(databaseName), "name", "tableName") +} + +tables <- function(x, ...) { + dispatchFunc("tables(databaseName = NULL)", x, ...) +} + +#' Table Names +#' +#' Returns the names of tables in the given database as an array. +#' +#' @param databaseName (optional) name of the database +#' @return a list of table names +#' @rdname tableNames +#' @export +#' @examples +#'\dontrun{ +#' sparkR.session() +#' tableNames("hive") +#' } +#' @name tableNames +#' @method tableNames default +#' @note tableNames since 1.4.0 +tableNames.default <- function(databaseName = NULL) { + sparkSession <- getSparkSession() + callJStatic("org.apache.spark.sql.api.r.SQLUtils", + "getTableNames", + sparkSession, + databaseName) +} + +tableNames <- function(x, ...) { + dispatchFunc("tableNames(databaseName = NULL)", x, ...) +} + +#' Returns the current default database +#' +#' Returns the current default database. +#' +#' @return name of the current default database. +#' @rdname currentDatabase +#' @name currentDatabase +#' @export +#' @examples +#' \dontrun{ +#' sparkR.session() +#' currentDatabase() +#' } +#' @note since 2.2.0 +currentDatabase <- function() { + sparkSession <- getSparkSession() + catalog <- callJMethod(sparkSession, "catalog") + callJMethod(catalog, "currentDatabase") +} + +#' Sets the current default database +#' +#' Sets the current default database. +#' +#' @param databaseName name of the database +#' @rdname setCurrentDatabase +#' @name setCurrentDatabase +#' @export +#' @examples +#' \dontrun{ +#' sparkR.session() +#' setCurrentDatabase("default") +#' } +#' @note since 2.2.0 +setCurrentDatabase <- function(databaseName) { + sparkSession <- getSparkSession() + if (class(databaseName) != "character") { + stop("databaseName must be a string.") + } + catalog <- callJMethod(sparkSession, "catalog") + invisible(handledCallJMethod(catalog, "setCurrentDatabase", databaseName)) +} + +#' Returns a list of databases available +#' +#' Returns a list of databases available. +#' +#' @return a SparkDataFrame of the list of databases. +#' @rdname listDatabases +#' @name listDatabases +#' @export +#' @examples +#' \dontrun{ +#' sparkR.session() +#' listDatabases() +#' } +#' @note since 2.2.0 +listDatabases <- function() { + sparkSession <- getSparkSession() + catalog <- callJMethod(sparkSession, "catalog") + dataFrame(callJMethod(callJMethod(catalog, "listDatabases"), "toDF")) +} + +#' Returns a list of tables or views in the specified database +#' +#' Returns a list of tables or views in the specified database. +#' This includes all temporary views. +#' +#' @param databaseName (optional) name of the database +#' @return a SparkDataFrame of the list of tables. +#' @rdname listTables +#' @name listTables +#' @seealso \link{tables} +#' @export +#' @examples +#' \dontrun{ +#' sparkR.session() +#' listTables() +#' listTables("default") +#' } +#' @note since 2.2.0 +listTables <- function(databaseName = NULL) { + sparkSession <- getSparkSession() + if (!is.null(databaseName) && class(databaseName) != "character") { + stop("databaseName must be a string.") + } + catalog <- callJMethod(sparkSession, "catalog") + jdst <- if (is.null(databaseName)) { + callJMethod(catalog, "listTables") + } else { + handledCallJMethod(catalog, "listTables", databaseName) + } + dataFrame(callJMethod(jdst, "toDF")) +} + +#' Returns a list of columns for the given table/view in the specified database +#' +#' Returns a list of columns for the given table/view in the specified database. +#' +#' @param tableName the qualified or unqualified name that designates a table/view. If no database +#' identifier is provided, it refers to a table/view in the current database. +#' If \code{databaseName} parameter is specified, this must be an unqualified name. +#' @param databaseName (optional) name of the database +#' @return a SparkDataFrame of the list of column descriptions. +#' @rdname listColumns +#' @name listColumns +#' @export +#' @examples +#' \dontrun{ +#' sparkR.session() +#' listColumns("mytable") +#' } +#' @note since 2.2.0 +listColumns <- function(tableName, databaseName = NULL) { + sparkSession <- getSparkSession() + if (!is.null(databaseName) && class(databaseName) != "character") { + stop("databaseName must be a string.") + } + catalog <- callJMethod(sparkSession, "catalog") + jdst <- if (is.null(databaseName)) { + handledCallJMethod(catalog, "listColumns", tableName) + } else { + handledCallJMethod(catalog, "listColumns", databaseName, tableName) + } + dataFrame(callJMethod(jdst, "toDF")) +} + +#' Returns a list of functions registered in the specified database +#' +#' Returns a list of functions registered in the specified database. +#' This includes all temporary functions. +#' +#' @param databaseName (optional) name of the database +#' @return a SparkDataFrame of the list of function descriptions. +#' @rdname listFunctions +#' @name listFunctions +#' @export +#' @examples +#' \dontrun{ +#' sparkR.session() +#' listFunctions() +#' } +#' @note since 2.2.0 +listFunctions <- function(databaseName = NULL) { + sparkSession <- getSparkSession() + if (!is.null(databaseName) && class(databaseName) != "character") { + stop("databaseName must be a string.") + } + catalog <- callJMethod(sparkSession, "catalog") + jdst <- if (is.null(databaseName)) { + callJMethod(catalog, "listFunctions") + } else { + handledCallJMethod(catalog, "listFunctions", databaseName) + } + dataFrame(callJMethod(jdst, "toDF")) +} + +#' Recovers all the partitions in the directory of a table and update the catalog +#' +#' Recovers all the partitions in the directory of a table and update the catalog. The name should +#' reference a partitioned table, and not a view. +#' +#' @param tableName the qualified or unqualified name that designates a table. If no database +#' identifier is provided, it refers to a table in the current database. +#' @rdname recoverPartitions +#' @name recoverPartitions +#' @export +#' @examples +#' \dontrun{ +#' sparkR.session() +#' recoverPartitions("myTable") +#' } +#' @note since 2.2.0 +recoverPartitions <- function(tableName) { + sparkSession <- getSparkSession() + catalog <- callJMethod(sparkSession, "catalog") + invisible(handledCallJMethod(catalog, "recoverPartitions", tableName)) +} + +#' Invalidates and refreshes all the cached data and metadata of the given table +#' +#' Invalidates and refreshes all the cached data and metadata of the given table. For performance +#' reasons, Spark SQL or the external data source library it uses might cache certain metadata about +#' a table, such as the location of blocks. When those change outside of Spark SQL, users should +#' call this function to invalidate the cache. +#' +#' If this table is cached as an InMemoryRelation, drop the original cached version and make the +#' new version cached lazily. +#' +#' @param tableName the qualified or unqualified name that designates a table. If no database +#' identifier is provided, it refers to a table in the current database. +#' @rdname refreshTable +#' @name refreshTable +#' @export +#' @examples +#' \dontrun{ +#' sparkR.session() +#' refreshTable("myTable") +#' } +#' @note since 2.2.0 +refreshTable <- function(tableName) { + sparkSession <- getSparkSession() + catalog <- callJMethod(sparkSession, "catalog") + invisible(handledCallJMethod(catalog, "refreshTable", tableName)) +} + +#' Invalidates and refreshes all the cached data and metadata for SparkDataFrame containing path +#' +#' Invalidates and refreshes all the cached data (and the associated metadata) for any +#' SparkDataFrame that contains the given data source path. Path matching is by prefix, i.e. "/" +#' would invalidate everything that is cached. +#' +#' @param path the path of the data source. +#' @rdname refreshByPath +#' @name refreshByPath +#' @export +#' @examples +#' \dontrun{ +#' sparkR.session() +#' refreshByPath("/path") +#' } +#' @note since 2.2.0 +refreshByPath <- function(path) { + sparkSession <- getSparkSession() + catalog <- callJMethod(sparkSession, "catalog") + invisible(handledCallJMethod(catalog, "refreshByPath", path)) +} diff --git a/R/pkg/R/column.R b/R/pkg/R/column.R index 539d91b0f8797..147ee4b6887b9 100644 --- a/R/pkg/R/column.R +++ b/R/pkg/R/column.R @@ -67,8 +67,7 @@ operators <- list( "+" = "plus", "-" = "minus", "*" = "multiply", "/" = "divide", "%%" = "mod", "==" = "equalTo", ">" = "gt", "<" = "lt", "!=" = "notEqual", "<=" = "leq", ">=" = "geq", # we can not override `&&` and `||`, so use `&` and `|` instead - "&" = "and", "|" = "or", #, "!" = "unary_$bang" - "^" = "pow" + "&" = "and", "|" = "or", "^" = "pow" ) column_functions1 <- c("asc", "desc", "isNaN", "isNull", "isNotNull") column_functions2 <- c("like", "rlike", "getField", "getItem", "contains") @@ -302,3 +301,55 @@ setMethod("otherwise", jc <- callJMethod(x@jc, "otherwise", value) column(jc) }) + +#' \%<=>\% +#' +#' Equality test that is safe for null values. +#' +#' Can be used, unlike standard equality operator, to perform null-safe joins. +#' Equivalent to Scala \code{Column.<=>} and \code{Column.eqNullSafe}. +#' +#' @param x a Column +#' @param value a value to compare +#' @rdname eq_null_safe +#' @name %<=>% +#' @aliases %<=>%,Column-method +#' @export +#' @examples +#' \dontrun{ +#' df1 <- createDataFrame(data.frame( +#' x = c(1, NA, 3, NA), y = c(2, 6, 3, NA) +#' )) +#' +#' head(select(df1, df1$x == df1$y, df1$x %<=>% df1$y)) +#' +#' df2 <- createDataFrame(data.frame(y = c(3, NA))) +#' count(join(df1, df2, df1$y == df2$y)) +#' +#' count(join(df1, df2, df1$y %<=>% df2$y)) +#' } +#' @note \%<=>\% since 2.3.0 +setMethod("%<=>%", + signature(x = "Column", value = "ANY"), + function(x, value) { + value <- if (class(value) == "Column") { value@jc } else { value } + jc <- callJMethod(x@jc, "eqNullSafe", value) + column(jc) + }) + +#' ! +#' +#' Inversion of boolean expression. +#' +#' @rdname not +#' @name not +#' @aliases !,Column-method +#' @export +#' @examples +#' \dontrun{ +#' df <- createDataFrame(data.frame(x = c(-1, 0, 1))) +#' +#' head(select(df, !column("x") > 0)) +#' } +#' @note ! since 2.3.0 +setMethod("!", signature(x = "Column"), function(x) not(x)) diff --git a/R/pkg/R/context.R b/R/pkg/R/context.R index 1a0dd65f450b9..50856e3d9856c 100644 --- a/R/pkg/R/context.R +++ b/R/pkg/R/context.R @@ -291,7 +291,7 @@ broadcast <- function(sc, object) { #' rdd <- parallelize(sc, 1:2, 2L) #' checkpoint(rdd) #'} -setCheckpointDir <- function(sc, dirName) { +setCheckpointDirSC <- function(sc, dirName) { invisible(callJMethod(sc, "setCheckpointDir", suppressWarnings(normalizePath(dirName)))) } @@ -330,7 +330,13 @@ spark.addFile <- function(path, recursive = FALSE) { #'} #' @note spark.getSparkFilesRootDirectory since 2.1.0 spark.getSparkFilesRootDirectory <- function() { - callJStatic("org.apache.spark.SparkFiles", "getRootDirectory") + if (Sys.getenv("SPARKR_IS_RUNNING_ON_WORKER") == "") { + # Running on driver. + callJStatic("org.apache.spark.SparkFiles", "getRootDirectory") + } else { + # Running on worker. + Sys.getenv("SPARKR_SPARKFILES_ROOT_DIR") + } } #' Get the absolute path of a file added through spark.addFile. @@ -345,7 +351,13 @@ spark.getSparkFilesRootDirectory <- function() { #'} #' @note spark.getSparkFiles since 2.1.0 spark.getSparkFiles <- function(fileName) { - callJStatic("org.apache.spark.SparkFiles", "get", as.character(fileName)) + if (Sys.getenv("SPARKR_IS_RUNNING_ON_WORKER") == "") { + # Running on driver. + callJStatic("org.apache.spark.SparkFiles", "get", as.character(fileName)) + } else { + # Running on worker. + file.path(spark.getSparkFilesRootDirectory(), as.character(fileName)) + } } #' Run a function over a list of elements, distributing the computations with Spark @@ -410,3 +422,22 @@ setLogLevel <- function(level) { sc <- getSparkContext() invisible(callJMethod(sc, "setLogLevel", level)) } + +#' Set checkpoint directory +#' +#' Set the directory under which SparkDataFrame are going to be checkpointed. The directory must be +#' a HDFS path if running on a cluster. +#' +#' @rdname setCheckpointDir +#' @param directory Directory path to checkpoint to +#' @seealso \link{checkpoint} +#' @export +#' @examples +#'\dontrun{ +#' setCheckpointDir("/checkpoint") +#'} +#' @note setCheckpointDir since 2.2.0 +setCheckpointDir <- function(directory) { + sc <- getSparkContext() + invisible(callJMethod(sc, "setCheckpointDir", suppressWarnings(normalizePath(directory)))) +} diff --git a/R/pkg/R/functions.R b/R/pkg/R/functions.R index edf2bcf8fdb3c..5f9d11475c94b 100644 --- a/R/pkg/R/functions.R +++ b/R/pkg/R/functions.R @@ -1795,10 +1795,10 @@ setMethod("to_date", #' to_json #' -#' Converts a column containing a \code{structType} into a Column of JSON string. -#' Resolving the Column can fail if an unsupported type is encountered. +#' Converts a column containing a \code{structType} or array of \code{structType} into a Column +#' of JSON string. Resolving the Column can fail if an unsupported type is encountered. #' -#' @param x Column containing the struct +#' @param x Column containing the struct or array of the structs #' @param ... additional named properties to control how it is converted, accepts the same options #' as the JSON data source. #' @@ -1809,8 +1809,13 @@ setMethod("to_date", #' @export #' @examples #' \dontrun{ -#' to_json(df$t, dateFormat = 'dd/MM/yyyy') -#' select(df, to_json(df$t)) +#' # Converts a struct into a JSON object +#' df <- sql("SELECT named_struct('date', cast('2000-01-01' as date)) as d") +#' select(df, to_json(df$d, dateFormat = 'dd/MM/yyyy')) +#' +#' # Converts an array of structs into a JSON array +#' df <- sql("SELECT array(named_struct('name', 'Bob'), named_struct('name', 'Alice')) as people") +#' select(df, to_json(df$people)) #'} #' @note to_json since 2.2.0 setMethod("to_json", signature(x = "Column"), @@ -2433,10 +2438,12 @@ setMethod("date_format", signature(y = "Column", x = "character"), #' from_json #' #' Parses a column containing a JSON string into a Column of \code{structType} with the specified -#' \code{schema}. If the string is unparseable, the Column will contains the value NA. +#' \code{schema} or array of \code{structType} if \code{as.json.array} is set to \code{TRUE}. +#' If the string is unparseable, the Column will contains the value NA. #' #' @param x Column containing the JSON string. #' @param schema a structType object to use as the schema to use when parsing the JSON string. +#' @param as.json.array indicating if input string is JSON array of objects or a single object. #' @param ... additional named properties to control how the json is parsed, accepts the same #' options as the JSON data source. #' @@ -2452,11 +2459,18 @@ setMethod("date_format", signature(y = "Column", x = "character"), #'} #' @note from_json since 2.2.0 setMethod("from_json", signature(x = "Column", schema = "structType"), - function(x, schema, ...) { + function(x, schema, as.json.array = FALSE, ...) { + if (as.json.array) { + jschema <- callJStatic("org.apache.spark.sql.types.DataTypes", + "createArrayType", + schema$jobj) + } else { + jschema <- schema$jobj + } options <- varargsToStrEnv(...) jc <- callJStatic("org.apache.spark.sql.functions", "from_json", - x@jc, schema$jobj, options) + x@jc, jschema, options) column(jc) }) @@ -2618,8 +2632,8 @@ setMethod("date_sub", signature(y = "Column", x = "numeric"), #' format_number #' -#' Formats numeric column y to a format like '#,###,###.##', rounded to x decimal places, -#' and returns the result as a string column. +#' Formats numeric column y to a format like '#,###,###.##', rounded to x decimal places +#' with HALF_EVEN round mode, and returns the result as a string column. #' #' If x is 0, the result has no decimal point or fractional part. #' If x < 0, the result will be null. @@ -3534,7 +3548,7 @@ setMethod("row_number", #' array_contains #' -#' Returns true if the array contain the value. +#' Returns null if the array is null, true if the array contains the value, and false otherwise. #' #' @param x A Column #' @param value A value to be checked if contained in the column @@ -3638,3 +3652,347 @@ setMethod("posexplode", jc <- callJStatic("org.apache.spark.sql.functions", "posexplode", x@jc) column(jc) }) + +#' create_array +#' +#' Creates a new array column. The input columns must all have the same data type. +#' +#' @param x Column to compute on +#' @param ... additional Column(s). +#' +#' @family normal_funcs +#' @rdname create_array +#' @name create_array +#' @aliases create_array,Column-method +#' @export +#' @examples \dontrun{create_array(df$x, df$y, df$z)} +#' @note create_array since 2.3.0 +setMethod("create_array", + signature(x = "Column"), + function(x, ...) { + jcols <- lapply(list(x, ...), function (x) { + stopifnot(class(x) == "Column") + x@jc + }) + jc <- callJStatic("org.apache.spark.sql.functions", "array", jcols) + column(jc) + }) + +#' create_map +#' +#' Creates a new map column. The input columns must be grouped as key-value pairs, +#' e.g. (key1, value1, key2, value2, ...). +#' The key columns must all have the same data type, and can't be null. +#' The value columns must all have the same data type. +#' +#' @param x Column to compute on +#' @param ... additional Column(s). +#' +#' @family normal_funcs +#' @rdname create_map +#' @name create_map +#' @aliases create_map,Column-method +#' @export +#' @examples \dontrun{create_map(lit("x"), lit(1.0), lit("y"), lit(-1.0))} +#' @note create_map since 2.3.0 +setMethod("create_map", + signature(x = "Column"), + function(x, ...) { + jcols <- lapply(list(x, ...), function (x) { + stopifnot(class(x) == "Column") + x@jc + }) + jc <- callJStatic("org.apache.spark.sql.functions", "map", jcols) + column(jc) + }) + +#' collect_list +#' +#' Creates a list of objects with duplicates. +#' +#' @param x Column to compute on +#' +#' @rdname collect_list +#' @name collect_list +#' @family agg_funcs +#' @aliases collect_list,Column-method +#' @export +#' @examples \dontrun{collect_list(df$x)} +#' @note collect_list since 2.3.0 +setMethod("collect_list", + signature(x = "Column"), + function(x) { + jc <- callJStatic("org.apache.spark.sql.functions", "collect_list", x@jc) + column(jc) + }) + +#' collect_set +#' +#' Creates a list of objects with duplicate elements eliminated. +#' +#' @param x Column to compute on +#' +#' @rdname collect_set +#' @name collect_set +#' @family agg_funcs +#' @aliases collect_set,Column-method +#' @export +#' @examples \dontrun{collect_set(df$x)} +#' @note collect_set since 2.3.0 +setMethod("collect_set", + signature(x = "Column"), + function(x) { + jc <- callJStatic("org.apache.spark.sql.functions", "collect_set", x@jc) + column(jc) + }) + +#' split_string +#' +#' Splits string on regular expression. +#' +#' Equivalent to \code{split} SQL function +#' +#' @param x Column to compute on +#' @param pattern Java regular expression +#' +#' @rdname split_string +#' @family string_funcs +#' @aliases split_string,Column-method +#' @export +#' @examples \dontrun{ +#' df <- read.text("README.md") +#' +#' head(select(df, split_string(df$value, "\\s+"))) +#' +#' # This is equivalent to the following SQL expression +#' head(selectExpr(df, "split(value, '\\\\s+')")) +#' } +#' @note split_string 2.3.0 +setMethod("split_string", + signature(x = "Column", pattern = "character"), + function(x, pattern) { + jc <- callJStatic("org.apache.spark.sql.functions", "split", x@jc, pattern) + column(jc) + }) + +#' repeat_string +#' +#' Repeats string n times. +#' +#' Equivalent to \code{repeat} SQL function +#' +#' @param x Column to compute on +#' @param n Number of repetitions +#' +#' @rdname repeat_string +#' @family string_funcs +#' @aliases repeat_string,Column-method +#' @export +#' @examples \dontrun{ +#' df <- read.text("README.md") +#' +#' first(select(df, repeat_string(df$value, 3))) +#' +#' # This is equivalent to the following SQL expression +#' first(selectExpr(df, "repeat(value, 3)")) +#' } +#' @note repeat_string since 2.3.0 +setMethod("repeat_string", + signature(x = "Column", n = "numeric"), + function(x, n) { + jc <- callJStatic("org.apache.spark.sql.functions", "repeat", x@jc, numToInt(n)) + column(jc) + }) + +#' explode_outer +#' +#' Creates a new row for each element in the given array or map column. +#' Unlike \code{explode}, if the array/map is \code{null} or empty +#' then \code{null} is produced. +#' +#' @param x Column to compute on +#' +#' @rdname explode_outer +#' @name explode_outer +#' @family collection_funcs +#' @aliases explode_outer,Column-method +#' @export +#' @examples \dontrun{ +#' df <- createDataFrame(data.frame( +#' id = c(1, 2, 3), text = c("a,b,c", NA, "d,e") +#' )) +#' +#' head(select(df, df$id, explode_outer(split_string(df$text, ",")))) +#' } +#' @note explode_outer since 2.3.0 +setMethod("explode_outer", + signature(x = "Column"), + function(x) { + jc <- callJStatic("org.apache.spark.sql.functions", "explode_outer", x@jc) + column(jc) + }) + +#' posexplode_outer +#' +#' Creates a new row for each element with position in the given array or map column. +#' Unlike \code{posexplode}, if the array/map is \code{null} or empty +#' then the row (\code{null}, \code{null}) is produced. +#' +#' @param x Column to compute on +#' +#' @rdname posexplode_outer +#' @name posexplode_outer +#' @family collection_funcs +#' @aliases posexplode_outer,Column-method +#' @export +#' @examples \dontrun{ +#' df <- createDataFrame(data.frame( +#' id = c(1, 2, 3), text = c("a,b,c", NA, "d,e") +#' )) +#' +#' head(select(df, df$id, posexplode_outer(split_string(df$text, ",")))) +#' } +#' @note posexplode_outer since 2.3.0 +setMethod("posexplode_outer", + signature(x = "Column"), + function(x) { + jc <- callJStatic("org.apache.spark.sql.functions", "posexplode_outer", x@jc) + column(jc) + }) + +#' not +#' +#' Inversion of boolean expression. +#' +#' \code{not} and \code{!} cannot be applied directly to numerical column. +#' To achieve R-like truthiness column has to be casted to \code{BooleanType}. +#' +#' @param x Column to compute on +#' @rdname not +#' @name not +#' @aliases not,Column-method +#' @family normal_funcs +#' @export +#' @examples \dontrun{ +#' df <- createDataFrame(data.frame( +#' is_true = c(TRUE, FALSE, NA), +#' flag = c(1, 0, 1) +#' )) +#' +#' head(select(df, not(df$is_true))) +#' +#' # Explicit cast is required when working with numeric column +#' head(select(df, not(cast(df$flag, "boolean")))) +#' } +#' @note not since 2.3.0 +setMethod("not", + signature(x = "Column"), + function(x) { + jc <- callJStatic("org.apache.spark.sql.functions", "not", x@jc) + column(jc) + }) + +#' grouping_bit +#' +#' Indicates whether a specified column in a GROUP BY list is aggregated or not, +#' returns 1 for aggregated or 0 for not aggregated in the result set. +#' +#' Same as \code{GROUPING} in SQL and \code{grouping} function in Scala. +#' +#' @param x Column to compute on +#' +#' @rdname grouping_bit +#' @name grouping_bit +#' @family agg_funcs +#' @aliases grouping_bit,Column-method +#' @export +#' @examples \dontrun{ +#' df <- createDataFrame(mtcars) +#' +#' # With cube +#' agg( +#' cube(df, "cyl", "gear", "am"), +#' mean(df$mpg), +#' grouping_bit(df$cyl), grouping_bit(df$gear), grouping_bit(df$am) +#' ) +#' +#' # With rollup +#' agg( +#' rollup(df, "cyl", "gear", "am"), +#' mean(df$mpg), +#' grouping_bit(df$cyl), grouping_bit(df$gear), grouping_bit(df$am) +#' ) +#' } +#' @note grouping_bit since 2.3.0 +setMethod("grouping_bit", + signature(x = "Column"), + function(x) { + jc <- callJStatic("org.apache.spark.sql.functions", "grouping", x@jc) + column(jc) + }) + +#' grouping_id +#' +#' Returns the level of grouping. +#' +#' Equals to \code{ +#' grouping_bit(c1) * 2^(n - 1) + grouping_bit(c2) * 2^(n - 2) + ... + grouping_bit(cn) +#' } +#' +#' @param x Column to compute on +#' @param ... additional Column(s) (optional). +#' +#' @rdname grouping_id +#' @name grouping_id +#' @family agg_funcs +#' @aliases grouping_id,Column-method +#' @export +#' @examples \dontrun{ +#' df <- createDataFrame(mtcars) +#' +#' # With cube +#' agg( +#' cube(df, "cyl", "gear", "am"), +#' mean(df$mpg), +#' grouping_id(df$cyl, df$gear, df$am) +#' ) +#' +#' # With rollup +#' agg( +#' rollup(df, "cyl", "gear", "am"), +#' mean(df$mpg), +#' grouping_id(df$cyl, df$gear, df$am) +#' ) +#' } +#' @note grouping_id since 2.3.0 +setMethod("grouping_id", + signature(x = "Column"), + function(x, ...) { + jcols <- lapply(list(x, ...), function (x) { + stopifnot(class(x) == "Column") + x@jc + }) + jc <- callJStatic("org.apache.spark.sql.functions", "grouping_id", jcols) + column(jc) + }) + +#' input_file_name +#' +#' Creates a string column with the input file name for a given row +#' +#' @rdname input_file_name +#' @name input_file_name +#' @family normal_funcs +#' @aliases input_file_name,missing-method +#' @export +#' @examples \dontrun{ +#' df <- read.text("README.md") +#' +#' head(select(df, input_file_name())) +#' } +#' @note input_file_name since 2.3.0 +setMethod("input_file_name", signature("missing"), + function() { + jc <- callJStatic("org.apache.spark.sql.functions", "input_file_name") + column(jc) + }) diff --git a/R/pkg/R/generics.R b/R/pkg/R/generics.R index 45bc12746511c..e835ef3e4f40d 100644 --- a/R/pkg/R/generics.R +++ b/R/pkg/R/generics.R @@ -32,7 +32,7 @@ setGeneric("coalesceRDD", function(x, numPartitions, ...) { standardGeneric("coa # @rdname checkpoint-methods # @export -setGeneric("checkpoint", function(x) { standardGeneric("checkpoint") }) +setGeneric("checkpointRDD", function(x) { standardGeneric("checkpointRDD") }) setGeneric("collectRDD", function(x, ...) { standardGeneric("collectRDD") }) @@ -406,6 +406,10 @@ setGeneric("attach") #' @export setGeneric("cache", function(x) { standardGeneric("cache") }) +#' @rdname checkpoint +#' @export +setGeneric("checkpoint", function(x, eager = TRUE) { standardGeneric("checkpoint") }) + #' @rdname coalesce #' @param x a Column or a SparkDataFrame. #' @param ... additional argument(s). If \code{x} is a Column, additional Columns can be optionally @@ -479,6 +483,10 @@ setGeneric("createOrReplaceTempView", # @export setGeneric("crossJoin", function(x, y) { standardGeneric("crossJoin") }) +#' @rdname cube +#' @export +setGeneric("cube", function(x, ...) { standardGeneric("cube") }) + #' @rdname dapply #' @export setGeneric("dapply", function(x, func, schema) { standardGeneric("dapply") }) @@ -539,6 +547,9 @@ setGeneric("dtypes", function(x) { standardGeneric("dtypes") }) #' @rdname explain #' @export +#' @param x a SparkDataFrame or a StreamingQuery. +#' @param extended Logical. If extended is FALSE, prints only the physical plan. +#' @param ... further arguments to be passed to or from other methods. setGeneric("explain", function(x, ...) { standardGeneric("explain") }) #' @rdname except @@ -565,6 +576,10 @@ setGeneric("group_by", function(x, ...) { standardGeneric("group_by") }) #' @export setGeneric("groupBy", function(x, ...) { standardGeneric("groupBy") }) +#' @rdname hint +#' @export +setGeneric("hint", function(x, name, ...) { standardGeneric("hint") }) + #' @rdname insertInto #' @export setGeneric("insertInto", function(x, tableName, ...) { standardGeneric("insertInto") }) @@ -577,6 +592,10 @@ setGeneric("intersect", function(x, y) { standardGeneric("intersect") }) #' @export setGeneric("isLocal", function(x) { standardGeneric("isLocal") }) +#' @rdname isStreaming +#' @export +setGeneric("isStreaming", function(x) { standardGeneric("isStreaming") }) + #' @rdname limit #' @export setGeneric("limit", function(x, num) {standardGeneric("limit") }) @@ -620,6 +639,10 @@ setGeneric("sample", standardGeneric("sample") }) +#' @rdname rollup +#' @export +setGeneric("rollup", function(x, ...) { standardGeneric("rollup") }) + #' @rdname sample #' @export setGeneric("sample_frac", @@ -682,6 +705,12 @@ setGeneric("write.parquet", function(x, path, ...) { #' @export setGeneric("saveAsParquetFile", function(x, path) { standardGeneric("saveAsParquetFile") }) +#' @rdname write.stream +#' @export +setGeneric("write.stream", function(df, source = NULL, outputMode = NULL, ...) { + standardGeneric("write.stream") +}) + #' @rdname write.text #' @export setGeneric("write.text", function(x, path, ...) { standardGeneric("write.text") }) @@ -831,6 +860,10 @@ setGeneric("otherwise", function(x, value) { standardGeneric("otherwise") }) #' @export setGeneric("over", function(x, window) { standardGeneric("over") }) +#' @rdname eq_null_safe +#' @export +setGeneric("%<=>%", function(x, value) { standardGeneric("%<=>%") }) + ###################### WindowSpec Methods ########################## #' @rdname partitionBy @@ -901,6 +934,14 @@ setGeneric("cbrt", function(x) { standardGeneric("cbrt") }) #' @export setGeneric("ceil", function(x) { standardGeneric("ceil") }) +#' @rdname collect_list +#' @export +setGeneric("collect_list", function(x) { standardGeneric("collect_list") }) + +#' @rdname collect_set +#' @export +setGeneric("collect_set", function(x) { standardGeneric("collect_set") }) + #' @rdname column #' @export setGeneric("column", function(x) { standardGeneric("column") }) @@ -925,6 +966,14 @@ setGeneric("countDistinct", function(x, ...) { standardGeneric("countDistinct") #' @export setGeneric("crc32", function(x) { standardGeneric("crc32") }) +#' @rdname create_array +#' @export +setGeneric("create_array", function(x, ...) { standardGeneric("create_array") }) + +#' @rdname create_map +#' @export +setGeneric("create_map", function(x, ...) { standardGeneric("create_map") }) + #' @rdname hash #' @export setGeneric("hash", function(x, ...) { standardGeneric("hash") }) @@ -975,6 +1024,10 @@ setGeneric("encode", function(x, charset) { standardGeneric("encode") }) #' @export setGeneric("explode", function(x) { standardGeneric("explode") }) +#' @rdname explode_outer +#' @export +setGeneric("explode_outer", function(x) { standardGeneric("explode_outer") }) + #' @rdname expr #' @export setGeneric("expr", function(x) { standardGeneric("expr") }) @@ -1003,6 +1056,14 @@ setGeneric("from_unixtime", function(x, ...) { standardGeneric("from_unixtime") #' @export setGeneric("greatest", function(x, ...) { standardGeneric("greatest") }) +#' @rdname grouping_bit +#' @export +setGeneric("grouping_bit", function(x) { standardGeneric("grouping_bit") }) + +#' @rdname grouping_id +#' @export +setGeneric("grouping_id", function(x, ...) { standardGeneric("grouping_id") }) + #' @rdname hex #' @export setGeneric("hex", function(x) { standardGeneric("hex") }) @@ -1019,6 +1080,12 @@ setGeneric("hypot", function(y, x) { standardGeneric("hypot") }) #' @export setGeneric("initcap", function(x) { standardGeneric("initcap") }) +#' @param x empty. Should be used with no argument. +#' @rdname input_file_name +#' @export +setGeneric("input_file_name", + function(x = "missing") { standardGeneric("input_file_name") }) + #' @rdname instr #' @export setGeneric("instr", function(y, x) { standardGeneric("instr") }) @@ -1109,6 +1176,10 @@ setGeneric("nanvl", function(y, x) { standardGeneric("nanvl") }) #' @export setGeneric("negate", function(x) { standardGeneric("negate") }) +#' @rdname not +#' @export +setGeneric("not", function(x) { standardGeneric("not") }) + #' @rdname next_day #' @export setGeneric("next_day", function(y, x) { standardGeneric("next_day") }) @@ -1134,6 +1205,10 @@ setGeneric("pmod", function(y, x) { standardGeneric("pmod") }) #' @export setGeneric("posexplode", function(x) { standardGeneric("posexplode") }) +#' @rdname posexplode_outer +#' @export +setGeneric("posexplode_outer", function(x) { standardGeneric("posexplode_outer") }) + #' @rdname quarter #' @export setGeneric("quarter", function(x) { standardGeneric("quarter") }) @@ -1159,6 +1234,10 @@ setGeneric("regexp_extract", function(x, pattern, idx) { standardGeneric("regexp setGeneric("regexp_replace", function(x, pattern, replacement) { standardGeneric("regexp_replace") }) +#' @rdname repeat_string +#' @export +setGeneric("repeat_string", function(x, n) { standardGeneric("repeat_string") }) + #' @rdname reverse #' @export setGeneric("reverse", function(x) { standardGeneric("reverse") }) @@ -1224,6 +1303,10 @@ setGeneric("skewness", function(x) { standardGeneric("skewness") }) #' @export setGeneric("sort_array", function(x, asc = TRUE) { standardGeneric("sort_array") }) +#' @rdname split_string +#' @export +setGeneric("split_string", function(x, pattern) { standardGeneric("split_string") }) + #' @rdname soundex #' @export setGeneric("soundex", function(x) { standardGeneric("soundex") }) @@ -1333,6 +1416,7 @@ setGeneric("window", function(x, ...) { standardGeneric("window") }) #' @export setGeneric("year", function(x) { standardGeneric("year") }) + ###################### Spark.ML Methods ########################## #' @rdname fitted @@ -1428,6 +1512,17 @@ setGeneric("spark.posterior", function(object, newData) { standardGeneric("spark #' @export setGeneric("spark.perplexity", function(object, data) { standardGeneric("spark.perplexity") }) +#' @rdname spark.fpGrowth +#' @export +setGeneric("spark.fpGrowth", function(data, ...) { standardGeneric("spark.fpGrowth") }) + +#' @rdname spark.fpGrowth +#' @export +setGeneric("spark.freqItemsets", function(object) { standardGeneric("spark.freqItemsets") }) + +#' @rdname spark.fpGrowth +#' @export +setGeneric("spark.associationRules", function(object) { standardGeneric("spark.associationRules") }) #' @param object a fitted ML model object. #' @param path the directory where the model is saved. @@ -1435,3 +1530,30 @@ setGeneric("spark.perplexity", function(object, data) { standardGeneric("spark.p #' @rdname write.ml #' @export setGeneric("write.ml", function(object, path, ...) { standardGeneric("write.ml") }) + + +###################### Streaming Methods ########################## + +#' @rdname awaitTermination +#' @export +setGeneric("awaitTermination", function(x, timeout = NULL) { standardGeneric("awaitTermination") }) + +#' @rdname isActive +#' @export +setGeneric("isActive", function(x) { standardGeneric("isActive") }) + +#' @rdname lastProgress +#' @export +setGeneric("lastProgress", function(x) { standardGeneric("lastProgress") }) + +#' @rdname queryName +#' @export +setGeneric("queryName", function(x) { standardGeneric("queryName") }) + +#' @rdname status +#' @export +setGeneric("status", function(x) { standardGeneric("status") }) + +#' @rdname stopQuery +#' @export +setGeneric("stopQuery", function(x) { standardGeneric("stopQuery") }) diff --git a/R/pkg/R/mllib_clustering.R b/R/pkg/R/mllib_clustering.R index 0ebdb5a273088..97c9fa1b45840 100644 --- a/R/pkg/R/mllib_clustering.R +++ b/R/pkg/R/mllib_clustering.R @@ -498,11 +498,7 @@ setMethod("write.ml", signature(object = "KMeansModel", path = "character"), #' @export #' @examples #' \dontrun{ -#' # nolint start -#' # An example "path/to/file" can be -#' # paste0(Sys.getenv("SPARK_HOME"), "/data/mllib/sample_lda_libsvm_data.txt") -#' # nolint end -#' text <- read.df("path/to/file", source = "libsvm") +#' text <- read.df("data/mllib/sample_lda_libsvm_data.txt", source = "libsvm") #' model <- spark.lda(data = text, optimizer = "em") #' #' # get a summary of the model diff --git a/R/pkg/R/mllib_fpm.R b/R/pkg/R/mllib_fpm.R new file mode 100644 index 0000000000000..dfcb45a1b66c9 --- /dev/null +++ b/R/pkg/R/mllib_fpm.R @@ -0,0 +1,162 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# mllib_fpm.R: Provides methods for MLlib frequent pattern mining algorithms integration + +#' S4 class that represents a FPGrowthModel +#' +#' @param jobj a Java object reference to the backing Scala FPGrowthModel +#' @export +#' @note FPGrowthModel since 2.2.0 +setClass("FPGrowthModel", slots = list(jobj = "jobj")) + +#' FP-growth +#' +#' A parallel FP-growth algorithm to mine frequent itemsets. +#' \code{spark.fpGrowth} fits a FP-growth model on a SparkDataFrame. Users can +#' \code{spark.freqItemsets} to get frequent itemsets, \code{spark.associationRules} to get +#' association rules, \code{predict} to make predictions on new data based on generated association +#' rules, and \code{write.ml}/\code{read.ml} to save/load fitted models. +#' For more details, see +#' \href{https://spark.apache.org/docs/latest/mllib-frequent-pattern-mining.html#fp-growth}{ +#' FP-growth}. +#' +#' @param data A SparkDataFrame for training. +#' @param minSupport Minimal support level. +#' @param minConfidence Minimal confidence level. +#' @param itemsCol Features column name. +#' @param numPartitions Number of partitions used for fitting. +#' @param ... additional argument(s) passed to the method. +#' @return \code{spark.fpGrowth} returns a fitted FPGrowth model. +#' @rdname spark.fpGrowth +#' @name spark.fpGrowth +#' @aliases spark.fpGrowth,SparkDataFrame-method +#' @export +#' @examples +#' \dontrun{ +#' raw_data <- read.df( +#' "data/mllib/sample_fpgrowth.txt", +#' source = "csv", +#' schema = structType(structField("raw_items", "string"))) +#' +#' data <- selectExpr(raw_data, "split(raw_items, ' ') as items") +#' model <- spark.fpGrowth(data) +#' +#' # Show frequent itemsets +#' frequent_itemsets <- spark.freqItemsets(model) +#' showDF(frequent_itemsets) +#' +#' # Show association rules +#' association_rules <- spark.associationRules(model) +#' showDF(association_rules) +#' +#' # Predict on new data +#' new_itemsets <- data.frame(items = c("t", "t,s")) +#' new_data <- selectExpr(createDataFrame(new_itemsets), "split(items, ',') as items") +#' predict(model, new_data) +#' +#' # Save and load model +#' path <- "/path/to/model" +#' write.ml(model, path) +#' read.ml(path) +#' +#' # Optional arguments +#' baskets_data <- selectExpr(createDataFrame(itemsets), "split(items, ',') as baskets") +#' another_model <- spark.fpGrowth(data, minSupport = 0.1, minConfidence = 0.5, +#' itemsCol = "baskets", numPartitions = 10) +#' } +#' @note spark.fpGrowth since 2.2.0 +setMethod("spark.fpGrowth", signature(data = "SparkDataFrame"), + function(data, minSupport = 0.3, minConfidence = 0.8, + itemsCol = "items", numPartitions = NULL) { + if (!is.numeric(minSupport) || minSupport < 0 || minSupport > 1) { + stop("minSupport should be a number [0, 1].") + } + if (!is.numeric(minConfidence) || minConfidence < 0 || minConfidence > 1) { + stop("minConfidence should be a number [0, 1].") + } + if (!is.null(numPartitions)) { + numPartitions <- as.integer(numPartitions) + stopifnot(numPartitions > 0) + } + + jobj <- callJStatic("org.apache.spark.ml.r.FPGrowthWrapper", "fit", + data@sdf, as.numeric(minSupport), as.numeric(minConfidence), + itemsCol, numPartitions) + new("FPGrowthModel", jobj = jobj) + }) + +# Get frequent itemsets. + +#' @param object a fitted FPGrowth model. +#' @return A \code{SparkDataFrame} with frequent itemsets. +#' The \code{SparkDataFrame} contains two columns: +#' \code{items} (an array of the same type as the input column) +#' and \code{freq} (frequency of the itemset). +#' @rdname spark.fpGrowth +#' @aliases freqItemsets,FPGrowthModel-method +#' @export +#' @note spark.freqItemsets(FPGrowthModel) since 2.2.0 +setMethod("spark.freqItemsets", signature(object = "FPGrowthModel"), + function(object) { + dataFrame(callJMethod(object@jobj, "freqItemsets")) + }) + +# Get association rules. + +#' @return A \code{SparkDataFrame} with association rules. +#' The \code{SparkDataFrame} contains three columns: +#' \code{antecedent} (an array of the same type as the input column), +#' \code{consequent} (an array of the same type as the input column), +#' and \code{condfidence} (confidence). +#' @rdname spark.fpGrowth +#' @aliases associationRules,FPGrowthModel-method +#' @export +#' @note spark.associationRules(FPGrowthModel) since 2.2.0 +setMethod("spark.associationRules", signature(object = "FPGrowthModel"), + function(object) { + dataFrame(callJMethod(object@jobj, "associationRules")) + }) + +# Makes predictions based on generated association rules + +#' @param newData a SparkDataFrame for testing. +#' @return \code{predict} returns a SparkDataFrame containing predicted values. +#' @rdname spark.fpGrowth +#' @aliases predict,FPGrowthModel-method +#' @export +#' @note predict(FPGrowthModel) since 2.2.0 +setMethod("predict", signature(object = "FPGrowthModel"), + function(object, newData) { + predict_internal(object, newData) + }) + +# Saves the FPGrowth model to the output path. + +#' @param path the directory where the model is saved. +#' @param overwrite logical value indicating whether to overwrite if the output path +#' already exists. Default is FALSE which means throw exception +#' if the output path exists. +#' @rdname spark.fpGrowth +#' @aliases write.ml,FPGrowthModel,character-method +#' @export +#' @seealso \link{read.ml} +#' @note write.ml(FPGrowthModel, character) since 2.2.0 +setMethod("write.ml", signature(object = "FPGrowthModel", path = "character"), + function(object, path, overwrite = FALSE) { + write_internal(object, path, overwrite) + }) diff --git a/R/pkg/R/mllib_regression.R b/R/pkg/R/mllib_regression.R index 648d363f1a255..d59c890f3e5fd 100644 --- a/R/pkg/R/mllib_regression.R +++ b/R/pkg/R/mllib_regression.R @@ -53,12 +53,23 @@ setClass("IsotonicRegressionModel", representation(jobj = "jobj")) #' the result of a call to a family function. Refer R family at #' \url{https://stat.ethz.ch/R-manual/R-devel/library/stats/html/family.html}. #' Currently these families are supported: \code{binomial}, \code{gaussian}, -#' \code{Gamma}, and \code{poisson}. +#' \code{Gamma}, \code{poisson} and \code{tweedie}. +#' +#' Note that there are two ways to specify the tweedie family. +#' \itemize{ +#' \item Set \code{family = "tweedie"} and specify the var.power and link.power; +#' \item When package \code{statmod} is loaded, the tweedie family is specified using the +#' family definition therein, i.e., \code{tweedie(var.power, link.power)}. +#' } #' @param tol positive convergence tolerance of iterations. #' @param maxIter integer giving the maximal number of IRLS iterations. #' @param weightCol the weight column name. If this is not set or \code{NULL}, we treat all instance #' weights as 1.0. #' @param regParam regularization parameter for L2 regularization. +#' @param var.power the power in the variance function of the Tweedie distribution which provides +#' the relationship between the variance and mean of the distribution. Only +#' applicable to the Tweedie family. +#' @param link.power the index in the power link function. Only applicable to the Tweedie family. #' @param ... additional arguments passed to the method. #' @aliases spark.glm,SparkDataFrame,formula-method #' @return \code{spark.glm} returns a fitted generalized linear model. @@ -84,14 +95,30 @@ setClass("IsotonicRegressionModel", representation(jobj = "jobj")) #' # can also read back the saved model and print #' savedModel <- read.ml(path) #' summary(savedModel) +#' +#' # fit tweedie model +#' model <- spark.glm(df, Freq ~ Sex + Age, family = "tweedie", +#' var.power = 1.2, link.power = 0) +#' summary(model) +#' +#' # use the tweedie family from statmod +#' library(statmod) +#' model <- spark.glm(df, Freq ~ Sex + Age, family = tweedie(1.2, 0)) +#' summary(model) #' } #' @note spark.glm since 2.0.0 #' @seealso \link{glm}, \link{read.ml} setMethod("spark.glm", signature(data = "SparkDataFrame", formula = "formula"), function(data, formula, family = gaussian, tol = 1e-6, maxIter = 25, weightCol = NULL, - regParam = 0.0) { + regParam = 0.0, var.power = 0.0, link.power = 1.0 - var.power) { + if (is.character(family)) { - family <- get(family, mode = "function", envir = parent.frame()) + # Handle when family = "tweedie" + if (tolower(family) == "tweedie") { + family <- list(family = "tweedie", link = NULL) + } else { + family <- get(family, mode = "function", envir = parent.frame()) + } } if (is.function(family)) { family <- family() @@ -100,6 +127,12 @@ setMethod("spark.glm", signature(data = "SparkDataFrame", formula = "formula"), print(family) stop("'family' not recognized") } + # Handle when family = statmod::tweedie() + if (tolower(family$family) == "tweedie" && !is.null(family$variance)) { + var.power <- log(family$variance(exp(1))) + link.power <- log(family$linkfun(exp(1))) + family <- list(family = "tweedie", link = NULL) + } formula <- paste(deparse(formula), collapse = "") if (!is.null(weightCol) && weightCol == "") { @@ -111,7 +144,8 @@ setMethod("spark.glm", signature(data = "SparkDataFrame", formula = "formula"), # For known families, Gamma is upper-cased jobj <- callJStatic("org.apache.spark.ml.r.GeneralizedLinearRegressionWrapper", "fit", formula, data@sdf, tolower(family$family), family$link, - tol, as.integer(maxIter), weightCol, regParam) + tol, as.integer(maxIter), weightCol, regParam, + as.double(var.power), as.double(link.power)) new("GeneralizedLinearRegressionModel", jobj = jobj) }) @@ -126,11 +160,13 @@ setMethod("spark.glm", signature(data = "SparkDataFrame", formula = "formula"), #' the result of a call to a family function. Refer R family at #' \url{https://stat.ethz.ch/R-manual/R-devel/library/stats/html/family.html}. #' Currently these families are supported: \code{binomial}, \code{gaussian}, -#' \code{Gamma}, and \code{poisson}. +#' \code{poisson}, \code{Gamma}, and \code{tweedie}. #' @param weightCol the weight column name. If this is not set or \code{NULL}, we treat all instance #' weights as 1.0. #' @param epsilon positive convergence tolerance of iterations. #' @param maxit integer giving the maximal number of IRLS iterations. +#' @param var.power the index of the power variance function in the Tweedie family. +#' @param link.power the index of the power link function in the Tweedie family. #' @return \code{glm} returns a fitted generalized linear model. #' @rdname glm #' @export @@ -145,8 +181,10 @@ setMethod("spark.glm", signature(data = "SparkDataFrame", formula = "formula"), #' @note glm since 1.5.0 #' @seealso \link{spark.glm} setMethod("glm", signature(formula = "formula", family = "ANY", data = "SparkDataFrame"), - function(formula, family = gaussian, data, epsilon = 1e-6, maxit = 25, weightCol = NULL) { - spark.glm(data, formula, family, tol = epsilon, maxIter = maxit, weightCol = weightCol) + function(formula, family = gaussian, data, epsilon = 1e-6, maxit = 25, weightCol = NULL, + var.power = 0.0, link.power = 1.0 - var.power) { + spark.glm(data, formula, family, tol = epsilon, maxIter = maxit, weightCol = weightCol, + var.power = var.power, link.power = link.power) }) # Returns the summary of a model produced by glm() or spark.glm(), similarly to R's summary(). @@ -172,9 +210,10 @@ setMethod("summary", signature(object = "GeneralizedLinearRegressionModel"), deviance <- callJMethod(jobj, "rDeviance") df.null <- callJMethod(jobj, "rResidualDegreeOfFreedomNull") df.residual <- callJMethod(jobj, "rResidualDegreeOfFreedom") - aic <- callJMethod(jobj, "rAic") iter <- callJMethod(jobj, "rNumIterations") family <- callJMethod(jobj, "rFamily") + aic <- callJMethod(jobj, "rAic") + if (family == "tweedie" && aic == 0) aic <- NA deviance.resid <- if (is.loaded) { NULL } else { diff --git a/R/pkg/R/mllib_tree.R b/R/pkg/R/mllib_tree.R index 40a806c41bad0..82279be6fbe77 100644 --- a/R/pkg/R/mllib_tree.R +++ b/R/pkg/R/mllib_tree.R @@ -52,12 +52,14 @@ summary.treeEnsemble <- function(model) { numFeatures <- callJMethod(jobj, "numFeatures") features <- callJMethod(jobj, "features") featureImportances <- callJMethod(callJMethod(jobj, "featureImportances"), "toString") + maxDepth <- callJMethod(jobj, "maxDepth") numTrees <- callJMethod(jobj, "numTrees") treeWeights <- callJMethod(jobj, "treeWeights") list(formula = formula, numFeatures = numFeatures, features = features, featureImportances = featureImportances, + maxDepth = maxDepth, numTrees = numTrees, treeWeights = treeWeights, jobj = jobj) @@ -70,6 +72,7 @@ print.summary.treeEnsemble <- function(x) { cat("\nNumber of features: ", x$numFeatures) cat("\nFeatures: ", unlist(x$features)) cat("\nFeature importances: ", x$featureImportances) + cat("\nMax Depth: ", x$maxDepth) cat("\nNumber of trees: ", x$numTrees) cat("\nTree weights: ", unlist(x$treeWeights)) @@ -197,8 +200,8 @@ setMethod("spark.gbt", signature(data = "SparkDataFrame", formula = "formula"), #' @return \code{summary} returns summary information of the fitted model, which is a list. #' The list of components includes \code{formula} (formula), #' \code{numFeatures} (number of features), \code{features} (list of features), -#' \code{featureImportances} (feature importances), \code{numTrees} (number of trees), -#' and \code{treeWeights} (tree weights). +#' \code{featureImportances} (feature importances), \code{maxDepth} (max depth of trees), +#' \code{numTrees} (number of trees), and \code{treeWeights} (tree weights). #' @rdname spark.gbt #' @aliases summary,GBTRegressionModel-method #' @export @@ -403,8 +406,8 @@ setMethod("spark.randomForest", signature(data = "SparkDataFrame", formula = "fo #' @return \code{summary} returns summary information of the fitted model, which is a list. #' The list of components includes \code{formula} (formula), #' \code{numFeatures} (number of features), \code{features} (list of features), -#' \code{featureImportances} (feature importances), \code{numTrees} (number of trees), -#' and \code{treeWeights} (tree weights). +#' \code{featureImportances} (feature importances), \code{maxDepth} (max depth of trees), +#' \code{numTrees} (number of trees), and \code{treeWeights} (tree weights). #' @rdname spark.randomForest #' @aliases summary,RandomForestRegressionModel-method #' @export diff --git a/R/pkg/R/mllib_utils.R b/R/pkg/R/mllib_utils.R index 04a0a6f944412..5dfef8625061b 100644 --- a/R/pkg/R/mllib_utils.R +++ b/R/pkg/R/mllib_utils.R @@ -118,6 +118,8 @@ read.ml <- function(path) { new("BisectingKMeansModel", jobj = jobj) } else if (isInstanceOf(jobj, "org.apache.spark.ml.r.LinearSVCWrapper")) { new("LinearSVCModel", jobj = jobj) + } else if (isInstanceOf(jobj, "org.apache.spark.ml.r.FPGrowthWrapper")) { + new("FPGrowthModel", jobj = jobj) } else { stop("Unsupported model: ", jobj) } diff --git a/R/pkg/R/sparkR.R b/R/pkg/R/sparkR.R index 61773ed3ee8c0..d0a12b7ecec65 100644 --- a/R/pkg/R/sparkR.R +++ b/R/pkg/R/sparkR.R @@ -322,10 +322,19 @@ sparkRHive.init <- function(jsc = NULL) { #' SparkSession or initializes a new SparkSession. #' Additional Spark properties can be set in \code{...}, and these named parameters take priority #' over values in \code{master}, \code{appName}, named lists of \code{sparkConfig}. -#' When called in an interactive session, this checks for the Spark installation, and, if not +#' +#' When called in an interactive session, this method checks for the Spark installation, and, if not #' found, it will be downloaded and cached automatically. Alternatively, \code{install.spark} can #' be called manually. #' +#' A default warehouse is created automatically in the current directory when a managed table is +#' created via \code{sql} statement \code{CREATE TABLE}, for example. To change the location of the +#' warehouse, set the named parameter \code{spark.sql.warehouse.dir} to the SparkSession. Along with +#' the warehouse, an accompanied metastore may also be automatically created in the current +#' directory when a new SparkSession is initialized with \code{enableHiveSupport} set to +#' \code{TRUE}, which is the default. For more details, refer to Hive configuration at +#' \url{http://spark.apache.org/docs/latest/sql-programming-guide.html#hive-tables}. +#' #' For details on how to initialize and use SparkR, refer to SparkR programming guide at #' \url{http://spark.apache.org/docs/latest/sparkr.html#starting-up-sparksession}. #' @@ -381,6 +390,10 @@ sparkR.session <- function( deployMode <- sparkConfigMap[["spark.submit.deployMode"]] } + if (!exists("spark.r.sql.derby.temp.dir", envir = sparkConfigMap)) { + sparkConfigMap[["spark.r.sql.derby.temp.dir"]] <- tempdir() + } + if (!exists(".sparkRjsc", envir = .sparkREnv)) { retHome <- sparkCheckInstall(sparkHome, master, deployMode) if (!is.null(retHome)) sparkHome <- retHome diff --git a/R/pkg/R/stats.R b/R/pkg/R/stats.R index 8d1d165052f7f..d78a10893f92e 100644 --- a/R/pkg/R/stats.R +++ b/R/pkg/R/stats.R @@ -149,7 +149,8 @@ setMethod("freqItems", signature(x = "SparkDataFrame", cols = "character"), #' This method implements a variation of the Greenwald-Khanna algorithm (with some speed #' optimizations). The algorithm was first present in [[http://dx.doi.org/10.1145/375663.375670 #' Space-efficient Online Computation of Quantile Summaries]] by Greenwald and Khanna. -#' Note that rows containing any NA values will be removed before calculation. +#' Note that NA values will be ignored in numerical columns before calculation. For +#' columns only containing NA values, an empty list is returned. #' #' @param x A SparkDataFrame. #' @param cols A single column name, or a list of names for multiple columns. diff --git a/R/pkg/R/streaming.R b/R/pkg/R/streaming.R new file mode 100644 index 0000000000000..8390bd5e6de72 --- /dev/null +++ b/R/pkg/R/streaming.R @@ -0,0 +1,214 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# streaming.R - Structured Streaming / StreamingQuery class and methods implemented in S4 OO classes + +#' @include generics.R jobj.R +NULL + +#' S4 class that represents a StreamingQuery +#' +#' StreamingQuery can be created by using read.stream() and write.stream() +#' +#' @rdname StreamingQuery +#' @seealso \link{read.stream} +#' +#' @param ssq A Java object reference to the backing Scala StreamingQuery +#' @export +#' @note StreamingQuery since 2.2.0 +#' @note experimental +setClass("StreamingQuery", + slots = list(ssq = "jobj")) + +setMethod("initialize", "StreamingQuery", function(.Object, ssq) { + .Object@ssq <- ssq + .Object +}) + +streamingQuery <- function(ssq) { + stopifnot(class(ssq) == "jobj") + new("StreamingQuery", ssq) +} + +#' @rdname show +#' @export +#' @note show(StreamingQuery) since 2.2.0 +setMethod("show", "StreamingQuery", + function(object) { + name <- callJMethod(object@ssq, "name") + if (!is.null(name)) { + cat(paste0("StreamingQuery '", name, "'\n")) + } else { + cat("StreamingQuery", "\n") + } + }) + +#' queryName +#' +#' Returns the user-specified name of the query. This is specified in +#' \code{write.stream(df, queryName = "query")}. This name, if set, must be unique across all active +#' queries. +#' +#' @param x a StreamingQuery. +#' @return The name of the query, or NULL if not specified. +#' @rdname queryName +#' @name queryName +#' @aliases queryName,StreamingQuery-method +#' @family StreamingQuery methods +#' @seealso \link{write.stream} +#' @export +#' @examples +#' \dontrun{ queryName(sq) } +#' @note queryName(StreamingQuery) since 2.2.0 +#' @note experimental +setMethod("queryName", + signature(x = "StreamingQuery"), + function(x) { + callJMethod(x@ssq, "name") + }) + +#' @rdname explain +#' @name explain +#' @aliases explain,StreamingQuery-method +#' @family StreamingQuery methods +#' @export +#' @examples +#' \dontrun{ explain(sq) } +#' @note explain(StreamingQuery) since 2.2.0 +setMethod("explain", + signature(x = "StreamingQuery"), + function(x, extended = FALSE) { + cat(callJMethod(x@ssq, "explainInternal", extended), "\n") + }) + +#' lastProgress +#' +#' Prints the most recent progess update of this streaming query in JSON format. +#' +#' @param x a StreamingQuery. +#' @rdname lastProgress +#' @name lastProgress +#' @aliases lastProgress,StreamingQuery-method +#' @family StreamingQuery methods +#' @export +#' @examples +#' \dontrun{ lastProgress(sq) } +#' @note lastProgress(StreamingQuery) since 2.2.0 +#' @note experimental +setMethod("lastProgress", + signature(x = "StreamingQuery"), + function(x) { + p <- callJMethod(x@ssq, "lastProgress") + if (is.null(p)) { + cat("Streaming query has no progress") + } else { + cat(callJMethod(p, "toString"), "\n") + } + }) + +#' status +#' +#' Prints the current status of the query in JSON format. +#' +#' @param x a StreamingQuery. +#' @rdname status +#' @name status +#' @aliases status,StreamingQuery-method +#' @family StreamingQuery methods +#' @export +#' @examples +#' \dontrun{ status(sq) } +#' @note status(StreamingQuery) since 2.2.0 +#' @note experimental +setMethod("status", + signature(x = "StreamingQuery"), + function(x) { + cat(callJMethod(callJMethod(x@ssq, "status"), "toString"), "\n") + }) + +#' isActive +#' +#' Returns TRUE if this query is actively running. +#' +#' @param x a StreamingQuery. +#' @return TRUE if query is actively running, FALSE if stopped. +#' @rdname isActive +#' @name isActive +#' @aliases isActive,StreamingQuery-method +#' @family StreamingQuery methods +#' @export +#' @examples +#' \dontrun{ isActive(sq) } +#' @note isActive(StreamingQuery) since 2.2.0 +#' @note experimental +setMethod("isActive", + signature(x = "StreamingQuery"), + function(x) { + callJMethod(x@ssq, "isActive") + }) + +#' awaitTermination +#' +#' Waits for the termination of the query, either by \code{stopQuery} or by an error. +#' +#' If the query has terminated, then all subsequent calls to this method will return TRUE +#' immediately. +#' +#' @param x a StreamingQuery. +#' @param timeout time to wait in milliseconds, if omitted, wait indefinitely until \code{stopQuery} +#' is called or an error has occured. +#' @return TRUE if query has terminated within the timeout period; nothing if timeout is not +#' specified. +#' @rdname awaitTermination +#' @name awaitTermination +#' @aliases awaitTermination,StreamingQuery-method +#' @family StreamingQuery methods +#' @export +#' @examples +#' \dontrun{ awaitTermination(sq, 10000) } +#' @note awaitTermination(StreamingQuery) since 2.2.0 +#' @note experimental +setMethod("awaitTermination", + signature(x = "StreamingQuery"), + function(x, timeout = NULL) { + if (is.null(timeout)) { + invisible(handledCallJMethod(x@ssq, "awaitTermination")) + } else { + handledCallJMethod(x@ssq, "awaitTermination", as.integer(timeout)) + } + }) + +#' stopQuery +#' +#' Stops the execution of this query if it is running. This method blocks until the execution is +#' stopped. +#' +#' @param x a StreamingQuery. +#' @rdname stopQuery +#' @name stopQuery +#' @aliases stopQuery,StreamingQuery-method +#' @family StreamingQuery methods +#' @export +#' @examples +#' \dontrun{ stopQuery(sq) } +#' @note stopQuery(StreamingQuery) since 2.2.0 +#' @note experimental +setMethod("stopQuery", + signature(x = "StreamingQuery"), + function(x) { + invisible(callJMethod(x@ssq, "stop")) + }) diff --git a/R/pkg/R/utils.R b/R/pkg/R/utils.R index 1f7848f2b413f..d29af00affb98 100644 --- a/R/pkg/R/utils.R +++ b/R/pkg/R/utils.R @@ -823,7 +823,16 @@ captureJVMException <- function(e, method) { stacktrace <- rawmsg } - if (any(grep("java.lang.IllegalArgumentException: ", stacktrace))) { + # StreamingQueryException could wrap an IllegalArgumentException, so look for that first + if (any(grep("org.apache.spark.sql.streaming.StreamingQueryException: ", stacktrace))) { + msg <- strsplit(stacktrace, "org.apache.spark.sql.streaming.StreamingQueryException: ", + fixed = TRUE)[[1]] + # Extract "Error in ..." message. + rmsg <- msg[1] + # Extract the first message of JVM exception. + first <- strsplit(msg[2], "\r?\n\tat")[[1]][1] + stop(paste0(rmsg, "streaming query error - ", first), call. = FALSE) + } else if (any(grep("java.lang.IllegalArgumentException: ", stacktrace))) { msg <- strsplit(stacktrace, "java.lang.IllegalArgumentException: ", fixed = TRUE)[[1]] # Extract "Error in ..." message. rmsg <- msg[1] @@ -837,6 +846,32 @@ captureJVMException <- function(e, method) { # Extract the first message of JVM exception. first <- strsplit(msg[2], "\r?\n\tat")[[1]][1] stop(paste0(rmsg, "analysis error - ", first), call. = FALSE) + } else + if (any(grep("org.apache.spark.sql.catalyst.analysis.NoSuchDatabaseException: ", stacktrace))) { + msg <- strsplit(stacktrace, "org.apache.spark.sql.catalyst.analysis.NoSuchDatabaseException: ", + fixed = TRUE)[[1]] + # Extract "Error in ..." message. + rmsg <- msg[1] + # Extract the first message of JVM exception. + first <- strsplit(msg[2], "\r?\n\tat")[[1]][1] + stop(paste0(rmsg, "no such database - ", first), call. = FALSE) + } else + if (any(grep("org.apache.spark.sql.catalyst.analysis.NoSuchTableException: ", stacktrace))) { + msg <- strsplit(stacktrace, "org.apache.spark.sql.catalyst.analysis.NoSuchTableException: ", + fixed = TRUE)[[1]] + # Extract "Error in ..." message. + rmsg <- msg[1] + # Extract the first message of JVM exception. + first <- strsplit(msg[2], "\r?\n\tat")[[1]][1] + stop(paste0(rmsg, "no such table - ", first), call. = FALSE) + } else if (any(grep("org.apache.spark.sql.catalyst.parser.ParseException: ", stacktrace))) { + msg <- strsplit(stacktrace, "org.apache.spark.sql.catalyst.parser.ParseException: ", + fixed = TRUE)[[1]] + # Extract "Error in ..." message. + rmsg <- msg[1] + # Extract the first message of JVM exception. + first <- strsplit(msg[2], "\r?\n\tat")[[1]][1] + stop(paste0(rmsg, "parse error - ", first), call. = FALSE) } else { stop(stacktrace, call. = FALSE) } diff --git a/R/pkg/inst/tests/testthat/test_Serde.R b/R/pkg/inst/tests/testthat/test_Serde.R index b5f6f1b54fa85..518fb7bd94043 100644 --- a/R/pkg/inst/tests/testthat/test_Serde.R +++ b/R/pkg/inst/tests/testthat/test_Serde.R @@ -20,6 +20,8 @@ context("SerDe functionality") sparkSession <- sparkR.session(enableHiveSupport = FALSE) test_that("SerDe of primitive types", { + skip_on_cran() + x <- callJStatic("SparkRHandler", "echo", 1L) expect_equal(x, 1L) expect_equal(class(x), "integer") @@ -38,6 +40,8 @@ test_that("SerDe of primitive types", { }) test_that("SerDe of list of primitive types", { + skip_on_cran() + x <- list(1L, 2L, 3L) y <- callJStatic("SparkRHandler", "echo", x) expect_equal(x, y) @@ -65,6 +69,8 @@ test_that("SerDe of list of primitive types", { }) test_that("SerDe of list of lists", { + skip_on_cran() + x <- list(list(1L, 2L, 3L), list(1, 2, 3), list(TRUE, FALSE), list("a", "b", "c")) y <- callJStatic("SparkRHandler", "echo", x) diff --git a/R/pkg/inst/tests/testthat/test_Windows.R b/R/pkg/inst/tests/testthat/test_Windows.R index e8d983426a676..919b063bf0693 100644 --- a/R/pkg/inst/tests/testthat/test_Windows.R +++ b/R/pkg/inst/tests/testthat/test_Windows.R @@ -17,9 +17,12 @@ context("Windows-specific tests") test_that("sparkJars tag in SparkContext", { + skip_on_cran() + if (.Platform$OS.type != "windows") { skip("This test is only for Windows, skipped") } + testOutput <- launchScript("ECHO", "a/b/c", wait = TRUE) abcPath <- testOutput[1] expect_equal(abcPath, "a\\b\\c") diff --git a/R/pkg/inst/tests/testthat/test_binaryFile.R b/R/pkg/inst/tests/testthat/test_binaryFile.R index b5c279e3156e5..63f54e1af02b1 100644 --- a/R/pkg/inst/tests/testthat/test_binaryFile.R +++ b/R/pkg/inst/tests/testthat/test_binaryFile.R @@ -24,6 +24,8 @@ sc <- callJStatic("org.apache.spark.sql.api.r.SQLUtils", "getJavaSparkContext", mockFile <- c("Spark is pretty.", "Spark is awesome.") test_that("saveAsObjectFile()/objectFile() following textFile() works", { + skip_on_cran() + fileName1 <- tempfile(pattern = "spark-test", fileext = ".tmp") fileName2 <- tempfile(pattern = "spark-test", fileext = ".tmp") writeLines(mockFile, fileName1) @@ -38,6 +40,8 @@ test_that("saveAsObjectFile()/objectFile() following textFile() works", { }) test_that("saveAsObjectFile()/objectFile() works on a parallelized list", { + skip_on_cran() + fileName <- tempfile(pattern = "spark-test", fileext = ".tmp") l <- list(1, 2, 3) @@ -50,6 +54,8 @@ test_that("saveAsObjectFile()/objectFile() works on a parallelized list", { }) test_that("saveAsObjectFile()/objectFile() following RDD transformations works", { + skip_on_cran() + fileName1 <- tempfile(pattern = "spark-test", fileext = ".tmp") fileName2 <- tempfile(pattern = "spark-test", fileext = ".tmp") writeLines(mockFile, fileName1) @@ -74,6 +80,8 @@ test_that("saveAsObjectFile()/objectFile() following RDD transformations works", }) test_that("saveAsObjectFile()/objectFile() works with multiple paths", { + skip_on_cran() + fileName1 <- tempfile(pattern = "spark-test", fileext = ".tmp") fileName2 <- tempfile(pattern = "spark-test", fileext = ".tmp") diff --git a/R/pkg/inst/tests/testthat/test_binary_function.R b/R/pkg/inst/tests/testthat/test_binary_function.R index 59cb2e6204405..25bb2b84266dd 100644 --- a/R/pkg/inst/tests/testthat/test_binary_function.R +++ b/R/pkg/inst/tests/testthat/test_binary_function.R @@ -29,6 +29,8 @@ rdd <- parallelize(sc, nums, 2L) mockFile <- c("Spark is pretty.", "Spark is awesome.") test_that("union on two RDDs", { + skip_on_cran() + actual <- collectRDD(unionRDD(rdd, rdd)) expect_equal(actual, as.list(rep(nums, 2))) @@ -51,6 +53,8 @@ test_that("union on two RDDs", { }) test_that("cogroup on two RDDs", { + skip_on_cran() + rdd1 <- parallelize(sc, list(list(1, 1), list(2, 4))) rdd2 <- parallelize(sc, list(list(1, 2), list(1, 3))) cogroup.rdd <- cogroup(rdd1, rdd2, numPartitions = 2L) @@ -69,6 +73,8 @@ test_that("cogroup on two RDDs", { }) test_that("zipPartitions() on RDDs", { + skip_on_cran() + rdd1 <- parallelize(sc, 1:2, 2L) # 1, 2 rdd2 <- parallelize(sc, 1:4, 2L) # 1:2, 3:4 rdd3 <- parallelize(sc, 1:6, 2L) # 1:3, 4:6 diff --git a/R/pkg/inst/tests/testthat/test_broadcast.R b/R/pkg/inst/tests/testthat/test_broadcast.R index 65f204d096f43..504ded4fc8623 100644 --- a/R/pkg/inst/tests/testthat/test_broadcast.R +++ b/R/pkg/inst/tests/testthat/test_broadcast.R @@ -26,6 +26,8 @@ nums <- 1:2 rrdd <- parallelize(sc, nums, 2L) test_that("using broadcast variable", { + skip_on_cran() + randomMat <- matrix(nrow = 10, ncol = 10, data = rnorm(100)) randomMatBr <- broadcast(sc, randomMat) @@ -38,6 +40,8 @@ test_that("using broadcast variable", { }) test_that("without using broadcast variable", { + skip_on_cran() + randomMat <- matrix(nrow = 10, ncol = 10, data = rnorm(100)) useBroadcast <- function(x) { diff --git a/R/pkg/inst/tests/testthat/test_client.R b/R/pkg/inst/tests/testthat/test_client.R index 0cf25fe1dbf39..3d53bebab6300 100644 --- a/R/pkg/inst/tests/testthat/test_client.R +++ b/R/pkg/inst/tests/testthat/test_client.R @@ -18,6 +18,8 @@ context("functions in client.R") test_that("adding spark-testing-base as a package works", { + skip_on_cran() + args <- generateSparkSubmitArgs("", "", "", "", "holdenk:spark-testing-base:1.3.0_0.0.5") expect_equal(gsub("[[:space:]]", "", args), @@ -26,16 +28,22 @@ test_that("adding spark-testing-base as a package works", { }) test_that("no package specified doesn't add packages flag", { + skip_on_cran() + args <- generateSparkSubmitArgs("", "", "", "", "") expect_equal(gsub("[[:space:]]", "", args), "") }) test_that("multiple packages don't produce a warning", { + skip_on_cran() + expect_warning(generateSparkSubmitArgs("", "", "", "", c("A", "B")), NA) }) test_that("sparkJars sparkPackages as character vectors", { + skip_on_cran() + args <- generateSparkSubmitArgs("", "", c("one.jar", "two.jar", "three.jar"), "", c("com.databricks:spark-avro_2.10:2.0.1")) expect_match(args, "--jars one.jar,two.jar,three.jar") diff --git a/R/pkg/inst/tests/testthat/test_context.R b/R/pkg/inst/tests/testthat/test_context.R index caca06933952b..632a90d68177f 100644 --- a/R/pkg/inst/tests/testthat/test_context.R +++ b/R/pkg/inst/tests/testthat/test_context.R @@ -18,13 +18,15 @@ context("test functions in sparkR.R") test_that("Check masked functions", { + skip_on_cran() + # Check that we are not masking any new function from base, stats, testthat unexpectedly # NOTE: We should avoid adding entries to *namesOfMaskedCompletely* as masked functions make it # hard for users to use base R functions. Please check when in doubt. - namesOfMaskedCompletely <- c("cov", "filter", "sample") + namesOfMaskedCompletely <- c("cov", "filter", "sample", "not") namesOfMasked <- c("describe", "cov", "filter", "lag", "na.omit", "predict", "sd", "var", "colnames", "colnames<-", "intersect", "rank", "rbind", "sample", "subset", - "summary", "transform", "drop", "window", "as.data.frame", "union") + "summary", "transform", "drop", "window", "as.data.frame", "union", "not") if (as.numeric(R.version$major) >= 3 && as.numeric(R.version$minor) >= 3) { namesOfMasked <- c("endsWith", "startsWith", namesOfMasked) } @@ -55,6 +57,8 @@ test_that("Check masked functions", { }) test_that("repeatedly starting and stopping SparkR", { + skip_on_cran() + for (i in 1:4) { sc <- suppressWarnings(sparkR.init()) rdd <- parallelize(sc, 1:20, 2L) @@ -73,6 +77,8 @@ test_that("repeatedly starting and stopping SparkSession", { }) test_that("rdd GC across sparkR.stop", { + skip_on_cran() + sc <- sparkR.sparkContext() # sc should get id 0 rdd1 <- parallelize(sc, 1:20, 2L) # rdd1 should get id 1 rdd2 <- parallelize(sc, 1:10, 2L) # rdd2 should get id 2 @@ -96,6 +102,8 @@ test_that("rdd GC across sparkR.stop", { }) test_that("job group functions can be called", { + skip_on_cran() + sc <- sparkR.sparkContext() setJobGroup("groupId", "job description", TRUE) cancelJobGroup("groupId") @@ -108,12 +116,16 @@ test_that("job group functions can be called", { }) test_that("utility function can be called", { + skip_on_cran() + sparkR.sparkContext() setLogLevel("ERROR") sparkR.session.stop() }) test_that("getClientModeSparkSubmitOpts() returns spark-submit args from whitelist", { + skip_on_cran() + e <- new.env() e[["spark.driver.memory"]] <- "512m" ops <- getClientModeSparkSubmitOpts("sparkrmain", e) @@ -141,6 +153,8 @@ test_that("getClientModeSparkSubmitOpts() returns spark-submit args from whiteli }) test_that("sparkJars sparkPackages as comma-separated strings", { + skip_on_cran() + expect_warning(processSparkJars(" a, b ")) jars <- suppressWarnings(processSparkJars(" a, b ")) expect_equal(lapply(jars, basename), list("a", "b")) @@ -168,6 +182,8 @@ test_that("spark.lapply should perform simple transforms", { }) test_that("add and get file to be downloaded with Spark job on every node", { + skip_on_cran() + sparkR.sparkContext() # Test add file. path <- tempfile(pattern = "hello", fileext = ".txt") @@ -177,6 +193,13 @@ test_that("add and get file to be downloaded with Spark job on every node", { spark.addFile(path) download_path <- spark.getSparkFiles(filename) expect_equal(readLines(download_path), words) + + # Test spark.getSparkFiles works well on executors. + seq <- seq(from = 1, to = 10, length.out = 5) + f <- function(seq) { spark.getSparkFiles(filename) } + results <- spark.lapply(seq, f) + for (i in 1:5) { expect_equal(basename(results[[i]]), filename) } + unlink(path) # Test add directory recursively. diff --git a/R/pkg/inst/tests/testthat/test_includePackage.R b/R/pkg/inst/tests/testthat/test_includePackage.R index 563ea298c2dd8..f823ad8e9c985 100644 --- a/R/pkg/inst/tests/testthat/test_includePackage.R +++ b/R/pkg/inst/tests/testthat/test_includePackage.R @@ -26,6 +26,8 @@ nums <- 1:2 rdd <- parallelize(sc, nums, 2L) test_that("include inside function", { + skip_on_cran() + # Only run the test if plyr is installed. if ("plyr" %in% rownames(installed.packages())) { suppressPackageStartupMessages(library(plyr)) @@ -42,6 +44,8 @@ test_that("include inside function", { }) test_that("use include package", { + skip_on_cran() + # Only run the test if plyr is installed. if ("plyr" %in% rownames(installed.packages())) { suppressPackageStartupMessages(library(plyr)) diff --git a/R/pkg/inst/tests/testthat/test_mllib_classification.R b/R/pkg/inst/tests/testthat/test_mllib_classification.R index 459254d271a58..cbc7087182868 100644 --- a/R/pkg/inst/tests/testthat/test_mllib_classification.R +++ b/R/pkg/inst/tests/testthat/test_mllib_classification.R @@ -284,22 +284,11 @@ test_that("spark.mlp", { c("1.0", "1.0", "1.0", "1.0", "0.0", "1.0", "2.0", "2.0", "1.0", "0.0")) # test initialWeights - model <- spark.mlp(df, label ~ features, layers = c(4, 3), maxIter = 2, initialWeights = + model <- spark.mlp(df, label ~ features, layers = c(4, 3), initialWeights = c(0, 0, 0, 0, 0, 5, 5, 5, 5, 5, 9, 9, 9, 9, 9)) mlpPredictions <- collect(select(predict(model, mlpTestDF), "prediction")) expect_equal(head(mlpPredictions$prediction, 10), - c("1.0", "1.0", "1.0", "1.0", "2.0", "1.0", "2.0", "2.0", "1.0", "0.0")) - - model <- spark.mlp(df, label ~ features, layers = c(4, 3), maxIter = 2, initialWeights = - c(0.0, 0.0, 0.0, 0.0, 0.0, 5.0, 5.0, 5.0, 5.0, 5.0, 9.0, 9.0, 9.0, 9.0, 9.0)) - mlpPredictions <- collect(select(predict(model, mlpTestDF), "prediction")) - expect_equal(head(mlpPredictions$prediction, 10), - c("1.0", "1.0", "1.0", "1.0", "2.0", "1.0", "2.0", "2.0", "1.0", "0.0")) - - model <- spark.mlp(df, label ~ features, layers = c(4, 3), maxIter = 2) - mlpPredictions <- collect(select(predict(model, mlpTestDF), "prediction")) - expect_equal(head(mlpPredictions$prediction, 10), - c("1.0", "1.0", "1.0", "1.0", "0.0", "1.0", "0.0", "2.0", "1.0", "0.0")) + c("1.0", "1.0", "1.0", "1.0", "0.0", "1.0", "2.0", "2.0", "1.0", "0.0")) # Test formula works well df <- suppressWarnings(createDataFrame(iris)) @@ -310,8 +299,6 @@ test_that("spark.mlp", { expect_equal(summary$numOfOutputs, 3) expect_equal(summary$layers, c(4, 3)) expect_equal(length(summary$weights), 15) - expect_equal(head(summary$weights, 5), list(-1.1957257, -5.2693685, 7.4489734, -6.3751413, - -10.2376130), tolerance = 1e-6) }) test_that("spark.naiveBayes", { diff --git a/R/pkg/inst/tests/testthat/test_mllib_clustering.R b/R/pkg/inst/tests/testthat/test_mllib_clustering.R index 1661e987b730f..478012e8828cd 100644 --- a/R/pkg/inst/tests/testthat/test_mllib_clustering.R +++ b/R/pkg/inst/tests/testthat/test_mllib_clustering.R @@ -255,6 +255,8 @@ test_that("spark.lda with libsvm", { }) test_that("spark.lda with text input", { + skip_on_cran() + text <- read.text(absoluteSparkPath("data/mllib/sample_lda_data.txt")) model <- spark.lda(text, optimizer = "online", features = "value") @@ -297,6 +299,8 @@ test_that("spark.lda with text input", { }) test_that("spark.posterior and spark.perplexity", { + skip_on_cran() + text <- read.text(absoluteSparkPath("data/mllib/sample_lda_data.txt")) model <- spark.lda(text, features = "value", k = 3) diff --git a/R/pkg/inst/tests/testthat/test_mllib_fpm.R b/R/pkg/inst/tests/testthat/test_mllib_fpm.R new file mode 100644 index 0000000000000..c38f1133897dd --- /dev/null +++ b/R/pkg/inst/tests/testthat/test_mllib_fpm.R @@ -0,0 +1,83 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +library(testthat) + +context("MLlib frequent pattern mining") + +# Tests for MLlib frequent pattern mining algorithms in SparkR +sparkSession <- sparkR.session(enableHiveSupport = FALSE) + +test_that("spark.fpGrowth", { + data <- selectExpr(createDataFrame(data.frame(items = c( + "1,2", + "1,2", + "1,2,3", + "1,3" + ))), "split(items, ',') as items") + + model <- spark.fpGrowth(data, minSupport = 0.3, minConfidence = 0.8, numPartitions = 1) + + itemsets <- collect(spark.freqItemsets(model)) + + expected_itemsets <- data.frame( + items = I(list(list("3"), list("3", "1"), list("2"), list("2", "1"), list("1"))), + freq = c(2, 2, 3, 3, 4) + ) + + expect_equivalent(expected_itemsets, itemsets) + + expected_association_rules <- data.frame( + antecedent = I(list(list("2"), list("3"))), + consequent = I(list(list("1"), list("1"))), + confidence = c(1, 1) + ) + + expect_equivalent(expected_association_rules, collect(spark.associationRules(model))) + + new_data <- selectExpr(createDataFrame(data.frame(items = c( + "1,2", + "1,3", + "2,3" + ))), "split(items, ',') as items") + + expected_predictions <- data.frame( + items = I(list(list("1", "2"), list("1", "3"), list("2", "3"))), + prediction = I(list(list(), list(), list("1"))) + ) + + expect_equivalent(expected_predictions, collect(predict(model, new_data))) + + modelPath <- tempfile(pattern = "spark-fpm", fileext = ".tmp") + write.ml(model, modelPath, overwrite = TRUE) + loaded_model <- read.ml(modelPath) + + expect_equivalent( + itemsets, + collect(spark.freqItemsets(loaded_model))) + + unlink(modelPath) + + model_without_numpartitions <- spark.fpGrowth(data, minSupport = 0.3, minConfidence = 0.8) + expect_equal( + count(spark.freqItemsets(model_without_numpartitions)), + count(spark.freqItemsets(model)) + ) + +}) + +sparkR.session.stop() diff --git a/R/pkg/inst/tests/testthat/test_mllib_regression.R b/R/pkg/inst/tests/testthat/test_mllib_regression.R index 81a5bdc414927..58924f952c6bf 100644 --- a/R/pkg/inst/tests/testthat/test_mllib_regression.R +++ b/R/pkg/inst/tests/testthat/test_mllib_regression.R @@ -23,6 +23,8 @@ context("MLlib regression algorithms, except for tree-based algorithms") sparkSession <- sparkR.session(enableHiveSupport = FALSE) test_that("formula of spark.glm", { + skip_on_cran() + training <- suppressWarnings(createDataFrame(iris)) # directly calling the spark API # dot minus and intercept vs native glm @@ -77,6 +79,24 @@ test_that("spark.glm and predict", { out <- capture.output(print(summary(model))) expect_true(any(grepl("Dispersion parameter for gamma family", out))) + # tweedie family + model <- spark.glm(training, Sepal_Width ~ Sepal_Length + Species, + family = "tweedie", var.power = 1.2, link.power = 0.0) + prediction <- predict(model, training) + expect_equal(typeof(take(select(prediction, "prediction"), 1)$prediction), "double") + vals <- collect(select(prediction, "prediction")) + + # manual calculation of the R predicted values to avoid dependence on statmod + #' library(statmod) + #' rModel <- glm(Sepal.Width ~ Sepal.Length + Species, data = iris, + #' family = tweedie(var.power = 1.2, link.power = 0.0)) + #' print(coef(rModel)) + + rCoef <- c(0.6455409, 0.1169143, -0.3224752, -0.3282174) + rVals <- exp(as.numeric(model.matrix(Sepal.Width ~ Sepal.Length + Species, + data = iris) %*% rCoef)) + expect_true(all(abs(rVals - vals) < 1e-5), rVals - vals) + # Test stats::predict is working x <- rnorm(15) y <- x + rnorm(15) @@ -177,6 +197,8 @@ test_that("spark.glm summary", { }) test_that("spark.glm save/load", { + skip_on_cran() + training <- suppressWarnings(createDataFrame(iris)) m <- spark.glm(training, Sepal_Width ~ Sepal_Length + Species) s <- summary(m) @@ -204,6 +226,8 @@ test_that("spark.glm save/load", { }) test_that("formula of glm", { + skip_on_cran() + training <- suppressWarnings(createDataFrame(iris)) # dot minus and intercept vs native glm model <- glm(Sepal_Width ~ . - Species + 0, data = training) @@ -230,10 +254,12 @@ test_that("formula of glm", { }) test_that("glm and predict", { + skip_on_cran() + training <- suppressWarnings(createDataFrame(iris)) # gaussian family model <- glm(Sepal_Width ~ Sepal_Length + Species, data = training) - prediction <- predict(model, training) + prediction <- predict(model, training) expect_equal(typeof(take(select(prediction, "prediction"), 1)$prediction), "double") vals <- collect(select(prediction, "prediction")) rVals <- predict(glm(Sepal.Width ~ Sepal.Length + Species, data = iris), iris) @@ -249,6 +275,24 @@ test_that("glm and predict", { data = iris, family = poisson(link = identity)), iris)) expect_true(all(abs(rVals - vals) < 1e-6), rVals - vals) + # tweedie family + model <- glm(Sepal_Width ~ Sepal_Length + Species, data = training, + family = "tweedie", var.power = 1.2, link.power = 0.0) + prediction <- predict(model, training) + expect_equal(typeof(take(select(prediction, "prediction"), 1)$prediction), "double") + vals <- collect(select(prediction, "prediction")) + + # manual calculation of the R predicted values to avoid dependence on statmod + #' library(statmod) + #' rModel <- glm(Sepal.Width ~ Sepal.Length + Species, data = iris, + #' family = tweedie(var.power = 1.2, link.power = 0.0)) + #' print(coef(rModel)) + + rCoef <- c(0.6455409, 0.1169143, -0.3224752, -0.3282174) + rVals <- exp(as.numeric(model.matrix(Sepal.Width ~ Sepal.Length + Species, + data = iris) %*% rCoef)) + expect_true(all(abs(rVals - vals) < 1e-5), rVals - vals) + # Test stats::predict is working x <- rnorm(15) y <- x + rnorm(15) @@ -256,6 +300,8 @@ test_that("glm and predict", { }) test_that("glm summary", { + skip_on_cran() + # gaussian family training <- suppressWarnings(createDataFrame(iris)) stats <- summary(glm(Sepal_Width ~ Sepal_Length + Species, data = training)) @@ -305,6 +351,8 @@ test_that("glm summary", { }) test_that("glm save/load", { + skip_on_cran() + training <- suppressWarnings(createDataFrame(iris)) m <- glm(Sepal_Width ~ Sepal_Length + Species, data = training) s <- summary(m) diff --git a/R/pkg/inst/tests/testthat/test_mllib_tree.R b/R/pkg/inst/tests/testthat/test_mllib_tree.R index e6fda251ebea2..e0802a9b02d13 100644 --- a/R/pkg/inst/tests/testthat/test_mllib_tree.R +++ b/R/pkg/inst/tests/testthat/test_mllib_tree.R @@ -39,6 +39,7 @@ test_that("spark.gbt", { tolerance = 1e-4) stats <- summary(model) expect_equal(stats$numTrees, 20) + expect_equal(stats$maxDepth, 5) expect_equal(stats$formula, "Employed ~ .") expect_equal(stats$numFeatures, 6) expect_equal(length(stats$treeWeights), 20) @@ -53,6 +54,7 @@ test_that("spark.gbt", { expect_equal(stats$numFeatures, stats2$numFeatures) expect_equal(stats$features, stats2$features) expect_equal(stats$featureImportances, stats2$featureImportances) + expect_equal(stats$maxDepth, stats2$maxDepth) expect_equal(stats$numTrees, stats2$numTrees) expect_equal(stats$treeWeights, stats2$treeWeights) @@ -66,6 +68,7 @@ test_that("spark.gbt", { stats <- summary(model) expect_equal(stats$numFeatures, 2) expect_equal(stats$numTrees, 20) + expect_equal(stats$maxDepth, 5) expect_error(capture.output(stats), NA) expect_true(length(capture.output(stats)) > 6) predictions <- collect(predict(model, data))$prediction @@ -93,6 +96,7 @@ test_that("spark.gbt", { expect_equal(iris2$NumericSpecies, as.double(collect(predict(m, df))$prediction)) expect_equal(s$numFeatures, 5) expect_equal(s$numTrees, 20) + expect_equal(stats$maxDepth, 5) # spark.gbt classification can work on libsvm data data <- read.df(absoluteSparkPath("data/mllib/sample_binary_classification_data.txt"), @@ -116,6 +120,7 @@ test_that("spark.randomForest", { stats <- summary(model) expect_equal(stats$numTrees, 1) + expect_equal(stats$maxDepth, 5) expect_error(capture.output(stats), NA) expect_true(length(capture.output(stats)) > 6) @@ -129,6 +134,7 @@ test_that("spark.randomForest", { tolerance = 1e-4) stats <- summary(model) expect_equal(stats$numTrees, 20) + expect_equal(stats$maxDepth, 5) modelPath <- tempfile(pattern = "spark-randomForestRegression", fileext = ".tmp") write.ml(model, modelPath) @@ -141,6 +147,7 @@ test_that("spark.randomForest", { expect_equal(stats$features, stats2$features) expect_equal(stats$featureImportances, stats2$featureImportances) expect_equal(stats$numTrees, stats2$numTrees) + expect_equal(stats$maxDepth, stats2$maxDepth) expect_equal(stats$treeWeights, stats2$treeWeights) unlink(modelPath) @@ -153,6 +160,7 @@ test_that("spark.randomForest", { stats <- summary(model) expect_equal(stats$numFeatures, 2) expect_equal(stats$numTrees, 20) + expect_equal(stats$maxDepth, 5) expect_error(capture.output(stats), NA) expect_true(length(capture.output(stats)) > 6) # Test string prediction values @@ -187,6 +195,8 @@ test_that("spark.randomForest", { stats <- summary(model) expect_equal(stats$numFeatures, 2) expect_equal(stats$numTrees, 20) + expect_equal(stats$maxDepth, 5) + # Test numeric prediction values predictions <- collect(predict(model, data))$prediction expect_equal(length(grep("1.0", predictions)), 50) diff --git a/R/pkg/inst/tests/testthat/test_parallelize_collect.R b/R/pkg/inst/tests/testthat/test_parallelize_collect.R index 55972e1ba4693..1f7f387de08ce 100644 --- a/R/pkg/inst/tests/testthat/test_parallelize_collect.R +++ b/R/pkg/inst/tests/testthat/test_parallelize_collect.R @@ -39,6 +39,8 @@ jsc <- callJStatic("org.apache.spark.sql.api.r.SQLUtils", "getJavaSparkContext", # Tests test_that("parallelize() on simple vectors and lists returns an RDD", { + skip_on_cran() + numVectorRDD <- parallelize(jsc, numVector, 1) numVectorRDD2 <- parallelize(jsc, numVector, 10) numListRDD <- parallelize(jsc, numList, 1) @@ -66,6 +68,8 @@ test_that("parallelize() on simple vectors and lists returns an RDD", { }) test_that("collect(), following a parallelize(), gives back the original collections", { + skip_on_cran() + numVectorRDD <- parallelize(jsc, numVector, 10) expect_equal(collectRDD(numVectorRDD), as.list(numVector)) @@ -86,6 +90,8 @@ test_that("collect(), following a parallelize(), gives back the original collect }) test_that("regression: collect() following a parallelize() does not drop elements", { + skip_on_cran() + # 10 %/% 6 = 1, ceiling(10 / 6) = 2 collLen <- 10 numPart <- 6 @@ -95,6 +101,8 @@ test_that("regression: collect() following a parallelize() does not drop element }) test_that("parallelize() and collect() work for lists of pairs (pairwise data)", { + skip_on_cran() + # use the pairwise logical to indicate pairwise data numPairsRDDD1 <- parallelize(jsc, numPairs, 1) numPairsRDDD2 <- parallelize(jsc, numPairs, 2) diff --git a/R/pkg/inst/tests/testthat/test_rdd.R b/R/pkg/inst/tests/testthat/test_rdd.R index 787ef51c501c0..a3b1631e1d119 100644 --- a/R/pkg/inst/tests/testthat/test_rdd.R +++ b/R/pkg/inst/tests/testthat/test_rdd.R @@ -29,22 +29,30 @@ intPairs <- list(list(1L, -1), list(2L, 100), list(2L, 1), list(1L, 200)) intRdd <- parallelize(sc, intPairs, 2L) test_that("get number of partitions in RDD", { + skip_on_cran() + expect_equal(getNumPartitionsRDD(rdd), 2) expect_equal(getNumPartitionsRDD(intRdd), 2) }) test_that("first on RDD", { + skip_on_cran() + expect_equal(firstRDD(rdd), 1) newrdd <- lapply(rdd, function(x) x + 1) expect_equal(firstRDD(newrdd), 2) }) test_that("count and length on RDD", { - expect_equal(countRDD(rdd), 10) - expect_equal(lengthRDD(rdd), 10) + skip_on_cran() + + expect_equal(countRDD(rdd), 10) + expect_equal(lengthRDD(rdd), 10) }) test_that("count by values and keys", { + skip_on_cran() + mods <- lapply(rdd, function(x) { x %% 3 }) actual <- countByValue(mods) expected <- list(list(0, 3L), list(1, 4L), list(2, 3L)) @@ -56,30 +64,40 @@ test_that("count by values and keys", { }) test_that("lapply on RDD", { + skip_on_cran() + multiples <- lapply(rdd, function(x) { 2 * x }) actual <- collectRDD(multiples) expect_equal(actual, as.list(nums * 2)) }) test_that("lapplyPartition on RDD", { + skip_on_cran() + sums <- lapplyPartition(rdd, function(part) { sum(unlist(part)) }) actual <- collectRDD(sums) expect_equal(actual, list(15, 40)) }) test_that("mapPartitions on RDD", { + skip_on_cran() + sums <- mapPartitions(rdd, function(part) { sum(unlist(part)) }) actual <- collectRDD(sums) expect_equal(actual, list(15, 40)) }) test_that("flatMap() on RDDs", { + skip_on_cran() + flat <- flatMap(intRdd, function(x) { list(x, x) }) actual <- collectRDD(flat) expect_equal(actual, rep(intPairs, each = 2)) }) test_that("filterRDD on RDD", { + skip_on_cran() + filtered.rdd <- filterRDD(rdd, function(x) { x %% 2 == 0 }) actual <- collectRDD(filtered.rdd) expect_equal(actual, list(2, 4, 6, 8, 10)) @@ -95,6 +113,8 @@ test_that("filterRDD on RDD", { }) test_that("lookup on RDD", { + skip_on_cran() + vals <- lookup(intRdd, 1L) expect_equal(vals, list(-1, 200)) @@ -103,6 +123,8 @@ test_that("lookup on RDD", { }) test_that("several transformations on RDD (a benchmark on PipelinedRDD)", { + skip_on_cran() + rdd2 <- rdd for (i in 1:12) rdd2 <- lapplyPartitionsWithIndex( @@ -117,6 +139,8 @@ test_that("several transformations on RDD (a benchmark on PipelinedRDD)", { }) test_that("PipelinedRDD support actions: cache(), persist(), unpersist(), checkpoint()", { + skip_on_cran() + # RDD rdd2 <- rdd # PipelinedRDD @@ -143,8 +167,8 @@ test_that("PipelinedRDD support actions: cache(), persist(), unpersist(), checkp expect_false(rdd2@env$isCached) tempDir <- tempfile(pattern = "checkpoint") - setCheckpointDir(sc, tempDir) - checkpoint(rdd2) + setCheckpointDirSC(sc, tempDir) + checkpointRDD(rdd2) expect_true(rdd2@env$isCheckpointed) rdd2 <- lapply(rdd2, function(x) x) @@ -158,6 +182,8 @@ test_that("PipelinedRDD support actions: cache(), persist(), unpersist(), checkp }) test_that("reduce on RDD", { + skip_on_cran() + sum <- reduce(rdd, "+") expect_equal(sum, 55) @@ -167,6 +193,8 @@ test_that("reduce on RDD", { }) test_that("lapply with dependency", { + skip_on_cran() + fa <- 5 multiples <- lapply(rdd, function(x) { fa * x }) actual <- collectRDD(multiples) @@ -175,6 +203,8 @@ test_that("lapply with dependency", { }) test_that("lapplyPartitionsWithIndex on RDDs", { + skip_on_cran() + func <- function(partIndex, part) { list(partIndex, Reduce("+", part)) } actual <- collectRDD(lapplyPartitionsWithIndex(rdd, func), flatten = FALSE) expect_equal(actual, list(list(0, 15), list(1, 40))) @@ -191,10 +221,14 @@ test_that("lapplyPartitionsWithIndex on RDDs", { }) test_that("sampleRDD() on RDDs", { + skip_on_cran() + expect_equal(unlist(collectRDD(sampleRDD(rdd, FALSE, 1.0, 2014L))), nums) }) test_that("takeSample() on RDDs", { + skip_on_cran() + # ported from RDDSuite.scala, modified seeds data <- parallelize(sc, 1:100, 2L) for (seed in 4:5) { @@ -237,6 +271,8 @@ test_that("takeSample() on RDDs", { }) test_that("mapValues() on pairwise RDDs", { + skip_on_cran() + multiples <- mapValues(intRdd, function(x) { x * 2 }) actual <- collectRDD(multiples) expected <- lapply(intPairs, function(x) { @@ -246,6 +282,8 @@ test_that("mapValues() on pairwise RDDs", { }) test_that("flatMapValues() on pairwise RDDs", { + skip_on_cran() + l <- parallelize(sc, list(list(1, c(1, 2)), list(2, c(3, 4)))) actual <- collectRDD(flatMapValues(l, function(x) { x })) expect_equal(actual, list(list(1, 1), list(1, 2), list(2, 3), list(2, 4))) @@ -258,6 +296,8 @@ test_that("flatMapValues() on pairwise RDDs", { }) test_that("reduceByKeyLocally() on PairwiseRDDs", { + skip_on_cran() + pairs <- parallelize(sc, list(list(1, 2), list(1.1, 3), list(1, 4)), 2L) actual <- reduceByKeyLocally(pairs, "+") expect_equal(sortKeyValueList(actual), @@ -271,6 +311,8 @@ test_that("reduceByKeyLocally() on PairwiseRDDs", { }) test_that("distinct() on RDDs", { + skip_on_cran() + nums.rep2 <- rep(1:10, 2) rdd.rep2 <- parallelize(sc, nums.rep2, 2L) uniques <- distinctRDD(rdd.rep2) @@ -279,21 +321,29 @@ test_that("distinct() on RDDs", { }) test_that("maximum() on RDDs", { + skip_on_cran() + max <- maximum(rdd) expect_equal(max, 10) }) test_that("minimum() on RDDs", { + skip_on_cran() + min <- minimum(rdd) expect_equal(min, 1) }) test_that("sumRDD() on RDDs", { + skip_on_cran() + sum <- sumRDD(rdd) expect_equal(sum, 55) }) test_that("keyBy on RDDs", { + skip_on_cran() + func <- function(x) { x * x } keys <- keyBy(rdd, func) actual <- collectRDD(keys) @@ -301,6 +351,8 @@ test_that("keyBy on RDDs", { }) test_that("repartition/coalesce on RDDs", { + skip_on_cran() + rdd <- parallelize(sc, 1:20, 4L) # each partition contains 5 elements # repartition @@ -322,6 +374,8 @@ test_that("repartition/coalesce on RDDs", { }) test_that("sortBy() on RDDs", { + skip_on_cran() + sortedRdd <- sortBy(rdd, function(x) { x * x }, ascending = FALSE) actual <- collectRDD(sortedRdd) expect_equal(actual, as.list(sort(nums, decreasing = TRUE))) @@ -333,6 +387,8 @@ test_that("sortBy() on RDDs", { }) test_that("takeOrdered() on RDDs", { + skip_on_cran() + l <- list(10, 1, 2, 9, 3, 4, 5, 6, 7) rdd <- parallelize(sc, l) actual <- takeOrdered(rdd, 6L) @@ -345,6 +401,8 @@ test_that("takeOrdered() on RDDs", { }) test_that("top() on RDDs", { + skip_on_cran() + l <- list(10, 1, 2, 9, 3, 4, 5, 6, 7) rdd <- parallelize(sc, l) actual <- top(rdd, 6L) @@ -357,6 +415,8 @@ test_that("top() on RDDs", { }) test_that("fold() on RDDs", { + skip_on_cran() + actual <- fold(rdd, 0, "+") expect_equal(actual, Reduce("+", nums, 0)) @@ -366,6 +426,8 @@ test_that("fold() on RDDs", { }) test_that("aggregateRDD() on RDDs", { + skip_on_cran() + rdd <- parallelize(sc, list(1, 2, 3, 4)) zeroValue <- list(0, 0) seqOp <- function(x, y) { list(x[[1]] + y, x[[2]] + 1) } @@ -379,6 +441,8 @@ test_that("aggregateRDD() on RDDs", { }) test_that("zipWithUniqueId() on RDDs", { + skip_on_cran() + rdd <- parallelize(sc, list("a", "b", "c", "d", "e"), 3L) actual <- collectRDD(zipWithUniqueId(rdd)) expected <- list(list("a", 0), list("b", 1), list("c", 4), @@ -393,6 +457,8 @@ test_that("zipWithUniqueId() on RDDs", { }) test_that("zipWithIndex() on RDDs", { + skip_on_cran() + rdd <- parallelize(sc, list("a", "b", "c", "d", "e"), 3L) actual <- collectRDD(zipWithIndex(rdd)) expected <- list(list("a", 0), list("b", 1), list("c", 2), @@ -407,24 +473,32 @@ test_that("zipWithIndex() on RDDs", { }) test_that("glom() on RDD", { + skip_on_cran() + rdd <- parallelize(sc, as.list(1:4), 2L) actual <- collectRDD(glom(rdd)) expect_equal(actual, list(list(1, 2), list(3, 4))) }) test_that("keys() on RDDs", { + skip_on_cran() + keys <- keys(intRdd) actual <- collectRDD(keys) expect_equal(actual, lapply(intPairs, function(x) { x[[1]] })) }) test_that("values() on RDDs", { + skip_on_cran() + values <- values(intRdd) actual <- collectRDD(values) expect_equal(actual, lapply(intPairs, function(x) { x[[2]] })) }) test_that("pipeRDD() on RDDs", { + skip_on_cran() + actual <- collectRDD(pipeRDD(rdd, "more")) expected <- as.list(as.character(1:10)) expect_equal(actual, expected) @@ -442,6 +516,8 @@ test_that("pipeRDD() on RDDs", { }) test_that("zipRDD() on RDDs", { + skip_on_cran() + rdd1 <- parallelize(sc, 0:4, 2) rdd2 <- parallelize(sc, 1000:1004, 2) actual <- collectRDD(zipRDD(rdd1, rdd2)) @@ -471,6 +547,8 @@ test_that("zipRDD() on RDDs", { }) test_that("cartesian() on RDDs", { + skip_on_cran() + rdd <- parallelize(sc, 1:3) actual <- collectRDD(cartesian(rdd, rdd)) expect_equal(sortKeyValueList(actual), @@ -514,6 +592,8 @@ test_that("cartesian() on RDDs", { }) test_that("subtract() on RDDs", { + skip_on_cran() + l <- list(1, 1, 2, 2, 3, 4) rdd1 <- parallelize(sc, l) @@ -541,6 +621,8 @@ test_that("subtract() on RDDs", { }) test_that("subtractByKey() on pairwise RDDs", { + skip_on_cran() + l <- list(list("a", 1), list("b", 4), list("b", 5), list("a", 2)) rdd1 <- parallelize(sc, l) @@ -570,6 +652,8 @@ test_that("subtractByKey() on pairwise RDDs", { }) test_that("intersection() on RDDs", { + skip_on_cran() + # intersection with self actual <- collectRDD(intersection(rdd, rdd)) expect_equal(sort(as.integer(actual)), nums) @@ -586,6 +670,8 @@ test_that("intersection() on RDDs", { }) test_that("join() on pairwise RDDs", { + skip_on_cran() + rdd1 <- parallelize(sc, list(list(1, 1), list(2, 4))) rdd2 <- parallelize(sc, list(list(1, 2), list(1, 3))) actual <- collectRDD(joinRDD(rdd1, rdd2, 2L)) @@ -610,6 +696,8 @@ test_that("join() on pairwise RDDs", { }) test_that("leftOuterJoin() on pairwise RDDs", { + skip_on_cran() + rdd1 <- parallelize(sc, list(list(1, 1), list(2, 4))) rdd2 <- parallelize(sc, list(list(1, 2), list(1, 3))) actual <- collectRDD(leftOuterJoin(rdd1, rdd2, 2L)) @@ -640,6 +728,8 @@ test_that("leftOuterJoin() on pairwise RDDs", { }) test_that("rightOuterJoin() on pairwise RDDs", { + skip_on_cran() + rdd1 <- parallelize(sc, list(list(1, 2), list(1, 3))) rdd2 <- parallelize(sc, list(list(1, 1), list(2, 4))) actual <- collectRDD(rightOuterJoin(rdd1, rdd2, 2L)) @@ -667,6 +757,8 @@ test_that("rightOuterJoin() on pairwise RDDs", { }) test_that("fullOuterJoin() on pairwise RDDs", { + skip_on_cran() + rdd1 <- parallelize(sc, list(list(1, 2), list(1, 3), list(3, 3))) rdd2 <- parallelize(sc, list(list(1, 1), list(2, 4))) actual <- collectRDD(fullOuterJoin(rdd1, rdd2, 2L)) @@ -698,6 +790,8 @@ test_that("fullOuterJoin() on pairwise RDDs", { }) test_that("sortByKey() on pairwise RDDs", { + skip_on_cran() + numPairsRdd <- map(rdd, function(x) { list (x, x) }) sortedRdd <- sortByKey(numPairsRdd, ascending = FALSE) actual <- collectRDD(sortedRdd) @@ -747,6 +841,8 @@ test_that("sortByKey() on pairwise RDDs", { }) test_that("collectAsMap() on a pairwise RDD", { + skip_on_cran() + rdd <- parallelize(sc, list(list(1, 2), list(3, 4))) vals <- collectAsMap(rdd) expect_equal(vals, list(`1` = 2, `3` = 4)) @@ -765,11 +861,15 @@ test_that("collectAsMap() on a pairwise RDD", { }) test_that("show()", { + skip_on_cran() + rdd <- parallelize(sc, list(1:10)) expect_output(showRDD(rdd), "ParallelCollectionRDD\\[\\d+\\] at parallelize at RRDD\\.scala:\\d+") }) test_that("sampleByKey() on pairwise RDDs", { + skip_on_cran() + rdd <- parallelize(sc, 1:2000) pairsRDD <- lapply(rdd, function(x) { if (x %% 2 == 0) list("a", x) else list("b", x) }) fractions <- list(a = 0.2, b = 0.1) @@ -794,6 +894,8 @@ test_that("sampleByKey() on pairwise RDDs", { }) test_that("Test correct concurrency of RRDD.compute()", { + skip_on_cran() + rdd <- parallelize(sc, 1:1000, 100) jrdd <- getJRDD(lapply(rdd, function(x) { x }), "row") zrdd <- callJMethod(jrdd, "zip", jrdd) diff --git a/R/pkg/inst/tests/testthat/test_shuffle.R b/R/pkg/inst/tests/testthat/test_shuffle.R index d38efab0fd1df..cedf4f100c6c4 100644 --- a/R/pkg/inst/tests/testthat/test_shuffle.R +++ b/R/pkg/inst/tests/testthat/test_shuffle.R @@ -37,6 +37,8 @@ strList <- list("Dexter Morgan: Blood. Sometimes it sets my teeth on edge and ", strListRDD <- parallelize(sc, strList, 4) test_that("groupByKey for integers", { + skip_on_cran() + grouped <- groupByKey(intRdd, 2L) actual <- collectRDD(grouped) @@ -46,6 +48,8 @@ test_that("groupByKey for integers", { }) test_that("groupByKey for doubles", { + skip_on_cran() + grouped <- groupByKey(doubleRdd, 2L) actual <- collectRDD(grouped) @@ -55,6 +59,8 @@ test_that("groupByKey for doubles", { }) test_that("reduceByKey for ints", { + skip_on_cran() + reduced <- reduceByKey(intRdd, "+", 2L) actual <- collectRDD(reduced) @@ -64,6 +70,8 @@ test_that("reduceByKey for ints", { }) test_that("reduceByKey for doubles", { + skip_on_cran() + reduced <- reduceByKey(doubleRdd, "+", 2L) actual <- collectRDD(reduced) @@ -72,6 +80,8 @@ test_that("reduceByKey for doubles", { }) test_that("combineByKey for ints", { + skip_on_cran() + reduced <- combineByKey(intRdd, function(x) { x }, "+", "+", 2L) actual <- collectRDD(reduced) @@ -81,6 +91,8 @@ test_that("combineByKey for ints", { }) test_that("combineByKey for doubles", { + skip_on_cran() + reduced <- combineByKey(doubleRdd, function(x) { x }, "+", "+", 2L) actual <- collectRDD(reduced) @@ -89,6 +101,8 @@ test_that("combineByKey for doubles", { }) test_that("combineByKey for characters", { + skip_on_cran() + stringKeyRDD <- parallelize(sc, list(list("max", 1L), list("min", 2L), list("other", 3L), list("max", 4L)), 2L) @@ -101,6 +115,8 @@ test_that("combineByKey for characters", { }) test_that("aggregateByKey", { + skip_on_cran() + # test aggregateByKey for int keys rdd <- parallelize(sc, list(list(1, 1), list(1, 2), list(2, 3), list(2, 4))) @@ -129,6 +145,8 @@ test_that("aggregateByKey", { }) test_that("foldByKey", { + skip_on_cran() + # test foldByKey for int keys folded <- foldByKey(intRdd, 0, "+", 2L) @@ -172,6 +190,8 @@ test_that("foldByKey", { }) test_that("partitionBy() partitions data correctly", { + skip_on_cran() + # Partition by magnitude partitionByMagnitude <- function(key) { if (key >= 3) 1 else 0 } @@ -187,6 +207,8 @@ test_that("partitionBy() partitions data correctly", { }) test_that("partitionBy works with dependencies", { + skip_on_cran() + kOne <- 1 partitionByParity <- function(key) { if (key %% 2 == kOne) 7 else 4 } @@ -205,6 +227,8 @@ test_that("partitionBy works with dependencies", { }) test_that("test partitionBy with string keys", { + skip_on_cran() + words <- flatMap(strListRDD, function(line) { strsplit(line, " ")[[1]] }) wordCount <- lapply(words, function(word) { list(word, 1L) }) diff --git a/R/pkg/inst/tests/testthat/test_sparkR.R b/R/pkg/inst/tests/testthat/test_sparkR.R index f73fc6baeccef..a40981c188f7a 100644 --- a/R/pkg/inst/tests/testthat/test_sparkR.R +++ b/R/pkg/inst/tests/testthat/test_sparkR.R @@ -18,6 +18,8 @@ context("functions in sparkR.R") test_that("sparkCheckInstall", { + skip_on_cran() + # "local, yarn-client, mesos-client" mode, SPARK_HOME was set correctly, # and the SparkR job was submitted by "spark-submit" sparkHome <- paste0(tempdir(), "/", "sparkHome") diff --git a/R/pkg/inst/tests/testthat/test_sparkSQL.R b/R/pkg/inst/tests/testthat/test_sparkSQL.R index 7c096597fea66..47cc34a6c5b75 100644 --- a/R/pkg/inst/tests/testthat/test_sparkSQL.R +++ b/R/pkg/inst/tests/testthat/test_sparkSQL.R @@ -60,6 +60,7 @@ unsetHiveContext <- function() { # Tests for SparkSQL functions in SparkR +filesBefore <- list.files(path = sparkRDir, all.files = TRUE) sparkSession <- sparkR.session() sc <- callJStatic("org.apache.spark.sql.api.r.SQLUtils", "getJavaSparkContext", sparkSession) @@ -96,15 +97,21 @@ mapTypeJsonPath <- tempfile(pattern = "sparkr-test", fileext = ".tmp") writeLines(mockLinesMapType, mapTypeJsonPath) test_that("calling sparkRSQL.init returns existing SQL context", { + skip_on_cran() + sqlContext <- suppressWarnings(sparkRSQL.init(sc)) expect_equal(suppressWarnings(sparkRSQL.init(sc)), sqlContext) }) test_that("calling sparkRSQL.init returns existing SparkSession", { + skip_on_cran() + expect_equal(suppressWarnings(sparkRSQL.init(sc)), sparkSession) }) test_that("calling sparkR.session returns existing SparkSession", { + skip_on_cran() + expect_equal(sparkR.session(), sparkSession) }) @@ -139,7 +146,71 @@ test_that("structType and structField", { expect_equal(testSchema$fields()[[1]]$dataType.toString(), "StringType") }) +test_that("structField type strings", { + # positive cases + primitiveTypes <- list(byte = "ByteType", + integer = "IntegerType", + float = "FloatType", + double = "DoubleType", + string = "StringType", + binary = "BinaryType", + boolean = "BooleanType", + timestamp = "TimestampType", + date = "DateType", + tinyint = "ByteType", + smallint = "ShortType", + int = "IntegerType", + bigint = "LongType", + decimal = "DecimalType(10,0)") + + complexTypes <- list("map" = "MapType(StringType,IntegerType,true)", + "array" = "ArrayType(StringType,true)", + "struct" = "StructType(StructField(a,StringType,true))") + + typeList <- c(primitiveTypes, complexTypes) + typeStrings <- names(typeList) + + for (i in seq_along(typeStrings)){ + typeString <- typeStrings[i] + expected <- typeList[[i]] + testField <- structField("_col", typeString) + expect_is(testField, "structField") + expect_true(testField$nullable()) + expect_equal(testField$dataType.toString(), expected) + } + + # negative cases + primitiveErrors <- list(Byte = "Byte", + INTEGER = "INTEGER", + numeric = "numeric", + character = "character", + raw = "raw", + logical = "logical", + short = "short", + varchar = "varchar", + long = "long", + char = "char") + + complexErrors <- list("map" = " integer", + "array" = "String", + "struct" = "string ", + "map " = "map ", + "array< string>" = " string", + "struct" = " string") + + errorList <- c(primitiveErrors, complexErrors) + typeStrings <- names(errorList) + + for (i in seq_along(typeStrings)){ + typeString <- typeStrings[i] + expected <- paste0("Unsupported type for SparkDataframe: ", errorList[[i]]) + expect_error(structField("_col", typeString), expected) + } +}) + test_that("create DataFrame from RDD", { + skip_on_cran() + rdd <- lapply(parallelize(sc, 1:10), function(x) { list(x, as.character(x)) }) df <- createDataFrame(rdd, list("a", "b")) dfAsDF <- as.DataFrame(rdd, list("a", "b")) @@ -227,7 +298,7 @@ test_that("create DataFrame from RDD", { setHiveContext(sc) sql("CREATE TABLE people (name string, age double, height float)") df <- read.df(jsonPathNa, "json", schema) - invisible(insertInto(df, "people")) + insertInto(df, "people") expect_equal(collect(sql("SELECT age from people WHERE name = 'Bob'"))$age, c(16)) expect_equal(collect(sql("SELECT height from people WHERE name ='Bob'"))$height, @@ -237,6 +308,8 @@ test_that("create DataFrame from RDD", { }) test_that("createDataFrame uses files for large objects", { + skip_on_cran() + # To simulate a large file scenario, we set spark.r.maxAllocationLimit to a smaller value conf <- callJMethod(sparkSession, "conf") callJMethod(conf, "set", "spark.r.maxAllocationLimit", "100") @@ -297,6 +370,8 @@ test_that("read/write csv as DataFrame", { }) test_that("Support other types for options", { + skip_on_cran() + csvPath <- tempfile(pattern = "sparkr-test", fileext = ".csv") mockLinesCsv <- c("year,make,model,comment,blank", "\"2012\",\"Tesla\",\"S\",\"No comment\",", @@ -351,6 +426,8 @@ test_that("convert NAs to null type in DataFrames", { }) test_that("toDF", { + skip_on_cran() + rdd <- lapply(parallelize(sc, 1:10), function(x) { list(x, as.character(x)) }) df <- toDF(rdd, list("a", "b")) expect_is(df, "SparkDataFrame") @@ -462,6 +539,8 @@ test_that("create DataFrame with complex types", { }) test_that("create DataFrame from a data.frame with complex types", { + skip_on_cran() + ldf <- data.frame(row.names = 1:2) ldf$a_list <- list(list(1, 2), list(3, 4)) ldf$an_envir <- c(as.environment(list(a = 1, b = 2)), as.environment(list(c = 3))) @@ -474,6 +553,8 @@ test_that("create DataFrame from a data.frame with complex types", { }) test_that("Collect DataFrame with complex types", { + skip_on_cran() + # ArrayType df <- read.json(complexTypeJsonPath) ldf <- collect(df) @@ -561,6 +642,8 @@ test_that("read/write json files", { }) test_that("read/write json files - compression option", { + skip_on_cran() + df <- read.df(jsonPath, "json") jsonPath <- tempfile(pattern = "jsonPath", fileext = ".json") @@ -574,6 +657,8 @@ test_that("read/write json files - compression option", { }) test_that("jsonRDD() on a RDD with json string", { + skip_on_cran() + sqlContext <- suppressWarnings(sparkRSQL.init(sc)) rdd <- parallelize(sc, mockLines) expect_equal(countRDD(rdd), 3) @@ -591,16 +676,20 @@ test_that("test tableNames and tables", { df <- read.json(jsonPath) createOrReplaceTempView(df, "table1") expect_equal(length(tableNames()), 1) - tables <- tables() + expect_equal(length(tableNames("default")), 1) + tables <- listTables() expect_equal(count(tables), 1) + expect_equal(count(tables()), count(tables)) + expect_true("tableName" %in% colnames(tables())) + expect_true(all(c("tableName", "database", "isTemporary") %in% colnames(tables()))) suppressWarnings(registerTempTable(df, "table2")) - tables <- tables() + tables <- listTables() expect_equal(count(tables), 2) suppressWarnings(dropTempTable("table1")) expect_true(dropTempView("table2")) - tables <- tables() + tables <- listTables() expect_equal(count(tables), 0) }) @@ -626,12 +715,17 @@ test_that( }) test_that("test cache, uncache and clearCache", { + skip_on_cran() + df <- read.json(jsonPath) createOrReplaceTempView(df, "table1") cacheTable("table1") uncacheTable("table1") clearCache() expect_true(dropTempView("table1")) + + expect_error(uncacheTable("foo"), + "Error in uncacheTable : no such table - Table or view 'foo' not found in database 'default'") }) test_that("insertInto() on a registered table", { @@ -676,6 +770,8 @@ test_that("tableToDF() returns a new DataFrame", { }) test_that("toRDD() returns an RRDD", { + skip_on_cran() + df <- read.json(jsonPath) testRDD <- toRDD(df) expect_is(testRDD, "RDD") @@ -683,6 +779,8 @@ test_that("toRDD() returns an RRDD", { }) test_that("union on two RDDs created from DataFrames returns an RRDD", { + skip_on_cran() + df <- read.json(jsonPath) RDD1 <- toRDD(df) RDD2 <- toRDD(df) @@ -693,6 +791,8 @@ test_that("union on two RDDs created from DataFrames returns an RRDD", { }) test_that("union on mixed serialization types correctly returns a byte RRDD", { + skip_on_cran() + # Byte RDD nums <- 1:10 rdd <- parallelize(sc, nums, 2L) @@ -722,6 +822,8 @@ test_that("union on mixed serialization types correctly returns a byte RRDD", { }) test_that("objectFile() works with row serialization", { + skip_on_cran() + objectPath <- tempfile(pattern = "spark-test", fileext = ".tmp") df <- read.json(jsonPath) dfRDD <- toRDD(df) @@ -734,6 +836,8 @@ test_that("objectFile() works with row serialization", { }) test_that("lapply() on a DataFrame returns an RDD with the correct columns", { + skip_on_cran() + df <- read.json(jsonPath) testRDD <- lapply(df, function(row) { row$newCol <- row$age + 5 @@ -802,6 +906,8 @@ test_that("collect() support Unicode characters", { }) test_that("multiple pipeline transformations result in an RDD with the correct values", { + skip_on_cran() + df <- read.json(jsonPath) first <- lapply(df, function(row) { row$age <- row$age + 5 @@ -840,6 +946,17 @@ test_that("cache(), storageLevel(), persist(), and unpersist() on a DataFrame", expect_true(is.data.frame(collect(df))) }) +test_that("setCheckpointDir(), checkpoint() on a DataFrame", { + checkpointDir <- file.path(tempdir(), "cproot") + expect_true(length(list.files(path = checkpointDir, all.files = TRUE)) == 0) + + setCheckpointDir(checkpointDir) + df <- read.json(jsonPath) + df <- checkpoint(df) + expect_is(df, "SparkDataFrame") + expect_false(length(list.files(path = checkpointDir, all.files = TRUE)) == 0) +}) + test_that("schema(), dtypes(), columns(), names() return the correct values/format", { df <- read.json(jsonPath) testSchema <- schema(df) @@ -1196,7 +1313,16 @@ test_that("column calculation", { test_that("test HiveContext", { setHiveContext(sc) - df <- createExternalTable("json", jsonPath, "json") + + schema <- structType(structField("name", "string"), structField("age", "integer"), + structField("height", "float")) + createTable("people", source = "json", schema = schema) + df <- read.df(jsonPathNa, "json", schema) + insertInto(df, "people") + expect_equal(collect(sql("SELECT age from people WHERE name = 'Bob'"))$age, c(16)) + sql("DROP TABLE people") + + df <- createTable("json", jsonPath, "json") expect_is(df, "SparkDataFrame") expect_equal(count(df), 3) df2 <- sql("select * from json") @@ -1204,25 +1330,26 @@ test_that("test HiveContext", { expect_equal(count(df2), 3) jsonPath2 <- tempfile(pattern = "sparkr-test", fileext = ".tmp") - invisible(saveAsTable(df, "json2", "json", "append", path = jsonPath2)) + saveAsTable(df, "json2", "json", "append", path = jsonPath2) df3 <- sql("select * from json2") expect_is(df3, "SparkDataFrame") expect_equal(count(df3), 3) unlink(jsonPath2) hivetestDataPath <- tempfile(pattern = "sparkr-test", fileext = ".tmp") - invisible(saveAsTable(df, "hivetestbl", path = hivetestDataPath)) + saveAsTable(df, "hivetestbl", path = hivetestDataPath) df4 <- sql("select * from hivetestbl") expect_is(df4, "SparkDataFrame") expect_equal(count(df4), 3) unlink(hivetestDataPath) parquetDataPath <- tempfile(pattern = "sparkr-test", fileext = ".tmp") - invisible(saveAsTable(df, "parquetest", "parquet", mode = "overwrite", path = parquetDataPath)) + saveAsTable(df, "parquetest", "parquet", mode = "overwrite", path = parquetDataPath) df5 <- sql("select * from parquetest") expect_is(df5, "SparkDataFrame") expect_equal(count(df5), 3) unlink(parquetDataPath) + unsetHiveContext() }) @@ -1232,6 +1359,8 @@ test_that("column operators", { c3 <- (c + c2 - c2) * c2 %% c2 c4 <- (c > c2) & (c2 <= c3) | (c == c2) & (c2 != c3) c5 <- c2 ^ c3 ^ c4 + c6 <- c2 %<=>% c3 + c7 <- !c6 }) test_that("column functions", { @@ -1256,6 +1385,8 @@ test_that("column functions", { c18 <- covar_pop(c, c1) + covar_pop("c", "c1") c19 <- spark_partition_id() + coalesce(c) + coalesce(c1, c2, c3) c20 <- to_timestamp(c) + to_timestamp(c, "yyyy") + to_date(c, "yyyy") + c21 <- posexplode_outer(c) + explode_outer(c) + c22 <- not(c) # Test if base::is.nan() is exposed expect_equal(is.nan(c("a", "b")), c(FALSE, FALSE)) @@ -1271,6 +1402,11 @@ test_that("column functions", { expect_equal(collect(df2)[[3, 1]], FALSE) expect_equal(collect(df2)[[3, 2]], TRUE) + # Test that input_file_name() + actual_names <- sort(collect(distinct(select(df, input_file_name())))) + expect_equal(length(actual_names), 1) + expect_equal(basename(actual_names[1, 1]), basename(jsonPath)) + df3 <- select(df, between(df$name, c("Apache", "Spark"))) expect_equal(collect(df3)[[1, 1]], TRUE) expect_equal(collect(df3)[[2, 1]], FALSE) @@ -1339,6 +1475,10 @@ test_that("column functions", { expect_equal(collect(select(df, bround(df$x, 0)))[[1]][2], 4) # Test to_json(), from_json() + df <- sql("SELECT array(named_struct('name', 'Bob'), named_struct('name', 'Alice')) as people") + j <- collect(select(df, alias(to_json(df$people), "json"))) + expect_equal(j[order(j$json), ][1], "[{\"name\":\"Bob\"},{\"name\":\"Alice\"}]") + df <- read.json(mapTypeJsonPath) j <- collect(select(df, alias(to_json(df$info), "json"))) expect_equal(j[order(j$json), ][1], "{\"age\":16,\"height\":176.5}") @@ -1354,9 +1494,8 @@ test_that("column functions", { # passing option df <- as.DataFrame(list(list("col" = "{\"date\":\"21/10/2014\"}"))) schema2 <- structType(structField("date", "date")) - expect_error(tryCatch(collect(select(df, from_json(df$col, schema2))), - error = function(e) { stop(e) }), - paste0(".*(java.lang.NumberFormatException: For input string:).*")) + s <- collect(select(df, from_json(df$col, schema2))) + expect_equal(s[[1]][[1]], NA) s <- collect(select(df, from_json(df$col, schema2, dateFormat = "dd/MM/yyyy"))) expect_is(s[[1]][[1]]$date, "Date") expect_equal(as.character(s[[1]][[1]]$date), "2014-10-21") @@ -1364,6 +1503,41 @@ test_that("column functions", { # check for unparseable df <- as.DataFrame(list(list("a" = ""))) expect_equal(collect(select(df, from_json(df$a, schema)))[[1]][[1]], NA) + + # check if array type in string is correctly supported. + jsonArr <- "[{\"name\":\"Bob\"}, {\"name\":\"Alice\"}]" + df <- as.DataFrame(list(list("people" = jsonArr))) + schema <- structType(structField("name", "string")) + arr <- collect(select(df, alias(from_json(df$people, schema, as.json.array = TRUE), "arrcol"))) + expect_equal(ncol(arr), 1) + expect_equal(nrow(arr), 1) + expect_is(arr[[1]][[1]], "list") + expect_equal(length(arr$arrcol[[1]]), 2) + expect_equal(arr$arrcol[[1]][[1]]$name, "Bob") + expect_equal(arr$arrcol[[1]][[2]]$name, "Alice") + + # Test create_array() and create_map() + df <- as.DataFrame(data.frame( + x = c(1.0, 2.0), y = c(-1.0, 3.0), z = c(-2.0, 5.0) + )) + + arrs <- collect(select(df, create_array(df$x, df$y, df$z))) + expect_equal(arrs[, 1], list(list(1, -1, -2), list(2, 3, 5))) + + maps <- collect(select( + df, create_map(lit("x"), df$x, lit("y"), df$y, lit("z"), df$z))) + + expect_equal( + maps[, 1], + lapply( + list(list(x = 1, y = -1, z = -2), list(x = 2, y = 3, z = 5)), + as.environment)) + + df <- as.DataFrame(data.frame(is_true = c(TRUE, FALSE, NA))) + expect_equal( + collect(select(df, alias(not(df$is_true), "is_false"))), + data.frame(is_false = c(FALSE, TRUE, NA)) + ) }) test_that("column binary mathfunctions", { @@ -1432,6 +1606,40 @@ test_that("string operators", { expect_equal(collect(select(df3, substring_index(df3$a, ".", 2)))[1, 1], "a.b") expect_equal(collect(select(df3, substring_index(df3$a, ".", -3)))[1, 1], "b.c.d") expect_equal(collect(select(df3, translate(df3$a, "bc", "12")))[1, 1], "a.1.2.d") + + l4 <- list(list(a = "a.b@c.d 1\\b")) + df4 <- createDataFrame(l4) + expect_equal( + collect(select(df4, split_string(df4$a, "\\s+")))[1, 1], + list(list("a.b@c.d", "1\\b")) + ) + expect_equal( + collect(select(df4, split_string(df4$a, "\\.")))[1, 1], + list(list("a", "b@c", "d 1\\b")) + ) + expect_equal( + collect(select(df4, split_string(df4$a, "@")))[1, 1], + list(list("a.b", "c.d 1\\b")) + ) + expect_equal( + collect(select(df4, split_string(df4$a, "\\\\")))[1, 1], + list(list("a.b@c.d 1", "b")) + ) + + l5 <- list(list(a = "abc")) + df5 <- createDataFrame(l5) + expect_equal( + collect(select(df5, repeat_string(df5$a, 1L)))[1, 1], + "abc" + ) + expect_equal( + collect(select(df5, repeat_string(df5$a, 3)))[1, 1], + "abcabcabc" + ) + expect_equal( + collect(select(df5, repeat_string(df5$a, -1)))[1, 1], + "" + ) }) test_that("date functions on a DataFrame", { @@ -1617,6 +1825,28 @@ test_that("group by, agg functions", { expect_true(abs(sd(1:2) - 0.7071068) < 1e-6) expect_true(abs(var(1:5, 1:5) - 2.5) < 1e-6) + # Test collect_list and collect_set + gd3_collections_local <- collect( + agg(gd3, collect_set(df8$age), collect_list(df8$age)) + ) + + expect_equal( + unlist(gd3_collections_local[gd3_collections_local$name == "Andy", 2]), + c(30) + ) + + expect_equal( + unlist(gd3_collections_local[gd3_collections_local$name == "Andy", 3]), + c(30, 30) + ) + + expect_equal( + sort(unlist( + gd3_collections_local[gd3_collections_local$name == "Justin", 3] + )), + c(1, 19) + ) + unlink(jsonPath2) unlink(jsonPath3) }) @@ -1646,6 +1876,160 @@ test_that("pivot GroupedData column", { expect_error(collect(sum(pivot(groupBy(df, "year"), "course", list("R", "R")), "earnings"))) }) +test_that("test multi-dimensional aggregations with cube and rollup", { + df <- createDataFrame(data.frame( + id = 1:6, + year = c(2016, 2016, 2016, 2017, 2017, 2017), + salary = c(10000, 15000, 20000, 22000, 32000, 21000), + department = c("management", "rnd", "sales", "management", "rnd", "sales") + )) + + actual_cube <- collect( + orderBy( + agg( + cube(df, "year", "department"), + expr("sum(salary) AS total_salary"), + expr("avg(salary) AS average_salary"), + alias(grouping_bit(df$year), "grouping_year"), + alias(grouping_bit(df$department), "grouping_department"), + alias(grouping_id(df$year, df$department), "grouping_id") + ), + "year", "department" + ) + ) + + expected_cube <- data.frame( + year = c(rep(NA, 4), rep(2016, 4), rep(2017, 4)), + department = rep(c(NA, "management", "rnd", "sales"), times = 3), + total_salary = c( + 120000, # Total + 10000 + 22000, 15000 + 32000, 20000 + 21000, # Department only + 20000 + 15000 + 10000, # 2016 + 10000, 15000, 20000, # 2016 each department + 21000 + 32000 + 22000, # 2017 + 22000, 32000, 21000 # 2017 each department + ), + average_salary = c( + # Total + mean(c(20000, 15000, 10000, 21000, 32000, 22000)), + # Mean by department + mean(c(10000, 22000)), mean(c(15000, 32000)), mean(c(20000, 21000)), + mean(c(10000, 15000, 20000)), # 2016 + 10000, 15000, 20000, # 2016 each department + mean(c(21000, 32000, 22000)), # 2017 + 22000, 32000, 21000 # 2017 each department + ), + grouping_year = c( + 1, # global + 1, 1, 1, # by department + 0, # 2016 + 0, 0, 0, # 2016 by department + 0, # 2017 + 0, 0, 0 # 2017 by department + ), + grouping_department = c( + 1, # global + 0, 0, 0, # by department + 1, # 2016 + 0, 0, 0, # 2016 by department + 1, # 2017 + 0, 0, 0 # 2017 by department + ), + grouping_id = c( + 3, # 11 + 2, 2, 2, # 10 + 1, # 01 + 0, 0, 0, # 00 + 1, # 01 + 0, 0, 0 # 00 + ), + stringsAsFactors = FALSE + ) + + expect_equal(actual_cube, expected_cube) + + # cube should accept column objects + expect_equal( + count(sum(cube(df, df$year, df$department), "salary")), + 12 + ) + + # cube without columns should result in a single aggregate + expect_equal( + collect(agg(cube(df), expr("sum(salary) as total_salary"))), + data.frame(total_salary = 120000) + ) + + actual_rollup <- collect( + orderBy( + agg( + rollup(df, "year", "department"), + expr("sum(salary) AS total_salary"), expr("avg(salary) AS average_salary"), + alias(grouping_bit(df$year), "grouping_year"), + alias(grouping_bit(df$department), "grouping_department"), + alias(grouping_id(df$year, df$department), "grouping_id") + ), + "year", "department" + ) + ) + + expected_rollup <- data.frame( + year = c(NA, rep(2016, 4), rep(2017, 4)), + department = c(NA, rep(c(NA, "management", "rnd", "sales"), times = 2)), + total_salary = c( + 120000, # Total + 20000 + 15000 + 10000, # 2016 + 10000, 15000, 20000, # 2016 each department + 21000 + 32000 + 22000, # 2017 + 22000, 32000, 21000 # 2017 each department + ), + average_salary = c( + # Total + mean(c(20000, 15000, 10000, 21000, 32000, 22000)), + mean(c(10000, 15000, 20000)), # 2016 + 10000, 15000, 20000, # 2016 each department + mean(c(21000, 32000, 22000)), # 2017 + 22000, 32000, 21000 # 2017 each department + ), + grouping_year = c( + 1, # global + 0, # 2016 + 0, 0, 0, # 2016 each department + 0, # 2017 + 0, 0, 0 # 2017 each department + ), + grouping_department = c( + 1, # global + 1, # 2016 + 0, 0, 0, # 2016 each department + 1, # 2017 + 0, 0, 0 # 2017 each department + ), + grouping_id = c( + 3, # 11 + 1, # 01 + 0, 0, 0, # 00 + 1, # 01 + 0, 0, 0 # 00 + ), + stringsAsFactors = FALSE + ) + + expect_equal(actual_rollup, expected_rollup) + + # cube should accept column objects + expect_equal( + count(sum(rollup(df, df$year, df$department), "salary")), + 9 + ) + + # rollup without columns should result in a single aggregate + expect_equal( + collect(agg(rollup(df), expr("sum(salary) as total_salary"))), + data.frame(total_salary = 120000) + ) +}) + test_that("arrange() and orderBy() on a DataFrame", { df <- read.json(jsonPath) sorted <- arrange(df, df$age) @@ -1691,6 +2075,16 @@ test_that("filter() on a DataFrame", { filtered6 <- where(df, df$age %in% c(19, 30)) expect_equal(count(filtered6), 2) + # test suites for %<=>% + dfNa <- read.json(jsonPathNa) + expect_equal(count(filter(dfNa, dfNa$age %<=>% 60)), 1) + expect_equal(count(filter(dfNa, !(dfNa$age %<=>% 60))), 5 - 1) + expect_equal(count(filter(dfNa, dfNa$age %<=>% NULL)), 3) + expect_equal(count(filter(dfNa, !(dfNa$age %<=>% NULL))), 5 - 3) + # match NA from two columns + expect_equal(count(filter(dfNa, dfNa$age %<=>% dfNa$height)), 2) + expect_equal(count(filter(dfNa, !(dfNa$age %<=>% dfNa$height))), 5 - 2) + # Test stats::filter is working #expect_true(is.ts(filter(1:100, rep(1, 3)))) # nolint }) @@ -1793,6 +2187,18 @@ test_that("join(), crossJoin() and merge() on a DataFrame", { unlink(jsonPath2) unlink(jsonPath3) + + # Join with broadcast hint + df1 <- sql("SELECT * FROM range(10e10)") + df2 <- sql("SELECT * FROM range(10e10)") + + execution_plan <- capture.output(explain(join(df1, df2, df1$id == df2$id))) + expect_false(any(grepl("BroadcastHashJoin", execution_plan))) + + execution_plan_hint <- capture.output( + explain(join(df1, hint(df2, "broadcast"), df1$id == df2$id)) + ) + expect_true(any(grepl("BroadcastHashJoin", execution_plan_hint))) }) test_that("toJSON() on DataFrame", { @@ -1850,6 +2256,13 @@ test_that("union(), rbind(), except(), and intersect() on a DataFrame", { expect_equal(count(unioned2), 12) expect_equal(first(unioned2)$name, "Michael") + df3 <- df2 + names(df3)[1] <- "newName" + expect_error(rbind(df, df3), + "Names of input data frames are different.") + expect_error(rbind(df, df2, df3), + "Names of input data frames are different.") + excepted <- arrange(except(df, df2), desc(df$age)) expect_is(unioned, "SparkDataFrame") expect_equal(count(excepted), 2) @@ -1945,6 +2358,8 @@ test_that("mutate(), transform(), rename() and names()", { }) test_that("read/write ORC files", { + skip_on_cran() + setHiveContext(sc) df <- read.df(jsonPath, "json") @@ -1966,6 +2381,8 @@ test_that("read/write ORC files", { }) test_that("read/write ORC files - compression option", { + skip_on_cran() + setHiveContext(sc) df <- read.df(jsonPath, "json") @@ -2012,6 +2429,8 @@ test_that("read/write Parquet files", { }) test_that("read/write Parquet files - compression option/mode", { + skip_on_cran() + df <- read.df(jsonPath, "json") tempPath <- tempfile(pattern = "tempPath", fileext = ".parquet") @@ -2029,6 +2448,8 @@ test_that("read/write Parquet files - compression option/mode", { }) test_that("read/write text files", { + skip_on_cran() + # Test write.df and read.df df <- read.df(jsonPath, "text") expect_is(df, "SparkDataFrame") @@ -2050,6 +2471,8 @@ test_that("read/write text files", { }) test_that("read/write text files - compression option", { + skip_on_cran() + df <- read.df(jsonPath, "text") textPath <- tempfile(pattern = "textPath", fileext = ".txt") @@ -2283,6 +2706,8 @@ test_that("approxQuantile() on a DataFrame", { }) test_that("SQL error message is returned from JVM", { + skip_on_cran() + retError <- tryCatch(sql("select * from blah"), error = function(e) e) expect_equal(grepl("Table or view not found", retError), TRUE) expect_equal(grepl("blah", retError), TRUE) @@ -2291,6 +2716,8 @@ test_that("SQL error message is returned from JVM", { irisDF <- suppressWarnings(createDataFrame(iris)) test_that("Method as.data.frame as a synonym for collect()", { + skip_on_cran() + expect_equal(as.data.frame(irisDF), collect(irisDF)) irisDF2 <- irisDF[irisDF$Species == "setosa", ] expect_equal(as.data.frame(irisDF2), collect(irisDF2)) @@ -2585,8 +3012,8 @@ test_that("coalesce, repartition, numPartitions", { df2 <- repartition(df1, 10) expect_equal(getNumPartitions(df2), 10) - expect_equal(getNumPartitions(coalesce(df2, 13)), 5) - expect_equal(getNumPartitions(coalesce(df2, 7)), 5) + expect_equal(getNumPartitions(coalesce(df2, 13)), 10) + expect_equal(getNumPartitions(coalesce(df2, 7)), 7) expect_equal(getNumPartitions(coalesce(df2, 3)), 3) }) @@ -2708,6 +3135,8 @@ test_that("Window functions on a DataFrame", { }) test_that("createDataFrame sqlContext parameter backward compatibility", { + skip_on_cran() + sqlContext <- suppressWarnings(sparkRSQL.init(sc)) a <- 1:3 b <- c("a", "b", "c") @@ -2734,7 +3163,7 @@ test_that("createDataFrame sqlContext parameter backward compatibility", { # more tests for SPARK-16538 createOrReplaceTempView(df, "table") - SparkR::tables() + SparkR::listTables() SparkR::sql("SELECT 1") suppressWarnings(SparkR::sql(sqlContext, "SELECT * FROM table")) suppressWarnings(SparkR::dropTempTable(sqlContext, "table")) @@ -2787,6 +3216,8 @@ test_that("Setting and getting config on SparkSession, sparkR.conf(), sparkR.uiW }) test_that("enableHiveSupport on SparkSession", { + skip_on_cran() + setHiveContext(sc) unsetHiveContext() # if we are still here, it must be built with hive @@ -2802,6 +3233,8 @@ test_that("Spark version from SparkSession", { }) test_that("Call DataFrameWriter.save() API in Java without path and check argument types", { + skip_on_cran() + df <- read.df(jsonPath, "json") # This tests if the exception is thrown from JVM not from SparkR side. # It makes sure that we can omit path argument in write.df API and then it calls @@ -2822,12 +3255,14 @@ test_that("Call DataFrameWriter.save() API in Java without path and check argume paste("source should be character, NULL or omitted. It is the datasource specified", "in 'spark.sql.sources.default' configuration by default.")) expect_error(write.df(df, path = c(3)), - "path should be charactor, NULL or omitted.") + "path should be character, NULL or omitted.") expect_error(write.df(df, mode = TRUE), - "mode should be charactor or omitted. It is 'error' by default.") + "mode should be character or omitted. It is 'error' by default.") }) test_that("Call DataFrameWriter.load() API in Java without path and check argument types", { + skip_on_cran() + # This tests if the exception is thrown from JVM not from SparkR side. # It makes sure that we can omit path argument in read.df API and then it calls # DataFrameWriter.load() without path. @@ -2843,7 +3278,7 @@ test_that("Call DataFrameWriter.load() API in Java without path and check argume # Arguments checking in R side. expect_error(read.df(path = c(3)), - "path should be charactor, NULL or omitted.") + "path should be character, NULL or omitted.") expect_error(read.df(jsonPath, source = c(1, 2)), paste("source should be character, NULL or omitted. It is the datasource specified", "in 'spark.sql.sources.default' configuration by default.")) @@ -2890,6 +3325,92 @@ test_that("Collect on DataFrame when NAs exists at the top of a timestamp column expect_equal(class(ldf3$col3), c("POSIXct", "POSIXt")) }) +test_that("catalog APIs, currentDatabase, setCurrentDatabase, listDatabases", { + expect_equal(currentDatabase(), "default") + expect_error(setCurrentDatabase("default"), NA) + expect_error(setCurrentDatabase("foo"), + "Error in setCurrentDatabase : analysis error - Database 'foo' does not exist") + dbs <- collect(listDatabases()) + expect_equal(names(dbs), c("name", "description", "locationUri")) + expect_equal(dbs[[1]], "default") +}) + +test_that("catalog APIs, listTables, listColumns, listFunctions", { + tb <- listTables() + count <- count(tables()) + expect_equal(nrow(tb), count) + expect_equal(colnames(tb), c("name", "database", "description", "tableType", "isTemporary")) + + createOrReplaceTempView(as.DataFrame(cars), "cars") + + tb <- listTables() + expect_equal(nrow(tb), count + 1) + tbs <- collect(tb) + expect_true(nrow(tbs[tbs$name == "cars", ]) > 0) + expect_error(listTables("bar"), + "Error in listTables : no such database - Database 'bar' not found") + + c <- listColumns("cars") + expect_equal(nrow(c), 2) + expect_equal(colnames(c), + c("name", "description", "dataType", "nullable", "isPartition", "isBucket")) + expect_equal(collect(c)[[1]][[1]], "speed") + expect_error(listColumns("foo", "default"), + "Error in listColumns : analysis error - Table 'foo' does not exist in database 'default'") + + f <- listFunctions() + expect_true(nrow(f) >= 200) # 250 + expect_equal(colnames(f), + c("name", "database", "description", "className", "isTemporary")) + expect_equal(take(orderBy(f, "className"), 1)$className, + "org.apache.spark.sql.catalyst.expressions.Abs") + expect_error(listFunctions("foo_db"), + "Error in listFunctions : analysis error - Database 'foo_db' does not exist") + + # recoverPartitions does not work with tempory view + expect_error(recoverPartitions("cars"), + "no such table - Table or view 'cars' not found in database 'default'") + expect_error(refreshTable("cars"), NA) + expect_error(refreshByPath("/"), NA) + + dropTempView("cars") +}) + +compare_list <- function(list1, list2) { + # get testthat to show the diff by first making the 2 lists equal in length + expect_equal(length(list1), length(list2)) + l <- max(length(list1), length(list2)) + length(list1) <- l + length(list2) <- l + expect_equal(sort(list1, na.last = TRUE), sort(list2, na.last = TRUE)) +} + +# This should always be the **very last test** in this test file. +test_that("No extra files are created in SPARK_HOME by starting session and making calls", { + skip_on_cran() + + # Check that it is not creating any extra file. + # Does not check the tempdir which would be cleaned up after. + filesAfter <- list.files(path = sparkRDir, all.files = TRUE) + + expect_true(length(sparkRFilesBefore) > 0) + # first, ensure derby.log is not there + expect_false("derby.log" %in% filesAfter) + # second, ensure only spark-warehouse is created when calling SparkSession, enableHiveSupport = F + # note: currently all other test files have enableHiveSupport = F, so we capture the list of files + # before creating a SparkSession with enableHiveSupport = T at the top of this test file + # (filesBefore). The test here is to compare that (filesBefore) against the list of files before + # any test is run in run-all.R (sparkRFilesBefore). + # sparkRWhitelistSQLDirs is also defined in run-all.R, and should contain only 2 whitelisted dirs, + # here allow the first value, spark-warehouse, in the diff, everything else should be exactly the + # same as before any test is run. + compare_list(sparkRFilesBefore, setdiff(filesBefore, sparkRWhitelistSQLDirs[[1]])) + # third, ensure only spark-warehouse and metastore_db are created when enableHiveSupport = T + # note: as the note above, after running all tests in this file while enableHiveSupport = T, we + # check the list of files again. This time we allow both whitelisted dirs to be in the diff. + compare_list(sparkRFilesBefore, setdiff(filesAfter, sparkRWhitelistSQLDirs)) +}) + unlink(parquetPath) unlink(orcPath) unlink(jsonPath) diff --git a/R/pkg/inst/tests/testthat/test_streaming.R b/R/pkg/inst/tests/testthat/test_streaming.R new file mode 100644 index 0000000000000..91df7ac6f9849 --- /dev/null +++ b/R/pkg/inst/tests/testthat/test_streaming.R @@ -0,0 +1,167 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +library(testthat) + +context("Structured Streaming") + +# Tests for Structured Streaming functions in SparkR + +sparkSession <- sparkR.session(enableHiveSupport = FALSE) + +jsonSubDir <- file.path("sparkr-test", "json", "") +if (.Platform$OS.type == "windows") { + # file.path removes the empty separator on Windows, adds it back + jsonSubDir <- paste0(jsonSubDir, .Platform$file.sep) +} +jsonDir <- file.path(tempdir(), jsonSubDir) +dir.create(jsonDir, recursive = TRUE) + +mockLines <- c("{\"name\":\"Michael\"}", + "{\"name\":\"Andy\", \"age\":30}", + "{\"name\":\"Justin\", \"age\":19}") +jsonPath <- tempfile(pattern = jsonSubDir, fileext = ".tmp") +writeLines(mockLines, jsonPath) + +mockLinesNa <- c("{\"name\":\"Bob\",\"age\":16,\"height\":176.5}", + "{\"name\":\"Alice\",\"age\":null,\"height\":164.3}", + "{\"name\":\"David\",\"age\":60,\"height\":null}") +jsonPathNa <- tempfile(pattern = jsonSubDir, fileext = ".tmp") + +schema <- structType(structField("name", "string"), + structField("age", "integer"), + structField("count", "double")) + +test_that("read.stream, write.stream, awaitTermination, stopQuery", { + skip_on_cran() + + df <- read.stream("json", path = jsonDir, schema = schema, maxFilesPerTrigger = 1) + expect_true(isStreaming(df)) + counts <- count(group_by(df, "name")) + q <- write.stream(counts, "memory", queryName = "people", outputMode = "complete") + + expect_false(awaitTermination(q, 5 * 1000)) + callJMethod(q@ssq, "processAllAvailable") + expect_equal(head(sql("SELECT count(*) FROM people"))[[1]], 3) + + writeLines(mockLinesNa, jsonPathNa) + awaitTermination(q, 5 * 1000) + callJMethod(q@ssq, "processAllAvailable") + expect_equal(head(sql("SELECT count(*) FROM people"))[[1]], 6) + + stopQuery(q) + expect_true(awaitTermination(q, 1)) + expect_error(awaitTermination(q), NA) +}) + +test_that("print from explain, lastProgress, status, isActive", { + skip_on_cran() + + df <- read.stream("json", path = jsonDir, schema = schema) + expect_true(isStreaming(df)) + counts <- count(group_by(df, "name")) + q <- write.stream(counts, "memory", queryName = "people2", outputMode = "complete") + + awaitTermination(q, 5 * 1000) + callJMethod(q@ssq, "processAllAvailable") + + expect_equal(capture.output(explain(q))[[1]], "== Physical Plan ==") + expect_true(any(grepl("\"description\" : \"MemorySink\"", capture.output(lastProgress(q))))) + expect_true(any(grepl("\"isTriggerActive\" : ", capture.output(status(q))))) + + expect_equal(queryName(q), "people2") + expect_true(isActive(q)) + + stopQuery(q) +}) + +test_that("Stream other format", { + skip_on_cran() + + parquetPath <- tempfile(pattern = "sparkr-test", fileext = ".parquet") + df <- read.df(jsonPath, "json", schema) + write.df(df, parquetPath, "parquet", "overwrite") + + df <- read.stream(path = parquetPath, schema = schema) + expect_true(isStreaming(df)) + counts <- count(group_by(df, "name")) + q <- write.stream(counts, "memory", queryName = "people3", outputMode = "complete") + + expect_false(awaitTermination(q, 5 * 1000)) + callJMethod(q@ssq, "processAllAvailable") + expect_equal(head(sql("SELECT count(*) FROM people3"))[[1]], 3) + + expect_equal(queryName(q), "people3") + expect_true(any(grepl("\"description\" : \"FileStreamSource[[:print:]]+parquet", + capture.output(lastProgress(q))))) + expect_true(isActive(q)) + + stopQuery(q) + expect_true(awaitTermination(q, 1)) + expect_false(isActive(q)) + + unlink(parquetPath) +}) + +test_that("Non-streaming DataFrame", { + skip_on_cran() + + c <- as.DataFrame(cars) + expect_false(isStreaming(c)) + + expect_error(write.stream(c, "memory", queryName = "people", outputMode = "complete"), + paste0(".*(writeStream : analysis error - 'writeStream' can be called only on ", + "streaming Dataset/DataFrame).*")) +}) + +test_that("Unsupported operation", { + skip_on_cran() + + # memory sink without aggregation + df <- read.stream("json", path = jsonDir, schema = schema, maxFilesPerTrigger = 1) + expect_error(write.stream(df, "memory", queryName = "people", outputMode = "complete"), + paste0(".*(start : analysis error - Complete output mode not supported when there ", + "are no streaming aggregations on streaming DataFrames/Datasets).*")) +}) + +test_that("Terminated by error", { + skip_on_cran() + + df <- read.stream("json", path = jsonDir, schema = schema, maxFilesPerTrigger = -1) + counts <- count(group_by(df, "name")) + # This would not fail before returning with a StreamingQuery, + # but could dump error log at just about the same time + expect_error(q <- write.stream(counts, "memory", queryName = "people4", outputMode = "complete"), + NA) + + expect_error(awaitTermination(q, 5 * 1000), + paste0(".*(awaitTermination : streaming query error - Invalid value '-1' for option", + " 'maxFilesPerTrigger', must be a positive integer).*")) + + expect_true(any(grepl("\"message\" : \"Terminated with exception: Invalid value", + capture.output(status(q))))) + expect_true(any(grepl("Streaming query has no progress", capture.output(lastProgress(q))))) + expect_equal(queryName(q), "people4") + expect_false(isActive(q)) + + stopQuery(q) +}) + +unlink(jsonPath) +unlink(jsonPathNa) + +sparkR.session.stop() diff --git a/R/pkg/inst/tests/testthat/test_take.R b/R/pkg/inst/tests/testthat/test_take.R index aaa532856c3d9..e2130eaac78dd 100644 --- a/R/pkg/inst/tests/testthat/test_take.R +++ b/R/pkg/inst/tests/testthat/test_take.R @@ -34,6 +34,8 @@ sparkSession <- sparkR.session(enableHiveSupport = FALSE) sc <- callJStatic("org.apache.spark.sql.api.r.SQLUtils", "getJavaSparkContext", sparkSession) test_that("take() gives back the original elements in correct count and order", { + skip_on_cran() + numVectorRDD <- parallelize(sc, numVector, 10) # case: number of elements to take is less than the size of the first partition expect_equal(takeRDD(numVectorRDD, 1), as.list(head(numVector, n = 1))) diff --git a/R/pkg/inst/tests/testthat/test_textFile.R b/R/pkg/inst/tests/testthat/test_textFile.R index 3b466066e9390..28b7e8e3183fd 100644 --- a/R/pkg/inst/tests/testthat/test_textFile.R +++ b/R/pkg/inst/tests/testthat/test_textFile.R @@ -24,6 +24,8 @@ sc <- callJStatic("org.apache.spark.sql.api.r.SQLUtils", "getJavaSparkContext", mockFile <- c("Spark is pretty.", "Spark is awesome.") test_that("textFile() on a local file returns an RDD", { + skip_on_cran() + fileName <- tempfile(pattern = "spark-test", fileext = ".tmp") writeLines(mockFile, fileName) @@ -36,6 +38,8 @@ test_that("textFile() on a local file returns an RDD", { }) test_that("textFile() followed by a collect() returns the same content", { + skip_on_cran() + fileName <- tempfile(pattern = "spark-test", fileext = ".tmp") writeLines(mockFile, fileName) @@ -46,6 +50,8 @@ test_that("textFile() followed by a collect() returns the same content", { }) test_that("textFile() word count works as expected", { + skip_on_cran() + fileName <- tempfile(pattern = "spark-test", fileext = ".tmp") writeLines(mockFile, fileName) @@ -64,6 +70,8 @@ test_that("textFile() word count works as expected", { }) test_that("several transformations on RDD created by textFile()", { + skip_on_cran() + fileName <- tempfile(pattern = "spark-test", fileext = ".tmp") writeLines(mockFile, fileName) @@ -78,6 +86,8 @@ test_that("several transformations on RDD created by textFile()", { }) test_that("textFile() followed by a saveAsTextFile() returns the same content", { + skip_on_cran() + fileName1 <- tempfile(pattern = "spark-test", fileext = ".tmp") fileName2 <- tempfile(pattern = "spark-test", fileext = ".tmp") writeLines(mockFile, fileName1) @@ -92,6 +102,8 @@ test_that("textFile() followed by a saveAsTextFile() returns the same content", }) test_that("saveAsTextFile() on a parallelized list works as expected", { + skip_on_cran() + fileName <- tempfile(pattern = "spark-test", fileext = ".tmp") l <- list(1, 2, 3) rdd <- parallelize(sc, l, 1L) @@ -103,6 +115,8 @@ test_that("saveAsTextFile() on a parallelized list works as expected", { }) test_that("textFile() and saveAsTextFile() word count works as expected", { + skip_on_cran() + fileName1 <- tempfile(pattern = "spark-test", fileext = ".tmp") fileName2 <- tempfile(pattern = "spark-test", fileext = ".tmp") writeLines(mockFile, fileName1) @@ -128,6 +142,8 @@ test_that("textFile() and saveAsTextFile() word count works as expected", { }) test_that("textFile() on multiple paths", { + skip_on_cran() + fileName1 <- tempfile(pattern = "spark-test", fileext = ".tmp") fileName2 <- tempfile(pattern = "spark-test", fileext = ".tmp") writeLines("Spark is pretty.", fileName1) @@ -141,6 +157,8 @@ test_that("textFile() on multiple paths", { }) test_that("Pipelined operations on RDDs created using textFile", { + skip_on_cran() + fileName <- tempfile(pattern = "spark-test", fileext = ".tmp") writeLines(mockFile, fileName) diff --git a/R/pkg/inst/tests/testthat/test_utils.R b/R/pkg/inst/tests/testthat/test_utils.R index 6d006eccf665e..4a01e875405ff 100644 --- a/R/pkg/inst/tests/testthat/test_utils.R +++ b/R/pkg/inst/tests/testthat/test_utils.R @@ -23,6 +23,7 @@ sc <- callJStatic("org.apache.spark.sql.api.r.SQLUtils", "getJavaSparkContext", test_that("convertJListToRList() gives back (deserializes) the original JLists of strings and integers", { + skip_on_cran() # It's hard to manually create a Java List using rJava, since it does not # support generics well. Instead, we rely on collectRDD() returning a # JList. @@ -40,6 +41,7 @@ test_that("convertJListToRList() gives back (deserializes) the original JLists }) test_that("serializeToBytes on RDD", { + skip_on_cran() # File content mockFile <- c("Spark is pretty.", "Spark is awesome.") fileName <- tempfile(pattern = "spark-test", fileext = ".tmp") @@ -167,16 +169,20 @@ test_that("convertToJSaveMode", { }) test_that("captureJVMException", { - method <- "getSQLDataType" + skip_on_cran() + + method <- "createStructField" expect_error(tryCatch(callJStatic("org.apache.spark.sql.api.r.SQLUtils", method, - "unknown"), + "col", "unknown", TRUE), error = function(e) { captureJVMException(e, method) }), - "Error in getSQLDataType : illegal argument - Invalid type unknown") + "parse error - .*DataType unknown.*not supported.") }) test_that("hashCode", { + skip_on_cran() + expect_error(hashCode("bc53d3605e8a5b7de1e8e271c2317645"), NA) }) diff --git a/R/pkg/tests/run-all.R b/R/pkg/tests/run-all.R index ab8d1ca019941..29812f872c784 100644 --- a/R/pkg/tests/run-all.R +++ b/R/pkg/tests/run-all.R @@ -22,6 +22,13 @@ library(SparkR) options("warn" = 2) # Setup global test environment +# Install Spark first to set SPARK_HOME install.spark() +sparkRDir <- file.path(Sys.getenv("SPARK_HOME"), "R") +sparkRFilesBefore <- list.files(path = sparkRDir, all.files = TRUE) +sparkRWhitelistSQLDirs <- c("spark-warehouse", "metastore_db") +invisible(lapply(sparkRWhitelistSQLDirs, + function(x) { unlink(file.path(sparkRDir, x), recursive = TRUE, force = TRUE)})) + test_package("SparkR") diff --git a/R/pkg/vignettes/sparkr-vignettes.Rmd b/R/pkg/vignettes/sparkr-vignettes.Rmd index 43c255cff3028..d38ec4f1b6f37 100644 --- a/R/pkg/vignettes/sparkr-vignettes.Rmd +++ b/R/pkg/vignettes/sparkr-vignettes.Rmd @@ -182,7 +182,7 @@ head(df) ``` ### Data Sources -SparkR supports operating on a variety of data sources through the `SparkDataFrame` interface. You can check the Spark SQL programming guide for more [specific options](https://spark.apache.org/docs/latest/sql-programming-guide.html#manually-specifying-options) that are available for the built-in data sources. +SparkR supports operating on a variety of data sources through the `SparkDataFrame` interface. You can check the Spark SQL Programming Guide for more [specific options](https://spark.apache.org/docs/latest/sql-programming-guide.html#manually-specifying-options) that are available for the built-in data sources. The general method for creating `SparkDataFrame` from data sources is `read.df`. This method takes in the path for the file to load and the type of data source, and the currently active Spark Session will be used automatically. SparkR supports reading CSV, JSON and Parquet files natively and through Spark Packages you can find data source connectors for popular file formats like Avro. These packages can be added with `sparkPackages` parameter when initializing SparkSession using `sparkR.session`. @@ -232,7 +232,7 @@ write.df(people, path = "people.parquet", source = "parquet", mode = "overwrite" ``` ### Hive Tables -You can also create SparkDataFrames from Hive tables. To do this we will need to create a SparkSession with Hive support which can access tables in the Hive MetaStore. Note that Spark should have been built with Hive support and more details can be found in the [SQL programming guide](https://spark.apache.org/docs/latest/sql-programming-guide.html). In SparkR, by default it will attempt to create a SparkSession with Hive support enabled (`enableHiveSupport = TRUE`). +You can also create SparkDataFrames from Hive tables. To do this we will need to create a SparkSession with Hive support which can access tables in the Hive MetaStore. Note that Spark should have been built with Hive support and more details can be found in the [SQL Programming Guide](https://spark.apache.org/docs/latest/sql-programming-guide.html). In SparkR, by default it will attempt to create a SparkSession with Hive support enabled (`enableHiveSupport = TRUE`). ```{r, eval=FALSE} sql("CREATE TABLE IF NOT EXISTS src (key INT, value STRING)") @@ -308,6 +308,21 @@ numCyl <- summarize(groupBy(carsDF, carsDF$cyl), count = n(carsDF$cyl)) head(numCyl) ``` +Use `cube` or `rollup` to compute subtotals across multiple dimensions. + +```{r} +mean(cube(carsDF, "cyl", "gear", "am"), "mpg") +``` + +generates groupings for {(`cyl`, `gear`, `am`), (`cyl`, `gear`), (`cyl`), ()}, while + +```{r} +mean(rollup(carsDF, "cyl", "gear", "am"), "mpg") +``` + +generates groupings for all possible combinations of grouping columns. + + #### Operating on Columns SparkR also provides a number of functions that can directly applied to columns for data processing and during aggregation. The example below shows the use of basic arithmetic functions. @@ -505,6 +520,10 @@ SparkR supports the following machine learning models and algorithms. * Alternating Least Squares (ALS) +#### Frequent Pattern Mining + +* FP-growth + #### Statistics * Kolmogorov-Smirnov Test @@ -653,6 +672,7 @@ head(select(naiveBayesPrediction, "Class", "Sex", "Age", "Survived", "prediction Survival analysis studies the expected duration of time until an event happens, and often the relationship with risk factors or treatment taken on the subject. In contrast to standard regression analysis, survival modeling has to deal with special characteristics in the data including non-negative survival time and censoring. Accelerated Failure Time (AFT) model is a parametric survival model for censored data that assumes the effect of a covariate is to accelerate or decelerate the life course of an event by some constant. For more information, refer to the Wikipedia page [AFT Model](https://en.wikipedia.org/wiki/Accelerated_failure_time_model) and the references there. Different from a [Proportional Hazards Model](https://en.wikipedia.org/wiki/Proportional_hazards_model) designed for the same purpose, the AFT model is easier to parallelize because each instance contributes to the objective function independently. + ```{r, warning=FALSE} library(survival) ovarianDF <- createDataFrame(ovarian) @@ -672,6 +692,7 @@ gaussian | identity, log, inverse binomial | logit, probit, cloglog (complementary log-log) poisson | log, identity, sqrt gamma | inverse, identity, log +tweedie | power link function There are three ways to specify the `family` argument. @@ -679,7 +700,11 @@ There are three ways to specify the `family` argument. * Family function, e.g. `family = binomial`. -* Result returned by a family function, e.g. `family = poisson(link = log)` +* Result returned by a family function, e.g. `family = poisson(link = log)`. + +* Note that there are two ways to specify the tweedie family: + a) Set `family = "tweedie"` and specify the `var.power` and `link.power` + b) When package `statmod` is loaded, the tweedie family is specified using the family definition therein, i.e., `tweedie()`. For more information regarding the families and their link functions, see the Wikipedia page [Generalized Linear Model](https://en.wikipedia.org/wiki/Generalized_linear_model). @@ -695,6 +720,18 @@ gaussianFitted <- predict(gaussianGLM, carsDF) head(select(gaussianFitted, "model", "prediction", "mpg", "wt", "hp")) ``` +The following is the same fit using the tweedie family: +```{r} +tweedieGLM1 <- spark.glm(carsDF, mpg ~ wt + hp, family = "tweedie", var.power = 0.0) +summary(tweedieGLM1) +``` +We can try other distributions in the tweedie family, for example, a compound Poisson distribution with a log link: +```{r} +tweedieGLM2 <- spark.glm(carsDF, mpg ~ wt + hp, family = "tweedie", + var.power = 1.2, link.power = 0.0) +summary(tweedieGLM2) +``` + #### Isotonic Regression `spark.isoreg` fits an [Isotonic Regression](https://en.wikipedia.org/wiki/Isotonic_regression) model against a `SparkDataFrame`. It solves a weighted univariate a regression problem under a complete order constraint. Specifically, given a set of real observed responses $y_1, \ldots, y_n$, corresponding real features $x_1, \ldots, x_n$, and optionally positive weights $w_1, \ldots, w_n$, we want to find a monotone (piecewise linear) function $f$ to minimize @@ -866,7 +903,7 @@ perplexity There are multiple options that can be configured in `spark.als`, including `rank`, `reg`, `nonnegative`. For a complete list, refer to the help file. -```{r} +```{r, eval=FALSE} ratings <- list(list(0, 0, 4.0), list(0, 1, 2.0), list(1, 1, 3.0), list(1, 2, 4.0), list(2, 1, 1.0), list(2, 2, 5.0)) df <- createDataFrame(ratings, c("user", "item", "rating")) @@ -874,7 +911,7 @@ model <- spark.als(df, "rating", "user", "item", rank = 10, reg = 0.1, nonnegati ``` Extract latent factors. -```{r} +```{r, eval=FALSE} stats <- summary(model) userFactors <- stats$userFactors itemFactors <- stats$itemFactors @@ -884,11 +921,42 @@ head(itemFactors) Make predictions. -```{r} +```{r, eval=FALSE} predicted <- predict(model, df) head(predicted) ``` +#### FP-growth + +`spark.fpGrowth` executes FP-growth algorithm to mine frequent itemsets on a `SparkDataFrame`. `itemsCol` should be an array of values. + +```{r} +df <- selectExpr(createDataFrame(data.frame(rawItems = c( + "T,R,U", "T,S", "V,R", "R,U,T,V", "R,S", "V,S,U", "U,R", "S,T", "V,R", "V,U,S", + "T,V,U", "R,V", "T,S", "T,S", "S,T", "S,U", "T,R", "V,R", "S,V", "T,S,U" +))), "split(rawItems, ',') AS items") + +fpm <- spark.fpGrowth(df, minSupport = 0.2, minConfidence = 0.5) +``` + +`spark.freqItemsets` method can be used to retrieve a `SparkDataFrame` with the frequent itemsets. + +```{r} +head(spark.freqItemsets(fpm)) +``` + +`spark.associationRules` returns a `SparkDataFrame` with the association rules. + +```{r} +head(spark.associationRules(fpm)) +``` + +We can make predictions based on the `antecedent`. + +```{r} +head(predict(fpm, df)) +``` + #### Kolmogorov-Smirnov Test `spark.kstest` runs a two-sided, one-sample [Kolmogorov-Smirnov (KS) test](https://en.wikipedia.org/wiki/Kolmogorov%E2%80%93Smirnov_test). @@ -935,6 +1003,72 @@ unlink(modelPath) ``` +## Structured Streaming + +SparkR supports the Structured Streaming API (experimental). + +You can check the Structured Streaming Programming Guide for [an introduction](https://spark.apache.org/docs/latest/structured-streaming-programming-guide.html#programming-model) to its programming model and basic concepts. + +### Simple Source and Sink + +Spark has a few built-in input sources. As an example, to test with a socket source reading text into words and displaying the computed word counts: + +```{r, eval=FALSE} +# Create DataFrame representing the stream of input lines from connection +lines <- read.stream("socket", host = hostname, port = port) + +# Split the lines into words +words <- selectExpr(lines, "explode(split(value, ' ')) as word") + +# Generate running word count +wordCounts <- count(groupBy(words, "word")) + +# Start running the query that prints the running counts to the console +query <- write.stream(wordCounts, "console", outputMode = "complete") +``` + +### Kafka Source + +It is simple to read data from Kafka. For more information, see [Input Sources](https://spark.apache.org/docs/latest/structured-streaming-programming-guide.html#input-sources) supported by Structured Streaming. + +```{r, eval=FALSE} +topic <- read.stream("kafka", + kafka.bootstrap.servers = "host1:port1,host2:port2", + subscribe = "topic1") +keyvalue <- selectExpr(topic, "CAST(key AS STRING)", "CAST(value AS STRING)") +``` + +### Operations and Sinks + +Most of the common operations on `SparkDataFrame` are supported for streaming, including selection, projection, and aggregation. Once you have defined the final result, to start the streaming computation, you will call the `write.stream` method setting a sink and `outputMode`. + +A streaming `SparkDataFrame` can be written for debugging to the console, to a temporary in-memory table, or for further processing in a fault-tolerant manner to a File Sink in different formats. + +```{r, eval=FALSE} +noAggDF <- select(where(deviceDataStreamingDf, "signal > 10"), "device") + +# Print new data to console +write.stream(noAggDF, "console") + +# Write new data to Parquet files +write.stream(noAggDF, + "parquet", + path = "path/to/destination/dir", + checkpointLocation = "path/to/checkpoint/dir") + +# Aggregate +aggDF <- count(groupBy(noAggDF, "device")) + +# Print updated aggregations to console +write.stream(aggDF, "console", outputMode = "complete") + +# Have all the aggregates in an in memory table. The query name will be the table name +write.stream(aggDF, "memory", queryName = "aggregates", outputMode = "complete") + +head(sql("select * from aggregates")) +``` + + ## Advanced Topics ### SparkR Object Classes diff --git a/R/run-tests.sh b/R/run-tests.sh index 742a2c5ed76da..29764f48bd156 100755 --- a/R/run-tests.sh +++ b/R/run-tests.sh @@ -23,7 +23,7 @@ FAILED=0 LOGFILE=$FWDIR/unit-tests.out rm -f $LOGFILE -SPARK_TESTING=1 $FWDIR/../bin/spark-submit --driver-java-options "-Dlog4j.configuration=file:$FWDIR/log4j.properties" --conf spark.hadoop.fs.defaultFS="file:///" $FWDIR/pkg/tests/run-all.R 2>&1 | tee -a $LOGFILE +SPARK_TESTING=1 NOT_CRAN=true $FWDIR/../bin/spark-submit --driver-java-options "-Dlog4j.configuration=file:$FWDIR/log4j.properties" --conf spark.hadoop.fs.defaultFS="file:///" $FWDIR/pkg/tests/run-all.R 2>&1 | tee -a $LOGFILE FAILED=$((PIPESTATUS[0]||$FAILED)) NUM_TEST_WARNING="$(grep -c -e 'Warnings ----------------' $LOGFILE)" diff --git a/README.md b/README.md index d995c103738a5..9495de8d2d9b6 100644 --- a/README.md +++ b/README.md @@ -115,7 +115,7 @@ building for particular Hive and Hive Thriftserver distributions. Please refer to the [Configuration Guide](http://spark.apache.org/docs/latest/configuration.html) in the online documentation for an overview on how to configure Spark. -## Contributing +## Contributing Please review the [Contribution to Spark guide](http://spark.apache.org/contributing.html) for information on how to get started contributing to the project. diff --git a/appveyor.yml b/appveyor.yml index 5adf1b4bedb44..bbb27589cad09 100644 --- a/appveyor.yml +++ b/appveyor.yml @@ -27,6 +27,9 @@ branches: only_commits: files: - R/ + - sql/core/src/main/scala/org/apache/spark/sql/api/r/ + - core/src/main/scala/org/apache/spark/api/r/ + - mllib/src/main/scala/org/apache/spark/ml/r/ cache: - C:\Users\appveyor\.m2 diff --git a/assembly/pom.xml b/assembly/pom.xml index 73677e870b731..7f339651c9751 100644 --- a/assembly/pom.xml +++ b/assembly/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.11 - 2.2.0-SNAPSHOT + 2.3.0-SNAPSHOT ../pom.xml diff --git a/bin/spark-class2.cmd b/bin/spark-class2.cmd index 869c0b202f7f3..9faa7d65f83e4 100644 --- a/bin/spark-class2.cmd +++ b/bin/spark-class2.cmd @@ -50,7 +50,16 @@ if not "x%SPARK_PREPEND_CLASSES%"=="x" ( rem Figure out where java is. set RUNNER=java -if not "x%JAVA_HOME%"=="x" set RUNNER=%JAVA_HOME%\bin\java +if not "x%JAVA_HOME%"=="x" ( + set RUNNER="%JAVA_HOME%\bin\java" +) else ( + where /q "%RUNNER%" + if ERRORLEVEL 1 ( + echo Java not found and JAVA_HOME environment variable is not set. + echo Install Java and set JAVA_HOME to point to the Java installation directory. + exit /b 1 + ) +) rem The launcher library prints the command to be executed in a single line suitable for being rem executed by the batch interpreter. So read all the output of the launcher into a variable. diff --git a/cloud/pom.xml b/cloud/pom.xml index 9baceaa1c1869..59be327bb3ee1 100644 --- a/cloud/pom.xml +++ b/cloud/pom.xml @@ -22,7 +22,7 @@ org.apache.spark spark-parent_2.11 - 2.2.0-SNAPSHOT + 2.3.0-SNAPSHOT ../pom.xml diff --git a/common/network-common/pom.xml b/common/network-common/pom.xml index 8657af744c069..066970f24205f 100644 --- a/common/network-common/pom.xml +++ b/common/network-common/pom.xml @@ -22,7 +22,7 @@ org.apache.spark spark-parent_2.11 - 2.2.0-SNAPSHOT + 2.3.0-SNAPSHOT ../../pom.xml diff --git a/common/network-common/src/main/java/org/apache/spark/network/crypto/ClientChallenge.java b/common/network-common/src/main/java/org/apache/spark/network/crypto/ClientChallenge.java index 3312a5bd81a66..819b8a7efbdba 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/crypto/ClientChallenge.java +++ b/common/network-common/src/main/java/org/apache/spark/network/crypto/ClientChallenge.java @@ -28,7 +28,7 @@ /** * The client challenge message, used to initiate authentication. * - * @see README.md + * Please see crypto/README.md for more details of implementation. */ public class ClientChallenge implements Encodable { /** Serialization tag used to catch incorrect payloads. */ diff --git a/common/network-common/src/main/java/org/apache/spark/network/crypto/ServerResponse.java b/common/network-common/src/main/java/org/apache/spark/network/crypto/ServerResponse.java index affdbf450b1d0..caf3a0f3b38cc 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/crypto/ServerResponse.java +++ b/common/network-common/src/main/java/org/apache/spark/network/crypto/ServerResponse.java @@ -28,7 +28,7 @@ /** * Server's response to client's challenge. * - * @see README.md + * Please see crypto/README.md for more details. */ public class ServerResponse implements Encodable { /** Serialization tag used to catch incorrect payloads. */ diff --git a/common/network-common/src/main/java/org/apache/spark/network/util/JavaUtils.java b/common/network-common/src/main/java/org/apache/spark/network/util/JavaUtils.java index f3eaf22c0166e..afc59efaef810 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/util/JavaUtils.java +++ b/common/network-common/src/main/java/org/apache/spark/network/util/JavaUtils.java @@ -18,10 +18,13 @@ package org.apache.spark.network.util; import java.io.Closeable; +import java.io.EOFException; import java.io.File; import java.io.IOException; import java.nio.ByteBuffer; +import java.nio.channels.ReadableByteChannel; import java.nio.charset.StandardCharsets; +import java.util.Locale; import java.util.concurrent.TimeUnit; import java.util.regex.Matcher; import java.util.regex.Pattern; @@ -208,7 +211,7 @@ private static boolean isSymlink(File file) throws IOException { * The unit is also considered the default if the given string does not specify a unit. */ public static long timeStringAs(String str, TimeUnit unit) { - String lower = str.toLowerCase().trim(); + String lower = str.toLowerCase(Locale.ROOT).trim(); try { Matcher m = Pattern.compile("(-?[0-9]+)([a-z]+)?").matcher(lower); @@ -256,7 +259,7 @@ public static long timeStringAsSec(String str) { * provided, a direct conversion to the provided unit is attempted. */ public static long byteStringAs(String str, ByteUnit unit) { - String lower = str.toLowerCase().trim(); + String lower = str.toLowerCase(Locale.ROOT).trim(); try { Matcher m = Pattern.compile("([0-9]+)([a-z]+)?").matcher(lower); @@ -344,4 +347,17 @@ public static byte[] bufferToArray(ByteBuffer buffer) { } } + /** + * Fills a buffer with data read from the channel. + */ + public static void readFully(ReadableByteChannel channel, ByteBuffer dst) throws IOException { + int expected = dst.remaining(); + while (dst.hasRemaining()) { + if (channel.read(dst) < 0) { + throw new EOFException(String.format("Not enough bytes in channel (expected %d).", + expected)); + } + } + } + } diff --git a/common/network-common/src/main/java/org/apache/spark/network/util/MapConfigProvider.java b/common/network-common/src/main/java/org/apache/spark/network/util/MapConfigProvider.java index 9cfee7f08d155..a2cf87d1af7ed 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/util/MapConfigProvider.java +++ b/common/network-common/src/main/java/org/apache/spark/network/util/MapConfigProvider.java @@ -42,6 +42,12 @@ public String get(String name) { return value; } + @Override + public String get(String name, String defaultValue) { + String value = config.get(name); + return value == null ? defaultValue : value; + } + @Override public Iterable> getAll() { return config.entrySet(); diff --git a/common/network-common/src/main/java/org/apache/spark/network/util/TransportConf.java b/common/network-common/src/main/java/org/apache/spark/network/util/TransportConf.java index c226d8f3bc8fa..a25078e262efb 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/util/TransportConf.java +++ b/common/network-common/src/main/java/org/apache/spark/network/util/TransportConf.java @@ -17,6 +17,7 @@ package org.apache.spark.network.util; +import java.util.Locale; import java.util.Properties; import com.google.common.primitives.Ints; @@ -75,7 +76,9 @@ public String getModuleName() { } /** IO mode: nio or epoll */ - public String ioMode() { return conf.get(SPARK_NETWORK_IO_MODE_KEY, "NIO").toUpperCase(); } + public String ioMode() { + return conf.get(SPARK_NETWORK_IO_MODE_KEY, "NIO").toUpperCase(Locale.ROOT); + } /** If true, we will prefer allocating off-heap byte buffers within Netty. */ public boolean preferDirectBufs() { diff --git a/common/network-common/src/test/java/org/apache/spark/network/RequestTimeoutIntegrationSuite.java b/common/network-common/src/test/java/org/apache/spark/network/RequestTimeoutIntegrationSuite.java index b848a91806fe6..1087b9fa7709b 100644 --- a/common/network-common/src/test/java/org/apache/spark/network/RequestTimeoutIntegrationSuite.java +++ b/common/network-common/src/test/java/org/apache/spark/network/RequestTimeoutIntegrationSuite.java @@ -226,6 +226,8 @@ public StreamManager getStreamManager() { callback0.latch.await(60, TimeUnit.SECONDS); assertTrue(callback0.failure instanceof IOException); + // make sure callback1 is called. + callback1.latch.await(60, TimeUnit.SECONDS); // failed at same time as previous assertTrue(callback1.failure instanceof IOException); } diff --git a/common/network-shuffle/pom.xml b/common/network-shuffle/pom.xml index 24c10fb1ddb9f..2de882adcb582 100644 --- a/common/network-shuffle/pom.xml +++ b/common/network-shuffle/pom.xml @@ -22,7 +22,7 @@ org.apache.spark spark-parent_2.11 - 2.2.0-SNAPSHOT + 2.3.0-SNAPSHOT ../../pom.xml diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandler.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandler.java index 6daf9609d76dc..c0f1da50f5e65 100644 --- a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandler.java +++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandler.java @@ -21,7 +21,7 @@ import java.io.IOException; import java.nio.ByteBuffer; import java.util.HashMap; -import java.util.List; +import java.util.Iterator; import java.util.Map; import com.codahale.metrics.Gauge; @@ -30,7 +30,6 @@ import com.codahale.metrics.MetricSet; import com.codahale.metrics.Timer; import com.google.common.annotations.VisibleForTesting; -import com.google.common.collect.Lists; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -93,14 +92,25 @@ protected void handleMessage( OpenBlocks msg = (OpenBlocks) msgObj; checkAuth(client, msg.appId); - List blocks = Lists.newArrayList(); - long totalBlockSize = 0; - for (String blockId : msg.blockIds) { - final ManagedBuffer block = blockManager.getBlockData(msg.appId, msg.execId, blockId); - totalBlockSize += block != null ? block.size() : 0; - blocks.add(block); - } - long streamId = streamManager.registerStream(client.getClientId(), blocks.iterator()); + Iterator iter = new Iterator() { + private int index = 0; + + @Override + public boolean hasNext() { + return index < msg.blockIds.length; + } + + @Override + public ManagedBuffer next() { + final ManagedBuffer block = blockManager.getBlockData(msg.appId, msg.execId, + msg.blockIds[index]); + index++; + metrics.blockTransferRateBytes.mark(block != null ? block.size() : 0); + return block; + } + }; + + long streamId = streamManager.registerStream(client.getClientId(), iter); if (logger.isTraceEnabled()) { logger.trace("Registered streamId {} with {} buffers for client {} from host {}", streamId, @@ -109,7 +119,6 @@ protected void handleMessage( getRemoteAddress(client.getChannel())); } callback.onSuccess(new StreamHandle(streamId, msg.blockIds.length).toByteBuffer()); - metrics.blockTransferRateBytes.mark(totalBlockSize); } finally { responseDelayContext.stop(); } diff --git a/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandlerSuite.java b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandlerSuite.java index e47a72c9d16cc..4d48b18970386 100644 --- a/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandlerSuite.java +++ b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandlerSuite.java @@ -88,8 +88,6 @@ public void testOpenShuffleBlocks() { ByteBuffer openBlocks = new OpenBlocks("app0", "exec1", new String[] { "b0", "b1" }) .toByteBuffer(); handler.receive(client, openBlocks, callback); - verify(blockResolver, times(1)).getBlockData("app0", "exec1", "b0"); - verify(blockResolver, times(1)).getBlockData("app0", "exec1", "b1"); ArgumentCaptor response = ArgumentCaptor.forClass(ByteBuffer.class); verify(callback, times(1)).onSuccess(response.capture()); @@ -107,6 +105,8 @@ public void testOpenShuffleBlocks() { assertEquals(block0Marker, buffers.next()); assertEquals(block1Marker, buffers.next()); assertFalse(buffers.hasNext()); + verify(blockResolver, times(1)).getBlockData("app0", "exec1", "b0"); + verify(blockResolver, times(1)).getBlockData("app0", "exec1", "b1"); // Verify open block request latency metrics Timer openBlockRequestLatencyMillis = (Timer) ((ExternalShuffleBlockHandler) handler) diff --git a/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleIntegrationSuite.java b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleIntegrationSuite.java index b8ae04eefb972..7a33b6821792c 100644 --- a/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleIntegrationSuite.java +++ b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleIntegrationSuite.java @@ -216,9 +216,8 @@ public void testFetchWrongExecutor() throws Exception { registerExecutor("exec-0", dataContext0.createExecutorInfo(SORT_MANAGER)); FetchResult execFetch = fetchBlocks("exec-0", new String[] { "shuffle_0_0_0" /* right */, "shuffle_1_0_0" /* wrong */ }); - // Both still fail, as we start by checking for all block. - assertTrue(execFetch.successBlocks.isEmpty()); - assertEquals(Sets.newHashSet("shuffle_0_0_0", "shuffle_1_0_0"), execFetch.failedBlocks); + assertEquals(Sets.newHashSet("shuffle_0_0_0"), execFetch.successBlocks); + assertEquals(Sets.newHashSet("shuffle_1_0_0"), execFetch.failedBlocks); } @Test diff --git a/common/network-yarn/pom.xml b/common/network-yarn/pom.xml index 5e5a80bd44467..a8488d8d1b704 100644 --- a/common/network-yarn/pom.xml +++ b/common/network-yarn/pom.xml @@ -22,7 +22,7 @@ org.apache.spark spark-parent_2.11 - 2.2.0-SNAPSHOT + 2.3.0-SNAPSHOT ../../pom.xml diff --git a/common/network-yarn/src/main/java/org/apache/spark/network/yarn/YarnShuffleService.java b/common/network-yarn/src/main/java/org/apache/spark/network/yarn/YarnShuffleService.java index 63b87ae35b242..cfbf6a8a9a61b 100644 --- a/common/network-yarn/src/main/java/org/apache/spark/network/yarn/YarnShuffleService.java +++ b/common/network-yarn/src/main/java/org/apache/spark/network/yarn/YarnShuffleService.java @@ -361,9 +361,9 @@ protected Path getRecoveryPath(String fileName) { * when it previously was not. If YARN NM recovery is enabled it uses that path, otherwise * it will uses a YARN local dir. */ - protected File initRecoveryDb(String dbFileName) { + protected File initRecoveryDb(String dbName) { if (_recoveryPath != null) { - File recoveryFile = new File(_recoveryPath.toUri().getPath(), dbFileName); + File recoveryFile = new File(_recoveryPath.toUri().getPath(), dbName); if (recoveryFile.exists()) { return recoveryFile; } @@ -371,7 +371,7 @@ protected File initRecoveryDb(String dbFileName) { // db doesn't exist in recovery path go check local dirs for it String[] localDirs = _conf.getTrimmedStrings("yarn.nodemanager.local-dirs"); for (String dir : localDirs) { - File f = new File(new Path(dir).toUri().getPath(), dbFileName); + File f = new File(new Path(dir).toUri().getPath(), dbName); if (f.exists()) { if (_recoveryPath == null) { // If NM recovery is not enabled, we should specify the recovery path using NM local @@ -384,17 +384,21 @@ protected File initRecoveryDb(String dbFileName) { // make sure to move all DBs to the recovery path from the old NM local dirs. // If another DB was initialized first just make sure all the DBs are in the same // location. - File newLoc = new File(_recoveryPath.toUri().getPath(), dbFileName); - if (!newLoc.equals(f)) { + Path newLoc = new Path(_recoveryPath, dbName); + Path copyFrom = new Path(f.toURI()); + if (!newLoc.equals(copyFrom)) { + logger.info("Moving " + copyFrom + " to: " + newLoc); try { - Files.move(f.toPath(), newLoc.toPath()); + // The move here needs to handle moving non-empty directories across NFS mounts + FileSystem fs = FileSystem.getLocal(_conf); + fs.rename(copyFrom, newLoc); } catch (Exception e) { // Fail to move recovery file to new path, just continue on with new DB location logger.error("Failed to move recovery file {} to the path {}", - dbFileName, _recoveryPath.toString(), e); + dbName, _recoveryPath.toString(), e); } } - return newLoc; + return new File(newLoc.toUri().getPath()); } } } @@ -402,7 +406,7 @@ protected File initRecoveryDb(String dbFileName) { _recoveryPath = new Path(localDirs[0]); } - return new File(_recoveryPath.toUri().getPath(), dbFileName); + return new File(_recoveryPath.toUri().getPath(), dbName); } /** diff --git a/common/network-yarn/src/main/java/org/apache/spark/network/yarn/util/HadoopConfigProvider.java b/common/network-yarn/src/main/java/org/apache/spark/network/yarn/util/HadoopConfigProvider.java index 62a6cca4ed4eb..8beb033699471 100644 --- a/common/network-yarn/src/main/java/org/apache/spark/network/yarn/util/HadoopConfigProvider.java +++ b/common/network-yarn/src/main/java/org/apache/spark/network/yarn/util/HadoopConfigProvider.java @@ -41,6 +41,12 @@ public String get(String name) { return value; } + @Override + public String get(String name, String defaultValue) { + String value = conf.get(name); + return value == null ? defaultValue : value; + } + @Override public Iterable> getAll() { return conf; diff --git a/common/sketch/pom.xml b/common/sketch/pom.xml index 1356c4723b662..6b81fc2b2b040 100644 --- a/common/sketch/pom.xml +++ b/common/sketch/pom.xml @@ -22,7 +22,7 @@ org.apache.spark spark-parent_2.11 - 2.2.0-SNAPSHOT + 2.3.0-SNAPSHOT ../../pom.xml diff --git a/common/tags/pom.xml b/common/tags/pom.xml index 9345dc8f0cc4b..f7e586ee777e1 100644 --- a/common/tags/pom.xml +++ b/common/tags/pom.xml @@ -22,7 +22,7 @@ org.apache.spark spark-parent_2.11 - 2.2.0-SNAPSHOT + 2.3.0-SNAPSHOT ../../pom.xml diff --git a/common/unsafe/pom.xml b/common/unsafe/pom.xml index f03a4da5e7152..680d0413b1616 100644 --- a/common/unsafe/pom.xml +++ b/common/unsafe/pom.xml @@ -22,7 +22,7 @@ org.apache.spark spark-parent_2.11 - 2.2.0-SNAPSHOT + 2.3.0-SNAPSHOT ../../pom.xml diff --git a/common/unsafe/src/main/java/org/apache/spark/unsafe/Platform.java b/common/unsafe/src/main/java/org/apache/spark/unsafe/Platform.java index f13c24ae5e017..4ab5b6889c212 100644 --- a/common/unsafe/src/main/java/org/apache/spark/unsafe/Platform.java +++ b/common/unsafe/src/main/java/org/apache/spark/unsafe/Platform.java @@ -46,18 +46,23 @@ public final class Platform { private static final boolean unaligned; static { boolean _unaligned; - // use reflection to access unaligned field - try { - Class bitsClass = - Class.forName("java.nio.Bits", false, ClassLoader.getSystemClassLoader()); - Method unalignedMethod = bitsClass.getDeclaredMethod("unaligned"); - unalignedMethod.setAccessible(true); - _unaligned = Boolean.TRUE.equals(unalignedMethod.invoke(null)); - } catch (Throwable t) { - // We at least know x86 and x64 support unaligned access. - String arch = System.getProperty("os.arch", ""); - //noinspection DynamicRegexReplaceableByCompiledPattern - _unaligned = arch.matches("^(i[3-6]86|x86(_64)?|x64|amd64|aarch64)$"); + String arch = System.getProperty("os.arch", ""); + if (arch.equals("ppc64le") || arch.equals("ppc64")) { + // Since java.nio.Bits.unaligned() doesn't return true on ppc (See JDK-8165231), but + // ppc64 and ppc64le support it + _unaligned = true; + } else { + try { + Class bitsClass = + Class.forName("java.nio.Bits", false, ClassLoader.getSystemClassLoader()); + Method unalignedMethod = bitsClass.getDeclaredMethod("unaligned"); + unalignedMethod.setAccessible(true); + _unaligned = Boolean.TRUE.equals(unalignedMethod.invoke(null)); + } catch (Throwable t) { + // We at least know x86 and x64 support unaligned access. + //noinspection DynamicRegexReplaceableByCompiledPattern + _unaligned = arch.matches("^(i[3-6]86|x86(_64)?|x64|amd64|aarch64)$"); + } } unaligned = _unaligned; } diff --git a/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java b/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java index 10a7cb1d06659..5437e998c085f 100644 --- a/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java +++ b/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java @@ -850,11 +850,23 @@ public UTF8String translate(Map dict) { return fromString(sb.toString()); } - private int getDigit(byte b) { - if (b >= '0' && b <= '9') { - return b - '0'; - } - throw new NumberFormatException(toString()); + /** + * Wrapper over `long` to allow result of parsing long from string to be accessed via reference. + * This is done solely for better performance and is not expected to be used by end users. + */ + public static class LongWrapper { + public long value = 0; + } + + /** + * Wrapper over `int` to allow result of parsing integer from string to be accessed via reference. + * This is done solely for better performance and is not expected to be used by end users. + * + * {@link LongWrapper} could have been used here but using `int` directly save the extra cost of + * conversion from `long` to `int` + */ + public static class IntWrapper { + public int value = 0; } /** @@ -862,14 +874,18 @@ private int getDigit(byte b) { * * Note that, in this method we accumulate the result in negative format, and convert it to * positive format at the end, if this string is not started with '-'. This is because min value - * is bigger than max value in digits, e.g. Integer.MAX_VALUE is '2147483647' and - * Integer.MIN_VALUE is '-2147483648'. + * is bigger than max value in digits, e.g. Long.MAX_VALUE is '9223372036854775807' and + * Long.MIN_VALUE is '-9223372036854775808'. * * This code is mostly copied from LazyLong.parseLong in Hive. + * + * @param toLongResult If a valid `long` was parsed from this UTF8String, then its value would + * be set in `toLongResult` + * @return true if the parsing was successful else false */ - public long toLong() { + public boolean toLong(LongWrapper toLongResult) { if (numBytes == 0) { - throw new NumberFormatException("Empty string"); + return false; } byte b = getByte(0); @@ -878,7 +894,7 @@ public long toLong() { if (negative || b == '+') { offset++; if (numBytes == 1) { - throw new NumberFormatException(toString()); + return false; } } @@ -897,20 +913,25 @@ public long toLong() { break; } - int digit = getDigit(b); + int digit; + if (b >= '0' && b <= '9') { + digit = b - '0'; + } else { + return false; + } + // We are going to process the new digit and accumulate the result. However, before doing // this, if the result is already smaller than the stopValue(Long.MIN_VALUE / radix), then - // result * 10 will definitely be smaller than minValue, and we can stop and throw exception. + // result * 10 will definitely be smaller than minValue, and we can stop. if (result < stopValue) { - throw new NumberFormatException(toString()); + return false; } result = result * radix - digit; // Since the previous result is less than or equal to stopValue(Long.MIN_VALUE / radix), we - // can just use `result > 0` to check overflow. If result overflows, we should stop and throw - // exception. + // can just use `result > 0` to check overflow. If result overflows, we should stop. if (result > 0) { - throw new NumberFormatException(toString()); + return false; } } @@ -918,8 +939,9 @@ public long toLong() { // part will not change the number, but we will verify that the fractional part // is well formed. while (offset < numBytes) { - if (getDigit(getByte(offset)) == -1) { - throw new NumberFormatException(toString()); + byte currentByte = getByte(offset); + if (currentByte < '0' || currentByte > '9') { + return false; } offset++; } @@ -927,11 +949,12 @@ public long toLong() { if (!negative) { result = -result; if (result < 0) { - throw new NumberFormatException(toString()); + return false; } } - return result; + toLongResult.value = result; + return true; } /** @@ -946,10 +969,14 @@ public long toLong() { * * Note that, this method is almost same as `toLong`, but we leave it duplicated for performance * reasons, like Hive does. + * + * @param intWrapper If a valid `int` was parsed from this UTF8String, then its value would + * be set in `intWrapper` + * @return true if the parsing was successful else false */ - public int toInt() { + public boolean toInt(IntWrapper intWrapper) { if (numBytes == 0) { - throw new NumberFormatException("Empty string"); + return false; } byte b = getByte(0); @@ -958,7 +985,7 @@ public int toInt() { if (negative || b == '+') { offset++; if (numBytes == 1) { - throw new NumberFormatException(toString()); + return false; } } @@ -977,20 +1004,25 @@ public int toInt() { break; } - int digit = getDigit(b); + int digit; + if (b >= '0' && b <= '9') { + digit = b - '0'; + } else { + return false; + } + // We are going to process the new digit and accumulate the result. However, before doing // this, if the result is already smaller than the stopValue(Integer.MIN_VALUE / radix), then - // result * 10 will definitely be smaller than minValue, and we can stop and throw exception. + // result * 10 will definitely be smaller than minValue, and we can stop if (result < stopValue) { - throw new NumberFormatException(toString()); + return false; } result = result * radix - digit; // Since the previous result is less than or equal to stopValue(Integer.MIN_VALUE / radix), - // we can just use `result > 0` to check overflow. If result overflows, we should stop and - // throw exception. + // we can just use `result > 0` to check overflow. If result overflows, we should stop if (result > 0) { - throw new NumberFormatException(toString()); + return false; } } @@ -998,8 +1030,9 @@ public int toInt() { // part will not change the number, but we will verify that the fractional part // is well formed. while (offset < numBytes) { - if (getDigit(getByte(offset)) == -1) { - throw new NumberFormatException(toString()); + byte currentByte = getByte(offset); + if (currentByte < '0' || currentByte > '9') { + return false; } offset++; } @@ -1007,31 +1040,33 @@ public int toInt() { if (!negative) { result = -result; if (result < 0) { - throw new NumberFormatException(toString()); + return false; } } - - return result; + intWrapper.value = result; + return true; } - public short toShort() { - int intValue = toInt(); - short result = (short) intValue; - if (result != intValue) { - throw new NumberFormatException(toString()); + public boolean toShort(IntWrapper intWrapper) { + if (toInt(intWrapper)) { + int intValue = intWrapper.value; + short result = (short) intValue; + if (result == intValue) { + return true; + } } - - return result; + return false; } - public byte toByte() { - int intValue = toInt(); - byte result = (byte) intValue; - if (result != intValue) { - throw new NumberFormatException(toString()); + public boolean toByte(IntWrapper intWrapper) { + if (toInt(intWrapper)) { + int intValue = intWrapper.value; + byte result = (byte) intValue; + if (result == intValue) { + return true; + } } - - return result; + return false; } @Override diff --git a/common/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java b/common/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java index 6f6e0ef0e4855..c376371abdf90 100644 --- a/common/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java +++ b/common/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java @@ -22,9 +22,7 @@ import java.nio.ByteBuffer; import java.nio.ByteOrder; import java.nio.charset.StandardCharsets; -import java.util.Arrays; -import java.util.HashMap; -import java.util.HashSet; +import java.util.*; import com.google.common.collect.ImmutableMap; import org.apache.spark.unsafe.Platform; @@ -608,4 +606,128 @@ public void writeToOutputStreamIntArray() throws IOException { .writeTo(outputStream); assertEquals("大千世界", outputStream.toString("UTF-8")); } + + @Test + public void testToShort() throws IOException { + Map inputToExpectedOutput = new HashMap<>(); + inputToExpectedOutput.put("1", (short) 1); + inputToExpectedOutput.put("+1", (short) 1); + inputToExpectedOutput.put("-1", (short) -1); + inputToExpectedOutput.put("0", (short) 0); + inputToExpectedOutput.put("1111.12345678901234567890", (short) 1111); + inputToExpectedOutput.put(String.valueOf(Short.MAX_VALUE), Short.MAX_VALUE); + inputToExpectedOutput.put(String.valueOf(Short.MIN_VALUE), Short.MIN_VALUE); + + Random rand = new Random(); + for (int i = 0; i < 10; i++) { + short value = (short) rand.nextInt(); + inputToExpectedOutput.put(String.valueOf(value), value); + } + + IntWrapper wrapper = new IntWrapper(); + for (Map.Entry entry : inputToExpectedOutput.entrySet()) { + assertTrue(entry.getKey(), UTF8String.fromString(entry.getKey()).toShort(wrapper)); + assertEquals((short) entry.getValue(), wrapper.value); + } + + List negativeInputs = + Arrays.asList("", " ", "null", "NULL", "\n", "~1212121", "3276700"); + + for (String negativeInput : negativeInputs) { + assertFalse(negativeInput, UTF8String.fromString(negativeInput).toShort(wrapper)); + } + } + + @Test + public void testToByte() throws IOException { + Map inputToExpectedOutput = new HashMap<>(); + inputToExpectedOutput.put("1", (byte) 1); + inputToExpectedOutput.put("+1",(byte) 1); + inputToExpectedOutput.put("-1", (byte) -1); + inputToExpectedOutput.put("0", (byte) 0); + inputToExpectedOutput.put("111.12345678901234567890", (byte) 111); + inputToExpectedOutput.put(String.valueOf(Byte.MAX_VALUE), Byte.MAX_VALUE); + inputToExpectedOutput.put(String.valueOf(Byte.MIN_VALUE), Byte.MIN_VALUE); + + Random rand = new Random(); + for (int i = 0; i < 10; i++) { + byte value = (byte) rand.nextInt(); + inputToExpectedOutput.put(String.valueOf(value), value); + } + + IntWrapper intWrapper = new IntWrapper(); + for (Map.Entry entry : inputToExpectedOutput.entrySet()) { + assertTrue(entry.getKey(), UTF8String.fromString(entry.getKey()).toByte(intWrapper)); + assertEquals((byte) entry.getValue(), intWrapper.value); + } + + List negativeInputs = + Arrays.asList("", " ", "null", "NULL", "\n", "~1212121", "12345678901234567890"); + + for (String negativeInput : negativeInputs) { + assertFalse(negativeInput, UTF8String.fromString(negativeInput).toByte(intWrapper)); + } + } + + @Test + public void testToInt() throws IOException { + Map inputToExpectedOutput = new HashMap<>(); + inputToExpectedOutput.put("1", 1); + inputToExpectedOutput.put("+1", 1); + inputToExpectedOutput.put("-1", -1); + inputToExpectedOutput.put("0", 0); + inputToExpectedOutput.put("11111.1234567", 11111); + inputToExpectedOutput.put(String.valueOf(Integer.MAX_VALUE), Integer.MAX_VALUE); + inputToExpectedOutput.put(String.valueOf(Integer.MIN_VALUE), Integer.MIN_VALUE); + + Random rand = new Random(); + for (int i = 0; i < 10; i++) { + int value = rand.nextInt(); + inputToExpectedOutput.put(String.valueOf(value), value); + } + + IntWrapper intWrapper = new IntWrapper(); + for (Map.Entry entry : inputToExpectedOutput.entrySet()) { + assertTrue(entry.getKey(), UTF8String.fromString(entry.getKey()).toInt(intWrapper)); + assertEquals((int) entry.getValue(), intWrapper.value); + } + + List negativeInputs = + Arrays.asList("", " ", "null", "NULL", "\n", "~1212121", "12345678901234567890"); + + for (String negativeInput : negativeInputs) { + assertFalse(negativeInput, UTF8String.fromString(negativeInput).toInt(intWrapper)); + } + } + + @Test + public void testToLong() throws IOException { + Map inputToExpectedOutput = new HashMap<>(); + inputToExpectedOutput.put("1", 1L); + inputToExpectedOutput.put("+1", 1L); + inputToExpectedOutput.put("-1", -1L); + inputToExpectedOutput.put("0", 0L); + inputToExpectedOutput.put("1076753423.12345678901234567890", 1076753423L); + inputToExpectedOutput.put(String.valueOf(Long.MAX_VALUE), Long.MAX_VALUE); + inputToExpectedOutput.put(String.valueOf(Long.MIN_VALUE), Long.MIN_VALUE); + + Random rand = new Random(); + for (int i = 0; i < 10; i++) { + long value = rand.nextLong(); + inputToExpectedOutput.put(String.valueOf(value), value); + } + + LongWrapper wrapper = new LongWrapper(); + for (Map.Entry entry : inputToExpectedOutput.entrySet()) { + assertTrue(entry.getKey(), UTF8String.fromString(entry.getKey()).toLong(wrapper)); + assertEquals((long) entry.getValue(), wrapper.value); + } + + List negativeInputs = Arrays.asList("", " ", "null", "NULL", "\n", "~1212121", + "1234567890123456789012345678901234"); + + for (String negativeInput : negativeInputs) { + assertFalse(negativeInput, UTF8String.fromString(negativeInput).toLong(wrapper)); + } + } } diff --git a/conf/spark-env.sh.template b/conf/spark-env.sh.template index 5c1e876ef9afc..94bd2c477a35b 100755 --- a/conf/spark-env.sh.template +++ b/conf/spark-env.sh.template @@ -25,12 +25,10 @@ # - HADOOP_CONF_DIR, to point Spark towards Hadoop configuration files # - SPARK_LOCAL_IP, to set the IP address Spark binds to on this node # - SPARK_PUBLIC_DNS, to set the public dns name of the driver program -# - SPARK_CLASSPATH, default classpath entries to append # Options read by executors and drivers running inside the cluster # - SPARK_LOCAL_IP, to set the IP address Spark binds to on this node # - SPARK_PUBLIC_DNS, to set the public DNS name of the driver program -# - SPARK_CLASSPATH, default classpath entries to append # - SPARK_LOCAL_DIRS, storage directories to use on this node for shuffle and RDD data # - MESOS_NATIVE_JAVA_LIBRARY, to point to your libmesos.so if you use Mesos @@ -48,7 +46,6 @@ # - SPARK_WORKER_CORES, to set the number of cores to use on this machine # - SPARK_WORKER_MEMORY, to set how much total memory workers have to give executors (e.g. 1000m, 2g) # - SPARK_WORKER_PORT / SPARK_WORKER_WEBUI_PORT, to use non-default ports for the worker -# - SPARK_WORKER_INSTANCES, to set the number of worker processes per node # - SPARK_WORKER_DIR, to set the working directory of worker processes # - SPARK_WORKER_OPTS, to set config properties only for the worker (e.g. "-Dx=y") # - SPARK_DAEMON_MEMORY, to allocate to the master, worker and history server themselves (default: 1g). diff --git a/core/pom.xml b/core/pom.xml index 97a463abbefdd..7f245b5b6384a 100644 --- a/core/pom.xml +++ b/core/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.11 - 2.2.0-SNAPSHOT + 2.3.0-SNAPSHOT ../pom.xml @@ -33,6 +33,10 @@ Spark Project Core http://spark.apache.org/ + + org.apache.avro + avro + org.apache.avro avro-mapred diff --git a/core/src/main/java/org/apache/spark/memory/MemoryConsumer.java b/core/src/main/java/org/apache/spark/memory/MemoryConsumer.java index fc1f3a80239ba..48cf4b9455e4d 100644 --- a/core/src/main/java/org/apache/spark/memory/MemoryConsumer.java +++ b/core/src/main/java/org/apache/spark/memory/MemoryConsumer.java @@ -60,8 +60,6 @@ protected long getUsed() { /** * Force spill during building. - * - * For testing. */ public void spill() throws IOException { spill(Long.MAX_VALUE, this); diff --git a/core/src/main/java/org/apache/spark/memory/TaskMemoryManager.java b/core/src/main/java/org/apache/spark/memory/TaskMemoryManager.java index 39fb3b249d731..5f91411749167 100644 --- a/core/src/main/java/org/apache/spark/memory/TaskMemoryManager.java +++ b/core/src/main/java/org/apache/spark/memory/TaskMemoryManager.java @@ -155,11 +155,8 @@ public long acquireExecutionMemory(long required, MemoryConsumer consumer) { for (MemoryConsumer c: consumers) { if (c != consumer && c.getUsed() > 0 && c.getMode() == mode) { long key = c.getUsed(); - List list = sortedConsumers.get(key); - if (list == null) { - list = new ArrayList<>(1); - sortedConsumers.put(key, list); - } + List list = + sortedConsumers.computeIfAbsent(key, k -> new ArrayList<>(1)); list.add(c); } } diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java b/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java index 4a15559e55cbd..323a5d3c52831 100644 --- a/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java +++ b/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java @@ -52,8 +52,7 @@ * This class implements sort-based shuffle's hash-style shuffle fallback path. This write path * writes incoming records to separate files, one file per reduce partition, then concatenates these * per-partition files to form a single output file, regions of which are served to reducers. - * Records are not buffered in memory. This is essentially identical to - * {@link org.apache.spark.shuffle.hash.HashShuffleWriter}, except that it writes output in a format + * Records are not buffered in memory. It writes output in a format * that can be served / consumed via {@link org.apache.spark.shuffle.IndexShuffleBlockResolver}. *

* This write path is inefficient for shuffles with large numbers of reduce partitions because it @@ -61,7 +60,7 @@ * {@link SortShuffleManager} only selects this write path when *

    *
  • no Ordering is specified,
  • - *
  • no Aggregator is specific, and
  • + *
  • no Aggregator is specified, and
  • *
  • the number of partitions is less than * spark.shuffle.sort.bypassMergeThreshold.
  • *
diff --git a/core/src/main/java/org/apache/spark/status/api/v1/TaskSorting.java b/core/src/main/java/org/apache/spark/status/api/v1/TaskSorting.java index 9307eb93a5b20..dff4f5df68784 100644 --- a/core/src/main/java/org/apache/spark/status/api/v1/TaskSorting.java +++ b/core/src/main/java/org/apache/spark/status/api/v1/TaskSorting.java @@ -19,7 +19,9 @@ import org.apache.spark.util.EnumUtil; +import java.util.Collections; import java.util.HashSet; +import java.util.Locale; import java.util.Set; public enum TaskSorting { @@ -30,13 +32,11 @@ public enum TaskSorting { private final Set alternateNames; TaskSorting(String... names) { alternateNames = new HashSet<>(); - for (String n: names) { - alternateNames.add(n); - } + Collections.addAll(alternateNames, names); } public static TaskSorting fromString(String str) { - String lower = str.toLowerCase(); + String lower = str.toLowerCase(Locale.ROOT); for (TaskSorting t: values()) { if (t.alternateNames.contains(lower)) { return t; diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java index f219c5605b643..c14c12664f5ab 100644 --- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java +++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java @@ -23,7 +23,6 @@ import org.apache.avro.reflect.Nullable; import org.apache.spark.TaskContext; -import org.apache.spark.TaskKilledException; import org.apache.spark.memory.MemoryConsumer; import org.apache.spark.memory.TaskMemoryManager; import org.apache.spark.unsafe.Platform; @@ -291,8 +290,8 @@ public void loadNext() { // to avoid performance overhead. This check is added here in `loadNext()` instead of in // `hasNext()` because it's technically possible for the caller to be relying on // `getNumRecords()` instead of `hasNext()` to know when to stop. - if (taskContext != null && taskContext.isInterrupted()) { - throw new TaskKilledException(); + if (taskContext != null) { + taskContext.killTaskIfInterrupted(); } // This pointer points to a 4-byte record length, followed by the record's bytes final long recordPointer = array.get(offset + position); diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillReader.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillReader.java index b6323c624b7b9..9521ab86a12d5 100644 --- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillReader.java +++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillReader.java @@ -24,7 +24,6 @@ import org.apache.spark.SparkEnv; import org.apache.spark.TaskContext; -import org.apache.spark.TaskKilledException; import org.apache.spark.io.NioBufferedFileInputStream; import org.apache.spark.serializer.SerializerManager; import org.apache.spark.storage.BlockId; @@ -102,8 +101,8 @@ public void loadNext() throws IOException { // to avoid performance overhead. This check is added here in `loadNext()` instead of in // `hasNext()` because it's technically possible for the caller to be relying on // `getNumRecords()` instead of `hasNext()` to know when to stop. - if (taskContext != null && taskContext.isInterrupted()) { - throw new TaskKilledException(); + if (taskContext != null) { + taskContext.killTaskIfInterrupted(); } recordLength = din.readInt(); keyPrefix = din.readLong(); diff --git a/core/src/main/resources/org/apache/spark/ui/static/executorspage-template.html b/core/src/main/resources/org/apache/spark/ui/static/executorspage-template.html index 4e83d6d564986..5c91304e49fd7 100644 --- a/core/src/main/resources/org/apache/spark/ui/static/executorspage-template.html +++ b/core/src/main/resources/org/apache/spark/ui/static/executorspage-template.html @@ -24,7 +24,15 @@

Summary

RDD Blocks Storage Memory + title="Memory used / total available memory for storage of data like RDD partitions cached in memory.">Storage Memory + + + On Heap Storage Memory + + + Off Heap Storage Memory Disk Used Cores @@ -73,6 +81,14 @@

Executors

Storage Memory + + + On Heap Storage Memory + + + Off Heap Storage Memory Disk Used Cores Active Tasks diff --git a/core/src/main/resources/org/apache/spark/ui/static/executorspage.js b/core/src/main/resources/org/apache/spark/ui/static/executorspage.js index 1aa03b025be64..8ddb1b7746b68 100644 --- a/core/src/main/resources/org/apache/spark/ui/static/executorspage.js +++ b/core/src/main/resources/org/apache/spark/ui/static/executorspage.js @@ -224,6 +224,10 @@ $(document).ready(function () { var allRDDBlocks = 0; var allMemoryUsed = 0; var allMaxMemory = 0; + var allOnHeapMemoryUsed = 0; + var allOnHeapMaxMemory = 0; + var allOffHeapMemoryUsed = 0; + var allOffHeapMaxMemory = 0; var allDiskUsed = 0; var allTotalCores = 0; var allMaxTasks = 0; @@ -242,6 +246,10 @@ $(document).ready(function () { var activeRDDBlocks = 0; var activeMemoryUsed = 0; var activeMaxMemory = 0; + var activeOnHeapMemoryUsed = 0; + var activeOnHeapMaxMemory = 0; + var activeOffHeapMemoryUsed = 0; + var activeOffHeapMaxMemory = 0; var activeDiskUsed = 0; var activeTotalCores = 0; var activeMaxTasks = 0; @@ -260,6 +268,10 @@ $(document).ready(function () { var deadRDDBlocks = 0; var deadMemoryUsed = 0; var deadMaxMemory = 0; + var deadOnHeapMemoryUsed = 0; + var deadOnHeapMaxMemory = 0; + var deadOffHeapMemoryUsed = 0; + var deadOffHeapMaxMemory = 0; var deadDiskUsed = 0; var deadTotalCores = 0; var deadMaxTasks = 0; @@ -274,11 +286,26 @@ $(document).ready(function () { var deadTotalShuffleWrite = 0; var deadTotalBlacklisted = 0; + response.forEach(function (exec) { + var memoryMetrics = { + usedOnHeapStorageMemory: 0, + usedOffHeapStorageMemory: 0, + totalOnHeapStorageMemory: 0, + totalOffHeapStorageMemory: 0 + }; + + exec.memoryMetrics = exec.hasOwnProperty('memoryMetrics') ? exec.memoryMetrics : memoryMetrics; + }); + response.forEach(function (exec) { allExecCnt += 1; allRDDBlocks += exec.rddBlocks; allMemoryUsed += exec.memoryUsed; allMaxMemory += exec.maxMemory; + allOnHeapMemoryUsed += exec.memoryMetrics.usedOnHeapStorageMemory; + allOnHeapMaxMemory += exec.memoryMetrics.totalOnHeapStorageMemory; + allOffHeapMemoryUsed += exec.memoryMetrics.usedOffHeapStorageMemory; + allOffHeapMaxMemory += exec.memoryMetrics.totalOffHeapStorageMemory; allDiskUsed += exec.diskUsed; allTotalCores += exec.totalCores; allMaxTasks += exec.maxTasks; @@ -297,6 +324,10 @@ $(document).ready(function () { activeRDDBlocks += exec.rddBlocks; activeMemoryUsed += exec.memoryUsed; activeMaxMemory += exec.maxMemory; + activeOnHeapMemoryUsed += exec.memoryMetrics.usedOnHeapStorageMemory; + activeOnHeapMaxMemory += exec.memoryMetrics.totalOnHeapStorageMemory; + activeOffHeapMemoryUsed += exec.memoryMetrics.usedOffHeapStorageMemory; + activeOffHeapMaxMemory += exec.memoryMetrics.totalOffHeapStorageMemory; activeDiskUsed += exec.diskUsed; activeTotalCores += exec.totalCores; activeMaxTasks += exec.maxTasks; @@ -315,6 +346,10 @@ $(document).ready(function () { deadRDDBlocks += exec.rddBlocks; deadMemoryUsed += exec.memoryUsed; deadMaxMemory += exec.maxMemory; + deadOnHeapMemoryUsed += exec.memoryMetrics.usedOnHeapStorageMemory; + deadOnHeapMaxMemory += exec.memoryMetrics.totalOnHeapStorageMemory; + deadOffHeapMemoryUsed += exec.memoryMetrics.usedOffHeapStorageMemory; + deadOffHeapMaxMemory += exec.memoryMetrics.totalOffHeapStorageMemory; deadDiskUsed += exec.diskUsed; deadTotalCores += exec.totalCores; deadMaxTasks += exec.maxTasks; @@ -336,6 +371,10 @@ $(document).ready(function () { "allRDDBlocks": allRDDBlocks, "allMemoryUsed": allMemoryUsed, "allMaxMemory": allMaxMemory, + "allOnHeapMemoryUsed": allOnHeapMemoryUsed, + "allOnHeapMaxMemory": allOnHeapMaxMemory, + "allOffHeapMemoryUsed": allOffHeapMemoryUsed, + "allOffHeapMaxMemory": allOffHeapMaxMemory, "allDiskUsed": allDiskUsed, "allTotalCores": allTotalCores, "allMaxTasks": allMaxTasks, @@ -355,6 +394,10 @@ $(document).ready(function () { "allRDDBlocks": activeRDDBlocks, "allMemoryUsed": activeMemoryUsed, "allMaxMemory": activeMaxMemory, + "allOnHeapMemoryUsed": activeOnHeapMemoryUsed, + "allOnHeapMaxMemory": activeOnHeapMaxMemory, + "allOffHeapMemoryUsed": activeOffHeapMemoryUsed, + "allOffHeapMaxMemory": activeOffHeapMaxMemory, "allDiskUsed": activeDiskUsed, "allTotalCores": activeTotalCores, "allMaxTasks": activeMaxTasks, @@ -374,6 +417,10 @@ $(document).ready(function () { "allRDDBlocks": deadRDDBlocks, "allMemoryUsed": deadMemoryUsed, "allMaxMemory": deadMaxMemory, + "allOnHeapMemoryUsed": deadOnHeapMemoryUsed, + "allOnHeapMaxMemory": deadOnHeapMaxMemory, + "allOffHeapMemoryUsed": deadOffHeapMemoryUsed, + "allOffHeapMaxMemory": deadOffHeapMaxMemory, "allDiskUsed": deadDiskUsed, "allTotalCores": deadTotalCores, "allMaxTasks": deadMaxTasks, @@ -412,7 +459,35 @@ $(document).ready(function () { {data: 'rddBlocks'}, { data: function (row, type) { - return type === 'display' ? (formatBytes(row.memoryUsed, type) + ' / ' + formatBytes(row.maxMemory, type)) : row.memoryUsed; + if (type !== 'display') + return row.memoryUsed; + else + return (formatBytes(row.memoryUsed, type) + ' / ' + + formatBytes(row.maxMemory, type)); + } + }, + { + data: function (row, type) { + if (type !== 'display') + return row.memoryMetrics.usedOnHeapStorageMemory; + else + return (formatBytes(row.memoryMetrics.usedOnHeapStorageMemory, type) + ' / ' + + formatBytes(row.memoryMetrics.totalOnHeapStorageMemory, type)); + }, + "fnCreatedCell": function (nTd, sData, oData, iRow, iCol) { + $(nTd).addClass('on_heap_memory') + } + }, + { + data: function (row, type) { + if (type !== 'display') + return row.memoryMetrics.usedOffHeapStorageMemory; + else + return (formatBytes(row.memoryMetrics.usedOffHeapStorageMemory, type) + ' / ' + + formatBytes(row.memoryMetrics.totalOffHeapStorageMemory, type)); + }, + "fnCreatedCell": function (nTd, sData, oData, iRow, iCol) { + $(nTd).addClass('off_heap_memory') } }, {data: 'diskUsed', render: formatBytes}, @@ -484,7 +559,35 @@ $(document).ready(function () { {data: 'allRDDBlocks'}, { data: function (row, type) { - return type === 'display' ? (formatBytes(row.allMemoryUsed, type) + ' / ' + formatBytes(row.allMaxMemory, type)) : row.allMemoryUsed; + if (type !== 'display') + return row.allMemoryUsed + else + return (formatBytes(row.allMemoryUsed, type) + ' / ' + + formatBytes(row.allMaxMemory, type)); + } + }, + { + data: function (row, type) { + if (type !== 'display') + return row.allOnHeapMemoryUsed; + else + return (formatBytes(row.allOnHeapMemoryUsed, type) + ' / ' + + formatBytes(row.allOnHeapMaxMemory, type)); + }, + "fnCreatedCell": function (nTd, sData, oData, iRow, iCol) { + $(nTd).addClass('on_heap_memory') + } + }, + { + data: function (row, type) { + if (type !== 'display') + return row.allOffHeapMemoryUsed; + else + return (formatBytes(row.allOffHeapMemoryUsed, type) + ' / ' + + formatBytes(row.allOffHeapMaxMemory, type)); + }, + "fnCreatedCell": function (nTd, sData, oData, iRow, iCol) { + $(nTd).addClass('off_heap_memory') } }, {data: 'allDiskUsed', render: formatBytes}, diff --git a/core/src/main/resources/org/apache/spark/ui/static/historypage-template.html b/core/src/main/resources/org/apache/spark/ui/static/historypage-template.html index 42e2d9abdeb5e..6ba3b092dc658 100644 --- a/core/src/main/resources/org/apache/spark/ui/static/historypage-template.html +++ b/core/src/main/resources/org/apache/spark/ui/static/historypage-template.html @@ -77,7 +77,7 @@ {{duration}} {{sparkUser}} {{lastUpdated}} - Download + Download {{/attempts}} {{/applications}} diff --git a/core/src/main/resources/org/apache/spark/ui/static/historypage.js b/core/src/main/resources/org/apache/spark/ui/static/historypage.js index 54810edaf1460..1f89306403cd5 100644 --- a/core/src/main/resources/org/apache/spark/ui/static/historypage.js +++ b/core/src/main/resources/org/apache/spark/ui/static/historypage.js @@ -120,6 +120,9 @@ $(document).ready(function() { attempt["startTime"] = formatDate(attempt["startTime"]); attempt["endTime"] = formatDate(attempt["endTime"]); attempt["lastUpdated"] = formatDate(attempt["lastUpdated"]); + attempt["log"] = uiRoot + "/api/v1/applications/" + id + "/" + + (attempt.hasOwnProperty("attemptId") ? attempt["attemptId"] + "/" : "") + "logs"; + var app_clone = {"id" : id, "name" : name, "num" : num, "attempts" : [attempt]}; array.push(app_clone); } diff --git a/core/src/main/resources/org/apache/spark/ui/static/log-view.js b/core/src/main/resources/org/apache/spark/ui/static/log-view.js index 1782b4f209c09..b5c43e5788bc3 100644 --- a/core/src/main/resources/org/apache/spark/ui/static/log-view.js +++ b/core/src/main/resources/org/apache/spark/ui/static/log-view.js @@ -51,13 +51,26 @@ function noNewAlert() { window.setTimeout(function () {alert.css("display", "none");}, 4000); } + +function getRESTEndPoint() { + // If the worker is served from the master through a proxy (see doc on spark.ui.reverseProxy), + // we need to retain the leading ../proxy// part of the URL when making REST requests. + // Similar logic is contained in executorspage.js function createRESTEndPoint. + var words = document.baseURI.split('/'); + var ind = words.indexOf("proxy"); + if (ind > 0) { + return words.slice(0, ind + 2).join('/') + "/log"; + } + return "/log" +} + function loadMore() { var offset = Math.max(startByte - byteLength, 0); var moreByteLength = Math.min(byteLength, startByte); $.ajax({ type: "GET", - url: "/log" + baseParams + "&offset=" + offset + "&byteLength=" + moreByteLength, + url: getRESTEndPoint() + baseParams + "&offset=" + offset + "&byteLength=" + moreByteLength, success: function (data) { var oldHeight = $(".log-content")[0].scrollHeight; var newlineIndex = data.indexOf('\n'); @@ -83,14 +96,14 @@ function loadMore() { function loadNew() { $.ajax({ type: "GET", - url: "/log" + baseParams + "&byteLength=0", + url: getRESTEndPoint() + baseParams + "&byteLength=0", success: function (data) { var dataInfo = data.substring(0, data.indexOf('\n')).match(/\d+/g); var newDataLen = dataInfo[2] - totalLogLength; if (newDataLen != 0) { $.ajax({ type: "GET", - url: "/log" + baseParams + "&byteLength=" + newDataLen, + url: getRESTEndPoint() + baseParams + "&byteLength=" + newDataLen, success: function (data) { var newlineIndex = data.indexOf('\n'); var dataInfo = data.substring(0, newlineIndex).match(/\d+/g); diff --git a/core/src/main/resources/org/apache/spark/ui/static/webui.css b/core/src/main/resources/org/apache/spark/ui/static/webui.css index 319a719efaa79..935d9b1aec615 100644 --- a/core/src/main/resources/org/apache/spark/ui/static/webui.css +++ b/core/src/main/resources/org/apache/spark/ui/static/webui.css @@ -205,7 +205,8 @@ span.additional-metric-title { /* Hide all additional metrics by default. This is done here rather than using JavaScript to * avoid slow page loads for stage pages with large numbers (e.g., thousands) of tasks. */ .scheduler_delay, .deserialization_time, .fetch_wait_time, .shuffle_read_remote, -.serialization_time, .getting_result_time, .peak_execution_memory { +.serialization_time, .getting_result_time, .peak_execution_memory, +.on_heap_memory, .off_heap_memory { display: none; } diff --git a/core/src/main/scala/org/apache/spark/ExecutorAllocationClient.scala b/core/src/main/scala/org/apache/spark/ExecutorAllocationClient.scala index e4b9f8111efca..9112d93a86b2a 100644 --- a/core/src/main/scala/org/apache/spark/ExecutorAllocationClient.scala +++ b/core/src/main/scala/org/apache/spark/ExecutorAllocationClient.scala @@ -71,13 +71,12 @@ private[spark] trait ExecutorAllocationClient { /** * Request that the cluster manager kill every executor on the specified host. - * Results in a call to killExecutors for each executor on the host, with the replace - * and force arguments set to true. + * * @return whether the request is acknowledged by the cluster manager. */ def killExecutorsOnHost(host: String): Boolean - /** + /** * Request that the cluster manager kill the specified executor. * @return whether the request is acknowledged by the cluster manager. */ diff --git a/core/src/main/scala/org/apache/spark/ExecutorAllocationManager.scala b/core/src/main/scala/org/apache/spark/ExecutorAllocationManager.scala index 1366251d0618f..fcc72ff49276d 100644 --- a/core/src/main/scala/org/apache/spark/ExecutorAllocationManager.scala +++ b/core/src/main/scala/org/apache/spark/ExecutorAllocationManager.scala @@ -331,7 +331,7 @@ private[spark] class ExecutorAllocationManager( val delta = addExecutors(maxNeeded) logDebug(s"Starting timer to add more executors (to " + s"expire in $sustainedSchedulerBacklogTimeoutS seconds)") - addTime += sustainedSchedulerBacklogTimeoutS * 1000 + addTime = now + (sustainedSchedulerBacklogTimeoutS * 1000) delta } else { 0 @@ -439,7 +439,7 @@ private[spark] class ExecutorAllocationManager( executorsRemoved } else { logWarning(s"Unable to reach the cluster manager to kill executor/s " + - "executorIdsToBeRemoved.mkString(\",\") or no executor eligible to kill!") + s"${executorIdsToBeRemoved.mkString(",")} or no executor eligible to kill!") Seq.empty[String] } } diff --git a/core/src/main/scala/org/apache/spark/InterruptibleIterator.scala b/core/src/main/scala/org/apache/spark/InterruptibleIterator.scala index 5c262bcbddf76..7f2c0068174b5 100644 --- a/core/src/main/scala/org/apache/spark/InterruptibleIterator.scala +++ b/core/src/main/scala/org/apache/spark/InterruptibleIterator.scala @@ -33,11 +33,8 @@ class InterruptibleIterator[+T](val context: TaskContext, val delegate: Iterator // is allowed. The assumption is that Thread.interrupted does not have a memory fence in read // (just a volatile field in C), while context.interrupted is a volatile in the JVM, which // introduces an expensive read fence. - if (context.isInterrupted) { - throw new TaskKilledException - } else { - delegate.hasNext - } + context.killTaskIfInterrupted() + delegate.hasNext } def next(): T = delegate.next() diff --git a/core/src/main/scala/org/apache/spark/SparkConf.scala b/core/src/main/scala/org/apache/spark/SparkConf.scala index fe912e639bcbc..2a2ce0504dbbf 100644 --- a/core/src/main/scala/org/apache/spark/SparkConf.scala +++ b/core/src/main/scala/org/apache/spark/SparkConf.scala @@ -518,71 +518,6 @@ class SparkConf(loadDefaults: Boolean) extends Cloneable with Logging with Seria } } - // Check for legacy configs - sys.env.get("SPARK_JAVA_OPTS").foreach { value => - val warning = - s""" - |SPARK_JAVA_OPTS was detected (set to '$value'). - |This is deprecated in Spark 1.0+. - | - |Please instead use: - | - ./spark-submit with conf/spark-defaults.conf to set defaults for an application - | - ./spark-submit with --driver-java-options to set -X options for a driver - | - spark.executor.extraJavaOptions to set -X options for executors - | - SPARK_DAEMON_JAVA_OPTS to set java options for standalone daemons (master or worker) - """.stripMargin - logWarning(warning) - - for (key <- Seq(executorOptsKey, driverOptsKey)) { - if (getOption(key).isDefined) { - throw new SparkException(s"Found both $key and SPARK_JAVA_OPTS. Use only the former.") - } else { - logWarning(s"Setting '$key' to '$value' as a work-around.") - set(key, value) - } - } - } - - sys.env.get("SPARK_CLASSPATH").foreach { value => - val warning = - s""" - |SPARK_CLASSPATH was detected (set to '$value'). - |This is deprecated in Spark 1.0+. - | - |Please instead use: - | - ./spark-submit with --driver-class-path to augment the driver classpath - | - spark.executor.extraClassPath to augment the executor classpath - """.stripMargin - logWarning(warning) - - for (key <- Seq(executorClasspathKey, driverClassPathKey)) { - if (getOption(key).isDefined) { - throw new SparkException(s"Found both $key and SPARK_CLASSPATH. Use only the former.") - } else { - logWarning(s"Setting '$key' to '$value' as a work-around.") - set(key, value) - } - } - } - - if (!contains(sparkExecutorInstances)) { - sys.env.get("SPARK_WORKER_INSTANCES").foreach { value => - val warning = - s""" - |SPARK_WORKER_INSTANCES was detected (set to '$value'). - |This is deprecated in Spark 1.0+. - | - |Please instead use: - | - ./spark-submit with --num-executors to specify the number of executors - | - Or set SPARK_EXECUTOR_INSTANCES - | - spark.executor.instances to configure the number of instances in the spark config. - """.stripMargin - logWarning(warning) - - set("spark.executor.instances", value) - } - } - if (contains("spark.master") && get("spark.master").startsWith("yarn-")) { val warning = s"spark.master ${get("spark.master")} is deprecated in Spark 2.0+, please " + "instead use \"yarn\" with specified deploy mode." diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index 7aba4e5abeb5e..aa48f028fba1b 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -366,7 +366,7 @@ class SparkContext(config: SparkConf) extends Logging { */ def setLogLevel(logLevel: String) { // let's allow lowercase or mixed case too - val upperCased = logLevel.toUpperCase(Locale.ENGLISH) + val upperCased = logLevel.toUpperCase(Locale.ROOT) require(SparkContext.VALID_LOG_LEVELS.contains(upperCased), s"Supplied level $logLevel did not match one of:" + s" ${SparkContext.VALID_LOG_LEVELS.mkString(",")}") @@ -1355,7 +1355,7 @@ class SparkContext(config: SparkConf) extends Logging { @deprecated("use AccumulatorV2", "2.0.0") def accumulator[T](initialValue: T, name: String)(implicit param: AccumulatorParam[T]) : Accumulator[T] = { - val acc = new Accumulator(initialValue, param, Some(name)) + val acc = new Accumulator(initialValue, param, Option(name)) cleaner.foreach(_.registerAccumulatorForCleanup(acc.newAcc)) acc } @@ -1384,7 +1384,7 @@ class SparkContext(config: SparkConf) extends Logging { @deprecated("use AccumulatorV2", "2.0.0") def accumulable[R, T](initialValue: R, name: String)(implicit param: AccumulableParam[R, T]) : Accumulable[R, T] = { - val acc = new Accumulable(initialValue, param, Some(name)) + val acc = new Accumulable(initialValue, param, Option(name)) cleaner.foreach(_.registerAccumulatorForCleanup(acc.newAcc)) acc } @@ -1419,7 +1419,7 @@ class SparkContext(config: SparkConf) extends Logging { * @note Accumulators must be registered before use, or it will throw exception. */ def register(acc: AccumulatorV2[_, _], name: String): Unit = { - acc.register(this, name = Some(name)) + acc.register(this, name = Option(name)) } /** @@ -1739,6 +1739,7 @@ class SparkContext(config: SparkConf) extends Logging { * Return information about blocks stored in all of the slaves */ @DeveloperApi + @deprecated("This method may change or be removed in a future release.", "2.2.0") def getExecutorStorageStatus: Array[StorageStatus] = { assertNotStopped() env.blockManager.master.getStorageStatus @@ -1861,8 +1862,8 @@ class SparkContext(config: SparkConf) extends Logging { } /** - * Add a set of conda packages (identified by package match specification + * Add a set of conda packages (identified by package match specification * for all tasks to be executed on this SparkContext in the future. */ def addCondaPackages(packages: Seq[String]): Unit = { @@ -1965,6 +1966,9 @@ class SparkContext(config: SparkConf) extends Logging { } SparkEnv.set(null) } + // Clear this `InheritableThreadLocal`, or it will still be inherited in child threads even this + // `SparkContext` is stopped. + localProperties.remove() // Unset YARN mode system env variable, to allow switching between cluster types. System.clearProperty("SPARK_YARN_MODE") SparkContext.clearActiveContext() @@ -2276,6 +2280,24 @@ class SparkContext(config: SparkConf) extends Logging { dagScheduler.cancelStage(stageId, None) } + /** + * Kill and reschedule the given task attempt. Task ids can be obtained from the Spark UI + * or through SparkListener.onTaskStart. + * + * @param taskId the task ID to kill. This id uniquely identifies the task attempt. + * @param interruptThread whether to interrupt the thread running the task. + * @param reason the reason for killing the task, which should be a short string. If a task + * is killed multiple times with different reasons, only one reason will be reported. + * + * @return Whether the task was successfully killed. + */ + def killTaskAttempt( + taskId: Long, + interruptThread: Boolean = true, + reason: String = "killed via SparkContext.killTaskAttempt"): Boolean = { + dagScheduler.killTaskAttempt(taskId, interruptThread, reason) + } + /** * Clean a closure to make it ready to serialized and send to tasks * (removes unreferenced variables in $outer's, updates REPL variables) diff --git a/core/src/main/scala/org/apache/spark/SparkEnv.scala b/core/src/main/scala/org/apache/spark/SparkEnv.scala index cb4f751bc8fb7..aeaad6941cd13 100644 --- a/core/src/main/scala/org/apache/spark/SparkEnv.scala +++ b/core/src/main/scala/org/apache/spark/SparkEnv.scala @@ -19,6 +19,7 @@ package org.apache.spark import java.io.File import java.net.Socket +import java.util.Locale import scala.collection.mutable import scala.util.Properties @@ -327,7 +328,8 @@ object SparkEnv extends Logging { "sort" -> classOf[org.apache.spark.shuffle.sort.SortShuffleManager].getName, "tungsten-sort" -> classOf[org.apache.spark.shuffle.sort.SortShuffleManager].getName) val shuffleMgrName = conf.get("spark.shuffle.manager", "sort") - val shuffleMgrClass = shortShuffleMgrNames.getOrElse(shuffleMgrName.toLowerCase, shuffleMgrName) + val shuffleMgrClass = + shortShuffleMgrNames.getOrElse(shuffleMgrName.toLowerCase(Locale.ROOT), shuffleMgrName) val shuffleManager = instantiateClass[ShuffleManager](shuffleMgrClass) val useLegacyMemoryManager = conf.getBoolean("spark.memory.useLegacyMode", false) diff --git a/core/src/main/scala/org/apache/spark/TaskContext.scala b/core/src/main/scala/org/apache/spark/TaskContext.scala index f0867ecb16ea3..0b87cd503d4fa 100644 --- a/core/src/main/scala/org/apache/spark/TaskContext.scala +++ b/core/src/main/scala/org/apache/spark/TaskContext.scala @@ -105,7 +105,9 @@ abstract class TaskContext extends Serializable { /** * Adds a (Java friendly) listener to be executed on task completion. - * This will be called in all situation - success, failure, or cancellation. + * This will be called in all situations - success, failure, or cancellation. Adding a listener + * to an already completed task will result in that listener being called immediately. + * * An example use is for HadoopRDD to register a callback to close the input stream. * * Exceptions thrown by the listener will result in failure of the task. @@ -114,7 +116,9 @@ abstract class TaskContext extends Serializable { /** * Adds a listener in the form of a Scala closure to be executed on task completion. - * This will be called in all situations - success, failure, or cancellation. + * This will be called in all situations - success, failure, or cancellation. Adding a listener + * to an already completed task will result in that listener being called immediately. + * * An example use is for HadoopRDD to register a callback to close the input stream. * * Exceptions thrown by the listener will result in failure of the task. @@ -126,14 +130,14 @@ abstract class TaskContext extends Serializable { } /** - * Adds a listener to be executed on task failure. - * Operations defined here must be idempotent, as `onTaskFailure` can be called multiple times. + * Adds a listener to be executed on task failure. Adding a listener to an already failed task + * will result in that listener being called immediately. */ def addTaskFailureListener(listener: TaskFailureListener): TaskContext /** - * Adds a listener to be executed on task failure. - * Operations defined here must be idempotent, as `onTaskFailure` can be called multiple times. + * Adds a listener to be executed on task failure. Adding a listener to an already failed task + * will result in that listener being called immediately. */ def addTaskFailureListener(f: (TaskContext, Throwable) => Unit): TaskContext = { addTaskFailureListener(new TaskFailureListener { @@ -180,6 +184,16 @@ abstract class TaskContext extends Serializable { @DeveloperApi def getMetricsSources(sourceName: String): Seq[Source] + /** + * If the task is interrupted, throws TaskKilledException with the reason for the interrupt. + */ + private[spark] def killTaskIfInterrupted(): Unit + + /** + * If the task is interrupted, the reason this task was killed, otherwise None. + */ + private[spark] def getKillReason(): Option[String] + /** * Returns the manager for this task's managed memory. */ diff --git a/core/src/main/scala/org/apache/spark/TaskContextImpl.scala b/core/src/main/scala/org/apache/spark/TaskContextImpl.scala index dc0d12878550a..8cd1d1c96aa0a 100644 --- a/core/src/main/scala/org/apache/spark/TaskContextImpl.scala +++ b/core/src/main/scala/org/apache/spark/TaskContextImpl.scala @@ -18,6 +18,7 @@ package org.apache.spark import java.util.Properties +import javax.annotation.concurrent.GuardedBy import scala.collection.mutable.ArrayBuffer @@ -29,6 +30,16 @@ import org.apache.spark.metrics.source.Source import org.apache.spark.shuffle.FetchFailedException import org.apache.spark.util._ +/** + * A [[TaskContext]] implementation. + * + * A small note on thread safety. The interrupted & fetchFailed fields are volatile, this makes + * sure that updates are always visible across threads. The complete & failed flags and their + * callbacks are protected by locking on the context instance. For instance, this ensures + * that you cannot add a completion listener in one thread while we are completing (and calling + * the completion listeners) in another thread. Other state is immutable, however the exposed + * `TaskMetrics` & `MetricsSystem` objects are not thread safe. + */ private[spark] class TaskContextImpl( val stageId: Int, val partitionId: Int, @@ -48,79 +59,108 @@ private[spark] class TaskContextImpl( /** List of callback functions to execute when the task fails. */ @transient private val onFailureCallbacks = new ArrayBuffer[TaskFailureListener] - // Whether the corresponding task has been killed. - @volatile private var interrupted: Boolean = false + // If defined, the corresponding task has been killed and this option contains the reason. + @volatile private var reasonIfKilled: Option[String] = None // Whether the task has completed. - @volatile private var completed: Boolean = false + private var completed: Boolean = false // Whether the task has failed. - @volatile private var failed: Boolean = false + private var failed: Boolean = false + + // Throwable that caused the task to fail + private var failure: Throwable = _ // If there was a fetch failure in the task, we store it here, to make sure user-code doesn't // hide the exception. See SPARK-19276 @volatile private var _fetchFailedException: Option[FetchFailedException] = None - override def addTaskCompletionListener(listener: TaskCompletionListener): this.type = { - onCompleteCallbacks += listener + @GuardedBy("this") + override def addTaskCompletionListener(listener: TaskCompletionListener) + : this.type = synchronized { + if (completed) { + listener.onTaskCompletion(this) + } else { + onCompleteCallbacks += listener + } this } - override def addTaskFailureListener(listener: TaskFailureListener): this.type = { - onFailureCallbacks += listener + @GuardedBy("this") + override def addTaskFailureListener(listener: TaskFailureListener) + : this.type = synchronized { + if (failed) { + listener.onTaskFailure(this, failure) + } else { + onFailureCallbacks += listener + } this } /** Marks the task as failed and triggers the failure listeners. */ - private[spark] def markTaskFailed(error: Throwable): Unit = { - // failure callbacks should only be called once + @GuardedBy("this") + private[spark] def markTaskFailed(error: Throwable): Unit = synchronized { if (failed) return failed = true - val errorMsgs = new ArrayBuffer[String](2) - // Process failure callbacks in the reverse order of registration - onFailureCallbacks.reverse.foreach { listener => - try { - listener.onTaskFailure(this, error) - } catch { - case e: Throwable => - errorMsgs += e.getMessage - logError("Error in TaskFailureListener", e) - } - } - if (errorMsgs.nonEmpty) { - throw new TaskCompletionListenerException(errorMsgs, Option(error)) + failure = error + invokeListeners(onFailureCallbacks, "TaskFailureListener", Option(error)) { + _.onTaskFailure(this, error) } } /** Marks the task as completed and triggers the completion listeners. */ - private[spark] def markTaskCompleted(): Unit = { + @GuardedBy("this") + private[spark] def markTaskCompleted(): Unit = synchronized { + if (completed) return completed = true + invokeListeners(onCompleteCallbacks, "TaskCompletionListener", None) { + _.onTaskCompletion(this) + } + } + + private def invokeListeners[T]( + listeners: Seq[T], + name: String, + error: Option[Throwable])( + callback: T => Unit): Unit = { val errorMsgs = new ArrayBuffer[String](2) - // Process complete callbacks in the reverse order of registration - onCompleteCallbacks.reverse.foreach { listener => + // Process callbacks in the reverse order of registration + listeners.reverse.foreach { listener => try { - listener.onTaskCompletion(this) + callback(listener) } catch { case e: Throwable => errorMsgs += e.getMessage - logError("Error in TaskCompletionListener", e) + logError(s"Error in $name", e) } } if (errorMsgs.nonEmpty) { - throw new TaskCompletionListenerException(errorMsgs) + throw new TaskCompletionListenerException(errorMsgs, error) } } /** Marks the task for interruption, i.e. cancellation. */ - private[spark] def markInterrupted(): Unit = { - interrupted = true + private[spark] def markInterrupted(reason: String): Unit = { + reasonIfKilled = Some(reason) + } + + private[spark] override def killTaskIfInterrupted(): Unit = { + val reason = reasonIfKilled + if (reason.isDefined) { + throw new TaskKilledException(reason.get) + } + } + + private[spark] override def getKillReason(): Option[String] = { + reasonIfKilled } - override def isCompleted(): Boolean = completed + @GuardedBy("this") + override def isCompleted(): Boolean = synchronized(completed) override def isRunningLocally(): Boolean = false - override def isInterrupted(): Boolean = interrupted + override def isInterrupted(): Boolean = reasonIfKilled.isDefined override def getLocalProperty(key: String): String = localProperties.getProperty(key) diff --git a/core/src/main/scala/org/apache/spark/TaskEndReason.scala b/core/src/main/scala/org/apache/spark/TaskEndReason.scala index 8c1b5f7bf0d9b..a76283e33fa65 100644 --- a/core/src/main/scala/org/apache/spark/TaskEndReason.scala +++ b/core/src/main/scala/org/apache/spark/TaskEndReason.scala @@ -212,8 +212,8 @@ case object TaskResultLost extends TaskFailedReason { * Task was killed intentionally and needs to be rescheduled. */ @DeveloperApi -case object TaskKilled extends TaskFailedReason { - override def toErrorString: String = "TaskKilled (killed intentionally)" +case class TaskKilled(reason: String) extends TaskFailedReason { + override def toErrorString: String = s"TaskKilled ($reason)" override def countTowardsTaskFailures: Boolean = false } diff --git a/core/src/main/scala/org/apache/spark/TaskKilledException.scala b/core/src/main/scala/org/apache/spark/TaskKilledException.scala index ad487c4efb87a..9dbf0d493be11 100644 --- a/core/src/main/scala/org/apache/spark/TaskKilledException.scala +++ b/core/src/main/scala/org/apache/spark/TaskKilledException.scala @@ -24,4 +24,6 @@ import org.apache.spark.annotation.DeveloperApi * Exception thrown when a task is explicitly killed (i.e., task failure is expected). */ @DeveloperApi -class TaskKilledException extends RuntimeException +class TaskKilledException(val reason: String) extends RuntimeException { + def this() = this("unknown reason") +} diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaSparkContext.scala b/core/src/main/scala/org/apache/spark/api/java/JavaSparkContext.scala index df2064a310217..70a82df8afd17 100644 --- a/core/src/main/scala/org/apache/spark/api/java/JavaSparkContext.scala +++ b/core/src/main/scala/org/apache/spark/api/java/JavaSparkContext.scala @@ -696,8 +696,8 @@ class JavaSparkContext(val sc: SparkContext) } /** - * Add a set of conda packages (identified by package match specification + * Add a set of conda packages (identified by package match specification * for all tasks to be executed on this SparkContext in the future. */ def addCondaPackages(packages: java.util.List[String]): Unit = { diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala index 6f31d8d553b1a..2b5a2aa479aaa 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala @@ -223,7 +223,7 @@ private[spark] class PythonRunner( case e: Exception if context.isInterrupted => logDebug("Exception thrown after task interruption", e) - throw new TaskKilledException + throw new TaskKilledException(context.getKillReason().getOrElse("unknown reason")) case e: Exception if env.isStopped => logDebug("Exception thrown after context is stopped", e) diff --git a/core/src/main/scala/org/apache/spark/api/r/RRDD.scala b/core/src/main/scala/org/apache/spark/api/r/RRDD.scala index a1a5eb8cf55e8..295355c7bf018 100644 --- a/core/src/main/scala/org/apache/spark/api/r/RRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/r/RRDD.scala @@ -17,6 +17,7 @@ package org.apache.spark.api.r +import java.io.File import java.util.{Map => JMap} import scala.collection.JavaConverters._ @@ -127,7 +128,15 @@ private[r] object RRDD { sparkConf.setExecutorEnv(name.toString, value.toString) } - val jsc = new JavaSparkContext(sparkConf) + if (sparkEnvirMap.containsKey("spark.r.sql.derby.temp.dir") && + System.getProperty("derby.stream.error.file") == null) { + // This must be set before SparkContext is instantiated. + System.setProperty("derby.stream.error.file", + Seq(sparkEnvirMap.get("spark.r.sql.derby.temp.dir").toString, "derby.log") + .mkString(File.separator)) + } + + val jsc = new JavaSparkContext(SparkContext.getOrCreate(sparkConf)) jars.foreach { jar => jsc.addJar(jar) } diff --git a/core/src/main/scala/org/apache/spark/api/r/RRunner.scala b/core/src/main/scala/org/apache/spark/api/r/RRunner.scala index 29e21b3b1aa8a..88118392003e8 100644 --- a/core/src/main/scala/org/apache/spark/api/r/RRunner.scala +++ b/core/src/main/scala/org/apache/spark/api/r/RRunner.scala @@ -347,6 +347,8 @@ private[r] object RRunner { pb.environment().put("SPARKR_RLIBDIR", rLibDir.mkString(",")) pb.environment().put("SPARKR_WORKER_PORT", port.toString) pb.environment().put("SPARKR_BACKEND_CONNECTION_TIMEOUT", rConnectionTimeout.toString) + pb.environment().put("SPARKR_SPARKFILES_ROOT_DIR", SparkFiles.getRootDirectory()) + pb.environment().put("SPARKR_IS_RUNNING_ON_WORKER", "TRUE") pb.redirectErrorStream(true) // redirect stderr into stdout val proc = pb.start() val errThread = startStdoutThread(proc) diff --git a/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala b/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala index 22d01c47e645d..039df75ce74fd 100644 --- a/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala +++ b/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala @@ -29,7 +29,7 @@ import org.apache.spark._ import org.apache.spark.internal.Logging import org.apache.spark.io.CompressionCodec import org.apache.spark.serializer.Serializer -import org.apache.spark.storage.{BlockId, BroadcastBlockId, StorageLevel} +import org.apache.spark.storage._ import org.apache.spark.util.{ByteBufferInputStream, Utils} import org.apache.spark.util.io.{ChunkedByteBuffer, ChunkedByteBufferOutputStream} @@ -141,10 +141,10 @@ private[spark] class TorrentBroadcast[T: ClassTag](obj: T, id: Long) } /** Fetch torrent blocks from the driver and/or other executors. */ - private def readBlocks(): Array[ChunkedByteBuffer] = { + private def readBlocks(): Array[BlockData] = { // Fetch chunks of data. Note that all these chunks are stored in the BlockManager and reported // to the driver, so other executors can pull these chunks from this executor as well. - val blocks = new Array[ChunkedByteBuffer](numBlocks) + val blocks = new Array[BlockData](numBlocks) val bm = SparkEnv.get.blockManager for (pid <- Random.shuffle(Seq.range(0, numBlocks))) { @@ -173,7 +173,7 @@ private[spark] class TorrentBroadcast[T: ClassTag](obj: T, id: Long) throw new SparkException( s"Failed to store $pieceId of $broadcastId in local BlockManager") } - blocks(pid) = b + blocks(pid) = new ByteBufferBlockData(b, true) case None => throw new SparkException(s"Failed to get $pieceId of $broadcastId") } @@ -219,18 +219,22 @@ private[spark] class TorrentBroadcast[T: ClassTag](obj: T, id: Long) case None => logInfo("Started reading broadcast variable " + id) val startTimeMs = System.currentTimeMillis() - val blocks = readBlocks().flatMap(_.getChunks()) + val blocks = readBlocks() logInfo("Reading broadcast variable " + id + " took" + Utils.getUsedTimeMs(startTimeMs)) - val obj = TorrentBroadcast.unBlockifyObject[T]( - blocks, SparkEnv.get.serializer, compressionCodec) - // Store the merged copy in BlockManager so other tasks on this executor don't - // need to re-fetch it. - val storageLevel = StorageLevel.MEMORY_AND_DISK - if (!blockManager.putSingle(broadcastId, obj, storageLevel, tellMaster = false)) { - throw new SparkException(s"Failed to store $broadcastId in BlockManager") + try { + val obj = TorrentBroadcast.unBlockifyObject[T]( + blocks.map(_.toInputStream()), SparkEnv.get.serializer, compressionCodec) + // Store the merged copy in BlockManager so other tasks on this executor don't + // need to re-fetch it. + val storageLevel = StorageLevel.MEMORY_AND_DISK + if (!blockManager.putSingle(broadcastId, obj, storageLevel, tellMaster = false)) { + throw new SparkException(s"Failed to store $broadcastId in BlockManager") + } + obj + } finally { + blocks.foreach(_.dispose()) } - obj } } } @@ -277,12 +281,11 @@ private object TorrentBroadcast extends Logging { } def unBlockifyObject[T: ClassTag]( - blocks: Array[ByteBuffer], + blocks: Array[InputStream], serializer: Serializer, compressionCodec: Option[CompressionCodec]): T = { require(blocks.nonEmpty, "Cannot unblockify an empty array of blocks") - val is = new SequenceInputStream( - blocks.iterator.map(new ByteBufferInputStream(_)).asJavaEnumeration) + val is = new SequenceInputStream(blocks.iterator.asJavaEnumeration) val in: InputStream = compressionCodec.map(c => c.compressedInputStream(is)).getOrElse(is) val ser = serializer.newInstance() val serIn = ser.deserializeStream(in) diff --git a/core/src/main/scala/org/apache/spark/deploy/FaultToleranceTest.scala b/core/src/main/scala/org/apache/spark/deploy/FaultToleranceTest.scala index 320af5cf97550..c6307da61c7eb 100644 --- a/core/src/main/scala/org/apache/spark/deploy/FaultToleranceTest.scala +++ b/core/src/main/scala/org/apache/spark/deploy/FaultToleranceTest.scala @@ -43,8 +43,7 @@ import org.apache.spark.util.{ThreadUtils, Utils} * Execute using * ./bin/spark-class org.apache.spark.deploy.FaultToleranceTest * - * Make sure that the environment includes the following properties in SPARK_DAEMON_JAVA_OPTS - * *and* SPARK_JAVA_OPTS: + * Make sure that the environment includes the following properties in SPARK_DAEMON_JAVA_OPTS: * - spark.deploy.recoveryMode=ZOOKEEPER * - spark.deploy.zookeeper.url=172.17.42.1:2181 * Note that 172.17.42.1 is the default docker ip for the host and 2181 is the default ZK port. diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala b/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala index f475ce87540aa..9cc321af4bde2 100644 --- a/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala +++ b/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala @@ -28,6 +28,7 @@ import scala.util.control.NonFatal import com.google.common.primitives.Longs import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.{FileStatus, FileSystem, Path, PathFilter} +import org.apache.hadoop.fs.permission.FsAction import org.apache.hadoop.mapred.JobConf import org.apache.hadoop.security.{Credentials, UserGroupInformation} import org.apache.hadoop.security.token.{Token, TokenIdentifier} @@ -349,10 +350,32 @@ class SparkHadoopUtil extends Logging { } } catch { case e: IOException => - logDebug("Failed to decode $token: $e", e) + logDebug(s"Failed to decode $token: $e", e) } buffer.toString } + + private[spark] def checkAccessPermission(status: FileStatus, mode: FsAction): Boolean = { + val perm = status.getPermission + val ugi = UserGroupInformation.getCurrentUser + + if (ugi.getShortUserName == status.getOwner) { + if (perm.getUserAction.implies(mode)) { + return true + } + } else if (ugi.getGroupNames.contains(status.getGroup)) { + if (perm.getGroupAction.implies(mode)) { + return true + } + } else if (perm.getOtherAction.implies(mode)) { + return true + } + + logDebug(s"Permission denied: user=${ugi.getShortUserName}, " + + s"path=${status.getPath}:${status.getOwner}:${status.getGroup}" + + s"${if (status.isDirectory) "d" else "-"}$perm") + false + } } object SparkHadoopUtil { diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala index fd28d2b5eea45..ea45e04f7de8e 100644 --- a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala +++ b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala @@ -496,12 +496,17 @@ object SparkSubmit extends CommandLineUtils { // In client mode, launch the application main class directly // In addition, add the main application jar and any added jars (if any) to the classpath - if (deployMode == CLIENT) { + // Also add the main application jar and any added jars to classpath in case YARN client + // requires these jars. + if (deployMode == CLIENT || isYarnCluster) { childMainClass = args.mainClass if (isUserJar(args.primaryResource)) { childClasspath += args.primaryResource } if (args.jars != null) { childClasspath ++= args.jars.split(",") } + } + + if (deployMode == CLIENT) { if (args.childArgs != null) { childArgs ++= args.childArgs } } diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala b/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala index 7693a8764115e..3350987d17a83 100644 --- a/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala +++ b/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala @@ -193,6 +193,7 @@ private[deploy] class SparkSubmitArguments(args: Seq[String], env: Map[String, S .orNull numExecutors = Option(numExecutors) .getOrElse(sparkProperties.get("spark.executor.instances").orNull) + queue = Option(queue).orElse(sparkProperties.get("spark.yarn.queue")).orNull keytab = Option(keytab).orElse(sparkProperties.get("spark.yarn.keytab")).orNull principal = Option(principal).orElse(sparkProperties.get("spark.yarn.principal")).orNull kubernetesNamespace = Option(kubernetesNamespace) diff --git a/core/src/main/scala/org/apache/spark/deploy/history/ApplicationHistoryProvider.scala b/core/src/main/scala/org/apache/spark/deploy/history/ApplicationHistoryProvider.scala index d7d82800b8b55..6d8758a3d3b1d 100644 --- a/core/src/main/scala/org/apache/spark/deploy/history/ApplicationHistoryProvider.scala +++ b/core/src/main/scala/org/apache/spark/deploy/history/ApplicationHistoryProvider.scala @@ -86,7 +86,7 @@ private[history] abstract class ApplicationHistoryProvider { * @return Count of application event logs that are currently under process */ def getEventLogsUnderProcess(): Int = { - return 0; + 0 } /** @@ -95,7 +95,7 @@ private[history] abstract class ApplicationHistoryProvider { * @return 0 if this is undefined or unsupported, otherwise the last updated time in millis */ def getLastUpdatedTime(): Long = { - return 0; + 0 } /** diff --git a/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala b/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala index 9012736bc2745..f4235df245128 100644 --- a/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala +++ b/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala @@ -27,7 +27,8 @@ import scala.xml.Node import com.google.common.io.ByteStreams import com.google.common.util.concurrent.{MoreExecutors, ThreadFactoryBuilder} -import org.apache.hadoop.fs.{FileStatus, FileSystem, Path} +import org.apache.hadoop.fs.{FileStatus, Path} +import org.apache.hadoop.fs.permission.FsAction import org.apache.hadoop.hdfs.DistributedFileSystem import org.apache.hadoop.hdfs.protocol.HdfsConstants import org.apache.hadoop.security.AccessControlException @@ -318,21 +319,14 @@ private[history] class FsHistoryProvider(conf: SparkConf, clock: Clock) // scan for modified applications, replay and merge them val logInfos: Seq[FileStatus] = statusList .filter { entry => - try { - val prevFileSize = fileToAppInfo.get(entry.getPath()).map{_.fileSize}.getOrElse(0L) - !entry.isDirectory() && - // FsHistoryProvider generates a hidden file which can't be read. Accidentally - // reading a garbage file is safe, but we would log an error which can be scary to - // the end-user. - !entry.getPath().getName().startsWith(".") && - prevFileSize < entry.getLen() - } catch { - case e: AccessControlException => - // Do not use "logInfo" since these messages can get pretty noisy if printed on - // every poll. - logDebug(s"No permission to read $entry, ignoring.") - false - } + val prevFileSize = fileToAppInfo.get(entry.getPath()).map{_.fileSize}.getOrElse(0L) + !entry.isDirectory() && + // FsHistoryProvider generates a hidden file which can't be read. Accidentally + // reading a garbage file is safe, but we would log an error which can be scary to + // the end-user. + !entry.getPath().getName().startsWith(".") && + prevFileSize < entry.getLen() && + SparkHadoopUtil.get.checkAccessPermission(entry, FsAction.READ) } .flatMap { entry => Some(entry) } .sortWith { case (entry1, entry2) => @@ -445,7 +439,7 @@ private[history] class FsHistoryProvider(conf: SparkConf, clock: Clock) /** * Replay the log files in the list and merge the list of old applications with new ones */ - private def mergeApplicationListing(fileStatus: FileStatus): Unit = { + protected def mergeApplicationListing(fileStatus: FileStatus): Unit = { val newAttempts = try { val eventsFilter: ReplayEventsFilter = { eventString => eventString.startsWith(APPL_START_EVENT_PREFIX) || diff --git a/core/src/main/scala/org/apache/spark/deploy/history/HistoryServer.scala b/core/src/main/scala/org/apache/spark/deploy/history/HistoryServer.scala index 54f39f7620e5d..d9c8fda99ef97 100644 --- a/core/src/main/scala/org/apache/spark/deploy/history/HistoryServer.scala +++ b/core/src/main/scala/org/apache/spark/deploy/history/HistoryServer.scala @@ -301,6 +301,14 @@ object HistoryServer extends Logging { logDebug(s"Clearing ${SecurityManager.SPARK_AUTH_CONF}") config.set(SecurityManager.SPARK_AUTH_CONF, "false") } + + if (config.getBoolean("spark.acls.enable", config.getBoolean("spark.ui.acls.enable", false))) { + logInfo("Either spark.acls.enable or spark.ui.acls.enable is configured, clearing it and " + + "only using spark.history.ui.acl.enable") + config.set("spark.acls.enable", "false") + config.set("spark.ui.acls.enable", "false") + } + new SecurityManager(config) } diff --git a/core/src/main/scala/org/apache/spark/deploy/master/ui/ApplicationPage.scala b/core/src/main/scala/org/apache/spark/deploy/master/ui/ApplicationPage.scala index 946a92882141c..a8d721f3e0d49 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/ui/ApplicationPage.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/ui/ApplicationPage.scala @@ -83,7 +83,7 @@ private[ui] class ApplicationPage(parent: MasterWebUI) extends WebUIPage("app") Executor Memory: {Utils.megabytesToString(app.desc.memoryPerExecutorMB)} -
  • Submit Date: {app.submitDate}
  • +
  • Submit Date: {UIUtils.formatDate(app.submitDate)}
  • State: {app.state}
  • { if (!app.isFinished) { diff --git a/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterPage.scala b/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterPage.scala index e722a24d4a89e..9351c72094e34 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterPage.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterPage.scala @@ -252,7 +252,7 @@ private[ui] class MasterPage(parent: MasterWebUI) extends WebUIPage("") { } {driver.id} {killLink} - {driver.submitDate} + {UIUtils.formatDate(driver.submitDate)} {driver.worker.map(w => if (w.isAlive()) { diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala b/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala index e48817ebbafdd..00b9d1af373db 100755 --- a/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala @@ -62,8 +62,8 @@ private[deploy] class Worker( private val forwordMessageScheduler = ThreadUtils.newDaemonSingleThreadScheduledExecutor("worker-forward-message-scheduler") - // A separated thread to clean up the workDir. Used to provide the implicit parameter of `Future` - // methods. + // A separated thread to clean up the workDir and the directories of finished applications. + // Used to provide the implicit parameter of `Future` methods. private val cleanupThreadExecutor = ExecutionContext.fromExecutorService( ThreadUtils.newDaemonSingleThreadExecutor("worker-cleanup-thread")) @@ -578,10 +578,15 @@ private[deploy] class Worker( if (shouldCleanup) { finishedApps -= id appDirectories.remove(id).foreach { dirList => - logInfo(s"Cleaning up local directories for application $id") - dirList.foreach { dir => - Utils.deleteRecursively(new File(dir)) - } + concurrent.Future { + logInfo(s"Cleaning up local directories for application $id") + dirList.foreach { dir => + Utils.deleteRecursively(new File(dir)) + } + }(cleanupThreadExecutor).onFailure { + case e: Throwable => + logError(s"Clean up app dir $dirList failed: ${e.getMessage}", e) + }(cleanupThreadExecutor) } shuffleService.applicationRemoved(id) } diff --git a/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala b/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala index b376ecd301eab..b2b26ee107c00 100644 --- a/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala +++ b/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala @@ -19,6 +19,7 @@ package org.apache.spark.executor import java.net.URL import java.nio.ByteBuffer +import java.util.Locale import java.util.concurrent.atomic.AtomicBoolean import scala.collection.mutable @@ -72,7 +73,7 @@ private[spark] class CoarseGrainedExecutorBackend( def extractLogUrls: Map[String, String] = { val prefix = "SPARK_LOG_URL_" sys.env.filterKeys(_.startsWith(prefix)) - .map(e => (e._1.substring(prefix.length).toLowerCase, e._2)) + .map(e => (e._1.substring(prefix.length).toLowerCase(Locale.ROOT), e._2)) } override def receive: PartialFunction[Any, Unit] = { @@ -97,11 +98,11 @@ private[spark] class CoarseGrainedExecutorBackend( executor.launchTask(this, taskDesc) } - case KillTask(taskId, _, interruptThread) => + case KillTask(taskId, _, interruptThread, reason) => if (executor == null) { exitExecutor(1, "Received KillTask command but executor was null") } else { - executor.killTask(taskId, interruptThread) + executor.killTask(taskId, interruptThread, reason) } case StopExecutor => diff --git a/core/src/main/scala/org/apache/spark/executor/Executor.scala b/core/src/main/scala/org/apache/spark/executor/Executor.scala index 790c1ae942474..51b6c373c4daf 100644 --- a/core/src/main/scala/org/apache/spark/executor/Executor.scala +++ b/core/src/main/scala/org/apache/spark/executor/Executor.scala @@ -23,13 +23,15 @@ import java.lang.management.ManagementFactory import java.net.{URI, URL} import java.nio.ByteBuffer import java.util.Properties -import java.util.concurrent.{ConcurrentHashMap, TimeUnit} +import java.util.concurrent._ import javax.annotation.concurrent.GuardedBy import scala.collection.JavaConverters._ import scala.collection.mutable.{ArrayBuffer, HashMap, Map} import scala.util.control.NonFatal +import com.google.common.util.concurrent.ThreadFactoryBuilder + import org.apache.spark._ import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.internal.Logging @@ -84,7 +86,20 @@ private[spark] class Executor( } // Start worker thread pool - private val threadPool = ThreadUtils.newDaemonCachedThreadPool("Executor task launch worker") + private val threadPool = { + val threadFactory = new ThreadFactoryBuilder() + .setDaemon(true) + .setNameFormat("Executor task launch worker-%d") + .setThreadFactory(new ThreadFactory { + override def newThread(r: Runnable): Thread = + // Use UninterruptibleThread to run tasks so that we can allow running codes without being + // interrupted by `Thread.interrupt()`. Some issues, such as KAFKA-1894, HADOOP-10622, + // will hang forever if some methods are interrupted. + new UninterruptibleThread(r, "unused") // thread name will be set by ThreadFactoryBuilder + }) + .build() + Executors.newCachedThreadPool(threadFactory).asInstanceOf[ThreadPoolExecutor] + } private val executorSource = new ExecutorSource(threadPool, executorId) // Pool used for threads that supervise task killing / cancellation private val taskReaperPool = ThreadUtils.newDaemonCachedThreadPool("Task reaper") @@ -158,7 +173,7 @@ private[spark] class Executor( threadPool.execute(tr) } - def killTask(taskId: Long, interruptThread: Boolean): Unit = { + def killTask(taskId: Long, interruptThread: Boolean, reason: String): Unit = { val taskRunner = runningTasks.get(taskId) if (taskRunner != null) { if (taskReaperEnabled) { @@ -168,7 +183,8 @@ private[spark] class Executor( case Some(existingReaper) => interruptThread && !existingReaper.interruptThread } if (shouldCreateReaper) { - val taskReaper = new TaskReaper(taskRunner, interruptThread = interruptThread) + val taskReaper = new TaskReaper( + taskRunner, interruptThread = interruptThread, reason = reason) taskReaperForTask(taskId) = taskReaper Some(taskReaper) } else { @@ -178,7 +194,7 @@ private[spark] class Executor( // Execute the TaskReaper from outside of the synchronized block. maybeNewTaskReaper.foreach(taskReaperPool.execute) } else { - taskRunner.kill(interruptThread = interruptThread) + taskRunner.kill(interruptThread = interruptThread, reason = reason) } } } @@ -189,8 +205,9 @@ private[spark] class Executor( * tasks instead of taking the JVM down. * @param interruptThread whether to interrupt the task thread */ - def killAllTasks(interruptThread: Boolean) : Unit = { - runningTasks.keys().asScala.foreach(t => killTask(t, interruptThread = interruptThread)) + def killAllTasks(interruptThread: Boolean, reason: String) : Unit = { + runningTasks.keys().asScala.foreach(t => + killTask(t, interruptThread = interruptThread, reason = reason)) } def stop(): Unit = { @@ -217,8 +234,8 @@ private[spark] class Executor( val threadName = s"Executor task launch worker for task $taskId" private val taskName = taskDescription.name - /** Whether this task has been killed. */ - @volatile private var killed = false + /** If specified, this task has been killed and this option contains the reason. */ + @volatile private var reasonIfKilled: Option[String] = None @volatile private var threadId: Long = -1 @@ -239,13 +256,13 @@ private[spark] class Executor( */ @volatile var task: Task[Any] = _ - def kill(interruptThread: Boolean): Unit = { - logInfo(s"Executor is trying to kill $taskName (TID $taskId)") - killed = true + def kill(interruptThread: Boolean, reason: String): Unit = { + logInfo(s"Executor is trying to kill $taskName (TID $taskId), reason: $reason") + reasonIfKilled = Some(reason) if (task != null) { synchronized { if (!finished) { - task.kill(interruptThread) + task.kill(interruptThread, reason) } } } @@ -296,12 +313,13 @@ private[spark] class Executor( // If this task has been killed before we deserialized it, let's quit now. Otherwise, // continue executing the task. - if (killed) { + val killReason = reasonIfKilled + if (killReason.isDefined) { // Throw an exception rather than returning, because returning within a try{} block // causes a NonLocalReturnControl exception to be thrown. The NonLocalReturnControl // exception will be caught by the catch block, leading to an incorrect ExceptionFailure // for the task. - throw new TaskKilledException + throw new TaskKilledException(killReason.get) } logDebug("Task " + taskId + "'s epoch is " + task.epoch) @@ -358,9 +376,7 @@ private[spark] class Executor( } else 0L // If the task has been killed, let's fail it. - if (task.killed) { - throw new TaskKilledException - } + task.context.killTaskIfInterrupted() val resultSer = env.serializer.newInstance() val beforeSerialization = System.currentTimeMillis() @@ -426,15 +442,18 @@ private[spark] class Executor( setTaskFinishedAndClearInterruptStatus() execBackend.statusUpdate(taskId, TaskState.FAILED, ser.serialize(reason)) - case _: TaskKilledException => - logInfo(s"Executor killed $taskName (TID $taskId)") + case t: TaskKilledException => + logInfo(s"Executor killed $taskName (TID $taskId), reason: ${t.reason}") setTaskFinishedAndClearInterruptStatus() - execBackend.statusUpdate(taskId, TaskState.KILLED, ser.serialize(TaskKilled)) + execBackend.statusUpdate(taskId, TaskState.KILLED, ser.serialize(TaskKilled(t.reason))) - case _: InterruptedException if task.killed => - logInfo(s"Executor interrupted and killed $taskName (TID $taskId)") + case _: InterruptedException | NonFatal(_) if + task != null && task.reasonIfKilled.isDefined => + val killReason = task.reasonIfKilled.getOrElse("unknown reason") + logInfo(s"Executor interrupted and killed $taskName (TID $taskId), reason: $killReason") setTaskFinishedAndClearInterruptStatus() - execBackend.statusUpdate(taskId, TaskState.KILLED, ser.serialize(TaskKilled)) + execBackend.statusUpdate( + taskId, TaskState.KILLED, ser.serialize(TaskKilled(killReason))) case CausedBy(cDE: CommitDeniedException) => val reason = cDE.toTaskFailedReason @@ -512,7 +531,8 @@ private[spark] class Executor( */ private class TaskReaper( taskRunner: TaskRunner, - val interruptThread: Boolean) + val interruptThread: Boolean, + val reason: String) extends Runnable { private[this] val taskId: Long = taskRunner.taskId @@ -533,7 +553,7 @@ private[spark] class Executor( // Only attempt to kill the task once. If interruptThread = false then a second kill // attempt would be a no-op and if interruptThread = true then it may not be safe or // effective to interrupt multiple times: - taskRunner.kill(interruptThread = interruptThread) + taskRunner.kill(interruptThread = interruptThread, reason = reason) // Monitor the killed task until it exits. The synchronization logic here is complicated // because we don't want to synchronize on the taskRunner while possibly taking a thread // dump, but we also need to be careful to avoid races between checking whether the task diff --git a/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala b/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala index dfd2f818acdac..a3ce3d1ccc5e3 100644 --- a/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala +++ b/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala @@ -251,13 +251,10 @@ class TaskMetrics private[spark] () extends Serializable { private[spark] def accumulators(): Seq[AccumulatorV2[_, _]] = internalAccums ++ externalAccums - /** - * Looks for a registered accumulator by accumulator name. - */ - private[spark] def lookForAccumulatorByName(name: String): Option[AccumulatorV2[_, _]] = { - accumulators.find { acc => - acc.name.isDefined && acc.name.get == name - } + private[spark] def nonZeroInternalAccums(): Seq[AccumulatorV2[_, _]] = { + // RESULT_SIZE accumulator is always zero at executor, we need to send it back as its + // value will be updated at driver side. + internalAccums.filter(a => !a.isZero || a == _resultSize) } } @@ -308,16 +305,16 @@ private[spark] object TaskMetrics extends Logging { */ def fromAccumulators(accums: Seq[AccumulatorV2[_, _]]): TaskMetrics = { val tm = new TaskMetrics - val (internalAccums, externalAccums) = - accums.partition(a => a.name.isDefined && tm.nameToAccums.contains(a.name.get)) - - internalAccums.foreach { acc => - val tmAcc = tm.nameToAccums(acc.name.get).asInstanceOf[AccumulatorV2[Any, Any]] - tmAcc.metadata = acc.metadata - tmAcc.merge(acc.asInstanceOf[AccumulatorV2[Any, Any]]) + for (acc <- accums) { + val name = acc.name + if (name.isDefined && tm.nameToAccums.contains(name.get)) { + val tmAcc = tm.nameToAccums(name.get).asInstanceOf[AccumulatorV2[Any, Any]] + tmAcc.metadata = acc.metadata + tmAcc.merge(acc.asInstanceOf[AccumulatorV2[Any, Any]]) + } else { + tm.externalAccums += acc + } } - - tm.externalAccums ++= externalAccums tm } } diff --git a/core/src/main/scala/org/apache/spark/internal/config/ConfigBuilder.scala b/core/src/main/scala/org/apache/spark/internal/config/ConfigBuilder.scala index a177e66645c7d..e5d60a7ef0984 100644 --- a/core/src/main/scala/org/apache/spark/internal/config/ConfigBuilder.scala +++ b/core/src/main/scala/org/apache/spark/internal/config/ConfigBuilder.scala @@ -18,6 +18,9 @@ package org.apache.spark.internal.config import java.util.concurrent.TimeUnit +import java.util.regex.PatternSyntaxException + +import scala.util.matching.Regex import org.apache.spark.network.util.{ByteUnit, JavaUtils} @@ -65,6 +68,13 @@ private object ConfigHelpers { def byteToString(v: Long, unit: ByteUnit): String = unit.convertTo(v, ByteUnit.BYTE) + "b" + def regexFromString(str: String, key: String): Regex = { + try str.r catch { + case e: PatternSyntaxException => + throw new IllegalArgumentException(s"$key should be a regex, but was $str", e) + } + } + } /** @@ -137,6 +147,14 @@ private[spark] class TypedConfigBuilder[T]( } } + /** Creates a [[ConfigEntry]] with a function to determine the default value */ + def createWithDefaultFunction(defaultFunc: () => T): ConfigEntry[T] = { + val entry = new ConfigEntryWithDefaultFunction[T](parent.key, defaultFunc, converter, + stringConverter, parent._doc, parent._public) + parent._onCreate.foreach(_ (entry)) + entry + } + /** * Creates a [[ConfigEntry]] that has a default value. The default value is provided as a * [[String]] and must be a valid value for the entry. @@ -214,4 +232,7 @@ private[spark] case class ConfigBuilder(key: String) { new FallbackConfigEntry(key, _doc, _public, fallback) } + def regexConf: TypedConfigBuilder[Regex] = { + new TypedConfigBuilder(this, regexFromString(_, this.key), _.toString) + } } diff --git a/core/src/main/scala/org/apache/spark/internal/config/ConfigEntry.scala b/core/src/main/scala/org/apache/spark/internal/config/ConfigEntry.scala index 4f3e42bb3c94e..e86712e84d6ac 100644 --- a/core/src/main/scala/org/apache/spark/internal/config/ConfigEntry.scala +++ b/core/src/main/scala/org/apache/spark/internal/config/ConfigEntry.scala @@ -78,7 +78,24 @@ private class ConfigEntryWithDefault[T] ( def readFrom(reader: ConfigReader): T = { reader.get(key).map(valueConverter).getOrElse(_defaultValue) } +} + +private class ConfigEntryWithDefaultFunction[T] ( + key: String, + _defaultFunction: () => T, + valueConverter: String => T, + stringConverter: T => String, + doc: String, + isPublic: Boolean) + extends ConfigEntry(key, valueConverter, stringConverter, doc, isPublic) { + + override def defaultValue: Option[T] = Some(_defaultFunction()) + override def defaultValueString: String = stringConverter(_defaultFunction()) + + def readFrom(reader: ConfigReader): T = { + reader.get(key).map(valueConverter).getOrElse(_defaultFunction()) + } } private class ConfigEntryWithDefaultString[T] ( diff --git a/core/src/main/scala/org/apache/spark/internal/config/package.scala b/core/src/main/scala/org/apache/spark/internal/config/package.scala index 0928f569e4906..e1c9e7b36bbfb 100644 --- a/core/src/main/scala/org/apache/spark/internal/config/package.scala +++ b/core/src/main/scala/org/apache/spark/internal/config/package.scala @@ -268,10 +268,18 @@ package object config { ConfigBuilder("spark.redaction.regex") .doc("Regex to decide which Spark configuration properties and environment variables in " + "driver and executor environments contain sensitive information. When this regex matches " + - "a property, its value is redacted from the environment UI and various logs like YARN " + - "and event logs.") - .stringConf - .createWithDefault("(?i)secret|password") + "a property key or value, the value is redacted from the environment UI and various logs " + + "like YARN and event logs.") + .regexConf + .createWithDefault("(?i)secret|password".r) + + private[spark] val STRING_REDACTION_PATTERN = + ConfigBuilder("spark.redaction.string.regex") + .doc("Regex to decide which parts of strings produced by Spark contain sensitive " + + "information. When this regex matches a string part, that string part is replaced by a " + + "dummy value. This is currently used to redact the output of SQL explain commands.") + .regexConf + .createOptional private[spark] val NETWORK_AUTH_ENABLED = ConfigBuilder("spark.authenticate") @@ -288,4 +296,10 @@ package object config { .booleanConf .createWithDefault(false) + private[spark] val CHECKPOINT_COMPRESS = + ConfigBuilder("spark.checkpoint.compress") + .doc("Whether to compress RDD checkpoints. Generally a good idea. Compression will use " + + "spark.io.compression.codec.") + .booleanConf + .createWithDefault(false) } diff --git a/core/src/main/scala/org/apache/spark/internal/io/FileCommitProtocol.scala b/core/src/main/scala/org/apache/spark/internal/io/FileCommitProtocol.scala index 2394cf361c33a..7efa9416362a0 100644 --- a/core/src/main/scala/org/apache/spark/internal/io/FileCommitProtocol.scala +++ b/core/src/main/scala/org/apache/spark/internal/io/FileCommitProtocol.scala @@ -121,6 +121,13 @@ abstract class FileCommitProtocol { def deleteWithJob(fs: FileSystem, path: Path, recursive: Boolean): Boolean = { fs.delete(path, recursive) } + + /** + * Called on the driver after a task commits. This can be used to access task commit messages + * before the job has finished. These same task commit messages will be passed to commitJob() + * if the entire job succeeds. + */ + def onTaskCommit(taskCommit: TaskCommitMessage): Unit = {} } diff --git a/core/src/main/scala/org/apache/spark/internal/io/SparkHadoopMapReduceWriter.scala b/core/src/main/scala/org/apache/spark/internal/io/SparkHadoopMapReduceWriter.scala index 659ad5d0bad8c..376ff9bb19f74 100644 --- a/core/src/main/scala/org/apache/spark/internal/io/SparkHadoopMapReduceWriter.scala +++ b/core/src/main/scala/org/apache/spark/internal/io/SparkHadoopMapReduceWriter.scala @@ -179,62 +179,3 @@ object SparkHadoopMapReduceWriter extends Logging { } } } - -private[spark] -object SparkHadoopWriterUtils { - - private val RECORDS_BETWEEN_BYTES_WRITTEN_METRIC_UPDATES = 256 - - def createJobID(time: Date, id: Int): JobID = { - val jobtrackerID = createJobTrackerID(time) - new JobID(jobtrackerID, id) - } - - def createJobTrackerID(time: Date): String = { - new SimpleDateFormat("yyyyMMddHHmmss", Locale.US).format(time) - } - - def createPathFromString(path: String, conf: JobConf): Path = { - if (path == null) { - throw new IllegalArgumentException("Output path is null") - } - val outputPath = new Path(path) - val fs = outputPath.getFileSystem(conf) - if (fs == null) { - throw new IllegalArgumentException("Incorrectly formatted output path") - } - outputPath.makeQualified(fs.getUri, fs.getWorkingDirectory) - } - - // Note: this needs to be a function instead of a 'val' so that the disableOutputSpecValidation - // setting can take effect: - def isOutputSpecValidationEnabled(conf: SparkConf): Boolean = { - val validationDisabled = disableOutputSpecValidation.value - val enabledInConf = conf.getBoolean("spark.hadoop.validateOutputSpecs", true) - enabledInConf && !validationDisabled - } - - // TODO: these don't seem like the right abstractions. - // We should abstract the duplicate code in a less awkward way. - - def initHadoopOutputMetrics(context: TaskContext): (OutputMetrics, () => Long) = { - val bytesWrittenCallback = SparkHadoopUtil.get.getFSBytesWrittenOnThreadCallback() - (context.taskMetrics().outputMetrics, bytesWrittenCallback) - } - - def maybeUpdateOutputMetrics( - outputMetrics: OutputMetrics, - callback: () => Long, - recordsWritten: Long): Unit = { - if (recordsWritten % RECORDS_BETWEEN_BYTES_WRITTEN_METRIC_UPDATES == 0) { - outputMetrics.setBytesWritten(callback()) - outputMetrics.setRecordsWritten(recordsWritten) - } - } - - /** - * Allows for the `spark.hadoop.validateOutputSpecs` checks to be disabled on a case-by-case - * basis; see SPARK-4835 for more details. - */ - val disableOutputSpecValidation: DynamicVariable[Boolean] = new DynamicVariable[Boolean](false) -} diff --git a/core/src/main/scala/org/apache/spark/SparkHadoopWriter.scala b/core/src/main/scala/org/apache/spark/internal/io/SparkHadoopWriter.scala similarity index 97% rename from core/src/main/scala/org/apache/spark/SparkHadoopWriter.scala rename to core/src/main/scala/org/apache/spark/internal/io/SparkHadoopWriter.scala index 46e22b215b8ee..acc9c38571007 100644 --- a/core/src/main/scala/org/apache/spark/SparkHadoopWriter.scala +++ b/core/src/main/scala/org/apache/spark/internal/io/SparkHadoopWriter.scala @@ -15,19 +15,18 @@ * limitations under the License. */ -package org.apache.spark +package org.apache.spark.internal.io import java.io.IOException -import java.text.NumberFormat -import java.text.SimpleDateFormat +import java.text.{NumberFormat, SimpleDateFormat} import java.util.{Date, Locale} import org.apache.hadoop.fs.FileSystem import org.apache.hadoop.mapred._ import org.apache.hadoop.mapreduce.TaskType +import org.apache.spark.SerializableWritable import org.apache.spark.internal.Logging -import org.apache.spark.internal.io.SparkHadoopWriterUtils import org.apache.spark.mapred.SparkHadoopMapRedUtil import org.apache.spark.rdd.HadoopRDD import org.apache.spark.util.SerializableJobConf diff --git a/core/src/main/scala/org/apache/spark/internal/io/SparkHadoopWriterUtils.scala b/core/src/main/scala/org/apache/spark/internal/io/SparkHadoopWriterUtils.scala new file mode 100644 index 0000000000000..de828a6d6156e --- /dev/null +++ b/core/src/main/scala/org/apache/spark/internal/io/SparkHadoopWriterUtils.scala @@ -0,0 +1,93 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.internal.io + +import java.text.SimpleDateFormat +import java.util.{Date, Locale} + +import scala.util.DynamicVariable + +import org.apache.hadoop.fs.Path +import org.apache.hadoop.mapred.{JobConf, JobID} + +import org.apache.spark.{SparkConf, TaskContext} +import org.apache.spark.deploy.SparkHadoopUtil +import org.apache.spark.executor.OutputMetrics + +/** + * A helper object that provide common utils used during saving an RDD using a Hadoop OutputFormat + * (both from the old mapred API and the new mapreduce API) + */ +private[spark] +object SparkHadoopWriterUtils { + + private val RECORDS_BETWEEN_BYTES_WRITTEN_METRIC_UPDATES = 256 + + def createJobID(time: Date, id: Int): JobID = { + val jobtrackerID = createJobTrackerID(time) + new JobID(jobtrackerID, id) + } + + def createJobTrackerID(time: Date): String = { + new SimpleDateFormat("yyyyMMddHHmmss", Locale.US).format(time) + } + + def createPathFromString(path: String, conf: JobConf): Path = { + if (path == null) { + throw new IllegalArgumentException("Output path is null") + } + val outputPath = new Path(path) + val fs = outputPath.getFileSystem(conf) + if (fs == null) { + throw new IllegalArgumentException("Incorrectly formatted output path") + } + outputPath.makeQualified(fs.getUri, fs.getWorkingDirectory) + } + + // Note: this needs to be a function instead of a 'val' so that the disableOutputSpecValidation + // setting can take effect: + def isOutputSpecValidationEnabled(conf: SparkConf): Boolean = { + val validationDisabled = disableOutputSpecValidation.value + val enabledInConf = conf.getBoolean("spark.hadoop.validateOutputSpecs", true) + enabledInConf && !validationDisabled + } + + // TODO: these don't seem like the right abstractions. + // We should abstract the duplicate code in a less awkward way. + + def initHadoopOutputMetrics(context: TaskContext): (OutputMetrics, () => Long) = { + val bytesWrittenCallback = SparkHadoopUtil.get.getFSBytesWrittenOnThreadCallback() + (context.taskMetrics().outputMetrics, bytesWrittenCallback) + } + + def maybeUpdateOutputMetrics( + outputMetrics: OutputMetrics, + callback: () => Long, + recordsWritten: Long): Unit = { + if (recordsWritten % RECORDS_BETWEEN_BYTES_WRITTEN_METRIC_UPDATES == 0) { + outputMetrics.setBytesWritten(callback()) + outputMetrics.setRecordsWritten(recordsWritten) + } + } + + /** + * Allows for the `spark.hadoop.validateOutputSpecs` checks to be disabled on a case-by-case + * basis; see SPARK-4835 for more details. + */ + val disableOutputSpecValidation: DynamicVariable[Boolean] = new DynamicVariable[Boolean](false) +} diff --git a/core/src/main/scala/org/apache/spark/io/CompressionCodec.scala b/core/src/main/scala/org/apache/spark/io/CompressionCodec.scala index 2e991ce394c42..0cb16f0627b72 100644 --- a/core/src/main/scala/org/apache/spark/io/CompressionCodec.scala +++ b/core/src/main/scala/org/apache/spark/io/CompressionCodec.scala @@ -18,6 +18,7 @@ package org.apache.spark.io import java.io._ +import java.util.Locale import com.ning.compress.lzf.{LZFInputStream, LZFOutputStream} import net.jpountz.lz4.LZ4BlockOutputStream @@ -66,13 +67,13 @@ private[spark] object CompressionCodec { } def createCodec(conf: SparkConf, codecName: String): CompressionCodec = { - val codecClass = shortCompressionCodecNames.getOrElse(codecName.toLowerCase, codecName) + val codecClass = + shortCompressionCodecNames.getOrElse(codecName.toLowerCase(Locale.ROOT), codecName) val codec = try { val ctor = Utils.classForName(codecClass).getConstructor(classOf[SparkConf]) Some(ctor.newInstance(conf).asInstanceOf[CompressionCodec]) } catch { - case e: ClassNotFoundException => None - case e: IllegalArgumentException => None + case _: ClassNotFoundException | _: IllegalArgumentException => None } codec.getOrElse(throw new IllegalArgumentException(s"Codec [$codecName] is not available. " + s"Consider setting $configKey=$FALLBACK_COMPRESSION_CODEC")) diff --git a/core/src/main/scala/org/apache/spark/launcher/WorkerCommandBuilder.scala b/core/src/main/scala/org/apache/spark/launcher/WorkerCommandBuilder.scala index 3fd812e9fcfe8..4216b2627309e 100644 --- a/core/src/main/scala/org/apache/spark/launcher/WorkerCommandBuilder.scala +++ b/core/src/main/scala/org/apache/spark/launcher/WorkerCommandBuilder.scala @@ -39,7 +39,6 @@ private[spark] class WorkerCommandBuilder(sparkHome: String, memoryMb: Int, comm val cmd = buildJavaCommand(command.classPathEntries.mkString(File.pathSeparator)) cmd.add(s"-Xmx${memoryMb}M") command.javaOpts.foreach(cmd.add) - addOptionString(cmd, getenv("SPARK_JAVA_OPTS")) cmd } diff --git a/core/src/main/scala/org/apache/spark/metrics/sink/ConsoleSink.scala b/core/src/main/scala/org/apache/spark/metrics/sink/ConsoleSink.scala index 81b9056b40fb8..fce556fd0382c 100644 --- a/core/src/main/scala/org/apache/spark/metrics/sink/ConsoleSink.scala +++ b/core/src/main/scala/org/apache/spark/metrics/sink/ConsoleSink.scala @@ -17,7 +17,7 @@ package org.apache.spark.metrics.sink -import java.util.Properties +import java.util.{Locale, Properties} import java.util.concurrent.TimeUnit import com.codahale.metrics.{ConsoleReporter, MetricRegistry} @@ -39,7 +39,7 @@ private[spark] class ConsoleSink(val property: Properties, val registry: MetricR } val pollUnit: TimeUnit = Option(property.getProperty(CONSOLE_KEY_UNIT)) match { - case Some(s) => TimeUnit.valueOf(s.toUpperCase()) + case Some(s) => TimeUnit.valueOf(s.toUpperCase(Locale.ROOT)) case None => TimeUnit.valueOf(CONSOLE_DEFAULT_UNIT) } diff --git a/core/src/main/scala/org/apache/spark/metrics/sink/CsvSink.scala b/core/src/main/scala/org/apache/spark/metrics/sink/CsvSink.scala index 9d5f2ae9328ad..88bba2fdbd1c6 100644 --- a/core/src/main/scala/org/apache/spark/metrics/sink/CsvSink.scala +++ b/core/src/main/scala/org/apache/spark/metrics/sink/CsvSink.scala @@ -42,7 +42,7 @@ private[spark] class CsvSink(val property: Properties, val registry: MetricRegis } val pollUnit: TimeUnit = Option(property.getProperty(CSV_KEY_UNIT)) match { - case Some(s) => TimeUnit.valueOf(s.toUpperCase()) + case Some(s) => TimeUnit.valueOf(s.toUpperCase(Locale.ROOT)) case None => TimeUnit.valueOf(CSV_DEFAULT_UNIT) } diff --git a/core/src/main/scala/org/apache/spark/metrics/sink/GraphiteSink.scala b/core/src/main/scala/org/apache/spark/metrics/sink/GraphiteSink.scala index 22454e50b14b4..23e31823f4930 100644 --- a/core/src/main/scala/org/apache/spark/metrics/sink/GraphiteSink.scala +++ b/core/src/main/scala/org/apache/spark/metrics/sink/GraphiteSink.scala @@ -18,7 +18,7 @@ package org.apache.spark.metrics.sink import java.net.InetSocketAddress -import java.util.Properties +import java.util.{Locale, Properties} import java.util.concurrent.TimeUnit import com.codahale.metrics.MetricRegistry @@ -59,7 +59,7 @@ private[spark] class GraphiteSink(val property: Properties, val registry: Metric } val pollUnit: TimeUnit = propertyToOption(GRAPHITE_KEY_UNIT) match { - case Some(s) => TimeUnit.valueOf(s.toUpperCase()) + case Some(s) => TimeUnit.valueOf(s.toUpperCase(Locale.ROOT)) case None => TimeUnit.valueOf(GRAPHITE_DEFAULT_UNIT) } @@ -67,7 +67,7 @@ private[spark] class GraphiteSink(val property: Properties, val registry: Metric MetricsSystem.checkMinimalPollingPeriod(pollUnit, pollPeriod) - val graphite = propertyToOption(GRAPHITE_KEY_PROTOCOL).map(_.toLowerCase) match { + val graphite = propertyToOption(GRAPHITE_KEY_PROTOCOL).map(_.toLowerCase(Locale.ROOT)) match { case Some("udp") => new GraphiteUDP(new InetSocketAddress(host, port)) case Some("tcp") | None => new Graphite(new InetSocketAddress(host, port)) case Some(p) => throw new Exception(s"Invalid Graphite protocol: $p") diff --git a/core/src/main/scala/org/apache/spark/metrics/sink/Slf4jSink.scala b/core/src/main/scala/org/apache/spark/metrics/sink/Slf4jSink.scala index 773e074336cb0..7fa4ba7622980 100644 --- a/core/src/main/scala/org/apache/spark/metrics/sink/Slf4jSink.scala +++ b/core/src/main/scala/org/apache/spark/metrics/sink/Slf4jSink.scala @@ -17,7 +17,7 @@ package org.apache.spark.metrics.sink -import java.util.Properties +import java.util.{Locale, Properties} import java.util.concurrent.TimeUnit import com.codahale.metrics.{MetricRegistry, Slf4jReporter} @@ -42,7 +42,7 @@ private[spark] class Slf4jSink( } val pollUnit: TimeUnit = Option(property.getProperty(SLF4J_KEY_UNIT)) match { - case Some(s) => TimeUnit.valueOf(s.toUpperCase()) + case Some(s) => TimeUnit.valueOf(s.toUpperCase(Locale.ROOT)) case None => TimeUnit.valueOf(SLF4J_DEFAULT_UNIT) } diff --git a/core/src/main/scala/org/apache/spark/network/netty/NettyBlockRpcServer.scala b/core/src/main/scala/org/apache/spark/network/netty/NettyBlockRpcServer.scala index 2ed8a00df7023..305fd9a6de10d 100644 --- a/core/src/main/scala/org/apache/spark/network/netty/NettyBlockRpcServer.scala +++ b/core/src/main/scala/org/apache/spark/network/netty/NettyBlockRpcServer.scala @@ -56,11 +56,12 @@ class NettyBlockRpcServer( message match { case openBlocks: OpenBlocks => - val blocks: Seq[ManagedBuffer] = - openBlocks.blockIds.map(BlockId.apply).map(blockManager.getBlockData) + val blocksNum = openBlocks.blockIds.length + val blocks = for (i <- (0 until blocksNum).view) + yield blockManager.getBlockData(BlockId.apply(openBlocks.blockIds(i))) val streamId = streamManager.registerStream(appId, blocks.iterator.asJava) - logTrace(s"Registered streamId $streamId with ${blocks.size} buffers") - responseContext.onSuccess(new StreamHandle(streamId, blocks.size).toByteBuffer) + logTrace(s"Registered streamId $streamId with $blocksNum buffers") + responseContext.onSuccess(new StreamHandle(streamId, blocksNum).toByteBuffer) case uploadBlock: UploadBlock => // StorageLevel and ClassTag are serialized as bytes using our JavaSerializer. diff --git a/core/src/main/scala/org/apache/spark/network/netty/SparkTransportConf.scala b/core/src/main/scala/org/apache/spark/network/netty/SparkTransportConf.scala index df520f804b4c3..25f7bcb9801b9 100644 --- a/core/src/main/scala/org/apache/spark/network/netty/SparkTransportConf.scala +++ b/core/src/main/scala/org/apache/spark/network/netty/SparkTransportConf.scala @@ -60,7 +60,7 @@ object SparkTransportConf { new TransportConf(module, new ConfigProvider { override def get(name: String): String = conf.get(name) - + override def get(name: String, defaultValue: String): String = conf.get(name, defaultValue) override def getAll(): java.lang.Iterable[java.util.Map.Entry[String, String]] = { conf.getAll.toMap.asJava.entrySet() } diff --git a/core/src/main/scala/org/apache/spark/rdd/BlockRDD.scala b/core/src/main/scala/org/apache/spark/rdd/BlockRDD.scala index d47b75544fdba..4e036c2ed49b5 100644 --- a/core/src/main/scala/org/apache/spark/rdd/BlockRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/BlockRDD.scala @@ -47,7 +47,7 @@ class BlockRDD[T: ClassTag](sc: SparkContext, @transient val blockIds: Array[Blo blockManager.get[T](blockId) match { case Some(block) => block.data.asInstanceOf[Iterator[T]] case None => - throw new Exception("Could not compute split, block " + blockId + " not found") + throw new Exception(s"Could not compute split, block $blockId of RDD $id not found") } } diff --git a/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala b/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala index 52ce03ff8cde9..58762cc0838cd 100644 --- a/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala +++ b/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala @@ -37,7 +37,8 @@ import org.apache.spark._ import org.apache.spark.Partitioner.defaultPartitioner import org.apache.spark.annotation.Experimental import org.apache.spark.deploy.SparkHadoopUtil -import org.apache.spark.internal.io.{SparkHadoopMapReduceWriter, SparkHadoopWriterUtils} +import org.apache.spark.internal.io.{SparkHadoopMapReduceWriter, SparkHadoopWriter, + SparkHadoopWriterUtils} import org.apache.spark.internal.Logging import org.apache.spark.partial.{BoundedDouble, PartialResult} import org.apache.spark.serializer.Serializer diff --git a/core/src/main/scala/org/apache/spark/rdd/ParallelCollectionRDD.scala b/core/src/main/scala/org/apache/spark/rdd/ParallelCollectionRDD.scala index e9092739b298a..9f8019b80a4dd 100644 --- a/core/src/main/scala/org/apache/spark/rdd/ParallelCollectionRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/ParallelCollectionRDD.scala @@ -116,7 +116,7 @@ private object ParallelCollectionRDD { */ def slice[T: ClassTag](seq: Seq[T], numSlices: Int): Seq[Seq[T]] = { if (numSlices < 1) { - throw new IllegalArgumentException("Positive number of slices required") + throw new IllegalArgumentException("Positive number of partitions required") } // Sequences need to be sliced at the same set of index positions for operations // like RDD.zip() to behave as expected diff --git a/core/src/main/scala/org/apache/spark/rdd/RDD.scala b/core/src/main/scala/org/apache/spark/rdd/RDD.scala index e524675332d1b..63a87e7f09d85 100644 --- a/core/src/main/scala/org/apache/spark/rdd/RDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/RDD.scala @@ -41,7 +41,7 @@ import org.apache.spark.partial.GroupedCountEvaluator import org.apache.spark.partial.PartialResult import org.apache.spark.storage.{RDDBlockId, StorageLevel} import org.apache.spark.util.{BoundedPriorityQueue, Utils} -import org.apache.spark.util.collection.OpenHashMap +import org.apache.spark.util.collection.{OpenHashMap, Utils => collectionUtils} import org.apache.spark.util.random.{BernoulliCellSampler, BernoulliSampler, PoissonSampler, SamplingUtils} @@ -1420,7 +1420,7 @@ abstract class RDD[T: ClassTag]( val mapRDDs = mapPartitions { items => // Priority keeps the largest elements, so let's reverse the ordering. val queue = new BoundedPriorityQueue[T](num)(ord.reverse) - queue ++= util.collection.Utils.takeOrdered(items, num)(ord) + queue ++= collectionUtils.takeOrdered(items, num)(ord) Iterator.single(queue) } if (mapRDDs.partitions.length == 0) { diff --git a/core/src/main/scala/org/apache/spark/rdd/ReliableCheckpointRDD.scala b/core/src/main/scala/org/apache/spark/rdd/ReliableCheckpointRDD.scala index e0a29b48314fb..37c67cee55f90 100644 --- a/core/src/main/scala/org/apache/spark/rdd/ReliableCheckpointRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/ReliableCheckpointRDD.scala @@ -18,6 +18,7 @@ package org.apache.spark.rdd import java.io.{FileNotFoundException, IOException} +import java.util.concurrent.TimeUnit import scala.reflect.ClassTag import scala.util.control.NonFatal @@ -27,6 +28,8 @@ import org.apache.hadoop.fs.Path import org.apache.spark._ import org.apache.spark.broadcast.Broadcast import org.apache.spark.internal.Logging +import org.apache.spark.internal.config.CHECKPOINT_COMPRESS +import org.apache.spark.io.CompressionCodec import org.apache.spark.util.{SerializableConfiguration, Utils} /** @@ -119,6 +122,7 @@ private[spark] object ReliableCheckpointRDD extends Logging { originalRDD: RDD[T], checkpointDir: String, blockSize: Int = -1): ReliableCheckpointRDD[T] = { + val checkpointStartTimeNs = System.nanoTime() val sc = originalRDD.sparkContext @@ -140,6 +144,10 @@ private[spark] object ReliableCheckpointRDD extends Logging { writePartitionerToCheckpointDir(sc, originalRDD.partitioner.get, checkpointDirPath) } + val checkpointDurationMs = + TimeUnit.NANOSECONDS.toMillis(System.nanoTime() - checkpointStartTimeNs) + logInfo(s"Checkpointing took $checkpointDurationMs ms.") + val newRDD = new ReliableCheckpointRDD[T]( sc, checkpointDirPath.toString, originalRDD.partitioner) if (newRDD.partitions.length != originalRDD.partitions.length) { @@ -169,7 +177,12 @@ private[spark] object ReliableCheckpointRDD extends Logging { val bufferSize = env.conf.getInt("spark.buffer.size", 65536) val fileOutputStream = if (blockSize < 0) { - fs.create(tempOutputPath, false, bufferSize) + val fileStream = fs.create(tempOutputPath, false, bufferSize) + if (env.conf.get(CHECKPOINT_COMPRESS)) { + CompressionCodec.createCodec(env.conf).compressedOutputStream(fileStream) + } else { + fileStream + } } else { // This is mainly for testing purpose fs.create(tempOutputPath, false, bufferSize, @@ -273,7 +286,14 @@ private[spark] object ReliableCheckpointRDD extends Logging { val env = SparkEnv.get val fs = path.getFileSystem(broadcastedConf.value.value) val bufferSize = env.conf.getInt("spark.buffer.size", 65536) - val fileInputStream = fs.open(path, bufferSize) + val fileInputStream = { + val fileStream = fs.open(path, bufferSize) + if (env.conf.get(CHECKPOINT_COMPRESS)) { + CompressionCodec.createCodec(env.conf).compressedInputStream(fileStream) + } else { + fileStream + } + } val serializer = env.serializer.newInstance() val deserializeStream = serializer.deserializeStream(fileInputStream) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/impl/PeriodicRDDCheckpointer.scala b/core/src/main/scala/org/apache/spark/rdd/util/PeriodicRDDCheckpointer.scala similarity index 97% rename from mllib/src/main/scala/org/apache/spark/mllib/impl/PeriodicRDDCheckpointer.scala rename to core/src/main/scala/org/apache/spark/rdd/util/PeriodicRDDCheckpointer.scala index 145dc22b7428e..ab72addb2466b 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/impl/PeriodicRDDCheckpointer.scala +++ b/core/src/main/scala/org/apache/spark/rdd/util/PeriodicRDDCheckpointer.scala @@ -15,11 +15,12 @@ * limitations under the License. */ -package org.apache.spark.mllib.impl +package org.apache.spark.rdd.util import org.apache.spark.SparkContext import org.apache.spark.rdd.RDD import org.apache.spark.storage.StorageLevel +import org.apache.spark.util.PeriodicCheckpointer /** diff --git a/core/src/main/scala/org/apache/spark/rpc/RpcEndpoint.scala b/core/src/main/scala/org/apache/spark/rpc/RpcEndpoint.scala index 0ba95169529e6..97eed540b8f59 100644 --- a/core/src/main/scala/org/apache/spark/rpc/RpcEndpoint.scala +++ b/core/src/main/scala/org/apache/spark/rpc/RpcEndpoint.scala @@ -35,7 +35,7 @@ private[spark] trait RpcEnvFactory { * * The life-cycle of an endpoint is: * - * constructor -> onStart -> receive* -> onStop + * {@code constructor -> onStart -> receive* -> onStop} * * Note: `receive` can be called concurrently. If you want `receive` to be thread-safe, please use * [[ThreadSafeRpcEndpoint]] @@ -63,16 +63,16 @@ private[spark] trait RpcEndpoint { } /** - * Process messages from [[RpcEndpointRef.send]] or [[RpcCallContext.reply)]]. If receiving a - * unmatched message, [[SparkException]] will be thrown and sent to `onError`. + * Process messages from `RpcEndpointRef.send` or `RpcCallContext.reply`. If receiving a + * unmatched message, `SparkException` will be thrown and sent to `onError`. */ def receive: PartialFunction[Any, Unit] = { case _ => throw new SparkException(self + " does not implement 'receive'") } /** - * Process messages from [[RpcEndpointRef.ask]]. If receiving a unmatched message, - * [[SparkException]] will be thrown and sent to `onError`. + * Process messages from `RpcEndpointRef.ask`. If receiving a unmatched message, + * `SparkException` will be thrown and sent to `onError`. */ def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = { case _ => context.sendFailure(new SparkException(self + " won't reply anything")) diff --git a/core/src/main/scala/org/apache/spark/rpc/RpcTimeout.scala b/core/src/main/scala/org/apache/spark/rpc/RpcTimeout.scala index 2c9a976e76939..0557b7a3cc0b7 100644 --- a/core/src/main/scala/org/apache/spark/rpc/RpcTimeout.scala +++ b/core/src/main/scala/org/apache/spark/rpc/RpcTimeout.scala @@ -26,7 +26,7 @@ import org.apache.spark.SparkConf import org.apache.spark.util.{ThreadUtils, Utils} /** - * An exception thrown if RpcTimeout modifies a [[TimeoutException]]. + * An exception thrown if RpcTimeout modifies a `TimeoutException`. */ private[rpc] class RpcTimeoutException(message: String, cause: TimeoutException) extends TimeoutException(message) { initCause(cause) } diff --git a/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala b/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala index ff5e39a8dcbc8..b316e5443f639 100644 --- a/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala +++ b/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala @@ -236,7 +236,8 @@ private[netty] class NettyRpcEnv( val timeoutCancelable = timeoutScheduler.schedule(new Runnable { override def run(): Unit = { - onFailure(new TimeoutException(s"Cannot receive any reply in ${timeout.duration}")) + onFailure(new TimeoutException(s"Cannot receive any reply from ${remoteAddr} " + + s"in ${timeout.duration}")) } }, timeout.duration.toNanos, TimeUnit.NANOSECONDS) promise.future.onComplete { v => diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala index 692ed8083475c..aab177f257a8c 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala @@ -187,6 +187,13 @@ class DAGScheduler( /** If enabled, FetchFailed will not cause stage retry, in order to surface the problem. */ private val disallowStageRetryForTest = sc.getConf.getBoolean("spark.test.noStageRetry", false) + /** + * Number of consecutive stage attempts allowed before a stage is aborted. + */ + private[scheduler] val maxConsecutiveStageAttempts = + sc.getConf.getInt("spark.stage.maxConsecutiveAttempts", + DAGScheduler.DEFAULT_MAX_CONSECUTIVE_STAGE_ATTEMPTS) + private val messageScheduler = ThreadUtils.newDaemonSingleThreadScheduledExecutor("dag-scheduler-message") @@ -600,7 +607,7 @@ class DAGScheduler( * @param resultHandler callback to pass each result to * @param properties scheduler properties to attach to this job, e.g. fair scheduler pool name * - * @throws Exception when the job fails + * @note Throws `Exception` when the job fails */ def runJob[T, U]( rdd: RDD[T], @@ -637,7 +644,7 @@ class DAGScheduler( * * @param rdd target RDD to run tasks on * @param func a function to run on each partition of the RDD - * @param evaluator [[ApproximateEvaluator]] to receive the partial results + * @param evaluator `ApproximateEvaluator` to receive the partial results * @param callSite where in the user program this job was called * @param timeout maximum time to wait for the job, in milliseconds * @param properties scheduler properties to attach to this job, e.g. fair scheduler pool name @@ -731,6 +738,15 @@ class DAGScheduler( eventProcessLoop.post(StageCancelled(stageId, reason)) } + /** + * Kill a given task. It will be retried. + * + * @return Whether the task was successfully killed. + */ + def killTaskAttempt(taskId: Long, interruptThread: Boolean, reason: String): Boolean = { + taskScheduler.killTaskAttempt(taskId, interruptThread, reason) + } + /** * Resubmit any failed stages. Ordinarily called after a small amount of time has passed since * the last fetch failure. @@ -1282,8 +1298,9 @@ class DAGScheduler( s"longer running") } + failedStage.fetchFailedAttemptIds.add(task.stageAttemptId) val shouldAbortStage = - failedStage.failedOnFetchAndShouldAbort(task.stageAttemptId) || + failedStage.fetchFailedAttemptIds.size >= maxConsecutiveStageAttempts || disallowStageRetryForTest if (shouldAbortStage) { @@ -1292,7 +1309,7 @@ class DAGScheduler( } else { s"""$failedStage (${failedStage.name}) |has failed the maximum allowable number of - |times: ${Stage.MAX_CONSECUTIVE_FETCH_FAILURES}. + |times: $maxConsecutiveStageAttempts. |Most recent failure reason: $failureMessage""".stripMargin.replaceAll("\n", " ") } abortStage(failedStage, abortMessage, None) @@ -1345,7 +1362,7 @@ class DAGScheduler( case TaskResultLost => // Do nothing here; the TaskScheduler handles these failures and resubmits the task. - case _: ExecutorLostFailure | TaskKilled | UnknownReason => + case _: ExecutorLostFailure | _: TaskKilled | UnknownReason => // Unrecognized failure - also do nothing. If the task fails repeatedly, the TaskScheduler // will abort the job. } @@ -1726,4 +1743,7 @@ private[spark] object DAGScheduler { // this is a simplistic way to avoid resubmitting tasks in the non-fetchable map stage one by one // as more failure events come in val RESUBMIT_TIMEOUT = 200 + + // Number of consecutive stage attempts allowed before a stage is aborted + val DEFAULT_MAX_CONSECUTIVE_STAGE_ATTEMPTS = 4 } diff --git a/core/src/main/scala/org/apache/spark/scheduler/EventLoggingListener.scala b/core/src/main/scala/org/apache/spark/scheduler/EventLoggingListener.scala index af9bdefc967ef..a7dbf87915b27 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/EventLoggingListener.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/EventLoggingListener.scala @@ -20,6 +20,7 @@ package org.apache.spark.scheduler import java.io._ import java.net.URI import java.nio.charset.StandardCharsets +import java.util.Locale import scala.collection.mutable import scala.collection.mutable.ArrayBuffer @@ -251,11 +252,17 @@ private[spark] class EventLoggingListener( private[spark] def redactEvent( event: SparkListenerEnvironmentUpdate): SparkListenerEnvironmentUpdate = { - // "Spark Properties" entry will always exist because the map is always populated with it. - val redactedProps = Utils.redact(sparkConf, event.environmentDetails("Spark Properties")) - val redactedEnvironmentDetails = event.environmentDetails + - ("Spark Properties" -> redactedProps) - SparkListenerEnvironmentUpdate(redactedEnvironmentDetails) + // environmentDetails maps a string descriptor to a set of properties + // Similar to: + // "JVM Information" -> jvmInformation, + // "Spark Properties" -> sparkProperties, + // ... + // where jvmInformation, sparkProperties, etc. are sequence of tuples. + // We go through the various of properties and redact sensitive information from them. + val redactedProps = event.environmentDetails.map{ case (name, props) => + name -> Utils.redact(sparkConf, props) + } + SparkListenerEnvironmentUpdate(redactedProps) } } @@ -316,7 +323,7 @@ private[spark] object EventLoggingListener extends Logging { } private def sanitize(str: String): String = { - str.replaceAll("[ :/]", "-").replaceAll("[.${}'\"]", "_").toLowerCase + str.replaceAll("[ :/]", "-").replaceAll("[.${}'\"]", "_").toLowerCase(Locale.ROOT) } /** diff --git a/core/src/main/scala/org/apache/spark/scheduler/ExternalClusterManager.scala b/core/src/main/scala/org/apache/spark/scheduler/ExternalClusterManager.scala index d1ac7131baba5..47f3527a32c01 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/ExternalClusterManager.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/ExternalClusterManager.scala @@ -42,7 +42,7 @@ private[spark] trait ExternalClusterManager { /** * Create a scheduler backend for the given SparkContext and scheduler. This is - * called after task scheduler is created using [[ExternalClusterManager.createTaskScheduler()]]. + * called after task scheduler is created using `ExternalClusterManager.createTaskScheduler()`. * @param sc SparkContext * @param masterURL the master URL * @param scheduler TaskScheduler that will be used with the scheduler backend. diff --git a/core/src/main/scala/org/apache/spark/scheduler/Pool.scala b/core/src/main/scala/org/apache/spark/scheduler/Pool.scala index 2a69a6c5e8790..1181371ab425a 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/Pool.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/Pool.scala @@ -37,24 +37,24 @@ private[spark] class Pool( val schedulableQueue = new ConcurrentLinkedQueue[Schedulable] val schedulableNameToSchedulable = new ConcurrentHashMap[String, Schedulable] - var weight = initWeight - var minShare = initMinShare + val weight = initWeight + val minShare = initMinShare var runningTasks = 0 - var priority = 0 + val priority = 0 // A pool's stage id is used to break the tie in scheduling. var stageId = -1 - var name = poolName + val name = poolName var parent: Pool = null - var taskSetSchedulingAlgorithm: SchedulingAlgorithm = { + private val taskSetSchedulingAlgorithm: SchedulingAlgorithm = { schedulingMode match { case SchedulingMode.FAIR => new FairSchedulingAlgorithm() case SchedulingMode.FIFO => new FIFOSchedulingAlgorithm() case _ => - val msg = "Unsupported scheduling mode: $schedulingMode. Use FAIR or FIFO instead." + val msg = s"Unsupported scheduling mode: $schedulingMode. Use FAIR or FIFO instead." throw new IllegalArgumentException(msg) } } diff --git a/core/src/main/scala/org/apache/spark/scheduler/SchedulableBuilder.scala b/core/src/main/scala/org/apache/spark/scheduler/SchedulableBuilder.scala index e53c4fb5b4778..5f3c280ec31ed 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/SchedulableBuilder.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/SchedulableBuilder.scala @@ -18,7 +18,7 @@ package org.apache.spark.scheduler import java.io.{FileInputStream, InputStream} -import java.util.{NoSuchElementException, Properties} +import java.util.{Locale, NoSuchElementException, Properties} import scala.util.control.NonFatal import scala.xml.{Node, XML} @@ -142,7 +142,8 @@ private[spark] class FairSchedulableBuilder(val rootPool: Pool, conf: SparkConf) defaultValue: SchedulingMode, fileName: String): SchedulingMode = { - val xmlSchedulingMode = (poolNode \ SCHEDULING_MODE_PROPERTY).text.trim.toUpperCase + val xmlSchedulingMode = + (poolNode \ SCHEDULING_MODE_PROPERTY).text.trim.toUpperCase(Locale.ROOT) val warningMessage = s"Unsupported schedulingMode: $xmlSchedulingMode found in " + s"Fair Scheduler configuration file: $fileName, using " + s"the default schedulingMode: $defaultValue for pool: $poolName" @@ -180,20 +181,23 @@ private[spark] class FairSchedulableBuilder(val rootPool: Pool, conf: SparkConf) } override def addTaskSetManager(manager: Schedulable, properties: Properties) { - var poolName = DEFAULT_POOL_NAME - var parentPool = rootPool.getSchedulableByName(poolName) - if (properties != null) { - poolName = properties.getProperty(FAIR_SCHEDULER_PROPERTIES, DEFAULT_POOL_NAME) - parentPool = rootPool.getSchedulableByName(poolName) - if (parentPool == null) { - // we will create a new pool that user has configured in app - // instead of being defined in xml file - parentPool = new Pool(poolName, DEFAULT_SCHEDULING_MODE, - DEFAULT_MINIMUM_SHARE, DEFAULT_WEIGHT) - rootPool.addSchedulable(parentPool) - logInfo("Created pool: %s, schedulingMode: %s, minShare: %d, weight: %d".format( - poolName, DEFAULT_SCHEDULING_MODE, DEFAULT_MINIMUM_SHARE, DEFAULT_WEIGHT)) + val poolName = if (properties != null) { + properties.getProperty(FAIR_SCHEDULER_PROPERTIES, DEFAULT_POOL_NAME) + } else { + DEFAULT_POOL_NAME } + var parentPool = rootPool.getSchedulableByName(poolName) + if (parentPool == null) { + // we will create a new pool that user has configured in app + // instead of being defined in xml file + parentPool = new Pool(poolName, DEFAULT_SCHEDULING_MODE, + DEFAULT_MINIMUM_SHARE, DEFAULT_WEIGHT) + rootPool.addSchedulable(parentPool) + logWarning(s"A job was submitted with scheduler pool $poolName, which has not been " + + "configured. This can happen when the file that pools are read from isn't set, or " + + s"when that file doesn't contain $poolName. Created $poolName with default " + + s"configuration (schedulingMode: $DEFAULT_SCHEDULING_MODE, " + + s"minShare: $DEFAULT_MINIMUM_SHARE, weight: $DEFAULT_WEIGHT)") } parentPool.addSchedulable(manager) logInfo("Added task set " + manager.name + " tasks to pool " + poolName) diff --git a/core/src/main/scala/org/apache/spark/scheduler/SchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/SchedulerBackend.scala index 8801a761afae3..22db3350abfa7 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/SchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/SchedulerBackend.scala @@ -30,8 +30,21 @@ private[spark] trait SchedulerBackend { def reviveOffers(): Unit def defaultParallelism(): Int - def killTask(taskId: Long, executorId: String, interruptThread: Boolean): Unit = + /** + * Requests that an executor kills a running task. + * + * @param taskId Id of the task. + * @param executorId Id of the executor the task is running on. + * @param interruptThread Whether the executor should interrupt the task thread. + * @param reason The reason for the task kill. + */ + def killTask( + taskId: Long, + executorId: String, + interruptThread: Boolean, + reason: String): Unit = throw new UnsupportedOperationException + def isReady(): Boolean = true /** diff --git a/core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala b/core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala index 4331addb44172..bc2e530716686 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala @@ -87,8 +87,13 @@ case class SparkListenerEnvironmentUpdate(environmentDetails: Map[String, Seq[(S extends SparkListenerEvent @DeveloperApi -case class SparkListenerBlockManagerAdded(time: Long, blockManagerId: BlockManagerId, maxMem: Long) - extends SparkListenerEvent +case class SparkListenerBlockManagerAdded( + time: Long, + blockManagerId: BlockManagerId, + maxMem: Long, + maxOnHeapMem: Option[Long] = None, + maxOffHeapMem: Option[Long] = None) extends SparkListenerEvent { +} @DeveloperApi case class SparkListenerBlockManagerRemoved(time: Long, blockManagerId: BlockManagerId) diff --git a/core/src/main/scala/org/apache/spark/scheduler/Stage.scala b/core/src/main/scala/org/apache/spark/scheduler/Stage.scala index 32e5df6d75f4f..290fd073caf27 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/Stage.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/Stage.scala @@ -87,23 +87,12 @@ private[scheduler] abstract class Stage( * We keep track of each attempt ID that has failed to avoid recording duplicate failures if * multiple tasks from the same stage attempt fail (SPARK-5945). */ - private val fetchFailedAttemptIds = new HashSet[Int] + val fetchFailedAttemptIds = new HashSet[Int] private[scheduler] def clearFailures() : Unit = { fetchFailedAttemptIds.clear() } - /** - * Check whether we should abort the failedStage due to multiple consecutive fetch failures. - * - * This method updates the running set of failed stage attempts and returns - * true if the number of failures exceeds the allowable number of failures. - */ - private[scheduler] def failedOnFetchAndShouldAbort(stageAttemptId: Int): Boolean = { - fetchFailedAttemptIds.add(stageAttemptId) - fetchFailedAttemptIds.size >= Stage.MAX_CONSECUTIVE_FETCH_FAILURES - } - /** Creates a new attempt for this stage by creating a new StageInfo with a new attempt ID. */ def makeNewStageAttempt( numPartitionsToCompute: Int, @@ -128,8 +117,3 @@ private[scheduler] abstract class Stage( /** Returns the sequence of partition ids that are missing (i.e. needs to be computed). */ def findMissingPartitions(): Seq[Int] } - -private[scheduler] object Stage { - // The number of consecutive failures allowed before a stage is aborted - val MAX_CONSECUTIVE_FETCH_FAILURES = 4 -} diff --git a/core/src/main/scala/org/apache/spark/scheduler/Task.scala b/core/src/main/scala/org/apache/spark/scheduler/Task.scala index 70213722aae4f..5c337b992c840 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/Task.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/Task.scala @@ -89,8 +89,8 @@ private[spark] abstract class Task[T]( TaskContext.setTaskContext(context) taskThread = Thread.currentThread() - if (_killed) { - kill(interruptThread = false) + if (_reasonIfKilled != null) { + kill(interruptThread = false, _reasonIfKilled) } new CallerContext( @@ -149,7 +149,7 @@ private[spark] abstract class Task[T]( def preferredLocations: Seq[TaskLocation] = Nil - // Map output tracker epoch. Will be set by TaskScheduler. + // Map output tracker epoch. Will be set by TaskSetManager. var epoch: Long = -1 // Task context, to be initialized in run(). @@ -158,17 +158,17 @@ private[spark] abstract class Task[T]( // The actual Thread on which the task is running, if any. Initialized in run(). @volatile @transient private var taskThread: Thread = _ - // A flag to indicate whether the task is killed. This is used in case context is not yet - // initialized when kill() is invoked. - @volatile @transient private var _killed = false + // If non-null, this task has been killed and the reason is as specified. This is used in case + // context is not yet initialized when kill() is invoked. + @volatile @transient private var _reasonIfKilled: String = null protected var _executorDeserializeTime: Long = 0 protected var _executorDeserializeCpuTime: Long = 0 /** - * Whether the task has been killed. + * If defined, this task has been killed and this option contains the reason. */ - def killed: Boolean = _killed + def reasonIfKilled: Option[String] = Option(_reasonIfKilled) /** * Returns the amount of time spent deserializing the RDD and function to be run. @@ -182,14 +182,11 @@ private[spark] abstract class Task[T]( */ def collectAccumulatorUpdates(taskFailed: Boolean = false): Seq[AccumulatorV2[_, _]] = { if (context != null) { - context.taskMetrics.internalAccums.filter { a => - // RESULT_SIZE accumulator is always zero at executor, we need to send it back as its - // value will be updated at driver side. - // Note: internal accumulators representing task metrics always count failed values - !a.isZero || a.name == Some(InternalAccumulator.RESULT_SIZE) - // zero value external accumulators may still be useful, e.g. SQLMetrics, we should not filter - // them out. - } ++ context.taskMetrics.externalAccums.filter(a => !taskFailed || a.countFailedValues) + // Note: internal accumulators representing task metrics always count failed values + context.taskMetrics.nonZeroInternalAccums() ++ + // zero value external accumulators may still be useful, e.g. SQLMetrics, we should not + // filter them out. + context.taskMetrics.externalAccums.filter(a => !taskFailed || a.countFailedValues) } else { Seq.empty } @@ -201,10 +198,11 @@ private[spark] abstract class Task[T]( * be called multiple times. * If interruptThread is true, we will also call Thread.interrupt() on the Task's executor thread. */ - def kill(interruptThread: Boolean) { - _killed = true + def kill(interruptThread: Boolean, reason: String) { + require(reason != null) + _reasonIfKilled = reason if (context != null) { - context.markInterrupted() + context.markInterrupted(reason) } if (interruptThread && taskThread != null) { taskThread.interrupt() diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskDescription.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskDescription.scala index 78aa5c40010cc..c98b87148e404 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskDescription.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskDescription.scala @@ -19,6 +19,7 @@ package org.apache.spark.scheduler import java.io.{DataInputStream, DataOutputStream} import java.nio.ByteBuffer +import java.nio.charset.StandardCharsets import java.util.Properties import scala.collection.JavaConverters._ @@ -86,7 +87,10 @@ private[spark] object TaskDescription { dataOut.writeInt(taskDescription.properties.size()) taskDescription.properties.asScala.foreach { case (key, value) => dataOut.writeUTF(key) - dataOut.writeUTF(value) + // SPARK-19796 -- writeUTF doesn't work for long strings, which can happen for property values + val bytes = value.getBytes(StandardCharsets.UTF_8) + dataOut.writeInt(bytes.length) + dataOut.write(bytes) } // Write the task. The task is already serialized, so write it directly to the byte buffer. @@ -124,7 +128,11 @@ private[spark] object TaskDescription { val properties = new Properties() val numProperties = dataIn.readInt() for (i <- 0 until numProperties) { - properties.setProperty(dataIn.readUTF(), dataIn.readUTF()) + val key = dataIn.readUTF() + val valueLength = dataIn.readInt() + val valueBytes = new Array[Byte](valueLength) + dataIn.readFully(valueBytes) + properties.setProperty(key, new String(valueBytes, StandardCharsets.UTF_8)) } // Create a sub-buffer for the serialized task into its own buffer (to be deserialized later). diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskInfo.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskInfo.scala index 59680139e7af3..9843eab4f1346 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskInfo.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskInfo.scala @@ -70,11 +70,13 @@ class TaskInfo( var killed = false - private[spark] def markGettingResult(time: Long = System.currentTimeMillis) { + private[spark] def markGettingResult(time: Long) { gettingResultTime = time } - private[spark] def markFinished(state: TaskState, time: Long = System.currentTimeMillis) { + private[spark] def markFinished(state: TaskState, time: Long) { + // finishTime should be set larger than 0, otherwise "finished" below will return false. + assert(time > 0) finishTime = time if (state == TaskState.FAILED) { failed = true diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskScheduler.scala index cd13eebe74a99..3de7d1f7de22b 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskScheduler.scala @@ -54,6 +54,13 @@ private[spark] trait TaskScheduler { // Cancel a stage. def cancelTasks(stageId: Int, interruptThread: Boolean): Unit + /** + * Kills a task attempt. + * + * @return Whether the task was successfully killed. + */ + def killTaskAttempt(taskId: Long, interruptThread: Boolean, reason: String): Boolean + // Set the DAG scheduler for upcalls. This is guaranteed to be set before submitTasks is called. def setDAGScheduler(dagScheduler: DAGScheduler): Unit diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala index bfbcfa1aa386f..1b6bc9139f9c9 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala @@ -18,7 +18,7 @@ package org.apache.spark.scheduler import java.nio.ByteBuffer -import java.util.{Timer, TimerTask} +import java.util.{Locale, Timer, TimerTask} import java.util.concurrent.TimeUnit import java.util.concurrent.atomic.AtomicLong @@ -38,7 +38,7 @@ import org.apache.spark.util.{AccumulatorV2, ThreadUtils, Utils} /** * Schedules tasks for multiple types of clusters by acting through a SchedulerBackend. - * It can also work with a local setup by using a [[LocalSchedulerBackend]] and setting + * It can also work with a local setup by using a `LocalSchedulerBackend` and setting * isLocal to true. It handles common logic, like determining a scheduling order across jobs, waking * up to launch speculative tasks, etc. * @@ -56,8 +56,9 @@ private[spark] class TaskSchedulerImpl private[scheduler]( val maxTaskFailures: Int, private[scheduler] val blacklistTrackerOpt: Option[BlacklistTracker], isLocal: Boolean = false) - extends TaskScheduler with Logging -{ + extends TaskScheduler with Logging { + + import TaskSchedulerImpl._ def this(sc: SparkContext) = { this( @@ -130,16 +131,18 @@ private[spark] class TaskSchedulerImpl private[scheduler]( val mapOutputTracker = SparkEnv.get.mapOutputTracker - var schedulableBuilder: SchedulableBuilder = null - var rootPool: Pool = null + private var schedulableBuilder: SchedulableBuilder = null // default scheduler is FIFO - private val schedulingModeConf = conf.get("spark.scheduler.mode", "FIFO") - val schedulingMode: SchedulingMode = try { - SchedulingMode.withName(schedulingModeConf.toUpperCase) - } catch { - case e: java.util.NoSuchElementException => - throw new SparkException(s"Unrecognized spark.scheduler.mode: $schedulingModeConf") - } + private val schedulingModeConf = conf.get(SCHEDULER_MODE_PROPERTY, SchedulingMode.FIFO.toString) + val schedulingMode: SchedulingMode = + try { + SchedulingMode.withName(schedulingModeConf.toUpperCase(Locale.ROOT)) + } catch { + case e: java.util.NoSuchElementException => + throw new SparkException(s"Unrecognized $SCHEDULER_MODE_PROPERTY: $schedulingModeConf") + } + + val rootPool: Pool = new Pool("", schedulingMode, 0, 0) // This is a var so that we can reset it for testing purposes. private[spark] var taskResultGetter = new TaskResultGetter(sc.env, this) @@ -150,8 +153,6 @@ private[spark] class TaskSchedulerImpl private[scheduler]( def initialize(backend: SchedulerBackend) { this.backend = backend - // temporarily set rootPool name to empty - rootPool = new Pool("", schedulingMode, 0, 0) schedulableBuilder = { schedulingMode match { case SchedulingMode.FIFO => @@ -159,7 +160,8 @@ private[spark] class TaskSchedulerImpl private[scheduler]( case SchedulingMode.FAIR => new FairSchedulableBuilder(rootPool, conf) case _ => - throw new IllegalArgumentException(s"Unsupported spark.scheduler.mode: $schedulingMode") + throw new IllegalArgumentException(s"Unsupported $SCHEDULER_MODE_PROPERTY: " + + s"$schedulingMode") } } schedulableBuilder.buildPools() @@ -172,7 +174,7 @@ private[spark] class TaskSchedulerImpl private[scheduler]( if (!isLocal && conf.getBoolean("spark.speculation", false)) { logInfo("Starting speculative execution thread") - speculationScheduler.scheduleAtFixedRate(new Runnable { + speculationScheduler.scheduleWithFixedDelay(new Runnable { override def run(): Unit = Utils.tryOrStopSparkContext(sc) { checkSpeculatableTasks() } @@ -239,7 +241,7 @@ private[spark] class TaskSchedulerImpl private[scheduler]( // simply abort the stage. tsm.runningTasksSet.foreach { tid => val execId = taskIdToExecutorId(tid) - backend.killTask(tid, execId, interruptThread) + backend.killTask(tid, execId, interruptThread, reason = "stage cancelled") } tsm.abort("Stage %s cancelled".format(stageId)) logInfo("Stage %d was cancelled".format(stageId)) @@ -247,6 +249,18 @@ private[spark] class TaskSchedulerImpl private[scheduler]( } } + override def killTaskAttempt(taskId: Long, interruptThread: Boolean, reason: String): Boolean = { + logInfo(s"Killing task $taskId: $reason") + val execId = taskIdToExecutorId.get(taskId) + if (execId.isDefined) { + backend.killTask(taskId, execId.get, interruptThread, reason) + true + } else { + logWarning(s"Could not kill task $taskId because no task with that ID was found.") + false + } + } + /** * Called to indicate that all task attempts (including speculated tasks) associated with the * given TaskSetManager have completed, so state associated with the TaskSetManager should be @@ -467,7 +481,7 @@ private[spark] class TaskSchedulerImpl private[scheduler]( taskState: TaskState, reason: TaskFailedReason): Unit = synchronized { taskSetManager.handleFailedTask(tid, taskState, reason) - if (!taskSetManager.isZombie && taskState != TaskState.KILLED) { + if (!taskSetManager.isZombie && !taskSetManager.someAttemptSucceeded(tid)) { // Need to revive offers again now that the task set manager state has been updated to // reflect failed tasks that need to be re-run. backend.reviveOffers() @@ -683,16 +697,19 @@ private[spark] class TaskSchedulerImpl private[scheduler]( private[spark] object TaskSchedulerImpl { + + val SCHEDULER_MODE_PROPERTY = "spark.scheduler.mode" + /** * Used to balance containers across hosts. * * Accepts a map of hosts to resource offers for that host, and returns a prioritized list of - * resource offers representing the order in which the offers should be used. The resource + * resource offers representing the order in which the offers should be used. The resource * offers are ordered such that we'll allocate one container on each host before allocating a * second container on any host, and so on, in order to reduce the damage if a host fails. * - * For example, given , , , returns - * [o1, o5, o4, 02, o6, o3] + * For example, given {@literal }, {@literal } and + * {@literal }, returns {@literal [o1, o5, o4, o2, o6, o3]}. */ def prioritizeContainers[K, T] (map: HashMap[K, ArrayBuffer[T]]): List[T] = { val _keyList = new ArrayBuffer[K](map.size) diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala index 19ebaf817e24e..a41b059fa7dec 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala @@ -19,11 +19,10 @@ package org.apache.spark.scheduler import java.io.NotSerializableException import java.nio.ByteBuffer -import java.util.Arrays import java.util.concurrent.ConcurrentLinkedQueue import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet} -import scala.math.{max, min} +import scala.math.max import scala.util.control.NonFatal import org.apache.spark._ @@ -31,6 +30,7 @@ import org.apache.spark.internal.Logging import org.apache.spark.scheduler.SchedulingMode._ import org.apache.spark.TaskState.TaskState import org.apache.spark.util.{AccumulatorV2, Clock, SystemClock, Utils} +import org.apache.spark.util.collection.MedianHeap /** * Schedules the tasks within a single TaskSet in the TaskSchedulerImpl. This class keeps track of @@ -63,6 +63,8 @@ private[spark] class TaskSetManager( // Limit of bytes for total size of results (default is 1GB) val maxResultSize = Utils.getMaxResultSize(conf) + val speculationEnabled = conf.getBoolean("spark.speculation", false) + // Serializer for closures and tasks. val env = SparkEnv.get val ser = env.closureSerializer.newInstance() @@ -78,16 +80,16 @@ private[spark] class TaskSetManager( private val numFailures = new Array[Int](numTasks) val taskAttempts = Array.fill[List[TaskInfo]](numTasks)(Nil) - var tasksSuccessful = 0 + private[scheduler] var tasksSuccessful = 0 - var weight = 1 - var minShare = 0 + val weight = 1 + val minShare = 0 var priority = taskSet.priority var stageId = taskSet.stageId val name = "TaskSet_" + taskSet.id var parent: Pool = null - var totalResultSize = 0L - var calculatedTasks = 0 + private var totalResultSize = 0L + private var calculatedTasks = 0 private[scheduler] val taskSetBlacklistHelperOpt: Option[TaskSetBlacklist] = { blacklistTracker.map { _ => @@ -95,17 +97,21 @@ private[spark] class TaskSetManager( } } - val runningTasksSet = new HashSet[Long] + private[scheduler] val runningTasksSet = new HashSet[Long] override def runningTasks: Int = runningTasksSet.size + def someAttemptSucceeded(tid: Long): Boolean = { + successful(taskInfos(tid).index) + } + // True once no more tasks should be launched for this task set manager. TaskSetManagers enter // the zombie state once at least one attempt of each task has completed successfully, or if the // task set is aborted (for example, because it was killed). TaskSetManagers remain in the zombie // state until all tasks have finished running; we keep TaskSetManagers that are in the zombie // state in order to continue to track and account for the running tasks. // TODO: We should kill any running task attempts when the task set manager becomes a zombie. - var isZombie = false + private[scheduler] var isZombie = false // Set of pending tasks for each executor. These collections are actually // treated as stacks, in which new tasks are added to the end of the @@ -129,17 +135,22 @@ private[spark] class TaskSetManager( private val pendingTasksForRack = new HashMap[String, ArrayBuffer[Int]] // Set containing pending tasks with no locality preferences. - var pendingTasksWithNoPrefs = new ArrayBuffer[Int] + private[scheduler] var pendingTasksWithNoPrefs = new ArrayBuffer[Int] // Set containing all pending tasks (also used as a stack, as above). - val allPendingTasks = new ArrayBuffer[Int] + private val allPendingTasks = new ArrayBuffer[Int] // Tasks that can be speculated. Since these will be a small fraction of total // tasks, we'll just hold them in a HashSet. - val speculatableTasks = new HashSet[Int] + private[scheduler] val speculatableTasks = new HashSet[Int] // Task index, start and finish time for each task attempt (indexed by task ID) - val taskInfos = new HashMap[Long, TaskInfo] + private val taskInfos = new HashMap[Long, TaskInfo] + + // Use a MedianHeap to record durations of successful tasks so we know when to launch + // speculative tasks. This is only used when speculation is enabled, to avoid the overhead + // of inserting into the heap when the heap won't be used. + val successfulTaskDurations = new MedianHeap() // How frequently to reprint duplicate exceptions in full, in milliseconds val EXCEPTION_PRINT_INTERVAL = @@ -148,7 +159,7 @@ private[spark] class TaskSetManager( // Map of recent exceptions (identified by string representation and top stack frame) to // duplicate count (how many times the same exception has appeared) and time the full exception // was printed. This should ideally be an LRU map that can drop old exceptions automatically. - val recentExceptions = HashMap[String, (Int, Long)]() + private val recentExceptions = HashMap[String, (Int, Long)]() // Figure out the current map output tracker epoch and set it on all tasks val epoch = sched.mapOutputTracker.getEpoch @@ -169,20 +180,22 @@ private[spark] class TaskSetManager( * This allows a performance optimization, of skipping levels that aren't relevant (eg., skip * PROCESS_LOCAL if no tasks could be run PROCESS_LOCAL for the current set of executors). */ - var myLocalityLevels = computeValidLocalityLevels() - var localityWaits = myLocalityLevels.map(getLocalityWait) // Time to wait at each level + private[scheduler] var myLocalityLevels = computeValidLocalityLevels() + + // Time to wait at each level + private[scheduler] var localityWaits = myLocalityLevels.map(getLocalityWait) // Delay scheduling variables: we keep track of our current locality level and the time we // last launched a task at that level, and move up a level when localityWaits[curLevel] expires. // We then move down if we manage to launch a "more local" task. - var currentLocalityIndex = 0 // Index of our current locality level in validLocalityLevels - var lastLaunchTime = clock.getTimeMillis() // Time we last launched a task at this level + private var currentLocalityIndex = 0 // Index of our current locality level in validLocalityLevels + private var lastLaunchTime = clock.getTimeMillis() // Time we last launched a task at this level override def schedulableQueue: ConcurrentLinkedQueue[Schedulable] = null override def schedulingMode: SchedulingMode = SchedulingMode.NONE - var emittedTaskSizeWarning = false + private[scheduler] var emittedTaskSizeWarning = false /** Add a task to all the pending-task lists that it should be on. */ private def addPendingTask(index: Int) { @@ -667,7 +680,7 @@ private[spark] class TaskSetManager( */ def handleTaskGettingResult(tid: Long): Unit = { val info = taskInfos(tid) - info.markGettingResult() + info.markGettingResult(clock.getTimeMillis()) sched.dagScheduler.taskGettingResult(info) } @@ -695,22 +708,23 @@ private[spark] class TaskSetManager( def handleSuccessfulTask(tid: Long, result: DirectTaskResult[_]): Unit = { val info = taskInfos(tid) val index = info.index - info.markFinished(TaskState.FINISHED) + info.markFinished(TaskState.FINISHED, clock.getTimeMillis()) + if (speculationEnabled) { + successfulTaskDurations.insert(info.duration) + } removeRunningTask(tid) - // This method is called by "TaskSchedulerImpl.handleSuccessfulTask" which holds the - // "TaskSchedulerImpl" lock until exiting. To avoid the SPARK-7655 issue, we should not - // "deserialize" the value when holding a lock to avoid blocking other threads. So we call - // "result.value()" in "TaskResultGetter.enqueueSuccessfulTask" before reaching here. - // Note: "result.value()" only deserializes the value when it's called at the first time, so - // here "result.value()" just returns the value and won't block other threads. - sched.dagScheduler.taskEnded(tasks(index), Success, result.value(), result.accumUpdates, info) + // Kill any other attempts for the same task (since those are unnecessary now that one // attempt completed successfully). for (attemptInfo <- taskAttempts(index) if attemptInfo.running) { logInfo(s"Killing attempt ${attemptInfo.attemptNumber} for task ${attemptInfo.id} " + s"in stage ${taskSet.id} (TID ${attemptInfo.taskId}) on ${attemptInfo.host} " + s"as the attempt ${info.attemptNumber} succeeded on ${info.host}") - sched.backend.killTask(attemptInfo.taskId, attemptInfo.executorId, true) + sched.backend.killTask( + attemptInfo.taskId, + attemptInfo.executorId, + interruptThread = true, + reason = "another attempt succeeded") } if (!successful(index)) { tasksSuccessful += 1 @@ -726,6 +740,13 @@ private[spark] class TaskSetManager( logInfo("Ignoring task-finished event for " + info.id + " in stage " + taskSet.id + " because task " + index + " has already completed successfully") } + // This method is called by "TaskSchedulerImpl.handleSuccessfulTask" which holds the + // "TaskSchedulerImpl" lock until exiting. To avoid the SPARK-7655 issue, we should not + // "deserialize" the value when holding a lock to avoid blocking other threads. So we call + // "result.value()" in "TaskResultGetter.enqueueSuccessfulTask" before reaching here. + // Note: "result.value()" only deserializes the value when it's called at the first time, so + // here "result.value()" just returns the value and won't block other threads. + sched.dagScheduler.taskEnded(tasks(index), Success, result.value(), result.accumUpdates, info) maybeFinishTaskSet() } @@ -739,7 +760,7 @@ private[spark] class TaskSetManager( return } removeRunningTask(tid) - info.markFinished(state) + info.markFinished(state, clock.getTimeMillis()) val index = info.index copiesRunning(index) -= 1 var accumUpdates: Seq[AccumulatorV2[_, _]] = Seq.empty @@ -917,11 +938,10 @@ private[spark] class TaskSetManager( var foundTasks = false val minFinishedForSpeculation = (SPECULATION_QUANTILE * numTasks).floor.toInt logDebug("Checking for speculative tasks: minFinished = " + minFinishedForSpeculation) + if (tasksSuccessful >= minFinishedForSpeculation && tasksSuccessful > 0) { val time = clock.getTimeMillis() - val durations = taskInfos.values.filter(_.successful).map(_.duration).toArray - Arrays.sort(durations) - val medianDuration = durations(min((0.5 * tasksSuccessful).round.toInt, durations.length - 1)) + var medianDuration = successfulTaskDurations.median val threshold = max(SPECULATION_MULTIPLIER * medianDuration, minTimeToSpeculation) // TODO: Threshold should also look at standard deviation of task durations and have a lower // bound based on that. diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedClusterMessage.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedClusterMessage.scala index 2898cd7d17ca0..6b49bd699a13a 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedClusterMessage.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedClusterMessage.scala @@ -40,7 +40,7 @@ private[spark] object CoarseGrainedClusterMessages { // Driver to executors case class LaunchTask(data: SerializableBuffer) extends CoarseGrainedClusterMessage - case class KillTask(taskId: Long, executor: String, interruptThread: Boolean) + case class KillTask(taskId: Long, executor: String, interruptThread: Boolean, reason: String) extends CoarseGrainedClusterMessage case class KillExecutorsOnHost(host: String) diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala index 94abe30bb12f2..dc82bb7704727 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala @@ -69,6 +69,10 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp // `CoarseGrainedSchedulerBackend.this`. private val executorDataMap = new HashMap[String, ExecutorData] + // Number of executors requested by the cluster manager, [[ExecutorAllocationManager]] + @GuardedBy("CoarseGrainedSchedulerBackend.this") + private var requestedTotalExecutors = 0 + // Number of executors requested from the cluster manager that have not registered yet @GuardedBy("CoarseGrainedSchedulerBackend.this") private var numPendingExecutors = 0 @@ -132,10 +136,11 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp case ReviveOffers => makeOffers() - case KillTask(taskId, executorId, interruptThread) => + case KillTask(taskId, executorId, interruptThread, reason) => executorDataMap.get(executorId) match { case Some(executorInfo) => - executorInfo.executorEndpoint.send(KillTask(taskId, executorId, interruptThread)) + executorInfo.executorEndpoint.send( + KillTask(taskId, executorId, interruptThread, reason)) case None => // Ignoring the task kill since the executor is not registered. logWarning(s"Attempted to kill task $taskId for unknown executor $executorId.") @@ -222,12 +227,18 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp // Make fake resource offers on all executors private def makeOffers() { - // Filter out executors under killing - val activeExecutors = executorDataMap.filterKeys(executorIsAlive) - val workOffers = activeExecutors.map { case (id, executorData) => - new WorkerOffer(id, executorData.executorHost, executorData.freeCores) - }.toIndexedSeq - launchTasks(scheduler.resourceOffers(workOffers)) + // Make sure no executor is killed while some task is launching on it + val taskDescs = CoarseGrainedSchedulerBackend.this.synchronized { + // Filter out executors under killing + val activeExecutors = executorDataMap.filterKeys(executorIsAlive) + val workOffers = activeExecutors.map { case (id, executorData) => + new WorkerOffer(id, executorData.executorHost, executorData.freeCores) + }.toIndexedSeq + scheduler.resourceOffers(workOffers) + } + if (!taskDescs.isEmpty) { + launchTasks(taskDescs) + } } override def onDisconnected(remoteAddress: RpcAddress): Unit = { @@ -240,12 +251,20 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp // Make fake resource offers on just one executor private def makeOffers(executorId: String) { - // Filter out executors under killing - if (executorIsAlive(executorId)) { - val executorData = executorDataMap(executorId) - val workOffers = IndexedSeq( - new WorkerOffer(executorId, executorData.executorHost, executorData.freeCores)) - launchTasks(scheduler.resourceOffers(workOffers)) + // Make sure no executor is killed while some task is launching on it + val taskDescs = CoarseGrainedSchedulerBackend.this.synchronized { + // Filter out executors under killing + if (executorIsAlive(executorId)) { + val executorData = executorDataMap(executorId) + val workOffers = IndexedSeq( + new WorkerOffer(executorId, executorData.executorHost, executorData.freeCores)) + scheduler.resourceOffers(workOffers) + } else { + Seq.empty + } + } + if (!taskDescs.isEmpty) { + launchTasks(taskDescs) } } @@ -398,6 +417,7 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp * */ protected def reset(): Unit = { val executors = synchronized { + requestedTotalExecutors = 0 numPendingExecutors = 0 executorsPendingToRemove.clear() Set() ++ executorDataMap.keys @@ -414,8 +434,9 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp driverEndpoint.send(ReviveOffers) } - override def killTask(taskId: Long, executorId: String, interruptThread: Boolean) { - driverEndpoint.send(KillTask(taskId, executorId, interruptThread)) + override def killTask( + taskId: Long, executorId: String, interruptThread: Boolean, reason: String) { + driverEndpoint.send(KillTask(taskId, executorId, interruptThread, reason)) } override def defaultParallelism(): Int = { @@ -471,12 +492,21 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp logInfo(s"Requesting $numAdditionalExecutors additional executor(s) from the cluster manager") val response = synchronized { + requestedTotalExecutors += numAdditionalExecutors numPendingExecutors += numAdditionalExecutors logDebug(s"Number of pending executors is now $numPendingExecutors") + if (requestedTotalExecutors != + (numExistingExecutors + numPendingExecutors - executorsPendingToRemove.size)) { + logDebug( + s"""requestExecutors($numAdditionalExecutors): Executor request doesn't match: + |requestedTotalExecutors = $requestedTotalExecutors + |numExistingExecutors = $numExistingExecutors + |numPendingExecutors = $numPendingExecutors + |executorsPendingToRemove = ${executorsPendingToRemove.size}""".stripMargin) + } // Account for executors pending to be added or removed - doRequestTotalExecutors( - numExistingExecutors + numPendingExecutors - executorsPendingToRemove.size) + doRequestTotalExecutors(requestedTotalExecutors) } defaultAskTimeout.awaitResult(response) @@ -508,6 +538,7 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp } val response = synchronized { + this.requestedTotalExecutors = numExecutors this.localityAwareTasks = localityAwareTasks this.hostToLocalTaskCount = hostToLocalTaskCount @@ -573,8 +604,17 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp // take into account executors that are pending to be added or removed. val adjustTotalExecutors = if (!replace) { - doRequestTotalExecutors( - numExistingExecutors + numPendingExecutors - executorsPendingToRemove.size) + requestedTotalExecutors = math.max(requestedTotalExecutors - executorsToKill.size, 0) + if (requestedTotalExecutors != + (numExistingExecutors + numPendingExecutors - executorsPendingToRemove.size)) { + logDebug( + s"""killExecutors($executorIds, $replace, $force): Executor counts do not match: + |requestedTotalExecutors = $requestedTotalExecutors + |numExistingExecutors = $numExistingExecutors + |numPendingExecutors = $numPendingExecutors + |executorsPendingToRemove = ${executorsPendingToRemove.size}""".stripMargin) + } + doRequestTotalExecutors(requestedTotalExecutors) } else { numPendingExecutors += knownExecutors.size Future.successful(true) diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/StandaloneSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/StandaloneSchedulerBackend.scala index 7befdb0c1f64d..0529fe9eed4da 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/StandaloneSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/StandaloneSchedulerBackend.scala @@ -18,6 +18,7 @@ package org.apache.spark.scheduler.cluster import java.util.concurrent.Semaphore +import java.util.concurrent.atomic.AtomicBoolean import scala.concurrent.Future @@ -42,7 +43,7 @@ private[spark] class StandaloneSchedulerBackend( with Logging { private var client: StandaloneAppClient = null - private var stopping = false + private val stopping = new AtomicBoolean(false) private val launcherBackend = new LauncherBackend() { override protected def onStopRequest(): Unit = stop(SparkAppHandle.State.KILLED) } @@ -112,7 +113,7 @@ private[spark] class StandaloneSchedulerBackend( launcherBackend.setState(SparkAppHandle.State.RUNNING) } - override def stop(): Unit = synchronized { + override def stop(): Unit = { stop(SparkAppHandle.State.FINISHED) } @@ -125,14 +126,14 @@ private[spark] class StandaloneSchedulerBackend( override def disconnected() { notifyContext() - if (!stopping) { + if (!stopping.get) { logWarning("Disconnected from Spark cluster! Waiting for reconnection...") } } override def dead(reason: String) { notifyContext() - if (!stopping) { + if (!stopping.get) { launcherBackend.setState(SparkAppHandle.State.KILLED) logError("Application has been killed. Reason: " + reason) try { @@ -206,20 +207,20 @@ private[spark] class StandaloneSchedulerBackend( registrationBarrier.release() } - private def stop(finalState: SparkAppHandle.State): Unit = synchronized { - try { - stopping = true - - super.stop() - client.stop() + private def stop(finalState: SparkAppHandle.State): Unit = { + if (stopping.compareAndSet(false, true)) { + try { + super.stop() + client.stop() - val callback = shutdownCallback - if (callback != null) { - callback(this) + val callback = shutdownCallback + if (callback != null) { + callback(this) + } + } finally { + launcherBackend.setState(finalState) + launcherBackend.close() } - } finally { - launcherBackend.setState(finalState) - launcherBackend.close() } } diff --git a/core/src/main/scala/org/apache/spark/scheduler/local/LocalSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/local/LocalSchedulerBackend.scala index 625f998cd4608..35509bc2f85b9 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/local/LocalSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/local/LocalSchedulerBackend.scala @@ -34,7 +34,7 @@ private case class ReviveOffers() private case class StatusUpdate(taskId: Long, state: TaskState, serializedData: ByteBuffer) -private case class KillTask(taskId: Long, interruptThread: Boolean) +private case class KillTask(taskId: Long, interruptThread: Boolean, reason: String) private case class StopExecutor() @@ -70,8 +70,8 @@ private[spark] class LocalEndpoint( reviveOffers() } - case KillTask(taskId, interruptThread) => - executor.killTask(taskId, interruptThread) + case KillTask(taskId, interruptThread, reason) => + executor.killTask(taskId, interruptThread, reason) } override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = { @@ -143,8 +143,9 @@ private[spark] class LocalSchedulerBackend( override def defaultParallelism(): Int = scheduler.conf.getInt("spark.default.parallelism", totalCores) - override def killTask(taskId: Long, executorId: String, interruptThread: Boolean) { - localEndpoint.send(KillTask(taskId, interruptThread)) + override def killTask( + taskId: Long, executorId: String, interruptThread: Boolean, reason: String) { + localEndpoint.send(KillTask(taskId, interruptThread, reason)) } override def statusUpdate(taskId: Long, state: TaskState, serializedData: ByteBuffer) { diff --git a/core/src/main/scala/org/apache/spark/security/CryptoStreamUtils.scala b/core/src/main/scala/org/apache/spark/security/CryptoStreamUtils.scala index cdd3b8d8512b1..78dabb42ac9d2 100644 --- a/core/src/main/scala/org/apache/spark/security/CryptoStreamUtils.scala +++ b/core/src/main/scala/org/apache/spark/security/CryptoStreamUtils.scala @@ -16,20 +16,23 @@ */ package org.apache.spark.security -import java.io.{InputStream, OutputStream} +import java.io.{EOFException, InputStream, OutputStream} +import java.nio.ByteBuffer +import java.nio.channels.{ReadableByteChannel, WritableByteChannel} import java.util.Properties import javax.crypto.KeyGenerator import javax.crypto.spec.{IvParameterSpec, SecretKeySpec} import scala.collection.JavaConverters._ +import com.google.common.io.ByteStreams import org.apache.commons.crypto.random._ import org.apache.commons.crypto.stream._ import org.apache.spark.SparkConf import org.apache.spark.internal.Logging import org.apache.spark.internal.config._ -import org.apache.spark.network.util.CryptoUtils +import org.apache.spark.network.util.{CryptoUtils, JavaUtils} /** * A util class for manipulating IO encryption and decryption streams. @@ -48,12 +51,27 @@ private[spark] object CryptoStreamUtils extends Logging { os: OutputStream, sparkConf: SparkConf, key: Array[Byte]): OutputStream = { - val properties = toCryptoConf(sparkConf) - val iv = createInitializationVector(properties) + val params = new CryptoParams(key, sparkConf) + val iv = createInitializationVector(params.conf) os.write(iv) - val transformationStr = sparkConf.get(IO_CRYPTO_CIPHER_TRANSFORMATION) - new CryptoOutputStream(transformationStr, properties, os, - new SecretKeySpec(key, "AES"), new IvParameterSpec(iv)) + new CryptoOutputStream(params.transformation, params.conf, os, params.keySpec, + new IvParameterSpec(iv)) + } + + /** + * Wrap a `WritableByteChannel` for encryption. + */ + def createWritableChannel( + channel: WritableByteChannel, + sparkConf: SparkConf, + key: Array[Byte]): WritableByteChannel = { + val params = new CryptoParams(key, sparkConf) + val iv = createInitializationVector(params.conf) + val helper = new CryptoHelperChannel(channel) + + helper.write(ByteBuffer.wrap(iv)) + new CryptoOutputStream(params.transformation, params.conf, helper, params.keySpec, + new IvParameterSpec(iv)) } /** @@ -63,12 +81,27 @@ private[spark] object CryptoStreamUtils extends Logging { is: InputStream, sparkConf: SparkConf, key: Array[Byte]): InputStream = { - val properties = toCryptoConf(sparkConf) val iv = new Array[Byte](IV_LENGTH_IN_BYTES) - is.read(iv, 0, iv.length) - val transformationStr = sparkConf.get(IO_CRYPTO_CIPHER_TRANSFORMATION) - new CryptoInputStream(transformationStr, properties, is, - new SecretKeySpec(key, "AES"), new IvParameterSpec(iv)) + ByteStreams.readFully(is, iv) + val params = new CryptoParams(key, sparkConf) + new CryptoInputStream(params.transformation, params.conf, is, params.keySpec, + new IvParameterSpec(iv)) + } + + /** + * Wrap a `ReadableByteChannel` for decryption. + */ + def createReadableChannel( + channel: ReadableByteChannel, + sparkConf: SparkConf, + key: Array[Byte]): ReadableByteChannel = { + val iv = new Array[Byte](IV_LENGTH_IN_BYTES) + val buf = ByteBuffer.wrap(iv) + JavaUtils.readFully(channel, buf) + + val params = new CryptoParams(key, sparkConf) + new CryptoInputStream(params.transformation, params.conf, channel, params.keySpec, + new IvParameterSpec(iv)) } def toCryptoConf(conf: SparkConf): Properties = { @@ -102,4 +135,34 @@ private[spark] object CryptoStreamUtils extends Logging { } iv } + + /** + * This class is a workaround for CRYPTO-125, that forces all bytes to be written to the + * underlying channel. Since the callers of this API are using blocking I/O, there are no + * concerns with regards to CPU usage here. + */ + private class CryptoHelperChannel(sink: WritableByteChannel) extends WritableByteChannel { + + override def write(src: ByteBuffer): Int = { + val count = src.remaining() + while (src.hasRemaining()) { + sink.write(src) + } + count + } + + override def isOpen(): Boolean = sink.isOpen() + + override def close(): Unit = sink.close() + + } + + private class CryptoParams(key: Array[Byte], sparkConf: SparkConf) { + + val keySpec = new SecretKeySpec(key, "AES") + val transformation = sparkConf.get(IO_CRYPTO_CIPHER_TRANSFORMATION) + val conf = toCryptoConf(sparkConf) + + } + } diff --git a/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala b/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala index 03815631a604c..e15166d11c243 100644 --- a/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala +++ b/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala @@ -19,6 +19,7 @@ package org.apache.spark.serializer import java.io._ import java.nio.ByteBuffer +import java.util.Locale import javax.annotation.Nullable import scala.collection.JavaConverters._ @@ -244,7 +245,8 @@ class KryoDeserializationStream( kryo.readClassAndObject(input).asInstanceOf[T] } catch { // DeserializationStream uses the EOF exception to indicate stopping condition. - case e: KryoException if e.getMessage.toLowerCase.contains("buffer underflow") => + case e: KryoException + if e.getMessage.toLowerCase(Locale.ROOT).contains("buffer underflow") => throw new EOFException } } @@ -384,9 +386,16 @@ private[serializer] object KryoSerializer { classOf[HighlyCompressedMapStatus], classOf[CompactBuffer[_]], classOf[BlockManagerId], + classOf[Array[Boolean]], classOf[Array[Byte]], classOf[Array[Short]], + classOf[Array[Int]], classOf[Array[Long]], + classOf[Array[Float]], + classOf[Array[Double]], + classOf[Array[Char]], + classOf[Array[String]], + classOf[Array[Array[String]]], classOf[BoundedPriorityQueue[_]], classOf[SparkConf] ) diff --git a/core/src/main/scala/org/apache/spark/serializer/Serializer.scala b/core/src/main/scala/org/apache/spark/serializer/Serializer.scala index 008b0387899f6..cb8b1cc077637 100644 --- a/core/src/main/scala/org/apache/spark/serializer/Serializer.scala +++ b/core/src/main/scala/org/apache/spark/serializer/Serializer.scala @@ -77,7 +77,7 @@ abstract class Serializer { * position = 0 * serOut.write(obj1) * serOut.flush() - * position = # of bytes writen to stream so far + * position = # of bytes written to stream so far * obj1Bytes = output[0:position-1] * serOut.write(obj2) * serOut.flush() @@ -125,7 +125,7 @@ abstract class SerializerInstance { * A stream for writing serialized objects. */ @DeveloperApi -abstract class SerializationStream { +abstract class SerializationStream extends Closeable { /** The most general-purpose method to write an object. */ def writeObject[T: ClassTag](t: T): SerializationStream /** Writes the object representing the key of a key-value pair. */ @@ -133,7 +133,7 @@ abstract class SerializationStream { /** Writes the object representing the value of a key-value pair. */ def writeValue[T: ClassTag](value: T): SerializationStream = writeObject(value) def flush(): Unit - def close(): Unit + override def close(): Unit def writeAll[T: ClassTag](iter: Iterator[T]): SerializationStream = { while (iter.hasNext) { @@ -149,14 +149,14 @@ abstract class SerializationStream { * A stream for reading serialized objects. */ @DeveloperApi -abstract class DeserializationStream { +abstract class DeserializationStream extends Closeable { /** The most general-purpose method to read an object. */ def readObject[T: ClassTag](): T /** Reads the object representing the key of a key-value pair. */ def readKey[T: ClassTag](): T = readObject[T]() /** Reads the object representing the value of a key-value pair. */ def readValue[T: ClassTag](): T = readObject[T]() - def close(): Unit + override def close(): Unit /** * Read the elements of this stream through an iterator. This can only be called once, as diff --git a/core/src/main/scala/org/apache/spark/serializer/SerializerManager.scala b/core/src/main/scala/org/apache/spark/serializer/SerializerManager.scala index 96b288b9cfb81..bb7ed8709ba8a 100644 --- a/core/src/main/scala/org/apache/spark/serializer/SerializerManager.scala +++ b/core/src/main/scala/org/apache/spark/serializer/SerializerManager.scala @@ -148,14 +148,14 @@ private[spark] class SerializerManager( /** * Wrap an output stream for compression if block compression is enabled for its block type */ - private[this] def wrapForCompression(blockId: BlockId, s: OutputStream): OutputStream = { + def wrapForCompression(blockId: BlockId, s: OutputStream): OutputStream = { if (shouldCompress(blockId)) compressionCodec.compressedOutputStream(s) else s } /** * Wrap an input stream for compression if block compression is enabled for its block type */ - private[this] def wrapForCompression(blockId: BlockId, s: InputStream): InputStream = { + def wrapForCompression(blockId: BlockId, s: InputStream): InputStream = { if (shouldCompress(blockId)) compressionCodec.compressedInputStream(s) else s } @@ -167,30 +167,26 @@ private[spark] class SerializerManager( val byteStream = new BufferedOutputStream(outputStream) val autoPick = !blockId.isInstanceOf[StreamBlockId] val ser = getSerializer(implicitly[ClassTag[T]], autoPick).newInstance() - ser.serializeStream(wrapStream(blockId, byteStream)).writeAll(values).close() + ser.serializeStream(wrapForCompression(blockId, byteStream)).writeAll(values).close() } /** Serializes into a chunked byte buffer. */ def dataSerialize[T: ClassTag]( blockId: BlockId, - values: Iterator[T], - allowEncryption: Boolean = true): ChunkedByteBuffer = { - dataSerializeWithExplicitClassTag(blockId, values, implicitly[ClassTag[T]], - allowEncryption = allowEncryption) + values: Iterator[T]): ChunkedByteBuffer = { + dataSerializeWithExplicitClassTag(blockId, values, implicitly[ClassTag[T]]) } /** Serializes into a chunked byte buffer. */ def dataSerializeWithExplicitClassTag( blockId: BlockId, values: Iterator[_], - classTag: ClassTag[_], - allowEncryption: Boolean = true): ChunkedByteBuffer = { + classTag: ClassTag[_]): ChunkedByteBuffer = { val bbos = new ChunkedByteBufferOutputStream(1024 * 1024 * 4, ByteBuffer.allocate) val byteStream = new BufferedOutputStream(bbos) val autoPick = !blockId.isInstanceOf[StreamBlockId] val ser = getSerializer(classTag, autoPick).newInstance() - val encrypted = if (allowEncryption) wrapForEncryption(byteStream) else byteStream - ser.serializeStream(wrapForCompression(blockId, encrypted)).writeAll(values).close() + ser.serializeStream(wrapForCompression(blockId, byteStream)).writeAll(values).close() bbos.toChunkedByteBuffer } @@ -200,15 +196,13 @@ private[spark] class SerializerManager( */ def dataDeserializeStream[T]( blockId: BlockId, - inputStream: InputStream, - maybeEncrypted: Boolean = true) + inputStream: InputStream) (classTag: ClassTag[T]): Iterator[T] = { val stream = new BufferedInputStream(inputStream) val autoPick = !blockId.isInstanceOf[StreamBlockId] - val decrypted = if (maybeEncrypted) wrapForEncryption(inputStream) else inputStream getSerializer(classTag, autoPick) .newInstance() - .deserializeStream(wrapForCompression(blockId, decrypted)) + .deserializeStream(wrapForCompression(blockId, inputStream)) .asIterator.asInstanceOf[Iterator[T]] } } diff --git a/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala b/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala index 8b2e26cdd94fb..ba3e0e395e958 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala @@ -95,8 +95,7 @@ private[spark] class BlockStoreShuffleReader[K, C]( // Sort the output if there is a sort ordering defined. dep.keyOrdering match { case Some(keyOrd: Ordering[K]) => - // Create an ExternalSorter to sort the data. Note that if spark.shuffle.spill is disabled, - // the ExternalSorter won't spill to disk. + // Create an ExternalSorter to sort the data. val sorter = new ExternalSorter[K, C, C](context, ordering = Some(keyOrd), serializer = dep.serializer) sorter.insertAll(aggregatedIter) diff --git a/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockResolver.scala b/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockResolver.scala index 91858f0912b65..15540485170d0 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockResolver.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockResolver.scala @@ -61,7 +61,7 @@ private[spark] class IndexShuffleBlockResolver( /** * Remove data file and index file that contain the output data from one map. - * */ + */ def removeDataByMap(shuffleId: Int, mapId: Int): Unit = { var file = getDataFile(shuffleId, mapId) if (file.exists()) { @@ -132,7 +132,7 @@ private[spark] class IndexShuffleBlockResolver( * replace them with new ones. * * Note: the `lengths` will be updated to match the existing index file if use the existing ones. - * */ + */ def writeIndexFileAndCommit( shuffleId: Int, mapId: Int, diff --git a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala index 5e977a16febe1..bfb4dc698e325 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala @@ -82,13 +82,13 @@ private[spark] class SortShuffleManager(conf: SparkConf) extends ShuffleManager override val shuffleBlockResolver = new IndexShuffleBlockResolver(conf) /** - * Register a shuffle with the manager and obtain a handle for it to pass to tasks. + * Obtains a [[ShuffleHandle]] to pass to tasks. */ override def registerShuffle[K, V, C]( shuffleId: Int, numMaps: Int, dependency: ShuffleDependency[K, V, C]): ShuffleHandle = { - if (SortShuffleWriter.shouldBypassMergeSort(SparkEnv.get.conf, dependency)) { + if (SortShuffleWriter.shouldBypassMergeSort(conf, dependency)) { // If there are fewer than spark.shuffle.sort.bypassMergeThreshold partitions and we don't // need map-side aggregation, then write numPartitions files directly and just concatenate // them at the end. This avoids doing serialization and deserialization twice to merge diff --git a/core/src/main/scala/org/apache/spark/status/api/v1/AllRDDResource.scala b/core/src/main/scala/org/apache/spark/status/api/v1/AllRDDResource.scala index 5c03609e5e5e5..1279b281ad8d8 100644 --- a/core/src/main/scala/org/apache/spark/status/api/v1/AllRDDResource.scala +++ b/core/src/main/scala/org/apache/spark/status/api/v1/AllRDDResource.scala @@ -70,7 +70,13 @@ private[spark] object AllRDDResource { address = status.blockManagerId.hostPort, memoryUsed = status.memUsedByRdd(rddId), memoryRemaining = status.memRemaining, - diskUsed = status.diskUsedByRdd(rddId) + diskUsed = status.diskUsedByRdd(rddId), + onHeapMemoryUsed = Some( + if (!rddInfo.storageLevel.useOffHeap) status.memUsedByRdd(rddId) else 0L), + offHeapMemoryUsed = Some( + if (rddInfo.storageLevel.useOffHeap) status.memUsedByRdd(rddId) else 0L), + onHeapMemoryRemaining = status.onHeapMemRemaining, + offHeapMemoryRemaining = status.offHeapMemRemaining ) } ) } else { None diff --git a/core/src/main/scala/org/apache/spark/status/api/v1/ApiRootResource.scala b/core/src/main/scala/org/apache/spark/status/api/v1/ApiRootResource.scala index 00f918c09c66b..f17b637754826 100644 --- a/core/src/main/scala/org/apache/spark/status/api/v1/ApiRootResource.scala +++ b/core/src/main/scala/org/apache/spark/status/api/v1/ApiRootResource.scala @@ -184,14 +184,27 @@ private[v1] class ApiRootResource extends ApiRequestContext { @Path("applications/{appId}/logs") def getEventLogs( @PathParam("appId") appId: String): EventLogDownloadResource = { - new EventLogDownloadResource(uiRoot, appId, None) + try { + // withSparkUI will throw NotFoundException if attemptId exists for this application. + // So we need to try again with attempt id "1". + withSparkUI(appId, None) { _ => + new EventLogDownloadResource(uiRoot, appId, None) + } + } catch { + case _: NotFoundException => + withSparkUI(appId, Some("1")) { _ => + new EventLogDownloadResource(uiRoot, appId, None) + } + } } @Path("applications/{appId}/{attemptId}/logs") def getEventLogs( @PathParam("appId") appId: String, @PathParam("attemptId") attemptId: String): EventLogDownloadResource = { - new EventLogDownloadResource(uiRoot, appId, Some(attemptId)) + withSparkUI(appId, Some(attemptId)) { _ => + new EventLogDownloadResource(uiRoot, appId, Some(attemptId)) + } } @Path("version") @@ -291,7 +304,6 @@ private[v1] trait ApiRequestContext { case None => throw new NotFoundException("no such app: " + appId) } } - } private[v1] class ForbiddenException(msg: String) extends WebApplicationException( diff --git a/core/src/main/scala/org/apache/spark/status/api/v1/api.scala b/core/src/main/scala/org/apache/spark/status/api/v1/api.scala index 5b9227350edaa..56d8e51732ffd 100644 --- a/core/src/main/scala/org/apache/spark/status/api/v1/api.scala +++ b/core/src/main/scala/org/apache/spark/status/api/v1/api.scala @@ -75,7 +75,14 @@ class ExecutorSummary private[spark]( val totalShuffleWrite: Long, val isBlacklisted: Boolean, val maxMemory: Long, - val executorLogs: Map[String, String]) + val executorLogs: Map[String, String], + val memoryMetrics: Option[MemoryMetrics]) + +class MemoryMetrics private[spark]( + val usedOnHeapStorageMemory: Long, + val usedOffHeapStorageMemory: Long, + val totalOnHeapStorageMemory: Long, + val totalOffHeapStorageMemory: Long) class JobData private[spark]( val jobId: Int, @@ -111,7 +118,11 @@ class RDDDataDistribution private[spark]( val address: String, val memoryUsed: Long, val memoryRemaining: Long, - val diskUsed: Long) + val diskUsed: Long, + val onHeapMemoryUsed: Option[Long], + val offHeapMemoryUsed: Option[Long], + val onHeapMemoryRemaining: Option[Long], + val offHeapMemoryRemaining: Option[Long]) class RDDPartitionInfo private[spark]( val blockName: String, diff --git a/core/src/main/scala/org/apache/spark/storage/BlockInfoManager.scala b/core/src/main/scala/org/apache/spark/storage/BlockInfoManager.scala index dd8f5bacb9f6e..3db59837fbebd 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockInfoManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockInfoManager.scala @@ -23,7 +23,7 @@ import scala.collection.JavaConverters._ import scala.collection.mutable import scala.reflect.ClassTag -import com.google.common.collect.ConcurrentHashMultiset +import com.google.common.collect.{ConcurrentHashMultiset, ImmutableMultiset} import org.apache.spark.{SparkException, TaskContext} import org.apache.spark.internal.Logging @@ -340,7 +340,7 @@ private[storage] class BlockInfoManager extends Logging { val blocksWithReleasedLocks = mutable.ArrayBuffer[BlockId]() val readLocks = synchronized { - readLocksByTask.remove(taskAttemptId).get + readLocksByTask.remove(taskAttemptId).getOrElse(ImmutableMultiset.of[BlockId]()) } val writeLocks = synchronized { writeLocksByTask.remove(taskAttemptId).getOrElse(Seq.empty) @@ -371,6 +371,12 @@ private[storage] class BlockInfoManager extends Logging { blocksWithReleasedLocks } + /** Returns the number of locks held by the given task. Used only for testing. */ + private[storage] def getTaskLockCount(taskAttemptId: TaskAttemptId): Int = { + readLocksByTask.get(taskAttemptId).map(_.size()).getOrElse(0) + + writeLocksByTask.get(taskAttemptId).map(_.size).getOrElse(0) + } + /** * Returns the number of blocks tracked. */ diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala index 45b73380806dd..3219969bcd06f 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala @@ -19,6 +19,7 @@ package org.apache.spark.storage import java.io._ import java.nio.ByteBuffer +import java.nio.channels.Channels import scala.collection.mutable import scala.collection.mutable.HashMap @@ -35,7 +36,7 @@ import org.apache.spark.executor.{DataReadMethod, ShuffleWriteMetrics} import org.apache.spark.internal.Logging import org.apache.spark.memory.{MemoryManager, MemoryMode} import org.apache.spark.network._ -import org.apache.spark.network.buffer.{ManagedBuffer, NettyManagedBuffer} +import org.apache.spark.network.buffer.ManagedBuffer import org.apache.spark.network.netty.SparkTransportConf import org.apache.spark.network.shuffle.ExternalShuffleClient import org.apache.spark.network.shuffle.protocol.ExecutorShuffleInfo @@ -48,13 +49,61 @@ import org.apache.spark.unsafe.Platform import org.apache.spark.util._ import org.apache.spark.util.io.ChunkedByteBuffer - /* Class for returning a fetched block and associated metrics. */ private[spark] class BlockResult( val data: Iterator[Any], val readMethod: DataReadMethod.Value, val bytes: Long) +/** + * Abstracts away how blocks are stored and provides different ways to read the underlying block + * data. Callers should call [[dispose()]] when they're done with the block. + */ +private[spark] trait BlockData { + + def toInputStream(): InputStream + + /** + * Returns a Netty-friendly wrapper for the block's data. + * + * Please see `ManagedBuffer.convertToNetty()` for more details. + */ + def toNetty(): Object + + def toChunkedByteBuffer(allocator: Int => ByteBuffer): ChunkedByteBuffer + + def toByteBuffer(): ByteBuffer + + def size: Long + + def dispose(): Unit + +} + +private[spark] class ByteBufferBlockData( + val buffer: ChunkedByteBuffer, + val shouldDispose: Boolean) extends BlockData { + + override def toInputStream(): InputStream = buffer.toInputStream(dispose = false) + + override def toNetty(): Object = buffer.toNetty + + override def toChunkedByteBuffer(allocator: Int => ByteBuffer): ChunkedByteBuffer = { + buffer.copy(allocator) + } + + override def toByteBuffer(): ByteBuffer = buffer.toByteBuffer + + override def size: Long = buffer.size + + override def dispose(): Unit = { + if (shouldDispose) { + buffer.dispose() + } + } + +} + /** * Manager running on every node (driver and executors) which provides interfaces for putting and * retrieving blocks both locally and remotely into various stores (memory, disk, and off-heap). @@ -94,15 +143,15 @@ private[spark] class BlockManager( // Actual storage of where blocks are kept private[spark] val memoryStore = new MemoryStore(conf, blockInfoManager, serializerManager, memoryManager, this) - private[spark] val diskStore = new DiskStore(conf, diskBlockManager) + private[spark] val diskStore = new DiskStore(conf, diskBlockManager, securityManager) memoryManager.setMemoryStore(memoryStore) // Note: depending on the memory manager, `maxMemory` may actually vary over time. // However, since we use this only for reporting and logging, what we actually want here is // the absolute maximum value that `maxMemory` can ever possibly reach. We may need // to revisit whether reporting this value as the "max" is intuitive to the user. - private val maxMemory = - memoryManager.maxOnHeapStorageMemory + memoryManager.maxOffHeapStorageMemory + private val maxOnHeapMemory = memoryManager.maxOnHeapStorageMemory + private val maxOffHeapMemory = memoryManager.maxOffHeapStorageMemory // Port used by the external shuffle service. In Yarn mode, this may be already be // set through the Hadoop configuration as the server is launched in the Yarn NM. @@ -180,7 +229,8 @@ private[spark] class BlockManager( val idFromMaster = master.registerBlockManager( id, - maxMemory, + maxOnHeapMemory, + maxOffHeapMemory, slaveEndpoint) blockManagerId = if (idFromMaster != null) idFromMaster else id @@ -258,7 +308,7 @@ private[spark] class BlockManager( def reregister(): Unit = { // TODO: We might need to rate limit re-registering. logInfo(s"BlockManager $blockManagerId re-registering with master") - master.registerBlockManager(blockManagerId, maxMemory, slaveEndpoint) + master.registerBlockManager(blockManagerId, maxOnHeapMemory, maxOffHeapMemory, slaveEndpoint) reportAllBlocks() } @@ -304,7 +354,8 @@ private[spark] class BlockManager( shuffleManager.shuffleBlockResolver.getBlockData(blockId.asInstanceOf[ShuffleBlockId]) } else { getLocalBytes(blockId) match { - case Some(buffer) => new BlockManagerManagedBuffer(blockInfoManager, blockId, buffer) + case Some(blockData) => + new BlockManagerManagedBuffer(blockInfoManager, blockId, blockData, true) case None => // If this block manager receives a request for a block that it doesn't have then it's // likely that the master has outdated block statuses for this block. Therefore, we send @@ -317,6 +368,9 @@ private[spark] class BlockManager( /** * Put the block locally, using the given storage level. + * + * '''Important!''' Callers must not mutate or release the data buffer underlying `bytes`. Doing + * so may corrupt or change the data stored by the `BlockManager`. */ override def putBlockData( blockId: BlockId, @@ -460,21 +514,22 @@ private[spark] class BlockManager( val ci = CompletionIterator[Any, Iterator[Any]](iter, releaseLock(blockId)) Some(new BlockResult(ci, DataReadMethod.Memory, info.size)) } else if (level.useDisk && diskStore.contains(blockId)) { + val diskData = diskStore.getBytes(blockId) val iterToReturn: Iterator[Any] = { - val diskBytes = diskStore.getBytes(blockId) if (level.deserialized) { val diskValues = serializerManager.dataDeserializeStream( blockId, - diskBytes.toInputStream(dispose = true))(info.classTag) + diskData.toInputStream())(info.classTag) maybeCacheDiskValuesInMemory(info, blockId, level, diskValues) } else { - val stream = maybeCacheDiskBytesInMemory(info, blockId, level, diskBytes) - .map {_.toInputStream(dispose = false)} - .getOrElse { diskBytes.toInputStream(dispose = true) } + val stream = maybeCacheDiskBytesInMemory(info, blockId, level, diskData) + .map { _.toInputStream(dispose = false) } + .getOrElse { diskData.toInputStream() } serializerManager.dataDeserializeStream(blockId, stream)(info.classTag) } } - val ci = CompletionIterator[Any, Iterator[Any]](iterToReturn, releaseLock(blockId)) + val ci = CompletionIterator[Any, Iterator[Any]](iterToReturn, + releaseLockAndDispose(blockId, diskData)) Some(new BlockResult(ci, DataReadMethod.Disk, info.size)) } else { handleLocalReadFailure(blockId) @@ -485,7 +540,7 @@ private[spark] class BlockManager( /** * Get block from the local block manager as serialized bytes. */ - def getLocalBytes(blockId: BlockId): Option[ChunkedByteBuffer] = { + def getLocalBytes(blockId: BlockId): Option[BlockData] = { logDebug(s"Getting local block $blockId as bytes") // As an optimization for map output fetches, if the block is for a shuffle, return it // without acquiring a lock; the disk store never deletes (recent) items so this should work @@ -493,9 +548,9 @@ private[spark] class BlockManager( val shuffleBlockResolver = shuffleManager.shuffleBlockResolver // TODO: This should gracefully handle case where local block is not available. Currently // downstream code will throw an exception. - Option( - new ChunkedByteBuffer( - shuffleBlockResolver.getBlockData(blockId.asInstanceOf[ShuffleBlockId]).nioByteBuffer())) + val buf = new ChunkedByteBuffer( + shuffleBlockResolver.getBlockData(blockId.asInstanceOf[ShuffleBlockId]).nioByteBuffer()) + Some(new ByteBufferBlockData(buf, true)) } else { blockInfoManager.lockForReading(blockId).map { info => doGetLocalBytes(blockId, info) } } @@ -507,7 +562,7 @@ private[spark] class BlockManager( * Must be called while holding a read lock on the block. * Releases the read lock upon exception; keeps the read lock upon successful return. */ - private def doGetLocalBytes(blockId: BlockId, info: BlockInfo): ChunkedByteBuffer = { + private def doGetLocalBytes(blockId: BlockId, info: BlockInfo): BlockData = { val level = info.level logDebug(s"Level for block $blockId is $level") // In order, try to read the serialized bytes from memory, then from disk, then fall back to @@ -522,17 +577,19 @@ private[spark] class BlockManager( diskStore.getBytes(blockId) } else if (level.useMemory && memoryStore.contains(blockId)) { // The block was not found on disk, so serialize an in-memory copy: - serializerManager.dataSerializeWithExplicitClassTag( - blockId, memoryStore.getValues(blockId).get, info.classTag) + new ByteBufferBlockData(serializerManager.dataSerializeWithExplicitClassTag( + blockId, memoryStore.getValues(blockId).get, info.classTag), true) } else { handleLocalReadFailure(blockId) } } else { // storage level is serialized if (level.useMemory && memoryStore.contains(blockId)) { - memoryStore.getBytes(blockId).get + new ByteBufferBlockData(memoryStore.getBytes(blockId).get, false) } else if (level.useDisk && diskStore.contains(blockId)) { - val diskBytes = diskStore.getBytes(blockId) - maybeCacheDiskBytesInMemory(info, blockId, level, diskBytes).getOrElse(diskBytes) + val diskData = diskStore.getBytes(blockId) + maybeCacheDiskBytesInMemory(info, blockId, level, diskData) + .map(new ByteBufferBlockData(_, false)) + .getOrElse(diskData) } else { handleLocalReadFailure(blockId) } @@ -755,43 +812,18 @@ private[spark] class BlockManager( /** * Put a new block of serialized bytes to the block manager. * - * @param encrypt If true, asks the block manager to encrypt the data block before storing, - * when I/O encryption is enabled. This is required for blocks that have been - * read from unencrypted sources, since all the BlockManager read APIs - * automatically do decryption. + * '''Important!''' Callers must not mutate or release the data buffer underlying `bytes`. Doing + * so may corrupt or change the data stored by the `BlockManager`. + * * @return true if the block was stored or false if an error occurred. */ def putBytes[T: ClassTag]( blockId: BlockId, bytes: ChunkedByteBuffer, level: StorageLevel, - tellMaster: Boolean = true, - encrypt: Boolean = false): Boolean = { + tellMaster: Boolean = true): Boolean = { require(bytes != null, "Bytes is null") - - val bytesToStore = - if (encrypt && securityManager.ioEncryptionKey.isDefined) { - try { - val data = bytes.toByteBuffer - val in = new ByteBufferInputStream(data, true) - val byteBufOut = new ByteBufferOutputStream(data.remaining()) - val out = CryptoStreamUtils.createCryptoOutputStream(byteBufOut, conf, - securityManager.ioEncryptionKey.get) - try { - ByteStreams.copy(in, out) - } finally { - in.close() - out.close() - } - new ChunkedByteBuffer(byteBufOut.toByteBuffer) - } finally { - bytes.dispose() - } - } else { - bytes - } - - doPutBytes(blockId, bytesToStore, level, implicitly[ClassTag[T]], tellMaster) + doPutBytes(blockId, bytes, level, implicitly[ClassTag[T]], tellMaster) } /** @@ -800,6 +832,9 @@ private[spark] class BlockManager( * * If the block already exists, this method will not overwrite it. * + * '''Important!''' Callers must not mutate or release the data buffer underlying `bytes`. Doing + * so may corrupt or change the data stored by the `BlockManager`. + * * @param keepReadLock if true, this method will hold the read lock when it returns (even if the * block already exists). If false, this method will hold no locks when it * returns. @@ -819,8 +854,9 @@ private[spark] class BlockManager( val replicationFuture = if (level.replication > 1) { Future { // This is a blocking action and should run in futureExecutionContext which is a cached - // thread pool - replicate(blockId, bytes, level, classTag) + // thread pool. The ByteBufferBlockData wrapper is not disposed of to avoid releasing + // buffers that are owned by the caller. + replicate(blockId, new ByteBufferBlockData(bytes, false), level, classTag) }(futureExecutionContext) } else { null @@ -843,7 +879,15 @@ private[spark] class BlockManager( false } } else { - memoryStore.putBytes(blockId, size, level.memoryMode, () => bytes) + val memoryMode = level.memoryMode + memoryStore.putBytes(blockId, size, memoryMode, () => { + if (memoryMode == MemoryMode.OFF_HEAP && + bytes.chunks.exists(buffer => !buffer.isDirect)) { + bytes.copy(Platform.allocateDirectBuffer) + } else { + bytes + } + }) } if (!putSucceeded && level.useDisk) { logWarning(s"Persisting block $blockId to disk instead.") @@ -991,8 +1035,9 @@ private[spark] class BlockManager( // Not enough space to unroll this block; drop to disk if applicable if (level.useDisk) { logWarning(s"Persisting block $blockId to disk instead.") - diskStore.put(blockId) { fileOutputStream => - serializerManager.dataSerializeStream(blockId, fileOutputStream, iter)(classTag) + diskStore.put(blockId) { channel => + val out = Channels.newOutputStream(channel) + serializerManager.dataSerializeStream(blockId, out, iter)(classTag) } size = diskStore.getSize(blockId) } else { @@ -1007,8 +1052,9 @@ private[spark] class BlockManager( // Not enough space to unroll this block; drop to disk if applicable if (level.useDisk) { logWarning(s"Persisting block $blockId to disk instead.") - diskStore.put(blockId) { fileOutputStream => - partiallySerializedValues.finishWritingToStream(fileOutputStream) + diskStore.put(blockId) { channel => + val out = Channels.newOutputStream(channel) + partiallySerializedValues.finishWritingToStream(out) } size = diskStore.getSize(blockId) } else { @@ -1018,8 +1064,9 @@ private[spark] class BlockManager( } } else if (level.useDisk) { - diskStore.put(blockId) { fileOutputStream => - serializerManager.dataSerializeStream(blockId, fileOutputStream, iterator())(classTag) + diskStore.put(blockId) { channel => + val out = Channels.newOutputStream(channel) + serializerManager.dataSerializeStream(blockId, out, iterator())(classTag) } size = diskStore.getSize(blockId) } @@ -1072,29 +1119,29 @@ private[spark] class BlockManager( blockInfo: BlockInfo, blockId: BlockId, level: StorageLevel, - diskBytes: ChunkedByteBuffer): Option[ChunkedByteBuffer] = { + diskData: BlockData): Option[ChunkedByteBuffer] = { require(!level.deserialized) if (level.useMemory) { // Synchronize on blockInfo to guard against a race condition where two readers both try to // put values read from disk into the MemoryStore. blockInfo.synchronized { if (memoryStore.contains(blockId)) { - diskBytes.dispose() + diskData.dispose() Some(memoryStore.getBytes(blockId).get) } else { val allocator = level.memoryMode match { case MemoryMode.ON_HEAP => ByteBuffer.allocate _ case MemoryMode.OFF_HEAP => Platform.allocateDirectBuffer _ } - val putSucceeded = memoryStore.putBytes(blockId, diskBytes.size, level.memoryMode, () => { + val putSucceeded = memoryStore.putBytes(blockId, diskData.size, level.memoryMode, () => { // https://issues.apache.org/jira/browse/SPARK-6076 // If the file size is bigger than the free memory, OOM will happen. So if we // cannot put it into MemoryStore, copyForMemory should not be created. That's why // this action is put into a `() => ChunkedByteBuffer` and created lazily. - diskBytes.copy(allocator) + diskData.toChunkedByteBuffer(allocator) }) if (putSucceeded) { - diskBytes.dispose() + diskData.dispose() Some(memoryStore.getBytes(blockId).get) } else { None @@ -1170,7 +1217,7 @@ private[spark] class BlockManager( blockId: BlockId, existingReplicas: Set[BlockManagerId], maxReplicas: Int): Unit = { - logInfo(s"Pro-actively replicating $blockId") + logInfo(s"Using $blockManagerId to pro-actively replicate $blockId") blockInfoManager.lockForReading(blockId).foreach { info => val data = doGetLocalBytes(blockId, info) val storageLevel = StorageLevel( @@ -1179,10 +1226,14 @@ private[spark] class BlockManager( useOffHeap = info.level.useOffHeap, deserialized = info.level.deserialized, replication = maxReplicas) + // we know we are called as a result of an executor removal, so we refresh peer cache + // this way, we won't try to replicate to a missing executor with a stale reference + getPeers(forceFetch = true) try { replicate(blockId, data, storageLevel, info.classTag, existingReplicas) } finally { - releaseLock(blockId) + logDebug(s"Releasing lock for $blockId") + releaseLockAndDispose(blockId, data) } } } @@ -1193,7 +1244,7 @@ private[spark] class BlockManager( */ private def replicate( blockId: BlockId, - data: ChunkedByteBuffer, + data: BlockData, level: StorageLevel, classTag: ClassTag[_], existingReplicas: Set[BlockManagerId] = Set.empty): Unit = { @@ -1207,7 +1258,6 @@ private[spark] class BlockManager( replication = 1) val numPeersToReplicateTo = level.replication - 1 - val startTime = System.nanoTime var peersReplicatedTo = mutable.HashSet.empty ++ existingReplicas @@ -1235,7 +1285,7 @@ private[spark] class BlockManager( peer.port, peer.executorId, blockId, - new NettyManagedBuffer(data.toNetty), + new BlockManagerManagedBuffer(blockInfoManager, blockId, data, false), tLevel, classTag) logTrace(s"Replicated $blockId of ${data.size} bytes to $peer" + @@ -1262,7 +1312,6 @@ private[spark] class BlockManager( numPeersToReplicateTo - peersReplicatedTo.size) } } - logDebug(s"Replicating $blockId of ${data.size} bytes to " + s"${peersReplicatedTo.size} peer(s) took ${(System.nanoTime - startTime) / 1e6} ms") if (peersReplicatedTo.size < numPeersToReplicateTo) { @@ -1318,10 +1367,11 @@ private[spark] class BlockManager( logInfo(s"Writing block $blockId to disk") data() match { case Left(elements) => - diskStore.put(blockId) { fileOutputStream => + diskStore.put(blockId) { channel => + val out = Channels.newOutputStream(channel) serializerManager.dataSerializeStream( blockId, - fileOutputStream, + out, elements.toIterator)(info.classTag.asInstanceOf[ClassTag[T]]) } case Right(bytes) => @@ -1413,6 +1463,11 @@ private[spark] class BlockManager( } } + def releaseLockAndDispose(blockId: BlockId, data: BlockData): Unit = { + blockInfoManager.unlock(blockId) + data.dispose() + } + def stop(): Unit = { blockTransferService.close() if (shuffleClient ne blockTransferService) { diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerManagedBuffer.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerManagedBuffer.scala index f66f942798550..1ea0d378cbe87 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManagerManagedBuffer.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerManagedBuffer.scala @@ -17,31 +17,52 @@ package org.apache.spark.storage -import org.apache.spark.network.buffer.{ManagedBuffer, NettyManagedBuffer} +import java.io.InputStream +import java.nio.ByteBuffer +import java.util.concurrent.atomic.AtomicInteger + +import org.apache.spark.network.buffer.ManagedBuffer import org.apache.spark.util.io.ChunkedByteBuffer /** - * This [[ManagedBuffer]] wraps a [[ChunkedByteBuffer]] retrieved from the [[BlockManager]] + * This [[ManagedBuffer]] wraps a [[BlockData]] instance retrieved from the [[BlockManager]] * so that the corresponding block's read lock can be released once this buffer's references * are released. * + * If `dispose` is set to true, the [[BlockData]]will be disposed when the buffer's reference + * count drops to zero. + * * This is effectively a wrapper / bridge to connect the BlockManager's notion of read locks * to the network layer's notion of retain / release counts. */ private[storage] class BlockManagerManagedBuffer( blockInfoManager: BlockInfoManager, blockId: BlockId, - chunkedBuffer: ChunkedByteBuffer) extends NettyManagedBuffer(chunkedBuffer.toNetty) { + data: BlockData, + dispose: Boolean) extends ManagedBuffer { + + private val refCount = new AtomicInteger(1) + + override def size(): Long = data.size + + override def nioByteBuffer(): ByteBuffer = data.toByteBuffer() + + override def createInputStream(): InputStream = data.toInputStream() + + override def convertToNetty(): Object = data.toNetty() override def retain(): ManagedBuffer = { - super.retain() + refCount.incrementAndGet() val locked = blockInfoManager.lockForReading(blockId, blocking = false) assert(locked.isDefined) this - } + } override def release(): ManagedBuffer = { blockInfoManager.unlock(blockId) - super.release() + if (refCount.decrementAndGet() == 0 && dispose) { + data.dispose() + } + this } } diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala index 3ca690db9e79f..ea5d8423a588c 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala @@ -57,11 +57,12 @@ class BlockManagerMaster( */ def registerBlockManager( blockManagerId: BlockManagerId, - maxMemSize: Long, + maxOnHeapMemSize: Long, + maxOffHeapMemSize: Long, slaveEndpoint: RpcEndpointRef): BlockManagerId = { logInfo(s"Registering BlockManager $blockManagerId") val updatedId = driverEndpoint.askSync[BlockManagerId]( - RegisterBlockManager(blockManagerId, maxMemSize, slaveEndpoint)) + RegisterBlockManager(blockManagerId, maxOnHeapMemSize, maxOffHeapMemSize, slaveEndpoint)) logInfo(s"Registered BlockManager $updatedId") updatedId } diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterEndpoint.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterEndpoint.scala index 84c04d22600ad..6f85b9e4d6c73 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterEndpoint.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterEndpoint.scala @@ -71,8 +71,8 @@ class BlockManagerMasterEndpoint( logInfo("BlockManagerMasterEndpoint up") override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = { - case RegisterBlockManager(blockManagerId, maxMemSize, slaveEndpoint) => - context.reply(register(blockManagerId, maxMemSize, slaveEndpoint)) + case RegisterBlockManager(blockManagerId, maxOnHeapMemSize, maxOffHeapMemSize, slaveEndpoint) => + context.reply(register(blockManagerId, maxOnHeapMemSize, maxOffHeapMemSize, slaveEndpoint)) case _updateBlockInfo @ UpdateBlockInfo(blockManagerId, blockId, storageLevel, deserializedSize, size) => @@ -276,7 +276,8 @@ class BlockManagerMasterEndpoint( private def storageStatus: Array[StorageStatus] = { blockManagerInfo.map { case (blockManagerId, info) => - new StorageStatus(blockManagerId, info.maxMem, info.blocks.asScala) + new StorageStatus(blockManagerId, info.maxMem, Some(info.maxOnHeapMem), + Some(info.maxOffHeapMem), info.blocks.asScala) }.toArray } @@ -338,7 +339,8 @@ class BlockManagerMasterEndpoint( */ private def register( idWithoutTopologyInfo: BlockManagerId, - maxMemSize: Long, + maxOnHeapMemSize: Long, + maxOffHeapMemSize: Long, slaveEndpoint: RpcEndpointRef): BlockManagerId = { // the dummy id is not expected to contain the topology information. // we get that info here and respond back with a more fleshed out block manager id @@ -359,14 +361,15 @@ class BlockManagerMasterEndpoint( case None => } logInfo("Registering block manager %s with %s RAM, %s".format( - id.hostPort, Utils.bytesToString(maxMemSize), id)) + id.hostPort, Utils.bytesToString(maxOnHeapMemSize + maxOffHeapMemSize), id)) blockManagerIdByExecutor(id.executorId) = id blockManagerInfo(id) = new BlockManagerInfo( - id, System.currentTimeMillis(), maxMemSize, slaveEndpoint) + id, System.currentTimeMillis(), maxOnHeapMemSize, maxOffHeapMemSize, slaveEndpoint) } - listenerBus.post(SparkListenerBlockManagerAdded(time, id, maxMemSize)) + listenerBus.post(SparkListenerBlockManagerAdded(time, id, maxOnHeapMemSize + maxOffHeapMemSize, + Some(maxOnHeapMemSize), Some(maxOffHeapMemSize))) id } @@ -464,10 +467,13 @@ object BlockStatus { private[spark] class BlockManagerInfo( val blockManagerId: BlockManagerId, timeMs: Long, - val maxMem: Long, + val maxOnHeapMem: Long, + val maxOffHeapMem: Long, val slaveEndpoint: RpcEndpointRef) extends Logging { + val maxMem = maxOnHeapMem + maxOffHeapMem + private var _lastSeenMs: Long = timeMs private var _remainingMem: Long = maxMem @@ -491,11 +497,17 @@ private[spark] class BlockManagerInfo( updateLastSeenMs() - if (_blocks.containsKey(blockId)) { + val blockExists = _blocks.containsKey(blockId) + var originalMemSize: Long = 0 + var originalDiskSize: Long = 0 + var originalLevel: StorageLevel = StorageLevel.NONE + + if (blockExists) { // The block exists on the slave already. val blockStatus: BlockStatus = _blocks.get(blockId) - val originalLevel: StorageLevel = blockStatus.storageLevel - val originalMemSize: Long = blockStatus.memSize + originalLevel = blockStatus.storageLevel + originalMemSize = blockStatus.memSize + originalDiskSize = blockStatus.diskSize if (originalLevel.useMemory) { _remainingMem += originalMemSize @@ -514,32 +526,44 @@ private[spark] class BlockManagerInfo( blockStatus = BlockStatus(storageLevel, memSize = memSize, diskSize = 0) _blocks.put(blockId, blockStatus) _remainingMem -= memSize - logInfo("Added %s in memory on %s (size: %s, free: %s)".format( - blockId, blockManagerId.hostPort, Utils.bytesToString(memSize), - Utils.bytesToString(_remainingMem))) + if (blockExists) { + logInfo(s"Updated $blockId in memory on ${blockManagerId.hostPort}" + + s" (current size: ${Utils.bytesToString(memSize)}," + + s" original size: ${Utils.bytesToString(originalMemSize)}," + + s" free: ${Utils.bytesToString(_remainingMem)})") + } else { + logInfo(s"Added $blockId in memory on ${blockManagerId.hostPort}" + + s" (size: ${Utils.bytesToString(memSize)}," + + s" free: ${Utils.bytesToString(_remainingMem)})") + } } if (storageLevel.useDisk) { blockStatus = BlockStatus(storageLevel, memSize = 0, diskSize = diskSize) _blocks.put(blockId, blockStatus) - logInfo("Added %s on disk on %s (size: %s)".format( - blockId, blockManagerId.hostPort, Utils.bytesToString(diskSize))) + if (blockExists) { + logInfo(s"Updated $blockId on disk on ${blockManagerId.hostPort}" + + s" (current size: ${Utils.bytesToString(diskSize)}," + + s" original size: ${Utils.bytesToString(originalDiskSize)})") + } else { + logInfo(s"Added $blockId on disk on ${blockManagerId.hostPort}" + + s" (size: ${Utils.bytesToString(diskSize)})") + } } if (!blockId.isBroadcast && blockStatus.isCached) { _cachedBlocks += blockId } - } else if (_blocks.containsKey(blockId)) { + } else if (blockExists) { // If isValid is not true, drop the block. - val blockStatus: BlockStatus = _blocks.get(blockId) _blocks.remove(blockId) _cachedBlocks -= blockId - if (blockStatus.storageLevel.useMemory) { - logInfo("Removed %s on %s in memory (size: %s, free: %s)".format( - blockId, blockManagerId.hostPort, Utils.bytesToString(blockStatus.memSize), - Utils.bytesToString(_remainingMem))) + if (originalLevel.useMemory) { + logInfo(s"Removed $blockId on ${blockManagerId.hostPort} in memory" + + s" (size: ${Utils.bytesToString(originalMemSize)}," + + s" free: ${Utils.bytesToString(_remainingMem)})") } - if (blockStatus.storageLevel.useDisk) { - logInfo("Removed %s on %s on disk (size: %s)".format( - blockId, blockManagerId.hostPort, Utils.bytesToString(blockStatus.diskSize))) + if (originalLevel.useDisk) { + logInfo(s"Removed $blockId on ${blockManagerId.hostPort} on disk" + + s" (size: ${Utils.bytesToString(originalDiskSize)})") } } } diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerMessages.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerMessages.scala index 0aea438e7f473..0c0ff144596ac 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManagerMessages.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerMessages.scala @@ -58,7 +58,8 @@ private[spark] object BlockManagerMessages { case class RegisterBlockManager( blockManagerId: BlockManagerId, - maxMemSize: Long, + maxOnHeapMemSize: Long, + maxOffHeapMemSize: Long, sender: RpcEndpointRef) extends ToBlockManagerMaster diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerSource.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerSource.scala index c5ba9af3e2658..197a01762c0c5 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManagerSource.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerSource.scala @@ -26,35 +26,39 @@ private[spark] class BlockManagerSource(val blockManager: BlockManager) override val metricRegistry = new MetricRegistry() override val sourceName = "BlockManager" - metricRegistry.register(MetricRegistry.name("memory", "maxMem_MB"), new Gauge[Long] { - override def getValue: Long = { - val storageStatusList = blockManager.master.getStorageStatus - val maxMem = storageStatusList.map(_.maxMem).sum - maxMem / 1024 / 1024 - } - }) - - metricRegistry.register(MetricRegistry.name("memory", "remainingMem_MB"), new Gauge[Long] { - override def getValue: Long = { - val storageStatusList = blockManager.master.getStorageStatus - val remainingMem = storageStatusList.map(_.memRemaining).sum - remainingMem / 1024 / 1024 - } - }) - - metricRegistry.register(MetricRegistry.name("memory", "memUsed_MB"), new Gauge[Long] { - override def getValue: Long = { - val storageStatusList = blockManager.master.getStorageStatus - val memUsed = storageStatusList.map(_.memUsed).sum - memUsed / 1024 / 1024 - } - }) - - metricRegistry.register(MetricRegistry.name("disk", "diskSpaceUsed_MB"), new Gauge[Long] { - override def getValue: Long = { - val storageStatusList = blockManager.master.getStorageStatus - val diskSpaceUsed = storageStatusList.map(_.diskUsed).sum - diskSpaceUsed / 1024 / 1024 - } - }) + private def registerGauge(name: String, func: BlockManagerMaster => Long): Unit = { + metricRegistry.register(name, new Gauge[Long] { + override def getValue: Long = func(blockManager.master) / 1024 / 1024 + }) + } + + registerGauge(MetricRegistry.name("memory", "maxMem_MB"), + _.getStorageStatus.map(_.maxMem).sum) + + registerGauge(MetricRegistry.name("memory", "maxOnHeapMem_MB"), + _.getStorageStatus.map(_.maxOnHeapMem.getOrElse(0L)).sum) + + registerGauge(MetricRegistry.name("memory", "maxOffHeapMem_MB"), + _.getStorageStatus.map(_.maxOffHeapMem.getOrElse(0L)).sum) + + registerGauge(MetricRegistry.name("memory", "remainingMem_MB"), + _.getStorageStatus.map(_.memRemaining).sum) + + registerGauge(MetricRegistry.name("memory", "remainingOnHeapMem_MB"), + _.getStorageStatus.map(_.onHeapMemRemaining.getOrElse(0L)).sum) + + registerGauge(MetricRegistry.name("memory", "remainingOffHeapMem_MB"), + _.getStorageStatus.map(_.offHeapMemRemaining.getOrElse(0L)).sum) + + registerGauge(MetricRegistry.name("memory", "memUsed_MB"), + _.getStorageStatus.map(_.memUsed).sum) + + registerGauge(MetricRegistry.name("memory", "onHeapMemUsed_MB"), + _.getStorageStatus.map(_.onHeapMemUsed.getOrElse(0L)).sum) + + registerGauge(MetricRegistry.name("memory", "offHeapMemUsed_MB"), + _.getStorageStatus.map(_.offHeapMemUsed.getOrElse(0L)).sum) + + registerGauge(MetricRegistry.name("disk", "diskSpaceUsed_MB"), + _.getStorageStatus.map(_.diskUsed).sum) } diff --git a/core/src/main/scala/org/apache/spark/storage/BlockReplicationPolicy.scala b/core/src/main/scala/org/apache/spark/storage/BlockReplicationPolicy.scala index bb8a684b4c7a8..353eac60df171 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockReplicationPolicy.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockReplicationPolicy.scala @@ -53,6 +53,46 @@ trait BlockReplicationPolicy { numReplicas: Int): List[BlockManagerId] } +object BlockReplicationUtils { + // scalastyle:off line.size.limit + /** + * Uses sampling algorithm by Robert Floyd. Finds a random sample in O(n) while + * minimizing space usage. Please see + * here. + * + * @param n total number of indices + * @param m number of samples needed + * @param r random number generator + * @return list of m random unique indices + */ + // scalastyle:on line.size.limit + private def getSampleIds(n: Int, m: Int, r: Random): List[Int] = { + val indices = (n - m + 1 to n).foldLeft(mutable.LinkedHashSet.empty[Int]) {case (set, i) => + val t = r.nextInt(i) + 1 + if (set.contains(t)) set + i else set + t + } + indices.map(_ - 1).toList + } + + /** + * Get a random sample of size m from the elems + * + * @param elems + * @param m number of samples needed + * @param r random number generator + * @tparam T + * @return a random list of size m. If there are fewer than m elements in elems, we just + * randomly shuffle elems + */ + def getRandomSample[T](elems: Seq[T], m: Int, r: Random): List[T] = { + if (elems.size > m) { + getSampleIds(elems.size, m, r).map(elems(_)) + } else { + r.shuffle(elems).toList + } + } +} + @DeveloperApi class RandomBlockReplicationPolicy extends BlockReplicationPolicy @@ -67,6 +107,7 @@ class RandomBlockReplicationPolicy * @param peersReplicatedTo Set of peers already replicated to * @param blockId BlockId of the block being replicated. This can be used as a source of * randomness if needed. + * @param numReplicas Number of peers we need to replicate to * @return A prioritized list of peers. Lower the index of a peer, higher its priority */ override def prioritize( @@ -78,7 +119,7 @@ class RandomBlockReplicationPolicy val random = new Random(blockId.hashCode) logDebug(s"Input peers : ${peers.mkString(", ")}") val prioritizedPeers = if (peers.size > numReplicas) { - getSampleIds(peers.size, numReplicas, random).map(peers(_)) + BlockReplicationUtils.getRandomSample(peers, numReplicas, random) } else { if (peers.size < numReplicas) { logWarning(s"Expecting ${numReplicas} replicas with only ${peers.size} peer/s.") @@ -88,26 +129,96 @@ class RandomBlockReplicationPolicy logDebug(s"Prioritized peers : ${prioritizedPeers.mkString(", ")}") prioritizedPeers } +} + +@DeveloperApi +class BasicBlockReplicationPolicy + extends BlockReplicationPolicy + with Logging { - // scalastyle:off line.size.limit /** - * Uses sampling algorithm by Robert Floyd. Finds a random sample in O(n) while - * minimizing space usage. Please see - * here. + * Method to prioritize a bunch of candidate peers of a block manager. This implementation + * replicates the behavior of block replication in HDFS. For a given number of replicas needed, + * we choose a peer within the rack, one outside and remaining blockmanagers are chosen at + * random, in that order till we meet the number of replicas needed. + * This works best with a total replication factor of 3, like HDFS. * - * @param n total number of indices - * @param m number of samples needed - * @param r random number generator - * @return list of m random unique indices + * @param blockManagerId Id of the current BlockManager for self identification + * @param peers A list of peers of a BlockManager + * @param peersReplicatedTo Set of peers already replicated to + * @param blockId BlockId of the block being replicated. This can be used as a source of + * randomness if needed. + * @param numReplicas Number of peers we need to replicate to + * @return A prioritized list of peers. Lower the index of a peer, higher its priority */ - // scalastyle:on line.size.limit - private def getSampleIds(n: Int, m: Int, r: Random): List[Int] = { - val indices = (n - m + 1 to n).foldLeft(Set.empty[Int]) {case (set, i) => - val t = r.nextInt(i) + 1 - if (set.contains(t)) set + i else set + t + override def prioritize( + blockManagerId: BlockManagerId, + peers: Seq[BlockManagerId], + peersReplicatedTo: mutable.HashSet[BlockManagerId], + blockId: BlockId, + numReplicas: Int): List[BlockManagerId] = { + + logDebug(s"Input peers : $peers") + logDebug(s"BlockManagerId : $blockManagerId") + + val random = new Random(blockId.hashCode) + + // if block doesn't have topology info, we can't do much, so we randomly shuffle + // if there is, we see what's needed from peersReplicatedTo and based on numReplicas, + // we choose whats needed + if (blockManagerId.topologyInfo.isEmpty || numReplicas == 0) { + // no topology info for the block. The best we can do is randomly choose peers + BlockReplicationUtils.getRandomSample(peers, numReplicas, random) + } else { + // we have topology information, we see what is left to be done from peersReplicatedTo + val doneWithinRack = peersReplicatedTo.exists(_.topologyInfo == blockManagerId.topologyInfo) + val doneOutsideRack = peersReplicatedTo.exists { p => + p.topologyInfo.isDefined && p.topologyInfo != blockManagerId.topologyInfo + } + + if (doneOutsideRack && doneWithinRack) { + // we are done, we just return a random sample + BlockReplicationUtils.getRandomSample(peers, numReplicas, random) + } else { + // we separate peers within and outside rack + val (inRackPeers, outOfRackPeers) = peers + .filter(_.host != blockManagerId.host) + .partition(_.topologyInfo == blockManagerId.topologyInfo) + + val peerWithinRack = if (doneWithinRack) { + // we are done with in-rack replication, so don't need anymore peers + Seq.empty + } else { + if (inRackPeers.isEmpty) { + Seq.empty + } else { + Seq(inRackPeers(random.nextInt(inRackPeers.size))) + } + } + + val peerOutsideRack = if (doneOutsideRack || numReplicas - peerWithinRack.size <= 0) { + Seq.empty + } else { + if (outOfRackPeers.isEmpty) { + Seq.empty + } else { + Seq(outOfRackPeers(random.nextInt(outOfRackPeers.size))) + } + } + + val priorityPeers = peerWithinRack ++ peerOutsideRack + val numRemainingPeers = numReplicas - priorityPeers.size + val remainingPeers = if (numRemainingPeers > 0) { + val rPeers = peers.filter(p => !priorityPeers.contains(p)) + BlockReplicationUtils.getRandomSample(rPeers, numRemainingPeers, random) + } else { + Seq.empty + } + + (priorityPeers ++ remainingPeers).toList + } + } - // we shuffle the result to ensure a random arrangement within the sample - // to avoid any bias from set implementations - r.shuffle(indices.map(_ - 1).toList) } + } diff --git a/core/src/main/scala/org/apache/spark/storage/DiskStore.scala b/core/src/main/scala/org/apache/spark/storage/DiskStore.scala index ca23e2391ed02..c6656341fcd15 100644 --- a/core/src/main/scala/org/apache/spark/storage/DiskStore.scala +++ b/core/src/main/scala/org/apache/spark/storage/DiskStore.scala @@ -17,48 +17,67 @@ package org.apache.spark.storage -import java.io.{FileOutputStream, IOException, RandomAccessFile} +import java.io._ import java.nio.ByteBuffer +import java.nio.channels.{Channels, ReadableByteChannel, WritableByteChannel} import java.nio.channels.FileChannel.MapMode +import java.nio.charset.StandardCharsets.UTF_8 +import java.util.concurrent.ConcurrentHashMap -import com.google.common.io.Closeables +import scala.collection.mutable.ListBuffer -import org.apache.spark.SparkConf +import com.google.common.io.{ByteStreams, Closeables, Files} +import io.netty.channel.FileRegion +import io.netty.util.AbstractReferenceCounted + +import org.apache.spark.{SecurityManager, SparkConf} import org.apache.spark.internal.Logging -import org.apache.spark.util.Utils +import org.apache.spark.network.buffer.ManagedBuffer +import org.apache.spark.network.util.JavaUtils +import org.apache.spark.security.CryptoStreamUtils +import org.apache.spark.util.{ByteBufferInputStream, Utils} import org.apache.spark.util.io.ChunkedByteBuffer /** * Stores BlockManager blocks on disk. */ -private[spark] class DiskStore(conf: SparkConf, diskManager: DiskBlockManager) extends Logging { +private[spark] class DiskStore( + conf: SparkConf, + diskManager: DiskBlockManager, + securityManager: SecurityManager) extends Logging { private val minMemoryMapBytes = conf.getSizeAsBytes("spark.storage.memoryMapThreshold", "2m") + private val blockSizes = new ConcurrentHashMap[String, Long]() - def getSize(blockId: BlockId): Long = { - diskManager.getFile(blockId.name).length - } + def getSize(blockId: BlockId): Long = blockSizes.get(blockId.name) /** * Invokes the provided callback function to write the specific block. * * @throws IllegalStateException if the block already exists in the disk store. */ - def put(blockId: BlockId)(writeFunc: FileOutputStream => Unit): Unit = { + def put(blockId: BlockId)(writeFunc: WritableByteChannel => Unit): Unit = { if (contains(blockId)) { throw new IllegalStateException(s"Block $blockId is already present in the disk store") } logDebug(s"Attempting to put block $blockId") val startTime = System.currentTimeMillis val file = diskManager.getFile(blockId) - val fileOutputStream = new FileOutputStream(file) + val out = new CountingWritableChannel(openForWrite(file)) var threwException: Boolean = true try { - writeFunc(fileOutputStream) + writeFunc(out) + blockSizes.put(blockId.name, out.getCount) threwException = false } finally { try { - Closeables.close(fileOutputStream, threwException) + out.close() + } catch { + case ioe: IOException => + if (!threwException) { + threwException = true + throw ioe + } } finally { if (threwException) { remove(blockId) @@ -73,41 +92,46 @@ private[spark] class DiskStore(conf: SparkConf, diskManager: DiskBlockManager) e } def putBytes(blockId: BlockId, bytes: ChunkedByteBuffer): Unit = { - put(blockId) { fileOutputStream => - val channel = fileOutputStream.getChannel - Utils.tryWithSafeFinally { - bytes.writeFully(channel) - } { - channel.close() - } + put(blockId) { channel => + bytes.writeFully(channel) } } - def getBytes(blockId: BlockId): ChunkedByteBuffer = { + def getBytes(blockId: BlockId): BlockData = { val file = diskManager.getFile(blockId.name) - val channel = new RandomAccessFile(file, "r").getChannel - Utils.tryWithSafeFinally { - // For small files, directly read rather than memory map - if (file.length < minMemoryMapBytes) { - val buf = ByteBuffer.allocate(file.length.toInt) - channel.position(0) - while (buf.remaining() != 0) { - if (channel.read(buf) == -1) { - throw new IOException("Reached EOF before filling buffer\n" + - s"offset=0\nfile=${file.getAbsolutePath}\nbuf.remaining=${buf.remaining}") + val blockSize = getSize(blockId) + + securityManager.getIOEncryptionKey() match { + case Some(key) => + // Encrypted blocks cannot be memory mapped; return a special object that does decryption + // and provides InputStream / FileRegion implementations for reading the data. + new EncryptedBlockData(file, blockSize, conf, key) + + case _ => + val channel = new FileInputStream(file).getChannel() + if (blockSize < minMemoryMapBytes) { + // For small files, directly read rather than memory map. + Utils.tryWithSafeFinally { + val buf = ByteBuffer.allocate(blockSize.toInt) + JavaUtils.readFully(channel, buf) + buf.flip() + new ByteBufferBlockData(new ChunkedByteBuffer(buf), true) + } { + channel.close() + } + } else { + Utils.tryWithSafeFinally { + new ByteBufferBlockData( + new ChunkedByteBuffer(channel.map(MapMode.READ_ONLY, 0, file.length)), true) + } { + channel.close() } } - buf.flip() - new ChunkedByteBuffer(buf) - } else { - new ChunkedByteBuffer(channel.map(MapMode.READ_ONLY, 0, file.length)) - } - } { - channel.close() } } def remove(blockId: BlockId): Boolean = { + blockSizes.remove(blockId.name) val file = diskManager.getFile(blockId.name) if (file.exists()) { val ret = file.delete() @@ -124,4 +148,142 @@ private[spark] class DiskStore(conf: SparkConf, diskManager: DiskBlockManager) e val file = diskManager.getFile(blockId.name) file.exists() } + + private def openForWrite(file: File): WritableByteChannel = { + val out = new FileOutputStream(file).getChannel() + try { + securityManager.getIOEncryptionKey().map { key => + CryptoStreamUtils.createWritableChannel(out, conf, key) + }.getOrElse(out) + } catch { + case e: Exception => + Closeables.close(out, true) + file.delete() + throw e + } + } + +} + +private class EncryptedBlockData( + file: File, + blockSize: Long, + conf: SparkConf, + key: Array[Byte]) extends BlockData { + + override def toInputStream(): InputStream = Channels.newInputStream(open()) + + override def toNetty(): Object = new ReadableChannelFileRegion(open(), blockSize) + + override def toChunkedByteBuffer(allocator: Int => ByteBuffer): ChunkedByteBuffer = { + val source = open() + try { + var remaining = blockSize + val chunks = new ListBuffer[ByteBuffer]() + while (remaining > 0) { + val chunkSize = math.min(remaining, Int.MaxValue) + val chunk = allocator(chunkSize.toInt) + remaining -= chunkSize + JavaUtils.readFully(source, chunk) + chunk.flip() + chunks += chunk + } + + new ChunkedByteBuffer(chunks.toArray) + } finally { + source.close() + } + } + + override def toByteBuffer(): ByteBuffer = { + // This is used by the block transfer service to replicate blocks. The upload code reads + // all bytes into memory to send the block to the remote executor, so it's ok to do this + // as long as the block fits in a Java array. + assert(blockSize <= Int.MaxValue, "Block is too large to be wrapped in a byte buffer.") + val dst = ByteBuffer.allocate(blockSize.toInt) + val in = open() + try { + JavaUtils.readFully(in, dst) + dst.flip() + dst + } finally { + Closeables.close(in, true) + } + } + + override def size: Long = blockSize + + override def dispose(): Unit = { } + + private def open(): ReadableByteChannel = { + val channel = new FileInputStream(file).getChannel() + try { + CryptoStreamUtils.createReadableChannel(channel, conf, key) + } catch { + case e: Exception => + Closeables.close(channel, true) + throw e + } + } + +} + +private class ReadableChannelFileRegion(source: ReadableByteChannel, blockSize: Long) + extends AbstractReferenceCounted with FileRegion { + + private var _transferred = 0L + + private val buffer = ByteBuffer.allocateDirect(64 * 1024) + buffer.flip() + + override def count(): Long = blockSize + + override def position(): Long = 0 + + override def transfered(): Long = _transferred + + override def transferTo(target: WritableByteChannel, pos: Long): Long = { + assert(pos == transfered(), "Invalid position.") + + var written = 0L + var lastWrite = -1L + while (lastWrite != 0) { + if (!buffer.hasRemaining()) { + buffer.clear() + source.read(buffer) + buffer.flip() + } + if (buffer.hasRemaining()) { + lastWrite = target.write(buffer) + written += lastWrite + } else { + lastWrite = 0 + } + } + + _transferred += written + written + } + + override def deallocate(): Unit = source.close() +} + +private class CountingWritableChannel(sink: WritableByteChannel) extends WritableByteChannel { + + private var count = 0L + + def getCount: Long = count + + override def write(src: ByteBuffer): Int = { + val written = sink.write(src) + if (written > 0) { + count += written + } + written + } + + override def isOpen(): Boolean = sink.isOpen() + + override def close(): Unit = sink.close() + } diff --git a/core/src/main/scala/org/apache/spark/storage/StorageStatusListener.scala b/core/src/main/scala/org/apache/spark/storage/StorageStatusListener.scala index 798658a15b797..ac60f795915a3 100644 --- a/core/src/main/scala/org/apache/spark/storage/StorageStatusListener.scala +++ b/core/src/main/scala/org/apache/spark/storage/StorageStatusListener.scala @@ -30,6 +30,7 @@ import org.apache.spark.scheduler._ * This class is thread-safe (unlike JobProgressListener) */ @DeveloperApi +@deprecated("This class will be removed in a future release.", "2.2.0") class StorageStatusListener(conf: SparkConf) extends SparkListener { // This maintains only blocks that are cached (i.e. storage level is not StorageLevel.NONE) private[storage] val executorIdToStorageStatus = mutable.Map[String, StorageStatus]() @@ -41,7 +42,7 @@ class StorageStatusListener(conf: SparkConf) extends SparkListener { } def deadStorageStatusList: Seq[StorageStatus] = synchronized { - deadExecutorStorageStatus.toSeq + deadExecutorStorageStatus } /** Update storage status list to reflect updated block statuses */ @@ -74,8 +75,10 @@ class StorageStatusListener(conf: SparkConf) extends SparkListener { synchronized { val blockManagerId = blockManagerAdded.blockManagerId val executorId = blockManagerId.executorId - val maxMem = blockManagerAdded.maxMem - val storageStatus = new StorageStatus(blockManagerId, maxMem) + // The onHeap and offHeap memory are always defined for new applications, + // but they can be missing if we are replaying old event logs. + val storageStatus = new StorageStatus(blockManagerId, blockManagerAdded.maxMem, + blockManagerAdded.maxOnHeapMem, blockManagerAdded.maxOffHeapMem) executorIdToStorageStatus(executorId) = storageStatus // Try to remove the dead storage status if same executor register the block manager twice. diff --git a/core/src/main/scala/org/apache/spark/storage/StorageUtils.scala b/core/src/main/scala/org/apache/spark/storage/StorageUtils.scala index e12f2e6095d5a..e9694fdbca2de 100644 --- a/core/src/main/scala/org/apache/spark/storage/StorageUtils.scala +++ b/core/src/main/scala/org/apache/spark/storage/StorageUtils.scala @@ -35,7 +35,12 @@ import org.apache.spark.internal.Logging * class cannot mutate the source of the information. Accesses are not thread-safe. */ @DeveloperApi -class StorageStatus(val blockManagerId: BlockManagerId, val maxMem: Long) { +@deprecated("This class may be removed or made private in a future release.", "2.2.0") +class StorageStatus( + val blockManagerId: BlockManagerId, + val maxMemory: Long, + val maxOnHeapMem: Option[Long], + val maxOffHeapMem: Option[Long]) { /** * Internal representation of the blocks stored in this block manager. @@ -46,25 +51,21 @@ class StorageStatus(val blockManagerId: BlockManagerId, val maxMem: Long) { private val _rddBlocks = new mutable.HashMap[Int, mutable.Map[BlockId, BlockStatus]] private val _nonRddBlocks = new mutable.HashMap[BlockId, BlockStatus] - /** - * Storage information of the blocks that entails memory, disk, and off-heap memory usage. - * - * As with the block maps, we store the storage information separately for RDD blocks and - * non-RDD blocks for the same reason. In particular, RDD storage information is stored - * in a map indexed by the RDD ID to the following 4-tuple: - * - * (memory size, disk size, storage level) - * - * We assume that all the blocks that belong to the same RDD have the same storage level. - * This field is not relevant to non-RDD blocks, however, so the storage information for - * non-RDD blocks contains only the first 3 fields (in the same order). - */ - private val _rddStorageInfo = new mutable.HashMap[Int, (Long, Long, StorageLevel)] - private var _nonRddStorageInfo: (Long, Long) = (0L, 0L) + private case class RddStorageInfo(memoryUsage: Long, diskUsage: Long, level: StorageLevel) + private val _rddStorageInfo = new mutable.HashMap[Int, RddStorageInfo] + + private case class NonRddStorageInfo(var onHeapUsage: Long, var offHeapUsage: Long, + var diskUsage: Long) + private val _nonRddStorageInfo = NonRddStorageInfo(0L, 0L, 0L) /** Create a storage status with an initial set of blocks, leaving the source unmodified. */ - def this(bmid: BlockManagerId, maxMem: Long, initialBlocks: Map[BlockId, BlockStatus]) { - this(bmid, maxMem) + def this( + bmid: BlockManagerId, + maxMemory: Long, + maxOnHeapMem: Option[Long], + maxOffHeapMem: Option[Long], + initialBlocks: Map[BlockId, BlockStatus]) { + this(bmid, maxMemory, maxOnHeapMem, maxOffHeapMem) initialBlocks.foreach { case (bid, bstatus) => addBlock(bid, bstatus) } } @@ -176,26 +177,57 @@ class StorageStatus(val blockManagerId: BlockManagerId, val maxMem: Long) { */ def numRddBlocksById(rddId: Int): Int = _rddBlocks.get(rddId).map(_.size).getOrElse(0) + /** Return the max memory can be used by this block manager. */ + def maxMem: Long = maxMemory + /** Return the memory remaining in this block manager. */ def memRemaining: Long = maxMem - memUsed + /** Return the memory used by caching RDDs */ + def cacheSize: Long = onHeapCacheSize.getOrElse(0L) + offHeapCacheSize.getOrElse(0L) + /** Return the memory used by this block manager. */ - def memUsed: Long = _nonRddStorageInfo._1 + cacheSize + def memUsed: Long = onHeapMemUsed.getOrElse(0L) + offHeapMemUsed.getOrElse(0L) - /** Return the memory used by caching RDDs */ - def cacheSize: Long = _rddBlocks.keys.toSeq.map(memUsedByRdd).sum + /** Return the on-heap memory remaining in this block manager. */ + def onHeapMemRemaining: Option[Long] = + for (m <- maxOnHeapMem; o <- onHeapMemUsed) yield m - o + + /** Return the off-heap memory remaining in this block manager. */ + def offHeapMemRemaining: Option[Long] = + for (m <- maxOffHeapMem; o <- offHeapMemUsed) yield m - o + + /** Return the on-heap memory used by this block manager. */ + def onHeapMemUsed: Option[Long] = onHeapCacheSize.map(_ + _nonRddStorageInfo.onHeapUsage) + + /** Return the off-heap memory used by this block manager. */ + def offHeapMemUsed: Option[Long] = offHeapCacheSize.map(_ + _nonRddStorageInfo.offHeapUsage) + + /** Return the memory used by on-heap caching RDDs */ + def onHeapCacheSize: Option[Long] = maxOnHeapMem.map { _ => + _rddStorageInfo.collect { + case (_, storageInfo) if !storageInfo.level.useOffHeap => storageInfo.memoryUsage + }.sum + } + + /** Return the memory used by off-heap caching RDDs */ + def offHeapCacheSize: Option[Long] = maxOffHeapMem.map { _ => + _rddStorageInfo.collect { + case (_, storageInfo) if storageInfo.level.useOffHeap => storageInfo.memoryUsage + }.sum + } /** Return the disk space used by this block manager. */ - def diskUsed: Long = _nonRddStorageInfo._2 + _rddBlocks.keys.toSeq.map(diskUsedByRdd).sum + def diskUsed: Long = _nonRddStorageInfo.diskUsage + _rddBlocks.keys.toSeq.map(diskUsedByRdd).sum /** Return the memory used by the given RDD in this block manager in O(1) time. */ - def memUsedByRdd(rddId: Int): Long = _rddStorageInfo.get(rddId).map(_._1).getOrElse(0L) + def memUsedByRdd(rddId: Int): Long = _rddStorageInfo.get(rddId).map(_.memoryUsage).getOrElse(0L) /** Return the disk space used by the given RDD in this block manager in O(1) time. */ - def diskUsedByRdd(rddId: Int): Long = _rddStorageInfo.get(rddId).map(_._2).getOrElse(0L) + def diskUsedByRdd(rddId: Int): Long = _rddStorageInfo.get(rddId).map(_.diskUsage).getOrElse(0L) /** Return the storage level, if any, used by the given RDD in this block manager. */ - def rddStorageLevel(rddId: Int): Option[StorageLevel] = _rddStorageInfo.get(rddId).map(_._3) + def rddStorageLevel(rddId: Int): Option[StorageLevel] = _rddStorageInfo.get(rddId).map(_.level) /** * Update the relevant storage info, taking into account any existing status for this block. @@ -210,10 +242,12 @@ class StorageStatus(val blockManagerId: BlockManagerId, val maxMem: Long) { val (oldMem, oldDisk) = blockId match { case RDDBlockId(rddId, _) => _rddStorageInfo.get(rddId) - .map { case (mem, disk, _) => (mem, disk) } + .map { case RddStorageInfo(mem, disk, _) => (mem, disk) } .getOrElse((0L, 0L)) - case _ => - _nonRddStorageInfo + case _ if !level.useOffHeap => + (_nonRddStorageInfo.onHeapUsage, _nonRddStorageInfo.diskUsage) + case _ if level.useOffHeap => + (_nonRddStorageInfo.offHeapUsage, _nonRddStorageInfo.diskUsage) } val newMem = math.max(oldMem + changeInMem, 0L) val newDisk = math.max(oldDisk + changeInDisk, 0L) @@ -225,30 +259,40 @@ class StorageStatus(val blockManagerId: BlockManagerId, val maxMem: Long) { if (newMem + newDisk == 0) { _rddStorageInfo.remove(rddId) } else { - _rddStorageInfo(rddId) = (newMem, newDisk, level) + _rddStorageInfo(rddId) = RddStorageInfo(newMem, newDisk, level) } case _ => - _nonRddStorageInfo = (newMem, newDisk) + if (!level.useOffHeap) { + _nonRddStorageInfo.onHeapUsage = newMem + } else { + _nonRddStorageInfo.offHeapUsage = newMem + } + _nonRddStorageInfo.diskUsage = newDisk } } - } /** Helper methods for storage-related objects. */ private[spark] object StorageUtils extends Logging { - /** - * Attempt to clean up a ByteBuffer if it is memory-mapped. This uses an *unsafe* Sun API that - * might cause errors if one attempts to read from the unmapped buffer, but it's better than - * waiting for the GC to find it because that could lead to huge numbers of open files. There's - * unfortunately no standard API to do this. + * Attempt to clean up a ByteBuffer if it is direct or memory-mapped. This uses an *unsafe* Sun + * API that will cause errors if one attempts to read from the disposed buffer. However, neither + * the bytes allocated to direct buffers nor file descriptors opened for memory-mapped buffers put + * pressure on the garbage collector. Waiting for garbage collection may lead to the depletion of + * off-heap memory or huge numbers of open files. There's unfortunately no standard API to + * manually dispose of these kinds of buffers. */ def dispose(buffer: ByteBuffer): Unit = { if (buffer != null && buffer.isInstanceOf[MappedByteBuffer]) { - logTrace(s"Unmapping $buffer") - if (buffer.asInstanceOf[DirectBuffer].cleaner() != null) { - buffer.asInstanceOf[DirectBuffer].cleaner().clean() - } + logTrace(s"Disposing of $buffer") + cleanDirectBuffer(buffer.asInstanceOf[DirectBuffer]) + } + } + + private def cleanDirectBuffer(buffer: DirectBuffer) = { + val cleaner = buffer.cleaner() + if (cleaner != null) { + cleaner.clean() } } diff --git a/core/src/main/scala/org/apache/spark/storage/memory/MemoryStore.scala b/core/src/main/scala/org/apache/spark/storage/memory/MemoryStore.scala index fb54dd66a39a9..90e3af2d0ec74 100644 --- a/core/src/main/scala/org/apache/spark/storage/memory/MemoryStore.scala +++ b/core/src/main/scala/org/apache/spark/storage/memory/MemoryStore.scala @@ -344,7 +344,7 @@ private[spark] class MemoryStore( val serializationStream: SerializationStream = { val autoPick = !blockId.isInstanceOf[StreamBlockId] val ser = serializerManager.getSerializer(classTag, autoPick).newInstance() - ser.serializeStream(serializerManager.wrapStream(blockId, redirectableStream)) + ser.serializeStream(serializerManager.wrapForCompression(blockId, redirectableStream)) } // Request enough memory to begin unrolling diff --git a/core/src/main/scala/org/apache/spark/ui/JettyUtils.scala b/core/src/main/scala/org/apache/spark/ui/JettyUtils.scala index bdbdba5780856..edf328b5ae538 100644 --- a/core/src/main/scala/org/apache/spark/ui/JettyUtils.scala +++ b/core/src/main/scala/org/apache/spark/ui/JettyUtils.scala @@ -29,8 +29,8 @@ import org.eclipse.jetty.client.api.Response import org.eclipse.jetty.proxy.ProxyServlet import org.eclipse.jetty.server._ import org.eclipse.jetty.server.handler._ +import org.eclipse.jetty.server.handler.gzip.GzipHandler import org.eclipse.jetty.servlet._ -import org.eclipse.jetty.servlets.gzip.GzipHandler import org.eclipse.jetty.util.component.LifeCycle import org.eclipse.jetty.util.thread.{QueuedThreadPool, ScheduledExecutorScheduler} import org.json4s.JValue diff --git a/core/src/main/scala/org/apache/spark/ui/SparkUI.scala b/core/src/main/scala/org/apache/spark/ui/SparkUI.scala index 7d31ac54a7177..bf4cf79e9faa3 100644 --- a/core/src/main/scala/org/apache/spark/ui/SparkUI.scala +++ b/core/src/main/scala/org/apache/spark/ui/SparkUI.scala @@ -117,7 +117,7 @@ private[spark] class SparkUI private ( endTime = new Date(-1), duration = 0, lastUpdated = new Date(startTime), - sparkUser = "", + sparkUser = getSparkUser, completed = false )) )) diff --git a/core/src/main/scala/org/apache/spark/ui/UIUtils.scala b/core/src/main/scala/org/apache/spark/ui/UIUtils.scala index d161843dd2230..79b0d81af52b5 100644 --- a/core/src/main/scala/org/apache/spark/ui/UIUtils.scala +++ b/core/src/main/scala/org/apache/spark/ui/UIUtils.scala @@ -342,7 +342,7 @@ private[spark] object UIUtils extends Logging { completed: Int, failed: Int, skipped: Int, - killed: Int, + reasonToNumKilled: Map[String, Int], total: Int): Seq[Node] = { val completeWidth = "width: %s%%".format((completed.toDouble/total)*100) // started + completed can be > total when there are speculative tasks @@ -354,7 +354,10 @@ private[spark] object UIUtils extends Logging { {completed}/{total} { if (failed > 0) s"($failed failed)" } { if (skipped > 0) s"($skipped skipped)" } - { if (killed > 0) s"($killed killed)" } + { reasonToNumKilled.toSeq.sortBy(-_._2).map { + case (reason, count) => s"($count killed: $reason)" + } + }
    @@ -443,7 +446,7 @@ private[spark] object UIUtils extends Logging { val xml = XML.loadString(s"""$desc""") // Verify that this has only anchors and span (we are wrapping in span) - val allowedNodeLabels = Set("a", "span") + val allowedNodeLabels = Set("a", "span", "br") val illegalNodes = xml \\ "_" filterNot { case node: Node => allowedNodeLabels.contains(node.label) } diff --git a/core/src/main/scala/org/apache/spark/ui/WebUI.scala b/core/src/main/scala/org/apache/spark/ui/WebUI.scala index a9480cc220c8d..8b75f5d8fe1a8 100644 --- a/core/src/main/scala/org/apache/spark/ui/WebUI.scala +++ b/core/src/main/scala/org/apache/spark/ui/WebUI.scala @@ -124,7 +124,7 @@ private[spark] abstract class WebUI( /** Bind to the HTTP server behind this web interface. */ def bind(): Unit = { - assert(!serverInfo.isDefined, s"Attempted to bind $className more than once!") + assert(serverInfo.isEmpty, s"Attempted to bind $className more than once!") try { val host = Option(conf.getenv("SPARK_LOCAL_IP")).getOrElse("0.0.0.0") serverInfo = Some(startJettyServer(host, port, sslOptions, handlers, conf, name)) diff --git a/core/src/main/scala/org/apache/spark/ui/env/EnvironmentTab.scala b/core/src/main/scala/org/apache/spark/ui/env/EnvironmentTab.scala index 70b3ffd95e605..8c18464e6477a 100644 --- a/core/src/main/scala/org/apache/spark/ui/env/EnvironmentTab.scala +++ b/core/src/main/scala/org/apache/spark/ui/env/EnvironmentTab.scala @@ -32,6 +32,7 @@ private[ui] class EnvironmentTab(parent: SparkUI) extends SparkUITab(parent, "en * A SparkListener that prepares information to be displayed on the EnvironmentTab */ @DeveloperApi +@deprecated("This class will be removed in a future release.", "2.2.0") class EnvironmentListener extends SparkListener { var jvmInformation = Seq[(String, String)]() var sparkProperties = Seq[(String, String)]() diff --git a/core/src/main/scala/org/apache/spark/ui/exec/ExecutorThreadDumpPage.scala b/core/src/main/scala/org/apache/spark/ui/exec/ExecutorThreadDumpPage.scala index c6a07445f2a35..6ce3f511e89c7 100644 --- a/core/src/main/scala/org/apache/spark/ui/exec/ExecutorThreadDumpPage.scala +++ b/core/src/main/scala/org/apache/spark/ui/exec/ExecutorThreadDumpPage.scala @@ -17,6 +17,7 @@ package org.apache.spark.ui.exec +import java.util.Locale import javax.servlet.http.HttpServletRequest import scala.xml.{Node, Text} @@ -42,14 +43,15 @@ private[ui] class ExecutorThreadDumpPage(parent: ExecutorsTab) extends WebUIPage val v1 = if (threadTrace1.threadName.contains("Executor task launch")) 1 else 0 val v2 = if (threadTrace2.threadName.contains("Executor task launch")) 1 else 0 if (v1 == v2) { - threadTrace1.threadName.toLowerCase < threadTrace2.threadName.toLowerCase + threadTrace1.threadName.toLowerCase(Locale.ROOT) < + threadTrace2.threadName.toLowerCase(Locale.ROOT) } else { v1 > v2 } }.map { thread => val threadId = thread.threadId val blockedBy = thread.blockedByThreadId match { - case Some(blockedByThreadId) => + case Some(_) =>
    Blocked by Thread {thread.blockedByThreadId} {thread.blockedByLock} diff --git a/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsPage.scala b/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsPage.scala index 2d1691e55c428..b7cbed468517c 100644 --- a/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsPage.scala +++ b/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsPage.scala @@ -21,7 +21,7 @@ import javax.servlet.http.HttpServletRequest import scala.xml.Node -import org.apache.spark.status.api.v1.ExecutorSummary +import org.apache.spark.status.api.v1.{ExecutorSummary, MemoryMetrics} import org.apache.spark.ui.{UIUtils, WebUIPage} // This isn't even used anymore -- but we need to keep it b/c of a MiMa false positive @@ -40,7 +40,8 @@ private[ui] case class ExecutorSummaryInfo( totalShuffleRead: Long, totalShuffleWrite: Long, isBlacklisted: Int, - maxMemory: Long, + maxOnHeapMem: Long, + maxOffHeapMem: Long, executorLogs: Map[String, String]) @@ -48,24 +49,56 @@ private[ui] class ExecutorsPage( parent: ExecutorsTab, threadDumpEnabled: Boolean) extends WebUIPage("") { - private val listener = parent.listener def render(request: HttpServletRequest): Seq[Node] = { val content =
    { +
    + + + Show Additional Metrics + + +
    ++
    ++ ++ ++ } -
    ; +
    UIUtils.headerSparkPage("Executors", content, parent, useDataTables = true) } } private[spark] object ExecutorsPage { + private val ON_HEAP_MEMORY_TOOLTIP = "Memory used / total available memory for on heap " + + "storage of data like RDD partitions cached in memory." + private val OFF_HEAP_MEMORY_TOOLTIP = "Memory used / total available memory for off heap " + + "storage of data like RDD partitions cached in memory." + /** Represent an executor's info as a map given a storage status index */ def getExecInfo( listener: ExecutorsListener, @@ -81,6 +114,16 @@ private[spark] object ExecutorsPage { val rddBlocks = status.numBlocks val memUsed = status.memUsed val maxMem = status.maxMem + val memoryMetrics = for { + onHeapUsed <- status.onHeapMemUsed + offHeapUsed <- status.offHeapMemUsed + maxOnHeap <- status.maxOnHeapMem + maxOffHeap <- status.maxOffHeapMem + } yield { + new MemoryMetrics(onHeapUsed, offHeapUsed, maxOnHeap, maxOffHeap) + } + + val diskUsed = status.diskUsed val taskSummary = listener.executorToTaskSummary.getOrElse(execId, ExecutorTaskSummary(execId)) @@ -104,7 +147,8 @@ private[spark] object ExecutorsPage { taskSummary.shuffleWrite, taskSummary.isBlacklisted, maxMem, - taskSummary.executorLogs + taskSummary.executorLogs, + memoryMetrics ) } } diff --git a/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsTab.scala b/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsTab.scala index 8ae712f8ed323..aabf6e0c63c02 100644 --- a/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsTab.scala +++ b/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsTab.scala @@ -62,9 +62,10 @@ private[ui] case class ExecutorTaskSummary( * A SparkListener that prepares information to be displayed on the ExecutorsTab */ @DeveloperApi +@deprecated("This class will be removed in a future release.", "2.2.0") class ExecutorsListener(storageStatusListener: StorageStatusListener, conf: SparkConf) extends SparkListener { - var executorToTaskSummary = LinkedHashMap[String, ExecutorTaskSummary]() + val executorToTaskSummary = LinkedHashMap[String, ExecutorTaskSummary]() var executorEvents = new ListBuffer[SparkListenerEvent]() private val maxTimelineExecutors = conf.getInt("spark.ui.timeline.executors.maximum", 1000) @@ -137,7 +138,7 @@ class ExecutorsListener(storageStatusListener: StorageStatusListener, conf: Spar // could have failed half-way through. The correct fix would be to keep track of the // metrics added by each attempt, but this is much more complicated. return - case e: ExceptionFailure => + case _: ExceptionFailure => taskSummary.tasksFailed += 1 case _ => taskSummary.tasksComplete += 1 diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/AllJobsPage.scala b/core/src/main/scala/org/apache/spark/ui/jobs/AllJobsPage.scala index d217f558045f2..18be0870746e9 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/AllJobsPage.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/AllJobsPage.scala @@ -630,8 +630,8 @@ private[ui] class JobPagedTable( {UIUtils.makeProgressBar(started = job.numActiveTasks, completed = job.numCompletedTasks, - failed = job.numFailedTasks, skipped = job.numSkippedTasks, killed = job.numKilledTasks, - total = job.numTasks - job.numSkippedTasks)} + failed = job.numFailedTasks, skipped = job.numSkippedTasks, + reasonToNumKilled = job.reasonToNumKilled, total = job.numTasks - job.numSkippedTasks)} } diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/AllStagesPage.scala b/core/src/main/scala/org/apache/spark/ui/jobs/AllStagesPage.scala index fe6ca1099e6b0..2b0816e35747d 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/AllStagesPage.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/AllStagesPage.scala @@ -34,9 +34,9 @@ private[ui] class AllStagesPage(parent: StagesTab) extends WebUIPage("") { listener.synchronized { val activeStages = listener.activeStages.values.toSeq val pendingStages = listener.pendingStages.values.toSeq - val completedStages = listener.completedStages.reverse.toSeq + val completedStages = listener.completedStages.reverse val numCompletedStages = listener.numCompletedStages - val failedStages = listener.failedStages.reverse.toSeq + val failedStages = listener.failedStages.reverse val numFailedStages = listener.numFailedStages val subPath = "stages" diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/ExecutorTable.scala b/core/src/main/scala/org/apache/spark/ui/jobs/ExecutorTable.scala index cd1b02addc789..382a6f979f2e6 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/ExecutorTable.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/ExecutorTable.scala @@ -133,9 +133,9 @@ private[ui] class ExecutorTable(stageId: Int, stageAttemptId: Int, parent: Stage {executorIdToAddress.getOrElse(k, "CANNOT FIND ADDRESS")} {UIUtils.formatDuration(v.taskTime)} - {v.failedTasks + v.succeededTasks + v.killedTasks} + {v.failedTasks + v.succeededTasks + v.reasonToNumKilled.values.sum} {v.failedTasks} - {v.killedTasks} + {v.reasonToNumKilled.values.sum} {v.succeededTasks} {if (stageData.hasInput) { diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/JobPage.scala b/core/src/main/scala/org/apache/spark/ui/jobs/JobPage.scala index 0ff9e5e9411ca..3131c4a1eb7d4 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/JobPage.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/JobPage.scala @@ -17,7 +17,7 @@ package org.apache.spark.ui.jobs -import java.util.Date +import java.util.{Date, Locale} import javax.servlet.http.HttpServletRequest import scala.collection.mutable.{Buffer, ListBuffer} @@ -77,7 +77,7 @@ private[ui] class JobPage(parent: JobsTab) extends WebUIPage("job") { | 'content': '
    + case JobFailed(_) => failedJobs += jobData trimJobsIfNecessary(failedJobs) jobData.status = JobExecutionStatus.FAILED @@ -284,7 +285,7 @@ class JobProgressListener(conf: SparkConf) extends SparkListener with Logging { ) { jobData.numActiveStages -= 1 if (stage.failureReason.isEmpty) { - if (!stage.submissionTime.isEmpty) { + if (stage.submissionTime.isDefined) { jobData.completedStageIndices.add(stage.stageId) } } else { @@ -371,8 +372,9 @@ class JobProgressListener(conf: SparkConf) extends SparkListener with Logging { taskEnd.reason match { case Success => execSummary.succeededTasks += 1 - case TaskKilled => - execSummary.killedTasks += 1 + case kill: TaskKilled => + execSummary.reasonToNumKilled = execSummary.reasonToNumKilled.updated( + kill.reason, execSummary.reasonToNumKilled.getOrElse(kill.reason, 0) + 1) case _ => execSummary.failedTasks += 1 } @@ -385,9 +387,10 @@ class JobProgressListener(conf: SparkConf) extends SparkListener with Logging { stageData.completedIndices.add(info.index) stageData.numCompleteTasks += 1 None - case TaskKilled => - stageData.numKilledTasks += 1 - Some(TaskKilled.toErrorString) + case kill: TaskKilled => + stageData.reasonToNumKilled = stageData.reasonToNumKilled.updated( + kill.reason, stageData.reasonToNumKilled.getOrElse(kill.reason, 0) + 1) + Some(kill.toErrorString) case e: ExceptionFailure => // Handle ExceptionFailure because we might have accumUpdates stageData.numFailedTasks += 1 Some(e.toErrorString) @@ -422,8 +425,9 @@ class JobProgressListener(conf: SparkConf) extends SparkListener with Logging { taskEnd.reason match { case Success => jobData.numCompletedTasks += 1 - case TaskKilled => - jobData.numKilledTasks += 1 + case kill: TaskKilled => + jobData.reasonToNumKilled = jobData.reasonToNumKilled.updated( + kill.reason, jobData.reasonToNumKilled.getOrElse(kill.reason, 0) + 1) case _ => jobData.numFailedTasks += 1 } diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala b/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala index ff17775008acc..19325a2dc9169 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala @@ -142,7 +142,7 @@ private[ui] class StagePage(parent: StagesTab) extends WebUIPage("stage") { val allAccumulables = progressListener.stageIdToData((stageId, stageAttemptId)).accumulables val externalAccumulables = allAccumulables.values.filter { acc => !acc.internal } - val hasAccumulators = externalAccumulables.size > 0 + val hasAccumulators = externalAccumulables.nonEmpty val summary =
    @@ -339,7 +339,7 @@ private[ui] class StagePage(parent: StagesTab) extends WebUIPage("stage") { val validTasks = tasks.filter(t => t.taskInfo.status == "SUCCESS" && t.metrics.isDefined) val summaryTable: Option[Seq[Node]] = - if (validTasks.size == 0) { + if (validTasks.isEmpty) { None } else { @@ -786,8 +786,8 @@ private[ui] object StagePage { info: TaskInfo, metrics: TaskMetricsUIData, currentTime: Long): Long = { if (info.finished) { val totalExecutionTime = info.finishTime - info.launchTime - val executorOverhead = (metrics.executorDeserializeTime + - metrics.resultSerializationTime) + val executorOverhead = metrics.executorDeserializeTime + + metrics.resultSerializationTime math.max( 0, totalExecutionTime - metrics.executorRunTime - executorOverhead - @@ -872,7 +872,7 @@ private[ui] class TaskDataSource( // so that we can avoid creating duplicate contents during sorting the data private val data = tasks.map(taskRow).sorted(ordering(sortColumn, desc)) - private var _slicedTaskIds: Set[Long] = null + private var _slicedTaskIds: Set[Long] = _ override def dataSize: Int = data.size diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/StageTable.scala b/core/src/main/scala/org/apache/spark/ui/jobs/StageTable.scala index e1fa9043b6a15..256b726fa7eea 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/StageTable.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/StageTable.scala @@ -300,7 +300,7 @@ private[ui] class StagePagedTable( {UIUtils.makeProgressBar(started = stageData.numActiveTasks, completed = stageData.completedIndices.size, failed = stageData.numFailedTasks, - skipped = 0, killed = stageData.numKilledTasks, total = info.numTasks)} + skipped = 0, reasonToNumKilled = stageData.reasonToNumKilled, total = info.numTasks)} {data.inputReadWithUnit} {data.outputWriteWithUnit} @@ -412,7 +412,7 @@ private[ui] class StageDataSource( // so that we can avoid creating duplicate contents during sorting the data private val data = stages.map(stageRow).sorted(ordering(sortColumn, desc)) - private var _slicedStageIds: Set[Int] = null + private var _slicedStageIds: Set[Int] = _ override def dataSize: Int = data.size diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/StagesTab.scala b/core/src/main/scala/org/apache/spark/ui/jobs/StagesTab.scala index c1f25114371f1..181465bdf9609 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/StagesTab.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/StagesTab.scala @@ -42,7 +42,7 @@ private[ui] class StagesTab(parent: SparkUI) extends SparkUITab(parent, "stages" val stageId = Option(request.getParameter("id")).map(_.toInt) stageId.foreach { id => if (progressListener.activeStages.contains(id)) { - sc.foreach(_.cancelStage(id)) + sc.foreach(_.cancelStage(id, "killed via the Web UI")) // Do a quick pause here to give Spark time to kill the stage so it shows up as // killed after the refresh. Note that this will block the serving thread so the // time should be limited in duration. diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/UIData.scala b/core/src/main/scala/org/apache/spark/ui/jobs/UIData.scala index 073f7edfc2fe9..ac1a74ad8029d 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/UIData.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/UIData.scala @@ -32,7 +32,7 @@ private[spark] object UIData { var taskTime : Long = 0 var failedTasks : Int = 0 var succeededTasks : Int = 0 - var killedTasks : Int = 0 + var reasonToNumKilled : Map[String, Int] = Map.empty var inputBytes : Long = 0 var inputRecords : Long = 0 var outputBytes : Long = 0 @@ -64,7 +64,7 @@ private[spark] object UIData { var numCompletedTasks: Int = 0, var numSkippedTasks: Int = 0, var numFailedTasks: Int = 0, - var numKilledTasks: Int = 0, + var reasonToNumKilled: Map[String, Int] = Map.empty, /* Stages */ var numActiveStages: Int = 0, // This needs to be a set instead of a simple count to prevent double-counting of rerun stages: @@ -78,7 +78,7 @@ private[spark] object UIData { var numCompleteTasks: Int = _ var completedIndices = new OpenHashSet[Int]() var numFailedTasks: Int = _ - var numKilledTasks: Int = _ + var reasonToNumKilled: Map[String, Int] = Map.empty var executorRunTime: Long = _ var executorCpuTime: Long = _ diff --git a/core/src/main/scala/org/apache/spark/ui/storage/RDDPage.scala b/core/src/main/scala/org/apache/spark/ui/storage/RDDPage.scala index 227e940c9c50c..a1a0c729b9240 100644 --- a/core/src/main/scala/org/apache/spark/ui/storage/RDDPage.scala +++ b/core/src/main/scala/org/apache/spark/ui/storage/RDDPage.scala @@ -147,7 +147,8 @@ private[ui] class RDDPage(parent: StorageTab) extends WebUIPage("rdd") { /** Header fields for the worker table */ private def workerHeader = Seq( "Host", - "Memory Usage", + "On Heap Memory Usage", + "Off Heap Memory Usage", "Disk Usage") /** Render an HTML row representing a worker */ @@ -155,8 +156,12 @@ private[ui] class RDDPage(parent: StorageTab) extends WebUIPage("rdd") { {worker.address} - {Utils.bytesToString(worker.memoryUsed)} - ({Utils.bytesToString(worker.memoryRemaining)} Remaining) + {Utils.bytesToString(worker.onHeapMemoryUsed.getOrElse(0L))} + ({Utils.bytesToString(worker.onHeapMemoryRemaining.getOrElse(0L))} Remaining) + + + {Utils.bytesToString(worker.offHeapMemoryUsed.getOrElse(0L))} + ({Utils.bytesToString(worker.offHeapMemoryRemaining.getOrElse(0L))} Remaining) {Utils.bytesToString(worker.diskUsed)} diff --git a/core/src/main/scala/org/apache/spark/ui/storage/StoragePage.scala b/core/src/main/scala/org/apache/spark/ui/storage/StoragePage.scala index 76d7c6d414bcf..aa84788f1df88 100644 --- a/core/src/main/scala/org/apache/spark/ui/storage/StoragePage.scala +++ b/core/src/main/scala/org/apache/spark/ui/storage/StoragePage.scala @@ -151,7 +151,7 @@ private[ui] class StoragePage(parent: StorageTab) extends WebUIPage("") { /** Render a stream block */ private def streamBlockTableRow(block: (BlockId, Seq[BlockUIData])): Seq[Node] = { val replications = block._2 - assert(replications.size > 0) // This must be true because it's the result of "groupBy" + assert(replications.nonEmpty) // This must be true because it's the result of "groupBy" if (replications.size == 1) { streamBlockTableSubrow(block._1, replications.head, replications.size, true) } else { diff --git a/core/src/main/scala/org/apache/spark/ui/storage/StorageTab.scala b/core/src/main/scala/org/apache/spark/ui/storage/StorageTab.scala index c212362557be6..148efb134e14f 100644 --- a/core/src/main/scala/org/apache/spark/ui/storage/StorageTab.scala +++ b/core/src/main/scala/org/apache/spark/ui/storage/StorageTab.scala @@ -39,6 +39,7 @@ private[ui] class StorageTab(parent: SparkUI) extends SparkUITab(parent, "storag * This class is thread-safe (unlike JobProgressListener) */ @DeveloperApi +@deprecated("This class will be removed in a future release.", "2.2.0") class StorageListener(storageStatusListener: StorageStatusListener) extends BlockStatusListener { private[ui] val _rddInfoMap = mutable.Map[Int, RDDInfo]() // exposed for testing diff --git a/core/src/main/scala/org/apache/spark/util/AccumulatorV2.scala b/core/src/main/scala/org/apache/spark/util/AccumulatorV2.scala index 00e0cf257cd4a..a65ec75cc5db6 100644 --- a/core/src/main/scala/org/apache/spark/util/AccumulatorV2.scala +++ b/core/src/main/scala/org/apache/spark/util/AccumulatorV2.scala @@ -84,8 +84,12 @@ abstract class AccumulatorV2[IN, OUT] extends Serializable { * Returns the name of this accumulator, can only be called after registration. */ final def name: Option[String] = { - assertMetadataNotNull() - metadata.name + if (atDriverSide) { + AccumulatorContext.get(id).flatMap(_.metadata.name) + } else { + assertMetadataNotNull() + metadata.name + } } /** @@ -161,7 +165,15 @@ abstract class AccumulatorV2[IN, OUT] extends Serializable { } val copyAcc = copyAndReset() assert(copyAcc.isZero, "copyAndReset must return a zero value copy") - copyAcc.metadata = metadata + val isInternalAcc = + (name.isDefined && name.get.startsWith(InternalAccumulator.METRICS_PREFIX)) || + getClass.getSimpleName == "SQLMetric" + if (isInternalAcc) { + // Do not serialize the name of internal accumulator and send it to executor. + copyAcc.metadata = metadata.copy(name = None) + } else { + copyAcc.metadata = metadata + } copyAcc } else { this @@ -263,23 +275,13 @@ private[spark] object AccumulatorContext { originals.clear() } - /** - * Looks for a registered accumulator by accumulator name. - */ - private[spark] def lookForAccumulatorByName(name: String): Option[AccumulatorV2[_, _]] = { - originals.values().asScala.find { ref => - val acc = ref.get - acc != null && acc.name.isDefined && acc.name.get == name - }.map(_.get) - } - // Identifier for distinguishing SQL metrics from other accumulators private[spark] val SQL_ACCUM_IDENTIFIER = "sql" } /** - * An [[AccumulatorV2 accumulator]] for computing sum, count, and averages for 64-bit integers. + * An [[AccumulatorV2 accumulator]] for computing sum, count, and average of 64-bit integers. * * @since 2.0.0 */ diff --git a/core/src/main/scala/org/apache/spark/util/ByteBufferInputStream.scala b/core/src/main/scala/org/apache/spark/util/ByteBufferInputStream.scala index dce2ac63a664c..50dc948e6c410 100644 --- a/core/src/main/scala/org/apache/spark/util/ByteBufferInputStream.scala +++ b/core/src/main/scala/org/apache/spark/util/ByteBufferInputStream.scala @@ -23,11 +23,10 @@ import java.nio.ByteBuffer import org.apache.spark.storage.StorageUtils /** - * Reads data from a ByteBuffer, and optionally cleans it up using StorageUtils.dispose() - * at the end of the stream (e.g. to close a memory-mapped file). + * Reads data from a ByteBuffer. */ private[spark] -class ByteBufferInputStream(private var buffer: ByteBuffer, dispose: Boolean = false) +class ByteBufferInputStream(private var buffer: ByteBuffer) extends InputStream { override def read(): Int = { @@ -72,9 +71,6 @@ class ByteBufferInputStream(private var buffer: ByteBuffer, dispose: Boolean = f */ private def cleanUp() { if (buffer != null) { - if (dispose) { - StorageUtils.dispose(buffer) - } buffer = null } } diff --git a/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala b/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala index 4b4d2d10cbf8d..8296c4294242c 100644 --- a/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala +++ b/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala @@ -182,7 +182,9 @@ private[spark] object JsonProtocol { ("Event" -> SPARK_LISTENER_EVENT_FORMATTED_CLASS_NAMES.blockManagerAdded) ~ ("Block Manager ID" -> blockManagerId) ~ ("Maximum Memory" -> blockManagerAdded.maxMem) ~ - ("Timestamp" -> blockManagerAdded.time) + ("Timestamp" -> blockManagerAdded.time) ~ + ("Maximum Onheap Memory" -> blockManagerAdded.maxOnHeapMem) ~ + ("Maximum Offheap Memory" -> blockManagerAdded.maxOffHeapMem) } def blockManagerRemovedToJson(blockManagerRemoved: SparkListenerBlockManagerRemoved): JValue = { @@ -264,8 +266,7 @@ private[spark] object JsonProtocol { ("Submission Time" -> submissionTime) ~ ("Completion Time" -> completionTime) ~ ("Failure Reason" -> failureReason) ~ - ("Accumulables" -> JArray( - stageInfo.accumulables.values.map(accumulableInfoToJson).toList)) + ("Accumulables" -> accumulablesToJson(stageInfo.accumulables.values)) } def taskInfoToJson(taskInfo: TaskInfo): JValue = { @@ -281,7 +282,15 @@ private[spark] object JsonProtocol { ("Finish Time" -> taskInfo.finishTime) ~ ("Failed" -> taskInfo.failed) ~ ("Killed" -> taskInfo.killed) ~ - ("Accumulables" -> JArray(taskInfo.accumulables.toList.map(accumulableInfoToJson))) + ("Accumulables" -> accumulablesToJson(taskInfo.accumulables)) + } + + private lazy val accumulableBlacklist = Set("internal.metrics.updatedBlockStatuses") + + def accumulablesToJson(accumulables: Traversable[AccumulableInfo]): JArray = { + JArray(accumulables + .filterNot(_.name.exists(accumulableBlacklist.contains)) + .toList.map(accumulableInfoToJson)) } def accumulableInfoToJson(accumulableInfo: AccumulableInfo): JValue = { @@ -376,7 +385,7 @@ private[spark] object JsonProtocol { ("Message" -> fetchFailed.message) case exceptionFailure: ExceptionFailure => val stackTrace = stackTraceToJson(exceptionFailure.stackTrace) - val accumUpdates = JArray(exceptionFailure.accumUpdates.map(accumulableInfoToJson).toList) + val accumUpdates = accumulablesToJson(exceptionFailure.accumUpdates) ("Class Name" -> exceptionFailure.className) ~ ("Description" -> exceptionFailure.description) ~ ("Stack Trace" -> stackTrace) ~ @@ -390,6 +399,8 @@ private[spark] object JsonProtocol { ("Executor ID" -> executorId) ~ ("Exit Caused By App" -> exitCausedByApp) ~ ("Loss Reason" -> reason.map(_.toString)) + case taskKilled: TaskKilled => + ("Kill Reason" -> taskKilled.reason) case _ => Utils.emptyJson } ("Reason" -> reason) ~ json @@ -603,7 +614,9 @@ private[spark] object JsonProtocol { val blockManagerId = blockManagerIdFromJson(json \ "Block Manager ID") val maxMem = (json \ "Maximum Memory").extract[Long] val time = Utils.jsonOption(json \ "Timestamp").map(_.extract[Long]).getOrElse(-1L) - SparkListenerBlockManagerAdded(time, blockManagerId, maxMem) + val maxOnHeapMem = Utils.jsonOption(json \ "Maximum Onheap Memory").map(_.extract[Long]) + val maxOffHeapMem = Utils.jsonOption(json \ "Maximum Offheap Memory").map(_.extract[Long]) + SparkListenerBlockManagerAdded(time, blockManagerId, maxMem, maxOnHeapMem, maxOffHeapMem) } def blockManagerRemovedFromJson(json: JValue): SparkListenerBlockManagerRemoved = { @@ -877,7 +890,10 @@ private[spark] object JsonProtocol { })) ExceptionFailure(className, description, stackTrace, fullStackTrace, None, accumUpdates) case `taskResultLost` => TaskResultLost - case `taskKilled` => TaskKilled + case `taskKilled` => + val killReason = Utils.jsonOption(json \ "Kill Reason") + .map(_.extract[String]).getOrElse("unknown reason") + TaskKilled(killReason) case `taskCommitDenied` => // Unfortunately, the `TaskCommitDenied` message was introduced in 1.3.0 but the JSON // de/serialization logic was not added until 1.5.1. To provide backward compatibility diff --git a/mllib/src/main/scala/org/apache/spark/mllib/impl/PeriodicCheckpointer.scala b/core/src/main/scala/org/apache/spark/util/PeriodicCheckpointer.scala similarity index 95% rename from mllib/src/main/scala/org/apache/spark/mllib/impl/PeriodicCheckpointer.scala rename to core/src/main/scala/org/apache/spark/util/PeriodicCheckpointer.scala index 4dd498cd91b4e..ce06e18879a49 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/impl/PeriodicCheckpointer.scala +++ b/core/src/main/scala/org/apache/spark/util/PeriodicCheckpointer.scala @@ -15,7 +15,7 @@ * limitations under the License. */ -package org.apache.spark.mllib.impl +package org.apache.spark.util import scala.collection.mutable @@ -58,7 +58,7 @@ import org.apache.spark.storage.StorageLevel * @param sc SparkContext for the Datasets given to this checkpointer * @tparam T Dataset type, such as RDD[Double] */ -private[mllib] abstract class PeriodicCheckpointer[T]( +private[spark] abstract class PeriodicCheckpointer[T]( val checkpointInterval: Int, val sc: SparkContext) extends Logging { @@ -127,6 +127,16 @@ private[mllib] abstract class PeriodicCheckpointer[T]( /** Get list of checkpoint files for this given Dataset */ protected def getCheckpointFiles(data: T): Iterable[String] + /** + * Call this to unpersist the Dataset. + */ + def unpersistDataSet(): Unit = { + while (persistedQueue.nonEmpty) { + val dataToUnpersist = persistedQueue.dequeue() + unpersist(dataToUnpersist) + } + } + /** * Call this at the end to delete any remaining checkpoint files. */ diff --git a/core/src/main/scala/org/apache/spark/util/UninterruptibleThread.scala b/core/src/main/scala/org/apache/spark/util/UninterruptibleThread.scala index f0b68f0cb7e29..27922b31949b6 100644 --- a/core/src/main/scala/org/apache/spark/util/UninterruptibleThread.scala +++ b/core/src/main/scala/org/apache/spark/util/UninterruptibleThread.scala @@ -27,7 +27,13 @@ import javax.annotation.concurrent.GuardedBy * * Note: "runUninterruptibly" should be called only in `this` thread. */ -private[spark] class UninterruptibleThread(name: String) extends Thread(name) { +private[spark] class UninterruptibleThread( + target: Runnable, + name: String) extends Thread(target, name) { + + def this(name: String) { + this(null, name) + } /** A monitor to protect "uninterruptible" and "interrupted" */ private val uninterruptibleLock = new Object diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala index 1af34e3da231f..4d37db96dfc37 100644 --- a/core/src/main/scala/org/apache/spark/util/Utils.scala +++ b/core/src/main/scala/org/apache/spark/util/Utils.scala @@ -740,7 +740,11 @@ private[spark] object Utils extends Logging { * always return a single directory. */ def getLocalDir(conf: SparkConf): String = { - getOrCreateLocalRootDirs(conf)(0) + getOrCreateLocalRootDirs(conf).headOption.getOrElse { + val configuredLocalDirs = getConfiguredLocalDirs(conf) + throw new IOException( + s"Failed to get a temp directory under [${configuredLocalDirs.mkString(",")}].") + } } private[spark] def isRunningInYarnContainer(conf: SparkConf): Boolean = { @@ -2585,18 +2589,45 @@ private[spark] object Utils extends Logging { } } - private[util] val REDACTION_REPLACEMENT_TEXT = "*********(redacted)" + private[spark] val REDACTION_REPLACEMENT_TEXT = "*********(redacted)" + /** + * Redact the sensitive values in the given map. If a map key matches the redaction pattern then + * its value is replaced with a dummy text. + */ def redact(conf: SparkConf, kvs: Seq[(String, String)]): Seq[(String, String)] = { - val redactionPattern = conf.get(SECRET_REDACTION_PATTERN).r + val redactionPattern = conf.get(SECRET_REDACTION_PATTERN) redact(redactionPattern, kvs) } + /** + * Redact the sensitive information in the given string. + */ + def redact(conf: SparkConf, text: String): String = { + if (text == null || text.isEmpty || !conf.contains(STRING_REDACTION_PATTERN)) return text + val regex = conf.get(STRING_REDACTION_PATTERN).get + regex.replaceAllIn(text, REDACTION_REPLACEMENT_TEXT) + } + private def redact(redactionPattern: Regex, kvs: Seq[(String, String)]): Seq[(String, String)] = { - kvs.map { kv => - redactionPattern.findFirstIn(kv._1) - .map { _ => (kv._1, REDACTION_REPLACEMENT_TEXT) } - .getOrElse(kv) + // If the sensitive information regex matches with either the key or the value, redact the value + // While the original intent was to only redact the value if the key matched with the regex, + // we've found that especially in verbose mode, the value of the property may contain sensitive + // information like so: + // "sun.java.command":"org.apache.spark.deploy.SparkSubmit ... \ + // --conf spark.executorEnv.HADOOP_CREDSTORE_PASSWORD=secret_password ... + // + // And, in such cases, simply searching for the sensitive information regex in the key name is + // not sufficient. The values themselves have to be searched as well and redacted if matched. + // This does mean we may be accounting more false positives - for example, if the value of an + // arbitrary property contained the term 'password', we may redact the value from the UI and + // logs. In order to work around it, user would have to make the spark.redaction.regex property + // more specific. + kvs.map { case (key, value) => + redactionPattern.findFirstIn(key) + .orElse(redactionPattern.findFirstIn(value)) + .map { _ => (key, REDACTION_REPLACEMENT_TEXT) } + .getOrElse((key, value)) } } diff --git a/core/src/main/scala/org/apache/spark/util/collection/MedianHeap.scala b/core/src/main/scala/org/apache/spark/util/collection/MedianHeap.scala new file mode 100644 index 0000000000000..6e57c3c5bee8c --- /dev/null +++ b/core/src/main/scala/org/apache/spark/util/collection/MedianHeap.scala @@ -0,0 +1,93 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.util.collection + +import scala.collection.mutable.PriorityQueue + +/** + * MedianHeap is designed to be used to quickly track the median of a group of numbers + * that may contain duplicates. Inserting a new number has O(log n) time complexity and + * determining the median has O(1) time complexity. + * The basic idea is to maintain two heaps: a smallerHalf and a largerHalf. The smallerHalf + * stores the smaller half of all numbers while the largerHalf stores the larger half. + * The sizes of two heaps need to be balanced each time when a new number is inserted so + * that their sizes will not be different by more than 1. Therefore each time when + * findMedian() is called we check if two heaps have the same size. If they do, we should + * return the average of the two top values of heaps. Otherwise we return the top of the + * heap which has one more element. + */ +private[spark] class MedianHeap(implicit val ord: Ordering[Double]) { + + /** + * Stores all the numbers less than the current median in a smallerHalf, + * i.e median is the maximum, at the root. + */ + private[this] var smallerHalf = PriorityQueue.empty[Double](ord) + + /** + * Stores all the numbers greater than the current median in a largerHalf, + * i.e median is the minimum, at the root. + */ + private[this] var largerHalf = PriorityQueue.empty[Double](ord.reverse) + + def isEmpty(): Boolean = { + smallerHalf.isEmpty && largerHalf.isEmpty + } + + def size(): Int = { + smallerHalf.size + largerHalf.size + } + + def insert(x: Double): Unit = { + // If both heaps are empty, we arbitrarily insert it into a heap, let's say, the largerHalf. + if (isEmpty) { + largerHalf.enqueue(x) + } else { + // If the number is larger than current median, it should be inserted into largerHalf, + // otherwise smallerHalf. + if (x > median) { + largerHalf.enqueue(x) + } else { + smallerHalf.enqueue(x) + } + } + rebalance() + } + + private[this] def rebalance(): Unit = { + if (largerHalf.size - smallerHalf.size > 1) { + smallerHalf.enqueue(largerHalf.dequeue()) + } + if (smallerHalf.size - largerHalf.size > 1) { + largerHalf.enqueue(smallerHalf.dequeue) + } + } + + def median: Double = { + if (isEmpty) { + throw new NoSuchElementException("MedianHeap is empty.") + } + if (largerHalf.size == smallerHalf.size) { + (largerHalf.head + smallerHalf.head) / 2.0 + } else if (largerHalf.size > smallerHalf.size) { + largerHalf.head + } else { + smallerHalf.head + } + } +} diff --git a/core/src/main/scala/org/apache/spark/util/io/ChunkedByteBuffer.scala b/core/src/main/scala/org/apache/spark/util/io/ChunkedByteBuffer.scala index 7572cac39317c..2f905c8af0f63 100644 --- a/core/src/main/scala/org/apache/spark/util/io/ChunkedByteBuffer.scala +++ b/core/src/main/scala/org/apache/spark/util/io/ChunkedByteBuffer.scala @@ -86,7 +86,11 @@ private[spark] class ChunkedByteBuffer(var chunks: Array[ByteBuffer]) { } /** - * Copy this buffer into a new ByteBuffer. + * Convert this buffer to a ByteBuffer. If this buffer is backed by a single chunk, its underlying + * data will not be copied. Instead, it will be duplicated. If this buffer is backed by multiple + * chunks, the data underlying this buffer will be copied into a new byte buffer. As a result, it + * is suggested to use this method only if the caller does not need to manage the memory + * underlying this buffer. * * @throws UnsupportedOperationException if this buffer's size exceeds the max ByteBuffer size. */ @@ -132,10 +136,8 @@ private[spark] class ChunkedByteBuffer(var chunks: Array[ByteBuffer]) { } /** - * Attempt to clean up a ByteBuffer if it is memory-mapped. This uses an *unsafe* Sun API that - * might cause errors if one attempts to read from the unmapped buffer, but it's better than - * waiting for the GC to find it because that could lead to huge numbers of open files. There's - * unfortunately no standard API to do this. + * Attempt to clean up any ByteBuffer in this ChunkedByteBuffer which is direct or memory-mapped. + * See [[StorageUtils.dispose]] for more information. */ def dispose(): Unit = { if (!disposed) { @@ -143,6 +145,7 @@ private[spark] class ChunkedByteBuffer(var chunks: Array[ByteBuffer]) { disposed = true } } + } /** diff --git a/core/src/test/resources/HistoryServerExpectations/executor_memory_usage_expectation.json b/core/src/test/resources/HistoryServerExpectations/executor_memory_usage_expectation.json new file mode 100644 index 0000000000000..0f94e3b255dbc --- /dev/null +++ b/core/src/test/resources/HistoryServerExpectations/executor_memory_usage_expectation.json @@ -0,0 +1,148 @@ +[ { + "id" : "2", + "hostPort" : "172.22.0.167:51487", + "isActive" : true, + "rddBlocks" : 0, + "memoryUsed" : 0, + "diskUsed" : 0, + "totalCores" : 4, + "maxTasks" : 4, + "activeTasks" : 0, + "failedTasks" : 4, + "completedTasks" : 0, + "totalTasks" : 4, + "totalDuration" : 2537, + "totalGCTime" : 88, + "totalInputBytes" : 0, + "totalShuffleRead" : 0, + "totalShuffleWrite" : 0, + "isBlacklisted" : true, + "maxMemory" : 908381388, + "executorLogs" : { + "stdout" : "http://172.22.0.167:51469/logPage/?appId=app-20161116163331-0000&executorId=2&logType=stdout", + "stderr" : "http://172.22.0.167:51469/logPage/?appId=app-20161116163331-0000&executorId=2&logType=stderr" + }, + "memoryMetrics": { + "usedOnHeapStorageMemory": 0, + "usedOffHeapStorageMemory": 0, + "totalOnHeapStorageMemory": 384093388, + "totalOffHeapStorageMemory": 524288000 + } +}, { + "id" : "driver", + "hostPort" : "172.22.0.167:51475", + "isActive" : true, + "rddBlocks" : 0, + "memoryUsed" : 0, + "diskUsed" : 0, + "totalCores" : 0, + "maxTasks" : 0, + "activeTasks" : 0, + "failedTasks" : 0, + "completedTasks" : 0, + "totalTasks" : 0, + "totalDuration" : 0, + "totalGCTime" : 0, + "totalInputBytes" : 0, + "totalShuffleRead" : 0, + "totalShuffleWrite" : 0, + "isBlacklisted" : true, + "maxMemory" : 908381388, + "executorLogs" : { }, + "memoryMetrics": { + "usedOnHeapStorageMemory": 0, + "usedOffHeapStorageMemory": 0, + "totalOnHeapStorageMemory": 384093388, + "totalOffHeapStorageMemory": 524288000 + } +}, { + "id" : "1", + "hostPort" : "172.22.0.167:51490", + "isActive" : true, + "rddBlocks" : 0, + "memoryUsed" : 0, + "diskUsed" : 0, + "totalCores" : 4, + "maxTasks" : 4, + "activeTasks" : 0, + "failedTasks" : 0, + "completedTasks" : 4, + "totalTasks" : 4, + "totalDuration" : 3152, + "totalGCTime" : 68, + "totalInputBytes" : 0, + "totalShuffleRead" : 0, + "totalShuffleWrite" : 0, + "isBlacklisted" : true, + "maxMemory" : 908381388, + "executorLogs" : { + "stdout" : "http://172.22.0.167:51467/logPage/?appId=app-20161116163331-0000&executorId=1&logType=stdout", + "stderr" : "http://172.22.0.167:51467/logPage/?appId=app-20161116163331-0000&executorId=1&logType=stderr" + }, + "memoryMetrics": { + "usedOnHeapStorageMemory": 0, + "usedOffHeapStorageMemory": 0, + "totalOnHeapStorageMemory": 384093388, + "totalOffHeapStorageMemory": 524288000 + } +}, { + "id" : "0", + "hostPort" : "172.22.0.167:51491", + "isActive" : true, + "rddBlocks" : 0, + "memoryUsed" : 0, + "diskUsed" : 0, + "totalCores" : 4, + "maxTasks" : 4, + "activeTasks" : 0, + "failedTasks" : 4, + "completedTasks" : 0, + "totalTasks" : 4, + "totalDuration" : 2551, + "totalGCTime" : 116, + "totalInputBytes" : 0, + "totalShuffleRead" : 0, + "totalShuffleWrite" : 0, + "isBlacklisted" : true, + "maxMemory" : 908381388, + "executorLogs" : { + "stdout" : "http://172.22.0.167:51465/logPage/?appId=app-20161116163331-0000&executorId=0&logType=stdout", + "stderr" : "http://172.22.0.167:51465/logPage/?appId=app-20161116163331-0000&executorId=0&logType=stderr" + }, + "memoryMetrics": { + "usedOnHeapStorageMemory": 0, + "usedOffHeapStorageMemory": 0, + "totalOnHeapStorageMemory": 384093388, + "totalOffHeapStorageMemory": 524288000 + } +}, { + "id" : "3", + "hostPort" : "172.22.0.167:51485", + "isActive" : true, + "rddBlocks" : 0, + "memoryUsed" : 0, + "diskUsed" : 0, + "totalCores" : 4, + "maxTasks" : 4, + "activeTasks" : 0, + "failedTasks" : 0, + "completedTasks" : 12, + "totalTasks" : 12, + "totalDuration" : 2453, + "totalGCTime" : 72, + "totalInputBytes" : 0, + "totalShuffleRead" : 0, + "totalShuffleWrite" : 0, + "isBlacklisted" : true, + "maxMemory" : 908381388, + "executorLogs" : { + "stdout" : "http://172.22.0.167:51466/logPage/?appId=app-20161116163331-0000&executorId=3&logType=stdout", + "stderr" : "http://172.22.0.167:51466/logPage/?appId=app-20161116163331-0000&executorId=3&logType=stderr" + }, + "memoryMetrics": { + "usedOnHeapStorageMemory": 0, + "usedOffHeapStorageMemory": 0, + "totalOnHeapStorageMemory": 384093388, + "totalOffHeapStorageMemory": 524288000 + } +} ] diff --git a/core/src/test/resources/HistoryServerExpectations/executor_node_blacklisting_expectation.json b/core/src/test/resources/HistoryServerExpectations/executor_node_blacklisting_expectation.json index 5914a1c2c4b6d..0f94e3b255dbc 100644 --- a/core/src/test/resources/HistoryServerExpectations/executor_node_blacklisting_expectation.json +++ b/core/src/test/resources/HistoryServerExpectations/executor_node_blacklisting_expectation.json @@ -17,10 +17,16 @@ "totalShuffleRead" : 0, "totalShuffleWrite" : 0, "isBlacklisted" : true, - "maxMemory" : 384093388, + "maxMemory" : 908381388, "executorLogs" : { "stdout" : "http://172.22.0.167:51469/logPage/?appId=app-20161116163331-0000&executorId=2&logType=stdout", "stderr" : "http://172.22.0.167:51469/logPage/?appId=app-20161116163331-0000&executorId=2&logType=stderr" + }, + "memoryMetrics": { + "usedOnHeapStorageMemory": 0, + "usedOffHeapStorageMemory": 0, + "totalOnHeapStorageMemory": 384093388, + "totalOffHeapStorageMemory": 524288000 } }, { "id" : "driver", @@ -41,8 +47,14 @@ "totalShuffleRead" : 0, "totalShuffleWrite" : 0, "isBlacklisted" : true, - "maxMemory" : 384093388, - "executorLogs" : { } + "maxMemory" : 908381388, + "executorLogs" : { }, + "memoryMetrics": { + "usedOnHeapStorageMemory": 0, + "usedOffHeapStorageMemory": 0, + "totalOnHeapStorageMemory": 384093388, + "totalOffHeapStorageMemory": 524288000 + } }, { "id" : "1", "hostPort" : "172.22.0.167:51490", @@ -62,10 +74,16 @@ "totalShuffleRead" : 0, "totalShuffleWrite" : 0, "isBlacklisted" : true, - "maxMemory" : 384093388, + "maxMemory" : 908381388, "executorLogs" : { "stdout" : "http://172.22.0.167:51467/logPage/?appId=app-20161116163331-0000&executorId=1&logType=stdout", "stderr" : "http://172.22.0.167:51467/logPage/?appId=app-20161116163331-0000&executorId=1&logType=stderr" + }, + "memoryMetrics": { + "usedOnHeapStorageMemory": 0, + "usedOffHeapStorageMemory": 0, + "totalOnHeapStorageMemory": 384093388, + "totalOffHeapStorageMemory": 524288000 } }, { "id" : "0", @@ -86,10 +104,16 @@ "totalShuffleRead" : 0, "totalShuffleWrite" : 0, "isBlacklisted" : true, - "maxMemory" : 384093388, + "maxMemory" : 908381388, "executorLogs" : { "stdout" : "http://172.22.0.167:51465/logPage/?appId=app-20161116163331-0000&executorId=0&logType=stdout", "stderr" : "http://172.22.0.167:51465/logPage/?appId=app-20161116163331-0000&executorId=0&logType=stderr" + }, + "memoryMetrics": { + "usedOnHeapStorageMemory": 0, + "usedOffHeapStorageMemory": 0, + "totalOnHeapStorageMemory": 384093388, + "totalOffHeapStorageMemory": 524288000 } }, { "id" : "3", @@ -110,9 +134,15 @@ "totalShuffleRead" : 0, "totalShuffleWrite" : 0, "isBlacklisted" : true, - "maxMemory" : 384093388, + "maxMemory" : 908381388, "executorLogs" : { "stdout" : "http://172.22.0.167:51466/logPage/?appId=app-20161116163331-0000&executorId=3&logType=stdout", "stderr" : "http://172.22.0.167:51466/logPage/?appId=app-20161116163331-0000&executorId=3&logType=stderr" + }, + "memoryMetrics": { + "usedOnHeapStorageMemory": 0, + "usedOffHeapStorageMemory": 0, + "totalOnHeapStorageMemory": 384093388, + "totalOffHeapStorageMemory": 524288000 } } ] diff --git a/core/src/test/resources/spark-events/app-20161116163331-0000 b/core/src/test/resources/spark-events/app-20161116163331-0000 index 7566c9fc0a20b..57cfc5b973129 100755 --- a/core/src/test/resources/spark-events/app-20161116163331-0000 +++ b/core/src/test/resources/spark-events/app-20161116163331-0000 @@ -1,15 +1,15 @@ {"Event":"SparkListenerLogStart","Spark Version":"2.1.0-SNAPSHOT"} -{"Event":"SparkListenerBlockManagerAdded","Block Manager ID":{"Executor ID":"driver","Host":"172.22.0.167","Port":51475},"Maximum Memory":384093388,"Timestamp":1479335611477} +{"Event":"SparkListenerBlockManagerAdded","Block Manager ID":{"Executor ID":"driver","Host":"172.22.0.167","Port":51475},"Maximum Memory":908381388,"Timestamp":1479335611477,"Maximum Onheap Memory":384093388,"Maximum Offheap Memory":524288000} {"Event":"SparkListenerEnvironmentUpdate","JVM Information":{"Java Home":"/Library/Java/JavaVirtualMachines/jdk1.8.0_92.jdk/Contents/Home/jre","Java Version":"1.8.0_92 (Oracle Corporation)","Scala Version":"version 2.11.8"},"Spark Properties":{"spark.blacklist.task.maxTaskAttemptsPerExecutor":"3","spark.blacklist.enabled":"TRUE","spark.driver.host":"172.22.0.167","spark.blacklist.task.maxTaskAttemptsPerNode":"3","spark.eventLog.enabled":"TRUE","spark.driver.port":"51459","spark.repl.class.uri":"spark://172.22.0.167:51459/classes","spark.jars":"","spark.repl.class.outputDir":"/private/var/folders/l4/d46wlzj16593f3d812vk49tw0000gp/T/spark-1cbc97d0-7fe6-4c9f-8c2c-f6fe51ee3cf2/repl-39929169-ac4c-4c6d-b116-f648e4dd62ed","spark.app.name":"Spark shell","spark.blacklist.stage.maxFailedExecutorsPerNode":"3","spark.scheduler.mode":"FIFO","spark.eventLog.overwrite":"TRUE","spark.blacklist.stage.maxFailedTasksPerExecutor":"3","spark.executor.id":"driver","spark.blacklist.application.maxFailedExecutorsPerNode":"2","spark.submit.deployMode":"client","spark.master":"local-cluster[4,4,1024]","spark.home":"/Users/Jose/IdeaProjects/spark","spark.eventLog.dir":"/Users/jose/logs","spark.sql.catalogImplementation":"in-memory","spark.eventLog.compress":"FALSE","spark.blacklist.application.maxFailedTasksPerExecutor":"1","spark.blacklist.timeout":"1000000","spark.app.id":"app-20161116163331-0000","spark.task.maxFailures":"4"},"System Properties":{"java.io.tmpdir":"/var/folders/l4/d46wlzj16593f3d812vk49tw0000gp/T/","line.separator":"\n","path.separator":":","sun.management.compiler":"HotSpot 64-Bit Tiered Compilers","SPARK_SUBMIT":"true","sun.cpu.endian":"little","java.specification.version":"1.8","java.vm.specification.name":"Java Virtual Machine Specification","java.vendor":"Oracle Corporation","java.vm.specification.version":"1.8","user.home":"/Users/Jose","file.encoding.pkg":"sun.io","sun.nio.ch.bugLevel":"","ftp.nonProxyHosts":"local|*.local|169.254/16|*.169.254/16","sun.arch.data.model":"64","sun.boot.library.path":"/Library/Java/JavaVirtualMachines/jdk1.8.0_92.jdk/Contents/Home/jre/lib","user.dir":"/Users/Jose/IdeaProjects/spark","java.library.path":"/Users/Jose/Library/Java/Extensions:/Library/Java/Extensions:/Network/Library/Java/Extensions:/System/Library/Java/Extensions:/usr/lib/java:.","sun.cpu.isalist":"","os.arch":"x86_64","java.vm.version":"25.92-b14","java.endorsed.dirs":"/Library/Java/JavaVirtualMachines/jdk1.8.0_92.jdk/Contents/Home/jre/lib/endorsed","java.runtime.version":"1.8.0_92-b14","java.vm.info":"mixed mode","java.ext.dirs":"/Users/Jose/Library/Java/Extensions:/Library/Java/JavaVirtualMachines/jdk1.8.0_92.jdk/Contents/Home/jre/lib/ext:/Library/Java/Extensions:/Network/Library/Java/Extensions:/System/Library/Java/Extensions:/usr/lib/java","java.runtime.name":"Java(TM) SE Runtime Environment","file.separator":"/","io.netty.maxDirectMemory":"0","java.class.version":"52.0","scala.usejavacp":"true","java.specification.name":"Java Platform API Specification","sun.boot.class.path":"/Library/Java/JavaVirtualMachines/jdk1.8.0_92.jdk/Contents/Home/jre/lib/resources.jar:/Library/Java/JavaVirtualMachines/jdk1.8.0_92.jdk/Contents/Home/jre/lib/rt.jar:/Library/Java/JavaVirtualMachines/jdk1.8.0_92.jdk/Contents/Home/jre/lib/sunrsasign.jar:/Library/Java/JavaVirtualMachines/jdk1.8.0_92.jdk/Contents/Home/jre/lib/jsse.jar:/Library/Java/JavaVirtualMachines/jdk1.8.0_92.jdk/Contents/Home/jre/lib/jce.jar:/Library/Java/JavaVirtualMachines/jdk1.8.0_92.jdk/Contents/Home/jre/lib/charsets.jar:/Library/Java/JavaVirtualMachines/jdk1.8.0_92.jdk/Contents/Home/jre/lib/jfr.jar:/Library/Java/JavaVirtualMachines/jdk1.8.0_92.jdk/Contents/Home/jre/classes","file.encoding":"UTF-8","user.timezone":"America/Chicago","java.specification.vendor":"Oracle Corporation","sun.java.launcher":"SUN_STANDARD","os.version":"10.11.6","sun.os.patch.level":"unknown","gopherProxySet":"false","java.vm.specification.vendor":"Oracle Corporation","user.country":"US","sun.jnu.encoding":"UTF-8","http.nonProxyHosts":"local|*.local|169.254/16|*.169.254/16","user.language":"en","socksNonProxyHosts":"local|*.local|169.254/16|*.169.254/16","java.vendor.url":"http://java.oracle.com/","java.awt.printerjob":"sun.lwawt.macosx.CPrinterJob","java.awt.graphicsenv":"sun.awt.CGraphicsEnvironment","awt.toolkit":"sun.lwawt.macosx.LWCToolkit","os.name":"Mac OS X","java.vm.vendor":"Oracle Corporation","java.vendor.url.bug":"http://bugreport.sun.com/bugreport/","user.name":"jose","java.vm.name":"Java HotSpot(TM) 64-Bit Server VM","sun.java.command":"org.apache.spark.deploy.SparkSubmit --master local-cluster[4,4,1024] --conf spark.blacklist.enabled=TRUE --conf spark.blacklist.timeout=1000000 --conf spark.blacklist.application.maxFailedTasksPerExecutor=1 --conf spark.eventLog.overwrite=TRUE --conf spark.blacklist.task.maxTaskAttemptsPerNode=3 --conf spark.blacklist.stage.maxFailedTasksPerExecutor=3 --conf spark.blacklist.task.maxTaskAttemptsPerExecutor=3 --conf spark.eventLog.compress=FALSE --conf spark.blacklist.stage.maxFailedExecutorsPerNode=3 --conf spark.eventLog.enabled=TRUE --conf spark.eventLog.dir=/Users/jose/logs --conf spark.blacklist.application.maxFailedExecutorsPerNode=2 --conf spark.task.maxFailures=4 --class org.apache.spark.repl.Main --name Spark shell spark-shell -i /Users/Jose/dev/jose-utils/blacklist/test-blacklist.scala","java.home":"/Library/Java/JavaVirtualMachines/jdk1.8.0_92.jdk/Contents/Home/jre","java.version":"1.8.0_92","sun.io.unicode.encoding":"UnicodeBig"},"Classpath Entries":{"/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/avro-mapred-1.7.7-hadoop2.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/hadoop-mapreduce-client-core-2.2.0.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/jetty-servlet-9.2.16.v20160414.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/parquet-column-1.8.1.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/snappy-java-1.1.2.6.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/oro-2.0.8.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/arpack_combined_all-0.1.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/pmml-schema-1.2.15.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/spark-assembly_2.11-2.1.0-SNAPSHOT.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/javassist-3.18.1-GA.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/spark-tags_2.11-2.1.0-SNAPSHOT.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/spark-launcher_2.11-2.1.0-SNAPSHOT.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/commons-math3-3.4.1.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/hk2-api-2.4.0-b34.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/scala-xml_2.11-1.0.4.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/objenesis-2.1.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/spire-macros_2.11-0.7.4.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/scala-reflect-2.11.8.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/spark-mllib-local_2.11-2.1.0-SNAPSHOT.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/spark-mllib_2.11-2.1.0-SNAPSHOT.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/jersey-server-2.22.2.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/core/target/scala-2.11/classes/":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/jackson-mapper-asl-1.9.13.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/jackson-module-scala_2.11-2.6.5.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/curator-framework-2.4.0.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/javax.inject-1.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/curator-client-2.4.0.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/jackson-core-asl-1.9.13.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/common/network-common/target/scala-2.11/classes/":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/zookeeper-3.4.5.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/hadoop-auth-2.2.0.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/repl/target/scala-2.11/classes/":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/jul-to-slf4j-1.7.16.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/jersey-media-jaxb-2.22.2.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/jetty-io-9.2.16.v20160414.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/RoaringBitmap-0.5.11.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/javax.ws.rs-api-2.0.1.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/sql/catalyst/target/scala-2.11/classes/":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/spark-unsafe_2.11-2.1.0-SNAPSHOT.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/spark-repl_2.11-2.1.0-SNAPSHOT.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/jetty-continuation-9.2.16.v20160414.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/hadoop-yarn-client-2.2.0.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/sql/hive-thriftserver/target/scala-2.11/classes":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/hadoop-annotations-2.2.0.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/metrics-graphite-3.1.2.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/hadoop-yarn-api-2.2.0.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/jersey-container-servlet-core-2.22.2.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/streaming/target/scala-2.11/classes/":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/commons-net-3.1.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/jetty-proxy-9.2.16.v20160414.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/spark-catalyst_2.11-2.1.0-SNAPSHOT.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/lz4-1.3.0.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/commons-crypto-1.0.0.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/common/network-yarn/target/scala-2.11/classes":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/javax.annotation-api-1.2.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/spark-sql_2.11-2.1.0-SNAPSHOT.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/guava-14.0.1.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/javax.servlet-api-3.1.0.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/commons-collections-3.2.1.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/conf/":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/unused-1.0.0.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/aopalliance-1.0.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/parquet-encoding-1.8.1.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/common/tags/target/scala-2.11/classes/":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/json4s-jackson_2.11-3.2.11.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/commons-cli-1.2.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/hadoop-yarn-server-common-2.2.0.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/cglib-2.2.1-v20090111.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/pyrolite-4.13.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/scala-library-2.11.8.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/scala-parser-combinators_2.11-1.0.4.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/jetty-util-6.1.26.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/py4j-0.10.4.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/commons-configuration-1.6.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/core-1.1.2.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/core/target/jars/*":"System Classpath","/Users/Jose/IdeaProjects/spark/common/network-shuffle/target/scala-2.11/classes/":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/parquet-format-2.3.0-incubating.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/kryo-shaded-3.0.3.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/sql/core/target/scala-2.11/classes/":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/chill-java-0.8.0.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/jackson-annotations-2.6.5.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/parquet-hadoop-1.8.1.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/sql/hive/target/scala-2.11/classes/":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/avro-ipc-1.7.7.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/xz-1.0.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/parquet-jackson-1.8.1.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/aopalliance-repackaged-2.4.0-b34.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/jersey-common-2.22.2.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/log4j-1.2.17.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/metrics-core-3.1.2.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/jetty-util-9.2.16.v20160414.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/scalap-2.11.0.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/osgi-resource-locator-1.0.1.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/commons-beanutils-1.7.0.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/commons-compress-1.4.1.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/jcl-over-slf4j-1.7.16.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/yarn/target/scala-2.11/classes":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/jetty-plus-9.2.16.v20160414.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/protobuf-java-2.5.0.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/common/unsafe/target/scala-2.11/classes/":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/jackson-module-paranamer-2.6.5.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/leveldbjni-all-1.8.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/jackson-core-2.6.5.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/slf4j-api-1.7.16.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/compress-lzf-1.0.3.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/stream-2.7.0.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/hadoop-mapreduce-client-shuffle-2.2.0.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/commons-codec-1.10.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/hadoop-yarn-common-2.2.0.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/common/sketch/target/scala-2.11/classes/":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/breeze_2.11-0.12.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/hadoop-mapreduce-client-common-2.2.0.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/spark-core_2.11-2.1.0-SNAPSHOT.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/jersey-container-servlet-2.22.2.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/spark-network-shuffle_2.11-2.1.0-SNAPSHOT.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/commons-lang-2.5.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/ivy-2.4.0.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/hadoop-common-2.2.0.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/commons-math-2.1.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/hadoop-hdfs-2.2.0.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/scala-compiler-2.11.8.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/metrics-jvm-3.1.2.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/commons-lang3-3.5.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/jsr305-1.3.9.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/minlog-1.3.0.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/netty-3.8.0.Final.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/jetty-webapp-9.2.16.v20160414.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/json4s-ast_2.11-3.2.11.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/xbean-asm5-shaded-4.4.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/commons-io-2.1.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/slf4j-log4j12-1.7.16.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/hk2-locator-2.4.0-b34.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/shapeless_2.11-2.0.0.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/spark-network-common_2.11-2.1.0-SNAPSHOT.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/jetty-xml-9.2.16.v20160414.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/commons-httpclient-3.1.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/javax.inject-2.4.0-b34.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/mllib/target/scala-2.11/classes/":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/scalatest_2.11-2.2.6.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/hk2-utils-2.4.0-b34.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/jetty-client-9.2.16.v20160414.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/jersey-guava-2.22.2.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/jetty-jndi-9.2.16.v20160414.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/graphx/target/scala-2.11/classes/":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/hadoop-mapreduce-client-app-2.2.0.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/examples/target/scala-2.11/classes/":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/xmlenc-0.52.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/jets3t-0.7.1.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/curator-recipes-2.4.0.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/opencsv-2.3.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/jtransforms-2.4.0.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/antlr4-runtime-4.5.3.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/chill_2.11-0.8.0.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/commons-digester-1.8.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/univocity-parsers-2.2.1.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/jline-2.12.1.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/spark-streaming_2.11-2.1.0-SNAPSHOT.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/launcher/target/scala-2.11/classes/":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/breeze-macros_2.11-0.12.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/jersey-client-2.22.2.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/jackson-databind-2.6.5.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/jetty-servlets-9.2.16.v20160414.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/paranamer-2.6.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/jetty-security-9.2.16.v20160414.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/avro-ipc-1.7.7-tests.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/avro-1.7.7.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/spire_2.11-0.7.4.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/hadoop-client-2.2.0.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/metrics-json-3.1.2.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/commons-beanutils-core-1.8.0.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/validation-api-1.1.0.Final.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/spark-graphx_2.11-2.1.0-SNAPSHOT.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/netty-all-4.0.41.Final.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/janino-3.0.0.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/json4s-core_2.11-3.2.11.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/commons-compiler-3.0.0.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/guice-3.0.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/jetty-server-9.2.16.v20160414.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/jetty-http-9.2.16.v20160414.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/parquet-common-1.8.1.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/hadoop-mapreduce-client-jobclient-2.2.0.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/spark-sketch_2.11-2.1.0-SNAPSHOT.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/pmml-model-1.2.15.jar":"System Classpath"}} {"Event":"SparkListenerApplicationStart","App Name":"Spark shell","App ID":"app-20161116163331-0000","Timestamp":1479335609916,"User":"jose"} {"Event":"SparkListenerExecutorAdded","Timestamp":1479335615320,"Executor ID":"3","Executor Info":{"Host":"172.22.0.167","Total Cores":4,"Log Urls":{"stdout":"http://172.22.0.167:51466/logPage/?appId=app-20161116163331-0000&executorId=3&logType=stdout","stderr":"http://172.22.0.167:51466/logPage/?appId=app-20161116163331-0000&executorId=3&logType=stderr"}}} -{"Event":"SparkListenerBlockManagerAdded","Block Manager ID":{"Executor ID":"3","Host":"172.22.0.167","Port":51485},"Maximum Memory":384093388,"Timestamp":1479335615387} +{"Event":"SparkListenerBlockManagerAdded","Block Manager ID":{"Executor ID":"3","Host":"172.22.0.167","Port":51485},"Maximum Memory":908381388,"Timestamp":1479335615387,"Maximum Onheap Memory":384093388,"Maximum Offheap Memory":524288000} {"Event":"SparkListenerExecutorAdded","Timestamp":1479335615393,"Executor ID":"2","Executor Info":{"Host":"172.22.0.167","Total Cores":4,"Log Urls":{"stdout":"http://172.22.0.167:51469/logPage/?appId=app-20161116163331-0000&executorId=2&logType=stdout","stderr":"http://172.22.0.167:51469/logPage/?appId=app-20161116163331-0000&executorId=2&logType=stderr"}}} {"Event":"SparkListenerExecutorAdded","Timestamp":1479335615443,"Executor ID":"1","Executor Info":{"Host":"172.22.0.167","Total Cores":4,"Log Urls":{"stdout":"http://172.22.0.167:51467/logPage/?appId=app-20161116163331-0000&executorId=1&logType=stdout","stderr":"http://172.22.0.167:51467/logPage/?appId=app-20161116163331-0000&executorId=1&logType=stderr"}}} -{"Event":"SparkListenerBlockManagerAdded","Block Manager ID":{"Executor ID":"2","Host":"172.22.0.167","Port":51487},"Maximum Memory":384093388,"Timestamp":1479335615448} +{"Event":"SparkListenerBlockManagerAdded","Block Manager ID":{"Executor ID":"2","Host":"172.22.0.167","Port":51487},"Maximum Memory":908381388,"Timestamp":1479335615448,"Maximum Onheap Memory":384093388,"Maximum Offheap Memory":524288000} {"Event":"SparkListenerExecutorAdded","Timestamp":1479335615462,"Executor ID":"0","Executor Info":{"Host":"172.22.0.167","Total Cores":4,"Log Urls":{"stdout":"http://172.22.0.167:51465/logPage/?appId=app-20161116163331-0000&executorId=0&logType=stdout","stderr":"http://172.22.0.167:51465/logPage/?appId=app-20161116163331-0000&executorId=0&logType=stderr"}}} -{"Event":"SparkListenerBlockManagerAdded","Block Manager ID":{"Executor ID":"1","Host":"172.22.0.167","Port":51490},"Maximum Memory":384093388,"Timestamp":1479335615496} -{"Event":"SparkListenerBlockManagerAdded","Block Manager ID":{"Executor ID":"0","Host":"172.22.0.167","Port":51491},"Maximum Memory":384093388,"Timestamp":1479335615515} +{"Event":"SparkListenerBlockManagerAdded","Block Manager ID":{"Executor ID":"1","Host":"172.22.0.167","Port":51490},"Maximum Memory":908381388,"Timestamp":1479335615496,"Maximum Onheap Memory":384093388,"Maximum Offheap Memory":524288000} +{"Event":"SparkListenerBlockManagerAdded","Block Manager ID":{"Executor ID":"0","Host":"172.22.0.167","Port":51491},"Maximum Memory":908381388,"Timestamp":1479335615515,"Maximum Onheap Memory":384093388,"Maximum Offheap Memory":524288000} {"Event":"SparkListenerJobStart","Job ID":0,"Submission Time":1479335616467,"Stage Infos":[{"Stage ID":0,"Stage Attempt ID":0,"Stage Name":"count at :26","Number of Tasks":16,"RDD Info":[{"RDD ID":1,"Name":"MapPartitionsRDD","Scope":"{\"id\":\"1\",\"name\":\"map\"}","Callsite":"map at :26","Parent IDs":[0],"Storage Level":{"Use Disk":false,"Use Memory":false,"Deserialized":false,"Replication":1},"Number of Partitions":16,"Number of Cached Partitions":0,"Memory Size":0,"Disk Size":0},{"RDD ID":0,"Name":"ParallelCollectionRDD","Scope":"{\"id\":\"0\",\"name\":\"parallelize\"}","Callsite":"parallelize at :26","Parent IDs":[],"Storage Level":{"Use Disk":false,"Use Memory":false,"Deserialized":false,"Replication":1},"Number of Partitions":16,"Number of Cached Partitions":0,"Memory Size":0,"Disk Size":0}],"Parent IDs":[],"Details":"org.apache.spark.rdd.RDD.count(RDD.scala:1135)\n$line16.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw.(:26)\n$line16.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw.(:31)\n$line16.$read$$iw$$iw$$iw$$iw$$iw$$iw.(:33)\n$line16.$read$$iw$$iw$$iw$$iw$$iw.(:35)\n$line16.$read$$iw$$iw$$iw$$iw.(:37)\n$line16.$read$$iw$$iw$$iw.(:39)\n$line16.$read$$iw$$iw.(:41)\n$line16.$read$$iw.(:43)\n$line16.$read.(:45)\n$line16.$read$.(:49)\n$line16.$read$.()\n$line16.$eval$.$print$lzycompute(:7)\n$line16.$eval$.$print(:6)\n$line16.$eval.$print()\nsun.reflect.NativeMethodAccessorImpl.invoke0(Native Method)\nsun.reflect.NativeMethodAccessorImpl.invoke(NativeMethodAccessorImpl.java:62)\nsun.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43)\njava.lang.reflect.Method.invoke(Method.java:498)\nscala.tools.nsc.interpreter.IMain$ReadEvalPrint.call(IMain.scala:786)","Accumulables":[]}],"Stage IDs":[0],"Properties":{}} {"Event":"SparkListenerStageSubmitted","Stage Info":{"Stage ID":0,"Stage Attempt ID":0,"Stage Name":"count at :26","Number of Tasks":16,"RDD Info":[{"RDD ID":1,"Name":"MapPartitionsRDD","Scope":"{\"id\":\"1\",\"name\":\"map\"}","Callsite":"map at :26","Parent IDs":[0],"Storage Level":{"Use Disk":false,"Use Memory":false,"Deserialized":false,"Replication":1},"Number of Partitions":16,"Number of Cached Partitions":0,"Memory Size":0,"Disk Size":0},{"RDD ID":0,"Name":"ParallelCollectionRDD","Scope":"{\"id\":\"0\",\"name\":\"parallelize\"}","Callsite":"parallelize at :26","Parent IDs":[],"Storage Level":{"Use Disk":false,"Use Memory":false,"Deserialized":false,"Replication":1},"Number of Partitions":16,"Number of Cached Partitions":0,"Memory Size":0,"Disk Size":0}],"Parent IDs":[],"Details":"org.apache.spark.rdd.RDD.count(RDD.scala:1135)\n$line16.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw.(:26)\n$line16.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw.(:31)\n$line16.$read$$iw$$iw$$iw$$iw$$iw$$iw.(:33)\n$line16.$read$$iw$$iw$$iw$$iw$$iw.(:35)\n$line16.$read$$iw$$iw$$iw$$iw.(:37)\n$line16.$read$$iw$$iw$$iw.(:39)\n$line16.$read$$iw$$iw.(:41)\n$line16.$read$$iw.(:43)\n$line16.$read.(:45)\n$line16.$read$.(:49)\n$line16.$read$.()\n$line16.$eval$.$print$lzycompute(:7)\n$line16.$eval$.$print(:6)\n$line16.$eval.$print()\nsun.reflect.NativeMethodAccessorImpl.invoke0(Native Method)\nsun.reflect.NativeMethodAccessorImpl.invoke(NativeMethodAccessorImpl.java:62)\nsun.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43)\njava.lang.reflect.Method.invoke(Method.java:498)\nscala.tools.nsc.interpreter.IMain$ReadEvalPrint.call(IMain.scala:786)","Accumulables":[]},"Properties":{}} {"Event":"SparkListenerTaskStart","Stage ID":0,"Stage Attempt ID":0,"Task Info":{"Task ID":0,"Index":0,"Attempt":0,"Launch Time":1479335616657,"Executor ID":"1","Host":"172.22.0.167","Locality":"PROCESS_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":0,"Failed":false,"Killed":false,"Accumulables":[]}} diff --git a/core/src/test/scala/org/apache/spark/AccumulatorSuite.scala b/core/src/test/scala/org/apache/spark/AccumulatorSuite.scala index 6d03ee091e4ed..ddbcb2d19dcbb 100644 --- a/core/src/test/scala/org/apache/spark/AccumulatorSuite.scala +++ b/core/src/test/scala/org/apache/spark/AccumulatorSuite.scala @@ -243,7 +243,7 @@ private[spark] object AccumulatorSuite { import InternalAccumulator._ /** - * Create a long accumulator and register it to [[AccumulatorContext]]. + * Create a long accumulator and register it to `AccumulatorContext`. */ def createLongAccum( name: String, @@ -258,7 +258,7 @@ private[spark] object AccumulatorSuite { } /** - * Make an [[AccumulableInfo]] out of an [[Accumulable]] with the intent to use the + * Make an `AccumulableInfo` out of an [[Accumulable]] with the intent to use the * info as an accumulator update. */ def makeInfo(a: AccumulatorV2[_, _]): AccumulableInfo = a.toInfo(Some(a.value), None) diff --git a/core/src/test/scala/org/apache/spark/CheckpointSuite.scala b/core/src/test/scala/org/apache/spark/CheckpointSuite.scala index b117c7709b46f..ee70a3399efed 100644 --- a/core/src/test/scala/org/apache/spark/CheckpointSuite.scala +++ b/core/src/test/scala/org/apache/spark/CheckpointSuite.scala @@ -21,8 +21,10 @@ import java.io.File import scala.reflect.ClassTag +import com.google.common.io.ByteStreams import org.apache.hadoop.fs.Path +import org.apache.spark.io.CompressionCodec import org.apache.spark.rdd._ import org.apache.spark.storage.{BlockId, StorageLevel, TestBlockId} import org.apache.spark.util.Utils @@ -580,3 +582,42 @@ object CheckpointSuite { ).asInstanceOf[RDD[(K, Array[Iterable[V]])]] } } + +class CheckpointCompressionSuite extends SparkFunSuite with LocalSparkContext { + + test("checkpoint compression") { + val checkpointDir = Utils.createTempDir() + try { + val conf = new SparkConf() + .set("spark.checkpoint.compress", "true") + .set("spark.ui.enabled", "false") + sc = new SparkContext("local", "test", conf) + sc.setCheckpointDir(checkpointDir.toString) + val rdd = sc.makeRDD(1 to 20, numSlices = 1) + rdd.checkpoint() + assert(rdd.collect().toSeq === (1 to 20)) + + // Verify that RDD is checkpointed + assert(rdd.firstParent.isInstanceOf[ReliableCheckpointRDD[_]]) + + val checkpointPath = new Path(rdd.getCheckpointFile.get) + val fs = checkpointPath.getFileSystem(sc.hadoopConfiguration) + val checkpointFile = + fs.listStatus(checkpointPath).map(_.getPath).find(_.getName.startsWith("part-")).get + + // Verify the checkpoint file is compressed, in other words, can be decompressed + val compressedInputStream = CompressionCodec.createCodec(conf) + .compressedInputStream(fs.open(checkpointFile)) + try { + ByteStreams.toByteArray(compressedInputStream) + } finally { + compressedInputStream.close() + } + + // Verify that the compressed content can be read back + assert(rdd.collect().toSeq === (1 to 20)) + } finally { + Utils.deleteRecursively(checkpointDir) + } + } +} diff --git a/core/src/test/scala/org/apache/spark/DebugFilesystem.scala b/core/src/test/scala/org/apache/spark/DebugFilesystem.scala index fb8d701ebda8a..91355f7362900 100644 --- a/core/src/test/scala/org/apache/spark/DebugFilesystem.scala +++ b/core/src/test/scala/org/apache/spark/DebugFilesystem.scala @@ -20,7 +20,6 @@ package org.apache.spark import java.io.{FileDescriptor, InputStream} import java.lang import java.nio.ByteBuffer -import java.util.concurrent.ConcurrentHashMap import scala.collection.JavaConverters._ import scala.collection.mutable @@ -31,20 +30,29 @@ import org.apache.spark.internal.Logging object DebugFilesystem extends Logging { // Stores the set of active streams and their creation sites. - private val openStreams = new ConcurrentHashMap[FSDataInputStream, Throwable]() + private val openStreams = mutable.Map.empty[FSDataInputStream, Throwable] - def clearOpenStreams(): Unit = { + def addOpenStream(stream: FSDataInputStream): Unit = openStreams.synchronized { + openStreams.put(stream, new Throwable()) + } + + def clearOpenStreams(): Unit = openStreams.synchronized { openStreams.clear() } - def assertNoOpenStreams(): Unit = { - val numOpen = openStreams.size() + def removeOpenStream(stream: FSDataInputStream): Unit = openStreams.synchronized { + openStreams.remove(stream) + } + + def assertNoOpenStreams(): Unit = openStreams.synchronized { + val numOpen = openStreams.values.size if (numOpen > 0) { - for (exc <- openStreams.values().asScala) { + for (exc <- openStreams.values) { logWarning("Leaked filesystem connection created at:") exc.printStackTrace() } - throw new RuntimeException(s"There are $numOpen possibly leaked file streams.") + throw new IllegalStateException(s"There are $numOpen possibly leaked file streams.", + openStreams.values.head) } } } @@ -59,8 +67,7 @@ class DebugFilesystem extends LocalFileSystem { override def open(f: Path, bufferSize: Int): FSDataInputStream = { val wrapped: FSDataInputStream = super.open(f, bufferSize) - openStreams.put(wrapped, new Throwable()) - + addOpenStream(wrapped) new FSDataInputStream(wrapped.getWrappedStream) { override def setDropBehind(dropBehind: lang.Boolean): Unit = wrapped.setDropBehind(dropBehind) @@ -97,7 +104,7 @@ class DebugFilesystem extends LocalFileSystem { override def close(): Unit = { wrapped.close() - openStreams.remove(wrapped) + removeOpenStream(wrapped) } override def read(): Int = wrapped.read() diff --git a/core/src/test/scala/org/apache/spark/DistributedSuite.scala b/core/src/test/scala/org/apache/spark/DistributedSuite.scala index 55913f549d0e5..2071f90cfeed5 100644 --- a/core/src/test/scala/org/apache/spark/DistributedSuite.scala +++ b/core/src/test/scala/org/apache/spark/DistributedSuite.scala @@ -17,10 +17,11 @@ package org.apache.spark -import org.scalatest.{Ignore, Matchers} +import org.scalatest.Matchers import org.scalatest.concurrent.Timeouts._ import org.scalatest.time.{Millis, Span} +import org.apache.spark.security.EncryptionFunSuite import org.apache.spark.storage.{RDDBlockId, StorageLevel} import org.apache.spark.util.io.ChunkedByteBuffer @@ -28,8 +29,8 @@ class NotSerializableClass class NotSerializableExn(val notSer: NotSerializableClass) extends Throwable() {} -@Ignore -class DistributedSuite extends SparkFunSuite with Matchers with LocalSparkContext { +class DistributedSuite extends SparkFunSuite with Matchers with LocalSparkContext + with EncryptionFunSuite { val clusterUrl = "local-cluster[2,1,1024]" @@ -150,8 +151,8 @@ class DistributedSuite extends SparkFunSuite with Matchers with LocalSparkContex sc.parallelize(1 to 10).count() } - private def testCaching(storageLevel: StorageLevel): Unit = { - sc = new SparkContext(clusterUrl, "test") + private def testCaching(conf: SparkConf, storageLevel: StorageLevel): Unit = { + sc = new SparkContext(conf.setMaster(clusterUrl).setAppName("test")) sc.jobProgressListener.waitUntilExecutorsUp(2, 30000) val data = sc.parallelize(1 to 1000, 10) val cachedData = data.persist(storageLevel) @@ -188,8 +189,8 @@ class DistributedSuite extends SparkFunSuite with Matchers with LocalSparkContex "caching in memory and disk, replicated" -> StorageLevel.MEMORY_AND_DISK_2, "caching in memory and disk, serialized, replicated" -> StorageLevel.MEMORY_AND_DISK_SER_2 ).foreach { case (testName, storageLevel) => - test(testName) { - testCaching(storageLevel) + encryptionTest(testName) { conf => + testCaching(conf, storageLevel) } } @@ -232,7 +233,7 @@ class DistributedSuite extends SparkFunSuite with Matchers with LocalSparkContex } test("recover from node failures") { - import DistributedSuite.{markNodeIfIdentity, failOnMarkedIdentity} + import DistributedSuite.{failOnMarkedIdentity, markNodeIfIdentity} DistributedSuite.amMaster = true sc = new SparkContext(clusterUrl, "test") val data = sc.parallelize(Seq(true, true), 2) @@ -242,7 +243,7 @@ class DistributedSuite extends SparkFunSuite with Matchers with LocalSparkContex } test("recover from repeated node failures during shuffle-map") { - import DistributedSuite.{markNodeIfIdentity, failOnMarkedIdentity} + import DistributedSuite.{failOnMarkedIdentity, markNodeIfIdentity} DistributedSuite.amMaster = true sc = new SparkContext(clusterUrl, "test") for (i <- 1 to 3) { @@ -254,7 +255,7 @@ class DistributedSuite extends SparkFunSuite with Matchers with LocalSparkContex } test("recover from repeated node failures during shuffle-reduce") { - import DistributedSuite.{markNodeIfIdentity, failOnMarkedIdentity} + import DistributedSuite.{failOnMarkedIdentity, markNodeIfIdentity} DistributedSuite.amMaster = true sc = new SparkContext(clusterUrl, "test") for (i <- 1 to 3) { @@ -273,7 +274,7 @@ class DistributedSuite extends SparkFunSuite with Matchers with LocalSparkContex } test("recover from node failures with replication") { - import DistributedSuite.{markNodeIfIdentity, failOnMarkedIdentity} + import DistributedSuite.{failOnMarkedIdentity, markNodeIfIdentity} DistributedSuite.amMaster = true // Using more than two nodes so we don't have a symmetric communication pattern and might // cache a partially correct list of peers. diff --git a/core/src/test/scala/org/apache/spark/ExternalShuffleServiceSuite.scala b/core/src/test/scala/org/apache/spark/ExternalShuffleServiceSuite.scala index eb3fb99747d12..fe944031bc948 100644 --- a/core/src/test/scala/org/apache/spark/ExternalShuffleServiceSuite.scala +++ b/core/src/test/scala/org/apache/spark/ExternalShuffleServiceSuite.scala @@ -27,7 +27,7 @@ import org.apache.spark.network.shuffle.{ExternalShuffleBlockHandler, ExternalSh /** * This suite creates an external shuffle server and routes all shuffle fetches through it. * Note that failures in this suite may arise due to changes in Spark that invalidate expectations - * set up in [[ExternalShuffleBlockHandler]], such as changing the format of shuffle files or how + * set up in `ExternalShuffleBlockHandler`, such as changing the format of shuffle files or how * we hash files into folders. */ class ExternalShuffleServiceSuite extends ShuffleSuite with BeforeAndAfterAll { diff --git a/core/src/test/scala/org/apache/spark/LocalSparkContext.scala b/core/src/test/scala/org/apache/spark/LocalSparkContext.scala index 24ec99c7e5e60..1dd89bcbe36bc 100644 --- a/core/src/test/scala/org/apache/spark/LocalSparkContext.scala +++ b/core/src/test/scala/org/apache/spark/LocalSparkContext.scala @@ -22,7 +22,7 @@ import org.scalatest.BeforeAndAfterAll import org.scalatest.BeforeAndAfterEach import org.scalatest.Suite -/** Manages a local `sc` {@link SparkContext} variable, correctly stopping it after each test. */ +/** Manages a local `sc` `SparkContext` variable, correctly stopping it after each test. */ trait LocalSparkContext extends BeforeAndAfterEach with BeforeAndAfterAll { self: Suite => @transient var sc: SparkContext = _ diff --git a/core/src/test/scala/org/apache/spark/ShuffleSuite.scala b/core/src/test/scala/org/apache/spark/ShuffleSuite.scala index e626ed3621d60..58b865969f517 100644 --- a/core/src/test/scala/org/apache/spark/ShuffleSuite.scala +++ b/core/src/test/scala/org/apache/spark/ShuffleSuite.scala @@ -17,7 +17,7 @@ package org.apache.spark -import java.util.Properties +import java.util.{Locale, Properties} import java.util.concurrent.{Callable, CyclicBarrier, Executors, ExecutorService} import org.scalatest.Matchers @@ -239,7 +239,7 @@ abstract class ShuffleSuite extends SparkFunSuite with Matchers with LocalSparkC } assert(thrown.getClass === classOf[SparkException]) - assert(thrown.getMessage.toLowerCase.contains("serializable")) + assert(thrown.getMessage.toLowerCase(Locale.ROOT).contains("serializable")) } test("shuffle with different compression settings (SPARK-3426)") { diff --git a/core/src/test/scala/org/apache/spark/SparkContextSuite.scala b/core/src/test/scala/org/apache/spark/SparkContextSuite.scala index f97a112ec1276..7e26139a2bead 100644 --- a/core/src/test/scala/org/apache/spark/SparkContextSuite.scala +++ b/core/src/test/scala/org/apache/spark/SparkContextSuite.scala @@ -18,7 +18,7 @@ package org.apache.spark import java.io.File -import java.net.MalformedURLException +import java.net.{MalformedURLException, URI} import java.nio.charset.StandardCharsets import java.util.concurrent.TimeUnit @@ -26,13 +26,15 @@ import scala.concurrent.duration._ import scala.concurrent.Await import com.google.common.io.Files +import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.fs.{FileSystem, Path} import org.apache.hadoop.io.{BytesWritable, LongWritable, Text} import org.apache.hadoop.mapred.TextInputFormat import org.apache.hadoop.mapreduce.lib.input.{TextInputFormat => NewTextInputFormat} import org.scalatest.concurrent.Eventually import org.scalatest.Matchers._ -import org.apache.spark.scheduler.{SparkListener, SparkListenerJobStart, SparkListenerTaskStart} +import org.apache.spark.scheduler.{SparkListener, SparkListenerJobStart, SparkListenerTaskEnd, SparkListenerTaskStart} import org.apache.spark.util.Utils @@ -538,10 +540,83 @@ class SparkContextSuite extends SparkFunSuite with LocalSparkContext with Eventu } } + testCancellingTasks("that raise interrupted exception on cancel") { + Thread.sleep(9999999) + } + + // SPARK-20217 should not fail stage if task throws non-interrupted exception + testCancellingTasks("that raise runtime exception on cancel") { + try { + Thread.sleep(9999999) + } catch { + case t: Throwable => + throw new RuntimeException("killed") + } + } + + // Launches one task that will block forever. Once the SparkListener detects the task has + // started, kill and re-schedule it. The second run of the task will complete immediately. + // If this test times out, then the first version of the task wasn't killed successfully. + def testCancellingTasks(desc: String)(blockFn: => Unit): Unit = test(s"Killing tasks $desc") { + sc = new SparkContext(new SparkConf().setAppName("test").setMaster("local")) + + SparkContextSuite.isTaskStarted = false + SparkContextSuite.taskKilled = false + SparkContextSuite.taskSucceeded = false + + val listener = new SparkListener { + override def onTaskStart(taskStart: SparkListenerTaskStart): Unit = { + eventually(timeout(10.seconds)) { + assert(SparkContextSuite.isTaskStarted) + } + if (!SparkContextSuite.taskKilled) { + SparkContextSuite.taskKilled = true + sc.killTaskAttempt(taskStart.taskInfo.taskId, true, "first attempt will hang") + } + } + override def onTaskEnd(taskEnd: SparkListenerTaskEnd): Unit = { + if (taskEnd.taskInfo.attemptNumber == 1 && taskEnd.reason == Success) { + SparkContextSuite.taskSucceeded = true + } + } + } + sc.addSparkListener(listener) + eventually(timeout(20.seconds)) { + sc.parallelize(1 to 1).foreach { x => + // first attempt will hang + if (!SparkContextSuite.isTaskStarted) { + SparkContextSuite.isTaskStarted = true + blockFn + } + // second attempt succeeds immediately + } + } + eventually(timeout(10.seconds)) { + assert(SparkContextSuite.taskSucceeded) + } + } + + test("SPARK-19446: DebugFilesystem.assertNoOpenStreams should report " + + "open streams to help debugging") { + val fs = new DebugFilesystem() + fs.initialize(new URI("file:///"), new Configuration()) + val file = File.createTempFile("SPARK19446", "temp") + Files.write(Array.ofDim[Byte](1000), file) + val path = new Path("file:///" + file.getCanonicalPath) + val stream = fs.open(path) + val exc = intercept[RuntimeException] { + DebugFilesystem.assertNoOpenStreams() + } + assert(exc != null) + assert(exc.getCause() != null) + stream.close() + } } object SparkContextSuite { @volatile var cancelJob = false @volatile var cancelStage = false @volatile var isTaskStarted = false + @volatile var taskKilled = false + @volatile var taskSucceeded = false } diff --git a/core/src/test/scala/org/apache/spark/broadcast/BroadcastSuite.scala b/core/src/test/scala/org/apache/spark/broadcast/BroadcastSuite.scala index 6646068d5080b..46f9ac6b0273a 100644 --- a/core/src/test/scala/org/apache/spark/broadcast/BroadcastSuite.scala +++ b/core/src/test/scala/org/apache/spark/broadcast/BroadcastSuite.scala @@ -17,6 +17,8 @@ package org.apache.spark.broadcast +import java.util.Locale + import scala.util.Random import org.scalatest.Assertions @@ -24,8 +26,10 @@ import org.scalatest.Assertions import org.apache.spark._ import org.apache.spark.io.SnappyCompressionCodec import org.apache.spark.rdd.RDD +import org.apache.spark.security.EncryptionFunSuite import org.apache.spark.serializer.JavaSerializer import org.apache.spark.storage._ +import org.apache.spark.util.io.ChunkedByteBuffer // Dummy class that creates a broadcast variable but doesn't use it class DummyBroadcastClass(rdd: RDD[Int]) extends Serializable { @@ -43,7 +47,7 @@ class DummyBroadcastClass(rdd: RDD[Int]) extends Serializable { } } -class BroadcastSuite extends SparkFunSuite with LocalSparkContext { +class BroadcastSuite extends SparkFunSuite with LocalSparkContext with EncryptionFunSuite { test("Using TorrentBroadcast locally") { sc = new SparkContext("local", "test") @@ -61,9 +65,8 @@ class BroadcastSuite extends SparkFunSuite with LocalSparkContext { assert(results.collect().toSet === (1 to 10).map(x => (x, 10)).toSet) } - test("Accessing TorrentBroadcast variables in a local cluster") { + encryptionTest("Accessing TorrentBroadcast variables in a local cluster") { conf => val numSlaves = 4 - val conf = new SparkConf conf.set("spark.serializer", "org.apache.spark.serializer.KryoSerializer") conf.set("spark.broadcast.compress", "true") sc = new SparkContext("local-cluster[%d, 1, 1024]".format(numSlaves), "test", conf) @@ -85,7 +88,9 @@ class BroadcastSuite extends SparkFunSuite with LocalSparkContext { val size = 1 + rand.nextInt(1024 * 10) val data: Array[Byte] = new Array[Byte](size) rand.nextBytes(data) - val blocks = blockifyObject(data, blockSize, serializer, compressionCodec) + val blocks = blockifyObject(data, blockSize, serializer, compressionCodec).map { b => + new ChunkedByteBuffer(b).toInputStream(dispose = true) + } val unblockified = unBlockifyObject[Array[Byte]](blocks, serializer, compressionCodec) assert(unblockified === data) } @@ -127,7 +132,7 @@ class BroadcastSuite extends SparkFunSuite with LocalSparkContext { val thrown = intercept[IllegalStateException] { sc.broadcast(Seq(1, 2, 3)) } - assert(thrown.getMessage.toLowerCase.contains("stopped")) + assert(thrown.getMessage.toLowerCase(Locale.ROOT).contains("stopped")) } test("Forbid broadcasting RDD directly") { @@ -137,9 +142,8 @@ class BroadcastSuite extends SparkFunSuite with LocalSparkContext { sc.stop() } - test("Cache broadcast to disk") { - val conf = new SparkConf() - .setMaster("local") + encryptionTest("Cache broadcast to disk") { conf => + conf.setMaster("local") .setAppName("test") .set("spark.memory.useLegacyMode", "true") .set("spark.storage.memoryFraction", "0.0") diff --git a/core/src/test/scala/org/apache/spark/deploy/SparkHadoopUtilSuite.scala b/core/src/test/scala/org/apache/spark/deploy/SparkHadoopUtilSuite.scala new file mode 100644 index 0000000000000..ab24a76e20a30 --- /dev/null +++ b/core/src/test/scala/org/apache/spark/deploy/SparkHadoopUtilSuite.scala @@ -0,0 +1,97 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.deploy + +import java.security.PrivilegedExceptionAction + +import scala.util.Random + +import org.apache.hadoop.fs.FileStatus +import org.apache.hadoop.fs.permission.{FsAction, FsPermission} +import org.apache.hadoop.security.UserGroupInformation +import org.scalatest.Matchers + +import org.apache.spark.SparkFunSuite + +class SparkHadoopUtilSuite extends SparkFunSuite with Matchers { + test("check file permission") { + import FsAction._ + val testUser = s"user-${Random.nextInt(100)}" + val testGroups = Array(s"group-${Random.nextInt(100)}") + val testUgi = UserGroupInformation.createUserForTesting(testUser, testGroups) + + testUgi.doAs(new PrivilegedExceptionAction[Void] { + override def run(): Void = { + val sparkHadoopUtil = new SparkHadoopUtil + + // If file is owned by user and user has access permission + var status = fileStatus(testUser, testGroups.head, READ_WRITE, READ_WRITE, NONE) + sparkHadoopUtil.checkAccessPermission(status, READ) should be(true) + sparkHadoopUtil.checkAccessPermission(status, WRITE) should be(true) + + // If file is owned by user but user has no access permission + status = fileStatus(testUser, testGroups.head, NONE, READ_WRITE, NONE) + sparkHadoopUtil.checkAccessPermission(status, READ) should be(false) + sparkHadoopUtil.checkAccessPermission(status, WRITE) should be(false) + + val otherUser = s"test-${Random.nextInt(100)}" + val otherGroup = s"test-${Random.nextInt(100)}" + + // If file is owned by user's group and user's group has access permission + status = fileStatus(otherUser, testGroups.head, NONE, READ_WRITE, NONE) + sparkHadoopUtil.checkAccessPermission(status, READ) should be(true) + sparkHadoopUtil.checkAccessPermission(status, WRITE) should be(true) + + // If file is owned by user's group but user's group has no access permission + status = fileStatus(otherUser, testGroups.head, READ_WRITE, NONE, NONE) + sparkHadoopUtil.checkAccessPermission(status, READ) should be(false) + sparkHadoopUtil.checkAccessPermission(status, WRITE) should be(false) + + // If file is owned by other user and this user has access permission + status = fileStatus(otherUser, otherGroup, READ_WRITE, READ_WRITE, READ_WRITE) + sparkHadoopUtil.checkAccessPermission(status, READ) should be(true) + sparkHadoopUtil.checkAccessPermission(status, WRITE) should be(true) + + // If file is owned by other user but this user has no access permission + status = fileStatus(otherUser, otherGroup, READ_WRITE, READ_WRITE, NONE) + sparkHadoopUtil.checkAccessPermission(status, READ) should be(false) + sparkHadoopUtil.checkAccessPermission(status, WRITE) should be(false) + + null + } + }) + } + + private def fileStatus( + owner: String, + group: String, + userAction: FsAction, + groupAction: FsAction, + otherAction: FsAction): FileStatus = { + new FileStatus(0L, + false, + 0, + 0L, + 0L, + 0L, + new FsPermission(userAction, groupAction, otherAction), + owner, + group, + null) + } +} diff --git a/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala index 9417930d02405..a43839a8815f9 100644 --- a/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala @@ -21,8 +21,10 @@ import java.io._ import java.nio.charset.StandardCharsets import scala.collection.mutable.ArrayBuffer +import scala.io.Source import com.google.common.io.ByteStreams +import org.apache.hadoop.fs.Path import org.scalatest.{BeforeAndAfterEach, Matchers} import org.scalatest.concurrent.Timeouts import org.scalatest.time.SpanSugar._ @@ -34,6 +36,7 @@ import org.apache.spark.deploy.SparkSubmitUtils.MavenCoordinate import org.apache.spark.internal.config._ import org.apache.spark.internal.Logging import org.apache.spark.TestUtils.JavaSourceFromString +import org.apache.spark.scheduler.EventLoggingListener import org.apache.spark.util.{CommandLineUtils, ResetSystemProperties, Utils} @@ -148,6 +151,17 @@ class SparkSubmitSuite appArgs.childArgs should be (Seq("--master", "local", "some", "--weird", "args")) } + test("print the right queue name") { + val clArgs = Seq( + "--name", "myApp", + "--class", "Foo", + "--conf", "spark.yarn.queue=thequeue", + "userjar.jar") + val appArgs = new SparkSubmitArguments(clArgs) + appArgs.queue should be ("thequeue") + appArgs.toString should include ("thequeue") + } + test("specify deploy mode through configuration") { val clArgs = Seq( "--master", "yarn", @@ -213,7 +227,12 @@ class SparkSubmitSuite childArgsStr should include ("--arg arg1 --arg arg2") childArgsStr should include regex ("--jar .*thejar.jar") mainClass should be ("org.apache.spark.deploy.yarn.Client") - classpath should have length (0) + + // In yarn cluster mode, also adding jars to classpath + classpath(0) should endWith ("thejar.jar") + classpath(1) should endWith ("one.jar") + classpath(2) should endWith ("two.jar") + classpath(3) should endWith ("three.jar") sysProps("spark.executor.memory") should be ("5g") sysProps("spark.driver.memory") should be ("4g") @@ -388,6 +407,37 @@ class SparkSubmitSuite runSparkSubmit(args) } + test("launch simple application with spark-submit with redaction") { + val testDir = Utils.createTempDir() + testDir.deleteOnExit() + val testDirPath = new Path(testDir.getAbsolutePath()) + val unusedJar = TestUtils.createJarWithClasses(Seq.empty) + val fileSystem = Utils.getHadoopFileSystem("/", + SparkHadoopUtil.get.newConfiguration(new SparkConf())) + try { + val args = Seq( + "--class", SimpleApplicationTest.getClass.getName.stripSuffix("$"), + "--name", "testApp", + "--master", "local", + "--conf", "spark.ui.enabled=false", + "--conf", "spark.master.rest.enabled=false", + "--conf", "spark.executorEnv.HADOOP_CREDSTORE_PASSWORD=secret_password", + "--conf", "spark.eventLog.enabled=true", + "--conf", "spark.eventLog.testing=true", + "--conf", s"spark.eventLog.dir=${testDirPath.toUri.toString}", + "--conf", "spark.hadoop.fs.defaultFS=unsupported://example.com", + unusedJar.toString) + runSparkSubmit(args) + val listStatus = fileSystem.listStatus(testDirPath) + val logData = EventLoggingListener.openEventLog(listStatus.last.getPath, fileSystem) + Source.fromInputStream(logData).getLines().foreach { line => + assert(!line.contains("secret_password")) + } + } finally { + Utils.deleteRecursively(testDir) + } + } + test("includes jars passed in through --jars") { val unusedJar = TestUtils.createJarWithClasses(Seq.empty) val jar1 = TestUtils.createJarWithClasses(Seq("SparkSubmitClassA")) diff --git a/core/src/test/scala/org/apache/spark/deploy/StandaloneDynamicAllocationSuite.scala b/core/src/test/scala/org/apache/spark/deploy/StandaloneDynamicAllocationSuite.scala index 9839dcf8535db..bf7480d79f8a1 100644 --- a/core/src/test/scala/org/apache/spark/deploy/StandaloneDynamicAllocationSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/StandaloneDynamicAllocationSuite.scala @@ -356,12 +356,13 @@ class StandaloneDynamicAllocationSuite test("kill the same executor twice (SPARK-9795)") { sc = new SparkContext(appConf) val appId = sc.applicationId + sc.requestExecutors(2) eventually(timeout(10.seconds), interval(10.millis)) { val apps = getApplications() assert(apps.size === 1) assert(apps.head.id === appId) assert(apps.head.executors.size === 2) - assert(apps.head.getExecutorLimit === Int.MaxValue) + assert(apps.head.getExecutorLimit === 2) } // sync executors between the Master and the driver, needed because // the driver refuses to kill executors it does not know about @@ -380,12 +381,13 @@ class StandaloneDynamicAllocationSuite test("the pending replacement executors should not be lost (SPARK-10515)") { sc = new SparkContext(appConf) val appId = sc.applicationId + sc.requestExecutors(2) eventually(timeout(10.seconds), interval(10.millis)) { val apps = getApplications() assert(apps.size === 1) assert(apps.head.id === appId) assert(apps.head.executors.size === 2) - assert(apps.head.getExecutorLimit === Int.MaxValue) + assert(apps.head.getExecutorLimit === 2) } // sync executors between the Master and the driver, needed because // the driver refuses to kill executors it does not know about diff --git a/core/src/test/scala/org/apache/spark/deploy/history/FsHistoryProviderSuite.scala b/core/src/test/scala/org/apache/spark/deploy/history/FsHistoryProviderSuite.scala index ec580a44b8e76..456158d41b93f 100644 --- a/core/src/test/scala/org/apache/spark/deploy/history/FsHistoryProviderSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/history/FsHistoryProviderSuite.scala @@ -27,6 +27,7 @@ import scala.concurrent.duration._ import scala.language.postfixOps import com.google.common.io.{ByteStreams, Files} +import org.apache.hadoop.fs.FileStatus import org.apache.hadoop.hdfs.DistributedFileSystem import org.json4s.jackson.JsonMethods._ import org.mockito.Matchers.any @@ -130,9 +131,19 @@ class FsHistoryProviderSuite extends SparkFunSuite with BeforeAndAfter with Matc } } - test("SPARK-3697: ignore directories that cannot be read.") { + test("SPARK-3697: ignore files that cannot be read.") { // setReadable(...) does not work on Windows. Please refer JDK-6728842. assume(!Utils.isWindows) + + class TestFsHistoryProvider extends FsHistoryProvider(createTestConf()) { + var mergeApplicationListingCall = 0 + override protected def mergeApplicationListing(fileStatus: FileStatus): Unit = { + super.mergeApplicationListing(fileStatus) + mergeApplicationListingCall += 1 + } + } + val provider = new TestFsHistoryProvider + val logFile1 = newLogFile("new1", None, inProgress = false) writeFile(logFile1, true, None, SparkListenerApplicationStart("app1-1", Some("app1-1"), 1L, "test", None), @@ -145,10 +156,11 @@ class FsHistoryProviderSuite extends SparkFunSuite with BeforeAndAfter with Matc ) logFile2.setReadable(false, false) - val provider = new FsHistoryProvider(createTestConf()) updateAndCheck(provider) { list => list.size should be (1) } + + provider.mergeApplicationListingCall should be (1) } test("history file is renamed from inprogress to completed") { diff --git a/core/src/test/scala/org/apache/spark/deploy/history/HistoryServerSuite.scala b/core/src/test/scala/org/apache/spark/deploy/history/HistoryServerSuite.scala index dcf83cb530a91..95acb9a54440f 100644 --- a/core/src/test/scala/org/apache/spark/deploy/history/HistoryServerSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/history/HistoryServerSuite.scala @@ -153,7 +153,8 @@ class HistoryServerSuite extends SparkFunSuite with BeforeAndAfter with Matchers "rdd list storage json" -> "applications/local-1422981780767/storage/rdd", "executor node blacklisting" -> "applications/app-20161116163331-0000/executors", - "executor node blacklisting unblacklisting" -> "applications/app-20161115172038-0000/executors" + "executor node blacklisting unblacklisting" -> "applications/app-20161115172038-0000/executors", + "executor memory usage" -> "applications/app-20161116163331-0000/executors" // Todo: enable this test when logging the even of onBlockUpdated. See: SPARK-13845 // "one rdd storage json" -> "applications/local-1422981780767/storage/rdd/0" ) @@ -564,13 +565,12 @@ class HistoryServerSuite extends SparkFunSuite with BeforeAndAfter with Matchers assert(jobcount === getNumJobs("/jobs")) // no need to retain the test dir now the tests complete - logDir.deleteOnExit(); - + logDir.deleteOnExit() } test("ui and api authorization checks") { - val appId = "app-20161115172038-0000" - val owner = "jose" + val appId = "local-1430917381535" + val owner = "irashid" val admin = "root" val other = "alice" @@ -589,8 +589,11 @@ class HistoryServerSuite extends SparkFunSuite with BeforeAndAfter with Matchers val port = server.boundPort val testUrls = Seq( - s"http://localhost:$port/api/v1/applications/$appId/jobs", - s"http://localhost:$port/history/$appId/jobs/") + s"http://localhost:$port/api/v1/applications/$appId/1/jobs", + s"http://localhost:$port/history/$appId/1/jobs/", + s"http://localhost:$port/api/v1/applications/$appId/logs", + s"http://localhost:$port/api/v1/applications/$appId/1/logs", + s"http://localhost:$port/api/v1/applications/$appId/2/logs") tests.foreach { case (user, expectedCode) => testUrls.foreach { url => diff --git a/core/src/test/scala/org/apache/spark/executor/ExecutorSuite.scala b/core/src/test/scala/org/apache/spark/executor/ExecutorSuite.scala index 8150fff2d018d..efcad140350b9 100644 --- a/core/src/test/scala/org/apache/spark/executor/ExecutorSuite.scala +++ b/core/src/test/scala/org/apache/spark/executor/ExecutorSuite.scala @@ -44,6 +44,7 @@ import org.apache.spark.scheduler.{FakeTask, ResultTask, TaskDescription} import org.apache.spark.serializer.JavaSerializer import org.apache.spark.shuffle.FetchFailedException import org.apache.spark.storage.BlockManagerId +import org.apache.spark.util.UninterruptibleThread class ExecutorSuite extends SparkFunSuite with LocalSparkContext with MockitoSugar with Eventually { @@ -110,14 +111,14 @@ class ExecutorSuite extends SparkFunSuite with LocalSparkContext with MockitoSug } // we know the task will be started, but not yet deserialized, because of the latches we // use in mockExecutorBackend. - executor.killAllTasks(true) + executor.killAllTasks(true, "test") executorSuiteHelper.latch2.countDown() if (!executorSuiteHelper.latch3.await(5, TimeUnit.SECONDS)) { fail("executor did not send second status update in time") } // `testFailedReason` should be `TaskKilled`; `taskState` should be `KILLED` - assert(executorSuiteHelper.testFailedReason === TaskKilled) + assert(executorSuiteHelper.testFailedReason === TaskKilled("test")) assert(executorSuiteHelper.taskState === TaskState.KILLED) } finally { @@ -158,6 +159,18 @@ class ExecutorSuite extends SparkFunSuite with LocalSparkContext with MockitoSug assert(failReason.isInstanceOf[FetchFailed]) } + test("Executor's worker threads should be UninterruptibleThread") { + val conf = new SparkConf() + .setMaster("local") + .setAppName("executor thread test") + .set("spark.ui.enabled", "false") + sc = new SparkContext(conf) + val executorThread = sc.parallelize(Seq(1), 1).map { _ => + Thread.currentThread.getClass.getName + }.collect().head + assert(executorThread === classOf[UninterruptibleThread].getName) + } + test("SPARK-19276: OOMs correctly handled with a FetchFailure") { // when there is a fatal error like an OOM, we don't do normal fetch failure handling, since it // may be a false positive. And we should call the uncaught exception handler. diff --git a/core/src/test/scala/org/apache/spark/internal/config/ConfigEntrySuite.scala b/core/src/test/scala/org/apache/spark/internal/config/ConfigEntrySuite.scala index 71eed464880b5..b72cd8be24206 100644 --- a/core/src/test/scala/org/apache/spark/internal/config/ConfigEntrySuite.scala +++ b/core/src/test/scala/org/apache/spark/internal/config/ConfigEntrySuite.scala @@ -17,11 +17,9 @@ package org.apache.spark.internal.config +import java.util.Locale import java.util.concurrent.TimeUnit -import scala.collection.JavaConverters._ -import scala.collection.mutable.HashMap - import org.apache.spark.{SparkConf, SparkFunSuite} import org.apache.spark.network.util.ByteUnit import org.apache.spark.util.SparkConfWithEnv @@ -98,6 +96,21 @@ class ConfigEntrySuite extends SparkFunSuite { assert(conf.get(bytes) === 1L) } + test("conf entry: regex") { + val conf = new SparkConf() + val rConf = ConfigBuilder(testKey("regex")).regexConf.createWithDefault(".*".r) + + conf.set(rConf, "[0-9a-f]{8}".r) + assert(conf.get(rConf).toString === "[0-9a-f]{8}") + + conf.set(rConf.key, "[0-9a-f]{4}") + assert(conf.get(rConf).toString === "[0-9a-f]{4}") + + conf.set(rConf.key, "[.") + val e = intercept[IllegalArgumentException](conf.get(rConf)) + assert(e.getMessage.contains("regex should be a regex, but was")) + } + test("conf entry: string seq") { val conf = new SparkConf() val seq = ConfigBuilder(testKey("seq")).stringConf.toSequence.createWithDefault(Seq()) @@ -120,7 +133,7 @@ class ConfigEntrySuite extends SparkFunSuite { val conf = new SparkConf() val transformationConf = ConfigBuilder(testKey("transformation")) .stringConf - .transform(_.toLowerCase()) + .transform(_.toLowerCase(Locale.ROOT)) .createWithDefault("FOO") assert(conf.get(transformationConf) === "foo") @@ -240,4 +253,12 @@ class ConfigEntrySuite extends SparkFunSuite { testEntryRef(nullConf, ref(nullConf)) } + test("conf entry : default function") { + var data = 0 + val conf = new SparkConf() + val iConf = ConfigBuilder(testKey("intval")).intConf.createWithDefaultFunction(() => data) + assert(conf.get(iConf) === 0) + data = 2 + assert(conf.get(iConf) === 2) + } } diff --git a/core/src/test/scala/org/apache/spark/rdd/SortingSuite.scala b/core/src/test/scala/org/apache/spark/rdd/SortingSuite.scala index f9a7f151823a2..7f20206202cb9 100644 --- a/core/src/test/scala/org/apache/spark/rdd/SortingSuite.scala +++ b/core/src/test/scala/org/apache/spark/rdd/SortingSuite.scala @@ -135,7 +135,7 @@ class SortingSuite extends SparkFunSuite with SharedSparkContext with Matchers w } test("get a range of elements in an array not partitioned by a range partitioner") { - val pairArr = util.Random.shuffle((1 to 1000).toList).map(x => (x, x)) + val pairArr = scala.util.Random.shuffle((1 to 1000).toList).map(x => (x, x)) val pairs = sc.parallelize(pairArr, 10) val range = pairs.filterByRange(200, 800).collect() assert((800 to 200 by -1).toArray.sorted === range.map(_._1).sorted) diff --git a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala index 8eaf9dfcf49b1..a10941b579fe2 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala @@ -110,8 +110,8 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with Timeou val cancelledStages = new HashSet[Int]() val taskScheduler = new TaskScheduler() { - override def rootPool: Pool = null - override def schedulingMode: SchedulingMode = SchedulingMode.NONE + override def schedulingMode: SchedulingMode = SchedulingMode.FIFO + override def rootPool: Pool = new Pool("", schedulingMode, 0, 0) override def start() = {} override def stop() = {} override def executorHeartbeatReceived( @@ -126,6 +126,8 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with Timeou override def cancelTasks(stageId: Int, interruptThread: Boolean) { cancelledStages += stageId } + override def killTaskAttempt( + taskId: Long, interruptThread: Boolean, reason: String): Boolean = false override def setDAGScheduler(dagScheduler: DAGScheduler) = {} override def defaultParallelism() = 2 override def executorLost(executorId: String, reason: ExecutorLossReason): Unit = {} @@ -542,8 +544,8 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with Timeou // make sure that the DAGScheduler doesn't crash when the TaskScheduler // doesn't implement killTask() val noKillTaskScheduler = new TaskScheduler() { - override def rootPool: Pool = null - override def schedulingMode: SchedulingMode = SchedulingMode.NONE + override def schedulingMode: SchedulingMode = SchedulingMode.FIFO + override def rootPool: Pool = new Pool("", schedulingMode, 0, 0) override def start(): Unit = {} override def stop(): Unit = {} override def submitTasks(taskSet: TaskSet): Unit = { @@ -552,6 +554,10 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with Timeou override def cancelTasks(stageId: Int, interruptThread: Boolean) { throw new UnsupportedOperationException } + override def killTaskAttempt( + taskId: Long, interruptThread: Boolean, reason: String): Boolean = { + throw new UnsupportedOperationException + } override def setDAGScheduler(dagScheduler: DAGScheduler): Unit = {} override def defaultParallelism(): Int = 2 override def executorHeartbeatReceived( @@ -801,7 +807,7 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with Timeou val reduceRdd = new MyRDD(sc, 2, List(shuffleDep), tracker = mapOutputTracker) submit(reduceRdd, Array(0, 1)) - for (attempt <- 0 until Stage.MAX_CONSECUTIVE_FETCH_FAILURES) { + for (attempt <- 0 until scheduler.maxConsecutiveStageAttempts) { // Complete all the tasks for the current attempt of stage 0 successfully completeShuffleMapStageSuccessfully(0, attempt, numShufflePartitions = 2) @@ -813,7 +819,7 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with Timeou // map output, for the next iteration through the loop scheduler.resubmitFailedStages() - if (attempt < Stage.MAX_CONSECUTIVE_FETCH_FAILURES - 1) { + if (attempt < scheduler.maxConsecutiveStageAttempts - 1) { assert(scheduler.runningStages.nonEmpty) assert(!ended) } else { @@ -847,11 +853,11 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with Timeou // In the first two iterations, Stage 0 succeeds and stage 1 fails. In the next two iterations, // stage 2 fails. - for (attempt <- 0 until Stage.MAX_CONSECUTIVE_FETCH_FAILURES) { + for (attempt <- 0 until scheduler.maxConsecutiveStageAttempts) { // Complete all the tasks for the current attempt of stage 0 successfully completeShuffleMapStageSuccessfully(0, attempt, numShufflePartitions = 2) - if (attempt < Stage.MAX_CONSECUTIVE_FETCH_FAILURES / 2) { + if (attempt < scheduler.maxConsecutiveStageAttempts / 2) { // Now we should have a new taskSet, for a new attempt of stage 1. // Fail all these tasks with FetchFailure completeNextStageWithFetchFailure(1, attempt, shuffleDepOne) @@ -859,8 +865,8 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with Timeou completeShuffleMapStageSuccessfully(1, attempt, numShufflePartitions = 1) // Fail stage 2 - completeNextStageWithFetchFailure(2, attempt - Stage.MAX_CONSECUTIVE_FETCH_FAILURES / 2, - shuffleDepTwo) + completeNextStageWithFetchFailure(2, + attempt - scheduler.maxConsecutiveStageAttempts / 2, shuffleDepTwo) } // this will trigger a resubmission of stage 0, since we've lost some of its @@ -872,7 +878,7 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with Timeou completeShuffleMapStageSuccessfully(1, 4, numShufflePartitions = 1) // Succeed stage2 with a "42" - completeNextResultStageWithSuccess(2, Stage.MAX_CONSECUTIVE_FETCH_FAILURES/2) + completeNextResultStageWithSuccess(2, scheduler.maxConsecutiveStageAttempts / 2) assert(results === Map(0 -> 42)) assertDataStructuresEmpty() @@ -895,7 +901,7 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with Timeou submit(finalRdd, Array(0)) // First, execute stages 0 and 1, failing stage 1 up to MAX-1 times. - for (attempt <- 0 until Stage.MAX_CONSECUTIVE_FETCH_FAILURES - 1) { + for (attempt <- 0 until scheduler.maxConsecutiveStageAttempts - 1) { // Make each task in stage 0 success completeShuffleMapStageSuccessfully(0, attempt, numShufflePartitions = 2) diff --git a/core/src/test/scala/org/apache/spark/scheduler/ExternalClusterManagerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/ExternalClusterManagerSuite.scala index e87cebf0cf358..ba56af8215cd7 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/ExternalClusterManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/ExternalClusterManagerSuite.scala @@ -73,12 +73,14 @@ private class DummySchedulerBackend extends SchedulerBackend { private class DummyTaskScheduler extends TaskScheduler { var initialized = false - override def rootPool: Pool = null - override def schedulingMode: SchedulingMode = SchedulingMode.NONE + override def schedulingMode: SchedulingMode = SchedulingMode.FIFO + override def rootPool: Pool = new Pool("", schedulingMode, 0, 0) override def start(): Unit = {} override def stop(): Unit = {} override def submitTasks(taskSet: TaskSet): Unit = {} override def cancelTasks(stageId: Int, interruptThread: Boolean): Unit = {} + override def killTaskAttempt( + taskId: Long, interruptThread: Boolean, reason: String): Boolean = false override def setDAGScheduler(dagScheduler: DAGScheduler): Unit = {} override def defaultParallelism(): Int = 2 override def executorLost(executorId: String, reason: ExecutorLossReason): Unit = {} diff --git a/core/src/test/scala/org/apache/spark/scheduler/OutputCommitCoordinatorSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/OutputCommitCoordinatorSuite.scala index 83ed12752074d..e51e6a0d3ff6b 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/OutputCommitCoordinatorSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/OutputCommitCoordinatorSuite.scala @@ -31,6 +31,7 @@ import org.mockito.stubbing.Answer import org.scalatest.BeforeAndAfter import org.apache.spark._ +import org.apache.spark.internal.io.SparkHadoopWriter import org.apache.spark.rdd.{FakeOutputCommitter, RDD} import org.apache.spark.util.{ThreadUtils, Utils} @@ -175,13 +176,13 @@ class OutputCommitCoordinatorSuite extends SparkFunSuite with BeforeAndAfter { assert(!outputCommitCoordinator.canCommit(stage, partition, nonAuthorizedCommitter)) // The non-authorized committer fails outputCommitCoordinator.taskCompleted( - stage, partition, attemptNumber = nonAuthorizedCommitter, reason = TaskKilled) + stage, partition, attemptNumber = nonAuthorizedCommitter, reason = TaskKilled("test")) // New tasks should still not be able to commit because the authorized committer has not failed assert( !outputCommitCoordinator.canCommit(stage, partition, nonAuthorizedCommitter + 1)) // The authorized committer now fails, clearing the lock outputCommitCoordinator.taskCompleted( - stage, partition, attemptNumber = authorizedCommitter, reason = TaskKilled) + stage, partition, attemptNumber = authorizedCommitter, reason = TaskKilled("test")) // A new task should now be allowed to become the authorized committer assert( outputCommitCoordinator.canCommit(stage, partition, nonAuthorizedCommitter + 2)) diff --git a/core/src/test/scala/org/apache/spark/scheduler/PoolSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/PoolSuite.scala index 520736ab64270..4901062a78553 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/PoolSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/PoolSuite.scala @@ -31,6 +31,7 @@ class PoolSuite extends SparkFunSuite with LocalSparkContext { val LOCAL = "local" val APP_NAME = "PoolSuite" val SCHEDULER_ALLOCATION_FILE_PROPERTY = "spark.scheduler.allocation.file" + val TEST_POOL = "testPool" def createTaskSetManager(stageId: Int, numTasks: Int, taskScheduler: TaskSchedulerImpl) : TaskSetManager = { @@ -40,7 +41,7 @@ class PoolSuite extends SparkFunSuite with LocalSparkContext { new TaskSetManager(taskScheduler, new TaskSet(tasks, stageId, 0, 0, null), 0) } - def scheduleTaskAndVerifyId(taskId: Int, rootPool: Pool, expectedStageId: Int) { + def scheduleTaskAndVerifyId(taskId: Int, rootPool: Pool, expectedStageId: Int): Unit = { val taskSetQueue = rootPool.getSortedTaskSetQueue val nextTaskSetToSchedule = taskSetQueue.find(t => (t.runningTasks + t.tasksSuccessful) < t.numTasks) @@ -201,12 +202,102 @@ class PoolSuite extends SparkFunSuite with LocalSparkContext { verifyPool(rootPool, "pool_with_surrounded_whitespace", 3, 2, FAIR) } + /** + * spark.scheduler.pool property should be ignored for the FIFO scheduler, + * because pools are only needed for fair scheduling. + */ + test("FIFO scheduler uses root pool and not spark.scheduler.pool property") { + sc = new SparkContext("local", "PoolSuite") + val taskScheduler = new TaskSchedulerImpl(sc) + + val rootPool = new Pool("", SchedulingMode.FIFO, initMinShare = 0, initWeight = 0) + val schedulableBuilder = new FIFOSchedulableBuilder(rootPool) + + val taskSetManager0 = createTaskSetManager(stageId = 0, numTasks = 1, taskScheduler) + val taskSetManager1 = createTaskSetManager(stageId = 1, numTasks = 1, taskScheduler) + + val properties = new Properties() + properties.setProperty("spark.scheduler.pool", TEST_POOL) + + // When FIFO Scheduler is used and task sets are submitted, they should be added to + // the root pool, and no additional pools should be created + // (even though there's a configured default pool). + schedulableBuilder.addTaskSetManager(taskSetManager0, properties) + schedulableBuilder.addTaskSetManager(taskSetManager1, properties) + + assert(rootPool.getSchedulableByName(TEST_POOL) === null) + assert(rootPool.schedulableQueue.size === 2) + assert(rootPool.getSchedulableByName(taskSetManager0.name) === taskSetManager0) + assert(rootPool.getSchedulableByName(taskSetManager1.name) === taskSetManager1) + } + + test("FAIR Scheduler uses default pool when spark.scheduler.pool property is not set") { + sc = new SparkContext("local", "PoolSuite") + val taskScheduler = new TaskSchedulerImpl(sc) + + val rootPool = new Pool("", SchedulingMode.FAIR, initMinShare = 0, initWeight = 0) + val schedulableBuilder = new FairSchedulableBuilder(rootPool, sc.conf) + schedulableBuilder.buildPools() + + // Submit a new task set manager with pool properties set to null. This should result + // in the task set manager getting added to the default pool. + val taskSetManager0 = createTaskSetManager(stageId = 0, numTasks = 1, taskScheduler) + schedulableBuilder.addTaskSetManager(taskSetManager0, null) + + val defaultPool = rootPool.getSchedulableByName(schedulableBuilder.DEFAULT_POOL_NAME) + assert(defaultPool !== null) + assert(defaultPool.schedulableQueue.size === 1) + assert(defaultPool.getSchedulableByName(taskSetManager0.name) === taskSetManager0) + + // When a task set manager is submitted with spark.scheduler.pool unset, it should be added to + // the default pool (as above). + val taskSetManager1 = createTaskSetManager(stageId = 1, numTasks = 1, taskScheduler) + schedulableBuilder.addTaskSetManager(taskSetManager1, new Properties()) + + assert(defaultPool.schedulableQueue.size === 2) + assert(defaultPool.getSchedulableByName(taskSetManager1.name) === taskSetManager1) + } + + test("FAIR Scheduler creates a new pool when spark.scheduler.pool property points to " + + "a non-existent pool") { + sc = new SparkContext("local", "PoolSuite") + val taskScheduler = new TaskSchedulerImpl(sc) + + val rootPool = new Pool("", SchedulingMode.FAIR, initMinShare = 0, initWeight = 0) + val schedulableBuilder = new FairSchedulableBuilder(rootPool, sc.conf) + schedulableBuilder.buildPools() + + assert(rootPool.getSchedulableByName(TEST_POOL) === null) + + val taskSetManager = createTaskSetManager(stageId = 0, numTasks = 1, taskScheduler) + + val properties = new Properties() + properties.setProperty(schedulableBuilder.FAIR_SCHEDULER_PROPERTIES, TEST_POOL) + + // The fair scheduler should create a new pool with default values when spark.scheduler.pool + // points to a pool that doesn't exist yet (this can happen when the file that pools are read + // from isn't set, or when that file doesn't contain the pool name specified + // by spark.scheduler.pool). + schedulableBuilder.addTaskSetManager(taskSetManager, properties) + + verifyPool(rootPool, TEST_POOL, schedulableBuilder.DEFAULT_MINIMUM_SHARE, + schedulableBuilder.DEFAULT_WEIGHT, schedulableBuilder.DEFAULT_SCHEDULING_MODE) + val testPool = rootPool.getSchedulableByName(TEST_POOL) + assert(testPool.getSchedulableByName(taskSetManager.name) === taskSetManager) + } + + test("Pool should throw IllegalArgumentException when schedulingMode is not supported") { + intercept[IllegalArgumentException] { + new Pool("TestPool", SchedulingMode.NONE, 0, 1) + } + } + private def verifyPool(rootPool: Pool, poolName: String, expectedInitMinShare: Int, expectedInitWeight: Int, expectedSchedulingMode: SchedulingMode): Unit = { - assert(rootPool.getSchedulableByName(poolName) != null) - assert(rootPool.getSchedulableByName(poolName).minShare === expectedInitMinShare) - assert(rootPool.getSchedulableByName(poolName).weight === expectedInitWeight) - assert(rootPool.getSchedulableByName(poolName).schedulingMode === expectedSchedulingMode) + val selectedPool = rootPool.getSchedulableByName(poolName) + assert(selectedPool !== null) + assert(selectedPool.minShare === expectedInitMinShare) + assert(selectedPool.weight === expectedInitWeight) + assert(selectedPool.schedulingMode === expectedSchedulingMode) } - } diff --git a/core/src/test/scala/org/apache/spark/scheduler/SchedulerIntegrationSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/SchedulerIntegrationSuite.scala index 398ac3d6202db..8300607ea888b 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/SchedulerIntegrationSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/SchedulerIntegrationSuite.scala @@ -95,12 +95,12 @@ abstract class SchedulerIntegrationSuite[T <: MockBackend: ClassTag] extends Spa } /** - * A map from partition -> results for all tasks of a job when you call this test framework's + * A map from partition to results for all tasks of a job when you call this test framework's * [[submit]] method. Two important considerations: * * 1. If there is a job failure, results may or may not be empty. If any tasks succeed before * the job has failed, they will get included in `results`. Instead, check for job failure by - * checking [[failure]]. (Also see [[assertDataStructuresEmpty()]]) + * checking [[failure]]. (Also see `assertDataStructuresEmpty()`) * * 2. This only gets cleared between tests. So you'll need to do special handling if you submit * more than one job in one test. @@ -410,7 +410,8 @@ private[spark] abstract class MockBackend( } } - override def killTask(taskId: Long, executorId: String, interruptThread: Boolean): Unit = { + override def killTask( + taskId: Long, executorId: String, interruptThread: Boolean, reason: String): Unit = { // We have to implement this b/c of SPARK-15385. // Its OK for this to be a no-op, because even if a backend does implement killTask, // it really can only be "best-effort" in any case, and the scheduler should be robust to that. diff --git a/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala index 7004128308af9..b22da565d86e7 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala @@ -198,7 +198,7 @@ class TaskContextSuite extends SparkFunSuite with BeforeAndAfter with LocalSpark sc = new SparkContext("local", "test") // Create a dummy task. We won't end up running this; we just want to collect // accumulator updates from it. - val taskMetrics = TaskMetrics.empty + val taskMetrics = TaskMetrics.registered val task = new Task[Int](0, 0, 0) { context = new TaskContextImpl(0, 0, 0L, 0, new TaskMemoryManager(SparkEnv.get.memoryManager, 0L), @@ -228,6 +228,32 @@ class TaskContextSuite extends SparkFunSuite with BeforeAndAfter with LocalSpark assert(res === Array("testPropValue,testPropValue")) } + test("immediately call a completion listener if the context is completed") { + var invocations = 0 + val context = TaskContext.empty() + context.markTaskCompleted() + context.addTaskCompletionListener(_ => invocations += 1) + assert(invocations == 1) + context.markTaskCompleted() + assert(invocations == 1) + } + + test("immediately call a failure listener if the context has failed") { + var invocations = 0 + var lastError: Throwable = null + val error = new RuntimeException + val context = TaskContext.empty() + context.markTaskFailed(error) + context.addTaskFailureListener { (_, e) => + lastError = e + invocations += 1 + } + assert(lastError == error) + assert(invocations == 1) + context.markTaskFailed(error) + assert(lastError == error) + assert(invocations == 1) + } } private object TaskContextSuite { diff --git a/core/src/test/scala/org/apache/spark/scheduler/TaskDescriptionSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskDescriptionSuite.scala index 9f1fe0515732e..97487ce1d2ca8 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/TaskDescriptionSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/TaskDescriptionSuite.scala @@ -17,6 +17,7 @@ package org.apache.spark.scheduler +import java.io.{ByteArrayOutputStream, DataOutputStream, UTFDataFormatException} import java.nio.ByteBuffer import java.util.Properties @@ -36,6 +37,21 @@ class TaskDescriptionSuite extends SparkFunSuite { val originalProperties = new Properties() originalProperties.put("property1", "18") originalProperties.put("property2", "test value") + // SPARK-19796 -- large property values (like a large job description for a long sql query) + // can cause problems for DataOutputStream, make sure we handle correctly + val sb = new StringBuilder() + (0 to 10000).foreach(_ => sb.append("1234567890")) + val largeString = sb.toString() + originalProperties.put("property3", largeString) + // make sure we've got a good test case + intercept[UTFDataFormatException] { + val out = new DataOutputStream(new ByteArrayOutputStream()) + try { + out.writeUTF(largeString) + } finally { + out.close() + } + } // Create a dummy byte buffer for the task. val taskBuffer = ByteBuffer.wrap(Array[Byte](1, 2, 3, 4)) diff --git a/core/src/test/scala/org/apache/spark/scheduler/TaskSchedulerImplSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskSchedulerImplSuite.scala index 9ae0bcd9b8860..8b9d45f734cda 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/TaskSchedulerImplSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/TaskSchedulerImplSuite.scala @@ -75,9 +75,7 @@ class TaskSchedulerImplSuite extends SparkFunSuite with LocalSparkContext with B def setupScheduler(confs: (String, String)*): TaskSchedulerImpl = { val conf = new SparkConf().setMaster("local").setAppName("TaskSchedulerImplSuite") - confs.foreach { case (k, v) => - conf.set(k, v) - } + confs.foreach { case (k, v) => conf.set(k, v) } sc = new SparkContext(conf) taskScheduler = new TaskSchedulerImpl(sc) setupHelper() @@ -904,4 +902,12 @@ class TaskSchedulerImplSuite extends SparkFunSuite with LocalSparkContext with B assert(taskDescs.size === 1) assert(taskDescs.head.executorId === "exec2") } + + test("TaskScheduler should throw IllegalArgumentException when schedulingMode is not supported") { + intercept[IllegalArgumentException] { + val taskScheduler = setupScheduler( + TaskSchedulerImpl.SCHEDULER_MODE_PROPERTY -> SchedulingMode.NONE.toString) + taskScheduler.initialize(new FakeSchedulerBackend) + } + } } diff --git a/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala index 2c2cda9f318eb..db14c9acfdce5 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala @@ -22,8 +22,10 @@ import java.util.{Properties, Random} import scala.collection.mutable import scala.collection.mutable.ArrayBuffer -import org.mockito.Matchers.{anyInt, anyString} +import org.mockito.Matchers.{any, anyInt, anyString} import org.mockito.Mockito.{mock, never, spy, verify, when} +import org.mockito.invocation.InvocationOnMock +import org.mockito.stubbing.Answer import org.apache.spark._ import org.apache.spark.internal.config @@ -192,6 +194,7 @@ class TaskSetManagerSuite extends SparkFunSuite with LocalSparkContext with Logg val taskOption = manager.resourceOffer("exec1", "host1", NO_PREF) assert(taskOption.isDefined) + clock.advance(1) // Tell it the task has finished manager.handleSuccessfulTask(0, createTaskResult(0, accumUpdates)) assert(sched.endedTasks(0) === Success) @@ -377,6 +380,7 @@ class TaskSetManagerSuite extends SparkFunSuite with LocalSparkContext with Logg sched = new FakeTaskScheduler(sc, ("exec1", "host1")) val taskSet = FakeTask.createTaskSet(1) val clock = new ManualClock + clock.advance(1) val manager = new TaskSetManager(sched, taskSet, MAX_TASK_FAILURES, clock = clock) assert(manager.resourceOffer("exec1", "host1", ANY).get.index === 0) @@ -394,6 +398,7 @@ class TaskSetManagerSuite extends SparkFunSuite with LocalSparkContext with Logg sched = new FakeTaskScheduler(sc, ("exec1", "host1")) val taskSet = FakeTask.createTaskSet(1) val clock = new ManualClock + clock.advance(1) val manager = new TaskSetManager(sched, taskSet, MAX_TASK_FAILURES, clock = clock) // Fail the task MAX_TASK_FAILURES times, and check that the task set is aborted @@ -427,6 +432,7 @@ class TaskSetManagerSuite extends SparkFunSuite with LocalSparkContext with Logg // affinity to exec1 on host1 - which we will fail. val taskSet = FakeTask.createTaskSet(1, Seq(TaskLocation("host1", "exec1"))) val clock = new ManualClock + clock.advance(1) // We don't directly use the application blacklist, but its presence triggers blacklisting // within the taskset. val mockListenerBus = mock(classOf[LiveListenerBus]) @@ -551,7 +557,9 @@ class TaskSetManagerSuite extends SparkFunSuite with LocalSparkContext with Logg Seq(TaskLocation("host1", "execB")), Seq(TaskLocation("host2", "execC")), Seq()) - val manager = new TaskSetManager(sched, taskSet, 1, clock = new ManualClock) + val clock = new ManualClock() + clock.advance(1) + val manager = new TaskSetManager(sched, taskSet, 1, clock = clock) sched.addExecutor("execA", "host1") manager.executorAdded() sched.addExecutor("execC", "host2") @@ -671,7 +679,11 @@ class TaskSetManagerSuite extends SparkFunSuite with LocalSparkContext with Logg val sched = new FakeTaskScheduler(sc, ("execA", "host1"), ("execB", "host2")) sched.initialize(new FakeSchedulerBackend() { - override def killTask(taskId: Long, executorId: String, interruptThread: Boolean): Unit = {} + override def killTask( + taskId: Long, + executorId: String, + interruptThread: Boolean, + reason: String): Unit = {} }) // Keep track of the number of tasks that are resubmitted, @@ -887,6 +899,7 @@ class TaskSetManagerSuite extends SparkFunSuite with LocalSparkContext with Logg val taskSet = FakeTask.createTaskSet(4) // Set the speculation multiplier to be 0 so speculative tasks are launched immediately sc.conf.set("spark.speculation.multiplier", "0.0") + sc.conf.set("spark.speculation", "true") val clock = new ManualClock() val manager = new TaskSetManager(sched, taskSet, MAX_TASK_FAILURES, clock = clock) val accumUpdatesByTask: Array[Seq[AccumulatorV2[_, _]]] = taskSet.tasks.map { task => @@ -904,6 +917,7 @@ class TaskSetManagerSuite extends SparkFunSuite with LocalSparkContext with Logg assert(task.executorId === k) } assert(sched.startedTasks.toSet === Set(0, 1, 2, 3)) + clock.advance(1) // Complete the 3 tasks and leave 1 task in running for (id <- Set(0, 1, 2)) { manager.handleSuccessfulTask(id, createTaskResult(id, accumUpdatesByTask(id))) @@ -927,7 +941,7 @@ class TaskSetManagerSuite extends SparkFunSuite with LocalSparkContext with Logg // Complete the speculative attempt for the running task manager.handleSuccessfulTask(4, createTaskResult(3, accumUpdatesByTask(3))) // Verify that it kills other running attempt - verify(sched.backend).killTask(3, "exec2", true) + verify(sched.backend).killTask(3, "exec2", true, "another attempt succeeded") // Because the SchedulerBackend was a mock, the 2nd copy of the task won't actually be // killed, so the FakeTaskScheduler is only told about the successful completion // of the speculated task. @@ -941,6 +955,7 @@ class TaskSetManagerSuite extends SparkFunSuite with LocalSparkContext with Logg // Set the speculation multiplier to be 0 so speculative tasks are launched immediately sc.conf.set("spark.speculation.multiplier", "0.0") sc.conf.set("spark.speculation.quantile", "0.6") + sc.conf.set("spark.speculation", "true") val clock = new ManualClock() val manager = new TaskSetManager(sched, taskSet, MAX_TASK_FAILURES, clock = clock) val accumUpdatesByTask: Array[Seq[AccumulatorV2[_, _]]] = taskSet.tasks.map { task => @@ -961,6 +976,7 @@ class TaskSetManagerSuite extends SparkFunSuite with LocalSparkContext with Logg tasks += task } assert(sched.startedTasks.toSet === (0 until 5).toSet) + clock.advance(1) // Complete 3 tasks and leave 2 tasks in running for (id <- Set(0, 1, 2)) { manager.handleSuccessfulTask(id, createTaskResult(id, accumUpdatesByTask(id))) @@ -1013,14 +1029,14 @@ class TaskSetManagerSuite extends SparkFunSuite with LocalSparkContext with Logg manager.handleSuccessfulTask(speculativeTask.taskId, createTaskResult(3, accumUpdatesByTask(3))) // Verify that it kills other running attempt val origTask = originalTasks(speculativeTask.index) - verify(sched.backend).killTask(origTask.taskId, "exec2", true) + verify(sched.backend).killTask(origTask.taskId, "exec2", true, "another attempt succeeded") // Because the SchedulerBackend was a mock, the 2nd copy of the task won't actually be // killed, so the FakeTaskScheduler is only told about the successful completion // of the speculated task. assert(sched.endedTasks(3) === Success) // also because the scheduler is a mock, our manager isn't notified about the task killed event, // so we do that manually - manager.handleFailedTask(origTask.taskId, TaskState.KILLED, TaskKilled) + manager.handleFailedTask(origTask.taskId, TaskState.KILLED, TaskKilled("test")) // this task has "failed" 4 times, but one of them doesn't count, so keep running the stage assert(manager.tasksSuccessful === 4) assert(!manager.isZombie) @@ -1037,11 +1053,35 @@ class TaskSetManagerSuite extends SparkFunSuite with LocalSparkContext with Logg createTaskResult(3, accumUpdatesByTask(3))) // Verify that it kills other running attempt val origTask2 = originalTasks(speculativeTask2.index) - verify(sched.backend).killTask(origTask2.taskId, "exec2", true) + verify(sched.backend).killTask(origTask2.taskId, "exec2", true, "another attempt succeeded") assert(manager.tasksSuccessful === 5) assert(manager.isZombie) } + + test("SPARK-19868: DagScheduler only notified of taskEnd when state is ready") { + // dagScheduler.taskEnded() is async, so it may *seem* ok to call it before we've set all + // appropriate state, eg. isZombie. However, this sets up a race that could go the wrong way. + // This is a super-focused regression test which checks the zombie state as soon as + // dagScheduler.taskEnded() is called, to ensure we haven't introduced a race. + sc = new SparkContext("local", "test") + sched = new FakeTaskScheduler(sc, ("exec1", "host1")) + val mockDAGScheduler = mock(classOf[DAGScheduler]) + sched.dagScheduler = mockDAGScheduler + val taskSet = FakeTask.createTaskSet(numTasks = 1, stageId = 0, stageAttemptId = 0) + val manager = new TaskSetManager(sched, taskSet, MAX_TASK_FAILURES, clock = new ManualClock(1)) + when(mockDAGScheduler.taskEnded(any(), any(), any(), any(), any())).thenAnswer( + new Answer[Unit] { + override def answer(invocationOnMock: InvocationOnMock): Unit = { + assert(manager.isZombie) + } + }) + val taskOption = manager.resourceOffer("exec1", "host1", NO_PREF) + assert(taskOption.isDefined) + // this would fail, inside our mock dag scheduler, if it calls dagScheduler.taskEnded() too soon + manager.handleSuccessfulTask(0, createTaskResult(0)) + } + test("SPARK-17894: Verify TaskSetManagers for different stage attempts have unique names") { sc = new SparkContext("local", "test") sched = new FakeTaskScheduler(sc, ("exec1", "host1")) @@ -1092,8 +1132,7 @@ class TaskSetManagerSuite extends SparkFunSuite with LocalSparkContext with Logg ExecutorLostFailure(taskDescs(1).executorId, exitCausedByApp = false, reason = None)) tsmSpy.handleFailedTask(taskDescs(2).taskId, TaskState.FAILED, TaskCommitDenied(0, 2, 0)) - tsmSpy.handleFailedTask(taskDescs(3).taskId, TaskState.KILLED, - TaskKilled) + tsmSpy.handleFailedTask(taskDescs(3).taskId, TaskState.KILLED, TaskKilled("test")) // Make sure that the blacklist ignored all of the task failures above, since they aren't // the fault of the executor where the task was running. diff --git a/core/src/test/scala/org/apache/spark/security/CryptoStreamUtilsSuite.scala b/core/src/test/scala/org/apache/spark/security/CryptoStreamUtilsSuite.scala index 0f3a4a03618ed..608052f5ed855 100644 --- a/core/src/test/scala/org/apache/spark/security/CryptoStreamUtilsSuite.scala +++ b/core/src/test/scala/org/apache/spark/security/CryptoStreamUtilsSuite.scala @@ -16,9 +16,11 @@ */ package org.apache.spark.security -import java.io.{ByteArrayInputStream, ByteArrayOutputStream} +import java.io.{ByteArrayInputStream, ByteArrayOutputStream, FileInputStream, FileOutputStream} +import java.nio.channels.Channels import java.nio.charset.StandardCharsets.UTF_8 -import java.util.UUID +import java.nio.file.Files +import java.util.{Arrays, Random, UUID} import com.google.common.io.ByteStreams @@ -121,6 +123,46 @@ class CryptoStreamUtilsSuite extends SparkFunSuite { } } + test("crypto stream wrappers") { + val testData = new Array[Byte](128 * 1024) + new Random().nextBytes(testData) + + val conf = createConf() + val key = createKey(conf) + val file = Files.createTempFile("crypto", ".test").toFile() + + val outStream = createCryptoOutputStream(new FileOutputStream(file), conf, key) + try { + ByteStreams.copy(new ByteArrayInputStream(testData), outStream) + } finally { + outStream.close() + } + + val inStream = createCryptoInputStream(new FileInputStream(file), conf, key) + try { + val inStreamData = ByteStreams.toByteArray(inStream) + assert(Arrays.equals(inStreamData, testData)) + } finally { + inStream.close() + } + + val outChannel = createWritableChannel(new FileOutputStream(file).getChannel(), conf, key) + try { + val inByteChannel = Channels.newChannel(new ByteArrayInputStream(testData)) + ByteStreams.copy(inByteChannel, outChannel) + } finally { + outChannel.close() + } + + val inChannel = createReadableChannel(new FileInputStream(file).getChannel(), conf, key) + try { + val inChannelData = ByteStreams.toByteArray(Channels.newInputStream(inChannel)) + assert(Arrays.equals(inChannelData, testData)) + } finally { + inChannel.close() + } + } + private def createConf(extra: (String, String)*): SparkConf = { val conf = new SparkConf() extra.foreach { case (k, v) => conf.set(k, v) } diff --git a/core/src/test/scala/org/apache/spark/security/EncryptionFunSuite.scala b/core/src/test/scala/org/apache/spark/security/EncryptionFunSuite.scala new file mode 100644 index 0000000000000..3f52dc41abf6d --- /dev/null +++ b/core/src/test/scala/org/apache/spark/security/EncryptionFunSuite.scala @@ -0,0 +1,39 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.security + +import org.apache.spark.{SparkConf, SparkFunSuite} +import org.apache.spark.internal.config._ + +trait EncryptionFunSuite { + + this: SparkFunSuite => + + /** + * Runs a test twice, initializing a SparkConf object with encryption off, then on. It's ok + * for the test to modify the provided SparkConf. + */ + final protected def encryptionTest(name: String)(fn: SparkConf => Unit) { + Seq(false, true).foreach { encrypt => + test(s"$name (encryption = ${ if (encrypt) "on" else "off" })") { + val conf = new SparkConf().set(IO_ENCRYPTION_ENABLED, encrypt) + fn(conf) + } + } + } + +} diff --git a/core/src/test/scala/org/apache/spark/serializer/KryoSerializerSuite.scala b/core/src/test/scala/org/apache/spark/serializer/KryoSerializerSuite.scala index a30653bb36fa1..7c3922e47fbb9 100644 --- a/core/src/test/scala/org/apache/spark/serializer/KryoSerializerSuite.scala +++ b/core/src/test/scala/org/apache/spark/serializer/KryoSerializerSuite.scala @@ -76,6 +76,9 @@ class KryoSerializerSuite extends SparkFunSuite with SharedSparkContext { } test("basic types") { + val conf = new SparkConf(false) + conf.set("spark.kryo.registrationRequired", "true") + val ser = new KryoSerializer(conf).newInstance() def check[T: ClassTag](t: T) { assert(ser.deserialize[T](ser.serialize(t)) === t) @@ -106,6 +109,9 @@ class KryoSerializerSuite extends SparkFunSuite with SharedSparkContext { } test("pairs") { + val conf = new SparkConf(false) + conf.set("spark.kryo.registrationRequired", "true") + val ser = new KryoSerializer(conf).newInstance() def check[T: ClassTag](t: T) { assert(ser.deserialize[T](ser.serialize(t)) === t) @@ -130,12 +136,16 @@ class KryoSerializerSuite extends SparkFunSuite with SharedSparkContext { } test("Scala data structures") { + val conf = new SparkConf(false) + conf.set("spark.kryo.registrationRequired", "true") + val ser = new KryoSerializer(conf).newInstance() def check[T: ClassTag](t: T) { assert(ser.deserialize[T](ser.serialize(t)) === t) } check(List[Int]()) check(List[Int](1, 2, 3)) + check(Seq[Int](1, 2, 3)) check(List[String]()) check(List[String]("x", "y", "z")) check(None) diff --git a/core/src/test/scala/org/apache/spark/serializer/SerializerPropertiesSuite.scala b/core/src/test/scala/org/apache/spark/serializer/SerializerPropertiesSuite.scala index 4ce3b941bea55..99882bf76e29d 100644 --- a/core/src/test/scala/org/apache/spark/serializer/SerializerPropertiesSuite.scala +++ b/core/src/test/scala/org/apache/spark/serializer/SerializerPropertiesSuite.scala @@ -29,7 +29,7 @@ import org.apache.spark.serializer.KryoTest.RegistratorWithoutAutoReset /** * Tests to ensure that [[Serializer]] implementations obey the API contracts for methods that * describe properties of the serialized stream, such as - * [[Serializer.supportsRelocationOfSerializedObjects]]. + * `Serializer.supportsRelocationOfSerializedObjects`. */ class SerializerPropertiesSuite extends SparkFunSuite { diff --git a/core/src/test/scala/org/apache/spark/storage/BlockManagerReplicationSuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockManagerReplicationSuite.scala index 70ae9be118fc0..c100803279eaf 100644 --- a/core/src/test/scala/org/apache/spark/storage/BlockManagerReplicationSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/BlockManagerReplicationSuite.scala @@ -17,6 +17,8 @@ package org.apache.spark.storage +import java.util.Locale + import scala.collection.mutable.ArrayBuffer import scala.concurrent.duration._ import scala.language.implicitConversions @@ -28,6 +30,7 @@ import org.scalatest.concurrent.Eventually._ import org.apache.spark._ import org.apache.spark.broadcast.BroadcastManager +import org.apache.spark.internal.Logging import org.apache.spark.memory.UnifiedMemoryManager import org.apache.spark.network.BlockTransferService import org.apache.spark.network.netty.NettyBlockTransferService @@ -36,6 +39,7 @@ import org.apache.spark.scheduler.LiveListenerBus import org.apache.spark.serializer.{KryoSerializer, SerializerManager} import org.apache.spark.shuffle.sort.SortShuffleManager import org.apache.spark.storage.StorageLevel._ +import org.apache.spark.util.Utils trait BlockManagerReplicationBehavior extends SparkFunSuite with Matchers @@ -43,6 +47,7 @@ trait BlockManagerReplicationBehavior extends SparkFunSuite with LocalSparkContext { val conf: SparkConf + protected var rpcEnv: RpcEnv = null protected var master: BlockManagerMaster = null protected lazy val securityMgr = new SecurityManager(conf) @@ -55,7 +60,6 @@ trait BlockManagerReplicationBehavior extends SparkFunSuite protected val allStores = new ArrayBuffer[BlockManager] // Reuse a serializer across tests to avoid creating a new thread-local buffer on each test - protected lazy val serializer = new KryoSerializer(conf) // Implicitly convert strings to BlockIds for test clarity. @@ -372,9 +376,10 @@ trait BlockManagerReplicationBehavior extends SparkFunSuite storageLevels.foreach { storageLevel => // Put the block into one of the stores - val blockId = new TestBlockId( - "block-with-" + storageLevel.description.replace(" ", "-").toLowerCase) - stores(0).putSingle(blockId, new Array[Byte](blockSize), storageLevel) + val blockId = TestBlockId( + "block-with-" + storageLevel.description.replace(" ", "-").toLowerCase(Locale.ROOT)) + val testValue = Array.fill[Byte](blockSize)(1) + stores(0).putSingle(blockId, testValue, storageLevel) // Assert that master know two locations for the block val blockLocations = master.getLocations(blockId).map(_.executorId).toSet @@ -386,12 +391,23 @@ trait BlockManagerReplicationBehavior extends SparkFunSuite testStore => blockLocations.contains(testStore.blockManagerId.executorId) }.foreach { testStore => val testStoreName = testStore.blockManagerId.executorId - assert( - testStore.getLocalValues(blockId).isDefined, s"$blockId was not found in $testStoreName") - testStore.releaseLock(blockId) + val blockResultOpt = testStore.getLocalValues(blockId) + assert(blockResultOpt.isDefined, s"$blockId was not found in $testStoreName") + val localValues = blockResultOpt.get.data.toSeq + assert(localValues.size == 1) + assert(localValues.head === testValue) assert(master.getLocations(blockId).map(_.executorId).toSet.contains(testStoreName), s"master does not have status for ${blockId.name} in $testStoreName") + val memoryStore = testStore.memoryStore + if (memoryStore.contains(blockId) && !storageLevel.deserialized) { + memoryStore.getBytes(blockId).get.chunks.foreach { byteBuffer => + assert(storageLevel.useOffHeap == byteBuffer.isDirect, + s"memory mode ${storageLevel.memoryMode} is not compatible with " + + byteBuffer.getClass.getSimpleName) + } + } + val blockStatus = master.getBlockStatus(blockId)(testStore.blockManagerId) // Assert that block status in the master for this store has expected storage level @@ -459,9 +475,8 @@ class BlockManagerProactiveReplicationSuite extends BlockManagerReplicationBehav conf.set("spark.storage.replication.proactive", "true") conf.set("spark.storage.exceptionOnPinLeak", "true") - (2 to 5).foreach{ i => - // flakes in palantir/spark - ignore(s"proactive block replication - $i replicas - ${i - 1} block manager deletions") { + (2 to 5).foreach { i => + test(s"proactive block replication - $i replicas - ${i - 1} block manager deletions") { testProactiveReplication(i) } } @@ -482,27 +497,61 @@ class BlockManagerProactiveReplicationSuite extends BlockManagerReplicationBehav assert(blockLocations.size === replicationFactor) // remove a random blockManager - val executorsToRemove = blockLocations.take(replicationFactor - 1) + val executorsToRemove = blockLocations.take(replicationFactor - 1).toSet logInfo(s"Removing $executorsToRemove") - executorsToRemove.foreach{exec => - master.removeExecutor(exec.executorId) + initialStores.filter(bm => executorsToRemove.contains(bm.blockManagerId)).foreach { bm => + master.removeExecutor(bm.blockManagerId.executorId) + bm.stop() // giving enough time for replication to happen and new block be reported to master - Thread.sleep(200) + eventually(timeout(5 seconds), interval(100 millis)) { + val newLocations = master.getLocations(blockId).toSet + assert(newLocations.size === replicationFactor) + } } - // giving enough time for replication complete and locks released - Thread.sleep(500) - - val newLocations = master.getLocations(blockId).toSet + val newLocations = eventually(timeout(5 seconds), interval(100 millis)) { + val _newLocations = master.getLocations(blockId).toSet + assert(_newLocations.size === replicationFactor) + _newLocations + } logInfo(s"New locations : $newLocations") - assert(newLocations.size === replicationFactor) - // there should only be one common block manager between initial and new locations - assert(newLocations.intersect(blockLocations.toSet).size === 1) - - // check if all the read locks have been released - initialStores.filter(bm => newLocations.contains(bm.blockManagerId)).foreach { bm => - val locks = bm.releaseAllLocksForTask(BlockInfo.NON_TASK_WRITER) - assert(locks.size === 0, "Read locks unreleased!") + + // new locations should not contain stopped block managers + assert(newLocations.forall(bmId => !executorsToRemove.contains(bmId)), + "New locations contain stopped block managers.") + + // Make sure all locks have been released. + eventually(timeout(1000 milliseconds), interval(10 milliseconds)) { + initialStores.filter(bm => newLocations.contains(bm.blockManagerId)).foreach { bm => + assert(bm.blockInfoManager.getTaskLockCount(BlockInfo.NON_TASK_WRITER) === 0) + } } } } + +class DummyTopologyMapper(conf: SparkConf) extends TopologyMapper(conf) with Logging { + // number of racks to test with + val numRacks = 3 + + /** + * Gets the topology information given the host name + * + * @param hostname Hostname + * @return random topology + */ + override def getTopologyForHost(hostname: String): Option[String] = { + Some(s"/Rack-${Utils.random.nextInt(numRacks)}") + } +} + +class BlockManagerBasicStrategyReplicationSuite extends BlockManagerReplicationBehavior { + val conf: SparkConf = new SparkConf(false).set("spark.app.id", "test") + conf.set("spark.kryoserializer.buffer", "1m") + conf.set( + "spark.storage.replication.policy", + classOf[BasicBlockReplicationPolicy].getName) + conf.set( + "spark.storage.replication.topologyMapper", + classOf[DummyTopologyMapper].getName) +} + diff --git a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala index 64a67b4c4cbab..a8b9604899838 100644 --- a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala @@ -35,6 +35,7 @@ import org.scalatest.concurrent.Timeouts._ import org.apache.spark._ import org.apache.spark.broadcast.BroadcastManager import org.apache.spark.executor.DataReadMethod +import org.apache.spark.internal.config._ import org.apache.spark.memory.UnifiedMemoryManager import org.apache.spark.network.{BlockDataManager, BlockTransferService} import org.apache.spark.network.buffer.{ManagedBuffer, NioManagedBuffer} @@ -42,6 +43,7 @@ import org.apache.spark.network.netty.NettyBlockTransferService import org.apache.spark.network.shuffle.BlockFetchingListener import org.apache.spark.rpc.RpcEnv import org.apache.spark.scheduler.LiveListenerBus +import org.apache.spark.security.{CryptoStreamUtils, EncryptionFunSuite} import org.apache.spark.serializer.{JavaSerializer, KryoSerializer, SerializerManager} import org.apache.spark.shuffle.sort.SortShuffleManager import org.apache.spark.storage.BlockManagerMessages.BlockManagerHeartbeat @@ -49,7 +51,8 @@ import org.apache.spark.util._ import org.apache.spark.util.io.ChunkedByteBuffer class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterEach - with PrivateMethodTester with LocalSparkContext with ResetSystemProperties { + with PrivateMethodTester with LocalSparkContext with ResetSystemProperties + with EncryptionFunSuite { import BlockManagerSuite._ @@ -75,16 +78,24 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE maxMem: Long, name: String = SparkContext.DRIVER_IDENTIFIER, master: BlockManagerMaster = this.master, - transferService: Option[BlockTransferService] = Option.empty): BlockManager = { - conf.set("spark.testing.memory", maxMem.toString) - conf.set("spark.memory.offHeap.size", maxMem.toString) - val serializer = new KryoSerializer(conf) + transferService: Option[BlockTransferService] = Option.empty, + testConf: Option[SparkConf] = None): BlockManager = { + val bmConf = testConf.map(_.setAll(conf.getAll)).getOrElse(conf) + bmConf.set("spark.testing.memory", maxMem.toString) + bmConf.set("spark.memory.offHeap.size", maxMem.toString) + val serializer = new KryoSerializer(bmConf) + val encryptionKey = if (bmConf.get(IO_ENCRYPTION_ENABLED)) { + Some(CryptoStreamUtils.createKey(bmConf)) + } else { + None + } + val bmSecurityMgr = new SecurityManager(bmConf, encryptionKey) val transfer = transferService .getOrElse(new NettyBlockTransferService(conf, securityMgr, "localhost", "localhost", 0, 1)) - val memManager = UnifiedMemoryManager(conf, numCores = 1) - val serializerManager = new SerializerManager(serializer, conf) - val blockManager = new BlockManager(name, rpcEnv, master, serializerManager, conf, - memManager, mapOutputTracker, shuffleManager, transfer, securityMgr, 0) + val memManager = UnifiedMemoryManager(bmConf, numCores = 1) + val serializerManager = new SerializerManager(serializer, bmConf) + val blockManager = new BlockManager(name, rpcEnv, master, serializerManager, bmConf, + memManager, mapOutputTracker, shuffleManager, transfer, bmSecurityMgr, 0) memManager.setMemoryStore(blockManager.memoryStore) blockManager.initialize("app-id") blockManager @@ -610,8 +621,8 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE assert(store.memoryStore.contains(rdd(0, 3)), "rdd_0_3 was not in store") } - test("on-disk storage") { - store = makeBlockManager(1200) + encryptionTest("on-disk storage") { _conf => + store = makeBlockManager(1200, testConf = Some(_conf)) val a1 = new Array[Byte](400) val a2 = new Array[Byte](400) val a3 = new Array[Byte](400) @@ -623,34 +634,35 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE assert(store.getSingleAndReleaseLock("a1").isDefined, "a1 was in store") } - test("disk and memory storage") { - testDiskAndMemoryStorage(StorageLevel.MEMORY_AND_DISK, getAsBytes = false) + encryptionTest("disk and memory storage") { _conf => + testDiskAndMemoryStorage(StorageLevel.MEMORY_AND_DISK, getAsBytes = false, testConf = conf) } - test("disk and memory storage with getLocalBytes") { - testDiskAndMemoryStorage(StorageLevel.MEMORY_AND_DISK, getAsBytes = true) + encryptionTest("disk and memory storage with getLocalBytes") { _conf => + testDiskAndMemoryStorage(StorageLevel.MEMORY_AND_DISK, getAsBytes = true, testConf = conf) } - test("disk and memory storage with serialization") { - testDiskAndMemoryStorage(StorageLevel.MEMORY_AND_DISK_SER, getAsBytes = false) + encryptionTest("disk and memory storage with serialization") { _conf => + testDiskAndMemoryStorage(StorageLevel.MEMORY_AND_DISK_SER, getAsBytes = false, testConf = conf) } - test("disk and memory storage with serialization and getLocalBytes") { - testDiskAndMemoryStorage(StorageLevel.MEMORY_AND_DISK_SER, getAsBytes = true) + encryptionTest("disk and memory storage with serialization and getLocalBytes") { _conf => + testDiskAndMemoryStorage(StorageLevel.MEMORY_AND_DISK_SER, getAsBytes = true, testConf = conf) } - test("disk and off-heap memory storage") { - testDiskAndMemoryStorage(StorageLevel.OFF_HEAP, getAsBytes = false) + encryptionTest("disk and off-heap memory storage") { _conf => + testDiskAndMemoryStorage(StorageLevel.OFF_HEAP, getAsBytes = false, testConf = conf) } - test("disk and off-heap memory storage with getLocalBytes") { - testDiskAndMemoryStorage(StorageLevel.OFF_HEAP, getAsBytes = true) + encryptionTest("disk and off-heap memory storage with getLocalBytes") { _conf => + testDiskAndMemoryStorage(StorageLevel.OFF_HEAP, getAsBytes = true, testConf = conf) } def testDiskAndMemoryStorage( storageLevel: StorageLevel, - getAsBytes: Boolean): Unit = { - store = makeBlockManager(12000) + getAsBytes: Boolean, + testConf: SparkConf): Unit = { + store = makeBlockManager(12000, testConf = Some(testConf)) val accessMethod = if (getAsBytes) store.getLocalBytesAndReleaseLock else store.getSingleAndReleaseLock val a1 = new Array[Byte](4000) @@ -678,8 +690,8 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE } } - test("LRU with mixed storage levels") { - store = makeBlockManager(12000) + encryptionTest("LRU with mixed storage levels") { _conf => + store = makeBlockManager(12000, testConf = Some(_conf)) val a1 = new Array[Byte](4000) val a2 = new Array[Byte](4000) val a3 = new Array[Byte](4000) @@ -700,8 +712,8 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE assert(store.getSingleAndReleaseLock("a4").isDefined, "a4 was not in store") } - test("in-memory LRU with streams") { - store = makeBlockManager(12000) + encryptionTest("in-memory LRU with streams") { _conf => + store = makeBlockManager(12000, testConf = Some(_conf)) val list1 = List(new Array[Byte](2000), new Array[Byte](2000)) val list2 = List(new Array[Byte](2000), new Array[Byte](2000)) val list3 = List(new Array[Byte](2000), new Array[Byte](2000)) @@ -728,8 +740,8 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE assert(store.getAndReleaseLock("list3") === None, "list1 was in store") } - test("LRU with mixed storage levels and streams") { - store = makeBlockManager(12000) + encryptionTest("LRU with mixed storage levels and streams") { _conf => + store = makeBlockManager(12000, testConf = Some(_conf)) val list1 = List(new Array[Byte](2000), new Array[Byte](2000)) val list2 = List(new Array[Byte](2000), new Array[Byte](2000)) val list3 = List(new Array[Byte](2000), new Array[Byte](2000)) @@ -1325,7 +1337,8 @@ private object BlockManagerSuite { val getAndReleaseLock: (BlockId) => Option[BlockResult] = wrapGet(store.get) val getSingleAndReleaseLock: (BlockId) => Option[Any] = wrapGet(store.getSingle) val getLocalBytesAndReleaseLock: (BlockId) => Option[ChunkedByteBuffer] = { - wrapGet(store.getLocalBytes) + val allocator = ByteBuffer.allocate _ + wrapGet { bid => store.getLocalBytes(bid).map(_.toChunkedByteBuffer(allocator)) } } } diff --git a/core/src/test/scala/org/apache/spark/storage/BlockReplicationPolicySuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockReplicationPolicySuite.scala index 800c3899f1a72..4000218e71a8b 100644 --- a/core/src/test/scala/org/apache/spark/storage/BlockReplicationPolicySuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/BlockReplicationPolicySuite.scala @@ -18,34 +18,35 @@ package org.apache.spark.storage import scala.collection.mutable +import scala.language.implicitConversions +import scala.util.Random import org.scalatest.{BeforeAndAfter, Matchers} import org.apache.spark.{LocalSparkContext, SparkFunSuite} -class BlockReplicationPolicySuite extends SparkFunSuite +class RandomBlockReplicationPolicyBehavior extends SparkFunSuite with Matchers with BeforeAndAfter with LocalSparkContext { // Implicitly convert strings to BlockIds for test clarity. - private implicit def StringToBlockId(value: String): BlockId = new TestBlockId(value) + protected implicit def StringToBlockId(value: String): BlockId = new TestBlockId(value) + val replicationPolicy: BlockReplicationPolicy = new RandomBlockReplicationPolicy + + val blockId = "test-block" /** * Test if we get the required number of peers when using random sampling from - * RandomBlockReplicationPolicy + * BlockReplicationPolicy */ - test(s"block replication - random block replication policy") { + test("block replication - random block replication policy") { val numBlockManagers = 10 val storeSize = 1000 - val blockManagers = (1 to numBlockManagers).map { i => - BlockManagerId(s"store-$i", "localhost", 1000 + i, None) - } + val blockManagers = generateBlockManagerIds(numBlockManagers, Seq("/Rack-1")) val candidateBlockManager = BlockManagerId("test-store", "localhost", 1000, None) - val replicationPolicy = new RandomBlockReplicationPolicy - val blockId = "test-block" - (1 to 10).foreach {numReplicas => + (1 to 10).foreach { numReplicas => logDebug(s"Num replicas : $numReplicas") val randomPeers = replicationPolicy.prioritize( candidateBlockManager, @@ -68,7 +69,69 @@ class BlockReplicationPolicySuite extends SparkFunSuite logDebug(s"Random peers : ${secondPass.mkString(", ")}") assert(secondPass.toSet.size === numReplicas) } + } + + /** + * Returns a sequence of [[BlockManagerId]], whose rack is randomly picked from the given `racks`. + * Note that, each rack will be picked at least once from `racks`, if `count` is greater or equal + * to the number of `racks`. + */ + protected def generateBlockManagerIds(count: Int, racks: Seq[String]): Seq[BlockManagerId] = { + val randomizedRacks: Seq[String] = Random.shuffle( + racks ++ racks.length.until(count).map(_ => racks(Random.nextInt(racks.length))) + ) + (0 until count).map { i => + BlockManagerId(s"Exec-$i", s"Host-$i", 10000 + i, Some(randomizedRacks(i))) + } } +} +class TopologyAwareBlockReplicationPolicyBehavior extends RandomBlockReplicationPolicyBehavior { + override val replicationPolicy = new BasicBlockReplicationPolicy + + test("All peers in the same rack") { + val racks = Seq("/default-rack") + val numBlockManager = 10 + (1 to 10).foreach {numReplicas => + val peers = generateBlockManagerIds(numBlockManager, racks) + val blockManager = BlockManagerId("Driver", "Host-driver", 10001, Some(racks.head)) + + val prioritizedPeers = replicationPolicy.prioritize( + blockManager, + peers, + mutable.HashSet.empty, + blockId, + numReplicas + ) + + assert(prioritizedPeers.toSet.size == numReplicas) + assert(prioritizedPeers.forall(p => p.host != blockManager.host)) + } + } + + test("Peers in 2 racks") { + val racks = Seq("/Rack-1", "/Rack-2") + (1 to 10).foreach {numReplicas => + val peers = generateBlockManagerIds(10, racks) + val blockManager = BlockManagerId("Driver", "Host-driver", 9001, Some(racks.head)) + + val prioritizedPeers = replicationPolicy.prioritize( + blockManager, + peers, + mutable.HashSet.empty, + blockId, + numReplicas + ) + + assert(prioritizedPeers.toSet.size == numReplicas) + val priorityPeers = prioritizedPeers.take(2) + assert(priorityPeers.forall(p => p.host != blockManager.host)) + if(numReplicas > 1) { + // both these conditions should be satisfied when numReplicas > 1 + assert(priorityPeers.exists(p => p.topologyInfo == blockManager.topologyInfo)) + assert(priorityPeers.exists(p => p.topologyInfo != blockManager.topologyInfo)) + } + } + } } diff --git a/core/src/test/scala/org/apache/spark/storage/DiskStoreSuite.scala b/core/src/test/scala/org/apache/spark/storage/DiskStoreSuite.scala index 9e6b02b9eac4d..67fc084e8a13d 100644 --- a/core/src/test/scala/org/apache/spark/storage/DiskStoreSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/DiskStoreSuite.scala @@ -18,15 +18,23 @@ package org.apache.spark.storage import java.nio.{ByteBuffer, MappedByteBuffer} -import java.util.Arrays +import java.util.{Arrays, Random} -import org.apache.spark.{SparkConf, SparkFunSuite} +import com.google.common.io.{ByteStreams, Files} +import io.netty.channel.FileRegion + +import org.apache.spark.{SecurityManager, SparkConf, SparkFunSuite} +import org.apache.spark.network.util.{ByteArrayWritableChannel, JavaUtils} +import org.apache.spark.security.CryptoStreamUtils import org.apache.spark.util.io.ChunkedByteBuffer import org.apache.spark.util.Utils class DiskStoreSuite extends SparkFunSuite { test("reads of memory-mapped and non memory-mapped files are equivalent") { + val conf = new SparkConf() + val securityManager = new SecurityManager(conf) + // It will cause error when we tried to re-open the filestore and the // memory-mapped byte buffer tot he file has not been GC on Windows. assume(!Utils.isWindows) @@ -37,16 +45,18 @@ class DiskStoreSuite extends SparkFunSuite { val byteBuffer = new ChunkedByteBuffer(ByteBuffer.wrap(bytes)) val blockId = BlockId("rdd_1_2") - val diskBlockManager = new DiskBlockManager(new SparkConf(), deleteFilesOnStop = true) + val diskBlockManager = new DiskBlockManager(conf, deleteFilesOnStop = true) - val diskStoreMapped = new DiskStore(new SparkConf().set(confKey, "0"), diskBlockManager) + val diskStoreMapped = new DiskStore(conf.clone().set(confKey, "0"), diskBlockManager, + securityManager) diskStoreMapped.putBytes(blockId, byteBuffer) - val mapped = diskStoreMapped.getBytes(blockId) + val mapped = diskStoreMapped.getBytes(blockId).asInstanceOf[ByteBufferBlockData].buffer assert(diskStoreMapped.remove(blockId)) - val diskStoreNotMapped = new DiskStore(new SparkConf().set(confKey, "1m"), diskBlockManager) + val diskStoreNotMapped = new DiskStore(conf.clone().set(confKey, "1m"), diskBlockManager, + securityManager) diskStoreNotMapped.putBytes(blockId, byteBuffer) - val notMapped = diskStoreNotMapped.getBytes(blockId) + val notMapped = diskStoreNotMapped.getBytes(blockId).asInstanceOf[ByteBufferBlockData].buffer // Not possible to do isInstanceOf due to visibility of HeapByteBuffer assert(notMapped.getChunks().forall(_.getClass.getName.endsWith("HeapByteBuffer")), @@ -63,4 +73,95 @@ class DiskStoreSuite extends SparkFunSuite { assert(Arrays.equals(mapped.toArray, bytes)) assert(Arrays.equals(notMapped.toArray, bytes)) } + + test("block size tracking") { + val conf = new SparkConf() + val diskBlockManager = new DiskBlockManager(conf, deleteFilesOnStop = true) + val diskStore = new DiskStore(conf, diskBlockManager, new SecurityManager(conf)) + + val blockId = BlockId("rdd_1_2") + diskStore.put(blockId) { chan => + val buf = ByteBuffer.wrap(new Array[Byte](32)) + while (buf.hasRemaining()) { + chan.write(buf) + } + } + + assert(diskStore.getSize(blockId) === 32L) + diskStore.remove(blockId) + assert(diskStore.getSize(blockId) === 0L) + } + + test("block data encryption") { + val testDir = Utils.createTempDir() + val testData = new Array[Byte](128 * 1024) + new Random().nextBytes(testData) + + val conf = new SparkConf() + val securityManager = new SecurityManager(conf, Some(CryptoStreamUtils.createKey(conf))) + val diskBlockManager = new DiskBlockManager(conf, deleteFilesOnStop = true) + val diskStore = new DiskStore(conf, diskBlockManager, securityManager) + + val blockId = BlockId("rdd_1_2") + diskStore.put(blockId) { chan => + val buf = ByteBuffer.wrap(testData) + while (buf.hasRemaining()) { + chan.write(buf) + } + } + + assert(diskStore.getSize(blockId) === testData.length) + + val diskData = Files.toByteArray(diskBlockManager.getFile(blockId.name)) + assert(!Arrays.equals(testData, diskData)) + + val blockData = diskStore.getBytes(blockId) + assert(blockData.isInstanceOf[EncryptedBlockData]) + assert(blockData.size === testData.length) + Map( + "input stream" -> readViaInputStream _, + "chunked byte buffer" -> readViaChunkedByteBuffer _, + "nio byte buffer" -> readViaNioBuffer _, + "managed buffer" -> readViaManagedBuffer _ + ).foreach { case (name, fn) => + val readData = fn(blockData) + assert(readData.length === blockData.size, s"Size of data read via $name did not match.") + assert(Arrays.equals(testData, readData), s"Data read via $name did not match.") + } + } + + private def readViaInputStream(data: BlockData): Array[Byte] = { + val is = data.toInputStream() + try { + ByteStreams.toByteArray(is) + } finally { + is.close() + } + } + + private def readViaChunkedByteBuffer(data: BlockData): Array[Byte] = { + val buf = data.toChunkedByteBuffer(ByteBuffer.allocate _) + try { + buf.toArray + } finally { + buf.dispose() + } + } + + private def readViaNioBuffer(data: BlockData): Array[Byte] = { + JavaUtils.bufferToArray(data.toByteBuffer()) + } + + private def readViaManagedBuffer(data: BlockData): Array[Byte] = { + val region = data.toNetty().asInstanceOf[FileRegion] + val byteChannel = new ByteArrayWritableChannel(data.size.toInt) + + while (region.transfered() < region.count()) { + region.transferTo(byteChannel, region.transfered()) + } + + byteChannel.close() + byteChannel.getData + } + } diff --git a/core/src/test/scala/org/apache/spark/storage/LocalDirsSuite.scala b/core/src/test/scala/org/apache/spark/storage/LocalDirsSuite.scala index c7074078d8fd2..f7b3a2754f0ea 100644 --- a/core/src/test/scala/org/apache/spark/storage/LocalDirsSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/LocalDirsSuite.scala @@ -17,7 +17,7 @@ package org.apache.spark.storage -import java.io.File +import java.io.{File, IOException} import org.scalatest.BeforeAndAfter @@ -33,9 +33,13 @@ class LocalDirsSuite extends SparkFunSuite with BeforeAndAfter { Utils.clearLocalRootDirs() } + after { + Utils.clearLocalRootDirs() + } + test("Utils.getLocalDir() returns a valid directory, even if some local dirs are missing") { // Regression test for SPARK-2974 - assert(!new File("/NONEXISTENT_DIR").exists()) + assert(!new File("/NONEXISTENT_PATH").exists()) val conf = new SparkConf(false) .set("spark.local.dir", s"/NONEXISTENT_PATH,${System.getProperty("java.io.tmpdir")}") assert(new File(Utils.getLocalDir(conf)).exists()) @@ -43,7 +47,7 @@ class LocalDirsSuite extends SparkFunSuite with BeforeAndAfter { test("SPARK_LOCAL_DIRS override also affects driver") { // Regression test for SPARK-2975 - assert(!new File("/NONEXISTENT_DIR").exists()) + assert(!new File("/NONEXISTENT_PATH").exists()) // spark.local.dir only contains invalid directories, but that's not a problem since // SPARK_LOCAL_DIRS will override it on both the driver and workers: val conf = new SparkConfWithEnv(Map("SPARK_LOCAL_DIRS" -> System.getProperty("java.io.tmpdir"))) @@ -51,4 +55,17 @@ class LocalDirsSuite extends SparkFunSuite with BeforeAndAfter { assert(new File(Utils.getLocalDir(conf)).exists()) } + test("Utils.getLocalDir() throws an exception if any temporary directory cannot be retrieved") { + val path1 = "/NONEXISTENT_PATH_ONE" + val path2 = "/NONEXISTENT_PATH_TWO" + assert(!new File(path1).exists()) + assert(!new File(path2).exists()) + val conf = new SparkConf(false).set("spark.local.dir", s"$path1,$path2") + val message = intercept[IOException] { + Utils.getLocalDir(conf) + }.getMessage + // If any temporary directory could not be retrieved under the given paths above, it should + // throw an exception with the message that includes the paths. + assert(message.contains(s"$path1,$path2")) + } } diff --git a/core/src/test/scala/org/apache/spark/storage/StorageSuite.scala b/core/src/test/scala/org/apache/spark/storage/StorageSuite.scala index e5733aebf607c..da198f946fd64 100644 --- a/core/src/test/scala/org/apache/spark/storage/StorageSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/StorageSuite.scala @@ -27,7 +27,7 @@ class StorageSuite extends SparkFunSuite { // For testing add, update, and remove (for non-RDD blocks) private def storageStatus1: StorageStatus = { - val status = new StorageStatus(BlockManagerId("big", "dog", 1), 1000L) + val status = new StorageStatus(BlockManagerId("big", "dog", 1), 1000L, Some(1000L), Some(0L)) assert(status.blocks.isEmpty) assert(status.rddBlocks.isEmpty) assert(status.memUsed === 0L) @@ -74,7 +74,7 @@ class StorageSuite extends SparkFunSuite { // For testing add, update, remove, get, and contains etc. for both RDD and non-RDD blocks private def storageStatus2: StorageStatus = { - val status = new StorageStatus(BlockManagerId("big", "dog", 1), 1000L) + val status = new StorageStatus(BlockManagerId("big", "dog", 1), 1000L, Some(1000L), Some(0L)) assert(status.rddBlocks.isEmpty) status.addBlock(TestBlockId("dan"), BlockStatus(memAndDisk, 10L, 20L)) status.addBlock(TestBlockId("man"), BlockStatus(memAndDisk, 10L, 20L)) @@ -252,9 +252,9 @@ class StorageSuite extends SparkFunSuite { // For testing StorageUtils.updateRddInfo and StorageUtils.getRddBlockLocations private def stockStorageStatuses: Seq[StorageStatus] = { - val status1 = new StorageStatus(BlockManagerId("big", "dog", 1), 1000L) - val status2 = new StorageStatus(BlockManagerId("fat", "duck", 2), 2000L) - val status3 = new StorageStatus(BlockManagerId("fat", "cat", 3), 3000L) + val status1 = new StorageStatus(BlockManagerId("big", "dog", 1), 1000L, Some(1000L), Some(0L)) + val status2 = new StorageStatus(BlockManagerId("fat", "duck", 2), 2000L, Some(2000L), Some(0L)) + val status3 = new StorageStatus(BlockManagerId("fat", "cat", 3), 3000L, Some(3000L), Some(0L)) status1.addBlock(RDDBlockId(0, 0), BlockStatus(memAndDisk, 1L, 2L)) status1.addBlock(RDDBlockId(0, 1), BlockStatus(memAndDisk, 1L, 2L)) status2.addBlock(RDDBlockId(0, 2), BlockStatus(memAndDisk, 1L, 2L)) @@ -332,4 +332,81 @@ class StorageSuite extends SparkFunSuite { assert(blockLocations1(RDDBlockId(1, 2)) === Seq("cat:3")) } + private val offheap = StorageLevel.OFF_HEAP + // For testing add, update, remove, get, and contains etc. for both RDD and non-RDD onheap + // and offheap blocks + private def storageStatus3: StorageStatus = { + val status = new StorageStatus(BlockManagerId("big", "dog", 1), 2000L, Some(1000L), Some(1000L)) + assert(status.rddBlocks.isEmpty) + status.addBlock(TestBlockId("dan"), BlockStatus(memAndDisk, 10L, 20L)) + status.addBlock(TestBlockId("man"), BlockStatus(offheap, 10L, 0L)) + status.addBlock(RDDBlockId(0, 0), BlockStatus(offheap, 10L, 0L)) + status.addBlock(RDDBlockId(1, 1), BlockStatus(offheap, 100L, 0L)) + status.addBlock(RDDBlockId(2, 2), BlockStatus(memAndDisk, 10L, 20L)) + status.addBlock(RDDBlockId(2, 3), BlockStatus(memAndDisk, 10L, 20L)) + status.addBlock(RDDBlockId(2, 4), BlockStatus(memAndDisk, 10L, 40L)) + status + } + + test("storage memUsed, diskUsed with on-heap and off-heap blocks") { + val status = storageStatus3 + def actualMemUsed: Long = status.blocks.values.map(_.memSize).sum + def actualDiskUsed: Long = status.blocks.values.map(_.diskSize).sum + + def actualOnHeapMemUsed: Long = + status.blocks.values.filter(!_.storageLevel.useOffHeap).map(_.memSize).sum + def actualOffHeapMemUsed: Long = + status.blocks.values.filter(_.storageLevel.useOffHeap).map(_.memSize).sum + + assert(status.maxMem === status.maxOnHeapMem.get + status.maxOffHeapMem.get) + + assert(status.memUsed === actualMemUsed) + assert(status.diskUsed === actualDiskUsed) + assert(status.onHeapMemUsed.get === actualOnHeapMemUsed) + assert(status.offHeapMemUsed.get === actualOffHeapMemUsed) + + assert(status.memRemaining === status.maxMem - actualMemUsed) + assert(status.onHeapMemRemaining.get === status.maxOnHeapMem.get - actualOnHeapMemUsed) + assert(status.offHeapMemRemaining.get === status.maxOffHeapMem.get - actualOffHeapMemUsed) + + status.addBlock(TestBlockId("wire"), BlockStatus(memAndDisk, 400L, 500L)) + status.addBlock(RDDBlockId(25, 25), BlockStatus(memAndDisk, 40L, 50L)) + assert(status.memUsed === actualMemUsed) + assert(status.diskUsed === actualDiskUsed) + + status.updateBlock(TestBlockId("dan"), BlockStatus(memAndDisk, 4L, 5L)) + status.updateBlock(RDDBlockId(0, 0), BlockStatus(offheap, 4L, 0L)) + status.updateBlock(RDDBlockId(1, 1), BlockStatus(offheap, 4L, 0L)) + assert(status.memUsed === actualMemUsed) + assert(status.diskUsed === actualDiskUsed) + assert(status.onHeapMemUsed.get === actualOnHeapMemUsed) + assert(status.offHeapMemUsed.get === actualOffHeapMemUsed) + + status.removeBlock(TestBlockId("fire")) + status.removeBlock(TestBlockId("man")) + status.removeBlock(RDDBlockId(2, 2)) + status.removeBlock(RDDBlockId(2, 3)) + assert(status.memUsed === actualMemUsed) + assert(status.diskUsed === actualDiskUsed) + } + + private def storageStatus4: StorageStatus = { + val status = new StorageStatus(BlockManagerId("big", "dog", 1), 2000L, None, None) + status + } + test("old SparkListenerBlockManagerAdded event compatible") { + // This scenario will only be happened when replaying old event log. In this scenario there's + // no block add or remove event replayed, so only total amount of memory is valid. + val status = storageStatus4 + assert(status.maxMem === status.maxMemory) + + assert(status.memUsed === 0L) + assert(status.diskUsed === 0L) + assert(status.onHeapMemUsed === None) + assert(status.offHeapMemUsed === None) + + assert(status.memRemaining === status.maxMem) + assert(status.onHeapMemRemaining === None) + assert(status.offHeapMemRemaining === None) + } } diff --git a/core/src/test/scala/org/apache/spark/ui/StagePageSuite.scala b/core/src/test/scala/org/apache/spark/ui/StagePageSuite.scala index 11482d187aeca..499d47b13d702 100644 --- a/core/src/test/scala/org/apache/spark/ui/StagePageSuite.scala +++ b/core/src/test/scala/org/apache/spark/ui/StagePageSuite.scala @@ -17,6 +17,7 @@ package org.apache.spark.ui +import java.util.Locale import javax.servlet.http.HttpServletRequest import scala.xml.Node @@ -37,14 +38,14 @@ class StagePageSuite extends SparkFunSuite with LocalSparkContext { test("peak execution memory should displayed") { val conf = new SparkConf(false) - val html = renderStagePage(conf).toString().toLowerCase + val html = renderStagePage(conf).toString().toLowerCase(Locale.ROOT) val targetString = "peak execution memory" assert(html.contains(targetString)) } test("SPARK-10543: peak execution memory should be per-task rather than cumulative") { val conf = new SparkConf(false) - val html = renderStagePage(conf).toString().toLowerCase + val html = renderStagePage(conf).toString().toLowerCase(Locale.ROOT) // verify min/25/50/75/max show task value not cumulative values assert(html.contains(s"$peakExecutionMemory.0 b" * 5)) } @@ -77,7 +78,7 @@ class StagePageSuite extends SparkFunSuite with LocalSparkContext { val taskInfo = new TaskInfo(taskId, taskId, 0, 0, "0", "localhost", TaskLocality.ANY, false) jobListener.onStageSubmitted(SparkListenerStageSubmitted(stageInfo)) jobListener.onTaskStart(SparkListenerTaskStart(0, 0, taskInfo)) - taskInfo.markFinished(TaskState.FINISHED) + taskInfo.markFinished(TaskState.FINISHED, System.currentTimeMillis()) val taskMetrics = TaskMetrics.empty taskMetrics.incPeakExecutionMemory(peakExecutionMemory) jobListener.onTaskEnd( diff --git a/core/src/test/scala/org/apache/spark/ui/UISeleniumSuite.scala b/core/src/test/scala/org/apache/spark/ui/UISeleniumSuite.scala index 4228373036425..bdd148875e38a 100644 --- a/core/src/test/scala/org/apache/spark/ui/UISeleniumSuite.scala +++ b/core/src/test/scala/org/apache/spark/ui/UISeleniumSuite.scala @@ -18,6 +18,7 @@ package org.apache.spark.ui import java.net.{HttpURLConnection, URL} +import java.util.Locale import javax.servlet.http.{HttpServletRequest, HttpServletResponse} import scala.io.Source @@ -39,7 +40,7 @@ import org.apache.spark.LocalSparkContext._ import org.apache.spark.api.java.StorageLevels import org.apache.spark.deploy.history.HistoryServerSuite import org.apache.spark.shuffle.FetchFailedException -import org.apache.spark.status.api.v1.{JacksonMessageWriter, StageStatus} +import org.apache.spark.status.api.v1.{JacksonMessageWriter, RDDDataDistribution, StageStatus} private[spark] class SparkUICssErrorHandler extends DefaultCssErrorHandler { @@ -103,6 +104,7 @@ class UISeleniumSuite extends SparkFunSuite with WebBrowser with Matchers with B .set("spark.ui.enabled", "true") .set("spark.ui.port", "0") .set("spark.ui.killEnabled", killEnabled.toString) + .set("spark.memory.offHeap.size", "64m") val sc = new SparkContext(conf) assert(sc.ui.isDefined) sc @@ -151,6 +153,39 @@ class UISeleniumSuite extends SparkFunSuite with WebBrowser with Matchers with B val updatedRddJson = getJson(ui, "storage/rdd/0") (updatedRddJson \ "storageLevel").extract[String] should be ( StorageLevels.MEMORY_ONLY.description) + + val dataDistributions0 = + (updatedRddJson \ "dataDistribution").extract[Seq[RDDDataDistribution]] + dataDistributions0.length should be (1) + val dist0 = dataDistributions0.head + + dist0.onHeapMemoryUsed should not be (None) + dist0.memoryUsed should be (dist0.onHeapMemoryUsed.get) + dist0.onHeapMemoryRemaining should not be (None) + dist0.offHeapMemoryRemaining should not be (None) + dist0.memoryRemaining should be ( + dist0.onHeapMemoryRemaining.get + dist0.offHeapMemoryRemaining.get) + dist0.onHeapMemoryUsed should not be (Some(0L)) + dist0.offHeapMemoryUsed should be (Some(0L)) + + rdd.unpersist() + rdd.persist(StorageLevels.OFF_HEAP).count() + val updatedStorageJson1 = getJson(ui, "storage/rdd") + updatedStorageJson1.children.length should be (1) + val updatedRddJson1 = getJson(ui, "storage/rdd/0") + val dataDistributions1 = + (updatedRddJson1 \ "dataDistribution").extract[Seq[RDDDataDistribution]] + dataDistributions1.length should be (1) + val dist1 = dataDistributions1.head + + dist1.offHeapMemoryUsed should not be (None) + dist1.memoryUsed should be (dist1.offHeapMemoryUsed.get) + dist1.onHeapMemoryRemaining should not be (None) + dist1.offHeapMemoryRemaining should not be (None) + dist1.memoryRemaining should be ( + dist1.onHeapMemoryRemaining.get + dist1.offHeapMemoryRemaining.get) + dist1.onHeapMemoryUsed should be (Some(0L)) + dist1.offHeapMemoryUsed should not be (Some(0L)) } } @@ -419,8 +454,8 @@ class UISeleniumSuite extends SparkFunSuite with WebBrowser with Matchers with B eventually(timeout(10 seconds), interval(50 milliseconds)) { goToUi(sc, "/jobs") findAll(cssSelector("tbody tr a")).foreach { link => - link.text.toLowerCase should include ("count") - link.text.toLowerCase should not include "unknown" + link.text.toLowerCase(Locale.ROOT) should include ("count") + link.text.toLowerCase(Locale.ROOT) should not include "unknown" } } } diff --git a/core/src/test/scala/org/apache/spark/ui/UISuite.scala b/core/src/test/scala/org/apache/spark/ui/UISuite.scala index f1be0f6de3ce2..0c3d4caeeabf9 100644 --- a/core/src/test/scala/org/apache/spark/ui/UISuite.scala +++ b/core/src/test/scala/org/apache/spark/ui/UISuite.scala @@ -19,6 +19,7 @@ package org.apache.spark.ui import java.net.{BindException, ServerSocket} import java.net.{URI, URL} +import java.util.Locale import javax.servlet.http.{HttpServlet, HttpServletRequest, HttpServletResponse} import scala.io.Source @@ -72,10 +73,10 @@ class UISuite extends SparkFunSuite { eventually(timeout(10 seconds), interval(50 milliseconds)) { val html = Source.fromURL(sc.ui.get.webUrl).mkString assert(!html.contains("random data that should not be present")) - assert(html.toLowerCase.contains("stages")) - assert(html.toLowerCase.contains("storage")) - assert(html.toLowerCase.contains("environment")) - assert(html.toLowerCase.contains("executors")) + assert(html.toLowerCase(Locale.ROOT).contains("stages")) + assert(html.toLowerCase(Locale.ROOT).contains("storage")) + assert(html.toLowerCase(Locale.ROOT).contains("environment")) + assert(html.toLowerCase(Locale.ROOT).contains("executors")) } } } @@ -85,7 +86,7 @@ class UISuite extends SparkFunSuite { // test if visible from http://localhost:4040 eventually(timeout(10 seconds), interval(50 milliseconds)) { val html = Source.fromURL("http://localhost:4040").mkString - assert(html.toLowerCase.contains("stages")) + assert(html.toLowerCase(Locale.ROOT).contains("stages")) } } } diff --git a/core/src/test/scala/org/apache/spark/ui/UIUtilsSuite.scala b/core/src/test/scala/org/apache/spark/ui/UIUtilsSuite.scala index 6335d905c0fbf..c770fd5da76f7 100644 --- a/core/src/test/scala/org/apache/spark/ui/UIUtilsSuite.scala +++ b/core/src/test/scala/org/apache/spark/ui/UIUtilsSuite.scala @@ -110,7 +110,7 @@ class UIUtilsSuite extends SparkFunSuite { } test("SPARK-11906: Progress bar should not overflow because of speculative tasks") { - val generated = makeProgressBar(2, 3, 0, 0, 0, 4).head.child.filter(_.label == "div") + val generated = makeProgressBar(2, 3, 0, 0, Map.empty, 4).head.child.filter(_.label == "div") val expected = Seq(
    ,
    diff --git a/core/src/test/scala/org/apache/spark/ui/jobs/JobProgressListenerSuite.scala b/core/src/test/scala/org/apache/spark/ui/jobs/JobProgressListenerSuite.scala index e3127da9a6b24..48be3be81755a 100644 --- a/core/src/test/scala/org/apache/spark/ui/jobs/JobProgressListenerSuite.scala +++ b/core/src/test/scala/org/apache/spark/ui/jobs/JobProgressListenerSuite.scala @@ -274,8 +274,9 @@ class JobProgressListenerSuite extends SparkFunSuite with LocalSparkContext with // Make sure killed tasks are accounted for correctly. listener.onTaskEnd( - SparkListenerTaskEnd(task.stageId, 0, taskType, TaskKilled, taskInfo, metrics)) - assert(listener.stageIdToData((task.stageId, 0)).numKilledTasks === 1) + SparkListenerTaskEnd( + task.stageId, 0, taskType, TaskKilled("test"), taskInfo, metrics)) + assert(listener.stageIdToData((task.stageId, 0)).reasonToNumKilled === Map("test" -> 1)) // Make sure we count success as success. listener.onTaskEnd( @@ -292,7 +293,7 @@ class JobProgressListenerSuite extends SparkFunSuite with LocalSparkContext with val execId = "exe-1" def makeTaskMetrics(base: Int): TaskMetrics = { - val taskMetrics = TaskMetrics.empty + val taskMetrics = TaskMetrics.registered val shuffleReadMetrics = taskMetrics.createTempShuffleReadMetrics() val shuffleWriteMetrics = taskMetrics.shuffleWriteMetrics val inputMetrics = taskMetrics.inputMetrics diff --git a/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala b/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala index 9f76c74bce89e..a77c8e3cab4e8 100644 --- a/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala @@ -164,7 +164,7 @@ class JsonProtocolSuite extends SparkFunSuite { testTaskEndReason(fetchMetadataFailed) testTaskEndReason(exceptionFailure) testTaskEndReason(TaskResultLost) - testTaskEndReason(TaskKilled) + testTaskEndReason(TaskKilled("test")) testTaskEndReason(TaskCommitDenied(2, 3, 4)) testTaskEndReason(ExecutorLostFailure("100", true, Some("Induced failure"))) testTaskEndReason(UnknownReason) @@ -676,7 +676,8 @@ private[spark] object JsonProtocolSuite extends Assertions { assert(r1.fullStackTrace === r2.fullStackTrace) assertSeqEquals[AccumulableInfo](r1.accumUpdates, r2.accumUpdates, (a, b) => a.equals(b)) case (TaskResultLost, TaskResultLost) => - case (TaskKilled, TaskKilled) => + case (r1: TaskKilled, r2: TaskKilled) => + assert(r1.reason == r2.reason) case (TaskCommitDenied(jobId1, partitionId1, attemptNumber1), TaskCommitDenied(jobId2, partitionId2, attemptNumber2)) => assert(jobId1 === jobId2) @@ -829,7 +830,7 @@ private[spark] object JsonProtocolSuite extends Assertions { hasHadoopInput: Boolean, hasOutput: Boolean, hasRecords: Boolean = true) = { - val t = TaskMetrics.empty + val t = TaskMetrics.registered // Set CPU times same as wall times for testing purpose t.setExecutorDeserializeTime(a) t.setExecutorDeserializeCpuTime(a) diff --git a/mllib/src/test/scala/org/apache/spark/mllib/impl/PeriodicRDDCheckpointerSuite.scala b/core/src/test/scala/org/apache/spark/util/PeriodicRDDCheckpointerSuite.scala similarity index 96% rename from mllib/src/test/scala/org/apache/spark/mllib/impl/PeriodicRDDCheckpointerSuite.scala rename to core/src/test/scala/org/apache/spark/util/PeriodicRDDCheckpointerSuite.scala index 14adf8c29fc6b..f9e1b791c86ea 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/impl/PeriodicRDDCheckpointerSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/PeriodicRDDCheckpointerSuite.scala @@ -15,18 +15,18 @@ * limitations under the License. */ -package org.apache.spark.mllib.impl +package org.apache.spark.utils import org.apache.hadoop.fs.Path -import org.apache.spark.{SparkContext, SparkFunSuite} -import org.apache.spark.mllib.util.MLlibTestSparkContext +import org.apache.spark.{SharedSparkContext, SparkContext, SparkFunSuite} import org.apache.spark.rdd.RDD +import org.apache.spark.rdd.util.PeriodicRDDCheckpointer import org.apache.spark.storage.StorageLevel import org.apache.spark.util.Utils -class PeriodicRDDCheckpointerSuite extends SparkFunSuite with MLlibTestSparkContext { +class PeriodicRDDCheckpointerSuite extends SparkFunSuite with SharedSparkContext { import PeriodicRDDCheckpointerSuite._ diff --git a/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala b/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala index 8ed09749ffd54..3339d5b35d3b2 100644 --- a/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala @@ -1010,15 +1010,19 @@ class UtilsSuite extends SparkFunSuite with ResetSystemProperties with Logging { "spark.executorEnv.HADOOP_CREDSTORE_PASSWORD", "spark.my.password", "spark.my.sECreT") - secretKeys.foreach { key => sparkConf.set(key, "secret_password") } + secretKeys.foreach { key => sparkConf.set(key, "sensitive_value") } // Set a non-secret key - sparkConf.set("spark.regular.property", "not_a_secret") + sparkConf.set("spark.regular.property", "regular_value") + // Set a property with a regular key but secret in the value + sparkConf.set("spark.sensitive.property", "has_secret_in_value") // Redact sensitive information val redactedConf = Utils.redact(sparkConf, sparkConf.getAll).toMap // Assert that secret information got redacted while the regular property remained the same secretKeys.foreach { key => assert(redactedConf(key) === Utils.REDACTION_REPLACEMENT_TEXT) } - assert(redactedConf("spark.regular.property") === "not_a_secret") + assert(redactedConf("spark.regular.property") === "regular_value") + assert(redactedConf("spark.sensitive.property") === Utils.REDACTION_REPLACEMENT_TEXT) + } } diff --git a/core/src/test/scala/org/apache/spark/util/collection/MedianHeapSuite.scala b/core/src/test/scala/org/apache/spark/util/collection/MedianHeapSuite.scala new file mode 100644 index 0000000000000..c2a3ee95f1c55 --- /dev/null +++ b/core/src/test/scala/org/apache/spark/util/collection/MedianHeapSuite.scala @@ -0,0 +1,66 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.util.collection + +import java.util.NoSuchElementException + +import org.apache.spark.SparkFunSuite + +class MedianHeapSuite extends SparkFunSuite { + + test("If no numbers in MedianHeap, NoSuchElementException is thrown.") { + val medianHeap = new MedianHeap() + intercept[NoSuchElementException] { + medianHeap.median + } + } + + test("Median should be correct when size of MedianHeap is even") { + val array = Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9) + val medianHeap = new MedianHeap() + array.foreach(medianHeap.insert(_)) + assert(medianHeap.size() === 10) + assert(medianHeap.median === 4.5) + } + + test("Median should be correct when size of MedianHeap is odd") { + val array = Array(0, 1, 2, 3, 4, 5, 6, 7, 8) + val medianHeap = new MedianHeap() + array.foreach(medianHeap.insert(_)) + assert(medianHeap.size() === 9) + assert(medianHeap.median === 4) + } + + test("Median should be correct though there are duplicated numbers inside.") { + val array = Array(0, 0, 1, 1, 2, 3, 4) + val medianHeap = new MedianHeap() + array.foreach(medianHeap.insert(_)) + assert(medianHeap.size === 7) + assert(medianHeap.median === 1) + } + + test("Median should be correct when input data is skewed.") { + val medianHeap = new MedianHeap() + (0 until 10).foreach(_ => medianHeap.insert(5)) + assert(medianHeap.median === 5) + (0 until 100).foreach(_ => medianHeap.insert(10)) + assert(medianHeap.median === 10) + (0 until 1000).foreach(_ => medianHeap.insert(0)) + assert(medianHeap.median === 0) + } +} diff --git a/dev/checkstyle-suppressions.xml b/dev/checkstyle-suppressions.xml index 31656ca0e5a60..bb7d31cad7be3 100644 --- a/dev/checkstyle-suppressions.xml +++ b/dev/checkstyle-suppressions.xml @@ -44,4 +44,8 @@ files="src/main/java/org/apache/hive/service/server/ThreadWithGarbageCleanup.java"/> + + diff --git a/dev/create-release/release-build.sh b/dev/create-release/release-build.sh index e1db997a7d410..7976d8a039544 100755 --- a/dev/create-release/release-build.sh +++ b/dev/create-release/release-build.sh @@ -246,7 +246,7 @@ if [[ "$1" == "package" ]]; then dest_dir="$REMOTE_PARENT_DIR/${DEST_DIR_NAME}-bin" echo "Copying release tarballs to $dest_dir" # Put to new directory: - LFTP mkdir -p $dest_dir + LFTP mkdir -p $dest_dir || true LFTP mput -O $dest_dir 'spark-*' LFTP mput -O $dest_dir 'pyspark-*' LFTP mput -O $dest_dir 'SparkR_*' @@ -254,7 +254,7 @@ if [[ "$1" == "package" ]]; then LFTP "rm -r -f $REMOTE_PARENT_DIR/latest || exit 0" LFTP mv $dest_dir "$REMOTE_PARENT_DIR/latest" # Re-upload a second time and leave the files in the timestamped upload directory: - LFTP mkdir -p $dest_dir + LFTP mkdir -p $dest_dir || true LFTP mput -O $dest_dir 'spark-*' LFTP mput -O $dest_dir 'pyspark-*' LFTP mput -O $dest_dir 'SparkR_*' @@ -271,13 +271,13 @@ if [[ "$1" == "docs" ]]; then PRODUCTION=1 RELEASE_VERSION="$SPARK_VERSION" jekyll build echo "Copying release documentation to $dest_dir" # Put to new directory: - LFTP mkdir -p $dest_dir + LFTP mkdir -p $dest_dir || true LFTP mirror -R _site $dest_dir # Delete /latest directory and rename new upload to /latest LFTP "rm -r -f $REMOTE_PARENT_DIR/latest || exit 0" LFTP mv $dest_dir "$REMOTE_PARENT_DIR/latest" # Re-upload a second time and leave the files in the timestamped upload directory: - LFTP mkdir -p $dest_dir + LFTP mkdir -p $dest_dir || true LFTP mirror -R _site $dest_dir cd .. exit 0 diff --git a/dev/deps/spark-deps-hadoop-palantir b/dev/deps/spark-deps-hadoop-palantir index 3f2985dd2167c..33ae78ff3bee5 100644 --- a/dev/deps/spark-deps-hadoop-palantir +++ b/dev/deps/spark-deps-hadoop-palantir @@ -25,8 +25,8 @@ base64-2.3.8.jar bcpkix-jdk15on-1.54.jar bcprov-jdk15on-1.54.jar bonecp-0.8.0.RELEASE.jar -breeze-macros_2.11-0.12.jar -breeze_2.11-0.12.jar +breeze-macros_2.11-0.13.1.jar +breeze_2.11-0.13.1.jar calcite-avatica-1.2.0-incubating.jar calcite-core-1.2.0-incubating.jar calcite-linq4j-1.2.0-incubating.jar @@ -157,6 +157,8 @@ libthrift-0.9.3.jar log4j-1.2.17.jar logging-interceptor-3.6.0.jar lz4-1.3.0.jar +machinist_2.11-0.6.1.jar +macro-compat_2.11-1.1.1.jar mail-1.4.7.jar metrics-core-3.1.2.jar metrics-graphite-3.1.2.jar @@ -193,14 +195,14 @@ scala-parser-combinators_2.11-1.0.4.jar scala-reflect-2.11.8.jar scala-xml_2.11-1.0.2.jar scalap-2.11.8.jar -shapeless_2.11-2.0.0.jar +shapeless_2.11-2.3.2.jar slf4j-api-1.7.16.jar slf4j-log4j12-1.7.16.jar snakeyaml-1.15.jar snappy-0.2.jar snappy-java-1.1.2.6.jar -spire-macros_2.11-0.7.4.jar -spire_2.11-0.7.4.jar +spire-macros_2.11-0.13.0.jar +spire_2.11-0.13.0.jar stax-api-1.0-2.jar stax-api-1.0.1.jar stream-2.7.0.jar diff --git a/dev/make-distribution.sh b/dev/make-distribution.sh index c98825462d586..c8f5257d21dda 100755 --- a/dev/make-distribution.sh +++ b/dev/make-distribution.sh @@ -144,7 +144,7 @@ VERSION_SET=$("$MVN" versions:set -DnewVersion=$VERSION | tail -n 1) if [ "$MAKE_TGZ" == "true" ]; then echo "Making spark-dist-$VERSION-$NAME.tgz" else - echo "Making distribution for Spark $VERSION in $DISTDIR..." + echo "Making distribution for Spark $VERSION in '$DISTDIR'..." fi # Build uber fat JAR @@ -187,7 +187,7 @@ cp "$DOCKERFILES_SRC/executor/Dockerfile" "$DISTDIR/dockerfiles/executor/Dockerf # Only create the yarn directory if the yarn artifacts were build. if [ -f "$SPARK_HOME"/common/network-yarn/target/scala*/spark-*-yarn-shuffle.jar ]; then - mkdir "$DISTDIR"/yarn + mkdir "$DISTDIR/yarn" cp "$SPARK_HOME"/common/network-yarn/target/scala*/spark-*-yarn-shuffle.jar "$DISTDIR/yarn" fi @@ -196,7 +196,7 @@ mkdir -p "$DISTDIR/examples/jars" cp "$SPARK_HOME"/examples/target/scala*/jars/* "$DISTDIR/examples/jars" # Deduplicate jars that have already been packaged as part of the main Spark dependencies. -for f in "$DISTDIR/examples/jars/"*; do +for f in "$DISTDIR"/examples/jars/*; do name=$(basename "$f") if [ -f "$DISTDIR/jars/$name" ]; then rm "$DISTDIR/examples/jars/$name" @@ -205,14 +205,14 @@ done # Copy example sources (needed for python and SQL) mkdir -p "$DISTDIR/examples/src/main" -cp -r "$SPARK_HOME"/examples/src/main "$DISTDIR/examples/src/" +cp -r "$SPARK_HOME/examples/src/main" "$DISTDIR/examples/src/" # Copy license and ASF files cp "$SPARK_HOME/LICENSE" "$DISTDIR" cp -r "$SPARK_HOME/licenses" "$DISTDIR" cp "$SPARK_HOME/NOTICE" "$DISTDIR" -if [ -e "$SPARK_HOME"/CHANGES.txt ]; then +if [ -e "$SPARK_HOME/CHANGES.txt" ]; then cp "$SPARK_HOME/CHANGES.txt" "$DISTDIR" fi @@ -234,43 +234,43 @@ fi # Make R package - this is used for both CRAN release and packing R layout into distribution if [ "$MAKE_R" == "true" ]; then echo "Building R source package" - R_PACKAGE_VERSION=`grep Version $SPARK_HOME/R/pkg/DESCRIPTION | awk '{print $NF}'` + R_PACKAGE_VERSION=`grep Version "$SPARK_HOME/R/pkg/DESCRIPTION" | awk '{print $NF}'` pushd "$SPARK_HOME/R" > /dev/null # Build source package and run full checks # Do not source the check-cran.sh - it should be run from where it is for it to set SPARK_HOME - NO_TESTS=1 "$SPARK_HOME/"R/check-cran.sh + NO_TESTS=1 "$SPARK_HOME/R/check-cran.sh" # Move R source package to match the Spark release version if the versions are not the same. # NOTE(shivaram): `mv` throws an error on Linux if source and destination are same file if [ "$R_PACKAGE_VERSION" != "$VERSION" ]; then - mv $SPARK_HOME/R/SparkR_"$R_PACKAGE_VERSION".tar.gz $SPARK_HOME/R/SparkR_"$VERSION".tar.gz + mv "$SPARK_HOME/R/SparkR_$R_PACKAGE_VERSION.tar.gz" "$SPARK_HOME/R/SparkR_$VERSION.tar.gz" fi # Install source package to get it to generate vignettes rds files, etc. - VERSION=$VERSION "$SPARK_HOME/"R/install-source-package.sh + VERSION=$VERSION "$SPARK_HOME/R/install-source-package.sh" popd > /dev/null else echo "Skipping building R source package" fi # Copy other things -mkdir "$DISTDIR"/conf -cp "$SPARK_HOME"/conf/*.template "$DISTDIR"/conf +mkdir "$DISTDIR/conf" +cp "$SPARK_HOME"/conf/*.template "$DISTDIR/conf" cp "$SPARK_HOME/README.md" "$DISTDIR" cp -r "$SPARK_HOME/bin" "$DISTDIR" cp -r "$SPARK_HOME/python" "$DISTDIR" # Remove the python distribution from dist/ if we built it if [ "$MAKE_PIP" == "true" ]; then - rm -f $DISTDIR/python/dist/pyspark-*.tar.gz + rm -f "$DISTDIR"/python/dist/pyspark-*.tar.gz fi cp -r "$SPARK_HOME/sbin" "$DISTDIR" # Copy SparkR if it exists -if [ -d "$SPARK_HOME"/R/lib/SparkR ]; then - mkdir -p "$DISTDIR"/R/lib - cp -r "$SPARK_HOME/R/lib/SparkR" "$DISTDIR"/R/lib - cp "$SPARK_HOME/R/lib/sparkr.zip" "$DISTDIR"/R/lib +if [ -d "$SPARK_HOME/R/lib/SparkR" ]; then + mkdir -p "$DISTDIR/R/lib" + cp -r "$SPARK_HOME/R/lib/SparkR" "$DISTDIR/R/lib" + cp "$SPARK_HOME/R/lib/sparkr.zip" "$DISTDIR/R/lib" fi if [ "$MAKE_TGZ" == "true" ]; then diff --git a/dev/run-pip-tests b/dev/run-pip-tests index af1b1feb70cd1..d51dde12a03c5 100755 --- a/dev/run-pip-tests +++ b/dev/run-pip-tests @@ -35,9 +35,28 @@ function delete_virtualenv() { } trap delete_virtualenv EXIT +PYTHON_EXECS=() # Some systems don't have pip or virtualenv - in those cases our tests won't work. -if ! hash virtualenv 2>/dev/null; then - echo "Missing virtualenv skipping pip installability tests." +if hash virtualenv 2>/dev/null && [ ! -n "$USE_CONDA" ]; then + echo "virtualenv installed - using. Note if this is a conda virtual env you may wish to set USE_CONDA" + # Figure out which Python execs we should test pip installation with + if hash python2 2>/dev/null; then + # We do this since we are testing with virtualenv and the default virtual env python + # is in /usr/bin/python + PYTHON_EXECS+=('python2') + elif hash python 2>/dev/null; then + # If python2 isn't installed fallback to python if available + PYTHON_EXECS+=('python') + fi + if hash python3 2>/dev/null; then + PYTHON_EXECS+=('python3') + fi +elif hash conda 2>/dev/null; then + echo "Using conda virtual enviroments" + PYTHON_EXECS=('3.5') + USE_CONDA=1 +else + echo "Missing virtualenv & conda, skipping pip installability tests" exit 0 fi if ! hash pip 2>/dev/null; then @@ -45,22 +64,8 @@ if ! hash pip 2>/dev/null; then exit 0 fi -# Figure out which Python execs we should test pip installation with -PYTHON_EXECS=() -if hash python2 2>/dev/null; then - # We do this since we are testing with virtualenv and the default virtual env python - # is in /usr/bin/python - PYTHON_EXECS+=('python2') -elif hash python 2>/dev/null; then - # If python2 isn't installed fallback to python if available - PYTHON_EXECS+=('python') -fi -if hash python3 2>/dev/null; then - PYTHON_EXECS+=('python3') -fi - # Determine which version of PySpark we are building for archive name -PYSPARK_VERSION=$(python -c "exec(open('python/pyspark/version.py').read());print __version__") +PYSPARK_VERSION=$(python3 -c "exec(open('python/pyspark/version.py').read());print(__version__)") PYSPARK_DIST="$FWDIR/python/dist/pyspark-$PYSPARK_VERSION.tar.gz" # The pip install options we use for all the pip commands PIP_OPTIONS="--upgrade --no-cache-dir --force-reinstall " @@ -75,18 +80,24 @@ for python in "${PYTHON_EXECS[@]}"; do echo "Using $VIRTUALENV_BASE for virtualenv" VIRTUALENV_PATH="$VIRTUALENV_BASE"/$python rm -rf "$VIRTUALENV_PATH" - mkdir -p "$VIRTUALENV_PATH" - virtualenv --python=$python "$VIRTUALENV_PATH" - source "$VIRTUALENV_PATH"/bin/activate - # Upgrade pip & friends - pip install --upgrade pip pypandoc wheel - pip install numpy # Needed so we can verify mllib imports + if [ -n "$USE_CONDA" ]; then + conda create -y -p "$VIRTUALENV_PATH" python=$python numpy pandas pip setuptools + source activate "$VIRTUALENV_PATH" + else + mkdir -p "$VIRTUALENV_PATH" + virtualenv --python=$python "$VIRTUALENV_PATH" + source "$VIRTUALENV_PATH"/bin/activate + fi + # Upgrade pip & friends if using virutal env + if [ ! -n "USE_CONDA" ]; then + pip install --upgrade pip pypandoc wheel numpy + fi echo "Creating pip installable source dist" cd "$FWDIR"/python # Delete the egg info file if it exists, this can cache the setup file. rm -rf pyspark.egg-info || echo "No existing egg info file, skipping deletion" - $python setup.py sdist + python setup.py sdist echo "Installing dist into virtual env" @@ -112,6 +123,13 @@ for python in "${PYTHON_EXECS[@]}"; do cd "$FWDIR" + # conda / virtualenv enviroments need to be deactivated differently + if [ -n "$USE_CONDA" ]; then + source deactivate + else + deactivate + fi + done done diff --git a/dev/run-tests-jenkins b/dev/run-tests-jenkins index e79accf9e987a..f41f1ac79e381 100755 --- a/dev/run-tests-jenkins +++ b/dev/run-tests-jenkins @@ -22,7 +22,8 @@ # Environment variables are populated by the code here: #+ https://github.com/jenkinsci/ghprb-plugin/blob/master/src/main/java/org/jenkinsci/plugins/ghprb/GhprbTrigger.java#L139 -FWDIR="$(cd "`dirname $0`"/..; pwd)" +FWDIR="$( cd "$( dirname "$0" )/.." && pwd )" cd "$FWDIR" +export PATH=/home/anaconda/bin:$PATH exec python -u ./dev/run-tests-jenkins.py "$@" diff --git a/dev/run-tests.py b/dev/run-tests.py index b18e1d511260e..c292d69e6cf5c 100755 --- a/dev/run-tests.py +++ b/dev/run-tests.py @@ -346,6 +346,19 @@ def build_spark_sbt(hadoop_version): exec_sbt(profiles_and_goals) +def build_spark_unidoc_sbt(hadoop_version): + set_title_and_block("Building Unidoc API Documentation", "BLOCK_DOCUMENTATION") + # Enable all of the profiles for the build: + build_profiles = get_hadoop_profiles(hadoop_version) + modules.root.build_profile_flags + sbt_goals = ["unidoc"] + profiles_and_goals = build_profiles + sbt_goals + + print("[info] Building Spark unidoc (w/Hive 1.2.1) using SBT with these arguments: ", + " ".join(profiles_and_goals)) + + exec_sbt(profiles_and_goals) + + def build_spark_assembly_sbt(hadoop_version): # Enable all of the profiles for the build: build_profiles = get_hadoop_profiles(hadoop_version) + modules.root.build_profile_flags @@ -355,6 +368,16 @@ def build_spark_assembly_sbt(hadoop_version): " ".join(profiles_and_goals)) exec_sbt(profiles_and_goals) + # Note that we skip Unidoc build only if Hadoop 2.6 is explicitly set in this SBT build. + # Due to a different dependency resolution in SBT & Unidoc by an unknown reason, the + # documentation build fails on a specific machine & environment in Jenkins but it was unable + # to reproduce. Please see SPARK-20343. This is a band-aid fix that should be removed in + # the future. + is_hadoop_version_2_6 = os.environ.get("AMPLAB_JENKINS_BUILD_PROFILE") == "hadoop2.6" + if not is_hadoop_version_2_6: + # Make sure that Java and Scala API documentation can be generated + build_spark_unidoc_sbt(hadoop_version) + def build_apache_spark(build_tool, hadoop_version): """Will build Spark against Hive v1.2.1 given the passed in build tool (either `sbt` or diff --git a/dev/scalastyle b/dev/scalastyle index de7423913fad9..204277c705d65 100755 --- a/dev/scalastyle +++ b/dev/scalastyle @@ -21,6 +21,7 @@ # with failure (either resolution or compilation); the "q" makes SBT quit. ERRORS=$(echo -e "q\n" \ | build/sbt \ + -Pcloud \ -Pkinesis-asl \ -Pmesos \ -Pyarn \ diff --git a/dev/sparktestsupport/modules.py b/dev/sparktestsupport/modules.py index 10ad1fe3aa2c6..78b5b8b0f4b59 100644 --- a/dev/sparktestsupport/modules.py +++ b/dev/sparktestsupport/modules.py @@ -340,6 +340,7 @@ def __hash__(self): "pyspark.profiler", "pyspark.shuffle", "pyspark.tests", + "pyspark.util", ] ) @@ -423,15 +424,17 @@ def __hash__(self): "python/pyspark/ml/" ], python_test_goals=[ - "pyspark.ml.feature", "pyspark.ml.classification", "pyspark.ml.clustering", + "pyspark.ml.evaluation", + "pyspark.ml.feature", + "pyspark.ml.fpm", "pyspark.ml.linalg.__init__", "pyspark.ml.recommendation", "pyspark.ml.regression", + "pyspark.ml.stat", "pyspark.ml.tuning", "pyspark.ml.tests", - "pyspark.ml.evaluation", ], blacklisted_python_implementations=[ "PyPy" # Skip these tests under PyPy since they require numpy and it isn't available there diff --git a/dists/hadoop-palantir/pom.xml b/dists/hadoop-palantir/pom.xml index 484f3e5f1b283..724d0b0348ff3 100644 --- a/dists/hadoop-palantir/pom.xml +++ b/dists/hadoop-palantir/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.11 - 2.2.0-SNAPSHOT + 2.3.0-SNAPSHOT ../../pom.xml diff --git a/dists/without-hadoop/pom.xml b/dists/without-hadoop/pom.xml index 0c183234dfb78..d9bd338f13aae 100644 --- a/dists/without-hadoop/pom.xml +++ b/dists/without-hadoop/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.11 - 2.2.0-SNAPSHOT + 2.3.0-SNAPSHOT ../../pom.xml diff --git a/docs/_config.yml b/docs/_config.yml index 83bb30598d153..21255ef7a5c45 100644 --- a/docs/_config.yml +++ b/docs/_config.yml @@ -14,8 +14,8 @@ include: # These allow the documentation to be updated with newer releases # of Spark, Scala, and Mesos. -SPARK_VERSION: 2.2.0-SNAPSHOT -SPARK_VERSION_SHORT: 2.2.0 +SPARK_VERSION: 2.3.0-SNAPSHOT +SPARK_VERSION_SHORT: 2.3.0 SCALA_BINARY_VERSION: "2.11" SCALA_VERSION: "2.11.7" MESOS_VERSION: 1.0.0 diff --git a/docs/_data/menu-ml.yaml b/docs/_data/menu-ml.yaml index 0c6b9b20a6e4b..047423f75aec1 100644 --- a/docs/_data/menu-ml.yaml +++ b/docs/_data/menu-ml.yaml @@ -8,6 +8,8 @@ url: ml-clustering.html - text: Collaborative filtering url: ml-collaborative-filtering.html +- text: Frequent Pattern Mining + url: ml-frequent-pattern-mining.html - text: Model selection and tuning url: ml-tuning.html - text: Advanced topics diff --git a/docs/building-spark.md b/docs/building-spark.md index 8353b7a520b8e..0f551bc66b8c9 100644 --- a/docs/building-spark.md +++ b/docs/building-spark.md @@ -154,7 +154,7 @@ Developers who compile Spark frequently may want to speed up compilation; e.g., developers who build with SBT). For more information about how to do this, refer to the [Useful Developer Tools page](http://spark.apache.org/developer-tools.html#reducing-build-times). -## Encrypted Filesystems +## Encrypted Filesystems When building on an encrypted filesystem (if your home directory is encrypted, for example), then the Spark build might fail with a "Filename too long" error. As a workaround, add the following in the configuration args of the `scala-maven-plugin` in the project `pom.xml`: @@ -232,7 +232,7 @@ Once installed, the `docker` service needs to be started, if not already running On Linux, this can be done by `sudo service docker start`. ./build/mvn install -DskipTests - ./build/mvn -Pdocker-integration-tests -pl :spark-docker-integration-tests_2.11 + ./build/mvn test -Pdocker-integration-tests -pl :spark-docker-integration-tests_2.11 or diff --git a/docs/cluster-overview.md b/docs/cluster-overview.md index 814e4406cf435..a2ad958959a50 100644 --- a/docs/cluster-overview.md +++ b/docs/cluster-overview.md @@ -52,7 +52,11 @@ The system currently supports three cluster managers: * [Apache Mesos](running-on-mesos.html) -- a general cluster manager that can also run Hadoop MapReduce and service applications. * [Hadoop YARN](running-on-yarn.html) -- the resource manager in Hadoop 2. - +* [Kubernetes (experimental)](https://github.com/apache-spark-on-k8s/spark) -- In addition to the above, +there is experimental support for Kubernetes. Kubernetes is an open-source platform +for providing container-centric infrastructure. Kubernetes support is being actively +developed in an [apache-spark-on-k8s](https://github.com/apache-spark-on-k8s/) Github organization. +For documentation, refer to that project's README. # Submitting Applications diff --git a/docs/configuration.md b/docs/configuration.md index 63392a741a1f0..1d8d963016c71 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -213,6 +213,14 @@ of the most common options to set are: and typically can have up to 50 characters. + + spark.driver.supervise + false + + If true, restarts the driver automatically if it fails with a non-zero exit status. + Only has effect in Spark standalone mode or Mesos cluster deploy mode. + + Apart from these, the following properties are also available, and may be useful in some situations: @@ -364,8 +372,8 @@ Apart from these, the following properties are also available, and may be useful (?i)secret|password Regex to decide which Spark configuration properties and environment variables in driver and - executor environments contain sensitive information. When this regex matches a property, its - value is redacted from the environment UI and various logs like YARN and event logs. + executor environments contain sensitive information. When this regex matches a property key or + value, the value is redacted from the environment UI and various logs like YARN and event logs. @@ -639,6 +647,7 @@ Apart from these, the following properties are also available, and may be useful false Whether to compress logged events, if spark.eventLog.enabled is true. + Compression will use spark.io.compression.codec. @@ -773,14 +782,15 @@ Apart from these, the following properties are also available, and may be useful true Whether to compress broadcast variables before sending them. Generally a good idea. + Compression will use spark.io.compression.codec. spark.io.compression.codec lz4 - The codec used to compress internal data such as RDD partitions, broadcast variables and - shuffle outputs. By default, Spark provides three codecs: lz4, lzf, + The codec used to compress internal data such as RDD partitions, event log, broadcast variables + and shuffle outputs. By default, Spark provides three codecs: lz4, lzf, and snappy. You can also use fully qualified class names to specify the codec, e.g. org.apache.spark.io.LZ4CompressionCodec, @@ -881,6 +891,7 @@ Apart from these, the following properties are also available, and may be useful StorageLevel.MEMORY_ONLY_SER in Java and Scala or StorageLevel.MEMORY_ONLY in Python). Can save substantial space at the cost of some extra CPU time. + Compression will use spark.io.compression.codec. @@ -1137,6 +1148,15 @@ Apart from these, the following properties are also available, and may be useful mapping has high overhead for blocks close to or below the page size of the operating system. + + spark.hadoop.mapreduce.fileoutputcommitter.algorithm.version + 1 + + The file output committer algorithm version, valid algorithm version number: 1 or 2. + Version 2 may have better performance, but version 1 may handle failures better in certain situations, + as per MAPREDUCE-4815. + + ### Networking @@ -1506,6 +1526,11 @@ Apart from these, the following properties are also available, and may be useful of this setting is to act as a safety-net to prevent runaway uncancellable tasks from rendering an executor unusable. + spark.stage.maxConsecutiveAttempts + 4 + + Number of consecutive stage attempts allowed before a stage is aborted. + @@ -2124,6 +2149,20 @@ showDF(properties, numRows = 200, truncate = FALSE) +### GraphX + + + + + + + + +
    Property NameDefaultMeaning
    spark.graphx.pregel.checkpointInterval-1 + Checkpoint interval for graph and message in Pregel. It used to avoid stackOverflowError due to long lineage chains + after lots of iterations. The checkpoint is disabled by default. +
    + ### Deploy @@ -2231,8 +2270,8 @@ should be included on Spark's classpath: * `hdfs-site.xml`, which provides default behaviors for the HDFS client. * `core-site.xml`, which sets the default filesystem name. -The location of these configuration files varies across CDH and HDP versions, but -a common location is inside of `/etc/hadoop/conf`. Some tools, such as Cloudera Manager, create +The location of these configuration files varies across Hadoop versions, but +a common location is inside of `/etc/hadoop/conf`. Some tools create configurations on-the-fly, but offer a mechanisms to download copies of them. To make these files visible to Spark, set `HADOOP_CONF_DIR` in `$SPARK_HOME/spark-env.sh` diff --git a/docs/graphx-programming-guide.md b/docs/graphx-programming-guide.md index e271b28fb4f28..76aa7b405e18c 100644 --- a/docs/graphx-programming-guide.md +++ b/docs/graphx-programming-guide.md @@ -708,7 +708,9 @@ messages remaining. > messaging function. These constraints allow additional optimization within GraphX. The following is the type signature of the [Pregel operator][GraphOps.pregel] as well as a *sketch* -of its implementation (note calls to graph.cache have been removed): +of its implementation (note: to avoid stackOverflowError due to long lineage chains, pregel support periodcally +checkpoint graph and messages by setting "spark.graphx.pregel.checkpointInterval" to a positive number, +say 10. And set checkpoint directory as well using SparkContext.setCheckpointDir(directory: String)): {% highlight scala %} class GraphOps[VD, ED] { @@ -722,6 +724,7 @@ class GraphOps[VD, ED] { : Graph[VD, ED] = { // Receive the initial message at each vertex var g = mapVertices( (vid, vdata) => vprog(vid, vdata, initialMsg) ).cache() + // compute the messages var messages = g.mapReduceTriplets(sendMsg, mergeMsg) var activeMessages = messages.count() @@ -734,8 +737,8 @@ class GraphOps[VD, ED] { // Send new messages, skipping edges where neither side received a message. We must cache // messages so it can be materialized on the next line, allowing us to uncache the previous // iteration. - messages = g.mapReduceTriplets( - sendMsg, mergeMsg, Some((oldMessages, activeDirection))).cache() + messages = GraphXUtils.mapReduceTriplets( + g, sendMsg, mergeMsg, Some((oldMessages, activeDirection))).cache() activeMessages = messages.count() i += 1 } diff --git a/docs/index.md b/docs/index.md index 96c380550b7bb..7fd0b688a7c0b 100644 --- a/docs/index.md +++ b/docs/index.md @@ -18,7 +18,7 @@ Users can also download a "Hadoop free" binary and run Spark with any Hadoop ver Scala and Java users can include Spark in their projects using its Maven coordinates and in the future Python users can also install Spark from PyPI. -If you'd like to build Spark from +If you'd like to build Spark from source, visit [Building Spark](building-spark.html). @@ -32,8 +32,8 @@ uses Scala {{site.SCALA_BINARY_VERSION}}. You will need to use a compatible Scal Note that support for Java 7 was removed as of Spark 2.2.0. -Note that support for Python 2.6 is deprecated as of Spark 2.0.0, and support for -Scala 2.10 and versions of Hadoop before 2.6 are deprecated as of Spark 2.1.0, and may be +Note that support for Python 2.6 is deprecated as of Spark 2.0.0, and support for +Scala 2.10 and versions of Hadoop before 2.6 are deprecated as of Spark 2.1.0, and may be removed in Spark 2.2.0. # Running the Examples and Shell diff --git a/docs/ml-features.md b/docs/ml-features.md index 57605bafbf4c3..e19fba249fb2d 100644 --- a/docs/ml-features.md +++ b/docs/ml-features.md @@ -503,6 +503,7 @@ for more details on the API. `StringIndexer` encodes a string column of labels to a column of label indices. The indices are in `[0, numLabels)`, ordered by label frequencies, so the most frequent label gets index `0`. +The unseen labels will be put at index numLabels if user chooses to keep them. If the input column is numeric, we cast it to string and index the string values. When downstream pipeline components such as `Estimator` or `Transformer` make use of this string-indexed label, you must set the input @@ -542,12 +543,13 @@ column, we should get the following: "a" gets index `0` because it is the most frequent, followed by "c" with index `1` and "b" with index `2`. -Additionally, there are two strategies regarding how `StringIndexer` will handle +Additionally, there are three strategies regarding how `StringIndexer` will handle unseen labels when you have fit a `StringIndexer` on one dataset and then use it to transform another: - throw an exception (which is the default) - skip the row containing the unseen label entirely +- put unseen labels in a special additional bucket, at index numLabels **Examples** @@ -561,6 +563,7 @@ Let's go back to our previous example but this time reuse our previously defined 1 | b 2 | c 3 | d + 4 | e ~~~~ If you've not set how `StringIndexer` handles unseen labels or set it to @@ -576,7 +579,22 @@ will be generated: 2 | c | 1.0 ~~~~ -Notice that the row containing "d" does not appear. +Notice that the rows containing "d" or "e" do not appear. + +If you call `setHandleInvalid("keep")`, the following dataset +will be generated: + +~~~~ + id | category | categoryIndex +----|----------|--------------- + 0 | a | 0.0 + 1 | b | 2.0 + 2 | c | 1.0 + 3 | d | 3.0 + 4 | e | 3.0 +~~~~ + +Notice that the rows containing "d" or "e" are mapped to index "3.0"
    @@ -1266,6 +1284,72 @@ for more details on the API.
    + +## Imputer + +The `Imputer` transformer completes missing values in a dataset, either using the mean or the +median of the columns in which the missing values are located. The input columns should be of +`DoubleType` or `FloatType`. Currently `Imputer` does not support categorical features and possibly +creates incorrect values for columns containing categorical features. + +**Note** all `null` values in the input columns are treated as missing, and so are also imputed. + +**Examples** + +Suppose that we have a DataFrame with the columns `a` and `b`: + +~~~ + a | b +------------|----------- + 1.0 | Double.NaN + 2.0 | Double.NaN + Double.NaN | 3.0 + 4.0 | 4.0 + 5.0 | 5.0 +~~~ + +In this example, Imputer will replace all occurrences of `Double.NaN` (the default for the missing value) +with the mean (the default imputation strategy) computed from the other values in the corresponding columns. +In this example, the surrogate values for columns `a` and `b` are 3.0 and 4.0 respectively. After +transformation, the missing values in the output columns will be replaced by the surrogate value for +the relevant column. + +~~~ + a | b | out_a | out_b +------------|------------|-------|------- + 1.0 | Double.NaN | 1.0 | 4.0 + 2.0 | Double.NaN | 2.0 | 4.0 + Double.NaN | 3.0 | 3.0 | 3.0 + 4.0 | 4.0 | 4.0 | 4.0 + 5.0 | 5.0 | 5.0 | 5.0 +~~~ + +
    +
    + +Refer to the [Imputer Scala docs](api/scala/index.html#org.apache.spark.ml.feature.Imputer) +for more details on the API. + +{% include_example scala/org/apache/spark/examples/ml/ImputerExample.scala %} +
    + +
    + +Refer to the [Imputer Java docs](api/java/org/apache/spark/ml/feature/Imputer.html) +for more details on the API. + +{% include_example java/org/apache/spark/examples/ml/JavaImputerExample.java %} +
    + +
    + +Refer to the [Imputer Python docs](api/python/pyspark.ml.html#pyspark.ml.feature.Imputer) +for more details on the API. + +{% include_example python/ml/imputer_example.py %} +
    +
    + # Feature Selectors ## VectorSlicer diff --git a/docs/ml-frequent-pattern-mining.md b/docs/ml-frequent-pattern-mining.md new file mode 100644 index 0000000000000..81634de8aade7 --- /dev/null +++ b/docs/ml-frequent-pattern-mining.md @@ -0,0 +1,87 @@ +--- +layout: global +title: Frequent Pattern Mining +displayTitle: Frequent Pattern Mining +--- + +Mining frequent items, itemsets, subsequences, or other substructures is usually among the +first steps to analyze a large-scale dataset, which has been an active research topic in +data mining for years. +We refer users to Wikipedia's [association rule learning](http://en.wikipedia.org/wiki/Association_rule_learning) +for more information. + +**Table of Contents** + +* This will become a table of contents (this text will be scraped). +{:toc} + +## FP-Growth + +The FP-growth algorithm is described in the paper +[Han et al., Mining frequent patterns without candidate generation](http://dx.doi.org/10.1145/335191.335372), +where "FP" stands for frequent pattern. +Given a dataset of transactions, the first step of FP-growth is to calculate item frequencies and identify frequent items. +Different from [Apriori-like](http://en.wikipedia.org/wiki/Apriori_algorithm) algorithms designed for the same purpose, +the second step of FP-growth uses a suffix tree (FP-tree) structure to encode transactions without generating candidate sets +explicitly, which are usually expensive to generate. +After the second step, the frequent itemsets can be extracted from the FP-tree. +In `spark.mllib`, we implemented a parallel version of FP-growth called PFP, +as described in [Li et al., PFP: Parallel FP-growth for query recommendation](http://dx.doi.org/10.1145/1454008.1454027). +PFP distributes the work of growing FP-trees based on the suffixes of transactions, +and hence is more scalable than a single-machine implementation. +We refer users to the papers for more details. + +`spark.ml`'s FP-growth implementation takes the following (hyper-)parameters: + +* `minSupport`: the minimum support for an itemset to be identified as frequent. + For example, if an item appears 3 out of 5 transactions, it has a support of 3/5=0.6. +* `minConfidence`: minimum confidence for generating Association Rule. Confidence is an indication of how often an + association rule has been found to be true. For example, if in the transactions itemset `X` appears 4 times, `X` + and `Y` co-occur only 2 times, the confidence for the rule `X => Y` is then 2/4 = 0.5. The parameter will not + affect the mining for frequent itemsets, but specify the minimum confidence for generating association rules + from frequent itemsets. +* `numPartitions`: the number of partitions used to distribute the work. By default the param is not set, and + number of partitions of the input dataset is used. + +The `FPGrowthModel` provides: + +* `freqItemsets`: frequent itemsets in the format of DataFrame("items"[Array], "freq"[Long]) +* `associationRules`: association rules generated with confidence above `minConfidence`, in the format of + DataFrame("antecedent"[Array], "consequent"[Array], "confidence"[Double]). +* `transform`: For each transaction in `itemsCol`, the `transform` method will compare its items against the antecedents + of each association rule. If the record contains all the antecedents of a specific association rule, the rule + will be considered as applicable and its consequents will be added to the prediction result. The transform + method will summarize the consequents from all the applicable rules as prediction. The prediction column has + the same data type as `itemsCol` and does not contain existing items in the `itemsCol`. + + +**Examples** + +
    + +
    +Refer to the [Scala API docs](api/scala/index.html#org.apache.spark.ml.fpm.FPGrowth) for more details. + +{% include_example scala/org/apache/spark/examples/ml/FPGrowthExample.scala %} +
    + +
    +Refer to the [Java API docs](api/java/org/apache/spark/ml/fpm/FPGrowth.html) for more details. + +{% include_example java/org/apache/spark/examples/ml/JavaFPGrowthExample.java %} +
    + +
    +Refer to the [Python API docs](api/python/pyspark.ml.html#pyspark.ml.fpm.FPGrowth) for more details. + +{% include_example python/ml/fpgrowth_example.py %} +
    + +
    + +Refer to the [R API docs](api/R/spark.fpGrowth.html) for more details. + +{% include_example r/ml/fpm.R %} +
    + +
    diff --git a/docs/mllib-collaborative-filtering.md b/docs/mllib-collaborative-filtering.md index 0f891a09a6e61..d1bb6d69f1256 100644 --- a/docs/mllib-collaborative-filtering.md +++ b/docs/mllib-collaborative-filtering.md @@ -20,7 +20,7 @@ algorithm to learn these latent factors. The implementation in `spark.mllib` has following parameters: * *numBlocks* is the number of blocks used to parallelize computation (set to -1 to auto-configure). -* *rank* is the number of latent factors in the model. +* *rank* is the number of features to use (also referred to as the number of latent factors). * *iterations* is the number of iterations of ALS to run. ALS typically converges to a reasonable solution in 20 iterations or less. * *lambda* specifies the regularization parameter in ALS. diff --git a/docs/mllib-dimensionality-reduction.md b/docs/mllib-dimensionality-reduction.md index 539cbc1b3163a..a72680d52a26c 100644 --- a/docs/mllib-dimensionality-reduction.md +++ b/docs/mllib-dimensionality-reduction.md @@ -76,13 +76,14 @@ Refer to the [`SingularValueDecomposition` Java docs](api/java/org/apache/spark/ The same code applies to `IndexedRowMatrix` if `U` is defined as an `IndexedRowMatrix`. + +
    +Refer to the [`SingularValueDecomposition` Python docs](api/python/pyspark.mllib.html#pyspark.mllib.linalg.distributed.SingularValueDecomposition) for details on the API. -In order to run the above application, follow the instructions -provided in the [Self-Contained -Applications](quick-start.html#self-contained-applications) section of the Spark -quick-start guide. Be sure to also include *spark-mllib* to your build file as -a dependency. +{% include_example python/mllib/svd_example.py %} +The same code applies to `IndexedRowMatrix` if `U` is defined as an +`IndexedRowMatrix`.
    @@ -118,17 +119,21 @@ Refer to the [`PCA` Scala docs](api/scala/index.html#org.apache.spark.mllib.feat The following code demonstrates how to compute principal components on a `RowMatrix` and use them to project the vectors into a low-dimensional space. -The number of columns should be small, e.g, less than 1000. Refer to the [`RowMatrix` Java docs](api/java/org/apache/spark/mllib/linalg/distributed/RowMatrix.html) for details on the API. {% include_example java/org/apache/spark/examples/mllib/JavaPCAExample.java %} - -In order to run the above application, follow the instructions -provided in the [Self-Contained Applications](quick-start.html#self-contained-applications) -section of the Spark -quick-start guide. Be sure to also include *spark-mllib* to your build file as -a dependency. +
    + +The following code demonstrates how to compute principal components on a `RowMatrix` +and use them to project the vectors into a low-dimensional space. + +Refer to the [`RowMatrix` Python docs](api/python/pyspark.mllib.html#pyspark.mllib.linalg.distributed.RowMatrix) for details on the API. + +{% include_example python/mllib/pca_rowmatrix_example.py %} + +
    + diff --git a/docs/mllib-frequent-pattern-mining.md b/docs/mllib-frequent-pattern-mining.md index 93e3f0b2d2267..c9cd7cc85e754 100644 --- a/docs/mllib-frequent-pattern-mining.md +++ b/docs/mllib-frequent-pattern-mining.md @@ -24,7 +24,7 @@ explicitly, which are usually expensive to generate. After the second step, the frequent itemsets can be extracted from the FP-tree. In `spark.mllib`, we implemented a parallel version of FP-growth called PFP, as described in [Li et al., PFP: Parallel FP-growth for query recommendation](http://dx.doi.org/10.1145/1454008.1454027). -PFP distributes the work of growing FP-trees based on the suffices of transactions, +PFP distributes the work of growing FP-trees based on the suffixes of transactions, and hence more scalable than a single-machine implementation. We refer users to the papers for more details. diff --git a/docs/monitoring.md b/docs/monitoring.md index 80519525af0c3..3e577c5f36778 100644 --- a/docs/monitoring.md +++ b/docs/monitoring.md @@ -27,8 +27,8 @@ in the UI to persisted storage. ## Viewing After the Fact -If Spark is run on Mesos or YARN, it is still possible to construct the UI of an -application through Spark's history server, provided that the application's event logs exist. +It is still possible to construct the UI of an application through Spark's history server, +provided that the application's event logs exist. You can start the history server by executing: ./sbin/start-history-server.sh @@ -257,7 +257,7 @@ In the API, an application is referenced by its application ID, `[app-id]`. When running on YARN, each application may have multiple attempts, but there are attempt IDs only for applications in cluster mode, not applications in client mode. Applications in YARN cluster mode can be identified by their `[attempt-id]`. In the API listed below, when running in YARN cluster mode, -`[app-id]` will actually be `[base-app-id]/[attempt-id]`, where `[base-app-id]` is the YARN application ID. +`[app-id]` will actually be `[base-app-id]/[attempt-id]`, where `[base-app-id]` is the YARN application ID.
    @@ -289,7 +289,7 @@ can be identified by their `[attempt-id]`. In the API listed below, when running @@ -299,12 +299,12 @@ can be identified by their `[attempt-id]`. In the API listed below, when running +
    ?status=[active|complete|pending|failed] list only stages in the state. diff --git a/docs/quick-start.md b/docs/quick-start.md index aa4319a23325c..b88ae5f6bb313 100644 --- a/docs/quick-start.md +++ b/docs/quick-start.md @@ -10,12 +10,13 @@ description: Quick start tutorial for Spark SPARK_VERSION_SHORT This tutorial provides a quick introduction to using Spark. We will first introduce the API through Spark's interactive shell (in Python or Scala), then show how to write applications in Java, Scala, and Python. -See the [programming guide](programming-guide.html) for a more complete reference. To follow along with this guide, first download a packaged release of Spark from the [Spark website](http://spark.apache.org/downloads.html). Since we won't be using HDFS, you can download a package for any version of Hadoop. +Note that, before Spark 2.0, the main programming interface of Spark was the Resilient Distributed Dataset (RDD). After Spark 2.0, RDDs are replaced by Dataset, which is strongly-typed like an RDD, but with richer optimizations under the hood. The RDD interface is still supported, and you can get a more complete reference at the [RDD programming guide](rdd-programming-guide.html). However, we highly recommend you to switch to use Dataset, which has better performance than RDD. See the [SQL programming guide](sql-programming-guide.html) to get more information about Dataset. + # Interactive Analysis with the Spark Shell ## Basics @@ -29,28 +30,28 @@ or Python. Start it by running the following in the Spark directory: ./bin/spark-shell -Spark's primary abstraction is a distributed collection of items called a Resilient Distributed Dataset (RDD). RDDs can be created from Hadoop InputFormats (such as HDFS files) or by transforming other RDDs. Let's make a new RDD from the text of the README file in the Spark source directory: +Spark's primary abstraction is a distributed collection of items called a Dataset. Datasets can be created from Hadoop InputFormats (such as HDFS files) or by transforming other Datasets. Let's make a new Dataset from the text of the README file in the Spark source directory: {% highlight scala %} -scala> val textFile = sc.textFile("README.md") -textFile: org.apache.spark.rdd.RDD[String] = README.md MapPartitionsRDD[1] at textFile at :25 +scala> val textFile = spark.read.textFile("README.md") +textFile: org.apache.spark.sql.Dataset[String] = [value: string] {% endhighlight %} -RDDs have _[actions](programming-guide.html#actions)_, which return values, and _[transformations](programming-guide.html#transformations)_, which return pointers to new RDDs. Let's start with a few actions: +You can get values from Dataset directly, by calling some actions, or transform the Dataset to get a new one. For more details, please read the _[API doc](api/scala/index.html#org.apache.spark.sql.Dataset)_. {% highlight scala %} -scala> textFile.count() // Number of items in this RDD +scala> textFile.count() // Number of items in this Dataset res0: Long = 126 // May be different from yours as README.md will change over time, similar to other outputs -scala> textFile.first() // First item in this RDD +scala> textFile.first() // First item in this Dataset res1: String = # Apache Spark {% endhighlight %} -Now let's use a transformation. We will use the [`filter`](programming-guide.html#transformations) transformation to return a new RDD with a subset of the items in the file. +Now let's transform this Dataset to a new one. We call `filter` to return a new Dataset with a subset of the items in the file. {% highlight scala %} scala> val linesWithSpark = textFile.filter(line => line.contains("Spark")) -linesWithSpark: org.apache.spark.rdd.RDD[String] = MapPartitionsRDD[2] at filter at :27 +linesWithSpark: org.apache.spark.sql.Dataset[String] = [value: string] {% endhighlight %} We can chain together transformations and actions: @@ -65,32 +66,32 @@ res3: Long = 15 ./bin/pyspark -Spark's primary abstraction is a distributed collection of items called a Resilient Distributed Dataset (RDD). RDDs can be created from Hadoop InputFormats (such as HDFS files) or by transforming other RDDs. Let's make a new RDD from the text of the README file in the Spark source directory: +Spark's primary abstraction is a distributed collection of items called a Dataset. Datasets can be created from Hadoop InputFormats (such as HDFS files) or by transforming other Datasets. Due to Python's dynamic nature, we don't need the Dataset to be strongly-typed in Python. As a result, all Datasets in Python are Dataset[Row], and we call it `DataFrame` to be consistent with the data frame concept in Pandas and R. Let's make a new DataFrame from the text of the README file in the Spark source directory: {% highlight python %} ->>> textFile = sc.textFile("README.md") +>>> textFile = spark.read.text("README.md") {% endhighlight %} -RDDs have _[actions](programming-guide.html#actions)_, which return values, and _[transformations](programming-guide.html#transformations)_, which return pointers to new RDDs. Let's start with a few actions: +You can get values from DataFrame directly, by calling some actions, or transform the DataFrame to get a new one. For more details, please read the _[API doc](api/python/index.html#pyspark.sql.DataFrame)_. {% highlight python %} ->>> textFile.count() # Number of items in this RDD +>>> textFile.count() # Number of rows in this DataFrame 126 ->>> textFile.first() # First item in this RDD -u'# Apache Spark' +>>> textFile.first() # First row in this DataFrame +Row(value=u'# Apache Spark') {% endhighlight %} -Now let's use a transformation. We will use the [`filter`](programming-guide.html#transformations) transformation to return a new RDD with a subset of the items in the file. +Now let's transform this DataFrame to a new one. We call `filter` to return a new DataFrame with a subset of the lines in the file. {% highlight python %} ->>> linesWithSpark = textFile.filter(lambda line: "Spark" in line) +>>> linesWithSpark = textFile.filter(textFile.value.contains("Spark")) {% endhighlight %} We can chain together transformations and actions: {% highlight python %} ->>> textFile.filter(lambda line: "Spark" in line).count() # How many lines contain "Spark"? +>>> textFile.filter(textFile.value.contains("Spark")).count() # How many lines contain "Spark"? 15 {% endhighlight %} @@ -98,8 +99,8 @@ We can chain together transformations and actions: -## More on RDD Operations -RDD actions and transformations can be used for more complex computations. Let's say we want to find the line with the most words: +## More on Dataset Operations +Dataset actions and transformations can be used for more complex computations. Let's say we want to find the line with the most words:
    @@ -109,7 +110,7 @@ scala> textFile.map(line => line.split(" ").size).reduce((a, b) => if (a > b) a res4: Long = 15 {% endhighlight %} -This first maps a line to an integer value, creating a new RDD. `reduce` is called on that RDD to find the largest line count. The arguments to `map` and `reduce` are Scala function literals (closures), and can use any language feature or Scala/Java library. For example, we can easily call functions declared elsewhere. We'll use `Math.max()` function to make this code easier to understand: +This first maps a line to an integer value, creating a new Dataset. `reduce` is called on that Dataset to find the largest word count. The arguments to `map` and `reduce` are Scala function literals (closures), and can use any language feature or Scala/Java library. For example, we can easily call functions declared elsewhere. We'll use `Math.max()` function to make this code easier to understand: {% highlight scala %} scala> import java.lang.Math @@ -122,11 +123,11 @@ res5: Int = 15 One common data flow pattern is MapReduce, as popularized by Hadoop. Spark can implement MapReduce flows easily: {% highlight scala %} -scala> val wordCounts = textFile.flatMap(line => line.split(" ")).map(word => (word, 1)).reduceByKey((a, b) => a + b) -wordCounts: org.apache.spark.rdd.RDD[(String, Int)] = ShuffledRDD[8] at reduceByKey at :28 +scala> val wordCounts = textFile.flatMap(line => line.split(" ")).groupByKey(identity).count() +wordCounts: org.apache.spark.sql.Dataset[(String, Long)] = [value: string, count(1): bigint] {% endhighlight %} -Here, we combined the [`flatMap`](programming-guide.html#transformations), [`map`](programming-guide.html#transformations), and [`reduceByKey`](programming-guide.html#transformations) transformations to compute the per-word counts in the file as an RDD of (String, Int) pairs. To collect the word counts in our shell, we can use the [`collect`](programming-guide.html#actions) action: +Here, we call `flatMap` to transform a Dataset of lines to a Dataset of words, and then combine `groupByKey` and `count` to compute the per-word counts in the file as a Dataset of (String, Long) pairs. To collect the word counts in our shell, we can call `collect`: {% highlight scala %} scala> wordCounts.collect() @@ -137,37 +138,24 @@ res6: Array[(String, Int)] = Array((means,1), (under,2), (this,3), (Because,1),
    {% highlight python %} ->>> textFile.map(lambda line: len(line.split())).reduce(lambda a, b: a if (a > b) else b) -15 +>>> from pyspark.sql.functions import * +>>> textFile.select(size(split(textFile.value, "\s+")).name("numWords")).agg(max(col("numWords"))).collect() +[Row(max(numWords)=15)] {% endhighlight %} -This first maps a line to an integer value, creating a new RDD. `reduce` is called on that RDD to find the largest line count. The arguments to `map` and `reduce` are Python [anonymous functions (lambdas)](https://docs.python.org/2/reference/expressions.html#lambda), -but we can also pass any top-level Python function we want. -For example, we'll define a `max` function to make this code easier to understand: - -{% highlight python %} ->>> def max(a, b): -... if a > b: -... return a -... else: -... return b -... - ->>> textFile.map(lambda line: len(line.split())).reduce(max) -15 -{% endhighlight %} +This first maps a line to an integer value and aliases it as "numWords", creating a new DataFrame. `agg` is called on that DataFrame to find the largest word count. The arguments to `select` and `agg` are both _[Column](api/python/index.html#pyspark.sql.Column)_, we can use `df.colName` to get a column from a DataFrame. We can also import pyspark.sql.functions, which provides a lot of convenient functions to build a new Column from an old one. One common data flow pattern is MapReduce, as popularized by Hadoop. Spark can implement MapReduce flows easily: {% highlight python %} ->>> wordCounts = textFile.flatMap(lambda line: line.split()).map(lambda word: (word, 1)).reduceByKey(lambda a, b: a+b) +>>> wordCounts = textFile.select(explode(split(textFile.value, "\s+")).as("word")).groupBy("word").count() {% endhighlight %} -Here, we combined the [`flatMap`](programming-guide.html#transformations), [`map`](programming-guide.html#transformations), and [`reduceByKey`](programming-guide.html#transformations) transformations to compute the per-word counts in the file as an RDD of (string, int) pairs. To collect the word counts in our shell, we can use the [`collect`](programming-guide.html#actions) action: +Here, we use the `explode` function in `select`, to transfrom a Dataset of lines to a Dataset of words, and then combine `groupBy` and `count` to compute the per-word counts in the file as a DataFrame of 2 columns: "word" and "count". To collect the word counts in our shell, we can call `collect`: {% highlight python %} >>> wordCounts.collect() -[(u'and', 9), (u'A', 1), (u'webpage', 1), (u'README', 1), (u'Note', 1), (u'"local"', 1), (u'variable', 1), ...] +[Row(word=u'online', count=1), Row(word=u'graphs', count=1), ...] {% endhighlight %}
    @@ -181,7 +169,7 @@ Spark also supports pulling data sets into a cluster-wide in-memory cache. This {% highlight scala %} scala> linesWithSpark.cache() -res7: linesWithSpark.type = MapPartitionsRDD[2] at filter at :27 +res7: linesWithSpark.type = [value: string] scala> linesWithSpark.count() res8: Long = 15 @@ -193,7 +181,7 @@ res9: Long = 15 It may seem silly to use Spark to explore and cache a 100-line text file. The interesting part is that these same functions can be used on very large data sets, even when they are striped across tens or hundreds of nodes. You can also do this interactively by connecting `bin/spark-shell` to -a cluster, as described in the [programming guide](programming-guide.html#initializing-spark). +a cluster, as described in the [RDD programming guide](rdd-programming-guide.html#using-the-shell).
    @@ -211,7 +199,7 @@ a cluster, as described in the [programming guide](programming-guide.html#initia It may seem silly to use Spark to explore and cache a 100-line text file. The interesting part is that these same functions can be used on very large data sets, even when they are striped across tens or hundreds of nodes. You can also do this interactively by connecting `bin/pyspark` to -a cluster, as described in the [programming guide](programming-guide.html#initializing-spark). +a cluster, as described in the [RDD programming guide](rdd-programming-guide.html#using-the-shell).
    @@ -228,20 +216,17 @@ named `SimpleApp.scala`: {% highlight scala %} /* SimpleApp.scala */ -import org.apache.spark.SparkContext -import org.apache.spark.SparkContext._ -import org.apache.spark.SparkConf +import org.apache.spark.sql.SparkSession object SimpleApp { def main(args: Array[String]) { val logFile = "YOUR_SPARK_HOME/README.md" // Should be some file on your system - val conf = new SparkConf().setAppName("Simple Application") - val sc = new SparkContext(conf) - val logData = sc.textFile(logFile, 2).cache() + val spark = SparkSession.builder.appName("Simple Application").getOrCreate() + val logData = spark.read.textFile(logFile).cache() val numAs = logData.filter(line => line.contains("a")).count() val numBs = logData.filter(line => line.contains("b")).count() println(s"Lines with a: $numAs, Lines with b: $numBs") - sc.stop() + spark.stop() } } {% endhighlight %} @@ -251,16 +236,13 @@ Subclasses of `scala.App` may not work correctly. This program just counts the number of lines containing 'a' and the number containing 'b' in the Spark README. Note that you'll need to replace YOUR_SPARK_HOME with the location where Spark is -installed. Unlike the earlier examples with the Spark shell, which initializes its own SparkContext, -we initialize a SparkContext as part of the program. +installed. Unlike the earlier examples with the Spark shell, which initializes its own SparkSession, +we initialize a SparkSession as part of the program. -We pass the SparkContext constructor a -[SparkConf](api/scala/index.html#org.apache.spark.SparkConf) -object which contains information about our -application. +We call `SparkSession.builder` to construct a [[SparkSession]], then set the application name, and finally call `getOrCreate` to get the [[SparkSession]] instance. -Our application depends on the Spark API, so we'll also include an sbt configuration file, -`build.sbt`, which explains that Spark is a dependency. This file also adds a repository that +Our application depends on the Spark API, so we'll also include an sbt configuration file, +`build.sbt`, which explains that Spark is a dependency. This file also adds a repository that Spark depends on: {% highlight scala %} @@ -270,7 +252,7 @@ version := "1.0" scalaVersion := "{{site.SCALA_VERSION}}" -libraryDependencies += "org.apache.spark" %% "spark-core" % "{{site.SPARK_VERSION}}" +libraryDependencies += "org.apache.spark" %% "spark-sql" % "{{site.SPARK_VERSION}}" {% endhighlight %} For sbt to work correctly, we'll need to layout `SimpleApp.scala` and `build.sbt` @@ -309,34 +291,28 @@ We'll create a very simple Spark application, `SimpleApp.java`: {% highlight java %} /* SimpleApp.java */ -import org.apache.spark.api.java.*; -import org.apache.spark.SparkConf; -import org.apache.spark.api.java.function.Function; +import org.apache.spark.sql.SparkSession; public class SimpleApp { public static void main(String[] args) { String logFile = "YOUR_SPARK_HOME/README.md"; // Should be some file on your system - SparkConf conf = new SparkConf().setAppName("Simple Application"); - JavaSparkContext sc = new JavaSparkContext(conf); - JavaRDD logData = sc.textFile(logFile).cache(); + SparkSession spark = SparkSession.builder().appName("Simple Application").getOrCreate(); + Dataset logData = spark.read.textFile(logFile).cache(); long numAs = logData.filter(s -> s.contains("a")).count(); long numBs = logData.filter(s -> s.contains("b")).count(); System.out.println("Lines with a: " + numAs + ", lines with b: " + numBs); - - sc.stop(); + + spark.stop(); } } {% endhighlight %} -This program just counts the number of lines containing 'a' and the number containing 'b' in a text -file. Note that you'll need to replace YOUR_SPARK_HOME with the location where Spark is installed. -As with the Scala example, we initialize a SparkContext, though we use the special -`JavaSparkContext` class to get a Java-friendly one. We also create RDDs (represented by -`JavaRDD`) and run transformations on them. Finally, we pass functions to Spark by creating classes -that extend `spark.api.java.function.Function`. The -[Spark programming guide](programming-guide.html) describes these differences in more detail. +This program just counts the number of lines containing 'a' and the number containing 'b' in the +Spark README. Note that you'll need to replace YOUR_SPARK_HOME with the location where Spark is +installed. Unlike the earlier examples with the Spark shell, which initializes its own SparkSession, +we initialize a SparkSession as part of the program. To build the program, we also write a Maven `pom.xml` file that lists Spark as a dependency. Note that Spark artifacts are tagged with a Scala version. @@ -352,7 +328,7 @@ Note that Spark artifacts are tagged with a Scala version. org.apache.spark - spark-core_{{site.SCALA_BINARY_VERSION}} + spark-sql_{{site.SCALA_BINARY_VERSION}} {{site.SPARK_VERSION}} @@ -395,27 +371,25 @@ As an example, we'll create a simple Spark application, `SimpleApp.py`: {% highlight python %} """SimpleApp.py""" -from pyspark import SparkContext +from pyspark.sql import SparkSession logFile = "YOUR_SPARK_HOME/README.md" # Should be some file on your system -sc = SparkContext("local", "Simple App") -logData = sc.textFile(logFile).cache() +spark = SparkSession.builder().appName(appName).master(master).getOrCreate() +logData = spark.read.text(logFile).cache() -numAs = logData.filter(lambda s: 'a' in s).count() -numBs = logData.filter(lambda s: 'b' in s).count() +numAs = logData.filter(logData.value.contains('a')).count() +numBs = logData.filter(logData.value.contains('b')).count() print("Lines with a: %i, lines with b: %i" % (numAs, numBs)) -sc.stop() +spark.stop() {% endhighlight %} This program just counts the number of lines containing 'a' and the number containing 'b' in a text file. Note that you'll need to replace YOUR_SPARK_HOME with the location where Spark is installed. -As with the Scala and Java examples, we use a SparkContext to create RDDs. -We can pass Python functions to Spark, which are automatically serialized along with any variables -that they reference. +As with the Scala and Java examples, we use a SparkSession to create Datasets. For applications that use custom classes or third-party libraries, we can also add code dependencies to `spark-submit` through its `--py-files` argument by packaging them into a .zip file (see `spark-submit --help` for details). @@ -438,8 +412,7 @@ Lines with a: 46, Lines with b: 23 # Where to Go from Here Congratulations on running your first Spark application! -* For an in-depth overview of the API, start with the [Spark programming guide](programming-guide.html), - or see "Programming Guides" menu for other components. +* For an in-depth overview of the API, start with the [RDD programming guide](rdd-programming-guide.html) and the [SQL programming guide](sql-programming-guide.html), or see "Programming Guides" menu for other components. * For running applications on a cluster, head to the [deployment overview](cluster-overview.html). * Finally, Spark includes several samples in the `examples` directory ([Scala]({{site.SPARK_GITHUB_URL}}/tree/master/examples/src/main/scala/org/apache/spark/examples), diff --git a/docs/programming-guide.md b/docs/rdd-programming-guide.md similarity index 99% rename from docs/programming-guide.md rename to docs/rdd-programming-guide.md index 3bbb1d36e5095..52e59df9990e9 100644 --- a/docs/programming-guide.md +++ b/docs/rdd-programming-guide.md @@ -24,7 +24,7 @@ along with if you launch Spark's interactive shell -- either `bin/spark-shell` f
    -Spark {{site.SPARK_VERSION}} is built and distributed to work with Scala {{site.SCALA_BINARY_VERSION}} +Spark {{site.SPARK_VERSION}} is built and distributed to work with Scala {{site.SCALA_BINARY_VERSION}} by default. (Spark can be built to work with other versions of Scala, too.) To write applications in Scala, you will need to use a compatible Scala version (e.g. {{site.SCALA_BINARY_VERSION}}.X). @@ -76,10 +76,10 @@ In addition, if you wish to access an HDFS cluster, you need to add a dependency Finally, you need to import some Spark classes into your program. Add the following lines: -{% highlight scala %} -import org.apache.spark.api.java.JavaSparkContext -import org.apache.spark.api.java.JavaRDD -import org.apache.spark.SparkConf +{% highlight java %} +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.SparkConf; {% endhighlight %}
    @@ -244,13 +244,13 @@ use IPython, set the `PYSPARK_DRIVER_PYTHON` variable to `ipython` when running $ PYSPARK_DRIVER_PYTHON=ipython ./bin/pyspark {% endhighlight %} -To use the Jupyter notebook (previously known as the IPython notebook), +To use the Jupyter notebook (previously known as the IPython notebook), {% highlight bash %} $ PYSPARK_DRIVER_PYTHON=jupyter ./bin/pyspark {% endhighlight %} -You can customize the `ipython` or `jupyter` commands by setting `PYSPARK_DRIVER_PYTHON_OPTS`. +You can customize the `ipython` or `jupyter` commands by setting `PYSPARK_DRIVER_PYTHON_OPTS`. After the Jupyter Notebook server is launched, you can create a new "Python 2" notebook from the "Files" tab. Inside the notebook, you can input the command `%pylab inline` as part of @@ -457,7 +457,7 @@ If required, a Hadoop configuration can be passed in as a Python dict. Here is a Elasticsearch ESInputFormat: {% highlight python %} -$ SPARK_CLASSPATH=/path/to/elasticsearch-hadoop.jar ./bin/pyspark +$ ./bin/pyspark --jars /path/to/elasticsearch-hadoop.jar >>> conf = {"es.resource" : "index/type"} # assume Elasticsearch is running on localhost defaults >>> rdd = sc.newAPIHadoopRDD("org.elasticsearch.hadoop.mr.EsInputFormat", "org.apache.hadoop.io.NullWritable", @@ -811,7 +811,7 @@ The variables within the closure sent to each executor are now copies and thus, In local mode, in some circumstances the `foreach` function will actually execute within the same JVM as the driver and will reference the same original **counter**, and may actually update it. -To ensure well-defined behavior in these sorts of scenarios one should use an [`Accumulator`](#accumulators). Accumulators in Spark are used specifically to provide a mechanism for safely updating a variable when execution is split up across worker nodes in a cluster. The Accumulators section of this guide discusses these in more detail. +To ensure well-defined behavior in these sorts of scenarios one should use an [`Accumulator`](#accumulators). Accumulators in Spark are used specifically to provide a mechanism for safely updating a variable when execution is split up across worker nodes in a cluster. The Accumulators section of this guide discusses these in more detail. In general, closures - constructs like loops or locally defined methods, should not be used to mutate some global state. Spark does not define or guarantee the behavior of mutations to objects referenced from outside of closures. Some code that does this may work in local mode, but that's just by accident and such code will not behave as expected in distributed mode. Use an Accumulator instead if some global aggregation is needed. @@ -1230,8 +1230,8 @@ storage levels is:
    EndpointMeaning
    /applications/[app-id]/jobs A list of all jobs for a given application. -
    ?status=[complete|succeeded|failed] list only jobs in the specific state. +
    ?status=[running|succeeded|failed|unknown] list only jobs in the specific state.
    /applications/[app-id]/stages A list of all stages for a given application.
    /applications/[app-id]/stages/[stage-id] A list of all attempts for the given stage. -
    ?status=[active|complete|pending|failed] list only stages in the state.
    -**Note:** *In Python, stored objects will always be serialized with the [Pickle](https://docs.python.org/2/library/pickle.html) library, -so it does not matter whether you choose a serialized level. The available storage levels in Python include `MEMORY_ONLY`, `MEMORY_ONLY_2`, +**Note:** *In Python, stored objects will always be serialized with the [Pickle](https://docs.python.org/2/library/pickle.html) library, +so it does not matter whether you choose a serialized level. The available storage levels in Python include `MEMORY_ONLY`, `MEMORY_ONLY_2`, `MEMORY_AND_DISK`, `MEMORY_AND_DISK_2`, `DISK_ONLY`, and `DISK_ONLY_2`.* Spark also automatically persists some intermediate data in shuffle operations (e.g. `reduceByKey`), even without users calling `persist`. This is done to avoid recomputing the entire input if a node fails during the shuffle. We still recommend users call `persist` on the resulting RDD if they plan to reuse it. @@ -1346,7 +1346,7 @@ As a user, you can create named or unnamed accumulators. As seen in the image be Accumulators in the Spark UI

    -Tracking accumulators in the UI can be useful for understanding the progress of +Tracking accumulators in the UI can be useful for understanding the progress of running stages (NOTE: this is not yet supported in Python).
    @@ -1355,7 +1355,7 @@ running stages (NOTE: this is not yet supported in Python). A numeric accumulator can be created by calling `SparkContext.longAccumulator()` or `SparkContext.doubleAccumulator()` to accumulate values of type Long or Double, respectively. Tasks running on a cluster can then add to it using -the `add` method. However, they cannot read its value. Only the driver program can read the accumulator's value, +the `add` method. However, they cannot read its value. Only the driver program can read the accumulator's value, using its `value` method. The code below shows an accumulator being used to add up the elements of an array: @@ -1409,7 +1409,7 @@ Note that, when programmers define their own type of AccumulatorV2, the resultin A numeric accumulator can be created by calling `SparkContext.longAccumulator()` or `SparkContext.doubleAccumulator()` to accumulate values of type Long or Double, respectively. Tasks running on a cluster can then add to it using -the `add` method. However, they cannot read its value. Only the driver program can read the accumulator's value, +the `add` method. However, they cannot read its value. Only the driver program can read the accumulator's value, using its `value` method. The code below shows an accumulator being used to add up the elements of an array: diff --git a/docs/running-on-mesos.md b/docs/running-on-mesos.md index 8d5ad12cb85be..314a806edf39e 100644 --- a/docs/running-on-mesos.md +++ b/docs/running-on-mesos.md @@ -356,6 +356,16 @@ See the [configuration page](configuration.html) for information on Spark config By default Mesos agents will not pull images they already have cached. + + spark.mesos.executor.docker.parameters + (none) + + Set the list of custom parameters which will be passed into the docker run command when launching the Spark executor on Mesos using the docker containerizer. The format of this property is a comma-separated list of + key/value pairs. Example: + +
    key1=val1,key2=val2,key3=val3
    + + spark.mesos.executor.docker.volumes (none) @@ -367,6 +377,15 @@ See the [configuration page](configuration.html) for information on Spark config
    [host_path:]container_path[:ro|:rw]
    + + spark.mesos.task.labels + (none) + + Set the Mesos labels to add to each task. Labels are free-form key-value pairs. + Key-value pairs should be separated by a colon, and commas used to list more than one. + Ex. key:value,key2:value2. + + spark.mesos.executor.home driver side SPARK_HOME diff --git a/docs/spark-standalone.md b/docs/spark-standalone.md index 1c0b60f7b9346..34ced9ed7b462 100644 --- a/docs/spark-standalone.md +++ b/docs/spark-standalone.md @@ -242,7 +242,7 @@ SPARK_WORKER_OPTS supports the following system properties: spark.worker.cleanup.appDataTtl - 7 * 24 * 3600 (7 days) + 604800 (7 days, 7 * 24 * 3600) The number of seconds to retain application work directories on each worker. This is a Time To Live and should depend on the amount of available disk space you have. Application logs and jars are diff --git a/docs/sparkr.md b/docs/sparkr.md index d7ffd9b3f1229..569b85e72c3cf 100644 --- a/docs/sparkr.md +++ b/docs/sparkr.md @@ -264,6 +264,36 @@ head(arrange(waiting_counts, desc(waiting_counts$count))) {% endhighlight %}
    +In addition to standard aggregations, SparkR supports [OLAP cube](https://en.wikipedia.org/wiki/OLAP_cube) operators `cube`: + +
    +{% highlight r %} +head(agg(cube(df, "cyl", "disp", "gear"), avg(df$mpg))) +## cyl disp gear avg(mpg) +##1 NA 140.8 4 22.8 +##2 4 75.7 4 30.4 +##3 8 400.0 3 19.2 +##4 8 318.0 3 15.5 +##5 NA 351.0 NA 15.8 +##6 NA 275.8 NA 16.3 +{% endhighlight %} +
    + +and `rollup`: + +
    +{% highlight r %} +head(agg(rollup(df, "cyl", "disp", "gear"), avg(df$mpg))) +## cyl disp gear avg(mpg) +##1 4 75.7 4 30.4 +##2 8 400.0 3 19.2 +##3 8 318.0 3 15.5 +##4 4 78.7 NA 32.4 +##5 8 304.0 3 15.2 +##6 4 79.0 NA 27.3 +{% endhighlight %} +
    + ### Operating on Columns SparkR also provides a number of functions that can directly applied to columns for data processing and during aggregation. The example below shows the use of basic arithmetic functions. @@ -394,75 +424,6 @@ head(result[order(result$max_eruption, decreasing = TRUE), ]) {% endhighlight %}
    -#### Data type mapping between R and Spark - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
    RSpark
    bytebyte
    integerinteger
    floatfloat
    doubledouble
    numericdouble
    characterstring
    stringstring
    binarybinary
    rawbinary
    logicalboolean
    POSIXcttimestamp
    POSIXlttimestamp
    Datedate
    arrayarray
    listarray
    envmap
    - #### Run local R functions distributed using `spark.lapply` ##### spark.lapply @@ -521,6 +482,7 @@ SparkR supports the following machine learning algorithms currently: * [`spark.logit`](api/R/spark.logit.html): [`Logistic Regression`](ml-classification-regression.html#logistic-regression) * [`spark.mlp`](api/R/spark.mlp.html): [`Multilayer Perceptron (MLP)`](ml-classification-regression.html#multilayer-perceptron-classifier) * [`spark.naiveBayes`](api/R/spark.naiveBayes.html): [`Naive Bayes`](ml-classification-regression.html#naive-bayes) +* [`spark.svmLinear`](api/R/spark.svmLinear.html): [`Linear Support Vector Machine`](ml-classification-regression.html#linear-support-vector-machine) #### Regression @@ -535,6 +497,7 @@ SparkR supports the following machine learning algorithms currently: #### Clustering +* [`spark.bisectingKmeans`](api/R/spark.bisectingKmeans.html): [`Bisecting k-means`](ml-clustering.html#bisecting-k-means) * [`spark.gaussianMixture`](api/R/spark.gaussianMixture.html): [`Gaussian Mixture Model (GMM)`](ml-clustering.html#gaussian-mixture-model-gmm) * [`spark.kmeans`](api/R/spark.kmeans.html): [`K-Means`](ml-clustering.html#k-means) * [`spark.lda`](api/R/spark.lda.html): [`Latent Dirichlet Allocation (LDA)`](ml-clustering.html#latent-dirichlet-allocation-lda) @@ -543,6 +506,10 @@ SparkR supports the following machine learning algorithms currently: * [`spark.als`](api/R/spark.als.html): [`Alternating Least Squares (ALS)`](ml-collaborative-filtering.html#collaborative-filtering) +#### Frequent Pattern Mining + +* [`spark.fpGrowth`](api/R/spark.fpGrowth.html) : [`FP-growth`](ml-frequent-pattern-mining.html#fp-growth) + #### Statistics * [`spark.kstest`](api/R/spark.kstest.html): `Kolmogorov-Smirnov Test` @@ -557,6 +524,79 @@ SparkR supports a subset of the available R formula operators for model fitting, The following example shows how to save/load a MLlib model by SparkR. {% include_example read_write r/ml/ml.R %} +# Data type mapping between R and Spark + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
    RSpark
    bytebyte
    integerinteger
    floatfloat
    doubledouble
    numericdouble
    characterstring
    stringstring
    binarybinary
    rawbinary
    logicalboolean
    POSIXcttimestamp
    POSIXlttimestamp
    Datedate
    arrayarray
    listarray
    envmap
    + +# Structured Streaming + +SparkR supports the Structured Streaming API (experimental). Structured Streaming is a scalable and fault-tolerant stream processing engine built on the Spark SQL engine. For more information see the R API on the [Structured Streaming Programming Guide](structured-streaming-programming-guide.html) + # R Function Name Conflicts When loading and attaching a new package in R, it is possible to have a name [conflict](https://stat.ethz.ch/R-manual/R-devel/library/base/html/library.html), where a @@ -608,3 +648,11 @@ You can inspect the search path in R with [`search()`](https://stat.ethz.ch/R-ma ## Upgrading to SparkR 2.1.0 - `join` no longer performs Cartesian Product by default, use `crossJoin` instead. + +## Upgrading to SparkR 2.2.0 + + - A `numPartitions` parameter has been added to `createDataFrame` and `as.DataFrame`. When splitting the data, the partition position calculation has been made to match the one in Scala. + - The method `createExternalTable` has been deprecated to be replaced by `createTable`. Either methods can be called to create external or managed table. Additional catalog methods have also been added. + - By default, derby.log is now saved to `tempdir()`. This will be created when instantiating the SparkSession with `enableHiveSupport` set to `TRUE`. + - `spark.lda` was not setting the optimizer correctly. It has been corrected. + - Several model summary outputs are updated to have `coefficients` as `matrix`. This includes `spark.logit`, `spark.kmeans`, `spark.glm`. Model summary outputs for `spark.gaussianMixture` have added log-likelihood as `loglik`. diff --git a/docs/sql-programming-guide.md b/docs/sql-programming-guide.md index b077575155eb0..490c1ce8a7cc5 100644 --- a/docs/sql-programming-guide.md +++ b/docs/sql-programming-guide.md @@ -571,7 +571,7 @@ be created by calling the `table` method on a `SparkSession` with the name of th For file-based data source, e.g. text, parquet, json, etc. you can specify a custom table path via the `path` option, e.g. `df.write.option("path", "/some/path").saveAsTable("t")`. When the table is dropped, the custom table path will not be removed and the table data is still there. If no custom table path is -specifed, Spark will write data to a default table path under the warehouse directory. When the table is +specified, Spark will write data to a default table path under the warehouse directory. When the table is dropped, the default table path will be removed too. Starting from Spark 2.1, persistent datasource tables have per-partition metadata stored in the Hive metastore. This brings several benefits: @@ -883,7 +883,7 @@ Configuration of Parquet can be done using the `setConf` method on `SparkSession
    Spark SQL can automatically infer the schema of a JSON dataset and load it as a `Dataset[Row]`. -This conversion can be done using `SparkSession.read.json()` on either an RDD of String, +This conversion can be done using `SparkSession.read.json()` on either a `Dataset[String]`, or a JSON file. Note that the file that is offered as _a json file_ is not a typical JSON file. Each @@ -897,7 +897,7 @@ For a regular multi-line JSON file, set the `wholeFile` option to `true`.
    Spark SQL can automatically infer the schema of a JSON dataset and load it as a `Dataset`. -This conversion can be done using `SparkSession.read().json()` on either an RDD of String, +This conversion can be done using `SparkSession.read().json()` on either a `Dataset`, or a JSON file. Note that the file that is offered as _a json file_ is not a typical JSON file. Each @@ -1223,6 +1223,13 @@ the following case-insensitive options: This is a JDBC writer related option. If specified, this option allows setting of database-specific table and partition options when creating a table (e.g., CREATE TABLE t (name string) ENGINE=InnoDB.). This option applies only to writing. + + + createTableColumnTypes + + The database column data types to use instead of the defaults, when creating the table. Data type information should be specified in the same format as CREATE TABLE columns syntax (e.g: "name CHAR(64), comments VARCHAR(1024)"). The specified types should be valid spark sql data types. This option applies only to writing. + +
    @@ -1693,7 +1700,7 @@ referencing a singleton. Spark SQL is designed to be compatible with the Hive Metastore, SerDes and UDFs. Currently Hive SerDes and UDFs are based on Hive 1.2.1, and Spark SQL can be connected to different versions of Hive Metastore -(from 0.12.0 to 1.2.1. Also see [Interacting with Different Versions of Hive Metastore] (#interacting-with-different-versions-of-hive-metastore)). +(from 0.12.0 to 2.1.1. Also see [Interacting with Different Versions of Hive Metastore] (#interacting-with-different-versions-of-hive-metastore)). #### Deploying in Existing Hive Warehouses diff --git a/docs/streaming-kafka-0-10-integration.md b/docs/streaming-kafka-0-10-integration.md index e3837013168dc..92c296a9e6bd3 100644 --- a/docs/streaming-kafka-0-10-integration.md +++ b/docs/streaming-kafka-0-10-integration.md @@ -12,6 +12,8 @@ For Scala/Java applications using SBT/Maven project definitions, link your strea artifactId = spark-streaming-kafka-0-10_{{site.SCALA_BINARY_VERSION}} version = {{site.SPARK_VERSION_SHORT}} +**Do not** manually add dependencies on `org.apache.kafka` artifacts (e.g. `kafka-clients`). The `spark-streaming-kafka-0-10` artifact has the appropriate transitive dependencies already, and different versions may be incompatible in hard to diagnose ways. + ### Creating a Direct Stream Note that the namespace for the import includes the version, org.apache.spark.streaming.kafka010 diff --git a/docs/structured-streaming-kafka-integration.md b/docs/structured-streaming-kafka-integration.md index 522e669568678..217c1a91a16f3 100644 --- a/docs/structured-streaming-kafka-integration.md +++ b/docs/structured-streaming-kafka-integration.md @@ -3,9 +3,9 @@ layout: global title: Structured Streaming + Kafka Integration Guide (Kafka broker version 0.10.0 or higher) --- -Structured Streaming integration for Kafka 0.10 to poll data from Kafka. +Structured Streaming integration for Kafka 0.10 to read data from and write data to Kafka. -### Linking +## Linking For Scala/Java applications using SBT/Maven project definitions, link your application with the following artifact: groupId = org.apache.spark @@ -15,40 +15,42 @@ For Scala/Java applications using SBT/Maven project definitions, link your appli For Python applications, you need to add this above library and its dependencies when deploying your application. See the [Deploying](#deploying) subsection below. -### Creating a Kafka Source Stream +## Reading Data from Kafka + +### Creating a Kafka Source for Streaming Queries
    {% highlight scala %} // Subscribe to 1 topic -val ds1 = spark +val df = spark .readStream .format("kafka") .option("kafka.bootstrap.servers", "host1:port1,host2:port2") .option("subscribe", "topic1") .load() -ds1.selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)") +df.selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)") .as[(String, String)] // Subscribe to multiple topics -val ds2 = spark +val df = spark .readStream .format("kafka") .option("kafka.bootstrap.servers", "host1:port1,host2:port2") .option("subscribe", "topic1,topic2") .load() -ds2.selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)") +df.selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)") .as[(String, String)] // Subscribe to a pattern -val ds3 = spark +val df = spark .readStream .format("kafka") .option("kafka.bootstrap.servers", "host1:port1,host2:port2") .option("subscribePattern", "topic.*") .load() -ds3.selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)") +df.selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)") .as[(String, String)] {% endhighlight %} @@ -57,31 +59,31 @@ ds3.selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)") {% highlight java %} // Subscribe to 1 topic -Dataset ds1 = spark +DataFrame df = spark .readStream() .format("kafka") .option("kafka.bootstrap.servers", "host1:port1,host2:port2") .option("subscribe", "topic1") .load() -ds1.selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)") +df.selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)") // Subscribe to multiple topics -Dataset ds2 = spark +DataFrame df = spark .readStream() .format("kafka") .option("kafka.bootstrap.servers", "host1:port1,host2:port2") .option("subscribe", "topic1,topic2") .load() -ds2.selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)") +df.selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)") // Subscribe to a pattern -Dataset ds3 = spark +DataFrame df = spark .readStream() .format("kafka") .option("kafka.bootstrap.servers", "host1:port1,host2:port2") .option("subscribePattern", "topic.*") .load() -ds3.selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)") +df.selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)") {% endhighlight %}
    @@ -89,37 +91,37 @@ ds3.selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)") {% highlight python %} # Subscribe to 1 topic -ds1 = spark - .readStream - .format("kafka") - .option("kafka.bootstrap.servers", "host1:port1,host2:port2") - .option("subscribe", "topic1") +df = spark \ + .readStream \ + .format("kafka") \ + .option("kafka.bootstrap.servers", "host1:port1,host2:port2") \ + .option("subscribe", "topic1") \ .load() -ds1.selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)") +df.selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)") # Subscribe to multiple topics -ds2 = spark - .readStream - .format("kafka") - .option("kafka.bootstrap.servers", "host1:port1,host2:port2") - .option("subscribe", "topic1,topic2") +df = spark \ + .readStream \ + .format("kafka") \ + .option("kafka.bootstrap.servers", "host1:port1,host2:port2") \ + .option("subscribe", "topic1,topic2") \ .load() -ds2.selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)") +df.selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)") # Subscribe to a pattern -ds3 = spark - .readStream - .format("kafka") - .option("kafka.bootstrap.servers", "host1:port1,host2:port2") - .option("subscribePattern", "topic.*") +df = spark \ + .readStream \ + .format("kafka") \ + .option("kafka.bootstrap.servers", "host1:port1,host2:port2") \ + .option("subscribePattern", "topic.*") \ .load() -ds3.selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)") +df.selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)") {% endhighlight %}
    -### Creating a Kafka Source Batch +### Creating a Kafka Source for Batch Queries If you have a use case that is better suited to batch processing, you can create an Dataset/DataFrame for a defined range of offsets. @@ -128,17 +130,17 @@ you can create an Dataset/DataFrame for a defined range of offsets. {% highlight scala %} // Subscribe to 1 topic defaults to the earliest and latest offsets -val ds1 = spark +val df = spark .read .format("kafka") .option("kafka.bootstrap.servers", "host1:port1,host2:port2") .option("subscribe", "topic1") .load() -ds1.selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)") +df.selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)") .as[(String, String)] // Subscribe to multiple topics, specifying explicit Kafka offsets -val ds2 = spark +val df = spark .read .format("kafka") .option("kafka.bootstrap.servers", "host1:port1,host2:port2") @@ -146,11 +148,11 @@ val ds2 = spark .option("startingOffsets", """{"topic1":{"0":23,"1":-2},"topic2":{"0":-2}}""") .option("endingOffsets", """{"topic1":{"0":50,"1":-1},"topic2":{"0":-1}}""") .load() -ds2.selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)") +df.selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)") .as[(String, String)] // Subscribe to a pattern, at the earliest and latest offsets -val ds3 = spark +val df = spark .read .format("kafka") .option("kafka.bootstrap.servers", "host1:port1,host2:port2") @@ -158,7 +160,7 @@ val ds3 = spark .option("startingOffsets", "earliest") .option("endingOffsets", "latest") .load() -ds3.selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)") +df.selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)") .as[(String, String)] {% endhighlight %} @@ -167,16 +169,16 @@ ds3.selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)") {% highlight java %} // Subscribe to 1 topic defaults to the earliest and latest offsets -Dataset ds1 = spark +DataFrame df = spark .read() .format("kafka") .option("kafka.bootstrap.servers", "host1:port1,host2:port2") .option("subscribe", "topic1") .load(); -ds1.selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)"); +df.selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)"); // Subscribe to multiple topics, specifying explicit Kafka offsets -Dataset ds2 = spark +DataFrame df = spark .read() .format("kafka") .option("kafka.bootstrap.servers", "host1:port1,host2:port2") @@ -184,10 +186,10 @@ Dataset ds2 = spark .option("startingOffsets", "{\"topic1\":{\"0\":23,\"1\":-2},\"topic2\":{\"0\":-2}}") .option("endingOffsets", "{\"topic1\":{\"0\":50,\"1\":-1},\"topic2\":{\"0\":-1}}") .load(); -ds2.selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)"); +df.selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)"); // Subscribe to a pattern, at the earliest and latest offsets -Dataset ds3 = spark +DataFrame df = spark .read() .format("kafka") .option("kafka.bootstrap.servers", "host1:port1,host2:port2") @@ -195,7 +197,7 @@ Dataset ds3 = spark .option("startingOffsets", "earliest") .option("endingOffsets", "latest") .load(); -ds3.selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)"); +df.selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)"); {% endhighlight %}
    @@ -203,16 +205,16 @@ ds3.selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)"); {% highlight python %} # Subscribe to 1 topic defaults to the earliest and latest offsets -ds1 = spark \ +df = spark \ .read \ .format("kafka") \ .option("kafka.bootstrap.servers", "host1:port1,host2:port2") \ .option("subscribe", "topic1") \ .load() -ds1.selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)") +df.selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)") # Subscribe to multiple topics, specifying explicit Kafka offsets -ds2 = spark \ +df = spark \ .read \ .format("kafka") \ .option("kafka.bootstrap.servers", "host1:port1,host2:port2") \ @@ -220,10 +222,10 @@ ds2 = spark \ .option("startingOffsets", """{"topic1":{"0":23,"1":-2},"topic2":{"0":-2}}""") \ .option("endingOffsets", """{"topic1":{"0":50,"1":-1},"topic2":{"0":-1}}""") \ .load() -ds2.selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)") +df.selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)") # Subscribe to a pattern, at the earliest and latest offsets -ds3 = spark \ +df = spark \ .read \ .format("kafka") \ .option("kafka.bootstrap.servers", "host1:port1,host2:port2") \ @@ -231,8 +233,7 @@ ds3 = spark \ .option("startingOffsets", "earliest") \ .option("endingOffsets", "latest") \ .load() -ds3.selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)") - +df.selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)") {% endhighlight %}
    @@ -373,11 +374,213 @@ The following configurations are optional: +## Writing Data to Kafka + +Here, we describe the support for writing Streaming Queries and Batch Queries to Apache Kafka. Take note that +Apache Kafka only supports at least once write semantics. Consequently, when writing---either Streaming Queries +or Batch Queries---to Kafka, some records may be duplicated; this can happen, for example, if Kafka needs +to retry a message that was not acknowledged by a Broker, even though that Broker received and wrote the message record. +Structured Streaming cannot prevent such duplicates from occurring due to these Kafka write semantics. However, +if writing the query is successful, then you can assume that the query output was written at least once. A possible +solution to remove duplicates when reading the written data could be to introduce a primary (unique) key +that can be used to perform de-duplication when reading. + +The Dataframe being written to Kafka should have the following columns in schema: + + + + + + + + + + + + + + +
    ColumnType
    key (optional)string or binary
    value (required)string or binary
    topic (*optional)string
    +\* The topic column is required if the "topic" configuration option is not specified.
    + +The value column is the only required option. If a key column is not specified then +a ```null``` valued key column will be automatically added (see Kafka semantics on +how ```null``` valued key values are handled). If a topic column exists then its value +is used as the topic when writing the given row to Kafka, unless the "topic" configuration +option is set i.e., the "topic" configuration option overrides the topic column. + +The following options must be set for the Kafka sink +for both batch and streaming queries. + + + + + + + + +
    Optionvaluemeaning
    kafka.bootstrap.serversA comma-separated list of host:portThe Kafka "bootstrap.servers" configuration.
    + +The following configurations are optional: + + + + + + + + + + +
    Optionvaluedefaultquery typemeaning
    topicstringnonestreaming and batchSets the topic that all rows will be written to in Kafka. This option overrides any + topic column that may exist in the data.
    + +### Creating a Kafka Sink for Streaming Queries + +
    +
    +{% highlight scala %} + +// Write key-value data from a DataFrame to a specific Kafka topic specified in an option +val ds = df + .selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)") + .writeStream + .format("kafka") + .option("kafka.bootstrap.servers", "host1:port1,host2:port2") + .option("topic", "topic1") + .start() + +// Write key-value data from a DataFrame to Kafka using a topic specified in the data +val ds = df + .selectExpr("topic", "CAST(key AS STRING)", "CAST(value AS STRING)") + .writeStream + .format("kafka") + .option("kafka.bootstrap.servers", "host1:port1,host2:port2") + .start() + +{% endhighlight %} +
    +
    +{% highlight java %} + +// Write key-value data from a DataFrame to a specific Kafka topic specified in an option +StreamingQuery ds = df + .selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)") + .writeStream() + .format("kafka") + .option("kafka.bootstrap.servers", "host1:port1,host2:port2") + .option("topic", "topic1") + .start() + +// Write key-value data from a DataFrame to Kafka using a topic specified in the data +StreamingQuery ds = df + .selectExpr("topic", "CAST(key AS STRING)", "CAST(value AS STRING)") + .writeStream() + .format("kafka") + .option("kafka.bootstrap.servers", "host1:port1,host2:port2") + .start() + +{% endhighlight %} +
    +
    +{% highlight python %} + +# Write key-value data from a DataFrame to a specific Kafka topic specified in an option +ds = df \ + .selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)") \ + .writeStream \ + .format("kafka") \ + .option("kafka.bootstrap.servers", "host1:port1,host2:port2") \ + .option("topic", "topic1") \ + .start() + +# Write key-value data from a DataFrame to Kafka using a topic specified in the data +ds = df \ + .selectExpr("topic", "CAST(key AS STRING)", "CAST(value AS STRING)") \ + .writeStream \ + .format("kafka") \ + .option("kafka.bootstrap.servers", "host1:port1,host2:port2") \ + .start() + +{% endhighlight %} +
    +
    + +### Writing the output of Batch Queries to Kafka + +
    +
    +{% highlight scala %} + +// Write key-value data from a DataFrame to a specific Kafka topic specified in an option +df.selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)") + .write + .format("kafka") + .option("kafka.bootstrap.servers", "host1:port1,host2:port2") + .option("topic", "topic1") + .save() + +// Write key-value data from a DataFrame to Kafka using a topic specified in the data +df.selectExpr("topic", "CAST(key AS STRING)", "CAST(value AS STRING)") + .write + .format("kafka") + .option("kafka.bootstrap.servers", "host1:port1,host2:port2") + .save() + +{% endhighlight %} +
    +
    +{% highlight java %} + +// Write key-value data from a DataFrame to a specific Kafka topic specified in an option +df.selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)") + .write() + .format("kafka") + .option("kafka.bootstrap.servers", "host1:port1,host2:port2") + .option("topic", "topic1") + .save() + +// Write key-value data from a DataFrame to Kafka using a topic specified in the data +df.selectExpr("topic", "CAST(key AS STRING)", "CAST(value AS STRING)") + .write() + .format("kafka") + .option("kafka.bootstrap.servers", "host1:port1,host2:port2") + .save() + +{% endhighlight %} +
    +
    +{% highlight python %} + +# Write key-value data from a DataFrame to a specific Kafka topic specified in an option +df.selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)") \ + .write \ + .format("kafka") \ + .option("kafka.bootstrap.servers", "host1:port1,host2:port2") \ + .option("topic", "topic1") \ + .save() + +# Write key-value data from a DataFrame to Kafka using a topic specified in the data +df.selectExpr("topic", "CAST(key AS STRING)", "CAST(value AS STRING)") \ + .write \ + .format("kafka") \ + .option("kafka.bootstrap.servers", "host1:port1,host2:port2") \ + .save() + +{% endhighlight %} +
    +
    + + +## Kafka Specific Configurations + Kafka's own configurations can be set via `DataStreamReader.option` with `kafka.` prefix, e.g, -`stream.option("kafka.bootstrap.servers", "host:port")`. For possible kafkaParams, see -[Kafka consumer config docs](http://kafka.apache.org/documentation.html#newconsumerconfigs). +`stream.option("kafka.bootstrap.servers", "host:port")`. For possible kafka parameters, see +[Kafka consumer config docs](http://kafka.apache.org/documentation.html#newconsumerconfigs) for +parameters related to reading data, and [Kafka producer config docs](http://kafka.apache.org/documentation/#producerconfigs) +for parameters related to writing data. -Note that the following Kafka params cannot be set and the Kafka source will throw an exception: +Note that the following Kafka params cannot be set and the Kafka source or sink will throw an exception: - **group.id**: Kafka source will create a unique group id for each query automatically. - **auto.offset.reset**: Set the source option `startingOffsets` to specify @@ -389,11 +592,15 @@ Note that the following Kafka params cannot be set and the Kafka source will thr DataFrame operations to explicitly deserialize the keys. - **value.deserializer**: Values are always deserialized as byte arrays with ByteArrayDeserializer. Use DataFrame operations to explicitly deserialize the values. +- **key.serializer**: Keys are always serialized with ByteArraySerializer or StringSerializer. Use +DataFrame operations to explicitly serialize the keys into either strings or byte arrays. +- **value.serializer**: values are always serialized with ByteArraySerializer or StringSerializer. Use +DataFrame oeprations to explicitly serialize the values into either strings or byte arrays. - **enable.auto.commit**: Kafka source doesn't commit any offset. - **interceptor.classes**: Kafka source always read keys and values as byte arrays. It's not safe to use ConsumerInterceptor as it may break the query. -### Deploying +## Deploying As with any Spark applications, `spark-submit` is used to launch your application. `spark-sql-kafka-0-10_{{site.SCALA_BINARY_VERSION}}` and its dependencies can be directly added to `spark-submit` using `--packages`, such as, diff --git a/docs/structured-streaming-programming-guide.md b/docs/structured-streaming-programming-guide.md index 6af47b6efba2c..53b3db21da769 100644 --- a/docs/structured-streaming-programming-guide.md +++ b/docs/structured-streaming-programming-guide.md @@ -1,6 +1,6 @@ --- layout: global -displayTitle: Structured Streaming Programming Guide [Alpha] +displayTitle: Structured Streaming Programming Guide [Experimental] title: Structured Streaming Programming Guide --- @@ -8,13 +8,13 @@ title: Structured Streaming Programming Guide {:toc} # Overview -Structured Streaming is a scalable and fault-tolerant stream processing engine built on the Spark SQL engine. You can express your streaming computation the same way you would express a batch computation on static data.The Spark SQL engine will take care of running it incrementally and continuously and updating the final result as streaming data continues to arrive. You can use the [Dataset/DataFrame API](sql-programming-guide.html) in Scala, Java or Python to express streaming aggregations, event-time windows, stream-to-batch joins, etc. The computation is executed on the same optimized Spark SQL engine. Finally, the system ensures end-to-end exactly-once fault-tolerance guarantees through checkpointing and Write Ahead Logs. In short, *Structured Streaming provides fast, scalable, fault-tolerant, end-to-end exactly-once stream processing without the user having to reason about streaming.* +Structured Streaming is a scalable and fault-tolerant stream processing engine built on the Spark SQL engine. You can express your streaming computation the same way you would express a batch computation on static data. The Spark SQL engine will take care of running it incrementally and continuously and updating the final result as streaming data continues to arrive. You can use the [Dataset/DataFrame API](sql-programming-guide.html) in Scala, Java, Python or R to express streaming aggregations, event-time windows, stream-to-batch joins, etc. The computation is executed on the same optimized Spark SQL engine. Finally, the system ensures end-to-end exactly-once fault-tolerance guarantees through checkpointing and Write Ahead Logs. In short, *Structured Streaming provides fast, scalable, fault-tolerant, end-to-end exactly-once stream processing without the user having to reason about streaming.* -**Structured Streaming is still ALPHA in Spark 2.1** and the APIs are still experimental. In this guide, we are going to walk you through the programming model and the APIs. First, let's start with a simple example - a streaming word count. +**Structured Streaming is still ALPHA in Spark 2.1** and the APIs are still experimental. In this guide, we are going to walk you through the programming model and the APIs. First, let's start with a simple example - a streaming word count. # Quick Example -Let’s say you want to maintain a running word count of text data received from a data server listening on a TCP socket. Let’s see how you can express this using Structured Streaming. You can see the full code in -[Scala]({{site.SPARK_GITHUB_URL}}/blob/v{{site.SPARK_VERSION_SHORT}}/examples/src/main/scala/org/apache/spark/examples/sql/streaming/StructuredNetworkWordCount.scala)/[Java]({{site.SPARK_GITHUB_URL}}/blob/v{{site.SPARK_VERSION_SHORT}}/examples/src/main/java/org/apache/spark/examples/sql/streaming/JavaStructuredNetworkWordCount.java)/[Python]({{site.SPARK_GITHUB_URL}}/blob/v{{site.SPARK_VERSION_SHORT}}/examples/src/main/python/sql/streaming/structured_network_wordcount.py). +Let’s say you want to maintain a running word count of text data received from a data server listening on a TCP socket. Let’s see how you can express this using Structured Streaming. You can see the full code in +[Scala]({{site.SPARK_GITHUB_URL}}/blob/v{{site.SPARK_VERSION_SHORT}}/examples/src/main/scala/org/apache/spark/examples/sql/streaming/StructuredNetworkWordCount.scala)/[Java]({{site.SPARK_GITHUB_URL}}/blob/v{{site.SPARK_VERSION_SHORT}}/examples/src/main/java/org/apache/spark/examples/sql/streaming/JavaStructuredNetworkWordCount.java)/[Python]({{site.SPARK_GITHUB_URL}}/blob/v{{site.SPARK_VERSION_SHORT}}/examples/src/main/python/sql/streaming/structured_network_wordcount.py)/[R]({{site.SPARK_GITHUB_URL}}/blob/v{{site.SPARK_VERSION_SHORT}}/examples/src/main/r/streaming/structured_network_wordcount.R). And if you [download Spark](http://spark.apache.org/downloads.html), you can directly run the example. In any case, let’s walk through the example step-by-step and understand how it works. First, we have to import the necessary classes and create a local SparkSession, the starting point of all functionalities related to Spark.
    @@ -63,6 +63,13 @@ spark = SparkSession \ .getOrCreate() {% endhighlight %} +
    +
    + +{% highlight r %} +sparkR.session(appName = "StructuredNetworkWordCount") +{% endhighlight %} +
    @@ -136,6 +143,22 @@ wordCounts = words.groupBy("word").count() This `lines` DataFrame represents an unbounded table containing the streaming text data. This table contains one column of strings named "value", and each line in the streaming text data becomes a row in the table. Note, that this is not currently receiving any data as we are just setting up the transformation, and have not yet started it. Next, we have used two built-in SQL functions - split and explode, to split each line into multiple rows with a word each. In addition, we use the function `alias` to name the new column as "word". Finally, we have defined the `wordCounts` DataFrame by grouping by the unique values in the Dataset and counting them. Note that this is a streaming DataFrame which represents the running word counts of the stream. + +
    + +{% highlight r %} +# Create DataFrame representing the stream of input lines from connection to localhost:9999 +lines <- read.stream("socket", host = "localhost", port = 9999) + +# Split the lines into words +words <- selectExpr(lines, "explode(split(value, ' ')) as word") + +# Generate running word count +wordCounts <- count(group_by(words, "word")) +{% endhighlight %} + +This `lines` SparkDataFrame represents an unbounded table containing the streaming text data. This table contains one column of strings named "value", and each line in the streaming text data becomes a row in the table. Note, that this is not currently receiving any data as we are just setting up the transformation, and have not yet started it. Next, we have a SQL expression with two SQL functions - split and explode, to split each line into multiple rows with a word each. In addition, we name the new column as "word". Finally, we have defined the `wordCounts` SparkDataFrame by grouping by the unique values in the SparkDataFrame and counting them. Note that this is a streaming SparkDataFrame which represents the running word counts of the stream. +
    @@ -181,10 +204,20 @@ query = wordCounts \ query.awaitTermination() {% endhighlight %} + +
    + +{% highlight r %} +# Start running the query that prints the running counts to the console +query <- write.stream(wordCounts, "console", outputMode = "complete") + +awaitTermination(query) +{% endhighlight %} +
    -After this code is executed, the streaming computation will have started in the background. The `query` object is a handle to that active streaming query, and we have decided to wait for the termination of the query using `query.awaitTermination()` to prevent the process from exiting while the query is active. +After this code is executed, the streaming computation will have started in the background. The `query` object is a handle to that active streaming query, and we have decided to wait for the termination of the query using `awaitTermination()` to prevent the process from exiting while the query is active. To actually execute this example code, you can either compile the code in your own [Spark application](quick-start.html#self-contained-applications), or simply @@ -211,6 +244,11 @@ $ ./bin/run-example org.apache.spark.examples.sql.streaming.JavaStructuredNetwor $ ./bin/spark-submit examples/src/main/python/sql/streaming/structured_network_wordcount.py localhost 9999 {% endhighlight %} +
    +{% highlight bash %} +$ ./bin/spark-submit examples/src/main/r/streaming/structured_network_wordcount.R localhost 9999 +{% endhighlight %} +
    Then, any lines typed in the terminal running the netcat server will be counted and printed on screen every second. It will look something like the following. @@ -325,6 +363,35 @@ Batch: 0 | spark| 1| +------+-----+ +------------------------------------------- +Batch: 1 +------------------------------------------- ++------+-----+ +| value|count| ++------+-----+ +|apache| 2| +| spark| 1| +|hadoop| 1| ++------+-----+ +... +{% endhighlight %} + +
    +{% highlight bash %} +# TERMINAL 2: RUNNING structured_network_wordcount.R + +$ ./bin/spark-submit examples/src/main/r/streaming/structured_network_wordcount.R localhost 9999 + +------------------------------------------- +Batch: 0 +------------------------------------------- ++------+-----+ +| value|count| ++------+-----+ +|apache| 1| +| spark| 1| ++------+-----+ + ------------------------------------------- Batch: 1 ------------------------------------------- @@ -362,7 +429,7 @@ A query on the input will generate the "Result Table". Every trigger interval (s ![Model](img/structured-streaming-model.png) -The "Output" is defined as what gets written out to the external storage. The output can be defined in different modes +The "Output" is defined as what gets written out to the external storage. The output can be defined in a different mode: - *Complete Mode* - The entire updated Result Table will be written to the external storage. It is up to the storage connector to decide how to handle writing of the entire table. @@ -409,14 +476,14 @@ to track the read position in the stream. The engine uses checkpointing and writ # API using Datasets and DataFrames Since Spark 2.0, DataFrames and Datasets can represent static, bounded data, as well as streaming, unbounded data. Similar to static Datasets/DataFrames, you can use the common entry point `SparkSession` -([Scala](api/scala/index.html#org.apache.spark.sql.SparkSession)/[Java](api/java/org/apache/spark/sql/SparkSession.html)/[Python](api/python/pyspark.sql.html#pyspark.sql.SparkSession) docs) +([Scala](api/scala/index.html#org.apache.spark.sql.SparkSession)/[Java](api/java/org/apache/spark/sql/SparkSession.html)/[Python](api/python/pyspark.sql.html#pyspark.sql.SparkSession)/[R](api/R/sparkR.session.html) docs) to create streaming DataFrames/Datasets from streaming sources, and apply the same operations on them as static DataFrames/Datasets. If you are not familiar with Datasets/DataFrames, you are strongly advised to familiarize yourself with them using the [DataFrame/Dataset Programming Guide](sql-programming-guide.html). ## Creating streaming DataFrames and streaming Datasets -Streaming DataFrames can be created through the `DataStreamReader` interface +Streaming DataFrames can be created through the `DataStreamReader` interface ([Scala](api/scala/index.html#org.apache.spark.sql.streaming.DataStreamReader)/[Java](api/java/org/apache/spark/sql/streaming/DataStreamReader.html)/[Python](api/python/pyspark.sql.html#pyspark.sql.streaming.DataStreamReader) docs) -returned by `SparkSession.readStream()`. Similar to the read interface for creating static DataFrame, you can specify the details of the source – data format, schema, options, etc. +returned by `SparkSession.readStream()`. In [R](api/R/read.stream.html), with the `read.stream()` method. Similar to the read interface for creating static DataFrame, you can specify the details of the source – data format, schema, options, etc. #### Input Sources In Spark 2.0, there are a few built-in sources. @@ -445,7 +512,8 @@ Here are the details of all the sources in Spark. path: path to the input directory, and common to all file formats.

    For file-format-specific options, see the related methods in DataStreamReader - (Scala/Java/Python). + (Scala/Java/Python/R). E.g. for "parquet" format options see DataStreamReader.parquet() Yes Supports glob paths, but does not support multiple comma-separated paths/globs. @@ -483,7 +551,7 @@ Here are some examples. {% highlight scala %} val spark: SparkSession = ... -// Read text from socket +// Read text from socket val socketDF = spark .readStream .format("socket") @@ -493,7 +561,7 @@ val socketDF = spark socketDF.isStreaming // Returns True for DataFrames that have streaming sources -socketDF.printSchema +socketDF.printSchema // Read all the csv files written atomically in a directory val userSchema = new StructType().add("name", "string").add("age", "integer") @@ -510,7 +578,7 @@ val csvDF = spark {% highlight java %} SparkSession spark = ... -// Read text from socket +// Read text from socket Dataset socketDF = spark .readStream() .format("socket") @@ -537,9 +605,9 @@ Dataset csvDF = spark {% highlight python %} spark = SparkSession. ... -# Read text from socket +# Read text from socket socketDF = spark \ - .readStream() \ + .readStream \ .format("socket") \ .option("host", "localhost") \ .option("port", 9999) \ @@ -547,17 +615,36 @@ socketDF = spark \ socketDF.isStreaming() # Returns True for DataFrames that have streaming sources -socketDF.printSchema() +socketDF.printSchema() # Read all the csv files written atomically in a directory userSchema = StructType().add("name", "string").add("age", "integer") csvDF = spark \ - .readStream() \ + .readStream \ .option("sep", ";") \ .schema(userSchema) \ .csv("/path/to/directory") # Equivalent to format("csv").load("/path/to/directory") {% endhighlight %} +
    +
    + +{% highlight r %} +sparkR.session(...) + +# Read text from socket +socketDF <- read.stream("socket", host = hostname, port = port) + +isStreaming(socketDF) # Returns TRUE for SparkDataFrames that have streaming sources + +printSchema(socketDF) + +# Read all the csv files written atomically in a directory +schema <- structType(structField("name", "string"), + structField("age", "integer")) +csvDF <- read.stream("csv", path = "/path/to/directory", schema = schema, sep = ";") +{% endhighlight %} +
    @@ -638,12 +725,24 @@ ds.groupByKey((MapFunction) value -> value.getDeviceType(), df = ... # streaming DataFrame with IOT device data with schema { device: string, deviceType: string, signal: double, time: DateType } # Select the devices which have signal more than 10 -df.select("device").where("signal > 10") +df.select("device").where("signal > 10") # Running count of the number of updates for each device type df.groupBy("deviceType").count() {% endhighlight %} +
    + +{% highlight r %} +df <- ... # streaming DataFrame with IOT device data with schema { device: string, deviceType: string, signal: double, time: DateType } + +# Select the devices which have signal more than 10 +select(where(df, "signal > 10"), "device") + +# Running count of the number of updates for each device type +count(groupBy(df, "deviceType")) +{% endhighlight %} +
    ### Window Operations on Event Time @@ -717,11 +816,11 @@ However, to run this query for days, it's necessary for the system to bound the intermediate in-memory state it accumulates. This means the system needs to know when an old aggregate can be dropped from the in-memory state because the application is not going to receive late data for that aggregate any more. To enable this, in Spark 2.1, we have introduced -**watermarking**, which let's the engine automatically track the current event time in the data and +**watermarking**, which lets the engine automatically track the current event time in the data and attempt to clean up old state accordingly. You can define the watermark of a query by -specifying the event time column and the threshold on how late the data is expected be in terms of +specifying the event time column and the threshold on how late the data is expected to be in terms of event time. For a specific window starting at time `T`, the engine will maintain state and allow late -data to be update the state until `(max event time seen by the engine - late threshold > T)`. +data to update the state until `(max event time seen by the engine - late threshold > T)`. In other words, late data within the threshold will be aggregated, but data later than the threshold will be dropped. Let's understand this with an example. We can easily define watermarking on the previous example using `withWatermark()` as shown below. @@ -764,11 +863,11 @@ Dataset windowedCounts = words words = ... # streaming DataFrame of schema { timestamp: Timestamp, word: String } # Group the data by window and word and compute the count of each group -windowedCounts = words - .withWatermark("timestamp", "10 minutes") +windowedCounts = words \ + .withWatermark("timestamp", "10 minutes") \ .groupBy( window(words.timestamp, "10 minutes", "5 minutes"), - words.word) + words.word) \ .count() {% endhighlight %} @@ -778,7 +877,7 @@ windowedCounts = words In this example, we are defining the watermark of the query on the value of the column "timestamp", and also defining "10 minutes" as the threshold of how late is the data allowed to be. If this query is run in Update output mode (discussed later in [Output Modes](#output-modes) section), -the engine will keep updating counts of a window in the Resule Table until the window is older +the engine will keep updating counts of a window in the Result Table until the window is older than the watermark, which lags behind the current event time in column "timestamp" by 10 minutes. Here is an illustration. @@ -792,7 +891,7 @@ This watermark lets the engine maintain intermediate state for additional 10 min data to be counted. For example, the data `(12:09, cat)` is out of order and late, and it falls in windows `12:05 - 12:15` and `12:10 - 12:20`. Since, it is still ahead of the watermark `12:04` in the trigger, the engine still maintains the intermediate counts as state and correctly updates the -counts of the related windows. However, when the watermark is updated to 12:11, the intermediate +counts of the related windows. However, when the watermark is updated to `12:11`, the intermediate state for window `(12:00 - 12:10)` is cleared, and all subsequent data (e.g. `(12:04, donkey)`) is considered "too late" and therefore ignored. Note that after every trigger, the updated counts (i.e. purple rows) are written to sink as the trigger output, as dictated by @@ -825,7 +924,7 @@ section for detailed explanation of the semantics of each output mode. same column as the timestamp column used in the aggregate. For example, `df.withWatermark("time", "1 min").groupBy("time2").count()` is invalid in Append output mode, as watermark is defined on a different column -as the aggregation column. +from the aggregation column. - `withWatermark` must be called before the aggregation for the watermark details to be used. For example, `df.groupBy("time").count().withWatermark("time", "1 min")` is invalid in Append @@ -840,7 +939,7 @@ Streaming DataFrames can be joined with static DataFrames to create new streamin {% highlight scala %} val staticDf = spark.read. ... -val streamingDf = spark.readStream. ... +val streamingDf = spark.readStream. ... streamingDf.join(staticDf, "type") // inner equi-join with a static DF streamingDf.join(staticDf, "type", "right_join") // right outer join with a static DF @@ -871,6 +970,65 @@ streamingDf.join(staticDf, "type", "right_join") # right outer join with a stat +### Streaming Deduplication +You can deduplicate records in data streams using a unique identifier in the events. This is exactly same as deduplication on static using a unique identifier column. The query will store the necessary amount of data from previous records such that it can filter duplicate records. Similar to aggregations, you can use deduplication with or without watermarking. + +- *With watermark* - If there is a upper bound on how late a duplicate record may arrive, then you can define a watermark on a event time column and deduplicate using both the guid and the event time columns. The query will use the watermark to remove old state data from past records that are not expected to get any duplicates any more. This bounds the amount of the state the query has to maintain. + +- *Without watermark* - Since there are no bounds on when a duplicate record may arrive, the query stores the data from all the past records as state. + +
    +
    + +{% highlight scala %} +val streamingDf = spark.readStream. ... // columns: guid, eventTime, ... + +// Without watermark using guid column +streamingDf.dropDuplicates("guid") + +// With watermark using guid and eventTime columns +streamingDf + .withWatermark("eventTime", "10 seconds") + .dropDuplicates("guid", "eventTime") +{% endhighlight %} + +
    +
    + +{% highlight java %} +Dataset streamingDf = spark.readStream. ...; // columns: guid, eventTime, ... + +// Without watermark using guid column +streamingDf.dropDuplicates("guid"); + +// With watermark using guid and eventTime columns +streamingDf + .withWatermark("eventTime", "10 seconds") + .dropDuplicates("guid", "eventTime"); +{% endhighlight %} + + +
    +
    + +{% highlight python %} +streamingDf = spark.readStream. ... + +// Without watermark using guid column +streamingDf.dropDuplicates("guid") + +// With watermark using guid and eventTime columns +streamingDf \ + .withWatermark("eventTime", "10 seconds") \ + .dropDuplicates("guid", "eventTime") +{% endhighlight %} + +
    +
    + +### Arbitrary Stateful Operations +Many uscases require more advanced stateful operations than aggregations. For example, in many usecases, you have to track sessions from data streams of events. For doing such sessionization, you will have to save arbitrary types of data as state, and perform arbitrary operations on the state using the data stream events in every trigger. Since Spark 2.2, this can be done using the operation `mapGroupsWithState` and the more powerful operation `flatMapGroupsWithState`. Both operations allow you to apply user-defined code on grouped Datasets to update user-defined state. For more concrete details, take a look at the API documentation ([Scala](api/scala/index.html#org.apache.spark.sql.streaming.GroupState)/[Java](api/java/org/apache/spark/sql/streaming/GroupState.html)) and the examples ([Scala]({{site.SPARK_GITHUB_URL}}/blob/v{{site.SPARK_VERSION_SHORT}}/examples/src/main/scala/org/apache/spark/examples/sql/streaming/StructuredSessionization.scala)/[Java]({{site.SPARK_GITHUB_URL}}/blob/v{{site.SPARK_VERSION_SHORT}}/examples/src/main/java/org/apache/spark/examples/sql/streaming/JavaStructuredSessionization.java)). + ### Unsupported Operations There are a few DataFrame/Dataset operations that are not supported with streaming DataFrames/Datasets. Some of them are as follows. @@ -891,7 +1049,7 @@ Some of them are as follows. + Right outer join with a streaming Dataset on the left is not supported -- Any kind of joins between two streaming Datasets are not yet supported. +- Any kind of joins between two streaming Datasets is not yet supported. In addition, there are some Dataset methods that will not work on streaming Datasets. They are actions that will immediately run queries and return results, which does not make sense on a streaming Dataset. Rather, those functionalities can be done by explicitly starting a streaming query (see the next section regarding that). @@ -909,11 +1067,11 @@ track of all the data received in the stream. This is therefore fundamentally ha efficiently. ## Starting Streaming Queries -Once you have defined the final result DataFrame/Dataset, all that is left is for you start the streaming computation. To do that, you have to use the `DataStreamWriter` +Once you have defined the final result DataFrame/Dataset, all that is left is for you to start the streaming computation. To do that, you have to use the `DataStreamWriter` ([Scala](api/scala/index.html#org.apache.spark.sql.streaming.DataStreamWriter)/[Java](api/java/org/apache/spark/sql/streaming/DataStreamWriter.html)/[Python](api/python/pyspark.sql.html#pyspark.sql.streaming.DataStreamWriter) docs) returned through `Dataset.writeStream()`. You will have to specify one or more of the following in this interface. -- *Details of the output sink:* Data format, location, etc. +- *Details of the output sink:* Data format, location, etc. - *Output mode:* Specify what gets written to the output sink. @@ -951,13 +1109,6 @@ Here is the compatibility matrix. Supported Output Modes Notes - - Queries without aggregation - Append, Update - - Complete mode not supported as it is infeasible to keep all data in the Result Table. - - Queries with aggregation Aggregation on event-time with watermark @@ -971,7 +1122,7 @@ Here is the compatibility matrix.

    Update mode uses watermark to drop old aggregation state.

    - Complete mode does drop not old aggregation state since by definition this mode + Complete mode does not drop old aggregation state since by definition this mode preserves all data in the Result Table. @@ -986,6 +1137,33 @@ Here is the compatibility matrix. this mode. + + Queries with mapGroupsWithState + Update + + + + Queries with flatMapGroupsWithState + Append operation mode + Append + + Aggregations are allowed after flatMapGroupsWithState. + + + + Update operation mode + Update + + Aggregations not allowed after flatMapGroupsWithState. + + + + Other queries + Append, Update + + Complete mode not supported as it is infeasible to keep all unaggregated data in the Result Table. + + @@ -994,10 +1172,11 @@ Here is the compatibility matrix. + #### Output Sinks There are a few types of built-in output sinks. -- **File sink** - Stores the output to a directory. +- **File sink** - Stores the output to a directory. {% highlight scala %} writeStream @@ -1052,12 +1231,21 @@ Here are the details of all the sinks in Spark. Append path: path to the output directory, must be specified. +
    maxFilesPerTrigger: maximum number of new files to be considered in every trigger (default: no max)
    - latestFirst: whether to processs the latest new files first, useful when there is a large backlog of files(default: false) -

    + latestFirst: whether to processs the latest new files first, useful when there is a large backlog of files (default: false) +
    + fileNameOnly: whether to check new files based on only the filename instead of on the full path (default: false). With this set to `true`, the following files would be considered as the same file, because their filenames, "dataset.txt", are the same: +
    + · "file:///dataset.txt"
    + · "s3://a/dataset.txt"
    + · "s3n://a/b/dataset.txt"
    + · "s3a://a/b/c/dataset.txt"
    +
    For file-format-specific options, see the related methods in DataFrameWriter - (Scala/Java/Python). + (Scala/Java/Python/R). E.g. for "parquet" format options see DataFrameWriter.parquet() Yes @@ -1120,7 +1308,7 @@ noAggDF .option("checkpointLocation", "path/to/checkpoint/dir") .option("path", "path/to/destination/dir") .start() - + // ========== DF with aggregation ========== val aggDF = df.groupBy("device").count() @@ -1131,7 +1319,7 @@ aggDF .format("console") .start() -// Have all the aggregates in an in-memory table +// Have all the aggregates in an in-memory table aggDF .writeStream .queryName("aggregates") // this query name will be the table name @@ -1162,7 +1350,7 @@ noAggDF .option("checkpointLocation", "path/to/checkpoint/dir") .option("path", "path/to/destination/dir") .start(); - + // ========== DF with aggregation ========== Dataset aggDF = df.groupBy("device").count(); @@ -1173,7 +1361,7 @@ aggDF .format("console") .start(); -// Have all the aggregates in an in-memory table +// Have all the aggregates in an in-memory table aggDF .writeStream() .queryName("aggregates") // this query name will be the table name @@ -1193,31 +1381,31 @@ noAggDF = deviceDataDf.select("device").where("signal > 10") # Print new data to console noAggDF \ - .writeStream() \ + .writeStream \ .format("console") \ .start() # Write new data to Parquet files noAggDF \ - .writeStream() \ + .writeStream \ .format("parquet") \ .option("checkpointLocation", "path/to/checkpoint/dir") \ .option("path", "path/to/destination/dir") \ .start() - + # ========== DF with aggregation ========== aggDF = df.groupBy("device").count() # Print updated aggregations to console aggDF \ - .writeStream() \ + .writeStream \ .outputMode("complete") \ .format("console") \ .start() # Have all the aggregates in an in memory table. The query name will be the table name aggDF \ - .writeStream() \ + .writeStream \ .queryName("aggregates") \ .outputMode("complete") \ .format("memory") \ @@ -1226,6 +1414,35 @@ aggDF \ spark.sql("select * from aggregates").show() # interactively query in-memory table {% endhighlight %} + +
    + +{% highlight r %} +# ========== DF with no aggregations ========== +noAggDF <- select(where(deviceDataDf, "signal > 10"), "device") + +# Print new data to console +write.stream(noAggDF, "console") + +# Write new data to Parquet files +write.stream(noAggDF, + "parquet", + path = "path/to/destination/dir", + checkpointLocation = "path/to/checkpoint/dir") + +# ========== DF with aggregation ========== +aggDF <- count(groupBy(df, "device")) + +# Print updated aggregations to console +write.stream(aggDF, "console", outputMode = "complete") + +# Have all the aggregates in an in memory table. The query name will be the table name +write.stream(aggDF, "memory", queryName = "aggregates", outputMode = "complete") + +# Interactively query in-memory table +head(sql("select * from aggregates")) +{% endhighlight %} +
    @@ -1263,7 +1480,7 @@ query.name // get the name of the auto-generated or user-specified name query.explain() // print detailed explanations of the query -query.stop() // stop the query +query.stop() // stop the query query.awaitTermination() // block until query is terminated, with stop() or with error @@ -1305,7 +1522,7 @@ query.lastProgress(); // the most recent progress update of this streaming qu
    {% highlight python %} -query = df.writeStream().format("console").start() # get the query object +query = df.writeStream.format("console").start() # get the query object query.id() # get the unique identifier of the running query that persists across restarts from checkpoint data @@ -1315,7 +1532,7 @@ query.name() # get the name of the auto-generated or user-specified name query.explain() # print detailed explanations of the query -query.stop() # stop the query +query.stop() # stop the query query.awaitTermination() # block until query is terminated, with stop() or with error @@ -1327,6 +1544,24 @@ query.lastProgress() # the most recent progress update of this streaming quer {% endhighlight %} +
    +
    + +{% highlight r %} +query <- write.stream(df, "console") # get the query object + +queryName(query) # get the name of the auto-generated or user-specified name + +explain(query) # print detailed explanations of the query + +stopQuery(query) # stop the query + +awaitTermination(query) # block until query is terminated, with stop() or with error + +lastProgress(query) # the most recent progress update of this streaming query + +{% endhighlight %} +
    @@ -1373,6 +1608,12 @@ spark.streams().get(id) # get a query object by its unique id spark.streams().awaitAnyTermination() # block until any one of them terminates {% endhighlight %} + +
    +{% highlight bash %} +Not available in R. +{% endhighlight %} +
    @@ -1388,15 +1629,15 @@ You can directly get the current status and metrics of an active query using `lastProgress()` returns a `StreamingQueryProgress` object in [Scala](api/scala/index.html#org.apache.spark.sql.streaming.StreamingQueryProgress) and [Java](api/java/org/apache/spark/sql/streaming/StreamingQueryProgress.html) -and an dictionary with the same fields in Python. It has all the information about +and a dictionary with the same fields in Python. It has all the information about the progress made in the last trigger of the stream - what data was processed, what were the processing rates, latencies, etc. There is also `streamingQuery.recentProgress` which returns an array of last few progresses. -In addition, `streamingQuery.status()` returns `StreamingQueryStatus` object +In addition, `streamingQuery.status()` returns a `StreamingQueryStatus` object in [Scala](api/scala/index.html#org.apache.spark.sql.streaming.StreamingQueryStatus) and [Java](api/java/org/apache/spark/sql/streaming/StreamingQueryStatus.html) -and an dictionary with the same fields in Python. It gives information about +and a dictionary with the same fields in Python. It gives information about what the query is immediately doing - is a trigger active, is data being processed, etc. Here are a few examples. @@ -1556,6 +1797,58 @@ Will print something like the following. ''' {% endhighlight %} + +
    + +{% highlight r %} +query <- ... # a StreamingQuery +lastProgress(query) + +''' +Will print something like the following. + +{ + "id" : "8c57e1ec-94b5-4c99-b100-f694162df0b9", + "runId" : "ae505c5a-a64e-4896-8c28-c7cbaf926f16", + "name" : null, + "timestamp" : "2017-04-26T08:27:28.835Z", + "numInputRows" : 0, + "inputRowsPerSecond" : 0.0, + "processedRowsPerSecond" : 0.0, + "durationMs" : { + "getOffset" : 0, + "triggerExecution" : 1 + }, + "stateOperators" : [ { + "numRowsTotal" : 4, + "numRowsUpdated" : 0 + } ], + "sources" : [ { + "description" : "TextSocketSource[host: localhost, port: 9999]", + "startOffset" : 1, + "endOffset" : 1, + "numInputRows" : 0, + "inputRowsPerSecond" : 0.0, + "processedRowsPerSecond" : 0.0 + } ], + "sink" : { + "description" : "org.apache.spark.sql.execution.streaming.ConsoleSink@76b37531" + } +} +''' + +status(query) +''' +Will print something like the following. + +{ + "message" : "Waiting for data to arrive", + "isDataAvailable" : false, + "isTriggerActive" : false +} +''' +{% endhighlight %} +
    @@ -1615,11 +1908,17 @@ spark.streams().addListener(new StreamingQueryListener() { Not available in Python. {% endhighlight %} + +
    +{% highlight bash %} +Not available in R. +{% endhighlight %} +
    ## Recovering from Failures with Checkpointing -In case of a failure or intentional shutdown, you can recover the previous progress and state of a previous query, and continue where it left off. This is done using checkpointing and write ahead logs. You can configure a query with a checkpoint location, and the query will save all the progress information (i.e. range of offsets processed in each trigger) and the running aggregates (e.g. word counts in the [quick example](#quick-example)) to the checkpoint location. This checkpoint location has to be a path in an HDFS compatible file system, and can be set as an option in the DataStreamWriter when [starting a query](#starting-streaming-queries). +In case of a failure or intentional shutdown, you can recover the previous progress and state of a previous query, and continue where it left off. This is done using checkpointing and write ahead logs. You can configure a query with a checkpoint location, and the query will save all the progress information (i.e. range of offsets processed in each trigger) and the running aggregates (e.g. word counts in the [quick example](#quick-example)) to the checkpoint location. This checkpoint location has to be a path in an HDFS compatible file system, and can be set as an option in the DataStreamWriter when [starting a query](#starting-streaming-queries).
    @@ -1650,27 +1949,25 @@ aggDF {% highlight python %} aggDF \ - .writeStream() \ + .writeStream \ .outputMode("complete") \ .option("checkpointLocation", "path/to/HDFS/dir") \ .format("memory") \ .start() {% endhighlight %} +
    +
    + +{% highlight r %} +write.stream(aggDF, "memory", outputMode = "complete", checkpointLocation = "path/to/HDFS/dir") +{% endhighlight %} +
    # Where to go from here -- Examples: See and run the -[Scala]({{site.SPARK_GITHUB_URL}}/tree/master/examples/src/main/scala/org/apache/spark/examples/sql/streaming)/[Java]({{site.SPARK_GITHUB_URL}}/tree/master/examples/src/main/java/org/apache/spark/examples/sql/streaming)/[Python]({{site.SPARK_GITHUB_URL}}/tree/master/examples/src/main/python/sql/streaming) +- Examples: See and run the +[Scala]({{site.SPARK_GITHUB_URL}}/tree/master/examples/src/main/scala/org/apache/spark/examples/sql/streaming)/[Java]({{site.SPARK_GITHUB_URL}}/tree/master/examples/src/main/java/org/apache/spark/examples/sql/streaming)/[Python]({{site.SPARK_GITHUB_URL}}/tree/master/examples/src/main/python/sql/streaming)/[R]({{site.SPARK_GITHUB_URL}}/tree/master/examples/src/main/r/streaming) examples. - Spark Summit 2016 Talk - [A Deep Dive into Structured Streaming](https://spark-summit.org/2016/events/a-deep-dive-into-structured-streaming/) - - - - - - - - - diff --git a/docs/submitting-applications.md b/docs/submitting-applications.md index d23dbcf10d952..866d6e527549c 100644 --- a/docs/submitting-applications.md +++ b/docs/submitting-applications.md @@ -143,6 +143,9 @@ The master URL passed to Spark can be in one of the following formats: spark://HOST:PORT Connect to the given Spark standalone cluster master. The port must be whichever one your master is configured to use, which is 7077 by default. + spark://HOST1:PORT1,HOST2:PORT2 Connect to the given Spark standalone + cluster with standby masters with Zookeeper. The list must have all the master hosts in the high availability cluster set up with Zookeeper. The port must be whichever each master is configured to use, which is 7077 by default. + mesos://HOST:PORT Connect to the given Mesos cluster. The port must be whichever one your is configured to use, which is 5050 by default. Or, for a Mesos cluster using ZooKeeper, use mesos://zk://.... diff --git a/examples/pom.xml b/examples/pom.xml index 5a869d9ed9219..0608cb89a3b4f 100644 --- a/examples/pom.xml +++ b/examples/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.11 - 2.2.0-SNAPSHOT + 2.3.0-SNAPSHOT ../pom.xml diff --git a/examples/src/main/java/org/apache/spark/examples/JavaSparkPi.java b/examples/src/main/java/org/apache/spark/examples/JavaSparkPi.java index cb4b26569088a..37bd8fffbe45a 100644 --- a/examples/src/main/java/org/apache/spark/examples/JavaSparkPi.java +++ b/examples/src/main/java/org/apache/spark/examples/JavaSparkPi.java @@ -26,7 +26,7 @@ /** * Computes an approximation to pi - * Usage: JavaSparkPi [slices] + * Usage: JavaSparkPi [partitions] */ public final class JavaSparkPi { diff --git a/examples/src/main/java/org/apache/spark/examples/JavaTC.java b/examples/src/main/java/org/apache/spark/examples/JavaTC.java index bde30b84d6cf3..c9ca9c9b3a412 100644 --- a/examples/src/main/java/org/apache/spark/examples/JavaTC.java +++ b/examples/src/main/java/org/apache/spark/examples/JavaTC.java @@ -32,7 +32,7 @@ /** * Transitive closure on a graph, implemented in Java. - * Usage: JavaTC [slices] + * Usage: JavaTC [partitions] */ public final class JavaTC { diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaFPGrowthExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaFPGrowthExample.java new file mode 100644 index 0000000000000..717ec21c8b203 --- /dev/null +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaFPGrowthExample.java @@ -0,0 +1,77 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.ml; + +// $example on$ +import java.util.Arrays; +import java.util.List; + +import org.apache.spark.ml.fpm.FPGrowth; +import org.apache.spark.ml.fpm.FPGrowthModel; +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.RowFactory; +import org.apache.spark.sql.SparkSession; +import org.apache.spark.sql.types.*; +// $example off$ + +/** + * An example demonstrating FPGrowth. + * Run with + *
    + * bin/run-example ml.JavaFPGrowthExample
    + * 
    + */ +public class JavaFPGrowthExample { + public static void main(String[] args) { + SparkSession spark = SparkSession + .builder() + .appName("JavaFPGrowthExample") + .getOrCreate(); + + // $example on$ + List data = Arrays.asList( + RowFactory.create(Arrays.asList("1 2 5".split(" "))), + RowFactory.create(Arrays.asList("1 2 3 5".split(" "))), + RowFactory.create(Arrays.asList("1 2".split(" "))) + ); + StructType schema = new StructType(new StructField[]{ new StructField( + "items", new ArrayType(DataTypes.StringType, true), false, Metadata.empty()) + }); + Dataset itemsDF = spark.createDataFrame(data, schema); + + FPGrowthModel model = new FPGrowth() + .setItemsCol("items") + .setMinSupport(0.5) + .setMinConfidence(0.6) + .fit(itemsDF); + + // Display frequent itemsets. + model.freqItemsets().show(); + + // Display generated association rules. + model.associationRules().show(); + + // transform examines the input items against all the association rules and summarize the + // consequents as prediction + model.transform(itemsDF).show(); + // $example off$ + + spark.stop(); + } +} diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaImputerExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaImputerExample.java new file mode 100644 index 0000000000000..ac40ccd9dbd75 --- /dev/null +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaImputerExample.java @@ -0,0 +1,71 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.ml; + +// $example on$ +import java.util.Arrays; +import java.util.List; + +import org.apache.spark.ml.feature.Imputer; +import org.apache.spark.ml.feature.ImputerModel; +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.RowFactory; +import org.apache.spark.sql.SparkSession; +import org.apache.spark.sql.types.*; +// $example off$ + +import static org.apache.spark.sql.types.DataTypes.*; + +/** + * An example demonstrating Imputer. + * Run with: + * bin/run-example ml.JavaImputerExample + */ +public class JavaImputerExample { + public static void main(String[] args) { + SparkSession spark = SparkSession + .builder() + .appName("JavaImputerExample") + .getOrCreate(); + + // $example on$ + List data = Arrays.asList( + RowFactory.create(1.0, Double.NaN), + RowFactory.create(2.0, Double.NaN), + RowFactory.create(Double.NaN, 3.0), + RowFactory.create(4.0, 4.0), + RowFactory.create(5.0, 5.0) + ); + StructType schema = new StructType(new StructField[]{ + createStructField("a", DoubleType, false), + createStructField("b", DoubleType, false) + }); + Dataset df = spark.createDataFrame(data, schema); + + Imputer imputer = new Imputer() + .setInputCols(new String[]{"a", "b"}) + .setOutputCols(new String[]{"out_a", "out_b"}); + + ImputerModel model = imputer.fit(df); + model.transform(df).show(); + // $example off$ + + spark.stop(); + } +} diff --git a/examples/src/main/java/org/apache/spark/examples/mllib/JavaPCAExample.java b/examples/src/main/java/org/apache/spark/examples/mllib/JavaPCAExample.java index 3077f557ef886..0a7dc621e1110 100644 --- a/examples/src/main/java/org/apache/spark/examples/mllib/JavaPCAExample.java +++ b/examples/src/main/java/org/apache/spark/examples/mllib/JavaPCAExample.java @@ -18,7 +18,8 @@ package org.apache.spark.examples.mllib; // $example on$ -import java.util.LinkedList; +import java.util.Arrays; +import java.util.List; // $example off$ import org.apache.spark.SparkConf; @@ -39,21 +40,25 @@ public class JavaPCAExample { public static void main(String[] args) { SparkConf conf = new SparkConf().setAppName("PCA Example"); SparkContext sc = new SparkContext(conf); + JavaSparkContext jsc = JavaSparkContext.fromSparkContext(sc); // $example on$ - double[][] array = {{1.12, 2.05, 3.12}, {5.56, 6.28, 8.94}, {10.2, 8.0, 20.5}}; - LinkedList rowsList = new LinkedList<>(); - for (int i = 0; i < array.length; i++) { - Vector currentRow = Vectors.dense(array[i]); - rowsList.add(currentRow); - } - JavaRDD rows = JavaSparkContext.fromSparkContext(sc).parallelize(rowsList); + List data = Arrays.asList( + Vectors.sparse(5, new int[] {1, 3}, new double[] {1.0, 7.0}), + Vectors.dense(2.0, 0.0, 3.0, 4.0, 5.0), + Vectors.dense(4.0, 0.0, 0.0, 6.0, 7.0) + ); + + JavaRDD rows = jsc.parallelize(data); // Create a RowMatrix from JavaRDD. RowMatrix mat = new RowMatrix(rows.rdd()); - // Compute the top 3 principal components. - Matrix pc = mat.computePrincipalComponents(3); + // Compute the top 4 principal components. + // Principal components are stored in a local dense matrix. + Matrix pc = mat.computePrincipalComponents(4); + + // Project the rows to the linear space spanned by the top 4 principal components. RowMatrix projected = mat.multiply(pc); // $example off$ Vector[] collectPartitions = (Vector[])projected.rows().collect(); @@ -61,6 +66,6 @@ public static void main(String[] args) { for (Vector vector : collectPartitions) { System.out.println("\t" + vector); } - sc.stop(); + jsc.stop(); } } diff --git a/examples/src/main/java/org/apache/spark/examples/mllib/JavaSVDExample.java b/examples/src/main/java/org/apache/spark/examples/mllib/JavaSVDExample.java index 3730e60f68803..802be3960a337 100644 --- a/examples/src/main/java/org/apache/spark/examples/mllib/JavaSVDExample.java +++ b/examples/src/main/java/org/apache/spark/examples/mllib/JavaSVDExample.java @@ -18,7 +18,8 @@ package org.apache.spark.examples.mllib; // $example on$ -import java.util.LinkedList; +import java.util.Arrays; +import java.util.List; // $example off$ import org.apache.spark.SparkConf; @@ -43,22 +44,22 @@ public static void main(String[] args) { JavaSparkContext jsc = JavaSparkContext.fromSparkContext(sc); // $example on$ - double[][] array = {{1.12, 2.05, 3.12}, {5.56, 6.28, 8.94}, {10.2, 8.0, 20.5}}; - LinkedList rowsList = new LinkedList<>(); - for (int i = 0; i < array.length; i++) { - Vector currentRow = Vectors.dense(array[i]); - rowsList.add(currentRow); - } - JavaRDD rows = jsc.parallelize(rowsList); + List data = Arrays.asList( + Vectors.sparse(5, new int[] {1, 3}, new double[] {1.0, 7.0}), + Vectors.dense(2.0, 0.0, 3.0, 4.0, 5.0), + Vectors.dense(4.0, 0.0, 0.0, 6.0, 7.0) + ); + + JavaRDD rows = jsc.parallelize(data); // Create a RowMatrix from JavaRDD. RowMatrix mat = new RowMatrix(rows.rdd()); - // Compute the top 3 singular values and corresponding singular vectors. - SingularValueDecomposition svd = mat.computeSVD(3, true, 1.0E-9d); - RowMatrix U = svd.U(); - Vector s = svd.s(); - Matrix V = svd.V(); + // Compute the top 5 singular values and corresponding singular vectors. + SingularValueDecomposition svd = mat.computeSVD(5, true, 1.0E-9d); + RowMatrix U = svd.U(); // The U factor is a RowMatrix. + Vector s = svd.s(); // The singular values are stored in a local dense vector. + Matrix V = svd.V(); // The V factor is a local dense matrix. // $example off$ Vector[] collectPartitions = (Vector[]) U.rows().collect(); System.out.println("U factor is:"); diff --git a/examples/src/main/java/org/apache/spark/examples/sql/JavaSQLDataSourceExample.java b/examples/src/main/java/org/apache/spark/examples/sql/JavaSQLDataSourceExample.java index 82bb284ea3e58..b66abaed66000 100644 --- a/examples/src/main/java/org/apache/spark/examples/sql/JavaSQLDataSourceExample.java +++ b/examples/src/main/java/org/apache/spark/examples/sql/JavaSQLDataSourceExample.java @@ -215,7 +215,7 @@ private static void runJsonDatasetExample(SparkSession spark) { // +------+ // Alternatively, a DataFrame can be created for a JSON dataset represented by - // an Dataset[String] storing one JSON object per string. + // a Dataset storing one JSON object per string. List jsonData = Arrays.asList( "{\"name\":\"Yin\",\"address\":{\"city\":\"Columbus\",\"state\":\"Ohio\"}}"); Dataset anotherPeopleDataset = spark.createDataset(jsonData, Encoders.STRING()); @@ -258,6 +258,11 @@ private static void runJdbcDatasetExample(SparkSession spark) { jdbcDF2.write() .jdbc("jdbc:postgresql:dbserver", "schema.tablename", connectionProperties); + + // Specifying create table column data types on write + jdbcDF.write() + .option("createTableColumnTypes", "name CHAR(64), comments VARCHAR(1024)") + .jdbc("jdbc:postgresql:dbserver", "schema.tablename", connectionProperties); // $example off:jdbc_dataset$ } } diff --git a/examples/src/main/java/org/apache/spark/examples/sql/hive/JavaSparkHiveExample.java b/examples/src/main/java/org/apache/spark/examples/sql/hive/JavaSparkHiveExample.java index 47638565b1663..575a463e8725f 100644 --- a/examples/src/main/java/org/apache/spark/examples/sql/hive/JavaSparkHiveExample.java +++ b/examples/src/main/java/org/apache/spark/examples/sql/hive/JavaSparkHiveExample.java @@ -89,7 +89,7 @@ public static void main(String[] args) { // The results of SQL queries are themselves DataFrames and support all normal functions. Dataset sqlDF = spark.sql("SELECT key, value FROM src WHERE key < 10 ORDER BY key"); - // The items in DaraFrames are of type Row, which lets you to access each column by ordinal. + // The items in DataFrames are of type Row, which lets you to access each column by ordinal. Dataset stringsDS = sqlDF.map( (MapFunction) row -> "Key: " + row.get(0) + ", Value: " + row.get(1), Encoders.STRING()); diff --git a/examples/src/main/java/org/apache/spark/examples/sql/streaming/JavaStructuredSessionization.java b/examples/src/main/java/org/apache/spark/examples/sql/streaming/JavaStructuredSessionization.java new file mode 100644 index 0000000000000..6b8e6554f1bb1 --- /dev/null +++ b/examples/src/main/java/org/apache/spark/examples/sql/streaming/JavaStructuredSessionization.java @@ -0,0 +1,252 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.examples.sql.streaming; + +import org.apache.spark.api.java.function.FlatMapFunction; +import org.apache.spark.api.java.function.MapFunction; +import org.apache.spark.api.java.function.MapGroupsWithStateFunction; +import org.apache.spark.sql.*; +import org.apache.spark.sql.streaming.GroupState; +import org.apache.spark.sql.streaming.GroupStateTimeout; +import org.apache.spark.sql.streaming.StreamingQuery; + +import java.io.Serializable; +import java.sql.Timestamp; +import java.util.*; + +/** + * Counts words in UTF8 encoded, '\n' delimited text received from the network. + *

    + * Usage: JavaStructuredNetworkWordCount + * and describe the TCP server that Structured Streaming + * would connect to receive data. + *

    + * To run this on your local machine, you need to first run a Netcat server + * `$ nc -lk 9999` + * and then run the example + * `$ bin/run-example sql.streaming.JavaStructuredSessionization + * localhost 9999` + */ +public final class JavaStructuredSessionization { + + public static void main(String[] args) throws Exception { + if (args.length < 2) { + System.err.println("Usage: JavaStructuredSessionization "); + System.exit(1); + } + + String host = args[0]; + int port = Integer.parseInt(args[1]); + + SparkSession spark = SparkSession + .builder() + .appName("JavaStructuredSessionization") + .getOrCreate(); + + // Create DataFrame representing the stream of input lines from connection to host:port + Dataset lines = spark + .readStream() + .format("socket") + .option("host", host) + .option("port", port) + .option("includeTimestamp", true) + .load(); + + FlatMapFunction linesToEvents = + new FlatMapFunction() { + @Override + public Iterator call(LineWithTimestamp lineWithTimestamp) throws Exception { + ArrayList eventList = new ArrayList(); + for (String word : lineWithTimestamp.getLine().split(" ")) { + eventList.add(new Event(word, lineWithTimestamp.getTimestamp())); + } + return eventList.iterator(); + } + }; + + // Split the lines into words, treat words as sessionId of events + Dataset events = lines + .withColumnRenamed("value", "line") + .as(Encoders.bean(LineWithTimestamp.class)) + .flatMap(linesToEvents, Encoders.bean(Event.class)); + + // Sessionize the events. Track number of events, start and end timestamps of session, and + // and report session updates. + // + // Step 1: Define the state update function + MapGroupsWithStateFunction stateUpdateFunc = + new MapGroupsWithStateFunction() { + @Override public SessionUpdate call( + String sessionId, Iterator events, GroupState state) + throws Exception { + // If timed out, then remove session and send final update + if (state.hasTimedOut()) { + SessionUpdate finalUpdate = new SessionUpdate( + sessionId, state.get().calculateDuration(), state.get().getNumEvents(), true); + state.remove(); + return finalUpdate; + + } else { + // Find max and min timestamps in events + long maxTimestampMs = Long.MIN_VALUE; + long minTimestampMs = Long.MAX_VALUE; + int numNewEvents = 0; + while (events.hasNext()) { + Event e = events.next(); + long timestampMs = e.getTimestamp().getTime(); + maxTimestampMs = Math.max(timestampMs, maxTimestampMs); + minTimestampMs = Math.min(timestampMs, minTimestampMs); + numNewEvents += 1; + } + SessionInfo updatedSession = new SessionInfo(); + + // Update start and end timestamps in session + if (state.exists()) { + SessionInfo oldSession = state.get(); + updatedSession.setNumEvents(oldSession.numEvents + numNewEvents); + updatedSession.setStartTimestampMs(oldSession.startTimestampMs); + updatedSession.setEndTimestampMs(Math.max(oldSession.endTimestampMs, maxTimestampMs)); + } else { + updatedSession.setNumEvents(numNewEvents); + updatedSession.setStartTimestampMs(minTimestampMs); + updatedSession.setEndTimestampMs(maxTimestampMs); + } + state.update(updatedSession); + // Set timeout such that the session will be expired if no data received for 10 seconds + state.setTimeoutDuration("10 seconds"); + return new SessionUpdate( + sessionId, state.get().calculateDuration(), state.get().getNumEvents(), false); + } + } + }; + + // Step 2: Apply the state update function to the events streaming Dataset grouped by sessionId + Dataset sessionUpdates = events + .groupByKey( + new MapFunction() { + @Override public String call(Event event) throws Exception { + return event.getSessionId(); + } + }, Encoders.STRING()) + .mapGroupsWithState( + stateUpdateFunc, + Encoders.bean(SessionInfo.class), + Encoders.bean(SessionUpdate.class), + GroupStateTimeout.ProcessingTimeTimeout()); + + // Start running the query that prints the session updates to the console + StreamingQuery query = sessionUpdates + .writeStream() + .outputMode("update") + .format("console") + .start(); + + query.awaitTermination(); + } + + /** + * User-defined data type representing the raw lines with timestamps. + */ + public static class LineWithTimestamp implements Serializable { + private String line; + private Timestamp timestamp; + + public Timestamp getTimestamp() { return timestamp; } + public void setTimestamp(Timestamp timestamp) { this.timestamp = timestamp; } + + public String getLine() { return line; } + public void setLine(String sessionId) { this.line = sessionId; } + } + + /** + * User-defined data type representing the input events + */ + public static class Event implements Serializable { + private String sessionId; + private Timestamp timestamp; + + public Event() { } + public Event(String sessionId, Timestamp timestamp) { + this.sessionId = sessionId; + this.timestamp = timestamp; + } + + public Timestamp getTimestamp() { return timestamp; } + public void setTimestamp(Timestamp timestamp) { this.timestamp = timestamp; } + + public String getSessionId() { return sessionId; } + public void setSessionId(String sessionId) { this.sessionId = sessionId; } + } + + /** + * User-defined data type for storing a session information as state in mapGroupsWithState. + */ + public static class SessionInfo implements Serializable { + private int numEvents = 0; + private long startTimestampMs = -1; + private long endTimestampMs = -1; + + public int getNumEvents() { return numEvents; } + public void setNumEvents(int numEvents) { this.numEvents = numEvents; } + + public long getStartTimestampMs() { return startTimestampMs; } + public void setStartTimestampMs(long startTimestampMs) { + this.startTimestampMs = startTimestampMs; + } + + public long getEndTimestampMs() { return endTimestampMs; } + public void setEndTimestampMs(long endTimestampMs) { this.endTimestampMs = endTimestampMs; } + + public long calculateDuration() { return endTimestampMs - startTimestampMs; } + + @Override public String toString() { + return "SessionInfo(numEvents = " + numEvents + + ", timestamps = " + startTimestampMs + " to " + endTimestampMs + ")"; + } + } + + /** + * User-defined data type representing the update information returned by mapGroupsWithState. + */ + public static class SessionUpdate implements Serializable { + private String id; + private long durationMs; + private int numEvents; + private boolean expired; + + public SessionUpdate() { } + + public SessionUpdate(String id, long durationMs, int numEvents, boolean expired) { + this.id = id; + this.durationMs = durationMs; + this.numEvents = numEvents; + this.expired = expired; + } + + public String getId() { return id; } + public void setId(String id) { this.id = id; } + + public long getDurationMs() { return durationMs; } + public void setDurationMs(long durationMs) { this.durationMs = durationMs; } + + public int getNumEvents() { return numEvents; } + public void setNumEvents(int numEvents) { this.numEvents = numEvents; } + + public boolean isExpired() { return expired; } + public void setExpired(boolean expired) { this.expired = expired; } + } +} diff --git a/examples/src/main/python/ml/fpgrowth_example.py b/examples/src/main/python/ml/fpgrowth_example.py new file mode 100644 index 0000000000000..c92c3c27abb21 --- /dev/null +++ b/examples/src/main/python/ml/fpgrowth_example.py @@ -0,0 +1,56 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# $example on$ +from pyspark.ml.fpm import FPGrowth +# $example off$ +from pyspark.sql import SparkSession + +""" +An example demonstrating FPGrowth. +Run with: + bin/spark-submit examples/src/main/python/ml/fpgrowth_example.py +""" + +if __name__ == "__main__": + spark = SparkSession\ + .builder\ + .appName("FPGrowthExample")\ + .getOrCreate() + + # $example on$ + df = spark.createDataFrame([ + (0, [1, 2, 5]), + (1, [1, 2, 3, 5]), + (2, [1, 2]) + ], ["id", "items"]) + + fpGrowth = FPGrowth(itemsCol="items", minSupport=0.5, minConfidence=0.6) + model = fpGrowth.fit(df) + + # Display frequent itemsets. + model.freqItemsets.show() + + # Display generated association rules. + model.associationRules.show() + + # transform examines the input items against all the association rules and summarize the + # consequents as prediction + model.transform(df).show() + # $example off$ + + spark.stop() diff --git a/examples/src/main/python/ml/imputer_example.py b/examples/src/main/python/ml/imputer_example.py new file mode 100644 index 0000000000000..b8437f827e56d --- /dev/null +++ b/examples/src/main/python/ml/imputer_example.py @@ -0,0 +1,50 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# $example on$ +from pyspark.ml.feature import Imputer +# $example off$ +from pyspark.sql import SparkSession + +""" +An example demonstrating Imputer. +Run with: + bin/spark-submit examples/src/main/python/ml/imputer_example.py +""" + +if __name__ == "__main__": + spark = SparkSession\ + .builder\ + .appName("ImputerExample")\ + .getOrCreate() + + # $example on$ + df = spark.createDataFrame([ + (1.0, float("nan")), + (2.0, float("nan")), + (float("nan"), 3.0), + (4.0, 4.0), + (5.0, 5.0) + ], ["a", "b"]) + + imputer = Imputer(inputCols=["a", "b"], outputCols=["out_a", "out_b"]) + model = imputer.fit(df) + + model.transform(df).show() + # $example off$ + + spark.stop() diff --git a/examples/src/main/python/mllib/pca_rowmatrix_example.py b/examples/src/main/python/mllib/pca_rowmatrix_example.py new file mode 100644 index 0000000000000..49b9b1bbe08e9 --- /dev/null +++ b/examples/src/main/python/mllib/pca_rowmatrix_example.py @@ -0,0 +1,46 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from pyspark import SparkContext +# $example on$ +from pyspark.mllib.linalg import Vectors +from pyspark.mllib.linalg.distributed import RowMatrix +# $example off$ + +if __name__ == "__main__": + sc = SparkContext(appName="PythonPCAOnRowMatrixExample") + + # $example on$ + rows = sc.parallelize([ + Vectors.sparse(5, {1: 1.0, 3: 7.0}), + Vectors.dense(2.0, 0.0, 3.0, 4.0, 5.0), + Vectors.dense(4.0, 0.0, 0.0, 6.0, 7.0) + ]) + + mat = RowMatrix(rows) + # Compute the top 4 principal components. + # Principal components are stored in a local dense matrix. + pc = mat.computePrincipalComponents(4) + + # Project the rows to the linear space spanned by the top 4 principal components. + projected = mat.multiply(pc) + # $example off$ + collected = projected.rows.collect() + print("Projected Row Matrix of principal component:") + for vector in collected: + print(vector) + sc.stop() diff --git a/examples/src/main/python/mllib/svd_example.py b/examples/src/main/python/mllib/svd_example.py new file mode 100644 index 0000000000000..5b220fdb3fd67 --- /dev/null +++ b/examples/src/main/python/mllib/svd_example.py @@ -0,0 +1,48 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from pyspark import SparkContext +# $example on$ +from pyspark.mllib.linalg import Vectors +from pyspark.mllib.linalg.distributed import RowMatrix +# $example off$ + +if __name__ == "__main__": + sc = SparkContext(appName="PythonSVDExample") + + # $example on$ + rows = sc.parallelize([ + Vectors.sparse(5, {1: 1.0, 3: 7.0}), + Vectors.dense(2.0, 0.0, 3.0, 4.0, 5.0), + Vectors.dense(4.0, 0.0, 0.0, 6.0, 7.0) + ]) + + mat = RowMatrix(rows) + + # Compute the top 5 singular values and corresponding singular vectors. + svd = mat.computeSVD(5, computeU=True) + U = svd.U # The U factor is a RowMatrix. + s = svd.s # The singular values are stored in a local dense vector. + V = svd.V # The V factor is a local dense matrix. + # $example off$ + collected = U.rows.collect() + print("U factor is:") + for vector in collected: + print(vector) + print("Singular values are: %s" % s) + print("V factor is:\n%s" % V) + sc.stop() diff --git a/examples/src/main/python/sql/datasource.py b/examples/src/main/python/sql/datasource.py index e9aa9d9ac2583..e4abb0933345d 100644 --- a/examples/src/main/python/sql/datasource.py +++ b/examples/src/main/python/sql/datasource.py @@ -169,6 +169,12 @@ def jdbc_dataset_example(spark): jdbcDF2.write \ .jdbc("jdbc:postgresql:dbserver", "schema.tablename", properties={"user": "username", "password": "password"}) + + # Specifying create table column data types on write + jdbcDF.write \ + .option("createTableColumnTypes", "name CHAR(64), comments VARCHAR(1024)") \ + .jdbc("jdbc:postgresql:dbserver", "schema.tablename", + properties={"user": "username", "password": "password"}) # $example off:jdbc_dataset$ diff --git a/examples/src/main/python/sql/hive.py b/examples/src/main/python/sql/hive.py index 1f175d725800f..1f83a6fb48b97 100644 --- a/examples/src/main/python/sql/hive.py +++ b/examples/src/main/python/sql/hive.py @@ -68,7 +68,7 @@ # The results of SQL queries are themselves DataFrames and support all normal functions. sqlDF = spark.sql("SELECT key, value FROM src WHERE key < 10 ORDER BY key") - # The items in DaraFrames are of type Row, which allows you to access each column by ordinal. + # The items in DataFrames are of type Row, which allows you to access each column by ordinal. stringsDS = sqlDF.rdd.map(lambda row: "Key: %d, Value: %s" % (row.key, row.value)) for record in stringsDS.collect(): print(record) diff --git a/examples/src/main/r/RSparkSQLExample.R b/examples/src/main/r/RSparkSQLExample.R index e647f0e1e9f17..3734568d872d0 100644 --- a/examples/src/main/r/RSparkSQLExample.R +++ b/examples/src/main/r/RSparkSQLExample.R @@ -15,6 +15,9 @@ # limitations under the License. # +# To run this example use +# ./bin/spark-submit examples/src/main/r/RSparkSQLExample.R + library(SparkR) # $example on:init_session$ diff --git a/examples/src/main/r/dataframe.R b/examples/src/main/r/dataframe.R index 82b85f2f590f6..311350497f873 100644 --- a/examples/src/main/r/dataframe.R +++ b/examples/src/main/r/dataframe.R @@ -15,6 +15,9 @@ # limitations under the License. # +# To run this example use +# ./bin/spark-submit examples/src/main/r/dataframe.R + library(SparkR) # Initialize SparkSession diff --git a/examples/src/main/r/ml/fpm.R b/examples/src/main/r/ml/fpm.R new file mode 100644 index 0000000000000..89c4564457d9e --- /dev/null +++ b/examples/src/main/r/ml/fpm.R @@ -0,0 +1,50 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# To run this example use +# ./bin/spark-submit examples/src/main/r/ml/fpm.R + +# Load SparkR library into your R session +library(SparkR) + +# Initialize SparkSession +sparkR.session(appName = "SparkR-ML-fpm-example") + +# $example on$ +# Load training data + +df <- selectExpr(createDataFrame(data.frame(rawItems = c( + "1,2,5", "1,2,3,5", "1,2" +))), "split(rawItems, ',') AS items") + +fpm <- spark.fpGrowth(df, itemsCol="items", minSupport=0.5, minConfidence=0.6) + +# Extracting frequent itemsets + +spark.freqItemsets(fpm) + +# Extracting association rules + +spark.associationRules(fpm) + +# Predict uses association rules to and combines possible consequents + +predict(fpm, df) + +# $example off$ + +sparkR.session.stop() diff --git a/examples/src/main/r/ml/glm.R b/examples/src/main/r/ml/glm.R index ee13910382c58..68787f9aa9dca 100644 --- a/examples/src/main/r/ml/glm.R +++ b/examples/src/main/r/ml/glm.R @@ -27,7 +27,7 @@ sparkR.session(appName = "SparkR-ML-glm-example") # $example on$ training <- read.df("data/mllib/sample_multiclass_classification_data.txt", source = "libsvm") # Fit a generalized linear model of family "gaussian" with spark.glm -df_list <- randomSplit(training, c(7,3), 2) +df_list <- randomSplit(training, c(7, 3), 2) gaussianDF <- df_list[[1]] gaussianTestDF <- df_list[[2]] gaussianGLM <- spark.glm(gaussianDF, label ~ features, family = "gaussian") @@ -44,8 +44,9 @@ gaussianGLM2 <- glm(label ~ features, gaussianDF, family = "gaussian") summary(gaussianGLM2) # Fit a generalized linear model of family "binomial" with spark.glm -training2 <- read.df("data/mllib/sample_binary_classification_data.txt", source = "libsvm") -df_list2 <- randomSplit(training2, c(7,3), 2) +training2 <- read.df("data/mllib/sample_multiclass_classification_data.txt", source = "libsvm") +training2 <- transform(training2, label = cast(training2$label > 1, "integer")) +df_list2 <- randomSplit(training2, c(7, 3), 2) binomialDF <- df_list2[[1]] binomialTestDF <- df_list2[[2]] binomialGLM <- spark.glm(binomialDF, label ~ features, family = "binomial") @@ -56,6 +57,15 @@ summary(binomialGLM) # Prediction binomialPredictions <- predict(binomialGLM, binomialTestDF) head(binomialPredictions) + +# Fit a generalized linear model of family "tweedie" with spark.glm +training3 <- read.df("data/mllib/sample_multiclass_classification_data.txt", source = "libsvm") +tweedieDF <- transform(training3, label = training3$label * exp(randn(10))) +tweedieGLM <- spark.glm(tweedieDF, label ~ features, family = "tweedie", + var.power = 1.2, link.power = 0) + +# Model summary +summary(tweedieGLM) # $example off$ sparkR.session.stop() diff --git a/examples/src/main/r/streaming/structured_network_wordcount.R b/examples/src/main/r/streaming/structured_network_wordcount.R new file mode 100644 index 0000000000000..cda18ebc072ee --- /dev/null +++ b/examples/src/main/r/streaming/structured_network_wordcount.R @@ -0,0 +1,57 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# Counts words in UTF8 encoded, '\n' delimited text received from the network. + +# To run this on your local machine, you need to first run a Netcat server +# $ nc -lk 9999 +# and then run the example +# ./bin/spark-submit examples/src/main/r/streaming/structured_network_wordcount.R localhost 9999 + +# Load SparkR library into your R session +library(SparkR) + +# Initialize SparkSession +sparkR.session(appName = "SparkR-Streaming-structured-network-wordcount-example") + +args <- commandArgs(trailing = TRUE) + +if (length(args) != 2) { + print("Usage: structured_network_wordcount.R ") + print(" and describe the TCP server that Structured Streaming") + print("would connect to receive data.") + q("no") +} + +hostname <- args[[1]] +port <- as.integer(args[[2]]) + +# Create DataFrame representing the stream of input lines from connection to localhost:9999 +lines <- read.stream("socket", host = hostname, port = port) + +# Split the lines into words +words <- selectExpr(lines, "explode(split(value, ' ')) as word") + +# Generate running word count +wordCounts <- count(groupBy(words, "word")) + +# Start running the query that prints the running counts to the console +query <- write.stream(wordCounts, "console", outputMode = "complete") + +awaitTermination(query) + +sparkR.session.stop() diff --git a/examples/src/main/scala/org/apache/spark/examples/BroadcastTest.scala b/examples/src/main/scala/org/apache/spark/examples/BroadcastTest.scala index 86eed3867c539..25718f904cc49 100644 --- a/examples/src/main/scala/org/apache/spark/examples/BroadcastTest.scala +++ b/examples/src/main/scala/org/apache/spark/examples/BroadcastTest.scala @@ -21,7 +21,7 @@ package org.apache.spark.examples import org.apache.spark.sql.SparkSession /** - * Usage: BroadcastTest [slices] [numElem] [blockSize] + * Usage: BroadcastTest [partitions] [numElem] [blockSize] */ object BroadcastTest { def main(args: Array[String]) { diff --git a/examples/src/main/scala/org/apache/spark/examples/MultiBroadcastTest.scala b/examples/src/main/scala/org/apache/spark/examples/MultiBroadcastTest.scala index 6495a86fcd77c..e6f33b7adf5d1 100644 --- a/examples/src/main/scala/org/apache/spark/examples/MultiBroadcastTest.scala +++ b/examples/src/main/scala/org/apache/spark/examples/MultiBroadcastTest.scala @@ -23,7 +23,7 @@ import org.apache.spark.sql.SparkSession /** - * Usage: MultiBroadcastTest [slices] [numElem] + * Usage: MultiBroadcastTest [partitions] [numElem] */ object MultiBroadcastTest { def main(args: Array[String]) { diff --git a/examples/src/main/scala/org/apache/spark/examples/SparkALS.scala b/examples/src/main/scala/org/apache/spark/examples/SparkALS.scala index 8a3d08f459783..a99ddd9fd37db 100644 --- a/examples/src/main/scala/org/apache/spark/examples/SparkALS.scala +++ b/examples/src/main/scala/org/apache/spark/examples/SparkALS.scala @@ -100,7 +100,7 @@ object SparkALS { ITERATIONS = iters.getOrElse("5").toInt slices = slices_.getOrElse("2").toInt case _ => - System.err.println("Usage: SparkALS [M] [U] [F] [iters] [slices]") + System.err.println("Usage: SparkALS [M] [U] [F] [iters] [partitions]") System.exit(1) } diff --git a/examples/src/main/scala/org/apache/spark/examples/SparkLR.scala b/examples/src/main/scala/org/apache/spark/examples/SparkLR.scala index afa8f58c96e59..cb2be091ffcf3 100644 --- a/examples/src/main/scala/org/apache/spark/examples/SparkLR.scala +++ b/examples/src/main/scala/org/apache/spark/examples/SparkLR.scala @@ -28,7 +28,7 @@ import org.apache.spark.sql.SparkSession /** * Logistic regression based classification. - * Usage: SparkLR [slices] + * Usage: SparkLR [partitions] * * This is an example implementation for learning how to use Spark. For more conventional use, * please refer to org.apache.spark.ml.classification.LogisticRegression. diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/DataFrameExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/DataFrameExample.scala index e07c9a4717c3a..0658bddf16961 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/DataFrameExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/DataFrameExample.scala @@ -30,7 +30,7 @@ import org.apache.spark.sql.{DataFrame, Row, SparkSession} import org.apache.spark.util.Utils /** - * An example of how to use [[org.apache.spark.sql.DataFrame]] for ML. Run with + * An example of how to use [[DataFrame]] for ML. Run with * {{{ * ./bin/run-example ml.DataFrameExample [options] * }}} diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/DecisionTreeExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/DecisionTreeExample.scala index 1745281c266cc..f736ceed4436f 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/DecisionTreeExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/DecisionTreeExample.scala @@ -18,6 +18,8 @@ // scalastyle:off println package org.apache.spark.examples.ml +import java.util.Locale + import scala.collection.mutable import scala.language.reflectiveCalls @@ -203,7 +205,7 @@ object DecisionTreeExample { .getOrCreate() params.checkpointDir.foreach(spark.sparkContext.setCheckpointDir) - val algo = params.algo.toLowerCase + val algo = params.algo.toLowerCase(Locale.ROOT) println(s"DecisionTreeExample with parameters:\n$params") diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/FPGrowthExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/FPGrowthExample.scala new file mode 100644 index 0000000000000..59110d70de550 --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/ml/FPGrowthExample.scala @@ -0,0 +1,67 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.ml + +// scalastyle:off println + +// $example on$ +import org.apache.spark.ml.fpm.FPGrowth +// $example off$ +import org.apache.spark.sql.SparkSession + +/** + * An example demonstrating FP-Growth. + * Run with + * {{{ + * bin/run-example ml.FPGrowthExample + * }}} + */ +object FPGrowthExample { + + def main(args: Array[String]): Unit = { + val spark = SparkSession + .builder + .appName(s"${this.getClass.getSimpleName}") + .getOrCreate() + import spark.implicits._ + + // $example on$ + val dataset = spark.createDataset(Seq( + "1 2 5", + "1 2 3 5", + "1 2") + ).map(t => t.split(" ")).toDF("items") + + val fpgrowth = new FPGrowth().setItemsCol("items").setMinSupport(0.5).setMinConfidence(0.6) + val model = fpgrowth.fit(dataset) + + // Display frequent itemsets. + model.freqItemsets.show() + + // Display generated association rules. + model.associationRules.show() + + // transform examines the input items against all the association rules and summarize the + // consequents as prediction + model.transform(dataset).show() + // $example off$ + + spark.stop() + } +} +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/GBTExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/GBTExample.scala index db55298d8ea10..ed598d0d7dfae 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/GBTExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/GBTExample.scala @@ -18,6 +18,8 @@ // scalastyle:off println package org.apache.spark.examples.ml +import java.util.Locale + import scala.collection.mutable import scala.language.reflectiveCalls @@ -140,7 +142,7 @@ object GBTExample { .getOrCreate() params.checkpointDir.foreach(spark.sparkContext.setCheckpointDir) - val algo = params.algo.toLowerCase + val algo = params.algo.toLowerCase(Locale.ROOT) println(s"GBTExample with parameters:\n$params") diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/ImputerExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/ImputerExample.scala new file mode 100644 index 0000000000000..49e98d0c622ca --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/ml/ImputerExample.scala @@ -0,0 +1,56 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.ml + +// $example on$ +import org.apache.spark.ml.feature.Imputer +// $example off$ +import org.apache.spark.sql.SparkSession + +/** + * An example demonstrating Imputer. + * Run with: + * bin/run-example ml.ImputerExample + */ +object ImputerExample { + + def main(args: Array[String]): Unit = { + val spark = SparkSession.builder + .appName("ImputerExample") + .getOrCreate() + + // $example on$ + val df = spark.createDataFrame(Seq( + (1.0, Double.NaN), + (2.0, Double.NaN), + (Double.NaN, 3.0), + (4.0, 4.0), + (5.0, 5.0) + )).toDF("a", "b") + + val imputer = new Imputer() + .setInputCols(Array("a", "b")) + .setOutputCols(Array("out_a", "out_b")) + + val model = imputer.fit(df) + model.transform(df).show() + // $example off$ + + spark.stop() + } +} diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/RandomForestExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/RandomForestExample.scala index a9e07c0705c92..8fd46c37e2987 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/RandomForestExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/RandomForestExample.scala @@ -18,6 +18,8 @@ // scalastyle:off println package org.apache.spark.examples.ml +import java.util.Locale + import scala.collection.mutable import scala.language.reflectiveCalls @@ -146,7 +148,7 @@ object RandomForestExample { .getOrCreate() params.checkpointDir.foreach(spark.sparkContext.setCheckpointDir) - val algo = params.algo.toLowerCase + val algo = params.algo.toLowerCase(Locale.ROOT) println(s"RandomForestExample with parameters:\n$params") diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/LDAExample.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/LDAExample.scala index b923e627f2095..cd77ecf990b3b 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/LDAExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/LDAExample.scala @@ -18,6 +18,8 @@ // scalastyle:off println package org.apache.spark.examples.mllib +import java.util.Locale + import org.apache.log4j.{Level, Logger} import scopt.OptionParser @@ -131,7 +133,7 @@ object LDAExample { // Run LDA. val lda = new LDA() - val optimizer = params.algorithm.toLowerCase match { + val optimizer = params.algorithm.toLowerCase(Locale.ROOT) match { case "em" => new EMLDAOptimizer // add (1.0 / actualCorpusSize) to MiniBatchFraction be more robust on tiny datasets. case "online" => new OnlineLDAOptimizer().setMiniBatchFraction(0.05 + 1.0 / actualCorpusSize) diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/PCAOnRowMatrixExample.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/PCAOnRowMatrixExample.scala index a137ba2a2f9d3..da43a8d9c7e80 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/PCAOnRowMatrixExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/PCAOnRowMatrixExample.scala @@ -39,9 +39,9 @@ object PCAOnRowMatrixExample { Vectors.dense(2.0, 0.0, 3.0, 4.0, 5.0), Vectors.dense(4.0, 0.0, 0.0, 6.0, 7.0)) - val dataRDD = sc.parallelize(data, 2) + val rows = sc.parallelize(data) - val mat: RowMatrix = new RowMatrix(dataRDD) + val mat: RowMatrix = new RowMatrix(rows) // Compute the top 4 principal components. // Principal components are stored in a local dense matrix. diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/SVDExample.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/SVDExample.scala index b286a3f7b9096..769ae2a3a88b1 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/SVDExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/SVDExample.scala @@ -28,6 +28,9 @@ import org.apache.spark.mllib.linalg.Vectors import org.apache.spark.mllib.linalg.distributed.RowMatrix // $example off$ +/** + * Example for SingularValueDecomposition. + */ object SVDExample { def main(args: Array[String]): Unit = { @@ -41,15 +44,15 @@ object SVDExample { Vectors.dense(2.0, 0.0, 3.0, 4.0, 5.0), Vectors.dense(4.0, 0.0, 0.0, 6.0, 7.0)) - val dataRDD = sc.parallelize(data, 2) + val rows = sc.parallelize(data) - val mat: RowMatrix = new RowMatrix(dataRDD) + val mat: RowMatrix = new RowMatrix(rows) // Compute the top 5 singular values and corresponding singular vectors. val svd: SingularValueDecomposition[RowMatrix, Matrix] = mat.computeSVD(5, computeU = true) val U: RowMatrix = svd.U // The U factor is a RowMatrix. - val s: Vector = svd.s // The singular values are stored in a local dense vector. - val V: Matrix = svd.V // The V factor is a local dense matrix. + val s: Vector = svd.s // The singular values are stored in a local dense vector. + val V: Matrix = svd.V // The V factor is a local dense matrix. // $example off$ val collect = U.rows.collect() println("U factor is:") diff --git a/examples/src/main/scala/org/apache/spark/examples/sql/SQLDataSourceExample.scala b/examples/src/main/scala/org/apache/spark/examples/sql/SQLDataSourceExample.scala index 381e69cda841c..ad74da72bd5e6 100644 --- a/examples/src/main/scala/org/apache/spark/examples/sql/SQLDataSourceExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/sql/SQLDataSourceExample.scala @@ -139,7 +139,7 @@ object SQLDataSourceExample { // +------+ // Alternatively, a DataFrame can be created for a JSON dataset represented by - // an Dataset[String] storing one JSON object per string + // a Dataset[String] storing one JSON object per string val otherPeopleDataset = spark.createDataset( """{"name":"Yin","address":{"city":"Columbus","state":"Ohio"}}""" :: Nil) val otherPeople = spark.read.json(otherPeopleDataset) @@ -181,6 +181,11 @@ object SQLDataSourceExample { jdbcDF2.write .jdbc("jdbc:postgresql:dbserver", "schema.tablename", connectionProperties) + + // Specifying create table column data types on write + jdbcDF.write + .option("createTableColumnTypes", "name CHAR(64), comments VARCHAR(1024)") + .jdbc("jdbc:postgresql:dbserver", "schema.tablename", connectionProperties) // $example off:jdbc_dataset$ } } diff --git a/examples/src/main/scala/org/apache/spark/examples/sql/hive/SparkHiveExample.scala b/examples/src/main/scala/org/apache/spark/examples/sql/hive/SparkHiveExample.scala index 3de26364b5288..e5f75d53edc86 100644 --- a/examples/src/main/scala/org/apache/spark/examples/sql/hive/SparkHiveExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/sql/hive/SparkHiveExample.scala @@ -76,7 +76,7 @@ object SparkHiveExample { // The results of SQL queries are themselves DataFrames and support all normal functions. val sqlDF = sql("SELECT key, value FROM src WHERE key < 10 ORDER BY key") - // The items in DaraFrames are of type Row, which allows you to access each column by ordinal. + // The items in DataFrames are of type Row, which allows you to access each column by ordinal. val stringsDS = sqlDF.map { case Row(key: Int, value: String) => s"Key: $key, Value: $value" } diff --git a/examples/src/main/scala/org/apache/spark/examples/sql/streaming/StructuredSessionization.scala b/examples/src/main/scala/org/apache/spark/examples/sql/streaming/StructuredSessionization.scala new file mode 100644 index 0000000000000..2ce792c00849c --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/sql/streaming/StructuredSessionization.scala @@ -0,0 +1,151 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// scalastyle:off println +package org.apache.spark.examples.sql.streaming + +import java.sql.Timestamp + +import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.streaming._ + + +/** + * Counts words in UTF8 encoded, '\n' delimited text received from the network. + * + * Usage: MapGroupsWithState + * and describe the TCP server that Structured Streaming + * would connect to receive data. + * + * To run this on your local machine, you need to first run a Netcat server + * `$ nc -lk 9999` + * and then run the example + * `$ bin/run-example sql.streaming.StructuredNetworkWordCount + * localhost 9999` + */ +object StructuredSessionization { + + def main(args: Array[String]): Unit = { + if (args.length < 2) { + System.err.println("Usage: StructuredNetworkWordCount ") + System.exit(1) + } + + val host = args(0) + val port = args(1).toInt + + val spark = SparkSession + .builder + .appName("StructuredSessionization") + .getOrCreate() + + import spark.implicits._ + + // Create DataFrame representing the stream of input lines from connection to host:port + val lines = spark.readStream + .format("socket") + .option("host", host) + .option("port", port) + .option("includeTimestamp", true) + .load() + + // Split the lines into words, treat words as sessionId of events + val events = lines + .as[(String, Timestamp)] + .flatMap { case (line, timestamp) => + line.split(" ").map(word => Event(sessionId = word, timestamp)) + } + + // Sessionize the events. Track number of events, start and end timestamps of session, and + // and report session updates. + val sessionUpdates = events + .groupByKey(event => event.sessionId) + .mapGroupsWithState[SessionInfo, SessionUpdate](GroupStateTimeout.ProcessingTimeTimeout) { + + case (sessionId: String, events: Iterator[Event], state: GroupState[SessionInfo]) => + + // If timed out, then remove session and send final update + if (state.hasTimedOut) { + val finalUpdate = + SessionUpdate(sessionId, state.get.durationMs, state.get.numEvents, expired = true) + state.remove() + finalUpdate + } else { + // Update start and end timestamps in session + val timestamps = events.map(_.timestamp.getTime).toSeq + val updatedSession = if (state.exists) { + val oldSession = state.get + SessionInfo( + oldSession.numEvents + timestamps.size, + oldSession.startTimestampMs, + math.max(oldSession.endTimestampMs, timestamps.max)) + } else { + SessionInfo(timestamps.size, timestamps.min, timestamps.max) + } + state.update(updatedSession) + + // Set timeout such that the session will be expired if no data received for 10 seconds + state.setTimeoutDuration("10 seconds") + SessionUpdate(sessionId, state.get.durationMs, state.get.numEvents, expired = false) + } + } + + // Start running the query that prints the session updates to the console + val query = sessionUpdates + .writeStream + .outputMode("update") + .format("console") + .start() + + query.awaitTermination() + } +} +/** User-defined data type representing the input events */ +case class Event(sessionId: String, timestamp: Timestamp) + +/** + * User-defined data type for storing a session information as state in mapGroupsWithState. + * + * @param numEvents total number of events received in the session + * @param startTimestampMs timestamp of first event received in the session when it started + * @param endTimestampMs timestamp of last event received in the session before it expired + */ +case class SessionInfo( + numEvents: Int, + startTimestampMs: Long, + endTimestampMs: Long) { + + /** Duration of the session, between the first and last events */ + def durationMs: Long = endTimestampMs - startTimestampMs +} + +/** + * User-defined data type representing the update information returned by mapGroupsWithState. + * + * @param id Id of the session + * @param durationMs Duration the session was active, that is, from first event to its expiry + * @param numEvents Number of events received by the session while it was active + * @param expired Is the session active or expired + */ +case class SessionUpdate( + id: String, + durationMs: Long, + numEvents: Int, + expired: Boolean) + +// scalastyle:on println + diff --git a/external/docker-integration-tests/pom.xml b/external/docker-integration-tests/pom.xml index 8948df2da89e2..0fa87a697454b 100644 --- a/external/docker-integration-tests/pom.xml +++ b/external/docker-integration-tests/pom.xml @@ -22,7 +22,7 @@ org.apache.spark spark-parent_2.11 - 2.2.0-SNAPSHOT + 2.3.0-SNAPSHOT ../../pom.xml diff --git a/external/flume-assembly/pom.xml b/external/flume-assembly/pom.xml index f8ef8a991316d..71016bc645ca7 100644 --- a/external/flume-assembly/pom.xml +++ b/external/flume-assembly/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.11 - 2.2.0-SNAPSHOT + 2.3.0-SNAPSHOT ../../pom.xml diff --git a/external/flume-sink/pom.xml b/external/flume-sink/pom.xml index 6d547c46d6a2d..12630840e79dc 100644 --- a/external/flume-sink/pom.xml +++ b/external/flume-sink/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.11 - 2.2.0-SNAPSHOT + 2.3.0-SNAPSHOT ../../pom.xml diff --git a/external/flume/pom.xml b/external/flume/pom.xml index 46901d64eda97..87a09642405a7 100644 --- a/external/flume/pom.xml +++ b/external/flume/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.11 - 2.2.0-SNAPSHOT + 2.3.0-SNAPSHOT ../../pom.xml diff --git a/external/kafka-0-10-assembly/pom.xml b/external/kafka-0-10-assembly/pom.xml index 295142cbfdff9..75df886ca44f6 100644 --- a/external/kafka-0-10-assembly/pom.xml +++ b/external/kafka-0-10-assembly/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.11 - 2.2.0-SNAPSHOT + 2.3.0-SNAPSHOT ../../pom.xml diff --git a/external/kafka-0-10-sql/pom.xml b/external/kafka-0-10-sql/pom.xml index 6cf448e65e8b4..557d27296345f 100644 --- a/external/kafka-0-10-sql/pom.xml +++ b/external/kafka-0-10-sql/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.11 - 2.2.0-SNAPSHOT + 2.3.0-SNAPSHOT ../../pom.xml diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/CachedKafkaConsumer.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/CachedKafkaConsumer.scala index 15b28256e825e..7c4f38e02fb2a 100644 --- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/CachedKafkaConsumer.scala +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/CachedKafkaConsumer.scala @@ -28,6 +28,7 @@ import org.apache.kafka.common.TopicPartition import org.apache.spark.{SparkEnv, SparkException, TaskContext} import org.apache.spark.internal.Logging import org.apache.spark.sql.kafka010.KafkaSource._ +import org.apache.spark.util.UninterruptibleThread /** @@ -62,11 +63,20 @@ private[kafka010] case class CachedKafkaConsumer private( case class AvailableOffsetRange(earliest: Long, latest: Long) + private def runUninterruptiblyIfPossible[T](body: => T): T = Thread.currentThread match { + case ut: UninterruptibleThread => + ut.runUninterruptibly(body) + case _ => + logWarning("CachedKafkaConsumer is not running in UninterruptibleThread. " + + "It may hang when CachedKafkaConsumer's methods are interrupted because of KAFKA-1894") + body + } + /** * Return the available offset range of the current partition. It's a pair of the earliest offset * and the latest offset. */ - def getAvailableOffsetRange(): AvailableOffsetRange = { + def getAvailableOffsetRange(): AvailableOffsetRange = runUninterruptiblyIfPossible { consumer.seekToBeginning(Set(topicPartition).asJava) val earliestOffset = consumer.position(topicPartition) consumer.seekToEnd(Set(topicPartition).asJava) @@ -92,7 +102,8 @@ private[kafka010] case class CachedKafkaConsumer private( offset: Long, untilOffset: Long, pollTimeoutMs: Long, - failOnDataLoss: Boolean): ConsumerRecord[Array[Byte], Array[Byte]] = { + failOnDataLoss: Boolean): + ConsumerRecord[Array[Byte], Array[Byte]] = runUninterruptiblyIfPossible { require(offset < untilOffset, s"offset must always be less than untilOffset [offset: $offset, untilOffset: $untilOffset]") logDebug(s"Get $groupId $topicPartition nextOffset $nextOffsetInFetchedData requested $offset") @@ -273,22 +284,10 @@ private[kafka010] case class CachedKafkaConsumer private( message: String, cause: Throwable = null): Unit = { val finalMessage = s"$message ${additionalMessage(failOnDataLoss)}" - if (failOnDataLoss) { - if (cause != null) { - throw new IllegalStateException(finalMessage) - } else { - throw new IllegalStateException(finalMessage, cause) - } - } else { - if (cause != null) { - logWarning(finalMessage) - } else { - logWarning(finalMessage, cause) - } - } + reportDataLoss0(failOnDataLoss, finalMessage, cause) } - private def close(): Unit = consumer.close() + def close(): Unit = consumer.close() private def seek(offset: Long): Unit = { logDebug(s"Seeking to $groupId $topicPartition $offset") @@ -383,7 +382,7 @@ private[kafka010] object CachedKafkaConsumer extends Logging { // If this is reattempt at running the task, then invalidate cache and start with // a new consumer - if (TaskContext.get != null && TaskContext.get.attemptNumber > 1) { + if (TaskContext.get != null && TaskContext.get.attemptNumber >= 1) { removeKafkaConsumer(topic, partition, kafkaParams) val consumer = new CachedKafkaConsumer(topicPartition, kafkaParams) consumer.inuse = true @@ -398,4 +397,31 @@ private[kafka010] object CachedKafkaConsumer extends Logging { consumer } } + + /** Create an [[CachedKafkaConsumer]] but don't put it into cache. */ + def createUncached( + topic: String, + partition: Int, + kafkaParams: ju.Map[String, Object]): CachedKafkaConsumer = { + new CachedKafkaConsumer(new TopicPartition(topic, partition), kafkaParams) + } + + private def reportDataLoss0( + failOnDataLoss: Boolean, + finalMessage: String, + cause: Throwable = null): Unit = { + if (failOnDataLoss) { + if (cause != null) { + throw new IllegalStateException(finalMessage, cause) + } else { + throw new IllegalStateException(finalMessage) + } + } else { + if (cause != null) { + logWarning(finalMessage, cause) + } else { + logWarning(finalMessage) + } + } + } } diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaOffsetReader.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaOffsetReader.scala index 2696d6f089d2f..3e65949a6fd1b 100644 --- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaOffsetReader.scala +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaOffsetReader.scala @@ -95,8 +95,10 @@ private[kafka010] class KafkaOffsetReader( * Closes the connection to Kafka, and cleans up state. */ def close(): Unit = { - consumer.close() - kafkaReaderThread.shutdownNow() + runUninterruptibly { + consumer.close() + } + kafkaReaderThread.shutdown() } /** diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaRelation.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaRelation.scala index f180bbad6e363..97bd283169323 100644 --- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaRelation.scala +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaRelation.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.kafka010 import java.{util => ju} +import java.util.UUID import org.apache.kafka.common.TopicPartition @@ -33,9 +34,9 @@ import org.apache.spark.unsafe.types.UTF8String private[kafka010] class KafkaRelation( override val sqlContext: SQLContext, - kafkaReader: KafkaOffsetReader, - executorKafkaParams: ju.Map[String, Object], + strategy: ConsumerStrategy, sourceOptions: Map[String, String], + specifiedKafkaParams: Map[String, String], failOnDataLoss: Boolean, startingOffsets: KafkaOffsetRangeLimit, endingOffsets: KafkaOffsetRangeLimit) @@ -53,9 +54,27 @@ private[kafka010] class KafkaRelation( override def schema: StructType = KafkaOffsetReader.kafkaSchema override def buildScan(): RDD[Row] = { + // Each running query should use its own group id. Otherwise, the query may be only assigned + // partial data since Kafka will assign partitions to multiple consumers having the same group + // id. Hence, we should generate a unique id for each query. + val uniqueGroupId = s"spark-kafka-relation-${UUID.randomUUID}" + + val kafkaOffsetReader = new KafkaOffsetReader( + strategy, + KafkaSourceProvider.kafkaParamsForDriver(specifiedKafkaParams), + sourceOptions, + driverGroupIdPrefix = s"$uniqueGroupId-driver") + // Leverage the KafkaReader to obtain the relevant partition offsets - val fromPartitionOffsets = getPartitionOffsets(startingOffsets) - val untilPartitionOffsets = getPartitionOffsets(endingOffsets) + val (fromPartitionOffsets, untilPartitionOffsets) = { + try { + (getPartitionOffsets(kafkaOffsetReader, startingOffsets), + getPartitionOffsets(kafkaOffsetReader, endingOffsets)) + } finally { + kafkaOffsetReader.close() + } + } + // Obtain topicPartitions in both from and until partition offset, ignoring // topic partitions that were added and/or deleted between the two above calls. if (fromPartitionOffsets.keySet != untilPartitionOffsets.keySet) { @@ -82,6 +101,8 @@ private[kafka010] class KafkaRelation( offsetRanges.sortBy(_.topicPartition.toString).mkString(", ")) // Create an RDD that reads from Kafka and get the (key, value) pair as byte arrays. + val executorKafkaParams = + KafkaSourceProvider.kafkaParamsForExecutors(specifiedKafkaParams, uniqueGroupId) val rdd = new KafkaSourceRDD( sqlContext.sparkContext, executorKafkaParams, offsetRanges, pollTimeoutMs, failOnDataLoss, reuseKafkaConsumer = false).map { cr => @@ -98,6 +119,7 @@ private[kafka010] class KafkaRelation( } private def getPartitionOffsets( + kafkaReader: KafkaOffsetReader, kafkaOffsets: KafkaOffsetRangeLimit): Map[TopicPartition, Long] = { def validateTopicPartitions(partitions: Set[TopicPartition], partitionOffsets: Map[TopicPartition, Long]): Map[TopicPartition, Long] = { diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSink.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSink.scala new file mode 100644 index 0000000000000..08914d82fffdd --- /dev/null +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSink.scala @@ -0,0 +1,43 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.kafka010 + +import java.{util => ju} + +import org.apache.spark.internal.Logging +import org.apache.spark.sql.{DataFrame, SQLContext} +import org.apache.spark.sql.execution.streaming.Sink + +private[kafka010] class KafkaSink( + sqlContext: SQLContext, + executorKafkaParams: ju.Map[String, Object], + topic: Option[String]) extends Sink with Logging { + @volatile private var latestBatchId = -1L + + override def toString(): String = "KafkaSink" + + override def addBatch(batchId: Long, data: DataFrame): Unit = { + if (batchId <= latestBatchId) { + logInfo(s"Skipping already committed batch $batchId") + } else { + KafkaWriter.write(sqlContext.sparkSession, + data.queryExecution, executorKafkaParams, topic) + latestBatchId = batchId + } + } +} diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSource.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSource.scala index 92b5d91ba435e..1fb0a338299b7 100644 --- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSource.scala +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSource.scala @@ -100,7 +100,7 @@ private[kafka010] class KafkaSource( override def serialize(metadata: KafkaSourceOffset, out: OutputStream): Unit = { out.write(0) // A zero byte is written to support Spark 2.1.0 (SPARK-19517) val writer = new BufferedWriter(new OutputStreamWriter(out, StandardCharsets.UTF_8)) - writer.write(VERSION) + writer.write("v" + VERSION + "\n") writer.write(metadata.json) writer.flush } @@ -111,13 +111,13 @@ private[kafka010] class KafkaSource( // HDFSMetadataLog guarantees that it never creates a partial file. assert(content.length != 0) if (content(0) == 'v') { - if (content.startsWith(VERSION)) { - KafkaSourceOffset(SerializedOffset(content.substring(VERSION.length))) + val indexOfNewLine = content.indexOf("\n") + if (indexOfNewLine > 0) { + val version = parseVersion(content.substring(0, indexOfNewLine), VERSION) + KafkaSourceOffset(SerializedOffset(content.substring(indexOfNewLine + 1))) } else { - val versionInFile = content.substring(0, content.indexOf("\n")) throw new IllegalStateException( - s"Unsupported format. Expected version is ${VERSION.stripLineEnd} " + - s"but was $versionInFile. Please upgrade your Spark.") + s"Log file was malformed: failed to detect the log file version line.") } } else { // The log was generated by Spark 2.1.0 @@ -351,7 +351,7 @@ private[kafka010] object KafkaSource { | source option "failOnDataLoss" to "false". """.stripMargin - private val VERSION = "v1\n" + private[kafka010] val VERSION = 1 def getSortedExecutorList(sc: SparkContext): Array[String] = { val bm = sc.env.blockManager diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceProvider.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceProvider.scala index 6a7456719875f..3cb4d8cad12cc 100644 --- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceProvider.scala +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceProvider.scala @@ -18,17 +18,19 @@ package org.apache.spark.sql.kafka010 import java.{util => ju} -import java.util.UUID +import java.util.{Locale, UUID} import scala.collection.JavaConverters._ import org.apache.kafka.clients.consumer.ConsumerConfig -import org.apache.kafka.common.serialization.ByteArrayDeserializer +import org.apache.kafka.clients.producer.ProducerConfig +import org.apache.kafka.common.serialization.{ByteArrayDeserializer, ByteArraySerializer} import org.apache.spark.internal.Logging -import org.apache.spark.sql.SQLContext -import org.apache.spark.sql.execution.streaming.Source +import org.apache.spark.sql.{AnalysisException, DataFrame, SaveMode, SQLContext} +import org.apache.spark.sql.execution.streaming.{Sink, Source} import org.apache.spark.sql.sources._ +import org.apache.spark.sql.streaming.OutputMode import org.apache.spark.sql.types.StructType /** @@ -36,8 +38,12 @@ import org.apache.spark.sql.types.StructType * IllegalArgumentException when the Kafka Dataset is created, so that it can catch * missing options even before the query is started. */ -private[kafka010] class KafkaSourceProvider extends DataSourceRegister with StreamSourceProvider - with RelationProvider with Logging { +private[kafka010] class KafkaSourceProvider extends DataSourceRegister + with StreamSourceProvider + with StreamSinkProvider + with RelationProvider + with CreatableRelationProvider + with Logging { import KafkaSourceProvider._ override def shortName(): String = "kafka" @@ -68,21 +74,16 @@ private[kafka010] class KafkaSourceProvider extends DataSourceRegister with Stre // id. Hence, we should generate a unique id for each query. val uniqueGroupId = s"spark-kafka-source-${UUID.randomUUID}-${metadataPath.hashCode}" - val caseInsensitiveParams = parameters.map { case (k, v) => (k.toLowerCase, v) } + val caseInsensitiveParams = parameters.map { case (k, v) => (k.toLowerCase(Locale.ROOT), v) } val specifiedKafkaParams = parameters .keySet - .filter(_.toLowerCase.startsWith("kafka.")) + .filter(_.toLowerCase(Locale.ROOT).startsWith("kafka.")) .map { k => k.drop(6).toString -> parameters(k) } .toMap - val startingStreamOffsets = - caseInsensitiveParams.get(STARTING_OFFSETS_OPTION_KEY).map(_.trim.toLowerCase) match { - case Some("latest") => LatestOffsetRangeLimit - case Some("earliest") => EarliestOffsetRangeLimit - case Some(json) => SpecificOffsetRangeLimit(JsonUtils.partitionOffsets(json)) - case None => LatestOffsetRangeLimit - } + val startingStreamOffsets = KafkaSourceProvider.getKafkaOffsetRangeLimit(caseInsensitiveParams, + STARTING_OFFSETS_OPTION_KEY, LatestOffsetRangeLimit) val kafkaOffsetReader = new KafkaOffsetReader( strategy(caseInsensitiveParams), @@ -110,87 +111,97 @@ private[kafka010] class KafkaSourceProvider extends DataSourceRegister with Stre sqlContext: SQLContext, parameters: Map[String, String]): BaseRelation = { validateBatchOptions(parameters) - // Each running query should use its own group id. Otherwise, the query may be only assigned - // partial data since Kafka will assign partitions to multiple consumers having the same group - // id. Hence, we should generate a unique id for each query. - val uniqueGroupId = s"spark-kafka-relation-${UUID.randomUUID}" - val caseInsensitiveParams = parameters.map { case (k, v) => (k.toLowerCase, v) } + val caseInsensitiveParams = parameters.map { case (k, v) => (k.toLowerCase(Locale.ROOT), v) } val specifiedKafkaParams = parameters .keySet - .filter(_.toLowerCase.startsWith("kafka.")) + .filter(_.toLowerCase(Locale.ROOT).startsWith("kafka.")) .map { k => k.drop(6).toString -> parameters(k) } .toMap - val startingRelationOffsets = - caseInsensitiveParams.get(STARTING_OFFSETS_OPTION_KEY).map(_.trim.toLowerCase) match { - case Some("earliest") => EarliestOffsetRangeLimit - case Some(json) => SpecificOffsetRangeLimit(JsonUtils.partitionOffsets(json)) - case None => EarliestOffsetRangeLimit - } + val startingRelationOffsets = KafkaSourceProvider.getKafkaOffsetRangeLimit( + caseInsensitiveParams, STARTING_OFFSETS_OPTION_KEY, EarliestOffsetRangeLimit) + assert(startingRelationOffsets != LatestOffsetRangeLimit) - val endingRelationOffsets = - caseInsensitiveParams.get(ENDING_OFFSETS_OPTION_KEY).map(_.trim.toLowerCase) match { - case Some("latest") => LatestOffsetRangeLimit - case Some(json) => SpecificOffsetRangeLimit(JsonUtils.partitionOffsets(json)) - case None => LatestOffsetRangeLimit - } - - val kafkaOffsetReader = new KafkaOffsetReader( - strategy(caseInsensitiveParams), - kafkaParamsForDriver(specifiedKafkaParams), - parameters, - driverGroupIdPrefix = s"$uniqueGroupId-driver") + val endingRelationOffsets = KafkaSourceProvider.getKafkaOffsetRangeLimit(caseInsensitiveParams, + ENDING_OFFSETS_OPTION_KEY, LatestOffsetRangeLimit) + assert(endingRelationOffsets != EarliestOffsetRangeLimit) new KafkaRelation( sqlContext, - kafkaOffsetReader, - kafkaParamsForExecutors(specifiedKafkaParams, uniqueGroupId), - parameters, - failOnDataLoss(caseInsensitiveParams), - startingRelationOffsets, - endingRelationOffsets) + strategy(caseInsensitiveParams), + sourceOptions = parameters, + specifiedKafkaParams = specifiedKafkaParams, + failOnDataLoss = failOnDataLoss(caseInsensitiveParams), + startingOffsets = startingRelationOffsets, + endingOffsets = endingRelationOffsets) } - private def kafkaParamsForDriver(specifiedKafkaParams: Map[String, String]) = - ConfigUpdater("source", specifiedKafkaParams) - .set(ConsumerConfig.KEY_DESERIALIZER_CLASS_CONFIG, deserClassName) - .set(ConsumerConfig.VALUE_DESERIALIZER_CLASS_CONFIG, deserClassName) - - // Set to "earliest" to avoid exceptions. However, KafkaSource will fetch the initial - // offsets by itself instead of counting on KafkaConsumer. - .set(ConsumerConfig.AUTO_OFFSET_RESET_CONFIG, "earliest") - - // So that consumers in the driver does not commit offsets unnecessarily - .set(ConsumerConfig.ENABLE_AUTO_COMMIT_CONFIG, "false") - - // So that the driver does not pull too much data - .set(ConsumerConfig.MAX_POLL_RECORDS_CONFIG, new java.lang.Integer(1)) - - // If buffer config is not set, set it to reasonable value to work around - // buffer issues (see KAFKA-3135) - .setIfUnset(ConsumerConfig.RECEIVE_BUFFER_CONFIG, 65536: java.lang.Integer) - .build() - - private def kafkaParamsForExecutors( - specifiedKafkaParams: Map[String, String], uniqueGroupId: String) = - ConfigUpdater("executor", specifiedKafkaParams) - .set(ConsumerConfig.KEY_DESERIALIZER_CLASS_CONFIG, deserClassName) - .set(ConsumerConfig.VALUE_DESERIALIZER_CLASS_CONFIG, deserClassName) - - // Make sure executors do only what the driver tells them. - .set(ConsumerConfig.AUTO_OFFSET_RESET_CONFIG, "none") + override def createSink( + sqlContext: SQLContext, + parameters: Map[String, String], + partitionColumns: Seq[String], + outputMode: OutputMode): Sink = { + val defaultTopic = parameters.get(TOPIC_OPTION_KEY).map(_.trim) + val specifiedKafkaParams = kafkaParamsForProducer(parameters) + new KafkaSink(sqlContext, + new ju.HashMap[String, Object](specifiedKafkaParams.asJava), defaultTopic) + } - // So that consumers in executors do not mess with any existing group id - .set(ConsumerConfig.GROUP_ID_CONFIG, s"$uniqueGroupId-executor") + override def createRelation( + outerSQLContext: SQLContext, + mode: SaveMode, + parameters: Map[String, String], + data: DataFrame): BaseRelation = { + mode match { + case SaveMode.Overwrite | SaveMode.Ignore => + throw new AnalysisException(s"Save mode $mode not allowed for Kafka. " + + s"Allowed save modes are ${SaveMode.Append} and " + + s"${SaveMode.ErrorIfExists} (default).") + case _ => // good + } + val topic = parameters.get(TOPIC_OPTION_KEY).map(_.trim) + val specifiedKafkaParams = kafkaParamsForProducer(parameters) + KafkaWriter.write(outerSQLContext.sparkSession, data.queryExecution, + new ju.HashMap[String, Object](specifiedKafkaParams.asJava), topic) + + /* This method is suppose to return a relation that reads the data that was written. + * We cannot support this for Kafka. Therefore, in order to make things consistent, + * we return an empty base relation. + */ + new BaseRelation { + override def sqlContext: SQLContext = unsupportedException + override def schema: StructType = unsupportedException + override def needConversion: Boolean = unsupportedException + override def sizeInBytes: Long = unsupportedException + override def unhandledFilters(filters: Array[Filter]): Array[Filter] = unsupportedException + private def unsupportedException = + throw new UnsupportedOperationException("BaseRelation from Kafka write " + + "operation is not usable.") + } + } - // So that consumers in executors does not commit offsets unnecessarily - .set(ConsumerConfig.ENABLE_AUTO_COMMIT_CONFIG, "false") + private def kafkaParamsForProducer(parameters: Map[String, String]): Map[String, String] = { + val caseInsensitiveParams = parameters.map { case (k, v) => (k.toLowerCase(Locale.ROOT), v) } + if (caseInsensitiveParams.contains(s"kafka.${ProducerConfig.KEY_SERIALIZER_CLASS_CONFIG}")) { + throw new IllegalArgumentException( + s"Kafka option '${ProducerConfig.KEY_SERIALIZER_CLASS_CONFIG}' is not supported as keys " + + "are serialized with ByteArraySerializer.") + } - // If buffer config is not set, set it to reasonable value to work around - // buffer issues (see KAFKA-3135) - .setIfUnset(ConsumerConfig.RECEIVE_BUFFER_CONFIG, 65536: java.lang.Integer) - .build() + if (caseInsensitiveParams.contains(s"kafka.${ProducerConfig.VALUE_SERIALIZER_CLASS_CONFIG}")) + { + throw new IllegalArgumentException( + s"Kafka option '${ProducerConfig.VALUE_SERIALIZER_CLASS_CONFIG}' is not supported as " + + "value are serialized with ByteArraySerializer.") + } + parameters + .keySet + .filter(_.toLowerCase(Locale.ROOT).startsWith("kafka.")) + .map { k => k.drop(6).toString -> parameters(k) } + .toMap + (ProducerConfig.KEY_SERIALIZER_CLASS_CONFIG -> classOf[ByteArraySerializer].getName, + ProducerConfig.VALUE_SERIALIZER_CLASS_CONFIG -> classOf[ByteArraySerializer].getName) + } private def strategy(caseInsensitiveParams: Map[String, String]) = caseInsensitiveParams.find(x => STRATEGY_OPTION_KEYS.contains(x._1)).get match { @@ -211,7 +222,7 @@ private[kafka010] class KafkaSourceProvider extends DataSourceRegister with Stre private def validateGeneralOptions(parameters: Map[String, String]): Unit = { // Validate source options - val caseInsensitiveParams = parameters.map { case (k, v) => (k.toLowerCase, v) } + val caseInsensitiveParams = parameters.map { case (k, v) => (k.toLowerCase(Locale.ROOT), v) } val specifiedStrategies = caseInsensitiveParams.filter { case (k, _) => STRATEGY_OPTION_KEYS.contains(k) }.toSeq @@ -316,34 +327,34 @@ private[kafka010] class KafkaSourceProvider extends DataSourceRegister with Stre private def validateBatchOptions(caseInsensitiveParams: Map[String, String]) = { // Batch specific options - caseInsensitiveParams.get(STARTING_OFFSETS_OPTION_KEY).map(_.trim.toLowerCase) match { - case Some("earliest") => // good to go - case Some("latest") => + KafkaSourceProvider.getKafkaOffsetRangeLimit( + caseInsensitiveParams, STARTING_OFFSETS_OPTION_KEY, EarliestOffsetRangeLimit) match { + case EarliestOffsetRangeLimit => // good to go + case LatestOffsetRangeLimit => throw new IllegalArgumentException("starting offset can't be latest " + "for batch queries on Kafka") - case Some(json) => (SpecificOffsetRangeLimit(JsonUtils.partitionOffsets(json))) - .partitionOffsets.foreach { + case SpecificOffsetRangeLimit(partitionOffsets) => + partitionOffsets.foreach { case (tp, off) if off == KafkaOffsetRangeLimit.LATEST => throw new IllegalArgumentException(s"startingOffsets for $tp can't " + "be latest for batch queries on Kafka") case _ => // ignore } - case _ => // default to earliest } - caseInsensitiveParams.get(ENDING_OFFSETS_OPTION_KEY).map(_.trim.toLowerCase) match { - case Some("earliest") => + KafkaSourceProvider.getKafkaOffsetRangeLimit( + caseInsensitiveParams, ENDING_OFFSETS_OPTION_KEY, LatestOffsetRangeLimit) match { + case EarliestOffsetRangeLimit => throw new IllegalArgumentException("ending offset can't be earliest " + "for batch queries on Kafka") - case Some("latest") => // good to go - case Some(json) => (SpecificOffsetRangeLimit(JsonUtils.partitionOffsets(json))) - .partitionOffsets.foreach { + case LatestOffsetRangeLimit => // good to go + case SpecificOffsetRangeLimit(partitionOffsets) => + partitionOffsets.foreach { case (tp, off) if off == KafkaOffsetRangeLimit.EARLIEST => throw new IllegalArgumentException(s"ending offset for $tp can't be " + "earliest for batch queries on Kafka") case _ => // ignore } - case _ => // default to latest } validateGeneralOptions(caseInsensitiveParams) @@ -353,6 +364,71 @@ private[kafka010] class KafkaSourceProvider extends DataSourceRegister with Stre logWarning("maxOffsetsPerTrigger option ignored in batch queries") } } +} + +private[kafka010] object KafkaSourceProvider extends Logging { + private val STRATEGY_OPTION_KEYS = Set("subscribe", "subscribepattern", "assign") + private[kafka010] val STARTING_OFFSETS_OPTION_KEY = "startingoffsets" + private[kafka010] val ENDING_OFFSETS_OPTION_KEY = "endingoffsets" + private val FAIL_ON_DATA_LOSS_OPTION_KEY = "failondataloss" + val TOPIC_OPTION_KEY = "topic" + + private val deserClassName = classOf[ByteArrayDeserializer].getName + + def getKafkaOffsetRangeLimit( + params: Map[String, String], + offsetOptionKey: String, + defaultOffsets: KafkaOffsetRangeLimit): KafkaOffsetRangeLimit = { + params.get(offsetOptionKey).map(_.trim) match { + case Some(offset) if offset.toLowerCase(Locale.ROOT) == "latest" => + LatestOffsetRangeLimit + case Some(offset) if offset.toLowerCase(Locale.ROOT) == "earliest" => + EarliestOffsetRangeLimit + case Some(json) => SpecificOffsetRangeLimit(JsonUtils.partitionOffsets(json)) + case None => defaultOffsets + } + } + + def kafkaParamsForDriver(specifiedKafkaParams: Map[String, String]): ju.Map[String, Object] = + ConfigUpdater("source", specifiedKafkaParams) + .set(ConsumerConfig.KEY_DESERIALIZER_CLASS_CONFIG, deserClassName) + .set(ConsumerConfig.VALUE_DESERIALIZER_CLASS_CONFIG, deserClassName) + + // Set to "earliest" to avoid exceptions. However, KafkaSource will fetch the initial + // offsets by itself instead of counting on KafkaConsumer. + .set(ConsumerConfig.AUTO_OFFSET_RESET_CONFIG, "earliest") + + // So that consumers in the driver does not commit offsets unnecessarily + .set(ConsumerConfig.ENABLE_AUTO_COMMIT_CONFIG, "false") + + // So that the driver does not pull too much data + .set(ConsumerConfig.MAX_POLL_RECORDS_CONFIG, new java.lang.Integer(1)) + + // If buffer config is not set, set it to reasonable value to work around + // buffer issues (see KAFKA-3135) + .setIfUnset(ConsumerConfig.RECEIVE_BUFFER_CONFIG, 65536: java.lang.Integer) + .build() + + def kafkaParamsForExecutors( + specifiedKafkaParams: Map[String, String], + uniqueGroupId: String): ju.Map[String, Object] = + ConfigUpdater("executor", specifiedKafkaParams) + .set(ConsumerConfig.KEY_DESERIALIZER_CLASS_CONFIG, deserClassName) + .set(ConsumerConfig.VALUE_DESERIALIZER_CLASS_CONFIG, deserClassName) + + // Make sure executors do only what the driver tells them. + .set(ConsumerConfig.AUTO_OFFSET_RESET_CONFIG, "none") + + // So that consumers in executors do not mess with any existing group id + .set(ConsumerConfig.GROUP_ID_CONFIG, s"$uniqueGroupId-executor") + + // So that consumers in executors does not commit offsets unnecessarily + .set(ConsumerConfig.ENABLE_AUTO_COMMIT_CONFIG, "false") + + // If buffer config is not set, set it to reasonable value to work around + // buffer issues (see KAFKA-3135) + .setIfUnset(ConsumerConfig.RECEIVE_BUFFER_CONFIG, 65536: java.lang.Integer) + .build() /** Class to conveniently update Kafka config params, while logging the changes */ private case class ConfigUpdater(module: String, kafkaParams: Map[String, String]) { @@ -360,14 +436,14 @@ private[kafka010] class KafkaSourceProvider extends DataSourceRegister with Stre def set(key: String, value: Object): this.type = { map.put(key, value) - logInfo(s"$module: Set $key to $value, earlier value: ${kafkaParams.get(key).getOrElse("")}") + logDebug(s"$module: Set $key to $value, earlier value: ${kafkaParams.getOrElse(key, "")}") this } def setIfUnset(key: String, value: Object): ConfigUpdater = { if (!map.containsKey(key)) { map.put(key, value) - logInfo(s"$module: Set $key to $value") + logDebug(s"$module: Set $key to $value") } this } @@ -375,12 +451,3 @@ private[kafka010] class KafkaSourceProvider extends DataSourceRegister with Stre def build(): ju.Map[String, Object] = map } } - -private[kafka010] object KafkaSourceProvider { - private val STRATEGY_OPTION_KEYS = Set("subscribe", "subscribepattern", "assign") - private val STARTING_OFFSETS_OPTION_KEY = "startingoffsets" - private val ENDING_OFFSETS_OPTION_KEY = "endingoffsets" - private val FAIL_ON_DATA_LOSS_OPTION_KEY = "failondataloss" - - private val deserClassName = classOf[ByteArrayDeserializer].getName -} diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceRDD.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceRDD.scala index 6fb3473eb75f5..9d9e2aaba8079 100644 --- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceRDD.scala +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceRDD.scala @@ -125,16 +125,15 @@ private[kafka010] class KafkaSourceRDD( context: TaskContext): Iterator[ConsumerRecord[Array[Byte], Array[Byte]]] = { val sourcePartition = thePart.asInstanceOf[KafkaSourceRDDPartition] val topic = sourcePartition.offsetRange.topic - if (!reuseKafkaConsumer) { - // if we can't reuse CachedKafkaConsumers, let's reset the groupId to something unique - // to each task (i.e., append the task's unique partition id), because we will have - // multiple tasks (e.g., in the case of union) reading from the same topic partitions - val old = executorKafkaParams.get(ConsumerConfig.GROUP_ID_CONFIG).asInstanceOf[String] - val id = TaskContext.getPartitionId() - executorKafkaParams.put(ConsumerConfig.GROUP_ID_CONFIG, old + "-" + id) - } val kafkaPartition = sourcePartition.offsetRange.partition - val consumer = CachedKafkaConsumer.getOrCreate(topic, kafkaPartition, executorKafkaParams) + val consumer = + if (!reuseKafkaConsumer) { + // If we can't reuse CachedKafkaConsumers, creating a new CachedKafkaConsumer. As here we + // uses `assign`, we don't need to worry about the "group.id" conflicts. + CachedKafkaConsumer.createUncached(topic, kafkaPartition, executorKafkaParams) + } else { + CachedKafkaConsumer.getOrCreate(topic, kafkaPartition, executorKafkaParams) + } val range = resolveRange(consumer, sourcePartition.offsetRange) assert( range.fromOffset <= range.untilOffset, @@ -170,7 +169,7 @@ private[kafka010] class KafkaSourceRDD( override protected def close(): Unit = { if (!reuseKafkaConsumer) { // Don't forget to close non-reuse KafkaConsumers. You may take down your cluster! - CachedKafkaConsumer.removeKafkaConsumer(topic, kafkaPartition, executorKafkaParams) + consumer.close() } else { // Indicate that we're no longer using this consumer CachedKafkaConsumer.releaseKafkaConsumer(topic, kafkaPartition, executorKafkaParams) diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaWriteTask.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaWriteTask.scala new file mode 100644 index 0000000000000..6e160cbe2db52 --- /dev/null +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaWriteTask.scala @@ -0,0 +1,123 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.kafka010 + +import java.{util => ju} + +import org.apache.kafka.clients.producer.{KafkaProducer, _} +import org.apache.kafka.common.serialization.ByteArraySerializer + +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.{Attribute, Cast, Literal, UnsafeProjection} +import org.apache.spark.sql.types.{BinaryType, StringType} + +/** + * A simple trait for writing out data in a single Spark task, without any concerns about how + * to commit or abort tasks. Exceptions thrown by the implementation of this class will + * automatically trigger task aborts. + */ +private[kafka010] class KafkaWriteTask( + producerConfiguration: ju.Map[String, Object], + inputSchema: Seq[Attribute], + topic: Option[String]) { + // used to synchronize with Kafka callbacks + @volatile private var failedWrite: Exception = null + private val projection = createProjection + private var producer: KafkaProducer[Array[Byte], Array[Byte]] = _ + + /** + * Writes key value data out to topics. + */ + def execute(iterator: Iterator[InternalRow]): Unit = { + producer = new KafkaProducer[Array[Byte], Array[Byte]](producerConfiguration) + while (iterator.hasNext && failedWrite == null) { + val currentRow = iterator.next() + val projectedRow = projection(currentRow) + val topic = projectedRow.getUTF8String(0) + val key = projectedRow.getBinary(1) + val value = projectedRow.getBinary(2) + if (topic == null) { + throw new NullPointerException(s"null topic present in the data. Use the " + + s"${KafkaSourceProvider.TOPIC_OPTION_KEY} option for setting a default topic.") + } + val record = new ProducerRecord[Array[Byte], Array[Byte]](topic.toString, key, value) + val callback = new Callback() { + override def onCompletion(recordMetadata: RecordMetadata, e: Exception): Unit = { + if (failedWrite == null && e != null) { + failedWrite = e + } + } + } + producer.send(record, callback) + } + } + + def close(): Unit = { + if (producer != null) { + checkForErrors + producer.close() + checkForErrors + producer = null + } + } + + private def createProjection: UnsafeProjection = { + val topicExpression = topic.map(Literal(_)).orElse { + inputSchema.find(_.name == KafkaWriter.TOPIC_ATTRIBUTE_NAME) + }.getOrElse { + throw new IllegalStateException(s"topic option required when no " + + s"'${KafkaWriter.TOPIC_ATTRIBUTE_NAME}' attribute is present") + } + topicExpression.dataType match { + case StringType => // good + case t => + throw new IllegalStateException(s"${KafkaWriter.TOPIC_ATTRIBUTE_NAME} " + + s"attribute unsupported type $t. ${KafkaWriter.TOPIC_ATTRIBUTE_NAME} " + + s"must be a ${StringType}") + } + val keyExpression = inputSchema.find(_.name == KafkaWriter.KEY_ATTRIBUTE_NAME) + .getOrElse(Literal(null, BinaryType)) + keyExpression.dataType match { + case StringType | BinaryType => // good + case t => + throw new IllegalStateException(s"${KafkaWriter.KEY_ATTRIBUTE_NAME} " + + s"attribute unsupported type $t") + } + val valueExpression = inputSchema + .find(_.name == KafkaWriter.VALUE_ATTRIBUTE_NAME).getOrElse( + throw new IllegalStateException(s"Required attribute " + + s"'${KafkaWriter.VALUE_ATTRIBUTE_NAME}' not found") + ) + valueExpression.dataType match { + case StringType | BinaryType => // good + case t => + throw new IllegalStateException(s"${KafkaWriter.VALUE_ATTRIBUTE_NAME} " + + s"attribute unsupported type $t") + } + UnsafeProjection.create( + Seq(topicExpression, Cast(keyExpression, BinaryType), + Cast(valueExpression, BinaryType)), inputSchema) + } + + private def checkForErrors: Unit = { + if (failedWrite != null) { + throw failedWrite + } + } +} + diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaWriter.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaWriter.scala new file mode 100644 index 0000000000000..61936e32fd837 --- /dev/null +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaWriter.scala @@ -0,0 +1,97 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.kafka010 + +import java.{util => ju} + +import org.apache.spark.internal.Logging +import org.apache.spark.sql.{AnalysisException, SparkSession} +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.execution.{QueryExecution, SQLExecution} +import org.apache.spark.sql.types.{BinaryType, StringType} +import org.apache.spark.util.Utils + +/** + * The [[KafkaWriter]] class is used to write data from a batch query + * or structured streaming query, given by a [[QueryExecution]], to Kafka. + * The data is assumed to have a value column, and an optional topic and key + * columns. If the topic column is missing, then the topic must come from + * the 'topic' configuration option. If the key column is missing, then a + * null valued key field will be added to the + * [[org.apache.kafka.clients.producer.ProducerRecord]]. + */ +private[kafka010] object KafkaWriter extends Logging { + val TOPIC_ATTRIBUTE_NAME: String = "topic" + val KEY_ATTRIBUTE_NAME: String = "key" + val VALUE_ATTRIBUTE_NAME: String = "value" + + override def toString: String = "KafkaWriter" + + def validateQuery( + queryExecution: QueryExecution, + kafkaParameters: ju.Map[String, Object], + topic: Option[String] = None): Unit = { + val schema = queryExecution.analyzed.output + schema.find(_.name == TOPIC_ATTRIBUTE_NAME).getOrElse( + if (topic == None) { + throw new AnalysisException(s"topic option required when no " + + s"'$TOPIC_ATTRIBUTE_NAME' attribute is present. Use the " + + s"${KafkaSourceProvider.TOPIC_OPTION_KEY} option for setting a topic.") + } else { + Literal(topic.get, StringType) + } + ).dataType match { + case StringType => // good + case _ => + throw new AnalysisException(s"Topic type must be a String") + } + schema.find(_.name == KEY_ATTRIBUTE_NAME).getOrElse( + Literal(null, StringType) + ).dataType match { + case StringType | BinaryType => // good + case _ => + throw new AnalysisException(s"$KEY_ATTRIBUTE_NAME attribute type " + + s"must be a String or BinaryType") + } + schema.find(_.name == VALUE_ATTRIBUTE_NAME).getOrElse( + throw new AnalysisException(s"Required attribute '$VALUE_ATTRIBUTE_NAME' not found") + ).dataType match { + case StringType | BinaryType => // good + case _ => + throw new AnalysisException(s"$VALUE_ATTRIBUTE_NAME attribute type " + + s"must be a String or BinaryType") + } + } + + def write( + sparkSession: SparkSession, + queryExecution: QueryExecution, + kafkaParameters: ju.Map[String, Object], + topic: Option[String] = None): Unit = { + val schema = queryExecution.analyzed.output + validateQuery(queryExecution, kafkaParameters, topic) + SQLExecution.withNewExecutionId(sparkSession, queryExecution) { + queryExecution.toRdd.foreachPartition { iter => + val writeTask = new KafkaWriteTask(kafkaParameters, schema, topic) + Utils.tryWithSafeFinally(block = writeTask.execute(iter))( + finallyBlock = writeTask.close()) + } + } + } +} diff --git a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/CachedKafkaConsumerSuite.scala b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/CachedKafkaConsumerSuite.scala new file mode 100644 index 0000000000000..7aa7dd096c07b --- /dev/null +++ b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/CachedKafkaConsumerSuite.scala @@ -0,0 +1,34 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.kafka010 + +import org.scalatest.PrivateMethodTester + +import org.apache.spark.sql.test.SharedSQLContext + +class CachedKafkaConsumerSuite extends SharedSQLContext with PrivateMethodTester { + + test("SPARK-19886: Report error cause correctly in reportDataLoss") { + val cause = new Exception("D'oh!") + val reportDataLoss = PrivateMethod[Unit]('reportDataLoss0) + val e = intercept[IllegalStateException] { + CachedKafkaConsumer.invokePrivate(reportDataLoss(true, "message", cause)) + } + assert(e.getCause === cause) + } +} diff --git a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaRelationSuite.scala b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaRelationSuite.scala index 68bc3e3e2e9a8..91893df4ec32f 100644 --- a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaRelationSuite.scala +++ b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaRelationSuite.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql.kafka010 +import java.util.Locale import java.util.concurrent.atomic.AtomicInteger import org.apache.kafka.common.TopicPartition @@ -195,7 +196,7 @@ class KafkaRelationSuite extends QueryTest with BeforeAndAfter with SharedSQLCon reader.load() } expectedMsgs.foreach { m => - assert(ex.getMessage.toLowerCase.contains(m.toLowerCase)) + assert(ex.getMessage.toLowerCase(Locale.ROOT).contains(m.toLowerCase(Locale.ROOT))) } } diff --git a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSinkSuite.scala b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSinkSuite.scala new file mode 100644 index 0000000000000..2ab336c7ac476 --- /dev/null +++ b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSinkSuite.scala @@ -0,0 +1,430 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.kafka010 + +import java.util.Locale +import java.util.concurrent.atomic.AtomicInteger + +import org.apache.kafka.clients.producer.ProducerConfig +import org.apache.kafka.common.serialization.ByteArraySerializer +import org.scalatest.time.SpanSugar._ + +import org.apache.spark.SparkException +import org.apache.spark.sql._ +import org.apache.spark.sql.catalyst.expressions.{AttributeReference, SpecificInternalRow, UnsafeProjection} +import org.apache.spark.sql.execution.streaming.MemoryStream +import org.apache.spark.sql.functions._ +import org.apache.spark.sql.streaming._ +import org.apache.spark.sql.test.SharedSQLContext +import org.apache.spark.sql.types.{BinaryType, DataType} + +class KafkaSinkSuite extends StreamTest with SharedSQLContext { + import testImplicits._ + + protected var testUtils: KafkaTestUtils = _ + + override val streamingTimeout = 30.seconds + + override def beforeAll(): Unit = { + super.beforeAll() + testUtils = new KafkaTestUtils( + withBrokerProps = Map("auto.create.topics.enable" -> "false")) + testUtils.setup() + } + + override def afterAll(): Unit = { + if (testUtils != null) { + testUtils.teardown() + testUtils = null + super.afterAll() + } + } + + test("batch - write to kafka") { + val topic = newTopic() + testUtils.createTopic(topic) + val df = Seq("1", "2", "3", "4", "5").map(v => (topic, v)).toDF("topic", "value") + df.write + .format("kafka") + .option("kafka.bootstrap.servers", testUtils.brokerAddress) + .option("topic", topic) + .save() + checkAnswer( + createKafkaReader(topic).selectExpr("CAST(value as STRING) value"), + Row("1") :: Row("2") :: Row("3") :: Row("4") :: Row("5") :: Nil) + } + + test("batch - null topic field value, and no topic option") { + val df = Seq[(String, String)](null.asInstanceOf[String] -> "1").toDF("topic", "value") + val ex = intercept[SparkException] { + df.write + .format("kafka") + .option("kafka.bootstrap.servers", testUtils.brokerAddress) + .save() + } + assert(ex.getMessage.toLowerCase(Locale.ROOT).contains( + "null topic present in the data")) + } + + test("batch - unsupported save modes") { + val topic = newTopic() + testUtils.createTopic(topic) + val df = Seq[(String, String)](null.asInstanceOf[String] -> "1").toDF("topic", "value") + + // Test bad save mode Ignore + var ex = intercept[AnalysisException] { + df.write + .format("kafka") + .option("kafka.bootstrap.servers", testUtils.brokerAddress) + .mode(SaveMode.Ignore) + .save() + } + assert(ex.getMessage.toLowerCase(Locale.ROOT).contains( + s"save mode ignore not allowed for kafka")) + + // Test bad save mode Overwrite + ex = intercept[AnalysisException] { + df.write + .format("kafka") + .option("kafka.bootstrap.servers", testUtils.brokerAddress) + .mode(SaveMode.Overwrite) + .save() + } + assert(ex.getMessage.toLowerCase(Locale.ROOT).contains( + s"save mode overwrite not allowed for kafka")) + } + + test("SPARK-20496: batch - enforce analyzed plans") { + val inputEvents = + spark.range(1, 1000) + .select(to_json(struct("*")) as 'value) + + val topic = newTopic() + testUtils.createTopic(topic) + // used to throw UnresolvedException + inputEvents.write + .format("kafka") + .option("kafka.bootstrap.servers", testUtils.brokerAddress) + .option("topic", topic) + .save() + } + + test("streaming - write to kafka with topic field") { + val input = MemoryStream[String] + val topic = newTopic() + testUtils.createTopic(topic) + + val writer = createKafkaWriter( + input.toDF(), + withTopic = None, + withOutputMode = Some(OutputMode.Append))( + withSelectExpr = s"'$topic' as topic", "value") + + val reader = createKafkaReader(topic) + .selectExpr("CAST(key as STRING) key", "CAST(value as STRING) value") + .selectExpr("CAST(key as INT) key", "CAST(value as INT) value") + .as[(Int, Int)] + .map(_._2) + + try { + input.addData("1", "2", "3", "4", "5") + failAfter(streamingTimeout) { + writer.processAllAvailable() + } + checkDatasetUnorderly(reader, 1, 2, 3, 4, 5) + input.addData("6", "7", "8", "9", "10") + failAfter(streamingTimeout) { + writer.processAllAvailable() + } + checkDatasetUnorderly(reader, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10) + } finally { + writer.stop() + } + } + + test("streaming - write aggregation w/o topic field, with topic option") { + val input = MemoryStream[String] + val topic = newTopic() + testUtils.createTopic(topic) + + val writer = createKafkaWriter( + input.toDF().groupBy("value").count(), + withTopic = Some(topic), + withOutputMode = Some(OutputMode.Update()))( + withSelectExpr = "CAST(value as STRING) key", "CAST(count as STRING) value") + + val reader = createKafkaReader(topic) + .selectExpr("CAST(key as STRING) key", "CAST(value as STRING) value") + .selectExpr("CAST(key as INT) key", "CAST(value as INT) value") + .as[(Int, Int)] + + try { + input.addData("1", "2", "2", "3", "3", "3") + failAfter(streamingTimeout) { + writer.processAllAvailable() + } + checkDatasetUnorderly(reader, (1, 1), (2, 2), (3, 3)) + input.addData("1", "2", "3") + failAfter(streamingTimeout) { + writer.processAllAvailable() + } + checkDatasetUnorderly(reader, (1, 1), (2, 2), (3, 3), (1, 2), (2, 3), (3, 4)) + } finally { + writer.stop() + } + } + + test("streaming - aggregation with topic field and topic option") { + /* The purpose of this test is to ensure that the topic option + * overrides the topic field. We begin by writing some data that + * includes a topic field and value (e.g., 'foo') along with a topic + * option. Then when we read from the topic specified in the option + * we should see the data i.e., the data was written to the topic + * option, and not to the topic in the data e.g., foo + */ + val input = MemoryStream[String] + val topic = newTopic() + testUtils.createTopic(topic) + + val writer = createKafkaWriter( + input.toDF().groupBy("value").count(), + withTopic = Some(topic), + withOutputMode = Some(OutputMode.Update()))( + withSelectExpr = "'foo' as topic", + "CAST(value as STRING) key", "CAST(count as STRING) value") + + val reader = createKafkaReader(topic) + .selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)") + .selectExpr("CAST(key AS INT)", "CAST(value AS INT)") + .as[(Int, Int)] + + try { + input.addData("1", "2", "2", "3", "3", "3") + failAfter(streamingTimeout) { + writer.processAllAvailable() + } + checkDatasetUnorderly(reader, (1, 1), (2, 2), (3, 3)) + input.addData("1", "2", "3") + failAfter(streamingTimeout) { + writer.processAllAvailable() + } + checkDatasetUnorderly(reader, (1, 1), (2, 2), (3, 3), (1, 2), (2, 3), (3, 4)) + } finally { + writer.stop() + } + } + + + test("streaming - write data with bad schema") { + val input = MemoryStream[String] + val topic = newTopic() + testUtils.createTopic(topic) + + /* No topic field or topic option */ + var writer: StreamingQuery = null + var ex: Exception = null + try { + ex = intercept[StreamingQueryException] { + writer = createKafkaWriter(input.toDF())( + withSelectExpr = "value as key", "value" + ) + input.addData("1", "2", "3", "4", "5") + writer.processAllAvailable() + } + } finally { + writer.stop() + } + assert(ex.getMessage + .toLowerCase(Locale.ROOT) + .contains("topic option required when no 'topic' attribute is present")) + + try { + /* No value field */ + ex = intercept[StreamingQueryException] { + writer = createKafkaWriter(input.toDF())( + withSelectExpr = s"'$topic' as topic", "value as key" + ) + input.addData("1", "2", "3", "4", "5") + writer.processAllAvailable() + } + } finally { + writer.stop() + } + assert(ex.getMessage.toLowerCase(Locale.ROOT).contains( + "required attribute 'value' not found")) + } + + test("streaming - write data with valid schema but wrong types") { + val input = MemoryStream[String] + val topic = newTopic() + testUtils.createTopic(topic) + + var writer: StreamingQuery = null + var ex: Exception = null + try { + /* topic field wrong type */ + ex = intercept[StreamingQueryException] { + writer = createKafkaWriter(input.toDF())( + withSelectExpr = s"CAST('1' as INT) as topic", "value" + ) + input.addData("1", "2", "3", "4", "5") + writer.processAllAvailable() + } + } finally { + writer.stop() + } + assert(ex.getMessage.toLowerCase(Locale.ROOT).contains("topic type must be a string")) + + try { + /* value field wrong type */ + ex = intercept[StreamingQueryException] { + writer = createKafkaWriter(input.toDF())( + withSelectExpr = s"'$topic' as topic", "CAST(value as INT) as value" + ) + input.addData("1", "2", "3", "4", "5") + writer.processAllAvailable() + } + } finally { + writer.stop() + } + assert(ex.getMessage.toLowerCase(Locale.ROOT).contains( + "value attribute type must be a string or binarytype")) + + try { + ex = intercept[StreamingQueryException] { + /* key field wrong type */ + writer = createKafkaWriter(input.toDF())( + withSelectExpr = s"'$topic' as topic", "CAST(value as INT) as key", "value" + ) + input.addData("1", "2", "3", "4", "5") + writer.processAllAvailable() + } + } finally { + writer.stop() + } + assert(ex.getMessage.toLowerCase(Locale.ROOT).contains( + "key attribute type must be a string or binarytype")) + } + + test("streaming - write to non-existing topic") { + val input = MemoryStream[String] + val topic = newTopic() + + var writer: StreamingQuery = null + var ex: Exception = null + try { + ex = intercept[StreamingQueryException] { + writer = createKafkaWriter(input.toDF(), withTopic = Some(topic))() + input.addData("1", "2", "3", "4", "5") + writer.processAllAvailable() + } + } finally { + writer.stop() + } + assert(ex.getMessage.toLowerCase(Locale.ROOT).contains("job aborted")) + } + + test("streaming - exception on config serializer") { + val input = MemoryStream[String] + var writer: StreamingQuery = null + var ex: Exception = null + ex = intercept[IllegalArgumentException] { + writer = createKafkaWriter( + input.toDF(), + withOptions = Map("kafka.key.serializer" -> "foo"))() + } + assert(ex.getMessage.toLowerCase(Locale.ROOT).contains( + "kafka option 'key.serializer' is not supported")) + + ex = intercept[IllegalArgumentException] { + writer = createKafkaWriter( + input.toDF(), + withOptions = Map("kafka.value.serializer" -> "foo"))() + } + assert(ex.getMessage.toLowerCase(Locale.ROOT).contains( + "kafka option 'value.serializer' is not supported")) + } + + test("generic - write big data with small producer buffer") { + /* This test ensures that we understand the semantics of Kafka when + * is comes to blocking on a call to send when the send buffer is full. + * This test will configure the smallest possible producer buffer and + * indicate that we should block when it is full. Thus, no exception should + * be thrown in the case of a full buffer. + */ + val topic = newTopic() + testUtils.createTopic(topic, 1) + val options = new java.util.HashMap[String, Object] + options.put("bootstrap.servers", testUtils.brokerAddress) + options.put("buffer.memory", "16384") // min buffer size + options.put("block.on.buffer.full", "true") + options.put(ProducerConfig.KEY_SERIALIZER_CLASS_CONFIG, classOf[ByteArraySerializer].getName) + options.put(ProducerConfig.VALUE_SERIALIZER_CLASS_CONFIG, classOf[ByteArraySerializer].getName) + val inputSchema = Seq(AttributeReference("value", BinaryType)()) + val data = new Array[Byte](15000) // large value + val writeTask = new KafkaWriteTask(options, inputSchema, Some(topic)) + try { + val fieldTypes: Array[DataType] = Array(BinaryType) + val converter = UnsafeProjection.create(fieldTypes) + val row = new SpecificInternalRow(fieldTypes) + row.update(0, data) + val iter = Seq.fill(1000)(converter.apply(row)).iterator + writeTask.execute(iter) + } finally { + writeTask.close() + } + } + + private val topicId = new AtomicInteger(0) + + private def newTopic(): String = s"topic-${topicId.getAndIncrement()}" + + private def createKafkaReader(topic: String): DataFrame = { + spark.read + .format("kafka") + .option("kafka.bootstrap.servers", testUtils.brokerAddress) + .option("startingOffsets", "earliest") + .option("endingOffsets", "latest") + .option("subscribe", topic) + .load() + } + + private def createKafkaWriter( + input: DataFrame, + withTopic: Option[String] = None, + withOutputMode: Option[OutputMode] = None, + withOptions: Map[String, String] = Map[String, String]()) + (withSelectExpr: String*): StreamingQuery = { + var stream: DataStreamWriter[Row] = null + withTempDir { checkpointDir => + var df = input.toDF() + if (withSelectExpr.length > 0) { + df = df.selectExpr(withSelectExpr: _*) + } + stream = df.writeStream + .format("kafka") + .option("checkpointLocation", checkpointDir.getCanonicalPath) + .option("kafka.bootstrap.servers", testUtils.brokerAddress) + .queryName("kafkaStream") + withTopic.foreach(stream.option("topic", _)) + withOutputMode.foreach(stream.outputMode(_)) + withOptions.foreach(opt => stream.option(opt._1, opt._2)) + } + stream.start() + } +} diff --git a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSourceSuite.scala b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSourceSuite.scala index 1b35c2243a691..dcee3c15b91f6 100644 --- a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSourceSuite.scala +++ b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSourceSuite.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.kafka010 import java.io._ import java.nio.charset.StandardCharsets.UTF_8 import java.nio.file.{Files, Paths} -import java.util.Properties +import java.util.{Locale, Properties} import java.util.concurrent.ConcurrentLinkedQueue import java.util.concurrent.atomic.AtomicInteger @@ -37,7 +37,9 @@ import org.apache.spark.SparkContext import org.apache.spark.sql.ForeachWriter import org.apache.spark.sql.execution.streaming._ import org.apache.spark.sql.functions.{count, window} +import org.apache.spark.sql.kafka010.KafkaSourceProvider._ import org.apache.spark.sql.streaming.{ProcessingTime, StreamTest} +import org.apache.spark.sql.streaming.util.StreamManualClock import org.apache.spark.sql.test.{SharedSQLContext, TestSparkSession} import org.apache.spark.util.Utils @@ -204,7 +206,7 @@ class KafkaSourceSuite extends KafkaSourceTest { override def serialize(metadata: KafkaSourceOffset, out: OutputStream): Unit = { out.write(0) val writer = new BufferedWriter(new OutputStreamWriter(out, UTF_8)) - writer.write(s"v0\n${metadata.json}") + writer.write(s"v99999\n${metadata.json}") writer.flush } } @@ -226,7 +228,12 @@ class KafkaSourceSuite extends KafkaSourceTest { source.getOffset.get // Read initial offset } - assert(e.getMessage.contains("Please upgrade your Spark")) + Seq( + s"maximum supported log version is v${KafkaSource.VERSION}, but encountered v99999", + "produced by a newer version of Spark and cannot be read by this version" + ).foreach { message => + assert(e.getMessage.contains(message)) + } } } @@ -295,8 +302,6 @@ class KafkaSourceSuite extends KafkaSourceTest { StopStream, StartStream(ProcessingTime(100), clock), waitUntilBatchProcessed, - AdvanceManualClock(100), - waitUntilBatchProcessed, // smallest now empty, 1 more from middle, 9 more from biggest CheckAnswer(1, 10, 100, 101, 102, 103, 104, 105, 106, 107, 11, 108, 109, 110, 111, 112, 113, 114, 115, 116, @@ -486,7 +491,7 @@ class KafkaSourceSuite extends KafkaSourceTest { reader.load() } expectedMsgs.foreach { m => - assert(ex.getMessage.toLowerCase.contains(m.toLowerCase)) + assert(ex.getMessage.toLowerCase(Locale.ROOT).contains(m.toLowerCase(Locale.ROOT))) } } @@ -519,7 +524,7 @@ class KafkaSourceSuite extends KafkaSourceTest { .option(s"$key", value) reader.load() } - assert(ex.getMessage.toLowerCase.contains("not supported")) + assert(ex.getMessage.toLowerCase(Locale.ROOT).contains("not supported")) } testUnsupportedConfig("kafka.group.id") @@ -606,6 +611,24 @@ class KafkaSourceSuite extends KafkaSourceTest { assert(query.exception.isEmpty) } + test("get offsets from case insensitive parameters") { + for ((optionKey, optionValue, answer) <- Seq( + (STARTING_OFFSETS_OPTION_KEY, "earLiEst", EarliestOffsetRangeLimit), + (ENDING_OFFSETS_OPTION_KEY, "laTest", LatestOffsetRangeLimit), + (STARTING_OFFSETS_OPTION_KEY, """{"topic-A":{"0":23}}""", + SpecificOffsetRangeLimit(Map(new TopicPartition("topic-A", 0) -> 23))))) { + val offset = getKafkaOffsetRangeLimit(Map(optionKey -> optionValue), optionKey, answer) + assert(offset === answer) + } + + for ((optionKey, answer) <- Seq( + (STARTING_OFFSETS_OPTION_KEY, EarliestOffsetRangeLimit), + (ENDING_OFFSETS_OPTION_KEY, LatestOffsetRangeLimit))) { + val offset = getKafkaOffsetRangeLimit(Map.empty, optionKey, answer) + assert(offset === answer) + } + } + private def newTopic(): String = s"topic-${topicId.getAndIncrement()}" private def assignString(topic: String, partitions: Iterable[Int]): String = { diff --git a/external/kafka-0-10/pom.xml b/external/kafka-0-10/pom.xml index 88499240cd569..6c98cb04fcfa6 100644 --- a/external/kafka-0-10/pom.xml +++ b/external/kafka-0-10/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.11 - 2.2.0-SNAPSHOT + 2.3.0-SNAPSHOT ../../pom.xml diff --git a/external/kafka-0-10/src/main/scala/org/apache/spark/streaming/kafka010/ConsumerStrategy.scala b/external/kafka-0-10/src/main/scala/org/apache/spark/streaming/kafka010/ConsumerStrategy.scala index 778c06ea16a2b..d2100fc5a4aba 100644 --- a/external/kafka-0-10/src/main/scala/org/apache/spark/streaming/kafka010/ConsumerStrategy.scala +++ b/external/kafka-0-10/src/main/scala/org/apache/spark/streaming/kafka010/ConsumerStrategy.scala @@ -17,7 +17,8 @@ package org.apache.spark.streaming.kafka010 -import java.{ lang => jl, util => ju } +import java.{lang => jl, util => ju} +import java.util.Locale import scala.collection.JavaConverters._ @@ -93,7 +94,8 @@ private case class Subscribe[K, V]( // but cant seek to a position before poll, because poll is what gets subscription partitions // So, poll, suppress the first exception, then seek val aor = kafkaParams.get(ConsumerConfig.AUTO_OFFSET_RESET_CONFIG) - val shouldSuppress = aor != null && aor.asInstanceOf[String].toUpperCase == "NONE" + val shouldSuppress = + aor != null && aor.asInstanceOf[String].toUpperCase(Locale.ROOT) == "NONE" try { consumer.poll(0) } catch { @@ -145,7 +147,8 @@ private case class SubscribePattern[K, V]( if (!toSeek.isEmpty) { // work around KAFKA-3370 when reset is none, see explanation in Subscribe above val aor = kafkaParams.get(ConsumerConfig.AUTO_OFFSET_RESET_CONFIG) - val shouldSuppress = aor != null && aor.asInstanceOf[String].toUpperCase == "NONE" + val shouldSuppress = + aor != null && aor.asInstanceOf[String].toUpperCase(Locale.ROOT) == "NONE" try { consumer.poll(0) } catch { diff --git a/external/kafka-0-10/src/main/scala/org/apache/spark/streaming/kafka010/KafkaRDD.scala b/external/kafka-0-10/src/main/scala/org/apache/spark/streaming/kafka010/KafkaRDD.scala index 4c6e2ce87e295..62cdf5b1134e4 100644 --- a/external/kafka-0-10/src/main/scala/org/apache/spark/streaming/kafka010/KafkaRDD.scala +++ b/external/kafka-0-10/src/main/scala/org/apache/spark/streaming/kafka010/KafkaRDD.scala @@ -199,7 +199,7 @@ private[spark] class KafkaRDD[K, V]( val consumer = if (useConsumerCache) { CachedKafkaConsumer.init(cacheInitialCapacity, cacheMaxCapacity, cacheLoadFactor) - if (context.attemptNumber > 1) { + if (context.attemptNumber >= 1) { // just in case the prior attempt failures were cache related CachedKafkaConsumer.remove(groupId, part.topic, part.partition) } diff --git a/external/kafka-0-8-assembly/pom.xml b/external/kafka-0-8-assembly/pom.xml index 3fedd9eda1959..f9c2dcb38dc0e 100644 --- a/external/kafka-0-8-assembly/pom.xml +++ b/external/kafka-0-8-assembly/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.11 - 2.2.0-SNAPSHOT + 2.3.0-SNAPSHOT ../../pom.xml diff --git a/external/kafka-0-8/pom.xml b/external/kafka-0-8/pom.xml index 8368a1f12218d..849c8b465f99e 100644 --- a/external/kafka-0-8/pom.xml +++ b/external/kafka-0-8/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.11 - 2.2.0-SNAPSHOT + 2.3.0-SNAPSHOT ../../pom.xml diff --git a/external/kafka-0-8/src/main/scala/org/apache/spark/streaming/kafka/KafkaUtils.scala b/external/kafka-0-8/src/main/scala/org/apache/spark/streaming/kafka/KafkaUtils.scala index d5aef8184fc87..78230725f322e 100644 --- a/external/kafka-0-8/src/main/scala/org/apache/spark/streaming/kafka/KafkaUtils.scala +++ b/external/kafka-0-8/src/main/scala/org/apache/spark/streaming/kafka/KafkaUtils.scala @@ -20,7 +20,7 @@ package org.apache.spark.streaming.kafka import java.io.OutputStream import java.lang.{Integer => JInt, Long => JLong, Number => JNumber} import java.nio.charset.StandardCharsets -import java.util.{List => JList, Map => JMap, Set => JSet} +import java.util.{List => JList, Locale, Map => JMap, Set => JSet} import scala.collection.JavaConverters._ import scala.reflect.ClassTag @@ -206,7 +206,7 @@ object KafkaUtils { kafkaParams: Map[String, String], topics: Set[String] ): Map[TopicAndPartition, Long] = { - val reset = kafkaParams.get("auto.offset.reset").map(_.toLowerCase) + val reset = kafkaParams.get("auto.offset.reset").map(_.toLowerCase(Locale.ROOT)) val result = for { topicPartitions <- kc.getPartitions(topics).right leaderOffsets <- (if (reset == Some("smallest")) { diff --git a/external/kinesis-asl-assembly/pom.xml b/external/kinesis-asl-assembly/pom.xml index 90bb0e4987c82..48783d65826aa 100644 --- a/external/kinesis-asl-assembly/pom.xml +++ b/external/kinesis-asl-assembly/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.11 - 2.2.0-SNAPSHOT + 2.3.0-SNAPSHOT ../../pom.xml diff --git a/external/kinesis-asl/pom.xml b/external/kinesis-asl/pom.xml index daa79e79163b9..40a751a652fa9 100644 --- a/external/kinesis-asl/pom.xml +++ b/external/kinesis-asl/pom.xml @@ -20,7 +20,7 @@ org.apache.spark spark-parent_2.11 - 2.2.0-SNAPSHOT + 2.3.0-SNAPSHOT ../../pom.xml diff --git a/external/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisBackedBlockRDD.scala b/external/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisBackedBlockRDD.scala index 23c4d99e50f51..f31ebf1ec8da0 100644 --- a/external/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisBackedBlockRDD.scala +++ b/external/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisBackedBlockRDD.scala @@ -36,7 +36,11 @@ import org.apache.spark.util.NextIterator /** Class representing a range of Kinesis sequence numbers. Both sequence numbers are inclusive. */ private[kinesis] case class SequenceNumberRange( - streamName: String, shardId: String, fromSeqNumber: String, toSeqNumber: String) + streamName: String, + shardId: String, + fromSeqNumber: String, + toSeqNumber: String, + recordCount: Int) /** Class representing an array of Kinesis sequence number ranges */ private[kinesis] @@ -78,8 +82,8 @@ class KinesisBackedBlockRDD[T: ClassTag]( @transient val arrayOfseqNumberRanges: Array[SequenceNumberRanges], @transient private val isBlockIdValid: Array[Boolean] = Array.empty, val retryTimeoutMs: Int = 10000, - val messageHandler: Record => T = KinesisUtils.defaultMessageHandler _, - val kinesisCredsProvider: SerializableCredentialsProvider = DefaultCredentialsProvider + val messageHandler: Record => T = KinesisInputDStream.defaultMessageHandler _, + val kinesisCreds: SparkAWSCredentials = DefaultCredentials ) extends BlockRDD[T](sc, _blockIds) { require(_blockIds.length == arrayOfseqNumberRanges.length, @@ -105,7 +109,7 @@ class KinesisBackedBlockRDD[T: ClassTag]( } def getBlockFromKinesis(): Iterator[T] = { - val credentials = kinesisCredsProvider.provider.getCredentials + val credentials = kinesisCreds.provider.getCredentials partition.seqNumberRanges.ranges.iterator.flatMap { range => new KinesisSequenceRangeIterator(credentials, endpointUrl, regionName, range, retryTimeoutMs).map(messageHandler) @@ -136,6 +140,8 @@ class KinesisSequenceRangeIterator( private val client = new AmazonKinesisClient(credentials) private val streamName = range.streamName private val shardId = range.shardId + // AWS limits to maximum of 10k records per get call + private val maxGetRecordsLimit = 10000 private var toSeqNumberReceived = false private var lastSeqNumber: String = null @@ -153,12 +159,14 @@ class KinesisSequenceRangeIterator( // If the internal iterator has not been initialized, // then fetch records from starting sequence number - internalIterator = getRecords(ShardIteratorType.AT_SEQUENCE_NUMBER, range.fromSeqNumber) + internalIterator = getRecords(ShardIteratorType.AT_SEQUENCE_NUMBER, range.fromSeqNumber, + range.recordCount) } else if (!internalIterator.hasNext) { // If the internal iterator does not have any more records, // then fetch more records after the last consumed sequence number - internalIterator = getRecords(ShardIteratorType.AFTER_SEQUENCE_NUMBER, lastSeqNumber) + internalIterator = getRecords(ShardIteratorType.AFTER_SEQUENCE_NUMBER, lastSeqNumber, + range.recordCount) } if (!internalIterator.hasNext) { @@ -191,9 +199,12 @@ class KinesisSequenceRangeIterator( /** * Get records starting from or after the given sequence number. */ - private def getRecords(iteratorType: ShardIteratorType, seqNum: String): Iterator[Record] = { + private def getRecords( + iteratorType: ShardIteratorType, + seqNum: String, + recordCount: Int): Iterator[Record] = { val shardIterator = getKinesisIterator(iteratorType, seqNum) - val result = getRecordsAndNextKinesisIterator(shardIterator) + val result = getRecordsAndNextKinesisIterator(shardIterator, recordCount) result._1 } @@ -202,10 +213,12 @@ class KinesisSequenceRangeIterator( * to get records from Kinesis), and get the next shard iterator for next consumption. */ private def getRecordsAndNextKinesisIterator( - shardIterator: String): (Iterator[Record], String) = { + shardIterator: String, + recordCount: Int): (Iterator[Record], String) = { val getRecordsRequest = new GetRecordsRequest getRecordsRequest.setRequestCredentials(credentials) getRecordsRequest.setShardIterator(shardIterator) + getRecordsRequest.setLimit(Math.min(recordCount, this.maxGetRecordsLimit)) val getRecordsResult = retryOrTimeout[GetRecordsResult]( s"getting records using shard iterator") { client.getRecords(getRecordsRequest) diff --git a/external/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisInputDStream.scala b/external/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisInputDStream.scala index fbc6b99443ed7..77553412eda56 100644 --- a/external/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisInputDStream.scala +++ b/external/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisInputDStream.scala @@ -22,24 +22,28 @@ import scala.reflect.ClassTag import com.amazonaws.services.kinesis.clientlibrary.lib.worker.InitialPositionInStream import com.amazonaws.services.kinesis.model.Record +import org.apache.spark.annotation.InterfaceStability import org.apache.spark.rdd.RDD import org.apache.spark.storage.{BlockId, StorageLevel} import org.apache.spark.streaming.{Duration, StreamingContext, Time} +import org.apache.spark.streaming.api.java.JavaStreamingContext import org.apache.spark.streaming.dstream.ReceiverInputDStream import org.apache.spark.streaming.receiver.Receiver import org.apache.spark.streaming.scheduler.ReceivedBlockInfo private[kinesis] class KinesisInputDStream[T: ClassTag]( _ssc: StreamingContext, - streamName: String, - endpointUrl: String, - regionName: String, - initialPositionInStream: InitialPositionInStream, - checkpointAppName: String, - checkpointInterval: Duration, - storageLevel: StorageLevel, - messageHandler: Record => T, - kinesisCredsProvider: SerializableCredentialsProvider + val streamName: String, + val endpointUrl: String, + val regionName: String, + val initialPositionInStream: InitialPositionInStream, + val checkpointAppName: String, + val checkpointInterval: Duration, + val _storageLevel: StorageLevel, + val messageHandler: Record => T, + val kinesisCreds: SparkAWSCredentials, + val dynamoDBCreds: Option[SparkAWSCredentials], + val cloudWatchCreds: Option[SparkAWSCredentials] ) extends ReceiverInputDStream[T](_ssc) { private[streaming] @@ -61,7 +65,7 @@ private[kinesis] class KinesisInputDStream[T: ClassTag]( isBlockIdValid = isBlockIdValid, retryTimeoutMs = ssc.graph.batchDuration.milliseconds.toInt, messageHandler = messageHandler, - kinesisCredsProvider = kinesisCredsProvider) + kinesisCreds = kinesisCreds) } else { logWarning("Kinesis sequence number information was not present with some block metadata," + " it may not be possible to recover from failures") @@ -71,7 +75,238 @@ private[kinesis] class KinesisInputDStream[T: ClassTag]( override def getReceiver(): Receiver[T] = { new KinesisReceiver(streamName, endpointUrl, regionName, initialPositionInStream, - checkpointAppName, checkpointInterval, storageLevel, messageHandler, - kinesisCredsProvider) + checkpointAppName, checkpointInterval, _storageLevel, messageHandler, + kinesisCreds, dynamoDBCreds, cloudWatchCreds) } } + +@InterfaceStability.Evolving +object KinesisInputDStream { + /** + * Builder for [[KinesisInputDStream]] instances. + * + * @since 2.2.0 + */ + @InterfaceStability.Evolving + class Builder { + // Required params + private var streamingContext: Option[StreamingContext] = None + private var streamName: Option[String] = None + private var checkpointAppName: Option[String] = None + + // Params with defaults + private var endpointUrl: Option[String] = None + private var regionName: Option[String] = None + private var initialPositionInStream: Option[InitialPositionInStream] = None + private var checkpointInterval: Option[Duration] = None + private var storageLevel: Option[StorageLevel] = None + private var kinesisCredsProvider: Option[SparkAWSCredentials] = None + private var dynamoDBCredsProvider: Option[SparkAWSCredentials] = None + private var cloudWatchCredsProvider: Option[SparkAWSCredentials] = None + + /** + * Sets the StreamingContext that will be used to construct the Kinesis DStream. This is a + * required parameter. + * + * @param ssc [[StreamingContext]] used to construct Kinesis DStreams + * @return Reference to this [[KinesisInputDStream.Builder]] + */ + def streamingContext(ssc: StreamingContext): Builder = { + streamingContext = Option(ssc) + this + } + + /** + * Sets the StreamingContext that will be used to construct the Kinesis DStream. This is a + * required parameter. + * + * @param jssc [[JavaStreamingContext]] used to construct Kinesis DStreams + * @return Reference to this [[KinesisInputDStream.Builder]] + */ + def streamingContext(jssc: JavaStreamingContext): Builder = { + streamingContext = Option(jssc.ssc) + this + } + + /** + * Sets the name of the Kinesis stream that the DStream will read from. This is a required + * parameter. + * + * @param streamName Name of Kinesis stream that the DStream will read from + * @return Reference to this [[KinesisInputDStream.Builder]] + */ + def streamName(streamName: String): Builder = { + this.streamName = Option(streamName) + this + } + + /** + * Sets the KCL application name to use when checkpointing state to DynamoDB. This is a + * required parameter. + * + * @param appName Value to use for the KCL app name (used when creating the DynamoDB checkpoint + * table and when writing metrics to CloudWatch) + * @return Reference to this [[KinesisInputDStream.Builder]] + */ + def checkpointAppName(appName: String): Builder = { + checkpointAppName = Option(appName) + this + } + + /** + * Sets the AWS Kinesis endpoint URL. Defaults to "https://kinesis.us-east-1.amazonaws.com" if + * no custom value is specified + * + * @param url Kinesis endpoint URL to use + * @return Reference to this [[KinesisInputDStream.Builder]] + */ + def endpointUrl(url: String): Builder = { + endpointUrl = Option(url) + this + } + + /** + * Sets the AWS region to construct clients for. Defaults to "us-east-1" if no custom value + * is specified. + * + * @param regionName Name of AWS region to use (e.g. "us-west-2") + * @return Reference to this [[KinesisInputDStream.Builder]] + */ + def regionName(regionName: String): Builder = { + this.regionName = Option(regionName) + this + } + + /** + * Sets the initial position data is read from in the Kinesis stream. Defaults to + * [[InitialPositionInStream.LATEST]] if no custom value is specified. + * + * @param initialPosition InitialPositionInStream value specifying where Spark Streaming + * will start reading records in the Kinesis stream from + * @return Reference to this [[KinesisInputDStream.Builder]] + */ + def initialPositionInStream(initialPosition: InitialPositionInStream): Builder = { + initialPositionInStream = Option(initialPosition) + this + } + + /** + * Sets how often the KCL application state is checkpointed to DynamoDB. Defaults to the Spark + * Streaming batch interval if no custom value is specified. + * + * @param interval [[Duration]] specifying how often the KCL state should be checkpointed to + * DynamoDB. + * @return Reference to this [[KinesisInputDStream.Builder]] + */ + def checkpointInterval(interval: Duration): Builder = { + checkpointInterval = Option(interval) + this + } + + /** + * Sets the storage level of the blocks for the DStream created. Defaults to + * [[StorageLevel.MEMORY_AND_DISK_2]] if no custom value is specified. + * + * @param storageLevel [[StorageLevel]] to use for the DStream data blocks + * @return Reference to this [[KinesisInputDStream.Builder]] + */ + def storageLevel(storageLevel: StorageLevel): Builder = { + this.storageLevel = Option(storageLevel) + this + } + + /** + * Sets the [[SparkAWSCredentials]] to use for authenticating to the AWS Kinesis + * endpoint. Defaults to [[DefaultCredentialsProvider]] if no custom value is specified. + * + * @param credentials [[SparkAWSCredentials]] to use for Kinesis authentication + */ + def kinesisCredentials(credentials: SparkAWSCredentials): Builder = { + kinesisCredsProvider = Option(credentials) + this + } + + /** + * Sets the [[SparkAWSCredentials]] to use for authenticating to the AWS DynamoDB + * endpoint. Will use the same credentials used for AWS Kinesis if no custom value is set. + * + * @param credentials [[SparkAWSCredentials]] to use for DynamoDB authentication + */ + def dynamoDBCredentials(credentials: SparkAWSCredentials): Builder = { + dynamoDBCredsProvider = Option(credentials) + this + } + + /** + * Sets the [[SparkAWSCredentials]] to use for authenticating to the AWS CloudWatch + * endpoint. Will use the same credentials used for AWS Kinesis if no custom value is set. + * + * @param credentials [[SparkAWSCredentials]] to use for CloudWatch authentication + */ + def cloudWatchCredentials(credentials: SparkAWSCredentials): Builder = { + cloudWatchCredsProvider = Option(credentials) + this + } + + /** + * Create a new instance of [[KinesisInputDStream]] with configured parameters and the provided + * message handler. + * + * @param handler Function converting [[Record]] instances read by the KCL to DStream type [[T]] + * @return Instance of [[KinesisInputDStream]] constructed with configured parameters + */ + def buildWithMessageHandler[T: ClassTag]( + handler: Record => T): KinesisInputDStream[T] = { + val ssc = getRequiredParam(streamingContext, "streamingContext") + new KinesisInputDStream( + ssc, + getRequiredParam(streamName, "streamName"), + endpointUrl.getOrElse(DEFAULT_KINESIS_ENDPOINT_URL), + regionName.getOrElse(DEFAULT_KINESIS_REGION_NAME), + initialPositionInStream.getOrElse(DEFAULT_INITIAL_POSITION_IN_STREAM), + getRequiredParam(checkpointAppName, "checkpointAppName"), + checkpointInterval.getOrElse(ssc.graph.batchDuration), + storageLevel.getOrElse(DEFAULT_STORAGE_LEVEL), + ssc.sc.clean(handler), + kinesisCredsProvider.getOrElse(DefaultCredentials), + dynamoDBCredsProvider, + cloudWatchCredsProvider) + } + + /** + * Create a new instance of [[KinesisInputDStream]] with configured parameters and using the + * default message handler, which returns [[Array[Byte]]]. + * + * @return Instance of [[KinesisInputDStream]] constructed with configured parameters + */ + def build(): KinesisInputDStream[Array[Byte]] = buildWithMessageHandler(defaultMessageHandler) + + private def getRequiredParam[T](param: Option[T], paramName: String): T = param.getOrElse { + throw new IllegalArgumentException(s"No value provided for required parameter $paramName") + } + } + + /** + * Creates a [[KinesisInputDStream.Builder]] for constructing [[KinesisInputDStream]] instances. + * + * @since 2.2.0 + * + * @return [[KinesisInputDStream.Builder]] instance + */ + def builder: Builder = new Builder + + private[kinesis] def defaultMessageHandler(record: Record): Array[Byte] = { + if (record == null) return null + val byteBuffer = record.getData() + val byteArray = new Array[Byte](byteBuffer.remaining()) + byteBuffer.get(byteArray) + byteArray + } + + private[kinesis] val DEFAULT_KINESIS_ENDPOINT_URL: String = + "https://kinesis.us-east-1.amazonaws.com" + private[kinesis] val DEFAULT_KINESIS_REGION_NAME: String = "us-east-1" + private[kinesis] val DEFAULT_INITIAL_POSITION_IN_STREAM: InitialPositionInStream = + InitialPositionInStream.LATEST + private[kinesis] val DEFAULT_STORAGE_LEVEL: StorageLevel = StorageLevel.MEMORY_AND_DISK_2 +} diff --git a/external/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisReceiver.scala b/external/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisReceiver.scala index 13fc54e531dda..1026d0fcb59bd 100644 --- a/external/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisReceiver.scala +++ b/external/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisReceiver.scala @@ -70,9 +70,14 @@ import org.apache.spark.util.Utils * See the Kinesis Spark Streaming documentation for more * details on the different types of checkpoints. * @param storageLevel Storage level to use for storing the received objects - * @param kinesisCredsProvider SerializableCredentialsProvider instance that will be used to - * generate the AWSCredentialsProvider instance used for KCL - * authorization. + * @param kinesisCreds SparkAWSCredentials instance that will be used to generate the + * AWSCredentialsProvider passed to the KCL to authorize Kinesis API calls. + * @param cloudWatchCreds Optional SparkAWSCredentials instance that will be used to generate the + * AWSCredentialsProvider passed to the KCL to authorize CloudWatch API + * calls. Will use kinesisCreds if value is None. + * @param dynamoDBCreds Optional SparkAWSCredentials instance that will be used to generate the + * AWSCredentialsProvider passed to the KCL to authorize DynamoDB API calls. + * Will use kinesisCreds if value is None. */ private[kinesis] class KinesisReceiver[T]( val streamName: String, @@ -83,7 +88,9 @@ private[kinesis] class KinesisReceiver[T]( checkpointInterval: Duration, storageLevel: StorageLevel, messageHandler: Record => T, - kinesisCredsProvider: SerializableCredentialsProvider) + kinesisCreds: SparkAWSCredentials, + dynamoDBCreds: Option[SparkAWSCredentials], + cloudWatchCreds: Option[SparkAWSCredentials]) extends Receiver[T](storageLevel) with Logging { receiver => /* @@ -140,10 +147,13 @@ private[kinesis] class KinesisReceiver[T]( workerId = Utils.localHostName() + ":" + UUID.randomUUID() kinesisCheckpointer = new KinesisCheckpointer(receiver, checkpointInterval, workerId) + val kinesisProvider = kinesisCreds.provider val kinesisClientLibConfiguration = new KinesisClientLibConfiguration( checkpointAppName, streamName, - kinesisCredsProvider.provider, + kinesisProvider, + dynamoDBCreds.map(_.provider).getOrElse(kinesisProvider), + cloudWatchCreds.map(_.provider).getOrElse(kinesisProvider), workerId) .withKinesisEndpoint(endpointUrl) .withInitialPositionInStream(initialPositionInStream) @@ -210,7 +220,8 @@ private[kinesis] class KinesisReceiver[T]( if (records.size > 0) { val dataIterator = records.iterator().asScala.map(messageHandler) val metadata = SequenceNumberRange(streamName, shardId, - records.get(0).getSequenceNumber(), records.get(records.size() - 1).getSequenceNumber()) + records.get(0).getSequenceNumber(), records.get(records.size() - 1).getSequenceNumber(), + records.size()) blockGenerator.addMultipleDataWithCallback(dataIterator, metadata) } } diff --git a/external/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisUtils.scala b/external/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisUtils.scala index 2d777982e760c..1298463bfba1e 100644 --- a/external/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisUtils.scala +++ b/external/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisUtils.scala @@ -58,6 +58,7 @@ object KinesisUtils { * on the workers. See AWS documentation to understand how DefaultAWSCredentialsProviderChain * gets the AWS credentials. */ + @deprecated("Use KinesisInputDStream.builder instead", "2.2.0") def createStream[T: ClassTag]( ssc: StreamingContext, kinesisAppName: String, @@ -73,7 +74,7 @@ object KinesisUtils { ssc.withNamedScope("kinesis stream") { new KinesisInputDStream[T](ssc, streamName, endpointUrl, validateRegion(regionName), initialPositionInStream, kinesisAppName, checkpointInterval, storageLevel, - cleanedHandler, DefaultCredentialsProvider) + cleanedHandler, DefaultCredentials, None, None) } } @@ -108,6 +109,7 @@ object KinesisUtils { * is enabled. Make sure that your checkpoint directory is secure. */ // scalastyle:off + @deprecated("Use KinesisInputDStream.builder instead", "2.2.0") def createStream[T: ClassTag]( ssc: StreamingContext, kinesisAppName: String, @@ -123,12 +125,12 @@ object KinesisUtils { // scalastyle:on val cleanedHandler = ssc.sc.clean(messageHandler) ssc.withNamedScope("kinesis stream") { - val kinesisCredsProvider = BasicCredentialsProvider( + val kinesisCredsProvider = BasicCredentials( awsAccessKeyId = awsAccessKeyId, awsSecretKey = awsSecretKey) new KinesisInputDStream[T](ssc, streamName, endpointUrl, validateRegion(regionName), initialPositionInStream, kinesisAppName, checkpointInterval, storageLevel, - cleanedHandler, kinesisCredsProvider) + cleanedHandler, kinesisCredsProvider, None, None) } } @@ -169,6 +171,7 @@ object KinesisUtils { * is enabled. Make sure that your checkpoint directory is secure. */ // scalastyle:off + @deprecated("Use KinesisInputDStream.builder instead", "2.2.0") def createStream[T: ClassTag]( ssc: StreamingContext, kinesisAppName: String, @@ -187,16 +190,16 @@ object KinesisUtils { // scalastyle:on val cleanedHandler = ssc.sc.clean(messageHandler) ssc.withNamedScope("kinesis stream") { - val kinesisCredsProvider = STSCredentialsProvider( + val kinesisCredsProvider = STSCredentials( stsRoleArn = stsAssumeRoleArn, stsSessionName = stsSessionName, stsExternalId = Option(stsExternalId), - longLivedCredsProvider = BasicCredentialsProvider( + longLivedCreds = BasicCredentials( awsAccessKeyId = awsAccessKeyId, awsSecretKey = awsSecretKey)) new KinesisInputDStream[T](ssc, streamName, endpointUrl, validateRegion(regionName), initialPositionInStream, kinesisAppName, checkpointInterval, storageLevel, - cleanedHandler, kinesisCredsProvider) + cleanedHandler, kinesisCredsProvider, None, None) } } @@ -227,6 +230,7 @@ object KinesisUtils { * on the workers. See AWS documentation to understand how DefaultAWSCredentialsProviderChain * gets the AWS credentials. */ + @deprecated("Use KinesisInputDStream.builder instead", "2.2.0") def createStream( ssc: StreamingContext, kinesisAppName: String, @@ -240,7 +244,7 @@ object KinesisUtils { ssc.withNamedScope("kinesis stream") { new KinesisInputDStream[Array[Byte]](ssc, streamName, endpointUrl, validateRegion(regionName), initialPositionInStream, kinesisAppName, checkpointInterval, storageLevel, - defaultMessageHandler, DefaultCredentialsProvider) + KinesisInputDStream.defaultMessageHandler, DefaultCredentials, None, None) } } @@ -272,6 +276,7 @@ object KinesisUtils { * @note The given AWS credentials will get saved in DStream checkpoints if checkpointing * is enabled. Make sure that your checkpoint directory is secure. */ + @deprecated("Use KinesisInputDStream.builder instead", "2.2.0") def createStream( ssc: StreamingContext, kinesisAppName: String, @@ -284,12 +289,12 @@ object KinesisUtils { awsAccessKeyId: String, awsSecretKey: String): ReceiverInputDStream[Array[Byte]] = { ssc.withNamedScope("kinesis stream") { - val kinesisCredsProvider = BasicCredentialsProvider( + val kinesisCredsProvider = BasicCredentials( awsAccessKeyId = awsAccessKeyId, awsSecretKey = awsSecretKey) new KinesisInputDStream[Array[Byte]](ssc, streamName, endpointUrl, validateRegion(regionName), initialPositionInStream, kinesisAppName, checkpointInterval, storageLevel, - defaultMessageHandler, kinesisCredsProvider) + KinesisInputDStream.defaultMessageHandler, kinesisCredsProvider, None, None) } } @@ -323,6 +328,7 @@ object KinesisUtils { * on the workers. See AWS documentation to understand how DefaultAWSCredentialsProviderChain * gets the AWS credentials. */ + @deprecated("Use KinesisInputDStream.builder instead", "2.2.0") def createStream[T]( jssc: JavaStreamingContext, kinesisAppName: String, @@ -372,6 +378,7 @@ object KinesisUtils { * is enabled. Make sure that your checkpoint directory is secure. */ // scalastyle:off + @deprecated("Use KinesisInputDStream.builder instead", "2.2.0") def createStream[T]( jssc: JavaStreamingContext, kinesisAppName: String, @@ -431,6 +438,7 @@ object KinesisUtils { * is enabled. Make sure that your checkpoint directory is secure. */ // scalastyle:off + @deprecated("Use KinesisInputDStream.builder instead", "2.2.0") def createStream[T]( jssc: JavaStreamingContext, kinesisAppName: String, @@ -482,6 +490,7 @@ object KinesisUtils { * on the workers. See AWS documentation to understand how DefaultAWSCredentialsProviderChain * gets the AWS credentials. */ + @deprecated("Use KinesisInputDStream.builder instead", "2.2.0") def createStream( jssc: JavaStreamingContext, kinesisAppName: String, @@ -493,7 +502,8 @@ object KinesisUtils { storageLevel: StorageLevel ): JavaReceiverInputDStream[Array[Byte]] = { createStream[Array[Byte]](jssc.ssc, kinesisAppName, streamName, endpointUrl, regionName, - initialPositionInStream, checkpointInterval, storageLevel, defaultMessageHandler(_)) + initialPositionInStream, checkpointInterval, storageLevel, + KinesisInputDStream.defaultMessageHandler(_)) } /** @@ -524,6 +534,7 @@ object KinesisUtils { * @note The given AWS credentials will get saved in DStream checkpoints if checkpointing * is enabled. Make sure that your checkpoint directory is secure. */ + @deprecated("Use KinesisInputDStream.builder instead", "2.2.0") def createStream( jssc: JavaStreamingContext, kinesisAppName: String, @@ -537,7 +548,7 @@ object KinesisUtils { awsSecretKey: String): JavaReceiverInputDStream[Array[Byte]] = { createStream[Array[Byte]](jssc.ssc, kinesisAppName, streamName, endpointUrl, regionName, initialPositionInStream, checkpointInterval, storageLevel, - defaultMessageHandler(_), awsAccessKeyId, awsSecretKey) + KinesisInputDStream.defaultMessageHandler(_), awsAccessKeyId, awsSecretKey) } private def validateRegion(regionName: String): String = { @@ -545,14 +556,6 @@ object KinesisUtils { throw new IllegalArgumentException(s"Region name '$regionName' is not valid") } } - - private[kinesis] def defaultMessageHandler(record: Record): Array[Byte] = { - if (record == null) return null - val byteBuffer = record.getData() - val byteArray = new Array[Byte](byteBuffer.remaining()) - byteBuffer.get(byteArray) - byteArray - } } /** @@ -597,7 +600,7 @@ private class KinesisUtilsPythonHelper { validateAwsCreds(awsAccessKeyId, awsSecretKey) KinesisUtils.createStream(jssc.ssc, kinesisAppName, streamName, endpointUrl, regionName, getInitialPositionInStream(initialPositionInStream), checkpointInterval, storageLevel, - KinesisUtils.defaultMessageHandler(_), awsAccessKeyId, awsSecretKey, + KinesisInputDStream.defaultMessageHandler(_), awsAccessKeyId, awsSecretKey, stsAssumeRoleArn, stsSessionName, stsExternalId) } else { validateAwsCreds(awsAccessKeyId, awsSecretKey) diff --git a/external/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/SerializableCredentialsProvider.scala b/external/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/SerializableCredentialsProvider.scala deleted file mode 100644 index aa6fe12edf74e..0000000000000 --- a/external/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/SerializableCredentialsProvider.scala +++ /dev/null @@ -1,85 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.spark.streaming.kinesis - -import scala.collection.JavaConverters._ - -import com.amazonaws.auth._ - -import org.apache.spark.internal.Logging - -/** - * Serializable interface providing a method executors can call to obtain an - * AWSCredentialsProvider instance for authenticating to AWS services. - */ -private[kinesis] sealed trait SerializableCredentialsProvider extends Serializable { - /** - * Return an AWSCredentialProvider instance that can be used by the Kinesis Client - * Library to authenticate to AWS services (Kinesis, CloudWatch and DynamoDB). - */ - def provider: AWSCredentialsProvider -} - -/** Returns DefaultAWSCredentialsProviderChain for authentication. */ -private[kinesis] final case object DefaultCredentialsProvider - extends SerializableCredentialsProvider { - - def provider: AWSCredentialsProvider = new DefaultAWSCredentialsProviderChain -} - -/** - * Returns AWSStaticCredentialsProvider constructed using basic AWS keypair. Falls back to using - * DefaultAWSCredentialsProviderChain if unable to construct a AWSCredentialsProviderChain - * instance with the provided arguments (e.g. if they are null). - */ -private[kinesis] final case class BasicCredentialsProvider( - awsAccessKeyId: String, - awsSecretKey: String) extends SerializableCredentialsProvider with Logging { - - def provider: AWSCredentialsProvider = try { - new AWSStaticCredentialsProvider(new BasicAWSCredentials(awsAccessKeyId, awsSecretKey)) - } catch { - case e: IllegalArgumentException => - logWarning("Unable to construct AWSStaticCredentialsProvider with provided keypair; " + - "falling back to DefaultAWSCredentialsProviderChain.", e) - new DefaultAWSCredentialsProviderChain - } -} - -/** - * Returns an STSAssumeRoleSessionCredentialsProvider instance which assumes an IAM - * role in order to authenticate against resources in an external account. - */ -private[kinesis] final case class STSCredentialsProvider( - stsRoleArn: String, - stsSessionName: String, - stsExternalId: Option[String] = None, - longLivedCredsProvider: SerializableCredentialsProvider = DefaultCredentialsProvider) - extends SerializableCredentialsProvider { - - def provider: AWSCredentialsProvider = { - val builder = new STSAssumeRoleSessionCredentialsProvider.Builder(stsRoleArn, stsSessionName) - .withLongLivedCredentialsProvider(longLivedCredsProvider.provider) - stsExternalId match { - case Some(stsExternalId) => - builder.withExternalId(stsExternalId) - .build() - case None => - builder.build() - } - } -} diff --git a/external/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/SparkAWSCredentials.scala b/external/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/SparkAWSCredentials.scala new file mode 100644 index 0000000000000..9facfe8ff2b0f --- /dev/null +++ b/external/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/SparkAWSCredentials.scala @@ -0,0 +1,182 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.streaming.kinesis + +import scala.collection.JavaConverters._ + +import com.amazonaws.auth._ + +import org.apache.spark.annotation.InterfaceStability +import org.apache.spark.internal.Logging + +/** + * Serializable interface providing a method executors can call to obtain an + * AWSCredentialsProvider instance for authenticating to AWS services. + */ +private[kinesis] sealed trait SparkAWSCredentials extends Serializable { + /** + * Return an AWSCredentialProvider instance that can be used by the Kinesis Client + * Library to authenticate to AWS services (Kinesis, CloudWatch and DynamoDB). + */ + def provider: AWSCredentialsProvider +} + +/** Returns DefaultAWSCredentialsProviderChain for authentication. */ +private[kinesis] final case object DefaultCredentials extends SparkAWSCredentials { + + def provider: AWSCredentialsProvider = new DefaultAWSCredentialsProviderChain +} + +/** + * Returns AWSStaticCredentialsProvider constructed using basic AWS keypair. Falls back to using + * DefaultCredentialsProviderChain if unable to construct a AWSCredentialsProviderChain + * instance with the provided arguments (e.g. if they are null). + */ +private[kinesis] final case class BasicCredentials( + awsAccessKeyId: String, + awsSecretKey: String) extends SparkAWSCredentials with Logging { + + def provider: AWSCredentialsProvider = try { + new AWSStaticCredentialsProvider(new BasicAWSCredentials(awsAccessKeyId, awsSecretKey)) + } catch { + case e: IllegalArgumentException => + logWarning("Unable to construct AWSStaticCredentialsProvider with provided keypair; " + + "falling back to DefaultCredentialsProviderChain.", e) + new DefaultAWSCredentialsProviderChain + } +} + +/** + * Returns an STSAssumeRoleSessionCredentialsProvider instance which assumes an IAM + * role in order to authenticate against resources in an external account. + */ +private[kinesis] final case class STSCredentials( + stsRoleArn: String, + stsSessionName: String, + stsExternalId: Option[String] = None, + longLivedCreds: SparkAWSCredentials = DefaultCredentials) + extends SparkAWSCredentials { + + def provider: AWSCredentialsProvider = { + val builder = new STSAssumeRoleSessionCredentialsProvider.Builder(stsRoleArn, stsSessionName) + .withLongLivedCredentialsProvider(longLivedCreds.provider) + stsExternalId match { + case Some(stsExternalId) => + builder.withExternalId(stsExternalId) + .build() + case None => + builder.build() + } + } +} + +@InterfaceStability.Evolving +object SparkAWSCredentials { + /** + * Builder for [[SparkAWSCredentials]] instances. + * + * @since 2.2.0 + */ + @InterfaceStability.Evolving + class Builder { + private var basicCreds: Option[BasicCredentials] = None + private var stsCreds: Option[STSCredentials] = None + + // scalastyle:off + /** + * Use a basic AWS keypair for long-lived authorization. + * + * @note The given AWS keypair will be saved in DStream checkpoints if checkpointing is + * enabled. Make sure that your checkpoint directory is secure. Prefer using the + * [[http://docs.aws.amazon.com/sdk-for-java/v1/developer-guide/credentials.html#credentials-default default provider chain]] + * instead if possible. + * + * @param accessKeyId AWS access key ID + * @param secretKey AWS secret key + * @return Reference to this [[SparkAWSCredentials.Builder]] + */ + // scalastyle:on + def basicCredentials(accessKeyId: String, secretKey: String): Builder = { + basicCreds = Option(BasicCredentials( + awsAccessKeyId = accessKeyId, + awsSecretKey = secretKey)) + this + } + + /** + * Use STS to assume an IAM role for temporary session-based authentication. Will use configured + * long-lived credentials for authorizing to STS itself (either the default provider chain + * or a configured keypair). + * + * @param roleArn ARN of IAM role to assume via STS + * @param sessionName Name to use for the STS session + * @return Reference to this [[SparkAWSCredentials.Builder]] + */ + def stsCredentials(roleArn: String, sessionName: String): Builder = { + stsCreds = Option(STSCredentials(stsRoleArn = roleArn, stsSessionName = sessionName)) + this + } + + /** + * Use STS to assume an IAM role for temporary session-based authentication. Will use configured + * long-lived credentials for authorizing to STS itself (either the default provider chain + * or a configured keypair). STS will validate the provided external ID with the one defined + * in the trust policy of the IAM role to be assumed (if one is present). + * + * @param roleArn ARN of IAM role to assume via STS + * @param sessionName Name to use for the STS session + * @param externalId External ID to validate against assumed IAM role's trust policy + * @return Reference to this [[SparkAWSCredentials.Builder]] + */ + def stsCredentials(roleArn: String, sessionName: String, externalId: String): Builder = { + stsCreds = Option(STSCredentials( + stsRoleArn = roleArn, + stsSessionName = sessionName, + stsExternalId = Option(externalId))) + this + } + + /** + * Returns the appropriate instance of [[SparkAWSCredentials]] given the configured + * parameters. + * + * - The long-lived credentials will either be [[DefaultCredentials]] or [[BasicCredentials]] + * if they were provided. + * + * - If STS credentials were provided, the configured long-lived credentials will be added to + * them and the result will be returned. + * + * - The long-lived credentials will be returned otherwise. + * + * @return [[SparkAWSCredentials]] to use for configured parameters + */ + def build(): SparkAWSCredentials = + stsCreds.map(_.copy(longLivedCreds = longLivedCreds)).getOrElse(longLivedCreds) + + private def longLivedCreds: SparkAWSCredentials = basicCreds.getOrElse(DefaultCredentials) + } + + /** + * Creates a [[SparkAWSCredentials.Builder]] for constructing + * [[SparkAWSCredentials]] instances. + * + * @since 2.2.0 + * + * @return [[SparkAWSCredentials.Builder]] instance + */ + def builder: Builder = new Builder +} diff --git a/external/kinesis-asl/src/test/java/org/apache/spark/streaming/kinesis/JavaKinesisInputDStreamBuilderSuite.java b/external/kinesis-asl/src/test/java/org/apache/spark/streaming/kinesis/JavaKinesisInputDStreamBuilderSuite.java new file mode 100644 index 0000000000000..7205f6e27266c --- /dev/null +++ b/external/kinesis-asl/src/test/java/org/apache/spark/streaming/kinesis/JavaKinesisInputDStreamBuilderSuite.java @@ -0,0 +1,63 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.streaming.kinesis; + +import org.junit.Test; + +import com.amazonaws.services.kinesis.clientlibrary.lib.worker.InitialPositionInStream; + +import org.apache.spark.storage.StorageLevel; +import org.apache.spark.streaming.Duration; +import org.apache.spark.streaming.Seconds; +import org.apache.spark.streaming.LocalJavaStreamingContext; +import org.apache.spark.streaming.api.java.JavaDStream; + +public class JavaKinesisInputDStreamBuilderSuite extends LocalJavaStreamingContext { + /** + * Basic test to ensure that the KinesisDStream.Builder interface is accessible from Java. + */ + @Test + public void testJavaKinesisDStreamBuilder() { + String streamName = "a-very-nice-stream-name"; + String endpointUrl = "https://kinesis.us-west-2.amazonaws.com"; + String region = "us-west-2"; + InitialPositionInStream initialPosition = InitialPositionInStream.TRIM_HORIZON; + String appName = "a-very-nice-kinesis-app"; + Duration checkpointInterval = Seconds.apply(30); + StorageLevel storageLevel = StorageLevel.MEMORY_ONLY(); + + KinesisInputDStream kinesisDStream = KinesisInputDStream.builder() + .streamingContext(ssc) + .streamName(streamName) + .endpointUrl(endpointUrl) + .regionName(region) + .initialPositionInStream(initialPosition) + .checkpointAppName(appName) + .checkpointInterval(checkpointInterval) + .storageLevel(storageLevel) + .build(); + assert(kinesisDStream.streamName() == streamName); + assert(kinesisDStream.endpointUrl() == endpointUrl); + assert(kinesisDStream.regionName() == region); + assert(kinesisDStream.initialPositionInStream() == initialPosition); + assert(kinesisDStream.checkpointAppName() == appName); + assert(kinesisDStream.checkpointInterval() == checkpointInterval); + assert(kinesisDStream._storageLevel() == storageLevel); + ssc.stop(); + } +} diff --git a/external/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisBackedBlockRDDSuite.scala b/external/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisBackedBlockRDDSuite.scala index 18a5a1509a33a..2c7b9c58e6fa6 100644 --- a/external/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisBackedBlockRDDSuite.scala +++ b/external/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisBackedBlockRDDSuite.scala @@ -51,7 +51,7 @@ abstract class KinesisBackedBlockRDDTests(aggregateTestData: Boolean) shardIdToSeqNumbers = shardIdToDataAndSeqNumbers.mapValues { _.map { _._2 }} shardIdToRange = shardIdToSeqNumbers.map { case (shardId, seqNumbers) => val seqNumRange = SequenceNumberRange( - testUtils.streamName, shardId, seqNumbers.head, seqNumbers.last) + testUtils.streamName, shardId, seqNumbers.head, seqNumbers.last, seqNumbers.size) (shardId, seqNumRange) } allRanges = shardIdToRange.values.toSeq @@ -181,7 +181,7 @@ abstract class KinesisBackedBlockRDDTests(aggregateTestData: Boolean) // Create the necessary ranges to use in the RDD val fakeRanges = Array.fill(numPartitions - numPartitionsInKinesis)( - SequenceNumberRanges(SequenceNumberRange("fakeStream", "fakeShardId", "xxx", "yyy"))) + SequenceNumberRanges(SequenceNumberRange("fakeStream", "fakeShardId", "xxx", "yyy", 1))) val realRanges = Array.tabulate(numPartitionsInKinesis) { i => val range = shardIdToRange(shardIds(i + (numPartitions - numPartitionsInKinesis))) SequenceNumberRanges(Array(range)) diff --git a/external/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisInputDStreamBuilderSuite.scala b/external/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisInputDStreamBuilderSuite.scala new file mode 100644 index 0000000000000..1c130654f3f95 --- /dev/null +++ b/external/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisInputDStreamBuilderSuite.scala @@ -0,0 +1,115 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.streaming.kinesis + +import java.lang.IllegalArgumentException + +import com.amazonaws.services.kinesis.clientlibrary.lib.worker.InitialPositionInStream +import org.scalatest.BeforeAndAfterEach +import org.scalatest.mock.MockitoSugar + +import org.apache.spark.SparkFunSuite +import org.apache.spark.storage.StorageLevel +import org.apache.spark.streaming.{Seconds, StreamingContext, TestSuiteBase} + +class KinesisInputDStreamBuilderSuite extends TestSuiteBase with BeforeAndAfterEach + with MockitoSugar { + import KinesisInputDStream._ + + private val ssc = new StreamingContext(conf, batchDuration) + private val streamName = "a-very-nice-kinesis-stream-name" + private val checkpointAppName = "a-very-nice-kcl-app-name" + private def baseBuilder = KinesisInputDStream.builder + private def builder = baseBuilder.streamingContext(ssc) + .streamName(streamName) + .checkpointAppName(checkpointAppName) + + override def afterAll(): Unit = { + ssc.stop() + } + + test("should raise an exception if the StreamingContext is missing") { + intercept[IllegalArgumentException] { + baseBuilder.streamName(streamName).checkpointAppName(checkpointAppName).build() + } + } + + test("should raise an exception if the stream name is missing") { + intercept[IllegalArgumentException] { + baseBuilder.streamingContext(ssc).checkpointAppName(checkpointAppName).build() + } + } + + test("should raise an exception if the checkpoint app name is missing") { + intercept[IllegalArgumentException] { + baseBuilder.streamingContext(ssc).streamName(streamName).build() + } + } + + test("should propagate required values to KinesisInputDStream") { + val dstream = builder.build() + assert(dstream.context == ssc) + assert(dstream.streamName == streamName) + assert(dstream.checkpointAppName == checkpointAppName) + } + + test("should propagate default values to KinesisInputDStream") { + val dstream = builder.build() + assert(dstream.endpointUrl == DEFAULT_KINESIS_ENDPOINT_URL) + assert(dstream.regionName == DEFAULT_KINESIS_REGION_NAME) + assert(dstream.initialPositionInStream == DEFAULT_INITIAL_POSITION_IN_STREAM) + assert(dstream.checkpointInterval == batchDuration) + assert(dstream._storageLevel == DEFAULT_STORAGE_LEVEL) + assert(dstream.kinesisCreds == DefaultCredentials) + assert(dstream.dynamoDBCreds == None) + assert(dstream.cloudWatchCreds == None) + } + + test("should propagate custom non-auth values to KinesisInputDStream") { + val customEndpointUrl = "https://kinesis.us-west-2.amazonaws.com" + val customRegion = "us-west-2" + val customInitialPosition = InitialPositionInStream.TRIM_HORIZON + val customAppName = "a-very-nice-kinesis-app" + val customCheckpointInterval = Seconds(30) + val customStorageLevel = StorageLevel.MEMORY_ONLY + val customKinesisCreds = mock[SparkAWSCredentials] + val customDynamoDBCreds = mock[SparkAWSCredentials] + val customCloudWatchCreds = mock[SparkAWSCredentials] + + val dstream = builder + .endpointUrl(customEndpointUrl) + .regionName(customRegion) + .initialPositionInStream(customInitialPosition) + .checkpointAppName(customAppName) + .checkpointInterval(customCheckpointInterval) + .storageLevel(customStorageLevel) + .kinesisCredentials(customKinesisCreds) + .dynamoDBCredentials(customDynamoDBCreds) + .cloudWatchCredentials(customCloudWatchCreds) + .build() + assert(dstream.endpointUrl == customEndpointUrl) + assert(dstream.regionName == customRegion) + assert(dstream.initialPositionInStream == customInitialPosition) + assert(dstream.checkpointAppName == customAppName) + assert(dstream.checkpointInterval == customCheckpointInterval) + assert(dstream._storageLevel == customStorageLevel) + assert(dstream.kinesisCreds == customKinesisCreds) + assert(dstream.dynamoDBCreds == Option(customDynamoDBCreds)) + assert(dstream.cloudWatchCreds == Option(customCloudWatchCreds)) + } +} diff --git a/external/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisReceiverSuite.scala b/external/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisReceiverSuite.scala index deb411d73e588..3b14c8471e205 100644 --- a/external/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisReceiverSuite.scala +++ b/external/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisReceiverSuite.scala @@ -31,7 +31,6 @@ import org.scalatest.{BeforeAndAfter, Matchers} import org.scalatest.mock.MockitoSugar import org.apache.spark.streaming.{Duration, TestSuiteBase} -import org.apache.spark.util.Utils /** * Suite of Kinesis streaming receiver tests focusing mostly on the KinesisRecordProcessor @@ -62,28 +61,6 @@ class KinesisReceiverSuite extends TestSuiteBase with Matchers with BeforeAndAft checkpointerMock = mock[IRecordProcessorCheckpointer] } - test("check serializability of credential provider classes") { - Utils.deserialize[BasicCredentialsProvider]( - Utils.serialize(BasicCredentialsProvider( - awsAccessKeyId = "x", - awsSecretKey = "y"))) - - Utils.deserialize[STSCredentialsProvider]( - Utils.serialize(STSCredentialsProvider( - stsRoleArn = "fakeArn", - stsSessionName = "fakeSessionName", - stsExternalId = Some("fakeExternalId")))) - - Utils.deserialize[STSCredentialsProvider]( - Utils.serialize(STSCredentialsProvider( - stsRoleArn = "fakeArn", - stsSessionName = "fakeSessionName", - stsExternalId = Some("fakeExternalId"), - longLivedCredsProvider = BasicCredentialsProvider( - awsAccessKeyId = "x", - awsSecretKey = "y")))) - } - test("process records including store and set checkpointer") { when(receiverMock.isStopped()).thenReturn(false) when(receiverMock.getCurrentLimit).thenReturn(Int.MaxValue) diff --git a/external/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisStreamSuite.scala b/external/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisStreamSuite.scala index 387a96f26b305..341a6898cbbff 100644 --- a/external/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisStreamSuite.scala +++ b/external/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisStreamSuite.scala @@ -22,7 +22,6 @@ import scala.concurrent.duration._ import scala.language.postfixOps import scala.util.Random -import com.amazonaws.regions.RegionUtils import com.amazonaws.services.kinesis.clientlibrary.lib.worker.InitialPositionInStream import com.amazonaws.services.kinesis.model.Record import org.scalatest.{BeforeAndAfter, BeforeAndAfterAll} @@ -119,13 +118,13 @@ abstract class KinesisStreamTests(aggregateTestData: Boolean) extends KinesisFun // Generate block info data for testing val seqNumRanges1 = SequenceNumberRanges( - SequenceNumberRange("fakeStream", "fakeShardId", "xxx", "yyy")) + SequenceNumberRange("fakeStream", "fakeShardId", "xxx", "yyy", 67)) val blockId1 = StreamBlockId(kinesisStream.id, 123) val blockInfo1 = ReceivedBlockInfo( 0, None, Some(seqNumRanges1), new BlockManagerBasedStoreResult(blockId1, None)) val seqNumRanges2 = SequenceNumberRanges( - SequenceNumberRange("fakeStream", "fakeShardId", "aaa", "bbb")) + SequenceNumberRange("fakeStream", "fakeShardId", "aaa", "bbb", 89)) val blockId2 = StreamBlockId(kinesisStream.id, 345) val blockInfo2 = ReceivedBlockInfo( 0, None, Some(seqNumRanges2), new BlockManagerBasedStoreResult(blockId2, None)) @@ -138,7 +137,7 @@ abstract class KinesisStreamTests(aggregateTestData: Boolean) extends KinesisFun assert(kinesisRDD.regionName === dummyRegionName) assert(kinesisRDD.endpointUrl === dummyEndpointUrl) assert(kinesisRDD.retryTimeoutMs === batchDuration.milliseconds) - assert(kinesisRDD.kinesisCredsProvider === BasicCredentialsProvider( + assert(kinesisRDD.kinesisCreds === BasicCredentials( awsAccessKeyId = dummyAWSAccessKey, awsSecretKey = dummyAWSSecretKey)) assert(nonEmptyRDD.partitions.size === blockInfos.size) @@ -173,11 +172,15 @@ abstract class KinesisStreamTests(aggregateTestData: Boolean) extends KinesisFun * and you have to set the system environment variable RUN_KINESIS_TESTS=1 . */ testIfEnabled("basic operation") { - val awsCredentials = KinesisTestUtils.getAWSCredentials() - val stream = KinesisUtils.createStream(ssc, appName, testUtils.streamName, - testUtils.endpointUrl, testUtils.regionName, InitialPositionInStream.LATEST, - Seconds(10), StorageLevel.MEMORY_ONLY, - awsCredentials.getAWSAccessKeyId, awsCredentials.getAWSSecretKey) + val stream = KinesisInputDStream.builder.streamingContext(ssc) + .checkpointAppName(appName) + .streamName(testUtils.streamName) + .endpointUrl(testUtils.endpointUrl) + .regionName(testUtils.regionName) + .initialPositionInStream(InitialPositionInStream.LATEST) + .checkpointInterval(Seconds(10)) + .storageLevel(StorageLevel.MEMORY_ONLY) + .build() val collected = new mutable.HashSet[Int] stream.map { bytes => new String(bytes).toInt }.foreachRDD { rdd => @@ -198,12 +201,17 @@ abstract class KinesisStreamTests(aggregateTestData: Boolean) extends KinesisFun } testIfEnabled("custom message handling") { - val awsCredentials = KinesisTestUtils.getAWSCredentials() def addFive(r: Record): Int = JavaUtils.bytesToString(r.getData).toInt + 5 - val stream = KinesisUtils.createStream(ssc, appName, testUtils.streamName, - testUtils.endpointUrl, testUtils.regionName, InitialPositionInStream.LATEST, - Seconds(10), StorageLevel.MEMORY_ONLY, addFive(_), - awsCredentials.getAWSAccessKeyId, awsCredentials.getAWSSecretKey) + + val stream = KinesisInputDStream.builder.streamingContext(ssc) + .checkpointAppName(appName) + .streamName(testUtils.streamName) + .endpointUrl(testUtils.endpointUrl) + .regionName(testUtils.regionName) + .initialPositionInStream(InitialPositionInStream.LATEST) + .checkpointInterval(Seconds(10)) + .storageLevel(StorageLevel.MEMORY_ONLY) + .buildWithMessageHandler(addFive(_)) stream shouldBe a [ReceiverInputDStream[_]] @@ -233,11 +241,15 @@ abstract class KinesisStreamTests(aggregateTestData: Boolean) extends KinesisFun val localTestUtils = new KPLBasedKinesisTestUtils(1) localTestUtils.createStream() try { - val awsCredentials = KinesisTestUtils.getAWSCredentials() - val stream = KinesisUtils.createStream(ssc, localAppName, localTestUtils.streamName, - localTestUtils.endpointUrl, localTestUtils.regionName, InitialPositionInStream.LATEST, - Seconds(10), StorageLevel.MEMORY_ONLY, - awsCredentials.getAWSAccessKeyId, awsCredentials.getAWSSecretKey) + val stream = KinesisInputDStream.builder.streamingContext(ssc) + .checkpointAppName(localAppName) + .streamName(localTestUtils.streamName) + .endpointUrl(localTestUtils.endpointUrl) + .regionName(localTestUtils.regionName) + .initialPositionInStream(InitialPositionInStream.LATEST) + .checkpointInterval(Seconds(10)) + .storageLevel(StorageLevel.MEMORY_ONLY) + .build() val collected = new mutable.HashSet[Int] stream.map { bytes => new String(bytes).toInt }.foreachRDD { rdd => @@ -303,13 +315,17 @@ abstract class KinesisStreamTests(aggregateTestData: Boolean) extends KinesisFun ssc = new StreamingContext(sc, Milliseconds(1000)) ssc.checkpoint(checkpointDir) - val awsCredentials = KinesisTestUtils.getAWSCredentials() val collectedData = new mutable.HashMap[Time, (Array[SequenceNumberRanges], Seq[Int])] - val kinesisStream = KinesisUtils.createStream(ssc, appName, testUtils.streamName, - testUtils.endpointUrl, testUtils.regionName, InitialPositionInStream.LATEST, - Seconds(10), StorageLevel.MEMORY_ONLY, - awsCredentials.getAWSAccessKeyId, awsCredentials.getAWSSecretKey) + val kinesisStream = KinesisInputDStream.builder.streamingContext(ssc) + .checkpointAppName(appName) + .streamName(testUtils.streamName) + .endpointUrl(testUtils.endpointUrl) + .regionName(testUtils.regionName) + .initialPositionInStream(InitialPositionInStream.LATEST) + .checkpointInterval(Seconds(10)) + .storageLevel(StorageLevel.MEMORY_ONLY) + .build() // Verify that the generated RDDs are KinesisBackedBlockRDDs, and collect the data in each batch kinesisStream.foreachRDD((rdd: RDD[Array[Byte]], time: Time) => { diff --git a/external/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/SparkAWSCredentialsBuilderSuite.scala b/external/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/SparkAWSCredentialsBuilderSuite.scala new file mode 100644 index 0000000000000..f579c2c3a6799 --- /dev/null +++ b/external/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/SparkAWSCredentialsBuilderSuite.scala @@ -0,0 +1,100 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.streaming.kinesis + +import org.apache.spark.streaming.TestSuiteBase +import org.apache.spark.util.Utils + +class SparkAWSCredentialsBuilderSuite extends TestSuiteBase { + private def builder = SparkAWSCredentials.builder + + private val basicCreds = BasicCredentials( + awsAccessKeyId = "a-very-nice-access-key", + awsSecretKey = "a-very-nice-secret-key") + + private val stsCreds = STSCredentials( + stsRoleArn = "a-very-nice-role-arn", + stsSessionName = "a-very-nice-secret-key", + stsExternalId = Option("a-very-nice-external-id"), + longLivedCreds = basicCreds) + + test("should build DefaultCredentials when given no params") { + assert(builder.build() == DefaultCredentials) + } + + test("should build BasicCredentials") { + assertResult(basicCreds) { + builder.basicCredentials(basicCreds.awsAccessKeyId, basicCreds.awsSecretKey) + .build() + } + } + + test("should build STSCredentials") { + // No external ID, default long-lived creds + assertResult(stsCreds.copy(stsExternalId = None, longLivedCreds = DefaultCredentials)) { + builder.stsCredentials(stsCreds.stsRoleArn, stsCreds.stsSessionName) + .build() + } + // Default long-lived creds + assertResult(stsCreds.copy(longLivedCreds = DefaultCredentials)) { + builder.stsCredentials( + stsCreds.stsRoleArn, + stsCreds.stsSessionName, + stsCreds.stsExternalId.get) + .build() + } + // No external ID, basic keypair for long-lived creds + assertResult(stsCreds.copy(stsExternalId = None)) { + builder.stsCredentials(stsCreds.stsRoleArn, stsCreds.stsSessionName) + .basicCredentials(basicCreds.awsAccessKeyId, basicCreds.awsSecretKey) + .build() + } + // Basic keypair for long-lived creds + assertResult(stsCreds) { + builder.stsCredentials( + stsCreds.stsRoleArn, + stsCreds.stsSessionName, + stsCreds.stsExternalId.get) + .basicCredentials(basicCreds.awsAccessKeyId, basicCreds.awsSecretKey) + .build() + } + // Order shouldn't matter + assertResult(stsCreds) { + builder.basicCredentials(basicCreds.awsAccessKeyId, basicCreds.awsSecretKey) + .stsCredentials( + stsCreds.stsRoleArn, + stsCreds.stsSessionName, + stsCreds.stsExternalId.get) + .build() + } + } + + test("SparkAWSCredentials classes should be serializable") { + assertResult(basicCreds) { + Utils.deserialize[BasicCredentials](Utils.serialize(basicCreds)) + } + assertResult(stsCreds) { + Utils.deserialize[STSCredentials](Utils.serialize(stsCreds)) + } + // Will also test if DefaultCredentials can be serialized + val stsDefaultCreds = stsCreds.copy(longLivedCreds = DefaultCredentials) + assertResult(stsDefaultCreds) { + Utils.deserialize[STSCredentials](Utils.serialize(stsDefaultCreds)) + } + } +} diff --git a/external/spark-ganglia-lgpl/pom.xml b/external/spark-ganglia-lgpl/pom.xml index 7da27817ebafd..36d555066b181 100644 --- a/external/spark-ganglia-lgpl/pom.xml +++ b/external/spark-ganglia-lgpl/pom.xml @@ -20,7 +20,7 @@ org.apache.spark spark-parent_2.11 - 2.2.0-SNAPSHOT + 2.3.0-SNAPSHOT ../../pom.xml diff --git a/graphx/pom.xml b/graphx/pom.xml index 8df33660ea9d1..cb30e4a4af4bc 100644 --- a/graphx/pom.xml +++ b/graphx/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.11 - 2.2.0-SNAPSHOT + 2.3.0-SNAPSHOT ../pom.xml diff --git a/graphx/src/main/scala/org/apache/spark/graphx/Pregel.scala b/graphx/src/main/scala/org/apache/spark/graphx/Pregel.scala index 646462b4a8350..755c6febc48e6 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/Pregel.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/Pregel.scala @@ -19,7 +19,10 @@ package org.apache.spark.graphx import scala.reflect.ClassTag +import org.apache.spark.graphx.util.PeriodicGraphCheckpointer import org.apache.spark.internal.Logging +import org.apache.spark.rdd.RDD +import org.apache.spark.rdd.util.PeriodicRDDCheckpointer /** * Implements a Pregel-like bulk-synchronous message-passing API. @@ -122,27 +125,39 @@ object Pregel extends Logging { require(maxIterations > 0, s"Maximum number of iterations must be greater than 0," + s" but got ${maxIterations}") - var g = graph.mapVertices((vid, vdata) => vprog(vid, vdata, initialMsg)).cache() + val checkpointInterval = graph.vertices.sparkContext.getConf + .getInt("spark.graphx.pregel.checkpointInterval", -1) + var g = graph.mapVertices((vid, vdata) => vprog(vid, vdata, initialMsg)) + val graphCheckpointer = new PeriodicGraphCheckpointer[VD, ED]( + checkpointInterval, graph.vertices.sparkContext) + graphCheckpointer.update(g) + // compute the messages var messages = GraphXUtils.mapReduceTriplets(g, sendMsg, mergeMsg) + val messageCheckpointer = new PeriodicRDDCheckpointer[(VertexId, A)]( + checkpointInterval, graph.vertices.sparkContext) + messageCheckpointer.update(messages.asInstanceOf[RDD[(VertexId, A)]]) var activeMessages = messages.count() + // Loop var prevG: Graph[VD, ED] = null var i = 0 while (activeMessages > 0 && i < maxIterations) { // Receive the messages and update the vertices. prevG = g - g = g.joinVertices(messages)(vprog).cache() + g = g.joinVertices(messages)(vprog) + graphCheckpointer.update(g) val oldMessages = messages // Send new messages, skipping edges where neither side received a message. We must cache // messages so it can be materialized on the next line, allowing us to uncache the previous // iteration. messages = GraphXUtils.mapReduceTriplets( - g, sendMsg, mergeMsg, Some((oldMessages, activeDirection))).cache() + g, sendMsg, mergeMsg, Some((oldMessages, activeDirection))) // The call to count() materializes `messages` and the vertices of `g`. This hides oldMessages // (depended on by the vertices of g) and the vertices of prevG (depended on by oldMessages // and the vertices of g). + messageCheckpointer.update(messages.asInstanceOf[RDD[(VertexId, A)]]) activeMessages = messages.count() logInfo("Pregel finished iteration " + i) @@ -154,7 +169,9 @@ object Pregel extends Logging { // count the iteration i += 1 } - messages.unpersist(blocking = false) + messageCheckpointer.unpersistDataSet() + graphCheckpointer.deleteAllCheckpoints() + messageCheckpointer.deleteAllCheckpoints() g } // end of apply diff --git a/graphx/src/main/scala/org/apache/spark/graphx/lib/PageRank.scala b/graphx/src/main/scala/org/apache/spark/graphx/lib/PageRank.scala index 37b6e453592e5..fd7b7f7c1c487 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/lib/PageRank.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/lib/PageRank.scala @@ -162,7 +162,8 @@ object PageRank extends Logging { iteration += 1 } - rankGraph + // SPARK-18847 If the graph has sinks (vertices with no outgoing edges) correct the sum of ranks + normalizeRankSum(rankGraph, personalized) } /** @@ -179,7 +180,8 @@ object PageRank extends Logging { * @param resetProb The random reset probability * @param sources The list of sources to compute personalized pagerank from * @return the graph with vertex attributes - * containing the pagerank relative to all starting nodes (as a sparse vector) and + * containing the pagerank relative to all starting nodes (as a sparse vector + * indexed by the position of nodes in the sources list) and * edge attributes the normalized edge weight */ def runParallelPersonalizedPageRank[VD: ClassTag, ED: ClassTag](graph: Graph[VD, ED], @@ -194,6 +196,8 @@ object PageRank extends Logging { // TODO if one sources vertex id is outside of the int range // we won't be able to store its activations in a sparse vector + require(sources.max <= Int.MaxValue.toLong, + s"This implementation currently only works for source vertex ids at most ${Int.MaxValue}") val zero = Vectors.sparse(sources.size, List()).asBreeze val sourcesInitMap = sources.zipWithIndex.map { case (vid, i) => val v = Vectors.sparse(sources.size, Array(i), Array(1.0)).asBreeze @@ -222,18 +226,18 @@ object PageRank extends Logging { // Propagates the message along outbound edges // and adding start nodes back in with activation resetProb val rankUpdates = rankGraph.aggregateMessages[BV[Double]]( - ctx => ctx.sendToDst(ctx.srcAttr :* ctx.attr), - (a : BV[Double], b : BV[Double]) => a :+ b, TripletFields.Src) + ctx => ctx.sendToDst(ctx.srcAttr *:* ctx.attr), + (a : BV[Double], b : BV[Double]) => a +:+ b, TripletFields.Src) rankGraph = rankGraph.outerJoinVertices(rankUpdates) { (vid, oldRank, msgSumOpt) => - val popActivations: BV[Double] = msgSumOpt.getOrElse(zero) :* (1.0 - resetProb) + val popActivations: BV[Double] = msgSumOpt.getOrElse(zero) *:* (1.0 - resetProb) val resetActivations = if (sourcesInitMapBC.value contains vid) { - sourcesInitMapBC.value(vid) :* resetProb + sourcesInitMapBC.value(vid) *:* resetProb } else { zero } - popActivations :+ resetActivations + popActivations +:+ resetActivations }.cache() rankGraph.edges.foreachPartition(x => {}) // also materializes rankGraph.vertices @@ -245,8 +249,10 @@ object PageRank extends Logging { i += 1 } + // SPARK-18847 If the graph has sinks (vertices with no outgoing edges) correct the sum of ranks + val rankSums = rankGraph.vertices.values.fold(zero)(_ +:+ _) rankGraph.mapVertices { (vid, attr) => - Vectors.fromBreeze(attr) + Vectors.fromBreeze(attr /:/ rankSums) } } @@ -307,7 +313,7 @@ object PageRank extends Logging { .mapTriplets( e => 1.0 / e.srcAttr ) // Set the vertex attributes to (initialPR, delta = 0) .mapVertices { (id, attr) => - if (id == src) (1.0, Double.NegativeInfinity) else (0.0, 0.0) + if (id == src) (0.0, Double.NegativeInfinity) else (0.0, 0.0) } .cache() @@ -322,13 +328,12 @@ object PageRank extends Logging { def personalizedVertexProgram(id: VertexId, attr: (Double, Double), msgSum: Double): (Double, Double) = { val (oldPR, lastDelta) = attr - var teleport = oldPR - val delta = if (src==id) resetProb else 0.0 - teleport = oldPR*delta - - val newPR = teleport + (1.0 - resetProb) * msgSum - val newDelta = if (lastDelta == Double.NegativeInfinity) newPR else newPR - oldPR - (newPR, newDelta) + val newPR = if (lastDelta == Double.NegativeInfinity) { + 1.0 + } else { + oldPR + (1.0 - resetProb) * msgSum + } + (newPR, newPR - oldPR) } def sendMessage(edge: EdgeTriplet[(Double, Double), Double]) = { @@ -353,9 +358,23 @@ object PageRank extends Logging { vertexProgram(id, attr, msgSum) } - Pregel(pagerankGraph, initialMessage, activeDirection = EdgeDirection.Out)( + val rankGraph = Pregel(pagerankGraph, initialMessage, activeDirection = EdgeDirection.Out)( vp, sendMessage, messageCombiner) .mapVertices((vid, attr) => attr._1) - } // end of deltaPageRank + // SPARK-18847 If the graph has sinks (vertices with no outgoing edges) correct the sum of ranks + normalizeRankSum(rankGraph, personalized) + } + + // Normalizes the sum of ranks to n (or 1 if personalized) + private def normalizeRankSum(rankGraph: Graph[Double, Double], personalized: Boolean) = { + val rankSum = rankGraph.vertices.values.sum() + if (personalized) { + rankGraph.mapVertices((id, rank) => rank / rankSum) + } else { + val numVertices = rankGraph.numVertices + val correctionFactor = numVertices.toDouble / rankSum + rankGraph.mapVertices((id, rank) => rank * correctionFactor) + } + } } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/impl/PeriodicGraphCheckpointer.scala b/graphx/src/main/scala/org/apache/spark/graphx/util/PeriodicGraphCheckpointer.scala similarity index 91% rename from mllib/src/main/scala/org/apache/spark/mllib/impl/PeriodicGraphCheckpointer.scala rename to graphx/src/main/scala/org/apache/spark/graphx/util/PeriodicGraphCheckpointer.scala index 80074897567eb..fda501aa757d6 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/impl/PeriodicGraphCheckpointer.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/util/PeriodicGraphCheckpointer.scala @@ -15,11 +15,12 @@ * limitations under the License. */ -package org.apache.spark.mllib.impl +package org.apache.spark.graphx.util import org.apache.spark.SparkContext import org.apache.spark.graphx.Graph import org.apache.spark.storage.StorageLevel +import org.apache.spark.util.PeriodicCheckpointer /** @@ -74,9 +75,8 @@ import org.apache.spark.storage.StorageLevel * @tparam VD Vertex descriptor type * @tparam ED Edge descriptor type * - * TODO: Move this out of MLlib? */ -private[mllib] class PeriodicGraphCheckpointer[VD, ED]( +private[spark] class PeriodicGraphCheckpointer[VD, ED]( checkpointInterval: Int, sc: SparkContext) extends PeriodicCheckpointer[Graph[VD, ED]](checkpointInterval, sc) { @@ -87,10 +87,13 @@ private[mllib] class PeriodicGraphCheckpointer[VD, ED]( override protected def persist(data: Graph[VD, ED]): Unit = { if (data.vertices.getStorageLevel == StorageLevel.NONE) { - data.vertices.persist() + /* We need to use cache because persist does not honor the default storage level requested + * when constructing the graph. Only cache does that. + */ + data.vertices.cache() } if (data.edges.getStorageLevel == StorageLevel.NONE) { - data.edges.persist() + data.edges.cache() } } diff --git a/graphx/src/test/scala/org/apache/spark/graphx/LocalSparkContext.scala b/graphx/src/test/scala/org/apache/spark/graphx/LocalSparkContext.scala index d2ad9be555770..66c4747fec268 100644 --- a/graphx/src/test/scala/org/apache/spark/graphx/LocalSparkContext.scala +++ b/graphx/src/test/scala/org/apache/spark/graphx/LocalSparkContext.scala @@ -21,7 +21,7 @@ import org.apache.spark.SparkConf import org.apache.spark.SparkContext /** - * Provides a method to run tests against a {@link SparkContext} variable that is correctly stopped + * Provides a method to run tests against a `SparkContext` variable that is correctly stopped * after each test. */ trait LocalSparkContext { diff --git a/graphx/src/test/scala/org/apache/spark/graphx/lib/PageRankSuite.scala b/graphx/src/test/scala/org/apache/spark/graphx/lib/PageRankSuite.scala index 6afbb5a959894..9779553ce85d1 100644 --- a/graphx/src/test/scala/org/apache/spark/graphx/lib/PageRankSuite.scala +++ b/graphx/src/test/scala/org/apache/spark/graphx/lib/PageRankSuite.scala @@ -50,7 +50,8 @@ object GridPageRank { inNbrs(ind).map( nbr => oldPr(nbr) / outDegree(nbr)).sum } } - (0L until (nRows * nCols)).zip(pr) + val prSum = pr.sum + (0L until (nRows * nCols)).zip(pr.map(_ * pr.length / prSum)) } } @@ -68,26 +69,34 @@ class PageRankSuite extends SparkFunSuite with LocalSparkContext { val nVertices = 100 val starGraph = GraphGenerators.starGraph(sc, nVertices).cache() val resetProb = 0.15 + val tol = 0.0001 + val numIter = 2 val errorTol = 1.0e-5 - val staticRanks1 = starGraph.staticPageRank(numIter = 2, resetProb).vertices - val staticRanks2 = starGraph.staticPageRank(numIter = 3, resetProb).vertices.cache() + val staticRanks = starGraph.staticPageRank(numIter, resetProb).vertices.cache() + val staticRanks2 = starGraph.staticPageRank(numIter + 1, resetProb).vertices - // Static PageRank should only take 3 iterations to converge - val notMatching = staticRanks1.innerZipJoin(staticRanks2) { (vid, pr1, pr2) => + // Static PageRank should only take 2 iterations to converge + val notMatching = staticRanks.innerZipJoin(staticRanks2) { (vid, pr1, pr2) => if (pr1 != pr2) 1 else 0 }.map { case (vid, test) => test }.sum() assert(notMatching === 0) - val staticErrors = staticRanks2.map { case (vid, pr) => - val p = math.abs(pr - (resetProb + (1.0 - resetProb) * (resetProb * (nVertices - 1)) )) - val correct = (vid > 0 && pr == resetProb) || (vid == 0L && p < 1.0E-5) - if (!correct) 1 else 0 - } - assert(staticErrors.sum === 0) + val dynamicRanks = starGraph.pageRank(tol, resetProb).vertices.cache() + assert(compareRanks(staticRanks, dynamicRanks) < errorTol) + + // Computed in igraph 1.0 w/ R bindings: + // > page_rank(make_star(100, mode = "in")) + // Alternatively in NetworkX 1.11: + // > nx.pagerank(nx.DiGraph([(x, 0) for x in range(1,100)])) + // We multiply by the number of vertices to account for difference in normalization + val centerRank = 0.462394787 * nVertices + val othersRank = 0.005430356 * nVertices + val igraphPR = centerRank +: Seq.fill(nVertices - 1)(othersRank) + val ranks = VertexRDD(sc.parallelize(0L until nVertices zip igraphPR)) + assert(compareRanks(staticRanks, ranks) < errorTol) + assert(compareRanks(dynamicRanks, ranks) < errorTol) - val dynamicRanks = starGraph.pageRank(0, resetProb).vertices.cache() - assert(compareRanks(staticRanks2, dynamicRanks) < errorTol) } } // end of test Star PageRank @@ -96,51 +105,62 @@ class PageRankSuite extends SparkFunSuite with LocalSparkContext { val nVertices = 100 val starGraph = GraphGenerators.starGraph(sc, nVertices).cache() val resetProb = 0.15 + val tol = 0.0001 + val numIter = 2 val errorTol = 1.0e-5 - val staticRanks1 = starGraph.staticPersonalizedPageRank(0, numIter = 1, resetProb).vertices - val staticRanks2 = starGraph.staticPersonalizedPageRank(0, numIter = 2, resetProb) - .vertices.cache() - - // Static PageRank should only take 2 iterations to converge - val notMatching = staticRanks1.innerZipJoin(staticRanks2) { (vid, pr1, pr2) => - if (pr1 != pr2) 1 else 0 - }.map { case (vid, test) => test }.sum - assert(notMatching === 0) + val staticRanks = starGraph.staticPersonalizedPageRank(0, numIter, resetProb).vertices.cache() - val staticErrors = staticRanks2.map { case (vid, pr) => - val correct = (vid > 0 && pr == 0.0) || - (vid == 0 && pr == resetProb) - if (!correct) 1 else 0 - } - assert(staticErrors.sum === 0) - - val dynamicRanks = starGraph.personalizedPageRank(0, 0, resetProb).vertices.cache() - assert(compareRanks(staticRanks2, dynamicRanks) < errorTol) + val dynamicRanks = starGraph.personalizedPageRank(0, tol, resetProb).vertices.cache() + assert(compareRanks(staticRanks, dynamicRanks) < errorTol) - val parallelStaticRanks1 = starGraph - .staticParallelPersonalizedPageRank(Array(0), 1, resetProb).mapVertices { + val parallelStaticRanks = starGraph + .staticParallelPersonalizedPageRank(Array(0), numIter, resetProb).mapVertices { case (vertexId, vector) => vector(0) }.vertices.cache() - assert(compareRanks(staticRanks1, parallelStaticRanks1) < errorTol) + assert(compareRanks(staticRanks, parallelStaticRanks) < errorTol) + + // Computed in igraph 1.0 w/ R bindings: + // > page_rank(make_star(100, mode = "in"), personalized = c(1, rep(0, 99)), algo = "arpack") + // NOTE: We use the arpack algorithm as prpack (the default) redistributes rank to all + // vertices uniformly instead of just to the personalization source. + // Alternatively in NetworkX 1.11: + // > nx.pagerank(nx.DiGraph([(x, 0) for x in range(1,100)]), + // personalization=dict([(x, 1 if x == 0 else 0) for x in range(0,100)])) + // We multiply by the number of vertices to account for difference in normalization + val igraphPR0 = 1.0 +: Seq.fill(nVertices - 1)(0.0) + val ranks0 = VertexRDD(sc.parallelize(0L until nVertices zip igraphPR0)) + assert(compareRanks(staticRanks, ranks0) < errorTol) + assert(compareRanks(dynamicRanks, ranks0) < errorTol) - val parallelStaticRanks2 = starGraph - .staticParallelPersonalizedPageRank(Array(0, 1), 2, resetProb).mapVertices { - case (vertexId, vector) => vector(0) - }.vertices.cache() - assert(compareRanks(staticRanks2, parallelStaticRanks2) < errorTol) // We have one outbound edge from 1 to 0 - val otherStaticRanks2 = starGraph.staticPersonalizedPageRank(1, numIter = 2, resetProb) + val otherStaticRanks = starGraph.staticPersonalizedPageRank(1, numIter, resetProb) .vertices.cache() - val otherDynamicRanks = starGraph.personalizedPageRank(1, 0, resetProb).vertices.cache() - val otherParallelStaticRanks2 = starGraph - .staticParallelPersonalizedPageRank(Array(0, 1), 2, resetProb).mapVertices { + val otherDynamicRanks = starGraph.personalizedPageRank(1, tol, resetProb).vertices.cache() + val otherParallelStaticRanks = starGraph + .staticParallelPersonalizedPageRank(Array(0, 1), numIter, resetProb).mapVertices { case (vertexId, vector) => vector(1) }.vertices.cache() - assert(compareRanks(otherDynamicRanks, otherStaticRanks2) < errorTol) - assert(compareRanks(otherStaticRanks2, otherParallelStaticRanks2) < errorTol) - assert(compareRanks(otherDynamicRanks, otherParallelStaticRanks2) < errorTol) + assert(compareRanks(otherDynamicRanks, otherStaticRanks) < errorTol) + assert(compareRanks(otherStaticRanks, otherParallelStaticRanks) < errorTol) + assert(compareRanks(otherDynamicRanks, otherParallelStaticRanks) < errorTol) + + // Computed in igraph 1.0 w/ R bindings: + // > page_rank(make_star(100, mode = "in"), + // personalized = c(0, 1, rep(0, 98)), algo = "arpack") + // NOTE: We use the arpack algorithm as prpack (the default) redistributes rank to all + // vertices uniformly instead of just to the personalization source. + // Alternatively in NetworkX 1.11: + // > nx.pagerank(nx.DiGraph([(x, 0) for x in range(1,100)]), + // personalization=dict([(x, 1 if x == 1 else 0) for x in range(0,100)])) + val centerRank = 0.4594595 + val sourceRank = 0.5405405 + val igraphPR1 = centerRank +: sourceRank +: Seq.fill(nVertices - 2)(0.0) + val ranks1 = VertexRDD(sc.parallelize(0L until nVertices zip igraphPR1)) + assert(compareRanks(otherStaticRanks, ranks1) < errorTol) + assert(compareRanks(otherDynamicRanks, ranks1) < errorTol) + assert(compareRanks(otherParallelStaticRanks, ranks1) < errorTol) } } // end of test Star PersonalPageRank @@ -229,4 +249,50 @@ class PageRankSuite extends SparkFunSuite with LocalSparkContext { } } + + test("Loop with sink PageRank") { + withSpark { sc => + val edges = sc.parallelize((1L, 2L) :: (2L, 3L) :: (3L, 1L) :: (1L, 4L) :: Nil) + val g = Graph.fromEdgeTuples(edges, 1) + val resetProb = 0.15 + val tol = 0.0001 + val numIter = 20 + val errorTol = 1.0e-5 + + val staticRanks = g.staticPageRank(numIter, resetProb).vertices.cache() + val dynamicRanks = g.pageRank(tol, resetProb).vertices.cache() + + assert(compareRanks(staticRanks, dynamicRanks) < errorTol) + + // Computed in igraph 1.0 w/ R bindings: + // > page_rank(graph_from_literal( A -+ B -+ C -+ A -+ D)) + // Alternatively in NetworkX 1.11: + // > nx.pagerank(nx.DiGraph([(1,2),(2,3),(3,1),(1,4)])) + // We multiply by the number of vertices to account for difference in normalization + val igraphPR = Seq(0.3078534, 0.2137622, 0.2646223, 0.2137622).map(_ * 4) + val ranks = VertexRDD(sc.parallelize(1L to 4L zip igraphPR)) + assert(compareRanks(staticRanks, ranks) < errorTol) + assert(compareRanks(dynamicRanks, ranks) < errorTol) + + val p1staticRanks = g.staticPersonalizedPageRank(1, numIter, resetProb).vertices.cache() + val p1dynamicRanks = g.personalizedPageRank(1, tol, resetProb).vertices.cache() + val p1parallelDynamicRanks = + g.staticParallelPersonalizedPageRank(Array(1, 2, 3, 4), numIter, resetProb) + .vertices.mapValues(v => v(0)).cache() + + // Computed in igraph 1.0 w/ R bindings: + // > page_rank(graph_from_literal( A -+ B -+ C -+ A -+ D), personalized = c(1, 0, 0, 0), + // algo = "arpack") + // NOTE: We use the arpack algorithm as prpack (the default) redistributes rank to all + // vertices uniformly instead of just to the personalization source. + // Alternatively in NetworkX 1.11: + // > nx.pagerank(nx.DiGraph([(1,2),(2,3),(3,1),(1,4)]), personalization={1:1, 2:0, 3:0, 4:0}) + val igraphPR2 = Seq(0.4522329, 0.1921990, 0.1633691, 0.1921990) + val ranks2 = VertexRDD(sc.parallelize(1L to 4L zip igraphPR2)) + assert(compareRanks(p1staticRanks, ranks2) < errorTol) + assert(compareRanks(p1dynamicRanks, ranks2) < errorTol) + assert(compareRanks(p1parallelDynamicRanks, ranks2) < errorTol) + + } + } } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/impl/PeriodicGraphCheckpointerSuite.scala b/graphx/src/test/scala/org/apache/spark/graphx/util/PeriodicGraphCheckpointerSuite.scala similarity index 70% rename from mllib/src/test/scala/org/apache/spark/mllib/impl/PeriodicGraphCheckpointerSuite.scala rename to graphx/src/test/scala/org/apache/spark/graphx/util/PeriodicGraphCheckpointerSuite.scala index a13e7f63a9296..e0c65e6940f66 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/impl/PeriodicGraphCheckpointerSuite.scala +++ b/graphx/src/test/scala/org/apache/spark/graphx/util/PeriodicGraphCheckpointerSuite.scala @@ -15,77 +15,81 @@ * limitations under the License. */ -package org.apache.spark.mllib.impl +package org.apache.spark.graphx.util import org.apache.hadoop.fs.Path import org.apache.spark.{SparkContext, SparkFunSuite} -import org.apache.spark.graphx.{Edge, Graph} -import org.apache.spark.mllib.util.MLlibTestSparkContext +import org.apache.spark.graphx.{Edge, Graph, LocalSparkContext} import org.apache.spark.storage.StorageLevel import org.apache.spark.util.Utils -class PeriodicGraphCheckpointerSuite extends SparkFunSuite with MLlibTestSparkContext { +class PeriodicGraphCheckpointerSuite extends SparkFunSuite with LocalSparkContext { import PeriodicGraphCheckpointerSuite._ test("Persisting") { var graphsToCheck = Seq.empty[GraphToCheck] - val graph1 = createGraph(sc) - val checkpointer = - new PeriodicGraphCheckpointer[Double, Double](10, graph1.vertices.sparkContext) - checkpointer.update(graph1) - graphsToCheck = graphsToCheck :+ GraphToCheck(graph1, 1) - checkPersistence(graphsToCheck, 1) - - var iteration = 2 - while (iteration < 9) { - val graph = createGraph(sc) - checkpointer.update(graph) - graphsToCheck = graphsToCheck :+ GraphToCheck(graph, iteration) - checkPersistence(graphsToCheck, iteration) - iteration += 1 + withSpark { sc => + val graph1 = createGraph(sc) + val checkpointer = + new PeriodicGraphCheckpointer[Double, Double](10, graph1.vertices.sparkContext) + checkpointer.update(graph1) + graphsToCheck = graphsToCheck :+ GraphToCheck(graph1, 1) + checkPersistence(graphsToCheck, 1) + + var iteration = 2 + while (iteration < 9) { + val graph = createGraph(sc) + checkpointer.update(graph) + graphsToCheck = graphsToCheck :+ GraphToCheck(graph, iteration) + checkPersistence(graphsToCheck, iteration) + iteration += 1 + } } } test("Checkpointing") { - val tempDir = Utils.createTempDir() - val path = tempDir.toURI.toString - val checkpointInterval = 2 - var graphsToCheck = Seq.empty[GraphToCheck] - sc.setCheckpointDir(path) - val graph1 = createGraph(sc) - val checkpointer = new PeriodicGraphCheckpointer[Double, Double]( - checkpointInterval, graph1.vertices.sparkContext) - checkpointer.update(graph1) - graph1.edges.count() - graph1.vertices.count() - graphsToCheck = graphsToCheck :+ GraphToCheck(graph1, 1) - checkCheckpoint(graphsToCheck, 1, checkpointInterval) - - var iteration = 2 - while (iteration < 9) { - val graph = createGraph(sc) - checkpointer.update(graph) - graph.vertices.count() - graph.edges.count() - graphsToCheck = graphsToCheck :+ GraphToCheck(graph, iteration) - checkCheckpoint(graphsToCheck, iteration, checkpointInterval) - iteration += 1 - } + withSpark { sc => + val tempDir = Utils.createTempDir() + val path = tempDir.toURI.toString + val checkpointInterval = 2 + var graphsToCheck = Seq.empty[GraphToCheck] + sc.setCheckpointDir(path) + val graph1 = createGraph(sc) + val checkpointer = new PeriodicGraphCheckpointer[Double, Double]( + checkpointInterval, graph1.vertices.sparkContext) + checkpointer.update(graph1) + graph1.edges.count() + graph1.vertices.count() + graphsToCheck = graphsToCheck :+ GraphToCheck(graph1, 1) + checkCheckpoint(graphsToCheck, 1, checkpointInterval) + + var iteration = 2 + while (iteration < 9) { + val graph = createGraph(sc) + checkpointer.update(graph) + graph.vertices.count() + graph.edges.count() + graphsToCheck = graphsToCheck :+ GraphToCheck(graph, iteration) + checkCheckpoint(graphsToCheck, iteration, checkpointInterval) + iteration += 1 + } - checkpointer.deleteAllCheckpoints() - graphsToCheck.foreach { graph => - confirmCheckpointRemoved(graph.graph) - } + checkpointer.deleteAllCheckpoints() + graphsToCheck.foreach { graph => + confirmCheckpointRemoved(graph.graph) + } - Utils.deleteRecursively(tempDir) + Utils.deleteRecursively(tempDir) + } } } private object PeriodicGraphCheckpointerSuite { + private val defaultStorageLevel = StorageLevel.MEMORY_ONLY_SER case class GraphToCheck(graph: Graph[Double, Double], gIndex: Int) @@ -96,7 +100,8 @@ private object PeriodicGraphCheckpointerSuite { Edge[Double](3, 4, 0)) def createGraph(sc: SparkContext): Graph[Double, Double] = { - Graph.fromEdges[Double, Double](sc.parallelize(edges), 0) + Graph.fromEdges[Double, Double]( + sc.parallelize(edges), 0, defaultStorageLevel, defaultStorageLevel) } def checkPersistence(graphs: Seq[GraphToCheck], iteration: Int): Unit = { @@ -116,8 +121,8 @@ private object PeriodicGraphCheckpointerSuite { assert(graph.vertices.getStorageLevel == StorageLevel.NONE) assert(graph.edges.getStorageLevel == StorageLevel.NONE) } else { - assert(graph.vertices.getStorageLevel != StorageLevel.NONE) - assert(graph.edges.getStorageLevel != StorageLevel.NONE) + assert(graph.vertices.getStorageLevel == defaultStorageLevel) + assert(graph.edges.getStorageLevel == defaultStorageLevel) } } catch { case _: AssertionError => diff --git a/launcher/pom.xml b/launcher/pom.xml index 025cd84f20f0e..e9b46c4cf0ffa 100644 --- a/launcher/pom.xml +++ b/launcher/pom.xml @@ -22,7 +22,7 @@ org.apache.spark spark-parent_2.11 - 2.2.0-SNAPSHOT + 2.3.0-SNAPSHOT ../pom.xml diff --git a/launcher/src/main/java/org/apache/spark/launcher/AbstractCommandBuilder.java b/launcher/src/main/java/org/apache/spark/launcher/AbstractCommandBuilder.java index bc8d6037a367b..6c0c3ebcaebf4 100644 --- a/launcher/src/main/java/org/apache/spark/launcher/AbstractCommandBuilder.java +++ b/launcher/src/main/java/org/apache/spark/launcher/AbstractCommandBuilder.java @@ -135,7 +135,6 @@ List buildClassPath(String appClassPath) throws IOException { String sparkHome = getSparkHome(); Set cp = new LinkedHashSet<>(); - addToClassPath(cp, getenv("SPARK_CLASSPATH")); addToClassPath(cp, appClassPath); addToClassPath(cp, getConfDir()); diff --git a/launcher/src/main/java/org/apache/spark/launcher/SparkClassCommandBuilder.java b/launcher/src/main/java/org/apache/spark/launcher/SparkClassCommandBuilder.java index 81786841de224..7cf5b7379503f 100644 --- a/launcher/src/main/java/org/apache/spark/launcher/SparkClassCommandBuilder.java +++ b/launcher/src/main/java/org/apache/spark/launcher/SparkClassCommandBuilder.java @@ -66,7 +66,6 @@ public List buildCommand(Map env) memKey = "SPARK_DAEMON_MEMORY"; break; case "org.apache.spark.executor.CoarseGrainedExecutorBackend": - javaOptsKeys.add("SPARK_JAVA_OPTS"); javaOptsKeys.add("SPARK_EXECUTOR_OPTS"); memKey = "SPARK_EXECUTOR_MEMORY"; break; @@ -84,7 +83,6 @@ public List buildCommand(Map env) memKey = "SPARK_DAEMON_MEMORY"; break; default: - javaOptsKeys.add("SPARK_JAVA_OPTS"); memKey = "SPARK_DRIVER_MEMORY"; break; } diff --git a/launcher/src/main/java/org/apache/spark/launcher/SparkSubmitCommandBuilder.java b/launcher/src/main/java/org/apache/spark/launcher/SparkSubmitCommandBuilder.java index 5e64fa7ed152c..5f2da036ff9f7 100644 --- a/launcher/src/main/java/org/apache/spark/launcher/SparkSubmitCommandBuilder.java +++ b/launcher/src/main/java/org/apache/spark/launcher/SparkSubmitCommandBuilder.java @@ -240,7 +240,6 @@ private List buildSparkSubmitCommand(Map env) addOptionString(cmd, System.getenv("SPARK_DAEMON_JAVA_OPTS")); } addOptionString(cmd, System.getenv("SPARK_SUBMIT_OPTS")); - addOptionString(cmd, System.getenv("SPARK_JAVA_OPTS")); // We don't want the client to specify Xmx. These have to be set by their corresponding // memory flag --driver-memory or configuration entry spark.driver.memory diff --git a/mllib-local/pom.xml b/mllib-local/pom.xml index 663f7fb0b010d..043d13609fd26 100644 --- a/mllib-local/pom.xml +++ b/mllib-local/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.11 - 2.2.0-SNAPSHOT + 2.3.0-SNAPSHOT ../pom.xml diff --git a/mllib-local/src/main/scala/org/apache/spark/ml/linalg/Matrices.scala b/mllib-local/src/main/scala/org/apache/spark/ml/linalg/Matrices.scala index d9ffdeb797fb8..07f3bc27280bd 100644 --- a/mllib-local/src/main/scala/org/apache/spark/ml/linalg/Matrices.scala +++ b/mllib-local/src/main/scala/org/apache/spark/ml/linalg/Matrices.scala @@ -44,6 +44,12 @@ sealed trait Matrix extends Serializable { @Since("2.0.0") val isTransposed: Boolean = false + /** Indicates whether the values backing this matrix are arranged in column major order. */ + private[ml] def isColMajor: Boolean = !isTransposed + + /** Indicates whether the values backing this matrix are arranged in row major order. */ + private[ml] def isRowMajor: Boolean = isTransposed + /** Converts to a dense array in column major. */ @Since("2.0.0") def toArray: Array[Double] = { @@ -148,7 +154,8 @@ sealed trait Matrix extends Serializable { * and column indices respectively with the type `Int`, and the final parameter is the * corresponding value in the matrix with type `Double`. */ - private[spark] def foreachActive(f: (Int, Int, Double) => Unit) + @Since("2.2.0") + def foreachActive(f: (Int, Int, Double) => Unit): Unit /** * Find the number of non-zero active values. @@ -161,6 +168,116 @@ sealed trait Matrix extends Serializable { */ @Since("2.0.0") def numActives: Int + + /** + * Converts this matrix to a sparse matrix. + * + * @param colMajor Whether the values of the resulting sparse matrix should be in column major + * or row major order. If `false`, resulting matrix will be row major. + */ + private[ml] def toSparseMatrix(colMajor: Boolean): SparseMatrix + + /** + * Converts this matrix to a sparse matrix in column major order. + */ + @Since("2.2.0") + def toSparseColMajor: SparseMatrix = toSparseMatrix(colMajor = true) + + /** + * Converts this matrix to a sparse matrix in row major order. + */ + @Since("2.2.0") + def toSparseRowMajor: SparseMatrix = toSparseMatrix(colMajor = false) + + /** + * Converts this matrix to a sparse matrix while maintaining the layout of the current matrix. + */ + @Since("2.2.0") + def toSparse: SparseMatrix = toSparseMatrix(colMajor = isColMajor) + + /** + * Converts this matrix to a dense matrix. + * + * @param colMajor Whether the values of the resulting dense matrix should be in column major + * or row major order. If `false`, resulting matrix will be row major. + */ + private[ml] def toDenseMatrix(colMajor: Boolean): DenseMatrix + + /** + * Converts this matrix to a dense matrix while maintaining the layout of the current matrix. + */ + @Since("2.2.0") + def toDense: DenseMatrix = toDenseMatrix(colMajor = isColMajor) + + /** + * Converts this matrix to a dense matrix in row major order. + */ + @Since("2.2.0") + def toDenseRowMajor: DenseMatrix = toDenseMatrix(colMajor = false) + + /** + * Converts this matrix to a dense matrix in column major order. + */ + @Since("2.2.0") + def toDenseColMajor: DenseMatrix = toDenseMatrix(colMajor = true) + + /** + * Returns a matrix in dense or sparse column major format, whichever uses less storage. + */ + @Since("2.2.0") + def compressedColMajor: Matrix = { + if (getDenseSizeInBytes <= getSparseSizeInBytes(colMajor = true)) { + this.toDenseColMajor + } else { + this.toSparseColMajor + } + } + + /** + * Returns a matrix in dense or sparse row major format, whichever uses less storage. + */ + @Since("2.2.0") + def compressedRowMajor: Matrix = { + if (getDenseSizeInBytes <= getSparseSizeInBytes(colMajor = false)) { + this.toDenseRowMajor + } else { + this.toSparseRowMajor + } + } + + /** + * Returns a matrix in dense column major, dense row major, sparse row major, or sparse column + * major format, whichever uses less storage. When dense representation is optimal, it maintains + * the current layout order. + */ + @Since("2.2.0") + def compressed: Matrix = { + val cscSize = getSparseSizeInBytes(colMajor = true) + val csrSize = getSparseSizeInBytes(colMajor = false) + if (getDenseSizeInBytes <= math.min(cscSize, csrSize)) { + // dense matrix size is the same for column major and row major, so maintain current layout + this.toDense + } else if (cscSize <= csrSize) { + this.toSparseColMajor + } else { + this.toSparseRowMajor + } + } + + /** Gets the size of the dense representation of this `Matrix`. */ + private[ml] def getDenseSizeInBytes: Long = { + Matrices.getDenseSize(numCols, numRows) + } + + /** Gets the size of the minimal sparse representation of this `Matrix`. */ + private[ml] def getSparseSizeInBytes(colMajor: Boolean): Long = { + val nnz = numNonzeros + val numPtrs = if (colMajor) numCols + 1L else numRows + 1L + Matrices.getSparseSize(nnz, numPtrs) + } + + /** Gets the current size in bytes of this `Matrix`. Useful for testing */ + private[ml] def getSizeInBytes: Long } /** @@ -258,7 +375,7 @@ class DenseMatrix @Since("2.0.0") ( override def transpose: DenseMatrix = new DenseMatrix(numCols, numRows, values, !isTransposed) - private[spark] override def foreachActive(f: (Int, Int, Double) => Unit): Unit = { + override def foreachActive(f: (Int, Int, Double) => Unit): Unit = { if (!isTransposed) { // outer loop over columns var j = 0 @@ -291,31 +408,49 @@ class DenseMatrix @Since("2.0.0") ( override def numActives: Int = values.length /** - * Generate a `SparseMatrix` from the given `DenseMatrix`. The new matrix will have isTransposed - * set to false. + * Generate a `SparseMatrix` from the given `DenseMatrix`. + * + * @param colMajor Whether the resulting `SparseMatrix` values will be in column major order. */ - @Since("2.0.0") - def toSparse: SparseMatrix = { - val spVals: MArrayBuilder[Double] = new MArrayBuilder.ofDouble - val colPtrs: Array[Int] = new Array[Int](numCols + 1) - val rowIndices: MArrayBuilder[Int] = new MArrayBuilder.ofInt - var nnz = 0 - var j = 0 - while (j < numCols) { - var i = 0 - while (i < numRows) { - val v = values(index(i, j)) - if (v != 0.0) { - rowIndices += i - spVals += v - nnz += 1 + private[ml] override def toSparseMatrix(colMajor: Boolean): SparseMatrix = { + if (!colMajor) this.transpose.toSparseColMajor.transpose + else { + val spVals: MArrayBuilder[Double] = new MArrayBuilder.ofDouble + val colPtrs: Array[Int] = new Array[Int](numCols + 1) + val rowIndices: MArrayBuilder[Int] = new MArrayBuilder.ofInt + var nnz = 0 + var j = 0 + while (j < numCols) { + var i = 0 + while (i < numRows) { + val v = values(index(i, j)) + if (v != 0.0) { + rowIndices += i + spVals += v + nnz += 1 + } + i += 1 } - i += 1 + j += 1 + colPtrs(j) = nnz } - j += 1 - colPtrs(j) = nnz + new SparseMatrix(numRows, numCols, colPtrs, rowIndices.result(), spVals.result()) + } + } + + /** + * Generate a `DenseMatrix` from this `DenseMatrix`. + * + * @param colMajor Whether the resulting `DenseMatrix` values will be in column major order. + */ + private[ml] override def toDenseMatrix(colMajor: Boolean): DenseMatrix = { + if (isRowMajor && colMajor) { + new DenseMatrix(numRows, numCols, this.toArray, isTransposed = false) + } else if (isColMajor && !colMajor) { + new DenseMatrix(numRows, numCols, this.transpose.toArray, isTransposed = true) + } else { + this } - new SparseMatrix(numRows, numCols, colPtrs, rowIndices.result(), spVals.result()) } override def colIter: Iterator[Vector] = { @@ -331,6 +466,8 @@ class DenseMatrix @Since("2.0.0") ( } } } + + private[ml] def getSizeInBytes: Long = Matrices.getDenseSize(numCols, numRows) } /** @@ -560,7 +697,7 @@ class SparseMatrix @Since("2.0.0") ( override def transpose: SparseMatrix = new SparseMatrix(numCols, numRows, colPtrs, rowIndices, values, !isTransposed) - private[spark] override def foreachActive(f: (Int, Int, Double) => Unit): Unit = { + override def foreachActive(f: (Int, Int, Double) => Unit): Unit = { if (!isTransposed) { var j = 0 while (j < numCols) { @@ -587,18 +724,67 @@ class SparseMatrix @Since("2.0.0") ( } } + override def numNonzeros: Int = values.count(_ != 0) + + override def numActives: Int = values.length + /** - * Generate a `DenseMatrix` from the given `SparseMatrix`. The new matrix will have isTransposed - * set to false. + * Generate a `SparseMatrix` from this `SparseMatrix`, removing explicit zero values if they + * exist. + * + * @param colMajor Whether or not the resulting `SparseMatrix` values are in column major + * order. */ - @Since("2.0.0") - def toDense: DenseMatrix = { - new DenseMatrix(numRows, numCols, toArray) + private[ml] override def toSparseMatrix(colMajor: Boolean): SparseMatrix = { + if (isColMajor && !colMajor) { + // it is col major and we want row major, use breeze to remove explicit zeros + val breezeTransposed = asBreeze.asInstanceOf[BSM[Double]].t + Matrices.fromBreeze(breezeTransposed).transpose.asInstanceOf[SparseMatrix] + } else if (isRowMajor && colMajor) { + // it is row major and we want col major, use breeze to remove explicit zeros + val breezeTransposed = asBreeze.asInstanceOf[BSM[Double]] + Matrices.fromBreeze(breezeTransposed).asInstanceOf[SparseMatrix] + } else { + val nnz = numNonzeros + if (nnz != numActives) { + // remove explicit zeros + val rr = new Array[Int](nnz) + val vv = new Array[Double](nnz) + val numPtrs = if (isRowMajor) numRows else numCols + val cc = new Array[Int](numPtrs + 1) + var nzIdx = 0 + var j = 0 + while (j < numPtrs) { + var idx = colPtrs(j) + val idxEnd = colPtrs(j + 1) + cc(j) = nzIdx + while (idx < idxEnd) { + if (values(idx) != 0.0) { + vv(nzIdx) = values(idx) + rr(nzIdx) = rowIndices(idx) + nzIdx += 1 + } + idx += 1 + } + j += 1 + } + cc(j) = nnz + new SparseMatrix(numRows, numCols, cc, rr, vv, isTransposed = isTransposed) + } else { + this + } + } } - override def numNonzeros: Int = values.count(_ != 0) - - override def numActives: Int = values.length + /** + * Generate a `DenseMatrix` from the given `SparseMatrix`. + * + * @param colMajor Whether the resulting `DenseMatrix` values are in column major order. + */ + private[ml] override def toDenseMatrix(colMajor: Boolean): DenseMatrix = { + if (colMajor) new DenseMatrix(numRows, numCols, this.toArray) + else new DenseMatrix(numRows, numCols, this.transpose.toArray, isTransposed = true) + } override def colIter: Iterator[Vector] = { if (isTransposed) { @@ -631,6 +817,8 @@ class SparseMatrix @Since("2.0.0") ( } } } + + private[ml] def getSizeInBytes: Long = Matrices.getSparseSize(numActives, colPtrs.length) } /** @@ -1079,4 +1267,26 @@ object Matrices { SparseMatrix.fromCOO(numRows, numCols, entries) } } + + private[ml] def getSparseSize(numActives: Long, numPtrs: Long): Long = { + /* + Sparse matrices store two int arrays, one double array, two ints, and one boolean: + 8 * values.length + 4 * rowIndices.length + 4 * colPtrs.length + arrayHeader * 3 + 2 * 4 + 1 + */ + val doubleBytes = java.lang.Double.BYTES + val intBytes = java.lang.Integer.BYTES + val arrayHeader = 12L + doubleBytes * numActives + intBytes * numActives + intBytes * numPtrs + arrayHeader * 3L + 9L + } + + private[ml] def getDenseSize(numCols: Long, numRows: Long): Long = { + /* + Dense matrices store one double array, two ints, and one boolean: + 8 * values.length + arrayHeader + 2 * 4 + 1 + */ + val doubleBytes = java.lang.Double.BYTES + val arrayHeader = 12L + doubleBytes * numCols * numRows + arrayHeader + 9L + } + } diff --git a/mllib-local/src/test/scala/org/apache/spark/ml/linalg/MatricesSuite.scala b/mllib-local/src/test/scala/org/apache/spark/ml/linalg/MatricesSuite.scala index 9c0aa73938478..9f8202086817d 100644 --- a/mllib-local/src/test/scala/org/apache/spark/ml/linalg/MatricesSuite.scala +++ b/mllib-local/src/test/scala/org/apache/spark/ml/linalg/MatricesSuite.scala @@ -160,22 +160,416 @@ class MatricesSuite extends SparkMLFunSuite { assert(sparseMat.values(2) === 10.0) } - test("toSparse, toDense") { - val m = 3 - val n = 2 - val values = Array(1.0, 2.0, 4.0, 5.0) - val allValues = Array(1.0, 2.0, 0.0, 0.0, 4.0, 5.0) - val colPtrs = Array(0, 2, 4) - val rowIndices = Array(0, 1, 1, 2) + test("dense to dense") { + /* + dm1 = 4.0 2.0 -8.0 + -1.0 7.0 4.0 + + dm2 = 5.0 -9.0 4.0 + 1.0 -3.0 -8.0 + */ + val dm1 = new DenseMatrix(2, 3, Array(4.0, -1.0, 2.0, 7.0, -8.0, 4.0)) + val dm2 = new DenseMatrix(2, 3, Array(5.0, -9.0, 4.0, 1.0, -3.0, -8.0), isTransposed = true) + + val dm8 = dm1.toDenseColMajor + assert(dm8 === dm1) + assert(dm8.isColMajor) + assert(dm8.values.equals(dm1.values)) + + val dm5 = dm2.toDenseColMajor + assert(dm5 === dm2) + assert(dm5.isColMajor) + assert(dm5.values === Array(5.0, 1.0, -9.0, -3.0, 4.0, -8.0)) + + val dm4 = dm1.toDenseRowMajor + assert(dm4 === dm1) + assert(dm4.isRowMajor) + assert(dm4.values === Array(4.0, 2.0, -8.0, -1.0, 7.0, 4.0)) + + val dm6 = dm2.toDenseRowMajor + assert(dm6 === dm2) + assert(dm6.isRowMajor) + assert(dm6.values.equals(dm2.values)) + + val dm3 = dm1.toDense + assert(dm3 === dm1) + assert(dm3.isColMajor) + assert(dm3.values.equals(dm1.values)) + + val dm9 = dm2.toDense + assert(dm9 === dm2) + assert(dm9.isRowMajor) + assert(dm9.values.equals(dm2.values)) + } - val spMat1 = new SparseMatrix(m, n, colPtrs, rowIndices, values) - val deMat1 = new DenseMatrix(m, n, allValues) + test("dense to sparse") { + /* + dm1 = 0.0 4.0 5.0 + 0.0 2.0 0.0 + + dm2 = 0.0 4.0 5.0 + 0.0 2.0 0.0 - val spMat2 = deMat1.toSparse - val deMat2 = spMat1.toDense + dm3 = 0.0 0.0 0.0 + 0.0 0.0 0.0 + */ + val dm1 = new DenseMatrix(2, 3, Array(0.0, 0.0, 4.0, 2.0, 5.0, 0.0)) + val dm2 = new DenseMatrix(2, 3, Array(0.0, 4.0, 5.0, 0.0, 2.0, 0.0), isTransposed = true) + val dm3 = new DenseMatrix(2, 3, Array(0.0, 0.0, 0.0, 0.0, 0.0, 0.0)) + + val sm1 = dm1.toSparseColMajor + assert(sm1 === dm1) + assert(sm1.isColMajor) + assert(sm1.values === Array(4.0, 2.0, 5.0)) + + val sm3 = dm2.toSparseColMajor + assert(sm3 === dm2) + assert(sm3.isColMajor) + assert(sm3.values === Array(4.0, 2.0, 5.0)) + + val sm5 = dm3.toSparseColMajor + assert(sm5 === dm3) + assert(sm5.values === Array.empty[Double]) + assert(sm5.isColMajor) + + val sm2 = dm1.toSparseRowMajor + assert(sm2 === dm1) + assert(sm2.isRowMajor) + assert(sm2.values === Array(4.0, 5.0, 2.0)) + + val sm4 = dm2.toSparseRowMajor + assert(sm4 === dm2) + assert(sm4.isRowMajor) + assert(sm4.values === Array(4.0, 5.0, 2.0)) + + val sm6 = dm3.toSparseRowMajor + assert(sm6 === dm3) + assert(sm6.values === Array.empty[Double]) + assert(sm6.isRowMajor) + + val sm7 = dm1.toSparse + assert(sm7 === dm1) + assert(sm7.values === Array(4.0, 2.0, 5.0)) + assert(sm7.isColMajor) + + val sm10 = dm2.toSparse + assert(sm10 === dm2) + assert(sm10.values === Array(4.0, 5.0, 2.0)) + assert(sm10.isRowMajor) + } + + test("sparse to sparse") { + /* + sm1 = sm2 = sm3 = sm4 = 0.0 4.0 5.0 + 0.0 2.0 0.0 + smZeros = 0.0 0.0 0.0 + 0.0 0.0 0.0 + */ + val sm1 = new SparseMatrix(2, 3, Array(0, 0, 2, 3), Array(0, 1, 0), Array(4.0, 2.0, 5.0)) + val sm2 = new SparseMatrix(2, 3, Array(0, 2, 3), Array(1, 2, 1), Array(4.0, 5.0, 2.0), + isTransposed = true) + val sm3 = new SparseMatrix(2, 3, Array(0, 0, 2, 4), Array(0, 1, 0, 1), + Array(4.0, 2.0, 5.0, 0.0)) + val sm4 = new SparseMatrix(2, 3, Array(0, 2, 4), Array(1, 2, 1, 2), + Array(4.0, 5.0, 2.0, 0.0), isTransposed = true) + val smZeros = new SparseMatrix(2, 3, Array(0, 2, 4, 6), Array(0, 1, 0, 1, 0, 1), + Array(0.0, 0.0, 0.0, 0.0, 0.0, 0.0)) + + val sm6 = sm1.toSparseColMajor + assert(sm6 === sm1) + assert(sm6.isColMajor) + assert(sm6.values.equals(sm1.values)) + + val sm7 = sm2.toSparseColMajor + assert(sm7 === sm2) + assert(sm7.isColMajor) + assert(sm7.values === Array(4.0, 2.0, 5.0)) + + val sm16 = sm3.toSparseColMajor + assert(sm16 === sm3) + assert(sm16.isColMajor) + assert(sm16.values === Array(4.0, 2.0, 5.0)) + + val sm14 = sm4.toSparseColMajor + assert(sm14 === sm4) + assert(sm14.values === Array(4.0, 2.0, 5.0)) + assert(sm14.isColMajor) + + val sm15 = smZeros.toSparseColMajor + assert(sm15 === smZeros) + assert(sm15.values === Array.empty[Double]) + assert(sm15.isColMajor) + + val sm5 = sm1.toSparseRowMajor + assert(sm5 === sm1) + assert(sm5.isRowMajor) + assert(sm5.values === Array(4.0, 5.0, 2.0)) + + val sm8 = sm2.toSparseRowMajor + assert(sm8 === sm2) + assert(sm8.isRowMajor) + assert(sm8.values.equals(sm2.values)) + + val sm10 = sm3.toSparseRowMajor + assert(sm10 === sm3) + assert(sm10.values === Array(4.0, 5.0, 2.0)) + assert(sm10.isRowMajor) + + val sm11 = sm4.toSparseRowMajor + assert(sm11 === sm4) + assert(sm11.values === Array(4.0, 5.0, 2.0)) + assert(sm11.isRowMajor) + + val sm17 = smZeros.toSparseRowMajor + assert(sm17 === smZeros) + assert(sm17.values === Array.empty[Double]) + assert(sm17.isRowMajor) + + val sm9 = sm3.toSparse + assert(sm9 === sm3) + assert(sm9.values === Array(4.0, 2.0, 5.0)) + assert(sm9.isColMajor) + + val sm12 = sm4.toSparse + assert(sm12 === sm4) + assert(sm12.values === Array(4.0, 5.0, 2.0)) + assert(sm12.isRowMajor) + + val sm13 = smZeros.toSparse + assert(sm13 === smZeros) + assert(sm13.values === Array.empty[Double]) + assert(sm13.isColMajor) + } + + test("sparse to dense") { + /* + sm1 = sm2 = 0.0 4.0 5.0 + 0.0 2.0 0.0 + + sm3 = 0.0 0.0 0.0 + 0.0 0.0 0.0 + */ + val sm1 = new SparseMatrix(2, 3, Array(0, 0, 2, 3), Array(0, 1, 0), Array(4.0, 2.0, 5.0)) + val sm2 = new SparseMatrix(2, 3, Array(0, 2, 3), Array(1, 2, 1), Array(4.0, 5.0, 2.0), + isTransposed = true) + val sm3 = new SparseMatrix(2, 3, Array(0, 0, 0, 0), Array.empty[Int], Array.empty[Double]) + + val dm6 = sm1.toDenseColMajor + assert(dm6 === sm1) + assert(dm6.isColMajor) + assert(dm6.values === Array(0.0, 0.0, 4.0, 2.0, 5.0, 0.0)) + + val dm7 = sm2.toDenseColMajor + assert(dm7 === sm2) + assert(dm7.isColMajor) + assert(dm7.values === Array(0.0, 0.0, 4.0, 2.0, 5.0, 0.0)) + + val dm2 = sm1.toDenseRowMajor + assert(dm2 === sm1) + assert(dm2.isRowMajor) + assert(dm2.values === Array(0.0, 4.0, 5.0, 0.0, 2.0, 0.0)) + + val dm4 = sm2.toDenseRowMajor + assert(dm4 === sm2) + assert(dm4.isRowMajor) + assert(dm4.values === Array(0.0, 4.0, 5.0, 0.0, 2.0, 0.0)) + + val dm1 = sm1.toDense + assert(dm1 === sm1) + assert(dm1.isColMajor) + assert(dm1.values === Array(0.0, 0.0, 4.0, 2.0, 5.0, 0.0)) + + val dm3 = sm2.toDense + assert(dm3 === sm2) + assert(dm3.isRowMajor) + assert(dm3.values === Array(0.0, 4.0, 5.0, 0.0, 2.0, 0.0)) + + val dm5 = sm3.toDense + assert(dm5 === sm3) + assert(dm5.isColMajor) + assert(dm5.values === Array.fill(6)(0.0)) + } + + test("compressed dense") { + /* + dm1 = 1.0 0.0 0.0 0.0 + 1.0 0.0 0.0 0.0 + 0.0 0.0 0.0 0.0 + + dm2 = 1.0 1.0 0.0 0.0 + 0.0 0.0 0.0 0.0 + 0.0 0.0 0.0 0.0 + */ + // this should compress to a sparse matrix + val dm1 = new DenseMatrix(3, 4, Array.fill(2)(1.0) ++ Array.fill(10)(0.0)) + + // optimal compression layout is row major since numRows < numCols + val cm1 = dm1.compressed.asInstanceOf[SparseMatrix] + assert(cm1 === dm1) + assert(cm1.isRowMajor) + assert(cm1.getSizeInBytes < dm1.getSizeInBytes) + + // force compressed column major + val cm2 = dm1.compressedColMajor.asInstanceOf[SparseMatrix] + assert(cm2 === dm1) + assert(cm2.isColMajor) + assert(cm2.getSizeInBytes < dm1.getSizeInBytes) + + // optimal compression layout for transpose is column major + val dm2 = dm1.transpose + val cm3 = dm2.compressed.asInstanceOf[SparseMatrix] + assert(cm3 === dm2) + assert(cm3.isColMajor) + assert(cm3.getSizeInBytes < dm2.getSizeInBytes) + + /* + dm3 = 1.0 1.0 1.0 0.0 + 1.0 1.0 0.0 0.0 + 1.0 1.0 0.0 0.0 + + dm4 = 1.0 1.0 1.0 1.0 + 1.0 1.0 1.0 0.0 + 0.0 0.0 0.0 0.0 + */ + // this should compress to a dense matrix + val dm3 = new DenseMatrix(3, 4, Array.fill(7)(1.0) ++ Array.fill(5)(0.0)) + val dm4 = new DenseMatrix(3, 4, Array.fill(7)(1.0) ++ Array.fill(5)(0.0), isTransposed = true) + + val cm4 = dm3.compressed.asInstanceOf[DenseMatrix] + assert(cm4 === dm3) + assert(cm4.isColMajor) + assert(cm4.values.equals(dm3.values)) + assert(cm4.getSizeInBytes === dm3.getSizeInBytes) + + // force compressed row major + val cm5 = dm3.compressedRowMajor.asInstanceOf[DenseMatrix] + assert(cm5 === dm3) + assert(cm5.isRowMajor) + assert(cm5.getSizeInBytes === dm3.getSizeInBytes) + + val cm6 = dm4.compressed.asInstanceOf[DenseMatrix] + assert(cm6 === dm4) + assert(cm6.isRowMajor) + assert(cm6.values.equals(dm4.values)) + assert(cm6.getSizeInBytes === dm4.getSizeInBytes) + + val cm7 = dm4.compressedColMajor.asInstanceOf[DenseMatrix] + assert(cm7 === dm4) + assert(cm7.isColMajor) + assert(cm7.getSizeInBytes === dm4.getSizeInBytes) + + // this has the same size sparse or dense + val dm5 = new DenseMatrix(4, 4, Array.fill(7)(1.0) ++ Array.fill(9)(0.0)) + // should choose dense to break ties + val cm8 = dm5.compressed.asInstanceOf[DenseMatrix] + assert(cm8.getSizeInBytes === dm5.toSparseColMajor.getSizeInBytes) + } - assert(spMat1.asBreeze === spMat2.asBreeze) - assert(deMat1.asBreeze === deMat2.asBreeze) + test("compressed sparse") { + /* + sm1 = 0.0 -1.0 + 0.0 0.0 + 0.0 0.0 + 0.0 0.0 + + sm2 = 0.0 0.0 0.0 0.0 + -1.0 0.0 0.0 0.0 + */ + // these should compress to sparse matrices + val sm1 = new SparseMatrix(4, 2, Array(0, 0, 1), Array(0), Array(-1.0)) + val sm2 = sm1.transpose + + val cm1 = sm1.compressed.asInstanceOf[SparseMatrix] + // optimal is column major + assert(cm1 === sm1) + assert(cm1.isColMajor) + assert(cm1.values.equals(sm1.values)) + assert(cm1.getSizeInBytes === sm1.getSizeInBytes) + + val cm2 = sm1.compressedRowMajor.asInstanceOf[SparseMatrix] + assert(cm2 === sm1) + assert(cm2.isRowMajor) + // forced to be row major, so we have increased the size + assert(cm2.getSizeInBytes > sm1.getSizeInBytes) + assert(cm2.getSizeInBytes < sm1.toDense.getSizeInBytes) + + val cm9 = sm1.compressedColMajor.asInstanceOf[SparseMatrix] + assert(cm9 === sm1) + assert(cm9.values.equals(sm1.values)) + assert(cm9.getSizeInBytes === sm1.getSizeInBytes) + + val cm3 = sm2.compressed.asInstanceOf[SparseMatrix] + assert(cm3 === sm2) + assert(cm3.isRowMajor) + assert(cm3.values.equals(sm2.values)) + assert(cm3.getSizeInBytes === sm2.getSizeInBytes) + + val cm8 = sm2.compressedColMajor.asInstanceOf[SparseMatrix] + assert(cm8 === sm2) + assert(cm8.isColMajor) + // forced to be col major, so we have increased the size + assert(cm8.getSizeInBytes > sm2.getSizeInBytes) + assert(cm8.getSizeInBytes < sm2.toDense.getSizeInBytes) + + val cm10 = sm2.compressedRowMajor.asInstanceOf[SparseMatrix] + assert(cm10 === sm2) + assert(cm10.isRowMajor) + assert(cm10.values.equals(sm2.values)) + assert(cm10.getSizeInBytes === sm2.getSizeInBytes) + + + /* + sm3 = 0.0 -1.0 + 2.0 3.0 + -4.0 9.0 + */ + // this should compress to a dense matrix + val sm3 = new SparseMatrix(3, 2, Array(0, 2, 5), Array(1, 2, 0, 1, 2), + Array(2.0, -4.0, -1.0, 3.0, 9.0)) + + // dense is optimal, and maintains column major + val cm4 = sm3.compressed.asInstanceOf[DenseMatrix] + assert(cm4 === sm3) + assert(cm4.isColMajor) + assert(cm4.getSizeInBytes < sm3.getSizeInBytes) + + val cm5 = sm3.compressedRowMajor.asInstanceOf[DenseMatrix] + assert(cm5 === sm3) + assert(cm5.isRowMajor) + assert(cm5.getSizeInBytes < sm3.getSizeInBytes) + + val cm11 = sm3.compressedColMajor.asInstanceOf[DenseMatrix] + assert(cm11 === sm3) + assert(cm11.isColMajor) + assert(cm11.getSizeInBytes < sm3.getSizeInBytes) + + /* + sm4 = 1.0 0.0 0.0 ... + + sm5 = 1.0 + 0.0 + 0.0 + ... + */ + val sm4 = new SparseMatrix(Int.MaxValue, 1, Array(0, 1), Array(0), Array(1.0)) + val cm6 = sm4.compressed.asInstanceOf[SparseMatrix] + assert(cm6 === sm4) + assert(cm6.isColMajor) + assert(cm6.getSizeInBytes <= sm4.getSizeInBytes) + + val sm5 = new SparseMatrix(1, Int.MaxValue, Array(0, 1), Array(0), Array(1.0), + isTransposed = true) + val cm7 = sm5.compressed.asInstanceOf[SparseMatrix] + assert(cm7 === sm5) + assert(cm7.isRowMajor) + assert(cm7.getSizeInBytes <= sm5.getSizeInBytes) + + // this has the same size sparse or dense + val sm6 = new SparseMatrix(4, 4, Array(0, 4, 7, 7, 7), Array(0, 1, 2, 3, 0, 1, 2), + Array.fill(7)(1.0)) + // should choose dense to break ties + val cm12 = sm6.compressed.asInstanceOf[DenseMatrix] + assert(cm12.getSizeInBytes === sm6.getSizeInBytes) } test("map, update") { diff --git a/mllib-local/src/test/scala/org/apache/spark/ml/linalg/VectorsSuite.scala b/mllib-local/src/test/scala/org/apache/spark/ml/linalg/VectorsSuite.scala index ea22c2787fb3c..dfbdaf19d374b 100644 --- a/mllib-local/src/test/scala/org/apache/spark/ml/linalg/VectorsSuite.scala +++ b/mllib-local/src/test/scala/org/apache/spark/ml/linalg/VectorsSuite.scala @@ -336,6 +336,11 @@ class VectorsSuite extends SparkMLFunSuite { val sv1 = Vectors.sparse(4, Array(0, 1, 2), Array(1.0, 2.0, 3.0)) val sv1c = sv1.compressed.asInstanceOf[DenseVector] assert(sv1 === sv1c) + + val sv2 = Vectors.sparse(Int.MaxValue, Array(0), Array(3.4)) + val sv2c = sv2.compressed.asInstanceOf[SparseVector] + assert(sv2c === sv2) + assert(sv2c.numActives === 1) } test("SparseVector.slice") { diff --git a/mllib-local/src/test/scala/org/apache/spark/ml/util/TestingUtils.scala b/mllib-local/src/test/scala/org/apache/spark/ml/util/TestingUtils.scala index 2327917e2cad7..6c79d77f142e5 100644 --- a/mllib-local/src/test/scala/org/apache/spark/ml/util/TestingUtils.scala +++ b/mllib-local/src/test/scala/org/apache/spark/ml/util/TestingUtils.scala @@ -32,6 +32,10 @@ object TestingUtils { * the relative tolerance is meaningless, so the exception will be raised to warn users. */ private def RelativeErrorComparison(x: Double, y: Double, eps: Double): Boolean = { + // Special case for NaNs + if (x.isNaN && y.isNaN) { + return true + } val absX = math.abs(x) val absY = math.abs(y) val diff = math.abs(x - y) @@ -49,6 +53,10 @@ object TestingUtils { * Private helper function for comparing two values using absolute tolerance. */ private def AbsoluteErrorComparison(x: Double, y: Double, eps: Double): Boolean = { + // Special case for NaNs + if (x.isNaN && y.isNaN) { + return true + } math.abs(x - y) < eps } @@ -207,7 +215,7 @@ object TestingUtils { if (r.fun(x, r.y, r.eps)) { throw new TestFailedException( s"Did not expect \n$x\n and \n${r.y}\n to be within " + - "${r.eps}${r.method} for all elements.", 0) + s"${r.eps}${r.method} for all elements.", 0) } true } diff --git a/mllib/pom.xml b/mllib/pom.xml index 82f840b0fc269..572670dc11b42 100644 --- a/mllib/pom.xml +++ b/mllib/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.11 - 2.2.0-SNAPSHOT + 2.3.0-SNAPSHOT ../pom.xml diff --git a/mllib/src/main/scala/org/apache/spark/ml/ann/LossFunction.scala b/mllib/src/main/scala/org/apache/spark/ml/ann/LossFunction.scala index 32d78e9b226eb..3aea568cd6527 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/ann/LossFunction.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/ann/LossFunction.scala @@ -56,7 +56,7 @@ private[ann] class SigmoidLayerModelWithSquaredError extends FunctionalLayerModel(new FunctionalLayer(new SigmoidFunction)) with LossFunction { override def loss(output: BDM[Double], target: BDM[Double], delta: BDM[Double]): Double = { ApplyInPlace(output, target, delta, (o: Double, t: Double) => o - t) - val error = Bsum(delta :* delta) / 2 / output.cols + val error = Bsum(delta *:* delta) / 2 / output.cols ApplyInPlace(delta, output, delta, (x: Double, o: Double) => x * (o - o * o)) error } @@ -119,6 +119,6 @@ private[ann] class SoftmaxLayerModelWithCrossEntropyLoss extends LayerModel with override def loss(output: BDM[Double], target: BDM[Double], delta: BDM[Double]): Double = { ApplyInPlace(output, target, delta, (o: Double, t: Double) => o - t) - -Bsum( target :* brzlog(output)) / output.cols + -Bsum( target *:* brzlog(output)) / output.cols } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala index d8608d885d6f1..bc0b49d48d323 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala @@ -74,7 +74,7 @@ abstract class Classifier[ * and features (`Vector`). * @param numClasses Number of classes label can take. Labels must be integers in the range * [0, numClasses). - * @throws SparkException if any label is not an integer >= 0 + * @note Throws `SparkException` if any label is a non-integer or is negative */ protected def extractLabeledPoints(dataset: Dataset[_], numClasses: Int): RDD[LabeledPoint] = { require(numClasses > 0, s"Classifier (in extractLabeledPoints) found numClasses =" + diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/LinearSVC.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/LinearSVC.scala index f76b14eeeb542..7507c7539d4ef 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/LinearSVC.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/LinearSVC.scala @@ -458,9 +458,7 @@ private class LinearSVCAggregator( */ def add(instance: Instance): this.type = { instance match { case Instance(label, weight, features) => - require(weight >= 0.0, s"instance weight, $weight has to be >= 0.0") - require(numFeatures == features.size, s"Dimensions mismatch when adding new instance." + - s" Expecting $numFeatures but got ${features.size}.") + if (weight == 0.0) return this val localFeaturesStd = bcFeaturesStd.value val localCoefficients = coefficientsArray @@ -512,6 +510,7 @@ private class LinearSVCAggregator( * @return This LinearSVCAggregator object. */ def merge(other: LinearSVCAggregator): this.type = { + if (other.weightSum != 0.0) { weightSum += other.weightSum lossSum += other.lossSum diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala index 1a78187d4f8e3..42dc7fbebe4c3 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala @@ -17,10 +17,12 @@ package org.apache.spark.ml.classification +import java.util.Locale + import scala.collection.mutable import breeze.linalg.{DenseVector => BDV} -import breeze.optimize.{CachedDiffFunction, DiffFunction, LBFGS => BreezeLBFGS, OWLQN => BreezeOWLQN} +import breeze.optimize.{CachedDiffFunction, DiffFunction, LBFGS => BreezeLBFGS, LBFGSB => BreezeLBFGSB, OWLQN => BreezeOWLQN} import org.apache.hadoop.fs.Path import org.apache.spark.SparkException @@ -176,11 +178,90 @@ private[classification] trait LogisticRegressionParams extends ProbabilisticClas } } + /** + * The lower bounds on coefficients if fitting under bound constrained optimization. + * The bound matrix must be compatible with the shape (1, number of features) for binomial + * regression, or (number of classes, number of features) for multinomial regression. + * Otherwise, it throws exception. + * Default is none. + * + * @group expertParam + */ + @Since("2.2.0") + val lowerBoundsOnCoefficients: Param[Matrix] = new Param(this, "lowerBoundsOnCoefficients", + "The lower bounds on coefficients if fitting under bound constrained optimization.") + + /** @group expertGetParam */ + @Since("2.2.0") + def getLowerBoundsOnCoefficients: Matrix = $(lowerBoundsOnCoefficients) + + /** + * The upper bounds on coefficients if fitting under bound constrained optimization. + * The bound matrix must be compatible with the shape (1, number of features) for binomial + * regression, or (number of classes, number of features) for multinomial regression. + * Otherwise, it throws exception. + * Default is none. + * + * @group expertParam + */ + @Since("2.2.0") + val upperBoundsOnCoefficients: Param[Matrix] = new Param(this, "upperBoundsOnCoefficients", + "The upper bounds on coefficients if fitting under bound constrained optimization.") + + /** @group expertGetParam */ + @Since("2.2.0") + def getUpperBoundsOnCoefficients: Matrix = $(upperBoundsOnCoefficients) + + /** + * The lower bounds on intercepts if fitting under bound constrained optimization. + * The bounds vector size must be equal with 1 for binomial regression, or the number + * of classes for multinomial regression. Otherwise, it throws exception. + * Default is none. + * + * @group expertParam + */ + @Since("2.2.0") + val lowerBoundsOnIntercepts: Param[Vector] = new Param(this, "lowerBoundsOnIntercepts", + "The lower bounds on intercepts if fitting under bound constrained optimization.") + + /** @group expertGetParam */ + @Since("2.2.0") + def getLowerBoundsOnIntercepts: Vector = $(lowerBoundsOnIntercepts) + + /** + * The upper bounds on intercepts if fitting under bound constrained optimization. + * The bound vector size must be equal with 1 for binomial regression, or the number + * of classes for multinomial regression. Otherwise, it throws exception. + * Default is none. + * + * @group expertParam + */ + @Since("2.2.0") + val upperBoundsOnIntercepts: Param[Vector] = new Param(this, "upperBoundsOnIntercepts", + "The upper bounds on intercepts if fitting under bound constrained optimization.") + + /** @group expertGetParam */ + @Since("2.2.0") + def getUpperBoundsOnIntercepts: Vector = $(upperBoundsOnIntercepts) + + protected def usingBoundConstrainedOptimization: Boolean = { + isSet(lowerBoundsOnCoefficients) || isSet(upperBoundsOnCoefficients) || + isSet(lowerBoundsOnIntercepts) || isSet(upperBoundsOnIntercepts) + } + override protected def validateAndTransformSchema( schema: StructType, fitting: Boolean, featuresDataType: DataType): StructType = { checkThresholdConsistency() + if (usingBoundConstrainedOptimization) { + require($(elasticNetParam) == 0.0, "Fitting under bound constrained optimization only " + + s"supports L2 regularization, but got elasticNetParam = $getElasticNetParam.") + } + if (!$(fitIntercept)) { + require(!isSet(lowerBoundsOnIntercepts) && !isSet(upperBoundsOnIntercepts), + "Please don't set bounds on intercepts if fitting without intercept.") + } super.validateAndTransformSchema(schema, fitting, featuresDataType) } } @@ -215,6 +296,9 @@ class LogisticRegression @Since("1.2.0") ( * For alpha in (0,1), the penalty is a combination of L1 and L2. * Default is 0.0 which is an L2 penalty. * + * Note: Fitting under bound constrained optimization only supports L2 regularization, + * so throws exception if this param is non-zero value. + * * @group setParam */ @Since("1.4.0") @@ -310,6 +394,83 @@ class LogisticRegression @Since("1.2.0") ( def setAggregationDepth(value: Int): this.type = set(aggregationDepth, value) setDefault(aggregationDepth -> 2) + /** + * Set the lower bounds on coefficients if fitting under bound constrained optimization. + * + * @group expertSetParam + */ + @Since("2.2.0") + def setLowerBoundsOnCoefficients(value: Matrix): this.type = set(lowerBoundsOnCoefficients, value) + + /** + * Set the upper bounds on coefficients if fitting under bound constrained optimization. + * + * @group expertSetParam + */ + @Since("2.2.0") + def setUpperBoundsOnCoefficients(value: Matrix): this.type = set(upperBoundsOnCoefficients, value) + + /** + * Set the lower bounds on intercepts if fitting under bound constrained optimization. + * + * @group expertSetParam + */ + @Since("2.2.0") + def setLowerBoundsOnIntercepts(value: Vector): this.type = set(lowerBoundsOnIntercepts, value) + + /** + * Set the upper bounds on intercepts if fitting under bound constrained optimization. + * + * @group expertSetParam + */ + @Since("2.2.0") + def setUpperBoundsOnIntercepts(value: Vector): this.type = set(upperBoundsOnIntercepts, value) + + private def assertBoundConstrainedOptimizationParamsValid( + numCoefficientSets: Int, + numFeatures: Int): Unit = { + if (isSet(lowerBoundsOnCoefficients)) { + require($(lowerBoundsOnCoefficients).numRows == numCoefficientSets && + $(lowerBoundsOnCoefficients).numCols == numFeatures, + "The shape of LowerBoundsOnCoefficients must be compatible with (1, number of features) " + + "for binomial regression, or (number of classes, number of features) for multinomial " + + "regression, but found: " + + s"(${getLowerBoundsOnCoefficients.numRows}, ${getLowerBoundsOnCoefficients.numCols}).") + } + if (isSet(upperBoundsOnCoefficients)) { + require($(upperBoundsOnCoefficients).numRows == numCoefficientSets && + $(upperBoundsOnCoefficients).numCols == numFeatures, + "The shape of upperBoundsOnCoefficients must be compatible with (1, number of features) " + + "for binomial regression, or (number of classes, number of features) for multinomial " + + "regression, but found: " + + s"(${getUpperBoundsOnCoefficients.numRows}, ${getUpperBoundsOnCoefficients.numCols}).") + } + if (isSet(lowerBoundsOnIntercepts)) { + require($(lowerBoundsOnIntercepts).size == numCoefficientSets, "The size of " + + "lowerBoundsOnIntercepts must be equal with 1 for binomial regression, or the number of " + + s"classes for multinomial regression, but found: ${getLowerBoundsOnIntercepts.size}.") + } + if (isSet(upperBoundsOnIntercepts)) { + require($(upperBoundsOnIntercepts).size == numCoefficientSets, "The size of " + + "upperBoundsOnIntercepts must be equal with 1 for binomial regression, or the number of " + + s"classes for multinomial regression, but found: ${getUpperBoundsOnIntercepts.size}.") + } + if (isSet(lowerBoundsOnCoefficients) && isSet(upperBoundsOnCoefficients)) { + require($(lowerBoundsOnCoefficients).toArray.zip($(upperBoundsOnCoefficients).toArray) + .forall(x => x._1 <= x._2), "LowerBoundsOnCoefficients should always be " + + "less than or equal to upperBoundsOnCoefficients, but found: " + + s"lowerBoundsOnCoefficients = $getLowerBoundsOnCoefficients, " + + s"upperBoundsOnCoefficients = $getUpperBoundsOnCoefficients.") + } + if (isSet(lowerBoundsOnIntercepts) && isSet(upperBoundsOnIntercepts)) { + require($(lowerBoundsOnIntercepts).toArray.zip($(upperBoundsOnIntercepts).toArray) + .forall(x => x._1 <= x._2), "LowerBoundsOnIntercepts should always be " + + "less than or equal to upperBoundsOnIntercepts, but found: " + + s"lowerBoundsOnIntercepts = $getLowerBoundsOnIntercepts, " + + s"upperBoundsOnIntercepts = $getUpperBoundsOnIntercepts.") + } + } + private var optInitialModel: Option[LogisticRegressionModel] = None private[spark] def setInitialModel(model: LogisticRegressionModel): this.type = { @@ -376,6 +537,11 @@ class LogisticRegression @Since("1.2.0") ( } val numCoefficientSets = if (isMultinomial) numClasses else 1 + // Check params interaction is valid if fitting under bound constrained optimization. + if (usingBoundConstrainedOptimization) { + assertBoundConstrainedOptimizationParamsValid(numCoefficientSets, numFeatures) + } + if (isDefined(thresholds)) { require($(thresholds).length == numClasses, this.getClass.getSimpleName + ".train() called with non-matching numClasses and thresholds.length." + @@ -395,18 +561,13 @@ class LogisticRegression @Since("1.2.0") ( val isConstantLabel = histogram.count(_ != 0.0) == 1 - if ($(fitIntercept) && isConstantLabel) { + if ($(fitIntercept) && isConstantLabel && !usingBoundConstrainedOptimization) { logWarning(s"All labels are the same value and fitIntercept=true, so the coefficients " + s"will be zeros. Training is not needed.") val constantLabelIndex = Vectors.dense(histogram).argmax - // TODO: use `compressed` after SPARK-17471 - val coefMatrix = if (numFeatures < numCoefficientSets) { - new SparseMatrix(numCoefficientSets, numFeatures, - Array.fill(numFeatures + 1)(0), Array.empty[Int], Array.empty[Double]) - } else { - new SparseMatrix(numCoefficientSets, numFeatures, Array.fill(numCoefficientSets + 1)(0), - Array.empty[Int], Array.empty[Double], isTransposed = true) - } + val coefMatrix = new SparseMatrix(numCoefficientSets, numFeatures, + new Array[Int](numCoefficientSets + 1), Array.empty[Int], Array.empty[Double], + isTransposed = true).compressed val interceptVec = if (isMultinomial) { Vectors.sparse(numClasses, Seq((constantLabelIndex, Double.PositiveInfinity))) } else { @@ -437,8 +598,53 @@ class LogisticRegression @Since("1.2.0") ( $(standardization), bcFeaturesStd, regParamL2, multinomial = isMultinomial, $(aggregationDepth)) + val numCoeffsPlusIntercepts = numFeaturesPlusIntercept * numCoefficientSets + + val (lowerBounds, upperBounds): (Array[Double], Array[Double]) = { + if (usingBoundConstrainedOptimization) { + val lowerBounds = Array.fill[Double](numCoeffsPlusIntercepts)(Double.NegativeInfinity) + val upperBounds = Array.fill[Double](numCoeffsPlusIntercepts)(Double.PositiveInfinity) + val isSetLowerBoundsOnCoefficients = isSet(lowerBoundsOnCoefficients) + val isSetUpperBoundsOnCoefficients = isSet(upperBoundsOnCoefficients) + val isSetLowerBoundsOnIntercepts = isSet(lowerBoundsOnIntercepts) + val isSetUpperBoundsOnIntercepts = isSet(upperBoundsOnIntercepts) + + var i = 0 + while (i < numCoeffsPlusIntercepts) { + val coefficientSetIndex = i % numCoefficientSets + val featureIndex = i / numCoefficientSets + if (featureIndex < numFeatures) { + if (isSetLowerBoundsOnCoefficients) { + lowerBounds(i) = $(lowerBoundsOnCoefficients)( + coefficientSetIndex, featureIndex) * featuresStd(featureIndex) + } + if (isSetUpperBoundsOnCoefficients) { + upperBounds(i) = $(upperBoundsOnCoefficients)( + coefficientSetIndex, featureIndex) * featuresStd(featureIndex) + } + } else { + if (isSetLowerBoundsOnIntercepts) { + lowerBounds(i) = $(lowerBoundsOnIntercepts)(coefficientSetIndex) + } + if (isSetUpperBoundsOnIntercepts) { + upperBounds(i) = $(upperBoundsOnIntercepts)(coefficientSetIndex) + } + } + i += 1 + } + (lowerBounds, upperBounds) + } else { + (null, null) + } + } + val optimizer = if ($(elasticNetParam) == 0.0 || $(regParam) == 0.0) { - new BreezeLBFGS[BDV[Double]]($(maxIter), 10, $(tol)) + if (lowerBounds != null && upperBounds != null) { + new BreezeLBFGSB( + BDV[Double](lowerBounds), BDV[Double](upperBounds), $(maxIter), 10, $(tol)) + } else { + new BreezeLBFGS[BDV[Double]]($(maxIter), 10, $(tol)) + } } else { val standardizationParam = $(standardization) def regParamL1Fun = (index: Int) => { @@ -549,6 +755,26 @@ class LogisticRegression @Since("1.2.0") ( math.log(histogram(1) / histogram(0))) } + if (usingBoundConstrainedOptimization) { + // Make sure all initial values locate in the corresponding bound. + var i = 0 + while (i < numCoeffsPlusIntercepts) { + val coefficientSetIndex = i % numCoefficientSets + val featureIndex = i / numCoefficientSets + if (initialCoefWithInterceptMatrix(coefficientSetIndex, featureIndex) < lowerBounds(i)) + { + initialCoefWithInterceptMatrix.update( + coefficientSetIndex, featureIndex, lowerBounds(i)) + } else if ( + initialCoefWithInterceptMatrix(coefficientSetIndex, featureIndex) > upperBounds(i)) + { + initialCoefWithInterceptMatrix.update( + coefficientSetIndex, featureIndex, upperBounds(i)) + } + i += 1 + } + } + val states = optimizer.iterations(new CachedDiffFunction(costFun), new BDV[Double](initialCoefWithInterceptMatrix.toArray)) @@ -602,7 +828,7 @@ class LogisticRegression @Since("1.2.0") ( if (isIntercept) interceptVec.toArray(classIndex) = value } - if ($(regParam) == 0.0 && isMultinomial) { + if ($(regParam) == 0.0 && isMultinomial && !usingBoundConstrainedOptimization) { /* When no regularization is applied, the multinomial coefficients lack identifiability because we do not use a pivot class. We can add any constant value to the coefficients @@ -612,31 +838,23 @@ class LogisticRegression @Since("1.2.0") ( Friedman, et al. "Regularization Paths for Generalized Linear Models via Coordinate Descent," https://core.ac.uk/download/files/153/6287975.pdf */ - val denseValues = denseCoefficientMatrix.values - val coefficientMean = denseValues.sum / denseValues.length - denseCoefficientMatrix.update(_ - coefficientMean) - } - - // TODO: use `denseCoefficientMatrix.compressed` after SPARK-17471 - val compressedCoefficientMatrix = if (isMultinomial) { - denseCoefficientMatrix - } else { - val compressedVector = Vectors.dense(denseCoefficientMatrix.values).compressed - compressedVector match { - case dv: DenseVector => denseCoefficientMatrix - case sv: SparseVector => - new SparseMatrix(1, numFeatures, Array(0, sv.indices.length), sv.indices, sv.values, - isTransposed = true) + val centers = Array.fill(numFeatures)(0.0) + denseCoefficientMatrix.foreachActive { case (i, j, v) => + centers(j) += v + } + centers.transform(_ / numCoefficientSets) + denseCoefficientMatrix.foreachActive { case (i, j, v) => + denseCoefficientMatrix.update(i, j, v - centers(j)) } } // center the intercepts when using multinomial algorithm - if ($(fitIntercept) && isMultinomial) { + if ($(fitIntercept) && isMultinomial && !usingBoundConstrainedOptimization) { val interceptArray = interceptVec.toArray val interceptMean = interceptArray.sum / interceptArray.length (0 until interceptVec.size).foreach { i => interceptArray(i) -= interceptMean } } - (compressedCoefficientMatrix, interceptVec.compressed, arrayBuilder.result()) + (denseCoefficientMatrix.compressed, interceptVec.compressed, arrayBuilder.result()) } } @@ -672,7 +890,7 @@ object LogisticRegression extends DefaultParamsReadable[LogisticRegression] { override def load(path: String): LogisticRegression = super.load(path) private[classification] val supportedFamilyNames = - Array("auto", "binomial", "multinomial").map(_.toLowerCase) + Array("auto", "binomial", "multinomial").map(_.toLowerCase(Locale.ROOT)) } /** @@ -713,7 +931,7 @@ class LogisticRegressionModel private[spark] ( // convert to appropriate vector representation without replicating data private lazy val _coefficients: Vector = { require(coefficientMatrix.isTransposed, - "LogisticRegressionModel coefficients should be row major.") + "LogisticRegressionModel coefficients should be row major for binomial model.") coefficientMatrix match { case dm: DenseMatrix => Vectors.dense(dm.values) case sm: SparseMatrix => Vectors.sparse(coefficientMatrix.numCols, sm.rowIndices, sm.values) @@ -1582,9 +1800,6 @@ private class LogisticAggregator( */ def add(instance: Instance): this.type = { instance match { case Instance(label, weight, features) => - require(numFeatures == features.size, s"Dimensions mismatch when adding new instance." + - s" Expecting $numFeatures but got ${features.size}.") - require(weight >= 0.0, s"instance weight, $weight has to be >= 0.0") if (weight == 0.0) return this @@ -1607,8 +1822,6 @@ private class LogisticAggregator( * @return This LogisticAggregator object. */ def merge(other: LogisticAggregator): this.type = { - require(numFeatures == other.numFeatures, s"Dimensions mismatch when merging with another " + - s"LogisticAggregator. Expecting $numFeatures but got ${other.numFeatures}.") if (other.weightSum != 0.0) { weightSum += other.weightSum diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifier.scala index 95c1337ed5608..ec39f964e213a 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifier.scala @@ -329,7 +329,8 @@ class MultilayerPerceptronClassificationModel private[ml] ( @Since("1.5.0") override def copy(extra: ParamMap): MultilayerPerceptronClassificationModel = { - copyValues(new MultilayerPerceptronClassificationModel(uid, layers, weights), extra) + val copied = new MultilayerPerceptronClassificationModel(uid, layers, weights).setParent(parent) + copyValues(copied, extra) } @Since("2.0.0") diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala index ce834f1d17e0d..ab4c235209289 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala @@ -140,7 +140,7 @@ class RandomForestClassifier @Since("1.4.0") ( .map(_.asInstanceOf[DecisionTreeClassificationModel]) val numFeatures = oldDataset.first().features.size - val m = new RandomForestClassificationModel(trees, numFeatures, numClasses) + val m = new RandomForestClassificationModel(uid, trees, numFeatures, numClasses) instr.logSuccess(m) m } diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala index a9c1a7ba0bc8a..5259ee419445f 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala @@ -472,7 +472,7 @@ class GaussianMixture @Since("2.0.0") ( */ val cov = { val ss = new DenseVector(new Array[Double](numFeatures)).asBreeze - slice.foreach(xi => ss += (xi.asBreeze - mean.asBreeze) :^ 2.0) + slice.foreach(xi => ss += (xi.asBreeze - mean.asBreeze) ^:^ 2.0) val diagVec = Vectors.fromBreeze(ss) BLAS.scal(1.0 / numSamples, diagVec) val covVec = new DenseVector(Array.fill[Double]( diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/LDA.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/LDA.scala index 55720e2d613d9..e3026c8efa823 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/clustering/LDA.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/LDA.scala @@ -17,6 +17,8 @@ package org.apache.spark.ml.clustering +import java.util.Locale + import org.apache.hadoop.fs.Path import org.json4s.DefaultFormats import org.json4s.JsonAST.JObject @@ -34,7 +36,6 @@ import org.apache.spark.mllib.clustering.{DistributedLDAModel => OldDistributedL EMLDAOptimizer => OldEMLDAOptimizer, LDA => OldLDA, LDAModel => OldLDAModel, LDAOptimizer => OldLDAOptimizer, LocalLDAModel => OldLocalLDAModel, OnlineLDAOptimizer => OldOnlineLDAOptimizer} -import org.apache.spark.mllib.impl.PeriodicCheckpointer import org.apache.spark.mllib.linalg.{Vector => OldVector, Vectors => OldVectors} import org.apache.spark.mllib.linalg.MatrixImplicits._ import org.apache.spark.mllib.linalg.VectorImplicits._ @@ -43,9 +44,9 @@ import org.apache.spark.rdd.RDD import org.apache.spark.sql.{DataFrame, Dataset, Row, SparkSession} import org.apache.spark.sql.functions.{col, monotonically_increasing_id, udf} import org.apache.spark.sql.types.StructType +import org.apache.spark.util.PeriodicCheckpointer import org.apache.spark.util.VersionUtils - private[clustering] trait LDAParams extends Params with HasFeaturesCol with HasMaxIter with HasSeed with HasCheckpointInterval { @@ -173,7 +174,8 @@ private[clustering] trait LDAParams extends Params with HasFeaturesCol with HasM @Since("1.6.0") final val optimizer = new Param[String](this, "optimizer", "Optimizer or inference" + " algorithm used to estimate the LDA model. Supported: " + supportedOptimizers.mkString(", "), - (o: String) => ParamValidators.inArray(supportedOptimizers).apply(o.toLowerCase)) + (o: String) => + ParamValidators.inArray(supportedOptimizers).apply(o.toLowerCase(Locale.ROOT))) /** @group getParam */ @Since("1.6.0") diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/BucketedRandomProjectionLSH.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/BucketedRandomProjectionLSH.scala index cbac16345a292..36a46ca6ff4b7 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/BucketedRandomProjectionLSH.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/BucketedRandomProjectionLSH.scala @@ -96,7 +96,10 @@ class BucketedRandomProjectionLSHModel private[ml]( } @Since("2.1.0") - override def copy(extra: ParamMap): this.type = defaultCopy(extra) + override def copy(extra: ParamMap): BucketedRandomProjectionLSHModel = { + val copied = new BucketedRandomProjectionLSHModel(uid, randUnitVectors).setParent(parent) + copyValues(copied, extra) + } @Since("2.1.0") override def write: MLWriter = { diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Imputer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Imputer.scala new file mode 100644 index 0000000000000..a41bd8e689d56 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Imputer.scala @@ -0,0 +1,259 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.feature + +import org.apache.hadoop.fs.Path + +import org.apache.spark.SparkException +import org.apache.spark.annotation.{Experimental, Since} +import org.apache.spark.ml.{Estimator, Model} +import org.apache.spark.ml.param._ +import org.apache.spark.ml.param.shared.HasInputCols +import org.apache.spark.ml.util._ +import org.apache.spark.sql.{DataFrame, Dataset, Row} +import org.apache.spark.sql.functions._ +import org.apache.spark.sql.types._ + +/** + * Params for [[Imputer]] and [[ImputerModel]]. + */ +private[feature] trait ImputerParams extends Params with HasInputCols { + + /** + * The imputation strategy. Currently only "mean" and "median" are supported. + * If "mean", then replace missing values using the mean value of the feature. + * If "median", then replace missing values using the approximate median value of the feature. + * Default: mean + * + * @group param + */ + final val strategy: Param[String] = new Param(this, "strategy", s"strategy for imputation. " + + s"If ${Imputer.mean}, then replace missing values using the mean value of the feature. " + + s"If ${Imputer.median}, then replace missing values using the median value of the feature.", + ParamValidators.inArray[String](Array(Imputer.mean, Imputer.median))) + + /** @group getParam */ + def getStrategy: String = $(strategy) + + /** + * The placeholder for the missing values. All occurrences of missingValue will be imputed. + * Note that null values are always treated as missing. + * Default: Double.NaN + * + * @group param + */ + final val missingValue: DoubleParam = new DoubleParam(this, "missingValue", + "The placeholder for the missing values. All occurrences of missingValue will be imputed") + + /** @group getParam */ + def getMissingValue: Double = $(missingValue) + + /** + * Param for output column names. + * @group param + */ + final val outputCols: StringArrayParam = new StringArrayParam(this, "outputCols", + "output column names") + + /** @group getParam */ + final def getOutputCols: Array[String] = $(outputCols) + + /** Validates and transforms the input schema. */ + protected def validateAndTransformSchema(schema: StructType): StructType = { + require($(inputCols).length == $(inputCols).distinct.length, s"inputCols contains" + + s" duplicates: (${$(inputCols).mkString(", ")})") + require($(outputCols).length == $(outputCols).distinct.length, s"outputCols contains" + + s" duplicates: (${$(outputCols).mkString(", ")})") + require($(inputCols).length == $(outputCols).length, s"inputCols(${$(inputCols).length})" + + s" and outputCols(${$(outputCols).length}) should have the same length") + val outputFields = $(inputCols).zip($(outputCols)).map { case (inputCol, outputCol) => + val inputField = schema(inputCol) + SchemaUtils.checkColumnTypes(schema, inputCol, Seq(DoubleType, FloatType)) + StructField(outputCol, inputField.dataType, inputField.nullable) + } + StructType(schema ++ outputFields) + } +} + +/** + * :: Experimental :: + * Imputation estimator for completing missing values, either using the mean or the median + * of the columns in which the missing values are located. The input columns should be of + * DoubleType or FloatType. Currently Imputer does not support categorical features + * (SPARK-15041) and possibly creates incorrect values for a categorical feature. + * + * Note that the mean/median value is computed after filtering out missing values. + * All Null values in the input columns are treated as missing, and so are also imputed. For + * computing median, DataFrameStatFunctions.approxQuantile is used with a relative error of 0.001. + */ +@Experimental +class Imputer @Since("2.2.0")(override val uid: String) + extends Estimator[ImputerModel] with ImputerParams with DefaultParamsWritable { + + @Since("2.2.0") + def this() = this(Identifiable.randomUID("imputer")) + + /** @group setParam */ + @Since("2.2.0") + def setInputCols(value: Array[String]): this.type = set(inputCols, value) + + /** @group setParam */ + @Since("2.2.0") + def setOutputCols(value: Array[String]): this.type = set(outputCols, value) + + /** + * Imputation strategy. Available options are ["mean", "median"]. + * @group setParam + */ + @Since("2.2.0") + def setStrategy(value: String): this.type = set(strategy, value) + + /** @group setParam */ + @Since("2.2.0") + def setMissingValue(value: Double): this.type = set(missingValue, value) + + setDefault(strategy -> Imputer.mean, missingValue -> Double.NaN) + + override def fit(dataset: Dataset[_]): ImputerModel = { + transformSchema(dataset.schema, logging = true) + val spark = dataset.sparkSession + import spark.implicits._ + val surrogates = $(inputCols).map { inputCol => + val ic = col(inputCol) + val filtered = dataset.select(ic.cast(DoubleType)) + .filter(ic.isNotNull && ic =!= $(missingValue) && !ic.isNaN) + if(filtered.take(1).length == 0) { + throw new SparkException(s"surrogate cannot be computed. " + + s"All the values in $inputCol are Null, Nan or missingValue(${$(missingValue)})") + } + val surrogate = $(strategy) match { + case Imputer.mean => filtered.select(avg(inputCol)).as[Double].first() + case Imputer.median => filtered.stat.approxQuantile(inputCol, Array(0.5), 0.001).head + } + surrogate + } + + val rows = spark.sparkContext.parallelize(Seq(Row.fromSeq(surrogates))) + val schema = StructType($(inputCols).map(col => StructField(col, DoubleType, nullable = false))) + val surrogateDF = spark.createDataFrame(rows, schema) + copyValues(new ImputerModel(uid, surrogateDF).setParent(this)) + } + + override def transformSchema(schema: StructType): StructType = { + validateAndTransformSchema(schema) + } + + override def copy(extra: ParamMap): Imputer = defaultCopy(extra) +} + +@Since("2.2.0") +object Imputer extends DefaultParamsReadable[Imputer] { + + /** strategy names that Imputer currently supports. */ + private[ml] val mean = "mean" + private[ml] val median = "median" + + @Since("2.2.0") + override def load(path: String): Imputer = super.load(path) +} + +/** + * :: Experimental :: + * Model fitted by [[Imputer]]. + * + * @param surrogateDF a DataFrame containing inputCols and their corresponding surrogates, + * which are used to replace the missing values in the input DataFrame. + */ +@Experimental +class ImputerModel private[ml]( + override val uid: String, + val surrogateDF: DataFrame) + extends Model[ImputerModel] with ImputerParams with MLWritable { + + import ImputerModel._ + + /** @group setParam */ + def setInputCols(value: Array[String]): this.type = set(inputCols, value) + + /** @group setParam */ + def setOutputCols(value: Array[String]): this.type = set(outputCols, value) + + override def transform(dataset: Dataset[_]): DataFrame = { + transformSchema(dataset.schema, logging = true) + var outputDF = dataset + val surrogates = surrogateDF.select($(inputCols).map(col): _*).head().toSeq + + $(inputCols).zip($(outputCols)).zip(surrogates).foreach { + case ((inputCol, outputCol), surrogate) => + val inputType = dataset.schema(inputCol).dataType + val ic = col(inputCol) + outputDF = outputDF.withColumn(outputCol, + when(ic.isNull, surrogate) + .when(ic === $(missingValue), surrogate) + .otherwise(ic) + .cast(inputType)) + } + outputDF.toDF() + } + + override def transformSchema(schema: StructType): StructType = { + validateAndTransformSchema(schema) + } + + override def copy(extra: ParamMap): ImputerModel = { + val copied = new ImputerModel(uid, surrogateDF) + copyValues(copied, extra).setParent(parent) + } + + @Since("2.2.0") + override def write: MLWriter = new ImputerModelWriter(this) +} + + +@Since("2.2.0") +object ImputerModel extends MLReadable[ImputerModel] { + + private[ImputerModel] class ImputerModelWriter(instance: ImputerModel) extends MLWriter { + + override protected def saveImpl(path: String): Unit = { + DefaultParamsWriter.saveMetadata(instance, path, sc) + val dataPath = new Path(path, "data").toString + instance.surrogateDF.repartition(1).write.parquet(dataPath) + } + } + + private class ImputerReader extends MLReader[ImputerModel] { + + private val className = classOf[ImputerModel].getName + + override def load(path: String): ImputerModel = { + val metadata = DefaultParamsReader.loadMetadata(path, sc, className) + val dataPath = new Path(path, "data").toString + val surrogateDF = sqlContext.read.parquet(dataPath) + val model = new ImputerModel(metadata.uid, surrogateDF) + DefaultParamsReader.getAndSetParams(model, metadata) + model + } + } + + @Since("2.2.0") + override def read: MLReader[ImputerModel] = new ImputerReader + + @Since("2.2.0") + override def load(path: String): ImputerModel = super.load(path) +} diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/MinHashLSH.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/MinHashLSH.scala index 620e1fbb09ff7..145422a059196 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/MinHashLSH.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/MinHashLSH.scala @@ -86,7 +86,10 @@ class MinHashLSHModel private[ml]( } @Since("2.1.0") - override def copy(extra: ParamMap): this.type = defaultCopy(extra) + override def copy(extra: ParamMap): MinHashLSHModel = { + val copied = new MinHashLSHModel(uid, randCoefficients).setParent(parent) + copyValues(copied, extra) + } @Since("2.1.0") override def write: MLWriter = new MinHashLSHModel.MinHashLSHModelWriter(this) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/QuantileDiscretizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/QuantileDiscretizer.scala index 80c7f55e26b84..feceeba866dfa 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/QuantileDiscretizer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/QuantileDiscretizer.scala @@ -93,8 +93,8 @@ private[feature] trait QuantileDiscretizerBase extends Params * are too few distinct values of the input to create enough distinct quantiles. * * NaN handling: - * NaN values will be removed from the column during `QuantileDiscretizer` fitting. This will - * produce a `Bucketizer` model for making predictions. During the transformation, + * null and NaN values will be ignored from the column during `QuantileDiscretizer` fitting. This + * will produce a `Bucketizer` model for making predictions. During the transformation, * `Bucketizer` will raise an error when it finds NaN values in the dataset, but the user can * also choose to either keep or remove NaN values within the dataset by setting `handleInvalid`. * If the user chooses to keep NaN values, they will be handled specially and placed into their own diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala index 389898666eb8e..5a3e2929f5f52 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala @@ -268,8 +268,10 @@ class RFormulaModel private[feature]( } @Since("1.5.0") - override def copy(extra: ParamMap): RFormulaModel = copyValues( - new RFormulaModel(uid, resolvedFormula, pipelineModel)) + override def copy(extra: ParamMap): RFormulaModel = { + val copied = new RFormulaModel(uid, resolvedFormula, pipelineModel).setParent(parent) + copyValues(copied, extra) + } @Since("2.0.0") override def toString: String = s"RFormulaModel($resolvedFormula) (uid=$uid)" diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala index a503411b63612..99321bcc7cf98 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala @@ -17,6 +17,8 @@ package org.apache.spark.ml.feature +import scala.language.existentials + import org.apache.hadoop.fs.Path import org.apache.spark.SparkException @@ -24,7 +26,7 @@ import org.apache.spark.annotation.Since import org.apache.spark.ml.{Estimator, Model, Transformer} import org.apache.spark.ml.attribute.{Attribute, NominalAttribute} import org.apache.spark.ml.param._ -import org.apache.spark.ml.param.shared._ +import org.apache.spark.ml.param.shared.{HasInputCol, HasOutputCol} import org.apache.spark.ml.util._ import org.apache.spark.sql.{DataFrame, Dataset} import org.apache.spark.sql.functions._ @@ -34,8 +36,28 @@ import org.apache.spark.util.collection.OpenHashMap /** * Base trait for [[StringIndexer]] and [[StringIndexerModel]]. */ -private[feature] trait StringIndexerBase extends Params with HasInputCol with HasOutputCol - with HasHandleInvalid { +private[feature] trait StringIndexerBase extends Params with HasInputCol with HasOutputCol { + + /** + * Param for how to handle invalid data (unseen labels or NULL values). + * Options are 'skip' (filter out rows with invalid data), + * 'error' (throw an error), or 'keep' (put invalid data in a special additional + * bucket, at index numLabels). + * Default: "error" + * @group param + */ + @Since("1.6.0") + val handleInvalid: Param[String] = new Param[String](this, "handleInvalid", "how to handle " + + "invalid data (unseen labels or NULL values). " + + "Options are 'skip' (filter out rows with invalid data), error (throw an error), " + + "or 'keep' (put invalid data in a special additional bucket, at index numLabels).", + ParamValidators.inArray(StringIndexer.supportedHandleInvalids)) + + setDefault(handleInvalid, StringIndexer.ERROR_INVALID) + + /** @group getParam */ + @Since("1.6.0") + def getHandleInvalid: String = $(handleInvalid) /** Validates and transforms the input schema. */ protected def validateAndTransformSchema(schema: StructType): StructType = { @@ -73,7 +95,6 @@ class StringIndexer @Since("1.4.0") ( /** @group setParam */ @Since("1.6.0") def setHandleInvalid(value: String): this.type = set(handleInvalid, value) - setDefault(handleInvalid, "error") /** @group setParam */ @Since("1.4.0") @@ -86,7 +107,7 @@ class StringIndexer @Since("1.4.0") ( @Since("2.0.0") override def fit(dataset: Dataset[_]): StringIndexerModel = { transformSchema(dataset.schema, logging = true) - val counts = dataset.select(col($(inputCol)).cast(StringType)) + val counts = dataset.na.drop(Array($(inputCol))).select(col($(inputCol)).cast(StringType)) .rdd .map(_.getString(0)) .countByValue() @@ -105,6 +126,11 @@ class StringIndexer @Since("1.4.0") ( @Since("1.6.0") object StringIndexer extends DefaultParamsReadable[StringIndexer] { + private[feature] val SKIP_INVALID: String = "skip" + private[feature] val ERROR_INVALID: String = "error" + private[feature] val KEEP_INVALID: String = "keep" + private[feature] val supportedHandleInvalids: Array[String] = + Array(SKIP_INVALID, ERROR_INVALID, KEEP_INVALID) @Since("1.6.0") override def load(path: String): StringIndexer = super.load(path) @@ -144,7 +170,6 @@ class StringIndexerModel ( /** @group setParam */ @Since("1.6.0") def setHandleInvalid(value: String): this.type = set(handleInvalid, value) - setDefault(handleInvalid, "error") /** @group setParam */ @Since("1.4.0") @@ -163,25 +188,43 @@ class StringIndexerModel ( } transformSchema(dataset.schema, logging = true) - val indexer = udf { label: String => - if (labelToIndex.contains(label)) { - labelToIndex(label) - } else { - throw new SparkException(s"Unseen label: $label.") - } + val filteredLabels = getHandleInvalid match { + case StringIndexer.KEEP_INVALID => labels :+ "__unknown" + case _ => labels } val metadata = NominalAttribute.defaultAttr - .withName($(outputCol)).withValues(labels).toMetadata() + .withName($(outputCol)).withValues(filteredLabels).toMetadata() // If we are skipping invalid records, filter them out. - val filteredDataset = getHandleInvalid match { - case "skip" => + val (filteredDataset, keepInvalid) = getHandleInvalid match { + case StringIndexer.SKIP_INVALID => val filterer = udf { label: String => labelToIndex.contains(label) } - dataset.where(filterer(dataset($(inputCol)))) - case _ => dataset + (dataset.na.drop(Array($(inputCol))).where(filterer(dataset($(inputCol)))), false) + case _ => (dataset, getHandleInvalid == StringIndexer.KEEP_INVALID) + } + + val indexer = udf { label: String => + if (label == null) { + if (keepInvalid) { + labels.length + } else { + throw new SparkException("StringIndexer encountered NULL value. To handle or skip " + + "NULLS, try setting StringIndexer.handleInvalid.") + } + } else { + if (labelToIndex.contains(label)) { + labelToIndex(label) + } else if (keepInvalid) { + labels.length + } else { + throw new SparkException(s"Unseen label: $label. To handle unseen labels, " + + s"set Param handleInvalid to ${StringIndexer.KEEP_INVALID}.") + } + } } + filteredDataset.select(col("*"), indexer(dataset($(inputCol)).cast(StringType)).as($(outputCol), metadata)) } diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala index 42e8a66a62b61..4ca062c0b5adf 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala @@ -227,25 +227,50 @@ class Word2VecModel private[ml] ( /** * Find "num" number of words closest in similarity to the given word, not - * including the word itself. Returns a dataframe with the words and the - * cosine similarities between the synonyms and the given word. + * including the word itself. + * @return a dataframe with columns "word" and "similarity" of the word and the cosine + * similarities between the synonyms and the given word vector. */ @Since("1.5.0") def findSynonyms(word: String, num: Int): DataFrame = { val spark = SparkSession.builder().getOrCreate() - spark.createDataFrame(wordVectors.findSynonyms(word, num)).toDF("word", "similarity") + spark.createDataFrame(findSynonymsArray(word, num)).toDF("word", "similarity") } /** - * Find "num" number of words whose vector representation most similar to the supplied vector. + * Find "num" number of words whose vector representation is most similar to the supplied vector. * If the supplied vector is the vector representation of a word in the model's vocabulary, - * that word will be in the results. Returns a dataframe with the words and the cosine + * that word will be in the results. + * @return a dataframe with columns "word" and "similarity" of the word and the cosine * similarities between the synonyms and the given word vector. */ @Since("2.0.0") def findSynonyms(vec: Vector, num: Int): DataFrame = { val spark = SparkSession.builder().getOrCreate() - spark.createDataFrame(wordVectors.findSynonyms(vec, num)).toDF("word", "similarity") + spark.createDataFrame(findSynonymsArray(vec, num)).toDF("word", "similarity") + } + + /** + * Find "num" number of words whose vector representation is most similar to the supplied vector. + * If the supplied vector is the vector representation of a word in the model's vocabulary, + * that word will be in the results. + * @return an array of the words and the cosine similarities between the synonyms given + * word vector. + */ + @Since("2.2.0") + def findSynonymsArray(vec: Vector, num: Int): Array[(String, Double)] = { + wordVectors.findSynonyms(vec, num) + } + + /** + * Find "num" number of words closest in similarity to the given word, not + * including the word itself. + * @return an array of the words and the cosine similarities between the synonyms given + * word vector. + */ + @Since("2.2.0") + def findSynonymsArray(word: String, num: Int): Array[(String, Double)] = { + wordVectors.findSynonyms(word, num) } /** @group setParam */ diff --git a/mllib/src/main/scala/org/apache/spark/ml/fpm/FPGrowth.scala b/mllib/src/main/scala/org/apache/spark/ml/fpm/FPGrowth.scala index 417968d9b817d..8f00daa59f1a5 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/fpm/FPGrowth.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/fpm/FPGrowth.scala @@ -17,7 +17,6 @@ package org.apache.spark.ml.fpm -import scala.collection.mutable.ArrayBuffer import scala.reflect.ClassTag import org.apache.hadoop.fs.Path @@ -25,7 +24,7 @@ import org.apache.hadoop.fs.Path import org.apache.spark.annotation.{Experimental, Since} import org.apache.spark.ml.{Estimator, Model} import org.apache.spark.ml.param._ -import org.apache.spark.ml.param.shared.{HasFeaturesCol, HasPredictionCol} +import org.apache.spark.ml.param.shared.HasPredictionCol import org.apache.spark.ml.util._ import org.apache.spark.mllib.fpm.{AssociationRules => MLlibAssociationRules, FPGrowth => MLlibFPGrowth} @@ -37,11 +36,24 @@ import org.apache.spark.sql.types._ /** * Common params for FPGrowth and FPGrowthModel */ -private[fpm] trait FPGrowthParams extends Params with HasFeaturesCol with HasPredictionCol { +private[fpm] trait FPGrowthParams extends Params with HasPredictionCol { + + /** + * Items column name. + * Default: "items" + * @group param + */ + @Since("2.2.0") + val itemsCol: Param[String] = new Param[String](this, "itemsCol", "items column name") + setDefault(itemsCol -> "items") + + /** @group getParam */ + @Since("2.2.0") + def getItemsCol: String = $(itemsCol) /** * Minimal support level of the frequent pattern. [0.0, 1.0]. Any pattern that appears - * more than (minSupport * size-of-the-dataset) times will be output + * more than (minSupport * size-of-the-dataset) times will be output in the frequent itemsets. * Default: 0.3 * @group param */ @@ -56,8 +68,8 @@ private[fpm] trait FPGrowthParams extends Params with HasFeaturesCol with HasPre def getMinSupport: Double = $(minSupport) /** - * Number of partitions (>=1) used by parallel FP-growth. By default the param is not set, and - * partition number of the input dataset is used. + * Number of partitions (at least 1) used by parallel FP-growth. By default the param is not + * set, and partition number of the input dataset is used. * @group expertParam */ @Since("2.2.0") @@ -69,8 +81,8 @@ private[fpm] trait FPGrowthParams extends Params with HasFeaturesCol with HasPre def getNumPartitions: Int = $(numPartitions) /** - * Minimal confidence for generating Association Rule. - * Note that minConfidence has no effect during fitting. + * Minimal confidence for generating Association Rule. minConfidence will not affect the mining + * for frequent itemsets, but will affect the association rules generation. * Default: 0.8 * @group param */ @@ -91,10 +103,10 @@ private[fpm] trait FPGrowthParams extends Params with HasFeaturesCol with HasPre */ @Since("2.2.0") protected def validateAndTransformSchema(schema: StructType): StructType = { - val inputType = schema($(featuresCol)).dataType + val inputType = schema($(itemsCol)).dataType require(inputType.isInstanceOf[ArrayType], s"The input column must be ArrayType, but got $inputType.") - SchemaUtils.appendColumn(schema, $(predictionCol), schema($(featuresCol)).dataType) + SchemaUtils.appendColumn(schema, $(predictionCol), schema($(itemsCol)).dataType) } } @@ -105,7 +117,7 @@ private[fpm] trait FPGrowthParams extends Params with HasFeaturesCol with HasPre * Recommendation. PFP distributes computation in such a way that each worker executes an * independent group of mining tasks. The FP-Growth algorithm is described in * Han et al., Mining frequent patterns without - * candidate generation. Note null values in the feature column are ignored during fit(). + * candidate generation. Note null values in the itemsCol column are ignored during fit(). * * @see * Association rule learning (Wikipedia) @@ -133,7 +145,7 @@ class FPGrowth @Since("2.2.0") ( /** @group setParam */ @Since("2.2.0") - def setFeaturesCol(value: String): this.type = set(featuresCol, value) + def setItemsCol(value: String): this.type = set(itemsCol, value) /** @group setParam */ @Since("2.2.0") @@ -146,17 +158,16 @@ class FPGrowth @Since("2.2.0") ( } private def genericFit[T: ClassTag](dataset: Dataset[_]): FPGrowthModel = { - val data = dataset.select($(featuresCol)) - val items = data.where(col($(featuresCol)).isNotNull).rdd.map(r => r.getSeq[T](0).toArray) + val data = dataset.select($(itemsCol)) + val items = data.where(col($(itemsCol)).isNotNull).rdd.map(r => r.getSeq[T](0).toArray) val mllibFP = new MLlibFPGrowth().setMinSupport($(minSupport)) if (isSet(numPartitions)) { mllibFP.setNumPartitions($(numPartitions)) } val parentModel = mllibFP.run(items) val rows = parentModel.freqItemsets.map(f => Row(f.items, f.freq)) - val schema = StructType(Seq( - StructField("items", dataset.schema($(featuresCol)).dataType, nullable = false), + StructField("items", dataset.schema($(itemsCol)).dataType, nullable = false), StructField("freq", LongType, nullable = false))) val frequentItems = dataset.sparkSession.createDataFrame(rows, schema) copyValues(new FPGrowthModel(uid, frequentItems)).setParent(this) @@ -183,7 +194,7 @@ object FPGrowth extends DefaultParamsReadable[FPGrowth] { * :: Experimental :: * Model fitted by FPGrowth. * - * @param freqItemsets frequent items in the format of DataFrame("items"[Seq], "freq"[Long]) + * @param freqItemsets frequent itemsets in the format of DataFrame("items"[Array], "freq"[Long]) */ @Since("2.2.0") @Experimental @@ -198,28 +209,46 @@ class FPGrowthModel private[ml] ( /** @group setParam */ @Since("2.2.0") - def setFeaturesCol(value: String): this.type = set(featuresCol, value) + def setItemsCol(value: String): this.type = set(itemsCol, value) /** @group setParam */ @Since("2.2.0") def setPredictionCol(value: String): this.type = set(predictionCol, value) /** - * Get association rules fitted by AssociationRules using the minConfidence. Returns a dataframe + * Cache minConfidence and associationRules to avoid redundant computation for association rules + * during transform. The associationRules will only be re-computed when minConfidence changed. + */ + @transient private var _cachedMinConf: Double = Double.NaN + + @transient private var _cachedRules: DataFrame = _ + + /** + * Get association rules fitted using the minConfidence. Returns a dataframe * with three fields, "antecedent", "consequent" and "confidence", where "antecedent" and * "consequent" are Array[T] and "confidence" is Double. */ @Since("2.2.0") - @transient lazy val associationRules: DataFrame = { - AssociationRules.getAssociationRulesFromFP(freqItemsets, "items", "freq", $(minConfidence)) + @transient def associationRules: DataFrame = { + if ($(minConfidence) == _cachedMinConf) { + _cachedRules + } else { + _cachedRules = AssociationRules + .getAssociationRulesFromFP(freqItemsets, "items", "freq", $(minConfidence)) + _cachedMinConf = $(minConfidence) + _cachedRules + } } /** * The transform method first generates the association rules according to the frequent itemsets. - * Then for each association rule, it will examine the input items against antecedents and - * summarize the consequents as prediction. The prediction column has the same data type as the - * input column(Array[T]) and will not contain existing items in the input column. The null - * values in the feature columns are treated as empty sets. + * Then for each transaction in itemsCol, the transform method will compare its items against the + * antecedents of each association rule. If the record contains all the antecedents of a + * specific association rule, the rule will be considered as applicable and its consequents + * will be added to the prediction result. The transform method will summarize the consequents + * from all the applicable rules as prediction. The prediction column has the same data type as + * the input column(Array[T]) and will not contain existing items in the input column. The null + * values in the itemsCol columns are treated as empty sets. * WARNING: internally it collects association rules to the driver and uses broadcast for * efficiency. This may bring pressure to driver memory for large set of association rules. */ @@ -235,7 +264,7 @@ class FPGrowthModel private[ml] ( .collect().asInstanceOf[Array[(Seq[Any], Seq[Any])]] val brRules = dataset.sparkSession.sparkContext.broadcast(rules) - val dt = dataset.schema($(featuresCol)).dataType + val dt = dataset.schema($(itemsCol)).dataType // For each rule, examine the input items and summarize the consequents val predictUDF = udf((items: Seq[_]) => { if (items != null) { @@ -245,11 +274,11 @@ class FPGrowthModel private[ml] ( rule._2.filter(item => !itemset.contains(item)) } else { Seq.empty - }) + }).distinct } else { Seq.empty - }.distinct }, dt) - dataset.withColumn($(predictionCol), predictUDF(col($(featuresCol)))) + }}, dt) + dataset.withColumn($(predictionCol), predictUDF(col($(itemsCol)))) } @Since("2.2.0") @@ -307,13 +336,13 @@ private[fpm] object AssociationRules { /** * Computes the association rules with confidence above minConfidence. - * @param dataset DataFrame("items", "freq") containing frequent itemset obtained from - * algorithms like [[FPGrowth]]. + * @param dataset DataFrame("items"[Array], "freq"[Long]) containing frequent itemsets obtained + * from algorithms like [[FPGrowth]]. * @param itemsCol column name for frequent itemsets - * @param freqCol column name for frequent itemsets count - * @param minConfidence minimum confidence for the result association rules - * @return a DataFrame("antecedent", "consequent", "confidence") containing the association - * rules. + * @param freqCol column name for appearance count of the frequent itemsets + * @param minConfidence minimum confidence for generating the association rules + * @return a DataFrame("antecedent"[Array], "consequent"[Array], "confidence"[Double]) + * containing the association rules. */ def getAssociationRulesFromFP[T: ClassTag]( dataset: Dataset[_], diff --git a/mllib/src/main/scala/org/apache/spark/ml/r/FPGrowthWrapper.scala b/mllib/src/main/scala/org/apache/spark/ml/r/FPGrowthWrapper.scala new file mode 100644 index 0000000000000..b8151d8d90702 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/r/FPGrowthWrapper.scala @@ -0,0 +1,86 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.r + +import org.apache.hadoop.fs.Path +import org.json4s.JsonDSL._ +import org.json4s.jackson.JsonMethods._ + +import org.apache.spark.ml.fpm.{FPGrowth, FPGrowthModel} +import org.apache.spark.ml.util._ +import org.apache.spark.sql.{DataFrame, Dataset} + +private[r] class FPGrowthWrapper private (val fpGrowthModel: FPGrowthModel) extends MLWritable { + def freqItemsets: DataFrame = fpGrowthModel.freqItemsets + def associationRules: DataFrame = fpGrowthModel.associationRules + + def transform(dataset: Dataset[_]): DataFrame = { + fpGrowthModel.transform(dataset) + } + + override def write: MLWriter = new FPGrowthWrapper.FPGrowthWrapperWriter(this) +} + +private[r] object FPGrowthWrapper extends MLReadable[FPGrowthWrapper] { + + def fit( + data: DataFrame, + minSupport: Double, + minConfidence: Double, + itemsCol: String, + numPartitions: Integer): FPGrowthWrapper = { + val fpGrowth = new FPGrowth() + .setMinSupport(minSupport) + .setMinConfidence(minConfidence) + .setItemsCol(itemsCol) + + if (numPartitions != null && numPartitions > 0) { + fpGrowth.setNumPartitions(numPartitions) + } + + val fpGrowthModel = fpGrowth.fit(data) + + new FPGrowthWrapper(fpGrowthModel) + } + + override def read: MLReader[FPGrowthWrapper] = new FPGrowthWrapperReader + + class FPGrowthWrapperReader extends MLReader[FPGrowthWrapper] { + override def load(path: String): FPGrowthWrapper = { + val modelPath = new Path(path, "model").toString + val fPGrowthModel = FPGrowthModel.load(modelPath) + + new FPGrowthWrapper(fPGrowthModel) + } + } + + class FPGrowthWrapperWriter(instance: FPGrowthWrapper) extends MLWriter { + override protected def saveImpl(path: String): Unit = { + val modelPath = new Path(path, "model").toString + val rMetadataPath = new Path(path, "rMetadata").toString + + val rMetadataJson: String = compact(render( + "class" -> instance.getClass.getName + )) + + sc.parallelize(Seq(rMetadataJson), 1).saveAsTextFile(rMetadataPath) + + instance.fpGrowthModel.save(modelPath) + } + } +} diff --git a/mllib/src/main/scala/org/apache/spark/ml/r/GBTClassificationWrapper.scala b/mllib/src/main/scala/org/apache/spark/ml/r/GBTClassificationWrapper.scala index aacb41ee2659b..c07eadb30a4d2 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/r/GBTClassificationWrapper.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/r/GBTClassificationWrapper.scala @@ -44,6 +44,7 @@ private[r] class GBTClassifierWrapper private ( lazy val featureImportances: Vector = gbtcModel.featureImportances lazy val numTrees: Int = gbtcModel.getNumTrees lazy val treeWeights: Array[Double] = gbtcModel.treeWeights + lazy val maxDepth: Int = gbtcModel.getMaxDepth def summary: String = gbtcModel.toDebugString diff --git a/mllib/src/main/scala/org/apache/spark/ml/r/GBTRegressionWrapper.scala b/mllib/src/main/scala/org/apache/spark/ml/r/GBTRegressionWrapper.scala index 585077588eb9b..b568d7859221f 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/r/GBTRegressionWrapper.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/r/GBTRegressionWrapper.scala @@ -42,6 +42,7 @@ private[r] class GBTRegressorWrapper private ( lazy val featureImportances: Vector = gbtrModel.featureImportances lazy val numTrees: Int = gbtrModel.getNumTrees lazy val treeWeights: Array[Double] = gbtrModel.treeWeights + lazy val maxDepth: Int = gbtrModel.getMaxDepth def summary: String = gbtrModel.toDebugString diff --git a/mllib/src/main/scala/org/apache/spark/ml/r/GeneralizedLinearRegressionWrapper.scala b/mllib/src/main/scala/org/apache/spark/ml/r/GeneralizedLinearRegressionWrapper.scala index cbd6cd1c7933c..4bd4aa7113f68 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/r/GeneralizedLinearRegressionWrapper.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/r/GeneralizedLinearRegressionWrapper.scala @@ -17,6 +17,8 @@ package org.apache.spark.ml.r +import java.util.Locale + import org.apache.hadoop.fs.Path import org.json4s._ import org.json4s.JsonDSL._ @@ -71,7 +73,9 @@ private[r] object GeneralizedLinearRegressionWrapper tol: Double, maxIter: Int, weightCol: String, - regParam: Double): GeneralizedLinearRegressionWrapper = { + regParam: Double, + variancePower: Double, + linkPower: Double): GeneralizedLinearRegressionWrapper = { val rFormula = new RFormula().setFormula(formula) checkDataColumns(rFormula, data) val rFormulaModel = rFormula.fit(data) @@ -83,13 +87,17 @@ private[r] object GeneralizedLinearRegressionWrapper // assemble and fit the pipeline val glr = new GeneralizedLinearRegression() .setFamily(family) - .setLink(link) .setFitIntercept(rFormula.hasIntercept) .setTol(tol) .setMaxIter(maxIter) .setRegParam(regParam) .setFeaturesCol(rFormula.getFeaturesCol) - + // set variancePower and linkPower if family is tweedie; otherwise, set link function + if (family.toLowerCase(Locale.ROOT) == "tweedie") { + glr.setVariancePower(variancePower).setLinkPower(linkPower) + } else { + glr.setLink(link) + } if (weightCol != null) glr.setWeightCol(weightCol) val pipeline = new Pipeline() @@ -145,7 +153,12 @@ private[r] object GeneralizedLinearRegressionWrapper val rDeviance: Double = summary.deviance val rResidualDegreeOfFreedomNull: Long = summary.residualDegreeOfFreedomNull val rResidualDegreeOfFreedom: Long = summary.residualDegreeOfFreedom - val rAic: Double = summary.aic + val rAic: Double = if (family.toLowerCase(Locale.ROOT) == "tweedie" && + !Array(0.0, 1.0, 2.0).exists(x => math.abs(x - variancePower) < 1e-8)) { + 0.0 + } else { + summary.aic + } val rNumIterations: Int = summary.numIterations new GeneralizedLinearRegressionWrapper(pipeline, rFeatures, rCoefficients, rDispersion, diff --git a/mllib/src/main/scala/org/apache/spark/ml/r/LogisticRegressionWrapper.scala b/mllib/src/main/scala/org/apache/spark/ml/r/LogisticRegressionWrapper.scala index c96f99cb83434..703bcdf4ca725 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/r/LogisticRegressionWrapper.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/r/LogisticRegressionWrapper.scala @@ -40,13 +40,13 @@ private[r] class LogisticRegressionWrapper private ( private val lrModel: LogisticRegressionModel = pipeline.stages(1).asInstanceOf[LogisticRegressionModel] - val rFeatures: Array[String] = if (lrModel.getFitIntercept) { + lazy val rFeatures: Array[String] = if (lrModel.getFitIntercept) { Array("(Intercept)") ++ features } else { features } - val rCoefficients: Array[Double] = { + lazy val rCoefficients: Array[Double] = { val numRows = lrModel.coefficientMatrix.numRows val numCols = lrModel.coefficientMatrix.numCols val numColsWithIntercept = if (lrModel.getFitIntercept) numCols + 1 else numCols diff --git a/mllib/src/main/scala/org/apache/spark/ml/r/MultilayerPerceptronClassifierWrapper.scala b/mllib/src/main/scala/org/apache/spark/ml/r/MultilayerPerceptronClassifierWrapper.scala index d34de30931143..48c87743dee60 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/r/MultilayerPerceptronClassifierWrapper.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/r/MultilayerPerceptronClassifierWrapper.scala @@ -36,11 +36,11 @@ private[r] class MultilayerPerceptronClassifierWrapper private ( import MultilayerPerceptronClassifierWrapper._ - val mlpModel: MultilayerPerceptronClassificationModel = + private val mlpModel: MultilayerPerceptronClassificationModel = pipeline.stages(1).asInstanceOf[MultilayerPerceptronClassificationModel] - val weights: Array[Double] = mlpModel.weights.toArray - val layers: Array[Int] = mlpModel.layers + lazy val weights: Array[Double] = mlpModel.weights.toArray + lazy val layers: Array[Int] = mlpModel.layers def transform(dataset: Dataset[_]): DataFrame = { pipeline.transform(dataset) diff --git a/mllib/src/main/scala/org/apache/spark/ml/r/RWrappers.scala b/mllib/src/main/scala/org/apache/spark/ml/r/RWrappers.scala index 358e522dfe1c8..b30ce12bc6cc8 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/r/RWrappers.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/r/RWrappers.scala @@ -68,6 +68,8 @@ private[r] object RWrappers extends MLReader[Object] { BisectingKMeansWrapper.load(path) case "org.apache.spark.ml.r.LinearSVCWrapper" => LinearSVCWrapper.load(path) + case "org.apache.spark.ml.r.FPGrowthWrapper" => + FPGrowthWrapper.load(path) case _ => throw new SparkException(s"SparkR read.ml does not support load $className") } diff --git a/mllib/src/main/scala/org/apache/spark/ml/r/RandomForestClassificationWrapper.scala b/mllib/src/main/scala/org/apache/spark/ml/r/RandomForestClassificationWrapper.scala index 366f375b58582..8a83d4e980f7b 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/r/RandomForestClassificationWrapper.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/r/RandomForestClassificationWrapper.scala @@ -44,6 +44,7 @@ private[r] class RandomForestClassifierWrapper private ( lazy val featureImportances: Vector = rfcModel.featureImportances lazy val numTrees: Int = rfcModel.getNumTrees lazy val treeWeights: Array[Double] = rfcModel.treeWeights + lazy val maxDepth: Int = rfcModel.getMaxDepth def summary: String = rfcModel.toDebugString diff --git a/mllib/src/main/scala/org/apache/spark/ml/r/RandomForestRegressionWrapper.scala b/mllib/src/main/scala/org/apache/spark/ml/r/RandomForestRegressionWrapper.scala index 4b9a3a731da9b..038bd79c7022b 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/r/RandomForestRegressionWrapper.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/r/RandomForestRegressionWrapper.scala @@ -42,6 +42,7 @@ private[r] class RandomForestRegressorWrapper private ( lazy val featureImportances: Vector = rfrModel.featureImportances lazy val numTrees: Int = rfrModel.getNumTrees lazy val treeWeights: Array[Double] = rfrModel.treeWeights + lazy val maxDepth: Int = rfrModel.getMaxDepth def summary: String = rfrModel.toDebugString diff --git a/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala b/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala index 60dd7367053e2..a20ef72446661 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala @@ -19,6 +19,7 @@ package org.apache.spark.ml.recommendation import java.{util => ju} import java.io.IOException +import java.util.Locale import scala.collection.mutable import scala.reflect.ClassTag @@ -40,8 +41,7 @@ import org.apache.spark.ml.util._ import org.apache.spark.mllib.linalg.CholeskyDecomposition import org.apache.spark.mllib.optimization.NNLS import org.apache.spark.rdd.RDD -import org.apache.spark.sql.{DataFrame, Dataset, Row} -import org.apache.spark.sql.catalyst.encoders.RowEncoder +import org.apache.spark.sql.{DataFrame, Dataset} import org.apache.spark.sql.functions._ import org.apache.spark.sql.types._ import org.apache.spark.storage.StorageLevel @@ -118,10 +118,11 @@ private[recommendation] trait ALSModelParams extends Params with HasPredictionCo "useful in cross-validation or production scenarios, for handling user/item ids the model " + "has not seen in the training data. Supported values: " + s"${ALSModel.supportedColdStartStrategies.mkString(",")}.", - (s: String) => ALSModel.supportedColdStartStrategies.contains(s.toLowerCase)) + (s: String) => + ALSModel.supportedColdStartStrategies.contains(s.toLowerCase(Locale.ROOT))) /** @group expertGetParam */ - def getColdStartStrategy: String = $(coldStartStrategy).toLowerCase + def getColdStartStrategy: String = $(coldStartStrategy).toLowerCase(Locale.ROOT) } /** diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala index 110764dc074f7..bff0d9bbb46ff 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala @@ -17,6 +17,8 @@ package org.apache.spark.ml.regression +import java.util.Locale + import breeze.stats.{distributions => dist} import org.apache.hadoop.fs.Path @@ -57,7 +59,7 @@ private[regression] trait GeneralizedLinearRegressionBase extends PredictorParam final val family: Param[String] = new Param(this, "family", "The name of family which is a description of the error distribution to be used in the " + s"model. Supported options: ${supportedFamilyNames.mkString(", ")}.", - (value: String) => supportedFamilyNames.contains(value.toLowerCase)) + (value: String) => supportedFamilyNames.contains(value.toLowerCase(Locale.ROOT))) /** @group getParam */ @Since("2.0.0") @@ -66,7 +68,7 @@ private[regression] trait GeneralizedLinearRegressionBase extends PredictorParam /** * Param for the power in the variance function of the Tweedie distribution which provides * the relationship between the variance and mean of the distribution. - * Only applicable for the Tweedie family. + * Only applicable to the Tweedie family. * (see * Tweedie Distribution (Wikipedia)) * Supported values: 0 and [1, Inf). @@ -79,7 +81,7 @@ private[regression] trait GeneralizedLinearRegressionBase extends PredictorParam final val variancePower: DoubleParam = new DoubleParam(this, "variancePower", "The power in the variance function of the Tweedie distribution which characterizes " + "the relationship between the variance and mean of the distribution. " + - "Only applicable for the Tweedie family. Supported values: 0 and [1, Inf).", + "Only applicable to the Tweedie family. Supported values: 0 and [1, Inf).", (x: Double) => x >= 1.0 || x == 0.0) /** @group getParam */ @@ -99,14 +101,14 @@ private[regression] trait GeneralizedLinearRegressionBase extends PredictorParam final val link: Param[String] = new Param(this, "link", "The name of link function " + "which provides the relationship between the linear predictor and the mean of the " + s"distribution function. Supported options: ${supportedLinkNames.mkString(", ")}", - (value: String) => supportedLinkNames.contains(value.toLowerCase)) + (value: String) => supportedLinkNames.contains(value.toLowerCase(Locale.ROOT))) /** @group getParam */ @Since("2.0.0") def getLink: String = $(link) /** - * Param for the index in the power link function. Only applicable for the Tweedie family. + * Param for the index in the power link function. Only applicable to the Tweedie family. * Note that link power 0, 1, -1 or 0.5 corresponds to the Log, Identity, Inverse or Sqrt * link, respectively. * When not set, this value defaults to 1 - [[variancePower]], which matches the R "statmod" @@ -116,7 +118,7 @@ private[regression] trait GeneralizedLinearRegressionBase extends PredictorParam */ @Since("2.2.0") final val linkPower: DoubleParam = new DoubleParam(this, "linkPower", - "The index in the power link function. Only applicable for the Tweedie family.") + "The index in the power link function. Only applicable to the Tweedie family.") /** @group getParam */ @Since("2.2.0") @@ -148,7 +150,7 @@ private[regression] trait GeneralizedLinearRegressionBase extends PredictorParam schema: StructType, fitting: Boolean, featuresDataType: DataType): StructType = { - if ($(family).toLowerCase == "tweedie") { + if ($(family).toLowerCase(Locale.ROOT) == "tweedie") { if (isSet(link)) { logWarning("When family is tweedie, use param linkPower to specify link function. " + "Setting param link will take no effect.") @@ -460,13 +462,15 @@ object GeneralizedLinearRegression extends DefaultParamsReadable[GeneralizedLine */ def apply(params: GeneralizedLinearRegressionBase): FamilyAndLink = { val familyObj = Family.fromParams(params) - val linkObj = if ((params.getFamily.toLowerCase != "tweedie" && - params.isSet(params.link)) || (params.getFamily.toLowerCase == "tweedie" && - params.isSet(params.linkPower))) { - Link.fromParams(params) - } else { - familyObj.defaultLink - } + val linkObj = + if ((params.getFamily.toLowerCase(Locale.ROOT) != "tweedie" && + params.isSet(params.link)) || + (params.getFamily.toLowerCase(Locale.ROOT) == "tweedie" && + params.isSet(params.linkPower))) { + Link.fromParams(params) + } else { + familyObj.defaultLink + } new FamilyAndLink(familyObj, linkObj) } } @@ -519,7 +523,7 @@ object GeneralizedLinearRegression extends DefaultParamsReadable[GeneralizedLine * @param params the parameter map containing family name and variance power */ def fromParams(params: GeneralizedLinearRegressionBase): Family = { - params.getFamily.toLowerCase match { + params.getFamily.toLowerCase(Locale.ROOT) match { case Gaussian.name => Gaussian case Binomial.name => Binomial case Poisson.name => Poisson @@ -795,7 +799,7 @@ object GeneralizedLinearRegression extends DefaultParamsReadable[GeneralizedLine * @param params the parameter map containing family, link and linkPower */ def fromParams(params: GeneralizedLinearRegressionBase): Link = { - if (params.getFamily.toLowerCase == "tweedie") { + if (params.getFamily.toLowerCase(Locale.ROOT) == "tweedie") { params.getLinkPower match { case 0.0 => Log case 1.0 => Identity @@ -804,7 +808,7 @@ object GeneralizedLinearRegression extends DefaultParamsReadable[GeneralizedLine case others => new Power(others) } } else { - params.getLink.toLowerCase match { + params.getLink.toLowerCase(Locale.ROOT) match { case Identity.name => Identity case Logit.name => Logit case Log.name => Log @@ -890,10 +894,10 @@ object GeneralizedLinearRegression extends DefaultParamsReadable[GeneralizedLine private[regression] object Probit extends Link("probit") { - override def link(mu: Double): Double = dist.Gaussian(0.0, 1.0).icdf(mu) + override def link(mu: Double): Double = dist.Gaussian(0.0, 1.0).inverseCdf(mu) override def deriv(mu: Double): Double = { - 1.0 / dist.Gaussian(0.0, 1.0).pdf(dist.Gaussian(0.0, 1.0).icdf(mu)) + 1.0 / dist.Gaussian(0.0, 1.0).pdf(dist.Gaussian(0.0, 1.0).inverseCdf(mu)) } override def unlink(eta: Double): Double = dist.Gaussian(0.0, 1.0).cdf(eta) @@ -1129,7 +1133,8 @@ class GeneralizedLinearRegressionSummary private[regression] ( private[regression] lazy val link: Link = familyLink.link /** Number of instances in DataFrame predictions. */ - private[regression] lazy val numInstances: Long = predictions.count() + @Since("2.2.0") + lazy val numInstances: Long = predictions.count() /** The numeric rank of the fitted linear model. */ @Since("2.0.0") @@ -1253,8 +1258,8 @@ class GeneralizedLinearRegressionSummary private[regression] ( */ @Since("2.0.0") lazy val dispersion: Double = if ( - model.getFamily.toLowerCase == Binomial.name || - model.getFamily.toLowerCase == Poisson.name) { + model.getFamily.toLowerCase(Locale.ROOT) == Binomial.name || + model.getFamily.toLowerCase(Locale.ROOT) == Poisson.name) { 1.0 } else { val rss = pearsonResiduals.agg(sum(pow(col("pearsonResiduals"), 2.0))).first().getDouble(0) @@ -1357,8 +1362,8 @@ class GeneralizedLinearRegressionTrainingSummary private[regression] ( @Since("2.0.0") lazy val pValues: Array[Double] = { if (isNormalSolver) { - if (model.getFamily.toLowerCase == Binomial.name || - model.getFamily.toLowerCase == Poisson.name) { + if (model.getFamily.toLowerCase(Locale.ROOT) == Binomial.name || + model.getFamily.toLowerCase(Locale.ROOT) == Poisson.name) { tValues.map { x => 2.0 * (1.0 - dist.Gaussian(0.0, 1.0).cdf(math.abs(x))) } } else { tValues.map { x => diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala index 45df1d9be647d..eaad54985229e 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala @@ -696,7 +696,8 @@ class LinearRegressionSummary private[regression] ( lazy val numInstances: Long = predictions.count() /** Degrees of freedom */ - private val degreesOfFreedom: Long = if (privateModel.getFitIntercept) { + @Since("2.2.0") + val degreesOfFreedom: Long = if (privateModel.getFitIntercept) { numInstances - privateModel.coefficients.size - 1 } else { numInstances - privateModel.coefficients.size @@ -970,9 +971,6 @@ private class LeastSquaresAggregator( */ def add(instance: Instance): this.type = { instance match { case Instance(label, weight, features) => - require(dim == features.size, s"Dimensions mismatch when adding new sample." + - s" Expecting $dim but got ${features.size}.") - require(weight >= 0.0, s"instance weight, $weight has to be >= 0.0") if (weight == 0.0) return this @@ -1004,8 +1002,6 @@ private class LeastSquaresAggregator( * @return This LeastSquaresAggregator object. */ def merge(other: LeastSquaresAggregator): this.type = { - require(dim == other.dim, s"Dimensions mismatch when merging with another " + - s"LeastSquaresAggregator. Expecting $dim but got ${other.dim}.") if (other.weightSum != 0) { totalCnt += other.totalCnt diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala index 2f524a8c5784d..a58da50fad972 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala @@ -131,7 +131,7 @@ class RandomForestRegressor @Since("1.4.0") (@Since("1.4.0") override val uid: S .map(_.asInstanceOf[DecisionTreeRegressionModel]) val numFeatures = oldDataset.first().features.size - val m = new RandomForestRegressionModel(trees, numFeatures) + val m = new RandomForestRegressionModel(uid, trees, numFeatures) instr.logSuccess(m) m } diff --git a/mllib/src/main/scala/org/apache/spark/ml/stat/ChiSquareTest.scala b/mllib/src/main/scala/org/apache/spark/ml/stat/ChiSquareTest.scala new file mode 100644 index 0000000000000..5b38ca73e8014 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/stat/ChiSquareTest.scala @@ -0,0 +1,81 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.stat + +import org.apache.spark.annotation.{Experimental, Since} +import org.apache.spark.ml.linalg.{Vector, Vectors, VectorUDT} +import org.apache.spark.ml.util.SchemaUtils +import org.apache.spark.mllib.linalg.{Vectors => OldVectors} +import org.apache.spark.mllib.regression.{LabeledPoint => OldLabeledPoint} +import org.apache.spark.mllib.stat.{Statistics => OldStatistics} +import org.apache.spark.sql.DataFrame +import org.apache.spark.sql.functions.col + + +/** + * :: Experimental :: + * + * Chi-square hypothesis testing for categorical data. + * + * See Wikipedia for more information + * on the Chi-squared test. + */ +@Experimental +@Since("2.2.0") +object ChiSquareTest { + + /** Used to construct output schema of tests */ + private case class ChiSquareResult( + pValues: Vector, + degreesOfFreedom: Array[Int], + statistics: Vector) + + /** + * Conduct Pearson's independence test for every feature against the label. For each feature, the + * (feature, label) pairs are converted into a contingency matrix for which the Chi-squared + * statistic is computed. All label and feature values must be categorical. + * + * The null hypothesis is that the occurrence of the outcomes is statistically independent. + * + * @param dataset DataFrame of categorical labels and categorical features. + * Real-valued features will be treated as categorical for each distinct value. + * @param featuresCol Name of features column in dataset, of type `Vector` (`VectorUDT`) + * @param labelCol Name of label column in dataset, of any numerical type + * @return DataFrame containing the test result for every feature against the label. + * This DataFrame will contain a single Row with the following fields: + * - `pValues: Vector` + * - `degreesOfFreedom: Array[Int]` + * - `statistics: Vector` + * Each of these fields has one value per feature. + */ + @Since("2.2.0") + def test(dataset: DataFrame, featuresCol: String, labelCol: String): DataFrame = { + val spark = dataset.sparkSession + import spark.implicits._ + + SchemaUtils.checkColumnType(dataset.schema, featuresCol, new VectorUDT) + SchemaUtils.checkNumericType(dataset.schema, labelCol) + val rdd = dataset.select(col(labelCol).cast("double"), col(featuresCol)).as[(Double, Vector)] + .rdd.map { case (label, features) => OldLabeledPoint(label, OldVectors.fromML(features)) } + val testResults = OldStatistics.chiSqTest(rdd) + val pValues: Vector = Vectors.dense(testResults.map(_.pValue)) + val degreesOfFreedom: Array[Int] = testResults.map(_.degreesOfFreedom) + val statistics: Vector = Vectors.dense(testResults.map(_.statistic)) + spark.createDataFrame(Seq(ChiSquareResult(pValues, degreesOfFreedom, statistics))) + } +} diff --git a/mllib/src/main/scala/org/apache/spark/ml/stat/Correlation.scala b/mllib/src/main/scala/org/apache/spark/ml/stat/Correlation.scala new file mode 100644 index 0000000000000..e185bc8a6faaa --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/stat/Correlation.scala @@ -0,0 +1,86 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.stat + +import scala.collection.JavaConverters._ + +import org.apache.spark.annotation.{Experimental, Since} +import org.apache.spark.ml.linalg.{SQLDataTypes, Vector} +import org.apache.spark.mllib.linalg.{Vectors => OldVectors} +import org.apache.spark.mllib.stat.{Statistics => OldStatistics} +import org.apache.spark.sql.{DataFrame, Dataset, Row} +import org.apache.spark.sql.types.{StructField, StructType} + +/** + * API for correlation functions in MLlib, compatible with Dataframes and Datasets. + * + * The functions in this package generalize the functions in [[org.apache.spark.sql.Dataset#stat]] + * to spark.ml's Vector types. + */ +@Since("2.2.0") +@Experimental +object Correlation { + + /** + * :: Experimental :: + * Compute the correlation matrix for the input Dataset of Vectors using the specified method. + * Methods currently supported: `pearson` (default), `spearman`. + * + * @param dataset A dataset or a dataframe + * @param column The name of the column of vectors for which the correlation coefficient needs + * to be computed. This must be a column of the dataset, and it must contain + * Vector objects. + * @param method String specifying the method to use for computing correlation. + * Supported: `pearson` (default), `spearman` + * @return A dataframe that contains the correlation matrix of the column of vectors. This + * dataframe contains a single row and a single column of name + * '$METHODNAME($COLUMN)'. + * @throws IllegalArgumentException if the column is not a valid column in the dataset, or if + * the content of this column is not of type Vector. + * + * Here is how to access the correlation coefficient: + * {{{ + * val data: Dataset[Vector] = ... + * val Row(coeff: Matrix) = Correlation.corr(data, "value").head + * // coeff now contains the Pearson correlation matrix. + * }}} + * + * @note For Spearman, a rank correlation, we need to create an RDD[Double] for each column + * and sort it in order to retrieve the ranks and then join the columns back into an RDD[Vector], + * which is fairly costly. Cache the input Dataset before calling corr with `method = "spearman"` + * to avoid recomputing the common lineage. + */ + @Since("2.2.0") + def corr(dataset: Dataset[_], column: String, method: String): DataFrame = { + val rdd = dataset.select(column).rdd.map { + case Row(v: Vector) => OldVectors.fromML(v) + } + val oldM = OldStatistics.corr(rdd, method) + val name = s"$method($column)" + val schema = StructType(Array(StructField(name, SQLDataTypes.MatrixType, nullable = false))) + dataset.sparkSession.createDataFrame(Seq(Row(oldM.asML)).asJava, schema) + } + + /** + * Compute the Pearson correlation matrix for the input Dataset of Vectors. + */ + @Since("2.2.0") + def corr(dataset: Dataset[_], column: String): DataFrame = { + corr(dataset, column, "pearson") + } +} diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/GradientBoostedTrees.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/GradientBoostedTrees.scala index 4c525c0714ec5..ce2bd7b430f43 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/GradientBoostedTrees.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/GradientBoostedTrees.scala @@ -21,12 +21,12 @@ import org.apache.spark.internal.Logging import org.apache.spark.ml.feature.LabeledPoint import org.apache.spark.ml.linalg.Vector import org.apache.spark.ml.regression.{DecisionTreeRegressionModel, DecisionTreeRegressor} -import org.apache.spark.mllib.impl.PeriodicRDDCheckpointer import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo} import org.apache.spark.mllib.tree.configuration.{BoostingStrategy => OldBoostingStrategy} import org.apache.spark.mllib.tree.impurity.{Variance => OldVariance} import org.apache.spark.mllib.tree.loss.{Loss => OldLoss} import org.apache.spark.rdd.RDD +import org.apache.spark.rdd.util.PeriodicRDDCheckpointer import org.apache.spark.storage.StorageLevel diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala index 008dd19c2498d..82e1ed85a0a14 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala @@ -996,7 +996,7 @@ private[spark] object RandomForest extends Logging { require(metadata.isContinuous(featureIndex), "findSplitsForContinuousFeature can only be used to find splits for a continuous feature.") - val splits = if (featureSamples.isEmpty) { + val splits: Array[Double] = if (featureSamples.isEmpty) { Array.empty[Double] } else { val numSplits = metadata.numSplits(featureIndex) @@ -1009,10 +1009,15 @@ private[spark] object RandomForest extends Logging { // sort distinct values val valueCounts = valueCountMap.toSeq.sortBy(_._1).toArray - // if possible splits is not enough or just enough, just return all possible splits val possibleSplits = valueCounts.length - 1 - if (possibleSplits <= numSplits) { - valueCounts.map(_._1).init + if (possibleSplits == 0) { + // constant feature + Array.empty[Double] + } else if (possibleSplits <= numSplits) { + // if possible splits is not enough or just enough, just return all possible splits + (1 to possibleSplits) + .map(index => (valueCounts(index - 1)._1 + valueCounts(index)._1) / 2.0) + .toArray } else { // stride between splits val stride: Double = numSamples.toDouble / (numSplits + 1) @@ -1037,7 +1042,7 @@ private[spark] object RandomForest extends Logging { // makes the gap between currentCount and targetCount smaller, // previous value is a split threshold. if (previousGap < currentGap) { - splitsBuilder += valueCounts(index - 1)._1 + splitsBuilder += (valueCounts(index - 1)._1 + valueCounts(index)._1) / 2.0 targetCount += stride } index += 1 diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala index 5eb707dfe7bc3..cd1950bd76c05 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala @@ -17,6 +17,8 @@ package org.apache.spark.ml.tree +import java.util.Locale + import scala.util.Try import org.apache.spark.ml.PredictorParams @@ -218,7 +220,8 @@ private[ml] trait TreeClassifierParams extends Params { final val impurity: Param[String] = new Param[String](this, "impurity", "Criterion used for" + " information gain calculation (case-insensitive). Supported options:" + s" ${TreeClassifierParams.supportedImpurities.mkString(", ")}", - (value: String) => TreeClassifierParams.supportedImpurities.contains(value.toLowerCase)) + (value: String) => + TreeClassifierParams.supportedImpurities.contains(value.toLowerCase(Locale.ROOT))) setDefault(impurity -> "gini") @@ -230,7 +233,7 @@ private[ml] trait TreeClassifierParams extends Params { def setImpurity(value: String): this.type = set(impurity, value) /** @group getParam */ - final def getImpurity: String = $(impurity).toLowerCase + final def getImpurity: String = $(impurity).toLowerCase(Locale.ROOT) /** Convert new impurity to old impurity. */ private[ml] def getOldImpurity: OldImpurity = { @@ -247,7 +250,8 @@ private[ml] trait TreeClassifierParams extends Params { private[ml] object TreeClassifierParams { // These options should be lowercase. - final val supportedImpurities: Array[String] = Array("entropy", "gini").map(_.toLowerCase) + final val supportedImpurities: Array[String] = + Array("entropy", "gini").map(_.toLowerCase(Locale.ROOT)) } private[ml] trait DecisionTreeClassifierParams @@ -267,7 +271,8 @@ private[ml] trait TreeRegressorParams extends Params { final val impurity: Param[String] = new Param[String](this, "impurity", "Criterion used for" + " information gain calculation (case-insensitive). Supported options:" + s" ${TreeRegressorParams.supportedImpurities.mkString(", ")}", - (value: String) => TreeRegressorParams.supportedImpurities.contains(value.toLowerCase)) + (value: String) => + TreeRegressorParams.supportedImpurities.contains(value.toLowerCase(Locale.ROOT))) setDefault(impurity -> "variance") @@ -279,7 +284,7 @@ private[ml] trait TreeRegressorParams extends Params { def setImpurity(value: String): this.type = set(impurity, value) /** @group getParam */ - final def getImpurity: String = $(impurity).toLowerCase + final def getImpurity: String = $(impurity).toLowerCase(Locale.ROOT) /** Convert new impurity to old impurity. */ private[ml] def getOldImpurity: OldImpurity = { @@ -295,7 +300,8 @@ private[ml] trait TreeRegressorParams extends Params { private[ml] object TreeRegressorParams { // These options should be lowercase. - final val supportedImpurities: Array[String] = Array("variance").map(_.toLowerCase) + final val supportedImpurities: Array[String] = + Array("variance").map(_.toLowerCase(Locale.ROOT)) } private[ml] trait DecisionTreeRegressorParams extends DecisionTreeParams @@ -417,7 +423,8 @@ private[ml] trait RandomForestParams extends TreeEnsembleParams { s" Supported options: ${RandomForestParams.supportedFeatureSubsetStrategies.mkString(", ")}" + s", (0.0-1.0], [1-n].", (value: String) => - RandomForestParams.supportedFeatureSubsetStrategies.contains(value.toLowerCase) + RandomForestParams.supportedFeatureSubsetStrategies.contains( + value.toLowerCase(Locale.ROOT)) || Try(value.toInt).filter(_ > 0).isSuccess || Try(value.toDouble).filter(_ > 0).filter(_ <= 1.0).isSuccess) @@ -431,13 +438,13 @@ private[ml] trait RandomForestParams extends TreeEnsembleParams { def setFeatureSubsetStrategy(value: String): this.type = set(featureSubsetStrategy, value) /** @group getParam */ - final def getFeatureSubsetStrategy: String = $(featureSubsetStrategy).toLowerCase + final def getFeatureSubsetStrategy: String = $(featureSubsetStrategy).toLowerCase(Locale.ROOT) } private[spark] object RandomForestParams { // These options should be lowercase. final val supportedFeatureSubsetStrategies: Array[String] = - Array("auto", "all", "onethird", "sqrt", "log2").map(_.toLowerCase) + Array("auto", "all", "onethird", "sqrt", "log2").map(_.toLowerCase(Locale.ROOT)) } private[ml] trait RandomForestClassifierParams @@ -509,7 +516,8 @@ private[ml] trait GBTParams extends TreeEnsembleParams with HasMaxIter { private[ml] object GBTClassifierParams { // The losses below should be lowercase. /** Accessor for supported loss settings: logistic */ - final val supportedLossTypes: Array[String] = Array("logistic").map(_.toLowerCase) + final val supportedLossTypes: Array[String] = + Array("logistic").map(_.toLowerCase(Locale.ROOT)) } private[ml] trait GBTClassifierParams extends GBTParams with TreeClassifierParams { @@ -523,12 +531,13 @@ private[ml] trait GBTClassifierParams extends GBTParams with TreeClassifierParam val lossType: Param[String] = new Param[String](this, "lossType", "Loss function which GBT" + " tries to minimize (case-insensitive). Supported options:" + s" ${GBTClassifierParams.supportedLossTypes.mkString(", ")}", - (value: String) => GBTClassifierParams.supportedLossTypes.contains(value.toLowerCase)) + (value: String) => + GBTClassifierParams.supportedLossTypes.contains(value.toLowerCase(Locale.ROOT))) setDefault(lossType -> "logistic") /** @group getParam */ - def getLossType: String = $(lossType).toLowerCase + def getLossType: String = $(lossType).toLowerCase(Locale.ROOT) /** (private[ml]) Convert new loss to old loss. */ override private[ml] def getOldLossType: OldClassificationLoss = { @@ -544,7 +553,8 @@ private[ml] trait GBTClassifierParams extends GBTParams with TreeClassifierParam private[ml] object GBTRegressorParams { // The losses below should be lowercase. /** Accessor for supported loss settings: squared (L2), absolute (L1) */ - final val supportedLossTypes: Array[String] = Array("squared", "absolute").map(_.toLowerCase) + final val supportedLossTypes: Array[String] = + Array("squared", "absolute").map(_.toLowerCase(Locale.ROOT)) } private[ml] trait GBTRegressorParams extends GBTParams with TreeRegressorParams { @@ -558,12 +568,13 @@ private[ml] trait GBTRegressorParams extends GBTParams with TreeRegressorParams val lossType: Param[String] = new Param[String](this, "lossType", "Loss function which GBT" + " tries to minimize (case-insensitive). Supported options:" + s" ${GBTRegressorParams.supportedLossTypes.mkString(", ")}", - (value: String) => GBTRegressorParams.supportedLossTypes.contains(value.toLowerCase)) + (value: String) => + GBTRegressorParams.supportedLossTypes.contains(value.toLowerCase(Locale.ROOT))) setDefault(lossType -> "squared") /** @group getParam */ - def getLossType: String = $(lossType).toLowerCase + def getLossType: String = $(lossType).toLowerCase(Locale.ROOT) /** (private[ml]) Convert new loss to old loss. */ override private[ml] def getOldLossType: OldLoss = { diff --git a/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala b/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala index 09bddcdb810bb..a8b80031faf86 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala @@ -104,8 +104,9 @@ abstract class MLWriter extends BaseReadWrite with Logging { // TODO: Revert back to the original content if save is not successful. fs.delete(qualifiedOutputPath, true) } else { - throw new IOException( - s"Path $path already exists. Please use write.overwrite().save(path) to overwrite it.") + throw new IOException(s"Path $path already exists. To overwrite it, " + + s"please use write.overwrite().save(path) for Scala and use " + + s"write().overwrite().save(path) for Java and Python.") } } saveImpl(path) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixture.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixture.scala index 051ec2404fb6e..4d952ac88c9be 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixture.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixture.scala @@ -271,7 +271,7 @@ class GaussianMixture private ( private def initCovariance(x: IndexedSeq[BV[Double]]): BreezeMatrix[Double] = { val mu = vectorMean(x) val ss = BDV.zeros[Double](x(0).length) - x.foreach(xi => ss += (xi - mu) :^ 2.0) + x.foreach(xi => ss += (xi - mu) ^:^ 2.0) diag(ss / x.length.toDouble) } } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDA.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDA.scala index 6c5f529fb8bfd..4aa647236b31c 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDA.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDA.scala @@ -17,6 +17,8 @@ package org.apache.spark.mllib.clustering +import java.util.Locale + import breeze.linalg.{DenseVector => BDV} import org.apache.spark.annotation.{DeveloperApi, Since} @@ -306,7 +308,7 @@ class LDA private ( @Since("1.4.0") def setOptimizer(optimizerName: String): this.type = { this.ldaOptimizer = - optimizerName.toLowerCase match { + optimizerName.toLowerCase(Locale.ROOT) match { case "em" => new EMLDAOptimizer case "online" => new OnlineLDAOptimizer case other => diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala index 7fd722a332923..663f63c25a940 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala @@ -314,7 +314,7 @@ class LocalLDAModel private[spark] ( docBound += count * LDAUtils.logSumExp(Elogthetad + localElogbeta(idx, ::).t) } // E[log p(theta | alpha) - log q(theta | gamma)] - docBound += sum((brzAlpha - gammad) :* Elogthetad) + docBound += sum((brzAlpha - gammad) *:* Elogthetad) docBound += sum(lgamma(gammad) - lgamma(brzAlpha)) docBound += lgamma(sum(brzAlpha)) - lgamma(sum(gammad)) @@ -324,7 +324,7 @@ class LocalLDAModel private[spark] ( // Bound component for prob(topic-term distributions): // E[log p(beta | eta) - log q(beta | lambda)] val sumEta = eta * vocabSize - val topicsPart = sum((eta - lambda) :* Elogbeta) + + val topicsPart = sum((eta - lambda) *:* Elogbeta) + sum(lgamma(lambda) - lgamma(eta)) + sum(lgamma(sumEta) - lgamma(sum(lambda(::, breeze.linalg.*)))) @@ -721,7 +721,7 @@ class DistributedLDAModel private[clustering] ( val N_wj = edgeContext.attr val smoothed_N_wk: TopicCounts = edgeContext.dstAttr + (eta - 1.0) val smoothed_N_kj: TopicCounts = edgeContext.srcAttr + (alpha - 1.0) - val phi_wk: TopicCounts = smoothed_N_wk :/ smoothed_N_k + val phi_wk: TopicCounts = smoothed_N_wk /:/ smoothed_N_k val theta_kj: TopicCounts = normalize(smoothed_N_kj, 1.0) val tokenLogLikelihood = N_wj * math.log(phi_wk.dot(theta_kj)) edgeContext.sendToDst(tokenLogLikelihood) @@ -748,7 +748,7 @@ class DistributedLDAModel private[clustering] ( if (isTermVertex(vertex)) { val N_wk = vertex._2 val smoothed_N_wk: TopicCounts = N_wk + (eta - 1.0) - val phi_wk: TopicCounts = smoothed_N_wk :/ smoothed_N_k + val phi_wk: TopicCounts = smoothed_N_wk /:/ smoothed_N_k sumPrior + (eta - 1.0) * sum(phi_wk.map(math.log)) } else { val N_kj = vertex._2 @@ -788,20 +788,14 @@ class DistributedLDAModel private[clustering] ( @Since("1.5.0") def topTopicsPerDocument(k: Int): RDD[(Long, Array[Int], Array[Double])] = { graph.vertices.filter(LDA.isDocumentVertex).map { case (docID, topicCounts) => - // TODO: Remove work-around for the breeze bug. - // https://github.com/scalanlp/breeze/issues/561 - val topIndices = if (k == topicCounts.length) { - Seq.range(0, k) - } else { - argtopk(topicCounts, k) - } + val topIndices = argtopk(topicCounts, k) val sumCounts = sum(topicCounts) val weights = if (sumCounts != 0) { - topicCounts(topIndices) / sumCounts + topicCounts(topIndices).toArray.map(_ / sumCounts) } else { - topicCounts(topIndices) + topicCounts(topIndices).toArray } - (docID.toLong, topIndices.toArray, weights.toArray) + (docID.toLong, topIndices.toArray, weights) } } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAOptimizer.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAOptimizer.scala index 48bae4276c480..d633893e55f55 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAOptimizer.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAOptimizer.scala @@ -25,7 +25,7 @@ import breeze.stats.distributions.{Gamma, RandBasis} import org.apache.spark.annotation.{DeveloperApi, Since} import org.apache.spark.graphx._ -import org.apache.spark.mllib.impl.PeriodicGraphCheckpointer +import org.apache.spark.graphx.util.PeriodicGraphCheckpointer import org.apache.spark.mllib.linalg.{DenseVector, Matrices, SparseVector, Vector, Vectors} import org.apache.spark.rdd.RDD import org.apache.spark.storage.StorageLevel @@ -482,7 +482,7 @@ final class OnlineLDAOptimizer extends LDAOptimizer { stats.map(_._2).flatMap(list => list).collect().map(_.toDenseMatrix): _*) stats.unpersist() expElogbetaBc.destroy(false) - val batchResult = statsSum :* expElogbeta.t + val batchResult = statsSum *:* expElogbeta.t // Note that this is an optimization to avoid batch.count updateLambda(batchResult, (miniBatchFraction * corpusSize).ceil.toInt) @@ -522,7 +522,7 @@ final class OnlineLDAOptimizer extends LDAOptimizer { val dalpha = -(gradf - b) / q - if (all((weight * dalpha + alpha) :> 0D)) { + if (all((weight * dalpha + alpha) >:> 0D)) { alpha :+= weight * dalpha this.alpha = Vectors.dense(alpha.toArray) } @@ -584,7 +584,7 @@ private[clustering] object OnlineLDAOptimizer { val expElogthetad: BDV[Double] = exp(LDAUtils.dirichletExpectation(gammad)) // K val expElogbetad = expElogbeta(ids, ::).toDenseMatrix // ids * K - val phiNorm: BDV[Double] = expElogbetad * expElogthetad :+ 1e-100 // ids + val phiNorm: BDV[Double] = expElogbetad * expElogthetad +:+ 1e-100 // ids var meanGammaChange = 1D val ctsVector = new BDV[Double](cts) // ids @@ -592,14 +592,14 @@ private[clustering] object OnlineLDAOptimizer { while (meanGammaChange > 1e-3) { val lastgamma = gammad.copy // K K * ids ids - gammad := (expElogthetad :* (expElogbetad.t * (ctsVector :/ phiNorm))) :+ alpha + gammad := (expElogthetad *:* (expElogbetad.t * (ctsVector /:/ phiNorm))) +:+ alpha expElogthetad := exp(LDAUtils.dirichletExpectation(gammad)) // TODO: Keep more values in log space, and only exponentiate when needed. - phiNorm := expElogbetad * expElogthetad :+ 1e-100 + phiNorm := expElogbetad * expElogthetad +:+ 1e-100 meanGammaChange = sum(abs(gammad - lastgamma)) / k } - val sstatsd = expElogthetad.asDenseMatrix.t * (ctsVector :/ phiNorm).asDenseMatrix + val sstatsd = expElogthetad.asDenseMatrix.t * (ctsVector /:/ phiNorm).asDenseMatrix (gammad, sstatsd, ids) } } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAUtils.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAUtils.scala index 1f6e1a077f923..c4bbe51a46c32 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAUtils.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAUtils.scala @@ -29,7 +29,7 @@ private[clustering] object LDAUtils { */ private[clustering] def logSumExp(x: BDV[Double]): Double = { val a = max(x) - a + log(sum(exp(x :- a))) + a + log(sum(exp(x -:- a))) } /** diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/PowerIterationClustering.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/PowerIterationClustering.scala index 4d3e265455da6..b2437b845f826 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/PowerIterationClustering.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/PowerIterationClustering.scala @@ -259,7 +259,7 @@ object PowerIterationClustering extends Logging { val j = ctx.dstId val s = ctx.attr if (s < 0.0) { - throw new SparkException("Similarity must be nonnegative but found s($i, $j) = $s.") + throw new SparkException(s"Similarity must be nonnegative but found s($i, $j) = $s.") } if (s > 0.0) { ctx.sendToSrc(s) @@ -283,7 +283,7 @@ object PowerIterationClustering extends Logging { : Graph[Double, Double] = { val edges = similarities.flatMap { case (i, j, s) => if (s < 0.0) { - throw new SparkException("Similarity must be nonnegative but found s($i, $j) = $s.") + throw new SparkException(s"Similarity must be nonnegative but found s($i, $j) = $s.") } if (i != j) { Seq(Edge(i, j, s), Edge(j, i, s)) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala b/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala index 2364d43aaa0e2..6f96813497b62 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala @@ -30,6 +30,7 @@ import org.json4s.jackson.JsonMethods._ import org.apache.spark.SparkContext import org.apache.spark.annotation.Since import org.apache.spark.api.java.JavaRDD +import org.apache.spark.broadcast.Broadcast import org.apache.spark.internal.Logging import org.apache.spark.mllib.linalg.{Vector, Vectors} import org.apache.spark.mllib.util.{Loader, Saveable} @@ -314,6 +315,20 @@ class Word2Vec extends Serializable with Logging { val expTable = sc.broadcast(createExpTable()) val bcVocab = sc.broadcast(vocab) val bcVocabHash = sc.broadcast(vocabHash) + try { + doFit(dataset, sc, expTable, bcVocab, bcVocabHash) + } finally { + expTable.destroy(blocking = false) + bcVocab.destroy(blocking = false) + bcVocabHash.destroy(blocking = false) + } + } + + private def doFit[S <: Iterable[String]]( + dataset: RDD[S], sc: SparkContext, + expTable: Broadcast[Array[Float]], + bcVocab: Broadcast[Array[VocabWord]], + bcVocabHash: Broadcast[mutable.HashMap[String, Int]]) = { // each partition is a collection of sentences, // will be translated into arrays of Index integer val sentences: RDD[Array[Int]] = dataset.mapPartitions { sentenceIter => @@ -435,9 +450,6 @@ class Word2Vec extends Serializable with Logging { bcSyn1Global.destroy(false) } newSentences.unpersist() - expTable.destroy(false) - bcVocab.destroy(false) - bcVocabHash.destroy(false) val wordArray = vocab.map(_.word) new Word2VecModel(wordArray.zipWithIndex.toMap, syn0Global) @@ -479,8 +491,8 @@ class Word2VecModel private[spark] ( // wordVecNorms: Array of length numWords, each value being the Euclidean norm // of the wordVector. - private val wordVecNorms: Array[Double] = { - val wordVecNorms = new Array[Double](numWords) + private val wordVecNorms: Array[Float] = { + val wordVecNorms = new Array[Float](numWords) var i = 0 while (i < numWords) { val vec = wordVectors.slice(i * vectorSize, i * vectorSize + vectorSize) @@ -558,7 +570,7 @@ class Word2VecModel private[spark] ( require(num > 0, "Number of similar words should > 0") val fVector = vector.toArray.map(_.toFloat) - val cosineVec = Array.fill[Float](numWords)(0) + val cosineVec = new Array[Float](numWords) val alpha: Float = 1 val beta: Float = 0 // Normalize input vector before blas.sgemv to avoid Inf value @@ -569,22 +581,23 @@ class Word2VecModel private[spark] ( blas.sgemv( "T", vectorSize, numWords, alpha, wordVectors, vectorSize, fVector, 1, beta, cosineVec, 1) - val cosVec = cosineVec.map(_.toDouble) - var ind = 0 - while (ind < numWords) { - val norm = wordVecNorms(ind) - if (norm == 0.0) { - cosVec(ind) = 0.0 + var i = 0 + while (i < numWords) { + val norm = wordVecNorms(i) + if (norm == 0.0f) { + cosineVec(i) = 0.0f } else { - cosVec(ind) /= norm + cosineVec(i) /= norm } - ind += 1 + i += 1 } - val pq = new BoundedPriorityQueue[(String, Double)](num + 1)(Ordering.by(_._2)) + val pq = new BoundedPriorityQueue[(String, Float)](num + 1)(Ordering.by(_._2)) - for(i <- cosVec.indices) { - pq += Tuple2(wordList(i), cosVec(i)) + var j = 0 + while (j < numWords) { + pq += Tuple2(wordList(j), cosineVec(j)) + j += 1 } val scored = pq.toSeq.sortBy(-_._2) @@ -594,7 +607,10 @@ class Word2VecModel private[spark] ( case None => scored } - filtered.take(num).toArray + filtered + .take(num) + .map { case (word, score) => (word, score.toDouble) } + .toArray } /** diff --git a/mllib/src/main/scala/org/apache/spark/mllib/fpm/PrefixSpan.scala b/mllib/src/main/scala/org/apache/spark/mllib/fpm/PrefixSpan.scala index 327cb974ef96c..3f8d65a378e2c 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/fpm/PrefixSpan.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/fpm/PrefixSpan.scala @@ -144,45 +144,13 @@ class PrefixSpan private ( logInfo(s"minimum count for a frequent pattern: $minCount") // Find frequent items. - val freqItemAndCounts = data.flatMap { itemsets => - val uniqItems = mutable.Set.empty[Item] - itemsets.foreach { _.foreach { item => - uniqItems += item - }} - uniqItems.toIterator.map((_, 1L)) - }.reduceByKey(_ + _) - .filter { case (_, count) => - count >= minCount - }.collect() - val freqItems = freqItemAndCounts.sortBy(-_._2).map(_._1) + val freqItems = findFrequentItems(data, minCount) logInfo(s"number of frequent items: ${freqItems.length}") // Keep only frequent items from input sequences and convert them to internal storage. val itemToInt = freqItems.zipWithIndex.toMap - val dataInternalRepr = data.flatMap { itemsets => - val allItems = mutable.ArrayBuilder.make[Int] - var containsFreqItems = false - allItems += 0 - itemsets.foreach { itemsets => - val items = mutable.ArrayBuilder.make[Int] - itemsets.foreach { item => - if (itemToInt.contains(item)) { - items += itemToInt(item) + 1 // using 1-indexing in internal format - } - } - val result = items.result() - if (result.nonEmpty) { - containsFreqItems = true - allItems ++= result.sorted - } - allItems += 0 - } - if (containsFreqItems) { - Iterator.single(allItems.result()) - } else { - Iterator.empty - } - }.persist(StorageLevel.MEMORY_AND_DISK) + val dataInternalRepr = toDatabaseInternalRepr(data, itemToInt) + .persist(StorageLevel.MEMORY_AND_DISK) val results = genFreqPatterns(dataInternalRepr, minCount, maxPatternLength, maxLocalProjDBSize) @@ -231,6 +199,67 @@ class PrefixSpan private ( @Since("1.5.0") object PrefixSpan extends Logging { + /** + * This methods finds all frequent items in a input dataset. + * + * @param data Sequences of itemsets. + * @param minCount The minimal number of sequence an item should be present in to be frequent + * + * @return An array of Item containing only frequent items. + */ + private[fpm] def findFrequentItems[Item: ClassTag]( + data: RDD[Array[Array[Item]]], + minCount: Long): Array[Item] = { + + data.flatMap { itemsets => + val uniqItems = mutable.Set.empty[Item] + itemsets.foreach(set => uniqItems ++= set) + uniqItems.toIterator.map((_, 1L)) + }.reduceByKey(_ + _).filter { case (_, count) => + count >= minCount + }.sortBy(-_._2).map(_._1).collect() + } + + /** + * This methods cleans the input dataset from un-frequent items, and translate it's item + * to their corresponding Int identifier. + * + * @param data Sequences of itemsets. + * @param itemToInt A map allowing translation of frequent Items to their Int Identifier. + * The map should only contain frequent item. + * + * @return The internal repr of the inputted dataset. With properly placed zero delimiter. + */ + private[fpm] def toDatabaseInternalRepr[Item: ClassTag]( + data: RDD[Array[Array[Item]]], + itemToInt: Map[Item, Int]): RDD[Array[Int]] = { + + data.flatMap { itemsets => + val allItems = mutable.ArrayBuilder.make[Int] + var containsFreqItems = false + allItems += 0 + itemsets.foreach { itemsets => + val items = mutable.ArrayBuilder.make[Int] + itemsets.foreach { item => + if (itemToInt.contains(item)) { + items += itemToInt(item) + 1 // using 1-indexing in internal format + } + } + val result = items.result() + if (result.nonEmpty) { + containsFreqItems = true + allItems ++= result.sorted + allItems += 0 + } + } + if (containsFreqItems) { + Iterator.single(allItems.result()) + } else { + Iterator.empty + } + } + } + /** * Find the complete set of frequent sequential patterns in the input sequences. * @param data ordered sequences of itemsets. We represent a sequence internally as Array[Int], diff --git a/mllib/src/main/scala/org/apache/spark/mllib/recommendation/ALS.scala b/mllib/src/main/scala/org/apache/spark/mllib/recommendation/ALS.scala index 76b1bc13b4b05..14288221b6945 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/recommendation/ALS.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/recommendation/ALS.scala @@ -301,7 +301,7 @@ object ALS { * level of parallelism. * * @param ratings RDD of [[Rating]] objects with userID, productID, and rating - * @param rank number of features to use + * @param rank number of features to use (also referred to as the number of latent factors) * @param iterations number of iterations of ALS * @param lambda regularization parameter * @param blocks level of parallelism to split computation into @@ -326,7 +326,7 @@ object ALS { * level of parallelism. * * @param ratings RDD of [[Rating]] objects with userID, productID, and rating - * @param rank number of features to use + * @param rank number of features to use (also referred to as the number of latent factors) * @param iterations number of iterations of ALS * @param lambda regularization parameter * @param blocks level of parallelism to split computation into @@ -349,7 +349,7 @@ object ALS { * parallelism automatically based on the number of partitions in `ratings`. * * @param ratings RDD of [[Rating]] objects with userID, productID, and rating - * @param rank number of features to use + * @param rank number of features to use (also referred to as the number of latent factors) * @param iterations number of iterations of ALS * @param lambda regularization parameter */ @@ -366,7 +366,7 @@ object ALS { * parallelism automatically based on the number of partitions in `ratings`. * * @param ratings RDD of [[Rating]] objects with userID, productID, and rating - * @param rank number of features to use + * @param rank number of features to use (also referred to as the number of latent factors) * @param iterations number of iterations of ALS */ @Since("0.8.0") @@ -383,7 +383,7 @@ object ALS { * a level of parallelism given by `blocks`. * * @param ratings RDD of (userID, productID, rating) pairs - * @param rank number of features to use + * @param rank number of features to use (also referred to as the number of latent factors) * @param iterations number of iterations of ALS * @param lambda regularization parameter * @param blocks level of parallelism to split computation into @@ -410,7 +410,7 @@ object ALS { * iteratively with a configurable level of parallelism. * * @param ratings RDD of [[Rating]] objects with userID, productID, and rating - * @param rank number of features to use + * @param rank number of features to use (also referred to as the number of latent factors) * @param iterations number of iterations of ALS * @param lambda regularization parameter * @param blocks level of parallelism to split computation into @@ -436,7 +436,7 @@ object ALS { * partitions in `ratings`. * * @param ratings RDD of [[Rating]] objects with userID, productID, and rating - * @param rank number of features to use + * @param rank number of features to use (also referred to as the number of latent factors) * @param iterations number of iterations of ALS * @param lambda regularization parameter * @param alpha confidence parameter @@ -455,7 +455,7 @@ object ALS { * partitions in `ratings`. * * @param ratings RDD of [[Rating]] objects with userID, productID, and rating - * @param rank number of features to use + * @param rank number of features to use (also referred to as the number of latent factors) * @param iterations number of iterations of ALS */ @Since("0.8.1") diff --git a/mllib/src/main/scala/org/apache/spark/mllib/stat/test/ChiSqTest.scala b/mllib/src/main/scala/org/apache/spark/mllib/stat/test/ChiSqTest.scala index 9a63b8a5d63db..ee51248e53556 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/stat/test/ChiSqTest.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/stat/test/ChiSqTest.scala @@ -41,7 +41,7 @@ import org.apache.spark.rdd.RDD * * More information on Chi-squared test: http://en.wikipedia.org/wiki/Chi-squared_test */ -private[stat] object ChiSqTest extends Logging { +private[spark] object ChiSqTest extends Logging { /** * @param name String name for the method. @@ -70,6 +70,11 @@ private[stat] object ChiSqTest extends Logging { } } + /** + * Max number of categories when indexing labels and features + */ + private[spark] val maxCategories: Int = 10000 + /** * Conduct Pearson's independence test for each feature against the label across the input RDD. * The contingency table is constructed from the raw (feature, label) pairs and used to conduct @@ -78,7 +83,6 @@ private[stat] object ChiSqTest extends Logging { */ def chiSquaredFeatures(data: RDD[LabeledPoint], methodName: String = PEARSON.name): Array[ChiSqTestResult] = { - val maxCategories = 10000 val numCols = data.first().features.size val results = new Array[ChiSqTestResult](numCols) var labels: Map[Double, Int] = null diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Impurity.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Impurity.scala index a5bdc2c6d2c94..4c7746869dde1 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Impurity.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Impurity.scala @@ -17,6 +17,8 @@ package org.apache.spark.mllib.tree.impurity +import java.util.Locale + import org.apache.spark.annotation.{DeveloperApi, Since} /** @@ -184,7 +186,7 @@ private[spark] object ImpurityCalculator { * the given stats. */ def getCalculator(impurity: String, stats: Array[Double]): ImpurityCalculator = { - impurity match { + impurity.toLowerCase(Locale.ROOT) match { case "gini" => new GiniCalculator(stats) case "entropy" => new EntropyCalculator(stats) case "variance" => new VarianceCalculator(stats) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala index a1562384b0a7e..27618e122aefd 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala @@ -248,7 +248,7 @@ object DecisionTreeModel extends Loader[DecisionTreeModel] with Logging { // Build node data into a tree. val trees = constructTrees(nodes) assert(trees.length == 1, - "Decision tree should contain exactly one tree but got ${trees.size} trees.") + s"Decision tree should contain exactly one tree but got ${trees.size} trees.") val model = new DecisionTreeModel(trees(0), Algo.fromString(algo)) assert(model.numNodes == numNodes, s"Unable to load DecisionTreeModel data from: $dataPath." + s" Expected $numNodes nodes but found ${model.numNodes}") diff --git a/mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala b/mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala index 95f904dac552c..4fdad05973969 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala @@ -119,7 +119,7 @@ object MLUtils extends Logging { while (i < indicesLength) { val current = indices(i) require(current > previous, s"indices should be one-based and in ascending order;" - + " found current=$current, previous=$previous; line=\"$line\"") + + s""" found current=$current, previous=$previous; line="$line"""") previous = current i += 1 } diff --git a/mllib/src/test/scala/org/apache/spark/ml/PipelineSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/PipelineSuite.scala index dafc6c200f95f..4a7e4dd80f246 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/PipelineSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/PipelineSuite.scala @@ -79,7 +79,7 @@ class PipelineSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul .setStages(Array(estimator0, transformer1, estimator2, transformer3)) val pipelineModel = pipeline.fit(dataset0) - MLTestingUtils.checkCopy(pipelineModel) + MLTestingUtils.checkCopyAndUids(pipeline, pipelineModel) assert(pipelineModel.stages.length === 4) assert(pipelineModel.stages(0).eq(model0)) @@ -230,7 +230,9 @@ class PipelineSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul } -/** Used to test [[Pipeline]] with [[MLWritable]] stages */ +/** + * Used to test [[Pipeline]] with `MLWritable` stages + */ class WritableStage(override val uid: String) extends Transformer with MLWritable { final val intParam: IntParam = new IntParam(this, "intParam", "doc") @@ -257,7 +259,9 @@ object WritableStage extends MLReadable[WritableStage] { override def load(path: String): WritableStage = super.load(path) } -/** Used to test [[Pipeline]] with non-[[MLWritable]] stages */ +/** + * Used to test [[Pipeline]] with non-`MLWritable` stages + */ class UnWritableStage(override val uid: String) extends Transformer { final val intParam: IntParam = new IntParam(this, "intParam", "doc") diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala index c711e7fa9dc67..918ab27e2730b 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala @@ -249,8 +249,7 @@ class DecisionTreeClassifierSuite val newData: DataFrame = TreeTests.setMetadata(rdd, categoricalFeatures, numClasses) val newTree = dt.fit(newData) - // copied model must have the same parent. - MLTestingUtils.checkCopy(newTree) + MLTestingUtils.checkCopyAndUids(dt, newTree) val predictions = newTree.transform(newData) .select(newTree.getPredictionCol, newTree.getRawPredictionCol, newTree.getProbabilityCol) @@ -372,16 +371,32 @@ class DecisionTreeClassifierSuite // Categorical splits with tree depth 2 val categoricalData: DataFrame = TreeTests.setMetadata(rdd, Map(0 -> 2, 1 -> 3), numClasses = 2) - testEstimatorAndModelReadWrite(dt, categoricalData, allParamSettings, checkModelData) + testEstimatorAndModelReadWrite(dt, categoricalData, allParamSettings, + allParamSettings, checkModelData) // Continuous splits with tree depth 2 val continuousData: DataFrame = TreeTests.setMetadata(rdd, Map.empty[Int, Int], numClasses = 2) - testEstimatorAndModelReadWrite(dt, continuousData, allParamSettings, checkModelData) + testEstimatorAndModelReadWrite(dt, continuousData, allParamSettings, + allParamSettings, checkModelData) // Continuous splits with tree depth 0 testEstimatorAndModelReadWrite(dt, continuousData, allParamSettings ++ Map("maxDepth" -> 0), - checkModelData) + allParamSettings ++ Map("maxDepth" -> 0), checkModelData) + } + + test("SPARK-20043: " + + "ImpurityCalculator builder fails for uppercase impurity type Gini in model read/write") { + val rdd = TreeTests.getTreeReadWriteData(sc) + val data: DataFrame = + TreeTests.setMetadata(rdd, Map.empty[Int, Int], numClasses = 2) + + val dt = new DecisionTreeClassifier() + .setImpurity("Gini") + .setMaxDepth(2) + val model = dt.fit(data) + + testDefaultReadWrite(model) } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala index 0598943c3d4be..1f79e0d4e6228 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala @@ -97,8 +97,7 @@ class GBTClassifierSuite extends SparkFunSuite with MLlibTestSparkContext assert(model.getProbabilityCol === "probability") assert(model.hasParent) - // copied model must have the same parent. - MLTestingUtils.checkCopy(model) + MLTestingUtils.checkCopyAndUids(gbt, model) } test("setThreshold, getThreshold") { @@ -261,8 +260,7 @@ class GBTClassifierSuite extends SparkFunSuite with MLlibTestSparkContext .setSeed(123) val model = gbt.fit(df) - // copied model must have the same parent. - MLTestingUtils.checkCopy(model) + MLTestingUtils.checkCopyAndUids(gbt, model) sc.checkpointDir = None Utils.deleteRecursively(tempDir) @@ -374,7 +372,8 @@ class GBTClassifierSuite extends SparkFunSuite with MLlibTestSparkContext val continuousData: DataFrame = TreeTests.setMetadata(rdd, Map.empty[Int, Int], numClasses = 2) - testEstimatorAndModelReadWrite(gbt, continuousData, allParamSettings, checkModelData) + testEstimatorAndModelReadWrite(gbt, continuousData, allParamSettings, + allParamSettings, checkModelData) } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/LinearSVCSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/LinearSVCSuite.scala index a165d8a9345cf..2f87afc23fe7e 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/LinearSVCSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/LinearSVCSuite.scala @@ -24,12 +24,13 @@ import breeze.linalg.{DenseVector => BDV} import org.apache.spark.SparkFunSuite import org.apache.spark.ml.classification.LinearSVCSuite._ import org.apache.spark.ml.feature.{Instance, LabeledPoint} -import org.apache.spark.ml.linalg.{Vector, Vectors} +import org.apache.spark.ml.linalg.{DenseVector, SparseVector, Vector, Vectors} import org.apache.spark.ml.param.ParamsSuite import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils} import org.apache.spark.ml.util.TestingUtils._ import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.sql.{Dataset, Row} +import org.apache.spark.sql.functions.udf class LinearSVCSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { @@ -41,6 +42,9 @@ class LinearSVCSuite extends SparkFunSuite with MLlibTestSparkContext with Defau @transient var smallValidationDataset: Dataset[_] = _ @transient var binaryDataset: Dataset[_] = _ + @transient var smallSparseBinaryDataset: Dataset[_] = _ + @transient var smallSparseValidationDataset: Dataset[_] = _ + override def beforeAll(): Unit = { super.beforeAll() @@ -51,6 +55,13 @@ class LinearSVCSuite extends SparkFunSuite with MLlibTestSparkContext with Defau smallBinaryDataset = generateSVMInput(A, Array[Double](B, C), nPoints, 42).toDF() smallValidationDataset = generateSVMInput(A, Array[Double](B, C), nPoints, 17).toDF() binaryDataset = generateSVMInput(1.0, Array[Double](1.0, 2.0, 3.0, 4.0), 10000, 42).toDF() + + // Dataset for testing SparseVector + val toSparse: Vector => SparseVector = _.asInstanceOf[DenseVector].toSparse + val sparse = udf(toSparse) + smallSparseBinaryDataset = smallBinaryDataset.withColumn("features", sparse('features)) + smallSparseValidationDataset = smallValidationDataset.withColumn("features", sparse('features)) + } /** @@ -68,6 +79,8 @@ class LinearSVCSuite extends SparkFunSuite with MLlibTestSparkContext with Defau val model = svm.fit(smallBinaryDataset) assert(model.transform(smallValidationDataset) .where("prediction=label").count() > nPoints * 0.8) + val sparseModel = svm.fit(smallSparseBinaryDataset) + checkModels(model, sparseModel) } test("Linear SVC binary classification with regularization") { @@ -75,6 +88,8 @@ class LinearSVCSuite extends SparkFunSuite with MLlibTestSparkContext with Defau val model = svm.setRegParam(0.1).fit(smallBinaryDataset) assert(model.transform(smallValidationDataset) .where("prediction=label").count() > nPoints * 0.8) + val sparseModel = svm.fit(smallSparseBinaryDataset) + checkModels(model, sparseModel) } test("params") { @@ -109,8 +124,7 @@ class LinearSVCSuite extends SparkFunSuite with MLlibTestSparkContext with Defau assert(model.hasParent) assert(model.numFeatures === 2) - // copied model must have the same parent. - MLTestingUtils.checkCopy(model) + MLTestingUtils.checkCopyAndUids(lsvc, model) } test("linear svc doesn't fit intercept when fitIntercept is off") { @@ -149,7 +163,7 @@ class LinearSVCSuite extends SparkFunSuite with MLlibTestSparkContext with Defau MLTestingUtils.testArbitrarilyScaledWeights[LinearSVCModel, LinearSVC]( dataset.as[LabeledPoint], estimator, modelEquals) MLTestingUtils.testOutliersWithSmallWeights[LinearSVCModel, LinearSVC]( - dataset.as[LabeledPoint], estimator, 2, modelEquals) + dataset.as[LabeledPoint], estimator, 2, modelEquals, outlierRatio = 3) MLTestingUtils.testOversamplingVsWeighting[LinearSVCModel, LinearSVC]( dataset.as[LabeledPoint], estimator, modelEquals, 42L) } @@ -217,7 +231,7 @@ class LinearSVCSuite extends SparkFunSuite with MLlibTestSparkContext with Defau } val svm = new LinearSVC() testEstimatorAndModelReadWrite(svm, smallBinaryDataset, LinearSVCSuite.allParamSettings, - checkModelData) + LinearSVCSuite.allParamSettings, checkModelData) } } @@ -235,7 +249,7 @@ object LinearSVCSuite { "aggregationDepth" -> 3 ) - // Generate noisy input of the form Y = signum(x.dot(weights) + intercept + noise) + // Generate noisy input of the form Y = signum(x.dot(weights) + intercept + noise) def generateSVMInput( intercept: Double, weights: Array[Double], @@ -252,5 +266,10 @@ object LinearSVCSuite { y.zip(x).map(p => LabeledPoint(p._1, Vectors.dense(p._2))) } + def checkModels(model1: LinearSVCModel, model2: LinearSVCModel): Unit = { + assert(model1.intercept == model2.intercept) + assert(model1.coefficients.equals(model2.coefficients)) + } + } diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala index d89a958eed45a..bf6bfe30bfe20 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala @@ -26,7 +26,7 @@ import org.apache.spark.{SparkException, SparkFunSuite} import org.apache.spark.ml.attribute.NominalAttribute import org.apache.spark.ml.classification.LogisticRegressionSuite._ import org.apache.spark.ml.feature.{Instance, LabeledPoint} -import org.apache.spark.ml.linalg.{DenseMatrix, Matrices, SparseMatrix, SparseVector, Vector, Vectors} +import org.apache.spark.ml.linalg.{DenseMatrix, Matrices, Matrix, SparseMatrix, Vector, Vectors} import org.apache.spark.ml.param.{ParamMap, ParamsSuite} import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils} import org.apache.spark.ml.util.TestingUtils._ @@ -142,8 +142,7 @@ class LogisticRegressionSuite assert(model.intercept !== 0.0) assert(model.hasParent) - // copied model must have the same parent. - MLTestingUtils.checkCopy(model) + MLTestingUtils.checkCopyAndUids(lr, model) assert(model.hasSummary) val copiedModel = model.copy(ParamMap.empty) assert(copiedModel.hasSummary) @@ -151,6 +150,54 @@ class LogisticRegressionSuite assert(!model.hasSummary) } + test("logistic regression: illegal params") { + val lowerBoundsOnCoefficients = Matrices.dense(1, 4, Array(1.0, 0.0, 1.0, 0.0)) + val upperBoundsOnCoefficients1 = Matrices.dense(1, 4, Array(0.0, 1.0, 1.0, 0.0)) + val upperBoundsOnCoefficients2 = Matrices.dense(1, 3, Array(1.0, 0.0, 1.0)) + val lowerBoundsOnIntercepts = Vectors.dense(1.0) + + // Work well when only set bound in one side. + new LogisticRegression() + .setLowerBoundsOnCoefficients(lowerBoundsOnCoefficients) + .fit(binaryDataset) + + withClue("bound constrained optimization only supports L2 regularization") { + intercept[IllegalArgumentException] { + new LogisticRegression() + .setLowerBoundsOnCoefficients(lowerBoundsOnCoefficients) + .setElasticNetParam(1.0) + .fit(binaryDataset) + } + } + + withClue("lowerBoundsOnCoefficients should less than or equal to upperBoundsOnCoefficients") { + intercept[IllegalArgumentException] { + new LogisticRegression() + .setLowerBoundsOnCoefficients(lowerBoundsOnCoefficients) + .setUpperBoundsOnCoefficients(upperBoundsOnCoefficients1) + .fit(binaryDataset) + } + } + + withClue("the coefficients bound matrix mismatched with shape (1, number of features)") { + intercept[IllegalArgumentException] { + new LogisticRegression() + .setLowerBoundsOnCoefficients(lowerBoundsOnCoefficients) + .setUpperBoundsOnCoefficients(upperBoundsOnCoefficients2) + .fit(binaryDataset) + } + } + + withClue("bounds on intercepts should not be set if fitting without intercept") { + intercept[IllegalArgumentException] { + new LogisticRegression() + .setLowerBoundsOnIntercepts(lowerBoundsOnIntercepts) + .setFitIntercept(false) + .fit(binaryDataset) + } + } + } + test("empty probabilityCol") { val lr = new LogisticRegression().setProbabilityCol("") val model = lr.fit(smallBinaryDataset) @@ -611,6 +658,107 @@ class LogisticRegressionSuite assert(model2.coefficients ~= coefficientsR relTol 1E-3) } + test("binary logistic regression with intercept without regularization with bound") { + // Bound constrained optimization with bound on one side. + val upperBoundsOnCoefficients = Matrices.dense(1, 4, Array(1.0, 0.0, 1.0, 0.0)) + val upperBoundsOnIntercepts = Vectors.dense(1.0) + + val trainer1 = new LogisticRegression() + .setUpperBoundsOnCoefficients(upperBoundsOnCoefficients) + .setUpperBoundsOnIntercepts(upperBoundsOnIntercepts) + .setFitIntercept(true) + .setStandardization(true) + .setWeightCol("weight") + val trainer2 = new LogisticRegression() + .setUpperBoundsOnCoefficients(upperBoundsOnCoefficients) + .setUpperBoundsOnIntercepts(upperBoundsOnIntercepts) + .setFitIntercept(true) + .setStandardization(false) + .setWeightCol("weight") + + val model1 = trainer1.fit(binaryDataset) + val model2 = trainer2.fit(binaryDataset) + + // The solution is generated by https://github.com/yanboliang/bound-optimization. + val coefficientsExpected1 = Vectors.dense(0.06079437, 0.0, -0.26351059, -0.59102199) + val interceptExpected1 = 1.0 + + assert(model1.intercept ~== interceptExpected1 relTol 1E-3) + assert(model1.coefficients ~= coefficientsExpected1 relTol 1E-3) + + // Without regularization, with or without standardization will converge to the same solution. + assert(model2.intercept ~== interceptExpected1 relTol 1E-3) + assert(model2.coefficients ~= coefficientsExpected1 relTol 1E-3) + + // Bound constrained optimization with bound on both side. + val lowerBoundsOnCoefficients = Matrices.dense(1, 4, Array(0.0, -1.0, 0.0, -1.0)) + val lowerBoundsOnIntercepts = Vectors.dense(0.0) + + val trainer3 = new LogisticRegression() + .setUpperBoundsOnCoefficients(upperBoundsOnCoefficients) + .setUpperBoundsOnIntercepts(upperBoundsOnIntercepts) + .setLowerBoundsOnCoefficients(lowerBoundsOnCoefficients) + .setLowerBoundsOnIntercepts(lowerBoundsOnIntercepts) + .setFitIntercept(true) + .setStandardization(true) + .setWeightCol("weight") + val trainer4 = new LogisticRegression() + .setUpperBoundsOnCoefficients(upperBoundsOnCoefficients) + .setUpperBoundsOnIntercepts(upperBoundsOnIntercepts) + .setLowerBoundsOnCoefficients(lowerBoundsOnCoefficients) + .setLowerBoundsOnIntercepts(lowerBoundsOnIntercepts) + .setFitIntercept(true) + .setStandardization(false) + .setWeightCol("weight") + + val model3 = trainer3.fit(binaryDataset) + val model4 = trainer4.fit(binaryDataset) + + // The solution is generated by https://github.com/yanboliang/bound-optimization. + val coefficientsExpected3 = Vectors.dense(0.0, 0.0, 0.0, -0.71708632) + val interceptExpected3 = 0.58776113 + + assert(model3.intercept ~== interceptExpected3 relTol 1E-3) + assert(model3.coefficients ~= coefficientsExpected3 relTol 1E-3) + + // Without regularization, with or without standardization will converge to the same solution. + assert(model4.intercept ~== interceptExpected3 relTol 1E-3) + assert(model4.coefficients ~= coefficientsExpected3 relTol 1E-3) + + // Bound constrained optimization with infinite bound on both side. + val trainer5 = new LogisticRegression() + .setUpperBoundsOnCoefficients(Matrices.dense(1, 4, Array.fill(4)(Double.PositiveInfinity))) + .setUpperBoundsOnIntercepts(Vectors.dense(Double.PositiveInfinity)) + .setLowerBoundsOnCoefficients(Matrices.dense(1, 4, Array.fill(4)(Double.NegativeInfinity))) + .setLowerBoundsOnIntercepts(Vectors.dense(Double.NegativeInfinity)) + .setFitIntercept(true) + .setStandardization(true) + .setWeightCol("weight") + val trainer6 = new LogisticRegression() + .setUpperBoundsOnCoefficients(Matrices.dense(1, 4, Array.fill(4)(Double.PositiveInfinity))) + .setUpperBoundsOnIntercepts(Vectors.dense(Double.PositiveInfinity)) + .setLowerBoundsOnCoefficients(Matrices.dense(1, 4, Array.fill(4)(Double.NegativeInfinity))) + .setLowerBoundsOnIntercepts(Vectors.dense(Double.NegativeInfinity)) + .setFitIntercept(true) + .setStandardization(false) + .setWeightCol("weight") + + val model5 = trainer5.fit(binaryDataset) + val model6 = trainer6.fit(binaryDataset) + + // The solution is generated by https://github.com/yanboliang/bound-optimization. + // It should be same as unbound constrained optimization with LBFGS. + val coefficientsExpected5 = Vectors.dense(-0.5734389, 0.8911736, -0.3878645, -0.8060570) + val interceptExpected5 = 2.7355261 + + assert(model5.intercept ~== interceptExpected5 relTol 1E-3) + assert(model5.coefficients ~= coefficientsExpected5 relTol 1E-3) + + // Without regularization, with or without standardization will converge to the same solution. + assert(model6.intercept ~== interceptExpected5 relTol 1E-3) + assert(model6.coefficients ~= coefficientsExpected5 relTol 1E-3) + } + test("binary logistic regression without intercept without regularization") { val trainer1 = (new LogisticRegression).setFitIntercept(false).setStandardization(true) .setWeightCol("weight") @@ -651,6 +799,34 @@ class LogisticRegressionSuite assert(model2.coefficients ~= coefficientsR relTol 1E-2) } + test("binary logistic regression without intercept without regularization with bound") { + val upperBoundsOnCoefficients = Matrices.dense(1, 4, Array(1.0, 0.0, 1.0, 0.0)).toSparse + + val trainer1 = new LogisticRegression() + .setUpperBoundsOnCoefficients(upperBoundsOnCoefficients) + .setFitIntercept(false) + .setStandardization(true) + .setWeightCol("weight") + val trainer2 = new LogisticRegression() + .setUpperBoundsOnCoefficients(upperBoundsOnCoefficients) + .setFitIntercept(false) + .setStandardization(false) + .setWeightCol("weight") + + val model1 = trainer1.fit(binaryDataset) + val model2 = trainer2.fit(binaryDataset) + + // The solution is generated by https://github.com/yanboliang/bound-optimization. + val coefficientsExpected = Vectors.dense(0.20847553, 0.0, -0.24240289, -0.55568071) + + assert(model1.intercept ~== 0.0 relTol 1E-3) + assert(model1.coefficients ~= coefficientsExpected relTol 1E-3) + + // Without regularization, with or without standardization will converge to the same solution. + assert(model2.intercept ~== 0.0 relTol 1E-3) + assert(model2.coefficients ~= coefficientsExpected relTol 1E-3) + } + test("binary logistic regression with intercept with L1 regularization") { val trainer1 = (new LogisticRegression).setFitIntercept(true) .setElasticNetParam(1.0).setRegParam(0.12).setStandardization(true).setWeightCol("weight") @@ -713,8 +889,6 @@ class LogisticRegressionSuite assert(model2.intercept ~== interceptR relTol 1E-2) assert(model2.coefficients ~== coefficientsR absTol 1E-3) - // TODO: move this to a standalone test of compression after SPARK-17471 - assert(model2.coefficients.isInstanceOf[SparseVector]) } test("binary logistic regression without intercept with L1 regularization") { @@ -818,6 +992,40 @@ class LogisticRegressionSuite assert(model2.coefficients ~= coefficientsR relTol 1E-3) } + test("binary logistic regression with intercept with L2 regularization with bound") { + val upperBoundsOnCoefficients = Matrices.dense(1, 4, Array(1.0, 0.0, 1.0, 0.0)) + val upperBoundsOnIntercepts = Vectors.dense(1.0) + + val trainer1 = new LogisticRegression() + .setUpperBoundsOnCoefficients(upperBoundsOnCoefficients) + .setUpperBoundsOnIntercepts(upperBoundsOnIntercepts) + .setRegParam(1.37) + .setFitIntercept(true) + .setStandardization(true) + .setWeightCol("weight") + val trainer2 = new LogisticRegression() + .setUpperBoundsOnCoefficients(upperBoundsOnCoefficients) + .setUpperBoundsOnIntercepts(upperBoundsOnIntercepts) + .setRegParam(1.37) + .setFitIntercept(true) + .setStandardization(false) + .setWeightCol("weight") + + val model1 = trainer1.fit(binaryDataset) + val model2 = trainer2.fit(binaryDataset) + + // The solution is generated by https://github.com/yanboliang/bound-optimization. + val coefficientsExpectedWithStd = Vectors.dense(-0.06985003, 0.0, -0.04794278, -0.10168595) + val interceptExpectedWithStd = 0.45750141 + val coefficientsExpected = Vectors.dense(-0.0494524, 0.0, -0.11360797, -0.06313577) + val interceptExpected = 0.53722967 + + assert(model1.intercept ~== interceptExpectedWithStd relTol 1E-3) + assert(model1.coefficients ~= coefficientsExpectedWithStd relTol 1E-3) + assert(model2.intercept ~== interceptExpected relTol 1E-3) + assert(model2.coefficients ~= coefficientsExpected relTol 1E-3) + } + test("binary logistic regression without intercept with L2 regularization") { val trainer1 = (new LogisticRegression).setFitIntercept(false) .setElasticNetParam(0.0).setRegParam(1.37).setStandardization(true).setWeightCol("weight") @@ -867,6 +1075,35 @@ class LogisticRegressionSuite assert(model2.coefficients ~= coefficientsR relTol 1E-2) } + test("binary logistic regression without intercept with L2 regularization with bound") { + val upperBoundsOnCoefficients = Matrices.dense(1, 4, Array(1.0, 0.0, 1.0, 0.0)) + + val trainer1 = new LogisticRegression() + .setUpperBoundsOnCoefficients(upperBoundsOnCoefficients) + .setRegParam(1.37) + .setFitIntercept(false) + .setStandardization(true) + .setWeightCol("weight") + val trainer2 = new LogisticRegression() + .setUpperBoundsOnCoefficients(upperBoundsOnCoefficients) + .setRegParam(1.37) + .setFitIntercept(false) + .setStandardization(false) + .setWeightCol("weight") + + val model1 = trainer1.fit(binaryDataset) + val model2 = trainer2.fit(binaryDataset) + + // The solution is generated by https://github.com/yanboliang/bound-optimization. + val coefficientsExpectedWithStd = Vectors.dense(-0.00796538, 0.0, -0.0394228, -0.0873314) + val coefficientsExpected = Vectors.dense(0.01105972, 0.0, -0.08574949, -0.05079558) + + assert(model1.intercept ~== 0.0 relTol 1E-3) + assert(model1.coefficients ~= coefficientsExpectedWithStd relTol 1E-3) + assert(model2.intercept ~== 0.0 relTol 1E-3) + assert(model2.coefficients ~= coefficientsExpected relTol 1E-3) + } + test("binary logistic regression with intercept with ElasticNet regularization") { val trainer1 = (new LogisticRegression).setFitIntercept(true).setMaxIter(200) .setElasticNetParam(0.38).setRegParam(0.21).setStandardization(true).setWeightCol("weight") @@ -1087,7 +1324,6 @@ class LogisticRegressionSuite } test("multinomial logistic regression with intercept without regularization") { - val trainer1 = (new LogisticRegression).setFitIntercept(true) .setElasticNetParam(0.0).setRegParam(0.0).setStandardization(true).setWeightCol("weight") val trainer2 = (new LogisticRegression).setFitIntercept(true) @@ -1142,6 +1378,9 @@ class LogisticRegressionSuite 0.10095851, -0.85897154, 0.08392798, 0.07904499), isTransposed = true) val interceptsR = Vectors.dense(-2.10320093, 0.3394473, 1.76375361) + model1.coefficientMatrix.colIter.foreach(v => assert(v.toArray.sum ~== 0.0 absTol eps)) + model2.coefficientMatrix.colIter.foreach(v => assert(v.toArray.sum ~== 0.0 absTol eps)) + assert(model1.coefficientMatrix ~== coefficientsR relTol 0.05) assert(model1.coefficientMatrix.toArray.sum ~== 0.0 absTol eps) assert(model1.interceptVector ~== interceptsR relTol 0.05) @@ -1152,6 +1391,110 @@ class LogisticRegressionSuite assert(model2.interceptVector.toArray.sum ~== 0.0 absTol eps) } + test("multinomial logistic regression with intercept without regularization with bound") { + // Bound constrained optimization with bound on one side. + val lowerBoundsOnCoefficients = Matrices.dense(3, 4, Array.fill(12)(1.0)) + val lowerBoundsOnIntercepts = Vectors.dense(Array.fill(3)(1.0)) + + val trainer1 = new LogisticRegression() + .setLowerBoundsOnCoefficients(lowerBoundsOnCoefficients) + .setLowerBoundsOnIntercepts(lowerBoundsOnIntercepts) + .setFitIntercept(true) + .setStandardization(true) + .setWeightCol("weight") + val trainer2 = new LogisticRegression() + .setLowerBoundsOnCoefficients(lowerBoundsOnCoefficients) + .setLowerBoundsOnIntercepts(lowerBoundsOnIntercepts) + .setFitIntercept(true) + .setStandardization(false) + .setWeightCol("weight") + + val model1 = trainer1.fit(multinomialDataset) + val model2 = trainer2.fit(multinomialDataset) + + // The solution is generated by https://github.com/yanboliang/bound-optimization. + val coefficientsExpected1 = new DenseMatrix(3, 4, Array( + 2.52076464, 2.73596057, 1.87984904, 2.73264492, + 1.93302281, 3.71363303, 1.50681746, 1.93398782, + 2.37839917, 1.93601818, 1.81924758, 2.45191255), isTransposed = true) + val interceptsExpected1 = Vectors.dense(1.00010477, 3.44237083, 4.86740286) + + checkCoefficientsEquivalent(model1.coefficientMatrix, coefficientsExpected1) + assert(model1.interceptVector ~== interceptsExpected1 relTol 0.01) + checkCoefficientsEquivalent(model2.coefficientMatrix, coefficientsExpected1) + assert(model2.interceptVector ~== interceptsExpected1 relTol 0.01) + + // Bound constrained optimization with bound on both side. + val upperBoundsOnCoefficients = Matrices.dense(3, 4, Array.fill(12)(2.0)) + val upperBoundsOnIntercepts = Vectors.dense(Array.fill(3)(2.0)) + + val trainer3 = new LogisticRegression() + .setLowerBoundsOnCoefficients(lowerBoundsOnCoefficients) + .setLowerBoundsOnIntercepts(lowerBoundsOnIntercepts) + .setUpperBoundsOnCoefficients(upperBoundsOnCoefficients) + .setUpperBoundsOnIntercepts(upperBoundsOnIntercepts) + .setFitIntercept(true) + .setStandardization(true) + .setWeightCol("weight") + val trainer4 = new LogisticRegression() + .setLowerBoundsOnCoefficients(lowerBoundsOnCoefficients) + .setLowerBoundsOnIntercepts(lowerBoundsOnIntercepts) + .setUpperBoundsOnCoefficients(upperBoundsOnCoefficients) + .setUpperBoundsOnIntercepts(upperBoundsOnIntercepts) + .setFitIntercept(true) + .setStandardization(false) + .setWeightCol("weight") + + val model3 = trainer3.fit(multinomialDataset) + val model4 = trainer4.fit(multinomialDataset) + + // The solution is generated by https://github.com/yanboliang/bound-optimization. + val coefficientsExpected3 = new DenseMatrix(3, 4, Array( + 1.61967097, 1.16027835, 1.45131448, 1.97390431, + 1.30529317, 2.0, 1.12985473, 1.26652854, + 1.61647195, 1.0, 1.40642959, 1.72985589), isTransposed = true) + val interceptsExpected3 = Vectors.dense(1.0, 2.0, 2.0) + + checkCoefficientsEquivalent(model3.coefficientMatrix, coefficientsExpected3) + assert(model3.interceptVector ~== interceptsExpected3 relTol 0.01) + checkCoefficientsEquivalent(model4.coefficientMatrix, coefficientsExpected3) + assert(model4.interceptVector ~== interceptsExpected3 relTol 0.01) + + // Bound constrained optimization with infinite bound on both side. + val trainer5 = new LogisticRegression() + .setLowerBoundsOnCoefficients(Matrices.dense(3, 4, Array.fill(12)(Double.NegativeInfinity))) + .setLowerBoundsOnIntercepts(Vectors.dense(Array.fill(3)(Double.NegativeInfinity))) + .setUpperBoundsOnCoefficients(Matrices.dense(3, 4, Array.fill(12)(Double.PositiveInfinity))) + .setUpperBoundsOnIntercepts(Vectors.dense(Array.fill(3)(Double.PositiveInfinity))) + .setFitIntercept(true) + .setStandardization(true) + .setWeightCol("weight") + val trainer6 = new LogisticRegression() + .setLowerBoundsOnCoefficients(Matrices.dense(3, 4, Array.fill(12)(Double.NegativeInfinity))) + .setLowerBoundsOnIntercepts(Vectors.dense(Array.fill(3)(Double.NegativeInfinity))) + .setUpperBoundsOnCoefficients(Matrices.dense(3, 4, Array.fill(12)(Double.PositiveInfinity))) + .setUpperBoundsOnIntercepts(Vectors.dense(Array.fill(3)(Double.PositiveInfinity))) + .setFitIntercept(true) + .setStandardization(false) + .setWeightCol("weight") + + val model5 = trainer5.fit(multinomialDataset) + val model6 = trainer6.fit(multinomialDataset) + + // The solution is generated by https://github.com/yanboliang/bound-optimization. + // It should be same as unbound constrained optimization with LBFGS. + val coefficientsExpected5 = new DenseMatrix(3, 4, Array( + 0.24337896, -0.05916156, 0.14446790, 0.35976165, + -0.3443375, 0.9181331, -0.2283959, -0.4388066, + 0.10095851, -0.85897154, 0.08392798, 0.07904499), isTransposed = true) + val interceptsExpected5 = Vectors.dense(-2.10320093, 0.3394473, 1.76375361) + + checkCoefficientsEquivalent(model5.coefficientMatrix, coefficientsExpected5) + assert(model5.interceptVector ~== interceptsExpected5 relTol 0.01) + checkCoefficientsEquivalent(model6.coefficientMatrix, coefficientsExpected5) + assert(model6.interceptVector ~== interceptsExpected5 relTol 0.01) + } + test("multinomial logistic regression without intercept without regularization") { val trainer1 = (new LogisticRegression).setFitIntercept(false) @@ -1207,6 +1550,9 @@ class LogisticRegressionSuite -0.3180040, 0.9679074, -0.2252219, -0.4319914, 0.2452411, -0.6046524, 0.1050710, 0.1180180), isTransposed = true) + model1.coefficientMatrix.colIter.foreach(v => assert(v.toArray.sum ~== 0.0 absTol eps)) + model2.coefficientMatrix.colIter.foreach(v => assert(v.toArray.sum ~== 0.0 absTol eps)) + assert(model1.coefficientMatrix ~== coefficientsR relTol 0.05) assert(model1.coefficientMatrix.toArray.sum ~== 0.0 absTol eps) assert(model1.interceptVector.toArray === Array.fill(3)(0.0)) @@ -1217,6 +1563,35 @@ class LogisticRegressionSuite assert(model2.interceptVector.toArray.sum ~== 0.0 absTol eps) } + test("multinomial logistic regression without intercept without regularization with bound") { + val lowerBoundsOnCoefficients = Matrices.dense(3, 4, Array.fill(12)(1.0)) + + val trainer1 = new LogisticRegression() + .setLowerBoundsOnCoefficients(lowerBoundsOnCoefficients) + .setFitIntercept(false) + .setStandardization(true) + .setWeightCol("weight") + val trainer2 = new LogisticRegression() + .setLowerBoundsOnCoefficients(lowerBoundsOnCoefficients) + .setFitIntercept(false) + .setStandardization(false) + .setWeightCol("weight") + + val model1 = trainer1.fit(multinomialDataset) + val model2 = trainer2.fit(multinomialDataset) + + // The solution is generated by https://github.com/yanboliang/bound-optimization. + val coefficientsExpected = new DenseMatrix(3, 4, Array( + 1.62410051, 1.38219391, 1.34486618, 1.74641729, + 1.23058989, 2.71787825, 1.0, 1.00007073, + 1.79478632, 1.14360459, 1.33011603, 1.55093897), isTransposed = true) + + checkCoefficientsEquivalent(model1.coefficientMatrix, coefficientsExpected) + assert(model1.interceptVector.toArray === Array.fill(3)(0.0)) + checkCoefficientsEquivalent(model2.coefficientMatrix, coefficientsExpected) + assert(model2.interceptVector.toArray === Array.fill(3)(0.0)) + } + test("multinomial logistic regression with intercept with L1 regularization") { // use tighter constraints because OWL-QN solver takes longer to converge @@ -1515,6 +1890,46 @@ class LogisticRegressionSuite assert(model2.interceptVector.toArray.sum ~== 0.0 absTol eps) } + test("multinomial logistic regression with intercept with L2 regularization with bound") { + val lowerBoundsOnCoefficients = Matrices.dense(3, 4, Array.fill(12)(1.0)) + val lowerBoundsOnIntercepts = Vectors.dense(Array.fill(3)(1.0)) + + val trainer1 = new LogisticRegression() + .setLowerBoundsOnCoefficients(lowerBoundsOnCoefficients) + .setLowerBoundsOnIntercepts(lowerBoundsOnIntercepts) + .setRegParam(0.1) + .setFitIntercept(true) + .setStandardization(true) + .setWeightCol("weight") + val trainer2 = new LogisticRegression() + .setLowerBoundsOnCoefficients(lowerBoundsOnCoefficients) + .setLowerBoundsOnIntercepts(lowerBoundsOnIntercepts) + .setRegParam(0.1) + .setFitIntercept(true) + .setStandardization(false) + .setWeightCol("weight") + + val model1 = trainer1.fit(multinomialDataset) + val model2 = trainer2.fit(multinomialDataset) + + // The solution is generated by https://github.com/yanboliang/bound-optimization. + val coefficientsExpectedWithStd = new DenseMatrix(3, 4, Array( + 1.0, 1.0, 1.0, 1.01647497, + 1.0, 1.44105616, 1.0, 1.0, + 1.0, 1.0, 1.0, 1.0), isTransposed = true) + val interceptsExpectedWithStd = Vectors.dense(2.52055893, 1.0, 2.560682) + val coefficientsExpected = new DenseMatrix(3, 4, Array( + 1.0, 1.0, 1.03189386, 1.0, + 1.0, 1.0, 1.0, 1.0, + 1.0, 1.0, 1.0, 1.0), isTransposed = true) + val interceptsExpected = Vectors.dense(1.06418835, 1.0, 1.20494701) + + assert(model1.coefficientMatrix ~== coefficientsExpectedWithStd relTol 0.01) + assert(model1.interceptVector ~== interceptsExpectedWithStd relTol 0.01) + assert(model2.coefficientMatrix ~== coefficientsExpected relTol 0.01) + assert(model2.interceptVector ~== interceptsExpected relTol 0.01) + } + test("multinomial logistic regression without intercept with L2 regularization") { val trainer1 = (new LogisticRegression).setFitIntercept(false) .setElasticNetParam(0.0).setRegParam(0.1).setStandardization(true).setWeightCol("weight") @@ -1612,6 +2027,41 @@ class LogisticRegressionSuite assert(model2.interceptVector.toArray.sum ~== 0.0 absTol eps) } + test("multinomial logistic regression without intercept with L2 regularization with bound") { + val lowerBoundsOnCoefficients = Matrices.dense(3, 4, Array.fill(12)(1.0)) + + val trainer1 = new LogisticRegression() + .setLowerBoundsOnCoefficients(lowerBoundsOnCoefficients) + .setRegParam(0.1) + .setFitIntercept(false) + .setStandardization(true) + .setWeightCol("weight") + val trainer2 = new LogisticRegression() + .setLowerBoundsOnCoefficients(lowerBoundsOnCoefficients) + .setRegParam(0.1) + .setFitIntercept(false) + .setStandardization(false) + .setWeightCol("weight") + + val model1 = trainer1.fit(multinomialDataset) + val model2 = trainer2.fit(multinomialDataset) + + // The solution is generated by https://github.com/yanboliang/bound-optimization. + val coefficientsExpectedWithStd = new DenseMatrix(3, 4, Array( + 1.01324653, 1.0, 1.0, 1.0415767, + 1.0, 1.0, 1.0, 1.0, + 1.02244888, 1.0, 1.0, 1.0), isTransposed = true) + val coefficientsExpected = new DenseMatrix(3, 4, Array( + 1.0, 1.0, 1.03932259, 1.0, + 1.0, 1.0, 1.0, 1.0, + 1.0, 1.0, 1.03274649, 1.0), isTransposed = true) + + assert(model1.coefficientMatrix ~== coefficientsExpectedWithStd absTol 0.01) + assert(model1.interceptVector.toArray === Array.fill(3)(0.0)) + assert(model2.coefficientMatrix ~== coefficientsExpected absTol 0.01) + assert(model2.interceptVector.toArray === Array.fill(3)(0.0)) + } + test("multinomial logistic regression with intercept with elasticnet regularization") { val trainer1 = (new LogisticRegression).setFitIntercept(true).setWeightCol("weight") .setElasticNetParam(0.5).setRegParam(0.1).setStandardization(true) @@ -1876,7 +2326,7 @@ class LogisticRegressionSuite MLTestingUtils.testArbitrarilyScaledWeights[LogisticRegressionModel, LogisticRegression]( dataset.as[LabeledPoint], estimator, modelEquals) MLTestingUtils.testOutliersWithSmallWeights[LogisticRegressionModel, LogisticRegression]( - dataset.as[LabeledPoint], estimator, numClasses, modelEquals) + dataset.as[LabeledPoint], estimator, numClasses, modelEquals, outlierRatio = 3) MLTestingUtils.testOversamplingVsWeighting[LogisticRegressionModel, LogisticRegression]( dataset.as[LabeledPoint], estimator, modelEquals, seed) } @@ -2031,29 +2481,61 @@ class LogisticRegressionSuite // TODO: check num iters is zero when it become available in the model } - test("compressed storage") { + test("compressed storage for constant label") { + /* + When the label is constant and fit intercept is true, all the coefficients will be + zeros, and so the model coefficients should be stored as sparse data structures, except + when the matrix dimensions are very small. + */ val moreClassesThanFeatures = Seq( - LabeledPoint(4.0, Vectors.dense(0.0, 0.0, 0.0)), - LabeledPoint(4.0, Vectors.dense(1.0, 1.0, 1.0)), - LabeledPoint(4.0, Vectors.dense(2.0, 2.0, 2.0))).toDF() - val mlr = new LogisticRegression().setFamily("multinomial") + LabeledPoint(4.0, Vectors.dense(Array.fill(5)(0.0))), + LabeledPoint(4.0, Vectors.dense(Array.fill(5)(1.0))), + LabeledPoint(4.0, Vectors.dense(Array.fill(5)(2.0)))).toDF() + val mlr = new LogisticRegression().setFamily("multinomial").setFitIntercept(true) val model = mlr.fit(moreClassesThanFeatures) assert(model.coefficientMatrix.isInstanceOf[SparseMatrix]) - assert(model.coefficientMatrix.asInstanceOf[SparseMatrix].colPtrs.length === 4) + assert(model.coefficientMatrix.isColMajor) + + // in this case, it should be stored as row major val moreFeaturesThanClasses = Seq( - LabeledPoint(1.0, Vectors.dense(0.0, 0.0, 0.0)), - LabeledPoint(1.0, Vectors.dense(1.0, 1.0, 1.0)), - LabeledPoint(1.0, Vectors.dense(2.0, 2.0, 2.0))).toDF() + LabeledPoint(1.0, Vectors.dense(Array.fill(5)(0.0))), + LabeledPoint(1.0, Vectors.dense(Array.fill(5)(1.0))), + LabeledPoint(1.0, Vectors.dense(Array.fill(5)(2.0)))).toDF() val model2 = mlr.fit(moreFeaturesThanClasses) assert(model2.coefficientMatrix.isInstanceOf[SparseMatrix]) - assert(model2.coefficientMatrix.asInstanceOf[SparseMatrix].colPtrs.length === 3) + assert(model2.coefficientMatrix.isRowMajor) - val blr = new LogisticRegression().setFamily("binomial") + val blr = new LogisticRegression().setFamily("binomial").setFitIntercept(true) val blrModel = blr.fit(moreFeaturesThanClasses) assert(blrModel.coefficientMatrix.isInstanceOf[SparseMatrix]) assert(blrModel.coefficientMatrix.asInstanceOf[SparseMatrix].colPtrs.length === 2) } + test("compressed coefficients") { + + val trainer1 = new LogisticRegression() + .setRegParam(0.1) + .setElasticNetParam(1.0) + + // compressed row major is optimal + val model1 = trainer1.fit(multinomialDataset.limit(100)) + assert(model1.coefficientMatrix.isInstanceOf[SparseMatrix]) + assert(model1.coefficientMatrix.isRowMajor) + + // compressed column major is optimal since there are more classes than features + val labelMeta = NominalAttribute.defaultAttr.withName("label").withNumValues(6).toMetadata() + val model2 = trainer1.fit(multinomialDataset + .withColumn("label", col("label").as("label", labelMeta)).limit(100)) + assert(model2.coefficientMatrix.isInstanceOf[SparseMatrix]) + assert(model2.coefficientMatrix.isColMajor) + + // coefficients are dense without L1 regularization + val trainer2 = new LogisticRegression() + .setElasticNetParam(0.0) + val model3 = trainer2.fit(multinomialDataset.limit(100)) + assert(model3.coefficientMatrix.isInstanceOf[DenseMatrix]) + } + test("numClasses specified in metadata/inferred") { val lr = new LogisticRegression().setMaxIter(1).setFamily("multinomial") @@ -2089,7 +2571,7 @@ class LogisticRegressionSuite } val lr = new LogisticRegression() testEstimatorAndModelReadWrite(lr, smallBinaryDataset, LogisticRegressionSuite.allParamSettings, - checkModelData) + LogisticRegressionSuite.allParamSettings, checkModelData) } test("should support all NumericType labels and weights, and not support other types") { @@ -2238,4 +2720,19 @@ object LogisticRegressionSuite { val testData = (0 until nPoints).map(i => LabeledPoint(y(i), x(i))) testData } + + /** + * When no regularization is applied, the multinomial coefficients lack identifiability + * because we do not use a pivot class. We can add any constant value to the coefficients + * and get the same likelihood. If fitting under bound constrained optimization, we don't + * choose the mean centered coefficients like what we do for unbound problems, since they + * may out of the bounds. We use this function to check whether two coefficients are equivalent. + */ + def checkCoefficientsEquivalent(coefficients1: Matrix, coefficients2: Matrix): Unit = { + coefficients1.colIter.zip(coefficients2.colIter).foreach { case (col1: Vector, col2: Vector) => + (col1.asBreeze - col2.asBreeze).toArray.toSeq.sliding(2).foreach { + case Seq(v1, v2) => assert(v1 ~= v2 absTol 1E-3) + } + } + } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifierSuite.scala index dfe2e3ee66593..94d71fd532332 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifierSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifierSuite.scala @@ -76,6 +76,7 @@ class MultilayerPerceptronClassifierSuite .setSolver("l-bfgs") val model = trainer.fit(dataset) val result = model.transform(dataset) + MLTestingUtils.checkCopyAndUids(trainer, model) val predictionAndLabels = result.select("prediction", "label").collect() predictionAndLabels.foreach { case Row(p: Double, l: Double) => assert(p == l) diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/NaiveBayesSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/NaiveBayesSuite.scala index 37d7991fe8dd8..3a2be236f1257 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/NaiveBayesSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/NaiveBayesSuite.scala @@ -149,6 +149,7 @@ class NaiveBayesSuite extends SparkFunSuite with MLlibTestSparkContext with Defa validateModelFit(pi, theta, model) assert(model.hasParent) + MLTestingUtils.checkCopyAndUids(nb, model) val validationDataset = generateNaiveBayesInput(piArray, thetaArray, nPoints, 17, "multinomial").toDF() @@ -167,7 +168,7 @@ class NaiveBayesSuite extends SparkFunSuite with MLlibTestSparkContext with Defa assert(m1.pi ~== m2.pi relTol 0.01) assert(m1.theta ~== m2.theta relTol 0.01) } - val testParams = Seq( + val testParams = Seq[(String, Dataset[_])]( ("bernoulli", bernoulliDataset), ("multinomial", dataset) ) @@ -178,7 +179,7 @@ class NaiveBayesSuite extends SparkFunSuite with MLlibTestSparkContext with Defa MLTestingUtils.testArbitrarilyScaledWeights[NaiveBayesModel, NaiveBayes]( dataset.as[LabeledPoint], estimatorNoSmoothing, modelEquals) MLTestingUtils.testOutliersWithSmallWeights[NaiveBayesModel, NaiveBayes]( - dataset.as[LabeledPoint], estimatorWithSmoothing, numClasses, modelEquals) + dataset.as[LabeledPoint], estimatorWithSmoothing, numClasses, modelEquals, outlierRatio = 3) MLTestingUtils.testOversamplingVsWeighting[NaiveBayesModel, NaiveBayes]( dataset.as[LabeledPoint], estimatorWithSmoothing, modelEquals, seed) } @@ -280,7 +281,8 @@ class NaiveBayesSuite extends SparkFunSuite with MLlibTestSparkContext with Defa assert(model.theta === model2.theta) } val nb = new NaiveBayes() - testEstimatorAndModelReadWrite(nb, dataset, NaiveBayesSuite.allParamSettings, checkModelData) + testEstimatorAndModelReadWrite(nb, dataset, NaiveBayesSuite.allParamSettings, + NaiveBayesSuite.allParamSettings, checkModelData) } test("should support all NumericType labels and weights, and not support other types") { diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala index f69351a6c76dd..6bfc1646c511a 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala @@ -76,8 +76,7 @@ class OneVsRestSuite extends SparkFunSuite with MLlibTestSparkContext with Defau assert(ova.getPredictionCol === "prediction") val ovaModel = ova.fit(dataset) - // copied model must have the same parent. - MLTestingUtils.checkCopy(ovaModel) + MLTestingUtils.checkCopyAndUids(ova, ovaModel) assert(ovaModel.models.length === numClasses) val transformedDataset = ovaModel.transform(dataset) diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala index 44e1585ee514b..ca2954d2f32c4 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala @@ -141,8 +141,7 @@ class RandomForestClassifierSuite val df: DataFrame = TreeTests.setMetadata(rdd, categoricalFeatures, numClasses) val model = rf.fit(df) - // copied model must have the same parent. - MLTestingUtils.checkCopy(model) + MLTestingUtils.checkCopyAndUids(rf, model) val predictions = model.transform(df) .select(rf.getPredictionCol, rf.getRawPredictionCol, rf.getProbabilityCol) @@ -218,7 +217,8 @@ class RandomForestClassifierSuite val continuousData: DataFrame = TreeTests.setMetadata(rdd, Map.empty[Int, Int], numClasses = 2) - testEstimatorAndModelReadWrite(rf, continuousData, allParamSettings, checkModelData) + testEstimatorAndModelReadWrite(rf, continuousData, allParamSettings, + allParamSettings, checkModelData) } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/clustering/BisectingKMeansSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/clustering/BisectingKMeansSuite.scala index 30513c1e276ae..fa7471fa2d658 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/clustering/BisectingKMeansSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/clustering/BisectingKMeansSuite.scala @@ -47,8 +47,7 @@ class BisectingKMeansSuite assert(bkm.getMinDivisibleClusterSize === 1.0) val model = bkm.setMaxIter(1).fit(dataset) - // copied model must have the same parent - MLTestingUtils.checkCopy(model) + MLTestingUtils.checkCopyAndUids(bkm, model) assert(model.hasSummary) val copiedModel = model.copy(ParamMap.empty) assert(copiedModel.hasSummary) @@ -138,8 +137,8 @@ class BisectingKMeansSuite assert(model.clusterCenters === model2.clusterCenters) } val bisectingKMeans = new BisectingKMeans() - testEstimatorAndModelReadWrite( - bisectingKMeans, dataset, BisectingKMeansSuite.allParamSettings, checkModelData) + testEstimatorAndModelReadWrite(bisectingKMeans, dataset, BisectingKMeansSuite.allParamSettings, + BisectingKMeansSuite.allParamSettings, checkModelData) } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/clustering/GaussianMixtureSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/clustering/GaussianMixtureSuite.scala index c500c5b3e365a..08b800b7e4183 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/clustering/GaussianMixtureSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/clustering/GaussianMixtureSuite.scala @@ -77,8 +77,7 @@ class GaussianMixtureSuite extends SparkFunSuite with MLlibTestSparkContext assert(gm.getTol === 0.01) val model = gm.setMaxIter(1).fit(dataset) - // copied model must have the same parent - MLTestingUtils.checkCopy(model) + MLTestingUtils.checkCopyAndUids(gm, model) assert(model.hasSummary) val copiedModel = model.copy(ParamMap.empty) assert(copiedModel.hasSummary) @@ -163,7 +162,7 @@ class GaussianMixtureSuite extends SparkFunSuite with MLlibTestSparkContext assert(model.gaussians.map(_.cov) === model2.gaussians.map(_.cov)) } val gm = new GaussianMixture() - testEstimatorAndModelReadWrite(gm, dataset, + testEstimatorAndModelReadWrite(gm, dataset, GaussianMixtureSuite.allParamSettings, GaussianMixtureSuite.allParamSettings, checkModelData) } diff --git a/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala index e10127f7d108f..119fe1dead9a9 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala @@ -52,8 +52,7 @@ class KMeansSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultR assert(kmeans.getTol === 1e-4) val model = kmeans.setMaxIter(1).fit(dataset) - // copied model must have the same parent - MLTestingUtils.checkCopy(model) + MLTestingUtils.checkCopyAndUids(kmeans, model) assert(model.hasSummary) val copiedModel = model.copy(ParamMap.empty) assert(copiedModel.hasSummary) @@ -150,7 +149,8 @@ class KMeansSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultR assert(model.clusterCenters === model2.clusterCenters) } val kmeans = new KMeans() - testEstimatorAndModelReadWrite(kmeans, dataset, KMeansSuite.allParamSettings, checkModelData) + testEstimatorAndModelReadWrite(kmeans, dataset, KMeansSuite.allParamSettings, + KMeansSuite.allParamSettings, checkModelData) } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/clustering/LDASuite.scala b/mllib/src/test/scala/org/apache/spark/ml/clustering/LDASuite.scala index 9aa11fbdbe868..b4fe63a89f871 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/clustering/LDASuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/clustering/LDASuite.scala @@ -176,7 +176,7 @@ class LDASuite extends SparkFunSuite with MLlibTestSparkContext with DefaultRead val lda = new LDA().setK(k).setSeed(1).setOptimizer("online").setMaxIter(2) val model = lda.fit(dataset) - MLTestingUtils.checkCopy(model) + MLTestingUtils.checkCopyAndUids(lda, model) assert(model.isInstanceOf[LocalLDAModel]) assert(model.vocabSize === vocabSize) @@ -221,7 +221,7 @@ class LDASuite extends SparkFunSuite with MLlibTestSparkContext with DefaultRead val lda = new LDA().setK(k).setSeed(1).setOptimizer("em").setMaxIter(2) val model_ = lda.fit(dataset) - MLTestingUtils.checkCopy(model_) + MLTestingUtils.checkCopyAndUids(lda, model_) assert(model_.isInstanceOf[DistributedLDAModel]) val model = model_.asInstanceOf[DistributedLDAModel] @@ -250,7 +250,8 @@ class LDASuite extends SparkFunSuite with MLlibTestSparkContext with DefaultRead Vectors.dense(model2.getDocConcentration) absTol 1e-6) } val lda = new LDA() - testEstimatorAndModelReadWrite(lda, dataset, LDASuite.allParamSettings, checkModelData) + testEstimatorAndModelReadWrite(lda, dataset, LDASuite.allParamSettings, + LDASuite.allParamSettings, checkModelData) } test("read/write DistributedLDAModel") { @@ -271,6 +272,7 @@ class LDASuite extends SparkFunSuite with MLlibTestSparkContext with DefaultRead } val lda = new LDA() testEstimatorAndModelReadWrite(lda, dataset, + LDASuite.allParamSettings ++ Map("optimizer" -> "em"), LDASuite.allParamSettings ++ Map("optimizer" -> "em"), checkModelData) } diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/BucketedRandomProjectionLSHSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/BucketedRandomProjectionLSHSuite.scala index ab937685a555c..7175c721bff36 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/BucketedRandomProjectionLSHSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/BucketedRandomProjectionLSHSuite.scala @@ -23,7 +23,7 @@ import breeze.numerics.constants.Pi import org.apache.spark.SparkFunSuite import org.apache.spark.ml.linalg.{Vector, Vectors} import org.apache.spark.ml.param.ParamsSuite -import org.apache.spark.ml.util.DefaultReadWriteTest +import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils} import org.apache.spark.ml.util.TestingUtils._ import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.sql.Dataset @@ -63,7 +63,7 @@ class BucketedRandomProjectionLSHSuite } val mh = new BucketedRandomProjectionLSH() val settings = Map("inputCol" -> "keys", "outputCol" -> "values", "bucketLength" -> 1.0) - testEstimatorAndModelReadWrite(mh, dataset, settings, checkModelData) + testEstimatorAndModelReadWrite(mh, dataset, settings, settings, checkModelData) } test("hashFunction") { @@ -89,10 +89,13 @@ class BucketedRandomProjectionLSHSuite .setOutputCol("values") .setBucketLength(1.0) .setSeed(12345) - val unitVectors = brp.fit(dataset).randUnitVectors + val brpModel = brp.fit(dataset) + val unitVectors = brpModel.randUnitVectors unitVectors.foreach { v: Vector => assert(Vectors.norm(v, 2.0) ~== 1.0 absTol 1e-14) } + + MLTestingUtils.checkCopyAndUids(brp, brpModel) } test("BucketedRandomProjectionLSH: test of LSH property") { diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/ChiSqSelectorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/ChiSqSelectorSuite.scala index 482e5d54260d4..c83909c4498f2 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/ChiSqSelectorSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/ChiSqSelectorSuite.scala @@ -119,7 +119,8 @@ class ChiSqSelectorSuite extends SparkFunSuite with MLlibTestSparkContext test("Test Chi-Square selector: numTopFeatures") { val selector = new ChiSqSelector() .setOutputCol("filtered").setSelectorType("numTopFeatures").setNumTopFeatures(1) - ChiSqSelectorSuite.testSelector(selector, dataset) + val model = ChiSqSelectorSuite.testSelector(selector, dataset) + MLTestingUtils.checkCopyAndUids(selector, model) } test("Test Chi-Square selector: percentile") { @@ -151,7 +152,8 @@ class ChiSqSelectorSuite extends SparkFunSuite with MLlibTestSparkContext assert(model.selectedFeatures === model2.selectedFeatures) } val nb = new ChiSqSelector - testEstimatorAndModelReadWrite(nb, dataset, ChiSqSelectorSuite.allParamSettings, checkModelData) + testEstimatorAndModelReadWrite(nb, dataset, ChiSqSelectorSuite.allParamSettings, + ChiSqSelectorSuite.allParamSettings, checkModelData) } test("should support all NumericType labels and not support other types") { @@ -165,11 +167,13 @@ class ChiSqSelectorSuite extends SparkFunSuite with MLlibTestSparkContext object ChiSqSelectorSuite { - private def testSelector(selector: ChiSqSelector, dataset: Dataset[_]): Unit = { - selector.fit(dataset).transform(dataset).select("filtered", "topFeature").collect() + private def testSelector(selector: ChiSqSelector, dataset: Dataset[_]): ChiSqSelectorModel = { + val selectorModel = selector.fit(dataset) + selectorModel.transform(dataset).select("filtered", "topFeature").collect() .foreach { case Row(vec1: Vector, vec2: Vector) => assert(vec1 ~== vec2 absTol 1e-1) } + selectorModel } /** diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/CountVectorizerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/CountVectorizerSuite.scala index 69d3033bb2189..f213145f1ba0a 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/CountVectorizerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/CountVectorizerSuite.scala @@ -19,7 +19,7 @@ package org.apache.spark.ml.feature import org.apache.spark.SparkFunSuite import org.apache.spark.ml.linalg.{Vector, Vectors} import org.apache.spark.ml.param.ParamsSuite -import org.apache.spark.ml.util.DefaultReadWriteTest +import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils} import org.apache.spark.ml.util.TestingUtils._ import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.sql.Row @@ -68,10 +68,11 @@ class CountVectorizerSuite extends SparkFunSuite with MLlibTestSparkContext val cv = new CountVectorizer() .setInputCol("words") .setOutputCol("features") - .fit(df) - assert(cv.vocabulary.toSet === Set("a", "b", "c", "d", "e")) + val cvm = cv.fit(df) + MLTestingUtils.checkCopyAndUids(cv, cvm) + assert(cvm.vocabulary.toSet === Set("a", "b", "c", "d", "e")) - cv.transform(df).select("features", "expected").collect().foreach { + cvm.transform(df).select("features", "expected").collect().foreach { case Row(features: Vector, expected: Vector) => assert(features ~== expected absTol 1e-14) } diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/IDFSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/IDFSuite.scala index 5325d95526a50..005edf73d29be 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/IDFSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/IDFSuite.scala @@ -20,7 +20,7 @@ package org.apache.spark.ml.feature import org.apache.spark.SparkFunSuite import org.apache.spark.ml.linalg.{DenseVector, SparseVector, Vector, Vectors} import org.apache.spark.ml.param.ParamsSuite -import org.apache.spark.ml.util.DefaultReadWriteTest +import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils} import org.apache.spark.ml.util.TestingUtils._ import org.apache.spark.mllib.feature.{IDFModel => OldIDFModel} import org.apache.spark.mllib.linalg.VectorImplicits._ @@ -65,10 +65,12 @@ class IDFSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultRead val df = data.zip(expected).toSeq.toDF("features", "expected") - val idfModel = new IDF() + val idfEst = new IDF() .setInputCol("features") .setOutputCol("idfValue") - .fit(df) + val idfModel = idfEst.fit(df) + + MLTestingUtils.checkCopyAndUids(idfEst, idfModel) idfModel.transform(df).select("idfValue", "expected").collect().foreach { case Row(x: Vector, y: Vector) => diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/ImputerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/ImputerSuite.scala new file mode 100644 index 0000000000000..ee2ba73fa96d5 --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/ImputerSuite.scala @@ -0,0 +1,185 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.ml.feature + +import org.apache.spark.{SparkException, SparkFunSuite} +import org.apache.spark.ml.util.DefaultReadWriteTest +import org.apache.spark.mllib.util.MLlibTestSparkContext +import org.apache.spark.mllib.util.TestingUtils._ +import org.apache.spark.sql.{DataFrame, Row} + +class ImputerSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { + + test("Imputer for Double with default missing Value NaN") { + val df = spark.createDataFrame( Seq( + (0, 1.0, 4.0, 1.0, 1.0, 4.0, 4.0), + (1, 11.0, 12.0, 11.0, 11.0, 12.0, 12.0), + (2, 3.0, Double.NaN, 3.0, 3.0, 10.0, 12.0), + (3, Double.NaN, 14.0, 5.0, 3.0, 14.0, 14.0) + )).toDF("id", "value1", "value2", "expected_mean_value1", "expected_median_value1", + "expected_mean_value2", "expected_median_value2") + val imputer = new Imputer() + .setInputCols(Array("value1", "value2")) + .setOutputCols(Array("out1", "out2")) + ImputerSuite.iterateStrategyTest(imputer, df) + } + + test("Imputer should handle NaNs when computing surrogate value, if missingValue is not NaN") { + val df = spark.createDataFrame( Seq( + (0, 1.0, 1.0, 1.0), + (1, 3.0, 3.0, 3.0), + (2, Double.NaN, Double.NaN, Double.NaN), + (3, -1.0, 2.0, 3.0) + )).toDF("id", "value", "expected_mean_value", "expected_median_value") + val imputer = new Imputer().setInputCols(Array("value")).setOutputCols(Array("out")) + .setMissingValue(-1.0) + ImputerSuite.iterateStrategyTest(imputer, df) + } + + test("Imputer for Float with missing Value -1.0") { + val df = spark.createDataFrame( Seq( + (0, 1.0F, 1.0F, 1.0F), + (1, 3.0F, 3.0F, 3.0F), + (2, 10.0F, 10.0F, 10.0F), + (3, 10.0F, 10.0F, 10.0F), + (4, -1.0F, 6.0F, 3.0F) + )).toDF("id", "value", "expected_mean_value", "expected_median_value") + val imputer = new Imputer().setInputCols(Array("value")).setOutputCols(Array("out")) + .setMissingValue(-1) + ImputerSuite.iterateStrategyTest(imputer, df) + } + + test("Imputer should impute null as well as 'missingValue'") { + val rawDf = spark.createDataFrame( Seq( + (0, 4.0, 4.0, 4.0), + (1, 10.0, 10.0, 10.0), + (2, 10.0, 10.0, 10.0), + (3, Double.NaN, 8.0, 10.0), + (4, -1.0, 8.0, 10.0) + )).toDF("id", "rawValue", "expected_mean_value", "expected_median_value") + val df = rawDf.selectExpr("*", "IF(rawValue=-1.0, null, rawValue) as value") + val imputer = new Imputer().setInputCols(Array("value")).setOutputCols(Array("out")) + ImputerSuite.iterateStrategyTest(imputer, df) + } + + test("Imputer throws exception when surrogate cannot be computed") { + val df = spark.createDataFrame( Seq( + (0, Double.NaN, 1.0, 1.0), + (1, Double.NaN, 3.0, 3.0), + (2, Double.NaN, Double.NaN, Double.NaN) + )).toDF("id", "value", "expected_mean_value", "expected_median_value") + Seq("mean", "median").foreach { strategy => + val imputer = new Imputer().setInputCols(Array("value")).setOutputCols(Array("out")) + .setStrategy(strategy) + withClue("Imputer should fail all the values are invalid") { + val e: SparkException = intercept[SparkException] { + val model = imputer.fit(df) + } + assert(e.getMessage.contains("surrogate cannot be computed")) + } + } + } + + test("Imputer input & output column validation") { + val df = spark.createDataFrame( Seq( + (0, 1.0, 1.0, 1.0), + (1, Double.NaN, 3.0, 3.0), + (2, Double.NaN, Double.NaN, Double.NaN) + )).toDF("id", "value1", "value2", "value3") + Seq("mean", "median").foreach { strategy => + withClue("Imputer should fail if inputCols and outputCols are different length") { + val e: IllegalArgumentException = intercept[IllegalArgumentException] { + val imputer = new Imputer().setStrategy(strategy) + .setInputCols(Array("value1", "value2")) + .setOutputCols(Array("out1")) + val model = imputer.fit(df) + } + assert(e.getMessage.contains("should have the same length")) + } + + withClue("Imputer should fail if inputCols contains duplicates") { + val e: IllegalArgumentException = intercept[IllegalArgumentException] { + val imputer = new Imputer().setStrategy(strategy) + .setInputCols(Array("value1", "value1")) + .setOutputCols(Array("out1", "out2")) + val model = imputer.fit(df) + } + assert(e.getMessage.contains("inputCols contains duplicates")) + } + + withClue("Imputer should fail if outputCols contains duplicates") { + val e: IllegalArgumentException = intercept[IllegalArgumentException] { + val imputer = new Imputer().setStrategy(strategy) + .setInputCols(Array("value1", "value2")) + .setOutputCols(Array("out1", "out1")) + val model = imputer.fit(df) + } + assert(e.getMessage.contains("outputCols contains duplicates")) + } + } + } + + test("Imputer read/write") { + val t = new Imputer() + .setInputCols(Array("myInputCol")) + .setOutputCols(Array("myOutputCol")) + .setMissingValue(-1.0) + testDefaultReadWrite(t) + } + + test("ImputerModel read/write") { + val spark = this.spark + import spark.implicits._ + val surrogateDF = Seq(1.234).toDF("myInputCol") + + val instance = new ImputerModel( + "myImputer", surrogateDF) + .setInputCols(Array("myInputCol")) + .setOutputCols(Array("myOutputCol")) + val newInstance = testDefaultReadWrite(instance) + assert(newInstance.surrogateDF.columns === instance.surrogateDF.columns) + assert(newInstance.surrogateDF.collect() === instance.surrogateDF.collect()) + } + +} + +object ImputerSuite { + + /** + * Imputation strategy. Available options are ["mean", "median"]. + * @param df DataFrame with columns "id", "value", "expected_mean", "expected_median" + */ + def iterateStrategyTest(imputer: Imputer, df: DataFrame): Unit = { + val inputCols = imputer.getInputCols + + Seq("mean", "median").foreach { strategy => + imputer.setStrategy(strategy) + val model = imputer.fit(df) + val resultDF = model.transform(df) + imputer.getInputCols.zip(imputer.getOutputCols).foreach { case (inputCol, outputCol) => + resultDF.select(s"expected_${strategy}_$inputCol", outputCol).collect().foreach { + case Row(exp: Float, out: Float) => + assert((exp.isNaN && out.isNaN) || (exp == out), + s"Imputed values differ. Expected: $exp, actual: $out") + case Row(exp: Double, out: Double) => + assert((exp.isNaN && out.isNaN) || (exp ~== out absTol 1e-5), + s"Imputed values differ. Expected: $exp, actual: $out") + } + } + } + } +} diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/LSHTest.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/LSHTest.scala index a9b559f7ba648..db4f56ed60d32 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/LSHTest.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/LSHTest.scala @@ -18,7 +18,7 @@ package org.apache.spark.ml.feature import org.apache.spark.ml.linalg.{Vector, VectorUDT} -import org.apache.spark.ml.util.SchemaUtils +import org.apache.spark.ml.util.{MLTestingUtils, SchemaUtils} import org.apache.spark.sql.Dataset import org.apache.spark.sql.functions._ import org.apache.spark.sql.types.DataTypes @@ -29,8 +29,10 @@ private[ml] object LSHTest { * the following property is satisfied. * * There exist dist1, dist2, p1, p2, so that for any two elements e1 and e2, - * If dist(e1, e2) <= dist1, then Pr{h(x) == h(y)} >= p1 - * If dist(e1, e2) >= dist2, then Pr{h(x) == h(y)} <= p2 + * If dist(e1, e2) is less than or equal to dist1, then Pr{h(x) == h(y)} is greater than + * or equal to p1 + * If dist(e1, e2) is greater than or equal to dist2, then Pr{h(x) == h(y)} is less than + * or equal to p2 * * This is called locality sensitive property. This method checks the property on an * existing dataset and calculate the probabilities. @@ -38,8 +40,10 @@ private[ml] object LSHTest { * * This method hashes each elements to hash buckets using LSH, and calculate the false positive * and false negative: - * False positive: Of all (e1, e2) sharing any bucket, the probability of dist(e1, e2) > distFP - * False negative: Of all (e1, e2) not sharing buckets, the probability of dist(e1, e2) < distFN + * False positive: Of all (e1, e2) sharing any bucket, the probability of dist(e1, e2) is greater + * than distFP + * False negative: Of all (e1, e2) not sharing buckets, the probability of dist(e1, e2) is less + * than distFN * * @param dataset The dataset to verify the locality sensitive hashing property. * @param lsh The lsh instance to perform the hashing @@ -58,6 +62,8 @@ private[ml] object LSHTest { val outputCol = model.getOutputCol val transformedData = model.transform(dataset) + MLTestingUtils.checkCopyAndUids(lsh, model) + // Check output column type SchemaUtils.checkColumnType( transformedData.schema, model.getOutputCol, DataTypes.createArrayType(new VectorUDT)) diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/MaxAbsScalerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/MaxAbsScalerSuite.scala index a12174493b867..918da4f9388d4 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/MaxAbsScalerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/MaxAbsScalerSuite.scala @@ -50,8 +50,7 @@ class MaxAbsScalerSuite extends SparkFunSuite with MLlibTestSparkContext with De assert(vector1.equals(vector2), s"MaxAbsScaler ut error: $vector2 should be $vector1") } - // copied model must have the same parent. - MLTestingUtils.checkCopy(model) + MLTestingUtils.checkCopyAndUids(scaler, model) } test("MaxAbsScaler read/write") { diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/MinHashLSHSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/MinHashLSHSuite.scala index 3461cdf82460f..96df68dbdf053 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/MinHashLSHSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/MinHashLSHSuite.scala @@ -20,7 +20,7 @@ package org.apache.spark.ml.feature import org.apache.spark.SparkFunSuite import org.apache.spark.ml.linalg.{Vector, Vectors} import org.apache.spark.ml.param.ParamsSuite -import org.apache.spark.ml.util.DefaultReadWriteTest +import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils} import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.sql.Dataset @@ -54,7 +54,16 @@ class MinHashLSHSuite extends SparkFunSuite with MLlibTestSparkContext with Defa } val mh = new MinHashLSH() val settings = Map("inputCol" -> "keys", "outputCol" -> "values") - testEstimatorAndModelReadWrite(mh, dataset, settings, checkModelData) + testEstimatorAndModelReadWrite(mh, dataset, settings, settings, checkModelData) + } + + test("Model copy and uid checks") { + val mh = new MinHashLSH() + .setInputCol("keys") + .setOutputCol("values") + val model = mh.fit(dataset) + assert(mh.uid === model.uid) + MLTestingUtils.checkCopyAndUids(mh, model) } test("hashFunction") { diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/MinMaxScalerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/MinMaxScalerSuite.scala index b79eeb2d75ef0..51db74eb739ca 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/MinMaxScalerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/MinMaxScalerSuite.scala @@ -53,8 +53,7 @@ class MinMaxScalerSuite extends SparkFunSuite with MLlibTestSparkContext with De assert(vector1.equals(vector2), "Transformed vector is different with expected.") } - // copied model must have the same parent. - MLTestingUtils.checkCopy(model) + MLTestingUtils.checkCopyAndUids(scaler, model) } test("MinMaxScaler arguments max must be larger than min") { diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/PCASuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/PCASuite.scala index a60e87590f060..3067a52a4df76 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/PCASuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/PCASuite.scala @@ -58,12 +58,12 @@ class PCASuite extends SparkFunSuite with MLlibTestSparkContext with DefaultRead .setInputCol("features") .setOutputCol("pca_features") .setK(3) - .fit(df) - // copied model must have the same parent. - MLTestingUtils.checkCopy(pca) + val pcaModel = pca.fit(df) - pca.transform(df).select("pca_features", "expected").collect().foreach { + MLTestingUtils.checkCopyAndUids(pca, pcaModel) + + pcaModel.transform(df).select("pca_features", "expected").collect().foreach { case Row(x: Vector, y: Vector) => assert(x ~== y absTol 1e-5, "Transformed vector is different with expected vector.") } diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala index c664460d7d8bb..fbebd75d70ac5 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala @@ -37,6 +37,7 @@ class RFormulaSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul val formula = new RFormula().setFormula("id ~ v1 + v2") val original = Seq((0, 1.0, 3.0), (2, 2.0, 5.0)).toDF("id", "v1", "v2") val model = formula.fit(original) + MLTestingUtils.checkCopyAndUids(formula, model) val result = model.transform(original) val resultSchema = model.transformSchema(original.schema) val expected = Seq( diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/StandardScalerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/StandardScalerSuite.scala index a928f93633011..350ba44baa1eb 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/StandardScalerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/StandardScalerSuite.scala @@ -20,7 +20,7 @@ package org.apache.spark.ml.feature import org.apache.spark.SparkFunSuite import org.apache.spark.ml.linalg.{Vector, Vectors} import org.apache.spark.ml.param.ParamsSuite -import org.apache.spark.ml.util.DefaultReadWriteTest +import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils} import org.apache.spark.ml.util.TestingUtils._ import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.sql.{DataFrame, Row} @@ -77,10 +77,11 @@ class StandardScalerSuite extends SparkFunSuite with MLlibTestSparkContext test("Standardization with default parameter") { val df0 = data.zip(resWithStd).toSeq.toDF("features", "expected") - val standardScaler0 = new StandardScaler() + val standardScalerEst0 = new StandardScaler() .setInputCol("features") .setOutputCol("standardized_features") - .fit(df0) + val standardScaler0 = standardScalerEst0.fit(df0) + MLTestingUtils.checkCopyAndUids(standardScalerEst0, standardScaler0) assertResult(standardScaler0.transform(df0)) } diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala index 2d0e63c9d669c..5634d4210f478 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala @@ -45,12 +45,11 @@ class StringIndexerSuite val indexer = new StringIndexer() .setInputCol("label") .setOutputCol("labelIndex") - .fit(df) + val indexerModel = indexer.fit(df) - // copied model must have the same parent. - MLTestingUtils.checkCopy(indexer) + MLTestingUtils.checkCopyAndUids(indexer, indexerModel) - val transformed = indexer.transform(df) + val transformed = indexerModel.transform(df) val attr = Attribute.fromStructField(transformed.schema("labelIndex")) .asInstanceOf[NominalAttribute] assert(attr.values.get === Array("a", "c", "b")) @@ -64,7 +63,7 @@ class StringIndexerSuite test("StringIndexerUnseen") { val data = Seq((0, "a"), (1, "b"), (4, "b")) - val data2 = Seq((0, "a"), (1, "b"), (2, "c")) + val data2 = Seq((0, "a"), (1, "b"), (2, "c"), (3, "d")) val df = data.toDF("id", "label") val df2 = data2.toDF("id", "label") val indexer = new StringIndexer() @@ -75,22 +74,32 @@ class StringIndexerSuite intercept[SparkException] { indexer.transform(df2).collect() } - val indexerSkipInvalid = new StringIndexer() - .setInputCol("label") - .setOutputCol("labelIndex") - .setHandleInvalid("skip") - .fit(df) + + indexer.setHandleInvalid("skip") // Verify that we skip the c record - val transformed = indexerSkipInvalid.transform(df2) - val attr = Attribute.fromStructField(transformed.schema("labelIndex")) + val transformedSkip = indexer.transform(df2) + val attrSkip = Attribute.fromStructField(transformedSkip.schema("labelIndex")) .asInstanceOf[NominalAttribute] - assert(attr.values.get === Array("b", "a")) - val output = transformed.select("id", "labelIndex").rdd.map { r => + assert(attrSkip.values.get === Array("b", "a")) + val outputSkip = transformedSkip.select("id", "labelIndex").rdd.map { r => (r.getInt(0), r.getDouble(1)) }.collect().toSet // a -> 1, b -> 0 - val expected = Set((0, 1.0), (1, 0.0)) - assert(output === expected) + val expectedSkip = Set((0, 1.0), (1, 0.0)) + assert(outputSkip === expectedSkip) + + indexer.setHandleInvalid("keep") + // Verify that we keep the unseen records + val transformedKeep = indexer.transform(df2) + val attrKeep = Attribute.fromStructField(transformedKeep.schema("labelIndex")) + .asInstanceOf[NominalAttribute] + assert(attrKeep.values.get === Array("b", "a", "__unknown")) + val outputKeep = transformedKeep.select("id", "labelIndex").rdd.map { r => + (r.getInt(0), r.getDouble(1)) + }.collect().toSet + // a -> 1, b -> 0, c -> 2, d -> 3 + val expectedKeep = Set((0, 1.0), (1, 0.0), (2, 2.0), (3, 2.0)) + assert(outputKeep === expectedKeep) } test("StringIndexer with a numeric input column") { @@ -112,6 +121,51 @@ class StringIndexerSuite assert(output === expected) } + test("StringIndexer with NULLs") { + val data: Seq[(Int, String)] = Seq((0, "a"), (1, "b"), (2, "b"), (3, null)) + val data2: Seq[(Int, String)] = Seq((0, "a"), (1, "b"), (3, null)) + val df = data.toDF("id", "label") + val df2 = data2.toDF("id", "label") + + val indexer = new StringIndexer() + .setInputCol("label") + .setOutputCol("labelIndex") + + withClue("StringIndexer should throw error when setHandleInvalid=error " + + "when given NULL values") { + intercept[SparkException] { + indexer.setHandleInvalid("error") + indexer.fit(df).transform(df2).collect() + } + } + + indexer.setHandleInvalid("skip") + val transformedSkip = indexer.fit(df).transform(df2) + val attrSkip = Attribute + .fromStructField(transformedSkip.schema("labelIndex")) + .asInstanceOf[NominalAttribute] + assert(attrSkip.values.get === Array("b", "a")) + val outputSkip = transformedSkip.select("id", "labelIndex").rdd.map { r => + (r.getInt(0), r.getDouble(1)) + }.collect().toSet + // a -> 1, b -> 0 + val expectedSkip = Set((0, 1.0), (1, 0.0)) + assert(outputSkip === expectedSkip) + + indexer.setHandleInvalid("keep") + val transformedKeep = indexer.fit(df).transform(df2) + val attrKeep = Attribute + .fromStructField(transformedKeep.schema("labelIndex")) + .asInstanceOf[NominalAttribute] + assert(attrKeep.values.get === Array("b", "a", "__unknown")) + val outputKeep = transformedKeep.select("id", "labelIndex").rdd.map { r => + (r.getInt(0), r.getDouble(1)) + }.collect().toSet + // a -> 1, b -> 0, null -> 2 + val expectedKeep = Set((0, 1.0), (1, 0.0), (3, 2.0)) + assert(outputKeep === expectedKeep) + } + test("StringIndexerModel should keep silent if the input column does not exist.") { val indexerModel = new StringIndexerModel("indexer", Array("a", "b", "c")) .setInputCol("label") diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/VectorIndexerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/VectorIndexerSuite.scala index b28ce2ab45b45..f2cca8aa82e85 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/VectorIndexerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/VectorIndexerSuite.scala @@ -114,8 +114,7 @@ class VectorIndexerSuite extends SparkFunSuite with MLlibTestSparkContext val vectorIndexer = getIndexer val model = vectorIndexer.fit(densePoints1) // vectors of length 3 - // copied model must have the same parent. - MLTestingUtils.checkCopy(model) + MLTestingUtils.checkCopyAndUids(vectorIndexer, model) model.transform(densePoints1) // should work model.transform(sparsePoints1) // should work diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/Word2VecSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/Word2VecSuite.scala index 613cc3d60b227..a6a1c2b4f32bd 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/Word2VecSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/Word2VecSuite.scala @@ -57,15 +57,14 @@ class Word2VecSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul val docDF = doc.zip(expected).toDF("text", "expected") - val model = new Word2Vec() + val w2v = new Word2Vec() .setVectorSize(3) .setInputCol("text") .setOutputCol("result") .setSeed(42L) - .fit(docDF) + val model = w2v.fit(docDF) - // copied model must have the same parent. - MLTestingUtils.checkCopy(model) + MLTestingUtils.checkCopyAndUids(w2v, model) // These expectations are just magic values, characterizing the current // behavior. The test needs to be updated to be more general, see SPARK-11502 @@ -133,14 +132,22 @@ class Word2VecSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul .setSeed(42L) .fit(docDF) - val expectedSimilarity = Array(0.2608488929093532, -0.8271274846926078) - val (synonyms, similarity) = model.findSynonyms("a", 2).rdd.map { + val expected = Map(("b", 0.2608488929093532), ("c", -0.8271274846926078)) + val findSynonymsResult = model.findSynonyms("a", 2).rdd.map { case Row(w: String, sim: Double) => (w, sim) - }.collect().unzip + }.collectAsMap() + + expected.foreach { + case (expectedSynonym, expectedSimilarity) => + assert(findSynonymsResult.contains(expectedSynonym)) + assert(expectedSimilarity ~== findSynonymsResult.get(expectedSynonym).get absTol 1E-5) + } - assert(synonyms === Array("b", "c")) - expectedSimilarity.zip(similarity).foreach { - case (expected, actual) => assert(math.abs((expected - actual) / expected) < 1E-5) + val findSynonymsArrayResult = model.findSynonymsArray("a", 2).toMap + findSynonymsResult.foreach { + case (expectedSynonym, expectedSimilarity) => + assert(findSynonymsArrayResult.contains(expectedSynonym)) + assert(expectedSimilarity ~== findSynonymsArrayResult.get(expectedSynonym).get absTol 1E-5) } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/fpm/FPGrowthSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/fpm/FPGrowthSuite.scala index 74c7461401905..87f8b9034dde8 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/fpm/FPGrowthSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/fpm/FPGrowthSuite.scala @@ -17,9 +17,10 @@ package org.apache.spark.ml.fpm import org.apache.spark.SparkFunSuite -import org.apache.spark.ml.util.DefaultReadWriteTest +import org.apache.spark.ml.param.ParamsSuite +import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils} import org.apache.spark.mllib.util.MLlibTestSparkContext -import org.apache.spark.sql.{DataFrame, Dataset, Row, SparkSession} +import org.apache.spark.sql.{DataFrame, Dataset, SparkSession} import org.apache.spark.sql.functions._ import org.apache.spark.sql.types._ @@ -34,7 +35,7 @@ class FPGrowthSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul test("FPGrowth fit and transform with different data types") { Array(IntegerType, StringType, ShortType, LongType, ByteType).foreach { dt => - val data = dataset.withColumn("features", col("features").cast(ArrayType(dt))) + val data = dataset.withColumn("items", col("items").cast(ArrayType(dt))) val model = new FPGrowth().setMinSupport(0.5).fit(data) val generatedRules = model.setMinConfidence(0.5).associationRules val expectedRules = spark.createDataFrame(Seq( @@ -52,8 +53,8 @@ class FPGrowthSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul (0, Array("1", "2"), Array.emptyIntArray), (0, Array("1", "2"), Array.emptyIntArray), (0, Array("1", "3"), Array(2)) - )).toDF("id", "features", "prediction") - .withColumn("features", col("features").cast(ArrayType(dt))) + )).toDF("id", "items", "prediction") + .withColumn("items", col("items").cast(ArrayType(dt))) .withColumn("prediction", col("prediction").cast(ArrayType(dt))) assert(expectedTransformed.collect().toSet.equals( transformed.collect().toSet)) @@ -79,30 +80,68 @@ class FPGrowthSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul (1, Array("1", "2", "3", "5")), (2, Array("1", "2", "3", "4")), (3, null.asInstanceOf[Array[String]]) - )).toDF("id", "features") + )).toDF("id", "items") val model = new FPGrowth().setMinSupport(0.7).fit(dataset) val prediction = model.transform(df) assert(prediction.select("prediction").where("id=3").first().getSeq[String](0).isEmpty) } + test("FPGrowth prediction should not contain duplicates") { + // This should generate rule 1 -> 3, 2 -> 3 + val dataset = spark.createDataFrame(Seq( + Array("1", "3"), + Array("2", "3") + ).map(Tuple1(_))).toDF("items") + val model = new FPGrowth().fit(dataset) + + val prediction = model.transform( + spark.createDataFrame(Seq(Tuple1(Array("1", "2")))).toDF("items") + ).first().getAs[Seq[String]]("prediction") + + assert(prediction === Seq("3")) + } + + test("FPGrowthModel setMinConfidence should affect rules generation and transform") { + val model = new FPGrowth().setMinSupport(0.1).setMinConfidence(0.1).fit(dataset) + val oldRulesNum = model.associationRules.count() + val oldPredict = model.transform(dataset) + + model.setMinConfidence(0.8765) + assert(oldRulesNum > model.associationRules.count()) + assert(!model.transform(dataset).collect().toSet.equals(oldPredict.collect().toSet)) + + // association rules should stay the same for same minConfidence + model.setMinConfidence(0.1) + assert(oldRulesNum === model.associationRules.count()) + assert(model.transform(dataset).collect().toSet.equals(oldPredict.collect().toSet)) + } + test("FPGrowth parameter check") { val fpGrowth = new FPGrowth().setMinSupport(0.4567) val model = fpGrowth.fit(dataset) .setMinConfidence(0.5678) assert(fpGrowth.getMinSupport === 0.4567) assert(model.getMinConfidence === 0.5678) + // numPartitions should not have default value. + assert(fpGrowth.isDefined(fpGrowth.numPartitions) === false) + MLTestingUtils.checkCopyAndUids(fpGrowth, model) + ParamsSuite.checkParams(fpGrowth) + ParamsSuite.checkParams(model) } test("read/write") { def checkModelData(model: FPGrowthModel, model2: FPGrowthModel): Unit = { - assert(model.freqItemsets.sort("items").collect() === - model2.freqItemsets.sort("items").collect()) + assert(model.freqItemsets.collect().toSet.equals( + model2.freqItemsets.collect().toSet)) + assert(model.associationRules.collect().toSet.equals( + model2.associationRules.collect().toSet)) + assert(model.setMinConfidence(0.9).associationRules.collect().toSet.equals( + model2.setMinConfidence(0.9).associationRules.collect().toSet)) } val fPGrowth = new FPGrowth() - testEstimatorAndModelReadWrite( - fPGrowth, dataset, FPGrowthSuite.allParamSettings, checkModelData) + testEstimatorAndModelReadWrite(fPGrowth, dataset, FPGrowthSuite.allParamSettings, + FPGrowthSuite.allParamSettings, checkModelData) } - } object FPGrowthSuite { @@ -113,7 +152,7 @@ object FPGrowthSuite { (0, Array("1", "2")), (0, Array("1", "2")), (0, Array("1", "3")) - )).toDF("id", "features") + )).toDF("id", "items") } /** diff --git a/mllib/src/test/scala/org/apache/spark/ml/param/ParamsSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/param/ParamsSuite.scala index aa9c53ca30eee..78a33e05e0e48 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/param/ParamsSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/param/ParamsSuite.scala @@ -377,7 +377,7 @@ class ParamsSuite extends SparkFunSuite { object ParamsSuite extends SparkFunSuite { /** - * Checks common requirements for [[Params.params]]: + * Checks common requirements for `Params.params`: * - params are ordered by names * - param parent has the same UID as the object's UID * - param name is the same as the param method name diff --git a/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala index e494ea89e63bd..7574af3d77ea8 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala @@ -409,8 +409,7 @@ class ALSSuite logInfo(s"Test RMSE is $rmse.") assert(rmse < targetRMSE) - // copied model must have the same parent. - MLTestingUtils.checkCopy(model) + MLTestingUtils.checkCopyAndUids(als, model) } test("exact rank-1 matrix") { @@ -518,37 +517,26 @@ class ALSSuite } test("read/write") { - import ALSSuite._ - val (ratings, _) = genExplicitTestData(numUsers = 4, numItems = 4, rank = 1) - val als = new ALS() - allEstimatorParamSettings.foreach { case (p, v) => - als.set(als.getParam(p), v) - } val spark = this.spark import spark.implicits._ - val model = als.fit(ratings.toDF()) - - // Test Estimator save/load - val als2 = testDefaultReadWrite(als) - allEstimatorParamSettings.foreach { case (p, v) => - val param = als.getParam(p) - assert(als.get(param).get === als2.get(param).get) - } + import ALSSuite._ + val (ratings, _) = genExplicitTestData(numUsers = 4, numItems = 4, rank = 1) - // Test Model save/load - val model2 = testDefaultReadWrite(model) - allModelParamSettings.foreach { case (p, v) => - val param = model.getParam(p) - assert(model.get(param).get === model2.get(param).get) - } - assert(model.rank === model2.rank) def getFactors(df: DataFrame): Set[(Int, Array[Float])] = { df.select("id", "features").collect().map { case r => (r.getInt(0), r.getAs[Array[Float]](1)) }.toSet } - assert(getFactors(model.userFactors) === getFactors(model2.userFactors)) - assert(getFactors(model.itemFactors) === getFactors(model2.itemFactors)) + + def checkModelData(model: ALSModel, model2: ALSModel): Unit = { + assert(model.rank === model2.rank) + assert(getFactors(model.userFactors) === getFactors(model2.userFactors)) + assert(getFactors(model.itemFactors) === getFactors(model2.itemFactors)) + } + + val als = new ALS() + testEstimatorAndModelReadWrite(als, ratings.toDF(), allEstimatorParamSettings, + allModelParamSettings, checkModelData) } test("input type validation") { diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/AFTSurvivalRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/AFTSurvivalRegressionSuite.scala index 3cd4b0ac308ef..fb39e50a83552 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/regression/AFTSurvivalRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/AFTSurvivalRegressionSuite.scala @@ -83,8 +83,7 @@ class AFTSurvivalRegressionSuite .setQuantilesCol("quantiles") .fit(datasetUnivariate) - // copied model must have the same parent. - MLTestingUtils.checkCopy(model) + MLTestingUtils.checkCopyAndUids(aftr, model) model.transform(datasetUnivariate) .select("label", "prediction", "quantiles") @@ -419,7 +418,8 @@ class AFTSurvivalRegressionSuite } val aft = new AFTSurvivalRegression() testEstimatorAndModelReadWrite(aft, datasetMultivariate, - AFTSurvivalRegressionSuite.allParamSettings, checkModelData) + AFTSurvivalRegressionSuite.allParamSettings, AFTSurvivalRegressionSuite.allParamSettings, + checkModelData) } test("SPARK-15892: Incorrectly merged AFTAggregator with zero total count") { diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/DecisionTreeRegressorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/DecisionTreeRegressorSuite.scala index 15fa26e8b5272..642f266891b57 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/regression/DecisionTreeRegressorSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/DecisionTreeRegressorSuite.scala @@ -69,11 +69,12 @@ class DecisionTreeRegressorSuite test("copied model must have the same parent") { val categoricalFeatures = Map(0 -> 2, 1 -> 2) val df = TreeTests.setMetadata(categoricalDataPointsRDD, categoricalFeatures, numClasses = 0) - val model = new DecisionTreeRegressor() + val dtr = new DecisionTreeRegressor() .setImpurity("variance") .setMaxDepth(2) - .setMaxBins(8).fit(df) - MLTestingUtils.checkCopy(model) + .setMaxBins(8) + val model = dtr.fit(df) + MLTestingUtils.checkCopyAndUids(dtr, model) } test("predictVariance") { @@ -165,16 +166,17 @@ class DecisionTreeRegressorSuite val categoricalData: DataFrame = TreeTests.setMetadata(rdd, Map(0 -> 2, 1 -> 3), numClasses = 0) testEstimatorAndModelReadWrite(dt, categoricalData, - TreeTests.allParamSettings, checkModelData) + TreeTests.allParamSettings, TreeTests.allParamSettings, checkModelData) // Continuous splits with tree depth 2 val continuousData: DataFrame = TreeTests.setMetadata(rdd, Map.empty[Int, Int], numClasses = 0) testEstimatorAndModelReadWrite(dt, continuousData, - TreeTests.allParamSettings, checkModelData) + TreeTests.allParamSettings, TreeTests.allParamSettings, checkModelData) // Continuous splits with tree depth 0 testEstimatorAndModelReadWrite(dt, continuousData, + TreeTests.allParamSettings ++ Map("maxDepth" -> 0), TreeTests.allParamSettings ++ Map("maxDepth" -> 0), checkModelData) } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala index dcf3f9a1ea9b2..2da25f7e0100a 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala @@ -90,8 +90,7 @@ class GBTRegressorSuite extends SparkFunSuite with MLlibTestSparkContext .setMaxIter(2) val model = gbt.fit(df) - // copied model must have the same parent. - MLTestingUtils.checkCopy(model) + MLTestingUtils.checkCopyAndUids(gbt, model) val preds = model.transform(df) val predictions = preds.select("prediction").rdd.map(_.getDouble(0)) // Checks based on SPARK-8736 (to ensure it is not doing classification) @@ -184,7 +183,8 @@ class GBTRegressorSuite extends SparkFunSuite with MLlibTestSparkContext val allParamSettings = TreeTests.allParamSettings ++ Map("lossType" -> "squared") val continuousData: DataFrame = TreeTests.setMetadata(rdd, Map.empty[Int, Int], numClasses = 0) - testEstimatorAndModelReadWrite(gbt, continuousData, allParamSettings, checkModelData) + testEstimatorAndModelReadWrite(gbt, continuousData, allParamSettings, + allParamSettings, checkModelData) } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala index add28a72b6808..f7c7c001a36af 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala @@ -197,8 +197,7 @@ class GeneralizedLinearRegressionSuite val model = glr.setFamily("gaussian").setLink("identity") .fit(datasetGaussianIdentity) - // copied model must have the same parent. - MLTestingUtils.checkCopy(model) + MLTestingUtils.checkCopyAndUids(glr, model) assert(model.hasSummary) val copiedModel = model.copy(ParamMap.empty) assert(copiedModel.hasSummary) @@ -1418,6 +1417,7 @@ class GeneralizedLinearRegressionSuite val glr = new GeneralizedLinearRegression() testEstimatorAndModelReadWrite(glr, datasetPoissonLog, + GeneralizedLinearRegressionSuite.allParamSettings, GeneralizedLinearRegressionSuite.allParamSettings, checkModelData) } diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/IsotonicRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/IsotonicRegressionSuite.scala index 8cbb2acad243e..180f5f7ce5ab2 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/regression/IsotonicRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/IsotonicRegressionSuite.scala @@ -93,8 +93,7 @@ class IsotonicRegressionSuite val model = ir.fit(dataset) - // copied model must have the same parent. - MLTestingUtils.checkCopy(model) + MLTestingUtils.checkCopyAndUids(ir, model) model.transform(dataset) .select("label", "features", "prediction", "weight") @@ -178,7 +177,7 @@ class IsotonicRegressionSuite val ir = new IsotonicRegression() testEstimatorAndModelReadWrite(ir, dataset, IsotonicRegressionSuite.allParamSettings, - checkModelData) + IsotonicRegressionSuite.allParamSettings, checkModelData) } test("should support all NumericType labels and weights, and not support other types") { diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala index 584a1b272f6c8..e7bd4eb9e0adf 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala @@ -148,8 +148,7 @@ class LinearRegressionSuite assert(lir.getSolver == "auto") val model = lir.fit(datasetWithDenseFeature) - // copied model must have the same parent. - MLTestingUtils.checkCopy(model) + MLTestingUtils.checkCopyAndUids(lir, model) assert(model.hasSummary) val copiedModel = model.copy(ParamMap.empty) assert(copiedModel.hasSummary) @@ -842,7 +841,8 @@ class LinearRegressionSuite MLTestingUtils.testArbitrarilyScaledWeights[LinearRegressionModel, LinearRegression]( datasetWithStrongNoise.as[LabeledPoint], estimator, modelEquals) MLTestingUtils.testOutliersWithSmallWeights[LinearRegressionModel, LinearRegression]( - datasetWithStrongNoise.as[LabeledPoint], estimator, numClasses, modelEquals) + datasetWithStrongNoise.as[LabeledPoint], estimator, numClasses, modelEquals, + outlierRatio = 3) MLTestingUtils.testOversamplingVsWeighting[LinearRegressionModel, LinearRegression]( datasetWithStrongNoise.as[LabeledPoint], estimator, modelEquals, seed) } @@ -985,7 +985,7 @@ class LinearRegressionSuite } val lr = new LinearRegression() testEstimatorAndModelReadWrite(lr, datasetWithWeight, LinearRegressionSuite.allParamSettings, - checkModelData) + LinearRegressionSuite.allParamSettings, checkModelData) } test("should support all NumericType labels and weights, and not support other types") { diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/RandomForestRegressorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/RandomForestRegressorSuite.scala index c08335f9f84af..8b8e8a655f47b 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/regression/RandomForestRegressorSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/RandomForestRegressorSuite.scala @@ -90,6 +90,8 @@ class RandomForestRegressorSuite extends SparkFunSuite with MLlibTestSparkContex val model = rf.fit(df) + MLTestingUtils.checkCopyAndUids(rf, model) + val importances = model.featureImportances val mostImportantFeature = importances.argmax assert(mostImportantFeature === 1) @@ -124,7 +126,8 @@ class RandomForestRegressorSuite extends SparkFunSuite with MLlibTestSparkContex val continuousData: DataFrame = TreeTests.setMetadata(rdd, Map.empty[Int, Int], numClasses = 0) - testEstimatorAndModelReadWrite(rf, continuousData, allParamSettings, checkModelData) + testEstimatorAndModelReadWrite(rf, continuousData, allParamSettings, + allParamSettings, checkModelData) } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/stat/ChiSquareTestSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/stat/ChiSquareTestSuite.scala new file mode 100644 index 0000000000000..2d6aad0808bc6 --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/ml/stat/ChiSquareTestSuite.scala @@ -0,0 +1,98 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.stat + +import java.util.Random + +import org.apache.spark.{SparkException, SparkFunSuite} +import org.apache.spark.ml.feature.LabeledPoint +import org.apache.spark.ml.linalg.{Vector, Vectors} +import org.apache.spark.ml.util.DefaultReadWriteTest +import org.apache.spark.ml.util.TestingUtils._ +import org.apache.spark.mllib.stat.test.ChiSqTest +import org.apache.spark.mllib.util.MLlibTestSparkContext + +class ChiSquareTestSuite + extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { + + import testImplicits._ + + test("test DataFrame of labeled points") { + // labels: 1.0 (2 / 6), 0.0 (4 / 6) + // feature1: 0.5 (1 / 6), 1.5 (2 / 6), 3.5 (3 / 6) + // feature2: 10.0 (1 / 6), 20.0 (1 / 6), 30.0 (2 / 6), 40.0 (2 / 6) + val data = Seq( + LabeledPoint(0.0, Vectors.dense(0.5, 10.0)), + LabeledPoint(0.0, Vectors.dense(1.5, 20.0)), + LabeledPoint(1.0, Vectors.dense(1.5, 30.0)), + LabeledPoint(0.0, Vectors.dense(3.5, 30.0)), + LabeledPoint(0.0, Vectors.dense(3.5, 40.0)), + LabeledPoint(1.0, Vectors.dense(3.5, 40.0))) + for (numParts <- List(2, 4, 6, 8)) { + val df = spark.createDataFrame(sc.parallelize(data, numParts)) + val chi = ChiSquareTest.test(df, "features", "label") + val (pValues: Vector, degreesOfFreedom: Array[Int], statistics: Vector) = + chi.select("pValues", "degreesOfFreedom", "statistics") + .as[(Vector, Array[Int], Vector)].head() + assert(pValues ~== Vectors.dense(0.6873, 0.6823) relTol 1e-4) + assert(degreesOfFreedom === Array(2, 3)) + assert(statistics ~== Vectors.dense(0.75, 1.5) relTol 1e-4) + } + } + + test("large number of features (SPARK-3087)") { + // Test that the right number of results is returned + val numCols = 1001 + val sparseData = Array( + LabeledPoint(0.0, Vectors.sparse(numCols, Seq((100, 2.0)))), + LabeledPoint(0.1, Vectors.sparse(numCols, Seq((200, 1.0))))) + val df = spark.createDataFrame(sparseData) + val chi = ChiSquareTest.test(df, "features", "label") + val (pValues: Vector, degreesOfFreedom: Array[Int], statistics: Vector) = + chi.select("pValues", "degreesOfFreedom", "statistics") + .as[(Vector, Array[Int], Vector)].head() + assert(pValues.size === numCols) + assert(degreesOfFreedom.length === numCols) + assert(statistics.size === numCols) + assert(pValues(1000) !== null) // SPARK-3087 + } + + test("fail on continuous features or labels") { + val tooManyCategories: Int = 100000 + assert(tooManyCategories > ChiSqTest.maxCategories, "This unit test requires that " + + "tooManyCategories be large enough to cause ChiSqTest to throw an exception.") + + val random = new Random(11L) + val continuousLabel = Seq.fill(tooManyCategories)( + LabeledPoint(random.nextDouble(), Vectors.dense(random.nextInt(2)))) + withClue("ChiSquare should throw an exception when given a continuous-valued label") { + intercept[SparkException] { + val df = spark.createDataFrame(continuousLabel) + ChiSquareTest.test(df, "features", "label") + } + } + val continuousFeature = Seq.fill(tooManyCategories)( + LabeledPoint(random.nextInt(2), Vectors.dense(random.nextDouble()))) + withClue("ChiSquare should throw an exception when given continuous-valued features") { + intercept[SparkException] { + val df = spark.createDataFrame(continuousFeature) + ChiSquareTest.test(df, "features", "label") + } + } + } +} diff --git a/mllib/src/test/scala/org/apache/spark/ml/stat/CorrelationSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/stat/CorrelationSuite.scala new file mode 100644 index 0000000000000..7d935e651f220 --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/ml/stat/CorrelationSuite.scala @@ -0,0 +1,77 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.stat + +import breeze.linalg.{DenseMatrix => BDM} + +import org.apache.spark.SparkFunSuite +import org.apache.spark.internal.Logging +import org.apache.spark.ml.linalg.{Matrices, Matrix, Vectors} +import org.apache.spark.ml.util.TestingUtils._ +import org.apache.spark.mllib.util.MLlibTestSparkContext +import org.apache.spark.sql.{DataFrame, Row} + + +class CorrelationSuite extends SparkFunSuite with MLlibTestSparkContext with Logging { + + val xData = Array(1.0, 0.0, -2.0) + val yData = Array(4.0, 5.0, 3.0) + val zeros = new Array[Double](3) + val data = Seq( + Vectors.dense(1.0, 0.0, 0.0, -2.0), + Vectors.dense(4.0, 5.0, 0.0, 3.0), + Vectors.dense(6.0, 7.0, 0.0, 8.0), + Vectors.dense(9.0, 0.0, 0.0, 1.0) + ) + + private def X = spark.createDataFrame(data.map(Tuple1.apply)).toDF("features") + + private def extract(df: DataFrame): BDM[Double] = { + val Array(Row(mat: Matrix)) = df.collect() + mat.asBreeze.toDenseMatrix + } + + + test("corr(X) default, pearson") { + val defaultMat = Correlation.corr(X, "features") + val pearsonMat = Correlation.corr(X, "features", "pearson") + // scalastyle:off + val expected = Matrices.fromBreeze(BDM( + (1.00000000, 0.05564149, Double.NaN, 0.4004714), + (0.05564149, 1.00000000, Double.NaN, 0.9135959), + (Double.NaN, Double.NaN, 1.00000000, Double.NaN), + (0.40047142, 0.91359586, Double.NaN, 1.0000000))) + // scalastyle:on + + assert(Matrices.fromBreeze(extract(defaultMat)) ~== expected absTol 1e-4) + assert(Matrices.fromBreeze(extract(pearsonMat)) ~== expected absTol 1e-4) + } + + test("corr(X) spearman") { + val spearmanMat = Correlation.corr(X, "features", "spearman") + // scalastyle:off + val expected = Matrices.fromBreeze(BDM( + (1.0000000, 0.1054093, Double.NaN, 0.4000000), + (0.1054093, 1.0000000, Double.NaN, 0.9486833), + (Double.NaN, Double.NaN, 1.00000000, Double.NaN), + (0.4000000, 0.9486833, Double.NaN, 1.0000000))) + // scalastyle:on + assert(Matrices.fromBreeze(extract(spearmanMat)) ~== expected absTol 1e-4) + } + +} diff --git a/mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala index e1ab7c2d6520b..df155b464c64b 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala @@ -104,6 +104,31 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext { assert(splits.distinct.length === splits.length) } + // SPARK-16957: Use midpoints for split values. + { + val fakeMetadata = new DecisionTreeMetadata(1, 0, 0, 0, + Map(), Set(), + Array(3), Gini, QuantileStrategy.Sort, + 0, 0, 0.0, 0, 0 + ) + + // possibleSplits <= numSplits + { + val featureSamples = Array(0, 1, 0, 0, 1, 0, 1, 1).map(_.toDouble) + val splits = RandomForest.findSplitsForContinuousFeature(featureSamples, fakeMetadata, 0) + val expectedSplits = Array((0.0 + 1.0) / 2) + assert(splits === expectedSplits) + } + + // possibleSplits > numSplits + { + val featureSamples = Array(0, 0, 1, 1, 2, 2, 3, 3).map(_.toDouble) + val splits = RandomForest.findSplitsForContinuousFeature(featureSamples, fakeMetadata, 0) + val expectedSplits = Array((0.0 + 1.0) / 2, (2.0 + 3.0) / 2) + assert(splits === expectedSplits) + } + } + // find splits should not return identical splits // when there are not enough split candidates, reduce the number of splits in metadata { @@ -112,9 +137,10 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext { Array(5), Gini, QuantileStrategy.Sort, 0, 0, 0.0, 0, 0 ) - val featureSamples = Array(1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 3).map(_.toDouble) + val featureSamples = Array(1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 3, 3).map(_.toDouble) val splits = RandomForest.findSplitsForContinuousFeature(featureSamples, fakeMetadata, 0) - assert(splits === Array(1.0, 2.0)) + val expectedSplits = Array((1.0 + 2.0) / 2, (2.0 + 3.0) / 2) + assert(splits === expectedSplits) // check returned splits are distinct assert(splits.distinct.length === splits.length) } @@ -126,9 +152,11 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext { Array(3), Gini, QuantileStrategy.Sort, 0, 0, 0.0, 0, 0 ) - val featureSamples = Array(2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 3, 4, 5).map(_.toDouble) + val featureSamples = Array(2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 3, 4, 5) + .map(_.toDouble) val splits = RandomForest.findSplitsForContinuousFeature(featureSamples, fakeMetadata, 0) - assert(splits === Array(2.0, 3.0)) + val expectedSplits = Array((2.0 + 3.0) / 2, (3.0 + 4.0) / 2) + assert(splits === expectedSplits) } // find splits when most samples close to the maximum @@ -138,9 +166,10 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext { Array(2), Gini, QuantileStrategy.Sort, 0, 0, 0.0, 0, 0 ) - val featureSamples = Array(0, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2).map(_.toDouble) + val featureSamples = Array(0, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2).map(_.toDouble) val splits = RandomForest.findSplitsForContinuousFeature(featureSamples, fakeMetadata, 0) - assert(splits === Array(1.0)) + val expectedSplits = Array((1.0 + 2.0) / 2) + assert(splits === expectedSplits) } // find splits for constant feature diff --git a/mllib/src/test/scala/org/apache/spark/ml/tree/impl/TreeTests.scala b/mllib/src/test/scala/org/apache/spark/ml/tree/impl/TreeTests.scala index c90cb8ca1034c..92a236928e90b 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/tree/impl/TreeTests.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/tree/impl/TreeTests.scala @@ -34,7 +34,7 @@ private[ml] object TreeTests extends SparkFunSuite { * Convert the given data to a DataFrame, and set the features and label metadata. * @param data Dataset. Categorical features and labels must already have 0-based indices. * This must be non-empty. - * @param categoricalFeatures Map: categorical feature index -> number of distinct values + * @param categoricalFeatures Map: categorical feature index to number of distinct values * @param numClasses Number of classes label can take. If 0, mark as continuous. * @return DataFrame with metadata */ @@ -69,7 +69,9 @@ private[ml] object TreeTests extends SparkFunSuite { df("label").as("label", labelMetadata)) } - /** Java-friendly version of [[setMetadata()]] */ + /** + * Java-friendly version of `setMetadata()` + */ def setMetadata( data: JavaRDD[LabeledPoint], categoricalFeatures: java.util.Map[java.lang.Integer, java.lang.Integer], diff --git a/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala index 7116265474f22..2b4e6b53e4f81 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala @@ -58,8 +58,7 @@ class CrossValidatorSuite .setNumFolds(3) val cvModel = cv.fit(dataset) - // copied model must have the same paren. - MLTestingUtils.checkCopy(cvModel) + MLTestingUtils.checkCopyAndUids(cv, cvModel) val parent = cvModel.bestModel.parent.asInstanceOf[LogisticRegression] assert(parent.getRegParam === 0.001) diff --git a/mllib/src/test/scala/org/apache/spark/ml/tuning/TrainValidationSplitSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/tuning/TrainValidationSplitSuite.scala index 4463a9b6e543a..a34f930aa11c4 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/tuning/TrainValidationSplitSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/tuning/TrainValidationSplitSuite.scala @@ -45,18 +45,18 @@ class TrainValidationSplitSuite .addGrid(lr.maxIter, Array(0, 10)) .build() val eval = new BinaryClassificationEvaluator - val cv = new TrainValidationSplit() + val tvs = new TrainValidationSplit() .setEstimator(lr) .setEstimatorParamMaps(lrParamMaps) .setEvaluator(eval) .setTrainRatio(0.5) .setSeed(42L) - val cvModel = cv.fit(dataset) - val parent = cvModel.bestModel.parent.asInstanceOf[LogisticRegression] - assert(cv.getTrainRatio === 0.5) + val tvsModel = tvs.fit(dataset) + val parent = tvsModel.bestModel.parent.asInstanceOf[LogisticRegression] + assert(tvs.getTrainRatio === 0.5) assert(parent.getRegParam === 0.001) assert(parent.getMaxIter === 10) - assert(cvModel.validationMetrics.length === lrParamMaps.length) + assert(tvsModel.validationMetrics.length === lrParamMaps.length) } test("train validation with linear regression") { @@ -71,28 +71,27 @@ class TrainValidationSplitSuite .addGrid(trainer.maxIter, Array(0, 10)) .build() val eval = new RegressionEvaluator() - val cv = new TrainValidationSplit() + val tvs = new TrainValidationSplit() .setEstimator(trainer) .setEstimatorParamMaps(lrParamMaps) .setEvaluator(eval) .setTrainRatio(0.5) .setSeed(42L) - val cvModel = cv.fit(dataset) + val tvsModel = tvs.fit(dataset) - // copied model must have the same paren. - MLTestingUtils.checkCopy(cvModel) + MLTestingUtils.checkCopyAndUids(tvs, tvsModel) - val parent = cvModel.bestModel.parent.asInstanceOf[LinearRegression] + val parent = tvsModel.bestModel.parent.asInstanceOf[LinearRegression] assert(parent.getRegParam === 0.001) assert(parent.getMaxIter === 10) - assert(cvModel.validationMetrics.length === lrParamMaps.length) + assert(tvsModel.validationMetrics.length === lrParamMaps.length) eval.setMetricName("r2") - val cvModel2 = cv.fit(dataset) - val parent2 = cvModel2.bestModel.parent.asInstanceOf[LinearRegression] + val tvsModel2 = tvs.fit(dataset) + val parent2 = tvsModel2.bestModel.parent.asInstanceOf[LinearRegression] assert(parent2.getRegParam === 0.001) assert(parent2.getMaxIter === 10) - assert(cvModel2.validationMetrics.length === lrParamMaps.length) + assert(tvsModel2.validationMetrics.length === lrParamMaps.length) } test("transformSchema should check estimatorParamMaps") { @@ -104,17 +103,17 @@ class TrainValidationSplitSuite .addGrid(est.inputCol, Array("input1", "input2")) .build() - val cv = new TrainValidationSplit() + val tvs = new TrainValidationSplit() .setEstimator(est) .setEstimatorParamMaps(paramMaps) .setEvaluator(eval) .setTrainRatio(0.5) - cv.transformSchema(new StructType()) // This should pass. + tvs.transformSchema(new StructType()) // This should pass. val invalidParamMaps = paramMaps :+ ParamMap(est.inputCol -> "") - cv.setEstimatorParamMaps(invalidParamMaps) + tvs.setEstimatorParamMaps(invalidParamMaps) intercept[IllegalArgumentException] { - cv.transformSchema(new StructType()) + tvs.transformSchema(new StructType()) } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/util/DefaultReadWriteTest.scala b/mllib/src/test/scala/org/apache/spark/ml/util/DefaultReadWriteTest.scala index 553b8725b30a3..27d606cb05dc2 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/util/DefaultReadWriteTest.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/util/DefaultReadWriteTest.scala @@ -81,42 +81,44 @@ trait DefaultReadWriteTest extends TempDirectory { self: Suite => /** * Default test for Estimator, Model pairs: * - Explicitly set Params, and train model - * - Test save/load using [[testDefaultReadWrite()]] on Estimator and Model + * - Test save/load using `testDefaultReadWrite` on Estimator and Model * - Check Params on Estimator and Model * - Compare model data * - * This requires that the [[Estimator]] and [[Model]] share the same set of [[Param]]s. + * This requires that `Model`'s `Param`s should be a subset of `Estimator`'s `Param`s. * * @param estimator Estimator to test - * @param dataset Dataset to pass to [[Estimator.fit()]] - * @param testParams Set of [[Param]] values to set in estimator - * @param checkModelData Method which takes the original and loaded [[Model]] and compares their - * data. This method does not need to check [[Param]] values. - * @tparam E Type of [[Estimator]] - * @tparam M Type of [[Model]] produced by estimator + * @param dataset Dataset to pass to `Estimator.fit()` + * @param testEstimatorParams Set of `Param` values to set in estimator + * @param testModelParams Set of `Param` values to set in model + * @param checkModelData Method which takes the original and loaded `Model` and compares their + * data. This method does not need to check `Param` values. + * @tparam E Type of `Estimator` + * @tparam M Type of `Model` produced by estimator */ def testEstimatorAndModelReadWrite[ E <: Estimator[M] with MLWritable, M <: Model[M] with MLWritable]( estimator: E, dataset: Dataset[_], - testParams: Map[String, Any], + testEstimatorParams: Map[String, Any], + testModelParams: Map[String, Any], checkModelData: (M, M) => Unit): Unit = { // Set some Params to make sure set Params are serialized. - testParams.foreach { case (p, v) => + testEstimatorParams.foreach { case (p, v) => estimator.set(estimator.getParam(p), v) } val model = estimator.fit(dataset) // Test Estimator save/load val estimator2 = testDefaultReadWrite(estimator) - testParams.foreach { case (p, v) => + testEstimatorParams.foreach { case (p, v) => val param = estimator.getParam(p) assert(estimator.get(param).get === estimator2.get(param).get) } // Test Model save/load val model2 = testDefaultReadWrite(model) - testParams.foreach { case (p, v) => + testModelParams.foreach { case (p, v) => val param = model.getParam(p) assert(model.get(param).get === model2.get(param).get) } diff --git a/mllib/src/test/scala/org/apache/spark/ml/util/MLTestingUtils.scala b/mllib/src/test/scala/org/apache/spark/ml/util/MLTestingUtils.scala index f1ed568d5e60a..bef79e634f75f 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/util/MLTestingUtils.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/util/MLTestingUtils.scala @@ -31,11 +31,15 @@ import org.apache.spark.sql.functions._ import org.apache.spark.sql.types._ object MLTestingUtils extends SparkFunSuite { - def checkCopy(model: Model[_]): Unit = { + + def checkCopyAndUids[T <: Estimator[_]](estimator: T, model: Model[_]): Unit = { + assert(estimator.uid === model.uid, "Model uid does not match parent estimator") + + // copied model must have the same parent val copied = model.copy(ParamMap.empty) .asInstanceOf[Model[_]] - assert(copied.parent.uid == model.parent.uid) assert(copied.parent == model.parent) + assert(copied.parent.uid == model.parent.uid) } def checkNumericTypes[M <: Model[M], T <: Estimator[M]]( @@ -260,12 +264,13 @@ object MLTestingUtils extends SparkFunSuite { data: Dataset[LabeledPoint], estimator: E with HasWeightCol, numClasses: Int, - modelEquals: (M, M) => Unit): Unit = { + modelEquals: (M, M) => Unit, + outlierRatio: Int): Unit = { import data.sqlContext.implicits._ val outlierDS = data.withColumn("weight", lit(1.0)).as[Instance].flatMap { case Instance(l, w, f) => val outlierLabel = if (numClasses == 0) -l else numClasses - l - 1 - List.fill(3)(Instance(outlierLabel, 0.0001, f)) ++ List(Instance(l, w, f)) + List.fill(outlierRatio)(Instance(outlierLabel, 0.0001, f)) ++ List(Instance(l, w, f)) } val trueModel = estimator.set(estimator.weightCol, "").fit(data) val outlierModel = estimator.set(estimator.weightCol, "weight").fit(outlierDS) diff --git a/mllib/src/test/scala/org/apache/spark/ml/util/StopwatchSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/util/StopwatchSuite.scala index 141249a427a4c..54e363a8b9f2b 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/util/StopwatchSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/util/StopwatchSuite.scala @@ -105,8 +105,8 @@ class StopwatchSuite extends SparkFunSuite with MLlibTestSparkContext { private object StopwatchSuite extends SparkFunSuite { /** - * Checks the input stopwatch on a task that takes a random time (<10ms) to finish. Validates and - * returns the duration reported by the stopwatch. + * Checks the input stopwatch on a task that takes a random time (less than 10ms) to finish. + * Validates and returns the duration reported by the stopwatch. */ def checkStopwatch(sw: Stopwatch): Long = { val ubStart = now diff --git a/mllib/src/test/scala/org/apache/spark/ml/util/TempDirectory.scala b/mllib/src/test/scala/org/apache/spark/ml/util/TempDirectory.scala index 8f11bbc8e47af..50b73e0e99a22 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/util/TempDirectory.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/util/TempDirectory.scala @@ -30,7 +30,9 @@ trait TempDirectory extends BeforeAndAfterAll { self: Suite => private var _tempDir: File = _ - /** Returns the temporary directory as a [[File]] instance. */ + /** + * Returns the temporary directory as a `File` instance. + */ protected def tempDir: File = _tempDir override def beforeAll(): Unit = { diff --git a/mllib/src/test/scala/org/apache/spark/mllib/fpm/PrefixSpanSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/fpm/PrefixSpanSuite.scala index 4c2376376dd2a..c2e08d078fc1a 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/fpm/PrefixSpanSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/fpm/PrefixSpanSuite.scala @@ -360,6 +360,49 @@ class PrefixSpanSuite extends SparkFunSuite with MLlibTestSparkContext { compareResults(expected, model.freqSequences.collect()) } + test("PrefixSpan pre-processing's cleaning test") { + + // One item per itemSet + val itemToInt1 = (4 to 5).zipWithIndex.toMap + val sequences1 = Seq( + Array(Array(4), Array(1), Array(2), Array(5), Array(2), Array(4), Array(5)), + Array(Array(6), Array(7), Array(8))) + val rdd1 = sc.parallelize(sequences1, 2).cache() + + val cleanedSequence1 = PrefixSpan.toDatabaseInternalRepr(rdd1, itemToInt1).collect() + + val expected1 = Array(Array(0, 4, 0, 5, 0, 4, 0, 5, 0)) + .map(_.map(x => if (x == 0) 0 else itemToInt1(x) + 1)) + + compareInternalSequences(expected1, cleanedSequence1) + + // Multi-item sequence + val itemToInt2 = (4 to 6).zipWithIndex.toMap + val sequences2 = Seq( + Array(Array(4, 5), Array(1, 6, 2), Array(2), Array(5), Array(2), Array(4), Array(5, 6, 7)), + Array(Array(8, 9), Array(1, 2))) + val rdd2 = sc.parallelize(sequences2, 2).cache() + + val cleanedSequence2 = PrefixSpan.toDatabaseInternalRepr(rdd2, itemToInt2).collect() + + val expected2 = Array(Array(0, 4, 5, 0, 6, 0, 5, 0, 4, 0, 5, 6, 0)) + .map(_.map(x => if (x == 0) 0 else itemToInt2(x) + 1)) + + compareInternalSequences(expected2, cleanedSequence2) + + // Emptied sequence + val itemToInt3 = (10 to 10).zipWithIndex.toMap + val sequences3 = Seq( + Array(Array(4, 5), Array(1, 6, 2), Array(2), Array(5), Array(2), Array(4), Array(5, 6, 7)), + Array(Array(8, 9), Array(1, 2))) + val rdd3 = sc.parallelize(sequences3, 2).cache() + + val cleanedSequence3 = PrefixSpan.toDatabaseInternalRepr(rdd3, itemToInt3).collect() + val expected3 = Array[Array[Int]]() + + compareInternalSequences(expected3, cleanedSequence3) + } + test("model save/load") { val sequences = Seq( Array(Array(1, 2), Array(3)), @@ -409,4 +452,12 @@ class PrefixSpanSuite extends SparkFunSuite with MLlibTestSparkContext { val actualSet = actualValue.map(x => (x._1.toSeq, x._2)).toSet assert(expectedSet === actualSet) } + + private def compareInternalSequences( + expectedValue: Array[Array[Int]], + actualValue: Array[Array[Int]]): Unit = { + val expectedSet = expectedValue.map(x => x.toSeq).toSet + val actualSet = actualValue.map(x => x.toSeq).toSet + assert(expectedSet === actualSet) + } } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/optimization/LBFGSSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/optimization/LBFGSSuite.scala index 0d2911173e03e..10a2ce1835186 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/optimization/LBFGSSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/optimization/LBFGSSuite.scala @@ -193,8 +193,8 @@ class LBFGSSuite extends SparkFunSuite with MLlibTestSparkContext with Matchers // With smaller convergenceTol, it takes more steps. assert(lossLBFGS3.length > lossLBFGS2.length) - // Based on observation, lossLBFGS2 runs 5 iterations, no theoretically guaranteed. - assert(lossLBFGS3.length == 6) + // Based on observation, lossLBFGS3 runs 7 iterations, no theoretically guaranteed. + assert(lossLBFGS3.length == 7) assert((lossLBFGS3(4) - lossLBFGS3(5)) / lossLBFGS3(4) < convergenceTol) } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/stat/HypothesisTestSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/stat/HypothesisTestSuite.scala index 46fcebe132749..992b876561896 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/stat/HypothesisTestSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/stat/HypothesisTestSuite.scala @@ -145,14 +145,17 @@ class HypothesisTestSuite extends SparkFunSuite with MLlibTestSparkContext { assert(chi(1000) != null) // SPARK-3087 // Detect continuous features or labels + val tooManyCategories: Int = 100000 + assert(tooManyCategories > ChiSqTest.maxCategories, "This unit test requires that " + + "tooManyCategories be large enough to cause ChiSqTest to throw an exception.") val random = new Random(11L) - val continuousLabel = - Seq.fill(100000)(LabeledPoint(random.nextDouble(), Vectors.dense(random.nextInt(2)))) + val continuousLabel = Seq.fill(tooManyCategories)( + LabeledPoint(random.nextDouble(), Vectors.dense(random.nextInt(2)))) intercept[SparkException] { Statistics.chiSqTest(sc.parallelize(continuousLabel, 2)) } - val continuousFeature = - Seq.fill(100000)(LabeledPoint(random.nextInt(2), Vectors.dense(random.nextDouble()))) + val continuousFeature = Seq.fill(tooManyCategories)( + LabeledPoint(random.nextInt(2), Vectors.dense(random.nextDouble()))) intercept[SparkException] { Statistics.chiSqTest(sc.parallelize(continuousFeature, 2)) } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/tree/ImpuritySuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/tree/ImpuritySuite.scala index 14152cdd63bc7..d0f02dd966bd5 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/tree/ImpuritySuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/tree/ImpuritySuite.scala @@ -21,7 +21,7 @@ import org.apache.spark.SparkFunSuite import org.apache.spark.mllib.tree.impurity.{EntropyAggregator, GiniAggregator} /** - * Test suites for [[GiniAggregator]] and [[EntropyAggregator]]. + * Test suites for `GiniAggregator` and `EntropyAggregator`. */ class ImpuritySuite extends SparkFunSuite { test("Gini impurity does not support negative labels") { diff --git a/mllib/src/test/scala/org/apache/spark/mllib/util/MLlibTestSparkContext.scala b/mllib/src/test/scala/org/apache/spark/mllib/util/MLlibTestSparkContext.scala index 6bb7ed9c9513c..720237bd2dddd 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/util/MLlibTestSparkContext.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/util/MLlibTestSparkContext.scala @@ -60,7 +60,7 @@ trait MLlibTestSparkContext extends TempDirectory { self: Suite => * A helper object for importing SQL implicits. * * Note that the alternative of importing `spark.implicits._` is not possible here. - * This is because we create the [[SQLContext]] immediately before the first test is run, + * This is because we create the `SQLContext` immediately before the first test is run, * but the implicits import is needed in the constructor. */ protected object testImplicits extends SQLImplicits { diff --git a/mllib/src/test/scala/org/apache/spark/mllib/util/TestingUtils.scala b/mllib/src/test/scala/org/apache/spark/mllib/util/TestingUtils.scala index 39a6bc37d9638..d39865a19a5c5 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/util/TestingUtils.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/util/TestingUtils.scala @@ -207,7 +207,7 @@ object TestingUtils { if (r.fun(x, r.y, r.eps)) { throw new TestFailedException( s"Did not expect \n$x\n and \n${r.y}\n to be within " + - "${r.eps}${r.method} for all elements.", 0) + s"${r.eps}${r.method} for all elements.", 0) } true } diff --git a/pom.xml b/pom.xml index 67e9de082dbbc..4b0a335813cbe 100644 --- a/pom.xml +++ b/pom.xml @@ -26,7 +26,7 @@ org.apache.spark spark-parent_2.11 - 2.2.0-SNAPSHOT + 2.3.0-SNAPSHOT pom Spark Project Parent POM http://spark.apache.org/ @@ -58,10 +58,6 @@ https://issues.apache.org/jira/browse/SPARK - - ${maven.version} - - Dev Mailing List @@ -139,7 +135,7 @@ 1.9.1-palantir1 8.18.0 1.54 - 9.2.16.v20160414 + 9.3.11.v20160721 3.1.0 0.8.0 2.4.0 @@ -726,7 +722,7 @@ org.scalanlp breeze_${scala.binary.version} - 0.12 + 0.13.1 diff --git a/project/CirclePlugin.scala b/project/CirclePlugin.scala index c1858c7153276..39ba4a0d786f6 100644 --- a/project/CirclePlugin.scala +++ b/project/CirclePlugin.scala @@ -87,7 +87,7 @@ object CirclePlugin extends AutoPlugin { testOptions := (testOptions in Test).value, resourceGenerators := (resourceGenerators in Test).value, // NOTE: this is because of dependencies like: - // org.apache.spark:spark-tags:2.2.0-SNAPSHOT:test->test + // org.apache.spark:spark-tags:2.3.0-SNAPSHOT:test->test // That somehow don't get resolved properly in the 'circle' ivy configuration even though it extends test // To test, copare: // > show unsafe/test:fullClasspath diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index 721e901be97f1..5abe5a4c0566e 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -34,6 +34,10 @@ import com.typesafe.tools.mima.core.ProblemFilters._ */ object MimaExcludes { + // Exclude rules for 2.3.x + lazy val v23excludes = v22excludes ++ Seq( + ) + // Exclude rules for 2.2.x lazy val v22excludes = v21excludes ++ Seq( // [SPARK-19652][UI] Do auth checks for REST API access. @@ -67,8 +71,52 @@ object MimaExcludes { ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.status.api.v1.TaskData.$default$11"), // [SPARK-17161] Removing Python-friendly constructors not needed - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.classification.OneVsRestModel.this") - ) + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.classification.OneVsRestModel.this"), + + // [SPARK-19820] Allow reason to be specified to task kill + ProblemFilters.exclude[MissingTypesProblem]("org.apache.spark.TaskKilled$"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.TaskKilled.productElement"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.TaskKilled.productArity"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.TaskKilled.canEqual"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.TaskKilled.productIterator"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.TaskKilled.countTowardsTaskFailures"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.TaskKilled.productPrefix"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.TaskKilled.toErrorString"), + ProblemFilters.exclude[FinalMethodProblem]("org.apache.spark.TaskKilled.toString"), + ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.TaskContext.killTaskIfInterrupted"), + ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.TaskContext.getKillReason"), + + // [SPARK-19876] Add one time trigger, and improve Trigger APIs + ProblemFilters.exclude[IncompatibleTemplateDefProblem]("org.apache.spark.sql.streaming.Trigger"), + ProblemFilters.exclude[MissingTypesProblem]("org.apache.spark.sql.streaming.ProcessingTime"), + + // [SPARK-17471][ML] Add compressed method to ML matrices + ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.ml.linalg.Matrix.compressed"), + ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.ml.linalg.Matrix.compressedColMajor"), + ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.ml.linalg.Matrix.compressedRowMajor"), + ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.ml.linalg.Matrix.isRowMajor"), + ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.ml.linalg.Matrix.isColMajor"), + ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.ml.linalg.Matrix.getSparseSizeInBytes"), + ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.ml.linalg.Matrix.toDense"), + ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.ml.linalg.Matrix.toSparse"), + ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.ml.linalg.Matrix.toDenseRowMajor"), + ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.ml.linalg.Matrix.toSparseRowMajor"), + ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.ml.linalg.Matrix.toSparseColMajor"), + ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.ml.linalg.Matrix.getDenseSizeInBytes"), + ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.ml.linalg.Matrix.toDenseColMajor"), + ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.ml.linalg.Matrix.toDenseMatrix"), + ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.ml.linalg.Matrix.toSparseMatrix"), + ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.ml.linalg.Matrix.getSizeInBytes") + ) ++ Seq( + // [SPARK-17019] Expose on-heap and off-heap memory usage in various places + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.scheduler.SparkListenerBlockManagerAdded.copy"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.scheduler.SparkListenerBlockManagerAdded.this"), + ProblemFilters.exclude[MissingTypesProblem]("org.apache.spark.scheduler.SparkListenerBlockManagerAdded$"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.scheduler.SparkListenerBlockManagerAdded.apply"), + ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.storage.StorageStatus.this"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.storage.StorageStatus.this"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.status.api.v1.RDDDataDistribution.this") + ) // Exclude rules for 2.1.x lazy val v21excludes = v20excludes ++ { @@ -917,6 +965,10 @@ object MimaExcludes { ) ++ Seq( // [SPARK-17163] Unify logistic regression interface. Private constructor has new signature. ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.classification.LogisticRegressionModel.this") + ) ++ Seq( + // [SPARK-17498] StringIndexer enhancement for handling unseen labels + ProblemFilters.exclude[MissingTypesProblem]("org.apache.spark.ml.feature.StringIndexer"), + ProblemFilters.exclude[MissingTypesProblem]("org.apache.spark.ml.feature.StringIndexerModel") ) ++ Seq( // [SPARK-17365][Core] Remove/Kill multiple executors together to reduce RPC call time ProblemFilters.exclude[MissingTypesProblem]("org.apache.spark.SparkContext") @@ -958,6 +1010,7 @@ object MimaExcludes { } def excludes(version: String) = version match { + case v if v.startsWith("2.3") => v23excludes case v if v.startsWith("2.2") => v22excludes case v if v.startsWith("2.1") => v21excludes case v if v.startsWith("2.0") => v20excludes diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala index beacadbd99ca1..f4d0c671f518f 100644 --- a/project/SparkBuild.scala +++ b/project/SparkBuild.scala @@ -662,6 +662,7 @@ object Unidoc { .map(_.filterNot(_.getCanonicalPath.contains("org/apache/spark/util/collection"))) .map(_.filterNot(_.getCanonicalPath.contains("org/apache/spark/sql/catalyst"))) .map(_.filterNot(_.getCanonicalPath.contains("org/apache/spark/sql/execution"))) + .map(_.filterNot(_.getCanonicalPath.contains("org/apache/spark/sql/internal"))) .map(_.filterNot(_.getCanonicalPath.contains("org/apache/spark/sql/hive/test"))) } diff --git a/python/docs/pyspark.ml.rst b/python/docs/pyspark.ml.rst index 26f7415e1a423..930646de9cd86 100644 --- a/python/docs/pyspark.ml.rst +++ b/python/docs/pyspark.ml.rst @@ -65,6 +65,14 @@ pyspark.ml.regression module :undoc-members: :inherited-members: +pyspark.ml.stat module +---------------------- + +.. automodule:: pyspark.ml.stat + :members: + :undoc-members: + :inherited-members: + pyspark.ml.tuning module ------------------------ @@ -80,3 +88,11 @@ pyspark.ml.evaluation module :members: :undoc-members: :inherited-members: + +pyspark.ml.fpm module +---------------------------- + +.. automodule:: pyspark.ml.fpm + :members: + :undoc-members: + :inherited-members: diff --git a/python/pyspark/broadcast.py b/python/pyspark/broadcast.py index 74dee1420754a..b1b59f73d6718 100644 --- a/python/pyspark/broadcast.py +++ b/python/pyspark/broadcast.py @@ -21,6 +21,7 @@ from tempfile import NamedTemporaryFile from pyspark.cloudpickle import print_exec +from pyspark.util import _exception_message if sys.version < '3': import cPickle as pickle @@ -82,7 +83,8 @@ def dump(self, value, f): except pickle.PickleError: raise except Exception as e: - msg = "Could not serialize broadcast: " + e.__class__.__name__ + ": " + e.message + msg = "Could not serialize broadcast: %s: %s" \ + % (e.__class__.__name__, _exception_message(e)) print_exec(sys.stderr) raise pickle.PicklingError(msg) f.close() diff --git a/python/pyspark/cloudpickle.py b/python/pyspark/cloudpickle.py index 959fb8b357f99..389bee7eee6e9 100644 --- a/python/pyspark/cloudpickle.py +++ b/python/pyspark/cloudpickle.py @@ -56,6 +56,7 @@ import traceback import weakref +from pyspark.util import _exception_message if sys.version < '3': from pickle import Pickler @@ -152,13 +153,13 @@ def dump(self, obj): except pickle.PickleError: raise except Exception as e: - if "'i' format requires" in e.message: - msg = "Object too large to serialize: " + e.message + emsg = _exception_message(e) + if "'i' format requires" in emsg: + msg = "Object too large to serialize: %s" % emsg else: - msg = "Could not serialize object: " + e.__class__.__name__ + ": " + e.message + msg = "Could not serialize object: %s: %s" % (e.__class__.__name__, emsg) print_exec(sys.stderr) raise pickle.PicklingError(msg) - def save_memoryview(self, obj): """Fallback to save_string""" diff --git a/python/pyspark/context.py b/python/pyspark/context.py index c28a7709f56e2..3daff1b64a07b 100644 --- a/python/pyspark/context.py +++ b/python/pyspark/context.py @@ -240,6 +240,32 @@ def signal_handler(signal, frame): if isinstance(threading.current_thread(), threading._MainThread): signal.signal(signal.SIGINT, signal_handler) + def __repr__(self): + return "".format( + master=self.master, + appName=self.appName, + ) + + def _repr_html_(self): + return """ +

    +

    SparkContext

    + +

    Spark UI

    + +
    +
    Version
    +
    v{sc.version}
    +
    Master
    +
    {sc.master}
    +
    AppName
    +
    {sc.appName}
    +
    +
    + """.format( + sc=self + ) + def _initialize_context(self, jconf): """ Initialize SparkContext in function to allow subclass specific initialization diff --git a/python/pyspark/ml/classification.py b/python/pyspark/ml/classification.py index b4fc357e42d71..a9756ea4af99a 100644 --- a/python/pyspark/ml/classification.py +++ b/python/pyspark/ml/classification.py @@ -185,36 +185,33 @@ class LogisticRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredicti >>> from pyspark.sql import Row >>> from pyspark.ml.linalg import Vectors >>> bdf = sc.parallelize([ - ... Row(label=1.0, weight=2.0, features=Vectors.dense(1.0)), - ... Row(label=0.0, weight=2.0, features=Vectors.sparse(1, [], []))]).toDF() - >>> blor = LogisticRegression(maxIter=5, regParam=0.01, weightCol="weight") + ... Row(label=1.0, weight=1.0, features=Vectors.dense(0.0, 5.0)), + ... Row(label=0.0, weight=2.0, features=Vectors.dense(1.0, 2.0)), + ... Row(label=1.0, weight=3.0, features=Vectors.dense(2.0, 1.0)), + ... Row(label=0.0, weight=4.0, features=Vectors.dense(3.0, 3.0))]).toDF() + >>> blor = LogisticRegression(regParam=0.01, weightCol="weight") >>> blorModel = blor.fit(bdf) >>> blorModel.coefficients - DenseVector([5.5...]) + DenseVector([-1.080..., -0.646...]) >>> blorModel.intercept - -2.68... - >>> mdf = sc.parallelize([ - ... Row(label=1.0, weight=2.0, features=Vectors.dense(1.0)), - ... Row(label=0.0, weight=2.0, features=Vectors.sparse(1, [], [])), - ... Row(label=2.0, weight=2.0, features=Vectors.dense(3.0))]).toDF() - >>> mlor = LogisticRegression(maxIter=5, regParam=0.01, weightCol="weight", - ... family="multinomial") + 3.112... + >>> data_path = "data/mllib/sample_multiclass_classification_data.txt" + >>> mdf = spark.read.format("libsvm").load(data_path) + >>> mlor = LogisticRegression(regParam=0.1, elasticNetParam=1.0, family="multinomial") >>> mlorModel = mlor.fit(mdf) - >>> print(mlorModel.coefficientMatrix) - DenseMatrix([[-2.3...], - [ 0.2...], - [ 2.1... ]]) + >>> mlorModel.coefficientMatrix + SparseMatrix(3, 4, [0, 1, 2, 3], [3, 2, 1], [1.87..., -2.75..., -0.50...], 1) >>> mlorModel.interceptVector - DenseVector([2.0..., 0.8..., -2.8...]) - >>> test0 = sc.parallelize([Row(features=Vectors.dense(-1.0))]).toDF() + DenseVector([0.04..., -0.42..., 0.37...]) + >>> test0 = sc.parallelize([Row(features=Vectors.dense(-1.0, 1.0))]).toDF() >>> result = blorModel.transform(test0).head() >>> result.prediction - 0.0 + 1.0 >>> result.probability - DenseVector([0.99..., 0.00...]) + DenseVector([0.02..., 0.97...]) >>> result.rawPrediction - DenseVector([8.22..., -8.22...]) - >>> test1 = sc.parallelize([Row(features=Vectors.sparse(1, [0], [1.0]))]).toDF() + DenseVector([-3.54..., 3.54...]) + >>> test1 = sc.parallelize([Row(features=Vectors.sparse(2, [0], [1.0]))]).toDF() >>> blorModel.transform(test1).head().prediction 1.0 >>> blor.setParams("vector") @@ -224,8 +221,8 @@ class LogisticRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredicti >>> lr_path = temp_path + "/lr" >>> blor.save(lr_path) >>> lr2 = LogisticRegression.load(lr_path) - >>> lr2.getMaxIter() - 5 + >>> lr2.getRegParam() + 0.01 >>> model_path = temp_path + "/lr_model" >>> blorModel.save(model_path) >>> model2 = LogisticRegressionModel.load(model_path) @@ -1482,31 +1479,33 @@ class OneVsRest(Estimator, OneVsRestParams, MLReadable, MLWritable): >>> from pyspark.sql import Row >>> from pyspark.ml.linalg import Vectors - >>> df = sc.parallelize([ - ... Row(label=0.0, features=Vectors.dense(1.0, 0.8)), - ... Row(label=1.0, features=Vectors.sparse(2, [], [])), - ... Row(label=2.0, features=Vectors.dense(0.5, 0.5))]).toDF() - >>> lr = LogisticRegression(maxIter=5, regParam=0.01) + >>> data_path = "data/mllib/sample_multiclass_classification_data.txt" + >>> df = spark.read.format("libsvm").load(data_path) + >>> lr = LogisticRegression(regParam=0.01) >>> ovr = OneVsRest(classifier=lr) >>> model = ovr.fit(df) - >>> [x.coefficients for x in model.models] - [DenseVector([3.3925, 1.8785]), DenseVector([-4.3016, -6.3163]), DenseVector([-4.5855, 6.1785])] + >>> model.models[0].coefficients + DenseVector([0.5..., -1.0..., 3.4..., 4.2...]) + >>> model.models[1].coefficients + DenseVector([-2.1..., 3.1..., -2.6..., -2.3...]) + >>> model.models[2].coefficients + DenseVector([0.3..., -3.4..., 1.0..., -1.1...]) >>> [x.intercept for x in model.models] - [-3.64747..., 2.55078..., -1.10165...] - >>> test0 = sc.parallelize([Row(features=Vectors.dense(-1.0, 0.0))]).toDF() + [-2.7..., -2.5..., -1.3...] + >>> test0 = sc.parallelize([Row(features=Vectors.dense(-1.0, 0.0, 1.0, 1.0))]).toDF() >>> model.transform(test0).head().prediction - 1.0 - >>> test1 = sc.parallelize([Row(features=Vectors.sparse(2, [0], [1.0]))]).toDF() - >>> model.transform(test1).head().prediction 0.0 - >>> test2 = sc.parallelize([Row(features=Vectors.dense(0.5, 0.4))]).toDF() - >>> model.transform(test2).head().prediction + >>> test1 = sc.parallelize([Row(features=Vectors.sparse(4, [0], [1.0]))]).toDF() + >>> model.transform(test1).head().prediction 2.0 + >>> test2 = sc.parallelize([Row(features=Vectors.dense(0.5, 0.4, 0.3, 0.2))]).toDF() + >>> model.transform(test2).head().prediction + 0.0 >>> model_path = temp_path + "/ovr_model" >>> model.save(model_path) >>> model2 = OneVsRestModel.load(model_path) >>> model2.transform(test0).head().prediction - 1.0 + 0.0 .. versionadded:: 2.0.0 """ diff --git a/python/pyspark/ml/feature.py b/python/pyspark/ml/feature.py index 92f8549e9cb9e..8d25f5b3a771a 100755 --- a/python/pyspark/ml/feature.py +++ b/python/pyspark/ml/feature.py @@ -36,6 +36,7 @@ 'ElementwiseProduct', 'HashingTF', 'IDF', 'IDFModel', + 'Imputer', 'ImputerModel', 'IndexToString', 'MaxAbsScaler', 'MaxAbsScalerModel', 'MinHashLSH', 'MinHashLSHModel', @@ -870,6 +871,165 @@ def idf(self): return self._call_java("idf") +@inherit_doc +class Imputer(JavaEstimator, HasInputCols, JavaMLReadable, JavaMLWritable): + """ + .. note:: Experimental + + Imputation estimator for completing missing values, either using the mean or the median + of the columns in which the missing values are located. The input columns should be of + DoubleType or FloatType. Currently Imputer does not support categorical features and + possibly creates incorrect values for a categorical feature. + + Note that the mean/median value is computed after filtering out missing values. + All Null values in the input columns are treated as missing, and so are also imputed. For + computing median, :py:meth:`pyspark.sql.DataFrame.approxQuantile` is used with a + relative error of `0.001`. + + >>> df = spark.createDataFrame([(1.0, float("nan")), (2.0, float("nan")), (float("nan"), 3.0), + ... (4.0, 4.0), (5.0, 5.0)], ["a", "b"]) + >>> imputer = Imputer(inputCols=["a", "b"], outputCols=["out_a", "out_b"]) + >>> model = imputer.fit(df) + >>> model.surrogateDF.show() + +---+---+ + | a| b| + +---+---+ + |3.0|4.0| + +---+---+ + ... + >>> model.transform(df).show() + +---+---+-----+-----+ + | a| b|out_a|out_b| + +---+---+-----+-----+ + |1.0|NaN| 1.0| 4.0| + |2.0|NaN| 2.0| 4.0| + |NaN|3.0| 3.0| 3.0| + ... + >>> imputer.setStrategy("median").setMissingValue(1.0).fit(df).transform(df).show() + +---+---+-----+-----+ + | a| b|out_a|out_b| + +---+---+-----+-----+ + |1.0|NaN| 4.0| NaN| + ... + >>> imputerPath = temp_path + "/imputer" + >>> imputer.save(imputerPath) + >>> loadedImputer = Imputer.load(imputerPath) + >>> loadedImputer.getStrategy() == imputer.getStrategy() + True + >>> loadedImputer.getMissingValue() + 1.0 + >>> modelPath = temp_path + "/imputer-model" + >>> model.save(modelPath) + >>> loadedModel = ImputerModel.load(modelPath) + >>> loadedModel.transform(df).head().out_a == model.transform(df).head().out_a + True + + .. versionadded:: 2.2.0 + """ + + outputCols = Param(Params._dummy(), "outputCols", + "output column names.", typeConverter=TypeConverters.toListString) + + strategy = Param(Params._dummy(), "strategy", + "strategy for imputation. If mean, then replace missing values using the mean " + "value of the feature. If median, then replace missing values using the " + "median value of the feature.", + typeConverter=TypeConverters.toString) + + missingValue = Param(Params._dummy(), "missingValue", + "The placeholder for the missing values. All occurrences of missingValue " + "will be imputed.", typeConverter=TypeConverters.toFloat) + + @keyword_only + def __init__(self, strategy="mean", missingValue=float("nan"), inputCols=None, + outputCols=None): + """ + __init__(self, strategy="mean", missingValue=float("nan"), inputCols=None, \ + outputCols=None): + """ + super(Imputer, self).__init__() + self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.Imputer", self.uid) + self._setDefault(strategy="mean", missingValue=float("nan")) + kwargs = self._input_kwargs + self.setParams(**kwargs) + + @keyword_only + @since("2.2.0") + def setParams(self, strategy="mean", missingValue=float("nan"), inputCols=None, + outputCols=None): + """ + setParams(self, strategy="mean", missingValue=float("nan"), inputCols=None, \ + outputCols=None) + Sets params for this Imputer. + """ + kwargs = self._input_kwargs + return self._set(**kwargs) + + @since("2.2.0") + def setOutputCols(self, value): + """ + Sets the value of :py:attr:`outputCols`. + """ + return self._set(outputCols=value) + + @since("2.2.0") + def getOutputCols(self): + """ + Gets the value of :py:attr:`outputCols` or its default value. + """ + return self.getOrDefault(self.outputCols) + + @since("2.2.0") + def setStrategy(self, value): + """ + Sets the value of :py:attr:`strategy`. + """ + return self._set(strategy=value) + + @since("2.2.0") + def getStrategy(self): + """ + Gets the value of :py:attr:`strategy` or its default value. + """ + return self.getOrDefault(self.strategy) + + @since("2.2.0") + def setMissingValue(self, value): + """ + Sets the value of :py:attr:`missingValue`. + """ + return self._set(missingValue=value) + + @since("2.2.0") + def getMissingValue(self): + """ + Gets the value of :py:attr:`missingValue` or its default value. + """ + return self.getOrDefault(self.missingValue) + + def _create_model(self, java_model): + return ImputerModel(java_model) + + +class ImputerModel(JavaModel, JavaMLReadable, JavaMLWritable): + """ + .. note:: Experimental + + Model fitted by :py:class:`Imputer`. + + .. versionadded:: 2.2.0 + """ + + @property + @since("2.2.0") + def surrogateDF(self): + """ + Returns a DataFrame containing inputCols and their corresponding surrogates, + which are used to replace the missing values in the input DataFrame. + """ + return self._call_java("surrogateDF") + + @inherit_doc class MaxAbsScaler(JavaEstimator, HasInputCol, HasOutputCol, JavaMLReadable, JavaMLWritable): """ diff --git a/python/pyspark/ml/fpm.py b/python/pyspark/ml/fpm.py new file mode 100644 index 0000000000000..b30d4edb19908 --- /dev/null +++ b/python/pyspark/ml/fpm.py @@ -0,0 +1,216 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from pyspark import keyword_only, since +from pyspark.ml.util import * +from pyspark.ml.wrapper import JavaEstimator, JavaModel +from pyspark.ml.param.shared import * + +__all__ = ["FPGrowth", "FPGrowthModel"] + + +class HasSupport(Params): + """ + Mixin for param support. + """ + + minSupport = Param( + Params._dummy(), + "minSupport", + """Minimal support level of the frequent pattern. [0.0, 1.0]. + Any pattern that appears more than (minSupport * size-of-the-dataset) + times will be output""", + typeConverter=TypeConverters.toFloat) + + def setMinSupport(self, value): + """ + Sets the value of :py:attr:`minSupport`. + """ + return self._set(minSupport=value) + + def getMinSupport(self): + """ + Gets the value of minSupport or its default value. + """ + return self.getOrDefault(self.minSupport) + + +class HasConfidence(Params): + """ + Mixin for param confidence. + """ + + minConfidence = Param( + Params._dummy(), + "minConfidence", + """Minimal confidence for generating Association Rule. [0.0, 1.0] + Note that minConfidence has no effect during fitting.""", + typeConverter=TypeConverters.toFloat) + + def setMinConfidence(self, value): + """ + Sets the value of :py:attr:`minConfidence`. + """ + return self._set(minConfidence=value) + + def getMinConfidence(self): + """ + Gets the value of minConfidence or its default value. + """ + return self.getOrDefault(self.minConfidence) + + +class HasItemsCol(Params): + """ + Mixin for param itemsCol: items column name. + """ + + itemsCol = Param(Params._dummy(), "itemsCol", + "items column name", typeConverter=TypeConverters.toString) + + def setItemsCol(self, value): + """ + Sets the value of :py:attr:`itemsCol`. + """ + return self._set(itemsCol=value) + + def getItemsCol(self): + """ + Gets the value of itemsCol or its default value. + """ + return self.getOrDefault(self.itemsCol) + + +class FPGrowthModel(JavaModel, JavaMLWritable, JavaMLReadable): + """ + .. note:: Experimental + + Model fitted by FPGrowth. + + .. versionadded:: 2.2.0 + """ + @property + @since("2.2.0") + def freqItemsets(self): + """ + DataFrame with two columns: + * `items` - Itemset of the same type as the input column. + * `freq` - Frequency of the itemset (`LongType`). + """ + return self._call_java("freqItemsets") + + @property + @since("2.2.0") + def associationRules(self): + """ + Data with three columns: + * `antecedent` - Array of the same type as the input column. + * `consequent` - Array of the same type as the input column. + * `confidence` - Confidence for the rule (`DoubleType`). + """ + return self._call_java("associationRules") + + +class FPGrowth(JavaEstimator, HasItemsCol, HasPredictionCol, + HasSupport, HasConfidence, JavaMLWritable, JavaMLReadable): + """ + .. note:: Experimental + + A parallel FP-growth algorithm to mine frequent itemsets. The algorithm is described in + Li et al., PFP: Parallel FP-Growth for Query Recommendation [LI2008]_. + PFP distributes computation in such a way that each worker executes an + independent group of mining tasks. The FP-Growth algorithm is described in + Han et al., Mining frequent patterns without candidate generation [HAN2000]_ + + .. [LI2008] http://dx.doi.org/10.1145/1454008.1454027 + .. [HAN2000] http://dx.doi.org/10.1145/335191.335372 + + .. note:: null values in the feature column are ignored during fit(). + .. note:: Internally `transform` `collects` and `broadcasts` association rules. + + >>> from pyspark.sql.functions import split + >>> data = (spark.read + ... .text("data/mllib/sample_fpgrowth.txt") + ... .select(split("value", "\s+").alias("items"))) + >>> data.show(truncate=False) + +------------------------+ + |items | + +------------------------+ + |[r, z, h, k, p] | + |[z, y, x, w, v, u, t, s]| + |[s, x, o, n, r] | + |[x, z, y, m, t, s, q, e]| + |[z] | + |[x, z, y, r, q, t, p] | + +------------------------+ + >>> fp = FPGrowth(minSupport=0.2, minConfidence=0.7) + >>> fpm = fp.fit(data) + >>> fpm.freqItemsets.show(5) + +---------+----+ + | items|freq| + +---------+----+ + | [s]| 3| + | [s, x]| 3| + |[s, x, z]| 2| + | [s, z]| 2| + | [r]| 3| + +---------+----+ + only showing top 5 rows + >>> fpm.associationRules.show(5) + +----------+----------+----------+ + |antecedent|consequent|confidence| + +----------+----------+----------+ + | [t, s]| [y]| 1.0| + | [t, s]| [x]| 1.0| + | [t, s]| [z]| 1.0| + | [p]| [r]| 1.0| + | [p]| [z]| 1.0| + +----------+----------+----------+ + only showing top 5 rows + >>> new_data = spark.createDataFrame([(["t", "s"], )], ["items"]) + >>> sorted(fpm.transform(new_data).first().prediction) + ['x', 'y', 'z'] + + .. versionadded:: 2.2.0 + """ + @keyword_only + def __init__(self, minSupport=0.3, minConfidence=0.8, itemsCol="items", + predictionCol="prediction", numPartitions=None): + """ + __init__(self, minSupport=0.3, minConfidence=0.8, itemsCol="items", \ + predictionCol="prediction", numPartitions=None) + """ + super(FPGrowth, self).__init__() + self._java_obj = self._new_java_obj("org.apache.spark.ml.fpm.FPGrowth", self.uid) + self._setDefault(minSupport=0.3, minConfidence=0.8, + itemsCol="items", predictionCol="prediction") + kwargs = self._input_kwargs + self.setParams(**kwargs) + + @keyword_only + @since("2.2.0") + def setParams(self, minSupport=0.3, minConfidence=0.8, itemsCol="items", + predictionCol="prediction", numPartitions=None): + """ + setParams(self, minSupport=0.3, minConfidence=0.8, itemsCol="items", \ + predictionCol="prediction", numPartitions=None) + """ + kwargs = self._input_kwargs + return self._set(**kwargs) + + def _create_model(self, java_model): + return FPGrowthModel(java_model) diff --git a/python/pyspark/ml/linalg/__init__.py b/python/pyspark/ml/linalg/__init__.py index b765343251965..ad1b487676fa7 100644 --- a/python/pyspark/ml/linalg/__init__.py +++ b/python/pyspark/ml/linalg/__init__.py @@ -72,7 +72,10 @@ def _convert_to_vector(l): return DenseVector(l) elif _have_scipy and scipy.sparse.issparse(l): assert l.shape[1] == 1, "Expected column vector" + # Make sure the converted csc_matrix has sorted indices. csc = l.tocsc() + if not csc.has_sorted_indices: + csc.sort_indices() return SparseVector(l.shape[0], csc.indices, csc.data) else: raise TypeError("Cannot convert type %s into Vector" % type(l)) diff --git a/python/pyspark/ml/recommendation.py b/python/pyspark/ml/recommendation.py index 8bc899a0788bb..bcfb36880eb02 100644 --- a/python/pyspark/ml/recommendation.py +++ b/python/pyspark/ml/recommendation.py @@ -82,6 +82,14 @@ class ALS(JavaEstimator, HasCheckpointInterval, HasMaxIter, HasPredictionCol, Ha Row(user=1, item=0, prediction=2.6258413791656494) >>> predictions[2] Row(user=2, item=0, prediction=-1.5018409490585327) + >>> user_recs = model.recommendForAllUsers(3) + >>> user_recs.where(user_recs.user == 0)\ + .select("recommendations.item", "recommendations.rating").collect() + [Row(item=[0, 1, 2], rating=[3.910..., 1.992..., -0.138...])] + >>> item_recs = model.recommendForAllItems(3) + >>> item_recs.where(item_recs.item == 2)\ + .select("recommendations.user", "recommendations.rating").collect() + [Row(user=[2, 1, 0], rating=[4.901..., 3.981..., -0.138...])] >>> als_path = temp_path + "/als" >>> als.save(als_path) >>> als2 = ALS.load(als_path) @@ -384,6 +392,28 @@ def itemFactors(self): """ return self._call_java("itemFactors") + @since("2.2.0") + def recommendForAllUsers(self, numItems): + """ + Returns top `numItems` items recommended for each user, for all users. + + :param numItems: max number of recommendations for each user + :return: a DataFrame of (userCol, recommendations), where recommendations are + stored as an array of (itemCol, rating) Rows. + """ + return self._call_java("recommendForAllUsers", numItems) + + @since("2.2.0") + def recommendForAllItems(self, numUsers): + """ + Returns top `numUsers` users recommended for each item, for all items. + + :param numUsers: max number of recommendations for each item + :return: a DataFrame of (itemCol, recommendations), where recommendations are + stored as an array of (userCol, rating) Rows. + """ + return self._call_java("recommendForAllItems", numUsers) + if __name__ == "__main__": import doctest diff --git a/python/pyspark/ml/regression.py b/python/pyspark/ml/regression.py index b199bf282e4f2..3c3fcc8d9b8d8 100644 --- a/python/pyspark/ml/regression.py +++ b/python/pyspark/ml/regression.py @@ -1294,8 +1294,8 @@ class GeneralizedLinearRegression(JavaEstimator, HasLabelCol, HasFeaturesCol, Ha Fit a Generalized Linear Model specified by giving a symbolic description of the linear predictor (link function) and a description of the error distribution (family). It supports - "gaussian", "binomial", "poisson" and "gamma" as family. Valid link functions for each family - is listed below. The first link function of each family is the default one. + "gaussian", "binomial", "poisson", "gamma" and "tweedie" as family. Valid link functions for + each family is listed below. The first link function of each family is the default one. * "gaussian" -> "identity", "log", "inverse" @@ -1305,6 +1305,9 @@ class GeneralizedLinearRegression(JavaEstimator, HasLabelCol, HasFeaturesCol, Ha * "gamma" -> "inverse", "identity", "log" + * "tweedie" -> power link function specified through "linkPower". \ + The default link power in the tweedie family is 1 - variancePower. + .. seealso:: `GLM `_ >>> from pyspark.ml.linalg import Vectors @@ -1344,7 +1347,7 @@ class GeneralizedLinearRegression(JavaEstimator, HasLabelCol, HasFeaturesCol, Ha family = Param(Params._dummy(), "family", "The name of family which is a description of " + "the error distribution to be used in the model. Supported options: " + - "gaussian (default), binomial, poisson and gamma.", + "gaussian (default), binomial, poisson, gamma and tweedie.", typeConverter=TypeConverters.toString) link = Param(Params._dummy(), "link", "The name of link function which provides the " + "relationship between the linear predictor and the mean of the distribution " + @@ -1352,32 +1355,46 @@ class GeneralizedLinearRegression(JavaEstimator, HasLabelCol, HasFeaturesCol, Ha "and sqrt.", typeConverter=TypeConverters.toString) linkPredictionCol = Param(Params._dummy(), "linkPredictionCol", "link prediction (linear " + "predictor) column name", typeConverter=TypeConverters.toString) + variancePower = Param(Params._dummy(), "variancePower", "The power in the variance function " + + "of the Tweedie distribution which characterizes the relationship " + + "between the variance and mean of the distribution. Only applicable " + + "for the Tweedie family. Supported values: 0 and [1, Inf).", + typeConverter=TypeConverters.toFloat) + linkPower = Param(Params._dummy(), "linkPower", "The index in the power link function. " + + "Only applicable to the Tweedie family.", + typeConverter=TypeConverters.toFloat) @keyword_only def __init__(self, labelCol="label", featuresCol="features", predictionCol="prediction", family="gaussian", link=None, fitIntercept=True, maxIter=25, tol=1e-6, - regParam=0.0, weightCol=None, solver="irls", linkPredictionCol=None): + regParam=0.0, weightCol=None, solver="irls", linkPredictionCol=None, + variancePower=0.0, linkPower=None): """ __init__(self, labelCol="label", featuresCol="features", predictionCol="prediction", \ family="gaussian", link=None, fitIntercept=True, maxIter=25, tol=1e-6, \ - regParam=0.0, weightCol=None, solver="irls", linkPredictionCol=None) + regParam=0.0, weightCol=None, solver="irls", linkPredictionCol=None, \ + variancePower=0.0, linkPower=None) """ super(GeneralizedLinearRegression, self).__init__() self._java_obj = self._new_java_obj( "org.apache.spark.ml.regression.GeneralizedLinearRegression", self.uid) - self._setDefault(family="gaussian", maxIter=25, tol=1e-6, regParam=0.0, solver="irls") + self._setDefault(family="gaussian", maxIter=25, tol=1e-6, regParam=0.0, solver="irls", + variancePower=0.0) kwargs = self._input_kwargs + self.setParams(**kwargs) @keyword_only @since("2.0.0") def setParams(self, labelCol="label", featuresCol="features", predictionCol="prediction", family="gaussian", link=None, fitIntercept=True, maxIter=25, tol=1e-6, - regParam=0.0, weightCol=None, solver="irls", linkPredictionCol=None): + regParam=0.0, weightCol=None, solver="irls", linkPredictionCol=None, + variancePower=0.0, linkPower=None): """ setParams(self, labelCol="label", featuresCol="features", predictionCol="prediction", \ family="gaussian", link=None, fitIntercept=True, maxIter=25, tol=1e-6, \ - regParam=0.0, weightCol=None, solver="irls", linkPredictionCol=None) + regParam=0.0, weightCol=None, solver="irls", linkPredictionCol=None, \ + variancePower=0.0, linkPower=None) Sets params for generalized linear regression. """ kwargs = self._input_kwargs @@ -1428,6 +1445,34 @@ def getLink(self): """ return self.getOrDefault(self.link) + @since("2.2.0") + def setVariancePower(self, value): + """ + Sets the value of :py:attr:`variancePower`. + """ + return self._set(variancePower=value) + + @since("2.2.0") + def getVariancePower(self): + """ + Gets the value of variancePower or its default value. + """ + return self.getOrDefault(self.variancePower) + + @since("2.2.0") + def setLinkPower(self, value): + """ + Sets the value of :py:attr:`linkPower`. + """ + return self._set(linkPower=value) + + @since("2.2.0") + def getLinkPower(self): + """ + Gets the value of linkPower or its default value. + """ + return self.getOrDefault(self.linkPower) + class GeneralizedLinearRegressionModel(JavaModel, JavaPredictionModel, JavaMLWritable, JavaMLReadable): diff --git a/python/pyspark/ml/stat.py b/python/pyspark/ml/stat.py new file mode 100644 index 0000000000000..079b0833e1c6d --- /dev/null +++ b/python/pyspark/ml/stat.py @@ -0,0 +1,154 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from pyspark import since, SparkContext +from pyspark.ml.common import _java2py, _py2java +from pyspark.ml.wrapper import _jvm + + +class ChiSquareTest(object): + """ + .. note:: Experimental + + Conduct Pearson's independence test for every feature against the label. For each feature, + the (feature, label) pairs are converted into a contingency matrix for which the Chi-squared + statistic is computed. All label and feature values must be categorical. + + The null hypothesis is that the occurrence of the outcomes is statistically independent. + + :param dataset: + DataFrame of categorical labels and categorical features. + Real-valued features will be treated as categorical for each distinct value. + :param featuresCol: + Name of features column in dataset, of type `Vector` (`VectorUDT`). + :param labelCol: + Name of label column in dataset, of any numerical type. + :return: + DataFrame containing the test result for every feature against the label. + This DataFrame will contain a single Row with the following fields: + - `pValues: Vector` + - `degreesOfFreedom: Array[Int]` + - `statistics: Vector` + Each of these fields has one value per feature. + + >>> from pyspark.ml.linalg import Vectors + >>> from pyspark.ml.stat import ChiSquareTest + >>> dataset = [[0, Vectors.dense([0, 0, 1])], + ... [0, Vectors.dense([1, 0, 1])], + ... [1, Vectors.dense([2, 1, 1])], + ... [1, Vectors.dense([3, 1, 1])]] + >>> dataset = spark.createDataFrame(dataset, ["label", "features"]) + >>> chiSqResult = ChiSquareTest.test(dataset, 'features', 'label') + >>> chiSqResult.select("degreesOfFreedom").collect()[0] + Row(degreesOfFreedom=[3, 1, 0]) + + .. versionadded:: 2.2.0 + + """ + @staticmethod + @since("2.2.0") + def test(dataset, featuresCol, labelCol): + """ + Perform a Pearson's independence test using dataset. + """ + sc = SparkContext._active_spark_context + javaTestObj = _jvm().org.apache.spark.ml.stat.ChiSquareTest + args = [_py2java(sc, arg) for arg in (dataset, featuresCol, labelCol)] + return _java2py(sc, javaTestObj.test(*args)) + + +class Correlation(object): + """ + .. note:: Experimental + + Compute the correlation matrix for the input dataset of Vectors using the specified method. + Methods currently supported: `pearson` (default), `spearman`. + + .. note:: For Spearman, a rank correlation, we need to create an RDD[Double] for each column + and sort it in order to retrieve the ranks and then join the columns back into an RDD[Vector], + which is fairly costly. Cache the input Dataset before calling corr with `method = 'spearman'` + to avoid recomputing the common lineage. + + :param dataset: + A dataset or a dataframe. + :param column: + The name of the column of vectors for which the correlation coefficient needs + to be computed. This must be a column of the dataset, and it must contain + Vector objects. + :param method: + String specifying the method to use for computing correlation. + Supported: `pearson` (default), `spearman`. + :return: + A dataframe that contains the correlation matrix of the column of vectors. This + dataframe contains a single row and a single column of name + '$METHODNAME($COLUMN)'. + + >>> from pyspark.ml.linalg import Vectors + >>> from pyspark.ml.stat import Correlation + >>> dataset = [[Vectors.dense([1, 0, 0, -2])], + ... [Vectors.dense([4, 5, 0, 3])], + ... [Vectors.dense([6, 7, 0, 8])], + ... [Vectors.dense([9, 0, 0, 1])]] + >>> dataset = spark.createDataFrame(dataset, ['features']) + >>> pearsonCorr = Correlation.corr(dataset, 'features', 'pearson').collect()[0][0] + >>> print(str(pearsonCorr).replace('nan', 'NaN')) + DenseMatrix([[ 1. , 0.0556..., NaN, 0.4004...], + [ 0.0556..., 1. , NaN, 0.9135...], + [ NaN, NaN, 1. , NaN], + [ 0.4004..., 0.9135..., NaN, 1. ]]) + >>> spearmanCorr = Correlation.corr(dataset, 'features', method='spearman').collect()[0][0] + >>> print(str(spearmanCorr).replace('nan', 'NaN')) + DenseMatrix([[ 1. , 0.1054..., NaN, 0.4 ], + [ 0.1054..., 1. , NaN, 0.9486... ], + [ NaN, NaN, 1. , NaN], + [ 0.4 , 0.9486... , NaN, 1. ]]) + + .. versionadded:: 2.2.0 + + """ + @staticmethod + @since("2.2.0") + def corr(dataset, column, method="pearson"): + """ + Compute the correlation matrix with specified method using dataset. + """ + sc = SparkContext._active_spark_context + javaCorrObj = _jvm().org.apache.spark.ml.stat.Correlation + args = [_py2java(sc, arg) for arg in (dataset, column, method)] + return _java2py(sc, javaCorrObj.corr(*args)) + + +if __name__ == "__main__": + import doctest + import pyspark.ml.stat + from pyspark.sql import SparkSession + + globs = pyspark.ml.stat.__dict__.copy() + # The small batch size here ensures that we see multiple batches, + # even in these small test examples: + spark = SparkSession.builder \ + .master("local[2]") \ + .appName("ml.stat tests") \ + .getOrCreate() + sc = spark.sparkContext + globs['sc'] = sc + globs['spark'] = spark + + failure_count, test_count = doctest.testmod(globs=globs, optionflags=doctest.ELLIPSIS) + spark.stop() + if failure_count: + exit(-1) diff --git a/python/pyspark/ml/tests.py b/python/pyspark/ml/tests.py index 352416055791e..571ac4bc1c366 100755 --- a/python/pyspark/ml/tests.py +++ b/python/pyspark/ml/tests.py @@ -41,31 +41,30 @@ import tempfile import array as pyarray import numpy as np -from numpy import ( - array, array_equal, zeros, inf, random, exp, dot, all, mean, abs, arange, tile, ones) -from numpy import sum as array_sum +from numpy import abs, all, arange, array, array_equal, inf, ones, tile, zeros import inspect from pyspark import keyword_only, SparkContext from pyspark.ml import Estimator, Model, Pipeline, PipelineModel, Transformer from pyspark.ml.classification import * from pyspark.ml.clustering import * +from pyspark.ml.common import _java2py, _py2java from pyspark.ml.evaluation import BinaryClassificationEvaluator, RegressionEvaluator from pyspark.ml.feature import * -from pyspark.ml.linalg import Vector, SparseVector, DenseVector, VectorUDT,\ - DenseMatrix, SparseMatrix, Vectors, Matrices, MatrixUDT, _convert_to_vector +from pyspark.ml.fpm import FPGrowth, FPGrowthModel +from pyspark.ml.linalg import DenseMatrix, DenseMatrix, DenseVector, Matrices, MatrixUDT, \ + SparseMatrix, SparseVector, Vector, VectorUDT, Vectors from pyspark.ml.param import Param, Params, TypeConverters -from pyspark.ml.param.shared import HasMaxIter, HasInputCol, HasSeed +from pyspark.ml.param.shared import HasInputCol, HasMaxIter, HasSeed from pyspark.ml.recommendation import ALS -from pyspark.ml.regression import LinearRegression, DecisionTreeRegressor, \ - GeneralizedLinearRegression +from pyspark.ml.regression import DecisionTreeRegressor, GeneralizedLinearRegression, \ + LinearRegression +from pyspark.ml.stat import ChiSquareTest from pyspark.ml.tuning import * from pyspark.ml.wrapper import JavaParams, JavaWrapper -from pyspark.ml.common import _java2py, _py2java from pyspark.serializers import PickleSerializer from pyspark.sql import DataFrame, Row, SparkSession from pyspark.sql.functions import rand -from pyspark.sql.utils import IllegalArgumentException from pyspark.storagelevel import * from pyspark.tests import ReusedPySparkTestCase as PySparkTestCase @@ -1223,6 +1222,63 @@ def test_apply_binary_term_freqs(self): ": expected " + str(expected[i]) + ", got " + str(features[i])) +class GeneralizedLinearRegressionTest(SparkSessionTestCase): + + def test_tweedie_distribution(self): + + df = self.spark.createDataFrame( + [(1.0, Vectors.dense(0.0, 0.0)), + (1.0, Vectors.dense(1.0, 2.0)), + (2.0, Vectors.dense(0.0, 0.0)), + (2.0, Vectors.dense(1.0, 1.0)), ], ["label", "features"]) + + glr = GeneralizedLinearRegression(family="tweedie", variancePower=1.6) + model = glr.fit(df) + self.assertTrue(np.allclose(model.coefficients.toArray(), [-0.4645, 0.3402], atol=1E-4)) + self.assertTrue(np.isclose(model.intercept, 0.7841, atol=1E-4)) + + model2 = glr.setLinkPower(-1.0).fit(df) + self.assertTrue(np.allclose(model2.coefficients.toArray(), [-0.6667, 0.5], atol=1E-4)) + self.assertTrue(np.isclose(model2.intercept, 0.6667, atol=1E-4)) + + +class FPGrowthTests(SparkSessionTestCase): + def setUp(self): + super(FPGrowthTests, self).setUp() + self.data = self.spark.createDataFrame( + [([1, 2], ), ([1, 2], ), ([1, 2, 3], ), ([1, 3], )], + ["items"]) + + def test_association_rules(self): + fp = FPGrowth() + fpm = fp.fit(self.data) + + expected_association_rules = self.spark.createDataFrame( + [([3], [1], 1.0), ([2], [1], 1.0)], + ["antecedent", "consequent", "confidence"] + ) + actual_association_rules = fpm.associationRules + + self.assertEqual(actual_association_rules.subtract(expected_association_rules).count(), 0) + self.assertEqual(expected_association_rules.subtract(actual_association_rules).count(), 0) + + def test_freq_itemsets(self): + fp = FPGrowth() + fpm = fp.fit(self.data) + + expected_freq_itemsets = self.spark.createDataFrame( + [([1], 4), ([2], 3), ([2, 1], 3), ([3], 2), ([3, 1], 2)], + ["items", "freq"] + ) + actual_freq_itemsets = fpm.freqItemsets + + self.assertEqual(actual_freq_itemsets.subtract(expected_freq_itemsets).count(), 0) + self.assertEqual(expected_freq_itemsets.subtract(actual_freq_itemsets).count(), 0) + + def tearDown(self): + del self.data + + class ALSTest(SparkSessionTestCase): def test_storage_levels(self): @@ -1253,6 +1309,7 @@ class DefaultValuesTests(PySparkTestCase): """ def check_params(self, py_stage): + import pyspark.ml.feature if not hasattr(py_stage, "_to_java"): return java_stage = py_stage._to_java() @@ -1272,6 +1329,15 @@ def check_params(self, py_stage): _java2py(self.sc, java_stage.clear(java_param).getOrDefault(java_param)) py_stage._clear(p) py_default = py_stage.getOrDefault(p) + if isinstance(py_stage, pyspark.ml.feature.Imputer) and p.name == "missingValue": + # SPARK-15040 - default value for Imputer param 'missingValue' is NaN, + # and NaN != NaN, so handle it specially here + import math + self.assertTrue(math.isnan(java_default) and math.isnan(py_default), + "Java default %s and python default %s are not both NaN for " + "param %s for Params %s" + % (str(java_default), str(py_default), p.name, str(py_stage))) + return self.assertEqual(java_default, py_default, "Java default %s != python default %s of param %s for Params %s" % (str(java_default), str(py_default), p.name, str(py_stage))) @@ -1672,6 +1738,22 @@ def test_new_java_array(self): self.assertEqual(_java2py(self.sc, java_array), []) +class ChiSquareTestTests(SparkSessionTestCase): + + def test_chisquaretest(self): + data = [[0, Vectors.dense([0, 1, 2])], + [1, Vectors.dense([1, 1, 1])], + [2, Vectors.dense([2, 1, 0])]] + df = self.spark.createDataFrame(data, ['label', 'feat']) + res = ChiSquareTest.test(df, 'feat', 'label') + # This line is hitting the collect bug described in #17218, commented for now. + # pValues = res.select("degreesOfFreedom").collect()) + self.assertIsInstance(res, DataFrame) + fieldNames = set(field.name for field in res.schema.fields) + expectedFields = ["pValues", "degreesOfFreedom", "statistics"] + self.assertTrue(all(field in fieldNames for field in expectedFields)) + + if __name__ == "__main__": from pyspark.ml.tests import * if xmlrunner: diff --git a/python/pyspark/mllib/linalg/__init__.py b/python/pyspark/mllib/linalg/__init__.py index 031f22c02098e..7b24b3c74a9fa 100644 --- a/python/pyspark/mllib/linalg/__init__.py +++ b/python/pyspark/mllib/linalg/__init__.py @@ -74,7 +74,10 @@ def _convert_to_vector(l): return DenseVector(l) elif _have_scipy and scipy.sparse.issparse(l): assert l.shape[1] == 1, "Expected column vector" + # Make sure the converted csc_matrix has sorted indices. csc = l.tocsc() + if not csc.has_sorted_indices: + csc.sort_indices() return SparseVector(l.shape[0], csc.indices, csc.data) else: raise TypeError("Cannot convert type %s into Vector" % type(l)) diff --git a/python/pyspark/mllib/linalg/distributed.py b/python/pyspark/mllib/linalg/distributed.py index 600655c912ca6..4cb802514be52 100644 --- a/python/pyspark/mllib/linalg/distributed.py +++ b/python/pyspark/mllib/linalg/distributed.py @@ -28,14 +28,13 @@ from pyspark import RDD, since from pyspark.mllib.common import callMLlibFunc, JavaModelWrapper -from pyspark.mllib.linalg import _convert_to_vector, Matrix, QRDecomposition +from pyspark.mllib.linalg import _convert_to_vector, DenseMatrix, Matrix, QRDecomposition from pyspark.mllib.stat import MultivariateStatisticalSummary from pyspark.storagelevel import StorageLevel -__all__ = ['DistributedMatrix', 'RowMatrix', 'IndexedRow', - 'IndexedRowMatrix', 'MatrixEntry', 'CoordinateMatrix', - 'BlockMatrix'] +__all__ = ['BlockMatrix', 'CoordinateMatrix', 'DistributedMatrix', 'IndexedRow', + 'IndexedRowMatrix', 'MatrixEntry', 'RowMatrix', 'SingularValueDecomposition'] class DistributedMatrix(object): @@ -301,6 +300,136 @@ def tallSkinnyQR(self, computeQ=False): R = decomp.call("R") return QRDecomposition(Q, R) + @since('2.2.0') + def computeSVD(self, k, computeU=False, rCond=1e-9): + """ + Computes the singular value decomposition of the RowMatrix. + + The given row matrix A of dimension (m X n) is decomposed into + U * s * V'T where + + * U: (m X k) (left singular vectors) is a RowMatrix whose + columns are the eigenvectors of (A X A') + * s: DenseVector consisting of square root of the eigenvalues + (singular values) in descending order. + * v: (n X k) (right singular vectors) is a Matrix whose columns + are the eigenvectors of (A' X A) + + For more specific details on implementation, please refer + the Scala documentation. + + :param k: Number of leading singular values to keep (`0 < k <= n`). + It might return less than k if there are numerically zero singular values + or there are not enough Ritz values converged before the maximum number of + Arnoldi update iterations is reached (in case that matrix A is ill-conditioned). + :param computeU: Whether or not to compute U. If set to be + True, then U is computed by A * V * s^-1 + :param rCond: Reciprocal condition number. All singular values + smaller than rCond * s[0] are treated as zero + where s[0] is the largest singular value. + :returns: :py:class:`SingularValueDecomposition` + + >>> rows = sc.parallelize([[3, 1, 1], [-1, 3, 1]]) + >>> rm = RowMatrix(rows) + + >>> svd_model = rm.computeSVD(2, True) + >>> svd_model.U.rows.collect() + [DenseVector([-0.7071, 0.7071]), DenseVector([-0.7071, -0.7071])] + >>> svd_model.s + DenseVector([3.4641, 3.1623]) + >>> svd_model.V + DenseMatrix(3, 2, [-0.4082, -0.8165, -0.4082, 0.8944, -0.4472, 0.0], 0) + """ + j_model = self._java_matrix_wrapper.call( + "computeSVD", int(k), bool(computeU), float(rCond)) + return SingularValueDecomposition(j_model) + + @since('2.2.0') + def computePrincipalComponents(self, k): + """ + Computes the k principal components of the given row matrix + + .. note:: This cannot be computed on matrices with more than 65535 columns. + + :param k: Number of principal components to keep. + :returns: :py:class:`pyspark.mllib.linalg.DenseMatrix` + + >>> rows = sc.parallelize([[1, 2, 3], [2, 4, 5], [3, 6, 1]]) + >>> rm = RowMatrix(rows) + + >>> # Returns the two principal components of rm + >>> pca = rm.computePrincipalComponents(2) + >>> pca + DenseMatrix(3, 2, [-0.349, -0.6981, 0.6252, -0.2796, -0.5592, -0.7805], 0) + + >>> # Transform into new dimensions with the greatest variance. + >>> rm.multiply(pca).rows.collect() # doctest: +NORMALIZE_WHITESPACE + [DenseVector([0.1305, -3.7394]), DenseVector([-0.3642, -6.6983]), \ + DenseVector([-4.6102, -4.9745])] + """ + return self._java_matrix_wrapper.call("computePrincipalComponents", k) + + @since('2.2.0') + def multiply(self, matrix): + """ + Multiply this matrix by a local dense matrix on the right. + + :param matrix: a local dense matrix whose number of rows must match the number of columns + of this matrix + :returns: :py:class:`RowMatrix` + + >>> rm = RowMatrix(sc.parallelize([[0, 1], [2, 3]])) + >>> rm.multiply(DenseMatrix(2, 2, [0, 2, 1, 3])).rows.collect() + [DenseVector([2.0, 3.0]), DenseVector([6.0, 11.0])] + """ + if not isinstance(matrix, DenseMatrix): + raise ValueError("Only multiplication with DenseMatrix " + "is supported.") + j_model = self._java_matrix_wrapper.call("multiply", matrix) + return RowMatrix(j_model) + + +class SingularValueDecomposition(JavaModelWrapper): + """ + Represents singular value decomposition (SVD) factors. + + .. versionadded:: 2.2.0 + """ + + @property + @since('2.2.0') + def U(self): + """ + Returns a distributed matrix whose columns are the left + singular vectors of the SingularValueDecomposition if computeU was set to be True. + """ + u = self.call("U") + if u is not None: + mat_name = u.getClass().getSimpleName() + if mat_name == "RowMatrix": + return RowMatrix(u) + elif mat_name == "IndexedRowMatrix": + return IndexedRowMatrix(u) + else: + raise TypeError("Expected RowMatrix/IndexedRowMatrix got %s" % mat_name) + + @property + @since('2.2.0') + def s(self): + """ + Returns a DenseVector with singular values in descending order. + """ + return self.call("s") + + @property + @since('2.2.0') + def V(self): + """ + Returns a DenseMatrix whose columns are the right singular + vectors of the SingularValueDecomposition. + """ + return self.call("V") + class IndexedRow(object): """ @@ -528,6 +657,68 @@ def toBlockMatrix(self, rowsPerBlock=1024, colsPerBlock=1024): colsPerBlock) return BlockMatrix(java_block_matrix, rowsPerBlock, colsPerBlock) + @since('2.2.0') + def computeSVD(self, k, computeU=False, rCond=1e-9): + """ + Computes the singular value decomposition of the IndexedRowMatrix. + + The given row matrix A of dimension (m X n) is decomposed into + U * s * V'T where + + * U: (m X k) (left singular vectors) is a IndexedRowMatrix + whose columns are the eigenvectors of (A X A') + * s: DenseVector consisting of square root of the eigenvalues + (singular values) in descending order. + * v: (n X k) (right singular vectors) is a Matrix whose columns + are the eigenvectors of (A' X A) + + For more specific details on implementation, please refer + the scala documentation. + + :param k: Number of leading singular values to keep (`0 < k <= n`). + It might return less than k if there are numerically zero singular values + or there are not enough Ritz values converged before the maximum number of + Arnoldi update iterations is reached (in case that matrix A is ill-conditioned). + :param computeU: Whether or not to compute U. If set to be + True, then U is computed by A * V * s^-1 + :param rCond: Reciprocal condition number. All singular values + smaller than rCond * s[0] are treated as zero + where s[0] is the largest singular value. + :returns: SingularValueDecomposition object + + >>> rows = [(0, (3, 1, 1)), (1, (-1, 3, 1))] + >>> irm = IndexedRowMatrix(sc.parallelize(rows)) + >>> svd_model = irm.computeSVD(2, True) + >>> svd_model.U.rows.collect() # doctest: +NORMALIZE_WHITESPACE + [IndexedRow(0, [-0.707106781187,0.707106781187]),\ + IndexedRow(1, [-0.707106781187,-0.707106781187])] + >>> svd_model.s + DenseVector([3.4641, 3.1623]) + >>> svd_model.V + DenseMatrix(3, 2, [-0.4082, -0.8165, -0.4082, 0.8944, -0.4472, 0.0], 0) + """ + j_model = self._java_matrix_wrapper.call( + "computeSVD", int(k), bool(computeU), float(rCond)) + return SingularValueDecomposition(j_model) + + @since('2.2.0') + def multiply(self, matrix): + """ + Multiply this matrix by a local dense matrix on the right. + + :param matrix: a local dense matrix whose number of rows must match the number of columns + of this matrix + :returns: :py:class:`IndexedRowMatrix` + + >>> mat = IndexedRowMatrix(sc.parallelize([(0, (0, 1)), (1, (2, 3))])) + >>> mat.multiply(DenseMatrix(2, 2, [0, 2, 1, 3])).rows.collect() + [IndexedRow(0, [2.0,3.0]), IndexedRow(1, [6.0,11.0])] + """ + if not isinstance(matrix, DenseMatrix): + raise ValueError("Only multiplication with DenseMatrix " + "is supported.") + return IndexedRowMatrix(self._java_matrix_wrapper.call("multiply", matrix)) + class MatrixEntry(object): """ diff --git a/python/pyspark/mllib/recommendation.py b/python/pyspark/mllib/recommendation.py index 732300ee9c2c9..81182881352bb 100644 --- a/python/pyspark/mllib/recommendation.py +++ b/python/pyspark/mllib/recommendation.py @@ -249,7 +249,7 @@ def train(cls, ratings, rank, iterations=5, lambda_=0.01, blocks=-1, nonnegative :param ratings: RDD of `Rating` or (userID, productID, rating) tuple. :param rank: - Rank of the feature matrices computed (number of features). + Number of features to use (also referred to as the number of latent factors). :param iterations: Number of iterations of ALS. (default: 5) @@ -287,7 +287,7 @@ def trainImplicit(cls, ratings, rank, iterations=5, lambda_=0.01, blocks=-1, alp :param ratings: RDD of `Rating` or (userID, productID, rating) tuple. :param rank: - Rank of the feature matrices computed (number of features). + Number of features to use (also referred to as the number of latent factors). :param iterations: Number of iterations of ALS. (default: 5) diff --git a/python/pyspark/mllib/tests.py b/python/pyspark/mllib/tests.py index c519883cdd73b..1037bab7f1088 100644 --- a/python/pyspark/mllib/tests.py +++ b/python/pyspark/mllib/tests.py @@ -23,6 +23,7 @@ import sys import tempfile import array as pyarray +from math import sqrt from time import time, sleep from shutil import rmtree @@ -54,6 +55,7 @@ from pyspark.mllib.clustering import StreamingKMeans, StreamingKMeansModel from pyspark.mllib.linalg import Vector, SparseVector, DenseVector, VectorUDT, _convert_to_vector,\ DenseMatrix, SparseMatrix, Vectors, Matrices, MatrixUDT +from pyspark.mllib.linalg.distributed import RowMatrix from pyspark.mllib.classification import StreamingLogisticRegressionWithSGD from pyspark.mllib.recommendation import Rating from pyspark.mllib.regression import LabeledPoint, StreamingLinearRegressionWithSGD @@ -853,6 +855,17 @@ def serialize(l): self.assertEqual(sv, serialize(lil.tocsr())) self.assertEqual(sv, serialize(lil.todok())) + def test_convert_to_vector(self): + from scipy.sparse import csc_matrix + # Create a CSC matrix with non-sorted indices + indptr = array([0, 2]) + indices = array([3, 1]) + data = array([2.0, 1.0]) + csc = csc_matrix((data, indices, indptr)) + self.assertFalse(csc.has_sorted_indices) + sv = SparseVector(4, {1: 1, 3: 2}) + self.assertEqual(sv, _convert_to_vector(csc)) + def test_dot(self): from scipy.sparse import lil_matrix lil = lil_matrix((4, 1)) @@ -1688,6 +1701,67 @@ def test_binary_term_freqs(self): ": expected " + str(expected[i]) + ", got " + str(output[i])) +class DimensionalityReductionTests(MLlibTestCase): + + denseData = [ + Vectors.dense([0.0, 1.0, 2.0]), + Vectors.dense([3.0, 4.0, 5.0]), + Vectors.dense([6.0, 7.0, 8.0]), + Vectors.dense([9.0, 0.0, 1.0]) + ] + sparseData = [ + Vectors.sparse(3, [(1, 1.0), (2, 2.0)]), + Vectors.sparse(3, [(0, 3.0), (1, 4.0), (2, 5.0)]), + Vectors.sparse(3, [(0, 6.0), (1, 7.0), (2, 8.0)]), + Vectors.sparse(3, [(0, 9.0), (2, 1.0)]) + ] + + def assertEqualUpToSign(self, vecA, vecB): + eq1 = vecA - vecB + eq2 = vecA + vecB + self.assertTrue(sum(abs(eq1)) < 1e-6 or sum(abs(eq2)) < 1e-6) + + def test_svd(self): + denseMat = RowMatrix(self.sc.parallelize(self.denseData)) + sparseMat = RowMatrix(self.sc.parallelize(self.sparseData)) + m = 4 + n = 3 + for mat in [denseMat, sparseMat]: + for k in range(1, 4): + rm = mat.computeSVD(k, computeU=True) + self.assertEqual(rm.s.size, k) + self.assertEqual(rm.U.numRows(), m) + self.assertEqual(rm.U.numCols(), k) + self.assertEqual(rm.V.numRows, n) + self.assertEqual(rm.V.numCols, k) + + # Test that U returned is None if computeU is set to False. + self.assertEqual(mat.computeSVD(1).U, None) + + # Test that low rank matrices cannot have number of singular values + # greater than a limit. + rm = RowMatrix(self.sc.parallelize(tile([1, 2, 3], (3, 1)))) + self.assertEqual(rm.computeSVD(3, False, 1e-6).s.size, 1) + + def test_pca(self): + expected_pcs = array([ + [0.0, 1.0, 0.0], + [sqrt(2.0) / 2.0, 0.0, sqrt(2.0) / 2.0], + [sqrt(2.0) / 2.0, 0.0, -sqrt(2.0) / 2.0] + ]) + n = 3 + denseMat = RowMatrix(self.sc.parallelize(self.denseData)) + sparseMat = RowMatrix(self.sc.parallelize(self.sparseData)) + for mat in [denseMat, sparseMat]: + for k in range(1, 4): + pcs = mat.computePrincipalComponents(k) + self.assertEqual(pcs.numRows, n) + self.assertEqual(pcs.numCols, k) + + # We can just test the updated principal component for equality. + self.assertEqualUpToSign(pcs.toArray()[:, k - 1], expected_pcs[:, k - 1]) + + if __name__ == "__main__": from pyspark.mllib.tests import * if not _have_scipy: diff --git a/python/pyspark/mllib/tree.py b/python/pyspark/mllib/tree.py index a6089fc8b9d32..619fa16d463f5 100644 --- a/python/pyspark/mllib/tree.py +++ b/python/pyspark/mllib/tree.py @@ -199,9 +199,9 @@ def trainClassifier(cls, data, numClasses, categoricalFeaturesInfo, >>> print(model.toDebugString()) DecisionTreeModel classifier of depth 1 with 3 nodes - If (feature 0 <= 0.0) + If (feature 0 <= 0.5) Predict: 0.0 - Else (feature 0 > 0.0) + Else (feature 0 > 0.5) Predict: 1.0 >>> model.predict(array([1.0])) @@ -383,14 +383,14 @@ def trainClassifier(cls, data, numClasses, categoricalFeaturesInfo, numTrees, Tree 0: Predict: 1.0 Tree 1: - If (feature 0 <= 1.0) + If (feature 0 <= 1.5) Predict: 0.0 - Else (feature 0 > 1.0) + Else (feature 0 > 1.5) Predict: 1.0 Tree 2: - If (feature 0 <= 1.0) + If (feature 0 <= 1.5) Predict: 0.0 - Else (feature 0 > 1.0) + Else (feature 0 > 1.5) Predict: 1.0 >>> model.predict([2.0]) diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py index 4a7b66268221a..873b00fbe5344 100644 --- a/python/pyspark/rdd.py +++ b/python/pyspark/rdd.py @@ -1804,17 +1804,31 @@ def combineByKey(self, createCombiner, mergeValue, mergeCombiners, a one-element list) - C{mergeValue}, to merge a V into a C (e.g., adds it to the end of a list) - - C{mergeCombiners}, to combine two C's into a single one. + - C{mergeCombiners}, to combine two C's into a single one (e.g., merges + the lists) + + To avoid memory allocation, both mergeValue and mergeCombiners are allowed to + modify and return their first argument instead of creating a new C. In addition, users can control the partitioning of the output RDD. .. note:: V and C can be different -- for example, one might group an RDD of type (Int, Int) into an RDD of type (Int, List[Int]). - >>> x = sc.parallelize([("a", 1), ("b", 1), ("a", 1)]) - >>> def add(a, b): return a + str(b) - >>> sorted(x.combineByKey(str, add, add).collect()) - [('a', '11'), ('b', '1')] + >>> x = sc.parallelize([("a", 1), ("b", 1), ("a", 2)]) + >>> def to_list(a): + ... return [a] + ... + >>> def append(a, b): + ... a.append(b) + ... return a + ... + >>> def extend(a, b): + ... a.extend(b) + ... return a + ... + >>> sorted(x.combineByKey(to_list, append, extend).collect()) + [('a', [1, 2]), ('b', [1])] """ if numPartitions is None: numPartitions = self._defaultReducePartitions() @@ -2072,10 +2086,12 @@ def coalesce(self, numPartitions, shuffle=False): batchSize = min(10, self.ctx._batchSize or 1024) ser = BatchedSerializer(PickleSerializer(), batchSize) selfCopy = self._reserialize(ser) + jrdd_deserializer = selfCopy._jrdd_deserializer jrdd = selfCopy._jrdd.coalesce(numPartitions, shuffle) else: + jrdd_deserializer = self._jrdd_deserializer jrdd = self._jrdd.coalesce(numPartitions, shuffle) - return RDD(jrdd, self.ctx, self._jrdd_deserializer) + return RDD(jrdd, self.ctx, jrdd_deserializer) def zip(self, other): """ diff --git a/python/pyspark/shell.py b/python/pyspark/shell.py index c1917d2be69d8..b5fcf7092d93a 100644 --- a/python/pyspark/shell.py +++ b/python/pyspark/shell.py @@ -24,13 +24,13 @@ import atexit import os import platform +import warnings import py4j -import pyspark +from pyspark import SparkConf from pyspark.context import SparkContext from pyspark.sql import SparkSession, SQLContext -from pyspark.storagelevel import StorageLevel if os.environ.get("SPARK_EXECUTOR_URI"): SparkContext.setSystemProperty("spark.executor.uri", os.environ["SPARK_EXECUTOR_URI"]) @@ -39,13 +39,23 @@ try: # Try to access HiveConf, it will raise exception if Hive is not added - SparkContext._jvm.org.apache.hadoop.hive.conf.HiveConf() - spark = SparkSession.builder\ - .enableHiveSupport()\ - .getOrCreate() + conf = SparkConf() + if conf.get('spark.sql.catalogImplementation', 'hive').lower() == 'hive': + SparkContext._jvm.org.apache.hadoop.hive.conf.HiveConf() + spark = SparkSession.builder\ + .enableHiveSupport()\ + .getOrCreate() + else: + spark = SparkSession.builder.getOrCreate() except py4j.protocol.Py4JError: + if conf.get('spark.sql.catalogImplementation', '').lower() == 'hive': + warnings.warn("Fall back to non-hive support because failing to access HiveConf, " + "please make sure you build spark with hive") spark = SparkSession.builder.getOrCreate() except TypeError: + if conf.get('spark.sql.catalogImplementation', '').lower() == 'hive': + warnings.warn("Fall back to non-hive support because failing to access HiveConf, " + "please make sure you build spark with hive") spark = SparkSession.builder.getOrCreate() sc = spark.sparkContext diff --git a/python/pyspark/sql/catalog.py b/python/pyspark/sql/catalog.py index 253a750629170..41e68a45a6159 100644 --- a/python/pyspark/sql/catalog.py +++ b/python/pyspark/sql/catalog.py @@ -72,10 +72,10 @@ def listDatabases(self): @ignore_unicode_prefix @since(2.0) def listTables(self, dbName=None): - """Returns a list of tables in the specified database. + """Returns a list of tables/views in the specified database. If no database is specified, the current database is used. - This includes all temporary tables. + This includes all temporary views. """ if dbName is None: dbName = self.currentDatabase() @@ -115,7 +115,7 @@ def listFunctions(self, dbName=None): @ignore_unicode_prefix @since(2.0) def listColumns(self, tableName, dbName=None): - """Returns a list of columns for the given table in the specified database. + """Returns a list of columns for the given table/view in the specified database. If no database is specified, the current database is used. @@ -161,14 +161,15 @@ def createExternalTable(self, tableName, path=None, source=None, schema=None, ** def createTable(self, tableName, path=None, source=None, schema=None, **options): """Creates a table based on the dataset in a data source. - It returns the DataFrame associated with the external table. + It returns the DataFrame associated with the table. The data source is specified by the ``source`` and a set of ``options``. If ``source`` is not specified, the default data source configured by - ``spark.sql.sources.default`` will be used. + ``spark.sql.sources.default`` will be used. When ``path`` is specified, an external table is + created from the data at the given path. Otherwise a managed table is created. Optionally, a schema can be provided as the schema of the returned :class:`DataFrame` and - created external table. + created table. :return: :class:`DataFrame` """ @@ -276,14 +277,24 @@ def clearCache(self): @since(2.0) def refreshTable(self, tableName): - """Invalidate and refresh all the cached metadata of the given table.""" + """Invalidates and refreshes all the cached data and metadata of the given table.""" self._jcatalog.refreshTable(tableName) @since('2.1.1') def recoverPartitions(self, tableName): - """Recover all the partitions of the given table and update the catalog.""" + """Recovers all the partitions of the given table and update the catalog. + + Only works with a partitioned table, and not a view. + """ self._jcatalog.recoverPartitions(tableName) + @since('2.2.0') + def refreshByPath(self, path): + """Invalidates and refreshes all the cached data (and the associated metadata) for any + DataFrame that contains the given data source path. + """ + self._jcatalog.refreshByPath(path) + def _reset(self): """(Internal use only) Drop all existing databases (except "default"), tables, partitions and functions, and set the current database to "default". diff --git a/python/pyspark/sql/column.py b/python/pyspark/sql/column.py index ec05c18d4f062..e753ed402cdd7 100644 --- a/python/pyspark/sql/column.py +++ b/python/pyspark/sql/column.py @@ -171,6 +171,61 @@ def __init__(self, jc): __ge__ = _bin_op("geq") __gt__ = _bin_op("gt") + _eqNullSafe_doc = """ + Equality test that is safe for null values. + + :param other: a value or :class:`Column` + + >>> from pyspark.sql import Row + >>> df1 = spark.createDataFrame([ + ... Row(id=1, value='foo'), + ... Row(id=2, value=None) + ... ]) + >>> df1.select( + ... df1['value'] == 'foo', + ... df1['value'].eqNullSafe('foo'), + ... df1['value'].eqNullSafe(None) + ... ).show() + +-------------+---------------+----------------+ + |(value = foo)|(value <=> foo)|(value <=> NULL)| + +-------------+---------------+----------------+ + | true| true| false| + | null| false| true| + +-------------+---------------+----------------+ + >>> df2 = spark.createDataFrame([ + ... Row(value = 'bar'), + ... Row(value = None) + ... ]) + >>> df1.join(df2, df1["value"] == df2["value"]).count() + 0 + >>> df1.join(df2, df1["value"].eqNullSafe(df2["value"])).count() + 1 + >>> df2 = spark.createDataFrame([ + ... Row(id=1, value=float('NaN')), + ... Row(id=2, value=42.0), + ... Row(id=3, value=None) + ... ]) + >>> df2.select( + ... df2['value'].eqNullSafe(None), + ... df2['value'].eqNullSafe(float('NaN')), + ... df2['value'].eqNullSafe(42.0) + ... ).show() + +----------------+---------------+----------------+ + |(value <=> NULL)|(value <=> NaN)|(value <=> 42.0)| + +----------------+---------------+----------------+ + | false| true| false| + | false| false| true| + | true| false| false| + +----------------+---------------+----------------+ + + .. note:: Unlike Pandas, PySpark doesn't consider NaN values to be NULL. + See the `NaN Semantics`_ for details. + .. _NaN Semantics: + https://spark.apache.org/docs/latest/sql-programming-guide.html#nan-semantics + .. versionadded:: 2.3.0 + """ + eqNullSafe = _bin_op("eqNullSafe", _eqNullSafe_doc) + # `and`, `or`, `not` cannot be overloaded in Python, # so use bitwise operators as boolean operators __and__ = _bin_op('and') @@ -185,9 +240,43 @@ def __contains__(self, item): "in a string column or 'array_contains' function for an array column.") # bitwise operators - bitwiseOR = _bin_op("bitwiseOR") - bitwiseAND = _bin_op("bitwiseAND") - bitwiseXOR = _bin_op("bitwiseXOR") + _bitwiseOR_doc = """ + Compute bitwise OR of this expression with another expression. + + :param other: a value or :class:`Column` to calculate bitwise or(|) against + this :class:`Column`. + + >>> from pyspark.sql import Row + >>> df = spark.createDataFrame([Row(a=170, b=75)]) + >>> df.select(df.a.bitwiseOR(df.b)).collect() + [Row((a | b)=235)] + """ + _bitwiseAND_doc = """ + Compute bitwise AND of this expression with another expression. + + :param other: a value or :class:`Column` to calculate bitwise and(&) against + this :class:`Column`. + + >>> from pyspark.sql import Row + >>> df = spark.createDataFrame([Row(a=170, b=75)]) + >>> df.select(df.a.bitwiseAND(df.b)).collect() + [Row((a & b)=10)] + """ + _bitwiseXOR_doc = """ + Compute bitwise XOR of this expression with another expression. + + :param other: a value or :class:`Column` to calculate bitwise xor(^) against + this :class:`Column`. + + >>> from pyspark.sql import Row + >>> df = spark.createDataFrame([Row(a=170, b=75)]) + >>> df.select(df.a.bitwiseXOR(df.b)).collect() + [Row((a ^ b)=225)] + """ + + bitwiseOR = _bin_op("bitwiseOR", _bitwiseOR_doc) + bitwiseAND = _bin_op("bitwiseAND", _bitwiseAND_doc) + bitwiseXOR = _bin_op("bitwiseXOR", _bitwiseXOR_doc) @since(1.3) def getItem(self, key): @@ -195,7 +284,7 @@ def getItem(self, key): An expression that gets an item at position ``ordinal`` out of a list, or gets an item by key out of a dict. - >>> df = sc.parallelize([([1, 2], {"key": "value"})]).toDF(["l", "d"]) + >>> df = spark.createDataFrame([([1, 2], {"key": "value"})], ["l", "d"]) >>> df.select(df.l.getItem(0), df.d.getItem("key")).show() +----+------+ |l[0]|d[key]| @@ -217,7 +306,7 @@ def getField(self, name): An expression that gets a field by name in a StructField. >>> from pyspark.sql import Row - >>> df = sc.parallelize([Row(r=Row(a=1, b="b"))]).toDF() + >>> df = spark.createDataFrame([Row(r=Row(a=1, b="b"))]) >>> df.select(df.r.getField("b")).show() +---+ |r.b| @@ -250,11 +339,59 @@ def __iter__(self): raise TypeError("Column is not iterable") # string methods - contains = _bin_op("contains") - rlike = _bin_op("rlike") - like = _bin_op("like") - startswith = _bin_op("startsWith") - endswith = _bin_op("endsWith") + _contains_doc = """ + Contains the other element. Returns a boolean :class:`Column` based on a string match. + + :param other: string in line + + >>> df.filter(df.name.contains('o')).collect() + [Row(age=5, name=u'Bob')] + """ + _rlike_doc = """ + SQL RLIKE expression (LIKE with Regex). Returns a boolean :class:`Column` based on a regex + match. + + :param other: an extended regex expression + + >>> df.filter(df.name.rlike('ice$')).collect() + [Row(age=2, name=u'Alice')] + """ + _like_doc = """ + SQL like expression. Returns a boolean :class:`Column` based on a SQL LIKE match. + + :param other: a SQL LIKE pattern + + See :func:`rlike` for a regex version + + >>> df.filter(df.name.like('Al%')).collect() + [Row(age=2, name=u'Alice')] + """ + _startswith_doc = """ + String starts with. Returns a boolean :class:`Column` based on a string match. + + :param other: string at start of line (do not use a regex `^`) + + >>> df.filter(df.name.startswith('Al')).collect() + [Row(age=2, name=u'Alice')] + >>> df.filter(df.name.startswith('^Al')).collect() + [] + """ + _endswith_doc = """ + String ends with. Returns a boolean :class:`Column` based on a string match. + + :param other: string at end of line (do not use a regex `$`) + + >>> df.filter(df.name.endswith('ice')).collect() + [Row(age=2, name=u'Alice')] + >>> df.filter(df.name.endswith('ice$')).collect() + [] + """ + + contains = ignore_unicode_prefix(_bin_op("contains", _contains_doc)) + rlike = ignore_unicode_prefix(_bin_op("rlike", _rlike_doc)) + like = ignore_unicode_prefix(_bin_op("like", _like_doc)) + startswith = ignore_unicode_prefix(_bin_op("startsWith", _startswith_doc)) + endswith = ignore_unicode_prefix(_bin_op("endsWith", _endswith_doc)) @ignore_unicode_prefix @since(1.3) @@ -298,13 +435,45 @@ def isin(self, *cols): return Column(jc) # order - asc = _unary_op("asc", "Returns a sort expression based on the" - " ascending order of the given column name.") - desc = _unary_op("desc", "Returns a sort expression based on the" - " descending order of the given column name.") + _asc_doc = """ + Returns a sort expression based on the ascending order of the given column name + + >>> from pyspark.sql import Row + >>> df = spark.createDataFrame([Row(name=u'Tom', height=80), Row(name=u'Alice', height=None)]) + >>> df.select(df.name).orderBy(df.name.asc()).collect() + [Row(name=u'Alice'), Row(name=u'Tom')] + """ + _desc_doc = """ + Returns a sort expression based on the descending order of the given column name. + + >>> from pyspark.sql import Row + >>> df = spark.createDataFrame([Row(name=u'Tom', height=80), Row(name=u'Alice', height=None)]) + >>> df.select(df.name).orderBy(df.name.desc()).collect() + [Row(name=u'Tom'), Row(name=u'Alice')] + """ + + asc = ignore_unicode_prefix(_unary_op("asc", _asc_doc)) + desc = ignore_unicode_prefix(_unary_op("desc", _desc_doc)) + + _isNull_doc = """ + True if the current expression is null. + + >>> from pyspark.sql import Row + >>> df = spark.createDataFrame([Row(name=u'Tom', height=80), Row(name=u'Alice', height=None)]) + >>> df.filter(df.height.isNull()).collect() + [Row(height=None, name=u'Alice')] + """ + _isNotNull_doc = """ + True if the current expression is NOT null. + + >>> from pyspark.sql import Row + >>> df = spark.createDataFrame([Row(name=u'Tom', height=80), Row(name=u'Alice', height=None)]) + >>> df.filter(df.height.isNotNull()).collect() + [Row(height=80, name=u'Tom')] + """ - isNull = _unary_op("isNull", "True if the current expression is null.") - isNotNull = _unary_op("isNotNull", "True if the current expression is not null.") + isNull = ignore_unicode_prefix(_unary_op("isNull", _isNull_doc)) + isNotNull = ignore_unicode_prefix(_unary_op("isNotNull", _isNotNull_doc)) @since(1.3) def alias(self, *alias, **kwargs): @@ -469,7 +638,7 @@ def _test(): .appName("sql.column tests")\ .getOrCreate() sc = spark.sparkContext - globs['sc'] = sc + globs['spark'] = spark globs['df'] = sc.parallelize([(2, 'Alice'), (5, 'Bob')]) \ .toDF(StructType([StructField('age', IntegerType()), StructField('name', StringType())])) diff --git a/python/pyspark/sql/context.py b/python/pyspark/sql/context.py index c22f4b87e1a78..fdb7abbad4e5f 100644 --- a/python/pyspark/sql/context.py +++ b/python/pyspark/sql/context.py @@ -385,7 +385,7 @@ def sql(self, sqlQuery): @since(1.0) def table(self, tableName): - """Returns the specified table as a :class:`DataFrame`. + """Returns the specified table or view as a :class:`DataFrame`. :return: :class:`DataFrame` diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index bb6df22682095..7b67985f2b320 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -25,6 +25,8 @@ else: from itertools import imap as map +import warnings + from pyspark import copy_func, since from pyspark.rdd import RDD, _load_from_socket, ignore_unicode_prefix from pyspark.serializers import BatchedSerializer, PickleSerializer, UTF8Deserializer @@ -288,13 +290,15 @@ def isStreaming(self): return self._jdf.isStreaming() @since(1.3) - def show(self, n=20, truncate=True): + def show(self, n=20, truncate=True, vertical=False): """Prints the first ``n`` rows to the console. :param n: Number of rows to show. :param truncate: If set to True, truncate strings longer than 20 chars by default. If set to a number greater than one, truncates long strings to length ``truncate`` and align cells right. + :param vertical: If set to True, print output rows vertically (one line + per column value). >>> df DataFrame[age: int, name: string] @@ -312,11 +316,18 @@ def show(self, n=20, truncate=True): | 2| Ali| | 5| Bob| +---+----+ + >>> df.show(vertical=True) + -RECORD 0----- + age | 2 + name | Alice + -RECORD 1----- + age | 5 + name | Bob """ if isinstance(truncate, bool) and truncate: - print(self._jdf.showString(n, 20)) + print(self._jdf.showString(n, 20, vertical)) else: - print(self._jdf.showString(n, int(truncate))) + print(self._jdf.showString(n, int(truncate), vertical)) def __repr__(self): return "DataFrame[%s]" % (", ".join("%s: %s" % c for c in self.dtypes)) @@ -369,6 +380,35 @@ def withWatermark(self, eventTime, delayThreshold): jdf = self._jdf.withWatermark(eventTime, delayThreshold) return DataFrame(jdf, self.sql_ctx) + @since(2.2) + def hint(self, name, *parameters): + """Specifies some hint on the current DataFrame. + + :param name: A name of the hint. + :param parameters: Optional parameters. + :return: :class:`DataFrame` + + >>> df.join(df2.hint("broadcast"), "name").show() + +----+---+------+ + |name|age|height| + +----+---+------+ + | Bob| 5| 85| + +----+---+------+ + """ + if len(parameters) == 1 and isinstance(parameters[0], list): + parameters = parameters[0] + + if not isinstance(name, str): + raise TypeError("name should be provided as str, got {0}".format(type(name))) + + for p in parameters: + if not isinstance(p, str): + raise TypeError( + "all parameters should be str, got {0} of type {1}".format(p, type(p))) + + jdf = self._jdf.hint(name, self._jseq(parameters)) + return DataFrame(jdf, self.sql_ctx) + @since(1.3) def count(self): """Returns the number of rows in this :class:`DataFrame`. @@ -1236,7 +1276,7 @@ def fillna(self, value, subset=None): Value to replace null values with. If the value is a dict, then `subset` is ignored and `value` must be a mapping from column name (string) to replacement value. The replacement value must be - an int, long, float, or string. + an int, long, float, boolean, or string. :param subset: optional list of column names to consider. Columns specified in subset that do not have matching data type are ignored. For example, if `value` is a string, and subset contains a non-string column, @@ -1281,7 +1321,7 @@ def fillna(self, value, subset=None): return DataFrame(self._jdf.na().fill(value, self._jseq(subset)), self.sql_ctx) @since(1.4) - def replace(self, to_replace, value, subset=None): + def replace(self, to_replace, value=None, subset=None): """Returns a new :class:`DataFrame` replacing a value with another value. :func:`DataFrame.replace` and :func:`DataFrameNaFunctions.replace` are aliases of each other. @@ -1326,43 +1366,72 @@ def replace(self, to_replace, value, subset=None): |null| null|null| +----+------+----+ """ - if not isinstance(to_replace, (float, int, long, basestring, list, tuple, dict)): + # Helper functions + def all_of(types): + """Given a type or tuple of types and a sequence of xs + check if each x is instance of type(s) + + >>> all_of(bool)([True, False]) + True + >>> all_of(basestring)(["a", 1]) + False + """ + def all_of_(xs): + return all(isinstance(x, types) for x in xs) + return all_of_ + + all_of_bool = all_of(bool) + all_of_str = all_of(basestring) + all_of_numeric = all_of((float, int, long)) + + # Validate input types + valid_types = (bool, float, int, long, basestring, list, tuple) + if not isinstance(to_replace, valid_types + (dict, )): raise ValueError( - "to_replace should be a float, int, long, string, list, tuple, or dict") + "to_replace should be a float, int, long, string, list, tuple, or dict. " + "Got {0}".format(type(to_replace))) + + if not isinstance(value, valid_types) and not isinstance(to_replace, dict): + raise ValueError("If to_replace is not a dict, value should be " + "a float, int, long, string, list, or tuple. " + "Got {0}".format(type(value))) - if not isinstance(value, (float, int, long, basestring, list, tuple)): - raise ValueError("value should be a float, int, long, string, list, or tuple") + if isinstance(to_replace, (list, tuple)) and isinstance(value, (list, tuple)): + if len(to_replace) != len(value): + raise ValueError("to_replace and value lists should be of the same length. " + "Got {0} and {1}".format(len(to_replace), len(value))) - rep_dict = dict() + if not (subset is None or isinstance(subset, (list, tuple, basestring))): + raise ValueError("subset should be a list or tuple of column names, " + "column name or None. Got {0}".format(type(subset))) + # Reshape input arguments if necessary if isinstance(to_replace, (float, int, long, basestring)): to_replace = [to_replace] - if isinstance(to_replace, tuple): - to_replace = list(to_replace) + if isinstance(value, (float, int, long, basestring)): + value = [value for _ in range(len(to_replace))] - if isinstance(value, tuple): - value = list(value) - - if isinstance(to_replace, list) and isinstance(value, list): - if len(to_replace) != len(value): - raise ValueError("to_replace and value lists should be of the same length") - rep_dict = dict(zip(to_replace, value)) - elif isinstance(to_replace, list) and isinstance(value, (float, int, long, basestring)): - rep_dict = dict([(tr, value) for tr in to_replace]) - elif isinstance(to_replace, dict): + if isinstance(to_replace, dict): rep_dict = to_replace + if value is not None: + warnings.warn("to_replace is a dict and value is not None. value will be ignored.") + else: + rep_dict = dict(zip(to_replace, value)) - if subset is None: - return DataFrame(self._jdf.na().replace('*', rep_dict), self.sql_ctx) - elif isinstance(subset, basestring): + if isinstance(subset, basestring): subset = [subset] - if not isinstance(subset, (list, tuple)): - raise ValueError("subset should be a list or tuple of column names") + # Verify we were not passed in mixed type generics." + if not any(all_of_type(rep_dict.keys()) and all_of_type(rep_dict.values()) + for all_of_type in [all_of_bool, all_of_str, all_of_numeric]): + raise ValueError("Mixed type replacements are not supported") - return DataFrame( - self._jdf.na().replace(self._jseq(subset), self._jmap(rep_dict)), self.sql_ctx) + if subset is None: + return DataFrame(self._jdf.na().replace('*', rep_dict), self.sql_ctx) + else: + return DataFrame( + self._jdf.na().replace(self._jseq(subset), self._jmap(rep_dict)), self.sql_ctx) @since(2.0) def approxQuantile(self, col, probabilities, relativeError): @@ -1384,7 +1453,8 @@ def approxQuantile(self, col, probabilities, relativeError): Space-efficient Online Computation of Quantile Summaries]] by Greenwald and Khanna. - Note that rows containing any null values will be removed before calculation. + Note that null values will be ignored in numerical columns before calculation. + For columns only containing null values, an empty list is returned. :param col: str, list. Can be a single column name, or a list of names for multiple columns. diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index ef1e0cc29dd6a..1e6f179d7cc19 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -1327,8 +1327,8 @@ def encode(col, charset): @since(1.5) def format_number(col, d): """ - Formats the number X to a format like '#,--#,--#.--', rounded to d decimal places, - and returns the result as a string. + Formats the number X to a format like '#,--#,--#.--', rounded to d decimal places + with HALF_EVEN round mode, and returns the result as a string. :param col: the column name of the numeric value to be formatted :param d: the N decimal places @@ -1675,8 +1675,8 @@ def array(*cols): @since(1.5) def array_contains(col, value): """ - Collection function: returns True if the array contains the given value. The collection - elements and value must be of the same type. + Collection function: returns null if the array is null, true if the array contains the + given value, and false otherwise. :param col: name of column containing array :param value: value to check for in array @@ -1774,10 +1774,11 @@ def json_tuple(col, *fields): def from_json(col, schema, options={}): """ Parses a column containing a JSON string into a [[StructType]] or [[ArrayType]] - with the specified schema. Returns `null`, in the case of an unparseable string. + of [[StructType]]s with the specified schema. Returns `null`, in the case of an unparseable + string. :param col: string column in json format - :param schema: a StructType or ArrayType to use when parsing the json column + :param schema: a StructType or ArrayType of StructType to use when parsing the json column :param options: options to control parsing. accepts the same options as the json datasource >>> from pyspark.sql.types import * @@ -1802,10 +1803,10 @@ def from_json(col, schema, options={}): @since(2.1) def to_json(col, options={}): """ - Converts a column containing a [[StructType]] into a JSON string. Throws an exception, - in the case of an unsupported type. + Converts a column containing a [[StructType]] or [[ArrayType]] of [[StructType]]s into a + JSON string. Throws an exception, in the case of an unsupported type. - :param col: name of column containing the struct + :param col: name of column containing the struct or array of the structs :param options: options to control converting. accepts the same options as the json datasource >>> from pyspark.sql import Row @@ -1814,6 +1815,10 @@ def to_json(col, options={}): >>> df = spark.createDataFrame(data, ("key", "value")) >>> df.select(to_json(df.value).alias("json")).collect() [Row(json=u'{"age":2,"name":"Alice"}')] + >>> data = [(1, [Row(name='Alice', age=2), Row(name='Bob', age=3)])] + >>> df = spark.createDataFrame(data, ("key", "value")) + >>> df.select(to_json(df.value).alias("json")).collect() + [Row(json=u'[{"age":2,"name":"Alice"},{"age":3,"name":"Bob"}]')] """ sc = SparkContext._active_spark_context diff --git a/python/pyspark/sql/readwriter.py b/python/pyspark/sql/readwriter.py index 45fb9b7591529..960fb882cf901 100644 --- a/python/pyspark/sql/readwriter.py +++ b/python/pyspark/sql/readwriter.py @@ -109,6 +109,11 @@ def schema(self, schema): @since(1.5) def option(self, key, value): """Adds an input option for the underlying data source. + + You can set the following option(s) for reading files: + * ``timeZone``: sets the string that indicates a timezone to be used to parse timestamps + in the JSON/CSV datasources or partition values. + If it isn't set, it uses the default value, session local timezone. """ self._jreader = self._jreader.option(key, to_str(value)) return self @@ -116,6 +121,11 @@ def option(self, key, value): @since(1.4) def options(self, **options): """Adds input options for the underlying data source. + + You can set the following option(s) for reading files: + * ``timeZone``: sets the string that indicates a timezone to be used to parse timestamps + in the JSON/CSV datasources or partition values. + If it isn't set, it uses the default value, session local timezone. """ for k in options: self._jreader = self._jreader.option(k, to_str(options[k])) @@ -159,17 +169,17 @@ def json(self, path, schema=None, primitivesAsString=None, prefersDecimal=None, allowComments=None, allowUnquotedFieldNames=None, allowSingleQuotes=None, allowNumericLeadingZero=None, allowBackslashEscapingAnyCharacter=None, mode=None, columnNameOfCorruptRecord=None, dateFormat=None, timestampFormat=None, - timeZone=None, wholeFile=None): + wholeFile=None): """ - Loads a JSON file and returns the results as a :class:`DataFrame`. + Loads JSON files and returns the results as a :class:`DataFrame`. - `JSON Lines `_(newline-delimited JSON) is supported by default. - For JSON (one record per file), set the `wholeFile` parameter to ``true``. + `JSON Lines `_ (newline-delimited JSON) is supported by default. + For JSON (one record per file), set the ``wholeFile`` parameter to ``true``. If the ``schema`` parameter is not specified, this function goes through the input once to determine the input schema. - :param path: string represents path to the JSON dataset, + :param path: string represents path to the JSON dataset, or a list of paths, or RDD of Strings storing JSON objects. :param schema: an optional :class:`pyspark.sql.types.StructType` for the input schema. :param primitivesAsString: infers all primitive values as a string type. If None is set, @@ -213,9 +223,7 @@ def json(self, path, schema=None, primitivesAsString=None, prefersDecimal=None, :param timestampFormat: sets the string that indicates a timestamp format. Custom date formats follow the formats at ``java.text.SimpleDateFormat``. This applies to timestamp type. If None is set, it uses the - default value, ``yyyy-MM-dd'T'HH:mm:ss.SSSZZ``. - :param timeZone: sets the string that indicates a timezone to be used to parse timestamps. - If None is set, it uses the default value, session local timezone. + default value, ``yyyy-MM-dd'T'HH:mm:ss.SSSXXX``. :param wholeFile: parse one record, which may span multiple lines, per file. If None is set, it uses the default value, ``false``. @@ -234,7 +242,7 @@ def json(self, path, schema=None, primitivesAsString=None, prefersDecimal=None, allowSingleQuotes=allowSingleQuotes, allowNumericLeadingZero=allowNumericLeadingZero, allowBackslashEscapingAnyCharacter=allowBackslashEscapingAnyCharacter, mode=mode, columnNameOfCorruptRecord=columnNameOfCorruptRecord, dateFormat=dateFormat, - timestampFormat=timestampFormat, timeZone=timeZone, wholeFile=wholeFile) + timestampFormat=timestampFormat, wholeFile=wholeFile) if isinstance(path, basestring): path = [path] if type(path) == list: @@ -252,7 +260,7 @@ def func(iterator): jrdd = keyed._jrdd.map(self._spark._jvm.BytesToString()) return self._df(self._jreader.json(jrdd)) else: - raise TypeError("path can be only string or RDD") + raise TypeError("path can be only string, list or RDD") @since(1.4) def table(self, tableName): @@ -269,7 +277,7 @@ def table(self, tableName): @since(1.4) def parquet(self, *paths): - """Loads a Parquet file, returning the result as a :class:`DataFrame`. + """Loads Parquet files, returning the result as a :class:`DataFrame`. You can set the following Parquet-specific option(s) for reading Parquet files: * ``mergeSchema``: sets whether we should merge schemas collected from all \ @@ -307,7 +315,7 @@ def csv(self, path, schema=None, sep=None, encoding=None, quote=None, escape=Non comment=None, header=None, inferSchema=None, ignoreLeadingWhiteSpace=None, ignoreTrailingWhiteSpace=None, nullValue=None, nanValue=None, positiveInf=None, negativeInf=None, dateFormat=None, timestampFormat=None, maxColumns=None, - maxCharsPerColumn=None, maxMalformedLogPerPartition=None, mode=None, timeZone=None, + maxCharsPerColumn=None, maxMalformedLogPerPartition=None, mode=None, columnNameOfCorruptRecord=None, wholeFile=None): """Loads a CSV file and returns the result as a :class:`DataFrame`. @@ -333,12 +341,12 @@ def csv(self, path, schema=None, sep=None, encoding=None, quote=None, escape=Non default value, ``false``. :param inferSchema: infers the input schema automatically from data. It requires one extra pass over the data. If None is set, it uses the default value, ``false``. - :param ignoreLeadingWhiteSpace: defines whether or not leading whitespaces from values - being read should be skipped. If None is set, it uses - the default value, ``false``. - :param ignoreTrailingWhiteSpace: defines whether or not trailing whitespaces from values - being read should be skipped. If None is set, it uses - the default value, ``false``. + :param ignoreLeadingWhiteSpace: A flag indicating whether or not leading whitespaces from + values being read should be skipped. If None is set, it + uses the default value, ``false``. + :param ignoreTrailingWhiteSpace: A flag indicating whether or not trailing whitespaces from + values being read should be skipped. If None is set, it + uses the default value, ``false``. :param nullValue: sets the string representation of a null value. If None is set, it uses the default value, empty string. Since 2.0.1, this ``nullValue`` param applies to all supported types including the string type. @@ -355,20 +363,16 @@ def csv(self, path, schema=None, sep=None, encoding=None, quote=None, escape=Non :param timestampFormat: sets the string that indicates a timestamp format. Custom date formats follow the formats at ``java.text.SimpleDateFormat``. This applies to timestamp type. If None is set, it uses the - default value, ``yyyy-MM-dd'T'HH:mm:ss.SSSZZ``. + default value, ``yyyy-MM-dd'T'HH:mm:ss.SSSXXX``. :param maxColumns: defines a hard limit of how many columns a record can have. If None is set, it uses the default value, ``20480``. :param maxCharsPerColumn: defines the maximum number of characters allowed for any given value being read. If None is set, it uses the default value, ``-1`` meaning unlimited length. - :param maxMalformedLogPerPartition: sets the maximum number of malformed rows Spark will - log for each partition. Malformed records beyond this - number will be ignored. If None is set, it - uses the default value, ``10``. + :param maxMalformedLogPerPartition: this parameter is no longer used since Spark 2.2.0. + If specified, it is ignored. :param mode: allows a mode for dealing with corrupt records during parsing. If None is set, it uses the default value, ``PERMISSIVE``. - :param timeZone: sets the string that indicates a timezone to be used to parse timestamps. - If None is set, it uses the default value, session local timezone. * ``PERMISSIVE`` : sets other fields to ``null`` when it meets a corrupted \ record, and puts the malformed string into a field configured by \ @@ -399,7 +403,7 @@ def csv(self, path, schema=None, sep=None, encoding=None, quote=None, escape=Non nanValue=nanValue, positiveInf=positiveInf, negativeInf=negativeInf, dateFormat=dateFormat, timestampFormat=timestampFormat, maxColumns=maxColumns, maxCharsPerColumn=maxCharsPerColumn, - maxMalformedLogPerPartition=maxMalformedLogPerPartition, mode=mode, timeZone=timeZone, + maxMalformedLogPerPartition=maxMalformedLogPerPartition, mode=mode, columnNameOfCorruptRecord=columnNameOfCorruptRecord, wholeFile=wholeFile) if isinstance(path, basestring): path = [path] @@ -407,7 +411,7 @@ def csv(self, path, schema=None, sep=None, encoding=None, quote=None, escape=Non @since(1.5) def orc(self, path): - """Loads an ORC file, returning the result as a :class:`DataFrame`. + """Loads ORC files, returning the result as a :class:`DataFrame`. .. note:: Currently ORC support is only available together with Hive support. @@ -415,7 +419,9 @@ def orc(self, path): >>> df.dtypes [('a', 'bigint'), ('b', 'int'), ('c', 'int')] """ - return self._df(self._jreader.orc(path)) + if isinstance(path, basestring): + path = [path] + return self._df(self._jreader.orc(_to_seq(self._spark._sc, path))) @since(1.4) def jdbc(self, url, table, column=None, lowerBound=None, upperBound=None, numPartitions=None, @@ -519,6 +525,11 @@ def format(self, source): @since(1.5) def option(self, key, value): """Adds an output option for the underlying data source. + + You can set the following option(s) for writing files: + * ``timeZone``: sets the string that indicates a timezone to be used to format + timestamps in the JSON/CSV datasources or partition values. + If it isn't set, it uses the default value, session local timezone. """ self._jwrite = self._jwrite.option(key, to_str(value)) return self @@ -526,6 +537,11 @@ def option(self, key, value): @since(1.4) def options(self, **options): """Adds output options for the underlying data source. + + You can set the following option(s) for writing files: + * ``timeZone``: sets the string that indicates a timezone to be used to format + timestamps in the JSON/CSV datasources or partition values. + If it isn't set, it uses the default value, session local timezone. """ for k in options: self._jwrite = self._jwrite.option(k, to_str(options[k])) @@ -617,9 +633,10 @@ def saveAsTable(self, name, format=None, mode=None, partitionBy=None, **options) self._jwrite.saveAsTable(name) @since(1.4) - def json(self, path, mode=None, compression=None, dateFormat=None, timestampFormat=None, - timeZone=None): - """Saves the content of the :class:`DataFrame` in JSON format at the specified path. + def json(self, path, mode=None, compression=None, dateFormat=None, timestampFormat=None): + """Saves the content of the :class:`DataFrame` in JSON format + (`JSON Lines text format or newline-delimited JSON `_) at the + specified path. :param path: the path in any Hadoop supported file system :param mode: specifies the behavior of the save operation when data already exists. @@ -638,16 +655,13 @@ def json(self, path, mode=None, compression=None, dateFormat=None, timestampForm :param timestampFormat: sets the string that indicates a timestamp format. Custom date formats follow the formats at ``java.text.SimpleDateFormat``. This applies to timestamp type. If None is set, it uses the - default value, ``yyyy-MM-dd'T'HH:mm:ss.SSSZZ``. - :param timeZone: sets the string that indicates a timezone to be used to format timestamps. - If None is set, it uses the default value, session local timezone. + default value, ``yyyy-MM-dd'T'HH:mm:ss.SSSXXX``. >>> df.write.json(os.path.join(tempfile.mkdtemp(), 'data')) """ self.mode(mode) self._set_opts( - compression=compression, dateFormat=dateFormat, timestampFormat=timestampFormat, - timeZone=timeZone) + compression=compression, dateFormat=dateFormat, timestampFormat=timestampFormat) self._jwrite.json(path) @since(1.4) @@ -694,7 +708,7 @@ def text(self, path, compression=None): @since(2.0) def csv(self, path, mode=None, compression=None, sep=None, quote=None, escape=None, header=None, nullValue=None, escapeQuotes=None, quoteAll=None, dateFormat=None, - timestampFormat=None, timeZone=None): + timestampFormat=None, ignoreLeadingWhiteSpace=None, ignoreTrailingWhiteSpace=None): """Saves the content of the :class:`DataFrame` in CSV format at the specified path. :param path: the path in any Hadoop supported file system @@ -716,10 +730,10 @@ def csv(self, path, mode=None, compression=None, sep=None, quote=None, escape=No empty string. :param escape: sets the single character used for escaping quotes inside an already quoted value. If None is set, it uses the default value, ``\`` - :param escapeQuotes: A flag indicating whether values containing quotes should always + :param escapeQuotes: a flag indicating whether values containing quotes should always be enclosed in quotes. If None is set, it uses the default value ``true``, escaping all values containing a quote character. - :param quoteAll: A flag indicating whether all values should always be enclosed in + :param quoteAll: a flag indicating whether all values should always be enclosed in quotes. If None is set, it uses the default value ``false``, only escaping values containing a quote character. :param header: writes the names of columns as the first line. If None is set, it uses @@ -733,16 +747,22 @@ def csv(self, path, mode=None, compression=None, sep=None, quote=None, escape=No :param timestampFormat: sets the string that indicates a timestamp format. Custom date formats follow the formats at ``java.text.SimpleDateFormat``. This applies to timestamp type. If None is set, it uses the - default value, ``yyyy-MM-dd'T'HH:mm:ss.SSSZZ``. - :param timeZone: sets the string that indicates a timezone to be used to parse timestamps. - If None is set, it uses the default value, session local timezone. + default value, ``yyyy-MM-dd'T'HH:mm:ss.SSSXXX``. + :param ignoreLeadingWhiteSpace: a flag indicating whether or not leading whitespaces from + values being written should be skipped. If None is set, it + uses the default value, ``true``. + :param ignoreTrailingWhiteSpace: a flag indicating whether or not trailing whitespaces from + values being written should be skipped. If None is set, it + uses the default value, ``true``. >>> df.write.csv(os.path.join(tempfile.mkdtemp(), 'data')) """ self.mode(mode) self._set_opts(compression=compression, sep=sep, quote=quote, escape=escape, header=header, nullValue=nullValue, escapeQuotes=escapeQuotes, quoteAll=quoteAll, - dateFormat=dateFormat, timestampFormat=timestampFormat, timeZone=timeZone) + dateFormat=dateFormat, timestampFormat=timestampFormat, + ignoreLeadingWhiteSpace=ignoreLeadingWhiteSpace, + ignoreTrailingWhiteSpace=ignoreTrailingWhiteSpace) self._jwrite.csv(path) @since(1.5) diff --git a/python/pyspark/sql/session.py b/python/pyspark/sql/session.py index 9f4772eec9f2a..c1bf2bd76fb7c 100644 --- a/python/pyspark/sql/session.py +++ b/python/pyspark/sql/session.py @@ -221,6 +221,17 @@ def __init__(self, sparkContext, jsparkSession=None): or SparkSession._instantiatedSession._sc._jsc is None: SparkSession._instantiatedSession = self + def _repr_html_(self): + return """ +
    +

    SparkSession - {catalogImplementation}

    + {sc_HTML} +
    + """.format( + catalogImplementation=self.conf.get("spark.sql.catalogImplementation"), + sc_HTML=self.sparkContext._repr_html_() + ) + @since(2.0) def newSession(self): """ diff --git a/python/pyspark/sql/streaming.py b/python/pyspark/sql/streaming.py index 625fb9ba385af..65b59d480da36 100644 --- a/python/pyspark/sql/streaming.py +++ b/python/pyspark/sql/streaming.py @@ -277,44 +277,6 @@ def resetTerminated(self): self._jsqm.resetTerminated() -class Trigger(object): - """Used to indicate how often results should be produced by a :class:`StreamingQuery`. - - .. note:: Experimental - - .. versionadded:: 2.0 - """ - - __metaclass__ = ABCMeta - - @abstractmethod - def _to_java_trigger(self, sqlContext): - """Internal method to construct the trigger on the jvm. - """ - pass - - -class ProcessingTime(Trigger): - """A trigger that runs a query periodically based on the processing time. If `interval` is 0, - the query will run as fast as possible. - - The interval should be given as a string, e.g. '2 seconds', '5 minutes', ... - - .. note:: Experimental - - .. versionadded:: 2.0 - """ - - def __init__(self, interval): - if type(interval) != str or len(interval.strip()) == 0: - raise ValueError("interval should be a non empty interval string, e.g. '2 seconds'.") - self.interval = interval - - def _to_java_trigger(self, sqlContext): - return sqlContext._sc._jvm.org.apache.spark.sql.streaming.ProcessingTime.create( - self.interval) - - class DataStreamReader(OptionUtils): """ Interface used to load a streaming :class:`DataFrame` from external storage systems @@ -373,6 +335,11 @@ def schema(self, schema): def option(self, key, value): """Adds an input option for the underlying data source. + You can set the following option(s) for reading files: + * ``timeZone``: sets the string that indicates a timezone to be used to parse timestamps + in the JSON/CSV datasources or partition values. + If it isn't set, it uses the default value, session local timezone. + .. note:: Experimental. >>> s = spark.readStream.option("x", 1) @@ -384,6 +351,11 @@ def option(self, key, value): def options(self, **options): """Adds input options for the underlying data source. + You can set the following option(s) for reading files: + * ``timeZone``: sets the string that indicates a timezone to be used to parse timestamps + in the JSON/CSV datasources or partition values. + If it isn't set, it uses the default value, session local timezone. + .. note:: Experimental. >>> s = spark.readStream.options(x="1", y=2) @@ -429,12 +401,12 @@ def json(self, path, schema=None, primitivesAsString=None, prefersDecimal=None, allowComments=None, allowUnquotedFieldNames=None, allowSingleQuotes=None, allowNumericLeadingZero=None, allowBackslashEscapingAnyCharacter=None, mode=None, columnNameOfCorruptRecord=None, dateFormat=None, timestampFormat=None, - timeZone=None, wholeFile=None): + wholeFile=None): """ Loads a JSON file stream and returns the results as a :class:`DataFrame`. - `JSON Lines `_(newline-delimited JSON) is supported by default. - For JSON (one record per file), set the `wholeFile` parameter to ``true``. + `JSON Lines `_ (newline-delimited JSON) is supported by default. + For JSON (one record per file), set the ``wholeFile`` parameter to ``true``. If the ``schema`` parameter is not specified, this function goes through the input once to determine the input schema. @@ -485,9 +457,7 @@ def json(self, path, schema=None, primitivesAsString=None, prefersDecimal=None, :param timestampFormat: sets the string that indicates a timestamp format. Custom date formats follow the formats at ``java.text.SimpleDateFormat``. This applies to timestamp type. If None is set, it uses the - default value, ``yyyy-MM-dd'T'HH:mm:ss.SSSZZ``. - :param timeZone: sets the string that indicates a timezone to be used to parse timestamps. - If None is set, it uses the default value, session local timezone. + default value, ``yyyy-MM-dd'T'HH:mm:ss.SSSXXX``. :param wholeFile: parse one record, which may span multiple lines, per file. If None is set, it uses the default value, ``false``. @@ -503,7 +473,7 @@ def json(self, path, schema=None, primitivesAsString=None, prefersDecimal=None, allowSingleQuotes=allowSingleQuotes, allowNumericLeadingZero=allowNumericLeadingZero, allowBackslashEscapingAnyCharacter=allowBackslashEscapingAnyCharacter, mode=mode, columnNameOfCorruptRecord=columnNameOfCorruptRecord, dateFormat=dateFormat, - timestampFormat=timestampFormat, timeZone=timeZone, wholeFile=wholeFile) + timestampFormat=timestampFormat, wholeFile=wholeFile) if isinstance(path, basestring): return self._df(self._jreader.json(path)) else: @@ -561,7 +531,7 @@ def csv(self, path, schema=None, sep=None, encoding=None, quote=None, escape=Non comment=None, header=None, inferSchema=None, ignoreLeadingWhiteSpace=None, ignoreTrailingWhiteSpace=None, nullValue=None, nanValue=None, positiveInf=None, negativeInf=None, dateFormat=None, timestampFormat=None, maxColumns=None, - maxCharsPerColumn=None, maxMalformedLogPerPartition=None, mode=None, timeZone=None, + maxCharsPerColumn=None, maxMalformedLogPerPartition=None, mode=None, columnNameOfCorruptRecord=None, wholeFile=None): """Loads a CSV file stream and returns the result as a :class:`DataFrame`. @@ -589,12 +559,12 @@ def csv(self, path, schema=None, sep=None, encoding=None, quote=None, escape=Non default value, ``false``. :param inferSchema: infers the input schema automatically from data. It requires one extra pass over the data. If None is set, it uses the default value, ``false``. - :param ignoreLeadingWhiteSpace: defines whether or not leading whitespaces from values - being read should be skipped. If None is set, it uses - the default value, ``false``. - :param ignoreTrailingWhiteSpace: defines whether or not trailing whitespaces from values - being read should be skipped. If None is set, it uses - the default value, ``false``. + :param ignoreLeadingWhiteSpace: a flag indicating whether or not leading whitespaces from + values being read should be skipped. If None is set, it + uses the default value, ``false``. + :param ignoreTrailingWhiteSpace: a flag indicating whether or not trailing whitespaces from + values being read should be skipped. If None is set, it + uses the default value, ``false``. :param nullValue: sets the string representation of a null value. If None is set, it uses the default value, empty string. Since 2.0.1, this ``nullValue`` param applies to all supported types including the string type. @@ -611,16 +581,16 @@ def csv(self, path, schema=None, sep=None, encoding=None, quote=None, escape=Non :param timestampFormat: sets the string that indicates a timestamp format. Custom date formats follow the formats at ``java.text.SimpleDateFormat``. This applies to timestamp type. If None is set, it uses the - default value, ``yyyy-MM-dd'T'HH:mm:ss.SSSZZ``. + default value, ``yyyy-MM-dd'T'HH:mm:ss.SSSXXX``. :param maxColumns: defines a hard limit of how many columns a record can have. If None is set, it uses the default value, ``20480``. :param maxCharsPerColumn: defines the maximum number of characters allowed for any given value being read. If None is set, it uses the default value, ``-1`` meaning unlimited length. + :param maxMalformedLogPerPartition: this parameter is no longer used since Spark 2.2.0. + If specified, it is ignored. :param mode: allows a mode for dealing with corrupt records during parsing. If None is set, it uses the default value, ``PERMISSIVE``. - :param timeZone: sets the string that indicates a timezone to be used to parse timestamps. - If None is set, it uses the default value, session local timezone. * ``PERMISSIVE`` : sets other fields to ``null`` when it meets a corrupted \ record, and puts the malformed string into a field configured by \ @@ -653,7 +623,7 @@ def csv(self, path, schema=None, sep=None, encoding=None, quote=None, escape=Non nanValue=nanValue, positiveInf=positiveInf, negativeInf=negativeInf, dateFormat=dateFormat, timestampFormat=timestampFormat, maxColumns=maxColumns, maxCharsPerColumn=maxCharsPerColumn, - maxMalformedLogPerPartition=maxMalformedLogPerPartition, mode=mode, timeZone=timeZone, + maxMalformedLogPerPartition=maxMalformedLogPerPartition, mode=mode, columnNameOfCorruptRecord=columnNameOfCorruptRecord, wholeFile=wholeFile) if isinstance(path, basestring): return self._df(self._jreader.csv(path)) @@ -721,6 +691,11 @@ def format(self, source): def option(self, key, value): """Adds an output option for the underlying data source. + You can set the following option(s) for writing files: + * ``timeZone``: sets the string that indicates a timezone to be used to format + timestamps in the JSON/CSV datasources or partition values. + If it isn't set, it uses the default value, session local timezone. + .. note:: Experimental. """ self._jwrite = self._jwrite.option(key, to_str(value)) @@ -730,6 +705,11 @@ def option(self, key, value): def options(self, **options): """Adds output options for the underlying data source. + You can set the following option(s) for writing files: + * ``timeZone``: sets the string that indicates a timezone to be used to format + timestamps in the JSON/CSV datasources or partition values. + If it isn't set, it uses the default value, session local timezone. + .. note:: Experimental. """ for k in options: @@ -772,7 +752,7 @@ def queryName(self, queryName): @keyword_only @since(2.0) - def trigger(self, processingTime=None): + def trigger(self, processingTime=None, once=None): """Set the trigger for the stream query. If this is not set it will run the query as fast as possible, which is equivalent to setting the trigger to ``processingTime='0 seconds'``. @@ -782,17 +762,26 @@ def trigger(self, processingTime=None): >>> # trigger the query for execution every 5 seconds >>> writer = sdf.writeStream.trigger(processingTime='5 seconds') + >>> # trigger the query for just once batch of data + >>> writer = sdf.writeStream.trigger(once=True) """ - from pyspark.sql.streaming import ProcessingTime - trigger = None + jTrigger = None if processingTime is not None: + if once is not None: + raise ValueError('Multiple triggers not allowed.') if type(processingTime) != str or len(processingTime.strip()) == 0: - raise ValueError('The processing time must be a non empty string. Got: %s' % + raise ValueError('Value for processingTime must be a non empty string. Got: %s' % processingTime) - trigger = ProcessingTime(processingTime) - if trigger is None: - raise ValueError('A trigger was not provided. Supported triggers: processingTime.') - self._jwrite = self._jwrite.trigger(trigger._to_java_trigger(self._spark)) + interval = processingTime.strip() + jTrigger = self._spark._sc._jvm.org.apache.spark.sql.streaming.Trigger.ProcessingTime( + interval) + elif once is not None: + if once is not True: + raise ValueError('Value for once must be True. Got: %s' % once) + jTrigger = self._spark._sc._jvm.org.apache.spark.sql.streaming.Trigger.Once() + else: + raise ValueError('No trigger provided') + self._jwrite = self._jwrite.trigger(jTrigger) return self @ignore_unicode_prefix diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 81f3d1d36a342..f644624f7f317 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -450,6 +450,24 @@ def test_wholefile_csv(self): Row(_c0=u'Hyukjin', _c1=u'25', _c2=u'I am Hyukjin\n\nI love Spark!')] self.assertEqual(ages_newlines.collect(), expected) + def test_ignorewhitespace_csv(self): + tmpPath = tempfile.mkdtemp() + shutil.rmtree(tmpPath) + self.spark.createDataFrame([[" a", "b ", " c "]]).write.csv( + tmpPath, + ignoreLeadingWhiteSpace=False, + ignoreTrailingWhiteSpace=False) + + expected = [Row(value=u' a,b , c ')] + readback = self.spark.read.text(tmpPath) + self.assertEqual(readback.collect(), expected) + shutil.rmtree(tmpPath) + + def test_read_multiple_orc_file(self): + df = self.spark.read.orc(["python/test_support/sql/orc_partitioned/b=0/c=0", + "python/test_support/sql/orc_partitioned/b=1/c=1"]) + self.assertEqual(2, df.count()) + def test_udf_with_input_file_name(self): from pyspark.sql.functions import udf, input_file_name from pyspark.sql.types import StringType @@ -964,7 +982,7 @@ def test_column_operators(self): cbool = (ci & ci), (ci | ci), (~ci) self.assertTrue(all(isinstance(c, Column) for c in cbool)) css = cs.contains('a'), cs.like('a'), cs.rlike('a'), cs.asc(), cs.desc(),\ - cs.startswith('a'), cs.endswith('a') + cs.startswith('a'), cs.endswith('a'), ci.eqNullSafe(cs) self.assertTrue(all(isinstance(c, Column) for c in css)) self.assertTrue(isinstance(ci.cast(LongType()), Column)) self.assertRaisesRegexp(ValueError, @@ -1111,6 +1129,14 @@ def test_rand_functions(self): rndn2 = df.select('key', functions.randn(0)).collect() self.assertEqual(sorted(rndn1), sorted(rndn2)) + def test_array_contains_function(self): + from pyspark.sql.functions import array_contains + + df = self.spark.createDataFrame([(["1", "2", "3"],), ([],)], ['data']) + actual = df.select(array_contains(df.data, 1).alias('b')).collect() + # The value argument can be implicitly castable to the element's type of the array. + self.assertEqual([Row(b=True), Row(b=False)], actual) + def test_between_function(self): df = self.sc.parallelize([ Row(a=1, b=2, c=3), @@ -1237,13 +1263,26 @@ def test_save_and_load_builder(self): shutil.rmtree(tmpPath) - def test_stream_trigger_takes_keyword_args(self): + def test_stream_trigger(self): df = self.spark.readStream.format('text').load('python/test_support/sql/streaming') + + # Should take at least one arg + try: + df.writeStream.trigger() + except ValueError: + pass + + # Should not take multiple args + try: + df.writeStream.trigger(once=True, processingTime='5 seconds') + except ValueError: + pass + + # Should take only keyword args try: df.writeStream.trigger('5 seconds') self.fail("Should have thrown an exception") except TypeError: - # should throw error pass def test_stream_read_options(self): @@ -1555,6 +1594,14 @@ def test_time_with_timezone(self): self.assertEqual(now, now1) self.assertEqual(now, utcnow1) + # regression test for SPARK-19561 + def test_datetime_at_epoch(self): + epoch = datetime.datetime.fromtimestamp(0) + df = self.spark.createDataFrame([Row(date=epoch)]) + first = df.select('date', lit(epoch).alias('lit_date')).first() + self.assertEqual(first['date'], epoch) + self.assertEqual(first['lit_date'], epoch) + def test_decimal(self): from decimal import Decimal schema = StructType([StructField("decimal", DecimalType(10, 5))]) @@ -1664,6 +1711,10 @@ def test_fillna(self): self.assertEqual(row.age, None) self.assertEqual(row.height, None) + # fillna with dictionary for boolean types + row = self.spark.createDataFrame([Row(a=None), Row(a=True)]).fillna({"a": True}).first() + self.assertEqual(row.a, True) + def test_bitwise_operations(self): from pyspark.sql import functions row = Row(a=170, b=75) @@ -1732,6 +1783,78 @@ def test_replace(self): self.assertEqual(row.age, 10) self.assertEqual(row.height, None) + # replace with lists + row = self.spark.createDataFrame( + [(u'Alice', 10, 80.1)], schema).replace([u'Alice'], [u'Ann']).first() + self.assertTupleEqual(row, (u'Ann', 10, 80.1)) + + # replace with dict + row = self.spark.createDataFrame( + [(u'Alice', 10, 80.1)], schema).replace({10: 11}).first() + self.assertTupleEqual(row, (u'Alice', 11, 80.1)) + + # test backward compatibility with dummy value + dummy_value = 1 + row = self.spark.createDataFrame( + [(u'Alice', 10, 80.1)], schema).replace({'Alice': 'Bob'}, dummy_value).first() + self.assertTupleEqual(row, (u'Bob', 10, 80.1)) + + # test dict with mixed numerics + row = self.spark.createDataFrame( + [(u'Alice', 10, 80.1)], schema).replace({10: -10, 80.1: 90.5}).first() + self.assertTupleEqual(row, (u'Alice', -10, 90.5)) + + # replace with tuples + row = self.spark.createDataFrame( + [(u'Alice', 10, 80.1)], schema).replace((u'Alice', ), (u'Bob', )).first() + self.assertTupleEqual(row, (u'Bob', 10, 80.1)) + + # replace multiple columns + row = self.spark.createDataFrame( + [(u'Alice', 10, 80.0)], schema).replace((10, 80.0), (20, 90)).first() + self.assertTupleEqual(row, (u'Alice', 20, 90.0)) + + # test for mixed numerics + row = self.spark.createDataFrame( + [(u'Alice', 10, 80.0)], schema).replace((10, 80), (20, 90.5)).first() + self.assertTupleEqual(row, (u'Alice', 20, 90.5)) + + row = self.spark.createDataFrame( + [(u'Alice', 10, 80.0)], schema).replace({10: 20, 80: 90.5}).first() + self.assertTupleEqual(row, (u'Alice', 20, 90.5)) + + # replace with boolean + row = (self + .spark.createDataFrame([(u'Alice', 10, 80.0)], schema) + .selectExpr("name = 'Bob'", 'age <= 15') + .replace(False, True).first()) + self.assertTupleEqual(row, (True, True)) + + # should fail if subset is not list, tuple or None + with self.assertRaises(ValueError): + self.spark.createDataFrame( + [(u'Alice', 10, 80.1)], schema).replace({10: 11}, subset=1).first() + + # should fail if to_replace and value have different length + with self.assertRaises(ValueError): + self.spark.createDataFrame( + [(u'Alice', 10, 80.1)], schema).replace(["Alice", "Bob"], ["Eve"]).first() + + # should fail if when received unexpected type + with self.assertRaises(ValueError): + from datetime import datetime + self.spark.createDataFrame( + [(u'Alice', 10, 80.1)], schema).replace(datetime.now(), datetime.now()).first() + + # should fail if provided mixed type replacements + with self.assertRaises(ValueError): + self.spark.createDataFrame( + [(u'Alice', 10, 80.1)], schema).replace(["Alice", 10], ["Eve", 20]).first() + + with self.assertRaises(ValueError): + self.spark.createDataFrame( + [(u'Alice', 10, 80.1)], schema).replace({u"Alice": u"Bob", 10: 20}).first() + def test_capture_analysis_exception(self): self.assertRaises(AnalysisException, lambda: self.spark.sql("select abc")) self.assertRaises(AnalysisException, lambda: self.df.selectExpr("a + b")) @@ -1783,6 +1906,22 @@ def test_functions_broadcast(self): # planner should not crash without a join broadcast(df1)._jdf.queryExecution().executedPlan() + def test_generic_hints(self): + from pyspark.sql import DataFrame + + df1 = self.spark.range(10e10).toDF("id") + df2 = self.spark.range(10e10).toDF("id") + + self.assertIsInstance(df1.hint("broadcast"), DataFrame) + self.assertIsInstance(df1.hint("broadcast", []), DataFrame) + + # Dummy rules + self.assertIsInstance(df1.hint("broadcast", "foo", "bar"), DataFrame) + self.assertIsInstance(df1.hint("broadcast", ["foo", "bar"]), DataFrame) + + plan = df1.join(df2.hint("broadcast"), "id")._jdf.queryExecution().executedPlan() + self.assertEqual(1, plan.toString().count("BroadcastHashJoin")) + def test_toDF_with_schema_string(self): data = [Row(key=i, value=str(i)) for i in range(100)] rdd = self.sc.parallelize(data, 5) diff --git a/python/pyspark/streaming/tests.py b/python/pyspark/streaming/tests.py index 2e8ed698278d0..ffba99502b148 100644 --- a/python/pyspark/streaming/tests.py +++ b/python/pyspark/streaming/tests.py @@ -55,7 +55,7 @@ class PySparkStreamingTestCase(unittest.TestCase): - timeout = 10 # seconds + timeout = 30 # seconds duration = .5 @classmethod @@ -903,11 +903,11 @@ def updater(vs, s): def setup(): conf = SparkConf().set("spark.default.parallelism", 1) sc = SparkContext(conf=conf) - ssc = StreamingContext(sc, 0.5) + ssc = StreamingContext(sc, 2) dstream = ssc.textFileStream(inputd).map(lambda x: (x, 1)) wc = dstream.updateStateByKey(updater) wc.map(lambda x: "%s,%d" % x).saveAsTextFiles(outputd + "test") - wc.checkpoint(.5) + wc.checkpoint(2) self.setupCalled = True return ssc @@ -921,21 +921,22 @@ def setup(): def check_output(n): while not os.listdir(outputd): - time.sleep(0.01) + if self.ssc.awaitTerminationOrTimeout(0.5): + raise Exception("ssc stopped") time.sleep(1) # make sure mtime is larger than the previous one with open(os.path.join(inputd, str(n)), 'w') as f: f.writelines(["%d\n" % i for i in range(10)]) while True: + if self.ssc.awaitTerminationOrTimeout(0.5): + raise Exception("ssc stopped") p = os.path.join(outputd, max(os.listdir(outputd))) if '_SUCCESS' not in os.listdir(p): # not finished - time.sleep(0.01) continue ordd = self.ssc.sparkContext.textFile(p).map(lambda line: line.split(",")) d = ordd.values().map(int).collect() if not d: - time.sleep(0.01) continue self.assertEqual(10, len(d)) s = set(d) diff --git a/python/pyspark/tests.py b/python/pyspark/tests.py index b7704d0d5395d..ef16eaa30afa2 100644 --- a/python/pyspark/tests.py +++ b/python/pyspark/tests.py @@ -1037,6 +1037,12 @@ def test_repartition_no_skewed(self): zeros = len([x for x in l if x == 0]) self.assertTrue(zeros == 0) + def test_repartition_on_textfile(self): + path = os.path.join(SPARK_HOME, "python/test_support/hello/hello.txt") + rdd = self.sc.textFile(path) + result = rdd.repartition(1).collect() + self.assertEqual(u"Hello World!", result[0]) + def test_distinct(self): rdd = self.sc.parallelize((1, 2, 3)*10, 10) self.assertEqual(rdd.getNumPartitions(), 10) diff --git a/python/pyspark/util.py b/python/pyspark/util.py new file mode 100644 index 0000000000000..e5d332ce54429 --- /dev/null +++ b/python/pyspark/util.py @@ -0,0 +1,45 @@ +# -*- coding: utf-8 -*- +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +__all__ = [] + + +def _exception_message(excp): + """Return the message from an exception as either a str or unicode object. Supports both + Python 2 and Python 3. + + >>> msg = "Exception message" + >>> excp = Exception(msg) + >>> msg == _exception_message(excp) + True + + >>> msg = u"unicöde" + >>> excp = Exception(msg) + >>> msg == _exception_message(excp) + True + """ + if hasattr(excp, "message"): + return excp.message + return str(excp) + + +if __name__ == "__main__": + import doctest + (failure_count, test_count) = doctest.testmod() + if failure_count: + exit(-1) diff --git a/python/pyspark/version.py b/python/pyspark/version.py index 08a301695fda7..41bf8c269b795 100644 --- a/python/pyspark/version.py +++ b/python/pyspark/version.py @@ -16,4 +16,4 @@ # See the License for the specific language governing permissions and # limitations under the License. -__version__ = "2.1.0.dev0" +__version__ = "2.2.0.dev0" diff --git a/python/run-tests.py b/python/run-tests.py index 53a0aef229b08..b2e50435bb192 100755 --- a/python/run-tests.py +++ b/python/run-tests.py @@ -111,9 +111,9 @@ def run_individual_python_test(test_name, pyspark_python): def get_default_python_executables(): - python_execs = [x for x in ["python2.6", "python3.4", "pypy"] if which(x)] - if "python2.6" not in python_execs: - LOGGER.warning("Not testing against `python2.6` because it could not be found; falling" + python_execs = [x for x in ["python2.7", "python3.4", "pypy"] if which(x)] + if "python2.7" not in python_execs: + LOGGER.warning("Not testing against `python2.7` because it could not be found; falling" " back to `python` instead") python_execs.insert(0, "python") return python_execs diff --git a/python/setup.py b/python/setup.py index 47eab98e0f7b3..f50035435e26b 100644 --- a/python/setup.py +++ b/python/setup.py @@ -167,7 +167,6 @@ def _supports_symlinks(): 'pyspark.ml', 'pyspark.ml.linalg', 'pyspark.ml.param', - 'pyspark.ml.stat', 'pyspark.sql', 'pyspark.streaming', 'pyspark.bin', diff --git a/repl/pom.xml b/repl/pom.xml index a256ae3b84183..6d133a3cfff7d 100644 --- a/repl/pom.xml +++ b/repl/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.11 - 2.2.0-SNAPSHOT + 2.3.0-SNAPSHOT ../pom.xml diff --git a/repl/scala-2.11/src/main/scala/org/apache/spark/repl/Main.scala b/repl/scala-2.11/src/main/scala/org/apache/spark/repl/Main.scala index 7f2ec01cc9676..39fc621de7807 100644 --- a/repl/scala-2.11/src/main/scala/org/apache/spark/repl/Main.scala +++ b/repl/scala-2.11/src/main/scala/org/apache/spark/repl/Main.scala @@ -18,6 +18,7 @@ package org.apache.spark.repl import java.io.File +import java.util.Locale import scala.tools.nsc.GenericRunnerSettings @@ -88,7 +89,7 @@ object Main extends Logging { } val builder = SparkSession.builder.config(conf) - if (conf.get(CATALOG_IMPLEMENTATION.key, "hive").toLowerCase == "hive") { + if (conf.get(CATALOG_IMPLEMENTATION.key, "hive").toLowerCase(Locale.ROOT) == "hive") { if (SparkSession.hiveClassesArePresent) { // In the case that the property is not set at all, builder's config // does not have this value set to 'hive' yet. The original default diff --git a/repl/scala-2.11/src/test/scala/org/apache/spark/repl/ReplSuite.scala b/repl/scala-2.11/src/test/scala/org/apache/spark/repl/ReplSuite.scala index 55c91675ed3ba..8fe27080cac66 100644 --- a/repl/scala-2.11/src/test/scala/org/apache/spark/repl/ReplSuite.scala +++ b/repl/scala-2.11/src/test/scala/org/apache/spark/repl/ReplSuite.scala @@ -473,4 +473,16 @@ class ReplSuite extends SparkFunSuite { assertDoesNotContain("AssertionError", output) assertDoesNotContain("Exception", output) } + + // TODO: [SPARK-20548] Fix and re-enable + ignore("newProductSeqEncoder with REPL defined class") { + val output = runInterpreterInPasteMode("local-cluster[1,4,4096]", + """ + |case class Click(id: Int) + |spark.implicits.newProductSeqEncoder[Click] + """.stripMargin) + + assertDoesNotContain("error:", output) + assertDoesNotContain("Exception", output) + } } diff --git a/resource-managers/kubernetes/core/pom.xml b/resource-managers/kubernetes/core/pom.xml index 44a387a420c99..2ca0bb1c312a9 100644 --- a/resource-managers/kubernetes/core/pom.xml +++ b/resource-managers/kubernetes/core/pom.xml @@ -20,7 +20,7 @@ org.apache.spark spark-parent_2.11 - 2.2.0-SNAPSHOT + 2.3.0-SNAPSHOT ../../../pom.xml diff --git a/resource-managers/kubernetes/docker-minimal-bundle/pom.xml b/resource-managers/kubernetes/docker-minimal-bundle/pom.xml index 7f4d935e0e243..595dc046efd6d 100644 --- a/resource-managers/kubernetes/docker-minimal-bundle/pom.xml +++ b/resource-managers/kubernetes/docker-minimal-bundle/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.11 - 2.2.0-SNAPSHOT + 2.3.0-SNAPSHOT ../../../pom.xml diff --git a/resource-managers/kubernetes/integration-tests-spark-jobs-helpers/pom.xml b/resource-managers/kubernetes/integration-tests-spark-jobs-helpers/pom.xml index f99838636b349..c4fe21604fb44 100644 --- a/resource-managers/kubernetes/integration-tests-spark-jobs-helpers/pom.xml +++ b/resource-managers/kubernetes/integration-tests-spark-jobs-helpers/pom.xml @@ -20,13 +20,16 @@ org.apache.spark spark-parent_2.11 - 2.2.0-SNAPSHOT + 2.3.0-SNAPSHOT ../../../pom.xml spark-kubernetes-integration-tests-spark-jobs-helpers_2.11 jar Spark Project Kubernetes Integration Tests Spark Jobs Helpers + + kubernetes-integration-tests-spark-jobs-helpers + diff --git a/resource-managers/kubernetes/integration-tests-spark-jobs/pom.xml b/resource-managers/kubernetes/integration-tests-spark-jobs/pom.xml index 59e59aca5109b..db5ec86e0dfd8 100644 --- a/resource-managers/kubernetes/integration-tests-spark-jobs/pom.xml +++ b/resource-managers/kubernetes/integration-tests-spark-jobs/pom.xml @@ -20,13 +20,16 @@ org.apache.spark spark-parent_2.11 - 2.2.0-SNAPSHOT + 2.3.0-SNAPSHOT ../../../pom.xml spark-kubernetes-integration-tests-spark-jobs_2.11 jar Spark Project Kubernetes Integration Tests Spark Jobs + + kubernetes-integration-tests-spark-jobs + diff --git a/resource-managers/kubernetes/integration-tests/pom.xml b/resource-managers/kubernetes/integration-tests/pom.xml index 048d2cc86e185..e50cb229894cb 100644 --- a/resource-managers/kubernetes/integration-tests/pom.xml +++ b/resource-managers/kubernetes/integration-tests/pom.xml @@ -20,13 +20,16 @@ org.apache.spark spark-parent_2.11 - 2.2.0-SNAPSHOT + 2.3.0-SNAPSHOT ../../../pom.xml spark-kubernetes-integration-tests_2.11 jar Spark Project Kubernetes Integration Tests + + kubernetes-integration-tests + diff --git a/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/kubernetes/integrationtest/sslutil/SSLUtils.scala b/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/kubernetes/integrationtest/sslutil/SSLUtils.scala index 2078e0585e8f0..418af13f8a8cc 100644 --- a/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/kubernetes/integrationtest/sslutil/SSLUtils.scala +++ b/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/kubernetes/integrationtest/sslutil/SSLUtils.scala @@ -19,8 +19,8 @@ package org.apache.spark.deploy.kubernetes.integrationtest.sslutil import java.io.{File, FileOutputStream, OutputStreamWriter} import java.math.BigInteger import java.nio.file.Files -import java.security.cert.X509Certificate import java.security.{KeyPair, KeyPairGenerator, KeyStore, SecureRandom} +import java.security.cert.X509Certificate import java.util.{Calendar, Random} import javax.security.auth.x500.X500Principal diff --git a/resource-managers/mesos/pom.xml b/resource-managers/mesos/pom.xml index 03846d9f5a3be..20b53f2d8f987 100644 --- a/resource-managers/mesos/pom.xml +++ b/resource-managers/mesos/pom.xml @@ -20,7 +20,7 @@ org.apache.spark spark-parent_2.11 - 2.2.0-SNAPSHOT + 2.3.0-SNAPSHOT ../../pom.xml diff --git a/resource-managers/mesos/src/main/scala/org/apache/spark/deploy/mesos/ui/DriverPage.scala b/resource-managers/mesos/src/main/scala/org/apache/spark/deploy/mesos/ui/DriverPage.scala index cd98110ddcc02..127fadabcce53 100644 --- a/resource-managers/mesos/src/main/scala/org/apache/spark/deploy/mesos/ui/DriverPage.scala +++ b/resource-managers/mesos/src/main/scala/org/apache/spark/deploy/mesos/ui/DriverPage.scala @@ -101,7 +101,7 @@ private[ui] class DriverPage(parent: MesosClusterUI) extends WebUIPage("driver") Launch Time - {state.startDate} + {UIUtils.formatDate(state.startDate)} Finish Time @@ -154,7 +154,7 @@ private[ui] class DriverPage(parent: MesosClusterUI) extends WebUIPage("driver") Memory{driver.mem} - Submitted{driver.submissionDate} + Submitted{UIUtils.formatDate(driver.submissionDate)} Supervise{driver.supervise} diff --git a/resource-managers/mesos/src/main/scala/org/apache/spark/deploy/mesos/ui/MesosClusterPage.scala b/resource-managers/mesos/src/main/scala/org/apache/spark/deploy/mesos/ui/MesosClusterPage.scala index 13ba7d311e57d..c9107c3e73d3f 100644 --- a/resource-managers/mesos/src/main/scala/org/apache/spark/deploy/mesos/ui/MesosClusterPage.scala +++ b/resource-managers/mesos/src/main/scala/org/apache/spark/deploy/mesos/ui/MesosClusterPage.scala @@ -68,7 +68,7 @@ private[mesos] class MesosClusterPage(parent: MesosClusterUI) extends WebUIPage( val id = submission.submissionId {id} - {submission.submissionDate} + {UIUtils.formatDate(submission.submissionDate)} {submission.command.mainClass} cpus: {submission.cores}, mem: {submission.mem} @@ -88,10 +88,10 @@ private[mesos] class MesosClusterPage(parent: MesosClusterUI) extends WebUIPage( {id} {historyCol} - {state.driverDescription.submissionDate} + {UIUtils.formatDate(state.driverDescription.submissionDate)} {state.driverDescription.command.mainClass} cpus: {state.driverDescription.cores}, mem: {state.driverDescription.mem} - {state.startDate} + {UIUtils.formatDate(state.startDate)} {state.slaveId.getValue} {stateString(state.mesosTaskStatus)} @@ -101,7 +101,7 @@ private[mesos] class MesosClusterPage(parent: MesosClusterUI) extends WebUIPage( val id = submission.submissionId {id} - {submission.submissionDate} + {UIUtils.formatDate(submission.submissionDate)} {submission.command.mainClass} {submission.retryState.get.lastFailureStatus} {submission.retryState.get.nextRetry} diff --git a/resource-managers/mesos/src/main/scala/org/apache/spark/executor/MesosExecutorBackend.scala b/resource-managers/mesos/src/main/scala/org/apache/spark/executor/MesosExecutorBackend.scala index b252539782580..a086ec7ea2da6 100644 --- a/resource-managers/mesos/src/main/scala/org/apache/spark/executor/MesosExecutorBackend.scala +++ b/resource-managers/mesos/src/main/scala/org/apache/spark/executor/MesosExecutorBackend.scala @@ -104,7 +104,8 @@ private[spark] class MesosExecutorBackend logError("Received KillTask but executor was null") } else { // TODO: Determine the 'interruptOnCancel' property set for the given job. - executor.killTask(t.getValue.toLong, interruptThread = false) + executor.killTask( + t.getValue.toLong, interruptThread = false, reason = "killed by mesos") } } diff --git a/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterManager.scala b/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterManager.scala index ed29b346ba263..911a0857917ef 100644 --- a/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterManager.scala +++ b/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterManager.scala @@ -22,7 +22,7 @@ import org.apache.spark.internal.config._ import org.apache.spark.scheduler.{ExternalClusterManager, SchedulerBackend, TaskScheduler, TaskSchedulerImpl} /** - * Cluster Manager for creation of Yarn scheduler and backend + * Cluster Manager for creation of Mesos scheduler and backend */ private[spark] class MesosClusterManager extends ExternalClusterManager { private val MESOS_REGEX = """mesos://(.*)""".r diff --git a/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterScheduler.scala b/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterScheduler.scala index 2760f31b12fa7..1bc6f71860c3f 100644 --- a/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterScheduler.scala +++ b/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterScheduler.scala @@ -152,6 +152,7 @@ private[spark] class MesosClusterScheduler( // is registered with Mesos master. @volatile protected var ready = false private var masterInfo: Option[MasterInfo] = None + private var schedulerDriver: SchedulerDriver = _ def submitDriver(desc: MesosDriverDescription): CreateSubmissionResponse = { val c = new CreateSubmissionResponse @@ -168,9 +169,8 @@ private[spark] class MesosClusterScheduler( return c } c.submissionId = desc.submissionId - queuedDriversState.persist(desc.submissionId, desc) - queuedDrivers += desc c.success = true + addDriverToQueue(desc) } c } @@ -191,7 +191,7 @@ private[spark] class MesosClusterScheduler( // 4. Check if it has already completed. if (launchedDrivers.contains(submissionId)) { val task = launchedDrivers(submissionId) - mesosDriver.killTask(task.taskId) + schedulerDriver.killTask(task.taskId) k.success = true k.message = "Killing running driver" } else if (removeFromQueuedDrivers(submissionId)) { @@ -324,7 +324,7 @@ private[spark] class MesosClusterScheduler( ready = false metricsSystem.report() metricsSystem.stop() - mesosDriver.stop(true) + schedulerDriver.stop(true) } override def registered( @@ -340,6 +340,8 @@ private[spark] class MesosClusterScheduler( stateLock.synchronized { this.masterInfo = Some(masterInfo) + this.schedulerDriver = driver + if (!pendingRecover.isEmpty) { // Start task reconciliation if we need to recover. val statuses = pendingRecover.collect { @@ -506,11 +508,10 @@ private[spark] class MesosClusterScheduler( } private class ResourceOffer( - val offerId: OfferID, - val slaveId: SlaveID, - var resources: JList[Resource]) { + val offer: Offer, + var remainingResources: JList[Resource]) { override def toString(): String = { - s"Offer id: ${offerId}, resources: ${resources}" + s"Offer id: ${offer.getId}, resources: ${remainingResources}" } } @@ -518,16 +519,16 @@ private[spark] class MesosClusterScheduler( val taskId = TaskID.newBuilder().setValue(desc.submissionId).build() val (remainingResources, cpuResourcesToUse) = - partitionResources(offer.resources, "cpus", desc.cores) + partitionResources(offer.remainingResources, "cpus", desc.cores) val (finalResources, memResourcesToUse) = partitionResources(remainingResources.asJava, "mem", desc.mem) - offer.resources = finalResources.asJava + offer.remainingResources = finalResources.asJava val appName = desc.conf.get("spark.app.name") val taskInfo = TaskInfo.newBuilder() .setTaskId(taskId) .setName(s"Driver for ${appName}") - .setSlaveId(offer.slaveId) + .setSlaveId(offer.offer.getSlaveId) .setCommand(buildDriverCommand(desc)) .addAllResources(cpuResourcesToUse.asJava) .addAllResources(memResourcesToUse.asJava) @@ -549,23 +550,29 @@ private[spark] class MesosClusterScheduler( val driverCpu = submission.cores val driverMem = submission.mem logTrace(s"Finding offer to launch driver with cpu: $driverCpu, mem: $driverMem") - val offerOption = currentOffers.find { o => - getResource(o.resources, "cpus") >= driverCpu && - getResource(o.resources, "mem") >= driverMem + val offerOption = currentOffers.find { offer => + getResource(offer.remainingResources, "cpus") >= driverCpu && + getResource(offer.remainingResources, "mem") >= driverMem } if (offerOption.isEmpty) { logDebug(s"Unable to find offer to launch driver id: ${submission.submissionId}, " + s"cpu: $driverCpu, mem: $driverMem") } else { val offer = offerOption.get - val queuedTasks = tasks.getOrElseUpdate(offer.offerId, new ArrayBuffer[TaskInfo]) + val queuedTasks = tasks.getOrElseUpdate(offer.offer.getId, new ArrayBuffer[TaskInfo]) try { val task = createTaskInfo(submission, offer) queuedTasks += task - logTrace(s"Using offer ${offer.offerId.getValue} to launch driver " + + logTrace(s"Using offer ${offer.offer.getId.getValue} to launch driver " + submission.submissionId) - val newState = new MesosClusterSubmissionState(submission, task.getTaskId, offer.slaveId, - None, new Date(), None, getDriverFrameworkID(submission)) + val newState = new MesosClusterSubmissionState( + submission, + task.getTaskId, + offer.offer.getSlaveId, + None, + new Date(), + None, + getDriverFrameworkID(submission)) launchedDrivers(submission.submissionId) = newState launchedDriversState.persist(submission.submissionId, newState) afterLaunchCallback(submission.submissionId) @@ -588,7 +595,7 @@ private[spark] class MesosClusterScheduler( val currentTime = new Date() val currentOffers = offers.asScala.map { - o => new ResourceOffer(o.getId, o.getSlaveId, o.getResourcesList) + offer => new ResourceOffer(offer, offer.getResourcesList) }.toList stateLock.synchronized { @@ -615,8 +622,8 @@ private[spark] class MesosClusterScheduler( driver.launchTasks(Collections.singleton(offerId), taskInfos.asJava) } - for (o <- currentOffers if !tasks.contains(o.offerId)) { - driver.declineOffer(o.offerId) + for (offer <- currentOffers if !tasks.contains(offer.offer.getId)) { + declineOffer(driver, offer.offer, None, Some(getRejectOfferDuration(conf))) } } @@ -662,6 +669,12 @@ private[spark] class MesosClusterScheduler( override def statusUpdate(driver: SchedulerDriver, status: TaskStatus): Unit = { val taskId = status.getTaskId.getValue + + logInfo(s"Received status update: taskId=${taskId}" + + s" state=${status.getState}" + + s" message=${status.getMessage}" + + s" reason=${status.getReason}"); + stateLock.synchronized { if (launchedDrivers.contains(taskId)) { if (status.getReason == Reason.REASON_RECONCILIATION && @@ -682,8 +695,7 @@ private[spark] class MesosClusterScheduler( val newDriverDescription = state.driverDescription.copy( retryState = Some(new MesosClusterRetryState(status, retries, nextRetry, waitTimeSec))) - pendingRetryDrivers += newDriverDescription - pendingRetryDriversState.persist(taskId, newDriverDescription) + addDriverToPending(newDriverDescription, taskId); } else if (TaskState.isFinished(mesosToTaskState(status.getState))) { removeFromLaunchedDrivers(taskId) state.finishDate = Some(new Date()) @@ -746,4 +758,21 @@ private[spark] class MesosClusterScheduler( def getQueuedDriversSize: Int = queuedDrivers.size def getLaunchedDriversSize: Int = launchedDrivers.size def getPendingRetryDriversSize: Int = pendingRetryDrivers.size + + private def addDriverToQueue(desc: MesosDriverDescription): Unit = { + queuedDriversState.persist(desc.submissionId, desc) + queuedDrivers += desc + revive() + } + + private def addDriverToPending(desc: MesosDriverDescription, taskId: String) = { + pendingRetryDriversState.persist(taskId, desc) + pendingRetryDrivers += desc + revive() + } + + private def revive(): Unit = { + logInfo("Reviving Offers.") + schedulerDriver.reviveOffers() + } } diff --git a/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosCoarseGrainedSchedulerBackend.scala b/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosCoarseGrainedSchedulerBackend.scala index f69c223ab9b6d..8f5b97ccb1f85 100644 --- a/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosCoarseGrainedSchedulerBackend.scala +++ b/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosCoarseGrainedSchedulerBackend.scala @@ -26,6 +26,7 @@ import scala.collection.mutable import scala.concurrent.Future import org.apache.mesos.Protos.{TaskInfo => MesosTaskInfo, _} +import org.apache.mesos.SchedulerDriver import org.apache.spark.{SecurityManager, SparkContext, SparkException, TaskState} import org.apache.spark.network.netty.SparkTransportConf @@ -59,13 +60,23 @@ private[spark] class MesosCoarseGrainedSchedulerBackend( private val maxCoresOption = conf.getOption("spark.cores.max").map(_.toInt) + private val executorCoresOption = conf.getOption("spark.executor.cores").map(_.toInt) + + private val minCoresPerExecutor = executorCoresOption.getOrElse(1) + // Maximum number of cores to acquire - private val maxCores = maxCoresOption.getOrElse(Int.MaxValue) + private val maxCores = { + val cores = maxCoresOption.getOrElse(Int.MaxValue) + // Set maxCores to a multiple of smallest executor we can launch + cores - (cores % minCoresPerExecutor) + } private val useFetcherCache = conf.getBoolean("spark.mesos.fetcherCache.enable", false) private val maxGpus = conf.getInt("spark.mesos.gpus.max", 0) + private val taskLabels = conf.get("spark.mesos.task.labels", "") + private[this] val shutdownTimeoutMS = conf.getTimeAsMs("spark.mesos.coarse.shutdownTimeout", "10s") .ensuring(_ >= 0, "spark.mesos.coarse.shutdownTimeout must be >= 0") @@ -119,11 +130,11 @@ private[spark] class MesosCoarseGrainedSchedulerBackend( // Reject offers with mismatched constraints in seconds private val rejectOfferDurationForUnmetConstraints = - getRejectOfferDurationForUnmetConstraints(sc) + getRejectOfferDurationForUnmetConstraints(sc.conf) // Reject offers when we reached the maximum number of cores for this framework private val rejectOfferDurationForReachedMaxCores = - getRejectOfferDurationForReachedMaxCores(sc) + getRejectOfferDurationForReachedMaxCores(sc.conf) // A client for talking to the external shuffle service private val mesosExternalShuffleClient: Option[MesosExternalShuffleClient] = { @@ -146,6 +157,8 @@ private[spark] class MesosCoarseGrainedSchedulerBackend( @volatile var appId: String = _ + private var schedulerDriver: SchedulerDriver = _ + def newMesosTaskId(): String = { val id = nextMesosTaskId nextMesosTaskId += 1 @@ -172,11 +185,6 @@ private[spark] class MesosCoarseGrainedSchedulerBackend( def createCommand(offer: Offer, numCores: Int, taskId: String): CommandInfo = { val environment = Environment.newBuilder() - val extraClassPath = conf.getOption("spark.executor.extraClassPath") - extraClassPath.foreach { cp => - environment.addVariables( - Environment.Variable.newBuilder().setName("SPARK_CLASSPATH").setValue(cp).build()) - } val extraJavaOpts = conf.get("spark.executor.extraJavaOptions", "") // Set the environment variable through a command prefix @@ -252,9 +260,12 @@ private[spark] class MesosCoarseGrainedSchedulerBackend( override def offerRescinded(d: org.apache.mesos.SchedulerDriver, o: OfferID) {} override def registered( - d: org.apache.mesos.SchedulerDriver, frameworkId: FrameworkID, masterInfo: MasterInfo) { - appId = frameworkId.getValue - mesosExternalShuffleClient.foreach(_.init(appId)) + driver: org.apache.mesos.SchedulerDriver, + frameworkId: FrameworkID, + masterInfo: MasterInfo) { + this.appId = frameworkId.getValue + this.mesosExternalShuffleClient.foreach(_.init(appId)) + this.schedulerDriver = driver markRegistered() } @@ -293,46 +304,25 @@ private[spark] class MesosCoarseGrainedSchedulerBackend( } private def declineUnmatchedOffers( - d: org.apache.mesos.SchedulerDriver, offers: mutable.Buffer[Offer]): Unit = { + driver: org.apache.mesos.SchedulerDriver, offers: mutable.Buffer[Offer]): Unit = { offers.foreach { offer => - declineOffer(d, offer, Some("unmet constraints"), + declineOffer( + driver, + offer, + Some("unmet constraints"), Some(rejectOfferDurationForUnmetConstraints)) } } - private def declineOffer( - d: org.apache.mesos.SchedulerDriver, - offer: Offer, - reason: Option[String] = None, - refuseSeconds: Option[Long] = None): Unit = { - - val id = offer.getId.getValue - val offerAttributes = toAttributeMap(offer.getAttributesList) - val mem = getResource(offer.getResourcesList, "mem") - val cpus = getResource(offer.getResourcesList, "cpus") - val ports = getRangeResource(offer.getResourcesList, "ports") - - logDebug(s"Declining offer: $id with attributes: $offerAttributes mem: $mem" + - s" cpu: $cpus port: $ports for $refuseSeconds seconds" + - reason.map(r => s" (reason: $r)").getOrElse("")) - - refuseSeconds match { - case Some(seconds) => - val filters = Filters.newBuilder().setRefuseSeconds(seconds).build() - d.declineOffer(offer.getId, filters) - case _ => d.declineOffer(offer.getId) - } - } - /** * Launches executors on accepted offers, and declines unused offers. Executors are launched * round-robin on offers. * - * @param d SchedulerDriver + * @param driver SchedulerDriver * @param offers Mesos offers that match attribute constraints */ private def handleMatchedOffers( - d: org.apache.mesos.SchedulerDriver, offers: mutable.Buffer[Offer]): Unit = { + driver: org.apache.mesos.SchedulerDriver, offers: mutable.Buffer[Offer]): Unit = { val tasks = buildMesosTasks(offers) for (offer <- offers) { val offerAttributes = toAttributeMap(offer.getAttributesList) @@ -358,15 +348,19 @@ private[spark] class MesosCoarseGrainedSchedulerBackend( s" ports: $ports") } - d.launchTasks( + driver.launchTasks( Collections.singleton(offer.getId), offerTasks.asJava) } else if (totalCoresAcquired >= maxCores) { // Reject an offer for a configurable amount of time to avoid starving other frameworks - declineOffer(d, offer, Some("reached spark.cores.max"), + declineOffer(driver, + offer, + Some("reached spark.cores.max"), Some(rejectOfferDurationForReachedMaxCores)) } else { - declineOffer(d, offer) + declineOffer( + driver, + offer) } } } @@ -419,10 +413,18 @@ private[spark] class MesosCoarseGrainedSchedulerBackend( .setTaskId(TaskID.newBuilder().setValue(taskId.toString).build()) .setSlaveId(offer.getSlaveId) .setCommand(createCommand(offer, taskCPUs + extraCoresPerExecutor, taskId)) - .setName("Task " + taskId) + .setName(s"${sc.appName} $taskId") + taskBuilder.addAllResources(resourcesToUse.asJava) taskBuilder.setContainer(MesosSchedulerBackendUtil.containerInfo(sc.conf)) + val labelsBuilder = taskBuilder.getLabelsBuilder + val labels = buildMesosLabels().asJava + + labelsBuilder.addAllLabels(labels) + + taskBuilder.setLabels(labelsBuilder) + tasks(offer.getId) ::= taskBuilder.build() remainingResources(offerId) = resourcesLeft.asJava totalCoresAcquired += taskCPUs @@ -437,6 +439,21 @@ private[spark] class MesosCoarseGrainedSchedulerBackend( tasks.toMap } + private def buildMesosLabels(): List[Label] = { + taskLabels.split(",").flatMap(label => + label.split(":") match { + case Array(key, value) => + Some(Label.newBuilder() + .setKey(key) + .setValue(value) + .build()) + case _ => + logWarning(s"Unable to parse $label into a key:value label for the task.") + None + } + ).toList + } + /** Extracts task needed resources from a list of available resources. */ private def partitionTaskResources( resources: JList[Resource], @@ -480,8 +497,9 @@ private[spark] class MesosCoarseGrainedSchedulerBackend( } private def executorCores(offerCPUs: Int): Int = { - sc.conf.getInt("spark.executor.cores", - math.min(offerCPUs, maxCores - totalCoresAcquired)) + executorCoresOption.getOrElse( + math.min(offerCPUs, maxCores - totalCoresAcquired) + ) } override def statusUpdate(d: org.apache.mesos.SchedulerDriver, status: TaskStatus) { @@ -582,8 +600,8 @@ private[spark] class MesosCoarseGrainedSchedulerBackend( // Close the mesos external shuffle client if used mesosExternalShuffleClient.foreach(_.close()) - if (mesosDriver != null) { - mesosDriver.stop() + if (schedulerDriver != null) { + schedulerDriver.stop() } } @@ -634,13 +652,13 @@ private[spark] class MesosCoarseGrainedSchedulerBackend( } override def doKillExecutors(executorIds: Seq[String]): Future[Boolean] = Future.successful { - if (mesosDriver == null) { + if (schedulerDriver == null) { logWarning("Asked to kill executors before the Mesos driver was started.") false } else { for (executorId <- executorIds) { val taskId = TaskID.newBuilder().setValue(executorId).build() - mesosDriver.killTask(taskId) + schedulerDriver.killTask(taskId) } // no need to adjust `executorLimitOption` since the AllocationManager already communicated // the desired limit through a call to `doRequestTotalExecutors`. diff --git a/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosFineGrainedSchedulerBackend.scala b/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosFineGrainedSchedulerBackend.scala index 7e561916a71e2..735c879c63c55 100644 --- a/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosFineGrainedSchedulerBackend.scala +++ b/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosFineGrainedSchedulerBackend.scala @@ -24,6 +24,7 @@ import scala.collection.JavaConverters._ import scala.collection.mutable.{HashMap, HashSet} import org.apache.mesos.Protos.{ExecutorInfo => MesosExecutorInfo, TaskInfo => MesosTaskInfo, _} +import org.apache.mesos.SchedulerDriver import org.apache.mesos.protobuf.ByteString import org.apache.spark.{SparkContext, SparkException, TaskState} @@ -65,7 +66,9 @@ private[spark] class MesosFineGrainedSchedulerBackend( // reject offers with mismatched constraints in seconds private val rejectOfferDurationForUnmetConstraints = - getRejectOfferDurationForUnmetConstraints(sc) + getRejectOfferDurationForUnmetConstraints(sc.conf) + + private var schedulerDriver: SchedulerDriver = _ @volatile var appId: String = _ @@ -89,6 +92,7 @@ private[spark] class MesosFineGrainedSchedulerBackend( /** * Creates a MesosExecutorInfo that is used to launch a Mesos executor. + * * @param availableResources Available resources that is offered by Mesos * @param execId The executor id to assign to this new executor. * @return A tuple of the new mesos executor info and the remaining available resources. @@ -102,10 +106,6 @@ private[spark] class MesosFineGrainedSchedulerBackend( throw new SparkException("Executor Spark home `spark.mesos.executor.home` is not set!") } val environment = Environment.newBuilder() - sc.conf.getOption("spark.executor.extraClassPath").foreach { cp => - environment.addVariables( - Environment.Variable.newBuilder().setName("SPARK_CLASSPATH").setValue(cp).build()) - } val extraJavaOpts = sc.conf.getOption("spark.executor.extraJavaOptions").getOrElse("") val prefixEnv = sc.conf.getOption("spark.executor.extraLibraryPath").map { p => @@ -178,10 +178,13 @@ private[spark] class MesosFineGrainedSchedulerBackend( override def offerRescinded(d: org.apache.mesos.SchedulerDriver, o: OfferID) {} override def registered( - d: org.apache.mesos.SchedulerDriver, frameworkId: FrameworkID, masterInfo: MasterInfo) { + driver: org.apache.mesos.SchedulerDriver, + frameworkId: FrameworkID, + masterInfo: MasterInfo) { inClassLoader() { appId = frameworkId.getValue logInfo("Registered as framework ID " + appId) + this.schedulerDriver = driver markRegistered() } } @@ -383,13 +386,13 @@ private[spark] class MesosFineGrainedSchedulerBackend( } override def stop() { - if (mesosDriver != null) { - mesosDriver.stop() + if (schedulerDriver != null) { + schedulerDriver.stop() } } override def reviveOffers() { - mesosDriver.reviveOffers() + schedulerDriver.reviveOffers() } override def frameworkMessage( @@ -425,8 +428,9 @@ private[spark] class MesosFineGrainedSchedulerBackend( recordSlaveLost(d, slaveId, ExecutorExited(status, exitCausedByApp = true)) } - override def killTask(taskId: Long, executorId: String, interruptThread: Boolean): Unit = { - mesosDriver.killTask( + override def killTask( + taskId: Long, executorId: String, interruptThread: Boolean, reason: String): Unit = { + schedulerDriver.killTask( TaskID.newBuilder() .setValue(taskId.toString).build() ) diff --git a/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackendUtil.scala b/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackendUtil.scala index a2adb228dc299..fbcbc55099ec5 100644 --- a/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackendUtil.scala +++ b/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackendUtil.scala @@ -17,7 +17,7 @@ package org.apache.spark.scheduler.cluster.mesos -import org.apache.mesos.Protos.{ContainerInfo, Image, NetworkInfo, Volume} +import org.apache.mesos.Protos.{ContainerInfo, Image, NetworkInfo, Parameter, Volume} import org.apache.mesos.Protos.ContainerInfo.{DockerInfo, MesosInfo} import org.apache.spark.{SparkConf, SparkException} @@ -99,6 +99,28 @@ private[mesos] object MesosSchedulerBackendUtil extends Logging { .toList } + /** + * Parse a list of docker parameters, each of which + * takes the form key=value + */ + private def parseParamsSpec(params: String): List[Parameter] = { + // split with limit of 2 to avoid parsing error when '=' + // exists in the parameter value + params.split(",").map(_.split("=", 2)).flatMap { spec: Array[String] => + val param: Parameter.Builder = Parameter.newBuilder() + spec match { + case Array(key, value) => + Some(param.setKey(key).setValue(value)) + case spec => + logWarning(s"Unable to parse arbitary parameters: $params. " + + "Expected form: \"key=value(, ...)\"") + None + } + } + .map { _.build() } + .toList + } + def containerInfo(conf: SparkConf): ContainerInfo = { val containerType = if (conf.contains("spark.mesos.executor.docker.image") && conf.get("spark.mesos.containerizer", "docker") == "docker") { @@ -120,8 +142,14 @@ private[mesos] object MesosSchedulerBackendUtil extends Logging { .map(parsePortMappingsSpec) .getOrElse(List.empty) + val params = conf + .getOption("spark.mesos.executor.docker.parameters") + .map(parseParamsSpec) + .getOrElse(List.empty) + if (containerType == ContainerInfo.Type.DOCKER) { - containerInfo.setDocker(dockerInfo(image, forcePullImage, portMaps)) + containerInfo + .setDocker(dockerInfo(image, forcePullImage, portMaps, params)) } else { containerInfo.setMesos(mesosInfo(image, forcePullImage)) } @@ -144,11 +172,13 @@ private[mesos] object MesosSchedulerBackendUtil extends Logging { private def dockerInfo( image: String, forcePullImage: Boolean, - portMaps: List[ContainerInfo.DockerInfo.PortMapping]): DockerInfo = { + portMaps: List[ContainerInfo.DockerInfo.PortMapping], + params: List[Parameter]): DockerInfo = { val dockerBuilder = ContainerInfo.DockerInfo.newBuilder() .setImage(image) .setForcePullImage(forcePullImage) portMaps.foreach(dockerBuilder.addPortMappings(_)) + params.foreach(dockerBuilder.addParameters(_)) dockerBuilder.build } diff --git a/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerUtils.scala b/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerUtils.scala index 1d742fefbbacf..9d81025a3016b 100644 --- a/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerUtils.scala +++ b/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerUtils.scala @@ -46,9 +46,6 @@ trait MesosSchedulerUtils extends Logging { // Lock used to wait for scheduler to be registered private final val registerLatch = new CountDownLatch(1) - // Driver for talking to Mesos - protected var mesosDriver: SchedulerDriver = null - /** * Creates a new MesosSchedulerDriver that communicates to the Mesos master. * @@ -115,10 +112,6 @@ trait MesosSchedulerUtils extends Logging { */ def startScheduler(newDriver: SchedulerDriver): Unit = { synchronized { - if (mesosDriver != null) { - registerLatch.await() - return - } @volatile var error: Option[Exception] = None @@ -128,8 +121,7 @@ trait MesosSchedulerUtils extends Logging { setDaemon(true) override def run() { try { - mesosDriver = newDriver - val ret = mesosDriver.run() + val ret = newDriver.run() logInfo("driver.run() returned with code " + ret) if (ret != null && ret.equals(Status.DRIVER_ABORTED)) { error = Some(new SparkException("Error starting driver, DRIVER_ABORTED")) @@ -247,7 +239,7 @@ trait MesosSchedulerUtils extends Logging { } /** - * Converts the attributes from the resource offer into a Map of name -> Attribute Value + * Converts the attributes from the resource offer into a Map of name to Attribute Value * The attribute values are the mesos attribute types and they are * * @param offerAttributes the attributes offered @@ -304,7 +296,7 @@ trait MesosSchedulerUtils extends Logging { /** * Parses the attributes constraints provided to spark and build a matching data struct: - * Map[, Set[values-to-match]] + * {@literal Map[, Set[values-to-match]} * The constraints are specified as ';' separated key-value pairs where keys and values * are separated by ':'. The ':' implies equality (for singular values) and "is one of" for * multiple values (comma separated). For example: @@ -362,7 +354,7 @@ trait MesosSchedulerUtils extends Logging { * container overheads. * * @param sc SparkContext to use to get `spark.mesos.executor.memoryOverhead` value - * @return memory requirement as (0.1 * ) or MEMORY_OVERHEAD_MINIMUM + * @return memory requirement as (0.1 * memoryOverhead) or MEMORY_OVERHEAD_MINIMUM * (whichever is larger) */ def executorMemory(sc: SparkContext): Int = { @@ -379,12 +371,24 @@ trait MesosSchedulerUtils extends Logging { } } - protected def getRejectOfferDurationForUnmetConstraints(sc: SparkContext): Long = { - sc.conf.getTimeAsSeconds("spark.mesos.rejectOfferDurationForUnmetConstraints", "120s") + private def getRejectOfferDurationStr(conf: SparkConf): String = { + conf.get("spark.mesos.rejectOfferDuration", "120s") + } + + protected def getRejectOfferDuration(conf: SparkConf): Long = { + Utils.timeStringAsSeconds(getRejectOfferDurationStr(conf)) + } + + protected def getRejectOfferDurationForUnmetConstraints(conf: SparkConf): Long = { + conf.getTimeAsSeconds( + "spark.mesos.rejectOfferDurationForUnmetConstraints", + getRejectOfferDurationStr(conf)) } - protected def getRejectOfferDurationForReachedMaxCores(sc: SparkContext): Long = { - sc.conf.getTimeAsSeconds("spark.mesos.rejectOfferDurationForReachedMaxCores", "120s") + protected def getRejectOfferDurationForReachedMaxCores(conf: SparkConf): Long = { + conf.getTimeAsSeconds( + "spark.mesos.rejectOfferDurationForReachedMaxCores", + getRejectOfferDurationStr(conf)) } /** @@ -438,6 +442,7 @@ trait MesosSchedulerUtils extends Logging { /** * The values of the non-zero ports to be used by the executor process. + * * @param conf the spark config to use * @return the ono-zero values of the ports */ @@ -521,4 +526,33 @@ trait MesosSchedulerUtils extends Logging { case TaskState.KILLED => MesosTaskState.TASK_KILLED case TaskState.LOST => MesosTaskState.TASK_LOST } + + protected def declineOffer( + driver: org.apache.mesos.SchedulerDriver, + offer: Offer, + reason: Option[String] = None, + refuseSeconds: Option[Long] = None): Unit = { + + val id = offer.getId.getValue + val offerAttributes = toAttributeMap(offer.getAttributesList) + val mem = getResource(offer.getResourcesList, "mem") + val cpus = getResource(offer.getResourcesList, "cpus") + val ports = getRangeResource(offer.getResourcesList, "ports") + + logDebug(s"Declining offer: $id with " + + s"attributes: $offerAttributes " + + s"mem: $mem " + + s"cpu: $cpus " + + s"port: $ports " + + refuseSeconds.map(s => s"for ${s} seconds ").getOrElse("") + + reason.map(r => s" (reason: $r)").getOrElse("")) + + refuseSeconds match { + case Some(seconds) => + val filters = Filters.newBuilder().setRefuseSeconds(seconds).build() + driver.declineOffer(offer.getId, filters) + case _ => + driver.declineOffer(offer.getId) + } + } } diff --git a/resource-managers/mesos/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterSchedulerSuite.scala b/resource-managers/mesos/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterSchedulerSuite.scala index b9d098486b675..32967b04cd346 100644 --- a/resource-managers/mesos/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterSchedulerSuite.scala +++ b/resource-managers/mesos/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterSchedulerSuite.scala @@ -53,19 +53,32 @@ class MesosClusterSchedulerSuite extends SparkFunSuite with LocalSparkContext wi override def start(): Unit = { ready = true } } scheduler.start() + scheduler.registered(driver, Utils.TEST_FRAMEWORK_ID, Utils.TEST_MASTER_INFO) + } + + private def testDriverDescription(submissionId: String): MesosDriverDescription = { + new MesosDriverDescription( + "d1", + "jar", + 1000, + 1, + true, + command, + Map[String, String](), + submissionId, + new Date()) } test("can queue drivers") { setScheduler() - val response = scheduler.submitDriver( - new MesosDriverDescription("d1", "jar", 1000, 1, true, - command, Map[String, String](), "s1", new Date())) + val response = scheduler.submitDriver(testDriverDescription("s1")) assert(response.success) - val response2 = - scheduler.submitDriver(new MesosDriverDescription( - "d1", "jar", 1000, 1, true, command, Map[String, String](), "s2", new Date())) + verify(driver, times(1)).reviveOffers() + + val response2 = scheduler.submitDriver(testDriverDescription("s2")) assert(response2.success) + val state = scheduler.getSchedulerState() val queuedDrivers = state.queuedDrivers.toList assert(queuedDrivers(0).submissionId == response.submissionId) @@ -75,9 +88,7 @@ class MesosClusterSchedulerSuite extends SparkFunSuite with LocalSparkContext wi test("can kill queued drivers") { setScheduler() - val response = scheduler.submitDriver( - new MesosDriverDescription("d1", "jar", 1000, 1, true, - command, Map[String, String](), "s1", new Date())) + val response = scheduler.submitDriver(testDriverDescription("s1")) assert(response.success) val killResponse = scheduler.killDriver(response.submissionId) assert(killResponse.success) @@ -238,18 +249,10 @@ class MesosClusterSchedulerSuite extends SparkFunSuite with LocalSparkContext wi } test("can kill supervised drivers") { - val driver = mock[SchedulerDriver] val conf = new SparkConf() conf.setMaster("mesos://localhost:5050") conf.setAppName("spark mesos") - scheduler = new MesosClusterScheduler( - new BlackHoleMesosClusterPersistenceEngineFactory, conf) { - override def start(): Unit = { - ready = true - mesosDriver = driver - } - } - scheduler.start() + setScheduler(conf.getAll.toMap) val response = scheduler.submitDriver( new MesosDriverDescription("d1", "jar", 100, 1, true, command, @@ -291,4 +294,16 @@ class MesosClusterSchedulerSuite extends SparkFunSuite with LocalSparkContext wi assert(state.launchedDrivers.isEmpty) assert(state.finishedDrivers.size == 1) } + + test("Declines offer with refuse seconds = 120.") { + setScheduler() + + val filter = Filters.newBuilder().setRefuseSeconds(120).build() + val offerId = OfferID.newBuilder().setValue("o1").build() + val offer = Utils.createOffer(offerId.getValue, "s1", 1000, 1) + + scheduler.resourceOffers(driver, Collections.singletonList(offer)) + + verify(driver, times(1)).declineOffer(offerId, filter) + } } diff --git a/resource-managers/mesos/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosCoarseGrainedSchedulerBackendSuite.scala b/resource-managers/mesos/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosCoarseGrainedSchedulerBackendSuite.scala index 78346e9744957..0418bfbaa5ed8 100644 --- a/resource-managers/mesos/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosCoarseGrainedSchedulerBackendSuite.scala +++ b/resource-managers/mesos/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosCoarseGrainedSchedulerBackendSuite.scala @@ -199,6 +199,40 @@ class MesosCoarseGrainedSchedulerBackendSuite extends SparkFunSuite verifyDeclinedOffer(driver, createOfferId("o2"), true) } + test("mesos declines offers with a filter when maxCores not a multiple of executor.cores") { + val maxCores = 4 + val executorCores = 3 + setBackend(Map( + "spark.cores.max" -> maxCores.toString, + "spark.executor.cores" -> executorCores.toString + )) + val executorMemory = backend.executorMemory(sc) + offerResources(List( + Resources(executorMemory, maxCores + 1), + Resources(executorMemory, maxCores + 1) + )) + verifyTaskLaunched(driver, "o1") + verifyDeclinedOffer(driver, createOfferId("o2"), true) + } + + test("mesos declines offers with a filter when reached spark.cores.max with executor.cores") { + val maxCores = 4 + val executorCores = 2 + setBackend(Map( + "spark.cores.max" -> maxCores.toString, + "spark.executor.cores" -> executorCores.toString + )) + val executorMemory = backend.executorMemory(sc) + offerResources(List( + Resources(executorMemory, maxCores + 1), + Resources(executorMemory, maxCores + 1), + Resources(executorMemory, maxCores + 1) + )) + verifyTaskLaunched(driver, "o1") + verifyTaskLaunched(driver, "o2") + verifyDeclinedOffer(driver, createOfferId("o3"), true) + } + test("mesos assigns tasks round-robin on offers") { val executorCores = 4 val maxCores = executorCores * 2 @@ -464,6 +498,63 @@ class MesosCoarseGrainedSchedulerBackendSuite extends SparkFunSuite assert(!uris.asScala.head.getCache) } + test("mesos sets task name to spark.app.name") { + setBackend() + + val offers = List(Resources(backend.executorMemory(sc), 1)) + offerResources(offers) + val launchedTasks = verifyTaskLaunched(driver, "o1") + + // Add " 0" to the taskName to match the executor number that is appended + assert(launchedTasks.head.getName == "test-mesos-dynamic-alloc 0") + } + + test("mesos sets configurable labels on tasks") { + val taskLabelsString = "mesos:test,label:test" + setBackend(Map( + "spark.mesos.task.labels" -> taskLabelsString + )) + + // Build up the labels + val taskLabels = Protos.Labels.newBuilder() + .addLabels(Protos.Label.newBuilder() + .setKey("mesos").setValue("test").build()) + .addLabels(Protos.Label.newBuilder() + .setKey("label").setValue("test").build()) + .build() + + val offers = List(Resources(backend.executorMemory(sc), 1)) + offerResources(offers) + val launchedTasks = verifyTaskLaunched(driver, "o1") + + val labels = launchedTasks.head.getLabels + + assert(launchedTasks.head.getLabels.equals(taskLabels)) + } + + test("mesos ignored invalid labels and sets configurable labels on tasks") { + val taskLabelsString = "mesos:test,label:test,incorrect:label:here" + setBackend(Map( + "spark.mesos.task.labels" -> taskLabelsString + )) + + // Build up the labels + val taskLabels = Protos.Labels.newBuilder() + .addLabels(Protos.Label.newBuilder() + .setKey("mesos").setValue("test").build()) + .addLabels(Protos.Label.newBuilder() + .setKey("label").setValue("test").build()) + .build() + + val offers = List(Resources(backend.executorMemory(sc), 1)) + offerResources(offers) + val launchedTasks = verifyTaskLaunched(driver, "o1") + + val labels = launchedTasks.head.getLabels + + assert(launchedTasks.head.getLabels.equals(taskLabels)) + } + test("mesos supports spark.mesos.network.name") { setBackend(Map( "spark.mesos.network.name" -> "test-network-name" @@ -552,17 +643,14 @@ class MesosCoarseGrainedSchedulerBackendSuite extends SparkFunSuite override protected def getShuffleClient(): MesosExternalShuffleClient = shuffleClient // override to avoid race condition with the driver thread on `mesosDriver` - override def startScheduler(newDriver: SchedulerDriver): Unit = { - mesosDriver = newDriver - } + override def startScheduler(newDriver: SchedulerDriver): Unit = {} override def stopExecutors(): Unit = { stopCalled = true } - - markRegistered() } backend.start() + backend.registered(driver, Utils.TEST_FRAMEWORK_ID, Utils.TEST_MASTER_INFO) backend } diff --git a/resource-managers/mesos/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackendUtilSuite.scala b/resource-managers/mesos/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackendUtilSuite.scala new file mode 100644 index 0000000000000..caf9d89fdd201 --- /dev/null +++ b/resource-managers/mesos/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackendUtilSuite.scala @@ -0,0 +1,53 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.scheduler.cluster.mesos + +import org.scalatest._ +import org.scalatest.mock.MockitoSugar + +import org.apache.spark.{SparkConf, SparkFunSuite} + +class MesosSchedulerBackendUtilSuite extends SparkFunSuite { + + test("ContainerInfo fails to parse invalid docker parameters") { + val conf = new SparkConf() + conf.set("spark.mesos.executor.docker.parameters", "a,b") + conf.set("spark.mesos.executor.docker.image", "test") + + val containerInfo = MesosSchedulerBackendUtil.containerInfo(conf) + val params = containerInfo.getDocker.getParametersList + + assert(params.size() == 0) + } + + test("ContainerInfo parses docker parameters") { + val conf = new SparkConf() + conf.set("spark.mesos.executor.docker.parameters", "a=1,b=2,c=3") + conf.set("spark.mesos.executor.docker.image", "test") + + val containerInfo = MesosSchedulerBackendUtil.containerInfo(conf) + val params = containerInfo.getDocker.getParametersList + assert(params.size() == 3) + assert(params.get(0).getKey == "a") + assert(params.get(0).getValue == "1") + assert(params.get(1).getKey == "b") + assert(params.get(1).getValue == "2") + assert(params.get(2).getKey == "c") + assert(params.get(2).getValue == "3") + } +} diff --git a/resource-managers/mesos/src/test/scala/org/apache/spark/scheduler/cluster/mesos/Utils.scala b/resource-managers/mesos/src/test/scala/org/apache/spark/scheduler/cluster/mesos/Utils.scala index 7ebb294aa9080..2a67cbc913ffe 100644 --- a/resource-managers/mesos/src/test/scala/org/apache/spark/scheduler/cluster/mesos/Utils.scala +++ b/resource-managers/mesos/src/test/scala/org/apache/spark/scheduler/cluster/mesos/Utils.scala @@ -28,6 +28,17 @@ import org.mockito.{ArgumentCaptor, Matchers} import org.mockito.Mockito._ object Utils { + + val TEST_FRAMEWORK_ID = FrameworkID.newBuilder() + .setValue("test-framework-id") + .build() + + val TEST_MASTER_INFO = MasterInfo.newBuilder() + .setId("test-master") + .setIp(0) + .setPort(0) + .build() + def createOffer( offerId: String, slaveId: String, diff --git a/resource-managers/yarn/pom.xml b/resource-managers/yarn/pom.xml index a1b641c8eeb84..71d4ad681e169 100644 --- a/resource-managers/yarn/pom.xml +++ b/resource-managers/yarn/pom.xml @@ -20,7 +20,7 @@ org.apache.spark spark-parent_2.11 - 2.2.0-SNAPSHOT + 2.3.0-SNAPSHOT ../../pom.xml diff --git a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala index e86bd5459311d..b817570c0abf7 100644 --- a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala +++ b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala @@ -21,7 +21,7 @@ import java.io.{File, FileOutputStream, IOException, OutputStreamWriter} import java.net.{InetAddress, UnknownHostException, URI} import java.nio.ByteBuffer import java.nio.charset.StandardCharsets -import java.util.{Properties, UUID} +import java.util.{Locale, Properties, UUID} import java.util.zip.{ZipEntry, ZipOutputStream} import scala.collection.JavaConverters._ @@ -371,6 +371,9 @@ private[spark] class Client( val nearestTimeOfNextRenewal = credentialManager.obtainCredentials(hadoopConf, credentials) if (credentials != null) { + // Add credentials to current user's UGI, so that following operations don't need to use the + // Kerberos tgt to get delegations again in the client side. + UserGroupInformation.getCurrentUser.addCredentials(credentials) logDebug(YarnSparkHadoopUtil.get.dumpTokens(credentials).mkString("\n")) } @@ -529,7 +532,7 @@ private[spark] class Client( try { jarsStream.setLevel(0) jarsDir.listFiles().foreach { f => - if (f.isFile && f.getName.toLowerCase().endsWith(".jar") && f.canRead) { + if (f.isFile && f.getName.toLowerCase(Locale.ROOT).endsWith(".jar") && f.canRead) { jarsStream.putNextEntry(new ZipEntry(f.getName)) Files.copy(f, jarsStream) jarsStream.closeEntry() @@ -574,7 +577,7 @@ private[spark] class Client( ).foreach { case (flist, resType, addToClasspath) => flist.foreach { file => val (_, localizedPath) = distribute(file, resType = resType) - // If addToClassPath, we ignore adding jar multiple times to distitrbuted cache. + // If addToClassPath, we ignore adding jar multiple times to distributed cache. if (addToClasspath) { if (localizedPath != null) { cachedSecondaryJarLinks += localizedPath @@ -748,14 +751,6 @@ private[spark] class Client( .map { case (k, v) => (k.substring(amEnvPrefix.length), v) } .foreach { case (k, v) => YarnSparkHadoopUtil.addPathToEnvironment(env, k, v) } - // Keep this for backwards compatibility but users should move to the config - sys.env.get("SPARK_YARN_USER_ENV").foreach { userEnvs => - // Allow users to specify some environment variables. - YarnSparkHadoopUtil.setEnvFromInputString(env, userEnvs) - // Pass SPARK_YARN_USER_ENV itself to the AM so it can use it to set up executor environments. - env("SPARK_YARN_USER_ENV") = userEnvs - } - // If pyFiles contains any .py files, we need to add LOCALIZED_PYTHON_DIR to the PYTHONPATH // of the container processes too. Add all non-.py files directly to PYTHONPATH. // @@ -782,35 +777,7 @@ private[spark] class Client( sparkConf.setExecutorEnv("PYTHONPATH", pythonPathStr) } - // In cluster mode, if the deprecated SPARK_JAVA_OPTS is set, we need to propagate it to - // executors. But we can't just set spark.executor.extraJavaOptions, because the driver's - // SparkContext will not let that set spark* system properties, which is expected behavior for - // Yarn clients. So propagate it through the environment. - // - // Note that to warn the user about the deprecation in cluster mode, some code from - // SparkConf#validateSettings() is duplicated here (to avoid triggering the condition - // described above). if (isClusterMode) { - sys.env.get("SPARK_JAVA_OPTS").foreach { value => - val warning = - s""" - |SPARK_JAVA_OPTS was detected (set to '$value'). - |This is deprecated in Spark 1.0+. - | - |Please instead use: - | - ./spark-submit with conf/spark-defaults.conf to set defaults for an application - | - ./spark-submit with --driver-java-options to set -X options for a driver - | - spark.executor.extraJavaOptions to set -X options for executors - """.stripMargin - logWarning(warning) - for (proc <- Seq("driver", "executor")) { - val key = s"spark.$proc.extraJavaOptions" - if (sparkConf.contains(key)) { - throw new SparkException(s"Found both $key and SPARK_JAVA_OPTS. Use only the former.") - } - } - env("SPARK_JAVA_OPTS") = value - } // propagate PYSPARK_DRIVER_PYTHON and PYSPARK_PYTHON to driver in cluster mode Seq("PYSPARK_DRIVER_PYTHON", "PYSPARK_PYTHON").foreach { envname => if (!env.contains(envname)) { @@ -883,8 +850,7 @@ private[spark] class Client( // Include driver-specific java options if we are launching a driver if (isClusterMode) { - val driverOpts = sparkConf.get(DRIVER_JAVA_OPTIONS).orElse(sys.env.get("SPARK_JAVA_OPTS")) - driverOpts.foreach { opts => + sparkConf.get(DRIVER_JAVA_OPTIONS).foreach { opts => javaOpts ++= Utils.splitCommandString(opts).map(YarnSparkHadoopUtil.escapeForShell) } val libraryPaths = Seq(sparkConf.get(DRIVER_LIBRARY_PATH), diff --git a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/ExecutorRunnable.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/ExecutorRunnable.scala index ee85c043b8bc0..3f4d236571ffd 100644 --- a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/ExecutorRunnable.scala +++ b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/ExecutorRunnable.scala @@ -143,9 +143,6 @@ private[yarn] class ExecutorRunnable( sparkConf.get(EXECUTOR_JAVA_OPTIONS).foreach { opts => javaOpts ++= Utils.splitCommandString(opts).map(YarnSparkHadoopUtil.escapeForShell) } - sys.env.get("SPARK_JAVA_OPTS").foreach { opts => - javaOpts ++= Utils.splitCommandString(opts).map(YarnSparkHadoopUtil.escapeForShell) - } sparkConf.get(EXECUTOR_LIBRARY_PATH).foreach { p => prefixEnv = Some(Client.getClusterPath(sparkConf, Utils.libraryPathEnvPrefix(Seq(p)))) } @@ -229,11 +226,6 @@ private[yarn] class ExecutorRunnable( YarnSparkHadoopUtil.addPathToEnvironment(env, key, value) } - // Keep this for backwards compatibility but users should move to the config - sys.env.get("SPARK_YARN_USER_ENV").foreach { userEnvs => - YarnSparkHadoopUtil.setEnvFromInputString(env, userEnvs) - } - // lookup appropriate http scheme for container log urls val yarnHttpPolicy = conf.get( YarnConfiguration.YARN_HTTP_POLICY_KEY, diff --git a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/LocalityPreferredContainerPlacementStrategy.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/LocalityPreferredContainerPlacementStrategy.scala index f2b6324db619a..257dc83621e98 100644 --- a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/LocalityPreferredContainerPlacementStrategy.scala +++ b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/LocalityPreferredContainerPlacementStrategy.scala @@ -23,7 +23,6 @@ import scala.collection.JavaConverters._ import org.apache.hadoop.conf.Configuration import org.apache.hadoop.yarn.api.records.{ContainerId, Resource} import org.apache.hadoop.yarn.client.api.AMRMClient.ContainerRequest -import org.apache.hadoop.yarn.util.RackResolver import org.apache.spark.SparkConf import org.apache.spark.internal.config._ @@ -83,7 +82,8 @@ private[yarn] case class ContainerLocalityPreferences(nodes: Array[String], rack private[yarn] class LocalityPreferredContainerPlacementStrategy( val sparkConf: SparkConf, val yarnConf: Configuration, - val resource: Resource) { + val resource: Resource, + resolver: SparkRackResolver) { /** * Calculate each container's node locality and rack locality @@ -139,7 +139,7 @@ private[yarn] class LocalityPreferredContainerPlacementStrategy( // still be allocated with new container request. val hosts = preferredLocalityRatio.filter(_._2 > 0).keys.toArray val racks = hosts.map { h => - RackResolver.resolve(yarnConf, h).getNetworkLocation + resolver.resolve(yarnConf, h) }.toSet containerLocalityPreferences += ContainerLocalityPreferences(hosts, racks.toArray) diff --git a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/SparkRackResolver.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/SparkRackResolver.scala new file mode 100644 index 0000000000000..c711d088f2116 --- /dev/null +++ b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/SparkRackResolver.scala @@ -0,0 +1,40 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.deploy.yarn + +import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.yarn.util.RackResolver +import org.apache.log4j.{Level, Logger} + +/** + * Wrapper around YARN's [[RackResolver]]. This allows Spark tests to easily override the + * default behavior, since YARN's class self-initializes the first time it's called, and + * future calls all use the initial configuration. + */ +private[yarn] class SparkRackResolver { + + // RackResolver logs an INFO message whenever it resolves a rack, which is way too often. + if (Logger.getLogger(classOf[RackResolver]).getLevel == null) { + Logger.getLogger(classOf[RackResolver]).setLevel(Level.WARN) + } + + def resolve(conf: Configuration, hostName: String): String = { + RackResolver.resolve(conf, hostName).getNetworkLocation() + } + +} diff --git a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocator.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocator.scala index abd2de75c6450..ed77a6e4a1c7c 100644 --- a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocator.scala +++ b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocator.scala @@ -30,7 +30,6 @@ import org.apache.hadoop.yarn.api.records._ import org.apache.hadoop.yarn.client.api.AMRMClient import org.apache.hadoop.yarn.client.api.AMRMClient.ContainerRequest import org.apache.hadoop.yarn.conf.YarnConfiguration -import org.apache.hadoop.yarn.util.RackResolver import org.apache.log4j.{Level, Logger} import org.apache.spark.{SecurityManager, SparkConf, SparkException} @@ -65,16 +64,12 @@ private[yarn] class YarnAllocator( amClient: AMRMClient[ContainerRequest], appAttemptId: ApplicationAttemptId, securityMgr: SecurityManager, - localResources: Map[String, LocalResource]) + localResources: Map[String, LocalResource], + resolver: SparkRackResolver) extends Logging { import YarnAllocator._ - // RackResolver logs an INFO message whenever it resolves a rack, which is way too often. - if (Logger.getLogger(classOf[RackResolver]).getLevel == null) { - Logger.getLogger(classOf[RackResolver]).setLevel(Level.WARN) - } - // Visible for testing. val allocatedHostToContainersMap = new HashMap[String, collection.mutable.Set[ContainerId]] val allocatedContainerToHostMap = new HashMap[ContainerId, String] @@ -159,7 +154,7 @@ private[yarn] class YarnAllocator( // A container placement strategy based on pending tasks' locality preference private[yarn] val containerPlacementStrategy = - new LocalityPreferredContainerPlacementStrategy(sparkConf, conf, resource) + new LocalityPreferredContainerPlacementStrategy(sparkConf, conf, resource, resolver) /** * Use a different clock for YarnAllocator. This is mainly used for testing. @@ -424,7 +419,7 @@ private[yarn] class YarnAllocator( // Match remaining by rack val remainingAfterRackMatches = new ArrayBuffer[Container] for (allocatedContainer <- remainingAfterHostMatches) { - val rack = RackResolver.resolve(conf, allocatedContainer.getNodeId.getHost).getNetworkLocation + val rack = resolver.resolve(conf, allocatedContainer.getNodeId.getHost) matchContainerToRequest(allocatedContainer, rack, containersToUse, remainingAfterRackMatches) } @@ -494,7 +489,8 @@ private[yarn] class YarnAllocator( val containerId = container.getId val executorId = executorIdCounter.toString assert(container.getResource.getMemory >= resource.getMemory) - logInfo(s"Launching container $containerId on host $executorHostname") + logInfo(s"Launching container $containerId on host $executorHostname " + + s"for executor with ID $executorId") def updateInternalState(): Unit = synchronized { numExecutorsRunning += 1 diff --git a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnRMClient.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnRMClient.scala index 53fb467f6408d..72f4d273ab53b 100644 --- a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnRMClient.scala +++ b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnRMClient.scala @@ -75,7 +75,7 @@ private[spark] class YarnRMClient extends Logging { registered = true } new YarnAllocator(driverUrl, driverRef, conf, sparkConf, amClient, getAttemptId(), securityMgr, - localResources) + localResources, new SparkRackResolver()) } /** diff --git a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/security/CredentialUpdater.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/security/CredentialUpdater.scala index 2fdb70a73c754..41b7b5d60b038 100644 --- a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/security/CredentialUpdater.scala +++ b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/security/CredentialUpdater.scala @@ -60,7 +60,7 @@ private[spark] class CredentialUpdater( if (remainingTime <= 0) { credentialUpdater.schedule(credentialUpdaterRunnable, 1, TimeUnit.MINUTES) } else { - logInfo(s"Scheduling credentials refresh from HDFS in $remainingTime millis.") + logInfo(s"Scheduling credentials refresh from HDFS in $remainingTime ms.") credentialUpdater.schedule(credentialUpdaterRunnable, remainingTime, TimeUnit.MILLISECONDS) } } @@ -81,8 +81,8 @@ private[spark] class CredentialUpdater( UserGroupInformation.getCurrentUser.addCredentials(newCredentials) logInfo("Credentials updated from credentials file.") - val remainingTime = getTimeOfNextUpdateFromFileName(credentialsStatus.getPath) - - System.currentTimeMillis() + val remainingTime = (getTimeOfNextUpdateFromFileName(credentialsStatus.getPath) + - System.currentTimeMillis()) if (remainingTime <= 0) TimeUnit.MINUTES.toMillis(1) else remainingTime } else { // If current credential file is older than expected, sleep 1 hour and check again. @@ -100,6 +100,7 @@ private[spark] class CredentialUpdater( TimeUnit.HOURS.toMillis(1) } + logInfo(s"Scheduling credentials refresh from HDFS in $timeToNextUpdate ms.") credentialUpdater.schedule( credentialUpdaterRunnable, timeToNextUpdate, TimeUnit.MILLISECONDS) } diff --git a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/security/HBaseCredentialProvider.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/security/HBaseCredentialProvider.scala index 5571df09a2ec9..5adeb8e605ff4 100644 --- a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/security/HBaseCredentialProvider.scala +++ b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/security/HBaseCredentialProvider.scala @@ -26,6 +26,7 @@ import org.apache.hadoop.security.token.{Token, TokenIdentifier} import org.apache.spark.SparkConf import org.apache.spark.internal.Logging +import org.apache.spark.util.Utils private[security] class HBaseCredentialProvider extends ServiceCredentialProvider with Logging { @@ -36,7 +37,7 @@ private[security] class HBaseCredentialProvider extends ServiceCredentialProvide sparkConf: SparkConf, creds: Credentials): Option[Long] = { try { - val mirror = universe.runtimeMirror(getClass.getClassLoader) + val mirror = universe.runtimeMirror(Utils.getContextOrSparkClassLoader) val obtainToken = mirror.classLoader. loadClass("org.apache.hadoop.hbase.security.token.TokenUtil"). getMethod("obtainToken", classOf[Configuration]) @@ -60,7 +61,7 @@ private[security] class HBaseCredentialProvider extends ServiceCredentialProvide private def hbaseConf(conf: Configuration): Configuration = { try { - val mirror = universe.runtimeMirror(getClass.getClassLoader) + val mirror = universe.runtimeMirror(Utils.getContextOrSparkClassLoader) val confCreate = mirror.classLoader. loadClass("org.apache.hadoop.hbase.HBaseConfiguration"). getMethod("create", classOf[Configuration]) diff --git a/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/LocalityPlacementStrategySuite.scala b/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/LocalityPlacementStrategySuite.scala index fb80ff9f31322..b7f25656e49ac 100644 --- a/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/LocalityPlacementStrategySuite.scala +++ b/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/LocalityPlacementStrategySuite.scala @@ -17,10 +17,9 @@ package org.apache.spark.deploy.yarn +import scala.collection.JavaConverters._ import scala.collection.mutable.{HashMap, HashSet, Set} -import org.apache.hadoop.fs.CommonConfigurationKeysPublic -import org.apache.hadoop.net.DNSToSwitchMapping import org.apache.hadoop.yarn.api.records._ import org.apache.hadoop.yarn.conf.YarnConfiguration import org.mockito.Mockito._ @@ -51,9 +50,6 @@ class LocalityPlacementStrategySuite extends SparkFunSuite { private def runTest(): Unit = { val yarnConf = new YarnConfiguration() - yarnConf.setClass( - CommonConfigurationKeysPublic.NET_TOPOLOGY_NODE_SWITCH_MAPPING_IMPL_KEY, - classOf[MockResolver], classOf[DNSToSwitchMapping]) // The numbers below have been chosen to balance being large enough to replicate the // original issue while not taking too long to run when the issue is fixed. The main @@ -62,7 +58,7 @@ class LocalityPlacementStrategySuite extends SparkFunSuite { val resource = Resource.newInstance(8 * 1024, 4) val strategy = new LocalityPreferredContainerPlacementStrategy(new SparkConf(), - yarnConf, resource) + yarnConf, resource, new MockResolver()) val totalTasks = 32 * 1024 val totalContainers = totalTasks / 16 diff --git a/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnAllocatorSuite.scala b/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnAllocatorSuite.scala index fcc0594cf6d80..97b0e8aca3330 100644 --- a/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnAllocatorSuite.scala +++ b/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnAllocatorSuite.scala @@ -17,12 +17,9 @@ package org.apache.spark.deploy.yarn -import java.util.{Arrays, List => JList} - import scala.collection.JavaConverters._ -import org.apache.hadoop.fs.CommonConfigurationKeysPublic -import org.apache.hadoop.net.DNSToSwitchMapping +import org.apache.hadoop.conf.Configuration import org.apache.hadoop.yarn.api.records._ import org.apache.hadoop.yarn.client.api.AMRMClient import org.apache.hadoop.yarn.client.api.AMRMClient.ContainerRequest @@ -38,24 +35,16 @@ import org.apache.spark.rpc.RpcEndpointRef import org.apache.spark.scheduler.SplitInfo import org.apache.spark.util.ManualClock -class MockResolver extends DNSToSwitchMapping { +class MockResolver extends SparkRackResolver { - override def resolve(names: JList[String]): JList[String] = { - if (names.size > 0 && names.get(0) == "host3") Arrays.asList("/rack2") - else Arrays.asList("/rack1") + override def resolve(conf: Configuration, hostName: String): String = { + if (hostName == "host3") "/rack2" else "/rack1" } - override def reloadCachedMappings() {} - - def reloadCachedMappings(names: JList[String]) {} } class YarnAllocatorSuite extends SparkFunSuite with Matchers with BeforeAndAfterEach { val conf = new YarnConfiguration() - conf.setClass( - CommonConfigurationKeysPublic.NET_TOPOLOGY_NODE_SWITCH_MAPPING_IMPL_KEY, - classOf[MockResolver], classOf[DNSToSwitchMapping]) - val sparkConf = new SparkConf() sparkConf.set("spark.driver.host", "localhost") sparkConf.set("spark.driver.port", "4040") @@ -111,7 +100,8 @@ class YarnAllocatorSuite extends SparkFunSuite with Matchers with BeforeAndAfter rmClient, appAttemptId, new SecurityManager(sparkConf), - Map()) + Map(), + new MockResolver()) } def createContainer(host: String): Container = { diff --git a/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala b/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala index 895a01e36f23f..37dc4148bbfab 100644 --- a/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala +++ b/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala @@ -24,6 +24,7 @@ import java.util.{HashMap => JHashMap} import scala.collection.mutable import scala.concurrent.duration._ +import scala.io.Source import scala.language.postfixOps import com.google.common.io.{ByteStreams, Files} @@ -124,24 +125,30 @@ class YarnClusterSuite extends BaseYarnClusterSuite { testBasicYarnApp(false) } - test("run Spark in yarn-client mode with different configurations") { + test("run Spark in yarn-client mode with different configurations, ensuring redaction") { testBasicYarnApp(true, Map( "spark.driver.memory" -> "512m", "spark.executor.cores" -> "1", "spark.executor.memory" -> "512m", - "spark.executor.instances" -> "2" + "spark.executor.instances" -> "2", + // Sending some senstive information, which we'll make sure gets redacted + "spark.executorEnv.HADOOP_CREDSTORE_PASSWORD" -> YarnClusterDriver.SECRET_PASSWORD, + "spark.yarn.appMasterEnv.HADOOP_CREDSTORE_PASSWORD" -> YarnClusterDriver.SECRET_PASSWORD )) } - test("run Spark in yarn-cluster mode with different configurations") { + test("run Spark in yarn-cluster mode with different configurations, ensuring redaction") { testBasicYarnApp(false, Map( "spark.driver.memory" -> "512m", "spark.driver.cores" -> "1", "spark.executor.cores" -> "1", "spark.executor.memory" -> "512m", - "spark.executor.instances" -> "2" + "spark.executor.instances" -> "2", + // Sending some senstive information, which we'll make sure gets redacted + "spark.executorEnv.HADOOP_CREDSTORE_PASSWORD" -> YarnClusterDriver.SECRET_PASSWORD, + "spark.yarn.appMasterEnv.HADOOP_CREDSTORE_PASSWORD" -> YarnClusterDriver.SECRET_PASSWORD )) } @@ -443,6 +450,7 @@ private object YarnClusterDriverUseSparkHadoopUtilConf extends Logging with Matc private object YarnClusterDriver extends Logging with Matchers { val WAIT_TIMEOUT_MILLIS = 10000 + val SECRET_PASSWORD = "secret_password" def main(args: Array[String]): Unit = { if (args.length != 1) { @@ -489,6 +497,13 @@ private object YarnClusterDriver extends Logging with Matchers { assert(executorInfos.nonEmpty) executorInfos.foreach { info => assert(info.logUrlMap.nonEmpty) + info.logUrlMap.values.foreach { url => + val log = Source.fromURL(url).mkString + assert( + !log.contains(SECRET_PASSWORD), + s"Executor logs contain sensitive info (${SECRET_PASSWORD}): \n${log} " + ) + } } // If we are running in yarn-cluster mode, verify that driver logs links and present and are @@ -500,8 +515,13 @@ private object YarnClusterDriver extends Logging with Matchers { assert(driverLogs.contains("stderr")) assert(driverLogs.contains("stdout")) val urlStr = driverLogs("stderr") - // Ensure that this is a valid URL, else this will throw an exception - new URL(urlStr) + driverLogs.foreach { kv => + val log = Source.fromURL(kv._2).mkString + assert( + !log.contains(SECRET_PASSWORD), + s"Driver logs contain sensitive info (${SECRET_PASSWORD}): \n${log} " + ) + } val containerId = YarnSparkHadoopUtil.get.getContainerId val user = Utils.getCurrentUserName() assert(urlStr.endsWith(s"/node/containerlogs/$containerId/$user/stderr?start=-4096")) diff --git a/resource-managers/yarn/src/test/scala/org/apache/spark/scheduler/cluster/YarnSchedulerBackendSuite.scala b/resource-managers/yarn/src/test/scala/org/apache/spark/scheduler/cluster/YarnSchedulerBackendSuite.scala index 4079d9e40fc41..0a413b2c23de1 100644 --- a/resource-managers/yarn/src/test/scala/org/apache/spark/scheduler/cluster/YarnSchedulerBackendSuite.scala +++ b/resource-managers/yarn/src/test/scala/org/apache/spark/scheduler/cluster/YarnSchedulerBackendSuite.scala @@ -16,6 +16,8 @@ */ package org.apache.spark.scheduler.cluster +import scala.language.reflectiveCalls + import org.mockito.Mockito.when import org.scalatest.mock.MockitoSugar diff --git a/sql/catalyst/pom.xml b/sql/catalyst/pom.xml index 765c92b8d3b9e..8d80f8eca5dba 100644 --- a/sql/catalyst/pom.xml +++ b/sql/catalyst/pom.xml @@ -22,7 +22,7 @@ org.apache.spark spark-parent_2.11 - 2.2.0-SNAPSHOT + 2.3.0-SNAPSHOT ../../pom.xml diff --git a/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 index 59f93b3c469d5..14c511f670606 100644 --- a/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 +++ b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 @@ -56,6 +56,10 @@ singleTableIdentifier : tableIdentifier EOF ; +singleFunctionIdentifier + : functionIdentifier EOF + ; + singleDataType : dataType EOF ; @@ -85,6 +89,8 @@ statement LIKE source=tableIdentifier locationSpec? #createTableLike | ANALYZE TABLE tableIdentifier partitionSpec? COMPUTE STATISTICS (identifier | FOR COLUMNS identifierSeq)? #analyze + | ALTER TABLE tableIdentifier + ADD COLUMNS '(' columns=colTypeList ')' #addTableColumns | ALTER (TABLE | VIEW) from=tableIdentifier RENAME TO to=tableIdentifier #renameTable | ALTER (TABLE | VIEW) tableIdentifier @@ -198,7 +204,6 @@ unsupportedHiveNativeCommands | kw1=ALTER kw2=TABLE tableIdentifier partitionSpec? kw3=COMPACT | kw1=ALTER kw2=TABLE tableIdentifier partitionSpec? kw3=CONCATENATE | kw1=ALTER kw2=TABLE tableIdentifier partitionSpec? kw3=SET kw4=FILEFORMAT - | kw1=ALTER kw2=TABLE tableIdentifier partitionSpec? kw3=ADD kw4=COLUMNS | kw1=ALTER kw2=TABLE tableIdentifier partitionSpec? kw3=REPLACE kw4=COLUMNS | kw1=START kw2=TRANSACTION | kw1=COMMIT @@ -492,6 +497,10 @@ tableIdentifier : (db=identifier '.')? table=identifier ; +functionIdentifier + : (db=identifier '.')? function=identifier + ; + namedExpression : expression (AS? (identifier | identifierList))? ; @@ -506,10 +515,10 @@ expression booleanExpression : NOT booleanExpression #logicalNot + | EXISTS '(' query ')' #exists | predicated #booleanDefault | left=booleanExpression operator=AND right=booleanExpression #logicalBinary | left=booleanExpression operator=OR right=booleanExpression #logicalBinary - | EXISTS '(' query ')' #exists ; // workaround for: @@ -525,6 +534,7 @@ predicate | NOT? kind=IN '(' query ')' | NOT? kind=(RLIKE | LIKE) pattern=valueExpression | IS NOT? kind=NULL + | IS NOT? kind=DISTINCT FROM right=valueExpression ; valueExpression @@ -543,12 +553,15 @@ primaryExpression | CASE whenClause+ (ELSE elseExpression=expression)? END #searchedCase | CASE value=expression whenClause+ (ELSE elseExpression=expression)? END #simpleCase | CAST '(' expression AS dataType ')' #cast + | FIRST '(' expression (IGNORE NULLS)? ')' #first + | LAST '(' expression (IGNORE NULLS)? ')' #last | constant #constantDefault | ASTERISK #star | qualifiedName '.' ASTERISK #star - | '(' expression (',' expression)+ ')' #rowConstructor + | '(' namedExpression (',' namedExpression)+ ')' #rowConstructor | '(' query ')' #subqueryExpression - | qualifiedName '(' (setQuantifier? expression (',' expression)*)? ')' (OVER windowSpec)? #functionCall + | qualifiedName '(' (setQuantifier? namedExpression (',' namedExpression)*)? ')' + (OVER windowSpec)? #functionCall | value=primaryExpression '[' index=valueExpression ']' #subscript | identifier #columnReference | base=primaryExpression '.' fieldName=identifier #dereference @@ -700,7 +713,7 @@ nonReserved | VIEW | REPLACE | IF | NO | DATA - | START | TRANSACTION | COMMIT | ROLLBACK + | START | TRANSACTION | COMMIT | ROLLBACK | IGNORE | SORT | CLUSTER | DISTRIBUTE | UNSET | TBLPROPERTIES | SKEWED | STORED | DIRECTORIES | LOCATION | EXCHANGE | ARCHIVE | UNARCHIVE | FILEFORMAT | TOUCH | COMPACT | CONCATENATE | CHANGE | CASCADE | RESTRICT | BUCKETS | CLUSTERED | SORTED | PURGE | INPUTFORMAT | OUTPUTFORMAT @@ -826,6 +839,7 @@ TRANSACTION: 'TRANSACTION'; COMMIT: 'COMMIT'; ROLLBACK: 'ROLLBACK'; MACRO: 'MACRO'; +IGNORE: 'IGNORE'; IF: 'IF'; diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/streaming/GroupStateTimeout.java b/sql/catalyst/src/main/java/org/apache/spark/sql/streaming/GroupStateTimeout.java new file mode 100644 index 0000000000000..5f1032d1229da --- /dev/null +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/streaming/GroupStateTimeout.java @@ -0,0 +1,57 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.streaming; + +import org.apache.spark.annotation.Experimental; +import org.apache.spark.annotation.InterfaceStability; +import org.apache.spark.sql.catalyst.plans.logical.*; + +/** + * Represents the type of timeouts possible for the Dataset operations + * `mapGroupsWithState` and `flatMapGroupsWithState`. See documentation on + * `GroupState` for more details. + * + * @since 2.2.0 + */ +@Experimental +@InterfaceStability.Evolving +public class GroupStateTimeout { + + /** + * Timeout based on processing time. The duration of timeout can be set for each group in + * `map/flatMapGroupsWithState` by calling `GroupState.setTimeoutDuration()`. See documentation + * on `GroupState` for more details. + */ + public static GroupStateTimeout ProcessingTimeTimeout() { + return ProcessingTimeTimeout$.MODULE$; + } + + /** + * Timeout based on event-time. The event-time timestamp for timeout can be set for each + * group in `map/flatMapGroupsWithState` by calling `GroupState.setTimeoutTimestamp()`. + * In addition, you have to define the watermark in the query using `Dataset.withWatermark`. + * When the watermark advances beyond the set timestamp of a group and the group has not + * received any data, then the group times out. See documentation on + * `GroupState` for more details. + */ + public static GroupStateTimeout EventTimeTimeout() { return EventTimeTimeout$.MODULE$; } + + /** No timeout. */ + public static GroupStateTimeout NoTimeout() { return NoTimeout$.MODULE$; } + +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/AnalysisException.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/AnalysisException.scala index ff8576157305b..50ee6cd4085ea 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/AnalysisException.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/AnalysisException.scala @@ -43,7 +43,7 @@ class AnalysisException protected[sql] ( } override def getMessage: String = { - val planAnnotation = plan.map(p => s";\n$p").getOrElse("") + val planAnnotation = Option(plan).flatten.map(p => s";\n$p").getOrElse("") getSimpleMessage + planAnnotation } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystConf.scala deleted file mode 100644 index 5f50ce1ba68ff..0000000000000 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystConf.scala +++ /dev/null @@ -1,80 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.catalyst - -import java.util.TimeZone - -import org.apache.spark.sql.catalyst.analysis._ - -/** - * Interface for configuration options used in the catalyst module. - */ -trait CatalystConf { - def caseSensitiveAnalysis: Boolean - - def orderByOrdinal: Boolean - def groupByOrdinal: Boolean - - def optimizerMaxIterations: Int - def optimizerInSetConversionThreshold: Int - def maxCaseBranchesForCodegen: Int - - def tableRelationCacheSize: Int - - def runSQLonFile: Boolean - - def warehousePath: String - - def sessionLocalTimeZone: String - - /** If true, cartesian products between relations will be allowed for all - * join types(inner, (left|right|full) outer). - * If false, cartesian products will require explicit CROSS JOIN syntax. - */ - def crossJoinEnabled: Boolean - - /** - * Returns the [[Resolver]] for the current configuration, which can be used to determine if two - * identifiers are equal. - */ - def resolver: Resolver = { - if (caseSensitiveAnalysis) caseSensitiveResolution else caseInsensitiveResolution - } - - /** - * Enables CBO for estimation of plan statistics when set true. - */ - def cboEnabled: Boolean -} - - -/** A CatalystConf that can be used for local testing. */ -case class SimpleCatalystConf( - caseSensitiveAnalysis: Boolean, - orderByOrdinal: Boolean = true, - groupByOrdinal: Boolean = true, - optimizerMaxIterations: Int = 100, - optimizerInSetConversionThreshold: Int = 10, - maxCaseBranchesForCodegen: Int = 20, - tableRelationCacheSize: Int = 1000, - runSQLonFile: Boolean = true, - crossJoinEnabled: Boolean = false, - cboEnabled: Boolean = false, - warehousePath: String = "/user/hive/warehouse", - sessionLocalTimeZone: String = TimeZone.getDefault().getID) - extends CatalystConf diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala index 5b9161551a7af..d4ebdb139fe0f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala @@ -310,11 +310,7 @@ object CatalystTypeConverters { case d: JavaBigInteger => Decimal(d) case d: Decimal => d } - if (decimal.changePrecision(dataType.precision, dataType.scale)) { - decimal - } else { - null - } + decimal.toPrecision(dataType.precision, dataType.scale).orNull } override def toScala(catalystValue: Decimal): JavaBigDecimal = { if (catalystValue == null) null diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala index e9d9508e5adfe..86a73a319ec3f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala @@ -69,7 +69,8 @@ object JavaTypeInference { * @param typeToken Java type * @return (SQL data type, nullable) */ - private def inferDataType(typeToken: TypeToken[_]): (DataType, Boolean) = { + private def inferDataType(typeToken: TypeToken[_], seenTypeSet: Set[Class[_]] = Set.empty) + : (DataType, Boolean) = { typeToken.getRawType match { case c: Class[_] if c.isAnnotationPresent(classOf[SQLUserDefinedType]) => (c.getAnnotation(classOf[SQLUserDefinedType]).udt().newInstance(), true) @@ -104,26 +105,32 @@ object JavaTypeInference { case c: Class[_] if c == classOf[java.sql.Timestamp] => (TimestampType, true) case _ if typeToken.isArray => - val (dataType, nullable) = inferDataType(typeToken.getComponentType) + val (dataType, nullable) = inferDataType(typeToken.getComponentType, seenTypeSet) (ArrayType(dataType, nullable), true) case _ if iterableType.isAssignableFrom(typeToken) => - val (dataType, nullable) = inferDataType(elementType(typeToken)) + val (dataType, nullable) = inferDataType(elementType(typeToken), seenTypeSet) (ArrayType(dataType, nullable), true) case _ if mapType.isAssignableFrom(typeToken) => val (keyType, valueType) = mapKeyValueType(typeToken) - val (keyDataType, _) = inferDataType(keyType) - val (valueDataType, nullable) = inferDataType(valueType) + val (keyDataType, _) = inferDataType(keyType, seenTypeSet) + val (valueDataType, nullable) = inferDataType(valueType, seenTypeSet) (MapType(keyDataType, valueDataType, nullable), true) case other => + if (seenTypeSet.contains(other)) { + throw new UnsupportedOperationException( + "Cannot have circular references in bean class, but got the circular reference " + + s"of class $other") + } + // TODO: we should only collect properties that have getter and setter. However, some tests // pass in scala case class as java bean class which doesn't have getter and setter. val properties = getJavaBeanReadableProperties(other) val fields = properties.map { property => val returnType = typeToken.method(property.getReadMethod).getReturnType - val (dataType, nullable) = inferDataType(returnType) + val (dataType, nullable) = inferDataType(returnType, seenTypeSet + other) new StructField(property.getName, dataType, nullable) } (new StructType(fields), true) @@ -197,20 +204,19 @@ object JavaTypeInference { typeToken.getRawType match { case c if !inferExternalType(c).isInstanceOf[ObjectType] => getPath - case c if c == classOf[java.lang.Short] => - NewInstance(c, getPath :: Nil, ObjectType(c)) - case c if c == classOf[java.lang.Integer] => - NewInstance(c, getPath :: Nil, ObjectType(c)) - case c if c == classOf[java.lang.Long] => - NewInstance(c, getPath :: Nil, ObjectType(c)) - case c if c == classOf[java.lang.Double] => - NewInstance(c, getPath :: Nil, ObjectType(c)) - case c if c == classOf[java.lang.Byte] => - NewInstance(c, getPath :: Nil, ObjectType(c)) - case c if c == classOf[java.lang.Float] => - NewInstance(c, getPath :: Nil, ObjectType(c)) - case c if c == classOf[java.lang.Boolean] => - NewInstance(c, getPath :: Nil, ObjectType(c)) + case c if c == classOf[java.lang.Short] || + c == classOf[java.lang.Integer] || + c == classOf[java.lang.Long] || + c == classOf[java.lang.Double] || + c == classOf[java.lang.Float] || + c == classOf[java.lang.Byte] || + c == classOf[java.lang.Boolean] => + StaticInvoke( + c, + ObjectType(c), + "valueOf", + getPath :: Nil, + propagateNull = true) case c if c == classOf[java.sql.Date] => StaticInvoke( @@ -336,7 +342,11 @@ object JavaTypeInference { */ def serializerFor(beanClass: Class[_]): CreateNamedStruct = { val inputObject = BoundReference(0, ObjectType(beanClass), nullable = true) - serializerFor(inputObject, TypeToken.of(beanClass)).asInstanceOf[CreateNamedStruct] + val nullSafeInput = AssertNotNull(inputObject, Seq("top level input bean")) + serializerFor(nullSafeInput, TypeToken.of(beanClass)) match { + case expressions.If(_, _, s: CreateNamedStruct) => s + case other => CreateNamedStruct(expressions.Literal("value") :: other :: Nil) + } } private def serializerFor(inputObject: Expression, typeToken: TypeToken[_]): Expression = { @@ -420,7 +430,7 @@ object JavaTypeInference { case other => val properties = getJavaBeanReadableAndWritableProperties(other) - CreateNamedStruct(properties.flatMap { p => + val nonNullOutput = CreateNamedStruct(properties.flatMap { p => val fieldName = p.getName val fieldType = typeToken.method(p.getReadMethod).getReturnType val fieldValue = Invoke( @@ -429,6 +439,9 @@ object JavaTypeInference { inferExternalType(fieldType.getRawType)) expressions.Literal(fieldName) :: serializerFor(fieldValue, fieldType) :: Nil }) + + val nullOutput = expressions.Literal.create(null, nonNullOutput.dataType) + expressions.If(IsNull(inputObject), nullOutput, nonNullOutput) } } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala index 7f7dd51aa2650..82710a2a183ab 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala @@ -92,7 +92,7 @@ object ScalaReflection extends ScalaReflection { * Array[T]. Special handling is performed for primitive types to map them back to their raw * JVM form instead of the Scala Array that handles auto boxing. */ - private def arrayClassFor(tpe: `Type`): DataType = ScalaReflectionLock.synchronized { + private def arrayClassFor(tpe: `Type`): ObjectType = ScalaReflectionLock.synchronized { val cls = tpe match { case t if t <:< definitions.IntTpe => classOf[Array[Int]] case t if t <:< definitions.LongTpe => classOf[Array[Long]] @@ -132,7 +132,7 @@ object ScalaReflection extends ScalaReflection { def deserializerFor[T : TypeTag]: Expression = { val tpe = localTypeOf[T] val clsName = getClassNameFromType(tpe) - val walkedTypePath = s"""- root class: "${clsName}"""" :: Nil + val walkedTypePath = s"""- root class: "$clsName"""" :: Nil deserializerFor(tpe, None, walkedTypePath) } @@ -178,15 +178,17 @@ object ScalaReflection extends ScalaReflection { * is [a: int, b: long], then we will hit runtime error and say that we can't construct class * `Data` with int and long, because we lost the information that `b` should be a string. * - * This method help us "remember" the required data type by adding a `UpCast`. Note that we - * don't need to cast struct type because there must be `UnresolvedExtractValue` or - * `GetStructField` wrapping it, thus we only need to handle leaf type. + * This method help us "remember" the required data type by adding a `UpCast`. Note that we + * only need to do this for leaf nodes. */ def upCastToExpectedType( expr: Expression, expected: DataType, walkedTypePath: Seq[String]): Expression = expected match { case _: StructType => expr + case _: ArrayType => expr + // TODO: ideally we should also skip MapType, but nested StructType inside MapType is rare and + // it's not trivial to support by-name resolution for StructType inside MapType. case _ => UpCast(expr, expected, walkedTypePath) } @@ -202,37 +204,37 @@ object ScalaReflection extends ScalaReflection { case t if t <:< localTypeOf[java.lang.Integer] => val boxedType = classOf[java.lang.Integer] val objectType = ObjectType(boxedType) - NewInstance(boxedType, getPath :: Nil, objectType) + StaticInvoke(boxedType, objectType, "valueOf", getPath :: Nil, propagateNull = true) case t if t <:< localTypeOf[java.lang.Long] => val boxedType = classOf[java.lang.Long] val objectType = ObjectType(boxedType) - NewInstance(boxedType, getPath :: Nil, objectType) + StaticInvoke(boxedType, objectType, "valueOf", getPath :: Nil, propagateNull = true) case t if t <:< localTypeOf[java.lang.Double] => val boxedType = classOf[java.lang.Double] val objectType = ObjectType(boxedType) - NewInstance(boxedType, getPath :: Nil, objectType) + StaticInvoke(boxedType, objectType, "valueOf", getPath :: Nil, propagateNull = true) case t if t <:< localTypeOf[java.lang.Float] => val boxedType = classOf[java.lang.Float] val objectType = ObjectType(boxedType) - NewInstance(boxedType, getPath :: Nil, objectType) + StaticInvoke(boxedType, objectType, "valueOf", getPath :: Nil, propagateNull = true) case t if t <:< localTypeOf[java.lang.Short] => val boxedType = classOf[java.lang.Short] val objectType = ObjectType(boxedType) - NewInstance(boxedType, getPath :: Nil, objectType) + StaticInvoke(boxedType, objectType, "valueOf", getPath :: Nil, propagateNull = true) case t if t <:< localTypeOf[java.lang.Byte] => val boxedType = classOf[java.lang.Byte] val objectType = ObjectType(boxedType) - NewInstance(boxedType, getPath :: Nil, objectType) + StaticInvoke(boxedType, objectType, "valueOf", getPath :: Nil, propagateNull = true) case t if t <:< localTypeOf[java.lang.Boolean] => val boxedType = classOf[java.lang.Boolean] val objectType = ObjectType(boxedType) - NewInstance(boxedType, getPath :: Nil, objectType) + StaticInvoke(boxedType, objectType, "valueOf", getPath :: Nil, propagateNull = true) case t if t <:< localTypeOf[java.sql.Date] => StaticInvoke( @@ -249,112 +251,83 @@ object ScalaReflection extends ScalaReflection { getPath :: Nil) case t if t <:< localTypeOf[java.lang.String] => - Invoke(getPath, "toString", ObjectType(classOf[String])) + Invoke(getPath, "toString", ObjectType(classOf[String]), returnNullable = false) case t if t <:< localTypeOf[java.math.BigDecimal] => - Invoke(getPath, "toJavaBigDecimal", ObjectType(classOf[java.math.BigDecimal])) + Invoke(getPath, "toJavaBigDecimal", ObjectType(classOf[java.math.BigDecimal]), + returnNullable = false) case t if t <:< localTypeOf[BigDecimal] => - Invoke(getPath, "toBigDecimal", ObjectType(classOf[BigDecimal])) + Invoke(getPath, "toBigDecimal", ObjectType(classOf[BigDecimal]), returnNullable = false) case t if t <:< localTypeOf[java.math.BigInteger] => - Invoke(getPath, "toJavaBigInteger", ObjectType(classOf[java.math.BigInteger])) + Invoke(getPath, "toJavaBigInteger", ObjectType(classOf[java.math.BigInteger]), + returnNullable = false) case t if t <:< localTypeOf[scala.math.BigInt] => - Invoke(getPath, "toScalaBigInt", ObjectType(classOf[scala.math.BigInt])) + Invoke(getPath, "toScalaBigInt", ObjectType(classOf[scala.math.BigInt]), + returnNullable = false) case t if t <:< localTypeOf[Array[_]] => val TypeRef(_, _, Seq(elementType)) = t + val Schema(dataType, elementNullable) = schemaFor(elementType) + val className = getClassNameFromType(elementType) + val newTypePath = s"""- array element class: "$className"""" +: walkedTypePath - // TODO: add runtime null check for primitive array - val primitiveMethod = elementType match { - case t if t <:< definitions.IntTpe => Some("toIntArray") - case t if t <:< definitions.LongTpe => Some("toLongArray") - case t if t <:< definitions.DoubleTpe => Some("toDoubleArray") - case t if t <:< definitions.FloatTpe => Some("toFloatArray") - case t if t <:< definitions.ShortTpe => Some("toShortArray") - case t if t <:< definitions.ByteTpe => Some("toByteArray") - case t if t <:< definitions.BooleanTpe => Some("toBooleanArray") - case _ => None + val mapFunction: Expression => Expression = element => { + // upcast the array element to the data type the encoder expected. + val casted = upCastToExpectedType(element, dataType, newTypePath) + val converter = deserializerFor(elementType, Some(casted), newTypePath) + if (elementNullable) { + converter + } else { + AssertNotNull(converter, newTypePath) + } } - primitiveMethod.map { method => - Invoke(getPath, method, arrayClassFor(elementType)) - }.getOrElse { - val className = getClassNameFromType(elementType) - val newTypePath = s"""- array element class: "$className"""" +: walkedTypePath - Invoke( - MapObjects( - p => deserializerFor(elementType, Some(p), newTypePath), - getPath, - schemaFor(elementType).dataType), - "array", - arrayClassFor(elementType)) + val arrayData = UnresolvedMapObjects(mapFunction, getPath) + val arrayCls = arrayClassFor(elementType) + + if (elementNullable) { + Invoke(arrayData, "array", arrayCls, returnNullable = false) + } else { + val primitiveMethod = elementType match { + case t if t <:< definitions.IntTpe => "toIntArray" + case t if t <:< definitions.LongTpe => "toLongArray" + case t if t <:< definitions.DoubleTpe => "toDoubleArray" + case t if t <:< definitions.FloatTpe => "toFloatArray" + case t if t <:< definitions.ShortTpe => "toShortArray" + case t if t <:< definitions.ByteTpe => "toByteArray" + case t if t <:< definitions.BooleanTpe => "toBooleanArray" + case other => throw new IllegalStateException("expect primitive array element type " + + "but got " + other) + } + Invoke(arrayData, primitiveMethod, arrayCls, returnNullable = false) } case t if t <:< localTypeOf[Seq[_]] => val TypeRef(_, _, Seq(elementType)) = t - val Schema(dataType, nullable) = schemaFor(elementType) + val Schema(dataType, elementNullable) = schemaFor(elementType) val className = getClassNameFromType(elementType) val newTypePath = s"""- array element class: "$className"""" +: walkedTypePath - val mapFunction: Expression => Expression = p => { - val converter = deserializerFor(elementType, Some(p), newTypePath) - if (nullable) { + val mapFunction: Expression => Expression = element => { + // upcast the array element to the data type the encoder expected. + val casted = upCastToExpectedType(element, dataType, newTypePath) + val converter = deserializerFor(elementType, Some(casted), newTypePath) + if (elementNullable) { converter } else { AssertNotNull(converter, newTypePath) } } - val array = Invoke( - MapObjects(mapFunction, getPath, dataType), - "array", - ObjectType(classOf[Array[Any]])) - - val wrappedArray = StaticInvoke( - scala.collection.mutable.WrappedArray.getClass, - ObjectType(classOf[Seq[_]]), - "make", - array :: Nil) - - if (localTypeOf[scala.collection.mutable.WrappedArray[_]] <:< t.erasure) { - wrappedArray - } else { - // Convert to another type using `to` - val cls = mirror.runtimeClass(t.typeSymbol.asClass) - import scala.collection.generic.CanBuildFrom - import scala.reflect.ClassTag - - // Some canBuildFrom methods take an implicit ClassTag parameter - val cbfParams = try { - cls.getDeclaredMethod("canBuildFrom", classOf[ClassTag[_]]) - StaticInvoke( - ClassTag.getClass, - ObjectType(classOf[ClassTag[_]]), - "apply", - StaticInvoke( - cls, - ObjectType(classOf[Class[_]]), - "getClass" - ) :: Nil - ) :: Nil - } catch { - case _: NoSuchMethodException => Nil - } - - Invoke( - wrappedArray, - "to", - ObjectType(cls), - StaticInvoke( - cls, - ObjectType(classOf[CanBuildFrom[_, _, _]]), - "canBuildFrom", - cbfParams - ) :: Nil - ) + val companion = t.normalize.typeSymbol.companionSymbol.typeSignature + val cls = companion.declaration(newTermName("newBuilder")) match { + case NoSymbol => classOf[Seq[_]] + case _ => mirror.runtimeClass(t.typeSymbol.asClass) } + UnresolvedMapObjects(mapFunction, getPath, Some(cls)) case t if t <:< localTypeOf[Map[_, _]] => // TODO: add walked type path for map @@ -364,19 +337,21 @@ object ScalaReflection extends ScalaReflection { Invoke( MapObjects( p => deserializerFor(keyType, Some(p), walkedTypePath), - Invoke(getPath, "keyArray", ArrayType(schemaFor(keyType).dataType)), + Invoke(getPath, "keyArray", ArrayType(schemaFor(keyType).dataType), + returnNullable = false), schemaFor(keyType).dataType), "array", - ObjectType(classOf[Array[Any]])) + ObjectType(classOf[Array[Any]]), returnNullable = false) val valueData = Invoke( MapObjects( p => deserializerFor(valueType, Some(p), walkedTypePath), - Invoke(getPath, "valueArray", ArrayType(schemaFor(valueType).dataType)), + Invoke(getPath, "valueArray", ArrayType(schemaFor(valueType).dataType), + returnNullable = false), schemaFor(valueType).dataType), "array", - ObjectType(classOf[Array[Any]])) + ObjectType(classOf[Array[Any]]), returnNullable = false) StaticInvoke( ArrayBasedMapData.getClass, @@ -470,14 +445,15 @@ object ScalaReflection extends ScalaReflection { private def serializerFor( inputObject: Expression, tpe: `Type`, - walkedTypePath: Seq[String]): Expression = ScalaReflectionLock.synchronized { + walkedTypePath: Seq[String], + seenTypeSet: Set[`Type`] = Set.empty): Expression = ScalaReflectionLock.synchronized { def toCatalystArray(input: Expression, elementType: `Type`): Expression = { dataTypeFor(elementType) match { case dt: ObjectType => val clsName = getClassNameFromType(elementType) val newPath = s"""- array element class: "$clsName"""" +: walkedTypePath - MapObjects(serializerFor(_, elementType, newPath), input, dt) + MapObjects(serializerFor(_, elementType, newPath, seenTypeSet), input, dt) case dt @ (BooleanType | ByteType | ShortType | IntegerType | LongType | FloatType | DoubleType) => @@ -511,7 +487,7 @@ object ScalaReflection extends ScalaReflection { val className = getClassNameFromType(optType) val newPath = s"""- option value class: "$className"""" +: walkedTypePath val unwrapped = UnwrapOption(dataTypeFor(optType), inputObject) - serializerFor(unwrapped, optType, newPath) + serializerFor(unwrapped, optType, newPath, seenTypeSet) // Since List[_] also belongs to localTypeOf[Product], we put this case before // "case t if definedByConstructorParams(t)" to make sure it will match to the @@ -534,9 +510,9 @@ object ScalaReflection extends ScalaReflection { ExternalMapToCatalyst( inputObject, dataTypeFor(keyType), - serializerFor(_, keyType, keyPath), + serializerFor(_, keyType, keyPath, seenTypeSet), dataTypeFor(valueType), - serializerFor(_, valueType, valuePath), + serializerFor(_, valueType, valuePath, seenTypeSet), valueNullable = !valueType.typeSymbol.asClass.isPrimitive) case t if t <:< localTypeOf[String] => @@ -622,6 +598,11 @@ object ScalaReflection extends ScalaReflection { Invoke(obj, "serialize", udt, inputObject :: Nil) case t if definedByConstructorParams(t) => + if (seenTypeSet.contains(t)) { + throw new UnsupportedOperationException( + s"cannot have circular references in class, but got the circular reference of class $t") + } + val params = getConstructorParameters(t) val nonNullOutput = CreateNamedStruct(params.flatMap { case (fieldName, fieldType) => if (javaKeywords.contains(fieldName)) { @@ -634,7 +615,8 @@ object ScalaReflection extends ScalaReflection { returnNullable = !fieldType.typeSymbol.asClass.isPrimitive) val clsName = getClassNameFromType(fieldType) val newPath = s"""- field (class: "$clsName", name: "$fieldName")""" +: walkedTypePath - expressions.Literal(fieldName) :: serializerFor(fieldValue, fieldType, newPath) :: Nil + expressions.Literal(fieldName) :: + serializerFor(fieldValue, fieldType, newPath, seenTypeSet + t) :: Nil }) val nullOutput = expressions.Literal.create(null, nonNullOutput.dataType) expressions.If(IsNull(inputObject), nullOutput, nonNullOutput) @@ -643,7 +625,6 @@ object ScalaReflection extends ScalaReflection { throw new UnsupportedOperationException( s"No Encoder found for $tpe\n" + walkedTypePath.mkString("\n")) } - } /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/AlreadyExistException.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/AlreadyExistException.scala index ec56fe7729c2a..57f7a80bedc6c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/AlreadyExistException.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/AlreadyExistException.scala @@ -44,6 +44,3 @@ class PartitionsAlreadyExistException(db: String, table: String, specs: Seq[Tabl class FunctionAlreadyExistsException(db: String, func: String) extends AnalysisException(s"Function '$func' already exists in database '$db'") - -class TempFunctionAlreadyExistsException(func: String) - extends AnalysisException(s"Temporary function '$func' already exists") diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 6d569b612de7d..72e7d5dd3638d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -21,18 +21,20 @@ import scala.annotation.tailrec import scala.collection.mutable.ArrayBuffer import org.apache.spark.sql.AnalysisException -import org.apache.spark.sql.catalyst.{CatalystConf, ScalaReflection, SimpleCatalystConf, TableIdentifier} +import org.apache.spark.sql.catalyst._ import org.apache.spark.sql.catalyst.catalog._ import org.apache.spark.sql.catalyst.encoders.OuterScopes import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate._ -import org.apache.spark.sql.catalyst.expressions.objects.NewInstance +import org.apache.spark.sql.catalyst.expressions.objects.{LambdaVariable, MapObjects, NewInstance, UnresolvedMapObjects} +import org.apache.spark.sql.catalyst.expressions.SubExprUtils._ import org.apache.spark.sql.catalyst.optimizer.BooleanSimplification import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, _} import org.apache.spark.sql.catalyst.rules._ import org.apache.spark.sql.catalyst.trees.TreeNodeRef import org.apache.spark.sql.catalyst.util.toPrettySQL +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ /** @@ -41,13 +43,13 @@ import org.apache.spark.sql.types._ * to resolve attribute references. */ object SimpleAnalyzer extends Analyzer( - new SessionCatalog( - new InMemoryCatalog, - EmptyFunctionRegistry, - new SimpleCatalystConf(caseSensitiveAnalysis = true)) { - override def createDatabase(dbDefinition: CatalogDatabase, ignoreIfExists: Boolean) {} - }, - new SimpleCatalystConf(caseSensitiveAnalysis = true)) + new SessionCatalog( + new InMemoryCatalog, + EmptyFunctionRegistry, + new SQLConf().copy(SQLConf.CASE_SENSITIVE -> true)) { + override def createDatabase(dbDefinition: CatalogDatabase, ignoreIfExists: Boolean) {} + }, + new SQLConf().copy(SQLConf.CASE_SENSITIVE -> true)) /** * Provides a way to keep state during the analysis, this enables us to decouple the concerns @@ -57,13 +59,12 @@ object SimpleAnalyzer extends Analyzer( * * @param defaultDatabase The default database used in the view resolution, this overrules the * current catalog database. - * @param nestedViewLevel The nested level in the view resolution, this enables us to limit the + * @param nestedViewDepth The nested depth in the view resolution, this enables us to limit the * depth of nested views. - * TODO Limit the depth of nested views. */ case class AnalysisContext( defaultDatabase: Option[String] = None, - nestedViewLevel: Int = 0) + nestedViewDepth: Int = 0) object AnalysisContext { private val value = new ThreadLocal[AnalysisContext]() { @@ -76,7 +77,7 @@ object AnalysisContext { def withAnalysisContext[A](database: Option[String])(f: => A): A = { val originContext = value.get() val context = AnalysisContext(defaultDatabase = database, - nestedViewLevel = originContext.nestedViewLevel + 1) + nestedViewDepth = originContext.nestedViewDepth + 1) set(context) try f finally { set(originContext) } } @@ -89,11 +90,11 @@ object AnalysisContext { */ class Analyzer( catalog: SessionCatalog, - conf: CatalystConf, + conf: SQLConf, maxIterations: Int) extends RuleExecutor[LogicalPlan] with CheckAnalysis { - def this(catalog: SessionCatalog, conf: CatalystConf) = { + def this(catalog: SessionCatalog, conf: SQLConf) = { this(catalog, conf, conf.optimizerMaxIterations) } @@ -117,6 +118,8 @@ class Analyzer( Batch("Hints", fixedPoint, new ResolveHints.ResolveBroadcastHints(conf), ResolveHints.RemoveAllHints), + Batch("Simple Sanity Check", Once, + LookupFunctions), Batch("Substitution", fixedPoint, CTESubstitution, WindowsSubstitution, @@ -133,6 +136,7 @@ class Analyzer( ResolveGroupingAnalytics :: ResolvePivot :: ResolveOrdinalInOrderByAndGroupBy :: + ResolveAggAliasInGroupBy :: ResolveMissingReferences :: ExtractGenerator :: ResolveGenerate :: @@ -147,6 +151,7 @@ class Analyzer( ResolveAggregateFunctions :: TimeWindowing :: ResolveInlineTables(conf) :: + ResolveTimeZone(conf) :: TypeCoercion.typeCoercionRules ++ extendedResolutionRules : _*), Batch("Post-Hoc Resolution", Once, postHocResolutionRules: _*), @@ -158,8 +163,8 @@ class Analyzer( HandleNullInputsForUDF), Batch("FixNullability", Once, FixNullability), - Batch("ResolveTimeZone", Once, - ResolveTimeZone), + Batch("Subquery", Once, + UpdateOuterReferences), Batch("Cleanup", fixedPoint, CleanupAliases) ) @@ -168,7 +173,7 @@ class Analyzer( * Analyze cte definitions and substitute child plan with analyzed cte definitions. */ object CTESubstitution extends Rule[LogicalPlan] { - def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { + def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators { case With(child, relations) => substituteCTE(child, relations.foldLeft(Seq.empty[(String, LogicalPlan)]) { case (resolved, (name, relation)) => @@ -196,7 +201,7 @@ class Analyzer( * Substitute child plan with WindowSpecDefinitions. */ object WindowsSubstitution extends Rule[LogicalPlan] { - def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { + def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators { // Lookup WindowSpecDefinitions. This rule works with unresolved children. case WithWindowDefinition(windowDefinitions, child) => child.transform { @@ -238,7 +243,7 @@ class Analyzer( private def hasUnresolvedAlias(exprs: Seq[NamedExpression]) = exprs.exists(_.find(_.isInstanceOf[UnresolvedAlias]).isDefined) - def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { + def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators { case Aggregate(groups, aggs, child) if child.resolved && hasUnresolvedAlias(aggs) => Aggregate(groups, assignAliases(aggs), child) @@ -482,14 +487,16 @@ class Analyzer( case Pivot(groupByExprs, pivotColumn, pivotValues, aggregates, child) => val singleAgg = aggregates.size == 1 def outputName(value: Literal, aggregate: Expression): String = { + val utf8Value = Cast(value, StringType, Some(conf.sessionLocalTimeZone)).eval(EmptyRow) + val stringValue: String = Option(utf8Value).map(_.toString).getOrElse("null") if (singleAgg) { - value.toString + stringValue } else { val suffix = aggregate match { case n: NamedExpression => n.name case _ => toPrettySQL(aggregate) } - value + "_" + suffix + stringValue + "_" + suffix } } if (aggregates.forall(a => PivotFirst.supportsDataType(a.dataType))) { @@ -520,7 +527,7 @@ class Analyzer( } else { val pivotAggregates: Seq[NamedExpression] = pivotValues.flatMap { value => def ifExpr(expr: Expression) = { - If(EqualTo(pivotColumn, value), expr, Literal(null)) + If(EqualNullSafe(pivotColumn, value), expr, Literal(null)) } aggregates.map { aggregate => val filteredAggregate = aggregate.transformDown { @@ -593,18 +600,28 @@ class Analyzer( case view @ View(desc, _, child) if !child.resolved => // Resolve all the UnresolvedRelations and Views in the child. val newChild = AnalysisContext.withAnalysisContext(desc.viewDefaultDatabase) { + if (AnalysisContext.get.nestedViewDepth > conf.maxNestedViewDepth) { + view.failAnalysis(s"The depth of view ${view.desc.identifier} exceeds the maximum " + + s"view resolution depth (${conf.maxNestedViewDepth}). Analysis is aborted to " + + "avoid errors. Increase the value of spark.sql.view.maxNestedViewDepth to work " + + "aroud this.") + } execute(child) } view.copy(child = newChild) - case p @ SubqueryAlias(_, view: View, _) => + case p @ SubqueryAlias(_, view: View) => val newChild = resolveRelation(view) p.copy(child = newChild) case _ => plan } - def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { + def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators { case i @ InsertIntoTable(u: UnresolvedRelation, parts, child, _, _) if child.resolved => - i.copy(table = EliminateSubqueryAliases(lookupTableFromCatalog(u))) + EliminateSubqueryAliases(lookupTableFromCatalog(u)) match { + case v: View => + u.failAnalysis(s"Inserting into a view is not allowed. View: ${v.desc.identifier}.") + case other => i.copy(table = other) + } case u: UnresolvedRelation => resolveRelation(u) } @@ -704,14 +721,73 @@ class Analyzer( } transformUp { case other => other transformExpressions { case a: Attribute => - attributeRewrites.get(a).getOrElse(a).withQualifier(a.qualifier) + dedupAttr(a, attributeRewrites) + case s: SubqueryExpression => + s.withNewPlan(dedupOuterReferencesInSubquery(s.plan, attributeRewrites)) } } newRight } } - def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { + private def dedupAttr(attr: Attribute, attrMap: AttributeMap[Attribute]): Attribute = { + attrMap.get(attr).getOrElse(attr).withQualifier(attr.qualifier) + } + + /** + * The outer plan may have been de-duplicated and the function below updates the + * outer references to refer to the de-duplicated attributes. + * + * For example (SQL): + * {{{ + * SELECT * FROM t1 + * INTERSECT + * SELECT * FROM t1 + * WHERE EXISTS (SELECT 1 + * FROM t2 + * WHERE t1.c1 = t2.c1) + * }}} + * Plan before resolveReference rule. + * 'Intersect + * :- Project [c1#245, c2#246] + * : +- SubqueryAlias t1 + * : +- Relation[c1#245,c2#246] parquet + * +- 'Project [*] + * +- Filter exists#257 [c1#245] + * : +- Project [1 AS 1#258] + * : +- Filter (outer(c1#245) = c1#251) + * : +- SubqueryAlias t2 + * : +- Relation[c1#251,c2#252] parquet + * +- SubqueryAlias t1 + * +- Relation[c1#245,c2#246] parquet + * Plan after the resolveReference rule. + * Intersect + * :- Project [c1#245, c2#246] + * : +- SubqueryAlias t1 + * : +- Relation[c1#245,c2#246] parquet + * +- Project [c1#259, c2#260] + * +- Filter exists#257 [c1#259] + * : +- Project [1 AS 1#258] + * : +- Filter (outer(c1#259) = c1#251) => Updated + * : +- SubqueryAlias t2 + * : +- Relation[c1#251,c2#252] parquet + * +- SubqueryAlias t1 + * +- Relation[c1#259,c2#260] parquet => Outer plan's attributes are de-duplicated. + */ + private def dedupOuterReferencesInSubquery( + plan: LogicalPlan, + attrMap: AttributeMap[Attribute]): LogicalPlan = { + plan transformDown { case currentFragment => + currentFragment transformExpressions { + case OuterReference(a: Attribute) => + OuterReference(dedupAttr(a, attrMap)) + case s: SubqueryExpression => + s.withNewPlan(dedupOuterReferencesInSubquery(s.plan, attrMap)) + } + } + } + + def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators { case p: LogicalPlan if !p.childrenResolved => p // If the projection list contains Stars, expand it. @@ -769,11 +845,10 @@ class Analyzer( case q: LogicalPlan => logTrace(s"Attempting to resolve ${q.simpleString}") - q transformExpressionsUp { + q.transformExpressionsUp { case u @ UnresolvedAttribute(nameParts) => - // Leave unchanged if resolution fails. Hopefully will be resolved next round. - val result = - withPosition(u) { q.resolveChildren(nameParts, resolver).getOrElse(u) } + // Leave unchanged if resolution fails. Hopefully will be resolved next round. + val result = withPosition(u) { q.resolveChildren(nameParts, resolver).getOrElse(u) } logDebug(s"Resolving $u to $result") result case UnresolvedExtractValue(child, fieldExpr) if child.resolved => @@ -886,16 +961,16 @@ class Analyzer( * have no effect on the results. */ object ResolveOrdinalInOrderByAndGroupBy extends Rule[LogicalPlan] { - def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { + def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators { case p if !p.childrenResolved => p // Replace the index with the related attribute for ORDER BY, // which is a 1-base position of the projection list. - case s @ Sort(orders, global, child) + case Sort(orders, global, child) if orders.exists(_.child.isInstanceOf[UnresolvedOrdinal]) => val newOrders = orders map { - case s @ SortOrder(UnresolvedOrdinal(index), direction, nullOrdering) => + case s @ SortOrder(UnresolvedOrdinal(index), direction, nullOrdering, _) => if (index > 0 && index <= child.output.size) { - SortOrder(child.output(index - 1), direction, nullOrdering) + SortOrder(child.output(index - 1), direction, nullOrdering, Set.empty) } else { s.failAnalysis( s"ORDER BY position $index is not in select list " + @@ -907,17 +982,11 @@ class Analyzer( // Replace the index with the corresponding expression in aggregateExpressions. The index is // a 1-base position of aggregateExpressions, which is output columns (select expression) - case a @ Aggregate(groups, aggs, child) if aggs.forall(_.resolved) && + case Aggregate(groups, aggs, child) if aggs.forall(_.resolved) && groups.exists(_.isInstanceOf[UnresolvedOrdinal]) => val newGroups = groups.map { - case ordinal @ UnresolvedOrdinal(index) if index > 0 && index <= aggs.size => - aggs(index - 1) match { - case e if ResolveAggregateFunctions.containsAggregate(e) => - ordinal.failAnalysis( - s"GROUP BY position $index is an aggregate function, and " + - "aggregate functions are not allowed in GROUP BY") - case o => o - } + case u @ UnresolvedOrdinal(index) if index > 0 && index <= aggs.size => + aggs(index - 1) case ordinal @ UnresolvedOrdinal(index) => ordinal.failAnalysis( s"GROUP BY position $index is not in select list " + @@ -928,6 +997,27 @@ class Analyzer( } } + /** + * Replace unresolved expressions in grouping keys with resolved ones in SELECT clauses. + * This rule is expected to run after [[ResolveReferences]] applied. + */ + object ResolveAggAliasInGroupBy extends Rule[LogicalPlan] { + + override def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators { + case agg @ Aggregate(groups, aggs, child) + if conf.groupByAliases && child.resolved && aggs.forall(_.resolved) && + groups.exists(_.isInstanceOf[UnresolvedAttribute]) => + // This is a strict check though, we put this to apply the rule only in alias expressions + def notResolvableByChild(attrName: String): Boolean = + !child.output.exists(a => resolver(a.name, attrName)) + agg.copy(groupingExpressions = groups.map { + case u: UnresolvedAttribute if notResolvableByChild(u.name) => + aggs.find(ne => resolver(ne.name, u.name)).getOrElse(u) + case e => e + }) + } + } + /** * In many dialects of SQL it is valid to sort by attributes that are not present in the SELECT * clause. This rule detects such queries and adds the required attributes to the original @@ -937,7 +1027,7 @@ class Analyzer( * The HAVING clause could also used a grouping columns that is not presented in the SELECT. */ object ResolveMissingReferences extends Rule[LogicalPlan] { - def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { + def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators { // Skip sort with aggregate. This will be handled in ResolveAggregateFunctions case sa @ Sort(_, _, child: Aggregate) => sa @@ -1038,11 +1128,30 @@ class Analyzer( } } + /** + * Checks whether a function identifier referenced by an [[UnresolvedFunction]] is defined in the + * function registry. Note that this rule doesn't try to resolve the [[UnresolvedFunction]]. It + * only performs simple existence check according to the function identifier to quickly identify + * undefined functions without triggering relation resolution, which may incur potentially + * expensive partition/schema discovery process in some cases. + * + * @see [[ResolveFunctions]] + * @see https://issues.apache.org/jira/browse/SPARK-19737 + */ + object LookupFunctions extends Rule[LogicalPlan] { + override def apply(plan: LogicalPlan): LogicalPlan = plan.transformAllExpressions { + case f: UnresolvedFunction if !catalog.functionExists(f.name) => + withPosition(f) { + throw new NoSuchFunctionException(f.name.database.getOrElse("default"), f.name.funcName) + } + } + } + /** * Replaces [[UnresolvedFunction]]s with concrete [[Expression]]s. */ object ResolveFunctions extends Rule[LogicalPlan] { - def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { + def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators { case q: LogicalPlan => q transformExpressions { case u if !u.childrenResolved => u // Skip until children are resolved. @@ -1107,31 +1216,49 @@ class Analyzer( } /** - * Pull out all (outer) correlated predicates from a given subquery. This method removes the - * correlated predicates from subquery [[Filter]]s and adds the references of these predicates - * to all intermediate [[Project]] and [[Aggregate]] clauses (if they are missing) in order to - * be able to evaluate the predicates at the top level. - * - * This method returns the rewritten subquery and correlated predicates. + * Validates to make sure the outer references appearing inside the subquery + * are legal. This function also returns the list of expressions + * that contain outer references. These outer references would be kept as children + * of subquery expressions by the caller of this function. */ - private def pullOutCorrelatedPredicates(sub: LogicalPlan): (LogicalPlan, Seq[Expression]) = { - val predicateMap = scala.collection.mutable.Map.empty[LogicalPlan, Seq[Expression]] + private def checkAndGetOuterReferences(sub: LogicalPlan): Seq[Expression] = { + val outerReferences = ArrayBuffer.empty[Expression] + + // Validate that correlated aggregate expression do not contain a mixture + // of outer and local references. + def checkMixedReferencesInsideAggregateExpr(expr: Expression): Unit = { + expr.foreach { + case a: AggregateExpression if containsOuter(a) => + val outer = a.collect { case OuterReference(e) => e.toAttribute } + val local = a.references -- outer + if (local.nonEmpty) { + val msg = + s""" + |Found an aggregate expression in a correlated predicate that has both + |outer and local references, which is not supported yet. + |Aggregate expression: ${SubExprUtils.stripOuterReference(a).sql}, + |Outer references: ${outer.map(_.sql).mkString(", ")}, + |Local references: ${local.map(_.sql).mkString(", ")}. + """.stripMargin.replace("\n", " ").trim() + failAnalysis(msg) + } + case _ => + } + } // Make sure a plan's subtree does not contain outer references def failOnOuterReferenceInSubTree(p: LogicalPlan): Unit = { - if (p.collectFirst(predicateMap).nonEmpty) { + if (hasOuterReferences(p)) { failAnalysis(s"Accessing outer query column is not allowed in:\n$p") } } - // Helper function for locating outer references. - def containsOuter(e: Expression): Boolean = { - e.find(_.isInstanceOf[OuterReference]).isDefined - } - - // Make sure a plan's expressions do not contain outer references - def failOnOuterReference(p: LogicalPlan): Unit = { - if (p.expressions.exists(containsOuter)) { + // Make sure a plan's expressions do not contain : + // 1. Aggregate expressions that have mixture of outer and local references. + // 2. Expressions containing outer references on plan nodes other than Filter. + def failOnInvalidOuterReference(p: LogicalPlan): Unit = { + p.expressions.foreach(checkMixedReferencesInsideAggregateExpr) + if (!p.isInstanceOf[Filter] && p.expressions.exists(containsOuter)) { failAnalysis( "Expressions referencing the outer query are not supported outside of WHERE/HAVING " + s"clauses:\n$p") @@ -1169,20 +1296,11 @@ class Analyzer( } } - /** Determine which correlated predicate references are missing from this plan. */ - def missingReferences(p: LogicalPlan): AttributeSet = { - val localPredicateReferences = p.collect(predicateMap) - .flatten - .map(_.references) - .reduceOption(_ ++ _) - .getOrElse(AttributeSet.empty) - localPredicateReferences -- p.outputSet - } - var foundNonEqualCorrelatedPred : Boolean = false - // Simplify the predicates before pulling them out. - val transformed = BooleanSimplification(sub) transformUp { + // Simplify the predicates before validating any unsupported correlation patterns + // in the plan. + BooleanSimplification(sub).foreachUp { // Whitelist operators allowed in a correlated subquery // There are 4 categories: @@ -1204,33 +1322,22 @@ class Analyzer( // Category 1: // BroadcastHint, Distinct, LeafNode, Repartition, and SubqueryAlias - case p: BroadcastHint => - p - case p: Distinct => - p - case p: LeafNode => - p - case p: Repartition => - p - case p: SubqueryAlias => - p + case _: BroadcastHint | _: Distinct | _: LeafNode | _: Repartition | _: SubqueryAlias => // Category 2: // These operators can be anywhere in a correlated subquery. // so long as they do not host outer references in the operators. - case p: Sort => - failOnOuterReference(p) - p - case p: RepartitionByExpression => - failOnOuterReference(p) - p + case s: Sort => + failOnInvalidOuterReference(s) + case r: RepartitionByExpression => + failOnInvalidOuterReference(r) // Category 3: // Filter is one of the two operators allowed to host correlated expressions. // The other operator is Join. Filter can be anywhere in a correlated subquery. - case f @ Filter(cond, child) => + case f: Filter => // Find all predicates with an outer reference. - val (correlated, local) = splitConjunctivePredicates(cond).partition(containsOuter) + val (correlated, _) = splitConjunctivePredicates(f.condition).partition(containsOuter) // Find any non-equality correlated predicates foundNonEqualCorrelatedPred = foundNonEqualCorrelatedPred || correlated.exists { @@ -1238,52 +1345,33 @@ class Analyzer( case _ => true } - // Rewrite the filter without the correlated predicates if any. - correlated match { - case Nil => f - case xs if local.nonEmpty => - val newFilter = Filter(local.reduce(And), child) - predicateMap += newFilter -> xs - newFilter - case xs => - predicateMap += child -> xs - child - } + failOnInvalidOuterReference(f) + // The aggregate expressions are treated in a special way by getOuterReferences. If the + // aggregate expression contains only outer reference attributes then the entire aggregate + // expression is isolated as an OuterReference. + // i.e min(OuterReference(b)) => OuterReference(min(b)) + outerReferences ++= getOuterReferences(correlated) // Project cannot host any correlated expressions // but can be anywhere in a correlated subquery. - case p @ Project(expressions, child) => - failOnOuterReference(p) - - val referencesToAdd = missingReferences(p) - if (referencesToAdd.nonEmpty) { - Project(expressions ++ referencesToAdd, child) - } else { - p - } + case p: Project => + failOnInvalidOuterReference(p) // Aggregate cannot host any correlated expressions // It can be on a correlation path if the correlation contains // only equality correlated predicates. // It cannot be on a correlation path if the correlation has // non-equality correlated predicates. - case a @ Aggregate(grouping, expressions, child) => - failOnOuterReference(a) + case a: Aggregate => + failOnInvalidOuterReference(a) failOnNonEqualCorrelatedPredicate(foundNonEqualCorrelatedPred, a) - val referencesToAdd = missingReferences(a) - if (referencesToAdd.nonEmpty) { - Aggregate(grouping ++ referencesToAdd, expressions ++ referencesToAdd, child) - } else { - a - } - // Join can host correlated expressions. case j @ Join(left, right, joinType, _) => joinType match { // Inner join, like Filter, can be anywhere. case _: InnerLike => - failOnOuterReference(j) + failOnInvalidOuterReference(j) // Left outer join's right operand cannot be on a correlation path. // LeftAnti and ExistenceJoin are special cases of LeftOuter. @@ -1294,12 +1382,12 @@ class Analyzer( // Any correlated references in the subplan // of the right operand cannot be pulled up. case LeftOuter | LeftSemi | LeftAnti | ExistenceJoin(_) => - failOnOuterReference(j) + failOnInvalidOuterReference(j) failOnOuterReferenceInSubTree(right) // Likewise, Right outer join's left operand cannot be on a correlation path. case RightOuter => - failOnOuterReference(j) + failOnInvalidOuterReference(j) failOnOuterReferenceInSubTree(left) // Any other join types not explicitly listed above, @@ -1307,7 +1395,6 @@ class Analyzer( case _ => failOnOuterReferenceInSubTree(j) } - j // Generator with join=true, i.e., expressed with // LATERAL VIEW [OUTER], similar to inner join, @@ -1315,9 +1402,8 @@ class Analyzer( // but must not host any outer references. // Note: // Generator with join=false is treated as Category 4. - case p @ Generate(generator, true, _, _, _, _) => - failOnOuterReference(p) - p + case g: Generate if g.join => + failOnInvalidOuterReference(g) // Category 4: Any other operators not in the above 3 categories // cannot be on a correlation path, that is they are allowed only @@ -1325,54 +1411,17 @@ class Analyzer( // are not allowed to have any correlated expressions. case p => failOnOuterReferenceInSubTree(p) - p } - (transformed, predicateMap.values.flatten.toSeq) + outerReferences } /** - * Rewrite the subquery in a safe way by preventing that the subquery and the outer use the same - * attributes. - */ - private def rewriteSubQuery( - sub: LogicalPlan, - outer: Seq[LogicalPlan]): (LogicalPlan, Seq[Expression]) = { - // Pull out the tagged predicates and rewrite the subquery in the process. - val (basePlan, baseConditions) = pullOutCorrelatedPredicates(sub) - - // Make sure the inner and the outer query attributes do not collide. - val outputSet = outer.map(_.outputSet).reduce(_ ++ _) - val duplicates = basePlan.outputSet.intersect(outputSet) - val (plan, deDuplicatedConditions) = if (duplicates.nonEmpty) { - val aliasMap = AttributeMap(duplicates.map { dup => - dup -> Alias(dup, dup.toString)() - }.toSeq) - val aliasedExpressions = basePlan.output.map { ref => - aliasMap.getOrElse(ref, ref) - } - val aliasedProjection = Project(aliasedExpressions, basePlan) - val aliasedConditions = baseConditions.map(_.transform { - case ref: Attribute => aliasMap.getOrElse(ref, ref).toAttribute - }) - (aliasedProjection, aliasedConditions) - } else { - (basePlan, baseConditions) - } - // Remove outer references from the correlated predicates. We wait with extracting - // these until collisions between the inner and outer query attributes have been - // solved. - val conditions = deDuplicatedConditions.map(_.transform { - case OuterReference(ref) => ref - }) - (plan, conditions) - } - - /** - * Resolve and rewrite a subquery. The subquery is resolved using its outer plans. This method + * Resolves the subquery. The subquery is resolved using its outer plans. This method * will resolve the subquery by alternating between the regular analyzer and by applying the * resolveOuterReferences rule. * - * All correlated conditions are pulled out of the subquery as soon as the subquery is resolved. + * Outer references from the correlated predicates are updated as children of + * Subquery expression. */ private def resolveSubQuery( e: SubqueryExpression, @@ -1395,7 +1444,8 @@ class Analyzer( } } while (!current.resolved && !current.fastEquals(previous)) - // Step 2: Pull out the predicates if the plan is resolved. + // Step 2: If the subquery plan is fully resolved, pull the outer references and record + // them as children of SubqueryExpression. if (current.resolved) { // Make sure the resolved query has the required number of output columns. This is only // needed for Scalar and IN subqueries. @@ -1403,41 +1453,44 @@ class Analyzer( failAnalysis(s"The number of columns in the subquery (${current.output.size}) " + s"does not match the required number of columns ($requiredColumns)") } - // Pullout predicates and construct a new plan. - f.tupled(rewriteSubQuery(current, plans)) + // Validate the outer reference and record the outer references as children of + // subquery expression. + f(current, checkAndGetOuterReferences(current)) } else { e.withNewPlan(current) } } /** - * Resolve and rewrite all subqueries in a LogicalPlan. This method transforms IN and EXISTS - * expressions into PredicateSubquery expression once the are resolved. + * Resolves the subquery. Apart of resolving the subquery and outer references (if any) + * in the subquery plan, the children of subquery expression are updated to record the + * outer references. This is needed to make sure + * (1) The column(s) referred from the outer query are not pruned from the plan during + * optimization. + * (2) Any aggregate expression(s) that reference outer attributes are pushed down to + * outer plan to get evaluated. */ private def resolveSubQueries(plan: LogicalPlan, plans: Seq[LogicalPlan]): LogicalPlan = { plan transformExpressions { case s @ ScalarSubquery(sub, _, exprId) if !sub.resolved => resolveSubQuery(s, plans, 1)(ScalarSubquery(_, _, exprId)) - case e @ Exists(sub, exprId) => - resolveSubQuery(e, plans)(PredicateSubquery(_, _, nullAware = false, exprId)) - case In(e, Seq(l @ ListQuery(_, exprId))) if e.resolved => + case e @ Exists(sub, _, exprId) if !sub.resolved => + resolveSubQuery(e, plans)(Exists(_, _, exprId)) + case In(value, Seq(l @ ListQuery(sub, _, exprId))) if value.resolved && !sub.resolved => // Get the left hand side expressions. - val expressions = e match { + val expressions = value match { case cns : CreateNamedStruct => cns.valExprs case expr => Seq(expr) } - resolveSubQuery(l, plans, expressions.size) { (rewrite, conditions) => - // Construct the IN conditions. - val inConditions = expressions.zip(rewrite.output).map(EqualTo.tupled) - PredicateSubquery(rewrite, inConditions ++ conditions, nullAware = true, exprId) - } + val expr = resolveSubQuery(l, plans, expressions.size)(ListQuery(_, _, exprId)) + In(value, Seq(expr)) } } /** * Resolve and rewrite all subqueries in an operator tree.. */ - def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { + def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators { // In case of HAVING (a filter after an aggregate) we use both the aggregate and // its child for resolution. case f @ Filter(_, a: Aggregate) if f.childrenResolved => @@ -1452,7 +1505,7 @@ class Analyzer( * Turns projections that contain aggregate expressions into aggregations. */ object GlobalAggregates extends Rule[LogicalPlan] { - def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { + def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators { case Project(projectList, child) if containsAggregates(projectList) => Aggregate(Nil, projectList, child) } @@ -1478,7 +1531,7 @@ class Analyzer( * underlying aggregate operator and then projected away after the original operator. */ object ResolveAggregateFunctions extends Rule[LogicalPlan] { - def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { + def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators { case filter @ Filter(havingCondition, aggregate @ Aggregate(grouping, originalAggExprs, child)) if aggregate.resolved => @@ -1650,7 +1703,7 @@ class Analyzer( } } - def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { + def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators { case Project(projectList, _) if projectList.exists(hasNestedGenerator) => val nestedGenerator = projectList.find(hasNestedGenerator).get throw new AnalysisException("Generators are not supported when it's nested in " + @@ -1708,7 +1761,7 @@ class Analyzer( * that wrap the [[Generator]]. */ object ResolveGenerate extends Rule[LogicalPlan] { - def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { + def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators { case g: Generate if !g.child.resolved || !g.generator.resolved => g case g: Generate if !g.resolved => g.copy(generatorOutput = makeGeneratorOutput(g.generator, g.generatorOutput.map(_.name))) @@ -2025,7 +2078,7 @@ class Analyzer( * put them into an inner Project and finally project them away at the outer Project. */ object PullOutNondeterministic extends Rule[LogicalPlan] { - override def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { + override def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators { case p if !p.resolved => p // Skip unresolved nodes. case p: Project => p case f: Filter => f @@ -2070,7 +2123,7 @@ class Analyzer( * and we should return null if the input is null. */ object HandleNullInputsForUDF extends Rule[LogicalPlan] { - override def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { + override def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators { case p if !p.resolved => p // Skip unresolved nodes. case p => p transformExpressionsUp { @@ -2135,7 +2188,7 @@ class Analyzer( * Then apply a Project on a normal Join to eliminate natural or using join. */ object ResolveNaturalAndUsingJoin extends Rule[LogicalPlan] { - override def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { + override def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators { case j @ Join(left, right, UsingJoin(joinType, usingCols), condition) if left.resolved && right.resolved && j.duplicateResolved => commonNaturalJoinProcessing(left, right, joinType, usingCols, None) @@ -2200,7 +2253,7 @@ class Analyzer( * to the given input attributes. */ object ResolveDeserializer extends Rule[LogicalPlan] { - def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { + def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators { case p if !p.childrenResolved => p case p if p.resolved => p @@ -2215,8 +2268,21 @@ class Analyzer( validateTopLevelTupleFields(deserializer, inputs) val resolved = resolveExpression( deserializer, LocalRelation(inputs), throws = true) - validateNestedTupleFields(resolved) - resolved + val result = resolved transformDown { + case UnresolvedMapObjects(func, inputData, cls) if inputData.resolved => + inputData.dataType match { + case ArrayType(et, cn) => + val expr = MapObjects(func, inputData, et, cn, cls) transformUp { + case UnresolvedExtractValue(child, fieldName) if child.resolved => + ExtractValue(child, fieldName, resolver) + } + expr + case other => + throw new AnalysisException("need an array field but got " + other.simpleString) + } + } + validateNestedTupleFields(result) + result } } @@ -2273,7 +2339,7 @@ class Analyzer( * constructed is an inner class. */ object ResolveNewInstance extends Rule[LogicalPlan] { - def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { + def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators { case p if !p.childrenResolved => p case p if p.resolved => p @@ -2296,14 +2362,18 @@ class Analyzer( */ object ResolveUpCast extends Rule[LogicalPlan] { private def fail(from: Expression, to: DataType, walkedTypePath: Seq[String]) = { - throw new AnalysisException(s"Cannot up cast ${from.sql} from " + + val fromStr = from match { + case l: LambdaVariable => "array element" + case e => e.sql + } + throw new AnalysisException(s"Cannot up cast $fromStr from " + s"${from.dataType.simpleString} to ${to.simpleString} as it may truncate\n" + "The type path of the target object is:\n" + walkedTypePath.mkString("", "\n", "\n") + "You can either add an explicit cast to the input data or choose a higher precision " + "type of the field in the target object") } - def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { + def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators { case p if !p.childrenResolved => p case p if p.resolved => p @@ -2318,18 +2388,6 @@ class Analyzer( } } } - - /** - * Replace [[TimeZoneAwareExpression]] without [[TimeZone]] by its copy with session local - * time zone. - */ - object ResolveTimeZone extends Rule[LogicalPlan] { - - override def apply(plan: LogicalPlan): LogicalPlan = plan.resolveExpressions { - case e: TimeZoneAwareExpression if e.timeZoneId.isEmpty => - e.withTimeZone(conf.sessionLocalTimeZone) - } - } } /** @@ -2338,7 +2396,7 @@ class Analyzer( */ object EliminateSubqueryAliases extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = plan transformUp { - case SubqueryAlias(_, child, _) => child + case SubqueryAlias(_, child) => child } } @@ -2369,7 +2427,7 @@ object CleanupAliases extends Rule[LogicalPlan] { case other => trimAliases(other) } - override def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { + override def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators { case Project(projectList, child) => val cleanedProjectList = projectList.map(trimNonTopLevelAliases(_).asInstanceOf[NamedExpression]) @@ -2437,7 +2495,7 @@ object TimeWindowing extends Rule[LogicalPlan] { * @return the logical plan that will generate the time windows using the Expand operator, with * the Filter operator for correctness and Project for usability. */ - def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { + def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators { case p: LogicalPlan if p.children.size == 1 => val child = p.children.head val windowExpressions = @@ -2486,7 +2544,7 @@ object TimeWindowing extends Rule[LogicalPlan] { substitutedPlan.withNewChildren(expandedPlan :: Nil) } else if (windowExpressions.size > 1) { p.failAnalysis("Multiple time window expressions would result in a cartesian product " + - "of rows, therefore they are not currently not supported.") + "of rows, therefore they are currently not supported.") } else { p // Return unchanged. Analyzer will throw exception later } @@ -2508,3 +2566,67 @@ object ResolveCreateNamedStruct extends Rule[LogicalPlan] { CreateNamedStruct(children.toList) } } + +/** + * The aggregate expressions from subquery referencing outer query block are pushed + * down to the outer query block for evaluation. This rule below updates such outer references + * as AttributeReference referring attributes from the parent/outer query block. + * + * For example (SQL): + * {{{ + * SELECT l.a FROM l GROUP BY 1 HAVING EXISTS (SELECT 1 FROM r WHERE r.d < min(l.b)) + * }}} + * Plan before the rule. + * Project [a#226] + * +- Filter exists#245 [min(b#227)#249] + * : +- Project [1 AS 1#247] + * : +- Filter (d#238 < min(outer(b#227))) <----- + * : +- SubqueryAlias r + * : +- Project [_1#234 AS c#237, _2#235 AS d#238] + * : +- LocalRelation [_1#234, _2#235] + * +- Aggregate [a#226], [a#226, min(b#227) AS min(b#227)#249] + * +- SubqueryAlias l + * +- Project [_1#223 AS a#226, _2#224 AS b#227] + * +- LocalRelation [_1#223, _2#224] + * Plan after the rule. + * Project [a#226] + * +- Filter exists#245 [min(b#227)#249] + * : +- Project [1 AS 1#247] + * : +- Filter (d#238 < outer(min(b#227)#249)) <----- + * : +- SubqueryAlias r + * : +- Project [_1#234 AS c#237, _2#235 AS d#238] + * : +- LocalRelation [_1#234, _2#235] + * +- Aggregate [a#226], [a#226, min(b#227) AS min(b#227)#249] + * +- SubqueryAlias l + * +- Project [_1#223 AS a#226, _2#224 AS b#227] + * +- LocalRelation [_1#223, _2#224] + */ +object UpdateOuterReferences extends Rule[LogicalPlan] { + private def stripAlias(expr: Expression): Expression = expr match { case a: Alias => a.child } + + private def updateOuterReferenceInSubquery( + plan: LogicalPlan, + refExprs: Seq[Expression]): LogicalPlan = { + plan transformAllExpressions { case e => + val outerAlias = + refExprs.find(stripAlias(_).semanticEquals(stripOuterReference(e))) + outerAlias match { + case Some(a: Alias) => OuterReference(a.toAttribute) + case _ => e + } + } + } + + def apply(plan: LogicalPlan): LogicalPlan = { + plan transform { + case f @ Filter(_, a: Aggregate) if f.resolved => + f transformExpressions { + case s: SubqueryExpression if s.children.nonEmpty => + // Collect the aliases from output of aggregate. + val outerAliases = a.aggregateExpressions collect { case a: Alias => a } + // Update the subquery plan to record the OuterReference to point to outer query plan. + s.withNewPlan(updateOuterReferenceInSubquery(s.plan, outerAliases)) + } + } + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala index 7529f9028498c..61797bc34dc27 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.catalyst.analysis import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression +import org.apache.spark.sql.catalyst.expressions.SubExprUtils._ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.types._ @@ -44,6 +45,18 @@ trait CheckAnalysis extends PredicateHelper { }).length > 1 } + protected def hasMapType(dt: DataType): Boolean = { + dt.existsRecursively(_.isInstanceOf[MapType]) + } + + protected def mapColumnInSetOperation(plan: LogicalPlan): Option[Attribute] = plan match { + case _: Intersect | _: Except | _: Distinct => + plan.output.find(a => hasMapType(a.dataType)) + case d: Deduplicate => + d.keys.find(a => hasMapType(a.dataType)) + case _ => None + } + private def checkLimitClause(limitExpr: Expression): Unit = { limitExpr match { case e if !e.foldable => failAnalysis( @@ -123,9 +136,6 @@ trait CheckAnalysis extends PredicateHelper { s"Scalar subquery must return only one column, but got ${query.output.size}") } else if (conditions.nonEmpty) { - // Collect the columns from the subquery for further checking. - var subqueryColumns = conditions.flatMap(_.references).filter(query.output.contains) - def checkAggregate(agg: Aggregate): Unit = { // Make sure correlated scalar subqueries contain one row for every outer row by // enforcing that they are aggregates containing exactly one aggregate expression. @@ -141,6 +151,9 @@ trait CheckAnalysis extends PredicateHelper { // SPARK-18504/SPARK-18814: Block cases where GROUP BY columns // are not part of the correlated columns. val groupByCols = AttributeSet(agg.groupingExpressions.flatMap(_.references)) + // Collect the local references from the correlated predicate in the subquery. + val subqueryColumns = getCorrelatedPredicates(query).flatMap(_.references) + .filterNot(conditions.flatMap(_.references).contains) val correlatedCols = AttributeSet(subqueryColumns) val invalidCols = groupByCols -- correlatedCols // GROUP BY columns must be a subset of columns in the predicates @@ -156,17 +169,7 @@ trait CheckAnalysis extends PredicateHelper { // For projects, do the necessary mapping and skip to its child. def cleanQuery(p: LogicalPlan): LogicalPlan = p match { case s: SubqueryAlias => cleanQuery(s.child) - case p: Project => - // SPARK-18814: Map any aliases to their AttributeReference children - // for the checking in the Aggregate operators below this Project. - subqueryColumns = subqueryColumns.map { - xs => p.projectList.collectFirst { - case e @ Alias(child : AttributeReference, _) if e.exprId == xs.exprId => - child - }.getOrElse(xs) - } - - cleanQuery(p.child) + case p: Project => cleanQuery(p.child) case child => child } @@ -200,14 +203,9 @@ trait CheckAnalysis extends PredicateHelper { s"filter expression '${f.condition.sql}' " + s"of type ${f.condition.dataType.simpleString} is not a boolean.") - case f @ Filter(condition, child) => - splitConjunctivePredicates(condition).foreach { - case _: PredicateSubquery | Not(_: PredicateSubquery) => - case e if PredicateSubquery.hasNullAwarePredicateWithinNot(e) => - failAnalysis(s"Null-aware predicate sub-queries cannot be used in nested" + - s" conditions: $e") - case e => - } + case Filter(condition, _) if hasNullAwarePredicateWithinNot(condition) => + failAnalysis("Null-aware predicate sub-queries cannot be used in nested " + + s"conditions: $condition") case j @ Join(_, _, _, Some(condition)) if condition.dataType != BooleanType => failAnalysis( @@ -256,6 +254,11 @@ trait CheckAnalysis extends PredicateHelper { } def checkValidGroupingExprs(expr: Expression): Unit = { + if (expr.find(_.isInstanceOf[AggregateExpression]).isDefined) { + failAnalysis( + "aggregate functions are not allowed in GROUP BY, but found " + expr.sql) + } + // Check if the data type of expr is orderable. if (!RowOrdering.isOrderable(expr.dataType)) { failAnalysis( @@ -273,8 +276,8 @@ trait CheckAnalysis extends PredicateHelper { } } - aggregateExprs.foreach(checkValidAggregateExpression) groupingExprs.foreach(checkValidGroupingExprs) + aggregateExprs.foreach(checkValidAggregateExpression) case Sort(orders, _, _) => orders.foreach { order => @@ -295,8 +298,11 @@ trait CheckAnalysis extends PredicateHelper { s"Correlated scalar sub-queries can only be used in a Filter/Aggregate/Project: $p") } - case p if p.expressions.exists(PredicateSubquery.hasPredicateSubquery) => - failAnalysis(s"Predicate sub-queries can only be used in a Filter: $p") + case p if p.expressions.exists(SubqueryExpression.hasInOrExistsSubquery) => + p match { + case _: Filter => // Ok + case _ => failAnalysis(s"Predicate sub-queries can only be used in a Filter: $p") + } case _: Union | _: SetOperation if operator.children.length > 1 => def dataTypes(plan: LogicalPlan): Seq[DataType] = plan.output.map(_.dataType) @@ -374,6 +380,14 @@ trait CheckAnalysis extends PredicateHelper { |Conflicting attributes: ${conflictingAttributes.mkString(",")} """.stripMargin) + // TODO: although map type is not orderable, technically map type should be able to be + // used in equality comparison, remove this type check once we support it. + case o if mapColumnInSetOperation(o).isDefined => + val mapCol = mapColumnInSetOperation(o).get + failAnalysis("Cannot have map type columns in DataFrame which calls " + + s"set operations(intersect, except, etc.), but the type of column ${mapCol.name} " + + "is " + mapCol.dataType.simpleString) + case o if o.expressions.exists(!_.deterministic) && !o.isInstanceOf[Project] && !o.isInstanceOf[Filter] && !o.isInstanceOf[Aggregate] && !o.isInstanceOf[Window] => diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index 9c9465f6b8def..e1d83a86f99dc 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -64,6 +64,8 @@ trait FunctionRegistry { /** Clear all registered functions. */ def clear(): Unit + /** Create a copy of this registry with identical functions as this registry. */ + override def clone(): FunctionRegistry = throw new CloneNotSupportedException() } class SimpleFunctionRegistry extends FunctionRegistry { @@ -107,7 +109,7 @@ class SimpleFunctionRegistry extends FunctionRegistry { functionBuilders.clear() } - def copy(): SimpleFunctionRegistry = synchronized { + override def clone(): SimpleFunctionRegistry = synchronized { val registry = new SimpleFunctionRegistry functionBuilders.iterator.foreach { case (name, (info, builder)) => registry.registerFunction(name, info, builder) @@ -150,6 +152,7 @@ object EmptyFunctionRegistry extends FunctionRegistry { throw new UnsupportedOperationException } + override def clone(): FunctionRegistry = this } @@ -421,6 +424,10 @@ object FunctionRegistry { expression[BitwiseOr]("|"), expression[BitwiseXor]("^"), + // json + expression[StructsToJson]("to_json"), + expression[JsonToStructs]("from_json"), + // Cast aliases (SPARK-16730) castAlias("boolean", BooleanType), castAlias("tinyint", ByteType), diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveHints.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveHints.scala index 70438eb5912b8..df688fa0e58ae 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveHints.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveHints.scala @@ -17,16 +17,18 @@ package org.apache.spark.sql.catalyst.analysis -import org.apache.spark.sql.catalyst.CatalystConf +import java.util.Locale + import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.catalyst.trees.CurrentOrigin +import org.apache.spark.sql.internal.SQLConf /** * Collection of rules related to hints. The only hint currently available is broadcast join hint. * - * Note that this is separatedly into two rules because in the future we might introduce new hint + * Note that this is separately into two rules because in the future we might introduce new hint * rules that have different ordering requirements from broadcast. */ object ResolveHints { @@ -43,7 +45,7 @@ object ResolveHints { * * This rule must happen before common table expressions. */ - class ResolveBroadcastHints(conf: CatalystConf) extends Rule[LogicalPlan] { + class ResolveBroadcastHints(conf: SQLConf) extends Rule[LogicalPlan] { private val BROADCAST_HINT_NAMES = Set("BROADCAST", "BROADCASTJOIN", "MAPJOIN") def resolver: Resolver = conf.resolver @@ -83,8 +85,14 @@ object ResolveHints { } def apply(plan: LogicalPlan): LogicalPlan = plan transformUp { - case h: Hint if BROADCAST_HINT_NAMES.contains(h.name.toUpperCase) => - applyBroadcastHint(h.child, h.parameters.toSet) + case h: Hint if BROADCAST_HINT_NAMES.contains(h.name.toUpperCase(Locale.ROOT)) => + if (h.parameters.isEmpty) { + // If there is no table alias specified, turn the entire subtree into a BroadcastHint. + BroadcastHint(h.child) + } else { + // Otherwise, find within the subtree query plans that should be broadcasted. + applyBroadcastHint(h.child, h.parameters.toSet) + } } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveInlineTables.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveInlineTables.scala index d5b3ea8c37c66..f2df3e132629f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveInlineTables.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveInlineTables.scala @@ -19,16 +19,16 @@ package org.apache.spark.sql.catalyst.analysis import scala.util.control.NonFatal -import org.apache.spark.sql.catalyst.{CatalystConf, InternalRow} -import org.apache.spark.sql.catalyst.expressions.{Cast, TimeZoneAwareExpression} +import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan} import org.apache.spark.sql.catalyst.rules.Rule +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.{StructField, StructType} /** * An analyzer rule that replaces [[UnresolvedInlineTable]] with [[LocalRelation]]. */ -case class ResolveInlineTables(conf: CatalystConf) extends Rule[LogicalPlan] { +case class ResolveInlineTables(conf: SQLConf) extends Rule[LogicalPlan] with CastSupport { override def apply(plan: LogicalPlan): LogicalPlan = plan transformUp { case table: UnresolvedInlineTable if table.expressionsResolved => validateInputDimension(table) @@ -98,12 +98,9 @@ case class ResolveInlineTables(conf: CatalystConf) extends Rule[LogicalPlan] { val castedExpr = if (e.dataType.sameType(targetType)) { e } else { - Cast(e, targetType) + cast(e, targetType) } - castedExpr.transform { - case e: TimeZoneAwareExpression if e.timeZoneId.isEmpty => - e.withTimeZone(conf.sessionLocalTimeZone) - }.eval() + castedExpr.eval() } catch { case NonFatal(ex) => table.failAnalysis(s"failed to evaluate expression ${e.sql}: ${ex.getMessage}") diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveTableValuedFunctions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveTableValuedFunctions.scala index 6b3bb68538dd1..de6de24350f23 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveTableValuedFunctions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveTableValuedFunctions.scala @@ -17,9 +17,9 @@ package org.apache.spark.sql.catalyst.analysis -import org.apache.spark.{SparkConf, SparkContext} +import java.util.Locale + import org.apache.spark.sql.catalyst.expressions.Expression -import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Range} import org.apache.spark.sql.catalyst.rules._ import org.apache.spark.sql.types.{DataType, IntegerType, LongType} @@ -105,7 +105,7 @@ object ResolveTableValuedFunctions extends Rule[LogicalPlan] { override def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { case u: UnresolvedTableValuedFunction if u.functionArgs.forall(_.resolved) => - builtinFunctions.get(u.functionName) match { + builtinFunctions.get(u.functionName.toLowerCase(Locale.ROOT)) match { case Some(tvf) => val resolved = tvf.flatMap { case (argList, resolver) => argList.implicitCast(u.functionArgs) match { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/SubstituteUnresolvedOrdinals.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/SubstituteUnresolvedOrdinals.scala index af0a565f73ae9..256b18771052a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/SubstituteUnresolvedOrdinals.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/SubstituteUnresolvedOrdinals.scala @@ -17,17 +17,17 @@ package org.apache.spark.sql.catalyst.analysis -import org.apache.spark.sql.catalyst.CatalystConf import org.apache.spark.sql.catalyst.expressions.{Expression, Literal, SortOrder} import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, LogicalPlan, Sort} import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.catalyst.trees.CurrentOrigin.withOrigin +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.IntegerType /** * Replaces ordinal in 'order by' or 'group by' with UnresolvedOrdinal expression. */ -class SubstituteUnresolvedOrdinals(conf: CatalystConf) extends Rule[LogicalPlan] { +class SubstituteUnresolvedOrdinals(conf: SQLConf) extends Rule[LogicalPlan] { private def isIntLiteral(e: Expression) = e match { case Literal(_, IntegerType) => true case _ => false @@ -36,7 +36,7 @@ class SubstituteUnresolvedOrdinals(conf: CatalystConf) extends Rule[LogicalPlan] def apply(plan: LogicalPlan): LogicalPlan = plan transform { case s: Sort if conf.orderByOrdinal && s.order.exists(o => isIntLiteral(o.child)) => val newOrders = s.order.map { - case order @ SortOrder(ordinal @ Literal(index: Int, IntegerType), _, _) => + case order @ SortOrder(ordinal @ Literal(index: Int, IntegerType), _, _, _) => val newOrdinal = withOrigin(ordinal.origin)(UnresolvedOrdinal(index)) withOrigin(order.origin)(order.copy(child = newOrdinal)) case other => other diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala index 2c00957bd6afb..e1dd010d37a95 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala @@ -108,6 +108,28 @@ object TypeCoercion { case _ => None } + /** + * This function determines the target type of a comparison operator when one operand + * is a String and the other is not. It also handles when one op is a Date and the + * other is a Timestamp by making the target type to be String. + */ + val findCommonTypeForBinaryComparison: (DataType, DataType) => Option[DataType] = { + // We should cast all relative timestamp/date/string comparison into string comparisons + // This behaves as a user would expect because timestamp strings sort lexicographically. + // i.e. TimeStamp(2013-01-01 00:00 ...) < "2014" = true + case (StringType, DateType) => Some(StringType) + case (DateType, StringType) => Some(StringType) + case (StringType, TimestampType) => Some(StringType) + case (TimestampType, StringType) => Some(StringType) + case (TimestampType, DateType) => Some(StringType) + case (DateType, TimestampType) => Some(StringType) + case (StringType, NullType) => Some(StringType) + case (NullType, StringType) => Some(StringType) + case (l: StringType, r: AtomicType) if r != StringType => Some(r) + case (l: AtomicType, r: StringType) if (l != StringType) => Some(l) + case (l, r) => None + } + /** * Case 2 type widening (see the classdoc comment above for TypeCoercion). * @@ -305,6 +327,14 @@ object TypeCoercion { * Promotes strings that appear in arithmetic expressions. */ object PromoteStrings extends Rule[LogicalPlan] { + private def castExpr(expr: Expression, targetType: DataType): Expression = { + (expr.dataType, targetType) match { + case (NullType, dt) => Literal.create(null, targetType) + case (l, dt) if (l != dt) => Cast(expr, targetType) + case _ => expr + } + } + def apply(plan: LogicalPlan): LogicalPlan = plan resolveExpressions { // Skip nodes who's children have not been resolved yet. case e if !e.childrenResolved => e @@ -321,37 +351,10 @@ object TypeCoercion { case p @ Equality(left @ TimestampType(), right @ StringType()) => p.makeCopy(Array(left, Cast(right, TimestampType))) - // We should cast all relative timestamp/date/string comparison into string comparisons - // This behaves as a user would expect because timestamp strings sort lexicographically. - // i.e. TimeStamp(2013-01-01 00:00 ...) < "2014" = true - case p @ BinaryComparison(left @ StringType(), right @ DateType()) => - p.makeCopy(Array(left, Cast(right, StringType))) - case p @ BinaryComparison(left @ DateType(), right @ StringType()) => - p.makeCopy(Array(Cast(left, StringType), right)) - case p @ BinaryComparison(left @ StringType(), right @ TimestampType()) => - p.makeCopy(Array(left, Cast(right, StringType))) - case p @ BinaryComparison(left @ TimestampType(), right @ StringType()) => - p.makeCopy(Array(Cast(left, StringType), right)) - - // Comparisons between dates and timestamps. - case p @ BinaryComparison(left @ TimestampType(), right @ DateType()) => - p.makeCopy(Array(Cast(left, StringType), Cast(right, StringType))) - case p @ BinaryComparison(left @ DateType(), right @ TimestampType()) => - p.makeCopy(Array(Cast(left, StringType), Cast(right, StringType))) - - // Checking NullType - case p @ BinaryComparison(left @ StringType(), right @ NullType()) => - p.makeCopy(Array(left, Literal.create(null, StringType))) - case p @ BinaryComparison(left @ NullType(), right @ StringType()) => - p.makeCopy(Array(Literal.create(null, StringType), right)) - - // When compare string with atomic type, case string to that type. - case p @ BinaryComparison(left @ StringType(), right @ AtomicType()) - if right.dataType != StringType => - p.makeCopy(Array(Cast(left, right.dataType), right)) - case p @ BinaryComparison(left @ AtomicType(), right @ StringType()) - if left.dataType != StringType => - p.makeCopy(Array(left, Cast(right, left.dataType))) + case p @ BinaryComparison(left, right) + if findCommonTypeForBinaryComparison(left.dataType, right.dataType).isDefined => + val commonType = findCommonTypeForBinaryComparison(left.dataType, right.dataType).get + p.makeCopy(Array(castExpr(left, commonType), castExpr(right, commonType))) case Sum(e @ StringType()) => Sum(Cast(e, DoubleType)) case Average(e @ StringType()) => Average(Cast(e, DoubleType)) @@ -365,17 +368,72 @@ object TypeCoercion { } /** - * Convert the value and in list expressions to the common operator type - * by looking at all the argument types and finding the closest one that - * all the arguments can be cast to. When no common operator type is found - * the original expression will be returned and an Analysis Exception will - * be raised at type checking phase. + * Handles type coercion for both IN expression with subquery and IN + * expressions without subquery. + * 1. In the first case, find the common type by comparing the left hand side (LHS) + * expression types against corresponding right hand side (RHS) expression derived + * from the subquery expression's plan output. Inject appropriate casts in the + * LHS and RHS side of IN expression. + * + * 2. In the second case, convert the value and in list expressions to the + * common operator type by looking at all the argument types and finding + * the closest one that all the arguments can be cast to. When no common + * operator type is found the original expression will be returned and an + * Analysis Exception will be raised at the type checking phase. */ object InConversion extends Rule[LogicalPlan] { + private def flattenExpr(expr: Expression): Seq[Expression] = { + expr match { + // Multi columns in IN clause is represented as a CreateNamedStruct. + // flatten the named struct to get the list of expressions. + case cns: CreateNamedStruct => cns.valExprs + case expr => Seq(expr) + } + } + def apply(plan: LogicalPlan): LogicalPlan = plan resolveExpressions { // Skip nodes who's children have not been resolved yet. case e if !e.childrenResolved => e + // Handle type casting required between value expression and subquery output + // in IN subquery. + case i @ In(a, Seq(ListQuery(sub, children, exprId))) + if !i.resolved && flattenExpr(a).length == sub.output.length => + // LHS is the value expression of IN subquery. + val lhs = flattenExpr(a) + + // RHS is the subquery output. + val rhs = sub.output + + val commonTypes = lhs.zip(rhs).flatMap { case (l, r) => + findCommonTypeForBinaryComparison(l.dataType, r.dataType) + .orElse(findTightestCommonType(l.dataType, r.dataType)) + } + + // The number of columns/expressions must match between LHS and RHS of an + // IN subquery expression. + if (commonTypes.length == lhs.length) { + val castedRhs = rhs.zip(commonTypes).map { + case (e, dt) if e.dataType != dt => Alias(Cast(e, dt), e.name)() + case (e, _) => e + } + val castedLhs = lhs.zip(commonTypes).map { + case (e, dt) if e.dataType != dt => Cast(e, dt) + case (e, _) => e + } + + // Before constructing the In expression, wrap the multi values in LHS + // in a CreatedNamedStruct. + val newLhs = castedLhs match { + case Seq(lhs) => lhs + case _ => CreateStruct(castedLhs) + } + + In(newLhs, Seq(ListQuery(Project(castedRhs, sub), children, exprId))) + } else { + i + } + case i @ In(a, b) if b.exists(_.dataType != a.dataType) => findWiderCommonType(i.children.map(_.dataType)) match { case Some(finalDataType) => i.withNewChildren(i.children.map(Cast(_, finalDataType))) @@ -513,6 +571,7 @@ object TypeCoercion { NaNvl(l, Cast(r, DoubleType)) case NaNvl(l, r) if l.dataType == FloatType && r.dataType == DoubleType => NaNvl(Cast(l, DoubleType), r) + case NaNvl(l, r) if r.dataType == NullType => NaNvl(l, Cast(r, l.dataType)) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala index 397f5cfe2a540..6ab4153bac70e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala @@ -51,6 +51,37 @@ object UnsupportedOperationChecker { subplan.collect { case a: Aggregate if a.isStreaming => a } } + val mapGroupsWithStates = plan.collect { + case f: FlatMapGroupsWithState if f.isStreaming && f.isMapGroupsWithState => f + } + + // Disallow multiple `mapGroupsWithState`s. + if (mapGroupsWithStates.size >= 2) { + throwError( + "Multiple mapGroupsWithStates are not supported on a streaming DataFrames/Datasets")(plan) + } + + val flatMapGroupsWithStates = plan.collect { + case f: FlatMapGroupsWithState if f.isStreaming && !f.isMapGroupsWithState => f + } + + // Disallow mixing `mapGroupsWithState`s and `flatMapGroupsWithState`s + if (mapGroupsWithStates.nonEmpty && flatMapGroupsWithStates.nonEmpty) { + throwError( + "Mixing mapGroupsWithStates and flatMapGroupsWithStates are not supported on a " + + "streaming DataFrames/Datasets")(plan) + } + + // Only allow multiple `FlatMapGroupsWithState(Append)`s in append mode. + if (flatMapGroupsWithStates.size >= 2 && ( + outputMode != InternalOutputModes.Append || + flatMapGroupsWithStates.exists(_.outputMode != InternalOutputModes.Append) + )) { + throwError( + "Multiple flatMapGroupsWithStates are not supported when they are not all in append mode" + + " or the output mode is not append on a streaming DataFrames/Datasets")(plan) + } + // Disallow multiple streaming aggregations val aggregates = collectStreamingAggregates(plan) @@ -108,17 +139,76 @@ object UnsupportedOperationChecker { } throwErrorIf( child.isStreaming && distinctAggExprs.nonEmpty, - "Distinct aggregations are not supported on streaming DataFrames/Datasets, unless " + - "it is on aggregated DataFrame/Dataset in Complete output mode. Consider using " + - "approximate distinct aggregation (e.g. approx_count_distinct() instead of count()).") + "Distinct aggregations are not supported on streaming DataFrames/Datasets. Consider " + + "using approx_count_distinct() instead.") case _: Command => throwError("Commands like CreateTable*, AlterTable*, Show* are not supported with " + "streaming DataFrames/Datasets") - case m: MapGroupsWithState if collectStreamingAggregates(m).nonEmpty => - throwError("(map/flatMap)GroupsWithState is not supported after aggregation on a " + - "streaming DataFrame/Dataset") + // mapGroupsWithState and flatMapGroupsWithState + case m: FlatMapGroupsWithState if m.isStreaming => + + // Check compatibility with output modes and aggregations in query + val aggsAfterFlatMapGroups = collectStreamingAggregates(plan) + + if (m.isMapGroupsWithState) { // check mapGroupsWithState + // allowed only in update query output mode and without aggregation + if (aggsAfterFlatMapGroups.nonEmpty) { + throwError( + "mapGroupsWithState is not supported with aggregation " + + "on a streaming DataFrame/Dataset") + } else if (outputMode != InternalOutputModes.Update) { + throwError( + "mapGroupsWithState is not supported with " + + s"$outputMode output mode on a streaming DataFrame/Dataset") + } + } else { // check latMapGroupsWithState + if (aggsAfterFlatMapGroups.isEmpty) { + // flatMapGroupsWithState without aggregation: operation's output mode must + // match query output mode + m.outputMode match { + case InternalOutputModes.Update if outputMode != InternalOutputModes.Update => + throwError( + "flatMapGroupsWithState in update mode is not supported with " + + s"$outputMode output mode on a streaming DataFrame/Dataset") + + case InternalOutputModes.Append if outputMode != InternalOutputModes.Append => + throwError( + "flatMapGroupsWithState in append mode is not supported with " + + s"$outputMode output mode on a streaming DataFrame/Dataset") + + case _ => + } + } else { + // flatMapGroupsWithState with aggregation: update operation mode not allowed, and + // *groupsWithState after aggregation not allowed + if (m.outputMode == InternalOutputModes.Update) { + throwError( + "flatMapGroupsWithState in update mode is not supported with " + + "aggregation on a streaming DataFrame/Dataset") + } else if (collectStreamingAggregates(m).nonEmpty) { + throwError( + "flatMapGroupsWithState in append mode is not supported after " + + s"aggregation on a streaming DataFrame/Dataset") + } + } + } + + // Check compatibility with timeout configs + if (m.timeout == EventTimeTimeout) { + // With event time timeout, watermark must be defined. + val watermarkAttributes = m.child.output.collect { + case a: Attribute if a.metadata.contains(EventTimeWatermark.delayKey) => a + } + if (watermarkAttributes.isEmpty) { + throwError( + "Watermark must be specified in the query using " + + "'[Dataset/DataFrame].withWatermark()' for using event-time timeout in a " + + "[map|flatMap]GroupsWithState. Event-time timeout not supported without " + + "watermark.")(plan) + } + } case d: Deduplicate if collectStreamingAggregates(d).nonEmpty => throwError("dropDuplicates is not supported after aggregation on a " + @@ -177,7 +267,7 @@ object UnsupportedOperationChecker { throwError("Limits are not supported on streaming DataFrames/Datasets") case Sort(_, _, _) if !containsCompleteData(subPlan) => - throwError("Sorting is not supported on streaming DataFrames/Datasets, unless it is on" + + throwError("Sorting is not supported on streaming DataFrames/Datasets, unless it is on " + "aggregated DataFrame/Dataset in Complete output mode") case Sample(_, _, _, _, child) if child.isStreaming => diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/timeZoneAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/timeZoneAnalysis.scala new file mode 100644 index 0000000000000..a27aa845bf0ae --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/timeZoneAnalysis.scala @@ -0,0 +1,61 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.catalyst.analysis + +import org.apache.spark.sql.catalyst.expressions.{Cast, Expression, ListQuery, TimeZoneAwareExpression} +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.catalyst.rules.Rule +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.types.DataType + +/** + * Replace [[TimeZoneAwareExpression]] without timezone id by its copy with session local + * time zone. + */ +case class ResolveTimeZone(conf: SQLConf) extends Rule[LogicalPlan] { + private val transformTimeZoneExprs: PartialFunction[Expression, Expression] = { + case e: TimeZoneAwareExpression if e.timeZoneId.isEmpty => + e.withTimeZone(conf.sessionLocalTimeZone) + // Casts could be added in the subquery plan through the rule TypeCoercion while coercing + // the types between the value expression and list query expression of IN expression. + // We need to subject the subquery plan through ResolveTimeZone again to setup timezone + // information for time zone aware expressions. + case e: ListQuery => e.withNewPlan(apply(e.plan)) + } + + override def apply(plan: LogicalPlan): LogicalPlan = + plan.resolveExpressions(transformTimeZoneExprs) + + def resolveTimeZones(e: Expression): Expression = e.transform(transformTimeZoneExprs) +} + +/** + * Mix-in trait for constructing valid [[Cast]] expressions. + */ +trait CastSupport { + /** + * Configuration used to create a valid cast expression. + */ + def conf: SQLConf + + /** + * Create a Cast expression with the session local time zone. + */ + def cast(child: Expression, dataType: DataType): Cast = { + Cast(child, dataType, Option(conf.sessionLocalTimeZone)) + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/view.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/view.scala index a5640a6c967a1..ea46dd7282401 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/view.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/view.scala @@ -18,10 +18,10 @@ package org.apache.spark.sql.catalyst.analysis import org.apache.spark.sql.AnalysisException -import org.apache.spark.sql.catalyst.CatalystConf import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, Cast} import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Project, View} import org.apache.spark.sql.catalyst.rules.Rule +import org.apache.spark.sql.internal.SQLConf /** * This file defines analysis rules related to views. @@ -47,7 +47,7 @@ import org.apache.spark.sql.catalyst.rules.Rule * This should be only done after the batch of Resolution, because the view attributes are not * completely resolved during the batch of Resolution. */ -case class AliasViewChild(conf: CatalystConf) extends Rule[LogicalPlan] { +case class AliasViewChild(conf: SQLConf) extends Rule[LogicalPlan] with CastSupport { override def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { case v @ View(desc, output, child) if child.resolved && output != child.output => val resolver = conf.resolver @@ -78,7 +78,7 @@ case class AliasViewChild(conf: CatalystConf) extends Rule[LogicalPlan] { throw new AnalysisException(s"Cannot up cast ${originAttr.sql} from " + s"${originAttr.dataType.simpleString} to ${attr.simpleString} as it may truncate\n") } else { - Alias(Cast(originAttr, attr.dataType), attr.name)(exprId = attr.exprId, + Alias(cast(originAttr, attr.dataType), attr.name)(exprId = attr.exprId, qualifier = attr.qualifier, explicitMetadata = Some(attr.metadata)) } case (_, originAttr) => originAttr diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalog.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalog.scala index 31eded4deba7d..974ef900e2eed 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalog.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalog.scala @@ -19,7 +19,8 @@ package org.apache.spark.sql.catalyst.catalog import org.apache.spark.sql.catalyst.analysis.{FunctionAlreadyExistsException, NoSuchDatabaseException, NoSuchFunctionException, NoSuchTableException} import org.apache.spark.sql.catalyst.expressions.Expression - +import org.apache.spark.sql.types.StructType +import org.apache.spark.util.ListenerBus /** * Interface for the system catalog (of functions, partitions, tables, and databases). @@ -30,7 +31,8 @@ import org.apache.spark.sql.catalyst.expressions.Expression * * Implementations should throw [[NoSuchDatabaseException]] when databases don't exist. */ -abstract class ExternalCatalog { +abstract class ExternalCatalog + extends ListenerBus[ExternalCatalogEventListener, ExternalCatalogEvent] { import CatalogTypes.TablePartitionSpec protected def requireDbExists(db: String): Unit = { @@ -61,9 +63,22 @@ abstract class ExternalCatalog { // Databases // -------------------------------------------------------------------------- - def createDatabase(dbDefinition: CatalogDatabase, ignoreIfExists: Boolean): Unit + final def createDatabase(dbDefinition: CatalogDatabase, ignoreIfExists: Boolean): Unit = { + val db = dbDefinition.name + postToAll(CreateDatabasePreEvent(db)) + doCreateDatabase(dbDefinition, ignoreIfExists) + postToAll(CreateDatabaseEvent(db)) + } + + protected def doCreateDatabase(dbDefinition: CatalogDatabase, ignoreIfExists: Boolean): Unit - def dropDatabase(db: String, ignoreIfNotExists: Boolean, cascade: Boolean): Unit + final def dropDatabase(db: String, ignoreIfNotExists: Boolean, cascade: Boolean): Unit = { + postToAll(DropDatabasePreEvent(db)) + doDropDatabase(db, ignoreIfNotExists, cascade) + postToAll(DropDatabaseEvent(db)) + } + + protected def doDropDatabase(db: String, ignoreIfNotExists: Boolean, cascade: Boolean): Unit /** * Alter a database whose name matches the one specified in `dbDefinition`, @@ -88,11 +103,39 @@ abstract class ExternalCatalog { // Tables // -------------------------------------------------------------------------- - def createTable(tableDefinition: CatalogTable, ignoreIfExists: Boolean): Unit + final def createTable(tableDefinition: CatalogTable, ignoreIfExists: Boolean): Unit = { + val db = tableDefinition.database + val name = tableDefinition.identifier.table + postToAll(CreateTablePreEvent(db, name)) + doCreateTable(tableDefinition, ignoreIfExists) + postToAll(CreateTableEvent(db, name)) + } + + protected def doCreateTable(tableDefinition: CatalogTable, ignoreIfExists: Boolean): Unit + + final def dropTable( + db: String, + table: String, + ignoreIfNotExists: Boolean, + purge: Boolean): Unit = { + postToAll(DropTablePreEvent(db, table)) + doDropTable(db, table, ignoreIfNotExists, purge) + postToAll(DropTableEvent(db, table)) + } + + protected def doDropTable( + db: String, + table: String, + ignoreIfNotExists: Boolean, + purge: Boolean): Unit - def dropTable(db: String, table: String, ignoreIfNotExists: Boolean, purge: Boolean): Unit + final def renameTable(db: String, oldName: String, newName: String): Unit = { + postToAll(RenameTablePreEvent(db, oldName, newName)) + doRenameTable(db, oldName, newName) + postToAll(RenameTableEvent(db, oldName, newName)) + } - def renameTable(db: String, oldName: String, newName: String): Unit + protected def doRenameTable(db: String, oldName: String, newName: String): Unit /** * Alter a table whose database and name match the ones specified in `tableDefinition`, assuming @@ -104,6 +147,19 @@ abstract class ExternalCatalog { */ def alterTable(tableDefinition: CatalogTable): Unit + /** + * Alter the schema of a table identified by the provided database and table name. The new schema + * should still contain the existing bucket columns and partition columns used by the table. This + * method will also update any Spark SQL-related parameters stored as Hive table properties (such + * as the schema itself). + * + * @param db Database that table to alter schema for exists in + * @param table Name of table to alter schema for + * @param schema Updated schema to be used for the table (must contain existing partition and + * bucket columns) + */ + def alterTableSchema(db: String, table: String, schema: StructType): Unit + def getTable(db: String, table: String): CatalogTable def getTableOption(db: String, table: String): Option[CatalogTable] @@ -256,11 +312,30 @@ abstract class ExternalCatalog { // Functions // -------------------------------------------------------------------------- - def createFunction(db: String, funcDefinition: CatalogFunction): Unit + final def createFunction(db: String, funcDefinition: CatalogFunction): Unit = { + val name = funcDefinition.identifier.funcName + postToAll(CreateFunctionPreEvent(db, name)) + doCreateFunction(db, funcDefinition) + postToAll(CreateFunctionEvent(db, name)) + } + + protected def doCreateFunction(db: String, funcDefinition: CatalogFunction): Unit - def dropFunction(db: String, funcName: String): Unit + final def dropFunction(db: String, funcName: String): Unit = { + postToAll(DropFunctionPreEvent(db, funcName)) + doDropFunction(db, funcName) + postToAll(DropFunctionEvent(db, funcName)) + } - def renameFunction(db: String, oldName: String, newName: String): Unit + protected def doDropFunction(db: String, funcName: String): Unit + + final def renameFunction(db: String, oldName: String, newName: String): Unit = { + postToAll(RenameFunctionPreEvent(db, oldName, newName)) + doRenameFunction(db, oldName, newName) + postToAll(RenameFunctionEvent(db, oldName, newName)) + } + + protected def doRenameFunction(db: String, oldName: String, newName: String): Unit def getFunction(db: String, funcName: String): CatalogFunction @@ -268,4 +343,9 @@ abstract class ExternalCatalog { def listFunctions(db: String, pattern: String): Seq[String] + override protected def doPostEvent( + listener: ExternalCatalogEventListener, + event: ExternalCatalogEvent): Unit = { + listener.onEvent(event) + } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalogUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalogUtils.scala index 58ced549bafe9..1fc3a654cfeb1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalogUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalogUtils.scala @@ -17,12 +17,16 @@ package org.apache.spark.sql.catalyst.catalog +import java.net.URI +import java.util.Locale + import org.apache.hadoop.fs.Path import org.apache.hadoop.util.Shell import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.analysis.Resolver import org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec +import org.apache.spark.sql.catalyst.expressions.{And, AttributeReference, BoundReference, Expression, InterpretedPredicate} object ExternalCatalogUtils { // This duplicates default value of Hive `ConfVars.DEFAULTPARTITIONNAME`, since catalyst doesn't @@ -116,13 +120,45 @@ object ExternalCatalogUtils { } def getPartitionPathString(col: String, value: String): String = { - val partitionString = if (value == null) { + val partitionString = if (value == null || value.isEmpty) { DEFAULT_PARTITION_NAME } else { escapePathName(value) } escapePathName(col) + "=" + partitionString } + + def prunePartitionsByFilter( + catalogTable: CatalogTable, + inputPartitions: Seq[CatalogTablePartition], + predicates: Seq[Expression], + defaultTimeZoneId: String): Seq[CatalogTablePartition] = { + if (predicates.isEmpty) { + inputPartitions + } else { + val partitionSchema = catalogTable.partitionSchema + val partitionColumnNames = catalogTable.partitionColumnNames.toSet + + val nonPartitionPruningPredicates = predicates.filterNot { + _.references.map(_.name).toSet.subsetOf(partitionColumnNames) + } + if (nonPartitionPruningPredicates.nonEmpty) { + throw new AnalysisException("Expected only partition pruning predicates: " + + nonPartitionPruningPredicates) + } + + val boundPredicate = + InterpretedPredicate.create(predicates.reduce(And).transform { + case att: AttributeReference => + val index = partitionSchema.indexWhere(_.name == att.name) + BoundReference(index, partitionSchema(index).dataType, nullable = true) + }) + + inputPartitions.filter { p => + boundPredicate.eval(p.toRow(partitionSchema, defaultTimeZoneId)) + } + } + } } object CatalogUtils { @@ -132,8 +168,10 @@ object CatalogUtils { */ def maskCredentials(options: Map[String, String]): Map[String, String] = { options.map { - case (key, _) if key.toLowerCase == "password" => (key, "###") - case (key, value) if key.toLowerCase == "url" && value.toLowerCase.contains("password") => + case (key, _) if key.toLowerCase(Locale.ROOT) == "password" => (key, "###") + case (key, value) + if key.toLowerCase(Locale.ROOT) == "url" && + value.toLowerCase(Locale.ROOT).contains("password") => (key, "###") case o => o } @@ -162,6 +200,30 @@ object CatalogUtils { BucketSpec(numBuckets, normalizedBucketCols, normalizedSortCols) } + /** + * Convert URI to String. + * Since URI.toString does not decode the uri, e.g. change '%25' to '%'. + * Here we create a hadoop Path with the given URI, and rely on Path.toString + * to decode the uri + * @param uri the URI of the path + * @return the String of the path + */ + def URIToString(uri: URI): String = { + new Path(uri).toString + } + + /** + * Convert String to URI. + * Since new URI(string) does not encode string, e.g. change '%' to '%25'. + * Here we create a hadoop Path with the given String, and rely on Path.toUri + * to encode the string + * @param str the String of the path + * @return the URI of the path + */ + def stringToURI(str: String): URI = { + new Path(str).toUri + } + private def normalizeColumnName( tableName: String, tableCols: Seq[String], diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/InMemoryCatalog.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/InMemoryCatalog.scala index 340e8451f14ee..81dd8efc0015f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/InMemoryCatalog.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/InMemoryCatalog.scala @@ -28,9 +28,10 @@ import org.apache.spark.{SparkConf, SparkException} import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.{FunctionIdentifier, TableIdentifier} import org.apache.spark.sql.catalyst.analysis._ -import org.apache.spark.sql.catalyst.catalog.ExternalCatalogUtils.escapePathName +import org.apache.spark.sql.catalyst.catalog.ExternalCatalogUtils._ import org.apache.spark.sql.catalyst.expressions.Expression import org.apache.spark.sql.catalyst.util.StringUtils +import org.apache.spark.sql.types.StructType /** * An in-memory (ephemeral) implementation of the system catalog. @@ -97,7 +98,7 @@ class InMemoryCatalog( // Databases // -------------------------------------------------------------------------- - override def createDatabase( + override protected def doCreateDatabase( dbDefinition: CatalogDatabase, ignoreIfExists: Boolean): Unit = synchronized { if (catalog.contains(dbDefinition.name)) { @@ -118,7 +119,7 @@ class InMemoryCatalog( } } - override def dropDatabase( + override protected def doDropDatabase( db: String, ignoreIfNotExists: Boolean, cascade: Boolean): Unit = synchronized { @@ -126,7 +127,7 @@ class InMemoryCatalog( if (!cascade) { // If cascade is false, make sure the database is empty. if (catalog(db).tables.nonEmpty) { - throw new AnalysisException(s"Database '$db' is not empty. One or more tables exist.") + throw new AnalysisException(s"Database $db is not empty. One or more tables exist.") } if (catalog(db).functions.nonEmpty) { throw new AnalysisException(s"Database '$db' is not empty. One or more functions exist.") @@ -179,7 +180,7 @@ class InMemoryCatalog( // Tables // -------------------------------------------------------------------------- - override def createTable( + override protected def doCreateTable( tableDefinition: CatalogTable, ignoreIfExists: Boolean): Unit = synchronized { assert(tableDefinition.identifier.database.isDefined) @@ -202,7 +203,7 @@ class InMemoryCatalog( tableDefinition.storage.locationUri.isEmpty val tableWithLocation = if (needDefaultTableLocation) { - val defaultTableLocation = new Path(catalog(db).db.locationUri, table) + val defaultTableLocation = new Path(new Path(catalog(db).db.locationUri), table) try { val fs = defaultTableLocation.getFileSystem(hadoopConfig) fs.mkdirs(defaultTableLocation) @@ -211,7 +212,7 @@ class InMemoryCatalog( throw new SparkException(s"Unable to create table $table as failed " + s"to create its directory $defaultTableLocation", e) } - tableDefinition.withNewStorage(locationUri = Some(defaultTableLocation.toUri.toString)) + tableDefinition.withNewStorage(locationUri = Some(defaultTableLocation.toUri)) } else { tableDefinition } @@ -220,7 +221,7 @@ class InMemoryCatalog( } } - override def dropTable( + override protected def doDropTable( db: String, table: String, ignoreIfNotExists: Boolean, @@ -263,7 +264,10 @@ class InMemoryCatalog( } } - override def renameTable(db: String, oldName: String, newName: String): Unit = synchronized { + override protected def doRenameTable( + db: String, + oldName: String, + newName: String): Unit = synchronized { requireTableExists(db, oldName) requireTableNotExists(db, newName) val oldDesc = catalog(db).tables(oldName) @@ -274,7 +278,7 @@ class InMemoryCatalog( "Managed table should always have table location, as we will assign a default location " + "to it if it doesn't have one.") val oldDir = new Path(oldDesc.table.location) - val newDir = new Path(catalog(db).db.locationUri, newName) + val newDir = new Path(new Path(catalog(db).db.locationUri), newName) try { val fs = oldDir.getFileSystem(hadoopConfig) fs.rename(oldDir, newDir) @@ -283,7 +287,7 @@ class InMemoryCatalog( throw new SparkException(s"Unable to rename table $oldName to $newName as failed " + s"to rename its directory $oldDir", e) } - oldDesc.table = oldDesc.table.withNewStorage(locationUri = Some(newDir.toUri.toString)) + oldDesc.table = oldDesc.table.withNewStorage(locationUri = Some(newDir.toUri)) } catalog(db).tables.put(newName, oldDesc) @@ -297,6 +301,15 @@ class InMemoryCatalog( catalog(db).tables(tableDefinition.identifier.table).table = tableDefinition } + override def alterTableSchema( + db: String, + table: String, + schema: StructType): Unit = synchronized { + requireTableExists(db, table) + val origTable = catalog(db).tables(table).table + catalog(db).tables(table).table = origTable.copy(schema = schema) + } + override def getTable(db: String, table: String): CatalogTable = synchronized { requireTableExists(db, table) catalog(db).tables(table).table @@ -389,7 +402,7 @@ class InMemoryCatalog( existingParts.put( p.spec, - p.copy(storage = p.storage.copy(locationUri = Some(partitionPath.toString)))) + p.copy(storage = p.storage.copy(locationUri = Some(partitionPath.toUri)))) } } @@ -462,7 +475,7 @@ class InMemoryCatalog( } oldPartition.copy( spec = newSpec, - storage = oldPartition.storage.copy(locationUri = Some(newPartPath.toString))) + storage = oldPartition.storage.copy(locationUri = Some(newPartPath.toUri))) } else { oldPartition.copy(spec = newSpec) } @@ -546,27 +559,30 @@ class InMemoryCatalog( table: String, predicates: Seq[Expression], defaultTimeZoneId: String): Seq[CatalogTablePartition] = { - // TODO: Provide an implementation - throw new UnsupportedOperationException( - "listPartitionsByFilter is not implemented for InMemoryCatalog") + val catalogTable = getTable(db, table) + val allPartitions = listPartitions(db, table) + prunePartitionsByFilter(catalogTable, allPartitions, predicates, defaultTimeZoneId) } // -------------------------------------------------------------------------- // Functions // -------------------------------------------------------------------------- - override def createFunction(db: String, func: CatalogFunction): Unit = synchronized { + override protected def doCreateFunction(db: String, func: CatalogFunction): Unit = synchronized { requireDbExists(db) requireFunctionNotExists(db, func.identifier.funcName) catalog(db).functions.put(func.identifier.funcName, func) } - override def dropFunction(db: String, funcName: String): Unit = synchronized { + override protected def doDropFunction(db: String, funcName: String): Unit = synchronized { requireFunctionExists(db, funcName) catalog(db).functions.remove(funcName) } - override def renameFunction(db: String, oldName: String, newName: String): Unit = synchronized { + override protected def doRenameFunction( + db: String, + oldName: String, + newName: String): Unit = synchronized { requireFunctionExists(db, oldName) requireFunctionNotExists(db, newName) val newFunc = getFunction(db, oldName).copy(identifier = FunctionIdentifier(newName, Some(db))) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala index f6412e42c13d5..6c6d600190b66 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala @@ -17,9 +17,12 @@ package org.apache.spark.sql.catalyst.catalog +import java.net.URI +import java.util.Locale import javax.annotation.concurrent.GuardedBy import scala.collection.mutable +import scala.util.{Failure, Success, Try} import com.google.common.cache.{Cache, CacheBuilder} import org.apache.hadoop.conf.Configuration @@ -34,6 +37,8 @@ import org.apache.spark.sql.catalyst.expressions.{Expression, ExpressionInfo} import org.apache.spark.sql.catalyst.parser.{CatalystSqlParser, ParserInterface} import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, SubqueryAlias, View} import org.apache.spark.sql.catalyst.util.StringUtils +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.types.{StructField, StructType} object SessionCatalog { val DEFAULT_DATABASE = "default" @@ -47,13 +52,13 @@ object SessionCatalog { * This class must be thread-safe. */ class SessionCatalog( - externalCatalog: ExternalCatalog, + val externalCatalog: ExternalCatalog, globalTempViewManager: GlobalTempViewManager, - functionResourceLoader: FunctionResourceLoader, functionRegistry: FunctionRegistry, - conf: CatalystConf, + conf: SQLConf, hadoopConf: Configuration, - parser: ParserInterface) extends Logging { + parser: ParserInterface, + functionResourceLoader: FunctionResourceLoader) extends Logging { import SessionCatalog._ import CatalogTypes.TablePartitionSpec @@ -61,20 +66,23 @@ class SessionCatalog( def this( externalCatalog: ExternalCatalog, functionRegistry: FunctionRegistry, - conf: CatalystConf) { + conf: SQLConf) { this( externalCatalog, new GlobalTempViewManager("global_temp"), - DummyFunctionResourceLoader, functionRegistry, conf, new Configuration(), - CatalystSqlParser) + CatalystSqlParser, + DummyFunctionResourceLoader) } // For testing only. def this(externalCatalog: ExternalCatalog) { - this(externalCatalog, new SimpleFunctionRegistry, new SimpleCatalystConf(true)) + this( + externalCatalog, + new SimpleFunctionRegistry, + new SQLConf().copy(SQLConf.CASE_SENSITIVE -> true)) } /** List of temporary tables, mapping from table name to their logical plan. */ @@ -86,7 +94,7 @@ class SessionCatalog( // check whether the temporary table or function exists, then, if not, operate on // the corresponding item in the current database. @GuardedBy("this") - protected var currentDb = formatDatabaseName(DEFAULT_DATABASE) + protected var currentDb: String = formatDatabaseName(DEFAULT_DATABASE) /** * Checks if the given name conforms the Hive standard ("[a-zA-z_0-9]+"), @@ -107,14 +115,14 @@ class SessionCatalog( * Format table name, taking into account case sensitivity. */ protected[this] def formatTableName(name: String): String = { - if (conf.caseSensitiveAnalysis) name else name.toLowerCase + if (conf.caseSensitiveAnalysis) name else name.toLowerCase(Locale.ROOT) } /** * Format database name, taking into account case sensitivity. */ protected[this] def formatDatabaseName(name: String): String = { - if (conf.caseSensitiveAnalysis) name else name.toLowerCase + if (conf.caseSensitiveAnalysis) name else name.toLowerCase(Locale.ROOT) } /** @@ -131,10 +139,10 @@ class SessionCatalog( * does not contain a scheme, this path will not be changed after the default * FileSystem is changed. */ - private def makeQualifiedPath(path: String): Path = { + private def makeQualifiedPath(path: URI): URI = { val hadoopPath = new Path(path) val fs = hadoopPath.getFileSystem(hadoopConf) - fs.makeQualified(hadoopPath) + fs.makeQualified(hadoopPath).toUri } private def requireDbExists(db: String): Unit = { @@ -156,6 +164,20 @@ class SessionCatalog( throw new TableAlreadyExistsException(db = db, table = name.table) } } + + private def checkDuplication(fields: Seq[StructField]): Unit = { + val columnNames = if (conf.caseSensitiveAnalysis) { + fields.map(_.name) + } else { + fields.map(_.name.toLowerCase) + } + if (columnNames.distinct.length != columnNames.length) { + val duplicateColumns = columnNames.groupBy(identity).collect { + case (x, ys) if ys.length > 1 => x + } + throw new AnalysisException(s"Found duplicate column(s): ${duplicateColumns.mkString(", ")}") + } + } // ---------------------------------------------------------------------------- // Databases // ---------------------------------------------------------------------------- @@ -170,7 +192,7 @@ class SessionCatalog( "you cannot create a database with this name.") } validateName(dbName) - val qualifiedPath = makeQualifiedPath(dbDefinition.locationUri).toString + val qualifiedPath = makeQualifiedPath(dbDefinition.locationUri) externalCatalog.createDatabase( dbDefinition.copy(name = dbName, locationUri = qualifiedPath), ignoreIfExists) @@ -228,9 +250,9 @@ class SessionCatalog( * Get the path for creating a non-default database when database location is not provided * by users. */ - def getDefaultDBPath(db: String): String = { + def getDefaultDBPath(db: String): URI = { val database = formatDatabaseName(db) - new Path(new Path(conf.warehousePath), database + ".db").toString + new Path(new Path(conf.warehousePath), database + ".db").toUri } // ---------------------------------------------------------------------------- @@ -254,7 +276,19 @@ class SessionCatalog( val db = formatDatabaseName(tableDefinition.identifier.database.getOrElse(getCurrentDatabase)) val table = formatTableName(tableDefinition.identifier.table) validateName(table) - val newTableDefinition = tableDefinition.copy(identifier = TableIdentifier(table, Some(db))) + + val newTableDefinition = if (tableDefinition.storage.locationUri.isDefined + && !tableDefinition.storage.locationUri.get.isAbsolute) { + // make the location of the table qualified. + val qualifiedTableLocation = + makeQualifiedPath(tableDefinition.storage.locationUri.get) + tableDefinition.copy( + storage = tableDefinition.storage.copy(locationUri = Some(qualifiedTableLocation)), + identifier = TableIdentifier(table, Some(db))) + } else { + tableDefinition.copy(identifier = TableIdentifier(table, Some(db))) + } + requireDbExists(db) externalCatalog.createTable(newTableDefinition, ignoreIfExists) } @@ -278,6 +312,47 @@ class SessionCatalog( externalCatalog.alterTable(newTableDefinition) } + /** + * Alter the schema of a table identified by the provided table identifier. The new schema + * should still contain the existing bucket columns and partition columns used by the table. This + * method will also update any Spark SQL-related parameters stored as Hive table properties (such + * as the schema itself). + * + * @param identifier TableIdentifier + * @param newSchema Updated schema to be used for the table (must contain existing partition and + * bucket columns, and partition columns need to be at the end) + */ + def alterTableSchema( + identifier: TableIdentifier, + newSchema: StructType): Unit = { + val db = formatDatabaseName(identifier.database.getOrElse(getCurrentDatabase)) + val table = formatTableName(identifier.table) + val tableIdentifier = TableIdentifier(table, Some(db)) + requireDbExists(db) + requireTableExists(tableIdentifier) + checkDuplication(newSchema) + + val catalogTable = externalCatalog.getTable(db, table) + val oldSchema = catalogTable.schema + + // not supporting dropping columns yet + val nonExistentColumnNames = oldSchema.map(_.name).filterNot(columnNameResolved(newSchema, _)) + if (nonExistentColumnNames.nonEmpty) { + throw new AnalysisException( + s""" + |Some existing schema fields (${nonExistentColumnNames.mkString("[", ",", "]")}) are + |not present in the new schema. We don't support dropping columns yet. + """.stripMargin) + } + + // assuming the newSchema has all partition columns at the end as required + externalCatalog.alterTableSchema(db, table, newSchema) + } + + private def columnNameResolved(schema: StructType, colName: String): Boolean = { + schema.fields.map(_.name).exists(conf.resolver(_, colName)) + } + /** * Return whether a table/view with the specified name exists. If no database is specified, check * with current database. @@ -351,11 +426,11 @@ class SessionCatalog( db, table, loadPath, spec, isOverwrite, inheritTableSpecs, isSrcLocal) } - def defaultTablePath(tableIdent: TableIdentifier): String = { + def defaultTablePath(tableIdent: TableIdentifier): URI = { val dbName = formatDatabaseName(tableIdent.database.getOrElse(getCurrentDatabase)) val dbLocation = getDatabaseMetadata(dbName).locationUri - new Path(new Path(dbLocation), formatTableName(tableIdent.table)).toString + new Path(new Path(dbLocation), formatTableName(tableIdent.table)).toUri } // ---------------------------------------------- @@ -577,7 +652,7 @@ class SessionCatalog( val table = formatTableName(name.table) if (db == globalTempViewManager.database) { globalTempViewManager.get(table).map { viewDef => - SubqueryAlias(table, viewDef, None) + SubqueryAlias(table, viewDef) }.getOrElse(throw new NoSuchTableException(db, table)) } else if (name.database.isDefined || !tempTables.contains(table)) { val metadata = externalCatalog.getTable(db, table) @@ -590,17 +665,17 @@ class SessionCatalog( desc = metadata, output = metadata.schema.toAttributes, child = parser.parsePlan(viewText)) - SubqueryAlias(table, child, Some(name.copy(table = table, database = Some(db)))) + SubqueryAlias(table, child) } else { val tableRelation = CatalogRelation( metadata, // we assume all the columns are nullable. metadata.dataSchema.asNullable.toAttributes, metadata.partitionSchema.asNullable.toAttributes) - SubqueryAlias(table, tableRelation, None) + SubqueryAlias(table, tableRelation) } } else { - SubqueryAlias(table, tempTables(table), None) + SubqueryAlias(table, tempTables(table)) } } } @@ -976,7 +1051,7 @@ class SessionCatalog( * * This performs reflection to decide what type of [[Expression]] to return in the builder. */ - def makeFunctionBuilder(name: String, functionClassName: String): FunctionBuilder = { + protected def makeFunctionBuilder(name: String, functionClassName: String): FunctionBuilder = { // TODO: at least support UDAFs here throw new UnsupportedOperationException("Use sqlContext.udf.register(...) instead.") } @@ -990,18 +1065,20 @@ class SessionCatalog( } /** - * Create a temporary function. - * This assumes no database is specified in `funcDefinition`. + * Registers a temporary or permanent function into a session-specific [[FunctionRegistry]] */ - def createTempFunction( - name: String, - info: ExpressionInfo, - funcDefinition: FunctionBuilder, - ignoreIfExists: Boolean): Unit = { - if (functionRegistry.lookupFunctionBuilder(name).isDefined && !ignoreIfExists) { - throw new TempFunctionAlreadyExistsException(name) + def registerFunction( + funcDefinition: CatalogFunction, + ignoreIfExists: Boolean, + functionBuilder: Option[FunctionBuilder] = None): Unit = { + val func = funcDefinition.identifier + if (functionRegistry.functionExists(func.unquotedString) && !ignoreIfExists) { + throw new AnalysisException(s"Function $func already exists") } - functionRegistry.registerFunction(name, info, funcDefinition) + val info = new ExpressionInfo(funcDefinition.className, func.database.orNull, func.funcName) + val builder = + functionBuilder.getOrElse(makeFunctionBuilder(func.unquotedString, funcDefinition.className)) + functionRegistry.registerFunction(func.unquotedString, info, builder) } /** @@ -1025,7 +1102,7 @@ class SessionCatalog( name.database.isEmpty && functionRegistry.functionExists(name.funcName) && !FunctionRegistry.builtin.functionExists(name.funcName) && - !hiveFunctions.contains(name.funcName.toLowerCase) + !hiveFunctions.contains(name.funcName.toLowerCase(Locale.ROOT)) } protected def failFunctionLookup(name: String): Nothing = { @@ -1106,12 +1183,7 @@ class SessionCatalog( // catalog. So, it is possible that qualifiedName is not exactly the same as // catalogFunction.identifier.unquotedString (difference is on case-sensitivity). // At here, we preserve the input from the user. - val info = new ExpressionInfo( - catalogFunction.className, - qualifiedName.database.orNull, - qualifiedName.funcName) - val builder = makeFunctionBuilder(qualifiedName.unquotedString, catalogFunction.className) - createTempFunction(qualifiedName.unquotedString, info, builder, ignoreIfExists = false) + registerFunction(catalogFunction.copy(identifier = qualifiedName), ignoreIfExists = false) // Now, we need to create the Expression. functionRegistry.lookupFunction(qualifiedName.unquotedString, children) } @@ -1131,15 +1203,25 @@ class SessionCatalog( def listFunctions(db: String, pattern: String): Seq[(FunctionIdentifier, String)] = { val dbName = formatDatabaseName(db) requireDbExists(dbName) - val dbFunctions = externalCatalog.listFunctions(dbName, pattern) - .map { f => FunctionIdentifier(f, Some(dbName)) } - val loadedFunctions = StringUtils.filterPattern(functionRegistry.listFunction(), pattern) - .map { f => FunctionIdentifier(f) } + val dbFunctions = externalCatalog.listFunctions(dbName, pattern).map { f => + FunctionIdentifier(f, Some(dbName)) } + val loadedFunctions = + StringUtils.filterPattern(functionRegistry.listFunction(), pattern).map { f => + // In functionRegistry, function names are stored as an unquoted format. + Try(parser.parseFunctionIdentifier(f)) match { + case Success(e) => e + case Failure(_) => + // The names of some built-in functions are not parsable by our parser, e.g., % + FunctionIdentifier(f) + } + } val functions = dbFunctions ++ loadedFunctions + // The session catalog caches some persistent functions in the FunctionRegistry + // so there can be duplicates. functions.map { case f if FunctionRegistry.functionSet.contains(f.funcName) => (f, "SYSTEM") case f => (f, "USER") - } + }.distinct } @@ -1155,6 +1237,7 @@ class SessionCatalog( */ def reset(): Unit = synchronized { setCurrentDatabase(DEFAULT_DATABASE) + externalCatalog.setCurrentDatabase(DEFAULT_DATABASE) listDatabases().filter(_ != DEFAULT_DATABASE).foreach { db => dropDatabase(db, ignoreIfNotExists = false, cascade = true) } @@ -1181,4 +1264,17 @@ class SessionCatalog( } } + /** + * Copy the current state of the catalog to another catalog. + * + * This function is synchronized on this [[SessionCatalog]] (the source) to make sure the copied + * state is consistent. The target [[SessionCatalog]] is not synchronized, and should not be + * because the target [[SessionCatalog]] should not be published at this point. The caller must + * synchronize on the target if this assumption does not hold. + */ + private[sql] def copyStateTo(target: SessionCatalog): Unit = synchronized { + target.currentDb = currentDb + // copy over temporary tables + tempTables.foreach(kv => target.tempTables.put(kv._1, kv._2)) + } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/events.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/events.scala new file mode 100644 index 0000000000000..459973a13bb10 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/events.scala @@ -0,0 +1,158 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.catalyst.catalog + +import org.apache.spark.scheduler.SparkListenerEvent + +/** + * Event emitted by the external catalog when it is modified. Events are either fired before or + * after the modification (the event should document this). + */ +trait ExternalCatalogEvent extends SparkListenerEvent + +/** + * Listener interface for external catalog modification events. + */ +trait ExternalCatalogEventListener { + def onEvent(event: ExternalCatalogEvent): Unit +} + +/** + * Event fired when a database is create or dropped. + */ +trait DatabaseEvent extends ExternalCatalogEvent { + /** + * Database of the object that was touched. + */ + val database: String +} + +/** + * Event fired before a database is created. + */ +case class CreateDatabasePreEvent(database: String) extends DatabaseEvent + +/** + * Event fired after a database has been created. + */ +case class CreateDatabaseEvent(database: String) extends DatabaseEvent + +/** + * Event fired before a database is dropped. + */ +case class DropDatabasePreEvent(database: String) extends DatabaseEvent + +/** + * Event fired after a database has been dropped. + */ +case class DropDatabaseEvent(database: String) extends DatabaseEvent + +/** + * Event fired when a table is created, dropped or renamed. + */ +trait TableEvent extends DatabaseEvent { + /** + * Name of the table that was touched. + */ + val name: String +} + +/** + * Event fired before a table is created. + */ +case class CreateTablePreEvent(database: String, name: String) extends TableEvent + +/** + * Event fired after a table has been created. + */ +case class CreateTableEvent(database: String, name: String) extends TableEvent + +/** + * Event fired before a table is dropped. + */ +case class DropTablePreEvent(database: String, name: String) extends TableEvent + +/** + * Event fired after a table has been dropped. + */ +case class DropTableEvent(database: String, name: String) extends TableEvent + +/** + * Event fired before a table is renamed. + */ +case class RenameTablePreEvent( + database: String, + name: String, + newName: String) + extends TableEvent + +/** + * Event fired after a table has been renamed. + */ +case class RenameTableEvent( + database: String, + name: String, + newName: String) + extends TableEvent + +/** + * Event fired when a function is created, dropped or renamed. + */ +trait FunctionEvent extends DatabaseEvent { + /** + * Name of the function that was touched. + */ + val name: String +} + +/** + * Event fired before a function is created. + */ +case class CreateFunctionPreEvent(database: String, name: String) extends FunctionEvent + +/** + * Event fired after a function has been created. + */ +case class CreateFunctionEvent(database: String, name: String) extends FunctionEvent + +/** + * Event fired before a function is dropped. + */ +case class DropFunctionPreEvent(database: String, name: String) extends FunctionEvent + +/** + * Event fired after a function has been dropped. + */ +case class DropFunctionEvent(database: String, name: String) extends FunctionEvent + +/** + * Event fired before a function is renamed. + */ +case class RenameFunctionPreEvent( + database: String, + name: String, + newName: String) + extends FunctionEvent + +/** + * Event fired after a function has been renamed. + */ +case class RenameFunctionEvent( + database: String, + name: String, + newName: String) + extends FunctionEvent diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/functionResources.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/functionResources.scala index 8e46b962ff432..67bf2d06c95dd 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/functionResources.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/functionResources.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.catalyst.catalog +import java.util.Locale + import org.apache.spark.sql.AnalysisException /** A trait that represents the type of a resourced needed by a function. */ @@ -33,7 +35,7 @@ object ArchiveResource extends FunctionResourceType("archive") object FunctionResourceType { def fromString(resourceType: String): FunctionResourceType = { - resourceType.toLowerCase match { + resourceType.toLowerCase(Locale.ROOT) match { case "jar" => JarResource case "file" => FileResource case "archive" => ArchiveResource diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala index 887caf07d1481..cc0cbba275b81 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala @@ -17,17 +17,21 @@ package org.apache.spark.sql.catalyst.catalog +import java.net.URI import java.util.Date +import scala.collection.mutable + import com.google.common.base.Objects import org.apache.spark.sql.AnalysisException -import org.apache.spark.sql.catalyst.{CatalystConf, FunctionIdentifier, InternalRow, TableIdentifier} +import org.apache.spark.sql.catalyst.{FunctionIdentifier, InternalRow, TableIdentifier} import org.apache.spark.sql.catalyst.analysis.MultiInstanceRelation -import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeMap, Cast, Literal} +import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeMap, AttributeReference, Cast, Literal} import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.util.{CaseInsensitiveMap, DateTimeUtils} import org.apache.spark.sql.catalyst.util.quoteIdentifier +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.StructType @@ -48,10 +52,7 @@ case class CatalogFunction( * Storage format, used to describe how a partition or a table is stored. */ case class CatalogStorageFormat( - // TODO(ekl) consider storing this field as java.net.URI for type safety. Note that this must - // be converted to/from a hadoop Path object using new Path(new URI(locationUri)) and - // path.toUri respectively before use as a filesystem path due to URI char escaping. - locationUri: Option[String], + locationUri: Option[URI], inputFormat: Option[String], outputFormat: Option[String], serde: Option[String], @@ -59,20 +60,25 @@ case class CatalogStorageFormat( properties: Map[String, String]) { override def toString: String = { - val serdePropsToString = CatalogUtils.maskCredentials(properties) match { - case props if props.isEmpty => "" - case props => "Properties: " + props.map(p => p._1 + "=" + p._2).mkString("[", ", ", "]") - } - val output = - Seq(locationUri.map("Location: " + _).getOrElse(""), - inputFormat.map("InputFormat: " + _).getOrElse(""), - outputFormat.map("OutputFormat: " + _).getOrElse(""), - if (compressed) "Compressed" else "", - serde.map("Serde: " + _).getOrElse(""), - serdePropsToString) - output.filter(_.nonEmpty).mkString("Storage(", ", ", ")") + toLinkedHashMap.map { case ((key, value)) => + if (value.isEmpty) key else s"$key: $value" + }.mkString("Storage(", ", ", ")") } + def toLinkedHashMap: mutable.LinkedHashMap[String, String] = { + val map = new mutable.LinkedHashMap[String, String]() + locationUri.foreach(l => map.put("Location", l.toString)) + serde.foreach(map.put("Serde Library", _)) + inputFormat.foreach(map.put("InputFormat", _)) + outputFormat.foreach(map.put("OutputFormat", _)) + if (compressed) map.put("Compressed", "") + CatalogUtils.maskCredentials(properties) match { + case props if props.isEmpty => // No-op + case props => + map.put("Properties", props.map(p => p._1 + "=" + p._2).mkString("[", ", ", "]")) + } + map + } } object CatalogStorageFormat { @@ -93,19 +99,32 @@ case class CatalogTablePartition( storage: CatalogStorageFormat, parameters: Map[String, String] = Map.empty) { - override def toString: String = { + def toLinkedHashMap: mutable.LinkedHashMap[String, String] = { + val map = new mutable.LinkedHashMap[String, String]() val specString = spec.map { case (k, v) => s"$k=$v" }.mkString(", ") - val output = - Seq( - s"Partition Values: [$specString]", - s"$storage", - s"Partition Parameters:{${parameters.map(p => p._1 + "=" + p._2).mkString(", ")}}") + map.put("Partition Values", s"[$specString]") + map ++= storage.toLinkedHashMap + if (parameters.nonEmpty) { + map.put("Partition Parameters", s"{${parameters.map(p => p._1 + "=" + p._2).mkString(", ")}}") + } + map + } - output.filter(_.nonEmpty).mkString("CatalogPartition(\n\t", "\n\t", ")") + override def toString: String = { + toLinkedHashMap.map { case ((key, value)) => + if (value.isEmpty) key else s"$key: $value" + }.mkString("CatalogPartition(\n\t", "\n\t", ")") + } + + /** Readable string representation for the CatalogTablePartition. */ + def simpleString: String = { + toLinkedHashMap.map { case ((key, value)) => + if (value.isEmpty) key else s"$key: $value" + }.mkString("", "\n", "") } /** Return the partition location, assuming it is specified. */ - def location: String = storage.locationUri.getOrElse { + def location: URI = storage.locationUri.getOrElse { val specString = spec.map { case (k, v) => s"$k=$v" }.mkString(", ") throw new AnalysisException(s"Partition [$specString] did not specify locationUri") } @@ -115,9 +134,15 @@ case class CatalogTablePartition( */ def toRow(partitionSchema: StructType, defaultTimeZondId: String): InternalRow = { val caseInsensitiveProperties = CaseInsensitiveMap(storage.properties) - val timeZoneId = caseInsensitiveProperties.getOrElse("timeZone", defaultTimeZondId) + val timeZoneId = caseInsensitiveProperties.getOrElse( + DateTimeUtils.TIMEZONE_OPTION, defaultTimeZondId) InternalRow.fromSeq(partitionSchema.map { field => - Cast(Literal(spec(field.name)), field.dataType, Option(timeZoneId)).eval() + val partValue = if (spec(field.name) == ExternalCatalogUtils.DEFAULT_PARTITION_NAME) { + null + } else { + spec(field.name) + } + Cast(Literal(partValue), field.dataType, Option(timeZoneId)).eval() }) } } @@ -150,6 +175,14 @@ case class BucketSpec( } s"$numBuckets buckets, $bucketString$sortString" } + + def toLinkedHashMap: mutable.LinkedHashMap[String, String] = { + mutable.LinkedHashMap[String, String]( + "Num Buckets" -> numBuckets.toString, + "Bucket Columns" -> bucketColumnNames.map(quoteIdentifier).mkString("[", ", ", "]"), + "Sort Columns" -> sortColumnNames.map(quoteIdentifier).mkString("[", ", ", "]") + ) + } } /** @@ -165,6 +198,11 @@ case class BucketSpec( * @param tracksPartitionsInCatalog whether this table's partition metadata is stored in the * catalog. If false, it is inferred automatically based on file * structure. + * @param schemaPreservesCase Whether or not the schema resolved for this table is case-sensitive. + * When using a Hive Metastore, this flag is set to false if a case- + * sensitive schema was unable to be read from the table properties. + * Used to trigger case-sensitive schema inference at query time, when + * configured. */ case class CatalogTable( identifier: TableIdentifier, @@ -182,7 +220,8 @@ case class CatalogTable( viewText: Option[String] = None, comment: Option[String] = None, unsupportedFeatures: Seq[String] = Seq.empty, - tracksPartitionsInCatalog: Boolean = false) { + tracksPartitionsInCatalog: Boolean = false, + schemaPreservesCase: Boolean = true) { import CatalogTable._ @@ -210,7 +249,7 @@ case class CatalogTable( } /** Return the table location, assuming it is specified. */ - def location: String = storage.locationUri.getOrElse { + def location: URI = storage.locationUri.getOrElse { throw new AnalysisException(s"table $identifier did not specify locationUri") } @@ -241,7 +280,7 @@ case class CatalogTable( /** Syntactic sugar to update a field in `storage`. */ def withNewStorage( - locationUri: Option[String] = storage.locationUri, + locationUri: Option[URI] = storage.locationUri, inputFormat: Option[String] = storage.inputFormat, outputFormat: Option[String] = storage.outputFormat, compressed: Boolean = false, @@ -251,40 +290,50 @@ case class CatalogTable( locationUri, inputFormat, outputFormat, serde, compressed, properties)) } - override def toString: String = { + + def toLinkedHashMap: mutable.LinkedHashMap[String, String] = { + val map = new mutable.LinkedHashMap[String, String]() val tableProperties = properties.map(p => p._1 + "=" + p._2).mkString("[", ", ", "]") val partitionColumns = partitionColumnNames.map(quoteIdentifier).mkString("[", ", ", "]") - val bucketStrings = bucketSpec match { - case Some(BucketSpec(numBuckets, bucketColumnNames, sortColumnNames)) => - val bucketColumnsString = bucketColumnNames.map(quoteIdentifier).mkString("[", ", ", "]") - val sortColumnsString = sortColumnNames.map(quoteIdentifier).mkString("[", ", ", "]") - Seq( - s"Num Buckets: $numBuckets", - if (bucketColumnNames.nonEmpty) s"Bucket Columns: $bucketColumnsString" else "", - if (sortColumnNames.nonEmpty) s"Sort Columns: $sortColumnsString" else "" - ) - - case _ => Nil + + identifier.database.foreach(map.put("Database", _)) + map.put("Table", identifier.table) + if (owner.nonEmpty) map.put("Owner", owner) + map.put("Created", new Date(createTime).toString) + map.put("Last Access", new Date(lastAccessTime).toString) + map.put("Type", tableType.name) + provider.foreach(map.put("Provider", _)) + bucketSpec.foreach(map ++= _.toLinkedHashMap) + comment.foreach(map.put("Comment", _)) + if (tableType == CatalogTableType.VIEW) { + viewText.foreach(map.put("View Text", _)) + viewDefaultDatabase.foreach(map.put("View Default Database", _)) + if (viewQueryColumnNames.nonEmpty) { + map.put("View Query Output Columns", viewQueryColumnNames.mkString("[", ", ", "]")) + } } - val output = - Seq(s"Table: ${identifier.quotedString}", - if (owner.nonEmpty) s"Owner: $owner" else "", - s"Created: ${new Date(createTime).toString}", - s"Last Access: ${new Date(lastAccessTime).toString}", - s"Type: ${tableType.name}", - if (schema.nonEmpty) s"Schema: ${schema.mkString("[", ", ", "]")}" else "", - if (provider.isDefined) s"Provider: ${provider.get}" else "", - if (partitionColumnNames.nonEmpty) s"Partition Columns: $partitionColumns" else "" - ) ++ bucketStrings ++ Seq( - viewText.map("View: " + _).getOrElse(""), - comment.map("Comment: " + _).getOrElse(""), - if (properties.nonEmpty) s"Properties: $tableProperties" else "", - if (stats.isDefined) s"Statistics: ${stats.get.simpleString}" else "", - s"$storage", - if (tracksPartitionsInCatalog) "Partition Provider: Catalog" else "") - - output.filter(_.nonEmpty).mkString("CatalogTable(\n\t", "\n\t", ")") + if (properties.nonEmpty) map.put("Properties", tableProperties) + stats.foreach(s => map.put("Statistics", s.simpleString)) + map ++= storage.toLinkedHashMap + if (tracksPartitionsInCatalog) map.put("Partition Provider", "Catalog") + if (partitionColumnNames.nonEmpty) map.put("Partition Columns", partitionColumns) + if (schema.nonEmpty) map.put("Schema", schema.treeString) + + map + } + + override def toString: String = { + toLinkedHashMap.map { case ((key, value)) => + if (value.isEmpty) key else s"$key: $value" + }.mkString("CatalogTable(\n", "\n", ")") + } + + /** Readable string representation for the CatalogTable. */ + def simpleString: String = { + toLinkedHashMap.map { case ((key, value)) => + if (value.isEmpty) key else s"$key: $value" + }.mkString("", "\n", "") } } @@ -337,7 +386,7 @@ object CatalogTableType { case class CatalogDatabase( name: String, description: String, - locationUri: String, + locationUri: URI, properties: Map[String, String]) @@ -354,14 +403,14 @@ object CatalogTypes { */ case class CatalogRelation( tableMeta: CatalogTable, - dataCols: Seq[Attribute], - partitionCols: Seq[Attribute]) extends LeafNode with MultiInstanceRelation { + dataCols: Seq[AttributeReference], + partitionCols: Seq[AttributeReference]) extends LeafNode with MultiInstanceRelation { assert(tableMeta.identifier.database.isDefined) assert(tableMeta.partitionSchema.sameType(partitionCols.toStructType)) assert(tableMeta.dataSchema.sameType(dataCols.toStructType)) // The partition column should always appear after data columns. - override def output: Seq[Attribute] = dataCols ++ partitionCols + override def output: Seq[AttributeReference] = dataCols ++ partitionCols def isPartitioned: Boolean = partitionCols.nonEmpty @@ -374,10 +423,17 @@ case class CatalogRelation( Objects.hashCode(tableMeta.identifier, output) } - /** Only compare table identifier. */ - override lazy val cleanArgs: Seq[Any] = Seq(tableMeta.identifier) - - override def computeStats(conf: CatalystConf): Statistics = { + override def preCanonicalized: LogicalPlan = copy(tableMeta = CatalogTable( + identifier = tableMeta.identifier, + tableType = tableMeta.tableType, + storage = CatalogStorageFormat.empty, + schema = tableMeta.schema, + partitionColumnNames = tableMeta.partitionColumnNames, + bucketSpec = tableMeta.bucketSpec, + createTime = -1 + )) + + override def computeStats(conf: SQLConf): Statistics = { // For data source tables, we will create a `LogicalRelation` and won't call this method, for // hive serde tables, we will always generate a statistics. // TODO: unify the table stats generation. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala index c062e4e84bcdd..75bf780d41424 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala @@ -109,9 +109,9 @@ package object dsl { def cast(to: DataType): Expression = Cast(expr, to) def asc: SortOrder = SortOrder(expr, Ascending) - def asc_nullsLast: SortOrder = SortOrder(expr, Ascending, NullsLast) + def asc_nullsLast: SortOrder = SortOrder(expr, Ascending, NullsLast, Set.empty) def desc: SortOrder = SortOrder(expr, Descending) - def desc_nullsFirst: SortOrder = SortOrder(expr, Descending, NullsFirst) + def desc_nullsFirst: SortOrder = SortOrder(expr, Descending, NullsFirst, Set.empty) def as(alias: String): NamedExpression = Alias(expr, alias)() def as(alias: Symbol): NamedExpression = Alias(expr, alias.name)() } @@ -346,7 +346,7 @@ package object dsl { orderSpec: Seq[SortOrder]): LogicalPlan = Window(windowExpressions, partitionSpec, orderSpec, logicalPlan) - def subquery(alias: Symbol): LogicalPlan = SubqueryAlias(alias.name, logicalPlan, None) + def subquery(alias: Symbol): LogicalPlan = SubqueryAlias(alias.name, logicalPlan) def except(otherPlan: LogicalPlan): LogicalPlan = Except(logicalPlan, otherPlan) @@ -368,7 +368,10 @@ package object dsl { analysis.UnresolvedRelation(TableIdentifier(tableName)), Map.empty, logicalPlan, overwrite, false) - def as(alias: String): LogicalPlan = SubqueryAlias(alias, logicalPlan, None) + def as(alias: String): LogicalPlan = SubqueryAlias(alias, logicalPlan) + + def coalesce(num: Integer): LogicalPlan = + Repartition(num, shuffle = false, logicalPlan) def repartition(num: Integer): LogicalPlan = Repartition(num, shuffle = true, logicalPlan) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala index 0782143d465b3..ec003cdc17b89 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala @@ -45,8 +45,8 @@ import org.apache.spark.util.Utils object ExpressionEncoder { def apply[T : TypeTag](): ExpressionEncoder[T] = { // We convert the not-serializable TypeTag into StructType and ClassTag. - val mirror = typeTag[T].mirror - val tpe = typeTag[T].tpe + val mirror = ScalaReflection.mirror + val tpe = typeTag[T].in(mirror).tpe if (ScalaReflection.optionOfProductType(tpe)) { throw new UnsupportedOperationException( @@ -229,9 +229,9 @@ case class ExpressionEncoder[T]( // serializer expressions are used to encode an object to a row, while the object is usually an // intermediate value produced inside an operator, not from the output of the child operator. This // is quite different from normal expressions, and `AttributeReference` doesn't work here - // (intermediate value is not an attribute). We assume that all serializer expressions use a same - // `BoundReference` to refer to the object, and throw exception if they don't. - assert(serializer.forall(_.references.isEmpty), "serializer cannot reference to any attributes.") + // (intermediate value is not an attribute). We assume that all serializer expressions use the + // same `BoundReference` to refer to the object, and throw exception if they don't. + assert(serializer.forall(_.references.isEmpty), "serializer cannot reference any attributes.") assert(serializer.flatMap { ser => val boundRefs = ser.collect { case b: BoundReference => b } assert(boundRefs.nonEmpty, diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala index e95e97b9dc6cb..0f8282d3b2f1f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala @@ -89,7 +89,7 @@ object RowEncoder { udtClass, Nil, dataType = ObjectType(udtClass), false) - Invoke(obj, "serialize", udt, inputObject :: Nil) + Invoke(obj, "serialize", udt, inputObject :: Nil, returnNullable = false) case TimestampType => StaticInvoke( @@ -136,16 +136,18 @@ object RowEncoder { case t @ MapType(kt, vt, valueNullable) => val keys = Invoke( - Invoke(inputObject, "keysIterator", ObjectType(classOf[scala.collection.Iterator[_]])), + Invoke(inputObject, "keysIterator", ObjectType(classOf[scala.collection.Iterator[_]]), + returnNullable = false), "toSeq", - ObjectType(classOf[scala.collection.Seq[_]])) + ObjectType(classOf[scala.collection.Seq[_]]), returnNullable = false) val convertedKeys = serializerFor(keys, ArrayType(kt, false)) val values = Invoke( - Invoke(inputObject, "valuesIterator", ObjectType(classOf[scala.collection.Iterator[_]])), + Invoke(inputObject, "valuesIterator", ObjectType(classOf[scala.collection.Iterator[_]]), + returnNullable = false), "toSeq", - ObjectType(classOf[scala.collection.Seq[_]])) + ObjectType(classOf[scala.collection.Seq[_]]), returnNullable = false) val convertedValues = serializerFor(values, ArrayType(vt, valueNullable)) NewInstance( @@ -262,17 +264,18 @@ object RowEncoder { input :: Nil) case _: DecimalType => - Invoke(input, "toJavaBigDecimal", ObjectType(classOf[java.math.BigDecimal])) + Invoke(input, "toJavaBigDecimal", ObjectType(classOf[java.math.BigDecimal]), + returnNullable = false) case StringType => - Invoke(input, "toString", ObjectType(classOf[String])) + Invoke(input, "toString", ObjectType(classOf[String]), returnNullable = false) case ArrayType(et, nullable) => val arrayData = Invoke( MapObjects(deserializerFor(_), input, et), "array", - ObjectType(classOf[Array[_]])) + ObjectType(classOf[Array[_]]), returnNullable = false) StaticInvoke( scala.collection.mutable.WrappedArray.getClass, ObjectType(classOf[Seq[_]]), diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala index a36d3507d92ec..a53ef426f79b5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala @@ -26,7 +26,7 @@ import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.util._ import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String} - +import org.apache.spark.unsafe.types.UTF8String.{IntWrapper, LongWrapper} object Cast { @@ -89,6 +89,31 @@ object Cast { case _ => false } + /** + * Return true if we need to use the `timeZone` information casting `from` type to `to` type. + * The patterns matched reflect the current implementation in the Cast node. + * c.f. usage of `timeZone` in: + * * Cast.castToString + * * Cast.castToDate + * * Cast.castToTimestamp + */ + def needsTimeZone(from: DataType, to: DataType): Boolean = (from, to) match { + case (StringType, TimestampType) => true + case (DateType, TimestampType) => true + case (TimestampType, StringType) => true + case (TimestampType, DateType) => true + case (ArrayType(fromType, _), ArrayType(toType, _)) => needsTimeZone(fromType, toType) + case (MapType(fromKey, fromValue, _), MapType(toKey, toValue, _)) => + needsTimeZone(fromKey, toKey) || needsTimeZone(fromValue, toValue) + case (StructType(fromFields), StructType(toFields)) => + fromFields.length == toFields.length && + fromFields.zip(toFields).exists { + case (fromField, toField) => + needsTimeZone(fromField.dataType, toField.dataType) + } + case _ => false + } + /** * Return true iff we may truncate during casting `from` type to `to` type. e.g. long -> int, * timestamp -> date. @@ -165,6 +190,13 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String override def withTimeZone(timeZoneId: String): TimeZoneAwareExpression = copy(timeZoneId = Option(timeZoneId)) + // When this cast involves TimeZone, it's only resolved if the timeZoneId is set; + // Otherwise behave like Expression.resolved. + override lazy val resolved: Boolean = + childrenResolved && checkInputDataTypes().isSuccess && (!needsTimeZone || timeZoneId.isDefined) + + private[this] def needsTimeZone: Boolean = Cast.needsTimeZone(child.dataType, dataType) + // [[func]] assumes the input is no longer null because eval already does the null check. @inline private[this] def buildCast[T](a: Any, func: T => Any): Any = func(a.asInstanceOf[T]) @@ -277,9 +309,8 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String // LongConverter private[this] def castToLong(from: DataType): Any => Any = from match { case StringType => - buildCast[UTF8String](_, s => try s.toLong catch { - case _: NumberFormatException => null - }) + val result = new LongWrapper() + buildCast[UTF8String](_, s => if (s.toLong(result)) result.value else null) case BooleanType => buildCast[Boolean](_, b => if (b) 1L else 0L) case DateType => @@ -293,9 +324,8 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String // IntConverter private[this] def castToInt(from: DataType): Any => Any = from match { case StringType => - buildCast[UTF8String](_, s => try s.toInt catch { - case _: NumberFormatException => null - }) + val result = new IntWrapper() + buildCast[UTF8String](_, s => if (s.toInt(result)) result.value else null) case BooleanType => buildCast[Boolean](_, b => if (b) 1 else 0) case DateType => @@ -309,8 +339,11 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String // ShortConverter private[this] def castToShort(from: DataType): Any => Any = from match { case StringType => - buildCast[UTF8String](_, s => try s.toShort catch { - case _: NumberFormatException => null + val result = new IntWrapper() + buildCast[UTF8String](_, s => if (s.toShort(result)) { + result.value.toShort + } else { + null }) case BooleanType => buildCast[Boolean](_, b => if (b) 1.toShort else 0.toShort) @@ -325,8 +358,11 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String // ByteConverter private[this] def castToByte(from: DataType): Any => Any = from match { case StringType => - buildCast[UTF8String](_, s => try s.toByte catch { - case _: NumberFormatException => null + val result = new IntWrapper() + buildCast[UTF8String](_, s => if (s.toByte(result)) { + result.value.toByte + } else { + null }) case BooleanType => buildCast[Boolean](_, b => if (b) 1.toByte else 0.toByte) @@ -348,6 +384,15 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String if (value.changePrecision(decimalType.precision, decimalType.scale)) value else null } + /** + * Create new `Decimal` with precision and scale given in `decimalType` (if any), + * returning null if it overflows or creating a new `value` and returning it if successful. + * + */ + private[this] def toPrecision(value: Decimal, decimalType: DecimalType): Decimal = + value.toPrecision(decimalType.precision, decimalType.scale).orNull + + private[this] def castToDecimal(from: DataType, target: DecimalType): Any => Any = from match { case StringType => buildCast[UTF8String](_, s => try { @@ -356,14 +401,14 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String case _: NumberFormatException => null }) case BooleanType => - buildCast[Boolean](_, b => changePrecision(if (b) Decimal.ONE else Decimal.ZERO, target)) + buildCast[Boolean](_, b => toPrecision(if (b) Decimal.ONE else Decimal.ZERO, target)) case DateType => buildCast[Int](_, d => null) // date can't cast to decimal in Hive case TimestampType => // Note that we lose precision here. buildCast[Long](_, t => changePrecision(Decimal(timestampToDouble(t)), target)) case dt: DecimalType => - b => changePrecision(b.asInstanceOf[Decimal].clone(), target) + b => toPrecision(b.asInstanceOf[Decimal], target) case t: IntegralType => b => changePrecision(Decimal(t.integral.asInstanceOf[Integral[Any]].toLong(b)), target) case x: FractionalType => @@ -449,35 +494,54 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String }) } - private[this] def cast(from: DataType, to: DataType): Any => Any = to match { - case dt if dt == from => identity[Any] - case StringType => castToString(from) - case BinaryType => castToBinary(from) - case DateType => castToDate(from) - case decimal: DecimalType => castToDecimal(from, decimal) - case TimestampType => castToTimestamp(from) - case CalendarIntervalType => castToInterval(from) - case BooleanType => castToBoolean(from) - case ByteType => castToByte(from) - case ShortType => castToShort(from) - case IntegerType => castToInt(from) - case FloatType => castToFloat(from) - case LongType => castToLong(from) - case DoubleType => castToDouble(from) - case array: ArrayType => castArray(from.asInstanceOf[ArrayType].elementType, array.elementType) - case map: MapType => castMap(from.asInstanceOf[MapType], map) - case struct: StructType => castStruct(from.asInstanceOf[StructType], struct) - case udt: UserDefinedType[_] - if udt.userClass == from.asInstanceOf[UserDefinedType[_]].userClass => - identity[Any] - case _: UserDefinedType[_] => - throw new SparkException(s"Cannot cast $from to $to.") + private[this] def cast(from: DataType, to: DataType): Any => Any = { + // If the cast does not change the structure, then we don't really need to cast anything. + // We can return what the children return. Same thing should happen in the codegen path. + if (DataType.equalsStructurally(from, to)) { + identity + } else { + to match { + case dt if dt == from => identity[Any] + case StringType => castToString(from) + case BinaryType => castToBinary(from) + case DateType => castToDate(from) + case decimal: DecimalType => castToDecimal(from, decimal) + case TimestampType => castToTimestamp(from) + case CalendarIntervalType => castToInterval(from) + case BooleanType => castToBoolean(from) + case ByteType => castToByte(from) + case ShortType => castToShort(from) + case IntegerType => castToInt(from) + case FloatType => castToFloat(from) + case LongType => castToLong(from) + case DoubleType => castToDouble(from) + case array: ArrayType => + castArray(from.asInstanceOf[ArrayType].elementType, array.elementType) + case map: MapType => castMap(from.asInstanceOf[MapType], map) + case struct: StructType => castStruct(from.asInstanceOf[StructType], struct) + case udt: UserDefinedType[_] + if udt.userClass == from.asInstanceOf[UserDefinedType[_]].userClass => + identity[Any] + case _: UserDefinedType[_] => + throw new SparkException(s"Cannot cast $from to $to.") + } + } } private[this] lazy val cast: Any => Any = cast(child.dataType, dataType) protected override def nullSafeEval(input: Any): Any = cast(input) + override def genCode(ctx: CodegenContext): ExprCode = { + // If the cast does not change the structure, then we don't really need to cast anything. + // We can return what the children return. Same thing should happen in the interpreted path. + if (DataType.equalsStructurally(child.dataType, dataType)) { + child.genCode(ctx) + } else { + super.genCode(ctx) + } + } + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val eval = child.genCode(ctx) val nullSafeCast = nullSafeCastFunction(child.dataType, dataType, ctx) @@ -503,11 +567,11 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String case TimestampType => castToTimestampCode(from, ctx) case CalendarIntervalType => castToIntervalCode(from) case BooleanType => castToBooleanCode(from) - case ByteType => castToByteCode(from) - case ShortType => castToShortCode(from) - case IntegerType => castToIntCode(from) + case ByteType => castToByteCode(from, ctx) + case ShortType => castToShortCode(from, ctx) + case IntegerType => castToIntCode(from, ctx) case FloatType => castToFloatCode(from) - case LongType => castToLongCode(from) + case LongType => castToLongCode(from, ctx) case DoubleType => castToDoubleCode(from) case array: ArrayType => @@ -734,13 +798,16 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String (c, evPrim, evNull) => s"$evPrim = $c != 0;" } - private[this] def castToByteCode(from: DataType): CastFunction = from match { + private[this] def castToByteCode(from: DataType, ctx: CodegenContext): CastFunction = from match { case StringType => + val wrapper = ctx.freshName("wrapper") + ctx.addMutableState("UTF8String.IntWrapper", wrapper, + s"$wrapper = new UTF8String.IntWrapper();") (c, evPrim, evNull) => s""" - try { - $evPrim = $c.toByte(); - } catch (java.lang.NumberFormatException e) { + if ($c.toByte($wrapper)) { + $evPrim = (byte) $wrapper.value; + } else { $evNull = true; } """ @@ -756,13 +823,18 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String (c, evPrim, evNull) => s"$evPrim = (byte) $c;" } - private[this] def castToShortCode(from: DataType): CastFunction = from match { + private[this] def castToShortCode( + from: DataType, + ctx: CodegenContext): CastFunction = from match { case StringType => + val wrapper = ctx.freshName("wrapper") + ctx.addMutableState("UTF8String.IntWrapper", wrapper, + s"$wrapper = new UTF8String.IntWrapper();") (c, evPrim, evNull) => s""" - try { - $evPrim = $c.toShort(); - } catch (java.lang.NumberFormatException e) { + if ($c.toShort($wrapper)) { + $evPrim = (short) $wrapper.value; + } else { $evNull = true; } """ @@ -778,13 +850,16 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String (c, evPrim, evNull) => s"$evPrim = (short) $c;" } - private[this] def castToIntCode(from: DataType): CastFunction = from match { + private[this] def castToIntCode(from: DataType, ctx: CodegenContext): CastFunction = from match { case StringType => + val wrapper = ctx.freshName("wrapper") + ctx.addMutableState("UTF8String.IntWrapper", wrapper, + s"$wrapper = new UTF8String.IntWrapper();") (c, evPrim, evNull) => s""" - try { - $evPrim = $c.toInt(); - } catch (java.lang.NumberFormatException e) { + if ($c.toInt($wrapper)) { + $evPrim = $wrapper.value; + } else { $evNull = true; } """ @@ -800,13 +875,17 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String (c, evPrim, evNull) => s"$evPrim = (int) $c;" } - private[this] def castToLongCode(from: DataType): CastFunction = from match { + private[this] def castToLongCode(from: DataType, ctx: CodegenContext): CastFunction = from match { case StringType => + val wrapper = ctx.freshName("wrapper") + ctx.addMutableState("UTF8String.LongWrapper", wrapper, + s"$wrapper = new UTF8String.LongWrapper();") + (c, evPrim, evNull) => s""" - try { - $evPrim = $c.toLong(); - } catch (java.lang.NumberFormatException e) { + if ($c.toLong($wrapper)) { + $evPrim = $wrapper.value; + } else { $evNull = true; } """ diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala index b93a5d0b7a0e5..b847ef7bfaa97 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.catalyst.expressions +import java.util.Locale + import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.expressions.codegen._ @@ -184,7 +186,7 @@ abstract class Expression extends TreeNode[Expression] { * Returns a user-facing string representation of this expression's name. * This should usually match the name of the function in SQL. */ - def prettyName: String = nodeName.toLowerCase + def prettyName: String = nodeName.toLowerCase(Locale.ROOT) protected def flatArguments: Iterator[Any] = productIterator.flatMap { case t: Traversable[_] => t @@ -491,7 +493,7 @@ abstract class BinaryExpression extends Expression { * A [[BinaryExpression]] that is an operator, with two properties: * * 1. The string representation is "x symbol y", rather than "funcName(x, y)". - * 2. Two inputs are expected to the be same type. If the two inputs have different types, + * 2. Two inputs are expected to be of the same type. If the two inputs have different types, * the analyzer will find the tightest common type and do the proper type casting. */ abstract class BinaryOperator extends BinaryExpression with ExpectsInputTypes { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SortOrder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SortOrder.scala index 3bebd552ef51a..abcb9a2b939b4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SortOrder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SortOrder.scala @@ -53,8 +53,15 @@ case object NullsLast extends NullOrdering{ /** * An expression that can be used to sort a tuple. This class extends expression primarily so that * transformations over expression will descend into its child. + * `sameOrderExpressions` is a set of expressions with the same sort order as the child. It is + * derived from equivalence relation in an operator, e.g. left/right keys of an inner sort merge + * join. */ -case class SortOrder(child: Expression, direction: SortDirection, nullOrdering: NullOrdering) +case class SortOrder( + child: Expression, + direction: SortDirection, + nullOrdering: NullOrdering, + sameOrderExpressions: Set[Expression]) extends UnaryExpression with Unevaluable { /** Sort order is not foldable because we don't have an eval for it. */ @@ -75,11 +82,19 @@ case class SortOrder(child: Expression, direction: SortDirection, nullOrdering: override def sql: String = child.sql + " " + direction.sql + " " + nullOrdering.sql def isAscending: Boolean = direction == Ascending + + def satisfies(required: SortOrder): Boolean = { + (sameOrderExpressions + child).exists(required.child.semanticEquals) && + direction == required.direction && nullOrdering == required.nullOrdering + } } object SortOrder { - def apply(child: Expression, direction: SortDirection): SortOrder = { - new SortOrder(child, direction, direction.defaultNullOrdering) + def apply( + child: Expression, + direction: SortDirection, + sameOrderExpressions: Set[Expression] = Set.empty): SortOrder = { + new SortOrder(child, direction, direction.defaultNullOrdering, sameOrderExpressions) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproximatePercentile.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproximatePercentile.scala index db062f1a543fe..1ec2e4a9e9319 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproximatePercentile.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproximatePercentile.scala @@ -245,7 +245,8 @@ object ApproximatePercentile { val result = new Array[Double](percentages.length) var i = 0 while (i < percentages.length) { - result(i) = summaries.query(percentages(i)) + // Since summaries.count != 0, the query here never return None. + result(i) = summaries.query(percentages(i)).get i += 1 } result diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/PivotFirst.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/PivotFirst.scala index 9ad31243e4122..523714869242d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/PivotFirst.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/PivotFirst.scala @@ -91,14 +91,12 @@ case class PivotFirst( override def update(mutableAggBuffer: InternalRow, inputRow: InternalRow): Unit = { val pivotColValue = pivotColumn.eval(inputRow) - if (pivotColValue != null) { - // We ignore rows whose pivot column value is not in the list of pivot column values. - val index = pivotIndex.getOrElse(pivotColValue, -1) - if (index >= 0) { - val value = valueColumn.eval(inputRow) - if (value != null) { - updateRow(mutableAggBuffer, mutableAggBufferOffset + index, value) - } + // We ignore rows whose pivot column value is not in the list of pivot column values. + val index = pivotIndex.getOrElse(pivotColValue, -1) + if (index >= 0) { + val value = valueColumn.eval(inputRow) + if (value != null) { + updateRow(mutableAggBuffer, mutableAggBufferOffset + index, value) } } } @@ -140,7 +138,9 @@ case class PivotFirst( override val aggBufferAttributes: Seq[AttributeReference] = - pivotIndex.toList.sortBy(_._2).map(kv => AttributeReference(kv._1.toString, valueDataType)()) + pivotIndex.toList.sortBy(_._2).map { kv => + AttributeReference(Option(kv._1).getOrElse("null").toString, valueDataType)() + } override val aggBufferSchema: StructType = StructType.fromAttributes(aggBufferAttributes) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala index 4870093e9250f..f2b252259b89d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala @@ -113,7 +113,7 @@ case class Abs(child: Expression) protected override def nullSafeEval(input: Any): Any = numeric.abs(input) } -abstract class BinaryArithmetic extends BinaryOperator { +abstract class BinaryArithmetic extends BinaryOperator with NullIntolerant { override def dataType: DataType = left.dataType @@ -146,7 +146,7 @@ object BinaryArithmetic { > SELECT 1 _FUNC_ 2; 3 """) -case class Add(left: Expression, right: Expression) extends BinaryArithmetic with NullIntolerant { +case class Add(left: Expression, right: Expression) extends BinaryArithmetic { override def inputType: AbstractDataType = TypeCollection.NumericAndInterval @@ -182,8 +182,7 @@ case class Add(left: Expression, right: Expression) extends BinaryArithmetic wit > SELECT 2 _FUNC_ 1; 1 """) -case class Subtract(left: Expression, right: Expression) - extends BinaryArithmetic with NullIntolerant { +case class Subtract(left: Expression, right: Expression) extends BinaryArithmetic { override def inputType: AbstractDataType = TypeCollection.NumericAndInterval @@ -219,8 +218,7 @@ case class Subtract(left: Expression, right: Expression) > SELECT 2 _FUNC_ 3; 6 """) -case class Multiply(left: Expression, right: Expression) - extends BinaryArithmetic with NullIntolerant { +case class Multiply(left: Expression, right: Expression) extends BinaryArithmetic { override def inputType: AbstractDataType = NumericType @@ -243,8 +241,7 @@ case class Multiply(left: Expression, right: Expression) 1.0 """) // scalastyle:on line.size.limit -case class Divide(left: Expression, right: Expression) - extends BinaryArithmetic with NullIntolerant { +case class Divide(left: Expression, right: Expression) extends BinaryArithmetic { override def inputType: AbstractDataType = TypeCollection(DoubleType, DecimalType) @@ -324,8 +321,7 @@ case class Divide(left: Expression, right: Expression) > SELECT 2 _FUNC_ 1.8; 0.2 """) -case class Remainder(left: Expression, right: Expression) - extends BinaryArithmetic with NullIntolerant { +case class Remainder(left: Expression, right: Expression) extends BinaryArithmetic { override def inputType: AbstractDataType = NumericType @@ -412,7 +408,7 @@ case class Remainder(left: Expression, right: Expression) > SELECT _FUNC_(-10, 3); 2 """) -case class Pmod(left: Expression, right: Expression) extends BinaryArithmetic with NullIntolerant { +case class Pmod(left: Expression, right: Expression) extends BinaryArithmetic { override def toString: String = s"pmod($left, $right)" diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/bitwiseExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/bitwiseExpressions.scala index 2918040771433..425efbb6c96c4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/bitwiseExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/bitwiseExpressions.scala @@ -86,7 +86,7 @@ case class BitwiseOr(left: Expression, right: Expression) extends BinaryArithmet } /** - * A function that calculates bitwise xor of two numbers. + * A function that calculates bitwise xor({@literal ^}) of two numbers. * * Code generation inherited from BinaryArithmetic. */ diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala index 22277ad8d56ee..b6675a84ece48 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala @@ -390,6 +390,8 @@ case class CreateNamedStructUnsafe(children: Seq[Expression]) extends CreateName Examples: > SELECT _FUNC_('a:1,b:2,c:3', ',', ':'); map("a":"1","b":"2","c":"3") + > SELECT _FUNC_('a'); + map("a":null) """) // scalastyle:on line.size.limit case class StringToMap(text: Expression, pairDelim: Expression, keyValueDelim: Expression) @@ -407,7 +409,7 @@ case class StringToMap(text: Expression, pairDelim: Expression, keyValueDelim: E override def inputTypes: Seq[AbstractDataType] = Seq(StringType, StringType, StringType) - override def dataType: DataType = MapType(StringType, StringType, valueContainsNull = false) + override def dataType: DataType = MapType(StringType, StringType) override def checkInputDataTypes(): TypeCheckResult = { if (Seq(pairDelim, keyValueDelim).exists(! _.foldable)) { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala index 0c256c3d890f1..ef88cfb543ebb 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala @@ -68,7 +68,7 @@ object ExtractValue { case StructType(_) => s"Field name should be String Literal, but it's $extraction" case other => - s"Can't extract value from $child" + s"Can't extract value from $child: need struct type but got ${other.simpleString}" } throw new AnalysisException(errorMsg) } @@ -104,7 +104,7 @@ trait ExtractValue extends Expression * For example, when get field `yEAr` from ``, we should pass in `yEAr`. */ case class GetStructField(child: Expression, ordinal: Int, name: Option[String] = None) - extends UnaryExpression with ExtractValue { + extends UnaryExpression with ExtractValue with NullIntolerant { lazy val childSchema = child.dataType.asInstanceOf[StructType] @@ -152,7 +152,7 @@ case class GetArrayStructFields( field: StructField, ordinal: Int, numFields: Int, - containsNull: Boolean) extends UnaryExpression with ExtractValue { + containsNull: Boolean) extends UnaryExpression with ExtractValue with NullIntolerant { override def dataType: DataType = ArrayType(field.dataType, containsNull) override def toString: String = s"$child.${field.name}" @@ -213,7 +213,7 @@ case class GetArrayStructFields( * We need to do type checking here as `ordinal` expression maybe unresolved. */ case class GetArrayItem(child: Expression, ordinal: Expression) - extends BinaryExpression with ExpectsInputTypes with ExtractValue { + extends BinaryExpression with ExpectsInputTypes with ExtractValue with NullIntolerant { // We have done type checking for child in `ExtractValue`, so only need to check the `ordinal`. override def inputTypes: Seq[AbstractDataType] = Seq(AnyDataType, IntegralType) @@ -260,7 +260,7 @@ case class GetArrayItem(child: Expression, ordinal: Expression) * We need to do type checking here as `key` expression maybe unresolved. */ case class GetMapValue(child: Expression, key: Expression) - extends BinaryExpression with ImplicitCastInputTypes with ExtractValue { + extends BinaryExpression with ImplicitCastInputTypes with ExtractValue with NullIntolerant { private def keyType = child.dataType.asInstanceOf[MapType].keyType diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala index f8fe774823e5b..bb8fd5032d63d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala @@ -24,7 +24,6 @@ import java.util.{Calendar, TimeZone} import scala.util.control.NonFatal import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodegenFallback, ExprCode} import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.types._ @@ -34,6 +33,9 @@ import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String} * Common base class for time zone aware expressions. */ trait TimeZoneAwareExpression extends Expression { + /** The expression is only resolved when the time zone has been set. */ + override lazy val resolved: Boolean = + childrenResolved && checkInputDataTypes().isSuccess && timeZoneId.isDefined /** the timezone ID to be used to evaluate value. */ def timeZoneId: Option[String] diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalExpressions.scala index fa5dea6841149..c2211ae5d594b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalExpressions.scala @@ -84,14 +84,8 @@ case class CheckOverflow(child: Expression, dataType: DecimalType) extends Unary override def nullable: Boolean = true - override def nullSafeEval(input: Any): Any = { - val d = input.asInstanceOf[Decimal].clone() - if (d.changePrecision(dataType.precision, dataType.scale)) { - d - } else { - null - } - } + override def nullSafeEval(input: Any): Any = + input.asInstanceOf[Decimal].toPrecision(dataType.precision, dataType.scale).orNull override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { nullSafeCodeGen(ctx, ev, eval => { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/hash.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/hash.scala index 2d9c2e42064b3..2a5963d37f5e8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/hash.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/hash.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql.catalyst.expressions +import java.math.{BigDecimal, RoundingMode} import java.security.{MessageDigest, NoSuchAlgorithmException} import java.util.zip.CRC32 @@ -334,6 +335,8 @@ abstract class HashExpression[E] extends Expression { } } + protected def genHashTimestamp(t: String, result: String): String = genHashLong(t, result) + protected def genHashCalendarInterval(input: String, result: String): String = { val microsecondsHash = s"$hasherClassName.hashLong($input.microseconds, $result)" s"$result = $hasherClassName.hashInt($input.months, $microsecondsHash);" @@ -399,7 +402,8 @@ abstract class HashExpression[E] extends Expression { case NullType => "" case BooleanType => genHashBoolean(input, result) case ByteType | ShortType | IntegerType | DateType => genHashInt(input, result) - case LongType | TimestampType => genHashLong(input, result) + case LongType => genHashLong(input, result) + case TimestampType => genHashTimestamp(input, result) case FloatType => genHashFloat(input, result) case DoubleType => genHashDouble(input, result) case d: DecimalType => genHashDecimal(ctx, d, input, result) @@ -432,6 +436,10 @@ abstract class InterpretedHashFunction { protected def hashUnsafeBytes(base: AnyRef, offset: Long, length: Int, seed: Long): Long + /** + * Computes hash of a given `value` of type `dataType`. The caller needs to check the validity + * of input `value`. + */ def hash(value: Any, dataType: DataType, seed: Long): Long = { value match { case null => seed @@ -579,8 +587,6 @@ object XxHash64Function extends InterpretedHashFunction { * * We should use this hash function for both shuffle and bucket of Hive tables, so that * we can guarantee shuffle and bucketing have same data distribution - * - * TODO: Support Decimal and date related types */ @ExpressionDescription( usage = "_FUNC_(expr1, expr2, ...) - Returns a hash value of the arguments.") @@ -635,13 +641,28 @@ case class HiveHash(children: Seq[Expression]) extends HashExpression[Int] { override protected def genHashBytes(b: String, result: String): String = s"$result = $hasherClassName.hashUnsafeBytes($b, Platform.BYTE_ARRAY_OFFSET, $b.length);" + override protected def genHashDecimal( + ctx: CodegenContext, + d: DecimalType, + input: String, + result: String): String = { + s""" + $result = ${HiveHashFunction.getClass.getName.stripSuffix("$")}.normalizeDecimal( + $input.toJavaBigDecimal()).hashCode();""" + } + override protected def genHashCalendarInterval(input: String, result: String): String = { s""" - $result = (31 * $hasherClassName.hashInt($input.months)) + - $hasherClassName.hashLong($input.microseconds);" + $result = (int) + ${HiveHashFunction.getClass.getName.stripSuffix("$")}.hashCalendarInterval($input); """ } + override protected def genHashTimestamp(input: String, result: String): String = + s""" + $result = (int) ${HiveHashFunction.getClass.getName.stripSuffix("$")}.hashTimestamp($input); + """ + override protected def genHashString(input: String, result: String): String = { val baseObject = s"$input.getBaseObject()" val baseOffset = s"$input.getBaseOffset()" @@ -732,6 +753,87 @@ object HiveHashFunction extends InterpretedHashFunction { HiveHasher.hashUnsafeBytes(base, offset, len) } + private val HIVE_DECIMAL_MAX_PRECISION = 38 + private val HIVE_DECIMAL_MAX_SCALE = 38 + + // Mimics normalization done for decimals in Hive at HiveDecimalV1.normalize() + def normalizeDecimal(input: BigDecimal): BigDecimal = { + if (input == null) return null + + def trimDecimal(input: BigDecimal) = { + var result = input + if (result.compareTo(BigDecimal.ZERO) == 0) { + // Special case for 0, because java doesn't strip zeros correctly on that number. + result = BigDecimal.ZERO + } else { + result = result.stripTrailingZeros + if (result.scale < 0) { + // no negative scale decimals + result = result.setScale(0) + } + } + result + } + + var result = trimDecimal(input) + val intDigits = result.precision - result.scale + if (intDigits > HIVE_DECIMAL_MAX_PRECISION) { + return null + } + + val maxScale = Math.min(HIVE_DECIMAL_MAX_SCALE, + Math.min(HIVE_DECIMAL_MAX_PRECISION - intDigits, result.scale)) + if (result.scale > maxScale) { + result = result.setScale(maxScale, RoundingMode.HALF_UP) + // Trimming is again necessary, because rounding may introduce new trailing 0's. + result = trimDecimal(result) + } + result + } + + /** + * Mimics TimestampWritable.hashCode() in Hive + */ + def hashTimestamp(timestamp: Long): Long = { + val timestampInSeconds = timestamp / 1000000 + val nanoSecondsPortion = (timestamp % 1000000) * 1000 + + var result = timestampInSeconds + result <<= 30 // the nanosecond part fits in 30 bits + result |= nanoSecondsPortion + ((result >>> 32) ^ result).toInt + } + + /** + * Hive allows input intervals to be defined using units below but the intervals + * have to be from the same category: + * - year, month (stored as HiveIntervalYearMonth) + * - day, hour, minute, second, nanosecond (stored as HiveIntervalDayTime) + * + * eg. (INTERVAL '30' YEAR + INTERVAL '-23' DAY) fails in Hive + * + * This method mimics HiveIntervalDayTime.hashCode() in Hive. + * + * Two differences wrt Hive due to how intervals are stored in Spark vs Hive: + * + * - If the `INTERVAL` is backed as HiveIntervalYearMonth in Hive, then this method will not + * produce Hive compatible result. The reason being Spark's representation of calendar does not + * have such categories based on the interval and is unified. + * + * - Spark's [[CalendarInterval]] has precision upto microseconds but Hive's + * HiveIntervalDayTime can store data with precision upto nanoseconds. So, any input intervals + * with nanosecond values will lead to wrong output hashes (ie. non adherent with Hive output) + */ + def hashCalendarInterval(calendarInterval: CalendarInterval): Long = { + val totalSeconds = calendarInterval.microseconds / CalendarInterval.MICROS_PER_SECOND.toInt + val result: Int = (17 * 37) + (totalSeconds ^ totalSeconds >> 32).toInt + + val nanoSeconds = + (calendarInterval.microseconds - + (totalSeconds * CalendarInterval.MICROS_PER_SECOND.toInt)).toInt * 1000 + (result * 37) + nanoSeconds + } + override def hash(value: Any, dataType: DataType, seed: Long): Long = { value match { case null => 0 @@ -785,6 +887,9 @@ object HiveHashFunction extends InterpretedHashFunction { } result + case d: Decimal => normalizeDecimal(d.toJavaBigDecimal).hashCode() + case timestamp: Long if dataType.isInstanceOf[TimestampType] => hashTimestamp(timestamp) + case calendarInterval: CalendarInterval => hashCalendarInterval(calendarInterval) case _ => super.hash(value, dataType, 0) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala index dbff62efdddb6..6b90354367f40 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala @@ -17,17 +17,19 @@ package org.apache.spark.sql.catalyst.expressions -import java.io.{ByteArrayOutputStream, CharArrayWriter, StringWriter} +import java.io.{ByteArrayInputStream, ByteArrayOutputStream, CharArrayWriter, InputStreamReader, StringWriter} import scala.util.parsing.combinator.RegexParsers import com.fasterxml.jackson.core._ +import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback +import org.apache.spark.sql.catalyst.parser.CatalystSqlParser import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.json._ -import org.apache.spark.sql.catalyst.util.{GenericArrayData, ParseModes} +import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, ArrayData, BadRecordException, FailFastMode, GenericArrayData} import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String import org.apache.spark.util.Utils @@ -147,7 +149,9 @@ case class GetJsonObject(json: Expression, path: Expression) if (parsed.isDefined) { try { - Utils.tryWithResource(jsonFactory.createParser(jsonStr.getBytes)) { parser => + /* We know the bytes are UTF-8 encoded. Pass a Reader to avoid having Jackson + detect character encoding which could fail for some malformed strings */ + Utils.tryWithResource(CreateJacksonParser.utf8String(jsonFactory, jsonStr)) { parser => val output = new ByteArrayOutputStream() val matched = Utils.tryWithResource( jsonFactory.createGenerator(output, JsonEncoding.UTF8)) { generator => @@ -330,7 +334,7 @@ case class GetJsonObject(json: Expression, path: Expression) // scalastyle:off line.size.limit @ExpressionDescription( - usage = "_FUNC_(jsonStr, p1, p2, ..., pn) - Return a tuple like the function get_json_object, but it takes multiple names. All the input parameters and output column types are string.", + usage = "_FUNC_(jsonStr, p1, p2, ..., pn) - Returns a tuple like the function get_json_object, but it takes multiple names. All the input parameters and output column types are string.", extended = """ Examples: > SELECT _FUNC_('{"a":1, "b":2}', 'a', 'b'); @@ -391,8 +395,10 @@ case class JsonTuple(children: Seq[Expression]) } try { - Utils.tryWithResource(jsonFactory.createParser(json.getBytes)) { - parser => parseRow(parser, input) + /* We know the bytes are UTF-8 encoded. Pass a Reader to avoid having Jackson + detect character encoding which could fail for some malformed strings */ + Utils.tryWithResource(CreateJacksonParser.utf8String(jsonFactory, json)) { parser => + parseRow(parser, input) } } catch { case _: JsonProcessingException => @@ -480,9 +486,21 @@ case class JsonTuple(children: Seq[Expression]) } /** - * Converts an json input string to a [[StructType]] or [[ArrayType]] with the specified schema. + * Converts an json input string to a [[StructType]] or [[ArrayType]] of [[StructType]]s + * with the specified schema. */ -case class JsonToStruct( +// scalastyle:off line.size.limit +@ExpressionDescription( + usage = "_FUNC_(jsonStr, schema[, options]) - Returns a struct value with the given `jsonStr` and `schema`.", + extended = """ + Examples: + > SELECT _FUNC_('{"a":1, "b":0.8}', 'a INT, b DOUBLE'); + {"a":1, "b":0.8} + > SELECT _FUNC_('{"time":"26/08/2015"}', 'time Timestamp', map('timestampFormat', 'dd/MM/yyyy')); + {"time":"2015-08-26 00:00:00.0"} + """) +// scalastyle:on line.size.limit +case class JsonToStructs( schema: DataType, options: Map[String, String], child: Expression, @@ -493,6 +511,21 @@ case class JsonToStruct( def this(schema: DataType, options: Map[String, String], child: Expression) = this(schema, options, child, None) + // Used in `FunctionRegistry` + def this(child: Expression, schema: Expression) = + this( + schema = JsonExprUtils.validateSchemaLiteral(schema), + options = Map.empty[String, String], + child = child, + timeZoneId = None) + + def this(child: Expression, schema: Expression, options: Expression) = + this( + schema = JsonExprUtils.validateSchemaLiteral(schema), + options = JsonExprUtils.convertToMapData(options), + child = child, + timeZoneId = None) + override def checkInputDataTypes(): TypeCheckResult = schema match { case _: StructType | ArrayType(_: StructType, _) => super.checkInputDataTypes() @@ -519,7 +552,7 @@ case class JsonToStruct( lazy val parser = new JacksonParser( rowSchema, - new JSONOptions(options + ("mode" -> ParseModes.FAIL_FAST_MODE), timeZoneId.get)) + new JSONOptions(options + ("mode" -> FailFastMode.name), timeZoneId.get)) override def dataType: DataType = schema @@ -554,7 +587,7 @@ case class JsonToStruct( CreateJacksonParser.utf8String, identity[UTF8String])) } catch { - case _: SparkSQLJsonProcessingException => null + case _: BadRecordException => null } } @@ -562,9 +595,22 @@ case class JsonToStruct( } /** - * Converts a [[StructType]] to a json output string. + * Converts a [[StructType]] or [[ArrayType]] of [[StructType]]s to a json output string. */ -case class StructToJson( +// scalastyle:off line.size.limit +@ExpressionDescription( + usage = "_FUNC_(expr[, options]) - Returns a json string with a given struct value", + extended = """ + Examples: + > SELECT _FUNC_(named_struct('a', 1, 'b', 2)); + {"a":1,"b":2} + > SELECT _FUNC_(named_struct('time', to_timestamp('2015-08-26', 'yyyy-MM-dd')), map('timestampFormat', 'dd/MM/yyyy')); + {"time":"26/08/2015"} + > SELECT _FUNC_(array(named_struct('a', 1, 'b', 2)); + [{"a":1,"b":2}] + """) +// scalastyle:on line.size.limit +case class StructsToJson( options: Map[String, String], child: Expression, timeZoneId: Option[String] = None) @@ -573,43 +619,90 @@ case class StructToJson( def this(options: Map[String, String], child: Expression) = this(options, child, None) + // Used in `FunctionRegistry` + def this(child: Expression) = this(Map.empty, child, None) + def this(child: Expression, options: Expression) = + this( + options = JsonExprUtils.convertToMapData(options), + child = child, + timeZoneId = None) + @transient lazy val writer = new CharArrayWriter() @transient - lazy val gen = - new JacksonGenerator( - child.dataType.asInstanceOf[StructType], - writer, - new JSONOptions(options, timeZoneId.get)) + lazy val gen = new JacksonGenerator( + rowSchema, writer, new JSONOptions(options, timeZoneId.get)) + + @transient + lazy val rowSchema = child.dataType match { + case st: StructType => st + case ArrayType(st: StructType, _) => st + } + + // This converts rows to the JSON output according to the given schema. + @transient + lazy val converter: Any => UTF8String = { + def getAndReset(): UTF8String = { + gen.flush() + val json = writer.toString + writer.reset() + UTF8String.fromString(json) + } + + child.dataType match { + case _: StructType => + (row: Any) => + gen.write(row.asInstanceOf[InternalRow]) + getAndReset() + case ArrayType(_: StructType, _) => + (arr: Any) => + gen.write(arr.asInstanceOf[ArrayData]) + getAndReset() + } + } override def dataType: DataType = StringType - override def checkInputDataTypes(): TypeCheckResult = { - if (StructType.acceptsType(child.dataType)) { + override def checkInputDataTypes(): TypeCheckResult = child.dataType match { + case _: StructType | ArrayType(_: StructType, _) => try { - JacksonUtils.verifySchema(child.dataType.asInstanceOf[StructType]) + JacksonUtils.verifySchema(rowSchema) TypeCheckResult.TypeCheckSuccess } catch { case e: UnsupportedOperationException => TypeCheckResult.TypeCheckFailure(e.getMessage) } - } else { - TypeCheckResult.TypeCheckFailure( - s"$prettyName requires that the expression is a struct expression.") - } + case _ => TypeCheckResult.TypeCheckFailure( + s"Input type ${child.dataType.simpleString} must be a struct or array of structs.") } override def withTimeZone(timeZoneId: String): TimeZoneAwareExpression = copy(timeZoneId = Option(timeZoneId)) - override def nullSafeEval(row: Any): Any = { - gen.write(row.asInstanceOf[InternalRow]) - gen.flush() - val json = writer.toString - writer.reset() - UTF8String.fromString(json) + override def nullSafeEval(value: Any): Any = converter(value) + + override def inputTypes: Seq[AbstractDataType] = TypeCollection(ArrayType, StructType) :: Nil +} + +object JsonExprUtils { + + def validateSchemaLiteral(exp: Expression): StructType = exp match { + case Literal(s, StringType) => CatalystSqlParser.parseTableSchema(s.toString) + case e => throw new AnalysisException(s"Expected a string literal instead of $e") } - override def inputTypes: Seq[AbstractDataType] = StructType :: Nil + def convertToMapData(exp: Expression): Map[String, String] = exp match { + case m: CreateMap + if m.dataType.acceptsType(MapType(StringType, StringType, valueContainsNull = false)) => + val arrayMap = m.eval().asInstanceOf[ArrayBasedMapData] + ArrayBasedMapData.toScalaMap(arrayMap).map { case (key, value) => + key.toString -> value.toString + } + case m: CreateMap => + throw new AnalysisException( + s"A type of keys and values in map() must be string, but got ${m.dataType}") + case _ => + throw new AnalysisException("Must use a map() function for options") + } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala index 65273a77b1054..c4d47ab2084fd 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.catalyst.expressions import java.{lang => jl} +import java.util.Locale import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.{TypeCheckFailure, TypeCheckSuccess} @@ -68,7 +69,7 @@ abstract class UnaryMathExpression(val f: Double => Double, name: String) } // name of function in java.lang.Math - def funcName: String = name.toLowerCase + def funcName: String = name.toLowerCase(Locale.ROOT) override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { defineCodeGen(ctx, ev, c => s"java.lang.Math.${funcName}($c)") @@ -124,7 +125,8 @@ abstract class BinaryMathExpression(f: (Double, Double) => Double, name: String) } override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - defineCodeGen(ctx, ev, (c1, c2) => s"java.lang.Math.${name.toLowerCase}($c1, $c2)") + defineCodeGen(ctx, ev, (c1, c2) => + s"java.lang.Math.${name.toLowerCase(Locale.ROOT)}($c1, $c2)") } } @@ -1024,7 +1026,7 @@ abstract class RoundBase(child: Expression, scale: Expression, child.dataType match { case _: DecimalType => val decimal = input1.asInstanceOf[Decimal] - if (decimal.changePrecision(decimal.precision, _scale, mode)) decimal else null + decimal.toPrecision(decimal.precision, _scale, mode).orNull case ByteType => BigDecimal(input1.asInstanceOf[Byte]).setScale(_scale, mode).toByte case ShortType => diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala index 36bf3017d4cdb..1a202ecf745c9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.catalyst.expressions.objects import java.lang.reflect.Modifier +import scala.collection.mutable.Builder import scala.language.existentials import scala.reflect.ClassTag @@ -224,25 +225,26 @@ case class Invoke( getFuncResult(ev.value, s"${obj.value}.$functionName($argString)") } else { val funcResult = ctx.freshName("funcResult") + // If the function can return null, we do an extra check to make sure our null bit is still + // set correctly. + val assignResult = if (!returnNullable) { + s"${ev.value} = (${ctx.boxedType(javaType)}) $funcResult;" + } else { + s""" + if ($funcResult != null) { + ${ev.value} = (${ctx.boxedType(javaType)}) $funcResult; + } else { + ${ev.isNull} = true; + } + """ + } s""" Object $funcResult = null; ${getFuncResult(funcResult, s"${obj.value}.$functionName($argString)")} - if ($funcResult == null) { - ${ev.isNull} = true; - } else { - ${ev.value} = (${ctx.boxedType(javaType)}) $funcResult; - } + $assignResult """ } - // If the function can return null, we do an extra check to make sure our null bit is still set - // correctly. - val postNullCheck = if (ctx.defaultValue(dataType) == "null") { - s"${ev.isNull} = ${ev.value} == null;" - } else { - "" - } - val code = s""" ${obj.code} boolean ${ev.isNull} = true; @@ -253,7 +255,6 @@ case class Invoke( if (!${ev.isNull}) { $evaluate } - $postNullCheck } """ ev.copy(code = code) @@ -405,7 +406,7 @@ case class WrapOption(child: Expression, optType: DataType) } /** - * A place holder for the loop variable used in [[MapObjects]]. This should never be constructed + * A placeholder for the loop variable used in [[MapObjects]]. This should never be constructed * manually, but will instead be passed into the provided lambda function. */ case class LambdaVariable( @@ -420,6 +421,27 @@ case class LambdaVariable( } } +/** + * When constructing [[MapObjects]], the element type must be given, which may not be available + * before analysis. This class acts like a placeholder for [[MapObjects]], and will be replaced by + * [[MapObjects]] during analysis after the input data is resolved. + * Note that, ideally we should not serialize and send unresolved expressions to executors, but + * users may accidentally do this(e.g. mistakenly reference an encoder instance when implementing + * Aggregator). Here we mark `function` as transient because it may reference scala Type, which is + * not serializable. Then even users mistakenly reference unresolved expression and serialize it, + * it's just a performance issue(more network traffic), and will not fail. + */ +case class UnresolvedMapObjects( + @transient function: Expression => Expression, + child: Expression, + customCollectionCls: Option[Class[_]] = None) extends UnaryExpression with Unevaluable { + override lazy val resolved = false + + override def dataType: DataType = customCollectionCls.map(ObjectType.apply).getOrElse { + throw new UnsupportedOperationException("not resolved") + } +} + object MapObjects { private val curId = new java.util.concurrent.atomic.AtomicInteger() @@ -429,24 +451,36 @@ object MapObjects { * @param function The function applied on the collection elements. * @param inputData An expression that when evaluated returns a collection object. * @param elementType The data type of elements in the collection. + * @param elementNullable When false, indicating elements in the collection are always + * non-null value. + * @param customCollectionCls Class of the resulting collection (returning ObjectType) + * or None (returning ArrayType) */ def apply( function: Expression => Expression, inputData: Expression, - elementType: DataType): MapObjects = { - val loopValue = "MapObjects_loopValue" + curId.getAndIncrement() - val loopIsNull = "MapObjects_loopIsNull" + curId.getAndIncrement() - val loopVar = LambdaVariable(loopValue, loopIsNull, elementType) - MapObjects(loopValue, loopIsNull, elementType, function(loopVar), inputData) + elementType: DataType, + elementNullable: Boolean = true, + customCollectionCls: Option[Class[_]] = None): MapObjects = { + val id = curId.getAndIncrement() + val loopValue = s"MapObjects_loopValue$id" + val loopIsNull = s"MapObjects_loopIsNull$id" + val loopVar = LambdaVariable(loopValue, loopIsNull, elementType, elementNullable) + MapObjects( + loopValue, loopIsNull, elementType, function(loopVar), inputData, customCollectionCls) } } /** * Applies the given expression to every element of a collection of items, returning the result - * as an ArrayType. This is similar to a typical map operation, but where the lambda function - * is expressed using catalyst expressions. + * as an ArrayType or ObjectType. This is similar to a typical map operation, but where the lambda + * function is expressed using catalyst expressions. * - * The following collection ObjectTypes are currently supported: + * The type of the result is determined as follows: + * - ArrayType - when customCollectionCls is None + * - ObjectType(collection) - when customCollectionCls contains a collection class + * + * The following collection ObjectTypes are currently supported on input: * Seq, Array, ArrayData, java.util.List * * @param loopValue the name of the loop variable that used when iterate the collection, and used @@ -458,13 +492,16 @@ object MapObjects { * @param lambdaFunction A function that take the `loopVar` as input, and used as lambda function * to handle collection elements. * @param inputData An expression that when evaluated returns a collection object. + * @param customCollectionCls Class of the resulting collection (returning ObjectType) + * or None (returning ArrayType) */ case class MapObjects private( loopValue: String, loopIsNull: String, loopVarDataType: DataType, lambdaFunction: Expression, - inputData: Expression) extends Expression with NonSQLExpression { + inputData: Expression, + customCollectionCls: Option[Class[_]]) extends Expression with NonSQLExpression { override def nullable: Boolean = inputData.nullable @@ -474,7 +511,8 @@ case class MapObjects private( throw new UnsupportedOperationException("Only code-generated evaluation is supported") override def dataType: DataType = - ArrayType(lambdaFunction.dataType, containsNull = lambdaFunction.nullable) + customCollectionCls.map(ObjectType.apply).getOrElse( + ArrayType(lambdaFunction.dataType, containsNull = lambdaFunction.nullable)) override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val elementJavaType = ctx.javaType(loopVarDataType) @@ -557,15 +595,40 @@ case class MapObjects private( case _ => s"$loopIsNull = $loopValue == null;" } + val (initCollection, addElement, getResult): (String, String => String, String) = + customCollectionCls match { + case Some(cls) => + // collection + val getBuilder = s"${cls.getName}$$.MODULE$$.newBuilder()" + val builder = ctx.freshName("collectionBuilder") + ( + s""" + ${classOf[Builder[_, _]].getName} $builder = $getBuilder; + $builder.sizeHint($dataLength); + """, + genValue => s"$builder.$$plus$$eq($genValue);", + s"(${cls.getName}) $builder.result();" + ) + case None => + // array + ( + s""" + $convertedType[] $convertedArray = null; + $convertedArray = $arrayConstructor; + """, + genValue => s"$convertedArray[$loopIndex] = $genValue;", + s"new ${classOf[GenericArrayData].getName}($convertedArray);" + ) + } + val code = s""" ${genInputData.code} ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)}; if (!${genInputData.isNull}) { $determineCollectionType - $convertedType[] $convertedArray = null; int $dataLength = $getLength; - $convertedArray = $arrayConstructor; + $initCollection int $loopIndex = 0; while ($loopIndex < $dataLength) { @@ -574,15 +637,15 @@ case class MapObjects private( ${genFunction.code} if (${genFunction.isNull}) { - $convertedArray[$loopIndex] = null; + ${addElement("null")} } else { - $convertedArray[$loopIndex] = $genFunctionValue; + ${addElement(genFunctionValue)} } $loopIndex += 1; } - ${ev.value} = new ${classOf[GenericArrayData].getName}($convertedArray); + ${ev.value} = $getResult } """ ev.copy(code = code, isNull = genInputData.isNull) @@ -935,13 +998,15 @@ case class InitializeJavaBean(beanInstance: Expression, setters: Map[String, Exp * `Int` field named `i`. Expression `s.i` is nullable because `s` can be null. However, for all * non-null `s`, `s.i` can't be null. */ -case class AssertNotNull(child: Expression, walkedTypePath: Seq[String]) +case class AssertNotNull(child: Expression, walkedTypePath: Seq[String] = Nil) extends UnaryExpression with NonSQLExpression { override def dataType: DataType = child.dataType override def foldable: Boolean = false override def nullable: Boolean = false + override def flatArguments: Iterator[Any] = Iterator(child) + private val errMsg = "Null value appeared in non-nullable field:" + walkedTypePath.mkString("\n", "\n", "\n") + "If the schema is inferred from a Scala tuple/case class, or a Java bean, " + @@ -951,7 +1016,7 @@ case class AssertNotNull(child: Expression, walkedTypePath: Seq[String]) override def eval(input: InternalRow): Any = { val result = child.eval(input) if (result == null) { - throw new RuntimeException(errMsg); + throw new NullPointerException(errMsg) } result } @@ -967,7 +1032,7 @@ case class AssertNotNull(child: Expression, walkedTypePath: Seq[String]) ${childGen.code} if (${childGen.isNull}) { - throw new RuntimeException($errMsgField); + throw new NullPointerException($errMsgField); } """ ev.copy(code = code, isNull = "false", value = childGen.value) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/package.scala index 1b00c9e79da22..4c8b177237d23 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/package.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/package.scala @@ -138,5 +138,5 @@ package object expressions { * input will result in null output). We will use this information during constructing IsNotNull * constraints. */ - trait NullIntolerant + trait NullIntolerant extends Expression } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala index 6552c0d8e0267..ba1e9084d0b80 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala @@ -92,6 +92,15 @@ trait PredicateHelper { * Returns true iff `expr` could be evaluated as a condition within join. */ protected def canEvaluateWithinJoin(expr: Expression): Boolean = expr match { + // Non-deterministic expressions are not allowed as join conditions. + case e if !e.deterministic => false + case _: ListQuery | _: Exists => + // A ListQuery defines the query which we want to search in an IN subquery expression. + // Currently the only way to evaluate an IN subquery is to convert it to a + // LeftSemi/LeftAnti/ExistenceJoin by `RewritePredicateSubquery` rule. + // It cannot be evaluated as part of a Join operator. + // An Exists shouldn't be push into a Join operator too. + false case e: SubqueryExpression => // non-correlated subquery will be replaced as literal e.children.isEmpty @@ -125,19 +134,44 @@ case class Not(child: Expression) */ @ExpressionDescription( usage = "expr1 _FUNC_(expr2, expr3, ...) - Returns true if `expr` equals to any valN.") -case class In(value: Expression, list: Seq[Expression]) extends Predicate - with ImplicitCastInputTypes { +case class In(value: Expression, list: Seq[Expression]) extends Predicate { require(list != null, "list should not be null") + override def checkInputDataTypes(): TypeCheckResult = { + list match { + case ListQuery(sub, _, _) :: Nil => + val valExprs = value match { + case cns: CreateNamedStruct => cns.valExprs + case expr => Seq(expr) + } - override def inputTypes: Seq[AbstractDataType] = value.dataType +: list.map(_.dataType) + val mismatchedColumns = valExprs.zip(sub.output).flatMap { + case (l, r) if l.dataType != r.dataType => + s"(${l.sql}:${l.dataType.catalogString}, ${r.sql}:${r.dataType.catalogString})" + case _ => None + } - override def checkInputDataTypes(): TypeCheckResult = { - if (list.exists(l => l.dataType != value.dataType)) { - TypeCheckResult.TypeCheckFailure( - "Arguments must be same type") - } else { - TypeCheckResult.TypeCheckSuccess + if (mismatchedColumns.nonEmpty) { + TypeCheckResult.TypeCheckFailure( + s""" + |The data type of one or more elements in the left hand side of an IN subquery + |is not compatible with the data type of the output of the subquery + |Mismatched columns: + |[${mismatchedColumns.mkString(", ")}] + |Left side: + |[${valExprs.map(_.dataType.catalogString).mkString(", ")}]. + |Right side: + |[${sub.output.map(_.dataType.catalogString).mkString(", ")}]. + """.stripMargin) + } else { + TypeCheckResult.TypeCheckSuccess + } + case _ => + if (list.exists(l => l.dataType != value.dataType)) { + TypeCheckResult.TypeCheckFailure("Arguments must be same type") + } else { + TypeCheckResult.TypeCheckSuccess + } } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala index 4896a6225aa80..3fa84589e3c68 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql.catalyst.expressions +import java.util.Locale import java.util.regex.{MatchResult, Pattern} import org.apache.commons.lang3.StringEscapeUtils @@ -27,8 +28,8 @@ import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String -trait StringRegexExpression extends ImplicitCastInputTypes { - self: BinaryExpression => +abstract class StringRegexExpression extends BinaryExpression + with ImplicitCastInputTypes with NullIntolerant { def escape(v: String): String def matches(regex: Pattern, str: String): Boolean @@ -60,7 +61,7 @@ trait StringRegexExpression extends ImplicitCastInputTypes { } } - override def sql: String = s"${left.sql} ${prettyName.toUpperCase} ${right.sql}" + override def sql: String = s"${left.sql} ${prettyName.toUpperCase(Locale.ROOT)} ${right.sql}" } @@ -68,9 +69,31 @@ trait StringRegexExpression extends ImplicitCastInputTypes { * Simple RegEx pattern matching function */ @ExpressionDescription( - usage = "str _FUNC_ pattern - Returns true if `str` matches `pattern`, or false otherwise.") -case class Like(left: Expression, right: Expression) - extends BinaryExpression with StringRegexExpression { + usage = "str _FUNC_ pattern - Returns true if str matches pattern, " + + "null if any arguments are null, false otherwise.", + extended = """ + Arguments: + str - a string expression + pattern - a string expression. The pattern is a string which is matched literally, with + exception to the following special symbols: + + _ matches any one character in the input (similar to . in posix regular expressions) + + % matches zero or more characters in the input (similar to .* in posix regular + expressions) + + The escape character is '\'. If an escape character precedes a special symbol or another + escape character, the following character is matched literally. It is invalid to escape + any other character. + + Examples: + > SELECT '%SystemDrive%\Users\John' _FUNC_ '\%SystemDrive\%\\Users%' + true + + See also: + Use RLIKE to match with standard regular expressions. +""") +case class Like(left: Expression, right: Expression) extends StringRegexExpression { override def escape(v: String): String = StringUtils.escapeLikeRegex(v) @@ -122,8 +145,7 @@ case class Like(left: Expression, right: Expression) @ExpressionDescription( usage = "str _FUNC_ regexp - Returns true if `str` matches `regexp`, or false otherwise.") -case class RLike(left: Expression, right: Expression) - extends BinaryExpression with StringRegexExpression { +case class RLike(left: Expression, right: Expression) extends StringRegexExpression { override def escape(v: String): String = v override def matches(regex: Pattern, str: String): Boolean = regex.matcher(str).find(0) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala index 908aa44f81c97..5598a146997ca 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala @@ -297,8 +297,8 @@ case class Lower(child: Expression) extends UnaryExpression with String2StringEx } /** A base trait for functions that compare two strings, returning a boolean. */ -trait StringPredicate extends Predicate with ImplicitCastInputTypes { - self: BinaryExpression => +abstract class StringPredicate extends BinaryExpression + with Predicate with ImplicitCastInputTypes with NullIntolerant { def compare(l: UTF8String, r: UTF8String): Boolean @@ -313,8 +313,7 @@ trait StringPredicate extends Predicate with ImplicitCastInputTypes { /** * A function that returns true if the string `left` contains the string `right`. */ -case class Contains(left: Expression, right: Expression) - extends BinaryExpression with StringPredicate { +case class Contains(left: Expression, right: Expression) extends StringPredicate { override def compare(l: UTF8String, r: UTF8String): Boolean = l.contains(r) override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { defineCodeGen(ctx, ev, (c1, c2) => s"($c1).contains($c2)") @@ -324,8 +323,7 @@ case class Contains(left: Expression, right: Expression) /** * A function that returns true if the string `left` starts with the string `right`. */ -case class StartsWith(left: Expression, right: Expression) - extends BinaryExpression with StringPredicate { +case class StartsWith(left: Expression, right: Expression) extends StringPredicate { override def compare(l: UTF8String, r: UTF8String): Boolean = l.startsWith(r) override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { defineCodeGen(ctx, ev, (c1, c2) => s"($c1).startsWith($c2)") @@ -335,8 +333,7 @@ case class StartsWith(left: Expression, right: Expression) /** * A function that returns true if the string `left` ends with the string `right`. */ -case class EndsWith(left: Expression, right: Expression) - extends BinaryExpression with StringPredicate { +case class EndsWith(left: Expression, right: Expression) extends StringPredicate { override def compare(l: UTF8String, r: UTF8String): Boolean = l.endsWith(r) override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { defineCodeGen(ctx, ev, (c1, c2) => s"($c1).endsWith($c2)") @@ -1122,7 +1119,7 @@ case class StringSpace(child: Expression) """) // scalastyle:on line.size.limit case class Substring(str: Expression, pos: Expression, len: Expression) - extends TernaryExpression with ImplicitCastInputTypes { + extends TernaryExpression with ImplicitCastInputTypes with NullIntolerant { def this(str: Expression, pos: Expression) = { this(str, pos, Literal(Integer.MAX_VALUE)) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/subquery.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/subquery.scala index e2e7d98e33459..d7b493d521ddb 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/subquery.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/subquery.scala @@ -17,8 +17,11 @@ package org.apache.spark.sql.catalyst.expressions +import scala.collection.mutable.ArrayBuffer + +import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression import org.apache.spark.sql.catalyst.plans.QueryPlan -import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.catalyst.plans.logical.{Filter, LogicalPlan} import org.apache.spark.sql.types._ /** @@ -40,19 +43,190 @@ abstract class PlanExpression[T <: QueryPlan[_]] extends Expression { /** * A base interface for expressions that contain a [[LogicalPlan]]. */ -abstract class SubqueryExpression extends PlanExpression[LogicalPlan] { +abstract class SubqueryExpression( + plan: LogicalPlan, + children: Seq[Expression], + exprId: ExprId) extends PlanExpression[LogicalPlan] { + override lazy val resolved: Boolean = childrenResolved && plan.resolved + override lazy val references: AttributeSet = + if (plan.resolved) super.references -- plan.outputSet else super.references override def withNewPlan(plan: LogicalPlan): SubqueryExpression + override def semanticEquals(o: Expression): Boolean = o match { + case p: SubqueryExpression => + this.getClass.getName.equals(p.getClass.getName) && plan.sameResult(p.plan) && + children.length == p.children.length && + children.zip(p.children).forall(p => p._1.semanticEquals(p._2)) + case _ => false + } + def canonicalize(attrs: AttributeSeq): SubqueryExpression = { + // Normalize the outer references in the subquery plan. + val normalizedPlan = plan.transformAllExpressions { + case OuterReference(r) => OuterReference(QueryPlan.normalizeExprId(r, attrs)) + } + withNewPlan(normalizedPlan).canonicalized.asInstanceOf[SubqueryExpression] + } } object SubqueryExpression { + /** + * Returns true when an expression contains an IN or EXISTS subquery and false otherwise. + */ + def hasInOrExistsSubquery(e: Expression): Boolean = { + e.find { + case _: ListQuery | _: Exists => true + case _ => false + }.isDefined + } + + /** + * Returns true when an expression contains a subquery that has outer reference(s). The outer + * reference attributes are kept as children of subquery expression by + * [[org.apache.spark.sql.catalyst.analysis.Analyzer.ResolveSubquery]] + */ def hasCorrelatedSubquery(e: Expression): Boolean = { e.find { - case e: SubqueryExpression if e.children.nonEmpty => true + case s: SubqueryExpression => s.children.nonEmpty case _ => false }.isDefined } } +object SubExprUtils extends PredicateHelper { + /** + * Returns true when an expression contains correlated predicates i.e outer references and + * returns false otherwise. + */ + def containsOuter(e: Expression): Boolean = { + e.find(_.isInstanceOf[OuterReference]).isDefined + } + + /** + * Returns whether there are any null-aware predicate subqueries inside Not. If not, we could + * turn the null-aware predicate into not-null-aware predicate. + */ + def hasNullAwarePredicateWithinNot(condition: Expression): Boolean = { + splitConjunctivePredicates(condition).exists { + case _: Exists | Not(_: Exists) => false + case In(_, Seq(_: ListQuery)) | Not(In(_, Seq(_: ListQuery))) => false + case e => e.find { x => + x.isInstanceOf[Not] && e.find { + case In(_, Seq(_: ListQuery)) => true + case _ => false + }.isDefined + }.isDefined + } + + } + + /** + * Returns an expression after removing the OuterReference shell. + */ + def stripOuterReference(e: Expression): Expression = e.transform { case OuterReference(r) => r } + + /** + * Returns the list of expressions after removing the OuterReference shell from each of + * the expression. + */ + def stripOuterReferences(e: Seq[Expression]): Seq[Expression] = e.map(stripOuterReference) + + /** + * Returns the logical plan after removing the OuterReference shell from all the expressions + * of the input logical plan. + */ + def stripOuterReferences(p: LogicalPlan): LogicalPlan = { + p.transformAllExpressions { + case OuterReference(a) => a + } + } + + /** + * Given a logical plan, returns TRUE if it has an outer reference and false otherwise. + */ + def hasOuterReferences(plan: LogicalPlan): Boolean = { + plan.find { + case f: Filter => containsOuter(f.condition) + case other => false + }.isDefined + } + + /** + * Given a list of expressions, returns the expressions which have outer references. Aggregate + * expressions are treated in a special way. If the children of aggregate expression contains an + * outer reference, then the entire aggregate expression is marked as an outer reference. + * Example (SQL): + * {{{ + * SELECT a FROM l GROUP by 1 HAVING EXISTS (SELECT 1 FROM r WHERE d < min(b)) + * }}} + * In the above case, we want to mark the entire min(b) as an outer reference + * OuterReference(min(b)) instead of min(OuterReference(b)). + * TODO: Currently we don't allow deep correlation. Also, we don't allow mixing of + * outer references and local references under an aggregate expression. + * For example (SQL): + * {{{ + * SELECT .. FROM p1 + * WHERE EXISTS (SELECT ... + * FROM p2 + * WHERE EXISTS (SELECT ... + * FROM sq + * WHERE min(p1.a + p2.b) = sq.c)) + * + * SELECT .. FROM p1 + * WHERE EXISTS (SELECT ... + * FROM p2 + * WHERE EXISTS (SELECT ... + * FROM sq + * WHERE min(p1.a) + max(p2.b) = sq.c)) + * + * SELECT .. FROM p1 + * WHERE EXISTS (SELECT ... + * FROM p2 + * WHERE EXISTS (SELECT ... + * FROM sq + * WHERE min(p1.a + sq.c) > 1)) + * }}} + * The code below needs to change when we support the above cases. + */ + def getOuterReferences(conditions: Seq[Expression]): Seq[Expression] = { + val outerExpressions = ArrayBuffer.empty[Expression] + conditions foreach { expr => + expr transformDown { + case a: AggregateExpression if a.collectLeaves.forall(_.isInstanceOf[OuterReference]) => + val newExpr = stripOuterReference(a) + outerExpressions += newExpr + newExpr + case OuterReference(e) => + outerExpressions += e + e + } + } + outerExpressions + } + + /** + * Returns all the expressions that have outer references from a logical plan. Currently only + * Filter operator can host outer references. + */ + def getOuterReferences(plan: LogicalPlan): Seq[Expression] = { + val conditions = plan.collect { case Filter(cond, _) => cond } + getOuterReferences(conditions) + } + + /** + * Returns the correlated predicates from a logical plan. The OuterReference wrapper + * is removed before returning the predicate to the caller. + */ + def getCorrelatedPredicates(plan: LogicalPlan): Seq[Expression] = { + val conditions = plan.collect { case Filter(cond, _) => cond } + conditions.flatMap { e => + val (correlated, _) = splitConjunctivePredicates(e).partition(containsOuter) + stripOuterReferences(correlated) match { + case Nil => None + case xs => xs + } + } + } +} + /** * A subquery that will return only one row and one column. This will be converted into a physical * scalar subquery during planning. @@ -63,73 +237,26 @@ case class ScalarSubquery( plan: LogicalPlan, children: Seq[Expression] = Seq.empty, exprId: ExprId = NamedExpression.newExprId) - extends SubqueryExpression with Unevaluable { - override lazy val resolved: Boolean = childrenResolved && plan.resolved - override lazy val references: AttributeSet = { - if (plan.resolved) super.references -- plan.outputSet - else super.references - } + extends SubqueryExpression(plan, children, exprId) with Unevaluable { override def dataType: DataType = plan.schema.fields.head.dataType - override def foldable: Boolean = false override def nullable: Boolean = true override def withNewPlan(plan: LogicalPlan): ScalarSubquery = copy(plan = plan) override def toString: String = s"scalar-subquery#${exprId.id} $conditionString" + override lazy val canonicalized: Expression = { + ScalarSubquery( + plan.canonicalized, + children.map(_.canonicalized), + ExprId(0)) + } } object ScalarSubquery { def hasCorrelatedScalarSubquery(e: Expression): Boolean = { e.find { - case e: ScalarSubquery if e.children.nonEmpty => true - case _ => false - }.isDefined - } -} - -/** - * A predicate subquery checks the existence of a value in a sub-query. We currently only allow - * [[PredicateSubquery]] expressions within a Filter plan (i.e. WHERE or a HAVING clause). This will - * be rewritten into a left semi/anti join during analysis. - */ -case class PredicateSubquery( - plan: LogicalPlan, - children: Seq[Expression] = Seq.empty, - nullAware: Boolean = false, - exprId: ExprId = NamedExpression.newExprId) - extends SubqueryExpression with Predicate with Unevaluable { - override lazy val resolved = childrenResolved && plan.resolved - override lazy val references: AttributeSet = super.references -- plan.outputSet - override def nullable: Boolean = nullAware - override def withNewPlan(plan: LogicalPlan): PredicateSubquery = copy(plan = plan) - override def semanticEquals(o: Expression): Boolean = o match { - case p: PredicateSubquery => - plan.sameResult(p.plan) && nullAware == p.nullAware && - children.length == p.children.length && - children.zip(p.children).forall(p => p._1.semanticEquals(p._2)) - case _ => false - } - override def toString: String = s"predicate-subquery#${exprId.id} $conditionString" -} - -object PredicateSubquery { - def hasPredicateSubquery(e: Expression): Boolean = { - e.find { - case _: PredicateSubquery | _: ListQuery | _: Exists => true + case s: ScalarSubquery => s.children.nonEmpty case _ => false }.isDefined } - - /** - * Returns whether there are any null-aware predicate subqueries inside Not. If not, we could - * turn the null-aware predicate into not-null-aware predicate. - */ - def hasNullAwarePredicateWithinNot(e: Expression): Boolean = { - e.find{ x => - x.isInstanceOf[Not] && e.find { - case p: PredicateSubquery => p.nullAware - case _ => false - }.isDefined - }.isDefined - } } /** @@ -144,18 +271,26 @@ object PredicateSubquery { * FROM b) * }}} */ -case class ListQuery(plan: LogicalPlan, exprId: ExprId = NamedExpression.newExprId) - extends SubqueryExpression with Unevaluable { - override lazy val resolved = false - override def children: Seq[Expression] = Seq.empty - override def dataType: DataType = ArrayType(NullType) +case class ListQuery( + plan: LogicalPlan, + children: Seq[Expression] = Seq.empty, + exprId: ExprId = NamedExpression.newExprId) + extends SubqueryExpression(plan, children, exprId) with Unevaluable { + override def dataType: DataType = plan.schema.fields.head.dataType override def nullable: Boolean = false override def withNewPlan(plan: LogicalPlan): ListQuery = copy(plan = plan) - override def toString: String = s"list#${exprId.id}" + override def toString: String = s"list#${exprId.id} $conditionString" + override lazy val canonicalized: Expression = { + ListQuery( + plan.canonicalized, + children.map(_.canonicalized), + ExprId(0)) + } } /** * The [[Exists]] expression checks if a row exists in a subquery given some correlated condition. + * * For example (SQL): * {{{ * SELECT * @@ -165,11 +300,18 @@ case class ListQuery(plan: LogicalPlan, exprId: ExprId = NamedExpression.newExpr * WHERE b.id = a.id) * }}} */ -case class Exists(plan: LogicalPlan, exprId: ExprId = NamedExpression.newExprId) - extends SubqueryExpression with Predicate with Unevaluable { - override lazy val resolved = false - override def children: Seq[Expression] = Seq.empty +case class Exists( + plan: LogicalPlan, + children: Seq[Expression] = Seq.empty, + exprId: ExprId = NamedExpression.newExprId) + extends SubqueryExpression(plan, children, exprId) with Predicate with Unevaluable { override def nullable: Boolean = false override def withNewPlan(plan: LogicalPlan): Exists = copy(plan = plan) - override def toString: String = s"exists#${exprId.id}" + override def toString: String = s"exists#${exprId.id} $conditionString" + override lazy val canonicalized: Expression = { + Exists( + plan.canonicalized, + children.map(_.canonicalized), + ExprId(0)) + } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/windowExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/windowExpressions.scala index 07d294b108548..37190429fc423 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/windowExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/windowExpressions.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.catalyst.expressions +import java.util.Locale + import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.analysis.{TypeCheckResult, UnresolvedException} import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.{TypeCheckFailure, TypeCheckSuccess} @@ -631,7 +633,7 @@ abstract class RankLike extends AggregateWindowFunction { override val updateExpressions = increaseRank +: increaseRowNumber +: children override val evaluateExpression: Expression = rank - override def sql: String = s"${prettyName.toUpperCase}()" + override def sql: String = s"${prettyName.toUpperCase(Locale.ROOT)}()" def withOrder(order: Seq[Expression]): RankLike } @@ -695,7 +697,7 @@ case class DenseRank(children: Seq[Expression]) extends RankLike { * * This documentation has been based upon similar documentation for the Hive and Presto projects. * - * @param children to base the rank on; a change in the value of one the children will trigger a + * @param children to base the rank on; a change in the value of one of the children will trigger a * change in rank. This is an internal parameter and will be assigned by the * Analyser. */ diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/CreateJacksonParser.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/CreateJacksonParser.scala index e0ed03a68981a..025a388aacaa5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/CreateJacksonParser.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/CreateJacksonParser.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.catalyst.json -import java.io.InputStream +import java.io.{ByteArrayInputStream, InputStream, InputStreamReader} import com.fasterxml.jackson.core.{JsonFactory, JsonParser} import org.apache.hadoop.io.Text @@ -33,7 +33,10 @@ private[sql] object CreateJacksonParser extends Serializable { val bb = record.getByteBuffer assert(bb.hasArray) - jsonFactory.createParser(bb.array(), bb.arrayOffset() + bb.position(), bb.remaining()) + val bain = new ByteArrayInputStream( + bb.array(), bb.arrayOffset() + bb.position(), bb.remaining()) + + jsonFactory.createParser(new InputStreamReader(bain, "UTF-8")) } def text(jsonFactory: JsonFactory, record: Text): JsonParser = { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JSONOptions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JSONOptions.scala index 5a91f9c1939aa..23ba5ed4d50dc 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JSONOptions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JSONOptions.scala @@ -23,7 +23,7 @@ import com.fasterxml.jackson.core.{JsonFactory, JsonParser} import org.apache.commons.lang3.time.FastDateFormat import org.apache.spark.internal.Logging -import org.apache.spark.sql.catalyst.util.{CaseInsensitiveMap, CompressionCodecs, ParseModes} +import org.apache.spark.sql.catalyst.util._ /** * Options for parsing JSON data into Spark SQL rows. @@ -65,11 +65,13 @@ private[sql] class JSONOptions( val allowBackslashEscapingAnyCharacter = parameters.get("allowBackslashEscapingAnyCharacter").map(_.toBoolean).getOrElse(false) val compressionCodec = parameters.get("compression").map(CompressionCodecs.getCodecClassName) - private val parseMode = parameters.getOrElse("mode", "PERMISSIVE") + val parseMode: ParseMode = + parameters.get("mode").map(ParseMode.fromString).getOrElse(PermissiveMode) val columnNameOfCorruptRecord = parameters.getOrElse("columnNameOfCorruptRecord", defaultColumnNameOfCorruptRecord) - val timeZone: TimeZone = TimeZone.getTimeZone(parameters.getOrElse("timeZone", defaultTimeZoneId)) + val timeZone: TimeZone = TimeZone.getTimeZone( + parameters.getOrElse(DateTimeUtils.TIMEZONE_OPTION, defaultTimeZoneId)) // Uses `FastDateFormat` which can be direct replacement for `SimpleDateFormat` and thread-safe. val dateFormat: FastDateFormat = @@ -77,19 +79,10 @@ private[sql] class JSONOptions( val timestampFormat: FastDateFormat = FastDateFormat.getInstance( - parameters.getOrElse("timestampFormat", "yyyy-MM-dd'T'HH:mm:ss.SSSZZ"), timeZone, Locale.US) + parameters.getOrElse("timestampFormat", "yyyy-MM-dd'T'HH:mm:ss.SSSXXX"), timeZone, Locale.US) val wholeFile = parameters.get("wholeFile").map(_.toBoolean).getOrElse(false) - // Parse mode flags - if (!ParseModes.isValidMode(parseMode)) { - logWarning(s"$parseMode is not a valid parse mode. Using ${ParseModes.DEFAULT}.") - } - - val failFast = ParseModes.isFailFastMode(parseMode) - val dropMalformed = ParseModes.isDropMalformedMode(parseMode) - val permissive = ParseModes.isPermissiveMode(parseMode) - /** Sets config options on a Jackson [[JsonFactory]]. */ def setJacksonOptions(factory: JsonFactory): Unit = { factory.configure(JsonParser.Feature.ALLOW_COMMENTS, allowComments) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonGenerator.scala index dec55279c9fc5..1d302aea6fd16 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonGenerator.scala @@ -37,6 +37,10 @@ private[sql] class JacksonGenerator( // `ValueWriter`s for all fields of the schema private val rootFieldWriters: Array[ValueWriter] = schema.map(_.dataType).map(makeWriter).toArray + // `ValueWriter` for array data storing rows of the schema. + private val arrElementWriter: ValueWriter = (arr: SpecializedGetters, i: Int) => { + writeObject(writeFields(arr.getStruct(i, schema.length), schema, rootFieldWriters)) + } private val gen = new JsonFactory().createGenerator(writer).setRootValueSeparator(null) @@ -185,17 +189,18 @@ private[sql] class JacksonGenerator( def flush(): Unit = gen.flush() /** - * Transforms a single InternalRow to JSON using Jackson + * Transforms a single `InternalRow` to JSON object using Jackson * * @param row The row to convert */ - def write(row: InternalRow): Unit = { - writeObject { - writeFields(row, schema, rootFieldWriters) - } - } + def write(row: InternalRow): Unit = writeObject(writeFields(row, schema, rootFieldWriters)) - def writeLineEnding(): Unit = { - gen.writeRaw('\n') - } + /** + * Transforms multiple `InternalRow`s to JSON array using Jackson + * + * @param array The array of rows to convert + */ + def write(array: ArrayData): Unit = writeArray(writeArrayData(array, arrElementWriter)) + + def writeLineEnding(): Unit = gen.writeRaw('\n') } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonParser.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonParser.scala index 9b80c0fc87c93..ff6c93ae9815c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonParser.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonParser.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.catalyst.json import java.io.ByteArrayOutputStream +import java.util.Locale import scala.collection.mutable.ArrayBuffer import scala.util.Try @@ -32,17 +33,14 @@ import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String import org.apache.spark.util.Utils -private[sql] class SparkSQLJsonProcessingException(msg: String) extends RuntimeException(msg) - /** * Constructs a parser for a given schema that translates a json string to an [[InternalRow]]. */ class JacksonParser( schema: StructType, - options: JSONOptions) extends Logging { + val options: JSONOptions) extends Logging { import JacksonUtils._ - import ParseModes._ import com.fasterxml.jackson.core.JsonToken._ // A `ValueConverter` is responsible for converting a value from `JsonParser` @@ -55,108 +53,6 @@ class JacksonParser( private val factory = new JsonFactory() options.setJacksonOptions(factory) - private val emptyRow: Seq[InternalRow] = Seq(new GenericInternalRow(schema.length)) - - private val corruptFieldIndex = schema.getFieldIndex(options.columnNameOfCorruptRecord) - corruptFieldIndex.foreach { corrFieldIndex => - require(schema(corrFieldIndex).dataType == StringType) - require(schema(corrFieldIndex).nullable) - } - - @transient - private[this] var isWarningPrinted: Boolean = false - - @transient - private def printWarningForMalformedRecord(record: () => UTF8String): Unit = { - def sampleRecord: String = { - if (options.wholeFile) { - "" - } else { - s"Sample record: ${record()}\n" - } - } - - def footer: String = { - s"""Code example to print all malformed records (scala): - |=================================================== - |// The corrupted record exists in column ${options.columnNameOfCorruptRecord}. - |val parsedJson = spark.read.json("/path/to/json/file/test.json") - | - """.stripMargin - } - - if (options.permissive) { - logWarning( - s"""Found at least one malformed record. The JSON reader will replace - |all malformed records with placeholder null in current $PERMISSIVE_MODE parser mode. - |To find out which corrupted records have been replaced with null, please use the - |default inferred schema instead of providing a custom schema. - | - |${sampleRecord ++ footer} - | - """.stripMargin) - } else if (options.dropMalformed) { - logWarning( - s"""Found at least one malformed record. The JSON reader will drop - |all malformed records in current $DROP_MALFORMED_MODE parser mode. To find out which - |corrupted records have been dropped, please switch the parser mode to $PERMISSIVE_MODE - |mode and use the default inferred schema. - | - |${sampleRecord ++ footer} - | - """.stripMargin) - } - } - - @transient - private def printWarningIfWholeFile(): Unit = { - if (options.wholeFile && corruptFieldIndex.isDefined) { - logWarning( - s"""Enabling wholeFile mode and defining columnNameOfCorruptRecord may result - |in very large allocations or OutOfMemoryExceptions being raised. - | - """.stripMargin) - } - } - - /** - * This function deals with the cases it fails to parse. This function will be called - * when exceptions are caught during converting. This functions also deals with `mode` option. - */ - private def failedRecord(record: () => UTF8String): Seq[InternalRow] = { - corruptFieldIndex match { - case _ if options.failFast => - if (options.wholeFile) { - throw new SparkSQLJsonProcessingException("Malformed line in FAILFAST mode") - } else { - throw new SparkSQLJsonProcessingException(s"Malformed line in FAILFAST mode: ${record()}") - } - - case _ if options.dropMalformed => - if (!isWarningPrinted) { - printWarningForMalformedRecord(record) - isWarningPrinted = true - } - Nil - - case None => - if (!isWarningPrinted) { - printWarningForMalformedRecord(record) - isWarningPrinted = true - } - emptyRow - - case Some(corruptIndex) => - if (!isWarningPrinted) { - printWarningIfWholeFile() - isWarningPrinted = true - } - val row = new GenericInternalRow(schema.length) - row.update(corruptIndex, record()) - Seq(row) - } - } - /** * Create a converter which converts the JSON documents held by the `JsonParser` * to a value according to a desired schema. This is a wrapper for the method @@ -231,7 +127,7 @@ class JacksonParser( case VALUE_STRING => // Special case handling for NaN and Infinity. val value = parser.getText - val lowerCaseValue = value.toLowerCase + val lowerCaseValue = value.toLowerCase(Locale.ROOT) if (lowerCaseValue.equals("nan") || lowerCaseValue.equals("infinity") || lowerCaseValue.equals("-infinity") || @@ -239,7 +135,7 @@ class JacksonParser( lowerCaseValue.equals("-inf")) { value.toFloat } else { - throw new SparkSQLJsonProcessingException(s"Cannot parse $value as FloatType.") + throw new RuntimeException(s"Cannot parse $value as FloatType.") } } @@ -251,7 +147,7 @@ class JacksonParser( case VALUE_STRING => // Special case handling for NaN and Infinity. val value = parser.getText - val lowerCaseValue = value.toLowerCase + val lowerCaseValue = value.toLowerCase(Locale.ROOT) if (lowerCaseValue.equals("nan") || lowerCaseValue.equals("infinity") || lowerCaseValue.equals("-infinity") || @@ -259,7 +155,7 @@ class JacksonParser( lowerCaseValue.equals("-inf")) { value.toDouble } else { - throw new SparkSQLJsonProcessingException(s"Cannot parse $value as DoubleType.") + throw new RuntimeException(s"Cannot parse $value as DoubleType.") } } @@ -391,9 +287,8 @@ class JacksonParser( case token => // We cannot parse this token based on the given data type. So, we throw a - // SparkSQLJsonProcessingException and this exception will be caught by - // `parse` method. - throw new SparkSQLJsonProcessingException( + // RuntimeException and this exception will be caught by `parse` method. + throw new RuntimeException( s"Failed to parse a value for data type $dataType (current token: $token).") } @@ -466,14 +361,14 @@ class JacksonParser( parser.nextToken() match { case null => Nil case _ => rootConverter.apply(parser) match { - case null => throw new SparkSQLJsonProcessingException("Root converter returned null") + case null => throw new RuntimeException("Root converter returned null") case rows => rows } } } } catch { - case _: JsonProcessingException | _: SparkSQLJsonProcessingException => - failedRecord(() => recordLiteral(record)) + case e @ (_: RuntimeException | _: JsonProcessingException) => + throw BadRecordException(() => recordLiteral(record), () => None, e) } } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/CostBasedJoinReorder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/CostBasedJoinReorder.scala new file mode 100644 index 0000000000000..51eca6ca33760 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/CostBasedJoinReorder.scala @@ -0,0 +1,459 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.optimizer + +import scala.collection.mutable + +import org.apache.spark.internal.Logging +import org.apache.spark.sql.catalyst.expressions.{And, Attribute, AttributeSet, Expression, PredicateHelper} +import org.apache.spark.sql.catalyst.plans.{Inner, InnerLike, JoinType} +import org.apache.spark.sql.catalyst.plans.logical.{BinaryNode, Join, LogicalPlan, Project} +import org.apache.spark.sql.catalyst.rules.Rule +import org.apache.spark.sql.internal.SQLConf + + +/** + * Cost-based join reorder. + * We may have several join reorder algorithms in the future. This class is the entry of these + * algorithms, and chooses which one to use. + */ +case class CostBasedJoinReorder(conf: SQLConf) extends Rule[LogicalPlan] with PredicateHelper { + def apply(plan: LogicalPlan): LogicalPlan = { + if (!conf.cboEnabled || !conf.joinReorderEnabled) { + plan + } else { + val result = plan transformDown { + // Start reordering with a joinable item, which is an InnerLike join with conditions. + case j @ Join(_, _, _: InnerLike, Some(cond)) => + reorder(j, j.output) + case p @ Project(projectList, Join(_, _, _: InnerLike, Some(cond))) + if projectList.forall(_.isInstanceOf[Attribute]) => + reorder(p, p.output) + } + // After reordering is finished, convert OrderedJoin back to Join + result transformDown { + case OrderedJoin(left, right, jt, cond) => Join(left, right, jt, cond) + } + } + } + + private def reorder(plan: LogicalPlan, output: Seq[Attribute]): LogicalPlan = { + val (items, conditions) = extractInnerJoins(plan) + val result = + // Do reordering if the number of items is appropriate and join conditions exist. + // We also need to check if costs of all items can be evaluated. + if (items.size > 2 && items.size <= conf.joinReorderDPThreshold && conditions.nonEmpty && + items.forall(_.stats(conf).rowCount.isDefined)) { + JoinReorderDP.search(conf, items, conditions, output) + } else { + plan + } + // Set consecutive join nodes ordered. + replaceWithOrderedJoin(result) + } + + /** + * Extracts items of consecutive inner joins and join conditions. + * This method works for bushy trees and left/right deep trees. + */ + private def extractInnerJoins(plan: LogicalPlan): (Seq[LogicalPlan], Set[Expression]) = { + plan match { + case Join(left, right, _: InnerLike, Some(cond)) => + val (leftPlans, leftConditions) = extractInnerJoins(left) + val (rightPlans, rightConditions) = extractInnerJoins(right) + (leftPlans ++ rightPlans, splitConjunctivePredicates(cond).toSet ++ + leftConditions ++ rightConditions) + case Project(projectList, j @ Join(_, _, _: InnerLike, Some(cond))) + if projectList.forall(_.isInstanceOf[Attribute]) => + extractInnerJoins(j) + case _ => + (Seq(plan), Set()) + } + } + + private def replaceWithOrderedJoin(plan: LogicalPlan): LogicalPlan = plan match { + case j @ Join(left, right, jt: InnerLike, Some(cond)) => + val replacedLeft = replaceWithOrderedJoin(left) + val replacedRight = replaceWithOrderedJoin(right) + OrderedJoin(replacedLeft, replacedRight, jt, Some(cond)) + case p @ Project(projectList, j @ Join(_, _, _: InnerLike, Some(cond))) => + p.copy(child = replaceWithOrderedJoin(j)) + case _ => + plan + } +} + +/** This is a mimic class for a join node that has been ordered. */ +case class OrderedJoin( + left: LogicalPlan, + right: LogicalPlan, + joinType: JoinType, + condition: Option[Expression]) extends BinaryNode { + override def output: Seq[Attribute] = left.output ++ right.output +} + +/** + * Reorder the joins using a dynamic programming algorithm. This implementation is based on the + * paper: Access Path Selection in a Relational Database Management System. + * http://www.inf.ed.ac.uk/teaching/courses/adbs/AccessPath.pdf + * + * First we put all items (basic joined nodes) into level 0, then we build all two-way joins + * at level 1 from plans at level 0 (single items), then build all 3-way joins from plans + * at previous levels (two-way joins and single items), then 4-way joins ... etc, until we + * build all n-way joins and pick the best plan among them. + * + * When building m-way joins, we only keep the best plan (with the lowest cost) for the same set + * of m items. E.g., for 3-way joins, we keep only the best plan for items {A, B, C} among + * plans (A J B) J C, (A J C) J B and (B J C) J A. + * We also prune cartesian product candidates when building a new plan if there exists no join + * condition involving references from both left and right. This pruning strategy significantly + * reduces the search space. + * E.g., given A J B J C J D with join conditions A.k1 = B.k1 and B.k2 = C.k2 and C.k3 = D.k3, + * plans maintained for each level are as follows: + * level 0: p({A}), p({B}), p({C}), p({D}) + * level 1: p({A, B}), p({B, C}), p({C, D}) + * level 2: p({A, B, C}), p({B, C, D}) + * level 3: p({A, B, C, D}) + * where p({A, B, C, D}) is the final output plan. + * + * For cost evaluation, since physical costs for operators are not available currently, we use + * cardinalities and sizes to compute costs. + */ +object JoinReorderDP extends PredicateHelper with Logging { + + def search( + conf: SQLConf, + items: Seq[LogicalPlan], + conditions: Set[Expression], + output: Seq[Attribute]): LogicalPlan = { + + val startTime = System.nanoTime() + // Level i maintains all found plans for i + 1 items. + // Create the initial plans: each plan is a single item with zero cost. + val itemIndex = items.zipWithIndex + val foundPlans = mutable.Buffer[JoinPlanMap](itemIndex.map { + case (item, id) => Set(id) -> JoinPlan(Set(id), item, Set(), Cost(0, 0)) + }.toMap) + + // Build filters from the join graph to be used by the search algorithm. + val filters = JoinReorderDPFilters.buildJoinGraphInfo(conf, items, conditions, itemIndex) + + // Build plans for next levels until the last level has only one plan. This plan contains + // all items that can be joined, so there's no need to continue. + val topOutputSet = AttributeSet(output) + while (foundPlans.size < items.length) { + // Build plans for the next level. + foundPlans += searchLevel(foundPlans, conf, conditions, topOutputSet, filters) + } + + val durationInMs = (System.nanoTime() - startTime) / (1000 * 1000) + logDebug(s"Join reordering finished. Duration: $durationInMs ms, number of items: " + + s"${items.length}, number of plans in memo: ${foundPlans.map(_.size).sum}") + + // The last level must have one and only one plan, because all items are joinable. + assert(foundPlans.size == items.length && foundPlans.last.size == 1) + foundPlans.last.head._2.plan match { + case p @ Project(projectList, j: Join) if projectList != output => + assert(topOutputSet == p.outputSet) + // Keep the same order of final output attributes. + p.copy(projectList = output) + case finalPlan => + finalPlan + } + } + + /** Find all possible plans at the next level, based on existing levels. */ + private def searchLevel( + existingLevels: Seq[JoinPlanMap], + conf: SQLConf, + conditions: Set[Expression], + topOutput: AttributeSet, + filters: Option[JoinGraphInfo]): JoinPlanMap = { + + val nextLevel = mutable.Map.empty[Set[Int], JoinPlan] + var k = 0 + val lev = existingLevels.length - 1 + // Build plans for the next level from plans at level k (one side of the join) and level + // lev - k (the other side of the join). + // For the lower level k, we only need to search from 0 to lev - k, because when building + // a join from A and B, both A J B and B J A are handled. + while (k <= lev - k) { + val oneSideCandidates = existingLevels(k).values.toSeq + for (i <- oneSideCandidates.indices) { + val oneSidePlan = oneSideCandidates(i) + val otherSideCandidates = if (k == lev - k) { + // Both sides of a join are at the same level, no need to repeat for previous ones. + oneSideCandidates.drop(i) + } else { + existingLevels(lev - k).values.toSeq + } + + otherSideCandidates.foreach { otherSidePlan => + buildJoin(oneSidePlan, otherSidePlan, conf, conditions, topOutput, filters) match { + case Some(newJoinPlan) => + // Check if it's the first plan for the item set, or it's a better plan than + // the existing one due to lower cost. + val existingPlan = nextLevel.get(newJoinPlan.itemIds) + if (existingPlan.isEmpty || newJoinPlan.betterThan(existingPlan.get, conf)) { + nextLevel.update(newJoinPlan.itemIds, newJoinPlan) + } + case None => + } + } + } + k += 1 + } + nextLevel.toMap + } + + /** + * Builds a new JoinPlan if the following conditions hold: + * - the sets of items contained in left and right sides do not overlap. + * - there exists at least one join condition involving references from both sides. + * - if star-join filter is enabled, allow the following combinations: + * 1) (oneJoinPlan U otherJoinPlan) is a subset of star-join + * 2) star-join is a subset of (oneJoinPlan U otherJoinPlan) + * 3) (oneJoinPlan U otherJoinPlan) is a subset of non star-join + * + * @param oneJoinPlan One side JoinPlan for building a new JoinPlan. + * @param otherJoinPlan The other side JoinPlan for building a new join node. + * @param conf SQLConf for statistics computation. + * @param conditions The overall set of join conditions. + * @param topOutput The output attributes of the final plan. + * @param filters Join graph info to be used as filters by the search algorithm. + * @return Builds and returns a new JoinPlan if both conditions hold. Otherwise, returns None. + */ + private def buildJoin( + oneJoinPlan: JoinPlan, + otherJoinPlan: JoinPlan, + conf: SQLConf, + conditions: Set[Expression], + topOutput: AttributeSet, + filters: Option[JoinGraphInfo]): Option[JoinPlan] = { + + if (oneJoinPlan.itemIds.intersect(otherJoinPlan.itemIds).nonEmpty) { + // Should not join two overlapping item sets. + return None + } + + if (filters.isDefined) { + // Apply star-join filter, which ensures that tables in a star schema relationship + // are planned together. The star-filter will eliminate joins among star and non-star + // tables until the star joins are built. The following combinations are allowed: + // 1. (oneJoinPlan U otherJoinPlan) is a subset of star-join + // 2. star-join is a subset of (oneJoinPlan U otherJoinPlan) + // 3. (oneJoinPlan U otherJoinPlan) is a subset of non star-join + val isValidJoinCombination = + JoinReorderDPFilters.starJoinFilter(oneJoinPlan.itemIds, otherJoinPlan.itemIds, + filters.get) + if (!isValidJoinCombination) return None + } + + val onePlan = oneJoinPlan.plan + val otherPlan = otherJoinPlan.plan + val joinConds = conditions + .filterNot(l => canEvaluate(l, onePlan)) + .filterNot(r => canEvaluate(r, otherPlan)) + .filter(e => e.references.subsetOf(onePlan.outputSet ++ otherPlan.outputSet)) + if (joinConds.isEmpty) { + // Cartesian product is very expensive, so we exclude them from candidate plans. + // This also significantly reduces the search space. + return None + } + + // Put the deeper side on the left, tend to build a left-deep tree. + val (left, right) = if (oneJoinPlan.itemIds.size >= otherJoinPlan.itemIds.size) { + (onePlan, otherPlan) + } else { + (otherPlan, onePlan) + } + val newJoin = Join(left, right, Inner, joinConds.reduceOption(And)) + val collectedJoinConds = joinConds ++ oneJoinPlan.joinConds ++ otherJoinPlan.joinConds + val remainingConds = conditions -- collectedJoinConds + val neededAttr = AttributeSet(remainingConds.flatMap(_.references)) ++ topOutput + val neededFromNewJoin = newJoin.output.filter(neededAttr.contains) + val newPlan = + if ((newJoin.outputSet -- neededFromNewJoin).nonEmpty) { + Project(neededFromNewJoin, newJoin) + } else { + newJoin + } + + val itemIds = oneJoinPlan.itemIds.union(otherJoinPlan.itemIds) + // Now the root node of onePlan/otherPlan becomes an intermediate join (if it's a non-leaf + // item), so the cost of the new join should also include its own cost. + val newPlanCost = oneJoinPlan.planCost + oneJoinPlan.rootCost(conf) + + otherJoinPlan.planCost + otherJoinPlan.rootCost(conf) + Some(JoinPlan(itemIds, newPlan, collectedJoinConds, newPlanCost)) + } + + /** Map[set of item ids, join plan for these items] */ + type JoinPlanMap = Map[Set[Int], JoinPlan] + + /** + * Partial join order in a specific level. + * + * @param itemIds Set of item ids participating in this partial plan. + * @param plan The plan tree with the lowest cost for these items found so far. + * @param joinConds Join conditions included in the plan. + * @param planCost The cost of this plan tree is the sum of costs of all intermediate joins. + */ + case class JoinPlan( + itemIds: Set[Int], + plan: LogicalPlan, + joinConds: Set[Expression], + planCost: Cost) { + + /** Get the cost of the root node of this plan tree. */ + def rootCost(conf: SQLConf): Cost = { + if (itemIds.size > 1) { + val rootStats = plan.stats(conf) + Cost(rootStats.rowCount.get, rootStats.sizeInBytes) + } else { + // If the plan is a leaf item, it has zero cost. + Cost(0, 0) + } + } + + def betterThan(other: JoinPlan, conf: SQLConf): Boolean = { + if (other.planCost.card == 0 || other.planCost.size == 0) { + false + } else { + val relativeRows = BigDecimal(this.planCost.card) / BigDecimal(other.planCost.card) + val relativeSize = BigDecimal(this.planCost.size) / BigDecimal(other.planCost.size) + relativeRows * conf.joinReorderCardWeight + + relativeSize * (1 - conf.joinReorderCardWeight) < 1 + } + } + } +} + +/** + * This class defines the cost model for a plan. + * @param card Cardinality (number of rows). + * @param size Size in bytes. + */ +case class Cost(card: BigInt, size: BigInt) { + def +(other: Cost): Cost = Cost(this.card + other.card, this.size + other.size) +} + +/** + * Implements optional filters to reduce the search space for join enumeration. + * + * 1) Star-join filters: Plan star-joins together since they are assumed + * to have an optimal execution based on their RI relationship. + * 2) Cartesian products: Defer their planning later in the graph to avoid + * large intermediate results (expanding joins, in general). + * 3) Composite inners: Don't generate "bushy tree" plans to avoid materializing + * intermediate results. + * + * Filters (2) and (3) are not implemented. + */ +object JoinReorderDPFilters extends PredicateHelper { + /** + * Builds join graph information to be used by the filtering strategies. + * Currently, it builds the sets of star/non-star joins. + * It can be extended with the sets of connected/unconnected joins, which + * can be used to filter Cartesian products. + */ + def buildJoinGraphInfo( + conf: SQLConf, + items: Seq[LogicalPlan], + conditions: Set[Expression], + itemIndex: Seq[(LogicalPlan, Int)]): Option[JoinGraphInfo] = { + + if (conf.joinReorderDPStarFilter) { + // Compute the tables in a star-schema relationship. + val starJoin = StarSchemaDetection(conf).findStarJoins(items, conditions.toSeq) + val nonStarJoin = items.filterNot(starJoin.contains(_)) + + if (starJoin.nonEmpty && nonStarJoin.nonEmpty) { + val itemMap = itemIndex.toMap + Some(JoinGraphInfo(starJoin.map(itemMap).toSet, nonStarJoin.map(itemMap).toSet)) + } else { + // Nothing interesting to return. + None + } + } else { + // Star schema filter is not enabled. + None + } + } + + /** + * Applies the star-join filter that eliminates join combinations among star + * and non-star tables until the star join is built. + * + * Given the oneSideJoinPlan/otherSideJoinPlan, which represent all the plan + * permutations generated by the DP join enumeration, and the star/non-star plans, + * the following plan combinations are allowed: + * 1. (oneSideJoinPlan U otherSideJoinPlan) is a subset of star-join + * 2. star-join is a subset of (oneSideJoinPlan U otherSideJoinPlan) + * 3. (oneSideJoinPlan U otherSideJoinPlan) is a subset of non star-join + * + * It assumes the sets are disjoint. + * + * Example query graph: + * + * t1 d1 - t2 - t3 + * \ / + * f1 + * | + * d2 + * + * star: {d1, f1, d2} + * non-star: {t2, t1, t3} + * + * level 0: (f1 ), (d2 ), (t3 ), (d1 ), (t1 ), (t2 ) + * level 1: {t3 t2 }, {f1 d2 }, {f1 d1 } + * level 2: {d2 f1 d1 } + * level 3: {t1 d1 f1 d2 }, {t2 d1 f1 d2 } + * level 4: {d1 t2 f1 t1 d2 }, {d1 t3 t2 f1 d2 } + * level 5: {d1 t3 t2 f1 t1 d2 } + * + * @param oneSideJoinPlan One side of the join represented as a set of plan ids. + * @param otherSideJoinPlan The other side of the join represented as a set of plan ids. + * @param filters Star and non-star plans represented as sets of plan ids + */ + def starJoinFilter( + oneSideJoinPlan: Set[Int], + otherSideJoinPlan: Set[Int], + filters: JoinGraphInfo) : Boolean = { + val starJoins = filters.starJoins + val nonStarJoins = filters.nonStarJoins + val join = oneSideJoinPlan.union(otherSideJoinPlan) + + // Disjoint sets + oneSideJoinPlan.intersect(otherSideJoinPlan).isEmpty && + // Either star or non-star is empty + (starJoins.isEmpty || nonStarJoins.isEmpty || + // Join is a subset of the star-join + join.subsetOf(starJoins) || + // Star-join is a subset of join + starJoins.subsetOf(join) || + // Join is a subset of non-star + join.subsetOf(nonStarJoins)) + } +} + +/** + * Helper class that keeps information about the join graph as sets of item/plan ids. + * It currently stores the star/non-star plans. It can be + * extended with the set of connected/unconnected plans. + */ +case class JoinGraphInfo (starJoins: Set[Int], nonStarJoins: Set[Int]) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index 036da3ad2062f..f2b9764b0f088 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -17,30 +17,24 @@ package org.apache.spark.sql.catalyst.optimizer -import scala.annotation.tailrec -import scala.collection.immutable.HashSet import scala.collection.mutable -import scala.collection.mutable.ArrayBuffer -import org.apache.spark.api.java.function.FilterFunction import org.apache.spark.sql.AnalysisException -import org.apache.spark.sql.catalyst.{CatalystConf, SimpleCatalystConf} import org.apache.spark.sql.catalyst.analysis._ import org.apache.spark.sql.catalyst.catalog.{InMemoryCatalog, SessionCatalog} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate._ -import org.apache.spark.sql.catalyst.expressions.Literal.{FalseLiteral, TrueLiteral} -import org.apache.spark.sql.catalyst.planning.ExtractFiltersAndInnerJoins import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules._ +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ /** * Abstract class all optimizers should inherit of, contains the standard batches (extending * Optimizers can override this. */ -abstract class Optimizer(sessionCatalog: SessionCatalog, conf: CatalystConf) +abstract class Optimizer(sessionCatalog: SessionCatalog, conf: SQLConf) extends RuleExecutor[LogicalPlan] { protected val fixedPoint = FixedPoint(conf.optimizerMaxIterations) @@ -68,6 +62,8 @@ abstract class Optimizer(sessionCatalog: SessionCatalog, conf: CatalystConf) // since the other rules might make two separate Unions operators adjacent. Batch("Union", Once, CombineUnions) :: + Batch("Pullup Correlated Expressions", Once, + PullupCorrelatedPredicates) :: Batch("Subquery", Once, OptimizeSubqueries) :: Batch("Replace Operators", fixedPoint, @@ -77,16 +73,16 @@ abstract class Optimizer(sessionCatalog: SessionCatalog, conf: CatalystConf) Batch("Aggregate", fixedPoint, RemoveLiteralFromGroupExpressions, RemoveRepetitionFromGroupExpressions) :: - Batch("Operator Optimizations", fixedPoint, + Batch("Operator Optimizations", fixedPoint, Seq( // Operator push down PushProjectionThroughUnion, - ReorderJoin, - EliminateOuterJoin, + ReorderJoin(conf), + EliminateOuterJoin(conf), PushPredicateThroughJoin, PushDownPredicate, LimitPushDown(conf), ColumnPruning, - InferFiltersFromConstraints, + InferFiltersFromConstraints(conf), // Operator combine CollapseRepartition, CollapseProject, @@ -105,7 +101,7 @@ abstract class Optimizer(sessionCatalog: SessionCatalog, conf: CatalystConf) SimplifyConditionals, RemoveDispensableExpressions, SimplifyBinaryComparison, - PruneFilters, + PruneFilters(conf), EliminateSorts, SimplifyCasts, SimplifyCaseConversionExpressions, @@ -115,12 +111,16 @@ abstract class Optimizer(sessionCatalog: SessionCatalog, conf: CatalystConf) RemoveRedundantProject, SimplifyCreateStructOps, SimplifyCreateArrayOps, - SimplifyCreateMapOps) :: + SimplifyCreateMapOps) ++ + extendedOperatorOptimizationRules: _*) :: Batch("Check Cartesian Products", Once, CheckCartesianProducts(conf)) :: + Batch("Join Reorder", Once, + CostBasedJoinReorder(conf)) :: Batch("Decimal Optimizations", fixedPoint, DecimalAggregates(conf)) :: - Batch("Typed Filter Optimization", fixedPoint, + Batch("Object Expressions Optimization", fixedPoint, + EliminateMapObjects, CombineTypedFilters) :: Batch("LocalRelation", fixedPoint, ConvertToLocalRelation, @@ -138,9 +138,15 @@ abstract class Optimizer(sessionCatalog: SessionCatalog, conf: CatalystConf) object OptimizeSubqueries extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions { case s: SubqueryExpression => - s.withNewPlan(Optimizer.this.execute(s.plan)) + val Subquery(newPlan) = Optimizer.this.execute(Subquery(s.plan)) + s.withNewPlan(newPlan) } } + + /** + * Override to provide additional rules for the operator optimization batch. + */ + def extendedOperatorOptimizationRules: Seq[Rule[LogicalPlan]] = Nil } /** @@ -155,8 +161,8 @@ class SimpleTestOptimizer extends Optimizer( new SessionCatalog( new InMemoryCatalog, EmptyFunctionRegistry, - new SimpleCatalystConf(caseSensitiveAnalysis = true)), - new SimpleCatalystConf(caseSensitiveAnalysis = true)) + new SQLConf().copy(SQLConf.CASE_SENSITIVE -> true)), + new SQLConf().copy(SQLConf.CASE_SENSITIVE -> true)) /** * Remove redundant aliases from a query plan. A redundant alias is an alias that does not change @@ -183,7 +189,10 @@ object RemoveRedundantAliases extends Rule[LogicalPlan] { // If the alias name is different from attribute name, we can't strip it either, or we // may accidentally change the output schema name of the root plan. case a @ Alias(attr: Attribute, name) - if a.metadata == Metadata.empty && name == attr.name && !blacklist.contains(attr) => + if a.metadata == Metadata.empty && + name == attr.name && + !blacklist.contains(attr) && + !blacklist.contains(a) => attr case a => a } @@ -191,10 +200,15 @@ object RemoveRedundantAliases extends Rule[LogicalPlan] { /** * Remove redundant alias expression from a LogicalPlan and its subtree. A blacklist is used to * prevent the removal of seemingly redundant aliases used to deduplicate the input for a (self) - * join. + * join or to prevent the removal of top-level subquery attributes. */ private def removeRedundantAliases(plan: LogicalPlan, blacklist: AttributeSet): LogicalPlan = { plan match { + // We want to keep the same output attributes for subqueries. This means we cannot remove + // the aliases that produce these attributes + case Subquery(child) => + Subquery(removeRedundantAliases(child, blacklist ++ child.outputSet)) + // A join has to be treated differently, because the left and the right side of the join are // not allowed to use the same attributes. We use a blacklist to prevent us from creating a // situation in which this happens; the rule will only remove an alias if its child @@ -257,7 +271,7 @@ object RemoveRedundantProject extends Rule[LogicalPlan] { /** * Pushes down [[LocalLimit]] beneath UNION ALL and beneath the streamed inputs of outer joins. */ -case class LimitPushDown(conf: CatalystConf) extends Rule[LogicalPlan] { +case class LimitPushDown(conf: SQLConf) extends Rule[LogicalPlan] { private def stripGlobalLimitIfPresent(plan: LogicalPlan): LogicalPlan = { plan match { @@ -427,8 +441,7 @@ object ColumnPruning extends Rule[LogicalPlan] { g.copy(child = prunedChild(g.child, g.references)) // Turn off `join` for Generate if no column from it's child is used - case p @ Project(_, g: Generate) - if g.join && !g.outer && p.references.subsetOf(g.generatedSet) => + case p @ Project(_, g: Generate) if g.join && p.references.subsetOf(g.generatedSet) => p.copy(child = g.copy(join = false)) // Eliminate unneeded attributes from right side of a Left Existence Join. @@ -562,38 +575,36 @@ object CollapseProject extends Rule[LogicalPlan] { } /** - * Combines adjacent [[Repartition]] and [[RepartitionByExpression]] operator combinations - * by keeping only the one. - * 1. For adjacent [[Repartition]]s, collapse into the last [[Repartition]]. - * 2. For adjacent [[RepartitionByExpression]]s, collapse into the last [[RepartitionByExpression]]. - * 3. For a combination of [[Repartition]] and [[RepartitionByExpression]], collapse as a single - * [[RepartitionByExpression]] with the expression and last number of partition. + * Combines adjacent [[RepartitionOperation]] operators */ object CollapseRepartition extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = plan transformUp { - // Case 1 - case Repartition(numPartitions, shuffle, Repartition(_, _, child)) => - Repartition(numPartitions, shuffle, child) - // Case 2 - case RepartitionByExpression(exprs, RepartitionByExpression(_, child, _), numPartitions) => - RepartitionByExpression(exprs, child, numPartitions) - // Case 3 - case Repartition(numPartitions, _, r: RepartitionByExpression) => - r.copy(numPartitions = numPartitions) - // Case 3 - case RepartitionByExpression(exprs, Repartition(_, _, child), numPartitions) => - RepartitionByExpression(exprs, child, numPartitions) + // Case 1: When a Repartition has a child of Repartition or RepartitionByExpression, + // 1) When the top node does not enable the shuffle (i.e., coalesce API), but the child + // enables the shuffle. Returns the child node if the last numPartitions is bigger; + // otherwise, keep unchanged. + // 2) In the other cases, returns the top node with the child's child + case r @ Repartition(_, _, child: RepartitionOperation) => (r.shuffle, child.shuffle) match { + case (false, true) => if (r.numPartitions >= child.numPartitions) child else r + case _ => r.copy(child = child.child) + } + // Case 2: When a RepartitionByExpression has a child of Repartition or RepartitionByExpression + // we can remove the child. + case r @ RepartitionByExpression(_, child: RepartitionOperation, _) => + r.copy(child = child.child) } } /** * Collapse Adjacent Window Expression. - * - If the partition specs and order specs are the same, collapse into the parent. + * - If the partition specs and order specs are the same and the window expression are + * independent, collapse into the parent. */ object CollapseWindow extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = plan transformUp { - case w @ Window(we1, ps1, os1, Window(we2, ps2, os2, grandChild)) if ps1 == ps2 && os1 == os2 => - w.copy(windowExpressions = we2 ++ we1, child = grandChild) + case w1 @ Window(we1, ps1, os1, w2 @ Window(we2, ps2, os2, grandChild)) + if ps1 == ps2 && os1 == os2 && w1.references.intersect(w2.windowOutputSet).isEmpty => + w1.copy(windowExpressions = we2 ++ we1, child = grandChild) } } @@ -606,8 +617,16 @@ object CollapseWindow extends Rule[LogicalPlan] { * Note: While this optimization is applicable to all types of join, it primarily benefits Inner and * LeftSemi joins. */ -object InferFiltersFromConstraints extends Rule[LogicalPlan] with PredicateHelper { - def apply(plan: LogicalPlan): LogicalPlan = plan transform { +case class InferFiltersFromConstraints(conf: SQLConf) + extends Rule[LogicalPlan] with PredicateHelper { + def apply(plan: LogicalPlan): LogicalPlan = if (conf.constraintPropagationEnabled) { + inferFilters(plan) + } else { + plan + } + + + private def inferFilters(plan: LogicalPlan): LogicalPlan = plan transform { case filter @ Filter(condition, child) => val newFilters = filter.constraints -- (child.constraints ++ splitConjunctivePredicates(condition)) @@ -696,7 +715,7 @@ object EliminateSorts extends Rule[LogicalPlan] { * 2) by substituting a dummy empty relation when the filter will always evaluate to `false`. * 3) by eliminating the always-true conditions given the constraints on the child's output. */ -object PruneFilters extends Rule[LogicalPlan] with PredicateHelper { +case class PruneFilters(conf: SQLConf) extends Rule[LogicalPlan] with PredicateHelper { def apply(plan: LogicalPlan): LogicalPlan = plan transform { // If the filter condition always evaluate to true, remove the filter. case Filter(Literal(true, BooleanType), child) => child @@ -709,7 +728,7 @@ object PruneFilters extends Rule[LogicalPlan] with PredicateHelper { case f @ Filter(fc, p: LogicalPlan) => val (prunedPredicates, remainingPredicates) = splitConjunctivePredicates(fc).partition { cond => - cond.deterministic && p.constraints.contains(cond) + cond.deterministic && p.getConstraints(conf.constraintPropagationEnabled).contains(cond) } if (prunedPredicates.isEmpty) { f @@ -736,7 +755,8 @@ object PushDownPredicate extends Rule[LogicalPlan] with PredicateHelper { // implies that, for a given input row, the output are determined by the expression's initial // state and all the input rows processed before. In another word, the order of input rows // matters for non-deterministic expressions, while pushing down predicates changes the order. - case filter @ Filter(condition, project @ Project(fields, grandChild)) + // This also applies to Aggregate. + case Filter(condition, project @ Project(fields, grandChild)) if fields.forall(_.deterministic) && canPushThroughCondition(grandChild, condition) => // Create a map of Aliases to their values from the child projection. @@ -747,33 +767,8 @@ object PushDownPredicate extends Rule[LogicalPlan] with PredicateHelper { project.copy(child = Filter(replaceAlias(condition, aliasMap), grandChild)) - // Push [[Filter]] operators through [[Window]] operators. Parts of the predicate that can be - // pushed beneath must satisfy the following conditions: - // 1. All the expressions are part of window partitioning key. The expressions can be compound. - // 2. Deterministic. - // 3. Placed before any non-deterministic predicates. - case filter @ Filter(condition, w: Window) - if w.partitionSpec.forall(_.isInstanceOf[AttributeReference]) => - val partitionAttrs = AttributeSet(w.partitionSpec.flatMap(_.references)) - - val (candidates, containingNonDeterministic) = - splitConjunctivePredicates(condition).span(_.deterministic) - - val (pushDown, rest) = candidates.partition { cond => - cond.references.subsetOf(partitionAttrs) - } - - val stayUp = rest ++ containingNonDeterministic - - if (pushDown.nonEmpty) { - val pushDownPredicate = pushDown.reduce(And) - val newWindow = w.copy(child = Filter(pushDownPredicate, w.child)) - if (stayUp.isEmpty) newWindow else Filter(stayUp.reduce(And), newWindow) - } else { - filter - } - - case filter @ Filter(condition, aggregate: Aggregate) => + case filter @ Filter(condition, aggregate: Aggregate) + if aggregate.aggregateExpressions.forall(_.deterministic) => // Find all the aliased expressions in the aggregate list that don't include any actual // AggregateExpression, and create a map from the alias to the expression val aliasMap = AttributeMap(aggregate.aggregateExpressions.collect { @@ -804,6 +799,32 @@ object PushDownPredicate extends Rule[LogicalPlan] with PredicateHelper { filter } + // Push [[Filter]] operators through [[Window]] operators. Parts of the predicate that can be + // pushed beneath must satisfy the following conditions: + // 1. All the expressions are part of window partitioning key. The expressions can be compound. + // 2. Deterministic. + // 3. Placed before any non-deterministic predicates. + case filter @ Filter(condition, w: Window) + if w.partitionSpec.forall(_.isInstanceOf[AttributeReference]) => + val partitionAttrs = AttributeSet(w.partitionSpec.flatMap(_.references)) + + val (candidates, containingNonDeterministic) = + splitConjunctivePredicates(condition).span(_.deterministic) + + val (pushDown, rest) = candidates.partition { cond => + cond.references.subsetOf(partitionAttrs) + } + + val stayUp = rest ++ containingNonDeterministic + + if (pushDown.nonEmpty) { + val pushDownPredicate = pushDown.reduce(And) + val newWindow = w.copy(child = Filter(pushDownPredicate, w.child)) + if (stayUp.isEmpty) newWindow else Filter(stayUp.reduce(And), newWindow) + } else { + filter + } + case filter @ Filter(condition, union: Union) => // Union could change the rows, so non-deterministic predicate can't be pushed down val (pushDown, stayUp) = splitConjunctivePredicates(condition).span(_.deterministic) @@ -829,7 +850,7 @@ object PushDownPredicate extends Rule[LogicalPlan] with PredicateHelper { filter } - case filter @ Filter(condition, u: UnaryNode) + case filter @ Filter(_, u: UnaryNode) if canPushThrough(u) && u.expressions.forall(_.deterministic) => pushDownPredicate(filter, u.child) { predicate => u.withNewChildren(Seq(Filter(predicate, u.child))) @@ -887,7 +908,7 @@ object PushDownPredicate extends Rule[LogicalPlan] with PredicateHelper { private def canPushThroughCondition(plan: LogicalPlan, condition: Expression): Boolean = { val attributes = plan.outputSet val matched = condition.find { - case PredicateSubquery(p, _, _, _) => p.outputSet.intersect(attributes).nonEmpty + case s: SubqueryExpression => s.plan.outputSet.intersect(attributes).nonEmpty case _ => false } matched.isEmpty @@ -1038,7 +1059,7 @@ object CombineLimits extends Rule[LogicalPlan] { * the join between R and S is not a cartesian product and therefore should be allowed. * The predicate R.r = S.s is not recognized as a join condition until the ReorderJoin rule. */ -case class CheckCartesianProducts(conf: CatalystConf) +case class CheckCartesianProducts(conf: SQLConf) extends Rule[LogicalPlan] with PredicateHelper { /** * Check if a join is a cartesian product. Returns true if @@ -1073,7 +1094,7 @@ case class CheckCartesianProducts(conf: CatalystConf) * This uses the same rules for increasing the precision and scale of the output as * [[org.apache.spark.sql.catalyst.analysis.DecimalPrecision]]. */ -case class DecimalAggregates(conf: CatalystConf) extends Rule[LogicalPlan] { +case class DecimalAggregates(conf: SQLConf) extends Rule[LogicalPlan] { import Decimal.MAX_LONG_DIGITS /** Maximum number of decimal digits representable precisely in a Double */ diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/StarSchemaDetection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/StarSchemaDetection.scala new file mode 100644 index 0000000000000..97ee9988386dd --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/StarSchemaDetection.scala @@ -0,0 +1,351 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.optimizer + +import scala.annotation.tailrec + +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.planning.PhysicalOperation +import org.apache.spark.sql.catalyst.plans._ +import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.internal.SQLConf + +/** + * Encapsulates star-schema detection logic. + */ +case class StarSchemaDetection(conf: SQLConf) extends PredicateHelper { + + /** + * Star schema consists of one or more fact tables referencing a number of dimension + * tables. In general, star-schema joins are detected using the following conditions: + * 1. Informational RI constraints (reliable detection) + * + Dimension contains a primary key that is being joined to the fact table. + * + Fact table contains foreign keys referencing multiple dimension tables. + * 2. Cardinality based heuristics + * + Usually, the table with the highest cardinality is the fact table. + * + Table being joined with the most number of tables is the fact table. + * + * To detect star joins, the algorithm uses a combination of the above two conditions. + * The fact table is chosen based on the cardinality heuristics, and the dimension + * tables are chosen based on the RI constraints. A star join will consist of the largest + * fact table joined with the dimension tables on their primary keys. To detect that a + * column is a primary key, the algorithm uses table and column statistics. + * + * The algorithm currently returns only the star join with the largest fact table. + * Choosing the largest fact table on the driving arm to avoid large inners is in + * general a good heuristic. This restriction will be lifted to observe multiple + * star joins. + * + * The highlights of the algorithm are the following: + * + * Given a set of joined tables/plans, the algorithm first verifies if they are eligible + * for star join detection. An eligible plan is a base table access with valid statistics. + * A base table access represents Project or Filter operators above a LeafNode. Conservatively, + * the algorithm only considers base table access as part of a star join since they provide + * reliable statistics. This restriction can be lifted with the CBO enablement by default. + * + * If some of the plans are not base table access, or statistics are not available, the algorithm + * returns an empty star join plan since, in the absence of statistics, it cannot make + * good planning decisions. Otherwise, the algorithm finds the table with the largest cardinality + * (number of rows), which is assumed to be a fact table. + * + * Next, it computes the set of dimension tables for the current fact table. A dimension table + * is assumed to be in a RI relationship with a fact table. To infer column uniqueness, + * the algorithm compares the number of distinct values with the total number of rows in the + * table. If their relative difference is within certain limits (i.e. ndvMaxError * 2, adjusted + * based on 1TB TPC-DS data), the column is assumed to be unique. + */ + def findStarJoins( + input: Seq[LogicalPlan], + conditions: Seq[Expression]): Seq[LogicalPlan] = { + + val emptyStarJoinPlan = Seq.empty[LogicalPlan] + + if (input.size < 2) { + emptyStarJoinPlan + } else { + // Find if the input plans are eligible for star join detection. + // An eligible plan is a base table access with valid statistics. + val foundEligibleJoin = input.forall { + case PhysicalOperation(_, _, t: LeafNode) if t.stats(conf).rowCount.isDefined => true + case _ => false + } + + if (!foundEligibleJoin) { + // Some plans don't have stats or are complex plans. Conservatively, + // return an empty star join. This restriction can be lifted + // once statistics are propagated in the plan. + emptyStarJoinPlan + } else { + // Find the fact table using cardinality based heuristics i.e. + // the table with the largest number of rows. + val sortedFactTables = input.map { plan => + TableAccessCardinality(plan, getTableAccessCardinality(plan)) + }.collect { case t @ TableAccessCardinality(_, Some(_)) => + t + }.sortBy(_.size)(implicitly[Ordering[Option[BigInt]]].reverse) + + sortedFactTables match { + case Nil => + emptyStarJoinPlan + case table1 :: table2 :: _ + if table2.size.get.toDouble > conf.starSchemaFTRatio * table1.size.get.toDouble => + // If the top largest tables have comparable number of rows, return an empty star plan. + // This restriction will be lifted when the algorithm is generalized + // to return multiple star plans. + emptyStarJoinPlan + case TableAccessCardinality(factTable, _) :: rest => + // Find the fact table joins. + val allFactJoins = rest.collect { case TableAccessCardinality(plan, _) + if findJoinConditions(factTable, plan, conditions).nonEmpty => + plan + } + + // Find the corresponding join conditions. + val allFactJoinCond = allFactJoins.flatMap { plan => + val joinCond = findJoinConditions(factTable, plan, conditions) + joinCond + } + + // Verify if the join columns have valid statistics. + // Allow any relational comparison between the tables. Later + // we will heuristically choose a subset of equi-join + // tables. + val areStatsAvailable = allFactJoins.forall { dimTable => + allFactJoinCond.exists { + case BinaryComparison(lhs: AttributeReference, rhs: AttributeReference) => + val dimCol = if (dimTable.outputSet.contains(lhs)) lhs else rhs + val factCol = if (factTable.outputSet.contains(lhs)) lhs else rhs + hasStatistics(dimCol, dimTable) && hasStatistics(factCol, factTable) + case _ => false + } + } + + if (!areStatsAvailable) { + emptyStarJoinPlan + } else { + // Find the subset of dimension tables. A dimension table is assumed to be in a + // RI relationship with the fact table. Only consider equi-joins + // between a fact and a dimension table to avoid expanding joins. + val eligibleDimPlans = allFactJoins.filter { dimTable => + allFactJoinCond.exists { + case cond @ Equality(lhs: AttributeReference, rhs: AttributeReference) => + val dimCol = if (dimTable.outputSet.contains(lhs)) lhs else rhs + isUnique(dimCol, dimTable) + case _ => false + } + } + + if (eligibleDimPlans.isEmpty || eligibleDimPlans.size < 2) { + // An eligible star join was not found since the join is not + // an RI join, or the star join is an expanding join. + // Also, a star would involve more than one dimension table. + emptyStarJoinPlan + } else { + factTable +: eligibleDimPlans + } + } + } + } + } + } + + /** + * Determines if a column referenced by a base table access is a primary key. + * A column is a PK if it is not nullable and has unique values. + * To determine if a column has unique values in the absence of informational + * RI constraints, the number of distinct values is compared to the total + * number of rows in the table. If their relative difference + * is within the expected limits (i.e. 2 * spark.sql.statistics.ndv.maxError based + * on TPC-DS data results), the column is assumed to have unique values. + */ + private def isUnique( + column: Attribute, + plan: LogicalPlan): Boolean = plan match { + case PhysicalOperation(_, _, t: LeafNode) => + val leafCol = findLeafNodeCol(column, plan) + leafCol match { + case Some(col) if t.outputSet.contains(col) => + val stats = t.stats(conf) + stats.rowCount match { + case Some(rowCount) if rowCount >= 0 => + if (stats.attributeStats.nonEmpty && stats.attributeStats.contains(col)) { + val colStats = stats.attributeStats.get(col) + if (colStats.get.nullCount > 0) { + false + } else { + val distinctCount = colStats.get.distinctCount + val relDiff = math.abs((distinctCount.toDouble / rowCount.toDouble) - 1.0d) + // ndvMaxErr adjusted based on TPCDS 1TB data results + relDiff <= conf.ndvMaxError * 2 + } + } else { + false + } + case None => false + } + case None => false + } + case _ => false + } + + /** + * Given a column over a base table access, it returns + * the leaf node column from which the input column is derived. + */ + @tailrec + private def findLeafNodeCol( + column: Attribute, + plan: LogicalPlan): Option[Attribute] = plan match { + case pl @ PhysicalOperation(_, _, _: LeafNode) => + pl match { + case t: LeafNode if t.outputSet.contains(column) => + Option(column) + case p: Project if p.outputSet.exists(_.semanticEquals(column)) => + val col = p.outputSet.find(_.semanticEquals(column)).get + findLeafNodeCol(col, p.child) + case f: Filter => + findLeafNodeCol(column, f.child) + case _ => None + } + case _ => None + } + + /** + * Checks if a column has statistics. + * The column is assumed to be over a base table access. + */ + private def hasStatistics( + column: Attribute, + plan: LogicalPlan): Boolean = plan match { + case PhysicalOperation(_, _, t: LeafNode) => + val leafCol = findLeafNodeCol(column, plan) + leafCol match { + case Some(col) if t.outputSet.contains(col) => + val stats = t.stats(conf) + stats.attributeStats.nonEmpty && stats.attributeStats.contains(col) + case None => false + } + case _ => false + } + + /** + * Returns the join predicates between two input plans. It only + * considers basic comparison operators. + */ + @inline + private def findJoinConditions( + plan1: LogicalPlan, + plan2: LogicalPlan, + conditions: Seq[Expression]): Seq[Expression] = { + val refs = plan1.outputSet ++ plan2.outputSet + conditions.filter { + case BinaryComparison(_, _) => true + case _ => false + }.filterNot(canEvaluate(_, plan1)) + .filterNot(canEvaluate(_, plan2)) + .filter(_.references.subsetOf(refs)) + } + + /** + * Checks if a star join is a selective join. A star join is assumed + * to be selective if there are local predicates on the dimension + * tables. + */ + private def isSelectiveStarJoin( + dimTables: Seq[LogicalPlan], + conditions: Seq[Expression]): Boolean = dimTables.exists { + case plan @ PhysicalOperation(_, p, _: LeafNode) => + // Checks if any condition applies to the dimension tables. + // Exclude the IsNotNull predicates until predicate selectivity is available. + // In most cases, this predicate is artificially introduced by the Optimizer + // to enforce nullability constraints. + val localPredicates = conditions.filterNot(_.isInstanceOf[IsNotNull]) + .exists(canEvaluate(_, plan)) + + // Checks if there are any predicates pushed down to the base table access. + val pushedDownPredicates = p.nonEmpty && !p.forall(_.isInstanceOf[IsNotNull]) + + localPredicates || pushedDownPredicates + case _ => false + } + + /** + * Helper case class to hold (plan, rowCount) pairs. + */ + private case class TableAccessCardinality(plan: LogicalPlan, size: Option[BigInt]) + + /** + * Returns the cardinality of a base table access. A base table access represents + * a LeafNode, or Project or Filter operators above a LeafNode. + */ + private def getTableAccessCardinality( + input: LogicalPlan): Option[BigInt] = input match { + case PhysicalOperation(_, cond, t: LeafNode) if t.stats(conf).rowCount.isDefined => + if (conf.cboEnabled && input.stats(conf).rowCount.isDefined) { + Option(input.stats(conf).rowCount.get) + } else { + Option(t.stats(conf).rowCount.get) + } + case _ => None + } + + /** + * Reorders a star join based on heuristics. It is called from ReorderJoin if CBO is disabled. + * 1) Finds the star join with the largest fact table. + * 2) Places the fact table the driving arm of the left-deep tree. + * This plan avoids large table access on the inner, and thus favor hash joins. + * 3) Applies the most selective dimensions early in the plan to reduce the amount of + * data flow. + */ + def reorderStarJoins( + input: Seq[(LogicalPlan, InnerLike)], + conditions: Seq[Expression]): Seq[(LogicalPlan, InnerLike)] = { + assert(input.size >= 2) + + val emptyStarJoinPlan = Seq.empty[(LogicalPlan, InnerLike)] + + // Find the eligible star plans. Currently, it only returns + // the star join with the largest fact table. + val eligibleJoins = input.collect{ case (plan, Inner) => plan } + val starPlan = findStarJoins(eligibleJoins, conditions) + + if (starPlan.isEmpty) { + emptyStarJoinPlan + } else { + val (factTable, dimTables) = (starPlan.head, starPlan.tail) + + // Only consider selective joins. This case is detected by observing local predicates + // on the dimension tables. In a star schema relationship, the join between the fact and the + // dimension table is a FK-PK join. Heuristically, a selective dimension may reduce + // the result of a join. + if (isSelectiveStarJoin(dimTables, conditions)) { + val reorderDimTables = dimTables.map { plan => + TableAccessCardinality(plan, getTableAccessCardinality(plan)) + }.sortBy(_.size).map { + case TableAccessCardinality(p1, _) => p1 + } + + val reorderStarPlan = factTable +: reorderDimTables + reorderStarPlan.map(plan => (plan, Inner)) + } else { + emptyStarJoinPlan + } + } + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala index 21d1cd5932620..34382bd272406 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala @@ -19,14 +19,15 @@ package org.apache.spark.sql.catalyst.optimizer import scala.collection.immutable.HashSet -import org.apache.spark.sql.catalyst.CatalystConf import org.apache.spark.sql.catalyst.analysis._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate._ import org.apache.spark.sql.catalyst.expressions.Literal.{FalseLiteral, TrueLiteral} +import org.apache.spark.sql.catalyst.expressions.objects.AssertNotNull import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules._ +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ /* @@ -115,7 +116,7 @@ object ReorderAssociativeOperator extends Rule[LogicalPlan] { * 2. Replaces [[In (value, seq[Literal])]] with optimized version * [[InSet (value, HashSet[Literal])]] which is much faster. */ -case class OptimizeIn(conf: CatalystConf) extends Rule[LogicalPlan] { +case class OptimizeIn(conf: SQLConf) extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = plan transform { case q: LogicalPlan => q transformExpressionsDown { case expr @ In(v, list) if expr.inSetConvertible => @@ -153,6 +154,11 @@ object BooleanSimplification extends Rule[LogicalPlan] with PredicateHelper { case TrueLiteral Or _ => TrueLiteral case _ Or TrueLiteral => TrueLiteral + case a And b if Not(a).semanticEquals(b) => FalseLiteral + case a Or b if Not(a).semanticEquals(b) => TrueLiteral + case a And b if a.semanticEquals(Not(b)) => FalseLiteral + case a Or b if a.semanticEquals(Not(b)) => TrueLiteral + case a And b if a.semanticEquals(b) => a case a Or b if a.semanticEquals(b) => a @@ -346,36 +352,33 @@ object LikeSimplification extends Rule[LogicalPlan] { * equivalent [[Literal]] values. This rule is more specific with * Null value propagation from bottom to top of the expression tree. */ -case class NullPropagation(conf: CatalystConf) extends Rule[LogicalPlan] { - private def nonNullLiteral(e: Expression): Boolean = e match { - case Literal(null, _) => false - case _ => true +case class NullPropagation(conf: SQLConf) extends Rule[LogicalPlan] { + private def isNullLiteral(e: Expression): Boolean = e match { + case Literal(null, _) => true + case _ => false } def apply(plan: LogicalPlan): LogicalPlan = plan transform { case q: LogicalPlan => q transformExpressionsUp { case e @ WindowExpression(Cast(Literal(0L, _), _, _), _) => Cast(Literal(0L), e.dataType, Option(conf.sessionLocalTimeZone)) - case e @ AggregateExpression(Count(exprs), _, _, _) if !exprs.exists(nonNullLiteral) => + case e @ AggregateExpression(Count(exprs), _, _, _) if exprs.forall(isNullLiteral) => Cast(Literal(0L), e.dataType, Option(conf.sessionLocalTimeZone)) - case e @ IsNull(c) if !c.nullable => Literal.create(false, BooleanType) - case e @ IsNotNull(c) if !c.nullable => Literal.create(true, BooleanType) - case e @ GetArrayItem(Literal(null, _), _) => Literal.create(null, e.dataType) - case e @ GetArrayItem(_, Literal(null, _)) => Literal.create(null, e.dataType) - case e @ GetMapValue(Literal(null, _), _) => Literal.create(null, e.dataType) - case e @ GetMapValue(_, Literal(null, _)) => Literal.create(null, e.dataType) - case e @ GetStructField(Literal(null, _), _, _) => Literal.create(null, e.dataType) - case e @ GetArrayStructFields(Literal(null, _), _, _, _, _) => - Literal.create(null, e.dataType) - case e @ EqualNullSafe(Literal(null, _), r) => IsNull(r) - case e @ EqualNullSafe(l, Literal(null, _)) => IsNull(l) case ae @ AggregateExpression(Count(exprs), _, false, _) if !exprs.exists(_.nullable) => // This rule should be only triggered when isDistinct field is false. ae.copy(aggregateFunction = Count(Literal(1))) + case IsNull(c) if !c.nullable => Literal.create(false, BooleanType) + case IsNotNull(c) if !c.nullable => Literal.create(true, BooleanType) + + case EqualNullSafe(Literal(null, _), r) => IsNull(r) + case EqualNullSafe(l, Literal(null, _)) => IsNull(l) + + case AssertNotNull(c, _) if !c.nullable => c + // For Coalesce, remove null literals. case e @ Coalesce(children) => - val newChildren = children.filter(nonNullLiteral) + val newChildren = children.filterNot(isNullLiteral) if (newChildren.isEmpty) { Literal.create(null, e.dataType) } else if (newChildren.length == 1) { @@ -384,33 +387,13 @@ case class NullPropagation(conf: CatalystConf) extends Rule[LogicalPlan] { Coalesce(newChildren) } - case e @ Substring(Literal(null, _), _, _) => Literal.create(null, e.dataType) - case e @ Substring(_, Literal(null, _), _) => Literal.create(null, e.dataType) - case e @ Substring(_, _, Literal(null, _)) => Literal.create(null, e.dataType) - - // Put exceptional cases above if any - case e @ BinaryArithmetic(Literal(null, _), _) => Literal.create(null, e.dataType) - case e @ BinaryArithmetic(_, Literal(null, _)) => Literal.create(null, e.dataType) - - case e @ BinaryComparison(Literal(null, _), _) => Literal.create(null, e.dataType) - case e @ BinaryComparison(_, Literal(null, _)) => Literal.create(null, e.dataType) - - case e: StringRegexExpression => e.children match { - case Literal(null, _) :: right :: Nil => Literal.create(null, e.dataType) - case left :: Literal(null, _) :: Nil => Literal.create(null, e.dataType) - case _ => e - } - - case e: StringPredicate => e.children match { - case Literal(null, _) :: right :: Nil => Literal.create(null, e.dataType) - case left :: Literal(null, _) :: Nil => Literal.create(null, e.dataType) - case _ => e - } - - // If the value expression is NULL then transform the In expression to - // Literal(null) - case In(Literal(null, _), list) => Literal.create(null, BooleanType) + // If the value expression is NULL then transform the In expression to null literal. + case In(Literal(null, _), _) => Literal.create(null, BooleanType) + // Non-leaf NullIntolerant expressions will return null, if at least one of its children is + // a null literal. + case e: NullIntolerant if e.children.exists(isNullLiteral) => + Literal.create(null, e.dataType) } } } @@ -507,7 +490,7 @@ object FoldablePropagation extends Rule[LogicalPlan] { /** * Optimizes expressions by replacing according to CodeGen configuration. */ -case class OptimizeCodegen(conf: CatalystConf) extends Rule[LogicalPlan] { +case class OptimizeCodegen(conf: SQLConf) extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions { case e: CaseWhen if canCodegen(e) => e.toCodegen() } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/joins.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/joins.scala index bfe529e21e9ad..2fe3039774423 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/joins.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/joins.scala @@ -24,15 +24,17 @@ import org.apache.spark.sql.catalyst.planning.ExtractFiltersAndInnerJoins import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules._ +import org.apache.spark.sql.internal.SQLConf /** * Reorder the joins and push all the conditions into join, so that the bottom ones have at least * one condition. * * The order of joins will not be changed if all of them already have at least one condition. + * + * If star schema detection is enabled, reorder the star join plans based on heuristics. */ -object ReorderJoin extends Rule[LogicalPlan] with PredicateHelper { - +case class ReorderJoin(conf: SQLConf) extends Rule[LogicalPlan] with PredicateHelper { /** * Join a list of plans together and push down the conditions into them. * @@ -42,7 +44,7 @@ object ReorderJoin extends Rule[LogicalPlan] with PredicateHelper { * @param conditions a list of condition for join. */ @tailrec - def createOrderedJoin(input: Seq[(LogicalPlan, InnerLike)], conditions: Seq[Expression]) + final def createOrderedJoin(input: Seq[(LogicalPlan, InnerLike)], conditions: Seq[Expression]) : LogicalPlan = { assert(input.size >= 2) if (input.size == 2) { @@ -83,9 +85,19 @@ object ReorderJoin extends Rule[LogicalPlan] with PredicateHelper { } def apply(plan: LogicalPlan): LogicalPlan = plan transform { - case j @ ExtractFiltersAndInnerJoins(input, conditions) + case ExtractFiltersAndInnerJoins(input, conditions) if input.size > 2 && conditions.nonEmpty => - createOrderedJoin(input, conditions) + if (conf.starSchemaDetection && !conf.cboEnabled) { + val starJoinPlan = StarSchemaDetection(conf).reorderStarJoins(input, conditions) + if (starJoinPlan.nonEmpty) { + val rest = input.filterNot(starJoinPlan.contains(_)) + createOrderedJoin(starJoinPlan ++ rest, conditions) + } else { + createOrderedJoin(input, conditions) + } + } else { + createOrderedJoin(input, conditions) + } } } @@ -101,7 +113,7 @@ object ReorderJoin extends Rule[LogicalPlan] with PredicateHelper { * * This rule should be executed before pushing down the Filter */ -object EliminateOuterJoin extends Rule[LogicalPlan] with PredicateHelper { +case class EliminateOuterJoin(conf: SQLConf) extends Rule[LogicalPlan] with PredicateHelper { /** * Returns whether the expression returns null or false when all inputs are nulls. @@ -117,12 +129,13 @@ object EliminateOuterJoin extends Rule[LogicalPlan] with PredicateHelper { } private def buildNewJoinType(filter: Filter, join: Join): JoinType = { - val conditions = splitConjunctivePredicates(filter.condition) ++ filter.constraints + val conditions = splitConjunctivePredicates(filter.condition) ++ + filter.getConstraints(conf.constraintPropagationEnabled) val leftConditions = conditions.filter(_.references.subsetOf(join.left.outputSet)) val rightConditions = conditions.filter(_.references.subsetOf(join.right.outputSet)) - val leftHasNonNullPredicate = leftConditions.exists(canFilterOutNull) - val rightHasNonNullPredicate = rightConditions.exists(canFilterOutNull) + lazy val leftHasNonNullPredicate = leftConditions.exists(canFilterOutNull) + lazy val rightHasNonNullPredicate = rightConditions.exists(canFilterOutNull) join.joinType match { case RightOuter if leftHasNonNullPredicate => Inner diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/objects.scala index 174d546e22809..8cdc6425bcad8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/objects.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/objects.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.catalyst.optimizer import org.apache.spark.api.java.function.FilterFunction import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.objects._ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules._ @@ -65,7 +66,7 @@ object EliminateSerialization extends Rule[LogicalPlan] { /** * Combines two adjacent [[TypedFilter]]s, which operate on same type object in condition, into one, - * mering the filter functions into one conjunctive function. + * merging the filter functions into one conjunctive function. */ object CombineTypedFilters extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = plan transform { @@ -96,3 +97,15 @@ object CombineTypedFilters extends Rule[LogicalPlan] { } } } + +/** + * Removes MapObjects when the following conditions are satisfied + * 1. Mapobject(... lambdavariable(..., false) ...), which means types for input and output + * are primitive types with non-nullable + * 2. no custom collection class specified representation of data item. + */ +object EliminateMapObjects extends Rule[LogicalPlan] { + def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions { + case MapObjects(_, _, _, LambdaVariable(_, _, _, false), inputData, None) => inputData + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala index 4d62cce9da0ac..2a3e07aebe709 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala @@ -21,6 +21,7 @@ import scala.collection.mutable.ArrayBuffer import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate._ +import org.apache.spark.sql.catalyst.expressions.SubExprUtils._ import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules._ @@ -41,10 +42,17 @@ import org.apache.spark.sql.types._ * condition. */ object RewritePredicateSubquery extends Rule[LogicalPlan] with PredicateHelper { + private def getValueExpression(e: Expression): Seq[Expression] = { + e match { + case cns : CreateNamedStruct => cns.valExprs + case expr => Seq(expr) + } + } + def apply(plan: LogicalPlan): LogicalPlan = plan transform { case Filter(condition, child) => val (withSubquery, withoutSubquery) = - splitConjunctivePredicates(condition).partition(PredicateSubquery.hasPredicateSubquery) + splitConjunctivePredicates(condition).partition(SubqueryExpression.hasInOrExistsSubquery) // Construct the pruned filter condition. val newFilter: LogicalPlan = withoutSubquery match { @@ -54,27 +62,37 @@ object RewritePredicateSubquery extends Rule[LogicalPlan] with PredicateHelper { // Filter the plan by applying left semi and left anti joins. withSubquery.foldLeft(newFilter) { - case (p, PredicateSubquery(sub, conditions, _, _)) => + case (p, Exists(sub, conditions, _)) => val (joinCond, outerPlan) = rewriteExistentialExpr(conditions, p) Join(outerPlan, sub, LeftSemi, joinCond) - case (p, Not(PredicateSubquery(sub, conditions, false, _))) => + case (p, Not(Exists(sub, conditions, _))) => val (joinCond, outerPlan) = rewriteExistentialExpr(conditions, p) Join(outerPlan, sub, LeftAnti, joinCond) - case (p, Not(PredicateSubquery(sub, conditions, true, _))) => + case (p, In(value, Seq(ListQuery(sub, conditions, _)))) => + val inConditions = getValueExpression(value).zip(sub.output).map(EqualTo.tupled) + val (joinCond, outerPlan) = rewriteExistentialExpr(inConditions ++ conditions, p) + Join(outerPlan, sub, LeftSemi, joinCond) + case (p, Not(In(value, Seq(ListQuery(sub, conditions, _))))) => // This is a NULL-aware (left) anti join (NAAJ) e.g. col NOT IN expr // Construct the condition. A NULL in one of the conditions is regarded as a positive // result; such a row will be filtered out by the Anti-Join operator. // Note that will almost certainly be planned as a Broadcast Nested Loop join. // Use EXISTS if performance matters to you. - val (joinCond, outerPlan) = rewriteExistentialExpr(conditions, p) + val inConditions = getValueExpression(value).zip(sub.output).map(EqualTo.tupled) + val (joinCond, outerPlan) = rewriteExistentialExpr(inConditions, p) // Expand the NOT IN expression with the NULL-aware semantic // to its full form. That is from: - // (a1,b1,...) = (a2,b2,...) + // (a1,a2,...) = (b1,b2,...) // to - // (a1=a2 OR isnull(a1=a2)) AND (b1=b2 OR isnull(b1=b2)) AND ... + // (a1=b1 OR isnull(a1=b1)) AND (a2=b2 OR isnull(a2=b2)) AND ... val joinConds = splitConjunctivePredicates(joinCond.get) - val pairs = joinConds.map(c => Or(c, IsNull(c))).reduceLeft(And) + // After that, add back the correlated join predicate(s) in the subquery + // Example: + // SELECT ... FROM A WHERE A.A1 NOT IN (SELECT B.B1 FROM B WHERE B.B2 = A.A2 AND B.B3 > 1) + // will have the final conditions in the LEFT ANTI as + // (A.A1 = B.B1 OR ISNULL(A.A1 = B.B1)) AND (B.B2 = A.A2) + val pairs = (joinConds.map(c => Or(c, IsNull(c))) ++ conditions).reduceLeft(And) Join(outerPlan, sub, LeftAnti, Option(pairs)) case (p, predicate) => val (newCond, inputPlan) = rewriteExistentialExpr(Seq(predicate), p) @@ -83,11 +101,10 @@ object RewritePredicateSubquery extends Rule[LogicalPlan] with PredicateHelper { } /** - * Given a predicate expression and an input plan, it rewrites - * any embedded existential sub-query into an existential join. - * It returns the rewritten expression together with the updated plan. - * Currently, it does not support null-aware joins. Embedded NOT IN predicates - * are blocked in the Analyzer. + * Given a predicate expression and an input plan, it rewrites any embedded existential sub-query + * into an existential join. It returns the rewritten expression together with the updated plan. + * Currently, it does not support NOT IN nested inside a NOT expression. This case is blocked in + * the Analyzer. */ private def rewriteExistentialExpr( exprs: Seq[Expression], @@ -95,17 +112,138 @@ object RewritePredicateSubquery extends Rule[LogicalPlan] with PredicateHelper { var newPlan = plan val newExprs = exprs.map { e => e transformUp { - case PredicateSubquery(sub, conditions, nullAware, _) => - // TODO: support null-aware join + case Exists(sub, conditions, _) => val exists = AttributeReference("exists", BooleanType, nullable = false)() newPlan = Join(newPlan, sub, ExistenceJoin(exists), conditions.reduceLeftOption(And)) exists - } + case In(value, Seq(ListQuery(sub, conditions, _))) => + val exists = AttributeReference("exists", BooleanType, nullable = false)() + val inConditions = getValueExpression(value).zip(sub.output).map(EqualTo.tupled) + val newConditions = (inConditions ++ conditions).reduceLeftOption(And) + newPlan = Join(newPlan, sub, ExistenceJoin(exists), newConditions) + exists + } } (newExprs.reduceOption(And), newPlan) } } + /** + * Pull out all (outer) correlated predicates from a given subquery. This method removes the + * correlated predicates from subquery [[Filter]]s and adds the references of these predicates + * to all intermediate [[Project]] and [[Aggregate]] clauses (if they are missing) in order to + * be able to evaluate the predicates at the top level. + * + * TODO: Look to merge this rule with RewritePredicateSubquery. + */ +object PullupCorrelatedPredicates extends Rule[LogicalPlan] with PredicateHelper { + /** + * Returns the correlated predicates and a updated plan that removes the outer references. + */ + private def pullOutCorrelatedPredicates( + sub: LogicalPlan, + outer: Seq[LogicalPlan]): (LogicalPlan, Seq[Expression]) = { + val predicateMap = scala.collection.mutable.Map.empty[LogicalPlan, Seq[Expression]] + + /** Determine which correlated predicate references are missing from this plan. */ + def missingReferences(p: LogicalPlan): AttributeSet = { + val localPredicateReferences = p.collect(predicateMap) + .flatten + .map(_.references) + .reduceOption(_ ++ _) + .getOrElse(AttributeSet.empty) + localPredicateReferences -- p.outputSet + } + + // Simplify the predicates before pulling them out. + val transformed = BooleanSimplification(sub) transformUp { + case f @ Filter(cond, child) => + val (correlated, local) = + splitConjunctivePredicates(cond).partition(containsOuter) + + // Rewrite the filter without the correlated predicates if any. + correlated match { + case Nil => f + case xs if local.nonEmpty => + val newFilter = Filter(local.reduce(And), child) + predicateMap += newFilter -> xs + newFilter + case xs => + predicateMap += child -> xs + child + } + case p @ Project(expressions, child) => + val referencesToAdd = missingReferences(p) + if (referencesToAdd.nonEmpty) { + Project(expressions ++ referencesToAdd, child) + } else { + p + } + case a @ Aggregate(grouping, expressions, child) => + val referencesToAdd = missingReferences(a) + if (referencesToAdd.nonEmpty) { + Aggregate(grouping ++ referencesToAdd, expressions ++ referencesToAdd, child) + } else { + a + } + case p => + p + } + + // Make sure the inner and the outer query attributes do not collide. + // In case of a collision, change the subquery plan's output to use + // different attribute by creating alias(s). + val baseConditions = predicateMap.values.flatten.toSeq + val (newPlan, newCond) = if (outer.nonEmpty) { + val outputSet = outer.map(_.outputSet).reduce(_ ++ _) + val duplicates = transformed.outputSet.intersect(outputSet) + val (plan, deDuplicatedConditions) = if (duplicates.nonEmpty) { + val aliasMap = AttributeMap(duplicates.map { dup => + dup -> Alias(dup, dup.toString)() + }.toSeq) + val aliasedExpressions = transformed.output.map { ref => + aliasMap.getOrElse(ref, ref) + } + val aliasedProjection = Project(aliasedExpressions, transformed) + val aliasedConditions = baseConditions.map(_.transform { + case ref: Attribute => aliasMap.getOrElse(ref, ref).toAttribute + }) + (aliasedProjection, aliasedConditions) + } else { + (transformed, baseConditions) + } + (plan, stripOuterReferences(deDuplicatedConditions)) + } else { + (transformed, stripOuterReferences(baseConditions)) + } + (newPlan, newCond) + } + + private def rewriteSubQueries(plan: LogicalPlan, outerPlans: Seq[LogicalPlan]): LogicalPlan = { + plan transformExpressions { + case ScalarSubquery(sub, children, exprId) if children.nonEmpty => + val (newPlan, newCond) = pullOutCorrelatedPredicates(sub, outerPlans) + ScalarSubquery(newPlan, newCond, exprId) + case Exists(sub, children, exprId) if children.nonEmpty => + val (newPlan, newCond) = pullOutCorrelatedPredicates(sub, outerPlans) + Exists(newPlan, newCond, exprId) + case ListQuery(sub, _, exprId) => + val (newPlan, newCond) = pullOutCorrelatedPredicates(sub, outerPlans) + ListQuery(newPlan, newCond, exprId) + } + } + + /** + * Pull up the correlated predicates and rewrite all subqueries in an operator tree.. + */ + def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { + case f @ Filter(_, a: Aggregate) => + rewriteSubQueries(f, Seq(a, a.child)) + // Only a few unary nodes (Project/Filter/Aggregate) can contain subqueries. + case q: UnaryNode => + rewriteSubQueries(q, q.children) + } +} /** * This rule rewrites correlated [[ScalarSubquery]] expressions into LEFT OUTER joins. @@ -169,7 +307,7 @@ object RewriteCorrelatedScalarSubquery extends Rule[LogicalPlan] { // and Project operators, followed by an optional Filter, followed by an // Aggregate. Traverse the operators recursively. def evalPlan(lp : LogicalPlan) : Map[ExprId, Option[Any]] = lp match { - case SubqueryAlias(_, child, _) => evalPlan(child) + case SubqueryAlias(_, child) => evalPlan(child) case Filter(condition, child) => val bindings = evalPlan(child) if (bindings.isEmpty) bindings @@ -227,7 +365,7 @@ object RewriteCorrelatedScalarSubquery extends Rule[LogicalPlan] { topPart += p bottomPart = child - case s @ SubqueryAlias(_, child, _) => + case s @ SubqueryAlias(_, child) => topPart += s bottomPart = child @@ -298,8 +436,8 @@ object RewriteCorrelatedScalarSubquery extends Rule[LogicalPlan] { topPart.reverse.foreach { case Project(projList, _) => subqueryRoot = Project(projList ++ havingInputs, subqueryRoot) - case s @ SubqueryAlias(alias, _, None) => - subqueryRoot = SubqueryAlias(alias, subqueryRoot, None) + case s @ SubqueryAlias(alias, _) => + subqueryRoot = SubqueryAlias(alias, subqueryRoot) case op => sys.error(s"Unexpected operator $op in corelated subquery") } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/package.scala index 105cdf52500c6..f9c88d496e899 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/package.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/package.scala @@ -28,5 +28,4 @@ package object catalyst { * 2.10.* builds. See SI-6240 for more details. */ protected[sql] object ScalaReflectionLock - } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala index d2e091f4dda69..d2a9b4a9a9f59 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.catalyst.parser import java.sql.{Date, Timestamp} +import java.util.Locale import javax.xml.bind.DatatypeConverter import scala.collection.JavaConverters._ @@ -31,6 +32,7 @@ import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.{FunctionIdentifier, TableIdentifier} import org.apache.spark.sql.catalyst.analysis._ import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.aggregate.{First, Last} import org.apache.spark.sql.catalyst.parser.SqlBaseParser._ import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical._ @@ -75,6 +77,11 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with Logging { visitTableIdentifier(ctx.tableIdentifier) } + override def visitSingleFunctionIdentifier( + ctx: SingleFunctionIdentifierContext): FunctionIdentifier = withOrigin(ctx) { + visitFunctionIdentifier(ctx.functionIdentifier) + } + override def visitSingleDataType(ctx: SingleDataTypeContext): DataType = withOrigin(ctx) { visitSparkDataType(ctx.dataType) } @@ -108,7 +115,7 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with Logging { * This is only used for Common Table Expressions. */ override def visitNamedQuery(ctx: NamedQueryContext): SubqueryAlias = withOrigin(ctx) { - SubqueryAlias(ctx.name.getText, plan(ctx.query), None) + SubqueryAlias(ctx.name.getText, plan(ctx.query)) } /** @@ -208,7 +215,10 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with Logging { */ protected def visitNonOptionalPartitionSpec( ctx: PartitionSpecContext): Map[String, String] = withOrigin(ctx) { - visitPartitionSpec(ctx).mapValues(_.orNull).map(identity) + visitPartitionSpec(ctx).map { + case (key, None) => throw new ParseException(s"Found an empty partition key '$key'.", ctx) + case (key, Some(value)) => key -> value + } } /** @@ -492,7 +502,7 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with Logging { } /** - * Add an [[Aggregate]] to a logical plan. + * Add an [[Aggregate]] or [[GroupingSets]] to a logical plan. */ private def withAggregation( ctx: AggregationContext, @@ -519,7 +529,7 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with Logging { } /** - * Add a Hint to a logical plan. + * Add a [[Hint]] to a logical plan. */ private def withHints( ctx: HintContext, @@ -545,7 +555,7 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with Logging { } /** - * Create a single relation referenced in a FROM claused. This method is used when a part of the + * Create a single relation referenced in a FROM clause. This method is used when a part of the * join condition is nested, for example: * {{{ * select * from t1 join (t2 cross join t3) on col1 = col2 @@ -666,7 +676,7 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with Logging { val tableWithAlias = Option(ctx.strictIdentifier).map(_.getText) match { case Some(strictIdentifier) => - SubqueryAlias(strictIdentifier, table, None) + SubqueryAlias(strictIdentifier, table) case _ => table } tableWithAlias.optionalMap(ctx.sample)(withSample) @@ -731,7 +741,7 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with Logging { * Create an alias (SubqueryAlias) for a LogicalPlan. */ private def aliasPlan(alias: ParserRuleContext, plan: LogicalPlan): LogicalPlan = { - SubqueryAlias(alias.getText, plan, None) + SubqueryAlias(alias.getText, plan) } /** @@ -759,6 +769,14 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with Logging { TableIdentifier(ctx.table.getText, Option(ctx.db).map(_.getText)) } + /** + * Create a [[FunctionIdentifier]] from a 'functionName' or 'databaseName'.'functionName' pattern. + */ + override def visitFunctionIdentifier( + ctx: FunctionIdentifierContext): FunctionIdentifier = withOrigin(ctx) { + FunctionIdentifier(ctx.function.getText, Option(ctx.db).map(_.getText)) + } + /* ******************************************************************************************** * Expression parsing * ******************************************************************************************** */ @@ -917,6 +935,7 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with Logging { * - (NOT) LIKE * - (NOT) RLIKE * - IS (NOT) NULL. + * - IS (NOT) DISTINCT FROM */ private def withPredicate(e: Expression, ctx: PredicateContext): Expression = withOrigin(ctx) { // Invert a predicate if it has a valid NOT clause. @@ -944,6 +963,10 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with Logging { IsNotNull(e) case SqlBaseParser.NULL => IsNull(e) + case SqlBaseParser.DISTINCT if ctx.NOT != null => + EqualNullSafe(e, expression(ctx.right)) + case SqlBaseParser.DISTINCT => + Not(EqualNullSafe(e, expression(ctx.right))) } } @@ -1009,6 +1032,22 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with Logging { Cast(expression(ctx.expression), visitSparkDataType(ctx.dataType)) } + /** + * Create a [[First]] expression. + */ + override def visitFirst(ctx: FirstContext): Expression = withOrigin(ctx) { + val ignoreNullsExpr = ctx.IGNORE != null + First(expression(ctx.expression), Literal(ignoreNullsExpr)).toAggregateExpression() + } + + /** + * Create a [[Last]] expression. + */ + override def visitLast(ctx: LastContext): Expression = withOrigin(ctx) { + val ignoreNullsExpr = ctx.IGNORE != null + Last(expression(ctx.expression), Literal(ignoreNullsExpr)).toAggregateExpression() + } + /** * Create a (windowed) Function expression. */ @@ -1016,8 +1055,9 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with Logging { // Create the function call. val name = ctx.qualifiedName.getText val isDistinct = Option(ctx.setQuantifier()).exists(_.DISTINCT != null) - val arguments = ctx.expression().asScala.map(expression) match { - case Seq(UnresolvedStar(None)) if name.toLowerCase == "count" && !isDistinct => + val arguments = ctx.namedExpression().asScala.map(expression) match { + case Seq(UnresolvedStar(None)) + if name.toLowerCase(Locale.ROOT) == "count" && !isDistinct => // Transform COUNT(*) into COUNT(1). Seq(Literal(1)) case expressions => @@ -1127,7 +1167,7 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with Logging { * Create a [[CreateStruct]] expression. */ override def visitRowConstructor(ctx: RowConstructorContext): Expression = withOrigin(ctx) { - CreateStruct(ctx.expression.asScala.map(expression)) + CreateStruct(ctx.namedExpression().asScala.map(expression)) } /** @@ -1229,7 +1269,7 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with Logging { } else { direction.defaultNullOrdering } - SortOrder(expression(ctx.expression), direction, nullOrdering) + SortOrder(expression(ctx.expression), direction, nullOrdering, Set.empty) } /** @@ -1241,7 +1281,7 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with Logging { */ override def visitTypeConstructor(ctx: TypeConstructorContext): Literal = withOrigin(ctx) { val value = string(ctx.STRING) - val valueType = ctx.identifier.getText.toUpperCase + val valueType = ctx.identifier.getText.toUpperCase(Locale.ROOT) try { valueType match { case "DATE" => @@ -1397,7 +1437,8 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with Logging { import ctx._ val s = value.getText try { - val interval = (unit.getText.toLowerCase, Option(to).map(_.getText.toLowerCase)) match { + val unitText = unit.getText.toLowerCase(Locale.ROOT) + val interval = (unitText, Option(to).map(_.getText.toLowerCase(Locale.ROOT))) match { case (u, None) if u.endsWith("s") => // Handle plural forms, e.g: yearS/monthS/weekS/dayS/hourS/minuteS/hourS/... CalendarInterval.fromSingleUnitString(u.substring(0, u.length - 1), s) @@ -1435,7 +1476,8 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with Logging { * Resolve/create a primitive type. */ override def visitPrimitiveDataType(ctx: PrimitiveDataTypeContext): DataType = withOrigin(ctx) { - (ctx.identifier.getText.toLowerCase, ctx.INTEGER_VALUE().asScala.toList) match { + val dataType = ctx.identifier.getText.toLowerCase(Locale.ROOT) + (dataType, ctx.INTEGER_VALUE().asScala.toList) match { case ("boolean", Nil) => BooleanType case ("tinyint" | "byte", Nil) => ByteType case ("smallint" | "short", Nil) => ShortType @@ -1454,8 +1496,8 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with Logging { case ("decimal", precision :: scale :: Nil) => DecimalType(precision.getText.toInt, scale.getText.toInt) case (dt, params) => - throw new ParseException( - s"DataType $dt${params.mkString("(", ",", ")")} is not supported.", ctx) + val dtStr = if (params.nonEmpty) s"$dt(${params.mkString(",")})" else dt + throw new ParseException(s"DataType $dtStr is not supported.", ctx) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParseDriver.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParseDriver.scala index d687a85c18b63..dcccbd0ed8d6b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParseDriver.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParseDriver.scala @@ -22,11 +22,11 @@ import org.antlr.v4.runtime.misc.ParseCancellationException import org.apache.spark.internal.Logging import org.apache.spark.sql.AnalysisException -import org.apache.spark.sql.catalyst.TableIdentifier +import org.apache.spark.sql.catalyst.{FunctionIdentifier, TableIdentifier} import org.apache.spark.sql.catalyst.expressions.Expression import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.trees.Origin -import org.apache.spark.sql.types.DataType +import org.apache.spark.sql.types.{DataType, StructType} /** * Base SQL parsing infrastructure. @@ -34,8 +34,7 @@ import org.apache.spark.sql.types.DataType abstract class AbstractSqlParser extends ParserInterface with Logging { /** Creates/Resolves DataType for a given SQL string. */ - def parseDataType(sqlText: String): DataType = parse(sqlText) { parser => - // TODO add this to the parser interface. + override def parseDataType(sqlText: String): DataType = parse(sqlText) { parser => astBuilder.visitSingleDataType(parser.singleDataType()) } @@ -49,6 +48,21 @@ abstract class AbstractSqlParser extends ParserInterface with Logging { astBuilder.visitSingleTableIdentifier(parser.singleTableIdentifier()) } + /** Creates FunctionIdentifier for a given SQL string. */ + override def parseFunctionIdentifier(sqlText: String): FunctionIdentifier = { + parse(sqlText) { parser => + astBuilder.visitSingleFunctionIdentifier(parser.singleFunctionIdentifier()) + } + } + + /** + * Creates StructType for a given SQL string, which is a comma separated list of field + * definitions which will preserve the correct Hive metadata. + */ + override def parseTableSchema(sqlText: String): StructType = parse(sqlText) { parser => + StructType(astBuilder.visitColTypeList(parser.colTypeList())) + } + /** Creates LogicalPlan for a given SQL string. */ override def parsePlan(sqlText: String): LogicalPlan = parse(sqlText) { parser => astBuilder.visitSingleStatement(parser.singleStatement()) match { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParserInterface.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParserInterface.scala index 7f35d650b9571..75240d2196222 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParserInterface.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParserInterface.scala @@ -17,20 +17,51 @@ package org.apache.spark.sql.catalyst.parser -import org.apache.spark.sql.catalyst.TableIdentifier +import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.sql.catalyst.{FunctionIdentifier, TableIdentifier} import org.apache.spark.sql.catalyst.expressions.Expression import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.types.{DataType, StructType} /** * Interface for a parser. */ +@DeveloperApi trait ParserInterface { - /** Creates LogicalPlan for a given SQL string. */ + /** + * Parse a string to a [[LogicalPlan]]. + */ + @throws[ParseException]("Text cannot be parsed to a LogicalPlan") def parsePlan(sqlText: String): LogicalPlan - /** Creates Expression for a given SQL string. */ + /** + * Parse a string to an [[Expression]]. + */ + @throws[ParseException]("Text cannot be parsed to an Expression") def parseExpression(sqlText: String): Expression - /** Creates TableIdentifier for a given SQL string. */ + /** + * Parse a string to a [[TableIdentifier]]. + */ + @throws[ParseException]("Text cannot be parsed to a TableIdentifier") def parseTableIdentifier(sqlText: String): TableIdentifier + + /** + * Parse a string to a [[FunctionIdentifier]]. + */ + @throws[ParseException]("Text cannot be parsed to a FunctionIdentifier") + def parseFunctionIdentifier(sqlText: String): FunctionIdentifier + + /** + * Parse a string to a [[StructType]]. The passed SQL string should be a comma separated list + * of field definitions which will preserve the correct Hive metadata. + */ + @throws[ParseException]("Text cannot be parsed to a schema") + def parseTableSchema(sqlText: String): StructType + + /** + * Parse a string to a [[DataType]]. + */ + @throws[ParseException]("Text cannot be parsed to a DataType") + def parseDataType(sqlText: String): DataType } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala index 0893af26738bf..d39b0ef7e1d8a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala @@ -167,8 +167,8 @@ object ExtractFiltersAndInnerJoins extends PredicateHelper { : (Seq[(LogicalPlan, InnerLike)], Seq[Expression]) = plan match { case Join(left, right, joinType: InnerLike, cond) => val (plans, conditions) = flattenJoin(left, joinType) - (plans ++ Seq((right, joinType)), conditions ++ cond.toSeq) - + (plans ++ Seq((right, joinType)), conditions ++ + cond.toSeq.flatMap(splitConjunctivePredicates)) case Filter(filterCondition, j @ Join(left, right, _: InnerLike, joinCondition)) => val (plans, conditions) = flattenJoin(j) (plans, conditions ++ splitConjunctivePredicates(filterCondition)) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala index a5761703fd655..2fb65bd435507 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala @@ -186,6 +186,17 @@ abstract class QueryPlan[PlanType <: QueryPlan[PlanType]] extends TreeNode[PlanT */ lazy val constraints: ExpressionSet = ExpressionSet(getRelevantConstraints(validConstraints)) + /** + * Returns [[constraints]] depending on the config of enabling constraint propagation. If the + * flag is disabled, simply returning an empty constraints. + */ + private[spark] def getConstraints(constraintPropagationEnabled: Boolean): ExpressionSet = + if (constraintPropagationEnabled) { + constraints + } else { + ExpressionSet(Set.empty) + } + /** * This method can be overridden by any child class of QueryPlan to specify a set of constraints * based on the given operator's constraint propagation logic. These constraints are then @@ -219,14 +230,15 @@ abstract class QueryPlan[PlanType <: QueryPlan[PlanType]] extends TreeNode[PlanT def producedAttributes: AttributeSet = AttributeSet.empty /** - * Attributes that are referenced by expressions but not provided by this nodes children. + * Attributes that are referenced by expressions but not provided by this node's children. * Subclasses should override this method if they produce attributes internally as it is used by * assertions designed to prevent the construction of invalid plans. */ def missingInput: AttributeSet = references -- inputSet -- producedAttributes /** - * Runs [[transform]] with `rule` on all expressions present in this query operator. + * Runs [[transformExpressionsDown]] with `rule` on all expressions present + * in this query operator. * Users should not expect a specific directionality. If a specific directionality is needed, * transformExpressionsDown or transformExpressionsUp should be used. * @@ -347,9 +359,43 @@ abstract class QueryPlan[PlanType <: QueryPlan[PlanType]] extends TreeNode[PlanT override protected def innerChildren: Seq[QueryPlan[_]] = subqueries /** - * Canonicalized copy of this query plan. + * Returns a plan where a best effort attempt has been made to transform `this` in a way + * that preserves the result but removes cosmetic variations (case sensitivity, ordering for + * commutative operations, expression id, etc.) + * + * Plans where `this.canonicalized == other.canonicalized` will always evaluate to the same + * result. + * + * Some nodes should overwrite this to provide proper canonicalize logic. + */ + lazy val canonicalized: PlanType = { + val canonicalizedChildren = children.map(_.canonicalized) + var id = -1 + preCanonicalized.mapExpressions { + case a: Alias => + id += 1 + // As the root of the expression, Alias will always take an arbitrary exprId, we need to + // normalize that for equality testing, by assigning expr id from 0 incrementally. The + // alias name doesn't matter and should be erased. + val normalizedChild = QueryPlan.normalizeExprId(a.child, allAttributes) + Alias(normalizedChild, "")(ExprId(id), a.qualifier, isGenerated = a.isGenerated) + + case ar: AttributeReference if allAttributes.indexOf(ar.exprId) == -1 => + // Top level `AttributeReference` may also be used for output like `Alias`, we should + // normalize the epxrId too. + id += 1 + ar.withExprId(ExprId(id)) + + case other => QueryPlan.normalizeExprId(other, allAttributes) + }.withNewChildren(canonicalizedChildren) + } + + /** + * Do some simple transformation on this plan before canonicalizing. Implementations can override + * this method to provide customized canonicalize logic without rewriting the whole logic. */ - protected lazy val canonicalized: PlanType = this + protected def preCanonicalized: PlanType = this + /** * Returns true when the given query plan will return the same results as this query plan. @@ -360,49 +406,40 @@ abstract class QueryPlan[PlanType <: QueryPlan[PlanType]] extends TreeNode[PlanT * enhancements like caching. However, it is not acceptable to return true if the results could * possibly be different. * - * By default this function performs a modified version of equality that is tolerant of cosmetic - * differences like attribute naming and or expression id differences. Operators that - * can do better should override this function. + * This function performs a modified version of equality that is tolerant of cosmetic + * differences like attribute naming and or expression id differences. */ - def sameResult(plan: PlanType): Boolean = { - val left = this.canonicalized - val right = plan.canonicalized - left.getClass == right.getClass && - left.children.size == right.children.size && - left.cleanArgs == right.cleanArgs && - (left.children, right.children).zipped.forall(_ sameResult _) - } + final def sameResult(other: PlanType): Boolean = this.canonicalized == other.canonicalized + + /** + * Returns a `hashCode` for the calculation performed by this plan. Unlike the standard + * `hashCode`, an attempt has been made to eliminate cosmetic differences. + */ + final def semanticHash(): Int = canonicalized.hashCode() /** * All the attributes that are used for this plan. */ lazy val allAttributes: AttributeSeq = children.flatMap(_.output) +} - protected def cleanExpression(e: Expression): Expression = e match { - case a: Alias => - // As the root of the expression, Alias will always take an arbitrary exprId, we need - // to erase that for equality testing. - val cleanedExprId = - Alias(a.child, a.name)(ExprId(-1), a.qualifier, isGenerated = a.isGenerated) - BindReferences.bindReference(cleanedExprId, allAttributes, allowFailures = true) - case other => - BindReferences.bindReference(other, allAttributes, allowFailures = true) - } - - /** Args that have cleaned such that differences in expression id should not affect equality */ - protected lazy val cleanArgs: Seq[Any] = { - def cleanArg(arg: Any): Any = arg match { - // Children are checked using sameResult above. - case tn: TreeNode[_] if containsChild(tn) => null - case e: Expression => cleanExpression(e).canonicalized - case other => other - } - - mapProductIterator { - case s: Option[_] => s.map(cleanArg) - case s: Seq[_] => s.map(cleanArg) - case m: Map[_, _] => m.mapValues(cleanArg) - case other => cleanArg(other) - }.toSeq +object QueryPlan { + /** + * Normalize the exprIds in the given expression, by updating the exprId in `AttributeReference` + * with its referenced ordinal from input attributes. It's similar to `BindReferences` but we + * do not use `BindReferences` here as the plan may take the expression as a parameter with type + * `Attribute`, and replace it with `BoundReference` will cause error. + */ + def normalizeExprId[T <: Expression](e: T, input: AttributeSeq): T = { + e.transformUp { + case s: SubqueryExpression => s.canonicalize(input) + case ar: AttributeReference => + val ordinal = input.indexOf(ar.exprId) + if (ordinal == -1) { + ar + } else { + ar.withExprId(ExprId(ordinal)) + } + }.canonicalized.asInstanceOf[T] } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/joinTypes.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/joinTypes.scala index 818f4e5ed2ae5..90d11d6d91512 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/joinTypes.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/joinTypes.scala @@ -17,10 +17,12 @@ package org.apache.spark.sql.catalyst.plans +import java.util.Locale + import org.apache.spark.sql.catalyst.expressions.Attribute object JoinType { - def apply(typ: String): JoinType = typ.toLowerCase.replace("_", "") match { + def apply(typ: String): JoinType = typ.toLowerCase(Locale.ROOT).replace("_", "") match { case "inner" => Inner case "outer" | "full" | "fullouter" => FullOuter case "leftouter" | "left" => LeftOuter diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/EventTimeWatermark.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/EventTimeWatermark.scala index 77309ce391a1a..06196b5afb031 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/EventTimeWatermark.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/EventTimeWatermark.scala @@ -24,6 +24,12 @@ import org.apache.spark.unsafe.types.CalendarInterval object EventTimeWatermark { /** The [[org.apache.spark.sql.types.Metadata]] key used to hold the eventTime watermark delay. */ val delayKey = "spark.watermarkDelayMs" + + def getDelayMs(delay: CalendarInterval): Long = { + // We define month as `31 days` to simplify calculation. + val millisPerMonth = CalendarInterval.MICROS_PER_DAY / 1000 * 31 + delay.milliseconds + delay.months * millisPerMonth + } } /** @@ -37,9 +43,17 @@ case class EventTimeWatermark( // Update the metadata on the eventTime column to include the desired delay. override val output: Seq[Attribute] = child.output.map { a => if (a semanticEquals eventTime) { + val delayMs = EventTimeWatermark.getDelayMs(delay) + val updatedMetadata = new MetadataBuilder() + .withMetadata(a.metadata) + .putLong(EventTimeWatermark.delayKey, delayMs) + .build() + a.withMetadata(updatedMetadata) + } else if (a.metadata.contains(EventTimeWatermark.delayKey)) { + // Remove existing watermark val updatedMetadata = new MetadataBuilder() .withMetadata(a.metadata) - .putLong(EventTimeWatermark.delayKey, delay.milliseconds) + .remove(EventTimeWatermark.delayKey) .build() a.withMetadata(updatedMetadata) } else { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LocalRelation.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LocalRelation.scala index 1faabcfcb73b5..9cd5dfd21b160 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LocalRelation.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LocalRelation.scala @@ -18,9 +18,10 @@ package org.apache.spark.sql.catalyst.plans.logical import org.apache.spark.sql.Row -import org.apache.spark.sql.catalyst.{CatalystConf, CatalystTypeConverters, InternalRow} +import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow} import org.apache.spark.sql.catalyst.analysis import org.apache.spark.sql.catalyst.expressions.{Attribute, Literal} +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.{StructField, StructType} object LocalRelation { @@ -66,15 +67,7 @@ case class LocalRelation(output: Seq[Attribute], data: Seq[InternalRow] = Nil) } } - override def sameResult(plan: LogicalPlan): Boolean = { - plan.canonicalized match { - case LocalRelation(otherOutput, otherData) => - otherOutput.map(_.dataType) == output.map(_.dataType) && otherData == data - case _ => false - } - } - - override def computeStats(conf: CatalystConf): Statistics = + override def computeStats(conf: SQLConf): Statistics = Statistics(sizeInBytes = output.map(n => BigInt(n.dataType.defaultSize)).sum * data.length) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala index e22b429aec68b..6bdcf490ca5c8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala @@ -19,11 +19,11 @@ package org.apache.spark.sql.catalyst.plans.logical import org.apache.spark.internal.Logging import org.apache.spark.sql.AnalysisException -import org.apache.spark.sql.catalyst.CatalystConf import org.apache.spark.sql.catalyst.analysis._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.QueryPlan import org.apache.spark.sql.catalyst.trees.CurrentOrigin +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.StructType @@ -32,7 +32,7 @@ abstract class LogicalPlan extends QueryPlan[LogicalPlan] with Logging { private var _analyzed: Boolean = false /** - * Marks this plan as already analyzed. This should only be called by CheckAnalysis. + * Marks this plan as already analyzed. This should only be called by [[CheckAnalysis]]. */ private[catalyst] def setAnalyzed(): Unit = { _analyzed = true } @@ -90,7 +90,7 @@ abstract class LogicalPlan extends QueryPlan[LogicalPlan] with Logging { * first time. If the configuration changes, the cache can be invalidated by calling * [[invalidateStatsCache()]]. */ - final def stats(conf: CatalystConf): Statistics = statsCache.getOrElse { + final def stats(conf: SQLConf): Statistics = statsCache.getOrElse { statsCache = Some(computeStats(conf)) statsCache.get } @@ -108,7 +108,7 @@ abstract class LogicalPlan extends QueryPlan[LogicalPlan] with Logging { * * [[LeafNode]]s must override this. */ - protected def computeStats(conf: CatalystConf): Statistics = { + protected def computeStats(conf: SQLConf): Statistics = { if (children.isEmpty) { throw new UnsupportedOperationException(s"LeafNode $nodeName must implement statistics.") } @@ -143,8 +143,6 @@ abstract class LogicalPlan extends QueryPlan[LogicalPlan] with Logging { */ def childrenResolved: Boolean = children.forall(_.resolved) - override lazy val canonicalized: LogicalPlan = EliminateSubqueryAliases(this) - /** * Resolves a given schema to concrete [[Attribute]] references in this query plan. This function * should only be called on analyzed plans since it will throw [[AnalysisException]] for @@ -335,7 +333,7 @@ abstract class UnaryNode extends LogicalPlan { override protected def validConstraints: Set[Expression] = child.constraints - override def computeStats(conf: CatalystConf): Statistics = { + override def computeStats(conf: SQLConf): Statistics = { // There should be some overhead in Row object, the size should not be zero when there is // no columns, this help to prevent divide-by-zero error. val childRowSize = child.output.map(_.dataType.defaultSize).sum + 8 diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/Statistics.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/Statistics.scala index f24b240956a61..3d4efef953a64 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/Statistics.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/Statistics.scala @@ -25,6 +25,7 @@ import org.apache.spark.internal.Logging import org.apache.spark.sql.{AnalysisException, Row} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate._ +import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.types._ import org.apache.spark.util.Utils @@ -74,11 +75,10 @@ case class Statistics( * Statistics collected for a column. * * 1. Supported data types are defined in `ColumnStat.supportsType`. - * 2. The JVM data type stored in min/max is the external data type (used in Row) for the - * corresponding Catalyst data type. For example, for DateType we store java.sql.Date, and for - * TimestampType we store java.sql.Timestamp. - * 3. For integral types, they are all upcasted to longs, i.e. shorts are stored as longs. - * 4. There is no guarantee that the statistics collected are accurate. Approximation algorithms + * 2. The JVM data type stored in min/max is the internal data type for the corresponding + * Catalyst data type. For example, the internal type of DateType is Int, and that the internal + * type of TimestampType is Long. + * 3. There is no guarantee that the statistics collected are accurate. Approximation algorithms * (sketches) might have been used, and the data collected can also be stale. * * @param distinctCount number of distinct values @@ -104,22 +104,43 @@ case class ColumnStat( /** * Returns a map from string to string that can be used to serialize the column stats. * The key is the name of the field (e.g. "distinctCount" or "min"), and the value is the string - * representation for the value. The deserialization side is defined in [[ColumnStat.fromMap]]. + * representation for the value. min/max values are converted to the external data type. For + * example, for DateType we store java.sql.Date, and for TimestampType we store + * java.sql.Timestamp. The deserialization side is defined in [[ColumnStat.fromMap]]. * * As part of the protocol, the returned map always contains a key called "version". * In the case min/max values are null (None), they won't appear in the map. */ - def toMap: Map[String, String] = { + def toMap(colName: String, dataType: DataType): Map[String, String] = { val map = new scala.collection.mutable.HashMap[String, String] map.put(ColumnStat.KEY_VERSION, "1") map.put(ColumnStat.KEY_DISTINCT_COUNT, distinctCount.toString) map.put(ColumnStat.KEY_NULL_COUNT, nullCount.toString) map.put(ColumnStat.KEY_AVG_LEN, avgLen.toString) map.put(ColumnStat.KEY_MAX_LEN, maxLen.toString) - min.foreach { v => map.put(ColumnStat.KEY_MIN_VALUE, v.toString) } - max.foreach { v => map.put(ColumnStat.KEY_MAX_VALUE, v.toString) } + min.foreach { v => map.put(ColumnStat.KEY_MIN_VALUE, toExternalString(v, colName, dataType)) } + max.foreach { v => map.put(ColumnStat.KEY_MAX_VALUE, toExternalString(v, colName, dataType)) } map.toMap } + + /** + * Converts the given value from Catalyst data type to string representation of external + * data type. + */ + private def toExternalString(v: Any, colName: String, dataType: DataType): String = { + val externalValue = dataType match { + case DateType => DateTimeUtils.toJavaDate(v.asInstanceOf[Int]) + case TimestampType => DateTimeUtils.toJavaTimestamp(v.asInstanceOf[Long]) + case BooleanType | _: IntegralType | FloatType | DoubleType => v + case _: DecimalType => v.asInstanceOf[Decimal].toJavaBigDecimal + // This version of Spark does not use min/max for binary/string types so we ignore it. + case _ => + throw new AnalysisException("Column statistics deserialization is not supported for " + + s"column $colName of data type: $dataType.") + } + externalValue.toString + } + } @@ -150,28 +171,15 @@ object ColumnStat extends Logging { * Creates a [[ColumnStat]] object from the given map. This is used to deserialize column stats * from some external storage. The serialization side is defined in [[ColumnStat.toMap]]. */ - def fromMap(table: String, field: StructField, map: Map[String, String]) - : Option[ColumnStat] = { - val str2val: (String => Any) = field.dataType match { - case _: IntegralType => _.toLong - case _: DecimalType => new java.math.BigDecimal(_) - case DoubleType | FloatType => _.toDouble - case BooleanType => _.toBoolean - case DateType => java.sql.Date.valueOf - case TimestampType => java.sql.Timestamp.valueOf - // This version of Spark does not use min/max for binary/string types so we ignore it. - case BinaryType | StringType => _ => null - case _ => - throw new AnalysisException("Column statistics deserialization is not supported for " + - s"column ${field.name} of data type: ${field.dataType}.") - } - + def fromMap(table: String, field: StructField, map: Map[String, String]): Option[ColumnStat] = { try { Some(ColumnStat( distinctCount = BigInt(map(KEY_DISTINCT_COUNT).toLong), // Note that flatMap(Option.apply) turns Option(null) into None. - min = map.get(KEY_MIN_VALUE).map(str2val).flatMap(Option.apply), - max = map.get(KEY_MAX_VALUE).map(str2val).flatMap(Option.apply), + min = map.get(KEY_MIN_VALUE) + .map(fromExternalString(_, field.name, field.dataType)).flatMap(Option.apply), + max = map.get(KEY_MAX_VALUE) + .map(fromExternalString(_, field.name, field.dataType)).flatMap(Option.apply), nullCount = BigInt(map(KEY_NULL_COUNT).toLong), avgLen = map.getOrElse(KEY_AVG_LEN, field.dataType.defaultSize.toString).toLong, maxLen = map.getOrElse(KEY_MAX_LEN, field.dataType.defaultSize.toString).toLong @@ -183,6 +191,30 @@ object ColumnStat extends Logging { } } + /** + * Converts from string representation of external data type to the corresponding Catalyst data + * type. + */ + private def fromExternalString(s: String, name: String, dataType: DataType): Any = { + dataType match { + case BooleanType => s.toBoolean + case DateType => DateTimeUtils.fromJavaDate(java.sql.Date.valueOf(s)) + case TimestampType => DateTimeUtils.fromJavaTimestamp(java.sql.Timestamp.valueOf(s)) + case ByteType => s.toByte + case ShortType => s.toShort + case IntegerType => s.toInt + case LongType => s.toLong + case FloatType => s.toFloat + case DoubleType => s.toDouble + case _: DecimalType => Decimal(s) + // This version of Spark does not use min/max for binary/string types so we ignore it. + case BinaryType | StringType => null + case _ => + throw new AnalysisException("Column statistics deserialization is not supported for " + + s"column $name of data type: $dataType.") + } + } + /** * Constructs an expression to compute column statistics for a given column. * @@ -232,11 +264,14 @@ object ColumnStat extends Logging { } /** Convert a struct for column stats (defined in statExprs) into [[ColumnStat]]. */ - def rowToColumnStat(row: Row): ColumnStat = { + def rowToColumnStat(row: Row, attr: Attribute): ColumnStat = { ColumnStat( distinctCount = BigInt(row.getLong(0)), - min = Option(row.get(1)), // for string/binary min/max, get should return null - max = Option(row.get(2)), + // for string/binary min/max, get should return null + min = Option(row.get(1)) + .map(v => fromExternalString(v.toString, attr.name, attr.dataType)).flatMap(Option.apply), + max = Option(row.get(2)) + .map(v => fromExternalString(v.toString, attr.name, attr.dataType)).flatMap(Option.apply), nullCount = BigInt(row.getLong(3)), avgLen = row.getLong(4), maxLen = row.getLong(5) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala index ccebae3cc2701..f663d7b8a8f7b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala @@ -17,13 +17,13 @@ package org.apache.spark.sql.catalyst.plans.logical -import org.apache.spark.sql.catalyst.{CatalystConf, TableIdentifier} import org.apache.spark.sql.catalyst.analysis.MultiInstanceRelation -import org.apache.spark.sql.catalyst.catalog.{CatalogTable, CatalogTypes} +import org.apache.spark.sql.catalyst.catalog.CatalogTable import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical.statsEstimation._ +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ import org.apache.spark.util.Utils @@ -38,6 +38,14 @@ case class ReturnAnswer(child: LogicalPlan) extends UnaryNode { override def output: Seq[Attribute] = child.output } +/** + * This node is inserted at the top of a subquery when it is optimized. This makes sure we can + * recognize a subquery as such, and it allows us to write subquery aware transformations. + */ +case class Subquery(child: LogicalPlan) extends UnaryNode { + override def output: Seq[Attribute] = child.output +} + case class Project(projectList: Seq[NamedExpression], child: LogicalPlan) extends UnaryNode { override def output: Seq[Attribute] = projectList.map(_.toAttribute) override def maxRows: Option[Long] = child.maxRows @@ -56,7 +64,7 @@ case class Project(projectList: Seq[NamedExpression], child: LogicalPlan) extend override def validConstraints: Set[Expression] = child.constraints.union(getAliasedConstraints(projectList)) - override def computeStats(conf: CatalystConf): Statistics = { + override def computeStats(conf: SQLConf): Statistics = { if (conf.cboEnabled) { ProjectEstimation.estimate(conf, this).getOrElse(super.computeStats(conf)) } else { @@ -75,7 +83,7 @@ case class Project(projectList: Seq[NamedExpression], child: LogicalPlan) extend * @param join when true, each output row is implicitly joined with the input tuple that produced * it. * @param outer when true, each input row will be output at least once, even if the output of the - * given `generator` is empty. `outer` has no effect when `join` is false. + * given `generator` is empty. * @param qualifier Qualifier for the attributes of generator(UDTF) * @param generatorOutput The output schema of the Generator. * @param child Children logical plan node @@ -130,7 +138,7 @@ case class Filter(condition: Expression, child: LogicalPlan) child.constraints.union(predicates.toSet) } - override def computeStats(conf: CatalystConf): Statistics = { + override def computeStats(conf: SQLConf): Statistics = { if (conf.cboEnabled) { FilterEstimation(this, conf).estimate.getOrElse(super.computeStats(conf)) } else { @@ -183,7 +191,7 @@ case class Intersect(left: LogicalPlan, right: LogicalPlan) extends SetOperation } } - override def computeStats(conf: CatalystConf): Statistics = { + override def computeStats(conf: SQLConf): Statistics = { val leftSize = left.stats(conf).sizeInBytes val rightSize = right.stats(conf).sizeInBytes val sizeInBytes = if (leftSize < rightSize) leftSize else rightSize @@ -200,7 +208,7 @@ case class Except(left: LogicalPlan, right: LogicalPlan) extends SetOperation(le override protected def validConstraints: Set[Expression] = leftConstraints - override def computeStats(conf: CatalystConf): Statistics = { + override def computeStats(conf: SQLConf): Statistics = { left.stats(conf).copy() } } @@ -239,7 +247,7 @@ case class Union(children: Seq[LogicalPlan]) extends LogicalPlan { children.length > 1 && childrenResolved && allChildrenCompatible } - override def computeStats(conf: CatalystConf): Statistics = { + override def computeStats(conf: SQLConf): Statistics = { val sizeInBytes = children.map(_.stats(conf).sizeInBytes).sum Statistics(sizeInBytes = sizeInBytes) } @@ -348,7 +356,7 @@ case class Join( case _ => resolvedExceptNatural } - override def computeStats(conf: CatalystConf): Statistics = { + override def computeStats(conf: SQLConf): Statistics = { def simpleEstimation: Statistics = joinType match { case LeftAnti | LeftSemi => // LeftSemi and LeftAnti won't ever be bigger than left @@ -374,8 +382,8 @@ case class BroadcastHint(child: LogicalPlan) extends UnaryNode { override def output: Seq[Attribute] = child.output // set isBroadcastable to true so the child will be broadcasted - override def computeStats(conf: CatalystConf): Statistics = - super.computeStats(conf).copy(isBroadcastable = true) + override def computeStats(conf: SQLConf): Statistics = + child.stats(conf).copy(isBroadcastable = true) } /** @@ -530,7 +538,7 @@ case class Range( override def newInstance(): Range = copy(output = output.map(_.newInstance())) - override def computeStats(conf: CatalystConf): Statistics = { + override def computeStats(conf: SQLConf): Statistics = { val sizeInBytes = LongType.defaultSize * numElements Statistics( sizeInBytes = sizeInBytes ) } @@ -563,7 +571,7 @@ case class Aggregate( child.constraints.union(getAliasedConstraints(nonAgg)) } - override def computeStats(conf: CatalystConf): Statistics = { + override def computeStats(conf: SQLConf): Statistics = { def simpleEstimation: Statistics = { if (groupingExpressions.isEmpty) { Statistics( @@ -679,7 +687,7 @@ case class Expand( override def references: AttributeSet = AttributeSet(projections.flatten.flatMap(_.references)) - override def computeStats(conf: CatalystConf): Statistics = { + override def computeStats(conf: SQLConf): Statistics = { val sizeInBytes = super.computeStats(conf).sizeInBytes * projections.length Statistics(sizeInBytes = sizeInBytes) } @@ -750,16 +758,15 @@ case class GlobalLimit(limitExpr: Expression, child: LogicalPlan) extends UnaryN case _ => None } } - override def computeStats(conf: CatalystConf): Statistics = { + override def computeStats(conf: SQLConf): Statistics = { val limit = limitExpr.eval().asInstanceOf[Int] - val sizeInBytes = if (limit == 0) { - // sizeInBytes can't be zero, or sizeInBytes of BinaryNode will also be zero - // (product of children). - 1 - } else { - (limit: Long) * output.map(a => a.dataType.defaultSize).sum - } - child.stats(conf).copy(sizeInBytes = sizeInBytes) + val childStats = child.stats(conf) + val rowCount: BigInt = childStats.rowCount.map(_.min(limit)).getOrElse(limit) + // Don't propagate column stats, because we don't know the distribution after a limit operation + Statistics( + sizeInBytes = EstimationUtils.getOutputSize(output, rowCount, childStats.attributeStats), + rowCount = Some(rowCount), + isBroadcastable = childStats.isBroadcastable) } } @@ -771,25 +778,33 @@ case class LocalLimit(limitExpr: Expression, child: LogicalPlan) extends UnaryNo case _ => None } } - override def computeStats(conf: CatalystConf): Statistics = { + override def computeStats(conf: SQLConf): Statistics = { val limit = limitExpr.eval().asInstanceOf[Int] - val sizeInBytes = if (limit == 0) { + val childStats = child.stats(conf) + if (limit == 0) { // sizeInBytes can't be zero, or sizeInBytes of BinaryNode will also be zero // (product of children). - 1 + Statistics( + sizeInBytes = 1, + rowCount = Some(0), + isBroadcastable = childStats.isBroadcastable) } else { - (limit: Long) * output.map(a => a.dataType.defaultSize).sum + // The output row count of LocalLimit should be the sum of row counts from each partition. + // However, since the number of partitions is not available here, we just use statistics of + // the child. Because the distribution after a limit operation is unknown, we do not propagate + // the column stats. + childStats.copy(attributeStats = AttributeMap(Nil)) } - child.stats(conf).copy(sizeInBytes = sizeInBytes) } } case class SubqueryAlias( alias: String, - child: LogicalPlan, - view: Option[TableIdentifier]) + child: LogicalPlan) extends UnaryNode { + override lazy val canonicalized: LogicalPlan = child.canonicalized + override def output: Seq[Attribute] = child.output.map(_.withQualifier(Some(alias))) } @@ -814,14 +829,16 @@ case class Sample( override def output: Seq[Attribute] = child.output - override def computeStats(conf: CatalystConf): Statistics = { + override def computeStats(conf: SQLConf): Statistics = { val ratio = upperBound - lowerBound - // BigInt can't multiply with Double - var sizeInBytes = child.stats(conf).sizeInBytes * (ratio * 100).toInt / 100 + val childStats = child.stats(conf) + var sizeInBytes = EstimationUtils.ceil(BigDecimal(childStats.sizeInBytes) * ratio) if (sizeInBytes == 0) { sizeInBytes = 1 } - child.stats(conf).copy(sizeInBytes = sizeInBytes) + val sampledRowCount = childStats.rowCount.map(c => EstimationUtils.ceil(BigDecimal(c) * ratio)) + // Don't propagate column stats, because we don't know the distribution after a sample operation + Statistics(sizeInBytes, sampledRowCount, isBroadcastable = childStats.isBroadcastable) } override protected def otherCopyArgs: Seq[AnyRef] = isTableSample :: Nil @@ -835,6 +852,15 @@ case class Distinct(child: LogicalPlan) extends UnaryNode { override def output: Seq[Attribute] = child.output } +/** + * A base interface for [[RepartitionByExpression]] and [[Repartition]] + */ +abstract class RepartitionOperation extends UnaryNode { + def shuffle: Boolean + def numPartitions: Int + override def output: Seq[Attribute] = child.output +} + /** * Returns a new RDD that has exactly `numPartitions` partitions. Differs from * [[RepartitionByExpression]] as this method is called directly by DataFrame's, because the user @@ -842,9 +868,8 @@ case class Distinct(child: LogicalPlan) extends UnaryNode { * of the output requires some specific ordering or distribution of the data. */ case class Repartition(numPartitions: Int, shuffle: Boolean, child: LogicalPlan) - extends UnaryNode { + extends RepartitionOperation { require(numPartitions > 0, s"Number of partitions ($numPartitions) must be positive.") - override def output: Seq[Attribute] = child.output } /** @@ -856,12 +881,12 @@ case class Repartition(numPartitions: Int, shuffle: Boolean, child: LogicalPlan) case class RepartitionByExpression( partitionExpressions: Seq[Expression], child: LogicalPlan, - numPartitions: Int) extends UnaryNode { + numPartitions: Int) extends RepartitionOperation { require(numPartitions > 0, s"Number of partitions ($numPartitions) must be positive.") override def maxRows: Option[Long] = child.maxRows - override def output: Seq[Attribute] = child.output + override def shuffle: Boolean = true } /** @@ -870,7 +895,7 @@ case class RepartitionByExpression( case object OneRowRelation extends LeafNode { override def maxRows: Option[Long] = Some(1) override def output: Seq[Attribute] = Nil - override def computeStats(conf: CatalystConf): Statistics = Statistics(sizeInBytes = 1) + override def computeStats(conf: SQLConf): Statistics = Statistics(sizeInBytes = 1) } /** A logical plan for `dropDuplicates`. */ diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala index 0be4823bbc895..bfb70c2ef4c89 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala @@ -26,7 +26,9 @@ import org.apache.spark.sql.catalyst.analysis.UnresolvedDeserializer import org.apache.spark.sql.catalyst.encoders._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.objects.Invoke +import org.apache.spark.sql.streaming.{GroupStateTimeout, OutputMode } import org.apache.spark.sql.types._ +import org.apache.spark.util.Utils object CatalystSerde { def deserialize[T : Encoder](child: LogicalPlan): DeserializeToObject = { @@ -39,7 +41,10 @@ object CatalystSerde { } def generateObjAttr[T : Encoder]: Attribute = { - AttributeReference("obj", encoderFor[T].deserializer.dataType, nullable = false)() + val enc = encoderFor[T] + val dataType = enc.deserializer.dataType + val nullable = !enc.clsTag.runtimeClass.isPrimitive + AttributeReference("obj", dataType, nullable)() } } @@ -210,13 +215,48 @@ case class TypedFilter( def typedCondition(input: Expression): Expression = { val (funcClass, methodName) = func match { case m: FilterFunction[_] => classOf[FilterFunction[_]] -> "call" - case _ => classOf[Any => Boolean] -> "apply" + case _ => FunctionUtils.getFunctionOneName(BooleanType, input.dataType) } val funcObj = Literal.create(func, ObjectType(funcClass)) Invoke(funcObj, methodName, BooleanType, input :: Nil) } } +object FunctionUtils { + private def getMethodType(dt: DataType, isOutput: Boolean): Option[String] = { + dt match { + case BooleanType if isOutput => Some("Z") + case IntegerType => Some("I") + case LongType => Some("J") + case FloatType => Some("F") + case DoubleType => Some("D") + case _ => None + } + } + + def getFunctionOneName(outputDT: DataType, inputDT: DataType): (Class[_], String) = { + // load "scala.Function1" using Java API to avoid requirements of type parameters + Utils.classForName("scala.Function1") -> { + // if a pair of an argument and return types is one of specific types + // whose specialized method (apply$mc..$sp) is generated by scalac, + // Catalyst generated a direct method call to the specialized method. + // The followings are references for this specialization: + // http://www.scala-lang.org/api/2.12.0/scala/Function1.html + // https://github.com/scala/scala/blob/2.11.x/src/compiler/scala/tools/nsc/transform/ + // SpecializeTypes.scala + // http://www.cakesolutions.net/teamblogs/scala-dissection-functions + // http://axel22.github.io/2013/11/03/specialization-quirks.html + val inputType = getMethodType(inputDT, false) + val outputType = getMethodType(outputDT, true) + if (inputType.isDefined && outputType.isDefined) { + s"apply$$mc${outputType.get}${inputType.get}$$sp" + } else { + "apply" + } + } + } +} + /** Factory for constructing new `AppendColumn` nodes. */ object AppendColumns { def apply[T : Encoder, U : Encoder]( @@ -314,24 +354,36 @@ case class MapGroups( child: LogicalPlan) extends UnaryNode with ObjectProducer /** Internal class representing State */ -trait LogicalKeyedState[S] +trait LogicalGroupState[S] + +/** Types of timeouts used in FlatMapGroupsWithState */ +case object NoTimeout extends GroupStateTimeout +case object ProcessingTimeTimeout extends GroupStateTimeout +case object EventTimeTimeout extends GroupStateTimeout /** Factory for constructing new `MapGroupsWithState` nodes. */ -object MapGroupsWithState { +object FlatMapGroupsWithState { def apply[K: Encoder, V: Encoder, S: Encoder, U: Encoder]( - func: (Any, Iterator[Any], LogicalKeyedState[Any]) => Iterator[Any], + func: (Any, Iterator[Any], LogicalGroupState[Any]) => Iterator[Any], groupingAttributes: Seq[Attribute], dataAttributes: Seq[Attribute], + outputMode: OutputMode, + isMapGroupsWithState: Boolean, + timeout: GroupStateTimeout, child: LogicalPlan): LogicalPlan = { - val mapped = new MapGroupsWithState( + val encoder = encoderFor[S] + + val mapped = new FlatMapGroupsWithState( func, UnresolvedDeserializer(encoderFor[K].deserializer, groupingAttributes), UnresolvedDeserializer(encoderFor[V].deserializer, dataAttributes), groupingAttributes, dataAttributes, CatalystSerde.generateObjAttr[U], - encoderFor[S].resolveAndBind().deserializer, - encoderFor[S].namedExpressions, + encoder.asInstanceOf[ExpressionEncoder[Any]], + outputMode, + isMapGroupsWithState, + timeout, child) CatalystSerde.serialize[U](mapped) } @@ -343,24 +395,34 @@ object MapGroupsWithState { * Func is invoked with an object representation of the grouping key an iterator containing the * object representation of all the rows with that key. * + * @param func function called on each group * @param keyDeserializer used to extract the key object for each group. * @param valueDeserializer used to extract the items in the iterator from an input row. * @param groupingAttributes used to group the data * @param dataAttributes used to read the data * @param outputObjAttr used to define the output object - * @param stateDeserializer used to deserialize state before calling `func` - * @param stateSerializer used to serialize updated state after calling `func` + * @param stateEncoder used to serialize/deserialize state before calling `func` + * @param outputMode the output mode of `func` + * @param isMapGroupsWithState whether it is created by the `mapGroupsWithState` method + * @param timeout used to timeout groups that have not received data in a while */ -case class MapGroupsWithState( - func: (Any, Iterator[Any], LogicalKeyedState[Any]) => Iterator[Any], +case class FlatMapGroupsWithState( + func: (Any, Iterator[Any], LogicalGroupState[Any]) => Iterator[Any], keyDeserializer: Expression, valueDeserializer: Expression, groupingAttributes: Seq[Attribute], dataAttributes: Seq[Attribute], outputObjAttr: Attribute, - stateDeserializer: Expression, - stateSerializer: Seq[NamedExpression], - child: LogicalPlan) extends UnaryNode with ObjectProducer + stateEncoder: ExpressionEncoder[Any], + outputMode: OutputMode, + isMapGroupsWithState: Boolean = false, + timeout: GroupStateTimeout, + child: LogicalPlan) extends UnaryNode with ObjectProducer { + + if (isMapGroupsWithState) { + assert(outputMode == OutputMode.Update) + } +} /** Factory for constructing new `FlatMapGroupsInR` nodes. */ object FlatMapGroupsInR { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/AggregateEstimation.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/AggregateEstimation.scala index ce74554c17010..48b5fbb03ef1e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/AggregateEstimation.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/AggregateEstimation.scala @@ -17,9 +17,9 @@ package org.apache.spark.sql.catalyst.plans.logical.statsEstimation -import org.apache.spark.sql.catalyst.CatalystConf import org.apache.spark.sql.catalyst.expressions.Attribute import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Statistics} +import org.apache.spark.sql.internal.SQLConf object AggregateEstimation { @@ -29,7 +29,7 @@ object AggregateEstimation { * Estimate the number of output rows based on column stats of group-by columns, and propagate * column stats for aggregate expressions. */ - def estimate(conf: CatalystConf, agg: Aggregate): Option[Statistics] = { + def estimate(conf: SQLConf, agg: Aggregate): Option[Statistics] = { val childStats = agg.child.stats(conf) // Check if we have column stats for all group-by columns. val colStatsExist = agg.groupingExpressions.forall { e => diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/EstimationUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/EstimationUtils.scala index 4d18b28be8663..f1aff62cb6af0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/EstimationUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/EstimationUtils.scala @@ -19,16 +19,16 @@ package org.apache.spark.sql.catalyst.plans.logical.statsEstimation import scala.math.BigDecimal.RoundingMode -import org.apache.spark.sql.catalyst.CatalystConf import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeMap} import org.apache.spark.sql.catalyst.plans.logical.{ColumnStat, LogicalPlan, Statistics} -import org.apache.spark.sql.types.{DataType, StringType} +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.types.{DecimalType, _} object EstimationUtils { /** Check if each plan has rowCount in its statistics. */ - def rowCountsExist(conf: CatalystConf, plans: LogicalPlan*): Boolean = + def rowCountsExist(conf: SQLConf, plans: LogicalPlan*): Boolean = plans.forall(_.stats(conf).rowCount.isDefined) /** Check if each attribute has column stat in the corresponding statistics. */ @@ -75,4 +75,32 @@ object EstimationUtils { // (simple computation of statistics returns product of children). if (outputRowCount > 0) outputRowCount * sizePerRow else 1 } + + /** + * For simplicity we use Decimal to unify operations for data types whose min/max values can be + * represented as numbers, e.g. Boolean can be represented as 0 (false) or 1 (true). + * The two methods below are the contract of conversion. + */ + def toDecimal(value: Any, dataType: DataType): Decimal = { + dataType match { + case _: NumericType | DateType | TimestampType => Decimal(value.toString) + case BooleanType => if (value.asInstanceOf[Boolean]) Decimal(1) else Decimal(0) + } + } + + def fromDecimal(dec: Decimal, dataType: DataType): Any = { + dataType match { + case BooleanType => dec.toLong == 1 + case DateType => dec.toInt + case TimestampType => dec.toLong + case ByteType => dec.toByte + case ShortType => dec.toShort + case IntegerType => dec.toInt + case LongType => dec.toLong + case FloatType => dec.toFloat + case DoubleType => dec.toDouble + case _: DecimalType => dec + } + } + } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/FilterEstimation.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/FilterEstimation.scala old mode 100644 new mode 100755 index 0c928832d7d22..4b6b3b14d9ac8 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/FilterEstimation.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/FilterEstimation.scala @@ -19,15 +19,16 @@ package org.apache.spark.sql.catalyst.plans.logical.statsEstimation import scala.collection.immutable.HashSet import scala.collection.mutable +import scala.math.BigDecimal.RoundingMode import org.apache.spark.internal.Logging -import org.apache.spark.sql.catalyst.CatalystConf import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.plans.logical._ -import org.apache.spark.sql.catalyst.util.DateTimeUtils +import org.apache.spark.sql.catalyst.expressions.Literal.{FalseLiteral, TrueLiteral} +import org.apache.spark.sql.catalyst.plans.logical.{ColumnStat, Filter, LeafNode, Statistics} +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ -case class FilterEstimation(plan: Filter, catalystConf: CatalystConf) extends Logging { +case class FilterEstimation(plan: Filter, catalystConf: SQLConf) extends Logging { private val childStats = plan.child.stats(catalystConf) @@ -52,17 +53,19 @@ case class FilterEstimation(plan: Filter, catalystConf: CatalystConf) extends Lo def estimate: Option[Statistics] = { if (childStats.rowCount.isEmpty) return None - // save a mutable copy of colStats so that we can later change it recursively + // Save a mutable copy of colStats so that we can later change it recursively. colStatsMap.setInitValues(childStats.attributeStats) - // estimate selectivity of this filter predicate - val filterSelectivity: Double = calculateFilterSelectivity(plan.condition) match { - case Some(percent) => percent - // for not-supported condition, set filter selectivity to a conservative estimate 100% - case None => 1.0 - } + // Estimate selectivity of this filter predicate, and update column stats if needed. + // For not-supported condition, set filter selectivity to a conservative estimate 100% + val filterSelectivity: Double = calculateFilterSelectivity(plan.condition).getOrElse(1.0) - val newColStats = colStatsMap.toColumnStats + val newColStats = if (filterSelectivity == 0) { + // The output is empty, we don't need to keep column stats. + AttributeMap[ColumnStat](Nil) + } else { + colStatsMap.toColumnStats + } val filteredRowCount: BigInt = EstimationUtils.ceil(BigDecimal(childStats.rowCount.get) * filterSelectivity) @@ -74,12 +77,14 @@ case class FilterEstimation(plan: Filter, catalystConf: CatalystConf) extends Lo } /** - * Returns a percentage of rows meeting a compound condition in Filter node. - * A compound condition is decomposed into multiple single conditions linked with AND, OR, NOT. + * Returns a percentage of rows meeting a condition in Filter node. + * If it's a single condition, we calculate the percentage directly. + * If it's a compound condition, it is decomposed into multiple single conditions linked with + * AND, OR, NOT. * For logical AND conditions, we need to update stats after a condition estimation * so that the stats will be more accurate for subsequent estimation. This is needed for * range condition such as (c > 40 AND c <= 50) - * For logical OR conditions, we do not update stats after a condition estimation. + * For logical OR and NOT conditions, we do not update stats after a condition estimation. * * @param condition the compound logical expression * @param update a boolean flag to specify if we need to update ColumnStat of a column @@ -90,34 +95,40 @@ case class FilterEstimation(plan: Filter, catalystConf: CatalystConf) extends Lo def calculateFilterSelectivity(condition: Expression, update: Boolean = true): Option[Double] = { condition match { case And(cond1, cond2) => - // For ease of debugging, we compute percent1 and percent2 in 2 statements. - val percent1 = calculateFilterSelectivity(cond1, update) - val percent2 = calculateFilterSelectivity(cond2, update) - (percent1, percent2) match { - case (Some(p1), Some(p2)) => Some(p1 * p2) - case (Some(p1), None) => Some(p1) - case (None, Some(p2)) => Some(p2) - case (None, None) => None - } + val percent1 = calculateFilterSelectivity(cond1, update).getOrElse(1.0) + val percent2 = calculateFilterSelectivity(cond2, update).getOrElse(1.0) + Some(percent1 * percent2) case Or(cond1, cond2) => - // For ease of debugging, we compute percent1 and percent2 in 2 statements. - val percent1 = calculateFilterSelectivity(cond1, update = false) - val percent2 = calculateFilterSelectivity(cond2, update = false) - (percent1, percent2) match { - case (Some(p1), Some(p2)) => Some(math.min(1.0, p1 + p2 - (p1 * p2))) - case (Some(p1), None) => Some(1.0) - case (None, Some(p2)) => Some(1.0) - case (None, None) => None + val percent1 = calculateFilterSelectivity(cond1, update = false).getOrElse(1.0) + val percent2 = calculateFilterSelectivity(cond2, update = false).getOrElse(1.0) + Some(percent1 + percent2 - (percent1 * percent2)) + + // Not-operator pushdown + case Not(And(cond1, cond2)) => + calculateFilterSelectivity(Or(Not(cond1), Not(cond2)), update = false) + + // Not-operator pushdown + case Not(Or(cond1, cond2)) => + calculateFilterSelectivity(And(Not(cond1), Not(cond2)), update = false) + + // Collapse two consecutive Not operators which could be generated after Not-operator pushdown + case Not(Not(cond)) => + calculateFilterSelectivity(cond, update = false) + + // The foldable Not has been processed in the ConstantFolding rule + // This is a top-down traversal. The Not could be pushed down by the above two cases. + case Not(l @ Literal(null, _)) => + calculateSingleCondition(l, update = false) + + case Not(cond) => + calculateFilterSelectivity(cond, update = false) match { + case Some(percent) => Some(1.0 - percent) + case None => None } - case Not(cond) => calculateFilterSelectivity(cond, update = false) match { - case Some(percent) => Some(1.0 - percent) - // for not-supported condition, set filter selectivity to a conservative estimate 100% - case None => None - } - - case _ => calculateSingleCondition(condition, update) + case _ => + calculateSingleCondition(condition, update) } } @@ -134,13 +145,16 @@ case class FilterEstimation(plan: Filter, catalystConf: CatalystConf) extends Lo */ def calculateSingleCondition(condition: Expression, update: Boolean): Option[Double] = { condition match { + case l: Literal => + evaluateLiteral(l) + // For evaluateBinary method, we assume the literal on the right side of an operator. // So we will change the order if not. // EqualTo/EqualNullSafe does not care about the order - case op @ Equality(ar: Attribute, l: Literal) => + case Equality(ar: Attribute, l: Literal) => evaluateEquality(ar, l, update) - case op @ Equality(l: Literal, ar: Attribute) => + case Equality(l: Literal, ar: Attribute) => evaluateEquality(ar, l, update) case op @ LessThan(ar: Attribute, l: Literal) => @@ -174,12 +188,33 @@ case class FilterEstimation(plan: Filter, catalystConf: CatalystConf) extends Lo case InSet(ar: Attribute, set) => evaluateInSet(ar, set, update) - case IsNull(ar: Attribute) => + // In current stage, we don't have advanced statistics such as sketches or histograms. + // As a result, some operator can't estimate `nullCount` accurately. E.g. left outer join + // estimation does not accurately update `nullCount` currently. + // So for IsNull and IsNotNull predicates, we only estimate them when the child is a leaf + // node, whose `nullCount` is accurate. + // This is a limitation due to lack of advanced stats. We should remove it in the future. + case IsNull(ar: Attribute) if plan.child.isInstanceOf[LeafNode] => evaluateNullCheck(ar, isNull = true, update) - case IsNotNull(ar: Attribute) => + case IsNotNull(ar: Attribute) if plan.child.isInstanceOf[LeafNode] => evaluateNullCheck(ar, isNull = false, update) + case op @ Equality(attrLeft: Attribute, attrRight: Attribute) => + evaluateBinaryForTwoColumns(op, attrLeft, attrRight, update) + + case op @ LessThan(attrLeft: Attribute, attrRight: Attribute) => + evaluateBinaryForTwoColumns(op, attrLeft, attrRight, update) + + case op @ LessThanOrEqual(attrLeft: Attribute, attrRight: Attribute) => + evaluateBinaryForTwoColumns(op, attrLeft, attrRight, update) + + case op @ GreaterThan(attrLeft: Attribute, attrRight: Attribute) => + evaluateBinaryForTwoColumns(op, attrLeft, attrRight, update) + + case op @ GreaterThanOrEqual(attrLeft: Attribute, attrRight: Attribute) => + evaluateBinaryForTwoColumns(op, attrLeft, attrRight, update) + case _ => // TODO: it's difficult to support string operators without advanced statistics. // Hence, these string operators Like(_, _) | Contains(_, _) | StartsWith(_, _) @@ -225,18 +260,18 @@ case class FilterEstimation(plan: Filter, catalystConf: CatalystConf) extends Lo } val percent = if (isNull) { - nullPercent.toDouble + nullPercent } else { - 1.0 - nullPercent.toDouble + 1.0 - nullPercent } - Some(percent) + Some(percent.toDouble) } /** * Returns a percentage of rows meeting a binary comparison expression. * - * @param op a binary comparison operator uch as =, <, <=, >, >= + * @param op a binary comparison operator such as =, <, <=, >, >= * @param attr an Attribute (or a column) * @param literal a literal value (or constant) * @param update a boolean flag to specify if we need to update ColumnStat of a given column @@ -249,41 +284,19 @@ case class FilterEstimation(plan: Filter, catalystConf: CatalystConf) extends Lo attr: Attribute, literal: Literal, update: Boolean): Option[Double] = { + if (!colStatsMap.contains(attr)) { + logDebug("[CBO] No statistics for " + attr) + return None + } + attr.dataType match { - case _: NumericType | DateType | TimestampType => + case _: NumericType | DateType | TimestampType | BooleanType => evaluateBinaryForNumeric(op, attr, literal, update) case StringType | BinaryType => // TODO: It is difficult to support other binary comparisons for String/Binary // type without min/max and advanced statistics like histogram. logDebug("[CBO] No range comparison statistics for String/Binary type " + attr) None - case _ => - // TODO: support boolean type. - None - } - } - - /** - * For a SQL data type, its internal data type may be different from its external type. - * For DateType, its internal type is Int, and its external data type is Java Date type. - * The min/max values in ColumnStat are saved in their corresponding external type. - * - * @param attrDataType the column data type - * @param litValue the literal value - * @return a BigDecimal value - */ - def convertBoundValue(attrDataType: DataType, litValue: Any): Option[Any] = { - attrDataType match { - case DateType => - Some(DateTimeUtils.toJavaDate(litValue.toString.toInt)) - case TimestampType => - Some(DateTimeUtils.toJavaTimestamp(litValue.toString.toLong)) - case _: DecimalType => - Some(litValue.asInstanceOf[Decimal].toJavaBigDecimal) - case StringType | BinaryType => - None - case _ => - Some(litValue) } } @@ -291,6 +304,10 @@ case class FilterEstimation(plan: Filter, catalystConf: CatalystConf) extends Lo * Returns a percentage of rows meeting an equality (=) expression. * This method evaluates the equality predicate for all data types. * + * For EqualNullSafe (<=>), if the literal is not null, result will be the same as EqualTo; + * if the literal is null, the condition will be changed to IsNull after optimization. + * So we don't need specific logic for EqualNullSafe here. + * * @param attr an Attribute (or a column) * @param literal a literal value (or constant) * @param update a boolean flag to specify if we need to update ColumnStat of a given column @@ -314,22 +331,46 @@ case class FilterEstimation(plan: Filter, catalystConf: CatalystConf) extends Lo val statsRange = Range(colStat.min, colStat.max, attr.dataType) if (statsRange.contains(literal)) { if (update) { - // We update ColumnStat structure after apply this equality predicate. - // Set distinctCount to 1. Set nullCount to 0. - // Need to save new min/max using the external type value of the literal - val newValue = convertBoundValue(attr.dataType, literal.value) - val newStats = colStat.copy(distinctCount = 1, min = newValue, - max = newValue, nullCount = 0) + // We update ColumnStat structure after apply this equality predicate: + // Set distinctCount to 1, nullCount to 0, and min/max values (if exist) to the literal + // value. + val newStats = attr.dataType match { + case StringType | BinaryType => + colStat.copy(distinctCount = 1, nullCount = 0) + case _ => + colStat.copy(distinctCount = 1, min = Some(literal.value), + max = Some(literal.value), nullCount = 0) + } colStatsMap(attr) = newStats } - Some(1.0 / ndv.toDouble) + Some((1.0 / BigDecimal(ndv)).toDouble) } else { Some(0.0) } } + /** + * Returns a percentage of rows meeting a Literal expression. + * This method evaluates all the possible literal cases in Filter. + * + * FalseLiteral and TrueLiteral should be eliminated by optimizer, but null literal might be added + * by optimizer rule NullPropagation. For safety, we handle all the cases here. + * + * @param literal a literal value (or constant) + * @return an optional double value to show the percentage of rows meeting a given condition + */ + def evaluateLiteral(literal: Literal): Option[Double] = { + literal match { + case Literal(null, _) => Some(0.0) + case FalseLiteral => Some(0.0) + case TrueLiteral => Some(1.0) + // Ideally, we should not hit the following branch + case _ => None + } + } + /** * Returns a percentage of rows meeting "IN" operator expression. * This method evaluates the equality predicate for all data types. @@ -368,18 +409,14 @@ case class FilterEstimation(plan: Filter, catalystConf: CatalystConf) extends Lo return Some(0.0) } - // Need to save new min/max using the external type value of the literal - val newMax = convertBoundValue( - attr.dataType, validQuerySet.maxBy(v => BigDecimal(v.toString))) - val newMin = convertBoundValue( - attr.dataType, validQuerySet.minBy(v => BigDecimal(v.toString))) - + val newMax = validQuerySet.maxBy(EstimationUtils.toDecimal(_, dataType)) + val newMin = validQuerySet.minBy(EstimationUtils.toDecimal(_, dataType)) // newNdv should not be greater than the old ndv. For example, column has only 2 values // 1 and 6. The predicate column IN (1, 2, 3, 4, 5). validQuerySet.size is 5. newNdv = ndv.min(BigInt(validQuerySet.size)) if (update) { - val newStats = colStat.copy(distinctCount = newNdv, min = newMin, - max = newMax, nullCount = 0) + val newStats = colStat.copy(distinctCount = newNdv, min = Some(newMin), + max = Some(newMax), nullCount = 0) colStatsMap(attr) = newStats } @@ -394,14 +431,14 @@ case class FilterEstimation(plan: Filter, catalystConf: CatalystConf) extends Lo // return the filter selectivity. Without advanced statistics such as histograms, // we have to assume uniform distribution. - Some(math.min(1.0, newNdv.toDouble / ndv.toDouble)) + Some(math.min(1.0, (BigDecimal(newNdv) / BigDecimal(ndv)).toDouble)) } /** * Returns a percentage of rows meeting a binary comparison expression. - * This method evaluate expression for Numeric columns only. + * This method evaluate expression for Numeric/Date/Timestamp/Boolean columns. * - * @param op a binary comparison operator uch as =, <, <=, >, >= + * @param op a binary comparison operator such as =, <, <=, >, >= * @param attr an Attribute (or a column) * @param literal a literal value (or constant) * @param update a boolean flag to specify if we need to update ColumnStat of a given column @@ -414,77 +451,306 @@ case class FilterEstimation(plan: Filter, catalystConf: CatalystConf) extends Lo literal: Literal, update: Boolean): Option[Double] = { - var percent = 1.0 val colStat = colStatsMap(attr) - val statsRange = - Range(colStat.min, colStat.max, attr.dataType).asInstanceOf[NumericRange] + val statsRange = Range(colStat.min, colStat.max, attr.dataType).asInstanceOf[NumericRange] + val max = statsRange.max.toBigDecimal + val min = statsRange.min.toBigDecimal + val ndv = BigDecimal(colStat.distinctCount) // determine the overlapping degree between predicate range and column's range - val literalValueBD = BigDecimal(literal.value.toString) + val numericLiteral = if (literal.dataType == BooleanType) { + if (literal.value.asInstanceOf[Boolean]) BigDecimal(1) else BigDecimal(0) + } else { + BigDecimal(literal.value.toString) + } val (noOverlap: Boolean, completeOverlap: Boolean) = op match { case _: LessThan => - (literalValueBD <= statsRange.min, literalValueBD > statsRange.max) + (numericLiteral <= min, numericLiteral > max) case _: LessThanOrEqual => - (literalValueBD < statsRange.min, literalValueBD >= statsRange.max) + (numericLiteral < min, numericLiteral >= max) case _: GreaterThan => - (literalValueBD >= statsRange.max, literalValueBD < statsRange.min) + (numericLiteral >= max, numericLiteral < min) case _: GreaterThanOrEqual => - (literalValueBD > statsRange.max, literalValueBD <= statsRange.min) + (numericLiteral > max, numericLiteral <= min) } + var percent = BigDecimal(1.0) if (noOverlap) { percent = 0.0 } else if (completeOverlap) { percent = 1.0 } else { - // this is partial overlap case - val literalDouble = literalValueBD.toDouble - val maxDouble = BigDecimal(statsRange.max).toDouble - val minDouble = BigDecimal(statsRange.min).toDouble - + // This is the partial overlap case: // Without advanced statistics like histogram, we assume uniform data distribution. // We just prorate the adjusted range over the initial range to compute filter selectivity. - // For ease of computation, we convert all relevant numeric values to Double. + assert(max > min) percent = op match { case _: LessThan => - (literalDouble - minDouble) / (maxDouble - minDouble) + if (numericLiteral == max) { + // If the literal value is right on the boundary, we can minus the part of the + // boundary value (1/ndv). + 1.0 - 1.0 / ndv + } else { + (numericLiteral - min) / (max - min) + } case _: LessThanOrEqual => - if (literalValueBD == BigDecimal(statsRange.min)) { - 1.0 / colStat.distinctCount.toDouble + if (numericLiteral == min) { + // The boundary value is the only satisfying value. + 1.0 / ndv } else { - (literalDouble - minDouble) / (maxDouble - minDouble) + (numericLiteral - min) / (max - min) } case _: GreaterThan => - (maxDouble - literalDouble) / (maxDouble - minDouble) + if (numericLiteral == min) { + 1.0 - 1.0 / ndv + } else { + (max - numericLiteral) / (max - min) + } case _: GreaterThanOrEqual => - if (literalValueBD == BigDecimal(statsRange.max)) { - 1.0 / colStat.distinctCount.toDouble + if (numericLiteral == max) { + 1.0 / ndv } else { - (maxDouble - literalDouble) / (maxDouble - minDouble) + (max - numericLiteral) / (max - min) } } if (update) { - // Need to save new min/max using the external type value of the literal - val newValue = convertBoundValue(attr.dataType, literal.value) + val newValue = Some(literal.value) var newMax = colStat.max var newMin = colStat.min + var newNdv = (ndv * percent).setScale(0, RoundingMode.HALF_UP).toBigInt() + if (newNdv < 1) newNdv = 1 + op match { - case _: GreaterThan => newMin = newValue - case _: GreaterThanOrEqual => newMin = newValue - case _: LessThan => newMax = newValue - case _: LessThanOrEqual => newMax = newValue + case _: GreaterThan | _: GreaterThanOrEqual => + // If new ndv is 1, then new max must be equal to new min. + newMin = if (newNdv == 1) newMax else newValue + case _: LessThan | _: LessThanOrEqual => + newMax = if (newNdv == 1) newMin else newValue } - val newNdv = math.max(math.round(colStat.distinctCount.toDouble * percent), 1) - val newStats = colStat.copy(distinctCount = newNdv, min = newMin, - max = newMax, nullCount = 0) + val newStats = + colStat.copy(distinctCount = newNdv, min = newMin, max = newMax, nullCount = 0) colStatsMap(attr) = newStats } } - Some(percent) + Some(percent.toDouble) + } + + /** + * Returns a percentage of rows meeting a binary comparison expression containing two columns. + * In SQL queries, we also see predicate expressions involving two columns + * such as "column-1 (op) column-2" where column-1 and column-2 belong to same table. + * Note that, if column-1 and column-2 belong to different tables, then it is a join + * operator's work, NOT a filter operator's work. + * + * @param op a binary comparison operator, including =, <=>, <, <=, >, >= + * @param attrLeft the left Attribute (or a column) + * @param attrRight the right Attribute (or a column) + * @param update a boolean flag to specify if we need to update ColumnStat of the given columns + * for subsequent conditions + * @return an optional double value to show the percentage of rows meeting a given condition + */ + def evaluateBinaryForTwoColumns( + op: BinaryComparison, + attrLeft: Attribute, + attrRight: Attribute, + update: Boolean): Option[Double] = { + + if (!colStatsMap.contains(attrLeft)) { + logDebug("[CBO] No statistics for " + attrLeft) + return None + } + if (!colStatsMap.contains(attrRight)) { + logDebug("[CBO] No statistics for " + attrRight) + return None + } + + attrLeft.dataType match { + case StringType | BinaryType => + // TODO: It is difficult to support other binary comparisons for String/Binary + // type without min/max and advanced statistics like histogram. + logDebug("[CBO] No range comparison statistics for String/Binary type " + attrLeft) + return None + case _ => + } + + val colStatLeft = colStatsMap(attrLeft) + val statsRangeLeft = Range(colStatLeft.min, colStatLeft.max, attrLeft.dataType) + .asInstanceOf[NumericRange] + val maxLeft = statsRangeLeft.max + val minLeft = statsRangeLeft.min + + val colStatRight = colStatsMap(attrRight) + val statsRangeRight = Range(colStatRight.min, colStatRight.max, attrRight.dataType) + .asInstanceOf[NumericRange] + val maxRight = statsRangeRight.max + val minRight = statsRangeRight.min + + // determine the overlapping degree between predicate range and column's range + val allNotNull = (colStatLeft.nullCount == 0) && (colStatRight.nullCount == 0) + val (noOverlap: Boolean, completeOverlap: Boolean) = op match { + // Left < Right or Left <= Right + // - no overlap: + // minRight maxRight minLeft maxLeft + // --------+------------------+------------+-------------+-------> + // - complete overlap: (If null values exists, we set it to partial overlap.) + // minLeft maxLeft minRight maxRight + // --------+------------------+------------+-------------+-------> + case _: LessThan => + (minLeft >= maxRight, (maxLeft < minRight) && allNotNull) + case _: LessThanOrEqual => + (minLeft > maxRight, (maxLeft <= minRight) && allNotNull) + + // Left > Right or Left >= Right + // - no overlap: + // minLeft maxLeft minRight maxRight + // --------+------------------+------------+-------------+-------> + // - complete overlap: (If null values exists, we set it to partial overlap.) + // minRight maxRight minLeft maxLeft + // --------+------------------+------------+-------------+-------> + case _: GreaterThan => + (maxLeft <= minRight, (minLeft > maxRight) && allNotNull) + case _: GreaterThanOrEqual => + (maxLeft < minRight, (minLeft >= maxRight) && allNotNull) + + // Left = Right or Left <=> Right + // - no overlap: + // minLeft maxLeft minRight maxRight + // --------+------------------+------------+-------------+-------> + // minRight maxRight minLeft maxLeft + // --------+------------------+------------+-------------+-------> + // - complete overlap: + // minLeft maxLeft + // minRight maxRight + // --------+------------------+-------> + case _: EqualTo => + ((maxLeft < minRight) || (maxRight < minLeft), + (minLeft == minRight) && (maxLeft == maxRight) && allNotNull + && (colStatLeft.distinctCount == colStatRight.distinctCount) + ) + case _: EqualNullSafe => + // For null-safe equality, we use a very restrictive condition to evaluate its overlap. + // If null values exists, we set it to partial overlap. + (((maxLeft < minRight) || (maxRight < minLeft)) && allNotNull, + (minLeft == minRight) && (maxLeft == maxRight) && allNotNull + && (colStatLeft.distinctCount == colStatRight.distinctCount) + ) + } + + var percent = BigDecimal(1.0) + if (noOverlap) { + percent = 0.0 + } else if (completeOverlap) { + percent = 1.0 + } else { + // For partial overlap, we use an empirical value 1/3 as suggested by the book + // "Database Systems, the complete book". + percent = 1.0 / 3.0 + + if (update) { + // Need to adjust new min/max after the filter condition is applied + + val ndvLeft = BigDecimal(colStatLeft.distinctCount) + var newNdvLeft = (ndvLeft * percent).setScale(0, RoundingMode.HALF_UP).toBigInt() + if (newNdvLeft < 1) newNdvLeft = 1 + val ndvRight = BigDecimal(colStatRight.distinctCount) + var newNdvRight = (ndvRight * percent).setScale(0, RoundingMode.HALF_UP).toBigInt() + if (newNdvRight < 1) newNdvRight = 1 + + var newMaxLeft = colStatLeft.max + var newMinLeft = colStatLeft.min + var newMaxRight = colStatRight.max + var newMinRight = colStatRight.min + + op match { + case _: LessThan | _: LessThanOrEqual => + // the left side should be less than the right side. + // If not, we need to adjust it to narrow the range. + // Left < Right or Left <= Right + // minRight < minLeft + // --------+******************+-------> + // filtered ^ + // | + // newMinRight + // + // maxRight < maxLeft + // --------+******************+-------> + // ^ filtered + // | + // newMaxLeft + if (minLeft > minRight) newMinRight = colStatLeft.min + if (maxLeft > maxRight) newMaxLeft = colStatRight.max + + case _: GreaterThan | _: GreaterThanOrEqual => + // the left side should be greater than the right side. + // If not, we need to adjust it to narrow the range. + // Left > Right or Left >= Right + // minLeft < minRight + // --------+******************+-------> + // filtered ^ + // | + // newMinLeft + // + // maxLeft < maxRight + // --------+******************+-------> + // ^ filtered + // | + // newMaxRight + if (minLeft < minRight) newMinLeft = colStatRight.min + if (maxLeft < maxRight) newMaxRight = colStatLeft.max + + case _: EqualTo | _: EqualNullSafe => + // need to set new min to the larger min value, and + // set the new max to the smaller max value. + // Left = Right or Left <=> Right + // minLeft < minRight + // --------+******************+-------> + // filtered ^ + // | + // newMinLeft + // + // minRight <= minLeft + // --------+******************+-------> + // filtered ^ + // | + // newMinRight + // + // maxLeft < maxRight + // --------+******************+-------> + // ^ filtered + // | + // newMaxRight + // + // maxRight <= maxLeft + // --------+******************+-------> + // ^ filtered + // | + // newMaxLeft + if (minLeft < minRight) { + newMinLeft = colStatRight.min + } else { + newMinRight = colStatLeft.min + } + if (maxLeft < maxRight) { + newMaxRight = colStatLeft.max + } else { + newMaxLeft = colStatRight.max + } + } + + val newStatsLeft = colStatLeft.copy(distinctCount = newNdvLeft, min = newMinLeft, + max = newMaxLeft) + colStatsMap(attrLeft) = newStatsLeft + val newStatsRight = colStatRight.copy(distinctCount = newNdvRight, min = newMinRight, + max = newMaxRight) + colStatsMap(attrRight) = newStatsRight + } + } + + Some(percent.toDouble) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/JoinEstimation.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/JoinEstimation.scala index 9782c0bb0a939..3245a73c8a2eb 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/JoinEstimation.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/JoinEstimation.scala @@ -21,12 +21,12 @@ import scala.collection.mutable import scala.collection.mutable.ArrayBuffer import org.apache.spark.internal.Logging -import org.apache.spark.sql.catalyst.CatalystConf import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeMap, AttributeReference, Expression} import org.apache.spark.sql.catalyst.planning.ExtractEquiJoinKeys import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical.{ColumnStat, Join, Statistics} import org.apache.spark.sql.catalyst.plans.logical.statsEstimation.EstimationUtils._ +import org.apache.spark.sql.internal.SQLConf object JoinEstimation extends Logging { @@ -34,7 +34,7 @@ object JoinEstimation extends Logging { * Estimate statistics after join. Return `None` if the join type is not supported, or we don't * have enough statistics for estimation. */ - def estimate(conf: CatalystConf, join: Join): Option[Statistics] = { + def estimate(conf: SQLConf, join: Join): Option[Statistics] = { join.joinType match { case Inner | Cross | LeftOuter | RightOuter | FullOuter => InnerOuterEstimation(conf, join).doEstimate() @@ -47,7 +47,7 @@ object JoinEstimation extends Logging { } } -case class InnerOuterEstimation(conf: CatalystConf, join: Join) extends Logging { +case class InnerOuterEstimation(conf: SQLConf, join: Join) extends Logging { private val leftStats = join.left.stats(conf) private val rightStats = join.right.stats(conf) @@ -288,7 +288,7 @@ case class InnerOuterEstimation(conf: CatalystConf, join: Join) extends Logging } } -case class LeftSemiAntiEstimation(conf: CatalystConf, join: Join) { +case class LeftSemiAntiEstimation(conf: SQLConf, join: Join) { def doEstimate(): Option[Statistics] = { // TODO: It's error-prone to estimate cardinalities for LeftSemi and LeftAnti based on basic // column stats. Now we just propagate the statistics from left side. We should do more diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/ProjectEstimation.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/ProjectEstimation.scala index e9084ad8b859c..d700cd3b20f7d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/ProjectEstimation.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/ProjectEstimation.scala @@ -17,14 +17,14 @@ package org.apache.spark.sql.catalyst.plans.logical.statsEstimation -import org.apache.spark.sql.catalyst.CatalystConf import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, AttributeMap} import org.apache.spark.sql.catalyst.plans.logical.{Project, Statistics} +import org.apache.spark.sql.internal.SQLConf object ProjectEstimation { import EstimationUtils._ - def estimate(conf: CatalystConf, project: Project): Option[Statistics] = { + def estimate(conf: SQLConf, project: Project): Option[Statistics] = { if (rowCountsExist(conf, project.child)) { val childStats = project.child.stats(conf) val inputAttrStats = childStats.attributeStats diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/Range.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/Range.scala index 3d13967cb62a4..4ac5ba5689f82 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/Range.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/Range.scala @@ -17,12 +17,8 @@ package org.apache.spark.sql.catalyst.plans.logical.statsEstimation -import java.math.{BigDecimal => JDecimal} -import java.sql.{Date, Timestamp} - import org.apache.spark.sql.catalyst.expressions.Literal -import org.apache.spark.sql.catalyst.util.DateTimeUtils -import org.apache.spark.sql.types.{BooleanType, DateType, TimestampType, _} +import org.apache.spark.sql.types._ /** Value range of a column. */ @@ -31,13 +27,10 @@ trait Range { } /** For simplicity we use decimal to unify operations of numeric ranges. */ -case class NumericRange(min: JDecimal, max: JDecimal) extends Range { +case class NumericRange(min: Decimal, max: Decimal) extends Range { override def contains(l: Literal): Boolean = { - val decimal = l.dataType match { - case BooleanType => if (l.value.asInstanceOf[Boolean]) new JDecimal(1) else new JDecimal(0) - case _ => new JDecimal(l.value.toString) - } - min.compareTo(decimal) <= 0 && max.compareTo(decimal) >= 0 + val lit = EstimationUtils.toDecimal(l.value, l.dataType) + min <= lit && max >= lit } } @@ -58,7 +51,10 @@ object Range { def apply(min: Option[Any], max: Option[Any], dataType: DataType): Range = dataType match { case StringType | BinaryType => new DefaultRange() case _ if min.isEmpty || max.isEmpty => new NullRange() - case _ => toNumericRange(min.get, max.get, dataType) + case _ => + NumericRange( + min = EstimationUtils.toDecimal(min.get, dataType), + max = EstimationUtils.toDecimal(max.get, dataType)) } def isIntersected(r1: Range, r2: Range): Boolean = (r1, r2) match { @@ -82,51 +78,11 @@ object Range { // binary/string types don't support intersecting. (None, None) case (n1: NumericRange, n2: NumericRange) => - val newRange = NumericRange(n1.min.max(n2.min), n1.max.min(n2.max)) - val (newMin, newMax) = fromNumericRange(newRange, dt) - (Some(newMin), Some(newMax)) + // Choose the maximum of two min values, and the minimum of two max values. + val newMin = if (n1.min <= n2.min) n2.min else n1.min + val newMax = if (n1.max <= n2.max) n1.max else n2.max + (Some(EstimationUtils.fromDecimal(newMin, dt)), + Some(EstimationUtils.fromDecimal(newMax, dt))) } } - - /** - * For simplicity we use decimal to unify operations of numeric types, the two methods below - * are the contract of conversion. - */ - private def toNumericRange(min: Any, max: Any, dataType: DataType): NumericRange = { - dataType match { - case _: NumericType => - NumericRange(new JDecimal(min.toString), new JDecimal(max.toString)) - case BooleanType => - val min1 = if (min.asInstanceOf[Boolean]) 1 else 0 - val max1 = if (max.asInstanceOf[Boolean]) 1 else 0 - NumericRange(new JDecimal(min1), new JDecimal(max1)) - case DateType => - val min1 = DateTimeUtils.fromJavaDate(min.asInstanceOf[Date]) - val max1 = DateTimeUtils.fromJavaDate(max.asInstanceOf[Date]) - NumericRange(new JDecimal(min1), new JDecimal(max1)) - case TimestampType => - val min1 = DateTimeUtils.fromJavaTimestamp(min.asInstanceOf[Timestamp]) - val max1 = DateTimeUtils.fromJavaTimestamp(max.asInstanceOf[Timestamp]) - NumericRange(new JDecimal(min1), new JDecimal(max1)) - } - } - - private def fromNumericRange(n: NumericRange, dataType: DataType): (Any, Any) = { - dataType match { - case _: IntegralType => - (n.min.longValue(), n.max.longValue()) - case FloatType | DoubleType => - (n.min.doubleValue(), n.max.doubleValue()) - case _: DecimalType => - (n.min, n.max) - case BooleanType => - (n.min.longValue() == 1, n.max.longValue() == 1) - case DateType => - (DateTimeUtils.toJavaDate(n.min.intValue()), DateTimeUtils.toJavaDate(n.max.intValue())) - case TimestampType => - (DateTimeUtils.toJavaTimestamp(n.min.longValue()), - DateTimeUtils.toJavaTimestamp(n.max.longValue())) - } - } - } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/broadcastMode.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/broadcastMode.scala index 9dfdf4da78ff6..2ab46dc8330aa 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/broadcastMode.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/broadcastMode.scala @@ -26,10 +26,7 @@ import org.apache.spark.sql.catalyst.InternalRow trait BroadcastMode { def transform(rows: Array[InternalRow]): Any - /** - * Returns true iff this [[BroadcastMode]] generates the same result as `other`. - */ - def compatibleWith(other: BroadcastMode): Boolean + def canonicalized: BroadcastMode } /** @@ -39,7 +36,5 @@ case object IdentityBroadcastMode extends BroadcastMode { // TODO: pack the UnsafeRows into single bytes array. override def transform(rows: Array[InternalRow]): Array[InternalRow] = rows - override def compatibleWith(other: BroadcastMode): Boolean = { - this eq other - } + override def canonicalized: BroadcastMode = this } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/streaming/InternalOutputModes.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/streaming/InternalOutputModes.scala index 351bd6fff4adf..3cd6970ebefbc 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/streaming/InternalOutputModes.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/streaming/InternalOutputModes.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.catalyst.streaming +import java.util.Locale + import org.apache.spark.sql.streaming.OutputMode /** @@ -44,4 +46,19 @@ private[sql] object InternalOutputModes { * aggregations, it will be equivalent to `Append` mode. */ case object Update extends OutputMode + + + def apply(outputMode: String): OutputMode = { + outputMode.toLowerCase(Locale.ROOT) match { + case "append" => + OutputMode.Append + case "complete" => + OutputMode.Complete + case "update" => + OutputMode.Update + case _ => + throw new IllegalArgumentException(s"Unknown output mode $outputMode. " + + "Accepted output modes are 'append', 'complete', 'update'") + } + } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala index cc4c0835954ba..2109c1c23b706 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala @@ -444,6 +444,11 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product { case None => Nil case Some(null) => Nil case Some(any) => any :: Nil + case table: CatalogTable => + table.storage.serde match { + case Some(serde) => table.identifier :: serde :: Nil + case _ => table.identifier :: Nil + } case other => other :: Nil }.mkString(", ") diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/ParseModes.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/BadRecordException.scala similarity index 54% rename from sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/ParseModes.scala rename to sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/BadRecordException.scala index 0e466962b4678..985f0dc1cd60e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/ParseModes.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/BadRecordException.scala @@ -17,25 +17,17 @@ package org.apache.spark.sql.catalyst.util -object ParseModes { - val PERMISSIVE_MODE = "PERMISSIVE" - val DROP_MALFORMED_MODE = "DROPMALFORMED" - val FAIL_FAST_MODE = "FAILFAST" +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.unsafe.types.UTF8String - val DEFAULT = PERMISSIVE_MODE - - def isValidMode(mode: String): Boolean = { - mode.toUpperCase match { - case PERMISSIVE_MODE | DROP_MALFORMED_MODE | FAIL_FAST_MODE => true - case _ => false - } - } - - def isDropMalformedMode(mode: String): Boolean = mode.toUpperCase == DROP_MALFORMED_MODE - def isFailFastMode(mode: String): Boolean = mode.toUpperCase == FAIL_FAST_MODE - def isPermissiveMode(mode: String): Boolean = if (isValidMode(mode)) { - mode.toUpperCase == PERMISSIVE_MODE - } else { - true // We default to permissive is the mode string is not valid - } -} +/** + * Exception thrown when the underlying parser meet a bad record and can't parse it. + * @param record a function to return the record that cause the parser to fail + * @param partialResult a function that returns an optional row, which is the partial result of + * parsing this bad record. + * @param cause the actual exception about why the record is bad and can't be parsed. + */ +case class BadRecordException( + record: () => UTF8String, + partialResult: () => Option[InternalRow], + cause: Throwable) extends Exception(cause) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/CaseInsensitiveMap.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/CaseInsensitiveMap.scala index 66dd093bbb691..bb2c5926ae9bb 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/CaseInsensitiveMap.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/CaseInsensitiveMap.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.catalyst.util +import java.util.Locale + /** * Builds a map in which keys are case insensitive. Input map can be accessed for cases where * case-sensitive information is required. The primary constructor is marked private to avoid @@ -26,11 +28,12 @@ package org.apache.spark.sql.catalyst.util class CaseInsensitiveMap[T] private (val originalMap: Map[String, T]) extends Map[String, T] with Serializable { - val keyLowerCasedMap = originalMap.map(kv => kv.copy(_1 = kv._1.toLowerCase)) + val keyLowerCasedMap = originalMap.map(kv => kv.copy(_1 = kv._1.toLowerCase(Locale.ROOT))) - override def get(k: String): Option[T] = keyLowerCasedMap.get(k.toLowerCase) + override def get(k: String): Option[T] = keyLowerCasedMap.get(k.toLowerCase(Locale.ROOT)) - override def contains(k: String): Boolean = keyLowerCasedMap.contains(k.toLowerCase) + override def contains(k: String): Boolean = + keyLowerCasedMap.contains(k.toLowerCase(Locale.ROOT)) override def +[B1 >: T](kv: (String, B1)): Map[String, B1] = { new CaseInsensitiveMap(originalMap + kv) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/CompressionCodecs.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/CompressionCodecs.scala index 435fba9d8851c..1377a03d93b7e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/CompressionCodecs.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/CompressionCodecs.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.catalyst.util +import java.util.Locale + import org.apache.hadoop.conf.Configuration import org.apache.hadoop.io.SequenceFile.CompressionType import org.apache.hadoop.io.compress._ @@ -38,7 +40,7 @@ object CompressionCodecs { * If it is already a class name, just return it. */ def getCodecClassName(name: String): String = { - val codecName = shortCompressionCodecNames.getOrElse(name.toLowerCase, name) + val codecName = shortCompressionCodecNames.getOrElse(name.toLowerCase(Locale.ROOT), name) try { // Validate the codec name if (codecName != null) { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala index 9e1de0fd2f3d6..eb6aad5b2d2bb 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala @@ -44,6 +44,7 @@ object DateTimeUtils { final val JULIAN_DAY_OF_EPOCH = 2440588 final val SECONDS_PER_DAY = 60 * 60 * 24L final val MICROS_PER_SECOND = 1000L * 1000L + final val MILLIS_PER_SECOND = 1000L final val NANOS_PER_SECOND = MICROS_PER_SECOND * 1000L final val MICROS_PER_DAY = MICROS_PER_SECOND * SECONDS_PER_DAY @@ -60,6 +61,8 @@ object DateTimeUtils { final val TimeZoneGMT = TimeZone.getTimeZone("GMT") final val MonthOf31Days = Set(1, 3, 5, 7, 8, 10, 12) + val TIMEZONE_OPTION = "timeZone" + def defaultTimeZone(): TimeZone = TimeZone.getDefault() // Reuse the Calendar object in each thread as it is expensive to create in each method call. @@ -235,6 +238,24 @@ object DateTimeUtils { (day.toInt, micros * 1000L) } + /* + * Converts the timestamp to milliseconds since epoch. In spark timestamp values have microseconds + * precision, so this conversion is lossy. + */ + def toMillis(us: SQLTimestamp): Long = { + // When the timestamp is negative i.e before 1970, we need to adjust the millseconds portion. + // Example - 1965-01-01 10:11:12.123456 is represented as (-157700927876544) in micro precision. + // In millis precision the above needs to be represented as (-157700927877). + Math.floor(us.toDouble / MILLIS_PER_SECOND).toLong + } + + /* + * Converts millseconds since epoch to SQLTimestamp. + */ + def fromMillis(millis: Long): SQLTimestamp = { + millis * 1000L + } + /** * Parses a given UTF8 date string to the corresponding a corresponding [[Long]] value. * The return type is [[Option]] in order to distinguish between 0L and null. The following @@ -873,7 +894,7 @@ object DateTimeUtils { * (Because 1970-01-01 is Thursday). */ def getDayOfWeekFromString(string: UTF8String): Int = { - val dowString = string.toString.toUpperCase + val dowString = string.toString.toUpperCase(Locale.ROOT) dowString match { case "SU" | "SUN" | "SUNDAY" => 3 case "MO" | "MON" | "MONDAY" => 4 @@ -930,7 +951,7 @@ object DateTimeUtils { if (format == null) { TRUNC_INVALID } else { - format.toString.toUpperCase match { + format.toString.toUpperCase(Locale.ROOT) match { case "YEAR" | "YYYY" | "YY" => TRUNC_TO_YEAR case "MON" | "MONTH" | "MM" => TRUNC_TO_MONTH case _ => TRUNC_INVALID diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/ParseMode.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/ParseMode.scala new file mode 100644 index 0000000000000..2beb875d1751d --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/ParseMode.scala @@ -0,0 +1,58 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.util + +import java.util.Locale + +import org.apache.spark.internal.Logging + +sealed trait ParseMode { + /** + * String name of the parse mode. + */ + def name: String +} + +/** + * This mode permissively parses the records. + */ +case object PermissiveMode extends ParseMode { val name = "PERMISSIVE" } + +/** + * This mode ignores the whole corrupted records. + */ +case object DropMalformedMode extends ParseMode { val name = "DROPMALFORMED" } + +/** + * This mode throws an exception when it meets corrupted records. + */ +case object FailFastMode extends ParseMode { val name = "FAILFAST" } + +object ParseMode extends Logging { + /** + * Returns the parse mode from the given string. + */ + def fromString(mode: String): ParseMode = mode.toUpperCase(Locale.ROOT) match { + case PermissiveMode.name => PermissiveMode + case DropMalformedMode.name => DropMalformedMode + case FailFastMode.name => FailFastMode + case _ => + logWarning(s"$mode is not a valid parse mode. Using ${PermissiveMode.name}.") + PermissiveMode + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/QuantileSummaries.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/QuantileSummaries.scala index 04f4ff2a92247..af543b04ba780 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/QuantileSummaries.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/QuantileSummaries.scala @@ -176,17 +176,19 @@ class QuantileSummaries( * @param quantile the target quantile * @return */ - def query(quantile: Double): Double = { + def query(quantile: Double): Option[Double] = { require(quantile >= 0 && quantile <= 1.0, "quantile should be in the range [0.0, 1.0]") require(headSampled.isEmpty, "Cannot operate on an uncompressed summary, call compress() first") + if (sampled.isEmpty) return None + if (quantile <= relativeError) { - return sampled.head.value + return Some(sampled.head.value) } if (quantile >= 1 - relativeError) { - return sampled.last.value + return Some(sampled.last.value) } // Target rank @@ -200,11 +202,11 @@ class QuantileSummaries( minRank += curSample.g val maxRank = minRank + curSample.delta if (maxRank - targetError <= rank && rank <= minRank + targetError) { - return curSample.value + return Some(curSample.value) } i += 1 } - sampled.last.value + Some(sampled.last.value) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/StringKeyHashMap.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/StringKeyHashMap.scala index a7ac6136835a7..812d5ded4bf0f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/StringKeyHashMap.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/StringKeyHashMap.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.catalyst.util +import java.util.Locale + /** * Build a map with String type of key, and it also supports either key case * sensitive or insensitive. @@ -25,7 +27,7 @@ object StringKeyHashMap { def apply[T](caseSensitive: Boolean): StringKeyHashMap[T] = if (caseSensitive) { new StringKeyHashMap[T](identity) } else { - new StringKeyHashMap[T](_.toLowerCase) + new StringKeyHashMap[T](_.toLowerCase(Locale.ROOT)) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/StringUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/StringUtils.scala index cde8bd5b9614c..ca22ea24207e1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/StringUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/StringUtils.scala @@ -19,32 +19,44 @@ package org.apache.spark.sql.catalyst.util import java.util.regex.{Pattern, PatternSyntaxException} +import org.apache.spark.sql.AnalysisException import org.apache.spark.unsafe.types.UTF8String object StringUtils { - // replace the _ with .{1} exactly match 1 time of any character - // replace the % with .*, match 0 or more times with any character - def escapeLikeRegex(v: String): String = { - if (!v.isEmpty) { - "(?s)" + (' ' +: v.init).zip(v).flatMap { - case (prev, '\\') => "" - case ('\\', c) => - c match { - case '_' => "_" - case '%' => "%" - case _ => Pattern.quote("\\" + c) - } - case (prev, c) => + /** + * Validate and convert SQL 'like' pattern to a Java regular expression. + * + * Underscores (_) are converted to '.' and percent signs (%) are converted to '.*', other + * characters are quoted literally. Escaping is done according to the rules specified in + * [[org.apache.spark.sql.catalyst.expressions.Like]] usage documentation. An invalid pattern will + * throw an [[AnalysisException]]. + * + * @param pattern the SQL pattern to convert + * @return the equivalent Java regular expression of the pattern + */ + def escapeLikeRegex(pattern: String): String = { + val in = pattern.toIterator + val out = new StringBuilder() + + def fail(message: String) = throw new AnalysisException( + s"the pattern '$pattern' is invalid, $message") + + while (in.hasNext) { + in.next match { + case '\\' if in.hasNext => + val c = in.next c match { - case '_' => "." - case '%' => ".*" - case _ => Pattern.quote(Character.toString(c)) + case '_' | '%' | '\\' => out ++= Pattern.quote(Character.toString(c)) + case _ => fail(s"the escape character is not allowed to precede '$c'") } - }.mkString - } else { - v + case '\\' => fail("it is not allowed to end with the escape character") + case '_' => out ++= "." + case '%' => out ++= ".*" + case c => out ++= Pattern.quote(Character.toString(c)) + } } + "(?s)" + out.result() // (?s) enables dotall mode, causing "." to match new lines } private[this] val trueStrings = Set("t", "true", "y", "yes", "1").map(UTF8String.fromString) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala similarity index 82% rename from sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala rename to sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 19e68c1e18b8e..a03aa5849c417 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -17,22 +17,19 @@ package org.apache.spark.sql.internal -import java.util.{NoSuchElementException, Properties, TimeZone} +import java.util.{Locale, NoSuchElementException, Properties, TimeZone} import java.util.concurrent.TimeUnit import scala.collection.JavaConverters._ import scala.collection.immutable import org.apache.hadoop.fs.Path -import org.apache.parquet.hadoop.ParquetOutputCommitter import org.apache.spark.internal.Logging import org.apache.spark.internal.config._ import org.apache.spark.network.util.ByteUnit -import org.apache.spark.sql.catalyst.CatalystConf -import org.apache.spark.sql.execution.datasources.SQLHadoopMapReduceCommitProtocol -import org.apache.spark.sql.execution.streaming.ManifestFileCommitProtocol -import org.apache.spark.util.Utils +import org.apache.spark.sql.catalyst.analysis.Resolver +import org.apache.spark.util.collection.unsafe.sort.UnsafeExternalSorter //////////////////////////////////////////////////////////////////////////////////////////////////// // This file defines the configuration options for Spark SQL. @@ -190,6 +187,15 @@ object SQLConf { .booleanConf .createWithDefault(false) + val CONSTRAINT_PROPAGATION_ENABLED = buildConf("spark.sql.constraintPropagation.enabled") + .internal() + .doc("When true, the query optimizer will infer and propagate data constraints in the query " + + "plan to optimize them. Constraint propagation can sometimes be computationally expensive" + + "for certain kinds of query plans (such as those with a large number of predicates and " + + "aliases) which might negatively impact overall runtime.") + .booleanConf + .createWithDefault(true) + val PARQUET_SCHEMA_MERGING_ENABLED = buildConf("spark.sql.parquet.mergeSchema") .doc("When true, the Parquet data source merges schemas collected from all data files, " + "otherwise the schema is picked from the summary file or a random data file " + @@ -226,6 +232,13 @@ object SQLConf { .booleanConf .createWithDefault(false) + val PARQUET_INT64_AS_TIMESTAMP_MILLIS = buildConf("spark.sql.parquet.int64AsTimestampMillis") + .doc("When true, timestamp values will be stored as INT64 with TIMESTAMP_MILLIS as the " + + "extended type. In this mode, the microsecond portion of the timestamp value will be" + + "truncated.") + .booleanConf + .createWithDefault(false) + val PARQUET_CACHE_METADATA = buildConf("spark.sql.parquet.cacheMetadata") .doc("Turns on caching of Parquet schema metadata. Can speed up querying of static data.") .booleanConf @@ -235,7 +248,7 @@ object SQLConf { .doc("Sets the compression codec use when writing Parquet files. Acceptable values include: " + "uncompressed, snappy, gzip, lzo.") .stringConf - .transform(_.toLowerCase()) + .transform(_.toLowerCase(Locale.ROOT)) .checkValues(Set("uncompressed", "snappy", "gzip", "lzo")) .createWithDefault("snappy") @@ -256,7 +269,7 @@ object SQLConf { "of org.apache.parquet.hadoop.ParquetOutputCommitter.") .internal() .stringConf - .createWithDefault(classOf[ParquetOutputCommitter].getName) + .createWithDefault("org.apache.parquet.hadoop.ParquetOutputCommitter") val PARQUET_VECTORIZED_READER_ENABLED = buildConf("spark.sql.parquet.enableVectorizedReader") @@ -306,6 +319,25 @@ object SQLConf { .longConf .createWithDefault(250 * 1024 * 1024) + object HiveCaseSensitiveInferenceMode extends Enumeration { + val INFER_AND_SAVE, INFER_ONLY, NEVER_INFER = Value + } + + val HIVE_CASE_SENSITIVE_INFERENCE = buildConf("spark.sql.hive.caseSensitiveInferenceMode") + .doc("Sets the action to take when a case-sensitive schema cannot be read from a Hive " + + "table's properties. Although Spark SQL itself is not case-sensitive, Hive compatible file " + + "formats such as Parquet are. Spark SQL must use a case-preserving schema when querying " + + "any table backed by files containing case-sensitive field names or queries may not return " + + "accurate results. Valid options include INFER_AND_SAVE (the default mode-- infer the " + + "case-sensitive schema from the underlying data files and write it back to the table " + + "properties), INFER_ONLY (infer the schema but don't attempt to write it to the table " + + "properties) and NEVER_INFER (fallback to using the case-insensitive metastore schema " + + "instead of inferring).") + .stringConf + .transform(_.toUpperCase(Locale.ROOT)) + .checkValues(HiveCaseSensitiveInferenceMode.values.map(_.toString)) + .createWithDefault(HiveCaseSensitiveInferenceMode.INFER_AND_SAVE.toString) + val OPTIMIZER_METADATA_ONLY = buildConf("spark.sql.optimizer.metadataOnly") .doc("When true, enable the metadata-only query optimization that use the table's metadata " + "to produce the partition columns instead of table scans. It applies when all the columns " + @@ -399,6 +431,12 @@ object SQLConf { .booleanConf .createWithDefault(true) + val GROUP_BY_ALIASES = buildConf("spark.sql.groupByAliases") + .doc("When true, aliases in a select list can be used in group by clauses. When false, " + + "an analysis exception is thrown in the case.") + .booleanConf + .createWithDefault(true) + // The output committer class used by data sources. The specified class needs to be a // subclass of org.apache.hadoop.mapreduce.OutputCommitter. val OUTPUT_COMMITTER_CLASS = @@ -408,7 +446,8 @@ object SQLConf { buildConf("spark.sql.sources.commitProtocolClass") .internal() .stringConf - .createWithDefault(classOf[SQLHadoopMapReduceCommitProtocol].getName) + .createWithDefault( + "org.apache.spark.sql.execution.datasources.SQLHadoopMapReduceCommitProtocol") val PARALLEL_PARTITION_DISCOVERY_THRESHOLD = buildConf("spark.sql.sources.parallelPartitionDiscovery.threshold") @@ -565,11 +604,24 @@ object SQLConf { .booleanConf .createWithDefault(true) + val MAX_NESTED_VIEW_DEPTH = + buildConf("spark.sql.view.maxNestedViewDepth") + .internal() + .doc("The maximum depth of a view reference in a nested view. A nested view may reference " + + "other nested views, the dependencies are organized in a directed acyclic graph (DAG). " + + "However the DAG depth may become too large and cause unexpected behavior. This " + + "configuration puts a limit on this: when the depth of a view exceeds this value during " + + "analysis, we terminate the resolution to avoid potential errors.") + .intConf + .checkValue(depth => depth > 0, "The maximum depth of a view reference in a nested view " + + "must be positive.") + .createWithDefault(100) + val STREAMING_FILE_COMMIT_PROTOCOL_CLASS = buildConf("spark.sql.streaming.commitProtocolClass") .internal() .stringConf - .createWithDefault(classOf[ManifestFileCommitProtocol].getName) + .createWithDefault("org.apache.spark.sql.execution.streaming.ManifestFileCommitProtocol") val OBJECT_AGG_SORT_BASED_FALLBACK_THRESHOLD = buildConf("spark.sql.objectHashAggregate.sortBased.fallbackThreshold") @@ -678,15 +730,80 @@ object SQLConf { .booleanConf .createWithDefault(false) + val JOIN_REORDER_ENABLED = + buildConf("spark.sql.cbo.joinReorder.enabled") + .doc("Enables join reorder in CBO.") + .booleanConf + .createWithDefault(false) + + val JOIN_REORDER_DP_THRESHOLD = + buildConf("spark.sql.cbo.joinReorder.dp.threshold") + .doc("The maximum number of joined nodes allowed in the dynamic programming algorithm.") + .intConf + .checkValue(number => number > 0, "The maximum number must be a positive integer.") + .createWithDefault(12) + + val JOIN_REORDER_CARD_WEIGHT = + buildConf("spark.sql.cbo.joinReorder.card.weight") + .internal() + .doc("The weight of cardinality (number of rows) for plan cost comparison in join reorder: " + + "rows * weight + size * (1 - weight).") + .doubleConf + .checkValue(weight => weight >= 0 && weight <= 1, "The weight value must be in [0, 1].") + .createWithDefault(0.7) + + val JOIN_REORDER_DP_STAR_FILTER = + buildConf("spark.sql.cbo.joinReorder.dp.star.filter") + .doc("Applies star-join filter heuristics to cost based join enumeration.") + .booleanConf + .createWithDefault(false) + + val STARSCHEMA_DETECTION = buildConf("spark.sql.cbo.starSchemaDetection") + .doc("When true, it enables join reordering based on star schema detection. ") + .booleanConf + .createWithDefault(false) + + val STARSCHEMA_FACT_TABLE_RATIO = buildConf("spark.sql.cbo.starJoinFTRatio") + .internal() + .doc("Specifies the upper limit of the ratio between the largest fact tables" + + " for a star join to be considered. ") + .doubleConf + .createWithDefault(0.9) + val SESSION_LOCAL_TIMEZONE = buildConf("spark.sql.session.timeZone") .doc("""The ID of session local timezone, e.g. "GMT", "America/Los_Angeles", etc.""") .stringConf - .createWithDefault(TimeZone.getDefault().getID()) + .createWithDefaultFunction(() => TimeZone.getDefault.getID) + + val WINDOW_EXEC_BUFFER_SPILL_THRESHOLD = + buildConf("spark.sql.windowExec.buffer.spill.threshold") + .internal() + .doc("Threshold for number of rows buffered in window operator") + .intConf + .createWithDefault(4096) + + val SORT_MERGE_JOIN_EXEC_BUFFER_SPILL_THRESHOLD = + buildConf("spark.sql.sortMergeJoinExec.buffer.spill.threshold") + .internal() + .doc("Threshold for number of rows buffered in sort merge join operator") + .intConf + .createWithDefault(Int.MaxValue) + + val CARTESIAN_PRODUCT_EXEC_BUFFER_SPILL_THRESHOLD = + buildConf("spark.sql.cartesianProductExec.buffer.spill.threshold") + .internal() + .doc("Threshold for number of rows buffered in cartesian product operator") + .intConf + .createWithDefault(UnsafeExternalSorter.DEFAULT_NUM_ELEMENTS_FOR_SPILL_THRESHOLD.toInt) object Deprecated { val MAPRED_REDUCE_TASKS = "mapred.reduce.tasks" } + + object Replaced { + val MAPREDUCE_JOB_REDUCES = "mapreduce.job.reduces" + } } /** @@ -698,7 +815,7 @@ object SQLConf { * * SQLConf is thread-safe (internally synchronized, so safe to be used in multiple threads). */ -private[sql] class SQLConf extends Serializable with CatalystConf with Logging { +class SQLConf extends Serializable with Logging { import SQLConf._ /** Only low degree of contention is expected for conf, thus NOT using ConcurrentHashMap. */ @@ -788,6 +905,9 @@ private[sql] class SQLConf extends Serializable with CatalystConf with Logging { def filesourcePartitionFileCacheSize: Long = getConf(HIVE_FILESOURCE_PARTITION_FILE_CACHE_SIZE) + def caseSensitiveInferenceMode: HiveCaseSensitiveInferenceMode.Value = + HiveCaseSensitiveInferenceMode.withName(getConf(HIVE_CASE_SENSITIVE_INFERENCE)) + def gatherFastStats: Boolean = getConf(GATHER_FASTSTAT) def optimizerMetadataOnly: Boolean = getConf(OPTIMIZER_METADATA_ONLY) @@ -807,6 +927,20 @@ private[sql] class SQLConf extends Serializable with CatalystConf with Logging { def caseSensitiveAnalysis: Boolean = getConf(SQLConf.CASE_SENSITIVE) + def constraintPropagationEnabled: Boolean = getConf(CONSTRAINT_PROPAGATION_ENABLED) + + /** + * Returns the [[Resolver]] for the current configuration, which can be used to determine if two + * identifiers are equal. + */ + def resolver: Resolver = { + if (caseSensitiveAnalysis) { + org.apache.spark.sql.catalyst.analysis.caseSensitiveResolution + } else { + org.apache.spark.sql.catalyst.analysis.caseInsensitiveResolution + } + } + def subexpressionEliminationEnabled: Boolean = getConf(SUBEXPRESSION_ELIMINATION_ENABLED) @@ -832,6 +966,8 @@ private[sql] class SQLConf extends Serializable with CatalystConf with Logging { def isParquetINT96AsTimestamp: Boolean = getConf(PARQUET_INT96_AS_TIMESTAMP) + def isParquetINT64AsTimestampMillis: Boolean = getConf(PARQUET_INT64_AS_TIMESTAMP_MILLIS) + def writeLegacyParquetFormat: Boolean = getConf(PARQUET_WRITE_LEGACY_FORMAT) def isParquetTimestampAsINT96: Boolean = getConf(PARQUET_TIMESTAMP_AS_INT96) @@ -866,7 +1002,7 @@ private[sql] class SQLConf extends Serializable with CatalystConf with Logging { def dataFramePivotMaxValues: Int = getConf(DATAFRAME_PIVOT_MAX_VALUES) - override def runSQLonFile: Boolean = getConf(RUN_SQL_ON_FILES) + def runSQLonFile: Boolean = getConf(RUN_SQL_ON_FILES) def enableTwoLevelAggMap: Boolean = getConf(ENABLE_TWOLEVEL_AGG_MAP) @@ -883,17 +1019,41 @@ private[sql] class SQLConf extends Serializable with CatalystConf with Logging { def hiveThriftServerSingleSession: Boolean = getConf(StaticSQLConf.HIVE_THRIFT_SERVER_SINGLESESSION) - override def orderByOrdinal: Boolean = getConf(ORDER_BY_ORDINAL) + def orderByOrdinal: Boolean = getConf(ORDER_BY_ORDINAL) - override def groupByOrdinal: Boolean = getConf(GROUP_BY_ORDINAL) + def groupByOrdinal: Boolean = getConf(GROUP_BY_ORDINAL) - override def crossJoinEnabled: Boolean = getConf(SQLConf.CROSS_JOINS_ENABLED) + def groupByAliases: Boolean = getConf(GROUP_BY_ALIASES) - override def sessionLocalTimeZone: String = getConf(SQLConf.SESSION_LOCAL_TIMEZONE) + def crossJoinEnabled: Boolean = getConf(SQLConf.CROSS_JOINS_ENABLED) + + def sessionLocalTimeZone: String = getConf(SQLConf.SESSION_LOCAL_TIMEZONE) def ndvMaxError: Double = getConf(NDV_MAX_ERROR) - override def cboEnabled: Boolean = getConf(SQLConf.CBO_ENABLED) + def cboEnabled: Boolean = getConf(SQLConf.CBO_ENABLED) + + def joinReorderEnabled: Boolean = getConf(SQLConf.JOIN_REORDER_ENABLED) + + def joinReorderDPThreshold: Int = getConf(SQLConf.JOIN_REORDER_DP_THRESHOLD) + + def joinReorderCardWeight: Double = getConf(SQLConf.JOIN_REORDER_CARD_WEIGHT) + + def joinReorderDPStarFilter: Boolean = getConf(SQLConf.JOIN_REORDER_DP_STAR_FILTER) + + def windowExecBufferSpillThreshold: Int = getConf(WINDOW_EXEC_BUFFER_SPILL_THRESHOLD) + + def sortMergeJoinExecBufferSpillThreshold: Int = + getConf(SORT_MERGE_JOIN_EXEC_BUFFER_SPILL_THRESHOLD) + + def cartesianProductExecBufferSpillThreshold: Int = + getConf(CARTESIAN_PRODUCT_EXEC_BUFFER_SPILL_THRESHOLD) + + def maxNestedViewDepth: Int = getConf(SQLConf.MAX_NESTED_VIEW_DEPTH) + + def starSchemaDetection: Boolean = getConf(STARSCHEMA_DETECTION) + + def starSchemaFTRatio: Double = getConf(STARSCHEMA_FACT_TABLE_RATIO) /** ********************** SQLConf functionality methods ************ */ @@ -1013,71 +1173,21 @@ private[sql] class SQLConf extends Serializable with CatalystConf with Logging { def clear(): Unit = { settings.clear() } -} - -/** - * Static SQL configuration is a cross-session, immutable Spark configuration. External users can - * see the static sql configs via `SparkSession.conf`, but can NOT set/unset them. - */ -object StaticSQLConf { - - import SQLConf.buildStaticConf - - val WAREHOUSE_PATH = buildStaticConf("spark.sql.warehouse.dir") - .doc("The default location for managed databases and tables.") - .stringConf - .createWithDefault(Utils.resolveURI("spark-warehouse").toString) - val CATALOG_IMPLEMENTATION = buildStaticConf("spark.sql.catalogImplementation") - .internal() - .stringConf - .createWithDefault("in-memory") - - val SESSION_STATE_IMPLEMENTATION = buildStaticConf("spark.sql.sessionStateImplementation") - .internal() - .stringConf - .createWithDefault(CATALOG_IMPLEMENTATION.defaultValueString) - - val GLOBAL_TEMP_DATABASE = buildStaticConf("spark.sql.globalTempDatabase") - .internal() - .stringConf - .createWithDefault("global_temp") - - // This is used to control when we will split a schema's JSON string to multiple pieces - // in order to fit the JSON string in metastore's table property (by default, the value has - // a length restriction of 4000 characters, so do not use a value larger than 4000 as the default - // value of this property). We will split the JSON string of a schema to its length exceeds the - // threshold. Note that, this conf is only read in HiveExternalCatalog which is cross-session, - // that's why this conf has to be a static SQL conf. - val SCHEMA_STRING_LENGTH_THRESHOLD = - buildStaticConf("spark.sql.sources.schemaStringLengthThreshold") - .doc("The maximum length allowed in a single cell when " + - "storing additional schema information in Hive's metastore.") - .internal() - .intConf - .createWithDefault(4000) - - val FILESOURCE_TABLE_RELATION_CACHE_SIZE = - buildStaticConf("spark.sql.filesourceTableRelationCacheSize") - .internal() - .doc("The maximum size of the cache that maps qualified table names to table relation plans.") - .intConf - .checkValue(cacheSize => cacheSize >= 0, "The maximum size of the cache must not be negative") - .createWithDefault(1000) - - // When enabling the debug, Spark SQL internal table properties are not filtered out; however, - // some related DDL commands (e.g., ANALYZE TABLE and CREATE TABLE LIKE) might not work properly. - val DEBUG_MODE = buildStaticConf("spark.sql.debug") - .internal() - .doc("Only used for internal debugging. Not all functions are supported when it is enabled.") - .booleanConf - .createWithDefault(false) + override def clone(): SQLConf = { + val result = new SQLConf + getAllConfs.foreach { + case(k, v) => if (v ne null) result.setConfString(k, v) + } + result + } - val HIVE_THRIFT_SERVER_SINGLESESSION = - buildStaticConf("spark.sql.hive.thriftServer.singleSession") - .doc("When set to true, Hive Thrift server is running in a single session mode. " + - "All the JDBC/ODBC connections share the temporary views, function registries, " + - "SQL configuration and the current database.") - .booleanConf - .createWithDefault(false) + // For test only + def copy(entries: (ConfigEntry[_], Any)*): SQLConf = { + val cloned = clone() + entries.foreach { + case (entry, value) => cloned.setConfString(entry.key, value.toString) + } + cloned + } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/StaticSQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/StaticSQLConf.scala new file mode 100644 index 0000000000000..94cacfe7777b3 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/StaticSQLConf.scala @@ -0,0 +1,94 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.internal + +import org.apache.spark.util.Utils + + +/** + * Static SQL configuration is a cross-session, immutable Spark configuration. External users can + * see the static sql configs via `SparkSession.conf`, but can NOT set/unset them. + */ +object StaticSQLConf { + + import SQLConf.buildStaticConf + + val WAREHOUSE_PATH = buildStaticConf("spark.sql.warehouse.dir") + .doc("The default location for managed databases and tables.") + .stringConf + .createWithDefault(Utils.resolveURI("spark-warehouse").toString) + + val CATALOG_IMPLEMENTATION = buildStaticConf("spark.sql.catalogImplementation") + .internal() + .stringConf + .createWithDefault("in-memory") + + val SESSION_STATE_IMPLEMENTATION = buildStaticConf("spark.sql.sessionStateImplementation") + .internal() + .stringConf + .createWithDefault(CATALOG_IMPLEMENTATION.defaultValueString) + + val GLOBAL_TEMP_DATABASE = buildStaticConf("spark.sql.globalTempDatabase") + .internal() + .stringConf + .createWithDefault("global_temp") + + // This is used to control when we will split a schema's JSON string to multiple pieces + // in order to fit the JSON string in metastore's table property (by default, the value has + // a length restriction of 4000 characters, so do not use a value larger than 4000 as the default + // value of this property). We will split the JSON string of a schema to its length exceeds the + // threshold. Note that, this conf is only read in HiveExternalCatalog which is cross-session, + // that's why this conf has to be a static SQL conf. + val SCHEMA_STRING_LENGTH_THRESHOLD = + buildStaticConf("spark.sql.sources.schemaStringLengthThreshold") + .doc("The maximum length allowed in a single cell when " + + "storing additional schema information in Hive's metastore.") + .internal() + .intConf + .createWithDefault(4000) + + val FILESOURCE_TABLE_RELATION_CACHE_SIZE = + buildStaticConf("spark.sql.filesourceTableRelationCacheSize") + .internal() + .doc("The maximum size of the cache that maps qualified table names to table relation plans.") + .intConf + .checkValue(cacheSize => cacheSize >= 0, "The maximum size of the cache must not be negative") + .createWithDefault(1000) + + // When enabling the debug, Spark SQL internal table properties are not filtered out; however, + // some related DDL commands (e.g., ANALYZE TABLE and CREATE TABLE LIKE) might not work properly. + val DEBUG_MODE = buildStaticConf("spark.sql.debug") + .internal() + .doc("Only used for internal debugging. Not all functions are supported when it is enabled.") + .booleanConf + .createWithDefault(false) + + val HIVE_THRIFT_SERVER_SINGLESESSION = + buildStaticConf("spark.sql.hive.thriftServer.singleSession") + .doc("When set to true, Hive Thrift server is running in a single session mode. " + + "All the JDBC/ODBC connections share the temporary views, function registries, " + + "SQL configuration and the current database.") + .booleanConf + .createWithDefault(false) + + val SPARK_SESSION_EXTENSIONS = buildStaticConf("spark.sql.extensions") + .doc("Name of the class used to configure Spark Session extensions. The class should " + + "implement Function1[SparkSessionExtension, Unit], and must have a no-args constructor.") + .stringConf + .createOptional +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala index 2642d9395ba88..30745c6a9d42a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.types +import java.util.Locale + import org.json4s._ import org.json4s.JsonAST.JValue import org.json4s.JsonDSL._ @@ -49,7 +51,9 @@ abstract class DataType extends AbstractDataType { /** Name of the type used in JSON serialization. */ def typeName: String = { - this.getClass.getSimpleName.stripSuffix("$").stripSuffix("Type").stripSuffix("UDT").toLowerCase + this.getClass.getSimpleName + .stripSuffix("$").stripSuffix("Type").stripSuffix("UDT") + .toLowerCase(Locale.ROOT) } private[sql] def jsonValue: JValue = typeName @@ -69,7 +73,7 @@ abstract class DataType extends AbstractDataType { /** Readable string representation for the type with truncation */ private[sql] def simpleString(maxNumberFields: Int): String = simpleString - def sql: String = simpleString.toUpperCase + def sql: String = simpleString.toUpperCase(Locale.ROOT) /** * Check if `this` and `other` are the same data type when ignoring nullability @@ -115,7 +119,10 @@ object DataType { name match { case "decimal" => DecimalType.USER_DEFAULT case FIXED_DECIMAL(precision, scale) => DecimalType(precision.toInt, scale.toInt) - case other => nonDecimalNameToType(other) + case other => nonDecimalNameToType.getOrElse( + other, + throw new IllegalArgumentException( + s"Failed to convert the JSON string '$name' to a data type.")) } } @@ -164,6 +171,10 @@ object DataType { ("sqlType", v: JValue), ("type", JString("udt"))) => new PythonUserDefinedType(parseDataType(v), pyClass, serialized) + + case other => + throw new IllegalArgumentException( + s"Failed to convert the JSON string '${compact(render(other))}' to a data type.") } private def parseStructField(json: JValue): StructField = json match { @@ -179,6 +190,9 @@ object DataType { ("nullable", JBool(nullable)), ("type", dataType: JValue)) => StructField(name, parseDataType(dataType), nullable) + case other => + throw new IllegalArgumentException( + s"Failed to convert the JSON string '${compact(render(other))}' to a field.") } protected[types] def buildFormattedString( @@ -274,4 +288,30 @@ object DataType { case (fromDataType, toDataType) => fromDataType == toDataType } } + + /** + * Returns true if the two data types share the same "shape", i.e. the types (including + * nullability) are the same, but the field names don't need to be the same. + */ + def equalsStructurally(from: DataType, to: DataType): Boolean = { + (from, to) match { + case (left: ArrayType, right: ArrayType) => + equalsStructurally(left.elementType, right.elementType) && + left.containsNull == right.containsNull + + case (left: MapType, right: MapType) => + equalsStructurally(left.keyType, right.keyType) && + equalsStructurally(left.valueType, right.valueType) && + left.valueContainsNull == right.valueContainsNull + + case (StructType(fromFields), StructType(toFields)) => + fromFields.length == toFields.length && + fromFields.zip(toFields) + .forall { case (l, r) => + equalsStructurally(l.dataType, r.dataType) && l.nullable == r.nullable + } + + case (fromDataType, toDataType) => fromDataType == toDataType + } + } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala index 089c84d5f7736..80916ee9c5379 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala @@ -21,6 +21,7 @@ import java.lang.{Long => JLong} import java.math.{BigInteger, MathContext, RoundingMode} import org.apache.spark.annotation.InterfaceStability +import org.apache.spark.sql.AnalysisException /** * A mutable implementation of BigDecimal that can hold a Long if values are small enough. @@ -131,14 +132,22 @@ final class Decimal extends Ordered[Decimal] with Serializable { } /** - * Set this Decimal to the given BigInteger value. Will have precision 38 and scale 0. + * If the value is not in the range of long, convert it to BigDecimal and + * the precision and scale are based on the converted value. + * + * This code avoids BigDecimal object allocation as possible to improve runtime efficiency */ def set(bigintval: BigInteger): Decimal = { - this.decimalVal = null - this.longVal = bigintval.longValueExact() - this._precision = DecimalType.MAX_PRECISION - this._scale = 0 - this + try { + this.decimalVal = null + this.longVal = bigintval.longValueExact() + this._precision = DecimalType.MAX_PRECISION + this._scale = 0 + this + } catch { + case _: ArithmeticException => + set(BigDecimal(bigintval)) + } } /** @@ -222,6 +231,19 @@ final class Decimal extends Ordered[Decimal] with Serializable { case java.math.BigDecimal.ROUND_HALF_EVEN => changePrecision(precision, scale, ROUND_HALF_EVEN) } + /** + * Create new `Decimal` with given precision and scale. + * + * @return `Some(decimal)` if successful or `None` if overflow would occur + */ + private[sql] def toPrecision( + precision: Int, + scale: Int, + roundMode: BigDecimal.RoundingMode.Value = ROUND_HALF_UP): Option[Decimal] = { + val copy = clone() + if (copy.changePrecision(precision, scale, roundMode)) Some(copy) else None + } + /** * Update precision and scale while keeping our value the same, and return true if successful. * @@ -362,17 +384,15 @@ final class Decimal extends Ordered[Decimal] with Serializable { def abs: Decimal = if (this.compare(Decimal.ZERO) < 0) this.unary_- else this def floor: Decimal = if (scale == 0) this else { - val value = this.clone() - value.changePrecision( - DecimalType.bounded(precision - scale + 1, 0).precision, 0, ROUND_FLOOR) - value + val newPrecision = DecimalType.bounded(precision - scale + 1, 0).precision + toPrecision(newPrecision, 0, ROUND_FLOOR).getOrElse( + throw new AnalysisException(s"Overflow when setting precision to $newPrecision")) } def ceil: Decimal = if (scale == 0) this else { - val value = this.clone() - value.changePrecision( - DecimalType.bounded(precision - scale + 1, 0).precision, 0, ROUND_CEILING) - value + val newPrecision = DecimalType.bounded(precision - scale + 1, 0).precision + toPrecision(newPrecision, 0, ROUND_CEILING).getOrElse( + throw new AnalysisException(s"Overflow when setting precision to $newPrecision")) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DecimalType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DecimalType.scala index 4dc06fc9cf09b..5c4bc5e33c53a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DecimalType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DecimalType.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.types +import java.util.Locale + import scala.reflect.runtime.universe.typeTag import org.apache.spark.annotation.InterfaceStability @@ -65,7 +67,7 @@ case class DecimalType(precision: Int, scale: Int) extends FractionalType { override def toString: String = s"DecimalType($precision,$scale)" - override def sql: String = typeName.toUpperCase + override def sql: String = typeName.toUpperCase(Locale.ROOT) /** * Returns whether this DecimalType is wider than `other`. If yes, it means `other` diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ObjectType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ObjectType.scala index b18fba29af0f9..2d49fe076786a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ObjectType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ObjectType.scala @@ -44,4 +44,9 @@ case class ObjectType(cls: Class[_]) extends DataType { def asNullable: DataType = this override def simpleString: String = cls.getName + + override def acceptsType(other: DataType): Boolean = other match { + case ObjectType(otherCls) => cls.isAssignableFrom(otherCls) + case _ => false + } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala index 8d8b5b86d5aa1..54006e20a3eb6 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala @@ -417,6 +417,12 @@ object StructType extends AbstractDataType { } } + /** + * Creates StructType for a given DDL-formatted string, which is a comma separated list of field + * definitions, e.g., a INT, b STRING. + */ + def fromDDL(ddl: String): StructType = CatalystSqlParser.parseTableSchema(ddl) + def apply(fields: Seq[StructField]): StructType = StructType(fields.toArray) def apply(fields: java.util.List[StructField]): StructType = { diff --git a/sql/catalyst/src/test/java/org/apache/spark/sql/streaming/JavaGroupStateTimeoutSuite.java b/sql/catalyst/src/test/java/org/apache/spark/sql/streaming/JavaGroupStateTimeoutSuite.java new file mode 100644 index 0000000000000..2e8f2e3fd9f47 --- /dev/null +++ b/sql/catalyst/src/test/java/org/apache/spark/sql/streaming/JavaGroupStateTimeoutSuite.java @@ -0,0 +1,33 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.streaming; + +import org.apache.spark.sql.catalyst.plans.logical.EventTimeTimeout$; +import org.apache.spark.sql.catalyst.plans.logical.NoTimeout$; +import org.apache.spark.sql.catalyst.plans.logical.ProcessingTimeTimeout$; +import org.junit.Test; + +public class JavaGroupStateTimeoutSuite { + + @Test + public void testTimeouts() { + assert (GroupStateTimeout.ProcessingTimeTimeout() == ProcessingTimeTimeout$.MODULE$); + assert (GroupStateTimeout.EventTimeTimeout() == EventTimeTimeout$.MODULE$); + assert (GroupStateTimeout.NoTimeout() == NoTimeout$.MODULE$); + } +} diff --git a/sql/catalyst/src/test/java/org/apache/spark/sql/streaming/JavaOutputModeSuite.java b/sql/catalyst/src/test/java/org/apache/spark/sql/streaming/JavaOutputModeSuite.java index e0a54fe30ac7d..d8845e0c838ff 100644 --- a/sql/catalyst/src/test/java/org/apache/spark/sql/streaming/JavaOutputModeSuite.java +++ b/sql/catalyst/src/test/java/org/apache/spark/sql/streaming/JavaOutputModeSuite.java @@ -17,6 +17,8 @@ package org.apache.spark.sql.streaming; +import java.util.Locale; + import org.junit.Test; public class JavaOutputModeSuite { @@ -24,8 +26,8 @@ public class JavaOutputModeSuite { @Test public void testOutputModes() { OutputMode o1 = OutputMode.Append(); - assert(o1.toString().toLowerCase().contains("append")); + assert(o1.toString().toLowerCase(Locale.ROOT).contains("append")); OutputMode o2 = OutputMode.Complete(); - assert (o2.toString().toLowerCase().contains("complete")); + assert (o2.toString().toLowerCase(Locale.ROOT).contains("complete")); } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/RandomDataGenerator.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/RandomDataGenerator.scala index 850869799507f..8ae3ff5043e68 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/RandomDataGenerator.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/RandomDataGenerator.scala @@ -117,11 +117,11 @@ object RandomDataGenerator { } /** - * Returns a function which generates random values for the given [[DataType]], or `None` if no + * Returns a function which generates random values for the given `DataType`, or `None` if no * random data generator is defined for that data type. The generated values will use an external - * representation of the data type; for example, the random generator for [[DateType]] will return - * instances of [[java.sql.Date]] and the generator for [[StructType]] will return a [[Row]]. - * For a [[UserDefinedType]] for a class X, an instance of class X is returned. + * representation of the data type; for example, the random generator for `DateType` will return + * instances of [[java.sql.Date]] and the generator for `StructType` will return a [[Row]]. + * For a `UserDefinedType` for a class X, an instance of class X is returned. * * @param dataType the type to generate values for * @param nullable whether null values should be generated diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/UnsafeProjectionBenchmark.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/UnsafeProjectionBenchmark.scala index a6d90409382e5..769addf3b29e6 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/UnsafeProjectionBenchmark.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/UnsafeProjectionBenchmark.scala @@ -24,7 +24,7 @@ import org.apache.spark.sql.types._ import org.apache.spark.util.Benchmark /** - * Benchmark [[UnsafeProjection]] for fixed-length/primitive-type fields. + * Benchmark `UnsafeProjection` for fixed-length/primitive-type fields. */ object UnsafeProjectionBenchmark { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala index 650a35398f3e8..70ad064f93ebc 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala @@ -312,14 +312,6 @@ class ScalaReflectionSuite extends SparkFunSuite { ArrayType(IntegerType, containsNull = false)) val arrayBufferDeserializer = deserializerFor[ArrayBuffer[Int]] assert(arrayBufferDeserializer.dataType == ObjectType(classOf[ArrayBuffer[_]])) - - // Check whether conversion is skipped when using WrappedArray[_] supertype - // (would otherwise needlessly add overhead) - import org.apache.spark.sql.catalyst.expressions.objects.StaticInvoke - val seqDeserializer = deserializerFor[Seq[Int]] - assert(seqDeserializer.asInstanceOf[StaticInvoke].staticObject == - scala.collection.mutable.WrappedArray.getClass) - assert(seqDeserializer.asInstanceOf[StaticInvoke].functionName == "make") } private val dataTypeForComplexData = dataTypeFor[ComplexData] diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala index c5e877d12811c..d2ebca5a83dd3 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala @@ -530,7 +530,7 @@ class AnalysisErrorSuite extends AnalysisTest { Exists( Join( LocalRelation(b), - Filter(EqualTo(OuterReference(a), c), LocalRelation(c)), + Filter(EqualTo(UnresolvedAttribute("a"), c), LocalRelation(c)), LeftOuter, Option(EqualTo(b, c)))), LocalRelation(a)) @@ -539,7 +539,7 @@ class AnalysisErrorSuite extends AnalysisTest { val plan2 = Filter( Exists( Join( - Filter(EqualTo(OuterReference(a), c), LocalRelation(c)), + Filter(EqualTo(UnresolvedAttribute("a"), c), LocalRelation(c)), LocalRelation(b), RightOuter, Option(EqualTo(b, c)))), @@ -547,14 +547,15 @@ class AnalysisErrorSuite extends AnalysisTest { assertAnalysisError(plan2, "Accessing outer query column is not allowed in" :: Nil) val plan3 = Filter( - Exists(Union(LocalRelation(b), Filter(EqualTo(OuterReference(a), c), LocalRelation(c)))), + Exists(Union(LocalRelation(b), + Filter(EqualTo(UnresolvedAttribute("a"), c), LocalRelation(c)))), LocalRelation(a)) assertAnalysisError(plan3, "Accessing outer query column is not allowed in" :: Nil) val plan4 = Filter( Exists( Limit(1, - Filter(EqualTo(OuterReference(a), b), LocalRelation(b))) + Filter(EqualTo(UnresolvedAttribute("a"), b), LocalRelation(b))) ), LocalRelation(a)) assertAnalysisError(plan4, "Accessing outer query column is not allowed in" :: Nil) @@ -562,7 +563,7 @@ class AnalysisErrorSuite extends AnalysisTest { val plan5 = Filter( Exists( Sample(0.0, 0.5, false, 1L, - Filter(EqualTo(OuterReference(a), b), LocalRelation(b)))().select('b) + Filter(EqualTo(UnresolvedAttribute("a"), b), LocalRelation(b)))().select('b) ), LocalRelation(a)) assertAnalysisError(plan5, diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala index 01737e0a17341..893bb1b74cea7 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala @@ -62,23 +62,23 @@ class AnalysisSuite extends AnalysisTest with ShouldMatchers { checkAnalysis( Project(Seq(UnresolvedAttribute("TbL.a")), - SubqueryAlias("TbL", UnresolvedRelation(TableIdentifier("TaBlE")), None)), + SubqueryAlias("TbL", UnresolvedRelation(TableIdentifier("TaBlE")))), Project(testRelation.output, testRelation)) assertAnalysisError( Project(Seq(UnresolvedAttribute("tBl.a")), - SubqueryAlias("TbL", UnresolvedRelation(TableIdentifier("TaBlE")), None)), + SubqueryAlias("TbL", UnresolvedRelation(TableIdentifier("TaBlE")))), Seq("cannot resolve")) checkAnalysis( Project(Seq(UnresolvedAttribute("TbL.a")), - SubqueryAlias("TbL", UnresolvedRelation(TableIdentifier("TaBlE")), None)), + SubqueryAlias("TbL", UnresolvedRelation(TableIdentifier("TaBlE")))), Project(testRelation.output, testRelation), caseSensitive = false) checkAnalysis( Project(Seq(UnresolvedAttribute("tBl.a")), - SubqueryAlias("TbL", UnresolvedRelation(TableIdentifier("TaBlE")), None)), + SubqueryAlias("TbL", UnresolvedRelation(TableIdentifier("TaBlE")))), Project(testRelation.output, testRelation), caseSensitive = false) } @@ -374,8 +374,8 @@ class AnalysisSuite extends AnalysisTest with ShouldMatchers { val query = Project(Seq($"x.key", $"y.key"), Join( - Project(Seq($"x.key"), SubqueryAlias("x", input, None)), - Project(Seq($"y.key"), SubqueryAlias("y", input, None)), + Project(Seq($"x.key"), SubqueryAlias("x", input)), + Project(Seq($"y.key"), SubqueryAlias("y", input)), Cross, None)) assertAnalysisSuccess(query) @@ -435,10 +435,10 @@ class AnalysisSuite extends AnalysisTest with ShouldMatchers { test("resolve as with an already existed alias") { checkAnalysis( Project(Seq(UnresolvedAttribute("tbl2.a")), - SubqueryAlias("tbl", testRelation, None).as("tbl2")), + SubqueryAlias("tbl", testRelation).as("tbl2")), Project(testRelation.output, testRelation), caseSensitive = false) - checkAnalysis(SubqueryAlias("tbl", testRelation, None).as("tbl2"), testRelation) + checkAnalysis(SubqueryAlias("tbl", testRelation).as("tbl2"), testRelation) } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisTest.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisTest.scala index 0f059b9591460..82015b1e0671c 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisTest.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisTest.scala @@ -17,11 +17,13 @@ package org.apache.spark.sql.catalyst.analysis +import java.util.Locale + import org.apache.spark.sql.AnalysisException -import org.apache.spark.sql.catalyst.SimpleCatalystConf import org.apache.spark.sql.catalyst.catalog.{InMemoryCatalog, SessionCatalog} import org.apache.spark.sql.catalyst.plans.PlanTest import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.internal.SQLConf trait AnalysisTest extends PlanTest { @@ -29,7 +31,7 @@ trait AnalysisTest extends PlanTest { protected val caseInsensitiveAnalyzer = makeAnalyzer(caseSensitive = false) private def makeAnalyzer(caseSensitive: Boolean): Analyzer = { - val conf = new SimpleCatalystConf(caseSensitive) + val conf = new SQLConf().copy(SQLConf.CASE_SENSITIVE -> caseSensitive) val catalog = new SessionCatalog(new InMemoryCatalog, EmptyFunctionRegistry, conf) catalog.createTempView("TaBlE", TestRelations.testRelation, overrideIfExists = true) catalog.createTempView("TaBlE2", TestRelations.testRelation2, overrideIfExists = true) @@ -79,7 +81,8 @@ trait AnalysisTest extends PlanTest { analyzer.checkAnalysis(analyzer.execute(inputPlan)) } - if (!expectedErrors.map(_.toLowerCase).forall(e.getMessage.toLowerCase.contains)) { + if (!expectedErrors.map(_.toLowerCase(Locale.ROOT)).forall( + e.getMessage.toLowerCase(Locale.ROOT).contains)) { fail( s"""Exception message should contain the following substrings: | diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecisionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecisionSuite.scala index 6995faebfa862..8f43171f309a9 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecisionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecisionSuite.scala @@ -19,7 +19,6 @@ package org.apache.spark.sql.catalyst.analysis import org.scalatest.BeforeAndAfter -import org.apache.spark.sql.catalyst.SimpleCatalystConf import org.apache.spark.sql.catalyst.catalog.{InMemoryCatalog, SessionCatalog} import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.expressions._ diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveInlineTablesSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveInlineTablesSuite.scala index f45a826869842..d0fe815052256 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveInlineTablesSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveInlineTablesSuite.scala @@ -22,6 +22,7 @@ import org.scalatest.BeforeAndAfter import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.expressions.{Cast, Literal, Rand} import org.apache.spark.sql.catalyst.expressions.aggregate.Count +import org.apache.spark.sql.catalyst.plans.logical.LocalRelation import org.apache.spark.sql.types.{LongType, NullType, TimestampType} /** @@ -91,12 +92,13 @@ class ResolveInlineTablesSuite extends AnalysisTest with BeforeAndAfter { test("convert TimeZoneAwareExpression") { val table = UnresolvedInlineTable(Seq("c1"), Seq(Seq(Cast(lit("1991-12-06 00:00:00.0"), TimestampType)))) - val converted = ResolveInlineTables(conf).convert(table) + val withTimeZone = ResolveTimeZone(conf).apply(table) + val LocalRelation(output, data) = ResolveInlineTables(conf).apply(withTimeZone) val correct = Cast(lit("1991-12-06 00:00:00.0"), TimestampType) .withTimeZone(conf.sessionLocalTimeZone).eval().asInstanceOf[Long] - assert(converted.output.map(_.dataType) == Seq(TimestampType)) - assert(converted.data.size == 1) - assert(converted.data(0).getLong(0) == correct) + assert(output.map(_.dataType) == Seq(TimestampType)) + assert(data.size == 1) + assert(data.head.getLong(0) == correct) } test("nullability inference in convert") { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveSubquerySuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveSubquerySuite.scala index 4aafb2b83fb69..55693121431a2 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveSubquerySuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveSubquerySuite.scala @@ -33,7 +33,7 @@ class ResolveSubquerySuite extends AnalysisTest { val t2 = LocalRelation(b) test("SPARK-17251 Improve `OuterReference` to be `NamedExpression`") { - val expr = Filter(In(a, Seq(ListQuery(Project(Seq(OuterReference(a)), t2)))), t1) + val expr = Filter(In(a, Seq(ListQuery(Project(Seq(UnresolvedAttribute("a")), t2)))), t1) val m = intercept[AnalysisException] { SimpleAnalyzer.ResolveSubquery(expr) }.getMessage diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/SubstituteUnresolvedOrdinalsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/SubstituteUnresolvedOrdinalsSuite.scala index 88f68ebadc72a..2331346f325aa 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/SubstituteUnresolvedOrdinalsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/SubstituteUnresolvedOrdinalsSuite.scala @@ -21,7 +21,7 @@ import org.apache.spark.sql.catalyst.analysis.TestRelations.testRelation2 import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.dsl.plans._ import org.apache.spark.sql.catalyst.expressions.Literal -import org.apache.spark.sql.catalyst.SimpleCatalystConf +import org.apache.spark.sql.internal.SQLConf class SubstituteUnresolvedOrdinalsSuite extends AnalysisTest { private lazy val a = testRelation2.output(0) @@ -44,7 +44,7 @@ class SubstituteUnresolvedOrdinalsSuite extends AnalysisTest { // order by ordinal can be turned off by config comparePlans( - new SubstituteUnresolvedOrdinals(conf.copy(orderByOrdinal = false)).apply(plan), + new SubstituteUnresolvedOrdinals(conf.copy(SQLConf.ORDER_BY_ORDINAL -> false)).apply(plan), testRelation2.orderBy(Literal(1).asc, Literal(2).asc)) } @@ -60,7 +60,7 @@ class SubstituteUnresolvedOrdinalsSuite extends AnalysisTest { // group by ordinal can be turned off by config comparePlans( - new SubstituteUnresolvedOrdinals(conf.copy(groupByOrdinal = false)).apply(plan2), + new SubstituteUnresolvedOrdinals(conf.copy(SQLConf.GROUP_BY_ORDINAL -> false)).apply(plan2), testRelation2.groupBy(Literal(1), Literal(2))('a, 'b)) } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala index 3e0c357b6de42..2624f5586fd5d 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala @@ -25,6 +25,7 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.PlanTest import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules.{Rule, RuleExecutor} +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.CalendarInterval @@ -656,14 +657,20 @@ class TypeCoercionSuite extends PlanTest { test("nanvl casts") { ruleTest(TypeCoercion.FunctionArgumentConversion, - NaNvl(Literal.create(1.0, FloatType), Literal.create(1.0, DoubleType)), - NaNvl(Cast(Literal.create(1.0, FloatType), DoubleType), Literal.create(1.0, DoubleType))) + NaNvl(Literal.create(1.0f, FloatType), Literal.create(1.0, DoubleType)), + NaNvl(Cast(Literal.create(1.0f, FloatType), DoubleType), Literal.create(1.0, DoubleType))) ruleTest(TypeCoercion.FunctionArgumentConversion, - NaNvl(Literal.create(1.0, DoubleType), Literal.create(1.0, FloatType)), - NaNvl(Literal.create(1.0, DoubleType), Cast(Literal.create(1.0, FloatType), DoubleType))) + NaNvl(Literal.create(1.0, DoubleType), Literal.create(1.0f, FloatType)), + NaNvl(Literal.create(1.0, DoubleType), Cast(Literal.create(1.0f, FloatType), DoubleType))) ruleTest(TypeCoercion.FunctionArgumentConversion, NaNvl(Literal.create(1.0, DoubleType), Literal.create(1.0, DoubleType)), NaNvl(Literal.create(1.0, DoubleType), Literal.create(1.0, DoubleType))) + ruleTest(TypeCoercion.FunctionArgumentConversion, + NaNvl(Literal.create(1.0f, FloatType), Literal.create(null, NullType)), + NaNvl(Literal.create(1.0f, FloatType), Cast(Literal.create(null, NullType), FloatType))) + ruleTest(TypeCoercion.FunctionArgumentConversion, + NaNvl(Literal.create(1.0, DoubleType), Literal.create(null, NullType)), + NaNvl(Literal.create(1.0, DoubleType), Cast(Literal.create(null, NullType), DoubleType))) } test("type coercion for If") { @@ -781,6 +788,12 @@ class TypeCoercionSuite extends PlanTest { } } + private val timeZoneResolver = ResolveTimeZone(new SQLConf) + + private def widenSetOperationTypes(plan: LogicalPlan): LogicalPlan = { + timeZoneResolver(TypeCoercion.WidenSetOperationTypes(plan)) + } + test("WidenSetOperationTypes for except and intersect") { val firstTable = LocalRelation( AttributeReference("i", IntegerType)(), @@ -793,11 +806,10 @@ class TypeCoercionSuite extends PlanTest { AttributeReference("f", FloatType)(), AttributeReference("l", LongType)()) - val wt = TypeCoercion.WidenSetOperationTypes val expectedTypes = Seq(StringType, DecimalType.SYSTEM_DEFAULT, FloatType, DoubleType) - val r1 = wt(Except(firstTable, secondTable)).asInstanceOf[Except] - val r2 = wt(Intersect(firstTable, secondTable)).asInstanceOf[Intersect] + val r1 = widenSetOperationTypes(Except(firstTable, secondTable)).asInstanceOf[Except] + val r2 = widenSetOperationTypes(Intersect(firstTable, secondTable)).asInstanceOf[Intersect] checkOutput(r1.left, expectedTypes) checkOutput(r1.right, expectedTypes) checkOutput(r2.left, expectedTypes) @@ -832,10 +844,9 @@ class TypeCoercionSuite extends PlanTest { AttributeReference("p", ByteType)(), AttributeReference("q", DoubleType)()) - val wt = TypeCoercion.WidenSetOperationTypes val expectedTypes = Seq(StringType, DecimalType.SYSTEM_DEFAULT, FloatType, DoubleType) - val unionRelation = wt( + val unionRelation = widenSetOperationTypes( Union(firstTable :: secondTable :: thirdTable :: forthTable :: Nil)).asInstanceOf[Union] assert(unionRelation.children.length == 4) checkOutput(unionRelation.children.head, expectedTypes) @@ -856,17 +867,15 @@ class TypeCoercionSuite extends PlanTest { } } - val dp = TypeCoercion.WidenSetOperationTypes - val left1 = LocalRelation( AttributeReference("l", DecimalType(10, 8))()) val right1 = LocalRelation( AttributeReference("r", DecimalType(5, 5))()) val expectedType1 = Seq(DecimalType(10, 8)) - val r1 = dp(Union(left1, right1)).asInstanceOf[Union] - val r2 = dp(Except(left1, right1)).asInstanceOf[Except] - val r3 = dp(Intersect(left1, right1)).asInstanceOf[Intersect] + val r1 = widenSetOperationTypes(Union(left1, right1)).asInstanceOf[Union] + val r2 = widenSetOperationTypes(Except(left1, right1)).asInstanceOf[Except] + val r3 = widenSetOperationTypes(Intersect(left1, right1)).asInstanceOf[Intersect] checkOutput(r1.children.head, expectedType1) checkOutput(r1.children.last, expectedType1) @@ -885,17 +894,17 @@ class TypeCoercionSuite extends PlanTest { val plan2 = LocalRelation( AttributeReference("r", rType)()) - val r1 = dp(Union(plan1, plan2)).asInstanceOf[Union] - val r2 = dp(Except(plan1, plan2)).asInstanceOf[Except] - val r3 = dp(Intersect(plan1, plan2)).asInstanceOf[Intersect] + val r1 = widenSetOperationTypes(Union(plan1, plan2)).asInstanceOf[Union] + val r2 = widenSetOperationTypes(Except(plan1, plan2)).asInstanceOf[Except] + val r3 = widenSetOperationTypes(Intersect(plan1, plan2)).asInstanceOf[Intersect] checkOutput(r1.children.last, Seq(expectedType)) checkOutput(r2.right, Seq(expectedType)) checkOutput(r3.right, Seq(expectedType)) - val r4 = dp(Union(plan2, plan1)).asInstanceOf[Union] - val r5 = dp(Except(plan2, plan1)).asInstanceOf[Except] - val r6 = dp(Intersect(plan2, plan1)).asInstanceOf[Intersect] + val r4 = widenSetOperationTypes(Union(plan2, plan1)).asInstanceOf[Union] + val r5 = widenSetOperationTypes(Except(plan2, plan1)).asInstanceOf[Except] + val r6 = widenSetOperationTypes(Intersect(plan2, plan1)).asInstanceOf[Intersect] checkOutput(r4.children.last, Seq(expectedType)) checkOutput(r5.left, Seq(expectedType)) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationsSuite.scala index 82be69a0f7d7b..c39e372c272b1 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationsSuite.scala @@ -17,19 +17,20 @@ package org.apache.spark.sql.catalyst.analysis +import java.util.Locale + import org.apache.spark.SparkFunSuite import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.dsl.plans._ import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder -import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, Literal, NamedExpression} +import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, NamedExpression} import org.apache.spark.sql.catalyst.expressions.aggregate.Count import org.apache.spark.sql.catalyst.plans._ -import org.apache.spark.sql.catalyst.plans.logical.{MapGroupsWithState, _} +import org.apache.spark.sql.catalyst.plans.logical.{FlatMapGroupsWithState, _} import org.apache.spark.sql.catalyst.streaming.InternalOutputModes._ import org.apache.spark.sql.streaming.OutputMode import org.apache.spark.sql.types.{IntegerType, LongType, MetadataBuilder} -import org.apache.spark.unsafe.types.CalendarInterval /** A dummy command for testing unsupported operations. */ case class DummyCommand() extends Command @@ -138,29 +139,228 @@ class UnsupportedOperationsSuite extends SparkFunSuite { outputMode = Complete, expectedMsgs = Seq("distinct aggregation")) - // MapGroupsWithState: Not supported after a streaming aggregation val att = new AttributeReference(name = "a", dataType = LongType)() - assertSupportedInBatchPlan( - "mapGroupsWithState - mapGroupsWithState on batch relation", - MapGroupsWithState(null, att, att, Seq(att), Seq(att), att, att, Seq(att), batchRelation)) + // FlatMapGroupsWithState: Both function modes equivalent and supported in batch. + for (funcMode <- Seq(Append, Update)) { + assertSupportedInBatchPlan( + s"flatMapGroupsWithState - flatMapGroupsWithState($funcMode) on batch relation", + FlatMapGroupsWithState( + null, att, att, Seq(att), Seq(att), att, null, funcMode, isMapGroupsWithState = false, null, + batchRelation)) + + assertSupportedInBatchPlan( + s"flatMapGroupsWithState - multiple flatMapGroupsWithState($funcMode)s on batch relation", + FlatMapGroupsWithState( + null, att, att, Seq(att), Seq(att), att, null, funcMode, isMapGroupsWithState = false, null, + FlatMapGroupsWithState( + null, att, att, Seq(att), Seq(att), att, null, funcMode, isMapGroupsWithState = false, + null, batchRelation))) + } + + // FlatMapGroupsWithState(Update) in streaming without aggregation + assertSupportedInStreamingPlan( + "flatMapGroupsWithState - flatMapGroupsWithState(Update) " + + "on streaming relation without aggregation in update mode", + FlatMapGroupsWithState( + null, att, att, Seq(att), Seq(att), att, null, Update, isMapGroupsWithState = false, null, + streamRelation), + outputMode = Update) + + assertNotSupportedInStreamingPlan( + "flatMapGroupsWithState - flatMapGroupsWithState(Update) " + + "on streaming relation without aggregation in append mode", + FlatMapGroupsWithState( + null, att, att, Seq(att), Seq(att), att, null, Update, isMapGroupsWithState = false, null, + streamRelation), + outputMode = Append, + expectedMsgs = Seq("flatMapGroupsWithState in update mode", "Append")) + + assertNotSupportedInStreamingPlan( + "flatMapGroupsWithState - flatMapGroupsWithState(Update) " + + "on streaming relation without aggregation in complete mode", + FlatMapGroupsWithState( + null, att, att, Seq(att), Seq(att), att, null, Update, isMapGroupsWithState = false, null, + streamRelation), + outputMode = Complete, + // Disallowed by the aggregation check but let's still keep this test in case it's broken in + // future. + expectedMsgs = Seq("Complete")) + + // FlatMapGroupsWithState(Update) in streaming with aggregation + for (outputMode <- Seq(Append, Update, Complete)) { + assertNotSupportedInStreamingPlan( + "flatMapGroupsWithState - flatMapGroupsWithState(Update) on streaming relation " + + s"with aggregation in $outputMode mode", + FlatMapGroupsWithState( + null, att, att, Seq(att), Seq(att), att, null, Update, isMapGroupsWithState = false, null, + Aggregate(Seq(attributeWithWatermark), aggExprs("c"), streamRelation)), + outputMode = outputMode, + expectedMsgs = Seq("flatMapGroupsWithState in update mode", "with aggregation")) + } + + // FlatMapGroupsWithState(Append) in streaming without aggregation + assertSupportedInStreamingPlan( + "flatMapGroupsWithState - flatMapGroupsWithState(Append) " + + "on streaming relation without aggregation in append mode", + FlatMapGroupsWithState( + null, att, att, Seq(att), Seq(att), att, null, Append, isMapGroupsWithState = false, null, + streamRelation), + outputMode = Append) + + assertNotSupportedInStreamingPlan( + "flatMapGroupsWithState - flatMapGroupsWithState(Append) " + + "on streaming relation without aggregation in update mode", + FlatMapGroupsWithState( + null, att, att, Seq(att), Seq(att), att, null, Append, isMapGroupsWithState = false, null, + streamRelation), + outputMode = Update, + expectedMsgs = Seq("flatMapGroupsWithState in append mode", "update")) + + // FlatMapGroupsWithState(Append) in streaming with aggregation + for (outputMode <- Seq(Append, Update, Complete)) { + assertSupportedInStreamingPlan( + "flatMapGroupsWithState - flatMapGroupsWithState(Append) " + + s"on streaming relation before aggregation in $outputMode mode", + Aggregate( + Seq(attributeWithWatermark), + aggExprs("c"), + FlatMapGroupsWithState( + null, att, att, Seq(att), Seq(att), att, null, Append, isMapGroupsWithState = false, null, + streamRelation)), + outputMode = outputMode) + } + + for (outputMode <- Seq(Append, Update)) { + assertNotSupportedInStreamingPlan( + "flatMapGroupsWithState - flatMapGroupsWithState(Append) " + + s"on streaming relation after aggregation in $outputMode mode", + FlatMapGroupsWithState(null, att, att, Seq(att), Seq(att), att, null, Append, + isMapGroupsWithState = false, null, + Aggregate(Seq(attributeWithWatermark), aggExprs("c"), streamRelation)), + outputMode = outputMode, + expectedMsgs = Seq("flatMapGroupsWithState", "after aggregation")) + } + + assertNotSupportedInStreamingPlan( + "flatMapGroupsWithState - " + + "flatMapGroupsWithState(Update) on streaming relation in complete mode", + FlatMapGroupsWithState( + null, att, att, Seq(att), Seq(att), att, null, Append, isMapGroupsWithState = false, null, + streamRelation), + outputMode = Complete, + // Disallowed by the aggregation check but let's still keep this test in case it's broken in + // future. + expectedMsgs = Seq("Complete")) + // FlatMapGroupsWithState inside batch relation should always be allowed + for (funcMode <- Seq(Append, Update)) { + for (outputMode <- Seq(Append, Update)) { // Complete is not supported without aggregation + assertSupportedInStreamingPlan( + s"flatMapGroupsWithState - flatMapGroupsWithState($funcMode) on batch relation inside " + + s"streaming relation in $outputMode output mode", + FlatMapGroupsWithState( + null, att, att, Seq(att), Seq(att), att, null, funcMode, isMapGroupsWithState = false, + null, batchRelation), + outputMode = outputMode + ) + } + } + + // multiple FlatMapGroupsWithStates assertSupportedInStreamingPlan( - "mapGroupsWithState - mapGroupsWithState on streaming relation before aggregation", - MapGroupsWithState(null, att, att, Seq(att), Seq(att), att, att, Seq(att), streamRelation), + "flatMapGroupsWithState - multiple flatMapGroupsWithStates on streaming relation and all are " + + "in append mode", + FlatMapGroupsWithState(null, att, att, Seq(att), Seq(att), att, null, Append, + isMapGroupsWithState = false, null, + FlatMapGroupsWithState(null, att, att, Seq(att), Seq(att), att, null, Append, + isMapGroupsWithState = false, null, streamRelation)), outputMode = Append) assertNotSupportedInStreamingPlan( - "mapGroupsWithState - mapGroupsWithState on streaming relation after aggregation", - MapGroupsWithState(null, att, att, Seq(att), Seq(att), att, att, Seq(att), - Aggregate(Nil, aggExprs("c"), streamRelation)), + "flatMapGroupsWithState - multiple flatMapGroupsWithStates on s streaming relation but some" + + " are not in append mode", + FlatMapGroupsWithState( + null, att, att, Seq(att), Seq(att), att, null, Update, isMapGroupsWithState = false, null, + FlatMapGroupsWithState( + null, att, att, Seq(att), Seq(att), att, null, Append, isMapGroupsWithState = false, null, + streamRelation)), + outputMode = Append, + expectedMsgs = Seq("multiple flatMapGroupsWithState", "append")) + + // mapGroupsWithState + assertNotSupportedInStreamingPlan( + "mapGroupsWithState - mapGroupsWithState " + + "on streaming relation without aggregation in append mode", + FlatMapGroupsWithState( + null, att, att, Seq(att), Seq(att), att, null, Update, isMapGroupsWithState = true, null, + streamRelation), + outputMode = Append, + // Disallowed by the aggregation check but let's still keep this test in case it's broken in + // future. + expectedMsgs = Seq("mapGroupsWithState", "append")) + + assertNotSupportedInStreamingPlan( + "mapGroupsWithState - mapGroupsWithState " + + "on streaming relation without aggregation in complete mode", + FlatMapGroupsWithState( + null, att, att, Seq(att), Seq(att), att, null, Update, isMapGroupsWithState = true, null, + streamRelation), outputMode = Complete, - expectedMsgs = Seq("(map/flatMap)GroupsWithState")) + // Disallowed by the aggregation check but let's still keep this test in case it's broken in + // future. + expectedMsgs = Seq("Complete")) + + for (outputMode <- Seq(Append, Update, Complete)) { + assertNotSupportedInStreamingPlan( + "mapGroupsWithState - mapGroupsWithState on streaming relation " + + s"with aggregation in $outputMode mode", + FlatMapGroupsWithState(null, att, att, Seq(att), Seq(att), att, null, Update, + isMapGroupsWithState = true, null, + Aggregate(Seq(attributeWithWatermark), aggExprs("c"), streamRelation)), + outputMode = outputMode, + expectedMsgs = Seq("mapGroupsWithState", "with aggregation")) + } + + // multiple mapGroupsWithStates + assertNotSupportedInStreamingPlan( + "mapGroupsWithState - multiple mapGroupsWithStates on streaming relation and all are " + + "in append mode", + FlatMapGroupsWithState( + null, att, att, Seq(att), Seq(att), att, null, Update, isMapGroupsWithState = true, null, + FlatMapGroupsWithState( + null, att, att, Seq(att), Seq(att), att, null, Update, isMapGroupsWithState = true, null, + streamRelation)), + outputMode = Append, + expectedMsgs = Seq("multiple mapGroupsWithStates")) + + // mixing mapGroupsWithStates and flatMapGroupsWithStates + assertNotSupportedInStreamingPlan( + "mapGroupsWithState - " + + "mixing mapGroupsWithStates and flatMapGroupsWithStates on streaming relation", + FlatMapGroupsWithState( + null, att, att, Seq(att), Seq(att), att, null, Update, isMapGroupsWithState = true, null, + FlatMapGroupsWithState( + null, att, att, Seq(att), Seq(att), att, null, Update, isMapGroupsWithState = false, null, + streamRelation) + ), + outputMode = Append, + expectedMsgs = Seq("Mixing mapGroupsWithStates and flatMapGroupsWithStates")) + + // mapGroupsWithState with event time timeout + watermark + assertNotSupportedInStreamingPlan( + "mapGroupsWithState - mapGroupsWithState with event time timeout without watermark", + FlatMapGroupsWithState( + null, att, att, Seq(att), Seq(att), att, null, Update, isMapGroupsWithState = true, + EventTimeTimeout, streamRelation), + outputMode = Update, + expectedMsgs = Seq("watermark")) assertSupportedInStreamingPlan( - "mapGroupsWithState - mapGroupsWithState on batch relation inside streaming relation", - MapGroupsWithState(null, att, att, Seq(att), Seq(att), att, att, Seq(att), batchRelation), - outputMode = Append - ) + "mapGroupsWithState - mapGroupsWithState with event time timeout with watermark", + FlatMapGroupsWithState( + null, att, att, Seq(att), Seq(att), att, null, Update, isMapGroupsWithState = true, + EventTimeTimeout, new TestStreamingRelation(attributeWithWatermark)), + outputMode = Update) // Deduplicate assertSupportedInStreamingPlan( @@ -497,7 +697,7 @@ class UnsupportedOperationsSuite extends SparkFunSuite { testBody } expectedMsgs.foreach { m => - if (!e.getMessage.toLowerCase.contains(m.toLowerCase)) { + if (!e.getMessage.toLowerCase(Locale.ROOT).contains(m.toLowerCase(Locale.ROOT))) { fail(s"Exception message should contain: '$m', " + s"actual exception message:\n\t'${e.getMessage}'") } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalogEventSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalogEventSuite.scala new file mode 100644 index 0000000000000..2539ea615ff92 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalogEventSuite.scala @@ -0,0 +1,188 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.catalyst.catalog + +import java.net.URI +import java.nio.file.{Files, Path} + +import scala.collection.mutable + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.AnalysisException +import org.apache.spark.sql.catalyst.{FunctionIdentifier, TableIdentifier} +import org.apache.spark.sql.types.StructType + +/** + * Test Suite for external catalog events + */ +class ExternalCatalogEventSuite extends SparkFunSuite { + + protected def newCatalog: ExternalCatalog = new InMemoryCatalog() + + private def testWithCatalog( + name: String)( + f: (ExternalCatalog, Seq[ExternalCatalogEvent] => Unit) => Unit): Unit = test(name) { + val catalog = newCatalog + val recorder = mutable.Buffer.empty[ExternalCatalogEvent] + catalog.addListener(new ExternalCatalogEventListener { + override def onEvent(event: ExternalCatalogEvent): Unit = { + recorder += event + } + }) + f(catalog, (expected: Seq[ExternalCatalogEvent]) => { + val actual = recorder.clone() + recorder.clear() + assert(expected === actual) + }) + } + + private def createDbDefinition(uri: URI): CatalogDatabase = { + CatalogDatabase(name = "db5", description = "", locationUri = uri, Map.empty) + } + + private def createDbDefinition(): CatalogDatabase = { + createDbDefinition(preparePath(Files.createTempDirectory("db_"))) + } + + private def preparePath(path: Path): URI = path.normalize().toUri + + testWithCatalog("database") { (catalog, checkEvents) => + // CREATE + val dbDefinition = createDbDefinition() + + catalog.createDatabase(dbDefinition, ignoreIfExists = false) + checkEvents(CreateDatabasePreEvent("db5") :: CreateDatabaseEvent("db5") :: Nil) + + catalog.createDatabase(dbDefinition, ignoreIfExists = true) + checkEvents(CreateDatabasePreEvent("db5") :: CreateDatabaseEvent("db5") :: Nil) + + intercept[AnalysisException] { + catalog.createDatabase(dbDefinition, ignoreIfExists = false) + } + checkEvents(CreateDatabasePreEvent("db5") :: Nil) + + // DROP + intercept[AnalysisException] { + catalog.dropDatabase("db4", ignoreIfNotExists = false, cascade = false) + } + checkEvents(DropDatabasePreEvent("db4") :: Nil) + + catalog.dropDatabase("db5", ignoreIfNotExists = false, cascade = false) + checkEvents(DropDatabasePreEvent("db5") :: DropDatabaseEvent("db5") :: Nil) + + catalog.dropDatabase("db4", ignoreIfNotExists = true, cascade = false) + checkEvents(DropDatabasePreEvent("db4") :: DropDatabaseEvent("db4") :: Nil) + } + + testWithCatalog("table") { (catalog, checkEvents) => + val path1 = Files.createTempDirectory("db_") + val path2 = Files.createTempDirectory(path1, "tbl_") + val uri1 = preparePath(path1) + val uri2 = preparePath(path2) + + // CREATE + val dbDefinition = createDbDefinition(uri1) + + val storage = CatalogStorageFormat.empty.copy( + locationUri = Option(uri2)) + val tableDefinition = CatalogTable( + identifier = TableIdentifier("tbl1", Some("db5")), + tableType = CatalogTableType.MANAGED, + storage = storage, + schema = new StructType().add("id", "long")) + + catalog.createDatabase(dbDefinition, ignoreIfExists = false) + checkEvents(CreateDatabasePreEvent("db5") :: CreateDatabaseEvent("db5") :: Nil) + + catalog.createTable(tableDefinition, ignoreIfExists = false) + checkEvents(CreateTablePreEvent("db5", "tbl1") :: CreateTableEvent("db5", "tbl1") :: Nil) + + catalog.createTable(tableDefinition, ignoreIfExists = true) + checkEvents(CreateTablePreEvent("db5", "tbl1") :: CreateTableEvent("db5", "tbl1") :: Nil) + + intercept[AnalysisException] { + catalog.createTable(tableDefinition, ignoreIfExists = false) + } + checkEvents(CreateTablePreEvent("db5", "tbl1") :: Nil) + + // RENAME + catalog.renameTable("db5", "tbl1", "tbl2") + checkEvents( + RenameTablePreEvent("db5", "tbl1", "tbl2") :: + RenameTableEvent("db5", "tbl1", "tbl2") :: Nil) + + intercept[AnalysisException] { + catalog.renameTable("db5", "tbl1", "tbl2") + } + checkEvents(RenameTablePreEvent("db5", "tbl1", "tbl2") :: Nil) + + // DROP + intercept[AnalysisException] { + catalog.dropTable("db5", "tbl1", ignoreIfNotExists = false, purge = true) + } + checkEvents(DropTablePreEvent("db5", "tbl1") :: Nil) + + catalog.dropTable("db5", "tbl2", ignoreIfNotExists = false, purge = true) + checkEvents(DropTablePreEvent("db5", "tbl2") :: DropTableEvent("db5", "tbl2") :: Nil) + + catalog.dropTable("db5", "tbl2", ignoreIfNotExists = true, purge = true) + checkEvents(DropTablePreEvent("db5", "tbl2") :: DropTableEvent("db5", "tbl2") :: Nil) + } + + testWithCatalog("function") { (catalog, checkEvents) => + // CREATE + val dbDefinition = createDbDefinition() + + val functionDefinition = CatalogFunction( + identifier = FunctionIdentifier("fn7", Some("db5")), + className = "", + resources = Seq.empty) + + val newIdentifier = functionDefinition.identifier.copy(funcName = "fn4") + val renamedFunctionDefinition = functionDefinition.copy(identifier = newIdentifier) + + catalog.createDatabase(dbDefinition, ignoreIfExists = false) + checkEvents(CreateDatabasePreEvent("db5") :: CreateDatabaseEvent("db5") :: Nil) + + catalog.createFunction("db5", functionDefinition) + checkEvents(CreateFunctionPreEvent("db5", "fn7") :: CreateFunctionEvent("db5", "fn7") :: Nil) + + intercept[AnalysisException] { + catalog.createFunction("db5", functionDefinition) + } + checkEvents(CreateFunctionPreEvent("db5", "fn7") :: Nil) + + // RENAME + catalog.renameFunction("db5", "fn7", "fn4") + checkEvents( + RenameFunctionPreEvent("db5", "fn7", "fn4") :: + RenameFunctionEvent("db5", "fn7", "fn4") :: Nil) + intercept[AnalysisException] { + catalog.renameFunction("db5", "fn7", "fn4") + } + checkEvents(RenameFunctionPreEvent("db5", "fn7", "fn4") :: Nil) + + // DROP + intercept[AnalysisException] { + catalog.dropFunction("db5", "fn7") + } + checkEvents(DropFunctionPreEvent("db5", "fn7") :: Nil) + + catalog.dropFunction("db5", "fn4") + checkEvents(DropFunctionPreEvent("db5", "fn4") :: DropFunctionEvent("db5", "fn4") :: Nil) + } +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalogSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalogSuite.scala index a5d399a065589..42db4398e5072 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalogSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalogSuite.scala @@ -17,6 +17,9 @@ package org.apache.spark.sql.catalyst.catalog +import java.net.URI +import java.util.TimeZone + import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.Path import org.scalatest.BeforeAndAfterEach @@ -26,7 +29,9 @@ import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.{FunctionIdentifier, TableIdentifier} import org.apache.spark.sql.catalyst.analysis.{FunctionAlreadyExistsException, NoSuchDatabaseException, NoSuchFunctionException} import org.apache.spark.sql.catalyst.analysis.TableAlreadyExistsException -import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.types._ import org.apache.spark.util.Utils @@ -238,6 +243,19 @@ abstract class ExternalCatalogSuite extends SparkFunSuite with BeforeAndAfterEac } } + test("alter table schema") { + val catalog = newBasicCatalog() + val tbl1 = catalog.getTable("db2", "tbl1") + val newSchema = StructType(Seq( + StructField("new_field_1", IntegerType), + StructField("new_field_2", StringType), + StructField("a", IntegerType), + StructField("b", StringType))) + catalog.alterTableSchema("db2", "tbl1", newSchema) + val newTbl1 = catalog.getTable("db2", "tbl1") + assert(newTbl1.schema == newSchema) + } + test("get table") { assert(newBasicCatalog().getTable("db2", "tbl1").identifier.table == "tbl1") } @@ -340,7 +358,7 @@ abstract class ExternalCatalogSuite extends SparkFunSuite with BeforeAndAfterEac "db1", "tbl", Map("partCol1" -> "1", "partCol2" -> "2")).location - val tableLocation = catalog.getTable("db1", "tbl").location + val tableLocation = new Path(catalog.getTable("db1", "tbl").location) val defaultPartitionLocation = new Path(new Path(tableLocation, "partCol1=1"), "partCol2=2") assert(new Path(partitionLocation) == defaultPartitionLocation) } @@ -421,6 +439,44 @@ abstract class ExternalCatalogSuite extends SparkFunSuite with BeforeAndAfterEac assert(catalog.listPartitions("db2", "tbl2", Some(Map("a" -> "unknown"))).isEmpty) } + test("list partitions by filter") { + val tz = TimeZone.getDefault.getID + val catalog = newBasicCatalog() + + def checkAnswer( + table: CatalogTable, filters: Seq[Expression], expected: Set[CatalogTablePartition]) + : Unit = { + + assertResult(expected.map(_.spec)) { + catalog.listPartitionsByFilter(table.database, table.identifier.identifier, filters, tz) + .map(_.spec).toSet + } + } + + val tbl2 = catalog.getTable("db2", "tbl2") + + checkAnswer(tbl2, Seq.empty, Set(part1, part2)) + checkAnswer(tbl2, Seq('a.int <= 1), Set(part1)) + checkAnswer(tbl2, Seq('a.int === 2), Set.empty) + checkAnswer(tbl2, Seq(In('a.int * 10, Seq(30))), Set(part2)) + checkAnswer(tbl2, Seq(Not(In('a.int, Seq(4)))), Set(part1, part2)) + checkAnswer(tbl2, Seq('a.int === 1, 'b.string === "2"), Set(part1)) + checkAnswer(tbl2, Seq('a.int === 1 && 'b.string === "2"), Set(part1)) + checkAnswer(tbl2, Seq('a.int === 1, 'b.string === "x"), Set.empty) + checkAnswer(tbl2, Seq('a.int === 1 || 'b.string === "x"), Set(part1)) + + intercept[AnalysisException] { + try { + checkAnswer(tbl2, Seq('a.int > 0 && 'col1.int > 0), Set.empty) + } catch { + // HiveExternalCatalog may be the first one to notice and throw an exception, which will + // then be caught and converted to a RuntimeException with a descriptive message. + case ex: RuntimeException if ex.getMessage.contains("MetaException") => + throw new AnalysisException(ex.getMessage) + } + } + } + test("drop partitions") { val catalog = newBasicCatalog() assert(catalogPartitionsEqual(catalog, "db2", "tbl2", Seq(part1, part2))) @@ -508,7 +564,7 @@ abstract class ExternalCatalogSuite extends SparkFunSuite with BeforeAndAfterEac partitionColumnNames = Seq("partCol1", "partCol2")) catalog.createTable(table, ignoreIfExists = false) - val tableLocation = catalog.getTable("db1", "tbl").location + val tableLocation = new Path(catalog.getTable("db1", "tbl").location) val mixedCasePart1 = CatalogTablePartition( Map("partCol1" -> "1", "partCol2" -> "2"), storageFormat) @@ -699,7 +755,7 @@ abstract class ExternalCatalogSuite extends SparkFunSuite with BeforeAndAfterEac // File System operations // -------------------------------------------------------------------------- - private def exists(uri: String, children: String*): Boolean = { + private def exists(uri: URI, children: String*): Boolean = { val base = new Path(uri) val finalPath = children.foldLeft(base) { case (parent, child) => new Path(parent, child) @@ -742,7 +798,7 @@ abstract class ExternalCatalogSuite extends SparkFunSuite with BeforeAndAfterEac identifier = TableIdentifier("external_table", Some("db1")), tableType = CatalogTableType.EXTERNAL, storage = CatalogStorageFormat( - Some(Utils.createTempDir().getAbsolutePath), + Some(Utils.createTempDir().toURI), None, None, None, false, Map.empty), schema = new StructType().add("a", "int").add("b", "string"), provider = Some(defaultProvider) @@ -790,7 +846,7 @@ abstract class ExternalCatalogSuite extends SparkFunSuite with BeforeAndAfterEac val partWithExistingDir = CatalogTablePartition( Map("partCol1" -> "7", "partCol2" -> "8"), CatalogStorageFormat( - Some(tempPath.toURI.toString), + Some(tempPath.toURI), None, None, None, false, Map.empty)) catalog.createPartitions("db1", "tbl", Seq(partWithExistingDir), ignoreIfExists = false) @@ -799,7 +855,7 @@ abstract class ExternalCatalogSuite extends SparkFunSuite with BeforeAndAfterEac val partWithNonExistingDir = CatalogTablePartition( Map("partCol1" -> "9", "partCol2" -> "10"), CatalogStorageFormat( - Some(tempPath.toURI.toString), + Some(tempPath.toURI), None, None, None, false, Map.empty)) catalog.createPartitions("db1", "tbl", Seq(partWithNonExistingDir), ignoreIfExists = false) assert(tempPath.exists()) @@ -883,7 +939,7 @@ abstract class CatalogTestUtils { def newFunc(): CatalogFunction = newFunc("funcName") - def newUriForDatabase(): String = Utils.createTempDir().toURI.toString.stripSuffix("/") + def newUriForDatabase(): URI = new URI(Utils.createTempDir().toURI.toString.stripSuffix("/")) def newDb(name: String): CatalogDatabase = { CatalogDatabase(name, name + " description", newUriForDatabase(), Map.empty) @@ -895,7 +951,7 @@ abstract class CatalogTestUtils { CatalogTable( identifier = TableIdentifier(name, database), tableType = CatalogTableType.EXTERNAL, - storage = storageFormat.copy(locationUri = Some(Utils.createTempDir().getAbsolutePath)), + storage = storageFormat.copy(locationUri = Some(Utils.createTempDir().toURI)), schema = new StructType() .add("col1", "int") .add("col2", "string") diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalogSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalogSuite.scala index a755231962be2..be8903000a0d1 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalogSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalogSuite.scala @@ -24,42 +24,68 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.parser.CatalystSqlParser import org.apache.spark.sql.catalyst.plans.PlanTest import org.apache.spark.sql.catalyst.plans.logical.{Range, SubqueryAlias, View} +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.types._ + +class InMemorySessionCatalogSuite extends SessionCatalogSuite { + protected val utils = new CatalogTestUtils { + override val tableInputFormat: String = "com.fruit.eyephone.CameraInputFormat" + override val tableOutputFormat: String = "com.fruit.eyephone.CameraOutputFormat" + override val defaultProvider: String = "parquet" + override def newEmptyCatalog(): ExternalCatalog = new InMemoryCatalog + } +} /** - * Tests for [[SessionCatalog]] that assume that [[InMemoryCatalog]] is correctly implemented. + * Tests for [[SessionCatalog]] * * Note: many of the methods here are very similar to the ones in [[ExternalCatalogSuite]]. * This is because [[SessionCatalog]] and [[ExternalCatalog]] share many similar method * signatures but do not extend a common parent. This is largely by design but * unfortunately leads to very similar test code in two places. */ -class SessionCatalogSuite extends PlanTest { - private val utils = new CatalogTestUtils { - override val tableInputFormat: String = "com.fruit.eyephone.CameraInputFormat" - override val tableOutputFormat: String = "com.fruit.eyephone.CameraOutputFormat" - override val defaultProvider: String = "parquet" - override def newEmptyCatalog(): ExternalCatalog = new InMemoryCatalog - } +abstract class SessionCatalogSuite extends PlanTest { + protected val utils: CatalogTestUtils + + protected val isHiveExternalCatalog = false import utils._ + private def withBasicCatalog(f: SessionCatalog => Unit): Unit = { + val catalog = new SessionCatalog(newBasicCatalog()) + try { + f(catalog) + } finally { + catalog.reset() + } + } + + private def withEmptyCatalog(f: SessionCatalog => Unit): Unit = { + val catalog = new SessionCatalog(newEmptyCatalog()) + catalog.createDatabase(newDb("default"), ignoreIfExists = true) + try { + f(catalog) + } finally { + catalog.reset() + } + } // -------------------------------------------------------------------------- // Databases // -------------------------------------------------------------------------- test("basic create and list databases") { - val catalog = new SessionCatalog(newEmptyCatalog()) - catalog.createDatabase(newDb("default"), ignoreIfExists = true) - assert(catalog.databaseExists("default")) - assert(!catalog.databaseExists("testing")) - assert(!catalog.databaseExists("testing2")) - catalog.createDatabase(newDb("testing"), ignoreIfExists = false) - assert(catalog.databaseExists("testing")) - assert(catalog.listDatabases().toSet == Set("default", "testing")) - catalog.createDatabase(newDb("testing2"), ignoreIfExists = false) - assert(catalog.listDatabases().toSet == Set("default", "testing", "testing2")) - assert(catalog.databaseExists("testing2")) - assert(!catalog.databaseExists("does_not_exist")) + withEmptyCatalog { catalog => + assert(catalog.databaseExists("default")) + assert(!catalog.databaseExists("testing")) + assert(!catalog.databaseExists("testing2")) + catalog.createDatabase(newDb("testing"), ignoreIfExists = false) + assert(catalog.databaseExists("testing")) + assert(catalog.listDatabases().toSet == Set("default", "testing")) + catalog.createDatabase(newDb("testing2"), ignoreIfExists = false) + assert(catalog.listDatabases().toSet == Set("default", "testing", "testing2")) + assert(catalog.databaseExists("testing2")) + assert(!catalog.databaseExists("does_not_exist")) + } } def testInvalidName(func: (String) => Unit) { @@ -74,121 +100,141 @@ class SessionCatalogSuite extends PlanTest { } test("create databases using invalid names") { - val catalog = new SessionCatalog(newEmptyCatalog()) - testInvalidName(name => catalog.createDatabase(newDb(name), ignoreIfExists = true)) + withEmptyCatalog { catalog => + testInvalidName( + name => catalog.createDatabase(newDb(name), ignoreIfExists = true)) + } } test("get database when a database exists") { - val catalog = new SessionCatalog(newBasicCatalog()) - val db1 = catalog.getDatabaseMetadata("db1") - assert(db1.name == "db1") - assert(db1.description.contains("db1")) + withBasicCatalog { catalog => + val db1 = catalog.getDatabaseMetadata("db1") + assert(db1.name == "db1") + assert(db1.description.contains("db1")) + } } test("get database should throw exception when the database does not exist") { - val catalog = new SessionCatalog(newBasicCatalog()) - intercept[NoSuchDatabaseException] { - catalog.getDatabaseMetadata("db_that_does_not_exist") + withBasicCatalog { catalog => + intercept[NoSuchDatabaseException] { + catalog.getDatabaseMetadata("db_that_does_not_exist") + } } } test("list databases without pattern") { - val catalog = new SessionCatalog(newBasicCatalog()) - assert(catalog.listDatabases().toSet == Set("default", "db1", "db2", "db3")) + withBasicCatalog { catalog => + assert(catalog.listDatabases().toSet == Set("default", "db1", "db2", "db3")) + } } test("list databases with pattern") { - val catalog = new SessionCatalog(newBasicCatalog()) - assert(catalog.listDatabases("db").toSet == Set.empty) - assert(catalog.listDatabases("db*").toSet == Set("db1", "db2", "db3")) - assert(catalog.listDatabases("*1").toSet == Set("db1")) - assert(catalog.listDatabases("db2").toSet == Set("db2")) + withBasicCatalog { catalog => + assert(catalog.listDatabases("db").toSet == Set.empty) + assert(catalog.listDatabases("db*").toSet == Set("db1", "db2", "db3")) + assert(catalog.listDatabases("*1").toSet == Set("db1")) + assert(catalog.listDatabases("db2").toSet == Set("db2")) + } } test("drop database") { - val catalog = new SessionCatalog(newBasicCatalog()) - catalog.dropDatabase("db1", ignoreIfNotExists = false, cascade = false) - assert(catalog.listDatabases().toSet == Set("default", "db2", "db3")) + withBasicCatalog { catalog => + catalog.dropDatabase("db1", ignoreIfNotExists = false, cascade = false) + assert(catalog.listDatabases().toSet == Set("default", "db2", "db3")) + } } test("drop database when the database is not empty") { // Throw exception if there are functions left - val externalCatalog1 = newBasicCatalog() - val sessionCatalog1 = new SessionCatalog(externalCatalog1) - externalCatalog1.dropTable("db2", "tbl1", ignoreIfNotExists = false, purge = false) - externalCatalog1.dropTable("db2", "tbl2", ignoreIfNotExists = false, purge = false) - intercept[AnalysisException] { - sessionCatalog1.dropDatabase("db2", ignoreIfNotExists = false, cascade = false) + withBasicCatalog { catalog => + catalog.externalCatalog.dropTable("db2", "tbl1", ignoreIfNotExists = false, purge = false) + catalog.externalCatalog.dropTable("db2", "tbl2", ignoreIfNotExists = false, purge = false) + intercept[AnalysisException] { + catalog.dropDatabase("db2", ignoreIfNotExists = false, cascade = false) + } } - - // Throw exception if there are tables left - val externalCatalog2 = newBasicCatalog() - val sessionCatalog2 = new SessionCatalog(externalCatalog2) - externalCatalog2.dropFunction("db2", "func1") - intercept[AnalysisException] { - sessionCatalog2.dropDatabase("db2", ignoreIfNotExists = false, cascade = false) + withBasicCatalog { catalog => + // Throw exception if there are tables left + catalog.externalCatalog.dropFunction("db2", "func1") + intercept[AnalysisException] { + catalog.dropDatabase("db2", ignoreIfNotExists = false, cascade = false) + } } - // When cascade is true, it should drop them - val externalCatalog3 = newBasicCatalog() - val sessionCatalog3 = new SessionCatalog(externalCatalog3) - externalCatalog3.dropDatabase("db2", ignoreIfNotExists = false, cascade = true) - assert(sessionCatalog3.listDatabases().toSet == Set("default", "db1", "db3")) + withBasicCatalog { catalog => + // When cascade is true, it should drop them + catalog.externalCatalog.dropDatabase("db2", ignoreIfNotExists = false, cascade = true) + assert(catalog.listDatabases().toSet == Set("default", "db1", "db3")) + } } test("drop database when the database does not exist") { - val catalog = new SessionCatalog(newBasicCatalog()) - intercept[NoSuchDatabaseException] { - catalog.dropDatabase("db_that_does_not_exist", ignoreIfNotExists = false, cascade = false) + withBasicCatalog { catalog => + // TODO: fix this inconsistent between HiveExternalCatalog and InMemoryCatalog + if (isHiveExternalCatalog) { + val e = intercept[AnalysisException] { + catalog.dropDatabase("db_that_does_not_exist", ignoreIfNotExists = false, cascade = false) + }.getMessage + assert(e.contains( + "org.apache.hadoop.hive.metastore.api.NoSuchObjectException: db_that_does_not_exist")) + } else { + intercept[NoSuchDatabaseException] { + catalog.dropDatabase("db_that_does_not_exist", ignoreIfNotExists = false, cascade = false) + } + } + catalog.dropDatabase("db_that_does_not_exist", ignoreIfNotExists = true, cascade = false) } - catalog.dropDatabase("db_that_does_not_exist", ignoreIfNotExists = true, cascade = false) } test("drop current database and drop default database") { - val catalog = new SessionCatalog(newBasicCatalog()) - catalog.setCurrentDatabase("db1") - assert(catalog.getCurrentDatabase == "db1") - catalog.dropDatabase("db1", ignoreIfNotExists = false, cascade = true) - intercept[NoSuchDatabaseException] { - catalog.createTable(newTable("tbl1", "db1"), ignoreIfExists = false) - } - catalog.setCurrentDatabase("default") - assert(catalog.getCurrentDatabase == "default") - intercept[AnalysisException] { - catalog.dropDatabase("default", ignoreIfNotExists = false, cascade = true) + withBasicCatalog { catalog => + catalog.setCurrentDatabase("db1") + assert(catalog.getCurrentDatabase == "db1") + catalog.dropDatabase("db1", ignoreIfNotExists = false, cascade = true) + intercept[NoSuchDatabaseException] { + catalog.createTable(newTable("tbl1", "db1"), ignoreIfExists = false) + } + catalog.setCurrentDatabase("default") + assert(catalog.getCurrentDatabase == "default") + intercept[AnalysisException] { + catalog.dropDatabase("default", ignoreIfNotExists = false, cascade = true) + } } } test("alter database") { - val catalog = new SessionCatalog(newBasicCatalog()) - val db1 = catalog.getDatabaseMetadata("db1") - // Note: alter properties here because Hive does not support altering other fields - catalog.alterDatabase(db1.copy(properties = Map("k" -> "v3", "good" -> "true"))) - val newDb1 = catalog.getDatabaseMetadata("db1") - assert(db1.properties.isEmpty) - assert(newDb1.properties.size == 2) - assert(newDb1.properties.get("k") == Some("v3")) - assert(newDb1.properties.get("good") == Some("true")) + withBasicCatalog { catalog => + val db1 = catalog.getDatabaseMetadata("db1") + // Note: alter properties here because Hive does not support altering other fields + catalog.alterDatabase(db1.copy(properties = Map("k" -> "v3", "good" -> "true"))) + val newDb1 = catalog.getDatabaseMetadata("db1") + assert(db1.properties.isEmpty) + assert(newDb1.properties.size == 2) + assert(newDb1.properties.get("k") == Some("v3")) + assert(newDb1.properties.get("good") == Some("true")) + } } test("alter database should throw exception when the database does not exist") { - val catalog = new SessionCatalog(newBasicCatalog()) - intercept[NoSuchDatabaseException] { - catalog.alterDatabase(newDb("unknown_db")) + withBasicCatalog { catalog => + intercept[NoSuchDatabaseException] { + catalog.alterDatabase(newDb("unknown_db")) + } } } test("get/set current database") { - val catalog = new SessionCatalog(newBasicCatalog()) - assert(catalog.getCurrentDatabase == "default") - catalog.setCurrentDatabase("db2") - assert(catalog.getCurrentDatabase == "db2") - intercept[NoSuchDatabaseException] { + withBasicCatalog { catalog => + assert(catalog.getCurrentDatabase == "default") + catalog.setCurrentDatabase("db2") + assert(catalog.getCurrentDatabase == "db2") + intercept[NoSuchDatabaseException] { + catalog.setCurrentDatabase("deebo") + } + catalog.createDatabase(newDb("deebo"), ignoreIfExists = false) catalog.setCurrentDatabase("deebo") + assert(catalog.getCurrentDatabase == "deebo") } - catalog.createDatabase(newDb("deebo"), ignoreIfExists = false) - catalog.setCurrentDatabase("deebo") - assert(catalog.getCurrentDatabase == "deebo") } // -------------------------------------------------------------------------- @@ -196,346 +242,388 @@ class SessionCatalogSuite extends PlanTest { // -------------------------------------------------------------------------- test("create table") { - val externalCatalog = newBasicCatalog() - val sessionCatalog = new SessionCatalog(externalCatalog) - assert(externalCatalog.listTables("db1").isEmpty) - assert(externalCatalog.listTables("db2").toSet == Set("tbl1", "tbl2")) - sessionCatalog.createTable(newTable("tbl3", "db1"), ignoreIfExists = false) - sessionCatalog.createTable(newTable("tbl3", "db2"), ignoreIfExists = false) - assert(externalCatalog.listTables("db1").toSet == Set("tbl3")) - assert(externalCatalog.listTables("db2").toSet == Set("tbl1", "tbl2", "tbl3")) - // Create table without explicitly specifying database - sessionCatalog.setCurrentDatabase("db1") - sessionCatalog.createTable(newTable("tbl4"), ignoreIfExists = false) - assert(externalCatalog.listTables("db1").toSet == Set("tbl3", "tbl4")) - assert(externalCatalog.listTables("db2").toSet == Set("tbl1", "tbl2", "tbl3")) + withBasicCatalog { catalog => + assert(catalog.externalCatalog.listTables("db1").isEmpty) + assert(catalog.externalCatalog.listTables("db2").toSet == Set("tbl1", "tbl2")) + catalog.createTable(newTable("tbl3", "db1"), ignoreIfExists = false) + catalog.createTable(newTable("tbl3", "db2"), ignoreIfExists = false) + assert(catalog.externalCatalog.listTables("db1").toSet == Set("tbl3")) + assert(catalog.externalCatalog.listTables("db2").toSet == Set("tbl1", "tbl2", "tbl3")) + // Create table without explicitly specifying database + catalog.setCurrentDatabase("db1") + catalog.createTable(newTable("tbl4"), ignoreIfExists = false) + assert(catalog.externalCatalog.listTables("db1").toSet == Set("tbl3", "tbl4")) + assert(catalog.externalCatalog.listTables("db2").toSet == Set("tbl1", "tbl2", "tbl3")) + } } test("create tables using invalid names") { - val catalog = new SessionCatalog(newEmptyCatalog()) - testInvalidName(name => catalog.createTable(newTable(name, "db1"), ignoreIfExists = false)) + withEmptyCatalog { catalog => + testInvalidName(name => catalog.createTable(newTable(name, "db1"), ignoreIfExists = false)) + } } test("create table when database does not exist") { - val catalog = new SessionCatalog(newBasicCatalog()) - // Creating table in non-existent database should always fail - intercept[NoSuchDatabaseException] { - catalog.createTable(newTable("tbl1", "does_not_exist"), ignoreIfExists = false) - } - intercept[NoSuchDatabaseException] { - catalog.createTable(newTable("tbl1", "does_not_exist"), ignoreIfExists = true) + withBasicCatalog { catalog => + // Creating table in non-existent database should always fail + intercept[NoSuchDatabaseException] { + catalog.createTable(newTable("tbl1", "does_not_exist"), ignoreIfExists = false) + } + intercept[NoSuchDatabaseException] { + catalog.createTable(newTable("tbl1", "does_not_exist"), ignoreIfExists = true) + } + // Table already exists + intercept[TableAlreadyExistsException] { + catalog.createTable(newTable("tbl1", "db2"), ignoreIfExists = false) + } + catalog.createTable(newTable("tbl1", "db2"), ignoreIfExists = true) } - // Table already exists - intercept[TableAlreadyExistsException] { - catalog.createTable(newTable("tbl1", "db2"), ignoreIfExists = false) - } - catalog.createTable(newTable("tbl1", "db2"), ignoreIfExists = true) } test("create temp table") { - val catalog = new SessionCatalog(newBasicCatalog()) - val tempTable1 = Range(1, 10, 1, 10) - val tempTable2 = Range(1, 20, 2, 10) - catalog.createTempView("tbl1", tempTable1, overrideIfExists = false) - catalog.createTempView("tbl2", tempTable2, overrideIfExists = false) - assert(catalog.getTempView("tbl1") == Option(tempTable1)) - assert(catalog.getTempView("tbl2") == Option(tempTable2)) - assert(catalog.getTempView("tbl3").isEmpty) - // Temporary table already exists - intercept[TempTableAlreadyExistsException] { + withBasicCatalog { catalog => + val tempTable1 = Range(1, 10, 1, 10) + val tempTable2 = Range(1, 20, 2, 10) catalog.createTempView("tbl1", tempTable1, overrideIfExists = false) + catalog.createTempView("tbl2", tempTable2, overrideIfExists = false) + assert(catalog.getTempView("tbl1") == Option(tempTable1)) + assert(catalog.getTempView("tbl2") == Option(tempTable2)) + assert(catalog.getTempView("tbl3").isEmpty) + // Temporary table already exists + intercept[TempTableAlreadyExistsException] { + catalog.createTempView("tbl1", tempTable1, overrideIfExists = false) + } + // Temporary table already exists but we override it + catalog.createTempView("tbl1", tempTable2, overrideIfExists = true) + assert(catalog.getTempView("tbl1") == Option(tempTable2)) } - // Temporary table already exists but we override it - catalog.createTempView("tbl1", tempTable2, overrideIfExists = true) - assert(catalog.getTempView("tbl1") == Option(tempTable2)) } test("drop table") { - val externalCatalog = newBasicCatalog() - val sessionCatalog = new SessionCatalog(externalCatalog) - assert(externalCatalog.listTables("db2").toSet == Set("tbl1", "tbl2")) - sessionCatalog.dropTable(TableIdentifier("tbl1", Some("db2")), ignoreIfNotExists = false, - purge = false) - assert(externalCatalog.listTables("db2").toSet == Set("tbl2")) - // Drop table without explicitly specifying database - sessionCatalog.setCurrentDatabase("db2") - sessionCatalog.dropTable(TableIdentifier("tbl2"), ignoreIfNotExists = false, purge = false) - assert(externalCatalog.listTables("db2").isEmpty) + withBasicCatalog { catalog => + assert(catalog.externalCatalog.listTables("db2").toSet == Set("tbl1", "tbl2")) + catalog.dropTable(TableIdentifier("tbl1", Some("db2")), ignoreIfNotExists = false, + purge = false) + assert(catalog.externalCatalog.listTables("db2").toSet == Set("tbl2")) + // Drop table without explicitly specifying database + catalog.setCurrentDatabase("db2") + catalog.dropTable(TableIdentifier("tbl2"), ignoreIfNotExists = false, purge = false) + assert(catalog.externalCatalog.listTables("db2").isEmpty) + } } test("drop table when database/table does not exist") { - val catalog = new SessionCatalog(newBasicCatalog()) - // Should always throw exception when the database does not exist - intercept[NoSuchDatabaseException] { - catalog.dropTable(TableIdentifier("tbl1", Some("unknown_db")), ignoreIfNotExists = false, - purge = false) - } - intercept[NoSuchDatabaseException] { - catalog.dropTable(TableIdentifier("tbl1", Some("unknown_db")), ignoreIfNotExists = true, + withBasicCatalog { catalog => + // Should always throw exception when the database does not exist + intercept[NoSuchDatabaseException] { + catalog.dropTable(TableIdentifier("tbl1", Some("unknown_db")), ignoreIfNotExists = false, + purge = false) + } + intercept[NoSuchDatabaseException] { + catalog.dropTable(TableIdentifier("tbl1", Some("unknown_db")), ignoreIfNotExists = true, + purge = false) + } + intercept[NoSuchTableException] { + catalog.dropTable(TableIdentifier("unknown_table", Some("db2")), ignoreIfNotExists = false, + purge = false) + } + catalog.dropTable(TableIdentifier("unknown_table", Some("db2")), ignoreIfNotExists = true, purge = false) } - intercept[NoSuchTableException] { - catalog.dropTable(TableIdentifier("unknown_table", Some("db2")), ignoreIfNotExists = false, - purge = false) - } - catalog.dropTable(TableIdentifier("unknown_table", Some("db2")), ignoreIfNotExists = true, - purge = false) } test("drop temp table") { - val externalCatalog = newBasicCatalog() - val sessionCatalog = new SessionCatalog(externalCatalog) - val tempTable = Range(1, 10, 2, 10) - sessionCatalog.createTempView("tbl1", tempTable, overrideIfExists = false) - sessionCatalog.setCurrentDatabase("db2") - assert(sessionCatalog.getTempView("tbl1") == Some(tempTable)) - assert(externalCatalog.listTables("db2").toSet == Set("tbl1", "tbl2")) - // If database is not specified, temp table should be dropped first - sessionCatalog.dropTable(TableIdentifier("tbl1"), ignoreIfNotExists = false, purge = false) - assert(sessionCatalog.getTempView("tbl1") == None) - assert(externalCatalog.listTables("db2").toSet == Set("tbl1", "tbl2")) - // If temp table does not exist, the table in the current database should be dropped - sessionCatalog.dropTable(TableIdentifier("tbl1"), ignoreIfNotExists = false, purge = false) - assert(externalCatalog.listTables("db2").toSet == Set("tbl2")) - // If database is specified, temp tables are never dropped - sessionCatalog.createTempView("tbl1", tempTable, overrideIfExists = false) - sessionCatalog.createTable(newTable("tbl1", "db2"), ignoreIfExists = false) - sessionCatalog.dropTable(TableIdentifier("tbl1", Some("db2")), ignoreIfNotExists = false, - purge = false) - assert(sessionCatalog.getTempView("tbl1") == Some(tempTable)) - assert(externalCatalog.listTables("db2").toSet == Set("tbl2")) + withBasicCatalog { catalog => + val tempTable = Range(1, 10, 2, 10) + catalog.createTempView("tbl1", tempTable, overrideIfExists = false) + catalog.setCurrentDatabase("db2") + assert(catalog.getTempView("tbl1") == Some(tempTable)) + assert(catalog.externalCatalog.listTables("db2").toSet == Set("tbl1", "tbl2")) + // If database is not specified, temp table should be dropped first + catalog.dropTable(TableIdentifier("tbl1"), ignoreIfNotExists = false, purge = false) + assert(catalog.getTempView("tbl1") == None) + assert(catalog.externalCatalog.listTables("db2").toSet == Set("tbl1", "tbl2")) + // If temp table does not exist, the table in the current database should be dropped + catalog.dropTable(TableIdentifier("tbl1"), ignoreIfNotExists = false, purge = false) + assert(catalog.externalCatalog.listTables("db2").toSet == Set("tbl2")) + // If database is specified, temp tables are never dropped + catalog.createTempView("tbl1", tempTable, overrideIfExists = false) + catalog.createTable(newTable("tbl1", "db2"), ignoreIfExists = false) + catalog.dropTable(TableIdentifier("tbl1", Some("db2")), ignoreIfNotExists = false, + purge = false) + assert(catalog.getTempView("tbl1") == Some(tempTable)) + assert(catalog.externalCatalog.listTables("db2").toSet == Set("tbl2")) + } } test("rename table") { - val externalCatalog = newBasicCatalog() - val sessionCatalog = new SessionCatalog(externalCatalog) - assert(externalCatalog.listTables("db2").toSet == Set("tbl1", "tbl2")) - sessionCatalog.renameTable(TableIdentifier("tbl1", Some("db2")), TableIdentifier("tblone")) - assert(externalCatalog.listTables("db2").toSet == Set("tblone", "tbl2")) - sessionCatalog.renameTable(TableIdentifier("tbl2", Some("db2")), TableIdentifier("tbltwo")) - assert(externalCatalog.listTables("db2").toSet == Set("tblone", "tbltwo")) - // Rename table without explicitly specifying database - sessionCatalog.setCurrentDatabase("db2") - sessionCatalog.renameTable(TableIdentifier("tbltwo"), TableIdentifier("table_two")) - assert(externalCatalog.listTables("db2").toSet == Set("tblone", "table_two")) - // Renaming "db2.tblone" to "db1.tblones" should fail because databases don't match - intercept[AnalysisException] { - sessionCatalog.renameTable( - TableIdentifier("tblone", Some("db2")), TableIdentifier("tblones", Some("db1"))) - } - // The new table already exists - intercept[TableAlreadyExistsException] { - sessionCatalog.renameTable( - TableIdentifier("tblone", Some("db2")), - TableIdentifier("table_two")) + withBasicCatalog { catalog => + assert(catalog.externalCatalog.listTables("db2").toSet == Set("tbl1", "tbl2")) + catalog.renameTable(TableIdentifier("tbl1", Some("db2")), TableIdentifier("tblone")) + assert(catalog.externalCatalog.listTables("db2").toSet == Set("tblone", "tbl2")) + catalog.renameTable(TableIdentifier("tbl2", Some("db2")), TableIdentifier("tbltwo")) + assert(catalog.externalCatalog.listTables("db2").toSet == Set("tblone", "tbltwo")) + // Rename table without explicitly specifying database + catalog.setCurrentDatabase("db2") + catalog.renameTable(TableIdentifier("tbltwo"), TableIdentifier("table_two")) + assert(catalog.externalCatalog.listTables("db2").toSet == Set("tblone", "table_two")) + // Renaming "db2.tblone" to "db1.tblones" should fail because databases don't match + intercept[AnalysisException] { + catalog.renameTable( + TableIdentifier("tblone", Some("db2")), TableIdentifier("tblones", Some("db1"))) + } + // The new table already exists + intercept[TableAlreadyExistsException] { + catalog.renameTable( + TableIdentifier("tblone", Some("db2")), + TableIdentifier("table_two")) + } } } test("rename tables to an invalid name") { - val catalog = new SessionCatalog(newBasicCatalog()) - testInvalidName( - name => catalog.renameTable(TableIdentifier("tbl1", Some("db2")), TableIdentifier(name))) + withBasicCatalog { catalog => + testInvalidName( + name => catalog.renameTable(TableIdentifier("tbl1", Some("db2")), TableIdentifier(name))) + } } test("rename table when database/table does not exist") { - val catalog = new SessionCatalog(newBasicCatalog()) - intercept[NoSuchDatabaseException] { - catalog.renameTable(TableIdentifier("tbl1", Some("unknown_db")), TableIdentifier("tbl2")) - } - intercept[NoSuchTableException] { - catalog.renameTable(TableIdentifier("unknown_table", Some("db2")), TableIdentifier("tbl2")) + withBasicCatalog { catalog => + intercept[NoSuchDatabaseException] { + catalog.renameTable(TableIdentifier("tbl1", Some("unknown_db")), TableIdentifier("tbl2")) + } + intercept[NoSuchTableException] { + catalog.renameTable(TableIdentifier("unknown_table", Some("db2")), TableIdentifier("tbl2")) + } } } test("rename temp table") { - val externalCatalog = newBasicCatalog() - val sessionCatalog = new SessionCatalog(externalCatalog) - val tempTable = Range(1, 10, 2, 10) - sessionCatalog.createTempView("tbl1", tempTable, overrideIfExists = false) - sessionCatalog.setCurrentDatabase("db2") - assert(sessionCatalog.getTempView("tbl1") == Option(tempTable)) - assert(externalCatalog.listTables("db2").toSet == Set("tbl1", "tbl2")) - // If database is not specified, temp table should be renamed first - sessionCatalog.renameTable(TableIdentifier("tbl1"), TableIdentifier("tbl3")) - assert(sessionCatalog.getTempView("tbl1").isEmpty) - assert(sessionCatalog.getTempView("tbl3") == Option(tempTable)) - assert(externalCatalog.listTables("db2").toSet == Set("tbl1", "tbl2")) - // If database is specified, temp tables are never renamed - sessionCatalog.renameTable(TableIdentifier("tbl2", Some("db2")), TableIdentifier("tbl4")) - assert(sessionCatalog.getTempView("tbl3") == Option(tempTable)) - assert(sessionCatalog.getTempView("tbl4").isEmpty) - assert(externalCatalog.listTables("db2").toSet == Set("tbl1", "tbl4")) + withBasicCatalog { catalog => + val tempTable = Range(1, 10, 2, 10) + catalog.createTempView("tbl1", tempTable, overrideIfExists = false) + catalog.setCurrentDatabase("db2") + assert(catalog.getTempView("tbl1") == Option(tempTable)) + assert(catalog.externalCatalog.listTables("db2").toSet == Set("tbl1", "tbl2")) + // If database is not specified, temp table should be renamed first + catalog.renameTable(TableIdentifier("tbl1"), TableIdentifier("tbl3")) + assert(catalog.getTempView("tbl1").isEmpty) + assert(catalog.getTempView("tbl3") == Option(tempTable)) + assert(catalog.externalCatalog.listTables("db2").toSet == Set("tbl1", "tbl2")) + // If database is specified, temp tables are never renamed + catalog.renameTable(TableIdentifier("tbl2", Some("db2")), TableIdentifier("tbl4")) + assert(catalog.getTempView("tbl3") == Option(tempTable)) + assert(catalog.getTempView("tbl4").isEmpty) + assert(catalog.externalCatalog.listTables("db2").toSet == Set("tbl1", "tbl4")) + } } test("alter table") { - val externalCatalog = newBasicCatalog() - val sessionCatalog = new SessionCatalog(externalCatalog) - val tbl1 = externalCatalog.getTable("db2", "tbl1") - sessionCatalog.alterTable(tbl1.copy(properties = Map("toh" -> "frem"))) - val newTbl1 = externalCatalog.getTable("db2", "tbl1") - assert(!tbl1.properties.contains("toh")) - assert(newTbl1.properties.size == tbl1.properties.size + 1) - assert(newTbl1.properties.get("toh") == Some("frem")) - // Alter table without explicitly specifying database - sessionCatalog.setCurrentDatabase("db2") - sessionCatalog.alterTable(tbl1.copy(identifier = TableIdentifier("tbl1"))) - val newestTbl1 = externalCatalog.getTable("db2", "tbl1") - assert(newestTbl1 == tbl1) + withBasicCatalog { catalog => + val tbl1 = catalog.externalCatalog.getTable("db2", "tbl1") + catalog.alterTable(tbl1.copy(properties = Map("toh" -> "frem"))) + val newTbl1 = catalog.externalCatalog.getTable("db2", "tbl1") + assert(!tbl1.properties.contains("toh")) + assert(newTbl1.properties.size == tbl1.properties.size + 1) + assert(newTbl1.properties.get("toh") == Some("frem")) + // Alter table without explicitly specifying database + catalog.setCurrentDatabase("db2") + catalog.alterTable(tbl1.copy(identifier = TableIdentifier("tbl1"))) + val newestTbl1 = catalog.externalCatalog.getTable("db2", "tbl1") + // For hive serde table, hive metastore will set transient_lastDdlTime in table's properties, + // and its value will be modified, here we ignore it when comparing the two tables. + assert(newestTbl1.copy(properties = Map.empty) == tbl1.copy(properties = Map.empty)) + } } test("alter table when database/table does not exist") { - val catalog = new SessionCatalog(newBasicCatalog()) - intercept[NoSuchDatabaseException] { - catalog.alterTable(newTable("tbl1", "unknown_db")) + withBasicCatalog { catalog => + intercept[NoSuchDatabaseException] { + catalog.alterTable(newTable("tbl1", "unknown_db")) + } + intercept[NoSuchTableException] { + catalog.alterTable(newTable("unknown_table", "db2")) + } } - intercept[NoSuchTableException] { - catalog.alterTable(newTable("unknown_table", "db2")) + } + + test("alter table add columns") { + withBasicCatalog { sessionCatalog => + sessionCatalog.createTable(newTable("t1", "default"), ignoreIfExists = false) + val oldTab = sessionCatalog.externalCatalog.getTable("default", "t1") + sessionCatalog.alterTableSchema( + TableIdentifier("t1", Some("default")), + StructType(oldTab.dataSchema.add("c3", IntegerType) ++ oldTab.partitionSchema)) + + val newTab = sessionCatalog.externalCatalog.getTable("default", "t1") + // construct the expected table schema + val expectedTableSchema = StructType(oldTab.dataSchema.fields ++ + Seq(StructField("c3", IntegerType)) ++ oldTab.partitionSchema) + assert(newTab.schema == expectedTableSchema) + } + } + + test("alter table drop columns") { + withBasicCatalog { sessionCatalog => + sessionCatalog.createTable(newTable("t1", "default"), ignoreIfExists = false) + val oldTab = sessionCatalog.externalCatalog.getTable("default", "t1") + val e = intercept[AnalysisException] { + sessionCatalog.alterTableSchema( + TableIdentifier("t1", Some("default")), StructType(oldTab.schema.drop(1))) + }.getMessage + assert(e.contains("We don't support dropping columns yet.")) } } test("get table") { - val externalCatalog = newBasicCatalog() - val sessionCatalog = new SessionCatalog(externalCatalog) - assert(sessionCatalog.getTableMetadata(TableIdentifier("tbl1", Some("db2"))) - == externalCatalog.getTable("db2", "tbl1")) - // Get table without explicitly specifying database - sessionCatalog.setCurrentDatabase("db2") - assert(sessionCatalog.getTableMetadata(TableIdentifier("tbl1")) - == externalCatalog.getTable("db2", "tbl1")) + withBasicCatalog { catalog => + assert(catalog.getTableMetadata(TableIdentifier("tbl1", Some("db2"))) + == catalog.externalCatalog.getTable("db2", "tbl1")) + // Get table without explicitly specifying database + catalog.setCurrentDatabase("db2") + assert(catalog.getTableMetadata(TableIdentifier("tbl1")) + == catalog.externalCatalog.getTable("db2", "tbl1")) + } } test("get table when database/table does not exist") { - val catalog = new SessionCatalog(newBasicCatalog()) - intercept[NoSuchDatabaseException] { - catalog.getTableMetadata(TableIdentifier("tbl1", Some("unknown_db"))) - } - intercept[NoSuchTableException] { - catalog.getTableMetadata(TableIdentifier("unknown_table", Some("db2"))) + withBasicCatalog { catalog => + intercept[NoSuchDatabaseException] { + catalog.getTableMetadata(TableIdentifier("tbl1", Some("unknown_db"))) + } + intercept[NoSuchTableException] { + catalog.getTableMetadata(TableIdentifier("unknown_table", Some("db2"))) + } } } test("get option of table metadata") { - val externalCatalog = newBasicCatalog() - val catalog = new SessionCatalog(externalCatalog) - assert(catalog.getTableMetadataOption(TableIdentifier("tbl1", Some("db2"))) - == Option(externalCatalog.getTable("db2", "tbl1"))) - assert(catalog.getTableMetadataOption(TableIdentifier("unknown_table", Some("db2"))).isEmpty) - intercept[NoSuchDatabaseException] { - catalog.getTableMetadataOption(TableIdentifier("tbl1", Some("unknown_db"))) + withBasicCatalog { catalog => + assert(catalog.getTableMetadataOption(TableIdentifier("tbl1", Some("db2"))) + == Option(catalog.externalCatalog.getTable("db2", "tbl1"))) + assert(catalog.getTableMetadataOption(TableIdentifier("unknown_table", Some("db2"))).isEmpty) + intercept[NoSuchDatabaseException] { + catalog.getTableMetadataOption(TableIdentifier("tbl1", Some("unknown_db"))) + } } } test("lookup table relation") { - val externalCatalog = newBasicCatalog() - val sessionCatalog = new SessionCatalog(externalCatalog) - val tempTable1 = Range(1, 10, 1, 10) - val metastoreTable1 = externalCatalog.getTable("db2", "tbl1") - sessionCatalog.createTempView("tbl1", tempTable1, overrideIfExists = false) - sessionCatalog.setCurrentDatabase("db2") - // If we explicitly specify the database, we'll look up the relation in that database - assert(sessionCatalog.lookupRelation(TableIdentifier("tbl1", Some("db2"))).children.head - .asInstanceOf[CatalogRelation].tableMeta == metastoreTable1) - // Otherwise, we'll first look up a temporary table with the same name - assert(sessionCatalog.lookupRelation(TableIdentifier("tbl1")) - == SubqueryAlias("tbl1", tempTable1, None)) - // Then, if that does not exist, look up the relation in the current database - sessionCatalog.dropTable(TableIdentifier("tbl1"), ignoreIfNotExists = false, purge = false) - assert(sessionCatalog.lookupRelation(TableIdentifier("tbl1")).children.head - .asInstanceOf[CatalogRelation].tableMeta == metastoreTable1) + withBasicCatalog { catalog => + val tempTable1 = Range(1, 10, 1, 10) + val metastoreTable1 = catalog.externalCatalog.getTable("db2", "tbl1") + catalog.createTempView("tbl1", tempTable1, overrideIfExists = false) + catalog.setCurrentDatabase("db2") + // If we explicitly specify the database, we'll look up the relation in that database + assert(catalog.lookupRelation(TableIdentifier("tbl1", Some("db2"))).children.head + .asInstanceOf[CatalogRelation].tableMeta == metastoreTable1) + // Otherwise, we'll first look up a temporary table with the same name + assert(catalog.lookupRelation(TableIdentifier("tbl1")) + == SubqueryAlias("tbl1", tempTable1)) + // Then, if that does not exist, look up the relation in the current database + catalog.dropTable(TableIdentifier("tbl1"), ignoreIfNotExists = false, purge = false) + assert(catalog.lookupRelation(TableIdentifier("tbl1")).children.head + .asInstanceOf[CatalogRelation].tableMeta == metastoreTable1) + } } test("look up view relation") { - val externalCatalog = newBasicCatalog() - val sessionCatalog = new SessionCatalog(externalCatalog) - val metadata = externalCatalog.getTable("db3", "view1") - sessionCatalog.setCurrentDatabase("default") - // Look up a view. - assert(metadata.viewText.isDefined) - val view = View(desc = metadata, output = metadata.schema.toAttributes, - child = CatalystSqlParser.parsePlan(metadata.viewText.get)) - comparePlans(sessionCatalog.lookupRelation(TableIdentifier("view1", Some("db3"))), - SubqueryAlias("view1", view, Some(TableIdentifier("view1", Some("db3"))))) - // Look up a view using current database of the session catalog. - sessionCatalog.setCurrentDatabase("db3") - comparePlans(sessionCatalog.lookupRelation(TableIdentifier("view1")), - SubqueryAlias("view1", view, Some(TableIdentifier("view1", Some("db3"))))) + withBasicCatalog { catalog => + val metadata = catalog.externalCatalog.getTable("db3", "view1") + catalog.setCurrentDatabase("default") + // Look up a view. + assert(metadata.viewText.isDefined) + val view = View(desc = metadata, output = metadata.schema.toAttributes, + child = CatalystSqlParser.parsePlan(metadata.viewText.get)) + comparePlans(catalog.lookupRelation(TableIdentifier("view1", Some("db3"))), + SubqueryAlias("view1", view)) + // Look up a view using current database of the session catalog. + catalog.setCurrentDatabase("db3") + comparePlans(catalog.lookupRelation(TableIdentifier("view1")), + SubqueryAlias("view1", view)) + } } test("table exists") { - val catalog = new SessionCatalog(newBasicCatalog()) - assert(catalog.tableExists(TableIdentifier("tbl1", Some("db2")))) - assert(catalog.tableExists(TableIdentifier("tbl2", Some("db2")))) - assert(!catalog.tableExists(TableIdentifier("tbl3", Some("db2")))) - assert(!catalog.tableExists(TableIdentifier("tbl1", Some("db1")))) - assert(!catalog.tableExists(TableIdentifier("tbl2", Some("db1")))) - // If database is explicitly specified, do not check temporary tables - val tempTable = Range(1, 10, 1, 10) - assert(!catalog.tableExists(TableIdentifier("tbl3", Some("db2")))) - // If database is not explicitly specified, check the current database - catalog.setCurrentDatabase("db2") - assert(catalog.tableExists(TableIdentifier("tbl1"))) - assert(catalog.tableExists(TableIdentifier("tbl2"))) - - catalog.createTempView("tbl3", tempTable, overrideIfExists = false) - // tableExists should not check temp view. - assert(!catalog.tableExists(TableIdentifier("tbl3"))) + withBasicCatalog { catalog => + assert(catalog.tableExists(TableIdentifier("tbl1", Some("db2")))) + assert(catalog.tableExists(TableIdentifier("tbl2", Some("db2")))) + assert(!catalog.tableExists(TableIdentifier("tbl3", Some("db2")))) + assert(!catalog.tableExists(TableIdentifier("tbl1", Some("db1")))) + assert(!catalog.tableExists(TableIdentifier("tbl2", Some("db1")))) + // If database is explicitly specified, do not check temporary tables + val tempTable = Range(1, 10, 1, 10) + assert(!catalog.tableExists(TableIdentifier("tbl3", Some("db2")))) + // If database is not explicitly specified, check the current database + catalog.setCurrentDatabase("db2") + assert(catalog.tableExists(TableIdentifier("tbl1"))) + assert(catalog.tableExists(TableIdentifier("tbl2"))) + + catalog.createTempView("tbl3", tempTable, overrideIfExists = false) + // tableExists should not check temp view. + assert(!catalog.tableExists(TableIdentifier("tbl3"))) + } } test("getTempViewOrPermanentTableMetadata on temporary views") { - val catalog = new SessionCatalog(newBasicCatalog()) - val tempTable = Range(1, 10, 2, 10) - intercept[NoSuchTableException] { - catalog.getTempViewOrPermanentTableMetadata(TableIdentifier("view1")) - }.getMessage + withBasicCatalog { catalog => + val tempTable = Range(1, 10, 2, 10) + intercept[NoSuchTableException] { + catalog.getTempViewOrPermanentTableMetadata(TableIdentifier("view1")) + }.getMessage - intercept[NoSuchTableException] { - catalog.getTempViewOrPermanentTableMetadata(TableIdentifier("view1", Some("default"))) - }.getMessage + intercept[NoSuchTableException] { + catalog.getTempViewOrPermanentTableMetadata(TableIdentifier("view1", Some("default"))) + }.getMessage - catalog.createTempView("view1", tempTable, overrideIfExists = false) - assert(catalog.getTempViewOrPermanentTableMetadata( - TableIdentifier("view1")).identifier.table == "view1") - assert(catalog.getTempViewOrPermanentTableMetadata( - TableIdentifier("view1")).schema(0).name == "id") + catalog.createTempView("view1", tempTable, overrideIfExists = false) + assert(catalog.getTempViewOrPermanentTableMetadata( + TableIdentifier("view1")).identifier.table == "view1") + assert(catalog.getTempViewOrPermanentTableMetadata( + TableIdentifier("view1")).schema(0).name == "id") - intercept[NoSuchTableException] { - catalog.getTempViewOrPermanentTableMetadata(TableIdentifier("view1", Some("default"))) - }.getMessage + intercept[NoSuchTableException] { + catalog.getTempViewOrPermanentTableMetadata(TableIdentifier("view1", Some("default"))) + }.getMessage + } } test("list tables without pattern") { - val catalog = new SessionCatalog(newBasicCatalog()) - val tempTable = Range(1, 10, 2, 10) - catalog.createTempView("tbl1", tempTable, overrideIfExists = false) - catalog.createTempView("tbl4", tempTable, overrideIfExists = false) - assert(catalog.listTables("db1").toSet == - Set(TableIdentifier("tbl1"), TableIdentifier("tbl4"))) - assert(catalog.listTables("db2").toSet == - Set(TableIdentifier("tbl1"), - TableIdentifier("tbl4"), - TableIdentifier("tbl1", Some("db2")), - TableIdentifier("tbl2", Some("db2")))) - intercept[NoSuchDatabaseException] { - catalog.listTables("unknown_db") + withBasicCatalog { catalog => + val tempTable = Range(1, 10, 2, 10) + catalog.createTempView("tbl1", tempTable, overrideIfExists = false) + catalog.createTempView("tbl4", tempTable, overrideIfExists = false) + assert(catalog.listTables("db1").toSet == + Set(TableIdentifier("tbl1"), TableIdentifier("tbl4"))) + assert(catalog.listTables("db2").toSet == + Set(TableIdentifier("tbl1"), + TableIdentifier("tbl4"), + TableIdentifier("tbl1", Some("db2")), + TableIdentifier("tbl2", Some("db2")))) + intercept[NoSuchDatabaseException] { + catalog.listTables("unknown_db") + } } } test("list tables with pattern") { - val catalog = new SessionCatalog(newBasicCatalog()) - val tempTable = Range(1, 10, 2, 10) - catalog.createTempView("tbl1", tempTable, overrideIfExists = false) - catalog.createTempView("tbl4", tempTable, overrideIfExists = false) - assert(catalog.listTables("db1", "*").toSet == catalog.listTables("db1").toSet) - assert(catalog.listTables("db2", "*").toSet == catalog.listTables("db2").toSet) - assert(catalog.listTables("db2", "tbl*").toSet == - Set(TableIdentifier("tbl1"), - TableIdentifier("tbl4"), - TableIdentifier("tbl1", Some("db2")), - TableIdentifier("tbl2", Some("db2")))) - assert(catalog.listTables("db2", "*1").toSet == - Set(TableIdentifier("tbl1"), TableIdentifier("tbl1", Some("db2")))) - intercept[NoSuchDatabaseException] { - catalog.listTables("unknown_db", "*") + withBasicCatalog { catalog => + val tempTable = Range(1, 10, 2, 10) + catalog.createTempView("tbl1", tempTable, overrideIfExists = false) + catalog.createTempView("tbl4", tempTable, overrideIfExists = false) + assert(catalog.listTables("db1", "*").toSet == catalog.listTables("db1").toSet) + assert(catalog.listTables("db2", "*").toSet == catalog.listTables("db2").toSet) + assert(catalog.listTables("db2", "tbl*").toSet == + Set(TableIdentifier("tbl1"), + TableIdentifier("tbl4"), + TableIdentifier("tbl1", Some("db2")), + TableIdentifier("tbl2", Some("db2")))) + assert(catalog.listTables("db2", "*1").toSet == + Set(TableIdentifier("tbl1"), TableIdentifier("tbl1", Some("db2")))) + intercept[NoSuchDatabaseException] { + catalog.listTables("unknown_db", "*") + } } } @@ -544,451 +632,477 @@ class SessionCatalogSuite extends PlanTest { // -------------------------------------------------------------------------- test("basic create and list partitions") { - val externalCatalog = newEmptyCatalog() - val sessionCatalog = new SessionCatalog(externalCatalog) - sessionCatalog.createDatabase(newDb("mydb"), ignoreIfExists = false) - sessionCatalog.createTable(newTable("tbl", "mydb"), ignoreIfExists = false) - sessionCatalog.createPartitions( - TableIdentifier("tbl", Some("mydb")), Seq(part1, part2), ignoreIfExists = false) - assert(catalogPartitionsEqual(externalCatalog.listPartitions("mydb", "tbl"), part1, part2)) - // Create partitions without explicitly specifying database - sessionCatalog.setCurrentDatabase("mydb") - sessionCatalog.createPartitions( - TableIdentifier("tbl"), Seq(partWithMixedOrder), ignoreIfExists = false) - assert(catalogPartitionsEqual( - externalCatalog.listPartitions("mydb", "tbl"), part1, part2, partWithMixedOrder)) + withEmptyCatalog { catalog => + catalog.createDatabase(newDb("mydb"), ignoreIfExists = false) + catalog.createTable(newTable("tbl", "mydb"), ignoreIfExists = false) + catalog.createPartitions( + TableIdentifier("tbl", Some("mydb")), Seq(part1, part2), ignoreIfExists = false) + assert(catalogPartitionsEqual( + catalog.externalCatalog.listPartitions("mydb", "tbl"), part1, part2)) + // Create partitions without explicitly specifying database + catalog.setCurrentDatabase("mydb") + catalog.createPartitions( + TableIdentifier("tbl"), Seq(partWithMixedOrder), ignoreIfExists = false) + assert(catalogPartitionsEqual( + catalog.externalCatalog.listPartitions("mydb", "tbl"), part1, part2, partWithMixedOrder)) + } } test("create partitions when database/table does not exist") { - val catalog = new SessionCatalog(newBasicCatalog()) - intercept[NoSuchDatabaseException] { - catalog.createPartitions( - TableIdentifier("tbl1", Some("unknown_db")), Seq(), ignoreIfExists = false) - } - intercept[NoSuchTableException] { - catalog.createPartitions( - TableIdentifier("does_not_exist", Some("db2")), Seq(), ignoreIfExists = false) + withBasicCatalog { catalog => + intercept[NoSuchDatabaseException] { + catalog.createPartitions( + TableIdentifier("tbl1", Some("unknown_db")), Seq(), ignoreIfExists = false) + } + intercept[NoSuchTableException] { + catalog.createPartitions( + TableIdentifier("does_not_exist", Some("db2")), Seq(), ignoreIfExists = false) + } } } test("create partitions that already exist") { - val catalog = new SessionCatalog(newBasicCatalog()) - intercept[AnalysisException] { + withBasicCatalog { catalog => + intercept[AnalysisException] { + catalog.createPartitions( + TableIdentifier("tbl2", Some("db2")), Seq(part1), ignoreIfExists = false) + } catalog.createPartitions( - TableIdentifier("tbl2", Some("db2")), Seq(part1), ignoreIfExists = false) + TableIdentifier("tbl2", Some("db2")), Seq(part1), ignoreIfExists = true) } - catalog.createPartitions( - TableIdentifier("tbl2", Some("db2")), Seq(part1), ignoreIfExists = true) } test("create partitions with invalid part spec") { - val catalog = new SessionCatalog(newBasicCatalog()) - var e = intercept[AnalysisException] { - catalog.createPartitions( - TableIdentifier("tbl2", Some("db2")), - Seq(part1, partWithLessColumns), ignoreIfExists = false) - } - assert(e.getMessage.contains("Partition spec is invalid. The spec (a) must match " + - "the partition spec (a, b) defined in table '`db2`.`tbl2`'")) - e = intercept[AnalysisException] { - catalog.createPartitions( - TableIdentifier("tbl2", Some("db2")), - Seq(part1, partWithMoreColumns), ignoreIfExists = true) - } - assert(e.getMessage.contains("Partition spec is invalid. The spec (a, b, c) must match " + - "the partition spec (a, b) defined in table '`db2`.`tbl2`'")) - e = intercept[AnalysisException] { - catalog.createPartitions( - TableIdentifier("tbl2", Some("db2")), - Seq(partWithUnknownColumns, part1), ignoreIfExists = true) - } - assert(e.getMessage.contains("Partition spec is invalid. The spec (a, unknown) must match " + - "the partition spec (a, b) defined in table '`db2`.`tbl2`'")) - e = intercept[AnalysisException] { - catalog.createPartitions( - TableIdentifier("tbl2", Some("db2")), - Seq(partWithEmptyValue, part1), ignoreIfExists = true) + withBasicCatalog { catalog => + var e = intercept[AnalysisException] { + catalog.createPartitions( + TableIdentifier("tbl2", Some("db2")), + Seq(part1, partWithLessColumns), ignoreIfExists = false) + } + assert(e.getMessage.contains("Partition spec is invalid. The spec (a) must match " + + "the partition spec (a, b) defined in table '`db2`.`tbl2`'")) + e = intercept[AnalysisException] { + catalog.createPartitions( + TableIdentifier("tbl2", Some("db2")), + Seq(part1, partWithMoreColumns), ignoreIfExists = true) + } + assert(e.getMessage.contains("Partition spec is invalid. The spec (a, b, c) must match " + + "the partition spec (a, b) defined in table '`db2`.`tbl2`'")) + e = intercept[AnalysisException] { + catalog.createPartitions( + TableIdentifier("tbl2", Some("db2")), + Seq(partWithUnknownColumns, part1), ignoreIfExists = true) + } + assert(e.getMessage.contains("Partition spec is invalid. The spec (a, unknown) must match " + + "the partition spec (a, b) defined in table '`db2`.`tbl2`'")) + e = intercept[AnalysisException] { + catalog.createPartitions( + TableIdentifier("tbl2", Some("db2")), + Seq(partWithEmptyValue, part1), ignoreIfExists = true) + } + assert(e.getMessage.contains("Partition spec is invalid. The spec ([a=3, b=]) contains an " + + "empty partition column value")) } - assert(e.getMessage.contains("Partition spec is invalid. The spec ([a=3, b=]) contains an " + - "empty partition column value")) } test("drop partitions") { - val externalCatalog = newBasicCatalog() - val sessionCatalog = new SessionCatalog(externalCatalog) - assert(catalogPartitionsEqual(externalCatalog.listPartitions("db2", "tbl2"), part1, part2)) - sessionCatalog.dropPartitions( - TableIdentifier("tbl2", Some("db2")), - Seq(part1.spec), - ignoreIfNotExists = false, - purge = false, - retainData = false) - assert(catalogPartitionsEqual(externalCatalog.listPartitions("db2", "tbl2"), part2)) - // Drop partitions without explicitly specifying database - sessionCatalog.setCurrentDatabase("db2") - sessionCatalog.dropPartitions( - TableIdentifier("tbl2"), - Seq(part2.spec), - ignoreIfNotExists = false, - purge = false, - retainData = false) - assert(externalCatalog.listPartitions("db2", "tbl2").isEmpty) - // Drop multiple partitions at once - sessionCatalog.createPartitions( - TableIdentifier("tbl2", Some("db2")), Seq(part1, part2), ignoreIfExists = false) - assert(catalogPartitionsEqual(externalCatalog.listPartitions("db2", "tbl2"), part1, part2)) - sessionCatalog.dropPartitions( - TableIdentifier("tbl2", Some("db2")), - Seq(part1.spec, part2.spec), - ignoreIfNotExists = false, - purge = false, - retainData = false) - assert(externalCatalog.listPartitions("db2", "tbl2").isEmpty) - } - - test("drop partitions when database/table does not exist") { - val catalog = new SessionCatalog(newBasicCatalog()) - intercept[NoSuchDatabaseException] { + withBasicCatalog { catalog => + assert(catalogPartitionsEqual( + catalog.externalCatalog.listPartitions("db2", "tbl2"), part1, part2)) catalog.dropPartitions( - TableIdentifier("tbl1", Some("unknown_db")), - Seq(), + TableIdentifier("tbl2", Some("db2")), + Seq(part1.spec), ignoreIfNotExists = false, purge = false, retainData = false) - } - intercept[NoSuchTableException] { + assert(catalogPartitionsEqual( + catalog.externalCatalog.listPartitions("db2", "tbl2"), part2)) + // Drop partitions without explicitly specifying database + catalog.setCurrentDatabase("db2") catalog.dropPartitions( - TableIdentifier("does_not_exist", Some("db2")), - Seq(), + TableIdentifier("tbl2"), + Seq(part2.spec), ignoreIfNotExists = false, purge = false, retainData = false) - } - } - - test("drop partitions that do not exist") { - val catalog = new SessionCatalog(newBasicCatalog()) - intercept[AnalysisException] { + assert(catalog.externalCatalog.listPartitions("db2", "tbl2").isEmpty) + // Drop multiple partitions at once + catalog.createPartitions( + TableIdentifier("tbl2", Some("db2")), Seq(part1, part2), ignoreIfExists = false) + assert(catalogPartitionsEqual( + catalog.externalCatalog.listPartitions("db2", "tbl2"), part1, part2)) catalog.dropPartitions( TableIdentifier("tbl2", Some("db2")), - Seq(part3.spec), + Seq(part1.spec, part2.spec), ignoreIfNotExists = false, purge = false, retainData = false) + assert(catalog.externalCatalog.listPartitions("db2", "tbl2").isEmpty) } - catalog.dropPartitions( - TableIdentifier("tbl2", Some("db2")), - Seq(part3.spec), - ignoreIfNotExists = true, - purge = false, - retainData = false) } - test("drop partitions with invalid partition spec") { - val catalog = new SessionCatalog(newBasicCatalog()) - var e = intercept[AnalysisException] { - catalog.dropPartitions( - TableIdentifier("tbl2", Some("db2")), - Seq(partWithMoreColumns.spec), - ignoreIfNotExists = false, - purge = false, - retainData = false) + test("drop partitions when database/table does not exist") { + withBasicCatalog { catalog => + intercept[NoSuchDatabaseException] { + catalog.dropPartitions( + TableIdentifier("tbl1", Some("unknown_db")), + Seq(), + ignoreIfNotExists = false, + purge = false, + retainData = false) + } + intercept[NoSuchTableException] { + catalog.dropPartitions( + TableIdentifier("does_not_exist", Some("db2")), + Seq(), + ignoreIfNotExists = false, + purge = false, + retainData = false) + } } - assert(e.getMessage.contains( - "Partition spec is invalid. The spec (a, b, c) must be contained within " + - "the partition spec (a, b) defined in table '`db2`.`tbl2`'")) - e = intercept[AnalysisException] { + } + + test("drop partitions that do not exist") { + withBasicCatalog { catalog => + intercept[AnalysisException] { + catalog.dropPartitions( + TableIdentifier("tbl2", Some("db2")), + Seq(part3.spec), + ignoreIfNotExists = false, + purge = false, + retainData = false) + } catalog.dropPartitions( TableIdentifier("tbl2", Some("db2")), - Seq(partWithUnknownColumns.spec), - ignoreIfNotExists = false, + Seq(part3.spec), + ignoreIfNotExists = true, purge = false, retainData = false) } - assert(e.getMessage.contains( - "Partition spec is invalid. The spec (a, unknown) must be contained within " + - "the partition spec (a, b) defined in table '`db2`.`tbl2`'")) - e = intercept[AnalysisException] { - catalog.dropPartitions( - TableIdentifier("tbl2", Some("db2")), - Seq(partWithEmptyValue.spec, part1.spec), - ignoreIfNotExists = false, - purge = false, - retainData = false) + } + + test("drop partitions with invalid partition spec") { + withBasicCatalog { catalog => + var e = intercept[AnalysisException] { + catalog.dropPartitions( + TableIdentifier("tbl2", Some("db2")), + Seq(partWithMoreColumns.spec), + ignoreIfNotExists = false, + purge = false, + retainData = false) + } + assert(e.getMessage.contains( + "Partition spec is invalid. The spec (a, b, c) must be contained within " + + "the partition spec (a, b) defined in table '`db2`.`tbl2`'")) + e = intercept[AnalysisException] { + catalog.dropPartitions( + TableIdentifier("tbl2", Some("db2")), + Seq(partWithUnknownColumns.spec), + ignoreIfNotExists = false, + purge = false, + retainData = false) + } + assert(e.getMessage.contains( + "Partition spec is invalid. The spec (a, unknown) must be contained within " + + "the partition spec (a, b) defined in table '`db2`.`tbl2`'")) + e = intercept[AnalysisException] { + catalog.dropPartitions( + TableIdentifier("tbl2", Some("db2")), + Seq(partWithEmptyValue.spec, part1.spec), + ignoreIfNotExists = false, + purge = false, + retainData = false) + } + assert(e.getMessage.contains("Partition spec is invalid. The spec ([a=3, b=]) contains an " + + "empty partition column value")) } - assert(e.getMessage.contains("Partition spec is invalid. The spec ([a=3, b=]) contains an " + - "empty partition column value")) } test("get partition") { - val catalog = new SessionCatalog(newBasicCatalog()) - assert(catalog.getPartition( - TableIdentifier("tbl2", Some("db2")), part1.spec).spec == part1.spec) - assert(catalog.getPartition( - TableIdentifier("tbl2", Some("db2")), part2.spec).spec == part2.spec) - // Get partition without explicitly specifying database - catalog.setCurrentDatabase("db2") - assert(catalog.getPartition(TableIdentifier("tbl2"), part1.spec).spec == part1.spec) - assert(catalog.getPartition(TableIdentifier("tbl2"), part2.spec).spec == part2.spec) - // Get non-existent partition - intercept[AnalysisException] { - catalog.getPartition(TableIdentifier("tbl2"), part3.spec) + withBasicCatalog { catalog => + assert(catalog.getPartition( + TableIdentifier("tbl2", Some("db2")), part1.spec).spec == part1.spec) + assert(catalog.getPartition( + TableIdentifier("tbl2", Some("db2")), part2.spec).spec == part2.spec) + // Get partition without explicitly specifying database + catalog.setCurrentDatabase("db2") + assert(catalog.getPartition(TableIdentifier("tbl2"), part1.spec).spec == part1.spec) + assert(catalog.getPartition(TableIdentifier("tbl2"), part2.spec).spec == part2.spec) + // Get non-existent partition + intercept[AnalysisException] { + catalog.getPartition(TableIdentifier("tbl2"), part3.spec) + } } } test("get partition when database/table does not exist") { - val catalog = new SessionCatalog(newBasicCatalog()) - intercept[NoSuchDatabaseException] { - catalog.getPartition(TableIdentifier("tbl1", Some("unknown_db")), part1.spec) - } - intercept[NoSuchTableException] { - catalog.getPartition(TableIdentifier("does_not_exist", Some("db2")), part1.spec) + withBasicCatalog { catalog => + intercept[NoSuchDatabaseException] { + catalog.getPartition(TableIdentifier("tbl1", Some("unknown_db")), part1.spec) + } + intercept[NoSuchTableException] { + catalog.getPartition(TableIdentifier("does_not_exist", Some("db2")), part1.spec) + } } } test("get partition with invalid partition spec") { - val catalog = new SessionCatalog(newBasicCatalog()) - var e = intercept[AnalysisException] { - catalog.getPartition(TableIdentifier("tbl1", Some("db2")), partWithLessColumns.spec) + withBasicCatalog { catalog => + var e = intercept[AnalysisException] { + catalog.getPartition(TableIdentifier("tbl1", Some("db2")), partWithLessColumns.spec) + } + assert(e.getMessage.contains("Partition spec is invalid. The spec (a) must match " + + "the partition spec (a, b) defined in table '`db2`.`tbl1`'")) + e = intercept[AnalysisException] { + catalog.getPartition(TableIdentifier("tbl1", Some("db2")), partWithMoreColumns.spec) + } + assert(e.getMessage.contains("Partition spec is invalid. The spec (a, b, c) must match " + + "the partition spec (a, b) defined in table '`db2`.`tbl1`'")) + e = intercept[AnalysisException] { + catalog.getPartition(TableIdentifier("tbl1", Some("db2")), partWithUnknownColumns.spec) + } + assert(e.getMessage.contains("Partition spec is invalid. The spec (a, unknown) must match " + + "the partition spec (a, b) defined in table '`db2`.`tbl1`'")) + e = intercept[AnalysisException] { + catalog.getPartition(TableIdentifier("tbl1", Some("db2")), partWithEmptyValue.spec) + } + assert(e.getMessage.contains("Partition spec is invalid. The spec ([a=3, b=]) contains an " + + "empty partition column value")) } - assert(e.getMessage.contains("Partition spec is invalid. The spec (a) must match " + - "the partition spec (a, b) defined in table '`db2`.`tbl1`'")) - e = intercept[AnalysisException] { - catalog.getPartition(TableIdentifier("tbl1", Some("db2")), partWithMoreColumns.spec) - } - assert(e.getMessage.contains("Partition spec is invalid. The spec (a, b, c) must match " + - "the partition spec (a, b) defined in table '`db2`.`tbl1`'")) - e = intercept[AnalysisException] { - catalog.getPartition(TableIdentifier("tbl1", Some("db2")), partWithUnknownColumns.spec) - } - assert(e.getMessage.contains("Partition spec is invalid. The spec (a, unknown) must match " + - "the partition spec (a, b) defined in table '`db2`.`tbl1`'")) - e = intercept[AnalysisException] { - catalog.getPartition(TableIdentifier("tbl1", Some("db2")), partWithEmptyValue.spec) - } - assert(e.getMessage.contains("Partition spec is invalid. The spec ([a=3, b=]) contains an " + - "empty partition column value")) } test("rename partitions") { - val catalog = new SessionCatalog(newBasicCatalog()) - val newPart1 = part1.copy(spec = Map("a" -> "100", "b" -> "101")) - val newPart2 = part2.copy(spec = Map("a" -> "200", "b" -> "201")) - val newSpecs = Seq(newPart1.spec, newPart2.spec) - catalog.renamePartitions( - TableIdentifier("tbl2", Some("db2")), Seq(part1.spec, part2.spec), newSpecs) - assert(catalog.getPartition( - TableIdentifier("tbl2", Some("db2")), newPart1.spec).spec === newPart1.spec) - assert(catalog.getPartition( - TableIdentifier("tbl2", Some("db2")), newPart2.spec).spec === newPart2.spec) - intercept[AnalysisException] { - catalog.getPartition(TableIdentifier("tbl2", Some("db2")), part1.spec) - } - intercept[AnalysisException] { - catalog.getPartition(TableIdentifier("tbl2", Some("db2")), part2.spec) - } - // Rename partitions without explicitly specifying database - catalog.setCurrentDatabase("db2") - catalog.renamePartitions(TableIdentifier("tbl2"), newSpecs, Seq(part1.spec, part2.spec)) - assert(catalog.getPartition(TableIdentifier("tbl2"), part1.spec).spec === part1.spec) - assert(catalog.getPartition(TableIdentifier("tbl2"), part2.spec).spec === part2.spec) - intercept[AnalysisException] { - catalog.getPartition(TableIdentifier("tbl2"), newPart1.spec) - } - intercept[AnalysisException] { - catalog.getPartition(TableIdentifier("tbl2"), newPart2.spec) + withBasicCatalog { catalog => + val newPart1 = part1.copy(spec = Map("a" -> "100", "b" -> "101")) + val newPart2 = part2.copy(spec = Map("a" -> "200", "b" -> "201")) + val newSpecs = Seq(newPart1.spec, newPart2.spec) + catalog.renamePartitions( + TableIdentifier("tbl2", Some("db2")), Seq(part1.spec, part2.spec), newSpecs) + assert(catalog.getPartition( + TableIdentifier("tbl2", Some("db2")), newPart1.spec).spec === newPart1.spec) + assert(catalog.getPartition( + TableIdentifier("tbl2", Some("db2")), newPart2.spec).spec === newPart2.spec) + intercept[AnalysisException] { + catalog.getPartition(TableIdentifier("tbl2", Some("db2")), part1.spec) + } + intercept[AnalysisException] { + catalog.getPartition(TableIdentifier("tbl2", Some("db2")), part2.spec) + } + // Rename partitions without explicitly specifying database + catalog.setCurrentDatabase("db2") + catalog.renamePartitions(TableIdentifier("tbl2"), newSpecs, Seq(part1.spec, part2.spec)) + assert(catalog.getPartition(TableIdentifier("tbl2"), part1.spec).spec === part1.spec) + assert(catalog.getPartition(TableIdentifier("tbl2"), part2.spec).spec === part2.spec) + intercept[AnalysisException] { + catalog.getPartition(TableIdentifier("tbl2"), newPart1.spec) + } + intercept[AnalysisException] { + catalog.getPartition(TableIdentifier("tbl2"), newPart2.spec) + } } } test("rename partitions when database/table does not exist") { - val catalog = new SessionCatalog(newBasicCatalog()) - intercept[NoSuchDatabaseException] { - catalog.renamePartitions( - TableIdentifier("tbl1", Some("unknown_db")), Seq(part1.spec), Seq(part2.spec)) - } - intercept[NoSuchTableException] { - catalog.renamePartitions( - TableIdentifier("does_not_exist", Some("db2")), Seq(part1.spec), Seq(part2.spec)) + withBasicCatalog { catalog => + intercept[NoSuchDatabaseException] { + catalog.renamePartitions( + TableIdentifier("tbl1", Some("unknown_db")), Seq(part1.spec), Seq(part2.spec)) + } + intercept[NoSuchTableException] { + catalog.renamePartitions( + TableIdentifier("does_not_exist", Some("db2")), Seq(part1.spec), Seq(part2.spec)) + } } } test("rename partition with invalid partition spec") { - val catalog = new SessionCatalog(newBasicCatalog()) - var e = intercept[AnalysisException] { - catalog.renamePartitions( - TableIdentifier("tbl1", Some("db2")), - Seq(part1.spec), Seq(partWithLessColumns.spec)) + withBasicCatalog { catalog => + var e = intercept[AnalysisException] { + catalog.renamePartitions( + TableIdentifier("tbl1", Some("db2")), + Seq(part1.spec), Seq(partWithLessColumns.spec)) + } + assert(e.getMessage.contains("Partition spec is invalid. The spec (a) must match " + + "the partition spec (a, b) defined in table '`db2`.`tbl1`'")) + e = intercept[AnalysisException] { + catalog.renamePartitions( + TableIdentifier("tbl1", Some("db2")), + Seq(part1.spec), Seq(partWithMoreColumns.spec)) + } + assert(e.getMessage.contains("Partition spec is invalid. The spec (a, b, c) must match " + + "the partition spec (a, b) defined in table '`db2`.`tbl1`'")) + e = intercept[AnalysisException] { + catalog.renamePartitions( + TableIdentifier("tbl1", Some("db2")), + Seq(part1.spec), Seq(partWithUnknownColumns.spec)) + } + assert(e.getMessage.contains("Partition spec is invalid. The spec (a, unknown) must match " + + "the partition spec (a, b) defined in table '`db2`.`tbl1`'")) + e = intercept[AnalysisException] { + catalog.renamePartitions( + TableIdentifier("tbl1", Some("db2")), + Seq(part1.spec), Seq(partWithEmptyValue.spec)) + } + assert(e.getMessage.contains("Partition spec is invalid. The spec ([a=3, b=]) contains an " + + "empty partition column value")) } - assert(e.getMessage.contains("Partition spec is invalid. The spec (a) must match " + - "the partition spec (a, b) defined in table '`db2`.`tbl1`'")) - e = intercept[AnalysisException] { - catalog.renamePartitions( - TableIdentifier("tbl1", Some("db2")), - Seq(part1.spec), Seq(partWithMoreColumns.spec)) - } - assert(e.getMessage.contains("Partition spec is invalid. The spec (a, b, c) must match " + - "the partition spec (a, b) defined in table '`db2`.`tbl1`'")) - e = intercept[AnalysisException] { - catalog.renamePartitions( - TableIdentifier("tbl1", Some("db2")), - Seq(part1.spec), Seq(partWithUnknownColumns.spec)) - } - assert(e.getMessage.contains("Partition spec is invalid. The spec (a, unknown) must match " + - "the partition spec (a, b) defined in table '`db2`.`tbl1`'")) - e = intercept[AnalysisException] { - catalog.renamePartitions( - TableIdentifier("tbl1", Some("db2")), - Seq(part1.spec), Seq(partWithEmptyValue.spec)) - } - assert(e.getMessage.contains("Partition spec is invalid. The spec ([a=3, b=]) contains an " + - "empty partition column value")) } test("alter partitions") { - val catalog = new SessionCatalog(newBasicCatalog()) - val newLocation = newUriForDatabase() - // Alter but keep spec the same - val oldPart1 = catalog.getPartition(TableIdentifier("tbl2", Some("db2")), part1.spec) - val oldPart2 = catalog.getPartition(TableIdentifier("tbl2", Some("db2")), part2.spec) - catalog.alterPartitions(TableIdentifier("tbl2", Some("db2")), Seq( - oldPart1.copy(storage = storageFormat.copy(locationUri = Some(newLocation))), - oldPart2.copy(storage = storageFormat.copy(locationUri = Some(newLocation))))) - val newPart1 = catalog.getPartition(TableIdentifier("tbl2", Some("db2")), part1.spec) - val newPart2 = catalog.getPartition(TableIdentifier("tbl2", Some("db2")), part2.spec) - assert(newPart1.storage.locationUri == Some(newLocation)) - assert(newPart2.storage.locationUri == Some(newLocation)) - assert(oldPart1.storage.locationUri != Some(newLocation)) - assert(oldPart2.storage.locationUri != Some(newLocation)) - // Alter partitions without explicitly specifying database - catalog.setCurrentDatabase("db2") - catalog.alterPartitions(TableIdentifier("tbl2"), Seq(oldPart1, oldPart2)) - val newerPart1 = catalog.getPartition(TableIdentifier("tbl2"), part1.spec) - val newerPart2 = catalog.getPartition(TableIdentifier("tbl2"), part2.spec) - assert(oldPart1.storage.locationUri == newerPart1.storage.locationUri) - assert(oldPart2.storage.locationUri == newerPart2.storage.locationUri) - // Alter but change spec, should fail because new partition specs do not exist yet - val badPart1 = part1.copy(spec = Map("a" -> "v1", "b" -> "v2")) - val badPart2 = part2.copy(spec = Map("a" -> "v3", "b" -> "v4")) - intercept[AnalysisException] { - catalog.alterPartitions(TableIdentifier("tbl2", Some("db2")), Seq(badPart1, badPart2)) + withBasicCatalog { catalog => + val newLocation = newUriForDatabase() + // Alter but keep spec the same + val oldPart1 = catalog.getPartition(TableIdentifier("tbl2", Some("db2")), part1.spec) + val oldPart2 = catalog.getPartition(TableIdentifier("tbl2", Some("db2")), part2.spec) + catalog.alterPartitions(TableIdentifier("tbl2", Some("db2")), Seq( + oldPart1.copy(storage = storageFormat.copy(locationUri = Some(newLocation))), + oldPart2.copy(storage = storageFormat.copy(locationUri = Some(newLocation))))) + val newPart1 = catalog.getPartition(TableIdentifier("tbl2", Some("db2")), part1.spec) + val newPart2 = catalog.getPartition(TableIdentifier("tbl2", Some("db2")), part2.spec) + assert(newPart1.storage.locationUri == Some(newLocation)) + assert(newPart2.storage.locationUri == Some(newLocation)) + assert(oldPart1.storage.locationUri != Some(newLocation)) + assert(oldPart2.storage.locationUri != Some(newLocation)) + // Alter partitions without explicitly specifying database + catalog.setCurrentDatabase("db2") + catalog.alterPartitions(TableIdentifier("tbl2"), Seq(oldPart1, oldPart2)) + val newerPart1 = catalog.getPartition(TableIdentifier("tbl2"), part1.spec) + val newerPart2 = catalog.getPartition(TableIdentifier("tbl2"), part2.spec) + assert(oldPart1.storage.locationUri == newerPart1.storage.locationUri) + assert(oldPart2.storage.locationUri == newerPart2.storage.locationUri) + // Alter but change spec, should fail because new partition specs do not exist yet + val badPart1 = part1.copy(spec = Map("a" -> "v1", "b" -> "v2")) + val badPart2 = part2.copy(spec = Map("a" -> "v3", "b" -> "v4")) + intercept[AnalysisException] { + catalog.alterPartitions(TableIdentifier("tbl2", Some("db2")), Seq(badPart1, badPart2)) + } } } test("alter partitions when database/table does not exist") { - val catalog = new SessionCatalog(newBasicCatalog()) - intercept[NoSuchDatabaseException] { - catalog.alterPartitions(TableIdentifier("tbl1", Some("unknown_db")), Seq(part1)) - } - intercept[NoSuchTableException] { - catalog.alterPartitions(TableIdentifier("does_not_exist", Some("db2")), Seq(part1)) + withBasicCatalog { catalog => + intercept[NoSuchDatabaseException] { + catalog.alterPartitions(TableIdentifier("tbl1", Some("unknown_db")), Seq(part1)) + } + intercept[NoSuchTableException] { + catalog.alterPartitions(TableIdentifier("does_not_exist", Some("db2")), Seq(part1)) + } } } test("alter partition with invalid partition spec") { - val catalog = new SessionCatalog(newBasicCatalog()) - var e = intercept[AnalysisException] { - catalog.alterPartitions(TableIdentifier("tbl1", Some("db2")), Seq(partWithLessColumns)) - } - assert(e.getMessage.contains("Partition spec is invalid. The spec (a) must match " + - "the partition spec (a, b) defined in table '`db2`.`tbl1`'")) - e = intercept[AnalysisException] { - catalog.alterPartitions(TableIdentifier("tbl1", Some("db2")), Seq(partWithMoreColumns)) - } - assert(e.getMessage.contains("Partition spec is invalid. The spec (a, b, c) must match " + - "the partition spec (a, b) defined in table '`db2`.`tbl1`'")) - e = intercept[AnalysisException] { - catalog.alterPartitions(TableIdentifier("tbl1", Some("db2")), Seq(partWithUnknownColumns)) + withBasicCatalog { catalog => + var e = intercept[AnalysisException] { + catalog.alterPartitions(TableIdentifier("tbl1", Some("db2")), Seq(partWithLessColumns)) + } + assert(e.getMessage.contains("Partition spec is invalid. The spec (a) must match " + + "the partition spec (a, b) defined in table '`db2`.`tbl1`'")) + e = intercept[AnalysisException] { + catalog.alterPartitions(TableIdentifier("tbl1", Some("db2")), Seq(partWithMoreColumns)) + } + assert(e.getMessage.contains("Partition spec is invalid. The spec (a, b, c) must match " + + "the partition spec (a, b) defined in table '`db2`.`tbl1`'")) + e = intercept[AnalysisException] { + catalog.alterPartitions(TableIdentifier("tbl1", Some("db2")), Seq(partWithUnknownColumns)) + } + assert(e.getMessage.contains("Partition spec is invalid. The spec (a, unknown) must match " + + "the partition spec (a, b) defined in table '`db2`.`tbl1`'")) + e = intercept[AnalysisException] { + catalog.alterPartitions(TableIdentifier("tbl1", Some("db2")), Seq(partWithEmptyValue)) + } + assert(e.getMessage.contains("Partition spec is invalid. The spec ([a=3, b=]) contains an " + + "empty partition column value")) } - assert(e.getMessage.contains("Partition spec is invalid. The spec (a, unknown) must match " + - "the partition spec (a, b) defined in table '`db2`.`tbl1`'")) - e = intercept[AnalysisException] { - catalog.alterPartitions(TableIdentifier("tbl1", Some("db2")), Seq(partWithEmptyValue)) - } - assert(e.getMessage.contains("Partition spec is invalid. The spec ([a=3, b=]) contains an " + - "empty partition column value")) } test("list partition names") { - val catalog = new SessionCatalog(newBasicCatalog()) - val expectedPartitionNames = Seq("a=1/b=2", "a=3/b=4") - assert(catalog.listPartitionNames(TableIdentifier("tbl2", Some("db2"))) == - expectedPartitionNames) - // List partition names without explicitly specifying database - catalog.setCurrentDatabase("db2") - assert(catalog.listPartitionNames(TableIdentifier("tbl2")) == expectedPartitionNames) + withBasicCatalog { catalog => + val expectedPartitionNames = Seq("a=1/b=2", "a=3/b=4") + assert(catalog.listPartitionNames(TableIdentifier("tbl2", Some("db2"))) == + expectedPartitionNames) + // List partition names without explicitly specifying database + catalog.setCurrentDatabase("db2") + assert(catalog.listPartitionNames(TableIdentifier("tbl2")) == expectedPartitionNames) + } } test("list partition names with partial partition spec") { - val catalog = new SessionCatalog(newBasicCatalog()) - assert( - catalog.listPartitionNames(TableIdentifier("tbl2", Some("db2")), Some(Map("a" -> "1"))) == - Seq("a=1/b=2")) + withBasicCatalog { catalog => + assert( + catalog.listPartitionNames(TableIdentifier("tbl2", Some("db2")), Some(Map("a" -> "1"))) == + Seq("a=1/b=2")) + } } test("list partition names with invalid partial partition spec") { - val catalog = new SessionCatalog(newBasicCatalog()) - var e = intercept[AnalysisException] { - catalog.listPartitionNames(TableIdentifier("tbl2", Some("db2")), - Some(partWithMoreColumns.spec)) - } - assert(e.getMessage.contains("Partition spec is invalid. The spec (a, b, c) must be " + - "contained within the partition spec (a, b) defined in table '`db2`.`tbl2`'")) - e = intercept[AnalysisException] { - catalog.listPartitionNames(TableIdentifier("tbl2", Some("db2")), - Some(partWithUnknownColumns.spec)) + withBasicCatalog { catalog => + var e = intercept[AnalysisException] { + catalog.listPartitionNames(TableIdentifier("tbl2", Some("db2")), + Some(partWithMoreColumns.spec)) + } + assert(e.getMessage.contains("Partition spec is invalid. The spec (a, b, c) must be " + + "contained within the partition spec (a, b) defined in table '`db2`.`tbl2`'")) + e = intercept[AnalysisException] { + catalog.listPartitionNames(TableIdentifier("tbl2", Some("db2")), + Some(partWithUnknownColumns.spec)) + } + assert(e.getMessage.contains("Partition spec is invalid. The spec (a, unknown) must be " + + "contained within the partition spec (a, b) defined in table '`db2`.`tbl2`'")) + e = intercept[AnalysisException] { + catalog.listPartitionNames(TableIdentifier("tbl2", Some("db2")), + Some(partWithEmptyValue.spec)) + } + assert(e.getMessage.contains("Partition spec is invalid. The spec ([a=3, b=]) contains an " + + "empty partition column value")) } - assert(e.getMessage.contains("Partition spec is invalid. The spec (a, unknown) must be " + - "contained within the partition spec (a, b) defined in table '`db2`.`tbl2`'")) - e = intercept[AnalysisException] { - catalog.listPartitionNames(TableIdentifier("tbl2", Some("db2")), - Some(partWithEmptyValue.spec)) - } - assert(e.getMessage.contains("Partition spec is invalid. The spec ([a=3, b=]) contains an " + - "empty partition column value")) } test("list partitions") { - val catalog = new SessionCatalog(newBasicCatalog()) - assert(catalogPartitionsEqual( - catalog.listPartitions(TableIdentifier("tbl2", Some("db2"))), part1, part2)) - // List partitions without explicitly specifying database - catalog.setCurrentDatabase("db2") - assert(catalogPartitionsEqual(catalog.listPartitions(TableIdentifier("tbl2")), part1, part2)) + withBasicCatalog { catalog => + assert(catalogPartitionsEqual( + catalog.listPartitions(TableIdentifier("tbl2", Some("db2"))), part1, part2)) + // List partitions without explicitly specifying database + catalog.setCurrentDatabase("db2") + assert(catalogPartitionsEqual(catalog.listPartitions(TableIdentifier("tbl2")), part1, part2)) + } } test("list partitions with partial partition spec") { - val catalog = new SessionCatalog(newBasicCatalog()) - assert(catalogPartitionsEqual( - catalog.listPartitions(TableIdentifier("tbl2", Some("db2")), Some(Map("a" -> "1"))), part1)) + withBasicCatalog { catalog => + assert(catalogPartitionsEqual( + catalog.listPartitions(TableIdentifier("tbl2", Some("db2")), Some(Map("a" -> "1"))), part1)) + } } test("list partitions with invalid partial partition spec") { - val catalog = new SessionCatalog(newBasicCatalog()) - var e = intercept[AnalysisException] { - catalog.listPartitions(TableIdentifier("tbl2", Some("db2")), Some(partWithMoreColumns.spec)) - } - assert(e.getMessage.contains("Partition spec is invalid. The spec (a, b, c) must be " + - "contained within the partition spec (a, b) defined in table '`db2`.`tbl2`'")) - e = intercept[AnalysisException] { - catalog.listPartitions(TableIdentifier("tbl2", Some("db2")), - Some(partWithUnknownColumns.spec)) - } - assert(e.getMessage.contains("Partition spec is invalid. The spec (a, unknown) must be " + - "contained within the partition spec (a, b) defined in table '`db2`.`tbl2`'")) - e = intercept[AnalysisException] { - catalog.listPartitions(TableIdentifier("tbl2", Some("db2")), Some(partWithEmptyValue.spec)) + withBasicCatalog { catalog => + var e = intercept[AnalysisException] { + catalog.listPartitions(TableIdentifier("tbl2", Some("db2")), Some(partWithMoreColumns.spec)) + } + assert(e.getMessage.contains("Partition spec is invalid. The spec (a, b, c) must be " + + "contained within the partition spec (a, b) defined in table '`db2`.`tbl2`'")) + e = intercept[AnalysisException] { + catalog.listPartitions(TableIdentifier("tbl2", Some("db2")), + Some(partWithUnknownColumns.spec)) + } + assert(e.getMessage.contains("Partition spec is invalid. The spec (a, unknown) must be " + + "contained within the partition spec (a, b) defined in table '`db2`.`tbl2`'")) + e = intercept[AnalysisException] { + catalog.listPartitions(TableIdentifier("tbl2", Some("db2")), Some(partWithEmptyValue.spec)) + } + assert(e.getMessage.contains("Partition spec is invalid. The spec ([a=3, b=]) contains an " + + "empty partition column value")) } - assert(e.getMessage.contains("Partition spec is invalid. The spec ([a=3, b=]) contains an " + - "empty partition column value")) } test("list partitions when database/table does not exist") { - val catalog = new SessionCatalog(newBasicCatalog()) - intercept[NoSuchDatabaseException] { - catalog.listPartitions(TableIdentifier("tbl1", Some("unknown_db"))) - } - intercept[NoSuchTableException] { - catalog.listPartitions(TableIdentifier("does_not_exist", Some("db2"))) + withBasicCatalog { catalog => + intercept[NoSuchDatabaseException] { + catalog.listPartitions(TableIdentifier("tbl1", Some("unknown_db"))) + } + intercept[NoSuchTableException] { + catalog.listPartitions(TableIdentifier("does_not_exist", Some("db2"))) + } } } @@ -997,8 +1111,17 @@ class SessionCatalogSuite extends PlanTest { expectedParts: CatalogTablePartition*): Boolean = { // ExternalCatalog may set a default location for partitions, here we ignore the partition // location when comparing them. - actualParts.map(p => p.copy(storage = p.storage.copy(locationUri = None))).toSet == - expectedParts.map(p => p.copy(storage = p.storage.copy(locationUri = None))).toSet + // And for hive serde table, hive metastore will set some values(e.g.transient_lastDdlTime) + // in table's parameters and storage's properties, here we also ignore them. + val actualPartsNormalize = actualParts.map(p => + p.copy(parameters = Map.empty, storage = p.storage.copy( + properties = Map.empty, locationUri = None, serde = None))).toSet + + val expectedPartsNormalize = expectedParts.map(p => + p.copy(parameters = Map.empty, storage = p.storage.copy( + properties = Map.empty, locationUri = None, serde = None))).toSet + + actualPartsNormalize == expectedPartsNormalize } // -------------------------------------------------------------------------- @@ -1006,194 +1129,280 @@ class SessionCatalogSuite extends PlanTest { // -------------------------------------------------------------------------- test("basic create and list functions") { - val externalCatalog = newEmptyCatalog() - val sessionCatalog = new SessionCatalog(externalCatalog) - sessionCatalog.createDatabase(newDb("mydb"), ignoreIfExists = false) - sessionCatalog.createFunction(newFunc("myfunc", Some("mydb")), ignoreIfExists = false) - assert(externalCatalog.listFunctions("mydb", "*").toSet == Set("myfunc")) - // Create function without explicitly specifying database - sessionCatalog.setCurrentDatabase("mydb") - sessionCatalog.createFunction(newFunc("myfunc2"), ignoreIfExists = false) - assert(externalCatalog.listFunctions("mydb", "*").toSet == Set("myfunc", "myfunc2")) + withEmptyCatalog { catalog => + catalog.createDatabase(newDb("mydb"), ignoreIfExists = false) + catalog.createFunction(newFunc("myfunc", Some("mydb")), ignoreIfExists = false) + assert(catalog.externalCatalog.listFunctions("mydb", "*").toSet == Set("myfunc")) + // Create function without explicitly specifying database + catalog.setCurrentDatabase("mydb") + catalog.createFunction(newFunc("myfunc2"), ignoreIfExists = false) + assert(catalog.externalCatalog.listFunctions("mydb", "*").toSet == Set("myfunc", "myfunc2")) + } } test("create function when database does not exist") { - val catalog = new SessionCatalog(newBasicCatalog()) - intercept[NoSuchDatabaseException] { - catalog.createFunction( - newFunc("func5", Some("does_not_exist")), ignoreIfExists = false) + withBasicCatalog { catalog => + intercept[NoSuchDatabaseException] { + catalog.createFunction( + newFunc("func5", Some("does_not_exist")), ignoreIfExists = false) + } } } test("create function that already exists") { - val catalog = new SessionCatalog(newBasicCatalog()) - intercept[FunctionAlreadyExistsException] { - catalog.createFunction(newFunc("func1", Some("db2")), ignoreIfExists = false) + withBasicCatalog { catalog => + intercept[FunctionAlreadyExistsException] { + catalog.createFunction(newFunc("func1", Some("db2")), ignoreIfExists = false) + } + catalog.createFunction(newFunc("func1", Some("db2")), ignoreIfExists = true) } - catalog.createFunction(newFunc("func1", Some("db2")), ignoreIfExists = true) } test("create temp function") { - val catalog = new SessionCatalog(newBasicCatalog()) - val tempFunc1 = (e: Seq[Expression]) => e.head - val tempFunc2 = (e: Seq[Expression]) => e.last - val info1 = new ExpressionInfo("tempFunc1", "temp1") - val info2 = new ExpressionInfo("tempFunc2", "temp2") - catalog.createTempFunction("temp1", info1, tempFunc1, ignoreIfExists = false) - catalog.createTempFunction("temp2", info2, tempFunc2, ignoreIfExists = false) - val arguments = Seq(Literal(1), Literal(2), Literal(3)) - assert(catalog.lookupFunction(FunctionIdentifier("temp1"), arguments) === Literal(1)) - assert(catalog.lookupFunction(FunctionIdentifier("temp2"), arguments) === Literal(3)) - // Temporary function does not exist. - intercept[NoSuchFunctionException] { - catalog.lookupFunction(FunctionIdentifier("temp3"), arguments) - } - val tempFunc3 = (e: Seq[Expression]) => Literal(e.size) - val info3 = new ExpressionInfo("tempFunc3", "temp1") - // Temporary function already exists - intercept[TempFunctionAlreadyExistsException] { - catalog.createTempFunction("temp1", info3, tempFunc3, ignoreIfExists = false) - } - // Temporary function is overridden - catalog.createTempFunction("temp1", info3, tempFunc3, ignoreIfExists = true) - assert( - catalog.lookupFunction(FunctionIdentifier("temp1"), arguments) === Literal(arguments.length)) + withBasicCatalog { catalog => + val tempFunc1 = (e: Seq[Expression]) => e.head + val tempFunc2 = (e: Seq[Expression]) => e.last + catalog.registerFunction( + newFunc("temp1", None), ignoreIfExists = false, functionBuilder = Some(tempFunc1)) + catalog.registerFunction( + newFunc("temp2", None), ignoreIfExists = false, functionBuilder = Some(tempFunc2)) + val arguments = Seq(Literal(1), Literal(2), Literal(3)) + assert(catalog.lookupFunction(FunctionIdentifier("temp1"), arguments) === Literal(1)) + assert(catalog.lookupFunction(FunctionIdentifier("temp2"), arguments) === Literal(3)) + // Temporary function does not exist. + intercept[NoSuchFunctionException] { + catalog.lookupFunction(FunctionIdentifier("temp3"), arguments) + } + val tempFunc3 = (e: Seq[Expression]) => Literal(e.size) + // Temporary function already exists + val e = intercept[AnalysisException] { + catalog.registerFunction( + newFunc("temp1", None), ignoreIfExists = false, functionBuilder = Some(tempFunc3)) + }.getMessage + assert(e.contains("Function temp1 already exists")) + // Temporary function is overridden + catalog.registerFunction( + newFunc("temp1", None), ignoreIfExists = true, functionBuilder = Some(tempFunc3)) + assert( + catalog.lookupFunction( + FunctionIdentifier("temp1"), arguments) === Literal(arguments.length)) + } } test("isTemporaryFunction") { - val externalCatalog = newBasicCatalog() - val sessionCatalog = new SessionCatalog(externalCatalog) + withBasicCatalog { catalog => + // Returns false when the function does not exist + assert(!catalog.isTemporaryFunction(FunctionIdentifier("temp1"))) - // Returns false when the function does not exist - assert(!sessionCatalog.isTemporaryFunction(FunctionIdentifier("temp1"))) + val tempFunc1 = (e: Seq[Expression]) => e.head + catalog.registerFunction( + newFunc("temp1", None), ignoreIfExists = false, functionBuilder = Some(tempFunc1)) - val tempFunc1 = (e: Seq[Expression]) => e.head - val info1 = new ExpressionInfo("tempFunc1", "temp1") - sessionCatalog.createTempFunction("temp1", info1, tempFunc1, ignoreIfExists = false) + // Returns true when the function is temporary + assert(catalog.isTemporaryFunction(FunctionIdentifier("temp1"))) - // Returns true when the function is temporary - assert(sessionCatalog.isTemporaryFunction(FunctionIdentifier("temp1"))) + // Returns false when the function is permanent + assert(catalog.externalCatalog.listFunctions("db2", "*").toSet == Set("func1")) + assert(!catalog.isTemporaryFunction(FunctionIdentifier("func1", Some("db2")))) + assert(!catalog.isTemporaryFunction(FunctionIdentifier("db2.func1"))) + catalog.setCurrentDatabase("db2") + assert(!catalog.isTemporaryFunction(FunctionIdentifier("func1"))) - // Returns false when the function is permanent - assert(externalCatalog.listFunctions("db2", "*").toSet == Set("func1")) - assert(!sessionCatalog.isTemporaryFunction(FunctionIdentifier("func1", Some("db2")))) - assert(!sessionCatalog.isTemporaryFunction(FunctionIdentifier("db2.func1"))) - sessionCatalog.setCurrentDatabase("db2") - assert(!sessionCatalog.isTemporaryFunction(FunctionIdentifier("func1"))) - - // Returns false when the function is built-in or hive - assert(FunctionRegistry.builtin.functionExists("sum")) - assert(!sessionCatalog.isTemporaryFunction(FunctionIdentifier("sum"))) - assert(!sessionCatalog.isTemporaryFunction(FunctionIdentifier("histogram_numeric"))) + // Returns false when the function is built-in or hive + assert(FunctionRegistry.builtin.functionExists("sum")) + assert(!catalog.isTemporaryFunction(FunctionIdentifier("sum"))) + assert(!catalog.isTemporaryFunction(FunctionIdentifier("histogram_numeric"))) + } } test("drop function") { - val externalCatalog = newBasicCatalog() - val sessionCatalog = new SessionCatalog(externalCatalog) - assert(externalCatalog.listFunctions("db2", "*").toSet == Set("func1")) - sessionCatalog.dropFunction( - FunctionIdentifier("func1", Some("db2")), ignoreIfNotExists = false) - assert(externalCatalog.listFunctions("db2", "*").isEmpty) - // Drop function without explicitly specifying database - sessionCatalog.setCurrentDatabase("db2") - sessionCatalog.createFunction(newFunc("func2", Some("db2")), ignoreIfExists = false) - assert(externalCatalog.listFunctions("db2", "*").toSet == Set("func2")) - sessionCatalog.dropFunction(FunctionIdentifier("func2"), ignoreIfNotExists = false) - assert(externalCatalog.listFunctions("db2", "*").isEmpty) + withBasicCatalog { catalog => + assert(catalog.externalCatalog.listFunctions("db2", "*").toSet == Set("func1")) + catalog.dropFunction( + FunctionIdentifier("func1", Some("db2")), ignoreIfNotExists = false) + assert(catalog.externalCatalog.listFunctions("db2", "*").isEmpty) + // Drop function without explicitly specifying database + catalog.setCurrentDatabase("db2") + catalog.createFunction(newFunc("func2", Some("db2")), ignoreIfExists = false) + assert(catalog.externalCatalog.listFunctions("db2", "*").toSet == Set("func2")) + catalog.dropFunction(FunctionIdentifier("func2"), ignoreIfNotExists = false) + assert(catalog.externalCatalog.listFunctions("db2", "*").isEmpty) + } } test("drop function when database/function does not exist") { - val catalog = new SessionCatalog(newBasicCatalog()) - intercept[NoSuchDatabaseException] { - catalog.dropFunction( - FunctionIdentifier("something", Some("unknown_db")), ignoreIfNotExists = false) + withBasicCatalog { catalog => + intercept[NoSuchDatabaseException] { + catalog.dropFunction( + FunctionIdentifier("something", Some("unknown_db")), ignoreIfNotExists = false) + } + intercept[NoSuchFunctionException] { + catalog.dropFunction(FunctionIdentifier("does_not_exist"), ignoreIfNotExists = false) + } + catalog.dropFunction(FunctionIdentifier("does_not_exist"), ignoreIfNotExists = true) } - intercept[NoSuchFunctionException] { - catalog.dropFunction(FunctionIdentifier("does_not_exist"), ignoreIfNotExists = false) - } - catalog.dropFunction(FunctionIdentifier("does_not_exist"), ignoreIfNotExists = true) } test("drop temp function") { - val catalog = new SessionCatalog(newBasicCatalog()) - val info = new ExpressionInfo("tempFunc", "func1") - val tempFunc = (e: Seq[Expression]) => e.head - catalog.createTempFunction("func1", info, tempFunc, ignoreIfExists = false) - val arguments = Seq(Literal(1), Literal(2), Literal(3)) - assert(catalog.lookupFunction(FunctionIdentifier("func1"), arguments) === Literal(1)) - catalog.dropTempFunction("func1", ignoreIfNotExists = false) - intercept[NoSuchFunctionException] { - catalog.lookupFunction(FunctionIdentifier("func1"), arguments) - } - intercept[NoSuchTempFunctionException] { + withBasicCatalog { catalog => + val tempFunc = (e: Seq[Expression]) => e.head + catalog.registerFunction( + newFunc("func1", None), ignoreIfExists = false, functionBuilder = Some(tempFunc)) + val arguments = Seq(Literal(1), Literal(2), Literal(3)) + assert(catalog.lookupFunction(FunctionIdentifier("func1"), arguments) === Literal(1)) catalog.dropTempFunction("func1", ignoreIfNotExists = false) + intercept[NoSuchFunctionException] { + catalog.lookupFunction(FunctionIdentifier("func1"), arguments) + } + intercept[NoSuchTempFunctionException] { + catalog.dropTempFunction("func1", ignoreIfNotExists = false) + } + catalog.dropTempFunction("func1", ignoreIfNotExists = true) } - catalog.dropTempFunction("func1", ignoreIfNotExists = true) } test("get function") { - val catalog = new SessionCatalog(newBasicCatalog()) - val expected = - CatalogFunction(FunctionIdentifier("func1", Some("db2")), funcClass, - Seq.empty[FunctionResource]) - assert(catalog.getFunctionMetadata(FunctionIdentifier("func1", Some("db2"))) == expected) - // Get function without explicitly specifying database - catalog.setCurrentDatabase("db2") - assert(catalog.getFunctionMetadata(FunctionIdentifier("func1")) == expected) + withBasicCatalog { catalog => + val expected = + CatalogFunction(FunctionIdentifier("func1", Some("db2")), funcClass, + Seq.empty[FunctionResource]) + assert(catalog.getFunctionMetadata(FunctionIdentifier("func1", Some("db2"))) == expected) + // Get function without explicitly specifying database + catalog.setCurrentDatabase("db2") + assert(catalog.getFunctionMetadata(FunctionIdentifier("func1")) == expected) + } } test("get function when database/function does not exist") { - val catalog = new SessionCatalog(newBasicCatalog()) - intercept[NoSuchDatabaseException] { - catalog.getFunctionMetadata(FunctionIdentifier("func1", Some("unknown_db"))) - } - intercept[NoSuchFunctionException] { - catalog.getFunctionMetadata(FunctionIdentifier("does_not_exist", Some("db2"))) + withBasicCatalog { catalog => + intercept[NoSuchDatabaseException] { + catalog.getFunctionMetadata(FunctionIdentifier("func1", Some("unknown_db"))) + } + intercept[NoSuchFunctionException] { + catalog.getFunctionMetadata(FunctionIdentifier("does_not_exist", Some("db2"))) + } } } test("lookup temp function") { - val catalog = new SessionCatalog(newBasicCatalog()) - val info1 = new ExpressionInfo("tempFunc1", "func1") - val tempFunc1 = (e: Seq[Expression]) => e.head - catalog.createTempFunction("func1", info1, tempFunc1, ignoreIfExists = false) - assert(catalog.lookupFunction( - FunctionIdentifier("func1"), Seq(Literal(1), Literal(2), Literal(3))) == Literal(1)) - catalog.dropTempFunction("func1", ignoreIfNotExists = false) - intercept[NoSuchFunctionException] { - catalog.lookupFunction(FunctionIdentifier("func1"), Seq(Literal(1), Literal(2), Literal(3))) + withBasicCatalog { catalog => + val tempFunc1 = (e: Seq[Expression]) => e.head + catalog.registerFunction( + newFunc("func1", None), ignoreIfExists = false, functionBuilder = Some(tempFunc1)) + assert(catalog.lookupFunction( + FunctionIdentifier("func1"), Seq(Literal(1), Literal(2), Literal(3))) == Literal(1)) + catalog.dropTempFunction("func1", ignoreIfNotExists = false) + intercept[NoSuchFunctionException] { + catalog.lookupFunction(FunctionIdentifier("func1"), Seq(Literal(1), Literal(2), Literal(3))) + } } } test("list functions") { - val catalog = new SessionCatalog(newBasicCatalog()) - val info1 = new ExpressionInfo("tempFunc1", "func1") - val info2 = new ExpressionInfo("tempFunc2", "yes_me") - val tempFunc1 = (e: Seq[Expression]) => e.head - val tempFunc2 = (e: Seq[Expression]) => e.last - catalog.createFunction(newFunc("func2", Some("db2")), ignoreIfExists = false) - catalog.createFunction(newFunc("not_me", Some("db2")), ignoreIfExists = false) - catalog.createTempFunction("func1", info1, tempFunc1, ignoreIfExists = false) - catalog.createTempFunction("yes_me", info2, tempFunc2, ignoreIfExists = false) - assert(catalog.listFunctions("db1", "*").map(_._1).toSet == - Set(FunctionIdentifier("func1"), - FunctionIdentifier("yes_me"))) - assert(catalog.listFunctions("db2", "*").map(_._1).toSet == - Set(FunctionIdentifier("func1"), - FunctionIdentifier("yes_me"), - FunctionIdentifier("func1", Some("db2")), - FunctionIdentifier("func2", Some("db2")), - FunctionIdentifier("not_me", Some("db2")))) - assert(catalog.listFunctions("db2", "func*").map(_._1).toSet == - Set(FunctionIdentifier("func1"), - FunctionIdentifier("func1", Some("db2")), - FunctionIdentifier("func2", Some("db2")))) + withBasicCatalog { catalog => + val funcMeta1 = newFunc("func1", None) + val funcMeta2 = newFunc("yes_me", None) + val tempFunc1 = (e: Seq[Expression]) => e.head + val tempFunc2 = (e: Seq[Expression]) => e.last + catalog.createFunction(newFunc("func2", Some("db2")), ignoreIfExists = false) + catalog.createFunction(newFunc("not_me", Some("db2")), ignoreIfExists = false) + catalog.registerFunction(funcMeta1, ignoreIfExists = false, functionBuilder = Some(tempFunc1)) + catalog.registerFunction(funcMeta2, ignoreIfExists = false, functionBuilder = Some(tempFunc2)) + assert(catalog.listFunctions("db1", "*").map(_._1).toSet == + Set(FunctionIdentifier("func1"), + FunctionIdentifier("yes_me"))) + assert(catalog.listFunctions("db2", "*").map(_._1).toSet == + Set(FunctionIdentifier("func1"), + FunctionIdentifier("yes_me"), + FunctionIdentifier("func1", Some("db2")), + FunctionIdentifier("func2", Some("db2")), + FunctionIdentifier("not_me", Some("db2")))) + assert(catalog.listFunctions("db2", "func*").map(_._1).toSet == + Set(FunctionIdentifier("func1"), + FunctionIdentifier("func1", Some("db2")), + FunctionIdentifier("func2", Some("db2")))) + } } test("list functions when database does not exist") { - val catalog = new SessionCatalog(newBasicCatalog()) - intercept[NoSuchDatabaseException] { - catalog.listFunctions("unknown_db", "func*") + withBasicCatalog { catalog => + intercept[NoSuchDatabaseException] { + catalog.listFunctions("unknown_db", "func*") + } + } + } + + test("copy SessionCatalog state - temp views") { + withEmptyCatalog { original => + val tempTable1 = Range(1, 10, 1, 10) + original.createTempView("copytest1", tempTable1, overrideIfExists = false) + + // check if tables copied over + val clone = new SessionCatalog(original.externalCatalog) + original.copyStateTo(clone) + + assert(original ne clone) + assert(clone.getTempView("copytest1") == Some(tempTable1)) + + // check if clone and original independent + clone.dropTable(TableIdentifier("copytest1"), ignoreIfNotExists = false, purge = false) + assert(original.getTempView("copytest1") == Some(tempTable1)) + + val tempTable2 = Range(1, 20, 2, 10) + original.createTempView("copytest2", tempTable2, overrideIfExists = false) + assert(clone.getTempView("copytest2").isEmpty) + } + } + + test("copy SessionCatalog state - current db") { + withEmptyCatalog { original => + val db1 = "db1" + val db2 = "db2" + val db3 = "db3" + + original.externalCatalog.createDatabase(newDb(db1), ignoreIfExists = true) + original.externalCatalog.createDatabase(newDb(db2), ignoreIfExists = true) + original.externalCatalog.createDatabase(newDb(db3), ignoreIfExists = true) + + original.setCurrentDatabase(db1) + + // check if current db copied over + val clone = new SessionCatalog(original.externalCatalog) + original.copyStateTo(clone) + + assert(original ne clone) + assert(clone.getCurrentDatabase == db1) + + // check if clone and original independent + clone.setCurrentDatabase(db2) + assert(original.getCurrentDatabase == db1) + original.setCurrentDatabase(db3) + assert(clone.getCurrentDatabase == db2) + } + } + + test("SPARK-19737: detect undefined functions without triggering relation resolution") { + import org.apache.spark.sql.catalyst.dsl.plans._ + + Seq(true, false) foreach { caseSensitive => + val conf = new SQLConf().copy(SQLConf.CASE_SENSITIVE -> caseSensitive) + val catalog = new SessionCatalog(newBasicCatalog(), new SimpleFunctionRegistry, conf) + try { + val analyzer = new Analyzer(catalog, conf) + + // The analyzer should report the undefined function rather than the undefined table first. + val cause = intercept[AnalysisException] { + analyzer.execute( + UnresolvedRelation(TableIdentifier("undefined_table")).select( + UnresolvedFunction("undefined_fn", Nil, isDistinct = false) + ) + ) + } + + assert(cause.getMessage.contains("Undefined function: 'undefined_fn'")) + } finally { + catalog.reset() + } } } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/EncoderResolutionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/EncoderResolutionSuite.scala index 802397d50e85c..630e8a7990e7b 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/EncoderResolutionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/EncoderResolutionSuite.scala @@ -33,6 +33,12 @@ case class StringIntClass(a: String, b: Int) case class ComplexClass(a: Long, b: StringLongClass) +case class PrimitiveArrayClass(arr: Array[Long]) + +case class ArrayClass(arr: Seq[StringIntClass]) + +case class NestedArrayClass(nestedArr: Array[ArrayClass]) + class EncoderResolutionSuite extends PlanTest { private val str = UTF8String.fromString("hello") @@ -62,6 +68,75 @@ class EncoderResolutionSuite extends PlanTest { encoder.resolveAndBind(attrs).fromRow(InternalRow(InternalRow(str, 1.toByte), 2)) } + test("real type doesn't match encoder schema but they are compatible: primitive array") { + val encoder = ExpressionEncoder[PrimitiveArrayClass] + val attrs = Seq('arr.array(IntegerType)) + val array = new GenericArrayData(Array(1, 2, 3)) + encoder.resolveAndBind(attrs).fromRow(InternalRow(array)) + } + + test("the real type is not compatible with encoder schema: primitive array") { + val encoder = ExpressionEncoder[PrimitiveArrayClass] + val attrs = Seq('arr.array(StringType)) + assert(intercept[AnalysisException](encoder.resolveAndBind(attrs)).message == + s""" + |Cannot up cast array element from string to bigint as it may truncate + |The type path of the target object is: + |- array element class: "scala.Long" + |- field (class: "scala.Array", name: "arr") + |- root class: "org.apache.spark.sql.catalyst.encoders.PrimitiveArrayClass" + |You can either add an explicit cast to the input data or choose a higher precision type + """.stripMargin.trim + " of the field in the target object") + } + + test("real type doesn't match encoder schema but they are compatible: array") { + val encoder = ExpressionEncoder[ArrayClass] + val attrs = Seq('arr.array(new StructType().add("a", "int").add("b", "int").add("c", "int"))) + val array = new GenericArrayData(Array(InternalRow(1, 2, 3))) + encoder.resolveAndBind(attrs).fromRow(InternalRow(array)) + } + + test("real type doesn't match encoder schema but they are compatible: nested array") { + val encoder = ExpressionEncoder[NestedArrayClass] + val et = new StructType().add("arr", ArrayType( + new StructType().add("a", "int").add("b", "int").add("c", "int"))) + val attrs = Seq('nestedArr.array(et)) + val innerArr = new GenericArrayData(Array(InternalRow(1, 2, 3))) + val outerArr = new GenericArrayData(Array(InternalRow(innerArr))) + encoder.resolveAndBind(attrs).fromRow(InternalRow(outerArr)) + } + + test("the real type is not compatible with encoder schema: non-array field") { + val encoder = ExpressionEncoder[ArrayClass] + val attrs = Seq('arr.int) + assert(intercept[AnalysisException](encoder.resolveAndBind(attrs)).message == + "need an array field but got int") + } + + test("the real type is not compatible with encoder schema: array element type") { + val encoder = ExpressionEncoder[ArrayClass] + val attrs = Seq('arr.array(new StructType().add("c", "int"))) + assert(intercept[AnalysisException](encoder.resolveAndBind(attrs)).message == + "No such struct field a in c") + } + + test("the real type is not compatible with encoder schema: nested array element type") { + val encoder = ExpressionEncoder[NestedArrayClass] + + withClue("inner element is not array") { + val attrs = Seq('nestedArr.array(new StructType().add("arr", "int"))) + assert(intercept[AnalysisException](encoder.resolveAndBind(attrs)).message == + "need an array field but got int") + } + + withClue("nested array element type is not compatible") { + val attrs = Seq('nestedArr.array(new StructType() + .add("arr", ArrayType(new StructType().add("c", "int"))))) + assert(intercept[AnalysisException](encoder.resolveAndBind(attrs)).message == + "No such struct field a in c") + } + } + test("nullability of array type element should not fail analysis") { val encoder = ExpressionEncoder[Seq[Int]] val attrs = 'a.array(IntegerType) :: Nil diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala index 8eccadbdd8afb..a7ffa884d2286 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala @@ -813,4 +813,18 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper { assert(cast(1.0.toFloat, DateType).checkInputDataTypes().isFailure) assert(cast(1.0, DateType).checkInputDataTypes().isFailure) } + + test("SPARK-20302 cast with same structure") { + val from = new StructType() + .add("a", IntegerType) + .add("b", new StructType().add("b1", LongType)) + + val to = new StructType() + .add("a1", IntegerType) + .add("b1", new StructType().add("b11", LongType)) + + val input = Row(10, Row(12L)) + + checkEvaluation(cast(Literal.create(input, from), to), input) + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala index abe1d2b2c99e1..5f8a8f44d48e6 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala @@ -251,6 +251,9 @@ class ComplexTypeSuite extends SparkFunSuite with ExpressionEvalHelper { } test("StringToMap") { + val expectedDataType = MapType(StringType, StringType, valueContainsNull = true) + assert(new StringToMap("").dataType === expectedDataType) + val s0 = Literal("a:1,b:2,c:3") val m0 = Map("a" -> "1", "b" -> "2", "c" -> "3") checkEvaluation(new StringToMap(s0), m0) @@ -271,6 +274,10 @@ class ComplexTypeSuite extends SparkFunSuite with ExpressionEvalHelper { val m4 = Map("a" -> "1", "b" -> "2", "c" -> "3") checkEvaluation(new StringToMap(s4, Literal("_")), m4) + val s5 = Literal("a") + val m5 = Map("a" -> null) + checkEvaluation(new StringToMap(s5), m5) + // arguments checking assert(new StringToMap(Literal("a:1,b:2,c:3")).checkInputDataTypes().isSuccess) assert(new StringToMap(Literal(null)).checkInputDataTypes().isFailure) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DateExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DateExpressionsSuite.scala index 9978f35a03810..ca89bf7db0b4f 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DateExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DateExpressionsSuite.scala @@ -160,7 +160,7 @@ class DateExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { test("Seconds") { assert(Second(Literal.create(null, DateType), gmtId).resolved === false) - assert(Second(Cast(Literal(d), TimestampType), None).resolved === true) + assert(Second(Cast(Literal(d), TimestampType, gmtId), gmtId).resolved === true) checkEvaluation(Second(Cast(Literal(d), TimestampType, gmtId), gmtId), 0) checkEvaluation(Second(Cast(Literal(sdf.format(d)), TimestampType, gmtId), gmtId), 15) checkEvaluation(Second(Literal(ts), gmtId), 15) @@ -220,7 +220,7 @@ class DateExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { test("Hour") { assert(Hour(Literal.create(null, DateType), gmtId).resolved === false) - assert(Hour(Literal(ts), None).resolved === true) + assert(Hour(Literal(ts), gmtId).resolved === true) checkEvaluation(Hour(Cast(Literal(d), TimestampType, gmtId), gmtId), 0) checkEvaluation(Hour(Cast(Literal(sdf.format(d)), TimestampType, gmtId), gmtId), 13) checkEvaluation(Hour(Literal(ts), gmtId), 13) @@ -246,7 +246,7 @@ class DateExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { test("Minute") { assert(Minute(Literal.create(null, DateType), gmtId).resolved === false) - assert(Minute(Literal(ts), None).resolved === true) + assert(Minute(Literal(ts), gmtId).resolved === true) checkEvaluation(Minute(Cast(Literal(d), TimestampType, gmtId), gmtId), 0) checkEvaluation( Minute(Cast(Literal(sdf.format(d)), TimestampType, gmtId), gmtId), 10) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala index 1ba6dd1c5e8ca..b6399edb68dd6 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala @@ -25,10 +25,12 @@ import org.scalatest.prop.GeneratorDrivenPropertyChecks import org.apache.spark.{SparkConf, SparkFunSuite} import org.apache.spark.serializer.JavaSerializer import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow} +import org.apache.spark.sql.catalyst.analysis.ResolveTimeZone import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.optimizer.SimpleTestOptimizer import org.apache.spark.sql.catalyst.plans.logical.{OneRowRelation, Project} -import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, ArrayData, GenericArrayData, MapData} +import org.apache.spark.sql.catalyst.util.{ArrayData, MapData} +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ import org.apache.spark.util.Utils @@ -45,7 +47,8 @@ trait ExpressionEvalHelper extends GeneratorDrivenPropertyChecks { protected def checkEvaluation( expression: => Expression, expected: Any, inputRow: InternalRow = EmptyRow): Unit = { val serializer = new JavaSerializer(new SparkConf()).newInstance - val expr: Expression = serializer.deserialize(serializer.serialize(expression)) + val resolver = ResolveTimeZone(new SQLConf) + val expr = resolver.resolveTimeZones(serializer.deserialize(serializer.serialize(expression))) val catalystValue = CatalystTypeConverters.convertToCatalyst(expected) checkEvaluationWithoutCodegen(expr, catalystValue, inputRow) checkEvaluationWithGeneratedMutableProjection(expr, catalystValue, inputRow) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/HashExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/HashExpressionsSuite.scala index 0cb3a79eee67d..59fc8eaf73d61 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/HashExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/HashExpressionsSuite.scala @@ -18,18 +18,20 @@ package org.apache.spark.sql.catalyst.expressions import java.nio.charset.StandardCharsets +import java.util.TimeZone import scala.collection.mutable.ArrayBuffer import org.apache.commons.codec.digest.DigestUtils +import org.scalatest.exceptions.TestFailedException import org.apache.spark.SparkFunSuite import org.apache.spark.sql.{RandomDataGenerator, Row} import org.apache.spark.sql.catalyst.encoders.{ExamplePointUDT, RowEncoder} import org.apache.spark.sql.catalyst.expressions.codegen.GenerateMutableProjection -import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, GenericArrayData} +import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, DateTimeUtils, GenericArrayData} import org.apache.spark.sql.types.{ArrayType, StructType, _} -import org.apache.spark.unsafe.types.UTF8String +import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String} class HashExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { val random = new scala.util.Random @@ -75,7 +77,6 @@ class HashExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { checkConsistencyBetweenInterpretedAndCodegen(Crc32, BinaryType) } - def checkHiveHash(input: Any, dataType: DataType, expected: Long): Unit = { // Note : All expected hashes need to be computed using Hive 1.2.1 val actual = HiveHashFunction.hash(input, dataType, seed = 0) @@ -169,6 +170,208 @@ class HashExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { // scalastyle:on nonascii } + test("hive-hash for date type") { + def checkHiveHashForDateType(dateString: String, expected: Long): Unit = { + checkHiveHash( + DateTimeUtils.stringToDate(UTF8String.fromString(dateString)).get, + DateType, + expected) + } + + // basic case + checkHiveHashForDateType("2017-01-01", 17167) + + // boundary cases + checkHiveHashForDateType("0000-01-01", -719530) + checkHiveHashForDateType("9999-12-31", 2932896) + + // epoch + checkHiveHashForDateType("1970-01-01", 0) + + // before epoch + checkHiveHashForDateType("1800-01-01", -62091) + + // Invalid input: bad date string. Hive returns 0 for such cases + intercept[NoSuchElementException](checkHiveHashForDateType("0-0-0", 0)) + intercept[NoSuchElementException](checkHiveHashForDateType("-1212-01-01", 0)) + intercept[NoSuchElementException](checkHiveHashForDateType("2016-99-99", 0)) + + // Invalid input: Empty string. Hive returns 0 for this case + intercept[NoSuchElementException](checkHiveHashForDateType("", 0)) + + // Invalid input: February 30th for a leap year. Hive supports this but Spark doesn't + intercept[NoSuchElementException](checkHiveHashForDateType("2016-02-30", 16861)) + } + + test("hive-hash for timestamp type") { + def checkHiveHashForTimestampType( + timestamp: String, + expected: Long, + timeZone: TimeZone = TimeZone.getTimeZone("UTC")): Unit = { + checkHiveHash( + DateTimeUtils.stringToTimestamp(UTF8String.fromString(timestamp), timeZone).get, + TimestampType, + expected) + } + + // basic case + checkHiveHashForTimestampType("2017-02-24 10:56:29", 1445725271) + + // with higher precision + checkHiveHashForTimestampType("2017-02-24 10:56:29.111111", 1353936655) + + // with different timezone + checkHiveHashForTimestampType("2017-02-24 10:56:29", 1445732471, + TimeZone.getTimeZone("US/Pacific")) + + // boundary cases + checkHiveHashForTimestampType("0001-01-01 00:00:00", 1645926784) + checkHiveHashForTimestampType("9999-01-01 00:00:00", -1081818240) + + // epoch + checkHiveHashForTimestampType("1970-01-01 00:00:00", 0) + + // before epoch + checkHiveHashForTimestampType("1800-01-01 03:12:45", -267420885) + + // Invalid input: bad timestamp string. Hive returns 0 for such cases + intercept[NoSuchElementException](checkHiveHashForTimestampType("0-0-0 0:0:0", 0)) + intercept[NoSuchElementException](checkHiveHashForTimestampType("-99-99-99 99:99:45", 0)) + intercept[NoSuchElementException](checkHiveHashForTimestampType("555555-55555-5555", 0)) + + // Invalid input: Empty string. Hive returns 0 for this case + intercept[NoSuchElementException](checkHiveHashForTimestampType("", 0)) + + // Invalid input: February 30th is a leap year. Hive supports this but Spark doesn't + intercept[NoSuchElementException](checkHiveHashForTimestampType("2016-02-30 00:00:00", 0)) + + // Invalid input: Hive accepts upto 9 decimal place precision but Spark uses upto 6 + intercept[TestFailedException](checkHiveHashForTimestampType("2017-02-24 10:56:29.11111111", 0)) + } + + test("hive-hash for CalendarInterval type") { + def checkHiveHashForIntervalType(interval: String, expected: Long): Unit = { + checkHiveHash(CalendarInterval.fromString(interval), CalendarIntervalType, expected) + } + + // ----- MICROSEC ----- + + // basic case + checkHiveHashForIntervalType("interval 1 microsecond", 24273) + + // negative + checkHiveHashForIntervalType("interval -1 microsecond", 22273) + + // edge / boundary cases + checkHiveHashForIntervalType("interval 0 microsecond", 23273) + checkHiveHashForIntervalType("interval 999 microsecond", 1022273) + checkHiveHashForIntervalType("interval -999 microsecond", -975727) + + // ----- MILLISEC ----- + + // basic case + checkHiveHashForIntervalType("interval 1 millisecond", 1023273) + + // negative + checkHiveHashForIntervalType("interval -1 millisecond", -976727) + + // edge / boundary cases + checkHiveHashForIntervalType("interval 0 millisecond", 23273) + checkHiveHashForIntervalType("interval 999 millisecond", 999023273) + checkHiveHashForIntervalType("interval -999 millisecond", -998976727) + + // ----- SECOND ----- + + // basic case + checkHiveHashForIntervalType("interval 1 second", 23310) + + // negative + checkHiveHashForIntervalType("interval -1 second", 23273) + + // edge / boundary cases + checkHiveHashForIntervalType("interval 0 second", 23273) + checkHiveHashForIntervalType("interval 2147483647 second", -2147460412) + checkHiveHashForIntervalType("interval -2147483648 second", -2147460412) + + // Out of range for both Hive and Spark + // Hive throws an exception. Spark overflows and returns wrong output + // checkHiveHashForIntervalType("interval 9999999999 second", 0) + + // ----- MINUTE ----- + + // basic cases + checkHiveHashForIntervalType("interval 1 minute", 25493) + + // negative + checkHiveHashForIntervalType("interval -1 minute", 25456) + + // edge / boundary cases + checkHiveHashForIntervalType("interval 0 minute", 23273) + checkHiveHashForIntervalType("interval 2147483647 minute", 21830) + checkHiveHashForIntervalType("interval -2147483648 minute", 22163) + + // Out of range for both Hive and Spark + // Hive throws an exception. Spark overflows and returns wrong output + // checkHiveHashForIntervalType("interval 9999999999 minute", 0) + + // ----- HOUR ----- + + // basic case + checkHiveHashForIntervalType("interval 1 hour", 156473) + + // negative + checkHiveHashForIntervalType("interval -1 hour", 156436) + + // edge / boundary cases + checkHiveHashForIntervalType("interval 0 hour", 23273) + checkHiveHashForIntervalType("interval 2147483647 hour", -62308) + checkHiveHashForIntervalType("interval -2147483648 hour", -43327) + + // Out of range for both Hive and Spark + // Hive throws an exception. Spark overflows and returns wrong output + // checkHiveHashForIntervalType("interval 9999999999 hour", 0) + + // ----- DAY ----- + + // basic cases + checkHiveHashForIntervalType("interval 1 day", 3220073) + + // negative + checkHiveHashForIntervalType("interval -1 day", 3220036) + + // edge / boundary cases + checkHiveHashForIntervalType("interval 0 day", 23273) + checkHiveHashForIntervalType("interval 106751991 day", -451506760) + checkHiveHashForIntervalType("interval -106751991 day", -451514123) + + // Hive supports `day` for a longer range but Spark's range is smaller + // The check for range is done at the parser level so this does not fail in Spark + // checkHiveHashForIntervalType("interval -2147483648 day", -1575127) + // checkHiveHashForIntervalType("interval 2147483647 day", -4767228) + + // Out of range for both Hive and Spark + // Hive throws an exception. Spark overflows and returns wrong output + // checkHiveHashForIntervalType("interval 9999999999 day", 0) + + // ----- MIX ----- + + checkHiveHashForIntervalType("interval 0 day 0 hour", 23273) + checkHiveHashForIntervalType("interval 0 day 0 hour 0 minute", 23273) + checkHiveHashForIntervalType("interval 0 day 0 hour 0 minute 0 second", 23273) + checkHiveHashForIntervalType("interval 0 day 0 hour 0 minute 0 second 0 millisecond", 23273) + checkHiveHashForIntervalType( + "interval 0 day 0 hour 0 minute 0 second 0 millisecond 0 microsecond", 23273) + + checkHiveHashForIntervalType("interval 6 day 15 hour", 21202073) + checkHiveHashForIntervalType("interval 5 day 4 hour 8 minute", 16557833) + checkHiveHashForIntervalType("interval -23 day 56 hour -1111113 minute 9898989 second", + -2128468593) + checkHiveHashForIntervalType("interval 66 day 12 hour 39 minute 23 second 987 millisecond", + 1199697904) + checkHiveHashForIntervalType( + "interval 66 day 12 hour 39 minute 23 second 987 millisecond 123 microsecond", 1199820904) + } + test("hive-hash for array") { // empty array checkHiveHash( @@ -371,6 +574,51 @@ class HashExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { new StructType().add("array", arrayOfString).add("map", mapOfString)) .add("structOfUDT", structOfUDT)) + test("hive-hash for decimal") { + def checkHiveHashForDecimal( + input: String, + precision: Int, + scale: Int, + expected: Long): Unit = { + val decimalType = DataTypes.createDecimalType(precision, scale) + val decimal = { + val value = Decimal.apply(new java.math.BigDecimal(input)) + if (value.changePrecision(precision, scale)) value else null + } + + checkHiveHash(decimal, decimalType, expected) + } + + checkHiveHashForDecimal("18", 38, 0, 558) + checkHiveHashForDecimal("-18", 38, 0, -558) + checkHiveHashForDecimal("-18", 38, 12, -558) + checkHiveHashForDecimal("18446744073709001000", 38, 19, 0) + checkHiveHashForDecimal("-18446744073709001000", 38, 22, 0) + checkHiveHashForDecimal("-18446744073709001000", 38, 3, 17070057) + checkHiveHashForDecimal("18446744073709001000", 38, 4, -17070057) + checkHiveHashForDecimal("9223372036854775807", 38, 4, 2147482656) + checkHiveHashForDecimal("-9223372036854775807", 38, 5, -2147482656) + checkHiveHashForDecimal("00000.00000000000", 38, 34, 0) + checkHiveHashForDecimal("-00000.00000000000", 38, 11, 0) + checkHiveHashForDecimal("123456.1234567890", 38, 2, 382713974) + checkHiveHashForDecimal("123456.1234567890", 38, 20, 1871500252) + checkHiveHashForDecimal("123456.1234567890", 38, 10, 1871500252) + checkHiveHashForDecimal("-123456.1234567890", 38, 10, -1871500234) + checkHiveHashForDecimal("123456.1234567890", 38, 0, 3827136) + checkHiveHashForDecimal("-123456.1234567890", 38, 0, -3827136) + checkHiveHashForDecimal("123456.1234567890", 38, 20, 1871500252) + checkHiveHashForDecimal("-123456.1234567890", 38, 20, -1871500234) + checkHiveHashForDecimal("123456.123456789012345678901234567890", 38, 0, 3827136) + checkHiveHashForDecimal("-123456.123456789012345678901234567890", 38, 0, -3827136) + checkHiveHashForDecimal("123456.123456789012345678901234567890", 38, 10, 1871500252) + checkHiveHashForDecimal("-123456.123456789012345678901234567890", 38, 10, -1871500234) + checkHiveHashForDecimal("123456.123456789012345678901234567890", 38, 20, 236317582) + checkHiveHashForDecimal("-123456.123456789012345678901234567890", 38, 20, -236317544) + checkHiveHashForDecimal("123456.123456789012345678901234567890", 38, 30, 1728235666) + checkHiveHashForDecimal("-123456.123456789012345678901234567890", 38, 30, -1728235608) + checkHiveHashForDecimal("123456.123456789012345678901234567890", 38, 31, 1728235666) + } + test("SPARK-18207: Compute hash for a lot of expressions") { val N = 1000 val wideRow = new GenericInternalRow( diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/JsonExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/JsonExpressionsSuite.scala index e3584909ddc4a..f892e80204603 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/JsonExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/JsonExpressionsSuite.scala @@ -21,7 +21,7 @@ import java.util.Calendar import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.util.{DateTimeTestUtils, DateTimeUtils, ParseModes} +import org.apache.spark.sql.catalyst.util.{DateTimeTestUtils, DateTimeUtils, GenericArrayData, PermissiveMode} import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String @@ -39,6 +39,10 @@ class JsonExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { |"fb:testid":"1234"} |""".stripMargin + /* invalid json with leading nulls would trigger java.io.CharConversionException + in Jackson's JsonFactory.createParser(byte[]) due to RFC-4627 encoding detection */ + val badJson = "\u0000\u0000\u0000A\u0001AAA" + test("$.store.bicycle") { checkEvaluation( GetJsonObject(Literal(json), Literal("$.store.bicycle")), @@ -224,6 +228,13 @@ class JsonExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { null) } + test("SPARK-16548: character conversion") { + checkEvaluation( + GetJsonObject(Literal(badJson), Literal("$.a")), + null + ) + } + test("non foldable literal") { checkEvaluation( GetJsonObject(NonFoldableLiteral(json), NonFoldableLiteral("$.fb:testid")), @@ -340,6 +351,12 @@ class JsonExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { InternalRow(null, null, null, null, null)) } + test("SPARK-16548: json_tuple - invalid json with leading nulls") { + checkJsonTuple( + JsonTuple(Literal(badJson) :: jsonTupleQuery), + InternalRow(null, null, null, null, null)) + } + test("json_tuple - preserve newlines") { checkJsonTuple( JsonTuple(Literal("{\"a\":\"b\nc\"}") :: Literal("a") :: Nil), @@ -352,7 +369,7 @@ class JsonExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { val jsonData = """{"a": 1}""" val schema = StructType(StructField("a", IntegerType) :: Nil) checkEvaluation( - JsonToStruct(schema, Map.empty, Literal(jsonData), gmtId), + JsonToStructs(schema, Map.empty, Literal(jsonData), gmtId), InternalRow(1) ) } @@ -361,13 +378,13 @@ class JsonExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { val jsonData = """{"a" 1}""" val schema = StructType(StructField("a", IntegerType) :: Nil) checkEvaluation( - JsonToStruct(schema, Map.empty, Literal(jsonData), gmtId), + JsonToStructs(schema, Map.empty, Literal(jsonData), gmtId), null ) // Other modes should still return `null`. checkEvaluation( - JsonToStruct(schema, Map("mode" -> ParseModes.PERMISSIVE_MODE), Literal(jsonData), gmtId), + JsonToStructs(schema, Map("mode" -> PermissiveMode.name), Literal(jsonData), gmtId), null ) } @@ -376,66 +393,73 @@ class JsonExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { val input = """[{"a": 1}, {"a": 2}]""" val schema = ArrayType(StructType(StructField("a", IntegerType) :: Nil)) val output = InternalRow(1) :: InternalRow(2) :: Nil - checkEvaluation(JsonToStruct(schema, Map.empty, Literal(input), gmtId), output) + checkEvaluation(JsonToStructs(schema, Map.empty, Literal(input), gmtId), output) } test("from_json - input=object, schema=array, output=array of single row") { val input = """{"a": 1}""" val schema = ArrayType(StructType(StructField("a", IntegerType) :: Nil)) val output = InternalRow(1) :: Nil - checkEvaluation(JsonToStruct(schema, Map.empty, Literal(input), gmtId), output) + checkEvaluation(JsonToStructs(schema, Map.empty, Literal(input), gmtId), output) } test("from_json - input=empty array, schema=array, output=empty array") { val input = "[ ]" val schema = ArrayType(StructType(StructField("a", IntegerType) :: Nil)) val output = Nil - checkEvaluation(JsonToStruct(schema, Map.empty, Literal(input), gmtId), output) + checkEvaluation(JsonToStructs(schema, Map.empty, Literal(input), gmtId), output) } test("from_json - input=empty object, schema=array, output=array of single row with null") { val input = "{ }" val schema = ArrayType(StructType(StructField("a", IntegerType) :: Nil)) val output = InternalRow(null) :: Nil - checkEvaluation(JsonToStruct(schema, Map.empty, Literal(input), gmtId), output) + checkEvaluation(JsonToStructs(schema, Map.empty, Literal(input), gmtId), output) } test("from_json - input=array of single object, schema=struct, output=single row") { val input = """[{"a": 1}]""" val schema = StructType(StructField("a", IntegerType) :: Nil) val output = InternalRow(1) - checkEvaluation(JsonToStruct(schema, Map.empty, Literal(input), gmtId), output) + checkEvaluation(JsonToStructs(schema, Map.empty, Literal(input), gmtId), output) } test("from_json - input=array, schema=struct, output=null") { val input = """[{"a": 1}, {"a": 2}]""" val schema = StructType(StructField("a", IntegerType) :: Nil) val output = null - checkEvaluation(JsonToStruct(schema, Map.empty, Literal(input), gmtId), output) + checkEvaluation(JsonToStructs(schema, Map.empty, Literal(input), gmtId), output) } test("from_json - input=empty array, schema=struct, output=null") { val input = """[]""" val schema = StructType(StructField("a", IntegerType) :: Nil) val output = null - checkEvaluation(JsonToStruct(schema, Map.empty, Literal(input), gmtId), output) + checkEvaluation(JsonToStructs(schema, Map.empty, Literal(input), gmtId), output) } test("from_json - input=empty object, schema=struct, output=single row with null") { val input = """{ }""" val schema = StructType(StructField("a", IntegerType) :: Nil) val output = InternalRow(null) - checkEvaluation(JsonToStruct(schema, Map.empty, Literal(input), gmtId), output) + checkEvaluation(JsonToStructs(schema, Map.empty, Literal(input), gmtId), output) } test("from_json null input column") { val schema = StructType(StructField("a", IntegerType) :: Nil) checkEvaluation( - JsonToStruct(schema, Map.empty, Literal.create(null, StringType), gmtId), + JsonToStructs(schema, Map.empty, Literal.create(null, StringType), gmtId), null ) } + test("SPARK-20549: from_json bad UTF-8") { + val schema = StructType(StructField("a", IntegerType) :: Nil) + checkEvaluation( + JsonToStructs(schema, Map.empty, Literal(badJson), gmtId), + null) + } + test("from_json with timestamp") { val schema = StructType(StructField("t", TimestampType) :: Nil) @@ -444,14 +468,14 @@ class JsonExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { c.set(2016, 0, 1, 0, 0, 0) c.set(Calendar.MILLISECOND, 123) checkEvaluation( - JsonToStruct(schema, Map.empty, Literal(jsonData1), gmtId), + JsonToStructs(schema, Map.empty, Literal(jsonData1), gmtId), InternalRow(c.getTimeInMillis * 1000L) ) // The result doesn't change because the json string includes timezone string ("Z" here), // which means the string represents the timestamp string in the timezone regardless of // the timeZoneId parameter. checkEvaluation( - JsonToStruct(schema, Map.empty, Literal(jsonData1), Option("PST")), + JsonToStructs(schema, Map.empty, Literal(jsonData1), Option("PST")), InternalRow(c.getTimeInMillis * 1000L) ) @@ -461,7 +485,7 @@ class JsonExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { c.set(2016, 0, 1, 0, 0, 0) c.set(Calendar.MILLISECOND, 0) checkEvaluation( - JsonToStruct( + JsonToStructs( schema, Map("timestampFormat" -> "yyyy-MM-dd'T'HH:mm:ss"), Literal(jsonData2), @@ -469,9 +493,10 @@ class JsonExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { InternalRow(c.getTimeInMillis * 1000L) ) checkEvaluation( - JsonToStruct( + JsonToStructs( schema, - Map("timestampFormat" -> "yyyy-MM-dd'T'HH:mm:ss", "timeZone" -> tz.getID), + Map("timestampFormat" -> "yyyy-MM-dd'T'HH:mm:ss", + DateTimeUtils.TIMEZONE_OPTION -> tz.getID), Literal(jsonData2), gmtId), InternalRow(c.getTimeInMillis * 1000L) @@ -482,25 +507,52 @@ class JsonExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { test("SPARK-19543: from_json empty input column") { val schema = StructType(StructField("a", IntegerType) :: Nil) checkEvaluation( - JsonToStruct(schema, Map.empty, Literal.create(" ", StringType), gmtId), + JsonToStructs(schema, Map.empty, Literal.create(" ", StringType), gmtId), null ) } - test("to_json") { + test("to_json - struct") { val schema = StructType(StructField("a", IntegerType) :: Nil) val struct = Literal.create(create_row(1), schema) checkEvaluation( - StructToJson(Map.empty, struct, gmtId), + StructsToJson(Map.empty, struct, gmtId), """{"a":1}""" ) } + test("to_json - array") { + val inputSchema = ArrayType(StructType(StructField("a", IntegerType) :: Nil)) + val input = new GenericArrayData(InternalRow(1) :: InternalRow(2) :: Nil) + val output = """[{"a":1},{"a":2}]""" + checkEvaluation( + StructsToJson(Map.empty, Literal.create(input, inputSchema), gmtId), + output) + } + + test("to_json - array with single empty row") { + val inputSchema = ArrayType(StructType(StructField("a", IntegerType) :: Nil)) + val input = new GenericArrayData(InternalRow(null) :: Nil) + val output = """[{}]""" + checkEvaluation( + StructsToJson(Map.empty, Literal.create(input, inputSchema), gmtId), + output) + } + + test("to_json - empty array") { + val inputSchema = ArrayType(StructType(StructField("a", IntegerType) :: Nil)) + val input = new GenericArrayData(Nil) + val output = """[]""" + checkEvaluation( + StructsToJson(Map.empty, Literal.create(input, inputSchema), gmtId), + output) + } + test("to_json null input column") { val schema = StructType(StructField("a", IntegerType) :: Nil) val struct = Literal.create(null, schema) checkEvaluation( - StructToJson(Map.empty, struct, gmtId), + StructsToJson(Map.empty, struct, gmtId), null ) } @@ -513,24 +565,26 @@ class JsonExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { val struct = Literal.create(create_row(c.getTimeInMillis * 1000L), schema) checkEvaluation( - StructToJson(Map.empty, struct, gmtId), + StructsToJson(Map.empty, struct, gmtId), """{"t":"2016-01-01T00:00:00.000Z"}""" ) checkEvaluation( - StructToJson(Map.empty, struct, Option("PST")), + StructsToJson(Map.empty, struct, Option("PST")), """{"t":"2015-12-31T16:00:00.000-08:00"}""" ) checkEvaluation( - StructToJson( - Map("timestampFormat" -> "yyyy-MM-dd'T'HH:mm:ss", "timeZone" -> gmtId.get), + StructsToJson( + Map("timestampFormat" -> "yyyy-MM-dd'T'HH:mm:ss", + DateTimeUtils.TIMEZONE_OPTION -> gmtId.get), struct, gmtId), """{"t":"2016-01-01T00:00:00"}""" ) checkEvaluation( - StructToJson( - Map("timestampFormat" -> "yyyy-MM-dd'T'HH:mm:ss", "timeZone" -> "PST"), + StructsToJson( + Map("timestampFormat" -> "yyyy-MM-dd'T'HH:mm:ss", + DateTimeUtils.TIMEZONE_OPTION -> "PST"), struct, gmtId), """{"t":"2015-12-31T16:00:00"}""" diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/RegexpExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/RegexpExpressionsSuite.scala index 5299549e7b4da..1ce150e091981 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/RegexpExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/RegexpExpressionsSuite.scala @@ -18,16 +18,38 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.dsl.expressions._ -import org.apache.spark.sql.types.StringType +import org.apache.spark.sql.types.{IntegerType, StringType} /** * Unit tests for regular expression (regexp) related SQL expressions. */ class RegexpExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { - test("LIKE literal Regular Expression") { - checkEvaluation(Literal.create(null, StringType).like("a"), null) + /** + * Check if a given expression evaluates to an expected output, in case the input is + * a literal and in case the input is in the form of a row. + * @tparam A type of input + * @param mkExpr the expression to test for a given input + * @param input value that will be used to create the expression, as literal and in the form + * of a row + * @param expected the expected output of the expression + * @param inputToExpression an implicit conversion from the input type to its corresponding + * sql expression + */ + def checkLiteralRow[A](mkExpr: Expression => Expression, input: A, expected: Any) + (implicit inputToExpression: A => Expression): Unit = { + checkEvaluation(mkExpr(input), expected) // check literal input + + val regex = 'a.string.at(0) + checkEvaluation(mkExpr(regex), expected, create_row(input)) // check row input + } + + test("LIKE Pattern") { + + // null handling + checkLiteralRow(Literal.create(null, StringType).like(_), "a", null) checkEvaluation(Literal.create("a", StringType).like(Literal.create(null, StringType)), null) checkEvaluation(Literal.create(null, StringType).like(Literal.create(null, StringType)), null) checkEvaluation( @@ -39,45 +61,64 @@ class RegexpExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation( Literal.create(null, StringType).like(NonFoldableLiteral.create(null, StringType)), null) - checkEvaluation("abdef" like "abdef", true) - checkEvaluation("a_%b" like "a\\__b", true) - checkEvaluation("addb" like "a_%b", true) - checkEvaluation("addb" like "a\\__b", false) - checkEvaluation("addb" like "a%\\%b", false) - checkEvaluation("a_%b" like "a%\\%b", true) - checkEvaluation("addb" like "a%", true) - checkEvaluation("addb" like "**", false) - checkEvaluation("abc" like "a%", true) - checkEvaluation("abc" like "b%", false) - checkEvaluation("abc" like "bc%", false) - checkEvaluation("a\nb" like "a_b", true) - checkEvaluation("ab" like "a%b", true) - checkEvaluation("a\nb" like "a%b", true) - } + // simple patterns + checkLiteralRow("abdef" like _, "abdef", true) + checkLiteralRow("a_%b" like _, "a\\__b", true) + checkLiteralRow("addb" like _, "a_%b", true) + checkLiteralRow("addb" like _, "a\\__b", false) + checkLiteralRow("addb" like _, "a%\\%b", false) + checkLiteralRow("a_%b" like _, "a%\\%b", true) + checkLiteralRow("addb" like _, "a%", true) + checkLiteralRow("addb" like _, "**", false) + checkLiteralRow("abc" like _, "a%", true) + checkLiteralRow("abc" like _, "b%", false) + checkLiteralRow("abc" like _, "bc%", false) + checkLiteralRow("a\nb" like _, "a_b", true) + checkLiteralRow("ab" like _, "a%b", true) + checkLiteralRow("a\nb" like _, "a%b", true) + + // empty input + checkLiteralRow("" like _, "", true) + checkLiteralRow("a" like _, "", false) + checkLiteralRow("" like _, "a", false) + + // SI-17647 double-escaping backslash + checkLiteralRow("""\\\\""" like _, """%\\%""", true) + checkLiteralRow("""%%""" like _, """%%""", true) + checkLiteralRow("""\__""" like _, """\\\__""", true) + checkLiteralRow("""\\\__""" like _, """%\\%\%""", false) + checkLiteralRow("""_\\\%""" like _, """%\\""", false) + + // unicode + // scalastyle:off nonascii + checkLiteralRow("a\u20ACa" like _, "_\u20AC_", true) + checkLiteralRow("a€a" like _, "_€_", true) + checkLiteralRow("a€a" like _, "_\u20AC_", true) + checkLiteralRow("a\u20ACa" like _, "_€_", true) + // scalastyle:on nonascii + + // invalid escaping + val invalidEscape = intercept[AnalysisException] { + evaluate("""a""" like """\a""") + } + assert(invalidEscape.getMessage.contains("pattern")) + + val endEscape = intercept[AnalysisException] { + evaluate("""a""" like """a\""") + } + assert(endEscape.getMessage.contains("pattern")) + + // case + checkLiteralRow("A" like _, "a%", false) + checkLiteralRow("a" like _, "A%", false) + checkLiteralRow("AaA" like _, "_a_", true) - test("LIKE Non-literal Regular Expression") { - val regEx = 'a.string.at(0) - checkEvaluation("abcd" like regEx, null, create_row(null)) - checkEvaluation("abdef" like regEx, true, create_row("abdef")) - checkEvaluation("a_%b" like regEx, true, create_row("a\\__b")) - checkEvaluation("addb" like regEx, true, create_row("a_%b")) - checkEvaluation("addb" like regEx, false, create_row("a\\__b")) - checkEvaluation("addb" like regEx, false, create_row("a%\\%b")) - checkEvaluation("a_%b" like regEx, true, create_row("a%\\%b")) - checkEvaluation("addb" like regEx, true, create_row("a%")) - checkEvaluation("addb" like regEx, false, create_row("**")) - checkEvaluation("abc" like regEx, true, create_row("a%")) - checkEvaluation("abc" like regEx, false, create_row("b%")) - checkEvaluation("abc" like regEx, false, create_row("bc%")) - checkEvaluation("a\nb" like regEx, true, create_row("a_b")) - checkEvaluation("ab" like regEx, true, create_row("a%b")) - checkEvaluation("a\nb" like regEx, true, create_row("a%b")) - - checkEvaluation(Literal.create(null, StringType) like regEx, null, create_row("bc%")) + // example + checkLiteralRow("""%SystemDrive%\Users\John""" like _, """\%SystemDrive\%\\Users%""", true) } - test("RLIKE literal Regular Expression") { - checkEvaluation(Literal.create(null, StringType) rlike "abdef", null) + test("RLIKE Regular Expression") { + checkLiteralRow(Literal.create(null, StringType) rlike _, "abdef", null) checkEvaluation("abdef" rlike Literal.create(null, StringType), null) checkEvaluation(Literal.create(null, StringType) rlike Literal.create(null, StringType), null) checkEvaluation("abdef" rlike NonFoldableLiteral.create("abdef", StringType), true) @@ -87,42 +128,32 @@ class RegexpExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation( Literal.create(null, StringType) rlike NonFoldableLiteral.create(null, StringType), null) - checkEvaluation("abdef" rlike "abdef", true) - checkEvaluation("abbbbc" rlike "a.*c", true) + checkLiteralRow("abdef" rlike _, "abdef", true) + checkLiteralRow("abbbbc" rlike _, "a.*c", true) - checkEvaluation("fofo" rlike "^fo", true) - checkEvaluation("fo\no" rlike "^fo\no$", true) - checkEvaluation("Bn" rlike "^Ba*n", true) - checkEvaluation("afofo" rlike "fo", true) - checkEvaluation("afofo" rlike "^fo", false) - checkEvaluation("Baan" rlike "^Ba?n", false) - checkEvaluation("axe" rlike "pi|apa", false) - checkEvaluation("pip" rlike "^(pi)*$", false) + checkLiteralRow("fofo" rlike _, "^fo", true) + checkLiteralRow("fo\no" rlike _, "^fo\no$", true) + checkLiteralRow("Bn" rlike _, "^Ba*n", true) + checkLiteralRow("afofo" rlike _, "fo", true) + checkLiteralRow("afofo" rlike _, "^fo", false) + checkLiteralRow("Baan" rlike _, "^Ba?n", false) + checkLiteralRow("axe" rlike _, "pi|apa", false) + checkLiteralRow("pip" rlike _, "^(pi)*$", false) - checkEvaluation("abc" rlike "^ab", true) - checkEvaluation("abc" rlike "^bc", false) - checkEvaluation("abc" rlike "^ab", true) - checkEvaluation("abc" rlike "^bc", false) + checkLiteralRow("abc" rlike _, "^ab", true) + checkLiteralRow("abc" rlike _, "^bc", false) + checkLiteralRow("abc" rlike _, "^ab", true) + checkLiteralRow("abc" rlike _, "^bc", false) intercept[java.util.regex.PatternSyntaxException] { evaluate("abbbbc" rlike "**") } - } - - test("RLIKE Non-literal Regular Expression") { - val regEx = 'a.string.at(0) - checkEvaluation("abdef" rlike regEx, true, create_row("abdef")) - checkEvaluation("abbbbc" rlike regEx, true, create_row("a.*c")) - checkEvaluation("fofo" rlike regEx, true, create_row("^fo")) - checkEvaluation("fo\no" rlike regEx, true, create_row("^fo\no$")) - checkEvaluation("Bn" rlike regEx, true, create_row("^Ba*n")) - intercept[java.util.regex.PatternSyntaxException] { - evaluate("abbbbc" rlike regEx, create_row("**")) + val regex = 'a.string.at(0) + evaluate("abbbbc" rlike regex, create_row("**")) } } - test("RegexReplace") { val row1 = create_row("100-200", "(\\d+)", "num") val row2 = create_row("100-200", "(\\d+)", "###") diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDFSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDFSuite.scala index 7e45028653e36..13bd363c8b692 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDFSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDFSuite.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.catalyst.expressions +import java.util.Locale + import org.apache.spark.{SparkException, SparkFunSuite} import org.apache.spark.sql.types.{IntegerType, StringType} @@ -32,7 +34,7 @@ class ScalaUDFSuite extends SparkFunSuite with ExpressionEvalHelper { test("better error message for NPE") { val udf = ScalaUDF( - (s: String) => s.toLowerCase, + (s: String) => s.toLowerCase(Locale.ROOT), StringType, Literal.create(null, StringType) :: Nil) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/AggregateOptimizeSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/AggregateOptimizeSuite.scala index b45bd977cbba1..e6132ab2e4d17 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/AggregateOptimizeSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/AggregateOptimizeSuite.scala @@ -17,7 +17,6 @@ package org.apache.spark.sql.catalyst.optimizer -import org.apache.spark.sql.catalyst.SimpleCatalystConf import org.apache.spark.sql.catalyst.analysis.{Analyzer, EmptyFunctionRegistry} import org.apache.spark.sql.catalyst.catalog.{InMemoryCatalog, SessionCatalog} import org.apache.spark.sql.catalyst.dsl.expressions._ @@ -26,9 +25,11 @@ import org.apache.spark.sql.catalyst.expressions.Literal import org.apache.spark.sql.catalyst.plans.PlanTest import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan} import org.apache.spark.sql.catalyst.rules.RuleExecutor +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.internal.SQLConf.{CASE_SENSITIVE, GROUP_BY_ORDINAL} class AggregateOptimizeSuite extends PlanTest { - override val conf = SimpleCatalystConf(caseSensitiveAnalysis = false, groupByOrdinal = false) + override val conf = new SQLConf().copy(CASE_SENSITIVE -> false, GROUP_BY_ORDINAL -> false) val catalog = new SessionCatalog(new InMemoryCatalog, EmptyFunctionRegistry, conf) val analyzer = new Analyzer(catalog, conf) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/BinaryComparisonSimplificationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/BinaryComparisonSimplificationSuite.scala index a0d489681fd9f..b29e1cbd14943 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/BinaryComparisonSimplificationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/BinaryComparisonSimplificationSuite.scala @@ -17,7 +17,6 @@ package org.apache.spark.sql.catalyst.optimizer -import org.apache.spark.sql.catalyst.SimpleCatalystConf import org.apache.spark.sql.catalyst.analysis._ import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.dsl.plans._ @@ -34,11 +33,11 @@ class BinaryComparisonSimplificationSuite extends PlanTest with PredicateHelper Batch("AnalysisNodes", Once, EliminateSubqueryAliases) :: Batch("Constant Folding", FixedPoint(50), - NullPropagation(SimpleCatalystConf(caseSensitiveAnalysis = true)), + NullPropagation(conf), ConstantFolding, BooleanSimplification, SimplifyBinaryComparison, - PruneFilters) :: Nil + PruneFilters(conf)) :: Nil } val nullableRelation = LocalRelation('a.int.withNullability(true)) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/BooleanSimplificationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/BooleanSimplificationSuite.scala index 1b9db06014921..c275f997ba6e9 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/BooleanSimplificationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/BooleanSimplificationSuite.scala @@ -17,7 +17,6 @@ package org.apache.spark.sql.catalyst.optimizer -import org.apache.spark.sql.catalyst.SimpleCatalystConf import org.apache.spark.sql.catalyst.analysis._ import org.apache.spark.sql.catalyst.catalog.{InMemoryCatalog, SessionCatalog} import org.apache.spark.sql.catalyst.dsl.expressions._ @@ -26,6 +25,8 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.PlanTest import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules._ +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.Row class BooleanSimplificationSuite extends PlanTest with PredicateHelper { @@ -34,14 +35,24 @@ class BooleanSimplificationSuite extends PlanTest with PredicateHelper { Batch("AnalysisNodes", Once, EliminateSubqueryAliases) :: Batch("Constant Folding", FixedPoint(50), - NullPropagation(SimpleCatalystConf(caseSensitiveAnalysis = true)), + NullPropagation(conf), ConstantFolding, BooleanSimplification, - PruneFilters) :: Nil + PruneFilters(conf)) :: Nil } val testRelation = LocalRelation('a.int, 'b.int, 'c.int, 'd.string) + val testRelationWithData = LocalRelation.fromExternalRows( + testRelation.output, Seq(Row(1, 2, 3, "abc")) + ) + + private def checkCondition(input: Expression, expected: LogicalPlan): Unit = { + val plan = testRelationWithData.where(input).analyze + val actual = Optimize.execute(plan) + comparePlans(actual, expected) + } + private def checkCondition(input: Expression, expected: Expression): Unit = { val plan = testRelation.where(input).analyze val actual = Optimize.execute(plan) @@ -138,7 +149,7 @@ class BooleanSimplificationSuite extends PlanTest with PredicateHelper { checkCondition(!(('a || 'b) && ('c || 'd)), (!'a && !'b) || (!'c && !'d)) } - private val caseInsensitiveConf = new SimpleCatalystConf(false) + private val caseInsensitiveConf = new SQLConf().copy(SQLConf.CASE_SENSITIVE -> false) private val caseInsensitiveAnalyzer = new Analyzer( new SessionCatalog(new InMemoryCatalog, EmptyFunctionRegistry, caseInsensitiveConf), caseInsensitiveConf) @@ -160,4 +171,12 @@ class BooleanSimplificationSuite extends PlanTest with PredicateHelper { testRelation.where('a > 2 || ('b > 3 && 'b < 5))) comparePlans(actual, expected) } + + test("Complementation Laws") { + checkCondition('a && !'a, testRelation) + checkCondition(!'a && 'a, testRelation) + + checkCondition('a || !'a, testRelationWithData) + checkCondition(!'a || 'a, testRelationWithData) + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/CollapseRepartitionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/CollapseRepartitionSuite.scala index 8952c72fe42fe..8cc8decd65de1 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/CollapseRepartitionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/CollapseRepartitionSuite.scala @@ -32,47 +32,167 @@ class CollapseRepartitionSuite extends PlanTest { val testRelation = LocalRelation('a.int, 'b.int) + + test("collapse two adjacent coalesces into one") { + // Always respects the top coalesces amd removes useless coalesce below coalesce + val query1 = testRelation + .coalesce(10) + .coalesce(20) + val query2 = testRelation + .coalesce(30) + .coalesce(20) + + val optimized1 = Optimize.execute(query1.analyze) + val optimized2 = Optimize.execute(query2.analyze) + val correctAnswer = testRelation.coalesce(20).analyze + + comparePlans(optimized1, correctAnswer) + comparePlans(optimized2, correctAnswer) + } + test("collapse two adjacent repartitions into one") { - val query = testRelation + // Always respects the top repartition amd removes useless repartition below repartition + val query1 = testRelation .repartition(10) .repartition(20) + val query2 = testRelation + .repartition(30) + .repartition(20) + + val optimized1 = Optimize.execute(query1.analyze) + val optimized2 = Optimize.execute(query2.analyze) + val correctAnswer = testRelation.repartition(20).analyze + + comparePlans(optimized1, correctAnswer) + comparePlans(optimized2, correctAnswer) + } + + test("coalesce above repartition") { + // Remove useless coalesce above repartition + val query1 = testRelation + .repartition(10) + .coalesce(20) + + val optimized1 = Optimize.execute(query1.analyze) + val correctAnswer1 = testRelation.repartition(10).analyze + + comparePlans(optimized1, correctAnswer1) + + // No change in this case + val query2 = testRelation + .repartition(30) + .coalesce(20) + + val optimized2 = Optimize.execute(query2.analyze) + val correctAnswer2 = query2.analyze + + comparePlans(optimized2, correctAnswer2) + } + + test("repartition above coalesce") { + // Always respects the top repartition amd removes useless coalesce below repartition + val query1 = testRelation + .coalesce(10) + .repartition(20) + val query2 = testRelation + .coalesce(30) + .repartition(20) - val optimized = Optimize.execute(query.analyze) + val optimized1 = Optimize.execute(query1.analyze) + val optimized2 = Optimize.execute(query2.analyze) val correctAnswer = testRelation.repartition(20).analyze - comparePlans(optimized, correctAnswer) + comparePlans(optimized1, correctAnswer) + comparePlans(optimized2, correctAnswer) } - test("collapse repartition and repartitionBy into one") { - val query = testRelation + test("distribute above repartition") { + // Always respects the top distribute and removes useless repartition + val query1 = testRelation .repartition(10) .distribute('a)(20) + val query2 = testRelation + .repartition(30) + .distribute('a)(20) - val optimized = Optimize.execute(query.analyze) + val optimized1 = Optimize.execute(query1.analyze) + val optimized2 = Optimize.execute(query2.analyze) val correctAnswer = testRelation.distribute('a)(20).analyze - comparePlans(optimized, correctAnswer) + comparePlans(optimized1, correctAnswer) + comparePlans(optimized2, correctAnswer) } - test("collapse repartitionBy and repartition into one") { - val query = testRelation + test("distribute above coalesce") { + // Always respects the top distribute and removes useless coalesce below repartition + val query1 = testRelation + .coalesce(10) + .distribute('a)(20) + val query2 = testRelation + .coalesce(30) .distribute('a)(20) - .repartition(10) - val optimized = Optimize.execute(query.analyze) - val correctAnswer = testRelation.distribute('a)(10).analyze + val optimized1 = Optimize.execute(query1.analyze) + val optimized2 = Optimize.execute(query2.analyze) + val correctAnswer = testRelation.distribute('a)(20).analyze - comparePlans(optimized, correctAnswer) + comparePlans(optimized1, correctAnswer) + comparePlans(optimized2, correctAnswer) } - test("collapse two adjacent repartitionBys into one") { - val query = testRelation + test("repartition above distribute") { + // Always respects the top repartition and removes useless distribute below repartition + val query1 = testRelation + .distribute('a)(10) + .repartition(20) + val query2 = testRelation + .distribute('a)(30) + .repartition(20) + + val optimized1 = Optimize.execute(query1.analyze) + val optimized2 = Optimize.execute(query2.analyze) + val correctAnswer = testRelation.repartition(20).analyze + + comparePlans(optimized1, correctAnswer) + comparePlans(optimized2, correctAnswer) + } + + test("coalesce above distribute") { + // Remove useless coalesce above distribute + val query1 = testRelation + .distribute('a)(10) + .coalesce(20) + + val optimized1 = Optimize.execute(query1.analyze) + val correctAnswer1 = testRelation.distribute('a)(10).analyze + + comparePlans(optimized1, correctAnswer1) + + // No change in this case + val query2 = testRelation + .distribute('a)(30) + .coalesce(20) + + val optimized2 = Optimize.execute(query2.analyze) + val correctAnswer2 = query2.analyze + + comparePlans(optimized2, correctAnswer2) + } + + test("collapse two adjacent distributes into one") { + // Always respects the top distribute + val query1 = testRelation .distribute('b)(10) .distribute('a)(20) + val query2 = testRelation + .distribute('b)(30) + .distribute('a)(20) - val optimized = Optimize.execute(query.analyze) + val optimized1 = Optimize.execute(query1.analyze) + val optimized2 = Optimize.execute(query2.analyze) val correctAnswer = testRelation.distribute('a)(20).analyze - comparePlans(optimized, correctAnswer) + comparePlans(optimized1, correctAnswer) + comparePlans(optimized2, correctAnswer) } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/CollapseWindowSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/CollapseWindowSuite.scala index 3f7d1d9fd99af..52054c2f8bd8d 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/CollapseWindowSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/CollapseWindowSuite.scala @@ -78,4 +78,15 @@ class CollapseWindowSuite extends PlanTest { comparePlans(optimized2, correctAnswer2) } + + test("Don't collapse adjacent windows with dependent columns") { + val query = testRelation + .window(Seq(sum(a).as('sum_a)), partitionSpec1, orderSpec1) + .window(Seq(max('sum_a).as('max_sum_a)), partitionSpec1, orderSpec1) + .analyze + + val expected = query.analyze + val optimized = Optimize.execute(query.analyze) + comparePlans(optimized, expected) + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ColumnPruningSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ColumnPruningSuite.scala index 5bd1bc80c3b8a..589607e3ad5cb 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ColumnPruningSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ColumnPruningSuite.scala @@ -320,16 +320,16 @@ class ColumnPruningSuite extends PlanTest { val query = Project(Seq($"x.key", $"y.key"), Join( - SubqueryAlias("x", input, None), - BroadcastHint(SubqueryAlias("y", input, None)), Inner, None)).analyze + SubqueryAlias("x", input), + BroadcastHint(SubqueryAlias("y", input)), Inner, None)).analyze val optimized = Optimize.execute(query) val expected = Join( - Project(Seq($"x.key"), SubqueryAlias("x", input, None)), + Project(Seq($"x.key"), SubqueryAlias("x", input)), BroadcastHint( - Project(Seq($"y.key"), SubqueryAlias("y", input, None))), + Project(Seq($"y.key"), SubqueryAlias("y", input))), Inner, None).analyze comparePlans(optimized, expected) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/CombiningLimitsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/CombiningLimitsSuite.scala index 276b8055b08d0..ac71887c16f96 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/CombiningLimitsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/CombiningLimitsSuite.scala @@ -17,7 +17,6 @@ package org.apache.spark.sql.catalyst.optimizer -import org.apache.spark.sql.catalyst.SimpleCatalystConf import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.dsl.plans._ import org.apache.spark.sql.catalyst.plans.logical._ @@ -33,7 +32,7 @@ class CombiningLimitsSuite extends PlanTest { Batch("Combine Limit", FixedPoint(10), CombineLimits) :: Batch("Constant Folding", FixedPoint(10), - NullPropagation(SimpleCatalystConf(caseSensitiveAnalysis = true)), + NullPropagation(conf), ConstantFolding, BooleanSimplification, SimplifyConditionals) :: Nil diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ConstantFoldingSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ConstantFoldingSuite.scala index d9655bbcc2ce1..25c592b9c1dde 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ConstantFoldingSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ConstantFoldingSuite.scala @@ -17,7 +17,6 @@ package org.apache.spark.sql.catalyst.optimizer -import org.apache.spark.sql.catalyst.SimpleCatalystConf import org.apache.spark.sql.catalyst.analysis.{EliminateSubqueryAliases, UnresolvedExtractValue} import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.dsl.plans._ @@ -34,7 +33,7 @@ class ConstantFoldingSuite extends PlanTest { Batch("AnalysisNodes", Once, EliminateSubqueryAliases) :: Batch("ConstantFolding", Once, - OptimizeIn(SimpleCatalystConf(true)), + OptimizeIn(conf), ConstantFolding, BooleanSimplification) :: Nil } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/DecimalAggregatesSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/DecimalAggregatesSuite.scala index a491f4433370d..cc4fb3a244a98 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/DecimalAggregatesSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/DecimalAggregatesSuite.scala @@ -17,7 +17,6 @@ package org.apache.spark.sql.catalyst.optimizer -import org.apache.spark.sql.catalyst.SimpleCatalystConf import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.dsl.plans._ import org.apache.spark.sql.catalyst.expressions._ @@ -30,7 +29,7 @@ class DecimalAggregatesSuite extends PlanTest { object Optimize extends RuleExecutor[LogicalPlan] { val batches = Batch("Decimal Optimizations", FixedPoint(100), - DecimalAggregates(SimpleCatalystConf(caseSensitiveAnalysis = true))) :: Nil + DecimalAggregates(conf)) :: Nil } val testRelation = LocalRelation('a.decimal(2, 1), 'b.decimal(12, 1)) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/EliminateMapObjectsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/EliminateMapObjectsSuite.scala new file mode 100644 index 0000000000000..d4f37e2a5e877 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/EliminateMapObjectsSuite.scala @@ -0,0 +1,62 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.optimizer + +import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.catalyst.dsl.plans._ +import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder +import org.apache.spark.sql.catalyst.expressions.AttributeReference +import org.apache.spark.sql.catalyst.expressions.objects.Invoke +import org.apache.spark.sql.catalyst.plans.PlanTest +import org.apache.spark.sql.catalyst.plans.logical.{DeserializeToObject, LocalRelation, LogicalPlan} +import org.apache.spark.sql.catalyst.rules.RuleExecutor +import org.apache.spark.sql.types._ + +class EliminateMapObjectsSuite extends PlanTest { + object Optimize extends RuleExecutor[LogicalPlan] { + val batches = { + Batch("EliminateMapObjects", FixedPoint(50), + NullPropagation(conf), + SimplifyCasts, + EliminateMapObjects) :: Nil + } + } + + implicit private def intArrayEncoder = ExpressionEncoder[Array[Int]]() + implicit private def doubleArrayEncoder = ExpressionEncoder[Array[Double]]() + + test("SPARK-20254: Remove unnecessary data conversion for primitive array") { + val intObjType = ObjectType(classOf[Array[Int]]) + val intInput = LocalRelation('a.array(ArrayType(IntegerType, false))) + val intQuery = intInput.deserialize[Array[Int]].analyze + val intOptimized = Optimize.execute(intQuery) + val intExpected = DeserializeToObject( + Invoke(intInput.output(0), "toIntArray", intObjType, Nil, true, false), + AttributeReference("obj", intObjType, true)(), intInput) + comparePlans(intOptimized, intExpected) + + val doubleObjType = ObjectType(classOf[Array[Double]]) + val doubleInput = LocalRelation('a.array(ArrayType(DoubleType, false))) + val doubleQuery = doubleInput.deserialize[Array[Double]].analyze + val doubleOptimized = Optimize.execute(doubleQuery) + val doubleExpected = DeserializeToObject( + Invoke(doubleInput.output(0), "toDoubleArray", doubleObjType, Nil, true, false), + AttributeReference("obj", doubleObjType, true)(), doubleInput) + comparePlans(doubleOptimized, doubleExpected) + } +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/EliminateSortsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/EliminateSortsSuite.scala index c5f9cc1852752..e318f36d78270 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/EliminateSortsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/EliminateSortsSuite.scala @@ -17,7 +17,6 @@ package org.apache.spark.sql.catalyst.optimizer -import org.apache.spark.sql.catalyst.SimpleCatalystConf import org.apache.spark.sql.catalyst.analysis.{Analyzer, EmptyFunctionRegistry} import org.apache.spark.sql.catalyst.catalog.{InMemoryCatalog, SessionCatalog} import org.apache.spark.sql.catalyst.dsl.expressions._ @@ -26,9 +25,11 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules._ +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.internal.SQLConf.{CASE_SENSITIVE, ORDER_BY_ORDINAL} class EliminateSortsSuite extends PlanTest { - override val conf = new SimpleCatalystConf(caseSensitiveAnalysis = true, orderByOrdinal = false) + override val conf = new SQLConf().copy(CASE_SENSITIVE -> true, ORDER_BY_ORDINAL -> false) val catalog = new SessionCatalog(new InMemoryCatalog, EmptyFunctionRegistry, conf) val analyzer = new Analyzer(catalog, conf) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/EliminateSubqueryAliasesSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/EliminateSubqueryAliasesSuite.scala index a8aeedbd62759..9b6d68aee803a 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/EliminateSubqueryAliasesSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/EliminateSubqueryAliasesSuite.scala @@ -46,13 +46,13 @@ class EliminateSubqueryAliasesSuite extends PlanTest with PredicateHelper { test("eliminate top level subquery") { val input = LocalRelation('a.int, 'b.int) - val query = SubqueryAlias("a", input, None) + val query = SubqueryAlias("a", input) comparePlans(afterOptimization(query), input) } test("eliminate mid-tree subquery") { val input = LocalRelation('a.int, 'b.int) - val query = Filter(TrueLiteral, SubqueryAlias("a", input, None)) + val query = Filter(TrueLiteral, SubqueryAlias("a", input)) comparePlans( afterOptimization(query), Filter(TrueLiteral, LocalRelation('a.int, 'b.int))) @@ -61,7 +61,7 @@ class EliminateSubqueryAliasesSuite extends PlanTest with PredicateHelper { test("eliminate multiple subqueries") { val input = LocalRelation('a.int, 'b.int) val query = Filter(TrueLiteral, - SubqueryAlias("c", SubqueryAlias("b", SubqueryAlias("a", input, None), None), None)) + SubqueryAlias("c", SubqueryAlias("b", SubqueryAlias("a", input)))) comparePlans( afterOptimization(query), Filter(TrueLiteral, LocalRelation('a.int, 'b.int))) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala index 6feea4060f46a..950aa2379517e 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala @@ -134,15 +134,20 @@ class FilterPushdownSuite extends PlanTest { comparePlans(optimized, correctAnswer) } - test("nondeterministic: can't push down filter with nondeterministic condition through project") { + test("nondeterministic: can always push down filter through project with deterministic field") { val originalQuery = testRelation - .select(Rand(10).as('rand), 'a) - .where('rand > 5 || 'a > 5) + .select('a) + .where(Rand(10) > 5 || 'a > 5) .analyze val optimized = Optimize.execute(originalQuery) - comparePlans(optimized, originalQuery) + val correctAnswer = testRelation + .where(Rand(10) > 5 || 'a > 5) + .select('a) + .analyze + + comparePlans(optimized, correctAnswer) } test("nondeterministic: can't push down filter through project with nondeterministic field") { @@ -156,6 +161,34 @@ class FilterPushdownSuite extends PlanTest { comparePlans(optimized, originalQuery) } + test("nondeterministic: can't push down filter through aggregate with nondeterministic field") { + val originalQuery = testRelation + .groupBy('a)('a, Rand(10).as('rand)) + .where('a > 5) + .analyze + + val optimized = Optimize.execute(originalQuery) + + comparePlans(optimized, originalQuery) + } + + test("nondeterministic: push down part of filter through aggregate with deterministic field") { + val originalQuery = testRelation + .groupBy('a)('a) + .where('a > 5 && Rand(10) > 5) + .analyze + + val optimized = Optimize.execute(originalQuery.analyze) + + val correctAnswer = testRelation + .where('a > 5) + .groupBy('a)('a) + .where(Rand(10) > 5) + .analyze + + comparePlans(optimized, correctAnswer) + } + test("filters: combines filters") { val originalQuery = testRelation .select('a) @@ -208,6 +241,16 @@ class FilterPushdownSuite extends PlanTest { comparePlans(optimized, correctAnswer) } + test("joins: do not push down non-deterministic filters into join condition") { + val x = testRelation.subquery('x) + val y = testRelation1.subquery('y) + + val originalQuery = x.join(y).where(Rand(10) > 5.0).analyze + val optimized = Optimize.execute(originalQuery) + + comparePlans(optimized, originalQuery) + } + test("joins: push to one side after transformCondition") { val x = testRelation.subquery('x) val y = testRelation1.subquery('y) @@ -836,6 +879,26 @@ class FilterPushdownSuite extends PlanTest { comparePlans(optimized, answer) } + test("SPARK-20094: don't push predicate with IN subquery into join condition") { + val x = testRelation.subquery('x) + val z = testRelation.subquery('z) + val w = testRelation1.subquery('w) + + val queryPlan = x + .join(z) + .where(("x.b".attr === "z.b".attr) && + ("x.a".attr > 1 || "z.c".attr.in(ListQuery(w.select("w.d".attr))))) + .analyze + + val expectedPlan = x + .join(z, Inner, Some("x.b".attr === "z.b".attr)) + .where("x.a".attr > 1 || "z.c".attr.in(ListQuery(w.select("w.d".attr)))) + .analyze + + val optimized = Optimize.execute(queryPlan) + comparePlans(optimized, expectedPlan) + } + test("Window: predicate push down -- basic") { val winExpr = windowExpr(count('b), windowSpec('a :: Nil, 'b.asc :: Nil, UnspecifiedFrame)) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/InferFiltersFromConstraintsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/InferFiltersFromConstraintsSuite.scala index 9f57f66a2ea20..c8fe37462726a 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/InferFiltersFromConstraintsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/InferFiltersFromConstraintsSuite.scala @@ -23,6 +23,7 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules._ +import org.apache.spark.sql.internal.SQLConf.CONSTRAINT_PROPAGATION_ENABLED class InferFiltersFromConstraintsSuite extends PlanTest { @@ -31,7 +32,16 @@ class InferFiltersFromConstraintsSuite extends PlanTest { Batch("InferAndPushDownFilters", FixedPoint(100), PushPredicateThroughJoin, PushDownPredicate, - InferFiltersFromConstraints, + InferFiltersFromConstraints(conf), + CombineFilters) :: Nil + } + + object OptimizeWithConstraintPropagationDisabled extends RuleExecutor[LogicalPlan] { + val batches = + Batch("InferAndPushDownFilters", FixedPoint(100), + PushPredicateThroughJoin, + PushDownPredicate, + InferFiltersFromConstraints(conf.copy(CONSTRAINT_PROPAGATION_ENABLED -> false)), CombineFilters) :: Nil } @@ -201,4 +211,10 @@ class InferFiltersFromConstraintsSuite extends PlanTest { val optimized = Optimize.execute(originalQuery) comparePlans(optimized, correctAnswer) } + + test("No inferred filter when constraint propagation is disabled") { + val originalQuery = testRelation.where('a === 1 && 'a === 'b).analyze + val optimized = OptimizeWithConstraintPropagationDisabled.execute(originalQuery) + comparePlans(optimized, originalQuery) + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/JoinOptimizationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/JoinOptimizationSuite.scala index 65dd6225cea07..a43d78c7bd447 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/JoinOptimizationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/JoinOptimizationSuite.scala @@ -27,7 +27,6 @@ import org.apache.spark.sql.catalyst.plans.{Cross, Inner, InnerLike, PlanTest} import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules.RuleExecutor - class JoinOptimizationSuite extends PlanTest { object Optimize extends RuleExecutor[LogicalPlan] { @@ -38,7 +37,7 @@ class JoinOptimizationSuite extends PlanTest { CombineFilters, PushDownPredicate, BooleanSimplification, - ReorderJoin, + ReorderJoin(conf), PushPredicateThroughJoin, ColumnPruning, CollapseProject) :: Nil @@ -129,15 +128,15 @@ class JoinOptimizationSuite extends PlanTest { val query = Project(Seq($"x.key", $"y.key"), Join( - SubqueryAlias("x", input, None), - BroadcastHint(SubqueryAlias("y", input, None)), Cross, None)).analyze + SubqueryAlias("x", input), + BroadcastHint(SubqueryAlias("y", input)), Cross, None)).analyze val optimized = Optimize.execute(query) val expected = Join( - Project(Seq($"x.key"), SubqueryAlias("x", input, None)), - BroadcastHint(Project(Seq($"y.key"), SubqueryAlias("y", input, None))), + Project(Seq($"x.key"), SubqueryAlias("x", input)), + BroadcastHint(Project(Seq($"y.key"), SubqueryAlias("y", input))), Cross, None).analyze comparePlans(optimized, expected) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/JoinReorderSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/JoinReorderSuite.scala new file mode 100644 index 0000000000000..71db4e2e0ec4d --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/JoinReorderSuite.scala @@ -0,0 +1,265 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.optimizer + +import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.catalyst.dsl.plans._ +import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeMap} +import org.apache.spark.sql.catalyst.plans.{Inner, PlanTest} +import org.apache.spark.sql.catalyst.plans.logical.{ColumnStat, LogicalPlan} +import org.apache.spark.sql.catalyst.rules.RuleExecutor +import org.apache.spark.sql.catalyst.statsEstimation.{StatsEstimationTestBase, StatsTestPlan} +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.internal.SQLConf.{CBO_ENABLED, JOIN_REORDER_ENABLED} + + +class JoinReorderSuite extends PlanTest with StatsEstimationTestBase { + + override val conf = new SQLConf().copy(CBO_ENABLED -> true, JOIN_REORDER_ENABLED -> true) + + object Optimize extends RuleExecutor[LogicalPlan] { + val batches = + Batch("Operator Optimizations", FixedPoint(100), + CombineFilters, + PushDownPredicate, + ReorderJoin(conf), + PushPredicateThroughJoin, + ColumnPruning, + CollapseProject) :: + Batch("Join Reorder", Once, + CostBasedJoinReorder(conf)) :: Nil + } + + /** Set up tables and columns for testing */ + private val columnInfo: AttributeMap[ColumnStat] = AttributeMap(Seq( + attr("t1.k-1-2") -> ColumnStat(distinctCount = 2, min = Some(1), max = Some(2), + nullCount = 0, avgLen = 4, maxLen = 4), + attr("t1.v-1-10") -> ColumnStat(distinctCount = 10, min = Some(1), max = Some(10), + nullCount = 0, avgLen = 4, maxLen = 4), + attr("t2.k-1-5") -> ColumnStat(distinctCount = 5, min = Some(1), max = Some(5), + nullCount = 0, avgLen = 4, maxLen = 4), + attr("t3.v-1-100") -> ColumnStat(distinctCount = 100, min = Some(1), max = Some(100), + nullCount = 0, avgLen = 4, maxLen = 4), + attr("t4.k-1-2") -> ColumnStat(distinctCount = 2, min = Some(1), max = Some(2), + nullCount = 0, avgLen = 4, maxLen = 4), + attr("t4.v-1-10") -> ColumnStat(distinctCount = 10, min = Some(1), max = Some(10), + nullCount = 0, avgLen = 4, maxLen = 4), + attr("t5.k-1-5") -> ColumnStat(distinctCount = 5, min = Some(1), max = Some(5), + nullCount = 0, avgLen = 4, maxLen = 4), + attr("t5.v-1-5") -> ColumnStat(distinctCount = 5, min = Some(1), max = Some(5), + nullCount = 0, avgLen = 4, maxLen = 4) + )) + + private val nameToAttr: Map[String, Attribute] = columnInfo.map(kv => kv._1.name -> kv._1) + private val nameToColInfo: Map[String, (Attribute, ColumnStat)] = + columnInfo.map(kv => kv._1.name -> kv) + + // Table t1/t4: big table with two columns + private val t1 = StatsTestPlan( + outputList = Seq("t1.k-1-2", "t1.v-1-10").map(nameToAttr), + rowCount = 1000, + // size = rows * (overhead + column length) + size = Some(1000 * (8 + 4 + 4)), + attributeStats = AttributeMap(Seq("t1.k-1-2", "t1.v-1-10").map(nameToColInfo))) + + private val t4 = StatsTestPlan( + outputList = Seq("t4.k-1-2", "t4.v-1-10").map(nameToAttr), + rowCount = 2000, + size = Some(2000 * (8 + 4 + 4)), + attributeStats = AttributeMap(Seq("t4.k-1-2", "t4.v-1-10").map(nameToColInfo))) + + // Table t2/t3: small table with only one column + private val t2 = StatsTestPlan( + outputList = Seq("t2.k-1-5").map(nameToAttr), + rowCount = 20, + size = Some(20 * (8 + 4)), + attributeStats = AttributeMap(Seq("t2.k-1-5").map(nameToColInfo))) + + private val t3 = StatsTestPlan( + outputList = Seq("t3.v-1-100").map(nameToAttr), + rowCount = 100, + size = Some(100 * (8 + 4)), + attributeStats = AttributeMap(Seq("t3.v-1-100").map(nameToColInfo))) + + // Table t5: small table with two columns + private val t5 = StatsTestPlan( + outputList = Seq("t5.k-1-5", "t5.v-1-5").map(nameToAttr), + rowCount = 20, + size = Some(20 * (8 + 4)), + attributeStats = AttributeMap(Seq("t5.k-1-5", "t5.v-1-5").map(nameToColInfo))) + + test("reorder 3 tables") { + val originalPlan = + t1.join(t2).join(t3).where((nameToAttr("t1.k-1-2") === nameToAttr("t2.k-1-5")) && + (nameToAttr("t1.v-1-10") === nameToAttr("t3.v-1-100"))) + + // The cost of original plan (use only cardinality to simplify explanation): + // cost = cost(t1 J t2) = 1000 * 20 / 5 = 4000 + // In contrast, the cost of the best plan: + // cost = cost(t1 J t3) = 1000 * 100 / 100 = 1000 < 4000 + // so (t1 J t3) J t2 is better (has lower cost, i.e. intermediate result size) than + // the original order (t1 J t2) J t3. + val bestPlan = + t1.join(t3, Inner, Some(nameToAttr("t1.v-1-10") === nameToAttr("t3.v-1-100"))) + .join(t2, Inner, Some(nameToAttr("t1.k-1-2") === nameToAttr("t2.k-1-5"))) + + assertEqualPlans(originalPlan, bestPlan) + } + + test("put unjoinable item at the end and reorder 3 joinable tables") { + // The ReorderJoin rule puts the unjoinable item at the end, and then CostBasedJoinReorder + // reorders other joinable items. + val originalPlan = + t1.join(t2).join(t4).join(t3).where((nameToAttr("t1.k-1-2") === nameToAttr("t2.k-1-5")) && + (nameToAttr("t1.v-1-10") === nameToAttr("t3.v-1-100"))) + + val bestPlan = + t1.join(t3, Inner, Some(nameToAttr("t1.v-1-10") === nameToAttr("t3.v-1-100"))) + .join(t2, Inner, Some(nameToAttr("t1.k-1-2") === nameToAttr("t2.k-1-5"))) + .join(t4) + + assertEqualPlans(originalPlan, bestPlan) + } + + test("reorder 3 tables with pure-attribute project") { + val originalPlan = + t1.join(t2).join(t3).where((nameToAttr("t1.k-1-2") === nameToAttr("t2.k-1-5")) && + (nameToAttr("t1.v-1-10") === nameToAttr("t3.v-1-100"))) + .select(nameToAttr("t1.v-1-10")) + + val bestPlan = + t1.join(t3, Inner, Some(nameToAttr("t1.v-1-10") === nameToAttr("t3.v-1-100"))) + .select(nameToAttr("t1.k-1-2"), nameToAttr("t1.v-1-10")) + .join(t2, Inner, Some(nameToAttr("t1.k-1-2") === nameToAttr("t2.k-1-5"))) + .select(nameToAttr("t1.v-1-10")) + + assertEqualPlans(originalPlan, bestPlan) + } + + test("reorder 3 tables - one of the leaf items is a project") { + val originalPlan = + t1.join(t5).join(t3).where((nameToAttr("t1.k-1-2") === nameToAttr("t5.k-1-5")) && + (nameToAttr("t1.v-1-10") === nameToAttr("t3.v-1-100"))) + .select(nameToAttr("t1.v-1-10")) + + // Items: t1, t3, project(t5.k-1-5, t5) + val bestPlan = + t1.join(t3, Inner, Some(nameToAttr("t1.v-1-10") === nameToAttr("t3.v-1-100"))) + .select(nameToAttr("t1.k-1-2"), nameToAttr("t1.v-1-10")) + .join(t5.select(nameToAttr("t5.k-1-5")), Inner, + Some(nameToAttr("t1.k-1-2") === nameToAttr("t5.k-1-5"))) + .select(nameToAttr("t1.v-1-10")) + + assertEqualPlans(originalPlan, bestPlan) + } + + test("don't reorder if project contains non-attribute") { + val originalPlan = + t1.join(t2, Inner, Some(nameToAttr("t1.k-1-2") === nameToAttr("t2.k-1-5"))) + .select((nameToAttr("t1.k-1-2") + nameToAttr("t2.k-1-5")) as "key", nameToAttr("t1.v-1-10")) + .join(t3, Inner, Some(nameToAttr("t1.v-1-10") === nameToAttr("t3.v-1-100"))) + .select("key".attr) + + assertEqualPlans(originalPlan, originalPlan) + } + + test("reorder 4 tables (bushy tree)") { + val originalPlan = + t1.join(t4).join(t2).join(t3).where((nameToAttr("t1.k-1-2") === nameToAttr("t4.k-1-2")) && + (nameToAttr("t1.k-1-2") === nameToAttr("t2.k-1-5")) && + (nameToAttr("t4.v-1-10") === nameToAttr("t3.v-1-100"))) + + // The cost of original plan (use only cardinality to simplify explanation): + // cost(t1 J t4) = 1000 * 2000 / 2 = 1000000, cost(t1t4 J t2) = 1000000 * 20 / 5 = 4000000, + // cost = cost(t1 J t4) + cost(t1t4 J t2) = 5000000 + // In contrast, the cost of the best plan (a bushy tree): + // cost(t1 J t2) = 1000 * 20 / 5 = 4000, cost(t4 J t3) = 2000 * 100 / 100 = 2000, + // cost = cost(t1 J t2) + cost(t4 J t3) = 6000 << 5000000. + val bestPlan = + t1.join(t2, Inner, Some(nameToAttr("t1.k-1-2") === nameToAttr("t2.k-1-5"))) + .join(t4.join(t3, Inner, Some(nameToAttr("t4.v-1-10") === nameToAttr("t3.v-1-100"))), + Inner, Some(nameToAttr("t1.k-1-2") === nameToAttr("t4.k-1-2"))) + + assertEqualPlans(originalPlan, bestPlan) + } + + test("keep the order of attributes in the final output") { + val outputLists = Seq("t1.k-1-2", "t1.v-1-10", "t3.v-1-100").permutations + while (outputLists.hasNext) { + val expectedOrder = outputLists.next().map(nameToAttr) + val expectedPlan = + t1.join(t3, Inner, Some(nameToAttr("t1.v-1-10") === nameToAttr("t3.v-1-100"))) + .join(t2, Inner, Some(nameToAttr("t1.k-1-2") === nameToAttr("t2.k-1-5"))) + .select(expectedOrder: _*) + // The plan should not change after optimization + assertEqualPlans(expectedPlan, expectedPlan) + } + } + + test("reorder recursively") { + // Original order: + // Join + // / \ + // Union t5 + // / \ + // Join t4 + // / \ + // Join t3 + // / \ + // t1 t2 + val bottomJoins = + t1.join(t2).join(t3).where((nameToAttr("t1.k-1-2") === nameToAttr("t2.k-1-5")) && + (nameToAttr("t1.v-1-10") === nameToAttr("t3.v-1-100"))) + .select(nameToAttr("t1.v-1-10")) + + val originalPlan = bottomJoins + .union(t4.select(nameToAttr("t4.v-1-10"))) + .join(t5, Inner, Some(nameToAttr("t1.v-1-10") === nameToAttr("t5.v-1-5"))) + + // Should be able to reorder the bottom part. + // Best order: + // Join + // / \ + // Union t5 + // / \ + // Join t4 + // / \ + // Join t2 + // / \ + // t1 t3 + val bestBottomPlan = + t1.join(t3, Inner, Some(nameToAttr("t1.v-1-10") === nameToAttr("t3.v-1-100"))) + .select(nameToAttr("t1.k-1-2"), nameToAttr("t1.v-1-10")) + .join(t2, Inner, Some(nameToAttr("t1.k-1-2") === nameToAttr("t2.k-1-5"))) + .select(nameToAttr("t1.v-1-10")) + + val bestPlan = bestBottomPlan + .union(t4.select(nameToAttr("t4.v-1-10"))) + .join(t5, Inner, Some(nameToAttr("t1.v-1-10") === nameToAttr("t5.v-1-5"))) + + assertEqualPlans(originalPlan, bestPlan) + } + + private def assertEqualPlans( + originalPlan: LogicalPlan, + groundTruthBestPlan: LogicalPlan): Unit = { + val optimized = Optimize.execute(originalPlan.analyze) + val expected = groundTruthBestPlan.analyze + compareJoinOrder(optimized, expected) + } +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/LimitPushdownSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/LimitPushdownSuite.scala index 0f3ba6c895566..2885fd6841e9d 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/LimitPushdownSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/LimitPushdownSuite.scala @@ -17,7 +17,6 @@ package org.apache.spark.sql.catalyst.optimizer -import org.apache.spark.sql.catalyst.SimpleCatalystConf import org.apache.spark.sql.catalyst.analysis.EliminateSubqueryAliases import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.dsl.plans._ diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeCodegenSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeCodegenSuite.scala index 4385b0e019f25..f3b65cc797ec4 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeCodegenSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeCodegenSuite.scala @@ -18,7 +18,6 @@ package org.apache.spark.sql.catalyst.optimizer import org.apache.spark.sql.catalyst.dsl.plans._ -import org.apache.spark.sql.catalyst.SimpleCatalystConf import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.Literal._ import org.apache.spark.sql.catalyst.plans.PlanTest @@ -29,7 +28,7 @@ import org.apache.spark.sql.catalyst.rules._ class OptimizeCodegenSuite extends PlanTest { object Optimize extends RuleExecutor[LogicalPlan] { - val batches = Batch("OptimizeCodegen", Once, OptimizeCodegen(SimpleCatalystConf(true))) :: Nil + val batches = Batch("OptimizeCodegen", Once, OptimizeCodegen(conf)) :: Nil } protected def assertEquivalent(e1: Expression, e2: Expression): Unit = { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeInSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeInSuite.scala index 9daede1a5f957..d8937321ecb98 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeInSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeInSuite.scala @@ -17,7 +17,6 @@ package org.apache.spark.sql.catalyst.optimizer -import org.apache.spark.sql.catalyst.SimpleCatalystConf import org.apache.spark.sql.catalyst.analysis.{EliminateSubqueryAliases, UnresolvedAttribute} import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.dsl.plans._ @@ -25,6 +24,7 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.logical.{Filter, LocalRelation, LogicalPlan} import org.apache.spark.sql.catalyst.plans.PlanTest import org.apache.spark.sql.catalyst.rules.RuleExecutor +import org.apache.spark.sql.internal.SQLConf.OPTIMIZER_INSET_CONVERSION_THRESHOLD import org.apache.spark.sql.types._ class OptimizeInSuite extends PlanTest { @@ -34,10 +34,10 @@ class OptimizeInSuite extends PlanTest { Batch("AnalysisNodes", Once, EliminateSubqueryAliases) :: Batch("ConstantFolding", FixedPoint(10), - NullPropagation(SimpleCatalystConf(caseSensitiveAnalysis = true)), + NullPropagation(conf), ConstantFolding, BooleanSimplification, - OptimizeIn(SimpleCatalystConf(caseSensitiveAnalysis = true))) :: Nil + OptimizeIn(conf)) :: Nil } val testRelation = LocalRelation('a.int, 'b.int, 'c.int) @@ -159,12 +159,11 @@ class OptimizeInSuite extends PlanTest { .where(In(UnresolvedAttribute("a"), Seq(Literal(1), Literal(2), Literal(3)))) .analyze - val notOptimizedPlan = OptimizeIn(SimpleCatalystConf(caseSensitiveAnalysis = true))(plan) + val notOptimizedPlan = OptimizeIn(conf)(plan) comparePlans(notOptimizedPlan, plan) // Reduce the threshold to turning into InSet. - val optimizedPlan = OptimizeIn(SimpleCatalystConf(caseSensitiveAnalysis = true, - optimizerInSetConversionThreshold = 2))(plan) + val optimizedPlan = OptimizeIn(conf.copy(OPTIMIZER_INSET_CONVERSION_THRESHOLD -> 2))(plan) optimizedPlan match { case Filter(cond, _) if cond.isInstanceOf[InSet] && cond.asInstanceOf[InSet].getHSet().size == 3 => diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OuterJoinEliminationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OuterJoinEliminationSuite.scala index c168a55e40c54..b7136703b7541 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OuterJoinEliminationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OuterJoinEliminationSuite.scala @@ -24,6 +24,7 @@ import org.apache.spark.sql.catalyst.expressions.{Coalesce, IsNotNull} import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules._ +import org.apache.spark.sql.internal.SQLConf.CONSTRAINT_PROPAGATION_ENABLED class OuterJoinEliminationSuite extends PlanTest { object Optimize extends RuleExecutor[LogicalPlan] { @@ -31,7 +32,16 @@ class OuterJoinEliminationSuite extends PlanTest { Batch("Subqueries", Once, EliminateSubqueryAliases) :: Batch("Outer Join Elimination", Once, - EliminateOuterJoin, + EliminateOuterJoin(conf), + PushPredicateThroughJoin) :: Nil + } + + object OptimizeWithConstraintPropagationDisabled extends RuleExecutor[LogicalPlan] { + val batches = + Batch("Subqueries", Once, + EliminateSubqueryAliases) :: + Batch("Outer Join Elimination", Once, + EliminateOuterJoin(conf.copy(CONSTRAINT_PROPAGATION_ENABLED -> false)), PushPredicateThroughJoin) :: Nil } @@ -231,4 +241,21 @@ class OuterJoinEliminationSuite extends PlanTest { comparePlans(optimized, correctAnswer) } + + test("no outer join elimination if constraint propagation is disabled") { + val x = testRelation.subquery('x) + val y = testRelation1.subquery('y) + + // The predicate "x.b + y.d >= 3" will be inferred constraints like: + // "x.b != null" and "y.d != null", if constraint propagation is enabled. + // When we disable it, the predicate can't be evaluated on left or right plan and used to + // filter out nulls. So the Outer Join will not be eliminated. + val originalQuery = + x.join(y, FullOuter, Option("x.a".attr === "y.d".attr)) + .where("x.b".attr + "y.d".attr >= 3) + + val optimized = OptimizeWithConstraintPropagationDisabled.execute(originalQuery.analyze) + + comparePlans(optimized, originalQuery.analyze) + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/PropagateEmptyRelationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/PropagateEmptyRelationSuite.scala index 908dde7a66988..c261a6091d476 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/PropagateEmptyRelationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/PropagateEmptyRelationSuite.scala @@ -33,7 +33,7 @@ class PropagateEmptyRelationSuite extends PlanTest { ReplaceExceptWithAntiJoin, ReplaceIntersectWithSemiJoin, PushDownPredicate, - PruneFilters, + PruneFilters(conf), PropagateEmptyRelation) :: Nil } @@ -45,7 +45,7 @@ class PropagateEmptyRelationSuite extends PlanTest { ReplaceExceptWithAntiJoin, ReplaceIntersectWithSemiJoin, PushDownPredicate, - PruneFilters) :: Nil + PruneFilters(conf)) :: Nil } val testRelation1 = LocalRelation.fromExternalRows(Seq('a.int), data = Seq(Row(1))) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/PruneFiltersSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/PruneFiltersSuite.scala index d8cfec5391497..741dd0cf428d0 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/PruneFiltersSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/PruneFiltersSuite.scala @@ -24,6 +24,7 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules._ +import org.apache.spark.sql.internal.SQLConf.CONSTRAINT_PROPAGATION_ENABLED class PruneFiltersSuite extends PlanTest { @@ -33,7 +34,18 @@ class PruneFiltersSuite extends PlanTest { EliminateSubqueryAliases) :: Batch("Filter Pushdown and Pruning", Once, CombineFilters, - PruneFilters, + PruneFilters(conf), + PushDownPredicate, + PushPredicateThroughJoin) :: Nil + } + + object OptimizeWithConstraintPropagationDisabled extends RuleExecutor[LogicalPlan] { + val batches = + Batch("Subqueries", Once, + EliminateSubqueryAliases) :: + Batch("Filter Pushdown and Pruning", Once, + CombineFilters, + PruneFilters(conf.copy(CONSTRAINT_PROPAGATION_ENABLED -> false)), PushDownPredicate, PushPredicateThroughJoin) :: Nil } @@ -133,4 +145,29 @@ class PruneFiltersSuite extends PlanTest { val correctAnswer = testRelation.where(Rand(10) > 5).where(Rand(10) > 5).select('a).analyze comparePlans(optimized, correctAnswer) } + + test("No pruning when constraint propagation is disabled") { + val tr1 = LocalRelation('a.int, 'b.int, 'c.int).subquery('tr1) + val tr2 = LocalRelation('a.int, 'd.int, 'e.int).subquery('tr2) + + val query = tr1 + .where("tr1.a".attr > 10 || "tr1.c".attr < 10) + .join(tr2.where('d.attr < 100), Inner, Some("tr1.a".attr === "tr2.a".attr)) + + val queryWithUselessFilter = + query.where( + ("tr1.a".attr > 10 || "tr1.c".attr < 10) && + 'd.attr < 100) + + val optimized = + OptimizeWithConstraintPropagationDisabled.execute(queryWithUselessFilter.analyze) + // When constraint propagation is disabled, the useless filter won't be pruned. + // It gets pushed down. Because the rule `CombineFilters` runs only once, there are redundant + // and duplicate filters. + val correctAnswer = tr1 + .where("tr1.a".attr > 10 || "tr1.c".attr < 10).where("tr1.a".attr > 10 || "tr1.c".attr < 10) + .join(tr2.where('d.attr < 100).where('d.attr < 100), + Inner, Some("tr1.a".attr === "tr2.a".attr)).analyze + comparePlans(optimized, correctAnswer) + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RemoveRedundantAliasAndProjectSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RemoveRedundantAliasAndProjectSuite.scala index c01ea01ec6808..1973b5abb462d 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RemoveRedundantAliasAndProjectSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RemoveRedundantAliasAndProjectSuite.scala @@ -116,4 +116,12 @@ class RemoveRedundantAliasAndProjectSuite extends PlanTest with PredicateHelper val expected = relation.window(Seq('b), Seq('a), Seq()).analyze comparePlans(optimized, expected) } + + test("do not remove output attributes from a subquery") { + val relation = LocalRelation('a.int, 'b.int) + val query = Subquery(relation.select('a as "a", 'b as "b").where('b < 10).select('a).analyze) + val optimized = Optimize.execute(query) + val expected = Subquery(relation.select('a as "a", 'b).where('b < 10).select('a).analyze) + comparePlans(optimized, expected) + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDistinctAggregatesSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDistinctAggregatesSuite.scala index 350a1c26fd1ef..8cb939e010c68 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDistinctAggregatesSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDistinctAggregatesSuite.scala @@ -16,19 +16,20 @@ */ package org.apache.spark.sql.catalyst.optimizer -import org.apache.spark.sql.catalyst.SimpleCatalystConf import org.apache.spark.sql.catalyst.analysis.{Analyzer, EmptyFunctionRegistry} import org.apache.spark.sql.catalyst.catalog.{InMemoryCatalog, SessionCatalog} import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.dsl.plans._ -import org.apache.spark.sql.catalyst.expressions.{If, Literal} -import org.apache.spark.sql.catalyst.expressions.aggregate.{CollectSet, Count} +import org.apache.spark.sql.catalyst.expressions.Literal +import org.apache.spark.sql.catalyst.expressions.aggregate.CollectSet import org.apache.spark.sql.catalyst.plans.PlanTest import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Expand, LocalRelation, LogicalPlan} +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.internal.SQLConf.{CASE_SENSITIVE, GROUP_BY_ORDINAL} import org.apache.spark.sql.types.{IntegerType, StringType} class RewriteDistinctAggregatesSuite extends PlanTest { - override val conf = SimpleCatalystConf(caseSensitiveAnalysis = false, groupByOrdinal = false) + override val conf = new SQLConf().copy(CASE_SENSITIVE -> false, GROUP_BY_ORDINAL -> false) val catalog = new SessionCatalog(new InMemoryCatalog, EmptyFunctionRegistry, conf) val analyzer = new Analyzer(catalog, conf) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SetOperationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SetOperationSuite.scala index 21b7f49e14bd5..756e0f35b2178 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SetOperationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SetOperationSuite.scala @@ -34,7 +34,7 @@ class SetOperationSuite extends PlanTest { CombineUnions, PushProjectionThroughUnion, PushDownPredicate, - PruneFilters) :: Nil + PruneFilters(conf)) :: Nil } val testRelation = LocalRelation('a.int, 'b.int, 'c.int) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/StarJoinCostBasedReorderSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/StarJoinCostBasedReorderSuite.scala new file mode 100644 index 0000000000000..a23d6266b2840 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/StarJoinCostBasedReorderSuite.scala @@ -0,0 +1,426 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.optimizer + +import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.catalyst.dsl.plans._ +import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeMap} +import org.apache.spark.sql.catalyst.plans.{Inner, PlanTest} +import org.apache.spark.sql.catalyst.plans.logical.{ColumnStat, LogicalPlan} +import org.apache.spark.sql.catalyst.rules.RuleExecutor +import org.apache.spark.sql.catalyst.statsEstimation.{StatsEstimationTestBase, StatsTestPlan} +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.internal.SQLConf._ + + +class StarJoinCostBasedReorderSuite extends PlanTest with StatsEstimationTestBase { + + override val conf = new SQLConf().copy( + CBO_ENABLED -> true, + JOIN_REORDER_ENABLED -> true, + JOIN_REORDER_DP_STAR_FILTER -> true) + + object Optimize extends RuleExecutor[LogicalPlan] { + val batches = + Batch("Operator Optimizations", FixedPoint(100), + CombineFilters, + PushDownPredicate, + ReorderJoin(conf), + PushPredicateThroughJoin, + ColumnPruning, + CollapseProject) :: + Batch("Join Reorder", Once, + CostBasedJoinReorder(conf)) :: Nil + } + + private val columnInfo: AttributeMap[ColumnStat] = AttributeMap(Seq( + // F1 (fact table) + attr("f1_fk1") -> ColumnStat(distinctCount = 100, min = Some(1), max = Some(100), + nullCount = 0, avgLen = 4, maxLen = 4), + attr("f1_fk2") -> ColumnStat(distinctCount = 100, min = Some(1), max = Some(100), + nullCount = 0, avgLen = 4, maxLen = 4), + attr("f1_fk3") -> ColumnStat(distinctCount = 100, min = Some(1), max = Some(100), + nullCount = 0, avgLen = 4, maxLen = 4), + attr("f1_c1") -> ColumnStat(distinctCount = 100, min = Some(1), max = Some(100), + nullCount = 0, avgLen = 4, maxLen = 4), + attr("f1_c2") -> ColumnStat(distinctCount = 100, min = Some(1), max = Some(100), + nullCount = 0, avgLen = 4, maxLen = 4), + + // D1 (dimension) + attr("d1_pk") -> ColumnStat(distinctCount = 100, min = Some(1), max = Some(100), + nullCount = 0, avgLen = 4, maxLen = 4), + attr("d1_c2") -> ColumnStat(distinctCount = 50, min = Some(1), max = Some(50), + nullCount = 0, avgLen = 4, maxLen = 4), + attr("d1_c3") -> ColumnStat(distinctCount = 50, min = Some(1), max = Some(50), + nullCount = 0, avgLen = 4, maxLen = 4), + + // D2 (dimension) + attr("d2_pk") -> ColumnStat(distinctCount = 20, min = Some(1), max = Some(20), + nullCount = 0, avgLen = 4, maxLen = 4), + attr("d2_c2") -> ColumnStat(distinctCount = 10, min = Some(1), max = Some(10), + nullCount = 0, avgLen = 4, maxLen = 4), + attr("d2_c3") -> ColumnStat(distinctCount = 10, min = Some(1), max = Some(10), + nullCount = 0, avgLen = 4, maxLen = 4), + + // D3 (dimension) + attr("d3_pk") -> ColumnStat(distinctCount = 10, min = Some(1), max = Some(10), + nullCount = 0, avgLen = 4, maxLen = 4), + attr("d3_c2") -> ColumnStat(distinctCount = 5, min = Some(1), max = Some(5), + nullCount = 0, avgLen = 4, maxLen = 4), + attr("d3_c3") -> ColumnStat(distinctCount = 5, min = Some(1), max = Some(5), + nullCount = 0, avgLen = 4, maxLen = 4), + + // T1 (regular table i.e. outside star) + attr("t1_c1") -> ColumnStat(distinctCount = 20, min = Some(1), max = Some(20), + nullCount = 1, avgLen = 4, maxLen = 4), + attr("t1_c2") -> ColumnStat(distinctCount = 10, min = Some(1), max = Some(10), + nullCount = 1, avgLen = 4, maxLen = 4), + attr("t1_c3") -> ColumnStat(distinctCount = 10, min = Some(1), max = Some(10), + nullCount = 1, avgLen = 4, maxLen = 4), + + // T2 (regular table) + attr("t2_c1") -> ColumnStat(distinctCount = 5, min = Some(1), max = Some(5), + nullCount = 1, avgLen = 4, maxLen = 4), + attr("t2_c2") -> ColumnStat(distinctCount = 5, min = Some(1), max = Some(5), + nullCount = 1, avgLen = 4, maxLen = 4), + attr("t2_c3") -> ColumnStat(distinctCount = 5, min = Some(1), max = Some(5), + nullCount = 1, avgLen = 4, maxLen = 4), + + // T3 (regular table) + attr("t3_c1") -> ColumnStat(distinctCount = 5, min = Some(1), max = Some(5), + nullCount = 1, avgLen = 4, maxLen = 4), + attr("t3_c2") -> ColumnStat(distinctCount = 5, min = Some(1), max = Some(5), + nullCount = 1, avgLen = 4, maxLen = 4), + attr("t3_c3") -> ColumnStat(distinctCount = 5, min = Some(1), max = Some(5), + nullCount = 1, avgLen = 4, maxLen = 4), + + // T4 (regular table) + attr("t4_c1") -> ColumnStat(distinctCount = 5, min = Some(1), max = Some(5), + nullCount = 1, avgLen = 4, maxLen = 4), + attr("t4_c2") -> ColumnStat(distinctCount = 5, min = Some(1), max = Some(5), + nullCount = 1, avgLen = 4, maxLen = 4), + attr("t4_c3") -> ColumnStat(distinctCount = 5, min = Some(1), max = Some(5), + nullCount = 1, avgLen = 4, maxLen = 4), + + // T5 (regular table) + attr("t5_c1") -> ColumnStat(distinctCount = 5, min = Some(1), max = Some(5), + nullCount = 1, avgLen = 4, maxLen = 4), + attr("t5_c2") -> ColumnStat(distinctCount = 5, min = Some(1), max = Some(5), + nullCount = 1, avgLen = 4, maxLen = 4), + attr("t5_c3") -> ColumnStat(distinctCount = 5, min = Some(1), max = Some(5), + nullCount = 1, avgLen = 4, maxLen = 4), + + // T6 (regular table) + attr("t6_c1") -> ColumnStat(distinctCount = 5, min = Some(1), max = Some(5), + nullCount = 1, avgLen = 4, maxLen = 4), + attr("t6_c2") -> ColumnStat(distinctCount = 5, min = Some(1), max = Some(5), + nullCount = 1, avgLen = 4, maxLen = 4), + attr("t6_c3") -> ColumnStat(distinctCount = 5, min = Some(1), max = Some(5), + nullCount = 1, avgLen = 4, maxLen = 4) + + )) + + private val nameToAttr: Map[String, Attribute] = columnInfo.map(kv => kv._1.name -> kv._1) + private val nameToColInfo: Map[String, (Attribute, ColumnStat)] = + columnInfo.map(kv => kv._1.name -> kv) + + private val f1 = StatsTestPlan( + outputList = Seq("f1_fk1", "f1_fk2", "f1_fk3", "f1_c1", "f1_c2").map(nameToAttr), + rowCount = 1000, + size = Some(1000 * (8 + 4 * 5)), + attributeStats = AttributeMap(Seq("f1_fk1", "f1_fk2", "f1_fk3", "f1_c1", "f1_c2") + .map(nameToColInfo))) + + // To control the layout of the join plans, keep the size for the non-fact tables constant + // and vary the rowcount and the number of distinct values of the join columns. + private val d1 = StatsTestPlan( + outputList = Seq("d1_pk", "d1_c2", "d1_c3").map(nameToAttr), + rowCount = 100, + size = Some(3000), + attributeStats = AttributeMap(Seq("d1_pk", "d1_c2", "d1_c3").map(nameToColInfo))) + + private val d2 = StatsTestPlan( + outputList = Seq("d2_pk", "d2_c2", "d2_c3").map(nameToAttr), + rowCount = 20, + size = Some(3000), + attributeStats = AttributeMap(Seq("d2_pk", "d2_c2", "d2_c3").map(nameToColInfo))) + + private val d3 = StatsTestPlan( + outputList = Seq("d3_pk", "d3_c2", "d3_c3").map(nameToAttr), + rowCount = 10, + size = Some(3000), + attributeStats = AttributeMap(Seq("d3_pk", "d3_c2", "d3_c3").map(nameToColInfo))) + + private val t1 = StatsTestPlan( + outputList = Seq("t1_c1", "t1_c2", "t1_c3").map(nameToAttr), + rowCount = 50, + size = Some(3000), + attributeStats = AttributeMap(Seq("t1_c1", "t1_c2", "t1_c3").map(nameToColInfo))) + + private val t2 = StatsTestPlan( + outputList = Seq("t2_c1", "t2_c2", "t2_c3").map(nameToAttr), + rowCount = 10, + size = Some(3000), + attributeStats = AttributeMap(Seq("t2_c1", "t2_c2", "t2_c3").map(nameToColInfo))) + + private val t3 = StatsTestPlan( + outputList = Seq("t3_c1", "t3_c2", "t3_c3").map(nameToAttr), + rowCount = 10, + size = Some(3000), + attributeStats = AttributeMap(Seq("t3_c1", "t3_c2", "t3_c3").map(nameToColInfo))) + + private val t4 = StatsTestPlan( + outputList = Seq("t4_c1", "t4_c2", "t4_c3").map(nameToAttr), + rowCount = 10, + size = Some(3000), + attributeStats = AttributeMap(Seq("t4_c1", "t4_c2", "t4_c3").map(nameToColInfo))) + + private val t5 = StatsTestPlan( + outputList = Seq("t5_c1", "t5_c2", "t5_c3").map(nameToAttr), + rowCount = 10, + size = Some(3000), + attributeStats = AttributeMap(Seq("t5_c1", "t5_c2", "t5_c3").map(nameToColInfo))) + + private val t6 = StatsTestPlan( + outputList = Seq("t6_c1", "t6_c2", "t6_c3").map(nameToAttr), + rowCount = 10, + size = Some(3000), + attributeStats = AttributeMap(Seq("t6_c1", "t6_c2", "t6_c3").map(nameToColInfo))) + + test("Test 1: Star query with two dimensions and two regular tables") { + + // d1 t1 + // \ / + // f1 + // / \ + // d2 t2 + // + // star: {f1, d1, d2} + // non-star: {t1, t2} + // + // level 0: (t2 ), (d2 ), (f1 ), (d1 ), (t1 ) + // level 1: {f1 d1 }, {d2 f1 } + // level 2: {d2 f1 d1 } + // level 3: {t2 d1 d2 f1 }, {t1 d1 d2 f1 } + // level 4: {f1 t1 t2 d1 d2 } + // + // Number of generated plans: 11 (vs. 20 w/o filter) + val query = + f1.join(t1).join(t2).join(d1).join(d2) + .where((nameToAttr("f1_c1") === nameToAttr("t1_c1")) && + (nameToAttr("f1_c2") === nameToAttr("t2_c1")) && + (nameToAttr("f1_fk1") === nameToAttr("d1_pk")) && + (nameToAttr("f1_fk2") === nameToAttr("d2_pk"))) + + val expected = + f1.join(d2, Inner, Some(nameToAttr("f1_fk2") === nameToAttr("d2_pk"))) + .join(d1, Inner, Some(nameToAttr("f1_fk1") === nameToAttr("d1_pk"))) + .join(t2, Inner, Some(nameToAttr("f1_c2") === nameToAttr("t2_c1"))) + .join(t1, Inner, Some(nameToAttr("f1_c1") === nameToAttr("t1_c1"))) + + assertEqualPlans(query, expected) + } + + test("Test 2: Star with a linear branch") { + // + // t1 d1 - t2 - t3 + // \ / + // f1 + // | + // d2 + // + // star: {d1, f1, d2} + // non-star: {t2, t1, t3} + // + // level 0: (f1 ), (d2 ), (t3 ), (d1 ), (t1 ), (t2 ) + // level 1: {t3 t2 }, {f1 d2 }, {f1 d1 } + // level 2: {d2 f1 d1 } + // level 3: {t1 d1 f1 d2 }, {t2 d1 f1 d2 } + // level 4: {d1 t2 f1 t1 d2 }, {d1 t3 t2 f1 d2 } + // level 5: {d1 t3 t2 f1 t1 d2 } + // + // Number of generated plans: 15 (vs 24) + val query = + d1.join(t1).join(t2).join(f1).join(d2).join(t3) + .where((nameToAttr("d1_pk") === nameToAttr("f1_fk1")) && + (nameToAttr("t1_c1") === nameToAttr("f1_c1")) && + (nameToAttr("d2_pk") === nameToAttr("f1_fk2")) && + (nameToAttr("f1_fk2") === nameToAttr("d2_pk")) && + (nameToAttr("d1_c2") === nameToAttr("t2_c1")) && + (nameToAttr("t2_c2") === nameToAttr("t3_c1"))) + + val expected = + f1.join(d2, Inner, Some(nameToAttr("f1_fk2") === nameToAttr("d2_pk"))) + .join(d1, Inner, Some(nameToAttr("f1_fk1") === nameToAttr("d1_pk"))) + .join(t3.join(t2, Inner, Some(nameToAttr("t2_c2") === nameToAttr("t3_c1"))), Inner, + Some(nameToAttr("d1_c2") === nameToAttr("t2_c1"))) + .join(t1, Inner, Some(nameToAttr("t1_c1") === nameToAttr("f1_c1"))) + + assertEqualPlans(query, expected) + } + + test("Test 3: Star with derived branches") { + // t3 t2 + // | | + // d1 - t4 - t1 + // | + // f1 + // | + // d2 + // + // star: (d1 f1 d2 ) + // non-star: (t4 t1 t2 t3 ) + // + // level 0: (t1 ), (t3 ), (f1 ), (d1 ), (t2 ), (d2 ), (t4 ) + // level 1: {f1 d2 }, {t1 t4 }, {t1 t2 }, {f1 d1 }, {t3 t4 } + // level 2: {d1 f1 d2 }, {t2 t1 t4 }, {t1 t3 t4 } + // level 3: {t4 d1 f1 d2 }, {t3 t4 t1 t2 } + // level 4: {d1 f1 t4 d2 t3 }, {d1 f1 t4 d2 t1 } + // level 5: {d1 f1 t4 d2 t1 t2 }, {d1 f1 t4 d2 t1 t3 } + // level 6: {d1 f1 t4 d2 t1 t2 t3 } + // + // Number of generated plans: 22 (vs. 34) + val query = + d1.join(t1).join(t2).join(t3).join(t4).join(f1).join(d2) + .where((nameToAttr("t1_c1") === nameToAttr("t2_c1")) && + (nameToAttr("t3_c1") === nameToAttr("t4_c1")) && + (nameToAttr("t1_c2") === nameToAttr("t4_c2")) && + (nameToAttr("d1_c2") === nameToAttr("t4_c3")) && + (nameToAttr("f1_fk1") === nameToAttr("d1_pk")) && + (nameToAttr("f1_fk2") === nameToAttr("d2_pk"))) + + val expected = + f1.join(d2, Inner, Some(nameToAttr("f1_fk2") === nameToAttr("d2_pk"))) + .join(d1, Inner, Some(nameToAttr("f1_fk1") === nameToAttr("d1_pk"))) + .join(t3.join(t4, Inner, Some(nameToAttr("t3_c1") === nameToAttr("t4_c1"))), Inner, + Some(nameToAttr("t3_c1") === nameToAttr("t4_c1"))) + .join(t1.join(t2, Inner, Some(nameToAttr("t1_c1") === nameToAttr("t2_c1"))), Inner, + Some(nameToAttr("t1_c2") === nameToAttr("t4_c2"))) + + assertEqualPlans(query, expected) + } + + test("Test 4: Star with several branches") { + // + // d1 - t3 - t4 + // | + // f1 - d3 - t1 - t2 + // | + // d2 - t5 - t6 + // + // star: {d1 f1 d2 d3 } + // non-star: {t5 t3 t6 t2 t4 t1} + // + // level 0: (t4 ), (d2 ), (t5 ), (d3 ), (d1 ), (f1 ), (t2 ), (t6 ), (t1 ), (t3 ) + // level 1: {t5 t6 }, {t4 t3 }, {d3 f1 }, {t2 t1 }, {d2 f1 }, {d1 f1 } + // level 2: {d2 d1 f1 }, {d2 d3 f1 }, {d3 d1 f1 } + // level 3: {d2 d1 d3 f1 } + // level 4: {d1 t3 d3 f1 d2 }, {d1 d3 f1 t1 d2 }, {d1 t5 d3 f1 d2 } + // level 5: {d1 t5 d3 f1 t1 d2 }, {d1 t3 t4 d3 f1 d2 }, {d1 t5 t6 d3 f1 d2 }, + // {d1 t5 t3 d3 f1 d2 }, {d1 t3 d3 f1 t1 d2 }, {d1 t2 d3 f1 t1 d2 } + // level 6: {d1 t5 t3 t4 d3 f1 d2 }, {d1 t3 t2 d3 f1 t1 d2 }, {d1 t5 t6 d3 f1 t1 d2 }, + // {d1 t5 t3 d3 f1 t1 d2 }, {d1 t5 t2 d3 f1 t1 d2 }, ... + // ... + // level 9: {d1 t5 t3 t6 t2 t4 d3 f1 t1 d2 } + // + // Number of generated plans: 46 (vs. 82) + val query = + d1.join(t3).join(t4).join(f1).join(d2).join(t5).join(t6).join(d3).join(t1).join(t2) + .where((nameToAttr("d1_c2") === nameToAttr("t3_c1")) && + (nameToAttr("t3_c2") === nameToAttr("t4_c2")) && + (nameToAttr("d1_pk") === nameToAttr("f1_fk1")) && + (nameToAttr("f1_fk2") === nameToAttr("d2_pk")) && + (nameToAttr("d2_c2") === nameToAttr("t5_c1")) && + (nameToAttr("t5_c2") === nameToAttr("t6_c2")) && + (nameToAttr("f1_fk3") === nameToAttr("d3_pk")) && + (nameToAttr("d3_c2") === nameToAttr("t1_c1")) && + (nameToAttr("t1_c2") === nameToAttr("t2_c2"))) + + val expected = + f1.join(d3, Inner, Some(nameToAttr("f1_fk3") === nameToAttr("d3_pk"))) + .join(d1, Inner, Some(nameToAttr("f1_fk1") === nameToAttr("d1_pk"))) + .join(d2, Inner, Some(nameToAttr("f1_fk2") === nameToAttr("d2_pk"))) + .join(t4.join(t3, Inner, Some(nameToAttr("t3_c2") === nameToAttr("t4_c2"))), Inner, + Some(nameToAttr("d1_c2") === nameToAttr("t3_c1"))) + .join(t2.join(t1, Inner, Some(nameToAttr("t1_c2") === nameToAttr("t2_c2"))), Inner, + Some(nameToAttr("d3_c2") === nameToAttr("t1_c1"))) + .join(t5.join(t6, Inner, Some(nameToAttr("t5_c2") === nameToAttr("t6_c2"))), Inner, + Some(nameToAttr("d2_c2") === nameToAttr("t5_c1"))) + + assertEqualPlans(query, expected) + } + + test("Test 5: RI star only") { + // d1 + // | + // f1 + // / \ + // d2 d3 + // + // star: {f1, d1, d2, d3} + // non-star: {} + // level 0: (d1), (f1), (d2), (d3) + // level 1: {f1 d3 }, {f1 d2 }, {d1 f1 } + // level 2: {d1 f1 d2 }, {d2 f1 d3 }, {d1 f1 d3 } + // level 3: {d1 d2 f1 d3 } + // Number of generated plans: 11 (= 11) + val query = + d1.join(d2).join(f1).join(d3) + .where((nameToAttr("f1_fk1") === nameToAttr("d1_pk")) && + (nameToAttr("f1_fk2") === nameToAttr("d2_pk")) && + (nameToAttr("f1_fk3") === nameToAttr("d3_pk"))) + + val expected = + f1.join(d3, Inner, Some(nameToAttr("f1_fk3") === nameToAttr("d3_pk"))) + .join(d2, Inner, Some(nameToAttr("f1_fk2") === nameToAttr("d2_pk"))) + .join(d1, Inner, Some(nameToAttr("f1_fk1") === nameToAttr("d1_pk"))) + + assertEqualPlans(query, expected) + } + + test("Test 6: No RI star") { + // + // f1 - t1 - t2 - t3 + // + // star: {} + // non-star: {f1, t1, t2, t3} + // level 0: (t1), (f1), (t2), (t3) + // level 1: {f1 t3 }, {f1 t2 }, {t1 f1 } + // level 2: {t1 f1 t2 }, {t2 f1 t3 }, {dt f1 t3 } + // level 3: {t1 t2 f1 t3 } + // Number of generated plans: 11 (= 11) + val query = + t1.join(f1).join(t2).join(t3) + .where((nameToAttr("f1_fk1") === nameToAttr("t1_c1")) && + (nameToAttr("f1_fk2") === nameToAttr("t2_c1")) && + (nameToAttr("f1_fk3") === nameToAttr("t3_c1"))) + + val expected = + f1.join(t3, Inner, Some(nameToAttr("f1_fk3") === nameToAttr("t3_c1"))) + .join(t2, Inner, Some(nameToAttr("f1_fk2") === nameToAttr("t2_c1"))) + .join(t1, Inner, Some(nameToAttr("f1_fk1") === nameToAttr("t1_c1"))) + + assertEqualPlans(query, expected) + } + + private def assertEqualPlans( plan1: LogicalPlan, plan2: LogicalPlan): Unit = { + val optimized = Optimize.execute(plan1.analyze) + val expected = plan2.analyze + compareJoinOrder(optimized, expected) + } +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/StarJoinReorderSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/StarJoinReorderSuite.scala new file mode 100644 index 0000000000000..605c01b7220d1 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/StarJoinReorderSuite.scala @@ -0,0 +1,579 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.optimizer + +import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.catalyst.dsl.plans._ +import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeMap} +import org.apache.spark.sql.catalyst.plans.{Inner, PlanTest} +import org.apache.spark.sql.catalyst.plans.logical.{ColumnStat, LocalRelation, LogicalPlan} +import org.apache.spark.sql.catalyst.rules.RuleExecutor +import org.apache.spark.sql.catalyst.statsEstimation.{StatsEstimationTestBase, StatsTestPlan} +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.internal.SQLConf.{CASE_SENSITIVE, STARSCHEMA_DETECTION} + +class StarJoinReorderSuite extends PlanTest with StatsEstimationTestBase { + + override val conf = new SQLConf().copy(CASE_SENSITIVE -> true, STARSCHEMA_DETECTION -> true) + + object Optimize extends RuleExecutor[LogicalPlan] { + val batches = + Batch("Operator Optimizations", FixedPoint(100), + CombineFilters, + PushDownPredicate, + ReorderJoin(conf), + PushPredicateThroughJoin, + ColumnPruning, + CollapseProject) :: Nil + } + + // Table setup using star schema relationships: + // + // d1 - f1 - d2 + // | + // d3 - s3 + // + // Table f1 is the fact table. Tables d1, d2, and d3 are the dimension tables. + // Dimension d3 is further joined/normalized into table s3. + // Tables' cardinality: f1 > d3 > d1 > d2 > s3 + private val columnInfo: AttributeMap[ColumnStat] = AttributeMap(Seq( + // F1 + attr("f1_fk1") -> ColumnStat(distinctCount = 3, min = Some(1), max = Some(3), + nullCount = 0, avgLen = 4, maxLen = 4), + attr("f1_fk2") -> ColumnStat(distinctCount = 3, min = Some(1), max = Some(3), + nullCount = 0, avgLen = 4, maxLen = 4), + attr("f1_fk3") -> ColumnStat(distinctCount = 4, min = Some(1), max = Some(4), + nullCount = 0, avgLen = 4, maxLen = 4), + attr("f1_c4") -> ColumnStat(distinctCount = 4, min = Some(1), max = Some(4), + nullCount = 0, avgLen = 4, maxLen = 4), + // D1 + attr("d1_pk1") -> ColumnStat(distinctCount = 4, min = Some(1), max = Some(4), + nullCount = 0, avgLen = 4, maxLen = 4), + attr("d1_c2") -> ColumnStat(distinctCount = 3, min = Some(1), max = Some(3), + nullCount = 0, avgLen = 4, maxLen = 4), + attr("d1_c3") -> ColumnStat(distinctCount = 4, min = Some(1), max = Some(4), + nullCount = 0, avgLen = 4, maxLen = 4), + attr("d1_c4") -> ColumnStat(distinctCount = 2, min = Some(2), max = Some(3), + nullCount = 0, avgLen = 4, maxLen = 4), + // D2 + attr("d2_c2") -> ColumnStat(distinctCount = 3, min = Some(1), max = Some(3), + nullCount = 1, avgLen = 4, maxLen = 4), + attr("d2_pk1") -> ColumnStat(distinctCount = 3, min = Some(1), max = Some(3), + nullCount = 0, avgLen = 4, maxLen = 4), + attr("d2_c3") -> ColumnStat(distinctCount = 3, min = Some(1), max = Some(3), + nullCount = 0, avgLen = 4, maxLen = 4), + attr("d2_c4") -> ColumnStat(distinctCount = 2, min = Some(3), max = Some(4), + nullCount = 0, avgLen = 4, maxLen = 4), + // D3 + attr("d3_fk1") -> ColumnStat(distinctCount = 3, min = Some(1), max = Some(3), + nullCount = 0, avgLen = 4, maxLen = 4), + attr("d3_c2") -> ColumnStat(distinctCount = 3, min = Some(1), max = Some(3), + nullCount = 0, avgLen = 4, maxLen = 4), + attr("d3_pk1") -> ColumnStat(distinctCount = 5, min = Some(1), max = Some(5), + nullCount = 0, avgLen = 4, maxLen = 4), + attr("d3_c4") -> ColumnStat(distinctCount = 2, min = Some(2), max = Some(3), + nullCount = 0, avgLen = 4, maxLen = 4), + // S3 + attr("s3_pk1") -> ColumnStat(distinctCount = 2, min = Some(1), max = Some(2), + nullCount = 0, avgLen = 4, maxLen = 4), + attr("s3_c2") -> ColumnStat(distinctCount = 1, min = Some(3), max = Some(3), + nullCount = 0, avgLen = 4, maxLen = 4), + attr("s3_c3") -> ColumnStat(distinctCount = 1, min = Some(3), max = Some(3), + nullCount = 0, avgLen = 4, maxLen = 4), + attr("s3_c4") -> ColumnStat(distinctCount = 2, min = Some(3), max = Some(4), + nullCount = 0, avgLen = 4, maxLen = 4), + // F11 + attr("f11_fk1") -> ColumnStat(distinctCount = 3, min = Some(1), max = Some(3), + nullCount = 0, avgLen = 4, maxLen = 4), + attr("f11_fk2") -> ColumnStat(distinctCount = 3, min = Some(1), max = Some(3), + nullCount = 0, avgLen = 4, maxLen = 4), + attr("f11_fk3") -> ColumnStat(distinctCount = 4, min = Some(1), max = Some(4), + nullCount = 0, avgLen = 4, maxLen = 4), + attr("f11_c4") -> ColumnStat(distinctCount = 4, min = Some(1), max = Some(4), + nullCount = 0, avgLen = 4, maxLen = 4) + )) + + private val nameToAttr: Map[String, Attribute] = columnInfo.map(kv => kv._1.name -> kv._1) + private val nameToColInfo: Map[String, (Attribute, ColumnStat)] = + columnInfo.map(kv => kv._1.name -> kv) + + private val f1 = StatsTestPlan( + outputList = Seq("f1_fk1", "f1_fk2", "f1_fk3", "f1_c4").map(nameToAttr), + rowCount = 6, + size = Some(48), + attributeStats = AttributeMap(Seq("f1_fk1", "f1_fk2", "f1_fk3", "f1_c4").map(nameToColInfo))) + + private val d1 = StatsTestPlan( + outputList = Seq("d1_pk1", "d1_c2", "d1_c3", "d1_c4").map(nameToAttr), + rowCount = 4, + size = Some(32), + attributeStats = AttributeMap(Seq("d1_pk1", "d1_c2", "d1_c3", "d1_c4").map(nameToColInfo))) + + private val d2 = StatsTestPlan( + outputList = Seq("d2_c2", "d2_pk1", "d2_c3", "d2_c4").map(nameToAttr), + rowCount = 3, + size = Some(24), + attributeStats = AttributeMap(Seq("d2_c2", "d2_pk1", "d2_c3", "d2_c4").map(nameToColInfo))) + + private val d3 = StatsTestPlan( + outputList = Seq("d3_fk1", "d3_c2", "d3_pk1", "d3_c4").map(nameToAttr), + rowCount = 5, + size = Some(40), + attributeStats = AttributeMap(Seq("d3_fk1", "d3_c2", "d3_pk1", "d3_c4").map(nameToColInfo))) + + private val s3 = StatsTestPlan( + outputList = Seq("s3_pk1", "s3_c2", "s3_c3", "s3_c4").map(nameToAttr), + rowCount = 2, + size = Some(17), + attributeStats = AttributeMap(Seq("s3_pk1", "s3_c2", "s3_c3", "s3_c4").map(nameToColInfo))) + + private val d3_ns = LocalRelation('d3_fk1.int, 'd3_c2.int, 'd3_pk1.int, 'd3_c4.int) + + private val f11 = StatsTestPlan( + outputList = Seq("f11_fk1", "f11_fk2", "f11_fk3", "f11_c4").map(nameToAttr), + rowCount = 6, + size = Some(48), + attributeStats = AttributeMap(Seq("f11_fk1", "f11_fk2", "f11_fk3", "f11_c4") + .map(nameToColInfo))) + + private val subq = d3.select(sum('d3_fk1).as('col)) + + test("Test 1: Selective star-join on all dimensions") { + // Star join: + // (=) (=) + // d1 - f1 - d2 + // | (=) + // s3 - d3 + // + // Query: + // select f1_fk1, f1_fk3 + // from d1, d2, f1, d3, s3 + // where f1_fk2 = d2_pk1 and d2_c2 < 2 + // and f1_fk1 = d1_pk1 + // and f1_fk3 = d3_pk1 + // and d3_fk1 = s3_pk1 + // + // Positional join reordering: d1, f1, d2, d3, s3 + // Star join reordering: f1, d2, d1, d3, s3 + val query = + d1.join(d2).join(f1).join(d3).join(s3) + .where((nameToAttr("f1_fk2") === nameToAttr("d2_pk1")) && + (nameToAttr("d2_c2") === 2) && + (nameToAttr("f1_fk1") === nameToAttr("d1_pk1")) && + (nameToAttr("f1_fk3") === nameToAttr("d3_pk1")) && + (nameToAttr("d3_fk1") === nameToAttr("s3_pk1"))) + + val expected = + f1.join(d2.where(nameToAttr("d2_c2") === 2), Inner, + Some(nameToAttr("f1_fk2") === nameToAttr("d2_pk1"))) + .join(d1, Inner, Some(nameToAttr("f1_fk1") === nameToAttr("d1_pk1"))) + .join(d3, Inner, Some(nameToAttr("f1_fk3") === nameToAttr("d3_pk1"))) + .join(s3, Inner, Some(nameToAttr("d3_fk1") === nameToAttr("s3_pk1"))) + + assertEqualPlans(query, expected) + } + + test("Test 2: Star join on a subset of dimensions due to inequality joins") { + // Star join: + // (=) (<) + // d1 - f1 - d2 + // | + // | (=) + // d3 - s3 + // (=) + // + // Query: + // select f1_fk1, f1_fk3 + // from d1, f1, d2, s3, d3 + // where f1_fk2 < d2_pk1 + // and f1_fk1 = d1_pk1 and d1_c2 = 2 + // and f1_fk3 = d3_pk1 + // and d3_fk1 = s3_pk1 + // + // Default join reordering: d1, f1, d2, d3, s3 + // Star join reordering: f1, d1, d3, d2, s3 + + val query = + d1.join(f1).join(d2).join(s3).join(d3) + .where((nameToAttr("f1_fk2") < nameToAttr("d2_pk1")) && + (nameToAttr("f1_fk1") === nameToAttr("d1_pk1")) && + (nameToAttr("d1_c2") === 2) && + (nameToAttr("f1_fk3") === nameToAttr("d3_pk1")) && + (nameToAttr("d3_fk1") === nameToAttr("s3_pk1"))) + + val expected = + f1.join(d1.where(nameToAttr("d1_c2") === 2), Inner, + Some(nameToAttr("f1_fk1") === nameToAttr("d1_pk1"))) + .join(d3, Inner, Some(nameToAttr("f1_fk3") === nameToAttr("d3_pk1"))) + .join(d2, Inner, Some(nameToAttr("f1_fk2") < nameToAttr("d2_pk1"))) + .join(s3, Inner, Some(nameToAttr("d3_fk1") === nameToAttr("s3_pk1"))) + + assertEqualPlans(query, expected) + } + + test("Test 3: Star join on a subset of dimensions since join column is not unique") { + // Star join: + // (=) (=) + // d1 - f1 - d2 + // | (=) + // d3 - s3 + // + // Query: + // select f1_fk1, f1_fk3 + // from d1, f1, d2, s3, d3 + // where f1_fk2 = d2_c4 + // and f1_fk1 = d1_pk1 and d1_c2 = 2 + // and f1_fk3 = d3_pk1 + // and d3_fk1 = s3_pk1 + // + // Default join reordering: d1, f1, d2, d3, s3 + // Star join reordering: f1, d1, d3, d2, s3 + val query = + d1.join(f1).join(d2).join(s3).join(d3) + .where((nameToAttr("f1_fk1") === nameToAttr("d1_pk1")) && + (nameToAttr("d1_c2") === 2) && + (nameToAttr("f1_fk2") === nameToAttr("d2_c4")) && + (nameToAttr("f1_fk3") === nameToAttr("d3_pk1")) && + (nameToAttr("d3_fk1") === nameToAttr("s3_pk1"))) + + val expected = + f1.join(d1.where(nameToAttr("d1_c2") === 2), Inner, + Some(nameToAttr("f1_fk1") === nameToAttr("d1_pk1"))) + .join(d3, Inner, Some(nameToAttr("d3_fk1") === nameToAttr("s3_pk1"))) + .join(d2, Inner, Some(nameToAttr("f1_fk2") === nameToAttr("d2_pk1"))) + .join(s3, Inner, Some(nameToAttr("f1_fk3") === nameToAttr("s3_c2"))) + + + assertEqualPlans(query, expected) + } + + test("Test 4: Star join on a subset of dimensions since join column is nullable") { + // Star join: + // (=) (=) + // d1 - f1 - d2 + // | (=) + // s3 - d3 + // + // Query: + // select f1_fk1, f1_fk3 + // from d1, f1, d2, s3, d3 + // where f1_fk2 = d2_c2 + // and f1_fk1 = d1_pk1 and d1_c2 = 2 + // and f1_fk3 = d3_pk1 + // and d3_fk1 = s3_pk1 + // + // Default join reordering: d1, f1, d2, d3, s3 + // Star join reordering: f1, d1, d3, d2, s3 + + val query = + d1.join(f1).join(d2).join(s3).join(d3) + .where((nameToAttr("f1_fk1") === nameToAttr("d1_pk1")) && + (nameToAttr("d1_c2") === 2) && + (nameToAttr("f1_fk2") === nameToAttr("d2_c2")) && + (nameToAttr("f1_fk3") === nameToAttr("d3_pk1")) && + (nameToAttr("d3_fk1") === nameToAttr("s3_pk1"))) + + val expected = + f1.join(d1.where(nameToAttr("d1_c2") === 2), Inner, + Some(nameToAttr("f1_fk1") === nameToAttr("d1_pk1"))) + .join(d3, Inner, Some(nameToAttr("f1_fk3") === nameToAttr("d3_pk1"))) + .join(d2, Inner, Some(nameToAttr("f1_fk2") === nameToAttr("d2_c2"))) + .join(s3, Inner, Some(nameToAttr("d3_fk1") < nameToAttr("s3_pk1"))) + + assertEqualPlans(query, expected) + } + + test("Test 5: Table stats not available for some of the joined tables") { + // Star join: + // (=) (=) + // d1 - f1 - d2 + // | (=) + // d3_ns - s3 + // + // select f1_fk1, f1_fk3 + // from d3_ns, f1, d1, d2, s3 + // where f1_fk2 = d2_pk1 and d2_c2 = 2 + // and f1_fk1 = d1_pk1 + // and f1_fk3 = d3_pk1 + // and d3_fk1 = s3_pk1 + // + // Positional join reordering: d3_ns, f1, d1, d2, s3 + // Star join reordering: empty + + val query = + d3_ns.join(f1).join(d1).join(d2).join(s3) + .where((nameToAttr("f1_fk2") === nameToAttr("d2_pk1")) && + (nameToAttr("d2_c2") === 2) && + (nameToAttr("f1_fk1") === nameToAttr("d1_pk1")) && + (nameToAttr("f1_fk3") === nameToAttr("d3_pk1")) && + (nameToAttr("d3_fk1") === nameToAttr("s3_pk1"))) + + val equivQuery = + d3_ns.join(f1, Inner, Some(nameToAttr("f1_fk3") === nameToAttr("d3_pk1"))) + .join(d1, Inner, Some(nameToAttr("f1_fk1") === nameToAttr("d1_pk1"))) + .join(d2.where(nameToAttr("d2_c2") === 2), Inner, + Some(nameToAttr("f1_fk2") === nameToAttr("d2_pk1"))) + .join(s3, Inner, Some(nameToAttr("d3_fk1") === nameToAttr("s3_pk1"))) + + assertEqualPlans(query, equivQuery) + } + + test("Test 6: Join with complex plans") { + // Star join: + // (=) (=) + // d1 - f1 - d2 + // | (=) + // (sub-query) + // + // select f1_fk1, f1_fk3 + // from (select sum(d3_fk1) as col from d3) subq, f1, d1, d2 + // where f1_fk2 = d2_pk1 and d2_c2 < 2 + // and f1_fk1 = d1_pk1 + // and f1_fk3 = sq.col + // + // Positional join reordering: d3, f1, d1, d2 + // Star join reordering: empty + + val query = + subq.join(f1).join(d1).join(d2) + .where((nameToAttr("f1_fk2") === nameToAttr("d2_pk1")) && + (nameToAttr("d2_c2") === 2) && + (nameToAttr("f1_fk1") === nameToAttr("d1_pk1")) && + (nameToAttr("f1_fk3") === "col".attr)) + + val expected = + d3.select('d3_fk1).select(sum('d3_fk1).as('col)) + .join(f1, Inner, Some(nameToAttr("f1_fk3") === "col".attr)) + .join(d1, Inner, Some(nameToAttr("f1_fk1") === nameToAttr("d1_pk1"))) + .join(d2.where(nameToAttr("d2_c2") === 2), Inner, + Some(nameToAttr("f1_fk2") === nameToAttr("d2_pk1"))) + + assertEqualPlans(query, expected) + } + + test("Test 7: Comparable fact table sizes") { + // Star join: + // (=) (=) + // d1 - f1 - d2 + // | (=) + // f11 - s3 + // + // select f1.f1_fk1, f1.f1_fk3 + // from d1, f11, f1, d2, s3 + // where f1.f1_fk2 = d2_pk1 and d2_c2 = 2 + // and f1.f1_fk1 = d1_pk1 + // and f1.f1_fk3 = f11.f1_fk3 + // and f11.f1_fk1 = s3_pk1 + // + // Positional join reordering: d1, f1, f11, d2, s3 + // Star join reordering: empty + + val query = + d1.join(f11).join(f1).join(d2).join(s3) + .where((nameToAttr("f1_fk2") === nameToAttr("d2_pk1")) && + (nameToAttr("d2_c2") === 2) && + (nameToAttr("f1_fk1") === nameToAttr("d1_pk1")) && + (nameToAttr("f1_fk3") === nameToAttr("f11_fk3")) && + (nameToAttr("f11_fk1") === nameToAttr("s3_pk1"))) + + val equivQuery = + d1.join(f1, Inner, Some(nameToAttr("f1_fk1") === nameToAttr("d1_pk1"))) + .join(f11, Inner, Some(nameToAttr("f1_fk3") === nameToAttr("f11_fk3"))) + .join(d2.where(nameToAttr("d2_c2") === 2), Inner, + Some(nameToAttr("f1_fk2") === nameToAttr("d2_pk1"))) + .join(s3, Inner, Some(nameToAttr("f11_fk1") === nameToAttr("s3_pk1"))) + + assertEqualPlans(query, equivQuery) + } + + test("Test 8: No RI joins") { + // Star join: + // (=) (=) + // d1 - f1 - d2 + // | (=) + // d3 - s3 + // + // select f1_fk1, f1_fk3 + // from d1, d3, f1, d2, s3 + // where f1_fk2 = d2_c4 and d2_c2 = 2 + // and f1_fk1 = d1_c4 + // and f1_fk3 = d3_c4 + // and d3_fk1 = s3_pk1 + // + // Positional/default join reordering: d1, f1, d3, d2, s3 + // Star join reordering: empty + + val query = + d1.join(d3).join(f1).join(d2).join(s3) + .where((nameToAttr("f1_fk2") === nameToAttr("d2_c4")) && + (nameToAttr("d2_c2") === 2) && + (nameToAttr("f1_fk1") === nameToAttr("d1_c4")) && + (nameToAttr("f1_fk3") === nameToAttr("d3_c4")) && + (nameToAttr("d3_fk1") === nameToAttr("s3_pk1"))) + + val expected = + d1.join(f1, Inner, Some(nameToAttr("f1_fk1") === nameToAttr("d1_c4"))) + .join(d3, Inner, Some(nameToAttr("f1_fk3") === nameToAttr("d3_c4"))) + .join(d2.where(nameToAttr("d2_c2") === 2), Inner, + Some(nameToAttr("f1_fk2") === nameToAttr("d2_c4"))) + .join(s3, Inner, Some(nameToAttr("d3_fk1") === nameToAttr("s3_pk1"))) + + assertEqualPlans(query, expected) + } + + test("Test 9: Complex join predicates") { + // Star join: + // (=) (=) + // d1 - f1 - d2 + // | (=) + // d3 - s3 + // + // select f1_fk1, f1_fk3 + // from d1, d3, f1, d2, s3 + // where f1_fk2 = d2_pk1 and d2_c2 = 2 + // and abs(f1_fk1) = d1_pk1 + // and f1_fk3 = d3_pk1 + // and d3_fk1 = s3_pk1 + // + // Positional/default join reordering: d1, f1, d3, d2, s3 + // Star join reordering: empty + + val query = + d1.join(d3).join(f1).join(d2).join(s3) + .where((nameToAttr("f1_fk2") === nameToAttr("d2_pk1")) && + (nameToAttr("d2_c2") === 2) && + (abs(nameToAttr("f1_fk1")) === nameToAttr("d1_pk1")) && + (nameToAttr("f1_fk3") === nameToAttr("d3_pk1")) && + (nameToAttr("d3_fk1") === nameToAttr("s3_pk1"))) + + val expected = + d1.join(f1, Inner, Some(abs(nameToAttr("f1_fk1")) === nameToAttr("d1_pk1"))) + .join(d3, Inner, Some(nameToAttr("f1_fk3") === nameToAttr("d3_pk1"))) + .join(d2.where(nameToAttr("d2_c2") === 2), Inner, + Some(nameToAttr("f1_fk2") === nameToAttr("d2_pk1"))) + .join(s3, Inner, Some(nameToAttr("d3_fk1") === nameToAttr("s3_pk1"))) + + assertEqualPlans(query, expected) + } + + test("Test 10: Less than two dimensions") { + // Star join: + // (<) (=) + // d1 - f1 - d2 + // |(<) + // d3 - s3 + // + // select f1_fk1, f1_fk3 + // from d1, d3, f1, d2, s3 + // where f1_fk2 = d2_pk1 and d2_c2 = 2 + // and f1_fk1 < d1_pk1 + // and f1_fk3 < d3_pk1 + // + // Positional join reordering: d1, f1, d3, d2, s3 + // Star join reordering: empty + + val query = + d1.join(d3).join(f1).join(d2).join(s3) + .where((nameToAttr("f1_fk2") === nameToAttr("d2_pk1")) && + (nameToAttr("d2_c2") === 2) && + (nameToAttr("f1_fk1") < nameToAttr("d1_pk1")) && + (nameToAttr("f1_fk3") < nameToAttr("d3_pk1")) && + (nameToAttr("d3_fk1") === nameToAttr("s3_pk1"))) + + val expected = + d1.join(f1, Inner, Some(nameToAttr("f1_fk1") < nameToAttr("d1_pk1"))) + .join(d3, Inner, Some(nameToAttr("f1_fk3") < nameToAttr("d3_pk1"))) + .join(d2.where(nameToAttr("d2_c2") === 2), + Inner, Some(nameToAttr("f1_fk2") === nameToAttr("d2_pk1"))) + .join(s3, Inner, Some(nameToAttr("d3_fk1") === nameToAttr("s3_pk1"))) + + assertEqualPlans(query, expected) + } + + test("Test 11: Expanding star join") { + // Star join: + // (<) (<) + // d1 - f1 - d2 + // | (<) + // d3 - s3 + // + // select f1_fk1, f1_fk3 + // from d1, d3, f1, d2, s3 + // where f1_fk2 < d2_pk1 + // and f1_fk1 < d1_pk1 + // and f1_fk3 < d3_pk1 + // and d3_fk1 < s3_pk1 + // + // Positional join reordering: d1, f1, d3, d2, s3 + // Star join reordering: empty + + val query = + d1.join(d3).join(f1).join(d2).join(s3) + .where((nameToAttr("f1_fk2") < nameToAttr("d2_pk1")) && + (nameToAttr("f1_fk1") < nameToAttr("d1_pk1")) && + (nameToAttr("f1_fk3") < nameToAttr("d3_pk1")) && + (nameToAttr("d3_fk1") < nameToAttr("s3_pk1"))) + + val expected = + d1.join(f1, Inner, Some(nameToAttr("f1_fk1") < nameToAttr("d1_pk1"))) + .join(d3, Inner, Some(nameToAttr("f1_fk3") < nameToAttr("d3_pk1"))) + .join(d2, Inner, Some(nameToAttr("f1_fk2") < nameToAttr("d2_pk1"))) + .join(s3, Inner, Some(nameToAttr("d3_fk1") < nameToAttr("s3_pk1"))) + + assertEqualPlans(query, expected) + } + + test("Test 12: Non selective star join") { + // Star join: + // (=) (=) + // d1 - f1 - d2 + // | (=) + // d3 - s3 + // + // select f1_fk1, f1_fk3 + // from d1, d3, f1, d2, s3 + // where f1_fk2 = d2_pk1 + // and f1_fk1 = d1_pk1 + // and f1_fk3 = d3_pk1 + // and d3_fk1 = s3_pk1 + // + // Positional join reordering: d1, f1, d3, d2, s3 + // Star join reordering: empty + + val query = + d1.join(d3).join(f1).join(d2).join(s3) + .where((nameToAttr("f1_fk2") === nameToAttr("d2_pk1")) && + (nameToAttr("f1_fk1") === nameToAttr("d1_pk1")) && + (nameToAttr("f1_fk3") === nameToAttr("d3_pk1")) && + (nameToAttr("d3_fk1") === nameToAttr("s3_pk1"))) + + val expected = + d1.join(f1, Inner, Some(nameToAttr("f1_fk1") === nameToAttr("d1_pk1"))) + .join(d3, Inner, Some(nameToAttr("f1_fk3") === nameToAttr("d3_pk1"))) + .join(d2, Inner, Some(nameToAttr("f1_fk2") === nameToAttr("d2_pk1"))) + .join(s3, Inner, Some(nameToAttr("d3_fk1") === nameToAttr("s3_pk1"))) + + assertEqualPlans(query, expected) + } + + private def assertEqualPlans( plan1: LogicalPlan, plan2: LogicalPlan): Unit = { + val optimized = Optimize.execute(plan1.analyze) + val expected = plan2.analyze + compareJoinOrder(optimized, expected) + } +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/DataTypeParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/DataTypeParserSuite.scala index 3964fa3924b24..4490523369006 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/DataTypeParserSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/DataTypeParserSuite.scala @@ -30,7 +30,7 @@ class DataTypeParserSuite extends SparkFunSuite { } } - def intercept(sql: String): Unit = + def intercept(sql: String): ParseException = intercept[ParseException](CatalystSqlParser.parseDataType(sql)) def unsupported(dataTypeString: String): Unit = { @@ -118,6 +118,11 @@ class DataTypeParserSuite extends SparkFunSuite { unsupported("struct") + test("Do not print empty parentheses for no params") { + assert(intercept("unkwon").getMessage.contains("unkwon is not supported")) + assert(intercept("unkwon(1,2,3)").getMessage.contains("unkwon(1,2,3) is not supported")) + } + // DataType parser accepts certain reserved keywords. checkDataType( "Struct", diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ExpressionParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ExpressionParserSuite.scala index 2fecb8dc4a60e..eb68eb9851b85 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ExpressionParserSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ExpressionParserSuite.scala @@ -21,12 +21,14 @@ import java.sql.{Date, Timestamp} import org.apache.spark.sql.catalyst.FunctionIdentifier import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, _} import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.aggregate.{First, Last} import org.apache.spark.sql.catalyst.plans.PlanTest import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.CalendarInterval /** - * Test basic expression parsing. If a type of expression is supported it should be tested here. + * Test basic expression parsing. + * If the type of an expression is supported it should be tested here. * * Please note that some of the expressions test don't have to be sound expressions, only their * structure needs to be valid. Unsound expressions should be caught by the Analyzer or @@ -165,6 +167,11 @@ class ExpressionParserSuite extends PlanTest { assertEqual("a = b is not null", ('a === 'b).isNotNull) } + test("is distinct expressions") { + assertEqual("a is distinct from b", !('a <=> 'b)) + assertEqual("a is not distinct from b", 'a <=> 'b) + } + test("binary arithmetic expressions") { // Simple operations assertEqual("a * b", 'a * 'b) @@ -209,6 +216,7 @@ class ExpressionParserSuite extends PlanTest { assertEqual("foo(distinct a, b)", 'foo.distinctFunction('a, 'b)) assertEqual("grouping(distinct a, b)", 'grouping.distinctFunction('a, 'b)) assertEqual("`select`(all a, b)", 'select.function('a, 'b)) + assertEqual("foo(a as x, b as e)", 'foo.function('a as 'x, 'b as 'e)) } test("window function expressions") { @@ -278,6 +286,7 @@ class ExpressionParserSuite extends PlanTest { // Note that '(a)' will be interpreted as a nested expression. assertEqual("(a, b)", CreateStruct(Seq('a, 'b))) assertEqual("(a, b, c)", CreateStruct(Seq('a, 'b, 'c))) + assertEqual("(a as b, b as c)", CreateStruct(Seq('a as 'b, 'b as 'c))) } test("scalar sub-query") { @@ -546,4 +555,11 @@ class ExpressionParserSuite extends PlanTest { val complexName2 = FunctionIdentifier("ba``r", Some("fo``o")) assertEqual(complexName2.quotedString, UnresolvedAttribute("fo``o.ba``r")) } + + test("SPARK-19526 Support ignore nulls keywords for first and last") { + assertEqual("first(a ignore nulls)", First('a, Literal(true)).toAggregateExpression()) + assertEqual("first(a)", First('a, Literal(false)).toAggregateExpression()) + assertEqual("last(a ignore nulls)", Last('a, Literal(true)).toAggregateExpression()) + assertEqual("last(a)", Last('a, Literal(false)).toAggregateExpression()) + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala index 67d5d2202b680..411777d6e85a2 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala @@ -79,7 +79,7 @@ class PlanParserSuite extends PlanTest { def cte(plan: LogicalPlan, namedPlans: (String, LogicalPlan)*): With = { val ctes = namedPlans.map { case (name, cte) => - name -> SubqueryAlias(name, cte, None) + name -> SubqueryAlias(name, cte) } With(plan, ctes) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/TableSchemaParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/TableSchemaParserSuite.scala new file mode 100644 index 0000000000000..da1041d617086 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/TableSchemaParserSuite.scala @@ -0,0 +1,88 @@ +/* +* Licensed to the Apache Software Foundation (ASF) under one or more +* contributor license agreements. See the NOTICE file distributed with +* this work for additional information regarding copyright ownership. +* The ASF licenses this file to You under the Apache License, Version 2.0 +* (the "License"); you may not use this file except in compliance with +* the License. You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +package org.apache.spark.sql.catalyst.parser + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.types._ + +class TableSchemaParserSuite extends SparkFunSuite { + + def parse(sql: String): StructType = CatalystSqlParser.parseTableSchema(sql) + + def checkTableSchema(tableSchemaString: String, expectedDataType: DataType): Unit = { + test(s"parse $tableSchemaString") { + assert(parse(tableSchemaString) === expectedDataType) + } + } + + def assertError(sql: String): Unit = + intercept[ParseException](CatalystSqlParser.parseTableSchema(sql)) + + checkTableSchema("a int", new StructType().add("a", "int")) + checkTableSchema("A int", new StructType().add("A", "int")) + checkTableSchema("a INT", new StructType().add("a", "int")) + checkTableSchema("`!@#$%.^&*()` string", new StructType().add("!@#$%.^&*()", "string")) + checkTableSchema("a int, b long", new StructType().add("a", "int").add("b", "long")) + checkTableSchema("a STRUCT", + StructType( + StructField("a", StructType( + StructField("intType", IntegerType) :: + StructField("ts", TimestampType) :: Nil)) :: Nil)) + checkTableSchema( + "a int comment 'test'", + new StructType().add("a", "int", nullable = true, "test")) + + test("complex hive type") { + val tableSchemaString = + """ + |complexStructCol struct< + |struct:struct, + |MAP:Map, + |arrAy:Array, + |anotherArray:Array> + """.stripMargin.replace("\n", "") + + val builder = new MetadataBuilder + builder.putString(HIVE_TYPE_STRING, + "struct," + + "MAP:map,arrAy:array,anotherArray:array>") + + val expectedDataType = + StructType( + StructField("complexStructCol", StructType( + StructField("struct", + StructType( + StructField("deciMal", DecimalType.USER_DEFAULT) :: + StructField("anotherDecimal", DecimalType(5, 2)) :: Nil)) :: + StructField("MAP", MapType(TimestampType, StringType)) :: + StructField("arrAy", ArrayType(DoubleType)) :: + StructField("anotherArray", ArrayType(StringType)) :: Nil), + nullable = true, + builder.build()) :: Nil) + + assert(parse(tableSchemaString) === expectedDataType) + } + + // Negative cases + assertError("") + assertError("a") + assertError("a INT b long") + assertError("a INT,, b long") + assertError("a INT, b long,,") + assertError("a INT, b long, c int,") +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/ConstraintPropagationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/ConstraintPropagationSuite.scala index 908b370408280..4061394b862a6 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/ConstraintPropagationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/ConstraintPropagationSuite.scala @@ -397,4 +397,22 @@ class ConstraintPropagationSuite extends SparkFunSuite { IsNotNull(resolveColumn(tr, "a")), IsNotNull(resolveColumn(tr, "c"))))) } + + test("enable/disable constraint propagation") { + val tr = LocalRelation('a.int, 'b.string, 'c.int) + val filterRelation = tr.where('a.attr > 10) + + verifyConstraints( + filterRelation.analyze.getConstraints(constraintPropagationEnabled = true), + filterRelation.analyze.constraints) + + assert(filterRelation.analyze.getConstraints(constraintPropagationEnabled = false).isEmpty) + + val aliasedRelation = tr.where('c.attr > 10 && 'a.attr < 5) + .groupBy('a, 'c, 'b)('a, 'c.as("c1"), count('a).as("a3")).select('c1, 'a, 'a3) + + verifyConstraints(aliasedRelation.analyze.getConstraints(constraintPropagationEnabled = true), + aliasedRelation.analyze.constraints) + assert(aliasedRelation.analyze.getConstraints(constraintPropagationEnabled = false).isEmpty) + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala index 3b7e5e938a8e4..f44428c3512a9 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala @@ -18,18 +18,18 @@ package org.apache.spark.sql.catalyst.plans import org.apache.spark.SparkFunSuite -import org.apache.spark.sql.catalyst.SimpleCatalystConf import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.util._ +import org.apache.spark.sql.internal.SQLConf /** * Provides helper methods for comparing plans. */ abstract class PlanTest extends SparkFunSuite with PredicateHelper { - protected val conf = SimpleCatalystConf(caseSensitiveAnalysis = true) + protected val conf = new SQLConf().copy(SQLConf.CASE_SENSITIVE -> true) /** * Since attribute references are given globally unique ids during analysis, @@ -43,8 +43,6 @@ abstract class PlanTest extends SparkFunSuite with PredicateHelper { e.copy(exprId = ExprId(0)) case l: ListQuery => l.copy(exprId = ExprId(0)) - case p: PredicateSubquery => - p.copy(exprId = ExprId(0)) case a: AttributeReference => AttributeReference(a.name, a.dataType, a.nullable)(exprId = ExprId(0)) case a: Alias => @@ -62,7 +60,7 @@ abstract class PlanTest extends SparkFunSuite with PredicateHelper { * - Sample the seed will replaced by 0L. * - Join conditions will be resorted by hashCode. */ - private def normalizePlan(plan: LogicalPlan): LogicalPlan = { + protected def normalizePlan(plan: LogicalPlan): LogicalPlan = { plan transform { case filter @ Filter(condition: Expression, child: LogicalPlan) => Filter(splitConjunctivePredicates(condition).map(rewriteEqual(_)).sortBy(_.hashCode()) @@ -108,4 +106,30 @@ abstract class PlanTest extends SparkFunSuite with PredicateHelper { protected def compareExpressions(e1: Expression, e2: Expression): Unit = { comparePlans(Filter(e1, OneRowRelation), Filter(e2, OneRowRelation)) } + + /** Fails the test if the join order in the two plans do not match */ + protected def compareJoinOrder(plan1: LogicalPlan, plan2: LogicalPlan) { + val normalized1 = normalizePlan(normalizeExprIds(plan1)) + val normalized2 = normalizePlan(normalizeExprIds(plan2)) + if (!sameJoinPlan(normalized1, normalized2)) { + fail( + s""" + |== FAIL: Plans do not match === + |${sideBySide(normalized1.treeString, normalized2.treeString).mkString("\n")} + """.stripMargin) + } + } + + /** Consider symmetry for joins when comparing plans. */ + private def sameJoinPlan(plan1: LogicalPlan, plan2: LogicalPlan): Boolean = { + (plan1, plan2) match { + case (j1: Join, j2: Join) => + (sameJoinPlan(j1.left, j2.left) && sameJoinPlan(j1.right, j2.right)) || + (sameJoinPlan(j1.left, j2.right) && sameJoinPlan(j1.right, j2.left)) + case (p1: Project, p2: Project) => + p1.projectList == p2.projectList && sameJoinPlan(p1.child, p2.child) + case _ => + plan1 == plan2 + } + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/AggregateEstimationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/AggregateEstimationSuite.scala index c0b9515ca7cd0..38483a298cef0 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/AggregateEstimationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/AggregateEstimationSuite.scala @@ -21,6 +21,7 @@ import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, AttributeMap import org.apache.spark.sql.catalyst.expressions.aggregate.Count import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.plans.logical.statsEstimation.EstimationUtils._ +import org.apache.spark.sql.internal.SQLConf class AggregateEstimationSuite extends StatsEstimationTestBase { @@ -101,13 +102,13 @@ class AggregateEstimationSuite extends StatsEstimationTestBase { val noGroupAgg = Aggregate(groupingExpressions = Nil, aggregateExpressions = Seq(Alias(Count(Literal(1)), "cnt")()), child) - assert(noGroupAgg.stats(conf.copy(cboEnabled = false)) == + assert(noGroupAgg.stats(conf.copy(SQLConf.CBO_ENABLED -> false)) == // overhead + count result size Statistics(sizeInBytes = 8 + 8, rowCount = Some(1))) val hasGroupAgg = Aggregate(groupingExpressions = attributes, aggregateExpressions = attributes :+ Alias(Count(Literal(1)), "cnt")(), child) - assert(hasGroupAgg.stats(conf.copy(cboEnabled = false)) == + assert(hasGroupAgg.stats(conf.copy(SQLConf.CBO_ENABLED -> false)) == // From UnaryNode.computeStats, childSize * outputRowSize / childRowSize Statistics(sizeInBytes = 48 * (8 + 4 + 8) / (8 + 4))) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/BasicStatsEstimationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/BasicStatsEstimationSuite.scala new file mode 100644 index 0000000000000..b06871f96f0d8 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/BasicStatsEstimationSuite.scala @@ -0,0 +1,141 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.statsEstimation + +import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeMap, AttributeReference, Literal} +import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.types.IntegerType + + +class BasicStatsEstimationSuite extends StatsEstimationTestBase { + val attribute = attr("key") + val colStat = ColumnStat(distinctCount = 10, min = Some(1), max = Some(10), + nullCount = 0, avgLen = 4, maxLen = 4) + + val plan = StatsTestPlan( + outputList = Seq(attribute), + attributeStats = AttributeMap(Seq(attribute -> colStat)), + rowCount = 10, + // row count * (overhead + column size) + size = Some(10 * (8 + 4))) + + test("BroadcastHint estimation") { + val filter = Filter(Literal(true), plan) + val filterStatsCboOn = Statistics(sizeInBytes = 10 * (8 +4), isBroadcastable = false, + rowCount = Some(10), attributeStats = AttributeMap(Seq(attribute -> colStat))) + val filterStatsCboOff = Statistics(sizeInBytes = 10 * (8 +4), isBroadcastable = false) + checkStats( + filter, + expectedStatsCboOn = filterStatsCboOn, + expectedStatsCboOff = filterStatsCboOff) + + val broadcastHint = BroadcastHint(filter) + checkStats( + broadcastHint, + expectedStatsCboOn = filterStatsCboOn.copy(isBroadcastable = true), + expectedStatsCboOff = filterStatsCboOff.copy(isBroadcastable = true)) + } + + test("limit estimation: limit < child's rowCount") { + val localLimit = LocalLimit(Literal(2), plan) + val globalLimit = GlobalLimit(Literal(2), plan) + // LocalLimit's stats is just its child's stats except column stats + checkStats(localLimit, plan.stats(conf).copy(attributeStats = AttributeMap(Nil))) + checkStats(globalLimit, Statistics(sizeInBytes = 24, rowCount = Some(2))) + } + + test("limit estimation: limit > child's rowCount") { + val localLimit = LocalLimit(Literal(20), plan) + val globalLimit = GlobalLimit(Literal(20), plan) + checkStats(localLimit, plan.stats(conf).copy(attributeStats = AttributeMap(Nil))) + // Limit is larger than child's rowCount, so GlobalLimit's stats is equal to its child's stats. + checkStats(globalLimit, plan.stats(conf).copy(attributeStats = AttributeMap(Nil))) + } + + test("limit estimation: limit = 0") { + val localLimit = LocalLimit(Literal(0), plan) + val globalLimit = GlobalLimit(Literal(0), plan) + val stats = Statistics(sizeInBytes = 1, rowCount = Some(0)) + checkStats(localLimit, stats) + checkStats(globalLimit, stats) + } + + test("sample estimation") { + val sample = Sample(0.0, 0.5, withReplacement = false, (math.random * 1000).toLong, plan)() + checkStats(sample, Statistics(sizeInBytes = 60, rowCount = Some(5))) + + // Child doesn't have rowCount in stats + val childStats = Statistics(sizeInBytes = 120) + val childPlan = DummyLogicalPlan(childStats, childStats) + val sample2 = + Sample(0.0, 0.11, withReplacement = false, (math.random * 1000).toLong, childPlan)() + checkStats(sample2, Statistics(sizeInBytes = 14)) + } + + test("estimate statistics when the conf changes") { + val expectedDefaultStats = + Statistics( + sizeInBytes = 40, + rowCount = Some(10), + attributeStats = AttributeMap(Seq( + AttributeReference("c1", IntegerType)() -> ColumnStat(10, Some(1), Some(10), 0, 4, 4))), + isBroadcastable = false) + val expectedCboStats = + Statistics( + sizeInBytes = 4, + rowCount = Some(1), + attributeStats = AttributeMap(Seq( + AttributeReference("c1", IntegerType)() -> ColumnStat(1, Some(5), Some(5), 0, 4, 4))), + isBroadcastable = false) + + val plan = DummyLogicalPlan(defaultStats = expectedDefaultStats, cboStats = expectedCboStats) + checkStats( + plan, expectedStatsCboOn = expectedCboStats, expectedStatsCboOff = expectedDefaultStats) + } + + /** Check estimated stats when cbo is turned on/off. */ + private def checkStats( + plan: LogicalPlan, + expectedStatsCboOn: Statistics, + expectedStatsCboOff: Statistics): Unit = { + // Invalidate statistics + plan.invalidateStatsCache() + assert(plan.stats(conf.copy(SQLConf.CBO_ENABLED -> true)) == expectedStatsCboOn) + + plan.invalidateStatsCache() + assert(plan.stats(conf.copy(SQLConf.CBO_ENABLED -> false)) == expectedStatsCboOff) + } + + /** Check estimated stats when it's the same whether cbo is turned on or off. */ + private def checkStats(plan: LogicalPlan, expectedStats: Statistics): Unit = + checkStats(plan, expectedStats, expectedStats) +} + +/** + * This class is used for unit-testing the cbo switch, it mimics a logical plan which computes + * a simple statistics or a cbo estimated statistics based on the conf. + */ +private case class DummyLogicalPlan( + defaultStats: Statistics, + cboStats: Statistics) extends LogicalPlan { + override def output: Seq[Attribute] = Nil + override def children: Seq[LogicalPlan] = Nil + override def computeStats(conf: SQLConf): Statistics = + if (conf.cboEnabled) cboStats else defaultStats +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/FilterEstimationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/FilterEstimationSuite.scala old mode 100644 new mode 100755 index 8be74ced7bb71..a28447840ae09 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/FilterEstimationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/FilterEstimationSuite.scala @@ -20,8 +20,11 @@ package org.apache.spark.sql.catalyst.statsEstimation import java.sql.Date import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.catalyst.expressions.Literal.{FalseLiteral, TrueLiteral} +import org.apache.spark.sql.catalyst.plans.LeftOuter +import org.apache.spark.sql.catalyst.plans.logical.{ColumnStat, Filter, Join, Statistics} import org.apache.spark.sql.catalyst.plans.logical.statsEstimation.EstimationUtils._ +import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.types._ /** @@ -31,356 +34,595 @@ import org.apache.spark.sql.types._ class FilterEstimationSuite extends StatsEstimationTestBase { // Suppose our test table has 10 rows and 6 columns. - // First column cint has values: 1, 2, 3, 4, 5, 6, 7, 8, 9, 10 + // column cint has values: 1, 2, 3, 4, 5, 6, 7, 8, 9, 10 // Hence, distinctCount:10, min:1, max:10, nullCount:0, avgLen:4, maxLen:4 - val arInt = AttributeReference("cint", IntegerType)() - val childColStatInt = ColumnStat(distinctCount = 10, min = Some(1), max = Some(10), + val attrInt = AttributeReference("cint", IntegerType)() + val colStatInt = ColumnStat(distinctCount = 10, min = Some(1), max = Some(10), nullCount = 0, avgLen = 4, maxLen = 4) - // only 2 values - val arBool = AttributeReference("cbool", BooleanType)() - val childColStatBool = ColumnStat(distinctCount = 2, min = Some(false), max = Some(true), + // column cbool has only 2 distinct values + val attrBool = AttributeReference("cbool", BooleanType)() + val colStatBool = ColumnStat(distinctCount = 2, min = Some(false), max = Some(true), nullCount = 0, avgLen = 1, maxLen = 1) - // Second column cdate has 10 values from 2017-01-01 through 2017-01-10. - val dMin = Date.valueOf("2017-01-01") - val dMax = Date.valueOf("2017-01-10") - val arDate = AttributeReference("cdate", DateType)() - val childColStatDate = ColumnStat(distinctCount = 10, min = Some(dMin), max = Some(dMax), + // column cdate has 10 values from 2017-01-01 through 2017-01-10. + val dMin = DateTimeUtils.fromJavaDate(Date.valueOf("2017-01-01")) + val dMax = DateTimeUtils.fromJavaDate(Date.valueOf("2017-01-10")) + val attrDate = AttributeReference("cdate", DateType)() + val colStatDate = ColumnStat(distinctCount = 10, min = Some(dMin), max = Some(dMax), nullCount = 0, avgLen = 4, maxLen = 4) - // Fourth column cdecimal has 4 values from 0.20 through 0.80 at increment of 0.20. - val decMin = new java.math.BigDecimal("0.200000000000000000") - val decMax = new java.math.BigDecimal("0.800000000000000000") - val arDecimal = AttributeReference("cdecimal", DecimalType(18, 18))() - val childColStatDecimal = ColumnStat(distinctCount = 4, min = Some(decMin), max = Some(decMax), + // column cdecimal has 4 values from 0.20 through 0.80 at increment of 0.20. + val decMin = Decimal("0.200000000000000000") + val decMax = Decimal("0.800000000000000000") + val attrDecimal = AttributeReference("cdecimal", DecimalType(18, 18))() + val colStatDecimal = ColumnStat(distinctCount = 4, min = Some(decMin), max = Some(decMax), nullCount = 0, avgLen = 8, maxLen = 8) - // Fifth column cdouble has 10 double values: 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0 - val arDouble = AttributeReference("cdouble", DoubleType)() - val childColStatDouble = ColumnStat(distinctCount = 10, min = Some(1.0), max = Some(10.0), + // column cdouble has 10 double values: 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0 + val attrDouble = AttributeReference("cdouble", DoubleType)() + val colStatDouble = ColumnStat(distinctCount = 10, min = Some(1.0), max = Some(10.0), nullCount = 0, avgLen = 8, maxLen = 8) - // Sixth column cstring has 10 String values: + // column cstring has 10 String values: // "A0", "A1", "A2", "A3", "A4", "A5", "A6", "A7", "A8", "A9" - val arString = AttributeReference("cstring", StringType)() - val childColStatString = ColumnStat(distinctCount = 10, min = None, max = None, + val attrString = AttributeReference("cstring", StringType)() + val colStatString = ColumnStat(distinctCount = 10, min = None, max = None, nullCount = 0, avgLen = 2, maxLen = 2) + // column cint2 has values: 7, 8, 9, 10, 11, 12, 13, 14, 15, 16 + // Hence, distinctCount:10, min:7, max:16, nullCount:0, avgLen:4, maxLen:4 + // This column is created to test "cint < cint2 + val attrInt2 = AttributeReference("cint2", IntegerType)() + val colStatInt2 = ColumnStat(distinctCount = 10, min = Some(7), max = Some(16), + nullCount = 0, avgLen = 4, maxLen = 4) + + // column cint3 has values: 30, 31, 32, 33, 34, 35, 36, 37, 38, 39 + // Hence, distinctCount:10, min:30, max:39, nullCount:0, avgLen:4, maxLen:4 + // This column is created to test "cint = cint3 without overlap at all. + val attrInt3 = AttributeReference("cint3", IntegerType)() + val colStatInt3 = ColumnStat(distinctCount = 10, min = Some(30), max = Some(39), + nullCount = 0, avgLen = 4, maxLen = 4) + + // column cint4 has values in the range from 1 to 10 + // distinctCount:10, min:1, max:10, nullCount:0, avgLen:4, maxLen:4 + // This column is created to test complete overlap + val attrInt4 = AttributeReference("cint4", IntegerType)() + val colStatInt4 = ColumnStat(distinctCount = 10, min = Some(1), max = Some(10), + nullCount = 0, avgLen = 4, maxLen = 4) + + val attributeMap = AttributeMap(Seq( + attrInt -> colStatInt, + attrBool -> colStatBool, + attrDate -> colStatDate, + attrDecimal -> colStatDecimal, + attrDouble -> colStatDouble, + attrString -> colStatString, + attrInt2 -> colStatInt2, + attrInt3 -> colStatInt3, + attrInt4 -> colStatInt4 + )) + + test("true") { + validateEstimatedStats( + Filter(TrueLiteral, childStatsTestPlan(Seq(attrInt), 10L)), + Seq(attrInt -> colStatInt), + expectedRowCount = 10) + } + + test("false") { + validateEstimatedStats( + Filter(FalseLiteral, childStatsTestPlan(Seq(attrInt), 10L)), + Nil, + expectedRowCount = 0) + } + + test("null") { + validateEstimatedStats( + Filter(Literal(null, IntegerType), childStatsTestPlan(Seq(attrInt), 10L)), + Nil, + expectedRowCount = 0) + } + + test("Not(null)") { + validateEstimatedStats( + Filter(Not(Literal(null, IntegerType)), childStatsTestPlan(Seq(attrInt), 10L)), + Nil, + expectedRowCount = 0) + } + + test("Not(Not(null))") { + validateEstimatedStats( + Filter(Not(Not(Literal(null, IntegerType))), childStatsTestPlan(Seq(attrInt), 10L)), + Nil, + expectedRowCount = 0) + } + + test("cint < 3 AND null") { + val condition = And(LessThan(attrInt, Literal(3)), Literal(null, IntegerType)) + validateEstimatedStats( + Filter(condition, childStatsTestPlan(Seq(attrInt), 10L)), + Nil, + expectedRowCount = 0) + } + + test("cint < 3 OR null") { + val condition = Or(LessThan(attrInt, Literal(3)), Literal(null, IntegerType)) + validateEstimatedStats( + Filter(condition, childStatsTestPlan(Seq(attrInt), 10L)), + Seq(attrInt -> colStatInt), + expectedRowCount = 3) + } + + test("Not(cint < 3 AND null)") { + val condition = Not(And(LessThan(attrInt, Literal(3)), Literal(null, IntegerType))) + validateEstimatedStats( + Filter(condition, childStatsTestPlan(Seq(attrInt), 10L)), + Seq(attrInt -> colStatInt), + expectedRowCount = 8) + } + + test("Not(cint < 3 OR null)") { + val condition = Not(Or(LessThan(attrInt, Literal(3)), Literal(null, IntegerType))) + validateEstimatedStats( + Filter(condition, childStatsTestPlan(Seq(attrInt), 10L)), + Nil, + expectedRowCount = 0) + } + + test("Not(cint < 3 AND Not(null))") { + val condition = Not(And(LessThan(attrInt, Literal(3)), Not(Literal(null, IntegerType)))) + validateEstimatedStats( + Filter(condition, childStatsTestPlan(Seq(attrInt), 10L)), + Seq(attrInt -> colStatInt), + expectedRowCount = 8) + } + test("cint = 2") { validateEstimatedStats( - arInt, - Filter(EqualTo(arInt, Literal(2)), childStatsTestPlan(Seq(arInt), 10L)), - ColumnStat(distinctCount = 1, min = Some(2), max = Some(2), - nullCount = 0, avgLen = 4, maxLen = 4), - 1) + Filter(EqualTo(attrInt, Literal(2)), childStatsTestPlan(Seq(attrInt), 10L)), + Seq(attrInt -> ColumnStat(distinctCount = 1, min = Some(2), max = Some(2), + nullCount = 0, avgLen = 4, maxLen = 4)), + expectedRowCount = 1) } test("cint <=> 2") { validateEstimatedStats( - arInt, - Filter(EqualNullSafe(arInt, Literal(2)), childStatsTestPlan(Seq(arInt), 10L)), - ColumnStat(distinctCount = 1, min = Some(2), max = Some(2), - nullCount = 0, avgLen = 4, maxLen = 4), - 1) + Filter(EqualNullSafe(attrInt, Literal(2)), childStatsTestPlan(Seq(attrInt), 10L)), + Seq(attrInt -> ColumnStat(distinctCount = 1, min = Some(2), max = Some(2), + nullCount = 0, avgLen = 4, maxLen = 4)), + expectedRowCount = 1) } test("cint = 0") { // This is an out-of-range case since 0 is outside the range [min, max] validateEstimatedStats( - arInt, - Filter(EqualTo(arInt, Literal(0)), childStatsTestPlan(Seq(arInt), 10L)), - ColumnStat(distinctCount = 10, min = Some(1), max = Some(10), - nullCount = 0, avgLen = 4, maxLen = 4), - 0) + Filter(EqualTo(attrInt, Literal(0)), childStatsTestPlan(Seq(attrInt), 10L)), + Nil, + expectedRowCount = 0) } test("cint < 3") { validateEstimatedStats( - arInt, - Filter(LessThan(arInt, Literal(3)), childStatsTestPlan(Seq(arInt), 10L)), - ColumnStat(distinctCount = 2, min = Some(1), max = Some(3), - nullCount = 0, avgLen = 4, maxLen = 4), - 3) + Filter(LessThan(attrInt, Literal(3)), childStatsTestPlan(Seq(attrInt), 10L)), + Seq(attrInt -> ColumnStat(distinctCount = 2, min = Some(1), max = Some(3), + nullCount = 0, avgLen = 4, maxLen = 4)), + expectedRowCount = 3) } test("cint < 0") { // This is a corner case since literal 0 is smaller than min. validateEstimatedStats( - arInt, - Filter(LessThan(arInt, Literal(0)), childStatsTestPlan(Seq(arInt), 10L)), - ColumnStat(distinctCount = 10, min = Some(1), max = Some(10), - nullCount = 0, avgLen = 4, maxLen = 4), - 0) + Filter(LessThan(attrInt, Literal(0)), childStatsTestPlan(Seq(attrInt), 10L)), + Nil, + expectedRowCount = 0) } test("cint <= 3") { validateEstimatedStats( - arInt, - Filter(LessThanOrEqual(arInt, Literal(3)), childStatsTestPlan(Seq(arInt), 10L)), - ColumnStat(distinctCount = 2, min = Some(1), max = Some(3), - nullCount = 0, avgLen = 4, maxLen = 4), - 3) + Filter(LessThanOrEqual(attrInt, Literal(3)), childStatsTestPlan(Seq(attrInt), 10L)), + Seq(attrInt -> ColumnStat(distinctCount = 2, min = Some(1), max = Some(3), + nullCount = 0, avgLen = 4, maxLen = 4)), + expectedRowCount = 3) } test("cint > 6") { validateEstimatedStats( - arInt, - Filter(GreaterThan(arInt, Literal(6)), childStatsTestPlan(Seq(arInt), 10L)), - ColumnStat(distinctCount = 4, min = Some(6), max = Some(10), - nullCount = 0, avgLen = 4, maxLen = 4), - 5) + Filter(GreaterThan(attrInt, Literal(6)), childStatsTestPlan(Seq(attrInt), 10L)), + Seq(attrInt -> ColumnStat(distinctCount = 4, min = Some(6), max = Some(10), + nullCount = 0, avgLen = 4, maxLen = 4)), + expectedRowCount = 5) } test("cint > 10") { // This is a corner case since max value is 10. validateEstimatedStats( - arInt, - Filter(GreaterThan(arInt, Literal(10)), childStatsTestPlan(Seq(arInt), 10L)), - ColumnStat(distinctCount = 10, min = Some(1), max = Some(10), - nullCount = 0, avgLen = 4, maxLen = 4), - 0) + Filter(GreaterThan(attrInt, Literal(10)), childStatsTestPlan(Seq(attrInt), 10L)), + Nil, + expectedRowCount = 0) } test("cint >= 6") { validateEstimatedStats( - arInt, - Filter(GreaterThanOrEqual(arInt, Literal(6)), childStatsTestPlan(Seq(arInt), 10L)), - ColumnStat(distinctCount = 4, min = Some(6), max = Some(10), - nullCount = 0, avgLen = 4, maxLen = 4), - 5) + Filter(GreaterThanOrEqual(attrInt, Literal(6)), childStatsTestPlan(Seq(attrInt), 10L)), + Seq(attrInt -> ColumnStat(distinctCount = 4, min = Some(6), max = Some(10), + nullCount = 0, avgLen = 4, maxLen = 4)), + expectedRowCount = 5) } test("cint IS NULL") { validateEstimatedStats( - arInt, - Filter(IsNull(arInt), childStatsTestPlan(Seq(arInt), 10L)), - ColumnStat(distinctCount = 0, min = None, max = None, - nullCount = 0, avgLen = 4, maxLen = 4), - 0) + Filter(IsNull(attrInt), childStatsTestPlan(Seq(attrInt), 10L)), + Nil, + expectedRowCount = 0) } test("cint IS NOT NULL") { validateEstimatedStats( - arInt, - Filter(IsNotNull(arInt), childStatsTestPlan(Seq(arInt), 10L)), - ColumnStat(distinctCount = 10, min = Some(1), max = Some(10), - nullCount = 0, avgLen = 4, maxLen = 4), - 10) + Filter(IsNotNull(attrInt), childStatsTestPlan(Seq(attrInt), 10L)), + Seq(attrInt -> ColumnStat(distinctCount = 10, min = Some(1), max = Some(10), + nullCount = 0, avgLen = 4, maxLen = 4)), + expectedRowCount = 10) + } + + test("cint IS NOT NULL && null") { + // 'cint < null' will be optimized to 'cint IS NOT NULL && null'. + // More similar cases can be found in the Optimizer NullPropagation. + val condition = And(IsNotNull(attrInt), Literal(null, IntegerType)) + validateEstimatedStats( + Filter(condition, childStatsTestPlan(Seq(attrInt), 10L)), + Nil, + expectedRowCount = 0) } test("cint > 3 AND cint <= 6") { - val condition = And(GreaterThan(arInt, Literal(3)), LessThanOrEqual(arInt, Literal(6))) + val condition = And(GreaterThan(attrInt, Literal(3)), LessThanOrEqual(attrInt, Literal(6))) validateEstimatedStats( - arInt, - Filter(condition, childStatsTestPlan(Seq(arInt), 10L)), - ColumnStat(distinctCount = 3, min = Some(3), max = Some(6), - nullCount = 0, avgLen = 4, maxLen = 4), - 4) + Filter(condition, childStatsTestPlan(Seq(attrInt), 10L)), + Seq(attrInt -> ColumnStat(distinctCount = 3, min = Some(3), max = Some(6), + nullCount = 0, avgLen = 4, maxLen = 4)), + expectedRowCount = 4) } test("cint = 3 OR cint = 6") { - val condition = Or(EqualTo(arInt, Literal(3)), EqualTo(arInt, Literal(6))) + val condition = Or(EqualTo(attrInt, Literal(3)), EqualTo(attrInt, Literal(6))) validateEstimatedStats( - arInt, - Filter(condition, childStatsTestPlan(Seq(arInt), 10L)), - ColumnStat(distinctCount = 10, min = Some(1), max = Some(10), - nullCount = 0, avgLen = 4, maxLen = 4), - 2) + Filter(condition, childStatsTestPlan(Seq(attrInt), 10L)), + Seq(attrInt -> ColumnStat(distinctCount = 10, min = Some(1), max = Some(10), + nullCount = 0, avgLen = 4, maxLen = 4)), + expectedRowCount = 2) + } + + test("Not(cint > 3 AND cint <= 6)") { + val condition = Not(And(GreaterThan(attrInt, Literal(3)), LessThanOrEqual(attrInt, Literal(6)))) + validateEstimatedStats( + Filter(condition, childStatsTestPlan(Seq(attrInt), 10L)), + Seq(attrInt -> colStatInt), + expectedRowCount = 6) + } + + test("Not(cint <= 3 OR cint > 6)") { + val condition = Not(Or(LessThanOrEqual(attrInt, Literal(3)), GreaterThan(attrInt, Literal(6)))) + validateEstimatedStats( + Filter(condition, childStatsTestPlan(Seq(attrInt), 10L)), + Seq(attrInt -> colStatInt), + expectedRowCount = 5) + } + + test("Not(cint = 3 AND cstring < 'A8')") { + val condition = Not(And(EqualTo(attrInt, Literal(3)), LessThan(attrString, Literal("A8")))) + validateEstimatedStats( + Filter(condition, childStatsTestPlan(Seq(attrInt, attrString), 10L)), + Seq(attrInt -> colStatInt, attrString -> colStatString), + expectedRowCount = 10) + } + + test("Not(cint = 3 OR cstring < 'A8')") { + val condition = Not(Or(EqualTo(attrInt, Literal(3)), LessThan(attrString, Literal("A8")))) + validateEstimatedStats( + Filter(condition, childStatsTestPlan(Seq(attrInt, attrString), 10L)), + Seq(attrInt -> colStatInt, attrString -> colStatString), + expectedRowCount = 9) } test("cint IN (3, 4, 5)") { validateEstimatedStats( - arInt, - Filter(InSet(arInt, Set(3, 4, 5)), childStatsTestPlan(Seq(arInt), 10L)), - ColumnStat(distinctCount = 3, min = Some(3), max = Some(5), - nullCount = 0, avgLen = 4, maxLen = 4), - 3) + Filter(InSet(attrInt, Set(3, 4, 5)), childStatsTestPlan(Seq(attrInt), 10L)), + Seq(attrInt -> ColumnStat(distinctCount = 3, min = Some(3), max = Some(5), + nullCount = 0, avgLen = 4, maxLen = 4)), + expectedRowCount = 3) } test("cint NOT IN (3, 4, 5)") { validateEstimatedStats( - arInt, - Filter(Not(InSet(arInt, Set(3, 4, 5))), childStatsTestPlan(Seq(arInt), 10L)), - ColumnStat(distinctCount = 10, min = Some(1), max = Some(10), - nullCount = 0, avgLen = 4, maxLen = 4), - 7) + Filter(Not(InSet(attrInt, Set(3, 4, 5))), childStatsTestPlan(Seq(attrInt), 10L)), + Seq(attrInt -> ColumnStat(distinctCount = 10, min = Some(1), max = Some(10), + nullCount = 0, avgLen = 4, maxLen = 4)), + expectedRowCount = 7) + } + + test("cbool IN (true)") { + validateEstimatedStats( + Filter(InSet(attrBool, Set(true)), childStatsTestPlan(Seq(attrBool), 10L)), + Seq(attrBool -> ColumnStat(distinctCount = 1, min = Some(true), max = Some(true), + nullCount = 0, avgLen = 1, maxLen = 1)), + expectedRowCount = 5) } test("cbool = true") { validateEstimatedStats( - arBool, - Filter(EqualTo(arBool, Literal(true)), childStatsTestPlan(Seq(arBool), 10L)), - ColumnStat(distinctCount = 1, min = Some(true), max = Some(true), - nullCount = 0, avgLen = 1, maxLen = 1), - 5) + Filter(EqualTo(attrBool, Literal(true)), childStatsTestPlan(Seq(attrBool), 10L)), + Seq(attrBool -> ColumnStat(distinctCount = 1, min = Some(true), max = Some(true), + nullCount = 0, avgLen = 1, maxLen = 1)), + expectedRowCount = 5) } test("cbool > false") { - // bool comparison is not supported yet, so stats remain same. validateEstimatedStats( - arBool, - Filter(GreaterThan(arBool, Literal(false)), childStatsTestPlan(Seq(arBool), 10L)), - ColumnStat(distinctCount = 2, min = Some(false), max = Some(true), - nullCount = 0, avgLen = 1, maxLen = 1), - 10) + Filter(GreaterThan(attrBool, Literal(false)), childStatsTestPlan(Seq(attrBool), 10L)), + Seq(attrBool -> ColumnStat(distinctCount = 1, min = Some(true), max = Some(true), + nullCount = 0, avgLen = 1, maxLen = 1)), + expectedRowCount = 5) } test("cdate = cast('2017-01-02' AS DATE)") { - val d20170102 = Date.valueOf("2017-01-02") + val d20170102 = DateTimeUtils.fromJavaDate(Date.valueOf("2017-01-02")) validateEstimatedStats( - arDate, - Filter(EqualTo(arDate, Literal(d20170102)), - childStatsTestPlan(Seq(arDate), 10L)), - ColumnStat(distinctCount = 1, min = Some(d20170102), max = Some(d20170102), - nullCount = 0, avgLen = 4, maxLen = 4), - 1) + Filter(EqualTo(attrDate, Literal(d20170102, DateType)), + childStatsTestPlan(Seq(attrDate), 10L)), + Seq(attrDate -> ColumnStat(distinctCount = 1, min = Some(d20170102), max = Some(d20170102), + nullCount = 0, avgLen = 4, maxLen = 4)), + expectedRowCount = 1) } test("cdate < cast('2017-01-03' AS DATE)") { - val d20170103 = Date.valueOf("2017-01-03") + val d20170103 = DateTimeUtils.fromJavaDate(Date.valueOf("2017-01-03")) validateEstimatedStats( - arDate, - Filter(LessThan(arDate, Literal(d20170103)), - childStatsTestPlan(Seq(arDate), 10L)), - ColumnStat(distinctCount = 2, min = Some(dMin), max = Some(d20170103), - nullCount = 0, avgLen = 4, maxLen = 4), - 3) + Filter(LessThan(attrDate, Literal(d20170103, DateType)), + childStatsTestPlan(Seq(attrDate), 10L)), + Seq(attrDate -> ColumnStat(distinctCount = 2, min = Some(dMin), max = Some(d20170103), + nullCount = 0, avgLen = 4, maxLen = 4)), + expectedRowCount = 3) } test("""cdate IN ( cast('2017-01-03' AS DATE), cast('2017-01-04' AS DATE), cast('2017-01-05' AS DATE) )""") { - val d20170103 = Date.valueOf("2017-01-03") - val d20170104 = Date.valueOf("2017-01-04") - val d20170105 = Date.valueOf("2017-01-05") - validateEstimatedStats( - arDate, - Filter(In(arDate, Seq(Literal(d20170103), Literal(d20170104), Literal(d20170105))), - childStatsTestPlan(Seq(arDate), 10L)), - ColumnStat(distinctCount = 3, min = Some(d20170103), max = Some(d20170105), - nullCount = 0, avgLen = 4, maxLen = 4), - 3) + val d20170103 = DateTimeUtils.fromJavaDate(Date.valueOf("2017-01-03")) + val d20170104 = DateTimeUtils.fromJavaDate(Date.valueOf("2017-01-04")) + val d20170105 = DateTimeUtils.fromJavaDate(Date.valueOf("2017-01-05")) + validateEstimatedStats( + Filter(In(attrDate, Seq(Literal(d20170103, DateType), Literal(d20170104, DateType), + Literal(d20170105, DateType))), childStatsTestPlan(Seq(attrDate), 10L)), + Seq(attrDate -> ColumnStat(distinctCount = 3, min = Some(d20170103), max = Some(d20170105), + nullCount = 0, avgLen = 4, maxLen = 4)), + expectedRowCount = 3) } test("cdecimal = 0.400000000000000000") { - val dec_0_40 = new java.math.BigDecimal("0.400000000000000000") + val dec_0_40 = Decimal("0.400000000000000000") validateEstimatedStats( - arDecimal, - Filter(EqualTo(arDecimal, Literal(dec_0_40)), - childStatsTestPlan(Seq(arDecimal), 4L)), - ColumnStat(distinctCount = 1, min = Some(dec_0_40), max = Some(dec_0_40), - nullCount = 0, avgLen = 8, maxLen = 8), - 1) + Filter(EqualTo(attrDecimal, Literal(dec_0_40)), + childStatsTestPlan(Seq(attrDecimal), 4L)), + Seq(attrDecimal -> ColumnStat(distinctCount = 1, min = Some(dec_0_40), max = Some(dec_0_40), + nullCount = 0, avgLen = 8, maxLen = 8)), + expectedRowCount = 1) } test("cdecimal < 0.60 ") { - val dec_0_60 = new java.math.BigDecimal("0.600000000000000000") + val dec_0_60 = Decimal("0.600000000000000000") validateEstimatedStats( - arDecimal, - Filter(LessThan(arDecimal, Literal(dec_0_60)), - childStatsTestPlan(Seq(arDecimal), 4L)), - ColumnStat(distinctCount = 3, min = Some(decMin), max = Some(dec_0_60), - nullCount = 0, avgLen = 8, maxLen = 8), - 3) + Filter(LessThan(attrDecimal, Literal(dec_0_60)), + childStatsTestPlan(Seq(attrDecimal), 4L)), + Seq(attrDecimal -> ColumnStat(distinctCount = 3, min = Some(decMin), max = Some(dec_0_60), + nullCount = 0, avgLen = 8, maxLen = 8)), + expectedRowCount = 3) } test("cdouble < 3.0") { validateEstimatedStats( - arDouble, - Filter(LessThan(arDouble, Literal(3.0)), childStatsTestPlan(Seq(arDouble), 10L)), - ColumnStat(distinctCount = 2, min = Some(1.0), max = Some(3.0), - nullCount = 0, avgLen = 8, maxLen = 8), - 3) + Filter(LessThan(attrDouble, Literal(3.0)), childStatsTestPlan(Seq(attrDouble), 10L)), + Seq(attrDouble -> ColumnStat(distinctCount = 2, min = Some(1.0), max = Some(3.0), + nullCount = 0, avgLen = 8, maxLen = 8)), + expectedRowCount = 3) } test("cstring = 'A2'") { validateEstimatedStats( - arString, - Filter(EqualTo(arString, Literal("A2")), childStatsTestPlan(Seq(arString), 10L)), - ColumnStat(distinctCount = 1, min = None, max = None, - nullCount = 0, avgLen = 2, maxLen = 2), - 1) + Filter(EqualTo(attrString, Literal("A2")), childStatsTestPlan(Seq(attrString), 10L)), + Seq(attrString -> ColumnStat(distinctCount = 1, min = None, max = None, + nullCount = 0, avgLen = 2, maxLen = 2)), + expectedRowCount = 1) } - // There is no min/max statistics for String type. We estimate 10 rows returned. - test("cstring < 'A2'") { + test("cstring < 'A2' - unsupported condition") { validateEstimatedStats( - arString, - Filter(LessThan(arString, Literal("A2")), childStatsTestPlan(Seq(arString), 10L)), - ColumnStat(distinctCount = 10, min = None, max = None, - nullCount = 0, avgLen = 2, maxLen = 2), - 10) + Filter(LessThan(attrString, Literal("A2")), childStatsTestPlan(Seq(attrString), 10L)), + Seq(attrString -> ColumnStat(distinctCount = 10, min = None, max = None, + nullCount = 0, avgLen = 2, maxLen = 2)), + expectedRowCount = 10) } - // This is a corner test case. We want to test if we can handle the case when the number of - // valid values in IN clause is greater than the number of distinct values for a given column. - // For example, column has only 2 distinct values 1 and 6. - // The predicate is: column IN (1, 2, 3, 4, 5). test("cint IN (1, 2, 3, 4, 5)") { + // This is a corner test case. We want to test if we can handle the case when the number of + // valid values in IN clause is greater than the number of distinct values for a given column. + // For example, column has only 2 distinct values 1 and 6. + // The predicate is: column IN (1, 2, 3, 4, 5). val cornerChildColStatInt = ColumnStat(distinctCount = 2, min = Some(1), max = Some(6), nullCount = 0, avgLen = 4, maxLen = 4) val cornerChildStatsTestplan = StatsTestPlan( - outputList = Seq(arInt), + outputList = Seq(attrInt), rowCount = 2L, - attributeStats = AttributeMap(Seq(arInt -> cornerChildColStatInt)) + attributeStats = AttributeMap(Seq(attrInt -> cornerChildColStatInt)) ) validateEstimatedStats( - arInt, - Filter(InSet(arInt, Set(1, 2, 3, 4, 5)), cornerChildStatsTestplan), - ColumnStat(distinctCount = 2, min = Some(1), max = Some(5), + Filter(InSet(attrInt, Set(1, 2, 3, 4, 5)), cornerChildStatsTestplan), + Seq(attrInt -> ColumnStat(distinctCount = 2, min = Some(1), max = Some(5), + nullCount = 0, avgLen = 4, maxLen = 4)), + expectedRowCount = 2) + } + + // This is a limitation test. We should remove it after the limitation is removed. + test("don't estimate IsNull or IsNotNull if the child is a non-leaf node") { + val attrIntLargerRange = AttributeReference("c1", IntegerType)() + val colStatIntLargerRange = ColumnStat(distinctCount = 20, min = Some(1), max = Some(20), + nullCount = 10, avgLen = 4, maxLen = 4) + val smallerTable = childStatsTestPlan(Seq(attrInt), 10L) + val largerTable = StatsTestPlan( + outputList = Seq(attrIntLargerRange), + rowCount = 30, + attributeStats = AttributeMap(Seq(attrIntLargerRange -> colStatIntLargerRange))) + val nonLeafChild = Join(largerTable, smallerTable, LeftOuter, + Some(EqualTo(attrIntLargerRange, attrInt))) + + Seq(IsNull(attrIntLargerRange), IsNotNull(attrIntLargerRange)).foreach { predicate => + validateEstimatedStats( + Filter(predicate, nonLeafChild), + // column stats don't change + Seq(attrInt -> colStatInt, attrIntLargerRange -> colStatIntLargerRange), + expectedRowCount = 30) + } + } + + test("cint = cint2") { + // partial overlap case + validateEstimatedStats( + Filter(EqualTo(attrInt, attrInt2), childStatsTestPlan(Seq(attrInt, attrInt2), 10L)), + Seq(attrInt -> ColumnStat(distinctCount = 3, min = Some(7), max = Some(10), + nullCount = 0, avgLen = 4, maxLen = 4), + attrInt2 -> ColumnStat(distinctCount = 3, min = Some(7), max = Some(10), + nullCount = 0, avgLen = 4, maxLen = 4)), + expectedRowCount = 4) + } + + test("cint > cint2") { + // partial overlap case + validateEstimatedStats( + Filter(GreaterThan(attrInt, attrInt2), childStatsTestPlan(Seq(attrInt, attrInt2), 10L)), + Seq(attrInt -> ColumnStat(distinctCount = 3, min = Some(7), max = Some(10), + nullCount = 0, avgLen = 4, maxLen = 4), + attrInt2 -> ColumnStat(distinctCount = 3, min = Some(7), max = Some(10), + nullCount = 0, avgLen = 4, maxLen = 4)), + expectedRowCount = 4) + } + + test("cint < cint2") { + // partial overlap case + validateEstimatedStats( + Filter(LessThan(attrInt, attrInt2), childStatsTestPlan(Seq(attrInt, attrInt2), 10L)), + Seq(attrInt -> ColumnStat(distinctCount = 3, min = Some(1), max = Some(10), + nullCount = 0, avgLen = 4, maxLen = 4), + attrInt2 -> ColumnStat(distinctCount = 3, min = Some(7), max = Some(16), + nullCount = 0, avgLen = 4, maxLen = 4)), + expectedRowCount = 4) + } + + test("cint = cint4") { + // complete overlap case + validateEstimatedStats( + Filter(EqualTo(attrInt, attrInt4), childStatsTestPlan(Seq(attrInt, attrInt4), 10L)), + Seq(attrInt -> ColumnStat(distinctCount = 10, min = Some(1), max = Some(10), nullCount = 0, avgLen = 4, maxLen = 4), - 2) + attrInt4 -> ColumnStat(distinctCount = 10, min = Some(1), max = Some(10), + nullCount = 0, avgLen = 4, maxLen = 4)), + expectedRowCount = 10) + } + + test("cint < cint4") { + // partial overlap case + validateEstimatedStats( + Filter(LessThan(attrInt, attrInt4), childStatsTestPlan(Seq(attrInt, attrInt4), 10L)), + Seq(attrInt -> ColumnStat(distinctCount = 3, min = Some(1), max = Some(10), + nullCount = 0, avgLen = 4, maxLen = 4), + attrInt4 -> ColumnStat(distinctCount = 3, min = Some(1), max = Some(10), + nullCount = 0, avgLen = 4, maxLen = 4)), + expectedRowCount = 4) + } + + test("cint = cint3") { + // no records qualify due to no overlap + validateEstimatedStats( + Filter(EqualTo(attrInt, attrInt3), childStatsTestPlan(Seq(attrInt, attrInt3), 10L)), + Nil, // set to empty + expectedRowCount = 0) + } + + test("cint < cint3") { + // all table records qualify. + validateEstimatedStats( + Filter(LessThan(attrInt, attrInt3), childStatsTestPlan(Seq(attrInt, attrInt3), 10L)), + Seq(attrInt -> ColumnStat(distinctCount = 10, min = Some(1), max = Some(10), + nullCount = 0, avgLen = 4, maxLen = 4), + attrInt3 -> ColumnStat(distinctCount = 10, min = Some(30), max = Some(39), + nullCount = 0, avgLen = 4, maxLen = 4)), + expectedRowCount = 10) + } + + test("cint > cint3") { + // no records qualify due to no overlap + validateEstimatedStats( + Filter(GreaterThan(attrInt, attrInt3), childStatsTestPlan(Seq(attrInt, attrInt3), 10L)), + Nil, // set to empty + expectedRowCount = 0) } private def childStatsTestPlan(outList: Seq[Attribute], tableRowCount: BigInt): StatsTestPlan = { StatsTestPlan( outputList = outList, rowCount = tableRowCount, - attributeStats = AttributeMap(Seq( - arInt -> childColStatInt, - arBool -> childColStatBool, - arDate -> childColStatDate, - arDecimal -> childColStatDecimal, - arDouble -> childColStatDouble, - arString -> childColStatString - )) - ) + attributeStats = AttributeMap(outList.map(a => a -> attributeMap(a)))) } private def validateEstimatedStats( - ar: AttributeReference, filterNode: Filter, - expectedColStats: ColumnStat, - rowCount: Int): Unit = { - - val expectedAttrStats = toAttributeMap(Seq(ar.name -> expectedColStats), filterNode) - val expectedSizeInBytes = getOutputSize(filterNode.output, rowCount, expectedAttrStats) - - val filteredStats = filterNode.stats(conf) - assert(filteredStats.sizeInBytes == expectedSizeInBytes) - assert(filteredStats.rowCount.get == rowCount) - assert(filteredStats.attributeStats(ar) == expectedColStats) - - // If the filter has a binary operator (including those nested inside - // AND/OR/NOT), swap the sides of the attribte and the literal, reverse the - // operator, and then check again. - val rewrittenFilter = filterNode transformExpressionsDown { - case EqualTo(ar: AttributeReference, l: Literal) => - EqualTo(l, ar) - - case LessThan(ar: AttributeReference, l: Literal) => - GreaterThan(l, ar) - case LessThanOrEqual(ar: AttributeReference, l: Literal) => - GreaterThanOrEqual(l, ar) - - case GreaterThan(ar: AttributeReference, l: Literal) => - LessThan(l, ar) - case GreaterThanOrEqual(ar: AttributeReference, l: Literal) => - LessThanOrEqual(l, ar) + expectedColStats: Seq[(Attribute, ColumnStat)], + expectedRowCount: Int): Unit = { + + // If the filter has a binary operator (including those nested inside AND/OR/NOT), swap the + // sides of the attribute and the literal, reverse the operator, and then check again. + val swappedFilter = filterNode transformExpressionsDown { + case EqualTo(attr: Attribute, l: Literal) => + EqualTo(l, attr) + + case LessThan(attr: Attribute, l: Literal) => + GreaterThan(l, attr) + case LessThanOrEqual(attr: Attribute, l: Literal) => + GreaterThanOrEqual(l, attr) + + case GreaterThan(attr: Attribute, l: Literal) => + LessThan(l, attr) + case GreaterThanOrEqual(attr: Attribute, l: Literal) => + LessThanOrEqual(l, attr) } - if (rewrittenFilter != filterNode) { - validateEstimatedStats(ar, rewrittenFilter, expectedColStats, rowCount) + val testFilters = if (swappedFilter != filterNode) { + Seq(swappedFilter, filterNode) + } else { + Seq(filterNode) + } + + testFilters.foreach { filter => + val expectedAttributeMap = AttributeMap(expectedColStats) + val expectedStats = Statistics( + sizeInBytes = getOutputSize(filter.output, expectedRowCount, expectedAttributeMap), + rowCount = Some(expectedRowCount), + attributeStats = expectedAttributeMap) + + val filterStats = filter.stats(conf) + assert(filterStats.sizeInBytes == expectedStats.sizeInBytes) + assert(filterStats.rowCount == expectedStats.rowCount) + val rowCountValue = filterStats.rowCount.getOrElse(0) + // check the output column stats if the row count is > 0. + // When row count is 0, the output is set to empty. + if (rowCountValue != 0) { + // Need to check attributeStats one by one because we may have multiple output columns. + // Due to update operation, the output columns may be in different order. + assert(expectedColStats.size == filterStats.attributeStats.size) + expectedColStats.foreach { kv => + val filterColumnStat = filterStats.attributeStats.get(kv._1).get + assert(filterColumnStat == kv._2) + } + } } } + } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/JoinEstimationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/JoinEstimationSuite.scala index f62df842fa50a..2d6b6e8e21f34 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/JoinEstimationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/JoinEstimationSuite.scala @@ -25,6 +25,7 @@ import org.apache.spark.sql.catalyst.expressions.{And, Attribute, AttributeMap, import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical.{ColumnStat, Join, Project, Statistics} import org.apache.spark.sql.catalyst.plans.logical.statsEstimation.EstimationUtils._ +import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.types.{DateType, TimestampType, _} @@ -254,24 +255,24 @@ class JoinEstimationSuite extends StatsEstimationTestBase { test("test join keys of different types") { /** Columns in a table with only one row */ def genColumnData: mutable.LinkedHashMap[Attribute, ColumnStat] = { - val dec = new java.math.BigDecimal("1.000000000000000000") - val date = Date.valueOf("2016-05-08") - val timestamp = Timestamp.valueOf("2016-05-08 00:00:01") + val dec = Decimal("1.000000000000000000") + val date = DateTimeUtils.fromJavaDate(Date.valueOf("2016-05-08")) + val timestamp = DateTimeUtils.fromJavaTimestamp(Timestamp.valueOf("2016-05-08 00:00:01")) mutable.LinkedHashMap[Attribute, ColumnStat]( AttributeReference("cbool", BooleanType)() -> ColumnStat(distinctCount = 1, min = Some(false), max = Some(false), nullCount = 0, avgLen = 1, maxLen = 1), AttributeReference("cbyte", ByteType)() -> ColumnStat(distinctCount = 1, - min = Some(1L), max = Some(1L), nullCount = 0, avgLen = 1, maxLen = 1), + min = Some(1.toByte), max = Some(1.toByte), nullCount = 0, avgLen = 1, maxLen = 1), AttributeReference("cshort", ShortType)() -> ColumnStat(distinctCount = 1, - min = Some(1L), max = Some(1L), nullCount = 0, avgLen = 2, maxLen = 2), + min = Some(1.toShort), max = Some(1.toShort), nullCount = 0, avgLen = 2, maxLen = 2), AttributeReference("cint", IntegerType)() -> ColumnStat(distinctCount = 1, - min = Some(1L), max = Some(1L), nullCount = 0, avgLen = 4, maxLen = 4), + min = Some(1), max = Some(1), nullCount = 0, avgLen = 4, maxLen = 4), AttributeReference("clong", LongType)() -> ColumnStat(distinctCount = 1, min = Some(1L), max = Some(1L), nullCount = 0, avgLen = 8, maxLen = 8), AttributeReference("cdouble", DoubleType)() -> ColumnStat(distinctCount = 1, min = Some(1.0), max = Some(1.0), nullCount = 0, avgLen = 8, maxLen = 8), AttributeReference("cfloat", FloatType)() -> ColumnStat(distinctCount = 1, - min = Some(1.0), max = Some(1.0), nullCount = 0, avgLen = 4, maxLen = 4), + min = Some(1.0f), max = Some(1.0f), nullCount = 0, avgLen = 4, maxLen = 4), AttributeReference("cdec", DecimalType.SYSTEM_DEFAULT)() -> ColumnStat(distinctCount = 1, min = Some(dec), max = Some(dec), nullCount = 0, avgLen = 16, maxLen = 16), AttributeReference("cstring", StringType)() -> ColumnStat(distinctCount = 1, diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/ProjectEstimationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/ProjectEstimationSuite.scala index f408dc4153586..a5c4d22a29386 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/ProjectEstimationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/ProjectEstimationSuite.scala @@ -21,6 +21,7 @@ import java.sql.{Date, Timestamp} import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, AttributeMap, AttributeReference} import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.types._ @@ -62,28 +63,28 @@ class ProjectEstimationSuite extends StatsEstimationTestBase { } test("test row size estimation") { - val dec1 = new java.math.BigDecimal("1.000000000000000000") - val dec2 = new java.math.BigDecimal("8.000000000000000000") - val d1 = Date.valueOf("2016-05-08") - val d2 = Date.valueOf("2016-05-09") - val t1 = Timestamp.valueOf("2016-05-08 00:00:01") - val t2 = Timestamp.valueOf("2016-05-09 00:00:02") + val dec1 = Decimal("1.000000000000000000") + val dec2 = Decimal("8.000000000000000000") + val d1 = DateTimeUtils.fromJavaDate(Date.valueOf("2016-05-08")) + val d2 = DateTimeUtils.fromJavaDate(Date.valueOf("2016-05-09")) + val t1 = DateTimeUtils.fromJavaTimestamp(Timestamp.valueOf("2016-05-08 00:00:01")) + val t2 = DateTimeUtils.fromJavaTimestamp(Timestamp.valueOf("2016-05-09 00:00:02")) val columnInfo: AttributeMap[ColumnStat] = AttributeMap(Seq( AttributeReference("cbool", BooleanType)() -> ColumnStat(distinctCount = 2, min = Some(false), max = Some(true), nullCount = 0, avgLen = 1, maxLen = 1), AttributeReference("cbyte", ByteType)() -> ColumnStat(distinctCount = 2, - min = Some(1L), max = Some(2L), nullCount = 0, avgLen = 1, maxLen = 1), + min = Some(1.toByte), max = Some(2.toByte), nullCount = 0, avgLen = 1, maxLen = 1), AttributeReference("cshort", ShortType)() -> ColumnStat(distinctCount = 2, - min = Some(1L), max = Some(3L), nullCount = 0, avgLen = 2, maxLen = 2), + min = Some(1.toShort), max = Some(3.toShort), nullCount = 0, avgLen = 2, maxLen = 2), AttributeReference("cint", IntegerType)() -> ColumnStat(distinctCount = 2, - min = Some(1L), max = Some(4L), nullCount = 0, avgLen = 4, maxLen = 4), + min = Some(1), max = Some(4), nullCount = 0, avgLen = 4, maxLen = 4), AttributeReference("clong", LongType)() -> ColumnStat(distinctCount = 2, min = Some(1L), max = Some(5L), nullCount = 0, avgLen = 8, maxLen = 8), AttributeReference("cdouble", DoubleType)() -> ColumnStat(distinctCount = 2, min = Some(1.0), max = Some(6.0), nullCount = 0, avgLen = 8, maxLen = 8), AttributeReference("cfloat", FloatType)() -> ColumnStat(distinctCount = 2, - min = Some(1.0), max = Some(7.0), nullCount = 0, avgLen = 4, maxLen = 4), + min = Some(1.0f), max = Some(7.0f), nullCount = 0, avgLen = 4, maxLen = 4), AttributeReference("cdecimal", DecimalType.SYSTEM_DEFAULT)() -> ColumnStat(distinctCount = 2, min = Some(dec1), max = Some(dec2), nullCount = 0, avgLen = 16, maxLen = 16), AttributeReference("cstring", StringType)() -> ColumnStat(distinctCount = 2, diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/StatsConfSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/StatsConfSuite.scala deleted file mode 100644 index 212d57a9bcf95..0000000000000 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/StatsConfSuite.scala +++ /dev/null @@ -1,64 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.catalyst.statsEstimation - -import org.apache.spark.sql.catalyst.CatalystConf -import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeMap, AttributeReference} -import org.apache.spark.sql.catalyst.plans.logical.{ColumnStat, LogicalPlan, Statistics} -import org.apache.spark.sql.types.IntegerType - - -class StatsConfSuite extends StatsEstimationTestBase { - test("estimate statistics when the conf changes") { - val expectedDefaultStats = - Statistics( - sizeInBytes = 40, - rowCount = Some(10), - attributeStats = AttributeMap(Seq( - AttributeReference("c1", IntegerType)() -> ColumnStat(10, Some(1), Some(10), 0, 4, 4))), - isBroadcastable = false) - val expectedCboStats = - Statistics( - sizeInBytes = 4, - rowCount = Some(1), - attributeStats = AttributeMap(Seq( - AttributeReference("c1", IntegerType)() -> ColumnStat(1, Some(5), Some(5), 0, 4, 4))), - isBroadcastable = false) - - val plan = DummyLogicalPlan(defaultStats = expectedDefaultStats, cboStats = expectedCboStats) - // Return the statistics estimated by cbo - assert(plan.stats(conf.copy(cboEnabled = true)) == expectedCboStats) - // Invalidate statistics - plan.invalidateStatsCache() - // Return the simple statistics - assert(plan.stats(conf.copy(cboEnabled = false)) == expectedDefaultStats) - } -} - -/** - * This class is used for unit-testing the cbo switch, it mimics a logical plan which computes - * a simple statistics or a cbo estimated statistics based on the conf. - */ -private case class DummyLogicalPlan( - defaultStats: Statistics, - cboStats: Statistics) extends LogicalPlan { - override def output: Seq[Attribute] = Nil - override def children: Seq[LogicalPlan] = Nil - override def computeStats(conf: CatalystConf): Statistics = - if (conf.cboEnabled) cboStats else defaultStats -} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/StatsEstimationTestBase.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/StatsEstimationTestBase.scala index c56b41ce37636..263f4e18803d5 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/StatsEstimationTestBase.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/StatsEstimationTestBase.scala @@ -18,16 +18,17 @@ package org.apache.spark.sql.catalyst.statsEstimation import org.apache.spark.SparkFunSuite -import org.apache.spark.sql.catalyst.{CatalystConf, SimpleCatalystConf} import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeMap, AttributeReference} import org.apache.spark.sql.catalyst.plans.logical.{ColumnStat, LeafNode, LogicalPlan, Statistics} +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.internal.SQLConf.{CASE_SENSITIVE, CBO_ENABLED} import org.apache.spark.sql.types.{IntegerType, StringType} -class StatsEstimationTestBase extends SparkFunSuite { +trait StatsEstimationTestBase extends SparkFunSuite { /** Enable stats estimation based on CBO. */ - protected val conf = SimpleCatalystConf(caseSensitiveAnalysis = true, cboEnabled = true) + protected val conf = new SQLConf().copy(CASE_SENSITIVE -> true, CBO_ENABLED -> true) def getColSize(attribute: Attribute, colStat: ColumnStat): Long = attribute.dataType match { // For UTF8String: base + offset + numBytes @@ -48,13 +49,13 @@ class StatsEstimationTestBase extends SparkFunSuite { /** * This class is used for unit-testing. It's a logical plan whose output and stats are passed in. */ -protected case class StatsTestPlan( +case class StatsTestPlan( outputList: Seq[Attribute], rowCount: BigInt, attributeStats: AttributeMap[ColumnStat], size: Option[BigInt] = None) extends LeafNode { override def output: Seq[Attribute] = outputList - override def computeStats(conf: CatalystConf): Statistics = Statistics( + override def computeStats(conf: SQLConf): Statistics = Statistics( // If sizeInBytes is useless in testing, we just use a fake value sizeInBytes = size.getOrElse(Int.MaxValue), rowCount = Some(rowCount), diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/streaming/InternalOutputModesSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/streaming/InternalOutputModesSuite.scala new file mode 100644 index 0000000000000..3159b541dca79 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/streaming/InternalOutputModesSuite.scala @@ -0,0 +1,50 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.streaming + +import java.util.Locale + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.streaming.OutputMode + +class InternalOutputModesSuite extends SparkFunSuite { + + test("supported strings") { + def testMode(outputMode: String, expected: OutputMode): Unit = { + assert(InternalOutputModes(outputMode) === expected) + } + + testMode("append", OutputMode.Append) + testMode("Append", OutputMode.Append) + testMode("complete", OutputMode.Complete) + testMode("Complete", OutputMode.Complete) + testMode("update", OutputMode.Update) + testMode("Update", OutputMode.Update) + } + + test("unsupported strings") { + def testMode(outputMode: String): Unit = { + val acceptedModes = Seq("append", "update", "complete") + val e = intercept[IllegalArgumentException](InternalOutputModes(outputMode)) + (Seq("output mode", "unknown", outputMode) ++ acceptedModes).foreach { s => + assert(e.getMessage.toLowerCase(Locale.ROOT).contains(s.toLowerCase(Locale.ROOT))) + } + } + testMode("Xyz") + } +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala index af1eaa1f23746..37e3dfabd0b21 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala @@ -491,7 +491,8 @@ class TreeNodeSuite extends SparkFunSuite { "lastAccessTime" -> -1, "tracksPartitionsInCatalog" -> false, "properties" -> JNull, - "unsupportedFeatures" -> List.empty[String])) + "unsupportedFeatures" -> List.empty[String], + "schemaPreservesCase" -> JBool(true))) // For unknown case class, returns JNull. val bigValue = new Array[Int](10000) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/QuantileSummariesSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/QuantileSummariesSuite.scala index 5e90970b1bb2e..df579d5ec1ddf 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/QuantileSummariesSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/QuantileSummariesSuite.scala @@ -55,15 +55,19 @@ class QuantileSummariesSuite extends SparkFunSuite { } private def checkQuantile(quant: Double, data: Seq[Double], summary: QuantileSummaries): Unit = { - val approx = summary.query(quant) - // The rank of the approximation. - val rank = data.count(_ < approx) // has to be <, not <= to be exact - val lower = math.floor((quant - summary.relativeError) * data.size) - val upper = math.ceil((quant + summary.relativeError) * data.size) - val msg = - s"$rank not in [$lower $upper], requested quantile: $quant, approx returned: $approx" - assert(rank >= lower, msg) - assert(rank <= upper, msg) + if (data.nonEmpty) { + val approx = summary.query(quant).get + // The rank of the approximation. + val rank = data.count(_ < approx) // has to be <, not <= to be exact + val lower = math.floor((quant - summary.relativeError) * data.size) + val upper = math.ceil((quant + summary.relativeError) * data.size) + val msg = + s"$rank not in [$lower $upper], requested quantile: $quant, approx returned: $approx" + assert(rank >= lower, msg) + assert(rank <= upper, msg) + } else { + assert(summary.query(quant).isEmpty) + } } for { @@ -74,9 +78,9 @@ class QuantileSummariesSuite extends SparkFunSuite { test(s"Extremas with epsi=$epsi and seq=$seq_name, compression=$compression") { val s = buildSummary(data, epsi, compression) - val min_approx = s.query(0.0) + val min_approx = s.query(0.0).get assert(min_approx == data.min, s"Did not return the min: min=${data.min}, got $min_approx") - val max_approx = s.query(1.0) + val max_approx = s.query(1.0).get assert(max_approx == data.max, s"Did not return the max: max=${data.max}, got $max_approx") } @@ -100,6 +104,18 @@ class QuantileSummariesSuite extends SparkFunSuite { checkQuantile(0.1, data, s) checkQuantile(0.001, data, s) } + + test(s"Tests on empty data with epsi=$epsi and seq=$seq_name, compression=$compression") { + val emptyData = Seq.empty[Double] + val s = buildSummary(emptyData, epsi, compression) + assert(s.count == 0, s"Found count=${s.count} but data size=0") + assert(s.sampled.isEmpty, s"if QuantileSummaries is empty, sampled should be empty") + checkQuantile(0.9999, emptyData, s) + checkQuantile(0.9, emptyData, s) + checkQuantile(0.5, emptyData, s) + checkQuantile(0.1, emptyData, s) + checkQuantile(0.001, emptyData, s) + } } // Tests for merging procedure @@ -118,9 +134,9 @@ class QuantileSummariesSuite extends SparkFunSuite { val s1 = buildSummary(data1, epsi, compression) val s2 = buildSummary(data2, epsi, compression) val s = s1.merge(s2) - val min_approx = s.query(0.0) + val min_approx = s.query(0.0).get assert(min_approx == data.min, s"Did not return the min: min=${data.min}, got $min_approx") - val max_approx = s.query(1.0) + val max_approx = s.query(1.0).get assert(max_approx == data.max, s"Did not return the max: max=${data.max}, got $max_approx") checkQuantile(0.9999, data, s) checkQuantile(0.9, data, s) @@ -137,9 +153,9 @@ class QuantileSummariesSuite extends SparkFunSuite { val s1 = buildSummary(data11, epsi, compression) val s2 = buildSummary(data12, epsi, compression) val s = s1.merge(s2) - val min_approx = s.query(0.0) + val min_approx = s.query(0.0).get assert(min_approx == data.min, s"Did not return the min: min=${data.min}, got $min_approx") - val max_approx = s.query(1.0) + val max_approx = s.query(1.0).get assert(max_approx == data.max, s"Did not return the max: max=${data.max}, got $max_approx") checkQuantile(0.9999, data, s) checkQuantile(0.9, data, s) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/StringUtilsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/StringUtilsSuite.scala index 2ffc18a8d14fb..78fee5135c3ae 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/StringUtilsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/StringUtilsSuite.scala @@ -24,9 +24,9 @@ class StringUtilsSuite extends SparkFunSuite { test("escapeLikeRegex") { assert(escapeLikeRegex("abdef") === "(?s)\\Qa\\E\\Qb\\E\\Qd\\E\\Qe\\E\\Qf\\E") - assert(escapeLikeRegex("a\\__b") === "(?s)\\Qa\\E_.\\Qb\\E") + assert(escapeLikeRegex("a\\__b") === "(?s)\\Qa\\E\\Q_\\E.\\Qb\\E") assert(escapeLikeRegex("a_%b") === "(?s)\\Qa\\E..*\\Qb\\E") - assert(escapeLikeRegex("a%\\%b") === "(?s)\\Qa\\E.*%\\Qb\\E") + assert(escapeLikeRegex("a%\\%b") === "(?s)\\Qa\\E.*\\Q%\\E\\Qb\\E") assert(escapeLikeRegex("a%") === "(?s)\\Qa\\E.*") assert(escapeLikeRegex("**") === "(?s)\\Q*\\E\\Q*\\E") assert(escapeLikeRegex("a_b") === "(?s)\\Qa\\E.\\Qb\\E") diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeSuite.scala index 61e1ec7c7ab35..c4635c8f126af 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeSuite.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.types +import com.fasterxml.jackson.core.JsonParseException + import org.apache.spark.{SparkException, SparkFunSuite} import org.apache.spark.sql.catalyst.parser.CatalystSqlParser @@ -169,30 +171,72 @@ class DataTypeSuite extends SparkFunSuite { assert(!arrayType.existsRecursively(_.isInstanceOf[IntegerType])) } - def checkDataTypeJsonRepr(dataType: DataType): Unit = { - test(s"JSON - $dataType") { + def checkDataTypeFromJson(dataType: DataType): Unit = { + test(s"from Json - $dataType") { assert(DataType.fromJson(dataType.json) === dataType) } } - checkDataTypeJsonRepr(NullType) - checkDataTypeJsonRepr(BooleanType) - checkDataTypeJsonRepr(ByteType) - checkDataTypeJsonRepr(ShortType) - checkDataTypeJsonRepr(IntegerType) - checkDataTypeJsonRepr(LongType) - checkDataTypeJsonRepr(FloatType) - checkDataTypeJsonRepr(DoubleType) - checkDataTypeJsonRepr(DecimalType(10, 5)) - checkDataTypeJsonRepr(DecimalType.SYSTEM_DEFAULT) - checkDataTypeJsonRepr(DateType) - checkDataTypeJsonRepr(TimestampType) - checkDataTypeJsonRepr(StringType) - checkDataTypeJsonRepr(BinaryType) - checkDataTypeJsonRepr(ArrayType(DoubleType, true)) - checkDataTypeJsonRepr(ArrayType(StringType, false)) - checkDataTypeJsonRepr(MapType(IntegerType, StringType, true)) - checkDataTypeJsonRepr(MapType(IntegerType, ArrayType(DoubleType), false)) + def checkDataTypeFromDDL(dataType: DataType): Unit = { + test(s"from DDL - $dataType") { + val parsed = StructType.fromDDL(s"a ${dataType.sql}") + val expected = new StructType().add("a", dataType) + assert(parsed.sameType(expected)) + } + } + + checkDataTypeFromJson(NullType) + + checkDataTypeFromJson(BooleanType) + checkDataTypeFromDDL(BooleanType) + + checkDataTypeFromJson(ByteType) + checkDataTypeFromDDL(ByteType) + + checkDataTypeFromJson(ShortType) + checkDataTypeFromDDL(ShortType) + + checkDataTypeFromJson(IntegerType) + checkDataTypeFromDDL(IntegerType) + + checkDataTypeFromJson(LongType) + checkDataTypeFromDDL(LongType) + + checkDataTypeFromJson(FloatType) + checkDataTypeFromDDL(FloatType) + + checkDataTypeFromJson(DoubleType) + checkDataTypeFromDDL(DoubleType) + + checkDataTypeFromJson(DecimalType(10, 5)) + checkDataTypeFromDDL(DecimalType(10, 5)) + + checkDataTypeFromJson(DecimalType.SYSTEM_DEFAULT) + checkDataTypeFromDDL(DecimalType.SYSTEM_DEFAULT) + + checkDataTypeFromJson(DateType) + checkDataTypeFromDDL(DateType) + + checkDataTypeFromJson(TimestampType) + checkDataTypeFromDDL(TimestampType) + + checkDataTypeFromJson(StringType) + checkDataTypeFromDDL(StringType) + + checkDataTypeFromJson(BinaryType) + checkDataTypeFromDDL(BinaryType) + + checkDataTypeFromJson(ArrayType(DoubleType, true)) + checkDataTypeFromDDL(ArrayType(DoubleType, true)) + + checkDataTypeFromJson(ArrayType(StringType, false)) + checkDataTypeFromDDL(ArrayType(StringType, false)) + + checkDataTypeFromJson(MapType(IntegerType, StringType, true)) + checkDataTypeFromDDL(MapType(IntegerType, StringType, true)) + + checkDataTypeFromJson(MapType(IntegerType, ArrayType(DoubleType), false)) + checkDataTypeFromDDL(MapType(IntegerType, ArrayType(DoubleType), false)) val metadata = new MetadataBuilder() .putString("name", "age") @@ -201,7 +245,34 @@ class DataTypeSuite extends SparkFunSuite { StructField("a", IntegerType, nullable = true), StructField("b", ArrayType(DoubleType), nullable = false), StructField("c", DoubleType, nullable = false, metadata))) - checkDataTypeJsonRepr(structType) + checkDataTypeFromJson(structType) + checkDataTypeFromDDL(structType) + + test("fromJson throws an exception when given type string is invalid") { + var message = intercept[IllegalArgumentException] { + DataType.fromJson(""""abcd"""") + }.getMessage + assert(message.contains( + "Failed to convert the JSON string 'abcd' to a data type.")) + + message = intercept[IllegalArgumentException] { + DataType.fromJson("""{"abcd":"a"}""") + }.getMessage + assert(message.contains( + """Failed to convert the JSON string '{"abcd":"a"}' to a data type""")) + + message = intercept[IllegalArgumentException] { + DataType.fromJson("""{"fields": [{"a":123}], "type": "struct"}""") + }.getMessage + assert(message.contains( + """Failed to convert the JSON string '{"a":123}' to a field.""")) + + // Malformed JSON string + message = intercept[JsonParseException] { + DataType.fromJson("abcd") + }.getMessage + assert(message.contains("Unrecognized token 'abcd'")) + } def checkDefaultSize(dataType: DataType, expectedDefaultSize: Int): Unit = { test(s"Check the default size of $dataType") { @@ -340,4 +411,35 @@ class DataTypeSuite extends SparkFunSuite { checkCatalogString(ArrayType(createStruct(40))) checkCatalogString(MapType(IntegerType, StringType)) checkCatalogString(MapType(IntegerType, createStruct(40))) + + def checkEqualsStructurally(from: DataType, to: DataType, expected: Boolean): Unit = { + val testName = s"equalsStructurally: (from: $from, to: $to)" + test(testName) { + assert(DataType.equalsStructurally(from, to) === expected) + } + } + + checkEqualsStructurally(BooleanType, BooleanType, true) + checkEqualsStructurally(IntegerType, IntegerType, true) + checkEqualsStructurally(IntegerType, LongType, false) + checkEqualsStructurally(ArrayType(IntegerType, true), ArrayType(IntegerType, true), true) + checkEqualsStructurally(ArrayType(IntegerType, true), ArrayType(IntegerType, false), false) + + checkEqualsStructurally( + new StructType().add("f1", IntegerType), + new StructType().add("f2", IntegerType), + true) + checkEqualsStructurally( + new StructType().add("f1", IntegerType), + new StructType().add("f2", IntegerType, false), + false) + + checkEqualsStructurally( + new StructType().add("f1", IntegerType).add("f", new StructType().add("f2", StringType)), + new StructType().add("f2", IntegerType).add("g", new StructType().add("f1", StringType)), + true) + checkEqualsStructurally( + new StructType().add("f1", IntegerType).add("f", new StructType().add("f2", StringType, false)), + new StructType().add("f2", IntegerType).add("g", new StructType().add("f1", StringType)), + false) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DecimalSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DecimalSuite.scala index 52d0692524d0f..93c231e30b49b 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DecimalSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DecimalSuite.scala @@ -193,7 +193,7 @@ class DecimalSuite extends SparkFunSuite with PrivateMethodTester { assert(Decimal(Long.MaxValue, 100, 0).toUnscaledLong === Long.MaxValue) } - test("changePrecision() on compact decimal should respect rounding mode") { + test("changePrecision/toPrecision on compact decimal should respect rounding mode") { Seq(ROUND_FLOOR, ROUND_CEILING, ROUND_HALF_UP, ROUND_HALF_EVEN).foreach { mode => Seq("0.4", "0.5", "0.6", "1.0", "1.1", "1.6", "2.5", "5.5").foreach { n => Seq("", "-").foreach { sign => @@ -202,8 +202,20 @@ class DecimalSuite extends SparkFunSuite with PrivateMethodTester { val d = Decimal(unscaled, 8, 1) assert(d.changePrecision(10, 0, mode)) assert(d.toString === bd.setScale(0, mode).toString(), s"num: $sign$n, mode: $mode") + + val copy = d.toPrecision(10, 0, mode).orNull + assert(copy !== null) + assert(d.ne(copy)) + assert(d === copy) + assert(copy.toString === bd.setScale(0, mode).toString(), s"num: $sign$n, mode: $mode") } } } } + + test("SPARK-20341: support BigInt's value does not fit in long value range") { + val bigInt = scala.math.BigInt("9223372036854775808") + val decimal = Decimal.apply(bigInt) + assert(decimal.toJavaBigDecimal.unscaledValue.toString === "9223372036854775808") + } } diff --git a/sql/core/pom.xml b/sql/core/pom.xml index 69d797b479159..e170133f0f0bf 100644 --- a/sql/core/pom.xml +++ b/sql/core/pom.xml @@ -22,7 +22,7 @@ org.apache.spark spark-parent_2.11 - 2.2.0-SNAPSHOT + 2.3.0-SNAPSHOT ../../pom.xml @@ -103,6 +103,10 @@ jackson-databind ${fasterxml.jackson.version} + + org.apache.xbean + xbean-asm5-shaded + org.scalacheck scalacheck_${scala.binary.version} @@ -147,11 +151,6 @@ mockito-core test - - org.apache.xbean - xbean-asm5-shaded - test - target/scala-${scala.binary.version}/classes diff --git a/sql/core/src/main/java/org/apache/spark/api/java/function/FlatMapGroupsWithStateFunction.java b/sql/core/src/main/java/org/apache/spark/api/java/function/FlatMapGroupsWithStateFunction.java index d44af7ef48157..802949c0ddb60 100644 --- a/sql/core/src/main/java/org/apache/spark/api/java/function/FlatMapGroupsWithStateFunction.java +++ b/sql/core/src/main/java/org/apache/spark/api/java/function/FlatMapGroupsWithStateFunction.java @@ -22,17 +22,18 @@ import org.apache.spark.annotation.Experimental; import org.apache.spark.annotation.InterfaceStability; -import org.apache.spark.sql.KeyedState; +import org.apache.spark.sql.streaming.GroupState; /** * ::Experimental:: * Base interface for a map function used in - * {@link org.apache.spark.sql.KeyValueGroupedDataset#flatMapGroupsWithState( - * FlatMapGroupsWithStateFunction, org.apache.spark.sql.Encoder, org.apache.spark.sql.Encoder)}. + * {@code org.apache.spark.sql.KeyValueGroupedDataset.flatMapGroupsWithState( + * FlatMapGroupsWithStateFunction, org.apache.spark.sql.streaming.OutputMode, + * org.apache.spark.sql.Encoder, org.apache.spark.sql.Encoder)} * @since 2.1.1 */ @Experimental @InterfaceStability.Evolving public interface FlatMapGroupsWithStateFunction extends Serializable { - Iterator call(K key, Iterator values, KeyedState state) throws Exception; + Iterator call(K key, Iterator values, GroupState state) throws Exception; } diff --git a/sql/core/src/main/java/org/apache/spark/api/java/function/MapGroupsWithStateFunction.java b/sql/core/src/main/java/org/apache/spark/api/java/function/MapGroupsWithStateFunction.java index 75986d1706209..353e9886a8a57 100644 --- a/sql/core/src/main/java/org/apache/spark/api/java/function/MapGroupsWithStateFunction.java +++ b/sql/core/src/main/java/org/apache/spark/api/java/function/MapGroupsWithStateFunction.java @@ -22,7 +22,7 @@ import org.apache.spark.annotation.Experimental; import org.apache.spark.annotation.InterfaceStability; -import org.apache.spark.sql.KeyedState; +import org.apache.spark.sql.streaming.GroupState; /** * ::Experimental:: @@ -34,5 +34,5 @@ @Experimental @InterfaceStability.Evolving public interface MapGroupsWithStateFunction extends Serializable { - R call(K key, Iterator values, KeyedState state) throws Exception; + R call(K key, Iterator values, GroupState state) throws Exception; } diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/SpecificParquetRecordReaderBase.java b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/SpecificParquetRecordReaderBase.java index 5d7879f5f3df0..9851f8b62bd2f 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/SpecificParquetRecordReaderBase.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/SpecificParquetRecordReaderBase.java @@ -66,7 +66,6 @@ import org.apache.spark.sql.types.StructType; import org.apache.spark.sql.types.StructType$; import org.apache.spark.util.AccumulatorV2; -import org.apache.spark.util.LongAccumulator; import scala.Option; /** @@ -162,14 +161,16 @@ public void initialize(InputSplit inputSplit, TaskAttemptContext taskAttemptCont this.sparkSchema = StructType$.MODULE$.fromString(sparkRequestedSchemaString); this.totalRowCount = this.reader.getRecordCount(); // For test purpose. - // If the predefined accumulator exists, the row group number to read will be updated - // to the accumulator. So we can check if the row groups are filtered or not in test case. + // If the last external accumulator is `NumRowGroupsAccumulator`, the row group number to read + // will be updated to the accumulator. So we can check if the row groups are filtered or not + // in test case. TaskContext taskContext = TaskContext$.MODULE$.get(); if (taskContext != null) { - Option> accu = taskContext.taskMetrics() - .lookForAccumulatorByName("numRowGroups"); - if (accu.isDefined()) { - ((LongAccumulator)accu.get()).add((long)blocks.size()); + Option> accu = taskContext.taskMetrics().externalAccums().lastOption(); + if (accu.isDefined() && accu.get().getClass().getSimpleName().equals("NumRowGroupsAcc")) { + @SuppressWarnings("unchecked") + AccumulatorV2 intAccum = (AccumulatorV2) accu.get(); + intAccum.add(blocks.size()); } } } @@ -206,6 +207,7 @@ protected void initialize(String path, List columns) throws IOException config.set("spark.sql.parquet.binaryAsString", "false"); config.set("spark.sql.parquet.int96AsTimestamp", "false"); config.set("spark.sql.parquet.writeLegacyFormat", "false"); + config.set("spark.sql.parquet.int64AsTimestampMillis", "false"); this.file = new Path(path); long length = this.file.getFileSystem(config).getFileStatus(this.file).getLen(); diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedColumnReader.java b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedColumnReader.java index fd4095bf91bd7..170df87f3d2be 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedColumnReader.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedColumnReader.java @@ -17,25 +17,30 @@ package org.apache.spark.sql.execution.datasources.parquet; -import java.io.IOException; +import static org.apache.parquet.column.ValuesType.REPETITION_LEVEL; +import static org.apache.spark.sql.execution.datasources.parquet.SpecificParquetRecordReaderBase.ValuesReaderIntIterator; +import static org.apache.spark.sql.execution.datasources.parquet.SpecificParquetRecordReaderBase.createRLEIterator; +import java.io.IOException; import org.apache.parquet.bytes.BytesUtils; import org.apache.parquet.column.ColumnDescriptor; import org.apache.parquet.column.Dictionary; import org.apache.parquet.column.Encoding; -import org.apache.parquet.column.page.*; +import org.apache.parquet.column.page.DataPage; +import org.apache.parquet.column.page.DataPageV1; +import org.apache.parquet.column.page.DataPageV2; +import org.apache.parquet.column.page.DictionaryPage; +import org.apache.parquet.column.page.PageReader; import org.apache.parquet.column.values.ValuesReader; import org.apache.parquet.io.api.Binary; +import org.apache.parquet.schema.OriginalType; import org.apache.parquet.schema.PrimitiveType; - +import org.apache.parquet.schema.Type; +import org.apache.spark.sql.catalyst.util.DateTimeUtils; import org.apache.spark.sql.execution.vectorized.ColumnVector; import org.apache.spark.sql.types.DataTypes; import org.apache.spark.sql.types.DecimalType; -import static org.apache.parquet.column.ValuesType.REPETITION_LEVEL; -import static org.apache.spark.sql.execution.datasources.parquet.SpecificParquetRecordReaderBase.ValuesReaderIntIterator; -import static org.apache.spark.sql.execution.datasources.parquet.SpecificParquetRecordReaderBase.createRLEIterator; - /** * Decoder to return values from a single column. */ @@ -89,12 +94,14 @@ public class VectorizedColumnReader { private final PageReader pageReader; private final ColumnDescriptor descriptor; + private final Type fullType; - public VectorizedColumnReader(ColumnDescriptor descriptor, PageReader pageReader) + public VectorizedColumnReader(ColumnDescriptor descriptor, PageReader pageReader, Type fullType) throws IOException { this.descriptor = descriptor; this.pageReader = pageReader; this.maxDefLevel = descriptor.getMaxDefinitionLevel(); + this.fullType = fullType; DictionaryPage dictionaryPage = pageReader.readDictionaryPage(); if (dictionaryPage != null) { @@ -155,9 +162,13 @@ void readBatch(int total, ColumnVector column) throws IOException { // Read and decode dictionary ids. defColumn.readIntegers( num, dictionaryIds, column, rowId, maxDefLevel, (VectorizedValuesReader) dataColumn); + + // Timestamp values encoded as INT64 can't be lazily decoded as we need to post process + // the values to add microseconds precision. if (column.hasDictionary() || (rowId == 0 && (descriptor.getType() == PrimitiveType.PrimitiveTypeName.INT32 || - descriptor.getType() == PrimitiveType.PrimitiveTypeName.INT64 || + (descriptor.getType() == PrimitiveType.PrimitiveTypeName.INT64 && + column.dataType() != DataTypes.TimestampType) || descriptor.getType() == PrimitiveType.PrimitiveTypeName.FLOAT || descriptor.getType() == PrimitiveType.PrimitiveTypeName.DOUBLE || descriptor.getType() == PrimitiveType.PrimitiveTypeName.BINARY))) { @@ -244,14 +255,24 @@ private void decodeDictionaryIds(int rowId, int num, ColumnVector column, case INT64: if (column.dataType() == DataTypes.LongType || - column.dataType() == DataTypes.TimestampType || + (column.dataType() == DataTypes.TimestampType + && fullType.getOriginalType() == OriginalType.TIMESTAMP_MICROS) || DecimalType.is64BitDecimalType(column.dataType())) { for (int i = rowId; i < rowId + num; ++i) { if (!column.isNullAt(i)) { column.putLong(i, dictionary.decodeToLong(dictionaryIds.getDictId(i))); } } - } else { + } else if (column.dataType() == DataTypes.TimestampType && + fullType.getOriginalType() == OriginalType.TIMESTAMP_MILLIS) { + for (int i = rowId; i < rowId + num; ++i) { + if (!column.isNullAt(i)) { + column.putLong(i, + DateTimeUtils.fromMillis(dictionary.decodeToLong(dictionaryIds.getDictId(i)))); + } + } + } + else { throw new UnsupportedOperationException("Unimplemented type: " + column.dataType()); } break; @@ -361,10 +382,20 @@ private void readIntBatch(int rowId, int num, ColumnVector column) throws IOExce private void readLongBatch(int rowId, int num, ColumnVector column) throws IOException { // This is where we implement support for the valid type conversions. if (column.dataType() == DataTypes.LongType || - column.dataType() == DataTypes.TimestampType || + (column.dataType() == DataTypes.TimestampType && + fullType.getOriginalType() == OriginalType.TIMESTAMP_MICROS)|| DecimalType.is64BitDecimalType(column.dataType())) { defColumn.readLongs( - num, column, rowId, maxDefLevel, (VectorizedValuesReader) dataColumn); + num, column, rowId, maxDefLevel, (VectorizedValuesReader) dataColumn); + } else if (column.dataType() == DataTypes.TimestampType && + fullType.getOriginalType() == OriginalType.TIMESTAMP_MILLIS) { + for (int i = 0; i < num; i++) { + if (defColumn.readInteger() == maxDefLevel) { + column.putLong(rowId + i, DateTimeUtils.fromMillis(dataColumn.readLong())); + } else { + column.putNull(rowId + i); + } + } } else { throw new UnsupportedOperationException("Unsupported conversion to: " + column.dataType()); } diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedParquetRecordReader.java b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedParquetRecordReader.java index 51bdf0f0f2291..568b1dd21ef8e 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedParquetRecordReader.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedParquetRecordReader.java @@ -277,7 +277,7 @@ private void checkEndOfRowGroup() throws IOException { for (int i = 0; i < columns.size(); ++i) { if (missingColumns[i]) continue; columnReaders[i] = new VectorizedColumnReader(columns.get(i), - pages.getPageReader(columns.get(i))); + pages.getPageReader(columns.get(i)), requestedSchema.getType(i)); } totalCountLoadedSoFar += pages.getRowCount(); } diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVector.java index 354c878aca000..ad267ab0c9c47 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVector.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVector.java @@ -180,7 +180,7 @@ public Object[] array() { @Override public boolean getBoolean(int ordinal) { - throw new UnsupportedOperationException(); + return data.getBoolean(offset + ordinal); } @Override @@ -188,7 +188,7 @@ public boolean getBoolean(int ordinal) { @Override public short getShort(int ordinal) { - throw new UnsupportedOperationException(); + return data.getShort(offset + ordinal); } @Override @@ -199,7 +199,7 @@ public short getShort(int ordinal) { @Override public float getFloat(int ordinal) { - throw new UnsupportedOperationException(); + return data.getFloat(offset + ordinal); } @Override @@ -801,6 +801,14 @@ public final int appendFloats(int count, float v) { return result; } + public final int appendFloats(int length, float[] src, int offset) { + reserve(elementsAppended + length); + int result = elementsAppended; + putFloats(elementsAppended, length, src, offset); + elementsAppended += length; + return result; + } + public final int appendDouble(double v) { reserve(elementsAppended + 1); putDouble(elementsAppended, v); diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OffHeapColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OffHeapColumnVector.java index e988c0722bd72..a7d3744d00e91 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OffHeapColumnVector.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OffHeapColumnVector.java @@ -436,28 +436,29 @@ public void loadBytes(ColumnVector.Array array) { // Split out the slow path. @Override protected void reserveInternal(int newCapacity) { + int oldCapacity = (this.data == 0L) ? 0 : capacity; if (this.resultArray != null) { this.lengthData = - Platform.reallocateMemory(lengthData, elementsAppended * 4, newCapacity * 4); + Platform.reallocateMemory(lengthData, oldCapacity * 4, newCapacity * 4); this.offsetData = - Platform.reallocateMemory(offsetData, elementsAppended * 4, newCapacity * 4); + Platform.reallocateMemory(offsetData, oldCapacity * 4, newCapacity * 4); } else if (type instanceof ByteType || type instanceof BooleanType) { - this.data = Platform.reallocateMemory(data, elementsAppended, newCapacity); + this.data = Platform.reallocateMemory(data, oldCapacity, newCapacity); } else if (type instanceof ShortType) { - this.data = Platform.reallocateMemory(data, elementsAppended * 2, newCapacity * 2); + this.data = Platform.reallocateMemory(data, oldCapacity * 2, newCapacity * 2); } else if (type instanceof IntegerType || type instanceof FloatType || type instanceof DateType || DecimalType.is32BitDecimalType(type)) { - this.data = Platform.reallocateMemory(data, elementsAppended * 4, newCapacity * 4); + this.data = Platform.reallocateMemory(data, oldCapacity * 4, newCapacity * 4); } else if (type instanceof LongType || type instanceof DoubleType || DecimalType.is64BitDecimalType(type) || type instanceof TimestampType) { - this.data = Platform.reallocateMemory(data, elementsAppended * 8, newCapacity * 8); + this.data = Platform.reallocateMemory(data, oldCapacity * 8, newCapacity * 8); } else if (resultStruct != null) { // Nothing to store. } else { throw new RuntimeException("Unhandled " + type); } - this.nulls = Platform.reallocateMemory(nulls, elementsAppended, newCapacity); - Platform.setMemory(nulls + elementsAppended, (byte)0, newCapacity - elementsAppended); + this.nulls = Platform.reallocateMemory(nulls, oldCapacity, newCapacity); + Platform.setMemory(nulls + oldCapacity, (byte)0, newCapacity - oldCapacity); capacity = newCapacity; } } diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OnHeapColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OnHeapColumnVector.java index 9b410bacff5df..94ed32294cfae 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OnHeapColumnVector.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OnHeapColumnVector.java @@ -410,53 +410,53 @@ protected void reserveInternal(int newCapacity) { int[] newLengths = new int[newCapacity]; int[] newOffsets = new int[newCapacity]; if (this.arrayLengths != null) { - System.arraycopy(this.arrayLengths, 0, newLengths, 0, elementsAppended); - System.arraycopy(this.arrayOffsets, 0, newOffsets, 0, elementsAppended); + System.arraycopy(this.arrayLengths, 0, newLengths, 0, capacity); + System.arraycopy(this.arrayOffsets, 0, newOffsets, 0, capacity); } arrayLengths = newLengths; arrayOffsets = newOffsets; } else if (type instanceof BooleanType) { if (byteData == null || byteData.length < newCapacity) { byte[] newData = new byte[newCapacity]; - if (byteData != null) System.arraycopy(byteData, 0, newData, 0, elementsAppended); + if (byteData != null) System.arraycopy(byteData, 0, newData, 0, capacity); byteData = newData; } } else if (type instanceof ByteType) { if (byteData == null || byteData.length < newCapacity) { byte[] newData = new byte[newCapacity]; - if (byteData != null) System.arraycopy(byteData, 0, newData, 0, elementsAppended); + if (byteData != null) System.arraycopy(byteData, 0, newData, 0, capacity); byteData = newData; } } else if (type instanceof ShortType) { if (shortData == null || shortData.length < newCapacity) { short[] newData = new short[newCapacity]; - if (shortData != null) System.arraycopy(shortData, 0, newData, 0, elementsAppended); + if (shortData != null) System.arraycopy(shortData, 0, newData, 0, capacity); shortData = newData; } } else if (type instanceof IntegerType || type instanceof DateType || DecimalType.is32BitDecimalType(type)) { if (intData == null || intData.length < newCapacity) { int[] newData = new int[newCapacity]; - if (intData != null) System.arraycopy(intData, 0, newData, 0, elementsAppended); + if (intData != null) System.arraycopy(intData, 0, newData, 0, capacity); intData = newData; } } else if (type instanceof LongType || type instanceof TimestampType || DecimalType.is64BitDecimalType(type)) { if (longData == null || longData.length < newCapacity) { long[] newData = new long[newCapacity]; - if (longData != null) System.arraycopy(longData, 0, newData, 0, elementsAppended); + if (longData != null) System.arraycopy(longData, 0, newData, 0, capacity); longData = newData; } } else if (type instanceof FloatType) { if (floatData == null || floatData.length < newCapacity) { float[] newData = new float[newCapacity]; - if (floatData != null) System.arraycopy(floatData, 0, newData, 0, elementsAppended); + if (floatData != null) System.arraycopy(floatData, 0, newData, 0, capacity); floatData = newData; } } else if (type instanceof DoubleType) { if (doubleData == null || doubleData.length < newCapacity) { double[] newData = new double[newCapacity]; - if (doubleData != null) System.arraycopy(doubleData, 0, newData, 0, elementsAppended); + if (doubleData != null) System.arraycopy(doubleData, 0, newData, 0, capacity); doubleData = newData; } } else if (resultStruct != null) { @@ -466,7 +466,7 @@ protected void reserveInternal(int newCapacity) { } byte[] newNulls = new byte[newCapacity]; - if (nulls != null) System.arraycopy(nulls, 0, newNulls, 0, elementsAppended); + if (nulls != null) System.arraycopy(nulls, 0, newNulls, 0, capacity); nulls = newNulls; capacity = newCapacity; diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala index 38029552d13bd..b23ab1fa3514a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala @@ -84,8 +84,8 @@ class TypedColumn[-T, U]( } /** - * Gives the TypedColumn a name (alias). - * If the current TypedColumn has metadata associated with it, this metadata will be propagated + * Gives the [[TypedColumn]] a name (alias). + * If the current `TypedColumn` has metadata associated with it, this metadata will be propagated * to the new column. * * @group expr_ops @@ -99,16 +99,14 @@ class TypedColumn[-T, U]( /** * A column that will be computed based on the data in a `DataFrame`. * - * A new column is constructed based on the input columns present in a dataframe: + * A new column can be constructed based on the input columns present in a DataFrame: * * {{{ - * df("columnName") // On a specific DataFrame. + * df("columnName") // On a specific `df` DataFrame. * col("columnName") // A generic column no yet associated with a DataFrame. * col("columnName.field") // Extracting a struct field * col("`a.column.with.dots`") // Escape `.` in column names. * $"columnName" // Scala short hand for a named column. - * expr("a + 1") // A column that is constructed from a parsed SQL Expression. - * lit("abc") // A column that produces a literal (constant) value. * }}} * * [[Column]] objects can be composed to form complex expressions: @@ -118,7 +116,7 @@ class TypedColumn[-T, U]( * $"a" === $"b" * }}} * - * @note The internal Catalyst expression can be accessed via "expr", but this method is for + * @note The internal Catalyst expression can be accessed via [[expr]], but this method is for * debugging purposes only and can change in any future Spark releases. * * @groupname java_expr_ops Java-specific expression operators @@ -781,7 +779,7 @@ class Column(val expr: Expression) extends Logging { def isin(list: Any*): Column = withExpr { In(expr, list.map(lit(_).expr)) } /** - * SQL like expression. + * SQL like expression. Returns a boolean column based on a SQL LIKE match. * * @group expr_ops * @since 1.3.0 @@ -789,7 +787,8 @@ class Column(val expr: Expression) extends Logging { def like(literal: String): Column = withExpr { Like(expr, lit(literal).expr) } /** - * SQL RLIKE expression (LIKE with Regex). + * SQL RLIKE expression (LIKE with Regex). Returns a boolean column based on a regex + * match. * * @group expr_ops * @since 1.3.0 @@ -840,7 +839,7 @@ class Column(val expr: Expression) extends Logging { } /** - * Contains the other element. + * Contains the other element. Returns a boolean column based on a string match. * * @group expr_ops * @since 1.3.0 @@ -848,7 +847,7 @@ class Column(val expr: Expression) extends Logging { def contains(other: Any): Column = withExpr { Contains(expr, lit(other).expr) } /** - * String starts with. + * String starts with. Returns a boolean column based on a string match. * * @group expr_ops * @since 1.3.0 @@ -856,7 +855,7 @@ class Column(val expr: Expression) extends Logging { def startsWith(other: Column): Column = withExpr { StartsWith(expr, lit(other).expr) } /** - * String starts with another string literal. + * String starts with another string literal. Returns a boolean column based on a string match. * * @group expr_ops * @since 1.3.0 @@ -864,7 +863,7 @@ class Column(val expr: Expression) extends Logging { def startsWith(literal: String): Column = this.startsWith(lit(literal)) /** - * String ends with. + * String ends with. Returns a boolean column based on a string match. * * @group expr_ops * @since 1.3.0 @@ -872,7 +871,7 @@ class Column(val expr: Expression) extends Logging { def endsWith(other: Column): Column = withExpr { EndsWith(expr, lit(other).expr) } /** - * String ends with another string literal. + * String ends with another string literal. Returns a boolean column based on a string match. * * @group expr_ops * @since 1.3.0 @@ -1010,7 +1009,7 @@ class Column(val expr: Expression) extends Logging { def cast(to: String): Column = cast(CatalystSqlParser.parseDataType(to)) /** - * Returns an ordering used in sorting. + * Returns a sort expression based on the descending order of the column. * {{{ * // Scala * df.sort(df("age").desc) @@ -1025,7 +1024,8 @@ class Column(val expr: Expression) extends Logging { def desc: Column = withExpr { SortOrder(expr, Descending) } /** - * Returns a descending ordering used in sorting, where null values appear before non-null values. + * Returns a sort expression based on the descending order of the column, + * and null values appear before non-null values. * {{{ * // Scala: sort a DataFrame by age column in descending order and null values appearing first. * df.sort(df("age").desc_nulls_first) @@ -1037,10 +1037,11 @@ class Column(val expr: Expression) extends Logging { * @group expr_ops * @since 2.1.0 */ - def desc_nulls_first: Column = withExpr { SortOrder(expr, Descending, NullsFirst) } + def desc_nulls_first: Column = withExpr { SortOrder(expr, Descending, NullsFirst, Set.empty) } /** - * Returns a descending ordering used in sorting, where null values appear after non-null values. + * Returns a sort expression based on the descending order of the column, + * and null values appear after non-null values. * {{{ * // Scala: sort a DataFrame by age column in descending order and null values appearing last. * df.sort(df("age").desc_nulls_last) @@ -1052,10 +1053,10 @@ class Column(val expr: Expression) extends Logging { * @group expr_ops * @since 2.1.0 */ - def desc_nulls_last: Column = withExpr { SortOrder(expr, Descending, NullsLast) } + def desc_nulls_last: Column = withExpr { SortOrder(expr, Descending, NullsLast, Set.empty) } /** - * Returns an ascending ordering used in sorting. + * Returns a sort expression based on ascending order of the column. * {{{ * // Scala: sort a DataFrame by age column in ascending order. * df.sort(df("age").asc) @@ -1070,7 +1071,8 @@ class Column(val expr: Expression) extends Logging { def asc: Column = withExpr { SortOrder(expr, Ascending) } /** - * Returns an ascending ordering used in sorting, where null values appear before non-null values. + * Returns a sort expression based on ascending order of the column, + * and null values return before non-null values. * {{{ * // Scala: sort a DataFrame by age column in ascending order and null values appearing first. * df.sort(df("age").asc_nulls_last) @@ -1082,10 +1084,11 @@ class Column(val expr: Expression) extends Logging { * @group expr_ops * @since 2.1.0 */ - def asc_nulls_first: Column = withExpr { SortOrder(expr, Ascending, NullsFirst) } + def asc_nulls_first: Column = withExpr { SortOrder(expr, Ascending, NullsFirst, Set.empty) } /** - * Returns an ordering used in sorting, where null values appear after non-null values. + * Returns a sort expression based on ascending order of the column, + * and null values appear after non-null values. * {{{ * // Scala: sort a DataFrame by age column in ascending order and null values appearing last. * df.sort(df("age").asc_nulls_last) @@ -1097,10 +1100,10 @@ class Column(val expr: Expression) extends Logging { * @group expr_ops * @since 2.1.0 */ - def asc_nulls_last: Column = withExpr { SortOrder(expr, Ascending, NullsLast) } + def asc_nulls_last: Column = withExpr { SortOrder(expr, Ascending, NullsLast, Set.empty) } /** - * Prints the expression to the console for debugging purpose. + * Prints the expression to the console for debugging purposes. * * @group df_ops * @since 1.3.0 @@ -1154,8 +1157,8 @@ class Column(val expr: Expression) extends Logging { * {{{ * val w = Window.partitionBy("name").orderBy("id") * df.select( - * sum("price").over(w.rangeBetween(Long.MinValue, 2)), - * avg("price").over(w.rowsBetween(0, 4)) + * sum("price").over(w.rangeBetween(Window.unboundedPreceding, 2)), + * avg("price").over(w.rowsBetween(Window.currentRow, 4)) * ) * }}} * diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala index 28820681cd3a6..052d85ad33bd6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql import java.{lang => jl} +import java.util.Locale import scala.collection.JavaConverters._ @@ -89,7 +90,7 @@ final class DataFrameNaFunctions private[sql](df: DataFrame) { * @since 1.3.1 */ def drop(how: String, cols: Seq[String]): DataFrame = { - how.toLowerCase match { + how.toLowerCase(Locale.ROOT) match { case "any" => drop(cols.size, cols) case "all" => drop(1, cols) case _ => throw new IllegalArgumentException(s"how ($how) must be 'any' or 'all'") @@ -410,7 +411,7 @@ final class DataFrameNaFunctions private[sql](df: DataFrame) { nanvl(df.col(quotedColName), lit(null)) // nanvl only supports these types case _ => df.col(quotedColName) } - coalesce(colValue, lit(replacement)).cast(col.dataType).as(col.name) + coalesce(colValue, lit(replacement).cast(col.dataType)).as(col.name) } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala index 41470ae6aae19..c1b32917415ae 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql -import java.util.Properties +import java.util.{Locale, Properties} import scala.collection.JavaConverters._ @@ -29,9 +29,10 @@ import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.json.{CreateJacksonParser, JacksonParser, JSONOptions} import org.apache.spark.sql.execution.LogicalRDD import org.apache.spark.sql.execution.command.DDLUtils -import org.apache.spark.sql.execution.datasources.DataSource +import org.apache.spark.sql.execution.datasources.{DataSource, FailureSafeParser} +import org.apache.spark.sql.execution.datasources.csv._ import org.apache.spark.sql.execution.datasources.jdbc._ -import org.apache.spark.sql.execution.datasources.json.JsonInferSchema +import org.apache.spark.sql.execution.datasources.json.TextInputJsonDataSource import org.apache.spark.sql.types.{StringType, StructType} import org.apache.spark.unsafe.types.UTF8String @@ -69,6 +70,12 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging { /** * Adds an input option for the underlying data source. * + * You can set the following option(s): + *
      + *
    • `timeZone` (default session local timezone): sets the string that indicates a timezone + * to be used to parse timestamps in the JSON/CSV datasources or partition values.
    • + *
    + * * @since 1.4.0 */ def option(key: String, value: String): DataFrameReader = { @@ -100,6 +107,12 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging { /** * (Scala-specific) Adds input options for the underlying data source. * + * You can set the following option(s): + *
      + *
    • `timeZone` (default session local timezone): sets the string that indicates a timezone + * to be used to parse timestamps in the JSON/CSV datasources or partition values.
    • + *
    + * * @since 1.4.0 */ def options(options: scala.collection.Map[String, String]): DataFrameReader = { @@ -110,6 +123,12 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging { /** * Adds input options for the underlying data source. * + * You can set the following option(s): + *
      + *
    • `timeZone` (default session local timezone): sets the string that indicates a timezone + * to be used to parse timestamps in the JSON/CSV datasources or partition values.
    • + *
    + * * @since 1.4.0 */ def options(options: java.util.Map[String, String]): DataFrameReader = { @@ -145,7 +164,7 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging { */ @scala.annotation.varargs def load(paths: String*): DataFrame = { - if (source.toLowerCase == DDLUtils.HIVE_PROVIDER) { + if (source.toLowerCase(Locale.ROOT) == DDLUtils.HIVE_PROVIDER) { throw new AnalysisException("Hive data source can only be used with tables, you can not " + "read files of Hive data source directly.") } @@ -249,8 +268,8 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging { } /** - * Loads a JSON file (JSON Lines text format or - * newline-delimited JSON) and returns the result as a `DataFrame`. + * Loads a JSON file and returns the results as a `DataFrame`. + * * See the documentation on the overloaded `json()` method with varargs for more details. * * @since 1.4.0 @@ -261,7 +280,7 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging { } /** - * Loads a JSON file and returns the results as a `DataFrame`. + * Loads JSON files and returns the results as a `DataFrame`. * * JSON Lines (newline-delimited JSON) is supported by * default. For JSON (one record per file), set the `wholeFile` option to true. @@ -301,11 +320,9 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging { *
  • `dateFormat` (default `yyyy-MM-dd`): sets the string that indicates a date format. * Custom date formats follow the formats at `java.text.SimpleDateFormat`. This applies to * date type.
  • - *
  • `timestampFormat` (default `yyyy-MM-dd'T'HH:mm:ss.SSSZZ`): sets the string that + *
  • `timestampFormat` (default `yyyy-MM-dd'T'HH:mm:ss.SSSXXX`): sets the string that * indicates a timestamp format. Custom date formats follow the formats at * `java.text.SimpleDateFormat`. This applies to timestamp type.
  • - *
  • `timeZone` (default session local timezone): sets the string that indicates a timezone - * to be used to parse timestamps.
  • *
  • `wholeFile` (default `false`): parse one record, which may span multiple lines, * per file
  • * @@ -359,27 +376,24 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging { extraOptions.toMap, sparkSession.sessionState.conf.sessionLocalTimeZone, sparkSession.sessionState.conf.columnNameOfCorruptRecord) - val createParser = CreateJacksonParser.string _ val schema = userSpecifiedSchema.getOrElse { - JsonInferSchema.infer( - jsonDataset.rdd, - parsedOptions, - createParser) + TextInputJsonDataSource.inferFromDataset(jsonDataset, parsedOptions) } - // Check a field requirement for corrupt records here to throw an exception in a driver side - schema.getFieldIndex(parsedOptions.columnNameOfCorruptRecord).foreach { corruptFieldIndex => - val f = schema(corruptFieldIndex) - if (f.dataType != StringType || !f.nullable) { - throw new AnalysisException( - "The field for corrupt records must be string type and nullable") - } - } + verifyColumnNameOfCorruptRecord(schema, parsedOptions.columnNameOfCorruptRecord) + val actualSchema = + StructType(schema.filterNot(_.name == parsedOptions.columnNameOfCorruptRecord)) + val createParser = CreateJacksonParser.string _ val parsed = jsonDataset.rdd.mapPartitions { iter => - val parser = new JacksonParser(schema, parsedOptions) - iter.flatMap(parser.parse(_, createParser, UTF8String.fromString)) + val rawParser = new JacksonParser(actualSchema, parsedOptions) + val parser = new FailureSafeParser[String]( + input => rawParser.parse(input, createParser, UTF8String.fromString), + parsedOptions.parseMode, + schema, + parsedOptions.columnNameOfCorruptRecord) + iter.flatMap(parser.parse) } Dataset.ofRows( @@ -399,7 +413,59 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging { } /** - * Loads a CSV file and returns the result as a `DataFrame`. + * Loads an `Dataset[String]` storing CSV rows and returns the result as a `DataFrame`. + * + * If the schema is not specified using `schema` function and `inferSchema` option is enabled, + * this function goes through the input once to determine the input schema. + * + * If the schema is not specified using `schema` function and `inferSchema` option is disabled, + * it determines the columns as string types and it reads only the first line to determine the + * names and the number of fields. + * + * @param csvDataset input Dataset with one CSV row per record + * @since 2.2.0 + */ + def csv(csvDataset: Dataset[String]): DataFrame = { + val parsedOptions: CSVOptions = new CSVOptions( + extraOptions.toMap, + sparkSession.sessionState.conf.sessionLocalTimeZone) + val filteredLines: Dataset[String] = + CSVUtils.filterCommentAndEmpty(csvDataset, parsedOptions) + val maybeFirstLine: Option[String] = filteredLines.take(1).headOption + + val schema = userSpecifiedSchema.getOrElse { + TextInputCSVDataSource.inferFromDataset( + sparkSession, + csvDataset, + maybeFirstLine, + parsedOptions) + } + + verifyColumnNameOfCorruptRecord(schema, parsedOptions.columnNameOfCorruptRecord) + val actualSchema = + StructType(schema.filterNot(_.name == parsedOptions.columnNameOfCorruptRecord)) + + val linesWithoutHeader: RDD[String] = maybeFirstLine.map { firstLine => + filteredLines.rdd.mapPartitions(CSVUtils.filterHeaderLine(_, firstLine, parsedOptions)) + }.getOrElse(filteredLines.rdd) + + val parsed = linesWithoutHeader.mapPartitions { iter => + val rawParser = new UnivocityParser(actualSchema, parsedOptions) + val parser = new FailureSafeParser[String]( + input => Seq(rawParser.parse(input)), + parsedOptions.parseMode, + schema, + parsedOptions.columnNameOfCorruptRecord) + iter.flatMap(parser.parse) + } + + Dataset.ofRows( + sparkSession, + LogicalRDD(schema.toAttributes, parsed)(sparkSession)) + } + + /** + * Loads CSV files and returns the result as a `DataFrame`. * * This function will go through the input once to determine the input schema if `inferSchema` * is enabled. To avoid going through the entire data once, disable `inferSchema` option or @@ -422,9 +488,9 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging { *
  • `header` (default `false`): uses the first line as names of columns.
  • *
  • `inferSchema` (default `false`): infers the input schema automatically from data. It * requires one extra pass over the data.
  • - *
  • `ignoreLeadingWhiteSpace` (default `false`): defines whether or not leading whitespaces - * from values being read should be skipped.
  • - *
  • `ignoreTrailingWhiteSpace` (default `false`): defines whether or not trailing + *
  • `ignoreLeadingWhiteSpace` (default `false`): a flag indicating whether or not leading + * whitespaces from values being read should be skipped.
  • + *
  • `ignoreTrailingWhiteSpace` (default `false`): a flag indicating whether or not trailing * whitespaces from values being read should be skipped.
  • *
  • `nullValue` (default empty string): sets the string representation of a null value. Since * 2.0.1, this applies to all supported types including the string type.
  • @@ -436,19 +502,15 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging { *
  • `dateFormat` (default `yyyy-MM-dd`): sets the string that indicates a date format. * Custom date formats follow the formats at `java.text.SimpleDateFormat`. This applies to * date type.
  • - *
  • `timestampFormat` (default `yyyy-MM-dd'T'HH:mm:ss.SSSZZ`): sets the string that + *
  • `timestampFormat` (default `yyyy-MM-dd'T'HH:mm:ss.SSSXXX`): sets the string that * indicates a timestamp format. Custom date formats follow the formats at * `java.text.SimpleDateFormat`. This applies to timestamp type.
  • - *
  • `timeZone` (default session local timezone): sets the string that indicates a timezone - * to be used to parse timestamps.
  • *
  • `maxColumns` (default `20480`): defines a hard limit of how many columns * a record can have.
  • *
  • `maxCharsPerColumn` (default `-1`): defines the maximum number of characters allowed * for any given value being read. By default, it is -1 meaning unlimited length
  • - *
  • `maxMalformedLogPerPartition` (default `10`): sets the maximum number of malformed rows - * Spark will log for each partition. Malformed records beyond this number will be ignored.
  • *
  • `mode` (default `PERMISSIVE`): allows a mode for dealing with corrupt records - * during parsing. + * during parsing. It supports the following case-insensitive modes. *
      *
    • `PERMISSIVE` : sets other fields to `null` when it meets a corrupted record, and puts * the malformed string into a field configured by `columnNameOfCorruptRecord`. To keep @@ -510,7 +572,7 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging { } /** - * Loads an ORC file and returns the result as a `DataFrame`. + * Loads ORC files and returns the result as a `DataFrame`. * * @param paths input paths * @since 2.0.0 @@ -604,6 +666,22 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging { } } + /** + * A convenient function for schema validation in datasources supporting + * `columnNameOfCorruptRecord` as an option. + */ + private def verifyColumnNameOfCorruptRecord( + schema: StructType, + columnNameOfCorruptRecord: String): Unit = { + schema.getFieldIndex(columnNameOfCorruptRecord).foreach { corruptFieldIndex => + val f = schema(corruptFieldIndex) + if (f.dataType != StringType || !f.nullable) { + throw new AnalysisException( + "The field for corrupt records must be string type and nullable") + } + } + } + /////////////////////////////////////////////////////////////////////////////////////// // Builder pattern config options /////////////////////////////////////////////////////////////////////////////////////// diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala index bdcdf0c61ff36..c856d3099f6ee 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala @@ -64,7 +64,7 @@ final class DataFrameStatFunctions private[sql](df: DataFrame) { * @return the approximate quantiles at the given probabilities * * @note null and NaN values will be removed from the numerical column before calculation. If - * the dataframe is empty or all rows contain null or NaN, null is returned. + * the dataframe is empty or the column only contains null or NaN, an empty array is returned. * * @since 2.0.0 */ @@ -72,8 +72,7 @@ final class DataFrameStatFunctions private[sql](df: DataFrame) { col: String, probabilities: Array[Double], relativeError: Double): Array[Double] = { - val res = approxQuantile(Array(col), probabilities, relativeError) - Option(res).map(_.head).orNull + approxQuantile(Array(col), probabilities, relativeError).head } /** @@ -89,8 +88,8 @@ final class DataFrameStatFunctions private[sql](df: DataFrame) { * Note that values greater than 1 are accepted but give the same result as 1. * @return the approximate quantiles at the given probabilities of each column * - * @note Rows containing any null or NaN values will be removed before calculation. If - * the dataframe is empty or all rows contain null or NaN, null is returned. + * @note null and NaN values will be ignored in numerical columns before calculation. For + * columns only containing null or NaN values, an empty array is returned. * * @since 2.2.0 */ @@ -98,13 +97,11 @@ final class DataFrameStatFunctions private[sql](df: DataFrame) { cols: Array[String], probabilities: Array[Double], relativeError: Double): Array[Array[Double]] = { - // TODO: Update NaN/null handling to keep consistent with the single-column version - try { - StatFunctions.multipleApproxQuantiles(df.select(cols.map(col): _*).na.drop(), cols, - probabilities, relativeError).map(_.toArray).toArray - } catch { - case e: NoSuchElementException => null - } + StatFunctions.multipleApproxQuantiles( + df.select(cols.map(col): _*), + cols, + probabilities, + relativeError).map(_.toArray).toArray } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala index 49e85dc7b13f6..1732a8e08b73f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql -import java.util.Properties +import java.util.{Locale, Properties} import scala.collection.JavaConverters._ @@ -66,7 +66,7 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) { * @since 1.4.0 */ def mode(saveMode: String): DataFrameWriter[T] = { - this.mode = saveMode.toLowerCase match { + this.mode = saveMode.toLowerCase(Locale.ROOT) match { case "overwrite" => SaveMode.Overwrite case "append" => SaveMode.Append case "ignore" => SaveMode.Ignore @@ -90,6 +90,12 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) { /** * Adds an output option for the underlying data source. * + * You can set the following option(s): + *
        + *
      • `timeZone` (default session local timezone): sets the string that indicates a timezone + * to be used to format timestamps in the JSON/CSV datasources or partition values.
      • + *
      + * * @since 1.4.0 */ def option(key: String, value: String): DataFrameWriter[T] = { @@ -121,6 +127,12 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) { /** * (Scala-specific) Adds output options for the underlying data source. * + * You can set the following option(s): + *
        + *
      • `timeZone` (default session local timezone): sets the string that indicates a timezone + * to be used to format timestamps in the JSON/CSV datasources or partition values.
      • + *
      + * * @since 1.4.0 */ def options(options: scala.collection.Map[String, String]): DataFrameWriter[T] = { @@ -131,6 +143,12 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) { /** * Adds output options for the underlying data source. * + * You can set the following option(s): + *
        + *
      • `timeZone` (default session local timezone): sets the string that indicates a timezone + * to be used to format timestamps in the JSON/CSV datasources or partition values.
      • + *
      + * * @since 1.4.0 */ def options(options: java.util.Map[String, String]): DataFrameWriter[T] = { @@ -205,7 +223,7 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) { * @since 1.4.0 */ def save(): Unit = { - if (source.toLowerCase == DDLUtils.HIVE_PROVIDER) { + if (source.toLowerCase(Locale.ROOT) == DDLUtils.HIVE_PROVIDER) { throw new AnalysisException("Hive data source can only be used with tables, you can not " + "write files of Hive data source directly.") } @@ -319,6 +337,11 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) { * +---+---+ * }}} * + * In this method, save mode is used to determine the behavior if the data source table exists in + * Spark catalog. We will always overwrite the underlying data of data source (e.g. a table in + * JDBC data source) if the table doesn't exist in Spark catalog, and will always append to the + * underlying data of data source if the table already exists. + * * When the DataFrame is created from a non-partitioned `HadoopFsRelation` with a single input * path, and the data source provider can be mapped to an existing Hive builtin SerDe (i.e. ORC * and Parquet), the table is persisted in a Hive compatible format, which means other systems @@ -454,11 +477,9 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) { *
    • `dateFormat` (default `yyyy-MM-dd`): sets the string that indicates a date format. * Custom date formats follow the formats at `java.text.SimpleDateFormat`. This applies to * date type.
    • - *
    • `timestampFormat` (default `yyyy-MM-dd'T'HH:mm:ss.SSSZZ`): sets the string that + *
    • `timestampFormat` (default `yyyy-MM-dd'T'HH:mm:ss.SSSXXX`): sets the string that * indicates a timestamp format. Custom date formats follow the formats at * `java.text.SimpleDateFormat`. This applies to timestamp type.
    • - *
    • `timeZone` (default session local timezone): sets the string that indicates a timezone - * to be used to format timestamps.
    • *
    * * @since 1.4.0 @@ -552,7 +573,7 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) { *
  • `escapeQuotes` (default `true`): a flag indicating whether values containing * quotes should always be enclosed in quotes. Default is to escape all values containing * a quote character.
  • - *
  • `quoteAll` (default `false`): A flag indicating whether all values should always be + *
  • `quoteAll` (default `false`): a flag indicating whether all values should always be * enclosed in quotes. Default is to only escape values containing a quote character.
  • *
  • `header` (default `false`): writes the names of columns as the first line.
  • *
  • `nullValue` (default empty string): sets the string representation of a null value.
  • @@ -562,11 +583,13 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) { *
  • `dateFormat` (default `yyyy-MM-dd`): sets the string that indicates a date format. * Custom date formats follow the formats at `java.text.SimpleDateFormat`. This applies to * date type.
  • - *
  • `timestampFormat` (default `yyyy-MM-dd'T'HH:mm:ss.SSSZZ`): sets the string that + *
  • `timestampFormat` (default `yyyy-MM-dd'T'HH:mm:ss.SSSXXX`): sets the string that * indicates a timestamp format. Custom date formats follow the formats at * `java.text.SimpleDateFormat`. This applies to timestamp type.
  • - *
  • `timeZone` (default session local timezone): sets the string that indicates a timezone - * to be used to format timestamps.
  • + *
  • `ignoreLeadingWhiteSpace` (default `true`): a flag indicating whether or not leading + * whitespaces from values being written should be skipped.
  • + *
  • `ignoreTrailingWhiteSpace` (default `true`): a flag indicating defines whether or not + * trailing whitespaces from values being written should be skipped.
  • * * * @since 2.0.0 diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index db782371ee818..e908ce2cf45dd 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -36,6 +36,7 @@ import org.apache.spark.broadcast.Broadcast import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst._ import org.apache.spark.sql.catalyst.analysis._ +import org.apache.spark.sql.catalyst.catalog.CatalogRelation import org.apache.spark.sql.catalyst.encoders._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate._ @@ -239,8 +240,10 @@ class Dataset[T] private[sql]( * @param _numRows Number of rows to show * @param truncate If set to more than 0, truncates strings to `truncate` characters and * all cells will be aligned right. + * @param vertical If set to true, prints output rows vertically (one line per column value). */ - private[sql] def showString(_numRows: Int, truncate: Int = 20): String = { + private[sql] def showString( + _numRows: Int, truncate: Int = 20, vertical: Boolean = false): String = { val numRows = _numRows.max(0) val takeResult = toDF().take(numRows + 1) val hasMoreData = takeResult.length > numRows @@ -276,46 +279,80 @@ class Dataset[T] private[sql]( val sb = new StringBuilder val numCols = schema.fieldNames.length + // We set a minimum column width at '3' + val minimumColWidth = 3 - // Initialise the width of each column to a minimum value of '3' - val colWidths = Array.fill(numCols)(3) + if (!vertical) { + // Initialise the width of each column to a minimum value + val colWidths = Array.fill(numCols)(minimumColWidth) - // Compute the width of each column - for (row <- rows) { - for ((cell, i) <- row.zipWithIndex) { - colWidths(i) = math.max(colWidths(i), cell.length) - } - } - - // Create SeparateLine - val sep: String = colWidths.map("-" * _).addString(sb, "+", "+", "+\n").toString() - - // column names - rows.head.zipWithIndex.map { case (cell, i) => - if (truncate > 0) { - StringUtils.leftPad(cell, colWidths(i)) - } else { - StringUtils.rightPad(cell, colWidths(i)) + // Compute the width of each column + for (row <- rows) { + for ((cell, i) <- row.zipWithIndex) { + colWidths(i) = math.max(colWidths(i), cell.length) + } } - }.addString(sb, "|", "|", "|\n") - sb.append(sep) + // Create SeparateLine + val sep: String = colWidths.map("-" * _).addString(sb, "+", "+", "+\n").toString() - // data - rows.tail.map { - _.zipWithIndex.map { case (cell, i) => + // column names + rows.head.zipWithIndex.map { case (cell, i) => if (truncate > 0) { - StringUtils.leftPad(cell.toString, colWidths(i)) + StringUtils.leftPad(cell, colWidths(i)) } else { - StringUtils.rightPad(cell.toString, colWidths(i)) + StringUtils.rightPad(cell, colWidths(i)) } }.addString(sb, "|", "|", "|\n") - } - sb.append(sep) + sb.append(sep) - // For Data that has more than "numRows" records - if (hasMoreData) { + // data + rows.tail.foreach { + _.zipWithIndex.map { case (cell, i) => + if (truncate > 0) { + StringUtils.leftPad(cell.toString, colWidths(i)) + } else { + StringUtils.rightPad(cell.toString, colWidths(i)) + } + }.addString(sb, "|", "|", "|\n") + } + + sb.append(sep) + } else { + // Extended display mode enabled + val fieldNames = rows.head + val dataRows = rows.tail + + // Compute the width of field name and data columns + val fieldNameColWidth = fieldNames.foldLeft(minimumColWidth) { case (curMax, fieldName) => + math.max(curMax, fieldName.length) + } + val dataColWidth = dataRows.foldLeft(minimumColWidth) { case (curMax, row) => + math.max(curMax, row.map(_.length).reduceLeftOption[Int] { case (cellMax, cell) => + math.max(cellMax, cell) + }.getOrElse(0)) + } + + dataRows.zipWithIndex.foreach { case (row, i) => + // "+ 5" in size means a character length except for padded names and data + val rowHeader = StringUtils.rightPad( + s"-RECORD $i", fieldNameColWidth + dataColWidth + 5, "-") + sb.append(rowHeader).append("\n") + row.zipWithIndex.map { case (cell, j) => + val fieldName = StringUtils.rightPad(fieldNames(j), fieldNameColWidth) + val data = StringUtils.rightPad(cell, dataColWidth) + s" $fieldName | $data " + }.addString(sb, "", "\n", "\n") + } + } + + // Print a footer + if (vertical && data.isEmpty) { + // In a vertical mode, print an empty row set explicitly + sb.append("(0 rows)\n") + } else if (hasMoreData) { + // For Data that has more than "numRows" records val rowsString = if (numRows == 1) "row" else "rows" sb.append(s"only showing top $numRows $rowsString\n") } @@ -563,7 +600,7 @@ class Dataset[T] private[sql]( * @param eventTime the name of the column that contains the event time of the row. * @param delayThreshold the minimum delay to wait to data to arrive late, relative to the latest * record that has been processed in the form of an interval - * (e.g. "1 minute" or "5 hours"). + * (e.g. "1 minute" or "5 hours"). NOTE: This should not be negative. * * @group streaming * @since 2.1.0 @@ -576,6 +613,8 @@ class Dataset[T] private[sql]( val parsedDelay = Option(CalendarInterval.fromString("interval " + delayThreshold)) .getOrElse(throw new AnalysisException(s"Unable to parse time delay '$delayThreshold'")) + require(parsedDelay.milliseconds >= 0 && parsedDelay.months >= 0, + s"delay threshold ($delayThreshold) should not be negative.") EventTimeWatermark(UnresolvedAttribute(eventTime), parsedDelay, logicalPlan) } @@ -660,8 +699,59 @@ class Dataset[T] private[sql]( * @group action * @since 1.6.0 */ + def show(numRows: Int, truncate: Int): Unit = show(numRows, truncate, vertical = false) + + /** + * Displays the Dataset in a tabular form. For example: + * {{{ + * year month AVG('Adj Close) MAX('Adj Close) + * 1980 12 0.503218 0.595103 + * 1981 01 0.523289 0.570307 + * 1982 02 0.436504 0.475256 + * 1983 03 0.410516 0.442194 + * 1984 04 0.450090 0.483521 + * }}} + * + * If `vertical` enabled, this command prints output rows vertically (one line per column value)? + * + * {{{ + * -RECORD 0------------------- + * year | 1980 + * month | 12 + * AVG('Adj Close) | 0.503218 + * AVG('Adj Close) | 0.595103 + * -RECORD 1------------------- + * year | 1981 + * month | 01 + * AVG('Adj Close) | 0.523289 + * AVG('Adj Close) | 0.570307 + * -RECORD 2------------------- + * year | 1982 + * month | 02 + * AVG('Adj Close) | 0.436504 + * AVG('Adj Close) | 0.475256 + * -RECORD 3------------------- + * year | 1983 + * month | 03 + * AVG('Adj Close) | 0.410516 + * AVG('Adj Close) | 0.442194 + * -RECORD 4------------------- + * year | 1984 + * month | 04 + * AVG('Adj Close) | 0.450090 + * AVG('Adj Close) | 0.483521 + * }}} + * + * @param numRows Number of rows to show + * @param truncate If set to more than 0, truncates strings to `truncate` characters and + * all cells will be aligned right. + * @param vertical If set to true, prints output rows vertically (one line per column value). + * @group action + * @since 2.3.0 + */ // scalastyle:off println - def show(numRows: Int, truncate: Int): Unit = println(showString(numRows, truncate)) + def show(numRows: Int, truncate: Int, vertical: Boolean): Unit = + println(showString(numRows, truncate, vertical)) // scalastyle:on println /** @@ -1070,6 +1160,22 @@ class Dataset[T] private[sql]( */ def apply(colName: String): Column = col(colName) + /** + * Specifies some hint on the current Dataset. As an example, the following code specifies + * that one of the plan can be broadcasted: + * + * {{{ + * df1.join(df2.hint("broadcast")) + * }}} + * + * @group basic + * @since 2.2.0 + */ + @scala.annotation.varargs + def hint(name: String, parameters: String*): Dataset[T] = withTypedPlan { + Hint(name, parameters, logicalPlan) + } + /** * Selects column based on the column name and return it as a [[Column]]. * @@ -1093,7 +1199,7 @@ class Dataset[T] private[sql]( * @since 1.6.0 */ def as(alias: String): Dataset[T] = withTypedPlan { - SubqueryAlias(alias, logicalPlan, None) + SubqueryAlias(alias, logicalPlan) } /** @@ -1723,15 +1829,23 @@ class Dataset[T] private[sql]( // It is possible that the underlying dataframe doesn't guarantee the ordering of rows in its // constituent partitions each time a split is materialized which could result in // overlapping splits. To prevent this, we explicitly sort each input partition to make the - // ordering deterministic. - // MapType cannot be sorted. - val sorted = Sort(logicalPlan.output.filterNot(_.dataType.isInstanceOf[MapType]) - .map(SortOrder(_, Ascending)), global = false, logicalPlan) + // ordering deterministic. Note that MapTypes cannot be sorted and are explicitly pruned out + // from the sort order. + val sortOrder = logicalPlan.output + .filter(attr => RowOrdering.isOrderable(attr.dataType)) + .map(SortOrder(_, Ascending)) + val plan = if (sortOrder.nonEmpty) { + Sort(sortOrder, global = false, logicalPlan) + } else { + // SPARK-12662: If sort order is empty, we materialize the dataset to guarantee determinism + cache() + logicalPlan + } val sum = weights.sum val normalizedCumWeights = weights.map(_ / sum).scanLeft(0.0d)(_ + _) normalizedCumWeights.sliding(2).map { x => new Dataset[T]( - sparkSession, Sample(x(0), x(1), withReplacement = false, seed, sorted)(), encoder) + sparkSession, Sample(x(0), x(1), withReplacement = false, seed, plan)(), encoder) }.toArray } @@ -2441,11 +2555,11 @@ class Dataset[T] private[sql]( } /** - * Returns a new Dataset that has exactly `numPartitions` partitions. - * Similar to coalesce defined on an `RDD`, this operation results in a narrow dependency, e.g. - * if you go from 1000 partitions to 100 partitions, there will not be a shuffle, instead each of - * the 100 new partitions will claim 10 of the current partitions. If a larger number of - * partitions is requested, it will stay at the current number of partitions. + * Returns a new Dataset that has exactly `numPartitions` partitions, when the fewer partitions + * are requested. If a larger number of partitions is requested, it will stay at the current + * number of partitions. Similar to coalesce defined on an `RDD`, this operation results in + * a narrow dependency, e.g. if you go from 1000 partitions to 100 partitions, there will not + * be a shuffle, instead each of the 100 new partitions will claim 10 of the current partitions. * * However, if you're doing a drastic coalesce, e.g. to numPartitions = 1, * this may result in your computation taking place on fewer nodes than @@ -2731,6 +2845,8 @@ class Dataset[T] private[sql]( fsBasedRelation.inputFiles case fr: FileRelation => fr.inputFiles + case r: CatalogRelation if DDLUtils.isHiveTable(r.tableMeta) => + r.tableMeta.storage.locationUri.map(_.toString).toArray }.flatten files.toSet.toArray } @@ -2772,7 +2888,7 @@ class Dataset[T] private[sql]( * Wrap a Dataset action to track all Spark jobs in the body so that we can connect them with * an execution. */ - private[sql] def withNewExecutionId[U](body: => U): U = { + private def withNewExecutionId[U](body: => U): U = { SQLExecution.withNewExecutionId(sparkSession, queryExecution)(body) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DatasetHolder.scala b/sql/core/src/main/scala/org/apache/spark/sql/DatasetHolder.scala index 18bccee98f610..582d4a3670b8e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DatasetHolder.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DatasetHolder.scala @@ -24,7 +24,8 @@ import org.apache.spark.annotation.InterfaceStability * * To use this, import implicit conversions in SQL: * {{{ - * import sqlContext.implicits._ + * val spark: SparkSession = ... + * import spark.implicits._ * }}} * * @since 1.6.0 diff --git a/sql/core/src/main/scala/org/apache/spark/sql/ExperimentalMethods.scala b/sql/core/src/main/scala/org/apache/spark/sql/ExperimentalMethods.scala index 1e8ba51e59e33..bd8dd6ea3fe0f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/ExperimentalMethods.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/ExperimentalMethods.scala @@ -46,4 +46,10 @@ class ExperimentalMethods private[sql]() { @volatile var extraOptimizations: Seq[Rule[LogicalPlan]] = Nil + override def clone(): ExperimentalMethods = { + val result = new ExperimentalMethods + result.extraStrategies = extraStrategies + result.extraOptimizations = extraOptimizations + result + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala index 3a548c251f5b1..cb42e9e4560cf 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala @@ -24,8 +24,10 @@ import org.apache.spark.api.java.function._ import org.apache.spark.sql.catalyst.encoders.{encoderFor, ExpressionEncoder} import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, CreateStruct} import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.catalyst.streaming.InternalOutputModes import org.apache.spark.sql.execution.QueryExecution import org.apache.spark.sql.expressions.ReduceAggregator +import org.apache.spark.sql.streaming.{GroupState, GroupStateTimeout, OutputMode} /** * :: Experimental :: @@ -226,20 +228,66 @@ class KeyValueGroupedDataset[K, V] private[sql]( * For a static batch Dataset, the function will be invoked once per group. For a streaming * Dataset, the function will be invoked for each group repeatedly in every trigger, and * updates to each group's state will be saved across invocations. - * See [[KeyedState]] for more details. + * See [[org.apache.spark.sql.streaming.GroupState]] for more details. * * @tparam S The type of the user-defined state. Must be encodable to Spark SQL types. * @tparam U The type of the output objects. Must be encodable to Spark SQL types. + * @param func Function to be called on every group. * * See [[Encoder]] for more details on what types are encodable to Spark SQL. - * @since 2.1.1 + * @since 2.2.0 */ @Experimental @InterfaceStability.Evolving def mapGroupsWithState[S: Encoder, U: Encoder]( - func: (K, Iterator[V], KeyedState[S]) => U): Dataset[U] = { - flatMapGroupsWithState[S, U]( - (key: K, it: Iterator[V], s: KeyedState[S]) => Iterator(func(key, it, s))) + func: (K, Iterator[V], GroupState[S]) => U): Dataset[U] = { + val flatMapFunc = (key: K, it: Iterator[V], s: GroupState[S]) => Iterator(func(key, it, s)) + Dataset[U]( + sparkSession, + FlatMapGroupsWithState[K, V, S, U]( + flatMapFunc.asInstanceOf[(Any, Iterator[Any], LogicalGroupState[Any]) => Iterator[Any]], + groupingAttributes, + dataAttributes, + OutputMode.Update, + isMapGroupsWithState = true, + GroupStateTimeout.NoTimeout, + child = logicalPlan)) + } + + /** + * ::Experimental:: + * (Scala-specific) + * Applies the given function to each group of data, while maintaining a user-defined per-group + * state. The result Dataset will represent the objects returned by the function. + * For a static batch Dataset, the function will be invoked once per group. For a streaming + * Dataset, the function will be invoked for each group repeatedly in every trigger, and + * updates to each group's state will be saved across invocations. + * See [[org.apache.spark.sql.streaming.GroupState]] for more details. + * + * @tparam S The type of the user-defined state. Must be encodable to Spark SQL types. + * @tparam U The type of the output objects. Must be encodable to Spark SQL types. + * @param func Function to be called on every group. + * @param timeoutConf Timeout configuration for groups that do not receive data for a while. + * + * See [[Encoder]] for more details on what types are encodable to Spark SQL. + * @since 2.2.0 + */ + @Experimental + @InterfaceStability.Evolving + def mapGroupsWithState[S: Encoder, U: Encoder]( + timeoutConf: GroupStateTimeout)( + func: (K, Iterator[V], GroupState[S]) => U): Dataset[U] = { + val flatMapFunc = (key: K, it: Iterator[V], s: GroupState[S]) => Iterator(func(key, it, s)) + Dataset[U]( + sparkSession, + FlatMapGroupsWithState[K, V, S, U]( + flatMapFunc.asInstanceOf[(Any, Iterator[Any], LogicalGroupState[Any]) => Iterator[Any]], + groupingAttributes, + dataAttributes, + OutputMode.Update, + isMapGroupsWithState = true, + timeoutConf, + child = logicalPlan)) } /** @@ -250,7 +298,7 @@ class KeyValueGroupedDataset[K, V] private[sql]( * For a static batch Dataset, the function will be invoked once per group. For a streaming * Dataset, the function will be invoked for each group repeatedly in every trigger, and * updates to each group's state will be saved across invocations. - * See [[KeyedState]] for more details. + * See `GroupState` for more details. * * @tparam S The type of the user-defined state. Must be encodable to Spark SQL types. * @tparam U The type of the output objects. Must be encodable to Spark SQL types. @@ -259,7 +307,7 @@ class KeyValueGroupedDataset[K, V] private[sql]( * @param outputEncoder Encoder for the output type. * * See [[Encoder]] for more details on what types are encodable to Spark SQL. - * @since 2.1.1 + * @since 2.2.0 */ @Experimental @InterfaceStability.Evolving @@ -267,8 +315,40 @@ class KeyValueGroupedDataset[K, V] private[sql]( func: MapGroupsWithStateFunction[K, V, S, U], stateEncoder: Encoder[S], outputEncoder: Encoder[U]): Dataset[U] = { - flatMapGroupsWithState[S, U]( - (key: K, it: Iterator[V], s: KeyedState[S]) => Iterator(func.call(key, it.asJava, s)) + mapGroupsWithState[S, U]( + (key: K, it: Iterator[V], s: GroupState[S]) => func.call(key, it.asJava, s) + )(stateEncoder, outputEncoder) + } + + /** + * ::Experimental:: + * (Java-specific) + * Applies the given function to each group of data, while maintaining a user-defined per-group + * state. The result Dataset will represent the objects returned by the function. + * For a static batch Dataset, the function will be invoked once per group. For a streaming + * Dataset, the function will be invoked for each group repeatedly in every trigger, and + * updates to each group's state will be saved across invocations. + * See `GroupState` for more details. + * + * @tparam S The type of the user-defined state. Must be encodable to Spark SQL types. + * @tparam U The type of the output objects. Must be encodable to Spark SQL types. + * @param func Function to be called on every group. + * @param stateEncoder Encoder for the state type. + * @param outputEncoder Encoder for the output type. + * @param timeoutConf Timeout configuration for groups that do not receive data for a while. + * + * See [[Encoder]] for more details on what types are encodable to Spark SQL. + * @since 2.2.0 + */ + @Experimental + @InterfaceStability.Evolving + def mapGroupsWithState[S, U]( + func: MapGroupsWithStateFunction[K, V, S, U], + stateEncoder: Encoder[S], + outputEncoder: Encoder[U], + timeoutConf: GroupStateTimeout): Dataset[U] = { + mapGroupsWithState[S, U](timeoutConf)( + (key: K, it: Iterator[V], s: GroupState[S]) => func.call(key, it.asJava, s) )(stateEncoder, outputEncoder) } @@ -280,25 +360,36 @@ class KeyValueGroupedDataset[K, V] private[sql]( * For a static batch Dataset, the function will be invoked once per group. For a streaming * Dataset, the function will be invoked for each group repeatedly in every trigger, and * updates to each group's state will be saved across invocations. - * See [[KeyedState]] for more details. + * See `GroupState` for more details. * * @tparam S The type of the user-defined state. Must be encodable to Spark SQL types. * @tparam U The type of the output objects. Must be encodable to Spark SQL types. + * @param func Function to be called on every group. + * @param outputMode The output mode of the function. + * @param timeoutConf Timeout configuration for groups that do not receive data for a while. * * See [[Encoder]] for more details on what types are encodable to Spark SQL. - * @since 2.1.1 + * @since 2.2.0 */ @Experimental @InterfaceStability.Evolving def flatMapGroupsWithState[S: Encoder, U: Encoder]( - func: (K, Iterator[V], KeyedState[S]) => Iterator[U]): Dataset[U] = { + outputMode: OutputMode, + timeoutConf: GroupStateTimeout)( + func: (K, Iterator[V], GroupState[S]) => Iterator[U]): Dataset[U] = { + if (outputMode != OutputMode.Append && outputMode != OutputMode.Update) { + throw new IllegalArgumentException("The output mode of function should be append or update") + } Dataset[U]( sparkSession, - MapGroupsWithState[K, V, S, U]( - func.asInstanceOf[(Any, Iterator[Any], LogicalKeyedState[Any]) => Iterator[Any]], + FlatMapGroupsWithState[K, V, S, U]( + func.asInstanceOf[(Any, Iterator[Any], LogicalGroupState[Any]) => Iterator[Any]], groupingAttributes, dataAttributes, - logicalPlan)) + outputMode, + isMapGroupsWithState = false, + timeoutConf, + child = logicalPlan)) } /** @@ -309,26 +400,29 @@ class KeyValueGroupedDataset[K, V] private[sql]( * For a static batch Dataset, the function will be invoked once per group. For a streaming * Dataset, the function will be invoked for each group repeatedly in every trigger, and * updates to each group's state will be saved across invocations. - * See [[KeyedState]] for more details. + * See `GroupState` for more details. * * @tparam S The type of the user-defined state. Must be encodable to Spark SQL types. * @tparam U The type of the output objects. Must be encodable to Spark SQL types. * @param func Function to be called on every group. + * @param outputMode The output mode of the function. * @param stateEncoder Encoder for the state type. * @param outputEncoder Encoder for the output type. + * @param timeoutConf Timeout configuration for groups that do not receive data for a while. * * See [[Encoder]] for more details on what types are encodable to Spark SQL. - * @since 2.1.1 + * @since 2.2.0 */ @Experimental @InterfaceStability.Evolving def flatMapGroupsWithState[S, U]( func: FlatMapGroupsWithStateFunction[K, V, S, U], + outputMode: OutputMode, stateEncoder: Encoder[S], - outputEncoder: Encoder[U]): Dataset[U] = { - flatMapGroupsWithState[S, U]( - (key: K, it: Iterator[V], s: KeyedState[S]) => func.call(key, it.asJava, s).asScala - )(stateEncoder, outputEncoder) + outputEncoder: Encoder[U], + timeoutConf: GroupStateTimeout): Dataset[U] = { + val f = (key: K, it: Iterator[V], s: GroupState[S]) => func.call(key, it.asJava, s).asScala + flatMapGroupsWithState[S, U](outputMode, timeoutConf)(f)(stateEncoder, outputEncoder) } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/KeyedState.scala b/sql/core/src/main/scala/org/apache/spark/sql/KeyedState.scala deleted file mode 100644 index 71efa4384211f..0000000000000 --- a/sql/core/src/main/scala/org/apache/spark/sql/KeyedState.scala +++ /dev/null @@ -1,140 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql - -import org.apache.spark.annotation.{Experimental, InterfaceStability} -import org.apache.spark.sql.catalyst.plans.logical.LogicalKeyedState - -/** - * :: Experimental :: - * - * Wrapper class for interacting with keyed state data in `mapGroupsWithState` and - * `flatMapGroupsWithState` operations on - * [[KeyValueGroupedDataset]]. - * - * Detail description on `[map/flatMap]GroupsWithState` operation - * ------------------------------------------------------------ - * Both, `mapGroupsWithState` and `flatMapGroupsWithState` in [[KeyValueGroupedDataset]] - * will invoke the user-given function on each group (defined by the grouping function in - * `Dataset.groupByKey()`) while maintaining user-defined per-group state between invocations. - * For a static batch Dataset, the function will be invoked once per group. For a streaming - * Dataset, the function will be invoked for each group repeatedly in every trigger. - * That is, in every batch of the `streaming.StreamingQuery`, - * the function will be invoked once for each group that has data in the batch. - * - * The function is invoked with following parameters. - * - The key of the group. - * - An iterator containing all the values for this key. - * - A user-defined state object set by previous invocations of the given function. - * In case of a batch Dataset, there is only one invocation and state object will be empty as - * there is no prior state. Essentially, for batch Datasets, `[map/flatMap]GroupsWithState` - * is equivalent to `[map/flatMap]Groups`. - * - * Important points to note about the function. - * - In a trigger, the function will be called only the groups present in the batch. So do not - * assume that the function will be called in every trigger for every group that has state. - * - There is no guaranteed ordering of values in the iterator in the function, neither with - * batch, nor with streaming Datasets. - * - All the data will be shuffled before applying the function. - * - * Important points to note about using KeyedState. - * - The value of the state cannot be null. So updating state with null will throw - * `IllegalArgumentException`. - * - Operations on `KeyedState` are not thread-safe. This is to avoid memory barriers. - * - If `remove()` is called, then `exists()` will return `false`, - * `get()` will throw `NoSuchElementException` and `getOption()` will return `None` - * - After that, if `update(newState)` is called, then `exists()` will again return `true`, - * `get()` and `getOption()`will return the updated value. - * - * Scala example of using KeyedState in `mapGroupsWithState`: - * {{{ - * // A mapping function that maintains an integer state for string keys and returns a string. - * def mappingFunction(key: String, value: Iterator[Int], state: KeyedState[Int]): String = { - * // Check if state exists - * if (state.exists) { - * val existingState = state.get // Get the existing state - * val shouldRemove = ... // Decide whether to remove the state - * if (shouldRemove) { - * state.remove() // Remove the state - * } else { - * val newState = ... - * state.update(newState) // Set the new state - * } - * } else { - * val initialState = ... - * state.update(initialState) // Set the initial state - * } - * ... // return something - * } - * - * }}} - * - * Java example of using `KeyedState`: - * {{{ - * // A mapping function that maintains an integer state for string keys and returns a string. - * MapGroupsWithStateFunction mappingFunction = - * new MapGroupsWithStateFunction() { - * - * @Override - * public String call(String key, Iterator value, KeyedState state) { - * if (state.exists()) { - * int existingState = state.get(); // Get the existing state - * boolean shouldRemove = ...; // Decide whether to remove the state - * if (shouldRemove) { - * state.remove(); // Remove the state - * } else { - * int newState = ...; - * state.update(newState); // Set the new state - * } - * } else { - * int initialState = ...; // Set the initial state - * state.update(initialState); - * } - * ... // return something - * } - * }; - * }}} - * - * @tparam S User-defined type of the state to be stored for each key. Must be encodable into - * Spark SQL types (see [[Encoder]] for more details). - * @since 2.1.1 - */ -@Experimental -@InterfaceStability.Evolving -trait KeyedState[S] extends LogicalKeyedState[S] { - - /** Whether state exists or not. */ - def exists: Boolean - - /** Get the state value if it exists, or throw NoSuchElementException. */ - @throws[NoSuchElementException]("when state does not exist") - def get: S - - /** Get the state value as a scala Option. */ - def getOption: Option[S] - - /** - * Update the value of the state. Note that `null` is not a valid value, and it throws - * IllegalArgumentException. - */ - @throws[IllegalArgumentException]("when updating with null") - def update(newState: S): Unit - - /** Remove this keyed state. */ - def remove(): Unit -} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala index 0fe8d87ebd6ba..64755434784a0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql +import java.util.Locale + import scala.collection.JavaConverters._ import scala.language.implicitConversions @@ -108,7 +110,7 @@ class RelationalGroupedDataset protected[sql]( private[this] def strToExpr(expr: String): (Expression => Expression) = { val exprToFunc: (Expression => Expression) = { - (inputExpr: Expression) => expr.toLowerCase match { + (inputExpr: Expression) => expr.toLowerCase(Locale.ROOT) match { // We special handle a few cases that have alias that are not in function registry. case "avg" | "average" | "mean" => UnresolvedFunction("avg", inputExpr :: Nil, isDistinct = false) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala index 234ef2dffc6bc..cc2983987eb90 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala @@ -17,7 +17,6 @@ package org.apache.spark.sql -import java.beans.BeanInfo import java.util.Properties import scala.collection.immutable @@ -527,8 +526,9 @@ class SQLContext private[sql](val sparkSession: SparkSession) * @group ddl_ops * @since 1.3.0 */ + @deprecated("use sparkSession.catalog.createTable instead.", "2.2.0") def createExternalTable(tableName: String, path: String): DataFrame = { - sparkSession.catalog.createExternalTable(tableName, path) + sparkSession.catalog.createTable(tableName, path) } /** @@ -538,11 +538,12 @@ class SQLContext private[sql](val sparkSession: SparkSession) * @group ddl_ops * @since 1.3.0 */ + @deprecated("use sparkSession.catalog.createTable instead.", "2.2.0") def createExternalTable( tableName: String, path: String, source: String): DataFrame = { - sparkSession.catalog.createExternalTable(tableName, path, source) + sparkSession.catalog.createTable(tableName, path, source) } /** @@ -552,11 +553,12 @@ class SQLContext private[sql](val sparkSession: SparkSession) * @group ddl_ops * @since 1.3.0 */ + @deprecated("use sparkSession.catalog.createTable instead.", "2.2.0") def createExternalTable( tableName: String, source: String, options: java.util.Map[String, String]): DataFrame = { - sparkSession.catalog.createExternalTable(tableName, source, options) + sparkSession.catalog.createTable(tableName, source, options) } /** @@ -567,11 +569,12 @@ class SQLContext private[sql](val sparkSession: SparkSession) * @group ddl_ops * @since 1.3.0 */ + @deprecated("use sparkSession.catalog.createTable instead.", "2.2.0") def createExternalTable( tableName: String, source: String, options: Map[String, String]): DataFrame = { - sparkSession.catalog.createExternalTable(tableName, source, options) + sparkSession.catalog.createTable(tableName, source, options) } /** @@ -581,12 +584,13 @@ class SQLContext private[sql](val sparkSession: SparkSession) * @group ddl_ops * @since 1.3.0 */ + @deprecated("use sparkSession.catalog.createTable instead.", "2.2.0") def createExternalTable( tableName: String, source: String, schema: StructType, options: java.util.Map[String, String]): DataFrame = { - sparkSession.catalog.createExternalTable(tableName, source, schema, options) + sparkSession.catalog.createTable(tableName, source, schema, options) } /** @@ -597,12 +601,13 @@ class SQLContext private[sql](val sparkSession: SparkSession) * @group ddl_ops * @since 1.3.0 */ + @deprecated("use sparkSession.catalog.createTable instead.", "2.2.0") def createExternalTable( tableName: String, source: String, schema: StructType, options: Map[String, String]): DataFrame = { - sparkSession.catalog.createExternalTable(tableName, source, schema, options) + sparkSession.catalog.createTable(tableName, source, schema, options) } /** @@ -1089,9 +1094,9 @@ object SQLContext { * method for internal use. */ private[sql] def beansToRows( - data: Iterator[_], - beanClass: Class[_], - attrs: Seq[AttributeReference]): Iterator[InternalRow] = { + data: Iterator[_], + beanClass: Class[_], + attrs: Seq[AttributeReference]): Iterator[InternalRow] = { val extractors = JavaTypeInference.getJavaBeanReadableProperties(beanClass).map(_.getReadMethod) val methodsToConverts = extractors.zip(attrs).map { case (e, attr) => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala b/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala index bc4cf92584b94..5e2692980195b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala @@ -21,7 +21,6 @@ import java.io.Closeable import java.util.concurrent.atomic.AtomicReference import scala.collection.JavaConverters._ -import scala.reflect.ClassTag import scala.reflect.runtime.universe.TypeTag import scala.util.control.NonFatal @@ -39,7 +38,7 @@ import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, Range} import org.apache.spark.sql.execution._ import org.apache.spark.sql.execution.datasources.LogicalRelation import org.apache.spark.sql.execution.ui.SQLListener -import org.apache.spark.sql.internal.{CatalogImpl, SessionState, SharedState} +import org.apache.spark.sql.internal._ import org.apache.spark.sql.internal.StaticSQLConf.{CATALOG_IMPLEMENTATION, SESSION_STATE_IMPLEMENTATION} import org.apache.spark.sql.sources.BaseRelation import org.apache.spark.sql.streaming._ @@ -61,21 +60,29 @@ import org.apache.spark.util.Utils * The builder can also be used to create a new session: * * {{{ - * SparkSession.builder() + * SparkSession.builder * .master("local") * .appName("Word Count") * .config("spark.some.config.option", "some-value") * .getOrCreate() * }}} + * + * @param sparkContext The Spark context associated with this Spark session. + * @param existingSharedState If supplied, use the existing shared state + * instead of creating a new one. + * @param parentSessionState If supplied, inherit all session state (i.e. temporary + * views, SQL config, UDFs etc) from parent. */ @InterfaceStability.Stable class SparkSession private( @transient val sparkContext: SparkContext, - @transient private val existingSharedState: Option[SharedState]) + @transient private val existingSharedState: Option[SharedState], + @transient private val parentSessionState: Option[SessionState], + @transient private[sql] val extensions: SparkSessionExtensions) extends Serializable with Closeable with Logging { self => private[sql] def this(sc: SparkContext) { - this(sc, None) + this(sc, None, None, new SparkSessionExtensions) } sparkContext.assertNotStopped() @@ -108,6 +115,7 @@ class SparkSession private( /** * State isolated across sessions, including SQL configurations, temporary tables, registered * functions, and everything else that accepts a [[org.apache.spark.sql.internal.SQLConf]]. + * If `parentSessionState` is not null, the `SessionState` will be a copy of the parent. * * This is internal to Spark and there is no guarantee on interface stability. * @@ -116,9 +124,13 @@ class SparkSession private( @InterfaceStability.Unstable @transient lazy val sessionState: SessionState = { - SparkSession.reflect[SessionState, SparkSession]( - SparkSession.sessionStateClassName(sparkContext.conf), - self) + parentSessionState + .map(_.clone(this)) + .getOrElse { + SparkSession.instantiateSessionState( + SparkSession.sessionStateClassName(sparkContext.conf), + self) + } } /** @@ -183,7 +195,7 @@ class SparkSession private( * * @since 2.0.0 */ - def udf: UDFRegistration = sessionState.udf + def udf: UDFRegistration = sessionState.udfRegistration /** * :: Experimental :: @@ -208,7 +220,25 @@ class SparkSession private( * @since 2.0.0 */ def newSession(): SparkSession = { - new SparkSession(sparkContext, Some(sharedState)) + new SparkSession(sparkContext, Some(sharedState), parentSessionState = None, extensions) + } + + /** + * Create an identical copy of this `SparkSession`, sharing the underlying `SparkContext` + * and shared state. All the state of this session (i.e. SQL configurations, temporary tables, + * registered functions) is copied over, and the cloned session is set up with the same shared + * state as this session. The cloned session is independent of this session, that is, any + * non-global change in either session is not reflected in the other. + * + * @note Other than the `SparkContext`, all shared state is initialized lazily. + * This method will force the initialization of the shared state to ensure that parent + * and child sessions are set up with the same shared state. If the underlying catalog + * implementation is Hive, this will initialize the metastore, which may take some time. + */ + private[sql] def cloneSession(): SparkSession = { + val result = new SparkSession(sparkContext, Some(sharedState), Some(sessionState), extensions) + result.sessionState // force copy of SessionState + result } @@ -562,8 +592,13 @@ class SparkSession private( @transient lazy val catalog: Catalog = new CatalogImpl(self) /** - * Returns the specified table as a `DataFrame`. + * Returns the specified table/view as a `DataFrame`. * + * @param tableName is either a qualified or unqualified name that designates a table or view. + * If a database is specified, it identifies the table/view from the database. + * Otherwise, it first attempts to find a temporary view with the given name + * and then match the table/view from the current database. + * Note that, the global temporary view database is also valid here. * @since 2.0.0 */ def table(tableName: String): DataFrame = { @@ -720,6 +755,8 @@ object SparkSession { private[this] val options = new scala.collection.mutable.HashMap[String, String] + private[this] val extensions = new SparkSessionExtensions + private[this] var userSuppliedContext: Option[SparkContext] = None private[spark] def sparkContext(sparkContext: SparkContext): Builder = synchronized { @@ -814,6 +851,17 @@ object SparkSession { } } + /** + * Inject extensions into the [[SparkSession]]. This allows a user to add Analyzer rules, + * Optimizer rules, Planning Strategies or a customized parser. + * + * @since 2.2.0 + */ + def withExtensions(f: SparkSessionExtensions => Unit): Builder = { + f(extensions) + this + } + /** * Gets an existing [[SparkSession]] or, if there is no existing one, creates a new * one based on the options set in this builder. @@ -870,7 +918,26 @@ object SparkSession { } sc } - session = new SparkSession(sparkContext) + + // Initialize extensions if the user has defined a configurator class. + val extensionConfOption = sparkContext.conf.get(StaticSQLConf.SPARK_SESSION_EXTENSIONS) + if (extensionConfOption.isDefined) { + val extensionConfClassName = extensionConfOption.get + try { + val extensionConfClass = Utils.classForName(extensionConfClassName) + val extensionConf = extensionConfClass.newInstance() + .asInstanceOf[SparkSessionExtensions => Unit] + extensionConf(extensions) + } catch { + // Ignore the error if we cannot find the class or when the class has the wrong type. + case e @ (_: ClassCastException | + _: ClassNotFoundException | + _: NoClassDefFoundError) => + logWarning(s"Cannot use $extensionConfClassName to configure session extensions.", e) + } + } + + session = new SparkSession(sparkContext, None, None, extensions) options.foreach { case (k, v) => session.sessionState.conf.setConfString(k, v) } defaultSession.set(session) @@ -962,27 +1029,29 @@ object SparkSession { /** Reference to the root SparkSession. */ private val defaultSession = new AtomicReference[SparkSession] - private val HIVE_SESSION_STATE_CLASS_NAME = "org.apache.spark.sql.hive.HiveSessionState" + private val HIVE_SESSION_STATE_BUILDER_CLASS_NAME = + "org.apache.spark.sql.hive.HiveSessionStateBuilder" private def sessionStateClassName(conf: SparkConf): String = { conf.get(SESSION_STATE_IMPLEMENTATION) match { - case "hive" => HIVE_SESSION_STATE_CLASS_NAME - case "in-memory" => classOf[SessionState].getCanonicalName - case name => name + case "hive" => HIVE_SESSION_STATE_BUILDER_CLASS_NAME + case "in-memory" => classOf[SessionStateBuilder].getCanonicalName + case builder => builder } } /** - * Helper method to create an instance of [[T]] using a single-arg constructor that - * accepts an [[Arg]]. + * Helper method to create an instance of `SessionState` based on `className` from conf. + * The result is either `SessionState` or a Hive based `SessionState`. */ - private def reflect[T, Arg <: AnyRef]( + private def instantiateSessionState( className: String, - ctorArg: Arg)(implicit ctorArgTag: ClassTag[Arg]): T = { + sparkSession: SparkSession): SessionState = { try { + // invoke `new [Hive]SessionStateBuilder(SparkSession, Option[SessionState])` val clazz = Utils.classForName(className) - val ctor = clazz.getDeclaredConstructor(ctorArgTag.runtimeClass) - ctor.newInstance(ctorArg).asInstanceOf[T] + val ctor = clazz.getConstructors.head + ctor.newInstance(sparkSession, None).asInstanceOf[BaseSessionStateBuilder].build() } catch { case NonFatal(e) => throw new IllegalArgumentException(s"Error while instantiating '$className':", e) @@ -994,7 +1063,7 @@ object SparkSession { */ private[spark] def hiveClassesArePresent: Boolean = { try { - Utils.classForName(HIVE_SESSION_STATE_CLASS_NAME) + Utils.classForName(HIVE_SESSION_STATE_BUILDER_CLASS_NAME) Utils.classForName("org.apache.hadoop.hive.conf.HiveConf") true } catch { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SparkSessionExtensions.scala b/sql/core/src/main/scala/org/apache/spark/sql/SparkSessionExtensions.scala new file mode 100644 index 0000000000000..f99c108161f94 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/SparkSessionExtensions.scala @@ -0,0 +1,171 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +import scala.collection.mutable + +import org.apache.spark.annotation.{DeveloperApi, Experimental, InterfaceStability} +import org.apache.spark.sql.catalyst.parser.ParserInterface +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.catalyst.rules.Rule + +/** + * :: Experimental :: + * Holder for injection points to the [[SparkSession]]. We make NO guarantee about the stability + * regarding binary compatibility and source compatibility of methods here. + * + * This current provides the following extension points: + * - Analyzer Rules. + * - Check Analysis Rules + * - Optimizer Rules. + * - Planning Strategies. + * - Customized Parser. + * - (External) Catalog listeners. + * + * The extensions can be used by calling withExtension on the [[SparkSession.Builder]], for + * example: + * {{{ + * SparkSession.builder() + * .master("...") + * .conf("...", true) + * .withExtensions { extensions => + * extensions.injectResolutionRule { session => + * ... + * } + * extensions.injectParser { (session, parser) => + * ... + * } + * } + * .getOrCreate() + * }}} + * + * Note that none of the injected builders should assume that the [[SparkSession]] is fully + * initialized and should not touch the session's internals (e.g. the SessionState). + */ +@DeveloperApi +@Experimental +@InterfaceStability.Unstable +class SparkSessionExtensions { + type RuleBuilder = SparkSession => Rule[LogicalPlan] + type CheckRuleBuilder = SparkSession => LogicalPlan => Unit + type StrategyBuilder = SparkSession => Strategy + type ParserBuilder = (SparkSession, ParserInterface) => ParserInterface + + private[this] val resolutionRuleBuilders = mutable.Buffer.empty[RuleBuilder] + + /** + * Build the analyzer resolution `Rule`s using the given [[SparkSession]]. + */ + private[sql] def buildResolutionRules(session: SparkSession): Seq[Rule[LogicalPlan]] = { + resolutionRuleBuilders.map(_.apply(session)) + } + + /** + * Inject an analyzer resolution `Rule` builder into the [[SparkSession]]. These analyzer + * rules will be executed as part of the resolution phase of analysis. + */ + def injectResolutionRule(builder: RuleBuilder): Unit = { + resolutionRuleBuilders += builder + } + + private[this] val postHocResolutionRuleBuilders = mutable.Buffer.empty[RuleBuilder] + + /** + * Build the analyzer post-hoc resolution `Rule`s using the given [[SparkSession]]. + */ + private[sql] def buildPostHocResolutionRules(session: SparkSession): Seq[Rule[LogicalPlan]] = { + postHocResolutionRuleBuilders.map(_.apply(session)) + } + + /** + * Inject an analyzer `Rule` builder into the [[SparkSession]]. These analyzer + * rules will be executed after resolution. + */ + def injectPostHocResolutionRule(builder: RuleBuilder): Unit = { + postHocResolutionRuleBuilders += builder + } + + private[this] val checkRuleBuilders = mutable.Buffer.empty[CheckRuleBuilder] + + /** + * Build the check analysis `Rule`s using the given [[SparkSession]]. + */ + private[sql] def buildCheckRules(session: SparkSession): Seq[LogicalPlan => Unit] = { + checkRuleBuilders.map(_.apply(session)) + } + + /** + * Inject an check analysis `Rule` builder into the [[SparkSession]]. The injected rules will + * be executed after the analysis phase. A check analysis rule is used to detect problems with a + * LogicalPlan and should throw an exception when a problem is found. + */ + def injectCheckRule(builder: CheckRuleBuilder): Unit = { + checkRuleBuilders += builder + } + + private[this] val optimizerRules = mutable.Buffer.empty[RuleBuilder] + + private[sql] def buildOptimizerRules(session: SparkSession): Seq[Rule[LogicalPlan]] = { + optimizerRules.map(_.apply(session)) + } + + /** + * Inject an optimizer `Rule` builder into the [[SparkSession]]. The injected rules will be + * executed during the operator optimization batch. An optimizer rule is used to improve the + * quality of an analyzed logical plan; these rules should never modify the result of the + * LogicalPlan. + */ + def injectOptimizerRule(builder: RuleBuilder): Unit = { + optimizerRules += builder + } + + private[this] val plannerStrategyBuilders = mutable.Buffer.empty[StrategyBuilder] + + private[sql] def buildPlannerStrategies(session: SparkSession): Seq[Strategy] = { + plannerStrategyBuilders.map(_.apply(session)) + } + + /** + * Inject a planner `Strategy` builder into the [[SparkSession]]. The injected strategy will + * be used to convert a `LogicalPlan` into a executable + * [[org.apache.spark.sql.execution.SparkPlan]]. + */ + def injectPlannerStrategy(builder: StrategyBuilder): Unit = { + plannerStrategyBuilders += builder + } + + private[this] val parserBuilders = mutable.Buffer.empty[ParserBuilder] + + private[sql] def buildParser( + session: SparkSession, + initial: ParserInterface): ParserInterface = { + parserBuilders.foldLeft(initial) { (parser, builder) => + builder(session, parser) + } + } + + /** + * Inject a custom parser into the [[SparkSession]]. Note that the builder is passed a session + * and an initial parser. The latter allows for a user to create a partial parser and to delegate + * to the underlying parser for completeness. If a user injects more parsers, then the parsers + * are stacked on top of each other. + */ + def injectParser(builder: ParserBuilder): Unit = { + parserBuilders += builder + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala b/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala index 7abfa4ea37a74..a57673334c10b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala @@ -36,7 +36,11 @@ import org.apache.spark.sql.types.{DataType, DataTypes} import org.apache.spark.util.Utils /** - * Functions for registering user-defined functions. Use `SQLContext.udf` to access this. + * Functions for registering user-defined functions. Use `SparkSession.udf` to access this: + * + * {{{ + * spark.udf + * }}} * * @note The user-defined functions must be deterministic. * diff --git a/sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala index a4c5bf756cd5a..d94e528a3ad47 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.api.r import java.io.{ByteArrayInputStream, ByteArrayOutputStream, DataInputStream, DataOutputStream} -import java.util.{Map => JMap} +import java.util.{Locale, Map => JMap} import scala.collection.JavaConverters._ import scala.util.matching.Regex @@ -31,6 +31,7 @@ import org.apache.spark.broadcast.Broadcast import org.apache.spark.rdd.RDD import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.expressions.GenericRowWithSchema +import org.apache.spark.sql.catalyst.parser.CatalystSqlParser import org.apache.spark.sql.execution.command.ShowTablesCommand import org.apache.spark.sql.internal.StaticSQLConf.CATALOG_IMPLEMENTATION import org.apache.spark.sql.types._ @@ -47,17 +48,19 @@ private[sql] object SQLUtils extends Logging { jsc: JavaSparkContext, sparkConfigMap: JMap[Object, Object], enableHiveSupport: Boolean): SparkSession = { - val spark = if (SparkSession.hiveClassesArePresent && enableHiveSupport - && jsc.sc.conf.get(CATALOG_IMPLEMENTATION.key, "hive").toLowerCase == "hive") { - SparkSession.builder().sparkContext(withHiveExternalCatalog(jsc.sc)).getOrCreate() - } else { - if (enableHiveSupport) { - logWarning("SparkR: enableHiveSupport is requested for SparkSession but " + - s"Spark is not built with Hive or ${CATALOG_IMPLEMENTATION.key} is not set to 'hive', " + - "falling back to without Hive support.") + val spark = + if (SparkSession.hiveClassesArePresent && enableHiveSupport && + jsc.sc.conf.get(CATALOG_IMPLEMENTATION.key, "hive").toLowerCase(Locale.ROOT) == + "hive") { + SparkSession.builder().sparkContext(withHiveExternalCatalog(jsc.sc)).getOrCreate() + } else { + if (enableHiveSupport) { + logWarning("SparkR: enableHiveSupport is requested for SparkSession but " + + s"Spark is not built with Hive or ${CATALOG_IMPLEMENTATION.key} is not set to " + + "'hive', falling back to without Hive support.") + } + SparkSession.builder().sparkContext(jsc.sc).getOrCreate() } - SparkSession.builder().sparkContext(jsc.sc).getOrCreate() - } setSparkContextSessionConf(spark, sparkConfigMap) spark } @@ -81,7 +84,7 @@ private[sql] object SQLUtils extends Logging { new JavaSparkContext(spark.sparkContext) } - def createStructType(fields : Seq[StructField]): StructType = { + def createStructType(fields: Seq[StructField]): StructType = { StructType(fields) } @@ -90,48 +93,8 @@ private[sql] object SQLUtils extends Logging { def r: Regex = new Regex(sc.parts.mkString, sc.parts.tail.map(_ => "x"): _*) } - def getSQLDataType(dataType: String): DataType = { - dataType match { - case "byte" => org.apache.spark.sql.types.ByteType - case "integer" => org.apache.spark.sql.types.IntegerType - case "float" => org.apache.spark.sql.types.FloatType - case "double" => org.apache.spark.sql.types.DoubleType - case "numeric" => org.apache.spark.sql.types.DoubleType - case "character" => org.apache.spark.sql.types.StringType - case "string" => org.apache.spark.sql.types.StringType - case "binary" => org.apache.spark.sql.types.BinaryType - case "raw" => org.apache.spark.sql.types.BinaryType - case "logical" => org.apache.spark.sql.types.BooleanType - case "boolean" => org.apache.spark.sql.types.BooleanType - case "timestamp" => org.apache.spark.sql.types.TimestampType - case "date" => org.apache.spark.sql.types.DateType - case r"\Aarray<(.+)${elemType}>\Z" => - org.apache.spark.sql.types.ArrayType(getSQLDataType(elemType)) - case r"\Amap<(.+)${keyType},(.+)${valueType}>\Z" => - if (keyType != "string" && keyType != "character") { - throw new IllegalArgumentException("Key type of a map must be string or character") - } - org.apache.spark.sql.types.MapType(getSQLDataType(keyType), getSQLDataType(valueType)) - case r"\Astruct<(.+)${fieldsStr}>\Z" => - if (fieldsStr(fieldsStr.length - 1) == ',') { - throw new IllegalArgumentException(s"Invalid type $dataType") - } - val fields = fieldsStr.split(",") - val structFields = fields.map { field => - field match { - case r"\A(.+)${fieldName}:(.+)${fieldType}\Z" => - createStructField(fieldName, fieldType, true) - - case _ => throw new IllegalArgumentException(s"Invalid type $dataType") - } - } - createStructType(structFields) - case _ => throw new IllegalArgumentException(s"Invalid type $dataType") - } - } - def createStructField(name: String, dataType: String, nullable: Boolean): StructField = { - val dtObj = getSQLDataType(dataType) + val dtObj = CatalystSqlParser.parseDataType(dataType) StructField(name, dtObj, nullable) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/catalog/Catalog.scala b/sql/core/src/main/scala/org/apache/spark/sql/catalog/Catalog.scala index 50252db789d46..7e5da012f84ca 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/catalog/Catalog.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/catalog/Catalog.scala @@ -54,16 +54,16 @@ abstract class Catalog { def listDatabases(): Dataset[Database] /** - * Returns a list of tables in the current database. - * This includes all temporary tables. + * Returns a list of tables/views in the current database. + * This includes all temporary views. * * @since 2.0.0 */ def listTables(): Dataset[Table] /** - * Returns a list of tables in the specified database. - * This includes all temporary tables. + * Returns a list of tables/views in the specified database. + * This includes all temporary views. * * @since 2.0.0 */ @@ -88,17 +88,21 @@ abstract class Catalog { def listFunctions(dbName: String): Dataset[Function] /** - * Returns a list of columns for the given table in the current database or - * the given temporary table. + * Returns a list of columns for the given table/view or temporary view. * + * @param tableName is either a qualified or unqualified name that designates a table/view. + * If no database identifier is provided, it refers to a temporary view or + * a table/view in the current database. * @since 2.0.0 */ @throws[AnalysisException]("table does not exist") def listColumns(tableName: String): Dataset[Column] /** - * Returns a list of columns for the given table in the specified database. + * Returns a list of columns for the given table/view in the specified database. * + * @param dbName is a name that designates a database. + * @param tableName is an unqualified name that designates a table/view. * @since 2.0.0 */ @throws[AnalysisException]("database or table does not exist") @@ -115,9 +119,11 @@ abstract class Catalog { /** * Get the table or view with the specified name. This table can be a temporary view or a - * table/view in the current database. This throws an AnalysisException when no Table - * can be found. + * table/view. This throws an AnalysisException when no Table can be found. * + * @param tableName is either a qualified or unqualified name that designates a table/view. + * If no database identifier is provided, it refers to a table/view in + * the current database. * @since 2.1.0 */ @throws[AnalysisException]("table does not exist") @@ -134,9 +140,11 @@ abstract class Catalog { /** * Get the function with the specified name. This function can be a temporary function or a - * function in the current database. This throws an AnalysisException when the function cannot - * be found. + * function. This throws an AnalysisException when the function cannot be found. * + * @param functionName is either a qualified or unqualified name that designates a function. + * If no database identifier is provided, it refers to a temporary function + * or a function in the current database. * @since 2.1.0 */ @throws[AnalysisException]("function does not exist") @@ -146,6 +154,8 @@ abstract class Catalog { * Get the function with the specified name. This throws an AnalysisException when the function * cannot be found. * + * @param dbName is a name that designates a database. + * @param functionName is an unqualified name that designates a function in the specified database * @since 2.1.0 */ @throws[AnalysisException]("database or function does not exist") @@ -160,8 +170,11 @@ abstract class Catalog { /** * Check if the table or view with the specified name exists. This can either be a temporary - * view or a table/view in the current database. + * view or a table/view. * + * @param tableName is either a qualified or unqualified name that designates a table/view. + * If no database identifier is provided, it refers to a table/view in + * the current database. * @since 2.1.0 */ def tableExists(tableName: String): Boolean @@ -169,14 +182,19 @@ abstract class Catalog { /** * Check if the table or view with the specified name exists in the specified database. * + * @param dbName is a name that designates a database. + * @param tableName is an unqualified name that designates a table. * @since 2.1.0 */ def tableExists(dbName: String, tableName: String): Boolean /** * Check if the function with the specified name exists. This can either be a temporary function - * or a function in the current database. + * or a function. * + * @param functionName is either a qualified or unqualified name that designates a function. + * If no database identifier is provided, it refers to a function in + * the current database. * @since 2.1.0 */ def functionExists(functionName: String): Boolean @@ -184,6 +202,8 @@ abstract class Catalog { /** * Check if the function with the specified name exists in the specified database. * + * @param dbName is a name that designates a database. + * @param functionName is an unqualified name that designates a function. * @since 2.1.0 */ def functionExists(dbName: String, functionName: String): Boolean @@ -192,6 +212,9 @@ abstract class Catalog { * Creates a table from the given path and returns the corresponding DataFrame. * It will use the default data source configured by spark.sql.sources.default. * + * @param tableName is either a qualified or unqualified name that designates a table. + * If no database identifier is provided, it refers to a table in + * the current database. * @since 2.0.0 */ @deprecated("use createTable instead.", "2.2.0") @@ -204,6 +227,9 @@ abstract class Catalog { * Creates a table from the given path and returns the corresponding DataFrame. * It will use the default data source configured by spark.sql.sources.default. * + * @param tableName is either a qualified or unqualified name that designates a table. + * If no database identifier is provided, it refers to a table in + * the current database. * @since 2.2.0 */ @Experimental @@ -214,6 +240,9 @@ abstract class Catalog { * Creates a table from the given path based on a data source and returns the corresponding * DataFrame. * + * @param tableName is either a qualified or unqualified name that designates a table. + * If no database identifier is provided, it refers to a table in + * the current database. * @since 2.0.0 */ @deprecated("use createTable instead.", "2.2.0") @@ -226,6 +255,9 @@ abstract class Catalog { * Creates a table from the given path based on a data source and returns the corresponding * DataFrame. * + * @param tableName is either a qualified or unqualified name that designates a table. + * If no database identifier is provided, it refers to a table in + * the current database. * @since 2.2.0 */ @Experimental @@ -236,6 +268,9 @@ abstract class Catalog { * Creates a table from the given path based on a data source and a set of options. * Then, returns the corresponding DataFrame. * + * @param tableName is either a qualified or unqualified name that designates a table. + * If no database identifier is provided, it refers to a table in + * the current database. * @since 2.0.0 */ @deprecated("use createTable instead.", "2.2.0") @@ -248,9 +283,12 @@ abstract class Catalog { /** * :: Experimental :: - * Creates a table from the given path based on a data source and a set of options. + * Creates a table based on the dataset in a data source and a set of options. * Then, returns the corresponding DataFrame. * + * @param tableName is either a qualified or unqualified name that designates a table. + * If no database identifier is provided, it refers to a table in + * the current database. * @since 2.2.0 */ @Experimental @@ -267,6 +305,9 @@ abstract class Catalog { * Creates a table from the given path based on a data source and a set of options. * Then, returns the corresponding DataFrame. * + * @param tableName is either a qualified or unqualified name that designates a table. + * If no database identifier is provided, it refers to a table in + * the current database. * @since 2.0.0 */ @deprecated("use createTable instead.", "2.2.0") @@ -280,9 +321,12 @@ abstract class Catalog { /** * :: Experimental :: * (Scala-specific) - * Creates a table from the given path based on a data source and a set of options. + * Creates a table based on the dataset in a data source and a set of options. * Then, returns the corresponding DataFrame. * + * @param tableName is either a qualified or unqualified name that designates a table. + * If no database identifier is provided, it refers to a table in + * the current database. * @since 2.2.0 */ @Experimental @@ -297,6 +341,9 @@ abstract class Catalog { * Create a table from the given path based on a data source, a schema and a set of options. * Then, returns the corresponding DataFrame. * + * @param tableName is either a qualified or unqualified name that designates a table. + * If no database identifier is provided, it refers to a table in + * the current database. * @since 2.0.0 */ @deprecated("use createTable instead.", "2.2.0") @@ -310,9 +357,12 @@ abstract class Catalog { /** * :: Experimental :: - * Create a table from the given path based on a data source, a schema and a set of options. + * Create a table based on the dataset in a data source, a schema and a set of options. * Then, returns the corresponding DataFrame. * + * @param tableName is either a qualified or unqualified name that designates a table. + * If no database identifier is provided, it refers to a table in + * the current database. * @since 2.2.0 */ @Experimental @@ -330,6 +380,9 @@ abstract class Catalog { * Create a table from the given path based on a data source, a schema and a set of options. * Then, returns the corresponding DataFrame. * + * @param tableName is either a qualified or unqualified name that designates a table. + * If no database identifier is provided, it refers to a table in + * the current database. * @since 2.0.0 */ @deprecated("use createTable instead.", "2.2.0") @@ -344,9 +397,12 @@ abstract class Catalog { /** * :: Experimental :: * (Scala-specific) - * Create a table from the given path based on a data source, a schema and a set of options. + * Create a table based on the dataset in a data source, a schema and a set of options. * Then, returns the corresponding DataFrame. * + * @param tableName is either a qualified or unqualified name that designates a table. + * If no database identifier is provided, it refers to a table in + * the current database. * @since 2.2.0 */ @Experimental @@ -368,7 +424,7 @@ abstract class Catalog { * Note that, the return type of this method was Unit in Spark 2.0, but changed to Boolean * in Spark 2.1. * - * @param viewName the name of the view to be dropped. + * @param viewName the name of the temporary view to be dropped. * @return true if the view is dropped successfully, false otherwise. * @since 2.0.0 */ @@ -383,15 +439,19 @@ abstract class Catalog { * preserved database `global_temp`, and we must use the qualified name to refer a global temp * view, e.g. `SELECT * FROM global_temp.view1`. * - * @param viewName the name of the view to be dropped. + * @param viewName the unqualified name of the temporary view to be dropped. * @return true if the view is dropped successfully, false otherwise. * @since 2.1.0 */ def dropGlobalTempView(viewName: String): Boolean /** - * Recover all the partitions in the directory of a table and update the catalog. + * Recovers all the partitions in the directory of a table and update the catalog. + * Only works with a partitioned table, and not a view. * + * @param tableName is either a qualified or unqualified name that designates a table. + * If no database identifier is provided, it refers to a table in the + * current database. * @since 2.1.1 */ def recoverPartitions(tableName: String): Unit @@ -399,6 +459,9 @@ abstract class Catalog { /** * Returns true if the table is currently cached in-memory. * + * @param tableName is either a qualified or unqualified name that designates a table/view. + * If no database identifier is provided, it refers to a temporary view or + * a table/view in the current database. * @since 2.0.0 */ def isCached(tableName: String): Boolean @@ -406,6 +469,9 @@ abstract class Catalog { /** * Caches the specified table in-memory. * + * @param tableName is either a qualified or unqualified name that designates a table/view. + * If no database identifier is provided, it refers to a temporary view or + * a table/view in the current database. * @since 2.0.0 */ def cacheTable(tableName: String): Unit @@ -413,6 +479,9 @@ abstract class Catalog { /** * Removes the specified table from the in-memory cache. * + * @param tableName is either a qualified or unqualified name that designates a table/view. + * If no database identifier is provided, it refers to a temporary view or + * a table/view in the current database. * @since 2.0.0 */ def uncacheTable(tableName: String): Unit @@ -425,21 +494,24 @@ abstract class Catalog { def clearCache(): Unit /** - * Invalidate and refresh all the cached metadata of the given table. For performance reasons, - * Spark SQL or the external data source library it uses might cache certain metadata about a - * table, such as the location of blocks. When those change outside of Spark SQL, users should - * call this function to invalidate the cache. + * Invalidates and refreshes all the cached data and metadata of the given table. For performance + * reasons, Spark SQL or the external data source library it uses might cache certain metadata + * about a table, such as the location of blocks. When those change outside of Spark SQL, users + * should call this function to invalidate the cache. * * If this table is cached as an InMemoryRelation, drop the original cached version and make the * new version cached lazily. * + * @param tableName is either a qualified or unqualified name that designates a table/view. + * If no database identifier is provided, it refers to a temporary view or + * a table/view in the current database. * @since 2.0.0 */ def refreshTable(tableName: String): Unit /** - * Invalidate and refresh all the cached data (and the associated metadata) for any dataframe that - * contains the given data source path. Path matching is by prefix, i.e. "/" would invalidate + * Invalidates and refreshes all the cached data (and the associated metadata) for any `Dataset` + * that contains the given data source path. Path matching is by prefix, i.e. "/" would invalidate * everything that is cached. * * @since 2.0.0 diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/CacheManager.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/CacheManager.scala index 80138510dc9ee..0ea806d6cb50b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/CacheManager.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/CacheManager.scala @@ -19,6 +19,8 @@ package org.apache.spark.sql.execution import java.util.concurrent.locks.ReentrantReadWriteLock +import scala.collection.JavaConverters._ + import org.apache.hadoop.fs.{FileSystem, Path} import org.apache.spark.internal.Logging @@ -45,7 +47,7 @@ case class CachedData(plan: LogicalPlan, cachedRepresentation: InMemoryRelation) class CacheManager extends Logging { @transient - private val cachedData = new scala.collection.mutable.ArrayBuffer[CachedData] + private val cachedData = new java.util.LinkedList[CachedData] @transient private val cacheLock = new ReentrantReadWriteLock @@ -70,7 +72,7 @@ class CacheManager extends Logging { /** Clears all cached tables. */ def clearCache(): Unit = writeLock { - cachedData.foreach(_.cachedRepresentation.cachedColumnBuffers.unpersist()) + cachedData.asScala.foreach(_.cachedRepresentation.cachedColumnBuffers.unpersist()) cachedData.clear() } @@ -88,46 +90,81 @@ class CacheManager extends Logging { query: Dataset[_], tableName: Option[String] = None, storageLevel: StorageLevel = MEMORY_AND_DISK): Unit = writeLock { - val planToCache = query.queryExecution.analyzed + val planToCache = query.logicalPlan if (lookupCachedData(planToCache).nonEmpty) { logWarning("Asked to cache already cached data.") } else { val sparkSession = query.sparkSession - cachedData += - CachedData( - planToCache, - InMemoryRelation( - sparkSession.sessionState.conf.useCompression, - sparkSession.sessionState.conf.columnBatchSize, - storageLevel, - sparkSession.sessionState.executePlan(planToCache).executedPlan, - tableName)) + cachedData.add(CachedData( + planToCache, + InMemoryRelation( + sparkSession.sessionState.conf.useCompression, + sparkSession.sessionState.conf.columnBatchSize, + storageLevel, + sparkSession.sessionState.executePlan(planToCache).executedPlan, + tableName))) } } /** - * Tries to remove the data for the given [[Dataset]] from the cache. - * No operation, if it's already uncached. + * Un-cache all the cache entries that refer to the given plan. + */ + def uncacheQuery(query: Dataset[_], blocking: Boolean = true): Unit = writeLock { + uncacheQuery(query.sparkSession, query.logicalPlan, blocking) + } + + /** + * Un-cache all the cache entries that refer to the given plan. */ - def uncacheQuery(query: Dataset[_], blocking: Boolean = true): Boolean = writeLock { - val planToCache = query.queryExecution.analyzed - val dataIndex = cachedData.indexWhere(cd => planToCache.sameResult(cd.plan)) - val found = dataIndex >= 0 - if (found) { - cachedData(dataIndex).cachedRepresentation.cachedColumnBuffers.unpersist(blocking) - cachedData.remove(dataIndex) + def uncacheQuery(spark: SparkSession, plan: LogicalPlan, blocking: Boolean): Unit = writeLock { + val it = cachedData.iterator() + while (it.hasNext) { + val cd = it.next() + if (cd.plan.find(_.sameResult(plan)).isDefined) { + cd.cachedRepresentation.cachedColumnBuffers.unpersist(blocking) + it.remove() + } } - found + } + + /** + * Tries to re-cache all the cache entries that refer to the given plan. + */ + def recacheByPlan(spark: SparkSession, plan: LogicalPlan): Unit = writeLock { + recacheByCondition(spark, _.find(_.sameResult(plan)).isDefined) + } + + private def recacheByCondition(spark: SparkSession, condition: LogicalPlan => Boolean): Unit = { + val it = cachedData.iterator() + val needToRecache = scala.collection.mutable.ArrayBuffer.empty[CachedData] + while (it.hasNext) { + val cd = it.next() + if (condition(cd.plan)) { + cd.cachedRepresentation.cachedColumnBuffers.unpersist() + // Remove the cache entry before we create a new one, so that we can have a different + // physical plan. + it.remove() + val newCache = InMemoryRelation( + useCompression = cd.cachedRepresentation.useCompression, + batchSize = cd.cachedRepresentation.batchSize, + storageLevel = cd.cachedRepresentation.storageLevel, + child = spark.sessionState.executePlan(cd.plan).executedPlan, + tableName = cd.cachedRepresentation.tableName) + needToRecache += cd.copy(cachedRepresentation = newCache) + } + } + + needToRecache.foreach(cachedData.add) } /** Optionally returns cached data for the given [[Dataset]] */ def lookupCachedData(query: Dataset[_]): Option[CachedData] = readLock { - lookupCachedData(query.queryExecution.analyzed) + lookupCachedData(query.logicalPlan) } /** Optionally returns cached data for the given [[LogicalPlan]]. */ def lookupCachedData(plan: LogicalPlan): Option[CachedData] = readLock { - cachedData.find(cd => plan.sameResult(cd.plan)) + cachedData.asScala.find(cd => plan.sameResult(cd.plan)) } /** Replaces segments of the given logical plan with cached versions where possible. */ @@ -145,40 +182,17 @@ class CacheManager extends Logging { } /** - * Invalidates the cache of any data that contains `plan`. Note that it is possible that this - * function will over invalidate. - */ - def invalidateCache(plan: LogicalPlan): Unit = writeLock { - cachedData.foreach { - case data if data.plan.collect { case p if p.sameResult(plan) => p }.nonEmpty => - data.cachedRepresentation.recache() - case _ => - } - } - - /** - * Invalidates the cache of any data that contains `resourcePath` in one or more + * Tries to re-cache all the cache entries that contain `resourcePath` in one or more * `HadoopFsRelation` node(s) as part of its logical plan. */ - def invalidateCachedPath( - sparkSession: SparkSession, resourcePath: String): Unit = writeLock { + def recacheByPath(spark: SparkSession, resourcePath: String): Unit = writeLock { val (fs, qualifiedPath) = { val path = new Path(resourcePath) - val fs = path.getFileSystem(sparkSession.sessionState.newHadoopConf()) - (fs, path.makeQualified(fs.getUri, fs.getWorkingDirectory)) + val fs = path.getFileSystem(spark.sessionState.newHadoopConf()) + (fs, fs.makeQualified(path)) } - cachedData.filter { - case data if data.plan.find(lookupAndRefresh(_, fs, qualifiedPath)).isDefined => true - case _ => false - }.foreach { data => - val dataIndex = cachedData.indexWhere(cd => data.plan.sameResult(cd.plan)) - if (dataIndex >= 0) { - data.cachedRepresentation.cachedColumnBuffers.unpersist(blocking = true) - cachedData.remove(dataIndex) - } - sparkSession.sharedState.cacheManager.cacheQuery(Dataset.ofRows(sparkSession, data.plan)) - } + recacheByCondition(spark, _.find(lookupAndRefresh(_, fs, qualifiedPath)).isDefined) } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ColumnarBatchScan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ColumnarBatchScan.scala index 04fba17be4bfa..e86116680a57a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ColumnarBatchScan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ColumnarBatchScan.scala @@ -111,17 +111,27 @@ private[sql] trait ColumnarBatchScan extends CodegenSupport { val columnsBatchInput = (output zip colVars).map { case (attr, colVar) => genCodeColumnVector(ctx, colVar, rowidx, attr.dataType, attr.nullable) } + val localIdx = ctx.freshName("localIdx") + val localEnd = ctx.freshName("localEnd") + val numRows = ctx.freshName("numRows") + val shouldStop = if (isShouldStopRequired) { + s"if (shouldStop()) { $idx = $rowidx + 1; return; }" + } else { + "// shouldStop check is eliminated" + } s""" |if ($batch == null) { | $nextBatch(); |} |while ($batch != null) { - | int numRows = $batch.numRows(); - | while ($idx < numRows) { - | int $rowidx = $idx++; + | int $numRows = $batch.numRows(); + | int $localEnd = $numRows - $idx; + | for (int $localIdx = 0; $localIdx < $localEnd; $localIdx++) { + | int $rowidx = $idx + $localIdx; | ${consume(ctx, columnsBatchInput).trim} - | if (shouldStop()) return; + | $shouldStop | } + | $idx = $numRows; | $batch = null; | $nextBatch(); |} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala index 2de24f69d52c0..eb8994ae0d5ff 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala @@ -27,23 +27,48 @@ import org.apache.spark.sql.SparkSession import org.apache.spark.sql.catalyst.{InternalRow, TableIdentifier} import org.apache.spark.sql.catalyst.catalog.BucketSpec import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode} +import org.apache.spark.sql.catalyst.expressions.codegen.CodegenContext +import org.apache.spark.sql.catalyst.plans.QueryPlan import org.apache.spark.sql.catalyst.plans.physical.{HashPartitioning, Partitioning, UnknownPartitioning} import org.apache.spark.sql.execution.datasources._ import org.apache.spark.sql.execution.datasources.parquet.{ParquetFileFormat => ParquetSource} import org.apache.spark.sql.execution.metric.SQLMetrics import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.sources.{BaseRelation, Filter} -import org.apache.spark.sql.types.{DataType, StructType} +import org.apache.spark.sql.sources.BaseRelation +import org.apache.spark.sql.types.StructType import org.apache.spark.util.Utils trait DataSourceScanExec extends LeafExecNode with CodegenSupport { val relation: BaseRelation val metastoreTableIdentifier: Option[TableIdentifier] + protected val nodeNamePrefix: String = "" + override val nodeName: String = { s"Scan $relation ${metastoreTableIdentifier.map(_.unquotedString).getOrElse("")}" } + + override def simpleString: String = { + val metadataEntries = metadata.toSeq.sorted.map { + case (key, value) => + key + ": " + StringUtils.abbreviate(redact(value), 100) + } + val metadataStr = Utils.truncatedString(metadataEntries, " ", ", ", "") + s"$nodeNamePrefix$nodeName${Utils.truncatedString(output, "[", ",", "]")}$metadataStr" + } + + override def verboseString: String = redact(super.verboseString) + + override def treeString(verbose: Boolean, addSuffix: Boolean): String = { + redact(super.treeString(verbose, addSuffix)) + } + + /** + * Shorthand for calling redactString() without specifying redacting rules + */ + private def redact(text: String): String = { + Utils.redact(SparkSession.getActiveSession.get.sparkContext.conf, text) + } } /** Physical plan node for scanning data from a relation. */ @@ -85,15 +110,6 @@ case class RowDataSourceScanExec( } } - override def simpleString: String = { - val metadataEntries = for ((key, value) <- metadata.toSeq.sorted) yield { - key + ": " + StringUtils.abbreviate(value, 100) - } - - s"$nodeName${Utils.truncatedString(output, "[", ",", "]")}" + - s"${Utils.truncatedString(metadataEntries, " ", ", ", "")}" - } - override def inputRDDs(): Seq[RDD[InternalRow]] = { rdd :: Nil } @@ -104,7 +120,7 @@ case class RowDataSourceScanExec( val input = ctx.freshName("input") ctx.addMutableState("scala.collection.Iterator", input, s"$input = inputs[0];") val exprRows = output.zipWithIndex.map{ case (a, i) => - new BoundReference(i, a.dataType, a.nullable) + BoundReference(i, a.dataType, a.nullable) } val row = ctx.freshName("row") ctx.INPUT_ROW = row @@ -121,29 +137,27 @@ case class RowDataSourceScanExec( """.stripMargin } - // Ignore rdd when checking results - override def sameResult(plan: SparkPlan): Boolean = plan match { - case other: RowDataSourceScanExec => relation == other.relation && metadata == other.metadata - case _ => false - } + // Only care about `relation` and `metadata` when canonicalizing. + override def preCanonicalized: SparkPlan = + copy(rdd = null, outputPartitioning = null, metastoreTableIdentifier = None) } /** * Physical plan node for scanning data from HadoopFsRelations. * * @param relation The file-based relation to scan. - * @param output Output attributes of the scan. - * @param outputSchema Output schema of the scan. + * @param output Output attributes of the scan, including data attributes and partition attributes. + * @param requiredSchema Required schema of the underlying relation, excluding partition columns. * @param partitionFilters Predicates to use for partition pruning. - * @param dataFilters Data source filters to use for filtering data within partitions. + * @param dataFilters Filters on non-partition columns. * @param metastoreTableIdentifier identifier for the table in the metastore. */ case class FileSourceScanExec( @transient relation: HadoopFsRelation, output: Seq[Attribute], - outputSchema: StructType, + requiredSchema: StructType, partitionFilters: Seq[Expression], - dataFilters: Seq[Filter], + dataFilters: Seq[Expression], override val metastoreTableIdentifier: Option[TableIdentifier]) extends DataSourceScanExec with ColumnarBatchScan { @@ -156,7 +170,21 @@ case class FileSourceScanExec( false } - @transient private lazy val selectedPartitions = relation.location.listFiles(partitionFilters) + @transient private lazy val selectedPartitions: Seq[PartitionDirectory] = { + val optimizerMetadataTimeNs = relation.location.metadataOpsTimeNs.getOrElse(0L) + val startTime = System.nanoTime() + val ret = relation.location.listFiles(partitionFilters, dataFilters) + val timeTakenMs = ((System.nanoTime() - startTime) + optimizerMetadataTimeNs) / 1000 / 1000 + + metrics("numFiles").add(ret.map(_.files.size.toLong).sum) + metrics("metadataTime").add(timeTakenMs) + + val executionId = sparkContext.getLocalProperty(SQLExecution.EXECUTION_ID_KEY) + SQLMetrics.postDriverMetricUpdates(sparkContext, executionId, + metrics("numFiles") :: metrics("metadataTime") :: Nil) + + ret + } override val (outputPartitioning, outputOrdering): (Partitioning, Seq[SortOrder]) = { val bucketSpec = if (relation.sparkSession.sessionState.conf.bucketingEnabled) { @@ -225,6 +253,10 @@ case class FileSourceScanExec( } } + @transient + private val pushedDownFilters = dataFilters.flatMap(DataSourceStrategy.translateFilter) + logInfo(s"Pushed Filters: ${pushedDownFilters.mkString(",")}") + // These metadata values make scan plans uniquely identifiable for equality checking. override val metadata: Map[String, String] = { def seqToString(seq: Seq[Any]) = seq.mkString("[", ", ", "]") @@ -234,10 +266,10 @@ case class FileSourceScanExec( val metadata = Map( "Format" -> relation.fileFormat.toString, - "ReadSchema" -> outputSchema.catalogString, + "ReadSchema" -> requiredSchema.catalogString, "Batched" -> supportsBatch.toString, "PartitionFilters" -> seqToString(partitionFilters), - "PushedFilters" -> seqToString(dataFilters), + "PushedFilters" -> seqToString(pushedDownFilters), "Location" -> locationDesc) val withOptPartitionCount = relation.partitionSchemaOption.map { _ => @@ -254,8 +286,8 @@ case class FileSourceScanExec( sparkSession = relation.sparkSession, dataSchema = relation.dataSchema, partitionSchema = relation.partitionSchema, - requiredSchema = outputSchema, - filters = dataFilters, + requiredSchema = requiredSchema, + filters = pushedDownFilters, options = relation.options, hadoopConf = relation.sparkSession.sessionState.newHadoopConfWithOptions(relation.options)) @@ -273,6 +305,8 @@ case class FileSourceScanExec( override lazy val metrics = Map("numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows"), + "numFiles" -> SQLMetrics.createMetric(sparkContext, "number of files"), + "metadataTime" -> SQLMetrics.createMetric(sparkContext, "metadata time (ms)"), "scanTime" -> SQLMetrics.createTimingMetric(sparkContext, "scan time")) protected override def doExecute(): RDD[InternalRow] = { @@ -302,13 +336,7 @@ case class FileSourceScanExec( } } - override def simpleString: String = { - val metadataEntries = for ((key, value) <- metadata.toSeq.sorted) yield { - key + ": " + StringUtils.abbreviate(value, 100) - } - val metadataStr = Utils.truncatedString(metadataEntries, " ", ", ", "") - s"File$nodeName${Utils.truncatedString(output, "[", ",", "]")}$metadataStr" - } + override val nodeNamePrefix: String = "File" override protected def doProduce(ctx: CodegenContext): String = { if (supportsBatch) { @@ -319,7 +347,7 @@ case class FileSourceScanExec( val input = ctx.freshName("input") ctx.addMutableState("scala.collection.Iterator", input, s"$input = inputs[0];") val exprRows = output.zipWithIndex.map{ case (a, i) => - new BoundReference(i, a.dataType, a.nullable) + BoundReference(i, a.dataType, a.nullable) } val row = ctx.freshName("row") ctx.INPUT_ROW = row @@ -361,7 +389,8 @@ case class FileSourceScanExec( val format = fsRelation.fileFormat val splitter = format.buildSplitter(session, fsRelation.location, - dataFilters, schema, session.sessionState.newHadoopConf()) + dataFilters.flatMap(DataSourceStrategy.translateFilter), schema, + session.sessionState.newHadoopConf()) val bucketed = partitionFiles.flatMap { case (file, values) => val blockLocations = getBlockLocations(file) @@ -417,7 +446,8 @@ case class FileSourceScanExec( val format = fsRelation.fileFormat val splitter = format.buildSplitter(session, fsRelation.location, - dataFilters, schema, session.sessionState.newHadoopConf()) + dataFilters.flatMap(DataSourceStrategy.translateFilter), schema, + session.sessionState.newHadoopConf()) val splitFiles = partitionFiles.flatMap { case (file, values) => val blockLocations = getBlockLocations(file) @@ -513,14 +543,13 @@ case class FileSourceScanExec( } } - override def sameResult(plan: SparkPlan): Boolean = plan match { - case other: FileSourceScanExec => - val thisPredicates = partitionFilters.map(cleanExpression) - val otherPredicates = other.partitionFilters.map(cleanExpression) - val result = relation == other.relation && metadata == other.metadata && - thisPredicates.length == otherPredicates.length && - thisPredicates.zip(otherPredicates).forall(p => p._1.semanticEquals(p._2)) - result - case _ => false + override lazy val canonicalized: FileSourceScanExec = { + FileSourceScanExec( + relation, + output.map(QueryPlan.normalizeExprId(_, output)), + requiredSchema, + partitionFilters.map(QueryPlan.normalizeExprId(_, output)), + dataFilters.map(QueryPlan.normalizeExprId(_, output)), + None) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala index 49336f424822f..3d1b481a53e75 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala @@ -19,12 +19,13 @@ package org.apache.spark.sql.execution import org.apache.spark.rdd.RDD import org.apache.spark.sql.{Encoder, Row, SparkSession} -import org.apache.spark.sql.catalyst.{CatalystConf, CatalystTypeConverters, InternalRow} +import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow} import org.apache.spark.sql.catalyst.analysis.MultiInstanceRelation import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.plans.physical.{Partitioning, UnknownPartitioning} import org.apache.spark.sql.execution.metric.SQLMetrics +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.DataType import org.apache.spark.util.Utils @@ -86,16 +87,9 @@ case class ExternalRDD[T]( override def newInstance(): ExternalRDD.this.type = ExternalRDD(outputObjAttr.newInstance(), rdd)(session).asInstanceOf[this.type] - override def sameResult(plan: LogicalPlan): Boolean = { - plan.canonicalized match { - case ExternalRDD(_, otherRDD) => rdd.id == otherRDD.id - case _ => false - } - } - override protected def stringArgs: Iterator[Any] = Iterator(output) - @transient override def computeStats(conf: CatalystConf): Statistics = Statistics( + @transient override def computeStats(conf: SQLConf): Statistics = Statistics( // TODO: Instead of returning a default value here, find a way to return a meaningful size // estimate for RDDs. See PR 1238 for more discussions. sizeInBytes = BigInt(session.sessionState.conf.defaultSizeInBytes) @@ -161,16 +155,9 @@ case class LogicalRDD( )(session).asInstanceOf[this.type] } - override def sameResult(plan: LogicalPlan): Boolean = { - plan.canonicalized match { - case LogicalRDD(_, otherRDD, _, _) => rdd.id == otherRDD.id - case _ => false - } - } - override protected def stringArgs: Iterator[Any] = Iterator(output) - @transient override def computeStats(conf: CatalystConf): Statistics = Statistics( + @transient override def computeStats(conf: SQLConf): Statistics = Statistics( // TODO: Instead of returning a default value here, find a way to return a meaningful size // estimate for RDDs. See PR 1238 for more discussions. sizeInBytes = BigInt(session.sessionState.conf.defaultSizeInBytes) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExternalAppendOnlyUnsafeRowArray.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExternalAppendOnlyUnsafeRowArray.scala new file mode 100644 index 0000000000000..458ac4ba3637c --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExternalAppendOnlyUnsafeRowArray.scala @@ -0,0 +1,243 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution + +import java.util.ConcurrentModificationException + +import scala.collection.mutable.ArrayBuffer + +import org.apache.spark.{SparkEnv, TaskContext} +import org.apache.spark.internal.Logging +import org.apache.spark.memory.TaskMemoryManager +import org.apache.spark.serializer.SerializerManager +import org.apache.spark.sql.catalyst.expressions.UnsafeRow +import org.apache.spark.sql.execution.ExternalAppendOnlyUnsafeRowArray.DefaultInitialSizeOfInMemoryBuffer +import org.apache.spark.storage.BlockManager +import org.apache.spark.util.collection.unsafe.sort.{UnsafeExternalSorter, UnsafeSorterIterator} + +/** + * An append-only array for [[UnsafeRow]]s that spills content to disk when there a predefined + * threshold of rows is reached. + * + * Setting spill threshold faces following trade-off: + * + * - If the spill threshold is too high, the in-memory array may occupy more memory than is + * available, resulting in OOM. + * - If the spill threshold is too low, we spill frequently and incur unnecessary disk writes. + * This may lead to a performance regression compared to the normal case of using an + * [[ArrayBuffer]] or [[Array]]. + */ +private[sql] class ExternalAppendOnlyUnsafeRowArray( + taskMemoryManager: TaskMemoryManager, + blockManager: BlockManager, + serializerManager: SerializerManager, + taskContext: TaskContext, + initialSize: Int, + pageSizeBytes: Long, + numRowsSpillThreshold: Int) extends Logging { + + def this(numRowsSpillThreshold: Int) { + this( + TaskContext.get().taskMemoryManager(), + SparkEnv.get.blockManager, + SparkEnv.get.serializerManager, + TaskContext.get(), + 1024, + SparkEnv.get.memoryManager.pageSizeBytes, + numRowsSpillThreshold) + } + + private val initialSizeOfInMemoryBuffer = + Math.min(DefaultInitialSizeOfInMemoryBuffer, numRowsSpillThreshold) + + private val inMemoryBuffer = if (initialSizeOfInMemoryBuffer > 0) { + new ArrayBuffer[UnsafeRow](initialSizeOfInMemoryBuffer) + } else { + null + } + + private var spillableArray: UnsafeExternalSorter = _ + private var numRows = 0 + + // A counter to keep track of total modifications done to this array since its creation. + // This helps to invalidate iterators when there are changes done to the backing array. + private var modificationsCount: Long = 0 + + private var numFieldsPerRow = 0 + + def length: Int = numRows + + def isEmpty: Boolean = numRows == 0 + + /** + * Clears up resources (eg. memory) held by the backing storage + */ + def clear(): Unit = { + if (spillableArray != null) { + // The last `spillableArray` of this task will be cleaned up via task completion listener + // inside `UnsafeExternalSorter` + spillableArray.cleanupResources() + spillableArray = null + } else if (inMemoryBuffer != null) { + inMemoryBuffer.clear() + } + numFieldsPerRow = 0 + numRows = 0 + modificationsCount += 1 + } + + def add(unsafeRow: UnsafeRow): Unit = { + if (numRows < numRowsSpillThreshold) { + inMemoryBuffer += unsafeRow.copy() + } else { + if (spillableArray == null) { + logInfo(s"Reached spill threshold of $numRowsSpillThreshold rows, switching to " + + s"${classOf[UnsafeExternalSorter].getName}") + + // We will not sort the rows, so prefixComparator and recordComparator are null + spillableArray = UnsafeExternalSorter.create( + taskMemoryManager, + blockManager, + serializerManager, + taskContext, + null, + null, + initialSize, + pageSizeBytes, + numRowsSpillThreshold, + false) + + // populate with existing in-memory buffered rows + if (inMemoryBuffer != null) { + inMemoryBuffer.foreach(existingUnsafeRow => + spillableArray.insertRecord( + existingUnsafeRow.getBaseObject, + existingUnsafeRow.getBaseOffset, + existingUnsafeRow.getSizeInBytes, + 0, + false) + ) + inMemoryBuffer.clear() + } + numFieldsPerRow = unsafeRow.numFields() + } + + spillableArray.insertRecord( + unsafeRow.getBaseObject, + unsafeRow.getBaseOffset, + unsafeRow.getSizeInBytes, + 0, + false) + } + + numRows += 1 + modificationsCount += 1 + } + + /** + * Creates an [[Iterator]] for the current rows in the array starting from a user provided index + * + * If there are subsequent [[add()]] or [[clear()]] calls made on this array after creation of + * the iterator, then the iterator is invalidated thus saving clients from thinking that they + * have read all the data while there were new rows added to this array. + */ + def generateIterator(startIndex: Int): Iterator[UnsafeRow] = { + if (startIndex < 0 || (numRows > 0 && startIndex > numRows)) { + throw new ArrayIndexOutOfBoundsException( + "Invalid `startIndex` provided for generating iterator over the array. " + + s"Total elements: $numRows, requested `startIndex`: $startIndex") + } + + if (spillableArray == null) { + new InMemoryBufferIterator(startIndex) + } else { + new SpillableArrayIterator(spillableArray.getIterator, numFieldsPerRow, startIndex) + } + } + + def generateIterator(): Iterator[UnsafeRow] = generateIterator(startIndex = 0) + + private[this] + abstract class ExternalAppendOnlyUnsafeRowArrayIterator extends Iterator[UnsafeRow] { + private val expectedModificationsCount = modificationsCount + + protected def isModified(): Boolean = expectedModificationsCount != modificationsCount + + protected def throwExceptionIfModified(): Unit = { + if (expectedModificationsCount != modificationsCount) { + throw new ConcurrentModificationException( + s"The backing ${classOf[ExternalAppendOnlyUnsafeRowArray].getName} has been modified " + + s"since the creation of this Iterator") + } + } + } + + private[this] class InMemoryBufferIterator(startIndex: Int) + extends ExternalAppendOnlyUnsafeRowArrayIterator { + + private var currentIndex = startIndex + + override def hasNext(): Boolean = !isModified() && currentIndex < numRows + + override def next(): UnsafeRow = { + throwExceptionIfModified() + val result = inMemoryBuffer(currentIndex) + currentIndex += 1 + result + } + } + + private[this] class SpillableArrayIterator( + iterator: UnsafeSorterIterator, + numFieldPerRow: Int, + startIndex: Int) + extends ExternalAppendOnlyUnsafeRowArrayIterator { + + private val currentRow = new UnsafeRow(numFieldPerRow) + + def init(): Unit = { + var i = 0 + while (i < startIndex) { + if (iterator.hasNext) { + iterator.loadNext() + } else { + throw new ArrayIndexOutOfBoundsException( + "Invalid `startIndex` provided for generating iterator over the array. " + + s"Total elements: $numRows, requested `startIndex`: $startIndex") + } + i += 1 + } + } + + // Traverse upto the given [[startIndex]] + init() + + override def hasNext(): Boolean = !isModified() && iterator.hasNext + + override def next(): UnsafeRow = { + throwExceptionIfModified() + iterator.loadNext() + currentRow.pointTo(iterator.getBaseObject, iterator.getBaseOffset, iterator.getRecordLength) + currentRow + } + } +} + +private[sql] object ExternalAppendOnlyUnsafeRowArray { + val DefaultInitialSizeOfInMemoryBuffer = 128 +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/GenerateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/GenerateExec.scala index 69be7094d2c39..c35e5638e9273 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/GenerateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/GenerateExec.scala @@ -32,7 +32,7 @@ import org.apache.spark.sql.types.{ArrayType, DataType, MapType, StructType} private[execution] sealed case class LazyIterator(func: () => TraversableOnce[InternalRow]) extends Iterator[InternalRow] { - lazy val results = func().toIterator + lazy val results: Iterator[InternalRow] = func().toIterator override def hasNext: Boolean = results.hasNext override def next(): InternalRow = results.next() } @@ -50,7 +50,7 @@ private[execution] sealed case class LazyIterator(func: () => TraversableOnce[In * @param join when true, each output row is implicitly joined with the input tuple that produced * it. * @param outer when true, each input row will be output at least once, even if the output of the - * given `generator` is empty. `outer` has no effect when `join` is false. + * given `generator` is empty. * @param generatorOutput the qualified output attributes of the generator of this node, which * constructed in analysis phase, and we can not change it, as the * parent node bound with it already. @@ -78,15 +78,15 @@ case class GenerateExec( override def outputPartitioning: Partitioning = child.outputPartitioning - val boundGenerator = BindReferences.bindReference(generator, child.output) + lazy val boundGenerator: Generator = BindReferences.bindReference(generator, child.output) protected override def doExecute(): RDD[InternalRow] = { // boundGenerator.terminate() should be triggered after all of the rows in the partition - val rows = if (join) { - child.execute().mapPartitionsInternal { iter => - val generatorNullRow = new GenericInternalRow(generator.elementSchema.length) + val numOutputRows = longMetric("numOutputRows") + child.execute().mapPartitionsWithIndexInternal { (index, iter) => + val generatorNullRow = new GenericInternalRow(generator.elementSchema.length) + val rows = if (join) { val joinedRow = new JoinedRow - iter.flatMap { row => // we should always set the left (child output) joinedRow.withLeft(row) @@ -101,25 +101,28 @@ case class GenerateExec( // keep it the same as Hive does joinedRow.withRight(row) } + } else { + iter.flatMap { row => + val outputRows = boundGenerator.eval(row) + if (outer && outputRows.isEmpty) { + Seq(generatorNullRow) + } else { + outputRows + } + } ++ LazyIterator(boundGenerator.terminate) } - } else { - child.execute().mapPartitionsInternal { iter => - iter.flatMap(boundGenerator.eval) ++ LazyIterator(boundGenerator.terminate) - } - } - val numOutputRows = longMetric("numOutputRows") - rows.mapPartitionsWithIndexInternal { (index, iter) => + // Convert the rows to unsafe rows. val proj = UnsafeProjection.create(output, output) proj.initialize(index) - iter.map { r => + rows.map { r => numOutputRows += 1 proj(r) } } } - override def supportCodegen: Boolean = generator.supportCodegen + override def supportCodegen: Boolean = false override def inputRDDs(): Seq[RDD[InternalRow]] = { child.asInstanceOf[CodegenSupport].inputRDDs() diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/LocalTableScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/LocalTableScanExec.scala index e366b9af35c62..19c68c13262a5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/LocalTableScanExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/LocalTableScanExec.scala @@ -33,7 +33,7 @@ case class LocalTableScanExec( override lazy val metrics = Map( "numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows")) - private val unsafeRows: Array[InternalRow] = { + private lazy val unsafeRows: Array[InternalRow] = { if (rows.isEmpty) { Array.empty } else { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/OptimizeMetadataOnlyQuery.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/OptimizeMetadataOnlyQuery.scala index aa578f4d23133..3c046ce494285 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/OptimizeMetadataOnlyQuery.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/OptimizeMetadataOnlyQuery.scala @@ -98,14 +98,14 @@ case class OptimizeMetadataOnlyQuery( relation match { case l @ LogicalRelation(fsRelation: HadoopFsRelation, _, _) => val partAttrs = getPartitionAttrs(fsRelation.partitionSchema.map(_.name), l) - val partitionData = fsRelation.location.listFiles(filters = Nil) + val partitionData = fsRelation.location.listFiles(Nil, Nil) LocalRelation(partAttrs, partitionData.map(_.values)) case relation: CatalogRelation => val partAttrs = getPartitionAttrs(relation.tableMeta.partitionColumnNames, relation) val caseInsensitiveProperties = CaseInsensitiveMap(relation.tableMeta.storage.properties) - val timeZoneId = caseInsensitiveProperties.get("timeZone") + val timeZoneId = caseInsensitiveProperties.get(DateTimeUtils.TIMEZONE_OPTION) .getOrElse(conf.sessionLocalTimeZone) val partitionData = catalog.listPartitions(relation.tableMeta.identifier).map { p => InternalRow.fromSeq(partAttrs.map { attr => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala index 6ec2f4d840862..8e8210e334a1d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala @@ -46,9 +46,14 @@ class QueryExecution(val sparkSession: SparkSession, val logical: LogicalPlan) { protected def planner = sparkSession.sessionState.planner def assertAnalyzed(): Unit = { - try sparkSession.sessionState.analyzer.checkAnalysis(analyzed) catch { + // Analyzer is invoked outside the try block to avoid calling it again from within the + // catch block below. + analyzed + try { + sparkSession.sessionState.analyzer.checkAnalysis(analyzed) + } catch { case e: AnalysisException => - val ae = new AnalysisException(e.message, e.line, e.startPosition, Some(analyzed)) + val ae = new AnalysisException(e.message, e.line, e.startPosition, Option(analyzed)) ae.setStackTrace(e.getStackTrace) throw ae } @@ -122,8 +127,8 @@ class QueryExecution(val sparkSession: SparkSession, val logical: LogicalPlan) { .map(s => String.format(s"%-20s", s)) .mkString("\t") } - // SHOW TABLES in Hive only output table names, while ours outputs database, table name, isTemp. - case command: ExecutedCommandExec if command.cmd.isInstanceOf[ShowTablesCommand] => + // SHOW TABLES in Hive only output table names, while ours output database, table name, isTemp. + case command @ ExecutedCommandExec(s: ShowTablesCommand) if !s.isExtended => command.executeCollect().map(_.getString(1)) case other => val result: Seq[Seq[Any]] = other.executeCollectPublic().map(_.toSeq).toSeq diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SortExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SortExec.scala index cc576bbc4c802..f98ae82574d20 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SortExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SortExec.scala @@ -177,6 +177,8 @@ case class SortExec( """.stripMargin.trim } + protected override val shouldStopRequired = false + override def doConsume(ctx: CodegenContext, input: Seq[ExprCode], row: ExprCode): String = { s""" |${row.code} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkOptimizer.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkOptimizer.scala index 981728331d361..1de4f508b89a0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkOptimizer.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkOptimizer.scala @@ -30,9 +30,23 @@ class SparkOptimizer( experimentalMethods: ExperimentalMethods) extends Optimizer(catalog, conf) { - override def batches: Seq[Batch] = super.batches :+ + override def batches: Seq[Batch] = (preOptimizationBatches ++ super.batches :+ Batch("Optimize Metadata Only Query", Once, OptimizeMetadataOnlyQuery(catalog, conf)) :+ Batch("Extract Python UDF from Aggregate", Once, ExtractPythonUDFFromAggregate) :+ - Batch("Prune File Source Table Partitions", Once, PruneFileSourcePartitions) :+ + Batch("Prune File Source Table Partitions", Once, PruneFileSourcePartitions)) ++ + postHocOptimizationBatches :+ Batch("User Provided Optimizers", fixedPoint, experimentalMethods.extraOptimizations: _*) + + /** + * Optimization batches that are executed before the regular optimization batches (also before + * the finish analysis batch). + */ + def preOptimizationBatches: Seq[Batch] = Nil + + /** + * Optimization batches that are executed after the regular optimization batches, but before the + * batch executing the [[ExperimentalMethods]] optimizer rules. This hook can be used to add + * custom optimizer batches to the Spark optimizer. + */ + def postHocOptimizationBatches: Seq[Batch] = Nil } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanner.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanner.scala index 678241656c011..4e718d609c921 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanner.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanner.scala @@ -27,21 +27,28 @@ import org.apache.spark.sql.internal.SQLConf class SparkPlanner( val sparkContext: SparkContext, val conf: SQLConf, - val extraStrategies: Seq[Strategy]) + val experimentalMethods: ExperimentalMethods) extends SparkStrategies { def numPartitions: Int = conf.numShufflePartitions def strategies: Seq[Strategy] = - extraStrategies ++ ( + experimentalMethods.extraStrategies ++ + extraPlanningStrategies ++ ( FileSourceStrategy :: - DataSourceStrategy :: + DataSourceStrategy(conf) :: SpecialLimits :: Aggregation :: JoinSelection :: InMemoryScans :: BasicOperators :: Nil) + /** + * Override to add extra planning strategies to the planner. These strategies are tried after + * the strategies defined in [[ExperimentalMethods]], and before the regular strategies. + */ + def extraPlanningStrategies: Seq[Strategy] = Nil + override protected def collectPlaceholders(plan: SparkPlan): Seq[(SparkPlan, LogicalPlan)] = { plan.collect { case placeholder @ PlanLater(logicalPlan) => placeholder -> logicalPlan diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala index 65df688689397..20dacf88504f1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.execution +import java.util.Locale + import scala.collection.JavaConverters._ import org.antlr.v4.runtime.{ParserRuleContext, Token} @@ -103,7 +105,7 @@ class SparkSqlAstBuilder(conf: SQLConf) extends AstBuilder { logWarning(s"Partition specification is ignored: ${ctx.partitionSpec.getText}") } if (ctx.identifier != null) { - if (ctx.identifier.getText.toLowerCase != "noscan") { + if (ctx.identifier.getText.toLowerCase(Locale.ROOT) != "noscan") { throw new ParseException(s"Expected `NOSCAN` instead of `${ctx.identifier.getText}`", ctx) } AnalyzeTableCommand(visitTableIdentifier(ctx.tableIdentifier)) @@ -134,7 +136,8 @@ class SparkSqlAstBuilder(conf: SQLConf) extends AstBuilder { ShowTablesCommand( Option(ctx.db).map(_.getText), Option(ctx.pattern).map(string), - isExtended = false) + isExtended = false, + partitionSpec = None) } /** @@ -146,14 +149,12 @@ class SparkSqlAstBuilder(conf: SQLConf) extends AstBuilder { * }}} */ override def visitShowTable(ctx: ShowTableContext): LogicalPlan = withOrigin(ctx) { - if (ctx.partitionSpec != null) { - operationNotAllowed("SHOW TABLE EXTENDED ... PARTITION", ctx) - } - + val partitionSpec = Option(ctx.partitionSpec).map(visitNonOptionalPartitionSpec) ShowTablesCommand( Option(ctx.db).map(_.getText), Option(ctx.pattern).map(string), - isExtended = true) + isExtended = true, + partitionSpec = partitionSpec) } /** @@ -323,8 +324,7 @@ class SparkSqlAstBuilder(conf: SQLConf) extends AstBuilder { DescribeTableCommand( visitTableIdentifier(ctx.tableIdentifier), partitionSpec, - ctx.EXTENDED != null, - ctx.FORMATTED != null) + ctx.EXTENDED != null || ctx.FORMATTED != null) } } @@ -386,7 +386,7 @@ class SparkSqlAstBuilder(conf: SQLConf) extends AstBuilder { "LOCATION and 'path' in OPTIONS are both used to indicate the custom table path, " + "you can only specify one of them.", ctx) } - val customLocation = storage.locationUri.orElse(location) + val customLocation = storage.locationUri.orElse(location.map(CatalogUtils.stringToURI(_))) val tableType = if (customLocation.isDefined) { CatalogTableType.EXTERNAL @@ -565,7 +565,7 @@ class SparkSqlAstBuilder(conf: SQLConf) extends AstBuilder { } else if (value.STRING != null) { string(value.STRING) } else if (value.booleanValue != null) { - value.getText.toLowerCase + value.getText.toLowerCase(Locale.ROOT) } else { value.getText } @@ -649,7 +649,7 @@ class SparkSqlAstBuilder(conf: SQLConf) extends AstBuilder { */ override def visitShowFunctions(ctx: ShowFunctionsContext): LogicalPlan = withOrigin(ctx) { import ctx._ - val (user, system) = Option(ctx.identifier).map(_.getText.toLowerCase) match { + val (user, system) = Option(ctx.identifier).map(_.getText.toLowerCase(Locale.ROOT)) match { case None | Some("all") => (true, true) case Some("system") => (false, true) case Some("user") => (true, false) @@ -679,7 +679,7 @@ class SparkSqlAstBuilder(conf: SQLConf) extends AstBuilder { */ override def visitCreateFunction(ctx: CreateFunctionContext): LogicalPlan = withOrigin(ctx) { val resources = ctx.resource.asScala.map { resource => - val resourceType = resource.identifier.getText.toLowerCase + val resourceType = resource.identifier.getText.toLowerCase(Locale.ROOT) resourceType match { case "jar" | "file" | "archive" => FunctionResource(FunctionResourceType.fromString(resourceType), string(resource.STRING)) @@ -742,6 +742,22 @@ class SparkSqlAstBuilder(conf: SQLConf) extends AstBuilder { ctx.VIEW != null) } + /** + * Create a [[AlterTableAddColumnsCommand]] command. + * + * For example: + * {{{ + * ALTER TABLE table1 + * ADD COLUMNS (col_name data_type [COMMENT col_comment], ...); + * }}} + */ + override def visitAddTableColumns(ctx: AddTableColumnsContext): LogicalPlan = withOrigin(ctx) { + AlterTableAddColumnsCommand( + visitTableIdentifier(ctx.tableIdentifier), + visitColTypeList(ctx.columns) + ) + } + /** * Create an [[AlterTableSetPropertiesCommand]] command. * @@ -945,7 +961,7 @@ class SparkSqlAstBuilder(conf: SQLConf) extends AstBuilder { .flatMap(_.orderedIdentifier.asScala) .map { orderedIdCtx => Option(orderedIdCtx.ordering).map(_.getText).foreach { dir => - if (dir.toLowerCase != "asc") { + if (dir.toLowerCase(Locale.ROOT) != "asc") { operationNotAllowed(s"Column ordering must be ASC, was '$dir'", ctx) } } @@ -998,13 +1014,13 @@ class SparkSqlAstBuilder(conf: SQLConf) extends AstBuilder { val mayebePaths = remainder(ctx.identifier).trim ctx.op.getType match { case SqlBaseParser.ADD => - ctx.identifier.getText.toLowerCase match { + ctx.identifier.getText.toLowerCase(Locale.ROOT) match { case "file" => AddFileCommand(mayebePaths) case "jar" => AddJarCommand(mayebePaths) case other => operationNotAllowed(s"ADD with resource type '$other'", ctx) } case SqlBaseParser.LIST => - ctx.identifier.getText.toLowerCase match { + ctx.identifier.getText.toLowerCase(Locale.ROOT) match { case "files" | "file" => if (mayebePaths.length > 0) { ListFilesCommand(mayebePaths.split("\\s+")) @@ -1080,8 +1096,10 @@ class SparkSqlAstBuilder(conf: SQLConf) extends AstBuilder { if (external && location.isEmpty) { operationNotAllowed("CREATE EXTERNAL TABLE must be accompanied by LOCATION", ctx) } + + val locUri = location.map(CatalogUtils.stringToURI(_)) val storage = CatalogStorageFormat( - locationUri = location, + locationUri = locUri, inputFormat = fileStorage.inputFormat.orElse(defaultStorage.inputFormat), outputFormat = fileStorage.outputFormat.orElse(defaultStorage.outputFormat), serde = rowStorage.serde.orElse(fileStorage.serde).orElse(defaultStorage.serde), @@ -1132,7 +1150,7 @@ class SparkSqlAstBuilder(conf: SQLConf) extends AstBuilder { // At here, both rowStorage.serdeProperties and fileStorage.serdeProperties // are empty Maps. val newTableDesc = tableDesc.copy( - storage = CatalogStorageFormat.empty.copy(locationUri = location), + storage = CatalogStorageFormat.empty.copy(locationUri = locUri), provider = Some(conf.defaultDataSourceName)) CreateTable(newTableDesc, mode, Some(q)) } else { @@ -1289,7 +1307,7 @@ class SparkSqlAstBuilder(conf: SQLConf) extends AstBuilder { (rowFormatCtx, createFileFormatCtx.fileFormat) match { case (_, ffTable: TableFileFormatContext) => // OK case (rfSerde: RowFormatSerdeContext, ffGeneric: GenericFileFormatContext) => - ffGeneric.identifier.getText.toLowerCase match { + ffGeneric.identifier.getText.toLowerCase(Locale.ROOT) match { case ("sequencefile" | "textfile" | "rcfile") => // OK case fmt => operationNotAllowed( @@ -1297,7 +1315,7 @@ class SparkSqlAstBuilder(conf: SQLConf) extends AstBuilder { parentCtx) } case (rfDelimited: RowFormatDelimitedContext, ffGeneric: GenericFileFormatContext) => - ffGeneric.identifier.getText.toLowerCase match { + ffGeneric.identifier.getText.toLowerCase(Locale.ROOT) match { case "textfile" => // OK case fmt => operationNotAllowed( s"ROW FORMAT DELIMITED is only compatible with 'textfile', not '$fmt'", parentCtx) @@ -1329,6 +1347,15 @@ class SparkSqlAstBuilder(conf: SQLConf) extends AstBuilder { if (ctx.identifierList != null) { operationNotAllowed("CREATE VIEW ... PARTITIONED ON", ctx) } else { + // CREATE VIEW ... AS INSERT INTO is not allowed. + ctx.query.queryNoWith match { + case s: SingleInsertQueryContext if s.insertInto != null => + operationNotAllowed("CREATE VIEW ... AS INSERT INTO", ctx) + case _: MultiInsertQueryContext => + operationNotAllowed("CREATE VIEW ... AS FROM ... [INSERT INTO ...]+", ctx) + case _ => // OK + } + val userSpecifiedColumns = Option(ctx.identifierCommentList).toSeq.flatMap { icl => icl.identifierComment.asScala.map { ic => ic.identifier.getText -> Option(ic.STRING).map(string) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index 20bf4925dbec5..ca2f6dd7a84b2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -326,16 +326,17 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { } /** - * Strategy to convert MapGroupsWithState logical operator to physical operator + * Strategy to convert [[FlatMapGroupsWithState]] logical operator to physical operator * in streaming plans. Conversion for batch plans is handled by [[BasicOperators]]. */ - object MapGroupsWithStateStrategy extends Strategy { + object FlatMapGroupsWithStateStrategy extends Strategy { override def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { - case MapGroupsWithState( - f, keyDeser, valueDeser, groupAttr, dataAttr, outputAttr, stateDeser, stateSer, child) => - val execPlan = MapGroupsWithStateExec( - f, keyDeser, valueDeser, groupAttr, dataAttr, outputAttr, None, stateDeser, stateSer, - planLater(child)) + case FlatMapGroupsWithState( + func, keyDeser, valueDeser, groupAttr, dataAttr, outputAttr, stateEnc, outputMode, _, + timeout, child) => + val execPlan = FlatMapGroupsWithStateExec( + func, keyDeser, valueDeser, groupAttr, dataAttr, outputAttr, None, stateEnc, outputMode, + timeout, batchTimestampMs = None, eventTimeWatermark = None, planLater(child)) execPlan :: Nil case _ => Nil @@ -381,7 +382,8 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { execution.AppendColumnsWithObjectExec(f, childSer, newSer, planLater(child)) :: Nil case logical.MapGroups(f, key, value, grouping, data, objAttr, child) => execution.MapGroupsExec(f, key, value, grouping, data, objAttr, planLater(child)) :: Nil - case logical.MapGroupsWithState(f, key, value, grouping, data, output, _, _, child) => + case logical.FlatMapGroupsWithState( + f, key, value, grouping, data, output, _, _, _, _, child) => execution.MapGroupsExec(f, key, value, grouping, data, output, planLater(child)) :: Nil case logical.CoGroup(f, key, lObj, rObj, lGroup, rGroup, lAttr, rAttr, oAttr, left, right) => execution.CoGroupExec( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala index c58474eba05d4..c1e1a631c677e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala @@ -17,7 +17,9 @@ package org.apache.spark.sql.execution -import org.apache.spark.{broadcast, TaskContext} +import java.util.Locale + +import org.apache.spark.broadcast import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ @@ -43,7 +45,7 @@ trait CodegenSupport extends SparkPlan { case _: SortMergeJoinExec => "smj" case _: RDDScanExec => "rdd" case _: DataSourceScanExec => "scan" - case _ => nodeName.toLowerCase + case _ => nodeName.toLowerCase(Locale.ROOT) } /** @@ -206,6 +208,21 @@ trait CodegenSupport extends SparkPlan { def doConsume(ctx: CodegenContext, input: Seq[ExprCode], row: ExprCode): String = { throw new UnsupportedOperationException } + + /** + * For optimization to suppress shouldStop() in a loop of WholeStageCodegen. + * Returning true means we need to insert shouldStop() into the loop producing rows, if any. + */ + def isShouldStopRequired: Boolean = { + return shouldStopRequired && (this.parent == null || this.parent.isShouldStopRequired) + } + + /** + * Set to false if this plan consumes all rows produced by children but doesn't output row + * to buffer by calling append(), so the children don't require shouldStop() + * in the loop of producing rows. + */ + protected def shouldStopRequired: Boolean = true } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala index 4529ed067e565..68c8e6ce62cbb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala @@ -238,6 +238,8 @@ case class HashAggregateExec( """.stripMargin } + protected override val shouldStopRequired = false + private def doConsumeWithoutKeys(ctx: CodegenContext, input: Seq[ExprCode]): String = { // only have DeclarativeAggregate val functions = aggregateExpressions.map(_.aggregateFunction.asInstanceOf[DeclarativeAggregate]) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala index 87e90ed685cca..64698d5527578 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala @@ -331,10 +331,11 @@ case class SampleExec( case class RangeExec(range: org.apache.spark.sql.catalyst.plans.logical.Range) extends LeafExecNode with CodegenSupport { - def start: Long = range.start - def step: Long = range.step - def numSlices: Int = range.numSlices.getOrElse(sparkContext.defaultParallelism) - def numElements: BigInt = range.numElements + val start: Long = range.start + val end: Long = range.end + val step: Long = range.step + val numSlices: Int = range.numSlices.getOrElse(sparkContext.defaultParallelism) + val numElements: BigInt = range.numElements override val output: Seq[Attribute] = range.output @@ -342,8 +343,9 @@ case class RangeExec(range: org.apache.spark.sql.catalyst.plans.logical.Range) "numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows"), "numGeneratedRows" -> SQLMetrics.createMetric(sparkContext, "number of generated rows")) - // output attributes should not affect the results - override lazy val cleanArgs: Seq[Any] = Seq(start, step, numSlices, numElements) + override lazy val canonicalized: SparkPlan = { + RangeExec(range.canonicalized.asInstanceOf[org.apache.spark.sql.catalyst.plans.logical.Range]) + } override def inputRDDs(): Seq[RDD[InternalRow]] = { sqlContext.sparkContext.parallelize(0 until numSlices, numSlices) @@ -387,8 +389,8 @@ case class RangeExec(range: org.apache.spark.sql.catalyst.plans.logical.Range) // How many values should be generated in the next batch. val nextBatchTodo = ctx.freshName("nextBatchTodo") - // The default size of a batch. - val batchSize = 1000L + // The default size of a batch, which must be positive integer + val batchSize = 1000 ctx.addNewFunction("initRange", s""" @@ -434,6 +436,15 @@ case class RangeExec(range: org.apache.spark.sql.catalyst.plans.logical.Range) val input = ctx.freshName("input") // Right now, Range is only used when there is one upstream. ctx.addMutableState("scala.collection.Iterator", input, s"$input = inputs[0];") + + val localIdx = ctx.freshName("localIdx") + val localEnd = ctx.freshName("localEnd") + val range = ctx.freshName("range") + val shouldStop = if (isShouldStopRequired) { + s"if (shouldStop()) { $number = $value + ${step}L; return; }" + } else { + "// shouldStop check is eliminated" + } s""" | // initialize Range | if (!$initTerm) { @@ -442,16 +453,18 @@ case class RangeExec(range: org.apache.spark.sql.catalyst.plans.logical.Range) | } | | while (true) { - | while ($number != $batchEnd) { - | long $value = $number; - | $number += ${step}L; - | ${consume(ctx, Seq(ev))} - | if (shouldStop()) return; + | long $range = $batchEnd - $number; + | if ($range != 0L) { + | int $localEnd = (int)($range / ${step}L); + | for (int $localIdx = 0; $localIdx < $localEnd; $localIdx++) { + | long $value = ((long)$localIdx * ${step}L) + $number; + | ${consume(ctx, Seq(ev))} + | $shouldStop + | } + | $number = $batchEnd; | } | - | if ($taskContext.isInterrupted()) { - | throw new TaskKilledException(); - | } + | $taskContext.killTaskIfInterrupted(); | | long $nextBatchTodo; | if ($numElementsTodo > ${batchSize}L) { @@ -526,7 +539,7 @@ case class RangeExec(range: org.apache.spark.sql.catalyst.plans.logical.Range) } } - override def simpleString: String = range.simpleString + override def simpleString: String = s"Range ($start, $end, step=$step, splits=$numSlices)" } /** @@ -594,11 +607,6 @@ case class SubqueryExec(name: String, child: SparkPlan) extends UnaryExecNode { override def outputOrdering: Seq[SortOrder] = child.outputOrdering - override def sameResult(o: SparkPlan): Boolean = o match { - case s: SubqueryExec => child.sameResult(s.child) - case _ => false - } - @transient private lazy val relationFuture: Future[Array[InternalRow]] = { // relationFuture is used in "doExecute". Therefore we can get the execution id correctly here. @@ -615,13 +623,7 @@ case class SubqueryExec(name: String, child: SparkPlan) extends UnaryExecNode { val dataSize = rows.map(_.asInstanceOf[UnsafeRow].getSizeInBytes.toLong).sum longMetric("dataSize") += dataSize - // There are some cases we don't care about the metrics and call `SparkPlan.doExecute` - // directly without setting an execution id. We should be tolerant to it. - if (executionId != null) { - sparkContext.listenerBus.post(SparkListenerDriverAccumUpdates( - executionId.toLong, metrics.values.map(m => m.id -> m.value).toSeq)) - } - + SQLMetrics.postDriverMetricUpdates(sparkContext, executionId, metrics.values.toSeq) rows } }(SubqueryExec.executionContext) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryRelation.scala index 37bd95e737786..0a9f3e799990f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryRelation.scala @@ -21,12 +21,13 @@ import org.apache.commons.lang3.StringUtils import org.apache.spark.network.util.JavaUtils import org.apache.spark.rdd.RDD -import org.apache.spark.sql.catalyst.{CatalystConf, InternalRow} +import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.MultiInstanceRelation import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.logical import org.apache.spark.sql.catalyst.plans.logical.Statistics import org.apache.spark.sql.execution.SparkPlan +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.storage.StorageLevel import org.apache.spark.util.LongAccumulator @@ -69,7 +70,7 @@ case class InMemoryRelation( @transient val partitionStatistics = new PartitionStatistics(output) - override def computeStats(conf: CatalystConf): Statistics = { + override def computeStats(conf: SQLConf): Statistics = { if (batchStats.value == 0L) { // Underlying columnar RDD hasn't been materialized, no useful statistics information // available, return the default statistics. @@ -85,12 +86,6 @@ case class InMemoryRelation( buildBuffers() } - def recache(): Unit = { - _cachedColumnBuffers.unpersist() - _cachedColumnBuffers = null - buildBuffers() - } - private def buildBuffers(): Unit = { val output = child.output val cached = child.execute().mapPartitionsInternal { rowIterator => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala index 9028caa446e8c..7063b08f7c644 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala @@ -22,7 +22,7 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.QueryPlan -import org.apache.spark.sql.catalyst.plans.physical.Partitioning +import org.apache.spark.sql.catalyst.plans.physical.{HashPartitioning, Partitioning} import org.apache.spark.sql.execution.LeafExecNode import org.apache.spark.sql.execution.metric.SQLMetrics import org.apache.spark.sql.types.UserDefinedType @@ -41,11 +41,28 @@ case class InMemoryTableScanExec( override def output: Seq[Attribute] = attributes + private def updateAttribute(expr: Expression): Expression = { + // attributes can be pruned so using relation's output. + // E.g., relation.output is [id, item] but this scan's output can be [item] only. + val attrMap = AttributeMap(relation.child.output.zip(relation.output)) + expr.transform { + case attr: Attribute => attrMap.getOrElse(attr, attr) + } + } + // The cached version does not change the outputPartitioning of the original SparkPlan. - override def outputPartitioning: Partitioning = relation.child.outputPartitioning + // But the cached version could alias output, so we need to replace output. + override def outputPartitioning: Partitioning = { + relation.child.outputPartitioning match { + case h: HashPartitioning => updateAttribute(h).asInstanceOf[HashPartitioning] + case _ => relation.child.outputPartitioning + } + } // The cached version does not change the outputOrdering of the original SparkPlan. - override def outputOrdering: Seq[SortOrder] = relation.child.outputOrdering + // But the cached version could alias output, so we need to replace output. + override def outputOrdering: Seq[SortOrder] = + relation.child.outputOrdering.map(updateAttribute(_).asInstanceOf[SortOrder]) private def statsFor(a: Attribute) = relation.partitionStatistics.forAttribute(a) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeColumnCommand.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeColumnCommand.scala index b89014ed8ef54..0d8db2ff5d5a0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeColumnCommand.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeColumnCommand.scala @@ -73,10 +73,10 @@ case class AnalyzeColumnCommand( val relation = sparkSession.table(tableIdent).logicalPlan // Resolve the column names and dedup using AttributeSet val resolver = sparkSession.sessionState.conf.resolver - val attributesToAnalyze = AttributeSet(columnNames.map { col => + val attributesToAnalyze = columnNames.map { col => val exprOption = relation.output.find(attr => resolver(attr.name, col)) exprOption.getOrElse(throw new AnalysisException(s"Column $col does not exist.")) - }).toSeq + } // Make sure the column types are supported for stats gathering. attributesToAnalyze.foreach { attr => @@ -99,8 +99,8 @@ case class AnalyzeColumnCommand( val statsRow = Dataset.ofRows(sparkSession, Aggregate(Nil, namedExpressions, relation)).head() val rowCount = statsRow.getLong(0) - val columnStats = attributesToAnalyze.zipWithIndex.map { case (expr, i) => - (expr.name, ColumnStat.rowToColumnStat(statsRow.getStruct(i + 1))) + val columnStats = attributesToAnalyze.zipWithIndex.map { case (attr, i) => + (attr.name, ColumnStat.rowToColumnStat(statsRow.getStruct(i + 1), attr)) }.toMap (rowCount, columnStats) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/SetCommand.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/SetCommand.scala index 7afa4e78a3786..5f12830ee621f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/SetCommand.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/SetCommand.scala @@ -60,6 +60,23 @@ case class SetCommand(kv: Option[(String, Option[String])]) extends RunnableComm } (keyValueOutput, runFunc) + case Some((SQLConf.Replaced.MAPREDUCE_JOB_REDUCES, Some(value))) => + val runFunc = (sparkSession: SparkSession) => { + logWarning( + s"Property ${SQLConf.Replaced.MAPREDUCE_JOB_REDUCES} is Hadoop's property, " + + s"automatically converted to ${SQLConf.SHUFFLE_PARTITIONS.key} instead.") + if (value.toInt < 1) { + val msg = + s"Setting negative ${SQLConf.Replaced.MAPREDUCE_JOB_REDUCES} for automatically " + + "determining the number of reducers is not supported." + throw new IllegalArgumentException(msg) + } else { + sparkSession.conf.set(SQLConf.SHUFFLE_PARTITIONS.key, value) + Seq(Row(SQLConf.SHUFFLE_PARTITIONS.key, value)) + } + } + (keyValueOutput, runFunc) + case Some((key @ SetCommand.VariableName(name), Some(value))) => val runFunc = (sparkSession: SparkSession) => { sparkSession.conf.set(name, value) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/commands.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/commands.scala index 5de45b159684c..41d91d877d4c2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/commands.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/commands.scala @@ -27,7 +27,7 @@ import org.apache.spark.sql.catalyst.plans.logical import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.execution.SparkPlan import org.apache.spark.sql.execution.debug._ -import org.apache.spark.sql.execution.streaming.IncrementalExecution +import org.apache.spark.sql.execution.streaming.{IncrementalExecution, OffsetSeqMetadata} import org.apache.spark.sql.streaming.OutputMode import org.apache.spark.sql.types._ @@ -106,7 +106,8 @@ case class ExplainCommand( if (logicalPlan.isStreaming) { // This is used only by explaining `Dataset/DataFrame` created by `spark.readStream`, so the // output mode does not matter since there is no `Sink`. - new IncrementalExecution(sparkSession, logicalPlan, OutputMode.Append(), "", 0, 0) + new IncrementalExecution( + sparkSession, logicalPlan, OutputMode.Append(), "", 0, OffsetSeqMetadata(0, 0)) } else { sparkSession.sessionState.executePlan(logicalPlan) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/createDataSourceTables.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/createDataSourceTables.scala index d835b521166a8..2d890118ae0a5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/createDataSourceTables.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/createDataSourceTables.scala @@ -17,6 +17,10 @@ package org.apache.spark.sql.execution.command +import java.net.URI + +import org.apache.hadoop.fs.Path + import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.catalog._ import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan @@ -54,7 +58,7 @@ case class CreateDataSourceTableCommand(table: CatalogTable, ignoreIfExists: Boo // Create the relation to validate the arguments before writing the metadata to the metastore, // and infer the table schema and partition if users didn't specify schema in CREATE TABLE. - val pathOption = table.storage.locationUri.map("path" -> _) + val pathOption = table.storage.locationUri.map("path" -> CatalogUtils.URIToString(_)) // Fill in some default table options from the session conf val tableWithDefaultOptions = table.copy( identifier = table.identifier.copy( @@ -69,7 +73,8 @@ case class CreateDataSourceTableCommand(table: CatalogTable, ignoreIfExists: Boo className = table.provider.get, bucketSpec = table.bucketSpec, options = table.storage.properties ++ pathOption, - catalogTable = Some(tableWithDefaultOptions)).resolveRelation() + // As discussed in SPARK-19583, we don't check if the location is existed + catalogTable = Some(tableWithDefaultOptions)).resolveRelation(checkFilesExist = false) val partitionColumnNames = if (table.schema.nonEmpty) { table.partitionColumnNames @@ -175,12 +180,12 @@ case class CreateDataSourceTableAsSelectCommand( private def saveDataIntoTable( session: SparkSession, table: CatalogTable, - tableLocation: Option[String], + tableLocation: Option[URI], data: LogicalPlan, mode: SaveMode, tableExists: Boolean): BaseRelation = { // Create the relation based on the input logical plan: `data`. - val pathOption = tableLocation.map("path" -> _) + val pathOption = tableLocation.map("path" -> CatalogUtils.URIToString(_)) val dataSource = DataSource( session, className = table.provider.get, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/databases.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/databases.scala index e5a6a5f60b8a6..470c736da98b7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/databases.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/databases.scala @@ -24,7 +24,7 @@ import org.apache.spark.sql.types.StringType /** * A command for users to list the databases/schemas. - * If a databasePattern is supplied then the databases that only matches the + * If a databasePattern is supplied then the databases that only match the * pattern would be listed. * The syntax of using this command in SQL is: * {{{ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala index 82cbb4aa47445..55540563ef911 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.execution.command +import java.util.Locale + import scala.collection.{GenMap, GenSeq} import scala.collection.parallel.ForkJoinTaskSupport import scala.concurrent.forkjoin.ForkJoinPool @@ -66,7 +68,7 @@ case class CreateDatabaseCommand( CatalogDatabase( databaseName, comment.getOrElse(""), - path.getOrElse(catalog.getDefaultDBPath(databaseName)), + path.map(CatalogUtils.stringToURI(_)).getOrElse(catalog.getDefaultDBPath(databaseName)), props), ifNotExists) Seq.empty[Row] @@ -146,7 +148,7 @@ case class DescribeDatabaseCommand( val result = Row("Database Name", dbMetadata.name) :: Row("Description", dbMetadata.description) :: - Row("Location", dbMetadata.locationUri) :: Nil + Row("Location", CatalogUtils.URIToString(dbMetadata.locationUri)) :: Nil if (extended) { val properties = @@ -199,8 +201,7 @@ case class DropTableCommand( } } try { - sparkSession.sharedState.cacheManager.uncacheQuery( - sparkSession.table(tableName.quotedString)) + sparkSession.sharedState.cacheManager.uncacheQuery(sparkSession.table(tableName)) } catch { case _: NoSuchTableException if ifExists => case NonFatal(e) => log.warn(e.toString, e) @@ -426,7 +427,8 @@ case class AlterTableAddPartitionCommand( table.identifier.quotedString, sparkSession.sessionState.conf.resolver) // inherit table storage format (possibly except for location) - CatalogTablePartition(normalizedSpec, table.storage.copy(locationUri = location)) + CatalogTablePartition(normalizedSpec, table.storage.copy( + locationUri = location.map(CatalogUtils.stringToURI(_)))) } catalog.createPartitions(table.identifier, parts, ignoreIfExists = ifNotExists) Seq.empty[Row] @@ -710,7 +712,7 @@ case class AlterTableRecoverPartitionsCommand( // inherit table storage format (possibly except for location) CatalogTablePartition( spec, - table.storage.copy(locationUri = Some(location.toUri.toString)), + table.storage.copy(locationUri = Some(location.toUri)), params) } spark.sessionState.catalog.createPartitions(tableName, parts, ignoreIfExists = true) @@ -741,6 +743,7 @@ case class AlterTableSetLocationCommand( override def run(sparkSession: SparkSession): Seq[Row] = { val catalog = sparkSession.sessionState.catalog val table = catalog.getTableMetadata(tableName) + val locUri = CatalogUtils.stringToURI(location) DDLUtils.verifyAlterTableType(catalog, table, isView = false) partitionSpec match { case Some(spec) => @@ -748,11 +751,11 @@ case class AlterTableSetLocationCommand( sparkSession, table, "ALTER TABLE ... SET LOCATION") // Partition spec is specified, so we set the location only for this partition val part = catalog.getPartition(table.identifier, spec) - val newPart = part.copy(storage = part.storage.copy(locationUri = Some(location))) + val newPart = part.copy(storage = part.storage.copy(locationUri = Some(locUri))) catalog.alterPartitions(table.identifier, Seq(newPart)) case None => // No partition spec is specified, so we set the location for the table itself - catalog.alterTable(table.withNewStorage(locationUri = Some(location))) + catalog.alterTable(table.withNewStorage(locationUri = Some(locUri))) } Seq.empty[Row] } @@ -763,11 +766,11 @@ object DDLUtils { val HIVE_PROVIDER = "hive" def isHiveTable(table: CatalogTable): Boolean = { - table.provider.isDefined && table.provider.get.toLowerCase == HIVE_PROVIDER + table.provider.isDefined && table.provider.get.toLowerCase(Locale.ROOT) == HIVE_PROVIDER } def isDatasourceTable(table: CatalogTable): Boolean = { - table.provider.isDefined && table.provider.get.toLowerCase != HIVE_PROVIDER + table.provider.isDefined && table.provider.get.toLowerCase(Locale.ROOT) != HIVE_PROVIDER } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/functions.scala index ea5398761c46d..545082324f0d3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/functions.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.execution.command +import java.util.Locale + import org.apache.spark.sql.{AnalysisException, Row, SparkSession} import org.apache.spark.sql.catalyst.FunctionIdentifier import org.apache.spark.sql.catalyst.analysis.{FunctionRegistry, NoSuchFunctionException} @@ -49,6 +51,7 @@ case class CreateFunctionCommand( override def run(sparkSession: SparkSession): Seq[Row] = { val catalog = sparkSession.sessionState.catalog + val func = CatalogFunction(FunctionIdentifier(functionName, databaseName), className, resources) if (isTemp) { if (databaseName.isDefined) { throw new AnalysisException(s"Specifying a database in CREATE TEMPORARY FUNCTION " + @@ -57,17 +60,13 @@ case class CreateFunctionCommand( // We first load resources and then put the builder in the function registry. // Please note that it is allowed to overwrite an existing temp function. catalog.loadFunctionResources(resources) - val info = new ExpressionInfo(className, functionName) - val builder = catalog.makeFunctionBuilder(functionName, className) - catalog.createTempFunction(functionName, info, builder, ignoreIfExists = false) + catalog.registerFunction(func, ignoreIfExists = false) } else { // For a permanent, we will store the metadata into underlying external catalog. // This function will be loaded into the FunctionRegistry when a query uses it. // We do not load it into FunctionRegistry right now. // TODO: should we also parse "IF NOT EXISTS"? - catalog.createFunction( - CatalogFunction(FunctionIdentifier(functionName, databaseName), className, resources), - ignoreIfExists = false) + catalog.createFunction(func, ignoreIfExists = false) } Seq.empty[Row] } @@ -100,7 +99,7 @@ case class DescribeFunctionCommand( override def run(sparkSession: SparkSession): Seq[Row] = { // Hard code "<>", "!=", "between", and "case" for now as there is no corresponding functions. - functionName.funcName.toLowerCase match { + functionName.funcName.toLowerCase(Locale.ROOT) match { case "<>" => Row(s"Function: $functionName") :: Row("Usage: expr1 <> expr2 - " + @@ -208,8 +207,6 @@ case class ShowFunctionsCommand( case (f, "USER") if showUserFunctions => f.unquotedString case (f, "SYSTEM") if showSystemFunctions => f.unquotedString } - // The session catalog caches some persistent functions in the FunctionRegistry - // so there can be duplicates. - functionNames.distinct.sorted.map(Row(_)) + functionNames.sorted.map(Row(_)) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/resources.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/resources.scala index 20b08946675d0..2e859cf1ef253 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/resources.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/resources.scala @@ -37,7 +37,7 @@ case class AddJarCommand(path: String) extends RunnableCommand { } override def run(sparkSession: SparkSession): Seq[Row] = { - sparkSession.sessionState.addJar(path) + sparkSession.sessionState.resourceLoader.addJar(path) Seq(Row(0)) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/tables.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/tables.scala index 3e80916104bd9..ebf03e1bf8869 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/tables.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/tables.scala @@ -37,7 +37,10 @@ import org.apache.spark.sql.catalyst.catalog.CatalogTableType._ import org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference} import org.apache.spark.sql.catalyst.util.quoteIdentifier -import org.apache.spark.sql.execution.datasources.PartitioningUtils +import org.apache.spark.sql.execution.datasources.{DataSource, FileFormat, PartitioningUtils} +import org.apache.spark.sql.execution.datasources.csv.CSVFileFormat +import org.apache.spark.sql.execution.datasources.json.JsonFileFormat +import org.apache.spark.sql.execution.datasources.parquet.ParquetFileFormat import org.apache.spark.sql.types._ import org.apache.spark.util.Utils @@ -79,7 +82,8 @@ case class CreateTableLikeCommand( CatalogTable( identifier = targetTable, tableType = tblType, - storage = sourceTableDesc.storage.copy(locationUri = location), + storage = sourceTableDesc.storage.copy( + locationUri = location.map(CatalogUtils.stringToURI(_))), schema = sourceTableDesc.schema, provider = newProvider, partitionColumnNames = sourceTableDesc.partitionColumnNames, @@ -173,6 +177,77 @@ case class AlterTableRenameCommand( } +/** + * A command that add columns to a table + * The syntax of using this command in SQL is: + * {{{ + * ALTER TABLE table_identifier + * ADD COLUMNS (col_name data_type [COMMENT col_comment], ...); + * }}} +*/ +case class AlterTableAddColumnsCommand( + table: TableIdentifier, + columns: Seq[StructField]) extends RunnableCommand { + override def run(sparkSession: SparkSession): Seq[Row] = { + val catalog = sparkSession.sessionState.catalog + val catalogTable = verifyAlterTableAddColumn(catalog, table) + + try { + sparkSession.catalog.uncacheTable(table.quotedString) + } catch { + case NonFatal(e) => + log.warn(s"Exception when attempting to uncache table ${table.quotedString}", e) + } + catalog.refreshTable(table) + + // make sure any partition columns are at the end of the fields + val reorderedSchema = catalogTable.dataSchema ++ columns ++ catalogTable.partitionSchema + catalog.alterTableSchema( + table, catalogTable.schema.copy(fields = reorderedSchema.toArray)) + + Seq.empty[Row] + } + + /** + * ALTER TABLE ADD COLUMNS command does not support temporary view/table, + * view, or datasource table with text, orc formats or external provider. + * For datasource table, it currently only supports parquet, json, csv. + */ + private def verifyAlterTableAddColumn( + catalog: SessionCatalog, + table: TableIdentifier): CatalogTable = { + val catalogTable = catalog.getTempViewOrPermanentTableMetadata(table) + + if (catalogTable.tableType == CatalogTableType.VIEW) { + throw new AnalysisException( + s""" + |ALTER ADD COLUMNS does not support views. + |You must drop and re-create the views for adding the new columns. Views: $table + """.stripMargin) + } + + if (DDLUtils.isDatasourceTable(catalogTable)) { + DataSource.lookupDataSource(catalogTable.provider.get).newInstance() match { + // For datasource table, this command can only support the following File format. + // TextFileFormat only default to one column "value" + // OrcFileFormat can not handle difference between user-specified schema and + // inferred schema yet. TODO, once this issue is resolved , we can add Orc back. + // Hive type is already considered as hive serde table, so the logic will not + // come in here. + case _: JsonFileFormat | _: CSVFileFormat | _: ParquetFileFormat => + case s => + throw new AnalysisException( + s""" + |ALTER ADD COLUMNS does not support datasource table with type $s. + |You must drop and re-create the table for adding the new columns. Tables: $table + """.stripMargin) + } + } + catalogTable + } +} + + /** * A command that loads data into a Hive table. * @@ -425,8 +500,7 @@ case class TruncateTableCommand( case class DescribeTableCommand( table: TableIdentifier, partitionSpec: TablePartitionSpec, - isExtended: Boolean, - isFormatted: Boolean) + isExtended: Boolean) extends RunnableCommand { override val output: Seq[Attribute] = Seq( @@ -461,14 +535,12 @@ case class DescribeTableCommand( describePartitionInfo(metadata, result) - if (partitionSpec.isEmpty) { - if (isExtended) { - describeExtendedTableInfo(metadata, result) - } else if (isFormatted) { - describeFormattedTableInfo(metadata, result) - } - } else { + if (partitionSpec.nonEmpty) { + // Outputs the partition-specific info for the DDL command: + // "DESCRIBE [EXTENDED|FORMATTED] table_name PARTITION (partitionVal*)" describeDetailedPartitionInfo(sparkSession, catalog, metadata, result) + } else if (isExtended) { + describeFormattedTableInfo(metadata, result) } } @@ -478,74 +550,20 @@ case class DescribeTableCommand( private def describePartitionInfo(table: CatalogTable, buffer: ArrayBuffer[Row]): Unit = { if (table.partitionColumnNames.nonEmpty) { append(buffer, "# Partition Information", "", "") - append(buffer, s"# ${output.head.name}", output(1).name, output(2).name) describeSchema(table.partitionSchema, buffer) } } - private def describeExtendedTableInfo(table: CatalogTable, buffer: ArrayBuffer[Row]): Unit = { - append(buffer, "", "", "") - append(buffer, "# Detailed Table Information", table.toString, "") - } - private def describeFormattedTableInfo(table: CatalogTable, buffer: ArrayBuffer[Row]): Unit = { + // The following information has been already shown in the previous outputs + val excludedTableInfo = Seq( + "Partition Columns", + "Schema" + ) append(buffer, "", "", "") append(buffer, "# Detailed Table Information", "", "") - append(buffer, "Database:", table.database, "") - append(buffer, "Owner:", table.owner, "") - append(buffer, "Create Time:", new Date(table.createTime).toString, "") - append(buffer, "Last Access Time:", new Date(table.lastAccessTime).toString, "") - append(buffer, "Location:", table.storage.locationUri.getOrElse(""), "") - append(buffer, "Table Type:", table.tableType.name, "") - table.stats.foreach(s => append(buffer, "Statistics:", s.simpleString, "")) - - append(buffer, "Table Parameters:", "", "") - table.properties.foreach { case (key, value) => - append(buffer, s" $key", value, "") - } - - describeStorageInfo(table, buffer) - - if (table.tableType == CatalogTableType.VIEW) describeViewInfo(table, buffer) - - if (DDLUtils.isDatasourceTable(table) && table.tracksPartitionsInCatalog) { - append(buffer, "Partition Provider:", "Catalog", "") - } - } - - private def describeStorageInfo(metadata: CatalogTable, buffer: ArrayBuffer[Row]): Unit = { - append(buffer, "", "", "") - append(buffer, "# Storage Information", "", "") - metadata.storage.serde.foreach(serdeLib => append(buffer, "SerDe Library:", serdeLib, "")) - metadata.storage.inputFormat.foreach(format => append(buffer, "InputFormat:", format, "")) - metadata.storage.outputFormat.foreach(format => append(buffer, "OutputFormat:", format, "")) - append(buffer, "Compressed:", if (metadata.storage.compressed) "Yes" else "No", "") - describeBucketingInfo(metadata, buffer) - - append(buffer, "Storage Desc Parameters:", "", "") - val maskedProperties = CatalogUtils.maskCredentials(metadata.storage.properties) - maskedProperties.foreach { case (key, value) => - append(buffer, s" $key", value, "") - } - } - - private def describeViewInfo(metadata: CatalogTable, buffer: ArrayBuffer[Row]): Unit = { - append(buffer, "", "", "") - append(buffer, "# View Information", "", "") - append(buffer, "View Text:", metadata.viewText.getOrElse(""), "") - append(buffer, "View Default Database:", metadata.viewDefaultDatabase.getOrElse(""), "") - append(buffer, "View Query Output Columns:", - metadata.viewQueryColumnNames.mkString("[", ", ", "]"), "") - } - - private def describeBucketingInfo(metadata: CatalogTable, buffer: ArrayBuffer[Row]): Unit = { - metadata.bucketSpec match { - case Some(BucketSpec(numBuckets, bucketColumnNames, sortColumnNames)) => - append(buffer, "Num Buckets:", numBuckets.toString, "") - append(buffer, "Bucket Columns:", bucketColumnNames.mkString("[", ", ", "]"), "") - append(buffer, "Sort Columns:", sortColumnNames.mkString("[", ", ", "]"), "") - - case _ => + table.toLinkedHashMap.filterKeys(!excludedTableInfo.contains(_)).foreach { + s => append(buffer, s._1, s._2, "") } } @@ -560,21 +578,7 @@ case class DescribeTableCommand( } DDLUtils.verifyPartitionProviderIsHive(spark, metadata, "DESC PARTITION") val partition = catalog.getPartition(table, partitionSpec) - if (isExtended) { - describeExtendedDetailedPartitionInfo(table, metadata, partition, result) - } else if (isFormatted) { - describeFormattedDetailedPartitionInfo(table, metadata, partition, result) - describeStorageInfo(metadata, result) - } - } - - private def describeExtendedDetailedPartitionInfo( - tableIdentifier: TableIdentifier, - table: CatalogTable, - partition: CatalogTablePartition, - buffer: ArrayBuffer[Row]): Unit = { - append(buffer, "", "", "") - append(buffer, "Detailed Partition Information " + partition.toString, "", "") + if (isExtended) describeFormattedDetailedPartitionInfo(table, metadata, partition, result) } private def describeFormattedDetailedPartitionInfo( @@ -584,17 +588,21 @@ case class DescribeTableCommand( buffer: ArrayBuffer[Row]): Unit = { append(buffer, "", "", "") append(buffer, "# Detailed Partition Information", "", "") - append(buffer, "Partition Value:", s"[${partition.spec.values.mkString(", ")}]", "") - append(buffer, "Database:", table.database, "") - append(buffer, "Table:", tableIdentifier.table, "") - append(buffer, "Location:", partition.storage.locationUri.getOrElse(""), "") - append(buffer, "Partition Parameters:", "", "") - partition.parameters.foreach { case (key, value) => - append(buffer, s" $key", value, "") + append(buffer, "Database", table.database, "") + append(buffer, "Table", tableIdentifier.table, "") + partition.toLinkedHashMap.foreach(s => append(buffer, s._1, s._2, "")) + append(buffer, "", "", "") + append(buffer, "# Storage Information", "", "") + table.bucketSpec match { + case Some(spec) => + spec.toLinkedHashMap.foreach(s => append(buffer, s._1, s._2, "")) + case _ => } + table.storage.toLinkedHashMap.foreach(s => append(buffer, s._1, s._2, "")) } private def describeSchema(schema: StructType, buffer: ArrayBuffer[Row]): Unit = { + append(buffer, s"# ${output.head.name}", output(1).name, output(2).name) schema.foreach { column => append(buffer, column.name, column.dataType.simpleString, column.getComment().orNull) } @@ -613,13 +621,15 @@ case class DescribeTableCommand( * The syntax of using this command in SQL is: * {{{ * SHOW TABLES [(IN|FROM) database_name] [[LIKE] 'identifier_with_wildcards']; - * SHOW TABLE EXTENDED [(IN|FROM) database_name] LIKE 'identifier_with_wildcards'; + * SHOW TABLE EXTENDED [(IN|FROM) database_name] LIKE 'identifier_with_wildcards' + * [PARTITION(partition_spec)]; * }}} */ case class ShowTablesCommand( databaseName: Option[String], tableIdentifierPattern: Option[String], - isExtended: Boolean = false) extends RunnableCommand { + isExtended: Boolean = false, + partitionSpec: Option[TablePartitionSpec] = None) extends RunnableCommand { // The result of SHOW TABLES/SHOW TABLE has three basic columns: database, tableName and // isTemporary. If `isExtended` is true, append column `information` to the output columns. @@ -639,18 +649,34 @@ case class ShowTablesCommand( // instead of calling tables in sparkSession. val catalog = sparkSession.sessionState.catalog val db = databaseName.getOrElse(catalog.getCurrentDatabase) - val tables = - tableIdentifierPattern.map(catalog.listTables(db, _)).getOrElse(catalog.listTables(db)) - tables.map { tableIdent => - val database = tableIdent.database.getOrElse("") - val tableName = tableIdent.table - val isTemp = catalog.isTemporaryTable(tableIdent) - if (isExtended) { - val information = catalog.getTempViewOrPermanentTableMetadata(tableIdent).toString - Row(database, tableName, isTemp, s"${information}\n") - } else { - Row(database, tableName, isTemp) + if (partitionSpec.isEmpty) { + // Show the information of tables. + val tables = + tableIdentifierPattern.map(catalog.listTables(db, _)).getOrElse(catalog.listTables(db)) + tables.map { tableIdent => + val database = tableIdent.database.getOrElse("") + val tableName = tableIdent.table + val isTemp = catalog.isTemporaryTable(tableIdent) + if (isExtended) { + val information = catalog.getTempViewOrPermanentTableMetadata(tableIdent).simpleString + Row(database, tableName, isTemp, s"$information\n") + } else { + Row(database, tableName, isTemp) + } } + } else { + // Show the information of partitions. + // + // Note: tableIdentifierPattern should be non-empty, otherwise a [[ParseException]] + // should have been thrown by the sql parser. + val tableIdent = TableIdentifier(tableIdentifierPattern.get, Some(db)) + val table = catalog.getTableMetadata(tableIdent).identifier + val partition = catalog.getPartition(tableIdent, partitionSpec.get) + val database = table.database.getOrElse("") + val tableName = table.table + val isTemp = catalog.isTemporaryTable(table) + val information = partition.simpleString + Seq(Row(database, tableName, isTemp, s"$information\n")) } } } @@ -953,7 +979,7 @@ case class ShowCreateTableCommand(table: TableIdentifier) extends RunnableComman // when the table creation DDL contains the PATH option. None } else { - Some(s"path '${escapeSingleQuotedString(location)}'") + Some(s"path '${escapeSingleQuotedString(CatalogUtils.URIToString(location))}'") } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/views.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/views.scala index 921c84895598c..00f0acab21aa2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/views.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/views.scala @@ -23,9 +23,9 @@ import org.apache.spark.sql.{AnalysisException, Row, SparkSession} import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.analysis.{UnresolvedFunction, UnresolvedRelation} import org.apache.spark.sql.catalyst.catalog.{CatalogStorageFormat, CatalogTable, CatalogTableType} -import org.apache.spark.sql.catalyst.expressions.Alias +import org.apache.spark.sql.catalyst.expressions.{Alias, SubqueryExpression} import org.apache.spark.sql.catalyst.plans.QueryPlan -import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Project} +import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Project, View} import org.apache.spark.sql.types.MetadataBuilder @@ -154,6 +154,10 @@ case class CreateViewCommand( } else if (tableMetadata.tableType != CatalogTableType.VIEW) { throw new AnalysisException(s"$name is not a view") } else if (replace) { + // Detect cyclic view reference on CREATE OR REPLACE VIEW. + val viewIdent = tableMetadata.identifier + checkCyclicViewReference(analyzedPlan, Seq(viewIdent), viewIdent) + // Handles `CREATE OR REPLACE VIEW v0 AS SELECT ...` catalog.alterTable(prepareTable(sparkSession, analyzedPlan)) } else { @@ -283,6 +287,10 @@ case class AlterViewAsCommand( throw new AnalysisException(s"${viewMeta.identifier} is not a view.") } + // Detect cyclic view reference on ALTER VIEW. + val viewIdent = viewMeta.identifier + checkCyclicViewReference(analyzedPlan, Seq(viewIdent), viewIdent) + val newProperties = generateViewProperties(viewMeta.properties, session, analyzedPlan) val updatedViewMeta = viewMeta.copy( @@ -358,4 +366,53 @@ object ViewHelper { generateViewDefaultDatabase(viewDefaultDatabase) ++ generateQueryColumnNames(queryOutput) } + + /** + * Recursively search the logical plan to detect cyclic view references, throw an + * AnalysisException if cycle detected. + * + * A cyclic view reference is a cycle of reference dependencies, for example, if the following + * statements are executed: + * CREATE VIEW testView AS SELECT id FROM tbl + * CREATE VIEW testView2 AS SELECT id FROM testView + * ALTER VIEW testView AS SELECT * FROM testView2 + * The view `testView` references `testView2`, and `testView2` also references `testView`, + * therefore a reference cycle (testView -> testView2 -> testView) exists. + * + * @param plan the logical plan we detect cyclic view references from. + * @param path the path between the altered view and current node. + * @param viewIdent the table identifier of the altered view, we compare two views by the + * `desc.identifier`. + */ + def checkCyclicViewReference( + plan: LogicalPlan, + path: Seq[TableIdentifier], + viewIdent: TableIdentifier): Unit = { + plan match { + case v: View => + val ident = v.desc.identifier + val newPath = path :+ ident + // If the table identifier equals to the `viewIdent`, current view node is the same with + // the altered view. We detect a view reference cycle, should throw an AnalysisException. + if (ident == viewIdent) { + throw new AnalysisException(s"Recursive view $viewIdent detected " + + s"(cycle: ${newPath.mkString(" -> ")})") + } else { + v.children.foreach { child => + checkCyclicViewReference(child, newPath, viewIdent) + } + } + case _ => + plan.children.foreach(child => checkCyclicViewReference(child, path, viewIdent)) + } + + // Detect cyclic references from subqueries. + plan.expressions.foreach { expr => + expr match { + case s: SubqueryExpression => + checkCyclicViewReference(s.plan, path, viewIdent) + case _ => // Do nothing. + } + } + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/CatalogFileIndex.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/CatalogFileIndex.scala index 2068811661fec..4046396d0e614 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/CatalogFileIndex.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/CatalogFileIndex.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.execution.datasources +import java.net.URI + import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.Path @@ -46,14 +48,15 @@ class CatalogFileIndex( assert(table.identifier.database.isDefined, "The table identifier must be qualified in CatalogFileIndex") - private val baseLocation: Option[String] = table.storage.locationUri + private val baseLocation: Option[URI] = table.storage.locationUri override def partitionSchema: StructType = table.partitionSchema override def rootPaths: Seq[Path] = baseLocation.map(new Path(_)).toSeq - override def listFiles(filters: Seq[Expression]): Seq[PartitionDirectory] = { - filterPartitions(filters).listFiles(Nil) + override def listFiles( + partitionFilters: Seq[Expression], dataFilters: Seq[Expression]): Seq[PartitionDirectory] = { + filterPartitions(partitionFilters).listFiles(Nil, dataFilters) } override def refresh(): Unit = fileStatusCache.invalidateAll() @@ -66,6 +69,7 @@ class CatalogFileIndex( */ def filterPartitions(filters: Seq[Expression]): InMemoryFileIndex = { if (table.partitionColumnNames.nonEmpty) { + val startTime = System.nanoTime() val selectedPartitions = sparkSession.sessionState.catalog.listPartitionsByFilter( table.identifier, filters) val partitions = selectedPartitions.map { p => @@ -76,8 +80,9 @@ class CatalogFileIndex( path.makeQualified(fs.getUri, fs.getWorkingDirectory)) } val partitionSpec = PartitionSpec(partitionSchema, partitions) + val timeNs = System.nanoTime() - startTime new PrunedInMemoryFileIndex( - sparkSession, new Path(baseLocation.get), fileStatusCache, partitionSpec) + sparkSession, new Path(baseLocation.get), fileStatusCache, partitionSpec, Option(timeNs)) } else { new InMemoryFileIndex( sparkSession, rootPaths, table.storage.properties, partitionSchema = None) @@ -108,7 +113,8 @@ private class PrunedInMemoryFileIndex( sparkSession: SparkSession, tableBasePath: Path, fileStatusCache: FileStatusCache, - override val partitionSpec: PartitionSpec) + override val partitionSpec: PartitionSpec, + override val metadataOpsTimeNs: Option[Long]) extends InMemoryFileIndex( sparkSession, partitionSpec.partitions.map(_.path), diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala index 4947dfda6fc7e..f3b209deaae5c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala @@ -17,19 +17,18 @@ package org.apache.spark.sql.execution.datasources -import java.util.{ServiceConfigurationError, ServiceLoader} +import java.util.{Locale, ServiceConfigurationError, ServiceLoader} import scala.collection.JavaConverters._ import scala.language.{existentials, implicitConversions} import scala.util.{Failure, Success, Try} -import scala.util.control.NonFatal import org.apache.hadoop.fs.Path import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.internal.Logging import org.apache.spark.sql._ -import org.apache.spark.sql.catalyst.catalog.{BucketSpec, CatalogStorageFormat, CatalogTable} +import org.apache.spark.sql.catalyst.catalog.{BucketSpec, CatalogStorageFormat, CatalogTable, CatalogUtils} import org.apache.spark.sql.catalyst.expressions.Attribute import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap import org.apache.spark.sql.execution.datasources.csv.CSVFileFormat @@ -539,15 +538,16 @@ object DataSource { // Found the data source using fully qualified path dataSource case Failure(error) => - if (provider1.toLowerCase == "orc" || + if (provider1.toLowerCase(Locale.ROOT) == "orc" || provider1.startsWith("org.apache.spark.sql.hive.orc")) { throw new AnalysisException( "The ORC data source must be used with Hive support enabled") - } else if (provider1.toLowerCase == "avro" || + } else if (provider1.toLowerCase(Locale.ROOT) == "avro" || provider1 == "com.databricks.spark.avro") { throw new AnalysisException( - s"Failed to find data source: ${provider1.toLowerCase}. Please find an Avro " + - "package at http://spark.apache.org/third-party-projects.html") + s"Failed to find data source: ${provider1.toLowerCase(Locale.ROOT)}. " + + "Please find an Avro package at " + + "http://spark.apache.org/third-party-projects.html") } else { throw new ClassNotFoundException( s"Failed to find data source: $provider1. Please find packages at " + @@ -596,7 +596,8 @@ object DataSource { */ def buildStorageFormatFromOptions(options: Map[String, String]): CatalogStorageFormat = { val path = CaseInsensitiveMap(options).get("path") - val optionsWithoutPath = options.filterKeys(_.toLowerCase != "path") - CatalogStorageFormat.empty.copy(locationUri = path, properties = optionsWithoutPath) + val optionsWithoutPath = options.filterKeys(_.toLowerCase(Locale.ROOT) != "path") + CatalogStorageFormat.empty.copy( + locationUri = path.map(CatalogUtils.stringToURI), properties = optionsWithoutPath) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala index bc5734f36891c..d307122b5c70d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala @@ -24,10 +24,10 @@ import scala.collection.mutable.ArrayBuffer import org.apache.spark.internal.Logging import org.apache.spark.rdd.RDD import org.apache.spark.sql._ -import org.apache.spark.sql.catalyst.{CatalystConf, CatalystTypeConverters, InternalRow, QualifiedTableName, TableIdentifier} +import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow, QualifiedTableName} import org.apache.spark.sql.catalyst.CatalystTypeConverters.convertToScala import org.apache.spark.sql.catalyst.analysis._ -import org.apache.spark.sql.catalyst.catalog.CatalogRelation +import org.apache.spark.sql.catalyst.catalog.{CatalogRelation, CatalogUtils} import org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.planning.PhysicalOperation @@ -36,7 +36,7 @@ import org.apache.spark.sql.catalyst.plans.physical.{HashPartitioning, UnknownPa import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.execution.{RowDataSourceScanExec, SparkPlan} import org.apache.spark.sql.execution.command._ -import org.apache.spark.sql.internal.StaticSQLConf +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.sources._ import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String @@ -48,7 +48,7 @@ import org.apache.spark.unsafe.types.UTF8String * Note that, this rule must be run after `PreprocessTableCreation` and * `PreprocessTableInsertion`. */ -case class DataSourceAnalysis(conf: CatalystConf) extends Rule[LogicalPlan] { +case class DataSourceAnalysis(conf: SQLConf) extends Rule[LogicalPlan] with CastSupport { def resolver: Resolver = conf.resolver @@ -98,11 +98,11 @@ case class DataSourceAnalysis(conf: CatalystConf) extends Rule[LogicalPlan] { val potentialSpecs = staticPartitions.filter { case (partKey, partValue) => resolver(field.name, partKey) } - if (potentialSpecs.size == 0) { + if (potentialSpecs.isEmpty) { None } else if (potentialSpecs.size == 1) { val partValue = potentialSpecs.head._2 - Some(Alias(Cast(Literal(partValue), field.dataType), field.name)()) + Some(Alias(cast(Literal(partValue), field.dataType), field.name)()) } else { throw new AnalysisException( s"Partition column ${field.name} have multiple values specified, " + @@ -215,12 +215,10 @@ class FindDataSourceTable(sparkSession: SparkSession) extends Rule[LogicalPlan] val table = r.tableMeta val qualifiedTableName = QualifiedTableName(table.database, table.identifier.table) val cache = sparkSession.sessionState.catalog.tableRelationCache - val inMemory = - sparkSession.sparkContext.conf.get(StaticSQLConf.CATALOG_IMPLEMENTATION) == "in-memory" val plan = cache.get(qualifiedTableName, new Callable[LogicalPlan]() { override def call(): LogicalPlan = { - val pathOption = table.storage.locationUri.map("path" -> _) + val pathOption = table.storage.locationUri.map("path" -> CatalogUtils.URIToString(_)) val dataSource = DataSource( sparkSession, @@ -231,19 +229,19 @@ class FindDataSourceTable(sparkSession: SparkSession) extends Rule[LogicalPlan] bucketSpec = table.bucketSpec, className = table.provider.get, options = table.storage.properties ++ pathOption, - // TODO: improve `InMemoryCatalog` and remove this limitation. - catalogTable = if (inMemory) None else Some(table)) + catalogTable = Some(table)) - LogicalRelation( - dataSource.resolveRelation(checkFilesExist = false), - catalogTable = Some(table)) + LogicalRelation(dataSource.resolveRelation(checkFilesExist = false), table) } }).asInstanceOf[LogicalRelation] - // It's possible that the table schema is empty and need to be inferred at runtime. We should - // not specify expected outputs for this case. - val expectedOutputs = if (r.output.isEmpty) None else Some(r.output) - plan.copy(expectedOutputAttributes = expectedOutputs) + if (r.output.isEmpty) { + // It's possible that the table schema is empty and need to be inferred at runtime. For this + // case, we don't need to change the output of the cached plan. + plan + } else { + plan.copy(output = r.output) + } } override def apply(plan: LogicalPlan): LogicalPlan = plan transform { @@ -260,7 +258,9 @@ class FindDataSourceTable(sparkSession: SparkSession) extends Rule[LogicalPlan] /** * A Strategy for planning scans over data sources defined using the sources API. */ -object DataSourceStrategy extends Strategy with Logging { +case class DataSourceStrategy(conf: SQLConf) extends Strategy with Logging with CastSupport { + import DataSourceStrategy._ + def apply(plan: LogicalPlan): Seq[execution.SparkPlan] = plan match { case PhysicalOperation(projects, filters, l @ LogicalRelation(t: CatalystScan, _, _)) => pruneFilterProjectRaw( @@ -300,7 +300,7 @@ object DataSourceStrategy extends Strategy with Logging { // Restriction: Bucket pruning works iff the bucketing column has one and only one column. def getBucketId(bucketColumn: Attribute, numBuckets: Int, value: Any): Int = { val mutableRow = new SpecificInternalRow(Seq(bucketColumn.dataType)) - mutableRow(0) = Cast(Literal(value), bucketColumn.dataType).eval(null) + mutableRow(0) = cast(Literal(value), bucketColumn.dataType).eval(null) val bucketIdGeneration = UnsafeProjection.create( HashPartitioning(bucketColumn :: Nil, numBuckets).partitionIdExpression :: Nil, bucketColumn :: Nil) @@ -438,7 +438,9 @@ object DataSourceStrategy extends Strategy with Logging { private[this] def toCatalystRDD(relation: LogicalRelation, rdd: RDD[Row]): RDD[InternalRow] = { toCatalystRDD(relation, relation.output, rdd) } +} +object DataSourceStrategy { /** * Tries to translate a Catalyst [[Expression]] into data source [[Filter]]. * @@ -529,8 +531,8 @@ object DataSourceStrategy extends Strategy with Logging { * all [[Filter]]s that are completely filtered at the DataSource. */ protected[sql] def selectFilters( - relation: BaseRelation, - predicates: Seq[Expression]): (Seq[Expression], Seq[Filter], Set[Filter]) = { + relation: BaseRelation, + predicates: Seq[Expression]): (Seq[Expression], Seq[Filter], Set[Filter]) = { // For conciseness, all Catalyst filter expressions of type `expressions.Expression` below are // called `predicate`s, while all data source filters of type `sources.Filter` are simply called diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FailureSafeParser.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FailureSafeParser.scala new file mode 100644 index 0000000000000..159aef220be15 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FailureSafeParser.scala @@ -0,0 +1,72 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.datasources + +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.GenericInternalRow +import org.apache.spark.sql.catalyst.util._ +import org.apache.spark.sql.types.StructType +import org.apache.spark.unsafe.types.UTF8String + +class FailureSafeParser[IN]( + rawParser: IN => Seq[InternalRow], + mode: ParseMode, + schema: StructType, + columnNameOfCorruptRecord: String) { + + private val corruptFieldIndex = schema.getFieldIndex(columnNameOfCorruptRecord) + private val actualSchema = StructType(schema.filterNot(_.name == columnNameOfCorruptRecord)) + private val resultRow = new GenericInternalRow(schema.length) + private val nullResult = new GenericInternalRow(schema.length) + + // This function takes 2 parameters: an optional partial result, and the bad record. If the given + // schema doesn't contain a field for corrupted record, we just return the partial result or a + // row with all fields null. If the given schema contains a field for corrupted record, we will + // set the bad record to this field, and set other fields according to the partial result or null. + private val toResultRow: (Option[InternalRow], () => UTF8String) => InternalRow = { + if (corruptFieldIndex.isDefined) { + (row, badRecord) => { + var i = 0 + while (i < actualSchema.length) { + val from = actualSchema(i) + resultRow(schema.fieldIndex(from.name)) = row.map(_.get(i, from.dataType)).orNull + i += 1 + } + resultRow(corruptFieldIndex.get) = badRecord() + resultRow + } + } else { + (row, _) => row.getOrElse(nullResult) + } + } + + def parse(input: IN): Iterator[InternalRow] = { + try { + rawParser.apply(input).toIterator.map(row => toResultRow(Some(row), () => null)) + } catch { + case e: BadRecordException => mode match { + case PermissiveMode => + Iterator(toResultRow(e.partialResult(), e.record)) + case DropMalformedMode => + Iterator.empty + case FailFastMode => + throw e.cause + } + } + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormat.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormat.scala index 984f0b9f8cd9f..d2adba2da9478 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormat.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormat.scala @@ -103,7 +103,7 @@ trait FileFormat { * @param options A set of string -> string configuration options. * @return */ - def buildReader( + protected def buildReader( sparkSession: SparkSession, dataSchema: StructType, partitionSchema: StructType, @@ -111,8 +111,6 @@ trait FileFormat { filters: Seq[Filter], options: Map[String, String], hadoopConf: Configuration): PartitionedFile => Iterator[InternalRow] = { - // TODO: Remove this default implementation when the other formats have been ported - // Until then we guard in [[FileSourceStrategy]] to only call this method on supported formats. throw new UnsupportedOperationException(s"buildReader is not supported for $this") } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala index 950e5ca0d6210..4ec09bff429c5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala @@ -80,6 +80,9 @@ object FileFormatWriter extends Logging { """.stripMargin) } + /** The result of a successful write task. */ + private case class WriteTaskResult(commitMsg: TaskCommitMessage, updatedPartitions: Set[String]) + /** * Basic work flow of this command is: * 1. Driver side setup, including output committer initialization and data source specific @@ -141,7 +144,7 @@ object FileFormatWriter extends Logging { customPartitionLocations = outputSpec.customPartitionLocations, maxRecordsPerFile = caseInsensitiveOptions.get("maxRecordsPerFile").map(_.toLong) .getOrElse(sparkSession.sessionState.conf.maxRecordsPerFile), - timeZoneId = caseInsensitiveOptions.get("timeZone") + timeZoneId = caseInsensitiveOptions.get(DateTimeUtils.TIMEZONE_OPTION) .getOrElse(sparkSession.sessionState.conf.sessionLocalTimeZone) ) @@ -172,8 +175,9 @@ object FileFormatWriter extends Logging { global = false, child = queryExecution.executedPlan).execute() } - - val ret = sparkSession.sparkContext.runJob(rdd, + val ret = new Array[WriteTaskResult](rdd.partitions.length) + sparkSession.sparkContext.runJob( + rdd, (taskContext: TaskContext, iter: Iterator[InternalRow]) => { executeTask( description = description, @@ -182,10 +186,16 @@ object FileFormatWriter extends Logging { sparkAttemptNumber = taskContext.attemptNumber(), committer, iterator = iter) + }, + 0 until rdd.partitions.length, + (index, res: WriteTaskResult) => { + committer.onTaskCommit(res.commitMsg) + ret(index) = res }) - val commitMsgs = ret.map(_._1) - val updatedPartitions = ret.flatMap(_._2).distinct.map(PartitioningUtils.parsePathFragment) + val commitMsgs = ret.map(_.commitMsg) + val updatedPartitions = ret.flatMap(_.updatedPartitions) + .distinct.map(PartitioningUtils.parsePathFragment) committer.commitJob(job, commitMsgs) logInfo(s"Job ${job.getJobID} committed.") @@ -205,7 +215,7 @@ object FileFormatWriter extends Logging { sparkPartitionId: Int, sparkAttemptNumber: Int, committer: FileCommitProtocol, - iterator: Iterator[InternalRow]): (TaskCommitMessage, Set[String]) = { + iterator: Iterator[InternalRow]): WriteTaskResult = { val jobId = SparkHadoopWriterUtils.createJobID(new Date, sparkStageId) val taskId = new TaskID(jobId, TaskType.MAP, sparkPartitionId) @@ -238,7 +248,7 @@ object FileFormatWriter extends Logging { // Execute the task to write rows out and commit the task. val outputPartitions = writeTask.execute(iterator) writeTask.releaseResources() - (committer.commitTask(taskAttemptContext), outputPartitions) + WriteTaskResult(committer.commitTask(taskAttemptContext), outputPartitions) })(catchBlock = { // If there is an error, release resource and then abort the task try { @@ -314,8 +324,11 @@ object FileFormatWriter extends Logging { override def releaseResources(): Unit = { if (currentWriter != null) { - currentWriter.close() - currentWriter = null + try { + currentWriter.close() + } finally { + currentWriter = null + } } } } @@ -335,14 +348,11 @@ object FileFormatWriter extends Logging { /** Expressions that given partition columns build a path string like: col1=val/col2=val/... */ private def partitionPathExpression: Seq[Expression] = { desc.partitionColumns.zipWithIndex.flatMap { case (c, i) => - val escaped = ScalaUDF( - ExternalCatalogUtils.escapePathName _, + val partitionName = ScalaUDF( + ExternalCatalogUtils.getPartitionPathString _, StringType, - Seq(Cast(c, StringType, Option(desc.timeZoneId))), - Seq(StringType)) - val str = If(IsNull(c), Literal(ExternalCatalogUtils.DEFAULT_PARTITION_NAME), escaped) - val partitionName = Literal(c.name + "=") :: str :: Nil - if (i == 0) partitionName else Literal(Path.SEPARATOR) :: partitionName + Seq(Literal(c.name), Cast(c, StringType, Option(desc.timeZoneId)))) + if (i == 0) Seq(partitionName) else Seq(Literal(Path.SEPARATOR), partitionName) } } @@ -452,8 +462,11 @@ object FileFormatWriter extends Logging { override def releaseResources(): Unit = { if (currentWriter != null) { - currentWriter.close() - currentWriter = null + try { + currentWriter.close() + } finally { + currentWriter = null + } } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileIndex.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileIndex.scala index 277223d52ec52..094a66a2820f3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileIndex.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileIndex.scala @@ -46,12 +46,17 @@ trait FileIndex { * Returns all valid files grouped into partitions when the data is partitioned. If the data is * unpartitioned, this will return a single partition with no partition values. * - * @param filters The filters used to prune which partitions are returned. These filters must - * only refer to partition columns and this method will only return files - * where these predicates are guaranteed to evaluate to `true`. Thus, these - * filters will not need to be evaluated again on the returned data. + * @param partitionFilters The filters used to prune which partitions are returned. These filters + * must only refer to partition columns and this method will only return + * files where these predicates are guaranteed to evaluate to `true`. + * Thus, these filters will not need to be evaluated again on the + * returned data. + * @param dataFilters Filters that can be applied on non-partitioned columns. The implementation + * does not need to guarantee these filters are applied, i.e. the execution + * engine will ensure these filters are still applied on the returned files. */ - def listFiles(filters: Seq[Expression]): Seq[PartitionDirectory] + def listFiles( + partitionFilters: Seq[Expression], dataFilters: Seq[Expression]): Seq[PartitionDirectory] /** * Returns the list of files that will be read when scanning this relation. This call may be @@ -67,4 +72,14 @@ trait FileIndex { /** Schema of the partitioning columns, or the empty schema if the table is not partitioned. */ def partitionSchema: StructType + + /** + * Returns an optional metadata operation time, in nanoseconds, for listing files. + * + * We do file listing in query optimization (in order to get the proper statistics) and we want + * to account for file listing time in physical execution (as metrics). To do that, we save the + * file listing time in some implementations and physical execution calls it in this method + * to update the metrics. + */ + def metadataOpsTimeNs: Option[Long] = None } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileScanRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileScanRDD.scala index 14f721d6a790a..9df20731c71d5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileScanRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileScanRDD.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.execution.datasources -import java.io.IOException +import java.io.{FileNotFoundException, IOException} import scala.collection.mutable @@ -44,7 +44,7 @@ case class PartitionedFile( filePath: String, start: Long, length: Long, - locations: Array[String] = Array.empty) { + @transient locations: Array[String] = Array.empty) { override def toString: String = { s"path: $filePath, range: $start-${start + length}, partition values: $partitionValues" } @@ -101,9 +101,7 @@ class FileScanRDD( // Kill the task in case it has been marked as killed. This logic is from // InterruptibleIterator, but we inline it here instead of wrapping the iterator in order // to avoid performance overhead. - if (context.isInterrupted()) { - throw new TaskKilledException - } + context.killTaskIfInterrupted() (currentIterator != null && currentIterator.hasNext) || nextIterator() } def next(): Object = { @@ -121,6 +119,20 @@ class FileScanRDD( nextElement } + private def readCurrentFile(): Iterator[InternalRow] = { + try { + readFunction(currentFile) + } catch { + case e: FileNotFoundException => + throw new FileNotFoundException( + e.getMessage + "\n" + + "It is possible the underlying files have been updated. " + + "You can explicitly invalidate the cache in Spark by " + + "running 'REFRESH TABLE tableName' command in SQL or " + + "by recreating the Dataset/DataFrame involved.") + } + } + /** Advances to the next file. Returns true if a new non-empty iterator is available. */ private def nextIterator(): Boolean = { updateBytesReadWithFileSize() @@ -130,54 +142,36 @@ class FileScanRDD( // Sets InputFileBlockHolder for the file block's information InputFileBlockHolder.set(currentFile.filePath, currentFile.start, currentFile.length) - try { - if (ignoreCorruptFiles) { - currentIterator = new NextIterator[Object] { - private val internalIter = { - try { - // The readFunction may read files before consuming the iterator. - // E.g., vectorized Parquet reader. - readFunction(currentFile) - } catch { - case e @(_: RuntimeException | _: IOException) => - logWarning(s"Skipped the rest content in the corrupted file: $currentFile", e) - Iterator.empty + if (ignoreCorruptFiles) { + currentIterator = new NextIterator[Object] { + // The readFunction may read some bytes before consuming the iterator, e.g., + // vectorized Parquet reader. Here we use lazy val to delay the creation of + // iterator so that we will throw exception in `getNext`. + private lazy val internalIter = readCurrentFile() + + override def getNext(): AnyRef = { + try { + if (internalIter.hasNext) { + internalIter.next() + } else { + finished = true + null } + } catch { + // Throw FileNotFoundException even `ignoreCorruptFiles` is true + case e: FileNotFoundException => throw e + case e @ (_: RuntimeException | _: IOException) => + logWarning( + s"Skipped the rest of the content in the corrupted file: $currentFile", e) + finished = true + null } - - override def getNext(): AnyRef = { - try { - if (internalIter.hasNext) { - internalIter.next() - } else { - finished = true - null - } - } catch { - case e: IOException => - logWarning(s"Skipped the rest content in the corrupted file: $currentFile", e) - finished = true - null - } - } - - override def close(): Unit = {} } - } else { - currentIterator = readFunction(currentFile) + + override def close(): Unit = {} } - } catch { - case e: IOException if ignoreCorruptFiles => - logWarning(s"Skipped the rest content in the corrupted file: $currentFile", e) - currentIterator = Iterator.empty - case e: java.io.FileNotFoundException => - throw new java.io.FileNotFoundException( - e.getMessage + "\n" + - "It is possible the underlying files have been updated. " + - "You can explicitly invalidate the cache in Spark by " + - "running 'REFRESH TABLE tableName' command in SQL or " + - "by recreating the Dataset/DataFrame involved." - ) + } else { + currentIterator = readCurrentFile() } hasNext diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategy.scala index 26e1380eca499..17f7e0e601c0c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategy.scala @@ -100,9 +100,6 @@ object FileSourceStrategy extends Strategy with Logging { val outputSchema = readDataColumns.toStructType logInfo(s"Output Data Schema: ${outputSchema.simpleString(5)}") - val pushedDownFilters = dataFilters.flatMap(DataSourceStrategy.translateFilter) - logInfo(s"Pushed Filters: ${pushedDownFilters.mkString(",")}") - val outputAttributes = readDataColumns ++ partitionColumns val scan = @@ -111,7 +108,7 @@ object FileSourceStrategy extends Strategy with Logging { outputAttributes, outputSchema, partitionKeyFilters.toSeq, - pushedDownFilters, + dataFilters, table.map(_.identifier)) val afterScanFilter = afterScanFilters.toSeq.reduceOption(expressions.And) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileStatusCache.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileStatusCache.scala index 5d97558633146..aea27bd4c4d7f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileStatusCache.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileStatusCache.scala @@ -94,27 +94,48 @@ private class SharedInMemoryCache(maxSizeInBytes: Long) extends Logging { // Opaque object that uniquely identifies a shared cache user private type ClientId = Object + private val warnedAboutEviction = new AtomicBoolean(false) // we use a composite cache key in order to distinguish entries inserted by different clients - private val cache: Cache[(ClientId, Path), Array[FileStatus]] = CacheBuilder.newBuilder() - .weigher(new Weigher[(ClientId, Path), Array[FileStatus]] { + private val cache: Cache[(ClientId, Path), Array[FileStatus]] = { + // [[Weigher]].weigh returns Int so we could only cache objects < 2GB + // instead, the weight is divided by this factor (which is smaller + // than the size of one [[FileStatus]]). + // so it will support objects up to 64GB in size. + val weightScale = 32 + val weigher = new Weigher[(ClientId, Path), Array[FileStatus]] { override def weigh(key: (ClientId, Path), value: Array[FileStatus]): Int = { - (SizeEstimator.estimate(key) + SizeEstimator.estimate(value)).toInt - }}) - .removalListener(new RemovalListener[(ClientId, Path), Array[FileStatus]]() { - override def onRemoval(removed: RemovalNotification[(ClientId, Path), Array[FileStatus]]) - : Unit = { + val estimate = (SizeEstimator.estimate(key) + SizeEstimator.estimate(value)) / weightScale + if (estimate > Int.MaxValue) { + logWarning(s"Cached table partition metadata size is too big. Approximating to " + + s"${Int.MaxValue.toLong * weightScale}.") + Int.MaxValue + } else { + estimate.toInt + } + } + } + val removalListener = new RemovalListener[(ClientId, Path), Array[FileStatus]]() { + override def onRemoval( + removed: RemovalNotification[(ClientId, Path), + Array[FileStatus]]): Unit = { if (removed.getCause == RemovalCause.SIZE && - warnedAboutEviction.compareAndSet(false, true)) { + warnedAboutEviction.compareAndSet(false, true)) { logWarning( "Evicting cached table partition metadata from memory due to size constraints " + - "(spark.sql.hive.filesourcePartitionFileCacheSize = " + maxSizeInBytes + " bytes). " + - "This may impact query planning performance.") + "(spark.sql.hive.filesourcePartitionFileCacheSize = " + + maxSizeInBytes + " bytes). This may impact query planning performance.") } - }}) - .maximumWeight(maxSizeInBytes) - .build[(ClientId, Path), Array[FileStatus]]() + } + } + CacheBuilder.newBuilder() + .weigher(weigher) + .removalListener(removalListener) + .maximumWeight(maxSizeInBytes / weightScale) + .build[(ClientId, Path), Array[FileStatus]]() + } + /** * @return a FileStatusCache that does not share any entries with any other client, but does diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InMemoryFileIndex.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InMemoryFileIndex.scala index ee4d0863d9771..91e31650617ec 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InMemoryFileIndex.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InMemoryFileIndex.scala @@ -17,32 +17,48 @@ package org.apache.spark.sql.execution.datasources +import java.io.FileNotFoundException + import scala.collection.mutable +import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs._ +import org.apache.hadoop.mapred.{FileInputFormat, JobConf} +import org.apache.spark.internal.Logging +import org.apache.spark.metrics.source.HiveCatalogMetrics +import org.apache.spark.sql.execution.streaming.FileStreamSink import org.apache.spark.sql.SparkSession import org.apache.spark.sql.types.StructType +import org.apache.spark.util.SerializableConfiguration /** * A [[FileIndex]] that generates the list of files to process by recursively listing all the * files present in `paths`. * - * @param rootPaths the list of root table paths to scan + * @param rootPathsSpecified the list of root table paths to scan (some of which might be + * filtered out later) * @param parameters as set of options to control discovery * @param partitionSchema an optional partition schema that will be use to provide types for the * discovered partitions */ class InMemoryFileIndex( sparkSession: SparkSession, - override val rootPaths: Seq[Path], + rootPathsSpecified: Seq[Path], parameters: Map[String, String], partitionSchema: Option[StructType], fileStatusCache: FileStatusCache = NoopCache) extends PartitioningAwareFileIndex( sparkSession, parameters, partitionSchema, fileStatusCache) { + // Filter out streaming metadata dirs or files such as "/.../_spark_metadata" (the metadata dir) + // or "/.../_spark_metadata/0" (a file in the metadata dir). `rootPathsSpecified` might contain + // such streaming metadata dir or files, e.g. when after globbing "basePath/*" where "basePath" + // is the output of a streaming query. + override val rootPaths = + rootPathsSpecified.filterNot(FileStreamSink.ancestorIsMetadataDirectory(_, hadoopConf)) + @volatile private var cachedLeafFiles: mutable.LinkedHashMap[Path, FileStatus] = _ @volatile private var cachedLeafDirToChildrenFiles: Map[Path, Array[FileStatus]] = _ @volatile private var cachedPartitionSpec: PartitionSpec = _ @@ -84,4 +100,222 @@ class InMemoryFileIndex( } override def hashCode(): Int = rootPaths.toSet.hashCode() + + /** + * List leaf files of given paths. This method will submit a Spark job to do parallel + * listing whenever there is a path having more files than the parallel partition discovery + * discovery threshold. + * + * This is publicly visible for testing. + */ + def listLeafFiles(paths: Seq[Path]): mutable.LinkedHashSet[FileStatus] = { + val output = mutable.LinkedHashSet[FileStatus]() + val pathsToFetch = mutable.ArrayBuffer[Path]() + for (path <- paths) { + fileStatusCache.getLeafFiles(path) match { + case Some(files) => + HiveCatalogMetrics.incrementFileCacheHits(files.length) + output ++= files + case None => + pathsToFetch += path + } + } + val filter = FileInputFormat.getInputPathFilter(new JobConf(hadoopConf, this.getClass)) + val discovered = InMemoryFileIndex.bulkListLeafFiles( + pathsToFetch, hadoopConf, filter, sparkSession) + discovered.foreach { case (path, leafFiles) => + HiveCatalogMetrics.incrementFilesDiscovered(leafFiles.size) + fileStatusCache.putLeafFiles(path, leafFiles.toArray) + output ++= leafFiles + } + output + } +} + +object InMemoryFileIndex extends Logging { + + /** A serializable variant of HDFS's BlockLocation. */ + private case class SerializableBlockLocation( + names: Array[String], + hosts: Array[String], + offset: Long, + length: Long) + + /** A serializable variant of HDFS's FileStatus. */ + private case class SerializableFileStatus( + path: String, + length: Long, + isDir: Boolean, + blockReplication: Short, + blockSize: Long, + modificationTime: Long, + accessTime: Long, + blockLocations: Array[SerializableBlockLocation]) + + /** + * Lists a collection of paths recursively. Picks the listing strategy adaptively depending + * on the number of paths to list. + * + * This may only be called on the driver. + * + * @return for each input path, the set of discovered files for the path + */ + private def bulkListLeafFiles( + paths: Seq[Path], + hadoopConf: Configuration, + filter: PathFilter, + sparkSession: SparkSession): Seq[(Path, Seq[FileStatus])] = { + + // Short-circuits parallel listing when serial listing is likely to be faster. + if (paths.size <= sparkSession.sessionState.conf.parallelPartitionDiscoveryThreshold) { + return paths.map { path => + (path, listLeafFiles(path, hadoopConf, filter, Some(sparkSession))) + } + } + + logInfo(s"Listing leaf files and directories in parallel under: ${paths.mkString(", ")}") + HiveCatalogMetrics.incrementParallelListingJobCount(1) + + val sparkContext = sparkSession.sparkContext + val serializableConfiguration = new SerializableConfiguration(hadoopConf) + val serializedPaths = paths.map(_.toString) + val parallelPartitionDiscoveryParallelism = + sparkSession.sessionState.conf.parallelPartitionDiscoveryParallelism + + // Set the number of parallelism to prevent following file listing from generating many tasks + // in case of large #defaultParallelism. + val numParallelism = Math.min(paths.size, parallelPartitionDiscoveryParallelism) + + val statusMap = sparkContext + .parallelize(serializedPaths, numParallelism) + .mapPartitions { pathStrings => + val hadoopConf = serializableConfiguration.value + pathStrings.map(new Path(_)).toSeq.map { path => + (path, listLeafFiles(path, hadoopConf, filter, None)) + }.iterator + }.map { case (path, statuses) => + val serializableStatuses = statuses.map { status => + // Turn FileStatus into SerializableFileStatus so we can send it back to the driver + val blockLocations = status match { + case f: LocatedFileStatus => + f.getBlockLocations.map { loc => + SerializableBlockLocation( + loc.getNames, + loc.getHosts, + loc.getOffset, + loc.getLength) + } + + case _ => + Array.empty[SerializableBlockLocation] + } + + SerializableFileStatus( + status.getPath.toString, + status.getLen, + status.isDirectory, + status.getReplication, + status.getBlockSize, + status.getModificationTime, + status.getAccessTime, + blockLocations) + } + (path.toString, serializableStatuses) + }.collect() + + // turn SerializableFileStatus back to Status + statusMap.map { case (path, serializableStatuses) => + val statuses = serializableStatuses.map { f => + val blockLocations = f.blockLocations.map { loc => + new BlockLocation(loc.names, loc.hosts, loc.offset, loc.length) + } + new LocatedFileStatus( + new FileStatus( + f.length, f.isDir, f.blockReplication, f.blockSize, f.modificationTime, + new Path(f.path)), + blockLocations) + } + (new Path(path), statuses) + } + } + + /** + * Lists a single filesystem path recursively. If a SparkSession object is specified, this + * function may launch Spark jobs to parallelize listing. + * + * If sessionOpt is None, this may be called on executors. + * + * @return all children of path that match the specified filter. + */ + private def listLeafFiles( + path: Path, + hadoopConf: Configuration, + filter: PathFilter, + sessionOpt: Option[SparkSession]): Seq[FileStatus] = { + logTrace(s"Listing $path") + val fs = path.getFileSystem(hadoopConf) + + // [SPARK-17599] Prevent InMemoryFileIndex from failing if path doesn't exist + // Note that statuses only include FileStatus for the files and dirs directly under path, + // and does not include anything else recursively. + val statuses = try fs.listStatus(path) catch { + case _: FileNotFoundException => + logWarning(s"The directory $path was not found. Was it deleted very recently?") + Array.empty[FileStatus] + } + + val filteredStatuses = statuses.filterNot(status => shouldFilterOut(status.getPath.getName)) + + val allLeafStatuses = { + val (dirs, topLevelFiles) = filteredStatuses.partition(_.isDirectory) + val nestedFiles: Seq[FileStatus] = sessionOpt match { + case Some(session) => + bulkListLeafFiles(dirs.map(_.getPath), hadoopConf, filter, session).flatMap(_._2) + case _ => + dirs.flatMap(dir => listLeafFiles(dir.getPath, hadoopConf, filter, sessionOpt)) + } + val allFiles = topLevelFiles ++ nestedFiles + if (filter != null) allFiles.filter(f => filter.accept(f.getPath)) else allFiles + } + + allLeafStatuses.filterNot(status => shouldFilterOut(status.getPath.getName)).map { + case f: LocatedFileStatus => + f + + // NOTE: + // + // - Although S3/S3A/S3N file system can be quite slow for remote file metadata + // operations, calling `getFileBlockLocations` does no harm here since these file system + // implementations don't actually issue RPC for this method. + // + // - Here we are calling `getFileBlockLocations` in a sequential manner, but it should not + // be a big deal since we always use to `listLeafFilesInParallel` when the number of + // paths exceeds threshold. + case f => + // The other constructor of LocatedFileStatus will call FileStatus.getPermission(), + // which is very slow on some file system (RawLocalFileSystem, which is launch a + // subprocess and parse the stdout). + val locations = fs.getFileBlockLocations(f, 0, f.getLen) + val lfs = new LocatedFileStatus(f.getLen, f.isDirectory, f.getReplication, f.getBlockSize, + f.getModificationTime, 0, null, null, null, null, f.getPath, locations) + if (f.isSymlink) { + lfs.setSymlink(f.getSymlink) + } + lfs + } + } + + /** Checks if we should filter out this path name. */ + def shouldFilterOut(pathName: String): Boolean = { + // We filter follow paths: + // 1. everything that starts with _ and ., except _common_metadata and _metadata + // because Parquet needs to find those metadata files from leaf files returned by this method. + // We should refactor this logic to not mix metadata files with data files. + // 2. everything that ends with `._COPYING_`, because this is a intermediate state of file. we + // should skip this file in case of double reading. + val exclude = (pathName.startsWith("_") && !pathName.contains("=")) || + pathName.startsWith(".") || pathName.endsWith("._COPYING_") + val include = pathName.startsWith("_common_metadata") || pathName.startsWith("_metadata") + exclude && !include + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoDataSourceCommand.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoDataSourceCommand.scala index b2ff68a833fea..a813829d50cb1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoDataSourceCommand.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoDataSourceCommand.scala @@ -42,8 +42,9 @@ case class InsertIntoDataSourceCommand( val df = sparkSession.internalCreateDataFrame(data.queryExecution.toRdd, logicalRelation.schema) relation.insert(df, overwrite) - // Invalidate the cache. - sparkSession.sharedState.cacheManager.invalidateCache(logicalRelation) + // Re-cache all cached plans(including this relation itself, if it's cached) that refer to this + // data source relation. + sparkSession.sharedState.cacheManager.recacheByPlan(sparkSession, logicalRelation) Seq.empty[Row] } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/LogicalRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/LogicalRelation.scala index 04a764bee2ef2..3813f953e06a3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/LogicalRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/LogicalRelation.scala @@ -16,41 +16,23 @@ */ package org.apache.spark.sql.execution.datasources -import org.apache.spark.sql.catalyst.CatalystConf import org.apache.spark.sql.catalyst.analysis.MultiInstanceRelation import org.apache.spark.sql.catalyst.catalog.CatalogTable -import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeMap, AttributeReference} +import org.apache.spark.sql.catalyst.expressions.{AttributeMap, AttributeReference} import org.apache.spark.sql.catalyst.plans.logical.{LeafNode, LogicalPlan, Statistics} +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.sources.BaseRelation import org.apache.spark.util.Utils /** * Used to link a [[BaseRelation]] in to a logical query plan. - * - * Note that sometimes we need to use `LogicalRelation` to replace an existing leaf node without - * changing the output attributes' IDs. The `expectedOutputAttributes` parameter is used for - * this purpose. See https://issues.apache.org/jira/browse/SPARK-10741 for more details. */ case class LogicalRelation( relation: BaseRelation, - expectedOutputAttributes: Option[Seq[Attribute]] = None, - catalogTable: Option[CatalogTable] = None) + output: Seq[AttributeReference], + catalogTable: Option[CatalogTable]) extends LeafNode with MultiInstanceRelation { - override val output: Seq[AttributeReference] = { - val attrs = relation.schema.toAttributes - expectedOutputAttributes.map { expectedAttrs => - assert(expectedAttrs.length == attrs.length) - attrs.zip(expectedAttrs).map { - // We should respect the attribute names provided by base relation and only use the - // exprId in `expectedOutputAttributes`. - // The reason is that, some relations(like parquet) will reconcile attribute names to - // workaround case insensitivity issue. - case (attr, expected) => attr.withExprId(expected.exprId) - } - }.getOrElse(attrs) - } - // Logical Relations are distinct if they have different output for the sake of transformations. override def equals(other: Any): Boolean = other match { case l @ LogicalRelation(otherRelation, _, _) => relation == otherRelation && output == l.output @@ -61,19 +43,10 @@ case class LogicalRelation( com.google.common.base.Objects.hashCode(relation, output) } - override def sameResult(otherPlan: LogicalPlan): Boolean = { - otherPlan.canonicalized match { - case LogicalRelation(otherRelation, _, _) => relation == otherRelation - case _ => false - } - } - - // When comparing two LogicalRelations from within LogicalPlan.sameResult, we only need - // LogicalRelation.cleanArgs to return Seq(relation), since expectedOutputAttribute's - // expId can be different but the relation is still the same. - override lazy val cleanArgs: Seq[Any] = Seq(relation) + // Only care about relation when canonicalizing. + override def preCanonicalized: LogicalPlan = copy(catalogTable = None) - @transient override def computeStats(conf: CatalystConf): Statistics = { + @transient override def computeStats(conf: SQLConf): Statistics = { catalogTable.flatMap(_.stats.map(_.toPlanStats(output))).getOrElse( Statistics(sizeInBytes = relation.sizeInBytes)) } @@ -87,11 +60,8 @@ case class LogicalRelation( * unique expression ids. We respect the `expectedOutputAttributes` and create * new instances of attributes in it. */ - override def newInstance(): this.type = { - LogicalRelation( - relation, - expectedOutputAttributes.map(_.map(_.newInstance())), - catalogTable).asInstanceOf[this.type] + override def newInstance(): LogicalRelation = { + this.copy(output = output.map(_.newInstance())) } override def refresh(): Unit = relation match { @@ -101,3 +71,11 @@ case class LogicalRelation( override def simpleString: String = s"Relation[${Utils.truncatedString(output, ",")}] $relation" } + +object LogicalRelation { + def apply(relation: BaseRelation): LogicalRelation = + LogicalRelation(relation, relation.schema.toAttributes, None) + + def apply(relation: BaseRelation, table: CatalogTable): LogicalRelation = + LogicalRelation(relation, relation.schema.toAttributes, Some(table)) +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningAwareFileIndex.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningAwareFileIndex.scala index 75b5384d355e5..6b6f6388d54e8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningAwareFileIndex.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningAwareFileIndex.scala @@ -17,22 +17,17 @@ package org.apache.spark.sql.execution.datasources -import java.io.FileNotFoundException - import scala.collection.mutable import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs._ -import org.apache.hadoop.mapred.{FileInputFormat, JobConf} import org.apache.spark.internal.Logging -import org.apache.spark.metrics.source.HiveCatalogMetrics import org.apache.spark.sql.SparkSession import org.apache.spark.sql.catalyst.{expressions, InternalRow} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.util.{CaseInsensitiveMap, DateTimeUtils} import org.apache.spark.sql.types.{StringType, StructType} -import org.apache.spark.util.SerializableConfiguration /** * An abstract class that represents [[FileIndex]]s that are aware of partitioned tables. @@ -54,17 +49,19 @@ abstract class PartitioningAwareFileIndex( override def partitionSchema: StructType = partitionSpec().partitionColumns - protected val hadoopConf = sparkSession.sessionState.newHadoopConfWithOptions(parameters) + protected val hadoopConf: Configuration = + sparkSession.sessionState.newHadoopConfWithOptions(parameters) protected def leafFiles: mutable.LinkedHashMap[Path, FileStatus] protected def leafDirToChildrenFiles: Map[Path, Array[FileStatus]] - override def listFiles(filters: Seq[Expression]): Seq[PartitionDirectory] = { + override def listFiles( + partitionFilters: Seq[Expression], dataFilters: Seq[Expression]): Seq[PartitionDirectory] = { val selectedPartitions = if (partitionSpec().partitionColumns.isEmpty) { PartitionDirectory(InternalRow.empty, allFiles().filter(f => isDataPath(f.getPath))) :: Nil } else { - prunePartitions(filters, partitionSpec()).map { + prunePartitions(partitionFilters, partitionSpec()).map { case PartitionPath(values, path) => val files: Seq[FileStatus] = leafDirToChildrenFiles.get(path) match { case Some(existingDir) => @@ -127,7 +124,7 @@ abstract class PartitioningAwareFileIndex( }.keys.toSeq val caseInsensitiveOptions = CaseInsensitiveMap(parameters) - val timeZoneId = caseInsensitiveOptions.get("timeZone") + val timeZoneId = caseInsensitiveOptions.get(DateTimeUtils.TIMEZONE_OPTION) .getOrElse(sparkSession.sessionState.conf.sessionLocalTimeZone) userPartitionSchema match { @@ -186,7 +183,8 @@ abstract class PartitioningAwareFileIndex( val total = partitions.length val selectedSize = selected.length val percentPruned = (1 - selectedSize.toDouble / total.toDouble) * 100 - s"Selected $selectedSize partitions out of $total, pruned $percentPruned% partitions." + s"Selected $selectedSize partitions out of $total, " + + s"pruned ${if (total == 0) "0" else s"$percentPruned%"} partitions." } selected @@ -238,224 +236,8 @@ abstract class PartitioningAwareFileIndex( val name = path.getName !((name.startsWith("_") && !name.contains("=")) || name.startsWith(".")) } - - /** - * List leaf files of given paths. This method will submit a Spark job to do parallel - * listing whenever there is a path having more files than the parallel partition discovery - * discovery threshold. - * - * This is publicly visible for testing. - */ - def listLeafFiles(paths: Seq[Path]): mutable.LinkedHashSet[FileStatus] = { - val output = mutable.LinkedHashSet[FileStatus]() - val pathsToFetch = mutable.ArrayBuffer[Path]() - for (path <- paths) { - fileStatusCache.getLeafFiles(path) match { - case Some(files) => - HiveCatalogMetrics.incrementFileCacheHits(files.length) - output ++= files - case None => - pathsToFetch += path - } - } - val filter = FileInputFormat.getInputPathFilter(new JobConf(hadoopConf, this.getClass)) - val discovered = PartitioningAwareFileIndex.bulkListLeafFiles( - pathsToFetch, hadoopConf, filter, sparkSession) - discovered.foreach { case (path, leafFiles) => - HiveCatalogMetrics.incrementFilesDiscovered(leafFiles.size) - fileStatusCache.putLeafFiles(path, leafFiles.toArray) - output ++= leafFiles - } - output - } } -object PartitioningAwareFileIndex extends Logging { +object PartitioningAwareFileIndex { val BASE_PATH_PARAM = "basePath" - - /** A serializable variant of HDFS's BlockLocation. */ - private case class SerializableBlockLocation( - names: Array[String], - hosts: Array[String], - offset: Long, - length: Long) - - /** A serializable variant of HDFS's FileStatus. */ - private case class SerializableFileStatus( - path: String, - length: Long, - isDir: Boolean, - blockReplication: Short, - blockSize: Long, - modificationTime: Long, - accessTime: Long, - blockLocations: Array[SerializableBlockLocation]) - - /** - * Lists a collection of paths recursively. Picks the listing strategy adaptively depending - * on the number of paths to list. - * - * This may only be called on the driver. - * - * @return for each input path, the set of discovered files for the path - */ - private def bulkListLeafFiles( - paths: Seq[Path], - hadoopConf: Configuration, - filter: PathFilter, - sparkSession: SparkSession): Seq[(Path, Seq[FileStatus])] = { - - // Short-circuits parallel listing when serial listing is likely to be faster. - if (paths.size <= sparkSession.sessionState.conf.parallelPartitionDiscoveryThreshold) { - return paths.map { path => - (path, listLeafFiles(path, hadoopConf, filter, Some(sparkSession))) - } - } - - logInfo(s"Listing leaf files and directories in parallel under: ${paths.mkString(", ")}") - HiveCatalogMetrics.incrementParallelListingJobCount(1) - - val sparkContext = sparkSession.sparkContext - val serializableConfiguration = new SerializableConfiguration(hadoopConf) - val serializedPaths = paths.map(_.toString) - val parallelPartitionDiscoveryParallelism = - sparkSession.sessionState.conf.parallelPartitionDiscoveryParallelism - - // Set the number of parallelism to prevent following file listing from generating many tasks - // in case of large #defaultParallelism. - val numParallelism = Math.min(paths.size, parallelPartitionDiscoveryParallelism) - - val statusMap = sparkContext - .parallelize(serializedPaths, numParallelism) - .mapPartitions { pathStrings => - val hadoopConf = serializableConfiguration.value - pathStrings.map(new Path(_)).toSeq.map { path => - (path, listLeafFiles(path, hadoopConf, filter, None)) - }.iterator - }.map { case (path, statuses) => - val serializableStatuses = statuses.map { status => - // Turn FileStatus into SerializableFileStatus so we can send it back to the driver - val blockLocations = status match { - case f: LocatedFileStatus => - f.getBlockLocations.map { loc => - SerializableBlockLocation( - loc.getNames, - loc.getHosts, - loc.getOffset, - loc.getLength) - } - - case _ => - Array.empty[SerializableBlockLocation] - } - - SerializableFileStatus( - status.getPath.toString, - status.getLen, - status.isDirectory, - status.getReplication, - status.getBlockSize, - status.getModificationTime, - status.getAccessTime, - blockLocations) - } - (path.toString, serializableStatuses) - }.collect() - - // turn SerializableFileStatus back to Status - statusMap.map { case (path, serializableStatuses) => - val statuses = serializableStatuses.map { f => - val blockLocations = f.blockLocations.map { loc => - new BlockLocation(loc.names, loc.hosts, loc.offset, loc.length) - } - new LocatedFileStatus( - new FileStatus( - f.length, f.isDir, f.blockReplication, f.blockSize, f.modificationTime, - new Path(f.path)), - blockLocations) - } - (new Path(path), statuses) - } - } - - /** - * Lists a single filesystem path recursively. If a SparkSession object is specified, this - * function may launch Spark jobs to parallelize listing. - * - * If sessionOpt is None, this may be called on executors. - * - * @return all children of path that match the specified filter. - */ - private def listLeafFiles( - path: Path, - hadoopConf: Configuration, - filter: PathFilter, - sessionOpt: Option[SparkSession]): Seq[FileStatus] = { - logTrace(s"Listing $path") - val fs = path.getFileSystem(hadoopConf) - val name = path.getName.toLowerCase - - // [SPARK-17599] Prevent InMemoryFileIndex from failing if path doesn't exist - // Note that statuses only include FileStatus for the files and dirs directly under path, - // and does not include anything else recursively. - val statuses = try fs.listStatus(path) catch { - case _: FileNotFoundException => - logWarning(s"The directory $path was not found. Was it deleted very recently?") - Array.empty[FileStatus] - } - - val filteredStatuses = statuses.filterNot(status => shouldFilterOut(status.getPath.getName)) - - val allLeafStatuses = { - val (dirs, topLevelFiles) = filteredStatuses.partition(_.isDirectory) - val nestedFiles: Seq[FileStatus] = sessionOpt match { - case Some(session) => - bulkListLeafFiles(dirs.map(_.getPath), hadoopConf, filter, session).flatMap(_._2) - case _ => - dirs.flatMap(dir => listLeafFiles(dir.getPath, hadoopConf, filter, sessionOpt)) - } - val allFiles = topLevelFiles ++ nestedFiles - if (filter != null) allFiles.filter(f => filter.accept(f.getPath)) else allFiles - } - - allLeafStatuses.filterNot(status => shouldFilterOut(status.getPath.getName)).map { - case f: LocatedFileStatus => - f - - // NOTE: - // - // - Although S3/S3A/S3N file system can be quite slow for remote file metadata - // operations, calling `getFileBlockLocations` does no harm here since these file system - // implementations don't actually issue RPC for this method. - // - // - Here we are calling `getFileBlockLocations` in a sequential manner, but it should not - // be a big deal since we always use to `listLeafFilesInParallel` when the number of - // paths exceeds threshold. - case f => - // The other constructor of LocatedFileStatus will call FileStatus.getPermission(), - // which is very slow on some file system (RawLocalFileSystem, which is launch a - // subprocess and parse the stdout). - val locations = fs.getFileBlockLocations(f, 0, f.getLen) - val lfs = new LocatedFileStatus(f.getLen, f.isDirectory, f.getReplication, f.getBlockSize, - f.getModificationTime, 0, null, null, null, null, f.getPath, locations) - if (f.isSymlink) { - lfs.setSymlink(f.getSymlink) - } - lfs - } - } - - /** Checks if we should filter out this path name. */ - def shouldFilterOut(pathName: String): Boolean = { - // We filter follow paths: - // 1. everything that starts with _ and ., except _common_metadata and _metadata - // because Parquet needs to find those metadata files from leaf files returned by this method. - // We should refactor this logic to not mix metadata files with data files. - // 2. everything that ends with `._COPYING_`, because this is a intermediate state of file. we - // should skip this file in case of double reading. - val exclude = (pathName.startsWith("_") && !pathName.contains("=")) || - pathName.startsWith(".") || pathName.endsWith("._COPYING_") - val include = pathName.startsWith("_common_metadata") || pathName.startsWith("_metadata") - exclude && !include - } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningUtils.scala index 09876bbc2f85d..2d70172487e17 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningUtils.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.execution.datasources import java.lang.{Double => JDouble, Long => JLong} import java.math.{BigDecimal => JBigDecimal} -import java.util.TimeZone +import java.util.{Locale, TimeZone} import scala.collection.mutable.ArrayBuffer import scala.util.Try @@ -33,7 +33,6 @@ import org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec import org.apache.spark.sql.catalyst.expressions.{Cast, Literal} import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.types._ -import org.apache.spark.unsafe.types.UTF8String // TODO: We should tighten up visibility of the classes here once we clean up Hive coupling. @@ -129,7 +128,7 @@ object PartitioningUtils { // "hdfs://host:9000/invalidPath" // "hdfs://host:9000/path" // TODO: Selective case sensitivity. - val discoveredBasePaths = optDiscoveredBasePaths.flatMap(x => x).map(_.toString.toLowerCase()) + val discoveredBasePaths = optDiscoveredBasePaths.flatten.map(_.toString.toLowerCase()) assert( discoveredBasePaths.distinct.size == 1, "Conflicting directory structures detected. Suspicious paths:\b" + @@ -195,7 +194,7 @@ object PartitioningUtils { while (!finished) { // Sometimes (e.g., when speculative task is enabled), temporary directories may be left // uncleaned. Here we simply ignore them. - if (currentPath.getName.toLowerCase == "_temporary") { + if (currentPath.getName.toLowerCase(Locale.ROOT) == "_temporary") { return (None, None) } @@ -244,7 +243,7 @@ object PartitioningUtils { if (equalSignIndex == -1) { None } else { - val columnName = columnSpec.take(equalSignIndex) + val columnName = unescapePathName(columnSpec.take(equalSignIndex)) assert(columnName.nonEmpty, s"Empty partition column name in '$columnSpec'") val rawColumnValue = columnSpec.drop(equalSignIndex + 1) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PruneFileSourcePartitions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PruneFileSourcePartitions.scala index 8566a8061034b..905b8683e10bd 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PruneFileSourcePartitions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PruneFileSourcePartitions.scala @@ -59,9 +59,7 @@ private[sql] object PruneFileSourcePartitions extends Rule[LogicalPlan] { val prunedFileIndex = catalogFileIndex.filterPartitions(partitionKeyFilters.toSeq) val prunedFsRelation = fsRelation.copy(location = prunedFileIndex)(sparkSession) - val prunedLogicalRelation = logicalRelation.copy( - relation = prunedFsRelation, - expectedOutputAttributes = Some(logicalRelation.output)) + val prunedLogicalRelation = logicalRelation.copy(relation = prunedFsRelation) // Keep partition-pruning predicates so that they are visible in physical planning val filterExpression = filters.reduceLeft(And) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVDataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVDataSource.scala index 73e6abc6dad37..83bdf6fe224be 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVDataSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVDataSource.scala @@ -17,12 +17,11 @@ package org.apache.spark.sql.execution.datasources.csv -import java.io.InputStream import java.nio.charset.{Charset, StandardCharsets} -import com.univocity.parsers.csv.{CsvParser, CsvParserSettings} +import com.univocity.parsers.csv.CsvParser import org.apache.hadoop.conf.Configuration -import org.apache.hadoop.fs.{FileStatus, Path} +import org.apache.hadoop.fs.FileStatus import org.apache.hadoop.io.{LongWritable, Text} import org.apache.hadoop.mapred.TextInputFormat import org.apache.hadoop.mapreduce.Job @@ -50,15 +49,26 @@ abstract class CSVDataSource extends Serializable { conf: Configuration, file: PartitionedFile, parser: UnivocityParser, - parsedOptions: CSVOptions): Iterator[InternalRow] + schema: StructType): Iterator[InternalRow] /** * Infers the schema from `inputPaths` files. */ - def infer( + final def inferSchema( sparkSession: SparkSession, inputPaths: Seq[FileStatus], - parsedOptions: CSVOptions): Option[StructType] + parsedOptions: CSVOptions): Option[StructType] = { + if (inputPaths.nonEmpty) { + Some(infer(sparkSession, inputPaths, parsedOptions)) + } else { + None + } + } + + protected def infer( + sparkSession: SparkSession, + inputPaths: Seq[FileStatus], + parsedOptions: CSVOptions): StructType /** * Generates a header from the given row which is null-safe and duplicate-safe. @@ -116,37 +126,51 @@ object TextInputCSVDataSource extends CSVDataSource { conf: Configuration, file: PartitionedFile, parser: UnivocityParser, - parsedOptions: CSVOptions): Iterator[InternalRow] = { + schema: StructType): Iterator[InternalRow] = { val lines = { val linesReader = new HadoopFileLinesReader(file, conf) Option(TaskContext.get()).foreach(_.addTaskCompletionListener(_ => linesReader.close())) linesReader.map { line => - new String(line.getBytes, 0, line.getLength, parsedOptions.charset) + new String(line.getBytes, 0, line.getLength, parser.options.charset) } } - val shouldDropHeader = parsedOptions.headerFlag && file.start == 0 - UnivocityParser.parseIterator(lines, shouldDropHeader, parser) + val shouldDropHeader = parser.options.headerFlag && file.start == 0 + UnivocityParser.parseIterator(lines, shouldDropHeader, parser, schema) } override def infer( sparkSession: SparkSession, inputPaths: Seq[FileStatus], - parsedOptions: CSVOptions): Option[StructType] = { - val csv: Dataset[String] = createBaseDataset(sparkSession, inputPaths, parsedOptions) - val firstLine: String = CSVUtils.filterCommentAndEmpty(csv, parsedOptions).first() - val firstRow = new CsvParser(parsedOptions.asParserSettings).parseLine(firstLine) - val caseSensitive = sparkSession.sessionState.conf.caseSensitiveAnalysis - val header = makeSafeHeader(firstRow, caseSensitive, parsedOptions) - val tokenRDD = csv.rdd.mapPartitions { iter => - val filteredLines = CSVUtils.filterCommentAndEmpty(iter, parsedOptions) - val linesWithoutHeader = - CSVUtils.filterHeaderLine(filteredLines, firstLine, parsedOptions) - val parser = new CsvParser(parsedOptions.asParserSettings) - linesWithoutHeader.map(parser.parseLine) - } + parsedOptions: CSVOptions): StructType = { + val csv = createBaseDataset(sparkSession, inputPaths, parsedOptions) + val maybeFirstLine = CSVUtils.filterCommentAndEmpty(csv, parsedOptions).take(1).headOption + inferFromDataset(sparkSession, csv, maybeFirstLine, parsedOptions) + } - Some(CSVInferSchema.infer(tokenRDD, header, parsedOptions)) + /** + * Infers the schema from `Dataset` that stores CSV string records. + */ + def inferFromDataset( + sparkSession: SparkSession, + csv: Dataset[String], + maybeFirstLine: Option[String], + parsedOptions: CSVOptions): StructType = maybeFirstLine match { + case Some(firstLine) => + val firstRow = new CsvParser(parsedOptions.asParserSettings).parseLine(firstLine) + val caseSensitive = sparkSession.sessionState.conf.caseSensitiveAnalysis + val header = makeSafeHeader(firstRow, caseSensitive, parsedOptions) + val tokenRDD = csv.rdd.mapPartitions { iter => + val filteredLines = CSVUtils.filterCommentAndEmpty(iter, parsedOptions) + val linesWithoutHeader = + CSVUtils.filterHeaderLine(filteredLines, firstLine, parsedOptions) + val parser = new CsvParser(parsedOptions.asParserSettings) + linesWithoutHeader.map(parser.parseLine) + } + CSVInferSchema.infer(tokenRDD, header, parsedOptions) + case None => + // If the first line could not be read, just return the empty schema. + StructType(Nil) } private def createBaseDataset( @@ -179,39 +203,40 @@ object WholeFileCSVDataSource extends CSVDataSource { conf: Configuration, file: PartitionedFile, parser: UnivocityParser, - parsedOptions: CSVOptions): Iterator[InternalRow] = { + schema: StructType): Iterator[InternalRow] = { UnivocityParser.parseStream( CodecStreams.createInputStreamWithCloseResource(conf, file.filePath), - parsedOptions.headerFlag, - parser) + parser.options.headerFlag, + parser, + schema) } override def infer( sparkSession: SparkSession, inputPaths: Seq[FileStatus], - parsedOptions: CSVOptions): Option[StructType] = { - val csv: RDD[PortableDataStream] = createBaseRdd(sparkSession, inputPaths, parsedOptions) - val maybeFirstRow: Option[Array[String]] = csv.flatMap { lines => + parsedOptions: CSVOptions): StructType = { + val csv = createBaseRdd(sparkSession, inputPaths, parsedOptions) + csv.flatMap { lines => UnivocityParser.tokenizeStream( CodecStreams.createInputStreamWithCloseResource(lines.getConfiguration, lines.getPath()), - false, + shouldDropHeader = false, new CsvParser(parsedOptions.asParserSettings)) - }.take(1).headOption - - if (maybeFirstRow.isDefined) { - val firstRow = maybeFirstRow.get - val caseSensitive = sparkSession.sessionState.conf.caseSensitiveAnalysis - val header = makeSafeHeader(firstRow, caseSensitive, parsedOptions) - val tokenRDD = csv.flatMap { lines => - UnivocityParser.tokenizeStream( - CodecStreams.createInputStreamWithCloseResource(lines.getConfiguration, lines.getPath()), - parsedOptions.headerFlag, - new CsvParser(parsedOptions.asParserSettings)) - } - Some(CSVInferSchema.infer(tokenRDD, header, parsedOptions)) - } else { - // If the first row could not be read, just return the empty schema. - Some(StructType(Nil)) + }.take(1).headOption match { + case Some(firstRow) => + val caseSensitive = sparkSession.sessionState.conf.caseSensitiveAnalysis + val header = makeSafeHeader(firstRow, caseSensitive, parsedOptions) + val tokenRDD = csv.flatMap { lines => + UnivocityParser.tokenizeStream( + CodecStreams.createInputStreamWithCloseResource( + lines.getConfiguration, + lines.getPath()), + parsedOptions.headerFlag, + new CsvParser(parsedOptions.asParserSettings)) + } + CSVInferSchema.infer(tokenRDD, header, parsedOptions) + case None => + // If the first row could not be read, just return the empty schema. + StructType(Nil) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVFileFormat.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVFileFormat.scala index 29c41455279e6..a99bdfee5d6e6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVFileFormat.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVFileFormat.scala @@ -51,12 +51,10 @@ class CSVFileFormat extends TextBasedFileFormat with DataSourceRegister { sparkSession: SparkSession, options: Map[String, String], files: Seq[FileStatus]): Option[StructType] = { - require(files.nonEmpty, "Cannot infer schema from an empty set of files") - val parsedOptions = new CSVOptions(options, sparkSession.sessionState.conf.sessionLocalTimeZone) - CSVDataSource(parsedOptions).infer(sparkSession, files, parsedOptions) + CSVDataSource(parsedOptions).inferSchema(sparkSession, files, parsedOptions) } override def prepareWrite( @@ -113,8 +111,11 @@ class CSVFileFormat extends TextBasedFileFormat with DataSourceRegister { (file: PartitionedFile) => { val conf = broadcastedHadoopConf.value.value - val parser = new UnivocityParser(dataSchema, requiredSchema, parsedOptions) - CSVDataSource(parsedOptions).readFile(conf, file, parser, parsedOptions) + val parser = new UnivocityParser( + StructType(dataSchema.filterNot(_.name == parsedOptions.columnNameOfCorruptRecord)), + StructType(requiredSchema.filterNot(_.name == parsedOptions.columnNameOfCorruptRecord)), + parsedOptions) + CSVDataSource(parsedOptions).readFile(conf, file, parser, requiredSchema) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVOptions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVOptions.scala index 50503385ad6d1..62e4c6e4b4ea0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVOptions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVOptions.scala @@ -24,9 +24,9 @@ import com.univocity.parsers.csv.{CsvParserSettings, CsvWriterSettings, Unescape import org.apache.commons.lang3.time.FastDateFormat import org.apache.spark.internal.Logging -import org.apache.spark.sql.catalyst.util.{CaseInsensitiveMap, CompressionCodecs, ParseModes} +import org.apache.spark.sql.catalyst.util._ -private[csv] class CSVOptions( +class CSVOptions( @transient private val parameters: CaseInsensitiveMap[String], defaultTimeZoneId: String, defaultColumnNameOfCorruptRecord: String) @@ -71,9 +71,9 @@ private[csv] class CSVOptions( val param = parameters.getOrElse(paramName, default.toString) if (param == null) { default - } else if (param.toLowerCase == "true") { + } else if (param.toLowerCase(Locale.ROOT) == "true") { true - } else if (param.toLowerCase == "false") { + } else if (param.toLowerCase(Locale.ROOT) == "false") { false } else { throw new Exception(s"$paramName flag can be true or false") @@ -82,7 +82,8 @@ private[csv] class CSVOptions( val delimiter = CSVUtils.toChar( parameters.getOrElse("sep", parameters.getOrElse("delimiter", ","))) - private val parseMode = parameters.getOrElse("mode", "PERMISSIVE") + val parseMode: ParseMode = + parameters.get("mode").map(ParseMode.fromString).getOrElse(PermissiveMode) val charset = parameters.getOrElse("encoding", parameters.getOrElse("charset", StandardCharsets.UTF_8.name())) @@ -92,17 +93,13 @@ private[csv] class CSVOptions( val headerFlag = getBool("header") val inferSchemaFlag = getBool("inferSchema") - val ignoreLeadingWhiteSpaceFlag = getBool("ignoreLeadingWhiteSpace") - val ignoreTrailingWhiteSpaceFlag = getBool("ignoreTrailingWhiteSpace") + val ignoreLeadingWhiteSpaceInRead = getBool("ignoreLeadingWhiteSpace", default = false) + val ignoreTrailingWhiteSpaceInRead = getBool("ignoreTrailingWhiteSpace", default = false) - // Parse mode flags - if (!ParseModes.isValidMode(parseMode)) { - logWarning(s"$parseMode is not a valid parse mode. Using ${ParseModes.DEFAULT}.") - } - - val failFast = ParseModes.isFailFastMode(parseMode) - val dropMalformed = ParseModes.isDropMalformedMode(parseMode) - val permissive = ParseModes.isPermissiveMode(parseMode) + // For write, both options were `true` by default. We leave it as `true` for + // backwards compatibility. + val ignoreLeadingWhiteSpaceFlagInWrite = getBool("ignoreLeadingWhiteSpace", default = true) + val ignoreTrailingWhiteSpaceFlagInWrite = getBool("ignoreTrailingWhiteSpace", default = true) val columnNameOfCorruptRecord = parameters.getOrElse("columnNameOfCorruptRecord", defaultColumnNameOfCorruptRecord) @@ -120,7 +117,8 @@ private[csv] class CSVOptions( name.map(CompressionCodecs.getCodecClassName) } - val timeZone: TimeZone = TimeZone.getTimeZone(parameters.getOrElse("timeZone", defaultTimeZoneId)) + val timeZone: TimeZone = TimeZone.getTimeZone( + parameters.getOrElse(DateTimeUtils.TIMEZONE_OPTION, defaultTimeZoneId)) // Uses `FastDateFormat` which can be direct replacement for `SimpleDateFormat` and thread-safe. val dateFormat: FastDateFormat = @@ -128,7 +126,7 @@ private[csv] class CSVOptions( val timestampFormat: FastDateFormat = FastDateFormat.getInstance( - parameters.getOrElse("timestampFormat", "yyyy-MM-dd'T'HH:mm:ss.SSSZZ"), timeZone, Locale.US) + parameters.getOrElse("timestampFormat", "yyyy-MM-dd'T'HH:mm:ss.SSSXXX"), timeZone, Locale.US) val wholeFile = parameters.get("wholeFile").map(_.toBoolean).getOrElse(false) @@ -138,8 +136,6 @@ private[csv] class CSVOptions( val escapeQuotes = getBool("escapeQuotes", true) - val maxMalformedLogPerPartition = getInt("maxMalformedLogPerPartition", 10) - val quoteAll = getBool("quoteAll", false) val inputBufferSize = 128 @@ -153,6 +149,8 @@ private[csv] class CSVOptions( format.setQuote(quote) format.setQuoteEscape(escape) format.setComment(comment) + writerSettings.setIgnoreLeadingWhitespaces(ignoreLeadingWhiteSpaceFlagInWrite) + writerSettings.setIgnoreTrailingWhitespaces(ignoreTrailingWhiteSpaceFlagInWrite) writerSettings.setNullValue(nullValue) writerSettings.setEmptyValue(nullValue) writerSettings.setSkipEmptyLines(true) @@ -168,8 +166,8 @@ private[csv] class CSVOptions( format.setQuote(quote) format.setQuoteEscape(escape) format.setComment(comment) - settings.setIgnoreLeadingWhitespaces(ignoreLeadingWhiteSpaceFlag) - settings.setIgnoreTrailingWhitespaces(ignoreTrailingWhiteSpaceFlag) + settings.setIgnoreLeadingWhitespaces(ignoreLeadingWhiteSpaceInRead) + settings.setIgnoreTrailingWhitespaces(ignoreTrailingWhiteSpaceInRead) settings.setReadInputOnSeparateThread(false) settings.setInputBufferSize(inputBufferSize) settings.setMaxColumns(maxColumns) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/UnivocityParser.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/UnivocityParser.scala index 3b3b87e4354d6..c3657acb7d867 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/UnivocityParser.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/UnivocityParser.scala @@ -30,14 +30,15 @@ import com.univocity.parsers.csv.CsvParser import org.apache.spark.internal.Logging import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.GenericInternalRow -import org.apache.spark.sql.catalyst.util.DateTimeUtils +import org.apache.spark.sql.catalyst.util.{BadRecordException, DateTimeUtils} +import org.apache.spark.sql.execution.datasources.FailureSafeParser import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String -private[csv] class UnivocityParser( +class UnivocityParser( schema: StructType, requiredSchema: StructType, - private val options: CSVOptions) extends Logging { + val options: CSVOptions) extends Logging { require(requiredSchema.toSet.subsetOf(schema.toSet), "requiredSchema should be the subset of schema.") @@ -46,39 +47,26 @@ private[csv] class UnivocityParser( // A `ValueConverter` is responsible for converting the given value to a desired type. private type ValueConverter = String => Any - private val corruptFieldIndex = schema.getFieldIndex(options.columnNameOfCorruptRecord) - corruptFieldIndex.foreach { corrFieldIndex => - require(schema(corrFieldIndex).dataType == StringType) - require(schema(corrFieldIndex).nullable) - } - - private val dataSchema = StructType(schema.filter(_.name != options.columnNameOfCorruptRecord)) - private val tokenizer = new CsvParser(options.asParserSettings) - private var numMalformedRecords = 0 - private val row = new GenericInternalRow(requiredSchema.length) - // In `PERMISSIVE` parse mode, we should be able to put the raw malformed row into the field - // specified in `columnNameOfCorruptRecord`. The raw input is retrieved by this method. - private def getCurrentInput(): String = tokenizer.getContext.currentParsedContent().stripLineEnd + // Retrieve the raw record string. + private def getCurrentInput: UTF8String = { + UTF8String.fromString(tokenizer.getContext.currentParsedContent().stripLineEnd) + } - // This parser loads an `tokenIndexArr`-th position value in input tokens, - // then put the value in `row(rowIndexArr)`. + // This parser first picks some tokens from the input tokens, according to the required schema, + // then parse these tokens and put the values in a row, with the order specified by the required + // schema. // // For example, let's say there is CSV data as below: // // a,b,c // 1,2,A // - // Also, let's say `columnNameOfCorruptRecord` is set to "_unparsed", `header` is `true` - // by user and the user selects "c", "b", "_unparsed" and "a" fields. In this case, we need - // to map those values below: - // - // required schema - ["c", "b", "_unparsed", "a"] - // CSV data schema - ["a", "b", "c"] - // required CSV data schema - ["c", "b", "a"] + // So the CSV data schema is: ["a", "b", "c"] + // And let's say the required schema is: ["c", "b"] // // with the input tokens, // @@ -86,45 +74,12 @@ private[csv] class UnivocityParser( // // Each input token is placed in each output row's position by mapping these. In this case, // - // output row - ["A", 2, null, 1] - // - // In more details, - // - `valueConverters`, input tokens - CSV data schema - // `valueConverters` keeps the positions of input token indices (by its index) to each - // value's converter (by its value) in an order of CSV data schema. In this case, - // [string->int, string->int, string->string]. - // - // - `tokenIndexArr`, input tokens - required CSV data schema - // `tokenIndexArr` keeps the positions of input token indices (by its index) to reordered - // fields given the required CSV data schema (by its value). In this case, [2, 1, 0]. - // - // - `rowIndexArr`, input tokens - required schema - // `rowIndexArr` keeps the positions of input token indices (by its index) to reordered - // field indices given the required schema (by its value). In this case, [0, 1, 3]. + // output row - ["A", 2] private val valueConverters: Array[ValueConverter] = - dataSchema.map(f => makeConverter(f.name, f.dataType, f.nullable, options)).toArray - - // Only used to create both `tokenIndexArr` and `rowIndexArr`. This variable means - // the fields that we should try to convert. - private val reorderedFields = if (options.dropMalformed) { - // If `dropMalformed` is enabled, then it needs to parse all the values - // so that we can decide which row is malformed. - requiredSchema ++ schema.filterNot(requiredSchema.contains(_)) - } else { - requiredSchema - } + schema.map(f => makeConverter(f.name, f.dataType, f.nullable, options)).toArray private val tokenIndexArr: Array[Int] = { - reorderedFields - .filter(_.name != options.columnNameOfCorruptRecord) - .map(f => dataSchema.indexOf(f)).toArray - } - - private val rowIndexArr: Array[Int] = if (corruptFieldIndex.isDefined) { - val corrFieldIndex = corruptFieldIndex.get - reorderedFields.indices.filter(_ != corrFieldIndex).toArray - } else { - reorderedFields.indices.toArray + requiredSchema.map(f => schema.indexOf(f)).toArray } /** @@ -205,7 +160,7 @@ private[csv] class UnivocityParser( } case _: StringType => (d: String) => - nullSafeDatum(d, name, nullable, options)(UTF8String.fromString(_)) + nullSafeDatum(d, name, nullable, options)(UTF8String.fromString) case udt: UserDefinedType[_] => (datum: String) => makeConverter(name, udt.sqlType, nullable, options) @@ -233,81 +188,41 @@ private[csv] class UnivocityParser( * Parses a single CSV string and turns it into either one resulting row or no row (if the * the record is malformed). */ - def parse(input: String): Option[InternalRow] = convert(tokenizer.parseLine(input)) - - private def convert(tokens: Array[String]): Option[InternalRow] = { - convertWithParseMode(tokens) { tokens => - var i: Int = 0 - while (i < tokenIndexArr.length) { - // It anyway needs to try to parse since it decides if this row is malformed - // or not after trying to cast in `DROPMALFORMED` mode even if the casted - // value is not stored in the row. - val from = tokenIndexArr(i) - val to = rowIndexArr(i) - val value = valueConverters(from).apply(tokens(from)) - if (i < requiredSchema.length) { - row(to) = value - } - i += 1 - } - row - } - } - - private def convertWithParseMode( - tokens: Array[String])(convert: Array[String] => InternalRow): Option[InternalRow] = { - if (options.dropMalformed && dataSchema.length != tokens.length) { - if (numMalformedRecords < options.maxMalformedLogPerPartition) { - logWarning(s"Dropping malformed line: ${tokens.mkString(options.delimiter.toString)}") - } - if (numMalformedRecords == options.maxMalformedLogPerPartition - 1) { - logWarning( - s"More than ${options.maxMalformedLogPerPartition} malformed records have been " + - "found on this partition. Malformed records from now on will not be logged.") + def parse(input: String): InternalRow = convert(tokenizer.parseLine(input)) + + private def convert(tokens: Array[String]): InternalRow = { + if (tokens.length != schema.length) { + // If the number of tokens doesn't match the schema, we should treat it as a malformed record. + // However, we still have chance to parse some of the tokens, by adding extra null tokens in + // the tail if the number is smaller, or by dropping extra tokens if the number is larger. + val checkedTokens = if (schema.length > tokens.length) { + tokens ++ new Array[String](schema.length - tokens.length) + } else { + tokens.take(schema.length) } - numMalformedRecords += 1 - None - } else if (options.failFast && dataSchema.length != tokens.length) { - throw new RuntimeException(s"Malformed line in FAILFAST mode: " + - s"${tokens.mkString(options.delimiter.toString)}") - } else { - // If a length of parsed tokens is not equal to expected one, it makes the length the same - // with the expected. If the length is shorter, it adds extra tokens in the tail. - // If longer, it drops extra tokens. - // - // TODO: Revisit this; if a length of tokens does not match an expected length in the schema, - // we probably need to treat it as a malformed record. - // See an URL below for related discussions: - // https://github.com/apache/spark/pull/16928#discussion_r102657214 - val checkedTokens = if (options.permissive && dataSchema.length != tokens.length) { - if (dataSchema.length > tokens.length) { - tokens ++ new Array[String](dataSchema.length - tokens.length) - } else { - tokens.take(dataSchema.length) + def getPartialResult(): Option[InternalRow] = { + try { + Some(convert(checkedTokens)) + } catch { + case _: BadRecordException => None } - } else { - tokens } - + throw BadRecordException( + () => getCurrentInput, + getPartialResult, + new RuntimeException("Malformed CSV record")) + } else { try { - Some(convert(checkedTokens)) + var i = 0 + while (i < requiredSchema.length) { + val from = tokenIndexArr(i) + row(i) = valueConverters(from).apply(tokens(from)) + i += 1 + } + row } catch { - case NonFatal(e) if options.permissive => - val row = new GenericInternalRow(requiredSchema.length) - corruptFieldIndex.foreach(row(_) = UTF8String.fromString(getCurrentInput())) - Some(row) - case NonFatal(e) if options.dropMalformed => - if (numMalformedRecords < options.maxMalformedLogPerPartition) { - logWarning("Parse exception. " + - s"Dropping malformed line: ${tokens.mkString(options.delimiter.toString)}") - } - if (numMalformedRecords == options.maxMalformedLogPerPartition - 1) { - logWarning( - s"More than ${options.maxMalformedLogPerPartition} malformed records have been " + - "found on this partition. Malformed records from now on will not be logged.") - } - numMalformedRecords += 1 - None + case NonFatal(e) => + throw BadRecordException(() => getCurrentInput, () => None, e) } } } @@ -331,10 +246,16 @@ private[csv] object UnivocityParser { def parseStream( inputStream: InputStream, shouldDropHeader: Boolean, - parser: UnivocityParser): Iterator[InternalRow] = { + parser: UnivocityParser, + schema: StructType): Iterator[InternalRow] = { val tokenizer = parser.tokenizer + val safeParser = new FailureSafeParser[Array[String]]( + input => Seq(parser.convert(input)), + parser.options.parseMode, + schema, + parser.options.columnNameOfCorruptRecord) convertStream(inputStream, shouldDropHeader, tokenizer) { tokens => - parser.convert(tokens) + safeParser.parse(tokens) }.flatten } @@ -368,7 +289,8 @@ private[csv] object UnivocityParser { def parseIterator( lines: Iterator[String], shouldDropHeader: Boolean, - parser: UnivocityParser): Iterator[InternalRow] = { + parser: UnivocityParser, + schema: StructType): Iterator[InternalRow] = { val options = parser.options val linesWithoutHeader = if (shouldDropHeader) { @@ -381,6 +303,12 @@ private[csv] object UnivocityParser { val filteredLines: Iterator[String] = CSVUtils.filterCommentAndEmpty(linesWithoutHeader, options) - filteredLines.flatMap(line => parser.parse(line)) + + val safeParser = new FailureSafeParser[String]( + input => Seq(parser.parse(input)), + parser.options.parseMode, + schema, + parser.options.columnNameOfCorruptRecord) + filteredLines.flatMap(safeParser.parse) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ddl.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ddl.scala index 110d503f91cf4..f8d4a9bb5b81a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ddl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ddl.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.execution.datasources +import java.util.Locale + import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.catalog.{CatalogTable, CatalogUtils} @@ -75,7 +77,7 @@ case class CreateTempViewUsing( } def run(sparkSession: SparkSession): Seq[Row] = { - if (provider.toLowerCase == DDLUtils.HIVE_PROVIDER) { + if (provider.toLowerCase(Locale.ROOT) == DDLUtils.HIVE_PROVIDER) { throw new AnalysisException("Hive data source can only be used with tables, " + "you can't use it with CREATE TEMP VIEW USING") } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCOptions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCOptions.scala index d4d34646545ba..591096d5efd22 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCOptions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCOptions.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.execution.datasources.jdbc import java.sql.{Connection, DriverManager} -import java.util.Properties +import java.util.{Locale, Properties} import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap @@ -55,7 +55,7 @@ class JDBCOptions( */ val asConnectionProperties: Properties = { val properties = new Properties() - parameters.originalMap.filterKeys(key => !jdbcOptionNames(key.toLowerCase)) + parameters.originalMap.filterKeys(key => !jdbcOptionNames(key.toLowerCase(Locale.ROOT))) .foreach { case (k, v) => properties.setProperty(k, v) } properties } @@ -119,6 +119,7 @@ class JDBCOptions( // E.g., "CREATE TABLE t (name string) ENGINE=InnoDB DEFAULT CHARSET=utf8" // TODO: to reuse the existing partition parameters for those partition specific options val createTableOptions = parameters.getOrElse(JDBC_CREATE_TABLE_OPTIONS, "") + val createTableColumnTypes = parameters.get(JDBC_CREATE_TABLE_COLUMN_TYPES) val batchSize = { val size = parameters.getOrElse(JDBC_BATCH_INSERT_SIZE, "1000").toInt require(size >= 1, @@ -140,7 +141,7 @@ object JDBCOptions { private val jdbcOptionNames = collection.mutable.Set[String]() private def newOption(name: String): String = { - jdbcOptionNames += name.toLowerCase + jdbcOptionNames += name.toLowerCase(Locale.ROOT) name } @@ -154,6 +155,7 @@ object JDBCOptions { val JDBC_BATCH_FETCH_SIZE = newOption("fetchsize") val JDBC_TRUNCATE = newOption("truncate") val JDBC_CREATE_TABLE_OPTIONS = newOption("createTableOptions") + val JDBC_CREATE_TABLE_COLUMN_TYPES = newOption("createTableColumnTypes") val JDBC_BATCH_INSERT_SIZE = newOption("batchsize") val JDBC_TXN_ISOLATION_LEVEL = newOption("isolationLevel") } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcRelationProvider.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcRelationProvider.scala index 88f6cb0021305..74dcfb06f5c2b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcRelationProvider.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcRelationProvider.scala @@ -69,7 +69,7 @@ class JdbcRelationProvider extends CreatableRelationProvider } else { // Otherwise, do not truncate the table, instead drop and recreate it dropTable(conn, options.table) - createTable(conn, df.schema, options) + createTable(conn, df, options) saveTable(df, Some(df.schema), isCaseSensitive, options) } @@ -87,7 +87,7 @@ class JdbcRelationProvider extends CreatableRelationProvider // Therefore, it is okay to do nothing here and then just return the relation below. } } else { - createTable(conn, df.schema, options) + createTable(conn, df, options) saveTable(df, Some(df.schema), isCaseSensitive, options) } } finally { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala index d89f600874177..0183805d56257 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.execution.datasources.jdbc import java.sql.{Connection, Driver, DriverManager, PreparedStatement, ResultSet, ResultSetMetaData, SQLException} +import java.util.Locale import scala.collection.JavaConverters._ import scala.util.Try @@ -30,7 +31,8 @@ import org.apache.spark.sql.{AnalysisException, DataFrame, Row} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.encoders.RowEncoder import org.apache.spark.sql.catalyst.expressions.SpecificInternalRow -import org.apache.spark.sql.catalyst.util.{DateTimeUtils, GenericArrayData} +import org.apache.spark.sql.catalyst.parser.CatalystSqlParser +import org.apache.spark.sql.catalyst.util.{CaseInsensitiveMap, DateTimeUtils, GenericArrayData} import org.apache.spark.sql.jdbc.{JdbcDialect, JdbcDialects, JdbcType} import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String @@ -541,7 +543,7 @@ object JdbcUtils extends Logging { case ArrayType(et, _) => // remove type length parameters from end of type name val typeName = getJdbcType(et, dialect).databaseTypeDefinition - .toLowerCase.split("\\(")(0) + .toLowerCase(Locale.ROOT).split("\\(")(0) (stmt: PreparedStatement, row: Row, pos: Int) => val array = conn.createArrayOf( typeName, @@ -650,8 +652,17 @@ object JdbcUtils extends Logging { case e: SQLException => val cause = e.getNextException if (cause != null && e.getCause != cause) { + // If there is no cause already, set 'next exception' as cause. If cause is null, + // it *may* be because no cause was set yet if (e.getCause == null) { - e.initCause(cause) + try { + e.initCause(cause) + } catch { + // Or it may be null because the cause *was* explicitly initialized, to *null*, + // in which case this fails. There is no other way to detect it. + // addSuppressed in this case as well. + case _: IllegalStateException => e.addSuppressed(cause) + } } else { e.addSuppressed(cause) } @@ -680,18 +691,70 @@ object JdbcUtils extends Logging { /** * Compute the schema string for this RDD. */ - def schemaString(schema: StructType, url: String): String = { + def schemaString( + df: DataFrame, + url: String, + createTableColumnTypes: Option[String] = None): String = { val sb = new StringBuilder() val dialect = JdbcDialects.get(url) - schema.fields foreach { field => + val userSpecifiedColTypesMap = createTableColumnTypes + .map(parseUserSpecifiedCreateTableColumnTypes(df, _)) + .getOrElse(Map.empty[String, String]) + df.schema.fields.foreach { field => val name = dialect.quoteIdentifier(field.name) - val typ: String = getJdbcType(field.dataType, dialect).databaseTypeDefinition + val typ = userSpecifiedColTypesMap + .getOrElse(field.name, getJdbcType(field.dataType, dialect).databaseTypeDefinition) val nullable = if (field.nullable) "" else "NOT NULL" sb.append(s", $name $typ $nullable") } if (sb.length < 2) "" else sb.substring(2) } + /** + * Parses the user specified createTableColumnTypes option value string specified in the same + * format as create table ddl column types, and returns Map of field name and the data type to + * use in-place of the default data type. + */ + private def parseUserSpecifiedCreateTableColumnTypes( + df: DataFrame, + createTableColumnTypes: String): Map[String, String] = { + def typeName(f: StructField): String = { + // char/varchar gets translated to string type. Real data type specified by the user + // is available in the field metadata as HIVE_TYPE_STRING + if (f.metadata.contains(HIVE_TYPE_STRING)) { + f.metadata.getString(HIVE_TYPE_STRING) + } else { + f.dataType.catalogString + } + } + + val userSchema = CatalystSqlParser.parseTableSchema(createTableColumnTypes) + val nameEquality = df.sparkSession.sessionState.conf.resolver + + // checks duplicate columns in the user specified column types. + userSchema.fieldNames.foreach { col => + val duplicatesCols = userSchema.fieldNames.filter(nameEquality(_, col)) + if (duplicatesCols.size >= 2) { + throw new AnalysisException( + "Found duplicate column(s) in createTableColumnTypes option value: " + + duplicatesCols.mkString(", ")) + } + } + + // checks if user specified column names exist in the DataFrame schema + userSchema.fieldNames.foreach { col => + df.schema.find(f => nameEquality(f.name, col)).getOrElse { + throw new AnalysisException( + s"createTableColumnTypes option column $col not found in schema " + + df.schema.catalogString) + } + } + + val userSchemaMap = userSchema.fields.map(f => f.name -> typeName(f)).toMap + val isCaseSensitive = df.sparkSession.sessionState.conf.caseSensitiveAnalysis + if (isCaseSensitive) userSchemaMap else CaseInsensitiveMap(userSchemaMap) + } + /** * Saves the RDD to the database in a single transaction. */ @@ -726,9 +789,10 @@ object JdbcUtils extends Logging { */ def createTable( conn: Connection, - schema: StructType, + df: DataFrame, options: JDBCOptions): Unit = { - val strSchema = schemaString(schema, options.url) + val strSchema = schemaString( + df, options.url, options.createTableColumnTypes) val table = options.table val createTableOptions = options.createTableOptions // Create the table if the table does not exist. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonDataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonDataSource.scala index 18843bfc307b3..4f2963da9ace9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonDataSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonDataSource.scala @@ -17,32 +17,32 @@ package org.apache.spark.sql.execution.datasources.json -import scala.reflect.ClassTag +import java.io.InputStream import com.fasterxml.jackson.core.{JsonFactory, JsonParser} import com.google.common.io.ByteStreams import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.FileStatus -import org.apache.hadoop.io.{LongWritable, Text} +import org.apache.hadoop.io.Text import org.apache.hadoop.mapreduce.Job -import org.apache.hadoop.mapreduce.lib.input.{FileInputFormat, TextInputFormat} +import org.apache.hadoop.mapreduce.lib.input.FileInputFormat import org.apache.spark.TaskContext import org.apache.spark.input.{PortableDataStream, StreamInputFormat} import org.apache.spark.rdd.{BinaryFileRDD, RDD} -import org.apache.spark.sql.{AnalysisException, SparkSession} +import org.apache.spark.sql.{AnalysisException, Dataset, Encoders, SparkSession} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.json.{CreateJacksonParser, JacksonParser, JSONOptions} -import org.apache.spark.sql.execution.datasources.{CodecStreams, HadoopFileLinesReader, PartitionedFile} +import org.apache.spark.sql.execution.datasources._ +import org.apache.spark.sql.execution.datasources.text.TextFileFormat import org.apache.spark.sql.types.StructType import org.apache.spark.unsafe.types.UTF8String import org.apache.spark.util.Utils /** * Common functions for parsing JSON files - * @tparam T A datatype containing the unparsed JSON, such as [[Text]] or [[String]] */ -abstract class JsonDataSource[T] extends Serializable { +abstract class JsonDataSource extends Serializable { def isSplitable: Boolean /** @@ -51,30 +51,15 @@ abstract class JsonDataSource[T] extends Serializable { def readFile( conf: Configuration, file: PartitionedFile, - parser: JacksonParser): Iterator[InternalRow] + parser: JacksonParser, + schema: StructType): Iterator[InternalRow] - /** - * Create an [[RDD]] that handles the preliminary parsing of [[T]] records - */ - protected def createBaseRdd( - sparkSession: SparkSession, - inputPaths: Seq[FileStatus]): RDD[T] - - /** - * A generic wrapper to invoke the correct [[JsonFactory]] method to allocate a [[JsonParser]] - * for an instance of [[T]] - */ - def createParser(jsonFactory: JsonFactory, value: T): JsonParser - - final def infer( + final def inferSchema( sparkSession: SparkSession, inputPaths: Seq[FileStatus], parsedOptions: JSONOptions): Option[StructType] = { if (inputPaths.nonEmpty) { - val jsonSchema = JsonInferSchema.infer( - createBaseRdd(sparkSession, inputPaths), - parsedOptions, - createParser) + val jsonSchema = infer(sparkSession, inputPaths, parsedOptions) checkConstraints(jsonSchema) Some(jsonSchema) } else { @@ -82,6 +67,11 @@ abstract class JsonDataSource[T] extends Serializable { } } + protected def infer( + sparkSession: SparkSession, + inputPaths: Seq[FileStatus], + parsedOptions: JSONOptions): StructType + /** Constraints to be imposed on schema to be stored. */ private def checkConstraints(schema: StructType): Unit = { if (schema.fieldNames.length != schema.fieldNames.distinct.length) { @@ -95,96 +85,102 @@ abstract class JsonDataSource[T] extends Serializable { } object JsonDataSource { - def apply(options: JSONOptions): JsonDataSource[_] = { + def apply(options: JSONOptions): JsonDataSource = { if (options.wholeFile) { WholeFileJsonDataSource } else { TextInputJsonDataSource } } - - /** - * Create a new [[RDD]] via the supplied callback if there is at least one file to process, - * otherwise an [[org.apache.spark.rdd.EmptyRDD]] will be returned. - */ - def createBaseRdd[T : ClassTag]( - sparkSession: SparkSession, - inputPaths: Seq[FileStatus])( - fn: (Configuration, String) => RDD[T]): RDD[T] = { - val paths = inputPaths.map(_.getPath) - - if (paths.nonEmpty) { - val job = Job.getInstance(sparkSession.sessionState.newHadoopConf()) - FileInputFormat.setInputPaths(job, paths: _*) - fn(job.getConfiguration, paths.mkString(",")) - } else { - sparkSession.sparkContext.emptyRDD[T] - } - } } -object TextInputJsonDataSource extends JsonDataSource[Text] { +object TextInputJsonDataSource extends JsonDataSource { override val isSplitable: Boolean = { // splittable if the underlying source is true } - override protected def createBaseRdd( + override def infer( sparkSession: SparkSession, - inputPaths: Seq[FileStatus]): RDD[Text] = { - JsonDataSource.createBaseRdd(sparkSession, inputPaths) { - case (conf, name) => - sparkSession.sparkContext.newAPIHadoopRDD( - conf, - classOf[TextInputFormat], - classOf[LongWritable], - classOf[Text]) - .setName(s"JsonLines: $name") - .values // get the text column - } + inputPaths: Seq[FileStatus], + parsedOptions: JSONOptions): StructType = { + val json: Dataset[String] = createBaseDataset(sparkSession, inputPaths) + inferFromDataset(json, parsedOptions) + } + + def inferFromDataset(json: Dataset[String], parsedOptions: JSONOptions): StructType = { + val sampled: Dataset[String] = JsonUtils.sample(json, parsedOptions) + val rdd: RDD[UTF8String] = sampled.queryExecution.toRdd.map(_.getUTF8String(0)) + JsonInferSchema.infer(rdd, parsedOptions, CreateJacksonParser.utf8String) + } + + private def createBaseDataset( + sparkSession: SparkSession, + inputPaths: Seq[FileStatus]): Dataset[String] = { + val paths = inputPaths.map(_.getPath.toString) + sparkSession.baseRelationToDataFrame( + DataSource.apply( + sparkSession, + paths = paths, + className = classOf[TextFileFormat].getName + ).resolveRelation(checkFilesExist = false)) + .select("value").as(Encoders.STRING) } override def readFile( conf: Configuration, file: PartitionedFile, - parser: JacksonParser): Iterator[InternalRow] = { + parser: JacksonParser, + schema: StructType): Iterator[InternalRow] = { val linesReader = new HadoopFileLinesReader(file, conf) Option(TaskContext.get()).foreach(_.addTaskCompletionListener(_ => linesReader.close())) - linesReader.flatMap(parser.parse(_, createParser, textToUTF8String)) + val safeParser = new FailureSafeParser[Text]( + input => parser.parse(input, CreateJacksonParser.text, textToUTF8String), + parser.options.parseMode, + schema, + parser.options.columnNameOfCorruptRecord) + linesReader.flatMap(safeParser.parse) } private def textToUTF8String(value: Text): UTF8String = { UTF8String.fromBytes(value.getBytes, 0, value.getLength) } - - override def createParser(jsonFactory: JsonFactory, value: Text): JsonParser = { - CreateJacksonParser.text(jsonFactory, value) - } } -object WholeFileJsonDataSource extends JsonDataSource[PortableDataStream] { +object WholeFileJsonDataSource extends JsonDataSource { override val isSplitable: Boolean = { false } - override protected def createBaseRdd( + override def infer( + sparkSession: SparkSession, + inputPaths: Seq[FileStatus], + parsedOptions: JSONOptions): StructType = { + val json: RDD[PortableDataStream] = createBaseRdd(sparkSession, inputPaths) + val sampled: RDD[PortableDataStream] = JsonUtils.sample(json, parsedOptions) + JsonInferSchema.infer(sampled, parsedOptions, createParser) + } + + private def createBaseRdd( sparkSession: SparkSession, inputPaths: Seq[FileStatus]): RDD[PortableDataStream] = { - JsonDataSource.createBaseRdd(sparkSession, inputPaths) { - case (conf, name) => - new BinaryFileRDD( - sparkSession.sparkContext, - classOf[StreamInputFormat], - classOf[String], - classOf[PortableDataStream], - conf, - sparkSession.sparkContext.defaultMinPartitions) - .setName(s"JsonFile: $name") - .values - } + val paths = inputPaths.map(_.getPath) + val job = Job.getInstance(sparkSession.sessionState.newHadoopConf()) + val conf = job.getConfiguration + val name = paths.mkString(",") + FileInputFormat.setInputPaths(job, paths: _*) + new BinaryFileRDD( + sparkSession.sparkContext, + classOf[StreamInputFormat], + classOf[String], + classOf[PortableDataStream], + conf, + sparkSession.sparkContext.defaultMinPartitions) + .setName(s"JsonFile: $name") + .values } - override def createParser(jsonFactory: JsonFactory, record: PortableDataStream): JsonParser = { + private def createParser(jsonFactory: JsonFactory, record: PortableDataStream): JsonParser = { CreateJacksonParser.inputStream( jsonFactory, CodecStreams.createInputStreamWithCloseResource(record.getConfiguration, record.getPath())) @@ -193,7 +189,8 @@ object WholeFileJsonDataSource extends JsonDataSource[PortableDataStream] { override def readFile( conf: Configuration, file: PartitionedFile, - parser: JacksonParser): Iterator[InternalRow] = { + parser: JacksonParser, + schema: StructType): Iterator[InternalRow] = { def partitionedFileString(ignored: Any): UTF8String = { Utils.tryWithResource { CodecStreams.createInputStreamWithCloseResource(conf, file.filePath) @@ -202,9 +199,13 @@ object WholeFileJsonDataSource extends JsonDataSource[PortableDataStream] { } } - parser.parse( - CodecStreams.createInputStreamWithCloseResource(conf, file.filePath), - CreateJacksonParser.inputStream, - partitionedFileString).toIterator + val safeParser = new FailureSafeParser[InputStream]( + input => parser.parse(input, CreateJacksonParser.inputStream, partitionedFileString), + parser.options.parseMode, + schema, + parser.options.columnNameOfCorruptRecord) + + safeParser.parse( + CodecStreams.createInputStreamWithCloseResource(conf, file.filePath)) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonFileFormat.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonFileFormat.scala index 902fee5a7e3f7..53d62d88b04c6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonFileFormat.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonFileFormat.scala @@ -54,7 +54,7 @@ class JsonFileFormat extends TextBasedFileFormat with DataSourceRegister { options, sparkSession.sessionState.conf.sessionLocalTimeZone, sparkSession.sessionState.conf.columnNameOfCorruptRecord) - JsonDataSource(parsedOptions).infer( + JsonDataSource(parsedOptions).inferSchema( sparkSession, files, parsedOptions) } @@ -102,6 +102,8 @@ class JsonFileFormat extends TextBasedFileFormat with DataSourceRegister { sparkSession.sessionState.conf.sessionLocalTimeZone, sparkSession.sessionState.conf.columnNameOfCorruptRecord) + val actualSchema = + StructType(requiredSchema.filterNot(_.name == parsedOptions.columnNameOfCorruptRecord)) // Check a field requirement for corrupt records here to throw an exception in a driver side dataSchema.getFieldIndex(parsedOptions.columnNameOfCorruptRecord).foreach { corruptFieldIndex => val f = dataSchema(corruptFieldIndex) @@ -112,11 +114,12 @@ class JsonFileFormat extends TextBasedFileFormat with DataSourceRegister { } (file: PartitionedFile) => { - val parser = new JacksonParser(requiredSchema, parsedOptions) + val parser = new JacksonParser(actualSchema, parsedOptions) JsonDataSource(parsedOptions).readFile( broadcastedHadoopConf.value.value, file, - parser) + parser, + requiredSchema) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonInferSchema.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonInferSchema.scala index ab09358115c0a..fb632cf2bb70e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonInferSchema.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonInferSchema.scala @@ -25,6 +25,7 @@ import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.analysis.TypeCoercion import org.apache.spark.sql.catalyst.json.JacksonUtils.nextUntil import org.apache.spark.sql.catalyst.json.JSONOptions +import org.apache.spark.sql.catalyst.util.{DropMalformedMode, FailFastMode, ParseMode, PermissiveMode} import org.apache.spark.sql.types._ import org.apache.spark.util.Utils @@ -40,18 +41,11 @@ private[sql] object JsonInferSchema { json: RDD[T], configOptions: JSONOptions, createParser: (JsonFactory, T) => JsonParser): StructType = { - require(configOptions.samplingRatio > 0, - s"samplingRatio (${configOptions.samplingRatio}) should be greater than 0") - val shouldHandleCorruptRecord = configOptions.permissive + val parseMode = configOptions.parseMode val columnNameOfCorruptRecord = configOptions.columnNameOfCorruptRecord - val schemaData = if (configOptions.samplingRatio > 0.99) { - json - } else { - json.sample(withReplacement = false, configOptions.samplingRatio, 1) - } // perform schema inference on each row and merge afterwards - val rootType = schemaData.mapPartitions { iter => + val rootType = json.mapPartitions { iter => val factory = new JsonFactory() configOptions.setJacksonOptions(factory) iter.flatMap { row => @@ -61,20 +55,24 @@ private[sql] object JsonInferSchema { Some(inferField(parser, configOptions)) } } catch { - case _: JsonParseException if shouldHandleCorruptRecord => - Some(StructType(Seq(StructField(columnNameOfCorruptRecord, StringType)))) - case _: JsonParseException => - None + case e @ (_: RuntimeException | _: JsonProcessingException) => parseMode match { + case PermissiveMode => + Some(StructType(Seq(StructField(columnNameOfCorruptRecord, StringType)))) + case DropMalformedMode => + None + case FailFastMode => + throw e + } } } - }.fold(StructType(Seq()))( - compatibleRootType(columnNameOfCorruptRecord, shouldHandleCorruptRecord)) + }.fold(StructType(Nil))( + compatibleRootType(columnNameOfCorruptRecord, parseMode)) canonicalizeType(rootType) match { case Some(st: StructType) => st case _ => // canonicalizeType erases all empty structs, including the only one we want to keep - StructType(Seq()) + StructType(Nil) } } @@ -208,19 +206,33 @@ private[sql] object JsonInferSchema { private def withCorruptField( struct: StructType, - columnNameOfCorruptRecords: String): StructType = { - if (!struct.fieldNames.contains(columnNameOfCorruptRecords)) { - // If this given struct does not have a column used for corrupt records, - // add this field. - val newFields: Array[StructField] = - StructField(columnNameOfCorruptRecords, StringType, nullable = true) +: struct.fields - // Note: other code relies on this sorting for correctness, so don't remove it! - java.util.Arrays.sort(newFields, structFieldComparator) - StructType(newFields) - } else { - // Otherwise, just return this struct. + other: DataType, + columnNameOfCorruptRecords: String, + parseMode: ParseMode) = parseMode match { + case PermissiveMode => + // If we see any other data type at the root level, we get records that cannot be + // parsed. So, we use the struct as the data type and add the corrupt field to the schema. + if (!struct.fieldNames.contains(columnNameOfCorruptRecords)) { + // If this given struct does not have a column used for corrupt records, + // add this field. + val newFields: Array[StructField] = + StructField(columnNameOfCorruptRecords, StringType, nullable = true) +: struct.fields + // Note: other code relies on this sorting for correctness, so don't remove it! + java.util.Arrays.sort(newFields, structFieldComparator) + StructType(newFields) + } else { + // Otherwise, just return this struct. + struct + } + + case DropMalformedMode => + // If corrupt record handling is disabled we retain the valid schema and discard the other. struct - } + + case FailFastMode => + // If `other` is not struct type, consider it as malformed one and throws an exception. + throw new RuntimeException("Failed to infer a common schema. Struct types are expected" + + s" but ${other.catalogString} was found.") } /** @@ -228,21 +240,20 @@ private[sql] object JsonInferSchema { */ private def compatibleRootType( columnNameOfCorruptRecords: String, - shouldHandleCorruptRecord: Boolean): (DataType, DataType) => DataType = { + parseMode: ParseMode): (DataType, DataType) => DataType = { // Since we support array of json objects at the top level, // we need to check the element type and find the root level data type. case (ArrayType(ty1, _), ty2) => - compatibleRootType(columnNameOfCorruptRecords, shouldHandleCorruptRecord)(ty1, ty2) + compatibleRootType(columnNameOfCorruptRecords, parseMode)(ty1, ty2) case (ty1, ArrayType(ty2, _)) => - compatibleRootType(columnNameOfCorruptRecords, shouldHandleCorruptRecord)(ty1, ty2) - // If we see any other data type at the root level, we get records that cannot be - // parsed. So, we use the struct as the data type and add the corrupt field to the schema. + compatibleRootType(columnNameOfCorruptRecords, parseMode)(ty1, ty2) + // Discard null/empty documents case (struct: StructType, NullType) => struct case (NullType, struct: StructType) => struct - case (struct: StructType, o) if !o.isInstanceOf[StructType] && shouldHandleCorruptRecord => - withCorruptField(struct, columnNameOfCorruptRecords) - case (o, struct: StructType) if !o.isInstanceOf[StructType] && shouldHandleCorruptRecord => - withCorruptField(struct, columnNameOfCorruptRecords) + case (struct: StructType, o) if !o.isInstanceOf[StructType] => + withCorruptField(struct, o, columnNameOfCorruptRecords, parseMode) + case (o, struct: StructType) if !o.isInstanceOf[StructType] => + withCorruptField(struct, o, columnNameOfCorruptRecords, parseMode) // If we get anything else, we call compatibleType. // Usually, when we reach here, ty1 and ty2 are two StructTypes. case (ty1, ty2) => compatibleType(ty1, ty2) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonUtils.scala new file mode 100644 index 0000000000000..d511594c5de1c --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonUtils.scala @@ -0,0 +1,51 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.datasources.json + +import org.apache.spark.input.PortableDataStream +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.Dataset +import org.apache.spark.sql.catalyst.json.JSONOptions + +object JsonUtils { + /** + * Sample JSON dataset as configured by `samplingRatio`. + */ + def sample(json: Dataset[String], options: JSONOptions): Dataset[String] = { + require(options.samplingRatio > 0, + s"samplingRatio (${options.samplingRatio}) should be greater than 0") + if (options.samplingRatio > 0.99) { + json + } else { + json.sample(withReplacement = false, options.samplingRatio, 1) + } + } + + /** + * Sample JSON RDD as configured by `samplingRatio`. + */ + def sample(json: RDD[PortableDataStream], options: JSONOptions): RDD[PortableDataStream] = { + require(options.samplingRatio > 0, + s"samplingRatio (${options.samplingRatio}) should be greater than 0") + if (options.samplingRatio > 0.99) { + json + } else { + json.sample(withReplacement = false, options.samplingRatio, 1) + } + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala index 0bfae3e267158..7854099e85e03 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala @@ -136,6 +136,10 @@ class ParquetFileFormat SQLConf.PARQUET_TIMESTAMP_AS_INT96.key, sparkSession.sessionState.conf.isParquetTimestampAsINT96.toString) + conf.set( + SQLConf.PARQUET_INT64_AS_TIMESTAMP_MILLIS.key, + sparkSession.sessionState.conf.isParquetINT64AsTimestampMillis.toString) + // Sets compression scheme conf.set(ParquetOutputFormat.COMPRESSION, parquetOptions.compressionCodecClassName) @@ -357,20 +361,6 @@ class ParquetFileFormat filters: Seq[Filter], options: Map[String, String], hadoopConf: Configuration): (PartitionedFile) => Iterator[InternalRow] = { - // For Parquet data source, `buildReader` already handles partition values appending. Here we - // simply delegate to `buildReader`. - buildReader( - sparkSession, dataSchema, partitionSchema, requiredSchema, filters, options, hadoopConf) - } - - override def buildReader( - sparkSession: SparkSession, - dataSchema: StructType, - partitionSchema: StructType, - requiredSchema: StructType, - filters: Seq[Filter], - options: Map[String, String], - hadoopConf: Configuration): PartitionedFile => Iterator[InternalRow] = { hadoopConf.set(ParquetInputFormat.READ_SUPPORT_CLASS, classOf[ParquetReadSupport].getName) hadoopConf.set( ParquetReadSupport.SPARK_ROW_REQUESTED_SCHEMA, @@ -386,15 +376,17 @@ class ParquetFileFormat hadoopConf.setBoolean( SQLConf.PARQUET_BINARY_AS_STRING.key, sparkSession.sessionState.conf.isParquetBinaryAsString) + hadoopConf.setBoolean(SQLConf.PARQUET_INT96_AS_TIMESTAMP.key, int96AsTimestamp) hadoopConf.setBoolean( - SQLConf.PARQUET_INT96_AS_TIMESTAMP.key, - int96AsTimestamp) + SQLConf.PARQUET_INT64_AS_TIMESTAMP_MILLIS.key, + sparkSession.sessionState.conf.isParquetINT64AsTimestampMillis) // By default, disable record level filtering. if (hadoopConf.get(ParquetInputFormat.RECORD_FILTERING_ENABLED) == null) { hadoopConf.setBoolean(ParquetInputFormat.RECORD_FILTERING_ENABLED, false) } + // Try to push down filters when filter push-down is enabled. val pushed = if (sparkSession.sessionState.conf.parquetFilterPushDown) { @@ -519,8 +511,10 @@ object ParquetFileFormat extends Logging { def parseParquetSchema(schema: MessageType): StructType = { val converter = new ParquetSchemaConverter( sparkSession.sessionState.conf.isParquetBinaryAsString, - sparkSession.sessionState.conf.isParquetBinaryAsString, - sparkSession.sessionState.conf.writeLegacyParquetFormat) + sparkSession.sessionState.conf.isParquetINT96AsTimestamp, + sparkSession.sessionState.conf.writeLegacyParquetFormat, + sparkSession.sessionState.conf.isParquetTimestampAsINT96, + sparkSession.sessionState.conf.isParquetINT64AsTimestampMillis) converter.convert(schema) } @@ -571,71 +565,6 @@ object ParquetFileFormat extends Logging { } } - /** - * Reconciles Hive Metastore case insensitivity issue and data type conflicts between Metastore - * schema and Parquet schema. - * - * Hive doesn't retain case information, while Parquet is case sensitive. On the other hand, the - * schema read from Parquet files may be incomplete (e.g. older versions of Parquet doesn't - * distinguish binary and string). This method generates a correct schema by merging Metastore - * schema data types and Parquet schema field names. - */ - def mergeMetastoreParquetSchema( - metastoreSchema: StructType, - parquetSchema: StructType): StructType = { - def schemaConflictMessage: String = - s"""Converting Hive Metastore Parquet, but detected conflicting schemas. Metastore schema: - |${metastoreSchema.prettyJson} - | - |Parquet schema: - |${parquetSchema.prettyJson} - """.stripMargin - - val mergedParquetSchema = mergeMissingNullableFields(metastoreSchema, parquetSchema) - - assert(metastoreSchema.size <= mergedParquetSchema.size, schemaConflictMessage) - - val ordinalMap = metastoreSchema.zipWithIndex.map { - case (field, index) => field.name.toLowerCase -> index - }.toMap - - val reorderedParquetSchema = mergedParquetSchema.sortBy(f => - ordinalMap.getOrElse(f.name.toLowerCase, metastoreSchema.size + 1)) - - StructType(metastoreSchema.zip(reorderedParquetSchema).map { - // Uses Parquet field names but retains Metastore data types. - case (mSchema, pSchema) if mSchema.name.toLowerCase == pSchema.name.toLowerCase => - mSchema.copy(name = pSchema.name) - case _ => - throw new SparkException(schemaConflictMessage) - }) - } - - /** - * Returns the original schema from the Parquet file with any missing nullable fields from the - * Hive Metastore schema merged in. - * - * When constructing a DataFrame from a collection of structured data, the resulting object has - * a schema corresponding to the union of the fields present in each element of the collection. - * Spark SQL simply assigns a null value to any field that isn't present for a particular row. - * In some cases, it is possible that a given table partition stored as a Parquet file doesn't - * contain a particular nullable field in its schema despite that field being present in the - * table schema obtained from the Hive Metastore. This method returns a schema representing the - * Parquet file schema along with any additional nullable fields from the Metastore schema - * merged in. - */ - private[parquet] def mergeMissingNullableFields( - metastoreSchema: StructType, - parquetSchema: StructType): StructType = { - val fieldMap = metastoreSchema.map(f => f.name.toLowerCase -> f).toMap - val missingFields = metastoreSchema - .map(_.name.toLowerCase) - .diff(parquetSchema.map(_.name.toLowerCase)) - .map(fieldMap(_)) - .filter(_.nullable) - StructType(parquetSchema ++ missingFields) - } - /** * Reads Parquet footers in multi-threaded manner. * If the config "spark.sql.files.ignoreCorruptFiles" is set to true, we will ignore the corrupted @@ -685,7 +614,9 @@ object ParquetFileFormat extends Logging { sparkSession: SparkSession): Option[StructType] = { val assumeBinaryIsString = sparkSession.sessionState.conf.isParquetBinaryAsString val assumeInt96IsTimestamp = sparkSession.sessionState.conf.isParquetINT96AsTimestamp + val writeTimestampInMillis = sparkSession.sessionState.conf.isParquetINT64AsTimestampMillis val writeLegacyParquetFormat = sparkSession.sessionState.conf.writeLegacyParquetFormat + val writeTimestampAsInt96 = sparkSession.sessionState.conf.isParquetTimestampAsINT96 val serializedConf = new SerializableConfiguration(sparkSession.sessionState.newHadoopConf()) // !! HACK ALERT !! @@ -729,7 +660,9 @@ object ParquetFileFormat extends Logging { new ParquetSchemaConverter( assumeBinaryIsString = assumeBinaryIsString, assumeInt96IsTimestamp = assumeInt96IsTimestamp, - writeLegacyParquetFormat = writeLegacyParquetFormat) + writeLegacyParquetFormat = writeLegacyParquetFormat, + writeTimestampAsInt96 = writeTimestampAsInt96, + writeTimestampInMillis = writeTimestampInMillis) if (footers.isEmpty) { Iterator.empty diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetOptions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetOptions.scala index bdda299a621ac..772d4565de548 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetOptions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetOptions.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.execution.datasources.parquet +import java.util.Locale + import org.apache.parquet.hadoop.metadata.CompressionCodecName import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap @@ -40,9 +42,11 @@ private[parquet] class ParquetOptions( * Acceptable values are defined in [[shortParquetCompressionCodecNames]]. */ val compressionCodecClassName: String = { - val codecName = parameters.getOrElse("compression", sqlConf.parquetCompressionCodec).toLowerCase + val codecName = parameters.getOrElse("compression", + sqlConf.parquetCompressionCodec).toLowerCase(Locale.ROOT) if (!shortParquetCompressionCodecNames.contains(codecName)) { - val availableCodecs = shortParquetCompressionCodecNames.keys.map(_.toLowerCase) + val availableCodecs = + shortParquetCompressionCodecNames.keys.map(_.toLowerCase(Locale.ROOT)) throw new IllegalArgumentException(s"Codec [$codecName] " + s"is not available. Available codecs are ${availableCodecs.mkString(", ")}.") } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRowConverter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRowConverter.scala index 56495f487d693..8c6a21de05b8b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRowConverter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRowConverter.scala @@ -25,7 +25,7 @@ import scala.collection.mutable.ArrayBuffer import org.apache.parquet.column.Dictionary import org.apache.parquet.io.api.{Binary, Converter, GroupConverter, PrimitiveConverter} -import org.apache.parquet.schema.{GroupType, MessageType, Type} +import org.apache.parquet.schema.{GroupType, MessageType, OriginalType, Type} import org.apache.parquet.schema.OriginalType.{INT_32, LIST, UTF8} import org.apache.parquet.schema.PrimitiveType.PrimitiveTypeName.{BINARY, DOUBLE, FIXED_LEN_BYTE_ARRAY, INT32, INT64} @@ -252,11 +252,17 @@ private[parquet] class ParquetRowConverter( case StringType => new ParquetStringConverter(updater) - case _: TimestampType if parquetType.asPrimitiveType().getPrimitiveTypeName == INT64 => + case TimestampType if parquetType.getOriginalType == OriginalType.TIMESTAMP_MICROS => new ParquetPrimitiveConverter(updater) + case TimestampType if parquetType.getOriginalType == OriginalType.TIMESTAMP_MILLIS => + new ParquetPrimitiveConverter(updater) { + override def addLong(value: Long): Unit = { + updater.setLong(DateTimeUtils.fromMillis(value)) + } + } + case TimestampType => - // TODO Implements `TIMESTAMP_MICROS` once parquet-mr has that. new ParquetPrimitiveConverter(updater) { // Converts nanosecond timestamps stored as INT96 override def addBinary(value: Binary): Unit = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaConverter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaConverter.scala index 44d9f25012bd8..1a14f4756c67c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaConverter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaConverter.scala @@ -53,18 +53,22 @@ import org.apache.spark.sql.types._ * affects Parquet write path. * @param writeTimestampAsInt96 Whether to use serialize Timestamps in legacy INT96 format for * forwards compatibility. + * @param writeTimestampInMillis Whether to write timestamp values as INT64 annotated by logical + * type TIMESTAMP_MILLIS. */ private[parquet] class ParquetSchemaConverter( assumeBinaryIsString: Boolean = SQLConf.PARQUET_BINARY_AS_STRING.defaultValue.get, assumeInt96IsTimestamp: Boolean = SQLConf.PARQUET_INT96_AS_TIMESTAMP.defaultValue.get, writeLegacyParquetFormat: Boolean = SQLConf.PARQUET_WRITE_LEGACY_FORMAT.defaultValue.get, - writeTimestampAsInt96: Boolean = SQLConf.PARQUET_TIMESTAMP_AS_INT96.defaultValue.get) { + writeTimestampAsInt96: Boolean = SQLConf.PARQUET_TIMESTAMP_AS_INT96.defaultValue.get, + writeTimestampInMillis: Boolean = SQLConf.PARQUET_INT64_AS_TIMESTAMP_MILLIS.defaultValue.get) { def this(conf: SQLConf) = this( assumeBinaryIsString = conf.isParquetBinaryAsString, assumeInt96IsTimestamp = conf.isParquetINT96AsTimestamp, writeLegacyParquetFormat = conf.writeLegacyParquetFormat, - writeTimestampAsInt96 = conf.isParquetTimestampAsINT96) + writeTimestampAsInt96 = conf.isParquetTimestampAsINT96, + writeTimestampInMillis = conf.isParquetINT64AsTimestampMillis) def this(conf: Configuration) = this( assumeBinaryIsString = conf.get(SQLConf.PARQUET_BINARY_AS_STRING.key).toBoolean, @@ -72,7 +76,9 @@ private[parquet] class ParquetSchemaConverter( writeLegacyParquetFormat = conf.get(SQLConf.PARQUET_WRITE_LEGACY_FORMAT.key, SQLConf.PARQUET_WRITE_LEGACY_FORMAT.defaultValue.get.toString).toBoolean, writeTimestampAsInt96 = conf.get(SQLConf.PARQUET_TIMESTAMP_AS_INT96.key, - SQLConf.PARQUET_TIMESTAMP_AS_INT96.defaultValue.get.toString).toBoolean) + SQLConf.PARQUET_TIMESTAMP_AS_INT96.defaultValue.get.toString).toBoolean, + writeTimestampInMillis = conf.get(SQLConf.PARQUET_INT64_AS_TIMESTAMP_MILLIS.key, + SQLConf.PARQUET_INT64_AS_TIMESTAMP_MILLIS.defaultValue.get.toString).toBoolean) /** * Converts Parquet [[MessageType]] `parquetSchema` to a Spark SQL [[StructType]]. @@ -165,7 +171,7 @@ private[parquet] class ParquetSchemaConverter( case DECIMAL => makeDecimalType(Decimal.MAX_LONG_DIGITS) case UINT_64 => typeNotSupported() case TIMESTAMP_MICROS => TimestampType - case TIMESTAMP_MILLIS => typeNotImplemented() + case TIMESTAMP_MILLIS => TimestampType case _ => illegalType() } @@ -375,6 +381,9 @@ private[parquet] class ParquetSchemaConverter( // from Spark 1.5.0, we resort to a timestamp type with 100 ns precision so that we can store // a timestamp into a `Long`. This design decision is subject to change though, for example, // we may resort to microsecond precision in the future. + case TimestampType if writeTimestampInMillis => + Types.primitive(INT64, repetition).as(TIMESTAMP_MILLIS).named(field.name) + case TimestampType if writeTimestampAsInt96 => Types.primitive(INT96, repetition).named(field.name) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetWriteSupport.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetWriteSupport.scala index 3d9d8ff8841cb..d856bb6b907e2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetWriteSupport.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetWriteSupport.scala @@ -72,6 +72,9 @@ private[parquet] class ParquetWriteSupport extends WriteSupport[InternalRow] wit // Whether to write timestamps as int96 private var writeTimestampAsInt96: Boolean = _ + // Whether to write timestamp value with milliseconds precision. + private var writeTimestampInMillis: Boolean = _ + // Reusable byte array used to write decimal values private val decimalBuffer = new Array[Byte](minBytesForPrecision(DecimalType.MAX_PRECISION)) @@ -87,6 +90,12 @@ private[parquet] class ParquetWriteSupport extends WriteSupport[InternalRow] wit assert(configuration.get(SQLConf.PARQUET_TIMESTAMP_AS_INT96.key) != null) configuration.get(SQLConf.PARQUET_TIMESTAMP_AS_INT96.key).toBoolean } + + this.writeTimestampInMillis = { + assert(configuration.get(SQLConf.PARQUET_INT64_AS_TIMESTAMP_MILLIS.key) != null) + configuration.get(SQLConf.PARQUET_INT64_AS_TIMESTAMP_MILLIS.key).toBoolean + } + this.rootFieldWriters = schema.map(_.dataType).map(makeWriter) val messageType = new ParquetSchemaConverter(configuration).convert(schema) @@ -188,6 +197,9 @@ private[parquet] class ParquetWriteSupport extends WriteSupport[InternalRow] wit } private def makeTimestampWriter(): ValueWriter = { + val millisWriter = (row: SpecializedGetters, ordinal: Int) => { + recordConsumer.addLong(DateTimeUtils.toMillis(row.getLong(ordinal))) + } val int96Writer = (row: SpecializedGetters, ordinal: Int) => { val (julianDay, timeOfDayNanos) = DateTimeUtils.toJulianDay(row.getLong(ordinal)) val buf = ByteBuffer.wrap(timestampBuffer) @@ -196,7 +208,8 @@ private[parquet] class ParquetWriteSupport extends WriteSupport[InternalRow] wit } val longWriter = (row: SpecializedGetters, ordinal: Int) => recordConsumer.addLong(row.getLong(ordinal)) - if (writeTimestampAsInt96) int96Writer else longWriter + if (writeTimestampAsInt96) int96Writer else + if (writeTimestampInMillis) millisWriter else longWriter } private def makeDecimalWriter(precision: Int, scale: Int): ValueWriter = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala index 4d781b96abace..3f4a78580f1eb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala @@ -17,10 +17,12 @@ package org.apache.spark.sql.execution.datasources +import java.util.Locale + import org.apache.spark.sql.{AnalysisException, SaveMode, SparkSession} import org.apache.spark.sql.catalyst.analysis._ import org.apache.spark.sql.catalyst.catalog._ -import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, Cast, RowOrdering} +import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, RowOrdering} import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.execution.command.DDLUtils @@ -48,7 +50,8 @@ class ResolveSQLOnFile(sparkSession: SparkSession) extends Rule[LogicalPlan] { // will catch it and return the original plan, so that the analyzer can report table not // found later. val isFileFormat = classOf[FileFormat].isAssignableFrom(dataSource.providingClass) - if (!isFileFormat || dataSource.className.toLowerCase == DDLUtils.HIVE_PROVIDER) { + if (!isFileFormat || + dataSource.className.toLowerCase(Locale.ROOT) == DDLUtils.HIVE_PROVIDER) { throw new AnalysisException("Unsupported data source type for direct query on files: " + s"${u.tableIdentifier.database.get}") } @@ -66,7 +69,8 @@ class ResolveSQLOnFile(sparkSession: SparkSession) extends Rule[LogicalPlan] { * Preprocess [[CreateTable]], to do some normalization and checking. */ case class PreprocessTableCreation(sparkSession: SparkSession) extends Rule[LogicalPlan] { - private val catalog = sparkSession.sessionState.catalog + // catalog is a def and not a val/lazy val as the latter would introduce a circular reference + private def catalog = sparkSession.sessionState.catalog def apply(plan: LogicalPlan): LogicalPlan = plan transform { // When we CREATE TABLE without specifying the table schema, we should fail the query if @@ -311,7 +315,7 @@ case class PreprocessTableCreation(sparkSession: SparkSession) extends Rule[Logi * table. It also does data type casting and field renaming, to make sure that the columns to be * inserted have the correct data type and fields have the correct names. */ -case class PreprocessTableInsertion(conf: SQLConf) extends Rule[LogicalPlan] { +case class PreprocessTableInsertion(conf: SQLConf) extends Rule[LogicalPlan] with CastSupport { private def preprocess( insert: InsertIntoTable, tblName: String, @@ -363,7 +367,7 @@ case class PreprocessTableInsertion(conf: SQLConf) extends Rule[LogicalPlan] { // Renaming is needed for handling the following cases like // 1) Column names/types do not match, e.g., INSERT INTO TABLE tab1 SELECT 1, 2 // 2) Target tables have column metadata - Alias(Cast(actual, expected.dataType), expected.name)( + Alias(cast(actual, expected.dataType), expected.name)( explicitMetadata = Option(expected.metadata)) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/BroadcastExchangeExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/BroadcastExchangeExec.scala index 7be5d31d4a765..9c859e41f8762 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/BroadcastExchangeExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/BroadcastExchangeExec.scala @@ -48,10 +48,8 @@ case class BroadcastExchangeExec( override def outputPartitioning: Partitioning = BroadcastPartitioning(mode) - override def sameResult(plan: SparkPlan): Boolean = plan match { - case p: BroadcastExchangeExec => - mode.compatibleWith(p.mode) && child.sameResult(p.child) - case _ => false + override lazy val canonicalized: SparkPlan = { + BroadcastExchangeExec(mode.canonicalized, child.canonicalized) } @transient @@ -97,13 +95,7 @@ case class BroadcastExchangeExec( val broadcasted = sparkContext.broadcast(relation) longMetric("broadcastTime") += (System.nanoTime() - beforeBroadcast) / 1000000 - // There are some cases we don't care about the metrics and call `SparkPlan.doExecute` - // directly without setting an execution id. We should be tolerant to it. - if (executionId != null) { - sparkContext.listenerBus.post(SparkListenerDriverAccumUpdates( - executionId.toLong, metrics.values.map(m => m.id -> m.value).toSeq)) - } - + SQLMetrics.postDriverMetricUpdates(sparkContext, executionId, metrics.values.toSeq) broadcasted } catch { case oe: OutOfMemoryError => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala index f17049949aa47..b91d077442557 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala @@ -241,7 +241,7 @@ case class EnsureRequirements(conf: SQLConf) extends Rule[SparkPlan] { } else { requiredOrdering.zip(child.outputOrdering).forall { case (requiredOrder, childOutputOrder) => - requiredOrder.semanticEquals(childOutputOrder) + childOutputOrder.satisfies(requiredOrder) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/Exchange.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/Exchange.scala index 9a9597d3733e0..d993ea6c6cef9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/Exchange.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/Exchange.scala @@ -48,10 +48,8 @@ abstract class Exchange extends UnaryExecNode { case class ReusedExchangeExec(override val output: Seq[Attribute], child: Exchange) extends LeafExecNode { - override def sameResult(plan: SparkPlan): Boolean = { - // Ignore this wrapper. `plan` could also be a ReusedExchange, so we reverse the order here. - plan.sameResult(child) - } + // Ignore this wrapper for canonicalizing. + override lazy val canonicalized: SparkPlan = child.canonicalized def doExecute(): RDD[InternalRow] = { child.execute() diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchange.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchange.scala index 125a4930c6528..f06544ea8ed04 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchange.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchange.scala @@ -46,7 +46,7 @@ case class ShuffleExchange( override def nodeName: String = { val extraInfo = coordinator match { case Some(exchangeCoordinator) => - s"(coordinator id: ${System.identityHashCode(coordinator)})" + s"(coordinator id: ${System.identityHashCode(exchangeCoordinator)})" case None => "" } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/CartesianProductExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/CartesianProductExec.scala index 8341fe2ffd078..f380986951317 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/CartesianProductExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/CartesianProductExec.scala @@ -19,65 +19,39 @@ package org.apache.spark.sql.execution.joins import org.apache.spark._ import org.apache.spark.rdd.{CartesianPartition, CartesianRDD, RDD} -import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression, JoinedRow, UnsafeRow} import org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeRowJoiner -import org.apache.spark.sql.execution.{BinaryExecNode, SparkPlan} +import org.apache.spark.sql.execution.{BinaryExecNode, ExternalAppendOnlyUnsafeRowArray, SparkPlan} import org.apache.spark.sql.execution.metric.SQLMetrics -import org.apache.spark.sql.internal.SQLConf import org.apache.spark.util.CompletionIterator -import org.apache.spark.util.collection.unsafe.sort.UnsafeExternalSorter /** * An optimized CartesianRDD for UnsafeRow, which will cache the rows from second child RDD, * will be much faster than building the right partition for every row in left RDD, it also * materialize the right RDD (in case of the right RDD is nondeterministic). */ -class UnsafeCartesianRDD(left : RDD[UnsafeRow], right : RDD[UnsafeRow], numFieldsOfRight: Int) +class UnsafeCartesianRDD( + left : RDD[UnsafeRow], + right : RDD[UnsafeRow], + numFieldsOfRight: Int, + spillThreshold: Int) extends CartesianRDD[UnsafeRow, UnsafeRow](left.sparkContext, left, right) { override def compute(split: Partition, context: TaskContext): Iterator[(UnsafeRow, UnsafeRow)] = { - // We will not sort the rows, so prefixComparator and recordComparator are null. - val sorter = UnsafeExternalSorter.create( - context.taskMemoryManager(), - SparkEnv.get.blockManager, - SparkEnv.get.serializerManager, - context, - null, - null, - 1024, - SparkEnv.get.memoryManager.pageSizeBytes, - SparkEnv.get.conf.getLong("spark.shuffle.spill.numElementsForceSpillThreshold", - UnsafeExternalSorter.DEFAULT_NUM_ELEMENTS_FOR_SPILL_THRESHOLD), - false) + val rowArray = new ExternalAppendOnlyUnsafeRowArray(spillThreshold) val partition = split.asInstanceOf[CartesianPartition] - for (y <- rdd2.iterator(partition.s2, context)) { - sorter.insertRecord(y.getBaseObject, y.getBaseOffset, y.getSizeInBytes, 0, false) - } + rdd2.iterator(partition.s2, context).foreach(rowArray.add) - // Create an iterator from sorter and wrapper it as Iterator[UnsafeRow] - def createIter(): Iterator[UnsafeRow] = { - val iter = sorter.getIterator - val unsafeRow = new UnsafeRow(numFieldsOfRight) - new Iterator[UnsafeRow] { - override def hasNext: Boolean = { - iter.hasNext - } - override def next(): UnsafeRow = { - iter.loadNext() - unsafeRow.pointTo(iter.getBaseObject, iter.getBaseOffset, iter.getRecordLength) - unsafeRow - } - } - } + // Create an iterator from rowArray + def createIter(): Iterator[UnsafeRow] = rowArray.generateIterator() val resultIter = for (x <- rdd1.iterator(partition.s1, context); y <- createIter()) yield (x, y) CompletionIterator[(UnsafeRow, UnsafeRow), Iterator[(UnsafeRow, UnsafeRow)]]( - resultIter, sorter.cleanupResources()) + resultIter, rowArray.clear()) } } @@ -97,7 +71,9 @@ case class CartesianProductExec( val leftResults = left.execute().asInstanceOf[RDD[UnsafeRow]] val rightResults = right.execute().asInstanceOf[RDD[UnsafeRow]] - val pair = new UnsafeCartesianRDD(leftResults, rightResults, right.output.size) + val spillThreshold = sqlContext.conf.cartesianProductExecBufferSpillThreshold + + val pair = new UnsafeCartesianRDD(leftResults, rightResults, right.output.size, spillThreshold) pair.mapPartitionsWithIndexInternal { (index, iter) => val joiner = GenerateUnsafeRowJoiner.create(left.schema, right.schema) val filtered = if (condition.isDefined) { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala index b9f6601ea87fe..2dd1dc3da96c9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala @@ -829,15 +829,10 @@ private[execution] case class HashedRelationBroadcastMode(key: Seq[Expression]) extends BroadcastMode { override def transform(rows: Array[InternalRow]): HashedRelation = { - HashedRelation(rows.iterator, canonicalizedKey, rows.length) + HashedRelation(rows.iterator, canonicalized.key, rows.length) } - private lazy val canonicalizedKey: Seq[Expression] = { - key.map { e => e.canonicalized } - } - - override def compatibleWith(other: BroadcastMode): Boolean = other match { - case m: HashedRelationBroadcastMode => canonicalizedKey == m.canonicalizedKey - case _ => false + override lazy val canonicalized: HashedRelationBroadcastMode = { + this.copy(key = key.map(_.canonicalized)) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala index ca9c0ed8cec32..c6aae1a4db2e4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala @@ -25,7 +25,8 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode} import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.physical._ -import org.apache.spark.sql.execution.{BinaryExecNode, CodegenSupport, RowIterator, SparkPlan} +import org.apache.spark.sql.execution.{BinaryExecNode, CodegenSupport, +ExternalAppendOnlyUnsafeRowArray, RowIterator, SparkPlan} import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics} import org.apache.spark.util.collection.BitSet @@ -79,7 +80,37 @@ case class SortMergeJoinExec( override def requiredChildDistribution: Seq[Distribution] = ClusteredDistribution(leftKeys) :: ClusteredDistribution(rightKeys) :: Nil - override def outputOrdering: Seq[SortOrder] = requiredOrders(leftKeys) + override def outputOrdering: Seq[SortOrder] = joinType match { + // For inner join, orders of both sides keys should be kept. + case Inner => + val leftKeyOrdering = getKeyOrdering(leftKeys, left.outputOrdering) + val rightKeyOrdering = getKeyOrdering(rightKeys, right.outputOrdering) + leftKeyOrdering.zip(rightKeyOrdering).map { case (lKey, rKey) => + // Also add the right key and its `sameOrderExpressions` + SortOrder(lKey.child, Ascending, lKey.sameOrderExpressions + rKey.child ++ rKey + .sameOrderExpressions) + } + // For left and right outer joins, the output is ordered by the streamed input's join keys. + case LeftOuter => getKeyOrdering(leftKeys, left.outputOrdering) + case RightOuter => getKeyOrdering(rightKeys, right.outputOrdering) + // There are null rows in both streams, so there is no order. + case FullOuter => Nil + case LeftExistence(_) => getKeyOrdering(leftKeys, left.outputOrdering) + case x => + throw new IllegalArgumentException( + s"${getClass.getSimpleName} should not take $x as the JoinType") + } + + /** + * For SMJ, child's output must have been sorted on key or expressions with the same order as + * key, so we can get ordering for key from child's output ordering. + */ + private def getKeyOrdering(keys: Seq[Expression], childOutputOrdering: Seq[SortOrder]) + : Seq[SortOrder] = { + keys.zip(childOutputOrdering).map { case (key, childOrder) => + SortOrder(key, Ascending, childOrder.sameOrderExpressions + childOrder.child - key) + } + } override def requiredChildOrdering: Seq[Seq[SortOrder]] = requiredOrders(leftKeys) :: requiredOrders(rightKeys) :: Nil @@ -95,9 +126,13 @@ case class SortMergeJoinExec( private def createRightKeyGenerator(): Projection = UnsafeProjection.create(rightKeys, right.output) + private def getSpillThreshold: Int = { + sqlContext.conf.sortMergeJoinExecBufferSpillThreshold + } + protected override def doExecute(): RDD[InternalRow] = { val numOutputRows = longMetric("numOutputRows") - + val spillThreshold = getSpillThreshold left.execute().zipPartitions(right.execute()) { (leftIter, rightIter) => val boundCondition: (InternalRow) => Boolean = { condition.map { cond => @@ -115,39 +150,39 @@ case class SortMergeJoinExec( case _: InnerLike => new RowIterator { private[this] var currentLeftRow: InternalRow = _ - private[this] var currentRightMatches: ArrayBuffer[InternalRow] = _ - private[this] var currentMatchIdx: Int = -1 + private[this] var currentRightMatches: ExternalAppendOnlyUnsafeRowArray = _ + private[this] var rightMatchesIterator: Iterator[UnsafeRow] = null private[this] val smjScanner = new SortMergeJoinScanner( createLeftKeyGenerator(), createRightKeyGenerator(), keyOrdering, RowIterator.fromScala(leftIter), - RowIterator.fromScala(rightIter) + RowIterator.fromScala(rightIter), + spillThreshold ) private[this] val joinRow = new JoinedRow if (smjScanner.findNextInnerJoinRows()) { currentRightMatches = smjScanner.getBufferedMatches currentLeftRow = smjScanner.getStreamedRow - currentMatchIdx = 0 + rightMatchesIterator = currentRightMatches.generateIterator() } override def advanceNext(): Boolean = { - while (currentMatchIdx >= 0) { - if (currentMatchIdx == currentRightMatches.length) { + while (rightMatchesIterator != null) { + if (!rightMatchesIterator.hasNext) { if (smjScanner.findNextInnerJoinRows()) { currentRightMatches = smjScanner.getBufferedMatches currentLeftRow = smjScanner.getStreamedRow - currentMatchIdx = 0 + rightMatchesIterator = currentRightMatches.generateIterator() } else { currentRightMatches = null currentLeftRow = null - currentMatchIdx = -1 + rightMatchesIterator = null return false } } - joinRow(currentLeftRow, currentRightMatches(currentMatchIdx)) - currentMatchIdx += 1 + joinRow(currentLeftRow, rightMatchesIterator.next()) if (boundCondition(joinRow)) { numOutputRows += 1 return true @@ -165,7 +200,8 @@ case class SortMergeJoinExec( bufferedKeyGenerator = createRightKeyGenerator(), keyOrdering, streamedIter = RowIterator.fromScala(leftIter), - bufferedIter = RowIterator.fromScala(rightIter) + bufferedIter = RowIterator.fromScala(rightIter), + spillThreshold ) val rightNullRow = new GenericInternalRow(right.output.length) new LeftOuterIterator( @@ -177,7 +213,8 @@ case class SortMergeJoinExec( bufferedKeyGenerator = createLeftKeyGenerator(), keyOrdering, streamedIter = RowIterator.fromScala(rightIter), - bufferedIter = RowIterator.fromScala(leftIter) + bufferedIter = RowIterator.fromScala(leftIter), + spillThreshold ) val leftNullRow = new GenericInternalRow(left.output.length) new RightOuterIterator( @@ -209,7 +246,8 @@ case class SortMergeJoinExec( createRightKeyGenerator(), keyOrdering, RowIterator.fromScala(leftIter), - RowIterator.fromScala(rightIter) + RowIterator.fromScala(rightIter), + spillThreshold ) private[this] val joinRow = new JoinedRow @@ -217,14 +255,15 @@ case class SortMergeJoinExec( while (smjScanner.findNextInnerJoinRows()) { val currentRightMatches = smjScanner.getBufferedMatches currentLeftRow = smjScanner.getStreamedRow - var i = 0 - while (i < currentRightMatches.length) { - joinRow(currentLeftRow, currentRightMatches(i)) - if (boundCondition(joinRow)) { - numOutputRows += 1 - return true + if (currentRightMatches != null && currentRightMatches.length > 0) { + val rightMatchesIterator = currentRightMatches.generateIterator() + while (rightMatchesIterator.hasNext) { + joinRow(currentLeftRow, rightMatchesIterator.next()) + if (boundCondition(joinRow)) { + numOutputRows += 1 + return true + } } - i += 1 } } false @@ -241,7 +280,8 @@ case class SortMergeJoinExec( createRightKeyGenerator(), keyOrdering, RowIterator.fromScala(leftIter), - RowIterator.fromScala(rightIter) + RowIterator.fromScala(rightIter), + spillThreshold ) private[this] val joinRow = new JoinedRow @@ -249,17 +289,16 @@ case class SortMergeJoinExec( while (smjScanner.findNextOuterJoinRows()) { currentLeftRow = smjScanner.getStreamedRow val currentRightMatches = smjScanner.getBufferedMatches - if (currentRightMatches == null) { + if (currentRightMatches == null || currentRightMatches.length == 0) { return true } - var i = 0 var found = false - while (!found && i < currentRightMatches.length) { - joinRow(currentLeftRow, currentRightMatches(i)) + val rightMatchesIterator = currentRightMatches.generateIterator() + while (!found && rightMatchesIterator.hasNext) { + joinRow(currentLeftRow, rightMatchesIterator.next()) if (boundCondition(joinRow)) { found = true } - i += 1 } if (!found) { numOutputRows += 1 @@ -281,7 +320,8 @@ case class SortMergeJoinExec( createRightKeyGenerator(), keyOrdering, RowIterator.fromScala(leftIter), - RowIterator.fromScala(rightIter) + RowIterator.fromScala(rightIter), + spillThreshold ) private[this] val joinRow = new JoinedRow @@ -290,14 +330,13 @@ case class SortMergeJoinExec( currentLeftRow = smjScanner.getStreamedRow val currentRightMatches = smjScanner.getBufferedMatches var found = false - if (currentRightMatches != null) { - var i = 0 - while (!found && i < currentRightMatches.length) { - joinRow(currentLeftRow, currentRightMatches(i)) + if (currentRightMatches != null && currentRightMatches.length > 0) { + val rightMatchesIterator = currentRightMatches.generateIterator() + while (!found && rightMatchesIterator.hasNext) { + joinRow(currentLeftRow, rightMatchesIterator.next()) if (boundCondition(joinRow)) { found = true } - i += 1 } } result.setBoolean(0, found) @@ -376,8 +415,11 @@ case class SortMergeJoinExec( // A list to hold all matched rows from right side. val matches = ctx.freshName("matches") - val clsName = classOf[java.util.ArrayList[InternalRow]].getName - ctx.addMutableState(clsName, matches, s"$matches = new $clsName();") + val clsName = classOf[ExternalAppendOnlyUnsafeRowArray].getName + + val spillThreshold = getSpillThreshold + + ctx.addMutableState(clsName, matches, s"$matches = new $clsName($spillThreshold);") // Copy the left keys as class members so they could be used in next function call. val matchedKeyVars = copyKeys(ctx, leftKeyVars) @@ -428,7 +470,7 @@ case class SortMergeJoinExec( | } | $leftRow = null; | } else { - | $matches.add($rightRow.copy()); + | $matches.add((UnsafeRow) $rightRow); | $rightRow = null;; | } | } while ($leftRow != null); @@ -517,8 +559,7 @@ case class SortMergeJoinExec( val rightRow = ctx.freshName("rightRow") val rightVars = createRightVar(ctx, rightRow) - val size = ctx.freshName("size") - val i = ctx.freshName("i") + val iterator = ctx.freshName("iterator") val numOutput = metricTerm(ctx, "numOutputRows") val (beforeLoop, condCheck) = if (condition.isDefined) { // Split the code of creating variables based on whether it's used by condition or not. @@ -551,10 +592,10 @@ case class SortMergeJoinExec( s""" |while (findNextInnerJoinRows($leftInput, $rightInput)) { - | int $size = $matches.size(); | ${beforeLoop.trim} - | for (int $i = 0; $i < $size; $i ++) { - | InternalRow $rightRow = (InternalRow) $matches.get($i); + | scala.collection.Iterator $iterator = $matches.generateIterator(); + | while ($iterator.hasNext()) { + | InternalRow $rightRow = (InternalRow) $iterator.next(); | ${condCheck.trim} | $numOutput.add(1); | ${consume(ctx, leftVars ++ rightVars)} @@ -589,7 +630,8 @@ private[joins] class SortMergeJoinScanner( bufferedKeyGenerator: Projection, keyOrdering: Ordering[InternalRow], streamedIter: RowIterator, - bufferedIter: RowIterator) { + bufferedIter: RowIterator, + bufferThreshold: Int) { private[this] var streamedRow: InternalRow = _ private[this] var streamedRowKey: InternalRow = _ private[this] var bufferedRow: InternalRow = _ @@ -600,7 +642,7 @@ private[joins] class SortMergeJoinScanner( */ private[this] var matchJoinKey: InternalRow = _ /** Buffered rows from the buffered side of the join. This is empty if there are no matches. */ - private[this] val bufferedMatches: ArrayBuffer[InternalRow] = new ArrayBuffer[InternalRow] + private[this] val bufferedMatches = new ExternalAppendOnlyUnsafeRowArray(bufferThreshold) // Initialization (note: do _not_ want to advance streamed here). advancedBufferedToRowWithNullFreeJoinKey() @@ -609,7 +651,7 @@ private[joins] class SortMergeJoinScanner( def getStreamedRow: InternalRow = streamedRow - def getBufferedMatches: ArrayBuffer[InternalRow] = bufferedMatches + def getBufferedMatches: ExternalAppendOnlyUnsafeRowArray = bufferedMatches /** * Advances both input iterators, stopping when we have found rows with matching join keys. @@ -755,7 +797,7 @@ private[joins] class SortMergeJoinScanner( matchJoinKey = streamedRowKey.copy() bufferedMatches.clear() do { - bufferedMatches += bufferedRow.copy() // need to copy mutable rows before buffering them + bufferedMatches.add(bufferedRow.asInstanceOf[UnsafeRow]) advancedBufferedToRowWithNullFreeJoinKey() } while (bufferedRow != null && keyOrdering.compare(streamedRowKey, bufferedRowKey) == 0) } @@ -819,7 +861,7 @@ private abstract class OneSideOuterIterator( protected[this] val joinedRow: JoinedRow = new JoinedRow() // Index of the buffered rows, reset to 0 whenever we advance to a new streamed row - private[this] var bufferIndex: Int = 0 + private[this] var rightMatchesIterator: Iterator[UnsafeRow] = null // This iterator is initialized lazily so there should be no matches initially assert(smjScanner.getBufferedMatches.length == 0) @@ -833,7 +875,7 @@ private abstract class OneSideOuterIterator( * @return whether there are more rows in the stream to consume. */ private def advanceStream(): Boolean = { - bufferIndex = 0 + rightMatchesIterator = null if (smjScanner.findNextOuterJoinRows()) { setStreamSideOutput(smjScanner.getStreamedRow) if (smjScanner.getBufferedMatches.isEmpty) { @@ -858,10 +900,13 @@ private abstract class OneSideOuterIterator( */ private def advanceBufferUntilBoundConditionSatisfied(): Boolean = { var foundMatch: Boolean = false - while (!foundMatch && bufferIndex < smjScanner.getBufferedMatches.length) { - setBufferedSideOutput(smjScanner.getBufferedMatches(bufferIndex)) + if (rightMatchesIterator == null) { + rightMatchesIterator = smjScanner.getBufferedMatches.generateIterator() + } + + while (!foundMatch && rightMatchesIterator.hasNext) { + setBufferedSideOutput(rightMatchesIterator.next()) foundMatch = boundCondition(joinedRow) - bufferIndex += 1 } foundMatch } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/metric/SQLMetrics.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/metric/SQLMetrics.scala index dbc27d8b237f3..ef982a4ebd10d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/metric/SQLMetrics.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/metric/SQLMetrics.scala @@ -22,9 +22,15 @@ import java.util.Locale import org.apache.spark.SparkContext import org.apache.spark.scheduler.AccumulableInfo +import org.apache.spark.sql.execution.ui.SparkListenerDriverAccumUpdates import org.apache.spark.util.{AccumulatorContext, AccumulatorV2, Utils} +/** + * A metric used in a SQL query plan. This is implemented as an [[AccumulatorV2]]. Updates on + * the executor side are automatically propagated and shown in the SQL UI through metrics. Updates + * on the driver side must be explicitly posted using [[SQLMetrics.postDriverMetricUpdates()]]. + */ class SQLMetric(val metricType: String, initValue: Long = 0L) extends AccumulatorV2[Long, Long] { // This is a workaround for SPARK-11013. // We may use -1 as initial value of the accumulator, if the accumulator is valid, we will @@ -126,4 +132,18 @@ object SQLMetrics { s"\n$sum ($min, $med, $max)" } } + + /** + * Updates metrics based on the driver side value. This is useful for certain metrics that + * are only updated on the driver, e.g. subquery execution time, or number of files. + */ + def postDriverMetricUpdates( + sc: SparkContext, executionId: String, metrics: Seq[SQLMetric]): Unit = { + // There are some cases we don't care about the metrics and call `SparkPlan.doExecute` + // directly without setting an execution id. We should be tolerant to it. + if (executionId != null) { + sc.listenerBus.post( + SparkListenerDriverAccumUpdates(executionId.toLong, metrics.map(m => m.id -> m.value))) + } + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala index 199ba5ce6969b..48c7b80bffe03 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala @@ -28,11 +28,13 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.expressions.objects.Invoke +import org.apache.spark.sql.catalyst.plans.logical.FunctionUtils import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.Row -import org.apache.spark.sql.catalyst.plans.logical.LogicalKeyedState -import org.apache.spark.sql.execution.streaming.KeyedStateImpl -import org.apache.spark.sql.types.{DataType, ObjectType, StructType} +import org.apache.spark.sql.catalyst.plans.logical.LogicalGroupState +import org.apache.spark.sql.execution.streaming.GroupStateImpl +import org.apache.spark.sql.types._ +import org.apache.spark.util.Utils /** @@ -219,7 +221,7 @@ case class MapElementsExec( override def doConsume(ctx: CodegenContext, input: Seq[ExprCode], row: ExprCode): String = { val (funcClass, methodName) = func match { case m: MapFunction[_, _] => classOf[MapFunction[_, _]] -> "call" - case _ => classOf[Any => Any] -> "apply" + case _ => FunctionUtils.getFunctionOneName(outputObjAttr.dataType, child.output(0).dataType) } val funcObj = Literal.create(func, ObjectType(funcClass)) val callFunc = Invoke(funcObj, methodName, outputObjAttr.dataType, child.output) @@ -353,14 +355,14 @@ case class MapGroupsExec( object MapGroupsExec { def apply( - func: (Any, Iterator[Any], LogicalKeyedState[Any]) => TraversableOnce[Any], + func: (Any, Iterator[Any], LogicalGroupState[Any]) => TraversableOnce[Any], keyDeserializer: Expression, valueDeserializer: Expression, groupingAttributes: Seq[Attribute], dataAttributes: Seq[Attribute], outputObjAttr: Attribute, child: SparkPlan): MapGroupsExec = { - val f = (key: Any, values: Iterator[Any]) => func(key, values, new KeyedStateImpl[Any](None)) + val f = (key: Any, values: Iterator[Any]) => func(key, values, new GroupStateImpl[Any](None)) new MapGroupsExec(f, keyDeserializer, valueDeserializer, groupingAttributes, dataAttributes, outputObjAttr, child) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/EvaluatePython.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/EvaluatePython.scala index 46fd54e5c7420..fcd84705f7e8b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/EvaluatePython.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/EvaluatePython.scala @@ -112,6 +112,8 @@ object EvaluatePython { case (c: Int, DateType) => c case (c: Long, TimestampType) => c + // Py4J serializes values between MIN_INT and MAX_INT as Ints, not Longs + case (c: Int, TimestampType) => c.toLong case (c, StringType) => UTF8String.fromString(c.toString) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala index c3d8859cb7a92..1debad03c93fa 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala @@ -54,6 +54,9 @@ object StatFunctions extends Logging { * Note that values greater than 1 are accepted but give the same result as 1. * * @return for each column, returns the requested approximations + * + * @note null and NaN values will be ignored in numerical columns before calculation. For + * a column only containing null or NaN values, an empty array is returned. */ def multipleApproxQuantiles( df: DataFrame, @@ -78,7 +81,10 @@ object StatFunctions extends Logging { def apply(summaries: Array[QuantileSummaries], row: Row): Array[QuantileSummaries] = { var i = 0 while (i < summaries.length) { - summaries(i) = summaries(i).insert(row.getDouble(i)) + if (!row.isNullAt(i)) { + val v = row.getDouble(i) + if (!v.isNaN) summaries(i) = summaries(i).insert(v) + } i += 1 } summaries @@ -91,7 +97,7 @@ object StatFunctions extends Logging { } val summaries = df.select(columns: _*).rdd.aggregate(emptySummaries)(apply, merge) - summaries.map { summary => probabilities.map(summary.query) } + summaries.map { summary => probabilities.flatMap(summary.query) } } /** Calculate the Pearson Correlation Coefficient for the given columns */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/BatchCommitLog.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/BatchCommitLog.scala new file mode 100644 index 0000000000000..a34938f911f76 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/BatchCommitLog.scala @@ -0,0 +1,83 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.streaming + +import java.io.{InputStream, OutputStream} +import java.nio.charset.StandardCharsets._ + +import scala.io.{Source => IOSource} + +import org.apache.spark.sql.SparkSession + +/** + * Used to write log files that represent batch commit points in structured streaming. + * A commit log file will be written immediately after the successful completion of a + * batch, and before processing the next batch. Here is an execution summary: + * - trigger batch 1 + * - obtain batch 1 offsets and write to offset log + * - process batch 1 + * - write batch 1 to completion log + * - trigger batch 2 + * - obtain bactch 2 offsets and write to offset log + * - process batch 2 + * - write batch 2 to completion log + * .... + * + * The current format of the batch completion log is: + * line 1: version + * line 2: metadata (optional json string) + */ +class BatchCommitLog(sparkSession: SparkSession, path: String) + extends HDFSMetadataLog[String](sparkSession, path) { + + import BatchCommitLog._ + + def add(batchId: Long): Unit = { + super.add(batchId, EMPTY_JSON) + } + + override def add(batchId: Long, metadata: String): Boolean = { + throw new UnsupportedOperationException( + "BatchCommitLog does not take any metadata, use 'add(batchId)' instead") + } + + override protected def deserialize(in: InputStream): String = { + // called inside a try-finally where the underlying stream is closed in the caller + val lines = IOSource.fromInputStream(in, UTF_8.name()).getLines() + if (!lines.hasNext) { + throw new IllegalStateException("Incomplete log file in the offset commit log") + } + parseVersion(lines.next.trim, VERSION) + EMPTY_JSON + } + + override protected def serialize(metadata: String, out: OutputStream): Unit = { + // called inside a try-finally where the underlying stream is closed in the caller + out.write(s"v${VERSION}".getBytes(UTF_8)) + out.write('\n') + + // write metadata + out.write(EMPTY_JSON.getBytes(UTF_8)) + } +} + +object BatchCommitLog { + private val VERSION = 1 + private val EMPTY_JSON = "{}" +} + diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/CompactibleFileStreamLog.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/CompactibleFileStreamLog.scala index 5a6f9e87f6eaa..408c8f81f17ba 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/CompactibleFileStreamLog.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/CompactibleFileStreamLog.scala @@ -40,7 +40,7 @@ import org.apache.spark.sql.SparkSession * doing a compaction, it will read all old log files and merge them with the new batch. */ abstract class CompactibleFileStreamLog[T <: AnyRef : ClassTag]( - metadataLogVersion: String, + metadataLogVersion: Int, sparkSession: SparkSession, path: String) extends HDFSMetadataLog[Array[T]](sparkSession, path) { @@ -134,7 +134,7 @@ abstract class CompactibleFileStreamLog[T <: AnyRef : ClassTag]( override def serialize(logData: Array[T], out: OutputStream): Unit = { // called inside a try-finally where the underlying stream is closed in the caller - out.write(metadataLogVersion.getBytes(UTF_8)) + out.write(("v" + metadataLogVersion).getBytes(UTF_8)) logData.foreach { data => out.write('\n') out.write(Serialization.write(data).getBytes(UTF_8)) @@ -146,10 +146,7 @@ abstract class CompactibleFileStreamLog[T <: AnyRef : ClassTag]( if (!lines.hasNext) { throw new IllegalStateException("Incomplete log file") } - val version = lines.next() - if (version != metadataLogVersion) { - throw new IllegalStateException(s"Unknown log version: ${version}") - } + val version = parseVersion(lines.next(), metadataLogVersion) lines.map(Serialization.read[T]).toArray } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/EventTimeWatermarkExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/EventTimeWatermarkExec.scala index 5a9a99e11188e..25cf609fc336e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/EventTimeWatermarkExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/EventTimeWatermarkExec.scala @@ -84,10 +84,7 @@ case class EventTimeWatermarkExec( child: SparkPlan) extends SparkPlan { val eventTimeStats = new EventTimeStatsAccum() - val delayMs = { - val millisPerMonth = CalendarInterval.MICROS_PER_DAY / 1000 * 31 - delay.milliseconds + delay.months * millisPerMonth - } + val delayMs = EventTimeWatermark.getDelayMs(delay) sparkContext.register(eventTimeStats) @@ -105,10 +102,16 @@ case class EventTimeWatermarkExec( override val output: Seq[Attribute] = child.output.map { a => if (a semanticEquals eventTime) { val updatedMetadata = new MetadataBuilder() - .withMetadata(a.metadata) - .putLong(EventTimeWatermark.delayKey, delayMs) - .build() - + .withMetadata(a.metadata) + .putLong(EventTimeWatermark.delayKey, delayMs) + .build() + a.withMetadata(updatedMetadata) + } else if (a.metadata.contains(EventTimeWatermark.delayKey)) { + // Remove existing watermark + val updatedMetadata = new MetadataBuilder() + .withMetadata(a.metadata) + .remove(EventTimeWatermark.delayKey) + .build() a.withMetadata(updatedMetadata) } else { a diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamOptions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamOptions.scala index 2f802d782f5ad..d54ed44b43bf1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamOptions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamOptions.scala @@ -38,7 +38,10 @@ class FileStreamOptions(parameters: CaseInsensitiveMap[String]) extends Logging } /** - * Maximum age of a file that can be found in this directory, before it is deleted. + * Maximum age of a file that can be found in this directory, before it is ignored. For the + * first batch all files will be considered valid. If `latestFirst` is set to `true` and + * `maxFilesPerTrigger` is set, then this parameter will be ignored, because old files that are + * valid, and should be processed, may be ignored. Please refer to SPARK-19813 for details. * * The max age is specified with respect to the timestamp of the latest file, and not the * timestamp of the current system. That this means if the last file has timestamp 1000, and the @@ -58,13 +61,29 @@ class FileStreamOptions(parameters: CaseInsensitiveMap[String]) extends Logging * Whether to scan latest files first. If it's true, when the source finds unprocessed files in a * trigger, it will first process the latest files. */ - val latestFirst: Boolean = parameters.get("latestFirst").map { str => - try { - str.toBoolean - } catch { - case _: IllegalArgumentException => - throw new IllegalArgumentException( - s"Invalid value '$str' for option 'latestFirst', must be 'true' or 'false'") - } - }.getOrElse(false) + val latestFirst: Boolean = withBooleanParameter("latestFirst", false) + + /** + * Whether to check new files based on only the filename instead of on the full path. + * + * With this set to `true`, the following files would be considered as the same file, because + * their filenames, "dataset.txt", are the same: + * - "file:///dataset.txt" + * - "s3://a/dataset.txt" + * - "s3n://a/b/dataset.txt" + * - "s3a://a/b/c/dataset.txt" + */ + val fileNameOnly: Boolean = withBooleanParameter("fileNameOnly", false) + + private def withBooleanParameter(name: String, default: Boolean) = { + parameters.get(name).map { str => + try { + str.toBoolean + } catch { + case _: IllegalArgumentException => + throw new IllegalArgumentException( + s"Invalid value '$str' for option '$name', must be 'true' or 'false'") + } + }.getOrElse(default) + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSink.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSink.scala index 07ec4e9429e42..6885d0bf67ccb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSink.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSink.scala @@ -53,6 +53,26 @@ object FileStreamSink extends Logging { case _ => false } } + + /** + * Returns true if the path is the metadata dir or its ancestor is the metadata dir. + * E.g.: + * - ancestorIsMetadataDirectory(/.../_spark_metadata) => true + * - ancestorIsMetadataDirectory(/.../_spark_metadata/0) => true + * - ancestorIsMetadataDirectory(/a/b/c) => false + */ + def ancestorIsMetadataDirectory(path: Path, hadoopConf: Configuration): Boolean = { + val fs = path.getFileSystem(hadoopConf) + var currentPath = path.makeQualified(fs.getUri, fs.getWorkingDirectory) + while (currentPath != null) { + if (currentPath.getName == FileStreamSink.metadataDir) { + return true + } else { + currentPath = currentPath.getParent + } + } + return false + } } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSinkLog.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSinkLog.scala index eb6eed87eca7b..8d718b2164d22 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSinkLog.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSinkLog.scala @@ -77,7 +77,7 @@ object SinkFileStatus { * (drops the deleted files). */ class FileStreamSinkLog( - metadataLogVersion: String, + metadataLogVersion: Int, sparkSession: SparkSession, path: String) extends CompactibleFileStreamLog[SinkFileStatus](metadataLogVersion, sparkSession, path) { @@ -106,7 +106,7 @@ class FileStreamSinkLog( } object FileStreamSinkLog { - val VERSION = "v1" + val VERSION = 1 val DELETE_ACTION = "delete" val ADD_ACTION = "add" } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSource.scala index 6a7263ca45d85..a9e64c640042a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSource.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.execution.streaming +import java.net.URI + import scala.collection.JavaConverters._ import org.apache.hadoop.fs.{FileStatus, Path} @@ -66,23 +68,36 @@ class FileStreamSource( private val fileSortOrder = if (sourceOptions.latestFirst) { logWarning( - """'latestFirst' is true. New files will be processed first. - |It may affect the watermark value""".stripMargin) + """'latestFirst' is true. New files will be processed first, which may affect the watermark + |value. In addition, 'maxFileAge' will be ignored.""".stripMargin) implicitly[Ordering[Long]].reverse } else { implicitly[Ordering[Long]] } + private val maxFileAgeMs: Long = if (sourceOptions.latestFirst && maxFilesPerBatch.isDefined) { + Long.MaxValue + } else { + sourceOptions.maxFileAgeMs + } + + private val fileNameOnly = sourceOptions.fileNameOnly + if (fileNameOnly) { + logWarning("'fileNameOnly' is enabled. Make sure your file names are unique (e.g. using " + + "UUID), otherwise, files with the same name but under different paths will be considered " + + "the same and causes data lost.") + } + /** A mapping from a file that we have processed to some timestamp it was last modified. */ // Visible for testing and debugging in production. - val seenFiles = new SeenFilesMap(sourceOptions.maxFileAgeMs) + val seenFiles = new SeenFilesMap(maxFileAgeMs, fileNameOnly) metadataLog.allFiles().foreach { entry => seenFiles.add(entry.path, entry.timestamp) } seenFiles.purge() - logInfo(s"maxFilesPerBatch = $maxFilesPerBatch, maxFileAge = ${sourceOptions.maxFileAgeMs}") + logInfo(s"maxFilesPerBatch = $maxFilesPerBatch, maxFileAgeMs = $maxFileAgeMs") /** * Returns the maximum offset that can be retrieved from the source. @@ -262,7 +277,7 @@ object FileStreamSource { * To prevent the hash map from growing indefinitely, a purge function is available to * remove files "maxAgeMs" older than the latest file. */ - class SeenFilesMap(maxAgeMs: Long) { + class SeenFilesMap(maxAgeMs: Long, fileNameOnly: Boolean) { require(maxAgeMs >= 0) /** Mapping from file to its timestamp. */ @@ -274,9 +289,13 @@ object FileStreamSource { /** Timestamp for the last purge operation. */ private var lastPurgeTimestamp: Timestamp = 0L + @inline private def stripPathIfNecessary(path: String) = { + if (fileNameOnly) new Path(new URI(path)).getName else path + } + /** Add a new file to the map. */ def add(path: String, timestamp: Timestamp): Unit = { - map.put(path, timestamp) + map.put(stripPathIfNecessary(path), timestamp) if (timestamp > latestTimestamp) { latestTimestamp = timestamp } @@ -289,7 +308,7 @@ object FileStreamSource { def isNewFile(path: String, timestamp: Timestamp): Boolean = { // Note that we are testing against lastPurgeTimestamp here so we'd never miss a file that // is older than (latestTimestamp - maxAgeMs) but has not been purged yet. - timestamp >= lastPurgeTimestamp && !map.containsKey(path) + timestamp >= lastPurgeTimestamp && !map.containsKey(stripPathIfNecessary(path)) } /** Removes aged entries and returns the number of files removed. */ @@ -308,9 +327,5 @@ object FileStreamSource { } def size: Int = map.size() - - def allEntries: Seq[(String, Timestamp)] = { - map.asScala.toSeq - } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSourceLog.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSourceLog.scala index 81908c0cefdfa..33e6a1d5d6e18 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSourceLog.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSourceLog.scala @@ -30,7 +30,7 @@ import org.apache.spark.sql.execution.streaming.FileStreamSource.FileEntry import org.apache.spark.sql.internal.SQLConf class FileStreamSourceLog( - metadataLogVersion: String, + metadataLogVersion: Int, sparkSession: SparkSession, path: String) extends CompactibleFileStreamLog[FileEntry](metadataLogVersion, sparkSession, path) { @@ -120,5 +120,5 @@ class FileStreamSourceLog( } object FileStreamSourceLog { - val VERSION = "v1" + val VERSION = 1 } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FlatMapGroupsWithStateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FlatMapGroupsWithStateExec.scala new file mode 100644 index 0000000000000..e42df5dd61c70 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FlatMapGroupsWithStateExec.scala @@ -0,0 +1,289 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.execution.streaming + +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder +import org.apache.spark.sql.catalyst.expressions.{Ascending, Attribute, AttributeReference, Expression, Literal, SortOrder, UnsafeRow} +import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.catalyst.plans.physical.{ClusteredDistribution, Distribution} +import org.apache.spark.sql.execution._ +import org.apache.spark.sql.execution.streaming.GroupStateImpl.NO_TIMESTAMP +import org.apache.spark.sql.execution.streaming.state._ +import org.apache.spark.sql.streaming.{GroupStateTimeout, OutputMode} +import org.apache.spark.sql.types.IntegerType +import org.apache.spark.util.CompletionIterator + +/** + * Physical operator for executing `FlatMapGroupsWithState.` + * + * @param func function called on each group + * @param keyDeserializer used to extract the key object for each group. + * @param valueDeserializer used to extract the items in the iterator from an input row. + * @param groupingAttributes used to group the data + * @param dataAttributes used to read the data + * @param outputObjAttr used to define the output object + * @param stateEncoder used to serialize/deserialize state before calling `func` + * @param outputMode the output mode of `func` + * @param timeoutConf used to timeout groups that have not received data in a while + * @param batchTimestampMs processing timestamp of the current batch. + */ +case class FlatMapGroupsWithStateExec( + func: (Any, Iterator[Any], LogicalGroupState[Any]) => Iterator[Any], + keyDeserializer: Expression, + valueDeserializer: Expression, + groupingAttributes: Seq[Attribute], + dataAttributes: Seq[Attribute], + outputObjAttr: Attribute, + stateId: Option[OperatorStateId], + stateEncoder: ExpressionEncoder[Any], + outputMode: OutputMode, + timeoutConf: GroupStateTimeout, + batchTimestampMs: Option[Long], + override val eventTimeWatermark: Option[Long], + child: SparkPlan + ) extends UnaryExecNode with ObjectProducerExec with StateStoreWriter with WatermarkSupport { + + import GroupStateImpl._ + + private val isTimeoutEnabled = timeoutConf != NoTimeout + private val timestampTimeoutAttribute = + AttributeReference("timeoutTimestamp", dataType = IntegerType, nullable = false)() + private val stateAttributes: Seq[Attribute] = { + val encSchemaAttribs = stateEncoder.schema.toAttributes + if (isTimeoutEnabled) encSchemaAttribs :+ timestampTimeoutAttribute else encSchemaAttribs + } + // Get the serializer for the state, taking into account whether we need to save timestamps + private val stateSerializer = { + val encoderSerializer = stateEncoder.namedExpressions + if (isTimeoutEnabled) { + encoderSerializer :+ Literal(GroupStateImpl.NO_TIMESTAMP) + } else { + encoderSerializer + } + } + // Get the deserializer for the state. Note that this must be done in the driver, as + // resolving and binding of deserializer expressions to the encoded type can be safely done + // only in the driver. + private val stateDeserializer = stateEncoder.resolveAndBind().deserializer + + + /** Distribute by grouping attributes */ + override def requiredChildDistribution: Seq[Distribution] = + ClusteredDistribution(groupingAttributes) :: Nil + + /** Ordering needed for using GroupingIterator */ + override def requiredChildOrdering: Seq[Seq[SortOrder]] = + Seq(groupingAttributes.map(SortOrder(_, Ascending))) + + override def keyExpressions: Seq[Attribute] = groupingAttributes + + override protected def doExecute(): RDD[InternalRow] = { + metrics // force lazy init at driver + + // Throw errors early if parameters are not as expected + timeoutConf match { + case ProcessingTimeTimeout => + require(batchTimestampMs.nonEmpty) + case EventTimeTimeout => + require(eventTimeWatermark.nonEmpty) // watermark value has been populated + require(watermarkExpression.nonEmpty) // input schema has watermark attribute + case _ => + } + + child.execute().mapPartitionsWithStateStore[InternalRow]( + getStateId.checkpointLocation, + getStateId.operatorId, + getStateId.batchId, + groupingAttributes.toStructType, + stateAttributes.toStructType, + sqlContext.sessionState, + Some(sqlContext.streams.stateStoreCoordinator)) { case (store, iter) => + val updater = new StateStoreUpdater(store) + + // If timeout is based on event time, then filter late data based on watermark + val filteredIter = watermarkPredicateForData match { + case Some(predicate) if timeoutConf == EventTimeTimeout => + iter.filter(row => !predicate.eval(row)) + case None => + iter + } + + // Generate a iterator that returns the rows grouped by the grouping function + // Note that this code ensures that the filtering for timeout occurs only after + // all the data has been processed. This is to ensure that the timeout information of all + // the keys with data is updated before they are processed for timeouts. + val outputIterator = + updater.updateStateForKeysWithData(filteredIter) ++ updater.updateStateForTimedOutKeys() + + // Return an iterator of all the rows generated by all the keys, such that when fully + // consumed, all the state updates will be committed by the state store + CompletionIterator[InternalRow, Iterator[InternalRow]]( + outputIterator, + { + store.commit() + longMetric("numTotalStateRows") += store.numKeys() + } + ) + } + } + + /** Helper class to update the state store */ + class StateStoreUpdater(store: StateStore) { + + // Converters for translating input keys, values, output data between rows and Java objects + private val getKeyObj = + ObjectOperator.deserializeRowToObject(keyDeserializer, groupingAttributes) + private val getValueObj = + ObjectOperator.deserializeRowToObject(valueDeserializer, dataAttributes) + private val getOutputRow = ObjectOperator.wrapObjectToRow(outputObjAttr.dataType) + + // Converters for translating state between rows and Java objects + private val getStateObjFromRow = ObjectOperator.deserializeRowToObject( + stateDeserializer, stateAttributes) + private val getStateRowFromObj = ObjectOperator.serializeObjectToRow(stateSerializer) + + // Index of the additional metadata fields in the state row + private val timeoutTimestampIndex = stateAttributes.indexOf(timestampTimeoutAttribute) + + // Metrics + private val numUpdatedStateRows = longMetric("numUpdatedStateRows") + private val numOutputRows = longMetric("numOutputRows") + + /** + * For every group, get the key, values and corresponding state and call the function, + * and return an iterator of rows + */ + def updateStateForKeysWithData(dataIter: Iterator[InternalRow]): Iterator[InternalRow] = { + val groupedIter = GroupedIterator(dataIter, groupingAttributes, child.output) + groupedIter.flatMap { case (keyRow, valueRowIter) => + val keyUnsafeRow = keyRow.asInstanceOf[UnsafeRow] + callFunctionAndUpdateState( + keyUnsafeRow, + valueRowIter, + store.get(keyUnsafeRow), + hasTimedOut = false) + } + } + + /** Find the groups that have timeout set and are timing out right now, and call the function */ + def updateStateForTimedOutKeys(): Iterator[InternalRow] = { + if (isTimeoutEnabled) { + val timeoutThreshold = timeoutConf match { + case ProcessingTimeTimeout => batchTimestampMs.get + case EventTimeTimeout => eventTimeWatermark.get + case _ => + throw new IllegalStateException( + s"Cannot filter timed out keys for $timeoutConf") + } + val timingOutKeys = store.filter { case (_, stateRow) => + val timeoutTimestamp = getTimeoutTimestamp(stateRow) + timeoutTimestamp != NO_TIMESTAMP && timeoutTimestamp < timeoutThreshold + } + timingOutKeys.flatMap { case (keyRow, stateRow) => + callFunctionAndUpdateState(keyRow, Iterator.empty, Some(stateRow), hasTimedOut = true) + } + } else Iterator.empty + } + + /** + * Call the user function on a key's data, update the state store, and return the return data + * iterator. Note that the store updating is lazy, that is, the store will be updated only + * after the returned iterator is fully consumed. + */ + private def callFunctionAndUpdateState( + keyRow: UnsafeRow, + valueRowIter: Iterator[InternalRow], + prevStateRowOption: Option[UnsafeRow], + hasTimedOut: Boolean): Iterator[InternalRow] = { + + val keyObj = getKeyObj(keyRow) // convert key to objects + val valueObjIter = valueRowIter.map(getValueObj.apply) // convert value rows to objects + val stateObjOption = getStateObj(prevStateRowOption) + val keyedState = new GroupStateImpl( + stateObjOption, + batchTimestampMs.getOrElse(NO_TIMESTAMP), + eventTimeWatermark.getOrElse(NO_TIMESTAMP), + timeoutConf, + hasTimedOut) + + // Call function, get the returned objects and convert them to rows + val mappedIterator = func(keyObj, valueObjIter, keyedState).map { obj => + numOutputRows += 1 + getOutputRow(obj) + } + + // When the iterator is consumed, then write changes to state + def onIteratorCompletion: Unit = { + if (keyedState.hasRemoved) { + store.remove(keyRow) + numUpdatedStateRows += 1 + + } else { + val previousTimeoutTimestamp = prevStateRowOption match { + case Some(row) => getTimeoutTimestamp(row) + case None => NO_TIMESTAMP + } + val currentTimeoutTimestamp = keyedState.getTimeoutTimestamp + val stateRowToWrite = if (keyedState.hasUpdated) { + getStateRow(keyedState.get) + } else { + prevStateRowOption.orNull + } + + val hasTimeoutChanged = currentTimeoutTimestamp != previousTimeoutTimestamp + val shouldWriteState = keyedState.hasUpdated || hasTimeoutChanged + + if (shouldWriteState) { + if (stateRowToWrite == null) { + // This should never happen because checks in GroupStateImpl should avoid cases + // where empty state would need to be written + throw new IllegalStateException("Attempting to write empty state") + } + setTimeoutTimestamp(stateRowToWrite, currentTimeoutTimestamp) + store.put(keyRow.copy(), stateRowToWrite.copy()) + numUpdatedStateRows += 1 + } + } + } + + // Return an iterator of rows such that fully consumed, the updated state value will be saved + CompletionIterator[InternalRow, Iterator[InternalRow]](mappedIterator, onIteratorCompletion) + } + + /** Returns the state as Java object if defined */ + def getStateObj(stateRowOption: Option[UnsafeRow]): Option[Any] = { + stateRowOption.map(getStateObjFromRow) + } + + /** Returns the row for an updated state */ + def getStateRow(obj: Any): UnsafeRow = { + getStateRowFromObj(obj) + } + + /** Returns the timeout timestamp of a state row is set */ + def getTimeoutTimestamp(stateRow: UnsafeRow): Long = { + if (isTimeoutEnabled) stateRow.getLong(timeoutTimestampIndex) else NO_TIMESTAMP + } + + /** Set the timestamp in a state row */ + def setTimeoutTimestamp(stateRow: UnsafeRow, timeoutTimestamps: Long): Unit = { + if (isTimeoutEnabled) stateRow.setLong(timeoutTimestampIndex, timeoutTimestamps) + } + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/GroupStateImpl.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/GroupStateImpl.scala new file mode 100644 index 0000000000000..148d92247d6f0 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/GroupStateImpl.scala @@ -0,0 +1,228 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.streaming + +import java.sql.Date + +import org.apache.commons.lang3.StringUtils + +import org.apache.spark.sql.catalyst.plans.logical.{EventTimeTimeout, ProcessingTimeTimeout} +import org.apache.spark.sql.execution.streaming.GroupStateImpl._ +import org.apache.spark.sql.streaming.{GroupState, GroupStateTimeout} +import org.apache.spark.unsafe.types.CalendarInterval + + +/** + * Internal implementation of the [[GroupState]] interface. Methods are not thread-safe. + * + * @param optionalValue Optional value of the state + * @param batchProcessingTimeMs Processing time of current batch, used to calculate timestamp + * for processing time timeouts + * @param timeoutConf Type of timeout configured. Based on this, different operations will + * be supported. + * @param hasTimedOut Whether the key for which this state wrapped is being created is + * getting timed out or not. + */ +private[sql] class GroupStateImpl[S]( + optionalValue: Option[S], + batchProcessingTimeMs: Long, + eventTimeWatermarkMs: Long, + timeoutConf: GroupStateTimeout, + override val hasTimedOut: Boolean) extends GroupState[S] { + + // Constructor to create dummy state when using mapGroupsWithState in a batch query + def this(optionalValue: Option[S]) = this( + optionalValue, + batchProcessingTimeMs = NO_TIMESTAMP, + eventTimeWatermarkMs = NO_TIMESTAMP, + timeoutConf = GroupStateTimeout.NoTimeout, + hasTimedOut = false) + private var value: S = optionalValue.getOrElse(null.asInstanceOf[S]) + private var defined: Boolean = optionalValue.isDefined + private var updated: Boolean = false // whether value has been updated (but not removed) + private var removed: Boolean = false // whether value has been removed + private var timeoutTimestamp: Long = NO_TIMESTAMP + + // ========= Public API ========= + override def exists: Boolean = defined + + override def get: S = { + if (defined) { + value + } else { + throw new NoSuchElementException("State is either not defined or has already been removed") + } + } + + override def getOption: Option[S] = { + if (defined) { + Some(value) + } else { + None + } + } + + override def update(newValue: S): Unit = { + if (newValue == null) { + throw new IllegalArgumentException("'null' is not a valid state value") + } + value = newValue + defined = true + updated = true + removed = false + } + + override def remove(): Unit = { + defined = false + updated = false + removed = true + timeoutTimestamp = NO_TIMESTAMP + } + + override def setTimeoutDuration(durationMs: Long): Unit = { + if (timeoutConf != ProcessingTimeTimeout) { + throw new UnsupportedOperationException( + "Cannot set timeout duration without enabling processing time timeout in " + + "map/flatMapGroupsWithState") + } + if (!defined) { + throw new IllegalStateException( + "Cannot set timeout information without any state value, " + + "state has either not been initialized, or has already been removed") + } + + if (durationMs <= 0) { + throw new IllegalArgumentException("Timeout duration must be positive") + } + if (!removed && batchProcessingTimeMs != NO_TIMESTAMP) { + timeoutTimestamp = durationMs + batchProcessingTimeMs + } else { + // This is being called in a batch query, hence no processing timestamp. + // Just ignore any attempts to set timeout. + } + } + + override def setTimeoutDuration(duration: String): Unit = { + setTimeoutDuration(parseDuration(duration)) + } + + @throws[IllegalArgumentException]("if 'timestampMs' is not positive") + @throws[IllegalStateException]("when state is either not initialized, or already removed") + @throws[UnsupportedOperationException]( + "if 'timeout' has not been enabled in [map|flatMap]GroupsWithState in a streaming query") + override def setTimeoutTimestamp(timestampMs: Long): Unit = { + checkTimeoutTimestampAllowed() + if (timestampMs <= 0) { + throw new IllegalArgumentException("Timeout timestamp must be positive") + } + if (eventTimeWatermarkMs != NO_TIMESTAMP && timestampMs < eventTimeWatermarkMs) { + throw new IllegalArgumentException( + s"Timeout timestamp ($timestampMs) cannot be earlier than the " + + s"current watermark ($eventTimeWatermarkMs)") + } + if (!removed && batchProcessingTimeMs != NO_TIMESTAMP) { + timeoutTimestamp = timestampMs + } else { + // This is being called in a batch query, hence no processing timestamp. + // Just ignore any attempts to set timeout. + } + } + + @throws[IllegalArgumentException]("if 'additionalDuration' is invalid") + @throws[IllegalStateException]("when state is either not initialized, or already removed") + @throws[UnsupportedOperationException]( + "if 'timeout' has not been enabled in [map|flatMap]GroupsWithState in a streaming query") + override def setTimeoutTimestamp(timestampMs: Long, additionalDuration: String): Unit = { + checkTimeoutTimestampAllowed() + setTimeoutTimestamp(parseDuration(additionalDuration) + timestampMs) + } + + @throws[IllegalStateException]("when state is either not initialized, or already removed") + @throws[UnsupportedOperationException]( + "if 'timeout' has not been enabled in [map|flatMap]GroupsWithState in a streaming query") + override def setTimeoutTimestamp(timestamp: Date): Unit = { + checkTimeoutTimestampAllowed() + setTimeoutTimestamp(timestamp.getTime) + } + + @throws[IllegalArgumentException]("if 'additionalDuration' is invalid") + @throws[IllegalStateException]("when state is either not initialized, or already removed") + @throws[UnsupportedOperationException]( + "if 'timeout' has not been enabled in [map|flatMap]GroupsWithState in a streaming query") + override def setTimeoutTimestamp(timestamp: Date, additionalDuration: String): Unit = { + checkTimeoutTimestampAllowed() + setTimeoutTimestamp(timestamp.getTime + parseDuration(additionalDuration)) + } + + override def toString: String = { + s"GroupState(${getOption.map(_.toString).getOrElse("")})" + } + + // ========= Internal API ========= + + /** Whether the state has been marked for removing */ + def hasRemoved: Boolean = removed + + /** Whether the state has been updated */ + def hasUpdated: Boolean = updated + + /** Return timeout timestamp or `TIMEOUT_TIMESTAMP_NOT_SET` if not set */ + def getTimeoutTimestamp: Long = timeoutTimestamp + + private def parseDuration(duration: String): Long = { + if (StringUtils.isBlank(duration)) { + throw new IllegalArgumentException( + "Provided duration is null or blank.") + } + val intervalString = if (duration.startsWith("interval")) { + duration + } else { + "interval " + duration + } + val cal = CalendarInterval.fromString(intervalString) + if (cal == null) { + throw new IllegalArgumentException( + s"Provided duration ($duration) is not valid.") + } + if (cal.milliseconds < 0 || cal.months < 0) { + throw new IllegalArgumentException(s"Provided duration ($duration) is not positive") + } + + val millisPerMonth = CalendarInterval.MICROS_PER_DAY / 1000 * 31 + cal.milliseconds + cal.months * millisPerMonth + } + + private def checkTimeoutTimestampAllowed(): Unit = { + if (timeoutConf != EventTimeTimeout) { + throw new UnsupportedOperationException( + "Cannot set timeout timestamp without enabling event time timeout in " + + "map/flatMapGroupsWithState") + } + if (!defined) { + throw new IllegalStateException( + "Cannot set timeout timestamp without any state value, " + + "state has either not been initialized, or has already been removed") + } + } +} + + +private[sql] object GroupStateImpl { + // Value used represent the lack of valid timestamp as a long + val NO_TIMESTAMP = -1L +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/HDFSMetadataLog.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/HDFSMetadataLog.scala index f9e1f7de9ec08..46bfc297931fb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/HDFSMetadataLog.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/HDFSMetadataLog.scala @@ -106,6 +106,7 @@ class HDFSMetadataLog[T <: AnyRef : ClassTag](sparkSession: SparkSession, path: * metadata has already been stored, this method will return `false`. */ override def add(batchId: Long, metadata: T): Boolean = { + require(metadata != null, "'null' metadata cannot written to a metadata log") get(batchId).map(_ => false).getOrElse { // Only write metadata when the batch has not yet been written writeBatch(batchId, metadata) @@ -195,6 +196,11 @@ class HDFSMetadataLog[T <: AnyRef : ClassTag](sparkSession: SparkSession, path: val input = fileManager.open(batchMetadataFile) try { Some(deserialize(input)) + } catch { + case ise: IllegalStateException => + // re-throw the exception with the log file path added + throw new IllegalStateException( + s"Failed to read log file $batchMetadataFile. ${ise.getMessage}", ise) } finally { IOUtils.closeQuietly(input) } @@ -268,6 +274,37 @@ class HDFSMetadataLog[T <: AnyRef : ClassTag](sparkSession: SparkSession, path: new FileSystemManager(metadataPath, hadoopConf) } } + + /** + * Parse the log version from the given `text` -- will throw exception when the parsed version + * exceeds `maxSupportedVersion`, or when `text` is malformed (such as "xyz", "v", "v-1", + * "v123xyz" etc.) + */ + private[sql] def parseVersion(text: String, maxSupportedVersion: Int): Int = { + if (text.length > 0 && text(0) == 'v') { + val version = + try { + text.substring(1, text.length).toInt + } catch { + case _: NumberFormatException => + throw new IllegalStateException(s"Log file was malformed: failed to read correct log " + + s"version from $text.") + } + if (version > 0) { + if (version > maxSupportedVersion) { + throw new IllegalStateException(s"UnsupportedLogVersion: maximum supported log version " + + s"is v${maxSupportedVersion}, but encountered v$version. The log file was produced " + + s"by a newer version of Spark and cannot be read by this version. Please upgrade.") + } else { + return version + } + } + } + + // reaching here means we failed to read the correct log version + throw new IllegalStateException(s"Log file was malformed: failed to read correct log " + + s"version from $text.") + } } object HDFSMetadataLog { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala index ffdcd9b19d058..622e049630db2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala @@ -20,8 +20,8 @@ package org.apache.spark.sql.execution.streaming import java.util.concurrent.atomic.AtomicInteger import org.apache.spark.internal.Logging -import org.apache.spark.sql.catalyst.expressions.{CurrentBatchTimestamp, Literal} -import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.{SparkSession, Strategy} +import org.apache.spark.sql.catalyst.expressions.CurrentBatchTimestamp import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.execution.{QueryExecution, SparkPlan, SparkPlanner, UnaryExecNode} @@ -37,23 +37,20 @@ class IncrementalExecution( val outputMode: OutputMode, val checkpointLocation: String, val currentBatchId: Long, - val currentEventTimeWatermark: Long) + offsetSeqMetadata: OffsetSeqMetadata) extends QueryExecution(sparkSession, logicalPlan) with Logging { - // TODO: make this always part of planning. - val streamingExtraStrategies = - sparkSession.sessionState.planner.StatefulAggregationStrategy +: - sparkSession.sessionState.planner.MapGroupsWithStateStrategy +: - sparkSession.sessionState.planner.StreamingRelationStrategy +: - sparkSession.sessionState.planner.StreamingDeduplicationStrategy +: - sparkSession.sessionState.experimentalMethods.extraStrategies - // Modified planner with stateful operations. - override def planner: SparkPlanner = - new SparkPlanner( + override val planner: SparkPlanner = new SparkPlanner( sparkSession.sparkContext, sparkSession.sessionState.conf, - streamingExtraStrategies) + sparkSession.sessionState.experimentalMethods) { + override def extraPlanningStrategies: Seq[Strategy] = + StatefulAggregationStrategy :: + FlatMapGroupsWithStateStrategy :: + StreamingRelationStrategy :: + StreamingDeduplicationStrategy :: Nil + } /** * See [SPARK-18339] @@ -88,12 +85,13 @@ class IncrementalExecution( keys, Some(stateId), Some(outputMode), - Some(currentEventTimeWatermark), + Some(offsetSeqMetadata.batchWatermarkMs), agg.withNewChildren( StateStoreRestoreExec( keys, Some(stateId), child) :: Nil)) + case StreamingDeduplicateExec(keys, child, None, None) => val stateId = OperatorStateId(checkpointLocation, operatorId.getAndIncrement(), currentBatchId) @@ -102,13 +100,15 @@ class IncrementalExecution( keys, child, Some(stateId), - Some(currentEventTimeWatermark)) - case MapGroupsWithStateExec( - f, kDeser, vDeser, group, data, output, None, stateDeser, stateSer, child) => + Some(offsetSeqMetadata.batchWatermarkMs)) + + case m: FlatMapGroupsWithStateExec => val stateId = OperatorStateId(checkpointLocation, operatorId.getAndIncrement(), currentBatchId) - MapGroupsWithStateExec( - f, kDeser, vDeser, group, data, output, Some(stateId), stateDeser, stateSer, child) + m.copy( + stateId = Some(stateId), + batchTimestampMs = Some(offsetSeqMetadata.batchTimestampMs), + eventTimeWatermark = Some(offsetSeqMetadata.batchWatermarkMs)) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/KeyedStateImpl.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/KeyedStateImpl.scala deleted file mode 100644 index eee7ec45dd77b..0000000000000 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/KeyedStateImpl.scala +++ /dev/null @@ -1,80 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.execution.streaming - -import org.apache.spark.sql.KeyedState - -/** Internal implementation of the [[KeyedState]] interface. Methods are not thread-safe. */ -private[sql] class KeyedStateImpl[S](optionalValue: Option[S]) extends KeyedState[S] { - private var value: S = optionalValue.getOrElse(null.asInstanceOf[S]) - private var defined: Boolean = optionalValue.isDefined - private var updated: Boolean = false - // whether value has been updated (but not removed) - private var removed: Boolean = false // whether value has been removed - - // ========= Public API ========= - override def exists: Boolean = defined - - override def get: S = { - if (defined) { - value - } else { - throw new NoSuchElementException("State is either not defined or has already been removed") - } - } - - override def getOption: Option[S] = { - if (defined) { - Some(value) - } else { - None - } - } - - override def update(newValue: S): Unit = { - if (newValue == null) { - throw new IllegalArgumentException("'null' is not a valid state value") - } - value = newValue - defined = true - updated = true - removed = false - } - - override def remove(): Unit = { - defined = false - updated = false - removed = true - } - - override def toString: String = { - s"KeyedState(${getOption.map(_.toString).getOrElse("")})" - } - - // ========= Internal API ========= - - /** Whether the state has been marked for removing */ - def isRemoved: Boolean = { - removed - } - - /** Whether the state has been been updated */ - def isUpdated: Boolean = { - updated - } -} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/OffsetSeq.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/OffsetSeq.scala index e5a1997d6b808..8249adab4bba8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/OffsetSeq.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/OffsetSeq.scala @@ -20,7 +20,6 @@ package org.apache.spark.sql.execution.streaming import org.json4s.NoTypeHints import org.json4s.jackson.Serialization - /** * An ordered collection of offsets, used to track the progress of processing data from one or more * [[Source]]s that are present in a streaming query. This is similar to simplified, single-instance @@ -70,8 +69,12 @@ object OffsetSeq { * bound the lateness of data that will processed. Time unit: milliseconds * @param batchTimestampMs: The current batch processing timestamp. * Time unit: milliseconds + * @param conf: Additional conf_s to be persisted across batches, e.g. number of shuffle partitions. */ -case class OffsetSeqMetadata(var batchWatermarkMs: Long = 0, var batchTimestampMs: Long = 0) { +case class OffsetSeqMetadata( + batchWatermarkMs: Long = 0, + batchTimestampMs: Long = 0, + conf: Map[String, String] = Map.empty) { def json: String = Serialization.write(this)(OffsetSeqMetadata.format) } @@ -79,4 +82,3 @@ object OffsetSeqMetadata { private implicit val format = Serialization.formats(NoTypeHints) def apply(json: String): OffsetSeqMetadata = Serialization.read[OffsetSeqMetadata](json) } - diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/OffsetSeqLog.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/OffsetSeqLog.scala index 3210d8ad64e22..4f8cd116f610e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/OffsetSeqLog.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/OffsetSeqLog.scala @@ -55,10 +55,8 @@ class OffsetSeqLog(sparkSession: SparkSession, path: String) if (!lines.hasNext) { throw new IllegalStateException("Incomplete log file") } - val version = lines.next() - if (version != OffsetSeqLog.VERSION) { - throw new IllegalStateException(s"Unknown log version: ${version}") - } + + val version = parseVersion(lines.next(), OffsetSeqLog.VERSION) // read metadata val metadata = lines.next().trim match { @@ -70,7 +68,7 @@ class OffsetSeqLog(sparkSession: SparkSession, path: String) override protected def serialize(offsetSeq: OffsetSeq, out: OutputStream): Unit = { // called inside a try-finally where the underlying stream is closed in the caller - out.write(OffsetSeqLog.VERSION.getBytes(UTF_8)) + out.write(("v" + OffsetSeqLog.VERSION).getBytes(UTF_8)) // write metadata out.write('\n') @@ -88,6 +86,6 @@ class OffsetSeqLog(sparkSession: SparkSession, path: String) } object OffsetSeqLog { - private val VERSION = "v1" + private[streaming] val VERSION = 1 private val SERIALIZED_VOID_OFFSET = "-" } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/Source.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/Source.scala index 75ffe90f2bb70..311942f6dbd84 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/Source.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/Source.scala @@ -25,7 +25,7 @@ import org.apache.spark.sql.types.StructType * monotonically increasing notion of progress that can be represented as an [[Offset]]. Spark * will regularly query each [[Source]] to see if any more data is available. */ -trait Source { +trait Source { /** Returns the schema of the data from this source */ def schema: StructType diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala index 70912d13ae458..b6ddf7437ea13 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala @@ -23,6 +23,7 @@ import java.util.concurrent.{CountDownLatch, TimeUnit} import java.util.concurrent.atomic.AtomicReference import java.util.concurrent.locks.ReentrantLock +import scala.collection.mutable.{Map => MutableMap} import scala.collection.mutable.ArrayBuffer import scala.util.control.NonFatal @@ -35,6 +36,7 @@ import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeMap, Curre import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan} import org.apache.spark.sql.execution.QueryExecution import org.apache.spark.sql.execution.command.StreamingExplainCommand +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.streaming._ import org.apache.spark.util.{Clock, UninterruptibleThread, Utils} @@ -117,7 +119,9 @@ class StreamExecution( } /** Metadata associated with the offset seq of a batch in the query. */ - protected var offsetSeqMetadata = OffsetSeqMetadata() + protected var offsetSeqMetadata = OffsetSeqMetadata(batchWatermarkMs = 0, batchTimestampMs = 0, + conf = Map(SQLConf.SHUFFLE_PARTITIONS.key -> + sparkSession.conf.get(SQLConf.SHUFFLE_PARTITIONS).toString)) override val id: UUID = UUID.fromString(streamMetadata.id) @@ -145,15 +149,18 @@ class StreamExecution( "logicalPlan must be initialized in StreamExecutionThread " + s"but the current thread was ${Thread.currentThread}") var nextSourceId = 0L + val toExecutionRelationMap = MutableMap[StreamingRelation, StreamingExecutionRelation]() val _logicalPlan = analyzedPlan.transform { - case StreamingRelation(dataSource, _, output) => - // Materialize source to avoid creating it in every batch - val metadataPath = s"$checkpointRoot/sources/$nextSourceId" - val source = dataSource.createSource(metadataPath) - nextSourceId += 1 - // We still need to use the previous `output` instead of `source.schema` as attributes in - // "df.logicalPlan" has already used attributes of the previous `output`. - StreamingExecutionRelation(source, output) + case streamingRelation@StreamingRelation(dataSource, _, output) => + toExecutionRelationMap.getOrElseUpdate(streamingRelation, { + // Materialize source to avoid creating it in every batch + val metadataPath = s"$checkpointRoot/sources/$nextSourceId" + val source = dataSource.createSource(metadataPath) + nextSourceId += 1 + // We still need to use the previous `output` instead of `source.schema` as attributes in + // "df.logicalPlan" has already used attributes of the previous `output`. + StreamingExecutionRelation(source, output) + }) } sources = _logicalPlan.collect { case s: StreamingExecutionRelation => s.source } uniqueSources = sources.distinct @@ -162,6 +169,8 @@ class StreamExecution( private val triggerExecutor = trigger match { case t: ProcessingTime => ProcessingTimeExecutor(t, triggerClock) + case OneTimeTrigger => OneTimeExecutor() + case _ => throw new IllegalStateException(s"Unknown type of trigger: $trigger") } /** Defines the internal state of execution */ @@ -206,6 +215,13 @@ class StreamExecution( */ val offsetLog = new OffsetSeqLog(sparkSession, checkpointFile("offsets")) + /** + * A log that records the batch ids that have completed. This is used to check if a batch was + * fully processed, and its output was committed to the sink, hence no need to process it again. + * This is used (for instance) during restart, to help identify which batch to run next. + */ + val batchCommitLog = new BatchCommitLog(sparkSession, checkpointFile("commits")) + /** Whether all fields of the query have been initialized */ private def isInitialized: Boolean = state.get != INITIALIZING @@ -240,6 +256,8 @@ class StreamExecution( */ private def runBatches(): Unit = { try { + sparkSession.sparkContext.setJobGroup(runId.toString, getBatchDescriptionString, + interruptOnCancel = true) if (sparkSession.sessionState.conf.streamingMetricsEnabled) { sparkSession.sparkContext.env.metricsSystem.registerSource(streamMetrics) } @@ -256,6 +274,15 @@ class StreamExecution( updateStatusMessage("Initializing sources") // force initialization of the logical plan so that the sources can be created logicalPlan + + // Isolated spark session to run the batches with. + val sparkSessionToRunBatches = sparkSession.cloneSession() + // Adaptive execution can change num shuffle partitions, disallow + sparkSessionToRunBatches.conf.set(SQLConf.ADAPTIVE_EXECUTION_ENABLED.key, "false") + offsetSeqMetadata = OffsetSeqMetadata(batchWatermarkMs = 0, batchTimestampMs = 0, + conf = Map(SQLConf.SHUFFLE_PARTITIONS.key -> + sparkSessionToRunBatches.conf.get(SQLConf.SHUFFLE_PARTITIONS.key))) + if (state.compareAndSet(INITIALIZING, ACTIVE)) { // Unblock `awaitInitialization` initializationLatch.countDown() @@ -263,42 +290,40 @@ class StreamExecution( triggerExecutor.execute(() => { startTrigger() - val continueToRun = - if (isActive) { - reportTimeTaken("triggerExecution") { - if (currentBatchId < 0) { - // We'll do this initialization only once - populateStartOffsets() - logDebug(s"Stream running from $committedOffsets to $availableOffsets") - } else { - constructNextBatch() - } - if (dataAvailable) { - currentStatus = currentStatus.copy(isDataAvailable = true) - updateStatusMessage("Processing new data") - runBatch() - } + if (isActive) { + reportTimeTaken("triggerExecution") { + if (currentBatchId < 0) { + // We'll do this initialization only once + populateStartOffsets(sparkSessionToRunBatches) + sparkSession.sparkContext.setJobDescription(getBatchDescriptionString) + logDebug(s"Stream running from $committedOffsets to $availableOffsets") + } else { + constructNextBatch() } - - // Report trigger as finished and construct progress object. - finishTrigger(dataAvailable) if (dataAvailable) { - // We'll increase currentBatchId after we complete processing current batch's data - currentBatchId += 1 - } else { - currentStatus = currentStatus.copy(isDataAvailable = false) - updateStatusMessage("Waiting for data to arrive") - Thread.sleep(pollingDelayMs) + currentStatus = currentStatus.copy(isDataAvailable = true) + updateStatusMessage("Processing new data") + runBatch(sparkSessionToRunBatches) } - true + } + // Report trigger as finished and construct progress object. + finishTrigger(dataAvailable) + if (dataAvailable) { + // Update committed offsets. + batchCommitLog.add(currentBatchId) + committedOffsets ++= availableOffsets + logDebug(s"batch ${currentBatchId} committed") + // We'll increase currentBatchId after we complete processing current batch's data + currentBatchId += 1 + sparkSession.sparkContext.setJobDescription(getBatchDescriptionString) } else { - false + currentStatus = currentStatus.copy(isDataAvailable = false) + updateStatusMessage("Waiting for data to arrive") + Thread.sleep(pollingDelayMs) } - - // Update committed offsets. - committedOffsets ++= availableOffsets + } updateStatusMessage("Waiting for next trigger") - continueToRun + isActive }) updateStatusMessage("Stopped") } else { @@ -361,6 +386,13 @@ class StreamExecution( } } } finally { + awaitBatchLock.lock() + try { + // Wake up any threads that are waiting for the stream to progress. + awaitBatchLockCondition.signalAll() + } finally { + awaitBatchLock.unlock() + } terminationLatch.countDown() } } @@ -373,22 +405,84 @@ class StreamExecution( * - currentBatchId * - committedOffsets * - availableOffsets + * The basic structure of this method is as follows: + * + * Identify (from the offset log) the offsets used to run the last batch + * IF last batch exists THEN + * Set the next batch to be executed as the last recovered batch + * Check the commit log to see which batch was committed last + * IF the last batch was committed THEN + * Call getBatch using the last batch start and end offsets + * // ^^^^ above line is needed since some sources assume last batch always re-executes + * Setup for a new batch i.e., start = last batch end, and identify new end + * DONE + * ELSE + * Identify a brand new batch + * DONE */ - private def populateStartOffsets(): Unit = { + private def populateStartOffsets(sparkSessionToRunBatches: SparkSession): Unit = { offsetLog.getLatest() match { - case Some((batchId, nextOffsets)) => - logInfo(s"Resuming streaming query, starting with batch $batchId") - currentBatchId = batchId + case Some((latestBatchId, nextOffsets)) => + /* First assume that we are re-executing the latest known batch + * in the offset log */ + currentBatchId = latestBatchId availableOffsets = nextOffsets.toStreamProgress(sources) - offsetSeqMetadata = nextOffsets.metadata.getOrElse(OffsetSeqMetadata()) - logDebug(s"Found possibly unprocessed offsets $availableOffsets " + - s"at batch timestamp ${offsetSeqMetadata.batchTimestampMs}") - - offsetLog.get(batchId - 1).foreach { - case lastOffsets => - committedOffsets = lastOffsets.toStreamProgress(sources) - logDebug(s"Resuming with committed offsets: $committedOffsets") + /* Initialize committed offsets to a committed batch, which at this + * is the second latest batch id in the offset log. */ + offsetLog.get(latestBatchId - 1).foreach { secondLatestBatchId => + committedOffsets = secondLatestBatchId.toStreamProgress(sources) + } + + // update offset metadata + nextOffsets.metadata.foreach { metadata => + val shufflePartitionsSparkSession: Int = + sparkSessionToRunBatches.conf.get(SQLConf.SHUFFLE_PARTITIONS) + val shufflePartitionsToUse = metadata.conf.getOrElse(SQLConf.SHUFFLE_PARTITIONS.key, { + // For backward compatibility, if # partitions was not recorded in the offset log, + // then ensure it is not missing. The new value is picked up from the conf. + logWarning("Number of shuffle partitions from previous run not found in checkpoint. " + + s"Using the value from the conf, $shufflePartitionsSparkSession partitions.") + shufflePartitionsSparkSession + }) + offsetSeqMetadata = OffsetSeqMetadata( + metadata.batchWatermarkMs, metadata.batchTimestampMs, + metadata.conf + (SQLConf.SHUFFLE_PARTITIONS.key -> shufflePartitionsToUse.toString)) + // Update conf with correct number of shuffle partitions + sparkSessionToRunBatches.conf.set( + SQLConf.SHUFFLE_PARTITIONS.key, shufflePartitionsToUse.toString) } + + /* identify the current batch id: if commit log indicates we successfully processed the + * latest batch id in the offset log, then we can safely move to the next batch + * i.e., committedBatchId + 1 */ + batchCommitLog.getLatest() match { + case Some((latestCommittedBatchId, _)) => + if (latestBatchId == latestCommittedBatchId) { + /* The last batch was successfully committed, so we can safely process a + * new next batch but first: + * Make a call to getBatch using the offsets from previous batch. + * because certain sources (e.g., KafkaSource) assume on restart the last + * batch will be executed before getOffset is called again. */ + availableOffsets.foreach { ao: (Source, Offset) => + val (source, end) = ao + if (committedOffsets.get(source).map(_ != end).getOrElse(true)) { + val start = committedOffsets.get(source) + source.getBatch(start, end) + } + } + currentBatchId = latestCommittedBatchId + 1 + committedOffsets ++= availableOffsets + // Construct a new batch be recomputing availableOffsets + constructNextBatch() + } else if (latestCommittedBatchId < latestBatchId - 1) { + logWarning(s"Batch completion log latest batch id is " + + s"${latestCommittedBatchId}, which is not trailing " + + s"batchid $latestBatchId by one") + } + case None => logInfo("no commit log present") + } + logDebug(s"Resuming at batch $currentBatchId with committed offsets " + + s"$committedOffsets and available offsets $availableOffsets") case None => // We are starting this stream for the first time. logInfo(s"Starting new streaming query.") currentBatchId = 0 @@ -437,8 +531,7 @@ class StreamExecution( } } if (hasNewData) { - // Current batch timestamp in milliseconds - offsetSeqMetadata.batchTimestampMs = triggerClock.getTimeMillis() + var batchWatermarkMs = offsetSeqMetadata.batchWatermarkMs // Update the eventTime watermark if we find one in the plan. if (lastExecution != null) { lastExecution.executedPlan.collect { @@ -446,16 +539,19 @@ class StreamExecution( logDebug(s"Observed event time stats: ${e.eventTimeStats.value}") e.eventTimeStats.value.max - e.delayMs }.headOption.foreach { newWatermarkMs => - if (newWatermarkMs > offsetSeqMetadata.batchWatermarkMs) { + if (newWatermarkMs > batchWatermarkMs) { logInfo(s"Updating eventTime watermark to: $newWatermarkMs ms") - offsetSeqMetadata.batchWatermarkMs = newWatermarkMs + batchWatermarkMs = newWatermarkMs } else { logDebug( s"Event time didn't move: $newWatermarkMs < " + - s"${offsetSeqMetadata.batchWatermarkMs}") + s"$batchWatermarkMs") } } } + offsetSeqMetadata = offsetSeqMetadata.copy( + batchWatermarkMs = batchWatermarkMs, + batchTimestampMs = triggerClock.getTimeMillis()) // Current batch timestamp in milliseconds updateStatusMessage("Writing offsets to log") reportTimeTaken("walCommit") { @@ -483,6 +579,7 @@ class StreamExecution( // Note that purge is exclusive, i.e. it purges everything before the target ID. if (minBatchesToRetain < currentBatchId) { offsetLog.purge(currentBatchId - minBatchesToRetain) + batchCommitLog.purge(currentBatchId - minBatchesToRetain) } } } else { @@ -498,8 +595,9 @@ class StreamExecution( /** * Processes any data available between `availableOffsets` and `committedOffsets`. + * @param sparkSessionToRunBatch Isolated [[SparkSession]] to run this batch with. */ - private def runBatch(): Unit = { + private def runBatch(sparkSessionToRunBatch: SparkSession): Unit = { // Request unprocessed data from all sources. newData = reportTimeTaken("getBatch") { availableOffsets.flatMap { @@ -544,17 +642,17 @@ class StreamExecution( reportTimeTaken("queryPlanning") { lastExecution = new IncrementalExecution( - sparkSession, + sparkSessionToRunBatch, triggerLogicalPlan, outputMode, checkpointFile("state"), currentBatchId, - offsetSeqMetadata.batchWatermarkMs) + offsetSeqMetadata) lastExecution.executedPlan // Force the lazy generation of execution plan } val nextBatch = - new Dataset(sparkSession, lastExecution, RowEncoder(lastExecution.analyzed.schema)) + new Dataset(sparkSessionToRunBatch, lastExecution, RowEncoder(lastExecution.analyzed.schema)) reportTimeTaken("addBatch") { sink.addBatch(currentBatchId, nextBatch) @@ -594,8 +692,11 @@ class StreamExecution( // intentionally state.set(TERMINATED) if (microBatchThread.isAlive) { + sparkSession.sparkContext.cancelJobGroup(runId.toString) microBatchThread.interrupt() microBatchThread.join() + // microBatchThread may spawn new jobs, so we need to cancel again to prevent a leak + sparkSession.sparkContext.cancelJobGroup(runId.toString) } logInfo(s"Query $prettyIdString was stopped") } @@ -735,6 +836,11 @@ class StreamExecution( } } + private def getBatchDescriptionString: String = { + val batchDescription = if (currentBatchId < 0) "init" else currentBatchId.toString + Option(name).map(_ + "
    ").getOrElse("") + + s"id = $id
    runId = $runId
    batch = $batchDescription" + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TriggerExecutor.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TriggerExecutor.scala index ac510df209f0a..d188566f822b4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TriggerExecutor.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TriggerExecutor.scala @@ -29,6 +29,17 @@ trait TriggerExecutor { def execute(batchRunner: () => Boolean): Unit } +/** + * A trigger executor that runs a single batch only, then terminates. + */ +case class OneTimeExecutor() extends TriggerExecutor { + + /** + * Execute a single batch using `batchRunner`. + */ + override def execute(batchRunner: () => Boolean): Unit = batchRunner() +} + /** * A trigger executor that runs a batch every `intervalMs` milliseconds. */ @@ -36,21 +47,22 @@ case class ProcessingTimeExecutor(processingTime: ProcessingTime, clock: Clock = extends TriggerExecutor with Logging { private val intervalMs = processingTime.intervalMs + require(intervalMs >= 0) - override def execute(batchRunner: () => Boolean): Unit = { + override def execute(triggerHandler: () => Boolean): Unit = { while (true) { - val batchStartTimeMs = clock.getTimeMillis() - val terminated = !batchRunner() + val triggerTimeMs = clock.getTimeMillis + val nextTriggerTimeMs = nextBatchTime(triggerTimeMs) + val terminated = !triggerHandler() if (intervalMs > 0) { - val batchEndTimeMs = clock.getTimeMillis() - val batchElapsedTimeMs = batchEndTimeMs - batchStartTimeMs + val batchElapsedTimeMs = clock.getTimeMillis - triggerTimeMs if (batchElapsedTimeMs > intervalMs) { notifyBatchFallingBehind(batchElapsedTimeMs) } if (terminated) { return } - clock.waitTillTime(nextBatchTime(batchEndTimeMs)) + clock.waitTillTime(nextTriggerTimeMs) } else { if (terminated) { return @@ -59,7 +71,7 @@ case class ProcessingTimeExecutor(processingTime: ProcessingTime, clock: Clock = } } - /** Called when a batch falls behind. Expose for test only */ + /** Called when a batch falls behind */ def notifyBatchFallingBehind(realElapsedTimeMs: Long): Unit = { logWarning("Current batch is falling behind. The trigger interval is " + s"${intervalMs} milliseconds, but spent ${realElapsedTimeMs} milliseconds") @@ -72,6 +84,6 @@ case class ProcessingTimeExecutor(processingTime: ProcessingTime, clock: Clock = * an interval of `100 ms`, `nextBatchTime(nextBatchTime(0)) = 200` rather than `0`). */ def nextBatchTime(now: Long): Long = { - now / intervalMs * intervalMs + intervalMs + if (intervalMs == 0) now else now / intervalMs * intervalMs + intervalMs } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/Triggers.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/Triggers.scala new file mode 100644 index 0000000000000..271bc4da99c08 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/Triggers.scala @@ -0,0 +1,29 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.streaming + +import org.apache.spark.annotation.{Experimental, InterfaceStability} +import org.apache.spark.sql.streaming.Trigger + +/** + * A [[Trigger]] that process only one batch of data in a streaming query then terminates + * the query. + */ +@Experimental +@InterfaceStability.Evolving +case object OneTimeTrigger extends Trigger diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala index 6d34d51d31c1e..971ce5afb1778 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala @@ -25,11 +25,11 @@ import scala.util.control.NonFatal import org.apache.spark.internal.Logging import org.apache.spark.sql._ -import org.apache.spark.sql.catalyst.CatalystConf import org.apache.spark.sql.catalyst.encoders.encoderFor import org.apache.spark.sql.catalyst.expressions.Attribute import org.apache.spark.sql.catalyst.plans.logical.{LeafNode, Statistics} import org.apache.spark.sql.catalyst.streaming.InternalOutputModes._ +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.streaming.OutputMode import org.apache.spark.sql.types.StructType import org.apache.spark.util.Utils @@ -230,6 +230,6 @@ case class MemoryPlan(sink: MemorySink, output: Seq[Attribute]) extends LeafNode private val sizePerRow = sink.schema.toAttributes.map(_.dataType.defaultSize).sum - override def computeStats(conf: CatalystConf): Statistics = + override def computeStats(conf: SQLConf): Statistics = Statistics(sizePerRow * sink.allData.size) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala index ab1204a750fac..1426728f9b550 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.execution.streaming.state import java.io.{DataInputStream, DataOutputStream, FileNotFoundException, IOException} +import java.util.Locale import scala.collection.JavaConverters._ import scala.collection.mutable @@ -73,7 +74,12 @@ private[state] class HDFSBackedStateStoreProvider( hadoopConf: Configuration ) extends StateStoreProvider with Logging { - type MapType = java.util.HashMap[UnsafeRow, UnsafeRow] + // ConcurrentHashMap is used because it generates fail-safe iterators on filtering + // - The iterator is weakly consistent with the map, i.e., iterator's data reflect the values in + // the map when the iterator was created + // - Any updates to the map while iterating through the filtered iterator does not throw + // java.util.ConcurrentModificationException + type MapType = java.util.concurrent.ConcurrentHashMap[UnsafeRow, UnsafeRow] /** Implementation of [[StateStore]] API which is backed by a HDFS-compatible file system */ class HDFSBackedStateStore(val version: Long, mapToUpdate: MapType) @@ -99,6 +105,16 @@ private[state] class HDFSBackedStateStoreProvider( Option(mapToUpdate.get(key)) } + override def filter( + condition: (UnsafeRow, UnsafeRow) => Boolean): Iterator[(UnsafeRow, UnsafeRow)] = { + mapToUpdate + .entrySet + .asScala + .iterator + .filter { entry => condition(entry.getKey, entry.getValue) } + .map { entry => (entry.getKey, entry.getValue) } + } + override def put(key: UnsafeRow, value: UnsafeRow): Unit = { verify(state == UPDATING, "Cannot put after already committed or aborted") @@ -227,7 +243,7 @@ private[state] class HDFSBackedStateStoreProvider( } override def toString(): String = { - s"HDFSStateStore[id = (op=${id.operatorId}, part=${id.partitionId}), dir = $baseDir]" + s"HDFSStateStore[id=(op=${id.operatorId},part=${id.partitionId}),dir=$baseDir]" } } @@ -584,7 +600,7 @@ private[state] class HDFSBackedStateStoreProvider( val nameParts = path.getName.split("\\.") if (nameParts.size == 2) { val version = nameParts(0).toLong - nameParts(1).toLowerCase match { + nameParts(1).toLowerCase(Locale.ROOT) match { case "delta" => // ignore the file otherwise, snapshot file already exists for that batch id if (!versionToFiles.contains(version)) { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala index dcb24b26f78f3..eaa558eb6d0ed 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala @@ -50,6 +50,15 @@ trait StateStore { /** Get the current value of a key. */ def get(key: UnsafeRow): Option[UnsafeRow] + /** + * Return an iterator of key-value pairs that satisfy a certain condition. + * Note that the iterator must be fail-safe towards modification to the store, that is, + * it must be based on the snapshot of store the time of this call, and any change made to the + * store while iterating through iterator should not cause the iterator to fail or have + * any affect on the values in the iterator. + */ + def filter(condition: (UnsafeRow, UnsafeRow) => Boolean): Iterator[(UnsafeRow, UnsafeRow)] + /** Put a new value for a key. */ def put(key: UnsafeRow, value: UnsafeRow): Unit diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala index d92529748b6ac..8dbda298c87bc 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala @@ -19,17 +19,18 @@ package org.apache.spark.sql.execution.streaming import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder import org.apache.spark.sql.catalyst.errors._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.codegen.{GenerateUnsafeProjection, Predicate} -import org.apache.spark.sql.catalyst.plans.logical.{EventTimeWatermark, LogicalKeyedState} +import org.apache.spark.sql.catalyst.plans.logical.{EventTimeWatermark, LogicalGroupState, ProcessingTimeTimeout} import org.apache.spark.sql.catalyst.plans.physical.{ClusteredDistribution, Distribution, Partitioning} import org.apache.spark.sql.catalyst.streaming.InternalOutputModes._ import org.apache.spark.sql.execution._ -import org.apache.spark.sql.execution.metric.SQLMetrics +import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics} import org.apache.spark.sql.execution.streaming.state._ -import org.apache.spark.sql.streaming.OutputMode -import org.apache.spark.sql.types.{DataType, NullType, StructType} +import org.apache.spark.sql.streaming.{GroupStateTimeout, OutputMode} +import org.apache.spark.sql.types._ import org.apache.spark.util.CompletionIterator @@ -68,7 +69,7 @@ trait StateStoreWriter extends StatefulOperator { } /** An operator that supports watermark. */ -trait WatermarkSupport extends SparkPlan { +trait WatermarkSupport extends UnaryExecNode { /** The keys that may have a watermark attribute. */ def keyExpressions: Seq[Attribute] @@ -76,10 +77,10 @@ trait WatermarkSupport extends SparkPlan { /** The watermark value. */ def eventTimeWatermark: Option[Long] - /** Generate a predicate that matches data older than the watermark */ - lazy val watermarkPredicate: Option[Predicate] = { + /** Generate an expression that matches data older than the watermark */ + lazy val watermarkExpression: Option[Expression] = { val optionalWatermarkAttribute = - keyExpressions.find(_.metadata.contains(EventTimeWatermark.delayKey)) + child.output.find(_.metadata.contains(EventTimeWatermark.delayKey)) optionalWatermarkAttribute.map { watermarkAttribute => // If we are evicting based on a window, use the end of the window. Otherwise just @@ -96,9 +97,17 @@ trait WatermarkSupport extends SparkPlan { } logInfo(s"Filtering state store on: $evictionExpression") - newPredicate(evictionExpression, keyExpressions) + evictionExpression } } + + /** Predicate based on keys that matches data older than the watermark */ + lazy val watermarkPredicateForKeys: Option[Predicate] = + watermarkExpression.map(newPredicate(_, keyExpressions)) + + /** Predicate based on the child output that matches data older than the watermark. */ + lazy val watermarkPredicateForData: Option[Predicate] = + watermarkExpression.map(newPredicate(_, child.output)) } /** @@ -192,7 +201,7 @@ case class StateStoreSaveExec( } // Assumption: Append mode can be done only when watermark has been specified - store.remove(watermarkPredicate.get.eval _) + store.remove(watermarkPredicateForKeys.get.eval _) store.commit() numTotalStateRows += store.numKeys() @@ -207,7 +216,7 @@ case class StateStoreSaveExec( new Iterator[InternalRow] { // Filter late date using watermark if specified - private[this] val baseIterator = watermarkPredicate match { + private[this] val baseIterator = watermarkPredicateForData match { case Some(predicate) => iter.filter((row: InternalRow) => !predicate.eval(row)) case None => iter } @@ -215,7 +224,9 @@ case class StateStoreSaveExec( override def hasNext: Boolean = { if (!baseIterator.hasNext) { // Remove old aggregates if watermark specified - if (watermarkPredicate.nonEmpty) store.remove(watermarkPredicate.get.eval _) + if (watermarkPredicateForKeys.nonEmpty) { + store.remove(watermarkPredicateForKeys.get.eval _) + } store.commit() numTotalStateRows += store.numKeys() false @@ -244,94 +255,6 @@ case class StateStoreSaveExec( override def outputPartitioning: Partitioning = child.outputPartitioning } - -/** Physical operator for executing streaming mapGroupsWithState. */ -case class MapGroupsWithStateExec( - func: (Any, Iterator[Any], LogicalKeyedState[Any]) => Iterator[Any], - keyDeserializer: Expression, - valueDeserializer: Expression, - groupingAttributes: Seq[Attribute], - dataAttributes: Seq[Attribute], - outputObjAttr: Attribute, - stateId: Option[OperatorStateId], - stateDeserializer: Expression, - stateSerializer: Seq[NamedExpression], - child: SparkPlan) extends UnaryExecNode with ObjectProducerExec with StateStoreWriter { - - override def outputPartitioning: Partitioning = child.outputPartitioning - - /** Distribute by grouping attributes */ - override def requiredChildDistribution: Seq[Distribution] = - ClusteredDistribution(groupingAttributes) :: Nil - - /** Ordering needed for using GroupingIterator */ - override def requiredChildOrdering: Seq[Seq[SortOrder]] = - Seq(groupingAttributes.map(SortOrder(_, Ascending))) - - override protected def doExecute(): RDD[InternalRow] = { - child.execute().mapPartitionsWithStateStore[InternalRow]( - getStateId.checkpointLocation, - getStateId.operatorId, - getStateId.batchId, - groupingAttributes.toStructType, - child.output.toStructType, - sqlContext.sessionState, - Some(sqlContext.streams.stateStoreCoordinator)) { (store, iter) => - val numTotalStateRows = longMetric("numTotalStateRows") - val numUpdatedStateRows = longMetric("numUpdatedStateRows") - val numOutputRows = longMetric("numOutputRows") - - // Generate a iterator that returns the rows grouped by the grouping function - val groupedIter = GroupedIterator(iter, groupingAttributes, child.output) - - // Converters to and from object and rows - val getKeyObj = ObjectOperator.deserializeRowToObject(keyDeserializer, groupingAttributes) - val getValueObj = ObjectOperator.deserializeRowToObject(valueDeserializer, dataAttributes) - val getOutputRow = ObjectOperator.wrapObjectToRow(outputObjAttr.dataType) - val getStateObj = - ObjectOperator.deserializeRowToObject(stateDeserializer) - val outputStateObj = ObjectOperator.serializeObjectToRow(stateSerializer) - - // For every group, get the key, values and corresponding state and call the function, - // and return an iterator of rows - val allRowsIterator = groupedIter.flatMap { case (keyRow, valueRowIter) => - - val key = keyRow.asInstanceOf[UnsafeRow] - val keyObj = getKeyObj(keyRow) // convert key to objects - val valueObjIter = valueRowIter.map(getValueObj.apply) // convert value rows to objects - val stateObjOption = store.get(key).map(getStateObj) // get existing state if any - val wrappedState = new KeyedStateImpl(stateObjOption) - val mappedIterator = func(keyObj, valueObjIter, wrappedState).map { obj => - numOutputRows += 1 - getOutputRow(obj) // convert back to rows - } - - // Return an iterator of rows generated this key, - // such that fully consumed, the updated state value will be saved - CompletionIterator[InternalRow, Iterator[InternalRow]]( - mappedIterator, { - // When the iterator is consumed, then write changes to state - if (wrappedState.isRemoved) { - store.remove(key) - numUpdatedStateRows += 1 - } else if (wrappedState.isUpdated) { - store.put(key, outputStateObj(wrappedState.get)) - numUpdatedStateRows += 1 - } - }) - } - - // Return an iterator of all the rows generated by all the keys, such that when fully - // consumer, all the state updates will be committed by the state store - CompletionIterator[InternalRow, Iterator[InternalRow]](allRowsIterator, { - store.commit() - numTotalStateRows += store.numKeys() - }) - } - } -} - - /** Physical operator for executing streaming Deduplicate. */ case class StreamingDeduplicateExec( keyExpressions: Seq[Attribute], @@ -360,8 +283,8 @@ case class StreamingDeduplicateExec( val numTotalStateRows = longMetric("numTotalStateRows") val numUpdatedStateRows = longMetric("numUpdatedStateRows") - val baseIterator = watermarkPredicate match { - case Some(predicate) => iter.filter((row: InternalRow) => !predicate.eval(row)) + val baseIterator = watermarkPredicateForData match { + case Some(predicate) => iter.filter(row => !predicate.eval(row)) case None => iter } @@ -381,7 +304,7 @@ case class StreamingDeduplicateExec( } CompletionIterator[InternalRow, Iterator[InternalRow]](result, { - watermarkPredicate.foreach(f => store.remove(f.eval _)) + watermarkPredicateForKeys.foreach(f => store.remove(f.eval _)) store.commit() numTotalStateRows += store.numKeys() }) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/subquery.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/subquery.scala index 730ca27f82bac..d11045fb6ac8c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/subquery.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/subquery.scala @@ -144,16 +144,13 @@ case class PlanSubqueries(sparkSession: SparkSession) extends Rule[SparkPlan] { ScalarSubquery( SubqueryExec(s"subquery${subquery.exprId.id}", executedPlan), subquery.exprId) - case expressions.PredicateSubquery(query, Seq(e: Expression), _, exprId) => - val executedPlan = new QueryExecution(sparkSession, query).executedPlan - InSubquery(e, SubqueryExec(s"subquery${exprId.id}", executedPlan), exprId) } } } /** - * Find out duplicated exchanges in the spark plan, then use the same exchange for all the + * Find out duplicated subqueries in the spark plan, then use the same subquery result for all the * references. */ case class ReuseSubquery(conf: SQLConf) extends Rule[SparkPlan] { @@ -162,7 +159,7 @@ case class ReuseSubquery(conf: SQLConf) extends Rule[SparkPlan] { if (!conf.exchangeReuseEnabled) { return plan } - // Build a hash map using schema of exchanges to avoid O(N*N) sameResult calls. + // Build a hash map using schema of subqueries to avoid O(N*N) sameResult calls. val subqueries = mutable.HashMap[StructType, ArrayBuffer[SubqueryExec]]() plan transformAllExpressions { case sub: ExecSubqueryExpression => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLListener.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLListener.scala index 12d3bc9281f35..b4a91230a0012 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLListener.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLListener.scala @@ -47,6 +47,13 @@ case class SparkListenerSQLExecutionStart( case class SparkListenerSQLExecutionEnd(executionId: Long, time: Long) extends SparkListenerEvent +/** + * A message used to update SQL metric value for driver-side updates (which doesn't get reflected + * automatically). + * + * @param executionId The execution id for a query, so we can find the query plan. + * @param accumUpdates Map from accumulator id to the metric value (metrics are always 64-bit ints). + */ @DeveloperApi case class SparkListenerDriverAccumUpdates( executionId: Long, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/window/RowBuffer.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/window/RowBuffer.scala deleted file mode 100644 index ee36c84251519..0000000000000 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/window/RowBuffer.scala +++ /dev/null @@ -1,115 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.execution.window - -import scala.collection.mutable.ArrayBuffer - -import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.UnsafeRow -import org.apache.spark.util.collection.unsafe.sort.{UnsafeExternalSorter, UnsafeSorterIterator} - - -/** - * The interface of row buffer for a partition. In absence of a buffer pool (with locking), the - * row buffer is used to materialize a partition of rows since we need to repeatedly scan these - * rows in window function processing. - */ -private[window] abstract class RowBuffer { - - /** Number of rows. */ - def size: Int - - /** Return next row in the buffer, null if no more left. */ - def next(): InternalRow - - /** Skip the next `n` rows. */ - def skip(n: Int): Unit - - /** Return a new RowBuffer that has the same rows. */ - def copy(): RowBuffer -} - -/** - * A row buffer based on ArrayBuffer (the number of rows is limited). - */ -private[window] class ArrayRowBuffer(buffer: ArrayBuffer[UnsafeRow]) extends RowBuffer { - - private[this] var cursor: Int = -1 - - /** Number of rows. */ - override def size: Int = buffer.length - - /** Return next row in the buffer, null if no more left. */ - override def next(): InternalRow = { - cursor += 1 - if (cursor < buffer.length) { - buffer(cursor) - } else { - null - } - } - - /** Skip the next `n` rows. */ - override def skip(n: Int): Unit = { - cursor += n - } - - /** Return a new RowBuffer that has the same rows. */ - override def copy(): RowBuffer = { - new ArrayRowBuffer(buffer) - } -} - -/** - * An external buffer of rows based on UnsafeExternalSorter. - */ -private[window] class ExternalRowBuffer(sorter: UnsafeExternalSorter, numFields: Int) - extends RowBuffer { - - private[this] val iter: UnsafeSorterIterator = sorter.getIterator - - private[this] val currentRow = new UnsafeRow(numFields) - - /** Number of rows. */ - override def size: Int = iter.getNumRecords() - - /** Return next row in the buffer, null if no more left. */ - override def next(): InternalRow = { - if (iter.hasNext) { - iter.loadNext() - currentRow.pointTo(iter.getBaseObject, iter.getBaseOffset, iter.getRecordLength) - currentRow - } else { - null - } - } - - /** Skip the next `n` rows. */ - override def skip(n: Int): Unit = { - var i = 0 - while (i < n && iter.hasNext) { - iter.loadNext() - i += 1 - } - } - - /** Return a new RowBuffer that has the same rows. */ - override def copy(): RowBuffer = { - new ExternalRowBuffer(sorter, numFields) - } -} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/window/WindowExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/window/WindowExec.scala index 80b87d5ffa797..950a6794a74a3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/window/WindowExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/window/WindowExec.scala @@ -20,15 +20,13 @@ package org.apache.spark.sql.execution.window import scala.collection.mutable import scala.collection.mutable.ArrayBuffer -import org.apache.spark.{SparkEnv, TaskContext} import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate._ import org.apache.spark.sql.catalyst.plans.physical._ -import org.apache.spark.sql.execution.{SparkPlan, UnaryExecNode} +import org.apache.spark.sql.execution.{ExternalAppendOnlyUnsafeRowArray, SparkPlan, UnaryExecNode} import org.apache.spark.sql.types.IntegerType -import org.apache.spark.util.collection.unsafe.sort.UnsafeExternalSorter /** * This class calculates and outputs (windowed) aggregates over the rows in a single (sorted) @@ -284,6 +282,7 @@ case class WindowExec( // Unwrap the expressions and factories from the map. val expressions = windowFrameExpressionFactoryPairs.flatMap(_._1) val factories = windowFrameExpressionFactoryPairs.map(_._2).toArray + val spillThreshold = sqlContext.conf.windowExecBufferSpillThreshold // Start processing. child.execute().mapPartitions { stream => @@ -310,10 +309,12 @@ case class WindowExec( fetchNextRow() // Manage the current partition. - val rows = ArrayBuffer.empty[UnsafeRow] val inputFields = child.output.length - var sorter: UnsafeExternalSorter = null - var rowBuffer: RowBuffer = null + + val buffer: ExternalAppendOnlyUnsafeRowArray = + new ExternalAppendOnlyUnsafeRowArray(spillThreshold) + var bufferIterator: Iterator[UnsafeRow] = _ + val windowFunctionResult = new SpecificInternalRow(expressions.map(_.dataType)) val frames = factories.map(_(windowFunctionResult)) val numFrames = frames.length @@ -323,78 +324,43 @@ case class WindowExec( val currentGroup = nextGroup.copy() // clear last partition - if (sorter != null) { - // the last sorter of this task will be cleaned up via task completion listener - sorter.cleanupResources() - sorter = null - } else { - rows.clear() - } + buffer.clear() while (nextRowAvailable && nextGroup == currentGroup) { - if (sorter == null) { - rows += nextRow.copy() - - if (rows.length >= 4096) { - // We will not sort the rows, so prefixComparator and recordComparator are null. - sorter = UnsafeExternalSorter.create( - TaskContext.get().taskMemoryManager(), - SparkEnv.get.blockManager, - SparkEnv.get.serializerManager, - TaskContext.get(), - null, - null, - 1024, - SparkEnv.get.memoryManager.pageSizeBytes, - SparkEnv.get.conf.getLong("spark.shuffle.spill.numElementsForceSpillThreshold", - UnsafeExternalSorter.DEFAULT_NUM_ELEMENTS_FOR_SPILL_THRESHOLD), - false) - rows.foreach { r => - sorter.insertRecord(r.getBaseObject, r.getBaseOffset, r.getSizeInBytes, 0, false) - } - rows.clear() - } - } else { - sorter.insertRecord(nextRow.getBaseObject, nextRow.getBaseOffset, - nextRow.getSizeInBytes, 0, false) - } + buffer.add(nextRow) fetchNextRow() } - if (sorter != null) { - rowBuffer = new ExternalRowBuffer(sorter, inputFields) - } else { - rowBuffer = new ArrayRowBuffer(rows) - } // Setup the frames. var i = 0 while (i < numFrames) { - frames(i).prepare(rowBuffer.copy()) + frames(i).prepare(buffer) i += 1 } // Setup iteration rowIndex = 0 - rowsSize = rowBuffer.size + bufferIterator = buffer.generateIterator() } // Iteration var rowIndex = 0 - var rowsSize = 0L - override final def hasNext: Boolean = rowIndex < rowsSize || nextRowAvailable + override final def hasNext: Boolean = + (bufferIterator != null && bufferIterator.hasNext) || nextRowAvailable val join = new JoinedRow override final def next(): InternalRow = { // Load the next partition if we need to. - if (rowIndex >= rowsSize && nextRowAvailable) { + if ((bufferIterator == null || !bufferIterator.hasNext) && nextRowAvailable) { fetchNextPartition() } - if (rowIndex < rowsSize) { + if (bufferIterator.hasNext) { + val current = bufferIterator.next() + // Get the results for the window frames. var i = 0 - val current = rowBuffer.next() while (i < numFrames) { frames(i).write(rowIndex, current) i += 1 @@ -406,7 +372,9 @@ case class WindowExec( // Return the projection. result(join) - } else throw new NoSuchElementException + } else { + throw new NoSuchElementException + } } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/window/WindowFunctionFrame.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/window/WindowFunctionFrame.scala index 70efc0f78ddb0..af2b4fb92062b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/window/WindowFunctionFrame.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/window/WindowFunctionFrame.scala @@ -22,6 +22,7 @@ import java.util import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate.NoOp +import org.apache.spark.sql.execution.ExternalAppendOnlyUnsafeRowArray /** @@ -35,7 +36,7 @@ private[window] abstract class WindowFunctionFrame { * * @param rows to calculate the frame results for. */ - def prepare(rows: RowBuffer): Unit + def prepare(rows: ExternalAppendOnlyUnsafeRowArray): Unit /** * Write the current results to the target row. @@ -43,6 +44,12 @@ private[window] abstract class WindowFunctionFrame { def write(index: Int, current: InternalRow): Unit } +object WindowFunctionFrame { + def getNextOrNull(iterator: Iterator[UnsafeRow]): UnsafeRow = { + if (iterator.hasNext) iterator.next() else null + } +} + /** * The offset window frame calculates frames containing LEAD/LAG statements. * @@ -65,7 +72,12 @@ private[window] final class OffsetWindowFunctionFrame( extends WindowFunctionFrame { /** Rows of the partition currently being processed. */ - private[this] var input: RowBuffer = null + private[this] var input: ExternalAppendOnlyUnsafeRowArray = null + + /** + * An iterator over the [[input]] + */ + private[this] var inputIterator: Iterator[UnsafeRow] = _ /** Index of the input row currently used for output. */ private[this] var inputIndex = 0 @@ -103,20 +115,21 @@ private[window] final class OffsetWindowFunctionFrame( newMutableProjection(boundExpressions, Nil).target(target) } - override def prepare(rows: RowBuffer): Unit = { + override def prepare(rows: ExternalAppendOnlyUnsafeRowArray): Unit = { input = rows + inputIterator = input.generateIterator() // drain the first few rows if offset is larger than zero inputIndex = 0 while (inputIndex < offset) { - input.next() + if (inputIterator.hasNext) inputIterator.next() inputIndex += 1 } inputIndex = offset } override def write(index: Int, current: InternalRow): Unit = { - if (inputIndex >= 0 && inputIndex < input.size) { - val r = input.next() + if (inputIndex >= 0 && inputIndex < input.length) { + val r = WindowFunctionFrame.getNextOrNull(inputIterator) projection(r) } else { // Use default values since the offset row does not exist. @@ -143,7 +156,12 @@ private[window] final class SlidingWindowFunctionFrame( extends WindowFunctionFrame { /** Rows of the partition currently being processed. */ - private[this] var input: RowBuffer = null + private[this] var input: ExternalAppendOnlyUnsafeRowArray = null + + /** + * An iterator over the [[input]] + */ + private[this] var inputIterator: Iterator[UnsafeRow] = _ /** The next row from `input`. */ private[this] var nextRow: InternalRow = null @@ -164,9 +182,10 @@ private[window] final class SlidingWindowFunctionFrame( private[this] var inputLowIndex = 0 /** Prepare the frame for calculating a new partition. Reset all variables. */ - override def prepare(rows: RowBuffer): Unit = { + override def prepare(rows: ExternalAppendOnlyUnsafeRowArray): Unit = { input = rows - nextRow = rows.next() + inputIterator = input.generateIterator() + nextRow = WindowFunctionFrame.getNextOrNull(inputIterator) inputHighIndex = 0 inputLowIndex = 0 buffer.clear() @@ -180,7 +199,7 @@ private[window] final class SlidingWindowFunctionFrame( // the output row upper bound. while (nextRow != null && ubound.compare(nextRow, inputHighIndex, current, index) <= 0) { buffer.add(nextRow.copy()) - nextRow = input.next() + nextRow = WindowFunctionFrame.getNextOrNull(inputIterator) inputHighIndex += 1 bufferUpdated = true } @@ -195,7 +214,7 @@ private[window] final class SlidingWindowFunctionFrame( // Only recalculate and update when the buffer changes. if (bufferUpdated) { - processor.initialize(input.size) + processor.initialize(input.length) val iter = buffer.iterator() while (iter.hasNext) { processor.update(iter.next()) @@ -222,13 +241,12 @@ private[window] final class UnboundedWindowFunctionFrame( extends WindowFunctionFrame { /** Prepare the frame for calculating a new partition. Process all rows eagerly. */ - override def prepare(rows: RowBuffer): Unit = { - val size = rows.size - processor.initialize(size) - var i = 0 - while (i < size) { - processor.update(rows.next()) - i += 1 + override def prepare(rows: ExternalAppendOnlyUnsafeRowArray): Unit = { + processor.initialize(rows.length) + + val iterator = rows.generateIterator() + while (iterator.hasNext) { + processor.update(iterator.next()) } } @@ -261,7 +279,12 @@ private[window] final class UnboundedPrecedingWindowFunctionFrame( extends WindowFunctionFrame { /** Rows of the partition currently being processed. */ - private[this] var input: RowBuffer = null + private[this] var input: ExternalAppendOnlyUnsafeRowArray = null + + /** + * An iterator over the [[input]] + */ + private[this] var inputIterator: Iterator[UnsafeRow] = _ /** The next row from `input`. */ private[this] var nextRow: InternalRow = null @@ -273,11 +296,15 @@ private[window] final class UnboundedPrecedingWindowFunctionFrame( private[this] var inputIndex = 0 /** Prepare the frame for calculating a new partition. */ - override def prepare(rows: RowBuffer): Unit = { + override def prepare(rows: ExternalAppendOnlyUnsafeRowArray): Unit = { input = rows - nextRow = rows.next() inputIndex = 0 - processor.initialize(input.size) + inputIterator = input.generateIterator() + if (inputIterator.hasNext) { + nextRow = inputIterator.next() + } + + processor.initialize(input.length) } /** Write the frame columns for the current row to the given target row. */ @@ -288,7 +315,7 @@ private[window] final class UnboundedPrecedingWindowFunctionFrame( // the output row upper bound. while (nextRow != null && ubound.compare(nextRow, inputIndex, current, index) <= 0) { processor.update(nextRow) - nextRow = input.next() + nextRow = WindowFunctionFrame.getNextOrNull(inputIterator) inputIndex += 1 bufferUpdated = true } @@ -323,7 +350,7 @@ private[window] final class UnboundedFollowingWindowFunctionFrame( extends WindowFunctionFrame { /** Rows of the partition currently being processed. */ - private[this] var input: RowBuffer = null + private[this] var input: ExternalAppendOnlyUnsafeRowArray = null /** * Index of the first input row with a value equal to or greater than the lower bound of the @@ -332,7 +359,7 @@ private[window] final class UnboundedFollowingWindowFunctionFrame( private[this] var inputIndex = 0 /** Prepare the frame for calculating a new partition. */ - override def prepare(rows: RowBuffer): Unit = { + override def prepare(rows: ExternalAppendOnlyUnsafeRowArray): Unit = { input = rows inputIndex = 0 } @@ -341,25 +368,25 @@ private[window] final class UnboundedFollowingWindowFunctionFrame( override def write(index: Int, current: InternalRow): Unit = { var bufferUpdated = index == 0 - // Duplicate the input to have a new iterator - val tmp = input.copy() - - // Drop all rows from the buffer for which the input row value is smaller than + // Ignore all the rows from the buffer for which the input row value is smaller than // the output row lower bound. - tmp.skip(inputIndex) - var nextRow = tmp.next() + val iterator = input.generateIterator(startIndex = inputIndex) + + var nextRow = WindowFunctionFrame.getNextOrNull(iterator) while (nextRow != null && lbound.compare(nextRow, inputIndex, current, index) < 0) { - nextRow = tmp.next() inputIndex += 1 bufferUpdated = true + nextRow = WindowFunctionFrame.getNextOrNull(iterator) } // Only recalculate and update when the buffer changes. if (bufferUpdated) { - processor.initialize(input.size) - while (nextRow != null) { + processor.initialize(input.length) + if (nextRow != null) { processor.update(nextRow) - nextRow = tmp.next() + } + while (iterator.hasNext) { + processor.update(iterator.next()) } processor.evaluate(target) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/expressions/ReduceAggregator.scala b/sql/core/src/main/scala/org/apache/spark/sql/expressions/ReduceAggregator.scala index 174378304d4a5..e266ae55cc4d9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/expressions/ReduceAggregator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/expressions/ReduceAggregator.scala @@ -30,7 +30,7 @@ import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder private[sql] class ReduceAggregator[T: Encoder](func: (T, T) => T) extends Aggregator[T, (Boolean, T), T] { - private val encoder = implicitly[Encoder[T]] + @transient private val encoder = implicitly[Encoder[T]] override def zero: (Boolean, T) = (false, null.asInstanceOf[T]) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/expressions/Window.scala b/sql/core/src/main/scala/org/apache/spark/sql/expressions/Window.scala index f3cf3052ea3ea..00053485e614c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/expressions/Window.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/expressions/Window.scala @@ -113,7 +113,7 @@ object Window { * Creates a [[WindowSpec]] with the frame boundaries defined, * from `start` (inclusive) to `end` (inclusive). * - * Both `start` and `end` are relative positions from the current row. For example, "0" means + * Both `start` and `end` are positions relative to the current row. For example, "0" means * "current row", while "-1" means the row before the current row, and "5" means the fifth row * after the current row. * @@ -131,9 +131,9 @@ object Window { * import org.apache.spark.sql.expressions.Window * val df = Seq((1, "a"), (1, "a"), (2, "a"), (1, "b"), (2, "b"), (3, "b")) * .toDF("id", "category") - * df.withColumn("sum", - * sum('id) over Window.partitionBy('category).orderBy('id).rowsBetween(0,1)) - * .show() + * val byCategoryOrderedById = + * Window.partitionBy('category).orderBy('id).rowsBetween(Window.currentRow, 1) + * df.withColumn("sum", sum('id) over byCategoryOrderedById).show() * * +---+--------+---+ * | id|category|sum| @@ -150,7 +150,7 @@ object Window { * @param start boundary start, inclusive. The frame is unbounded if this is * the minimum long value (`Window.unboundedPreceding`). * @param end boundary end, inclusive. The frame is unbounded if this is the - * maximum long value (`Window.unboundedFollowing`). + * maximum long value (`Window.unboundedFollowing`). * @since 2.1.0 */ // Note: when updating the doc for this method, also update WindowSpec.rowsBetween. @@ -162,7 +162,7 @@ object Window { * Creates a [[WindowSpec]] with the frame boundaries defined, * from `start` (inclusive) to `end` (inclusive). * - * Both `start` and `end` are relative from the current row. For example, "0" means "current row", + * Both `start` and `end` are relative to the current row. For example, "0" means "current row", * while "-1" means one off before the current row, and "5" means the five off after the * current row. * @@ -183,9 +183,9 @@ object Window { * import org.apache.spark.sql.expressions.Window * val df = Seq((1, "a"), (1, "a"), (2, "a"), (1, "b"), (2, "b"), (3, "b")) * .toDF("id", "category") - * df.withColumn("sum", - * sum('id) over Window.partitionBy('category).orderBy('id).rangeBetween(0,1)) - * .show() + * val byCategoryOrderedById = + * Window.partitionBy('category).orderBy('id).rowsBetween(Window.currentRow, 1) + * df.withColumn("sum", sum('id) over byCategoryOrderedById).show() * * +---+--------+---+ * | id|category|sum| @@ -202,7 +202,7 @@ object Window { * @param start boundary start, inclusive. The frame is unbounded if this is * the minimum long value (`Window.unboundedPreceding`). * @param end boundary end, inclusive. The frame is unbounded if this is the - * maximum long value (`Window.unboundedFollowing`). + * maximum long value (`Window.unboundedFollowing`). * @since 2.1.0 */ // Note: when updating the doc for this method, also update WindowSpec.rangeBetween. @@ -221,7 +221,8 @@ object Window { * * {{{ * // PARTITION BY country ORDER BY date ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW - * Window.partitionBy("country").orderBy("date").rowsBetween(Long.MinValue, 0) + * Window.partitionBy("country").orderBy("date") + * .rowsBetween(Window.unboundedPreceding, Window.currentRow) * * // PARTITION BY country ORDER BY date ROWS BETWEEN 3 PRECEDING AND 3 FOLLOWING * Window.partitionBy("country").orderBy("date").rowsBetween(-3, 3) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/expressions/WindowSpec.scala b/sql/core/src/main/scala/org/apache/spark/sql/expressions/WindowSpec.scala index de7d7a1772753..6279d48c94de5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/expressions/WindowSpec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/expressions/WindowSpec.scala @@ -86,7 +86,7 @@ class WindowSpec private[sql]( * after the current row. * * We recommend users use `Window.unboundedPreceding`, `Window.unboundedFollowing`, - * and `[Window.currentRow` to specify special boundary values, rather than using integral + * and `Window.currentRow` to specify special boundary values, rather than using integral * values directly. * * A row based boundary is based on the position of the row within the partition. @@ -99,9 +99,9 @@ class WindowSpec private[sql]( * import org.apache.spark.sql.expressions.Window * val df = Seq((1, "a"), (1, "a"), (2, "a"), (1, "b"), (2, "b"), (3, "b")) * .toDF("id", "category") - * df.withColumn("sum", - * sum('id) over Window.partitionBy('category).orderBy('id).rowsBetween(0,1)) - * .show() + * val byCategoryOrderedById = + * Window.partitionBy('category).orderBy('id).rowsBetween(Window.currentRow, 1) + * df.withColumn("sum", sum('id) over byCategoryOrderedById).show() * * +---+--------+---+ * | id|category|sum| @@ -118,7 +118,7 @@ class WindowSpec private[sql]( * @param start boundary start, inclusive. The frame is unbounded if this is * the minimum long value (`Window.unboundedPreceding`). * @param end boundary end, inclusive. The frame is unbounded if this is the - * maximum long value (`Window.unboundedFollowing`). + * maximum long value (`Window.unboundedFollowing`). * @since 1.4.0 */ // Note: when updating the doc for this method, also update Window.rowsBetween. @@ -134,7 +134,7 @@ class WindowSpec private[sql]( * current row. * * We recommend users use `Window.unboundedPreceding`, `Window.unboundedFollowing`, - * and `[Window.currentRow` to specify special boundary values, rather than using integral + * and `Window.currentRow` to specify special boundary values, rather than using integral * values directly. * * A range based boundary is based on the actual value of the ORDER BY @@ -150,9 +150,9 @@ class WindowSpec private[sql]( * import org.apache.spark.sql.expressions.Window * val df = Seq((1, "a"), (1, "a"), (2, "a"), (1, "b"), (2, "b"), (3, "b")) * .toDF("id", "category") - * df.withColumn("sum", - * sum('id) over Window.partitionBy('category).orderBy('id).rangeBetween(0,1)) - * .show() + * val byCategoryOrderedById = + * Window.partitionBy('category).orderBy('id).rangeBetween(Window.currentRow, 1) + * df.withColumn("sum", sum('id) over byCategoryOrderedById).show() * * +---+--------+---+ * | id|category|sum| @@ -169,7 +169,7 @@ class WindowSpec private[sql]( * @param start boundary start, inclusive. The frame is unbounded if this is * the minimum long value (`Window.unboundedPreceding`). * @param end boundary end, inclusive. The frame is unbounded if this is the - * maximum long value (`Window.unboundedFollowing`). + * maximum long value (`Window.unboundedFollowing`). * @since 1.4.0 */ // Note: when updating the doc for this method, also update Window.rangeBetween. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index 201f726db3fad..f07e04368389f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -21,6 +21,7 @@ import scala.collection.JavaConverters._ import scala.language.implicitConversions import scala.reflect.runtime.universe.{typeTag, TypeTag} import scala.util.Try +import scala.util.control.NonFatal import org.apache.spark.annotation.{Experimental, InterfaceStability} import org.apache.spark.sql.catalyst.ScalaReflection @@ -1861,7 +1862,7 @@ object functions { def rint(columnName: String): Column = rint(Column(columnName)) /** - * Returns the value of the column `e` rounded to 0 decimal places. + * Returns the value of the column `e` rounded to 0 decimal places with HALF_UP round mode. * * @group math_funcs * @since 1.5.0 @@ -1869,8 +1870,8 @@ object functions { def round(e: Column): Column = round(e, 0) /** - * Round the value of `e` to `scale` decimal places if `scale` is greater than or equal to 0 - * or at integral part when `scale` is less than 0. + * Round the value of `e` to `scale` decimal places with HALF_UP round mode + * if `scale` is greater than or equal to 0 or at integral part when `scale` is less than 0. * * @group math_funcs * @since 1.5.0 @@ -2191,8 +2192,8 @@ object functions { } /** - * Formats numeric column x to a format like '#,###,###.##', rounded to d decimal places, - * and returns the result as a string column. + * Formats numeric column x to a format like '#,###,###.##', rounded to d decimal places + * with HALF_EVEN round mode, and returns the result as a string column. * * If d is 0, the result has no decimal point or fractional part. * If d is less than 0, the result will be null. @@ -2896,7 +2897,7 @@ object functions { ////////////////////////////////////////////////////////////////////////////////////////////// /** - * Returns true if the array contains `value` + * Returns null if the array is null, true if the array contains `value`, and false otherwise. * @group collection_funcs * @since 1.5.0 */ @@ -2967,7 +2968,7 @@ object functions { * * @param e a string column containing JSON data. * @param schema the schema to use when parsing the json string - * @param options options to control how the json is parsed. accepts the same options and the + * @param options options to control how the json is parsed. Accepts the same options as the * json data source. * * @group collection_funcs @@ -2978,7 +2979,8 @@ object functions { /** * (Scala-specific) Parses a column containing a JSON string into a `StructType` or `ArrayType` - * with the specified schema. Returns `null`, in the case of an unparseable string. + * of `StructType`s with the specified schema. Returns `null`, in the case of an unparseable + * string. * * @param e a string column containing JSON data. * @param schema the schema to use when parsing the json string @@ -2989,7 +2991,7 @@ object functions { * @since 2.2.0 */ def from_json(e: Column, schema: DataType, options: Map[String, String]): Column = withExpr { - JsonToStruct(schema, options, e.expr) + JsonToStructs(schema, options, e.expr) } /** @@ -3009,7 +3011,8 @@ object functions { /** * (Java-specific) Parses a column containing a JSON string into a `StructType` or `ArrayType` - * with the specified schema. Returns `null`, in the case of an unparseable string. + * of `StructType`s with the specified schema. Returns `null`, in the case of an unparseable + * string. * * @param e a string column containing JSON data. * @param schema the schema to use when parsing the json string @@ -3036,7 +3039,7 @@ object functions { from_json(e, schema, Map.empty[String, String]) /** - * Parses a column containing a JSON string into a `StructType` or `ArrayType` + * Parses a column containing a JSON string into a `StructType` or `ArrayType` of `StructType`s * with the specified schema. Returns `null`, in the case of an unparseable string. * * @param e a string column containing JSON data. @@ -3049,23 +3052,32 @@ object functions { from_json(e, schema, Map.empty[String, String]) /** - * Parses a column containing a JSON string into a `StructType` or `ArrayType` + * Parses a column containing a JSON string into a `StructType` or `ArrayType` of `StructType`s * with the specified schema. Returns `null`, in the case of an unparseable string. * * @param e a string column containing JSON data. - * @param schema the schema to use when parsing the json string as a json string + * @param schema the schema to use when parsing the json string as a json string. In Spark 2.1, + * the user-provided schema has to be in JSON format. Since Spark 2.2, the DDL + * format is also supported for the schema. * * @group collection_funcs * @since 2.1.0 */ - def from_json(e: Column, schema: String, options: java.util.Map[String, String]): Column = - from_json(e, DataType.fromJson(schema), options) + def from_json(e: Column, schema: String, options: java.util.Map[String, String]): Column = { + val dataType = try { + DataType.fromJson(schema) + } catch { + case NonFatal(_) => StructType.fromDDL(schema) + } + from_json(e, dataType, options) + } /** - * (Scala-specific) Converts a column containing a `StructType` into a JSON string with the - * specified schema. Throws an exception, in the case of an unsupported type. + * (Scala-specific) Converts a column containing a `StructType` or `ArrayType` of `StructType`s + * into a JSON string with the specified schema. Throws an exception, in the case of an + * unsupported type. * - * @param e a struct column. + * @param e a column containing a struct or array of the structs. * @param options options to control how the struct column is converted into a json string. * accepts the same options and the json data source. * @@ -3073,14 +3085,15 @@ object functions { * @since 2.1.0 */ def to_json(e: Column, options: Map[String, String]): Column = withExpr { - StructToJson(options, e.expr) + StructsToJson(options, e.expr) } /** - * (Java-specific) Converts a column containing a `StructType` into a JSON string with the - * specified schema. Throws an exception, in the case of an unsupported type. + * (Java-specific) Converts a column containing a `StructType` or `ArrayType` of `StructType`s + * into a JSON string with the specified schema. Throws an exception, in the case of an + * unsupported type. * - * @param e a struct column. + * @param e a column containing a struct or array of the structs. * @param options options to control how the struct column is converted into a json string. * accepts the same options and the json data source. * @@ -3091,10 +3104,10 @@ object functions { to_json(e, options.asScala.toMap) /** - * Converts a column containing a `StructType` into a JSON string with the - * specified schema. Throws an exception, in the case of an unsupported type. + * Converts a column containing a `StructType` or `ArrayType` of `StructType`s into a JSON string + * with the specified schema. Throws an exception, in the case of an unsupported type. * - * @param e a struct column. + * @param e a column containing a struct or array of the structs. * * @group collection_funcs * @since 2.1.0 diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala new file mode 100644 index 0000000000000..2a801d87b12eb --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala @@ -0,0 +1,322 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.internal + +import org.apache.spark.SparkConf +import org.apache.spark.annotation.{Experimental, InterfaceStability} +import org.apache.spark.sql.{ExperimentalMethods, SparkSession, UDFRegistration, _} +import org.apache.spark.sql.catalyst.analysis.{Analyzer, FunctionRegistry} +import org.apache.spark.sql.catalyst.catalog.SessionCatalog +import org.apache.spark.sql.catalyst.optimizer.Optimizer +import org.apache.spark.sql.catalyst.parser.ParserInterface +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.catalyst.rules.Rule +import org.apache.spark.sql.execution.{QueryExecution, SparkOptimizer, SparkPlanner, SparkSqlParser} +import org.apache.spark.sql.execution.datasources._ +import org.apache.spark.sql.streaming.StreamingQueryManager +import org.apache.spark.sql.util.ExecutionListenerManager + +/** + * Builder class that coordinates construction of a new [[SessionState]]. + * + * The builder explicitly defines all components needed by the session state, and creates a session + * state when `build` is called. Components should only be initialized once. This is not a problem + * for most components as they are only used in the `build` function. However some components + * (`conf`, `catalog`, `functionRegistry`, `experimentalMethods` & `sqlParser`) are as dependencies + * for other components and are shared as a result. These components are defined as lazy vals to + * make sure the component is created only once. + * + * A developer can modify the builder by providing custom versions of components, or by using the + * hooks provided for the analyzer, optimizer & planner. There are some dependencies between the + * components (they are documented per dependency), a developer should respect these when making + * modifications in order to prevent initialization problems. + * + * A parent [[SessionState]] can be used to initialize the new [[SessionState]]. The new session + * state will clone the parent sessions state's `conf`, `functionRegistry`, `experimentalMethods` + * and `catalog` fields. Note that the state is cloned when `build` is called, and not before. + */ +@Experimental +@InterfaceStability.Unstable +abstract class BaseSessionStateBuilder( + val session: SparkSession, + val parentState: Option[SessionState] = None) { + type NewBuilder = (SparkSession, Option[SessionState]) => BaseSessionStateBuilder + + /** + * Function that produces a new instance of the SessionStateBuilder. This is used by the + * [[SessionState]]'s clone functionality. Make sure to override this when implementing your own + * [[SessionStateBuilder]]. + */ + protected def newBuilder: NewBuilder + + /** + * Session extensions defined in the [[SparkSession]]. + */ + protected def extensions: SparkSessionExtensions = session.extensions + + /** + * Extract entries from `SparkConf` and put them in the `SQLConf` + */ + protected def mergeSparkConf(sqlConf: SQLConf, sparkConf: SparkConf): Unit = { + sparkConf.getAll.foreach { case (k, v) => + sqlConf.setConfString(k, v) + } + } + + /** + * SQL-specific key-value configurations. + * + * These either get cloned from a pre-existing instance or newly created. The conf is always + * merged with its [[SparkConf]]. + */ + protected lazy val conf: SQLConf = { + val conf = parentState.map(_.conf.clone()).getOrElse(new SQLConf) + mergeSparkConf(conf, session.sparkContext.conf) + conf + } + + /** + * Internal catalog managing functions registered by the user. + * + * This either gets cloned from a pre-existing version or cloned from the built-in registry. + */ + protected lazy val functionRegistry: FunctionRegistry = { + parentState.map(_.functionRegistry).getOrElse(FunctionRegistry.builtin).clone() + } + + /** + * Experimental methods that can be used to define custom optimization rules and custom planning + * strategies. + * + * This either gets cloned from a pre-existing version or newly created. + */ + protected lazy val experimentalMethods: ExperimentalMethods = { + parentState.map(_.experimentalMethods.clone()).getOrElse(new ExperimentalMethods) + } + + /** + * Parser that extracts expressions, plans, table identifiers etc. from SQL texts. + * + * Note: this depends on the `conf` field. + */ + protected lazy val sqlParser: ParserInterface = { + extensions.buildParser(session, new SparkSqlParser(conf)) + } + + /** + * ResourceLoader that is used to load function resources and jars. + */ + protected lazy val resourceLoader: SessionResourceLoader = new SessionResourceLoader(session) + + /** + * Catalog for managing table and database states. If there is a pre-existing catalog, the state + * of that catalog (temp tables & current database) will be copied into the new catalog. + * + * Note: this depends on the `conf`, `functionRegistry` and `sqlParser` fields. + */ + protected lazy val catalog: SessionCatalog = { + val catalog = new SessionCatalog( + session.sharedState.externalCatalog, + session.sharedState.globalTempViewManager, + functionRegistry, + conf, + SessionState.newHadoopConf(session.sparkContext.hadoopConfiguration, conf), + sqlParser, + resourceLoader) + parentState.foreach(_.catalog.copyStateTo(catalog)) + catalog + } + + /** + * Interface exposed to the user for registering user-defined functions. + * + * Note 1: The user-defined functions must be deterministic. + * Note 2: This depends on the `functionRegistry` field. + */ + protected def udfRegistration: UDFRegistration = new UDFRegistration(functionRegistry) + + /** + * Logical query plan analyzer for resolving unresolved attributes and relations. + * + * Note: this depends on the `conf` and `catalog` fields. + */ + protected def analyzer: Analyzer = new Analyzer(catalog, conf) { + override val extendedResolutionRules: Seq[Rule[LogicalPlan]] = + new FindDataSourceTable(session) +: + new ResolveSQLOnFile(session) +: + customResolutionRules + + override val postHocResolutionRules: Seq[Rule[LogicalPlan]] = + PreprocessTableCreation(session) +: + PreprocessTableInsertion(conf) +: + DataSourceAnalysis(conf) +: + customPostHocResolutionRules + + override val extendedCheckRules: Seq[LogicalPlan => Unit] = + PreWriteCheck +: + HiveOnlyCheck +: + customCheckRules + } + + /** + * Custom resolution rules to add to the Analyzer. Prefer overriding this instead of creating + * your own Analyzer. + * + * Note that this may NOT depend on the `analyzer` function. + */ + protected def customResolutionRules: Seq[Rule[LogicalPlan]] = { + extensions.buildResolutionRules(session) + } + + /** + * Custom post resolution rules to add to the Analyzer. Prefer overriding this instead of + * creating your own Analyzer. + * + * Note that this may NOT depend on the `analyzer` function. + */ + protected def customPostHocResolutionRules: Seq[Rule[LogicalPlan]] = { + extensions.buildPostHocResolutionRules(session) + } + + /** + * Custom check rules to add to the Analyzer. Prefer overriding this instead of creating + * your own Analyzer. + * + * Note that this may NOT depend on the `analyzer` function. + */ + protected def customCheckRules: Seq[LogicalPlan => Unit] = { + extensions.buildCheckRules(session) + } + + /** + * Logical query plan optimizer. + * + * Note: this depends on the `conf`, `catalog` and `experimentalMethods` fields. + */ + protected def optimizer: Optimizer = { + new SparkOptimizer(catalog, conf, experimentalMethods) { + override def extendedOperatorOptimizationRules: Seq[Rule[LogicalPlan]] = + super.extendedOperatorOptimizationRules ++ customOperatorOptimizationRules + } + } + + /** + * Custom operator optimization rules to add to the Optimizer. Prefer overriding this instead + * of creating your own Optimizer. + * + * Note that this may NOT depend on the `optimizer` function. + */ + protected def customOperatorOptimizationRules: Seq[Rule[LogicalPlan]] = { + extensions.buildOptimizerRules(session) + } + + /** + * Planner that converts optimized logical plans to physical plans. + * + * Note: this depends on the `conf` and `experimentalMethods` fields. + */ + protected def planner: SparkPlanner = { + new SparkPlanner(session.sparkContext, conf, experimentalMethods) { + override def extraPlanningStrategies: Seq[Strategy] = + super.extraPlanningStrategies ++ customPlanningStrategies + } + } + + /** + * Custom strategies to add to the planner. Prefer overriding this instead of creating + * your own Planner. + * + * Note that this may NOT depend on the `planner` function. + */ + protected def customPlanningStrategies: Seq[Strategy] = { + extensions.buildPlannerStrategies(session) + } + + /** + * Create a query execution object. + */ + protected def createQueryExecution: LogicalPlan => QueryExecution = { plan => + new QueryExecution(session, plan) + } + + /** + * Interface to start and stop streaming queries. + */ + protected def streamingQueryManager: StreamingQueryManager = new StreamingQueryManager(session) + + /** + * An interface to register custom [[org.apache.spark.sql.util.QueryExecutionListener]]s + * that listen for execution metrics. + * + * This gets cloned from parent if available, otherwise is a new instance is created. + */ + protected def listenerManager: ExecutionListenerManager = { + parentState.map(_.listenerManager.clone()).getOrElse(new ExecutionListenerManager) + } + + /** + * Function used to make clones of the session state. + */ + protected def createClone: (SparkSession, SessionState) => SessionState = { + val createBuilder = newBuilder + (session, state) => createBuilder(session, Option(state)).build() + } + + /** + * Build the [[SessionState]]. + */ + def build(): SessionState = { + new SessionState( + session.sharedState, + conf, + experimentalMethods, + functionRegistry, + udfRegistration, + catalog, + sqlParser, + analyzer, + optimizer, + planner, + streamingQueryManager, + listenerManager, + resourceLoader, + createQueryExecution, + createClone) + } +} + +/** + * Helper class for using SessionStateBuilders during tests. + */ +private[sql] trait WithTestConf { self: BaseSessionStateBuilder => + def overrideConfs: Map[String, String] + + override protected lazy val conf: SQLConf = { + val conf = parentState.map(_.conf.clone()).getOrElse { + new SQLConf { + clear() + override def clear(): Unit = { + super.clear() + // Make sure we start with the default test configs even after clear + overrideConfs.foreach { case (key, value) => setConfString(key, value) } + } + } + } + mergeSparkConf(conf, session.sparkContext.conf) + conf + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/CatalogImpl.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/CatalogImpl.scala index 3d9f41832bc73..0b8e53868c999 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/internal/CatalogImpl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/CatalogImpl.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.internal import scala.reflect.runtime.universe.TypeTag +import scala.util.control.NonFatal import org.apache.spark.annotation.Experimental import org.apache.spark.sql._ @@ -77,7 +78,7 @@ class CatalogImpl(sparkSession: SparkSession) extends Catalog { new Database( name = metadata.name, description = metadata.description, - locationUri = metadata.locationUri) + locationUri = CatalogUtils.URIToString(metadata.locationUri)) } /** @@ -98,14 +99,27 @@ class CatalogImpl(sparkSession: SparkSession) extends Catalog { CatalogImpl.makeDataset(tables, sparkSession) } + /** + * Returns a Table for the given table/view or temporary view. + * + * Note that this function requires the table already exists in the Catalog. + * + * If the table metadata retrieval failed due to any reason (e.g., table serde class + * is not accessible or the table type is not accepted by Spark SQL), this function + * still returns the corresponding Table without the description and tableType) + */ private def makeTable(tableIdent: TableIdentifier): Table = { - val metadata = sessionCatalog.getTempViewOrPermanentTableMetadata(tableIdent) + val metadata = try { + Some(sessionCatalog.getTempViewOrPermanentTableMetadata(tableIdent)) + } catch { + case NonFatal(_) => None + } val isTemp = sessionCatalog.isTemporaryTable(tableIdent) new Table( name = tableIdent.table, - database = metadata.identifier.database.orNull, - description = metadata.comment.orNull, - tableType = if (isTemp) "TEMPORARY" else metadata.tableType.name, + database = metadata.map(_.identifier.database).getOrElse(tableIdent.database).orNull, + description = metadata.map(_.comment.orNull).orNull, + tableType = if (isTemp) "TEMPORARY" else metadata.map(_.tableType.name).orNull, isTemporary = isTemp) } @@ -141,15 +155,16 @@ class CatalogImpl(sparkSession: SparkSession) extends Catalog { } /** - * Returns a list of columns for the given table in the current database. + * Returns a list of columns for the given table/view or temporary view. */ @throws[AnalysisException]("table does not exist") override def listColumns(tableName: String): Dataset[Column] = { - listColumns(TableIdentifier(tableName, None)) + val tableIdent = sparkSession.sessionState.sqlParser.parseTableIdentifier(tableName) + listColumns(tableIdent) } /** - * Returns a list of columns for the given table in the specified database. + * Returns a list of columns for the given table/view or temporary view in the specified database. */ @throws[AnalysisException]("database or table does not exist") override def listColumns(dbName: String, tableName: String): Dataset[Column] = { @@ -175,7 +190,7 @@ class CatalogImpl(sparkSession: SparkSession) extends Catalog { } /** - * Get the database with the specified name. This throws an `AnalysisException` when no + * Gets the database with the specified name. This throws an `AnalysisException` when no * `Database` can be found. */ override def getDatabase(dbName: String): Database = { @@ -183,33 +198,37 @@ class CatalogImpl(sparkSession: SparkSession) extends Catalog { } /** - * Get the table or view with the specified name. This table can be a temporary view or a - * table/view in the current database. This throws an `AnalysisException` when no `Table` - * can be found. + * Gets the table or view with the specified name. This table can be a temporary view or a + * table/view. This throws an `AnalysisException` when no `Table` can be found. */ override def getTable(tableName: String): Table = { - getTable(null, tableName) + val tableIdent = sparkSession.sessionState.sqlParser.parseTableIdentifier(tableName) + getTable(tableIdent.database.orNull, tableIdent.table) } /** - * Get the table or view with the specified name in the specified database. This throws an + * Gets the table or view with the specified name in the specified database. This throws an * `AnalysisException` when no `Table` can be found. */ override def getTable(dbName: String, tableName: String): Table = { - makeTable(TableIdentifier(tableName, Option(dbName))) + if (tableExists(dbName, tableName)) { + makeTable(TableIdentifier(tableName, Option(dbName))) + } else { + throw new AnalysisException(s"Table or view '$tableName' not found in database '$dbName'") + } } /** - * Get the function with the specified name. This function can be a temporary function or a - * function in the current database. This throws an `AnalysisException` when no `Function` - * can be found. + * Gets the function with the specified name. This function can be a temporary function or a + * function. This throws an `AnalysisException` when no `Function` can be found. */ override def getFunction(functionName: String): Function = { - getFunction(null, functionName) + val functionIdent = sparkSession.sessionState.sqlParser.parseFunctionIdentifier(functionName) + getFunction(functionIdent.database.orNull, functionIdent.funcName) } /** - * Get the function with the specified name. This returns `None` when no `Function` can be + * Gets the function with the specified name. This returns `None` when no `Function` can be * found. */ override def getFunction(dbName: String, functionName: String): Function = { @@ -217,22 +236,23 @@ class CatalogImpl(sparkSession: SparkSession) extends Catalog { } /** - * Check if the database with the specified name exists. + * Checks if the database with the specified name exists. */ override def databaseExists(dbName: String): Boolean = { sessionCatalog.databaseExists(dbName) } /** - * Check if the table or view with the specified name exists. This can either be a temporary - * view or a table/view in the current database. + * Checks if the table or view with the specified name exists. This can either be a temporary + * view or a table/view. */ override def tableExists(tableName: String): Boolean = { - tableExists(null, tableName) + val tableIdent = sparkSession.sessionState.sqlParser.parseTableIdentifier(tableName) + tableExists(tableIdent.database.orNull, tableIdent.table) } /** - * Check if the table or view with the specified name exists in the specified database. + * Checks if the table or view with the specified name exists in the specified database. */ override def tableExists(dbName: String, tableName: String): Boolean = { val tableIdent = TableIdentifier(tableName, Option(dbName)) @@ -240,15 +260,16 @@ class CatalogImpl(sparkSession: SparkSession) extends Catalog { } /** - * Check if the function with the specified name exists. This can either be a temporary function - * or a function in the current database. + * Checks if the function with the specified name exists. This can either be a temporary function + * or a function. */ override def functionExists(functionName: String): Boolean = { - functionExists(null, functionName) + val functionIdent = sparkSession.sessionState.sqlParser.parseFunctionIdentifier(functionName) + functionExists(functionIdent.database.orNull, functionIdent.funcName) } /** - * Check if the function with the specified name exists in the specified database. + * Checks if the function with the specified name exists in the specified database. */ override def functionExists(dbName: String, functionName: String): Boolean = { sessionCatalog.functionExists(FunctionIdentifier(functionName, Option(dbName))) @@ -270,7 +291,7 @@ class CatalogImpl(sparkSession: SparkSession) extends Catalog { /** * :: Experimental :: - * Creates a table from the given path based on a data source and returns the corresponding + * Creates a table from the given path and returns the corresponding * DataFrame. * * @group ddl_ops @@ -284,7 +305,7 @@ class CatalogImpl(sparkSession: SparkSession) extends Catalog { /** * :: Experimental :: * (Scala-specific) - * Creates a table from the given path based on a data source and a set of options. + * Creates a table based on the dataset in a data source and a set of options. * Then, returns the corresponding DataFrame. * * @group ddl_ops @@ -301,7 +322,7 @@ class CatalogImpl(sparkSession: SparkSession) extends Catalog { /** * :: Experimental :: * (Scala-specific) - * Create a table from the given path based on a data source, a schema and a set of options. + * Creates a table based on the dataset in a data source, a schema and a set of options. * Then, returns the corresponding DataFrame. * * @group ddl_ops @@ -336,13 +357,13 @@ class CatalogImpl(sparkSession: SparkSession) extends Catalog { * Drops the local temporary view with the given view name in the catalog. * If the view has been cached/persisted before, it's also unpersisted. * - * @param viewName the name of the view to be dropped. + * @param viewName the identifier of the temporary view to be dropped. * @group ddl_ops * @since 2.0.0 */ override def dropTempView(viewName: String): Boolean = { - sparkSession.sessionState.catalog.getTempView(viewName).exists { tempView => - sparkSession.sharedState.cacheManager.uncacheQuery(Dataset.ofRows(sparkSession, tempView)) + sparkSession.sessionState.catalog.getTempView(viewName).exists { viewDef => + sparkSession.sharedState.cacheManager.uncacheQuery(sparkSession, viewDef, blocking = true) sessionCatalog.dropTempView(viewName) } } @@ -351,21 +372,24 @@ class CatalogImpl(sparkSession: SparkSession) extends Catalog { * Drops the global temporary view with the given view name in the catalog. * If the view has been cached/persisted before, it's also unpersisted. * - * @param viewName the name of the view to be dropped. + * @param viewName the identifier of the global temporary view to be dropped. * @group ddl_ops * @since 2.1.0 */ override def dropGlobalTempView(viewName: String): Boolean = { sparkSession.sessionState.catalog.getGlobalTempView(viewName).exists { viewDef => - sparkSession.sharedState.cacheManager.uncacheQuery(Dataset.ofRows(sparkSession, viewDef)) + sparkSession.sharedState.cacheManager.uncacheQuery(sparkSession, viewDef, blocking = true) sessionCatalog.dropGlobalTempView(viewName) } } /** - * Recover all the partitions in the directory of a table and update the catalog. + * Recovers all the partitions in the directory of a table and update the catalog. + * Only works with a partitioned table, and not a temporary view. * - * @param tableName the name of the table to be repaired. + * @param tableName is either a qualified or unqualified name that designates a table. + * If no database identifier is provided, it refers to a table in the + * current database. * @group ddl_ops * @since 2.1.1 */ @@ -376,7 +400,7 @@ class CatalogImpl(sparkSession: SparkSession) extends Catalog { } /** - * Returns true if the table is currently cached in-memory. + * Returns true if the table or view is currently cached in-memory. * * @group cachemgmt * @since 2.0.0 @@ -386,7 +410,7 @@ class CatalogImpl(sparkSession: SparkSession) extends Catalog { } /** - * Caches the specified table in-memory. + * Caches the specified table or view in-memory. * * @group cachemgmt * @since 2.0.0 @@ -396,17 +420,17 @@ class CatalogImpl(sparkSession: SparkSession) extends Catalog { } /** - * Removes the specified table from the in-memory cache. + * Removes the specified table or view from the in-memory cache. * * @group cachemgmt * @since 2.0.0 */ override def uncacheTable(tableName: String): Unit = { - sparkSession.sharedState.cacheManager.uncacheQuery(query = sparkSession.table(tableName)) + sparkSession.sharedState.cacheManager.uncacheQuery(sparkSession.table(tableName)) } /** - * Removes all cached tables from the in-memory cache. + * Removes all cached tables or views from the in-memory cache. * * @group cachemgmt * @since 2.0.0 @@ -426,8 +450,12 @@ class CatalogImpl(sparkSession: SparkSession) extends Catalog { } /** - * Refresh the cache entry for a table, if any. For Hive metastore table, the metadata - * is refreshed. For data source tables, the schema will not be inferred and refreshed. + * Invalidates and refreshes all the cached data and metadata of the given table or view. + * For Hive metastore table, the metadata is refreshed. For data source tables, the schema will + * not be inferred and refreshed. + * + * If this table is cached as an InMemoryRelation, drop the original cached version and make the + * new version cached lazily. * * @group cachemgmt * @since 2.0.0 @@ -440,29 +468,25 @@ class CatalogImpl(sparkSession: SparkSession) extends Catalog { // If this table is cached as an InMemoryRelation, drop the original // cached version and make the new version cached lazily. - val logicalPlan = sparkSession.table(tableIdent).queryExecution.analyzed - // Use lookupCachedData directly since RefreshTable also takes databaseName. - val isCached = sparkSession.sharedState.cacheManager.lookupCachedData(logicalPlan).nonEmpty - if (isCached) { - // Create a data frame to represent the table. - // TODO: Use uncacheTable once it supports database name. - val df = Dataset.ofRows(sparkSession, logicalPlan) + val table = sparkSession.table(tableIdent) + if (isCached(table)) { // Uncache the logicalPlan. - sparkSession.sharedState.cacheManager.uncacheQuery(df, blocking = true) + sparkSession.sharedState.cacheManager.uncacheQuery(table, blocking = true) // Cache it again. - sparkSession.sharedState.cacheManager.cacheQuery(df, Some(tableIdent.table)) + sparkSession.sharedState.cacheManager.cacheQuery(table, Some(tableIdent.table)) } } /** - * Refresh the cache entry and the associated metadata for all dataframes (if any), that contain - * the given data source path. + * Refreshes the cache entry and the associated metadata for all Dataset (if any), that contain + * the given data source path. Path matching is by prefix, i.e. "/" would invalidate + * everything that is cached. * * @group cachemgmt * @since 2.0.0 */ override def refreshByPath(resourcePath: String): Unit = { - sparkSession.sharedState.cacheManager.invalidateCachedPath(sparkSession, resourcePath) + sparkSession.sharedState.cacheManager.recacheByPath(sparkSession, resourcePath) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/HiveSerDe.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/HiveSerDe.scala index ca46a1151e3e1..b9515ec7bca2a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/internal/HiveSerDe.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/HiveSerDe.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.internal +import java.util.Locale + import org.apache.spark.sql.catalyst.catalog.CatalogStorageFormat case class HiveSerDe( @@ -68,7 +70,7 @@ object HiveSerDe { * @return HiveSerDe associated with the specified source */ def sourceToSerDe(source: String): Option[HiveSerDe] = { - val key = source.toLowerCase match { + val key = source.toLowerCase(Locale.ROOT) match { case s if s.startsWith("org.apache.spark.sql.parquet") => "parquet" case s if s.startsWith("org.apache.spark.sql.orc") => "orc" case s if s.equals("orcfile") => "orc" diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/SessionState.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/SessionState.scala index 9668106a18f24..2ae108376fc73 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/internal/SessionState.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/SessionState.scala @@ -22,38 +22,58 @@ import java.io.File import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.Path +import org.apache.spark.SparkContext +import org.apache.spark.annotation.{Experimental, InterfaceStability} import org.apache.spark.sql._ -import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.analysis.{Analyzer, FunctionRegistry} import org.apache.spark.sql.catalyst.catalog._ import org.apache.spark.sql.catalyst.optimizer.Optimizer import org.apache.spark.sql.catalyst.parser.ParserInterface import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.execution._ -import org.apache.spark.sql.execution.command.AnalyzeTableCommand -import org.apache.spark.sql.execution.datasources._ -import org.apache.spark.sql.streaming.{StreamingQuery, StreamingQueryManager} -import org.apache.spark.sql.util.ExecutionListenerManager - +import org.apache.spark.sql.streaming.StreamingQueryManager +import org.apache.spark.sql.util.{ExecutionListenerManager, QueryExecutionListener} /** * A class that holds all session-specific state in a given [[SparkSession]]. + * + * @param sharedState The state shared across sessions, e.g. global view manager, external catalog. + * @param conf SQL-specific key-value configurations. + * @param experimentalMethods Interface to add custom planning strategies and optimizers. + * @param functionRegistry Internal catalog for managing functions registered by the user. + * @param udfRegistration Interface exposed to the user for registering user-defined functions. + * @param catalog Internal catalog for managing table and database states. + * @param sqlParser Parser that extracts expressions, plans, table identifiers etc. from SQL texts. + * @param analyzer Logical query plan analyzer for resolving unresolved attributes and relations. + * @param optimizer Logical query plan optimizer. + * @param planner Planner that converts optimized logical plans to physical plans. + * @param streamingQueryManager Interface to start and stop streaming queries. + * @param listenerManager Interface to register custom [[QueryExecutionListener]]s. + * @param resourceLoader Session shared resource loader to load JARs, files, etc. + * @param createQueryExecution Function used to create QueryExecution objects. + * @param createClone Function used to create clones of the session state. */ -class SessionState(sparkSession: SparkSession) { - - // Note: These are all lazy vals because they depend on each other (e.g. conf) and we - // want subclasses to override some of the fields. Otherwise, we would get a lot of NPEs. - /** - * SQL-specific key-value configurations. - */ - lazy val conf: SQLConf = new SQLConf - - def newHadoopConf(): Configuration = { - val hadoopConf = new Configuration(sparkSession.sparkContext.hadoopConfiguration) - conf.getAllConfs.foreach { case (k, v) => if (v ne null) hadoopConf.set(k, v) } - hadoopConf - } +private[sql] class SessionState( + sharedState: SharedState, + val conf: SQLConf, + val experimentalMethods: ExperimentalMethods, + val functionRegistry: FunctionRegistry, + val udfRegistration: UDFRegistration, + val catalog: SessionCatalog, + val sqlParser: ParserInterface, + val analyzer: Analyzer, + val optimizer: Optimizer, + val planner: SparkPlanner, + val streamingQueryManager: StreamingQueryManager, + val listenerManager: ExecutionListenerManager, + val resourceLoader: SessionResourceLoader, + createQueryExecution: LogicalPlan => QueryExecution, + createClone: (SparkSession, SessionState) => SessionState) { + + def newHadoopConf(): Configuration = SessionState.newHadoopConf( + sharedState.sparkContext.hadoopConfiguration, + conf) def newHadoopConfWithOptions(options: Map[String, String]): Configuration = { val hadoopConf = newHadoopConf() @@ -65,118 +85,67 @@ class SessionState(sparkSession: SparkSession) { hadoopConf } - lazy val experimentalMethods = new ExperimentalMethods - /** - * Internal catalog for managing functions registered by the user. + * Get an identical copy of the `SessionState` and associate it with the given `SparkSession` */ - lazy val functionRegistry: FunctionRegistry = FunctionRegistry.builtin.copy() - - /** - * A class for loading resources specified by a function. - */ - lazy val functionResourceLoader: FunctionResourceLoader = { - new FunctionResourceLoader { - override def loadResource(resource: FunctionResource): Unit = { - resource.resourceType match { - case JarResource => addJar(resource.uri) - case FileResource => sparkSession.sparkContext.addFile(resource.uri) - case ArchiveResource => - throw new AnalysisException( - "Archive is not allowed to be loaded. If YARN mode is used, " + - "please use --archives options while calling spark-submit.") - } - } - } - } - - /** - * Internal catalog for managing table and database states. - */ - lazy val catalog = new SessionCatalog( - sparkSession.sharedState.externalCatalog, - sparkSession.sharedState.globalTempViewManager, - functionResourceLoader, - functionRegistry, - conf, - newHadoopConf(), - sqlParser) - - /** - * Interface exposed to the user for registering user-defined functions. - * Note that the user-defined functions must be deterministic. - */ - lazy val udf: UDFRegistration = new UDFRegistration(functionRegistry) - - /** - * Logical query plan analyzer for resolving unresolved attributes and relations. - */ - lazy val analyzer: Analyzer = { - new Analyzer(catalog, conf) { - override val extendedResolutionRules = - new FindDataSourceTable(sparkSession) :: - new ResolveSQLOnFile(sparkSession) :: Nil - - override val postHocResolutionRules = - PreprocessTableCreation(sparkSession) :: - PreprocessTableInsertion(conf) :: - DataSourceAnalysis(conf) :: Nil - - override val extendedCheckRules = Seq(PreWriteCheck, HiveOnlyCheck) - } - } - - /** - * Logical query plan optimizer. - */ - lazy val optimizer: Optimizer = new SparkOptimizer(catalog, conf, experimentalMethods) - - /** - * Parser that extracts expressions, plans, table identifiers etc. from SQL texts. - */ - lazy val sqlParser: ParserInterface = new SparkSqlParser(conf) - - /** - * Planner that converts optimized logical plans to physical plans. - */ - def planner: SparkPlanner = - new SparkPlanner(sparkSession.sparkContext, conf, experimentalMethods.extraStrategies) - - /** - * An interface to register custom [[org.apache.spark.sql.util.QueryExecutionListener]]s - * that listen for execution metrics. - */ - lazy val listenerManager: ExecutionListenerManager = new ExecutionListenerManager - - /** - * Interface to start and stop [[StreamingQuery]]s. - */ - lazy val streamingQueryManager: StreamingQueryManager = { - new StreamingQueryManager(sparkSession) - } - - private val jarClassLoader: NonClosableMutableURLClassLoader = - sparkSession.sharedState.jarClassLoader - - // Automatically extract all entries and put it in our SQLConf - // We need to call it after all of vals have been initialized. - sparkSession.sparkContext.getConf.getAll.foreach { case (k, v) => - conf.setConfString(k, v) - } + def clone(newSparkSession: SparkSession): SessionState = createClone(newSparkSession, this) // ------------------------------------------------------ // Helper methods, partially leftover from pre-2.0 days // ------------------------------------------------------ - def executePlan(plan: LogicalPlan): QueryExecution = new QueryExecution(sparkSession, plan) + def executePlan(plan: LogicalPlan): QueryExecution = createQueryExecution(plan) def refreshTable(tableName: String): Unit = { catalog.refreshTable(sqlParser.parseTableIdentifier(tableName)) } +} - def addJar(path: String): Unit = { - sparkSession.sparkContext.addJar(path) +private[sql] object SessionState { + def newHadoopConf(hadoopConf: Configuration, sqlConf: SQLConf): Configuration = { + val newHadoopConf = new Configuration(hadoopConf) + sqlConf.getAllConfs.foreach { case (k, v) => if (v ne null) newHadoopConf.set(k, v) } + newHadoopConf + } +} +/** + * Concrete implementation of a [[SessionStateBuilder]]. + */ +@Experimental +@InterfaceStability.Unstable +class SessionStateBuilder( + session: SparkSession, + parentState: Option[SessionState] = None) + extends BaseSessionStateBuilder(session, parentState) { + override protected def newBuilder: NewBuilder = new SessionStateBuilder(_, _) +} + +/** + * Session shared [[FunctionResourceLoader]]. + */ +@InterfaceStability.Unstable +class SessionResourceLoader(session: SparkSession) extends FunctionResourceLoader { + override def loadResource(resource: FunctionResource): Unit = { + resource.resourceType match { + case JarResource => addJar(resource.uri) + case FileResource => session.sparkContext.addFile(resource.uri) + case ArchiveResource => + throw new AnalysisException( + "Archive is not allowed to be loaded. If YARN mode is used, " + + "please use --archives options while calling spark-submit.") + } + } + + /** + * Add a jar path to [[SparkContext]] and the classloader. + * + * Note: this method seems not access any session state, but a Hive based `SessionState` needs + * to add the jar to its hive client for the current session. Hence, it still needs to be in + * [[SessionState]]. + */ + def addJar(path: String): Unit = { + session.sparkContext.addJar(path) val uri = new Path(path).toUri val jarURL = if (uri.getScheme == null) { // `path` is a local file path without a URL scheme @@ -185,15 +154,7 @@ class SessionState(sparkSession: SparkSession) { // `path` is a URL with a scheme uri.toURL } - jarClassLoader.addURL(jarURL) - Thread.currentThread().setContextClassLoader(jarClassLoader) - } - - /** - * Analyzes the given table in the current database to generate statistics, which will be - * used in query optimizations. - */ - def analyze(tableIdent: TableIdentifier, noscan: Boolean = true): Unit = { - AnalyzeTableCommand(tableIdent, noscan).run(sparkSession) + session.sharedState.jarClassLoader.addURL(jarURL) + Thread.currentThread().setContextClassLoader(session.sharedState.jarClassLoader) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/SharedState.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/SharedState.scala index bfe4825307708..a5219191b0399 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/internal/SharedState.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/SharedState.scala @@ -17,10 +17,14 @@ package org.apache.spark.sql.internal +import java.net.URL +import java.util.Locale + import scala.reflect.ClassTag import scala.util.control.NonFatal import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.fs.FsUrlStreamHandlerFactory import org.apache.spark.{SparkConf, SparkContext, SparkException} import org.apache.spark.internal.Logging @@ -86,7 +90,7 @@ private[sql] class SharedState(val sparkContext: SparkContext) extends Logging { /** * A catalog that interacts with external systems. */ - val externalCatalog: ExternalCatalog = + lazy val externalCatalog: ExternalCatalog = SharedState.reflect[ExternalCatalog, SparkConf, Configuration]( SharedState.externalCatalogClassName(sparkContext.conf), sparkContext.conf, @@ -95,7 +99,10 @@ private[sql] class SharedState(val sparkContext: SparkContext) extends Logging { // Create the default database if it doesn't exist. { val defaultDbDefinition = CatalogDatabase( - SessionCatalog.DEFAULT_DATABASE, "default database", warehousePath, Map()) + SessionCatalog.DEFAULT_DATABASE, + "default database", + CatalogUtils.stringToURI(warehousePath), + Map()) // Initialize default database if it doesn't exist if (!externalCatalog.databaseExists(SessionCatalog.DEFAULT_DATABASE)) { // There may be another Spark application creating default database at the same time, here we @@ -104,6 +111,13 @@ private[sql] class SharedState(val sparkContext: SparkContext) extends Logging { } } + // Make sure we propagate external catalog events to the spark listener bus + externalCatalog.addListener(new ExternalCatalogEventListener { + override def onEvent(event: ExternalCatalogEvent): Unit = { + sparkContext.listenerBus.post(event) + } + }) + /** * A manager for global temporary views. */ @@ -111,7 +125,7 @@ private[sql] class SharedState(val sparkContext: SparkContext) extends Logging { // System preserved database should not exists in metastore. However it's hard to guarantee it // for every session, because case-sensitivity differs. Here we always lowercase it to make our // life easier. - val globalTempDB = sparkContext.conf.get(GLOBAL_TEMP_DATABASE).toLowerCase + val globalTempDB = sparkContext.conf.get(GLOBAL_TEMP_DATABASE).toLowerCase(Locale.ROOT) if (externalCatalog.databaseExists(globalTempDB)) { throw new SparkException( s"$globalTempDB is a system preserved database, please rename your existing database " + @@ -142,7 +156,13 @@ private[sql] class SharedState(val sparkContext: SparkContext) extends Logging { } } -object SharedState { +object SharedState extends Logging { + try { + URL.setURLStreamHandlerFactory(new FsUrlStreamHandlerFactory()) + } catch { + case e: Error => + logWarning("URL.setURLStreamHandlerFactory failed to set FsUrlStreamHandlerFactory") + } private val HIVE_EXTERNAL_CATALOG_CLASS_NAME = "org.apache.spark.sql.hive.HiveExternalCatalog" diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala index aed8074a64d5b..746b2a94f102d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.streaming +import java.util.Locale + import scala.collection.JavaConverters._ import org.apache.spark.annotation.{Experimental, InterfaceStability} @@ -61,6 +63,12 @@ final class DataStreamReader private[sql](sparkSession: SparkSession) extends Lo /** * Adds an input option for the underlying data source. * + * You can set the following option(s): + *
      + *
    • `timeZone` (default session local timezone): sets the string that indicates a timezone + * to be used to parse timestamps in the JSON/CSV datasources or partition values.
    • + *
    + * * @since 2.0.0 */ def option(key: String, value: String): DataStreamReader = { @@ -92,6 +100,12 @@ final class DataStreamReader private[sql](sparkSession: SparkSession) extends Lo /** * (Scala-specific) Adds input options for the underlying data source. * + * You can set the following option(s): + *
      + *
    • `timeZone` (default session local timezone): sets the string that indicates a timezone + * to be used to parse timestamps in the JSON/CSV datasources or partition values.
    • + *
    + * * @since 2.0.0 */ def options(options: scala.collection.Map[String, String]): DataStreamReader = { @@ -102,6 +116,12 @@ final class DataStreamReader private[sql](sparkSession: SparkSession) extends Lo /** * Adds input options for the underlying data source. * + * You can set the following option(s): + *
      + *
    • `timeZone` (default session local timezone): sets the string that indicates a timezone + * to be used to parse timestamps in the JSON/CSV datasources or partition values.
    • + *
    + * * @since 2.0.0 */ def options(options: java.util.Map[String, String]): DataStreamReader = { @@ -117,7 +137,7 @@ final class DataStreamReader private[sql](sparkSession: SparkSession) extends Lo * @since 2.0.0 */ def load(): DataFrame = { - if (source.toLowerCase == DDLUtils.HIVE_PROVIDER) { + if (source.toLowerCase(Locale.ROOT) == DDLUtils.HIVE_PROVIDER) { throw new AnalysisException("Hive data source can only be used with tables, you can not " + "read files of Hive data source directly.") } @@ -183,11 +203,9 @@ final class DataStreamReader private[sql](sparkSession: SparkSession) extends Lo *
  • `dateFormat` (default `yyyy-MM-dd`): sets the string that indicates a date format. * Custom date formats follow the formats at `java.text.SimpleDateFormat`. This applies to * date type.
  • - *
  • `timestampFormat` (default `yyyy-MM-dd'T'HH:mm:ss.SSSZZ`): sets the string that + *
  • `timestampFormat` (default `yyyy-MM-dd'T'HH:mm:ss.SSSXXX`): sets the string that * indicates a timestamp format. Custom date formats follow the formats at * `java.text.SimpleDateFormat`. This applies to timestamp type.
  • - *
  • `timeZone` (default session local timezone): sets the string that indicates a timezone - * to be used to parse timestamps.
  • *
  • `wholeFile` (default `false`): parse one record, which may span multiple lines, * per file
  • * @@ -222,9 +240,9 @@ final class DataStreamReader private[sql](sparkSession: SparkSession) extends Lo *
  • `header` (default `false`): uses the first line as names of columns.
  • *
  • `inferSchema` (default `false`): infers the input schema automatically from data. It * requires one extra pass over the data.
  • - *
  • `ignoreLeadingWhiteSpace` (default `false`): defines whether or not leading whitespaces - * from values being read should be skipped.
  • - *
  • `ignoreTrailingWhiteSpace` (default `false`): defines whether or not trailing + *
  • `ignoreLeadingWhiteSpace` (default `false`): a flag indicating whether or not leading + * whitespaces from values being read should be skipped.
  • + *
  • `ignoreTrailingWhiteSpace` (default `false`): a flag indicating whether or not trailing * whitespaces from values being read should be skipped.
  • *
  • `nullValue` (default empty string): sets the string representation of a null value. Since * 2.0.1, this applies to all supported types including the string type.
  • @@ -236,17 +254,15 @@ final class DataStreamReader private[sql](sparkSession: SparkSession) extends Lo *
  • `dateFormat` (default `yyyy-MM-dd`): sets the string that indicates a date format. * Custom date formats follow the formats at `java.text.SimpleDateFormat`. This applies to * date type.
  • - *
  • `timestampFormat` (default `yyyy-MM-dd'T'HH:mm:ss.SSSZZ`): sets the string that + *
  • `timestampFormat` (default `yyyy-MM-dd'T'HH:mm:ss.SSSXXX`): sets the string that * indicates a timestamp format. Custom date formats follow the formats at * `java.text.SimpleDateFormat`. This applies to timestamp type.
  • - *
  • `timeZone` (default session local timezone): sets the string that indicates a timezone - * to be used to parse timestamps.
  • *
  • `maxColumns` (default `20480`): defines a hard limit of how many columns * a record can have.
  • *
  • `maxCharsPerColumn` (default `-1`): defines the maximum number of characters allowed * for any given value being read. By default, it is -1 meaning unlimited length
  • *
  • `mode` (default `PERMISSIVE`): allows a mode for dealing with corrupt records - * during parsing. + * during parsing. It supports the following case-insensitive modes. *
      *
    • `PERMISSIVE` : sets other fields to `null` when it meets a corrupted record, and puts * the malformed string into a field configured by `columnNameOfCorruptRecord`. To keep diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala index 0f7a33723cccc..0d2611f9bbcce 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala @@ -17,11 +17,13 @@ package org.apache.spark.sql.streaming +import java.util.Locale + import scala.collection.JavaConverters._ import org.apache.spark.annotation.{Experimental, InterfaceStability} -import org.apache.spark.sql.{AnalysisException, DataFrame, Dataset, ForeachWriter} -import org.apache.spark.sql.catalyst.streaming.InternalOutputModes._ +import org.apache.spark.sql.{AnalysisException, Dataset, ForeachWriter} +import org.apache.spark.sql.catalyst.streaming.InternalOutputModes import org.apache.spark.sql.execution.command.DDLUtils import org.apache.spark.sql.execution.datasources.DataSource import org.apache.spark.sql.execution.streaming.{ForeachSink, MemoryPlan, MemorySink} @@ -69,17 +71,7 @@ final class DataStreamWriter[T] private[sql](ds: Dataset[T]) { * @since 2.0.0 */ def outputMode(outputMode: String): DataStreamWriter[T] = { - this.outputMode = outputMode.toLowerCase match { - case "append" => - OutputMode.Append - case "complete" => - OutputMode.Complete - case "update" => - OutputMode.Update - case _ => - throw new IllegalArgumentException(s"Unknown output mode $outputMode. " + - "Accepted output modes are 'append', 'complete', 'update'") - } + this.outputMode = InternalOutputModes(outputMode) this } @@ -155,6 +147,12 @@ final class DataStreamWriter[T] private[sql](ds: Dataset[T]) { /** * Adds an output option for the underlying data source. * + * You can set the following option(s): + *
        + *
      • `timeZone` (default session local timezone): sets the string that indicates a timezone + * to be used to format timestamps in the JSON/CSV datasources or partition values.
      • + *
      + * * @since 2.0.0 */ def option(key: String, value: String): DataStreamWriter[T] = { @@ -186,6 +184,12 @@ final class DataStreamWriter[T] private[sql](ds: Dataset[T]) { /** * (Scala-specific) Adds output options for the underlying data source. * + * You can set the following option(s): + *
        + *
      • `timeZone` (default session local timezone): sets the string that indicates a timezone + * to be used to format timestamps in the JSON/CSV datasources or partition values.
      • + *
      + * * @since 2.0.0 */ def options(options: scala.collection.Map[String, String]): DataStreamWriter[T] = { @@ -196,6 +200,12 @@ final class DataStreamWriter[T] private[sql](ds: Dataset[T]) { /** * Adds output options for the underlying data source. * + * You can set the following option(s): + *
        + *
      • `timeZone` (default session local timezone): sets the string that indicates a timezone + * to be used to format timestamps in the JSON/CSV datasources or partition values.
      • + *
      + * * @since 2.0.0 */ def options(options: java.util.Map[String, String]): DataStreamWriter[T] = { @@ -222,7 +232,7 @@ final class DataStreamWriter[T] private[sql](ds: Dataset[T]) { * @since 2.0.0 */ def start(): StreamingQuery = { - if (source.toLowerCase == DDLUtils.HIVE_PROVIDER) { + if (source.toLowerCase(Locale.ROOT) == DDLUtils.HIVE_PROVIDER) { throw new AnalysisException("Hive data source can only be used with tables, you can not " + "write files of Hive data source directly.") } @@ -369,7 +379,7 @@ final class DataStreamWriter[T] private[sql](ds: Dataset[T]) { private var outputMode: OutputMode = OutputMode.Append - private var trigger: Trigger = ProcessingTime(0L) + private var trigger: Trigger = Trigger.ProcessingTime(0L) private var extraOptions = new scala.collection.mutable.HashMap[String, String] diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/GroupState.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/GroupState.scala new file mode 100644 index 0000000000000..c659ac7fcf3d9 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/GroupState.scala @@ -0,0 +1,296 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.streaming + +import org.apache.spark.annotation.{Experimental, InterfaceStability} +import org.apache.spark.sql.KeyValueGroupedDataset +import org.apache.spark.sql.catalyst.plans.logical.LogicalGroupState + +/** + * :: Experimental :: + * + * Wrapper class for interacting with per-group state data in `mapGroupsWithState` and + * `flatMapGroupsWithState` operations on `KeyValueGroupedDataset`. + * + * Detail description on `[map/flatMap]GroupsWithState` operation + * -------------------------------------------------------------- + * Both, `mapGroupsWithState` and `flatMapGroupsWithState` in `KeyValueGroupedDataset` + * will invoke the user-given function on each group (defined by the grouping function in + * `Dataset.groupByKey()`) while maintaining user-defined per-group state between invocations. + * For a static batch Dataset, the function will be invoked once per group. For a streaming + * Dataset, the function will be invoked for each group repeatedly in every trigger. + * That is, in every batch of the `StreamingQuery`, + * the function will be invoked once for each group that has data in the trigger. Furthermore, + * if timeout is set, then the function will invoked on timed out groups (more detail below). + * + * The function is invoked with following parameters. + * - The key of the group. + * - An iterator containing all the values for this group. + * - A user-defined state object set by previous invocations of the given function. + * + * In case of a batch Dataset, there is only one invocation and state object will be empty as + * there is no prior state. Essentially, for batch Datasets, `[map/flatMap]GroupsWithState` + * is equivalent to `[map/flatMap]Groups` and any updates to the state and/or timeouts have + * no effect. + * + * The major difference between `mapGroupsWithState` and `flatMapGroupsWithState` is that the + * former allows the function to return one and only one record, whereas the latter + * allows the function to return any number of records (including no records). Furthermore, the + * `flatMapGroupsWithState` is associated with an operation output mode, which can be either + * `Append` or `Update`. Semantically, this defines whether the output records of one trigger + * is effectively replacing the previously output records (from previous triggers) or is appending + * to the list of previously output records. Essentially, this defines how the Result Table (refer + * to the semantics in the programming guide) is updated, and allows us to reason about the + * semantics of later operations. + * + * Important points to note about the function (both mapGroupsWithState and flatMapGroupsWithState). + * - In a trigger, the function will be called only the groups present in the batch. So do not + * assume that the function will be called in every trigger for every group that has state. + * - There is no guaranteed ordering of values in the iterator in the function, neither with + * batch, nor with streaming Datasets. + * - All the data will be shuffled before applying the function. + * - If timeout is set, then the function will also be called with no values. + * See more details on `GroupStateTimeout` below. + * + * Important points to note about using `GroupState`. + * - The value of the state cannot be null. So updating state with null will throw + * `IllegalArgumentException`. + * - Operations on `GroupState` are not thread-safe. This is to avoid memory barriers. + * - If `remove()` is called, then `exists()` will return `false`, + * `get()` will throw `NoSuchElementException` and `getOption()` will return `None` + * - After that, if `update(newState)` is called, then `exists()` will again return `true`, + * `get()` and `getOption()`will return the updated value. + * + * Important points to note about using `GroupStateTimeout`. + * - The timeout type is a global param across all the groups (set as `timeout` param in + * `[map|flatMap]GroupsWithState`, but the exact timeout duration/timestamp is configurable per + * group by calling `setTimeout...()` in `GroupState`. + * - Timeouts can be either based on processing time (i.e. + * `GroupStateTimeout.ProcessingTimeTimeout`) or event time (i.e. + * `GroupStateTimeout.EventTimeTimeout`). + * - With `ProcessingTimeTimeout`, the timeout duration can be set by calling + * `GroupState.setTimeoutDuration`. The timeout will occur when the clock has advanced by the set + * duration. Guarantees provided by this timeout with a duration of D ms are as follows: + * - Timeout will never be occur before the clock time has advanced by D ms + * - Timeout will occur eventually when there is a trigger in the query + * (i.e. after D ms). So there is a no strict upper bound on when the timeout would occur. + * For example, the trigger interval of the query will affect when the timeout actually occurs. + * If there is no data in the stream (for any group) for a while, then their will not be + * any trigger and timeout function call will not occur until there is data. + * - Since the processing time timeout is based on the clock time, it is affected by the + * variations in the system clock (i.e. time zone changes, clock skew, etc.). + * - With `EventTimeTimeout`, the user also has to specify the the the event time watermark in + * the query using `Dataset.withWatermark()`. With this setting, data that is older than the + * watermark are filtered out. The timeout can be set for a group by setting a timeout timestamp + * using`GroupState.setTimeoutTimestamp()`, and the timeout would occur when the watermark + * advances beyond the set timestamp. You can control the timeout delay by two parameters - + * (i) watermark delay and an additional duration beyond the timestamp in the event (which + * is guaranteed to be newer than watermark due to the filtering). Guarantees provided by this + * timeout are as follows: + * - Timeout will never be occur before watermark has exceeded the set timeout. + * - Similar to processing time timeouts, there is a no strict upper bound on the delay when + * the timeout actually occurs. The watermark can advance only when there is data in the + * stream, and the event time of the data has actually advanced. + * - When the timeout occurs for a group, the function is called for that group with no values, and + * `GroupState.hasTimedOut()` set to true. + * - The timeout is reset every time the function is called on a group, that is, + * when the group has new data, or the group has timed out. So the user has to set the timeout + * duration every time the function is called, otherwise there will not be any timeout set. + * + * Scala example of using GroupState in `mapGroupsWithState`: + * {{{ + * // A mapping function that maintains an integer state for string keys and returns a string. + * // Additionally, it sets a timeout to remove the state if it has not received data for an hour. + * def mappingFunction(key: String, value: Iterator[Int], state: GroupState[Int]): String = { + * + * if (state.hasTimedOut) { // If called when timing out, remove the state + * state.remove() + * + * } else if (state.exists) { // If state exists, use it for processing + * val existingState = state.get // Get the existing state + * val shouldRemove = ... // Decide whether to remove the state + * if (shouldRemove) { + * state.remove() // Remove the state + * + * } else { + * val newState = ... + * state.update(newState) // Set the new state + * state.setTimeoutDuration("1 hour") // Set the timeout + * } + * + * } else { + * val initialState = ... + * state.update(initialState) // Set the initial state + * state.setTimeoutDuration("1 hour") // Set the timeout + * } + * ... + * // return something + * } + * + * dataset + * .groupByKey(...) + * .mapGroupsWithState(GroupStateTimeout.ProcessingTimeTimeout)(mappingFunction) + * }}} + * + * Java example of using `GroupState`: + * {{{ + * // A mapping function that maintains an integer state for string keys and returns a string. + * // Additionally, it sets a timeout to remove the state if it has not received data for an hour. + * MapGroupsWithStateFunction mappingFunction = + * new MapGroupsWithStateFunction() { + * + * @Override + * public String call(String key, Iterator value, GroupState state) { + * if (state.hasTimedOut()) { // If called when timing out, remove the state + * state.remove(); + * + * } else if (state.exists()) { // If state exists, use it for processing + * int existingState = state.get(); // Get the existing state + * boolean shouldRemove = ...; // Decide whether to remove the state + * if (shouldRemove) { + * state.remove(); // Remove the state + * + * } else { + * int newState = ...; + * state.update(newState); // Set the new state + * state.setTimeoutDuration("1 hour"); // Set the timeout + * } + * + * } else { + * int initialState = ...; // Set the initial state + * state.update(initialState); + * state.setTimeoutDuration("1 hour"); // Set the timeout + * } + * ... +* // return something + * } + * }; + * + * dataset + * .groupByKey(...) + * .mapGroupsWithState( + * mappingFunction, Encoders.INT, Encoders.STRING, GroupStateTimeout.ProcessingTimeTimeout); + * }}} + * + * @tparam S User-defined type of the state to be stored for each group. Must be encodable into + * Spark SQL types (see `Encoder` for more details). + * @since 2.2.0 + */ +@Experimental +@InterfaceStability.Evolving +trait GroupState[S] extends LogicalGroupState[S] { + + /** Whether state exists or not. */ + def exists: Boolean + + /** Get the state value if it exists, or throw NoSuchElementException. */ + @throws[NoSuchElementException]("when state does not exist") + def get: S + + /** Get the state value as a scala Option. */ + def getOption: Option[S] + + /** + * Update the value of the state. Note that `null` is not a valid value, and it throws + * IllegalArgumentException. + */ + @throws[IllegalArgumentException]("when updating with null") + def update(newState: S): Unit + + /** Remove this state. Note that this resets any timeout configuration as well. */ + def remove(): Unit + + /** + * Whether the function has been called because the key has timed out. + * @note This can return true only when timeouts are enabled in `[map/flatmap]GroupsWithStates`. + */ + def hasTimedOut: Boolean + + /** + * Set the timeout duration in ms for this key. + * + * @note ProcessingTimeTimeout must be enabled in `[map/flatmap]GroupsWithStates`. + */ + @throws[IllegalArgumentException]("if 'durationMs' is not positive") + @throws[IllegalStateException]("when state is either not initialized, or already removed") + @throws[UnsupportedOperationException]( + "if 'timeout' has not been enabled in [map|flatMap]GroupsWithState in a streaming query") + def setTimeoutDuration(durationMs: Long): Unit + + /** + * Set the timeout duration for this key as a string. For example, "1 hour", "2 days", etc. + * + * @note ProcessingTimeTimeout must be enabled in `[map/flatmap]GroupsWithStates`. + */ + @throws[IllegalArgumentException]("if 'duration' is not a valid duration") + @throws[IllegalStateException]("when state is either not initialized, or already removed") + @throws[UnsupportedOperationException]( + "if 'timeout' has not been enabled in [map|flatMap]GroupsWithState in a streaming query") + def setTimeoutDuration(duration: String): Unit + + @throws[IllegalArgumentException]("if 'timestampMs' is not positive") + @throws[IllegalStateException]("when state is either not initialized, or already removed") + @throws[UnsupportedOperationException]( + "if 'timeout' has not been enabled in [map|flatMap]GroupsWithState in a streaming query") + /** + * Set the timeout timestamp for this key as milliseconds in epoch time. + * This timestamp cannot be older than the current watermark. + * + * @note EventTimeTimeout must be enabled in `[map/flatmap]GroupsWithStates`. + */ + def setTimeoutTimestamp(timestampMs: Long): Unit + + @throws[IllegalArgumentException]("if 'additionalDuration' is invalid") + @throws[IllegalStateException]("when state is either not initialized, or already removed") + @throws[UnsupportedOperationException]( + "if 'timeout' has not been enabled in [map|flatMap]GroupsWithState in a streaming query") + /** + * Set the timeout timestamp for this key as milliseconds in epoch time and an additional + * duration as a string (e.g. "1 hour", "2 days", etc.). + * The final timestamp (including the additional duration) cannot be older than the + * current watermark. + * + * @note EventTimeTimeout must be enabled in `[map/flatmap]GroupsWithStates`. + */ + def setTimeoutTimestamp(timestampMs: Long, additionalDuration: String): Unit + + @throws[IllegalStateException]("when state is either not initialized, or already removed") + @throws[UnsupportedOperationException]( + "if 'timeout' has not been enabled in [map|flatMap]GroupsWithState in a streaming query") + /** + * Set the timeout timestamp for this key as a java.sql.Date. + * This timestamp cannot be older than the current watermark. + * + * @note EventTimeTimeout must be enabled in `[map/flatmap]GroupsWithStates`. + */ + def setTimeoutTimestamp(timestamp: java.sql.Date): Unit + + @throws[IllegalArgumentException]("if 'additionalDuration' is invalid") + @throws[IllegalStateException]("when state is either not initialized, or already removed") + @throws[UnsupportedOperationException]( + "if 'timeout' has not been enabled in [map|flatMap]GroupsWithState in a streaming query") + /** + * Set the timeout timestamp for this key as a java.sql.Date and an additional + * duration as a string (e.g. "1 hour", "2 days", etc.). + * The final timestamp (including the additional duration) cannot be older than the + * current watermark. + * + * @note EventTimeTimeout must be enabled in `[map/flatmap]GroupsWithStates`. + */ + def setTimeoutTimestamp(timestamp: java.sql.Date, additionalDuration: String): Unit +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/Trigger.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/ProcessingTime.scala similarity index 75% rename from sql/core/src/main/scala/org/apache/spark/sql/streaming/Trigger.scala rename to sql/core/src/main/scala/org/apache/spark/sql/streaming/ProcessingTime.scala index 68f2eab9d45fc..9ba1fc01cbd30 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/streaming/Trigger.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/ProcessingTime.scala @@ -26,16 +26,6 @@ import org.apache.commons.lang3.StringUtils import org.apache.spark.annotation.{Experimental, InterfaceStability} import org.apache.spark.unsafe.types.CalendarInterval -/** - * :: Experimental :: - * Used to indicate how often results should be produced by a [[StreamingQuery]]. - * - * @since 2.0.0 - */ -@Experimental -@InterfaceStability.Evolving -sealed trait Trigger - /** * :: Experimental :: * A trigger that runs a query periodically based on the processing time. If `interval` is 0, @@ -43,24 +33,25 @@ sealed trait Trigger * * Scala Example: * {{{ - * df.write.trigger(ProcessingTime("10 seconds")) + * df.writeStream.trigger(ProcessingTime("10 seconds")) * * import scala.concurrent.duration._ - * df.write.trigger(ProcessingTime(10.seconds)) + * df.writeStream.trigger(ProcessingTime(10.seconds)) * }}} * * Java Example: * {{{ - * df.write.trigger(ProcessingTime.create("10 seconds")) + * df.writeStream.trigger(ProcessingTime.create("10 seconds")) * * import java.util.concurrent.TimeUnit - * df.write.trigger(ProcessingTime.create(10, TimeUnit.SECONDS)) + * df.writeStream.trigger(ProcessingTime.create(10, TimeUnit.SECONDS)) * }}} * * @since 2.0.0 */ @Experimental @InterfaceStability.Evolving +@deprecated("use Trigger.ProcessingTime(intervalMs)", "2.2.0") case class ProcessingTime(intervalMs: Long) extends Trigger { require(intervalMs >= 0, "the interval of trigger should not be negative") } @@ -73,6 +64,7 @@ case class ProcessingTime(intervalMs: Long) extends Trigger { */ @Experimental @InterfaceStability.Evolving +@deprecated("use Trigger.ProcessingTime(intervalMs)", "2.2.0") object ProcessingTime { /** @@ -80,11 +72,13 @@ object ProcessingTime { * * Example: * {{{ - * df.write.trigger(ProcessingTime("10 seconds")) + * df.writeStream.trigger(ProcessingTime("10 seconds")) * }}} * * @since 2.0.0 + * @deprecated use Trigger.ProcessingTime(interval) */ + @deprecated("use Trigger.ProcessingTime(interval)", "2.2.0") def apply(interval: String): ProcessingTime = { if (StringUtils.isBlank(interval)) { throw new IllegalArgumentException( @@ -110,11 +104,13 @@ object ProcessingTime { * Example: * {{{ * import scala.concurrent.duration._ - * df.write.trigger(ProcessingTime(10.seconds)) + * df.writeStream.trigger(ProcessingTime(10.seconds)) * }}} * * @since 2.0.0 + * @deprecated use Trigger.ProcessingTime(interval) */ + @deprecated("use Trigger.ProcessingTime(interval)", "2.2.0") def apply(interval: Duration): ProcessingTime = { new ProcessingTime(interval.toMillis) } @@ -124,11 +120,13 @@ object ProcessingTime { * * Example: * {{{ - * df.write.trigger(ProcessingTime.create("10 seconds")) + * df.writeStream.trigger(ProcessingTime.create("10 seconds")) * }}} * * @since 2.0.0 + * @deprecated use Trigger.ProcessingTime(interval) */ + @deprecated("use Trigger.ProcessingTime(interval)", "2.2.0") def create(interval: String): ProcessingTime = { apply(interval) } @@ -139,11 +137,13 @@ object ProcessingTime { * Example: * {{{ * import java.util.concurrent.TimeUnit - * df.write.trigger(ProcessingTime.create(10, TimeUnit.SECONDS)) + * df.writeStream.trigger(ProcessingTime.create(10, TimeUnit.SECONDS)) * }}} * * @since 2.0.0 + * @deprecated use Trigger.ProcessingTime(interval, unit) */ + @deprecated("use Trigger.ProcessingTime(interval, unit)", "2.2.0") def create(interval: Long, unit: TimeUnit): ProcessingTime = { new ProcessingTime(unit.toMillis(interval)) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryManager.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryManager.scala index 38edb40dfb781..7810d9f6e9642 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryManager.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryManager.scala @@ -25,6 +25,7 @@ import scala.collection.mutable import org.apache.hadoop.fs.Path import org.apache.spark.annotation.{Experimental, InterfaceStability} +import org.apache.spark.internal.Logging import org.apache.spark.sql.{AnalysisException, DataFrame, SparkSession} import org.apache.spark.sql.catalyst.analysis.UnsupportedOperationChecker import org.apache.spark.sql.execution.streaming._ @@ -40,7 +41,7 @@ import org.apache.spark.util.{Clock, SystemClock, Utils} */ @Experimental @InterfaceStability.Evolving -class StreamingQueryManager private[sql] (sparkSession: SparkSession) { +class StreamingQueryManager private[sql] (sparkSession: SparkSession) extends Logging { private[sql] val stateStoreCoordinator = StateStoreCoordinatorRef.forDriver(sparkSession.sparkContext.env) @@ -234,9 +235,8 @@ class StreamingQueryManager private[sql] (sparkSession: SparkSession) { } if (sparkSession.sessionState.conf.adaptiveExecutionEnabled) { - throw new AnalysisException( - s"${SQLConf.ADAPTIVE_EXECUTION_ENABLED.key} " + - "is not supported in streaming DataFrames/Datasets") + logWarning(s"${SQLConf.ADAPTIVE_EXECUTION_ENABLED.key} " + + "is not supported in streaming DataFrames/Datasets and will be disabled.") } new StreamingQueryWrapper(new StreamExecution( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/Trigger.java b/sql/core/src/main/scala/org/apache/spark/sql/streaming/Trigger.java new file mode 100644 index 0000000000000..3e3997fa9bfec --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/Trigger.java @@ -0,0 +1,105 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.streaming; + +import java.util.concurrent.TimeUnit; + +import scala.concurrent.duration.Duration; + +import org.apache.spark.annotation.Experimental; +import org.apache.spark.annotation.InterfaceStability; +import org.apache.spark.sql.execution.streaming.OneTimeTrigger$; + +/** + * :: Experimental :: + * Policy used to indicate how often results should be produced by a [[StreamingQuery]]. + * + * @since 2.0.0 + */ +@Experimental +@InterfaceStability.Evolving +public class Trigger { + + /** + * :: Experimental :: + * A trigger policy that runs a query periodically based on an interval in processing time. + * If `interval` is 0, the query will run as fast as possible. + * + * @since 2.2.0 + */ + public static Trigger ProcessingTime(long intervalMs) { + return ProcessingTime.create(intervalMs, TimeUnit.MILLISECONDS); + } + + /** + * :: Experimental :: + * (Java-friendly) + * A trigger policy that runs a query periodically based on an interval in processing time. + * If `interval` is 0, the query will run as fast as possible. + * + * {{{ + * import java.util.concurrent.TimeUnit + * df.writeStream.trigger(ProcessingTime.create(10, TimeUnit.SECONDS)) + * }}} + * + * @since 2.2.0 + */ + public static Trigger ProcessingTime(long interval, TimeUnit timeUnit) { + return ProcessingTime.create(interval, timeUnit); + } + + /** + * :: Experimental :: + * (Scala-friendly) + * A trigger policy that runs a query periodically based on an interval in processing time. + * If `duration` is 0, the query will run as fast as possible. + * + * {{{ + * import scala.concurrent.duration._ + * df.writeStream.trigger(ProcessingTime(10.seconds)) + * }}} + * @since 2.2.0 + */ + public static Trigger ProcessingTime(Duration interval) { + return ProcessingTime.apply(interval); + } + + /** + * :: Experimental :: + * A trigger policy that runs a query periodically based on an interval in processing time. + * If `interval` is effectively 0, the query will run as fast as possible. + * + * {{{ + * df.writeStream.trigger(Trigger.ProcessingTime("10 seconds")) + * }}} + * @since 2.2.0 + */ + public static Trigger ProcessingTime(String interval) { + return ProcessingTime.apply(interval); + } + + /** + * A trigger that process only one batch of data in a streaming query then terminates + * the query. + * + * @since 2.2.0 + */ + public static Trigger Once() { + return OneTimeTrigger$.MODULE$; + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/util/QueryExecutionListener.scala b/sql/core/src/main/scala/org/apache/spark/sql/util/QueryExecutionListener.scala index 26ad0eadd9d4c..f6240d85fba6f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/util/QueryExecutionListener.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/util/QueryExecutionListener.scala @@ -98,6 +98,16 @@ class ExecutionListenerManager private[sql] () extends Logging { listeners.clear() } + /** + * Get an identical copy of this listener manager. + */ + @DeveloperApi + override def clone(): ExecutionListenerManager = writeLock { + val newListenerManager = new ExecutionListenerManager + listeners.foreach(newListenerManager.register) + newListenerManager + } + private[sql] def onSuccess(funcName: String, qe: QueryExecution, duration: Long): Unit = { readLock { withErrorHandling { listener => diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java index be8d95d0d9124..b007093dad84b 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java @@ -423,4 +423,36 @@ public void testJsonRDDToDataFrame() { Assert.assertEquals(1L, df.count()); Assert.assertEquals(2L, df.collectAsList().get(0).getLong(0)); } + + public class CircularReference1Bean implements Serializable { + private CircularReference2Bean child; + + public CircularReference2Bean getChild() { + return child; + } + + public void setChild(CircularReference2Bean child) { + this.child = child; + } + } + + public class CircularReference2Bean implements Serializable { + private CircularReference1Bean child; + + public CircularReference1Bean getChild() { + return child; + } + + public void setChild(CircularReference1Bean child) { + this.child = child; + } + } + + // Checks a simple case for DataFrame here and put exhaustive tests for the issue + // of circular references in `JavaDatasetSuite`. + @Test(expected = UnsupportedOperationException.class) + public void testCircularReferenceBean() { + CircularReference1Bean bean = new CircularReference1Bean(); + spark.createDataFrame(Arrays.asList(bean), CircularReference1Bean.class); + } } diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java index e3b0e37ccab05..3ba37addfc8b4 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java @@ -23,6 +23,8 @@ import java.sql.Timestamp; import java.util.*; +import org.apache.spark.sql.streaming.GroupStateTimeout; +import org.apache.spark.sql.streaming.OutputMode; import scala.Tuple2; import scala.Tuple3; import scala.Tuple4; @@ -117,7 +119,7 @@ public void testCommonOperation() { Dataset parMapped = ds.mapPartitions((MapPartitionsFunction) it -> { List ls = new LinkedList<>(); while (it.hasNext()) { - ls.add(it.next().toUpperCase(Locale.ENGLISH)); + ls.add(it.next().toUpperCase(Locale.ROOT)); } return ls.iterator(); }, Encoders.STRING()); @@ -205,8 +207,10 @@ public void testGroupBy() { } return Collections.singletonList(sb.toString()).iterator(); }, + OutputMode.Append(), Encoders.LONG(), - Encoders.STRING()); + Encoders.STRING(), + GroupStateTimeout.NoTimeout()); Assert.assertEquals(asSet("1a", "3foobar"), toSet(flatMapped2.collectAsList())); @@ -1289,4 +1293,110 @@ public void testEmptyBean() { Assert.assertEquals(df.schema().length(), 0); Assert.assertEquals(df.collectAsList().size(), 1); } + + public class CircularReference1Bean implements Serializable { + private CircularReference2Bean child; + + public CircularReference2Bean getChild() { + return child; + } + + public void setChild(CircularReference2Bean child) { + this.child = child; + } + } + + public class CircularReference2Bean implements Serializable { + private CircularReference1Bean child; + + public CircularReference1Bean getChild() { + return child; + } + + public void setChild(CircularReference1Bean child) { + this.child = child; + } + } + + public class CircularReference3Bean implements Serializable { + private CircularReference3Bean[] child; + + public CircularReference3Bean[] getChild() { + return child; + } + + public void setChild(CircularReference3Bean[] child) { + this.child = child; + } + } + + public class CircularReference4Bean implements Serializable { + private Map child; + + public Map getChild() { + return child; + } + + public void setChild(Map child) { + this.child = child; + } + } + + public class CircularReference5Bean implements Serializable { + private String id; + private List child; + + public String getId() { + return id; + } + + public List getChild() { + return child; + } + + public void setId(String id) { + this.id = id; + } + + public void setChild(List child) { + this.child = child; + } + } + + @Test(expected = UnsupportedOperationException.class) + public void testCircularReferenceBean1() { + CircularReference1Bean bean = new CircularReference1Bean(); + spark.createDataset(Arrays.asList(bean), Encoders.bean(CircularReference1Bean.class)); + } + + @Test(expected = UnsupportedOperationException.class) + public void testCircularReferenceBean2() { + CircularReference3Bean bean = new CircularReference3Bean(); + spark.createDataset(Arrays.asList(bean), Encoders.bean(CircularReference3Bean.class)); + } + + @Test(expected = UnsupportedOperationException.class) + public void testCircularReferenceBean3() { + CircularReference4Bean bean = new CircularReference4Bean(); + spark.createDataset(Arrays.asList(bean), Encoders.bean(CircularReference4Bean.class)); + } + + @Test(expected = RuntimeException.class) + public void testNullInTopLevelBean() { + NestedSmallBean bean = new NestedSmallBean(); + // We cannot set null in top-level bean + spark.createDataset(Arrays.asList(bean, null), Encoders.bean(NestedSmallBean.class)); + } + + @Test + public void testSerializeNull() { + NestedSmallBean bean = new NestedSmallBean(); + Encoder encoder = Encoders.bean(NestedSmallBean.class); + List beans = Arrays.asList(bean); + Dataset ds1 = spark.createDataset(beans, encoder); + Assert.assertEquals(beans, ds1.collectAsList()); + Dataset ds2 = + ds1.map((MapFunction) b -> b, encoder); + Assert.assertEquals(beans, ds2.collectAsList()); + } } diff --git a/sql/core/src/test/resources/sql-tests/inputs/describe.sql b/sql/core/src/test/resources/sql-tests/inputs/describe.sql index ff327f5e82b13..6de4cf0d5afa1 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/describe.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/describe.sql @@ -1,22 +1,35 @@ -CREATE TABLE t (a STRING, b INT, c STRING, d STRING) USING parquet PARTITIONED BY (c, d); +CREATE TABLE t (a STRING, b INT, c STRING, d STRING) USING parquet + PARTITIONED BY (c, d) CLUSTERED BY (a) SORTED BY (b ASC) INTO 2 BUCKETS + COMMENT 'table_comment'; + +CREATE TEMPORARY VIEW temp_v AS SELECT * FROM t; + +CREATE TEMPORARY VIEW temp_Data_Source_View + USING org.apache.spark.sql.sources.DDLScanSource + OPTIONS ( + From '1', + To '10', + Table 'test1'); + +CREATE VIEW v AS SELECT * FROM t; ALTER TABLE t ADD PARTITION (c='Us', d=1); DESCRIBE t; -DESC t; +DESC default.t; DESC TABLE t; --- Ignore these because there exist timestamp results, e.g., `Create Table`. --- DESC EXTENDED t; --- DESC FORMATTED t; +DESC FORMATTED t; + +DESC EXTENDED t; DESC t PARTITION (c='Us', d=1); --- Ignore these because there exist timestamp results, e.g., transient_lastDdlTime. --- DESC EXTENDED t PARTITION (c='Us', d=1); --- DESC FORMATTED t PARTITION (c='Us', d=1); +DESC EXTENDED t PARTITION (c='Us', d=1); + +DESC FORMATTED t PARTITION (c='Us', d=1); -- NoSuchPartitionException: Partition not found in table DESC t PARTITION (c='Us', d=2); @@ -27,5 +40,39 @@ DESC t PARTITION (c='Us'); -- ParseException: PARTITION specification is incomplete DESC t PARTITION (c='Us', d); --- DROP TEST TABLE +-- DESC Temp View + +DESC temp_v; + +DESC TABLE temp_v; + +DESC FORMATTED temp_v; + +DESC EXTENDED temp_v; + +DESC temp_Data_Source_View; + +-- AnalysisException DESC PARTITION is not allowed on a temporary view +DESC temp_v PARTITION (c='Us', d=1); + +-- DESC Persistent View + +DESC v; + +DESC TABLE v; + +DESC FORMATTED v; + +DESC EXTENDED v; + +-- AnalysisException DESC PARTITION is not allowed on a view +DESC v PARTITION (c='Us', d=1); + +-- DROP TEST TABLES/VIEWS DROP TABLE t; + +DROP VIEW temp_v; + +DROP VIEW temp_Data_Source_View; + +DROP VIEW v; diff --git a/sql/core/src/test/resources/sql-tests/inputs/group-by-ordinal.sql b/sql/core/src/test/resources/sql-tests/inputs/group-by-ordinal.sql index 9c8d851e36e9b..6566338f3d4a9 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/group-by-ordinal.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/group-by-ordinal.sql @@ -49,6 +49,9 @@ select a, count(a) from (select 1 as a) tmp group by 1 order by 1; -- group by ordinal followed by having select count(a), a from (select 1 as a) tmp group by 2 having a > 0; +-- mixed cases: group-by ordinals and aliases +select a, a AS k, count(b) from data group by k, 1; + -- turn of group by ordinal set spark.sql.groupByOrdinal=false; diff --git a/sql/core/src/test/resources/sql-tests/inputs/group-by.sql b/sql/core/src/test/resources/sql-tests/inputs/group-by.sql index 4d0ed43153004..a7994f3beaff3 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/group-by.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/group-by.sql @@ -35,3 +35,21 @@ FROM testData; -- Aggregate with foldable input and multiple distinct groups. SELECT COUNT(DISTINCT b), COUNT(DISTINCT b, c) FROM (SELECT 1 AS a, 2 AS b, 3 AS c) GROUP BY a; + +-- Aliases in SELECT could be used in GROUP BY +SELECT a AS k, COUNT(b) FROM testData GROUP BY k; +SELECT a AS k, COUNT(b) FROM testData GROUP BY k HAVING k > 1; + +-- Aggregate functions cannot be used in GROUP BY +SELECT COUNT(b) AS k FROM testData GROUP BY k; + +-- Test data. +CREATE OR REPLACE TEMPORARY VIEW testDataHasSameNameWithAlias AS SELECT * FROM VALUES +(1, 1, 3), (1, 2, 1) AS testDataHasSameNameWithAlias(k, a, v); +SELECT k AS a, COUNT(v) FROM testDataHasSameNameWithAlias GROUP BY a; + +-- turn off group by aliases +set spark.sql.groupByAliases=false; + +-- Check analysis exceptions +SELECT a AS k, COUNT(b) FROM testData GROUP BY k; diff --git a/sql/core/src/test/resources/sql-tests/inputs/having.sql b/sql/core/src/test/resources/sql-tests/inputs/having.sql index 364c022d959dc..868a911e787f6 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/having.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/having.sql @@ -13,3 +13,6 @@ SELECT count(k) FROM hav GROUP BY v + 1 HAVING v + 1 = 2; -- SPARK-11032: resolve having correctly SELECT MIN(t.v) FROM (SELECT * FROM hav WHERE v > 0) t HAVING(COUNT(1) > 0); + +-- SPARK-20329: make sure we handle timezones correctly +SELECT a + b FROM VALUES (1L, 2), (3L, 4) AS T(a, b) GROUP BY a + b HAVING a + b > 1; diff --git a/sql/core/src/test/resources/sql-tests/inputs/json-functions.sql b/sql/core/src/test/resources/sql-tests/inputs/json-functions.sql new file mode 100644 index 0000000000000..b3cc2cea51d43 --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/inputs/json-functions.sql @@ -0,0 +1,22 @@ +-- to_json +describe function to_json; +describe function extended to_json; +select to_json(named_struct('a', 1, 'b', 2)); +select to_json(named_struct('time', to_timestamp('2015-08-26', 'yyyy-MM-dd')), map('timestampFormat', 'dd/MM/yyyy')); +select to_json(array(named_struct('a', 1, 'b', 2))); +-- Check if errors handled +select to_json(named_struct('a', 1, 'b', 2), named_struct('mode', 'PERMISSIVE')); +select to_json(named_struct('a', 1, 'b', 2), map('mode', 1)); +select to_json(); + +-- from_json +describe function from_json; +describe function extended from_json; +select from_json('{"a":1}', 'a INT'); +select from_json('{"time":"26/08/2015"}', 'time Timestamp', map('timestampFormat', 'dd/MM/yyyy')); +-- Check if errors handled +select from_json('{"a":1}', 1); +select from_json('{"a":1}', 'a InvalidType'); +select from_json('{"a":1}', 'a INT', named_struct('mode', 'PERMISSIVE')); +select from_json('{"a":1}', 'a INT', map('mode', 1)); +select from_json(); diff --git a/sql/core/src/test/resources/sql-tests/inputs/show-tables.sql b/sql/core/src/test/resources/sql-tests/inputs/show-tables.sql index 10c379dfa014e..3c77c9977d80f 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/show-tables.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/show-tables.sql @@ -17,10 +17,21 @@ SHOW TABLES LIKE 'show_t1*|show_t2*'; SHOW TABLES IN showdb 'show_t*'; -- SHOW TABLE EXTENDED --- Ignore these because there exist timestamp results, e.g. `Created`. --- SHOW TABLE EXTENDED LIKE 'show_t*'; +SHOW TABLE EXTENDED LIKE 'show_t*'; SHOW TABLE EXTENDED; + +-- SHOW TABLE EXTENDED ... PARTITION +SHOW TABLE EXTENDED LIKE 'show_t1' PARTITION(c='Us', d=1); +-- Throw a ParseException if table name is not specified. +SHOW TABLE EXTENDED PARTITION(c='Us', d=1); +-- Don't support regular expression for table name if a partition specification is present. +SHOW TABLE EXTENDED LIKE 'show_t*' PARTITION(c='Us', d=1); +-- Partition specification is not complete. SHOW TABLE EXTENDED LIKE 'show_t1' PARTITION(c='Us'); +-- Partition specification is invalid. +SHOW TABLE EXTENDED LIKE 'show_t1' PARTITION(a='Us', d=1); +-- Partition specification doesn't exist. +SHOW TABLE EXTENDED LIKE 'show_t1' PARTITION(c='Ch', d=1); -- Clean Up DROP TABLE show_t1; diff --git a/sql/core/src/test/resources/sql-tests/inputs/struct.sql b/sql/core/src/test/resources/sql-tests/inputs/struct.sql new file mode 100644 index 0000000000000..e56344dc4de80 --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/inputs/struct.sql @@ -0,0 +1,20 @@ +CREATE TEMPORARY VIEW tbl_x AS VALUES + (1, NAMED_STRUCT('C', 'gamma', 'D', 'delta')), + (2, NAMED_STRUCT('C', 'epsilon', 'D', 'eta')), + (3, NAMED_STRUCT('C', 'theta', 'D', 'iota')) + AS T(ID, ST); + +-- Create a struct +SELECT STRUCT('alpha', 'beta') ST; + +-- Create a struct with aliases +SELECT STRUCT('alpha' AS A, 'beta' AS B) ST; + +-- Star expansion in a struct. +SELECT ID, STRUCT(ST.*) NST FROM tbl_x; + +-- Append a column to a struct +SELECT ID, STRUCT(ST.*,CAST(ID AS STRING) AS E) NST FROM tbl_x; + +-- Prepend a column to a struct +SELECT ID, STRUCT(CAST(ID AS STRING) AS AA, ST.*) NST FROM tbl_x; diff --git a/sql/core/src/test/resources/sql-tests/inputs/subquery/in-subquery/simple-in.sql b/sql/core/src/test/resources/sql-tests/inputs/subquery/in-subquery/simple-in.sql index 20370b045e803..f19567d2fac20 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/subquery/in-subquery/simple-in.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/subquery/in-subquery/simple-in.sql @@ -109,4 +109,28 @@ FROM t1 WHERE t1a NOT IN (SELECT t2a FROM t2); +-- DDLs +create temporary view a as select * from values + (1, 1), (2, 1), (null, 1), (1, 3), (null, 3), (1, null), (null, 2) + as a(a1, a2); +create temporary view b as select * from values + (1, 1, 2), (null, 3, 2), (1, null, 2), (1, 2, null) + as b(b1, b2, b3); + +-- TC 02.01 +SELECT a1, a2 +FROM a +WHERE a1 NOT IN (SELECT b.b1 + FROM b + WHERE a.a2 = b.b2) +; + +-- TC 02.02 +SELECT a1, a2 +FROM a +WHERE a1 NOT IN (SELECT b.b1 + FROM b + WHERE a.a2 = b.b2 + AND b.b3 > 1) +; diff --git a/sql/core/src/test/resources/sql-tests/inputs/subquery/negative-cases/invalid-correlation.sql b/sql/core/src/test/resources/sql-tests/inputs/subquery/negative-cases/invalid-correlation.sql index cf93c5a835971..e22cade936792 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/subquery/negative-cases/invalid-correlation.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/subquery/negative-cases/invalid-correlation.sql @@ -1,42 +1,72 @@ -- The test file contains negative test cases -- of invalid queries where error messages are expected. -create temporary view t1 as select * from values +CREATE TEMPORARY VIEW t1 AS SELECT * FROM VALUES (1, 2, 3) -as t1(t1a, t1b, t1c); +AS t1(t1a, t1b, t1c); -create temporary view t2 as select * from values +CREATE TEMPORARY VIEW t2 AS SELECT * FROM VALUES (1, 0, 1) -as t2(t2a, t2b, t2c); +AS t2(t2a, t2b, t2c); -create temporary view t3 as select * from values +CREATE TEMPORARY VIEW t3 AS SELECT * FROM VALUES (3, 1, 2) -as t3(t3a, t3b, t3c); +AS t3(t3a, t3b, t3c); -- TC 01.01 -- The column t2b in the SELECT of the subquery is invalid -- because it is neither an aggregate function nor a GROUP BY column. -select t1a, t2b -from t1, t2 -where t1b = t2c -and t2b = (select max(avg) - from (select t2b, avg(t2b) avg - from t2 - where t2a = t1.t1b +SELECT t1a, t2b +FROM t1, t2 +WHERE t1b = t2c +AND t2b = (SELECT max(avg) + FROM (SELECT t2b, avg(t2b) avg + FROM t2 + WHERE t2a = t1.t1b ) ) ; -- TC 01.02 -- Invalid due to the column t2b not part of the output from table t2. -select * -from t1 -where t1a in (select min(t2a) - from t2 - group by t2c - having t2c in (select max(t3c) - from t3 - group by t3b - having t3b > t2b )) +SELECT * +FROM t1 +WHERE t1a IN (SELECT min(t2a) + FROM t2 + GROUP BY t2c + HAVING t2c IN (SELECT max(t3c) + FROM t3 + GROUP BY t3b + HAVING t3b > t2b )) ; +-- TC 01.03 +-- Invalid due to mixure of outer and local references under an AggegatedExpression +-- in a correlated predicate +SELECT t1a +FROM t1 +GROUP BY 1 +HAVING EXISTS (SELECT 1 + FROM t2 + WHERE t2a < min(t1a + t2a)); + +-- TC 01.04 +-- Invalid due to mixure of outer and local references under an AggegatedExpression +SELECT t1a +FROM t1 +WHERE t1a IN (SELECT t2a + FROM t2 + WHERE EXISTS (SELECT 1 + FROM t3 + GROUP BY 1 + HAVING min(t2a + t3a) > 1)); + +-- TC 01.05 +-- Invalid due to outer reference appearing in projection list +SELECT t1a +FROM t1 +WHERE t1a IN (SELECT t2a + FROM t2 + WHERE EXISTS (SELECT min(t2a) + FROM t3)); + diff --git a/sql/core/src/test/resources/sql-tests/inputs/table-valued-functions.sql b/sql/core/src/test/resources/sql-tests/inputs/table-valued-functions.sql index 2e6dcd538b7ac..d0d2df7b243d5 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/table-valued-functions.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/table-valued-functions.sql @@ -18,3 +18,9 @@ select * from range(1, 1, 1, 1, 1); -- range call with null select * from range(1, null); + +-- range call with a mixed-case function name +select * from RaNgE(2); + +-- Explain +EXPLAIN select * from RaNgE(2); diff --git a/sql/core/src/test/resources/sql-tests/results/change-column.sql.out b/sql/core/src/test/resources/sql-tests/results/change-column.sql.out index ba8bc936f0c79..678a3f0f0a3c6 100644 --- a/sql/core/src/test/resources/sql-tests/results/change-column.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/change-column.sql.out @@ -15,6 +15,7 @@ DESC test_change -- !query 1 schema struct -- !query 1 output +# col_name data_type comment a int b string c int @@ -34,6 +35,7 @@ DESC test_change -- !query 3 schema struct -- !query 3 output +# col_name data_type comment a int b string c int @@ -53,6 +55,7 @@ DESC test_change -- !query 5 schema struct -- !query 5 output +# col_name data_type comment a int b string c int @@ -91,6 +94,7 @@ DESC test_change -- !query 8 schema struct -- !query 8 output +# col_name data_type comment a int b string c int @@ -125,6 +129,7 @@ DESC test_change -- !query 12 schema struct -- !query 12 output +# col_name data_type comment a int this is column a b string #*02?` c int @@ -143,6 +148,7 @@ DESC test_change -- !query 14 schema struct -- !query 14 output +# col_name data_type comment a int this is column a b string #*02?` c int @@ -162,6 +168,7 @@ DESC test_change -- !query 16 schema struct -- !query 16 output +# col_name data_type comment a int this is column a b string #*02?` c int @@ -186,6 +193,7 @@ DESC test_change -- !query 18 schema struct -- !query 18 output +# col_name data_type comment a int this is column a b string #*02?` c int @@ -229,6 +237,7 @@ DESC test_change -- !query 23 schema struct -- !query 23 output +# col_name data_type comment a int this is column A b string #*02?` c int diff --git a/sql/core/src/test/resources/sql-tests/results/describe.sql.out b/sql/core/src/test/resources/sql-tests/results/describe.sql.out index 0a11c1cde2b45..de10b29f3c65b 100644 --- a/sql/core/src/test/resources/sql-tests/results/describe.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/describe.sql.out @@ -1,9 +1,11 @@ -- Automatically generated by SQLQueryTestSuite --- Number of queries: 10 +-- Number of queries: 31 -- !query 0 -CREATE TABLE t (a STRING, b INT, c STRING, d STRING) USING parquet PARTITIONED BY (c, d) +CREATE TABLE t (a STRING, b INT, c STRING, d STRING) USING parquet + PARTITIONED BY (c, d) CLUSTERED BY (a) SORTED BY (b ASC) INTO 2 BUCKETS + COMMENT 'table_comment' -- !query 0 schema struct<> -- !query 0 output @@ -11,7 +13,7 @@ struct<> -- !query 1 -ALTER TABLE t ADD PARTITION (c='Us', d=1) +CREATE TEMPORARY VIEW temp_v AS SELECT * FROM t -- !query 1 schema struct<> -- !query 1 output @@ -19,90 +21,239 @@ struct<> -- !query 2 -DESCRIBE t +CREATE TEMPORARY VIEW temp_Data_Source_View + USING org.apache.spark.sql.sources.DDLScanSource + OPTIONS ( + From '1', + To '10', + Table 'test1') -- !query 2 schema -struct +struct<> -- !query 2 output -# Partition Information + + + +-- !query 3 +CREATE VIEW v AS SELECT * FROM t +-- !query 3 schema +struct<> +-- !query 3 output + + + +-- !query 4 +ALTER TABLE t ADD PARTITION (c='Us', d=1) +-- !query 4 schema +struct<> +-- !query 4 output + + + +-- !query 5 +DESCRIBE t +-- !query 5 schema +struct +-- !query 5 output # col_name data_type comment a string b int c string -c string d string +# Partition Information +# col_name data_type comment +c string d string --- !query 3 -DESC t --- !query 3 schema +-- !query 6 +DESC default.t +-- !query 6 schema struct --- !query 3 output -# Partition Information +-- !query 6 output # col_name data_type comment a string b int c string -c string d string +# Partition Information +# col_name data_type comment +c string d string --- !query 4 +-- !query 7 DESC TABLE t --- !query 4 schema +-- !query 7 schema struct --- !query 4 output +-- !query 7 output +# col_name data_type comment +a string +b int +c string +d string # Partition Information # col_name data_type comment +c string +d string + + +-- !query 8 +DESC FORMATTED t +-- !query 8 schema +struct +-- !query 8 output +# col_name data_type comment a string b int c string +d string +# Partition Information +# col_name data_type comment c string d string -d string + +# Detailed Table Information +Database default +Table t +Created [not included in comparison] +Last Access [not included in comparison] +Type MANAGED +Provider parquet +Num Buckets 2 +Bucket Columns [`a`] +Sort Columns [`b`] +Comment table_comment +Location [not included in comparison]sql/core/spark-warehouse/t +Partition Provider Catalog --- !query 5 +-- !query 9 +DESC EXTENDED t +-- !query 9 schema +struct +-- !query 9 output +# col_name data_type comment +a string +b int +c string +d string +# Partition Information +# col_name data_type comment +c string +d string + +# Detailed Table Information +Database default +Table t +Created [not included in comparison] +Last Access [not included in comparison] +Type MANAGED +Provider parquet +Num Buckets 2 +Bucket Columns [`a`] +Sort Columns [`b`] +Comment table_comment +Location [not included in comparison]sql/core/spark-warehouse/t +Partition Provider Catalog + + +-- !query 10 DESC t PARTITION (c='Us', d=1) --- !query 5 schema +-- !query 10 schema struct --- !query 5 output +-- !query 10 output +# col_name data_type comment +a string +b int +c string +d string # Partition Information # col_name data_type comment +c string +d string + + +-- !query 11 +DESC EXTENDED t PARTITION (c='Us', d=1) +-- !query 11 schema +struct +-- !query 11 output +# col_name data_type comment a string b int c string +d string +# Partition Information +# col_name data_type comment c string d string -d string + +# Detailed Partition Information +Database default +Table t +Partition Values [c=Us, d=1] +Location [not included in comparison]sql/core/spark-warehouse/t/c=Us/d=1 + +# Storage Information +Num Buckets 2 +Bucket Columns [`a`] +Sort Columns [`b`] +Location [not included in comparison]sql/core/spark-warehouse/t --- !query 6 +-- !query 12 +DESC FORMATTED t PARTITION (c='Us', d=1) +-- !query 12 schema +struct +-- !query 12 output +# col_name data_type comment +a string +b int +c string +d string +# Partition Information +# col_name data_type comment +c string +d string + +# Detailed Partition Information +Database default +Table t +Partition Values [c=Us, d=1] +Location [not included in comparison]sql/core/spark-warehouse/t/c=Us/d=1 + +# Storage Information +Num Buckets 2 +Bucket Columns [`a`] +Sort Columns [`b`] +Location [not included in comparison]sql/core/spark-warehouse/t + + +-- !query 13 DESC t PARTITION (c='Us', d=2) --- !query 6 schema +-- !query 13 schema struct<> --- !query 6 output +-- !query 13 output org.apache.spark.sql.catalyst.analysis.NoSuchPartitionException Partition not found in table 't' database 'default': c -> Us d -> 2; --- !query 7 +-- !query 14 DESC t PARTITION (c='Us') --- !query 7 schema +-- !query 14 schema struct<> --- !query 7 output +-- !query 14 output org.apache.spark.sql.AnalysisException Partition spec is invalid. The spec (c) must match the partition spec (c, d) defined in table '`default`.`t`'; --- !query 8 +-- !query 15 DESC t PARTITION (c='Us', d) --- !query 8 schema +-- !query 15 schema struct<> --- !query 8 output +-- !query 15 output org.apache.spark.sql.catalyst.parser.ParseException PARTITION specification is incomplete: `d`(line 1, pos 0) @@ -112,9 +263,193 @@ DESC t PARTITION (c='Us', d) ^^^ --- !query 9 +-- !query 16 +DESC temp_v +-- !query 16 schema +struct +-- !query 16 output +# col_name data_type comment +a string +b int +c string +d string + + +-- !query 17 +DESC TABLE temp_v +-- !query 17 schema +struct +-- !query 17 output +# col_name data_type comment +a string +b int +c string +d string + + +-- !query 18 +DESC FORMATTED temp_v +-- !query 18 schema +struct +-- !query 18 output +# col_name data_type comment +a string +b int +c string +d string + + +-- !query 19 +DESC EXTENDED temp_v +-- !query 19 schema +struct +-- !query 19 output +# col_name data_type comment +a string +b int +c string +d string + + +-- !query 20 +DESC temp_Data_Source_View +-- !query 20 schema +struct +-- !query 20 output +# col_name data_type comment +intType int test comment test1 +stringType string +dateType date +timestampType timestamp +doubleType double +bigintType bigint +tinyintType tinyint +decimalType decimal(10,0) +fixedDecimalType decimal(5,1) +binaryType binary +booleanType boolean +smallIntType smallint +floatType float +mapType map +arrayType array +structType struct + + +-- !query 21 +DESC temp_v PARTITION (c='Us', d=1) +-- !query 21 schema +struct<> +-- !query 21 output +org.apache.spark.sql.AnalysisException +DESC PARTITION is not allowed on a temporary view: temp_v; + + +-- !query 22 +DESC v +-- !query 22 schema +struct +-- !query 22 output +# col_name data_type comment +a string +b int +c string +d string + + +-- !query 23 +DESC TABLE v +-- !query 23 schema +struct +-- !query 23 output +# col_name data_type comment +a string +b int +c string +d string + + +-- !query 24 +DESC FORMATTED v +-- !query 24 schema +struct +-- !query 24 output +# col_name data_type comment +a string +b int +c string +d string + +# Detailed Table Information +Database default +Table v +Created [not included in comparison] +Last Access [not included in comparison] +Type VIEW +View Text SELECT * FROM t +View Default Database default +View Query Output Columns [a, b, c, d] +Properties [view.query.out.col.3=d, view.query.out.col.0=a, view.query.out.numCols=4, view.default.database=default, view.query.out.col.1=b, view.query.out.col.2=c] + + +-- !query 25 +DESC EXTENDED v +-- !query 25 schema +struct +-- !query 25 output +# col_name data_type comment +a string +b int +c string +d string + +# Detailed Table Information +Database default +Table v +Created [not included in comparison] +Last Access [not included in comparison] +Type VIEW +View Text SELECT * FROM t +View Default Database default +View Query Output Columns [a, b, c, d] +Properties [view.query.out.col.3=d, view.query.out.col.0=a, view.query.out.numCols=4, view.default.database=default, view.query.out.col.1=b, view.query.out.col.2=c] + + +-- !query 26 +DESC v PARTITION (c='Us', d=1) +-- !query 26 schema +struct<> +-- !query 26 output +org.apache.spark.sql.AnalysisException +DESC PARTITION is not allowed on a view: v; + + +-- !query 27 DROP TABLE t --- !query 9 schema +-- !query 27 schema struct<> --- !query 9 output +-- !query 27 output + + + +-- !query 28 +DROP VIEW temp_v +-- !query 28 schema +struct<> +-- !query 28 output + + + +-- !query 29 +DROP VIEW temp_Data_Source_View +-- !query 29 schema +struct<> +-- !query 29 output + + + +-- !query 30 +DROP VIEW v +-- !query 30 schema +struct<> +-- !query 30 output diff --git a/sql/core/src/test/resources/sql-tests/results/group-by-ordinal.sql.out b/sql/core/src/test/resources/sql-tests/results/group-by-ordinal.sql.out index c0930bbde69a4..9ecbe19078dd6 100644 --- a/sql/core/src/test/resources/sql-tests/results/group-by-ordinal.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/group-by-ordinal.sql.out @@ -1,5 +1,5 @@ -- Automatically generated by SQLQueryTestSuite --- Number of queries: 19 +-- Number of queries: 20 -- !query 0 @@ -122,7 +122,7 @@ select a, b, sum(b) from data group by 3 struct<> -- !query 11 output org.apache.spark.sql.AnalysisException -GROUP BY position 3 is an aggregate function, and aggregate functions are not allowed in GROUP BY; line 1 pos 39 +aggregate functions are not allowed in GROUP BY, but found sum(CAST(data.`b` AS BIGINT)); -- !query 12 @@ -131,7 +131,7 @@ select a, b, sum(b) + 2 from data group by 3 struct<> -- !query 12 output org.apache.spark.sql.AnalysisException -GROUP BY position 3 is an aggregate function, and aggregate functions are not allowed in GROUP BY; line 1 pos 43 +aggregate functions are not allowed in GROUP BY, but found (sum(CAST(data.`b` AS BIGINT)) + CAST(2 AS BIGINT)); -- !query 13 @@ -173,16 +173,26 @@ struct -- !query 17 -set spark.sql.groupByOrdinal=false +select a, a AS k, count(b) from data group by k, 1 -- !query 17 schema -struct +struct -- !query 17 output -spark.sql.groupByOrdinal false +1 1 2 +2 2 2 +3 3 2 -- !query 18 -select sum(b) from data group by -1 +set spark.sql.groupByOrdinal=false -- !query 18 schema -struct +struct -- !query 18 output +spark.sql.groupByOrdinal false + + +-- !query 19 +select sum(b) from data group by -1 +-- !query 19 schema +struct +-- !query 19 output 9 diff --git a/sql/core/src/test/resources/sql-tests/results/group-by.sql.out b/sql/core/src/test/resources/sql-tests/results/group-by.sql.out index 4b87d5161fc0e..6bf9dff883c1e 100644 --- a/sql/core/src/test/resources/sql-tests/results/group-by.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/group-by.sql.out @@ -1,5 +1,5 @@ -- Automatically generated by SQLQueryTestSuite --- Number of queries: 15 +-- Number of queries: 22 -- !query 0 @@ -139,3 +139,67 @@ SELECT COUNT(DISTINCT b), COUNT(DISTINCT b, c) FROM (SELECT 1 AS a, 2 AS b, 3 AS struct -- !query 14 output 1 1 + + +-- !query 15 +SELECT a AS k, COUNT(b) FROM testData GROUP BY k +-- !query 15 schema +struct +-- !query 15 output +1 2 +2 2 +3 2 +NULL 1 + + +-- !query 16 +SELECT a AS k, COUNT(b) FROM testData GROUP BY k HAVING k > 1 +-- !query 16 schema +struct +-- !query 16 output +2 2 +3 2 + + +-- !query 17 +SELECT COUNT(b) AS k FROM testData GROUP BY k +-- !query 17 schema +struct<> +-- !query 17 output +org.apache.spark.sql.AnalysisException +aggregate functions are not allowed in GROUP BY, but found count(testdata.`b`); + + +-- !query 18 +CREATE OR REPLACE TEMPORARY VIEW testDataHasSameNameWithAlias AS SELECT * FROM VALUES +(1, 1, 3), (1, 2, 1) AS testDataHasSameNameWithAlias(k, a, v) +-- !query 18 schema +struct<> +-- !query 18 output + + + +-- !query 19 +SELECT k AS a, COUNT(v) FROM testDataHasSameNameWithAlias GROUP BY a +-- !query 19 schema +struct<> +-- !query 19 output +org.apache.spark.sql.AnalysisException +expression 'testdatahassamenamewithalias.`k`' is neither present in the group by, nor is it an aggregate function. Add to group by or wrap in first() (or first_value) if you don't care which value you get.; + + +-- !query 20 +set spark.sql.groupByAliases=false +-- !query 20 schema +struct +-- !query 20 output +spark.sql.groupByAliases false + + +-- !query 21 +SELECT a AS k, COUNT(b) FROM testData GROUP BY k +-- !query 21 schema +struct<> +-- !query 21 output +org.apache.spark.sql.AnalysisException +cannot resolve '`k`' given input columns: [a, b]; line 1 pos 47 diff --git a/sql/core/src/test/resources/sql-tests/results/having.sql.out b/sql/core/src/test/resources/sql-tests/results/having.sql.out index e0923832673cb..d87ee5221647f 100644 --- a/sql/core/src/test/resources/sql-tests/results/having.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/having.sql.out @@ -1,5 +1,5 @@ -- Automatically generated by SQLQueryTestSuite --- Number of queries: 4 +-- Number of queries: 5 -- !query 0 @@ -38,3 +38,12 @@ SELECT MIN(t.v) FROM (SELECT * FROM hav WHERE v > 0) t HAVING(COUNT(1) > 0) struct -- !query 3 output 1 + + +-- !query 4 +SELECT a + b FROM VALUES (1L, 2), (3L, 4) AS T(a, b) GROUP BY a + b HAVING a + b > 1 +-- !query 4 schema +struct<(a + CAST(b AS BIGINT)):bigint> +-- !query 4 output +3 +7 diff --git a/sql/core/src/test/resources/sql-tests/results/json-functions.sql.out b/sql/core/src/test/resources/sql-tests/results/json-functions.sql.out new file mode 100644 index 0000000000000..fedabaee2237f --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/results/json-functions.sql.out @@ -0,0 +1,176 @@ +-- Automatically generated by SQLQueryTestSuite +-- Number of queries: 17 + + +-- !query 0 +describe function to_json +-- !query 0 schema +struct +-- !query 0 output +Class: org.apache.spark.sql.catalyst.expressions.StructsToJson +Function: to_json +Usage: to_json(expr[, options]) - Returns a json string with a given struct value + + +-- !query 1 +describe function extended to_json +-- !query 1 schema +struct +-- !query 1 output +Class: org.apache.spark.sql.catalyst.expressions.StructsToJson +Extended Usage: + Examples: + > SELECT to_json(named_struct('a', 1, 'b', 2)); + {"a":1,"b":2} + > SELECT to_json(named_struct('time', to_timestamp('2015-08-26', 'yyyy-MM-dd')), map('timestampFormat', 'dd/MM/yyyy')); + {"time":"26/08/2015"} + > SELECT to_json(array(named_struct('a', 1, 'b', 2)); + [{"a":1,"b":2}] + +Function: to_json +Usage: to_json(expr[, options]) - Returns a json string with a given struct value + + +-- !query 2 +select to_json(named_struct('a', 1, 'b', 2)) +-- !query 2 schema +struct +-- !query 2 output +{"a":1,"b":2} + + +-- !query 3 +select to_json(named_struct('time', to_timestamp('2015-08-26', 'yyyy-MM-dd')), map('timestampFormat', 'dd/MM/yyyy')) +-- !query 3 schema +struct +-- !query 3 output +{"time":"26/08/2015"} + + +-- !query 4 +select to_json(array(named_struct('a', 1, 'b', 2))) +-- !query 4 schema +struct +-- !query 4 output +[{"a":1,"b":2}] + + +-- !query 5 +select to_json(named_struct('a', 1, 'b', 2), named_struct('mode', 'PERMISSIVE')) +-- !query 5 schema +struct<> +-- !query 5 output +org.apache.spark.sql.AnalysisException +Must use a map() function for options;; line 1 pos 7 + + +-- !query 6 +select to_json(named_struct('a', 1, 'b', 2), map('mode', 1)) +-- !query 6 schema +struct<> +-- !query 6 output +org.apache.spark.sql.AnalysisException +A type of keys and values in map() must be string, but got MapType(StringType,IntegerType,false);; line 1 pos 7 + + +-- !query 7 +select to_json() +-- !query 7 schema +struct<> +-- !query 7 output +org.apache.spark.sql.AnalysisException +Invalid number of arguments for function to_json; line 1 pos 7 + + +-- !query 8 +describe function from_json +-- !query 8 schema +struct +-- !query 8 output +Class: org.apache.spark.sql.catalyst.expressions.JsonToStructs +Function: from_json +Usage: from_json(jsonStr, schema[, options]) - Returns a struct value with the given `jsonStr` and `schema`. + + +-- !query 9 +describe function extended from_json +-- !query 9 schema +struct +-- !query 9 output +Class: org.apache.spark.sql.catalyst.expressions.JsonToStructs +Extended Usage: + Examples: + > SELECT from_json('{"a":1, "b":0.8}', 'a INT, b DOUBLE'); + {"a":1, "b":0.8} + > SELECT from_json('{"time":"26/08/2015"}', 'time Timestamp', map('timestampFormat', 'dd/MM/yyyy')); + {"time":"2015-08-26 00:00:00.0"} + +Function: from_json +Usage: from_json(jsonStr, schema[, options]) - Returns a struct value with the given `jsonStr` and `schema`. + + +-- !query 10 +select from_json('{"a":1}', 'a INT') +-- !query 10 schema +struct> +-- !query 10 output +{"a":1} + + +-- !query 11 +select from_json('{"time":"26/08/2015"}', 'time Timestamp', map('timestampFormat', 'dd/MM/yyyy')) +-- !query 11 schema +struct> +-- !query 11 output +{"time":2015-08-26 00:00:00.0} + + +-- !query 12 +select from_json('{"a":1}', 1) +-- !query 12 schema +struct<> +-- !query 12 output +org.apache.spark.sql.AnalysisException +Expected a string literal instead of 1;; line 1 pos 7 + + +-- !query 13 +select from_json('{"a":1}', 'a InvalidType') +-- !query 13 schema +struct<> +-- !query 13 output +org.apache.spark.sql.AnalysisException + +DataType invalidtype is not supported.(line 1, pos 2) + +== SQL == +a InvalidType +--^^^ +; line 1 pos 7 + + +-- !query 14 +select from_json('{"a":1}', 'a INT', named_struct('mode', 'PERMISSIVE')) +-- !query 14 schema +struct<> +-- !query 14 output +org.apache.spark.sql.AnalysisException +Must use a map() function for options;; line 1 pos 7 + + +-- !query 15 +select from_json('{"a":1}', 'a INT', map('mode', 1)) +-- !query 15 schema +struct<> +-- !query 15 output +org.apache.spark.sql.AnalysisException +A type of keys and values in map() must be string, but got MapType(StringType,IntegerType,false);; line 1 pos 7 + + +-- !query 16 +select from_json() +-- !query 16 schema +struct<> +-- !query 16 output +org.apache.spark.sql.AnalysisException +Invalid number of arguments for function from_json; line 1 pos 7 diff --git a/sql/core/src/test/resources/sql-tests/results/show-tables.sql.out b/sql/core/src/test/resources/sql-tests/results/show-tables.sql.out index 3d287f43accc9..8f2a54f7c24e2 100644 --- a/sql/core/src/test/resources/sql-tests/results/show-tables.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/show-tables.sql.out @@ -1,5 +1,5 @@ -- Automatically generated by SQLQueryTestSuite --- Number of queries: 20 +-- Number of queries: 26 -- !query 0 @@ -114,76 +114,164 @@ show_t3 -- !query 12 -SHOW TABLE EXTENDED +SHOW TABLE EXTENDED LIKE 'show_t*' -- !query 12 schema -struct<> +struct -- !query 12 output -org.apache.spark.sql.catalyst.parser.ParseException - -mismatched input '' expecting 'LIKE'(line 1, pos 19) - -== SQL == -SHOW TABLE EXTENDED --------------------^^^ +show_t3 true Table: show_t3 +Created [not included in comparison] +Last Access [not included in comparison] +Type: VIEW +Schema: root + |-- e: integer (nullable = true) + + +showdb show_t1 false Database: showdb +Table: show_t1 +Created [not included in comparison] +Last Access [not included in comparison] +Type: MANAGED +Provider: parquet +Location [not included in comparison]sql/core/spark-warehouse/showdb.db/show_t1 +Partition Provider: Catalog +Partition Columns: [`c`, `d`] +Schema: root + |-- a: string (nullable = true) + |-- b: integer (nullable = true) + |-- c: string (nullable = true) + |-- d: string (nullable = true) + + +showdb show_t2 false Database: showdb +Table: show_t2 +Created [not included in comparison] +Last Access [not included in comparison] +Type: MANAGED +Provider: parquet +Location [not included in comparison]sql/core/spark-warehouse/showdb.db/show_t2 +Schema: root + |-- b: string (nullable = true) + |-- d: integer (nullable = true) -- !query 13 -SHOW TABLE EXTENDED LIKE 'show_t1' PARTITION(c='Us') +SHOW TABLE EXTENDED -- !query 13 schema struct<> -- !query 13 output org.apache.spark.sql.catalyst.parser.ParseException -Operation not allowed: SHOW TABLE EXTENDED ... PARTITION(line 1, pos 0) +mismatched input '' expecting 'LIKE'(line 1, pos 19) == SQL == -SHOW TABLE EXTENDED LIKE 'show_t1' PARTITION(c='Us') -^^^ +SHOW TABLE EXTENDED +-------------------^^^ -- !query 14 -DROP TABLE show_t1 +SHOW TABLE EXTENDED LIKE 'show_t1' PARTITION(c='Us', d=1) -- !query 14 schema -struct<> +struct -- !query 14 output - +showdb show_t1 false Partition Values: [c=Us, d=1] +Location [not included in comparison]sql/core/spark-warehouse/showdb.db/show_t1/c=Us/d=1 -- !query 15 -DROP TABLE show_t2 +SHOW TABLE EXTENDED PARTITION(c='Us', d=1) -- !query 15 schema struct<> -- !query 15 output +org.apache.spark.sql.catalyst.parser.ParseException + +mismatched input 'PARTITION' expecting 'LIKE'(line 1, pos 20) +== SQL == +SHOW TABLE EXTENDED PARTITION(c='Us', d=1) +--------------------^^^ -- !query 16 -DROP VIEW show_t3 +SHOW TABLE EXTENDED LIKE 'show_t*' PARTITION(c='Us', d=1) -- !query 16 schema struct<> -- !query 16 output - +org.apache.spark.sql.catalyst.analysis.NoSuchTableException +Table or view 'show_t*' not found in database 'showdb'; -- !query 17 -DROP VIEW global_temp.show_t4 +SHOW TABLE EXTENDED LIKE 'show_t1' PARTITION(c='Us') -- !query 17 schema struct<> -- !query 17 output - +org.apache.spark.sql.AnalysisException +Partition spec is invalid. The spec (c) must match the partition spec (c, d) defined in table '`showdb`.`show_t1`'; -- !query 18 -USE default +SHOW TABLE EXTENDED LIKE 'show_t1' PARTITION(a='Us', d=1) -- !query 18 schema struct<> -- !query 18 output - +org.apache.spark.sql.AnalysisException +Partition spec is invalid. The spec (a, d) must match the partition spec (c, d) defined in table '`showdb`.`show_t1`'; -- !query 19 -DROP DATABASE showdb +SHOW TABLE EXTENDED LIKE 'show_t1' PARTITION(c='Ch', d=1) -- !query 19 schema struct<> -- !query 19 output +org.apache.spark.sql.catalyst.analysis.NoSuchPartitionException +Partition not found in table 'show_t1' database 'showdb': +c -> Ch +d -> 1; + + +-- !query 20 +DROP TABLE show_t1 +-- !query 20 schema +struct<> +-- !query 20 output + + + +-- !query 21 +DROP TABLE show_t2 +-- !query 21 schema +struct<> +-- !query 21 output + + + +-- !query 22 +DROP VIEW show_t3 +-- !query 22 schema +struct<> +-- !query 22 output + + + +-- !query 23 +DROP VIEW global_temp.show_t4 +-- !query 23 schema +struct<> +-- !query 23 output + + + +-- !query 24 +USE default +-- !query 24 schema +struct<> +-- !query 24 output + + + +-- !query 25 +DROP DATABASE showdb +-- !query 25 schema +struct<> +-- !query 25 output diff --git a/sql/core/src/test/resources/sql-tests/results/sql-compatibility-functions.sql.out b/sql/core/src/test/resources/sql-tests/results/sql-compatibility-functions.sql.out index 9f0b95994be53..732b11050f461 100644 --- a/sql/core/src/test/resources/sql-tests/results/sql-compatibility-functions.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/sql-compatibility-functions.sql.out @@ -88,7 +88,7 @@ Project [coalesce(cast(id#xL as string), x) AS ifnull(`id`, 'x')#x, id#xL AS nul == Physical Plan == *Project [coalesce(cast(id#xL as string), x) AS ifnull(`id`, 'x')#x, id#xL AS nullif(`id`, 'x')#xL, coalesce(cast(id#xL as string), x) AS nvl(`id`, 'x')#x, x AS nvl2(`id`, 'x', 'y')#x] -+- *Range (0, 2, step=1, splits=None) ++- *Range (0, 2, step=1, splits=2) -- !query 9 diff --git a/sql/core/src/test/resources/sql-tests/results/struct.sql.out b/sql/core/src/test/resources/sql-tests/results/struct.sql.out new file mode 100644 index 0000000000000..3e32f46195464 --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/results/struct.sql.out @@ -0,0 +1,60 @@ +-- Automatically generated by SQLQueryTestSuite +-- Number of queries: 6 + + +-- !query 0 +CREATE TEMPORARY VIEW tbl_x AS VALUES + (1, NAMED_STRUCT('C', 'gamma', 'D', 'delta')), + (2, NAMED_STRUCT('C', 'epsilon', 'D', 'eta')), + (3, NAMED_STRUCT('C', 'theta', 'D', 'iota')) + AS T(ID, ST) +-- !query 0 schema +struct<> +-- !query 0 output + + + +-- !query 1 +SELECT STRUCT('alpha', 'beta') ST +-- !query 1 schema +struct> +-- !query 1 output +{"col1":"alpha","col2":"beta"} + + +-- !query 2 +SELECT STRUCT('alpha' AS A, 'beta' AS B) ST +-- !query 2 schema +struct> +-- !query 2 output +{"A":"alpha","B":"beta"} + + +-- !query 3 +SELECT ID, STRUCT(ST.*) NST FROM tbl_x +-- !query 3 schema +struct> +-- !query 3 output +1 {"C":"gamma","D":"delta"} +2 {"C":"epsilon","D":"eta"} +3 {"C":"theta","D":"iota"} + + +-- !query 4 +SELECT ID, STRUCT(ST.*,CAST(ID AS STRING) AS E) NST FROM tbl_x +-- !query 4 schema +struct> +-- !query 4 output +1 {"C":"gamma","D":"delta","E":"1"} +2 {"C":"epsilon","D":"eta","E":"2"} +3 {"C":"theta","D":"iota","E":"3"} + + +-- !query 5 +SELECT ID, STRUCT(CAST(ID AS STRING) AS AA, ST.*) NST FROM tbl_x +-- !query 5 schema +struct> +-- !query 5 output +1 {"AA":"1","C":"gamma","D":"delta"} +2 {"AA":"2","C":"epsilon","D":"eta"} +3 {"AA":"3","C":"theta","D":"iota"} diff --git a/sql/core/src/test/resources/sql-tests/results/subquery/in-subquery/simple-in.sql.out b/sql/core/src/test/resources/sql-tests/results/subquery/in-subquery/simple-in.sql.out index 66493d7fcc92d..d69b4bcf185c3 100644 --- a/sql/core/src/test/resources/sql-tests/results/subquery/in-subquery/simple-in.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/subquery/in-subquery/simple-in.sql.out @@ -1,5 +1,5 @@ -- Automatically generated by SQLQueryTestSuite --- Number of queries: 10 +-- Number of queries: 14 -- !query 0 @@ -174,3 +174,51 @@ t1a 6 2014-04-04 01:02:00.001 t1d 10 2015-05-04 01:01:00 t1d NULL 2014-06-04 01:01:00 t1d NULL 2014-07-04 01:02:00.001 + + +-- !query 10 +create temporary view a as select * from values + (1, 1), (2, 1), (null, 1), (1, 3), (null, 3), (1, null), (null, 2) + as a(a1, a2) +-- !query 10 schema +struct<> +-- !query 10 output + + + +-- !query 11 +create temporary view b as select * from values + (1, 1, 2), (null, 3, 2), (1, null, 2), (1, 2, null) + as b(b1, b2, b3) +-- !query 11 schema +struct<> +-- !query 11 output + + + +-- !query 12 +SELECT a1, a2 +FROM a +WHERE a1 NOT IN (SELECT b.b1 + FROM b + WHERE a.a2 = b.b2) +-- !query 12 schema +struct +-- !query 12 output +1 NULL +2 1 + + +-- !query 13 +SELECT a1, a2 +FROM a +WHERE a1 NOT IN (SELECT b.b1 + FROM b + WHERE a.a2 = b.b2 + AND b.b3 > 1) +-- !query 13 schema +struct +-- !query 13 output +1 NULL +2 1 +NULL 2 diff --git a/sql/core/src/test/resources/sql-tests/results/subquery/negative-cases/invalid-correlation.sql.out b/sql/core/src/test/resources/sql-tests/results/subquery/negative-cases/invalid-correlation.sql.out index 50ae01e181bcf..e4b1a2dbc675c 100644 --- a/sql/core/src/test/resources/sql-tests/results/subquery/negative-cases/invalid-correlation.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/subquery/negative-cases/invalid-correlation.sql.out @@ -1,11 +1,11 @@ -- Automatically generated by SQLQueryTestSuite --- Number of queries: 5 +-- Number of queries: 8 -- !query 0 -create temporary view t1 as select * from values +CREATE TEMPORARY VIEW t1 AS SELECT * FROM VALUES (1, 2, 3) -as t1(t1a, t1b, t1c) +AS t1(t1a, t1b, t1c) -- !query 0 schema struct<> -- !query 0 output @@ -13,9 +13,9 @@ struct<> -- !query 1 -create temporary view t2 as select * from values +CREATE TEMPORARY VIEW t2 AS SELECT * FROM VALUES (1, 0, 1) -as t2(t2a, t2b, t2c) +AS t2(t2a, t2b, t2c) -- !query 1 schema struct<> -- !query 1 output @@ -23,9 +23,9 @@ struct<> -- !query 2 -create temporary view t3 as select * from values +CREATE TEMPORARY VIEW t3 AS SELECT * FROM VALUES (3, 1, 2) -as t3(t3a, t3b, t3c) +AS t3(t3a, t3b, t3c) -- !query 2 schema struct<> -- !query 2 output @@ -33,34 +33,84 @@ struct<> -- !query 3 -select t1a, t2b -from t1, t2 -where t1b = t2c -and t2b = (select max(avg) - from (select t2b, avg(t2b) avg - from t2 - where t2a = t1.t1b +SELECT t1a, t2b +FROM t1, t2 +WHERE t1b = t2c +AND t2b = (SELECT max(avg) + FROM (SELECT t2b, avg(t2b) avg + FROM t2 + WHERE t2a = t1.t1b ) ) -- !query 3 schema struct<> -- !query 3 output org.apache.spark.sql.AnalysisException -expression 't2.`t2b`' is neither present in the group by, nor is it an aggregate function. Add to group by or wrap in first() (or first_value) if you don't care which value you get.; +grouping expressions sequence is empty, and 't2.`t2b`' is not an aggregate function. Wrap '(avg(CAST(t2.`t2b` AS BIGINT)) AS `avg`)' in windowing function(s) or wrap 't2.`t2b`' in first() (or first_value) if you don't care which value you get.; -- !query 4 -select * -from t1 -where t1a in (select min(t2a) - from t2 - group by t2c - having t2c in (select max(t3c) - from t3 - group by t3b - having t3b > t2b )) +SELECT * +FROM t1 +WHERE t1a IN (SELECT min(t2a) + FROM t2 + GROUP BY t2c + HAVING t2c IN (SELECT max(t3c) + FROM t3 + GROUP BY t3b + HAVING t3b > t2b )) -- !query 4 schema struct<> -- !query 4 output org.apache.spark.sql.AnalysisException -resolved attribute(s) t2b#x missing from min(t2a)#x,t2c#x in operator !Filter predicate-subquery#x [(t2c#x = max(t3c)#x) && (t3b#x > t2b#x)]; +resolved attribute(s) t2b#x missing from min(t2a)#x,t2c#x in operator !Filter t2c#x IN (list#x [t2b#x]); + + +-- !query 5 +SELECT t1a +FROM t1 +GROUP BY 1 +HAVING EXISTS (SELECT 1 + FROM t2 + WHERE t2a < min(t1a + t2a)) +-- !query 5 schema +struct<> +-- !query 5 output +org.apache.spark.sql.AnalysisException +Found an aggregate expression in a correlated predicate that has both outer and local references, which is not supported yet. Aggregate expression: min((t1.`t1a` + t2.`t2a`)), Outer references: t1.`t1a`, Local references: t2.`t2a`.; + + +-- !query 6 +SELECT t1a +FROM t1 +WHERE t1a IN (SELECT t2a + FROM t2 + WHERE EXISTS (SELECT 1 + FROM t3 + GROUP BY 1 + HAVING min(t2a + t3a) > 1)) +-- !query 6 schema +struct<> +-- !query 6 output +org.apache.spark.sql.AnalysisException +Found an aggregate expression in a correlated predicate that has both outer and local references, which is not supported yet. Aggregate expression: min((t2.`t2a` + t3.`t3a`)), Outer references: t2.`t2a`, Local references: t3.`t3a`.; + + +-- !query 7 +SELECT t1a +FROM t1 +WHERE t1a IN (SELECT t2a + FROM t2 + WHERE EXISTS (SELECT min(t2a) + FROM t3)) +-- !query 7 schema +struct<> +-- !query 7 output +org.apache.spark.sql.AnalysisException +Expressions referencing the outer query are not supported outside of WHERE/HAVING clauses: +Aggregate [min(outer(t2a#x)) AS min(outer())#x] ++- SubqueryAlias t3 + +- Project [t3a#x, t3b#x, t3c#x] + +- SubqueryAlias t3 + +- LocalRelation [t3a#x, t3b#x, t3c#x] +; diff --git a/sql/core/src/test/resources/sql-tests/results/table-valued-functions.sql.out b/sql/core/src/test/resources/sql-tests/results/table-valued-functions.sql.out index d769bcef0aca7..e2ee970d35f60 100644 --- a/sql/core/src/test/resources/sql-tests/results/table-valued-functions.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/table-valued-functions.sql.out @@ -1,5 +1,5 @@ -- Automatically generated by SQLQueryTestSuite --- Number of queries: 7 +-- Number of queries: 9 -- !query 0 @@ -85,3 +85,21 @@ struct<> -- !query 6 output java.lang.IllegalArgumentException Invalid arguments for resolved function: 1, null + + +-- !query 7 +select * from RaNgE(2) +-- !query 7 schema +struct +-- !query 7 output +0 +1 + + +-- !query 8 +EXPLAIN select * from RaNgE(2) +-- !query 8 schema +struct +-- !query 8 output +== Physical Plan == +*Range (0, 2, step=1, splits=2) diff --git a/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.1.0/metadata b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.1.0/metadata new file mode 100644 index 0000000000000..3492220e36b8d --- /dev/null +++ b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.1.0/metadata @@ -0,0 +1 @@ +{"id":"dddc5e7f-1e71-454c-8362-de184444fb5a"} \ No newline at end of file diff --git a/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.1.0/offsets/0 b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.1.0/offsets/0 new file mode 100644 index 0000000000000..cbde042e79af1 --- /dev/null +++ b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.1.0/offsets/0 @@ -0,0 +1,3 @@ +v1 +{"batchWatermarkMs":0,"batchTimestampMs":1489180207737} +0 \ No newline at end of file diff --git a/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.1.0/offsets/1 b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.1.0/offsets/1 new file mode 100644 index 0000000000000..10b5774746de9 --- /dev/null +++ b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.1.0/offsets/1 @@ -0,0 +1,3 @@ +v1 +{"batchWatermarkMs":0,"batchTimestampMs":1489180209261} +2 \ No newline at end of file diff --git a/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.1.0/state/0/0/1.delta b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.1.0/state/0/0/1.delta new file mode 100644 index 0000000000000..6352978051846 Binary files /dev/null and b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.1.0/state/0/0/1.delta differ diff --git a/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.1.0/state/0/0/2.delta b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.1.0/state/0/0/2.delta new file mode 100644 index 0000000000000..6352978051846 Binary files /dev/null and b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.1.0/state/0/0/2.delta differ diff --git a/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.1.0/state/0/1/1.delta b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.1.0/state/0/1/1.delta new file mode 100644 index 0000000000000..7dc49cb3e47fd Binary files /dev/null and b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.1.0/state/0/1/1.delta differ diff --git a/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.1.0/state/0/1/2.delta b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.1.0/state/0/1/2.delta new file mode 100644 index 0000000000000..8b566e81f4866 Binary files /dev/null and b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.1.0/state/0/1/2.delta differ diff --git a/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.1.0/state/0/2/1.delta b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.1.0/state/0/2/1.delta new file mode 100644 index 0000000000000..ca2a7ed033f3b Binary files /dev/null and b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.1.0/state/0/2/1.delta differ diff --git a/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.1.0/state/0/2/2.delta b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.1.0/state/0/2/2.delta new file mode 100644 index 0000000000000..361f2db605020 Binary files /dev/null and b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.1.0/state/0/2/2.delta differ diff --git a/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.1.0/state/0/3/1.delta b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.1.0/state/0/3/1.delta new file mode 100644 index 0000000000000..4c8804c61ad7f Binary files /dev/null and b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.1.0/state/0/3/1.delta differ diff --git a/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.1.0/state/0/3/2.delta b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.1.0/state/0/3/2.delta new file mode 100644 index 0000000000000..7d3e07fe03306 Binary files /dev/null and b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.1.0/state/0/3/2.delta differ diff --git a/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.1.0/state/0/4/1.delta b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.1.0/state/0/4/1.delta new file mode 100644 index 0000000000000..fe521b8c07504 Binary files /dev/null and b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.1.0/state/0/4/1.delta differ diff --git a/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.1.0/state/0/4/2.delta b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.1.0/state/0/4/2.delta new file mode 100644 index 0000000000000..6352978051846 Binary files /dev/null and b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.1.0/state/0/4/2.delta differ diff --git a/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.1.0/state/0/5/1.delta b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.1.0/state/0/5/1.delta new file mode 100644 index 0000000000000..6352978051846 Binary files /dev/null and b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.1.0/state/0/5/1.delta differ diff --git a/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.1.0/state/0/5/2.delta b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.1.0/state/0/5/2.delta new file mode 100644 index 0000000000000..6352978051846 Binary files /dev/null and b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.1.0/state/0/5/2.delta differ diff --git a/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.1.0/state/0/6/1.delta b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.1.0/state/0/6/1.delta new file mode 100644 index 0000000000000..6352978051846 Binary files /dev/null and b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.1.0/state/0/6/1.delta differ diff --git a/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.1.0/state/0/6/2.delta b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.1.0/state/0/6/2.delta new file mode 100644 index 0000000000000..e69925cabaa94 Binary files /dev/null and b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.1.0/state/0/6/2.delta differ diff --git a/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.1.0/state/0/7/1.delta b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.1.0/state/0/7/1.delta new file mode 100644 index 0000000000000..6352978051846 Binary files /dev/null and b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.1.0/state/0/7/1.delta differ diff --git a/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.1.0/state/0/7/2.delta b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.1.0/state/0/7/2.delta new file mode 100644 index 0000000000000..36397a3dda248 Binary files /dev/null and b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.1.0/state/0/7/2.delta differ diff --git a/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.1.0/state/0/8/1.delta b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.1.0/state/0/8/1.delta new file mode 100644 index 0000000000000..6352978051846 Binary files /dev/null and b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.1.0/state/0/8/1.delta differ diff --git a/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.1.0/state/0/8/2.delta b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.1.0/state/0/8/2.delta new file mode 100644 index 0000000000000..6352978051846 Binary files /dev/null and b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.1.0/state/0/8/2.delta differ diff --git a/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.1.0/state/0/9/1.delta b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.1.0/state/0/9/1.delta new file mode 100644 index 0000000000000..6352978051846 Binary files /dev/null and b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.1.0/state/0/9/1.delta differ diff --git a/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.1.0/state/0/9/2.delta b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.1.0/state/0/9/2.delta new file mode 100644 index 0000000000000..0c9b6ac5c863d Binary files /dev/null and b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.1.0/state/0/9/2.delta differ diff --git a/sql/core/src/test/resources/test-data/timemillis-in-i64.parquet b/sql/core/src/test/resources/test-data/timemillis-in-i64.parquet new file mode 100644 index 0000000000000..d3c39e2c26eec Binary files /dev/null and b/sql/core/src/test/resources/test-data/timemillis-in-i64.parquet differ diff --git a/sql/core/src/test/resources/tpcds/q77.sql b/sql/core/src/test/resources/tpcds/q77.sql index 7830f96e76515..a69df9fbcd366 100755 --- a/sql/core/src/test/resources/tpcds/q77.sql +++ b/sql/core/src/test/resources/tpcds/q77.sql @@ -36,7 +36,7 @@ WITH ss AS sum(cr_net_loss) AS profit_loss FROM catalog_returns, date_dim WHERE cr_returned_date_sk = d_date_sk - AND d_date BETWEEN cast('2000-08-03]' AS DATE) AND + AND d_date BETWEEN cast('2000-08-03' AS DATE) AND (cast('2000-08-03' AS DATE) + INTERVAL 30 days)), ws AS (SELECT diff --git a/sql/core/src/test/scala/org/apache/spark/sql/AggregateHashMapSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/AggregateHashMapSuite.scala index 3e85d95523125..7e61a68025158 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/AggregateHashMapSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/AggregateHashMapSuite.scala @@ -19,13 +19,12 @@ package org.apache.spark.sql import org.scalatest.BeforeAndAfter -class SingleLevelAggregateHashMapSuite extends DataFrameAggregateSuite with BeforeAndAfter { +import org.apache.spark.SparkConf - protected override def beforeAll(): Unit = { - sparkConf.set("spark.sql.codegen.fallback", "false") - sparkConf.set("spark.sql.codegen.aggregate.map.twolevel.enable", "false") - super.beforeAll() - } +class SingleLevelAggregateHashMapSuite extends DataFrameAggregateSuite with BeforeAndAfter { + override protected def sparkConf: SparkConf = super.sparkConf + .set("spark.sql.codegen.fallback", "false") + .set("spark.sql.codegen.aggregate.map.twolevel.enable", "false") // adding some checking after each test is run, assuring that the configs are not changed // in test code @@ -38,12 +37,9 @@ class SingleLevelAggregateHashMapSuite extends DataFrameAggregateSuite with Befo } class TwoLevelAggregateHashMapSuite extends DataFrameAggregateSuite with BeforeAndAfter { - - protected override def beforeAll(): Unit = { - sparkConf.set("spark.sql.codegen.fallback", "false") - sparkConf.set("spark.sql.codegen.aggregate.map.twolevel.enable", "true") - super.beforeAll() - } + override protected def sparkConf: SparkConf = super.sparkConf + .set("spark.sql.codegen.fallback", "false") + .set("spark.sql.codegen.aggregate.map.twolevel.enable", "true") // adding some checking after each test is run, assuring that the configs are not changed // in test code @@ -55,15 +51,14 @@ class TwoLevelAggregateHashMapSuite extends DataFrameAggregateSuite with BeforeA } } -class TwoLevelAggregateHashMapWithVectorizedMapSuite extends DataFrameAggregateSuite with -BeforeAndAfter { +class TwoLevelAggregateHashMapWithVectorizedMapSuite + extends DataFrameAggregateSuite + with BeforeAndAfter { - protected override def beforeAll(): Unit = { - sparkConf.set("spark.sql.codegen.fallback", "false") - sparkConf.set("spark.sql.codegen.aggregate.map.twolevel.enable", "true") - sparkConf.set("spark.sql.codegen.aggregate.map.vectorized.enable", "true") - super.beforeAll() - } + override protected def sparkConf: SparkConf = super.sparkConf + .set("spark.sql.codegen.fallback", "false") + .set("spark.sql.codegen.aggregate.map.twolevel.enable", "true") + .set("spark.sql.codegen.aggregate.map.vectorized.enable", "true") // adding some checking after each test is run, assuring that the configs are not changed // in test code diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala index 2a0e088437fda..e66fe97afad45 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala @@ -24,15 +24,15 @@ import scala.language.postfixOps import org.scalatest.concurrent.Eventually._ import org.apache.spark.CleanerListener -import org.apache.spark.sql.catalyst.expressions.{Expression, SubqueryExpression} -import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan -import org.apache.spark.sql.execution.RDDScanExec +import org.apache.spark.sql.catalyst.TableIdentifier +import org.apache.spark.sql.catalyst.expressions.SubqueryExpression +import org.apache.spark.sql.execution.{RDDScanExec, SparkPlan} import org.apache.spark.sql.execution.columnar._ import org.apache.spark.sql.execution.exchange.ShuffleExchange import org.apache.spark.sql.functions._ import org.apache.spark.sql.test.{SharedSQLContext, SQLTestUtils} import org.apache.spark.storage.{RDDBlockId, StorageLevel} -import org.apache.spark.util.AccumulatorContext +import org.apache.spark.util.{AccumulatorContext, Utils} private case class BigData(s: String) @@ -65,7 +65,8 @@ class CachedTableSuite extends QueryTest with SQLTestUtils with SharedSQLContext maybeBlock.nonEmpty } - private def getNumInMemoryRelations(plan: LogicalPlan): Int = { + private def getNumInMemoryRelations(ds: Dataset[_]): Int = { + val plan = ds.queryExecution.withCachedData var sum = plan.collect { case _: InMemoryRelation => 1 }.sum plan.transformAllExpressions { case e: SubqueryExpression => @@ -75,6 +76,13 @@ class CachedTableSuite extends QueryTest with SQLTestUtils with SharedSQLContext sum } + private def getNumInMemoryTablesRecursively(plan: SparkPlan): Int = { + plan.collect { + case InMemoryTableScanExec(_, _, relation) => + getNumInMemoryTablesRecursively(relation.child) + 1 + }.sum + } + test("withColumn doesn't invalidate cached dataframe") { var evalCount = 0 val myUDF = udf((x: String) => { evalCount += 1; "result" }) @@ -187,7 +195,7 @@ class CachedTableSuite extends QueryTest with SQLTestUtils with SharedSQLContext assertCached(spark.table("testData")) assertResult(1, "InMemoryRelation not found, testData should have been cached") { - getNumInMemoryRelations(spark.table("testData").queryExecution.withCachedData) + getNumInMemoryRelations(spark.table("testData")) } spark.catalog.cacheTable("testData") @@ -580,21 +588,21 @@ class CachedTableSuite extends QueryTest with SQLTestUtils with SharedSQLContext localRelation.createOrReplaceTempView("localRelation") spark.catalog.cacheTable("localRelation") - assert(getNumInMemoryRelations(localRelation.queryExecution.withCachedData) == 1) + assert(getNumInMemoryRelations(localRelation) == 1) } test("SPARK-19093 Caching in side subquery") { withTempView("t1") { Seq(1).toDF("c1").createOrReplaceTempView("t1") spark.catalog.cacheTable("t1") - val cachedPlan = + val ds = sql( """ |SELECT * FROM t1 |WHERE |NOT EXISTS (SELECT * FROM t1) - """.stripMargin).queryExecution.optimizedPlan - assert(getNumInMemoryRelations(cachedPlan) == 2) + """.stripMargin) + assert(getNumInMemoryRelations(ds) == 2) } } @@ -610,17 +618,17 @@ class CachedTableSuite extends QueryTest with SQLTestUtils with SharedSQLContext spark.catalog.cacheTable("t4") // Nested predicate subquery - val cachedPlan = + val ds = sql( """ |SELECT * FROM t1 |WHERE |c1 IN (SELECT c1 FROM t2 WHERE c1 IN (SELECT c1 FROM t3 WHERE c1 = 1)) - """.stripMargin).queryExecution.optimizedPlan - assert(getNumInMemoryRelations(cachedPlan) == 3) + """.stripMargin) + assert(getNumInMemoryRelations(ds) == 3) // Scalar subquery and predicate subquery - val cachedPlan2 = + val ds2 = sql( """ |SELECT * FROM (SELECT max(c1) FROM t1 GROUP BY c1) @@ -630,8 +638,27 @@ class CachedTableSuite extends QueryTest with SQLTestUtils with SharedSQLContext |EXISTS (SELECT c1 FROM t3) |OR |c1 IN (SELECT c1 FROM t4) - """.stripMargin).queryExecution.optimizedPlan - assert(getNumInMemoryRelations(cachedPlan2) == 4) + """.stripMargin) + assert(getNumInMemoryRelations(ds2) == 4) + } + } + + test("SPARK-19765: UNCACHE TABLE should un-cache all cached plans that refer to this table") { + withTable("t") { + withTempPath { path => + Seq(1 -> "a").toDF("i", "j").write.parquet(path.getCanonicalPath) + sql(s"CREATE TABLE t USING parquet LOCATION '$path'") + spark.catalog.cacheTable("t") + spark.table("t").select($"i").cache() + checkAnswer(spark.table("t").select($"i"), Row(1)) + assertCached(spark.table("t").select($"i")) + + Utils.deleteRecursively(path) + spark.sessionState.catalog.refreshTable(TableIdentifier("t")) + spark.catalog.uncacheTable("t") + assert(spark.table("t").select($"i").count() == 0) + assert(getNumInMemoryRelations(spark.table("t").select($"i")) == 0) + } } } @@ -650,4 +677,138 @@ class CachedTableSuite extends QueryTest with SQLTestUtils with SharedSQLContext assert(spark.read.parquet(path).filter($"id" > 4).count() == 15) } } + + test("SPARK-19993 simple subquery caching") { + withTempView("t1", "t2") { + Seq(1).toDF("c1").createOrReplaceTempView("t1") + Seq(2).toDF("c1").createOrReplaceTempView("t2") + + sql( + """ + |SELECT * FROM t1 + |WHERE + |NOT EXISTS (SELECT * FROM t2) + """.stripMargin).cache() + + val cachedDs = + sql( + """ + |SELECT * FROM t1 + |WHERE + |NOT EXISTS (SELECT * FROM t2) + """.stripMargin) + assert(getNumInMemoryRelations(cachedDs) == 1) + + // Additional predicate in the subquery plan should cause a cache miss + val cachedMissDs = + sql( + """ + |SELECT * FROM t1 + |WHERE + |NOT EXISTS (SELECT * FROM t2 where c1 = 0) + """.stripMargin) + assert(getNumInMemoryRelations(cachedMissDs) == 0) + } + } + + test("SPARK-19993 subquery caching with correlated predicates") { + withTempView("t1", "t2") { + Seq(1).toDF("c1").createOrReplaceTempView("t1") + Seq(1).toDF("c1").createOrReplaceTempView("t2") + + // Simple correlated predicate in subquery + sql( + """ + |SELECT * FROM t1 + |WHERE + |t1.c1 in (SELECT t2.c1 FROM t2 where t1.c1 = t2.c1) + """.stripMargin).cache() + + val cachedDs = + sql( + """ + |SELECT * FROM t1 + |WHERE + |t1.c1 in (SELECT t2.c1 FROM t2 where t1.c1 = t2.c1) + """.stripMargin) + assert(getNumInMemoryRelations(cachedDs) == 1) + } + } + + test("SPARK-19993 subquery with cached underlying relation") { + withTempView("t1") { + Seq(1).toDF("c1").createOrReplaceTempView("t1") + spark.catalog.cacheTable("t1") + + // underlying table t1 is cached as well as the query that refers to it. + val ds = + sql( + """ + |SELECT * FROM t1 + |WHERE + |NOT EXISTS (SELECT * FROM t1) + """.stripMargin) + assert(getNumInMemoryRelations(ds) == 2) + + val cachedDs = + sql( + """ + |SELECT * FROM t1 + |WHERE + |NOT EXISTS (SELECT * FROM t1) + """.stripMargin).cache() + assert(getNumInMemoryTablesRecursively(cachedDs.queryExecution.sparkPlan) == 3) + } + } + + test("SPARK-19993 nested subquery caching and scalar + predicate subqueris") { + withTempView("t1", "t2", "t3", "t4") { + Seq(1).toDF("c1").createOrReplaceTempView("t1") + Seq(2).toDF("c1").createOrReplaceTempView("t2") + Seq(1).toDF("c1").createOrReplaceTempView("t3") + Seq(1).toDF("c1").createOrReplaceTempView("t4") + + // Nested predicate subquery + sql( + """ + |SELECT * FROM t1 + |WHERE + |c1 IN (SELECT c1 FROM t2 WHERE c1 IN (SELECT c1 FROM t3 WHERE c1 = 1)) + """.stripMargin).cache() + + val cachedDs = + sql( + """ + |SELECT * FROM t1 + |WHERE + |c1 IN (SELECT c1 FROM t2 WHERE c1 IN (SELECT c1 FROM t3 WHERE c1 = 1)) + """.stripMargin) + assert(getNumInMemoryRelations(cachedDs) == 1) + + // Scalar subquery and predicate subquery + sql( + """ + |SELECT * FROM (SELECT max(c1) FROM t1 GROUP BY c1) + |WHERE + |c1 = (SELECT max(c1) FROM t2 GROUP BY c1) + |OR + |EXISTS (SELECT c1 FROM t3) + |OR + |c1 IN (SELECT c1 FROM t4) + """.stripMargin).cache() + + val cachedDs2 = + sql( + """ + |SELECT * FROM (SELECT max(c1) FROM t1 GROUP BY c1) + |WHERE + |c1 = (SELECT max(c1) FROM t2 GROUP BY c1) + |OR + |EXISTS (SELECT c1 FROM t3) + |OR + |c1 IN (SELECT c1 FROM t4) + """.stripMargin) + assert(getNumInMemoryRelations(cachedDs2) == 1) + } + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala index b0f398dab7455..bc708ca88d7e1 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala @@ -39,6 +39,9 @@ class ColumnExpressionSuite extends QueryTest with SharedSQLContext { StructType(Seq(StructField("a", BooleanType), StructField("b", BooleanType)))) } + private lazy val nullData = Seq( + (Some(1), Some(1)), (Some(1), Some(2)), (Some(1), None), (None, None)).toDF("a", "b") + test("column names with space") { val df = Seq((1, "a")).toDF("name with space", "name.with.dot") @@ -283,23 +286,6 @@ class ColumnExpressionSuite extends QueryTest with SharedSQLContext { } test("<=>") { - checkAnswer( - testData2.filter($"a" === 1), - testData2.collect().toSeq.filter(r => r.getInt(0) == 1)) - - checkAnswer( - testData2.filter($"a" === $"b"), - testData2.collect().toSeq.filter(r => r.getInt(0) == r.getInt(1))) - } - - test("=!=") { - val nullData = spark.createDataFrame(sparkContext.parallelize( - Row(1, 1) :: - Row(1, 2) :: - Row(1, null) :: - Row(null, null) :: Nil), - StructType(Seq(StructField("a", IntegerType), StructField("b", IntegerType)))) - checkAnswer( nullData.filter($"b" <=> 1), Row(1, 1) :: Nil) @@ -321,7 +307,18 @@ class ColumnExpressionSuite extends QueryTest with SharedSQLContext { checkAnswer( nullData2.filter($"a" <=> null), Row(null) :: Nil) + } + test("=!=") { + checkAnswer( + nullData.filter($"b" =!= 1), + Row(1, 2) :: Nil) + + checkAnswer(nullData.filter($"b" =!= null), Nil) + + checkAnswer( + nullData.filter($"a" =!= $"b"), + Row(1, 2) :: Nil) } test(">") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala index e7079120bb7df..8569c2d76b694 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala @@ -538,4 +538,11 @@ class DataFrameAggregateSuite extends QueryTest with SharedSQLContext { Seq(Row(3, 0, 0.0, 1, 5.0), Row(2, 1, 4.0, 0, 0.0)) ) } + + test("aggregate function in GROUP BY") { + val e = intercept[AnalysisException] { + testData.groupBy(sum($"key")).count() + } + assert(e.message.contains("aggregate functions are not allowed in GROUP BY")) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameImplicitsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameImplicitsSuite.scala index 094efbaeadcd5..63094d1b6122b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameImplicitsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameImplicitsSuite.scala @@ -51,4 +51,15 @@ class DataFrameImplicitsSuite extends QueryTest with SharedSQLContext { sparkContext.parallelize(1 to 10).map(_.toString).toDF("stringCol"), (1 to 10).map(i => Row(i.toString))) } + + test("SPARK-19959: df[java.lang.Long].collect includes null throws NullPointerException") { + checkAnswer(sparkContext.parallelize(Seq[java.lang.Integer](0, null, 2), 1).toDF, + Seq(Row(0), Row(null), Row(2))) + checkAnswer(sparkContext.parallelize(Seq[java.lang.Long](0L, null, 2L), 1).toDF, + Seq(Row(0L), Row(null), Row(2L))) + checkAnswer(sparkContext.parallelize(Seq[java.lang.Float](0.0F, null, 2.0F), 1).toDF, + Seq(Row(0.0F), Row(null), Row(2.0F))) + checkAnswer(sparkContext.parallelize(Seq[java.lang.Double](0.0D, null, 2.0D), 1).toDF, + Seq(Row(0.0D), Row(null), Row(2.0D))) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala index 541ffb58e727f..4a52af6c32c37 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala @@ -151,7 +151,7 @@ class DataFrameJoinSuite extends QueryTest with SharedSQLContext { Row(1, 1, 1, 1) :: Row(2, 1, 2, 2) :: Nil) } - test("broadcast join hint") { + test("broadcast join hint using broadcast function") { val df1 = Seq((1, "1"), (2, "2")).toDF("key", "value") val df2 = Seq((1, "1"), (2, "2")).toDF("key", "value") @@ -174,6 +174,22 @@ class DataFrameJoinSuite extends QueryTest with SharedSQLContext { } } + test("broadcast join hint using Dataset.hint") { + // make sure a giant join is not broadcastable + val plan1 = + spark.range(10e10.toLong) + .join(spark.range(10e10.toLong), "id") + .queryExecution.executedPlan + assert(plan1.collect { case p: BroadcastHashJoinExec => p }.size == 0) + + // now with a hint it should be broadcasted + val plan2 = + spark.range(10e10.toLong) + .join(spark.range(10e10.toLong).hint("broadcast"), "id") + .queryExecution.executedPlan + assert(plan2.collect { case p: BroadcastHashJoinExec => p }.size == 1) + } + test("join - outer join conversion") { val df = Seq((1, 2, "1"), (3, 4, "3")).toDF("int", "int2", "str").as("a") val df2 = Seq((1, 3, "1"), (5, 6, "5")).toDF("int", "int2", "str").as("b") diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameNaFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameNaFunctionsSuite.scala index fd829846ac332..aa237d0619ac3 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameNaFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameNaFunctionsSuite.scala @@ -145,6 +145,20 @@ class DataFrameNaFunctionsSuite extends QueryTest with SharedSQLContext { Row(1, 2) :: Row(-1, -2) :: Row(9123146099426677101L, 9123146560113991650L) :: Nil ) + checkAnswer( + Seq[(java.lang.Long, java.lang.Double)]((null, 3.14), (9123146099426677101L, null), + (9123146560113991650L, 1.6), (null, null)).toDF("a", "b").na.fill(0.2), + Row(0, 3.14) :: Row(9123146099426677101L, 0.2) :: Row(9123146560113991650L, 1.6) + :: Row(0, 0.2) :: Nil + ) + + checkAnswer( + Seq[(java.lang.Long, java.lang.Float)]((null, 3.14f), (9123146099426677101L, null), + (9123146560113991650L, 1.6f), (null, null)).toDF("a", "b").na.fill(0.2), + Row(0, 3.14f) :: Row(9123146099426677101L, 0.2f) :: Row(9123146560113991650L, 1.6f) + :: Row(0, 0.2f) :: Nil + ) + checkAnswer( Seq[(java.lang.Long, java.lang.Double)]((null, 1.23), (3L, null), (4L, 3.45)) .toDF("a", "b").na.fill(2.34), diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFramePivotSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFramePivotSuite.scala index 51ffe34172714..6ca9ee57e8f49 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFramePivotSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFramePivotSuite.scala @@ -23,7 +23,7 @@ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types._ -class DataFramePivotSuite extends QueryTest with SharedSQLContext{ +class DataFramePivotSuite extends QueryTest with SharedSQLContext { import testImplicits._ test("pivot courses") { @@ -216,4 +216,34 @@ class DataFramePivotSuite extends QueryTest with SharedSQLContext{ Row("d", 15000.0, 48000.0) :: Row("J", 20000.0, 30000.0) :: Nil ) } + + test("pivot with null should not throw NPE") { + checkAnswer( + Seq(Tuple1(None), Tuple1(Some(1))).toDF("a").groupBy($"a").pivot("a").count(), + Row(null, 1, null) :: Row(1, null, 1) :: Nil) + } + + test("pivot with null and aggregate type not supported by PivotFirst returns correct result") { + checkAnswer( + Seq(Tuple1(None), Tuple1(Some(1))).toDF("a") + .withColumn("b", expr("array(a, 7)")) + .groupBy($"a").pivot("a").agg(min($"b")), + Row(null, Seq(null, 7), null) :: Row(1, null, Seq(1, 7)) :: Nil) + } + + test("pivot with timestamp and count should not print internal representation") { + val ts = "2012-12-31 16:00:10.011" + val tsWithZone = "2013-01-01 00:00:10.011" + + withSQLConf(SQLConf.SESSION_LOCAL_TIMEZONE.key -> "GMT") { + val df = Seq(java.sql.Timestamp.valueOf(ts)).toDF("a").groupBy("a").pivot("a").count() + val expected = StructType( + StructField("a", TimestampType) :: + StructField(tsWithZone, LongType) :: Nil) + assert(df.schema == expected) + // String representation of timestamp with timezone should take the time difference + // into account. + checkAnswer(df.select($"a".cast(StringType)), Row(tsWithZone)) + } + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameRangeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameRangeSuite.scala index acf393a9b0faf..7b495656b93d7 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameRangeSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameRangeSuite.scala @@ -89,6 +89,22 @@ class DataFrameRangeSuite extends QueryTest with SharedSQLContext with Eventuall val n = 9L * 1000 * 1000 * 1000 * 1000 * 1000 * 1000 val res13 = spark.range(-n, n, n / 9).select("id") assert(res13.count == 18) + + // range with non aggregation operation + val res14 = spark.range(0, 100, 2).toDF.filter("50 <= id") + val len14 = res14.collect.length + assert(len14 == 25) + + val res15 = spark.range(100, -100, -2).toDF.filter("id <= 0") + val len15 = res15.collect.length + assert(len15 == 50) + + val res16 = spark.range(-1500, 1500, 3).toDF.filter("0 <= id") + val len16 = res16.collect.length + assert(len16 == 500) + + val res17 = spark.range(10, 0, -1, 1).toDF.sortWithinPartitions("id") + assert(res17.collect === (1 to 10).map(i => Row(i)).toArray) } test("Range with randomized parameters") { @@ -169,6 +185,12 @@ class DataFrameRangeSuite extends QueryTest with SharedSQLContext with Eventuall } } } + + test("SPARK-20430 Initialize Range parameters in a driver side") { + withSQLConf(SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key -> "false") { + checkAnswer(sql("SELECT * FROM range(3)"), Row(0) :: Row(1) :: Row(2) :: Nil) + } + } } object DataFrameRangeSuite { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala index d0910e618a040..dd118f88e3bb3 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala @@ -68,25 +68,38 @@ class DataFrameStatSuite extends QueryTest with SharedSQLContext { } test("randomSplit on reordered partitions") { - // This test ensures that randomSplit does not create overlapping splits even when the - // underlying dataframe (such as the one below) doesn't guarantee a deterministic ordering of - // rows in each partition. - val data = - sparkContext.parallelize(1 to 600, 2).mapPartitions(scala.util.Random.shuffle(_)).toDF("id") - val splits = data.randomSplit(Array[Double](2, 3), seed = 1) - assert(splits.length == 2, "wrong number of splits") + def testNonOverlappingSplits(data: DataFrame): Unit = { + val splits = data.randomSplit(Array[Double](2, 3), seed = 1) + assert(splits.length == 2, "wrong number of splits") - // Verify that the splits span the entire dataset - assert(splits.flatMap(_.collect()).toSet == data.collect().toSet) + // Verify that the splits span the entire dataset + assert(splits.flatMap(_.collect()).toSet == data.collect().toSet) - // Verify that the splits don't overlap - assert(splits(0).intersect(splits(1)).collect().isEmpty) + // Verify that the splits don't overlap + assert(splits(0).collect().toSeq.intersect(splits(1).collect().toSeq).isEmpty) - // Verify that the results are deterministic across multiple runs - val firstRun = splits.toSeq.map(_.collect().toSeq) - val secondRun = data.randomSplit(Array[Double](2, 3), seed = 1).toSeq.map(_.collect().toSeq) - assert(firstRun == secondRun) + // Verify that the results are deterministic across multiple runs + val firstRun = splits.toSeq.map(_.collect().toSeq) + val secondRun = data.randomSplit(Array[Double](2, 3), seed = 1).toSeq.map(_.collect().toSeq) + assert(firstRun == secondRun) + } + + // This test ensures that randomSplit does not create overlapping splits even when the + // underlying dataframe (such as the one below) doesn't guarantee a deterministic ordering of + // rows in each partition. + val dataWithInts = sparkContext.parallelize(1 to 600, 2) + .mapPartitions(scala.util.Random.shuffle(_)).toDF("int") + val dataWithMaps = sparkContext.parallelize(1 to 600, 2) + .map(i => (i, Map(i -> i.toString))) + .mapPartitions(scala.util.Random.shuffle(_)).toDF("int", "map") + val dataWithArrayOfMaps = sparkContext.parallelize(1 to 600, 2) + .map(i => (i, Array(Map(i -> i.toString)))) + .mapPartitions(scala.util.Random.shuffle(_)).toDF("int", "arrayOfMaps") + + testNonOverlappingSplits(dataWithInts) + testNonOverlappingSplits(dataWithMaps) + testNonOverlappingSplits(dataWithArrayOfMaps) } test("pearson correlation") { @@ -171,15 +184,6 @@ class DataFrameStatSuite extends QueryTest with SharedSQLContext { df.stat.approxQuantile(Array("singles", "doubles"), Array(q1, q2), -1.0) } assert(e2.getMessage.contains("Relative Error must be non-negative")) - - // return null if the dataset is empty - val res1 = df.selectExpr("*").limit(0) - .stat.approxQuantile("singles", Array(q1, q2), epsilons.head) - assert(res1 === null) - - val res2 = df.selectExpr("*").limit(0) - .stat.approxQuantile(Array("singles", "doubles"), Array(q1, q2), epsilons.head) - assert(res2 === null) } test("approximate quantile 2: test relativeError greater than 1 return the same result as 1") { @@ -214,20 +218,48 @@ class DataFrameStatSuite extends QueryTest with SharedSQLContext { val q1 = 0.5 val q2 = 0.8 val epsilon = 0.1 - val rows = spark.sparkContext.parallelize(Seq(Row(Double.NaN, 1.0), Row(1.0, 1.0), - Row(-1.0, Double.NaN), Row(Double.NaN, Double.NaN), Row(null, null), Row(null, 1.0), - Row(-1.0, null), Row(Double.NaN, null))) + val rows = spark.sparkContext.parallelize(Seq(Row(Double.NaN, 1.0, Double.NaN), + Row(1.0, -1.0, null), Row(-1.0, Double.NaN, null), Row(Double.NaN, Double.NaN, null), + Row(null, null, Double.NaN), Row(null, 1.0, null), Row(-1.0, null, Double.NaN), + Row(Double.NaN, null, null))) val schema = StructType(Seq(StructField("input1", DoubleType, nullable = true), - StructField("input2", DoubleType, nullable = true))) + StructField("input2", DoubleType, nullable = true), + StructField("input3", DoubleType, nullable = true))) val dfNaN = spark.createDataFrame(rows, schema) - val resNaN = dfNaN.stat.approxQuantile("input1", Array(q1, q2), epsilon) - assert(resNaN.count(_.isNaN) === 0) - assert(resNaN.count(_ == null) === 0) - val resNaN2 = dfNaN.stat.approxQuantile(Array("input1", "input2"), + val resNaN1 = dfNaN.stat.approxQuantile("input1", Array(q1, q2), epsilon) + assert(resNaN1.count(_.isNaN) === 0) + assert(resNaN1.count(_ == null) === 0) + + val resNaN2 = dfNaN.stat.approxQuantile("input2", Array(q1, q2), epsilon) + assert(resNaN2.count(_.isNaN) === 0) + assert(resNaN2.count(_ == null) === 0) + + val resNaN3 = dfNaN.stat.approxQuantile("input3", Array(q1, q2), epsilon) + assert(resNaN3.isEmpty) + + val resNaNAll = dfNaN.stat.approxQuantile(Array("input1", "input2", "input3"), Array(q1, q2), epsilon) - assert(resNaN2.flatten.count(_.isNaN) === 0) - assert(resNaN2.flatten.count(_ == null) === 0) + assert(resNaNAll.flatten.count(_.isNaN) === 0) + assert(resNaNAll.flatten.count(_ == null) === 0) + + assert(resNaN1(0) === resNaNAll(0)(0)) + assert(resNaN1(1) === resNaNAll(0)(1)) + assert(resNaN2(0) === resNaNAll(1)(0)) + assert(resNaN2(1) === resNaNAll(1)(1)) + + // return empty array for columns only containing null or NaN values + assert(resNaNAll(2).isEmpty) + + // return empty array if the dataset is empty + val res1 = dfNaN.selectExpr("*").limit(0) + .stat.approxQuantile("input1", Array(q1, q2), epsilon) + assert(res1.isEmpty) + + val res2 = dfNaN.selectExpr("*").limit(0) + .stat.approxQuantile(Array("input1", "input2"), Array(q1, q2), epsilon) + assert(res2(0).isEmpty) + assert(res2(1).isEmpty) } test("crosstab") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index 46919d5b6556c..5518304fe87f0 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -765,6 +765,21 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { assert(df.showString(10, truncate = 20) === expectedAnswerForTrue) } + test("showString: truncate = [0, 20], vertical = true") { + val longString = Array.fill(21)("1").mkString + val df = sparkContext.parallelize(Seq("1", longString)).toDF() + val expectedAnswerForFalse = "-RECORD 0----------------------\n" + + " value | 1 \n" + + "-RECORD 1----------------------\n" + + " value | 111111111111111111111 \n" + assert(df.showString(10, truncate = 0, vertical = true) === expectedAnswerForFalse) + val expectedAnswerForTrue = "-RECORD 0---------------------\n" + + " value | 1 \n" + + "-RECORD 1---------------------\n" + + " value | 11111111111111111... \n" + assert(df.showString(10, truncate = 20, vertical = true) === expectedAnswerForTrue) + } + test("showString: truncate = [3, 17]") { val longString = Array.fill(21)("1").mkString val df = sparkContext.parallelize(Seq("1", longString)).toDF() @@ -786,6 +801,21 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { assert(df.showString(10, truncate = 17) === expectedAnswerForTrue) } + test("showString: truncate = [3, 17], vertical = true") { + val longString = Array.fill(21)("1").mkString + val df = sparkContext.parallelize(Seq("1", longString)).toDF() + val expectedAnswerForFalse = "-RECORD 0----\n" + + " value | 1 \n" + + "-RECORD 1----\n" + + " value | 111 \n" + assert(df.showString(10, truncate = 3, vertical = true) === expectedAnswerForFalse) + val expectedAnswerForTrue = "-RECORD 0------------------\n" + + " value | 1 \n" + + "-RECORD 1------------------\n" + + " value | 11111111111111... \n" + assert(df.showString(10, truncate = 17, vertical = true) === expectedAnswerForTrue) + } + test("showString(negative)") { val expectedAnswer = """+---+-----+ ||key|value| @@ -796,6 +826,11 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { assert(testData.select($"*").showString(-1) === expectedAnswer) } + test("showString(negative), vertical = true") { + val expectedAnswer = "(0 rows)\n" + assert(testData.select($"*").showString(-1, vertical = true) === expectedAnswer) + } + test("showString(0)") { val expectedAnswer = """+---+-----+ ||key|value| @@ -806,6 +841,11 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { assert(testData.select($"*").showString(0) === expectedAnswer) } + test("showString(0), vertical = true") { + val expectedAnswer = "(0 rows)\n" + assert(testData.select($"*").showString(0, vertical = true) === expectedAnswer) + } + test("showString: array") { val df = Seq( (Array(1, 2, 3), Array(1, 2, 3)), @@ -821,6 +861,20 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { assert(df.showString(10) === expectedAnswer) } + test("showString: array, vertical = true") { + val df = Seq( + (Array(1, 2, 3), Array(1, 2, 3)), + (Array(2, 3, 4), Array(2, 3, 4)) + ).toDF() + val expectedAnswer = "-RECORD 0--------\n" + + " _1 | [1, 2, 3] \n" + + " _2 | [1, 2, 3] \n" + + "-RECORD 1--------\n" + + " _1 | [2, 3, 4] \n" + + " _2 | [2, 3, 4] \n" + assert(df.showString(10, vertical = true) === expectedAnswer) + } + test("showString: binary") { val df = Seq( ("12".getBytes(StandardCharsets.UTF_8), "ABC.".getBytes(StandardCharsets.UTF_8)), @@ -836,6 +890,20 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { assert(df.showString(10) === expectedAnswer) } + test("showString: binary, vertical = true") { + val df = Seq( + ("12".getBytes(StandardCharsets.UTF_8), "ABC.".getBytes(StandardCharsets.UTF_8)), + ("34".getBytes(StandardCharsets.UTF_8), "12346".getBytes(StandardCharsets.UTF_8)) + ).toDF() + val expectedAnswer = "-RECORD 0---------------\n" + + " _1 | [31 32] \n" + + " _2 | [41 42 43 2E] \n" + + "-RECORD 1---------------\n" + + " _1 | [33 34] \n" + + " _2 | [31 32 33 34 36] \n" + assert(df.showString(10, vertical = true) === expectedAnswer) + } + test("showString: minimum column width") { val df = Seq( (1, 1), @@ -851,6 +919,20 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { assert(df.showString(10) === expectedAnswer) } + test("showString: minimum column width, vertical = true") { + val df = Seq( + (1, 1), + (2, 2) + ).toDF() + val expectedAnswer = "-RECORD 0--\n" + + " _1 | 1 \n" + + " _2 | 1 \n" + + "-RECORD 1--\n" + + " _1 | 2 \n" + + " _2 | 2 \n" + assert(df.showString(10, vertical = true) === expectedAnswer) + } + test("SPARK-7319 showString") { val expectedAnswer = """+---+-----+ ||key|value| @@ -862,6 +944,14 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { assert(testData.select($"*").showString(1) === expectedAnswer) } + test("SPARK-7319 showString, vertical = true") { + val expectedAnswer = "-RECORD 0----\n" + + " key | 1 \n" + + " value | 1 \n" + + "only showing top 1 row\n" + assert(testData.select($"*").showString(1, vertical = true) === expectedAnswer) + } + test("SPARK-7327 show with empty dataFrame") { val expectedAnswer = """+---+-----+ ||key|value| @@ -871,6 +961,10 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { assert(testData.select($"*").filter($"key" < 0).showString(1) === expectedAnswer) } + test("SPARK-7327 show with empty dataFrame, vertical = true") { + assert(testData.select($"*").filter($"key" < 0).showString(1, vertical = true) === "(0 rows)\n") + } + test("SPARK-18350 show with session local timezone") { val d = Date.valueOf("2016-12-01") val ts = Timestamp.valueOf("2016-12-01 00:00:00") @@ -895,6 +989,24 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { } } + test("SPARK-18350 show with session local timezone, vertical = true") { + val d = Date.valueOf("2016-12-01") + val ts = Timestamp.valueOf("2016-12-01 00:00:00") + val df = Seq((d, ts)).toDF("d", "ts") + val expectedAnswer = "-RECORD 0------------------\n" + + " d | 2016-12-01 \n" + + " ts | 2016-12-01 00:00:00 \n" + assert(df.showString(1, truncate = 0, vertical = true) === expectedAnswer) + + withSQLConf(SQLConf.SESSION_LOCAL_TIMEZONE.key -> "GMT") { + + val expectedAnswer = "-RECORD 0------------------\n" + + " d | 2016-12-01 \n" + + " ts | 2016-12-01 08:00:00 \n" + assert(df.showString(1, truncate = 0, vertical = true) === expectedAnswer) + } + } + test("createDataFrame(RDD[Row], StructType) should convert UDTs (SPARK-6672)") { val rowRDD = sparkContext.parallelize(Seq(Row(new ExamplePoint(1.0, 2.0)))) val schema = StructType(Array(StructField("point", new ExamplePointUDT(), false))) @@ -1715,4 +1827,33 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { .foldLeft(lit(false))((e, index) => e.or(df.col(df.columns(index)) =!= "string")) df.filter(filter).count } + + test("SPARK-19893: cannot run set operations with map type") { + val df = spark.range(1).select(map(lit("key"), $"id").as("m")) + val e = intercept[AnalysisException](df.intersect(df)) + assert(e.message.contains( + "Cannot have map type columns in DataFrame which calls set operations")) + val e2 = intercept[AnalysisException](df.except(df)) + assert(e2.message.contains( + "Cannot have map type columns in DataFrame which calls set operations")) + val e3 = intercept[AnalysisException](df.distinct()) + assert(e3.message.contains( + "Cannot have map type columns in DataFrame which calls set operations")) + withTempView("v") { + df.createOrReplaceTempView("v") + val e4 = intercept[AnalysisException](sql("SELECT DISTINCT m FROM v")) + assert(e4.message.contains( + "Cannot have map type columns in DataFrame which calls set operations")) + } + } + + test("SPARK-20359: catalyst outer join optimization should not throw npe") { + val df1 = Seq("a", "b", "c").toDF("x") + .withColumn("y", udf{ (x: String) => x.substring(0, 1) + "!" }.apply($"x")) + val df2 = Seq("a", "b").toDF("x1") + df1 + .join(df2, df1("x") === df2("x1"), "left_outer") + .filter($"x1".isNotNull || !$"y".isin("a!")) + .count + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetBenchmark.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetBenchmark.scala index 66d94d6016050..1a0672b8876da 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetBenchmark.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetBenchmark.scala @@ -31,6 +31,49 @@ object DatasetBenchmark { case class Data(l: Long, s: String) + def backToBackMapLong(spark: SparkSession, numRows: Long, numChains: Int): Benchmark = { + import spark.implicits._ + + val rdd = spark.sparkContext.range(0, numRows) + val ds = spark.range(0, numRows) + val df = ds.toDF("l") + val func = (l: Long) => l + 1 + + val benchmark = new Benchmark("back-to-back map long", numRows) + + benchmark.addCase("RDD") { iter => + var res = rdd + var i = 0 + while (i < numChains) { + res = res.map(func) + i += 1 + } + res.foreach(_ => Unit) + } + + benchmark.addCase("DataFrame") { iter => + var res = df + var i = 0 + while (i < numChains) { + res = res.select($"l" + 1 as "l") + i += 1 + } + res.queryExecution.toRdd.foreach(_ => Unit) + } + + benchmark.addCase("Dataset") { iter => + var res = ds.as[Long] + var i = 0 + while (i < numChains) { + res = res.map(func) + i += 1 + } + res.queryExecution.toRdd.foreach(_ => Unit) + } + + benchmark + } + def backToBackMap(spark: SparkSession, numRows: Long, numChains: Int): Benchmark = { import spark.implicits._ @@ -72,6 +115,49 @@ object DatasetBenchmark { benchmark } + def backToBackFilterLong(spark: SparkSession, numRows: Long, numChains: Int): Benchmark = { + import spark.implicits._ + + val rdd = spark.sparkContext.range(1, numRows) + val ds = spark.range(1, numRows) + val df = ds.toDF("l") + val func = (l: Long) => l % 2L == 0L + + val benchmark = new Benchmark("back-to-back filter Long", numRows) + + benchmark.addCase("RDD") { iter => + var res = rdd + var i = 0 + while (i < numChains) { + res = res.filter(func) + i += 1 + } + res.foreach(_ => Unit) + } + + benchmark.addCase("DataFrame") { iter => + var res = df + var i = 0 + while (i < numChains) { + res = res.filter($"l" % 2L === 0L) + i += 1 + } + res.queryExecution.toRdd.foreach(_ => Unit) + } + + benchmark.addCase("Dataset") { iter => + var res = ds.as[Long] + var i = 0 + while (i < numChains) { + res = res.filter(func) + i += 1 + } + res.queryExecution.toRdd.foreach(_ => Unit) + } + + benchmark + } + def backToBackFilter(spark: SparkSession, numRows: Long, numChains: Int): Benchmark = { import spark.implicits._ @@ -165,9 +251,22 @@ object DatasetBenchmark { val numRows = 100000000 val numChains = 10 - val benchmark = backToBackMap(spark, numRows, numChains) - val benchmark2 = backToBackFilter(spark, numRows, numChains) - val benchmark3 = aggregate(spark, numRows) + val benchmark0 = backToBackMapLong(spark, numRows, numChains) + val benchmark1 = backToBackMap(spark, numRows, numChains) + val benchmark2 = backToBackFilterLong(spark, numRows, numChains) + val benchmark3 = backToBackFilter(spark, numRows, numChains) + val benchmark4 = aggregate(spark, numRows) + + /* + OpenJDK 64-Bit Server VM 1.8.0_111-8u111-b14-2ubuntu0.16.04.2-b14 on Linux 4.4.0-47-generic + Intel(R) Xeon(R) CPU E5-2667 v3 @ 3.20GHz + back-to-back map long: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + ------------------------------------------------------------------------------------------------ + RDD 1883 / 1892 53.1 18.8 1.0X + DataFrame 502 / 642 199.1 5.0 3.7X + Dataset 657 / 784 152.2 6.6 2.9X + */ + benchmark0.run() /* OpenJDK 64-Bit Server VM 1.8.0_91-b14 on Linux 3.10.0-327.18.2.el7.x86_64 @@ -178,7 +277,18 @@ object DatasetBenchmark { DataFrame 2647 / 3116 37.8 26.5 1.3X Dataset 4781 / 5155 20.9 47.8 0.7X */ - benchmark.run() + benchmark1.run() + + /* + OpenJDK 64-Bit Server VM 1.8.0_121-8u121-b13-0ubuntu1.16.04.2-b13 on Linux 4.4.0-47-generic + Intel(R) Xeon(R) CPU E5-2667 v3 @ 3.20GHz + back-to-back filter Long: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + ------------------------------------------------------------------------------------------------ + RDD 846 / 1120 118.1 8.5 1.0X + DataFrame 270 / 329 370.9 2.7 3.1X + Dataset 545 / 789 183.5 5.4 1.6X + */ + benchmark2.run() /* OpenJDK 64-Bit Server VM 1.8.0_91-b14 on Linux 3.10.0-327.18.2.el7.x86_64 @@ -189,7 +299,7 @@ object DatasetBenchmark { DataFrame 59 / 72 1695.4 0.6 22.8X Dataset 2777 / 2805 36.0 27.8 0.5X */ - benchmark2.run() + benchmark3.run() /* Java HotSpot(TM) 64-Bit Server VM 1.8.0_60-b27 on Mac OS X 10.12.1 @@ -201,6 +311,6 @@ object DatasetBenchmark { Dataset sum using Aggregator 4656 / 4758 21.5 46.6 0.4X Dataset complex Aggregator 6636 / 7039 15.1 66.4 0.3X */ - benchmark3.run() + benchmark4.run() } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetPrimitiveSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetPrimitiveSuite.scala index 6b50cb3e48c76..541565344f758 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetPrimitiveSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetPrimitiveSuite.scala @@ -62,6 +62,50 @@ class DatasetPrimitiveSuite extends QueryTest with SharedSQLContext { 2, 3, 4) } + test("mapPrimitive") { + val dsInt = Seq(1, 2, 3).toDS() + checkDataset(dsInt.map(_ > 1), false, true, true) + checkDataset(dsInt.map(_ + 1), 2, 3, 4) + checkDataset(dsInt.map(_ + 8589934592L), 8589934593L, 8589934594L, 8589934595L) + checkDataset(dsInt.map(_ + 1.1F), 2.1F, 3.1F, 4.1F) + checkDataset(dsInt.map(_ + 1.23D), 2.23D, 3.23D, 4.23D) + + val dsLong = Seq(1L, 2L, 3L).toDS() + checkDataset(dsLong.map(_ > 1), false, true, true) + checkDataset(dsLong.map(e => (e + 1).toInt), 2, 3, 4) + checkDataset(dsLong.map(_ + 8589934592L), 8589934593L, 8589934594L, 8589934595L) + checkDataset(dsLong.map(_ + 1.1F), 2.1F, 3.1F, 4.1F) + checkDataset(dsLong.map(_ + 1.23D), 2.23D, 3.23D, 4.23D) + + val dsFloat = Seq(1F, 2F, 3F).toDS() + checkDataset(dsFloat.map(_ > 1), false, true, true) + checkDataset(dsFloat.map(e => (e + 1).toInt), 2, 3, 4) + checkDataset(dsFloat.map(e => (e + 123456L).toLong), 123457L, 123458L, 123459L) + checkDataset(dsFloat.map(_ + 1.1F), 2.1F, 3.1F, 4.1F) + checkDataset(dsFloat.map(_ + 1.23D), 2.23D, 3.23D, 4.23D) + + val dsDouble = Seq(1D, 2D, 3D).toDS() + checkDataset(dsDouble.map(_ > 1), false, true, true) + checkDataset(dsDouble.map(e => (e + 1).toInt), 2, 3, 4) + checkDataset(dsDouble.map(e => (e + 8589934592L).toLong), + 8589934593L, 8589934594L, 8589934595L) + checkDataset(dsDouble.map(e => (e + 1.1F).toFloat), 2.1F, 3.1F, 4.1F) + checkDataset(dsDouble.map(_ + 1.23D), 2.23D, 3.23D, 4.23D) + + val dsBoolean = Seq(true, false).toDS() + checkDataset(dsBoolean.map(e => !e), false, true) + } + + test("mapPrimitiveArray") { + val dsInt = Seq(Array(1, 2), Array(3, 4)).toDS() + checkDataset(dsInt.map(e => e), Array(1, 2), Array(3, 4)) + checkDataset(dsInt.map(e => null: Array[Int]), null, null) + + val dsDouble = Seq(Array(1D, 2D), Array(3D, 4D)).toDS() + checkDataset(dsDouble.map(e => e), Array(1D, 2D), Array(3D, 4D)) + checkDataset(dsDouble.map(e => null: Array[Double]), null, null) + } + test("filter") { val ds = Seq(1, 2, 3, 4).toDS() checkDataset( @@ -69,6 +113,23 @@ class DatasetPrimitiveSuite extends QueryTest with SharedSQLContext { 2, 4) } + test("filterPrimitive") { + val dsInt = Seq(1, 2, 3).toDS() + checkDataset(dsInt.filter(_ > 1), 2, 3) + + val dsLong = Seq(1L, 2L, 3L).toDS() + checkDataset(dsLong.filter(_ > 1), 2L, 3L) + + val dsFloat = Seq(1F, 2F, 3F).toDS() + checkDataset(dsFloat.filter(_ > 1), 2F, 3F) + + val dsDouble = Seq(1D, 2D, 3D).toDS() + checkDataset(dsDouble.filter(_ > 1), 2D, 3D) + + val dsBoolean = Seq(true, false).toDS() + checkDataset(dsBoolean.filter(e => !e), false) + } + test("foreach") { val ds = Seq(1, 2, 3).toDS() val acc = sparkContext.longAccumulator diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSerializerRegistratorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSerializerRegistratorSuite.scala index 0f3d0cefe3bb5..68f7de047b392 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSerializerRegistratorSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSerializerRegistratorSuite.scala @@ -20,9 +20,9 @@ package org.apache.spark.sql import com.esotericsoftware.kryo.{Kryo, Serializer} import com.esotericsoftware.kryo.io.{Input, Output} +import org.apache.spark.SparkConf import org.apache.spark.serializer.KryoRegistrator import org.apache.spark.sql.test.SharedSQLContext -import org.apache.spark.sql.test.TestSparkSession /** * Test suite to test Kryo custom registrators. @@ -30,12 +30,10 @@ import org.apache.spark.sql.test.TestSparkSession class DatasetSerializerRegistratorSuite extends QueryTest with SharedSQLContext { import testImplicits._ - /** - * Initialize the [[TestSparkSession]] with a [[KryoRegistrator]]. - */ - protected override def beforeAll(): Unit = { - sparkConf.set("spark.kryo.registrator", TestRegistrator().getClass.getCanonicalName) - super.beforeAll() + + override protected def sparkConf: SparkConf = { + // Make sure we use the KryoRegistrator + super.sparkConf.set("spark.kryo.registrator", TestRegistrator().getClass.getCanonicalName) } test("Kryo registrator") { @@ -56,7 +54,9 @@ object TestRegistrator { def apply(): TestRegistrator = new TestRegistrator() } -/** A [[Serializer]] that takes a [[KryoData]] and serializes it as KryoData(0). */ +/** + * A `Serializer` that takes a [[KryoData]] and serializes it as KryoData(0). + */ class ZeroKryoDataSerializer extends Serializer[KryoData] { override def write(kryo: Kryo, output: Output, t: KryoData): Unit = { output.writeInt(0) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala index b37bf131e8dce..5b5cd28ad0c99 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala @@ -142,6 +142,15 @@ class DatasetSuite extends QueryTest with SharedSQLContext { assert(ds.take(2) === Array(ClassData("a", 1), ClassData("b", 2))) } + test("as seq of case class - reorder fields by name") { + val df = spark.range(3).select(array(struct($"id".cast("int").as("b"), lit("a").as("a")))) + val ds = df.as[Seq[ClassData]] + assert(ds.collect() === Array( + Seq(ClassData("a", 0)), + Seq(ClassData("a", 1)), + Seq(ClassData("a", 2)))) + } + test("map") { val ds = Seq(("a", 1), ("b", 2), ("c", 3)).toDS() checkDataset( @@ -1136,10 +1145,34 @@ class DatasetSuite extends QueryTest with SharedSQLContext { assert(spark.range(1).map { x => new java.sql.Timestamp(100000) }.head == new java.sql.Timestamp(100000)) } + + test("SPARK-19896: cannot have circular references in in case class") { + val errMsg1 = intercept[UnsupportedOperationException] { + Seq(CircularReferenceClassA(null)).toDS + } + assert(errMsg1.getMessage.startsWith("cannot have circular references in class, but got the " + + "circular reference of class")) + val errMsg2 = intercept[UnsupportedOperationException] { + Seq(CircularReferenceClassC(null)).toDS + } + assert(errMsg2.getMessage.startsWith("cannot have circular references in class, but got the " + + "circular reference of class")) + val errMsg3 = intercept[UnsupportedOperationException] { + Seq(CircularReferenceClassD(null)).toDS + } + assert(errMsg3.getMessage.startsWith("cannot have circular references in class, but got the " + + "circular reference of class")) + } + + test("SPARK-20125: option of map") { + val ds = Seq(WithMapInOption(Some(Map(1 -> 1)))).toDS() + checkDataset(ds, WithMapInOption(Some(Map(1 -> 1)))) + } } case class WithImmutableMap(id: String, map_test: scala.collection.immutable.Map[Long, String]) case class WithMap(id: String, map_test: scala.collection.Map[Long, String]) +case class WithMapInOption(m: Option[scala.collection.Map[Int, Int]]) case class Generic[T](id: T, value: Double) @@ -1214,3 +1247,9 @@ object DatasetTransform { case class Route(src: String, dest: String, cost: Int) case class GroupedRoutes(src: String, dest: String, routes: Seq[Route]) + +case class CircularReferenceClassA(cls: CircularReferenceClassB) +case class CircularReferenceClassB(cls: CircularReferenceClassA) +case class CircularReferenceClassC(ar: Array[CircularReferenceClassC]) +case class CircularReferenceClassD(map: Map[String, CircularReferenceClassE]) +case class CircularReferenceClassE(id: String, list: List[CircularReferenceClassD]) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala index 2e006735d123e..1a66aa85f5a02 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql +import scala.collection.mutable.ListBuffer import scala.language.existentials import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation @@ -24,7 +25,7 @@ import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.execution.joins._ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSQLContext - +import org.apache.spark.TestUtils.{assertNotSpilled, assertSpilled} class JoinSuite extends QueryTest with SharedSQLContext { import testImplicits._ @@ -604,4 +605,137 @@ class JoinSuite extends QueryTest with SharedSQLContext { cartesianQueries.foreach(checkCartesianDetection) } + + test("test SortMergeJoin (without spill)") { + withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "1", + "spark.sql.sortMergeJoinExec.buffer.spill.threshold" -> Int.MaxValue.toString) { + + assertNotSpilled(sparkContext, "inner join") { + checkAnswer( + sql("SELECT * FROM testData JOIN testData2 ON key = a where key = 2"), + Row(2, "2", 2, 1) :: Row(2, "2", 2, 2) :: Nil + ) + } + + val expected = new ListBuffer[Row]() + expected.append( + Row(1, "1", 1, 1), Row(1, "1", 1, 2), + Row(2, "2", 2, 1), Row(2, "2", 2, 2), + Row(3, "3", 3, 1), Row(3, "3", 3, 2) + ) + for (i <- 4 to 100) { + expected.append(Row(i, i.toString, null, null)) + } + + assertNotSpilled(sparkContext, "left outer join") { + checkAnswer( + sql( + """ + |SELECT + | big.key, big.value, small.a, small.b + |FROM + | testData big + |LEFT OUTER JOIN + | testData2 small + |ON + | big.key = small.a + """.stripMargin), + expected + ) + } + + assertNotSpilled(sparkContext, "right outer join") { + checkAnswer( + sql( + """ + |SELECT + | big.key, big.value, small.a, small.b + |FROM + | testData2 small + |RIGHT OUTER JOIN + | testData big + |ON + | big.key = small.a + """.stripMargin), + expected + ) + } + } + } + + test("test SortMergeJoin (with spill)") { + withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "1", + "spark.sql.sortMergeJoinExec.buffer.spill.threshold" -> "0") { + + assertSpilled(sparkContext, "inner join") { + checkAnswer( + sql("SELECT * FROM testData JOIN testData2 ON key = a where key = 2"), + Row(2, "2", 2, 1) :: Row(2, "2", 2, 2) :: Nil + ) + } + + val expected = new ListBuffer[Row]() + expected.append( + Row(1, "1", 1, 1), Row(1, "1", 1, 2), + Row(2, "2", 2, 1), Row(2, "2", 2, 2), + Row(3, "3", 3, 1), Row(3, "3", 3, 2) + ) + for (i <- 4 to 100) { + expected.append(Row(i, i.toString, null, null)) + } + + assertSpilled(sparkContext, "left outer join") { + checkAnswer( + sql( + """ + |SELECT + | big.key, big.value, small.a, small.b + |FROM + | testData big + |LEFT OUTER JOIN + | testData2 small + |ON + | big.key = small.a + """.stripMargin), + expected + ) + } + + assertSpilled(sparkContext, "right outer join") { + checkAnswer( + sql( + """ + |SELECT + | big.key, big.value, small.a, small.b + |FROM + | testData2 small + |RIGHT OUTER JOIN + | testData big + |ON + | big.key = small.a + """.stripMargin), + expected + ) + } + + // FULL OUTER JOIN still does not use [[ExternalAppendOnlyUnsafeRowArray]] + // so should not cause any spill + assertNotSpilled(sparkContext, "full outer join") { + checkAnswer( + sql( + """ + |SELECT + | big.key, big.value, small.a, small.b + |FROM + | testData2 small + |FULL OUTER JOIN + | testData big + |ON + | big.key = small.a + """.stripMargin), + expected + ) + } + } + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/JsonFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/JsonFunctionsSuite.scala index 953d161ec2a1d..69a500c845a7b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/JsonFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/JsonFunctionsSuite.scala @@ -156,7 +156,14 @@ class JsonFunctionsSuite extends QueryTest with SharedSQLContext { Row(Seq(Row(1, "a"), Row(2, null), Row(null, null)))) } - test("to_json") { + test("from_json uses DDL strings for defining a schema") { + val df = Seq("""{"a": 1, "b": "haa"}""").toDS() + checkAnswer( + df.select(from_json($"value", "a INT, b STRING", new java.util.HashMap[String, String]())), + Row(Row(1, "haa")) :: Nil) + } + + test("to_json - struct") { val df = Seq(Tuple1(Tuple1(1))).toDF("a") checkAnswer( @@ -164,6 +171,14 @@ class JsonFunctionsSuite extends QueryTest with SharedSQLContext { Row("""{"_1":1}""") :: Nil) } + test("to_json - array") { + val df = Seq(Tuple1(Tuple1(1) :: Nil)).toDF("a") + + checkAnswer( + df.select(to_json($"a")), + Row("""[{"_1":1}]""") :: Nil) + } + test("to_json with option") { val df = Seq(Tuple1(Tuple1(java.sql.Timestamp.valueOf("2015-08-26 18:00:00.0")))).toDF("a") val options = Map("timestampFormat" -> "dd/MM/yyyy HH:mm") @@ -184,7 +199,7 @@ class JsonFunctionsSuite extends QueryTest with SharedSQLContext { "Unable to convert column a of type calendarinterval to JSON.")) } - test("roundtrip in to_json and from_json") { + test("roundtrip in to_json and from_json - struct") { val dfOne = Seq(Tuple1(Tuple1(1)), Tuple1(null)).toDF("struct") val schemaOne = dfOne.schema(0).dataType.asInstanceOf[StructType] val readBackOne = dfOne.select(to_json($"struct").as("json")) @@ -197,4 +212,77 @@ class JsonFunctionsSuite extends QueryTest with SharedSQLContext { .select(to_json($"struct").as("json")) checkAnswer(dfTwo, readBackTwo) } + + test("roundtrip in to_json and from_json - array") { + val dfOne = Seq(Tuple1(Tuple1(1) :: Nil), Tuple1(null :: Nil)).toDF("array") + val schemaOne = dfOne.schema(0).dataType + val readBackOne = dfOne.select(to_json($"array").as("json")) + .select(from_json($"json", schemaOne).as("array")) + checkAnswer(dfOne, readBackOne) + + val dfTwo = Seq(Some("""[{"a":1}]"""), None).toDF("json") + val schemaTwo = ArrayType(StructType(StructField("a", IntegerType) :: Nil)) + val readBackTwo = dfTwo.select(from_json($"json", schemaTwo).as("array")) + .select(to_json($"array").as("json")) + checkAnswer(dfTwo, readBackTwo) + } + + test("SPARK-19637 Support to_json in SQL") { + val df1 = Seq(Tuple1(Tuple1(1))).toDF("a") + checkAnswer( + df1.selectExpr("to_json(a)"), + Row("""{"_1":1}""") :: Nil) + + val df2 = Seq(Tuple1(Tuple1(java.sql.Timestamp.valueOf("2015-08-26 18:00:00.0")))).toDF("a") + checkAnswer( + df2.selectExpr("to_json(a, map('timestampFormat', 'dd/MM/yyyy HH:mm'))"), + Row("""{"_1":"26/08/2015 18:00"}""") :: Nil) + + val errMsg1 = intercept[AnalysisException] { + df2.selectExpr("to_json(a, named_struct('a', 1))") + } + assert(errMsg1.getMessage.startsWith("Must use a map() function for options")) + + val errMsg2 = intercept[AnalysisException] { + df2.selectExpr("to_json(a, map('a', 1))") + } + assert(errMsg2.getMessage.startsWith( + "A type of keys and values in map() must be string, but got")) + } + + test("SPARK-19967 Support from_json in SQL") { + val df1 = Seq("""{"a": 1}""").toDS() + checkAnswer( + df1.selectExpr("from_json(value, 'a INT')"), + Row(Row(1)) :: Nil) + + val df2 = Seq("""{"c0": "a", "c1": 1, "c2": {"c20": 3.8, "c21": 8}}""").toDS() + checkAnswer( + df2.selectExpr("from_json(value, 'c0 STRING, c1 INT, c2 STRUCT')"), + Row(Row("a", 1, Row(3.8, 8))) :: Nil) + + val df3 = Seq("""{"time": "26/08/2015 18:00"}""").toDS() + checkAnswer( + df3.selectExpr( + "from_json(value, 'time Timestamp', map('timestampFormat', 'dd/MM/yyyy HH:mm'))"), + Row(Row(java.sql.Timestamp.valueOf("2015-08-26 18:00:00.0")))) + + val errMsg1 = intercept[AnalysisException] { + df3.selectExpr("from_json(value, 1)") + } + assert(errMsg1.getMessage.startsWith("Expected a string literal instead of")) + val errMsg2 = intercept[AnalysisException] { + df3.selectExpr("""from_json(value, 'time InvalidType')""") + } + assert(errMsg2.getMessage.contains("DataType invalidtype is not supported")) + val errMsg3 = intercept[AnalysisException] { + df3.selectExpr("from_json(value, 'time Timestamp', named_struct('a', 1))") + } + assert(errMsg3.getMessage.startsWith("Must use a map() function for options")) + val errMsg4 = intercept[AnalysisException] { + df3.selectExpr("from_json(value, 'time Timestamp', map('a', 1))") + } + assert(errMsg4.getMessage.startsWith( + "A type of keys and values in map() must be string, but got")) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/MathFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/MathFunctionsSuite.scala index 37443d0342980..328c5395ec91e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/MathFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/MathFunctionsSuite.scala @@ -233,6 +233,18 @@ class MathFunctionsSuite extends QueryTest with SharedSQLContext { ) } + test("round/bround with data frame from a local Seq of Product") { + val df = spark.createDataFrame(Seq(Tuple1(BigDecimal("5.9")))).toDF("value") + checkAnswer( + df.withColumn("value_rounded", round('value)), + Seq(Row(BigDecimal("5.9"), BigDecimal("6"))) + ) + checkAnswer( + df.withColumn("value_brounded", bround('value)), + Seq(Row(BigDecimal("5.9"), BigDecimal("6"))) + ) + } + test("exp") { testOneToOneMathFunction(exp, math.exp) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index 468ea0551298e..3ecbf96b41961 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql import java.io.File import java.math.MathContext +import java.net.{MalformedURLException, URL} import java.sql.Timestamp import java.util.concurrent.atomic.AtomicBoolean @@ -1019,6 +1020,18 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { spark.sessionState.conf.clear() } + test("SET mapreduce.job.reduces automatically converted to spark.sql.shuffle.partitions") { + spark.sessionState.conf.clear() + val before = spark.conf.get(SQLConf.SHUFFLE_PARTITIONS.key).toInt + val newConf = before + 1 + sql(s"SET mapreduce.job.reduces=${newConf.toString}") + val after = spark.conf.get(SQLConf.SHUFFLE_PARTITIONS.key).toInt + assert(before != after) + assert(newConf === after) + intercept[IllegalArgumentException](sql(s"SET mapreduce.job.reduces=-1")) + spark.sessionState.conf.clear() + } + test("apply schema") { val schema1 = StructType( StructField("f1", IntegerType, false) :: @@ -2586,4 +2599,24 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { } assert(!jobStarted.get(), "Command should not trigger a Spark job.") } + + test("SPARK-20164: AnalysisException should be tolerant to null query plan") { + try { + throw new AnalysisException("", None, None, plan = null) + } catch { + case ae: AnalysisException => assert(ae.plan == null && ae.getMessage == ae.getSimpleMessage) + } + } + + test("SPARK-12868: Allow adding jars from hdfs ") { + val jarFromHdfs = "hdfs://doesnotmatter/test.jar" + val jarFromInvalidFs = "fffs://doesnotmatter/test.jar" + + // if 'hdfs' is not supported, MalformedURLException will be thrown + new URL(jarFromHdfs) + + intercept[MalformedURLException] { + new URL(jarFromInvalidFs) + } + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQueryTestSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQueryTestSuite.scala index 68ababcd11027..d9130fdcfaea6 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQueryTestSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQueryTestSuite.scala @@ -26,6 +26,7 @@ import org.apache.spark.sql.catalyst.planning.PhysicalOperation import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules.RuleExecutor import org.apache.spark.sql.catalyst.util.{fileToString, stringToFile} +import org.apache.spark.sql.execution.command.DescribeTableCommand import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types.StructType @@ -123,7 +124,8 @@ class SQLQueryTestSuite extends QueryTest with SharedSQLContext { } private def createScalaTestCase(testCase: TestCase): Unit = { - if (blackList.exists(t => testCase.name.toLowerCase.contains(t.toLowerCase))) { + if (blackList.exists(t => + testCase.name.toLowerCase(Locale.ROOT).contains(t.toLowerCase(Locale.ROOT)))) { // Create a test case to ignore this case. ignore(testCase.name) { /* Do nothing */ } } else { @@ -165,8 +167,8 @@ class SQLQueryTestSuite extends QueryTest with SharedSQLContext { s"-- Number of queries: ${outputs.size}\n\n\n" + outputs.zipWithIndex.map{case (qr, i) => qr.toString(i)}.mkString("\n\n\n") + "\n" } - val resultFile = new File(testCase.resultFile); - val parent = resultFile.getParentFile(); + val resultFile = new File(testCase.resultFile) + val parent = resultFile.getParentFile if (!parent.exists()) { assert(parent.mkdirs(), "Could not create directory: " + parent) } @@ -212,20 +214,25 @@ class SQLQueryTestSuite extends QueryTest with SharedSQLContext { /** Executes a query and returns the result as (schema of the output, normalized output). */ private def getNormalizedResult(session: SparkSession, sql: String): (StructType, Seq[String]) = { // Returns true if the plan is supposed to be sorted. - def isSorted(plan: LogicalPlan): Boolean = plan match { + def needSort(plan: LogicalPlan): Boolean = plan match { case _: Join | _: Aggregate | _: Generate | _: Sample | _: Distinct => false + case _: DescribeTableCommand => true case PhysicalOperation(_, _, Sort(_, true, _)) => true - case _ => plan.children.iterator.exists(isSorted) + case _ => plan.children.iterator.exists(needSort) } try { val df = session.sql(sql) val schema = df.schema + val notIncludedMsg = "[not included in comparison]" // Get answer, but also get rid of the #1234 expression ids that show up in explain plans - val answer = df.queryExecution.hiveResultString().map(_.replaceAll("#\\d+", "#x")) + val answer = df.queryExecution.hiveResultString().map(_.replaceAll("#\\d+", "#x") + .replaceAll("Location.*/sql/core/", s"Location ${notIncludedMsg}sql/core/") + .replaceAll("Created.*", s"Created $notIncludedMsg") + .replaceAll("Last Access.*", s"Last Access $notIncludedMsg")) // If the output is not pre-sorted, sort it. - if (isSorted(df.queryExecution.analyzed)) (schema, answer) else (schema, answer.sorted) + if (needSort(df.queryExecution.analyzed)) (schema, answer) else (schema, answer.sorted) } catch { case a: AnalysisException => diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SessionStateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SessionStateSuite.scala new file mode 100644 index 0000000000000..5638c8eeda842 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/SessionStateSuite.scala @@ -0,0 +1,215 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +import org.scalatest.BeforeAndAfterAll +import org.scalatest.BeforeAndAfterEach +import scala.collection.mutable.ArrayBuffer + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.catalyst.rules.Rule +import org.apache.spark.sql.execution.QueryExecution +import org.apache.spark.sql.util.QueryExecutionListener + +class SessionStateSuite extends SparkFunSuite + with BeforeAndAfterEach with BeforeAndAfterAll { + + /** + * A shared SparkSession for all tests in this suite. Make sure you reset any changes to this + * session as this is a singleton HiveSparkSession in HiveSessionStateSuite and it's shared + * with all Hive test suites. + */ + protected var activeSession: SparkSession = _ + + override def beforeAll(): Unit = { + activeSession = SparkSession.builder().master("local").getOrCreate() + } + + override def afterAll(): Unit = { + if (activeSession != null) { + activeSession.stop() + activeSession = null + } + super.afterAll() + } + + test("fork new session and inherit RuntimeConfig options") { + val key = "spark-config-clone" + try { + activeSession.conf.set(key, "active") + + // inheritance + val forkedSession = activeSession.cloneSession() + assert(forkedSession ne activeSession) + assert(forkedSession.conf ne activeSession.conf) + assert(forkedSession.conf.get(key) == "active") + + // independence + forkedSession.conf.set(key, "forked") + assert(activeSession.conf.get(key) == "active") + activeSession.conf.set(key, "dontcopyme") + assert(forkedSession.conf.get(key) == "forked") + } finally { + activeSession.conf.unset(key) + } + } + + test("fork new session and inherit function registry and udf") { + val testFuncName1 = "strlenScala" + val testFuncName2 = "addone" + try { + activeSession.udf.register(testFuncName1, (_: String).length + (_: Int)) + val forkedSession = activeSession.cloneSession() + + // inheritance + assert(forkedSession ne activeSession) + assert(forkedSession.sessionState.functionRegistry ne + activeSession.sessionState.functionRegistry) + assert(forkedSession.sessionState.functionRegistry.lookupFunction(testFuncName1).nonEmpty) + + // independence + forkedSession.sessionState.functionRegistry.dropFunction(testFuncName1) + assert(activeSession.sessionState.functionRegistry.lookupFunction(testFuncName1).nonEmpty) + activeSession.udf.register(testFuncName2, (_: Int) + 1) + assert(forkedSession.sessionState.functionRegistry.lookupFunction(testFuncName2).isEmpty) + } finally { + activeSession.sessionState.functionRegistry.dropFunction(testFuncName1) + activeSession.sessionState.functionRegistry.dropFunction(testFuncName2) + } + } + + test("fork new session and inherit experimental methods") { + val originalExtraOptimizations = activeSession.experimental.extraOptimizations + val originalExtraStrategies = activeSession.experimental.extraStrategies + try { + object DummyRule1 extends Rule[LogicalPlan] { + def apply(p: LogicalPlan): LogicalPlan = p + } + object DummyRule2 extends Rule[LogicalPlan] { + def apply(p: LogicalPlan): LogicalPlan = p + } + val optimizations = List(DummyRule1, DummyRule2) + activeSession.experimental.extraOptimizations = optimizations + val forkedSession = activeSession.cloneSession() + + // inheritance + assert(forkedSession ne activeSession) + assert(forkedSession.experimental ne activeSession.experimental) + assert(forkedSession.experimental.extraOptimizations.toSet == + activeSession.experimental.extraOptimizations.toSet) + + // independence + forkedSession.experimental.extraOptimizations = List(DummyRule2) + assert(activeSession.experimental.extraOptimizations == optimizations) + activeSession.experimental.extraOptimizations = List(DummyRule1) + assert(forkedSession.experimental.extraOptimizations == List(DummyRule2)) + } finally { + activeSession.experimental.extraOptimizations = originalExtraOptimizations + activeSession.experimental.extraStrategies = originalExtraStrategies + } + } + + test("fork new session and inherit listener manager") { + class CommandCollector extends QueryExecutionListener { + val commands: ArrayBuffer[String] = ArrayBuffer.empty[String] + override def onFailure(funcName: String, qe: QueryExecution, ex: Exception) : Unit = {} + override def onSuccess(funcName: String, qe: QueryExecution, duration: Long): Unit = { + commands += funcName + } + } + val collectorA = new CommandCollector + val collectorB = new CommandCollector + val collectorC = new CommandCollector + + try { + def runCollectQueryOn(sparkSession: SparkSession): Unit = { + val tupleEncoder = Encoders.tuple(Encoders.scalaInt, Encoders.STRING) + val df = sparkSession.createDataset(Seq(1 -> "a"))(tupleEncoder).toDF("i", "j") + df.select("i").collect() + } + + activeSession.listenerManager.register(collectorA) + val forkedSession = activeSession.cloneSession() + + // inheritance + assert(forkedSession ne activeSession) + assert(forkedSession.listenerManager ne activeSession.listenerManager) + runCollectQueryOn(forkedSession) + assert(collectorA.commands.length == 1) // forked should callback to A + assert(collectorA.commands(0) == "collect") + + // independence + // => changes to forked do not affect original + forkedSession.listenerManager.register(collectorB) + runCollectQueryOn(activeSession) + assert(collectorB.commands.isEmpty) // original should not callback to B + assert(collectorA.commands.length == 2) // original should still callback to A + assert(collectorA.commands(1) == "collect") + // <= changes to original do not affect forked + activeSession.listenerManager.register(collectorC) + runCollectQueryOn(forkedSession) + assert(collectorC.commands.isEmpty) // forked should not callback to C + assert(collectorA.commands.length == 3) // forked should still callback to A + assert(collectorB.commands.length == 1) // forked should still callback to B + assert(collectorA.commands(2) == "collect") + assert(collectorB.commands(0) == "collect") + } finally { + activeSession.listenerManager.unregister(collectorA) + activeSession.listenerManager.unregister(collectorC) + } + } + + test("fork new sessions and run query on inherited table") { + def checkTableExists(sparkSession: SparkSession): Unit = { + QueryTest.checkAnswer(sparkSession.sql( + """ + |SELECT x.str, COUNT(*) + |FROM df x JOIN df y ON x.str = y.str + |GROUP BY x.str + """.stripMargin), + Row("1", 1) :: Row("2", 1) :: Row("3", 1) :: Nil) + } + + val spark = activeSession + // Cannot use `import activeSession.implicits._` due to the compiler limitation. + import spark.implicits._ + + try { + activeSession + .createDataset[(Int, String)](Seq(1, 2, 3).map(i => (i, i.toString))) + .toDF("int", "str") + .createOrReplaceTempView("df") + checkTableExists(activeSession) + + val forkedSession = activeSession.cloneSession() + assert(forkedSession ne activeSession) + assert(forkedSession.sessionState ne activeSession.sessionState) + checkTableExists(forkedSession) + checkTableExists(activeSession.cloneSession()) // ability to clone multiple times + checkTableExists(forkedSession.cloneSession()) // clone of clone + } finally { + activeSession.sql("drop table df") + } + } + + test("fork new session and inherit reference to SharedState") { + val forkedSession = activeSession.cloneSession() + assert(activeSession.sharedState eq forkedSession.sharedState) + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionExtensionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionExtensionSuite.scala new file mode 100644 index 0000000000000..43db79663322a --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionExtensionSuite.scala @@ -0,0 +1,144 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.catalyst.{FunctionIdentifier, TableIdentifier} +import org.apache.spark.sql.catalyst.expressions.Expression +import org.apache.spark.sql.catalyst.parser.{CatalystSqlParser, ParserInterface} +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.catalyst.rules.Rule +import org.apache.spark.sql.execution.{SparkPlan, SparkStrategy} +import org.apache.spark.sql.types.{DataType, StructType} + +/** + * Test cases for the [[SparkSessionExtensions]]. + */ +class SparkSessionExtensionSuite extends SparkFunSuite { + type ExtensionsBuilder = SparkSessionExtensions => Unit + private def create(builder: ExtensionsBuilder): ExtensionsBuilder = builder + + private def stop(spark: SparkSession): Unit = { + spark.stop() + SparkSession.clearActiveSession() + SparkSession.clearDefaultSession() + } + + private def withSession(builder: ExtensionsBuilder)(f: SparkSession => Unit): Unit = { + val spark = SparkSession.builder().master("local[1]").withExtensions(builder).getOrCreate() + try f(spark) finally { + stop(spark) + } + } + + test("inject analyzer rule") { + withSession(_.injectResolutionRule(MyRule)) { session => + assert(session.sessionState.analyzer.extendedResolutionRules.contains(MyRule(session))) + } + } + + test("inject check analysis rule") { + withSession(_.injectCheckRule(MyCheckRule)) { session => + assert(session.sessionState.analyzer.extendedCheckRules.contains(MyCheckRule(session))) + } + } + + test("inject optimizer rule") { + withSession(_.injectOptimizerRule(MyRule)) { session => + assert(session.sessionState.optimizer.batches.flatMap(_.rules).contains(MyRule(session))) + } + } + + test("inject spark planner strategy") { + withSession(_.injectPlannerStrategy(MySparkStrategy)) { session => + assert(session.sessionState.planner.strategies.contains(MySparkStrategy(session))) + } + } + + test("inject parser") { + val extension = create { extensions => + extensions.injectParser((_, _) => CatalystSqlParser) + } + withSession(extension) { session => + assert(session.sessionState.sqlParser == CatalystSqlParser) + } + } + + test("inject stacked parsers") { + val extension = create { extensions => + extensions.injectParser((_, _) => CatalystSqlParser) + extensions.injectParser(MyParser) + extensions.injectParser(MyParser) + } + withSession(extension) { session => + val parser = MyParser(session, MyParser(session, CatalystSqlParser)) + assert(session.sessionState.sqlParser == parser) + } + } + + test("use custom class for extensions") { + val session = SparkSession.builder() + .master("local[1]") + .config("spark.sql.extensions", classOf[MyExtensions].getCanonicalName) + .getOrCreate() + try { + assert(session.sessionState.planner.strategies.contains(MySparkStrategy(session))) + assert(session.sessionState.analyzer.extendedResolutionRules.contains(MyRule(session))) + } finally { + stop(session) + } + } +} + +case class MyRule(spark: SparkSession) extends Rule[LogicalPlan] { + override def apply(plan: LogicalPlan): LogicalPlan = plan +} + +case class MyCheckRule(spark: SparkSession) extends (LogicalPlan => Unit) { + override def apply(plan: LogicalPlan): Unit = { } +} + +case class MySparkStrategy(spark: SparkSession) extends SparkStrategy { + override def apply(plan: LogicalPlan): Seq[SparkPlan] = Seq.empty +} + +case class MyParser(spark: SparkSession, delegate: ParserInterface) extends ParserInterface { + override def parsePlan(sqlText: String): LogicalPlan = + delegate.parsePlan(sqlText) + + override def parseExpression(sqlText: String): Expression = + delegate.parseExpression(sqlText) + + override def parseTableIdentifier(sqlText: String): TableIdentifier = + delegate.parseTableIdentifier(sqlText) + + override def parseFunctionIdentifier(sqlText: String): FunctionIdentifier = + delegate.parseFunctionIdentifier(sqlText) + + override def parseTableSchema(sqlText: String): StructType = + delegate.parseTableSchema(sqlText) + + override def parseDataType(sqlText: String): DataType = + delegate.parseDataType(sqlText) +} + +class MyExtensions extends (SparkSessionExtensions => Unit) { + def apply(e: SparkSessionExtensions): Unit = { + e.injectPlannerStrategy(MySparkStrategy) + e.injectResolutionRule(MyRule) + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/StatisticsCollectionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/StatisticsCollectionSuite.scala index bbb31dbc8f3de..ddc393c8da053 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/StatisticsCollectionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/StatisticsCollectionSuite.scala @@ -26,6 +26,7 @@ import scala.util.Random import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.catalog.{CatalogRelation, CatalogStatistics} import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.execution.datasources.LogicalRelation import org.apache.spark.sql.internal.StaticSQLConf import org.apache.spark.sql.test.{SharedSQLContext, SQLTestUtils} @@ -112,36 +113,12 @@ class StatisticsCollectionSuite extends StatisticsCollectionTestBase with Shared spark.sessionState.conf.autoBroadcastJoinThreshold) } - test("estimates the size of limit") { - withTempView("test") { - Seq(("one", 1), ("two", 2), ("three", 3), ("four", 4)).toDF("k", "v") - .createOrReplaceTempView("test") - Seq((0, 1), (1, 24), (2, 48)).foreach { case (limit, expected) => - val df = sql(s"""SELECT * FROM test limit $limit""") - - val sizesGlobalLimit = df.queryExecution.analyzed.collect { case g: GlobalLimit => - g.stats(conf).sizeInBytes - } - assert(sizesGlobalLimit.size === 1, s"Size wrong for:\n ${df.queryExecution}") - assert(sizesGlobalLimit.head === BigInt(expected), - s"expected exact size $expected for table 'test', got: ${sizesGlobalLimit.head}") - - val sizesLocalLimit = df.queryExecution.analyzed.collect { case l: LocalLimit => - l.stats(conf).sizeInBytes - } - assert(sizesLocalLimit.size === 1, s"Size wrong for:\n ${df.queryExecution}") - assert(sizesLocalLimit.head === BigInt(expected), - s"expected exact size $expected for table 'test', got: ${sizesLocalLimit.head}") - } - } - } - test("column stats round trip serialization") { // Make sure we serialize and then deserialize and we will get the result data val df = data.toDF(stats.keys.toSeq :+ "carray" : _*) stats.zip(df.schema).foreach { case ((k, v), field) => withClue(s"column $k with type ${field.dataType}") { - val roundtrip = ColumnStat.fromMap("table_is_foo", field, v.toMap) + val roundtrip = ColumnStat.fromMap("table_is_foo", field, v.toMap(k, field.dataType)) assert(roundtrip == Some(v)) } } @@ -225,17 +202,19 @@ abstract class StatisticsCollectionTestBase extends QueryTest with SQLTestUtils /** A mapping from column to the stats collected. */ protected val stats = mutable.LinkedHashMap( "cbool" -> ColumnStat(2, Some(false), Some(true), 1, 1, 1), - "cbyte" -> ColumnStat(2, Some(1L), Some(2L), 1, 1, 1), - "cshort" -> ColumnStat(2, Some(1L), Some(3L), 1, 2, 2), - "cint" -> ColumnStat(2, Some(1L), Some(4L), 1, 4, 4), + "cbyte" -> ColumnStat(2, Some(1.toByte), Some(2.toByte), 1, 1, 1), + "cshort" -> ColumnStat(2, Some(1.toShort), Some(3.toShort), 1, 2, 2), + "cint" -> ColumnStat(2, Some(1), Some(4), 1, 4, 4), "clong" -> ColumnStat(2, Some(1L), Some(5L), 1, 8, 8), "cdouble" -> ColumnStat(2, Some(1.0), Some(6.0), 1, 8, 8), - "cfloat" -> ColumnStat(2, Some(1.0), Some(7.0), 1, 4, 4), - "cdecimal" -> ColumnStat(2, Some(dec1), Some(dec2), 1, 16, 16), + "cfloat" -> ColumnStat(2, Some(1.0f), Some(7.0f), 1, 4, 4), + "cdecimal" -> ColumnStat(2, Some(Decimal(dec1)), Some(Decimal(dec2)), 1, 16, 16), "cstring" -> ColumnStat(2, None, None, 1, 3, 3), "cbinary" -> ColumnStat(2, None, None, 1, 3, 3), - "cdate" -> ColumnStat(2, Some(d1), Some(d2), 1, 4, 4), - "ctimestamp" -> ColumnStat(2, Some(t1), Some(t2), 1, 8, 8) + "cdate" -> ColumnStat(2, Some(DateTimeUtils.fromJavaDate(d1)), + Some(DateTimeUtils.fromJavaDate(d2)), 1, 4, 4), + "ctimestamp" -> ColumnStat(2, Some(DateTimeUtils.fromJavaTimestamp(t1)), + Some(DateTimeUtils.fromJavaTimestamp(t2)), 1, 8, 8) ) private val randomName = new Random(31) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala index 25dbecb5894e4..131abf7c1e5d3 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala @@ -622,7 +622,12 @@ class SubquerySuite extends QueryTest with SharedSQLContext { test("SPARK-15370: COUNT bug with attribute ref in subquery input and output ") { checkAnswer( - sql("select l.b, (select (r.c + count(*)) is null from r where l.a = r.c) from l"), + sql( + """ + |select l.b, (select (r.c + count(*)) is null + |from r + |where l.a = r.c group by r.c) from l + """.stripMargin), Row(1.0, false) :: Row(1.0, false) :: Row(2.0, true) :: Row(2.0, true) :: Row(3.0, false) :: Row(5.0, true) :: Row(null, false) :: Row(null, true) :: Nil) } @@ -817,12 +822,49 @@ class SubquerySuite extends QueryTest with SharedSQLContext { checkAnswer( sql( """ - | select c2 - | from t1 - | where exists (select * - | from t2 lateral view explode(arr_c2) q as c2 - where t1.c1 = t2.c1)""".stripMargin), + | SELECT c2 + | FROM t1 + | WHERE EXISTS (SELECT * + | FROM t2 LATERAL VIEW explode(arr_c2) q AS c2 + WHERE t1.c1 = t2.c1)""".stripMargin), Row(1) :: Row(0) :: Nil) + + val msg1 = intercept[AnalysisException] { + sql( + """ + | SELECT c1 + | FROM t2 + | WHERE EXISTS (SELECT * + | FROM t1 LATERAL VIEW explode(t2.arr_c2) q AS c2 + | WHERE t1.c1 = t2.c1) + """.stripMargin) + } + assert(msg1.getMessage.contains( + "Expressions referencing the outer query are not supported outside of WHERE/HAVING")) + } + } + + test("SPARK-19933 Do not eliminate top-level aliases in sub-queries") { + withTempView("t1", "t2") { + spark.range(4).createOrReplaceTempView("t1") + checkAnswer( + sql("select * from t1 where id in (select id as id from t1)"), + Row(0) :: Row(1) :: Row(2) :: Row(3) :: Nil) + + spark.range(2).createOrReplaceTempView("t2") + checkAnswer( + sql("select * from t1 where id in (select id as id from t2)"), + Row(0) :: Row(1) :: Nil) } } + + test("ListQuery and Exists should work even no correlated references") { + checkAnswer( + sql("select * from l, r where l.a = r.c AND (r.d in (select d from r) OR l.a >= 1)"), + Row(2, 1.0, 2, 3.0) :: Row(2, 1.0, 2, 3.0) :: Row(2, 1.0, 2, 3.0) :: + Row(2, 1.0, 2, 3.0) :: Row(3.0, 3.0, 3, 2.0) :: Row(6, null, 6, null) :: Nil) + checkAnswer( + sql("select * from l, r where l.a = r.c + 1 AND (exists (select * from r) OR l.a = r.c)"), + Row(3, 3.0, 2, 3.0) :: Row(3, 3.0, 2, 3.0) :: Nil) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/DataSourceScanExecRedactionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/DataSourceScanExecRedactionSuite.scala new file mode 100644 index 0000000000000..f7f1ccea281c1 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/DataSourceScanExecRedactionSuite.scala @@ -0,0 +1,55 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.execution + +import org.apache.hadoop.fs.Path + +import org.apache.spark.SparkConf +import org.apache.spark.sql.QueryTest +import org.apache.spark.sql.test.SharedSQLContext + +/** + * Suite that tests the redaction of DataSourceScanExec + */ +class DataSourceScanExecRedactionSuite extends QueryTest with SharedSQLContext { + + override protected def sparkConf: SparkConf = super.sparkConf + .set("spark.redaction.string.regex", "file:/[\\w_]+") + + test("treeString is redacted") { + withTempDir { dir => + val basePath = dir.getCanonicalPath + spark.range(0, 10).toDF("a").write.parquet(new Path(basePath, "foo=1").toString) + val df = spark.read.parquet(basePath) + + val rootPath = df.queryExecution.sparkPlan.find(_.isInstanceOf[FileSourceScanExec]).get + .asInstanceOf[FileSourceScanExec].relation.location.rootPaths.head + assert(rootPath.toString.contains(basePath.toString)) + + assert(!df.queryExecution.sparkPlan.treeString(verbose = true).contains(rootPath.getName)) + assert(!df.queryExecution.executedPlan.treeString(verbose = true).contains(rootPath.getName)) + assert(!df.queryExecution.toString.contains(rootPath.getName)) + assert(!df.queryExecution.simpleString.contains(rootPath.getName)) + + val replacement = "*********" + assert(df.queryExecution.sparkPlan.treeString(verbose = true).contains(replacement)) + assert(df.queryExecution.executedPlan.treeString(verbose = true).contains(replacement)) + assert(df.queryExecution.toString.contains(replacement)) + assert(df.queryExecution.simpleString.contains(replacement)) + } + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeSuite.scala index 36cde3233dce8..59eaf4d1c29b7 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeSuite.scala @@ -36,17 +36,17 @@ class ExchangeSuite extends SparkPlanTest with SharedSQLContext { ) } - test("compatible BroadcastMode") { + test("BroadcastMode.canonicalized") { val mode1 = IdentityBroadcastMode val mode2 = HashedRelationBroadcastMode(Literal(1L) :: Nil) val mode3 = HashedRelationBroadcastMode(Literal("s") :: Nil) - assert(mode1.compatibleWith(mode1)) - assert(!mode1.compatibleWith(mode2)) - assert(!mode2.compatibleWith(mode1)) - assert(mode2.compatibleWith(mode2)) - assert(!mode2.compatibleWith(mode3)) - assert(mode3.compatibleWith(mode3)) + assert(mode1.canonicalized == mode1.canonicalized) + assert(mode1.canonicalized != mode2.canonicalized) + assert(mode2.canonicalized != mode1.canonicalized) + assert(mode2.canonicalized == mode2.canonicalized) + assert(mode2.canonicalized != mode3.canonicalized) + assert(mode3.canonicalized == mode3.canonicalized) } test("BroadcastExchange same result") { @@ -70,7 +70,7 @@ class ExchangeSuite extends SparkPlanTest with SharedSQLContext { assert(!exchange1.sameResult(exchange2)) assert(!exchange2.sameResult(exchange3)) - assert(!exchange3.sameResult(exchange4)) + assert(exchange3.sameResult(exchange4)) assert(exchange4 sameResult exchange3) } @@ -98,7 +98,7 @@ class ExchangeSuite extends SparkPlanTest with SharedSQLContext { assert(exchange1 sameResult exchange2) assert(!exchange2.sameResult(exchange3)) assert(!exchange3.sameResult(exchange4)) - assert(!exchange4.sameResult(exchange5)) + assert(exchange4.sameResult(exchange5)) assert(exchange5 sameResult exchange4) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/ExternalAppendOnlyUnsafeRowArrayBenchmark.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/ExternalAppendOnlyUnsafeRowArrayBenchmark.scala new file mode 100644 index 0000000000000..00c5f2550cbb1 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/ExternalAppendOnlyUnsafeRowArrayBenchmark.scala @@ -0,0 +1,233 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution + +import scala.collection.mutable.ArrayBuffer + +import org.apache.spark.{SparkConf, SparkContext, SparkEnv, TaskContext} +import org.apache.spark.memory.MemoryTestingUtils +import org.apache.spark.sql.catalyst.expressions.UnsafeRow +import org.apache.spark.util.Benchmark +import org.apache.spark.util.collection.unsafe.sort.UnsafeExternalSorter + +object ExternalAppendOnlyUnsafeRowArrayBenchmark { + + def testAgainstRawArrayBuffer(numSpillThreshold: Int, numRows: Int, iterations: Int): Unit = { + val random = new java.util.Random() + val rows = (1 to numRows).map(_ => { + val row = new UnsafeRow(1) + row.pointTo(new Array[Byte](64), 16) + row.setLong(0, random.nextLong()) + row + }) + + val benchmark = new Benchmark(s"Array with $numRows rows", iterations * numRows) + + // Internally, `ExternalAppendOnlyUnsafeRowArray` will create an + // in-memory buffer of size `numSpillThreshold`. This will mimic that + val initialSize = + Math.min( + ExternalAppendOnlyUnsafeRowArray.DefaultInitialSizeOfInMemoryBuffer, + numSpillThreshold) + + benchmark.addCase("ArrayBuffer") { _: Int => + var sum = 0L + for (_ <- 0L until iterations) { + val array = new ArrayBuffer[UnsafeRow](initialSize) + + // Internally, `ExternalAppendOnlyUnsafeRowArray` will create a + // copy of the row. This will mimic that + rows.foreach(x => array += x.copy()) + + var i = 0 + val n = array.length + while (i < n) { + sum = sum + array(i).getLong(0) + i += 1 + } + array.clear() + } + } + + benchmark.addCase("ExternalAppendOnlyUnsafeRowArray") { _: Int => + var sum = 0L + for (_ <- 0L until iterations) { + val array = new ExternalAppendOnlyUnsafeRowArray(numSpillThreshold) + rows.foreach(x => array.add(x)) + + val iterator = array.generateIterator() + while (iterator.hasNext) { + sum = sum + iterator.next().getLong(0) + } + array.clear() + } + } + + val conf = new SparkConf(false) + // Make the Java serializer write a reset instruction (TC_RESET) after each object to test + // for a bug we had with bytes written past the last object in a batch (SPARK-2792) + conf.set("spark.serializer.objectStreamReset", "1") + conf.set("spark.serializer", "org.apache.spark.serializer.JavaSerializer") + + val sc = new SparkContext("local", "test", conf) + val taskContext = MemoryTestingUtils.fakeTaskContext(SparkEnv.get) + TaskContext.setTaskContext(taskContext) + benchmark.run() + sc.stop() + } + + def testAgainstRawUnsafeExternalSorter( + numSpillThreshold: Int, + numRows: Int, + iterations: Int): Unit = { + + val random = new java.util.Random() + val rows = (1 to numRows).map(_ => { + val row = new UnsafeRow(1) + row.pointTo(new Array[Byte](64), 16) + row.setLong(0, random.nextLong()) + row + }) + + val benchmark = new Benchmark(s"Spilling with $numRows rows", iterations * numRows) + + benchmark.addCase("UnsafeExternalSorter") { _: Int => + var sum = 0L + for (_ <- 0L until iterations) { + val array = UnsafeExternalSorter.create( + TaskContext.get().taskMemoryManager(), + SparkEnv.get.blockManager, + SparkEnv.get.serializerManager, + TaskContext.get(), + null, + null, + 1024, + SparkEnv.get.memoryManager.pageSizeBytes, + numSpillThreshold, + false) + + rows.foreach(x => + array.insertRecord( + x.getBaseObject, + x.getBaseOffset, + x.getSizeInBytes, + 0, + false)) + + val unsafeRow = new UnsafeRow(1) + val iter = array.getIterator + while (iter.hasNext) { + iter.loadNext() + unsafeRow.pointTo(iter.getBaseObject, iter.getBaseOffset, iter.getRecordLength) + sum = sum + unsafeRow.getLong(0) + } + array.cleanupResources() + } + } + + benchmark.addCase("ExternalAppendOnlyUnsafeRowArray") { _: Int => + var sum = 0L + for (_ <- 0L until iterations) { + val array = new ExternalAppendOnlyUnsafeRowArray(numSpillThreshold) + rows.foreach(x => array.add(x)) + + val iterator = array.generateIterator() + while (iterator.hasNext) { + sum = sum + iterator.next().getLong(0) + } + array.clear() + } + } + + val conf = new SparkConf(false) + // Make the Java serializer write a reset instruction (TC_RESET) after each object to test + // for a bug we had with bytes written past the last object in a batch (SPARK-2792) + conf.set("spark.serializer.objectStreamReset", "1") + conf.set("spark.serializer", "org.apache.spark.serializer.JavaSerializer") + + val sc = new SparkContext("local", "test", conf) + val taskContext = MemoryTestingUtils.fakeTaskContext(SparkEnv.get) + TaskContext.setTaskContext(taskContext) + benchmark.run() + sc.stop() + } + + def main(args: Array[String]): Unit = { + + // ========================================================================================= // + // WITHOUT SPILL + // ========================================================================================= // + + val spillThreshold = 100 * 1000 + + /* + Intel(R) Core(TM) i7-6920HQ CPU @ 2.90GHz + + Array with 1000 rows: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + ------------------------------------------------------------------------------------------------ + ArrayBuffer 7821 / 7941 33.5 29.8 1.0X + ExternalAppendOnlyUnsafeRowArray 8798 / 8819 29.8 33.6 0.9X + */ + testAgainstRawArrayBuffer(spillThreshold, 1000, 1 << 18) + + /* + Intel(R) Core(TM) i7-6920HQ CPU @ 2.90GHz + + Array with 30000 rows: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + ------------------------------------------------------------------------------------------------ + ArrayBuffer 19200 / 19206 25.6 39.1 1.0X + ExternalAppendOnlyUnsafeRowArray 19558 / 19562 25.1 39.8 1.0X + */ + testAgainstRawArrayBuffer(spillThreshold, 30 * 1000, 1 << 14) + + /* + Intel(R) Core(TM) i7-6920HQ CPU @ 2.90GHz + + Array with 100000 rows: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + ------------------------------------------------------------------------------------------------ + ArrayBuffer 5949 / 6028 17.2 58.1 1.0X + ExternalAppendOnlyUnsafeRowArray 6078 / 6138 16.8 59.4 1.0X + */ + testAgainstRawArrayBuffer(spillThreshold, 100 * 1000, 1 << 10) + + // ========================================================================================= // + // WITH SPILL + // ========================================================================================= // + + /* + Intel(R) Core(TM) i7-6920HQ CPU @ 2.90GHz + + Spilling with 1000 rows: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + ------------------------------------------------------------------------------------------------ + UnsafeExternalSorter 9239 / 9470 28.4 35.2 1.0X + ExternalAppendOnlyUnsafeRowArray 8857 / 8909 29.6 33.8 1.0X + */ + testAgainstRawUnsafeExternalSorter(100 * 1000, 1000, 1 << 18) + + /* + Intel(R) Core(TM) i7-6920HQ CPU @ 2.90GHz + + Spilling with 10000 rows: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + ------------------------------------------------------------------------------------------------ + UnsafeExternalSorter 4 / 5 39.3 25.5 1.0X + ExternalAppendOnlyUnsafeRowArray 5 / 6 29.8 33.5 0.8X + */ + testAgainstRawUnsafeExternalSorter( + UnsafeExternalSorter.DEFAULT_NUM_ELEMENTS_FOR_SPILL_THRESHOLD.toInt, 10 * 1000, 1 << 4) + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/ExternalAppendOnlyUnsafeRowArraySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/ExternalAppendOnlyUnsafeRowArraySuite.scala new file mode 100644 index 0000000000000..53c41639942b4 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/ExternalAppendOnlyUnsafeRowArraySuite.scala @@ -0,0 +1,351 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution + +import java.util.ConcurrentModificationException + +import scala.collection.mutable.ArrayBuffer + +import org.apache.spark._ +import org.apache.spark.memory.MemoryTestingUtils +import org.apache.spark.sql.catalyst.expressions.UnsafeRow + +class ExternalAppendOnlyUnsafeRowArraySuite extends SparkFunSuite with LocalSparkContext { + private val random = new java.util.Random() + private var taskContext: TaskContext = _ + + override def afterAll(): Unit = TaskContext.unset() + + private def withExternalArray(spillThreshold: Int) + (f: ExternalAppendOnlyUnsafeRowArray => Unit): Unit = { + sc = new SparkContext("local", "test", new SparkConf(false)) + + taskContext = MemoryTestingUtils.fakeTaskContext(SparkEnv.get) + TaskContext.setTaskContext(taskContext) + + val array = new ExternalAppendOnlyUnsafeRowArray( + taskContext.taskMemoryManager(), + SparkEnv.get.blockManager, + SparkEnv.get.serializerManager, + taskContext, + 1024, + SparkEnv.get.memoryManager.pageSizeBytes, + spillThreshold) + try f(array) finally { + array.clear() + } + } + + private def insertRow(array: ExternalAppendOnlyUnsafeRowArray): Long = { + val valueInserted = random.nextLong() + + val row = new UnsafeRow(1) + row.pointTo(new Array[Byte](64), 16) + row.setLong(0, valueInserted) + array.add(row) + valueInserted + } + + private def checkIfValueExists(iterator: Iterator[UnsafeRow], expectedValue: Long): Unit = { + assert(iterator.hasNext) + val actualRow = iterator.next() + assert(actualRow.getLong(0) == expectedValue) + assert(actualRow.getSizeInBytes == 16) + } + + private def validateData( + array: ExternalAppendOnlyUnsafeRowArray, + expectedValues: ArrayBuffer[Long]): Iterator[UnsafeRow] = { + val iterator = array.generateIterator() + for (value <- expectedValues) { + checkIfValueExists(iterator, value) + } + + assert(!iterator.hasNext) + iterator + } + + private def populateRows( + array: ExternalAppendOnlyUnsafeRowArray, + numRowsToBePopulated: Int): ArrayBuffer[Long] = { + val populatedValues = new ArrayBuffer[Long] + populateRows(array, numRowsToBePopulated, populatedValues) + } + + private def populateRows( + array: ExternalAppendOnlyUnsafeRowArray, + numRowsToBePopulated: Int, + populatedValues: ArrayBuffer[Long]): ArrayBuffer[Long] = { + for (_ <- 0 until numRowsToBePopulated) { + populatedValues.append(insertRow(array)) + } + populatedValues + } + + private def getNumBytesSpilled: Long = { + TaskContext.get().taskMetrics().memoryBytesSpilled + } + + private def assertNoSpill(): Unit = { + assert(getNumBytesSpilled == 0) + } + + private def assertSpill(): Unit = { + assert(getNumBytesSpilled > 0) + } + + test("insert rows less than the spillThreshold") { + val spillThreshold = 100 + withExternalArray(spillThreshold) { array => + assert(array.isEmpty) + + val expectedValues = populateRows(array, 1) + assert(!array.isEmpty) + assert(array.length == 1) + + val iterator1 = validateData(array, expectedValues) + + // Add more rows (but not too many to trigger switch to [[UnsafeExternalSorter]]) + // Verify that NO spill has happened + populateRows(array, spillThreshold - 1, expectedValues) + assert(array.length == spillThreshold) + assertNoSpill() + + val iterator2 = validateData(array, expectedValues) + + assert(!iterator1.hasNext) + assert(!iterator2.hasNext) + } + } + + test("insert rows more than the spillThreshold to force spill") { + val spillThreshold = 100 + withExternalArray(spillThreshold) { array => + val numValuesInserted = 20 * spillThreshold + + assert(array.isEmpty) + val expectedValues = populateRows(array, 1) + assert(array.length == 1) + + val iterator1 = validateData(array, expectedValues) + + // Populate more rows to trigger spill. Verify that spill has happened + populateRows(array, numValuesInserted - 1, expectedValues) + assert(array.length == numValuesInserted) + assertSpill() + + val iterator2 = validateData(array, expectedValues) + assert(!iterator2.hasNext) + + assert(!iterator1.hasNext) + intercept[ConcurrentModificationException](iterator1.next()) + } + } + + test("iterator on an empty array should be empty") { + withExternalArray(spillThreshold = 10) { array => + val iterator = array.generateIterator() + assert(array.isEmpty) + assert(array.length == 0) + assert(!iterator.hasNext) + } + } + + test("generate iterator with negative start index") { + withExternalArray(spillThreshold = 2) { array => + val exception = + intercept[ArrayIndexOutOfBoundsException](array.generateIterator(startIndex = -10)) + + assert(exception.getMessage.contains( + "Invalid `startIndex` provided for generating iterator over the array") + ) + } + } + + test("generate iterator with start index exceeding array's size (without spill)") { + val spillThreshold = 2 + withExternalArray(spillThreshold) { array => + populateRows(array, spillThreshold / 2) + + val exception = + intercept[ArrayIndexOutOfBoundsException]( + array.generateIterator(startIndex = spillThreshold * 10)) + assert(exception.getMessage.contains( + "Invalid `startIndex` provided for generating iterator over the array")) + } + } + + test("generate iterator with start index exceeding array's size (with spill)") { + val spillThreshold = 2 + withExternalArray(spillThreshold) { array => + populateRows(array, spillThreshold * 2) + + val exception = + intercept[ArrayIndexOutOfBoundsException]( + array.generateIterator(startIndex = spillThreshold * 10)) + + assert(exception.getMessage.contains( + "Invalid `startIndex` provided for generating iterator over the array")) + } + } + + test("generate iterator with custom start index (without spill)") { + val spillThreshold = 10 + withExternalArray(spillThreshold) { array => + val expectedValues = populateRows(array, spillThreshold) + val startIndex = spillThreshold / 2 + val iterator = array.generateIterator(startIndex = startIndex) + for (i <- startIndex until expectedValues.length) { + checkIfValueExists(iterator, expectedValues(i)) + } + } + } + + test("generate iterator with custom start index (with spill)") { + val spillThreshold = 10 + withExternalArray(spillThreshold) { array => + val expectedValues = populateRows(array, spillThreshold * 10) + val startIndex = spillThreshold * 2 + val iterator = array.generateIterator(startIndex = startIndex) + for (i <- startIndex until expectedValues.length) { + checkIfValueExists(iterator, expectedValues(i)) + } + } + } + + test("test iterator invalidation (without spill)") { + withExternalArray(spillThreshold = 10) { array => + // insert 2 rows, iterate until the first row + populateRows(array, 2) + + var iterator = array.generateIterator() + assert(iterator.hasNext) + iterator.next() + + // Adding more row(s) should invalidate any old iterators + populateRows(array, 1) + assert(!iterator.hasNext) + intercept[ConcurrentModificationException](iterator.next()) + + // Clearing the array should also invalidate any old iterators + iterator = array.generateIterator() + assert(iterator.hasNext) + iterator.next() + + array.clear() + assert(!iterator.hasNext) + intercept[ConcurrentModificationException](iterator.next()) + } + } + + test("test iterator invalidation (with spill)") { + val spillThreshold = 10 + withExternalArray(spillThreshold) { array => + // Populate enough rows so that spill has happens + populateRows(array, spillThreshold * 2) + assertSpill() + + var iterator = array.generateIterator() + assert(iterator.hasNext) + iterator.next() + + // Adding more row(s) should invalidate any old iterators + populateRows(array, 1) + assert(!iterator.hasNext) + intercept[ConcurrentModificationException](iterator.next()) + + // Clearing the array should also invalidate any old iterators + iterator = array.generateIterator() + assert(iterator.hasNext) + iterator.next() + + array.clear() + assert(!iterator.hasNext) + intercept[ConcurrentModificationException](iterator.next()) + } + } + + test("clear on an empty the array") { + withExternalArray(spillThreshold = 2) { array => + val iterator = array.generateIterator() + assert(!iterator.hasNext) + + // multiple clear'ing should not have an side-effect + array.clear() + array.clear() + array.clear() + assert(array.isEmpty) + assert(array.length == 0) + + // Clearing an empty array should also invalidate any old iterators + assert(!iterator.hasNext) + intercept[ConcurrentModificationException](iterator.next()) + } + } + + test("clear array (without spill)") { + val spillThreshold = 10 + withExternalArray(spillThreshold) { array => + // Populate rows ... but not enough to trigger spill + populateRows(array, spillThreshold / 2) + assertNoSpill() + + // Clear the array + array.clear() + assert(array.isEmpty) + + // Re-populate few rows so that there is no spill + // Verify the data. Verify that there was no spill + val expectedValues = populateRows(array, spillThreshold / 3) + validateData(array, expectedValues) + assertNoSpill() + + // Populate more rows .. enough to not trigger a spill. + // Verify the data. Verify that there was no spill + populateRows(array, spillThreshold / 3, expectedValues) + validateData(array, expectedValues) + assertNoSpill() + } + } + + test("clear array (with spill)") { + val spillThreshold = 10 + withExternalArray(spillThreshold) { array => + // Populate enough rows to trigger spill + populateRows(array, spillThreshold * 2) + val bytesSpilled = getNumBytesSpilled + assert(bytesSpilled > 0) + + // Clear the array + array.clear() + assert(array.isEmpty) + + // Re-populate the array ... but NOT upto the point that there is spill. + // Verify data. Verify that there was NO "extra" spill + val expectedValues = populateRows(array, spillThreshold / 2) + validateData(array, expectedValues) + assert(getNumBytesSpilled == bytesSpilled) + + // Populate more rows to trigger spill + // Verify the data. Verify that there was "extra" spill + populateRows(array, spillThreshold * 2, expectedValues) + validateData(array, expectedValues) + assert(getNumBytesSpilled > bytesSpilled) + } + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala index 0bfc92fdb6218..4d155d538d637 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala @@ -21,7 +21,7 @@ import org.apache.spark.rdd.RDD import org.apache.spark.sql.{execution, Row} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.plans.Inner +import org.apache.spark.sql.catalyst.plans.{FullOuter, Inner, LeftOuter, RightOuter} import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Repartition} import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.execution.columnar.InMemoryRelation @@ -242,15 +242,18 @@ class PlannerSuite extends SharedSQLContext { val doubleRepartitioned = testData.repartition(10).repartition(20).coalesce(5) def countRepartitions(plan: LogicalPlan): Int = plan.collect { case r: Repartition => r }.length assert(countRepartitions(doubleRepartitioned.queryExecution.logical) === 3) - assert(countRepartitions(doubleRepartitioned.queryExecution.optimizedPlan) === 1) + assert(countRepartitions(doubleRepartitioned.queryExecution.optimizedPlan) === 2) doubleRepartitioned.queryExecution.optimizedPlan match { - case r: Repartition => - assert(r.numPartitions === 5) - assert(r.shuffle === false) + case Repartition (numPartitions, shuffle, Repartition(_, shuffleChild, _)) => + assert(numPartitions === 5) + assert(shuffle === false) + assert(shuffleChild === true) } } - // --- Unit tests of EnsureRequirements --------------------------------------------------------- + /////////////////////////////////////////////////////////////////////////// + // Unit tests of EnsureRequirements for Exchange + /////////////////////////////////////////////////////////////////////////// // When it comes to testing whether EnsureRequirements properly ensures distribution requirements, // there two dimensions that need to be considered: are the child partitionings compatible and @@ -383,93 +386,6 @@ class PlannerSuite extends SharedSQLContext { } } - test("EnsureRequirements adds sort when there is no existing ordering") { - val orderingA = SortOrder(Literal(1), Ascending) - val orderingB = SortOrder(Literal(2), Ascending) - assert(orderingA != orderingB) - val inputPlan = DummySparkPlan( - children = DummySparkPlan(outputOrdering = Seq.empty) :: Nil, - requiredChildOrdering = Seq(Seq(orderingB)), - requiredChildDistribution = Seq(UnspecifiedDistribution) - ) - val outputPlan = EnsureRequirements(spark.sessionState.conf).apply(inputPlan) - assertDistributionRequirementsAreSatisfied(outputPlan) - if (outputPlan.collect { case s: SortExec => true }.isEmpty) { - fail(s"Sort should have been added:\n$outputPlan") - } - } - - test("EnsureRequirements skips sort when required ordering is prefix of existing ordering") { - val orderingA = SortOrder(Literal(1), Ascending) - val orderingB = SortOrder(Literal(2), Ascending) - assert(orderingA != orderingB) - val inputPlan = DummySparkPlan( - children = DummySparkPlan(outputOrdering = Seq(orderingA, orderingB)) :: Nil, - requiredChildOrdering = Seq(Seq(orderingA)), - requiredChildDistribution = Seq(UnspecifiedDistribution) - ) - val outputPlan = EnsureRequirements(spark.sessionState.conf).apply(inputPlan) - assertDistributionRequirementsAreSatisfied(outputPlan) - if (outputPlan.collect { case s: SortExec => true }.nonEmpty) { - fail(s"No sorts should have been added:\n$outputPlan") - } - } - - test("EnsureRequirements skips sort when required ordering is semantically equal to " + - "existing ordering") { - val exprId: ExprId = NamedExpression.newExprId - val attribute1 = - AttributeReference( - name = "col1", - dataType = LongType, - nullable = false - ) (exprId = exprId, - qualifier = Some("col1_qualifier") - ) - - val attribute2 = - AttributeReference( - name = "col1", - dataType = LongType, - nullable = false - ) (exprId = exprId) - - val orderingA1 = SortOrder(attribute1, Ascending) - val orderingA2 = SortOrder(attribute2, Ascending) - - assert(orderingA1 != orderingA2, s"$orderingA1 should NOT equal to $orderingA2") - assert(orderingA1.semanticEquals(orderingA2), - s"$orderingA1 should be semantically equal to $orderingA2") - - val inputPlan = DummySparkPlan( - children = DummySparkPlan(outputOrdering = Seq(orderingA1)) :: Nil, - requiredChildOrdering = Seq(Seq(orderingA2)), - requiredChildDistribution = Seq(UnspecifiedDistribution) - ) - val outputPlan = EnsureRequirements(spark.sessionState.conf).apply(inputPlan) - assertDistributionRequirementsAreSatisfied(outputPlan) - if (outputPlan.collect { case s: SortExec => true }.nonEmpty) { - fail(s"No sorts should have been added:\n$outputPlan") - } - } - - // This is a regression test for SPARK-11135 - test("EnsureRequirements adds sort when required ordering isn't a prefix of existing ordering") { - val orderingA = SortOrder(Literal(1), Ascending) - val orderingB = SortOrder(Literal(2), Ascending) - assert(orderingA != orderingB) - val inputPlan = DummySparkPlan( - children = DummySparkPlan(outputOrdering = Seq(orderingA)) :: Nil, - requiredChildOrdering = Seq(Seq(orderingA, orderingB)), - requiredChildDistribution = Seq(UnspecifiedDistribution) - ) - val outputPlan = EnsureRequirements(spark.sessionState.conf).apply(inputPlan) - assertDistributionRequirementsAreSatisfied(outputPlan) - if (outputPlan.collect { case s: SortExec => true }.isEmpty) { - fail(s"Sort should have been added:\n$outputPlan") - } - } - test("EnsureRequirements eliminates Exchange if child has Exchange with same partitioning") { val distribution = ClusteredDistribution(Literal(1) :: Nil) val finalPartitioning = HashPartitioning(Literal(1) :: Nil, 5) @@ -480,7 +396,7 @@ class PlannerSuite extends SharedSQLContext { children = DummySparkPlan(outputPartitioning = childPartitioning) :: Nil, requiredChildDistribution = Seq(distribution), requiredChildOrdering = Seq(Seq.empty)), - None) + None) val outputPlan = EnsureRequirements(spark.sessionState.conf).apply(inputPlan) assertDistributionRequirementsAreSatisfied(outputPlan) @@ -509,8 +425,6 @@ class PlannerSuite extends SharedSQLContext { } } - // --------------------------------------------------------------------------------------------- - test("Reuse exchanges") { val distribution = ClusteredDistribution(Literal(1) :: Nil) val finalPartitioning = HashPartitioning(Literal(1) :: Nil, 5) @@ -524,12 +438,12 @@ class PlannerSuite extends SharedSQLContext { None) val inputPlan = SortMergeJoinExec( - Literal(1) :: Nil, - Literal(1) :: Nil, - Inner, - None, - shuffle, - shuffle) + Literal(1) :: Nil, + Literal(1) :: Nil, + Inner, + None, + shuffle, + shuffle) val outputPlan = ReuseExchange(spark.sessionState.conf).apply(inputPlan) if (outputPlan.collect { case e: ReusedExchangeExec => true }.size != 1) { @@ -556,6 +470,158 @@ class PlannerSuite extends SharedSQLContext { fail(s"Should have only two shuffles:\n$outputPlan") } } + + /////////////////////////////////////////////////////////////////////////// + // Unit tests of EnsureRequirements for Sort + /////////////////////////////////////////////////////////////////////////// + + private val exprA = Literal(1) + private val exprB = Literal(2) + private val exprC = Literal(3) + private val orderingA = SortOrder(exprA, Ascending) + private val orderingB = SortOrder(exprB, Ascending) + private val orderingC = SortOrder(exprC, Ascending) + private val planA = DummySparkPlan(outputOrdering = Seq(orderingA), + outputPartitioning = HashPartitioning(exprA :: Nil, 5)) + private val planB = DummySparkPlan(outputOrdering = Seq(orderingB), + outputPartitioning = HashPartitioning(exprB :: Nil, 5)) + private val planC = DummySparkPlan(outputOrdering = Seq(orderingC), + outputPartitioning = HashPartitioning(exprC :: Nil, 5)) + + assert(orderingA != orderingB && orderingA != orderingC && orderingB != orderingC) + + private def assertSortRequirementsAreSatisfied( + childPlan: SparkPlan, + requiredOrdering: Seq[SortOrder], + shouldHaveSort: Boolean): Unit = { + val inputPlan = DummySparkPlan( + children = childPlan :: Nil, + requiredChildOrdering = Seq(requiredOrdering), + requiredChildDistribution = Seq(UnspecifiedDistribution) + ) + val outputPlan = EnsureRequirements(spark.sessionState.conf).apply(inputPlan) + assertDistributionRequirementsAreSatisfied(outputPlan) + if (shouldHaveSort) { + if (outputPlan.collect { case s: SortExec => true }.isEmpty) { + fail(s"Sort should have been added:\n$outputPlan") + } + } else { + if (outputPlan.collect { case s: SortExec => true }.nonEmpty) { + fail(s"No sorts should have been added:\n$outputPlan") + } + } + } + + test("EnsureRequirements skips sort when either side of join keys is required after inner SMJ") { + val innerSmj = SortMergeJoinExec(exprA :: Nil, exprB :: Nil, Inner, None, planA, planB) + // Both left and right keys should be sorted after the SMJ. + Seq(orderingA, orderingB).foreach { ordering => + assertSortRequirementsAreSatisfied( + childPlan = innerSmj, + requiredOrdering = Seq(ordering), + shouldHaveSort = false) + } + } + + test("EnsureRequirements skips sort when key order of a parent SMJ is propagated from its " + + "child SMJ") { + val childSmj = SortMergeJoinExec(exprA :: Nil, exprB :: Nil, Inner, None, planA, planB) + val parentSmj = SortMergeJoinExec(exprB :: Nil, exprC :: Nil, Inner, None, childSmj, planC) + // After the second SMJ, exprA, exprB and exprC should all be sorted. + Seq(orderingA, orderingB, orderingC).foreach { ordering => + assertSortRequirementsAreSatisfied( + childPlan = parentSmj, + requiredOrdering = Seq(ordering), + shouldHaveSort = false) + } + } + + test("EnsureRequirements for sort operator after left outer sort merge join") { + // Only left key is sorted after left outer SMJ (thus doesn't need a sort). + val leftSmj = SortMergeJoinExec(exprA :: Nil, exprB :: Nil, LeftOuter, None, planA, planB) + Seq((orderingA, false), (orderingB, true)).foreach { case (ordering, needSort) => + assertSortRequirementsAreSatisfied( + childPlan = leftSmj, + requiredOrdering = Seq(ordering), + shouldHaveSort = needSort) + } + } + + test("EnsureRequirements for sort operator after right outer sort merge join") { + // Only right key is sorted after right outer SMJ (thus doesn't need a sort). + val rightSmj = SortMergeJoinExec(exprA :: Nil, exprB :: Nil, RightOuter, None, planA, planB) + Seq((orderingA, true), (orderingB, false)).foreach { case (ordering, needSort) => + assertSortRequirementsAreSatisfied( + childPlan = rightSmj, + requiredOrdering = Seq(ordering), + shouldHaveSort = needSort) + } + } + + test("EnsureRequirements adds sort after full outer sort merge join") { + // Neither keys is sorted after full outer SMJ, so they both need sorts. + val fullSmj = SortMergeJoinExec(exprA :: Nil, exprB :: Nil, FullOuter, None, planA, planB) + Seq(orderingA, orderingB).foreach { ordering => + assertSortRequirementsAreSatisfied( + childPlan = fullSmj, + requiredOrdering = Seq(ordering), + shouldHaveSort = true) + } + } + + test("EnsureRequirements adds sort when there is no existing ordering") { + assertSortRequirementsAreSatisfied( + childPlan = DummySparkPlan(outputOrdering = Seq.empty), + requiredOrdering = Seq(orderingB), + shouldHaveSort = true) + } + + test("EnsureRequirements skips sort when required ordering is prefix of existing ordering") { + assertSortRequirementsAreSatisfied( + childPlan = DummySparkPlan(outputOrdering = Seq(orderingA, orderingB)), + requiredOrdering = Seq(orderingA), + shouldHaveSort = false) + } + + test("EnsureRequirements skips sort when required ordering is semantically equal to " + + "existing ordering") { + val exprId: ExprId = NamedExpression.newExprId + val attribute1 = + AttributeReference( + name = "col1", + dataType = LongType, + nullable = false + ) (exprId = exprId, + qualifier = Some("col1_qualifier") + ) + + val attribute2 = + AttributeReference( + name = "col1", + dataType = LongType, + nullable = false + ) (exprId = exprId) + + val orderingA1 = SortOrder(attribute1, Ascending) + val orderingA2 = SortOrder(attribute2, Ascending) + + assert(orderingA1 != orderingA2, s"$orderingA1 should NOT equal to $orderingA2") + assert(orderingA1.semanticEquals(orderingA2), + s"$orderingA1 should be semantically equal to $orderingA2") + + assertSortRequirementsAreSatisfied( + childPlan = DummySparkPlan(outputOrdering = Seq(orderingA1)), + requiredOrdering = Seq(orderingA2), + shouldHaveSort = false) + } + + // This is a regression test for SPARK-11135 + test("EnsureRequirements adds sort when required ordering isn't a prefix of existing ordering") { + assertSortRequirementsAreSatisfied( + childPlan = DummySparkPlan(outputOrdering = Seq(orderingA)), + requiredOrdering = Seq(orderingA, orderingB), + shouldHaveSort = true) + } } // Used for unit-testing EnsureRequirements diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/QueryExecutionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/QueryExecutionSuite.scala index 8bceab39f71d5..05637821f71f1 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/QueryExecutionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/QueryExecutionSuite.scala @@ -16,6 +16,10 @@ */ package org.apache.spark.sql.execution +import java.util.Locale + +import scala.language.reflectiveCalls + import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, OneRowRelation} import org.apache.spark.sql.test.SharedSQLContext @@ -24,11 +28,12 @@ class QueryExecutionSuite extends SharedSQLContext { test("toString() exception/error handling") { val badRule = new SparkStrategy { var mode: String = "" - override def apply(plan: LogicalPlan): Seq[SparkPlan] = mode.toLowerCase match { - case "exception" => throw new AnalysisException(mode) - case "error" => throw new Error(mode) - case _ => Nil - } + override def apply(plan: LogicalPlan): Seq[SparkPlan] = + mode.toLowerCase(Locale.ROOT) match { + case "exception" => throw new AnalysisException(mode) + case "error" => throw new Error(mode) + case _ => Nil + } } spark.experimental.extraStrategies = badRule :: Nil diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/SQLViewSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/SQLViewSuite.scala index 2d95cb6d64a87..d32716c18ddfb 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/SQLViewSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/SQLViewSuite.scala @@ -172,7 +172,7 @@ abstract class SQLViewSuite extends QueryTest with SQLTestUtils { var e = intercept[AnalysisException] { sql(s"INSERT INTO TABLE $viewName SELECT 1") }.getMessage - assert(e.contains("Inserting into an RDD-based table is not allowed")) + assert(e.contains("Inserting into a view is not allowed. View: `default`.`testview`")) val dataFilePath = Thread.currentThread().getContextClassLoader.getResource("data/files/employee.dat") @@ -609,12 +609,64 @@ abstract class SQLViewSuite extends QueryTest with SQLTestUtils { } } - // TODO: Check for cyclic view references on ALTER VIEW. - ignore("correctly handle a cyclic view reference") { - withView("view1", "view2") { + test("correctly handle a cyclic view reference") { + withView("view1", "view2", "view3") { sql("CREATE VIEW view1 AS SELECT * FROM jt") sql("CREATE VIEW view2 AS SELECT * FROM view1") - intercept[AnalysisException](sql("ALTER VIEW view1 AS SELECT * FROM view2")) + sql("CREATE VIEW view3 AS SELECT * FROM view2") + + // Detect cyclic view reference on ALTER VIEW. + val e1 = intercept[AnalysisException] { + sql("ALTER VIEW view1 AS SELECT * FROM view2") + }.getMessage + assert(e1.contains("Recursive view `default`.`view1` detected (cycle: `default`.`view1` " + + "-> `default`.`view2` -> `default`.`view1`)")) + + // Detect the most left cycle when there exists multiple cyclic view references. + val e2 = intercept[AnalysisException] { + sql("ALTER VIEW view1 AS SELECT * FROM view3 JOIN view2") + }.getMessage + assert(e2.contains("Recursive view `default`.`view1` detected (cycle: `default`.`view1` " + + "-> `default`.`view3` -> `default`.`view2` -> `default`.`view1`)")) + + // Detect cyclic view reference on CREATE OR REPLACE VIEW. + val e3 = intercept[AnalysisException] { + sql("CREATE OR REPLACE VIEW view1 AS SELECT * FROM view2") + }.getMessage + assert(e3.contains("Recursive view `default`.`view1` detected (cycle: `default`.`view1` " + + "-> `default`.`view2` -> `default`.`view1`)")) + + // Detect cyclic view reference from subqueries. + val e4 = intercept[AnalysisException] { + sql("ALTER VIEW view1 AS SELECT * FROM jt WHERE EXISTS (SELECT 1 FROM view2)") + }.getMessage + assert(e4.contains("Recursive view `default`.`view1` detected (cycle: `default`.`view1` " + + "-> `default`.`view2` -> `default`.`view1`)")) + } + } + + test("restrict the nested level of a view") { + val viewNames = Array.range(0, 11).map(idx => s"view$idx") + withView(viewNames: _*) { + sql("CREATE VIEW view0 AS SELECT * FROM jt") + Array.range(0, 10).foreach { idx => + sql(s"CREATE VIEW view${idx + 1} AS SELECT * FROM view$idx") + } + + withSQLConf("spark.sql.view.maxNestedViewDepth" -> "10") { + val e = intercept[AnalysisException] { + sql("SELECT * FROM view10") + }.getMessage + assert(e.contains("The depth of view `default`.`view0` exceeds the maximum view " + + "resolution depth (10). Analysis is aborted to avoid errors. Increase the value " + + "of spark.sql.view.maxNestedViewDepth to work aroud this.")) + } + + val e = intercept[IllegalArgumentException] { + withSQLConf("spark.sql.view.maxNestedViewDepth" -> "0") {} + }.getMessage + assert(e.contains("The maximum depth of a view reference in a nested view must be " + + "positive.")) } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/SQLWindowFunctionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/SQLWindowFunctionSuite.scala index afd47897ed4b2..52e4f047225de 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/SQLWindowFunctionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/SQLWindowFunctionSuite.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.execution import org.apache.spark.sql.{AnalysisException, QueryTest, Row} import org.apache.spark.sql.test.SharedSQLContext +import org.apache.spark.TestUtils.assertSpilled case class WindowData(month: Int, area: String, product: Int) @@ -412,4 +413,36 @@ class SQLWindowFunctionSuite extends QueryTest with SharedSQLContext { """.stripMargin), Row(1, 3, null) :: Row(2, null, 4) :: Nil) } + + test("test with low buffer spill threshold") { + val nums = sparkContext.parallelize(1 to 10).map(x => (x, x % 2)).toDF("x", "y") + nums.createOrReplaceTempView("nums") + + val expected = + Row(1, 1, 1) :: + Row(0, 2, 3) :: + Row(1, 3, 6) :: + Row(0, 4, 10) :: + Row(1, 5, 15) :: + Row(0, 6, 21) :: + Row(1, 7, 28) :: + Row(0, 8, 36) :: + Row(1, 9, 45) :: + Row(0, 10, 55) :: Nil + + val actual = sql( + """ + |SELECT y, x, sum(x) OVER w1 AS running_sum + |FROM nums + |WINDOW w1 AS (ORDER BY x ROWS BETWEEN UNBOUNDED PRECEDiNG AND CURRENT RoW) + """.stripMargin) + + withSQLConf("spark.sql.windowExec.buffer.spill.threshold" -> "1") { + assertSpilled(sparkContext, "test with low buffer spill threshold") { + checkAnswer(actual, expected) + } + } + + spark.catalog.dropTempView("nums") + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkSqlParserSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkSqlParserSuite.scala index bb6c486e880a0..908b955abbf07 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkSqlParserSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkSqlParserSuite.scala @@ -45,7 +45,7 @@ class SparkSqlParserSuite extends PlanTest { * Normalizes plans: * - CreateTable the createTime in tableDesc will replaced by -1L. */ - private def normalizePlan(plan: LogicalPlan): LogicalPlan = { + override def normalizePlan(plan: LogicalPlan): LogicalPlan = { plan match { case CreateTable(tableDesc, mode, query) => val newTableDesc = tableDesc.copy(createTime = -1L) @@ -210,16 +210,27 @@ class SparkSqlParserSuite extends PlanTest { "no viable alternative at input") } + test("create view as insert into table") { + // Single insert query + intercept("CREATE VIEW testView AS INSERT INTO jt VALUES(1, 1)", + "Operation not allowed: CREATE VIEW ... AS INSERT INTO") + + // Multi insert query + intercept("CREATE VIEW testView AS FROM jt INSERT INTO tbl1 SELECT * WHERE jt.id < 5 " + + "INSERT INTO tbl2 SELECT * WHERE jt.id > 4", + "Operation not allowed: CREATE VIEW ... AS FROM ... [INSERT INTO ...]+") + } + test("SPARK-17328 Fix NPE with EXPLAIN DESCRIBE TABLE") { assertEqual("describe table t", DescribeTableCommand( - TableIdentifier("t"), Map.empty, isExtended = false, isFormatted = false)) + TableIdentifier("t"), Map.empty, isExtended = false)) assertEqual("describe table extended t", DescribeTableCommand( - TableIdentifier("t"), Map.empty, isExtended = true, isFormatted = false)) + TableIdentifier("t"), Map.empty, isExtended = true)) assertEqual("describe table formatted t", DescribeTableCommand( - TableIdentifier("t"), Map.empty, isExtended = false, isFormatted = true)) + TableIdentifier("t"), Map.empty, isExtended = true)) intercept("explain describe tables x", "Unsupported SQL statement") } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala index 4d9203556d49e..a4b30a2f8cec1 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala @@ -116,34 +116,6 @@ class WholeStageCodegenSuite extends SparkPlanTest with SharedSQLContext { assert(ds.collect() === Array(("a", 10.0), ("b", 3.0), ("c", 1.0))) } - test("generate should be included in WholeStageCodegen") { - import org.apache.spark.sql.functions._ - val ds = spark.range(2).select( - col("id"), - explode(array(col("id") + 1, col("id") + 2)).as("value")) - val plan = ds.queryExecution.executedPlan - assert(plan.find(p => - p.isInstanceOf[WholeStageCodegenExec] && - p.asInstanceOf[WholeStageCodegenExec].child.isInstanceOf[GenerateExec]).isDefined) - assert(ds.collect() === Array(Row(0, 1), Row(0, 2), Row(1, 2), Row(1, 3))) - } - - test("large stack generator should not use WholeStageCodegen") { - def createStackGenerator(rows: Int): SparkPlan = { - val id = UnresolvedAttribute("id") - val stack = Stack(Literal(rows) +: Seq.tabulate(rows)(i => Add(id, Literal(i)))) - spark.range(500).select(Column(stack)).queryExecution.executedPlan - } - val isCodeGenerated: SparkPlan => Boolean = { - case WholeStageCodegenExec(_: GenerateExec) => true - case _ => false - } - - // Only 'stack' generators that produce 50 rows or less are code generated. - assert(createStackGenerator(50).find(isCodeGenerated).isDefined) - assert(createStackGenerator(100).find(isCodeGenerated).isEmpty) - } - test("SPARK-19512 codegen for comparing structs is incorrect") { // this would raise CompileException before the fix spark.range(10) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/AggregateBenchmark.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/AggregateBenchmark.scala index 8a2993bdf4b28..8a798fb444696 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/AggregateBenchmark.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/AggregateBenchmark.scala @@ -107,6 +107,7 @@ class AggregateBenchmark extends BenchmarkBase { benchmark.addCase(s"codegen = T hashmap = F", numIters = 3) { iter => sparkSession.conf.set("spark.sql.codegen.wholeStage", "true") sparkSession.conf.set("spark.sql.codegen.aggregate.map.twolevel.enable", "false") + sparkSession.conf.set("spark.sql.codegen.aggregate.map.vectorized.enable", "false") f() } @@ -148,6 +149,7 @@ class AggregateBenchmark extends BenchmarkBase { benchmark.addCase(s"codegen = T hashmap = F", numIters = 3) { iter => sparkSession.conf.set("spark.sql.codegen.wholeStage", value = true) sparkSession.conf.set("spark.sql.codegen.aggregate.map.twolevel.enable", "false") + sparkSession.conf.set("spark.sql.codegen.aggregate.map.vectorized.enable", "false") f() } @@ -187,6 +189,7 @@ class AggregateBenchmark extends BenchmarkBase { benchmark.addCase(s"codegen = T hashmap = F", numIters = 3) { iter => sparkSession.conf.set("spark.sql.codegen.wholeStage", "true") sparkSession.conf.set("spark.sql.codegen.aggregate.map.twolevel.enable", "false") + sparkSession.conf.set("spark.sql.codegen.aggregate.map.vectorized.enable", "false") f() } @@ -225,6 +228,7 @@ class AggregateBenchmark extends BenchmarkBase { benchmark.addCase(s"codegen = T hashmap = F") { iter => sparkSession.conf.set("spark.sql.codegen.wholeStage", "true") sparkSession.conf.set("spark.sql.codegen.aggregate.map.twolevel.enable", "false") + sparkSession.conf.set("spark.sql.codegen.aggregate.map.vectorized.enable", "false") f() } @@ -273,6 +277,7 @@ class AggregateBenchmark extends BenchmarkBase { benchmark.addCase(s"codegen = T hashmap = F") { iter => sparkSession.conf.set("spark.sql.codegen.wholeStage", "true") sparkSession.conf.set("spark.sql.codegen.aggregate.map.twolevel.enable", "false") + sparkSession.conf.set("spark.sql.codegen.aggregate.map.vectorized.enable", "false") f() } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarQuerySuite.scala index f355a5200ce2f..109b1d9db60d2 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarQuerySuite.scala @@ -21,6 +21,9 @@ import java.nio.charset.StandardCharsets import java.sql.{Date, Timestamp} import org.apache.spark.sql.{DataFrame, QueryTest, Row} +import org.apache.spark.sql.catalyst.expressions.AttributeSet +import org.apache.spark.sql.catalyst.plans.physical.HashPartitioning +import org.apache.spark.sql.functions._ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.test.SQLTestData._ @@ -234,8 +237,7 @@ class InMemoryColumnarQuerySuite extends QueryTest with SharedSQLContext { Seq(StringType, BinaryType, NullType, BooleanType, ByteType, ShortType, IntegerType, LongType, FloatType, DoubleType, DecimalType(25, 5), DecimalType(6, 5), - DateType, TimestampType, - ArrayType(IntegerType), MapType(StringType, LongType), struct) + DateType, TimestampType, ArrayType(IntegerType), struct) val fields = dataTypes.zipWithIndex.map { case (dataType, index) => StructField(s"col$index", dataType, true) } @@ -244,10 +246,10 @@ class InMemoryColumnarQuerySuite extends QueryTest with SharedSQLContext { // Create an RDD for the schema val rdd = - sparkContext.parallelize((1 to 10000), 10).map { i => + sparkContext.parallelize(1 to 10000, 10).map { i => Row( - s"str${i}: test cache.", - s"binary${i}: test cache.".getBytes(StandardCharsets.UTF_8), + s"str$i: test cache.", + s"binary$i: test cache.".getBytes(StandardCharsets.UTF_8), null, i % 2 == 0, i.toByte, @@ -255,13 +257,12 @@ class InMemoryColumnarQuerySuite extends QueryTest with SharedSQLContext { i, Long.MaxValue - i.toLong, (i + 0.25).toFloat, - (i + 0.75), + i + 0.75, BigDecimal(Long.MaxValue.toString + ".12345"), new java.math.BigDecimal(s"${i % 9 + 1}" + ".23456"), new Date(i), new Timestamp(i * 1000000L), - (i to i + 10).toSeq, - (i to i + 10).map(j => s"map_key_$j" -> (Long.MaxValue - j)).toMap, + i to i + 10, Row((i - 0.25).toFloat, Seq(true, false, null))) } spark.createDataFrame(rdd, schema).createOrReplaceTempView("InMemoryCache_different_data_types") @@ -390,4 +391,42 @@ class InMemoryColumnarQuerySuite extends QueryTest with SharedSQLContext { } } + test("InMemoryTableScanExec should return correct output ordering and partitioning") { + val df1 = Seq((0, 0), (1, 1)).toDF + .repartition(col("_1")).sortWithinPartitions(col("_1")).persist + val df2 = Seq((0, 0), (1, 1)).toDF + .repartition(col("_1")).sortWithinPartitions(col("_1")).persist + + // Because two cached dataframes have the same logical plan, this is a self-join actually. + // So we force one of in-memory relation to alias its output. Then we can test if original and + // aliased in-memory relations have correct ordering and partitioning. + val joined = df1.joinWith(df2, df1("_1") === df2("_1")) + + val inMemoryScans = joined.queryExecution.executedPlan.collect { + case m: InMemoryTableScanExec => m + } + inMemoryScans.foreach { inMemoryScan => + val sortedAttrs = AttributeSet(inMemoryScan.outputOrdering.flatMap(_.references)) + assert(sortedAttrs.subsetOf(inMemoryScan.outputSet)) + + val partitionedAttrs = + inMemoryScan.outputPartitioning.asInstanceOf[HashPartitioning].references + assert(partitionedAttrs.subsetOf(inMemoryScan.outputSet)) + } + } + + test("SPARK-20356: pruned InMemoryTableScanExec should have correct ordering and partitioning") { + withSQLConf("spark.sql.shuffle.partitions" -> "200") { + val df1 = Seq(("a", 1), ("b", 1), ("c", 2)).toDF("item", "group") + val df2 = Seq(("a", 1), ("b", 2), ("c", 3)).toDF("item", "id") + val df3 = df1.join(df2, Seq("item")).select($"id", $"group".as("item")).distinct() + + df3.unpersist() + val agg_without_cache = df3.groupBy($"item").count() + + df3.cache() + val agg_with_cache = df3.groupBy($"item").count() + checkAnswer(agg_without_cache, agg_with_cache) + } + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLCommandSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLCommandSuite.scala index 76bb9e5929a71..8a6bc62fec96c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLCommandSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLCommandSuite.scala @@ -17,6 +17,9 @@ package org.apache.spark.sql.execution.command +import java.net.URI +import java.util.Locale + import scala.reflect.{classTag, ClassTag} import org.apache.spark.sql.catalyst.TableIdentifier @@ -38,8 +41,10 @@ class DDLCommandSuite extends PlanTest { val e = intercept[ParseException] { parser.parsePlan(sql) } - assert(e.getMessage.toLowerCase.contains("operation not allowed")) - containsThesePhrases.foreach { p => assert(e.getMessage.toLowerCase.contains(p.toLowerCase)) } + assert(e.getMessage.toLowerCase(Locale.ROOT).contains("operation not allowed")) + containsThesePhrases.foreach { p => + assert(e.getMessage.toLowerCase(Locale.ROOT).contains(p.toLowerCase(Locale.ROOT))) + } } private def parseAs[T: ClassTag](query: String): T = { @@ -317,7 +322,7 @@ class DDLCommandSuite extends PlanTest { val query = "CREATE EXTERNAL TABLE my_tab LOCATION '/something/anything'" val ct = parseAs[CreateTable](query) assert(ct.tableDesc.tableType == CatalogTableType.EXTERNAL) - assert(ct.tableDesc.storage.locationUri == Some("/something/anything")) + assert(ct.tableDesc.storage.locationUri == Some(new URI("/something/anything"))) } test("create hive table - property values must be set") { @@ -334,7 +339,7 @@ class DDLCommandSuite extends PlanTest { val query = "CREATE TABLE my_tab LOCATION '/something/anything'" val ct = parseAs[CreateTable](query) assert(ct.tableDesc.tableType == CatalogTableType.EXTERNAL) - assert(ct.tableDesc.storage.locationUri == Some("/something/anything")) + assert(ct.tableDesc.storage.locationUri == Some(new URI("/something/anything"))) } test("create table - with partitioned by") { @@ -409,7 +414,7 @@ class DDLCommandSuite extends PlanTest { val expectedTableDesc = CatalogTable( identifier = TableIdentifier("my_tab"), tableType = CatalogTableType.EXTERNAL, - storage = CatalogStorageFormat.empty.copy(locationUri = Some("/tmp/file")), + storage = CatalogStorageFormat.empty.copy(locationUri = Some(new URI("/tmp/file"))), schema = new StructType().add("a", IntegerType).add("b", StringType), provider = Some("parquet")) @@ -525,13 +530,13 @@ class DDLCommandSuite extends PlanTest { """.stripMargin val sql4 = """ - |ALTER TABLE table_name PARTITION (test, dt='2008-08-08', + |ALTER TABLE table_name PARTITION (test=1, dt='2008-08-08', |country='us') SET SERDE 'org.apache.class' WITH SERDEPROPERTIES ('columns'='foo,bar', |'field.delim' = ',') """.stripMargin val sql5 = """ - |ALTER TABLE table_name PARTITION (test, dt='2008-08-08', + |ALTER TABLE table_name PARTITION (test=1, dt='2008-08-08', |country='us') SET SERDEPROPERTIES ('columns'='foo,bar', 'field.delim' = ',') """.stripMargin val parsed1 = parser.parsePlan(sql1) @@ -553,12 +558,12 @@ class DDLCommandSuite extends PlanTest { tableIdent, Some("org.apache.class"), Some(Map("columns" -> "foo,bar", "field.delim" -> ",")), - Some(Map("test" -> null, "dt" -> "2008-08-08", "country" -> "us"))) + Some(Map("test" -> "1", "dt" -> "2008-08-08", "country" -> "us"))) val expected5 = AlterTableSerDePropertiesCommand( tableIdent, None, Some(Map("columns" -> "foo,bar", "field.delim" -> ",")), - Some(Map("test" -> null, "dt" -> "2008-08-08", "country" -> "us"))) + Some(Map("test" -> "1", "dt" -> "2008-08-08", "country" -> "us"))) comparePlans(parsed1, expected1) comparePlans(parsed2, expected2) comparePlans(parsed3, expected3) @@ -778,13 +783,7 @@ class DDLCommandSuite extends PlanTest { assertUnsupported("ALTER TABLE table_name SKEWED BY (key) ON (1,5,6) STORED AS DIRECTORIES") } - test("alter table: add/replace columns (not allowed)") { - assertUnsupported( - """ - |ALTER TABLE table_name PARTITION (dt='2008-08-08', country='us') - |ADD COLUMNS (new_col1 INT COMMENT 'test_comment', new_col2 LONG - |COMMENT 'test_comment2') CASCADE - """.stripMargin) + test("alter table: replace columns (not allowed)") { assertUnsupported( """ |ALTER TABLE table_name REPLACE COLUMNS (new_col1 INT @@ -833,6 +832,14 @@ class DDLCommandSuite extends PlanTest { assert(e.contains("Found duplicate keys 'a'")) } + test("empty values in non-optional partition specs") { + val e = intercept[ParseException] { + parser.parsePlan( + "SHOW PARTITIONS dbx.tab1 PARTITION (a='1', b)") + }.getMessage + assert(e.contains("Found an empty partition key 'b'")) + } + test("drop table") { val tableName1 = "db.tab" val tableName2 = "tab" diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala index 8b8cd0fdf4db2..0abcff76060f7 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala @@ -19,36 +19,128 @@ package org.apache.spark.sql.execution.command import java.io.File import java.net.URI +import java.util.Locale import org.apache.hadoop.fs.Path import org.scalatest.BeforeAndAfterEach import org.apache.spark.sql.{AnalysisException, QueryTest, Row, SaveMode} import org.apache.spark.sql.catalyst.TableIdentifier -import org.apache.spark.sql.catalyst.analysis.{DatabaseAlreadyExistsException, FunctionRegistry, NoSuchPartitionException, NoSuchTableException, TempTableAlreadyExistsException} -import org.apache.spark.sql.catalyst.catalog.{BucketSpec, CatalogDatabase, CatalogStorageFormat} -import org.apache.spark.sql.catalyst.catalog.{CatalogTable, CatalogTableType} -import org.apache.spark.sql.catalyst.catalog.{CatalogTablePartition, SessionCatalog} +import org.apache.spark.sql.catalyst.analysis.{FunctionRegistry, NoSuchPartitionException, NoSuchTableException, TempTableAlreadyExistsException} +import org.apache.spark.sql.catalyst.catalog._ import org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.internal.StaticSQLConf.CATALOG_IMPLEMENTATION -import org.apache.spark.sql.test.SharedSQLContext +import org.apache.spark.sql.test.{SharedSQLContext, SQLTestUtils} import org.apache.spark.sql.types._ import org.apache.spark.util.Utils -class DDLSuite extends QueryTest with SharedSQLContext with BeforeAndAfterEach { - private val escapedIdentifier = "`(.+)`".r +class InMemoryCatalogedDDLSuite extends DDLSuite with SharedSQLContext with BeforeAndAfterEach { override def afterEach(): Unit = { try { // drop all databases, tables and functions after each test spark.sessionState.catalog.reset() } finally { - Utils.deleteRecursively(new File("spark-warehouse")) + Utils.deleteRecursively(new File(spark.sessionState.conf.warehousePath)) super.afterEach() } } + protected override def generateTable( + catalog: SessionCatalog, + name: TableIdentifier, + isDataSource: Boolean = true): CatalogTable = { + val storage = + CatalogStorageFormat.empty.copy(locationUri = Some(catalog.defaultTablePath(name))) + val metadata = new MetadataBuilder() + .putString("key", "value") + .build() + CatalogTable( + identifier = name, + tableType = CatalogTableType.EXTERNAL, + storage = storage, + schema = new StructType() + .add("col1", "int", nullable = true, metadata = metadata) + .add("col2", "string") + .add("a", "int") + .add("b", "int"), + provider = Some("parquet"), + partitionColumnNames = Seq("a", "b"), + createTime = 0L, + tracksPartitionsInCatalog = true) + } + + test("create a managed Hive source table") { + assume(spark.sparkContext.conf.get(CATALOG_IMPLEMENTATION) == "in-memory") + val tabName = "tbl" + withTable(tabName) { + val e = intercept[AnalysisException] { + sql(s"CREATE TABLE $tabName (i INT, j STRING)") + }.getMessage + assert(e.contains("Hive support is required to CREATE Hive TABLE")) + } + } + + test("create an external Hive source table") { + assume(spark.sparkContext.conf.get(CATALOG_IMPLEMENTATION) == "in-memory") + withTempDir { tempDir => + val tabName = "tbl" + withTable(tabName) { + val e = intercept[AnalysisException] { + sql( + s""" + |CREATE EXTERNAL TABLE $tabName (i INT, j STRING) + |ROW FORMAT DELIMITED FIELDS TERMINATED BY ',' + |LOCATION '${tempDir.toURI}' + """.stripMargin) + }.getMessage + assert(e.contains("Hive support is required to CREATE Hive TABLE")) + } + } + } + + test("Create Hive Table As Select") { + import testImplicits._ + withTable("t", "t1") { + var e = intercept[AnalysisException] { + sql("CREATE TABLE t SELECT 1 as a, 1 as b") + }.getMessage + assert(e.contains("Hive support is required to CREATE Hive TABLE (AS SELECT)")) + + spark.range(1).select('id as 'a, 'id as 'b).write.saveAsTable("t1") + e = intercept[AnalysisException] { + sql("CREATE TABLE t SELECT a, b from t1") + }.getMessage + assert(e.contains("Hive support is required to CREATE Hive TABLE (AS SELECT)")) + } + } + +} + +abstract class DDLSuite extends QueryTest with SQLTestUtils { + + protected def isUsingHiveMetastore: Boolean = { + spark.sparkContext.conf.get(CATALOG_IMPLEMENTATION) == "hive" + } + + protected def generateTable( + catalog: SessionCatalog, + name: TableIdentifier, + isDataSource: Boolean = true): CatalogTable + + private val escapedIdentifier = "`(.+)`".r + + protected def normalizeCatalogTable(table: CatalogTable): CatalogTable = table + + private def normalizeSerdeProp(props: Map[String, String]): Map[String, String] = { + props.filterNot(p => Seq("serialization.format", "path").contains(p._1)) + } + + private def checkCatalogTables(expected: CatalogTable, actual: CatalogTable): Unit = { + assert(normalizeCatalogTable(actual) == normalizeCatalogTable(expected)) + } + /** * Strip backticks, if any, from the string. */ @@ -63,7 +155,7 @@ class DDLSuite extends QueryTest with SharedSQLContext with BeforeAndAfterEach { val e = intercept[AnalysisException] { sql(query) } - assert(e.getMessage.toLowerCase.contains("operation not allowed")) + assert(e.getMessage.toLowerCase(Locale.ROOT).contains("operation not allowed")) } private def maybeWrapException[T](expectException: Boolean)(body: => T): Unit = { @@ -72,39 +164,16 @@ class DDLSuite extends QueryTest with SharedSQLContext with BeforeAndAfterEach { private def createDatabase(catalog: SessionCatalog, name: String): Unit = { catalog.createDatabase( - CatalogDatabase(name, "", spark.sessionState.conf.warehousePath, Map()), + CatalogDatabase( + name, "", CatalogUtils.stringToURI(spark.sessionState.conf.warehousePath), Map()), ignoreIfExists = false) } - private def generateTable(catalog: SessionCatalog, name: TableIdentifier): CatalogTable = { - val storage = - CatalogStorageFormat( - locationUri = Some(catalog.defaultTablePath(name)), - inputFormat = None, - outputFormat = None, - serde = None, - compressed = false, - properties = Map()) - val metadata = new MetadataBuilder() - .putString("key", "value") - .build() - CatalogTable( - identifier = name, - tableType = CatalogTableType.EXTERNAL, - storage = storage, - schema = new StructType() - .add("col1", "int", nullable = true, metadata = metadata) - .add("col2", "string") - .add("a", "int") - .add("b", "int"), - provider = Some("parquet"), - partitionColumnNames = Seq("a", "b"), - createTime = 0L, - tracksPartitionsInCatalog = true) - } - - private def createTable(catalog: SessionCatalog, name: TableIdentifier): Unit = { - catalog.createTable(generateTable(catalog, name), ignoreIfExists = false) + private def createTable( + catalog: SessionCatalog, + name: TableIdentifier, + isDataSource: Boolean = true): Unit = { + catalog.createTable(generateTable(catalog, name, isDataSource), ignoreIfExists = false) } private def createTablePartition( @@ -116,6 +185,51 @@ class DDLSuite extends QueryTest with SharedSQLContext with BeforeAndAfterEach { catalog.createPartitions(tableName, Seq(part), ignoreIfExists = false) } + private def getDBPath(dbName: String): URI = { + val warehousePath = makeQualifiedPath(spark.sessionState.conf.warehousePath) + new Path(CatalogUtils.URIToString(warehousePath), s"$dbName.db").toUri + } + + test("alter table: set location (datasource table)") { + testSetLocation(isDatasourceTable = true) + } + + test("alter table: set properties (datasource table)") { + testSetProperties(isDatasourceTable = true) + } + + test("alter table: unset properties (datasource table)") { + testUnsetProperties(isDatasourceTable = true) + } + + test("alter table: set serde (datasource table)") { + testSetSerde(isDatasourceTable = true) + } + + test("alter table: set serde partition (datasource table)") { + testSetSerdePartition(isDatasourceTable = true) + } + + test("alter table: change column (datasource table)") { + testChangeColumn(isDatasourceTable = true) + } + + test("alter table: add partition (datasource table)") { + testAddPartitions(isDatasourceTable = true) + } + + test("alter table: drop partition (datasource table)") { + testDropPartitions(isDatasourceTable = true) + } + + test("alter table: rename partition (datasource table)") { + testRenamePartitions(isDatasourceTable = true) + } + + test("drop table - data source table") { + testDropTable(isDatasourceTable = true) + } + test("the qualified path of a database is stored in the catalog") { val catalog = spark.sessionState.catalog @@ -133,24 +247,16 @@ class DDLSuite extends QueryTest with SharedSQLContext with BeforeAndAfterEach { } } - private def makeQualifiedPath(path: String): String = { - // copy-paste from SessionCatalog - val hadoopPath = new Path(path) - val fs = hadoopPath.getFileSystem(sparkContext.hadoopConfiguration) - fs.makeQualified(hadoopPath).toString - } - test("Create Database using Default Warehouse Path") { val catalog = spark.sessionState.catalog val dbName = "db1" try { sql(s"CREATE DATABASE $dbName") val db1 = catalog.getDatabaseMetadata(dbName) - val expectedLocation = makeQualifiedPath(s"spark-warehouse/$dbName.db") assert(db1 == CatalogDatabase( dbName, "", - expectedLocation, + getDBPath(dbName), Map.empty)) sql(s"DROP DATABASE $dbName CASCADE") assert(!catalog.databaseExists(dbName)) @@ -193,16 +299,17 @@ class DDLSuite extends QueryTest with SharedSQLContext with BeforeAndAfterEach { val dbNameWithoutBackTicks = cleanIdentifier(dbName) sql(s"CREATE DATABASE $dbName") val db1 = catalog.getDatabaseMetadata(dbNameWithoutBackTicks) - val expectedLocation = makeQualifiedPath(s"spark-warehouse/$dbNameWithoutBackTicks.db") assert(db1 == CatalogDatabase( dbNameWithoutBackTicks, "", - expectedLocation, + getDBPath(dbNameWithoutBackTicks), Map.empty)) - intercept[DatabaseAlreadyExistsException] { + // TODO: HiveExternalCatalog should throw DatabaseAlreadyExistsException + val e = intercept[AnalysisException] { sql(s"CREATE DATABASE $dbName") - } + }.getMessage + assert(e.contains(s"already exists")) } finally { catalog.reset() } @@ -421,19 +528,6 @@ class DDLSuite extends QueryTest with SharedSQLContext with BeforeAndAfterEach { } } - test("desc table for parquet data source table using in-memory catalog") { - assume(spark.sparkContext.conf.get(CATALOG_IMPLEMENTATION) == "in-memory") - val tabName = "tab1" - withTable(tabName) { - sql(s"CREATE TABLE $tabName(a int comment 'test') USING parquet ") - - checkAnswer( - sql(s"DESC $tabName").select("col_name", "data_type", "comment"), - Row("a", "int", "test") - ) - } - } - test("Alter/Describe Database") { val catalog = spark.sessionState.catalog val databaseNames = Seq("db1", "`database`") @@ -441,7 +535,7 @@ class DDLSuite extends QueryTest with SharedSQLContext with BeforeAndAfterEach { databaseNames.foreach { dbName => try { val dbNameWithoutBackTicks = cleanIdentifier(dbName) - val location = makeQualifiedPath(s"spark-warehouse/$dbNameWithoutBackTicks.db") + val location = getDBPath(dbNameWithoutBackTicks) sql(s"CREATE DATABASE $dbName") @@ -449,7 +543,7 @@ class DDLSuite extends QueryTest with SharedSQLContext with BeforeAndAfterEach { sql(s"DESCRIBE DATABASE EXTENDED $dbName"), Row("Database Name", dbNameWithoutBackTicks) :: Row("Description", "") :: - Row("Location", location) :: + Row("Location", CatalogUtils.URIToString(location)) :: Row("Properties", "") :: Nil) sql(s"ALTER DATABASE $dbName SET DBPROPERTIES ('a'='a', 'b'='b', 'c'='c')") @@ -458,7 +552,7 @@ class DDLSuite extends QueryTest with SharedSQLContext with BeforeAndAfterEach { sql(s"DESCRIBE DATABASE EXTENDED $dbName"), Row("Database Name", dbNameWithoutBackTicks) :: Row("Description", "") :: - Row("Location", location) :: + Row("Location", CatalogUtils.URIToString(location)) :: Row("Properties", "((a,a), (b,b), (c,c))") :: Nil) sql(s"ALTER DATABASE $dbName SET DBPROPERTIES ('d'='d')") @@ -467,7 +561,7 @@ class DDLSuite extends QueryTest with SharedSQLContext with BeforeAndAfterEach { sql(s"DESCRIBE DATABASE EXTENDED $dbName"), Row("Database Name", dbNameWithoutBackTicks) :: Row("Description", "") :: - Row("Location", location) :: + Row("Location", CatalogUtils.URIToString(location)) :: Row("Properties", "((a,a), (b,b), (c,c), (d,d))") :: Nil) } finally { catalog.reset() @@ -485,7 +579,12 @@ class DDLSuite extends QueryTest with SharedSQLContext with BeforeAndAfterEach { var message = intercept[AnalysisException] { sql(s"DROP DATABASE $dbName") }.getMessage - assert(message.contains(s"Database '$dbNameWithoutBackTicks' not found")) + // TODO: Unify the exception. + if (isUsingHiveMetastore) { + assert(message.contains(s"NoSuchObjectException: $dbNameWithoutBackTicks")) + } else { + assert(message.contains(s"Database '$dbNameWithoutBackTicks' not found")) + } message = intercept[AnalysisException] { sql(s"ALTER DATABASE $dbName SET DBPROPERTIES ('d'='d')") @@ -514,7 +613,8 @@ class DDLSuite extends QueryTest with SharedSQLContext with BeforeAndAfterEach { val message = intercept[AnalysisException] { sql(s"DROP DATABASE $dbName RESTRICT") }.getMessage - assert(message.contains(s"Database '$dbName' is not empty. One or more tables exist")) + assert(message.contains(s"Database $dbName is not empty. One or more tables exist")) + catalog.dropTable(tableIdent1, ignoreIfNotExists = false, purge = false) @@ -545,7 +645,7 @@ class DDLSuite extends QueryTest with SharedSQLContext with BeforeAndAfterEach { createTable(catalog, tableIdent1) val expectedTableIdent = tableIdent1.copy(database = Some("default")) val expectedTable = generateTable(catalog, expectedTableIdent) - assert(catalog.getTableMetadata(tableIdent1) === expectedTable) + checkCatalogTables(expectedTable, catalog.getTableMetadata(tableIdent1)) } test("create table in a specific db") { @@ -554,7 +654,7 @@ class DDLSuite extends QueryTest with SharedSQLContext with BeforeAndAfterEach { val tableIdent1 = TableIdentifier("tab1", Some("dbx")) createTable(catalog, tableIdent1) val expectedTable = generateTable(catalog, tableIdent1) - assert(catalog.getTableMetadata(tableIdent1) === expectedTable) + checkCatalogTables(expectedTable, catalog.getTableMetadata(tableIdent1)) } test("create table using") { @@ -595,21 +695,28 @@ class DDLSuite extends QueryTest with SharedSQLContext with BeforeAndAfterEach { } test("create temporary view using") { - val csvFile = - Thread.currentThread().getContextClassLoader.getResource("test-data/cars.csv").toString - withView("testview") { - sql(s"CREATE OR REPLACE TEMPORARY VIEW testview (c1 String, c2 String) USING " + - "org.apache.spark.sql.execution.datasources.csv.CSVFileFormat " + - s"OPTIONS (PATH '$csvFile')") + // when we test the HiveCatalogedDDLSuite, it will failed because the csvFile path above + // starts with 'jar:', and it is an illegal parameter for Path, so here we copy it + // to a temp file by withResourceTempPath + withResourceTempPath("test-data/cars.csv") { tmpFile => + withView("testview") { + sql(s"CREATE OR REPLACE TEMPORARY VIEW testview (c1 String, c2 String) USING " + + "org.apache.spark.sql.execution.datasources.csv.CSVFileFormat " + + s"OPTIONS (PATH '$tmpFile')") - checkAnswer( - sql("select c1, c2 from testview order by c1 limit 1"), + checkAnswer( + sql("select c1, c2 from testview order by c1 limit 1"), Row("1997", "Ford") :: Nil) - // Fails if creating a new view with the same name - intercept[TempTableAlreadyExistsException] { - sql(s"CREATE TEMPORARY VIEW testview USING " + - s"org.apache.spark.sql.execution.datasources.csv.CSVFileFormat OPTIONS (PATH '$csvFile')") + // Fails if creating a new view with the same name + intercept[TempTableAlreadyExistsException] { + sql( + s""" + |CREATE TEMPORARY VIEW testview + |USING org.apache.spark.sql.execution.datasources.csv.CSVFileFormat + |OPTIONS (PATH '$tmpFile') + """.stripMargin) + } } } } @@ -735,56 +842,6 @@ class DDLSuite extends QueryTest with SharedSQLContext with BeforeAndAfterEach { } } - test("alter table: set location") { - testSetLocation(isDatasourceTable = false) - } - - test("alter table: set location (datasource table)") { - testSetLocation(isDatasourceTable = true) - } - - test("alter table: set properties") { - testSetProperties(isDatasourceTable = false) - } - - test("alter table: set properties (datasource table)") { - testSetProperties(isDatasourceTable = true) - } - - test("alter table: unset properties") { - testUnsetProperties(isDatasourceTable = false) - } - - test("alter table: unset properties (datasource table)") { - testUnsetProperties(isDatasourceTable = true) - } - - // TODO: move this test to HiveDDLSuite.scala - ignore("alter table: set serde") { - testSetSerde(isDatasourceTable = false) - } - - test("alter table: set serde (datasource table)") { - testSetSerde(isDatasourceTable = true) - } - - // TODO: move this test to HiveDDLSuite.scala - ignore("alter table: set serde partition") { - testSetSerdePartition(isDatasourceTable = false) - } - - test("alter table: set serde partition (datasource table)") { - testSetSerdePartition(isDatasourceTable = true) - } - - test("alter table: change column") { - testChangeColumn(isDatasourceTable = false) - } - - test("alter table: change column (datasource table)") { - testChangeColumn(isDatasourceTable = true) - } - test("alter table: bucketing is not supported") { val catalog = spark.sessionState.catalog val tableIdent = TableIdentifier("tab1", Some("dbx")) @@ -809,14 +866,6 @@ class DDLSuite extends QueryTest with SharedSQLContext with BeforeAndAfterEach { assertUnsupported("ALTER TABLE dbx.tab1 NOT STORED AS DIRECTORIES") } - test("alter table: add partition") { - testAddPartitions(isDatasourceTable = false) - } - - test("alter table: add partition (datasource table)") { - testAddPartitions(isDatasourceTable = true) - } - test("alter table: recover partitions (sequential)") { withSQLConf("spark.rdd.parallelListingThreshold" -> "10") { testRecoverPartitions() @@ -829,7 +878,7 @@ class DDLSuite extends QueryTest with SharedSQLContext with BeforeAndAfterEach { } } - private def testRecoverPartitions() { + protected def testRecoverPartitions() { val catalog = spark.sessionState.catalog // table to alter does not exist intercept[AnalysisException] { @@ -868,8 +917,14 @@ class DDLSuite extends QueryTest with SharedSQLContext with BeforeAndAfterEach { sql("ALTER TABLE tab1 RECOVER PARTITIONS") assert(catalog.listPartitions(tableIdent).map(_.spec).toSet == Set(part1, part2)) - assert(catalog.getPartition(tableIdent, part1).parameters("numFiles") == "1") - assert(catalog.getPartition(tableIdent, part2).parameters("numFiles") == "2") + if (!isUsingHiveMetastore) { + assert(catalog.getPartition(tableIdent, part1).parameters("numFiles") == "1") + assert(catalog.getPartition(tableIdent, part2).parameters("numFiles") == "2") + } else { + // After ALTER TABLE, the statistics of the first partition is removed by Hive megastore + assert(catalog.getPartition(tableIdent, part1).parameters.get("numFiles").isEmpty) + assert(catalog.getPartition(tableIdent, part2).parameters("numFiles") == "2") + } } finally { fs.delete(root, true) } @@ -879,59 +934,10 @@ class DDLSuite extends QueryTest with SharedSQLContext with BeforeAndAfterEach { assertUnsupported("ALTER VIEW dbx.tab1 ADD IF NOT EXISTS PARTITION (b='2')") } - test("alter table: drop partition") { - testDropPartitions(isDatasourceTable = false) - } - - test("alter table: drop partition (datasource table)") { - testDropPartitions(isDatasourceTable = true) - } - test("alter table: drop partition is not supported for views") { assertUnsupported("ALTER VIEW dbx.tab1 DROP IF EXISTS PARTITION (b='2')") } - test("alter table: rename partition") { - testRenamePartitions(isDatasourceTable = false) - } - - test("alter table: rename partition (datasource table)") { - testRenamePartitions(isDatasourceTable = true) - } - - test("show table extended") { - withTempView("show1a", "show2b") { - sql( - """ - |CREATE TEMPORARY VIEW show1a - |USING org.apache.spark.sql.sources.DDLScanSource - |OPTIONS ( - | From '1', - | To '10', - | Table 'test1' - | - |) - """.stripMargin) - sql( - """ - |CREATE TEMPORARY VIEW show2b - |USING org.apache.spark.sql.sources.DDLScanSource - |OPTIONS ( - | From '1', - | To '10', - | Table 'test1' - |) - """.stripMargin) - assert( - sql("SHOW TABLE EXTENDED LIKE 'show*'").count() >= 2) - assert( - sql("SHOW TABLE EXTENDED LIKE 'show*'").schema == - StructType(StructField("database", StringType, false) :: - StructField("tableName", StringType, false) :: - StructField("isTemporary", BooleanType, false) :: - StructField("information", StringType, false) :: Nil)) - } - } test("show databases") { sql("CREATE DATABASE showdb2B") @@ -975,22 +981,14 @@ class DDLSuite extends QueryTest with SharedSQLContext with BeforeAndAfterEach { assert(catalog.listTables("default") == Nil) } - test("drop table") { - testDropTable(isDatasourceTable = false) - } - - test("drop table - data source table") { - testDropTable(isDatasourceTable = true) - } - - private def testDropTable(isDatasourceTable: Boolean): Unit = { + protected def testDropTable(isDatasourceTable: Boolean): Unit = { + if (!isUsingHiveMetastore) { + assert(isDatasourceTable, "InMemoryCatalog only supports data source tables") + } val catalog = spark.sessionState.catalog val tableIdent = TableIdentifier("tab1", Some("dbx")) createDatabase(catalog, "dbx") - createTable(catalog, tableIdent) - if (isDatasourceTable) { - convertToDatasourceTable(catalog, tableIdent) - } + createTable(catalog, tableIdent, isDatasourceTable) assert(catalog.listTables("dbx") == Seq(tableIdent)) sql("DROP TABLE dbx.tab1") assert(catalog.listTables("dbx") == Nil) @@ -1014,23 +1012,20 @@ class DDLSuite extends QueryTest with SharedSQLContext with BeforeAndAfterEach { e.getMessage.contains("Cannot drop a table with DROP VIEW. Please use DROP TABLE instead")) } - private def convertToDatasourceTable( - catalog: SessionCatalog, - tableIdent: TableIdentifier): Unit = { - catalog.alterTable(catalog.getTableMetadata(tableIdent).copy( - provider = Some("csv"))) - } - - private def testSetProperties(isDatasourceTable: Boolean): Unit = { + protected def testSetProperties(isDatasourceTable: Boolean): Unit = { + if (!isUsingHiveMetastore) { + assert(isDatasourceTable, "InMemoryCatalog only supports data source tables") + } val catalog = spark.sessionState.catalog val tableIdent = TableIdentifier("tab1", Some("dbx")) createDatabase(catalog, "dbx") - createTable(catalog, tableIdent) - if (isDatasourceTable) { - convertToDatasourceTable(catalog, tableIdent) - } + createTable(catalog, tableIdent, isDatasourceTable) def getProps: Map[String, String] = { - catalog.getTableMetadata(tableIdent).properties + if (isUsingHiveMetastore) { + normalizeCatalogTable(catalog.getTableMetadata(tableIdent)).properties + } else { + catalog.getTableMetadata(tableIdent).properties + } } assert(getProps.isEmpty) // set table properties @@ -1046,16 +1041,20 @@ class DDLSuite extends QueryTest with SharedSQLContext with BeforeAndAfterEach { } } - private def testUnsetProperties(isDatasourceTable: Boolean): Unit = { + protected def testUnsetProperties(isDatasourceTable: Boolean): Unit = { + if (!isUsingHiveMetastore) { + assert(isDatasourceTable, "InMemoryCatalog only supports data source tables") + } val catalog = spark.sessionState.catalog val tableIdent = TableIdentifier("tab1", Some("dbx")) createDatabase(catalog, "dbx") - createTable(catalog, tableIdent) - if (isDatasourceTable) { - convertToDatasourceTable(catalog, tableIdent) - } + createTable(catalog, tableIdent, isDatasourceTable) def getProps: Map[String, String] = { - catalog.getTableMetadata(tableIdent).properties + if (isUsingHiveMetastore) { + normalizeCatalogTable(catalog.getTableMetadata(tableIdent)).properties + } else { + catalog.getTableMetadata(tableIdent).properties + } } // unset table properties sql("ALTER TABLE dbx.tab1 SET TBLPROPERTIES ('j' = 'am', 'p' = 'an', 'c' = 'lan', 'x' = 'y')") @@ -1079,49 +1078,46 @@ class DDLSuite extends QueryTest with SharedSQLContext with BeforeAndAfterEach { assert(getProps == Map("x" -> "y")) } - private def testSetLocation(isDatasourceTable: Boolean): Unit = { + protected def testSetLocation(isDatasourceTable: Boolean): Unit = { + if (!isUsingHiveMetastore) { + assert(isDatasourceTable, "InMemoryCatalog only supports data source tables") + } val catalog = spark.sessionState.catalog val tableIdent = TableIdentifier("tab1", Some("dbx")) val partSpec = Map("a" -> "1", "b" -> "2") createDatabase(catalog, "dbx") - createTable(catalog, tableIdent) + createTable(catalog, tableIdent, isDatasourceTable) createTablePartition(catalog, partSpec, tableIdent) - if (isDatasourceTable) { - convertToDatasourceTable(catalog, tableIdent) - } assert(catalog.getTableMetadata(tableIdent).storage.locationUri.isDefined) - assert(catalog.getTableMetadata(tableIdent).storage.properties.isEmpty) + assert(normalizeSerdeProp(catalog.getTableMetadata(tableIdent).storage.properties).isEmpty) assert(catalog.getPartition(tableIdent, partSpec).storage.locationUri.isDefined) - assert(catalog.getPartition(tableIdent, partSpec).storage.properties.isEmpty) + assert( + normalizeSerdeProp(catalog.getPartition(tableIdent, partSpec).storage.properties).isEmpty) + // Verify that the location is set to the expected string - def verifyLocation(expected: String, spec: Option[TablePartitionSpec] = None): Unit = { + def verifyLocation(expected: URI, spec: Option[TablePartitionSpec] = None): Unit = { val storageFormat = spec .map { s => catalog.getPartition(tableIdent, s).storage } .getOrElse { catalog.getTableMetadata(tableIdent).storage } - if (isDatasourceTable) { - if (spec.isDefined) { - assert(storageFormat.properties.isEmpty) - assert(storageFormat.locationUri === Some(expected)) - } else { - assert(storageFormat.locationUri === Some(expected)) - } - } else { - assert(storageFormat.locationUri === Some(expected)) - } + // TODO(gatorsmile): fix the bug in alter table set location. + // if (isUsingHiveMetastore) { + // assert(storageFormat.properties.get("path") === expected) + // } + assert(storageFormat.locationUri === Some(expected)) } // set table location sql("ALTER TABLE dbx.tab1 SET LOCATION '/path/to/your/lovely/heart'") - verifyLocation("/path/to/your/lovely/heart") + verifyLocation(new URI("/path/to/your/lovely/heart")) // set table partition location sql("ALTER TABLE dbx.tab1 PARTITION (a='1', b='2') SET LOCATION '/path/to/part/ways'") - verifyLocation("/path/to/part/ways", Some(partSpec)) + verifyLocation(new URI("/path/to/part/ways"), Some(partSpec)) // set table location without explicitly specifying database catalog.setCurrentDatabase("dbx") sql("ALTER TABLE tab1 SET LOCATION '/swanky/steak/place'") - verifyLocation("/swanky/steak/place") + verifyLocation(new URI("/swanky/steak/place")) // set table partition location without explicitly specifying database sql("ALTER TABLE tab1 PARTITION (a='1', b='2') SET LOCATION 'vienna'") - verifyLocation("vienna", Some(partSpec)) + verifyLocation(new URI("vienna"), Some(partSpec)) // table to alter does not exist intercept[AnalysisException] { sql("ALTER TABLE dbx.does_not_exist SET LOCATION '/mister/spark'") @@ -1132,16 +1128,33 @@ class DDLSuite extends QueryTest with SharedSQLContext with BeforeAndAfterEach { } } - private def testSetSerde(isDatasourceTable: Boolean): Unit = { + protected def testSetSerde(isDatasourceTable: Boolean): Unit = { + if (!isUsingHiveMetastore) { + assert(isDatasourceTable, "InMemoryCatalog only supports data source tables") + } val catalog = spark.sessionState.catalog val tableIdent = TableIdentifier("tab1", Some("dbx")) createDatabase(catalog, "dbx") - createTable(catalog, tableIdent) - if (isDatasourceTable) { - convertToDatasourceTable(catalog, tableIdent) + createTable(catalog, tableIdent, isDatasourceTable) + def checkSerdeProps(expectedSerdeProps: Map[String, String]): Unit = { + val serdeProp = catalog.getTableMetadata(tableIdent).storage.properties + if (isUsingHiveMetastore) { + assert(normalizeSerdeProp(serdeProp) == expectedSerdeProps) + } else { + assert(serdeProp == expectedSerdeProps) + } } - assert(catalog.getTableMetadata(tableIdent).storage.serde.isEmpty) - assert(catalog.getTableMetadata(tableIdent).storage.properties.isEmpty) + if (isUsingHiveMetastore) { + val expectedSerde = if (isDatasourceTable) { + "org.apache.hadoop.hive.ql.io.parquet.serde.ParquetHiveSerDe" + } else { + "org.apache.hadoop.hive.serde2.lazy.LazySimpleSerDe" + } + assert(catalog.getTableMetadata(tableIdent).storage.serde == Some(expectedSerde)) + } else { + assert(catalog.getTableMetadata(tableIdent).storage.serde.isEmpty) + } + checkSerdeProps(Map.empty[String, String]) // set table serde and/or properties (should fail on datasource tables) if (isDatasourceTable) { val e1 = intercept[AnalysisException] { @@ -1154,45 +1167,61 @@ class DDLSuite extends QueryTest with SharedSQLContext with BeforeAndAfterEach { assert(e1.getMessage.contains("datasource")) assert(e2.getMessage.contains("datasource")) } else { - sql("ALTER TABLE dbx.tab1 SET SERDE 'org.apache.jadoop'") - assert(catalog.getTableMetadata(tableIdent).storage.serde == Some("org.apache.jadoop")) - assert(catalog.getTableMetadata(tableIdent).storage.properties.isEmpty) - sql("ALTER TABLE dbx.tab1 SET SERDE 'org.apache.madoop' " + + val newSerde = "org.apache.hadoop.hive.ql.io.parquet.serde.ParquetHiveSerDe" + sql(s"ALTER TABLE dbx.tab1 SET SERDE '$newSerde'") + assert(catalog.getTableMetadata(tableIdent).storage.serde == Some(newSerde)) + checkSerdeProps(Map.empty[String, String]) + val serde2 = "org.apache.hadoop.hive.serde2.columnar.LazyBinaryColumnarSerDe" + sql(s"ALTER TABLE dbx.tab1 SET SERDE '$serde2' " + "WITH SERDEPROPERTIES ('k' = 'v', 'kay' = 'vee')") - assert(catalog.getTableMetadata(tableIdent).storage.serde == Some("org.apache.madoop")) - assert(catalog.getTableMetadata(tableIdent).storage.properties == - Map("k" -> "v", "kay" -> "vee")) + assert(catalog.getTableMetadata(tableIdent).storage.serde == Some(serde2)) + checkSerdeProps(Map("k" -> "v", "kay" -> "vee")) } // set serde properties only sql("ALTER TABLE dbx.tab1 SET SERDEPROPERTIES ('k' = 'vvv', 'kay' = 'vee')") - assert(catalog.getTableMetadata(tableIdent).storage.properties == - Map("k" -> "vvv", "kay" -> "vee")) + checkSerdeProps(Map("k" -> "vvv", "kay" -> "vee")) // set things without explicitly specifying database catalog.setCurrentDatabase("dbx") sql("ALTER TABLE tab1 SET SERDEPROPERTIES ('kay' = 'veee')") - assert(catalog.getTableMetadata(tableIdent).storage.properties == - Map("k" -> "vvv", "kay" -> "veee")) + checkSerdeProps(Map("k" -> "vvv", "kay" -> "veee")) // table to alter does not exist intercept[AnalysisException] { sql("ALTER TABLE does_not_exist SET SERDEPROPERTIES ('x' = 'y')") } } - private def testSetSerdePartition(isDatasourceTable: Boolean): Unit = { + protected def testSetSerdePartition(isDatasourceTable: Boolean): Unit = { + if (!isUsingHiveMetastore) { + assert(isDatasourceTable, "InMemoryCatalog only supports data source tables") + } val catalog = spark.sessionState.catalog val tableIdent = TableIdentifier("tab1", Some("dbx")) val spec = Map("a" -> "1", "b" -> "2") createDatabase(catalog, "dbx") - createTable(catalog, tableIdent) + createTable(catalog, tableIdent, isDatasourceTable) createTablePartition(catalog, spec, tableIdent) createTablePartition(catalog, Map("a" -> "1", "b" -> "3"), tableIdent) createTablePartition(catalog, Map("a" -> "2", "b" -> "2"), tableIdent) createTablePartition(catalog, Map("a" -> "2", "b" -> "3"), tableIdent) - if (isDatasourceTable) { - convertToDatasourceTable(catalog, tableIdent) + def checkPartitionSerdeProps(expectedSerdeProps: Map[String, String]): Unit = { + val serdeProp = catalog.getPartition(tableIdent, spec).storage.properties + if (isUsingHiveMetastore) { + assert(normalizeSerdeProp(serdeProp) == expectedSerdeProps) + } else { + assert(serdeProp == expectedSerdeProps) + } } - assert(catalog.getPartition(tableIdent, spec).storage.serde.isEmpty) - assert(catalog.getPartition(tableIdent, spec).storage.properties.isEmpty) + if (isUsingHiveMetastore) { + val expectedSerde = if (isDatasourceTable) { + "org.apache.hadoop.hive.ql.io.parquet.serde.ParquetHiveSerDe" + } else { + "org.apache.hadoop.hive.serde2.lazy.LazySimpleSerDe" + } + assert(catalog.getPartition(tableIdent, spec).storage.serde == Some(expectedSerde)) + } else { + assert(catalog.getPartition(tableIdent, spec).storage.serde.isEmpty) + } + checkPartitionSerdeProps(Map.empty[String, String]) // set table serde and/or properties (should fail on datasource tables) if (isDatasourceTable) { val e1 = intercept[AnalysisException] { @@ -1207,26 +1236,23 @@ class DDLSuite extends QueryTest with SharedSQLContext with BeforeAndAfterEach { } else { sql("ALTER TABLE dbx.tab1 PARTITION (a=1, b=2) SET SERDE 'org.apache.jadoop'") assert(catalog.getPartition(tableIdent, spec).storage.serde == Some("org.apache.jadoop")) - assert(catalog.getPartition(tableIdent, spec).storage.properties.isEmpty) + checkPartitionSerdeProps(Map.empty[String, String]) sql("ALTER TABLE dbx.tab1 PARTITION (a=1, b=2) SET SERDE 'org.apache.madoop' " + "WITH SERDEPROPERTIES ('k' = 'v', 'kay' = 'vee')") assert(catalog.getPartition(tableIdent, spec).storage.serde == Some("org.apache.madoop")) - assert(catalog.getPartition(tableIdent, spec).storage.properties == - Map("k" -> "v", "kay" -> "vee")) + checkPartitionSerdeProps(Map("k" -> "v", "kay" -> "vee")) } // set serde properties only maybeWrapException(isDatasourceTable) { sql("ALTER TABLE dbx.tab1 PARTITION (a=1, b=2) " + "SET SERDEPROPERTIES ('k' = 'vvv', 'kay' = 'vee')") - assert(catalog.getPartition(tableIdent, spec).storage.properties == - Map("k" -> "vvv", "kay" -> "vee")) + checkPartitionSerdeProps(Map("k" -> "vvv", "kay" -> "vee")) } // set things without explicitly specifying database catalog.setCurrentDatabase("dbx") maybeWrapException(isDatasourceTable) { sql("ALTER TABLE tab1 PARTITION (a=1, b=2) SET SERDEPROPERTIES ('kay' = 'veee')") - assert(catalog.getPartition(tableIdent, spec).storage.properties == - Map("k" -> "vvv", "kay" -> "veee")) + checkPartitionSerdeProps(Map("k" -> "vvv", "kay" -> "veee")) } // table to alter does not exist intercept[AnalysisException] { @@ -1234,7 +1260,10 @@ class DDLSuite extends QueryTest with SharedSQLContext with BeforeAndAfterEach { } } - private def testAddPartitions(isDatasourceTable: Boolean): Unit = { + protected def testAddPartitions(isDatasourceTable: Boolean): Unit = { + if (!isUsingHiveMetastore) { + assert(isDatasourceTable, "InMemoryCatalog only supports data source tables") + } val catalog = spark.sessionState.catalog val tableIdent = TableIdentifier("tab1", Some("dbx")) val part1 = Map("a" -> "1", "b" -> "5") @@ -1243,11 +1272,8 @@ class DDLSuite extends QueryTest with SharedSQLContext with BeforeAndAfterEach { val part4 = Map("a" -> "4", "b" -> "8") val part5 = Map("a" -> "9", "b" -> "9") createDatabase(catalog, "dbx") - createTable(catalog, tableIdent) + createTable(catalog, tableIdent, isDatasourceTable) createTablePartition(catalog, part1, tableIdent) - if (isDatasourceTable) { - convertToDatasourceTable(catalog, tableIdent) - } assert(catalog.listPartitions(tableIdent).map(_.spec).toSet == Set(part1)) // basic add partition @@ -1255,7 +1281,15 @@ class DDLSuite extends QueryTest with SharedSQLContext with BeforeAndAfterEach { "PARTITION (a='2', b='6') LOCATION 'paris' PARTITION (a='3', b='7')") assert(catalog.listPartitions(tableIdent).map(_.spec).toSet == Set(part1, part2, part3)) assert(catalog.getPartition(tableIdent, part1).storage.locationUri.isDefined) - assert(catalog.getPartition(tableIdent, part2).storage.locationUri == Option("paris")) + val partitionLocation = if (isUsingHiveMetastore) { + val tableLocation = catalog.getTableMetadata(tableIdent).storage.locationUri + assert(tableLocation.isDefined) + makeQualifiedPath(new Path(tableLocation.get.toString, "paris").toString) + } else { + new URI("paris") + } + + assert(catalog.getPartition(tableIdent, part2).storage.locationUri == Option(partitionLocation)) assert(catalog.getPartition(tableIdent, part3).storage.locationUri.isDefined) // add partitions without explicitly specifying database @@ -1285,7 +1319,10 @@ class DDLSuite extends QueryTest with SharedSQLContext with BeforeAndAfterEach { Set(part1, part2, part3, part4, part5)) } - private def testDropPartitions(isDatasourceTable: Boolean): Unit = { + protected def testDropPartitions(isDatasourceTable: Boolean): Unit = { + if (!isUsingHiveMetastore) { + assert(isDatasourceTable, "InMemoryCatalog only supports data source tables") + } val catalog = spark.sessionState.catalog val tableIdent = TableIdentifier("tab1", Some("dbx")) val part1 = Map("a" -> "1", "b" -> "5") @@ -1294,7 +1331,7 @@ class DDLSuite extends QueryTest with SharedSQLContext with BeforeAndAfterEach { val part4 = Map("a" -> "4", "b" -> "8") val part5 = Map("a" -> "9", "b" -> "9") createDatabase(catalog, "dbx") - createTable(catalog, tableIdent) + createTable(catalog, tableIdent, isDatasourceTable) createTablePartition(catalog, part1, tableIdent) createTablePartition(catalog, part2, tableIdent) createTablePartition(catalog, part3, tableIdent) @@ -1302,9 +1339,6 @@ class DDLSuite extends QueryTest with SharedSQLContext with BeforeAndAfterEach { createTablePartition(catalog, part5, tableIdent) assert(catalog.listPartitions(tableIdent).map(_.spec).toSet == Set(part1, part2, part3, part4, part5)) - if (isDatasourceTable) { - convertToDatasourceTable(catalog, tableIdent) - } // basic drop partition sql("ALTER TABLE dbx.tab1 DROP IF EXISTS PARTITION (a='4', b='8'), PARTITION (a='3', b='7')") @@ -1338,21 +1372,21 @@ class DDLSuite extends QueryTest with SharedSQLContext with BeforeAndAfterEach { assert(catalog.listPartitions(tableIdent).isEmpty) } - private def testRenamePartitions(isDatasourceTable: Boolean): Unit = { + protected def testRenamePartitions(isDatasourceTable: Boolean): Unit = { + if (!isUsingHiveMetastore) { + assert(isDatasourceTable, "InMemoryCatalog only supports data source tables") + } val catalog = spark.sessionState.catalog val tableIdent = TableIdentifier("tab1", Some("dbx")) val part1 = Map("a" -> "1", "b" -> "q") val part2 = Map("a" -> "2", "b" -> "c") val part3 = Map("a" -> "3", "b" -> "p") createDatabase(catalog, "dbx") - createTable(catalog, tableIdent) + createTable(catalog, tableIdent, isDatasourceTable) createTablePartition(catalog, part1, tableIdent) createTablePartition(catalog, part2, tableIdent) createTablePartition(catalog, part3, tableIdent) assert(catalog.listPartitions(tableIdent).map(_.spec).toSet == Set(part1, part2, part3)) - if (isDatasourceTable) { - convertToDatasourceTable(catalog, tableIdent) - } // basic rename partition sql("ALTER TABLE dbx.tab1 PARTITION (a='1', b='q') RENAME TO PARTITION (a='100', b='p')") @@ -1382,15 +1416,15 @@ class DDLSuite extends QueryTest with SharedSQLContext with BeforeAndAfterEach { Set(Map("a" -> "1", "b" -> "p"), Map("a" -> "20", "b" -> "c"), Map("a" -> "3", "b" -> "p"))) } - private def testChangeColumn(isDatasourceTable: Boolean): Unit = { + protected def testChangeColumn(isDatasourceTable: Boolean): Unit = { + if (!isUsingHiveMetastore) { + assert(isDatasourceTable, "InMemoryCatalog only supports data source tables") + } val catalog = spark.sessionState.catalog val resolver = spark.sessionState.conf.resolver val tableIdent = TableIdentifier("tab1", Some("dbx")) createDatabase(catalog, "dbx") - createTable(catalog, tableIdent) - if (isDatasourceTable) { - convertToDatasourceTable(catalog, tableIdent) - } + createTable(catalog, tableIdent, isDatasourceTable) def getMetadata(colName: String): Metadata = { val column = catalog.getTableMetadata(tableIdent).schema.fields.find { field => resolver(field.name, colName) @@ -1482,35 +1516,6 @@ class DDLSuite extends QueryTest with SharedSQLContext with BeforeAndAfterEach { ) } - test("create a managed Hive source table") { - assume(spark.sparkContext.conf.get(CATALOG_IMPLEMENTATION) == "in-memory") - val tabName = "tbl" - withTable(tabName) { - val e = intercept[AnalysisException] { - sql(s"CREATE TABLE $tabName (i INT, j STRING)") - }.getMessage - assert(e.contains("Hive support is required to CREATE Hive TABLE")) - } - } - - test("create an external Hive source table") { - assume(spark.sparkContext.conf.get(CATALOG_IMPLEMENTATION) == "in-memory") - withTempDir { tempDir => - val tabName = "tbl" - withTable(tabName) { - val e = intercept[AnalysisException] { - sql( - s""" - |CREATE EXTERNAL TABLE $tabName (i INT, j STRING) - |ROW FORMAT DELIMITED FIELDS TERMINATED BY ',' - |LOCATION '${tempDir.toURI}' - """.stripMargin) - }.getMessage - assert(e.contains("Hive support is required to CREATE Hive TABLE")) - } - } - } - test("create a data source table without schema") { import testImplicits._ withTempPath { tempDir => @@ -1549,22 +1554,6 @@ class DDLSuite extends QueryTest with SharedSQLContext with BeforeAndAfterEach { } } - test("Create Hive Table As Select") { - import testImplicits._ - withTable("t", "t1") { - var e = intercept[AnalysisException] { - sql("CREATE TABLE t SELECT 1 as a, 1 as b") - }.getMessage - assert(e.contains("Hive support is required to CREATE Hive TABLE (AS SELECT)")) - - spark.range(1).select('id as 'a, 'id as 'b).write.saveAsTable("t1") - e = intercept[AnalysisException] { - sql("CREATE TABLE t SELECT a, b from t1") - }.getMessage - assert(e.contains("Hive support is required to CREATE Hive TABLE (AS SELECT)")) - } - } - test("Create Data Source Table As Select") { import testImplicits._ withTable("t", "t1", "t2") { @@ -1578,17 +1567,20 @@ class DDLSuite extends QueryTest with SharedSQLContext with BeforeAndAfterEach { } test("drop current database") { - sql("CREATE DATABASE temp") - sql("USE temp") - sql("DROP DATABASE temp") - val e = intercept[AnalysisException] { + withDatabase("temp") { + sql("CREATE DATABASE temp") + sql("USE temp") + sql("DROP DATABASE temp") + val e = intercept[AnalysisException] { sql("CREATE TABLE t (a INT, b INT) USING parquet") }.getMessage - assert(e.contains("Database 'temp' not found")) + assert(e.contains("Database 'temp' not found")) + } } test("drop default database") { - Seq("true", "false").foreach { caseSensitive => + val caseSensitiveOptions = if (isUsingHiveMetastore) Seq("false") else Seq("true", "false") + caseSensitiveOptions.foreach { caseSensitive => withSQLConf(SQLConf.CASE_SENSITIVE.key -> caseSensitive) { var message = intercept[AnalysisException] { sql("DROP DATABASE default") @@ -1790,7 +1782,7 @@ class DDLSuite extends QueryTest with SharedSQLContext with BeforeAndAfterEach { withTable(tabName) { sql(s"CREATE TABLE $tabName(col1 int, col2 string) USING parquet ") val message = intercept[AnalysisException] { - sql(s"SHOW COLUMNS IN $db.showcolumn FROM ${db.toUpperCase}") + sql(s"SHOW COLUMNS IN $db.showcolumn FROM ${db.toUpperCase(Locale.ROOT)}") }.getMessage assert(message.contains("SHOW COLUMNS with conflicting databases")) } @@ -1813,27 +1805,30 @@ class DDLSuite extends QueryTest with SharedSQLContext with BeforeAndAfterEach { checkAnswer(spark.table("tbl"), Row(1)) val defaultTablePath = spark.sessionState.catalog .getTableMetadata(TableIdentifier("tbl")).storage.locationUri.get - - sql(s"ALTER TABLE tbl SET LOCATION '${dir.toURI}'") - spark.catalog.refreshTable("tbl") - // SET LOCATION won't move data from previous table path to new table path. - assert(spark.table("tbl").count() == 0) - // the previous table path should be still there. - assert(new File(new URI(defaultTablePath)).exists()) - - sql("INSERT INTO tbl SELECT 2") - checkAnswer(spark.table("tbl"), Row(2)) - // newly inserted data will go to the new table path. - assert(dir.listFiles().nonEmpty) - - sql("DROP TABLE tbl") - // the new table path will be removed after DROP TABLE. - assert(!dir.exists()) + try { + sql(s"ALTER TABLE tbl SET LOCATION '${dir.toURI}'") + spark.catalog.refreshTable("tbl") + // SET LOCATION won't move data from previous table path to new table path. + assert(spark.table("tbl").count() == 0) + // the previous table path should be still there. + assert(new File(defaultTablePath).exists()) + + sql("INSERT INTO tbl SELECT 2") + checkAnswer(spark.table("tbl"), Row(2)) + // newly inserted data will go to the new table path. + assert(dir.listFiles().nonEmpty) + + sql("DROP TABLE tbl") + // the new table path will be removed after DROP TABLE. + assert(!dir.exists()) + } finally { + Utils.deleteRecursively(new File(defaultTablePath)) + } } } } - test("insert data to a data source table which has a not existed location should succeed") { + test("insert data to a data source table which has a non-existing location should succeed") { withTable("t") { withTempDir { dir => spark.sql( @@ -1843,28 +1838,27 @@ class DDLSuite extends QueryTest with SharedSQLContext with BeforeAndAfterEach { |OPTIONS(path "$dir") """.stripMargin) val table = spark.sessionState.catalog.getTableMetadata(TableIdentifier("t")) - assert(table.location == dir.getAbsolutePath) + assert(table.location == makeQualifiedPath(dir.getAbsolutePath)) dir.delete - val tableLocFile = new File(table.location) - assert(!tableLocFile.exists) + assert(!dir.exists) spark.sql("INSERT INTO TABLE t SELECT 'c', 1") - assert(tableLocFile.exists) + assert(dir.exists) checkAnswer(spark.table("t"), Row("c", 1) :: Nil) Utils.deleteRecursively(dir) - assert(!tableLocFile.exists) + assert(!dir.exists) spark.sql("INSERT OVERWRITE TABLE t SELECT 'c', 1") - assert(tableLocFile.exists) + assert(dir.exists) checkAnswer(spark.table("t"), Row("c", 1) :: Nil) val newDirFile = new File(dir, "x") - val newDir = newDirFile.toURI.toString + val newDir = newDirFile.getAbsolutePath spark.sql(s"ALTER TABLE t SET LOCATION '$newDir'") spark.sessionState.catalog.refreshTable(TableIdentifier("t")) val table1 = spark.sessionState.catalog.getTableMetadata(TableIdentifier("t")) - assert(table1.location == newDir) + assert(table1.location == new URI(newDir)) assert(!newDirFile.exists) spark.sql("INSERT INTO TABLE t SELECT 'c', 1") @@ -1874,7 +1868,7 @@ class DDLSuite extends QueryTest with SharedSQLContext with BeforeAndAfterEach { } } - test("insert into a data source table with no existed partition location should succeed") { + test("insert into a data source table with a non-existing partition location should succeed") { withTable("t") { withTempDir { dir => spark.sql( @@ -1885,7 +1879,7 @@ class DDLSuite extends QueryTest with SharedSQLContext with BeforeAndAfterEach { |LOCATION "$dir" """.stripMargin) val table = spark.sessionState.catalog.getTableMetadata(TableIdentifier("t")) - assert(table.location == dir.getAbsolutePath) + assert(table.location == makeQualifiedPath(dir.getAbsolutePath)) spark.sql("INSERT INTO TABLE t PARTITION(a=1, b=2) SELECT 3, 4") checkAnswer(spark.table("t"), Row(3, 4, 1, 2) :: Nil) @@ -1901,7 +1895,7 @@ class DDLSuite extends QueryTest with SharedSQLContext with BeforeAndAfterEach { } } - test("read data from a data source table which has a not existed location should succeed") { + test("read data from a data source table which has a non-existing location should succeed") { withTable("t") { withTempDir { dir => spark.sql( @@ -1911,13 +1905,14 @@ class DDLSuite extends QueryTest with SharedSQLContext with BeforeAndAfterEach { |OPTIONS(path "$dir") """.stripMargin) val table = spark.sessionState.catalog.getTableMetadata(TableIdentifier("t")) - assert(table.location == dir.getAbsolutePath) + + assert(table.location == makeQualifiedPath(dir.getAbsolutePath)) dir.delete() checkAnswer(spark.table("t"), Nil) val newDirFile = new File(dir, "x") - val newDir = newDirFile.toURI.toString + val newDir = newDirFile.toURI spark.sql(s"ALTER TABLE t SET LOCATION '$newDir'") val table1 = spark.sessionState.catalog.getTableMetadata(TableIdentifier("t")) @@ -1928,7 +1923,7 @@ class DDLSuite extends QueryTest with SharedSQLContext with BeforeAndAfterEach { } } - test("read data from a data source table with no existed partition location should succeed") { + test("read data from a data source table with non-existing partition location should succeed") { withTable("t") { withTempDir { dir => spark.sql( @@ -1950,48 +1945,341 @@ class DDLSuite extends QueryTest with SharedSQLContext with BeforeAndAfterEach { } } + test("create datasource table with a non-existing location") { + withTable("t", "t1") { + withTempPath { dir => + spark.sql(s"CREATE TABLE t(a int, b int) USING parquet LOCATION '$dir'") + + val table = spark.sessionState.catalog.getTableMetadata(TableIdentifier("t")) + assert(table.location == makeQualifiedPath(dir.getAbsolutePath)) + + spark.sql("INSERT INTO TABLE t SELECT 1, 2") + assert(dir.exists()) + + checkAnswer(spark.table("t"), Row(1, 2)) + } + // partition table + withTempPath { dir => + spark.sql(s"CREATE TABLE t1(a int, b int) USING parquet PARTITIONED BY(a) LOCATION '$dir'") + + val table = spark.sessionState.catalog.getTableMetadata(TableIdentifier("t1")) + assert(table.location == makeQualifiedPath(dir.getAbsolutePath)) + + spark.sql("INSERT INTO TABLE t1 PARTITION(a=1) SELECT 2") + + val partDir = new File(dir, "a=1") + assert(partDir.exists()) + + checkAnswer(spark.table("t1"), Row(2, 1)) + } + } + } + Seq(true, false).foreach { shouldDelete => - val tcName = if (shouldDelete) "non-existent" else "existed" + val tcName = if (shouldDelete) "non-existing" else "existed" test(s"CTAS for external data source table with a $tcName location") { withTable("t", "t1") { - withTempDir { - dir => - if (shouldDelete) { - dir.delete() - } - spark.sql( - s""" - |CREATE TABLE t - |USING parquet - |LOCATION '$dir' - |AS SELECT 3 as a, 4 as b, 1 as c, 2 as d - """.stripMargin) + withTempDir { dir => + if (shouldDelete) dir.delete() + spark.sql( + s""" + |CREATE TABLE t + |USING parquet + |LOCATION '$dir' + |AS SELECT 3 as a, 4 as b, 1 as c, 2 as d + """.stripMargin) + val table = spark.sessionState.catalog.getTableMetadata(TableIdentifier("t")) + assert(table.location == makeQualifiedPath(dir.getAbsolutePath)) + + checkAnswer(spark.table("t"), Row(3, 4, 1, 2)) + } + // partition table + withTempDir { dir => + if (shouldDelete) dir.delete() + spark.sql( + s""" + |CREATE TABLE t1 + |USING parquet + |PARTITIONED BY(a, b) + |LOCATION '$dir' + |AS SELECT 3 as a, 4 as b, 1 as c, 2 as d + """.stripMargin) + val table = spark.sessionState.catalog.getTableMetadata(TableIdentifier("t1")) + assert(table.location == makeQualifiedPath(dir.getAbsolutePath)) + + val partDir = new File(dir, "a=3") + assert(partDir.exists()) + + checkAnswer(spark.table("t1"), Row(1, 2, 3, 4)) + } + } + } + } + + Seq("a b", "a:b", "a%b", "a,b").foreach { specialChars => + test(s"data source table:partition column name containing $specialChars") { + withTable("t") { + withTempDir { dir => + spark.sql( + s""" + |CREATE TABLE t(a string, `$specialChars` string) + |USING parquet + |PARTITIONED BY(`$specialChars`) + |LOCATION '$dir' + """.stripMargin) + + assert(dir.listFiles().isEmpty) + spark.sql(s"INSERT INTO TABLE t PARTITION(`$specialChars`=2) SELECT 1") + val partEscaped = s"${ExternalCatalogUtils.escapePathName(specialChars)}=2" + val partFile = new File(dir, partEscaped) + assert(partFile.listFiles().length >= 1) + checkAnswer(spark.table("t"), Row("1", "2") :: Nil) + } + } + } + } + + Seq("a b", "a:b", "a%b").foreach { specialChars => + test(s"location uri contains $specialChars for datasource table") { + withTable("t", "t1") { + withTempDir { dir => + val loc = new File(dir, specialChars) + loc.mkdir() + spark.sql( + s""" + |CREATE TABLE t(a string) + |USING parquet + |LOCATION '$loc' + """.stripMargin) + + val table = spark.sessionState.catalog.getTableMetadata(TableIdentifier("t")) + assert(table.location == makeQualifiedPath(loc.getAbsolutePath)) + assert(new Path(table.location).toString.contains(specialChars)) + + assert(loc.listFiles().isEmpty) + spark.sql("INSERT INTO TABLE t SELECT 1") + assert(loc.listFiles().length >= 1) + checkAnswer(spark.table("t"), Row("1") :: Nil) + } + + withTempDir { dir => + val loc = new File(dir, specialChars) + loc.mkdir() + spark.sql( + s""" + |CREATE TABLE t1(a string, b string) + |USING parquet + |PARTITIONED BY(b) + |LOCATION '$loc' + """.stripMargin) + + val table = spark.sessionState.catalog.getTableMetadata(TableIdentifier("t1")) + assert(table.location == makeQualifiedPath(loc.getAbsolutePath)) + assert(new Path(table.location).toString.contains(specialChars)) + + assert(loc.listFiles().isEmpty) + spark.sql("INSERT INTO TABLE t1 PARTITION(b=2) SELECT 1") + val partFile = new File(loc, "b=2") + assert(partFile.listFiles().length >= 1) + checkAnswer(spark.table("t1"), Row("1", "2") :: Nil) + + spark.sql("INSERT INTO TABLE t1 PARTITION(b='2017-03-03 12:13%3A14') SELECT 1") + val partFile1 = new File(loc, "b=2017-03-03 12:13%3A14") + assert(!partFile1.exists()) + val partFile2 = new File(loc, "b=2017-03-03 12%3A13%253A14") + assert(partFile2.listFiles().length >= 1) + checkAnswer(spark.table("t1"), Row("1", "2") :: Row("1", "2017-03-03 12:13%3A14") :: Nil) + } + } + } + } + + Seq("a b", "a:b", "a%b").foreach { specialChars => + test(s"location uri contains $specialChars for database") { + withDatabase ("tmpdb") { + withTable("t") { + withTempDir { dir => + val loc = new File(dir, specialChars) + spark.sql(s"CREATE DATABASE tmpdb LOCATION '$loc'") + spark.sql("USE tmpdb") + + import testImplicits._ + Seq(1).toDF("a").write.saveAsTable("t") + val tblloc = new File(loc, "t") val table = spark.sessionState.catalog.getTableMetadata(TableIdentifier("t")) - assert(table.location == dir.getAbsolutePath) + assert(table.location == makeQualifiedPath(tblloc.getAbsolutePath)) + assert(tblloc.listFiles().nonEmpty) + } + } + } + } + } + + test("the qualified path of a datasource table is stored in the catalog") { + withTable("t", "t1") { + withTempDir { dir => + assert(!dir.getAbsolutePath.startsWith("file:/")) + spark.sql( + s""" + |CREATE TABLE t(a string) + |USING parquet + |LOCATION '$dir' + """.stripMargin) + val table = spark.sessionState.catalog.getTableMetadata(TableIdentifier("t")) + assert(table.location.toString.startsWith("file:/")) + } + + withTempDir { dir => + assert(!dir.getAbsolutePath.startsWith("file:/")) + spark.sql( + s""" + |CREATE TABLE t1(a string, b string) + |USING parquet + |PARTITIONED BY(b) + |LOCATION '$dir' + """.stripMargin) + val table = spark.sessionState.catalog.getTableMetadata(TableIdentifier("t1")) + assert(table.location.toString.startsWith("file:/")) + } + } + } + + val supportedNativeFileFormatsForAlterTableAddColumns = Seq("parquet", "json", "csv") + + supportedNativeFileFormatsForAlterTableAddColumns.foreach { provider => + test(s"alter datasource table add columns - $provider") { + withTable("t1") { + sql(s"CREATE TABLE t1 (c1 int) USING $provider") + sql("INSERT INTO t1 VALUES (1)") + sql("ALTER TABLE t1 ADD COLUMNS (c2 int)") + checkAnswer( + spark.table("t1"), + Seq(Row(1, null)) + ) + checkAnswer( + sql("SELECT * FROM t1 WHERE c2 is null"), + Seq(Row(1, null)) + ) + + sql("INSERT INTO t1 VALUES (3, 2)") + checkAnswer( + sql("SELECT * FROM t1 WHERE c2 = 2"), + Seq(Row(3, 2)) + ) + } + } + } + + supportedNativeFileFormatsForAlterTableAddColumns.foreach { provider => + test(s"alter datasource table add columns - partitioned - $provider") { + withTable("t1") { + sql(s"CREATE TABLE t1 (c1 int, c2 int) USING $provider PARTITIONED BY (c2)") + sql("INSERT INTO t1 PARTITION(c2 = 2) VALUES (1)") + sql("ALTER TABLE t1 ADD COLUMNS (c3 int)") + checkAnswer( + spark.table("t1"), + Seq(Row(1, null, 2)) + ) + checkAnswer( + sql("SELECT * FROM t1 WHERE c3 is null"), + Seq(Row(1, null, 2)) + ) + sql("INSERT INTO t1 PARTITION(c2 =1) VALUES (2, 3)") + checkAnswer( + sql("SELECT * FROM t1 WHERE c3 = 3"), + Seq(Row(2, 3, 1)) + ) + checkAnswer( + sql("SELECT * FROM t1 WHERE c2 = 1"), + Seq(Row(2, 3, 1)) + ) + } + } + } + + test("alter datasource table add columns - text format not supported") { + withTable("t1") { + sql("CREATE TABLE t1 (c1 int) USING text") + val e = intercept[AnalysisException] { + sql("ALTER TABLE t1 ADD COLUMNS (c2 int)") + }.getMessage + assert(e.contains("ALTER ADD COLUMNS does not support datasource table with type")) + } + } - checkAnswer(spark.table("t"), Row(3, 4, 1, 2)) + test("alter table add columns -- not support temp view") { + withTempView("tmp_v") { + sql("CREATE TEMPORARY VIEW tmp_v AS SELECT 1 AS c1, 2 AS c2") + val e = intercept[AnalysisException] { + sql("ALTER TABLE tmp_v ADD COLUMNS (c3 INT)") + } + assert(e.message.contains("ALTER ADD COLUMNS does not support views")) + } + } + + test("alter table add columns -- not support view") { + withView("v1") { + sql("CREATE VIEW v1 AS SELECT 1 AS c1, 2 AS c2") + val e = intercept[AnalysisException] { + sql("ALTER TABLE v1 ADD COLUMNS (c3 INT)") + } + assert(e.message.contains("ALTER ADD COLUMNS does not support views")) + } + } + + test("alter table add columns with existing column name") { + withTable("t1") { + sql("CREATE TABLE t1 (c1 int) USING PARQUET") + val e = intercept[AnalysisException] { + sql("ALTER TABLE t1 ADD COLUMNS (c1 string)") + }.getMessage + assert(e.contains("Found duplicate column(s)")) + } + } + + Seq(true, false).foreach { caseSensitive => + test(s"alter table add columns with existing column name - caseSensitive $caseSensitive") { + withSQLConf(SQLConf.CASE_SENSITIVE.key -> s"$caseSensitive") { + withTable("t1") { + sql("CREATE TABLE t1 (c1 int) USING PARQUET") + if (!caseSensitive) { + val e = intercept[AnalysisException] { + sql("ALTER TABLE t1 ADD COLUMNS (C1 string)") + }.getMessage + assert(e.contains("Found duplicate column(s)")) + } else { + if (isUsingHiveMetastore) { + // hive catalog will still complains that c1 is duplicate column name because hive + // identifiers are case insensitive. + val e = intercept[AnalysisException] { + sql("ALTER TABLE t1 ADD COLUMNS (C1 string)") + }.getMessage + assert(e.contains("HiveException")) + } else { + sql("ALTER TABLE t1 ADD COLUMNS (C1 string)") + assert(spark.table("t1").schema + .equals(new StructType().add("c1", IntegerType).add("C1", StringType))) + } + } } - // partition table - withTempDir { - dir => - if (shouldDelete) { - dir.delete() + } + } + + test(s"basic DDL using locale tr - caseSensitive $caseSensitive") { + withSQLConf(SQLConf.CASE_SENSITIVE.key -> s"$caseSensitive") { + withLocale("tr") { + val dbName = "DaTaBaSe_I" + withDatabase(dbName) { + sql(s"CREATE DATABASE $dbName") + sql(s"USE $dbName") + + val tabName = "tAb_I" + withTable(tabName) { + sql(s"CREATE TABLE $tabName(col_I int) USING PARQUET") + sql(s"INSERT OVERWRITE TABLE $tabName SELECT 1") + checkAnswer(sql(s"SELECT col_I FROM $tabName"), Row(1) :: Nil) } - spark.sql( - s""" - |CREATE TABLE t1 - |USING parquet - |PARTITIONED BY(a, b) - |LOCATION '$dir' - |AS SELECT 3 as a, 4 as b, 1 as c, 2 as d - """.stripMargin) - val table = spark.sessionState.catalog.getTableMetadata(TableIdentifier("t1")) - assert(table.location == dir.getAbsolutePath) - - val partDir = new File(dir, "a=3") - assert(partDir.exists()) - - checkAnswer(spark.table("t1"), Row(1, 2, 3, 4)) + } } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileIndexSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileIndexSuite.scala index 7ea4064927576..b4616826e40b3 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileIndexSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileIndexSuite.scala @@ -27,8 +27,10 @@ import org.apache.hadoop.fs.{FileStatus, Path, RawLocalFileSystem} import org.apache.spark.metrics.source.HiveCatalogMetrics import org.apache.spark.sql.catalyst.util._ +import org.apache.spark.sql.functions.col import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSQLContext +import org.apache.spark.util.{KnownSizeEstimation, SizeEstimator} class FileIndexSuite extends SharedSQLContext { @@ -135,15 +137,15 @@ class FileIndexSuite extends SharedSQLContext { } } - test("PartitioningAwareFileIndex - file filtering") { - assert(!PartitioningAwareFileIndex.shouldFilterOut("abcd")) - assert(PartitioningAwareFileIndex.shouldFilterOut(".ab")) - assert(PartitioningAwareFileIndex.shouldFilterOut("_cd")) - assert(!PartitioningAwareFileIndex.shouldFilterOut("_metadata")) - assert(!PartitioningAwareFileIndex.shouldFilterOut("_common_metadata")) - assert(PartitioningAwareFileIndex.shouldFilterOut("_ab_metadata")) - assert(PartitioningAwareFileIndex.shouldFilterOut("_cd_common_metadata")) - assert(PartitioningAwareFileIndex.shouldFilterOut("a._COPYING_")) + test("InMemoryFileIndex - file filtering") { + assert(!InMemoryFileIndex.shouldFilterOut("abcd")) + assert(InMemoryFileIndex.shouldFilterOut(".ab")) + assert(InMemoryFileIndex.shouldFilterOut("_cd")) + assert(!InMemoryFileIndex.shouldFilterOut("_metadata")) + assert(!InMemoryFileIndex.shouldFilterOut("_common_metadata")) + assert(InMemoryFileIndex.shouldFilterOut("_ab_metadata")) + assert(InMemoryFileIndex.shouldFilterOut("_cd_common_metadata")) + assert(InMemoryFileIndex.shouldFilterOut("a._COPYING_")) } test("SPARK-17613 - PartitioningAwareFileIndex: base path w/o '/' at end") { @@ -220,6 +222,32 @@ class FileIndexSuite extends SharedSQLContext { assert(catalog.leafDirPaths.head == fs.makeQualified(dirPath)) } } + + test("SPARK-20280 - FileStatusCache with a partition with very many files") { + /* fake the size, otherwise we need to allocate 2GB of data to trigger this bug */ + class MyFileStatus extends FileStatus with KnownSizeEstimation { + override def estimatedSize: Long = 1000 * 1000 * 1000 + } + /* files * MyFileStatus.estimatedSize should overflow to negative integer + * so, make it between 2bn and 4bn + */ + val files = (1 to 3).map { i => + new MyFileStatus() + } + val fileStatusCache = FileStatusCache.getOrCreate(spark) + fileStatusCache.putLeafFiles(new Path("/tmp", "abc"), files.toArray) + } + + test("SPARK-20367 - properly unescape column names in inferPartitioning") { + withTempPath { path => + val colToUnescape = "Column/#%'?" + spark + .range(1) + .select(col("id").as(colToUnescape), col("id")) + .write.partitionBy(colToUnescape).parquet(path.getAbsolutePath) + assert(spark.read.parquet(path.getAbsolutePath).schema.exists(_.name == colToUnescape)) + } + } } class FakeParentPathFileSystem extends RawLocalFileSystem { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategySuite.scala index ba3f62c562c3e..9dcf9437a7349 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategySuite.scala @@ -42,7 +42,7 @@ import org.apache.spark.util.Utils class FileSourceStrategySuite extends QueryTest with SharedSQLContext with PredicateHelper { import testImplicits._ - protected override val sparkConf = new SparkConf().set("spark.default.parallelism", "1") + protected override def sparkConf = super.sparkConf.set("spark.default.parallelism", "1") test("unpartitioned table, single partition") { val table = @@ -395,7 +395,7 @@ class FileSourceStrategySuite extends QueryTest with SharedSQLContext with Predi val fileCatalog = new InMemoryFileIndex( sparkSession = spark, - rootPaths = Seq(new Path(tempDir)), + rootPathsSpecified = Seq(new Path(tempDir)), parameters = Map.empty[String, String], partitionSchema = None) // This should not fail. diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala index 56071803f685f..352dba79a4c08 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala @@ -29,6 +29,7 @@ import org.apache.hadoop.io.SequenceFile.CompressionType import org.apache.spark.SparkException import org.apache.spark.sql.{AnalysisException, DataFrame, QueryTest, Row, UDT} +import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.functions.{col, regexp_replace} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.{SharedSQLContext, SQLTestUtils} @@ -129,6 +130,22 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils { verifyCars(cars, withHeader = true, checkTypes = true) } + test("simple csv test with string dataset") { + val csvDataset = spark.read.text(testFile(carsFile)).as[String] + val cars = spark.read + .option("header", "true") + .option("inferSchema", "true") + .csv(csvDataset) + + verifyCars(cars, withHeader = true, checkTypes = true) + + val carsWithoutHeader = spark.read + .option("header", "false") + .csv(csvDataset) + + verifyCars(carsWithoutHeader, withHeader = false, checkTypes = false) + } + test("test inferring booleans") { val result = spark.read .format("csv") @@ -276,7 +293,7 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils { .load(testFile(carsFile)).collect() } - assert(exception.getMessage.contains("Malformed line in FAILFAST mode: 2015,Chevy,Volt")) + assert(exception.getMessage.contains("Malformed CSV record")) } } @@ -749,7 +766,7 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils { .option("header", "true") .load(iso8601timestampsPath) - val iso8501 = FastDateFormat.getInstance("yyyy-MM-dd'T'HH:mm:ss.SSSZZ", Locale.US) + val iso8501 = FastDateFormat.getInstance("yyyy-MM-dd'T'HH:mm:ss.SSSXXX", Locale.US) val expectedTimestamps = timestamps.collect().map { r => // This should be ISO8601 formatted string. Row(iso8501.format(r.toSeq.head)) @@ -896,7 +913,7 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils { .format("csv") .option("header", "true") .option("timestampFormat", "yyyy/MM/dd HH:mm") - .option("timeZone", "GMT") + .option(DateTimeUtils.TIMEZONE_OPTION, "GMT") .save(timestampsWithFormatPath) // This will load back the timestamps as string. @@ -918,7 +935,7 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils { .option("header", "true") .option("inferSchema", "true") .option("timestampFormat", "yyyy/MM/dd HH:mm") - .option("timeZone", "GMT") + .option(DateTimeUtils.TIMEZONE_OPTION, "GMT") .load(timestampsWithFormatPath) checkAnswer(readBack, timestampsWithFormat) @@ -975,9 +992,10 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils { test("SPARK-18699 put malformed records in a `columnNameOfCorruptRecord` field") { Seq(false, true).foreach { wholeFile => val schema = new StructType().add("a", IntegerType).add("b", TimestampType) + // We use `PERMISSIVE` mode by default if invalid string is given. val df1 = spark .read - .option("mode", "PERMISSIVE") + .option("mode", "abcd") .option("wholeFile", wholeFile) .schema(schema) .csv(testFile(valueMalformedFile)) @@ -991,7 +1009,7 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils { val schemaWithCorrField1 = schema.add(columnNameOfCorruptRecord, StringType) val df2 = spark .read - .option("mode", "PERMISSIVE") + .option("mode", "Permissive") .option("columnNameOfCorruptRecord", columnNameOfCorruptRecord) .option("wholeFile", wholeFile) .schema(schemaWithCorrField1) @@ -1008,7 +1026,7 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils { .add("b", TimestampType) val df3 = spark .read - .option("mode", "PERMISSIVE") + .option("mode", "permissive") .option("columnNameOfCorruptRecord", columnNameOfCorruptRecord) .option("wholeFile", wholeFile) .schema(schemaWithCorrField2) @@ -1077,17 +1095,83 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils { } } - test("Empty file produces empty dataframe with empty schema - wholeFile option") { - withTempPath { path => - path.createNewFile() - + test("Empty file produces empty dataframe with empty schema") { + Seq(false, true).foreach { wholeFile => val df = spark.read.format("csv") .option("header", true) - .option("wholeFile", true) - .load(path.getAbsolutePath) + .option("wholeFile", wholeFile) + .load(testFile(emptyFile)) assert(df.schema === spark.emptyDataFrame.schema) checkAnswer(df, spark.emptyDataFrame) } } + + test("Empty string dataset produces empty dataframe and keep user-defined schema") { + val df1 = spark.read.csv(spark.emptyDataset[String]) + assert(df1.schema === spark.emptyDataFrame.schema) + checkAnswer(df1, spark.emptyDataFrame) + + val schema = StructType(StructField("a", StringType) :: Nil) + val df2 = spark.read.schema(schema).csv(spark.emptyDataset[String]) + assert(df2.schema === schema) + } + + test("ignoreLeadingWhiteSpace and ignoreTrailingWhiteSpace options - read") { + val input = " a,b , c " + + // For reading, default of both `ignoreLeadingWhiteSpace` and`ignoreTrailingWhiteSpace` + // are `false`. So, these are excluded. + val combinations = Seq( + (true, true), + (false, true), + (true, false)) + + // Check if read rows ignore whitespaces as configured. + val expectedRows = Seq( + Row("a", "b", "c"), + Row(" a", "b", " c"), + Row("a", "b ", "c ")) + + combinations.zip(expectedRows) + .foreach { case ((ignoreLeadingWhiteSpace, ignoreTrailingWhiteSpace), expected) => + val df = spark.read + .option("ignoreLeadingWhiteSpace", ignoreLeadingWhiteSpace) + .option("ignoreTrailingWhiteSpace", ignoreTrailingWhiteSpace) + .csv(Seq(input).toDS()) + + checkAnswer(df, expected) + } + } + + test("SPARK-18579: ignoreLeadingWhiteSpace and ignoreTrailingWhiteSpace options - write") { + val df = Seq((" a", "b ", " c ")).toDF() + + // For writing, default of both `ignoreLeadingWhiteSpace` and `ignoreTrailingWhiteSpace` + // are `true`. So, these are excluded. + val combinations = Seq( + (false, false), + (false, true), + (true, false)) + + // Check if written lines ignore each whitespaces as configured. + val expectedLines = Seq( + " a,b , c ", + " a,b, c", + "a,b ,c ") + + combinations.zip(expectedLines) + .foreach { case ((ignoreLeadingWhiteSpace, ignoreTrailingWhiteSpace), expected) => + withTempPath { path => + df.write + .option("ignoreLeadingWhiteSpace", ignoreLeadingWhiteSpace) + .option("ignoreTrailingWhiteSpace", ignoreTrailingWhiteSpace) + .csv(path.getAbsolutePath) + + // Read back the written lines. + val readBack = spark.read.text(path.getAbsolutePath) + checkAnswer(readBack, Row(expected)) + } + } + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala index 5961bd0a733df..782c34d40fb82 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala @@ -1042,9 +1042,8 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { spark.read .option("mode", "FAILFAST") .json(corruptRecords) - .collect() } - assert(exceptionOne.getMessage.contains("Malformed line in FAILFAST mode: {")) + assert(exceptionOne.getMessage.contains("JsonParseException")) val exceptionTwo = intercept[SparkException] { spark.read @@ -1053,7 +1052,7 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { .json(corruptRecords) .collect() } - assert(exceptionTwo.getMessage.contains("Malformed line in FAILFAST mode: {")) + assert(exceptionTwo.getMessage.contains("JsonParseException")) } test("Corrupt records: DROPMALFORMED mode") { @@ -1083,84 +1082,72 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { assert(jsonDFTwo.schema === schemaTwo) } + test("SPARK-19641: Additional corrupt records: DROPMALFORMED mode") { + val schema = new StructType().add("dummy", StringType) + // `DROPMALFORMED` mode should skip corrupt records + val jsonDF = spark.read + .option("mode", "DROPMALFORMED") + .json(additionalCorruptRecords) + checkAnswer( + jsonDF, + Row("test")) + assert(jsonDF.schema === schema) + } + test("Corrupt records: PERMISSIVE mode, without designated column for malformed records") { - withTempView("jsonTable") { - val schema = StructType( - StructField("a", StringType, true) :: - StructField("b", StringType, true) :: - StructField("c", StringType, true) :: Nil) + val schema = StructType( + StructField("a", StringType, true) :: + StructField("b", StringType, true) :: + StructField("c", StringType, true) :: Nil) - val jsonDF = spark.read.schema(schema).json(corruptRecords) - jsonDF.createOrReplaceTempView("jsonTable") + val jsonDF = spark.read.schema(schema).json(corruptRecords) - checkAnswer( - sql( - """ - |SELECT a, b, c - |FROM jsonTable - """.stripMargin), - Seq( - // Corrupted records are replaced with null - Row(null, null, null), - Row(null, null, null), - Row(null, null, null), - Row("str_a_4", "str_b_4", "str_c_4"), - Row(null, null, null)) - ) - } + checkAnswer( + jsonDF.select($"a", $"b", $"c"), + Seq( + // Corrupted records are replaced with null + Row(null, null, null), + Row(null, null, null), + Row(null, null, null), + Row("str_a_4", "str_b_4", "str_c_4"), + Row(null, null, null)) + ) } test("Corrupt records: PERMISSIVE mode, with designated column for malformed records") { // Test if we can query corrupt records. withSQLConf(SQLConf.COLUMN_NAME_OF_CORRUPT_RECORD.key -> "_unparsed") { - withTempView("jsonTable") { - val jsonDF = spark.read.json(corruptRecords) - jsonDF.createOrReplaceTempView("jsonTable") - val schema = StructType( - StructField("_unparsed", StringType, true) :: + val jsonDF = spark.read.json(corruptRecords) + val schema = StructType( + StructField("_unparsed", StringType, true) :: StructField("a", StringType, true) :: StructField("b", StringType, true) :: StructField("c", StringType, true) :: Nil) - assert(schema === jsonDF.schema) - - // In HiveContext, backticks should be used to access columns starting with a underscore. - checkAnswer( - sql( - """ - |SELECT a, b, c, _unparsed - |FROM jsonTable - """.stripMargin), - Row(null, null, null, "{") :: - Row(null, null, null, """{"a":1, b:2}""") :: - Row(null, null, null, """{"a":{, b:3}""") :: - Row("str_a_4", "str_b_4", "str_c_4", null) :: - Row(null, null, null, "]") :: Nil - ) - - checkAnswer( - sql( - """ - |SELECT a, b, c - |FROM jsonTable - |WHERE _unparsed IS NULL - """.stripMargin), - Row("str_a_4", "str_b_4", "str_c_4") - ) - - checkAnswer( - sql( - """ - |SELECT _unparsed - |FROM jsonTable - |WHERE _unparsed IS NOT NULL - """.stripMargin), - Row("{") :: - Row("""{"a":1, b:2}""") :: - Row("""{"a":{, b:3}""") :: - Row("]") :: Nil - ) - } + assert(schema === jsonDF.schema) + + // In HiveContext, backticks should be used to access columns starting with a underscore. + checkAnswer( + jsonDF.select($"a", $"b", $"c", $"_unparsed"), + Row(null, null, null, "{") :: + Row(null, null, null, """{"a":1, b:2}""") :: + Row(null, null, null, """{"a":{, b:3}""") :: + Row("str_a_4", "str_b_4", "str_c_4", null) :: + Row(null, null, null, "]") :: Nil + ) + + checkAnswer( + jsonDF.filter($"_unparsed".isNull).select($"a", $"b", $"c"), + Row("str_a_4", "str_b_4", "str_c_4") + ) + + checkAnswer( + jsonDF.filter($"_unparsed".isNotNull).select($"_unparsed"), + Row("{") :: + Row("""{"a":1, b:2}""") :: + Row("""{"a":{, b:3}""") :: + Row("]") :: Nil + ) } } @@ -1779,7 +1766,7 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { timestampsWithFormat.write .format("json") .option("timestampFormat", "yyyy/MM/dd HH:mm") - .option("timeZone", "GMT") + .option(DateTimeUtils.TIMEZONE_OPTION, "GMT") .save(timestampsWithFormatPath) // This will load back the timestamps as string. @@ -1797,7 +1784,7 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { val readBack = spark.read .schema(customSchema) .option("timestampFormat", "yyyy/MM/dd HH:mm") - .option("timeZone", "GMT") + .option(DateTimeUtils.TIMEZONE_OPTION, "GMT") .json(timestampsWithFormatPath) checkAnswer(readBack, timestampsWithFormat) @@ -1918,6 +1905,24 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { } } + test("SPARK-19641: Handle multi-line corrupt documents (DROPMALFORMED)") { + withTempPath { dir => + val path = dir.getCanonicalPath + val corruptRecordCount = additionalCorruptRecords.count().toInt + assert(corruptRecordCount === 5) + + additionalCorruptRecords + .toDF("value") + // this is the minimum partition count that avoids hash collisions + .repartition(corruptRecordCount * 4, F.hash($"value")) + .write + .text(path) + + val jsonDF = spark.read.option("wholeFile", true).option("mode", "DROPMALFORMED").json(path) + checkAnswer(jsonDF, Seq(Row("test"))) + } + } + test("SPARK-18352: Handle multi-line corrupt documents (FAILFAST)") { withTempPath { dir => val path = dir.getCanonicalPath @@ -1939,9 +1944,8 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { .option("wholeFile", true) .option("mode", "FAILFAST") .json(path) - .collect() } - assert(exceptionOne.getMessage.contains("Malformed line in FAILFAST mode")) + assert(exceptionOne.getMessage.contains("Failed to infer a common schema")) val exceptionTwo = intercept[SparkException] { spark.read @@ -1951,7 +1955,7 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { .json(path) .collect() } - assert(exceptionTwo.getMessage.contains("Malformed line in FAILFAST mode")) + assert(exceptionTwo.getMessage.contains("Failed to parse a value")) } } @@ -1964,19 +1968,20 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { StructField("c", StringType, true) :: Nil) val errMsg = intercept[AnalysisException] { spark.read - .option("mode", "PERMISSIVE") + .option("mode", "Permissive") .option("columnNameOfCorruptRecord", columnNameOfCorruptRecord) .schema(schema) .json(corruptRecords) }.getMessage assert(errMsg.startsWith("The field for corrupt records must be string type and nullable")) + // We use `PERMISSIVE` mode by default if invalid string is given. withTempPath { dir => val path = dir.getCanonicalPath corruptRecords.toDF("value").write.text(path) val errMsg = intercept[AnalysisException] { spark.read - .option("mode", "PERMISSIVE") + .option("mode", "permm") .option("columnNameOfCorruptRecord", columnNameOfCorruptRecord) .schema(schema) .json(path) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala index e9dfb143735af..efa319b436b46 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala @@ -34,7 +34,7 @@ import org.apache.spark.sql.functions._ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types._ -import org.apache.spark.util.{AccumulatorContext, LongAccumulator} +import org.apache.spark.util.{AccumulatorContext, AccumulatorV2} /** * A test suite that tests Parquet filter2 API based filter pushdown optimization. @@ -505,19 +505,17 @@ class ParquetFilterSuite extends QueryTest with ParquetTest with SharedSQLContex val path = s"${dir.getCanonicalPath}/table" (1 to 1024).map(i => (101, i)).toDF("a", "b").write.parquet(path) - Seq(("true", (x: Long) => x == 0), ("false", (x: Long) => x > 0)) + Seq(("true", (x: Integer) => x == 0), ("false", (x: Integer) => x > 0)) .foreach { case (push, func) => withSQLConf(SQLConf.PARQUET_FILTER_PUSHDOWN_ENABLED.key -> push) { - val accu = new LongAccumulator - accu.register(sparkContext, Some("numRowGroups")) + val accu = new NumRowGroupsAcc + sparkContext.register(accu) val df = spark.read.parquet(path).filter("a < 100") df.foreachPartition(_.foreach(v => accu.add(0))) df.collect - val numRowGroups = AccumulatorContext.lookForAccumulatorByName("numRowGroups") - assert(numRowGroups.isDefined) - assert(func(numRowGroups.get.asInstanceOf[LongAccumulator].value)) + assert(func(accu.value)) AccumulatorContext.remove(accu.id) } } @@ -593,3 +591,27 @@ class ParquetFilterSuite extends QueryTest with ParquetTest with SharedSQLContex } } } + +class NumRowGroupsAcc extends AccumulatorV2[Integer, Integer] { + private var _sum = 0 + + override def isZero: Boolean = _sum == 0 + + override def copy(): AccumulatorV2[Integer, Integer] = { + val acc = new NumRowGroupsAcc() + acc._sum = _sum + acc + } + + override def reset(): Unit = _sum = 0 + + override def add(v: Integer): Unit = _sum += v + + override def merge(other: AccumulatorV2[Integer, Integer]): Unit = other match { + case a: NumRowGroupsAcc => _sum += a._sum + case _ => throw new UnsupportedOperationException( + s"Cannot merge ${this.getClass.getName} with ${other.getClass.getName}") + } + + override def value: Integer = _sum +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala index b492e30a72acf..6c814d2d6875d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.execution.datasources.parquet +import java.util.Locale + import scala.collection.JavaConverters._ import scala.collection.mutable import scala.reflect.ClassTag @@ -107,11 +109,13 @@ class ParquetIOSuite extends QueryTest with ParquetTest with SharedSQLContext { | required binary g(ENUM); | required binary h(DECIMAL(32,0)); | required fixed_len_byte_array(32) i(DECIMAL(32,0)); + | required int64 j(TIMESTAMP_MILLIS); |} """.stripMargin) val expectedSparkTypes = Seq(ByteType, ShortType, DateType, DecimalType(1, 0), - DecimalType(10, 0), StringType, StringType, DecimalType(32, 0), DecimalType(32, 0)) + DecimalType(10, 0), StringType, StringType, DecimalType(32, 0), DecimalType(32, 0), + TimestampType) withTempPath { location => val path = new Path(location.getCanonicalPath) @@ -298,7 +302,7 @@ class ParquetIOSuite extends QueryTest with ParquetTest with SharedSQLContext { def checkCompressionCodec(codec: CompressionCodecName): Unit = { withSQLConf(SQLConf.PARQUET_COMPRESSION.key -> codec.name()) { withParquetFile(data) { path => - assertResult(spark.conf.get(SQLConf.PARQUET_COMPRESSION).toUpperCase) { + assertResult(spark.conf.get(SQLConf.PARQUET_COMPRESSION).toUpperCase(Locale.ROOT)) { compressionCodecFor(path, codec.name()) } } @@ -607,6 +611,18 @@ class ParquetIOSuite extends QueryTest with ParquetTest with SharedSQLContext { } } + test("read dictionary and plain encoded timestamp_millis written as INT64") { + ("true" :: "false" :: Nil).foreach { vectorized => + withSQLConf(SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key -> vectorized) { + checkAnswer( + // timestamp column in this file is encoded using combination of plain + // and dictionary encodings. + readResourceParquetFile("test-data/timemillis-in-i64.parquet"), + (1 to 3).map(i => Row(new java.sql.Timestamp(10)))) + } + } + } + test("SPARK-12589 copy() on rows returned from reader works for strings") { withTempPath { dir => val data = (1, "abc") ::(2, "helloabcde") :: Nil diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetPartitionDiscoverySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetPartitionDiscoverySuite.scala index 2e36f2c9528f2..624d4d234d59d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetPartitionDiscoverySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetPartitionDiscoverySuite.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.execution.datasources.parquet import java.io.File import java.math.BigInteger import java.sql.{Date, Timestamp} -import java.util.{Calendar, TimeZone} +import java.util.{Calendar, Locale, TimeZone} import scala.collection.mutable.ArrayBuffer @@ -32,7 +32,9 @@ import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.catalog.ExternalCatalogUtils import org.apache.spark.sql.catalyst.expressions.Literal -import org.apache.spark.sql.execution.datasources.{PartitionPath => Partition, _} +import org.apache.spark.sql.catalyst.util.DateTimeUtils +import org.apache.spark.sql.execution.datasources._ +import org.apache.spark.sql.execution.datasources.{PartitionPath => Partition} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types._ @@ -473,7 +475,8 @@ class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest with Sha assert(partDf.schema.map(_.name) === Seq("intField", "stringField")) path.listFiles().foreach { f => - if (!f.getName.startsWith("_") && f.getName.toLowerCase().endsWith(".parquet")) { + if (!f.getName.startsWith("_") && + f.getName.toLowerCase(Locale.ROOT).endsWith(".parquet")) { // when the input is a path to a parquet file val df = spark.read.parquet(f.getCanonicalPath) assert(df.schema.map(_.name) === Seq("intField", "stringField")) @@ -481,7 +484,8 @@ class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest with Sha } path.listFiles().foreach { f => - if (!f.getName.startsWith("_") && f.getName.toLowerCase().endsWith(".parquet")) { + if (!f.getName.startsWith("_") && + f.getName.toLowerCase(Locale.ROOT).endsWith(".parquet")) { // when the input is a path to a parquet file but `basePath` is overridden to // the base path containing partitioning directories val df = spark @@ -706,10 +710,11 @@ class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest with Sha } withTempPath { dir => - df.write.option("timeZone", "GMT") + df.write.option(DateTimeUtils.TIMEZONE_OPTION, "GMT") .format("parquet").partitionBy(partitionColumns.map(_.name): _*).save(dir.toString) val fields = schema.map(f => Column(f.name).cast(f.dataType)) - checkAnswer(spark.read.option("timeZone", "GMT").load(dir.toString).select(fields: _*), row) + checkAnswer(spark.read.option(DateTimeUtils.TIMEZONE_OPTION, "GMT") + .load(dir.toString).select(fields: _*), row) } } @@ -747,10 +752,11 @@ class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest with Sha } withTempPath { dir => - df.write.option("timeZone", "GMT") + df.write.option(DateTimeUtils.TIMEZONE_OPTION, "GMT") .format("parquet").partitionBy(partitionColumns.map(_.name): _*).save(dir.toString) val fields = schema.map(f => Column(f.name)) - checkAnswer(spark.read.option("timeZone", "GMT").load(dir.toString).select(fields: _*), row) + checkAnswer(spark.read.option(DateTimeUtils.TIMEZONE_OPTION, "GMT") + .load(dir.toString).select(fields: _*), row) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetQuerySuite.scala index fd10356056482..d1046afbddfac 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetQuerySuite.scala @@ -25,7 +25,7 @@ import com.google.common.collect.{HashMultiset, Multiset} import org.apache.hadoop.fs.{FileSystem, FSDataInputStream, Path, RawLocalFileSystem} import org.apache.parquet.hadoop.ParquetOutputFormat -import org.apache.spark.SparkException +import org.apache.spark.{DebugFilesystem, SparkException} import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.{InternalRow, TableIdentifier} import org.apache.spark.sql.catalyst.expressions.SpecificInternalRow @@ -154,7 +154,7 @@ class ParquetQuerySuite extends QueryTest with ParquetTest with SharedSQLContext } test("SPARK-6917 DecimalType should work with non-native types") { - val data = (1 to 10).map(i => Row(Decimal(i, 18, 0), new java.sql.Timestamp(i))) + val data = (1 to 10).map(i => Row(Decimal(i, 18, 0), new Timestamp(i))) val schema = StructType(List(StructField("d", DecimalType(18, 0), false), StructField("time", TimestampType, false)).toArray) withTempPath { file => @@ -165,6 +165,78 @@ class ParquetQuerySuite extends QueryTest with ParquetTest with SharedSQLContext } } + test("SPARK-10634 timestamp written and read as INT64 - TIMESTAMP_MILLIS") { + val data = (1 to 10).map(i => Row(i, new Timestamp(i))) + val schema = StructType(List(StructField("d", IntegerType, false), + StructField("time", TimestampType, false)).toArray) + withSQLConf(SQLConf.PARQUET_INT64_AS_TIMESTAMP_MILLIS.key -> "true") { + withTempPath { file => + val df = spark.createDataFrame(sparkContext.parallelize(data), schema) + df.write.parquet(file.getCanonicalPath) + ("true" :: "false" :: Nil).foreach { vectorized => + withSQLConf(SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key -> vectorized) { + val df2 = spark.read.parquet(file.getCanonicalPath) + checkAnswer(df2, df.collect().toSeq) + } + } + } + } + } + + test("SPARK-10634 timestamp written and read as INT64 - truncation") { + withTable("ts") { + sql("create table ts (c1 int, c2 timestamp) using parquet") + sql("insert into ts values (1, '2016-01-01 10:11:12.123456')") + sql("insert into ts values (2, null)") + sql("insert into ts values (3, '1965-01-01 10:11:12.123456')") + checkAnswer( + sql("select * from ts"), + Seq( + Row(1, Timestamp.valueOf("2016-01-01 10:11:12.123456")), + Row(2, null), + Row(3, Timestamp.valueOf("1965-01-01 10:11:12.123456")))) + } + + // The microsecond portion is truncated when written as TIMESTAMP_MILLIS. + withTable("ts") { + withSQLConf(SQLConf.PARQUET_INT64_AS_TIMESTAMP_MILLIS.key -> "true") { + sql("create table ts (c1 int, c2 timestamp) using parquet") + sql("insert into ts values (1, '2016-01-01 10:11:12.123456')") + sql("insert into ts values (2, null)") + sql("insert into ts values (3, '1965-01-01 10:11:12.125456')") + sql("insert into ts values (4, '1965-01-01 10:11:12.125')") + sql("insert into ts values (5, '1965-01-01 10:11:12.1')") + sql("insert into ts values (6, '1965-01-01 10:11:12.123456789')") + sql("insert into ts values (7, '0001-01-01 00:00:00.000000')") + checkAnswer( + sql("select * from ts"), + Seq( + Row(1, Timestamp.valueOf("2016-01-01 10:11:12.123")), + Row(2, null), + Row(3, Timestamp.valueOf("1965-01-01 10:11:12.125")), + Row(4, Timestamp.valueOf("1965-01-01 10:11:12.125")), + Row(5, Timestamp.valueOf("1965-01-01 10:11:12.1")), + Row(6, Timestamp.valueOf("1965-01-01 10:11:12.123")), + Row(7, Timestamp.valueOf("0001-01-01 00:00:00.000")))) + + // Read timestamps that were encoded as TIMESTAMP_MILLIS annotated as INT64 + // with PARQUET_INT64_AS_TIMESTAMP_MILLIS set to false. + withSQLConf(SQLConf.PARQUET_INT64_AS_TIMESTAMP_MILLIS.key -> "false") { + checkAnswer( + sql("select * from ts"), + Seq( + Row(1, Timestamp.valueOf("2016-01-01 10:11:12.123")), + Row(2, null), + Row(3, Timestamp.valueOf("1965-01-01 10:11:12.125")), + Row(4, Timestamp.valueOf("1965-01-01 10:11:12.125")), + Row(5, Timestamp.valueOf("1965-01-01 10:11:12.1")), + Row(6, Timestamp.valueOf("1965-01-01 10:11:12.123")), + Row(7, Timestamp.valueOf("0001-01-01 00:00:00.000")))) + } + } + } + } + test("Enabling/disabling merging partfiles when merging parquet schema") { def testSchemaMerging(expectedColumnNumber: Int): Unit = { withTempDir { dir => @@ -246,6 +318,39 @@ class ParquetQuerySuite extends QueryTest with ParquetTest with SharedSQLContext } } + /** + * this is part of test 'Enabling/disabling ignoreCorruptFiles' but run in a loop + * to increase the chance of failure + */ + ignore("SPARK-20407 ParquetQuerySuite 'Enabling/disabling ignoreCorruptFiles' flaky test") { + def testIgnoreCorruptFiles(): Unit = { + withTempDir { dir => + val basePath = dir.getCanonicalPath + spark.range(1).toDF("a").write.parquet(new Path(basePath, "first").toString) + spark.range(1, 2).toDF("a").write.parquet(new Path(basePath, "second").toString) + spark.range(2, 3).toDF("a").write.json(new Path(basePath, "third").toString) + val df = spark.read.parquet( + new Path(basePath, "first").toString, + new Path(basePath, "second").toString, + new Path(basePath, "third").toString) + checkAnswer( + df, + Seq(Row(0), Row(1))) + } + } + + for (i <- 1 to 100) { + DebugFilesystem.clearOpenStreams() + withSQLConf(SQLConf.IGNORE_CORRUPT_FILES.key -> "false") { + val exception = intercept[SparkException] { + testIgnoreCorruptFiles() + } + assert(exception.getMessage().contains("is not a Parquet file")) + } + DebugFilesystem.assertNoOpenStreams() + } + } + test("SPARK-8990 DataFrameReader.parquet() should respect user specified options") { withTempPath { dir => val basePath = dir.getCanonicalPath diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaSuite.scala index 8a980a7eb538f..ce992674d719f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaSuite.scala @@ -53,11 +53,13 @@ abstract class ParquetSchemaTest extends ParquetTest with SharedSQLContext { parquetSchema: String, binaryAsString: Boolean, int96AsTimestamp: Boolean, - writeLegacyParquetFormat: Boolean): Unit = { + writeLegacyParquetFormat: Boolean, + int64AsTimestampMillis: Boolean = false): Unit = { val converter = new ParquetSchemaConverter( assumeBinaryIsString = binaryAsString, assumeInt96IsTimestamp = int96AsTimestamp, - writeLegacyParquetFormat = writeLegacyParquetFormat) + writeLegacyParquetFormat = writeLegacyParquetFormat, + writeTimestampInMillis = int64AsTimestampMillis) test(s"sql <= parquet: $testName") { val actual = converter.convert(MessageTypeParser.parseMessageType(parquetSchema)) @@ -77,11 +79,13 @@ abstract class ParquetSchemaTest extends ParquetTest with SharedSQLContext { parquetSchema: String, binaryAsString: Boolean, int96AsTimestamp: Boolean, - writeLegacyParquetFormat: Boolean): Unit = { + writeLegacyParquetFormat: Boolean, + int64AsTimestampMillis: Boolean = false): Unit = { val converter = new ParquetSchemaConverter( assumeBinaryIsString = binaryAsString, assumeInt96IsTimestamp = int96AsTimestamp, - writeLegacyParquetFormat = writeLegacyParquetFormat) + writeLegacyParquetFormat = writeLegacyParquetFormat, + writeTimestampInMillis = int64AsTimestampMillis) test(s"sql => parquet: $testName") { val actual = converter.convert(sqlSchema) @@ -97,7 +101,8 @@ abstract class ParquetSchemaTest extends ParquetTest with SharedSQLContext { parquetSchema: String, binaryAsString: Boolean, int96AsTimestamp: Boolean, - writeLegacyParquetFormat: Boolean): Unit = { + writeLegacyParquetFormat: Boolean, + int64AsTimestampMillis: Boolean = false): Unit = { testCatalystToParquet( testName, @@ -105,7 +110,8 @@ abstract class ParquetSchemaTest extends ParquetTest with SharedSQLContext { parquetSchema, binaryAsString, int96AsTimestamp, - writeLegacyParquetFormat) + writeLegacyParquetFormat, + int64AsTimestampMillis) testParquetToCatalyst( testName, @@ -113,7 +119,8 @@ abstract class ParquetSchemaTest extends ParquetTest with SharedSQLContext { parquetSchema, binaryAsString, int96AsTimestamp, - writeLegacyParquetFormat) + writeLegacyParquetFormat, + int64AsTimestampMillis) } } @@ -368,88 +375,6 @@ class ParquetSchemaSuite extends ParquetSchemaTest { } } - test("merge with metastore schema") { - // Field type conflict resolution - assertResult( - StructType(Seq( - StructField("lowerCase", StringType), - StructField("UPPERCase", DoubleType, nullable = false)))) { - - ParquetFileFormat.mergeMetastoreParquetSchema( - StructType(Seq( - StructField("lowercase", StringType), - StructField("uppercase", DoubleType, nullable = false))), - - StructType(Seq( - StructField("lowerCase", BinaryType), - StructField("UPPERCase", IntegerType, nullable = true)))) - } - - // MetaStore schema is subset of parquet schema - assertResult( - StructType(Seq( - StructField("UPPERCase", DoubleType, nullable = false)))) { - - ParquetFileFormat.mergeMetastoreParquetSchema( - StructType(Seq( - StructField("uppercase", DoubleType, nullable = false))), - - StructType(Seq( - StructField("lowerCase", BinaryType), - StructField("UPPERCase", IntegerType, nullable = true)))) - } - - // Metastore schema contains additional non-nullable fields. - assert(intercept[Throwable] { - ParquetFileFormat.mergeMetastoreParquetSchema( - StructType(Seq( - StructField("uppercase", DoubleType, nullable = false), - StructField("lowerCase", BinaryType, nullable = false))), - - StructType(Seq( - StructField("UPPERCase", IntegerType, nullable = true)))) - }.getMessage.contains("detected conflicting schemas")) - - // Conflicting non-nullable field names - intercept[Throwable] { - ParquetFileFormat.mergeMetastoreParquetSchema( - StructType(Seq(StructField("lower", StringType, nullable = false))), - StructType(Seq(StructField("lowerCase", BinaryType)))) - } - } - - test("merge missing nullable fields from Metastore schema") { - // Standard case: Metastore schema contains additional nullable fields not present - // in the Parquet file schema. - assertResult( - StructType(Seq( - StructField("firstField", StringType, nullable = true), - StructField("secondField", StringType, nullable = true), - StructField("thirdfield", StringType, nullable = true)))) { - ParquetFileFormat.mergeMetastoreParquetSchema( - StructType(Seq( - StructField("firstfield", StringType, nullable = true), - StructField("secondfield", StringType, nullable = true), - StructField("thirdfield", StringType, nullable = true))), - StructType(Seq( - StructField("firstField", StringType, nullable = true), - StructField("secondField", StringType, nullable = true)))) - } - - // Merge should fail if the Metastore contains any additional fields that are not - // nullable. - assert(intercept[Throwable] { - ParquetFileFormat.mergeMetastoreParquetSchema( - StructType(Seq( - StructField("firstfield", StringType, nullable = true), - StructField("secondfield", StringType, nullable = true), - StructField("thirdfield", StringType, nullable = false))), - StructType(Seq( - StructField("firstField", StringType, nullable = true), - StructField("secondField", StringType, nullable = true)))) - }.getMessage.contains("detected conflicting schemas")) - } - test("schema merging failure error message") { import testImplicits._ @@ -1047,6 +972,18 @@ class ParquetSchemaSuite extends ParquetSchemaTest { int96AsTimestamp = true, writeLegacyParquetFormat = true) + testSchema( + "Timestamp written and read as INT64 with TIMESTAMP_MILLIS", + StructType(Seq(StructField("f1", TimestampType))), + """message root { + | optional INT64 f1 (TIMESTAMP_MILLIS); + |} + """.stripMargin, + binaryAsString = true, + int96AsTimestamp = false, + writeLegacyParquetFormat = true, + int64AsTimestampMillis = true) + private def testSchemaClipping( testName: String, parquetSchema: String, diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala index 9c55357ab9bc1..26c45e092dc65 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala @@ -22,15 +22,12 @@ import scala.reflect.ClassTag import org.apache.spark.AccumulatorSuite import org.apache.spark.sql.{Dataset, QueryTest, Row, SparkSession} import org.apache.spark.sql.catalyst.expressions.{BitwiseAnd, BitwiseOr, Cast, Literal, ShiftLeft} -import org.apache.spark.sql.catalyst.plans.logical.SubqueryAlias -import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.execution.exchange.EnsureRequirements import org.apache.spark.sql.execution.SparkPlan import org.apache.spark.sql.functions._ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SQLTestUtils import org.apache.spark.sql.types.{LongType, ShortType} -import org.apache.spark.util.Utils /** * Test various broadcast join operators. diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/CompactibleFileStreamLogSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/CompactibleFileStreamLogSuite.scala index 24d92a96237e3..3d480b148db55 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/CompactibleFileStreamLogSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/CompactibleFileStreamLogSuite.scala @@ -28,8 +28,8 @@ import org.apache.spark.sql.test.SharedSQLContext class CompactibleFileStreamLogSuite extends SparkFunSuite with SharedSQLContext { /** To avoid caching of FS objects */ - override protected val sparkConf = - new SparkConf().set(s"spark.hadoop.fs.$scheme.impl.disable.cache", "true") + override protected def sparkConf = + super.sparkConf.set(s"spark.hadoop.fs.$scheme.impl.disable.cache", "true") import CompactibleFileStreamLog._ @@ -122,7 +122,7 @@ class CompactibleFileStreamLogSuite extends SparkFunSuite with SharedSQLContext defaultMinBatchesToRetain = 1, compactibleLog => { val logs = Array("entry_1", "entry_2", "entry_3") - val expected = s"""${FakeCompactibleFileStreamLog.VERSION} + val expected = s"""v${FakeCompactibleFileStreamLog.VERSION} |"entry_1" |"entry_2" |"entry_3"""".stripMargin @@ -132,7 +132,7 @@ class CompactibleFileStreamLogSuite extends SparkFunSuite with SharedSQLContext baos.reset() compactibleLog.serialize(Array(), baos) - assert(FakeCompactibleFileStreamLog.VERSION === baos.toString(UTF_8.name())) + assert(s"v${FakeCompactibleFileStreamLog.VERSION}" === baos.toString(UTF_8.name())) }) } @@ -142,7 +142,7 @@ class CompactibleFileStreamLogSuite extends SparkFunSuite with SharedSQLContext defaultCompactInterval = 3, defaultMinBatchesToRetain = 1, compactibleLog => { - val logs = s"""${FakeCompactibleFileStreamLog.VERSION} + val logs = s"""v${FakeCompactibleFileStreamLog.VERSION} |"entry_1" |"entry_2" |"entry_3"""".stripMargin @@ -152,10 +152,36 @@ class CompactibleFileStreamLogSuite extends SparkFunSuite with SharedSQLContext assert(Nil === compactibleLog.deserialize( - new ByteArrayInputStream(FakeCompactibleFileStreamLog.VERSION.getBytes(UTF_8)))) + new ByteArrayInputStream(s"v${FakeCompactibleFileStreamLog.VERSION}".getBytes(UTF_8)))) }) } + test("deserialization log written by future version") { + withTempDir { dir => + def newFakeCompactibleFileStreamLog(version: Int): FakeCompactibleFileStreamLog = + new FakeCompactibleFileStreamLog( + version, + _fileCleanupDelayMs = Long.MaxValue, // this param does not matter here in this test case + _defaultCompactInterval = 3, // this param does not matter here in this test case + _defaultMinBatchesToRetain = 1, // this param does not matter here in this test case + spark, + dir.getCanonicalPath) + + val writer = newFakeCompactibleFileStreamLog(version = 2) + val reader = newFakeCompactibleFileStreamLog(version = 1) + writer.add(0, Array("entry")) + val e = intercept[IllegalStateException] { + reader.get(0) + } + Seq( + "maximum supported log version is v1, but encountered v2", + "produced by a newer version of Spark and cannot be read by this version" + ).foreach { message => + assert(e.getMessage.contains(message)) + } + } + } + test("compact") { withFakeCompactibleFileStreamLog( fileCleanupDelayMs = Long.MaxValue, @@ -219,6 +245,7 @@ class CompactibleFileStreamLogSuite extends SparkFunSuite with SharedSQLContext ): Unit = { withTempDir { file => val compactibleLog = new FakeCompactibleFileStreamLog( + FakeCompactibleFileStreamLog.VERSION, fileCleanupDelayMs, defaultCompactInterval, defaultMinBatchesToRetain, @@ -230,17 +257,18 @@ class CompactibleFileStreamLogSuite extends SparkFunSuite with SharedSQLContext } object FakeCompactibleFileStreamLog { - val VERSION = "test_version" + val VERSION = 1 } class FakeCompactibleFileStreamLog( + metadataLogVersion: Int, _fileCleanupDelayMs: Long, _defaultCompactInterval: Int, _defaultMinBatchesToRetain: Int, sparkSession: SparkSession, path: String) extends CompactibleFileStreamLog[String]( - FakeCompactibleFileStreamLog.VERSION, + metadataLogVersion, sparkSession, path ) { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/FileStreamSinkLogSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/FileStreamSinkLogSuite.scala index 340d2945acd4a..dd3a414659c23 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/FileStreamSinkLogSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/FileStreamSinkLogSuite.scala @@ -74,7 +74,7 @@ class FileStreamSinkLogSuite extends SparkFunSuite with SharedSQLContext { action = FileStreamSinkLog.ADD_ACTION)) // scalastyle:off - val expected = s"""$VERSION + val expected = s"""v$VERSION |{"path":"/a/b/x","size":100,"isDir":false,"modificationTime":1000,"blockReplication":1,"blockSize":10000,"action":"add"} |{"path":"/a/b/y","size":200,"isDir":false,"modificationTime":2000,"blockReplication":2,"blockSize":20000,"action":"delete"} |{"path":"/a/b/z","size":300,"isDir":false,"modificationTime":3000,"blockReplication":3,"blockSize":30000,"action":"add"}""".stripMargin @@ -84,14 +84,14 @@ class FileStreamSinkLogSuite extends SparkFunSuite with SharedSQLContext { assert(expected === baos.toString(UTF_8.name())) baos.reset() sinkLog.serialize(Array(), baos) - assert(VERSION === baos.toString(UTF_8.name())) + assert(s"v$VERSION" === baos.toString(UTF_8.name())) } } test("deserialize") { withFileStreamSinkLog { sinkLog => // scalastyle:off - val logs = s"""$VERSION + val logs = s"""v$VERSION |{"path":"/a/b/x","size":100,"isDir":false,"modificationTime":1000,"blockReplication":1,"blockSize":10000,"action":"add"} |{"path":"/a/b/y","size":200,"isDir":false,"modificationTime":2000,"blockReplication":2,"blockSize":20000,"action":"delete"} |{"path":"/a/b/z","size":300,"isDir":false,"modificationTime":3000,"blockReplication":3,"blockSize":30000,"action":"add"}""".stripMargin @@ -125,7 +125,7 @@ class FileStreamSinkLogSuite extends SparkFunSuite with SharedSQLContext { assert(expected === sinkLog.deserialize(new ByteArrayInputStream(logs.getBytes(UTF_8)))) - assert(Nil === sinkLog.deserialize(new ByteArrayInputStream(VERSION.getBytes(UTF_8)))) + assert(Nil === sinkLog.deserialize(new ByteArrayInputStream(s"v$VERSION".getBytes(UTF_8)))) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/HDFSMetadataLogSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/HDFSMetadataLogSuite.scala index 55750b9202982..7689bc03a4ccf 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/HDFSMetadataLogSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/HDFSMetadataLogSuite.scala @@ -38,8 +38,8 @@ import org.apache.spark.util.UninterruptibleThread class HDFSMetadataLogSuite extends SparkFunSuite with SharedSQLContext { /** To avoid caching of FS objects */ - override protected val sparkConf = - new SparkConf().set(s"spark.hadoop.fs.$scheme.impl.disable.cache", "true") + override protected def sparkConf = + super.sparkConf.set(s"spark.hadoop.fs.$scheme.impl.disable.cache", "true") private implicit def toOption[A](a: A): Option[A] = Option(a) @@ -127,6 +127,33 @@ class HDFSMetadataLogSuite extends SparkFunSuite with SharedSQLContext { } } + test("HDFSMetadataLog: parseVersion") { + withTempDir { dir => + val metadataLog = new HDFSMetadataLog[String](spark, dir.getAbsolutePath) + def assertLogFileMalformed(func: => Int): Unit = { + val e = intercept[IllegalStateException] { func } + assert(e.getMessage.contains(s"Log file was malformed: failed to read correct log version")) + } + assertLogFileMalformed { metadataLog.parseVersion("", 100) } + assertLogFileMalformed { metadataLog.parseVersion("xyz", 100) } + assertLogFileMalformed { metadataLog.parseVersion("v10.x", 100) } + assertLogFileMalformed { metadataLog.parseVersion("10", 100) } + assertLogFileMalformed { metadataLog.parseVersion("v0", 100) } + assertLogFileMalformed { metadataLog.parseVersion("v-10", 100) } + + assert(metadataLog.parseVersion("v10", 10) === 10) + assert(metadataLog.parseVersion("v10", 100) === 10) + + val e = intercept[IllegalStateException] { metadataLog.parseVersion("v200", 100) } + Seq( + "maximum supported log version is v100, but encountered v200", + "produced by a newer version of Spark and cannot be read by this version" + ).foreach { message => + assert(e.getMessage.contains(message)) + } + } + } + test("HDFSMetadataLog: restart") { withTempDir { temp => val metadataLog = new HDFSMetadataLog[String](spark, temp.getAbsolutePath) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/OffsetSeqLogSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/OffsetSeqLogSuite.scala index 5ae8b2484d2ef..dc556322beddb 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/OffsetSeqLogSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/OffsetSeqLogSuite.scala @@ -20,6 +20,8 @@ package org.apache.spark.sql.execution.streaming import java.io.File import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.catalyst.util.stringToFile +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSQLContext class OffsetSeqLogSuite extends SparkFunSuite with SharedSQLContext { @@ -28,12 +30,37 @@ class OffsetSeqLogSuite extends SparkFunSuite with SharedSQLContext { case class StringOffset(override val json: String) extends Offset test("OffsetSeqMetadata - deserialization") { - assert(OffsetSeqMetadata(0, 0) === OffsetSeqMetadata("""{}""")) - assert(OffsetSeqMetadata(1, 0) === OffsetSeqMetadata("""{"batchWatermarkMs":1}""")) - assert(OffsetSeqMetadata(0, 2) === OffsetSeqMetadata("""{"batchTimestampMs":2}""")) - assert( - OffsetSeqMetadata(1, 2) === - OffsetSeqMetadata("""{"batchWatermarkMs":1,"batchTimestampMs":2}""")) + val key = SQLConf.SHUFFLE_PARTITIONS.key + + def getConfWith(shufflePartitions: Int): Map[String, String] = { + Map(key -> shufflePartitions.toString) + } + + // None set + assert(OffsetSeqMetadata(0, 0, Map.empty) === OffsetSeqMetadata("""{}""")) + + // One set + assert(OffsetSeqMetadata(1, 0, Map.empty) === OffsetSeqMetadata("""{"batchWatermarkMs":1}""")) + assert(OffsetSeqMetadata(0, 2, Map.empty) === OffsetSeqMetadata("""{"batchTimestampMs":2}""")) + assert(OffsetSeqMetadata(0, 0, getConfWith(shufflePartitions = 2)) === + OffsetSeqMetadata(s"""{"conf": {"$key":2}}""")) + + // Two set + assert(OffsetSeqMetadata(1, 2, Map.empty) === + OffsetSeqMetadata("""{"batchWatermarkMs":1,"batchTimestampMs":2}""")) + assert(OffsetSeqMetadata(1, 0, getConfWith(shufflePartitions = 3)) === + OffsetSeqMetadata(s"""{"batchWatermarkMs":1,"conf": {"$key":3}}""")) + assert(OffsetSeqMetadata(0, 2, getConfWith(shufflePartitions = 3)) === + OffsetSeqMetadata(s"""{"batchTimestampMs":2,"conf": {"$key":3}}""")) + + // All set + assert(OffsetSeqMetadata(1, 2, getConfWith(shufflePartitions = 3)) === + OffsetSeqMetadata(s"""{"batchWatermarkMs":1,"batchTimestampMs":2,"conf": {"$key":3}}""")) + + // Drop unknown fields + assert(OffsetSeqMetadata(1, 2, getConfWith(shufflePartitions = 3)) === + OffsetSeqMetadata( + s"""{"batchWatermarkMs":1,"batchTimestampMs":2,"conf": {"$key":3}},"unknown":1""")) } test("OffsetSeqLog - serialization - deserialization") { @@ -70,6 +97,22 @@ class OffsetSeqLogSuite extends SparkFunSuite with SharedSQLContext { } } + test("deserialization log written by future version") { + withTempDir { dir => + stringToFile(new File(dir, "0"), "v99999") + val log = new OffsetSeqLog(spark, dir.getCanonicalPath) + val e = intercept[IllegalStateException] { + log.get(0) + } + Seq( + s"maximum supported log version is v${OffsetSeqLog.VERSION}, but encountered v99999", + "produced by a newer version of Spark and cannot be read by this version" + ).foreach { message => + assert(e.getMessage.contains(message)) + } + } + } + test("read Spark 2.1.0 log format") { val (batchId, offsetSeq) = readFromResource("offset-log-version-2.1.0") assert(batchId === 0) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/ProcessingTimeExecutorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/ProcessingTimeExecutorSuite.scala index 00d5e051de357..007554a83f548 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/ProcessingTimeExecutorSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/ProcessingTimeExecutorSuite.scala @@ -17,14 +17,24 @@ package org.apache.spark.sql.execution.streaming -import java.util.concurrent.{CountDownLatch, TimeUnit} +import java.util.concurrent.ConcurrentHashMap + +import scala.collection.mutable + +import org.eclipse.jetty.util.ConcurrentHashSet +import org.scalatest.concurrent.Eventually +import org.scalatest.concurrent.PatienceConfiguration.Timeout +import org.scalatest.concurrent.Timeouts._ +import org.scalatest.time.SpanSugar._ import org.apache.spark.SparkFunSuite import org.apache.spark.sql.streaming.ProcessingTime -import org.apache.spark.util.{Clock, ManualClock, SystemClock} +import org.apache.spark.sql.streaming.util.StreamManualClock class ProcessingTimeExecutorSuite extends SparkFunSuite { + val timeout = 10.seconds + test("nextBatchTime") { val processingTimeExecutor = ProcessingTimeExecutor(ProcessingTime(100)) assert(processingTimeExecutor.nextBatchTime(0) === 100) @@ -35,6 +45,57 @@ class ProcessingTimeExecutorSuite extends SparkFunSuite { assert(processingTimeExecutor.nextBatchTime(150) === 200) } + test("trigger timing") { + val triggerTimes = new ConcurrentHashSet[Int] + val clock = new StreamManualClock() + @volatile var continueExecuting = true + @volatile var clockIncrementInTrigger = 0L + val executor = ProcessingTimeExecutor(ProcessingTime("1000 milliseconds"), clock) + val executorThread = new Thread() { + override def run(): Unit = { + executor.execute(() => { + // Record the trigger time, increment clock if needed and + triggerTimes.add(clock.getTimeMillis.toInt) + clock.advance(clockIncrementInTrigger) + clockIncrementInTrigger = 0 // reset this so that there are no runaway triggers + continueExecuting + }) + } + } + executorThread.start() + // First batch should execute immediately, then executor should wait for next one + eventually { + assert(triggerTimes.contains(0)) + assert(clock.isStreamWaitingAt(0)) + assert(clock.isStreamWaitingFor(1000)) + } + + // Second batch should execute when clock reaches the next trigger time. + // If next trigger takes less than the trigger interval, executor should wait for next one + clockIncrementInTrigger = 500 + clock.setTime(1000) + eventually { + assert(triggerTimes.contains(1000)) + assert(clock.isStreamWaitingAt(1500)) + assert(clock.isStreamWaitingFor(2000)) + } + + // If next trigger takes less than the trigger interval, executor should immediately execute + // another one + clockIncrementInTrigger = 1500 + clock.setTime(2000) // allow another trigger by setting clock to 2000 + eventually { + // Since the next trigger will take 1500 (which is more than trigger interval of 1000) + // executor will immediately execute another trigger + assert(triggerTimes.contains(2000) && triggerTimes.contains(3500)) + assert(clock.isStreamWaitingAt(3500)) + assert(clock.isStreamWaitingFor(4000)) + } + continueExecuting = false + clock.advance(1000) + waitForThreadJoin(executorThread) + } + test("calling nextBatchTime with the result of a previous call should return the next interval") { val intervalMS = 100 val processingTimeExecutor = ProcessingTimeExecutor(ProcessingTime(intervalMS)) @@ -54,7 +115,7 @@ class ProcessingTimeExecutorSuite extends SparkFunSuite { val processingTimeExecutor = ProcessingTimeExecutor(ProcessingTime(intervalMs)) processingTimeExecutor.execute(() => { batchCounts += 1 - // If the batch termination works well, batchCounts should be 3 after `execute` + // If the batch termination works correctly, batchCounts should be 3 after `execute` batchCounts < 3 }) assert(batchCounts === 3) @@ -66,9 +127,8 @@ class ProcessingTimeExecutorSuite extends SparkFunSuite { } test("notifyBatchFallingBehind") { - val clock = new ManualClock() + val clock = new StreamManualClock() @volatile var batchFallingBehindCalled = false - val latch = new CountDownLatch(1) val t = new Thread() { override def run(): Unit = { val processingTimeExecutor = new ProcessingTimeExecutor(ProcessingTime(100), clock) { @@ -77,7 +137,6 @@ class ProcessingTimeExecutorSuite extends SparkFunSuite { } } processingTimeExecutor.execute(() => { - latch.countDown() clock.waitTillTime(200) false }) @@ -85,9 +144,17 @@ class ProcessingTimeExecutorSuite extends SparkFunSuite { } t.start() // Wait until the batch is running so that we don't call `advance` too early - assert(latch.await(10, TimeUnit.SECONDS), "the batch has not yet started in 10 seconds") + eventually { assert(clock.isStreamWaitingFor(200)) } clock.advance(200) - t.join() + waitForThreadJoin(t) assert(batchFallingBehindCalled === true) } + + private def eventually(body: => Unit): Unit = { + Eventually.eventually(Timeout(timeout)) { body } + } + + private def waitForThreadJoin(thread: Thread): Unit = { + failAfter(timeout) { thread.join() } + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala index e848f74e3159f..ebb7422765ebb 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala @@ -123,6 +123,30 @@ class StateStoreSuite extends SparkFunSuite with BeforeAndAfter with PrivateMeth assert(getDataFromFiles(provider, version = 2) === Set("b" -> 2, "c" -> 4)) } + test("filter and concurrent updates") { + val provider = newStoreProvider() + + // Verify state before starting a new set of updates + assert(provider.latestIterator.isEmpty) + val store = provider.getStore(0) + put(store, "a", 1) + put(store, "b", 2) + + // Updates should work while iterating of filtered entries + val filtered = store.filter { case (keyRow, _) => rowToString(keyRow) == "a" } + filtered.foreach { case (keyRow, valueRow) => + store.put(keyRow, intToRow(rowToInt(valueRow) + 1)) + } + assert(get(store, "a") === Some(2)) + + // Removes should work while iterating of filtered entries + val filtered2 = store.filter { case (keyRow, _) => rowToString(keyRow) == "b" } + filtered2.foreach { case (keyRow, _) => + store.remove(keyRow) + } + assert(get(store, "b") === None) + } + test("updates iterator with all combos of updates and removes") { val provider = newStoreProvider() var currentVersion: Int = 0 diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLListenerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLListenerSuite.scala index e41c00ecec271..e6cd41e4facf1 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLListenerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLListenerSuite.scala @@ -477,9 +477,11 @@ private case class MyPlan(sc: SparkContext, expectedValue: Long) extends LeafExe override def doExecute(): RDD[InternalRow] = { longMetric("dummy") += expectedValue - sc.listenerBus.post(SparkListenerDriverAccumUpdates( - sc.getLocalProperty(SQLExecution.EXECUTION_ID_KEY).toLong, - metrics.values.map(m => m.id -> m.value).toSeq)) + + SQLMetrics.postDriverMetricUpdates( + sc, + sc.getLocalProperty(SQLExecution.EXECUTION_ID_KEY), + metrics.values.toSeq) sc.emptyRDD } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala index 8184d7d909f4b..e48e3f6402901 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala @@ -41,24 +41,49 @@ class ColumnarBatchSuite extends SparkFunSuite { val column = ColumnVector.allocate(1024, IntegerType, memMode) var idx = 0 assert(column.anyNullsSet() == false) + assert(column.numNulls() == 0) + + column.appendNotNull() + reference += false + assert(column.anyNullsSet() == false) + assert(column.numNulls() == 0) + + column.appendNotNulls(3) + (1 to 3).foreach(_ => reference += false) + assert(column.anyNullsSet() == false) + assert(column.numNulls() == 0) + + column.appendNull() + reference += true + assert(column.anyNullsSet()) + assert(column.numNulls() == 1) + + column.appendNulls(3) + (1 to 3).foreach(_ => reference += true) + assert(column.anyNullsSet()) + assert(column.numNulls() == 4) + + idx = column.elementsAppended column.putNotNull(idx) reference += false idx += 1 - assert(column.anyNullsSet() == false) + assert(column.anyNullsSet()) + assert(column.numNulls() == 4) column.putNull(idx) reference += true idx += 1 - assert(column.anyNullsSet() == true) - assert(column.numNulls() == 1) + assert(column.anyNullsSet()) + assert(column.numNulls() == 5) column.putNulls(idx, 3) reference += true reference += true reference += true idx += 3 - assert(column.anyNullsSet() == true) + assert(column.anyNullsSet()) + assert(column.numNulls() == 8) column.putNotNulls(idx, 4) reference += false @@ -66,8 +91,8 @@ class ColumnarBatchSuite extends SparkFunSuite { reference += false reference += false idx += 4 - assert(column.anyNullsSet() == true) - assert(column.numNulls() == 4) + assert(column.anyNullsSet()) + assert(column.numNulls() == 8) reference.zipWithIndex.foreach { v => assert(v._1 == column.isNullAt(v._2)) @@ -85,9 +110,26 @@ class ColumnarBatchSuite extends SparkFunSuite { val reference = mutable.ArrayBuffer.empty[Byte] val column = ColumnVector.allocate(1024, ByteType, memMode) - var idx = 0 - val values = (1 :: 2 :: 3 :: 4 :: 5 :: Nil).map(_.toByte).toArray + var values = (10 :: 20 :: 30 :: 40 :: 50 :: Nil).map(_.toByte).toArray + column.appendBytes(2, values, 0) + reference += 10.toByte + reference += 20.toByte + + column.appendBytes(3, values, 2) + reference += 30.toByte + reference += 40.toByte + reference += 50.toByte + + column.appendBytes(6, 60.toByte) + (1 to 6).foreach(_ => reference += 60.toByte) + + column.appendByte(70.toByte) + reference += 70.toByte + + var idx = column.elementsAppended + + values = (1 :: 2 :: 3 :: 4 :: 5 :: Nil).map(_.toByte).toArray column.putBytes(idx, 2, values, 0) reference += 1 reference += 2 @@ -126,9 +168,26 @@ class ColumnarBatchSuite extends SparkFunSuite { val reference = mutable.ArrayBuffer.empty[Short] val column = ColumnVector.allocate(1024, ShortType, memMode) - var idx = 0 - val values = (1 :: 2 :: 3 :: 4 :: 5 :: Nil).map(_.toShort).toArray + var values = (10 :: 20 :: 30 :: 40 :: 50 :: Nil).map(_.toShort).toArray + column.appendShorts(2, values, 0) + reference += 10.toShort + reference += 20.toShort + + column.appendShorts(3, values, 2) + reference += 30.toShort + reference += 40.toShort + reference += 50.toShort + + column.appendShorts(6, 60.toShort) + (1 to 6).foreach(_ => reference += 60.toShort) + + column.appendShort(70.toShort) + reference += 70.toShort + + var idx = column.elementsAppended + + values = (1 :: 2 :: 3 :: 4 :: 5 :: Nil).map(_.toShort).toArray column.putShorts(idx, 2, values, 0) reference += 1 reference += 2 @@ -189,9 +248,26 @@ class ColumnarBatchSuite extends SparkFunSuite { val reference = mutable.ArrayBuffer.empty[Int] val column = ColumnVector.allocate(1024, IntegerType, memMode) - var idx = 0 - val values = (1 :: 2 :: 3 :: 4 :: 5 :: Nil).toArray + var values = (10 :: 20 :: 30 :: 40 :: 50 :: Nil).toArray + column.appendInts(2, values, 0) + reference += 10 + reference += 20 + + column.appendInts(3, values, 2) + reference += 30 + reference += 40 + reference += 50 + + column.appendInts(6, 60) + (1 to 6).foreach(_ => reference += 60) + + column.appendInt(70) + reference += 70 + + var idx = column.elementsAppended + + values = (1 :: 2 :: 3 :: 4 :: 5 :: Nil).toArray column.putInts(idx, 2, values, 0) reference += 1 reference += 2 @@ -257,9 +333,26 @@ class ColumnarBatchSuite extends SparkFunSuite { val reference = mutable.ArrayBuffer.empty[Long] val column = ColumnVector.allocate(1024, LongType, memMode) - var idx = 0 - val values = (1L :: 2L :: 3L :: 4L :: 5L :: Nil).toArray + var values = (10L :: 20L :: 30L :: 40L :: 50L :: Nil).toArray + column.appendLongs(2, values, 0) + reference += 10L + reference += 20L + + column.appendLongs(3, values, 2) + reference += 30L + reference += 40L + reference += 50L + + column.appendLongs(6, 60L) + (1 to 6).foreach(_ => reference += 60L) + + column.appendLong(70L) + reference += 70L + + var idx = column.elementsAppended + + values = (1L :: 2L :: 3L :: 4L :: 5L :: Nil).toArray column.putLongs(idx, 2, values, 0) reference += 1 reference += 2 @@ -320,6 +413,97 @@ class ColumnarBatchSuite extends SparkFunSuite { }} } + test("Float APIs") { + (MemoryMode.ON_HEAP :: MemoryMode.OFF_HEAP :: Nil).foreach { memMode => { + val seed = System.currentTimeMillis() + val random = new Random(seed) + val reference = mutable.ArrayBuffer.empty[Float] + + val column = ColumnVector.allocate(1024, FloatType, memMode) + + var values = (.1f :: .2f :: .3f :: .4f :: .5f :: Nil).toArray + column.appendFloats(2, values, 0) + reference += .1f + reference += .2f + + column.appendFloats(3, values, 2) + reference += .3f + reference += .4f + reference += .5f + + column.appendFloats(6, .6f) + (1 to 6).foreach(_ => reference += .6f) + + column.appendFloat(.7f) + reference += .7f + + var idx = column.elementsAppended + + values = (1.0f :: 2.0f :: 3.0f :: 4.0f :: 5.0f :: Nil).toArray + column.putFloats(idx, 2, values, 0) + reference += 1.0f + reference += 2.0f + idx += 2 + + column.putFloats(idx, 3, values, 2) + reference += 3.0f + reference += 4.0f + reference += 5.0f + idx += 3 + + val buffer = new Array[Byte](8) + Platform.putFloat(buffer, Platform.BYTE_ARRAY_OFFSET, 2.234f) + Platform.putFloat(buffer, Platform.BYTE_ARRAY_OFFSET + 4, 1.123f) + + if (ByteOrder.nativeOrder().equals(ByteOrder.BIG_ENDIAN)) { + // Ensure array contains Little Endian floats + val bb = ByteBuffer.wrap(buffer).order(ByteOrder.LITTLE_ENDIAN) + Platform.putFloat(buffer, Platform.BYTE_ARRAY_OFFSET, bb.getFloat(0)) + Platform.putFloat(buffer, Platform.BYTE_ARRAY_OFFSET + 4, bb.getFloat(4)) + } + + column.putFloats(idx, 1, buffer, 4) + column.putFloats(idx + 1, 1, buffer, 0) + reference += 1.123f + reference += 2.234f + idx += 2 + + column.putFloats(idx, 2, buffer, 0) + reference += 2.234f + reference += 1.123f + idx += 2 + + while (idx < column.capacity) { + val single = random.nextBoolean() + if (single) { + val v = random.nextFloat() + column.putFloat(idx, v) + reference += v + idx += 1 + } else { + val n = math.min(random.nextInt(column.capacity / 20), column.capacity - idx) + val v = random.nextFloat() + column.putFloats(idx, n, v) + var i = 0 + while (i < n) { + reference += v + i += 1 + } + idx += n + } + } + + reference.zipWithIndex.foreach { v => + assert(v._1 == column.getFloat(v._2), "Seed = " + seed + " MemMode=" + memMode) + if (memMode == MemoryMode.OFF_HEAP) { + val addr = column.valuesNativeAddress() + assert(v._1 == Platform.getFloat(null, addr + 4 * v._2)) + } + } + column.close + }} + } + test("Double APIs") { (MemoryMode.ON_HEAP :: MemoryMode.OFF_HEAP :: Nil).foreach { memMode => { val seed = System.currentTimeMillis() @@ -327,9 +511,26 @@ class ColumnarBatchSuite extends SparkFunSuite { val reference = mutable.ArrayBuffer.empty[Double] val column = ColumnVector.allocate(1024, DoubleType, memMode) - var idx = 0 - val values = (1.0 :: 2.0 :: 3.0 :: 4.0 :: 5.0 :: Nil).toArray + var values = (.1 :: .2 :: .3 :: .4 :: .5 :: Nil).toArray + column.appendDoubles(2, values, 0) + reference += .1 + reference += .2 + + column.appendDoubles(3, values, 2) + reference += .3 + reference += .4 + reference += .5 + + column.appendDoubles(6, .6) + (1 to 6).foreach(_ => reference += .6) + + column.appendDouble(.7) + reference += .7 + + var idx = column.elementsAppended + + values = (1.0 :: 2.0 :: 3.0 :: 4.0 :: 5.0 :: Nil).toArray column.putDoubles(idx, 2, values, 0) reference += 1.0 reference += 2.0 @@ -346,8 +547,8 @@ class ColumnarBatchSuite extends SparkFunSuite { Platform.putDouble(buffer, Platform.BYTE_ARRAY_OFFSET + 8, 1.123) if (ByteOrder.nativeOrder().equals(ByteOrder.BIG_ENDIAN)) { - // Ensure array contains Liitle Endian doubles - var bb = ByteBuffer.wrap(buffer).order(ByteOrder.LITTLE_ENDIAN) + // Ensure array contains Little Endian doubles + val bb = ByteBuffer.wrap(buffer).order(ByteOrder.LITTLE_ENDIAN) Platform.putDouble(buffer, Platform.BYTE_ARRAY_OFFSET, bb.getDouble(0)) Platform.putDouble(buffer, Platform.BYTE_ARRAY_OFFSET + 8, bb.getDouble(8)) } @@ -400,40 +601,47 @@ class ColumnarBatchSuite extends SparkFunSuite { val column = ColumnVector.allocate(6, BinaryType, memMode) assert(column.arrayData().elementsAppended == 0) - var idx = 0 + + val str = "string" + column.appendByteArray(str.getBytes(StandardCharsets.UTF_8), + 0, str.getBytes(StandardCharsets.UTF_8).length) + reference += str + assert(column.arrayData().elementsAppended == 6) + + var idx = column.elementsAppended val values = ("Hello" :: "abc" :: Nil).toArray column.putByteArray(idx, values(0).getBytes(StandardCharsets.UTF_8), 0, values(0).getBytes(StandardCharsets.UTF_8).length) reference += values(0) idx += 1 - assert(column.arrayData().elementsAppended == 5) + assert(column.arrayData().elementsAppended == 11) column.putByteArray(idx, values(1).getBytes(StandardCharsets.UTF_8), 0, values(1).getBytes(StandardCharsets.UTF_8).length) reference += values(1) idx += 1 - assert(column.arrayData().elementsAppended == 8) + assert(column.arrayData().elementsAppended == 14) // Just put llo val offset = column.putByteArray(idx, values(0).getBytes(StandardCharsets.UTF_8), 2, values(0).getBytes(StandardCharsets.UTF_8).length - 2) reference += "llo" idx += 1 - assert(column.arrayData().elementsAppended == 11) + assert(column.arrayData().elementsAppended == 17) // Put the same "ll" at offset. This should not allocate more memory in the column. column.putArray(idx, offset, 2) reference += "ll" idx += 1 - assert(column.arrayData().elementsAppended == 11) + assert(column.arrayData().elementsAppended == 17) // Put a long string val s = "abcdefghijklmnopqrstuvwxyz" column.putByteArray(idx, (s + s).getBytes(StandardCharsets.UTF_8)) reference += (s + s) idx += 1 - assert(column.arrayData().elementsAppended == 11 + (s + s).length) + assert(column.arrayData().elementsAppended == 17 + (s + s).length) reference.zipWithIndex.foreach { v => assert(v._1.length == column.getArrayLength(v._2), "MemoryMode=" + memMode) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/internal/CatalogSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/internal/CatalogSuite.scala index 75723d0abcfcc..8f9c52cb1e031 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/internal/CatalogSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/internal/CatalogSuite.scala @@ -18,7 +18,6 @@ package org.apache.spark.sql.internal import java.io.File -import java.net.URI import org.scalatest.BeforeAndAfterEach @@ -76,9 +75,10 @@ class CatalogSuite } private def createTempFunction(name: String): Unit = { - val info = new ExpressionInfo("className", name) val tempFunc = (e: Seq[Expression]) => e.head - sessionCatalog.createTempFunction(name, info, tempFunc, ignoreIfExists = false) + val funcMeta = CatalogFunction(FunctionIdentifier(name, None), "className", Nil) + sessionCatalog.registerFunction( + funcMeta, ignoreIfExists = false, functionBuilder = Some(tempFunc)) } private def dropFunction(name: String, db: Option[String] = None): Unit = { @@ -103,6 +103,11 @@ class CatalogSuite assert(col.isPartition == tableMetadata.partitionColumnNames.contains(col.name)) assert(col.isBucket == bucketColumnNames.contains(col.name)) } + + dbName.foreach { db => + val expected = columns.collect().map(_.name).toSet + assert(spark.catalog.listColumns(s"$db.$tableName").collect().map(_.name).toSet == expected) + } } override def afterEach(): Unit = { @@ -346,6 +351,7 @@ class CatalogSuite // Find a qualified table assert(spark.catalog.getTable(db, "tbl_y").name === "tbl_y") + assert(spark.catalog.getTable(s"$db.tbl_y").name === "tbl_y") // Find an unqualified table using the current database intercept[AnalysisException](spark.catalog.getTable("tbl_y")) @@ -379,6 +385,11 @@ class CatalogSuite assert(fn2.database === db) assert(!fn2.isTemporary) + val fn2WithQualifiedName = spark.catalog.getFunction(s"$db.fn2") + assert(fn2WithQualifiedName.name === "fn2") + assert(fn2WithQualifiedName.database === db) + assert(!fn2WithQualifiedName.isTemporary) + // Find an unqualified function using the current database intercept[AnalysisException](spark.catalog.getFunction("fn2")) spark.catalog.setCurrentDatabase(db) @@ -404,6 +415,7 @@ class CatalogSuite assert(!spark.catalog.tableExists("tbl_x")) assert(!spark.catalog.tableExists("tbl_y")) assert(!spark.catalog.tableExists(db, "tbl_y")) + assert(!spark.catalog.tableExists(s"$db.tbl_y")) // Create objects. createTempTable("tbl_x") @@ -414,11 +426,15 @@ class CatalogSuite // Find a qualified table assert(spark.catalog.tableExists(db, "tbl_y")) + assert(spark.catalog.tableExists(s"$db.tbl_y")) // Find an unqualified table using the current database assert(!spark.catalog.tableExists("tbl_y")) spark.catalog.setCurrentDatabase(db) assert(spark.catalog.tableExists("tbl_y")) + + // Unable to find the table, although the temp view with the given name exists + assert(!spark.catalog.tableExists(db, "tbl_x")) } } } @@ -430,6 +446,7 @@ class CatalogSuite assert(!spark.catalog.functionExists("fn1")) assert(!spark.catalog.functionExists("fn2")) assert(!spark.catalog.functionExists(db, "fn2")) + assert(!spark.catalog.functionExists(s"$db.fn2")) // Create objects. createTempFunction("fn1") @@ -440,11 +457,15 @@ class CatalogSuite // Find a qualified function assert(spark.catalog.functionExists(db, "fn2")) + assert(spark.catalog.functionExists(s"$db.fn2")) // Find an unqualified function using the current database assert(!spark.catalog.functionExists("fn2")) spark.catalog.setCurrentDatabase(db) assert(spark.catalog.functionExists("fn2")) + + // Unable to find the function, although the temp function with the given name exists + assert(!spark.catalog.functionExists(db, "fn1")) } } } @@ -459,7 +480,7 @@ class CatalogSuite options = Map("path" -> dir.getAbsolutePath)) val table = spark.sessionState.catalog.getTableMetadata(TableIdentifier("t")) assert(table.tableType == CatalogTableType.EXTERNAL) - assert(table.storage.locationUri.get == dir.getAbsolutePath) + assert(table.storage.locationUri.get == makeQualifiedPath(dir.getAbsolutePath)) Seq((1)).toDF("i").write.insertInto("t") assert(dir.exists() && dir.listFiles().nonEmpty) @@ -481,7 +502,7 @@ class CatalogSuite options = Map.empty[String, String]) val table = spark.sessionState.catalog.getTableMetadata(TableIdentifier("t")) assert(table.tableType == CatalogTableType.MANAGED) - val tablePath = new File(new URI(table.storage.locationUri.get)) + val tablePath = new File(table.storage.locationUri.get) assert(tablePath.exists() && tablePath.listFiles().isEmpty) Seq((1)).toDF("i").write.insertInto("t") @@ -493,6 +514,25 @@ class CatalogSuite } } - // TODO: add tests for the rest of them + test("clone Catalog") { + // need to test tempTables are cloned + assert(spark.catalog.listTables().collect().isEmpty) + createTempTable("my_temp_table") + assert(spark.catalog.listTables().collect().map(_.name).toSet == Set("my_temp_table")) + + // inheritance + val forkedSession = spark.cloneSession() + assert(spark ne forkedSession) + assert(spark.catalog ne forkedSession.catalog) + assert(forkedSession.catalog.listTables().collect().map(_.name).toSet == Set("my_temp_table")) + + // independence + dropTable("my_temp_table") // drop table in original session + assert(spark.catalog.listTables().collect().map(_.name).toSet == Set()) + assert(forkedSession.catalog.listTables().collect().map(_.name).toSet == Set("my_temp_table")) + forkedSession.sessionState.catalog + .createTempView("fork_table", Range(1, 2, 3, 4), overrideIfExists = true) + assert(spark.catalog.listTables().collect().map(_.name).toSet == Set()) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/internal/SQLConfEntrySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/internal/SQLConfEntrySuite.scala index 0e3a5ca9d71dd..f2456c7704064 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/internal/SQLConfEntrySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/internal/SQLConfEntrySuite.scala @@ -187,4 +187,22 @@ class SQLConfEntrySuite extends SparkFunSuite { } assert(e2.getMessage === "The maximum size of the cache must not be negative") } + + test("clone SQLConf") { + val original = new SQLConf + val key = "spark.sql.SQLConfEntrySuite.clone" + assert(original.getConfString(key, "noentry") === "noentry") + + // inheritance + original.setConfString(key, "orig") + val clone = original.clone() + assert(original ne clone) + assert(clone.getConfString(key, "noentry") === "orig") + + // independence + clone.setConfString(key, "clone") + assert(original.getConfString(key, "noentry") === "orig") + original.setConfString(key, "dontcopyme") + assert(clone.getConfString(key, "noentry") === "clone") + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala index 5463728ca0c1d..5bd36ec25ccb0 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala @@ -806,10 +806,6 @@ class JDBCSuite extends SparkFunSuite sql(s"DESC FORMATTED $tableName").collect().foreach { r => assert(!r.toString().contains(password)) } - - sql(s"DESC EXTENDED $tableName").collect().foreach { r => - assert(!r.toString().contains(password)) - } } } } @@ -869,7 +865,7 @@ class JDBCSuite extends SparkFunSuite test("SPARK-16387: Reserved SQL words are not escaped by JDBC writer") { val df = spark.createDataset(Seq("a", "b", "c")).toDF("order") - val schema = JdbcUtils.schemaString(df.schema, "jdbc:mysql://localhost:3306/temp") + val schema = JdbcUtils.schemaString(df, "jdbc:mysql://localhost:3306/temp") assert(schema.contains("`order` TEXT")) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCWriteSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCWriteSuite.scala index ec7b19e666ec0..bf1fd160704fa 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCWriteSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCWriteSuite.scala @@ -17,15 +17,16 @@ package org.apache.spark.sql.jdbc -import java.sql.DriverManager +import java.sql.{Date, DriverManager, Timestamp} import java.util.Properties import scala.collection.JavaConverters.propertiesAsScalaMapConverter import org.scalatest.BeforeAndAfter -import org.apache.spark.sql.{AnalysisException, Row, SaveMode} -import org.apache.spark.sql.execution.datasources.jdbc.JDBCOptions +import org.apache.spark.sql.{AnalysisException, DataFrame, Row, SaveMode} +import org.apache.spark.sql.catalyst.parser.ParseException +import org.apache.spark.sql.execution.datasources.jdbc.{JDBCOptions, JdbcUtils} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types._ @@ -362,4 +363,147 @@ class JDBCWriteSuite extends SharedSQLContext with BeforeAndAfter { assert(sql("select * from people_view").count() == 2) } } + + test("SPARK-10849: test schemaString - from createTableColumnTypes option values") { + def testCreateTableColDataTypes(types: Seq[String]): Unit = { + val colTypes = types.zipWithIndex.map { case (t, i) => (s"col$i", t) } + val schema = colTypes + .foldLeft(new StructType())((schema, colType) => schema.add(colType._1, colType._2)) + val createTableColTypes = + colTypes.map { case (col, dataType) => s"$col $dataType" }.mkString(", ") + val df = spark.createDataFrame(sparkContext.parallelize(Seq(Row.empty)), schema) + + val expectedSchemaStr = + colTypes.map { case (col, dataType) => s""""$col" $dataType """ }.mkString(", ") + + assert(JdbcUtils.schemaString(df, url1, Option(createTableColTypes)) == expectedSchemaStr) + } + + testCreateTableColDataTypes(Seq("boolean")) + testCreateTableColDataTypes(Seq("tinyint", "smallint", "int", "bigint")) + testCreateTableColDataTypes(Seq("float", "double")) + testCreateTableColDataTypes(Seq("string", "char(10)", "varchar(20)")) + testCreateTableColDataTypes(Seq("decimal(10,0)", "decimal(10,5)")) + testCreateTableColDataTypes(Seq("date", "timestamp")) + testCreateTableColDataTypes(Seq("binary")) + } + + test("SPARK-10849: create table using user specified column type and verify on target table") { + def testUserSpecifiedColTypes( + df: DataFrame, + createTableColTypes: String, + expectedTypes: Map[String, String]): Unit = { + df.write + .mode(SaveMode.Overwrite) + .option("createTableColumnTypes", createTableColTypes) + .jdbc(url1, "TEST.DBCOLTYPETEST", properties) + + // verify the data types of the created table by reading the database catalog of H2 + val query = + """ + |(SELECT column_name, type_name, character_maximum_length + | FROM information_schema.columns WHERE table_name = 'DBCOLTYPETEST') + """.stripMargin + val rows = spark.read.jdbc(url1, query, properties).collect() + + rows.foreach { row => + val typeName = row.getString(1) + // For CHAR and VARCHAR, we also compare the max length + if (typeName.contains("CHAR")) { + val charMaxLength = row.getInt(2) + assert(expectedTypes(row.getString(0)) == s"$typeName($charMaxLength)") + } else { + assert(expectedTypes(row.getString(0)) == typeName) + } + } + } + + val data = Seq[Row](Row(1, "dave", "Boston")) + val schema = StructType( + StructField("id", IntegerType) :: + StructField("first#name", StringType) :: + StructField("city", StringType) :: Nil) + val df = spark.createDataFrame(sparkContext.parallelize(data), schema) + + // out-of-order + val expected1 = Map("id" -> "BIGINT", "first#name" -> "VARCHAR(123)", "city" -> "CHAR(20)") + testUserSpecifiedColTypes(df, "`first#name` VARCHAR(123), id BIGINT, city CHAR(20)", expected1) + // partial schema + val expected2 = Map("id" -> "INTEGER", "first#name" -> "VARCHAR(123)", "city" -> "CHAR(20)") + testUserSpecifiedColTypes(df, "`first#name` VARCHAR(123), city CHAR(20)", expected2) + + withSQLConf(SQLConf.CASE_SENSITIVE.key -> "false") { + // should still respect the original column names + val expected = Map("id" -> "INTEGER", "first#name" -> "VARCHAR(123)", "city" -> "CLOB") + testUserSpecifiedColTypes(df, "`FiRsT#NaMe` VARCHAR(123)", expected) + } + + withSQLConf(SQLConf.CASE_SENSITIVE.key -> "true") { + val schema = StructType( + StructField("id", IntegerType) :: + StructField("First#Name", StringType) :: + StructField("city", StringType) :: Nil) + val df = spark.createDataFrame(sparkContext.parallelize(data), schema) + val expected = Map("id" -> "INTEGER", "First#Name" -> "VARCHAR(123)", "city" -> "CLOB") + testUserSpecifiedColTypes(df, "`First#Name` VARCHAR(123)", expected) + } + } + + test("SPARK-10849: jdbc CreateTableColumnTypes option with invalid data type") { + val df = spark.createDataFrame(sparkContext.parallelize(arr2x2), schema2) + val msg = intercept[ParseException] { + df.write.mode(SaveMode.Overwrite) + .option("createTableColumnTypes", "name CLOB(2000)") + .jdbc(url1, "TEST.USERDBTYPETEST", properties) + }.getMessage() + assert(msg.contains("DataType clob(2000) is not supported.")) + } + + test("SPARK-10849: jdbc CreateTableColumnTypes option with invalid syntax") { + val df = spark.createDataFrame(sparkContext.parallelize(arr2x2), schema2) + val msg = intercept[ParseException] { + df.write.mode(SaveMode.Overwrite) + .option("createTableColumnTypes", "`name char(20)") // incorrectly quoted column + .jdbc(url1, "TEST.USERDBTYPETEST", properties) + }.getMessage() + assert(msg.contains("no viable alternative at input")) + } + + test("SPARK-10849: jdbc CreateTableColumnTypes duplicate columns") { + withSQLConf(SQLConf.CASE_SENSITIVE.key -> "false") { + val df = spark.createDataFrame(sparkContext.parallelize(arr2x2), schema2) + val msg = intercept[AnalysisException] { + df.write.mode(SaveMode.Overwrite) + .option("createTableColumnTypes", "name CHAR(20), id int, NaMe VARCHAR(100)") + .jdbc(url1, "TEST.USERDBTYPETEST", properties) + }.getMessage() + assert(msg.contains( + "Found duplicate column(s) in createTableColumnTypes option value: name, NaMe")) + } + } + + test("SPARK-10849: jdbc CreateTableColumnTypes invalid columns") { + // schema2 has the column "id" and "name" + val df = spark.createDataFrame(sparkContext.parallelize(arr2x2), schema2) + + withSQLConf(SQLConf.CASE_SENSITIVE.key -> "false") { + val msg = intercept[AnalysisException] { + df.write.mode(SaveMode.Overwrite) + .option("createTableColumnTypes", "firstName CHAR(20), id int") + .jdbc(url1, "TEST.USERDBTYPETEST", properties) + }.getMessage() + assert(msg.contains("createTableColumnTypes option column firstName not found in " + + "schema struct")) + } + + withSQLConf(SQLConf.CASE_SENSITIVE.key -> "true") { + val msg = intercept[AnalysisException] { + df.write.mode(SaveMode.Overwrite) + .option("createTableColumnTypes", "id int, Name VARCHAR(100)") + .jdbc(url1, "TEST.USERDBTYPETEST", properties) + }.getMessage() + assert(msg.contains("createTableColumnTypes option column Name not found in " + + "schema struct")) + } + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/BucketedReadSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/BucketedReadSuite.scala index 9b65419dba234..ba0ca666b5c14 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/BucketedReadSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/BucketedReadSuite.scala @@ -90,6 +90,7 @@ abstract class BucketedReadSuite extends QueryTest with SQLTestUtils { originalDataFrame: DataFrame): Unit = { // This test verifies parts of the plan. Disable whole stage codegen. withSQLConf(SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key -> "false") { + val strategy = DataSourceStrategy(spark.sessionState.conf) val bucketedDataFrame = spark.table("bucketed_table").select("i", "j", "k") val BucketSpec(numBuckets, bucketColumnNames, _) = bucketSpec // Limit: bucket pruning only works when the bucket column has one and only one column @@ -98,7 +99,7 @@ abstract class BucketedReadSuite extends QueryTest with SQLTestUtils { val bucketColumn = bucketedDataFrame.schema.toAttributes(bucketColumnIndex) val matchedBuckets = new BitSet(numBuckets) bucketValues.foreach { value => - matchedBuckets.set(DataSourceStrategy.getBucketId(bucketColumn, numBuckets, value)) + matchedBuckets.set(strategy.getBucketId(bucketColumn, numBuckets, value)) } // Filter could hide the bug in bucket pruning. Thus, skipping all the filters diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/BucketedWriteSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/BucketedWriteSuite.scala index 9082261af7b00..93f3efe2ccc4a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/BucketedWriteSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/BucketedWriteSuite.scala @@ -92,7 +92,7 @@ abstract class BucketedWriteSuite extends QueryTest with SQLTestUtils { def tableDir: File = { val identifier = spark.sessionState.sqlParser.parseTableIdentifier("bucketed_table") - new File(URI.create(spark.sessionState.catalog.defaultTablePath(identifier))) + new File(spark.sessionState.catalog.defaultTablePath(identifier)) } /** diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/DDLTestSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/DDLTestSuite.scala deleted file mode 100644 index 674463feca4db..0000000000000 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/DDLTestSuite.scala +++ /dev/null @@ -1,123 +0,0 @@ -/* -* Licensed to the Apache Software Foundation (ASF) under one or more -* contributor license agreements. See the NOTICE file distributed with -* this work for additional information regarding copyright ownership. -* The ASF licenses this file to You under the Apache License, Version 2.0 -* (the "License"); you may not use this file except in compliance with -* the License. You may obtain a copy of the License at -* -* http://www.apache.org/licenses/LICENSE-2.0 -* -* Unless required by applicable law or agreed to in writing, software -* distributed under the License is distributed on an "AS IS" BASIS, -* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -* See the License for the specific language governing permissions and -* limitations under the License. -*/ - -package org.apache.spark.sql.sources - -import org.apache.spark.rdd.RDD -import org.apache.spark.sql._ -import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.test.SharedSQLContext -import org.apache.spark.sql.types._ -import org.apache.spark.unsafe.types.UTF8String - -class DDLScanSource extends RelationProvider { - override def createRelation( - sqlContext: SQLContext, - parameters: Map[String, String]): BaseRelation = { - SimpleDDLScan( - parameters("from").toInt, - parameters("TO").toInt, - parameters("Table"))(sqlContext.sparkSession) - } -} - -case class SimpleDDLScan( - from: Int, - to: Int, - table: String)(@transient val sparkSession: SparkSession) - extends BaseRelation with TableScan { - - override def sqlContext: SQLContext = sparkSession.sqlContext - - override def schema: StructType = - StructType(Seq( - StructField("intType", IntegerType, nullable = false).withComment(s"test comment $table"), - StructField("stringType", StringType, nullable = false), - StructField("dateType", DateType, nullable = false), - StructField("timestampType", TimestampType, nullable = false), - StructField("doubleType", DoubleType, nullable = false), - StructField("bigintType", LongType, nullable = false), - StructField("tinyintType", ByteType, nullable = false), - StructField("decimalType", DecimalType.USER_DEFAULT, nullable = false), - StructField("fixedDecimalType", DecimalType(5, 1), nullable = false), - StructField("binaryType", BinaryType, nullable = false), - StructField("booleanType", BooleanType, nullable = false), - StructField("smallIntType", ShortType, nullable = false), - StructField("floatType", FloatType, nullable = false), - StructField("mapType", MapType(StringType, StringType)), - StructField("arrayType", ArrayType(StringType)), - StructField("structType", - StructType(StructField("f1", StringType) :: StructField("f2", IntegerType) :: Nil - ) - ) - )) - - override def needConversion: Boolean = false - - override def buildScan(): RDD[Row] = { - // Rely on a type erasure hack to pass RDD[InternalRow] back as RDD[Row] - sparkSession.sparkContext.parallelize(from to to).map { e => - InternalRow(UTF8String.fromString(s"people$e"), e * 2) - }.asInstanceOf[RDD[Row]] - } -} - -class DDLTestSuite extends DataSourceTest with SharedSQLContext { - protected override lazy val sql = spark.sql _ - - override def beforeAll(): Unit = { - super.beforeAll() - sql( - """ - |CREATE OR REPLACE TEMPORARY VIEW ddlPeople - |USING org.apache.spark.sql.sources.DDLScanSource - |OPTIONS ( - | From '1', - | To '10', - | Table 'test1' - |) - """.stripMargin) - } - - sqlTest( - "describe ddlPeople", - Seq( - Row("intType", "int", "test comment test1"), - Row("stringType", "string", null), - Row("dateType", "date", null), - Row("timestampType", "timestamp", null), - Row("doubleType", "double", null), - Row("bigintType", "bigint", null), - Row("tinyintType", "tinyint", null), - Row("decimalType", "decimal(10,0)", null), - Row("fixedDecimalType", "decimal(5,1)", null), - Row("binaryType", "binary", null), - Row("booleanType", "boolean", null), - Row("smallIntType", "smallint", null), - Row("floatType", "float", null), - Row("mapType", "map", null), - Row("arrayType", "array", null), - Row("structType", "struct", null) - )) - - test("SPARK-7686 DescribeCommand should have correct physical plan output attributes") { - val attributes = sql("describe ddlPeople") - .queryExecution.executedPlan.output - assert(attributes.map(_.name) === Seq("col_name", "data_type", "comment")) - assert(attributes.map(_.dataType).toSet === Set(StringType)) - } -} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/DataSourceAnalysisSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/DataSourceAnalysisSuite.scala index 448adcf11d656..735e07c21373a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/DataSourceAnalysisSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/DataSourceAnalysisSuite.scala @@ -21,11 +21,11 @@ import org.scalatest.BeforeAndAfterAll import org.apache.spark.SparkFunSuite import org.apache.spark.sql.AnalysisException -import org.apache.spark.sql.catalyst.SimpleCatalystConf import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, Cast, Expression, Literal} import org.apache.spark.sql.execution.datasources.DataSourceAnalysis -import org.apache.spark.sql.types.{IntegerType, StructType} +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.types.{DataType, IntegerType, StructType} class DataSourceAnalysisSuite extends SparkFunSuite with BeforeAndAfterAll { @@ -49,7 +49,11 @@ class DataSourceAnalysisSuite extends SparkFunSuite with BeforeAndAfterAll { } Seq(true, false).foreach { caseSensitive => - val rule = DataSourceAnalysis(SimpleCatalystConf(caseSensitive)) + val conf = new SQLConf().copy(SQLConf.CASE_SENSITIVE -> caseSensitive) + def cast(e: Expression, dt: DataType): Expression = { + Cast(e, dt, Option(conf.sessionLocalTimeZone)) + } + val rule = DataSourceAnalysis(conf) test( s"convertStaticPartitions only handle INSERT having at least static partitions " + s"(caseSensitive: $caseSensitive)") { @@ -150,7 +154,7 @@ class DataSourceAnalysisSuite extends SparkFunSuite with BeforeAndAfterAll { if (!caseSensitive) { val nonPartitionedAttributes = Seq('e.int, 'f.int) val expected = nonPartitionedAttributes ++ - Seq(Cast(Literal("1"), IntegerType), Cast(Literal("3"), IntegerType)) + Seq(cast(Literal("1"), IntegerType), cast(Literal("3"), IntegerType)) val actual = rule.convertStaticPartitions( sourceAttributes = nonPartitionedAttributes, providedPartitions = Map("b" -> Some("1"), "C" -> Some("3")), @@ -162,7 +166,7 @@ class DataSourceAnalysisSuite extends SparkFunSuite with BeforeAndAfterAll { { val nonPartitionedAttributes = Seq('e.int, 'f.int) val expected = nonPartitionedAttributes ++ - Seq(Cast(Literal("1"), IntegerType), Cast(Literal("3"), IntegerType)) + Seq(cast(Literal("1"), IntegerType), cast(Literal("3"), IntegerType)) val actual = rule.convertStaticPartitions( sourceAttributes = nonPartitionedAttributes, providedPartitions = Map("b" -> Some("1"), "c" -> Some("3")), @@ -174,7 +178,7 @@ class DataSourceAnalysisSuite extends SparkFunSuite with BeforeAndAfterAll { // Test the case having a single static partition column. { val nonPartitionedAttributes = Seq('e.int, 'f.int) - val expected = nonPartitionedAttributes ++ Seq(Cast(Literal("1"), IntegerType)) + val expected = nonPartitionedAttributes ++ Seq(cast(Literal("1"), IntegerType)) val actual = rule.convertStaticPartitions( sourceAttributes = nonPartitionedAttributes, providedPartitions = Map("b" -> Some("1")), @@ -189,7 +193,7 @@ class DataSourceAnalysisSuite extends SparkFunSuite with BeforeAndAfterAll { val dynamicPartitionAttributes = Seq('g.int) val expected = nonPartitionedAttributes ++ - Seq(Cast(Literal("1"), IntegerType)) ++ + Seq(cast(Literal("1"), IntegerType)) ++ dynamicPartitionAttributes val actual = rule.convertStaticPartitions( sourceAttributes = nonPartitionedAttributes ++ dynamicPartitionAttributes, diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/DataSourceTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/DataSourceTest.scala index cc77d3c4b91ac..80868fff897fd 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/DataSourceTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/DataSourceTest.scala @@ -17,7 +17,11 @@ package org.apache.spark.sql.sources +import org.apache.spark.rdd.RDD import org.apache.spark.sql._ +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.types._ +import org.apache.spark.unsafe.types.UTF8String private[sql] abstract class DataSourceTest extends QueryTest { @@ -28,3 +32,55 @@ private[sql] abstract class DataSourceTest extends QueryTest { } } + +class DDLScanSource extends RelationProvider { + override def createRelation( + sqlContext: SQLContext, + parameters: Map[String, String]): BaseRelation = { + SimpleDDLScan( + parameters("from").toInt, + parameters("TO").toInt, + parameters("Table"))(sqlContext.sparkSession) + } +} + +case class SimpleDDLScan( + from: Int, + to: Int, + table: String)(@transient val sparkSession: SparkSession) + extends BaseRelation with TableScan { + + override def sqlContext: SQLContext = sparkSession.sqlContext + + override def schema: StructType = + StructType(Seq( + StructField("intType", IntegerType, nullable = false).withComment(s"test comment $table"), + StructField("stringType", StringType, nullable = false), + StructField("dateType", DateType, nullable = false), + StructField("timestampType", TimestampType, nullable = false), + StructField("doubleType", DoubleType, nullable = false), + StructField("bigintType", LongType, nullable = false), + StructField("tinyintType", ByteType, nullable = false), + StructField("decimalType", DecimalType.USER_DEFAULT, nullable = false), + StructField("fixedDecimalType", DecimalType(5, 1), nullable = false), + StructField("binaryType", BinaryType, nullable = false), + StructField("booleanType", BooleanType, nullable = false), + StructField("smallIntType", ShortType, nullable = false), + StructField("floatType", FloatType, nullable = false), + StructField("mapType", MapType(StringType, StringType)), + StructField("arrayType", ArrayType(StringType)), + StructField("structType", + StructType(StructField("f1", StringType) :: StructField("f2", IntegerType) :: Nil + ) + ) + )) + + override def needConversion: Boolean = false + + override def buildScan(): RDD[Row] = { + // Rely on a type erasure hack to pass RDD[InternalRow] back as RDD[Row] + sparkSession.sparkContext.parallelize(from to to).map { e => + InternalRow(UTF8String.fromString(s"people$e"), e * 2) + }.asInstanceOf[RDD[Row]] + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/FilteredScanSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/FilteredScanSuite.scala index be56c964a18f8..5a0388ec1d1db 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/FilteredScanSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/FilteredScanSuite.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.sources +import java.util.Locale + import scala.language.existentials import org.apache.spark.rdd.RDD @@ -76,7 +78,7 @@ case class SimpleFilteredScan(from: Int, to: Int)(@transient val sparkSession: S case "b" => (i: Int) => Seq(i * 2) case "c" => (i: Int) => val c = (i - 1 + 'a').toChar.toString - Seq(c * 5 + c.toUpperCase * 5) + Seq(c * 5 + c.toUpperCase(Locale.ROOT) * 5) } FiltersPushed.list = filters @@ -113,7 +115,8 @@ case class SimpleFilteredScan(from: Int, to: Int)(@transient val sparkSession: S } def eval(a: Int) = { - val c = (a - 1 + 'a').toChar.toString * 5 + (a - 1 + 'a').toChar.toString.toUpperCase * 5 + val c = (a - 1 + 'a').toChar.toString * 5 + + (a - 1 + 'a').toChar.toString.toUpperCase(Locale.ROOT) * 5 filters.forall(translateFilterOnA(_)(a)) && filters.forall(translateFilterOnC(_)(c)) } @@ -151,7 +154,7 @@ class FilteredScanSuite extends DataSourceTest with SharedSQLContext with Predic sqlTest( "SELECT * FROM oneToTenFiltered", (1 to 10).map(i => Row(i, i * 2, (i - 1 + 'a').toChar.toString * 5 - + (i - 1 + 'a').toChar.toString.toUpperCase * 5)).toSeq) + + (i - 1 + 'a').toChar.toString.toUpperCase(Locale.ROOT) * 5)).toSeq) sqlTest( "SELECT a, b FROM oneToTenFiltered", diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/PartitionedWriteSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/PartitionedWriteSuite.scala index f251290583c5e..a2f3afe3ce236 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/PartitionedWriteSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/PartitionedWriteSuite.scala @@ -25,6 +25,7 @@ import org.apache.hadoop.mapreduce.TaskAttemptContext import org.apache.spark.internal.Logging import org.apache.spark.sql.{QueryTest, Row} import org.apache.spark.sql.catalyst.catalog.ExternalCatalogUtils +import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.execution.datasources.SQLHadoopMapReduceCommitProtocol import org.apache.spark.sql.functions._ import org.apache.spark.sql.internal.SQLConf @@ -142,7 +143,8 @@ class PartitionedWriteSuite extends QueryTest with SharedSQLContext { checkPartitionValues(files.head, "2016-12-01 00:00:00") } withTempPath { f => - df.write.option("timeZone", "GMT").partitionBy("ts").parquet(f.getAbsolutePath) + df.write.option(DateTimeUtils.TIMEZONE_OPTION, "GMT") + .partitionBy("ts").parquet(f.getAbsolutePath) val files = recursiveList(f).filter(_.getAbsolutePath.endsWith("parquet")) assert(files.length == 1) // use timeZone option "GMT" to format partition value. diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/PathOptionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/PathOptionSuite.scala index faf9afc49a2d3..6dd4847ead738 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/PathOptionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/PathOptionSuite.scala @@ -17,10 +17,13 @@ package org.apache.spark.sql.sources +import java.net.URI + import org.apache.hadoop.fs.Path import org.apache.spark.sql.{DataFrame, SaveMode, SparkSession, SQLContext} import org.apache.spark.sql.catalyst.TableIdentifier +import org.apache.spark.sql.catalyst.catalog.CatalogUtils import org.apache.spark.sql.execution.datasources.LogicalRelation import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types.{IntegerType, Metadata, MetadataBuilder, StructType} @@ -72,28 +75,29 @@ class PathOptionSuite extends DataSourceTest with SharedSQLContext { |USING ${classOf[TestOptionsSource].getCanonicalName} |OPTIONS (PATH '/tmp/path') """.stripMargin) - assert(getPathOption("src") == Some("/tmp/path")) + assert(getPathOption("src").map(makeQualifiedPath) == Some(makeQualifiedPath("/tmp/path"))) } // should exist even path option is not specified when creating table withTable("src") { sql(s"CREATE TABLE src(i int) USING ${classOf[TestOptionsSource].getCanonicalName}") - assert(getPathOption("src") == Some(defaultTablePath("src"))) + assert(getPathOption("src").map(makeQualifiedPath) == Some(defaultTablePath("src"))) } } test("path option also exist for write path") { withTable("src") { withTempPath { p => - val path = new Path(p.getAbsolutePath).toString sql( s""" |CREATE TABLE src |USING ${classOf[TestOptionsSource].getCanonicalName} - |OPTIONS (PATH '$path') + |OPTIONS (PATH '$p') |AS SELECT 1 """.stripMargin) - assert(spark.table("src").schema.head.metadata.getString("path") == path) + assert( + spark.table("src").schema.head.metadata.getString("path") == + p.getAbsolutePath) } } @@ -105,7 +109,9 @@ class PathOptionSuite extends DataSourceTest with SharedSQLContext { |USING ${classOf[TestOptionsSource].getCanonicalName} |AS SELECT 1 """.stripMargin) - assert(spark.table("src").schema.head.metadata.getString("path") == defaultTablePath("src")) + assert( + makeQualifiedPath(spark.table("src").schema.head.metadata.getString("path")) == + defaultTablePath("src")) } } @@ -117,13 +123,13 @@ class PathOptionSuite extends DataSourceTest with SharedSQLContext { |USING ${classOf[TestOptionsSource].getCanonicalName} |OPTIONS (PATH '/tmp/path')""".stripMargin) sql("ALTER TABLE src SET LOCATION '/tmp/path2'") - assert(getPathOption("src") == Some("/tmp/path2")) + assert(getPathOption("src").map(makeQualifiedPath) == Some(makeQualifiedPath("/tmp/path2"))) } withTable("src", "src2") { sql(s"CREATE TABLE src(i int) USING ${classOf[TestOptionsSource].getCanonicalName}") sql("ALTER TABLE src RENAME TO src2") - assert(getPathOption("src2") == Some(defaultTablePath("src2"))) + assert(getPathOption("src2").map(makeQualifiedPath) == Some(defaultTablePath("src2"))) } } @@ -133,7 +139,7 @@ class PathOptionSuite extends DataSourceTest with SharedSQLContext { }.head } - private def defaultTablePath(tableName: String): String = { + private def defaultTablePath(tableName: String): URI = { spark.sessionState.catalog.defaultTablePath(TableIdentifier(tableName)) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/ResolvedDataSourceSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/ResolvedDataSourceSuite.scala index 9b5e364e512a2..0f97fd78d2ffb 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/ResolvedDataSourceSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/ResolvedDataSourceSuite.scala @@ -27,7 +27,8 @@ class ResolvedDataSourceSuite extends SparkFunSuite { DataSource( sparkSession = null, className = name, - options = Map("timeZone" -> DateTimeUtils.defaultTimeZone().getID)).providingClass + options = Map(DateTimeUtils.TIMEZONE_OPTION -> DateTimeUtils.defaultTimeZone().getID) + ).providingClass test("jdbc") { assert( diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/DeduplicateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/DeduplicateSuite.scala index 7ea716231e5dc..a15c2cff930fc 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/DeduplicateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/DeduplicateSuite.scala @@ -249,4 +249,23 @@ class DeduplicateSuite extends StateStoreMetricsTest with BeforeAndAfterAll { } } } + + test("SPARK-19841: watermarkPredicate should filter based on keys") { + val input = MemoryStream[(Int, Int)] + val df = input.toDS.toDF("time", "id") + .withColumn("time", $"time".cast("timestamp")) + .withWatermark("time", "1 second") + .dropDuplicates("id", "time") // Change the column positions + .select($"id") + testStream(df)( + AddData(input, 1 -> 1, 1 -> 1, 1 -> 2), + CheckLastBatch(1, 2), + AddData(input, 1 -> 1, 2 -> 3, 2 -> 4), + CheckLastBatch(3, 4), + AddData(input, 1 -> 0, 1 -> 1, 3 -> 5, 3 -> 6), // Drop (1 -> 0, 1 -> 1) due to watermark + CheckLastBatch(5, 6), + AddData(input, 1 -> 0, 4 -> 7), // Drop (1 -> 0) due to watermark + CheckLastBatch(7) + ) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/EventTimeWatermarkSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/EventTimeWatermarkSuite.scala index c34d119734cc0..fd850a7365e20 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/EventTimeWatermarkSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/EventTimeWatermarkSuite.scala @@ -25,6 +25,7 @@ import org.scalatest.BeforeAndAfter import org.apache.spark.internal.Logging import org.apache.spark.sql.AnalysisException +import org.apache.spark.sql.catalyst.plans.logical.EventTimeWatermark import org.apache.spark.sql.execution.streaming._ import org.apache.spark.sql.functions.{count, window} import org.apache.spark.sql.streaming.OutputMode._ @@ -217,7 +218,9 @@ class EventTimeWatermarkSuite extends StreamTest with BeforeAndAfter with Loggin AddData(inputData, 25), // Evict items less than previous watermark. CheckLastBatch((10, 5)), StopStream, - AssertOnQuery { q => // clear the sink + AssertOnQuery { q => // purge commit and clear the sink + val commit = q.batchCommitLog.getLatest().map(_._1).getOrElse(-1L) + 1L + q.batchCommitLog.purge(commit) q.sink.asInstanceOf[MemorySink].clear() true }, @@ -305,6 +308,42 @@ class EventTimeWatermarkSuite extends StreamTest with BeforeAndAfter with Loggin ) } + test("delay threshold should not be negative.") { + val inputData = MemoryStream[Int].toDF() + var e = intercept[IllegalArgumentException] { + inputData.withWatermark("value", "-1 year") + } + assert(e.getMessage contains "should not be negative.") + + e = intercept[IllegalArgumentException] { + inputData.withWatermark("value", "1 year -13 months") + } + assert(e.getMessage contains "should not be negative.") + + e = intercept[IllegalArgumentException] { + inputData.withWatermark("value", "1 month -40 days") + } + assert(e.getMessage contains "should not be negative.") + + e = intercept[IllegalArgumentException] { + inputData.withWatermark("value", "-10 seconds") + } + assert(e.getMessage contains "should not be negative.") + } + + test("the new watermark should override the old one") { + val df = MemoryStream[(Long, Long)].toDF() + .withColumn("first", $"_1".cast("timestamp")) + .withColumn("second", $"_2".cast("timestamp")) + .withWatermark("first", "1 minute") + .withWatermark("second", "2 minutes") + + val eventTimeColumns = df.logicalPlan.output + .filter(_.metadata.contains(EventTimeWatermark.delayKey)) + assert(eventTimeColumns.size === 1) + assert(eventTimeColumns(0).name === "second") + } + private def assertNumStateRows(numTotalRows: Long): AssertOnQuery = AssertOnQuery { q => val progressWithData = q.recentProgress.filter(_.numInputRows > 0).lastOption.get assert(progressWithData.stateOperators(0).numRowsTotal === numTotalRows) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSinkSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSinkSuite.scala index f67444fbc49d6..1a2d3a13f3a4a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSinkSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSinkSuite.scala @@ -17,10 +17,14 @@ package org.apache.spark.sql.streaming +import java.util.Locale + +import org.apache.hadoop.fs.Path + import org.apache.spark.sql.{AnalysisException, DataFrame} import org.apache.spark.sql.execution.DataSourceScanExec import org.apache.spark.sql.execution.datasources._ -import org.apache.spark.sql.execution.streaming.{MemoryStream, MetadataLogFileIndex} +import org.apache.spark.sql.execution.streaming._ import org.apache.spark.sql.functions._ import org.apache.spark.sql.types.{IntegerType, StructField, StructType} import org.apache.spark.util.Utils @@ -143,6 +147,43 @@ class FileStreamSinkSuite extends StreamTest { } } + test("partitioned writing and batch reading with 'basePath'") { + withTempDir { outputDir => + withTempDir { checkpointDir => + val outputPath = outputDir.getAbsolutePath + val inputData = MemoryStream[Int] + val ds = inputData.toDS() + + var query: StreamingQuery = null + + try { + query = + ds.map(i => (i, -i, i * 1000)) + .toDF("id1", "id2", "value") + .writeStream + .partitionBy("id1", "id2") + .option("checkpointLocation", checkpointDir.getAbsolutePath) + .format("parquet") + .start(outputPath) + + inputData.addData(1, 2, 3) + failAfter(streamingTimeout) { + query.processAllAvailable() + } + + val readIn = spark.read.option("basePath", outputPath).parquet(s"$outputDir/*/*") + checkDatasetUnorderly( + readIn.as[(Int, Int, Int)], + (1000, 1, -1), (2000, 2, -2), (3000, 3, -3)) + } finally { + if (query != null) { + query.stop() + } + } + } + } + } + // This tests whether FileStreamSink works with aggregations. Specifically, it tests // whether the correct streaming QueryExecution (i.e. IncrementalExecution) is used to // to execute the trigger for writing data to file sink. See SPARK-18440 for more details. @@ -221,7 +262,7 @@ class FileStreamSinkSuite extends StreamTest { df.writeStream.format("parquet").outputMode(mode).start(dir.getCanonicalPath) } Seq(mode, "not support").foreach { w => - assert(e.getMessage.toLowerCase.contains(w)) + assert(e.getMessage.toLowerCase(Locale.ROOT).contains(w)) } } @@ -264,4 +305,22 @@ class FileStreamSinkSuite extends StreamTest { } } } + + test("FileStreamSink.ancestorIsMetadataDirectory()") { + val hadoopConf = spark.sparkContext.hadoopConfiguration + def assertAncestorIsMetadataDirectory(path: String): Unit = + assert(FileStreamSink.ancestorIsMetadataDirectory(new Path(path), hadoopConf)) + def assertAncestorIsNotMetadataDirectory(path: String): Unit = + assert(!FileStreamSink.ancestorIsMetadataDirectory(new Path(path), hadoopConf)) + + assertAncestorIsMetadataDirectory(s"/${FileStreamSink.metadataDir}") + assertAncestorIsMetadataDirectory(s"/${FileStreamSink.metadataDir}/") + assertAncestorIsMetadataDirectory(s"/a/${FileStreamSink.metadataDir}") + assertAncestorIsMetadataDirectory(s"/a/${FileStreamSink.metadataDir}/") + assertAncestorIsMetadataDirectory(s"/a/b/${FileStreamSink.metadataDir}/c") + assertAncestorIsMetadataDirectory(s"/a/b/${FileStreamSink.metadataDir}/c/") + + assertAncestorIsNotMetadataDirectory(s"/a/b/c") + assertAncestorIsNotMetadataDirectory(s"/a/b/c/${FileStreamSink.metadataDir}extra") + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSourceSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSourceSuite.scala index 1586850c77fca..2108b118bf059 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSourceSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSourceSuite.scala @@ -33,6 +33,7 @@ import org.apache.spark.sql.execution.streaming._ import org.apache.spark.sql.execution.streaming.FileStreamSource.{FileEntry, SeenFilesMap} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.streaming.ExistsThrowsExceptionFileSystem._ +import org.apache.spark.sql.streaming.util.StreamManualClock import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types._ import org.apache.spark.util.Utils @@ -43,8 +44,8 @@ abstract class FileStreamSourceTest import testImplicits._ /** - * A subclass [[AddData]] for adding data to files. This is meant to use the - * [[FileStreamSource]] actually being used in the execution. + * A subclass `AddData` for adding data to files. This is meant to use the + * `FileStreamSource` actually being used in the execution. */ abstract class AddFileData extends AddData { override def addData(query: Option[StreamExecution]): (Source, Offset) = { @@ -909,7 +910,7 @@ class FileStreamSourceSuite extends FileStreamSourceTest { } } - test("max files per trigger - incorrect values") { + testQuietly("max files per trigger - incorrect values") { val testTable = "maxFilesPerTrigger_test" withTable(testTable) { withTempDir { case src => @@ -1173,6 +1174,41 @@ class FileStreamSourceSuite extends FileStreamSourceTest { SerializedOffset(str.trim) } + private def runTwoBatchesAndVerifyResults( + src: File, + latestFirst: Boolean, + firstBatch: String, + secondBatch: String, + maxFileAge: Option[String] = None): Unit = { + val srcOptions = Map("latestFirst" -> latestFirst.toString, "maxFilesPerTrigger" -> "1") ++ + maxFileAge.map("maxFileAge" -> _) + val fileStream = createFileStream( + "text", + src.getCanonicalPath, + options = srcOptions) + val clock = new StreamManualClock() + testStream(fileStream)( + StartStream(trigger = ProcessingTime(10), triggerClock = clock), + AssertOnQuery { _ => + // Block until the first batch finishes. + eventually(timeout(streamingTimeout)) { + assert(clock.isStreamWaitingAt(0)) + } + true + }, + CheckLastBatch(firstBatch), + AdvanceManualClock(10), + AssertOnQuery { _ => + // Block until the second batch finishes. + eventually(timeout(streamingTimeout)) { + assert(clock.isStreamWaitingAt(10)) + } + true + }, + CheckLastBatch(secondBatch) + ) + } + test("FileStreamSource - latestFirst") { withTempDir { src => // Prepare two files: 1.txt, 2.txt, and make sure they have different modified time. @@ -1180,47 +1216,28 @@ class FileStreamSourceSuite extends FileStreamSourceTest { val f2 = stringToFile(new File(src, "2.txt"), "2") f2.setLastModified(f1.lastModified + 1000) - def runTwoBatchesAndVerifyResults( - latestFirst: Boolean, - firstBatch: String, - secondBatch: String): Unit = { - val fileStream = createFileStream( - "text", - src.getCanonicalPath, - options = Map("latestFirst" -> latestFirst.toString, "maxFilesPerTrigger" -> "1")) - val clock = new StreamManualClock() - testStream(fileStream)( - StartStream(trigger = ProcessingTime(10), triggerClock = clock), - AssertOnQuery { _ => - // Block until the first batch finishes. - eventually(timeout(streamingTimeout)) { - assert(clock.isStreamWaitingAt(0)) - } - true - }, - CheckLastBatch(firstBatch), - AdvanceManualClock(10), - AssertOnQuery { _ => - // Block until the second batch finishes. - eventually(timeout(streamingTimeout)) { - assert(clock.isStreamWaitingAt(10)) - } - true - }, - CheckLastBatch(secondBatch) - ) - } - // Read oldest files first, so the first batch is "1", and the second batch is "2". - runTwoBatchesAndVerifyResults(latestFirst = false, firstBatch = "1", secondBatch = "2") + runTwoBatchesAndVerifyResults(src, latestFirst = false, firstBatch = "1", secondBatch = "2") // Read latest files first, so the first batch is "2", and the second batch is "1". - runTwoBatchesAndVerifyResults(latestFirst = true, firstBatch = "2", secondBatch = "1") + runTwoBatchesAndVerifyResults(src, latestFirst = true, firstBatch = "2", secondBatch = "1") + } + } + + test("SPARK-19813: Ignore maxFileAge when maxFilesPerTrigger and latestFirst is used") { + withTempDir { src => + // Prepare two files: 1.txt, 2.txt, and make sure they have different modified time. + val f1 = stringToFile(new File(src, "1.txt"), "1") + val f2 = stringToFile(new File(src, "2.txt"), "2") + f2.setLastModified(f1.lastModified + 3600 * 1000 /* 1 hour later */) + + runTwoBatchesAndVerifyResults(src, latestFirst = true, firstBatch = "2", secondBatch = "1", + maxFileAge = Some("1m") /* 1 minute */) } } test("SeenFilesMap") { - val map = new SeenFilesMap(maxAgeMs = 10) + val map = new SeenFilesMap(maxAgeMs = 10, fileNameOnly = false) map.add("a", 5) assert(map.size == 1) @@ -1253,8 +1270,26 @@ class FileStreamSourceSuite extends FileStreamSourceTest { assert(map.isNewFile("e", 20)) } + test("SeenFilesMap with fileNameOnly = true") { + val map = new SeenFilesMap(maxAgeMs = 10, fileNameOnly = true) + + map.add("file:///a/b/c/d", 5) + map.add("file:///a/b/c/e", 5) + assert(map.size === 2) + + assert(!map.isNewFile("d", 5)) + assert(!map.isNewFile("file:///d", 5)) + assert(!map.isNewFile("file:///x/d", 5)) + assert(!map.isNewFile("file:///x/y/d", 5)) + + map.add("s3:///bucket/d", 5) + map.add("s3n:///bucket/d", 5) + map.add("s3a:///bucket/d", 5) + assert(map.size === 2) + } + test("SeenFilesMap should only consider a file old if it is earlier than last purge time") { - val map = new SeenFilesMap(maxAgeMs = 10) + val map = new SeenFilesMap(maxAgeMs = 10, fileNameOnly = false) map.add("a", 20) assert(map.size == 1) @@ -1292,7 +1327,7 @@ class FileStreamSourceStressTestSuite extends FileStreamSourceTest { import testImplicits._ - test("file source stress test") { + testQuietly("file source stress test") { val src = Utils.createTempDir(namePrefix = "streaming.src") val tmp = Utils.createTempDir(namePrefix = "streaming.tmp") diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateSuite.scala new file mode 100644 index 0000000000000..85aa7dbe9ed86 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateSuite.scala @@ -0,0 +1,927 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.streaming + +import java.sql.Date +import java.util.concurrent.ConcurrentHashMap + +import org.scalatest.BeforeAndAfterAll + +import org.apache.spark.SparkException +import org.apache.spark.api.java.function.FlatMapGroupsWithStateFunction +import org.apache.spark.sql.Encoder +import org.apache.spark.sql.catalyst.expressions.{GenericInternalRow, UnsafeProjection, UnsafeRow} +import org.apache.spark.sql.catalyst.plans.logical.FlatMapGroupsWithState +import org.apache.spark.sql.catalyst.plans.physical.UnknownPartitioning +import org.apache.spark.sql.catalyst.streaming.InternalOutputModes._ +import org.apache.spark.sql.execution.RDDScanExec +import org.apache.spark.sql.execution.streaming.{FlatMapGroupsWithStateExec, GroupStateImpl, MemoryStream} +import org.apache.spark.sql.execution.streaming.state.{StateStore, StateStoreId, StoreUpdate} +import org.apache.spark.sql.streaming.FlatMapGroupsWithStateSuite.MemoryStateStore +import org.apache.spark.sql.streaming.util.StreamManualClock +import org.apache.spark.sql.types.{DataType, IntegerType} + +/** Class to check custom state types */ +case class RunningCount(count: Long) + +case class Result(key: Long, count: Int) + +class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest with BeforeAndAfterAll { + + import testImplicits._ + import GroupStateImpl._ + import GroupStateTimeout._ + + override def afterAll(): Unit = { + super.afterAll() + StateStore.stop() + } + + test("GroupState - get, exists, update, remove") { + var state: GroupStateImpl[String] = null + + def testState( + expectedData: Option[String], + shouldBeUpdated: Boolean = false, + shouldBeRemoved: Boolean = false): Unit = { + if (expectedData.isDefined) { + assert(state.exists) + assert(state.get === expectedData.get) + } else { + assert(!state.exists) + intercept[NoSuchElementException] { + state.get + } + } + assert(state.getOption === expectedData) + assert(state.hasUpdated === shouldBeUpdated) + assert(state.hasRemoved === shouldBeRemoved) + } + + // Updating empty state + state = new GroupStateImpl[String](None) + testState(None) + state.update("") + testState(Some(""), shouldBeUpdated = true) + + // Updating exiting state + state = new GroupStateImpl[String](Some("2")) + testState(Some("2")) + state.update("3") + testState(Some("3"), shouldBeUpdated = true) + + // Removing state + state.remove() + testState(None, shouldBeRemoved = true, shouldBeUpdated = false) + state.remove() // should be still callable + state.update("4") + testState(Some("4"), shouldBeRemoved = false, shouldBeUpdated = true) + + // Updating by null throw exception + intercept[IllegalArgumentException] { + state.update(null) + } + } + + test("GroupState - setTimeout**** with NoTimeout") { + for (initState <- Seq(None, Some(5))) { + // for different initial state + implicit val state = new GroupStateImpl(initState, 1000, 1000, NoTimeout, hasTimedOut = false) + testTimeoutDurationNotAllowed[UnsupportedOperationException](state) + testTimeoutTimestampNotAllowed[UnsupportedOperationException](state) + } + } + + test("GroupState - setTimeout**** with ProcessingTimeTimeout") { + implicit var state: GroupStateImpl[Int] = null + + state = new GroupStateImpl[Int](None, 1000, 1000, ProcessingTimeTimeout, hasTimedOut = false) + assert(state.getTimeoutTimestamp === NO_TIMESTAMP) + testTimeoutDurationNotAllowed[IllegalStateException](state) + testTimeoutTimestampNotAllowed[UnsupportedOperationException](state) + + state.update(5) + assert(state.getTimeoutTimestamp === NO_TIMESTAMP) + state.setTimeoutDuration(1000) + assert(state.getTimeoutTimestamp === 2000) + state.setTimeoutDuration("2 second") + assert(state.getTimeoutTimestamp === 3000) + testTimeoutTimestampNotAllowed[UnsupportedOperationException](state) + + state.remove() + assert(state.getTimeoutTimestamp === NO_TIMESTAMP) + testTimeoutDurationNotAllowed[IllegalStateException](state) + testTimeoutTimestampNotAllowed[UnsupportedOperationException](state) + } + + test("GroupState - setTimeout**** with EventTimeTimeout") { + implicit val state = new GroupStateImpl[Int]( + None, 1000, 1000, EventTimeTimeout, hasTimedOut = false) + assert(state.getTimeoutTimestamp === NO_TIMESTAMP) + testTimeoutDurationNotAllowed[UnsupportedOperationException](state) + testTimeoutTimestampNotAllowed[IllegalStateException](state) + + state.update(5) + state.setTimeoutTimestamp(10000) + assert(state.getTimeoutTimestamp === 10000) + state.setTimeoutTimestamp(new Date(20000)) + assert(state.getTimeoutTimestamp === 20000) + testTimeoutDurationNotAllowed[UnsupportedOperationException](state) + + state.remove() + assert(state.getTimeoutTimestamp === NO_TIMESTAMP) + testTimeoutDurationNotAllowed[UnsupportedOperationException](state) + testTimeoutTimestampNotAllowed[IllegalStateException](state) + } + + test("GroupState - illegal params to setTimeout****") { + var state: GroupStateImpl[Int] = null + + // Test setTimeout****() with illegal values + def testIllegalTimeout(body: => Unit): Unit = { + intercept[IllegalArgumentException] { body } + assert(state.getTimeoutTimestamp === NO_TIMESTAMP) + } + + state = new GroupStateImpl(Some(5), 1000, 1000, ProcessingTimeTimeout, hasTimedOut = false) + testIllegalTimeout { state.setTimeoutDuration(-1000) } + testIllegalTimeout { state.setTimeoutDuration(0) } + testIllegalTimeout { state.setTimeoutDuration("-2 second") } + testIllegalTimeout { state.setTimeoutDuration("-1 month") } + testIllegalTimeout { state.setTimeoutDuration("1 month -1 day") } + + state = new GroupStateImpl(Some(5), 1000, 1000, EventTimeTimeout, hasTimedOut = false) + testIllegalTimeout { state.setTimeoutTimestamp(-10000) } + testIllegalTimeout { state.setTimeoutTimestamp(10000, "-3 second") } + testIllegalTimeout { state.setTimeoutTimestamp(10000, "-1 month") } + testIllegalTimeout { state.setTimeoutTimestamp(10000, "1 month -1 day") } + testIllegalTimeout { state.setTimeoutTimestamp(new Date(-10000)) } + testIllegalTimeout { state.setTimeoutTimestamp(new Date(-10000), "-3 second") } + testIllegalTimeout { state.setTimeoutTimestamp(new Date(-10000), "-1 month") } + testIllegalTimeout { state.setTimeoutTimestamp(new Date(-10000), "1 month -1 day") } + } + + test("GroupState - hasTimedOut") { + for (timeoutConf <- Seq(NoTimeout, ProcessingTimeTimeout, EventTimeTimeout)) { + for (initState <- Seq(None, Some(5))) { + val state1 = new GroupStateImpl(initState, 1000, 1000, timeoutConf, hasTimedOut = false) + assert(state1.hasTimedOut === false) + val state2 = new GroupStateImpl(initState, 1000, 1000, timeoutConf, hasTimedOut = true) + assert(state2.hasTimedOut === true) + } + } + } + + test("GroupState - primitive type") { + var intState = new GroupStateImpl[Int](None) + intercept[NoSuchElementException] { + intState.get + } + assert(intState.getOption === None) + + intState = new GroupStateImpl[Int](Some(10)) + assert(intState.get == 10) + intState.update(0) + assert(intState.get == 0) + intState.remove() + intercept[NoSuchElementException] { + intState.get + } + } + + // Values used for testing StateStoreUpdater + val currentBatchTimestamp = 1000 + val currentBatchWatermark = 1000 + val beforeTimeoutThreshold = 999 + val afterTimeoutThreshold = 1001 + + + // Tests for StateStoreUpdater.updateStateForKeysWithData() when timeout = NoTimeout + for (priorState <- Seq(None, Some(0))) { + val priorStateStr = if (priorState.nonEmpty) "prior state set" else "no prior state" + val testName = s"NoTimeout - $priorStateStr - " + + testStateUpdateWithData( + testName + "no update", + stateUpdates = state => { /* do nothing */ }, + timeoutConf = GroupStateTimeout.NoTimeout, + priorState = priorState, + expectedState = priorState) // should not change + + testStateUpdateWithData( + testName + "state updated", + stateUpdates = state => { state.update(5) }, + timeoutConf = GroupStateTimeout.NoTimeout, + priorState = priorState, + expectedState = Some(5)) // should change + + testStateUpdateWithData( + testName + "state removed", + stateUpdates = state => { state.remove() }, + timeoutConf = GroupStateTimeout.NoTimeout, + priorState = priorState, + expectedState = None) // should be removed + } + + // Tests for StateStoreUpdater.updateStateForKeysWithData() when timeout != NoTimeout + for (priorState <- Seq(None, Some(0))) { + for (priorTimeoutTimestamp <- Seq(NO_TIMESTAMP, 1000)) { + var testName = "" + if (priorState.nonEmpty) { + testName += "prior state set, " + if (priorTimeoutTimestamp == 1000) { + testName += "prior timeout set" + } else { + testName += "no prior timeout" + } + } else { + testName += "no prior state" + } + for (timeoutConf <- Seq(ProcessingTimeTimeout, EventTimeTimeout)) { + + testStateUpdateWithData( + s"$timeoutConf - $testName - no update", + stateUpdates = state => { /* do nothing */ }, + timeoutConf = timeoutConf, + priorState = priorState, + priorTimeoutTimestamp = priorTimeoutTimestamp, + expectedState = priorState, // state should not change + expectedTimeoutTimestamp = NO_TIMESTAMP) // timestamp should be reset + + testStateUpdateWithData( + s"$timeoutConf - $testName - state updated", + stateUpdates = state => { state.update(5) }, + timeoutConf = timeoutConf, + priorState = priorState, + priorTimeoutTimestamp = priorTimeoutTimestamp, + expectedState = Some(5), // state should change + expectedTimeoutTimestamp = NO_TIMESTAMP) // timestamp should be reset + + testStateUpdateWithData( + s"$timeoutConf - $testName - state removed", + stateUpdates = state => { state.remove() }, + timeoutConf = timeoutConf, + priorState = priorState, + priorTimeoutTimestamp = priorTimeoutTimestamp, + expectedState = None) // state should be removed + } + + testStateUpdateWithData( + s"ProcessingTimeTimeout - $testName - state and timeout duration updated", + stateUpdates = + (state: GroupState[Int]) => { state.update(5); state.setTimeoutDuration(5000) }, + timeoutConf = ProcessingTimeTimeout, + priorState = priorState, + priorTimeoutTimestamp = priorTimeoutTimestamp, + expectedState = Some(5), // state should change + expectedTimeoutTimestamp = currentBatchTimestamp + 5000) // timestamp should change + + testStateUpdateWithData( + s"EventTimeTimeout - $testName - state and timeout timestamp updated", + stateUpdates = + (state: GroupState[Int]) => { state.update(5); state.setTimeoutTimestamp(5000) }, + timeoutConf = EventTimeTimeout, + priorState = priorState, + priorTimeoutTimestamp = priorTimeoutTimestamp, + expectedState = Some(5), // state should change + expectedTimeoutTimestamp = 5000) // timestamp should change + + testStateUpdateWithData( + s"EventTimeTimeout - $testName - timeout timestamp updated to before watermark", + stateUpdates = + (state: GroupState[Int]) => { + state.update(5) + intercept[IllegalArgumentException] { + state.setTimeoutTimestamp(currentBatchWatermark - 1) // try to set to < watermark + } + }, + timeoutConf = EventTimeTimeout, + priorState = priorState, + priorTimeoutTimestamp = priorTimeoutTimestamp, + expectedState = Some(5), // state should change + expectedTimeoutTimestamp = NO_TIMESTAMP) // timestamp should not update + } + } + + // Tests for StateStoreUpdater.updateStateForTimedOutKeys() + val preTimeoutState = Some(5) + for (timeoutConf <- Seq(ProcessingTimeTimeout, EventTimeTimeout)) { + testStateUpdateWithTimeout( + s"$timeoutConf - should not timeout", + stateUpdates = state => { assert(false, "function called without timeout") }, + timeoutConf = timeoutConf, + priorTimeoutTimestamp = afterTimeoutThreshold, + expectedState = preTimeoutState, // state should not change + expectedTimeoutTimestamp = afterTimeoutThreshold) // timestamp should not change + + testStateUpdateWithTimeout( + s"$timeoutConf - should timeout - no update/remove", + stateUpdates = state => { /* do nothing */ }, + timeoutConf = timeoutConf, + priorTimeoutTimestamp = beforeTimeoutThreshold, + expectedState = preTimeoutState, // state should not change + expectedTimeoutTimestamp = NO_TIMESTAMP) // timestamp should be reset + + testStateUpdateWithTimeout( + s"$timeoutConf - should timeout - update state", + stateUpdates = state => { state.update(5) }, + timeoutConf = timeoutConf, + priorTimeoutTimestamp = beforeTimeoutThreshold, + expectedState = Some(5), // state should change + expectedTimeoutTimestamp = NO_TIMESTAMP) // timestamp should be reset + + testStateUpdateWithTimeout( + s"$timeoutConf - should timeout - remove state", + stateUpdates = state => { state.remove() }, + timeoutConf = timeoutConf, + priorTimeoutTimestamp = beforeTimeoutThreshold, + expectedState = None, // state should be removed + expectedTimeoutTimestamp = NO_TIMESTAMP) + } + + testStateUpdateWithTimeout( + "ProcessingTimeTimeout - should timeout - timeout duration updated", + stateUpdates = state => { state.setTimeoutDuration(2000) }, + timeoutConf = ProcessingTimeTimeout, + priorTimeoutTimestamp = beforeTimeoutThreshold, + expectedState = preTimeoutState, // state should not change + expectedTimeoutTimestamp = currentBatchTimestamp + 2000) // timestamp should change + + testStateUpdateWithTimeout( + "ProcessingTimeTimeout - should timeout - timeout duration and state updated", + stateUpdates = state => { state.update(5); state.setTimeoutDuration(2000) }, + timeoutConf = ProcessingTimeTimeout, + priorTimeoutTimestamp = beforeTimeoutThreshold, + expectedState = Some(5), // state should change + expectedTimeoutTimestamp = currentBatchTimestamp + 2000) // timestamp should change + + testStateUpdateWithTimeout( + "EventTimeTimeout - should timeout - timeout timestamp updated", + stateUpdates = state => { state.setTimeoutTimestamp(5000) }, + timeoutConf = EventTimeTimeout, + priorTimeoutTimestamp = beforeTimeoutThreshold, + expectedState = preTimeoutState, // state should not change + expectedTimeoutTimestamp = 5000) // timestamp should change + + testStateUpdateWithTimeout( + "EventTimeTimeout - should timeout - timeout and state updated", + stateUpdates = state => { state.update(5); state.setTimeoutTimestamp(5000) }, + timeoutConf = EventTimeTimeout, + priorTimeoutTimestamp = beforeTimeoutThreshold, + expectedState = Some(5), // state should change + expectedTimeoutTimestamp = 5000) // timestamp should change + + test("StateStoreUpdater - rows are cloned before writing to StateStore") { + // function for running count + val func = (key: Int, values: Iterator[Int], state: GroupState[Int]) => { + state.update(state.getOption.getOrElse(0) + values.size) + Iterator.empty + } + val store = newStateStore() + val plan = newFlatMapGroupsWithStateExec(func) + val updater = new plan.StateStoreUpdater(store) + val data = Seq(1, 1, 2) + val returnIter = updater.updateStateForKeysWithData(data.iterator.map(intToRow)) + returnIter.size // consume the iterator to force store updates + val storeData = store.iterator.map { case (k, v) => (rowToInt(k), rowToInt(v)) }.toSet + assert(storeData === Set((1, 2), (2, 1))) + } + + test("flatMapGroupsWithState - streaming") { + // Function to maintain running count up to 2, and then remove the count + // Returns the data and the count if state is defined, otherwise does not return anything + val stateFunc = (key: String, values: Iterator[String], state: GroupState[RunningCount]) => { + + val count = state.getOption.map(_.count).getOrElse(0L) + values.size + if (count == 3) { + state.remove() + Iterator.empty + } else { + state.update(RunningCount(count)) + Iterator((key, count.toString)) + } + } + + val inputData = MemoryStream[String] + val result = + inputData.toDS() + .groupByKey(x => x) + .flatMapGroupsWithState(Update, GroupStateTimeout.NoTimeout)(stateFunc) + + testStream(result, Update)( + AddData(inputData, "a"), + CheckLastBatch(("a", "1")), + assertNumStateRows(total = 1, updated = 1), + AddData(inputData, "a", "b"), + CheckLastBatch(("a", "2"), ("b", "1")), + assertNumStateRows(total = 2, updated = 2), + StopStream, + StartStream(), + AddData(inputData, "a", "b"), // should remove state for "a" and not return anything for a + CheckLastBatch(("b", "2")), + assertNumStateRows(total = 1, updated = 2), + StopStream, + StartStream(), + AddData(inputData, "a", "c"), // should recreate state for "a" and return count as 1 and + CheckLastBatch(("a", "1"), ("c", "1")), + assertNumStateRows(total = 3, updated = 2) + ) + } + + test("flatMapGroupsWithState - streaming + func returns iterator that updates state lazily") { + // Function to maintain running count up to 2, and then remove the count + // Returns the data and the count if state is defined, otherwise does not return anything + // Additionally, it updates state lazily as the returned iterator get consumed + val stateFunc = (key: String, values: Iterator[String], state: GroupState[RunningCount]) => { + values.flatMap { _ => + val count = state.getOption.map(_.count).getOrElse(0L) + 1 + if (count == 3) { + state.remove() + None + } else { + state.update(RunningCount(count)) + Some((key, count.toString)) + } + } + } + + val inputData = MemoryStream[String] + val result = + inputData.toDS() + .groupByKey(x => x) + .flatMapGroupsWithState(Update, GroupStateTimeout.NoTimeout)(stateFunc) + testStream(result, Update)( + AddData(inputData, "a", "a", "b"), + CheckLastBatch(("a", "1"), ("a", "2"), ("b", "1")), + StopStream, + StartStream(), + AddData(inputData, "a", "b"), // should remove state for "a" and not return anything for a + CheckLastBatch(("b", "2")), + StopStream, + StartStream(), + AddData(inputData, "a", "c"), // should recreate state for "a" and return count as 1 and + CheckLastBatch(("a", "1"), ("c", "1")) + ) + } + + test("flatMapGroupsWithState - streaming + aggregation") { + // Function to maintain running count up to 2, and then remove the count + // Returns the data and the count (-1 if count reached beyond 2 and state was just removed) + val stateFunc = (key: String, values: Iterator[String], state: GroupState[RunningCount]) => { + + val count = state.getOption.map(_.count).getOrElse(0L) + values.size + if (count == 3) { + state.remove() + Iterator(key -> "-1") + } else { + state.update(RunningCount(count)) + Iterator(key -> count.toString) + } + } + + val inputData = MemoryStream[String] + val result = + inputData.toDS() + .groupByKey(x => x) + .flatMapGroupsWithState(Append, GroupStateTimeout.NoTimeout)(stateFunc) + .groupByKey(_._1) + .count() + + testStream(result, Complete)( + AddData(inputData, "a"), + CheckLastBatch(("a", 1)), + AddData(inputData, "a", "b"), + // mapGroups generates ("a", "2"), ("b", "1"); so increases counts of a and b by 1 + CheckLastBatch(("a", 2), ("b", 1)), + StopStream, + StartStream(), + AddData(inputData, "a", "b"), + // mapGroups should remove state for "a" and generate ("a", "-1"), ("b", "2") ; + // so increment a and b by 1 + CheckLastBatch(("a", 3), ("b", 2)), + StopStream, + StartStream(), + AddData(inputData, "a", "c"), + // mapGroups should recreate state for "a" and generate ("a", "1"), ("c", "1") ; + // so increment a and c by 1 + CheckLastBatch(("a", 4), ("b", 2), ("c", 1)) + ) + } + + test("flatMapGroupsWithState - batch") { + // Function that returns running count only if its even, otherwise does not return + val stateFunc = (key: String, values: Iterator[String], state: GroupState[RunningCount]) => { + if (state.exists) throw new IllegalArgumentException("state.exists should be false") + Iterator((key, values.size)) + } + val df = Seq("a", "a", "b").toDS + .groupByKey(x => x) + .flatMapGroupsWithState(Update, GroupStateTimeout.NoTimeout)(stateFunc).toDF + checkAnswer(df, Seq(("a", 2), ("b", 1)).toDF) + } + + test("flatMapGroupsWithState - streaming with processing time timeout") { + // Function to maintain running count up to 2, and then remove the count + // Returns the data and the count (-1 if count reached beyond 2 and state was just removed) + val stateFunc = (key: String, values: Iterator[String], state: GroupState[RunningCount]) => { + if (state.hasTimedOut) { + state.remove() + Iterator((key, "-1")) + } else { + val count = state.getOption.map(_.count).getOrElse(0L) + values.size + state.update(RunningCount(count)) + state.setTimeoutDuration("10 seconds") + Iterator((key, count.toString)) + } + } + + val clock = new StreamManualClock + val inputData = MemoryStream[String] + val result = + inputData.toDS() + .groupByKey(x => x) + .flatMapGroupsWithState(Update, ProcessingTimeTimeout)(stateFunc) + + testStream(result, Update)( + StartStream(ProcessingTime("1 second"), triggerClock = clock), + AddData(inputData, "a"), + AdvanceManualClock(1 * 1000), + CheckLastBatch(("a", "1")), + assertNumStateRows(total = 1, updated = 1), + + AddData(inputData, "b"), + AdvanceManualClock(1 * 1000), + CheckLastBatch(("b", "1")), + assertNumStateRows(total = 2, updated = 1), + + AddData(inputData, "b"), + AdvanceManualClock(10 * 1000), + CheckLastBatch(("a", "-1"), ("b", "2")), + assertNumStateRows(total = 1, updated = 2), + + StopStream, + StartStream(Trigger.ProcessingTime("1 second"), triggerClock = clock), + + AddData(inputData, "c"), + AdvanceManualClock(11 * 1000), + CheckLastBatch(("b", "-1"), ("c", "1")), + assertNumStateRows(total = 1, updated = 2), + + AddData(inputData, "c"), + AdvanceManualClock(20 * 1000), + CheckLastBatch(("c", "2")), + assertNumStateRows(total = 1, updated = 1) + ) + } + + test("flatMapGroupsWithState - streaming with event time timeout") { + // Function to maintain the max event time + // Returns the max event time in the state, or -1 if the state was removed by timeout + val stateFunc = ( + key: String, + values: Iterator[(String, Long)], + state: GroupState[Long]) => { + val timeoutDelay = 5 + if (key != "a") { + Iterator.empty + } else { + if (state.hasTimedOut) { + state.remove() + Iterator((key, -1)) + } else { + val valuesSeq = values.toSeq + val maxEventTime = math.max(valuesSeq.map(_._2).max, state.getOption.getOrElse(0L)) + val timeoutTimestampMs = maxEventTime + timeoutDelay + state.update(maxEventTime) + state.setTimeoutTimestamp(timeoutTimestampMs * 1000) + Iterator((key, maxEventTime.toInt)) + } + } + } + val inputData = MemoryStream[(String, Int)] + val result = + inputData.toDS + .select($"_1".as("key"), $"_2".cast("timestamp").as("eventTime")) + .withWatermark("eventTime", "10 seconds") + .as[(String, Long)] + .groupByKey(_._1) + .flatMapGroupsWithState(Update, EventTimeTimeout)(stateFunc) + + testStream(result, Update)( + StartStream(ProcessingTime("1 second")), + AddData(inputData, ("a", 11), ("a", 13), ("a", 15)), // Set timeout timestamp of ... + CheckLastBatch(("a", 15)), // "a" to 15 + 5 = 20s, watermark to 5s + AddData(inputData, ("a", 4)), // Add data older than watermark for "a" + CheckLastBatch(), // No output as data should get filtered by watermark + AddData(inputData, ("dummy", 35)), // Set watermark = 35 - 10 = 25s + CheckLastBatch(), // No output as no data for "a" + AddData(inputData, ("a", 24)), // Add data older than watermark, should be ignored + CheckLastBatch(("a", -1)) // State for "a" should timeout and emit -1 + ) + } + + test("mapGroupsWithState - streaming") { + // Function to maintain running count up to 2, and then remove the count + // Returns the data and the count (-1 if count reached beyond 2 and state was just removed) + val stateFunc = (key: String, values: Iterator[String], state: GroupState[RunningCount]) => { + + val count = state.getOption.map(_.count).getOrElse(0L) + values.size + if (count == 3) { + state.remove() + (key, "-1") + } else { + state.update(RunningCount(count)) + (key, count.toString) + } + } + + val inputData = MemoryStream[String] + val result = + inputData.toDS() + .groupByKey(x => x) + .mapGroupsWithState(stateFunc) // Types = State: MyState, Out: (Str, Str) + + testStream(result, Update)( + AddData(inputData, "a"), + CheckLastBatch(("a", "1")), + assertNumStateRows(total = 1, updated = 1), + AddData(inputData, "a", "b"), + CheckLastBatch(("a", "2"), ("b", "1")), + assertNumStateRows(total = 2, updated = 2), + StopStream, + StartStream(), + AddData(inputData, "a", "b"), // should remove state for "a" and return count as -1 + CheckLastBatch(("a", "-1"), ("b", "2")), + assertNumStateRows(total = 1, updated = 2), + StopStream, + StartStream(), + AddData(inputData, "a", "c"), // should recreate state for "a" and return count as 1 + CheckLastBatch(("a", "1"), ("c", "1")), + assertNumStateRows(total = 3, updated = 2) + ) + } + + test("mapGroupsWithState - batch") { + val stateFunc = (key: String, values: Iterator[String], state: GroupState[RunningCount]) => { + if (state.exists) throw new IllegalArgumentException("state.exists should be false") + (key, values.size) + } + + checkAnswer( + spark.createDataset(Seq("a", "a", "b")) + .groupByKey(x => x) + .mapGroupsWithState(stateFunc) + .toDF, + spark.createDataset(Seq(("a", 2), ("b", 1))).toDF) + } + + testQuietly("StateStore.abort on task failure handling") { + val stateFunc = (key: String, values: Iterator[String], state: GroupState[RunningCount]) => { + if (FlatMapGroupsWithStateSuite.failInTask) throw new Exception("expected failure") + val count = state.getOption.map(_.count).getOrElse(0L) + values.size + state.update(RunningCount(count)) + (key, count) + } + + val inputData = MemoryStream[String] + val result = + inputData.toDS() + .groupByKey(x => x) + .mapGroupsWithState(stateFunc) // Types = State: MyState, Out: (Str, Str) + + def setFailInTask(value: Boolean): AssertOnQuery = AssertOnQuery { q => + FlatMapGroupsWithStateSuite.failInTask = value + true + } + + testStream(result, Update)( + setFailInTask(false), + AddData(inputData, "a"), + CheckLastBatch(("a", 1L)), + AddData(inputData, "a"), + CheckLastBatch(("a", 2L)), + setFailInTask(true), + AddData(inputData, "a"), + ExpectFailure[SparkException](), // task should fail but should not increment count + setFailInTask(false), + StartStream(), + CheckLastBatch(("a", 3L)) // task should not fail, and should show correct count + ) + } + + test("output partitioning is unknown") { + val stateFunc = (key: String, values: Iterator[String], state: GroupState[RunningCount]) => key + val inputData = MemoryStream[String] + val result = inputData.toDS.groupByKey(x => x).mapGroupsWithState(stateFunc) + testStream(result, Update)( + AddData(inputData, "a"), + CheckLastBatch("a"), + AssertOnQuery(_.lastExecution.executedPlan.outputPartitioning === UnknownPartitioning(0)) + ) + } + + test("disallow complete mode") { + val stateFunc = (key: String, values: Iterator[String], state: GroupState[Int]) => { + Iterator[String]() + } + + var e = intercept[IllegalArgumentException] { + MemoryStream[String].toDS().groupByKey(x => x).flatMapGroupsWithState( + OutputMode.Complete, GroupStateTimeout.NoTimeout)(stateFunc) + } + assert(e.getMessage === "The output mode of function should be append or update") + + val javaStateFunc = new FlatMapGroupsWithStateFunction[String, String, Int, String] { + import java.util.{Iterator => JIterator} + override def call( + key: String, + values: JIterator[String], + state: GroupState[Int]): JIterator[String] = { null } + } + e = intercept[IllegalArgumentException] { + MemoryStream[String].toDS().groupByKey(x => x).flatMapGroupsWithState( + javaStateFunc, OutputMode.Complete, + implicitly[Encoder[Int]], implicitly[Encoder[String]], GroupStateTimeout.NoTimeout) + } + assert(e.getMessage === "The output mode of function should be append or update") + } + + def testStateUpdateWithData( + testName: String, + stateUpdates: GroupState[Int] => Unit, + timeoutConf: GroupStateTimeout, + priorState: Option[Int], + priorTimeoutTimestamp: Long = NO_TIMESTAMP, + expectedState: Option[Int] = None, + expectedTimeoutTimestamp: Long = NO_TIMESTAMP): Unit = { + + if (priorState.isEmpty && priorTimeoutTimestamp != NO_TIMESTAMP) { + return // there can be no prior timestamp, when there is no prior state + } + test(s"StateStoreUpdater - updates with data - $testName") { + val mapGroupsFunc = (key: Int, values: Iterator[Int], state: GroupState[Int]) => { + assert(state.hasTimedOut === false, "hasTimedOut not false") + assert(values.nonEmpty, "Some value is expected") + stateUpdates(state) + Iterator.empty + } + testStateUpdate( + testTimeoutUpdates = false, mapGroupsFunc, timeoutConf, + priorState, priorTimeoutTimestamp, expectedState, expectedTimeoutTimestamp) + } + } + + def testStateUpdateWithTimeout( + testName: String, + stateUpdates: GroupState[Int] => Unit, + timeoutConf: GroupStateTimeout, + priorTimeoutTimestamp: Long, + expectedState: Option[Int], + expectedTimeoutTimestamp: Long = NO_TIMESTAMP): Unit = { + + test(s"StateStoreUpdater - updates for timeout - $testName") { + val mapGroupsFunc = (key: Int, values: Iterator[Int], state: GroupState[Int]) => { + assert(state.hasTimedOut === true, "hasTimedOut not true") + assert(values.isEmpty, "values not empty") + stateUpdates(state) + Iterator.empty + } + testStateUpdate( + testTimeoutUpdates = true, mapGroupsFunc, timeoutConf = timeoutConf, + preTimeoutState, priorTimeoutTimestamp, expectedState, expectedTimeoutTimestamp) + } + } + + def testStateUpdate( + testTimeoutUpdates: Boolean, + mapGroupsFunc: (Int, Iterator[Int], GroupState[Int]) => Iterator[Int], + timeoutConf: GroupStateTimeout, + priorState: Option[Int], + priorTimeoutTimestamp: Long, + expectedState: Option[Int], + expectedTimeoutTimestamp: Long): Unit = { + + val store = newStateStore() + val mapGroupsSparkPlan = newFlatMapGroupsWithStateExec( + mapGroupsFunc, timeoutConf, currentBatchTimestamp) + val updater = new mapGroupsSparkPlan.StateStoreUpdater(store) + val key = intToRow(0) + // Prepare store with prior state configs + if (priorState.nonEmpty) { + val row = updater.getStateRow(priorState.get) + updater.setTimeoutTimestamp(row, priorTimeoutTimestamp) + store.put(key.copy(), row.copy()) + } + + // Call updating function to update state store + val returnedIter = if (testTimeoutUpdates) { + updater.updateStateForTimedOutKeys() + } else { + updater.updateStateForKeysWithData(Iterator(key)) + } + returnedIter.size // consumer the iterator to force state updates + + // Verify updated state in store + val updatedStateRow = store.get(key) + assert( + updater.getStateObj(updatedStateRow).map(_.toString.toInt) === expectedState, + "final state not as expected") + if (updatedStateRow.nonEmpty) { + assert( + updater.getTimeoutTimestamp(updatedStateRow.get) === expectedTimeoutTimestamp, + "final timeout timestamp not as expected") + } + } + + def newFlatMapGroupsWithStateExec( + func: (Int, Iterator[Int], GroupState[Int]) => Iterator[Int], + timeoutType: GroupStateTimeout = GroupStateTimeout.NoTimeout, + batchTimestampMs: Long = NO_TIMESTAMP): FlatMapGroupsWithStateExec = { + MemoryStream[Int] + .toDS + .groupByKey(x => x) + .flatMapGroupsWithState[Int, Int](Append, timeoutConf = timeoutType)(func) + .logicalPlan.collectFirst { + case FlatMapGroupsWithState(f, k, v, g, d, o, s, m, _, t, _) => + FlatMapGroupsWithStateExec( + f, k, v, g, d, o, None, s, m, t, + Some(currentBatchTimestamp), Some(currentBatchWatermark), RDDScanExec(g, null, "rdd")) + }.get + } + + def testTimeoutDurationNotAllowed[T <: Exception: Manifest](state: GroupStateImpl[_]): Unit = { + val prevTimestamp = state.getTimeoutTimestamp + intercept[T] { state.setTimeoutDuration(1000) } + assert(state.getTimeoutTimestamp === prevTimestamp) + intercept[T] { state.setTimeoutDuration("2 second") } + assert(state.getTimeoutTimestamp === prevTimestamp) + } + + def testTimeoutTimestampNotAllowed[T <: Exception: Manifest](state: GroupStateImpl[_]): Unit = { + val prevTimestamp = state.getTimeoutTimestamp + intercept[T] { state.setTimeoutTimestamp(2000) } + assert(state.getTimeoutTimestamp === prevTimestamp) + intercept[T] { state.setTimeoutTimestamp(2000, "1 second") } + assert(state.getTimeoutTimestamp === prevTimestamp) + intercept[T] { state.setTimeoutTimestamp(new Date(2000)) } + assert(state.getTimeoutTimestamp === prevTimestamp) + intercept[T] { state.setTimeoutTimestamp(new Date(2000), "1 second") } + assert(state.getTimeoutTimestamp === prevTimestamp) + } + + def newStateStore(): StateStore = new MemoryStateStore() + + val intProj = UnsafeProjection.create(Array[DataType](IntegerType)) + def intToRow(i: Int): UnsafeRow = { + intProj.apply(new GenericInternalRow(Array[Any](i))).copy() + } + + def rowToInt(row: UnsafeRow): Int = row.getInt(0) +} + +object FlatMapGroupsWithStateSuite { + + var failInTask = true + + class MemoryStateStore extends StateStore() { + import scala.collection.JavaConverters._ + private val map = new ConcurrentHashMap[UnsafeRow, UnsafeRow] + + override def iterator(): Iterator[(UnsafeRow, UnsafeRow)] = { + map.entrySet.iterator.asScala.map { case e => (e.getKey, e.getValue) } + } + + override def filter(c: (UnsafeRow, UnsafeRow) => Boolean): Iterator[(UnsafeRow, UnsafeRow)] = { + iterator.filter { case (k, v) => c(k, v) } + } + + override def get(key: UnsafeRow): Option[UnsafeRow] = Option(map.get(key)) + override def put(key: UnsafeRow, newValue: UnsafeRow): Unit = map.put(key, newValue) + override def remove(key: UnsafeRow): Unit = { map.remove(key) } + override def remove(condition: (UnsafeRow) => Boolean): Unit = { + iterator.map(_._1).filter(condition).foreach(map.remove) + } + override def commit(): Long = version + 1 + override def abort(): Unit = { } + override def id: StateStoreId = null + override def version: Long = 0 + override def updates(): Iterator[StoreUpdate] = { throw new UnsupportedOperationException } + override def numKeys(): Long = map.size + override def hasCommitted: Boolean = true + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/MapGroupsWithStateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/MapGroupsWithStateSuite.scala deleted file mode 100644 index 6cf4d51f99333..0000000000000 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/MapGroupsWithStateSuite.scala +++ /dev/null @@ -1,328 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.streaming - -import org.scalatest.BeforeAndAfterAll - -import org.apache.spark.SparkException -import org.apache.spark.sql.KeyedState -import org.apache.spark.sql.catalyst.streaming.InternalOutputModes._ -import org.apache.spark.sql.execution.streaming.{KeyedStateImpl, MemoryStream} -import org.apache.spark.sql.execution.streaming.state.StateStore - -/** Class to check custom state types */ -case class RunningCount(count: Long) - -class MapGroupsWithStateSuite extends StateStoreMetricsTest with BeforeAndAfterAll { - - import testImplicits._ - - override def afterAll(): Unit = { - super.afterAll() - StateStore.stop() - } - - test("KeyedState - get, exists, update, remove") { - var state: KeyedStateImpl[String] = null - - def testState( - expectedData: Option[String], - shouldBeUpdated: Boolean = false, - shouldBeRemoved: Boolean = false): Unit = { - if (expectedData.isDefined) { - assert(state.exists) - assert(state.get === expectedData.get) - } else { - assert(!state.exists) - intercept[NoSuchElementException] { - state.get - } - } - assert(state.getOption === expectedData) - assert(state.isUpdated === shouldBeUpdated) - assert(state.isRemoved === shouldBeRemoved) - } - - // Updating empty state - state = new KeyedStateImpl[String](None) - testState(None) - state.update("") - testState(Some(""), shouldBeUpdated = true) - - // Updating exiting state - state = new KeyedStateImpl[String](Some("2")) - testState(Some("2")) - state.update("3") - testState(Some("3"), shouldBeUpdated = true) - - // Removing state - state.remove() - testState(None, shouldBeRemoved = true, shouldBeUpdated = false) - state.remove() // should be still callable - state.update("4") - testState(Some("4"), shouldBeRemoved = false, shouldBeUpdated = true) - - // Updating by null throw exception - intercept[IllegalArgumentException] { - state.update(null) - } - } - - test("KeyedState - primitive type") { - var intState = new KeyedStateImpl[Int](None) - intercept[NoSuchElementException] { - intState.get - } - assert(intState.getOption === None) - - intState = new KeyedStateImpl[Int](Some(10)) - assert(intState.get == 10) - intState.update(0) - assert(intState.get == 0) - intState.remove() - intercept[NoSuchElementException] { - intState.get - } - } - - test("flatMapGroupsWithState - streaming") { - // Function to maintain running count up to 2, and then remove the count - // Returns the data and the count if state is defined, otherwise does not return anything - val stateFunc = (key: String, values: Iterator[String], state: KeyedState[RunningCount]) => { - - val count = state.getOption.map(_.count).getOrElse(0L) + values.size - if (count == 3) { - state.remove() - Iterator.empty - } else { - state.update(RunningCount(count)) - Iterator((key, count.toString)) - } - } - - val inputData = MemoryStream[String] - val result = - inputData.toDS() - .groupByKey(x => x) - .flatMapGroupsWithState(stateFunc) // State: Int, Out: (Str, Str) - - testStream(result, Append)( - AddData(inputData, "a"), - CheckLastBatch(("a", "1")), - assertNumStateRows(total = 1, updated = 1), - AddData(inputData, "a", "b"), - CheckLastBatch(("a", "2"), ("b", "1")), - assertNumStateRows(total = 2, updated = 2), - StopStream, - StartStream(), - AddData(inputData, "a", "b"), // should remove state for "a" and not return anything for a - CheckLastBatch(("b", "2")), - assertNumStateRows(total = 1, updated = 2), - StopStream, - StartStream(), - AddData(inputData, "a", "c"), // should recreate state for "a" and return count as 1 and - CheckLastBatch(("a", "1"), ("c", "1")), - assertNumStateRows(total = 3, updated = 2) - ) - } - - test("flatMapGroupsWithState - streaming + func returns iterator that updates state lazily") { - // Function to maintain running count up to 2, and then remove the count - // Returns the data and the count if state is defined, otherwise does not return anything - // Additionally, it updates state lazily as the returned iterator get consumed - val stateFunc = (key: String, values: Iterator[String], state: KeyedState[RunningCount]) => { - values.flatMap { _ => - val count = state.getOption.map(_.count).getOrElse(0L) + 1 - if (count == 3) { - state.remove() - None - } else { - state.update(RunningCount(count)) - Some((key, count.toString)) - } - } - } - - val inputData = MemoryStream[String] - val result = - inputData.toDS() - .groupByKey(x => x) - .flatMapGroupsWithState(stateFunc) // State: Int, Out: (Str, Str) - - testStream(result, Append)( - AddData(inputData, "a", "a", "b"), - CheckLastBatch(("a", "1"), ("a", "2"), ("b", "1")), - StopStream, - StartStream(), - AddData(inputData, "a", "b"), // should remove state for "a" and not return anything for a - CheckLastBatch(("b", "2")), - StopStream, - StartStream(), - AddData(inputData, "a", "c"), // should recreate state for "a" and return count as 1 and - CheckLastBatch(("a", "1"), ("c", "1")) - ) - } - - test("flatMapGroupsWithState - batch") { - // Function that returns running count only if its even, otherwise does not return - val stateFunc = (key: String, values: Iterator[String], state: KeyedState[RunningCount]) => { - if (state.exists) throw new IllegalArgumentException("state.exists should be false") - Iterator((key, values.size)) - } - checkAnswer( - Seq("a", "a", "b").toDS.groupByKey(x => x).flatMapGroupsWithState(stateFunc).toDF, - Seq(("a", 2), ("b", 1)).toDF) - } - - test("mapGroupsWithState - streaming") { - // Function to maintain running count up to 2, and then remove the count - // Returns the data and the count (-1 if count reached beyond 2 and state was just removed) - val stateFunc = (key: String, values: Iterator[String], state: KeyedState[RunningCount]) => { - - val count = state.getOption.map(_.count).getOrElse(0L) + values.size - if (count == 3) { - state.remove() - (key, "-1") - } else { - state.update(RunningCount(count)) - (key, count.toString) - } - } - - val inputData = MemoryStream[String] - val result = - inputData.toDS() - .groupByKey(x => x) - .mapGroupsWithState(stateFunc) // Types = State: MyState, Out: (Str, Str) - - testStream(result, Append)( - AddData(inputData, "a"), - CheckLastBatch(("a", "1")), - assertNumStateRows(total = 1, updated = 1), - AddData(inputData, "a", "b"), - CheckLastBatch(("a", "2"), ("b", "1")), - assertNumStateRows(total = 2, updated = 2), - StopStream, - StartStream(), - AddData(inputData, "a", "b"), // should remove state for "a" and return count as -1 - CheckLastBatch(("a", "-1"), ("b", "2")), - assertNumStateRows(total = 1, updated = 2), - StopStream, - StartStream(), - AddData(inputData, "a", "c"), // should recreate state for "a" and return count as 1 - CheckLastBatch(("a", "1"), ("c", "1")), - assertNumStateRows(total = 3, updated = 2) - ) - } - - test("mapGroupsWithState - streaming + aggregation") { - // Function to maintain running count up to 2, and then remove the count - // Returns the data and the count (-1 if count reached beyond 2 and state was just removed) - val stateFunc = (key: String, values: Iterator[String], state: KeyedState[RunningCount]) => { - - val count = state.getOption.map(_.count).getOrElse(0L) + values.size - if (count == 3) { - state.remove() - (key, "-1") - } else { - state.update(RunningCount(count)) - (key, count.toString) - } - } - - val inputData = MemoryStream[String] - val result = - inputData.toDS() - .groupByKey(x => x) - .mapGroupsWithState(stateFunc) // Types = State: MyState, Out: (Str, Str) - .groupByKey(_._1) - .count() - - testStream(result, Complete)( - AddData(inputData, "a"), - CheckLastBatch(("a", 1)), - AddData(inputData, "a", "b"), - // mapGroups generates ("a", "2"), ("b", "1"); so increases counts of a and b by 1 - CheckLastBatch(("a", 2), ("b", 1)), - StopStream, - StartStream(), - AddData(inputData, "a", "b"), - // mapGroups should remove state for "a" and generate ("a", "-1"), ("b", "2") ; - // so increment a and b by 1 - CheckLastBatch(("a", 3), ("b", 2)), - StopStream, - StartStream(), - AddData(inputData, "a", "c"), - // mapGroups should recreate state for "a" and generate ("a", "1"), ("c", "1") ; - // so increment a and c by 1 - CheckLastBatch(("a", 4), ("b", 2), ("c", 1)) - ) - } - - test("mapGroupsWithState - batch") { - val stateFunc = (key: String, values: Iterator[String], state: KeyedState[RunningCount]) => { - if (state.exists) throw new IllegalArgumentException("state.exists should be false") - (key, values.size) - } - - checkAnswer( - spark.createDataset(Seq("a", "a", "b")) - .groupByKey(x => x) - .mapGroupsWithState(stateFunc) - .toDF, - spark.createDataset(Seq(("a", 2), ("b", 1))).toDF) - } - - testQuietly("StateStore.abort on task failure handling") { - val stateFunc = (key: String, values: Iterator[String], state: KeyedState[RunningCount]) => { - if (MapGroupsWithStateSuite.failInTask) throw new Exception("expected failure") - val count = state.getOption.map(_.count).getOrElse(0L) + values.size - state.update(RunningCount(count)) - (key, count) - } - - val inputData = MemoryStream[String] - val result = - inputData.toDS() - .groupByKey(x => x) - .mapGroupsWithState(stateFunc) // Types = State: MyState, Out: (Str, Str) - - def setFailInTask(value: Boolean): AssertOnQuery = AssertOnQuery { q => - MapGroupsWithStateSuite.failInTask = value - true - } - - testStream(result, Append)( - setFailInTask(false), - AddData(inputData, "a"), - CheckLastBatch(("a", 1L)), - AddData(inputData, "a"), - CheckLastBatch(("a", 2L)), - setFailInTask(true), - AddData(inputData, "a"), - ExpectFailure[SparkException](), // task should fail but should not increment count - setFailInTask(false), - StartStream(), - CheckLastBatch(("a", 3L)) // task should not fail, and should show correct count - ) - } -} - -object MapGroupsWithStateSuite { - var failInTask = true -} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala index 6dfcd8baba20e..1fc062974e185 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala @@ -17,19 +17,26 @@ package org.apache.spark.sql.streaming -import java.io.{InterruptedIOException, IOException} +import java.io.{File, InterruptedIOException, IOException} import java.util.concurrent.{CountDownLatch, TimeoutException, TimeUnit} import scala.reflect.ClassTag import scala.util.control.ControlThrowable +import org.apache.commons.io.FileUtils + +import org.apache.spark.SparkContext +import org.apache.spark.scheduler.{SparkListener, SparkListenerJobStart} import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.streaming.InternalOutputModes import org.apache.spark.sql.execution.command.ExplainCommand import org.apache.spark.sql.execution.streaming._ import org.apache.spark.sql.functions._ +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.sources.StreamSourceProvider +import org.apache.spark.sql.streaming.util.StreamManualClock import org.apache.spark.sql.types.{IntegerType, StructField, StructType} +import org.apache.spark.util.Utils class StreamSuite extends StreamTest { @@ -64,6 +71,27 @@ class StreamSuite extends StreamTest { CheckAnswer(Row(1, 1, "one"), Row(2, 2, "two"), Row(4, 4, "four"))) } + test("SPARK-20432: union one stream with itself") { + val df = spark.readStream.format(classOf[FakeDefaultSource].getName).load().select("a") + val unioned = df.union(df) + withTempDir { outputDir => + withTempDir { checkpointDir => + val query = + unioned + .writeStream.format("parquet") + .option("checkpointLocation", checkpointDir.getAbsolutePath) + .start(outputDir.getAbsolutePath) + try { + query.processAllAvailable() + val outputDf = spark.read.parquet(outputDir.getAbsolutePath).as[Long] + checkDatasetUnorderly[Long](outputDf, (0L to 10L).union((0L to 10L)).toArray: _*) + } finally { + query.stop() + } + } + } + } + test("union two streams") { val inputData1 = MemoryStream[Int] val inputData2 = MemoryStream[Int] @@ -115,6 +143,33 @@ class StreamSuite extends StreamTest { assertDF(df) } + test("Within the same streaming query, one StreamingRelation should only be transformed to one " + + "StreamingExecutionRelation") { + val df = spark.readStream.format(classOf[FakeDefaultSource].getName).load() + var query: StreamExecution = null + try { + query = + df.union(df) + .writeStream + .format("memory") + .queryName("memory") + .start() + .asInstanceOf[StreamingQueryWrapper] + .streamingQuery + query.awaitInitialization(streamingTimeout.toMillis) + val executionRelations = + query + .logicalPlan + .collect { case ser: StreamingExecutionRelation => ser } + assert(executionRelations.size === 2) + assert(executionRelations.distinct.size === 1) + } finally { + if (query != null) { + query.stop() + } + } + } + test("unsupported queries") { val streamInput = MemoryStream[Int] val batchInput = Seq(1, 2, 3).toDS() @@ -152,6 +207,15 @@ class StreamSuite extends StreamTest { AssertOnQuery(_.offsetLog.getLatest().get._1 == expectedId, s"offsetLog's latest should be $expectedId") + // Check the latest batchid in the commit log + def CheckCommitLogLatestBatchId(expectedId: Int): AssertOnQuery = + AssertOnQuery(_.batchCommitLog.getLatest().get._1 == expectedId, + s"commitLog's latest should be $expectedId") + + // Ensure that there has not been an incremental execution after restart + def CheckNoIncrementalExecutionCurrentBatchId(): AssertOnQuery = + AssertOnQuery(_.lastExecution == null, s"lastExecution not expected to run") + // For each batch, we would log the state change during the execution // This checks whether the key of the state change log is the expected batch id def CheckIncrementalExecutionCurrentBatchId(expectedId: Int): AssertOnQuery = @@ -177,6 +241,7 @@ class StreamSuite extends StreamTest { // Check the results of batch 0 CheckAnswer(1, 2, 3), CheckIncrementalExecutionCurrentBatchId(0), + CheckCommitLogLatestBatchId(0), CheckOffsetLogLatestBatchId(0), CheckSinkLatestBatchId(0), // Add some data in batch 1 @@ -187,6 +252,7 @@ class StreamSuite extends StreamTest { // Check the results of batch 1 CheckAnswer(1, 2, 3, 4, 5, 6), CheckIncrementalExecutionCurrentBatchId(1), + CheckCommitLogLatestBatchId(1), CheckOffsetLogLatestBatchId(1), CheckSinkLatestBatchId(1), @@ -199,6 +265,7 @@ class StreamSuite extends StreamTest { // the currentId does not get logged (e.g. as 2) even if the clock has advanced many times CheckAnswer(1, 2, 3, 4, 5, 6), CheckIncrementalExecutionCurrentBatchId(1), + CheckCommitLogLatestBatchId(1), CheckOffsetLogLatestBatchId(1), CheckSinkLatestBatchId(1), @@ -206,14 +273,15 @@ class StreamSuite extends StreamTest { StopStream, StartStream(ProcessingTime("10 seconds"), new StreamManualClock(60 * 1000)), - /* -- batch 1 rerun ----------------- */ - // this batch 1 would re-run because the latest batch id logged in offset log is 1 + /* -- batch 1 no rerun ----------------- */ + // batch 1 would not re-run because the latest batch id logged in commit log is 1 AdvanceManualClock(10 * 1000), + CheckNoIncrementalExecutionCurrentBatchId(), /* -- batch 2 ----------------------- */ // Check the results of batch 1 CheckAnswer(1, 2, 3, 4, 5, 6), - CheckIncrementalExecutionCurrentBatchId(1), + CheckCommitLogLatestBatchId(1), CheckOffsetLogLatestBatchId(1), CheckSinkLatestBatchId(1), // Add some data in batch 2 @@ -224,6 +292,7 @@ class StreamSuite extends StreamTest { // Check the results of batch 2 CheckAnswer(1, 2, 3, 4, 5, 6, 7, 8, 9), CheckIncrementalExecutionCurrentBatchId(2), + CheckCommitLogLatestBatchId(2), CheckOffsetLogLatestBatchId(2), CheckSinkLatestBatchId(2)) } @@ -389,6 +458,162 @@ class StreamSuite extends StreamTest { query.stop() assert(query.exception.isEmpty) } + + test("SPARK-19873: streaming aggregation with change in number of partitions") { + val inputData = MemoryStream[(Int, Int)] + val agg = inputData.toDS().groupBy("_1").count() + + testStream(agg, OutputMode.Complete())( + AddData(inputData, (1, 0), (2, 0)), + StartStream(additionalConfs = Map(SQLConf.SHUFFLE_PARTITIONS.key -> "2")), + CheckAnswer((1, 1), (2, 1)), + StopStream, + AddData(inputData, (3, 0), (2, 0)), + StartStream(additionalConfs = Map(SQLConf.SHUFFLE_PARTITIONS.key -> "5")), + CheckAnswer((1, 1), (2, 2), (3, 1)), + StopStream, + AddData(inputData, (3, 0), (1, 0)), + StartStream(additionalConfs = Map(SQLConf.SHUFFLE_PARTITIONS.key -> "1")), + CheckAnswer((1, 2), (2, 2), (3, 2))) + } + + testQuietly("recover from a Spark v2.1 checkpoint") { + var inputData: MemoryStream[Int] = null + var query: DataStreamWriter[Row] = null + + def prepareMemoryStream(): Unit = { + inputData = MemoryStream[Int] + inputData.addData(1, 2, 3, 4) + inputData.addData(3, 4, 5, 6) + inputData.addData(5, 6, 7, 8) + + query = inputData + .toDF() + .groupBy($"value") + .agg(count("*")) + .writeStream + .outputMode("complete") + .format("memory") + } + + // Get an existing checkpoint generated by Spark v2.1. + // v2.1 does not record # shuffle partitions in the offset metadata. + val resourceUri = + this.getClass.getResource("/structured-streaming/checkpoint-version-2.1.0").toURI + val checkpointDir = new File(resourceUri) + + // 1 - Test if recovery from the checkpoint is successful. + prepareMemoryStream() + val dir1 = Utils.createTempDir().getCanonicalFile // not using withTempDir {}, makes test flaky + // Copy the checkpoint to a temp dir to prevent changes to the original. + // Not doing this will lead to the test passing on the first run, but fail subsequent runs. + FileUtils.copyDirectory(checkpointDir, dir1) + // Checkpoint data was generated by a query with 10 shuffle partitions. + // In order to test reading from the checkpoint, the checkpoint must have two or more batches, + // since the last batch may be rerun. + withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "10") { + var streamingQuery: StreamingQuery = null + try { + streamingQuery = + query.queryName("counts").option("checkpointLocation", dir1.getCanonicalPath).start() + streamingQuery.processAllAvailable() + inputData.addData(9) + streamingQuery.processAllAvailable() + + QueryTest.checkAnswer(spark.table("counts").toDF(), + Row("1", 1) :: Row("2", 1) :: Row("3", 2) :: Row("4", 2) :: + Row("5", 2) :: Row("6", 2) :: Row("7", 1) :: Row("8", 1) :: Row("9", 1) :: Nil) + } finally { + if (streamingQuery ne null) { + streamingQuery.stop() + } + } + } + + // 2 - Check recovery with wrong num shuffle partitions + prepareMemoryStream() + val dir2 = Utils.createTempDir().getCanonicalFile + FileUtils.copyDirectory(checkpointDir, dir2) + // Since the number of partitions is greater than 10, should throw exception. + withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "15") { + var streamingQuery: StreamingQuery = null + try { + intercept[StreamingQueryException] { + streamingQuery = + query.queryName("badQuery").option("checkpointLocation", dir2.getCanonicalPath).start() + streamingQuery.processAllAvailable() + } + } finally { + if (streamingQuery ne null) { + streamingQuery.stop() + } + } + } + } + + test("calling stop() on a query cancels related jobs") { + val input = MemoryStream[Int] + val query = input + .toDS() + .map { i => + while (!org.apache.spark.TaskContext.get().isInterrupted()) { + // keep looping till interrupted by query.stop() + Thread.sleep(100) + } + i + } + .writeStream + .format("console") + .start() + + input.addData(1) + // wait for jobs to start + eventually(timeout(streamingTimeout)) { + assert(sparkContext.statusTracker.getActiveJobIds().nonEmpty) + } + + query.stop() + // make sure jobs are stopped + eventually(timeout(streamingTimeout)) { + assert(sparkContext.statusTracker.getActiveJobIds().isEmpty) + } + } + + test("batch id is updated correctly in the job description") { + val queryName = "memStream" + @volatile var jobDescription: String = null + def assertDescContainsQueryNameAnd(batch: Integer): Unit = { + // wait for listener event to be processed + spark.sparkContext.listenerBus.waitUntilEmpty(streamingTimeout.toMillis) + assert(jobDescription.contains(queryName) && jobDescription.contains(s"batch = $batch")) + } + + spark.sparkContext.addSparkListener(new SparkListener { + override def onJobStart(jobStart: SparkListenerJobStart): Unit = { + jobDescription = jobStart.properties.getProperty(SparkContext.SPARK_JOB_DESCRIPTION) + } + }) + + val input = MemoryStream[Int] + val query = input + .toDS() + .map(_ + 1) + .writeStream + .format("memory") + .queryName(queryName) + .start() + + input.addData(1) + query.processAllAvailable() + assertDescContainsQueryNameAnd(batch = 0) + input.addData(2, 3) + query.processAllAvailable() + assertDescContainsQueryNameAnd(batch = 1) + input.addData(4) + query.processAllAvailable() + assertDescContainsQueryNameAnd(batch = 2) + query.stop() + } } abstract class FakeSource extends StreamSourceProvider { @@ -458,7 +683,7 @@ class ThrowingIOExceptionLikeHadoop12074 extends FakeSource { object ThrowingIOExceptionLikeHadoop12074 { /** - * A latch to allow the user to wait until [[ThrowingIOExceptionLikeHadoop12074.createSource]] is + * A latch to allow the user to wait until `ThrowingIOExceptionLikeHadoop12074.createSource` is * called. */ @volatile var createSourceLatch: CountDownLatch = null @@ -489,7 +714,7 @@ class ThrowingInterruptedIOException extends FakeSource { object ThrowingInterruptedIOException { /** - * A latch to allow the user to wait until [[ThrowingInterruptedIOException.createSource]] is + * A latch to allow the user to wait until `ThrowingInterruptedIOException.createSource` is * called. */ @volatile var createSourceLatch: CountDownLatch = null diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala index 60e2375a9817d..5bc36dd30f6d1 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala @@ -159,7 +159,7 @@ trait StreamTest extends QueryTest with SharedSQLContext with Timeouts { /** Starts the stream, resuming if data has already been processed. It must not be running. */ case class StartStream( - trigger: Trigger = ProcessingTime(0), + trigger: Trigger = Trigger.ProcessingTime(0), triggerClock: Clock = new SystemClock, additionalConfs: Map[String, String] = Map.empty) extends StreamAction @@ -214,24 +214,6 @@ trait StreamTest extends QueryTest with SharedSQLContext with Timeouts { AssertOnQuery(query => { func(query); true }) } - class StreamManualClock(time: Long = 0L) extends ManualClock(time) with Serializable { - private var waitStartTime: Option[Long] = None - - override def waitTillTime(targetTime: Long): Long = synchronized { - try { - waitStartTime = Some(getTimeMillis()) - super.waitTillTime(targetTime) - } finally { - waitStartTime = None - } - } - - def isStreamWaitingAt(time: Long): Boolean = synchronized { - waitStartTime == Some(time) - } - } - - /** * Executes the specified actions on the given streaming DataFrame and provides helpful * error messages in the case of failures or incorrect answers. @@ -242,6 +224,8 @@ trait StreamTest extends QueryTest with SharedSQLContext with Timeouts { def testStream( _stream: Dataset[_], outputMode: OutputMode = OutputMode.Append)(actions: StreamAction*): Unit = synchronized { + import org.apache.spark.sql.streaming.util.StreamManualClock + // `synchronized` is added to prevent the user from calling multiple `testStream`s concurrently // because this method assumes there is only one active query in its `StreamingQueryListener` // and it may not work correctly when multiple `testStream`s run concurrently. @@ -293,6 +277,11 @@ trait StreamTest extends QueryTest with SharedSQLContext with Timeouts { def threadState = if (currentStream != null && currentStream.microBatchThread.isAlive) "alive" else "dead" + def threadStackTrace = if (currentStream != null && currentStream.microBatchThread.isAlive) { + s"Thread stack trace: ${currentStream.microBatchThread.getStackTrace.mkString("\n")}" + } else { + "" + } def testState = s""" @@ -303,6 +292,7 @@ trait StreamTest extends QueryTest with SharedSQLContext with Timeouts { |Output Mode: $outputMode |Stream state: $currentOffsets |Thread state: $threadState + |$threadStackTrace |${if (streamThreadDeathCause != null) stackTraceToString(streamThreadDeathCause) else ""} | |== Sink == @@ -488,8 +478,27 @@ trait StreamTest extends QueryTest with SharedSQLContext with Timeouts { case a: AddData => try { - // Add data and get the source where it was added, and the expected offset of the - // added data. + + // If the query is running with manual clock, then wait for the stream execution + // thread to start waiting for the clock to increment. This is needed so that we + // are adding data when there is no trigger that is active. This would ensure that + // the data gets deterministically added to the next batch triggered after the manual + // clock is incremented in following AdvanceManualClock. This avoid race conditions + // between the test thread and the stream execution thread in tests using manual + // clock. + if (currentStream != null && + currentStream.triggerClock.isInstanceOf[StreamManualClock]) { + val clock = currentStream.triggerClock.asInstanceOf[StreamManualClock] + eventually("Error while synchronizing with manual clock before adding data") { + if (currentStream.isActive) { + assert(clock.isStreamWaitingAt(clock.getTimeMillis())) + } + } + if (!currentStream.isActive) { + failTest("Query terminated while synchronizing with manual clock") + } + } + // Add data val queryToUse = Option(currentStream).orElse(Option(lastStream)) val (source, offset) = a.addData(queryToUse) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingAggregationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingAggregationSuite.scala index 0c8015672bab4..4345a70601c34 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingAggregationSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingAggregationSuite.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.streaming -import java.util.TimeZone +import java.util.{Locale, TimeZone} import org.scalatest.BeforeAndAfterAll @@ -30,6 +30,7 @@ import org.apache.spark.sql.execution.streaming.state.StateStore import org.apache.spark.sql.expressions.scalalang.typed import org.apache.spark.sql.functions._ import org.apache.spark.sql.streaming.OutputMode._ +import org.apache.spark.sql.streaming.util.StreamManualClock object FailureSinglton { var firstTime = true @@ -68,6 +69,22 @@ class StreamingAggregationSuite extends StateStoreMetricsTest with BeforeAndAfte ) } + test("count distinct") { + val inputData = MemoryStream[(Int, Seq[Int])] + + val aggregated = + inputData.toDF() + .select($"*", explode($"_2") as 'value) + .groupBy($"_1") + .agg(size(collect_set($"value"))) + .as[(Int, Int)] + + testStream(aggregated, Update)( + AddData(inputData, (1, Seq(1, 2))), + CheckLastBatch((1, 2)) + ) + } + test("simple count, complete mode") { val inputData = MemoryStream[Int] @@ -104,7 +121,7 @@ class StreamingAggregationSuite extends StateStoreMetricsTest with BeforeAndAfte testStream(aggregated, Append)() } Seq("append", "not supported").foreach { m => - assert(e.getMessage.toLowerCase.contains(m.toLowerCase)) + assert(e.getMessage.toLowerCase(Locale.ROOT).contains(m.toLowerCase(Locale.ROOT))) } } @@ -272,11 +289,13 @@ class StreamingAggregationSuite extends StateStoreMetricsTest with BeforeAndAfte StopStream, AssertOnQuery { q => // clear the sink q.sink.asInstanceOf[MemorySink].clear() + q.batchCommitLog.purge(3) // advance by a minute i.e., 90 seconds total clock.advance(60 * 1000L) true }, StartStream(ProcessingTime("10 seconds"), triggerClock = clock), + // The commit log blown, causing the last batch to re-run CheckLastBatch((20L, 1), (85L, 1)), AssertOnQuery { q => clock.getTimeMillis() == 90000L @@ -322,11 +341,13 @@ class StreamingAggregationSuite extends StateStoreMetricsTest with BeforeAndAfte StopStream, AssertOnQuery { q => // clear the sink q.sink.asInstanceOf[MemorySink].clear() + q.batchCommitLog.purge(3) // advance by 60 days i.e., 90 days total clock.advance(DateTimeUtils.MILLIS_PER_DAY * 60) true }, StartStream(ProcessingTime("10 day"), triggerClock = clock), + // Commit log blown, causing a re-run of the last batch CheckLastBatch((20L, 1), (85L, 1)), // advance clock to 100 days, should retain keys >= 90 diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQueryListenerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQueryListenerSuite.scala index eb09b9ffcfc5d..59c6a6fade175 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQueryListenerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQueryListenerSuite.scala @@ -21,6 +21,7 @@ import java.util.UUID import scala.collection.mutable import scala.concurrent.duration._ +import scala.language.reflectiveCalls import org.scalactic.TolerantNumerics import org.scalatest.concurrent.AsyncAssertions.Waiter @@ -35,6 +36,7 @@ import org.apache.spark.sql.{Encoder, SparkSession} import org.apache.spark.sql.execution.streaming._ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.streaming.StreamingQueryListener._ +import org.apache.spark.sql.streaming.util.StreamManualClock import org.apache.spark.util.JsonProtocol class StreamingQueryListenerSuite extends StreamTest with BeforeAndAfter { @@ -57,6 +59,20 @@ class StreamingQueryListenerSuite extends StreamTest with BeforeAndAfter { val inputData = new MemoryStream[Int](0, sqlContext) val df = inputData.toDS().as[Long].map { 10 / _ } val listener = new EventCollector + + case class AssertStreamExecThreadToWaitForClock() + extends AssertOnQuery(q => { + eventually(Timeout(streamingTimeout)) { + if (q.exception.isEmpty) { + assert(clock.asInstanceOf[StreamManualClock].isStreamWaitingAt(clock.getTimeMillis)) + } + } + if (q.exception.isDefined) { + throw q.exception.get + } + true + }, "") + try { // No events until started spark.streams.addListener(listener) @@ -81,6 +97,7 @@ class StreamingQueryListenerSuite extends StreamTest with BeforeAndAfter { // Progress event generated when data processed AddData(inputData, 1, 2), AdvanceManualClock(100), + AssertStreamExecThreadToWaitForClock(), CheckAnswer(10, 5), AssertOnQuery { query => assert(listener.progressEvents.nonEmpty) @@ -109,8 +126,9 @@ class StreamingQueryListenerSuite extends StreamTest with BeforeAndAfter { // Termination event generated with exception message when stopped with error StartStream(ProcessingTime(100), triggerClock = clock), + AssertStreamExecThreadToWaitForClock(), AddData(inputData, 0), - AdvanceManualClock(100), + AdvanceManualClock(100), // process bad data ExpectFailure[SparkException](), AssertOnQuery { query => eventually(Timeout(streamingTimeout)) { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQueryManagerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQueryManagerSuite.scala index f05e9d1fda73f..b49efa6890236 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQueryManagerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQueryManagerSuite.scala @@ -239,16 +239,6 @@ class StreamingQueryManagerSuite extends StreamTest with BeforeAndAfter { } } - test("SPARK-19268: Adaptive query execution should be disallowed") { - withSQLConf(SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true") { - val e = intercept[AnalysisException] { - MemoryStream[Int].toDS.writeStream.queryName("test-query").format("memory").start() - } - assert(e.getMessage.contains(SQLConf.ADAPTIVE_EXECUTION_ENABLED.key) && - e.getMessage.contains("not supported")) - } - } - /** Run a body of code by defining a query on each dataset */ private def withQueriesOn(datasets: Dataset[_]*)(body: Seq[StreamingQuery] => Unit): Unit = { failAfter(streamingTimeout) { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala index a0a2b2b4c9b3b..b69536ed37463 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala @@ -34,7 +34,7 @@ import org.apache.spark.SparkException import org.apache.spark.sql.execution.streaming._ import org.apache.spark.sql.functions._ import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.streaming.util.{BlockingSource, MockSourceProvider} +import org.apache.spark.sql.streaming.util.{BlockingSource, MockSourceProvider, StreamManualClock} import org.apache.spark.util.ManualClock @@ -158,52 +158,102 @@ class StreamingQuerySuite extends StreamTest with BeforeAndAfter with Logging wi ) } + testQuietly("OneTime trigger, commit log, and exception") { + import Trigger.Once + val inputData = MemoryStream[Int] + val mapped = inputData.toDS().map { 6 / _} + + testStream(mapped)( + AssertOnQuery(_.isActive === true), + StopStream, + AddData(inputData, 1, 2), + StartStream(trigger = Once), + CheckAnswer(6, 3), + StopStream, // clears out StreamTest state + AssertOnQuery { q => + // both commit log and offset log contain the same (latest) batch id + q.batchCommitLog.getLatest().map(_._1).getOrElse(-1L) == + q.offsetLog.getLatest().map(_._1).getOrElse(-2L) + }, + AssertOnQuery { q => + // blow away commit log and sink result + q.batchCommitLog.purge(1) + q.sink.asInstanceOf[MemorySink].clear() + true + }, + StartStream(trigger = Once), + CheckAnswer(6, 3), // ensure we fall back to offset log and reprocess batch + StopStream, + AddData(inputData, 3), + StartStream(trigger = Once), + CheckLastBatch(2), // commit log should be back in place + StopStream, + AddData(inputData, 0), + StartStream(trigger = Once), + ExpectFailure[SparkException](), + AssertOnQuery(_.isActive === false), + AssertOnQuery(q => { + q.exception.get.startOffset === + q.committedOffsets.toOffsetSeq(Seq(inputData), OffsetSeqMetadata()).toString && + q.exception.get.endOffset === + q.availableOffsets.toOffsetSeq(Seq(inputData), OffsetSeqMetadata()).toString + }, "incorrect start offset or end offset on exception") + ) + } + testQuietly("status, lastProgress, and recentProgress") { import StreamingQuerySuite._ clock = new StreamManualClock /** Custom MemoryStream that waits for manual clock to reach a time */ val inputData = new MemoryStream[Int](0, sqlContext) { - // Wait for manual clock to be 100 first time there is data + // getOffset should take 50 ms the first time it is called override def getOffset: Option[Offset] = { val offset = super.getOffset if (offset.nonEmpty) { - clock.waitTillTime(300) + clock.waitTillTime(1050) } offset } - // Wait for manual clock to be 300 first time there is data + // getBatch should take 100 ms the first time it is called override def getBatch(start: Option[Offset], end: Offset): DataFrame = { - clock.waitTillTime(600) + if (start.isEmpty) clock.waitTillTime(1150) super.getBatch(start, end) } } - // This is to make sure thatquery waits for manual clock to be 600 first time there is data - val mapped = inputData.toDS().as[Long].map { x => - clock.waitTillTime(1100) + // query execution should take 350 ms the first time it is called + val mapped = inputData.toDS.coalesce(1).as[Long].map { x => + clock.waitTillTime(1500) // this will only wait the first time when clock < 1500 10 / x }.agg(count("*")).as[Long] - case class AssertStreamExecThreadToWaitForClock() + case class AssertStreamExecThreadIsWaitingForTime(targetTime: Long) extends AssertOnQuery(q => { eventually(Timeout(streamingTimeout)) { if (q.exception.isEmpty) { - assert(clock.asInstanceOf[StreamManualClock].isStreamWaitingAt(clock.getTimeMillis)) + assert(clock.isStreamWaitingFor(targetTime)) } } if (q.exception.isDefined) { throw q.exception.get } true - }, "") + }, "") { + override def toString: String = s"AssertStreamExecThreadIsWaitingForTime($targetTime)" + } + + case class AssertClockTime(time: Long) + extends AssertOnQuery(q => clock.getTimeMillis() === time, "") { + override def toString: String = s"AssertClockTime($time)" + } var lastProgressBeforeStop: StreamingQueryProgress = null testStream(mapped, OutputMode.Complete)( - StartStream(ProcessingTime(100), triggerClock = clock), - AssertStreamExecThreadToWaitForClock(), + StartStream(ProcessingTime(1000), triggerClock = clock), + AssertStreamExecThreadIsWaitingForTime(1000), AssertOnQuery(_.status.isDataAvailable === false), AssertOnQuery(_.status.isTriggerActive === false), AssertOnQuery(_.status.message === "Waiting for next trigger"), @@ -211,32 +261,37 @@ class StreamingQuerySuite extends StreamTest with BeforeAndAfter with Logging wi // Test status and progress while offset is being fetched AddData(inputData, 1, 2), - AdvanceManualClock(100), // time = 100 to start new trigger, will block on getOffset - AssertStreamExecThreadToWaitForClock(), + AdvanceManualClock(1000), // time = 1000 to start new trigger, will block on getOffset + AssertStreamExecThreadIsWaitingForTime(1050), AssertOnQuery(_.status.isDataAvailable === false), AssertOnQuery(_.status.isTriggerActive === true), AssertOnQuery(_.status.message.startsWith("Getting offsets from")), AssertOnQuery(_.recentProgress.count(_.numInputRows > 0) === 0), // Test status and progress while batch is being fetched - AdvanceManualClock(200), // time = 300 to unblock getOffset, will block on getBatch - AssertStreamExecThreadToWaitForClock(), + AdvanceManualClock(50), // time = 1050 to unblock getOffset + AssertClockTime(1050), + AssertStreamExecThreadIsWaitingForTime(1150), // will block on getBatch that needs 1150 AssertOnQuery(_.status.isDataAvailable === true), AssertOnQuery(_.status.isTriggerActive === true), AssertOnQuery(_.status.message === "Processing new data"), AssertOnQuery(_.recentProgress.count(_.numInputRows > 0) === 0), // Test status and progress while batch is being processed - AdvanceManualClock(300), // time = 600 to unblock getBatch, will block in Spark job + AdvanceManualClock(100), // time = 1150 to unblock getBatch + AssertClockTime(1150), + AssertStreamExecThreadIsWaitingForTime(1500), // will block in Spark job that needs 1500 AssertOnQuery(_.status.isDataAvailable === true), AssertOnQuery(_.status.isTriggerActive === true), AssertOnQuery(_.status.message === "Processing new data"), AssertOnQuery(_.recentProgress.count(_.numInputRows > 0) === 0), // Test status and progress while batch processing has completed - AdvanceManualClock(500), // time = 1100 to unblock job - AssertOnQuery { _ => clock.getTimeMillis() === 1100 }, + AssertOnQuery { _ => clock.getTimeMillis() === 1150 }, + AdvanceManualClock(350), // time = 1500 to unblock job + AssertClockTime(1500), CheckAnswer(2), + AssertStreamExecThreadIsWaitingForTime(2000), AssertOnQuery(_.status.isDataAvailable === true), AssertOnQuery(_.status.isTriggerActive === false), AssertOnQuery(_.status.message === "Waiting for next trigger"), @@ -249,21 +304,21 @@ class StreamingQuerySuite extends StreamTest with BeforeAndAfter with Logging wi assert(progress.id === query.id) assert(progress.name === query.name) assert(progress.batchId === 0) - assert(progress.timestamp === "1970-01-01T00:00:00.100Z") // 100 ms in UTC + assert(progress.timestamp === "1970-01-01T00:00:01.000Z") // 100 ms in UTC assert(progress.numInputRows === 2) - assert(progress.processedRowsPerSecond === 2.0) + assert(progress.processedRowsPerSecond === 4.0) - assert(progress.durationMs.get("getOffset") === 200) - assert(progress.durationMs.get("getBatch") === 300) + assert(progress.durationMs.get("getOffset") === 50) + assert(progress.durationMs.get("getBatch") === 100) assert(progress.durationMs.get("queryPlanning") === 0) assert(progress.durationMs.get("walCommit") === 0) - assert(progress.durationMs.get("triggerExecution") === 1000) + assert(progress.durationMs.get("triggerExecution") === 500) assert(progress.sources.length === 1) assert(progress.sources(0).description contains "MemoryStream") assert(progress.sources(0).startOffset === null) assert(progress.sources(0).endOffset !== null) - assert(progress.sources(0).processedRowsPerSecond === 2.0) + assert(progress.sources(0).processedRowsPerSecond === 4.0) // 2 rows processed in 500 ms assert(progress.stateOperators.length === 1) assert(progress.stateOperators(0).numRowsUpdated === 1) @@ -273,8 +328,12 @@ class StreamingQuerySuite extends StreamTest with BeforeAndAfter with Logging wi true }, + // Test whether input rate is updated after two batches + AssertStreamExecThreadIsWaitingForTime(2000), // blocked waiting for next trigger time AddData(inputData, 1, 2), - AdvanceManualClock(100), // allow another trigger + AdvanceManualClock(500), // allow another trigger + AssertClockTime(2000), + AssertStreamExecThreadIsWaitingForTime(3000), // will block waiting for next trigger time CheckAnswer(4), AssertOnQuery(_.status.isDataAvailable === true), AssertOnQuery(_.status.isTriggerActive === false), @@ -282,13 +341,14 @@ class StreamingQuerySuite extends StreamTest with BeforeAndAfter with Logging wi AssertOnQuery { query => assert(query.recentProgress.last.eq(query.lastProgress)) assert(query.lastProgress.batchId === 1) - assert(query.lastProgress.sources(0).inputRowsPerSecond === 1.818) + assert(query.lastProgress.inputRowsPerSecond === 2.0) + assert(query.lastProgress.sources(0).inputRowsPerSecond === 2.0) true }, // Test status and progress after data is not available for a trigger - AdvanceManualClock(100), // allow another trigger - AssertStreamExecThreadToWaitForClock(), + AdvanceManualClock(1000), // allow another trigger + AssertStreamExecThreadIsWaitingForTime(4000), AssertOnQuery(_.status.isDataAvailable === false), AssertOnQuery(_.status.isTriggerActive === false), AssertOnQuery(_.status.message === "Waiting for next trigger"), @@ -305,9 +365,10 @@ class StreamingQuerySuite extends StreamTest with BeforeAndAfter with Logging wi AssertOnQuery(_.status.message === "Stopped"), // Test status and progress after query terminated with error - StartStream(ProcessingTime(100), triggerClock = clock), + StartStream(ProcessingTime(1000), triggerClock = clock), + AdvanceManualClock(1000), // ensure initial trigger completes before AddData AddData(inputData, 0), - AdvanceManualClock(100), + AdvanceManualClock(1000), // allow another trigger ExpectFailure[SparkException](), AssertOnQuery(_.status.isDataAvailable === false), AssertOnQuery(_.status.isTriggerActive === false), @@ -441,7 +502,7 @@ class StreamingQuerySuite extends StreamTest with BeforeAndAfter with Logging wi } } - test("StreamingQuery should be Serializable but cannot be used in executors") { + testQuietly("StreamingQuery should be Serializable but cannot be used in executors") { def startQuery(ds: Dataset[Int], queryName: String): StreamingQuery = { ds.writeStream .queryName(queryName) @@ -581,8 +642,10 @@ class StreamingQuerySuite extends StreamTest with BeforeAndAfter with Logging wi * * @param expectedBehavior Expected behavior (not blocked, blocked, or exception thrown) * @param timeoutMs Timeout in milliseconds - * When timeoutMs <= 0, awaitTermination() is tested (i.e. w/o timeout) - * When timeoutMs > 0, awaitTermination(timeoutMs) is tested + * When timeoutMs is less than or equal to 0, awaitTermination() is + * tested (i.e. w/o timeout) + * When timeoutMs is greater than 0, awaitTermination(timeoutMs) is + * tested * @param expectedReturnValue Expected return value when awaitTermination(timeoutMs) is used */ case class TestAwaitTermination( @@ -606,8 +669,10 @@ class StreamingQuerySuite extends StreamTest with BeforeAndAfter with Logging wi * * @param expectedBehavior Expected behavior (not blocked, blocked, or exception thrown) * @param timeoutMs Timeout in milliseconds - * When timeoutMs <= 0, awaitTermination() is tested (i.e. w/o timeout) - * When timeoutMs > 0, awaitTermination(timeoutMs) is tested + * When timeoutMs is less than or equal to 0, awaitTermination() is + * tested (i.e. w/o timeout) + * When timeoutMs is greater than 0, awaitTermination(timeoutMs) is + * tested * @param expectedReturnValue Expected return value when awaitTermination(timeoutMs) is used */ def assertOnQueryCondition( @@ -632,5 +697,5 @@ class StreamingQuerySuite extends StreamTest with BeforeAndAfter with Logging wi object StreamingQuerySuite { // Singleton reference to clock that does not get serialized in task closures - var clock: ManualClock = null + var clock: StreamManualClock = null } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/test/DataStreamReaderWriterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/test/DataStreamReaderWriterSuite.scala index 0470411a0f108..dc2506a48ad00 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/test/DataStreamReaderWriterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/test/DataStreamReaderWriterSuite.scala @@ -18,20 +18,22 @@ package org.apache.spark.sql.streaming.test import java.io.File +import java.util.Locale import java.util.concurrent.TimeUnit import scala.concurrent.duration._ import org.apache.hadoop.fs.Path +import org.mockito.Matchers.{any, eq => meq} import org.mockito.Mockito._ -import org.scalatest.{BeforeAndAfter, PrivateMethodTester} -import org.scalatest.PrivateMethodTester.PrivateMethod +import org.scalatest.BeforeAndAfter import org.apache.spark.sql._ import org.apache.spark.sql.execution.streaming._ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.sources.{StreamSinkProvider, StreamSourceProvider} -import org.apache.spark.sql.streaming._ +import org.apache.spark.sql.streaming.{ProcessingTime => DeprecatedProcessingTime, _} +import org.apache.spark.sql.streaming.Trigger._ import org.apache.spark.sql.types._ import org.apache.spark.util.Utils @@ -107,7 +109,7 @@ class DefaultSource extends StreamSourceProvider with StreamSinkProvider { } } -class DataStreamReaderWriterSuite extends StreamTest with BeforeAndAfter with PrivateMethodTester { +class DataStreamReaderWriterSuite extends StreamTest with BeforeAndAfter { private def newMetadataDir = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath @@ -125,7 +127,7 @@ class DataStreamReaderWriterSuite extends StreamTest with BeforeAndAfter with Pr .save() } Seq("'write'", "not", "streaming Dataset/DataFrame").foreach { s => - assert(e.getMessage.toLowerCase.contains(s.toLowerCase)) + assert(e.getMessage.toLowerCase(Locale.ROOT).contains(s.toLowerCase(Locale.ROOT))) } } @@ -346,7 +348,7 @@ class DataStreamReaderWriterSuite extends StreamTest with BeforeAndAfter with Pr q = df.writeStream .format("org.apache.spark.sql.streaming.test") .option("checkpointLocation", newMetadataDir) - .trigger(ProcessingTime.create(100, TimeUnit.SECONDS)) + .trigger(ProcessingTime(100, TimeUnit.SECONDS)) .start() q.stop() @@ -371,61 +373,26 @@ class DataStreamReaderWriterSuite extends StreamTest with BeforeAndAfter with Pr .option("checkpointLocation", checkpointLocationURI.toString) .trigger(ProcessingTime(10.seconds)) .start() + q.processAllAvailable() q.stop() verify(LastOptions.mockStreamSourceProvider).createSource( - spark.sqlContext, - s"$checkpointLocationURI/sources/0", - None, - "org.apache.spark.sql.streaming.test", - Map.empty) + any(), + meq(s"$checkpointLocationURI/sources/0"), + meq(None), + meq("org.apache.spark.sql.streaming.test"), + meq(Map.empty)) verify(LastOptions.mockStreamSourceProvider).createSource( - spark.sqlContext, - s"$checkpointLocationURI/sources/1", - None, - "org.apache.spark.sql.streaming.test", - Map.empty) + any(), + meq(s"$checkpointLocationURI/sources/1"), + meq(None), + meq("org.apache.spark.sql.streaming.test"), + meq(Map.empty)) } private def newTextInput = Utils.createTempDir(namePrefix = "text").getCanonicalPath - test("supported strings in outputMode(string)") { - val outputModeMethod = PrivateMethod[OutputMode]('outputMode) - - def testMode(outputMode: String, expected: OutputMode): Unit = { - val df = spark.readStream - .format("org.apache.spark.sql.streaming.test") - .load() - val w = df.writeStream - w.outputMode(outputMode) - val setOutputMode = w invokePrivate outputModeMethod() - assert(setOutputMode === expected) - } - - testMode("append", OutputMode.Append) - testMode("Append", OutputMode.Append) - testMode("complete", OutputMode.Complete) - testMode("Complete", OutputMode.Complete) - testMode("update", OutputMode.Update) - testMode("Update", OutputMode.Update) - } - - test("unsupported strings in outputMode(string)") { - def testMode(outputMode: String): Unit = { - val acceptedModes = Seq("append", "update", "complete") - val df = spark.readStream - .format("org.apache.spark.sql.streaming.test") - .load() - val w = df.writeStream - val e = intercept[IllegalArgumentException](w.outputMode(outputMode)) - (Seq("output mode", "unknown", outputMode) ++ acceptedModes).foreach { s => - assert(e.getMessage.toLowerCase.contains(s.toLowerCase)) - } - } - testMode("Xyz") - } - test("check foreach() catches null writers") { val df = spark.readStream .format("org.apache.spark.sql.streaming.test") @@ -434,7 +401,7 @@ class DataStreamReaderWriterSuite extends StreamTest with BeforeAndAfter with Pr var w = df.writeStream var e = intercept[IllegalArgumentException](w.foreach(null)) Seq("foreach", "null").foreach { s => - assert(e.getMessage.toLowerCase.contains(s.toLowerCase)) + assert(e.getMessage.toLowerCase(Locale.ROOT).contains(s.toLowerCase(Locale.ROOT))) } } @@ -451,7 +418,7 @@ class DataStreamReaderWriterSuite extends StreamTest with BeforeAndAfter with Pr var w = df.writeStream.partitionBy("value") var e = intercept[AnalysisException](w.foreach(foreachWriter).start()) Seq("foreach", "partitioning").foreach { s => - assert(e.getMessage.toLowerCase.contains(s.toLowerCase)) + assert(e.getMessage.toLowerCase(Locale.ROOT).contains(s.toLowerCase(Locale.ROOT))) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/util/StreamManualClock.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/util/StreamManualClock.scala new file mode 100644 index 0000000000000..c769a790a4168 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/util/StreamManualClock.scala @@ -0,0 +1,51 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.streaming.util + +import org.apache.spark.util.ManualClock + +/** + * ManualClock used for streaming tests that allows checking whether the stream is waiting + * on the clock at expected times. + */ +class StreamManualClock(time: Long = 0L) extends ManualClock(time) with Serializable { + private var waitStartTime: Option[Long] = None + private var waitTargetTime: Option[Long] = None + + override def waitTillTime(targetTime: Long): Long = synchronized { + try { + waitStartTime = Some(getTimeMillis()) + waitTargetTime = Some(targetTime) + super.waitTillTime(targetTime) + } finally { + waitStartTime = None + waitTargetTime = None + } + } + + /** Is the streaming thread waiting for the clock to advance when it is at the given time */ + def isStreamWaitingAt(time: Long): Boolean = synchronized { + waitStartTime == Some(time) + } + + /** Is the streaming thread waiting for clock to advance to the given time */ + def isStreamWaitingFor(target: Long): Boolean = synchronized { + waitTargetTime == Some(target) + } +} + diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/DataFrameReaderWriterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/DataFrameReaderWriterSuite.scala index 8a8ba05534529..fb15e7def6dbe 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/test/DataFrameReaderWriterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/test/DataFrameReaderWriterSuite.scala @@ -18,9 +18,13 @@ package org.apache.spark.sql.test import java.io.File +import java.util.Locale +import java.util.concurrent.ConcurrentLinkedQueue import org.scalatest.BeforeAndAfter +import org.apache.spark.internal.io.FileCommitProtocol.TaskCommitMessage +import org.apache.spark.internal.io.HadoopMapReduceCommitProtocol import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.sources._ @@ -41,7 +45,6 @@ object LastOptions { } } - /** Dummy provider. */ class DefaultSource extends RelationProvider @@ -107,6 +110,20 @@ class DefaultSourceWithoutUserSpecifiedSchema } } +object MessageCapturingCommitProtocol { + val commitMessages = new ConcurrentLinkedQueue[TaskCommitMessage]() +} + +class MessageCapturingCommitProtocol(jobId: String, path: String) + extends HadoopMapReduceCommitProtocol(jobId, path) { + + // captures commit messages for testing + override def onTaskCommit(msg: TaskCommitMessage): Unit = { + MessageCapturingCommitProtocol.commitMessages.offer(msg) + } +} + + class DataFrameReaderWriterSuite extends QueryTest with SharedSQLContext with BeforeAndAfter { import testImplicits._ @@ -128,7 +145,7 @@ class DataFrameReaderWriterSuite extends QueryTest with SharedSQLContext with Be .start() } Seq("'writeStream'", "only", "streaming Dataset/DataFrame").foreach { s => - assert(e.getMessage.toLowerCase.contains(s.toLowerCase)) + assert(e.getMessage.toLowerCase(Locale.ROOT).contains(s.toLowerCase(Locale.ROOT))) } } @@ -260,13 +277,13 @@ class DataFrameReaderWriterSuite extends QueryTest with SharedSQLContext with Be var w = df.write.partitionBy("value") var e = intercept[AnalysisException](w.jdbc(null, null, null)) Seq("jdbc", "partitioning").foreach { s => - assert(e.getMessage.toLowerCase.contains(s.toLowerCase)) + assert(e.getMessage.toLowerCase(Locale.ROOT).contains(s.toLowerCase(Locale.ROOT))) } w = df.write.bucketBy(2, "value") e = intercept[AnalysisException](w.jdbc(null, null, null)) Seq("jdbc", "bucketing").foreach { s => - assert(e.getMessage.toLowerCase.contains(s.toLowerCase)) + assert(e.getMessage.toLowerCase(Locale.ROOT).contains(s.toLowerCase(Locale.ROOT))) } } @@ -291,6 +308,19 @@ class DataFrameReaderWriterSuite extends QueryTest with SharedSQLContext with Be Option(dir).map(spark.read.format("org.apache.spark.sql.test").load) } + test("write path implements onTaskCommit API correctly") { + withSQLConf( + "spark.sql.sources.commitProtocolClass" -> + classOf[MessageCapturingCommitProtocol].getCanonicalName) { + withTempDir { dir => + val path = dir.getCanonicalPath + MessageCapturingCommitProtocol.commitMessages.clear() + spark.range(10).repartition(10).write.mode("overwrite").parquet(path) + assert(MessageCapturingCommitProtocol.commitMessages.size() == 10) + } + } + } + test("read a data source that does not extend SchemaRelationProvider") { val dfReader = spark.read .option("from", "1") @@ -356,7 +386,8 @@ class DataFrameReaderWriterSuite extends QueryTest with SharedSQLContext with Be // Reader, with user specified schema, should just apply user schema on the file data val e = intercept[AnalysisException] { spark.read.schema(userSchema).textFile() } - assert(e.getMessage.toLowerCase.contains("user specified schema not supported")) + assert(e.getMessage.toLowerCase(Locale.ROOT).contains( + "user specified schema not supported")) intercept[AnalysisException] { spark.read.schema(userSchema).textFile(dir) } intercept[AnalysisException] { spark.read.schema(userSchema).textFile(dir, dir) } intercept[AnalysisException] { spark.read.schema(userSchema).textFile(Seq(dir, dir): _*) } @@ -370,9 +401,11 @@ class DataFrameReaderWriterSuite extends QueryTest with SharedSQLContext with Be val schema = df.schema // Reader, without user specified schema - intercept[IllegalArgumentException] { + val message = intercept[AnalysisException] { testRead(spark.read.csv(), Seq.empty, schema) - } + }.getMessage + assert(message.contains("Unable to infer schema for CSV. It must be specified manually.")) + testRead(spark.read.csv(dir), data, schema) testRead(spark.read.csv(dir, dir), data ++ data, schema) testRead(spark.read.csv(Seq(dir, dir): _*), data ++ data, schema) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala index d4afb9d8af6f8..f6d47734d7e83 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala @@ -18,14 +18,17 @@ package org.apache.spark.sql.test import java.io.File -import java.util.UUID +import java.net.URI +import java.nio.file.Files +import java.util.{Locale, UUID} +import scala.concurrent.duration._ import scala.language.implicitConversions -import scala.util.Try import scala.util.control.NonFatal -import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.fs.Path import org.scalatest.BeforeAndAfterAll +import org.scalatest.concurrent.Eventually import org.apache.spark.SparkFunSuite import org.apache.spark.sql._ @@ -40,15 +43,15 @@ import org.apache.spark.util.{UninterruptibleThread, Utils} /** * Helper trait that should be extended by all SQL test suites. * - * This allows subclasses to plugin a custom [[SQLContext]]. It comes with test data + * This allows subclasses to plugin a custom `SQLContext`. It comes with test data * prepared in advance as well as all implicit conversions used extensively by dataframes. - * To use implicit methods, import `testImplicits._` instead of through the [[SQLContext]]. + * To use implicit methods, import `testImplicits._` instead of through the `SQLContext`. * - * Subclasses should *not* create [[SQLContext]]s in the test suite constructor, which is + * Subclasses should *not* create `SQLContext`s in the test suite constructor, which is * prone to leaving multiple overlapping [[org.apache.spark.SparkContext]]s in the same JVM. */ private[sql] trait SQLTestUtils - extends SparkFunSuite + extends SparkFunSuite with Eventually with BeforeAndAfterAll with SQLTestData { self => @@ -64,7 +67,7 @@ private[sql] trait SQLTestUtils * A helper object for importing SQL implicits. * * Note that the alternative of importing `spark.implicits._` is not possible here. - * This is because we create the [[SQLContext]] immediately before the first test is run, + * This is because we create the `SQLContext` immediately before the first test is run, * but the implicits import is needed in the constructor. */ protected object testImplicits extends SQLImplicits { @@ -72,7 +75,7 @@ private[sql] trait SQLTestUtils } /** - * Materialize the test data immediately after the [[SQLContext]] is set up. + * Materialize the test data immediately after the `SQLContext` is set up. * This is necessary if the data is accessed by name but not through direct reference. */ protected def setupTestData(): Unit = { @@ -122,6 +125,30 @@ private[sql] trait SQLTestUtils try f(path) finally Utils.deleteRecursively(path) } + /** + * Copy file in jar's resource to a temp file, then pass it to `f`. + * This function is used to make `f` can use the path of temp file(e.g. file:/), instead of + * path of jar's resource which starts with 'jar:file:/' + */ + protected def withResourceTempPath(resourcePath: String)(f: File => Unit): Unit = { + val inputStream = + Thread.currentThread().getContextClassLoader.getResourceAsStream(resourcePath) + withTempDir { dir => + val tmpFile = new File(dir, "tmp") + Files.copy(inputStream, tmpFile.toPath) + f(tmpFile) + } + } + + /** + * Waits for all tasks on all executors to be finished. + */ + protected def waitForTasksToFinish(): Unit = { + eventually(timeout(10.seconds)) { + assert(spark.sparkContext.statusTracker + .getExecutorInfos.map(_.numRunningTasks()).sum == 0) + } + } /** * Creates a temporary directory, which is then passed to `f` and will be deleted after `f` * returns. @@ -130,7 +157,11 @@ private[sql] trait SQLTestUtils */ protected def withTempDir(f: File => Unit): Unit = { val dir = Utils.createTempDir().getCanonicalFile - try f(dir) finally Utils.deleteRecursively(dir) + try f(dir) finally { + // wait for all tasks to finish before deleting files + waitForTasksToFinish() + Utils.deleteRecursively(dir) + } } /** @@ -206,12 +237,39 @@ private[sql] trait SQLTestUtils try f(dbName) finally { if (spark.catalog.currentDatabase == dbName) { - spark.sql(s"USE ${DEFAULT_DATABASE}") + spark.sql(s"USE $DEFAULT_DATABASE") } spark.sql(s"DROP DATABASE $dbName CASCADE") } } + /** + * Drops database `dbName` after calling `f`. + */ + protected def withDatabase(dbNames: String*)(f: => Unit): Unit = { + try f finally { + dbNames.foreach { name => + spark.sql(s"DROP DATABASE IF EXISTS $name") + } + spark.sql(s"USE $DEFAULT_DATABASE") + } + } + + /** + * Enables Locale `language` before executing `f`, then switches back to the default locale of JVM + * after `f` returns. + */ + protected def withLocale(language: String)(f: => Unit): Unit = { + val originalLocale = Locale.getDefault + try { + // Add Locale setting + Locale.setDefault(new Locale(language)) + f + } finally { + Locale.setDefault(originalLocale) + } + } + /** * Activates database `db` before executing `f`, then switches back to `default` database after * `f` returns. @@ -234,8 +292,8 @@ private[sql] trait SQLTestUtils } /** - * Turn a logical plan into a [[DataFrame]]. This should be removed once we have an easier - * way to construct [[DataFrame]] directly out of local data without relying on implicits. + * Turn a logical plan into a `DataFrame`. This should be removed once we have an easier + * way to construct `DataFrame` directly out of local data without relying on implicits. */ protected implicit def logicalPlanToSparkQuery(plan: LogicalPlan): DataFrame = { Dataset.ofRows(spark, plan) @@ -255,7 +313,9 @@ private[sql] trait SQLTestUtils } } - /** Run a test on a separate [[UninterruptibleThread]]. */ + /** + * Run a test on a separate `UninterruptibleThread`. + */ protected def testWithUninterruptibleThread(name: String, quietly: Boolean = false) (body: => Unit): Unit = { val timeoutMillis = 10000 @@ -294,6 +354,17 @@ private[sql] trait SQLTestUtils test(name) { runOnThread() } } } + + /** + * This method is used to make the given path qualified, when a path + * does not contain a scheme, this path will not be changed after the default + * FileSystem is changed. + */ + def makeQualifiedPath(path: String): URI = { + val hadoopPath = new Path(path) + val fs = hadoopPath.getFileSystem(spark.sessionState.newHadoopConf()) + fs.makeQualified(hadoopPath).toUri + } } private[sql] object SQLTestUtils { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/SharedSQLContext.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/SharedSQLContext.scala index e122b39f6fc40..81c69a338abcc 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/test/SharedSQLContext.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/test/SharedSQLContext.scala @@ -17,19 +17,22 @@ package org.apache.spark.sql.test +import scala.concurrent.duration._ + import org.scalatest.BeforeAndAfterEach +import org.scalatest.concurrent.Eventually import org.apache.spark.{DebugFilesystem, SparkConf} import org.apache.spark.sql.{SparkSession, SQLContext} -import org.apache.spark.sql.internal.SQLConf - /** * Helper trait for SQL test suites where all tests share a single [[TestSparkSession]]. */ -trait SharedSQLContext extends SQLTestUtils with BeforeAndAfterEach { +trait SharedSQLContext extends SQLTestUtils with BeforeAndAfterEach with Eventually { - protected val sparkConf = new SparkConf() + protected def sparkConf = { + new SparkConf().set("spark.hadoop.fs.file.impl", classOf[DebugFilesystem].getName) + } /** * The [[TestSparkSession]] to use for all tests in this suite. @@ -50,8 +53,7 @@ trait SharedSQLContext extends SQLTestUtils with BeforeAndAfterEach { protected implicit def sqlContext: SQLContext = _spark.sqlContext protected def createSparkSession: TestSparkSession = { - new TestSparkSession( - sparkConf.set("spark.hadoop.fs.file.impl", classOf[DebugFilesystem].getName)) + new TestSparkSession(sparkConf) } /** @@ -84,6 +86,10 @@ trait SharedSQLContext extends SQLTestUtils with BeforeAndAfterEach { protected override def afterEach(): Unit = { super.afterEach() - DebugFilesystem.assertNoOpenStreams() + // files can be closed from other threads, so wait a bit + // normally this doesn't take more than 1s + eventually(timeout(10.seconds)) { + DebugFilesystem.assertNoOpenStreams() + } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/TestSQLContext.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/TestSQLContext.scala index 8ab6db175da5d..959edf9a49371 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/test/TestSQLContext.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/test/TestSQLContext.scala @@ -19,10 +19,10 @@ package org.apache.spark.sql.test import org.apache.spark.{SparkConf, SparkContext} import org.apache.spark.sql.SparkSession -import org.apache.spark.sql.internal.{SessionState, SQLConf} +import org.apache.spark.sql.internal.{SessionState, SessionStateBuilder, SQLConf, WithTestConf} /** - * A special [[SparkSession]] prepared for testing. + * A special `SparkSession` prepared for testing. */ private[sql] class TestSparkSession(sc: SparkContext) extends SparkSession(sc) { self => def this(sparkConf: SparkConf) { @@ -35,17 +35,8 @@ private[sql] class TestSparkSession(sc: SparkContext) extends SparkSession(sc) { } @transient - override lazy val sessionState: SessionState = new SessionState(self) { - override lazy val conf: SQLConf = { - new SQLConf { - clear() - override def clear(): Unit = { - super.clear() - // Make sure we start with the default test configs even after clear - TestSQLContext.overrideConfs.foreach { case (key, value) => setConfString(key, value) } - } - } - } + override lazy val sessionState: SessionState = { + new TestSQLSessionStateBuilder(this, None).build() } // Needed for Java tests @@ -69,3 +60,11 @@ private[sql] object TestSQLContext { // Fewer shuffle partitions to speed up testing. SQLConf.SHUFFLE_PARTITIONS.key -> "5") } + +private[sql] class TestSQLSessionStateBuilder( + session: SparkSession, + state: Option[SessionState]) + extends SessionStateBuilder(session, state) with WithTestConf { + override def overrideConfs: Map[String, String] = TestSQLContext.overrideConfs + override def newBuilder: NewBuilder = new TestSQLSessionStateBuilder(_, _) +} diff --git a/sql/hive-thriftserver/pom.xml b/sql/hive-thriftserver/pom.xml index 9c879218ddc0d..a5a8e2640586c 100644 --- a/sql/hive-thriftserver/pom.xml +++ b/sql/hive-thriftserver/pom.xml @@ -22,7 +22,7 @@ org.apache.spark spark-parent_2.11 - 2.2.0-SNAPSHOT + 2.3.0-SNAPSHOT ../../pom.xml diff --git a/sql/hive-thriftserver/src/main/java/org/apache/hive/service/Service.java b/sql/hive-thriftserver/src/main/java/org/apache/hive/service/Service.java index b95077cd62186..0d0e3e4011b5b 100644 --- a/sql/hive-thriftserver/src/main/java/org/apache/hive/service/Service.java +++ b/sql/hive-thriftserver/src/main/java/org/apache/hive/service/Service.java @@ -49,7 +49,7 @@ enum STATE { * The transition must be from {@link STATE#NOTINITED} to {@link STATE#INITED} unless the * operation failed and an exception was raised. * - * @param config + * @param conf * the configuration of the service */ void init(HiveConf conf); diff --git a/sql/hive-thriftserver/src/main/java/org/apache/hive/service/ServiceOperations.java b/sql/hive-thriftserver/src/main/java/org/apache/hive/service/ServiceOperations.java index a2c580d6acc71..c3219aabfc23b 100644 --- a/sql/hive-thriftserver/src/main/java/org/apache/hive/service/ServiceOperations.java +++ b/sql/hive-thriftserver/src/main/java/org/apache/hive/service/ServiceOperations.java @@ -51,7 +51,7 @@ public static void ensureCurrentState(Service.STATE state, /** * Initialize a service. - *

      + * * The service state is checked before the operation begins. * This process is not thread safe. * @param service a service that must be in the state @@ -69,7 +69,7 @@ public static void init(Service service, HiveConf configuration) { /** * Start a service. - *

      + * * The service state is checked before the operation begins. * This process is not thread safe. * @param service a service that must be in the state @@ -86,7 +86,7 @@ public static void start(Service service) { /** * Initialize then start a service. - *

      + * * The service state is checked before the operation begins. * This process is not thread safe. * @param service a service that must be in the state @@ -102,9 +102,9 @@ public static void deploy(Service service, HiveConf configuration) { /** * Stop a service. - *

      Do nothing if the service is null or not - * in a state in which it can be/needs to be stopped. - *

      + * + * Do nothing if the service is null or not in a state in which it can be/needs to be stopped. + * * The service state is checked before the operation begins. * This process is not thread safe. * @param service a service or null diff --git a/sql/hive-thriftserver/src/main/java/org/apache/hive/service/auth/HiveAuthFactory.java b/sql/hive-thriftserver/src/main/java/org/apache/hive/service/auth/HiveAuthFactory.java index 1e6ac4f3df475..c5ade65283045 100644 --- a/sql/hive-thriftserver/src/main/java/org/apache/hive/service/auth/HiveAuthFactory.java +++ b/sql/hive-thriftserver/src/main/java/org/apache/hive/service/auth/HiveAuthFactory.java @@ -24,6 +24,7 @@ import java.util.Arrays; import java.util.HashMap; import java.util.List; +import java.util.Locale; import java.util.Map; import javax.net.ssl.SSLServerSocket; @@ -259,12 +260,12 @@ public static TServerSocket getServerSSLSocket(String hiveHost, int portNum, Str if (thriftServerSocket.getServerSocket() instanceof SSLServerSocket) { List sslVersionBlacklistLocal = new ArrayList(); for (String sslVersion : sslVersionBlacklist) { - sslVersionBlacklistLocal.add(sslVersion.trim().toLowerCase()); + sslVersionBlacklistLocal.add(sslVersion.trim().toLowerCase(Locale.ROOT)); } SSLServerSocket sslServerSocket = (SSLServerSocket) thriftServerSocket.getServerSocket(); List enabledProtocols = new ArrayList(); for (String protocol : sslServerSocket.getEnabledProtocols()) { - if (sslVersionBlacklistLocal.contains(protocol.toLowerCase())) { + if (sslVersionBlacklistLocal.contains(protocol.toLowerCase(Locale.ROOT))) { LOG.debug("Disabling SSL Protocol: " + protocol); } else { enabledProtocols.add(protocol); diff --git a/sql/hive-thriftserver/src/main/java/org/apache/hive/service/auth/HttpAuthUtils.java b/sql/hive-thriftserver/src/main/java/org/apache/hive/service/auth/HttpAuthUtils.java index 5021528299682..f7375ee707830 100644 --- a/sql/hive-thriftserver/src/main/java/org/apache/hive/service/auth/HttpAuthUtils.java +++ b/sql/hive-thriftserver/src/main/java/org/apache/hive/service/auth/HttpAuthUtils.java @@ -89,7 +89,7 @@ public static String getKerberosServiceTicket(String principal, String host, * @param clientUserName Client User name. * @return An unsigned cookie token generated from input parameters. * The final cookie generated is of the following format : - * cu=&rn=&s= + * {@code cu=&rn=&s=} */ public static String createCookieToken(String clientUserName) { StringBuffer sb = new StringBuffer(); diff --git a/sql/hive-thriftserver/src/main/java/org/apache/hive/service/auth/PasswdAuthenticationProvider.java b/sql/hive-thriftserver/src/main/java/org/apache/hive/service/auth/PasswdAuthenticationProvider.java index e2a6de165adc5..1af1c1d06e7f7 100644 --- a/sql/hive-thriftserver/src/main/java/org/apache/hive/service/auth/PasswdAuthenticationProvider.java +++ b/sql/hive-thriftserver/src/main/java/org/apache/hive/service/auth/PasswdAuthenticationProvider.java @@ -26,7 +26,7 @@ public interface PasswdAuthenticationProvider { * to authenticate users for their requests. * If a user is to be granted, return nothing/throw nothing. * When a user is to be disallowed, throw an appropriate {@link AuthenticationException}. - *

      + * * For an example implementation, see {@link LdapAuthenticationProviderImpl}. * * @param user The username received over the connection request diff --git a/sql/hive-thriftserver/src/main/java/org/apache/hive/service/auth/SaslQOP.java b/sql/hive-thriftserver/src/main/java/org/apache/hive/service/auth/SaslQOP.java index ab3ac6285aa02..ad4dfd75f4707 100644 --- a/sql/hive-thriftserver/src/main/java/org/apache/hive/service/auth/SaslQOP.java +++ b/sql/hive-thriftserver/src/main/java/org/apache/hive/service/auth/SaslQOP.java @@ -19,6 +19,7 @@ package org.apache.hive.service.auth; import java.util.HashMap; +import java.util.Locale; import java.util.Map; /** @@ -52,7 +53,7 @@ public String toString() { public static SaslQOP fromString(String str) { if (str != null) { - str = str.toLowerCase(); + str = str.toLowerCase(Locale.ROOT); } SaslQOP saslQOP = STR_TO_ENUM.get(str); if (saslQOP == null) { diff --git a/sql/hive-thriftserver/src/main/java/org/apache/hive/service/auth/TSetIpAddressProcessor.java b/sql/hive-thriftserver/src/main/java/org/apache/hive/service/auth/TSetIpAddressProcessor.java index 645e3e2bbd4e2..9a61ad49942c8 100644 --- a/sql/hive-thriftserver/src/main/java/org/apache/hive/service/auth/TSetIpAddressProcessor.java +++ b/sql/hive-thriftserver/src/main/java/org/apache/hive/service/auth/TSetIpAddressProcessor.java @@ -31,12 +31,9 @@ /** * This class is responsible for setting the ipAddress for operations executed via HiveServer2. - *

      - *

        - *
      • IP address is only set for operations that calls listeners with hookContext
      • - *
      • IP address is only set if the underlying transport mechanism is socket
      • - *
      - *

      + * + * - IP address is only set for operations that calls listeners with hookContext + * - IP address is only set if the underlying transport mechanism is socket * * @see org.apache.hadoop.hive.ql.hooks.ExecuteWithHookContext */ diff --git a/sql/hive-thriftserver/src/main/java/org/apache/hive/service/cli/CLIServiceUtils.java b/sql/hive-thriftserver/src/main/java/org/apache/hive/service/cli/CLIServiceUtils.java index 9d64b102e008d..bf2380632fa6c 100644 --- a/sql/hive-thriftserver/src/main/java/org/apache/hive/service/cli/CLIServiceUtils.java +++ b/sql/hive-thriftserver/src/main/java/org/apache/hive/service/cli/CLIServiceUtils.java @@ -38,7 +38,7 @@ public class CLIServiceUtils { * Convert a SQL search pattern into an equivalent Java Regex. * * @param pattern input which may contain '%' or '_' wildcard characters, or - * these characters escaped using {@link #getSearchStringEscape()}. + * these characters escaped using {@code getSearchStringEscape()}. * @return replace %/_ with regex search characters, also handle escaped * characters. */ diff --git a/sql/hive-thriftserver/src/main/java/org/apache/hive/service/cli/Type.java b/sql/hive-thriftserver/src/main/java/org/apache/hive/service/cli/Type.java index a96d2ac371cd3..7752ec03a29b7 100644 --- a/sql/hive-thriftserver/src/main/java/org/apache/hive/service/cli/Type.java +++ b/sql/hive-thriftserver/src/main/java/org/apache/hive/service/cli/Type.java @@ -19,6 +19,7 @@ package org.apache.hive.service.cli; import java.sql.DatabaseMetaData; +import java.util.Locale; import org.apache.hadoop.hive.common.type.HiveDecimal; import org.apache.hive.service.cli.thrift.TTypeId; @@ -160,7 +161,7 @@ public static Type getType(String name) { if (name.equalsIgnoreCase(type.name)) { return type; } else if (type.isQualifiedType() || type.isComplexType()) { - if (name.toUpperCase().startsWith(type.name)) { + if (name.toUpperCase(Locale.ROOT).startsWith(type.name)) { return type; } } diff --git a/sql/hive-thriftserver/src/main/java/org/apache/hive/service/cli/operation/ClassicTableTypeMapping.java b/sql/hive-thriftserver/src/main/java/org/apache/hive/service/cli/operation/ClassicTableTypeMapping.java index 05a6bf938404b..af36057bdaeca 100644 --- a/sql/hive-thriftserver/src/main/java/org/apache/hive/service/cli/operation/ClassicTableTypeMapping.java +++ b/sql/hive-thriftserver/src/main/java/org/apache/hive/service/cli/operation/ClassicTableTypeMapping.java @@ -28,9 +28,9 @@ /** * ClassicTableTypeMapping. * Classic table type mapping : - * Managed Table ==> Table - * External Table ==> Table - * Virtual View ==> View + * Managed Table to Table + * External Table to Table + * Virtual View to View */ public class ClassicTableTypeMapping implements TableTypeMapping { diff --git a/sql/hive-thriftserver/src/main/java/org/apache/hive/service/cli/operation/TableTypeMapping.java b/sql/hive-thriftserver/src/main/java/org/apache/hive/service/cli/operation/TableTypeMapping.java index e392c459cf586..e59d19ea6be42 100644 --- a/sql/hive-thriftserver/src/main/java/org/apache/hive/service/cli/operation/TableTypeMapping.java +++ b/sql/hive-thriftserver/src/main/java/org/apache/hive/service/cli/operation/TableTypeMapping.java @@ -31,7 +31,7 @@ public interface TableTypeMapping { /** * Map hive's table type name to client's table type - * @param clientTypeName + * @param hiveTypeName * @return */ String mapToClientType(String hiveTypeName); diff --git a/sql/hive-thriftserver/src/main/java/org/apache/hive/service/cli/session/SessionManager.java b/sql/hive-thriftserver/src/main/java/org/apache/hive/service/cli/session/SessionManager.java index de066dd406c7a..c1b3892f52060 100644 --- a/sql/hive-thriftserver/src/main/java/org/apache/hive/service/cli/session/SessionManager.java +++ b/sql/hive-thriftserver/src/main/java/org/apache/hive/service/cli/session/SessionManager.java @@ -224,7 +224,9 @@ public SessionHandle openSession(TProtocolVersion protocol, String username, Str * The username passed to this method is the effective username. * If withImpersonation is true (==doAs true) we wrap all the calls in HiveSession * within a UGI.doAs, where UGI corresponds to the effective user. - * @see org.apache.hive.service.cli.thrift.ThriftCLIService#getUserName() + * + * Please see {@code org.apache.hive.service.cli.thrift.ThriftCLIService.getUserName()} for + * more details. * * @param protocol * @param username diff --git a/sql/hive-thriftserver/src/main/java/org/apache/hive/service/server/ThreadFactoryWithGarbageCleanup.java b/sql/hive-thriftserver/src/main/java/org/apache/hive/service/server/ThreadFactoryWithGarbageCleanup.java index fb8141a905acb..94f8126552e9d 100644 --- a/sql/hive-thriftserver/src/main/java/org/apache/hive/service/server/ThreadFactoryWithGarbageCleanup.java +++ b/sql/hive-thriftserver/src/main/java/org/apache/hive/service/server/ThreadFactoryWithGarbageCleanup.java @@ -30,12 +30,12 @@ * in custom cleanup code to be called before this thread is GC-ed. * Currently cleans up the following: * 1. ThreadLocal RawStore object: - * In case of an embedded metastore, HiveServer2 threads (foreground & background) + * In case of an embedded metastore, HiveServer2 threads (foreground and background) * end up caching a ThreadLocal RawStore object. The ThreadLocal RawStore object has - * an instance of PersistenceManagerFactory & PersistenceManager. + * an instance of PersistenceManagerFactory and PersistenceManager. * The PersistenceManagerFactory keeps a cache of PersistenceManager objects, * which are only removed when PersistenceManager#close method is called. - * HiveServer2 uses ExecutorService for managing thread pools for foreground & background threads. + * HiveServer2 uses ExecutorService for managing thread pools for foreground and background threads. * ExecutorService unfortunately does not provide any hooks to be called, * when a thread from the pool is terminated. * As a solution, we're using this ThreadFactory to keep a cache of RawStore objects per thread. diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2.scala index 13c6f11f461c6..5e4734ad3ad25 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2.scala @@ -46,7 +46,7 @@ import org.apache.spark.util.{ShutdownHookManager, Utils} */ object HiveThriftServer2 extends Logging { var LOG = LogFactory.getLog(classOf[HiveServer2]) - var uiTab: Option[ThriftServerTab] = _ + var uiTab: Option[ThriftServerTab] = None var listener: HiveThriftServer2Listener = _ /** @@ -294,7 +294,7 @@ private[hive] class HiveThriftServer2(sqlContext: SQLContext) private def isHTTPTransportMode(hiveConf: HiveConf): Boolean = { val transportMode = hiveConf.getVar(ConfVars.HIVE_SERVER2_TRANSPORT_MODE) - transportMode.toLowerCase(Locale.ENGLISH).equals("http") + transportMode.toLowerCase(Locale.ROOT).equals("http") } diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkExecuteStatementOperation.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkExecuteStatementOperation.scala index 517b01f183926..ff3784cab9e26 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkExecuteStatementOperation.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkExecuteStatementOperation.scala @@ -292,7 +292,7 @@ object SparkExecuteStatementOperation { def getTableSchema(structType: StructType): TableSchema = { val schema = structType.map { field => val attrTypeString = if (field.dataType == NullType) "void" else field.dataType.catalogString - new FieldSchema(field.name, attrTypeString, "") + new FieldSchema(field.name, attrTypeString, field.getComment.getOrElse("")) } new TableSchema(schema.asJava) } diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIDriver.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIDriver.scala index 0c79b6f4211ff..33e18a8da60fb 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIDriver.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIDriver.scala @@ -34,11 +34,12 @@ import org.apache.hadoop.hive.ql.Driver import org.apache.hadoop.hive.ql.exec.Utilities import org.apache.hadoop.hive.ql.processors._ import org.apache.hadoop.hive.ql.session.SessionState +import org.apache.log4j.{Level, Logger} import org.apache.thrift.transport.TSocket import org.apache.spark.internal.Logging import org.apache.spark.sql.AnalysisException -import org.apache.spark.sql.hive.{HiveSessionState, HiveUtils} +import org.apache.spark.sql.hive.HiveUtils import org.apache.spark.util.ShutdownHookManager /** @@ -46,8 +47,8 @@ import org.apache.spark.util.ShutdownHookManager * has dropped its support. */ private[hive] object SparkSQLCLIDriver extends Logging { - private var prompt = "spark-sql" - private var continuedPrompt = "".padTo(prompt.length, ' ') + private val prompt = "spark-sql" + private val continuedPrompt = "".padTo(prompt.length, ' ') private var transport: TSocket = _ installSignalHandler() @@ -275,6 +276,10 @@ private[hive] class SparkSQLCLIDriver extends CliDriver with Logging { private val console = new SessionState.LogHelper(LOG) + if (sessionState.getIsSilent) { + Logger.getRootLogger.setLevel(Level.WARN) + } + private val isRemoteMode = { SparkSQLCLIDriver.isRemoteMode(sessionState) } @@ -297,7 +302,7 @@ private[hive] class SparkSQLCLIDriver extends CliDriver with Logging { override def processCmd(cmd: String): Int = { val cmd_trimmed: String = cmd.trim() - val cmd_lower = cmd_trimmed.toLowerCase(Locale.ENGLISH) + val cmd_lower = cmd_trimmed.toLowerCase(Locale.ROOT) val tokens: Array[String] = cmd_trimmed.split("\\s+") val cmd_1: String = cmd_trimmed.substring(tokens(0).length()).trim() if (cmd_lower.equals("quit") || @@ -305,7 +310,7 @@ private[hive] class SparkSQLCLIDriver extends CliDriver with Logging { sessionState.close() System.exit(0) } - if (tokens(0).toLowerCase(Locale.ENGLISH).equals("source") || + if (tokens(0).toLowerCase(Locale.ROOT).equals("source") || cmd_trimmed.startsWith("!") || isRemoteMode) { val start = System.currentTimeMillis() super.processCmd(cmd) diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLEnv.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLEnv.scala index c0b299411e94a..01c4eb131a564 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLEnv.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLEnv.scala @@ -22,7 +22,7 @@ import java.io.PrintStream import org.apache.spark.{SparkConf, SparkContext} import org.apache.spark.internal.Logging import org.apache.spark.sql.{SparkSession, SQLContext} -import org.apache.spark.sql.hive.{HiveSessionState, HiveUtils} +import org.apache.spark.sql.hive.{HiveExternalCatalog, HiveUtils} import org.apache.spark.util.Utils /** A singleton object for the master program. The slaves should not access this. */ @@ -49,10 +49,12 @@ private[hive] object SparkSQLEnv extends Logging { sparkContext = sparkSession.sparkContext sqlContext = sparkSession.sqlContext - val sessionState = sparkSession.sessionState.asInstanceOf[HiveSessionState] - sessionState.metadataHive.setOut(new PrintStream(System.out, true, "UTF-8")) - sessionState.metadataHive.setInfo(new PrintStream(System.err, true, "UTF-8")) - sessionState.metadataHive.setError(new PrintStream(System.err, true, "UTF-8")) + val metadataHive = sparkSession + .sharedState.externalCatalog.asInstanceOf[HiveExternalCatalog] + .client.newSession() + metadataHive.setOut(new PrintStream(System.out, true, "UTF-8")) + metadataHive.setInfo(new PrintStream(System.err, true, "UTF-8")) + metadataHive.setError(new PrintStream(System.err, true, "UTF-8")) sparkSession.conf.set("spark.sql.hive.version", HiveUtils.hiveExecutionVersion) } } diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/server/SparkSQLOperationManager.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/server/SparkSQLOperationManager.scala index 49ab664009341..a0e5012633f5e 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/server/SparkSQLOperationManager.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/server/SparkSQLOperationManager.scala @@ -26,7 +26,7 @@ import org.apache.hive.service.cli.session.HiveSession import org.apache.spark.internal.Logging import org.apache.spark.sql.SQLContext -import org.apache.spark.sql.hive.HiveSessionState +import org.apache.spark.sql.hive.HiveUtils import org.apache.spark.sql.hive.thriftserver.{ReflectionUtils, SparkExecuteStatementOperation} /** @@ -49,8 +49,8 @@ private[thriftserver] class SparkSQLOperationManager() val sqlContext = sessionToContexts.get(parentSession.getSessionHandle) require(sqlContext != null, s"Session handle: ${parentSession.getSessionHandle} has not been" + s" initialized or had already closed.") - val sessionState = sqlContext.sessionState.asInstanceOf[HiveSessionState] - val runInBackground = async && sessionState.hiveThriftServerAsync + val conf = sqlContext.sessionState.conf + val runInBackground = async && conf.getConf(HiveUtils.HIVE_THRIFT_SERVER_ASYNC) val operation = new SparkExecuteStatementOperation(parentSession, statement, confOverlay, runInBackground)(sqlContext, sessionToActivePool) handleToOperation.put(operation.getHandle, operation) diff --git a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/SparkExecuteStatementOperationSuite.scala b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/SparkExecuteStatementOperationSuite.scala index 32ded0d254ef8..06e3980662048 100644 --- a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/SparkExecuteStatementOperationSuite.scala +++ b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/SparkExecuteStatementOperationSuite.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.hive.thriftserver import org.apache.spark.SparkFunSuite -import org.apache.spark.sql.types.{NullType, StructField, StructType} +import org.apache.spark.sql.types.{IntegerType, NullType, StringType, StructField, StructType} class SparkExecuteStatementOperationSuite extends SparkFunSuite { test("SPARK-17112 `select null` via JDBC triggers IllegalArgumentException in ThriftServer") { @@ -30,4 +30,16 @@ class SparkExecuteStatementOperationSuite extends SparkFunSuite { assert(columns.get(0).getType() == org.apache.hive.service.cli.Type.NULL_TYPE) assert(columns.get(1).getType() == org.apache.hive.service.cli.Type.NULL_TYPE) } + + test("SPARK-20146 Comment should be preserved") { + val field1 = StructField("column1", StringType).withComment("comment 1") + val field2 = StructField("column2", IntegerType) + val tableSchema = StructType(Seq(field1, field2)) + val columns = SparkExecuteStatementOperation.getTableSchema(tableSchema).getColumnDescriptors() + assert(columns.size() == 2) + assert(columns.get(0).getType() == org.apache.hive.service.cli.Type.STRING_TYPE) + assert(columns.get(0).getComment() == "comment 1") + assert(columns.get(1).getType() == org.apache.hive.service.cli.Type.INT_TYPE) + assert(columns.get(1).getComment() == "") + } } diff --git a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala index f78660f7c14b6..0a53aaca404e6 100644 --- a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala +++ b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala @@ -39,7 +39,7 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter { private val originalLocale = Locale.getDefault private val originalColumnBatchSize = TestHive.conf.columnBatchSize private val originalInMemoryPartitionPruning = TestHive.conf.inMemoryPartitionPruning - private val originalConvertMetastoreOrc = TestHive.sessionState.convertMetastoreOrc + private val originalConvertMetastoreOrc = TestHive.conf.getConf(HiveUtils.CONVERT_METASTORE_ORC) private val originalCrossJoinEnabled = TestHive.conf.crossJoinEnabled private val originalSessionLocalTimeZone = TestHive.conf.sessionLocalTimeZone diff --git a/sql/hive/pom.xml b/sql/hive/pom.xml index 5119eeb2a2ecd..a2cd07f86a069 100644 --- a/sql/hive/pom.xml +++ b/sql/hive/pom.xml @@ -22,7 +22,7 @@ org.apache.spark spark-parent_2.11 - 2.2.0-SNAPSHOT + 2.3.0-SNAPSHOT ../../pom.xml diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala index 5393c57c9a28f..02a5117f005e8 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala @@ -48,10 +48,6 @@ class HiveContext private[hive](_sparkSession: SparkSession) new HiveContext(sparkSession.newSession()) } - protected[sql] override def sessionState: HiveSessionState = { - sparkSession.sessionState.asInstanceOf[HiveSessionState] - } - /** * Invalidate and refresh all the cached the metadata of the given table. For performance reasons, * Spark SQL or the external data source library it uses might cache certain metadata about a diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala index d09160f4e4da9..ba48facff2933 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala @@ -18,8 +18,9 @@ package org.apache.spark.sql.hive import java.io.IOException -import java.net.URI +import java.lang.reflect.InvocationTargetException import java.util +import java.util.Locale import scala.collection.mutable import scala.util.control.NonFatal @@ -35,10 +36,10 @@ import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.analysis.TableAlreadyExistsException import org.apache.spark.sql.catalyst.catalog._ -import org.apache.spark.sql.catalyst.catalog.ExternalCatalogUtils.escapePathName +import org.apache.spark.sql.catalyst.catalog.ExternalCatalogUtils._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.logical.ColumnStat -import org.apache.spark.sql.catalyst.util.{CaseInsensitiveMap, DateTimeUtils} +import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap import org.apache.spark.sql.execution.command.DDLUtils import org.apache.spark.sql.execution.datasources.PartitioningUtils import org.apache.spark.sql.hive.client.HiveClient @@ -61,14 +62,15 @@ private[spark] class HiveExternalCatalog(conf: SparkConf, hadoopConf: Configurat /** * A Hive client used to interact with the metastore. */ - val client: HiveClient = { + lazy val client: HiveClient = { HiveUtils.newClientForMetadata(conf, hadoopConf) } // Exceptions thrown by the hive client that we would like to wrap private val clientExceptions = Set( classOf[HiveException].getCanonicalName, - classOf[TException].getCanonicalName) + classOf[TException].getCanonicalName, + classOf[InvocationTargetException].getCanonicalName) /** * Whether this is an exception thrown by the hive client that should be wrapped. @@ -94,7 +96,13 @@ private[spark] class HiveExternalCatalog(conf: SparkConf, hadoopConf: Configurat try { body } catch { - case NonFatal(e) if isClientException(e) => + case NonFatal(exception) if isClientException(exception) => + val e = exception match { + // Since we are using shim, the exceptions thrown by the underlying method of + // Method.invoke() are wrapped by InvocationTargetException + case i: InvocationTargetException => i.getCause + case o => o + } throw new AnalysisException( e.getClass.getCanonicalName + ": " + e.getMessage, cause = Some(e)) } @@ -129,17 +137,33 @@ private[spark] class HiveExternalCatalog(conf: SparkConf, hadoopConf: Configurat } } + /** + * Checks the validity of column names. Hive metastore disallows the table to use comma in + * data column names. Partition columns do not have such a restriction. Views do not have such + * a restriction. + */ + private def verifyColumnNames(table: CatalogTable): Unit = { + if (table.tableType != VIEW) { + table.dataSchema.map(_.name).foreach { colName => + if (colName.contains(",")) { + throw new AnalysisException("Cannot create a table having a column whose name contains " + + s"commas in Hive metastore. Table: ${table.identifier}; Column: $colName") + } + } + } + } + // -------------------------------------------------------------------------- // Databases // -------------------------------------------------------------------------- - override def createDatabase( + override protected def doCreateDatabase( dbDefinition: CatalogDatabase, ignoreIfExists: Boolean): Unit = withClient { client.createDatabase(dbDefinition, ignoreIfExists) } - override def dropDatabase( + override protected def doDropDatabase( db: String, ignoreIfNotExists: Boolean, cascade: Boolean): Unit = withClient { @@ -186,7 +210,7 @@ private[spark] class HiveExternalCatalog(conf: SparkConf, hadoopConf: Configurat // Tables // -------------------------------------------------------------------------- - override def createTable( + override protected def doCreateTable( tableDefinition: CatalogTable, ignoreIfExists: Boolean): Unit = withClient { assert(tableDefinition.identifier.database.isDefined) @@ -194,6 +218,7 @@ private[spark] class HiveExternalCatalog(conf: SparkConf, hadoopConf: Configurat val table = tableDefinition.identifier.table requireDbExists(db) verifyTableProperties(tableDefinition) + verifyColumnNames(tableDefinition) if (tableExists(db, table) && !ignoreIfExists) { throw new TableAlreadyExistsException(db = db, table = table) @@ -210,7 +235,7 @@ private[spark] class HiveExternalCatalog(conf: SparkConf, hadoopConf: Configurat tableDefinition.storage.locationUri.isEmpty val tableLocation = if (needDefaultTableLocation) { - Some(defaultTablePath(tableDefinition.identifier)) + Some(CatalogUtils.stringToURI(defaultTablePath(tableDefinition.identifier))) } else { tableDefinition.storage.locationUri } @@ -260,7 +285,7 @@ private[spark] class HiveExternalCatalog(conf: SparkConf, hadoopConf: Configurat // However, in older version of Spark we already store table location in storage properties // with key "path". Here we keep this behaviour for backward compatibility. val storagePropsWithLocation = table.storage.properties ++ - table.storage.locationUri.map("path" -> _) + table.storage.locationUri.map("path" -> CatalogUtils.URIToString(_)) // converts the table metadata to Spark SQL specific format, i.e. set data schema, names and // bucket specification to empty. Note that partition columns are retained, so that we can @@ -285,7 +310,7 @@ private[spark] class HiveExternalCatalog(conf: SparkConf, hadoopConf: Configurat // compatible format, which means the data source is file-based and must have a `path`. require(table.storage.locationUri.isDefined, "External file-based data source table must have a `path` entry in storage properties.") - Some(new Path(table.location).toUri.toString) + Some(table.location) } else { None } @@ -432,13 +457,13 @@ private[spark] class HiveExternalCatalog(conf: SparkConf, hadoopConf: Configurat // // Please refer to https://issues.apache.org/jira/browse/SPARK-15269 for more details. val tempPath = { - val dbLocation = getDatabase(tableDefinition.database).locationUri + val dbLocation = new Path(getDatabase(tableDefinition.database).locationUri) new Path(dbLocation, tableDefinition.identifier.table + "-__PLACEHOLDER__") } try { client.createTable( - tableDefinition.withNewStorage(locationUri = Some(tempPath.toString)), + tableDefinition.withNewStorage(locationUri = Some(tempPath.toUri)), ignoreIfExists) } finally { FileSystem.get(tempPath.toUri, hadoopConf).delete(tempPath, true) @@ -448,7 +473,7 @@ private[spark] class HiveExternalCatalog(conf: SparkConf, hadoopConf: Configurat } } - override def dropTable( + override protected def doDropTable( db: String, table: String, ignoreIfNotExists: Boolean, @@ -457,7 +482,10 @@ private[spark] class HiveExternalCatalog(conf: SparkConf, hadoopConf: Configurat client.dropTable(db, table, ignoreIfNotExists, purge) } - override def renameTable(db: String, oldName: String, newName: String): Unit = withClient { + override protected def doRenameTable( + db: String, + oldName: String, + newName: String): Unit = withClient { val rawTable = getRawTable(db, oldName) // Note that Hive serde tables don't use path option in storage properties to store the value @@ -492,7 +520,7 @@ private[spark] class HiveExternalCatalog(conf: SparkConf, hadoopConf: Configurat // We can't use `filterKeys` here, as the map returned by `filterKeys` is not serializable, // while `CatalogTable` should be serializable. val propsWithoutPath = table.storage.properties.filter { - case (k, v) => k.toLowerCase != "path" + case (k, v) => k.toLowerCase(Locale.ROOT) != "path" } table.storage.copy(properties = propsWithoutPath ++ newPath.map("path" -> _)) } @@ -518,8 +546,10 @@ private[spark] class HiveExternalCatalog(conf: SparkConf, hadoopConf: Configurat if (stats.rowCount.isDefined) { statsProperties += STATISTICS_NUM_ROWS -> stats.rowCount.get.toString() } + val colNameTypeMap: Map[String, DataType] = + tableDefinition.schema.fields.map(f => (f.name, f.dataType)).toMap stats.colStats.foreach { case (colName, colStat) => - colStat.toMap.foreach { case (k, v) => + colStat.toMap(colName, colNameTypeMap(colName)).foreach { case (k, v) => statsProperties += (columnStatKeyPropName(colName, k) -> v) } } @@ -563,7 +593,7 @@ private[spark] class HiveExternalCatalog(conf: SparkConf, hadoopConf: Configurat // want to alter the table location to a file path, we will fail. This should be fixed // in the future. - val newLocation = tableDefinition.storage.locationUri + val newLocation = tableDefinition.storage.locationUri.map(CatalogUtils.URIToString(_)) val storageWithPathOption = tableDefinition.storage.copy( properties = tableDefinition.storage.properties ++ newLocation.map("path" -> _)) @@ -597,6 +627,26 @@ private[spark] class HiveExternalCatalog(conf: SparkConf, hadoopConf: Configurat } } + override def alterTableSchema(db: String, table: String, schema: StructType): Unit = withClient { + requireTableExists(db, table) + val rawTable = getRawTable(db, table) + val withNewSchema = rawTable.copy(schema = schema) + verifyColumnNames(withNewSchema) + // Add table metadata such as table schema, partition columns, etc. to table properties. + val updatedTable = withNewSchema.copy( + properties = withNewSchema.properties ++ tableMetaToTableProps(withNewSchema)) + try { + client.alterTable(updatedTable) + } catch { + case NonFatal(e) => + val warningMessage = + s"Could not alter schema of table ${rawTable.identifier.quotedString} in a Hive " + + "compatible way. Updating Hive metastore in Spark SQL specific format." + logWarning(warningMessage, e) + client.alterTable(updatedTable.copy(schema = updatedTable.partitionSchema)) + } + } + override def getTable(db: String, table: String): CatalogTable = withClient { restoreTableMetadata(getRawTable(db, table)) } @@ -690,10 +740,10 @@ private[spark] class HiveExternalCatalog(conf: SparkConf, hadoopConf: Configurat "different from the schema when this table was created by Spark SQL" + s"(${schemaFromTableProps.simpleString}). We have to fall back to the table schema " + "from Hive metastore which is not case preserving.") - hiveTable + hiveTable.copy(schemaPreservesCase = false) } } else { - hiveTable + hiveTable.copy(schemaPreservesCase = false) } } @@ -704,7 +754,8 @@ private[spark] class HiveExternalCatalog(conf: SparkConf, hadoopConf: Configurat val storageWithLocation = { val tableLocation = getLocationFromStorageProps(table) // We pass None as `newPath` here, to remove the path option in storage properties. - updateLocationInStorageProps(table, newPath = None).copy(locationUri = tableLocation) + updateLocationInStorageProps(table, newPath = None).copy( + locationUri = tableLocation.map(CatalogUtils.stringToURI(_))) } val partitionProvider = table.properties.get(TABLE_PARTITION_PROVIDER) @@ -848,10 +899,10 @@ private[spark] class HiveExternalCatalog(conf: SparkConf, hadoopConf: Configurat // However, Hive metastore is not case preserving and will generate wrong partition location // with lower cased partition column names. Here we set the default partition location // manually to avoid this problem. - val partitionPath = p.storage.locationUri.map(uri => new Path(new URI(uri))).getOrElse { + val partitionPath = p.storage.locationUri.map(uri => new Path(uri)).getOrElse { ExternalCatalogUtils.generatePartitionPath(p.spec, partitionColumnNames, tablePath) } - p.copy(storage = p.storage.copy(locationUri = Some(partitionPath.toUri.toString))) + p.copy(storage = p.storage.copy(locationUri = Some(partitionPath.toUri))) } val lowerCasedParts = partsWithLocation.map(p => p.copy(spec = lowerCasePartitionSpec(p.spec))) client.createPartitions(db, table, lowerCasedParts, ignoreIfExists) @@ -890,7 +941,7 @@ private[spark] class HiveExternalCatalog(conf: SparkConf, hadoopConf: Configurat val newParts = newSpecs.map { spec => val rightPath = renamePartitionDirectory(fs, tablePath, partitionColumnNames, spec) val partition = client.getPartition(db, table, lowerCasePartitionSpec(spec)) - partition.copy(storage = partition.storage.copy(locationUri = Some(rightPath.toString))) + partition.copy(storage = partition.storage.copy(locationUri = Some(rightPath.toUri))) } alterPartitions(db, table, newParts) } @@ -984,8 +1035,8 @@ private[spark] class HiveExternalCatalog(conf: SparkConf, hadoopConf: Configurat val partColNameMap = buildLowerCasePartColNameMap(catalogTable).mapValues(escapePathName) val clientPartitionNames = client.getPartitionNames(catalogTable, partialSpec.map(lowerCasePartitionSpec)) - clientPartitionNames.map { partName => - val partSpec = PartitioningUtils.parsePathFragmentAsSeq(partName) + clientPartitionNames.map { partitionPath => + val partSpec = PartitioningUtils.parsePathFragmentAsSeq(partitionPath) partSpec.map { case (partName, partValue) => partColNameMap(partName.toLowerCase) + "=" + escapePathName(partValue) }.mkString("/") @@ -1012,62 +1063,42 @@ private[spark] class HiveExternalCatalog(conf: SparkConf, hadoopConf: Configurat defaultTimeZoneId: String): Seq[CatalogTablePartition] = withClient { val rawTable = getRawTable(db, table) val catalogTable = restoreTableMetadata(rawTable) - val partitionColumnNames = catalogTable.partitionColumnNames.toSet - val nonPartitionPruningPredicates = predicates.filterNot { - _.references.map(_.name).toSet.subsetOf(partitionColumnNames) - } - if (nonPartitionPruningPredicates.nonEmpty) { - sys.error("Expected only partition pruning predicates: " + - predicates.reduceLeft(And)) - } - - val partitionSchema = catalogTable.partitionSchema - val partColNameMap = buildLowerCasePartColNameMap(getTable(db, table)) + val partColNameMap = buildLowerCasePartColNameMap(catalogTable) - if (predicates.nonEmpty) { - val clientPrunedPartitions = client.getPartitionsByFilter(rawTable, predicates).map { part => + val clientPrunedPartitions = + client.getPartitionsByFilter(rawTable, predicates).map { part => part.copy(spec = restorePartitionSpec(part.spec, partColNameMap)) } - val boundPredicate = - InterpretedPredicate.create(predicates.reduce(And).transform { - case att: AttributeReference => - val index = partitionSchema.indexWhere(_.name == att.name) - BoundReference(index, partitionSchema(index).dataType, nullable = true) - }) - clientPrunedPartitions.filter { p => - boundPredicate.eval(p.toRow(partitionSchema, defaultTimeZoneId)) - } - } else { - client.getPartitions(catalogTable).map { part => - part.copy(spec = restorePartitionSpec(part.spec, partColNameMap)) - } - } + prunePartitionsByFilter(catalogTable, clientPrunedPartitions, predicates, defaultTimeZoneId) } // -------------------------------------------------------------------------- // Functions // -------------------------------------------------------------------------- - override def createFunction( + override protected def doCreateFunction( db: String, funcDefinition: CatalogFunction): Unit = withClient { requireDbExists(db) // Hive's metastore is case insensitive. However, Hive's createFunction does // not normalize the function name (unlike the getFunction part). So, // we are normalizing the function name. - val functionName = funcDefinition.identifier.funcName.toLowerCase + val functionName = funcDefinition.identifier.funcName.toLowerCase(Locale.ROOT) requireFunctionNotExists(db, functionName) val functionIdentifier = funcDefinition.identifier.copy(funcName = functionName) client.createFunction(db, funcDefinition.copy(identifier = functionIdentifier)) } - override def dropFunction(db: String, name: String): Unit = withClient { + override protected def doDropFunction(db: String, name: String): Unit = withClient { requireFunctionExists(db, name) client.dropFunction(db, name) } - override def renameFunction(db: String, oldName: String, newName: String): Unit = withClient { + override protected def doRenameFunction( + db: String, + oldName: String, + newName: String): Unit = withClient { requireFunctionExists(db, oldName) requireFunctionNotExists(db, newName) client.renameFunction(db, oldName, newName) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala index 6f5b923cd4f9e..4dec2f71b8a50 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala @@ -53,8 +53,8 @@ import org.apache.spark.unsafe.types.UTF8String * java.sql.Date * java.sql.Timestamp * Complex Types => - * Map: [[MapData]] - * List: [[ArrayData]] + * Map: `MapData` + * List: `ArrayData` * Struct: [[org.apache.spark.sql.catalyst.InternalRow]] * Union: NOT SUPPORTED YET * The Complex types plays as a container, which can hold arbitrary data types. diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala index 151a69aebf1d8..6b98066cb76c8 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala @@ -17,21 +17,19 @@ package org.apache.spark.sql.hive -import java.net.URI +import scala.util.control.NonFatal import com.google.common.util.concurrent.Striped import org.apache.hadoop.fs.Path +import org.apache.spark.SparkException import org.apache.spark.internal.Logging import org.apache.spark.sql.SparkSession import org.apache.spark.sql.catalyst.{QualifiedTableName, TableIdentifier} import org.apache.spark.sql.catalyst.catalog._ import org.apache.spark.sql.catalyst.plans.logical._ -import org.apache.spark.sql.catalyst.rules._ -import org.apache.spark.sql.execution.command.DDLUtils import org.apache.spark.sql.execution.datasources._ -import org.apache.spark.sql.execution.datasources.parquet.{ParquetFileFormat, ParquetOptions} -import org.apache.spark.sql.hive.orc.OrcFileFormat +import org.apache.spark.sql.internal.SQLConf.HiveCaseSensitiveInferenceMode._ import org.apache.spark.sql.types._ /** @@ -41,16 +39,10 @@ import org.apache.spark.sql.types._ * cleaned up to integrate more nicely with [[HiveExternalCatalog]]. */ private[hive] class HiveMetastoreCatalog(sparkSession: SparkSession) extends Logging { - private val sessionState = sparkSession.sessionState.asInstanceOf[HiveSessionState] - private lazy val tableRelationCache = sparkSession.sessionState.catalog.tableRelationCache - - private def getCurrentDatabase: String = sessionState.catalog.getCurrentDatabase - - def getQualifiedTableName(tableIdent: TableIdentifier): QualifiedTableName = { - QualifiedTableName( - tableIdent.database.getOrElse(getCurrentDatabase).toLowerCase, - tableIdent.table.toLowerCase) - } + // these are def_s and not val/lazy val since the latter would introduce circular references + private def sessionState = sparkSession.sessionState + private def tableRelationCache = sparkSession.sessionState.catalog.tableRelationCache + import HiveMetastoreCatalog._ /** These locks guard against multiple attempts to instantiate a table, which wastes memory. */ private val tableCreationLocks = Striped.lazyWeakLock(100) @@ -64,11 +56,12 @@ private[hive] class HiveMetastoreCatalog(sparkSession: SparkSession) extends Log } } - def hiveDefaultTableFilePath(tableIdent: TableIdentifier): String = { - // Code based on: hiveWarehouse.getTablePath(currentDatabase, tableName) - val QualifiedTableName(dbName, tblName) = getQualifiedTableName(tableIdent) - val dbLocation = sparkSession.sharedState.externalCatalog.getDatabase(dbName).locationUri - new Path(new Path(dbLocation), tblName).toString + // For testing only + private[hive] def getCachedDataSourceTable(table: TableIdentifier): LogicalPlan = { + val key = QualifiedTableName( + table.database.getOrElse(sessionState.catalog.getCurrentDatabase).toLowerCase, + table.table.toLowerCase) + tableRelationCache.getIfPresent(key) } private def getCached( @@ -118,7 +111,7 @@ private[hive] class HiveMetastoreCatalog(sparkSession: SparkSession) extends Log } } - private def convertToLogicalRelation( + def convertToLogicalRelation( relation: CatalogRelation, options: Map[String, String], fileFormatClass: Class[_ <: FileFormat], @@ -128,7 +121,9 @@ private[hive] class HiveMetastoreCatalog(sparkSession: SparkSession) extends Log QualifiedTableName(relation.tableMeta.database, relation.tableMeta.identifier.table) val lazyPruningEnabled = sparkSession.sqlContext.conf.manageFilesourcePartitions - val tablePath = new Path(new URI(relation.tableMeta.location)) + val tablePath = new Path(relation.tableMeta.location) + val fileFormat = fileFormatClass.newInstance() + val result = if (relation.isPartitioned) { val partitionSchema = relation.tableMeta.partitionSchema val rootPaths: Seq[Path] = if (lazyPruningEnabled) { @@ -141,7 +136,7 @@ private[hive] class HiveMetastoreCatalog(sparkSession: SparkSession) extends Log // locations,_omitting_ the table's base path. val paths = sparkSession.sharedState.externalCatalog .listPartitions(tableIdentifier.database, tableIdentifier.name) - .map(p => new Path(new URI(p.storage.locationUri.get))) + .map(p => new Path(p.storage.locationUri.get)) if (paths.isEmpty) { Seq(tablePath) @@ -169,16 +164,18 @@ private[hive] class HiveMetastoreCatalog(sparkSession: SparkSession) extends Log } } + val (dataSchema, updatedTable) = + inferIfNeeded(relation, options, fileFormat, Option(fileIndex)) + val fsRelation = HadoopFsRelation( location = fileIndex, partitionSchema = partitionSchema, - dataSchema = relation.tableMeta.dataSchema, + dataSchema = dataSchema, // We don't support hive bucketed tables, only ones we write out. bucketSpec = None, - fileFormat = fileFormatClass.newInstance(), + fileFormat = fileFormat, options = options)(sparkSession = sparkSession) - - val created = LogicalRelation(fsRelation, catalogTable = Some(relation.tableMeta)) + val created = LogicalRelation(fsRelation, updatedTable) tableRelationCache.put(tableIdentifier, created) created } @@ -195,17 +192,18 @@ private[hive] class HiveMetastoreCatalog(sparkSession: SparkSession) extends Log fileFormatClass, None) val logicalRelation = cached.getOrElse { + val (dataSchema, updatedTable) = inferIfNeeded(relation, options, fileFormat) val created = LogicalRelation( DataSource( sparkSession = sparkSession, paths = rootPath.toString :: Nil, - userSpecifiedSchema = Some(metastoreSchema), + userSpecifiedSchema = Option(dataSchema), // We don't support hive bucketed tables, only ones we write out. bucketSpec = None, options = options, className = fileType).resolveRelation(), - catalogTable = Some(relation.tableMeta)) + table = updatedTable) tableRelationCache.put(tableIdentifier, created) created @@ -214,75 +212,89 @@ private[hive] class HiveMetastoreCatalog(sparkSession: SparkSession) extends Log logicalRelation }) } - result.copy(expectedOutputAttributes = Some(relation.output)) - } - - /** - * When scanning or writing to non-partitioned Metastore Parquet tables, convert them to Parquet - * data source relations for better performance. - */ - object ParquetConversions extends Rule[LogicalPlan] { - private def shouldConvertMetastoreParquet(relation: CatalogRelation): Boolean = { - relation.tableMeta.storage.serde.getOrElse("").toLowerCase.contains("parquet") && - sessionState.convertMetastoreParquet - } - - private def convertToParquetRelation(relation: CatalogRelation): LogicalRelation = { - val fileFormatClass = classOf[ParquetFileFormat] - val mergeSchema = sessionState.convertMetastoreParquetWithSchemaMerging - val options = Map(ParquetOptions.MERGE_SCHEMA -> mergeSchema.toString) - - convertToLogicalRelation(relation, options, fileFormatClass, "parquet") + // The inferred schema may have different filed names as the table schema, we should respect + // it, but also respect the exprId in table relation output. + assert(result.output.length == relation.output.length && + result.output.zip(relation.output).forall { case (a1, a2) => a1.dataType == a2.dataType }) + val newOutput = result.output.zip(relation.output).map { + case (a1, a2) => a1.withExprId(a2.exprId) } + result.copy(output = newOutput) + } - override def apply(plan: LogicalPlan): LogicalPlan = { - plan transformUp { - // Write path - case InsertIntoTable(r: CatalogRelation, partition, query, overwrite, ifNotExists) - // Inserting into partitioned table is not supported in Parquet data source (yet). - if query.resolved && DDLUtils.isHiveTable(r.tableMeta) && - !r.isPartitioned && shouldConvertMetastoreParquet(r) => - InsertIntoTable(convertToParquetRelation(r), partition, query, overwrite, ifNotExists) + private def inferIfNeeded( + relation: CatalogRelation, + options: Map[String, String], + fileFormat: FileFormat, + fileIndexOpt: Option[FileIndex] = None): (StructType, CatalogTable) = { + val inferenceMode = sparkSession.sessionState.conf.caseSensitiveInferenceMode + val shouldInfer = (inferenceMode != NEVER_INFER) && !relation.tableMeta.schemaPreservesCase + val tableName = relation.tableMeta.identifier.unquotedString + if (shouldInfer) { + logInfo(s"Inferring case-sensitive schema for table $tableName (inference mode: " + + s"$inferenceMode)") + val fileIndex = fileIndexOpt.getOrElse { + val rootPath = new Path(relation.tableMeta.location) + new InMemoryFileIndex(sparkSession, Seq(rootPath), options, None) + } - // Read path - case relation: CatalogRelation if DDLUtils.isHiveTable(relation.tableMeta) && - shouldConvertMetastoreParquet(relation) => - convertToParquetRelation(relation) + val inferredSchema = fileFormat + .inferSchema( + sparkSession, + options, + fileIndex.listFiles(Nil, Nil).flatMap(_.files)) + .map(mergeWithMetastoreSchema(relation.tableMeta.schema, _)) + + inferredSchema match { + case Some(schema) => + if (inferenceMode == INFER_AND_SAVE) { + updateCatalogSchema(relation.tableMeta.identifier, schema) + } + (schema, relation.tableMeta.copy(schema = schema)) + case None => + logWarning(s"Unable to infer schema for table $tableName from file format " + + s"$fileFormat (inference mode: $inferenceMode). Using metastore schema.") + (relation.tableMeta.schema, relation.tableMeta) } + } else { + (relation.tableMeta.schema, relation.tableMeta) } } - /** - * When scanning Metastore ORC tables, convert them to ORC data source relations - * for better performance. - */ - object OrcConversions extends Rule[LogicalPlan] { - private def shouldConvertMetastoreOrc(relation: CatalogRelation): Boolean = { - relation.tableMeta.storage.serde.getOrElse("").toLowerCase.contains("orc") && - sessionState.convertMetastoreOrc - } - - private def convertToOrcRelation(relation: CatalogRelation): LogicalRelation = { - val fileFormatClass = classOf[OrcFileFormat] - val options = Map[String, String]() - - convertToLogicalRelation(relation, options, fileFormatClass, "orc") - } + private def updateCatalogSchema(identifier: TableIdentifier, schema: StructType): Unit = try { + val db = identifier.database.get + logInfo(s"Saving case-sensitive schema for table ${identifier.unquotedString}") + sparkSession.sharedState.externalCatalog.alterTableSchema(db, identifier.table, schema) + } catch { + case NonFatal(ex) => + logWarning(s"Unable to save case-sensitive schema for table ${identifier.unquotedString}", ex) + } +} - override def apply(plan: LogicalPlan): LogicalPlan = { - plan transformUp { - // Write path - case InsertIntoTable(r: CatalogRelation, partition, query, overwrite, ifNotExists) - // Inserting into partitioned table is not supported in Orc data source (yet). - if query.resolved && DDLUtils.isHiveTable(r.tableMeta) && - !r.isPartitioned && shouldConvertMetastoreOrc(r) => - InsertIntoTable(convertToOrcRelation(r), partition, query, overwrite, ifNotExists) - // Read path - case relation: CatalogRelation if DDLUtils.isHiveTable(relation.tableMeta) && - shouldConvertMetastoreOrc(relation) => - convertToOrcRelation(relation) - } - } +private[hive] object HiveMetastoreCatalog { + def mergeWithMetastoreSchema( + metastoreSchema: StructType, + inferredSchema: StructType): StructType = try { + // Find any nullable fields in mestastore schema that are missing from the inferred schema. + val metastoreFields = metastoreSchema.map(f => f.name.toLowerCase -> f).toMap + val missingNullables = metastoreFields + .filterKeys(!inferredSchema.map(_.name.toLowerCase).contains(_)) + .values + .filter(_.nullable) + // Merge missing nullable fields to inferred schema and build a case-insensitive field map. + val inferredFields = StructType(inferredSchema ++ missingNullables) + .map(f => f.name.toLowerCase -> f).toMap + StructType(metastoreSchema.map(f => f.copy(name = inferredFields(f.name).name))) + } catch { + case NonFatal(_) => + val msg = s"""Detected conflicting schemas when merging the schema obtained from the Hive + | Metastore with the one inferred from the file format. Metastore schema: + |${metastoreSchema.prettyJson} + | + |Inferred schema: + |${inferredSchema.prettyJson} + """.stripMargin + throw new SparkException(msg) } } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala index c9be1b9d100b0..377d4f2473c58 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.hive +import java.util.Locale + import scala.util.{Failure, Success, Try} import scala.util.control.NonFatal @@ -25,15 +27,13 @@ import org.apache.hadoop.hive.ql.exec.{UDAF, UDF} import org.apache.hadoop.hive.ql.exec.{FunctionRegistry => HiveFunctionRegistry} import org.apache.hadoop.hive.ql.udf.generic.{AbstractGenericUDAFResolver, GenericUDF, GenericUDTF} -import org.apache.spark.sql.{AnalysisException, SparkSession} -import org.apache.spark.sql.catalyst.{FunctionIdentifier, TableIdentifier} +import org.apache.spark.sql.AnalysisException +import org.apache.spark.sql.catalyst.FunctionIdentifier import org.apache.spark.sql.catalyst.analysis.FunctionRegistry import org.apache.spark.sql.catalyst.analysis.FunctionRegistry.FunctionBuilder -import org.apache.spark.sql.catalyst.catalog.{FunctionResourceLoader, GlobalTempViewManager, SessionCatalog} -import org.apache.spark.sql.catalyst.expressions.{Cast, Expression, ExpressionInfo} +import org.apache.spark.sql.catalyst.catalog.{CatalogFunction, FunctionResourceLoader, GlobalTempViewManager, SessionCatalog} +import org.apache.spark.sql.catalyst.expressions.{Cast, Expression} import org.apache.spark.sql.catalyst.parser.ParserInterface -import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan -import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.hive.HiveShim.HiveFunctionWrapper import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.{DecimalType, DoubleType} @@ -43,45 +43,20 @@ import org.apache.spark.util.Utils private[sql] class HiveSessionCatalog( externalCatalog: HiveExternalCatalog, globalTempViewManager: GlobalTempViewManager, - sparkSession: SparkSession, - functionResourceLoader: FunctionResourceLoader, + val metastoreCatalog: HiveMetastoreCatalog, functionRegistry: FunctionRegistry, conf: SQLConf, hadoopConf: Configuration, - parser: ParserInterface) + parser: ParserInterface, + functionResourceLoader: FunctionResourceLoader) extends SessionCatalog( - externalCatalog, - globalTempViewManager, - functionResourceLoader, - functionRegistry, - conf, - hadoopConf, - parser) { - - // ---------------------------------------------------------------- - // | Methods and fields for interacting with HiveMetastoreCatalog | - // ---------------------------------------------------------------- - - // Catalog for handling data source tables. TODO: This really doesn't belong here since it is - // essentially a cache for metastore tables. However, it relies on a lot of session-specific - // things so it would be a lot of work to split its functionality between HiveSessionCatalog - // and HiveCatalog. We should still do it at some point... - private val metastoreCatalog = new HiveMetastoreCatalog(sparkSession) - - // These 2 rules must be run before all other DDL post-hoc resolution rules, i.e. - // `PreprocessTableCreation`, `PreprocessTableInsertion`, `DataSourceAnalysis` and `HiveAnalysis`. - val ParquetConversions: Rule[LogicalPlan] = metastoreCatalog.ParquetConversions - val OrcConversions: Rule[LogicalPlan] = metastoreCatalog.OrcConversions - - def hiveDefaultTableFilePath(name: TableIdentifier): String = { - metastoreCatalog.hiveDefaultTableFilePath(name) - } - - // For testing only - private[hive] def getCachedDataSourceTable(table: TableIdentifier): LogicalPlan = { - val key = metastoreCatalog.getQualifiedTableName(table) - sparkSession.sessionState.catalog.tableRelationCache.getIfPresent(key) - } + externalCatalog, + globalTempViewManager, + functionRegistry, + conf, + hadoopConf, + parser, + functionResourceLoader) { override def makeFunctionBuilder(funcName: String, className: String): FunctionBuilder = { makeFunctionBuilder(funcName, Utils.classForName(className)) @@ -149,13 +124,6 @@ private[sql] class HiveSessionCatalog( } private def lookupFunction0(name: FunctionIdentifier, children: Seq[Expression]): Expression = { - // TODO: Once lookupFunction accepts a FunctionIdentifier, we should refactor this method to - // if (super.functionExists(name)) { - // super.lookupFunction(name, children) - // } else { - // // This function is a Hive builtin function. - // ... - // } val database = name.database.map(formatDatabaseName) val funcName = name.copy(database = database) Try(super.lookupFunction(funcName, children)) match { @@ -170,7 +138,7 @@ private[sql] class HiveSessionCatalog( // This function is not in functionRegistry, let's try to load it as a Hive's // built-in function. // Hive is case insensitive. - val functionName = funcName.unquotedString.toLowerCase + val functionName = funcName.unquotedString.toLowerCase(Locale.ROOT) if (!hiveFunctions.contains(functionName)) { failFunctionLookup(funcName.unquotedString) } @@ -189,16 +157,22 @@ private[sql] class HiveSessionCatalog( } } val className = functionInfo.getFunctionClass.getName - val builder = makeFunctionBuilder(functionName, className) + val functionIdentifier = + FunctionIdentifier(functionName.toLowerCase(Locale.ROOT), database) + val func = CatalogFunction(functionIdentifier, className, Nil) // Put this Hive built-in function to our function registry. - val info = new ExpressionInfo(className, functionName) - createTempFunction(functionName, info, builder, ignoreIfExists = false) + registerFunction(func, ignoreIfExists = false) // Now, we need to create the Expression. functionRegistry.lookupFunction(functionName, children) } } } + // TODO Removes this method after implementing Spark native "histogram_numeric". + override def functionExists(name: FunctionIdentifier): Boolean = { + super.functionExists(name) || hiveFunctions.contains(name.funcName) + } + /** List of functions we pass over to Hive. Note that over time this list should go to 0. */ // We have a list of Hive built-in functions that we do not support. So, we will check // Hive's function registry and lazily load needed functions into our own function registry. diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionState.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionState.scala deleted file mode 100644 index 5a08a6bc66f6b..0000000000000 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionState.scala +++ /dev/null @@ -1,149 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.hive - -import org.apache.spark.sql._ -import org.apache.spark.sql.catalyst.analysis.Analyzer -import org.apache.spark.sql.execution.SparkPlanner -import org.apache.spark.sql.execution.datasources._ -import org.apache.spark.sql.hive.client.HiveClient -import org.apache.spark.sql.internal.SessionState - - -/** - * A class that holds all session-specific state in a given [[SparkSession]] backed by Hive. - */ -private[hive] class HiveSessionState(sparkSession: SparkSession) - extends SessionState(sparkSession) { - - self => - - /** - * A Hive client used for interacting with the metastore. - */ - lazy val metadataHive: HiveClient = - sparkSession.sharedState.externalCatalog.asInstanceOf[HiveExternalCatalog].client.newSession() - - /** - * Internal catalog for managing table and database states. - */ - override lazy val catalog = { - new HiveSessionCatalog( - sparkSession.sharedState.externalCatalog.asInstanceOf[HiveExternalCatalog], - sparkSession.sharedState.globalTempViewManager, - sparkSession, - functionResourceLoader, - functionRegistry, - conf, - newHadoopConf(), - sqlParser) - } - - /** - * An analyzer that uses the Hive metastore. - */ - override lazy val analyzer: Analyzer = { - new Analyzer(catalog, conf) { - override val extendedResolutionRules = - new ResolveHiveSerdeTable(sparkSession) :: - new FindDataSourceTable(sparkSession) :: - new ResolveSQLOnFile(sparkSession) :: Nil - - override val postHocResolutionRules = - new DetermineTableStats(sparkSession) :: - catalog.ParquetConversions :: - catalog.OrcConversions :: - PreprocessTableCreation(sparkSession) :: - PreprocessTableInsertion(conf) :: - DataSourceAnalysis(conf) :: - HiveAnalysis :: Nil - - override val extendedCheckRules = Seq(PreWriteCheck) - } - } - - /** - * Planner that takes into account Hive-specific strategies. - */ - override def planner: SparkPlanner = { - new SparkPlanner(sparkSession.sparkContext, conf, experimentalMethods.extraStrategies) - with HiveStrategies { - override val sparkSession: SparkSession = self.sparkSession - - override def strategies: Seq[Strategy] = { - experimentalMethods.extraStrategies ++ Seq( - FileSourceStrategy, - DataSourceStrategy, - SpecialLimits, - InMemoryScans, - HiveTableScans, - Scripts, - Aggregation, - JoinSelection, - BasicOperators - ) - } - } - } - - - // ------------------------------------------------------ - // Helper methods, partially leftover from pre-2.0 days - // ------------------------------------------------------ - - override def addJar(path: String): Unit = { - metadataHive.addJar(path) - super.addJar(path) - } - - /** - * When true, enables an experimental feature where metastore tables that use the parquet SerDe - * are automatically converted to use the Spark SQL parquet table scan, instead of the Hive - * SerDe. - */ - def convertMetastoreParquet: Boolean = { - conf.getConf(HiveUtils.CONVERT_METASTORE_PARQUET) - } - - /** - * When true, also tries to merge possibly different but compatible Parquet schemas in different - * Parquet data files. - * - * This configuration is only effective when "spark.sql.hive.convertMetastoreParquet" is true. - */ - def convertMetastoreParquetWithSchemaMerging: Boolean = { - conf.getConf(HiveUtils.CONVERT_METASTORE_PARQUET_WITH_SCHEMA_MERGING) - } - - /** - * When true, enables an experimental feature where metastore tables that use the Orc SerDe - * are automatically converted to use the Spark SQL ORC table scan, instead of the Hive - * SerDe. - */ - def convertMetastoreOrc: Boolean = { - conf.getConf(HiveUtils.CONVERT_METASTORE_ORC) - } - - /** - * When true, Hive Thrift server will execute SQL queries asynchronously using a thread pool." - */ - def hiveThriftServerAsync: Boolean = { - conf.getConf(HiveUtils.HIVE_THRIFT_SERVER_ASYNC) - } - -} diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionStateBuilder.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionStateBuilder.scala new file mode 100644 index 0000000000000..e16c9e46b7723 --- /dev/null +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionStateBuilder.scala @@ -0,0 +1,128 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive + +import org.apache.spark.annotation.{Experimental, InterfaceStability} +import org.apache.spark.sql._ +import org.apache.spark.sql.catalyst.analysis.Analyzer +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.catalyst.rules.Rule +import org.apache.spark.sql.execution.SparkPlanner +import org.apache.spark.sql.execution.datasources._ +import org.apache.spark.sql.hive.client.HiveClient +import org.apache.spark.sql.internal.{BaseSessionStateBuilder, SessionResourceLoader, SessionState} + +/** + * Builder that produces a Hive-aware `SessionState`. + */ +@Experimental +@InterfaceStability.Unstable +class HiveSessionStateBuilder(session: SparkSession, parentState: Option[SessionState] = None) + extends BaseSessionStateBuilder(session, parentState) { + + private def externalCatalog: HiveExternalCatalog = + session.sharedState.externalCatalog.asInstanceOf[HiveExternalCatalog] + + /** + * Create a Hive aware resource loader. + */ + override protected lazy val resourceLoader: HiveSessionResourceLoader = { + val client: HiveClient = externalCatalog.client.newSession() + new HiveSessionResourceLoader(session, client) + } + + /** + * Create a [[HiveSessionCatalog]]. + */ + override protected lazy val catalog: HiveSessionCatalog = { + val catalog = new HiveSessionCatalog( + externalCatalog, + session.sharedState.globalTempViewManager, + new HiveMetastoreCatalog(session), + functionRegistry, + conf, + SessionState.newHadoopConf(session.sparkContext.hadoopConfiguration, conf), + sqlParser, + resourceLoader) + parentState.foreach(_.catalog.copyStateTo(catalog)) + catalog + } + + /** + * A logical query plan `Analyzer` with rules specific to Hive. + */ + override protected def analyzer: Analyzer = new Analyzer(catalog, conf) { + override val extendedResolutionRules: Seq[Rule[LogicalPlan]] = + new ResolveHiveSerdeTable(session) +: + new FindDataSourceTable(session) +: + new ResolveSQLOnFile(session) +: + customResolutionRules + + override val postHocResolutionRules: Seq[Rule[LogicalPlan]] = + new DetermineTableStats(session) +: + RelationConversions(conf, catalog) +: + PreprocessTableCreation(session) +: + PreprocessTableInsertion(conf) +: + DataSourceAnalysis(conf) +: + HiveAnalysis +: + customPostHocResolutionRules + + override val extendedCheckRules: Seq[LogicalPlan => Unit] = + PreWriteCheck +: + customCheckRules + } + + /** + * Planner that takes into account Hive-specific strategies. + */ + override protected def planner: SparkPlanner = { + new SparkPlanner(session.sparkContext, conf, experimentalMethods) with HiveStrategies { + override val sparkSession: SparkSession = session + + override def extraPlanningStrategies: Seq[Strategy] = + super.extraPlanningStrategies ++ customPlanningStrategies + + override def strategies: Seq[Strategy] = { + experimentalMethods.extraStrategies ++ + extraPlanningStrategies ++ Seq( + FileSourceStrategy, + DataSourceStrategy(conf), + SpecialLimits, + InMemoryScans, + HiveTableScans, + Scripts, + Aggregation, + JoinSelection, + BasicOperators + ) + } + } + } + + override protected def newBuilder: NewBuilder = new HiveSessionStateBuilder(_, _) +} + +class HiveSessionResourceLoader( + session: SparkSession, + client: HiveClient) + extends SessionResourceLoader(session) { + override def addJar(path: String): Unit = { + client.addJar(path) + super.addJar(path) + } +} diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala index 624cfa206eeb2..09a5eda6e543f 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.hive import java.io.IOException -import java.net.URI +import java.util.Locale import org.apache.hadoop.fs.{FileSystem, Path} import org.apache.hadoop.hive.common.StatsSetupConst @@ -31,9 +31,11 @@ import org.apache.spark.sql.catalyst.plans.logical.{InsertIntoTable, LogicalPlan import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.execution._ import org.apache.spark.sql.execution.command.{CreateTableCommand, DDLUtils} -import org.apache.spark.sql.execution.datasources.CreateTable +import org.apache.spark.sql.execution.datasources.{CreateTable, LogicalRelation} +import org.apache.spark.sql.execution.datasources.parquet.{ParquetFileFormat, ParquetOptions} import org.apache.spark.sql.hive.execution._ -import org.apache.spark.sql.internal.HiveSerDe +import org.apache.spark.sql.hive.orc.OrcFileFormat +import org.apache.spark.sql.internal.{HiveSerDe, SQLConf} /** @@ -133,7 +135,7 @@ class DetermineTableStats(session: SparkSession) extends Rule[LogicalPlan] { } else if (session.sessionState.conf.fallBackToHdfsForStatsEnabled) { try { val hadoopConf = session.sessionState.newHadoopConf() - val tablePath = new Path(new URI(table.location)) + val tablePath = new Path(table.location) val fs: FileSystem = tablePath.getFileSystem(hadoopConf) fs.getContentSummary(tablePath).getLength } catch { @@ -170,6 +172,55 @@ object HiveAnalysis extends Rule[LogicalPlan] { } } +/** + * Relation conversion from metastore relations to data source relations for better performance + * + * - When writing to non-partitioned Hive-serde Parquet/Orc tables + * - When scanning Hive-serde Parquet/ORC tables + * + * This rule must be run before all other DDL post-hoc resolution rules, i.e. + * `PreprocessTableCreation`, `PreprocessTableInsertion`, `DataSourceAnalysis` and `HiveAnalysis`. + */ +case class RelationConversions( + conf: SQLConf, + sessionCatalog: HiveSessionCatalog) extends Rule[LogicalPlan] { + private def isConvertible(relation: CatalogRelation): Boolean = { + val serde = relation.tableMeta.storage.serde.getOrElse("").toLowerCase(Locale.ROOT) + serde.contains("parquet") && conf.getConf(HiveUtils.CONVERT_METASTORE_PARQUET) || + serde.contains("orc") && conf.getConf(HiveUtils.CONVERT_METASTORE_ORC) + } + + private def convert(relation: CatalogRelation): LogicalRelation = { + val serde = relation.tableMeta.storage.serde.getOrElse("").toLowerCase(Locale.ROOT) + if (serde.contains("parquet")) { + val options = Map(ParquetOptions.MERGE_SCHEMA -> + conf.getConf(HiveUtils.CONVERT_METASTORE_PARQUET_WITH_SCHEMA_MERGING).toString) + sessionCatalog.metastoreCatalog + .convertToLogicalRelation(relation, options, classOf[ParquetFileFormat], "parquet") + } else { + val options = Map[String, String]() + sessionCatalog.metastoreCatalog + .convertToLogicalRelation(relation, options, classOf[OrcFileFormat], "orc") + } + } + + override def apply(plan: LogicalPlan): LogicalPlan = { + plan transformUp { + // Write path + case InsertIntoTable(r: CatalogRelation, partition, query, overwrite, ifNotExists) + // Inserting into partitioned table is not supported in Parquet/Orc data source (yet). + if query.resolved && DDLUtils.isHiveTable(r.tableMeta) && + !r.isPartitioned && isConvertible(r) => + InsertIntoTable(convert(r), partition, query, overwrite, ifNotExists) + + // Read path + case relation: CatalogRelation + if DDLUtils.isHiveTable(relation.tableMeta) && isConvertible(relation) => + convert(relation) + } + } +} + private[hive] trait HiveStrategies { // Possibly being too clever with types here... or not clever enough. self: SparkPlanner => diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveUtils.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveUtils.scala index afc2bf85334d0..3de60c7fc1318 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveUtils.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveUtils.scala @@ -21,6 +21,7 @@ import java.io.File import java.net.{URL, URLClassLoader} import java.nio.charset.StandardCharsets import java.sql.Timestamp +import java.util.Locale import java.util.concurrent.TimeUnit import scala.collection.mutable.HashMap @@ -338,7 +339,7 @@ private[spark] object HiveUtils extends Logging { logWarning(s"Hive jar path '$path' does not exist.") Nil } else { - files.filter(_.getName.toLowerCase.endsWith(".jar")) + files.filter(_.getName.toLowerCase(Locale.ROOT).endsWith(".jar")) } case path => new File(path) :: Nil diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala index 7acaa9a7ab417..387ec4f967233 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.hive.client import java.io.{File, PrintStream} +import java.util.Locale import scala.collection.JavaConverters._ import scala.collection.mutable.ArrayBuffer @@ -97,6 +98,7 @@ private[hive] class HiveClientImpl( case hive.v1_1 => new Shim_v1_1() case hive.v1_2 => new Shim_v1_2() case hive.v2_0 => new Shim_v2_0() + case hive.v2_1 => new Shim_v2_1() } // Create an internal session state for this HiveClientImpl. @@ -152,7 +154,7 @@ private[hive] class HiveClientImpl( hadoopConf.iterator().asScala.foreach { entry => val key = entry.getKey val value = entry.getValue - if (key.toLowerCase.contains("password")) { + if (key.toLowerCase(Locale.ROOT).contains("password")) { logDebug(s"Applying Hadoop and Hive config to Hive Conf: $key=xxx") } else { logDebug(s"Applying Hadoop and Hive config to Hive Conf: $key=$value") @@ -167,7 +169,7 @@ private[hive] class HiveClientImpl( hiveConf.setClassLoader(initClassLoader) // 2: we set all spark confs to this hiveConf. sparkConf.getAll.foreach { case (k, v) => - if (k.toLowerCase.contains("password")) { + if (k.toLowerCase(Locale.ROOT).contains("password")) { logDebug(s"Applying Spark config to Hive Conf: $k=xxx") } else { logDebug(s"Applying Spark config to Hive Conf: $k=$v") @@ -176,7 +178,7 @@ private[hive] class HiveClientImpl( } // 3: we set all entries in config to this hiveConf. extraConfig.foreach { case (k, v) => - if (k.toLowerCase.contains("password")) { + if (k.toLowerCase(Locale.ROOT).contains("password")) { logDebug(s"Applying extra config to HiveConf: $k=xxx") } else { logDebug(s"Applying extra config to HiveConf: $k=$v") @@ -206,6 +208,8 @@ private[hive] class HiveClientImpl( /** Returns the configuration for the current session. */ def conf: HiveConf = state.getConf + private val userName = state.getAuthenticator.getUserName + override def getConf(key: String, defaultValue: String): String = { conf.get(key, defaultValue) } @@ -278,6 +282,8 @@ private[hive] class HiveClientImpl( state.getConf.setClassLoader(clientLoader.classLoader) // Set the thread local metastore client to the client associated with this HiveClientImpl. Hive.set(client) + // Replace conf in the thread local Hive with current conf + Hive.get(conf) // setCurrentSessionState will use the classLoader associated // with the HiveConf in `state` to override the context class loader of the current // thread. @@ -317,7 +323,7 @@ private[hive] class HiveClientImpl( new HiveDatabase( database.name, database.description, - database.locationUri, + CatalogUtils.URIToString(database.locationUri), Option(database.properties).map(_.asJava).orNull), ignoreIfExists) } @@ -335,7 +341,7 @@ private[hive] class HiveClientImpl( new HiveDatabase( database.name, database.description, - database.locationUri, + CatalogUtils.URIToString(database.locationUri), Option(database.properties).map(_.asJava).orNull)) } @@ -344,7 +350,7 @@ private[hive] class HiveClientImpl( CatalogDatabase( name = d.getName, description = d.getDescription, - locationUri = d.getLocationUri, + locationUri = CatalogUtils.stringToURI(d.getLocationUri), properties = Option(d.getParameters).map(_.asScala.toMap).orNull) }.getOrElse(throw new NoSuchDatabaseException(dbName)) } @@ -410,7 +416,7 @@ private[hive] class HiveClientImpl( createTime = h.getTTable.getCreateTime.toLong * 1000, lastAccessTime = h.getLastAccessTime.toLong * 1000, storage = CatalogStorageFormat( - locationUri = shim.getDataLocation(h), + locationUri = shim.getDataLocation(h).map(CatalogUtils.stringToURI), // To avoid ClassNotFound exception, we try our best to not get the format class, but get // the class name directly. However, for non-native tables, there is no interface to get // the format class name, so we may still throw ClassNotFound in this case. @@ -438,7 +444,7 @@ private[hive] class HiveClientImpl( } override def createTable(table: CatalogTable, ignoreIfExists: Boolean): Unit = withHiveState { - client.createTable(toHiveTable(table, Some(conf)), ignoreIfExists) + client.createTable(toHiveTable(table, Some(userName)), ignoreIfExists) } override def dropTable( @@ -450,10 +456,10 @@ private[hive] class HiveClientImpl( } override def alterTable(tableName: String, table: CatalogTable): Unit = withHiveState { - val hiveTable = toHiveTable(table, Some(conf)) + val hiveTable = toHiveTable(table, Some(userName)) // Do not use `table.qualifiedName` here because this may be a rename val qualifiedTableName = s"${table.database}.$tableName" - client.alterTable(qualifiedTableName, hiveTable) + shim.alterTable(client, qualifiedTableName, hiveTable) } override def createPartitions( @@ -519,7 +525,7 @@ private[hive] class HiveClientImpl( newSpecs: Seq[TablePartitionSpec]): Unit = withHiveState { require(specs.size == newSpecs.size, "number of old and new partition specs differ") val catalogTable = getTable(db, table) - val hiveTable = toHiveTable(catalogTable, Some(conf)) + val hiveTable = toHiveTable(catalogTable, Some(userName)) specs.zip(newSpecs).foreach { case (oldSpec, newSpec) => val hivePart = getPartitionOption(catalogTable, oldSpec) .map { p => toHivePartition(p.copy(spec = newSpec), hiveTable) } @@ -532,8 +538,8 @@ private[hive] class HiveClientImpl( db: String, table: String, newParts: Seq[CatalogTablePartition]): Unit = withHiveState { - val hiveTable = toHiveTable(getTable(db, table), Some(conf)) - client.alterPartitions(table, newParts.map { p => toHivePartition(p, hiveTable) }.asJava) + val hiveTable = toHiveTable(getTable(db, table), Some(userName)) + shim.alterPartitions(client, table, newParts.map { p => toHivePartition(p, hiveTable) }.asJava) } /** @@ -560,7 +566,7 @@ private[hive] class HiveClientImpl( override def getPartitionOption( table: CatalogTable, spec: TablePartitionSpec): Option[CatalogTablePartition] = withHiveState { - val hiveTable = toHiveTable(table, Some(conf)) + val hiveTable = toHiveTable(table, Some(userName)) val hivePartition = client.getPartition(hiveTable, spec.asJava, false) Option(hivePartition).map(fromHivePartition) } @@ -572,7 +578,7 @@ private[hive] class HiveClientImpl( override def getPartitions( table: CatalogTable, spec: Option[TablePartitionSpec]): Seq[CatalogTablePartition] = withHiveState { - val hiveTable = toHiveTable(table, Some(conf)) + val hiveTable = toHiveTable(table, Some(userName)) val parts = spec match { case None => shim.getAllPartitions(client, hiveTable).map(fromHivePartition) case Some(s) => @@ -586,7 +592,7 @@ private[hive] class HiveClientImpl( override def getPartitionsByFilter( table: CatalogTable, predicates: Seq[Expression]): Seq[CatalogTablePartition] = withHiveState { - val hiveTable = toHiveTable(table, Some(conf)) + val hiveTable = toHiveTable(table, Some(userName)) val parts = shim.getPartitionsByFilter(client, hiveTable, predicates).map(fromHivePartition) HiveCatalogMetrics.incrementFetchedPartitions(parts.length) parts @@ -617,7 +623,7 @@ private[hive] class HiveClientImpl( */ protected def runHive(cmd: String, maxRows: Int = 1000): Seq[String] = withHiveState { logDebug(s"Running hiveql '$cmd'") - if (cmd.toLowerCase.startsWith("set")) { logDebug(s"Changing config: $cmd") } + if (cmd.toLowerCase(Locale.ROOT).startsWith("set")) { logDebug(s"Changing config: $cmd") } try { val cmd_trimmed: String = cmd.trim() val tokens: Array[String] = cmd_trimmed.split("\\s+") @@ -814,9 +820,7 @@ private[hive] object HiveClientImpl { /** * Converts the native table metadata representation format CatalogTable to Hive's Table. */ - def toHiveTable( - table: CatalogTable, - conf: Option[HiveConf] = None): HiveTable = { + def toHiveTable(table: CatalogTable, userName: Option[String] = None): HiveTable = { val hiveTable = new HiveTable(table.database, table.identifier.table) // For EXTERNAL_TABLE, we also need to set EXTERNAL field in the table properties. // Otherwise, Hive metastore will change the table to a MANAGED_TABLE. @@ -848,10 +852,11 @@ private[hive] object HiveClientImpl { hiveTable.setFields(schema.asJava) } hiveTable.setPartCols(partCols.asJava) - conf.foreach(c => hiveTable.setOwner(c.getUser)) + userName.foreach(hiveTable.setOwner) hiveTable.setCreateTime((table.createTime / 1000).toInt) hiveTable.setLastAccessTime((table.lastAccessTime / 1000).toInt) - table.storage.locationUri.foreach { loc => hiveTable.getTTable.getSd.setLocation(loc)} + table.storage.locationUri.map(CatalogUtils.URIToString).foreach { loc => + hiveTable.getTTable.getSd.setLocation(loc)} table.storage.inputFormat.map(toInputFormat).foreach(hiveTable.setInputFormatClass) table.storage.outputFormat.map(toOutputFormat).foreach(hiveTable.setOutputFormatClass) hiveTable.setSerializationLib( @@ -885,7 +890,7 @@ private[hive] object HiveClientImpl { } val storageDesc = new StorageDescriptor val serdeInfo = new SerDeInfo - p.storage.locationUri.foreach(storageDesc.setLocation) + p.storage.locationUri.map(CatalogUtils.URIToString(_)).foreach(storageDesc.setLocation) p.storage.inputFormat.foreach(storageDesc.setInputFormat) p.storage.outputFormat.foreach(storageDesc.setOutputFormat) p.storage.serde.foreach(serdeInfo.setSerializationLib) @@ -906,7 +911,7 @@ private[hive] object HiveClientImpl { CatalogTablePartition( spec = Option(hp.getSpec).map(_.asScala.toMap).getOrElse(Map.empty), storage = CatalogStorageFormat( - locationUri = Option(apiPartition.getSd.getLocation), + locationUri = Option(CatalogUtils.stringToURI(apiPartition.getSd.getLocation)), inputFormat = Option(apiPartition.getSd.getInputFormat), outputFormat = Option(apiPartition.getSd.getOutputFormat), serde = Option(apiPartition.getSd.getSerdeInfo.getSerializationLib), diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveShim.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveShim.scala index 7280748361d60..7abb9f06b1310 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveShim.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveShim.scala @@ -20,17 +20,18 @@ package org.apache.spark.sql.hive.client import java.lang.{Boolean => JBoolean, Integer => JInteger, Long => JLong} import java.lang.reflect.{InvocationTargetException, Method, Modifier} import java.net.URI -import java.util.{ArrayList => JArrayList, List => JList, Map => JMap, Set => JSet} +import java.util.{ArrayList => JArrayList, List => JList, Locale, Map => JMap, Set => JSet} import java.util.concurrent.TimeUnit import scala.collection.JavaConverters._ -import scala.util.Try import scala.util.control.NonFatal -import org.apache.hadoop.fs.{FileSystem, Path} +import org.apache.hadoop.fs.Path import org.apache.hadoop.hive.conf.HiveConf -import org.apache.hadoop.hive.metastore.api.{Function => HiveFunction, FunctionType, MetaException, PrincipalType, ResourceType, ResourceUri} +import org.apache.hadoop.hive.metastore.api.{EnvironmentContext, Function => HiveFunction, FunctionType} +import org.apache.hadoop.hive.metastore.api.{MetaException, PrincipalType, ResourceType, ResourceUri} import org.apache.hadoop.hive.ql.Driver +import org.apache.hadoop.hive.ql.io.AcidUtils import org.apache.hadoop.hive.ql.metadata.{Hive, HiveException, Partition, Table} import org.apache.hadoop.hive.ql.plan.AddPartitionDesc import org.apache.hadoop.hive.ql.processors.{CommandProcessor, CommandProcessorFactory} @@ -41,7 +42,7 @@ import org.apache.spark.internal.Logging import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.FunctionIdentifier import org.apache.spark.sql.catalyst.analysis.NoSuchPermanentFunctionException -import org.apache.spark.sql.catalyst.catalog.{CatalogFunction, CatalogTablePartition, FunctionResource, FunctionResourceType} +import org.apache.spark.sql.catalyst.catalog.{CatalogFunction, CatalogTablePartition, CatalogUtils, FunctionResource, FunctionResourceType} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.{IntegralType, StringType} @@ -83,6 +84,10 @@ private[client] sealed abstract class Shim { def getMetastoreClientConnectRetryDelayMillis(conf: HiveConf): Long + def alterTable(hive: Hive, tableName: String, table: Table): Unit + + def alterPartitions(hive: Hive, tableName: String, newParts: JList[Partition]): Unit + def createPartitions( hive: Hive, db: String, @@ -159,6 +164,10 @@ private[client] sealed abstract class Shim { } private[client] class Shim_v0_12 extends Shim with Logging { + // See HIVE-12224, HOLD_DDLTIME was broken as soon as it landed + protected lazy val holdDDLTime = JBoolean.FALSE + // deletes the underlying data along with metadata + protected lazy val deleteDataInDropIndex = JBoolean.TRUE private lazy val startMethod = findStaticMethod( @@ -241,6 +250,18 @@ private[client] class Shim_v0_12 extends Shim with Logging { classOf[String], classOf[String], JBoolean.TYPE) + private lazy val alterTableMethod = + findMethod( + classOf[Hive], + "alterTable", + classOf[String], + classOf[Table]) + private lazy val alterPartitionsMethod = + findMethod( + classOf[Hive], + "alterPartitions", + classOf[String], + classOf[JList[Partition]]) override def setCurrentSessionState(state: SessionState): Unit = { // Starting from Hive 0.13, setCurrentSessionState will internally override @@ -268,7 +289,7 @@ private[client] class Shim_v0_12 extends Shim with Logging { val table = hive.getTable(database, tableName) parts.foreach { s => val location = s.storage.locationUri.map( - uri => new Path(table.getPath, new Path(new URI(uri)))).orNull + uri => new Path(table.getPath, new Path(uri))).orNull val params = if (s.parameters.nonEmpty) s.parameters.asJava else null val spec = s.spec.asJava if (hive.getPartition(table, spec, false) != null && ignoreIfExists) { @@ -342,7 +363,7 @@ private[client] class Shim_v0_12 extends Shim with Logging { tableName: String, replace: Boolean, isSrcLocal: Boolean): Unit = { - loadTableMethod.invoke(hive, loadPath, tableName, replace: JBoolean, JBoolean.FALSE) + loadTableMethod.invoke(hive, loadPath, tableName, replace: JBoolean, holdDDLTime) } override def loadDynamicPartitions( @@ -354,11 +375,11 @@ private[client] class Shim_v0_12 extends Shim with Logging { numDP: Int, listBucketingEnabled: Boolean): Unit = { loadDynamicPartitionsMethod.invoke(hive, loadPath, tableName, partSpec, replace: JBoolean, - numDP: JInteger, JBoolean.FALSE, listBucketingEnabled: JBoolean) + numDP: JInteger, holdDDLTime, listBucketingEnabled: JBoolean) } override def dropIndex(hive: Hive, dbName: String, tableName: String, indexName: String): Unit = { - dropIndexMethod.invoke(hive, dbName, tableName, indexName, true: JBoolean) + dropIndexMethod.invoke(hive, dbName, tableName, indexName, deleteDataInDropIndex) } override def dropTable( @@ -374,6 +395,14 @@ private[client] class Shim_v0_12 extends Shim with Logging { hive.dropTable(dbName, tableName, deleteData, ignoreIfNotExists) } + override def alterTable(hive: Hive, tableName: String, table: Table): Unit = { + alterTableMethod.invoke(hive, tableName, table) + } + + override def alterPartitions(hive: Hive, tableName: String, newParts: JList[Partition]): Unit = { + alterPartitionsMethod.invoke(hive, tableName, newParts) + } + override def dropPartition( hive: Hive, dbName: String, @@ -463,7 +492,7 @@ private[client] class Shim_v0_13 extends Shim_v0_12 { val addPartitionDesc = new AddPartitionDesc(db, table, ignoreIfExists) parts.zipWithIndex.foreach { case (s, i) => addPartitionDesc.addPartition( - s.spec.asJava, s.storage.locationUri.map(u => new Path(new URI(u)).toString).orNull) + s.spec.asJava, s.storage.locationUri.map(CatalogUtils.URIToString(_)).orNull) if (s.parameters.nonEmpty) { addPartitionDesc.getPartition(i).setPartParams(s.parameters.asJava) } @@ -476,8 +505,8 @@ private[client] class Shim_v0_13 extends Shim_v0_12 { private def toHiveFunction(f: CatalogFunction, db: String): HiveFunction = { val resourceUris = f.resources.map { resource => - new ResourceUri( - ResourceType.valueOf(resource.resourceType.resourceType.toUpperCase()), resource.uri) + new ResourceUri(ResourceType.valueOf( + resource.resourceType.resourceType.toUpperCase(Locale.ROOT)), resource.uri) } new HiveFunction( f.identifier.funcName, @@ -521,7 +550,7 @@ private[client] class Shim_v0_13 extends Shim_v0_12 { } FunctionResource(FunctionResourceType.fromString(resourceType), uri.getUri()) } - new CatalogFunction(name, hf.getClassName, resources) + CatalogFunction(name, hf.getClassName, resources) } override def getFunctionOption(hive: Hive, db: String, name: String): Option[CatalogFunction] = { @@ -555,7 +584,7 @@ private[client] class Shim_v0_13 extends Shim_v0_12 { */ def convertFilters(table: Table, filters: Seq[Expression]): String = { // hive varchar is treated as catalyst string, but hive varchar can't be pushed down. - val varcharKeys = table.getPartitionKeys.asScala + lazy val varcharKeys = table.getPartitionKeys.asScala .filter(col => col.getType.startsWith(serdeConstants.VARCHAR_TYPE_NAME) || col.getType.startsWith(serdeConstants.CHAR_TYPE_NAME)) .map(col => col.getName).toSet @@ -567,13 +596,24 @@ private[client] class Shim_v0_13 extends Shim_v0_12 { s"$v ${op.symbol} ${a.name}" case op @ BinaryComparison(a: Attribute, Literal(v, _: StringType)) if !varcharKeys.contains(a.name) => - s"""${a.name} ${op.symbol} "$v"""" + s"""${a.name} ${op.symbol} ${quoteStringLiteral(v.toString)}""" case op @ BinaryComparison(Literal(v, _: StringType), a: Attribute) if !varcharKeys.contains(a.name) => - s""""$v" ${op.symbol} ${a.name}""" + s"""${quoteStringLiteral(v.toString)} ${op.symbol} ${a.name}""" }.mkString(" and ") } + private def quoteStringLiteral(str: String): String = { + if (!str.contains("\"")) { + s""""$str"""" + } else if (!str.contains("'")) { + s"""'$str'""" + } else { + throw new UnsupportedOperationException( + """Partition filter cannot have both `"` and `'` characters""") + } + } + override def getPartitionsByFilter( hive: Hive, table: Table, @@ -582,6 +622,7 @@ private[client] class Shim_v0_13 extends Shim_v0_12 { // Hive getPartitionsByFilter() takes a string that represents partition // predicates like "str_key=\"value\" and int_key=1 ..." val filter = convertFilters(table, predicates) + val partitions = if (filter.isEmpty) { getAllPartitionsMethod.invoke(hive, table).asInstanceOf[JSet[Partition]] @@ -639,6 +680,11 @@ private[client] class Shim_v0_13 extends Shim_v0_12 { private[client] class Shim_v0_14 extends Shim_v0_13 { + // true if this is an ACID operation + protected lazy val isAcid = JBoolean.FALSE + // true if list bucketing enabled + protected lazy val isSkewedStoreAsSubdir = JBoolean.FALSE + private lazy val loadPartitionMethod = findMethod( classOf[Hive], @@ -701,8 +747,8 @@ private[client] class Shim_v0_14 extends Shim_v0_13 { isSkewedStoreAsSubdir: Boolean, isSrcLocal: Boolean): Unit = { loadPartitionMethod.invoke(hive, loadPath, tableName, partSpec, replace: JBoolean, - JBoolean.FALSE, inheritTableSpecs: JBoolean, isSkewedStoreAsSubdir: JBoolean, - isSrcLocal: JBoolean, JBoolean.FALSE) + holdDDLTime, inheritTableSpecs: JBoolean, isSkewedStoreAsSubdir: JBoolean, + isSrcLocal: JBoolean, isAcid) } override def loadTable( @@ -711,8 +757,8 @@ private[client] class Shim_v0_14 extends Shim_v0_13 { tableName: String, replace: Boolean, isSrcLocal: Boolean): Unit = { - loadTableMethod.invoke(hive, loadPath, tableName, replace: JBoolean, JBoolean.FALSE, - isSrcLocal: JBoolean, JBoolean.FALSE, JBoolean.FALSE) + loadTableMethod.invoke(hive, loadPath, tableName, replace: JBoolean, holdDDLTime, + isSrcLocal: JBoolean, isSkewedStoreAsSubdir, isAcid) } override def loadDynamicPartitions( @@ -724,7 +770,7 @@ private[client] class Shim_v0_14 extends Shim_v0_13 { numDP: Int, listBucketingEnabled: Boolean): Unit = { loadDynamicPartitionsMethod.invoke(hive, loadPath, tableName, partSpec, replace: JBoolean, - numDP: JInteger, JBoolean.FALSE, listBucketingEnabled: JBoolean, JBoolean.FALSE) + numDP: JInteger, holdDDLTime, listBucketingEnabled: JBoolean, isAcid) } override def dropTable( @@ -734,12 +780,8 @@ private[client] class Shim_v0_14 extends Shim_v0_13 { deleteData: Boolean, ignoreIfNotExists: Boolean, purge: Boolean): Unit = { - try { - dropTableMethod.invoke(hive, dbName, tableName, deleteData: JBoolean, - ignoreIfNotExists: JBoolean, purge: JBoolean) - } catch { - case e: InvocationTargetException => throw e.getCause() - } + dropTableMethod.invoke(hive, dbName, tableName, deleteData: JBoolean, + ignoreIfNotExists: JBoolean, purge: JBoolean) } override def getMetastoreClientConnectRetryDelayMillis(conf: HiveConf): Long = { @@ -757,6 +799,9 @@ private[client] class Shim_v1_0 extends Shim_v0_14 { private[client] class Shim_v1_1 extends Shim_v1_0 { + // throws an exception if the index does not exist + protected lazy val throwExceptionInDropIndex = JBoolean.TRUE + private lazy val dropIndexMethod = findMethod( classOf[Hive], @@ -768,13 +813,17 @@ private[client] class Shim_v1_1 extends Shim_v1_0 { JBoolean.TYPE) override def dropIndex(hive: Hive, dbName: String, tableName: String, indexName: String): Unit = { - dropIndexMethod.invoke(hive, dbName, tableName, indexName, true: JBoolean, true: JBoolean) + dropIndexMethod.invoke(hive, dbName, tableName, indexName, throwExceptionInDropIndex, + deleteDataInDropIndex) } } private[client] class Shim_v1_2 extends Shim_v1_1 { + // txnId can be 0 unless isAcid == true + protected lazy val txnIdInLoadDynamicPartitions: JLong = 0L + private lazy val loadDynamicPartitionsMethod = findMethod( classOf[Hive], @@ -811,8 +860,8 @@ private[client] class Shim_v1_2 extends Shim_v1_1 { numDP: Int, listBucketingEnabled: Boolean): Unit = { loadDynamicPartitionsMethod.invoke(hive, loadPath, tableName, partSpec, replace: JBoolean, - numDP: JInteger, JBoolean.FALSE, listBucketingEnabled: JBoolean, JBoolean.FALSE, - 0L: JLong) + numDP: JInteger, holdDDLTime, listBucketingEnabled: JBoolean, isAcid, + txnIdInLoadDynamicPartitions) } override def dropPartition( @@ -825,11 +874,7 @@ private[client] class Shim_v1_2 extends Shim_v1_1 { val dropOptions = dropOptionsClass.newInstance().asInstanceOf[Object] dropOptionsDeleteData.setBoolean(dropOptions, deleteData) dropOptionsPurge.setBoolean(dropOptions, purge) - try { - dropPartitionMethod.invoke(hive, dbName, tableName, part, dropOptions) - } catch { - case e: InvocationTargetException => throw e.getCause() - } + dropPartitionMethod.invoke(hive, dbName, tableName, part, dropOptions) } } @@ -881,7 +926,7 @@ private[client] class Shim_v2_0 extends Shim_v1_2 { isSrcLocal: Boolean): Unit = { loadPartitionMethod.invoke(hive, loadPath, tableName, partSpec, replace: JBoolean, inheritTableSpecs: JBoolean, isSkewedStoreAsSubdir: JBoolean, - isSrcLocal: JBoolean, JBoolean.FALSE) + isSrcLocal: JBoolean, isAcid) } override def loadTable( @@ -891,7 +936,7 @@ private[client] class Shim_v2_0 extends Shim_v1_2 { replace: Boolean, isSrcLocal: Boolean): Unit = { loadTableMethod.invoke(hive, loadPath, tableName, replace: JBoolean, isSrcLocal: JBoolean, - JBoolean.FALSE, JBoolean.FALSE) + isSkewedStoreAsSubdir, isAcid) } override def loadDynamicPartitions( @@ -903,7 +948,114 @@ private[client] class Shim_v2_0 extends Shim_v1_2 { numDP: Int, listBucketingEnabled: Boolean): Unit = { loadDynamicPartitionsMethod.invoke(hive, loadPath, tableName, partSpec, replace: JBoolean, - numDP: JInteger, listBucketingEnabled: JBoolean, JBoolean.FALSE, 0L: JLong) + numDP: JInteger, listBucketingEnabled: JBoolean, isAcid, txnIdInLoadDynamicPartitions) } } + +private[client] class Shim_v2_1 extends Shim_v2_0 { + + // true if there is any following stats task + protected lazy val hasFollowingStatsTask = JBoolean.FALSE + // TODO: Now, always set environmentContext to null. In the future, we should avoid setting + // hive-generated stats to -1 when altering tables by using environmentContext. See Hive-12730 + protected lazy val environmentContextInAlterTable = null + + private lazy val loadPartitionMethod = + findMethod( + classOf[Hive], + "loadPartition", + classOf[Path], + classOf[String], + classOf[JMap[String, String]], + JBoolean.TYPE, + JBoolean.TYPE, + JBoolean.TYPE, + JBoolean.TYPE, + JBoolean.TYPE, + JBoolean.TYPE) + private lazy val loadTableMethod = + findMethod( + classOf[Hive], + "loadTable", + classOf[Path], + classOf[String], + JBoolean.TYPE, + JBoolean.TYPE, + JBoolean.TYPE, + JBoolean.TYPE, + JBoolean.TYPE) + private lazy val loadDynamicPartitionsMethod = + findMethod( + classOf[Hive], + "loadDynamicPartitions", + classOf[Path], + classOf[String], + classOf[JMap[String, String]], + JBoolean.TYPE, + JInteger.TYPE, + JBoolean.TYPE, + JBoolean.TYPE, + JLong.TYPE, + JBoolean.TYPE, + classOf[AcidUtils.Operation]) + private lazy val alterTableMethod = + findMethod( + classOf[Hive], + "alterTable", + classOf[String], + classOf[Table], + classOf[EnvironmentContext]) + private lazy val alterPartitionsMethod = + findMethod( + classOf[Hive], + "alterPartitions", + classOf[String], + classOf[JList[Partition]], + classOf[EnvironmentContext]) + + override def loadPartition( + hive: Hive, + loadPath: Path, + tableName: String, + partSpec: JMap[String, String], + replace: Boolean, + inheritTableSpecs: Boolean, + isSkewedStoreAsSubdir: Boolean, + isSrcLocal: Boolean): Unit = { + loadPartitionMethod.invoke(hive, loadPath, tableName, partSpec, replace: JBoolean, + inheritTableSpecs: JBoolean, isSkewedStoreAsSubdir: JBoolean, + isSrcLocal: JBoolean, isAcid, hasFollowingStatsTask) + } + + override def loadTable( + hive: Hive, + loadPath: Path, + tableName: String, + replace: Boolean, + isSrcLocal: Boolean): Unit = { + loadTableMethod.invoke(hive, loadPath, tableName, replace: JBoolean, isSrcLocal: JBoolean, + isSkewedStoreAsSubdir, isAcid, hasFollowingStatsTask) + } + + override def loadDynamicPartitions( + hive: Hive, + loadPath: Path, + tableName: String, + partSpec: JMap[String, String], + replace: Boolean, + numDP: Int, + listBucketingEnabled: Boolean): Unit = { + loadDynamicPartitionsMethod.invoke(hive, loadPath, tableName, partSpec, replace: JBoolean, + numDP: JInteger, listBucketingEnabled: JBoolean, isAcid, txnIdInLoadDynamicPartitions, + hasFollowingStatsTask, AcidUtils.Operation.NOT_ACID) + } + + override def alterTable(hive: Hive, tableName: String, table: Table): Unit = { + alterTableMethod.invoke(hive, tableName, table, environmentContextInAlterTable) + } + + override def alterPartitions(hive: Hive, tableName: String, newParts: JList[Partition]): Unit = { + alterPartitionsMethod.invoke(hive, tableName, newParts, environmentContextInAlterTable) + } +} diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/IsolatedClientLoader.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/IsolatedClientLoader.scala index 9c17d173df42a..2d14cb708cae6 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/IsolatedClientLoader.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/IsolatedClientLoader.scala @@ -95,6 +95,7 @@ private[hive] object IsolatedClientLoader extends Logging { case "1.1" | "1.1.0" => hive.v1_1 case "1.2" | "1.2.0" | "1.2.1" => hive.v1_2 case "2.0" | "2.0.0" | "2.0.1" => hive.v2_0 + case "2.1" | "2.1.0" | "2.1.1" => hive.v2_1 } private def downloadVersion( diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/package.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/package.scala index 790ad74e6639e..f9635e36549e8 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/package.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/package.scala @@ -67,7 +67,11 @@ package object client { exclusions = Seq("org.apache.curator:*", "org.pentaho:pentaho-aggdesigner-algorithm")) - val allSupportedHiveVersions = Set(v12, v13, v14, v1_0, v1_1, v1_2, v2_0) + case object v2_1 extends HiveVersion("2.1.1", + exclusions = Seq("org.apache.curator:*", + "org.pentaho:pentaho-aggdesigner-algorithm")) + + val allSupportedHiveVersions = Set(v12, v13, v14, v1_0, v1_1, v1_2, v2_0, v2_1) } // scalastyle:on diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveOptions.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveOptions.scala index 192851028031b..5c515515b9b9c 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveOptions.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveOptions.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.hive.execution +import java.util.Locale + import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap /** @@ -29,7 +31,7 @@ class HiveOptions(@transient private val parameters: CaseInsensitiveMap[String]) def this(parameters: Map[String, String]) = this(CaseInsensitiveMap(parameters)) - val fileFormat = parameters.get(FILE_FORMAT).map(_.toLowerCase) + val fileFormat = parameters.get(FILE_FORMAT).map(_.toLowerCase(Locale.ROOT)) val inputFormat = parameters.get(INPUT_FORMAT) val outputFormat = parameters.get(OUTPUT_FORMAT) @@ -75,7 +77,7 @@ class HiveOptions(@transient private val parameters: CaseInsensitiveMap[String]) } def serdeProperties: Map[String, String] = parameters.filterKeys { - k => !lowerCasedOptionNames.contains(k.toLowerCase) + k => !lowerCasedOptionNames.contains(k.toLowerCase(Locale.ROOT)) }.map { case (k, v) => delimiterOptions.getOrElse(k, k) -> v } } @@ -83,7 +85,7 @@ object HiveOptions { private val lowerCasedOptionNames = collection.mutable.Set[String]() private def newOption(name: String): String = { - lowerCasedOptionNames += name.toLowerCase + lowerCasedOptionNames += name.toLowerCase(Locale.ROOT) name } @@ -99,5 +101,5 @@ object HiveOptions { // The following typo is inherited from Hive... "collectionDelim" -> "colelction.delim", "mapkeyDelim" -> "mapkey.delim", - "lineDelim" -> "line.delim").map { case (k, v) => k.toLowerCase -> v } + "lineDelim" -> "line.delim").map { case (k, v) => k.toLowerCase(Locale.ROOT) -> v } } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveTableScanExec.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveTableScanExec.scala index 28f074849c0f5..666548d1a490b 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveTableScanExec.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveTableScanExec.scala @@ -32,6 +32,7 @@ import org.apache.spark.sql.SparkSession import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.catalog.CatalogRelation import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.plans.QueryPlan import org.apache.spark.sql.execution._ import org.apache.spark.sql.execution.metric.SQLMetrics import org.apache.spark.sql.hive._ @@ -72,7 +73,7 @@ case class HiveTableScanExec( // Bind all partition key attribute references in the partition pruning predicate for later // evaluation. - private val boundPruningPred = partitionPruningPred.reduceLeftOption(And).map { pred => + private lazy val boundPruningPred = partitionPruningPred.reduceLeftOption(And).map { pred => require( pred.dataType == BooleanType, s"Data type of predicate $pred must be BooleanType rather than ${pred.dataType}.") @@ -80,20 +81,22 @@ case class HiveTableScanExec( BindReferences.bindReference(pred, relation.partitionCols) } - // Create a local copy of hadoopConf,so that scan specific modifications should not impact - // other queries - @transient private val hadoopConf = sparkSession.sessionState.newHadoopConf() - - @transient private val hiveQlTable = HiveClientImpl.toHiveTable(relation.tableMeta) - @transient private val tableDesc = new TableDesc( + @transient private lazy val hiveQlTable = HiveClientImpl.toHiveTable(relation.tableMeta) + @transient private lazy val tableDesc = new TableDesc( hiveQlTable.getInputFormatClass, hiveQlTable.getOutputFormatClass, hiveQlTable.getMetadata) - // append columns ids and names before broadcast - addColumnMetadataToConf(hadoopConf) + // Create a local copy of hadoopConf,so that scan specific modifications should not impact + // other queries + @transient private lazy val hadoopConf = { + val c = sparkSession.sessionState.newHadoopConf() + // append columns ids and names before broadcast + addColumnMetadataToConf(c) + c + } - @transient private val hadoopReader = new HadoopTableReader( + @transient private lazy val hadoopReader = new HadoopTableReader( output, relation.partitionCols, tableDesc, @@ -104,7 +107,7 @@ case class HiveTableScanExec( Cast(Literal(value), dataType).eval(null) } - private def addColumnMetadataToConf(hiveConf: Configuration) { + private def addColumnMetadataToConf(hiveConf: Configuration): Unit = { // Specifies needed column IDs for those non-partitioning columns. val columnOrdinals = AttributeMap(relation.dataCols.zipWithIndex) val neededColumnIDs = output.flatMap(columnOrdinals.get).map(o => o: Integer) @@ -198,18 +201,13 @@ case class HiveTableScanExec( } } - override def sameResult(plan: SparkPlan): Boolean = plan match { - case other: HiveTableScanExec => - val thisPredicates = partitionPruningPred.map(cleanExpression) - val otherPredicates = other.partitionPruningPred.map(cleanExpression) - - val result = relation.sameResult(other.relation) && - output.length == other.output.length && - output.zip(other.output) - .forall(p => p._1.name == p._2.name && p._1.dataType == p._2.dataType) && - thisPredicates.length == otherPredicates.length && - thisPredicates.zip(otherPredicates).forall(p => p._1.semanticEquals(p._2)) - result - case _ => false + override lazy val canonicalized: HiveTableScanExec = { + val input: AttributeSeq = relation.output + HiveTableScanExec( + requestedAttributes.map(QueryPlan.normalizeExprId(_, input)), + relation.canonicalized.asInstanceOf[CatalogRelation], + partitionPruningPred.map(QueryPlan.normalizeExprId(_, input)))(sparkSession) } + + override def otherCopyArgs: Seq[AnyRef] = Seq(sparkSession) } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala index 3c57ee4c8b8f6..3682dc850790e 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala @@ -149,7 +149,7 @@ case class InsertIntoHiveTable( // staging directory under the table director for Hive prior to 1.1, the staging directory will // be removed by Hive when Hive is trying to empty the table directory. val hiveVersionsUsingOldExternalTempPath: Set[HiveVersion] = Set(v12, v13, v14, v1_0) - val hiveVersionsUsingNewExternalTempPath: Set[HiveVersion] = Set(v1_1, v1_2, v2_0) + val hiveVersionsUsingNewExternalTempPath: Set[HiveVersion] = Set(v1_1, v1_2, v2_0, v2_1) // Ensure all the supported versions are considered here. assert(hiveVersionsUsingNewExternalTempPath ++ hiveVersionsUsingOldExternalTempPath == @@ -393,8 +393,8 @@ case class InsertIntoHiveTable( logWarning(s"Unable to delete staging directory: $stagingDir.\n" + e) } - // Invalidate the cache. - sparkSession.catalog.uncacheTable(table.qualifiedName) + // un-cache this table. + sparkSession.catalog.uncacheTable(table.identifier.quotedString) sparkSession.sessionState.catalog.refreshTable(table.identifier) // It would be nice to just return the childRdd unchanged so insert operations could be chained, diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala index 506949cb682b0..a83ad61b204ad 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala @@ -44,7 +44,7 @@ private[hive] case class HiveSimpleUDF( name: String, funcWrapper: HiveFunctionWrapper, children: Seq[Expression]) extends Expression with HiveInspectors with CodegenFallback with Logging { - override def deterministic: Boolean = isUDFDeterministic + override def deterministic: Boolean = isUDFDeterministic && children.forall(_.deterministic) override def nullable: Boolean = true @@ -108,12 +108,13 @@ private[hive] case class HiveSimpleUDF( private[hive] class DeferredObjectAdapter(oi: ObjectInspector, dataType: DataType) extends DeferredObject with HiveInspectors { + private val wrapper = wrapperFor(oi, dataType) private var func: () => Any = _ def set(func: () => Any): Unit = { this.func = func } override def prepare(i: Int): Unit = {} - override def get(): AnyRef = wrap(func(), oi, dataType) + override def get(): AnyRef = wrapper(func()).asInstanceOf[AnyRef] } private[hive] case class HiveGenericUDF( @@ -122,7 +123,7 @@ private[hive] case class HiveGenericUDF( override def nullable: Boolean = true - override def deterministic: Boolean = isUDFDeterministic + override def deterministic: Boolean = isUDFDeterministic && children.forall(_.deterministic) override def foldable: Boolean = isUDFDeterministic && returnInspector.isInstanceOf[ConstantObjectInspector] diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFileFormat.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFileFormat.scala index f496c01ce9ff7..3a34ec55c8b07 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFileFormat.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFileFormat.scala @@ -20,6 +20,8 @@ package org.apache.spark.sql.hive.orc import java.net.URI import java.util.Properties +import scala.collection.JavaConverters._ + import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.{FileStatus, Path} import org.apache.hadoop.hive.conf.HiveConf.ConfVars @@ -196,6 +198,11 @@ private[orc] class OrcSerializer(dataSchema: StructType, conf: Configuration) private[this] val cachedOrcStruct = structOI.create().asInstanceOf[OrcStruct] + // Wrapper functions used to wrap Spark SQL input arguments into Hive specific format + private[this] val wrappers = dataSchema.zip(structOI.getAllStructFieldRefs().asScala.toSeq).map { + case (f, i) => wrapperFor(i.getFieldObjectInspector, f.dataType) + } + private[this] def wrapOrcStruct( struct: OrcStruct, oi: SettableStructObjectInspector, @@ -208,10 +215,8 @@ private[orc] class OrcSerializer(dataSchema: StructType, conf: Configuration) oi.setStructFieldData( struct, fieldRefs.get(i), - wrap( - row.get(i, dataSchema(i).dataType), - fieldRefs.get(i).getFieldObjectInspector, - dataSchema(i).dataType)) + wrappers(i)(row.get(i, dataSchema(i).dataType)) + ) i += 1 } } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcOptions.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcOptions.scala index ccaa568dcce2a..043eb69818ba1 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcOptions.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcOptions.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.hive.orc +import java.util.Locale + import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap /** @@ -41,9 +43,9 @@ private[orc] class OrcOptions(@transient private val parameters: CaseInsensitive val codecName = parameters .get("compression") .orElse(orcCompressionConf) - .getOrElse("snappy").toLowerCase + .getOrElse("snappy").toLowerCase(Locale.ROOT) if (!shortOrcCompressionCodecNames.contains(codecName)) { - val availableCodecs = shortOrcCompressionCodecNames.keys.map(_.toLowerCase) + val availableCodecs = shortOrcCompressionCodecNames.keys.map(_.toLowerCase(Locale.ROOT)) throw new IllegalArgumentException(s"Codec [$codecName] " + s"is not available. Available codecs are ${availableCodecs.mkString(", ")}.") } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala index efc2f0098454b..d9bb1f8c7edcc 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala @@ -24,6 +24,8 @@ import scala.collection.JavaConverters._ import scala.collection.mutable import scala.language.implicitConversions +import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.fs.Path import org.apache.hadoop.hive.conf.HiveConf.ConfVars import org.apache.hadoop.hive.ql.exec.FunctionRegistry import org.apache.hadoop.hive.serde2.`lazy`.LazySimpleSerDe @@ -31,14 +33,13 @@ import org.apache.hadoop.hive.serde2.`lazy`.LazySimpleSerDe import org.apache.spark.{SparkConf, SparkContext} import org.apache.spark.internal.Logging import org.apache.spark.sql.{SparkSession, SQLContext} -import org.apache.spark.sql.catalyst.analysis._ -import org.apache.spark.sql.catalyst.analysis.FunctionRegistry.FunctionBuilder -import org.apache.spark.sql.catalyst.expressions.ExpressionInfo +import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.execution.QueryExecution import org.apache.spark.sql.execution.command.CacheTableCommand import org.apache.spark.sql.hive._ -import org.apache.spark.sql.internal.{SharedState, SQLConf} +import org.apache.spark.sql.hive.client.HiveClient +import org.apache.spark.sql.internal.{SessionState, SharedState, SQLConf, WithTestConf} import org.apache.spark.sql.internal.StaticSQLConf.CATALOG_IMPLEMENTATION import org.apache.spark.util.{ShutdownHookManager, Utils} @@ -57,6 +58,37 @@ object TestHive .set("spark.ui.enabled", "false"))) +case class TestHiveVersion(hiveClient: HiveClient) + extends TestHiveContext(TestHive.sparkContext, hiveClient) + + +private[hive] class TestHiveExternalCatalog( + conf: SparkConf, + hadoopConf: Configuration, + hiveClient: Option[HiveClient] = None) + extends HiveExternalCatalog(conf, hadoopConf) with Logging { + + override lazy val client: HiveClient = + hiveClient.getOrElse { + HiveUtils.newClientForMetadata(conf, hadoopConf) + } +} + + +private[hive] class TestHiveSharedState( + sc: SparkContext, + hiveClient: Option[HiveClient] = None) + extends SharedState(sc) { + + override lazy val externalCatalog: TestHiveExternalCatalog = { + new TestHiveExternalCatalog( + sc.conf, + sc.hadoopConfiguration, + hiveClient) + } +} + + /** * A locally running test instance of Spark's Hive execution engine. * @@ -80,12 +112,16 @@ class TestHiveContext( this(new TestHiveSparkSession(HiveUtils.withHiveExternalCatalog(sc), loadTestTables)) } + def this(sc: SparkContext, hiveClient: HiveClient) { + this(new TestHiveSparkSession(HiveUtils.withHiveExternalCatalog(sc), + hiveClient, + loadTestTables = false)) + } + override def newSession(): TestHiveContext = { new TestHiveContext(sparkSession.newSession()) } - override def sessionState: TestHiveSessionState = sparkSession.sessionState - def setCacheTables(c: Boolean): Unit = { sparkSession.setCacheTables(c) } @@ -109,12 +145,14 @@ class TestHiveContext( * * @param sc SparkContext * @param existingSharedState optional [[SharedState]] + * @param parentSessionState optional parent [[SessionState]] * @param loadTestTables if true, load the test tables. They can only be loaded when running * in the JVM, i.e when calling from Python this flag has to be false. */ private[hive] class TestHiveSparkSession( @transient private val sc: SparkContext, - @transient private val existingSharedState: Option[SharedState], + @transient private val existingSharedState: Option[TestHiveSharedState], + @transient private val parentSessionState: Option[SessionState], private val loadTestTables: Boolean) extends SparkSession(sc) with Logging { self => @@ -122,6 +160,15 @@ private[hive] class TestHiveSparkSession( this( sc, existingSharedState = None, + parentSessionState = None, + loadTestTables) + } + + def this(sc: SparkContext, hiveClient: HiveClient, loadTestTables: Boolean) { + this( + sc, + existingSharedState = Some(new TestHiveSharedState(sc, Some(hiveClient))), + parentSessionState = None, loadTestTables) } @@ -140,18 +187,29 @@ private[hive] class TestHiveSparkSession( assume(sc.conf.get(CATALOG_IMPLEMENTATION) == "hive") @transient - override lazy val sharedState: SharedState = { - existingSharedState.getOrElse(new SharedState(sc)) + override lazy val sharedState: TestHiveSharedState = { + existingSharedState.getOrElse(new TestHiveSharedState(sc)) } - // TODO: Let's remove TestHiveSessionState. Otherwise, we are not really testing the reflection - // logic based on the setting of CATALOG_IMPLEMENTATION. @transient - override lazy val sessionState: TestHiveSessionState = - new TestHiveSessionState(self) + override lazy val sessionState: SessionState = { + new TestHiveSessionStateBuilder(this, parentSessionState).build() + } + + lazy val metadataHive: HiveClient = sharedState.externalCatalog.client.newSession() override def newSession(): TestHiveSparkSession = { - new TestHiveSparkSession(sc, Some(sharedState), loadTestTables) + new TestHiveSparkSession(sc, Some(sharedState), None, loadTestTables) + } + + override def cloneSession(): SparkSession = { + val result = new TestHiveSparkSession( + sparkContext, + Some(sharedState), + Some(sessionState), + loadTestTables) + result.sessionState // force copy of SessionState + result } private var cacheTables: Boolean = false @@ -433,23 +491,31 @@ private[hive] class TestHiveSparkSession( sessionState.catalog.clearTempTables() sessionState.catalog.tableRelationCache.invalidateAll() - sessionState.metadataHive.reset() + metadataHive.reset() FunctionRegistry.getFunctionNames.asScala.filterNot(originalUDFs.contains(_)). foreach { udfName => FunctionRegistry.unregisterTemporaryUDF(udfName) } + // HDFS root scratch dir requires the write all (733) permission. For each connecting user, + // an HDFS scratch dir: ${hive.exec.scratchdir}/ is created, with + // ${hive.scratch.dir.permission}. To resolve the permission issue, the simplest way is to + // delete it. Later, it will be re-created with the right permission. + val location = new Path(sc.hadoopConfiguration.get(ConfVars.SCRATCHDIR.varname)) + val fs = location.getFileSystem(sc.hadoopConfiguration) + fs.delete(location, true) + // Some tests corrupt this value on purpose, which breaks the RESET call below. sessionState.conf.setConfString("fs.defaultFS", new File(".").toURI.toString) // It is important that we RESET first as broken hooks that might have been set could break // other sql exec here. - sessionState.metadataHive.runSqlHive("RESET") + metadataHive.runSqlHive("RESET") // For some reason, RESET does not reset the following variables... // https://issues.apache.org/jira/browse/HIVE-9004 - sessionState.metadataHive.runSqlHive("set hive.table.parameters.default=") - sessionState.metadataHive.runSqlHive("set datanucleus.cache.collections=true") - sessionState.metadataHive.runSqlHive("set datanucleus.cache.collections.lazy=true") + metadataHive.runSqlHive("set hive.table.parameters.default=") + metadataHive.runSqlHive("set datanucleus.cache.collections=true") + metadataHive.runSqlHive("set datanucleus.cache.collections.lazy=true") // Lots of tests fail if we do not change the partition whitelist from the default. - sessionState.metadataHive.runSqlHive("set hive.metastore.partition.name.whitelist.pattern=.*") + metadataHive.runSqlHive("set hive.metastore.partition.name.whitelist.pattern=.*") sessionState.catalog.setCurrentDatabase("default") } catch { @@ -492,26 +558,6 @@ private[hive] class TestHiveQueryExecution( } } -private[hive] class TestHiveSessionState( - sparkSession: TestHiveSparkSession) - extends HiveSessionState(sparkSession) { self => - - override lazy val conf: SQLConf = { - new SQLConf { - clear() - override def caseSensitiveAnalysis: Boolean = getConf(SQLConf.CASE_SENSITIVE, false) - override def clear(): Unit = { - super.clear() - TestHiveContext.overrideConfs.foreach { case (k, v) => setConfString(k, v) } - } - } - } - - override def executePlan(plan: LogicalPlan): TestHiveQueryExecution = { - new TestHiveQueryExecution(sparkSession, plan) - } -} - private[hive] object TestHiveContext { @@ -537,3 +583,18 @@ private[hive] object TestHiveContext { } } + +private[sql] class TestHiveSessionStateBuilder( + session: SparkSession, + state: Option[SessionState]) + extends HiveSessionStateBuilder(session, state) + with WithTestConf { + + override def overrideConfs: Map[String, String] = TestHiveContext.overrideConfs + + override def createQueryExecution: (LogicalPlan) => QueryExecution = { plan => + new TestHiveQueryExecution(session.asInstanceOf[TestHiveSparkSession], plan) + } + + override protected def newBuilder: NewBuilder = new TestHiveSessionStateBuilder(_, _) +} diff --git a/sql/hive/src/test/java/org/apache/spark/sql/hive/JavaMetastoreDataSourcesSuite.java b/sql/hive/src/test/java/org/apache/spark/sql/hive/JavaMetastoreDataSourcesSuite.java index 0b157a45e6e05..25bd4d0017bd8 100644 --- a/sql/hive/src/test/java/org/apache/spark/sql/hive/JavaMetastoreDataSourcesSuite.java +++ b/sql/hive/src/test/java/org/apache/spark/sql/hive/JavaMetastoreDataSourcesSuite.java @@ -72,8 +72,7 @@ public void setUp() throws IOException { path.delete(); } HiveSessionCatalog catalog = (HiveSessionCatalog) sqlContext.sessionState().catalog(); - hiveManagedPath = new Path( - catalog.hiveDefaultTableFilePath(new TableIdentifier("javaSavedTable"))); + hiveManagedPath = new Path(catalog.defaultTablePath(new TableIdentifier("javaSavedTable"))); fs = hiveManagedPath.getFileSystem(sc.hadoopConfiguration()); fs.delete(hiveManagedPath, true); diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/execution/benchmark/ObjectHashAggregateExecBenchmark.scala b/sql/hive/src/test/scala/org/apache/spark/sql/execution/benchmark/ObjectHashAggregateExecBenchmark.scala index 197110f4912a7..73383ae4d4118 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/execution/benchmark/ObjectHashAggregateExecBenchmark.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/execution/benchmark/ObjectHashAggregateExecBenchmark.scala @@ -22,7 +22,9 @@ import scala.concurrent.duration._ import org.apache.hadoop.hive.ql.udf.generic.GenericUDAFPercentileApprox import org.apache.spark.sql.Column -import org.apache.spark.sql.catalyst.expressions.{ExpressionInfo, Literal} +import org.apache.spark.sql.catalyst.FunctionIdentifier +import org.apache.spark.sql.catalyst.catalog.CatalogFunction +import org.apache.spark.sql.catalyst.expressions.Literal import org.apache.spark.sql.catalyst.expressions.aggregate.ApproximatePercentile import org.apache.spark.sql.hive.HiveSessionCatalog import org.apache.spark.sql.hive.execution.TestingTypedCount @@ -217,9 +219,9 @@ class ObjectHashAggregateExecBenchmark extends BenchmarkBase with TestHiveSingle private def registerHiveFunction(functionName: String, clazz: Class[_]): Unit = { val sessionCatalog = sparkSession.sessionState.catalog.asInstanceOf[HiveSessionCatalog] - val builder = sessionCatalog.makeFunctionBuilder(functionName, clazz.getName) - val info = new ExpressionInfo(clazz.getName, functionName) - sessionCatalog.createTempFunction(functionName, info, builder, ignoreIfExists = false) + val functionIdentifier = FunctionIdentifier(functionName, database = None) + val func = CatalogFunction(functionIdentifier, clazz.getName, resources = Nil) + sessionCatalog.registerFunction(func, ignoreIfExists = false) } private def percentile_approx( diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/CachedTableSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/CachedTableSuite.scala index 8ccc2b7527f24..d3cbf898e2439 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/CachedTableSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/CachedTableSuite.scala @@ -195,10 +195,8 @@ class CachedTableSuite extends QueryTest with SQLTestUtils with TestHiveSingleto tempPath.delete() table("src").write.mode(SaveMode.Overwrite).parquet(tempPath.toString) sql("DROP TABLE IF EXISTS refreshTable") - sparkSession.catalog.createExternalTable("refreshTable", tempPath.toString, "parquet") - checkAnswer( - table("refreshTable"), - table("src").collect()) + sparkSession.catalog.createTable("refreshTable", tempPath.toString, "parquet") + checkAnswer(table("refreshTable"), table("src")) // Cache the table. sql("CACHE TABLE refreshTable") assertCached(table("refreshTable")) @@ -331,7 +329,7 @@ class CachedTableSuite extends QueryTest with SQLTestUtils with TestHiveSingleto fileFormat = new ParquetFileFormat(), options = Map.empty)(sparkSession = spark) - val plan = LogicalRelation(relation, catalogTable = Some(tableMeta)) + val plan = LogicalRelation(relation, tableMeta) spark.sharedState.cacheManager.cacheQuery(Dataset.ofRows(spark, plan)) assert(spark.sharedState.cacheManager.lookupCachedData(plan).isDefined) @@ -344,7 +342,7 @@ class CachedTableSuite extends QueryTest with SQLTestUtils with TestHiveSingleto bucketSpec = None, fileFormat = new ParquetFileFormat(), options = Map.empty)(sparkSession = spark) - val samePlan = LogicalRelation(sameRelation, catalogTable = Some(tableMeta)) + val samePlan = LogicalRelation(sameRelation, tableMeta) assert(spark.sharedState.cacheManager.lookupCachedData(samePlan).isDefined) } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDDLCommandSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDDLCommandSuite.scala index 6d7a1c3937a96..59cc6605a1243 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDDLCommandSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDDLCommandSuite.scala @@ -17,6 +17,9 @@ package org.apache.spark.sql.hive +import java.net.URI +import java.util.Locale + import org.apache.spark.sql.{AnalysisException, SaveMode} import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute @@ -47,7 +50,7 @@ class HiveDDLCommandSuite extends PlanTest with SQLTestUtils with TestHiveSingle val e = intercept[ParseException] { parser.parsePlan(sql) } - assert(e.getMessage.toLowerCase.contains("operation not allowed")) + assert(e.getMessage.toLowerCase(Locale.ROOT).contains("operation not allowed")) } private def analyzeCreateTable(sql: String): CatalogTable = { @@ -70,7 +73,7 @@ class HiveDDLCommandSuite extends PlanTest with SQLTestUtils with TestHiveSingle assert(desc.identifier.database == Some("mydb")) assert(desc.identifier.table == "page_view") assert(desc.tableType == CatalogTableType.EXTERNAL) - assert(desc.storage.locationUri == Some("/user/external/page_view")) + assert(desc.storage.locationUri == Some(new URI("/user/external/page_view"))) assert(desc.schema.isEmpty) // will be populated later when the table is actually created assert(desc.comment == Some("This is the staging page view table")) // TODO will be SQLText @@ -102,7 +105,7 @@ class HiveDDLCommandSuite extends PlanTest with SQLTestUtils with TestHiveSingle assert(desc.identifier.database == Some("mydb")) assert(desc.identifier.table == "page_view") assert(desc.tableType == CatalogTableType.EXTERNAL) - assert(desc.storage.locationUri == Some("/user/external/page_view")) + assert(desc.storage.locationUri == Some(new URI("/user/external/page_view"))) assert(desc.schema.isEmpty) // will be populated later when the table is actually created // TODO will be SQLText assert(desc.comment == Some("This is the staging page view table")) @@ -338,7 +341,7 @@ class HiveDDLCommandSuite extends PlanTest with SQLTestUtils with TestHiveSingle val query = "CREATE EXTERNAL TABLE tab1 (id int, name string) LOCATION '/path/to/nowhere'" val (desc, _) = extractTableDesc(query) assert(desc.tableType == CatalogTableType.EXTERNAL) - assert(desc.storage.locationUri == Some("/path/to/nowhere")) + assert(desc.storage.locationUri == Some(new URI("/path/to/nowhere"))) } test("create table - if not exists") { @@ -469,7 +472,7 @@ class HiveDDLCommandSuite extends PlanTest with SQLTestUtils with TestHiveSingle assert(desc.viewText.isEmpty) assert(desc.viewDefaultDatabase.isEmpty) assert(desc.viewQueryColumnNames.isEmpty) - assert(desc.storage.locationUri == Some("/path/to/mercury")) + assert(desc.storage.locationUri == Some(new URI("/path/to/mercury"))) assert(desc.storage.inputFormat == Some("winput")) assert(desc.storage.outputFormat == Some("wowput")) assert(desc.storage.serde == Some("org.apache.poof.serde.Baff")) @@ -644,7 +647,7 @@ class HiveDDLCommandSuite extends PlanTest with SQLTestUtils with TestHiveSingle .add("id", "int") .add("name", "string", nullable = true, comment = "blabla")) assert(table.provider == Some(DDLUtils.HIVE_PROVIDER)) - assert(table.storage.locationUri == Some("/tmp/file")) + assert(table.storage.locationUri == Some(new URI("/tmp/file"))) assert(table.storage.properties == Map("my_prop" -> "1")) assert(table.comment == Some("BLABLA")) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveExternalCatalogBackwardCompatibilitySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveExternalCatalogBackwardCompatibilitySuite.scala index ee632d24b717e..705d43f1f3aba 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveExternalCatalogBackwardCompatibilitySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveExternalCatalogBackwardCompatibilitySuite.scala @@ -40,7 +40,8 @@ class HiveExternalCatalogBackwardCompatibilitySuite extends QueryTest spark.sharedState.externalCatalog.asInstanceOf[HiveExternalCatalog].client val tempDir = Utils.createTempDir().getCanonicalFile - val tempDirUri = tempDir.toURI.toString.stripSuffix("/") + val tempDirUri = tempDir.toURI + val tempDirStr = tempDir.getAbsolutePath override def beforeEach(): Unit = { sql("CREATE DATABASE test_db") @@ -59,9 +60,7 @@ class HiveExternalCatalogBackwardCompatibilitySuite extends QueryTest } private def defaultTableURI(tableName: String): URI = { - val defaultPath = - spark.sessionState.catalog.defaultTablePath(TableIdentifier(tableName, Some("test_db"))) - new Path(defaultPath).toUri + spark.sessionState.catalog.defaultTablePath(TableIdentifier(tableName, Some("test_db"))) } // Raw table metadata that are dumped from tables created by Spark 2.0. Note that, all spark @@ -170,8 +169,8 @@ class HiveExternalCatalogBackwardCompatibilitySuite extends QueryTest identifier = TableIdentifier("tbl7", Some("test_db")), tableType = CatalogTableType.EXTERNAL, storage = CatalogStorageFormat.empty.copy( - locationUri = Some(defaultTableURI("tbl7").toString + "-__PLACEHOLDER__"), - properties = Map("path" -> tempDirUri)), + locationUri = Some(new URI(defaultTableURI("tbl7") + "-__PLACEHOLDER__")), + properties = Map("path" -> tempDirStr)), schema = new StructType(), provider = Some("json"), properties = Map( @@ -184,7 +183,7 @@ class HiveExternalCatalogBackwardCompatibilitySuite extends QueryTest tableType = CatalogTableType.EXTERNAL, storage = CatalogStorageFormat.empty.copy( locationUri = Some(tempDirUri), - properties = Map("path" -> tempDirUri)), + properties = Map("path" -> tempDirStr)), schema = simpleSchema, properties = Map( "spark.sql.sources.provider" -> "parquet", @@ -195,8 +194,8 @@ class HiveExternalCatalogBackwardCompatibilitySuite extends QueryTest identifier = TableIdentifier("tbl9", Some("test_db")), tableType = CatalogTableType.EXTERNAL, storage = CatalogStorageFormat.empty.copy( - locationUri = Some(defaultTableURI("tbl9").toString + "-__PLACEHOLDER__"), - properties = Map("path" -> tempDirUri)), + locationUri = Some(new URI(defaultTableURI("tbl9") + "-__PLACEHOLDER__")), + properties = Map("path" -> tempDirStr)), schema = new StructType(), provider = Some("json"), properties = Map("spark.sql.sources.provider" -> "json")) @@ -220,7 +219,7 @@ class HiveExternalCatalogBackwardCompatibilitySuite extends QueryTest if (tbl.tableType == CatalogTableType.EXTERNAL) { // trim the URI prefix - val tableLocation = new URI(readBack.storage.locationUri.get).getPath + val tableLocation = readBack.storage.locationUri.get.getPath val expectedLocation = tempDir.toURI.getPath.stripSuffix("/") assert(tableLocation == expectedLocation) } @@ -236,7 +235,7 @@ class HiveExternalCatalogBackwardCompatibilitySuite extends QueryTest val readBack = getTableMetadata(tbl.identifier.table) // trim the URI prefix - val actualTableLocation = new URI(readBack.storage.locationUri.get).getPath + val actualTableLocation = readBack.storage.locationUri.get.getPath val expected = dir.toURI.getPath.stripSuffix("/") assert(actualTableLocation == expected) } @@ -252,7 +251,7 @@ class HiveExternalCatalogBackwardCompatibilitySuite extends QueryTest assert(readBack.schema.sameType(expectedSchema)) // trim the URI prefix - val actualTableLocation = new URI(readBack.storage.locationUri.get).getPath + val actualTableLocation = readBack.storage.locationUri.get.getPath val expectedLocation = if (tbl.tableType == CatalogTableType.EXTERNAL) { tempDir.toURI.getPath.stripSuffix("/") } else { diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveExternalCatalogSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveExternalCatalogSuite.scala index 4349f1aa23be0..bd54c043c6ec4 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveExternalCatalogSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveExternalCatalogSuite.scala @@ -22,7 +22,6 @@ import org.apache.hadoop.conf.Configuration import org.apache.spark.SparkConf import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.catalog._ -import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.execution.command.DDLUtils import org.apache.spark.sql.types.StructType @@ -50,13 +49,6 @@ class HiveExternalCatalogSuite extends ExternalCatalogSuite { import utils._ - test("list partitions by filter") { - val catalog = newBasicCatalog() - val selectedPartitions = catalog.listPartitionsByFilter("db2", "tbl2", Seq('a.int === 1), "GMT") - assert(selectedPartitions.length == 1) - assert(selectedPartitions.head.spec == part1.spec) - } - test("SPARK-18647: do not put provider in table properties for Hive serde table") { val catalog = newBasicCatalog() val hiveTable = CatalogTable( diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveExternalSessionCatalogSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveExternalSessionCatalogSuite.scala new file mode 100644 index 0000000000000..285f35b0b0eac --- /dev/null +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveExternalSessionCatalogSuite.scala @@ -0,0 +1,40 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive + +import org.apache.spark.sql.catalyst.catalog.{CatalogTestUtils, ExternalCatalog, SessionCatalogSuite} +import org.apache.spark.sql.hive.test.TestHiveSingleton + +class HiveExternalSessionCatalogSuite extends SessionCatalogSuite with TestHiveSingleton { + + protected override val isHiveExternalCatalog = true + + private val externalCatalog = { + val catalog = spark.sharedState.externalCatalog + catalog.asInstanceOf[HiveExternalCatalog].client.reset() + catalog + } + + protected val utils = new CatalogTestUtils { + override val tableInputFormat: String = "org.apache.hadoop.mapred.SequenceFileInputFormat" + override val tableOutputFormat: String = + "org.apache.hadoop.hive.ql.io.HiveSequenceFileOutputFormat" + override val defaultProvider: String = "hive" + override def newEmptyCatalog(): ExternalCatalog = externalCatalog + } +} diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveMetastoreCatalogSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveMetastoreCatalogSuite.scala index 16cf4d7ec67f6..d8fd68b63d1eb 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveMetastoreCatalogSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveMetastoreCatalogSuite.scala @@ -62,7 +62,7 @@ class HiveMetastoreCatalogSuite extends TestHiveSingleton with SQLTestUtils { spark.sql("create view vw1 as select 1 as id") val plan = spark.sql("select id from vw1").queryExecution.analyzed val aliases = plan.collect { - case x @ SubqueryAlias("vw1", _, Some(TableIdentifier("vw1", Some("default")))) => x + case x @ SubqueryAlias("vw1", _) => x } assert(aliases.size == 1) } @@ -115,7 +115,7 @@ class DataSourceWithHiveMetastoreCatalogSuite assert(columns.map(_.dataType) === Seq(DecimalType(10, 3), StringType)) checkAnswer(table("t"), testDF) - assert(sessionState.metadataHive.runSqlHive("SELECT * FROM t") === Seq("1.1\t1", "2.1\t2")) + assert(sparkSession.metadataHive.runSqlHive("SELECT * FROM t") === Seq("1.1\t1", "2.1\t2")) } } @@ -140,14 +140,14 @@ class DataSourceWithHiveMetastoreCatalogSuite assert(hiveTable.storage.serde === Some(serde)) assert(hiveTable.tableType === CatalogTableType.EXTERNAL) - assert(hiveTable.storage.locationUri === Some(path.toString)) + assert(hiveTable.storage.locationUri === Some(makeQualifiedPath(dir.getAbsolutePath))) val columns = hiveTable.schema assert(columns.map(_.name) === Seq("d1", "d2")) assert(columns.map(_.dataType) === Seq(DecimalType(10, 3), StringType)) checkAnswer(table("t"), testDF) - assert(sessionState.metadataHive.runSqlHive("SELECT * FROM t") === + assert(sparkSession.metadataHive.runSqlHive("SELECT * FROM t") === Seq("1.1\t1", "2.1\t2")) } } @@ -176,7 +176,7 @@ class DataSourceWithHiveMetastoreCatalogSuite assert(columns.map(_.dataType) === Seq(IntegerType, StringType)) checkAnswer(table("t"), Row(1, "val_1")) - assert(sessionState.metadataHive.runSqlHive("SELECT * FROM t") === Seq("1\tval_1")) + assert(sparkSession.metadataHive.runSqlHive("SELECT * FROM t") === Seq("1\tval_1")) } } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSchemaInferenceSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSchemaInferenceSuite.scala new file mode 100644 index 0000000000000..319d02613f00a --- /dev/null +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSchemaInferenceSuite.scala @@ -0,0 +1,323 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive + +import java.io.File + +import scala.util.Random + +import org.scalatest.BeforeAndAfterEach + +import org.apache.spark.sql.catalyst.TableIdentifier +import org.apache.spark.sql.catalyst.catalog._ +import org.apache.spark.sql.execution.datasources.FileStatusCache +import org.apache.spark.sql.QueryTest +import org.apache.spark.sql.hive.test.TestHiveSingleton +import org.apache.spark.sql.internal.{HiveSerDe, SQLConf} +import org.apache.spark.sql.internal.SQLConf.HiveCaseSensitiveInferenceMode.{Value => InferenceMode, _} +import org.apache.spark.sql.test.SQLTestUtils +import org.apache.spark.sql.types._ + +class HiveSchemaInferenceSuite + extends QueryTest with TestHiveSingleton with SQLTestUtils with BeforeAndAfterEach { + + import HiveSchemaInferenceSuite._ + import HiveExternalCatalog.DATASOURCE_SCHEMA_PREFIX + + override def beforeEach(): Unit = { + super.beforeEach() + FileStatusCache.resetForTesting() + } + + override def afterEach(): Unit = { + super.afterEach() + spark.sessionState.catalog.tableRelationCache.invalidateAll() + FileStatusCache.resetForTesting() + } + + private val externalCatalog = spark.sharedState.externalCatalog.asInstanceOf[HiveExternalCatalog] + private val client = externalCatalog.client + + // Return a copy of the given schema with all field names converted to lower case. + private def lowerCaseSchema(schema: StructType): StructType = { + StructType(schema.map(f => f.copy(name = f.name.toLowerCase))) + } + + // Create a Hive external test table containing the given field and partition column names. + // Returns a case-sensitive schema for the table. + private def setupExternalTable( + fileType: String, + fields: Seq[String], + partitionCols: Seq[String], + dir: File): StructType = { + // Treat all table fields as bigints... + val structFields = fields.map { field => + StructField( + name = field, + dataType = LongType, + nullable = true, + metadata = new MetadataBuilder().putString(HIVE_TYPE_STRING, "bigint").build()) + } + // and all partition columns as ints + val partitionStructFields = partitionCols.map { field => + StructField( + // Partition column case isn't preserved + name = field.toLowerCase, + dataType = IntegerType, + nullable = true, + metadata = new MetadataBuilder().putString(HIVE_TYPE_STRING, "int").build()) + } + val schema = StructType(structFields ++ partitionStructFields) + + // Write some test data (partitioned if specified) + val writer = spark.range(NUM_RECORDS) + .selectExpr((fields ++ partitionCols).map("id as " + _): _*) + .write + .partitionBy(partitionCols: _*) + .mode("overwrite") + fileType match { + case ORC_FILE_TYPE => + writer.orc(dir.getAbsolutePath) + case PARQUET_FILE_TYPE => + writer.parquet(dir.getAbsolutePath) + } + + // Create Hive external table with lowercased schema + val serde = HiveSerDe.serdeMap(fileType) + client.createTable( + CatalogTable( + identifier = TableIdentifier(table = TEST_TABLE_NAME, database = Option(DATABASE)), + tableType = CatalogTableType.EXTERNAL, + storage = CatalogStorageFormat( + locationUri = Option(new java.net.URI(dir.getAbsolutePath)), + inputFormat = serde.inputFormat, + outputFormat = serde.outputFormat, + serde = serde.serde, + compressed = false, + properties = Map("serialization.format" -> "1")), + schema = schema, + provider = Option("hive"), + partitionColumnNames = partitionCols.map(_.toLowerCase), + properties = Map.empty), + true) + + // Add partition records (if specified) + if (!partitionCols.isEmpty) { + spark.catalog.recoverPartitions(TEST_TABLE_NAME) + } + + // Check that the table returned by HiveExternalCatalog has schemaPreservesCase set to false + // and that the raw table returned by the Hive client doesn't have any Spark SQL properties + // set (table needs to be obtained from client since HiveExternalCatalog filters these + // properties out). + assert(!externalCatalog.getTable(DATABASE, TEST_TABLE_NAME).schemaPreservesCase) + val rawTable = client.getTable(DATABASE, TEST_TABLE_NAME) + assert(rawTable.properties.filterKeys(_.startsWith(DATASOURCE_SCHEMA_PREFIX)) == Map.empty) + schema + } + + private def withTestTables( + fileType: String)(f: (Seq[String], Seq[String], StructType) => Unit): Unit = { + // Test both a partitioned and unpartitioned Hive table + val tableFields = Seq( + (Seq("fieldOne"), Seq("partCol1", "partCol2")), + (Seq("fieldOne", "fieldTwo"), Seq.empty[String])) + + tableFields.foreach { case (fields, partCols) => + withTempDir { dir => + val schema = setupExternalTable(fileType, fields, partCols, dir) + withTable(TEST_TABLE_NAME) { f(fields, partCols, schema) } + } + } + } + + private def withFileTypes(f: (String) => Unit): Unit + = Seq(ORC_FILE_TYPE, PARQUET_FILE_TYPE).foreach(f) + + private def withInferenceMode(mode: InferenceMode)(f: => Unit): Unit = { + withSQLConf( + HiveUtils.CONVERT_METASTORE_ORC.key -> "true", + SQLConf.HIVE_CASE_SENSITIVE_INFERENCE.key -> mode.toString)(f) + } + + private val inferenceKey = SQLConf.HIVE_CASE_SENSITIVE_INFERENCE.key + + private def testFieldQuery(fields: Seq[String]): Unit = { + if (!fields.isEmpty) { + val query = s"SELECT * FROM ${TEST_TABLE_NAME} WHERE ${Random.shuffle(fields).head} >= 0" + assert(spark.sql(query).count == NUM_RECORDS) + } + } + + private def testTableSchema(expectedSchema: StructType): Unit + = assert(spark.table(TEST_TABLE_NAME).schema == expectedSchema) + + withFileTypes { fileType => + test(s"$fileType: schema should be inferred and saved when INFER_AND_SAVE is specified") { + withInferenceMode(INFER_AND_SAVE) { + withTestTables(fileType) { (fields, partCols, schema) => + testFieldQuery(fields) + testFieldQuery(partCols) + testTableSchema(schema) + + // Verify the catalog table now contains the updated schema and properties + val catalogTable = externalCatalog.getTable(DATABASE, TEST_TABLE_NAME) + assert(catalogTable.schemaPreservesCase) + assert(catalogTable.schema == schema) + assert(catalogTable.partitionColumnNames == partCols.map(_.toLowerCase)) + } + } + } + } + + withFileTypes { fileType => + test(s"$fileType: schema should be inferred but not stored when INFER_ONLY is specified") { + withInferenceMode(INFER_ONLY) { + withTestTables(fileType) { (fields, partCols, schema) => + val originalTable = externalCatalog.getTable(DATABASE, TEST_TABLE_NAME) + testFieldQuery(fields) + testFieldQuery(partCols) + testTableSchema(schema) + // Catalog table shouldn't be altered + assert(externalCatalog.getTable(DATABASE, TEST_TABLE_NAME) == originalTable) + } + } + } + } + + withFileTypes { fileType => + test(s"$fileType: schema should not be inferred when NEVER_INFER is specified") { + withInferenceMode(NEVER_INFER) { + withTestTables(fileType) { (fields, partCols, schema) => + val originalTable = externalCatalog.getTable(DATABASE, TEST_TABLE_NAME) + // Only check the table schema as the test queries will break + testTableSchema(lowerCaseSchema(schema)) + assert(externalCatalog.getTable(DATABASE, TEST_TABLE_NAME) == originalTable) + } + } + } + } + + test("mergeWithMetastoreSchema() should return expected results") { + // Field type conflict resolution + assertResult( + StructType(Seq( + StructField("lowerCase", StringType), + StructField("UPPERCase", DoubleType, nullable = false)))) { + + HiveMetastoreCatalog.mergeWithMetastoreSchema( + StructType(Seq( + StructField("lowercase", StringType), + StructField("uppercase", DoubleType, nullable = false))), + + StructType(Seq( + StructField("lowerCase", BinaryType), + StructField("UPPERCase", IntegerType, nullable = true)))) + } + + // MetaStore schema is subset of parquet schema + assertResult( + StructType(Seq( + StructField("UPPERCase", DoubleType, nullable = false)))) { + + HiveMetastoreCatalog.mergeWithMetastoreSchema( + StructType(Seq( + StructField("uppercase", DoubleType, nullable = false))), + + StructType(Seq( + StructField("lowerCase", BinaryType), + StructField("UPPERCase", IntegerType, nullable = true)))) + } + + // Metastore schema contains additional non-nullable fields. + assert(intercept[Throwable] { + HiveMetastoreCatalog.mergeWithMetastoreSchema( + StructType(Seq( + StructField("uppercase", DoubleType, nullable = false), + StructField("lowerCase", BinaryType, nullable = false))), + + StructType(Seq( + StructField("UPPERCase", IntegerType, nullable = true)))) + }.getMessage.contains("Detected conflicting schemas")) + + // Conflicting non-nullable field names + intercept[Throwable] { + HiveMetastoreCatalog.mergeWithMetastoreSchema( + StructType(Seq(StructField("lower", StringType, nullable = false))), + StructType(Seq(StructField("lowerCase", BinaryType)))) + } + + // Check that merging missing nullable fields works as expected. + assertResult( + StructType(Seq( + StructField("firstField", StringType, nullable = true), + StructField("secondField", StringType, nullable = true), + StructField("thirdfield", StringType, nullable = true)))) { + HiveMetastoreCatalog.mergeWithMetastoreSchema( + StructType(Seq( + StructField("firstfield", StringType, nullable = true), + StructField("secondfield", StringType, nullable = true), + StructField("thirdfield", StringType, nullable = true))), + StructType(Seq( + StructField("firstField", StringType, nullable = true), + StructField("secondField", StringType, nullable = true)))) + } + + // Merge should fail if the Metastore contains any additional fields that are not + // nullable. + assert(intercept[Throwable] { + HiveMetastoreCatalog.mergeWithMetastoreSchema( + StructType(Seq( + StructField("firstfield", StringType, nullable = true), + StructField("secondfield", StringType, nullable = true), + StructField("thirdfield", StringType, nullable = false))), + StructType(Seq( + StructField("firstField", StringType, nullable = true), + StructField("secondField", StringType, nullable = true)))) + }.getMessage.contains("Detected conflicting schemas")) + + // Schema merge should maintain metastore order. + assertResult( + StructType(Seq( + StructField("first_field", StringType, nullable = true), + StructField("second_field", StringType, nullable = true), + StructField("third_field", StringType, nullable = true), + StructField("fourth_field", StringType, nullable = true), + StructField("fifth_field", StringType, nullable = true)))) { + HiveMetastoreCatalog.mergeWithMetastoreSchema( + StructType(Seq( + StructField("first_field", StringType, nullable = true), + StructField("second_field", StringType, nullable = true), + StructField("third_field", StringType, nullable = true), + StructField("fourth_field", StringType, nullable = true), + StructField("fifth_field", StringType, nullable = true))), + StructType(Seq( + StructField("fifth_field", StringType, nullable = true), + StructField("third_field", StringType, nullable = true), + StructField("second_field", StringType, nullable = true)))) + } + } +} + +object HiveSchemaInferenceSuite { + private val NUM_RECORDS = 10 + private val DATABASE = "default" + private val TEST_TABLE_NAME = "test_table" + private val ORC_FILE_TYPE = "orc" + private val PARQUET_FILE_TYPE = "parquet" +} diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSessionStateSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSessionStateSuite.scala new file mode 100644 index 0000000000000..958ad3e1c3ce8 --- /dev/null +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSessionStateSuite.scala @@ -0,0 +1,41 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive + +import org.scalatest.BeforeAndAfterEach + +import org.apache.spark.sql._ +import org.apache.spark.sql.hive.test.TestHiveSingleton + +/** + * Run all tests from `SessionStateSuite` with a Hive based `SessionState`. + */ +class HiveSessionStateSuite extends SessionStateSuite + with TestHiveSingleton with BeforeAndAfterEach { + + override def beforeAll(): Unit = { + // Reuse the singleton session + activeSession = spark + } + + override def afterAll(): Unit = { + // Set activeSession to null to avoid stopping the singleton session + activeSession = null + super.afterAll() + } +} diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSparkSubmitSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSparkSubmitSuite.scala index fdb5dd2309a08..661b9a4f3fbc1 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSparkSubmitSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSparkSubmitSuite.scala @@ -487,7 +487,7 @@ object SetWarehouseLocationTest extends Logging { val tableMetadata = catalog.getTableMetadata(TableIdentifier("testLocation", Some("default"))) val expectedLocation = - "file:" + expectedWarehouseLocation.toString + "/testlocation" + CatalogUtils.stringToURI(s"file:${expectedWarehouseLocation.toString}/testlocation") val actualLocation = tableMetadata.location if (actualLocation != expectedLocation) { throw new Exception( @@ -502,8 +502,8 @@ object SetWarehouseLocationTest extends Logging { sparkSession.sql("create table testLocation (a int)") val tableMetadata = catalog.getTableMetadata(TableIdentifier("testLocation", Some("testLocationDB"))) - val expectedLocation = - "file:" + expectedWarehouseLocation.toString + "/testlocationdb.db/testlocation" + val expectedLocation = CatalogUtils.stringToURI( + s"file:${expectedWarehouseLocation.toString}/testlocationdb.db/testlocation") val actualLocation = tableMetadata.location if (actualLocation != expectedLocation) { throw new Exception( @@ -870,14 +870,16 @@ object SPARK_18360 { val rawTable = hiveClient.getTable("default", "test_tbl") // Hive will use the value of `hive.metastore.warehouse.dir` to generate default table // location for tables in default database. - assert(rawTable.storage.locationUri.get.contains(newWarehousePath)) + assert(rawTable.storage.locationUri.map( + CatalogUtils.URIToString(_)).get.contains(newWarehousePath)) hiveClient.dropTable("default", "test_tbl", ignoreIfNotExists = false, purge = false) spark.sharedState.externalCatalog.createTable(tableMeta, ignoreIfExists = false) val readBack = spark.sharedState.externalCatalog.getTable("default", "test_tbl") // Spark SQL will use the location of default database to generate default table // location for tables in default database. - assert(readBack.storage.locationUri.get.contains(defaultDbLocation)) + assert(readBack.storage.locationUri.map(CatalogUtils.URIToString(_)) + .get.contains(defaultDbLocation)) } finally { hiveClient.dropTable("default", "test_tbl", ignoreIfNotExists = true, purge = false) hiveClient.runSqlHive(s"SET hive.metastore.warehouse.dir=$defaultDbLocation") diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala index 03ea0c8c77682..b554694815571 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala @@ -379,8 +379,7 @@ class MetastoreDataSourcesSuite extends QueryTest with SQLTestUtils with TestHiv |) """.stripMargin) - val expectedPath = - sessionState.catalog.hiveDefaultTableFilePath(TableIdentifier("ctasJsonTable")) + val expectedPath = sessionState.catalog.defaultTablePath(TableIdentifier("ctasJsonTable")) val filesystemPath = new Path(expectedPath) val fs = filesystemPath.getFileSystem(spark.sessionState.newHadoopConf()) fs.delete(filesystemPath, true) @@ -486,7 +485,7 @@ class MetastoreDataSourcesSuite extends QueryTest with SQLTestUtils with TestHiv sql("DROP TABLE savedJsonTable") intercept[AnalysisException] { read.json( - sessionState.catalog.hiveDefaultTableFilePath(TableIdentifier("savedJsonTable"))) + sessionState.catalog.defaultTablePath(TableIdentifier("savedJsonTable")).toString) } } @@ -756,7 +755,7 @@ class MetastoreDataSourcesSuite extends QueryTest with SQLTestUtils with TestHiv serde = None, compressed = false, properties = Map( - "path" -> sessionState.catalog.hiveDefaultTableFilePath(TableIdentifier(tableName))) + "path" -> sessionState.catalog.defaultTablePath(TableIdentifier(tableName)).toString) ), properties = Map( DATASOURCE_PROVIDER -> "json", @@ -768,9 +767,6 @@ class MetastoreDataSourcesSuite extends QueryTest with SQLTestUtils with TestHiv sessionState.refreshTable(tableName) val actualSchema = table(tableName).schema assert(schema === actualSchema) - - // Checks the DESCRIBE output. - checkAnswer(sql("DESCRIBE spark6655"), Row("int", "int", null) :: Nil) } } @@ -1011,7 +1007,7 @@ class MetastoreDataSourcesSuite extends QueryTest with SQLTestUtils with TestHiv identifier = TableIdentifier("not_skip_hive_metadata"), tableType = CatalogTableType.EXTERNAL, storage = CatalogStorageFormat.empty.copy( - locationUri = Some(tempPath.getCanonicalPath), + locationUri = Some(tempPath.toURI), properties = Map("skipHiveMetadata" -> "false") ), schema = schema, @@ -1382,7 +1378,10 @@ class MetastoreDataSourcesSuite extends QueryTest with SQLTestUtils with TestHiv checkAnswer(spark.table("old"), Row(1, "a")) - checkAnswer(sql("DESC old"), Row("i", "int", null) :: Row("j", "string", null) :: Nil) + val expectedSchema = StructType(Seq( + StructField("i", IntegerType, nullable = true), + StructField("j", StringType, nullable = true))) + assert(table("old").schema === expectedSchema) } } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/MultiDatabaseSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/MultiDatabaseSuite.scala index 47ee4dd4d952c..4aea6d14efb0e 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/MultiDatabaseSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/MultiDatabaseSuite.scala @@ -17,6 +17,10 @@ package org.apache.spark.sql.hive +import java.net.URI + +import org.apache.hadoop.fs.Path + import org.apache.spark.sql.{AnalysisException, QueryTest, SaveMode} import org.apache.spark.sql.hive.test.TestHiveSingleton import org.apache.spark.sql.test.SQLTestUtils @@ -26,8 +30,8 @@ class MultiDatabaseSuite extends QueryTest with SQLTestUtils with TestHiveSingle private def checkTablePath(dbName: String, tableName: String): Unit = { val metastoreTable = spark.sharedState.externalCatalog.getTable(dbName, tableName) - val expectedPath = - spark.sharedState.externalCatalog.getDatabase(dbName).locationUri + "/" + tableName + val expectedPath = new Path(new Path( + spark.sharedState.externalCatalog.getDatabase(dbName).locationUri), tableName).toUri assert(metastoreTable.location === expectedPath) } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/PartitionProviderCompatibilitySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/PartitionProviderCompatibilitySuite.scala index 96385961c9a52..9440a17677ebf 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/PartitionProviderCompatibilitySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/PartitionProviderCompatibilitySuite.scala @@ -22,7 +22,7 @@ import java.io.File import org.apache.hadoop.fs.Path import org.apache.spark.metrics.source.HiveCatalogMetrics -import org.apache.spark.sql.{AnalysisException, QueryTest} +import org.apache.spark.sql.{AnalysisException, QueryTest, Row} import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.hive.test.TestHiveSingleton import org.apache.spark.sql.internal.SQLConf @@ -316,6 +316,28 @@ class PartitionProviderCompatibilitySuite } } } + + test(s"SPARK-19887 partition value is null - partition management $enabled") { + withTable("test") { + Seq((1, "p", 1), (2, null, 2)).toDF("a", "b", "c") + .write.partitionBy("b", "c").saveAsTable("test") + checkAnswer(spark.table("test"), + Row(1, "p", 1) :: Row(2, null, 2) :: Nil) + + Seq((3, null: String, 3)).toDF("a", "b", "c") + .write.mode("append").partitionBy("b", "c").saveAsTable("test") + checkAnswer(spark.table("test"), + Row(1, "p", 1) :: Row(2, null, 2) :: Row(3, null, 3) :: Nil) + // make sure partition pruning also works. + checkAnswer(spark.table("test").filter($"b".isNotNull), Row(1, "p", 1)) + + // empty string is an invalid partition value and we treat it as null when read back. + Seq((4, "", 4)).toDF("a", "b", "c") + .write.mode("append").partitionBy("b", "c").saveAsTable("test") + checkAnswer(spark.table("test"), + Row(1, "p", 1) :: Row(2, null, 2) :: Row(3, null, 3) :: Row(4, null, 4) :: Nil) + } + } } /** diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala index 962998ea6fb68..3191b9975fbf9 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala @@ -413,7 +413,7 @@ class StatisticsSuite extends StatisticsCollectionTestBase with TestHiveSingleto } // Table lookup will make the table cached. spark.table(tableIndent) - statsBeforeUpdate = catalog.getCachedDataSourceTable(tableIndent) + statsBeforeUpdate = catalog.metastoreCatalog.getCachedDataSourceTable(tableIndent) .asInstanceOf[LogicalRelation].catalogTable.get.stats.get sql(s"INSERT INTO $tableName SELECT 2") @@ -423,7 +423,7 @@ class StatisticsSuite extends StatisticsCollectionTestBase with TestHiveSingleto sql(s"ANALYZE TABLE $tableName COMPUTE STATISTICS") } spark.table(tableIndent) - statsAfterUpdate = catalog.getCachedDataSourceTable(tableIndent) + statsAfterUpdate = catalog.metastoreCatalog.getCachedDataSourceTable(tableIndent) .asInstanceOf[LogicalRelation].catalogTable.get.stats.get } (statsBeforeUpdate, statsAfterUpdate) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/FiltersSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/FiltersSuite.scala index cd96c85f3e209..031c1a5ec0ec3 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/FiltersSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/FiltersSuite.scala @@ -65,6 +65,11 @@ class FiltersSuite extends SparkFunSuite with Logging { (Literal("") === a("varchar", StringType)) :: Nil, "") + filterTest("SPARK-19912 String literals should be escaped for Hive metastore partition pruning", + (a("stringcol", StringType) === Literal("p1\" and q=\"q1")) :: + (Literal("p2\" and q=\"q2") === a("stringcol", StringType)) :: Nil, + """stringcol = 'p1" and q="q1' and 'p2" and q="q2' = stringcol""") + private def filterTest(name: String, filters: Seq[Expression], result: String) = { test(name) { val converted = shim.convertFilters(testTable, filters) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/VersionsSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/VersionsSuite.scala index d61d10bf869e2..7aff49c0fc3b1 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/VersionsSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/VersionsSuite.scala @@ -18,23 +18,24 @@ package org.apache.spark.sql.hive.client import java.io.{ByteArrayOutputStream, File, PrintStream} +import java.net.URI import org.apache.hadoop.conf.Configuration -import org.apache.hadoop.fs.Path +import org.apache.hadoop.hive.common.StatsSetupConst import org.apache.hadoop.hive.ql.io.HiveIgnoreKeyTextOutputFormat import org.apache.hadoop.hive.serde2.`lazy`.LazySimpleSerDe import org.apache.hadoop.mapred.TextInputFormat +import org.apache.spark.SparkFunSuite import org.apache.spark.internal.Logging -import org.apache.spark.sql.{AnalysisException, QueryTest, Row} +import org.apache.spark.sql.{AnalysisException, Row} import org.apache.spark.sql.catalyst.{FunctionIdentifier, TableIdentifier} import org.apache.spark.sql.catalyst.analysis.{NoSuchDatabaseException, NoSuchPermanentFunctionException} import org.apache.spark.sql.catalyst.catalog._ import org.apache.spark.sql.catalyst.expressions.{AttributeReference, EqualTo, Literal} import org.apache.spark.sql.catalyst.util.quietly -import org.apache.spark.sql.hive.HiveUtils -import org.apache.spark.sql.hive.test.TestHiveSingleton -import org.apache.spark.sql.test.SQLTestUtils +import org.apache.spark.sql.hive.{HiveExternalCatalog, HiveUtils} +import org.apache.spark.sql.hive.test.TestHiveVersion import org.apache.spark.sql.types.IntegerType import org.apache.spark.sql.types.StructType import org.apache.spark.tags.ExtendedHiveTest @@ -47,14 +48,34 @@ import org.apache.spark.util.{MutableURLClassLoader, Utils} * is not fully tested. */ @ExtendedHiveTest -class VersionsSuite extends QueryTest with SQLTestUtils with TestHiveSingleton with Logging { +class VersionsSuite extends SparkFunSuite with Logging { private val clientBuilder = new HiveClientBuilder import clientBuilder.buildClient + /** + * Creates a temporary directory, which is then passed to `f` and will be deleted after `f` + * returns. + */ + protected def withTempDir(f: File => Unit): Unit = { + val dir = Utils.createTempDir().getCanonicalFile + try f(dir) finally Utils.deleteRecursively(dir) + } + + /** + * Drops table `tableName` after calling `f`. + */ + protected def withTable(tableNames: String*)(f: => Unit): Unit = { + try f finally { + tableNames.foreach { name => + versionSpark.sql(s"DROP TABLE IF EXISTS $name") + } + } + } + test("success sanity check") { val badClient = buildClient(HiveUtils.hiveExecutionVersion, new Configuration()) - val db = new CatalogDatabase("default", "desc", "loc", Map()) + val db = new CatalogDatabase("default", "desc", new URI("loc"), Map()) badClient.createDatabase(db, ignoreIfExists = true) } @@ -88,22 +109,30 @@ class VersionsSuite extends QueryTest with SQLTestUtils with TestHiveSingleton w assert(getNestedMessages(e) contains "Unknown column 'A0.OWNER_NAME' in 'field list'") } - private val versions = Seq("0.12", "0.13", "0.14", "1.0", "1.1", "1.2", "2.0") + private val versions = Seq("0.12", "0.13", "0.14", "1.0", "1.1", "1.2", "2.0", "2.1") private var client: HiveClient = null + private var versionSpark: TestHiveVersion = null + versions.foreach { version => test(s"$version: create client") { client = null System.gc() // Hack to avoid SEGV on some JVM versions. val hadoopConf = new Configuration() hadoopConf.set("test", "success") - // Hive changed the default of datanucleus.schema.autoCreateAll from true to false since 2.0 - // For details, see the JIRA HIVE-6113 - if (version == "2.0") { + // Hive changed the default of datanucleus.schema.autoCreateAll from true to false and + // hive.metastore.schema.verification from false to true since 2.0 + // For details, see the JIRA HIVE-6113 and HIVE-12463 + if (version == "2.0" || version == "2.1") { hadoopConf.set("datanucleus.schema.autoCreateAll", "true") + hadoopConf.set("hive.metastore.schema.verification", "false") } client = buildClient(version, hadoopConf, HiveUtils.hiveClientConfigurations(hadoopConf)) + if (versionSpark != null) versionSpark.reset() + versionSpark = TestHiveVersion(client) + assert(versionSpark.sharedState.externalCatalog.asInstanceOf[HiveExternalCatalog].client + .version.fullVersion.startsWith(version)) } def table(database: String, tableName: String): CatalogTable = { @@ -125,10 +154,10 @@ class VersionsSuite extends QueryTest with SQLTestUtils with TestHiveSingleton w // Database related API /////////////////////////////////////////////////////////////////////////// - val tempDatabasePath = Utils.createTempDir().getCanonicalPath + val tempDatabasePath = Utils.createTempDir().toURI test(s"$version: createDatabase") { - val defaultDB = CatalogDatabase("default", "desc", "loc", Map()) + val defaultDB = CatalogDatabase("default", "desc", new URI("loc"), Map()) client.createDatabase(defaultDB, ignoreIfExists = true) val tempDB = CatalogDatabase( "temporary", description = "test create", tempDatabasePath, Map()) @@ -346,7 +375,7 @@ class VersionsSuite extends QueryTest with SQLTestUtils with TestHiveSingleton w test(s"$version: alterPartitions") { val spec = Map("key1" -> "1", "key2" -> "2") - val newLocation = Utils.createTempDir().getPath() + val newLocation = new URI(Utils.createTempDir().toURI.toString.stripSuffix("/")) val storage = storageFormat.copy( locationUri = Some(newLocation), // needed for 0.12 alter partitions @@ -544,22 +573,30 @@ class VersionsSuite extends QueryTest with SQLTestUtils with TestHiveSingleton w test(s"$version: CREATE TABLE AS SELECT") { withTable("tbl") { - spark.sql("CREATE TABLE tbl AS SELECT 1 AS a") - assert(spark.table("tbl").collect().toSeq == Seq(Row(1))) + versionSpark.sql("CREATE TABLE tbl AS SELECT 1 AS a") + assert(versionSpark.table("tbl").collect().toSeq == Seq(Row(1))) + val tableMeta = versionSpark.sessionState.catalog.getTableMetadata(TableIdentifier("tbl")) + val totalSize = tableMeta.properties.get(StatsSetupConst.TOTAL_SIZE).map(_.toLong) + // Except 0.12, all the following versions will fill the Hive-generated statistics + if (version == "0.12") { + assert(totalSize.isEmpty) + } else { + assert(totalSize.nonEmpty && totalSize.get > 0) + } } } test(s"$version: Delete the temporary staging directory and files after each insert") { withTempDir { tmpDir => withTable("tab") { - spark.sql( + versionSpark.sql( s""" |CREATE TABLE tab(c1 string) |location '${tmpDir.toURI.toString}' """.stripMargin) (1 to 3).map { i => - spark.sql(s"INSERT OVERWRITE TABLE tab SELECT '$i'") + versionSpark.sql(s"INSERT OVERWRITE TABLE tab SELECT '$i'") } def listFiles(path: File): List[String] = { val dir = path.listFiles() @@ -568,7 +605,9 @@ class VersionsSuite extends QueryTest with SQLTestUtils with TestHiveSingleton w folders.flatMap(listFiles) ++: filePaths } // expect 2 files left: `.part-00000-random-uuid.crc` and `part-00000-random-uuid` - assert(listFiles(tmpDir).length == 2) + // 0.12, 0.13, 1.0 and 1.1 also has another two more files ._SUCCESS.crc and _SUCCESS + val metadataFiles = Seq("._SUCCESS.crc", "_SUCCESS") + assert(listFiles(tmpDir).filterNot(metadataFiles.contains).length == 2) } } } @@ -608,7 +647,7 @@ class VersionsSuite extends QueryTest with SQLTestUtils with TestHiveSingleton w withTable(tableName, tempTableName) { // Creates the external partitioned Avro table to be tested. - sql( + versionSpark.sql( s"""CREATE EXTERNAL TABLE $tableName |PARTITIONED BY (ds STRING) |ROW FORMAT SERDE 'org.apache.hadoop.hive.serde2.avro.AvroSerDe' @@ -621,7 +660,7 @@ class VersionsSuite extends QueryTest with SQLTestUtils with TestHiveSingleton w ) // Creates an temporary Avro table used to prepare testing Avro file. - sql( + versionSpark.sql( s"""CREATE EXTERNAL TABLE $tempTableName |ROW FORMAT SERDE 'org.apache.hadoop.hive.serde2.avro.AvroSerDe' |STORED AS @@ -633,45 +672,29 @@ class VersionsSuite extends QueryTest with SQLTestUtils with TestHiveSingleton w ) // Generates Avro data. - sql(s"INSERT OVERWRITE TABLE $tempTableName SELECT 1, STRUCT(2, 2.5)") + versionSpark.sql(s"INSERT OVERWRITE TABLE $tempTableName SELECT 1, STRUCT(2, 2.5)") // Adds generated Avro data as a new partition to the testing table. - sql(s"ALTER TABLE $tableName ADD PARTITION (ds = 'foo') LOCATION '$path/$tempTableName'") + versionSpark.sql( + s"ALTER TABLE $tableName ADD PARTITION (ds = 'foo') LOCATION '$path/$tempTableName'") // The following query fails before SPARK-13709 is fixed. This is because when reading // data from table partitions, Avro deserializer needs the Avro schema, which is defined // in table property "avro.schema.literal". However, we only initializes the deserializer // using partition properties, which doesn't include the wanted property entry. Merging // two sets of properties solves the problem. - checkAnswer( - sql(s"SELECT * FROM $tableName"), - Row(1, Row(2, 2.5D), "foo") - ) + assert(versionSpark.sql(s"SELECT * FROM $tableName").collect() === + Array(Row(1, Row(2, 2.5D), "foo"))) } } } test(s"$version: CTAS for managed data source tables") { withTable("t", "t1") { - import spark.implicits._ - - val tPath = new Path(spark.sessionState.conf.warehousePath, "t") - Seq("1").toDF("a").write.saveAsTable("t") - val expectedPath = s"file:${tPath.toUri.getPath.stripSuffix("/")}" - val table = spark.sessionState.catalog.getTableMetadata(TableIdentifier("t")) - - assert(table.location.stripSuffix("/") == expectedPath) - assert(tPath.getFileSystem(spark.sessionState.newHadoopConf()).exists(tPath)) - checkAnswer(spark.table("t"), Row("1") :: Nil) - - val t1Path = new Path(spark.sessionState.conf.warehousePath, "t1") - spark.sql("create table t1 using parquet as select 2 as a") - val table1 = spark.sessionState.catalog.getTableMetadata(TableIdentifier("t1")) - val expectedPath1 = s"file:${t1Path.toUri.getPath.stripSuffix("/")}" - - assert(table1.location.stripSuffix("/") == expectedPath1) - assert(t1Path.getFileSystem(spark.sessionState.newHadoopConf()).exists(t1Path)) - checkAnswer(spark.table("t1"), Row(2) :: Nil) + versionSpark.range(1).write.saveAsTable("t") + assert(versionSpark.table("t").collect() === Array(Row(0))) + versionSpark.sql("create table t1 using parquet as select 2 as a") + assert(versionSpark.table("t1").collect() === Array(Row(2))) } } // TODO: add more tests. diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala index 4a8086d7e5400..84f915977bd88 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala @@ -509,6 +509,19 @@ abstract class AggregationQuerySuite extends QueryTest with SQLTestUtils with Te Row(null, null, 110.0, null, null, 10.0) :: Nil) } + test("non-deterministic children expressions of UDAF") { + val e = intercept[AnalysisException] { + spark.sql( + """ + |SELECT mydoublesum(value + 1.5 * key + rand()) + |FROM agg1 + |GROUP BY key + """.stripMargin) + }.getMessage + assert(Seq("nondeterministic expression", + "should not appear in the arguments of an aggregate function").forall(e.contains)) + } + test("interpreted aggregate function") { checkAnswer( spark.sql( diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala index 536ca8fd9d45d..abe5d835719b6 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.hive.execution import java.io._ import java.nio.charset.StandardCharsets import java.util +import java.util.Locale import scala.util.control.NonFatal @@ -207,6 +208,7 @@ abstract class HiveComparisonTest // This list contains indicators for those lines which do not have actual results and we // want to ignore. lazy val ignoredLineIndicators = Seq( + "# Detailed Table Information", "# Partition Information", "# col_name" ) @@ -298,10 +300,11 @@ abstract class HiveComparisonTest // thus the tables referenced in those DDL commands cannot be extracted for use by our // test table auto-loading mechanism. In addition, the tests which use the SHOW TABLES // command expect these tables to exist. - val hasShowTableCommand = queryList.exists(_.toLowerCase.contains("show tables")) + val hasShowTableCommand = + queryList.exists(_.toLowerCase(Locale.ROOT).contains("show tables")) for (table <- Seq("src", "srcpart")) { val hasMatchingQuery = queryList.exists { query => - val normalizedQuery = query.toLowerCase.stripSuffix(";") + val normalizedQuery = query.toLowerCase(Locale.ROOT).stripSuffix(";") normalizedQuery.endsWith(table) || normalizedQuery.contains(s"from $table") || normalizedQuery.contains(s"from default.$table") @@ -358,7 +361,7 @@ abstract class HiveComparisonTest stringToFile(new File(failedDirectory, testCaseName), errorMessage + consoleTestCase) fail(errorMessage) } - }.toSeq + } (queryList, hiveResults, catalystResults).zipped.foreach { case (query, hive, (hiveQuery, catalyst)) => @@ -369,6 +372,7 @@ abstract class HiveComparisonTest if ((!hiveQuery.logical.isInstanceOf[ExplainCommand]) && (!hiveQuery.logical.isInstanceOf[ShowFunctionsCommand]) && (!hiveQuery.logical.isInstanceOf[DescribeFunctionCommand]) && + (!hiveQuery.logical.isInstanceOf[DescribeTableCommand]) && preparedHive != catalyst) { val hivePrintOut = s"== HIVE - ${preparedHive.size} row(s) ==" +: preparedHive @@ -442,7 +446,7 @@ abstract class HiveComparisonTest "create table", "drop index" ) - !queryList.map(_.toLowerCase).exists { query => + !queryList.map(_.toLowerCase(Locale.ROOT)).exists { query => excludedSubstrings.exists(s => query.contains(s)) } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala index 81ae5b7bdb672..c3d734e5a0366 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.hive.execution import java.io.File +import java.net.URI import org.apache.hadoop.fs.Path import org.scalatest.BeforeAndAfterEach @@ -25,20 +26,146 @@ import org.scalatest.BeforeAndAfterEach import org.apache.spark.SparkException import org.apache.spark.sql.{AnalysisException, QueryTest, Row, SaveMode} import org.apache.spark.sql.catalyst.analysis.{NoSuchPartitionException, TableAlreadyExistsException} -import org.apache.spark.sql.catalyst.catalog.{CatalogDatabase, CatalogTable, CatalogTableType} +import org.apache.spark.sql.catalyst.catalog._ import org.apache.spark.sql.catalyst.TableIdentifier -import org.apache.spark.sql.execution.command.DDLUtils +import org.apache.spark.sql.execution.command.{DDLSuite, DDLUtils} import org.apache.spark.sql.hive.HiveExternalCatalog import org.apache.spark.sql.hive.orc.OrcFileOperator import org.apache.spark.sql.hive.test.TestHiveSingleton -import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.internal.{HiveSerDe, SQLConf} import org.apache.spark.sql.internal.StaticSQLConf.CATALOG_IMPLEMENTATION import org.apache.spark.sql.test.SQLTestUtils -import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.types._ + +// TODO(gatorsmile): combine HiveCatalogedDDLSuite and HiveDDLSuite +class HiveCatalogedDDLSuite extends DDLSuite with TestHiveSingleton with BeforeAndAfterEach { + override def afterEach(): Unit = { + try { + // drop all databases, tables and functions after each test + spark.sessionState.catalog.reset() + } finally { + super.afterEach() + } + } + + protected override def generateTable( + catalog: SessionCatalog, + name: TableIdentifier, + isDataSource: Boolean): CatalogTable = { + val storage = + if (isDataSource) { + val serde = HiveSerDe.sourceToSerDe("parquet") + assert(serde.isDefined, "The default format is not Hive compatible") + CatalogStorageFormat( + locationUri = Some(catalog.defaultTablePath(name)), + inputFormat = serde.get.inputFormat, + outputFormat = serde.get.outputFormat, + serde = serde.get.serde, + compressed = false, + properties = Map("serialization.format" -> "1")) + } else { + CatalogStorageFormat( + locationUri = Some(catalog.defaultTablePath(name)), + inputFormat = Some("org.apache.hadoop.mapred.SequenceFileInputFormat"), + outputFormat = Some("org.apache.hadoop.hive.ql.io.HiveSequenceFileOutputFormat"), + serde = Some("org.apache.hadoop.hive.serde2.lazy.LazySimpleSerDe"), + compressed = false, + properties = Map("serialization.format" -> "1")) + } + val metadata = new MetadataBuilder() + .putString("key", "value") + .build() + CatalogTable( + identifier = name, + tableType = CatalogTableType.EXTERNAL, + storage = storage, + schema = new StructType() + .add("col1", "int", nullable = true, metadata = metadata) + .add("col2", "string") + .add("a", "int") + .add("b", "int"), + provider = if (isDataSource) Some("parquet") else Some("hive"), + partitionColumnNames = Seq("a", "b"), + createTime = 0L, + tracksPartitionsInCatalog = true) + } + + protected override def normalizeCatalogTable(table: CatalogTable): CatalogTable = { + val nondeterministicProps = Set( + "CreateTime", + "transient_lastDdlTime", + "grantTime", + "lastUpdateTime", + "last_modified_by", + "last_modified_time", + "Owner:", + "COLUMN_STATS_ACCURATE", + // The following are hive specific schema parameters which we do not need to match exactly. + "numFiles", + "numRows", + "rawDataSize", + "totalSize", + "totalNumberFiles", + "maxFileSize", + "minFileSize" + ) + + table.copy( + createTime = 0L, + lastAccessTime = 0L, + owner = "", + properties = table.properties.filterKeys(!nondeterministicProps.contains(_)), + // View texts are checked separately + viewText = None + ) + } + + test("alter table: set location") { + testSetLocation(isDatasourceTable = false) + } + + test("alter table: set properties") { + testSetProperties(isDatasourceTable = false) + } + + test("alter table: unset properties") { + testUnsetProperties(isDatasourceTable = false) + } + + test("alter table: set serde") { + testSetSerde(isDatasourceTable = false) + } + + test("alter table: set serde partition") { + testSetSerdePartition(isDatasourceTable = false) + } + + test("alter table: change column") { + testChangeColumn(isDatasourceTable = false) + } + + test("alter table: rename partition") { + testRenamePartitions(isDatasourceTable = false) + } + + test("alter table: drop partition") { + testDropPartitions(isDatasourceTable = false) + } + + test("alter table: add partition") { + testAddPartitions(isDatasourceTable = false) + } + + test("drop table") { + testDropTable(isDatasourceTable = false) + } + +} class HiveDDLSuite extends QueryTest with SQLTestUtils with TestHiveSingleton with BeforeAndAfterEach { import testImplicits._ + val hiveFormats = Seq("PARQUET", "ORC", "TEXTFILE", "SEQUENCEFILE", "RCFILE", "AVRO") override def afterEach(): Unit = { try { @@ -54,11 +181,11 @@ class HiveDDLSuite dbPath: Option[String] = None): Boolean = { val expectedTablePath = if (dbPath.isEmpty) { - hiveContext.sessionState.catalog.hiveDefaultTableFilePath(tableIdentifier) + hiveContext.sessionState.catalog.defaultTablePath(tableIdentifier) } else { - new Path(new Path(dbPath.get), tableIdentifier.table).toString + new Path(new Path(dbPath.get), tableIdentifier.table).toUri } - val filesystemPath = new Path(expectedTablePath) + val filesystemPath = new Path(expectedTablePath.toString) val fs = filesystemPath.getFileSystem(spark.sessionState.newHadoopConf()) fs.exists(filesystemPath) } @@ -634,23 +761,6 @@ class HiveDDLSuite } } - test("desc table for Hive table") { - withTable("tab1") { - val tabName = "tab1" - sql(s"CREATE TABLE $tabName(c1 int)") - - assert(sql(s"DESC $tabName").collect().length == 1) - - assert( - sql(s"DESC FORMATTED $tabName").collect() - .exists(_.getString(0) == "# Storage Information")) - - assert( - sql(s"DESC EXTENDED $tabName").collect() - .exists(_.getString(0) == "# Detailed Table Information")) - } - } - test("desc table for Hive table - partitioned table") { withTable("tbl") { sql("CREATE TABLE tbl(a int) PARTITIONED BY (b int)") @@ -667,23 +777,6 @@ class HiveDDLSuite } } - test("desc formatted table for permanent view") { - withTable("tbl") { - withView("view1") { - sql("CREATE TABLE tbl(a int)") - sql("CREATE VIEW view1 AS SELECT * FROM tbl") - assert(sql("DESC FORMATTED view1").collect().containsSlice( - Seq( - Row("# View Information", "", ""), - Row("View Text:", "SELECT * FROM tbl", ""), - Row("View Default Database:", "default", ""), - Row("View Query Output Columns:", "[a]", "") - ) - )) - } - } - } - test("desc table for data source table using Hive Metastore") { assume(spark.sparkContext.conf.get(CATALOG_IMPLEMENTATION) == "hive") val tabName = "tab1" @@ -692,7 +785,7 @@ class HiveDDLSuite checkAnswer( sql(s"DESC $tabName").select("col_name", "data_type", "comment"), - Row("a", "int", "test") + Row("# col_name", "data_type", "comment") :: Row("a", "int", "test") :: Nil ) } } @@ -710,7 +803,7 @@ class HiveDDLSuite } sql(s"CREATE DATABASE $dbName Location '${tmpDir.toURI.getPath.stripSuffix("/")}'") val db1 = catalog.getDatabaseMetadata(dbName) - val dbPath = tmpDir.toURI.toString.stripSuffix("/") + val dbPath = new URI(tmpDir.toURI.toString.stripSuffix("/")) assert(db1 == CatalogDatabase(dbName, "", dbPath, Map.empty)) sql("USE db1") @@ -747,11 +840,12 @@ class HiveDDLSuite sql(s"CREATE DATABASE $dbName") val catalog = spark.sessionState.catalog val expectedDBLocation = s"file:${dbPath.toUri.getPath.stripSuffix("/")}/$dbName.db" + val expectedDBUri = CatalogUtils.stringToURI(expectedDBLocation) val db1 = catalog.getDatabaseMetadata(dbName) assert(db1 == CatalogDatabase( dbName, "", - expectedDBLocation, + expectedDBUri, Map.empty)) // the database directory was created assert(fs.exists(dbPath) && fs.isDirectory(dbPath)) @@ -1143,23 +1237,6 @@ class HiveDDLSuite sql(s"SELECT * FROM ${targetTable.identifier}")) } - test("desc table for data source table") { - withTable("tab1") { - val tabName = "tab1" - spark.range(1).write.format("json").saveAsTable(tabName) - - assert(sql(s"DESC $tabName").collect().length == 1) - - assert( - sql(s"DESC FORMATTED $tabName").collect() - .exists(_.getString(0) == "# Storage Information")) - - assert( - sql(s"DESC EXTENDED $tabName").collect() - .exists(_.getString(0) == "# Detailed Table Information")) - } - } - test("create table with the same name as an index table") { val tabName = "tab1" val indexName = tabName + "_index" @@ -1173,6 +1250,14 @@ class HiveDDLSuite s"CREATE INDEX $indexName ON TABLE $tabName (a) AS 'COMPACT' WITH DEFERRED REBUILD") val indexTabName = spark.sessionState.catalog.listTables("default", s"*$indexName*").head.table + + // Even if index tables exist, listTables and getTable APIs should still work + checkAnswer( + spark.catalog.listTables().toDF(), + Row(indexTabName, "default", null, null, false) :: + Row(tabName, "default", null, "MANAGED", false) :: Nil) + assert(spark.catalog.getTable("default", indexTabName).name === indexTabName) + intercept[TableAlreadyExistsException] { sql(s"CREATE TABLE $indexTabName(b int)") } @@ -1245,46 +1330,6 @@ class HiveDDLSuite } } - test("desc table for data source table - partitioned bucketed table") { - withTable("t1") { - spark - .range(1).select('id as 'a, 'id as 'b, 'id as 'c, 'id as 'd).write - .bucketBy(2, "b").sortBy("c").partitionBy("d") - .saveAsTable("t1") - - val formattedDesc = sql("DESC FORMATTED t1").collect() - - assert(formattedDesc.containsSlice( - Seq( - Row("a", "bigint", null), - Row("b", "bigint", null), - Row("c", "bigint", null), - Row("d", "bigint", null), - Row("# Partition Information", "", ""), - Row("# col_name", "data_type", "comment"), - Row("d", "bigint", null), - Row("", "", ""), - Row("# Detailed Table Information", "", ""), - Row("Database:", "default", "") - ) - )) - - assert(formattedDesc.containsSlice( - Seq( - Row("Table Type:", "MANAGED", "") - ) - )) - - assert(formattedDesc.containsSlice( - Seq( - Row("Num Buckets:", "2", ""), - Row("Bucket Columns:", "[b]", ""), - Row("Sort Columns:", "[c]", "") - ) - )) - } - } - test("datasource and statistics table property keys are not allowed") { import org.apache.spark.sql.hive.HiveExternalCatalog.DATASOURCE_PREFIX import org.apache.spark.sql.hive.HiveExternalCatalog.STATISTICS_PREFIX @@ -1588,99 +1633,296 @@ class HiveDDLSuite } } + test("create hive table with a non-existing location") { + withTable("t", "t1") { + withTempPath { dir => + spark.sql(s"CREATE TABLE t(a int, b int) USING hive LOCATION '$dir'") + + val table = spark.sessionState.catalog.getTableMetadata(TableIdentifier("t")) + assert(table.location == makeQualifiedPath(dir.getAbsolutePath)) + + spark.sql("INSERT INTO TABLE t SELECT 1, 2") + assert(dir.exists()) + + checkAnswer(spark.table("t"), Row(1, 2)) + } + // partition table + withTempPath { dir => + spark.sql( + s""" + |CREATE TABLE t1(a int, b int) + |USING hive + |PARTITIONED BY(a) + |LOCATION '$dir' + """.stripMargin) + + val table = spark.sessionState.catalog.getTableMetadata(TableIdentifier("t1")) + assert(table.location == makeQualifiedPath(dir.getAbsolutePath)) + + spark.sql("INSERT INTO TABLE t1 PARTITION(a=1) SELECT 2") + + val partDir = new File(dir, "a=1") + assert(partDir.exists()) + + checkAnswer(spark.table("t1"), Row(2, 1)) + } + } + } + Seq(true, false).foreach { shouldDelete => - val tcName = if (shouldDelete) "non-existent" else "existed" - test(s"CTAS for external data source table with a $tcName location") { + val tcName = if (shouldDelete) "non-existing" else "existed" + + test(s"CTAS for external hive table with a $tcName location") { withTable("t", "t1") { - withTempDir { - dir => - if (shouldDelete) { - dir.delete() - } + withSQLConf("hive.exec.dynamic.partition.mode" -> "nonstrict") { + withTempDir { dir => + if (shouldDelete) dir.delete() spark.sql( s""" |CREATE TABLE t - |USING parquet + |USING hive |LOCATION '$dir' |AS SELECT 3 as a, 4 as b, 1 as c, 2 as d """.stripMargin) - val table = spark.sessionState.catalog.getTableMetadata(TableIdentifier("t")) - assert(table.location == dir.getAbsolutePath) + assert(table.location == makeQualifiedPath(dir.getAbsolutePath)) checkAnswer(spark.table("t"), Row(3, 4, 1, 2)) - } - // partition table - withTempDir { - dir => - if (shouldDelete) { - dir.delete() - } + } + // partition table + withTempDir { dir => + if (shouldDelete) dir.delete() spark.sql( s""" |CREATE TABLE t1 - |USING parquet + |USING hive |PARTITIONED BY(a, b) |LOCATION '$dir' |AS SELECT 3 as a, 4 as b, 1 as c, 2 as d """.stripMargin) - val table = spark.sessionState.catalog.getTableMetadata(TableIdentifier("t1")) - assert(table.location == dir.getAbsolutePath) + assert(table.location == makeQualifiedPath(dir.getAbsolutePath)) val partDir = new File(dir, "a=3") assert(partDir.exists()) checkAnswer(spark.table("t1"), Row(1, 2, 3, 4)) + } } } } + } - test(s"CTAS for external hive table with a $tcName location") { - withTable("t", "t1") { - withSQLConf("hive.exec.dynamic.partition.mode" -> "nonstrict") { - withTempDir { - dir => - if (shouldDelete) { - dir.delete() - } - spark.sql( - s""" - |CREATE TABLE t - |USING hive - |LOCATION '$dir' - |AS SELECT 3 as a, 4 as b, 1 as c, 2 as d - """.stripMargin) - val dirPath = new Path(dir.getAbsolutePath) - val fs = dirPath.getFileSystem(spark.sessionState.newHadoopConf()) - val table = spark.sessionState.catalog.getTableMetadata(TableIdentifier("t")) - assert(new Path(table.location) == fs.makeQualified(dirPath)) - - checkAnswer(spark.table("t"), Row(3, 4, 1, 2)) + Seq("parquet", "hive").foreach { datasource => + Seq("a b", "a:b", "a%b", "a,b").foreach { specialChars => + test(s"partition column name of $datasource table containing $specialChars") { + withTable("t") { + withTempDir { dir => + spark.sql( + s""" + |CREATE TABLE t(a string, `$specialChars` string) + |USING $datasource + |PARTITIONED BY(`$specialChars`) + |LOCATION '$dir' + """.stripMargin) + + assert(dir.listFiles().isEmpty) + spark.sql(s"INSERT INTO TABLE t PARTITION(`$specialChars`=2) SELECT 1") + val partEscaped = s"${ExternalCatalogUtils.escapePathName(specialChars)}=2" + val partFile = new File(dir, partEscaped) + assert(partFile.listFiles().length >= 1) + checkAnswer(spark.table("t"), Row("1", "2") :: Nil) + + withSQLConf("hive.exec.dynamic.partition.mode" -> "nonstrict") { + spark.sql(s"INSERT INTO TABLE t PARTITION(`$specialChars`) SELECT 3, 4") + val partEscaped1 = s"${ExternalCatalogUtils.escapePathName(specialChars)}=4" + val partFile1 = new File(dir, partEscaped1) + assert(partFile1.listFiles().length >= 1) + checkAnswer(spark.table("t"), Row("1", "2") :: Row("3", "4") :: Nil) + } } - // partition table - withTempDir { - dir => - if (shouldDelete) { - dir.delete() - } - spark.sql( - s""" - |CREATE TABLE t1 - |USING hive - |PARTITIONED BY(a, b) - |LOCATION '$dir' - |AS SELECT 3 as a, 4 as b, 1 as c, 2 as d - """.stripMargin) - val dirPath = new Path(dir.getAbsolutePath) - val fs = dirPath.getFileSystem(spark.sessionState.newHadoopConf()) - val table = spark.sessionState.catalog.getTableMetadata(TableIdentifier("t1")) - assert(new Path(table.location) == fs.makeQualified(dirPath)) - - val partDir = new File(dir, "a=3") - assert(partDir.exists()) - - checkAnswer(spark.table("t1"), Row(1, 2, 3, 4)) + } + } + } + } + + Seq("a b", "a:b", "a%b").foreach { specialChars => + test(s"hive table: location uri contains $specialChars") { + withTable("t") { + withTempDir { dir => + val loc = new File(dir, specialChars) + loc.mkdir() + spark.sql( + s""" + |CREATE TABLE t(a string) + |USING hive + |LOCATION '$loc' + """.stripMargin) + + val table = spark.sessionState.catalog.getTableMetadata(TableIdentifier("t")) + assert(table.location == makeQualifiedPath(loc.getAbsolutePath)) + assert(new Path(table.location).toString.contains(specialChars)) + + assert(loc.listFiles().isEmpty) + if (specialChars != "a:b") { + spark.sql("INSERT INTO TABLE t SELECT 1") + assert(loc.listFiles().length >= 1) + checkAnswer(spark.table("t"), Row("1") :: Nil) + } else { + val e = intercept[AnalysisException] { + spark.sql("INSERT INTO TABLE t SELECT 1") + }.getMessage + assert(e.contains("java.net.URISyntaxException: Relative path in absolute URI: a:b")) + } + } + + withTempDir { dir => + val loc = new File(dir, specialChars) + loc.mkdir() + spark.sql( + s""" + |CREATE TABLE t1(a string, b string) + |USING hive + |PARTITIONED BY(b) + |LOCATION '$loc' + """.stripMargin) + + val table = spark.sessionState.catalog.getTableMetadata(TableIdentifier("t1")) + assert(table.location == makeQualifiedPath(loc.getAbsolutePath)) + assert(new Path(table.location).toString.contains(specialChars)) + + assert(loc.listFiles().isEmpty) + if (specialChars != "a:b") { + spark.sql("INSERT INTO TABLE t1 PARTITION(b=2) SELECT 1") + val partFile = new File(loc, "b=2") + assert(partFile.listFiles().length >= 1) + checkAnswer(spark.table("t1"), Row("1", "2") :: Nil) + + spark.sql("INSERT INTO TABLE t1 PARTITION(b='2017-03-03 12:13%3A14') SELECT 1") + val partFile1 = new File(loc, "b=2017-03-03 12:13%3A14") + assert(!partFile1.exists()) + val partFile2 = new File(loc, "b=2017-03-03 12%3A13%253A14") + assert(partFile2.listFiles().length >= 1) + checkAnswer(spark.table("t1"), + Row("1", "2") :: Row("1", "2017-03-03 12:13%3A14") :: Nil) + } else { + val e = intercept[AnalysisException] { + spark.sql("INSERT INTO TABLE t1 PARTITION(b=2) SELECT 1") + }.getMessage + assert(e.contains("java.net.URISyntaxException: Relative path in absolute URI: a:b")) + + val e1 = intercept[AnalysisException] { + spark.sql("INSERT INTO TABLE t1 PARTITION(b='2017-03-03 12:13%3A14') SELECT 1") + }.getMessage + assert(e1.contains("java.net.URISyntaxException: Relative path in absolute URI: a:b")) + } + } + } + } + } + + test("SPARK-19905: Hive SerDe table input paths") { + withTable("spark_19905") { + withTempView("spark_19905_view") { + spark.range(10).createOrReplaceTempView("spark_19905_view") + sql("CREATE TABLE spark_19905 STORED AS RCFILE AS SELECT * FROM spark_19905_view") + assert(spark.table("spark_19905").inputFiles.nonEmpty) + assert(sql("SELECT input_file_name() FROM spark_19905").count() > 0) + } + } + } + + hiveFormats.foreach { tableType => + test(s"alter hive serde table add columns -- partitioned - $tableType") { + withTable("tab") { + sql( + s""" + |CREATE TABLE tab (c1 int, c2 int) + |PARTITIONED BY (c3 int) STORED AS $tableType + """.stripMargin) + + sql("INSERT INTO tab PARTITION (c3=1) VALUES (1, 2)") + sql("ALTER TABLE tab ADD COLUMNS (c4 int)") + + checkAnswer( + sql("SELECT * FROM tab WHERE c3 = 1"), + Seq(Row(1, 2, null, 1)) + ) + assert(spark.table("tab").schema + .contains(StructField("c4", IntegerType))) + sql("INSERT INTO tab PARTITION (c3=2) VALUES (2, 3, 4)") + checkAnswer( + spark.table("tab"), + Seq(Row(1, 2, null, 1), Row(2, 3, 4, 2)) + ) + checkAnswer( + sql("SELECT * FROM tab WHERE c3 = 2 AND c4 IS NOT NULL"), + Seq(Row(2, 3, 4, 2)) + ) + + sql("ALTER TABLE tab ADD COLUMNS (c5 char(10))") + assert(spark.table("tab").schema.find(_.name == "c5") + .get.metadata.getString("HIVE_TYPE_STRING") == "char(10)") + } + } + } + + hiveFormats.foreach { tableType => + test(s"alter hive serde table add columns -- with predicate - $tableType ") { + withTable("tab") { + sql(s"CREATE TABLE tab (c1 int, c2 int) STORED AS $tableType") + sql("INSERT INTO tab VALUES (1, 2)") + sql("ALTER TABLE tab ADD COLUMNS (c4 int)") + checkAnswer( + sql("SELECT * FROM tab WHERE c4 IS NULL"), + Seq(Row(1, 2, null)) + ) + assert(spark.table("tab").schema + .contains(StructField("c4", IntegerType))) + sql("INSERT INTO tab VALUES (2, 3, 4)") + checkAnswer( + sql("SELECT * FROM tab WHERE c4 = 4 "), + Seq(Row(2, 3, 4)) + ) + checkAnswer( + spark.table("tab"), + Seq(Row(1, 2, null), Row(2, 3, 4)) + ) + } + } + } + + Seq(true, false).foreach { caseSensitive => + test(s"alter add columns with existing column name - caseSensitive $caseSensitive") { + withSQLConf(SQLConf.CASE_SENSITIVE.key -> s"$caseSensitive") { + withTable("tab") { + sql("CREATE TABLE tab (c1 int) PARTITIONED BY (c2 int) STORED AS PARQUET") + if (!caseSensitive) { + // duplicating partitioning column name + val e1 = intercept[AnalysisException] { + sql("ALTER TABLE tab ADD COLUMNS (C2 string)") + }.getMessage + assert(e1.contains("Found duplicate column(s)")) + + // duplicating data column name + val e2 = intercept[AnalysisException] { + sql("ALTER TABLE tab ADD COLUMNS (C1 string)") + }.getMessage + assert(e2.contains("Found duplicate column(s)")) + } else { + // hive catalog will still complains that c1 is duplicate column name because hive + // identifiers are case insensitive. + val e1 = intercept[AnalysisException] { + sql("ALTER TABLE tab ADD COLUMNS (C2 string)") + }.getMessage + assert(e1.contains("HiveException")) + + // hive catalog will still complains that c1 is duplicate column name because hive + // identifiers are case insensitive. + val e2 = intercept[AnalysisException] { + sql("ALTER TABLE tab ADD COLUMNS (C1 string)") + }.getMessage + assert(e2.contains("HiveException")) } } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveExplainSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveExplainSuite.scala index 8a37bc3665d32..aa1ca2909074f 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveExplainSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveExplainSuite.scala @@ -43,11 +43,29 @@ class HiveExplainSuite extends QueryTest with SQLTestUtils with TestHiveSingleto test("explain extended command") { checkKeywordsExist(sql(" explain select * from src where key=123 "), - "== Physical Plan ==") + "== Physical Plan ==", + "org.apache.hadoop.hive.serde2.lazy.LazySimpleSerDe") + checkKeywordsNotExist(sql(" explain select * from src where key=123 "), "== Parsed Logical Plan ==", "== Analyzed Logical Plan ==", - "== Optimized Logical Plan ==") + "== Optimized Logical Plan ==", + "Owner", + "Database", + "Created", + "Last Access", + "Type", + "Provider", + "Properties", + "Statistics", + "Location", + "Serde Library", + "InputFormat", + "OutputFormat", + "Partition Provider", + "Schema" + ) + checkKeywordsExist(sql(" explain extended select * from src where key=123 "), "== Parsed Logical Plan ==", "== Analyzed Logical Plan ==", diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveOperatorQueryableSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveOperatorQueryableSuite.scala deleted file mode 100644 index 0e89e990e564e..0000000000000 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveOperatorQueryableSuite.scala +++ /dev/null @@ -1,53 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.hive.execution - -import org.apache.spark.sql.{QueryTest, Row} -import org.apache.spark.sql.hive.test.TestHiveSingleton - -/** - * A set of tests that validates commands can also be queried by like a table - */ -class HiveOperatorQueryableSuite extends QueryTest with TestHiveSingleton { - import spark._ - - test("SPARK-5324 query result of describe command") { - hiveContext.loadTestTable("src") - - // Creates a temporary view with the output of a describe command - sql("desc src").createOrReplaceTempView("mydesc") - checkAnswer( - sql("desc mydesc"), - Seq( - Row("col_name", "string", "name of the column"), - Row("data_type", "string", "data type of the column"), - Row("comment", "string", "comment of the column"))) - - checkAnswer( - sql("select * from mydesc"), - Seq( - Row("key", "int", null), - Row("value", "string", null))) - - checkAnswer( - sql("select col_name, data_type, comment from mydesc"), - Seq( - Row("key", "int", null), - Row("value", "string", null))) - } -} diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQueryFileTest.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQueryFileTest.scala index e772324a57ab8..bb4ce6d3aa3f1 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQueryFileTest.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQueryFileTest.scala @@ -24,7 +24,7 @@ import org.apache.spark.sql.catalyst.util._ /** * A framework for running the query tests that are listed as a set of text files. * - * TestSuites that derive from this class must provide a map of testCaseName -> testCaseFiles + * TestSuites that derive from this class must provide a map of testCaseName to testCaseFiles * that should be included. Additionally, there is support for whitelisting and blacklisting * tests as development progresses. */ diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala index dd278f683a3cd..cf33760360724 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala @@ -80,7 +80,7 @@ class HiveQuerySuite extends HiveComparisonTest with SQLTestUtils with BeforeAnd private def assertUnsupportedFeature(body: => Unit): Unit = { val e = intercept[ParseException] { body } - assert(e.getMessage.toLowerCase.contains("operation not allowed")) + assert(e.getMessage.toLowerCase(Locale.ROOT).contains("operation not allowed")) } // Testing the Broadcast based join for cartesian join (cross join) @@ -789,62 +789,6 @@ class HiveQuerySuite extends HiveComparisonTest with SQLTestUtils with BeforeAnd assert(Try(q0.count()).isSuccess) } - test("DESCRIBE commands") { - sql(s"CREATE TABLE test_describe_commands1 (key INT, value STRING) PARTITIONED BY (dt STRING)") - - sql( - """FROM src INSERT OVERWRITE TABLE test_describe_commands1 PARTITION (dt='2008-06-08') - |SELECT key, value - """.stripMargin) - - // Describe a table - assertResult( - Array( - Row("key", "int", null), - Row("value", "string", null), - Row("dt", "string", null), - Row("# Partition Information", "", ""), - Row("# col_name", "data_type", "comment"), - Row("dt", "string", null)) - ) { - sql("DESCRIBE test_describe_commands1") - .select('col_name, 'data_type, 'comment) - .collect() - } - - // Describe a table with a fully qualified table name - assertResult( - Array( - Row("key", "int", null), - Row("value", "string", null), - Row("dt", "string", null), - Row("# Partition Information", "", ""), - Row("# col_name", "data_type", "comment"), - Row("dt", "string", null)) - ) { - sql("DESCRIBE default.test_describe_commands1") - .select('col_name, 'data_type, 'comment) - .collect() - } - - // Describe a temporary view. - val testData = - TestHive.sparkContext.parallelize( - TestData(1, "str1") :: - TestData(1, "str2") :: Nil) - testData.toDF().createOrReplaceTempView("test_describe_commands2") - - assertResult( - Array( - Row("a", "int", null), - Row("b", "string", null)) - ) { - sql("DESCRIBE test_describe_commands2") - .select('col_name, 'data_type, 'comment) - .collect() - } - } - test("SPARK-2263: Insert Map values") { sql("CREATE TABLE m(value MAP)") sql("INSERT OVERWRITE TABLE m SELECT MAP(key, value) FROM src LIMIT 10") diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDAFSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDAFSuite.scala index c9ef72ee112cf..479ca1e8def56 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDAFSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDAFSuite.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.hive.execution import scala.collection.JavaConverters._ +import org.apache.hadoop.hive.ql.udf.UDAFPercentile import org.apache.hadoop.hive.ql.udf.generic.{AbstractGenericUDAFResolver, GenericUDAFEvaluator, GenericUDAFMax} import org.apache.hadoop.hive.ql.udf.generic.GenericUDAFEvaluator.{AggregationBuffer, Mode} import org.apache.hadoop.hive.ql.util.JavaDataModel @@ -26,7 +27,7 @@ import org.apache.hadoop.hive.serde2.objectinspector.{ObjectInspector, ObjectIns import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory import org.apache.hadoop.hive.serde2.typeinfo.TypeInfo -import org.apache.spark.sql.{QueryTest, Row} +import org.apache.spark.sql.{AnalysisException, QueryTest, Row} import org.apache.spark.sql.execution.aggregate.ObjectHashAggregateExec import org.apache.spark.sql.hive.test.TestHiveSingleton import org.apache.spark.sql.test.SQLTestUtils @@ -84,6 +85,21 @@ class HiveUDAFSuite extends QueryTest with TestHiveSingleton with SQLTestUtils { Row(1, Row(1, 1)) )) } + + test("non-deterministic children expressions of UDAF") { + withTempView("view1") { + spark.range(1).selectExpr("id as x", "id as y").createTempView("view1") + withUserDefinedFunction("testUDAFPercentile" -> true) { + // non-deterministic children of Hive UDAF + sql(s"CREATE TEMPORARY FUNCTION testUDAFPercentile AS '${classOf[UDAFPercentile].getName}'") + val e1 = intercept[AnalysisException] { + sql("SELECT testUDAFPercentile(x, rand()) from view1 group by y") + }.getMessage + assert(Seq("nondeterministic expression", + "should not appear in the arguments of an aggregate function").forall(e1.contains)) + } + } + } } /** diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDFSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDFSuite.scala index ef6883839d437..4446af2e75e00 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDFSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDFSuite.scala @@ -31,6 +31,7 @@ import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectIn import org.apache.hadoop.io.{LongWritable, Writable} import org.apache.spark.sql.{AnalysisException, QueryTest, Row} +import org.apache.spark.sql.catalyst.plans.logical.Project import org.apache.spark.sql.functions.max import org.apache.spark.sql.hive.test.TestHiveSingleton import org.apache.spark.sql.test.SQLTestUtils @@ -387,6 +388,20 @@ class HiveUDFSuite extends QueryTest with TestHiveSingleton with SQLTestUtils { hiveContext.reset() } + test("non-deterministic children of UDF") { + withUserDefinedFunction("testStringStringUDF" -> true, "testGenericUDFHash" -> true) { + // HiveSimpleUDF + sql(s"CREATE TEMPORARY FUNCTION testStringStringUDF AS '${classOf[UDFStringString].getName}'") + val df1 = sql("SELECT testStringStringUDF(rand(), \"hello\")") + assert(!df1.logicalPlan.asInstanceOf[Project].projectList.forall(_.deterministic)) + + // HiveGenericUDF + sql(s"CREATE TEMPORARY FUNCTION testGenericUDFHash AS '${classOf[GenericUDFHash].getName}'") + val df2 = sql("SELECT testGenericUDFHash(rand())") + assert(!df2.logicalPlan.asInstanceOf[Project].projectList.forall(_.deterministic)) + } + } + test("Hive UDFs with insufficient number of input arguments should trigger an analysis error") { Seq((1, 2)).toDF("a", "b").createOrReplaceTempView("testUDF") @@ -558,6 +573,23 @@ class HiveUDFSuite extends QueryTest with TestHiveSingleton with SQLTestUtils { checkAnswer(testData.selectExpr("statelessUDF() as s").agg(max($"s")), Row(1)) } } + + test("Show persistent functions") { + val testData = spark.sparkContext.parallelize(StringCaseClass("") :: Nil).toDF() + withTempView("inputTable") { + testData.createOrReplaceTempView("inputTable") + withUserDefinedFunction("testUDFToListInt" -> false) { + val numFunc = spark.catalog.listFunctions().count() + sql(s"CREATE FUNCTION testUDFToListInt AS '${classOf[UDFToListInt].getName}'") + assert(spark.catalog.listFunctions().count() == numFunc + 1) + checkAnswer( + sql("SELECT testUDFToListInt(s) FROM inputTable"), + Seq(Row(Seq(1, 2, 3)))) + assert(sql("show functions").count() == numFunc + 1) + assert(spark.catalog.listFunctions().count() == numFunc + 1) + } + } + } } class TestPair(x: Int, y: Int) extends Writable with Serializable { diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/PruneFileSourcePartitionsSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/PruneFileSourcePartitionsSuite.scala index cd8f94b1cc4f0..f818e29555468 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/PruneFileSourcePartitionsSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/PruneFileSourcePartitionsSuite.scala @@ -58,7 +58,7 @@ class PruneFileSourcePartitionsSuite extends QueryTest with SQLTestUtils with Te fileFormat = new ParquetFileFormat(), options = Map.empty)(sparkSession = spark) - val logicalRelation = LogicalRelation(relation, catalogTable = Some(tableMeta)) + val logicalRelation = LogicalRelation(relation, tableMeta) val query = Project(Seq('i, 'p), Filter('p === 1, logicalRelation)).analyze val optimized = Optimize.execute(query) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala index b936f16f80673..159386e2f0c2a 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.hive.execution import java.io.File import java.nio.charset.StandardCharsets import java.sql.{Date, Timestamp} +import java.util.Locale import com.google.common.io.Files import org.apache.hadoop.fs.Path @@ -28,7 +29,7 @@ import org.apache.spark.TestUtils import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.analysis.{EliminateSubqueryAliases, FunctionRegistry, NoSuchPartitionException} -import org.apache.spark.sql.catalyst.catalog.{CatalogRelation, CatalogTableType} +import org.apache.spark.sql.catalyst.catalog.{CatalogRelation, CatalogTableType, CatalogUtils} import org.apache.spark.sql.catalyst.parser.ParseException import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, SubqueryAlias} import org.apache.spark.sql.execution.datasources.{HadoopFsRelation, LogicalRelation} @@ -363,79 +364,6 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils with TestHiveSingleton { } } - test("describe partition") { - withTable("partitioned_table") { - sql("CREATE TABLE partitioned_table (a STRING, b INT) PARTITIONED BY (c STRING, d STRING)") - sql("ALTER TABLE partitioned_table ADD PARTITION (c='Us', d=1)") - - checkKeywordsExist(sql("DESC partitioned_table PARTITION (c='Us', d=1)"), - "# Partition Information", - "# col_name") - - checkKeywordsExist(sql("DESC EXTENDED partitioned_table PARTITION (c='Us', d=1)"), - "# Partition Information", - "# col_name", - "Detailed Partition Information CatalogPartition(", - "Partition Values: [c=Us, d=1]", - "Storage(Location:", - "Partition Parameters") - - checkKeywordsExist(sql("DESC FORMATTED partitioned_table PARTITION (c='Us', d=1)"), - "# Partition Information", - "# col_name", - "# Detailed Partition Information", - "Partition Value:", - "Database:", - "Table:", - "Location:", - "Partition Parameters:", - "# Storage Information") - } - } - - test("describe partition - error handling") { - withTable("partitioned_table", "datasource_table") { - sql("CREATE TABLE partitioned_table (a STRING, b INT) PARTITIONED BY (c STRING, d STRING)") - sql("ALTER TABLE partitioned_table ADD PARTITION (c='Us', d=1)") - - val m = intercept[NoSuchPartitionException] { - sql("DESC partitioned_table PARTITION (c='Us', d=2)") - }.getMessage() - assert(m.contains("Partition not found in table")) - - val m2 = intercept[AnalysisException] { - sql("DESC partitioned_table PARTITION (c='Us')") - }.getMessage() - assert(m2.contains("Partition spec is invalid")) - - val m3 = intercept[ParseException] { - sql("DESC partitioned_table PARTITION (c='Us', d)") - }.getMessage() - assert(m3.contains("PARTITION specification is incomplete: `d`")) - - spark - .range(1).select('id as 'a, 'id as 'b, 'id as 'c, 'id as 'd).write - .partitionBy("d") - .saveAsTable("datasource_table") - - sql("DESC datasource_table PARTITION (d=0)") - - val m5 = intercept[AnalysisException] { - spark.range(10).select('id as 'a, 'id as 'b).createTempView("view1") - sql("DESC view1 PARTITION (c='Us', d=1)") - }.getMessage() - assert(m5.contains("DESC PARTITION is not allowed on a temporary view")) - - withView("permanent_view") { - val m = intercept[AnalysisException] { - sql("CREATE VIEW permanent_view AS SELECT * FROM partitioned_table") - sql("DESC permanent_view PARTITION (c='Us', d=1)") - }.getMessage() - assert(m.contains("DESC PARTITION is not allowed on a view")) - } - } - } - test("SPARK-5371: union with null and sum") { val df = Seq((1, 1)).toDF("c1", "c2") df.createOrReplaceTempView("table1") @@ -544,17 +472,17 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils with TestHiveSingleton { } userSpecifiedLocation match { case Some(location) => - assert(r.tableMeta.location === location) + assert(r.tableMeta.location === CatalogUtils.stringToURI(location)) case None => // OK. } // Also make sure that the format and serde are as desired. - assert(catalogTable.storage.inputFormat.get.toLowerCase.contains(format)) - assert(catalogTable.storage.outputFormat.get.toLowerCase.contains(format)) + assert(catalogTable.storage.inputFormat.get.toLowerCase(Locale.ROOT).contains(format)) + assert(catalogTable.storage.outputFormat.get.toLowerCase(Locale.ROOT).contains(format)) val serde = catalogTable.storage.serde.get format match { case "sequence" | "text" => assert(serde.contains("LazySimpleSerDe")) case "rcfile" => assert(serde.contains("LazyBinaryColumnarSerDe")) - case _ => assert(serde.toLowerCase.contains(format)) + case _ => assert(serde.toLowerCase(Locale.ROOT).contains(format)) } } @@ -676,7 +604,7 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils with TestHiveSingleton { } test("CTAS with serde") { - sql("CREATE TABLE ctas1 AS SELECT key k, value FROM src ORDER BY k, value").collect() + sql("CREATE TABLE ctas1 AS SELECT key k, value FROM src ORDER BY k, value") sql( """CREATE TABLE ctas2 | ROW FORMAT SERDE "org.apache.hadoop.hive.serde2.columnar.ColumnarSerDe" @@ -686,86 +614,76 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils with TestHiveSingleton { | AS | SELECT key, value | FROM src - | ORDER BY key, value""".stripMargin).collect() + | ORDER BY key, value""".stripMargin) + + val storageCtas2 = spark.sessionState.catalog.getTableMetadata(TableIdentifier("ctas2")).storage + assert(storageCtas2.inputFormat == Some("org.apache.hadoop.hive.ql.io.RCFileInputFormat")) + assert(storageCtas2.outputFormat == Some("org.apache.hadoop.hive.ql.io.RCFileOutputFormat")) + assert(storageCtas2.serde == Some("org.apache.hadoop.hive.serde2.columnar.ColumnarSerDe")) + sql( """CREATE TABLE ctas3 | ROW FORMAT DELIMITED FIELDS TERMINATED BY ',' LINES TERMINATED BY '\012' | STORED AS textfile AS | SELECT key, value | FROM src - | ORDER BY key, value""".stripMargin).collect() + | ORDER BY key, value""".stripMargin) // the table schema may like (key: integer, value: string) sql( """CREATE TABLE IF NOT EXISTS ctas4 AS - | SELECT 1 AS key, value FROM src LIMIT 1""".stripMargin).collect() + | SELECT 1 AS key, value FROM src LIMIT 1""".stripMargin) // do nothing cause the table ctas4 already existed. sql( """CREATE TABLE IF NOT EXISTS ctas4 AS - | SELECT key, value FROM src ORDER BY key, value""".stripMargin).collect() + | SELECT key, value FROM src ORDER BY key, value""".stripMargin) checkAnswer( sql("SELECT k, value FROM ctas1 ORDER BY k, value"), - sql("SELECT key, value FROM src ORDER BY key, value").collect().toSeq) + sql("SELECT key, value FROM src ORDER BY key, value")) checkAnswer( sql("SELECT key, value FROM ctas2 ORDER BY key, value"), sql( """ SELECT key, value FROM src - ORDER BY key, value""").collect().toSeq) + ORDER BY key, value""")) checkAnswer( sql("SELECT key, value FROM ctas3 ORDER BY key, value"), sql( """ SELECT key, value FROM src - ORDER BY key, value""").collect().toSeq) + ORDER BY key, value""")) intercept[AnalysisException] { sql( """CREATE TABLE ctas4 AS - | SELECT key, value FROM src ORDER BY key, value""".stripMargin).collect() + | SELECT key, value FROM src ORDER BY key, value""".stripMargin) } checkAnswer( sql("SELECT key, value FROM ctas4 ORDER BY key, value"), sql("SELECT key, value FROM ctas4 LIMIT 1").collect().toSeq) - /* - Disabled because our describe table does not output the serde information right now. - checkKeywordsExist(sql("DESC EXTENDED ctas2"), - "name:key", "type:string", "name:value", "ctas2", - "org.apache.hadoop.hive.ql.io.RCFileInputFormat", - "org.apache.hadoop.hive.ql.io.RCFileOutputFormat", - "org.apache.hadoop.hive.serde2.columnar.ColumnarSerDe", - "serde_p1=p1", "serde_p2=p2", "tbl_p1=p11", "tbl_p2=p22", "MANAGED_TABLE" - ) - */ - sql( """CREATE TABLE ctas5 | STORED AS parquet AS | SELECT key, value | FROM src - | ORDER BY key, value""".stripMargin).collect() + | ORDER BY key, value""".stripMargin) + val storageCtas5 = spark.sessionState.catalog.getTableMetadata(TableIdentifier("ctas5")).storage + assert(storageCtas5.inputFormat == + Some("org.apache.hadoop.hive.ql.io.parquet.MapredParquetInputFormat")) + assert(storageCtas5.outputFormat == + Some("org.apache.hadoop.hive.ql.io.parquet.MapredParquetOutputFormat")) + assert(storageCtas5.serde == + Some("org.apache.hadoop.hive.ql.io.parquet.serde.ParquetHiveSerDe")) - /* - Disabled because our describe table does not output the serde information right now. - withSQLConf(HiveUtils.CONVERT_METASTORE_PARQUET.key -> "false") { - checkKeywordsExist(sql("DESC EXTENDED ctas5"), - "name:key", "type:string", "name:value", "ctas5", - "org.apache.hadoop.hive.ql.io.parquet.MapredParquetInputFormat", - "org.apache.hadoop.hive.ql.io.parquet.MapredParquetOutputFormat", - "org.apache.hadoop.hive.ql.io.parquet.serde.ParquetHiveSerDe", - "MANAGED_TABLE" - ) - } - */ // use the Hive SerDe for parquet tables withSQLConf(HiveUtils.CONVERT_METASTORE_PARQUET.key -> "false") { checkAnswer( sql("SELECT key, value FROM ctas5 ORDER BY key, value"), - sql("SELECT key, value FROM src ORDER BY key, value").collect().toSeq) + sql("SELECT key, value FROM src ORDER BY key, value")) } } @@ -1030,7 +948,7 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils with TestHiveSingleton { withSQLConf(SQLConf.CONVERT_CTAS.key -> "false") { sql("CREATE TABLE explodeTest (key bigInt)") table("explodeTest").queryExecution.analyzed match { - case SubqueryAlias(_, r: CatalogRelation, _) => // OK + case SubqueryAlias(_, r: CatalogRelation) => // OK case _ => fail("To correctly test the fix of SPARK-5875, explodeTest should be a MetastoreRelation") } @@ -2057,4 +1975,44 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils with TestHiveSingleton { } } } + + test("Auto alias construction of get_json_object") { + val df = Seq(("1", """{"f1": "value1", "f5": 5.23}""")).toDF("key", "jstring") + val expectedMsg = "Cannot create a table having a column whose name contains commas " + + "in Hive metastore. Table: `default`.`t`; Column: get_json_object(jstring, $.f1)" + + withTable("t") { + val e = intercept[AnalysisException] { + df.select($"key", functions.get_json_object($"jstring", "$.f1")) + .write.format("hive").saveAsTable("t") + }.getMessage + assert(e.contains(expectedMsg)) + } + + withTempView("tempView") { + withTable("t") { + df.createTempView("tempView") + val e = intercept[AnalysisException] { + sql("CREATE TABLE t AS SELECT key, get_json_object(jstring, '$.f1') FROM tempView") + }.getMessage + assert(e.contains(expectedMsg)) + } + } + } + + test("SPARK-19912 String literals should be escaped for Hive metastore partition pruning") { + withTable("spark_19912") { + Seq( + (1, "p1", "q1"), + (2, "'", "q2"), + (3, "\"", "q3"), + (4, "p1\" and q=\"q1", "q4") + ).toDF("a", "p", "q").write.partitionBy("p", "q").saveAsTable("spark_19912") + + val table = spark.table("spark_19912") + checkAnswer(table.filter($"p" === "'").select($"a"), Row(2)) + checkAnswer(table.filter($"p" === "\"").select($"a"), Row(3)) + checkAnswer(table.filter($"p" === "p1\" and q=\"q1").select($"a"), Row(4)) + } + } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcHadoopFsRelationSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcHadoopFsRelationSuite.scala index 4f771caa1db27..ba0a7605da71c 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcHadoopFsRelationSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcHadoopFsRelationSuite.scala @@ -19,10 +19,10 @@ package org.apache.spark.sql.hive.orc import java.io.File -import org.apache.hadoop.fs.{FileSystem, Path} +import org.apache.hadoop.fs.Path -import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.sql.Row +import org.apache.spark.sql.catalyst.catalog.CatalogUtils import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.sources.HadoopFsRelationTest import org.apache.spark.sql.types._ @@ -42,12 +42,9 @@ class OrcHadoopFsRelationSuite extends HadoopFsRelationTest { test("save()/load() - partitioned table - simple queries - partition columns in data") { withTempDir { file => - val basePath = new Path(file.getCanonicalPath) - val fs = basePath.getFileSystem(SparkHadoopUtil.get.conf) - val qualifiedBasePath = fs.makeQualified(basePath) - for (p1 <- 1 to 2; p2 <- Seq("foo", "bar")) { - val partitionDir = new Path(qualifiedBasePath, s"p1=$p1/p2=$p2") + val partitionDir = new Path( + CatalogUtils.URIToString(makeQualifiedPath(file.getCanonicalPath)), s"p1=$p1/p2=$p2") sparkContext .parallelize(for (i <- 1 to 3) yield (i, s"val_$i", p1)) .toDF("a", "b", "p1") diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcQuerySuite.scala index 38a5477796a4a..8c855730c31f2 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcQuerySuite.scala @@ -33,6 +33,7 @@ import org.apache.spark.sql.hive.test.TestHive._ import org.apache.spark.sql.hive.test.TestHive.implicits._ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.{IntegerType, StructType} +import org.apache.spark.util.Utils case class AllDataTypesWithNonPrimitiveType( stringField: String, @@ -284,7 +285,7 @@ class OrcQuerySuite extends QueryTest with BeforeAndAfterAll with OrcTest { val queryOutput = selfJoin.queryExecution.analyzed.output assertResult(4, "Field count mismatches")(queryOutput.size) - assertResult(2, "Duplicated expression ID in query plan:\n $selfJoin") { + assertResult(2, s"Duplicated expression ID in query plan:\n $selfJoin") { queryOutput.filter(_.name == "_1").map(_.exprId).size } @@ -293,7 +294,7 @@ class OrcQuerySuite extends QueryTest with BeforeAndAfterAll with OrcTest { } test("nested data - struct with array field") { - val data = (1 to 10).map(i => Tuple1((i, Seq("val_$i")))) + val data = (1 to 10).map(i => Tuple1((i, Seq(s"val_$i")))) withOrcTable(data, "t") { checkAnswer(sql("SELECT `_1`.`_2`[0] FROM t"), data.map { case Tuple1((_, Seq(string))) => Row(string) @@ -302,7 +303,7 @@ class OrcQuerySuite extends QueryTest with BeforeAndAfterAll with OrcTest { } test("nested data - array of struct") { - val data = (1 to 10).map(i => Tuple1(Seq(i -> "val_$i"))) + val data = (1 to 10).map(i => Tuple1(Seq(i -> s"val_$i"))) withOrcTable(data, "t") { checkAnswer(sql("SELECT `_1`[0].`_2` FROM t"), data.map { case Tuple1(Seq((_, string))) => Row(string) @@ -611,4 +612,12 @@ class OrcQuerySuite extends QueryTest with BeforeAndAfterAll with OrcTest { } } } + + test("read from multiple orc input paths") { + val path1 = Utils.createTempDir() + val path2 = Utils.createTempDir() + makeOrcFile((1 to 10).map(Tuple1.apply), path1) + makeOrcFile((1 to 10).map(Tuple1.apply), path2) + assertResult(20)(read.orc(path1.getCanonicalPath, path2.getCanonicalPath).count()) + } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcSourceSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcSourceSuite.scala index 11dda5425cf94..6bfb88c0c1af5 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcSourceSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcSourceSuite.scala @@ -157,19 +157,21 @@ abstract class OrcSuite extends QueryTest with TestHiveSingleton with BeforeAndA val location = Utils.createTempDir() val uri = location.toURI try { + hiveClient.runSqlHive("USE default") hiveClient.runSqlHive( """ - |CREATE EXTERNAL TABLE hive_orc( - | a STRING, - | b CHAR(10), - | c VARCHAR(10), - | d ARRAY) - |STORED AS orc""".stripMargin) + |CREATE EXTERNAL TABLE hive_orc( + | a STRING, + | b CHAR(10), + | c VARCHAR(10), + | d ARRAY) + |STORED AS orc""".stripMargin) // Hive throws an exception if I assign the location in the create table statement. hiveClient.runSqlHive( s"ALTER TABLE hive_orc SET LOCATION '$uri'") hiveClient.runSqlHive( - """INSERT INTO TABLE hive_orc + """ + |INSERT INTO TABLE hive_orc |SELECT 'a', 'b', 'c', ARRAY(CAST('d' AS CHAR(3))) |FROM (SELECT 1) t""".stripMargin) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcTest.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcTest.scala index 7226ed521ef32..a2f08c5ba72c6 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcTest.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcTest.scala @@ -43,7 +43,7 @@ private[sql] trait OrcTest extends SQLTestUtils with TestHiveSingleton { } /** - * Writes `data` to a Orc file and reads it back as a [[DataFrame]], + * Writes `data` to a Orc file and reads it back as a `DataFrame`, * which is then passed to `f`. The Orc file will be deleted after `f` returns. */ protected def withOrcDataFrame[T <: Product: ClassTag: TypeTag] @@ -53,7 +53,7 @@ private[sql] trait OrcTest extends SQLTestUtils with TestHiveSingleton { } /** - * Writes `data` to a Orc file, reads it back as a [[DataFrame]] and registers it as a + * Writes `data` to a Orc file, reads it back as a `DataFrame` and registers it as a * temporary table named `tableName`, then call `f`. The temporary table together with the * Orc file will be dropped/deleted after `f` returns. */ diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/parquetSuites.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/parquetSuites.scala index d528d8c58841c..eaec06eb91480 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/parquetSuites.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/parquetSuites.scala @@ -22,6 +22,7 @@ import java.io.File import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.catalog.CatalogRelation +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.execution.DataSourceScanExec import org.apache.spark.sql.execution.datasources._ import org.apache.spark.sql.hive.execution.HiveTableScanExec @@ -448,12 +449,17 @@ class ParquetMetastoreSuite extends ParquetPartitioningTest { } } + private def getCachedDataSourceTable(table: TableIdentifier): LogicalPlan = { + sessionState.catalog.asInstanceOf[HiveSessionCatalog].metastoreCatalog + .getCachedDataSourceTable(table) + } + test("Caching converted data source Parquet Relations") { def checkCached(tableIdentifier: TableIdentifier): Unit = { // Converted test_parquet should be cached. - sessionState.catalog.getCachedDataSourceTable(tableIdentifier) match { + getCachedDataSourceTable(tableIdentifier) match { case null => fail("Converted test_parquet should be cached in the cache.") - case logical @ LogicalRelation(parquetRelation: HadoopFsRelation, _, _) => // OK + case LogicalRelation(_: HadoopFsRelation, _, _) => // OK case other => fail( "The cached test_parquet should be a Parquet Relation. " + @@ -479,14 +485,14 @@ class ParquetMetastoreSuite extends ParquetPartitioningTest { var tableIdentifier = TableIdentifier("test_insert_parquet", Some("default")) // First, make sure the converted test_parquet is not cached. - assert(sessionState.catalog.getCachedDataSourceTable(tableIdentifier) === null) + assert(getCachedDataSourceTable(tableIdentifier) === null) // Table lookup will make the table cached. table("test_insert_parquet") checkCached(tableIdentifier) // For insert into non-partitioned table, we will do the conversion, // so the converted test_insert_parquet should be cached. sessionState.refreshTable("test_insert_parquet") - assert(sessionState.catalog.getCachedDataSourceTable(tableIdentifier) === null) + assert(getCachedDataSourceTable(tableIdentifier) === null) sql( """ |INSERT INTO TABLE test_insert_parquet @@ -499,7 +505,7 @@ class ParquetMetastoreSuite extends ParquetPartitioningTest { sql("select a, b from jt").collect()) // Invalidate the cache. sessionState.refreshTable("test_insert_parquet") - assert(sessionState.catalog.getCachedDataSourceTable(tableIdentifier) === null) + assert(getCachedDataSourceTable(tableIdentifier) === null) // Create a partitioned table. sql( @@ -517,7 +523,7 @@ class ParquetMetastoreSuite extends ParquetPartitioningTest { """.stripMargin) tableIdentifier = TableIdentifier("test_parquet_partitioned_cache_test", Some("default")) - assert(sessionState.catalog.getCachedDataSourceTable(tableIdentifier) === null) + assert(getCachedDataSourceTable(tableIdentifier) === null) sql( """ |INSERT INTO TABLE test_parquet_partitioned_cache_test @@ -526,14 +532,14 @@ class ParquetMetastoreSuite extends ParquetPartitioningTest { """.stripMargin) // Right now, insert into a partitioned Parquet is not supported in data source Parquet. // So, we expect it is not cached. - assert(sessionState.catalog.getCachedDataSourceTable(tableIdentifier) === null) + assert(getCachedDataSourceTable(tableIdentifier) === null) sql( """ |INSERT INTO TABLE test_parquet_partitioned_cache_test |PARTITION (`date`='2015-04-02') |select a, b from jt """.stripMargin) - assert(sessionState.catalog.getCachedDataSourceTable(tableIdentifier) === null) + assert(getCachedDataSourceTable(tableIdentifier) === null) // Make sure we can cache the partitioned table. table("test_parquet_partitioned_cache_test") @@ -549,7 +555,7 @@ class ParquetMetastoreSuite extends ParquetPartitioningTest { """.stripMargin).collect()) sessionState.refreshTable("test_parquet_partitioned_cache_test") - assert(sessionState.catalog.getCachedDataSourceTable(tableIdentifier) === null) + assert(getCachedDataSourceTable(tableIdentifier) === null) dropTables("test_insert_parquet", "test_parquet_partitioned_cache_test") } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/sources/JsonHadoopFsRelationSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/sources/JsonHadoopFsRelationSuite.scala index d79edee5b1a4c..49be30435ad2f 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/sources/JsonHadoopFsRelationSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/sources/JsonHadoopFsRelationSuite.scala @@ -21,8 +21,8 @@ import java.math.BigDecimal import org.apache.hadoop.fs.Path -import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.sql.Row +import org.apache.spark.sql.catalyst.catalog.CatalogUtils import org.apache.spark.sql.types._ class JsonHadoopFsRelationSuite extends HadoopFsRelationTest { @@ -38,12 +38,9 @@ class JsonHadoopFsRelationSuite extends HadoopFsRelationTest { test("save()/load() - partitioned table - simple queries - partition columns in data") { withTempDir { file => - val basePath = new Path(file.getCanonicalPath) - val fs = basePath.getFileSystem(SparkHadoopUtil.get.conf) - val qualifiedBasePath = fs.makeQualified(basePath) - for (p1 <- 1 to 2; p2 <- Seq("foo", "bar")) { - val partitionDir = new Path(qualifiedBasePath, s"p1=$p1/p2=$p2") + val partitionDir = new Path( + CatalogUtils.URIToString(makeQualifiedPath(file.getCanonicalPath)), s"p1=$p1/p2=$p2") sparkContext .parallelize(for (i <- 1 to 3) yield s"""{"a":$i,"b":"val_$i"}""") .saveAsTextFile(partitionDir.toString) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/sources/ParquetHadoopFsRelationSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/sources/ParquetHadoopFsRelationSuite.scala index f676e24fcf972..6858bbc441721 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/sources/ParquetHadoopFsRelationSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/sources/ParquetHadoopFsRelationSuite.scala @@ -23,8 +23,8 @@ import com.google.common.io.Files import org.apache.hadoop.fs.Path import org.apache.parquet.hadoop.ParquetOutputFormat -import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.sql._ +import org.apache.spark.sql.catalyst.catalog.CatalogUtils import org.apache.spark.sql.execution.datasources.SQLHadoopMapReduceCommitProtocol import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ @@ -44,12 +44,9 @@ class ParquetHadoopFsRelationSuite extends HadoopFsRelationTest { test("save()/load() - partitioned table - simple queries - partition columns in data") { withTempDir { file => - val basePath = new Path(file.getCanonicalPath) - val fs = basePath.getFileSystem(SparkHadoopUtil.get.conf) - val qualifiedBasePath = fs.makeQualified(basePath) - for (p1 <- 1 to 2; p2 <- Seq("foo", "bar")) { - val partitionDir = new Path(qualifiedBasePath, s"p1=$p1/p2=$p2") + val partitionDir = new Path( + CatalogUtils.URIToString(makeQualifiedPath(file.getCanonicalPath)), s"p1=$p1/p2=$p2") sparkContext .parallelize(for (i <- 1 to 3) yield (i, s"val_$i", p1)) .toDF("a", "b", "p1") diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextHadoopFsRelationSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextHadoopFsRelationSuite.scala index a47a2246ddc3c..2ec593b95c9b6 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextHadoopFsRelationSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextHadoopFsRelationSuite.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.sources import org.apache.hadoop.fs.Path -import org.apache.spark.deploy.SparkHadoopUtil +import org.apache.spark.sql.catalyst.catalog.CatalogUtils import org.apache.spark.sql.catalyst.expressions.PredicateHelper import org.apache.spark.sql.types._ @@ -45,12 +45,9 @@ class SimpleTextHadoopFsRelationSuite extends HadoopFsRelationTest with Predicat test("save()/load() - partitioned table - simple queries - partition columns in data") { withTempDir { file => - val basePath = new Path(file.getCanonicalPath) - val fs = basePath.getFileSystem(SparkHadoopUtil.get.conf) - val qualifiedBasePath = fs.makeQualified(basePath) - for (p1 <- 1 to 2; p2 <- Seq("foo", "bar")) { - val partitionDir = new Path(qualifiedBasePath, s"p1=$p1/p2=$p2") + val partitionDir = new Path( + CatalogUtils.URIToString(makeQualifiedPath(file.getCanonicalPath)), s"p1=$p1/p2=$p2") sparkContext .parallelize(for (i <- 1 to 3) yield s"$i,val_$i,$p1") .saveAsTextFile(partitionDir.toString) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextRelation.scala b/sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextRelation.scala index b04ef6f21d69f..60a4638f610b3 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextRelation.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextRelation.scala @@ -21,7 +21,7 @@ import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.{FileStatus, Path} import org.apache.hadoop.mapreduce.{Job, TaskAttemptContext} -import org.apache.spark.sql.{sources, Row, SparkSession} +import org.apache.spark.sql.{sources, SparkSession} import org.apache.spark.sql.catalyst.{expressions, InternalRow} import org.apache.spark.sql.catalyst.expressions.{Cast, Expression, GenericInternalRow, InterpretedPredicate, InterpretedProjection, JoinedRow, Literal} import org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeProjection diff --git a/streaming/pom.xml b/streaming/pom.xml index de1be9c13e05f..fea882ad11230 100644 --- a/streaming/pom.xml +++ b/streaming/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.11 - 2.2.0-SNAPSHOT + 2.3.0-SNAPSHOT ../pom.xml diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/FileInputDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/FileInputDStream.scala index ed9305875cb77..905b1c52afa69 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/FileInputDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/FileInputDStream.scala @@ -230,7 +230,7 @@ class FileInputDStream[K, V, F <: NewInputFormat[K, V]]( * - It must pass the user-provided file filter. * - It must be newer than the ignore threshold. It is assumed that files older than the ignore * threshold have already been considered or are existing files before start - * (when newFileOnly = true). + * (when newFilesOnly = true). * - It must not be present in the recently selected files that this class remembers. * - It must not be newer than the time of the batch (i.e. `currentTime` for which this * file is being tested. This can occur if the driver was recovered, and the missing batches diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/InputDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/InputDStream.scala index 9a760e2947d0b..931f015f03b6f 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/InputDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/InputDStream.scala @@ -17,6 +17,8 @@ package org.apache.spark.streaming.dstream +import java.util.Locale + import scala.reflect.ClassTag import org.apache.spark.SparkContext @@ -60,7 +62,7 @@ abstract class InputDStream[T: ClassTag](_ssc: StreamingContext) .split("(?=[A-Z])") .filter(_.nonEmpty) .mkString(" ") - .toLowerCase + .toLowerCase(Locale.ROOT) .capitalize s"$newName [$id]" } @@ -74,7 +76,7 @@ abstract class InputDStream[T: ClassTag](_ssc: StreamingContext) protected[streaming] override val baseScope: Option[String] = { val scopeName = Option(ssc.sc.getLocalProperty(SparkContext.RDD_SCOPE_KEY)) .map { json => RDDOperationScope.fromJson(json).name + s" [$id]" } - .getOrElse(name.toLowerCase) + .getOrElse(name.toLowerCase(Locale.ROOT)) Some(new RDDOperationScope(scopeName).toJson) } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/rdd/MapWithStateRDD.scala b/streaming/src/main/scala/org/apache/spark/streaming/rdd/MapWithStateRDD.scala index 58b7031d5ea6a..15d3c7e54b8dd 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/rdd/MapWithStateRDD.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/rdd/MapWithStateRDD.scala @@ -29,7 +29,7 @@ import org.apache.spark.streaming.util.{EmptyStateMap, StateMap} import org.apache.spark.util.Utils /** - * Record storing the keyed-state [[MapWithStateRDD]]. Each record contains a [[StateMap]] and a + * Record storing the keyed-state [[MapWithStateRDD]]. Each record contains a `StateMap` and a * sequence of records returned by the mapping function of `mapWithState`. */ private[streaming] case class MapWithStateRDDRecord[K, S, E]( @@ -111,7 +111,7 @@ private[streaming] class MapWithStateRDDPartition( /** * RDD storing the keyed states of `mapWithState` operation and corresponding mapped data. * Each partition of this RDD has a single record of type [[MapWithStateRDDRecord]]. This contains a - * [[StateMap]] (containing the keyed-states) and the sequence of records returned by the mapping + * `StateMap` (containing the keyed-states) and the sequence of records returned by the mapping * function of `mapWithState`. * @param prevStateRDD The previous MapWithStateRDD on whose StateMap data `this` RDD * will be created diff --git a/streaming/src/main/scala/org/apache/spark/streaming/rdd/WriteAheadLogBackedBlockRDD.scala b/streaming/src/main/scala/org/apache/spark/streaming/rdd/WriteAheadLogBackedBlockRDD.scala index d0864fd3678b2..844760ab61d2e 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/rdd/WriteAheadLogBackedBlockRDD.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/rdd/WriteAheadLogBackedBlockRDD.scala @@ -158,16 +158,14 @@ class WriteAheadLogBackedBlockRDD[T: ClassTag]( logInfo(s"Read partition data of $this from write ahead log, record handle " + partition.walRecordHandle) if (storeInBlockManager) { - blockManager.putBytes(blockId, new ChunkedByteBuffer(dataRead.duplicate()), storageLevel, - encrypt = true) + blockManager.putBytes(blockId, new ChunkedByteBuffer(dataRead.duplicate()), storageLevel) logDebug(s"Stored partition data of $this into block manager with level $storageLevel") dataRead.rewind() } serializerManager .dataDeserializeStream( blockId, - new ChunkedByteBuffer(dataRead).toInputStream(), - maybeEncrypted = false)(elementClassTag) + new ChunkedByteBuffer(dataRead).toInputStream())(elementClassTag) .asInstanceOf[Iterator[T]] } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceivedBlockHandler.scala b/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceivedBlockHandler.scala index 2b488038f0620..80c07958b41f2 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceivedBlockHandler.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceivedBlockHandler.scala @@ -87,8 +87,7 @@ private[streaming] class BlockManagerBasedBlockHandler( putResult case ByteBufferBlock(byteBuffer) => blockManager.putBytes( - blockId, new ChunkedByteBuffer(byteBuffer.duplicate()), storageLevel, tellMaster = true, - encrypt = true) + blockId, new ChunkedByteBuffer(byteBuffer.duplicate()), storageLevel, tellMaster = true) case o => throw new SparkException( s"Could not store $blockId to block manager, unexpected block type ${o.getClass.getName}") @@ -176,11 +175,10 @@ private[streaming] class WriteAheadLogBasedBlockHandler( val serializedBlock = block match { case ArrayBufferBlock(arrayBuffer) => numRecords = Some(arrayBuffer.size.toLong) - serializerManager.dataSerialize(blockId, arrayBuffer.iterator, allowEncryption = false) + serializerManager.dataSerialize(blockId, arrayBuffer.iterator) case IteratorBlock(iterator) => val countIterator = new CountingIterator(iterator) - val serializedBlock = serializerManager.dataSerialize(blockId, countIterator, - allowEncryption = false) + val serializedBlock = serializerManager.dataSerialize(blockId, countIterator) numRecords = countIterator.count serializedBlock case ByteBufferBlock(byteBuffer) => @@ -195,8 +193,7 @@ private[streaming] class WriteAheadLogBasedBlockHandler( blockId, serializedBlock, effectiveStorageLevel, - tellMaster = true, - encrypt = true) + tellMaster = true) if (!putSucceeded) { throw new SparkException( s"Could not store $blockId to block manager with storage level $storageLevel") diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/InputInfoTracker.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/InputInfoTracker.scala index 8e1a090618433..639ac6de4f5d3 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/InputInfoTracker.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/InputInfoTracker.scala @@ -66,7 +66,7 @@ private[streaming] class InputInfoTracker(ssc: StreamingContext) extends Logging new mutable.HashMap[Int, StreamInputInfo]()) if (inputInfos.contains(inputInfo.inputStreamId)) { - throw new IllegalStateException(s"Input stream ${inputInfo.inputStreamId} for batch" + + throw new IllegalStateException(s"Input stream ${inputInfo.inputStreamId} for batch " + s"$batchTime is already added into InputInfoTracker, this is an illegal state") } inputInfos += ((inputInfo.inputStreamId, inputInfo)) diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/rate/PIDRateEstimator.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/rate/PIDRateEstimator.scala index a73e6cc2cd9c1..dc02062b9eb44 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/rate/PIDRateEstimator.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/rate/PIDRateEstimator.scala @@ -26,7 +26,7 @@ import org.apache.spark.internal.Logging * case of Spark Streaming the error is the difference between the measured processing * rate (number of elements/processing delay) and the previous rate. * - * @see https://en.wikipedia.org/wiki/PID_controller + * @see PID controller (Wikipedia) * * @param batchIntervalMillis the batch duration, in milliseconds * @param proportional how much the correction should depend on the current diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/rate/RateEstimator.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/rate/RateEstimator.scala index 7b2ef6881d6f7..e4b9dffee04f4 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/rate/RateEstimator.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/rate/RateEstimator.scala @@ -24,7 +24,7 @@ import org.apache.spark.streaming.Duration * A component that estimates the rate at which an `InputDStream` should ingest * records, based on updates at every batch completion. * - * @see [[org.apache.spark.streaming.scheduler.RateController]] + * Please see `org.apache.spark.streaming.scheduler.RateController` for more details. */ private[streaming] trait RateEstimator extends Serializable { diff --git a/streaming/src/main/scala/org/apache/spark/streaming/ui/AllBatchesTable.scala b/streaming/src/main/scala/org/apache/spark/streaming/ui/AllBatchesTable.scala index 1352ca1c4c95f..70b4bb466c46b 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/ui/AllBatchesTable.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/ui/AllBatchesTable.scala @@ -97,7 +97,7 @@ private[ui] abstract class BatchTableBase(tableId: String, batchInterval: Long) completed = batch.numCompletedOutputOp, failed = batch.numFailedOutputOp, skipped = 0, - killed = 0, + reasonToNumKilled = Map.empty, total = batch.outputOperations.size) } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/ui/BatchPage.scala b/streaming/src/main/scala/org/apache/spark/streaming/ui/BatchPage.scala index 1a87fc790f91b..f55af6a5cc358 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/ui/BatchPage.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/ui/BatchPage.scala @@ -146,7 +146,7 @@ private[ui] class BatchPage(parent: StreamingTab) extends WebUIPage("batch") { completed = sparkJob.numCompletedTasks, failed = sparkJob.numFailedTasks, skipped = sparkJob.numSkippedTasks, - killed = sparkJob.numKilledTasks, + reasonToNumKilled = sparkJob.reasonToNumKilled, total = sparkJob.numTasks - sparkJob.numSkippedTasks) } diff --git a/streaming/src/test/java/test/org/apache/spark/streaming/Java8APISuite.java b/streaming/src/test/java/test/org/apache/spark/streaming/Java8APISuite.java index 80513de4ee117..90d1f8c5035b3 100644 --- a/streaming/src/test/java/test/org/apache/spark/streaming/Java8APISuite.java +++ b/streaming/src/test/java/test/org/apache/spark/streaming/Java8APISuite.java @@ -101,7 +101,7 @@ public void testMapPartitions() { JavaDStream mapped = stream.mapPartitions(in -> { String out = ""; while (in.hasNext()) { - out = out + in.next().toUpperCase(); + out = out + in.next().toUpperCase(Locale.ROOT); } return Arrays.asList(out).iterator(); }); @@ -806,7 +806,8 @@ public void testMapValues() { ssc, inputData, 1); JavaPairDStream pairStream = JavaPairDStream.fromJavaDStream(stream); - JavaPairDStream mapped = pairStream.mapValues(String::toUpperCase); + JavaPairDStream mapped = + pairStream.mapValues(s -> s.toUpperCase(Locale.ROOT)); JavaTestUtils.attachTestOutputStream(mapped); List>> result = JavaTestUtils.runStreams(ssc, 2, 2); diff --git a/streaming/src/test/java/test/org/apache/spark/streaming/JavaAPISuite.java b/streaming/src/test/java/test/org/apache/spark/streaming/JavaAPISuite.java index 96f8d9593d630..6c86cacec8279 100644 --- a/streaming/src/test/java/test/org/apache/spark/streaming/JavaAPISuite.java +++ b/streaming/src/test/java/test/org/apache/spark/streaming/JavaAPISuite.java @@ -267,7 +267,7 @@ public void testMapPartitions() { JavaDStream mapped = stream.mapPartitions(in -> { StringBuilder out = new StringBuilder(); while (in.hasNext()) { - out.append(in.next().toUpperCase(Locale.ENGLISH)); + out.append(in.next().toUpperCase(Locale.ROOT)); } return Arrays.asList(out.toString()).iterator(); }); @@ -1315,7 +1315,7 @@ public void testMapValues() { JavaPairDStream pairStream = JavaPairDStream.fromJavaDStream(stream); JavaPairDStream mapped = - pairStream.mapValues(s -> s.toUpperCase(Locale.ENGLISH)); + pairStream.mapValues(s -> s.toUpperCase(Locale.ROOT)); JavaTestUtils.attachTestOutputStream(mapped); List>> result = JavaTestUtils.runStreams(ssc, 2, 2); diff --git a/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockHandlerSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockHandlerSuite.scala index c2b0389b8c6f0..3c4a2716caf90 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockHandlerSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockHandlerSuite.scala @@ -175,8 +175,7 @@ abstract class BaseReceivedBlockHandlerSuite(enableEncryption: Boolean) reader.close() serializerManager.dataDeserializeStream( generateBlockId(), - new ChunkedByteBuffer(bytes).toInputStream(), - maybeEncrypted = false)(ClassTag.Any).toList + new ChunkedByteBuffer(bytes).toInputStream())(ClassTag.Any).toList } loggedData shouldEqual data } @@ -357,7 +356,7 @@ abstract class BaseReceivedBlockHandlerSuite(enableEncryption: Boolean) } def dataToByteBuffer(b: Seq[String]) = - serializerManager.dataSerialize(generateBlockId, b.iterator, allowEncryption = false) + serializerManager.dataSerialize(generateBlockId, b.iterator) val blocks = data.grouped(10).toSeq diff --git a/streaming/src/test/scala/org/apache/spark/streaming/StreamingContextSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/StreamingContextSuite.scala index 5645996de5a69..eb996c93ff381 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/StreamingContextSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/StreamingContextSuite.scala @@ -18,6 +18,7 @@ package org.apache.spark.streaming import java.io.{File, NotSerializableException} +import java.util.Locale import java.util.concurrent.{CountDownLatch, TimeUnit} import java.util.concurrent.atomic.AtomicInteger @@ -745,7 +746,7 @@ class StreamingContextSuite extends SparkFunSuite with BeforeAndAfter with Timeo val ex = intercept[IllegalStateException] { body } - assert(ex.getMessage.toLowerCase().contains(expectedErrorMsg)) + assert(ex.getMessage.toLowerCase(Locale.ROOT).contains(expectedErrorMsg)) } } diff --git a/streaming/src/test/scala/org/apache/spark/streaming/rdd/WriteAheadLogBackedBlockRDDSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/rdd/WriteAheadLogBackedBlockRDDSuite.scala index 2ac0dc96916c5..aa69be7ca9939 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/rdd/WriteAheadLogBackedBlockRDDSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/rdd/WriteAheadLogBackedBlockRDDSuite.scala @@ -250,8 +250,7 @@ class WriteAheadLogBackedBlockRDDSuite require(blockData.size === blockIds.size) val writer = new FileBasedWriteAheadLogWriter(new File(dir, "logFile").toString, hadoopConf) val segments = blockData.zip(blockIds).map { case (data, id) => - writer.write(serializerManager.dataSerialize(id, data.iterator, allowEncryption = false) - .toByteBuffer) + writer.write(serializerManager.dataSerialize(id, data.iterator).toByteBuffer) } writer.close() segments diff --git a/tools/pom.xml b/tools/pom.xml index 938ba2f6ac201..7ba4dc9842f1b 100644 --- a/tools/pom.xml +++ b/tools/pom.xml @@ -20,7 +20,7 @@ org.apache.spark spark-parent_2.11 - 2.2.0-SNAPSHOT + 2.3.0-SNAPSHOT ../pom.xml